jarvis-bot/jarvis/cogs/reddit.py

591 lines
23 KiB
Python

"""JARVIS Reddit cog."""
import asyncio
import logging
import re
from typing import List, Optional
from asyncpraw import Reddit
from asyncpraw.models.reddit.submission import Submission
from asyncpraw.models.reddit.submission import Subreddit as Sub
from asyncprawcore.exceptions import Forbidden, NotFound, Redirect
from jarvis_core.db import q
from jarvis_core.db.models import (
Redditor,
RedditorFollow,
Subreddit,
SubredditFollow,
UserSetting,
)
from naff import Client, Extension, InteractionContext, Permissions
from naff.client.utils.misc_utils import get
from naff.models.discord.channel import ChannelTypes, GuildText
from naff.models.discord.components import ActionRow, Select, SelectOption
from naff.models.discord.embed import Embed, EmbedField
from naff.models.naff.application_commands import (
OptionTypes,
SlashCommand,
SlashCommandChoice,
slash_option,
)
from naff.models.naff.command import check
from jarvis import const
from jarvis.config import JarvisConfig
from jarvis.utils import build_embed
from jarvis.utils.permissions import admin_or_permissions
DEFAULT_USER_AGENT = f"python:JARVIS:{const.__version__} (by u/zevaryx)"
sub_name = re.compile(r"\A[A-Za-z0-9][A-Za-z0-9_]{2,20}\Z")
user_name = re.compile(r"[A-Za-z0-9_-]+")
image_link = re.compile(r"https?://(?:www)?\.?preview\.redd\.it\/(.*\..*)\?.*")
class RedditCog(Extension):
"""JARVIS Reddit Cog."""
def __init__(self, bot: Client):
self.bot = bot
self.logger = logging.getLogger(__name__)
config = JarvisConfig.from_yaml()
config.reddit["user_agent"] = config.reddit.get("user_agent", DEFAULT_USER_AGENT)
self.api = Reddit(**config.reddit)
async def post_embeds(self, sub: Sub, post: Submission) -> Optional[List[Embed]]:
"""
Build a post embeds.
Args:
post: Post to build embeds
"""
url = "https://reddit.com" + post.permalink
await post.author.load()
author_url = f"https://reddit.com/u/{post.author.name}"
author_icon = post.author.icon_img
images = []
title = post.title
if len(title) > 256:
title = title[253] + "..."
fields = []
content = ""
og_post = None
if "crosspost_parent_list" in vars(post):
og_post = post # noqa: F841
post = await self.api.submission(post.crosspost_parent_list[0]["id"])
await post.load()
fields.append(EmbedField(name="Crossposted From", value=post.subreddit_name_prefixed))
content = f"> **{post.title}**"
if "url" in vars(post):
if any(post.url.endswith(x) for x in ["jpeg", "jpg", "png", "gif"]):
images = [post.url]
if "media_metadata" in vars(post):
for k, v in post.media_metadata.items():
if v["status"] != "valid" or v["m"] not in ["image/jpg", "image/png", "image/gif"]:
continue
ext = v["m"].split("/")[-1]
i_url = f"https://i.redd.it/{k}.{ext}"
images.append(i_url)
if len(images) == 4:
break
if "selftext" in vars(post) and post.selftext:
text = post.selftext
if post.spoiler:
text = "||" + text + "||"
content += "\n\n" + post.selftext
if len(content) > 900:
content = content[:900] + "..."
if post.spoiler:
content += "||"
content += f"\n\n[View this post]({url})"
content = "\n".join(image_link.sub(r"https://i.redd.it/\1", x) for x in content.split("\n"))
if not images and not content:
self.logger.debug(f"Post {post.id} had neither content nor images?")
return None
color = "#FF4500"
if "primary_color" in vars(sub):
color = sub.primary_color
base_embed = build_embed(
title=title,
description=content,
fields=fields,
timestamp=post.created_utc,
url=url,
color=color,
)
base_embed.set_author(name="u/" + post.author.name, url=author_url, icon_url=author_icon)
base_embed.set_footer(
text=f"r/{sub.display_name}",
icon_url="https://www.redditinc.com/assets/images/site/reddit-logo.png",
)
embeds = [base_embed]
if len(images) > 0:
embeds[0].set_image(url=images[0])
for image in images[1:4]:
embed = Embed(url=url)
embed.set_image(url=image)
embeds.append(embed)
return embeds
reddit = SlashCommand(name="reddit", description="Manage Reddit follows")
follow = reddit.group(name="follow", description="Add a follow")
unfollow = reddit.group(name="unfollow", description="Remove a follow")
@follow.subcommand(sub_cmd_name="redditor", sub_cmd_description="Follow a Redditor")
@slash_option(
name="name",
description="Redditor name",
opt_type=OptionTypes.STRING,
required=True,
)
@slash_option(
name="channel",
description="Channel to post to",
opt_type=OptionTypes.CHANNEL,
channel_types=[ChannelTypes.GUILD_TEXT],
required=True,
)
@check(admin_or_permissions(Permissions.MANAGE_GUILD))
async def _redditor_follow(
self, ctx: InteractionContext, name: str, channel: GuildText
) -> None:
if not user_name.match(name):
await ctx.send("Invalid Redditor name", ephemeral=True)
return
if not isinstance(channel, GuildText):
await ctx.send("Channel must be a text channel", ephemeral=True)
return
try:
redditor = await self.api.redditor(name)
await redditor.load()
except (NotFound, Forbidden, Redirect) as e:
self.logger.debug(f"Redditor {name} raised {e.__class__.__name__} on add")
await ctx.send("Redditor may be deleted or nonexistent.", ephemeral=True)
return
exists = await RedditorFollow.find_one(q(name=redditor.name, guild=ctx.guild.id))
if exists:
await ctx.send("Redditor already being followed in this guild", ephemeral=True)
return
count = len([i async for i in SubredditFollow.find(q(guild=ctx.guild.id))])
if count >= 12:
await ctx.send("Cannot follow more than 12 Redditors", ephemeral=True)
return
sr = await Redditor.find_one(q(name=redditor.name))
if not sr:
sr = Redditor(name=redditor.name)
await sr.commit()
srf = RedditorFollow(
name=redditor.name,
channel=channel.id,
guild=ctx.guild.id,
admin=ctx.author.id,
)
await srf.commit()
await ctx.send(f"Now following `u/{name}` in {channel.mention}")
@unfollow.subcommand(sub_cmd_name="redditor", sub_cmd_description="Unfollow Redditor")
@check(admin_or_permissions(Permissions.MANAGE_GUILD))
async def _redditor_unfollow(self, ctx: InteractionContext) -> None:
subs = RedditorFollow.find(q(guild=ctx.guild.id))
redditors = []
async for sub in subs:
redditors.append(sub)
if not redditors:
await ctx.send("You need to follow a redditor first", ephemeral=True)
return
options = []
names = []
for idx, redditor in enumerate(redditors):
sub = await Redditor.find_one(q(name=redditor.name))
names.append(sub.name)
option = SelectOption(label=sub.name, value=str(idx))
options.append(option)
select = Select(
options=options, custom_id="to_delete", min_values=1, max_values=len(redditors)
)
components = [ActionRow(select)]
block = "\n".join(x for x in names)
message = await ctx.send(
content=f"You are following the following redditors:\n```\n{block}\n```\n\n"
"Please choose redditors to unfollow",
components=components,
)
try:
context = await self.bot.wait_for_component(
check=lambda x: ctx.author.id == x.context.author.id,
messages=message,
timeout=60 * 5,
)
for to_delete in context.context.values:
follow = get(redditors, guild=ctx.guild.id, name=names[int(to_delete)])
try:
await follow.delete()
except Exception:
self.logger.debug("Ignoring deletion error")
for row in components:
for component in row.components:
component.disabled = True
block = "\n".join(names[int(x)] for x in context.context.values)
await context.context.edit_origin(
content=f"Unfollowed the following:\n```\n{block}\n```", components=components
)
except asyncio.TimeoutError:
for row in components:
for component in row.components:
component.disabled = True
await message.edit(components=components)
@follow.subcommand(sub_cmd_name="subreddit", sub_cmd_description="Follow a Subreddit")
@slash_option(
name="name",
description="Subreddit display name",
opt_type=OptionTypes.STRING,
required=True,
)
@slash_option(
name="channel",
description="Channel to post to",
opt_type=OptionTypes.CHANNEL,
channel_types=[ChannelTypes.GUILD_TEXT],
required=True,
)
@check(admin_or_permissions(Permissions.MANAGE_GUILD))
async def _subreddit_follow(
self, ctx: InteractionContext, name: str, channel: GuildText
) -> None:
if not sub_name.match(name):
await ctx.send("Invalid Subreddit name", ephemeral=True)
return
if not isinstance(channel, GuildText):
await ctx.send("Channel must be a text channel", ephemeral=True)
return
try:
subreddit = await self.api.subreddit(name)
await subreddit.load()
except (NotFound, Forbidden, Redirect) as e:
self.logger.debug(f"Subreddit {name} raised {e.__class__.__name__} on add")
await ctx.send("Subreddit may be private, quarantined, or nonexistent.", ephemeral=True)
return
exists = await SubredditFollow.find_one(
q(display_name=subreddit.display_name, guild=ctx.guild.id)
)
if exists:
await ctx.send("Subreddit already being followed in this guild", ephemeral=True)
return
count = len([i async for i in SubredditFollow.find(q(guild=ctx.guild.id))])
if count >= 12:
await ctx.send("Cannot follow more than 12 Subreddits", ephemeral=True)
return
if subreddit.over18 and not channel.nsfw:
await ctx.send(
"Subreddit is nsfw, but channel is not. Mark the channel NSFW first.",
ephemeral=True,
)
return
sr = await Subreddit.find_one(q(display_name=subreddit.display_name))
if not sr:
sr = Subreddit(display_name=subreddit.display_name, over18=subreddit.over18)
await sr.commit()
srf = SubredditFollow(
display_name=subreddit.display_name,
channel=channel.id,
guild=ctx.guild.id,
admin=ctx.author.id,
)
await srf.commit()
await ctx.send(f"Now following `r/{name}` in {channel.mention}")
@unfollow.subcommand(sub_cmd_name="subreddit", sub_cmd_description="Unfollow Subreddits")
@check(admin_or_permissions(Permissions.MANAGE_GUILD))
async def _subreddit_unfollow(self, ctx: InteractionContext) -> None:
subs = SubredditFollow.find(q(guild=ctx.guild.id))
subreddits = []
async for sub in subs:
subreddits.append(sub)
if not subreddits:
await ctx.send("You need to follow a Subreddit first", ephemeral=True)
return
options = []
names = []
for idx, subreddit in enumerate(subreddits):
sub = await Subreddit.find_one(q(display_name=subreddit.display_name))
names.append(sub.display_name)
option = SelectOption(label=sub.display_name, value=str(idx))
options.append(option)
select = Select(
options=options, custom_id="to_delete", min_values=1, max_values=len(subreddits)
)
components = [ActionRow(select)]
block = "\n".join(x for x in names)
message = await ctx.send(
content=f"You are following the following subreddits:\n```\n{block}\n```\n\n"
"Please choose subreddits to unfollow",
components=components,
)
try:
context = await self.bot.wait_for_component(
check=lambda x: ctx.author.id == x.context.author.id,
messages=message,
timeout=60 * 5,
)
for to_delete in context.context.values:
follow = get(subreddits, guild=ctx.guild.id, display_name=names[int(to_delete)])
try:
await follow.delete()
except Exception:
self.logger.debug("Ignoring deletion error")
for row in components:
for component in row.components:
component.disabled = True
block = "\n".join(names[int(x)] for x in context.context.values)
await context.context.edit_origin(
content=f"Unfollowed the following:\n```\n{block}\n```", components=components
)
except asyncio.TimeoutError:
for row in components:
for component in row.components:
component.disabled = True
await message.edit(components=components)
@reddit.subcommand(sub_cmd_name="hot", sub_cmd_description="Get the hot post of a subreddit")
@slash_option(
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
)
async def _subreddit_hot(self, ctx: InteractionContext, name: str) -> None:
if not sub_name.match(name):
await ctx.send("Invalid Subreddit name", ephemeral=True)
return
try:
await ctx.defer()
subreddit = await self.api.subreddit(name)
await subreddit.load()
except (NotFound, Forbidden, Redirect) as e:
self.logger.debug(f"Subreddit {name} raised {e.__class__.__name__} in hot")
await ctx.send("Subreddit may be private, quarantined, or nonexistent.", ephemeral=True)
return
try:
post = [x async for x in subreddit.hot(limit=1)][0]
except Exception as e:
self.logger.error(f"Failed to get post from {name}", exc_info=e)
await ctx.send("Well, this is awkward. Something went wrong", ephemeral=True)
return
embeds = await self.post_embeds(subreddit, post)
if post.over_18 and not ctx.channel.nsfw:
setting = await UserSetting.find_one(
q(user=ctx.author.id, type="reddit", setting="dm_nsfw")
)
if setting and setting.value:
try:
await ctx.author.send(embeds=embeds)
except Exception:
pass
await ctx.send("Hey! Due to content, I cannot share the result")
else:
await ctx.send(embeds=embeds)
@reddit.subcommand(sub_cmd_name="top", sub_cmd_description="Get the top post of a subreddit")
@slash_option(
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
)
@slash_option(
name="time",
description="Top time",
opt_type=OptionTypes.STRING,
required=False,
choices=[
SlashCommandChoice(name="All", value="all"),
SlashCommandChoice(name="Day", value="day"),
SlashCommandChoice(name="Hour", value="hour"),
SlashCommandChoice(name="Month", value="month"),
SlashCommandChoice(name="Week", value="week"),
SlashCommandChoice(name="Year", value="year"),
],
)
async def _subreddit_top(self, ctx: InteractionContext, name: str, time: str = "all") -> None:
if not sub_name.match(name):
await ctx.send("Invalid Subreddit name", ephemeral=True)
return
try:
await ctx.defer()
subreddit = await self.api.subreddit(name)
await subreddit.load()
except (NotFound, Forbidden, Redirect) as e:
self.logger.debug(f"Subreddit {name} raised {e.__class__.__name__} in top")
await ctx.send("Subreddit may be private, quarantined, or nonexistent.", ephemeral=True)
return
try:
post = [x async for x in subreddit.top(time_filter=time, limit=1)][0]
except Exception as e:
self.logger.error(f"Failed to get post from {name}", exc_info=e)
await ctx.send("Well, this is awkward. Something went wrong", ephemeral=True)
return
embeds = await self.post_embeds(subreddit, post)
if post.over_18 and not ctx.channel.nsfw:
setting = await UserSetting.find_one(
q(user=ctx.author.id, type="reddit", setting="dm_nsfw")
)
if setting and setting.value:
try:
await ctx.author.send(embeds=embeds)
except Exception:
pass
await ctx.send("Hey! Due to content, I cannot share the result")
else:
await ctx.send(embeds=embeds)
@reddit.subcommand(
sub_cmd_name="random", sub_cmd_description="Get a random post of a subreddit"
)
@slash_option(
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
)
async def _subreddit_random(self, ctx: InteractionContext, name: str) -> None:
if not sub_name.match(name):
await ctx.send("Invalid Subreddit name", ephemeral=True)
return
try:
await ctx.defer()
subreddit = await self.api.subreddit(name)
await subreddit.load()
except (NotFound, Forbidden, Redirect) as e:
self.logger.debug(f"Subreddit {name} raised {e.__class__.__name__} in random")
await ctx.send("Subreddit may be private, quarantined, or nonexistent.", ephemeral=True)
return
try:
post = await subreddit.random()
except Exception as e:
self.logger.error(f"Failed to get post from {name}", exc_info=e)
await ctx.send("Well, this is awkward. Something went wrong", ephemeral=True)
return
embeds = await self.post_embeds(subreddit, post)
if post.over_18 and not ctx.channel.nsfw:
setting = await UserSetting.find_one(
q(user=ctx.author.id, type="reddit", setting="dm_nsfw")
)
if setting and setting.value:
try:
await ctx.author.send(embeds=embeds)
except Exception:
pass
await ctx.send("Hey! Due to content, I cannot share the result")
else:
await ctx.send(embeds=embeds)
@reddit.subcommand(
sub_cmd_name="rising", sub_cmd_description="Get a rising post of a subreddit"
)
@slash_option(
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
)
async def _subreddit_rising(self, ctx: InteractionContext, name: str) -> None:
if not sub_name.match(name):
await ctx.send("Invalid Subreddit name", ephemeral=True)
return
try:
await ctx.defer()
subreddit = await self.api.subreddit(name)
await subreddit.load()
except (NotFound, Forbidden, Redirect) as e:
self.logger.debug(f"Subreddit {name} raised {e.__class__.__name__} in rising")
await ctx.send("Subreddit may be private, quarantined, or nonexistent.", ephemeral=True)
return
try:
post = [x async for x in subreddit.rising(limit=1)][0]
except Exception as e:
self.logger.error(f"Failed to get post from {name}", exc_info=e)
await ctx.send("Well, this is awkward. Something went wrong", ephemeral=True)
return
embeds = await self.post_embeds(subreddit, post)
if post.over_18 and not ctx.channel.nsfw:
setting = await UserSetting.find_one(
q(user=ctx.author.id, type="reddit", setting="dm_nsfw")
)
if setting and setting.value:
try:
await ctx.author.send(embeds=embeds)
except Exception:
pass
await ctx.send("Hey! Due to content, I cannot share the result")
else:
await ctx.send(embeds=embeds)
@reddit.subcommand(sub_cmd_name="post", sub_cmd_description="Get a specific submission")
@slash_option(
name="sid", description="Submission ID", opt_type=OptionTypes.STRING, required=True
)
async def _reddit_post(self, ctx: InteractionContext, sid: str) -> None:
await ctx.defer()
try:
post = await self.api.submission(sid)
await post.load()
except (NotFound, Forbidden, Redirect) as e:
self.logger.debug(f"Submission {sid} raised {e.__class__.__name__} in post")
await ctx.send("Post could not be found.", ephemeral=True)
return
embeds = await self.post_embeds(post.subreddit, post)
if post.over_18 and not ctx.channel.nsfw:
setting = await UserSetting.find_one(
q(user=ctx.author.id, type="reddit", setting="dm_nsfw")
)
if setting and setting.value:
try:
await ctx.author.send(embeds=embeds)
except Exception:
pass
await ctx.send("Hey! Due to content, I cannot share the result")
else:
await ctx.send(embeds=embeds)
@reddit.subcommand(
sub_cmd_name="dm_nsfw", sub_cmd_description="DM NSFW posts if channel isn't NSFW"
)
@slash_option(name="dm", description="Send DM?", opt_type=OptionTypes.BOOLEAN, required=True)
async def _reddit_dm(self, ctx: InteractionContext, dm: bool) -> None:
setting = await UserSetting.find_one(
q(user=ctx.author.id, type="reddit", setting="dm_nsfw")
)
if not setting:
setting = UserSetting(user=ctx.author.id, type="reddit", setting="dm_nsfw", value=dm)
setting.value = dm
await setting.commit()
await ctx.send(f"Reddit DM NSFW setting is now set to {dm}", ephemeral=True)
def setup(bot: Client) -> None:
"""Add RedditCog to JARVIS"""
if JarvisConfig.from_yaml().reddit:
RedditCog(bot)