"""
PokeGambler - A Pokemon themed gambling bot for Discord.
Copyright (C) 2021 Harshith Thota
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
----------------------------------------------------------------------------
Module which extends Discord.py to allow for custom application commands.
"""
# pylint: disable=no-member, too-many-lines
from __future__ import annotations
import inspect
import itertools
from typing import (
TYPE_CHECKING, Any, Callable, Coroutine,
Dict, List, Optional, Tuple, Type, Union
)
import discord
from cachetools import TTLCache
from discord.app_commands import Choice
from discord.http import Route
from scripts.base.enums import OptionTypes
from ..helpers.parsers import CustomRstParser
from ..helpers.utils import get_modules
from .components import (
AppCommand, ContextMenu, GuildEvent,
GuildEventType, SlashCommand
)
if TYPE_CHECKING:
from bot import PokeGambler
[docs]class OverriddenChannel:
"""
A modified version of :class:`discord.TextChannel` which
overrides :meth:`discord.TextChannel.send` to
suppress stray deferred ephemeral messages.
:param parent: The original interaction class.
:type parent: :class:`discord.Interaction`
"""
def __init__(self, parent: discord.Interaction):
self.parent = parent
self.channel = parent.channel
def __getattr__(self, item: str) -> Any:
return self.__dict__.get(
item, getattr(self.channel, item)
)
def __str__(self) -> str:
return str(self.channel)
def __eq__(self, other: OverriddenChannel) -> bool:
return self.channel.id == other.channel.id
[docs] async def send(self, *args, **kwargs):
"""
:meth:`discord.TextChannel.send` like behavior for
:class:`~scripts.base.handlers.OverriddenChannel`.
"""
msg = await self.channel.send(*args, **kwargs)
return msg
[docs]class CustomInteraction:
"""
An overrided class of :class:`discord.Interaction` to
provide hybrid behavior between interaction and Message.
:param interaction: The original interaction class.
:type interaction: :class:`discord.Interaction`
"""
def __init__(self, interaction: discord.Interaction):
self.interaction = interaction
self.author = interaction.user
self.channel = OverriddenChannel(interaction)
def __getattr__(self, item: str) -> Any:
return self.__dict__.get(
item, getattr(self.interaction, item)
)
def __str__(self) -> str:
return str(self.interaction)
def __eq__(self, other: CustomInteraction) -> bool:
return self.interaction.id == other.interaction.id
[docs] async def reply(self, *args, **kwargs):
"""
:meth:`discord.Message.reply` like behavior for Interactions.
"""
try:
await self.interaction.response.send_message(
*args, **kwargs
)
msg = await self.interaction.original_message()
except discord.InteractionResponded:
msg = await self.interaction.followup.send(
*args, **kwargs
)
return msg
[docs] async def add_reaction(self, reaction: str):
""":meth:`discord.Message.add_reaction` like behavior for Interactions.
:param reaction: The reaction to add.
:type reaction: :class:`str`
"""
# pylint: disable=import-outside-toplevel
from ..base.views import EmojiButton
emoji_btn = EmojiButton(reaction)
payload = {
"content": '\u200B',
"view": emoji_btn,
"ephemeral": True
}
try:
await self.interaction.response.send_message(
**payload
)
except discord.InteractionResponded:
await self.interaction.followup.send(
**payload
)
[docs]class CommandListing(list):
"""An extending hybrid list to hold instances of
:class:`~.components.AppCommand`.
Supports accessing a command by name.
:param handler: The command handler object.
:type session: Union[SlashHandler, ContextHandler]
"""
def __init__(self, handler: Union[SlashHandler, ContextHandler]):
list.__init__(self)
self.handler = handler
def __iter__(self) -> List[AppCommand]:
for item in super().__iter__():
if isinstance(item, AppCommand):
yield item
def __getitem__(self, item: str) -> AppCommand:
return next(
(
command
for command in self
if command.name == item
), None
)
def __setitem__(self, key: str, value: AppCommand) -> None:
if isinstance(value, dict):
value = AppCommand.from_dict(**value)
super().__setitem__(key, value)
def __contains__(self, item: Union[AppCommand, Dict]) -> bool:
if isinstance(item, dict):
item = self.Component.from_dict(**item)
return any(
command.id == item.id
for command in self
)
# pylint: disable=invalid-name
@property
def Component(self) -> Type[AppCommand]:
"""Returns the component class of the handler.
:return: The component class of the handler.
:rtype: Type[AppCommand]
"""
return self.handler.component_class
[docs] def append(self, command: Union[AppCommand, Dict]):
"""Transform/Ensure that the command is a
:class:`~.components.AppCommand` and append it to the list.
:param command: The command to add to the list.
:type command: Union[:class:`~.components.AppCommand`, Dict]
"""
if isinstance(command, dict):
command = self.Component.from_dict(command)
super().append(command)
[docs] def remove(self, command: Union[str, Dict, AppCommand]):
"""
| Remove a command from the list.
| Supports lookups by name and ID.
:param command: The command to remove from the list.
:type command: Union[str, Dict, :class:`~.components.AppCommand`]
"""
for index, item in enumerate(self):
if any([
isinstance(command, str) and item.name == command,
isinstance(command, dict) and item.id == command['id'],
isinstance(
command, AppCommand
) and item.id == command.id,
]):
del self[index]
return
@property
def names(self) -> List[str]:
"""
:return: Stored command names.
:rtype: List
"""
return [command.name for command in self]
[docs] async def refresh(self, fresh=False, **kwargs):
"""
Refreshes the command list.
"""
if fresh:
self.clear()
route = self.handler.get_route('GET', **kwargs)
commands = await self.handler.http.request(route)
for command in commands:
if command['type'] in self.Component.types():
if 'options' in command:
for option in command['options']:
option['autocomplete'] = option.get(
'autocomplete', False
)
if command['name'] not in self.names:
self.append(
self.handler.component_class.from_dict(command)
)
[docs]def command_to_dict(
command: Callable,
local: Optional[bool] = False
) -> Dict[str, Any]:
"""Convert a command to a dictionary.
:param command: Command to convert.
:type command: Callable
:param local: Whether to use local or official server.
:type local: Optional[bool]
:return: Dictionary representation of the command.
:rtype: Dict[str, Any]
"""
description = command.__doc__.split("\n")[0]
with CustomRstParser() as rst_parser:
rst_parser.parse(command.__doc__)
meta = rst_parser.meta
options = rst_parser.parsed_params
desc = meta.description
if len(desc) <= 100:
description = desc
options = sorted(options, key=lambda x: -x['required'])
# Fix for Discord API not supporting Default Option
for option in options:
option.pop('default', None)
cmd_name = command.__name__.replace("cmd_", "")
if local:
cmd_name += "_"
return {
"name": cmd_name,
"description": description,
"type": 1,
"options": options
}
[docs]class SlashHandler:
"""
| Class which handles custom slash commands.
| It is an extension of Discord.py's http module.
"""
def __init__(self, ctx: PokeGambler):
self.ctx = ctx
self.http = ctx.http
self.component_class = SlashCommand
self.registered = CommandListing(self)
self.official_roles = {}
self.update_counter = 0
[docs] async def add_slash_commands(self, **kwargs):
"""Add all slash commands to the guild/globally.
:param kwargs: Keyword arguments to pass to the route.
:type kwargs: Dict
"""
if not any([
self.ctx.is_prod,
self.ctx.is_local
]):
return
if not self.official_roles:
self.official_roles = {
role.name: role.id
for role in self.ctx.get_guild(
self.ctx.official_server
).roles
}
current_commands = []
for module in get_modules(self.ctx):
current_commands.extend(
getattr(module, attr)
for attr in dir(module)
if (
attr.startswith("cmd_")
and "no_slash" not in dir(
getattr(module, attr)
)
)
)
current_commands = sorted(
list(set(current_commands)),
key=current_commands.index
)
await self.__sync_commands(current_commands, **kwargs)
for command in current_commands:
await self.register_command(command, **kwargs)
msg = f"Succesfully synced {len(self.registered)} slash commands."
if not self.update_counter:
msg = 'No commands require sync.'
self.ctx.logger.pprint(msg, color='green')
[docs] async def delete_command(
self, command: Union[SlashCommand, Dict],
**kwargs
):
"""Deletes a Slash command to the guild/globally.
:param command: The Command object/dictionary to delete.
:type command: Union[:class:`~.components.SlashCommand`, Dict]
:param kwargs: Keyword arguments to pass to the route.
:type kwargs: Dict[str, Any]
"""
if isinstance(command, SlashCommand):
cmd = command.to_dict()
route_kwargs = {
"method": "DELETE",
"endpoint": cmd['id']
}
if self.ctx.is_prod:
route_kwargs["guild_id"] = self.ctx.official_server
kwargs |= route_kwargs
route = self.get_route(**kwargs)
try:
await self.http.request(route)
if cmd['name'] in self.registered.names:
self.registered.remove(cmd['name'])
self.ctx.logger.pprint(
f"Unregistered command: {cmd['name']}",
color='yellow'
)
except discord.HTTPException as excp:
self.ctx.logger.pprint(
excp, color='red'
)
[docs] async def get_command(self, command: str) -> AppCommand:
"""Gets the Command Object from its name.
:param command: The command name
:type command: str
:return: The corresponding Slash Command object.
:rtype: :class:`~.components.SlashCommand`
"""
return next(
(
cmd
for cmd in self.registered
if cmd.name == command
), None
)
[docs] def get_route(self, method="POST", **kwargs):
"""Get the route for the query based on the kwargs.
:param method: HTTP Method to use.
:type method: str
:param kwargs: Keyword arguments to pass to the route.
:type kwargs: Dict[str, Any]
"""
path = f"/applications/{self.ctx.user.id}"
if kwargs.get('guild_id'):
path += f'/guilds/{kwargs["guild_id"]}'
path += '/commands'
if kwargs.get('endpoint'):
path += f'/{kwargs["endpoint"]}'
return Route(
method=method,
path=path,
**kwargs
)
[docs] async def parse_response(
self, interaction: discord.Interaction
) -> Tuple[Callable, Dict[str, Any]]:
"""Parse the response from the server.
:param interaction: Interaction to parse.
:type interaction: :class:`discord.Interaction`
:return: Parsed command method and additional details.
:rtype: Tuple[Callable, Dict[str, Any]]
"""
# await interaction.response.defer()
interaction = CustomInteraction(interaction)
data = interaction.data
kwargs = {
'message': interaction,
'args': [],
'mentions': []
}
method = None
cmd = f'cmd_{data["name"]}'
if self.ctx.is_local:
cmd = cmd.rstrip('_')
for com in get_modules(self.ctx):
if com.enabled:
method = getattr(com, cmd, None)
if method:
break
with CustomRstParser() as rst_parser:
rst_parser.parse(method.__doc__)
param_names = rst_parser.param_names
for opt in data.get('options', {}):
if opt['name'] in param_names:
if opt['type'] == OptionTypes.USER.value:
getter = interaction.guild.get_member
alt_getter = self.ctx.get_user
elif opt['type'] == OptionTypes.CHANNEL.value:
getter = interaction.guild.get_channel
alt_getter = self.ctx.get_channel
elif opt['type'] == OptionTypes.ROLE.value:
getter = interaction.guild.get_role
alt_getter = None
elif opt['type'] == OptionTypes.ATTACHMENT.value:
getter = None
alt_getter = None
opt['value'] = discord.Attachment(
data=data['resolved']['attachments'][opt['value']],
# pylint: disable=protected-access
state=interaction._state
)
else:
getter = None
if getter:
opt['value'] = self.__get_entity(
int(opt['value']),
getter, alt_getter
)
kwargs[opt['name']] = opt['value']
return (method, kwargs)
[docs] async def register_command(
self, command: Callable,
**kwargs
):
"""Register a Slash command to the guild/globally.
:param command: Command to register.
:type command: str
:param kwargs: Keyword arguments to pass to the route.
:type kwargs: Dict
"""
cmd_name = command.__name__.replace('cmd_', '')
if self.ctx.is_local:
cmd_name += '_'
if cmd_name in self.registered.names and self.__params_matched(
command, cmd_name
):
return {}
self.update_counter += 1
self.ctx.logger.pprint(
f"[{self.update_counter}] Registering the command: {cmd_name}",
color='blue'
)
payload = self.__prep_payload(command)
if not payload:
return {}
if hasattr(command, "os_only") and self.ctx.is_prod:
route = self.get_route(guild_id=self.ctx.official_server)
elif self.ctx.is_local:
route = self.get_route(**kwargs)
else:
route = self.get_route()
if any(
hasattr(command, perm)
for perm in (
'admin_only',
'owner_only'
)
) and self.ctx.is_prod:
route = self.get_route(guild_id=self.ctx.official_server)
payload["default_member_permission"] = "0"
try:
resp = await self.http.request(route, json=payload)
self.registered.remove(cmd_name)
self.registered.append(resp)
return resp
except discord.HTTPException as excp:
self.ctx.logger.pprint(
str(excp).splitlines()[-1],
color='red'
)
self.ctx.logger.pprint(
f"{command}\n{payload}",
color='red'
)
return {}
def __params_matched(self, command, cmd_name):
params = {}
with CustomRstParser() as rst_parser:
rst_parser.parse(command.__doc__)
if rst_parser.params:
params = {
key: {
field: val
for field, val in param.parse().items()
if all([
val is not None,
# Fix for Discord API not supporting Default option
field != 'default'
])
}
for key, param in rst_parser.params.items()
if key != 'message'
}
registered_cmd_params = self.registered[cmd_name].parameters
return params == registered_cmd_params
def __prep_payload(self, command: Callable) -> Dict:
"""Prepare payload for slash command registration.
:param command: Command to register.
:type command: Callable
:return: Payload for slash command registration.
:rtype: Dict
"""
if getattr(command, "disabled", False):
return {}
if not command.__doc__:
return {}
return command_to_dict(
command, local=self.ctx.is_local
)
async def __sync_commands(self, current_commands, **kwargs):
if not self.ctx.is_local:
await self.registered.refresh(fresh=True)
await self.registered.refresh(**kwargs)
pad = '' if not self.ctx.is_local else '_'
cmds = [
cmd.__name__.replace('cmd_', '') + pad
for cmd in current_commands
]
for command in self.registered:
if command.name not in cmds:
await self.delete_command(command, **kwargs)
@staticmethod
def __get_entity(val, getter, alt_getter=None):
id_ = int(val)
alt_getter = alt_getter or getter
return getter(id_) or alt_getter(id_)
[docs]class ContextHandler:
"""Class which handles context menus.
.. warning::
Currently does not support autosync.
:param ctx: The pokegambler client.
:type ctx: :class:`bot.PokeGambler`
"""
def __init__(self, ctx: PokeGambler):
self.ctx = ctx
self.http = ctx.http
self.component_class = ContextMenu
self.registered = CommandListing(self)
self.update_counter = 0
[docs] async def delete_command(
self, command: Union[ContextMenu, Dict],
**kwargs
):
"""Deletes a Context Menu command.
:param command: The Command object/dictionary to delete.
:type command: Union[:class:`~.components.ContextMenu`, Dict]
:param kwargs: Keyword arguments to pass to the route.
:type kwargs: Dict[str, Any]
"""
if isinstance(command, ContextMenu):
cmd = command.to_dict()
route_kwargs = {
"method": "DELETE",
"endpoint": cmd['id']
}
kwargs |= route_kwargs
route = self.get_route(**kwargs)
try:
await self.http.request(route)
if cmd['name'] in self.registered.names:
self.registered.remove(cmd['name'])
self.ctx.logger.pprint(
f"Unregistered command: {cmd['name']}",
color='yellow'
)
except discord.HTTPException as excp:
self.ctx.logger.pprint(
excp, color='red'
)
[docs] async def execute(
self, interaction: discord.Interaction
) -> Callable:
"""Parse the data into :class:`.components.ContextMenu`
and executes its callback.
:param interaction: Interaction to parse.
:type interaction: :class:`discord.Interaction`
:return: Parsed command method.
:rtype: Callable
"""
await interaction.response.defer()
interaction = CustomInteraction(interaction)
command = self.registered[interaction.data['name']]
await command.callback(message=interaction)
[docs] def get_route(self, method="POST", **kwargs):
"""Get the route for the query.
:param method: HTTP Method to use.
:type method: str
"""
path = f"/applications/{self.ctx.user.id}/commands"
if kwargs.get("endpoint"):
path += f"/{kwargs['endpoint']}"
return Route(
method=method,
path=path
)
[docs] async def register_command(
self, callback: Callable,
type_: Optional[int] = 2
) -> ContextMenu:
"""
Register a context menu command.
:param callback: The callback which gets executed upon interaction.
:type callback: Callable
:param type_: The type of the context menu command.
:type type_: Optional[int]
:return: The registered context menu command.
:rtype: :class:`~.components.ContextMenu`
"""
if not self.registered:
await self.registered.refresh()
cmd_name = callback.__name__.title().replace('Cmd_', '')
if cmd_name in self.registered.names:
if not self.registered[cmd_name].callback:
self.registered[cmd_name].callback = callback
return {}
session = self.http
payload = {
"name": cmd_name,
"type": type_
}
route = Route(
'POST',
f"/applications/{self.ctx.user.id}/commands",
json=payload
)
self.ctx.logger.pprint(
f"Registering the command: {payload['name']}",
color='blue'
)
try:
resp = await session.request(route, json=payload)
self.update_counter += 1
self.registered.append(
ContextMenu.from_dict(
{**resp, "callback": callback}
)
)
self.ctx.logger.pprint(
f"Succesfully registered: {payload['name']}",
color='green'
)
return self.registered[-1]
except discord.HTTPException as excp:
self.ctx.logger.pprint(
excp, color='red'
)
return None
[docs] async def register_all(self):
"""
Registers all the decorated Context Menu commands.
"""
for module in get_modules(self.ctx):
for cmd in dir(module):
if cmd.startswith('cmd_') and hasattr(
getattr(module, cmd), 'ctx_command'
):
await self.register_command(getattr(module, cmd))
msg = f"Registered {len(self.registered)} commands."
if not self.update_counter:
msg = "No new context menu commands found."
self.ctx.logger.pprint(msg, color='green')
[docs]class AutocompleteHandler:
"""Autocomplete handler for commands.
:param ctx: The pokegambler client.
:type ctx: :class:`bot.PokeGambler`
"""
def __init__(self, ctx: PokeGambler):
self.ctx = ctx
self.commands = {}
self.cache = TTLCache(maxsize=10, ttl=60)
[docs] def register(
self, cmd: Coroutine,
callback_dict: Dict[str, Callable[
[discord.Interaction], List[Any]
]]
):
"""Register a command with its choices.
:param cmd: The command to register.
:param callback_dict: The callback dictionary.
"""
cmd_dict = command_to_dict(cmd, local=self.ctx.is_local)
cmd = SlashCommand.from_dict(cmd_dict)
self.commands[cmd] = callback_dict
[docs] def unregister(self, cmd: SlashCommand):
"""Unregister a command.
:param cmd: The command to unregister.
:type cmd: :class:`~scripts.base.components.SlashCommand`
"""
self.commands.pop(cmd, None)
[docs] async def parse(
self, interaction: discord.Interaction
) -> List[Choice]:
"""Get the choices for a command.
:param interaction: The interaction to parse.
:type interaction: :class:`discord.Interaction`
:return: The choices for the command.
:rtype: List[:class:`discord.app_commands.Choice`]
"""
cmd = SlashCommand.from_dict(
interaction.data
)
if cmd not in self.registered:
return []
focused_opt = next(
(
opt
for opt in cmd.options
if opt.get('focused', False)
),
None
)
if not focused_opt:
return []
# Reset on empty input
if focused_opt.get('value', '') == '':
self.cache.pop(
(cmd, focused_opt, interaction.user.id),
None
)
focused_opt = focused_opt.name
choices = self.cache.get(
(cmd, focused_opt, interaction.user.id)
)
if not choices:
callable_ = self.commands[cmd][focused_opt]
if inspect.iscoroutinefunction(callable_):
choices = await callable_(interaction)
else:
choices = callable_(interaction)
if len(choices) > 20:
choices = iter(choices)
self.cache[
(cmd, focused_opt, interaction.user.id)
] = choices
choice_list = [
Choice(
name=elem['name'],
value=elem['value']
)
for elem in itertools.islice(choices, 20)
]
await interaction.response.autocomplete(choice_list)
return choice_list
@property
def registered(self) -> List[SlashCommand]:
"""Get the registered commands.
:return: The registered commands.
:rtype: List[:class:`~scripts.base.components.SlashCommand`]
"""
return list(self.commands)
[docs]class GuildEventHandler:
"""Class which handles guild events.
:param ctx: The pokegambler client.
:type ctx: :class:`bot.PokeGambler`
"""
def __init__(self, ctx: PokeGambler):
self.ctx = ctx
self.http = ctx.http
[docs] async def list_events(self, guild_id: Union[int, str]) -> List[GuildEvent]:
"""List all the events for a guild.
:param guild_id: The guild id.
:type guild_id: Union[int, str]
:return: The list of events.
:rtype: List[:class:`~scripts.base.components.GuildEvent`]
"""
route = Route(
method="GET",
path=f"/guilds/{guild_id}/scheduled-events"
)
try:
resp = await self.http.request(route)
return [
GuildEvent.from_dict(event)
for event in resp
]
except discord.HTTPException as excp:
self.ctx.logger.pprint(
excp, color='red'
)
return None
[docs] async def register_event(
self, guild_id: Union[int, str],
event: GuildEvent
) -> GuildEvent:
"""Register a guild event.
:param guild_id: The guild id.
:type guild_id: Union[int, str]
:param event: The event to register.
:type event: :class:`~scripts.base.components.GuildEvent`
:return: The registered event.
:rtype: :class:`~scripts.base.components.GuildEvent`
"""
route = Route(
method="POST",
path=f"/guilds/{guild_id}/scheduled-events"
)
if not isinstance(event, GuildEvent):
event = self.dict_to_event(event)
try:
resp = await self.http.request(route, json=event.to_payload())
return GuildEvent.from_dict(resp)
except discord.HTTPException as excp:
self.ctx.logger.pprint(
excp, color='red'
)
return None
[docs] @staticmethod
def dict_to_event(event: Dict) -> GuildEvent:
"""Convert a dict to a :class:`.components.GuildEvent`
:param event: The dict to convert.
:type event: Dict
:return: The converted event.
:rtype: :class:`~scripts.base.components.GuildEvent`
"""
if "start" in event:
event["scheduled_start_time"] = event.pop("start")
if "end" in event:
event["scheduled_end_time"] = event.pop("end")
text_channel = event.pop("text_channel", None)
voice_channel = event.pop("voice_channel", None)
gevent = GuildEvent.from_dict(event)
if text_channel:
gevent.set_text_channel(str(text_channel))
elif voice_channel:
gevent.set_voice_channel(str(voice_channel))
elif (
gevent.entity_type == GuildEventType.EXTERNAL
and (
not gevent.entity_metadata
or not gevent.entity_metadata.get('location')
)
):
gevent.entity_metadata = {
'location': "Here"
}
return gevent