Source code for scripts.helpers.validators

"""
PokeGambler - A Pokemon themed gambling bot for Discord.
Copyright (C) 2021 Harshith Thota

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as
published by the Free Software Foundation, either version 3 of the
License, or (at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.

You should have received a copy of the GNU Affero General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.
----------------------------------------------------------------------------

This module is a compilation of user input validators.
"""

# pylint: disable=too-few-public-methods

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
import re

from .utils import dm_send, get_embed

if TYPE_CHECKING:
    from discord import Embed, Message


[docs]class Validator(ABC): """ Base class for all validators. :param message: The message which triggered the validation. :type message: :class:`discord.Message` :param on_error: A dictionary containing error heading and content. :type on_error: Dict :param dm_user: Route the message to user\'s DM? :type dm_user: bool :param notify: Notify the user if the validation fails? :type notify: bool """ error_embed_title = "" error_embed_desc = "" error_embed_kwargs = {} null_embed_title = "No Value specified." null_embed_desc = "You need to provide a value." null_embed_kwargs = {} cleaner = None # pylint: disable=too-many-arguments def __init__( self, message: Message, on_error: Optional[Dict[str, str]] = None, on_null: Optional[Dict[str, str]] = None, dm_user: bool = False, notify: bool = True ): self.message = message on_error = on_error or {} on_null = on_null or {} if on_error.get("title"): self.error_embed_title = on_error.pop("title") if on_error.get("description"): self.error_embed_desc = on_error.pop("description") if on_null.get("title"): self.null_embed_title = on_null.pop("title") if on_null.get("description"): self.null_embed_desc = on_null.pop("description") self.error_embed_kwargs.update(on_error) self.null_embed_kwargs.update(on_null) self.dm_user = dm_user self.notify = notify @abstractmethod def check(self, value) -> bool: """ Subclass specific logic for validation. """ @property def error_embed(self) -> Embed: """Returns an error embed. :return: An error embed. :rtype: :class:`discord.Embed` """ return get_embed( title=self.error_embed_title, embed_type="error", content=( self.error_embed_desc + "\nPlease retry the command." ), **self.error_embed_kwargs ) @property def null_embed(self) -> Embed: """Returns a null embed. :return: A null embed. :rtype: :class:`discord.Embed` """ return get_embed( title=self.null_embed_title, embed_type="error", content=( self.null_embed_desc + "\nPlease retry the command." ), **self.null_embed_kwargs ) async def __notify(self, is_null: bool = False): """ Performs the notifier function. """ if not self.notify: return embed = self.null_embed if is_null else self.error_embed if not self.dm_user: await self.message.reply(embed=embed) else: await dm_send( self.message, self.message.author, embed=embed )
[docs] async def validate(self, value: Any) -> bool: """Validates the given value. :param value: The value to be validated. :type value: str :return: True if the value is valid, False otherwise. :rtype: bool """ if value is None: await self.__notify(is_null=True) return False if not self.check(value): await self.__notify() return False return True
[docs] async def cleaned(self, value: Any) -> Any: """ Cleans the given value after validation. """ valid = await self.validate(value) if not valid: return None if self.cleaner is not None: # pylint: disable=not-callable return self.cleaner(value) return value
[docs]class ChainValidator(Validator): """ Chain multiple validators together. :param validator_spec: Validator and Kwargs mapping. :type validator_spec: Dict[Validator, Dict[str, Any]] """ def __init__( self, message: Message, validator_spec: Dict[Validator, Dict[str, Any]], **kwargs ): super().__init__(message=message, **kwargs) self.validators = [ validator(**kwgs) for validator, kwgs in validator_spec.items() ] def check(self, value) -> bool: for validator in self.validators: if not validator.check(value): self.error_embed_title = validator.error_embed_title self.error_embed_desc = validator.error_embed_desc return False return True
[docs]class IntegerValidator(Validator): """ Validates an integer value. """ error_embed_title = "Invalid input" #: error_embed_desc = "Please enter a valid integer." cleaner = int def __init__(self, **kwargs): super().__init__(**kwargs) def check(self, value: Union[str, int]) -> bool: return str(value).replace(',', '').isdigit()
[docs]class MaxValidator(IntegerValidator): """Validates a value is less than or equal to the maximum. :param max_value: The maximum value to validate against. :type max_value: int :kwargs: Additional keyword arguments to pass to the superclass. :type kwargs: Dict """ def __init__( self, max_value: int, **kwargs ): super().__init__(**kwargs) self.max_value = max_value if not kwargs.get("on_error"): self.error_embed_title = "Input value is too high." self.error_embed_desc = "Value must be less than " + \ f"than **{max_value}**." def check(self, value: Union[str, int]) -> bool: if not super().check(value): self.error_embed_title = super().error_embed_title self.error_embed_desc = super().error_embed_desc return False return int(str(value).replace(',', '')) <= self.max_value
[docs]class MinValidator(IntegerValidator): """ Validates a value is greater than or equal to the minimum. :param min_value: The minimum value to validate against. :type min_value: int :kwargs: Additional keyword arguments to pass to the superclass. :type kwargs: Dict """ def __init__( self, min_value: int, **kwargs ): super().__init__(**kwargs) self.min_value = min_value if not kwargs.get("on_error"): self.error_embed_title = "Input value is too low." self.error_embed_desc = "Value must be greater than " + \ f"than **{min_value}**." def check(self, value: Union[str, int]) -> bool: if not super().check(value): self.error_embed_title = super().error_embed_title self.error_embed_desc = super().error_embed_desc return False return int(str(value).replace(',', '')) >= self.min_value
[docs]class MaxLengthValidator(Validator): """Validates a string to be of a certain length. :param max_length: The maximum permissible length of the string. :type max_length: int :kwargs: Additional keyword arguments to pass to the superclass. :type kwargs: Dict """ error_embed_title = "Invalid Input" #: error_embed_desc = "Input value has too many characters." def __init__(self, max_length: int, **kwargs): super().__init__(**kwargs) self.max_length = max_length if not kwargs.get("on_error"): self.error_embed_desc += f"\nMax length: **{max_length}**" def check(self, value) -> bool: return len(value) <= self.max_length
[docs]class MinLengthValidator(Validator): """Validates a string to be of a certain length. :param min_length: The minimum permissible length of the string. :type min_length: int :kwargs: Additional keyword arguments to pass to the superclass. :type kwargs: Dict """ error_embed_title = "Invalid Input" #: error_embed_desc = "Input value has too few characters." def __init__(self, min_length: int, **kwargs): super().__init__(**kwargs) self.min_length = min_length if not kwargs.get("on_error"): self.error_embed_desc += f"\nMin length: **{min_length}**" def check(self, value) -> bool: return len(value) >= self.min_length
[docs]class RegexValidator(Validator): """Validates a string against a regular expression. :param pattern: The regular expression to use for validation. :type pattern: str """ def __init__(self, pattern: str, **kwargs): super().__init__(**kwargs) self.pattern = re.compile(pattern) def check(self, value) -> bool: return self.pattern.match(value) is not None
[docs]class HexValidator(RegexValidator): """ Validates if a string is a hexadecimal value. """ #: error_embed_title = "Invalid Hexadecimal value" def __init__(self, **kwargs): super().__init__(r'#?[0-9a-fA-F]{6}', **kwargs)
[docs] @staticmethod def cleaner(hex_str: str) -> int: """ Converts a hexadecimal string to an integer. :param hex_str: The hexadecimal string to convert. :type hex_str: str :return: The integer value of the hexadecimal string. :rtype: int """ return int(hex_str.lstrip('#'), 16)
[docs]class ImageUrlValidator(RegexValidator): """ Validates if a strings is an image URL. """ #: error_embed_title = "Invalid image URL" #: error_embed_desc = "Only Png and JPG images are supported." def __init__(self, **kwargs): super().__init__( r'(?:http|https)://[^\s]+\.(?:png|jpg|jpeg)', **kwargs )
[docs]class UrlValidator(RegexValidator): """ Validates if a string is a URL. """ #: error_embed_title = "Invalid URL" def __init__(self, **kwargs): super().__init__(r'(?:http|https)://[^\s]+', **kwargs)
[docs]class ItemNameValidator(RegexValidator): """ Validates if a string is a valid item name. """ #: error_embed_title = "Invalid Item Name" def __init__(self, **kwargs): super().__init__( r'[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))', **kwargs )