167 lines
6.1 KiB
Python
167 lines
6.1 KiB
Python
"""JARVIS Reddit cog."""
|
|
import asyncio
|
|
import logging
|
|
|
|
from asyncpraw import Reddit
|
|
from asyncprawcore.exceptions import Forbidden, NotFound, Redirect
|
|
from dis_snek import InteractionContext, Permissions, Scale, Snake
|
|
from dis_snek.client.utils.misc_utils import get
|
|
from dis_snek.models.discord.channel import ChannelTypes, GuildText
|
|
from dis_snek.models.discord.components import ActionRow, Select, SelectOption
|
|
from dis_snek.models.snek.application_commands import (
|
|
OptionTypes,
|
|
SlashCommand,
|
|
slash_option,
|
|
)
|
|
from dis_snek.models.snek.command import check
|
|
from jarvis_core.db import q
|
|
from jarvis_core.db.models import Subreddit, SubredditFollow
|
|
|
|
from jarvis import const
|
|
from jarvis.config import JarvisConfig
|
|
from jarvis.utils.permissions import admin_or_permissions
|
|
|
|
DEFAULT_USER_AGENT = f"python:JARVIS:{const.__version__} (by u/zevaryx)"
|
|
|
|
|
|
class RedditCog(Scale):
|
|
"""JARVIS Reddit Cog."""
|
|
|
|
def __init__(self, bot: Snake):
|
|
self.bot = bot
|
|
self.logger = logging.getLogger(__name__)
|
|
config = JarvisConfig.from_yaml()
|
|
config.reddit["user_agent"] = config.reddit.get("user_agent", DEFAULT_USER_AGENT)
|
|
self.api = Reddit(**config.reddit)
|
|
|
|
reddit = SlashCommand(name="reddit", description="Manage Reddit follows")
|
|
|
|
@reddit.subcommand(sub_cmd_name="follow", sub_cmd_description="Follow a Subreddit")
|
|
@slash_option(
|
|
name="name",
|
|
description="Subreddit display name",
|
|
opt_type=OptionTypes.STRING,
|
|
required=True,
|
|
)
|
|
@slash_option(
|
|
name="channel",
|
|
description="Channel to post to",
|
|
opt_type=OptionTypes.CHANNEL,
|
|
channel_types=[ChannelTypes.GUILD_TEXT],
|
|
required=True,
|
|
)
|
|
@check(admin_or_permissions(Permissions.MANAGE_GUILD))
|
|
async def _reddit_follow(self, ctx: InteractionContext, name: str, channel: GuildText) -> None:
|
|
name = name.replace("r/", "")
|
|
if len(name) > 20 or len(name) < 3:
|
|
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
|
return
|
|
|
|
if not isinstance(channel, GuildText):
|
|
await ctx.send("Channel must be a text channel", ephemeral=True)
|
|
return
|
|
|
|
try:
|
|
subreddit = await self.api.subreddit(name)
|
|
await subreddit.load()
|
|
except (NotFound, Forbidden, Redirect) as e:
|
|
self.logger.debug(f"Subreddit {name} raised {e.__class__.__name__} on add")
|
|
await ctx.send("Subreddit may be private, quarantined, or nonexistent.", ephemeral=True)
|
|
return
|
|
|
|
exists = await SubredditFollow.find_one(
|
|
q(display_name=subreddit.display_name, guild=ctx.guild.id)
|
|
)
|
|
if exists:
|
|
await ctx.send("Subreddit already being followed in this guild", ephemeral=True)
|
|
return
|
|
|
|
count = len([i async for i in SubredditFollow.find(q(guild=ctx.guild.id))])
|
|
if count >= 12:
|
|
await ctx.send("Cannot follow more than 12 Subreddits", ephemeral=True)
|
|
return
|
|
|
|
if subreddit.over18 and not channel.nsfw:
|
|
await ctx.send(
|
|
"Subreddit is nsfw, but channel is not. Mark the channel NSFW first.",
|
|
ephemeral=True,
|
|
)
|
|
return
|
|
|
|
sr = await Subreddit.find_one(q(display_name=subreddit.display_name))
|
|
if not sr:
|
|
sr = Subreddit(display_name=subreddit.display_name, over18=subreddit.over18)
|
|
await sr.commit()
|
|
|
|
srf = SubredditFollow(
|
|
display_name=subreddit.display_name,
|
|
channel=channel.id,
|
|
guild=ctx.guild.id,
|
|
admin=ctx.author.id,
|
|
)
|
|
await srf.commit()
|
|
|
|
await ctx.send(f"Now following `r/{name}` in {channel.mention}")
|
|
|
|
@reddit.subcommand(sub_cmd_name="unfollow", sub_cmd_description="Unfollow Subreddits")
|
|
@check(admin_or_permissions(Permissions.MANAGE_GUILD))
|
|
async def _subreddit_unfollow(self, ctx: InteractionContext) -> None:
|
|
subs = SubredditFollow.find(q(guild=ctx.guild.id))
|
|
subreddits = []
|
|
async for sub in subs:
|
|
subreddits.append(sub)
|
|
if not subreddits:
|
|
await ctx.send("You need to follow a Subreddit first", ephemeral=True)
|
|
return
|
|
|
|
options = []
|
|
names = []
|
|
for idx, subreddit in enumerate(subreddits):
|
|
sub = await Subreddit.find_one(q(display_name=subreddit.display_name))
|
|
names.append(sub.display_name)
|
|
option = SelectOption(label=sub.display_name, value=str(idx))
|
|
options.append(option)
|
|
|
|
select = Select(
|
|
options=options, custom_id="to_delete", min_values=1, max_values=len(subreddits)
|
|
)
|
|
|
|
components = [ActionRow(select)]
|
|
block = "\n".join(x for x in names)
|
|
message = await ctx.send(
|
|
content=f"You are following the following subreddits:\n```\n{block}\n```\n\n"
|
|
"Please choose subreddits to unfollow",
|
|
components=components,
|
|
)
|
|
|
|
try:
|
|
context = await self.bot.wait_for_component(
|
|
check=lambda x: ctx.author.id == x.context.author.id,
|
|
messages=message,
|
|
timeout=60 * 5,
|
|
)
|
|
for to_delete in context.context.values:
|
|
follow = get(subreddits, guild=ctx.guild.id, display_name=names[int(to_delete)])
|
|
try:
|
|
await follow.delete()
|
|
except Exception:
|
|
self.logger.debug("Ignoring deletion error")
|
|
for row in components:
|
|
for component in row.components:
|
|
component.disabled = True
|
|
|
|
block = "\n".join(names[int(x)] for x in context.context.values)
|
|
await context.context.edit_origin(
|
|
content=f"Unfollowed the following:\n```\n{block}\n```", components=components
|
|
)
|
|
except asyncio.TimeoutError:
|
|
for row in components:
|
|
for component in row.components:
|
|
component.disabled = True
|
|
await message.edit(components=components)
|
|
|
|
|
|
def setup(bot: Snake) -> None:
|
|
"""Add RedditCog to JARVIS"""
|
|
if JarvisConfig.from_yaml().reddit:
|
|
RedditCog(bot)
|