Added initial websocket support

This commit is contained in:
RogueException
2016-05-27 01:52:11 -03:00
parent 50392d3688
commit b93abcc95b
118 changed files with 1725 additions and 776 deletions

View File

@@ -1,7 +1,9 @@
using Discord.API.Rest;
using Discord.Net;
using Discord.Net.Converters;
using Discord.Net.Queue;
using Discord.Net.Rest;
using Discord.Net.WebSockets;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
@@ -17,77 +19,202 @@ using System.Threading.Tasks;
namespace Discord.API
{
public class DiscordApiClient
public class DiscordApiClient : IDisposable
{
internal event EventHandler<SentRequestEventArgs> SentRequest;
internal event Func<SentRequestEventArgs, Task> SentRequest;
private readonly RequestQueue _requestQueue;
private readonly JsonSerializer _serializer;
private readonly IRestClient _restClient;
private CancellationToken _cancelToken;
private readonly IWebSocketClient _gatewayClient;
private readonly SemaphoreSlim _connectionLock;
private CancellationTokenSource _loginCancelToken, _connectCancelToken;
private bool _isDisposed;
public LoginState LoginState { get; private set; }
public ConnectionState ConnectionState { get; private set; }
public TokenType AuthTokenType { get; private set; }
public IRestClient RestClient { get; private set; }
public IRequestQueue RequestQueue { get; private set; }
public DiscordApiClient(RestClientProvider restClientProvider)
public DiscordApiClient(RestClientProvider restClientProvider, WebSocketProvider webSocketProvider = null, JsonSerializer serializer = null, RequestQueue requestQueue = null)
{
_connectionLock = new SemaphoreSlim(1, 1);
_requestQueue = requestQueue ?? new RequestQueue();
_restClient = restClientProvider(DiscordConfig.ClientAPIUrl);
_restClient.SetHeader("accept", "*/*");
_restClient.SetHeader("user-agent", DiscordConfig.UserAgent);
_requestQueue = new RequestQueue(_restClient);
_serializer = new JsonSerializer()
if (webSocketProvider != null)
{
ContractResolver = new DiscordContractResolver()
};
}
public async Task Login(TokenType tokenType, string token, CancellationToken cancelToken)
{
AuthTokenType = tokenType;
_cancelToken = cancelToken;
await _requestQueue.SetCancelToken(cancelToken).ConfigureAwait(false);
switch (tokenType)
{
case TokenType.Bot:
token = $"Bot {token}";
break;
case TokenType.Bearer:
token = $"Bearer {token}";
break;
case TokenType.User:
break;
default:
throw new ArgumentException("Unknown oauth token type", nameof(tokenType));
_gatewayClient = webSocketProvider();
_gatewayClient.SetHeader("user-agent", DiscordConfig.UserAgent);
}
_restClient.SetHeader("authorization", token);
_serializer = serializer ?? new JsonSerializer { ContractResolver = new DiscordContractResolver() };
}
public async Task Login(LoginParams args, CancellationToken cancelToken)
void Dispose(bool disposing)
{
AuthTokenType = TokenType.User;
_restClient.SetHeader("authorization", null);
_cancelToken = cancelToken;
if (!_isDisposed)
{
if (disposing)
{
_loginCancelToken?.Dispose();
_connectCancelToken?.Dispose();
}
_isDisposed = true;
}
}
public void Dispose() => Dispose(true);
LoginResponse response;
public async Task Login(LoginParams args)
{
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
response = await Send<LoginResponse>("POST", "auth/login", args, GlobalBucket.Login).ConfigureAwait(false);
await LoginInternal(TokenType.User, null, args, true).ConfigureAwait(false);
}
catch
finally { _connectionLock.Release(); }
}
public async Task Login(TokenType tokenType, string token)
{
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
_cancelToken = CancellationToken.None;
await LoginInternal(tokenType, token, null, false).ConfigureAwait(false);
}
finally { _connectionLock.Release(); }
}
private async Task LoginInternal(TokenType tokenType, string token, LoginParams args, bool doLogin)
{
if (LoginState != LoginState.LoggedOut)
await LogoutInternal().ConfigureAwait(false);
LoginState = LoginState.LoggingIn;
try
{
_loginCancelToken = new CancellationTokenSource();
AuthTokenType = TokenType.User;
_restClient.SetHeader("authorization", null);
await _requestQueue.SetCancelToken(_loginCancelToken.Token).ConfigureAwait(false);
_restClient.SetCancelToken(_loginCancelToken.Token);
if (doLogin)
{
var response = await Send<LoginResponse>("POST", "auth/login", args, GlobalBucket.Login).ConfigureAwait(false);
token = response.Token;
}
AuthTokenType = tokenType;
switch (tokenType)
{
case TokenType.Bot:
token = $"Bot {token}";
break;
case TokenType.Bearer:
token = $"Bearer {token}";
break;
case TokenType.User:
break;
default:
throw new ArgumentException("Unknown oauth token type", nameof(tokenType));
}
_restClient.SetHeader("authorization", token);
LoginState = LoginState.LoggedIn;
}
catch (Exception)
{
await LogoutInternal().ConfigureAwait(false);
throw;
}
_restClient.SetHeader("authorization", response.Token);
}
public async Task Logout()
{
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
await LogoutInternal().ConfigureAwait(false);
}
finally { _connectionLock.Release(); }
}
private async Task LogoutInternal()
{
//TODO: An exception here will lock the client into the unusable LoggingOut state. How should we handle? (Add same solution to both DiscordClients too)
if (LoginState == LoginState.LoggedOut) return;
LoginState = LoginState.LoggingOut;
try { _loginCancelToken?.Cancel(false); }
catch { }
await DisconnectInternal().ConfigureAwait(false);
await _requestQueue.Clear().ConfigureAwait(false);
await _requestQueue.SetCancelToken(CancellationToken.None).ConfigureAwait(false);
_restClient.SetCancelToken(CancellationToken.None);
LoginState = LoginState.LoggedOut;
}
public async Task Connect()
{
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
await ConnectInternal().ConfigureAwait(false);
}
finally { _connectionLock.Release(); }
}
private async Task ConnectInternal()
{
if (LoginState != LoginState.LoggedIn)
throw new InvalidOperationException("You must log in before connecting.");
if (_gatewayClient == null)
throw new NotSupportedException("This client is not configured with websocket support.");
ConnectionState = ConnectionState.Connecting;
try
{
_connectCancelToken = new CancellationTokenSource();
if (_gatewayClient != null)
_gatewayClient.SetCancelToken(_connectCancelToken.Token);
var gatewayResponse = await GetGateway().ConfigureAwait(false);
await _gatewayClient.Connect(gatewayResponse.Url).ConfigureAwait(false);
ConnectionState = ConnectionState.Connected;
}
catch (Exception)
{
await DisconnectInternal().ConfigureAwait(false);
throw;
}
}
public async Task Disconnect()
{
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
await DisconnectInternal().ConfigureAwait(false);
}
finally { _connectionLock.Release(); }
}
private async Task DisconnectInternal()
{
if (_gatewayClient == null)
throw new NotSupportedException("This client is not configured with websocket support.");
if (ConnectionState == ConnectionState.Disconnected) return;
ConnectionState = ConnectionState.Disconnecting;
try { _connectCancelToken?.Cancel(false); }
catch { }
await _gatewayClient.Disconnect().ConfigureAwait(false);
ConnectionState = ConnectionState.Disconnected;
}
//Core
@@ -134,32 +261,28 @@ namespace Discord.API
private async Task<Stream> SendInternal(string method, string endpoint, object payload, bool headerOnly, BucketGroup group, int bucketId, ulong guildId)
{
_cancelToken.ThrowIfCancellationRequested();
var stopwatch = Stopwatch.StartNew();
string json = null;
if (payload != null)
json = Serialize(payload);
var responseStream = await _requestQueue.Send(new RestRequest(method, endpoint, json, headerOnly), group, bucketId, guildId).ConfigureAwait(false);
var responseStream = await _requestQueue.Send(new RestRequest(_restClient, method, endpoint, json, headerOnly), group, bucketId, guildId).ConfigureAwait(false);
int bytes = headerOnly ? 0 : (int)responseStream.Length;
stopwatch.Stop();
double milliseconds = ToMilliseconds(stopwatch);
SentRequest?.Invoke(this, new SentRequestEventArgs(method, endpoint, bytes, milliseconds));
await SentRequest.Raise(new SentRequestEventArgs(method, endpoint, bytes, milliseconds)).ConfigureAwait(false);
return responseStream;
}
private async Task<Stream> SendInternal(string method, string endpoint, IReadOnlyDictionary<string, object> multipartArgs, bool headerOnly, BucketGroup group, int bucketId, ulong guildId)
{
_cancelToken.ThrowIfCancellationRequested();
var stopwatch = Stopwatch.StartNew();
var responseStream = await _requestQueue.Send(new RestRequest(method, endpoint, multipartArgs, headerOnly), group, bucketId, guildId).ConfigureAwait(false);
var responseStream = await _requestQueue.Send(new RestRequest(_restClient, method, endpoint, multipartArgs, headerOnly), group, bucketId, guildId).ConfigureAwait(false);
int bytes = headerOnly ? 0 : (int)responseStream.Length;
stopwatch.Stop();
double milliseconds = ToMilliseconds(stopwatch);
SentRequest?.Invoke(this, new SentRequestEventArgs(method, endpoint, bytes, milliseconds));
await SentRequest.Raise(new SentRequestEventArgs(method, endpoint, bytes, milliseconds)).ConfigureAwait(false);
return responseStream;
}

View File

@@ -0,0 +1,11 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
namespace Discord.API
{
public class DiscordAPISocketClient
{
}
}

View File

@@ -0,0 +1,12 @@
using Newtonsoft.Json;
namespace Discord.API.Gateway
{
public class GuildMembersChunkEvent
{
[JsonProperty("guild_id")]
public ulong GuildId { get; set; }
[JsonProperty("members")]
public GuildMember[] Members { get; set; }
}
}

View File

@@ -0,0 +1,12 @@
using Newtonsoft.Json;
namespace Discord.API.Gateway
{
public class GuildRoleCreateEvent
{
[JsonProperty("guild_id")]
public ulong GuildId { get; set; }
[JsonProperty("role")]
public Role Data { get; set; }
}
}

View File

@@ -0,0 +1,12 @@
using Newtonsoft.Json;
namespace Discord.API.Gateway
{
public class GuildRoleUpdateEvent
{
[JsonProperty("guild_id")]
public ulong GuildId { get; set; }
[JsonProperty("role")]
public Role Data { get; set; }
}
}

View File

@@ -0,0 +1,17 @@
using Newtonsoft.Json;
using System.Collections.Generic;
namespace Discord.API.Gateway
{
public class IdentifyCommand
{
[JsonProperty("token")]
public string Token { get; set; }
[JsonProperty("properties")]
public IDictionary<string, string> Properties { get; set; }
[JsonProperty("large_threshold")]
public int LargeThreshold { get; set; }
[JsonProperty("compress")]
public bool UseCompression { get; set; }
}
}

View File

@@ -0,0 +1,24 @@
namespace Discord.API.Gateway
{
public enum OpCodes : byte
{
/// <summary> C←S - Used to send most events. </summary>
Dispatch = 0,
/// <summary> C↔S - Used to keep the connection alive and measure latency. </summary>
Heartbeat = 1,
/// <summary> C→S - Used to associate a connection with a token and specify configuration. </summary>
Identify = 2,
/// <summary> C→S - Used to update client's status and current game id. </summary>
StatusUpdate = 3,
/// <summary> C→S - Used to join a particular voice channel. </summary>
VoiceStateUpdate = 4,
/// <summary> C→S - Used to ensure the server's voice server is alive. Only send this if voice connection fails or suddenly drops. </summary>
VoiceServerPing = 5,
/// <summary> C→S - Used to resume a connection after a redirect occurs. </summary>
Resume = 6,
/// <summary> C←S - Used to notify a client that they must reconnect to another gateway. </summary>
Reconnect = 7,
/// <summary> C→S - Used to request all members that were withheld by large_threshold </summary>
RequestGuildMembers = 8
}
}

View File

@@ -0,0 +1,40 @@
using Newtonsoft.Json;
namespace Discord.API.Gateway
{
public class ReadyEvent
{
public class ReadState
{
[JsonProperty("id")]
public string ChannelId { get; set; }
[JsonProperty("mention_count")]
public int MentionCount { get; set; }
[JsonProperty("last_message_id")]
public string LastMessageId { get; set; }
}
[JsonProperty("v")]
public int Version { get; set; }
[JsonProperty("user")]
public User User { get; set; }
[JsonProperty("session_id")]
public string SessionId { get; set; }
[JsonProperty("read_state")]
public ReadState[] ReadStates { get; set; }
[JsonProperty("guilds")]
public Guild[] Guilds { get; set; }
[JsonProperty("private_channels")]
public Channel[] PrivateChannels { get; set; }
[JsonProperty("heartbeat_interval")]
public int HeartbeatInterval { get; set; }
//Ignored
[JsonProperty("user_settings")]
public object UserSettings { get; set; }
[JsonProperty("user_guild_settings")]
public object UserGuildSettings { get; set; }
[JsonProperty("tutorial")]
public object Tutorial { get; set; }
}
}

View File

@@ -0,0 +1,14 @@
using Newtonsoft.Json;
namespace Discord.API.Gateway
{
public class RequestMembersCommand
{
[JsonProperty("guild_id")]
public ulong[] GuildId { get; set; }
[JsonProperty("query")]
public string Query { get; set; }
[JsonProperty("limit")]
public int Limit { get; set; }
}
}

View File

@@ -0,0 +1,12 @@
using Newtonsoft.Json;
namespace Discord.API.Gateway
{
public class ResumeCommand
{
[JsonProperty("session_id")]
public string SessionId { get; set; }
[JsonProperty("seq")]
public uint Sequence { get; set; }
}
}

View File

@@ -0,0 +1,10 @@
using Newtonsoft.Json;
namespace Discord.API.Gateway
{
public class ResumedEvent
{
[JsonProperty("heartbeat_interval")]
public int HeartbeatInterval { get; set; }
}
}

View File

@@ -0,0 +1,14 @@
using Newtonsoft.Json;
namespace Discord.API.Gateway
{
public class TypingStartEvent
{
[JsonProperty("user_id")]
public ulong UserId { get; set; }
[JsonProperty("channel_id")]
public ulong ChannelId { get; set; }
[JsonProperty("timestamp")]
public int Timestamp { get; set; }
}
}

View File

@@ -0,0 +1,12 @@
using Newtonsoft.Json;
namespace Discord.API.Gateway
{
public class UpdateStatusCommand
{
[JsonProperty("idle_since")]
public long? IdleSince { get; set; }
[JsonProperty("game")]
public Game Game { get; set; }
}
}

View File

@@ -0,0 +1,16 @@
using Newtonsoft.Json;
namespace Discord.API.Gateway
{
public class UpdateVoiceCommand
{
[JsonProperty("guild_id")]
public ulong? GuildId { get; set; }
[JsonProperty("channel_id")]
public ulong? ChannelId { get; set; }
[JsonProperty("self_mute")]
public bool IsSelfMuted { get; set; }
[JsonProperty("self_deaf")]
public bool IsSelfDeafened { get; set; }
}
}

View File

@@ -0,0 +1,14 @@
using Newtonsoft.Json;
namespace Discord.API.Gateway
{
public class VoiceServerUpdateEvent
{
[JsonProperty("guild_id")]
public ulong GuildId { get; set; }
[JsonProperty("endpoint")]
public string Endpoint { get; set; }
[JsonProperty("token")]
public string Token { get; set; }
}
}

View File

@@ -1,9 +0,0 @@
namespace Discord.API
{
public interface IWebSocketMessage
{
int OpCode { get; }
object Payload { get; }
bool IsPrivate { get; }
}
}

View File

@@ -1,4 +1,5 @@
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
namespace Discord.API
{
@@ -11,13 +12,6 @@ namespace Discord.API
[JsonProperty("s", NullValueHandling = NullValueHandling.Ignore)]
public uint? Sequence { get; set; }
[JsonProperty("d")]
public object Payload { get; set; }
public WebSocketMessage() { }
public WebSocketMessage(IWebSocketMessage msg)
{
Operation = msg.OpCode;
Payload = msg.Payload;
}
public JToken Payload { get; set; }
}
}

View File

@@ -1,13 +0,0 @@
using System;
namespace Discord
{
internal static class EventExtensions
{
public static void Raise(this EventHandler eventHandler, object sender)
=> eventHandler?.Invoke(sender, EventArgs.Empty);
public static void Raise<T>(this EventHandler<T> eventHandler, object sender, T eventArgs)
where T : EventArgs
=> eventHandler?.Invoke(sender, eventArgs);
}
}

View File

@@ -0,0 +1,10 @@
namespace Discord
{
public enum ConnectionState : byte
{
Disconnected,
Connecting,
Connected,
Disconnecting
}
}

View File

@@ -5,6 +5,7 @@ namespace Discord
internal static class DateTimeUtils
{
private const ulong EpochTicks = 621355968000000000UL;
private const ulong DiscordEpochMillis = 1420070400000UL;
public static DateTime FromEpochMilliseconds(ulong value)
=> new DateTime((long)(value * TimeSpan.TicksPerMillisecond + EpochTicks), DateTimeKind.Utc);
@@ -12,6 +13,6 @@ namespace Discord
=> new DateTime((long)(value * TimeSpan.TicksPerSecond + EpochTicks), DateTimeKind.Utc);
public static DateTime FromSnowflake(ulong value)
=> FromEpochMilliseconds((value >> 22) + 1420070400000UL);
=> FromEpochMilliseconds((value >> 22) + DiscordEpochMillis);
}
}

View File

@@ -0,0 +1,45 @@
using System;
using System.Threading.Tasks;
namespace Discord
{
internal static class EventExtensions
{
public static async Task Raise(this Func<Task> eventHandler)
{
var subscriptions = eventHandler?.GetInvocationList();
if (subscriptions != null)
{
for (int i = 0; i < subscriptions.Length; i++)
await (subscriptions[i] as Func<Task>).Invoke().ConfigureAwait(false);
}
}
public static async Task Raise<T>(this Func<T, Task> eventHandler, T arg)
{
var subscriptions = eventHandler?.GetInvocationList();
if (subscriptions != null)
{
for (int i = 0; i < subscriptions.Length; i++)
await (subscriptions[i] as Func<T, Task>).Invoke(arg).ConfigureAwait(false);
}
}
public static async Task Raise<T1, T2>(this Func<T1, T2, Task> eventHandler, T1 arg1, T2 arg2)
{
var subscriptions = eventHandler?.GetInvocationList();
if (subscriptions != null)
{
for (int i = 0; i < subscriptions.Length; i++)
await (subscriptions[i] as Func<T1, T2, Task>).Invoke(arg1, arg2).ConfigureAwait(false);
}
}
public static async Task Raise<T1, T2, T3>(this Func<T1, T2, Task> eventHandler, T1 arg1, T2 arg2, T3 arg3)
{
var subscriptions = eventHandler?.GetInvocationList();
if (subscriptions != null)
{
for (int i = 0; i < subscriptions.Length; i++)
await (subscriptions[i] as Func<T1, T2, T3, Task>).Invoke(arg1, arg2, arg3).ConfigureAwait(false);
}
}
}
}

View File

@@ -1,5 +1,6 @@
using Discord.API;
using Discord.Net.Rest;
using Discord.Net.Queue;
using Discord.WebSocket.Data;
using System.Collections.Generic;
using System.IO;
using System.Threading.Tasks;
@@ -9,15 +10,20 @@ namespace Discord
//TODO: Add docstrings
public interface IDiscordClient
{
TokenType AuthTokenType { get; }
LoginState LoginState { get; }
ConnectionState ConnectionState { get; }
DiscordApiClient ApiClient { get; }
IRestClient RestClient { get; }
IRequestQueue RequestQueue { get; }
IDataStore DataStore { get; }
Task Login(string email, string password);
Task Login(TokenType tokenType, string token, bool validateToken = true);
Task Logout();
Task Connect();
Task Disconnect();
Task<IChannel> GetChannel(ulong id);
Task<IEnumerable<IDMChannel>> GetDMChannels();

View File

@@ -1,4 +1,5 @@
using System;
using System.Threading.Tasks;
namespace Discord.Logging
{
@@ -6,28 +7,28 @@ namespace Discord.Logging
{
LogSeverity Level { get; }
void Log(LogSeverity severity, string message, Exception exception = null);
void Log(LogSeverity severity, FormattableString message, Exception exception = null);
void Log(LogSeverity severity, Exception exception);
Task Log(LogSeverity severity, string message, Exception exception = null);
Task Log(LogSeverity severity, FormattableString message, Exception exception = null);
Task Log(LogSeverity severity, Exception exception);
void Error(string message, Exception exception = null);
void Error(FormattableString message, Exception exception = null);
void Error(Exception exception);
Task Error(string message, Exception exception = null);
Task Error(FormattableString message, Exception exception = null);
Task Error(Exception exception);
void Warning(string message, Exception exception = null);
void Warning(FormattableString message, Exception exception = null);
void Warning(Exception exception);
Task Warning(string message, Exception exception = null);
Task Warning(FormattableString message, Exception exception = null);
Task Warning(Exception exception);
void Info(string message, Exception exception = null);
void Info(FormattableString message, Exception exception = null);
void Info(Exception exception);
Task Info(string message, Exception exception = null);
Task Info(FormattableString message, Exception exception = null);
Task Info(Exception exception);
void Verbose(string message, Exception exception = null);
void Verbose(FormattableString message, Exception exception = null);
void Verbose(Exception exception);
Task Verbose(string message, Exception exception = null);
Task Verbose(FormattableString message, Exception exception = null);
Task Verbose(Exception exception);
void Debug(string message, Exception exception = null);
void Debug(FormattableString message, Exception exception = null);
void Debug(Exception exception);
Task Debug(string message, Exception exception = null);
Task Debug(FormattableString message, Exception exception = null);
Task Debug(Exception exception);
}
}

View File

@@ -1,4 +1,5 @@
using System;
using System.Threading.Tasks;
namespace Discord.Logging
{
@@ -6,107 +7,107 @@ namespace Discord.Logging
{
public LogSeverity Level { get; }
public event EventHandler<LogMessageEventArgs> Message = delegate { };
public event Func<LogMessageEventArgs, Task> Message;
internal LogManager(LogSeverity minSeverity)
{
Level = minSeverity;
}
public void Log(LogSeverity severity, string source, string message, Exception ex = null)
public async Task Log(LogSeverity severity, string source, string message, Exception ex = null)
{
if (severity <= Level)
Message(this, new LogMessageEventArgs(severity, source, message, ex));
await Message.Raise(new LogMessageEventArgs(severity, source, message, ex)).ConfigureAwait(false);
}
public void Log(LogSeverity severity, string source, FormattableString message, Exception ex = null)
public async Task Log(LogSeverity severity, string source, FormattableString message, Exception ex = null)
{
if (severity <= Level)
Message(this, new LogMessageEventArgs(severity, source, message.ToString(), ex));
await Message.Raise(new LogMessageEventArgs(severity, source, message.ToString(), ex)).ConfigureAwait(false);
}
public void Log(LogSeverity severity, string source, Exception ex)
public async Task Log(LogSeverity severity, string source, Exception ex)
{
if (severity <= Level)
Message(this, new LogMessageEventArgs(severity, source, null, ex));
await Message.Raise(new LogMessageEventArgs(severity, source, null, ex)).ConfigureAwait(false);
}
void ILogger.Log(LogSeverity severity, string message, Exception ex)
async Task ILogger.Log(LogSeverity severity, string message, Exception ex)
{
if (severity <= Level)
Message(this, new LogMessageEventArgs(severity, "Discord", message, ex));
await Message.Raise(new LogMessageEventArgs(severity, "Discord", message, ex)).ConfigureAwait(false);
}
void ILogger.Log(LogSeverity severity, FormattableString message, Exception ex)
async Task ILogger.Log(LogSeverity severity, FormattableString message, Exception ex)
{
if (severity <= Level)
Message(this, new LogMessageEventArgs(severity, "Discord", message.ToString(), ex));
await Message.Raise(new LogMessageEventArgs(severity, "Discord", message.ToString(), ex)).ConfigureAwait(false);
}
void ILogger.Log(LogSeverity severity, Exception ex)
async Task ILogger.Log(LogSeverity severity, Exception ex)
{
if (severity <= Level)
Message(this, new LogMessageEventArgs(severity, "Discord", null, ex));
await Message.Raise(new LogMessageEventArgs(severity, "Discord", null, ex)).ConfigureAwait(false);
}
public void Error(string source, string message, Exception ex = null)
public Task Error(string source, string message, Exception ex = null)
=> Log(LogSeverity.Error, source, message, ex);
public void Error(string source, FormattableString message, Exception ex = null)
public Task Error(string source, FormattableString message, Exception ex = null)
=> Log(LogSeverity.Error, source, message, ex);
public void Error(string source, Exception ex)
public Task Error(string source, Exception ex)
=> Log(LogSeverity.Error, source, ex);
void ILogger.Error(string message, Exception ex)
Task ILogger.Error(string message, Exception ex)
=> Log(LogSeverity.Error, "Discord", message, ex);
void ILogger.Error(FormattableString message, Exception ex)
Task ILogger.Error(FormattableString message, Exception ex)
=> Log(LogSeverity.Error, "Discord", message, ex);
void ILogger.Error(Exception ex)
Task ILogger.Error(Exception ex)
=> Log(LogSeverity.Error, "Discord", ex);
public void Warning(string source, string message, Exception ex = null)
public Task Warning(string source, string message, Exception ex = null)
=> Log(LogSeverity.Warning, source, message, ex);
public void Warning(string source, FormattableString message, Exception ex = null)
public Task Warning(string source, FormattableString message, Exception ex = null)
=> Log(LogSeverity.Warning, source, message, ex);
public void Warning(string source, Exception ex)
public Task Warning(string source, Exception ex)
=> Log(LogSeverity.Warning, source, ex);
void ILogger.Warning(string message, Exception ex)
Task ILogger.Warning(string message, Exception ex)
=> Log(LogSeverity.Warning, "Discord", message, ex);
void ILogger.Warning(FormattableString message, Exception ex)
Task ILogger.Warning(FormattableString message, Exception ex)
=> Log(LogSeverity.Warning, "Discord", message, ex);
void ILogger.Warning(Exception ex)
Task ILogger.Warning(Exception ex)
=> Log(LogSeverity.Warning, "Discord", ex);
public void Info(string source, string message, Exception ex = null)
public Task Info(string source, string message, Exception ex = null)
=> Log(LogSeverity.Info, source, message, ex);
public void Info(string source, FormattableString message, Exception ex = null)
public Task Info(string source, FormattableString message, Exception ex = null)
=> Log(LogSeverity.Info, source, message, ex);
public void Info(string source, Exception ex)
public Task Info(string source, Exception ex)
=> Log(LogSeverity.Info, source, ex);
void ILogger.Info(string message, Exception ex)
Task ILogger.Info(string message, Exception ex)
=> Log(LogSeverity.Info, "Discord", message, ex);
void ILogger.Info(FormattableString message, Exception ex)
Task ILogger.Info(FormattableString message, Exception ex)
=> Log(LogSeverity.Info, "Discord", message, ex);
void ILogger.Info(Exception ex)
Task ILogger.Info(Exception ex)
=> Log(LogSeverity.Info, "Discord", ex);
public void Verbose(string source, string message, Exception ex = null)
public Task Verbose(string source, string message, Exception ex = null)
=> Log(LogSeverity.Verbose, source, message, ex);
public void Verbose(string source, FormattableString message, Exception ex = null)
public Task Verbose(string source, FormattableString message, Exception ex = null)
=> Log(LogSeverity.Verbose, source, message, ex);
public void Verbose(string source, Exception ex)
public Task Verbose(string source, Exception ex)
=> Log(LogSeverity.Verbose, source, ex);
void ILogger.Verbose(string message, Exception ex)
Task ILogger.Verbose(string message, Exception ex)
=> Log(LogSeverity.Verbose, "Discord", message, ex);
void ILogger.Verbose(FormattableString message, Exception ex)
Task ILogger.Verbose(FormattableString message, Exception ex)
=> Log(LogSeverity.Verbose, "Discord", message, ex);
void ILogger.Verbose(Exception ex)
Task ILogger.Verbose(Exception ex)
=> Log(LogSeverity.Verbose, "Discord", ex);
public void Debug(string source, string message, Exception ex = null)
public Task Debug(string source, string message, Exception ex = null)
=> Log(LogSeverity.Debug, source, message, ex);
public void Debug(string source, FormattableString message, Exception ex = null)
public Task Debug(string source, FormattableString message, Exception ex = null)
=> Log(LogSeverity.Debug, source, message, ex);
public void Debug(string source, Exception ex)
public Task Debug(string source, Exception ex)
=> Log(LogSeverity.Debug, source, ex);
void ILogger.Debug(string message, Exception ex)
Task ILogger.Debug(string message, Exception ex)
=> Log(LogSeverity.Debug, "Discord", message, ex);
void ILogger.Debug(FormattableString message, Exception ex)
Task ILogger.Debug(FormattableString message, Exception ex)
=> Log(LogSeverity.Debug, "Discord", message, ex);
void ILogger.Debug(Exception ex)
Task ILogger.Debug(Exception ex)
=> Log(LogSeverity.Debug, "Discord", ex);
internal Logger CreateLogger(string name) => new Logger(this, name);

View File

@@ -0,0 +1,10 @@
namespace Discord
{
public enum LoginState : byte
{
LoggedOut,
LoggingIn,
LoggedIn,
LoggingOut
}
}

View File

@@ -1,4 +1,4 @@
namespace Discord.Net.Rest
namespace Discord.Net.Queue
{
internal enum BucketGroup
{

View File

@@ -0,0 +1,12 @@
namespace Discord.Net.Queue
{
public enum GlobalBucket
{
General,
Login,
DirectMessage,
SendEditMessage,
Gateway,
UpdateStatus
}
}

View File

@@ -1,4 +1,4 @@
namespace Discord.Net.Rest
namespace Discord.Net.Queue
{
public enum GuildBucket
{

View File

@@ -0,0 +1,13 @@
using System.IO;
using System.Threading;
using System.Threading.Tasks;
namespace Discord.Net.Queue
{
internal interface IQueuedRequest
{
TaskCompletionSource<Stream> Promise { get; }
CancellationToken CancelToken { get; }
Task<Stream> Send();
}
}

View File

@@ -1,6 +1,6 @@
using System.Threading.Tasks;
namespace Discord.Net.Rest
namespace Discord.Net.Queue
{
//TODO: Add docstrings
public interface IRequestQueue

View File

@@ -4,7 +4,7 @@ using System.IO;
using System.Threading;
using System.Threading.Tasks;
namespace Discord.Net.Rest
namespace Discord.Net.Queue
{
public class RequestQueue : IRequestQueue
{
@@ -15,12 +15,8 @@ namespace Discord.Net.Rest
private CancellationToken? _parentToken;
private CancellationToken _cancelToken;
public IRestClient RestClient { get; }
public RequestQueue(IRestClient restClient)
public RequestQueue()
{
RestClient = restClient;
_lock = new SemaphoreSlim(1, 1);
_globalBuckets = new RequestQueueBucket[Enum.GetValues(typeof(GlobalBucket)).Length];
_guildBuckets = new Dictionary<ulong, RequestQueueBucket>[Enum.GetValues(typeof(GuildBucket)).Length];
@@ -38,12 +34,10 @@ namespace Discord.Net.Rest
finally { Unlock(); }
}
internal async Task<Stream> Send(RestRequest request, BucketGroup group, int bucketId, ulong guildId)
internal async Task<Stream> Send(IQueuedRequest request, BucketGroup group, int bucketId, ulong guildId)
{
RequestQueueBucket bucket;
request.CancelToken = _cancelToken;
await Lock().ConfigureAwait(false);
try
{
@@ -66,6 +60,9 @@ namespace Discord.Net.Rest
case GlobalBucket.General: return new RequestQueueBucket(this, bucket, int.MaxValue, 0); //Catch-all
case GlobalBucket.Login: return new RequestQueueBucket(this, bucket, 1, 1); //TODO: Is this actual logins or token validations too?
case GlobalBucket.DirectMessage: return new RequestQueueBucket(this, bucket, 5, 5);
case GlobalBucket.SendEditMessage: return new RequestQueueBucket(this, bucket, 50, 10);
case GlobalBucket.Gateway: return new RequestQueueBucket(this, bucket, 120, 60);
case GlobalBucket.UpdateStatus: return new RequestQueueBucket(this, bucket, 5, 1, GlobalBucket.Gateway);
default: throw new ArgumentException($"Unknown global bucket: {bucket}", nameof(bucket));
}
@@ -75,7 +72,7 @@ namespace Discord.Net.Rest
switch (bucket)
{
//Per Guild
case GuildBucket.SendEditMessage: return new RequestQueueBucket(this, bucket, guildId, 5, 5);
case GuildBucket.SendEditMessage: return new RequestQueueBucket(this, bucket, guildId, 5, 5, GlobalBucket.SendEditMessage);
case GuildBucket.DeleteMessage: return new RequestQueueBucket(this, bucket, guildId, 5, 1);
case GuildBucket.DeleteMessages: return new RequestQueueBucket(this, bucket, guildId, 1, 1);
case GuildBucket.ModifyMember: return new RequestQueueBucket(this, bucket, guildId, 10, 10); //TODO: Is this all users or just roles?

View File

@@ -5,15 +5,17 @@ using System.Net;
using System.Threading;
using System.Threading.Tasks;
namespace Discord.Net.Rest
namespace Discord.Net.Queue
{
//TODO: Implement bucket chaining
internal class RequestQueueBucket
{
private readonly RequestQueue _parent;
private readonly BucketGroup _bucketGroup;
private readonly GlobalBucket? _chainedBucket;
private readonly int _bucketId;
private readonly ulong _guildId;
private readonly ConcurrentQueue<RestRequest> _queue;
private readonly ConcurrentQueue<IQueuedRequest> _queue;
private readonly SemaphoreSlim _lock;
private Task _resetTask;
private bool _waitingToProcess;
@@ -23,31 +25,32 @@ namespace Discord.Net.Rest
public int WindowSeconds { get; }
public int WindowCount { get; private set; }
public RequestQueueBucket(RequestQueue parent, GlobalBucket bucket, int windowMaxCount, int windowSeconds)
: this(parent, windowMaxCount, windowSeconds)
public RequestQueueBucket(RequestQueue parent, GlobalBucket bucket, int windowMaxCount, int windowSeconds, GlobalBucket? chainedBucket = null)
: this(parent, windowMaxCount, windowSeconds, chainedBucket)
{
_bucketGroup = BucketGroup.Global;
_bucketId = (int)bucket;
_guildId = 0;
}
public RequestQueueBucket(RequestQueue parent, GuildBucket bucket, ulong guildId, int windowMaxCount, int windowSeconds)
: this(parent, windowMaxCount, windowSeconds)
public RequestQueueBucket(RequestQueue parent, GuildBucket bucket, ulong guildId, int windowMaxCount, int windowSeconds, GlobalBucket? chainedBucket = null)
: this(parent, windowMaxCount, windowSeconds, chainedBucket)
{
_bucketGroup = BucketGroup.Guild;
_bucketId = (int)bucket;
_guildId = guildId;
}
private RequestQueueBucket(RequestQueue parent, int windowMaxCount, int windowSeconds)
private RequestQueueBucket(RequestQueue parent, int windowMaxCount, int windowSeconds, GlobalBucket? chainedBucket = null)
{
_parent = parent;
WindowMaxCount = windowMaxCount;
WindowSeconds = windowSeconds;
_queue = new ConcurrentQueue<RestRequest>();
_chainedBucket = chainedBucket;
_queue = new ConcurrentQueue<IQueuedRequest>();
_lock = new SemaphoreSlim(1, 1);
_id = new System.Random().Next(0, int.MaxValue);
}
public void Queue(RestRequest request)
public void Queue(IQueuedRequest request)
{
_queue.Enqueue(request);
}
@@ -68,7 +71,7 @@ namespace Discord.Net.Rest
_waitingToProcess = false;
while (true)
{
RestRequest request;
IQueuedRequest request;
//If we're waiting to reset (due to a rate limit exception, or preemptive check), abort
if (WindowCount == WindowMaxCount) return;
@@ -81,11 +84,7 @@ namespace Discord.Net.Rest
request.Promise.SetException(new OperationCanceledException(request.CancelToken));
else
{
Stream stream;
if (request.IsMultipart)
stream = await _parent.RestClient.Send(request.Method, request.Endpoint, request.CancelToken, request.MultipartParams, request.HeaderOnly).ConfigureAwait(false);
else
stream = await _parent.RestClient.Send(request.Method, request.Endpoint, request.CancelToken, request.Json, request.HeaderOnly).ConfigureAwait(false);
Stream stream = await request.Send().ConfigureAwait(false);
request.Promise.SetResult(stream);
}
}
@@ -157,7 +156,7 @@ namespace Discord.Net.Rest
public void Clear()
{
//Assume this obj is under lock
RestRequest request;
IQueuedRequest request;
while (_queue.TryDequeue(out request)) { }
}

View File

@@ -0,0 +1,53 @@
using Discord.Net.Rest;
using System.Collections.Generic;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
namespace Discord.Net.Queue
{
internal class RestRequest : IQueuedRequest
{
public IRestClient Client { get; }
public string Method { get; }
public string Endpoint { get; }
public string Json { get; }
public bool HeaderOnly { get; }
public IReadOnlyDictionary<string, object> MultipartParams { get; }
public TaskCompletionSource<Stream> Promise { get; }
public CancellationToken CancelToken { get; internal set; }
public bool IsMultipart => MultipartParams != null;
public RestRequest(IRestClient client, string method, string endpoint, string json, bool headerOnly)
: this(client, method, endpoint, headerOnly)
{
Json = json;
}
public RestRequest(IRestClient client, string method, string endpoint, IReadOnlyDictionary<string, object> multipartParams, bool headerOnly)
: this(client, method, endpoint, headerOnly)
{
MultipartParams = multipartParams;
}
private RestRequest(IRestClient client, string method, string endpoint, bool headerOnly)
{
Client = client;
Method = method;
Endpoint = endpoint;
Json = null;
MultipartParams = null;
HeaderOnly = headerOnly;
Promise = new TaskCompletionSource<Stream>();
}
public async Task<Stream> Send()
{
if (IsMultipart)
return await Client.Send(Method, Endpoint, MultipartParams, HeaderOnly).ConfigureAwait(false);
else
return await Client.Send(Method, Endpoint, Json, HeaderOnly).ConfigureAwait(false);
}
}
}

View File

@@ -0,0 +1,34 @@
using Discord.Net.WebSockets;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
namespace Discord.Net.Queue
{
internal class WebSocketRequest : IQueuedRequest
{
public IWebSocketClient Client { get; }
public byte[] Data { get; }
public int Offset { get; }
public int Bytes { get; }
public bool IsText { get; }
public CancellationToken CancelToken { get; }
public TaskCompletionSource<Stream> Promise { get; }
public WebSocketRequest(byte[] data, bool isText, CancellationToken cancelToken) : this(data, 0, data.Length, isText, cancelToken) { }
public WebSocketRequest(byte[] data, int offset, int length, bool isText, CancellationToken cancelToken)
{
Data = data;
Offset = offset;
Bytes = length;
IsText = isText;
Promise = new TaskCompletionSource<Stream>();
}
public async Task<Stream> Send()
{
await Client.Send(Data, Offset, Bytes, IsText).ConfigureAwait(false);
return null;
}
}
}

View File

@@ -17,6 +17,8 @@ namespace Discord.Net.Rest
protected readonly HttpClient _client;
protected readonly string _baseUrl;
private CancellationTokenSource _cancelTokenSource;
private CancellationToken _cancelToken, _parentToken;
protected bool _isDisposed;
public DefaultRestClient(string baseUrl)
@@ -32,6 +34,7 @@ namespace Discord.Net.Rest
});
SetHeader("accept-encoding", "gzip, deflate");
_parentToken = CancellationToken.None;
}
protected virtual void Dispose(bool disposing)
{
@@ -53,19 +56,28 @@ namespace Discord.Net.Rest
if (value != null)
_client.DefaultRequestHeaders.Add(key, value);
}
public void SetCancelToken(CancellationToken cancelToken)
{
_parentToken = cancelToken;
_cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_parentToken, _cancelTokenSource.Token).Token;
}
public async Task<Stream> Send(string method, string endpoint, CancellationToken cancelToken, string json = null, bool headerOnly = false)
public async Task<Stream> Send(string method, string endpoint, bool headerOnly = false)
{
string uri = Path.Combine(_baseUrl, endpoint);
using (var restRequest = new HttpRequestMessage(GetMethod(method), uri))
return await SendInternal(restRequest, headerOnly).ConfigureAwait(false);
}
public async Task<Stream> Send(string method, string endpoint, string json, bool headerOnly = false)
{
string uri = Path.Combine(_baseUrl, endpoint);
using (var restRequest = new HttpRequestMessage(GetMethod(method), uri))
{
if (json != null)
restRequest.Content = new StringContent(json, Encoding.UTF8, "application/json");
return await SendInternal(restRequest, cancelToken, headerOnly).ConfigureAwait(false);
restRequest.Content = new StringContent(json, Encoding.UTF8, "application/json");
return await SendInternal(restRequest, headerOnly).ConfigureAwait(false);
}
}
public async Task<Stream> Send(string method, string endpoint, CancellationToken cancelToken, IReadOnlyDictionary<string, object> multipartParams, bool headerOnly = false)
public async Task<Stream> Send(string method, string endpoint, IReadOnlyDictionary<string, object> multipartParams, bool headerOnly = false)
{
string uri = Path.Combine(_baseUrl, endpoint);
using (var restRequest = new HttpRequestMessage(GetMethod(method), uri))
@@ -112,14 +124,15 @@ namespace Discord.Net.Rest
}
}
restRequest.Content = content;
return await SendInternal(restRequest, cancelToken, headerOnly).ConfigureAwait(false);
return await SendInternal(restRequest, headerOnly).ConfigureAwait(false);
}
}
private async Task<Stream> SendInternal(HttpRequestMessage request, CancellationToken cancelToken, bool headerOnly)
private async Task<Stream> SendInternal(HttpRequestMessage request, bool headerOnly)
{
while (true)
{
var cancelToken = _cancelToken; //It's okay if another thread changes this, causes a retry to abort
HttpResponseMessage response = await _client.SendAsync(request, cancelToken).ConfigureAwait(false);
int statusCode = (int)response.StatusCode;

View File

@@ -9,8 +9,10 @@ namespace Discord.Net.Rest
public interface IRestClient
{
void SetHeader(string key, string value);
void SetCancelToken(CancellationToken cancelToken);
Task<Stream> Send(string method, string endpoint, CancellationToken cancelToken, string json = null, bool headerOnly = false);
Task<Stream> Send(string method, string endpoint, CancellationToken cancelToken, IReadOnlyDictionary<string, object> multipartParams, bool headerOnly = false);
Task<Stream> Send(string method, string endpoint, bool headerOnly = false);
Task<Stream> Send(string method, string endpoint, string json, bool headerOnly = false);
Task<Stream> Send(string method, string endpoint, IReadOnlyDictionary<string, object> multipartParams, bool headerOnly = false);
}
}

View File

@@ -1,9 +0,0 @@
namespace Discord.Net.Rest
{
public enum GlobalBucket
{
General,
Login,
DirectMessage
}
}

View File

@@ -1,42 +0,0 @@
using System.Collections.Generic;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
namespace Discord.Net.Rest
{
internal class RestRequest
{
public string Method { get; }
public string Endpoint { get; }
public string Json { get; }
public bool HeaderOnly { get; }
public CancellationToken CancelToken { get; internal set; }
public IReadOnlyDictionary<string, object> MultipartParams { get; }
public TaskCompletionSource<Stream> Promise { get; }
public bool IsMultipart => MultipartParams != null;
public RestRequest(string method, string endpoint, string json, bool headerOnly)
: this(method, endpoint, headerOnly)
{
Json = json;
}
public RestRequest(string method, string endpoint, IReadOnlyDictionary<string, object> multipartParams, bool headerOnly)
: this(method, endpoint, headerOnly)
{
MultipartParams = multipartParams;
}
private RestRequest(string method, string endpoint, bool headerOnly)
{
Method = method;
Endpoint = endpoint;
Json = null;
MultipartParams = null;
HeaderOnly = headerOnly;
Promise = new TaskCompletionSource<Stream>();
}
}
}

View File

@@ -1,5 +1,4 @@
using System;
using System.Collections.Concurrent;
using System.ComponentModel;
using System.IO;
using System.Net.WebSockets;
@@ -13,26 +12,25 @@ namespace Discord.Net.WebSockets
{
public const int ReceiveChunkSize = 12 * 1024; //12KB
public const int SendChunkSize = 4 * 1024; //4KB
protected const int HR_TIMEOUT = -2147012894;
private const int HR_TIMEOUT = -2147012894;
public event EventHandler<BinaryMessageEventArgs> BinaryMessage = delegate { };
public event EventHandler<TextMessageEventArgs> TextMessage = delegate { };
protected readonly ConcurrentQueue<string> _sendQueue;
protected readonly ClientWebSocket _client;
protected Task _receiveTask, _sendTask;
protected CancellationTokenSource _cancelToken;
protected bool _isDisposed;
public event Func<BinaryMessageEventArgs, Task> BinaryMessage;
public event Func<TextMessageEventArgs, Task> TextMessage;
private readonly ClientWebSocket _client;
private Task _task;
private CancellationTokenSource _cancelTokenSource;
private CancellationToken _cancelToken, _parentToken;
private bool _isDisposed;
public DefaultWebSocketClient()
{
_sendQueue = new ConcurrentQueue<string>();
_client = new ClientWebSocket();
_client.Options.Proxy = null;
_client.Options.KeepAliveInterval = TimeSpan.Zero;
_parentToken = CancellationToken.None;
}
protected virtual void Dispose(bool disposing)
private void Dispose(bool disposing)
{
if (!_isDisposed)
{
@@ -46,135 +44,106 @@ namespace Discord.Net.WebSockets
Dispose(true);
}
public async Task Connect(string host, CancellationToken cancelToken)
public async Task Connect(string host)
{
await Disconnect().ConfigureAwait(false);
_cancelToken = new CancellationTokenSource();
var combinedToken = CancellationTokenSource.CreateLinkedTokenSource(_cancelToken.Token, cancelToken).Token;
_cancelTokenSource = new CancellationTokenSource();
_cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_parentToken, _cancelTokenSource.Token).Token;
await _client.ConnectAsync(new Uri(host), combinedToken).ConfigureAwait(false);
_receiveTask = ReceiveAsync(combinedToken);
_sendTask = SendAsync(combinedToken);
await _client.ConnectAsync(new Uri(host), _cancelToken).ConfigureAwait(false);
_task = Run(_cancelToken);
}
public async Task Disconnect()
{
_cancelToken.Cancel();
string ignored;
while (_sendQueue.TryDequeue(out ignored)) { }
_cancelTokenSource.Cancel();
_client.Abort();
var receiveTask = _receiveTask ?? Task.CompletedTask;
var sendTask = _sendTask ?? Task.CompletedTask;
await Task.WhenAll(receiveTask, sendTask).ConfigureAwait(false);
await (_task ?? Task.CompletedTask).ConfigureAwait(false);
}
public void SetHeader(string key, string value)
{
_client.Options.SetRequestHeader(key, value);
}
public void QueueMessage(string message)
public void SetCancelToken(CancellationToken cancelToken)
{
_sendQueue.Enqueue(message);
_parentToken = cancelToken;
_cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_parentToken, _cancelTokenSource.Token).Token;
}
//TODO: Check this code
private Task ReceiveAsync(CancellationToken cancelToken)
public async Task Send(byte[] data, int offset, int count, bool isText)
{
return Task.Run(async () =>
int frameCount = (int)Math.Ceiling((double)count / SendChunkSize);
for (int i = 0; i < frameCount; i++, offset += SendChunkSize)
{
var buffer = new ArraySegment<byte>(new byte[ReceiveChunkSize]);
var stream = new MemoryStream();
bool isLast = i == (frameCount - 1);
int frameSize;
if (isLast)
frameSize = count - (i * SendChunkSize);
else
frameSize = SendChunkSize;
try
{
while (!cancelToken.IsCancellationRequested)
{
WebSocketReceiveResult result = null;
do
{
if (cancelToken.IsCancellationRequested) return;
try
{
result = await _client.ReceiveAsync(buffer, cancelToken).ConfigureAwait(false);
}
catch (Win32Exception ex) when (ex.HResult == HR_TIMEOUT)
{
throw new Exception($"Connection timed out.");
}
if (result.MessageType == WebSocketMessageType.Close)
throw new WebSocketException((int)result.CloseStatus.Value, result.CloseStatusDescription);
else
stream.Write(buffer.Array, 0, result.Count);
}
while (result == null || !result.EndOfMessage);
var array = stream.ToArray();
if (result.MessageType == WebSocketMessageType.Binary)
BinaryMessage(this, new BinaryMessageEventArgs(array));
else if (result.MessageType == WebSocketMessageType.Text)
{
string text = Encoding.UTF8.GetString(array, 0, array.Length);
TextMessage(this, new TextMessageEventArgs(text));
}
stream.Position = 0;
stream.SetLength(0);
}
await _client.SendAsync(new ArraySegment<byte>(data, offset, count), isText ? WebSocketMessageType.Text : WebSocketMessageType.Binary, isLast, _cancelToken).ConfigureAwait(false);
}
catch (OperationCanceledException) { }
});
catch (Win32Exception ex) when (ex.HResult == HR_TIMEOUT)
{
return;
}
}
}
//TODO: Check this code
private Task SendAsync(CancellationToken cancelToken)
private async Task Run(CancellationToken cancelToken)
{
return Task.Run(async () =>
var buffer = new ArraySegment<byte>(new byte[ReceiveChunkSize]);
var stream = new MemoryStream();
try
{
byte[] bytes = new byte[SendChunkSize];
try
while (!cancelToken.IsCancellationRequested)
{
while (!cancelToken.IsCancellationRequested)
WebSocketReceiveResult result = null;
do
{
string json;
while (_sendQueue.TryDequeue(out json))
if (cancelToken.IsCancellationRequested) return;
try
{
int byteCount = Encoding.UTF8.GetBytes(json, 0, json.Length, bytes, 0);
int frameCount = (int)Math.Ceiling((double)byteCount / SendChunkSize);
int offset = 0;
for (int i = 0; i < frameCount; i++, offset += SendChunkSize)
{
bool isLast = i == (frameCount - 1);
int count;
if (isLast)
count = byteCount - (i * SendChunkSize);
else
count = SendChunkSize;
try
{
await _client.SendAsync(new ArraySegment<byte>(bytes, offset, count), WebSocketMessageType.Text, isLast, cancelToken).ConfigureAwait(false);
}
catch (Win32Exception ex) when (ex.HResult == HR_TIMEOUT)
{
return;
}
}
result = await _client.ReceiveAsync(buffer, cancelToken).ConfigureAwait(false);
}
await Task.Delay(DiscordConfig.WebSocketQueueInterval, cancelToken).ConfigureAwait(false);
catch (Win32Exception ex) when (ex.HResult == HR_TIMEOUT)
{
throw new Exception("Connection timed out.");
}
if (result.MessageType == WebSocketMessageType.Close)
throw new WebSocketException((int)result.CloseStatus.Value, result.CloseStatusDescription);
else
stream.Write(buffer.Array, 0, result.Count);
}
while (result == null || !result.EndOfMessage);
var array = stream.ToArray();
if (result.MessageType == WebSocketMessageType.Binary)
await BinaryMessage.Raise(new BinaryMessageEventArgs(array)).ConfigureAwait(false);
else if (result.MessageType == WebSocketMessageType.Text)
{
string text = Encoding.UTF8.GetString(array, 0, array.Length);
await TextMessage.Raise(new TextMessageEventArgs(text)).ConfigureAwait(false);
}
stream.Position = 0;
stream.SetLength(0);
}
catch (OperationCanceledException) { }
});
}
catch (OperationCanceledException) { }
}
}
}

View File

@@ -7,13 +7,15 @@ namespace Discord.Net.WebSockets
//TODO: Add ETF
public interface IWebSocketClient
{
event EventHandler<BinaryMessageEventArgs> BinaryMessage;
event EventHandler<TextMessageEventArgs> TextMessage;
event Func<BinaryMessageEventArgs, Task> BinaryMessage;
event Func<TextMessageEventArgs, Task> TextMessage;
void SetHeader(string key, string value);
void SetCancelToken(CancellationToken cancelToken);
Task Connect(string host, CancellationToken cancelToken);
Task Connect(string host);
Task Disconnect();
void QueueMessage(string message);
Task Send(byte[] data, int offset, int length, bool isText);
}
}

View File

@@ -1,4 +1,4 @@
namespace Discord.Net.WebSockets
{
public delegate IWebSocketClient WebSocketProvider(string baseUrl);
public delegate IWebSocketClient WebSocketProvider();
}

View File

@@ -1,5 +1,7 @@
using Discord.API.Rest;
using Discord.Logging;
using Discord.Net;
using Discord.Net.Queue;
using Discord.Net.Rest;
using System;
using System.Collections.Generic;
@@ -15,39 +17,37 @@ namespace Discord.Rest
//TODO: Log Logins/Logouts
public sealed class DiscordClient : IDiscordClient, IDisposable
{
public event EventHandler<LogMessageEventArgs> Log;
public event EventHandler LoggedIn, LoggedOut;
public event Func<LogMessageEventArgs, Task> Log;
public event Func<Task> LoggedIn, LoggedOut;
private readonly Logger _discordLogger, _restLogger;
private readonly SemaphoreSlim _connectionLock;
private readonly RestClientProvider _restClientProvider;
private readonly LogManager _log;
private CancellationTokenSource _cancelTokenSource;
private readonly RequestQueue _requestQueue;
private bool _isDisposed;
private SelfUser _currentUser;
public bool IsLoggedIn { get; private set; }
public LoginState LoginState { get; private set; }
public API.DiscordApiClient ApiClient { get; private set; }
public TokenType AuthTokenType => ApiClient.AuthTokenType;
public IRestClient RestClient => ApiClient.RestClient;
public IRequestQueue RequestQueue => ApiClient.RequestQueue;
public IRequestQueue RequestQueue => _requestQueue;
public DiscordClient(DiscordConfig config = null)
{
if (config == null)
config = new DiscordConfig();
_restClientProvider = config.RestClientProvider;
_log = new LogManager(config.LogLevel);
_log.Message += (s, e) => Log.Raise(this, e);
_log.Message += async e => await Log.Raise(e).ConfigureAwait(false);
_discordLogger = _log.CreateLogger("Discord");
_restLogger = _log.CreateLogger("Rest");
_connectionLock = new SemaphoreSlim(1, 1);
ApiClient = new API.DiscordApiClient(_restClientProvider);
ApiClient.SentRequest += (s, e) => _log.Verbose("Rest", $"{e.Method} {e.Endpoint}: {e.Milliseconds} ms");
_requestQueue = new RequestQueue();
ApiClient = new API.DiscordApiClient(config.RestClientProvider, requestQueue: _requestQueue);
ApiClient.SentRequest += async e => await _log.Verbose("Rest", $"{e.Method} {e.Endpoint}: {e.Milliseconds} ms").ConfigureAwait(false);
}
public async Task Login(string email, string password)
@@ -55,7 +55,7 @@ namespace Discord.Rest
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
await LoginInternal(email, password).ConfigureAwait(false);
await LoginInternal(TokenType.User, null, email, password, true, false).ConfigureAwait(false);
}
finally { _connectionLock.Release(); }
}
@@ -64,55 +64,51 @@ namespace Discord.Rest
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
await LoginInternal(tokenType, token, validateToken).ConfigureAwait(false);
await LoginInternal(tokenType, token, null, null, false, validateToken).ConfigureAwait(false);
}
finally { _connectionLock.Release(); }
}
private async Task LoginInternal(string email, string password)
private async Task LoginInternal(TokenType tokenType, string token, string email, string password, bool useEmail, bool validateToken)
{
if (IsLoggedIn)
if (LoginState != LoginState.LoggedOut)
await LogoutInternal().ConfigureAwait(false);
LoginState = LoginState.LoggingIn;
try
{
_cancelTokenSource = new CancellationTokenSource();
var args = new LoginParams { Email = email, Password = password };
await ApiClient.Login(args, _cancelTokenSource.Token).ConfigureAwait(false);
await CompleteLogin(false).ConfigureAwait(false);
}
catch { await LogoutInternal().ConfigureAwait(false); throw; }
}
private async Task LoginInternal(TokenType tokenType, string token, bool validateToken)
{
if (IsLoggedIn)
await LogoutInternal().ConfigureAwait(false);
try
{
_cancelTokenSource = new CancellationTokenSource();
await ApiClient.Login(tokenType, token, _cancelTokenSource.Token).ConfigureAwait(false);
await CompleteLogin(validateToken).ConfigureAwait(false);
}
catch { await LogoutInternal().ConfigureAwait(false); throw; }
}
private async Task CompleteLogin(bool validateToken)
{
if (validateToken)
{
try
if (useEmail)
{
await ApiClient.ValidateToken().ConfigureAwait(false);
var args = new LoginParams { Email = email, Password = password };
await ApiClient.Login(args).ConfigureAwait(false);
}
catch { await ApiClient.Logout().ConfigureAwait(false); }
else
await ApiClient.Login(tokenType, token).ConfigureAwait(false);
if (validateToken)
{
try
{
await ApiClient.ValidateToken().ConfigureAwait(false);
}
catch (HttpException ex)
{
throw new ArgumentException("Token validation failed", nameof(token), ex);
}
}
LoginState = LoginState.LoggedIn;
}
IsLoggedIn = true;
LoggedIn.Raise(this);
catch (Exception)
{
await LogoutInternal().ConfigureAwait(false);
throw;
}
await LoggedIn.Raise().ConfigureAwait(false);
}
public async Task Logout()
{
_cancelTokenSource?.Cancel();
await _connectionLock.WaitAsync().ConfigureAwait(false);
try
{
@@ -122,22 +118,16 @@ namespace Discord.Rest
}
private async Task LogoutInternal()
{
bool wasLoggedIn = IsLoggedIn;
if (_cancelTokenSource != null)
{
try { _cancelTokenSource.Cancel(false); }
catch { }
}
if (LoginState == LoginState.LoggedOut) return;
LoginState = LoginState.LoggingOut;
await ApiClient.Logout().ConfigureAwait(false);
_currentUser = null;
if (wasLoggedIn)
{
IsLoggedIn = false;
LoggedOut.Raise(this);
}
LoginState = LoginState.LoggedOut;
await LoggedOut.Raise().ConfigureAwait(false);
}
public async Task<IEnumerable<Connection>> GetConnections()
@@ -251,16 +241,15 @@ namespace Discord.Rest
void Dispose(bool disposing)
{
if (!_isDisposed)
{
if (disposing)
_cancelTokenSource.Dispose();
_isDisposed = true;
}
}
public void Dispose() => Dispose(true);
API.DiscordApiClient IDiscordClient.ApiClient => ApiClient;
ConnectionState IDiscordClient.ConnectionState => ConnectionState.Disconnected;
WebSocket.Data.IDataStore IDiscordClient.DataStore => null;
Task IDiscordClient.Connect() { return Task.FromException(new NotSupportedException("This client does not support websocket connections.")); }
Task IDiscordClient.Disconnect() { return Task.FromException(new NotSupportedException("This client does not support websocket connections.")); }
async Task<IChannel> IDiscordClient.GetChannel(ulong id)
=> await GetChannel(id).ConfigureAwait(false);
async Task<IEnumerable<IDMChannel>> IDiscordClient.GetDMChannels()

View File

@@ -0,0 +1,4 @@
namespace Discord.WebSocket.Data
{
public delegate IDataStore DataStoreProvider(int shardId, int totalShards, int guildCount, int dmCount);
}

View File

@@ -0,0 +1,110 @@
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
namespace Discord.WebSocket.Data
{
public class DefaultDataStore : IDataStore
{
private const double AverageChannelsPerGuild = 10.22; //Source: Googie2149
private const double AverageRolesPerGuild = 5; //Source: Googie2149 //TODO: Get a real value
private const double AverageUsersPerGuild = 47.78; //Source: Googie2149
private const double CollectionMultiplier = 1.05; //Add buffer to handle growth
private const double CollectionConcurrencyLevel = 1; //WebSocket updater/event handler. //TODO: Needs profiling, increase to 2?
private ConcurrentDictionary<ulong, Channel> _channels;
private ConcurrentDictionary<ulong, Guild> _guilds;
private ConcurrentDictionary<ulong, Role> _roles;
private ConcurrentDictionary<ulong, User> _users;
public IEnumerable<Channel> Channels => _channels.Select(x => x.Value);
public IEnumerable<Guild> Guilds => _guilds.Select(x => x.Value);
public IEnumerable<Role> Roles => _roles.Select(x => x.Value);
public IEnumerable<User> Users => _users.Select(x => x.Value);
public DefaultDataStore(int guildCount, int dmChannelCount)
{
_channels = new ConcurrentDictionary<ulong, Channel>(1, (int)((guildCount * AverageChannelsPerGuild + dmChannelCount) * CollectionMultiplier));
_guilds = new ConcurrentDictionary<ulong, Guild>(1, (int)(guildCount * CollectionMultiplier));
_roles = new ConcurrentDictionary<ulong, Role>(1, (int)(guildCount * AverageRolesPerGuild * CollectionMultiplier));
_users = new ConcurrentDictionary<ulong, User>(1, (int)(guildCount * AverageUsersPerGuild * CollectionMultiplier));
}
public Channel GetChannel(ulong id)
{
Channel channel;
if (_channels.TryGetValue(id, out channel))
return channel;
return null;
}
public void AddChannel(Channel channel)
{
_channels[channel.Id] = channel;
}
public Channel RemoveChannel(ulong id)
{
Channel channel;
if (_channels.TryRemove(id, out channel))
return channel;
return null;
}
public Guild GetGuild(ulong id)
{
Guild guild;
if (_guilds.TryGetValue(id, out guild))
return guild;
return null;
}
public void AddGuild(Guild guild)
{
_guilds[guild.Id] = guild;
}
public Guild RemoveGuild(ulong id)
{
Guild guild;
if (_guilds.TryRemove(id, out guild))
return guild;
return null;
}
public Role GetRole(ulong id)
{
Role role;
if (_roles.TryGetValue(id, out role))
return role;
return null;
}
public void AddRole(Role role)
{
_roles[role.Id] = role;
}
public Role RemoveRole(ulong id)
{
Role role;
if (_roles.TryRemove(id, out role))
return role;
return null;
}
public User GetUser(ulong id)
{
User user;
if (_users.TryGetValue(id, out user))
return user;
return null;
}
public void AddUser(User user)
{
_users[user.Id] = user;
}
public User RemoveUser(ulong id)
{
User user;
if (_users.TryRemove(id, out user))
return user;
return null;
}
}
}

View File

@@ -0,0 +1,28 @@
using System.Collections.Generic;
namespace Discord.WebSocket.Data
{
public interface IDataStore
{
IEnumerable<Channel> Channels { get; }
IEnumerable<Guild> Guilds { get; }
IEnumerable<Role> Roles { get; }
IEnumerable<User> Users { get; }
Channel GetChannel(ulong id);
void AddChannel(Channel channel);
Channel RemoveChannel(ulong id);
Guild GetGuild(ulong id);
void AddGuild(Guild guild);
Guild RemoveGuild(ulong id);
Role GetRole(ulong id);
void AddRole(Role role);
Role RemoveRole(ulong id);
User GetUser(ulong id);
void AddUser(User user);
User RemoveUser(ulong id);
}
}

View File

@@ -0,0 +1,7 @@
namespace Discord.WebSocket.Data
{
//TODO: Implement
/*public class SharedDataStore
{
}*/
}

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More