"""JARVIS Reddit cog.""" import asyncio import logging from asyncpraw import Reddit from asyncprawcore.exceptions import Forbidden, NotFound, Redirect from dis_snek import InteractionContext, Permissions, Scale, Snake from dis_snek.client.utils.misc_utils import get from dis_snek.models.discord.channel import ChannelTypes, GuildText from dis_snek.models.discord.components import ActionRow, Select, SelectOption from dis_snek.models.snek.application_commands import ( OptionTypes, SlashCommand, slash_option, ) from dis_snek.models.snek.command import check from jarvis_core.db import q from jarvis_core.db.models import Subreddit, SubredditFollow from jarvis import const from jarvis.config import JarvisConfig from jarvis.utils.permissions import admin_or_permissions DEFAULT_USER_AGENT = f"python:JARVIS:{const.__version__} (by u/zevaryx)" class RedditCog(Scale): """JARVIS Reddit Cog.""" def __init__(self, bot: Snake): self.bot = bot self.logger = logging.getLogger(__name__) config = JarvisConfig.from_yaml() config.reddit["user_agent"] = config.reddit.get("user_agent", DEFAULT_USER_AGENT) self.api = Reddit(**config.reddit) reddit = SlashCommand(name="reddit", description="Manage Reddit follows") @reddit.subcommand(sub_cmd_name="follow", sub_cmd_description="Follow a Subreddit") @slash_option( name="name", description="Subreddit display name", opt_type=OptionTypes.STRING, required=True, ) @slash_option( name="channel", description="Channel to post to", opt_type=OptionTypes.CHANNEL, channel_types=[ChannelTypes.GUILD_TEXT], required=True, ) @check(admin_or_permissions(Permissions.MANAGE_GUILD)) async def _reddit_follow(self, ctx: InteractionContext, name: str, channel: GuildText) -> None: name = name.replace("r/", "") if len(name) > 20 or len(name) < 3: await ctx.send("Invalid Subreddit name", ephemeral=True) return if not isinstance(channel, GuildText): await ctx.send("Channel must be a text channel", ephemeral=True) return try: subreddit = await self.api.subreddit(name) await subreddit.load() except (NotFound, Forbidden, Redirect) as e: self.logger.debug(f"Subreddit {name} raised {e.__class__.__name__} on add") await ctx.send("Subreddit may be private, quarantined, or nonexistent.", ephemeral=True) return exists = await SubredditFollow.find_one( q(display_name=subreddit.display_name, guild=ctx.guild.id) ) if exists: await ctx.send("Subreddit already being followed in this guild", ephemeral=True) return count = len([i async for i in SubredditFollow.find(q(guild=ctx.guild.id))]) if count >= 12: await ctx.send("Cannot follow more than 12 Subreddits", ephemeral=True) return if subreddit.over18 and not channel.nsfw: await ctx.send( "Subreddit is nsfw, but channel is not. Mark the channel NSFW first.", ephemeral=True, ) return sr = await Subreddit.find_one(q(display_name=subreddit.display_name)) if not sr: sr = Subreddit(display_name=subreddit.display_name, over18=subreddit.over18) await sr.commit() srf = SubredditFollow( display_name=subreddit.display_name, channel=channel.id, guild=ctx.guild.id, admin=ctx.author.id, ) await srf.commit() await ctx.send(f"Now following `r/{name}` in {channel.mention}") @reddit.subcommand(sub_cmd_name="unfollow", sub_cmd_description="Unfollow Subreddits") @check(admin_or_permissions(Permissions.MANAGE_GUILD)) async def _subreddit_unfollow(self, ctx: InteractionContext) -> None: subs = SubredditFollow.find(q(guild=ctx.guild.id)) subreddits = [] async for sub in subs: subreddits.append(sub) if not subreddits: await ctx.send("You need to follow a Subreddit first", ephemeral=True) return options = [] names = [] for idx, subreddit in enumerate(subreddits): sub = await Subreddit.find_one(q(display_name=subreddit.display_name)) names.append(sub.display_name) option = SelectOption(label=sub.display_name, value=str(idx)) options.append(option) select = Select( options=options, custom_id="to_delete", min_values=1, max_values=len(subreddits) ) components = [ActionRow(select)] block = "\n".join(x for x in names) message = await ctx.send( content=f"You are following the following subreddits:\n```\n{block}\n```\n\n" "Please choose subreddits to unfollow", components=components, ) try: context = await self.bot.wait_for_component( check=lambda x: ctx.author.id == x.context.author.id, messages=message, timeout=60 * 5, ) for to_delete in context.context.values: follow = get(subreddits, guild=ctx.guild.id, display_name=names[int(to_delete)]) try: await follow.delete() except Exception: self.logger.debug("Ignoring deletion error") for row in components: for component in row.components: component.disabled = True block = "\n".join(names[int(x)] for x in context.context.values) await context.context.edit_origin( content=f"Unfollowed the following:\n```\n{block}\n```", components=components ) except asyncio.TimeoutError: for row in components: for component in row.components: component.disabled = True await message.edit(components=components) def setup(bot: Snake) -> None: """Add RedditCog to JARVIS""" if JarvisConfig.from_yaml().reddit: RedditCog(bot)