Add new scheduler for better accuracy
This commit is contained in:
parent
75338b4a79
commit
4017c358a4
4 changed files with 113 additions and 63 deletions
|
@ -5,9 +5,32 @@ from logging import Logger
|
|||
|
||||
from dis_snek import Snake
|
||||
from dis_snek.client.errors import NotFound
|
||||
from dis_snek.models.discord.guild import Guild
|
||||
from dis_snek.models.discord.user import User
|
||||
from jarvis_core.db import q
|
||||
from jarvis_core.db.models import Ban, Unban
|
||||
|
||||
from jarvis_tasks.util import runat
|
||||
|
||||
|
||||
async def _unban(bot: int, guild: Guild, user: User, ban: Ban, logger: Logger) -> None:
|
||||
if guild and user:
|
||||
logger.debug(f"Unbanning user {user.id} from guild {guild.id}")
|
||||
try:
|
||||
await guild.unban(user=user, reason="JARVIS tempban expired")
|
||||
except NotFound:
|
||||
logger.debug(f"User {user.id} not banned from guild {guild.id}")
|
||||
ban.active = False
|
||||
await ban.commit()
|
||||
await Unban(
|
||||
user=user.id,
|
||||
guild=guild.id,
|
||||
username=user.username,
|
||||
discrim=user.discriminator,
|
||||
admin=bot,
|
||||
reason="Ban expired",
|
||||
).commit()
|
||||
|
||||
|
||||
async def unban(bot: Snake, logger: Logger) -> None:
|
||||
"""
|
||||
|
@ -18,30 +41,14 @@ async def unban(bot: Snake, logger: Logger) -> None:
|
|||
logger: Global logger
|
||||
"""
|
||||
while True:
|
||||
max_time = datetime.utcnow() + timedelta(minutes=10)
|
||||
bans = Ban.find(q(type="temp", active=True))
|
||||
max_ts = datetime.utcnow() + timedelta(minutes=9)
|
||||
bans = Ban.find(q(type="temp", active=True, duration__lte=max_ts))
|
||||
async for ban in bans:
|
||||
if ban.created_at + timedelta(hours=ban.duration) < max_time:
|
||||
guild = await bot.fetch_guild(ban.guild)
|
||||
user = await bot.fetch_user(ban.user)
|
||||
if guild and user:
|
||||
logger.debug(f"Unbanned user {user.id} from guild {guild.id}")
|
||||
try:
|
||||
await guild.unban(user=user, reason="JARVIS tempban expired")
|
||||
except NotFound:
|
||||
logger.debug(f"User {user.id} not banned from guild {guild.id}")
|
||||
|
||||
ban.update(q(active=False))
|
||||
await ban.commit()
|
||||
u = Unban(
|
||||
user=user.id,
|
||||
guild=guild.id,
|
||||
username=user.username,
|
||||
discrim=user.discriminator,
|
||||
admin=bot.user.id,
|
||||
reason="Ban expired",
|
||||
)
|
||||
await u.commit()
|
||||
coro = _unban(bot.user.id, guild, user, ban, logger)
|
||||
when = ban.created_at + timedelta(hours=ban.duration)
|
||||
asyncio.create_task(runat(when, coro, logger))
|
||||
|
||||
# Check ever 10 minutes
|
||||
await asyncio.sleep(600)
|
||||
|
|
|
@ -2,12 +2,57 @@
|
|||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from logging import Logger
|
||||
from typing import Optional
|
||||
|
||||
from dis_snek import Snake
|
||||
from dis_snek.models.discord.channel import GuildText
|
||||
from dis_snek.models.discord.embed import Embed
|
||||
from dis_snek.models.discord.user import User
|
||||
from jarvis_core.db import q
|
||||
from jarvis_core.db.models import Reminder
|
||||
from jarvis_core.util import build_embed
|
||||
|
||||
from jarvis_tasks.util import runat
|
||||
|
||||
|
||||
async def _remind(
|
||||
user: User,
|
||||
reminder: Reminder,
|
||||
embed: Embed,
|
||||
logger: Logger,
|
||||
channel: Optional[GuildText] = None,
|
||||
) -> None:
|
||||
delete = True
|
||||
try:
|
||||
await user.send(embed=embed)
|
||||
logger.debug(f"Reminder {reminder.id} send to user")
|
||||
except Exception:
|
||||
logger.debug("Failed to DM user, falling back to channel")
|
||||
if channel:
|
||||
member = await channel.guild.fetch_member(user.id)
|
||||
if not member:
|
||||
logger.debug("User no longer in origin guild")
|
||||
else:
|
||||
if channel and not reminder.private:
|
||||
await channel.send(f"{member.mention}", embed=embed)
|
||||
logger.debug(f"Reminder {reminder.id} sent to origin channel")
|
||||
elif channel:
|
||||
await channel.send(
|
||||
f"{member.mention}, you had a private reminder set for now,"
|
||||
" but I couldn't send it to you.\n"
|
||||
f"Use `/reminder fetch {str(reminder.id)}` to view"
|
||||
)
|
||||
logger.debug(
|
||||
f"Reminder {reminder.id} private, sent notification to origin channel"
|
||||
)
|
||||
reminder.active = False
|
||||
await reminder.commit()
|
||||
delete = False
|
||||
else:
|
||||
logger.warning(f"Reminder {reminder.id} failed, no way to contact user.")
|
||||
if delete:
|
||||
await reminder.delete()
|
||||
|
||||
|
||||
async def remind(bot: Snake, logger: Logger) -> None:
|
||||
"""
|
||||
|
@ -18,9 +63,8 @@ async def remind(bot: Snake, logger: Logger) -> None:
|
|||
logger: Global logger
|
||||
"""
|
||||
while True:
|
||||
reminders = Reminder.find(
|
||||
q(remind_at__lte=datetime.utcnow() + timedelta(seconds=5), active=True)
|
||||
)
|
||||
max_ts = datetime.utcnow() + timedelta(seconds=5)
|
||||
reminders = Reminder.find(q(remind_at__lte=max_ts, active=True))
|
||||
async for reminder in reminders:
|
||||
user = await bot.fetch_user(reminder.user)
|
||||
if not user:
|
||||
|
@ -37,37 +81,9 @@ async def remind(bot: Snake, logger: Logger) -> None:
|
|||
|
||||
embed.set_thumbnail(url=user.avatar.url)
|
||||
|
||||
try:
|
||||
await user.send(embed=embed)
|
||||
logger.info(f"Reminder {reminder.id} sent to user")
|
||||
await reminder.delete()
|
||||
except Exception:
|
||||
logger.info("User has closed DMs")
|
||||
guild = await bot.fetch_guild(reminder.guild)
|
||||
member = await bot.fetch_member(user.id)
|
||||
if not member:
|
||||
logger.warning("User no longer member of origin guild, deleting reminder")
|
||||
await reminder.delete()
|
||||
continue
|
||||
channel = await guild.fetch_channel(reminder.channel) if guild else None
|
||||
if channel and not reminder.private:
|
||||
await channel.send(f"{member.mention}", embed=embed)
|
||||
logger.debug(f"Reminder {reminder.id} sent to origin channel")
|
||||
await reminder.delete()
|
||||
elif channel:
|
||||
await channel.send(
|
||||
f"{member.mention}, you had a private reminder set for now,"
|
||||
" but I couldn't send it to you.\n"
|
||||
f"Use `/reminder fetch {str(reminder.id)}` to view"
|
||||
)
|
||||
logger.info(
|
||||
f"Reminder {reminder.id} private, sent notification to origin channel"
|
||||
)
|
||||
reminder.update(q(active=False))
|
||||
await reminder.commit()
|
||||
else:
|
||||
logger.warning("No way to contact user, deleting reminder")
|
||||
await reminder.delete()
|
||||
channel = await bot.fetch_channel(reminder.channel)
|
||||
coro = _remind(user, reminder, embed, logger, channel)
|
||||
asyncio.create_task(runat(reminder.remind_at, coro, logger))
|
||||
|
||||
# Check every 5 seconds
|
||||
await asyncio.sleep(5)
|
||||
|
|
|
@ -7,6 +7,14 @@ from dis_snek import Snake
|
|||
from jarvis_core.db import q
|
||||
from jarvis_core.db.models import Warning
|
||||
|
||||
from jarvis_tasks.util import runat
|
||||
|
||||
|
||||
async def _unwarn(warn: Warning, logger: Logger) -> None:
|
||||
logger.debug(f"Deactivating warning {warn.id}")
|
||||
warn.update(q(active=False))
|
||||
await warn.commit()
|
||||
|
||||
|
||||
async def unwarn(bot: Snake, logger: Logger) -> None:
|
||||
"""
|
||||
|
@ -17,12 +25,11 @@ async def unwarn(bot: Snake, logger: Logger) -> None:
|
|||
logger: Global logger
|
||||
"""
|
||||
while True:
|
||||
warns = Warning.find(q(active=True))
|
||||
max_ts = datetime.utcnow() + timedelta(minutes=55)
|
||||
warns = Warning.find(q(active=True, created_at__lte=max_ts))
|
||||
async for warn in warns:
|
||||
if warn.created_at + timedelta(hours=warn.duration) < datetime.utcnow():
|
||||
logger.debug(f"Deactivating warning {warn.id}")
|
||||
warn.update(q(active=False))
|
||||
await warn.commit()
|
||||
|
||||
coro = _unwarn(warn, logger)
|
||||
when = warn.created_at + timedelta(hours=warn.duration)
|
||||
asyncio.create_task(runat(when, coro, logger))
|
||||
# Check every hour
|
||||
await asyncio.sleep(3600)
|
||||
|
|
20
jarvis_tasks/util.py
Normal file
20
jarvis_tasks/util.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
"""JARVIS task utilities."""
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from logging import Logger
|
||||
from typing import Coroutine
|
||||
|
||||
|
||||
async def runat(when: datetime, coro: Coroutine, logger: Logger) -> None:
|
||||
"""
|
||||
Run a task at a scheduled time.
|
||||
|
||||
Args:
|
||||
when: When to run the task
|
||||
coro: Coroutine to execute
|
||||
logger: Global logger
|
||||
"""
|
||||
logger.debug(f"Scheduling task {coro.__name__} for {when.isoformat()}")
|
||||
delay = when - datetime.utcnow()
|
||||
await asyncio.sleep(delay.total_seconds())
|
||||
await coro
|
Loading…
Add table
Reference in a new issue