Cleaned up permission checks, added Permissions.Has(enum) methods

This commit is contained in:
RogueException
2016-08-09 18:18:50 -03:00
parent e452aa9662
commit dcb603acd7
4 changed files with 40 additions and 35 deletions

View File

@@ -1,6 +1,4 @@
using System; using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace Discord.Commands namespace Discord.Commands
@@ -8,35 +6,36 @@ namespace Discord.Commands
[Flags] [Flags]
public enum ContextType public enum ContextType
{ {
Guild = 1, // 01 Guild = 0x01,
DM = 2 // 10 DM = 0x02,
Group = 0x04
} }
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = true)] [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = true)]
public class RequireContextAttribute : PreconditionAttribute public class RequireContextAttribute : PreconditionAttribute
{ {
public ContextType Context { get; set; } public ContextType Contexts { get; }
public RequireContextAttribute(ContextType context) public RequireContextAttribute(ContextType contexts)
{ {
Context = context; Contexts = contexts;
} }
public override Task<PreconditionResult> CheckPermissions(IMessage context, Command executingCommand, object moduleInstance) public override Task<PreconditionResult> CheckPermissions(IMessage context, Command executingCommand, object moduleInstance)
{ {
var validContext = false; bool isValid = false;
if (Context.HasFlag(ContextType.Guild)) if ((Contexts & ContextType.Guild) != 0)
validContext = validContext || context.Channel is IGuildChannel; isValid = isValid || context.Channel is IGuildChannel;
if ((Contexts & ContextType.DM) != 0)
isValid = isValid || context.Channel is IDMChannel;
if ((Contexts & ContextType.Group) != 0)
isValid = isValid || context.Channel is IGroupChannel;
if (Context.HasFlag(ContextType.DM)) if (isValid)
validContext = validContext || context.Channel is IDMChannel;
if (validContext)
return Task.FromResult(PreconditionResult.FromSuccess()); return Task.FromResult(PreconditionResult.FromSuccess());
else else
return Task.FromResult(PreconditionResult.FromError($"Invalid context for command; accepted contexts: {Context}")); return Task.FromResult(PreconditionResult.FromError($"Invalid context for command; accepted contexts: {Contexts}"));
} }
} }
} }

View File

@@ -1,6 +1,4 @@
using System; using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks; using System.Threading.Tasks;
namespace Discord.Commands.Attributes.Preconditions namespace Discord.Commands.Attributes.Preconditions
@@ -8,15 +6,14 @@ namespace Discord.Commands.Attributes.Preconditions
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = true, Inherited = true)] [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = true, Inherited = true)]
public class RequirePermission : PreconditionAttribute public class RequirePermission : PreconditionAttribute
{ {
public GuildPermission? GuildPermission { get; set; } public GuildPermission? GuildPermission { get; }
public ChannelPermission? ChannelPermission { get; set; } public ChannelPermission? ChannelPermission { get; }
public RequirePermission(GuildPermission permission) public RequirePermission(GuildPermission permission)
{ {
GuildPermission = permission; GuildPermission = permission;
ChannelPermission = null; ChannelPermission = null;
} }
public RequirePermission(ChannelPermission permission) public RequirePermission(ChannelPermission permission)
{ {
ChannelPermission = permission; ChannelPermission = permission;
@@ -25,25 +22,28 @@ namespace Discord.Commands.Attributes.Preconditions
public override Task<PreconditionResult> CheckPermissions(IMessage context, Command executingCommand, object moduleInstance) public override Task<PreconditionResult> CheckPermissions(IMessage context, Command executingCommand, object moduleInstance)
{ {
if (!(context.Channel is IGuildChannel)) var guildUser = context.Author as IGuildUser;
return Task.FromResult(PreconditionResult.FromError("Command must be used in a guild channel"));
var author = context.Author as IGuildUser;
if (GuildPermission.HasValue) if (GuildPermission.HasValue)
{ {
var guildPerms = author.GuildPermissions.ToList(); if (guildUser == null)
if (!guildPerms.Contains(GuildPermission.Value)) return Task.FromResult(PreconditionResult.FromError("Command must be used in a guild channel"));
return Task.FromResult(PreconditionResult.FromError($"User is missing guild permission {GuildPermission.Value}")); if (!guildUser.GuildPermissions.Has(GuildPermission.Value))
return Task.FromResult(PreconditionResult.FromError($"Command requires guild permission {GuildPermission.Value}"));
} }
if (ChannelPermission.HasValue) if (ChannelPermission.HasValue)
{ {
var channel = context.Channel as IGuildChannel; var guildChannel = context.Channel as IGuildChannel;
var channelPerms = author.GetPermissions(channel).ToList();
if (!channelPerms.Contains(ChannelPermission.Value)) ChannelPermissions perms;
return Task.FromResult(PreconditionResult.FromError($"User is missing channel permission {ChannelPermission.Value}")); if (guildChannel != null)
perms = guildUser.GetPermissions(guildChannel);
else
perms = ChannelPermissions.All(guildChannel);
if (!perms.Has(ChannelPermission.Value))
return Task.FromResult(PreconditionResult.FromError($"Command requires channel permission {ChannelPermission.Value}"));
} }
return Task.FromResult(PreconditionResult.FromSuccess()); return Task.FromResult(PreconditionResult.FromSuccess());

View File

@@ -8,9 +8,10 @@ namespace Discord
public struct ChannelPermissions public struct ChannelPermissions
{ {
//TODO: C#7 Candidate for binary literals //TODO: C#7 Candidate for binary literals
private static ChannelPermissions _allDM { get; } = new ChannelPermissions(Convert.ToUInt64("00010000000000111111110000011001", 2));
private static ChannelPermissions _allText { get; } = new ChannelPermissions(Convert.ToUInt64("00000000000000011100110000000000", 2)); private static ChannelPermissions _allText { get; } = new ChannelPermissions(Convert.ToUInt64("00000000000000011100110000000000", 2));
private static ChannelPermissions _allVoice { get; } = new ChannelPermissions(Convert.ToUInt64("00010011111100000000000000011001", 2)); private static ChannelPermissions _allVoice { get; } = new ChannelPermissions(Convert.ToUInt64("00010011111100000000000000011001", 2));
private static ChannelPermissions _allDM { get; } = new ChannelPermissions(Convert.ToUInt64("00010000000000111111110000011001", 2));
private static ChannelPermissions _allGroup { get; } = new ChannelPermissions(Convert.ToUInt64("00010000000000111111110000011001", 2));
/// <summary> Gets a blank ChannelPermissions that grants no permissions. </summary> /// <summary> Gets a blank ChannelPermissions that grants no permissions. </summary>
public static ChannelPermissions None { get; } = new ChannelPermissions(); public static ChannelPermissions None { get; } = new ChannelPermissions();
@@ -21,6 +22,7 @@ namespace Discord
if (channel is ITextChannel) return _allText; if (channel is ITextChannel) return _allText;
if (channel is IVoiceChannel) return _allVoice; if (channel is IVoiceChannel) return _allVoice;
if (channel is IDMChannel) return _allDM; if (channel is IDMChannel) return _allDM;
if (channel is IGroupChannel) return _allGroup;
throw new ArgumentException("Unknown channel type", nameof(channel)); throw new ArgumentException("Unknown channel type", nameof(channel));
} }
@@ -118,6 +120,8 @@ namespace Discord
embedLinks, attachFiles, readMessageHistory, mentionEveryone, connect, speak, muteMembers, deafenMembers, embedLinks, attachFiles, readMessageHistory, mentionEveryone, connect, speak, muteMembers, deafenMembers,
moveMembers, useVoiceActivation, managePermissions); moveMembers, useVoiceActivation, managePermissions);
public bool Has(ChannelPermission permission) => Permissions.GetValue(RawValue, permission);
public List<ChannelPermission> ToList() public List<ChannelPermission> ToList()
{ {
var perms = new List<ChannelPermission>(); var perms = new List<ChannelPermission>();

View File

@@ -130,6 +130,8 @@ namespace Discord
sendMessages, sendTTSMessages, manageMessages, embedLinks, attachFiles, mentionEveryone, connect, speak, muteMembers, deafenMembers, sendMessages, sendTTSMessages, manageMessages, embedLinks, attachFiles, mentionEveryone, connect, speak, muteMembers, deafenMembers,
moveMembers, useVoiceActivation, changeNickname, manageNicknames, manageRoles); moveMembers, useVoiceActivation, changeNickname, manageNicknames, manageRoles);
public bool Has(GuildPermission permission) => Permissions.GetValue(RawValue, permission);
public List<GuildPermission> ToList() public List<GuildPermission> ToList()
{ {
var perms = new List<GuildPermission>(); var perms = new List<GuildPermission>();