Files
Discord.Net/src/Discord.Net.Commands/CommandService.cs
2016-06-27 06:56:24 -03:00

297 lines
13 KiB
C#

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Globalization;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
namespace Discord.Commands
{
public class CommandService
{
private readonly SemaphoreSlim _moduleLock;
private readonly ConcurrentDictionary<object, Module> _modules;
private readonly ConcurrentDictionary<string, List<Command>> _map;
private readonly ConcurrentDictionary<Type, TypeReader> _typeReaders;
public IEnumerable<Module> Modules => _modules.Select(x => x.Value);
public IEnumerable<Command> Commands => _modules.SelectMany(x => x.Value.Commands);
public CommandService()
{
_moduleLock = new SemaphoreSlim(1, 1);
_modules = new ConcurrentDictionary<object, Module>();
_map = new ConcurrentDictionary<string, List<Command>>();
_typeReaders = new ConcurrentDictionary<Type, TypeReader>
{
[typeof(string)] = new GenericTypeReader((m, s) => Task.FromResult(TypeReaderResult.FromSuccess(s))),
[typeof(byte)] = new GenericTypeReader((m, s) =>
{
byte value;
if (byte.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Byte"));
}),
[typeof(sbyte)] = new GenericTypeReader((m, s) =>
{
sbyte value;
if (sbyte.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse SByte"));
}),
[typeof(ushort)] = new GenericTypeReader((m, s) =>
{
ushort value;
if (ushort.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse UInt16"));
}),
[typeof(short)] = new GenericTypeReader((m, s) =>
{
short value;
if (short.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Int16"));
}),
[typeof(uint)] = new GenericTypeReader((m, s) =>
{
uint value;
if (uint.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse UInt32"));
}),
[typeof(int)] = new GenericTypeReader((m, s) =>
{
int value;
if (int.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Int32"));
}),
[typeof(ulong)] = new GenericTypeReader((m, s) =>
{
ulong value;
if (ulong.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse UInt64"));
}),
[typeof(long)] = new GenericTypeReader((m, s) =>
{
long value;
if (long.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Int64"));
}),
[typeof(float)] = new GenericTypeReader((m, s) =>
{
float value;
if (float.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Single"));
}),
[typeof(double)] = new GenericTypeReader((m, s) =>
{
double value;
if (double.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Double"));
}),
[typeof(decimal)] = new GenericTypeReader((m, s) =>
{
decimal value;
if (decimal.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse Decimal"));
}),
[typeof(DateTime)] = new GenericTypeReader((m, s) =>
{
DateTime value;
if (DateTime.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse DateTime"));
}),
[typeof(DateTimeOffset)] = new GenericTypeReader((m, s) =>
{
DateTimeOffset value;
if (DateTimeOffset.TryParse(s, out value)) return Task.FromResult(TypeReaderResult.FromSuccess(value));
return Task.FromResult(TypeReaderResult.FromError(CommandError.ParseFailed, "Failed to parse DateTimeOffset"));
}),
[typeof(IMessage)] = new MessageTypeReader(),
[typeof(IChannel)] = new ChannelTypeReader<IChannel>(),
[typeof(IGuildChannel)] = new ChannelTypeReader<IGuildChannel>(),
[typeof(ITextChannel)] = new ChannelTypeReader<ITextChannel>(),
[typeof(IVoiceChannel)] = new ChannelTypeReader<IVoiceChannel>(),
[typeof(IRole)] = new RoleTypeReader(),
[typeof(IUser)] = new UserTypeReader<IUser>(),
[typeof(IGuildUser)] = new UserTypeReader<IGuildUser>()
};
}
public void AddTypeReader<T>(TypeReader reader)
{
_typeReaders[typeof(T)] = reader;
}
public void AddTypeReader(Type type, TypeReader reader)
{
_typeReaders[type] = reader;
}
internal TypeReader GetTypeReader(Type type)
{
TypeReader reader;
if (_typeReaders.TryGetValue(type, out reader))
return reader;
return null;
}
public async Task<Module> Load(object moduleInstance)
{
await _moduleLock.WaitAsync().ConfigureAwait(false);
try
{
if (_modules.ContainsKey(moduleInstance))
throw new ArgumentException($"This module has already been loaded.");
var typeInfo = moduleInstance.GetType().GetTypeInfo();
var moduleAttr = typeInfo.GetCustomAttribute<ModuleAttribute>();
if (moduleAttr != null)
throw new ArgumentException($"Modules must be marked with ModuleAttribute.");
return LoadInternal(moduleInstance, moduleAttr, typeInfo);
}
finally
{
_moduleLock.Release();
}
}
private Module LoadInternal(object moduleInstance, ModuleAttribute moduleAttr, TypeInfo typeInfo)
{
var loadedModule = new Module(this, moduleInstance, moduleAttr, typeInfo);
_modules[moduleInstance] = loadedModule;
foreach (var cmd in loadedModule.Commands)
{
var list = _map.GetOrAdd(cmd.Text, _ => new List<Command>());
lock (list)
list.Add(cmd);
}
return loadedModule;
}
public async Task<IEnumerable<Module>> LoadAssembly(Assembly assembly)
{
var modules = ImmutableArray.CreateBuilder<Module>();
await _moduleLock.WaitAsync().ConfigureAwait(false);
try
{
foreach (var type in assembly.ExportedTypes)
{
var typeInfo = type.GetTypeInfo();
var moduleAttr = typeInfo.GetCustomAttribute<ModuleAttribute>();
if (moduleAttr != null)
{
var moduleInstance = ReflectionUtils.CreateObject(typeInfo);
modules.Add(LoadInternal(moduleInstance, moduleAttr, typeInfo));
}
}
return modules.ToImmutable();
}
finally
{
_moduleLock.Release();
}
}
public async Task<bool> Unload(Module module)
{
await _moduleLock.WaitAsync().ConfigureAwait(false);
try
{
return UnloadInternal(module.Instance);
}
finally
{
_moduleLock.Release();
}
}
public async Task<bool> Unload(object moduleInstance)
{
await _moduleLock.WaitAsync().ConfigureAwait(false);
try
{
return UnloadInternal(moduleInstance);
}
finally
{
_moduleLock.Release();
}
}
private bool UnloadInternal(object module)
{
Module unloadedModule;
if (_modules.TryRemove(module, out unloadedModule))
{
foreach (var cmd in unloadedModule.Commands)
{
List<Command> list;
if (_map.TryGetValue(cmd.Text, out list))
{
lock (list)
list.Remove(cmd);
}
}
return true;
}
else
return false;
}
public SearchResult Search(IMessage message, int argPos) => Search(message, message.RawText.Substring(argPos));
public SearchResult Search(IMessage message, string input)
{
string lowerInput = input.ToLowerInvariant();
ImmutableArray<Command>.Builder matches = null;
List<Command> group;
int pos = -1;
while (true)
{
pos = input.IndexOf(' ', pos + 1);
string cmdText = pos == -1 ? input : input.Substring(0, pos);
if (!_map.TryGetValue(cmdText, out group))
break;
lock (group)
{
if (matches == null)
matches = ImmutableArray.CreateBuilder<Command>(group.Count);
for (int i = 0; i < group.Count; i++)
matches.Add(group[i]);
}
if (pos == -1)
{
pos = input.Length;
break;
}
}
if (matches != null)
return SearchResult.FromSuccess(input, matches.ToImmutable());
else
return SearchResult.FromError(CommandError.UnknownCommand, "Unknown command.");
}
public Task<IResult> Execute(IMessage message, int argPos) => Execute(message, message.RawText.Substring(argPos));
public async Task<IResult> Execute(IMessage message, string input)
{
var searchResult = Search(message, input);
if (!searchResult.IsSuccess)
return searchResult;
var commands = searchResult.Commands;
for (int i = commands.Count - 1; i >= 0; i++)
{
var parseResult = await commands[i].Parse(message, searchResult);
if (!parseResult.IsSuccess)
continue;
var executeResult = await commands[i].Execute(message, parseResult);
return executeResult;
}
return ParseResult.FromError(CommandError.ParseFailed, "This input does not match any overload.");
}
}
}