jarvis-bot/jarvis/cogs/reddit.py
2022-05-01 15:38:41 -06:00

167 lines
6.1 KiB
Python

"""JARVIS Reddit cog."""
import asyncio
import logging
from asyncpraw import Reddit
from asyncprawcore.exceptions import Forbidden, NotFound, Redirect
from dis_snek import InteractionContext, Permissions, Scale, Snake
from dis_snek.client.utils.misc_utils import get
from dis_snek.models.discord.channel import ChannelTypes, GuildText
from dis_snek.models.discord.components import ActionRow, Select, SelectOption
from dis_snek.models.snek.application_commands import (
OptionTypes,
SlashCommand,
slash_option,
)
from dis_snek.models.snek.command import check
from jarvis_core.db import q
from jarvis_core.db.models import Subreddit, SubredditFollow
from jarvis import const
from jarvis.config import JarvisConfig
from jarvis.utils.permissions import admin_or_permissions
DEFAULT_USER_AGENT = f"python:JARVIS:{const.__version__} (by u/zevaryx)"
class RedditCog(Scale):
"""JARVIS Reddit Cog."""
def __init__(self, bot: Snake):
self.bot = bot
self.logger = logging.getLogger(__name__)
config = JarvisConfig.from_yaml()
config.reddit["user_agent"] = config.reddit.get("user_agent", DEFAULT_USER_AGENT)
self.api = Reddit(**config.reddit)
reddit = SlashCommand(name="reddit", description="Manage Reddit follows")
@reddit.subcommand(sub_cmd_name="follow", sub_cmd_description="Follow a Subreddit")
@slash_option(
name="name",
description="Subreddit display name",
opt_type=OptionTypes.STRING,
required=True,
)
@slash_option(
name="channel",
description="Channel to post to",
opt_type=OptionTypes.CHANNEL,
channel_types=[ChannelTypes.GUILD_TEXT],
required=True,
)
@check(admin_or_permissions(Permissions.MANAGE_GUILD))
async def _reddit_follow(self, ctx: InteractionContext, name: str, channel: GuildText) -> None:
name = name.replace("r/", "")
if len(name) > 20 or len(name) < 3:
await ctx.send("Invalid Subreddit name", ephemeral=True)
return
if not isinstance(channel, GuildText):
await ctx.send("Channel must be a text channel", ephemeral=True)
return
try:
subreddit = await self.api.subreddit(name)
await subreddit.load()
except (NotFound, Forbidden, Redirect) as e:
self.logger.debug(f"Subreddit {name} raised {e.__class__.__name__} on add")
await ctx.send("Subreddit may be private, quarantined, or nonexistent.", ephemeral=True)
return
exists = await SubredditFollow.find_one(
q(display_name=subreddit.display_name, guild=ctx.guild.id)
)
if exists:
await ctx.send("Subreddit already being followed in this guild", ephemeral=True)
return
count = len([i async for i in SubredditFollow.find(q(guild=ctx.guild.id))])
if count >= 12:
await ctx.send("Cannot follow more than 12 Subreddits", ephemeral=True)
return
if subreddit.over18 and not channel.nsfw:
await ctx.send(
"Subreddit is nsfw, but channel is not. Mark the channel NSFW first.",
ephemeral=True,
)
return
sr = await Subreddit.find_one(q(display_name=subreddit.display_name))
if not sr:
sr = Subreddit(display_name=subreddit.display_name, over18=subreddit.over18)
await sr.commit()
srf = SubredditFollow(
display_name=subreddit.display_name,
channel=channel.id,
guild=ctx.guild.id,
admin=ctx.author.id,
)
await srf.commit()
await ctx.send(f"Now following `r/{name}` in {channel.mention}")
@reddit.subcommand(sub_cmd_name="unfollow", sub_cmd_description="Unfollow Subreddits")
@check(admin_or_permissions(Permissions.MANAGE_GUILD))
async def _subreddit_unfollow(self, ctx: InteractionContext) -> None:
subs = SubredditFollow.find(q(guild=ctx.guild.id))
subreddits = []
async for sub in subs:
subreddits.append(sub)
if not subreddits:
await ctx.send("You need to follow a Subreddit first", ephemeral=True)
return
options = []
names = []
for idx, subreddit in enumerate(subreddits):
sub = await Subreddit.find_one(q(display_name=subreddit.display_name))
names.append(sub.display_name)
option = SelectOption(label=sub.display_name, value=str(idx))
options.append(option)
select = Select(
options=options, custom_id="to_delete", min_values=1, max_values=len(subreddits)
)
components = [ActionRow(select)]
block = "\n".join(x for x in names)
message = await ctx.send(
content=f"You are following the following subreddits:\n```\n{block}\n```\n\n"
"Please choose subreddits to unfollow",
components=components,
)
try:
context = await self.bot.wait_for_component(
check=lambda x: ctx.author.id == x.context.author.id,
messages=message,
timeout=60 * 5,
)
for to_delete in context.context.values:
follow = get(subreddits, guild=ctx.guild.id, display_name=names[int(to_delete)])
try:
await follow.delete()
except Exception:
self.logger.debug("Ignoring deletion error")
for row in components:
for component in row.components:
component.disabled = True
block = "\n".join(names[int(x)] for x in context.context.values)
await context.context.edit_origin(
content=f"Unfollowed the following:\n```\n{block}\n```", components=components
)
except asyncio.TimeoutError:
for row in components:
for component in row.components:
component.disabled = True
await message.edit(components=components)
def setup(bot: Snake) -> None:
"""Add RedditCog to JARVIS"""
if JarvisConfig.from_yaml().reddit:
RedditCog(bot)