Added new parameter scoring, support multiple matches

This commit is contained in:
RogueException
2016-08-21 11:03:50 -03:00
parent ed7710fbef
commit 324664917d
14 changed files with 316 additions and 177 deletions

View File

@@ -1,7 +1,9 @@
using System; using System;
using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Diagnostics; using System.Diagnostics;
using System.Linq;
using System.Reflection; using System.Reflection;
using System.Threading.Tasks; using System.Threading.Tasks;
@@ -10,6 +12,9 @@ namespace Discord.Commands
[DebuggerDisplay(@"{DebuggerDisplay,nq}")] [DebuggerDisplay(@"{DebuggerDisplay,nq}")]
public class Command public class Command
{ {
private static readonly MethodInfo _convertParamsMethod = typeof(Command).GetTypeInfo().GetDeclaredMethod(nameof(ConvertParamsList));
private static readonly ConcurrentDictionary<Type, Func<IEnumerable<object>, object>> _arrayConverters = new ConcurrentDictionary<Type, Func<IEnumerable<object>, object>>();
private readonly object _instance; private readonly object _instance;
private readonly Func<IMessage, IReadOnlyList<object>, Task> _action; private readonly Func<IMessage, IReadOnlyList<object>, Task> _action;
@@ -19,6 +24,7 @@ namespace Discord.Commands
public string Description { get; } public string Description { get; }
public string Summary { get; } public string Summary { get; }
public string Text { get; } public string Text { get; }
public bool HasVarArgs { get; }
public IReadOnlyList<CommandParameter> Parameters { get; } public IReadOnlyList<CommandParameter> Parameters { get; }
public IReadOnlyList<PreconditionAttribute> Preconditions { get; } public IReadOnlyList<PreconditionAttribute> Preconditions { get; }
@@ -42,8 +48,9 @@ namespace Discord.Commands
var summary = source.GetCustomAttribute<SummaryAttribute>(); var summary = source.GetCustomAttribute<SummaryAttribute>();
if (summary != null) if (summary != null)
Summary = summary.Text; Summary = summary.Text;
Parameters = BuildParameters(source); Parameters = BuildParameters(source);
HasVarArgs = Parameters.Count > 0 ? Parameters[Parameters.Count - 1].IsMultiple : false;
Preconditions = BuildPreconditions(source); Preconditions = BuildPreconditions(source);
_action = BuildAction(source); _action = BuildAction(source);
} }
@@ -76,14 +83,38 @@ namespace Discord.Commands
return await CommandParser.ParseArgs(this, msg, searchResult.Text.Substring(Text.Length), 0).ConfigureAwait(false); return await CommandParser.ParseArgs(this, msg, searchResult.Text.Substring(Text.Length), 0).ConfigureAwait(false);
} }
public async Task<ExecuteResult> Execute(IMessage msg, ParseResult parseResult) public Task<ExecuteResult> Execute(IMessage msg, ParseResult parseResult)
{ {
if (!parseResult.IsSuccess) if (!parseResult.IsSuccess)
return ExecuteResult.FromError(parseResult); return Task.FromResult(ExecuteResult.FromError(parseResult));
var argList = new object[parseResult.ArgValues.Count];
for (int i = 0; i < parseResult.ArgValues.Count; i++)
{
if (!parseResult.ArgValues[i].IsSuccess)
return Task.FromResult(ExecuteResult.FromError(parseResult.ArgValues[i]));
argList[i] = parseResult.ArgValues[i].Values.First().Value;
}
object[] paramList = null;
if (parseResult.ParamValues != null)
{
paramList = new object[parseResult.ParamValues.Count];
for (int i = 0; i < parseResult.ParamValues.Count; i++)
{
if (!parseResult.ParamValues[i].IsSuccess)
return Task.FromResult(ExecuteResult.FromError(parseResult.ParamValues[i]));
paramList[i] = parseResult.ParamValues[i].Values.First().Value;
}
}
return Execute(msg, argList, paramList);
}
public async Task<ExecuteResult> Execute(IMessage msg, IEnumerable<object> argList, IEnumerable<object> paramList)
{
try try
{ {
await _action.Invoke(msg, parseResult.Values);//Note: This code may need context await _action.Invoke(msg, GenerateArgs(argList, paramList)).ConfigureAwait(false);//Note: This code may need context
return ExecuteResult.FromSuccess(); return ExecuteResult.FromSuccess();
} }
catch (Exception ex) catch (Exception ex)
@@ -108,7 +139,7 @@ namespace Discord.Commands
{ {
var parameter = parameters[i]; var parameter = parameters[i];
var type = parameter.ParameterType; var type = parameter.ParameterType;
//Detect 'params' //Detect 'params'
bool isMultiple = parameter.GetCustomAttribute<ParamArrayAttribute>() != null; bool isMultiple = parameter.GetCustomAttribute<ParamArrayAttribute>() != null;
if (isMultiple) if (isMultiple)
@@ -156,6 +187,39 @@ namespace Discord.Commands
}; };
} }
private object[] GenerateArgs(IEnumerable<object> argList, IEnumerable<object> paramsList)
{
int argCount = Parameters.Count;
var array = new object[Parameters.Count];
if (HasVarArgs)
argCount--;
int i = 0;
foreach (var arg in argList)
{
if (i == argCount)
throw new InvalidOperationException("Command was invoked with too many parameters");
array[i++] = arg;
}
if (i < argCount)
throw new InvalidOperationException("Command was invoked with too few parameters");
if (HasVarArgs)
{
var func = _arrayConverters.GetOrAdd(Parameters[Parameters.Count - 1].ElementType, t =>
{
var method = _convertParamsMethod.MakeGenericMethod(t);
return (Func<IEnumerable<object>, object>)method.CreateDelegate(typeof(Func<IEnumerable<object>, object>));
});
array[i] = func(paramsList);
}
return array;
}
private static T[] ConvertParamsList<T>(IEnumerable<object> paramsList)
=> paramsList.Cast<T>().ToArray();
public override string ToString() => Name; public override string ToString() => Name;
private string DebuggerDisplay => $"{Module.Name}.{Name} ({Text})"; private string DebuggerDisplay => $"{Module.Name}.{Name} ({Text})";
} }

View File

@@ -3,14 +3,14 @@
public enum CommandError public enum CommandError
{ {
//Search //Search
UnknownCommand, UnknownCommand = 1,
//Parse //Parse
ParseFailed, ParseFailed,
BadArgCount, BadArgCount,
//Parse (Type Reader) //Parse (Type Reader)
CastFailed, //CastFailed,
ObjectNotFound, ObjectNotFound,
MultipleMatches, MultipleMatches,

View File

@@ -17,7 +17,7 @@ namespace Discord.Commands
public bool IsRemainder { get; } public bool IsRemainder { get; }
public bool IsMultiple { get; } public bool IsMultiple { get; }
public Type ElementType { get; } public Type ElementType { get; }
internal object DefaultValue { get; } public object DefaultValue { get; }
public CommandParameter(ParameterInfo source, string name, string summary, Type type, TypeReader reader, bool isOptional, bool isRemainder, bool isMultiple, object defaultValue) public CommandParameter(ParameterInfo source, string name, string summary, Type type, TypeReader reader, bool isOptional, bool isRemainder, bool isMultiple, object defaultValue)
{ {

View File

@@ -1,8 +1,5 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Collections.Immutable; using System.Collections.Immutable;
using System.Reflection;
using System.Text; using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
@@ -16,9 +13,6 @@ namespace Discord.Commands
Parameter, Parameter,
QuotedParameter QuotedParameter
} }
private static readonly MethodInfo _convertArrayMethod = typeof(CommandParser).GetTypeInfo().GetDeclaredMethod(nameof(ConvertParamsList));
private static readonly ConcurrentDictionary<Type, Func<List<object>, object>> _arrayConverters = new ConcurrentDictionary<Type, Func<List<object>, object>>();
public static async Task<ParseResult> ParseArgs(Command command, IMessage context, string input, int startPos) public static async Task<ParseResult> ParseArgs(Command command, IMessage context, string input, int startPos)
{ {
@@ -27,9 +21,10 @@ namespace Discord.Commands
int endPos = input.Length; int endPos = input.Length;
var curPart = ParserPart.None; var curPart = ParserPart.None;
int lastArgEndPos = int.MinValue; int lastArgEndPos = int.MinValue;
var argList = ImmutableArray.CreateBuilder<object>(); var argList = ImmutableArray.CreateBuilder<TypeReaderResult>();
List<object> paramsList = null; // TODO: could we use a better type? ImmutableArray<TypeReaderResult>.Builder paramList = null;
bool isEscaping = false; bool isEscaping = false;
bool hasMultipleMatches = false;
char c; char c;
for (int curPos = startPos; curPos <= endPos; curPos++) for (int curPos = startPos; curPos <= endPos; curPos++)
@@ -117,30 +112,28 @@ namespace Discord.Commands
var typeReaderResult = await curParam.Parse(context, argString).ConfigureAwait(false); var typeReaderResult = await curParam.Parse(context, argString).ConfigureAwait(false);
if (!typeReaderResult.IsSuccess) if (!typeReaderResult.IsSuccess)
return ParseResult.FromError(typeReaderResult); {
if (typeReaderResult.Error == CommandError.MultipleMatches)
hasMultipleMatches = true;
else
return ParseResult.FromError(typeReaderResult);
}
if (curParam.IsMultiple) if (curParam.IsMultiple)
{ {
if (paramsList == null) if (paramList == null)
paramsList = new List<object>(); paramList = ImmutableArray.CreateBuilder<TypeReaderResult>();
paramsList.Add(typeReaderResult.Value); paramList.Add(typeReaderResult);
if (curPos == endPos) if (curPos == endPos)
{ {
var func = _arrayConverters.GetOrAdd(curParam.ElementType, t =>
{
var method = _convertArrayMethod.MakeGenericMethod(t);
return (Func<List<object>, object>)method.CreateDelegate(typeof(Func<List<object>, object>));
});
argList.Add(func.Invoke(paramsList));
curParam = null; curParam = null;
curPart = ParserPart.None; curPart = ParserPart.None;
} }
} }
else else
{ {
argList.Add(typeReaderResult.Value); argList.Add(typeReaderResult);
curParam = null; curParam = null;
curPart = ParserPart.None; curPart = ParserPart.None;
@@ -154,34 +147,24 @@ namespace Discord.Commands
var typeReaderResult = await curParam.Parse(context, argBuilder.ToString()).ConfigureAwait(false); var typeReaderResult = await curParam.Parse(context, argBuilder.ToString()).ConfigureAwait(false);
if (!typeReaderResult.IsSuccess) if (!typeReaderResult.IsSuccess)
return ParseResult.FromError(typeReaderResult); return ParseResult.FromError(typeReaderResult);
argList.Add(typeReaderResult.Value); argList.Add(typeReaderResult);
} }
if (isEscaping) if (isEscaping)
return ParseResult.FromError(CommandError.ParseFailed, "Input text may not end on an incomplete escape."); return ParseResult.FromError(CommandError.ParseFailed, "Input text may not end on an incomplete escape.");
if (curPart == ParserPart.QuotedParameter) if (curPart == ParserPart.QuotedParameter)
return ParseResult.FromError(CommandError.ParseFailed, "A quoted parameter is incomplete"); return ParseResult.FromError(CommandError.ParseFailed, "A quoted parameter is incomplete");
if (argList.Count < command.Parameters.Count) //Add missing optionals
for (int i = paramList.Count; i < command.Parameters.Count; i++)
{ {
for (int i = argList.Count; i < command.Parameters.Count; i++) var param = command.Parameters[i];
{ if (!param.IsOptional)
var param = command.Parameters[i]; return ParseResult.FromError(CommandError.BadArgCount, "The input text has too few parameters.");
if (!param.IsOptional) argList.Add(TypeReaderResult.FromSuccess(param.DefaultValue));
return ParseResult.FromError(CommandError.BadArgCount, "The input text has too few parameters.");
argList.Add(param.DefaultValue);
}
} }
return ParseResult.FromSuccess(argList.ToImmutable()); return ParseResult.FromSuccess(argList.ToImmutable(), paramList?.ToImmutable());
}
private static T[] ConvertParamsList<T>(List<object> paramsList)
{
var array = new T[paramsList.Count];
for (int i = 0; i < array.Length; i++)
array[i] = (T)paramsList[i];
return array;
} }
} }
} }

View File

@@ -40,16 +40,8 @@ namespace Discord.Commands
[typeof(decimal)] = new SimpleTypeReader<decimal>(), [typeof(decimal)] = new SimpleTypeReader<decimal>(),
[typeof(DateTime)] = new SimpleTypeReader<DateTime>(), [typeof(DateTime)] = new SimpleTypeReader<DateTime>(),
[typeof(DateTimeOffset)] = new SimpleTypeReader<DateTimeOffset>(), [typeof(DateTimeOffset)] = new SimpleTypeReader<DateTimeOffset>(),
//TODO: Do we want to support any other interfaces?
//[typeof(IMentionable)] = new GeneralTypeReader(),
//[typeof(ISnowflakeEntity)] = new GeneralTypeReader(),
//[typeof(IEntity<ulong>)] = new GeneralTypeReader(),
[typeof(IMessage)] = new MessageTypeReader(), [typeof(IMessage)] = new MessageTypeReader(),
//[typeof(IAttachment)] = new xxx(),
//[typeof(IEmbed)] = new xxx(),
[typeof(IChannel)] = new ChannelTypeReader<IChannel>(), [typeof(IChannel)] = new ChannelTypeReader<IChannel>(),
[typeof(IDMChannel)] = new ChannelTypeReader<IDMChannel>(), [typeof(IDMChannel)] = new ChannelTypeReader<IDMChannel>(),
@@ -61,10 +53,8 @@ namespace Discord.Commands
[typeof(IVoiceChannel)] = new ChannelTypeReader<IVoiceChannel>(), [typeof(IVoiceChannel)] = new ChannelTypeReader<IVoiceChannel>(),
//[typeof(IGuild)] = new GuildTypeReader<IGuild>(), //[typeof(IGuild)] = new GuildTypeReader<IGuild>(),
//[typeof(IUserGuild)] = new GuildTypeReader<IUserGuild>(),
//[typeof(IGuildIntegration)] = new xxx(),
[typeof(IRole)] = new RoleTypeReader(), [typeof(IRole)] = new RoleTypeReader<IRole>(),
//[typeof(IInvite)] = new InviteTypeReader<IInvite>(), //[typeof(IInvite)] = new InviteTypeReader<IInvite>(),
//[typeof(IInviteMetadata)] = new InviteTypeReader<IInviteMetadata>(), //[typeof(IInviteMetadata)] = new InviteTypeReader<IInviteMetadata>(),
@@ -72,10 +62,6 @@ namespace Discord.Commands
[typeof(IUser)] = new UserTypeReader<IUser>(), [typeof(IUser)] = new UserTypeReader<IUser>(),
[typeof(IGroupUser)] = new UserTypeReader<IGroupUser>(), [typeof(IGroupUser)] = new UserTypeReader<IGroupUser>(),
[typeof(IGuildUser)] = new UserTypeReader<IGuildUser>(), [typeof(IGuildUser)] = new UserTypeReader<IGuildUser>(),
//[typeof(ISelfUser)] = new UserTypeReader<ISelfUser>(),
//[typeof(IPresence)] = new UserTypeReader<IPresence>(),
//[typeof(IVoiceState)] = new UserTypeReader<IVoiceState>(),
//[typeof(IConnection)] = new xxx(),
}; };
} }
@@ -201,8 +187,9 @@ namespace Discord.Commands
return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command."); return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command.");
} }
public Task<IResult> Execute(IMessage message, int argPos) => Execute(message, message.Content.Substring(argPos)); public Task<IResult> Execute(IMessage message, int argPos, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception)
public async Task<IResult> Execute(IMessage message, string input) => Execute(message, message.Content.Substring(argPos), multiMatchHandling);
public async Task<IResult> Execute(IMessage message, string input, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception)
{ {
var searchResult = Search(message, input); var searchResult = Search(message, input);
if (!searchResult.IsSuccess) if (!searchResult.IsSuccess)
@@ -223,14 +210,29 @@ namespace Discord.Commands
var parseResult = await commands[i].Parse(message, searchResult, preconditionResult); var parseResult = await commands[i].Parse(message, searchResult, preconditionResult);
if (!parseResult.IsSuccess) if (!parseResult.IsSuccess)
{ {
if (commands.Count == 1) if (parseResult.Error == CommandError.MultipleMatches)
return parseResult; {
else TypeReaderValue[] argList, paramList;
continue; switch (multiMatchHandling)
{
case MultiMatchHandling.Best:
argList = parseResult.ArgValues.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToArray();
paramList = parseResult.ParamValues?.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToArray();
parseResult = ParseResult.FromSuccess(argList, paramList);
break;
}
}
if (!parseResult.IsSuccess)
{
if (commands.Count == 1)
return parseResult;
else
continue;
}
} }
var executeResult = await commands[i].Execute(message, parseResult); return await commands[i].Execute(message, parseResult);
return executeResult;
} }
return SearchResult.FromError(CommandError.UnknownCommand, "This input does not match any overload."); return SearchResult.FromError(CommandError.UnknownCommand, "This input does not match any overload.");

View File

@@ -0,0 +1,8 @@
namespace Discord.Commands
{
public enum MultiMatchHandling
{
Exception,
Best
}
}

View File

@@ -1,4 +1,6 @@
using System; using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
@@ -9,40 +11,37 @@ namespace Discord.Commands
{ {
public override async Task<TypeReaderResult> Read(IMessage context, string input) public override async Task<TypeReaderResult> Read(IMessage context, string input)
{ {
IGuildChannel guildChannel = context.Channel as IGuildChannel; var guild = (context.Channel as IGuildChannel)?.Guild;
IChannel result = null;
if (guildChannel != null) if (guild != null)
{ {
//By Id var results = new Dictionary<ulong, TypeReaderValue>();
var channels = await guild.GetChannelsAsync().ConfigureAwait(false);
ulong id; ulong id;
if (MentionUtils.TryParseChannel(input, out id) || ulong.TryParse(input, out id))
{
var channel = await guildChannel.Guild.GetChannelAsync(id).ConfigureAwait(false);
if (channel != null)
result = channel;
}
//By Name //By Mention (1.0)
if (result == null) if (MentionUtils.TryParseChannel(input, out id))
{ AddResult(results, await guild.GetUserAsync(id).ConfigureAwait(false) as T, 1.00f);
var channels = await guildChannel.Guild.GetChannelsAsync().ConfigureAwait(false);
var filteredChannels = channels.Where(x => string.Equals(input, x.Name, StringComparison.OrdinalIgnoreCase)).ToArray(); //By Id (0.9)
if (filteredChannels.Length > 1) if (ulong.TryParse(input, NumberStyles.None, CultureInfo.InvariantCulture, out id))
return TypeReaderResult.FromError(CommandError.MultipleMatches, "Multiple channels found."); AddResult(results, await guild.GetChannelAsync(id).ConfigureAwait(false) as T, 0.90f);
else if (filteredChannels.Length == 1)
result = filteredChannels[0]; //By Name (0.7-0.8)
} foreach (var channel in channels.Where(x => string.Equals(input, x.Name, StringComparison.OrdinalIgnoreCase)))
AddResult(results, channel as T, channel.Name == input ? 0.80f : 0.70f);
if (results.Count > 0)
return TypeReaderResult.FromSuccess(results.Values);
} }
if (result == null) return TypeReaderResult.FromError(CommandError.ObjectNotFound, "Channel not found.");
return TypeReaderResult.FromError(CommandError.ObjectNotFound, "Channel not found."); }
T castResult = result as T; private void AddResult(Dictionary<ulong, TypeReaderValue> results, T channel, float score)
if (castResult == null) {
return TypeReaderResult.FromError(CommandError.CastFailed, $"Channel is not a {typeof(T).Name}."); if (channel != null && !results.ContainsKey(channel.Id))
else results.Add(channel.Id, new TypeReaderValue(channel, score));
return TypeReaderResult.FromSuccess(castResult);
} }
} }
} }

View File

@@ -52,14 +52,14 @@ namespace Discord.Commands
if (_enumsByValue.TryGetValue(baseValue, out enumValue)) if (_enumsByValue.TryGetValue(baseValue, out enumValue))
return Task.FromResult(TypeReaderResult.FromSuccess(enumValue)); return Task.FromResult(TypeReaderResult.FromSuccess(enumValue));
else else
return Task.FromResult(TypeReaderResult.FromError(CommandError.CastFailed, $"Value is not a {_enumType.Name}")); return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, $"Value is not a {_enumType.Name}"));
} }
else else
{ {
if (_enumsByName.TryGetValue(input.ToLower(), out enumValue)) if (_enumsByName.TryGetValue(input.ToLower(), out enumValue))
return Task.FromResult(TypeReaderResult.FromSuccess(enumValue)); return Task.FromResult(TypeReaderResult.FromSuccess(enumValue));
else else
return Task.FromResult(TypeReaderResult.FromError(CommandError.CastFailed, $"Value is not a {_enumType.Name}")); return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, $"Value is not a {_enumType.Name}"));
} }
} }
} }

View File

@@ -7,18 +7,17 @@ namespace Discord.Commands
{ {
public override Task<TypeReaderResult> Read(IMessage context, string input) public override Task<TypeReaderResult> Read(IMessage context, string input)
{ {
//By Id
ulong id; ulong id;
//By Id (1.0)
if (ulong.TryParse(input, NumberStyles.None, CultureInfo.InvariantCulture, out id)) if (ulong.TryParse(input, NumberStyles.None, CultureInfo.InvariantCulture, out id))
{ {
var msg = context.Channel.GetCachedMessage(id); var msg = context.Channel.GetCachedMessage(id);
if (msg == null) if (msg != null)
return Task.FromResult(TypeReaderResult.FromError(CommandError.ObjectNotFound, "Message not found."));
else
return Task.FromResult(TypeReaderResult.FromSuccess(msg)); return Task.FromResult(TypeReaderResult.FromSuccess(msg));
} }
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Message Id.")); return Task.FromResult(TypeReaderResult.FromError(CommandError.ObjectNotFound, "Message not found."));
} }
} }
} }

View File

@@ -1,36 +1,46 @@
using System; using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace Discord.Commands namespace Discord.Commands
{ {
internal class RoleTypeReader : TypeReader internal class RoleTypeReader<T> : TypeReader
where T : class, IRole
{ {
public override Task<TypeReaderResult> Read(IMessage context, string input) public override Task<TypeReaderResult> Read(IMessage context, string input)
{ {
IGuildChannel guildChannel = context.Channel as IGuildChannel; var guild = (context.Channel as IGuildChannel)?.Guild;
ulong id;
if (guildChannel != null) if (guild != null)
{ {
//By Id var results = new Dictionary<ulong, TypeReaderValue>();
ulong id; var roles = guild.Roles;
if (MentionUtils.TryParseRole(input, out id) || ulong.TryParse(input, out id))
{
var channel = guildChannel.Guild.GetRole(id);
if (channel != null)
return Task.FromResult(TypeReaderResult.FromSuccess(channel));
}
//By Name //By Mention (1.0)
var roles = guildChannel.Guild.Roles; if (MentionUtils.TryParseRole(input, out id))
var filteredRoles = roles.Where(x => string.Equals(input, x.Name, StringComparison.OrdinalIgnoreCase)).ToArray(); AddResult(results, guild.GetRole(id) as T, 1.00f);
if (filteredRoles.Length > 1)
return Task.FromResult(TypeReaderResult.FromError(CommandError.MultipleMatches, "Multiple roles found.")); //By Id (0.9)
else if (filteredRoles.Length == 1) if (ulong.TryParse(input, NumberStyles.None, CultureInfo.InvariantCulture, out id))
return Task.FromResult(TypeReaderResult.FromSuccess(filteredRoles[0])); AddResult(results, guild.GetRole(id) as T, 0.90f);
//By Name (0.7-0.8)
foreach (var role in roles.Where(x => string.Equals(input, x.Name, StringComparison.OrdinalIgnoreCase)))
AddResult(results, role as T, role.Name == input ? 0.80f : 0.70f);
if (results.Count > 0)
return Task.FromResult(TypeReaderResult.FromSuccess(results));
} }
return Task.FromResult(TypeReaderResult.FromError(CommandError.ObjectNotFound, "Role not found.")); return Task.FromResult(TypeReaderResult.FromError(CommandError.ObjectNotFound, "Role not found."));
} }
private void AddResult(Dictionary<ulong, TypeReaderValue> results, T role, float score)
{
if (role != null && !results.ContainsKey(role.Id))
results.Add(role.Id, new TypeReaderValue(role, score));
}
} }
} }

View File

@@ -16,8 +16,7 @@ namespace Discord.Commands
T value; T value;
if (_tryParse(input, out value)) if (_tryParse(input, out value))
return Task.FromResult(TypeReaderResult.FromSuccess(value)); return Task.FromResult(TypeReaderResult.FromSuccess(value));
else return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, $"Failed to parse {typeof(T).Name}"));
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, $"Failed to parse {typeof(T).Name}"));
} }
} }
} }

View File

@@ -1,4 +1,6 @@
using System; using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq; using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
@@ -9,54 +11,78 @@ namespace Discord.Commands
{ {
public override async Task<TypeReaderResult> Read(IMessage context, string input) public override async Task<TypeReaderResult> Read(IMessage context, string input)
{ {
IUser result = null; var results = new Dictionary<ulong, TypeReaderValue>();
var guild = (context.Channel as IGuildChannel)?.Guild;
//By Id IReadOnlyCollection<IUser> channelUsers = await context.Channel.GetUsersAsync().ConfigureAwait(false);
IReadOnlyCollection<IGuildUser> guildUsers = null;
ulong id; ulong id;
if (MentionUtils.TryParseUser(input, out id) || ulong.TryParse(input, out id))
if (guild != null)
guildUsers = await guild.GetUsersAsync().ConfigureAwait(false);
//By Mention (1.0)
if (MentionUtils.TryParseUser(input, out id))
{ {
var user = await context.Channel.GetUserAsync(id).ConfigureAwait(false); if (guild != null)
if (user != null) AddResult(results, await guild.GetUserAsync(id).ConfigureAwait(false) as T, 1.00f);
result = user; else
AddResult(results, await context.Channel.GetUserAsync(id).ConfigureAwait(false) as T, 1.00f);
} }
//By Username + Discriminator //By Id (0.9)
if (result == null) if (ulong.TryParse(input, NumberStyles.None, CultureInfo.InvariantCulture, out id))
{ {
int index = input.LastIndexOf('#'); if (guild != null)
if (index >= 0) AddResult(results, await guild.GetUserAsync(id).ConfigureAwait(false) as T, 0.90f);
else
AddResult(results, await context.Channel.GetUserAsync(id).ConfigureAwait(false) as T, 0.90f);
}
//By Username + Discriminator (0.7-0.85)
int index = input.LastIndexOf('#');
if (index >= 0)
{
string username = input.Substring(0, index);
ushort discriminator;
if (ushort.TryParse(input.Substring(index + 1), out discriminator))
{ {
string username = input.Substring(0, index); var channelUser = channelUsers.Where(x => x.DiscriminatorValue == discriminator &&
ushort discriminator; string.Equals(username, x.Username, StringComparison.OrdinalIgnoreCase)).FirstOrDefault();
if (ushort.TryParse(input.Substring(index + 1), out discriminator)) AddResult(results, channelUser as T, channelUser.Username == username ? 0.85f : 0.75f);
{
var users = await context.Channel.GetUsersAsync().ConfigureAwait(false); var guildUser = channelUsers.Where(x => x.DiscriminatorValue == discriminator &&
result = users.Where(x => string.Equals(username, x.Username, StringComparison.OrdinalIgnoreCase)).FirstOrDefault();
x.DiscriminatorValue == discriminator && AddResult(results, guildUser as T, guildUser.Username == username ? 0.80f : 0.70f);
string.Equals(username, x.Username, StringComparison.OrdinalIgnoreCase)).FirstOrDefault();
}
} }
} }
//By Username //By Username (0.5-0.6)
if (result == null)
{ {
var users = await context.Channel.GetUsersAsync().ConfigureAwait(false); foreach (var channelUser in channelUsers.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase)))
var filteredUsers = users.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase)).ToArray(); AddResult(results, channelUser as T, channelUser.Username == input ? 0.65f : 0.55f);
if (filteredUsers.Length > 1)
return TypeReaderResult.FromError(CommandError.MultipleMatches, "Multiple users found."); foreach (var guildUser in guildUsers.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase)))
else if (filteredUsers.Length == 1) AddResult(results, guildUser as T, guildUser.Username == input ? 0.60f : 0.50f);
result = filteredUsers[0];
} }
if (result == null) //By Nickname (0.5-0.6)
return TypeReaderResult.FromError(CommandError.ObjectNotFound, "User not found."); {
foreach (var channelUser in channelUsers.Where(x => string.Equals(input, (x as IGuildUser).Nickname, StringComparison.OrdinalIgnoreCase)))
AddResult(results, channelUser as T, (channelUser as IGuildUser).Nickname == input ? 0.65f : 0.55f);
T castResult = result as T; foreach (var guildUser in guildUsers.Where(x => string.Equals(input, (x as IGuildUser).Nickname, StringComparison.OrdinalIgnoreCase)))
if (castResult == null) AddResult(results, guildUser as T, (guildUser as IGuildUser).Nickname == input ? 0.60f : 0.50f);
return TypeReaderResult.FromError(CommandError.CastFailed, $"User is not a {typeof(T).Name}."); }
else
return TypeReaderResult.FromSuccess(castResult); if (results.Count > 0)
return TypeReaderResult.FromSuccess(results.Values.ToArray());
return TypeReaderResult.FromError(CommandError.ObjectNotFound, "User not found.");
}
private void AddResult(Dictionary<ulong, TypeReaderValue> results, T user, float score)
{
if (user != null && !results.ContainsKey(user.Id))
results.Add(user.Id, new TypeReaderValue(user, score));
} }
} }
} }

View File

@@ -6,28 +6,53 @@ namespace Discord.Commands
[DebuggerDisplay(@"{DebuggerDisplay,nq}")] [DebuggerDisplay(@"{DebuggerDisplay,nq}")]
public struct ParseResult : IResult public struct ParseResult : IResult
{ {
public IReadOnlyList<object> Values { get; } public IReadOnlyList<TypeReaderResult> ArgValues { get; }
public IReadOnlyList<TypeReaderResult> ParamValues { get; }
public CommandError? Error { get; } public CommandError? Error { get; }
public string ErrorReason { get; } public string ErrorReason { get; }
public bool IsSuccess => !Error.HasValue; public bool IsSuccess => !Error.HasValue;
private ParseResult(IReadOnlyList<object> values, CommandError? error, string errorReason) private ParseResult(IReadOnlyList<TypeReaderResult> argValues, IReadOnlyList<TypeReaderResult> paramValue, CommandError? error, string errorReason)
{ {
Values = values; ArgValues = argValues;
ParamValues = paramValue;
Error = error; Error = error;
ErrorReason = errorReason; ErrorReason = errorReason;
} }
public static ParseResult FromSuccess(IReadOnlyList<object> values) public static ParseResult FromSuccess(IReadOnlyList<TypeReaderResult> argValues, IReadOnlyList<TypeReaderResult> paramValues)
=> new ParseResult(values, null, null); {
for (int i = 0; i < argValues.Count; i++)
{
if (argValues[i].Values.Count > 1)
return new ParseResult(argValues, paramValues, CommandError.MultipleMatches, "Multiple matches found.");
}
for (int i = 0; i < paramValues.Count; i++)
{
if (paramValues[i].Values.Count > 1)
return new ParseResult(argValues, paramValues, CommandError.MultipleMatches, "Multiple matches found.");
}
return new ParseResult(argValues, paramValues, null, null);
}
public static ParseResult FromSuccess(IReadOnlyList<TypeReaderValue> argValues, IReadOnlyList<TypeReaderValue> paramValues)
{
var argList = new TypeReaderResult[argValues.Count];
for (int i = 0; i < argValues.Count; i++)
argList[i] = TypeReaderResult.FromSuccess(argValues[i]);
var paramList = new TypeReaderResult[paramValues.Count];
for (int i = 0; i < paramValues.Count; i++)
paramList[i] = TypeReaderResult.FromSuccess(paramValues[i]);
return new ParseResult(argList, paramList, null, null);
}
public static ParseResult FromError(CommandError error, string reason) public static ParseResult FromError(CommandError error, string reason)
=> new ParseResult(null, error, reason); => new ParseResult(null, null, error, reason);
public static ParseResult FromError(IResult result) public static ParseResult FromError(IResult result)
=> new ParseResult(null, result.Error, result.ErrorReason); => new ParseResult(null, null, result.Error, result.ErrorReason);
public override string ToString() => IsSuccess ? "Success" : $"{Error}: {ErrorReason}"; public override string ToString() => IsSuccess ? "Success" : $"{Error}: {ErrorReason}";
private string DebuggerDisplay => IsSuccess ? $"Success ({Values.Count} Values)" : $"{Error}: {ErrorReason}"; private string DebuggerDisplay => IsSuccess ? $"Success ({ArgValues.Count}{(ParamValues != null ? $" +{ParamValues.Count} Values" : "")})" : $"{Error}: {ErrorReason}";
} }
} }

View File

@@ -1,32 +1,56 @@
using System.Diagnostics; using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Diagnostics;
using System.Linq;
namespace Discord.Commands namespace Discord.Commands
{ {
[DebuggerDisplay(@"{DebuggerDisplay,nq}")] [DebuggerDisplay(@"{DebuggerDisplay,nq}")]
public struct TypeReaderResult : IResult public struct TypeReaderValue
{ {
public object Value { get; } public object Value { get; }
public float Score { get; }
public TypeReaderValue(object value, float score)
{
Value = value;
Score = score;
}
public override string ToString() => Value?.ToString();
private string DebuggerDisplay => $"[{Value}, {Math.Round(Score, 2)}]";
}
[DebuggerDisplay(@"{DebuggerDisplay,nq}")]
public struct TypeReaderResult : IResult
{
public IReadOnlyCollection<TypeReaderValue> Values { get; }
public CommandError? Error { get; } public CommandError? Error { get; }
public string ErrorReason { get; } public string ErrorReason { get; }
public bool IsSuccess => !Error.HasValue; public bool IsSuccess => !Error.HasValue;
private TypeReaderResult(object value, CommandError? error, string errorReason) private TypeReaderResult(IReadOnlyCollection<TypeReaderValue> values, CommandError? error, string errorReason)
{ {
Value = value; Values = values;
Error = error; Error = error;
ErrorReason = errorReason; ErrorReason = errorReason;
} }
public static TypeReaderResult FromSuccess(object value) public static TypeReaderResult FromSuccess(object value)
=> new TypeReaderResult(value, null, null); => new TypeReaderResult(ImmutableArray.Create(new TypeReaderValue(value, 1.0f)), null, null);
public static TypeReaderResult FromSuccess(TypeReaderValue value)
=> new TypeReaderResult(ImmutableArray.Create(value), null, null);
public static TypeReaderResult FromSuccess(IReadOnlyCollection<TypeReaderValue> values)
=> new TypeReaderResult(values, null, null);
public static TypeReaderResult FromError(CommandError error, string reason) public static TypeReaderResult FromError(CommandError error, string reason)
=> new TypeReaderResult(null, error, reason); => new TypeReaderResult(null, error, reason);
public static TypeReaderResult FromError(IResult result) public static TypeReaderResult FromError(IResult result)
=> new TypeReaderResult(null, result.Error, result.ErrorReason); => new TypeReaderResult(null, result.Error, result.ErrorReason);
public override string ToString() => IsSuccess ? "Success" : $"{Error}: {ErrorReason}"; public override string ToString() => IsSuccess ? "Success" : $"{Error}: {ErrorReason}";
private string DebuggerDisplay => IsSuccess ? $"Success ({Value})" : $"{Error}: {ErrorReason}"; private string DebuggerDisplay => IsSuccess ? $"Success ({string.Join(", ", Values)})" : $"{Error}: {ErrorReason}";
} }
} }