using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Collections.Immutable; using System.Linq; using System.Reflection; using System.Threading; using System.Threading.Tasks; using Discord.Commands.Builders; namespace Discord.Commands { public class CommandService { private readonly SemaphoreSlim _moduleLock; private readonly ConcurrentDictionary _typedModuleDefs; private readonly ConcurrentDictionary> _typeReaders; private readonly ConcurrentDictionary _defaultTypeReaders; private readonly ConcurrentBag _moduleDefs; private readonly CommandMap _map; internal readonly bool _caseSensitive; internal readonly RunMode _defaultRunMode; internal readonly char _splitCharacter; public IEnumerable Modules => _moduleDefs.Select(x => x); public IEnumerable Commands => _moduleDefs.SelectMany(x => x.Commands); public ILookup TypeReaders => _typeReaders.SelectMany(x => x.Value.Select(y => new {y.Key, y.Value})).ToLookup(x => x.Key, x => x.Value); public CommandService() : this(new CommandServiceConfig()) { } public CommandService(CommandServiceConfig config) { _moduleLock = new SemaphoreSlim(1, 1); _typedModuleDefs = new ConcurrentDictionary(); _moduleDefs = new ConcurrentBag(); _map = new CommandMap(); _typeReaders = new ConcurrentDictionary>(); _defaultTypeReaders = new ConcurrentDictionary { [typeof(bool)] = new SimpleTypeReader(), [typeof(char)] = new SimpleTypeReader(), [typeof(string)] = new SimpleTypeReader(), [typeof(byte)] = new SimpleTypeReader(), [typeof(sbyte)] = new SimpleTypeReader(), [typeof(ushort)] = new SimpleTypeReader(), [typeof(short)] = new SimpleTypeReader(), [typeof(uint)] = new SimpleTypeReader(), [typeof(int)] = new SimpleTypeReader(), [typeof(ulong)] = new SimpleTypeReader(), [typeof(long)] = new SimpleTypeReader(), [typeof(float)] = new SimpleTypeReader(), [typeof(double)] = new SimpleTypeReader(), [typeof(decimal)] = new SimpleTypeReader(), [typeof(DateTime)] = new SimpleTypeReader(), [typeof(DateTimeOffset)] = new SimpleTypeReader(), [typeof(IMessage)] = new MessageTypeReader(), [typeof(IUserMessage)] = new MessageTypeReader(), [typeof(IChannel)] = new ChannelTypeReader(), [typeof(IDMChannel)] = new ChannelTypeReader(), [typeof(IGroupChannel)] = new ChannelTypeReader(), [typeof(IGuildChannel)] = new ChannelTypeReader(), [typeof(IMessageChannel)] = new ChannelTypeReader(), [typeof(IPrivateChannel)] = new ChannelTypeReader(), [typeof(ITextChannel)] = new ChannelTypeReader(), [typeof(IVoiceChannel)] = new ChannelTypeReader(), [typeof(IRole)] = new RoleTypeReader(), [typeof(IUser)] = new UserTypeReader(), [typeof(IGroupUser)] = new UserTypeReader(), [typeof(IGuildUser)] = new UserTypeReader(), }; _caseSensitive = config.CaseSensitiveCommands; _defaultRunMode = config.DefaultRunMode; _splitCharacter = config.CommandSplitCharacter; } //Modules public async Task CreateModuleAsync(string primaryAlias, Action buildFunc) { await _moduleLock.WaitAsync().ConfigureAwait(false); try { var builder = new ModuleBuilder(this, null, primaryAlias); buildFunc(builder); var module = builder.Build(this); return LoadModuleInternal(module); } finally { _moduleLock.Release(); } } public async Task AddModuleAsync() { await _moduleLock.WaitAsync().ConfigureAwait(false); try { var typeInfo = typeof(T).GetTypeInfo(); if (_typedModuleDefs.ContainsKey(typeof(T))) throw new ArgumentException($"This module has already been added."); var module = ModuleClassBuilder.Build(this, typeInfo).FirstOrDefault(); if (module.Value == default(ModuleInfo)) throw new InvalidOperationException($"Could not build the module {typeof(T).FullName}, did you pass an invalid type?"); _typedModuleDefs[module.Key] = module.Value; return LoadModuleInternal(module.Value); } finally { _moduleLock.Release(); } } public async Task> AddModulesAsync(Assembly assembly) { await _moduleLock.WaitAsync().ConfigureAwait(false); try { var types = ModuleClassBuilder.Search(assembly).ToArray(); var moduleDefs = ModuleClassBuilder.Build(types, this); foreach (var info in moduleDefs) { _typedModuleDefs[info.Key] = info.Value; LoadModuleInternal(info.Value); } return moduleDefs.Select(x => x.Value).ToImmutableArray(); } finally { _moduleLock.Release(); } } private ModuleInfo LoadModuleInternal(ModuleInfo module) { _moduleDefs.Add(module); foreach (var command in module.Commands) _map.AddCommand(command, this); foreach (var submodule in module.Submodules) LoadModuleInternal(submodule); return module; } public async Task RemoveModuleAsync(ModuleInfo module) { await _moduleLock.WaitAsync().ConfigureAwait(false); try { return RemoveModuleInternal(module); } finally { _moduleLock.Release(); } } public async Task RemoveModuleAsync() { await _moduleLock.WaitAsync().ConfigureAwait(false); try { ModuleInfo module; _typedModuleDefs.TryGetValue(typeof(T), out module); if (module == default(ModuleInfo)) return false; return RemoveModuleInternal(module); } finally { _moduleLock.Release(); } } private bool RemoveModuleInternal(ModuleInfo module) { var defsRemove = module; if (!_moduleDefs.TryTake(out defsRemove)) return false; foreach (var cmd in module.Commands) _map.RemoveCommand(cmd, this); foreach (var submodule in module.Submodules) { RemoveModuleInternal(submodule); } return true; } //Type Readers public void AddTypeReader(TypeReader reader) { var readers = _typeReaders.GetOrAdd(typeof(T), x => new ConcurrentDictionary()); readers[reader.GetType()] = reader; } public void AddTypeReader(Type type, TypeReader reader) { var readers = _typeReaders.GetOrAdd(type, x=> new ConcurrentDictionary()); readers[reader.GetType()] = reader; } internal IDictionary GetTypeReaders(Type type) { ConcurrentDictionary definedTypeReaders; if (_typeReaders.TryGetValue(type, out definedTypeReaders)) return definedTypeReaders; return null; } internal TypeReader GetDefaultTypeReader(Type type) { TypeReader reader; if (_defaultTypeReaders.TryGetValue(type, out reader)) return reader; return null; } //Execution public SearchResult Search(CommandContext context, int argPos) => Search(context, context.Message.Content.Substring(argPos)); public SearchResult Search(CommandContext context, string input) { string searchInput = _caseSensitive ? input : input.ToLowerInvariant(); var matches = _map.GetCommands(searchInput, this).OrderByDescending(x => x.Priority).ToImmutableArray(); if (matches.Length > 0) return SearchResult.FromSuccess(input, matches); else return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command."); } public Task ExecuteAsync(CommandContext context, int argPos, IDependencyMap dependencyMap = null, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception) => ExecuteAsync(context, context.Message.Content.Substring(argPos), dependencyMap, multiMatchHandling); public async Task ExecuteAsync(CommandContext context, string input, IDependencyMap dependencyMap = null, MultiMatchHandling multiMatchHandling = MultiMatchHandling.Exception) { dependencyMap = dependencyMap ?? DependencyMap.Empty; var searchResult = Search(context, input); if (!searchResult.IsSuccess) return searchResult; var commands = searchResult.Commands; for (int i = commands.Count - 1; i >= 0; i--) { var preconditionResult = await commands[i].CheckPreconditionsAsync(context, dependencyMap).ConfigureAwait(false); if (!preconditionResult.IsSuccess) { if (commands.Count == 1) return preconditionResult; else continue; } var parseResult = await commands[i].ParseAsync(context, searchResult, preconditionResult).ConfigureAwait(false); if (!parseResult.IsSuccess) { if (parseResult.Error == CommandError.MultipleMatches) { IReadOnlyList argList, paramList; switch (multiMatchHandling) { case MultiMatchHandling.Best: argList = parseResult.ArgValues.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToImmutableArray(); paramList = parseResult.ParamValues.Select(x => x.Values.OrderByDescending(y => y.Score).First()).ToImmutableArray(); parseResult = ParseResult.FromSuccess(argList, paramList); break; } } if (!parseResult.IsSuccess) { if (commands.Count == 1) return parseResult; else continue; } } return await commands[i].Execute(context, parseResult, dependencyMap).ConfigureAwait(false); } return SearchResult.FromError(CommandError.UnknownCommand, "This input does not match any overload."); } } }