Fix channel being null in DMs on Interactions (#2098)

This commit is contained in:
Quin Lynch
2022-02-11 21:43:32 -04:00
committed by GitHub
parent 0352374e74
commit 7e1b8c9db0
12 changed files with 111 additions and 64 deletions

View File

@@ -62,6 +62,11 @@ namespace Discord
/// </remarks> /// </remarks>
string GuildLocale { get; } string GuildLocale { get; }
/// <summary>
/// Gets whether or not this interaction was executed in a dm channel.
/// </summary>
bool IsDMInteraction { get; }
/// <summary> /// <summary>
/// Responds to an Interaction with type <see cref="InteractionResponseType.ChannelMessageWithSource"/>. /// Responds to an Interaction with type <see cref="InteractionResponseType.ChannelMessageWithSource"/>.
/// </summary> /// </summary>

View File

@@ -57,11 +57,9 @@ namespace Discord.Interactions
bool isValid = false; bool isValid = false;
if ((Contexts & ContextType.Guild) != 0) if ((Contexts & ContextType.Guild) != 0)
isValid = context.Channel is IGuildChannel; isValid = !context.Interaction.IsDMInteraction;
if ((Contexts & ContextType.DM) != 0) if ((Contexts & ContextType.DM) != 0 && (Contexts & ContextType.Group) != 0)
isValid = isValid || context.Channel is IDMChannel; isValid = context.Interaction.IsDMInteraction;
if ((Contexts & ContextType.Group) != 0)
isValid = isValid || context.Channel is IGroupChannel;
if (isValid) if (isValid)
return Task.FromResult(PreconditionResult.FromSuccess()); return Task.FromResult(PreconditionResult.FromSuccess());

View File

@@ -61,6 +61,9 @@ namespace Discord.Rest
/// <inheritdoc/> /// <inheritdoc/>
public bool HasResponded { get; protected set; } public bool HasResponded { get; protected set; }
/// <inheritdoc/>
public bool IsDMInteraction { get; private set; }
internal RestInteraction(BaseDiscordClient discord, ulong id) internal RestInteraction(BaseDiscordClient discord, ulong id)
: base(discord, id) : base(discord, id)
{ {
@@ -108,6 +111,8 @@ namespace Discord.Rest
internal virtual async Task UpdateAsync(DiscordRestClient discord, Model model) internal virtual async Task UpdateAsync(DiscordRestClient discord, Model model)
{ {
IsDMInteraction = !model.GuildId.IsSpecified;
Data = model.Data.IsSpecified Data = model.Data.IsSpecified
? model.Data.Value ? model.Data.Value
: null; : null;

View File

@@ -2233,24 +2233,42 @@ namespace Discord.WebSocket
var data = (payload as JToken).ToObject<API.Interaction>(_serializer); var data = (payload as JToken).ToObject<API.Interaction>(_serializer);
SocketChannel channel = null; var guild = data.GuildId.IsSpecified ? GetGuild(data.GuildId.Value) : null;
if(data.ChannelId.IsSpecified)
{
channel = State.GetChannel(data.ChannelId.Value);
}
else if (data.User.IsSpecified)
{
channel = State.GetDMChannel(data.User.Value.Id);
}
var guild = (channel as SocketGuildChannel)?.Guild;
if (guild != null && !guild.IsSynced) if (guild != null && !guild.IsSynced)
{ {
await UnsyncedGuildAsync(type, guild.Id).ConfigureAwait(false); await UnsyncedGuildAsync(type, guild.Id).ConfigureAwait(false);
return; return;
} }
var interaction = SocketInteraction.Create(this, data, channel as ISocketMessageChannel); SocketUser user = data.User.IsSpecified
? State.GetOrAddUser(data.User.Value.Id, (_) => SocketGlobalUser.Create(this, State, data.User.Value))
: guild.AddOrUpdateUser(data.Member.Value);
SocketChannel channel = null;
if(data.ChannelId.IsSpecified)
{
channel = State.GetChannel(data.ChannelId.Value);
if (channel == null)
{
if (!data.GuildId.IsSpecified) // assume it is a DM
{
channel = CreateDMChannel(data.ChannelId.Value, user, State);
}
else
{
await UnknownChannelAsync(type, data.ChannelId.Value).ConfigureAwait(false);
return;
}
}
}
else if (data.User.IsSpecified)
{
channel = State.GetDMChannel(data.User.Value.Id);
}
var interaction = SocketInteraction.Create(this, data, channel as ISocketMessageChannel, user);
await TimedInvokeAsync(_interactionCreatedEvent, nameof(InteractionCreated), interaction).ConfigureAwait(false); await TimedInvokeAsync(_interactionCreatedEvent, nameof(InteractionCreated), interaction).ConfigureAwait(false);

View File

@@ -13,8 +13,8 @@ namespace Discord.WebSocket
/// </summary> /// </summary>
public new SocketMessageCommandData Data { get; } public new SocketMessageCommandData Data { get; }
internal SocketMessageCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel) internal SocketMessageCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
: base(client, model, channel) : base(client, model, channel, user)
{ {
var dataModel = model.Data.IsSpecified var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value ? (DataModel)model.Data.Value
@@ -27,9 +27,9 @@ namespace Discord.WebSocket
Data = SocketMessageCommandData.Create(client, dataModel, model.Id, guildId); Data = SocketMessageCommandData.Create(client, dataModel, model.Id, guildId);
} }
internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel) internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{ {
var entity = new SocketMessageCommand(client, model, channel); var entity = new SocketMessageCommand(client, model, channel, user);
entity.Update(model); entity.Update(model);
return entity; return entity;
} }

View File

@@ -13,8 +13,8 @@ namespace Discord.WebSocket
/// </summary> /// </summary>
public new SocketUserCommandData Data { get; } public new SocketUserCommandData Data { get; }
internal SocketUserCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel) internal SocketUserCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
: base(client, model, channel) : base(client, model, channel, user)
{ {
var dataModel = model.Data.IsSpecified var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value ? (DataModel)model.Data.Value
@@ -27,9 +27,9 @@ namespace Discord.WebSocket
Data = SocketUserCommandData.Create(client, dataModel, model.Id, guildId); Data = SocketUserCommandData.Create(client, dataModel, model.Id, guildId);
} }
internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel) internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{ {
var entity = new SocketUserCommand(client, model, channel); var entity = new SocketUserCommand(client, model, channel, user);
entity.Update(model); entity.Update(model);
return entity; return entity;
} }

View File

@@ -28,8 +28,8 @@ namespace Discord.WebSocket
private object _lock = new object(); private object _lock = new object();
public override bool HasResponded { get; internal set; } = false; public override bool HasResponded { get; internal set; } = false;
internal SocketMessageComponent(DiscordSocketClient client, Model model, ISocketMessageChannel channel) internal SocketMessageComponent(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
: base(client, model.Id, channel) : base(client, model.Id, channel, user)
{ {
var dataModel = model.Data.IsSpecified var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value ? (DataModel)model.Data.Value
@@ -38,9 +38,9 @@ namespace Discord.WebSocket
Data = new SocketMessageComponentData(dataModel); Data = new SocketMessageComponentData(dataModel);
} }
internal new static SocketMessageComponent Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel) internal new static SocketMessageComponent Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{ {
var entity = new SocketMessageComponent(client, model, channel); var entity = new SocketMessageComponent(client, model, channel, user);
entity.Update(model); entity.Update(model);
return entity; return entity;
} }

View File

@@ -22,8 +22,8 @@ namespace Discord.WebSocket
/// <value></value> /// <value></value>
public new SocketModalData Data { get; set; } public new SocketModalData Data { get; set; }
internal SocketModal(DiscordSocketClient client, ModelBase model, ISocketMessageChannel channel) internal SocketModal(DiscordSocketClient client, ModelBase model, ISocketMessageChannel channel, SocketUser user)
: base(client, model.Id, channel) : base(client, model.Id, channel, user)
{ {
var dataModel = model.Data.IsSpecified var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value ? (DataModel)model.Data.Value
@@ -32,9 +32,9 @@ namespace Discord.WebSocket
Data = new SocketModalData(dataModel); Data = new SocketModalData(dataModel);
} }
internal new static SocketModal Create(DiscordSocketClient client, ModelBase model, ISocketMessageChannel channel) internal new static SocketModal Create(DiscordSocketClient client, ModelBase model, ISocketMessageChannel channel, SocketUser user)
{ {
var entity = new SocketModal(client, model, channel); var entity = new SocketModal(client, model, channel, user);
entity.Update(model); entity.Update(model);
return entity; return entity;
} }

View File

@@ -21,8 +21,8 @@ namespace Discord.WebSocket
public override bool HasResponded { get; internal set; } public override bool HasResponded { get; internal set; }
private object _lock = new object(); private object _lock = new object();
internal SocketAutocompleteInteraction(DiscordSocketClient client, Model model, ISocketMessageChannel channel) internal SocketAutocompleteInteraction(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
: base(client, model.Id, channel) : base(client, model.Id, channel, user)
{ {
var dataModel = model.Data.IsSpecified var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value ? (DataModel)model.Data.Value
@@ -32,9 +32,9 @@ namespace Discord.WebSocket
Data = new SocketAutocompleteInteractionData(dataModel); Data = new SocketAutocompleteInteractionData(dataModel);
} }
internal new static SocketAutocompleteInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel) internal new static SocketAutocompleteInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{ {
var entity = new SocketAutocompleteInteraction(client, model, channel); var entity = new SocketAutocompleteInteraction(client, model, channel, user);
entity.Update(model); entity.Update(model);
return entity; return entity;
} }

View File

@@ -13,8 +13,8 @@ namespace Discord.WebSocket
/// </summary> /// </summary>
public new SocketSlashCommandData Data { get; } public new SocketSlashCommandData Data { get; }
internal SocketSlashCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel) internal SocketSlashCommand(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
: base(client, model, channel) : base(client, model, channel, user)
{ {
var dataModel = model.Data.IsSpecified var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value ? (DataModel)model.Data.Value
@@ -27,9 +27,9 @@ namespace Discord.WebSocket
Data = SocketSlashCommandData.Create(client, dataModel, guildId); Data = SocketSlashCommandData.Create(client, dataModel, guildId);
} }
internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel) internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{ {
var entity = new SocketSlashCommand(client, model, channel); var entity = new SocketSlashCommand(client, model, channel, user);
entity.Update(model); entity.Update(model);
return entity; return entity;
} }

View File

@@ -35,8 +35,8 @@ namespace Discord.WebSocket
private object _lock = new object(); private object _lock = new object();
internal SocketCommandBase(DiscordSocketClient client, Model model, ISocketMessageChannel channel) internal SocketCommandBase(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
: base(client, model.Id, channel) : base(client, model.Id, channel, user)
{ {
var dataModel = model.Data.IsSpecified var dataModel = model.Data.IsSpecified
? (DataModel)model.Data.Value ? (DataModel)model.Data.Value
@@ -49,9 +49,9 @@ namespace Discord.WebSocket
Data = SocketCommandBaseData.Create(client, dataModel, model.Id, guildId); Data = SocketCommandBaseData.Create(client, dataModel, model.Id, guildId);
} }
internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel) internal new static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{ {
var entity = new SocketCommandBase(client, model, channel); var entity = new SocketCommandBase(client, model, channel, user);
entity.Update(model); entity.Update(model);
return entity; return entity;
} }

View File

@@ -5,6 +5,7 @@ using Model = Discord.API.Interaction;
using DataModel = Discord.API.ApplicationCommandInteractionData; using DataModel = Discord.API.ApplicationCommandInteractionData;
using System.IO; using System.IO;
using System.Collections.Generic; using System.Collections.Generic;
using Discord.Net;
namespace Discord.WebSocket namespace Discord.WebSocket
{ {
@@ -72,17 +73,23 @@ namespace Discord.WebSocket
public bool IsValidToken public bool IsValidToken
=> InteractionHelper.CanRespondOrFollowup(this); => InteractionHelper.CanRespondOrFollowup(this);
internal SocketInteraction(DiscordSocketClient client, ulong id, ISocketMessageChannel channel) /// <inheritdoc/>
public bool IsDMInteraction { get; private set; }
private ulong? _channelId;
internal SocketInteraction(DiscordSocketClient client, ulong id, ISocketMessageChannel channel, SocketUser user)
: base(client, id) : base(client, id)
{ {
Channel = channel; Channel = channel;
User = user;
CreatedAt = client.UseInteractionSnowflakeDate CreatedAt = client.UseInteractionSnowflakeDate
? SnowflakeUtils.FromSnowflake(Id) ? SnowflakeUtils.FromSnowflake(Id)
: DateTime.UtcNow; : DateTime.UtcNow;
} }
internal static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel) internal static SocketInteraction Create(DiscordSocketClient client, Model model, ISocketMessageChannel channel, SocketUser user)
{ {
if (model.Type == InteractionType.ApplicationCommand) if (model.Type == InteractionType.ApplicationCommand)
{ {
@@ -95,27 +102,31 @@ namespace Discord.WebSocket
return dataModel.Type switch return dataModel.Type switch
{ {
ApplicationCommandType.Slash => SocketSlashCommand.Create(client, model, channel), ApplicationCommandType.Slash => SocketSlashCommand.Create(client, model, channel, user),
ApplicationCommandType.Message => SocketMessageCommand.Create(client, model, channel), ApplicationCommandType.Message => SocketMessageCommand.Create(client, model, channel, user),
ApplicationCommandType.User => SocketUserCommand.Create(client, model, channel), ApplicationCommandType.User => SocketUserCommand.Create(client, model, channel, user),
_ => null _ => null
}; };
} }
if (model.Type == InteractionType.MessageComponent) if (model.Type == InteractionType.MessageComponent)
return SocketMessageComponent.Create(client, model, channel); return SocketMessageComponent.Create(client, model, channel, user);
if (model.Type == InteractionType.ApplicationCommandAutocomplete) if (model.Type == InteractionType.ApplicationCommandAutocomplete)
return SocketAutocompleteInteraction.Create(client, model, channel); return SocketAutocompleteInteraction.Create(client, model, channel, user);
if (model.Type == InteractionType.ModalSubmit) if (model.Type == InteractionType.ModalSubmit)
return SocketModal.Create(client, model, channel); return SocketModal.Create(client, model, channel, user);
return null; return null;
} }
internal virtual void Update(Model model) internal virtual void Update(Model model)
{ {
IsDMInteraction = !model.GuildId.IsSpecified;
_channelId = model.ChannelId.ToNullable();
Data = model.Data.IsSpecified Data = model.Data.IsSpecified
? model.Data.Value ? model.Data.Value
: null; : null;
@@ -123,18 +134,6 @@ namespace Discord.WebSocket
Version = model.Version; Version = model.Version;
Type = model.Type; Type = model.Type;
if (User == null)
{
if (model.Member.IsSpecified && model.GuildId.IsSpecified)
{
User = SocketGuildUser.Create(Discord.State.GetGuild(model.GuildId.Value), Discord.State, model.Member.Value);
}
else
{
User = SocketGlobalUser.Create(Discord, Discord.State, model.User.Value);
}
}
UserLocale = model.UserLocale.IsSpecified UserLocale = model.UserLocale.IsSpecified
? model.UserLocale.Value ? model.UserLocale.Value
: null; : null;
@@ -399,6 +398,28 @@ namespace Discord.WebSocket
public abstract Task RespondWithModalAsync(Modal modal, RequestOptions options = null); public abstract Task RespondWithModalAsync(Modal modal, RequestOptions options = null);
#endregion #endregion
/// <summary>
/// Attepts to get the channel this interaction was executed in.
/// </summary>
/// <param name="options">The request options for this <see langword="async"/> request.</param>
/// <returns>
/// A task that represents the asynchronous operation of fetching the channel.
/// </returns>
public async ValueTask<IMessageChannel> GetChannelAsync(RequestOptions options = null)
{
if (Channel != null)
return Channel;
if (!_channelId.HasValue)
return null;
try
{
return (IMessageChannel)await Discord.GetChannelAsync(_channelId.Value, options).ConfigureAwait(false);
}
catch(HttpException ex) when (ex.DiscordCode == DiscordErrorCode.MissingPermissions) { return null; } // bot can't view that channel, return null instead of throwing.
}
#region IDiscordInteraction #region IDiscordInteraction
/// <inheritdoc/> /// <inheritdoc/>
IUser IDiscordInteraction.User => User; IUser IDiscordInteraction.User => User;