feature: Implement gateway ratelimit (#1537)

* Implement gateway ratelimit

* Remove unused code

* Share WebSocketRequestQueue between clients

* Add global limit and a way to change gateway limits

* Refactoring variable to fit lib standards

* Update xml docs

* Update xml docs

* Move warning to remarks

* Remove specific RequestQueue for WebSocket and other changes

The only account limit is for identify that is dealt in a different way (exclusive semaphore), so websocket queues can be shared with REST and don't need to be shared between clients anymore.

Also added the ratelimit for presence updates.

* Add summary to IdentifySemaphoreName

* Fix spacing

* Add max_concurrency and other fixes

- Add session_start_limit to GetBotGatewayResponse
- Add GetBotGatewayAsync to IDiscordClient
- Add master/slave semaphores to enable concurrency
- Not store semaphore name as static
- Clone GatewayLimits when cloning the Config

* Add missing RequestQueue parameter and wrong nullable

* Add RequeueQueue paramater to Webhook

* Better xml documentation

* Remove GatewayLimits class and other changes

- Remove GatewayLimits
- Transfer a few properties to DiscordSocketConfig
- Remove unnecessary usings

* Remove unnecessary using and wording

* Remove more unnecessary usings

* Change named Semaphores to SemaphoreSlim

* Remove unused using

* Update branch

* Fix merge conflicts and update to new ratelimit

* Fixing merge, ignore limit for heartbeat, and dispose

* Missed one place and better xml docs.

* Wait identify before opening the connection

* Only request identify ticket when needed

* Move identify control to sharded client

* Better description for IdentifyMaxConcurrency

* Add lock to InvalidSession
This commit is contained in:
Paulo
2020-11-18 23:40:09 -03:00
committed by GitHub
parent 97e71cd5e5
commit ec673e1863
17 changed files with 397 additions and 40 deletions

View File

@@ -12,12 +12,14 @@ namespace Discord.WebSocket
public partial class DiscordShardedClient : BaseSocketClient, IDiscordClient
{
private readonly DiscordSocketConfig _baseConfig;
private readonly SemaphoreSlim _connectionGroupLock;
private readonly Dictionary<int, int> _shardIdsToIndex;
private readonly bool _automaticShards;
private int[] _shardIds;
private DiscordSocketClient[] _shards;
private int _totalShards;
private SemaphoreSlim[] _identifySemaphores;
private object _semaphoreResetLock;
private Task _semaphoreResetTask;
private bool _isDisposed;
@@ -62,10 +64,10 @@ namespace Discord.WebSocket
if (ids != null && config.TotalShards == null)
throw new ArgumentException($"Custom ids are not supported when {nameof(config.TotalShards)} is not specified.");
_semaphoreResetLock = new object();
_shardIdsToIndex = new Dictionary<int, int>();
config.DisplayInitialLog = false;
_baseConfig = config;
_connectionGroupLock = new SemaphoreSlim(1, 1);
if (config.TotalShards == null)
_automaticShards = true;
@@ -74,12 +76,15 @@ namespace Discord.WebSocket
_totalShards = config.TotalShards.Value;
_shardIds = ids ?? Enumerable.Range(0, _totalShards).ToArray();
_shards = new DiscordSocketClient[_shardIds.Length];
_identifySemaphores = new SemaphoreSlim[config.IdentifyMaxConcurrency];
for (int i = 0; i < config.IdentifyMaxConcurrency; i++)
_identifySemaphores[i] = new SemaphoreSlim(1, 1);
for (int i = 0; i < _shardIds.Length; i++)
{
_shardIdsToIndex.Add(_shardIds[i], i);
var newConfig = config.Clone();
newConfig.ShardId = _shardIds[i];
_shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null);
_shards[i] = new DiscordSocketClient(newConfig, this, i != 0 ? _shards[0] : null);
RegisterEvents(_shards[i], i == 0);
}
}
@@ -88,21 +93,53 @@ namespace Discord.WebSocket
=> new API.DiscordSocketApiClient(config.RestClientProvider, config.WebSocketProvider, DiscordRestConfig.UserAgent,
rateLimitPrecision: config.RateLimitPrecision);
internal async Task AcquireIdentifyLockAsync(int shardId, CancellationToken token)
{
int semaphoreIdx = shardId % _baseConfig.IdentifyMaxConcurrency;
await _identifySemaphores[semaphoreIdx].WaitAsync(token).ConfigureAwait(false);
}
internal void ReleaseIdentifyLock()
{
lock (_semaphoreResetLock)
{
if (_semaphoreResetTask == null)
_semaphoreResetTask = ResetSemaphoresAsync();
}
}
private async Task ResetSemaphoresAsync()
{
await Task.Delay(5000).ConfigureAwait(false);
lock (_semaphoreResetLock)
{
foreach (var semaphore in _identifySemaphores)
if (semaphore.CurrentCount == 0)
semaphore.Release();
_semaphoreResetTask = null;
}
}
internal override async Task OnLoginAsync(TokenType tokenType, string token)
{
if (_automaticShards)
{
var shardCount = await GetRecommendedShardCountAsync().ConfigureAwait(false);
_shardIds = Enumerable.Range(0, shardCount).ToArray();
var botGateway = await GetBotGatewayAsync().ConfigureAwait(false);
_shardIds = Enumerable.Range(0, botGateway.Shards).ToArray();
_totalShards = _shardIds.Length;
_shards = new DiscordSocketClient[_shardIds.Length];
int maxConcurrency = botGateway.SessionStartLimit.MaxConcurrency;
_baseConfig.IdentifyMaxConcurrency = maxConcurrency;
_identifySemaphores = new SemaphoreSlim[maxConcurrency];
for (int i = 0; i < maxConcurrency; i++)
_identifySemaphores[i] = new SemaphoreSlim(1, 1);
for (int i = 0; i < _shardIds.Length; i++)
{
_shardIdsToIndex.Add(_shardIds[i], i);
var newConfig = _baseConfig.Clone();
newConfig.ShardId = _shardIds[i];
newConfig.TotalShards = _totalShards;
_shards[i] = new DiscordSocketClient(newConfig, _connectionGroupLock, i != 0 ? _shards[0] : null);
_shards[i] = new DiscordSocketClient(newConfig, this, i != 0 ? _shards[0] : null);
RegisterEvents(_shards[i], i == 0);
}
}
@@ -398,7 +435,6 @@ namespace Discord.WebSocket
foreach (var client in _shards)
client?.Dispose();
}
_connectionGroupLock?.Dispose();
}
_isDisposed = true;

View File

@@ -132,6 +132,8 @@ namespace Discord.API
if (WebSocketClient == null)
throw new NotSupportedException("This client is not configured with WebSocket support.");
RequestQueue.ClearGatewayBuckets();
//Re-create streams to reset the zlib state
_compressed?.Dispose();
_decompressor?.Dispose();
@@ -205,7 +207,11 @@ namespace Discord.API
payload = new SocketFrame { Operation = (int)opCode, Payload = payload };
if (payload != null)
bytes = Encoding.UTF8.GetBytes(SerializeJson(payload));
await RequestQueue.SendAsync(new WebSocketRequest(WebSocketClient, null, bytes, true, options)).ConfigureAwait(false);
options.IsGatewayBucket = true;
if (options.BucketId == null)
options.BucketId = GatewayBucket.Get(GatewayBucketType.Unbucketed).Id;
await RequestQueue.SendAsync(new WebSocketRequest(WebSocketClient, bytes, true, opCode == GatewayOpCode.Heartbeat, options)).ConfigureAwait(false);
await _sentGatewayMessageEvent.InvokeAsync(opCode).ConfigureAwait(false);
}
@@ -225,6 +231,8 @@ namespace Discord.API
if (totalShards > 1)
msg.ShardingParams = new int[] { shardID, totalShards };
options.BucketId = GatewayBucket.Get(GatewayBucketType.Identify).Id;
if (gatewayIntents.HasValue)
msg.Intents = (int)gatewayIntents.Value;
else
@@ -258,6 +266,7 @@ namespace Discord.API
IsAFK = isAFK,
Game = game
};
options.BucketId = GatewayBucket.Get(GatewayBucketType.PresenceUpdate).Id;
await SendGatewayAsync(GatewayOpCode.StatusUpdate, args, options: options).ConfigureAwait(false);
}
public async Task SendRequestMembersAsync(IEnumerable<ulong> guildIds, RequestOptions options = null)

View File

@@ -26,7 +26,7 @@ namespace Discord.WebSocket
{
private readonly ConcurrentQueue<ulong> _largeGuilds;
private readonly JsonSerializer _serializer;
private readonly SemaphoreSlim _connectionGroupLock;
private readonly DiscordShardedClient _shardedClient;
private readonly DiscordSocketClient _parentClient;
private readonly ConcurrentQueue<long> _heartbeatTimes;
private readonly ConnectionManager _connection;
@@ -120,9 +120,9 @@ namespace Discord.WebSocket
/// <param name="config">The configuration to be used with the client.</param>
#pragma warning disable IDISP004
public DiscordSocketClient(DiscordSocketConfig config) : this(config, CreateApiClient(config), null, null) { }
internal DiscordSocketClient(DiscordSocketConfig config, SemaphoreSlim groupLock, DiscordSocketClient parentClient) : this(config, CreateApiClient(config), groupLock, parentClient) { }
internal DiscordSocketClient(DiscordSocketConfig config, DiscordShardedClient shardedClient, DiscordSocketClient parentClient) : this(config, CreateApiClient(config), shardedClient, parentClient) { }
#pragma warning restore IDISP004
private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, SemaphoreSlim groupLock, DiscordSocketClient parentClient)
private DiscordSocketClient(DiscordSocketConfig config, API.DiscordSocketApiClient client, DiscordShardedClient shardedClient, DiscordSocketClient parentClient)
: base(config, client)
{
ShardId = config.ShardId ?? 0;
@@ -148,7 +148,7 @@ namespace Discord.WebSocket
_connection.Disconnected += (ex, recon) => TimedInvokeAsync(_disconnectedEvent, nameof(Disconnected), ex);
_nextAudioId = 1;
_connectionGroupLock = groupLock;
_shardedClient = shardedClient;
_parentClient = parentClient;
_serializer = new JsonSerializer { ContractResolver = new DiscordContractResolver() };
@@ -229,8 +229,12 @@ namespace Discord.WebSocket
private async Task OnConnectingAsync()
{
if (_connectionGroupLock != null)
await _connectionGroupLock.WaitAsync(_connection.CancelToken).ConfigureAwait(false);
bool locked = false;
if (_shardedClient != null && _sessionId == null)
{
await _shardedClient.AcquireIdentifyLockAsync(ShardId, _connection.CancelToken).ConfigureAwait(false);
locked = true;
}
try
{
await _gatewayLogger.DebugAsync("Connecting ApiClient").ConfigureAwait(false);
@@ -255,11 +259,8 @@ namespace Discord.WebSocket
}
finally
{
if (_connectionGroupLock != null)
{
await Task.Delay(5000).ConfigureAwait(false);
_connectionGroupLock.Release();
}
if (locked)
_shardedClient.ReleaseIdentifyLock();
}
}
private async Task OnDisconnectingAsync(Exception ex)
@@ -519,7 +520,15 @@ namespace Discord.WebSocket
_sessionId = null;
_lastSeq = 0;
await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards, guildSubscriptions: _guildSubscriptions, gatewayIntents: _gatewayIntents).ConfigureAwait(false);
await _shardedClient.AcquireIdentifyLockAsync(ShardId, _connection.CancelToken).ConfigureAwait(false);
try
{
await ApiClient.SendIdentifyAsync(shardID: ShardId, totalShards: TotalShards, guildSubscriptions: _guildSubscriptions, gatewayIntents: _gatewayIntents).ConfigureAwait(false);
}
finally
{
_shardedClient.ReleaseIdentifyLock();
}
}
break;
case GatewayOpCode.Reconnect:

View File

@@ -126,6 +126,14 @@ namespace Discord.WebSocket
public bool GuildSubscriptions { get; set; } = true;
/// <summary>
/// Gets or sets the maximum identify concurrency.
/// </summary>
/// <remarks>
/// This information is provided by Discord.
/// It is only used when using a <see cref="DiscordShardedClient"/> and auto-sharding is disabled.
/// </remarks>
public int IdentifyMaxConcurrency { get; set; } = 1;
/// Gets or sets the maximum wait time in milliseconds between GUILD_AVAILABLE events before firing READY.
///
/// If zero, READY will fire as soon as it is received and all guilds will be unavailable.