Started WebSocket cleanup

This commit is contained in:
RogueException
2015-12-08 09:35:18 -04:00
parent d2de658df0
commit 3fe7124f4d
6 changed files with 100 additions and 110 deletions

View File

@@ -74,7 +74,7 @@ namespace Discord.Audio
var client = _service.Client; var client = _service.Client;
string token = e.Payload.Value<string>("token"); string token = e.Payload.Value<string>("token");
_voiceSocket.Host = "wss://" + e.Payload.Value<string>("endpoint").Split(':')[0]; _voiceSocket.Host = "wss://" + e.Payload.Value<string>("endpoint").Split(':')[0];
await _voiceSocket.Login(client.CurrentUser.Id, _gatewaySocket.SessionId, token, client.CancelToken).ConfigureAwait(false); await _voiceSocket.Connect(client.CurrentUser.Id, _gatewaySocket.SessionId, token/*, client.CancelToken*/).ConfigureAwait(false);
} }
} }
break; break;

View File

@@ -58,20 +58,13 @@ namespace Discord.Net.WebSockets
_sendBuffer = new VoiceBuffer((int)Math.Ceiling(_audioConfig.BufferLength / (double)_encoder.FrameLength), _encoder.FrameSize); _sendBuffer = new VoiceBuffer((int)Math.Ceiling(_audioConfig.BufferLength / (double)_encoder.FrameLength), _encoder.FrameSize);
} }
public async Task Login(long userId, string sessionId, string token, CancellationToken cancelToken) public async Task Connect(long userId, string sessionId, string token)
{ {
if ((WebSocketState)_state == WebSocketState.Connected)
{
//Adjust the host and tell the system to reconnect
await DisconnectInternal(new Exception("Server transfer occurred."), isUnexpected: false).ConfigureAwait(false);
return;
}
_userId = userId; _userId = userId;
_sessionId = sessionId; _sessionId = sessionId;
_token = token; _token = token;
await Start().ConfigureAwait(false); await BeginConnect().ConfigureAwait(false);
} }
public async Task Reconnect() public async Task Reconnect()
{ {
@@ -83,9 +76,7 @@ namespace Discord.Net.WebSockets
{ {
try try
{ {
//This check is needed in case we start a reconnect before the initial login completes await Connect(_userId.Value, _sessionId, _token).ConfigureAwait(false);
if (_state != (int)WebSocketState.Disconnected)
await Start().ConfigureAwait(false);
break; break;
} }
catch (OperationCanceledException) { throw; } catch (OperationCanceledException) { throw; }
@@ -99,13 +90,15 @@ namespace Discord.Net.WebSockets
} }
catch (OperationCanceledException) { } catch (OperationCanceledException) { }
} }
public Task Disconnect()
{
return SignalDisconnect(wait: true);
}
protected override IEnumerable<Task> GetTasks() protected override async Task Run()
{ {
_udp = new UdpClient(new IPEndPoint(IPAddress.Any, 0)); _udp = new UdpClient(new IPEndPoint(IPAddress.Any, 0));
SendIdentify();
List<Task> tasks = new List<Task>(); List<Task> tasks = new List<Task>();
if ((_audioConfig.Mode & AudioMode.Outgoing) != 0) if ((_audioConfig.Mode & AudioMode.Outgoing) != 0)
{ {
@@ -113,34 +106,23 @@ namespace Discord.Net.WebSockets
_sendThread.IsBackground = true; _sendThread.IsBackground = true;
_sendThread.Start(); _sendThread.Start();
} }
//This thread is required to establish a connection even if we're outgoing only
if ((_audioConfig.Mode & AudioMode.Incoming) != 0) if ((_audioConfig.Mode & AudioMode.Incoming) != 0)
{ {
_receiveThread = new Thread(new ThreadStart(() => ReceiveVoiceAsync(_cancelToken))); _receiveThread = new Thread(new ThreadStart(() => ReceiveVoiceAsync(_cancelToken)));
_receiveThread.IsBackground = true; _receiveThread.IsBackground = true;
_receiveThread.Start(); _receiveThread.Start();
} }
else //Dont make an OS thread if we only want to capture one packet...
tasks.Add(Task.Run(() => ReceiveVoiceAsync(_cancelToken))); SendIdentify();
#if !DOTNET5_4 #if !DOTNET5_4
tasks.Add(WatcherAsync()); tasks.Add(WatcherAsync());
#endif #endif
if (tasks.Count > 0) await RunTasks(tasks.ToArray());
{
// We need to combine tasks into one because receiveThread is
// supposed to exit early if it's an outgoing-only client
// and we dont want the main thread to think we errored
var task = Task.WhenAll(tasks);
tasks.Clear();
tasks.Add(task);
}
tasks.AddRange(base.GetTasks());
return new Task[] { Task.WhenAll(tasks.ToArray()) }; await Cleanup();
} }
protected override Task Stop() protected override Task Cleanup()
{ {
if (_sendThread != null) if (_sendThread != null)
_sendThread.Join(); _sendThread.Join();
@@ -165,7 +147,7 @@ namespace Discord.Net.WebSockets
} }
_udp = null; _udp = null;
return base.Stop(); return base.Cleanup();
} }
private void ReceiveVoiceAsync(CancellationToken cancelToken) private void ReceiveVoiceAsync(CancellationToken cancelToken)
@@ -474,7 +456,7 @@ namespace Discord.Net.WebSockets
var payload = (msg.Payload as JToken).ToObject<JoinServerEvent>(_serializer); var payload = (msg.Payload as JToken).ToObject<JoinServerEvent>(_serializer);
_secretKey = payload.SecretKey; _secretKey = payload.SecretKey;
SendIsTalking(true); SendIsTalking(true);
EndConnect(); await EndConnect();
} }
break; break;
case VoiceOpCodes.Speaking: case VoiceOpCodes.Speaking:

View File

@@ -266,11 +266,9 @@ namespace Discord
if (_state == (int)DiscordClientState.Connecting) if (_state == (int)DiscordClientState.Connecting)
CompleteConnect(); CompleteConnect();
}; };
socket.Disconnected += async (s, e) => socket.Disconnected += (s, e) =>
{ {
RaiseDisconnected(e); RaiseDisconnected(e);
if (e.WasUnexpected)
await socket.Reconnect(_token).ConfigureAwait(false);
}; };
socket.ReceivedDispatch += async (s, e) => await OnReceivedEvent(e).ConfigureAwait(false); socket.ReceivedDispatch += async (s, e) => await OnReceivedEvent(e).ConfigureAwait(false);
@@ -329,7 +327,7 @@ namespace Discord
_webSocket.Host = gateway; _webSocket.Host = gateway;
_webSocket.ParentCancelToken = _cancelToken; _webSocket.ParentCancelToken = _cancelToken;
await _webSocket.Login(token).ConfigureAwait(false); await _webSocket.Connect(token).ConfigureAwait(false);
_runTask = RunTasks(); _runTask = RunTasks();
@@ -422,7 +420,7 @@ namespace Discord
var wasDisconnectUnexpected = _wasDisconnectUnexpected; var wasDisconnectUnexpected = _wasDisconnectUnexpected;
_wasDisconnectUnexpected = false; _wasDisconnectUnexpected = false;
await _webSocket.Disconnect().ConfigureAwait(false); await _webSocket.SignalDisconnect().ConfigureAwait(false);
_userId = null; _userId = null;
_gateway = null; _gateway = null;

View File

@@ -8,7 +8,11 @@ namespace Discord.Net.WebSockets
{ {
public partial class GatewayWebSocket : WebSocket public partial class GatewayWebSocket : WebSocket
{ {
private int _lastSeq; public int LastSequence => _lastSeq;
private int _lastSeq;
public string Token => _token;
private string _token;
public string SessionId => _sessionId; public string SessionId => _sessionId;
private string _sessionId; private string _sessionId;
@@ -16,25 +20,25 @@ namespace Discord.Net.WebSockets
public GatewayWebSocket(DiscordConfig config, Logger logger) public GatewayWebSocket(DiscordConfig config, Logger logger)
: base(config, logger) : base(config, logger)
{ {
Disconnected += async (s, e) =>
{
if (e.WasUnexpected)
await Reconnect().ConfigureAwait(false);
};
} }
public async Task Login(string token) public async Task Connect(string token)
{ {
_token = token;
await BeginConnect().ConfigureAwait(false); await BeginConnect().ConfigureAwait(false);
await Start().ConfigureAwait(false);
SendIdentify(token); SendIdentify(token);
} }
private async Task Redirect(string server) private async Task Redirect(string server)
{ {
await DisconnectInternal(isUnexpected: false).ConfigureAwait(false);
await BeginConnect().ConfigureAwait(false); await BeginConnect().ConfigureAwait(false);
await Start().ConfigureAwait(false);
SendResume(); SendResume();
} }
public async Task Reconnect(string token) private async Task Reconnect()
{ {
try try
{ {
@@ -44,7 +48,7 @@ namespace Discord.Net.WebSockets
{ {
try try
{ {
await Login(token).ConfigureAwait(false); await Connect(_token).ConfigureAwait(false);
break; break;
} }
catch (OperationCanceledException) { throw; } catch (OperationCanceledException) { throw; }
@@ -58,6 +62,15 @@ namespace Discord.Net.WebSockets
} }
catch (OperationCanceledException) { } catch (OperationCanceledException) { }
} }
public Task Disconnect()
{
return SignalDisconnect(wait: true);
}
protected override async Task Run()
{
await RunTasks();
}
protected override async Task ProcessMessage(string json) protected override async Task ProcessMessage(string json)
{ {
@@ -85,7 +98,7 @@ namespace Discord.Net.WebSockets
} }
RaiseReceivedDispatch(msg.Type, token); RaiseReceivedDispatch(msg.Type, token);
if (msg.Type == "READY" || msg.Type == "RESUMED") if (msg.Type == "READY" || msg.Type == "RESUMED")
EndConnect(); await EndConnect(); //Complete the connect
} }
break; break;
case GatewayOpCodes.Redirect: case GatewayOpCodes.Redirect:

View File

@@ -114,38 +114,50 @@ namespace Discord.Net.WebSockets
{ {
try try
{ {
await Disconnect().ConfigureAwait(false); await SignalDisconnect(wait: true).ConfigureAwait(false);
_state = (int)WebSocketState.Connecting;
if (ParentCancelToken == null) if (ParentCancelToken == null)
throw new InvalidOperationException("Parent cancel token was never set."); throw new InvalidOperationException("Parent cancel token was never set.");
_cancelTokenSource = new CancellationTokenSource(); _cancelTokenSource = new CancellationTokenSource();
_cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_cancelTokenSource.Token, ParentCancelToken.Value).Token; _cancelToken = CancellationTokenSource.CreateLinkedTokenSource(_cancelTokenSource.Token, ParentCancelToken.Value).Token;
_state = (int)WebSocketState.Connecting; if (_state != (int)WebSocketState.Connecting)
throw new InvalidOperationException("Socket is in the wrong state.");
_lastHeartbeat = DateTime.UtcNow;
await _engine.Connect(Host, _cancelToken).ConfigureAwait(false);
_runTask = Run();
} }
catch (Exception ex) catch (Exception ex)
{ {
await DisconnectInternal(ex, isUnexpected: false).ConfigureAwait(false); await SignalDisconnect(ex, isUnexpected: false).ConfigureAwait(false);
throw; throw;
} }
} }
protected void EndConnect() protected async Task EndConnect()
{ {
_state = (int)WebSocketState.Connected; try
_connectedEvent.Set(); {
RaiseConnected(); _state = (int)WebSocketState.Connected;
}
public Task Disconnect() => DisconnectInternal(new Exception("Disconnect was requested by user."), isUnexpected: false); _connectedEvent.Set();
protected internal async Task DisconnectInternal(Exception ex = null, bool isUnexpected = true, bool skipAwait = false) RaiseConnected();
}
catch (Exception ex)
{
await SignalDisconnect(ex, isUnexpected: false).ConfigureAwait(false);
throw;
}
}
protected internal async Task SignalDisconnect(Exception ex = null, bool isUnexpected = false, bool wait = false)
{ {
int oldState;
bool hasWriterLock;
//If in either connecting or connected state, get a lock by being the first to switch to disconnecting //If in either connecting or connected state, get a lock by being the first to switch to disconnecting
oldState = Interlocked.CompareExchange(ref _state, (int)WebSocketState.Disconnecting, (int)WebSocketState.Connecting); int oldState = Interlocked.CompareExchange(ref _state, (int)WebSocketState.Disconnecting, (int)WebSocketState.Connecting);
if (oldState == (int)WebSocketState.Disconnected) return; //Already disconnected if (oldState == (int)WebSocketState.Disconnected) return; //Already disconnected
hasWriterLock = oldState == (int)WebSocketState.Connecting; //Caused state change bool hasWriterLock = oldState == (int)WebSocketState.Connecting; //Caused state change
if (!hasWriterLock) if (!hasWriterLock)
{ {
oldState = Interlocked.CompareExchange(ref _state, (int)WebSocketState.Disconnecting, (int)WebSocketState.Connected); oldState = Interlocked.CompareExchange(ref _state, (int)WebSocketState.Disconnecting, (int)WebSocketState.Connected);
@@ -155,70 +167,55 @@ namespace Discord.Net.WebSockets
if (hasWriterLock) if (hasWriterLock)
{ {
_wasDisconnectUnexpected = isUnexpected; CaptureError(ex ?? new Exception("Disconnect was requested."), isUnexpected);
_disconnectState = (WebSocketState)oldState;
_disconnectReason = ex != null ? ExceptionDispatchInfo.Capture(ex) : null;
_cancelTokenSource.Cancel(); _cancelTokenSource.Cancel();
if (_disconnectState == WebSocketState.Connecting) //_runTask was never made if (_disconnectState == WebSocketState.Connecting) //_runTask was never made
await Stop().ConfigureAwait(false); await Cleanup().ConfigureAwait(false);
} }
if (!skipAwait) if (!wait)
{ {
Task task = _runTask; Task task = _runTask;
if (_runTask != null) if (_runTask != null)
await task.ConfigureAwait(false); await task.ConfigureAwait(false);
} }
} }
private void CaptureError(Exception ex, bool isUnexpected)
protected virtual async Task Start()
{ {
try _disconnectReason = ExceptionDispatchInfo.Capture(ex);
{ _wasDisconnectUnexpected = isUnexpected;
if (_state != (int)WebSocketState.Connecting)
throw new InvalidOperationException("Socket is in the wrong state.");
_lastHeartbeat = DateTime.UtcNow;
await _engine.Connect(Host, _cancelToken).ConfigureAwait(false);
_runTask = RunTasks();
}
catch (Exception ex)
{
await DisconnectInternal(ex, isUnexpected: false).ConfigureAwait(false);
throw;
}
} }
protected virtual async Task RunTasks() protected abstract Task Run();
protected async Task RunTasks(params Task[] tasks)
{ {
Task[] tasks = GetTasks().ToArray(); //Get all async tasks
tasks = tasks
.Concat(_engine.GetTasks(_cancelToken))
.Concat(new Task[] { HeartbeatAsync(_cancelToken) })
.ToArray();
//Create group tasks
Task firstTask = Task.WhenAny(tasks); Task firstTask = Task.WhenAny(tasks);
Task allTasks = Task.WhenAll(tasks); Task allTasks = Task.WhenAll(tasks);
//Wait until the first task ends/errors and capture the error //Wait until the first task ends/errors and capture the error
try { await firstTask.ConfigureAwait(false); } Exception ex = null;
catch (Exception ex) { await DisconnectInternal(ex: ex, skipAwait: true).ConfigureAwait(false); } try { await firstTask.ConfigureAwait(false); }
catch (Exception ex2) { ex = ex2; }
//Ensure all other tasks are signaled to end. //Ensure all other tasks are signaled to end.
await DisconnectInternal(skipAwait: true).ConfigureAwait(false); await SignalDisconnect(ex, ex != null, true).ConfigureAwait(false);
//Wait for the remaining tasks to complete //Wait for the remaining tasks to complete
try { await allTasks.ConfigureAwait(false); } try { await allTasks.ConfigureAwait(false); }
catch { } catch { }
//Start cleanup //Start cleanup
await Stop().ConfigureAwait(false); await Cleanup().ConfigureAwait(false);
}
protected virtual IEnumerable<Task> GetTasks()
{
var cancelToken = _cancelToken;
return _engine.GetTasks(cancelToken)
.Concat(new Task[] { HeartbeatAsync(cancelToken) });
} }
protected virtual async Task Stop() protected virtual async Task Cleanup()
{ {
var disconnectState = _disconnectState; var disconnectState = _disconnectState;
_disconnectState = WebSocketState.Disconnected; _disconnectState = WebSocketState.Disconnected;
@@ -254,7 +251,7 @@ namespace Discord.Net.WebSockets
private Task HeartbeatAsync(CancellationToken cancelToken) private Task HeartbeatAsync(CancellationToken cancelToken)
{ {
return Task.Run((Func<Task>)(async () => return Task.Run(async () =>
{ {
try try
{ {
@@ -270,7 +267,7 @@ namespace Discord.Net.WebSockets
} }
} }
catch (OperationCanceledException) { } catch (OperationCanceledException) { }
})); });
} }
protected internal void ThrowError() protected internal void ThrowError()

View File

@@ -55,14 +55,14 @@ namespace Discord.Net.WebSockets
_webSocket.OnError += async (s, e) => _webSocket.OnError += async (s, e) =>
{ {
_logger.Log(LogSeverity.Error, "WebSocket Error", e.Exception); _logger.Log(LogSeverity.Error, "WebSocket Error", e.Exception);
await _parent.DisconnectInternal(e.Exception, skipAwait: true).ConfigureAwait(false); await _parent.SignalDisconnect(e.Exception, isUnexpected: true).ConfigureAwait(false);
}; };
_webSocket.OnClose += async (s, e) => _webSocket.OnClose += async (s, e) =>
{ {
string code = e.WasClean ? e.Code.ToString() : "Unexpected"; string code = e.WasClean ? e.Code.ToString() : "Unexpected";
string reason = e.Reason != "" ? e.Reason : "No Reason"; string reason = e.Reason != "" ? e.Reason : "No Reason";
Exception ex = new Exception($"Got Close Message ({code}): {reason}"); var ex = new Exception($"Got Close Message ({code}): {reason}");
await _parent.DisconnectInternal(ex, skipAwait: true).ConfigureAwait(false); await _parent.SignalDisconnect(ex, isUnexpected: true).ConfigureAwait(false);
}; };
_webSocket.Log.Output = (e, m) => { }; //Dont let websocket-sharp print to console directly _webSocket.Log.Output = (e, m) => { }; //Dont let websocket-sharp print to console directly
_webSocket.Connect(); _webSocket.Connect();