Renamed existing Flatten method to FlattenAsync and added new Flatten method. Also fixed ClientHelper using incorrect guild batch count. (#744)

This commit is contained in:
ObsidianMinor
2018-01-06 21:43:11 -06:00
committed by Christopher F
parent edfbd055bb
commit 5bbd9bba82
4 changed files with 63 additions and 11 deletions

View File

@@ -13,7 +13,7 @@ namespace Discord.Commands
public override async Task<TypeReaderResult> ReadAsync(ICommandContext context, string input, IServiceProvider services)
{
var results = new Dictionary<ulong, TypeReaderValue>();
IReadOnlyCollection<IUser> channelUsers = (await context.Channel.GetUsersAsync(CacheMode.CacheOnly).Flatten().ConfigureAwait(false)).ToArray(); //TODO: must be a better way?
IAsyncEnumerable<IUser> channelUsers = context.Channel.GetUsersAsync(CacheMode.CacheOnly).Flatten(); // it's better
IReadOnlyCollection<IGuildUser> guildUsers = ImmutableArray.Create<IGuildUser>();
ulong id;
@@ -45,7 +45,7 @@ namespace Discord.Commands
string username = input.Substring(0, index);
if (ushort.TryParse(input.Substring(index + 1), out ushort discriminator))
{
var channelUser = channelUsers.FirstOrDefault(x => x.DiscriminatorValue == discriminator &&
var channelUser = await channelUsers.FirstOrDefault(x => x.DiscriminatorValue == discriminator &&
string.Equals(username, x.Username, StringComparison.OrdinalIgnoreCase));
AddResult(results, channelUser as T, channelUser?.Username == username ? 0.85f : 0.75f);
@@ -57,8 +57,9 @@ namespace Discord.Commands
//By Username (0.5-0.6)
{
foreach (var channelUser in channelUsers.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase)))
AddResult(results, channelUser as T, channelUser.Username == input ? 0.65f : 0.55f);
await channelUsers
.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase))
.ForEachAsync(channelUser => AddResult(results, channelUser as T, channelUser.Username == input ? 0.65f : 0.55f));
foreach (var guildUser in guildUsers.Where(x => string.Equals(input, x.Username, StringComparison.OrdinalIgnoreCase)))
AddResult(results, guildUser as T, guildUser.Username == input ? 0.60f : 0.50f);
@@ -66,8 +67,9 @@ namespace Discord.Commands
//By Nickname (0.5-0.6)
{
foreach (var channelUser in channelUsers.Where(x => string.Equals(input, (x as IGuildUser)?.Nickname, StringComparison.OrdinalIgnoreCase)))
AddResult(results, channelUser as T, (channelUser as IGuildUser).Nickname == input ? 0.65f : 0.55f);
await channelUsers
.Where(x => string.Equals(input, (x as IGuildUser)?.Nickname, StringComparison.OrdinalIgnoreCase))
.ForEachAsync(channelUser => AddResult(results, channelUser as T, (channelUser as IGuildUser).Nickname == input ? 0.65f : 0.55f));
foreach (var guildUser in guildUsers.Where(x => string.Equals(input, (x as IGuildUser).Nickname, StringComparison.OrdinalIgnoreCase)))
AddResult(results, guildUser as T, (guildUser as IGuildUser).Nickname == input ? 0.60f : 0.50f);

View File

@@ -1,14 +1,64 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace Discord
{
public static class AsyncEnumerableExtensions
{
public static async Task<IEnumerable<T>> Flatten<T>(this IAsyncEnumerable<IReadOnlyCollection<T>> source)
/// <summary>
/// Flattens the specified pages into one <see cref="IEnumerable{T}"/> asynchronously
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="source"></param>
/// <returns></returns>
public static async Task<IEnumerable<T>> FlattenAsync<T>(this IAsyncEnumerable<IEnumerable<T>> source)
{
return (await source.ToArray().ConfigureAwait(false)).SelectMany(x => x);
return await source.Flatten().ToArray().ConfigureAwait(false);
}
public static IAsyncEnumerable<T> Flatten<T>(this IAsyncEnumerable<IEnumerable<T>> source)
{
return new PagedCollectionEnumerator<T>(source);
}
internal class PagedCollectionEnumerator<T> : IAsyncEnumerator<T>, IAsyncEnumerable<T>
{
readonly IAsyncEnumerator<IEnumerable<T>> _source;
IEnumerator<T> _enumerator;
public IAsyncEnumerator<T> GetEnumerator() => this;
internal PagedCollectionEnumerator(IAsyncEnumerable<IEnumerable<T>> source)
{
_source = source.GetEnumerator();
}
public T Current => _enumerator.Current;
public void Dispose()
{
_enumerator?.Dispose();
_source.Dispose();
}
public async Task<bool> MoveNext(CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
if(!_enumerator?.MoveNext() ?? true)
{
if (!await _source.MoveNext(cancellationToken).ConfigureAwait(false))
return false;
_enumerator?.Dispose();
_enumerator = _source.Current.GetEnumerator();
return _enumerator.MoveNext();
}
return true;
}
}
}
}

View File

@@ -79,7 +79,7 @@ namespace Discord.Rest
ulong? fromGuildId, int? limit, RequestOptions options)
{
return new PagedAsyncEnumerable<RestUserGuild>(
DiscordConfig.MaxUsersPerBatch,
DiscordConfig.MaxGuildsPerBatch,
async (info, ct) =>
{
var args = new GetGuildSummariesParams
@@ -106,7 +106,7 @@ namespace Discord.Rest
}
public static async Task<IReadOnlyCollection<RestGuild>> GetGuildsAsync(BaseDiscordClient client, RequestOptions options)
{
var summaryModels = await GetGuildSummariesAsync(client, null, null, options).Flatten();
var summaryModels = await GetGuildSummariesAsync(client, null, null, options).FlattenAsync().ConfigureAwait(false);
var guilds = ImmutableArray.CreateBuilder<RestGuild>();
foreach (var summaryModel in summaryModels)
{

View File

@@ -413,7 +413,7 @@ namespace Discord.Rest
async Task<IReadOnlyCollection<IGuildUser>> IGuild.GetUsersAsync(CacheMode mode, RequestOptions options)
{
if (mode == CacheMode.AllowDownload)
return (await GetUsersAsync(options).Flatten().ConfigureAwait(false)).ToImmutableArray();
return (await GetUsersAsync(options).FlattenAsync().ConfigureAwait(false)).ToImmutableArray();
else
return ImmutableArray.Create<IGuildUser>();
}