diff --git a/.gitignore b/.gitignore index 5164e92..2a5529f 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,6 @@ docs/build # local env .env* + +# Virtual environment +.venv diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..7bd209d --- /dev/null +++ b/alembic.ini @@ -0,0 +1,113 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README new file mode 100644 index 0000000..2500aa1 --- /dev/null +++ b/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..c2b3e2c --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,80 @@ +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool +from bot import constants +from bot.orm.models import Base +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# modify the config here because ConfigParser can't handle default values +config.set_main_option("sqlalchemy.url", constants.Bot.database_dsn) + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..fbc4b07 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/14dcb08ebf2e_initial_revision.py b/alembic/versions/14dcb08ebf2e_initial_revision.py new file mode 100644 index 0000000..cee568e --- /dev/null +++ b/alembic/versions/14dcb08ebf2e_initial_revision.py @@ -0,0 +1,45 @@ +"""initial revision + +Revision ID: 14dcb08ebf2e +Revises: +Create Date: 2023-10-15 21:13:49.716003 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "14dcb08ebf2e" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "guilds", sa.Column("guild_id", sa.BigInteger(), nullable=False), sa.PrimaryKeyConstraint("guild_id") + ) + op.create_table( + "reminders", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("channel_id", sa.BigInteger(), nullable=False), + sa.Column("message_id", sa.BigInteger(), nullable=False), + sa.Column("author_id", sa.BigInteger(), nullable=False), + sa.Column("mention_ids", sa.ARRAY(sa.BigInteger()), nullable=False), + sa.Column("content", sa.String(), nullable=False), + sa.Column("expiration", sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("reminders") + op.drop_table("guilds") + # ### end Alembic commands ### diff --git a/pyproject.toml b/pyproject.toml index 5921406..2a3501d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,6 @@ line-length = 120 [tool.ruff] target-version = "py311" line-length = 120 -select = ["ALL"] ignore = [ "ERA001", # (Found commented-out code) - Porting features a piece at a time "G004", # (Logging statement uses f-string) - Developer UX diff --git a/requirements/requirements.in b/requirements/requirements.in index e4195ba..ff37cb3 100644 --- a/requirements/requirements.in +++ b/requirements/requirements.in @@ -16,6 +16,7 @@ coloredlogs # Database psycopg[binary] SQLAlchemy +alembic # Utilities # utils/helpers diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 8ca6218..09378d3 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile with Python 3.11 # by the following command: # -# pip-compile --output-file=requirements/requirements.txt requirements/requirements.in +# pip-compile requirements/requirements.in # aiodns==3.0.0 # via pydis-core @@ -12,6 +12,8 @@ aiohttp==3.8.6 # discord-py aiosignal==1.3.1 # via aiohttp +alembic==1.12.0 + # via -r requirements/requirements.in annotated-types==0.6.0 # via pydantic arrow==1.3.0 @@ -53,6 +55,10 @@ idna==3.4 # yarl imsosorry==1.2.1 # via -r requirements/requirements.in +mako==1.2.4 + # via alembic +markupsafe==2.1.3 + # via mako multidict==6.0.4 # via # aiohttp @@ -96,7 +102,9 @@ six==1.16.0 # python-dateutil # requests-file sqlalchemy==2.0.22 - # via -r requirements/requirements.in + # via + # -r requirements/requirements.in + # alembic statsd==4.0.1 # via pydis-core tldextract==5.0.0 @@ -105,6 +113,7 @@ types-python-dateutil==2.8.19.14 # via arrow typing-extensions==4.8.0 # via + # alembic # psycopg # pydantic # pydantic-core diff --git a/src/bot/__main__.py b/src/bot/__main__.py index 64dc029..b742a57 100644 --- a/src/bot/__main__.py +++ b/src/bot/__main__.py @@ -5,6 +5,7 @@ import aiohttp import discord from discord.ext import commands +from sqlalchemy.ext.asyncio import create_async_engine from bot import constants from bot.bot import Bot @@ -19,6 +20,7 @@ async def main() -> None: guild_id=constants.Guild.id, http_session=aiohttp.ClientSession(), allowed_roles=list({discord.Object(id_) for id_ in constants.MODERATION_ROLES}), + engine=create_async_engine(constants.Bot.database_dsn), command_prefix=commands.when_mentioned, intents=intents, ) diff --git a/src/bot/bot.py b/src/bot/bot.py index a6d02be..bb8b292 100644 --- a/src/bot/bot.py +++ b/src/bot/bot.py @@ -1,9 +1,13 @@ """Bot subclass.""" +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from typing import Self +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession from pydis_core import BotBase from sentry_sdk import push_scope +from sqlalchemy.ext.asyncio.session import async_sessionmaker from bot import exts from bot.log import get_logger @@ -22,9 +26,25 @@ def __init__(self: Self, base: Exception) -> None: class Bot(BotBase): """A subclass of `pydis_core.BotBase` that implements bot-specific functions.""" - def __init__(self: Self, *args: list, **kwargs: dict) -> None: + def __init__(self: Self, *args: list, engine: AsyncEngine, **kwargs: dict) -> None: + self._engine = engine + self._sessionmaker = async_sessionmaker(bind=engine) + super().__init__(*args, **kwargs) + @asynccontextmanager + async def get_session(self) -> AsyncIterator[AsyncSession]: + """Return a session in async context manager style. + + Will automatically commit changes to database and close session at end of context + """ + session = self._sessionmaker() + try: + yield session + finally: + await session.commit() + await session.close() + async def setup_hook(self: Self) -> None: """Default async initialisation method for discord.py.""" # noqa: D401 await super().setup_hook() diff --git a/src/bot/converters.py b/src/bot/converters.py new file mode 100644 index 0000000..97fa6ee --- /dev/null +++ b/src/bot/converters.py @@ -0,0 +1,49 @@ +from datetime import UTC, datetime, timedelta + +from discord.ext.commands import BadArgument, Context, Converter + +from bot.utils.time import parse_duration_string + + +class DeltaConverter(Converter): + """Convert duration strings into dateutil.relativedelta.relativedelta objects.""" + + async def convert(self, ctx: Context, duration: str) -> timedelta: + """ + Converts a `duration` string to a timedelta object. + + The converter supports the following symbols for each unit of time: + - years: `Y`, `y`, `year`, `years` + - months: `m`, `month`, `months` + - weeks: `w`, `W`, `week`, `weeks` + - days: `d`, `D`, `day`, `days` + - hours: `H`, `h`, `hour`, `hours` + - minutes: `M`, `minute`, `minutes` + - seconds: `S`, `s`, `second`, `seconds` + + The units need to be provided in descending order of magnitude. + """ + if not (delta := parse_duration_string(duration)): + msg = f"`{duration}` is not a valid duration string." + raise BadArgument(msg) + + return delta + + +class DurationConverter(DeltaConverter): + """Convert duration strings into UTC datetime.datetime objects.""" + + async def convert(self, ctx: Context, duration: str) -> datetime: + """ + Converts a `duration` string to a datetime object that's `duration` in the future. + + The converter supports the same symbols for each unit of time as its parent class. + """ + delta = await super().convert(ctx, duration) + now = datetime.now(UTC) + + try: + return now + delta + except (ValueError, OverflowError): + msg = f"`{duration}` results in a datetime outside the supported range." + raise BadArgument(msg) diff --git a/src/bot/errors.py b/src/bot/errors.py new file mode 100644 index 0000000..6505cf5 --- /dev/null +++ b/src/bot/errors.py @@ -0,0 +1,21 @@ +from collections.abc import Hashable + + +class LockedResourceError(RuntimeError): + """ + Exception raised when an operation is attempted on a locked resource. + + Attributes + ---------- + `type` -- name of the locked resource's type + `id` -- ID of the locked resource + """ + + def __init__(self, resource_type: str, resource_id: Hashable) -> None: + self.type = resource_type + self.id = resource_id + + super().__init__( + f"Cannot operate on {self.type.lower()} `{self.id}`; " + "it is currently locked and in use by another operation.", + ) diff --git a/src/bot/exts/utilities/reminder.py b/src/bot/exts/utilities/reminder.py new file mode 100644 index 0000000..9a4c133 --- /dev/null +++ b/src/bot/exts/utilities/reminder.py @@ -0,0 +1,573 @@ +import random +import textwrap +import typing as t +from datetime import UTC, datetime, timedelta + +import discord +from discord import Interaction +from discord.ext.commands import Cog, Context, Greedy, group +from pydis_core.utils.channel import get_or_fetch_channel +from pydis_core.utils.scheduling import Scheduler +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from bot.bot import Bot +from bot.constants import ( + NEGATIVE_REPLIES, + POSITIVE_REPLIES, + Roles, +) +from bot.converters import DurationConverter +from bot.orm.models import Reminder +from bot.errors import LockedResourceError +from bot.log import get_logger +from bot.utils.checks import has_no_roles_check +from bot.utils.lock import lock_arg +from bot.utils.messages import send_denial +from bot.utils.paginator import LinePaginator + +log = get_logger(__name__) + +LOCK_NAMESPACE = "reminder" +MAXIMUM_REMINDERS = 5 +REMINDER_EDIT_CONFIRMATION_TIMEOUT = 60 + + +class ModifyReminderConfirmationView(discord.ui.View): + """A view to confirm modifying someone else's reminder by admins.""" + + def __init__(self, author: discord.Member | discord.User) -> None: + super().__init__(timeout=REMINDER_EDIT_CONFIRMATION_TIMEOUT) + self.author = author + self.result: bool = False + + async def interaction_check(self, interaction: Interaction) -> bool: + """Only allow interactions from the command invoker.""" + return interaction.user.id == self.author.id + + @discord.ui.button(label="Confirm", style=discord.ButtonStyle.blurple, row=0) + async def confirm(self, interaction: Interaction, _: discord.ui.Button) -> None: + """Confirm the reminder modification.""" + await interaction.response.edit_message(view=None) + self.result = True + self.stop() + + @discord.ui.button(label="Cancel", row=0) + async def cancel(self, interaction: Interaction, _: discord.ui.Button) -> None: + """Cancel the reminder modification.""" + await interaction.response.edit_message(view=None) + self.stop() + + +class Reminders(Cog): + """Provide in-channel reminder functionality.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + self.scheduler = Scheduler(self.__class__.__name__) + + async def cog_unload(self) -> None: + """Cancel scheduled tasks.""" + self.scheduler.cancel_all() + + async def cog_load(self) -> None: + """Get all current reminders from the API and reschedule them.""" + await self.bot.wait_until_guild_available() + + async with self.bot.get_session() as session: + query = select(Reminder) + scalars = await session.scalars(query) + reminders = scalars.all() + + for reminder in reminders: + await self.ensure_valid_reminder(reminder, session=session) + self.schedule_reminder(reminder) + + async def ensure_valid_reminder(self, reminder: Reminder, *, session: AsyncSession) -> discord.Message | None: + """Ensure reminder channel and message can be fetched. Otherwise delete the reminder.""" + channel = await get_or_fetch_channel(self.bot, reminder.channel_id) + if isinstance(channel, discord.abc.Messageable): + try: + return await channel.fetch_message(reminder.message_id) + except (discord.NotFound, discord.HTTPException, discord.Forbidden): + log.exception("Error while ensuring validity of reminder") + else: + log.info(f"Could not access channel ID {reminder.channel_id} for reminder {reminder.id}") + + log.warning("Deleting reminder {reminder.id} as it is invalid") + await session.merge(reminder) + await session.delete(reminder) + return None + + @staticmethod + async def _send_confirmation( + ctx: Context, + on_success: str, + reminder_id: int, + ) -> None: + """Send an embed confirming the reminder change was made successfully.""" + embed = discord.Embed( + description=on_success, + colour=discord.Colour.green(), + title=random.choice(POSITIVE_REPLIES), + ) + + footer_str = f"ID: {reminder_id}" + + embed.set_footer(text=footer_str) + + await ctx.send(embed=embed) + + def schedule_reminder(self, reminder: Reminder) -> None: + """A coroutine which sends the reminder once the time is reached, and cancels the running task.""" + self.scheduler.schedule_at(reminder.expiration, reminder.id, self.send_reminder(reminder.id)) + + async def _reschedule_reminder(self, reminder: Reminder) -> None: + """Reschedule a reminder object.""" + log.trace(f"Cancelling old task #{reminder.id}") + self.scheduler.cancel(reminder.id) + + log.trace(f"Scheduling new task #{reminder.id}") + self.schedule_reminder(reminder) + + @lock_arg(LOCK_NAMESPACE, "reminder_id", lambda id: id, raise_error=True) + async def send_reminder(self, reminder_id: int) -> None: + """Send the reminder, then delete it.""" + async with self.bot.get_session() as session: + query = select(Reminder).where(Reminder.id == reminder_id) + + reminder = await session.scalar(query) + if reminder is None: + log.error(f"Reminder {reminder_id} not found while sending reminder") + return + + try: + channel = await get_or_fetch_channel(self.bot, reminder.channel_id) + except (discord.NotFound, discord.Forbidden, discord.HTTPException): + log.exception( + f"Unable to find message {reminder.message_id} " f"while sending reminder {reminder.id}, deleting", + ) + + await session.delete(reminder) + return + + channel = self.bot.get_partial_messageable(reminder.channel_id) + embed = discord.Embed() + if datetime.now(UTC) > reminder.expiration + timedelta(seconds=30): + embed.colour = discord.Colour.red() + embed.set_author(name="Sorry, your reminder should have arrived earlier!") + else: + embed.colour = discord.Colour.og_blurple() + embed.set_author(name="It has arrived!") + + # Let's not use a codeblock to keep emojis and mentions working. Embeds are safe anyway. + embed.description = f"Here's your reminder: {reminder.content}" + + additional_mentions = " ".join(f"<@{target}>" for target in reminder.mention_ids) + + partial_message = channel.get_partial_message(reminder.message_id) + jump_button = discord.ui.Button( + label="Click here to go to your reminder", + style=discord.ButtonStyle.link, + url=partial_message.jump_url, + ) + + view = discord.ui.View() + view.add_item(jump_button) + + try: + await partial_message.reply(content=f"{additional_mentions}", embed=embed, view=view) + except discord.HTTPException as e: + log.info( + f"There was an error when trying to reply to a reminder invocation message, {e}, " + "fall back to using jump_url", + ) + await channel.send(content=f"<@{reminder.author_id}> {additional_mentions}", embed=embed, view=view) + + await session.delete(reminder) + log.debug(f"Deleting reminder #{reminder.id} (the user has been reminded).") + + @staticmethod + def try_get_content_from_rely(message: discord.Message) -> str | None: + """ + Attempts to get content from a message's reply, if it exists. + + Differs from `pydis_core.utils.commands.clean_text_or_reply` as allows for messages with no content. + """ + if (reference := message.reference) and isinstance((resolved_message := reference.resolved), discord.Message): + if resolved_message.content: + return resolved_message.content + else: + # If the replied message has no content (e.g. only attachments/embeds) + return "*See referenced message.*" + + return None + + @group(name="remind", aliases=("reminder", "reminders", "remindme"), invoke_without_command=True) + async def remind_group( + self, + ctx: Context, + mentions: Greedy[discord.Member | discord.User], + expiration: t.Annotated[datetime, DurationConverter], + *, + content: str | None = None, + ) -> None: + """ + Commands for managing your reminders. + + The `expiration` duration of `!remind new` supports the following symbols for each unit of time: + - years: `Y`, `y`, `year`, `years` + - months: `m`, `month`, `months` + - weeks: `w`, `W`, `week`, `weeks` + - days: `d`, `D`, `day`, `days` + - hours: `H`, `h`, `hour`, `hours` + - minutes: `M`, `minute`, `minutes` + - seconds: `S`, `s`, `second`, `seconds` + + For example, to set a reminder that expires in 3 days and 1 minute, you can do `!remind new 3d1M Do something`. + """ + await self.new_reminder(ctx, mentions=mentions, expiration=expiration, content=content) + + @remind_group.command(name="new", aliases=("add", "create")) + async def new_reminder( + self, + ctx: Context, + mentions: Greedy[discord.Member | discord.User], + expiration: t.Annotated[datetime, DurationConverter], + *, + content: str | None = None, + ) -> None: + """ + Set yourself a simple reminder. + + The `expiration` duration supports the following symbols for each unit of time: + - years: `Y`, `y`, `year`, `years` + - months: `m`, `month`, `months` + - weeks: `w`, `W`, `week`, `weeks` + - days: `d`, `D`, `day`, `days` + - hours: `H`, `h`, `hour`, `hours` + - minutes: `M`, `minute`, `minutes` + - seconds: `S`, `s`, `second`, `seconds` + + For example, to set a reminder that expires in 3 days and 1 minute, you can do `!remind new 3d1M Do something`. + """ + # Get their current active reminders + async with self.bot.get_session() as session: + query = select(Reminder).where(Reminder.author_id == ctx.author.id) + + scalars = await session.scalars(query) + reminders = scalars.all() + + if len(reminders) > MAXIMUM_REMINDERS: + await send_denial(ctx, f"You have too many active reminders! ({MAXIMUM_REMINDERS})") + return + + # Remove duplicate mentions + mention_ids = {mention.id for mention in mentions} + mention_ids.discard(ctx.author.id) + mention_ids = list(mention_ids) + + content = content or self.try_get_content_from_rely(ctx.message) + if not content: + await send_denial(ctx, "You must have content in your message or reply to a message!") + return + + # Now we can attempt to actually set the reminder. + async with self.bot.get_session() as session: + reminder = Reminder( + channel_id=ctx.channel.id, + message_id=ctx.message.id, + author_id=ctx.author.id, + expiration=expiration, + mention_ids=mention_ids, + content=content, + ) + + session.add(reminder) + + await session.flush() + + formatted_expiry = discord.utils.format_dt(expiration, style="F") + mention_string = f"Your reminder will arrive on {formatted_expiry}" + if mentions: + mention_string += f" and will mention {len(mentions)} other(s)" + mention_string += "!" + + # Confirm to the user that it worked. + await self._send_confirmation( + ctx, + on_success=mention_string, + reminder_id=reminder.id, + ) + + self.schedule_reminder(reminder) + + @remind_group.command(name="list") + async def list_reminders(self, ctx: Context) -> None: + """View a paginated embed of all reminders for your user.""" + # Get all the user's reminders from the database. + async with self.bot.get_session() as session: + query = select(Reminder).where(Reminder.author_id == ctx.author.id).order_by(Reminder.expiration) + scalars = await session.scalars(query) + reminders = scalars.all() + + lines = [] + for reminder in reminders: + expiry = discord.utils.format_dt(reminder.expiration, style="R") + + message = await self.ensure_valid_reminder(reminder, session=session) + if message is None: + log.warning("Invalid reminder {reminder.id} while listing, deleting") + continue + + mention_string = "**Mentions:** " + ", ".join(target.mention for target in message.mentions) + + text = textwrap.dedent( + f""" + **Reminder #{reminder.id}:** *expires {expiry}* {mention_string} + {message.content} + """ + ) + + lines.append(text) + + embed = discord.Embed() + embed.colour = discord.Colour.og_blurple() + embed.title = f"Reminders for {ctx.author}" + + # Remind the user that they have no reminders :^) + if not lines: + embed.description = "No active reminders could be found." + await ctx.send(embed=embed) + return + + # Construct the embed and paginate it. + embed.colour = discord.Colour.og_blurple() + + await LinePaginator.paginate( + lines, + ctx, + embed, + max_lines=3, + ) + + @remind_group.group(name="edit", aliases=("change", "modify"), invoke_without_command=True) + async def edit_reminder_group(self, ctx: Context) -> None: + """Commands for modifying your current reminders.""" + await ctx.send_help(ctx.command) + + @edit_reminder_group.command(name="duration", aliases=("time",)) + async def edit_reminder_duration( + self, + ctx: Context, + reminder_id: int, + expiration: t.Annotated[datetime, DurationConverter], + ) -> None: + """ + Edit one of your reminder's expiration. + + The `expiration` duration supports the following symbols for each unit of time: + - years: `Y`, `y`, `year`, `years` + - months: `m`, `month`, `months` + - weeks: `w`, `W`, `week`, `weeks` + - days: `d`, `D`, `day`, `days` + - hours: `H`, `h`, `hour`, `hours` + - minutes: `M`, `minute`, `minutes` + - seconds: `S`, `s`, `second`, `seconds` + + For example, to edit a reminder to expire in 3 days and 1 minute, you can do `!remind edit duration 1234 3d1M`. + """ + await self.edit_reminder(ctx, reminder_id=reminder_id, new_expiration=expiration) + + @edit_reminder_group.command(name="content", aliases=("reason",)) + async def edit_reminder_content( + self, + ctx: Context, + reminder_id: int, + *, + content: str | None = None, + ) -> None: + """ + Edit one of your reminder's content. + + You can either supply the new content yourself, or reply to a message to use its content. + """ + content = content or self.try_get_content_from_rely(ctx.message) + if not content: + await send_denial(ctx, "You must have content in your message or reply to a message!") + return + + await self.edit_reminder(ctx, reminder_id=reminder_id, new_content=content) + + @edit_reminder_group.command(name="mentions", aliases=("pings",)) + async def edit_reminder_mentions( + self, + ctx: Context, + reminder_id: int, + mentions: Greedy[discord.User | discord.Member], + ) -> None: + """Edit one of your reminder's mentions.""" + # Remove duplicate mentions + mention_ids = {mention.id for mention in mentions} + mention_ids.discard(ctx.author.id) + mention_ids = list(mention_ids) + + await self.edit_reminder(ctx, reminder_id=reminder_id, new_mention_ids=mention_ids) + + @lock_arg(LOCK_NAMESPACE, "reminder_id", raise_error=True) + async def edit_reminder( + self, + ctx: Context, + *, + reminder_id: int, + new_mention_ids: list[int] | None = None, + new_content: str | None = None, + new_expiration: datetime | None = None, + ) -> None: + """Edits a reminder with the given new data, then sends a confirmation message.""" + if not await self._can_modify(ctx, reminder_id=reminder_id): + return + + async with self.bot.get_session() as session: + query = select(Reminder).where(Reminder.id == reminder_id) + + reminder = await session.scalar(query) + + if reminder is None: + await send_denial(ctx, f"Unable to find reminder `{reminder_id}`.") + return + + if new_mention_ids: + reminder.mention_ids = new_mention_ids + + if new_content: + reminder.content = new_content + + if new_expiration: + reminder.expiration = new_expiration + await self._reschedule_reminder(reminder) + + # Send a confirmation message to the channel + await self._send_confirmation( + ctx, + on_success="That reminder has been edited successfully!", + reminder_id=reminder_id, + ) + + @lock_arg(LOCK_NAMESPACE, "reminder_id", raise_error=True) + async def _delete_reminder(self, ctx: Context, reminder_id: int) -> bool: + """Acquires a lock on `reminder_id` and returns `True` if reminder is deleted, otherwise `False`.""" + if not await self._can_modify(ctx, reminder_id=reminder_id): + return False + + async with self.bot.get_session() as session: + query = select(Reminder).where(Reminder.id == reminder_id) + + reminder = await session.scalar(query) + + if reminder is None: + await send_denial(ctx, f"Unable to find reminder `{reminder_id}`.") + return False + + await session.delete(reminder) + + self.scheduler.cancel(reminder_id) + return True + + @remind_group.command("delete", aliases=("remove", "cancel")) + async def delete_reminder(self, ctx: Context, reminder_ids: Greedy[int]) -> None: + """Delete up to (and including) 5 of your active reminders.""" + if len(reminder_ids) > 5: + await send_denial(ctx, "You can only delete a maximum of 5 reminders at once.") + return + + deleted_ids: list[str] = [] + for reminder_id in set(reminder_ids): + try: + reminder_deleted = await self._delete_reminder(ctx, reminder_id) + except LockedResourceError: + continue + else: + if reminder_deleted: + deleted_ids.append(str(reminder_id)) + + if deleted_ids: + colour = discord.Colour.green() + title = random.choice(POSITIVE_REPLIES) + deletion_message = f"Successfully deleted the following reminder(s): {', '.join(deleted_ids)}" + + if len(deleted_ids) != len(reminder_ids): + deletion_message += ( + "\n\nThe other reminder(s) could not be deleted as they're either locked, " + "belong to someone else, or don't exist." + ) + else: + colour = discord.Colour.red() + title = random.choice(NEGATIVE_REPLIES) + deletion_message = ( + "Could not delete the reminder(s) as they're either locked, " "belong to someone else, or don't exist." + ) + + embed = discord.Embed( + description=deletion_message, + colour=colour, + title=title, + ) + await ctx.send(embed=embed) + + async def _can_modify( + self, + ctx: Context, + *, + reminder_id: int, + send_on_denial: bool = True, + ) -> bool: + """ + Check whether the reminder can be modified by the ctx author. + + The check passes if the user created the reminder, or if they are an admin (with confirmation). + """ + async with self.bot.get_session() as session: + query = select(Reminder).where(Reminder.id == reminder_id) + + reminder = await session.scalar(query) + + if reminder is None: + log.warning(f"Reminder {reminder_id} not found when checking if user can modify") + return False + + if reminder.author_id != ctx.author.id: + if await has_no_roles_check(ctx, Roles.administrators): + log.warning(f"{ctx.author} is not the reminder's author and thus does not pass the check.") + if send_on_denial: + await send_denial(ctx, "You can't modify reminders of other users!") + return False + else: + log.debug(f"{ctx.author} is an admin, asking for confirmation to modify someone else's.") + + modify_action = "delete" if ctx.command == self.delete_reminder else "edit" + + confirmation_view = ModifyReminderConfirmationView(ctx.author) + confirmation_message = await ctx.reply( + f"Are you sure you want to {modify_action} <@{reminder.author_id}>'s reminder?", + view=confirmation_view, + ) + view_timed_out = await confirmation_view.wait() + # We don't have access to the message in `on_timeout` so we have to delete the view here + if view_timed_out: + await confirmation_message.edit(view=None) + + if confirmation_view.result: + log.debug(f"{ctx.author} has confirmed reminder modification.") + else: + await ctx.send("🚫 Operation canceled.") + log.debug(f"{ctx.author} has cancelled reminder modification.") + return confirmation_view.result or False + else: + log.debug(f"{ctx.author} is the reminder's author and passes the check.") + return True + + +async def setup(bot: Bot) -> None: + """Load the Reminders cog.""" + await bot.add_cog(Reminders(bot)) diff --git a/src/bot/orm/models.py b/src/bot/orm/models.py index 952a6af..ded6fc7 100644 --- a/src/bot/orm/models.py +++ b/src/bot/orm/models.py @@ -1,7 +1,8 @@ """Database models.""" -from sqlalchemy import BigInteger -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +from datetime import datetime +from sqlalchemy import ARRAY, BigInteger, DateTime +from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, mapped_column class Base(DeclarativeBase): @@ -14,3 +15,28 @@ class Guild(Base): __tablename__ = "guilds" guild_id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + + +class Reminder(MappedAsDataclass, Base): + """Represents a Reminder in the database.""" + + __tablename__ = "reminders" + + id: Mapped[int] = mapped_column( # noqa: A003 # shadowing id() function is fine + init=False, + primary_key=True, + autoincrement=True, + unique=True, + ) + + channel_id: Mapped[int] = mapped_column(BigInteger) + + message_id: Mapped[int] = mapped_column(BigInteger) + + author_id: Mapped[int] = mapped_column(BigInteger) + + mention_ids: Mapped[list[int]] = mapped_column(ARRAY(BigInteger)) + + content: Mapped[str] + + expiration: Mapped[datetime] = mapped_column(DateTime(timezone=True)) diff --git a/src/bot/utils/checks.py b/src/bot/utils/checks.py new file mode 100644 index 0000000..dfef5e8 --- /dev/null +++ b/src/bot/utils/checks.py @@ -0,0 +1,29 @@ +from discord.ext.commands import CheckFailure, Context, NoPrivateMessage, has_any_role + + +async def has_any_role_check(ctx: Context, *roles: str | int) -> bool: + """ + Returns True if the context's author has any of the specified roles. + + `roles` are the names or IDs of the roles for which to check. + False is always returns if the context is outside a guild. + """ + try: + return await has_any_role(*roles).predicate(ctx) + except CheckFailure: + return False + + +async def has_no_roles_check(ctx: Context, *roles: str | int) -> bool: + """ + Returns True if the context's author doesn't have any of the specified roles. + + `roles` are the names or IDs of the roles for which to check. + False is always returns if the context is outside a guild. + """ + try: + return not await has_any_role(*roles).predicate(ctx) + except NoPrivateMessage: + return False + except CheckFailure: + return True diff --git a/src/bot/utils/messages.py b/src/bot/utils/messages.py index cae934e..83e0f97 100644 --- a/src/bot/utils/messages.py +++ b/src/bot/utils/messages.py @@ -2,18 +2,32 @@ import contextlib from collections.abc import Callable +from random import random import discord from discord import Embed, Message from discord.ext import commands from discord.ext.commands import Context, MessageConverter +from bot.constants import NEGATIVE_REPLIES + def format_user(user: discord.abc.User) -> str: """Return a string for `user` which has their mention and ID.""" return f"{user.mention} (`{user.id}`)" +async def send_denial(ctx: Context, reason: str) -> discord.Message: + """Send an embed denying the user with the given reason.""" + embed = discord.Embed( + title=random.choice(NEGATIVE_REPLIES), + description=reason, + colour=discord.Colour.red(), + ) + + return await ctx.send(embed=embed) + + async def get_discord_message(ctx: Context, text: str) -> Message | str: """ Attempt to convert a given `text` to a discord Message object and return it. diff --git a/src/bot/utils/paginator.py b/src/bot/utils/paginator.py new file mode 100644 index 0000000..a5aaa97 --- /dev/null +++ b/src/bot/utils/paginator.py @@ -0,0 +1,370 @@ +import asyncio +from contextlib import suppress +from functools import partial + +import discord +from discord.abc import User +from discord.ext.commands import Context, Paginator + +from bot.log import get_logger +from bot.utils import messages + +FIRST_EMOJI = "\u23EE" # [:track_previous:] +LEFT_EMOJI = "\u2B05" # [:arrow_left:] +RIGHT_EMOJI = "\u27A1" # [:arrow_right:] +LAST_EMOJI = "\u23ED" # [:track_next:] +DELETE_EMOJI = "\u274C" # [:x:] + +PAGINATION_EMOJI = (FIRST_EMOJI, LEFT_EMOJI, RIGHT_EMOJI, LAST_EMOJI, DELETE_EMOJI) + +log = get_logger(__name__) + + +class EmptyPaginatorEmbedError(Exception): + """Raised when attempting to paginate with empty contents.""" + + +class LinePaginator(Paginator): + """ + A class that aids in paginating code blocks for Discord messages. + + Available attributes include: + * prefix: `str` + The prefix inserted to every page. e.g. three backticks. + * suffix: `str` + The suffix appended at the end of every page. e.g. three backticks. + * max_size: `int` + The maximum amount of codepoints allowed in a page. + * scale_to_size: `int` + The maximum amount of characters a single line can scale up to. + * max_lines: `int` + The maximum amount of lines allowed in a page. + """ + + def __init__( + self, + prefix: str = "```", + suffix: str = "```", + max_size: int = 4000, + scale_to_size: int = 4000, + max_lines: int | None = None, + linesep: str = "\n", + ) -> None: + """ + This function overrides the Paginator.__init__ from inside discord.ext.commands. + + It overrides in order to allow us to configure the maximum number of lines per page. + """ + # Embeds that exceed 4096 characters will result in an HTTPException + # (Discord API limit), so we've set a limit of 4000 + if max_size > 4000: + msg = f"max_size must be <= 4,000 characters. ({max_size} > 4000)" + raise ValueError(msg) + + super().__init__( + prefix, + suffix, + max_size - len(suffix), + linesep, + ) + + if scale_to_size < max_size: + msg = f"scale_to_size must be >= max_size. ({scale_to_size} < {max_size})" + raise ValueError(msg) + + if scale_to_size > 4000: + msg = f"scale_to_size must be <= 4,000 characters. ({scale_to_size} > 4000)" + raise ValueError(msg) + + self.scale_to_size = scale_to_size - len(suffix) + self.max_lines = max_lines + self._current_page = [prefix] + self._linecount = 0 + self._count = len(prefix) + 1 # prefix + newline + self._pages = [] + + def add_line(self, line: str = "", *, empty: bool = False) -> None: + """ + Adds a line to the current page. + + If a line on a page exceeds `max_size` characters, then `max_size` will go up to + `scale_to_size` for a single line before creating a new page for the overflow words. If it + is still exceeded, the excess characters are stored and placed on the next pages unti + there are none remaining (by word boundary). The line is truncated if `scale_to_size` is + still exceeded after attempting to continue onto the next page. + + In the case that the page already contains one or more lines and the new lines would cause + `max_size` to be exceeded, a new page is created. This is done in order to make a best + effort to avoid breaking up single lines across pages, while keeping the total length of the + page at a reasonable size. + + This function overrides the `Paginator.add_line` from inside `discord.ext.commands`. + + It overrides in order to allow us to configure the maximum number of lines per page. + """ + remaining_words = None + if len(line) > (max_chars := self.max_size - len(self.prefix) - 2) and len(line) > self.scale_to_size: + line, remaining_words = self._split_remaining_words(line, max_chars) + if len(line) > self.scale_to_size: + log.debug("Could not continue to next page, truncating line.") + line = line[: self.scale_to_size] + + # Check if we should start a new page or continue the line on the current one + if self.max_lines is not None and self._linecount >= self.max_lines: + log.debug("max_lines exceeded, creating new page.") + self._new_page() + elif self._count + len(line) + 1 > self.max_size and self._linecount > 0: + log.debug("max_size exceeded on page with lines, creating new page.") + self._new_page() + + self._linecount += 1 + + self._count += len(line) + 1 + self._current_page.append(line) + + if empty: + self._current_page.append("") + self._count += 1 + + # Start a new page if there were any overflow words + if remaining_words: + self._new_page() + self.add_line(remaining_words) + + def _new_page(self) -> None: + """ + Internal: start a new page for the paginator. + + This closes the current page and resets the counters for the new page's line count and + character count. + """ + self._linecount = 0 + self._count = len(self.prefix) + 1 + self.close_page() + + def _split_remaining_words(self, line: str, max_chars: int) -> tuple[str, str | None]: + """ + Internal: split a line into two strings -- reduced_words and remaining_words. + + reduced_words: the remaining words in `line`, after attempting to remove all words that + exceed `max_chars` (rounding down to the nearest word boundary). + + remaining_words: the words in `line` which exceed `max_chars`. This value is None if + no words could be split from `line`. + + If there are any remaining_words, an ellipses is appended to reduced_words and a + continuation header is inserted before remaining_words to visually communicate the line + continuation. + + Return a tuple in the format (reduced_words, remaining_words). + """ + reduced_words = [] + remaining_words = [] + + # "(Continued)" is used on a line by itself to indicate the continuation of last page + continuation_header = "(Continued)\n-----------\n" + reduced_char_count = 0 + is_full = False + + for word in line.split(" "): + if not is_full: + if len(word) + reduced_char_count <= max_chars: + reduced_words.append(word) + reduced_char_count += len(word) + 1 + else: + # If reduced_words is empty, we were unable to split the words across pages + if not reduced_words: + return line, None + is_full = True + remaining_words.append(word) + else: + remaining_words.append(word) + + return ( + " ".join(reduced_words) + "..." if remaining_words else "", + continuation_header + " ".join(remaining_words) if remaining_words else None, + ) + + @classmethod + async def paginate( + cls, + lines: list[str], + ctx: Context | discord.Interaction, + embed: discord.Embed, + prefix: str = "", + suffix: str = "", + max_lines: int | None = None, + max_size: int = 500, + scale_to_size: int = 4000, + empty: bool = True, + restrict_to_user: User | None = None, + timeout: int = 300, + footer_text: str | None = None, + url: str | None = None, + exception_on_empty_embed: bool = False, + reply: bool = False, + ) -> discord.Message | None: + """ + Use a paginator and set of reactions to provide pagination over a set of lines. + + The reactions are used to switch page, or to finish with pagination. + + When used, this will send a message using `ctx.send()` and apply a set of reactions to it. These reactions may + be used to change page, or to remove pagination from the message. + + Pagination will also be removed automatically if no reaction is added for five minutes (300 seconds). + + The interaction will be limited to `restrict_to_user` (ctx.author by default) or + to any user with a moderation role. + + Example: + >>> embed = discord.Embed() + >>> embed.set_author(name="Some Operation", url=url, icon_url=icon) + >>> await LinePaginator.paginate([line for line in lines], ctx, embed) + """ + paginator = cls( + prefix=prefix, suffix=suffix, max_size=max_size, max_lines=max_lines, scale_to_size=scale_to_size + ) + current_page = 0 + + if not restrict_to_user: + restrict_to_user = ctx.user if isinstance(ctx, discord.Interaction) else ctx.author + + if not lines: + if exception_on_empty_embed: + log.exception("Pagination asked for empty lines iterable") + msg = "No lines to paginate" + raise EmptyPaginatorEmbedError(msg) + + log.debug("No lines to add to paginator, adding '(nothing to display)' message") + lines.append("*(nothing to display)*") + + for line in lines: + try: + paginator.add_line(line, empty=empty) + except Exception: + log.exception(f"Failed to add line to paginator: '{line}'") + raise # Should propagate + else: + log.trace(f"Added line to paginator: '{line}'") + + log.debug(f"Paginator created with {len(paginator.pages)} pages") + + embed.description = paginator.pages[current_page] + + reference = ctx.message if reply else None + + if len(paginator.pages) <= 1: + if footer_text: + embed.set_footer(text=footer_text) + log.trace(f"Setting embed footer to '{footer_text}'") + + if url: + embed.url = url + log.trace(f"Setting embed url to '{url}'") + + log.debug("There's less than two pages, so we won't paginate - sending single page on its own") + + if isinstance(ctx, discord.Interaction): + return await ctx.response.send_message(embed=embed) + return await ctx.send(embed=embed, reference=reference) + + if footer_text: + embed.set_footer(text=f"{footer_text} (Page {current_page + 1}/{len(paginator.pages)})") + else: + embed.set_footer(text=f"Page {current_page + 1}/{len(paginator.pages)}") + log.trace(f"Setting embed footer to '{embed.footer.text}'") + + if url: + embed.url = url + log.trace(f"Setting embed url to '{url}'") + + log.debug("Sending first page to channel...") + + if isinstance(ctx, discord.Interaction): + await ctx.response.send_message(embed=embed) + message = await ctx.original_response() + else: + message = await ctx.send(embed=embed, reference=reference) + + log.debug("Adding emoji reactions to message...") + + for emoji in PAGINATION_EMOJI: + # Add all the applicable emoji to the message + log.trace(f"Adding reaction: {emoji!r}") + await message.add_reaction(emoji) + + check = partial( + messages.reaction_check, + message_id=message.id, + allowed_emoji=PAGINATION_EMOJI, + allowed_users=(restrict_to_user.id,), + ) + + while True: + try: + if isinstance(ctx, discord.Interaction): + reaction, user = await ctx.client.wait_for("reaction_add", timeout=timeout, check=check) + else: + reaction, user = await ctx.bot.wait_for("reaction_add", timeout=timeout, check=check) + log.trace(f"Got reaction: {reaction}") + except asyncio.TimeoutError: + log.debug("Timed out waiting for a reaction") + break # We're done, no reactions for the last 5 minutes + + if str(reaction.emoji) == DELETE_EMOJI: + log.debug("Got delete reaction") + return await message.delete() + if reaction.emoji in PAGINATION_EMOJI: + total_pages = len(paginator.pages) + try: + await message.remove_reaction(reaction.emoji, user) + except discord.HTTPException as e: + # Suppress if trying to act on an archived thread. + if e.code != 50083: + raise + + if reaction.emoji == FIRST_EMOJI: + current_page = 0 + log.debug(f"Got first page reaction - changing to page 1/{total_pages}") + elif reaction.emoji == LAST_EMOJI: + current_page = len(paginator.pages) - 1 + log.debug(f"Got last page reaction - changing to page {current_page + 1}/{total_pages}") + elif reaction.emoji == LEFT_EMOJI: + if current_page <= 0: + log.debug("Got previous page reaction, but we're on the first page - ignoring") + continue + + current_page -= 1 + log.debug(f"Got previous page reaction - changing to page {current_page + 1}/{total_pages}") + elif reaction.emoji == RIGHT_EMOJI: + if current_page >= len(paginator.pages) - 1: + log.debug("Got next page reaction, but we're on the last page - ignoring") + continue + + current_page += 1 + log.debug(f"Got next page reaction - changing to page {current_page + 1}/{total_pages}") + + embed.description = paginator.pages[current_page] + + if footer_text: + embed.set_footer(text=f"{footer_text} (Page {current_page + 1}/{len(paginator.pages)})") + else: + embed.set_footer(text=f"Page {current_page + 1}/{len(paginator.pages)}") + + try: + await message.edit(embed=embed) + except discord.HTTPException as e: + if e.code == 50083: + # Trying to act on an archived thread, just ignore and abort + break + raise + + log.debug("Ending pagination and clearing reactions.") + with suppress(discord.NotFound): + try: + await message.clear_reactions() + except discord.HTTPException as e: + # Suppress if trying to act on an archived thread. + if e.code != 50083: + raise diff --git a/src/bot/utils/time.py b/src/bot/utils/time.py new file mode 100644 index 0000000..f7a2cc9 --- /dev/null +++ b/src/bot/utils/time.py @@ -0,0 +1,46 @@ +import re +from datetime import timedelta + +DURATION_REGEX = re.compile( + r"((?P\d+?) ?(years|year|Y|y) ?)?" + r"((?P\d+?) ?(months|month|M) ?)?" + r"((?P\d+?) ?(weeks|week|W|w) ?)?" + r"((?P\d+?) ?(days|day|D|d) ?)?" + r"((?P\d+?) ?(hours|hour|H|h) ?)?" + r"((?P\d+?) ?(minutes|minute|m) ?)?" + r"((?P\d+?) ?(seconds|second|S|s))?", +) + + +def parse_duration_string(duration: str) -> timedelta | None: + """ + Convert a `duration` string to a relativedelta object. + + The following symbols are supported for each unit of time: + + - years: `Y`, `y`, `year`, `years` + - months: `M`, `month`, `months` + - weeks: `w`, `W`, `week`, `weeks` + - days: `d`, `D`, `day`, `days` + - hours: `H`, `h`, `hour`, `hours` + - minutes: `m`, `minute`, `minutes` + - seconds: `S`, `s`, `second`, `seconds` + + The units need to be provided in descending order of magnitude. + Return None if the `duration` string cannot be parsed according to the symbols above. + """ + match = DURATION_REGEX.fullmatch(duration) + if not match: + return None + + duration_dict = {unit: int(amount) for unit, amount in match.groupdict(default="0").items()} + + # since timedelta doesn't support months, let's just say 1 month = 30 days + months = duration_dict.pop("months") + duration_dict["days"] += int(months) * 30 + + # since timedelta doesn't support years, let's just say 1 year = 365 days + years = duration_dict.pop("years") + duration_dict["days"] += int(years) * 365 + + return timedelta(**duration_dict)