using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; namespace Discord.WebSocket { internal class DataStore { private const int CollectionConcurrencyLevel = 1; //WebSocket updater/event handler. //TODO: Needs profiling, increase to 2? private const double AverageChannelsPerGuild = 10.22; //Source: Googie2149 private const double AverageUsersPerGuild = 47.78; //Source: Googie2149 private const double CollectionMultiplier = 1.05; //Add 5% buffer to handle growth private readonly ConcurrentDictionary _channels; private readonly ConcurrentDictionary _dmChannels; private readonly ConcurrentDictionary _guilds; private readonly ConcurrentDictionary _users; private readonly ConcurrentHashSet _groupChannels; internal IReadOnlyCollection Channels => _channels.ToReadOnlyCollection(); internal IReadOnlyCollection DMChannels => _dmChannels.ToReadOnlyCollection(); internal IReadOnlyCollection GroupChannels => _groupChannels.Select(x => GetChannel(x) as SocketGroupChannel).ToReadOnlyCollection(_groupChannels); internal IReadOnlyCollection Guilds => _guilds.ToReadOnlyCollection(); internal IReadOnlyCollection Users => _users.ToReadOnlyCollection(); internal IReadOnlyCollection PrivateChannels => _dmChannels.Select(x => x.Value as IPrivateChannel).Concat( _groupChannels.Select(x => GetChannel(x) as IPrivateChannel)) .ToReadOnlyCollection(() => _dmChannels.Count + _groupChannels.Count); public DataStore(int guildCount, int dmChannelCount) { double estimatedChannelCount = guildCount * AverageChannelsPerGuild + dmChannelCount; double estimatedUsersCount = guildCount * AverageUsersPerGuild; _channels = new ConcurrentDictionary(CollectionConcurrencyLevel, (int)(estimatedChannelCount * CollectionMultiplier)); _dmChannels = new ConcurrentDictionary(CollectionConcurrencyLevel, (int)(dmChannelCount * CollectionMultiplier)); _guilds = new ConcurrentDictionary(CollectionConcurrencyLevel, (int)(guildCount * CollectionMultiplier)); _users = new ConcurrentDictionary(CollectionConcurrencyLevel, (int)(estimatedUsersCount * CollectionMultiplier)); _groupChannels = new ConcurrentHashSet(CollectionConcurrencyLevel, (int)(10 * CollectionMultiplier)); } internal SocketChannel GetChannel(ulong id) { SocketChannel channel; if (_channels.TryGetValue(id, out channel)) return channel; return null; } internal SocketDMChannel GetDMChannel(ulong userId) { SocketDMChannel channel; if (_dmChannels.TryGetValue(userId, out channel)) return channel; return null; } internal void AddChannel(SocketChannel channel) { _channels[channel.Id] = channel; var dmChannel = channel as SocketDMChannel; if (dmChannel != null) _dmChannels[dmChannel.Recipient.Id] = dmChannel; else { var groupChannel = channel as SocketGroupChannel; if (groupChannel != null) _groupChannels.TryAdd(groupChannel.Id); } } internal SocketChannel RemoveChannel(ulong id) { SocketChannel channel; if (_channels.TryRemove(id, out channel)) { var dmChannel = channel as SocketDMChannel; if (dmChannel != null) { SocketDMChannel ignored; _dmChannels.TryRemove(dmChannel.Recipient.Id, out ignored); } else { var groupChannel = channel as SocketGroupChannel; if (groupChannel != null) _groupChannels.TryRemove(id); } return channel; } return null; } internal SocketGuild GetGuild(ulong id) { SocketGuild guild; if (_guilds.TryGetValue(id, out guild)) return guild; return null; } internal void AddGuild(SocketGuild guild) { _guilds[guild.Id] = guild; } internal SocketGuild RemoveGuild(ulong id) { SocketGuild guild; if (_guilds.TryRemove(id, out guild)) return guild; return null; } internal SocketGlobalUser GetUser(ulong id) { SocketGlobalUser user; if (_users.TryGetValue(id, out user)) return user; return null; } internal SocketGlobalUser GetOrAddUser(ulong id, Func userFactory) { return _users.GetOrAdd(id, userFactory); } internal SocketGlobalUser RemoveUser(ulong id) { SocketGlobalUser user; if (_users.TryRemove(id, out user)) return user; return null; } } }