Add new scheduler for better accuracy

This commit is contained in:
Zeva Rose 2022-03-20 14:32:43 -06:00
parent 75338b4a79
commit 4017c358a4
4 changed files with 113 additions and 63 deletions

View file

@ -5,9 +5,32 @@ from logging import Logger
from dis_snek import Snake from dis_snek import Snake
from dis_snek.client.errors import NotFound from dis_snek.client.errors import NotFound
from dis_snek.models.discord.guild import Guild
from dis_snek.models.discord.user import User
from jarvis_core.db import q from jarvis_core.db import q
from jarvis_core.db.models import Ban, Unban from jarvis_core.db.models import Ban, Unban
from jarvis_tasks.util import runat
async def _unban(bot: int, guild: Guild, user: User, ban: Ban, logger: Logger) -> None:
if guild and user:
logger.debug(f"Unbanning user {user.id} from guild {guild.id}")
try:
await guild.unban(user=user, reason="JARVIS tempban expired")
except NotFound:
logger.debug(f"User {user.id} not banned from guild {guild.id}")
ban.active = False
await ban.commit()
await Unban(
user=user.id,
guild=guild.id,
username=user.username,
discrim=user.discriminator,
admin=bot,
reason="Ban expired",
).commit()
async def unban(bot: Snake, logger: Logger) -> None: async def unban(bot: Snake, logger: Logger) -> None:
""" """
@ -18,30 +41,14 @@ async def unban(bot: Snake, logger: Logger) -> None:
logger: Global logger logger: Global logger
""" """
while True: while True:
max_time = datetime.utcnow() + timedelta(minutes=10) max_ts = datetime.utcnow() + timedelta(minutes=9)
bans = Ban.find(q(type="temp", active=True)) bans = Ban.find(q(type="temp", active=True, duration__lte=max_ts))
async for ban in bans: async for ban in bans:
if ban.created_at + timedelta(hours=ban.duration) < max_time: guild = await bot.fetch_guild(ban.guild)
guild = await bot.fetch_guild(ban.guild) user = await bot.fetch_user(ban.user)
user = await bot.fetch_user(ban.user) coro = _unban(bot.user.id, guild, user, ban, logger)
if guild and user: when = ban.created_at + timedelta(hours=ban.duration)
logger.debug(f"Unbanned user {user.id} from guild {guild.id}") asyncio.create_task(runat(when, coro, logger))
try:
await guild.unban(user=user, reason="JARVIS tempban expired")
except NotFound:
logger.debug(f"User {user.id} not banned from guild {guild.id}")
ban.update(q(active=False))
await ban.commit()
u = Unban(
user=user.id,
guild=guild.id,
username=user.username,
discrim=user.discriminator,
admin=bot.user.id,
reason="Ban expired",
)
await u.commit()
# Check ever 10 minutes # Check ever 10 minutes
await asyncio.sleep(600) await asyncio.sleep(600)

View file

@ -2,12 +2,57 @@
import asyncio import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
from logging import Logger from logging import Logger
from typing import Optional
from dis_snek import Snake from dis_snek import Snake
from dis_snek.models.discord.channel import GuildText
from dis_snek.models.discord.embed import Embed
from dis_snek.models.discord.user import User
from jarvis_core.db import q from jarvis_core.db import q
from jarvis_core.db.models import Reminder from jarvis_core.db.models import Reminder
from jarvis_core.util import build_embed from jarvis_core.util import build_embed
from jarvis_tasks.util import runat
async def _remind(
user: User,
reminder: Reminder,
embed: Embed,
logger: Logger,
channel: Optional[GuildText] = None,
) -> None:
delete = True
try:
await user.send(embed=embed)
logger.debug(f"Reminder {reminder.id} send to user")
except Exception:
logger.debug("Failed to DM user, falling back to channel")
if channel:
member = await channel.guild.fetch_member(user.id)
if not member:
logger.debug("User no longer in origin guild")
else:
if channel and not reminder.private:
await channel.send(f"{member.mention}", embed=embed)
logger.debug(f"Reminder {reminder.id} sent to origin channel")
elif channel:
await channel.send(
f"{member.mention}, you had a private reminder set for now,"
" but I couldn't send it to you.\n"
f"Use `/reminder fetch {str(reminder.id)}` to view"
)
logger.debug(
f"Reminder {reminder.id} private, sent notification to origin channel"
)
reminder.active = False
await reminder.commit()
delete = False
else:
logger.warning(f"Reminder {reminder.id} failed, no way to contact user.")
if delete:
await reminder.delete()
async def remind(bot: Snake, logger: Logger) -> None: async def remind(bot: Snake, logger: Logger) -> None:
""" """
@ -18,9 +63,8 @@ async def remind(bot: Snake, logger: Logger) -> None:
logger: Global logger logger: Global logger
""" """
while True: while True:
reminders = Reminder.find( max_ts = datetime.utcnow() + timedelta(seconds=5)
q(remind_at__lte=datetime.utcnow() + timedelta(seconds=5), active=True) reminders = Reminder.find(q(remind_at__lte=max_ts, active=True))
)
async for reminder in reminders: async for reminder in reminders:
user = await bot.fetch_user(reminder.user) user = await bot.fetch_user(reminder.user)
if not user: if not user:
@ -37,37 +81,9 @@ async def remind(bot: Snake, logger: Logger) -> None:
embed.set_thumbnail(url=user.avatar.url) embed.set_thumbnail(url=user.avatar.url)
try: channel = await bot.fetch_channel(reminder.channel)
await user.send(embed=embed) coro = _remind(user, reminder, embed, logger, channel)
logger.info(f"Reminder {reminder.id} sent to user") asyncio.create_task(runat(reminder.remind_at, coro, logger))
await reminder.delete()
except Exception:
logger.info("User has closed DMs")
guild = await bot.fetch_guild(reminder.guild)
member = await bot.fetch_member(user.id)
if not member:
logger.warning("User no longer member of origin guild, deleting reminder")
await reminder.delete()
continue
channel = await guild.fetch_channel(reminder.channel) if guild else None
if channel and not reminder.private:
await channel.send(f"{member.mention}", embed=embed)
logger.debug(f"Reminder {reminder.id} sent to origin channel")
await reminder.delete()
elif channel:
await channel.send(
f"{member.mention}, you had a private reminder set for now,"
" but I couldn't send it to you.\n"
f"Use `/reminder fetch {str(reminder.id)}` to view"
)
logger.info(
f"Reminder {reminder.id} private, sent notification to origin channel"
)
reminder.update(q(active=False))
await reminder.commit()
else:
logger.warning("No way to contact user, deleting reminder")
await reminder.delete()
# Check every 5 seconds # Check every 5 seconds
await asyncio.sleep(5) await asyncio.sleep(5)

View file

@ -7,6 +7,14 @@ from dis_snek import Snake
from jarvis_core.db import q from jarvis_core.db import q
from jarvis_core.db.models import Warning from jarvis_core.db.models import Warning
from jarvis_tasks.util import runat
async def _unwarn(warn: Warning, logger: Logger) -> None:
logger.debug(f"Deactivating warning {warn.id}")
warn.update(q(active=False))
await warn.commit()
async def unwarn(bot: Snake, logger: Logger) -> None: async def unwarn(bot: Snake, logger: Logger) -> None:
""" """
@ -17,12 +25,11 @@ async def unwarn(bot: Snake, logger: Logger) -> None:
logger: Global logger logger: Global logger
""" """
while True: while True:
warns = Warning.find(q(active=True)) max_ts = datetime.utcnow() + timedelta(minutes=55)
warns = Warning.find(q(active=True, created_at__lte=max_ts))
async for warn in warns: async for warn in warns:
if warn.created_at + timedelta(hours=warn.duration) < datetime.utcnow(): coro = _unwarn(warn, logger)
logger.debug(f"Deactivating warning {warn.id}") when = warn.created_at + timedelta(hours=warn.duration)
warn.update(q(active=False)) asyncio.create_task(runat(when, coro, logger))
await warn.commit()
# Check every hour # Check every hour
await asyncio.sleep(3600) await asyncio.sleep(3600)

20
jarvis_tasks/util.py Normal file
View file

@ -0,0 +1,20 @@
"""JARVIS task utilities."""
import asyncio
from datetime import datetime
from logging import Logger
from typing import Coroutine
async def runat(when: datetime, coro: Coroutine, logger: Logger) -> None:
"""
Run a task at a scheduled time.
Args:
when: When to run the task
coro: Coroutine to execute
logger: Global logger
"""
logger.debug(f"Scheduling task {coro.__name__} for {when.isoformat()}")
delay = when - datetime.utcnow()
await asyncio.sleep(delay.total_seconds())
await coro