From 4017c358a454e457e1e95ab52afaf3d14475e91a Mon Sep 17 00:00:00 2001 From: zevaryx Date: Sun, 20 Mar 2022 14:32:43 -0600 Subject: [PATCH] Add new scheduler for better accuracy --- jarvis_tasks/tasks/ban.py | 53 +++++++++++---------- jarvis_tasks/tasks/reminder.py | 84 ++++++++++++++++++++-------------- jarvis_tasks/tasks/warning.py | 19 +++++--- jarvis_tasks/util.py | 20 ++++++++ 4 files changed, 113 insertions(+), 63 deletions(-) create mode 100644 jarvis_tasks/util.py diff --git a/jarvis_tasks/tasks/ban.py b/jarvis_tasks/tasks/ban.py index d5af4f0..3e94ef5 100644 --- a/jarvis_tasks/tasks/ban.py +++ b/jarvis_tasks/tasks/ban.py @@ -5,9 +5,32 @@ from logging import Logger from dis_snek import Snake from dis_snek.client.errors import NotFound +from dis_snek.models.discord.guild import Guild +from dis_snek.models.discord.user import User from jarvis_core.db import q from jarvis_core.db.models import Ban, Unban +from jarvis_tasks.util import runat + + +async def _unban(bot: int, guild: Guild, user: User, ban: Ban, logger: Logger) -> None: + if guild and user: + logger.debug(f"Unbanning user {user.id} from guild {guild.id}") + try: + await guild.unban(user=user, reason="JARVIS tempban expired") + except NotFound: + logger.debug(f"User {user.id} not banned from guild {guild.id}") + ban.active = False + await ban.commit() + await Unban( + user=user.id, + guild=guild.id, + username=user.username, + discrim=user.discriminator, + admin=bot, + reason="Ban expired", + ).commit() + async def unban(bot: Snake, logger: Logger) -> None: """ @@ -18,30 +41,14 @@ async def unban(bot: Snake, logger: Logger) -> None: logger: Global logger """ while True: - max_time = datetime.utcnow() + timedelta(minutes=10) - bans = Ban.find(q(type="temp", active=True)) + max_ts = datetime.utcnow() + timedelta(minutes=9) + bans = Ban.find(q(type="temp", active=True, duration__lte=max_ts)) async for ban in bans: - if ban.created_at + timedelta(hours=ban.duration) < max_time: - guild = await bot.fetch_guild(ban.guild) - user = await bot.fetch_user(ban.user) - if guild and user: - logger.debug(f"Unbanned user {user.id} from guild {guild.id}") - try: - await guild.unban(user=user, reason="JARVIS tempban expired") - except NotFound: - logger.debug(f"User {user.id} not banned from guild {guild.id}") - - ban.update(q(active=False)) - await ban.commit() - u = Unban( - user=user.id, - guild=guild.id, - username=user.username, - discrim=user.discriminator, - admin=bot.user.id, - reason="Ban expired", - ) - await u.commit() + guild = await bot.fetch_guild(ban.guild) + user = await bot.fetch_user(ban.user) + coro = _unban(bot.user.id, guild, user, ban, logger) + when = ban.created_at + timedelta(hours=ban.duration) + asyncio.create_task(runat(when, coro, logger)) # Check ever 10 minutes await asyncio.sleep(600) diff --git a/jarvis_tasks/tasks/reminder.py b/jarvis_tasks/tasks/reminder.py index bc01460..9357a6d 100644 --- a/jarvis_tasks/tasks/reminder.py +++ b/jarvis_tasks/tasks/reminder.py @@ -2,12 +2,57 @@ import asyncio from datetime import datetime, timedelta from logging import Logger +from typing import Optional from dis_snek import Snake +from dis_snek.models.discord.channel import GuildText +from dis_snek.models.discord.embed import Embed +from dis_snek.models.discord.user import User from jarvis_core.db import q from jarvis_core.db.models import Reminder from jarvis_core.util import build_embed +from jarvis_tasks.util import runat + + +async def _remind( + user: User, + reminder: Reminder, + embed: Embed, + logger: Logger, + channel: Optional[GuildText] = None, +) -> None: + delete = True + try: + await user.send(embed=embed) + logger.debug(f"Reminder {reminder.id} send to user") + except Exception: + logger.debug("Failed to DM user, falling back to channel") + if channel: + member = await channel.guild.fetch_member(user.id) + if not member: + logger.debug("User no longer in origin guild") + else: + if channel and not reminder.private: + await channel.send(f"{member.mention}", embed=embed) + logger.debug(f"Reminder {reminder.id} sent to origin channel") + elif channel: + await channel.send( + f"{member.mention}, you had a private reminder set for now," + " but I couldn't send it to you.\n" + f"Use `/reminder fetch {str(reminder.id)}` to view" + ) + logger.debug( + f"Reminder {reminder.id} private, sent notification to origin channel" + ) + reminder.active = False + await reminder.commit() + delete = False + else: + logger.warning(f"Reminder {reminder.id} failed, no way to contact user.") + if delete: + await reminder.delete() + async def remind(bot: Snake, logger: Logger) -> None: """ @@ -18,9 +63,8 @@ async def remind(bot: Snake, logger: Logger) -> None: logger: Global logger """ while True: - reminders = Reminder.find( - q(remind_at__lte=datetime.utcnow() + timedelta(seconds=5), active=True) - ) + max_ts = datetime.utcnow() + timedelta(seconds=5) + reminders = Reminder.find(q(remind_at__lte=max_ts, active=True)) async for reminder in reminders: user = await bot.fetch_user(reminder.user) if not user: @@ -37,37 +81,9 @@ async def remind(bot: Snake, logger: Logger) -> None: embed.set_thumbnail(url=user.avatar.url) - try: - await user.send(embed=embed) - logger.info(f"Reminder {reminder.id} sent to user") - await reminder.delete() - except Exception: - logger.info("User has closed DMs") - guild = await bot.fetch_guild(reminder.guild) - member = await bot.fetch_member(user.id) - if not member: - logger.warning("User no longer member of origin guild, deleting reminder") - await reminder.delete() - continue - channel = await guild.fetch_channel(reminder.channel) if guild else None - if channel and not reminder.private: - await channel.send(f"{member.mention}", embed=embed) - logger.debug(f"Reminder {reminder.id} sent to origin channel") - await reminder.delete() - elif channel: - await channel.send( - f"{member.mention}, you had a private reminder set for now," - " but I couldn't send it to you.\n" - f"Use `/reminder fetch {str(reminder.id)}` to view" - ) - logger.info( - f"Reminder {reminder.id} private, sent notification to origin channel" - ) - reminder.update(q(active=False)) - await reminder.commit() - else: - logger.warning("No way to contact user, deleting reminder") - await reminder.delete() + channel = await bot.fetch_channel(reminder.channel) + coro = _remind(user, reminder, embed, logger, channel) + asyncio.create_task(runat(reminder.remind_at, coro, logger)) # Check every 5 seconds await asyncio.sleep(5) diff --git a/jarvis_tasks/tasks/warning.py b/jarvis_tasks/tasks/warning.py index 5a17f0f..669ff93 100644 --- a/jarvis_tasks/tasks/warning.py +++ b/jarvis_tasks/tasks/warning.py @@ -7,6 +7,14 @@ from dis_snek import Snake from jarvis_core.db import q from jarvis_core.db.models import Warning +from jarvis_tasks.util import runat + + +async def _unwarn(warn: Warning, logger: Logger) -> None: + logger.debug(f"Deactivating warning {warn.id}") + warn.update(q(active=False)) + await warn.commit() + async def unwarn(bot: Snake, logger: Logger) -> None: """ @@ -17,12 +25,11 @@ async def unwarn(bot: Snake, logger: Logger) -> None: logger: Global logger """ while True: - warns = Warning.find(q(active=True)) + max_ts = datetime.utcnow() + timedelta(minutes=55) + warns = Warning.find(q(active=True, created_at__lte=max_ts)) async for warn in warns: - if warn.created_at + timedelta(hours=warn.duration) < datetime.utcnow(): - logger.debug(f"Deactivating warning {warn.id}") - warn.update(q(active=False)) - await warn.commit() - + coro = _unwarn(warn, logger) + when = warn.created_at + timedelta(hours=warn.duration) + asyncio.create_task(runat(when, coro, logger)) # Check every hour await asyncio.sleep(3600) diff --git a/jarvis_tasks/util.py b/jarvis_tasks/util.py new file mode 100644 index 0000000..5910065 --- /dev/null +++ b/jarvis_tasks/util.py @@ -0,0 +1,20 @@ +"""JARVIS task utilities.""" +import asyncio +from datetime import datetime +from logging import Logger +from typing import Coroutine + + +async def runat(when: datetime, coro: Coroutine, logger: Logger) -> None: + """ + Run a task at a scheduled time. + + Args: + when: When to run the task + coro: Coroutine to execute + logger: Global logger + """ + logger.debug(f"Scheduling task {coro.__name__} for {when.isoformat()}") + delay = when - datetime.utcnow() + await asyncio.sleep(delay.total_seconds()) + await coro