[Feature] Voice reconnection and resuming (#2873)

* Voice receive fix (use system-selected port)

* Update SocketGuild.cs

* Reconnect voice after moved, resume voice connection, don't invoke Disconnected event when is going to reconnect

* no more collection primitives

* Disconnected event rallback & dispose audio client after finished

* Update src/Discord.Net.WebSocket/Audio/AudioClient.cs

* Update src/Discord.Net.WebSocket/Audio/AudioClient.cs

---------
This commit is contained in:
Богдан Петренко
2024-03-14 11:33:41 +02:00
committed by GitHub
parent d68e06e27d
commit 09680c51ac
5 changed files with 203 additions and 15 deletions

View File

@@ -0,0 +1,14 @@
using Newtonsoft.Json;
namespace Discord.API.Voice
{
public class ResumeParams
{
[JsonProperty("server_id")]
public ulong ServerId { get; set; }
[JsonProperty("session_id")]
public string SessionId { get; set; }
[JsonProperty("token")]
public string Token { get; set; }
}
}

View File

@@ -1,6 +1,7 @@
using Discord.API.Voice;
using Discord.Audio.Streams;
using Discord.Logging;
using Discord.Net;
using Discord.Net.Converters;
using Discord.WebSocket;
using Newtonsoft.Json;
@@ -9,18 +10,23 @@ using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
namespace Discord.Audio
{
//TODO: Add audio reconnecting
internal partial class AudioClient : IAudioClient
{
private static readonly int ConnectionTimeoutMs = 30000; // 30 seconds
private static readonly int KeepAliveIntervalMs = 5000; // 5 seconds
private static readonly int[] BlacklistedResumeCodes = new int[]
{
4001, 4002, 4003, 4004, 4005, 4006, 4009, 4012, 1014, 4016
};
private struct StreamPair
{
public AudioInStream Reader;
@@ -49,6 +55,8 @@ namespace Discord.Audio
private ulong _userId;
private uint _ssrc;
private bool _isSpeaking;
private StopReason _stopReason;
private bool _resuming;
public SocketGuild Guild { get; }
public DiscordVoiceAPIClient ApiClient { get; private set; }
@@ -56,6 +64,7 @@ namespace Discord.Audio
public int UdpLatency { get; private set; }
public ulong ChannelId { get; internal set; }
internal byte[] SecretKey { get; private set; }
internal bool IsFinished { get; private set; }
private DiscordSocketClient Discord => Guild.Discord;
public ConnectionState ConnectionState => _connection.State;
@@ -78,7 +87,7 @@ namespace Discord.Audio
_connection = new ConnectionManager(_stateLock, _audioLogger, ConnectionTimeoutMs,
OnConnectingAsync, OnDisconnectingAsync, x => ApiClient.Disconnected += x);
_connection.Connected += () => _connectedEvent.InvokeAsync();
_connection.Disconnected += (ex, recon) => _disconnectedEvent.InvokeAsync(ex);
_connection.Disconnected += (exception, _) => _disconnectedEvent.InvokeAsync(exception);
_heartbeatTimes = new ConcurrentQueue<long>();
_keepaliveTimes = new ConcurrentQueue<KeyValuePair<ulong, int>>();
_ssrcMap = new ConcurrentDictionary<uint, ulong>();
@@ -110,15 +119,30 @@ namespace Discord.Audio
}
public Task StopAsync()
=> _connection.StopAsync();
=> StopAsync(StopReason.Normal);
internal Task StopAsync(StopReason stopReason)
{
_stopReason = stopReason;
return _connection.StopAsync();
}
private async Task OnConnectingAsync()
{
await _audioLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false);
await _audioLogger.DebugAsync($"Connecting ApiClient. Voice server: wss://{_url}").ConfigureAwait(false);
await ApiClient.ConnectAsync($"wss://{_url}?v={DiscordConfig.VoiceAPIVersion}").ConfigureAwait(false);
await _audioLogger.DebugAsync($"Listening on port {ApiClient.UdpPort}").ConfigureAwait(false);
await _audioLogger.DebugAsync("Sending Identity").ConfigureAwait(false);
await ApiClient.SendIdentityAsync(_userId, _sessionId, _token).ConfigureAwait(false);
if (!_resuming)
{
await _audioLogger.DebugAsync("Sending Identity").ConfigureAwait(false);
await ApiClient.SendIdentityAsync(_userId, _sessionId, _token).ConfigureAwait(false);
}
else
{
await _audioLogger.DebugAsync("Sending Resume").ConfigureAwait(false);
await ApiClient.SendResume(_token, _sessionId).ConfigureAwait(false);
}
//Wait for READY
await _connection.WaitAsync().ConfigureAwait(false);
@@ -128,6 +152,63 @@ namespace Discord.Audio
await _audioLogger.DebugAsync("Disconnecting ApiClient").ConfigureAwait(false);
await ApiClient.DisconnectAsync().ConfigureAwait(false);
if (_stopReason == StopReason.Unknown && ex.InnerException is WebSocketException exception)
{
await _audioLogger.WarningAsync(
$"Audio connection terminated with unknown reason. Code: {exception.ErrorCode} - {exception.Message}",
exception);
if (_resuming)
{
await _audioLogger.WarningAsync("Resume failed");
_resuming = false;
await FinishDisconnect(ex, true);
return;
}
if (BlacklistedResumeCodes.Contains(exception.ErrorCode))
{
await FinishDisconnect(ex, true);
return;
}
await ClearHeartBeaters();
_resuming = true;
return;
}
await FinishDisconnect(ex, _stopReason != StopReason.Moved);
if (_stopReason == StopReason.Normal)
{
await _audioLogger.DebugAsync("Sending Voice State").ConfigureAwait(false);
await Discord.ApiClient.SendVoiceStateUpdateAsync(Guild.Id, null, false, false).ConfigureAwait(false);
}
_stopReason = StopReason.Unknown;
}
private async Task FinishDisconnect(Exception ex, bool wontTryReconnect)
{
await _audioLogger.DebugAsync("Finishing audio connection").ConfigureAwait(false);
await ClearHeartBeaters().ConfigureAwait(false);
if (wontTryReconnect)
{
await _connection.StopAsync().ConfigureAwait(false);
await ClearInputStreamsAsync().ConfigureAwait(false);
IsFinished = true;
}
}
private async Task ClearHeartBeaters()
{
//Wait for tasks to complete
await _audioLogger.DebugAsync("Waiting for heartbeater").ConfigureAwait(false);
@@ -143,12 +224,11 @@ namespace Discord.Audio
{ }
_lastMessageTime = 0;
await ClearInputStreamsAsync().ConfigureAwait(false);
await _audioLogger.DebugAsync("Sending Voice State").ConfigureAwait(false);
await Discord.ApiClient.SendVoiceStateUpdateAsync(Guild.Id, null, false, false).ConfigureAwait(false);
while (_keepaliveTimes.TryDequeue(out _))
{ }
}
#region Streams
public AudioOutStream CreateOpusStream(int bufferMillis)
{
var outputStream = new OutputStream(ApiClient); //Ignores header
@@ -217,6 +297,7 @@ namespace Discord.Audio
_ssrcMap.Clear();
_streams.Clear();
}
#endregion
private async Task ProcessMessageAsync(VoiceOpCode opCode, object payload)
{
@@ -285,7 +366,7 @@ namespace Discord.Audio
await _audioLogger.DebugAsync("Received Speaking").ConfigureAwait(false);
var data = (payload as JToken).ToObject<SpeakingEvent>(_serializer);
_ssrcMap[data.Ssrc] = data.UserId; //TODO: Memory Leak: SSRCs are never cleaned up
_ssrcMap[data.Ssrc] = data.UserId;
await _speakingUpdatedEvent.InvokeAsync(data.UserId, data.Speaking);
}
@@ -299,6 +380,17 @@ namespace Discord.Audio
await _clientDisconnectedEvent.InvokeAsync(data.UserId);
}
break;
case VoiceOpCode.Resumed:
{
await _audioLogger.DebugAsync($"Voice connection resumed: wss://{_url}");
_resuming = false;
_heartbeatTask = RunHeartbeatAsync(_heartbeatInterval, _connection.CancelToken);
_keepaliveTask = RunKeepaliveAsync(_connection.CancelToken);
_ = _connection.CompleteAsync();
}
break;
default:
await _audioLogger.WarningAsync($"Unknown OpCode ({opCode})").ConfigureAwait(false);
break;
@@ -485,6 +577,49 @@ namespace Discord.Audio
}
}
/// <summary>
/// Waits until all post-disconnect actions are done.
/// </summary>
/// <param name="timeout">Maximum time to wait.</param>
/// <returns>
/// A <see cref="Task"/> that represents an asynchronous process of waiting.
/// </returns>
internal async Task WaitForDisconnectAsync(TimeSpan timeout)
{
if (ConnectionState == ConnectionState.Disconnected)
return;
var completion = new TaskCompletionSource<Exception>();
var cts = new CancellationTokenSource();
var _ = Task.Delay(timeout, cts.Token).ContinueWith(_ =>
{
completion.TrySetException(new TimeoutException("Exceeded maximum time to wait"));
cts.Dispose();
}, cts.Token);
_connection.Disconnected += HandleDisconnectSubscription;
await completion.Task.ConfigureAwait(false);
Task HandleDisconnectSubscription(Exception exception, bool reconnect)
{
try
{
cts.Cancel();
completion.TrySetResult(exception);
}
finally
{
_connection.Disconnected -= HandleDisconnectSubscription;
cts.Dispose();
}
return Task.CompletedTask;
}
}
internal void Dispose(bool disposing)
{
if (disposing)
@@ -496,5 +631,13 @@ namespace Discord.Audio
}
/// <inheritdoc />
public void Dispose() => Dispose(true);
internal enum StopReason
{
Unknown = 0,
Normal,
Disconnected,
Moved
}
}
}

View File

@@ -172,8 +172,8 @@ namespace Discord
await _onDisconnecting(ex).ConfigureAwait(false);
await _disconnectedEvent.InvokeAsync(ex, isReconnecting).ConfigureAwait(false);
State = ConnectionState.Disconnected;
await _disconnectedEvent.InvokeAsync(ex, isReconnecting).ConfigureAwait(false);
await _logger.InfoAsync("Disconnected").ConfigureAwait(false);
}

View File

@@ -166,6 +166,16 @@ namespace Discord.Audio
});
}
public Task SendResume(string token, string sessionId)
{
return SendAsync(VoiceOpCode.Resume, new ResumeParams
{
ServerId = GuildId,
SessionId = sessionId,
Token = token
});
}
public async Task ConnectAsync(string url)
{
await _connectionLock.WaitAsync().ConfigureAwait(false);

View File

@@ -1450,7 +1450,7 @@ namespace Discord.WebSocket
/// <returns>
/// A task that represents the asynchronous get operation. The task result contains a read-only collection
/// of the requested audit log entries.
/// </returns>
/// </returns>
public IAsyncEnumerable<IReadOnlyCollection<RestAuditLogEntry>> GetAuditLogsAsync(int limit, RequestOptions options = null, ulong? beforeId = null, ulong? userId = null, ActionType? actionType = null, ulong? afterId = null)
=> GuildHelper.GetAuditLogsAsync(this, Discord, beforeId, limit, options, userId: userId, actionType: actionType, afterId: afterId);
@@ -1687,7 +1687,7 @@ namespace Discord.WebSocket
if (after.VoiceChannel != null && _audioClient.ChannelId != after.VoiceChannel?.Id)
{
_audioClient.ChannelId = after.VoiceChannel.Id;
await RepopulateAudioStreamsAsync().ConfigureAwait(false);
await _audioClient.StopAsync(Audio.AudioClient.StopReason.Moved);
}
}
else
@@ -1711,7 +1711,13 @@ namespace Discord.WebSocket
if (_voiceStates.TryRemove(id, out SocketVoiceState voiceState))
{
if (_audioClient != null)
{
await _audioClient.RemoveInputStreamAsync(id).ConfigureAwait(false); //User changed channels, end their stream
if (id == CurrentUser.Id)
await _audioClient.StopAsync(Audio.AudioClient.StopReason.Disconnected);
}
return voiceState;
}
return null;
@@ -1755,7 +1761,7 @@ namespace Discord.WebSocket
var audioClient = new AudioClient(this, Discord.GetAudioId(), channelId);
audioClient.Disconnected += async ex =>
{
if (!promise.Task.IsCompleted)
if (promise.Task.IsCompleted && audioClient.IsFinished)
{
try
{ audioClient.Dispose(); }
@@ -1866,6 +1872,21 @@ namespace Discord.WebSocket
if (_audioClient != null)
{
await RepopulateAudioStreamsAsync().ConfigureAwait(false);
if (_audioClient.ConnectionState != ConnectionState.Disconnected)
{
try
{
await _audioClient.WaitForDisconnectAsync(TimeSpan.FromSeconds(5)).ConfigureAwait(false);
}
catch (TimeoutException)
{
await Discord.LogManager.WarningAsync("Failed to wait for disconnect audio client in time", null).ConfigureAwait(false);
}
}
await Task.Delay(TimeSpan.FromMilliseconds(5)).ConfigureAwait(false);
await _audioClient.StartAsync(url, Discord.CurrentUser.Id, voiceState.VoiceSessionId, token).ConfigureAwait(false);
}
}