feature: Add Direction.Around to GetMessagesAsync (#1526)
* Add Direction.Around to GetMessagesAsync * Reuse the method * Reuse GetMany * Fix limit when getting from cache without message id * Fix limit when getting from rest without message id * Change cache return It will return in a similar way to REST
This commit is contained in:
@@ -109,12 +109,19 @@ namespace Discord.Rest
|
|||||||
public static IAsyncEnumerable<IReadOnlyCollection<RestMessage>> GetMessagesAsync(IMessageChannel channel, BaseDiscordClient client,
|
public static IAsyncEnumerable<IReadOnlyCollection<RestMessage>> GetMessagesAsync(IMessageChannel channel, BaseDiscordClient client,
|
||||||
ulong? fromMessageId, Direction dir, int limit, RequestOptions options)
|
ulong? fromMessageId, Direction dir, int limit, RequestOptions options)
|
||||||
{
|
{
|
||||||
if (dir == Direction.Around)
|
|
||||||
throw new NotImplementedException(); //TODO: Impl
|
|
||||||
|
|
||||||
var guildId = (channel as IGuildChannel)?.GuildId;
|
var guildId = (channel as IGuildChannel)?.GuildId;
|
||||||
var guild = guildId != null ? (client as IDiscordClient).GetGuildAsync(guildId.Value, CacheMode.CacheOnly).Result : null;
|
var guild = guildId != null ? (client as IDiscordClient).GetGuildAsync(guildId.Value, CacheMode.CacheOnly).Result : null;
|
||||||
|
|
||||||
|
if (dir == Direction.Around && limit > DiscordConfig.MaxMessagesPerBatch)
|
||||||
|
{
|
||||||
|
int around = limit / 2;
|
||||||
|
if (fromMessageId.HasValue)
|
||||||
|
return GetMessagesAsync(channel, client, fromMessageId.Value + 1, Direction.Before, around + 1, options) //Need to include the message itself
|
||||||
|
.Concat(GetMessagesAsync(channel, client, fromMessageId, Direction.After, around, options));
|
||||||
|
else //Shouldn't happen since there's no public overload for ulong? and Direction
|
||||||
|
return GetMessagesAsync(channel, client, null, Direction.Before, around + 1, options);
|
||||||
|
}
|
||||||
|
|
||||||
return new PagedAsyncEnumerable<RestMessage>(
|
return new PagedAsyncEnumerable<RestMessage>(
|
||||||
DiscordConfig.MaxMessagesPerBatch,
|
DiscordConfig.MaxMessagesPerBatch,
|
||||||
async (info, ct) =>
|
async (info, ct) =>
|
||||||
|
|||||||
@@ -11,23 +11,11 @@ namespace Discord.WebSocket
|
|||||||
public static IAsyncEnumerable<IReadOnlyCollection<IMessage>> GetMessagesAsync(ISocketMessageChannel channel, DiscordSocketClient discord, MessageCache messages,
|
public static IAsyncEnumerable<IReadOnlyCollection<IMessage>> GetMessagesAsync(ISocketMessageChannel channel, DiscordSocketClient discord, MessageCache messages,
|
||||||
ulong? fromMessageId, Direction dir, int limit, CacheMode mode, RequestOptions options)
|
ulong? fromMessageId, Direction dir, int limit, CacheMode mode, RequestOptions options)
|
||||||
{
|
{
|
||||||
if (dir == Direction.Around)
|
|
||||||
throw new NotImplementedException(); //TODO: Impl
|
|
||||||
|
|
||||||
IReadOnlyCollection<SocketMessage> cachedMessages = null;
|
|
||||||
IAsyncEnumerable<IReadOnlyCollection<IMessage>> result = null;
|
|
||||||
|
|
||||||
if (dir == Direction.After && fromMessageId == null)
|
if (dir == Direction.After && fromMessageId == null)
|
||||||
return AsyncEnumerable.Empty<IReadOnlyCollection<IMessage>>();
|
return AsyncEnumerable.Empty<IReadOnlyCollection<IMessage>>();
|
||||||
|
|
||||||
if (dir == Direction.Before || mode == CacheMode.CacheOnly)
|
var cachedMessages = GetCachedMessages(channel, discord, messages, fromMessageId, dir, limit);
|
||||||
{
|
var result = ImmutableArray.Create(cachedMessages).ToAsyncEnumerable<IReadOnlyCollection<IMessage>>();
|
||||||
if (messages != null) //Cache enabled
|
|
||||||
cachedMessages = messages.GetMany(fromMessageId, dir, limit);
|
|
||||||
else
|
|
||||||
cachedMessages = ImmutableArray.Create<SocketMessage>();
|
|
||||||
result = ImmutableArray.Create(cachedMessages).ToAsyncEnumerable<IReadOnlyCollection<IMessage>>();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (dir == Direction.Before)
|
if (dir == Direction.Before)
|
||||||
{
|
{
|
||||||
@@ -38,18 +26,35 @@ namespace Discord.WebSocket
|
|||||||
//Download remaining messages
|
//Download remaining messages
|
||||||
ulong? minId = cachedMessages.Count > 0 ? cachedMessages.Min(x => x.Id) : fromMessageId;
|
ulong? minId = cachedMessages.Count > 0 ? cachedMessages.Min(x => x.Id) : fromMessageId;
|
||||||
var downloadedMessages = ChannelHelper.GetMessagesAsync(channel, discord, minId, dir, limit, options);
|
var downloadedMessages = ChannelHelper.GetMessagesAsync(channel, discord, minId, dir, limit, options);
|
||||||
return result.Concat(downloadedMessages);
|
if (cachedMessages.Count != 0)
|
||||||
|
return result.Concat(downloadedMessages);
|
||||||
|
else
|
||||||
|
return downloadedMessages;
|
||||||
}
|
}
|
||||||
else
|
else if (dir == Direction.After)
|
||||||
{
|
{
|
||||||
if (mode == CacheMode.CacheOnly)
|
limit -= cachedMessages.Count;
|
||||||
|
if (mode == CacheMode.CacheOnly || limit <= 0)
|
||||||
return result;
|
return result;
|
||||||
|
|
||||||
//Dont use cache in this case
|
//Download remaining messages
|
||||||
|
ulong maxId = cachedMessages.Count > 0 ? cachedMessages.Max(x => x.Id) : fromMessageId.Value;
|
||||||
|
var downloadedMessages = ChannelHelper.GetMessagesAsync(channel, discord, maxId, dir, limit, options);
|
||||||
|
if (cachedMessages.Count != 0)
|
||||||
|
return result.Concat(downloadedMessages);
|
||||||
|
else
|
||||||
|
return downloadedMessages;
|
||||||
|
}
|
||||||
|
else //Direction.Around
|
||||||
|
{
|
||||||
|
if (mode == CacheMode.CacheOnly || limit <= cachedMessages.Count)
|
||||||
|
return result;
|
||||||
|
|
||||||
|
//Cache isn't useful here since Discord will send them anyways
|
||||||
return ChannelHelper.GetMessagesAsync(channel, discord, fromMessageId, dir, limit, options);
|
return ChannelHelper.GetMessagesAsync(channel, discord, fromMessageId, dir, limit, options);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
public static IReadOnlyCollection<SocketMessage> GetCachedMessages(SocketChannel channel, DiscordSocketClient discord, MessageCache messages,
|
public static IReadOnlyCollection<SocketMessage> GetCachedMessages(ISocketMessageChannel channel, DiscordSocketClient discord, MessageCache messages,
|
||||||
ulong? fromMessageId, Direction dir, int limit)
|
ulong? fromMessageId, Direction dir, int limit)
|
||||||
{
|
{
|
||||||
if (messages != null) //Cache enabled
|
if (messages != null) //Cache enabled
|
||||||
|
|||||||
@@ -56,11 +56,23 @@ namespace Discord.WebSocket
|
|||||||
cachedMessageIds = _orderedMessages;
|
cachedMessageIds = _orderedMessages;
|
||||||
else if (dir == Direction.Before)
|
else if (dir == Direction.Before)
|
||||||
cachedMessageIds = _orderedMessages.Where(x => x < fromMessageId.Value);
|
cachedMessageIds = _orderedMessages.Where(x => x < fromMessageId.Value);
|
||||||
else
|
else if (dir == Direction.After)
|
||||||
cachedMessageIds = _orderedMessages.Where(x => x > fromMessageId.Value);
|
cachedMessageIds = _orderedMessages.Where(x => x > fromMessageId.Value);
|
||||||
|
else //Direction.Around
|
||||||
|
{
|
||||||
|
if (!_messages.TryGetValue(fromMessageId.Value, out SocketMessage msg))
|
||||||
|
return ImmutableArray<SocketMessage>.Empty;
|
||||||
|
int around = limit / 2;
|
||||||
|
var before = GetMany(fromMessageId, Direction.Before, around);
|
||||||
|
var after = GetMany(fromMessageId, Direction.After, around).Reverse();
|
||||||
|
|
||||||
|
return after.Concat(new SocketMessage[] { msg }).Concat(before).ToImmutableArray();
|
||||||
|
}
|
||||||
|
|
||||||
if (dir == Direction.Before)
|
if (dir == Direction.Before)
|
||||||
cachedMessageIds = cachedMessageIds.Reverse();
|
cachedMessageIds = cachedMessageIds.Reverse();
|
||||||
|
if (dir == Direction.Around) //Only happens if fromMessageId is null, should only get "around" and itself (+1)
|
||||||
|
limit = limit / 2 + 1;
|
||||||
|
|
||||||
return cachedMessageIds
|
return cachedMessageIds
|
||||||
.Select(x =>
|
.Select(x =>
|
||||||
|
|||||||
Reference in New Issue
Block a user