From be3d5387a5abd094c2847f1828b227426ff738fa Mon Sep 17 00:00:00 2001 From: zevaryx Date: Tue, 19 Apr 2022 14:59:31 -0600 Subject: [PATCH] Add queues to prevent re-processing tasks --- jarvis_tasks/tasks/ban.py | 6 ++++++ jarvis_tasks/tasks/lock.py | 6 ++++++ jarvis_tasks/tasks/lockdown.py | 6 ++++++ jarvis_tasks/tasks/reminder.py | 6 ++++++ jarvis_tasks/tasks/warning.py | 6 ++++++ 5 files changed, 30 insertions(+) diff --git a/jarvis_tasks/tasks/ban.py b/jarvis_tasks/tasks/ban.py index 9c9ca10..9b5a3c5 100644 --- a/jarvis_tasks/tasks/ban.py +++ b/jarvis_tasks/tasks/ban.py @@ -12,6 +12,8 @@ from jarvis_core.db.models import Ban, Unban from jarvis_tasks.util import runat +queue = [] + async def _unban(bot: int, guild: Guild, user: User, ban: Ban, logger: Logger) -> None: if guild and user: @@ -30,6 +32,7 @@ async def _unban(bot: int, guild: Guild, user: User, ban: Ban, logger: Logger) - admin=bot, reason="Ban expired", ).commit() + queue.remove(ban.id) async def unban(bot: Snake, logger: Logger) -> None: @@ -44,11 +47,14 @@ async def unban(bot: Snake, logger: Logger) -> None: max_ts = datetime.now(tz=timezone.utc) + timedelta(minutes=9) bans = Ban.find(q(type="temp", active=True, duration__lte=max_ts)) async for ban in bans: + if ban.id in queue: + continue guild = await bot.fetch_guild(ban.guild) user = await bot.fetch_user(ban.user) coro = _unban(bot.user.id, guild, user, ban, logger) when = ban.created_at + timedelta(hours=ban.duration) asyncio.create_task(runat(when, coro, logger)) + queue.append(ban.id) # Check ever 10 minutes await asyncio.sleep(600) diff --git a/jarvis_tasks/tasks/lock.py b/jarvis_tasks/tasks/lock.py index beb468a..df607af 100644 --- a/jarvis_tasks/tasks/lock.py +++ b/jarvis_tasks/tasks/lock.py @@ -11,6 +11,8 @@ from jarvis_core.db.models import Lock from jarvis_tasks.util import runat +queue = [] + async def _unlock(channel: GuildChannel, lock: Lock, logger: Logger) -> None: logger.debug(f"Deactivating lock {lock.id}") @@ -31,6 +33,7 @@ async def _unlock(channel: GuildChannel, lock: Lock, logger: Logger) -> None: logger.debug("Locked channel deleted, ignoring error") lock.active = False await lock.commit() + queue.remove(lock.id) async def unlock(bot: Snake, logger: Logger) -> None: @@ -45,10 +48,13 @@ async def unlock(bot: Snake, logger: Logger) -> None: max_ts = datetime.now(tz=timezone.utc) + timedelta(seconds=55) locks = Lock.find(q(active=True, created_at__lte=max_ts)) async for lock in locks: + if lock.id in queue: + continue guild = await bot.fetch_guild(lock.guild) channel = await guild.fetch_channel(lock.channel) coro = _unlock(channel, lock, logger) when = lock.created_at + timedelta(minutes=lock.duration) asyncio.create_task(runat(when, coro, logger)) + queue.append(lock.id) await asyncio.sleep(delay=60) diff --git a/jarvis_tasks/tasks/lockdown.py b/jarvis_tasks/tasks/lockdown.py index 2e2087b..13c4061 100644 --- a/jarvis_tasks/tasks/lockdown.py +++ b/jarvis_tasks/tasks/lockdown.py @@ -11,6 +11,8 @@ from jarvis_core.db.models import Lockdown from jarvis_tasks.util import runat +queue = [] + async def _lift(role: Role, lock: Lockdown, logger: Logger) -> None: logger.debug(f"Lifting lockdown {lock.id}") @@ -18,6 +20,7 @@ async def _lift(role: Role, lock: Lockdown, logger: Logger) -> None: await role.edit(permissions=original_perms) lock.active = False await lock.commit() + queue.remove(lock.id) async def lift(bot: Snake, logger: Logger) -> None: @@ -32,10 +35,13 @@ async def lift(bot: Snake, logger: Logger) -> None: max_ts = datetime.now(tz=timezone.utc) + timedelta(seconds=55) locks = Lockdown.find(q(active=True, created_at__lte=max_ts)) async for lock in locks: + if lock.id in queue: + continue guild = await bot.fetch_guild(lock.guild) role = await guild.fetch_role(guild.id) coro = _lift(role, lock, logger) when = lock.created_at + timedelta(minutes=lock.duration) asyncio.create_task(runat(when, coro, logger)) + queue.append(lock.id) await asyncio.sleep(delay=60) diff --git a/jarvis_tasks/tasks/reminder.py b/jarvis_tasks/tasks/reminder.py index 8196793..e823bda 100644 --- a/jarvis_tasks/tasks/reminder.py +++ b/jarvis_tasks/tasks/reminder.py @@ -14,6 +14,8 @@ from jarvis_core.util import build_embed from jarvis_tasks.util import runat +queue = [] + async def _remind( user: User, @@ -52,6 +54,7 @@ async def _remind( logger.warning(f"Reminder {reminder.id} failed, no way to contact user.") if delete: await reminder.delete() + queue.remove(reminder.id) async def remind(bot: Snake, logger: Logger) -> None: @@ -66,6 +69,8 @@ async def remind(bot: Snake, logger: Logger) -> None: max_ts = datetime.now(tz=timezone.utc) + timedelta(seconds=5) reminders = Reminder.find(q(remind_at__lte=max_ts, active=True)) async for reminder in reminders: + if reminder.id in queue: + continue user = await bot.fetch_user(reminder.user) if not user: logger.warning(f"Failed to get user with ID {reminder.user}") @@ -84,6 +89,7 @@ async def remind(bot: Snake, logger: Logger) -> None: channel = await bot.fetch_channel(reminder.channel) coro = _remind(user, reminder, embed, logger, channel) asyncio.create_task(runat(reminder.remind_at, coro, logger)) + queue.append(reminder.id) # Check every 5 seconds await asyncio.sleep(5) diff --git a/jarvis_tasks/tasks/warning.py b/jarvis_tasks/tasks/warning.py index c895635..c8012a3 100644 --- a/jarvis_tasks/tasks/warning.py +++ b/jarvis_tasks/tasks/warning.py @@ -9,11 +9,14 @@ from jarvis_core.db.models import Warning from jarvis_tasks.util import runat +queue = [] + async def _unwarn(warn: Warning, logger: Logger) -> None: logger.debug(f"Deactivating warning {warn.id}") warn.active = False await warn.commit() + queue.remove(warn.id) async def unwarn(bot: Snake, logger: Logger) -> None: @@ -28,8 +31,11 @@ async def unwarn(bot: Snake, logger: Logger) -> None: max_ts = datetime.now(tz=timezone.utc) + timedelta(minutes=55) warns = Warning.find(q(active=True, created_at__lte=max_ts)) async for warn in warns: + if warn.id in queue: + continue coro = _unwarn(warn, logger) when = warn.created_at + timedelta(hours=warn.duration) asyncio.create_task(runat(when, coro, logger)) + queue.append(warn.id) # Check every hour await asyncio.sleep(3600)