using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; using System.Reflection; using System.Threading; using System.Threading.Tasks; namespace Discord.Commands { public class CommandService { private static readonly TypeInfo _moduleTypeInfo = typeof(ModuleBase).GetTypeInfo(); private readonly SemaphoreSlim _moduleLock; private readonly ConcurrentDictionary _moduleDefs; private readonly ConcurrentDictionary _typeReaders; private readonly CommandMap _map; public IEnumerable Modules => _moduleDefs.Select(x => x.Value); public IEnumerable Commands => _moduleDefs.SelectMany(x => x.Value.Commands); public CommandService() { _moduleLock = new SemaphoreSlim(1, 1); _moduleDefs = new ConcurrentDictionary(); _map = new CommandMap(); _typeReaders = new ConcurrentDictionary { [typeof(bool)] = new SimpleTypeReader(), [typeof(char)] = new SimpleTypeReader(), [typeof(string)] = new SimpleTypeReader(), [typeof(byte)] = new SimpleTypeReader(), [typeof(sbyte)] = new SimpleTypeReader(), [typeof(ushort)] = new SimpleTypeReader(), [typeof(short)] = new SimpleTypeReader(), [typeof(uint)] = new SimpleTypeReader(), [typeof(int)] = new SimpleTypeReader(), [typeof(ulong)] = new SimpleTypeReader(), [typeof(long)] = new SimpleTypeReader(), [typeof(float)] = new SimpleTypeReader(), [typeof(double)] = new SimpleTypeReader(), [typeof(decimal)] = new SimpleTypeReader(), [typeof(DateTime)] = new SimpleTypeReader(), [typeof(DateTimeOffset)] = new SimpleTypeReader(), [typeof(IMessage)] = new MessageTypeReader(), [typeof(IUserMessage)] = new MessageTypeReader(), [typeof(IChannel)] = new ChannelTypeReader(), [typeof(IDMChannel)] = new ChannelTypeReader(), [typeof(IGroupChannel)] = new ChannelTypeReader(), [typeof(IGuildChannel)] = new ChannelTypeReader(), [typeof(IMessageChannel)] = new ChannelTypeReader(), [typeof(IPrivateChannel)] = new ChannelTypeReader(), [typeof(ITextChannel)] = new ChannelTypeReader(), [typeof(IVoiceChannel)] = new ChannelTypeReader(), [typeof(IRole)] = new RoleTypeReader(), [typeof(IUser)] = new UserTypeReader(), [typeof(IGroupUser)] = new UserTypeReader(), [typeof(IGuildUser)] = new UserTypeReader(), }; } //Modules public async Task AddModule(IDependencyMap dependencyMap = null) { await _moduleLock.WaitAsync().ConfigureAwait(false); try { if (_moduleDefs.ContainsKey(typeof(T))) throw new ArgumentException($"This module has already been added."); var typeInfo = typeof(T).GetTypeInfo(); if (!_moduleTypeInfo.IsAssignableFrom(typeInfo)) throw new ArgumentException($"Modules must inherit ModuleBase."); return AddModuleInternal(typeInfo, dependencyMap); } finally { _moduleLock.Release(); } } public async Task> AddModules(Assembly assembly, IDependencyMap dependencyMap = null) { var moduleDefs = ImmutableArray.CreateBuilder(); await _moduleLock.WaitAsync().ConfigureAwait(false); try { foreach (var type in assembly.ExportedTypes) { if (!_moduleDefs.ContainsKey(type)) { var typeInfo = type.GetTypeInfo(); if (_moduleTypeInfo.IsAssignableFrom(typeInfo)) { var dontAutoLoad = typeInfo.GetCustomAttribute(); if (dontAutoLoad == null) moduleDefs.Add(AddModuleInternal(typeInfo, dependencyMap)); } } } return moduleDefs.ToImmutable(); } finally { _moduleLock.Release(); } } private ModuleInfo AddModuleInternal(TypeInfo typeInfo, IDependencyMap dependencyMap) { var moduleDef = new ModuleInfo(typeInfo, this, dependencyMap); _moduleDefs[typeInfo.BaseType] = moduleDef; foreach (var cmd in moduleDef.Commands) _map.AddCommand(cmd); return moduleDef; } public async Task RemoveModule(ModuleInfo module) { await _moduleLock.WaitAsync().ConfigureAwait(false); try { return RemoveModuleInternal(module.Source.BaseType); } finally { _moduleLock.Release(); } } public async Task RemoveModule() { await _moduleLock.WaitAsync().ConfigureAwait(false); try { return RemoveModuleInternal(typeof(T)); } finally { _moduleLock.Release(); } } private bool RemoveModuleInternal(Type type) { ModuleInfo unloadedModule; if (_moduleDefs.TryRemove(type, out unloadedModule)) { foreach (var cmd in unloadedModule.Commands) _map.RemoveCommand(cmd); return true; } else return false; } //Type Readers public void AddTypeReader(TypeReader reader) { _typeReaders[typeof(T)] = reader; } public void AddTypeReader(Type type, TypeReader reader) { _typeReaders[type] = reader; } internal TypeReader GetTypeReader(Type type) { TypeReader reader; if (_typeReaders.TryGetValue(type, out reader)) return reader; return null; } //Execution public SearchResult Search(CommandContext context, int argPos) => Search(context, context.Message.Content.Substring(argPos)); public SearchResult Search(CommandContext context, string input) { string lowerInput = input.ToLowerInvariant(); var matches = _map.GetCommands(input).OrderByDescending(x => x.Priority).ToImmutableArray(); if (matches.Length > 0) return SearchResult.FromSuccess(input, matches); else return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command."); } public Task Execute(CommandContext context, int argPos, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception) => Execute(context, context.Message.Content.Substring(argPos), multiMatchHandling); public async Task Execute(CommandContext context, string input, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception) { var searchResult = Search(context, input); if (!searchResult.IsSuccess) return searchResult; var commands = searchResult.Commands; for (int i = commands.Count - 1; i >= 0; i--) { var preconditionResult = await commands[i].CheckPreconditions(context).ConfigureAwait(false); if (!preconditionResult.IsSuccess) { if (commands.Count == 1) return preconditionResult; else continue; } var parseResult = await commands[i].Parse(context, searchResult, preconditionResult).ConfigureAwait(false); if (!parseResult.IsSuccess) { if (parseResult.Error == CommandError.MultipleMatches) { IReadOnlyList argList, paramList; switch (multiMatchHandling) { case MultiMatchHandling.Best: argList = parseResult.ArgValues.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToImmutableArray(); paramList = parseResult.ParamValues.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToImmutableArray(); parseResult = ParseResult.FromSuccess(argList, paramList); break; } } if (!parseResult.IsSuccess) { if (commands.Count == 1) return parseResult; else continue; } } return await commands[i].Execute(context, parseResult).ConfigureAwait(false); } return SearchResult.FromError(CommandError.UnknownCommand, "This input does not match any overload."); } } }