423 lines
17 KiB
Python
423 lines
17 KiB
Python
"""JARVIS Reddit cog."""
|
|
import asyncio
|
|
import logging
|
|
from typing import List, Optional
|
|
|
|
from asyncpraw import Reddit
|
|
from asyncpraw.models.reddit.submission import Submission
|
|
from asyncpraw.models.reddit.submission import Subreddit as Sub
|
|
from asyncprawcore.exceptions import Forbidden, NotFound, Redirect
|
|
from jarvis_core.db import q
|
|
from jarvis_core.db.models import Subreddit, SubredditFollow
|
|
from naff import Client, Extension, InteractionContext, Permissions
|
|
from naff.client.utils.misc_utils import get
|
|
from naff.models.discord.channel import ChannelTypes, GuildText
|
|
from naff.models.discord.components import ActionRow, Select, SelectOption
|
|
from naff.models.discord.embed import Embed, EmbedField
|
|
from naff.models.naff.application_commands import (
|
|
OptionTypes,
|
|
SlashCommand,
|
|
SlashCommandChoice,
|
|
slash_option,
|
|
)
|
|
from naff.models.naff.command import check
|
|
|
|
from jarvis import const
|
|
from jarvis.config import JarvisConfig
|
|
from jarvis.utils import build_embed
|
|
from jarvis.utils.permissions import admin_or_permissions
|
|
|
|
DEFAULT_USER_AGENT = f"python:JARVIS:{const.__version__} (by u/zevaryx)"
|
|
|
|
|
|
class RedditCog(Extension):
|
|
"""JARVIS Reddit Cog."""
|
|
|
|
def __init__(self, bot: Client):
|
|
self.bot = bot
|
|
self.logger = logging.getLogger(__name__)
|
|
config = JarvisConfig.from_yaml()
|
|
config.reddit["user_agent"] = config.reddit.get("user_agent", DEFAULT_USER_AGENT)
|
|
self.api = Reddit(**config.reddit)
|
|
|
|
async def post_embeds(self, sub: Sub, post: Submission) -> Optional[List[Embed]]:
|
|
"""
|
|
Build a post embeds.
|
|
|
|
Args:
|
|
post: Post to build embeds
|
|
"""
|
|
url = "https://reddit.com" + post.permalink
|
|
await post.author.load()
|
|
author_url = f"https://reddit.com/u/{post.author.name}"
|
|
author_icon = post.author.icon_img
|
|
images = []
|
|
title = f"{post.title}"
|
|
fields = []
|
|
content = ""
|
|
og_post = None
|
|
if not post.is_self:
|
|
og_post = post # noqa: F841
|
|
post = await self.api.submission(post.crosspost_parent_list[0]["id"])
|
|
await post.load()
|
|
fields.append(EmbedField(name="Crossposted From", value=post.subreddit_name_prefixed))
|
|
content = f"> **{post.title}**"
|
|
if "url" in vars(post):
|
|
if any(post.url.endswith(x) for x in ["jpeg", "jpg", "png", "gif"]):
|
|
images = [post.url]
|
|
if "media_metadata" in vars(post):
|
|
for k, v in post.media_metadata.items():
|
|
if v["status"] != "valid" or v["m"] not in ["image/jpg", "image/png", "image/gif"]:
|
|
continue
|
|
ext = v["m"].split("/")[-1]
|
|
i_url = f"https://i.redd.it/{k}.{ext}"
|
|
images.append(i_url)
|
|
if len(images) == 4:
|
|
break
|
|
|
|
if "selftext" in vars(post) and post.selftext:
|
|
content += "\n\n" + post.selftext
|
|
if len(content) > 900:
|
|
content = content[:900] + "..."
|
|
content += f"\n\n[View this post]({url})"
|
|
|
|
if not images and not content:
|
|
self.logger.debug(f"Post {post.id} had neither content nor images?")
|
|
return None
|
|
|
|
color = "#FF4500"
|
|
if "primary_color" in vars(sub):
|
|
color = sub.primary_color
|
|
base_embed = build_embed(
|
|
title=title,
|
|
description=content,
|
|
fields=fields,
|
|
timestamp=post.created_utc,
|
|
url=url,
|
|
color=color,
|
|
)
|
|
base_embed.set_author(name="u/" + post.author.name, url=author_url, icon_url=author_icon)
|
|
base_embed.set_footer(
|
|
text="Reddit", icon_url="https://www.redditinc.com/assets/images/site/reddit-logo.png"
|
|
)
|
|
|
|
embeds = [base_embed]
|
|
|
|
if len(images) > 0:
|
|
embeds[0].set_image(url=images[0])
|
|
for image in images[1:4]:
|
|
embed = Embed(url=url)
|
|
embed.set_image(url=image)
|
|
embeds.append(embed)
|
|
|
|
return embeds
|
|
|
|
reddit = SlashCommand(name="reddit", description="Manage Reddit follows")
|
|
|
|
@reddit.subcommand(sub_cmd_name="follow", sub_cmd_description="Follow a Subreddit")
|
|
@slash_option(
|
|
name="name",
|
|
description="Subreddit display name",
|
|
opt_type=OptionTypes.STRING,
|
|
required=True,
|
|
)
|
|
@slash_option(
|
|
name="channel",
|
|
description="Channel to post to",
|
|
opt_type=OptionTypes.CHANNEL,
|
|
channel_types=[ChannelTypes.GUILD_TEXT],
|
|
required=True,
|
|
)
|
|
@check(admin_or_permissions(Permissions.MANAGE_GUILD))
|
|
async def _reddit_follow(self, ctx: InteractionContext, name: str, channel: GuildText) -> None:
|
|
name = name.replace("r/", "")
|
|
if len(name) > 20 or len(name) < 3:
|
|
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
|
return
|
|
|
|
if not isinstance(channel, GuildText):
|
|
await ctx.send("Channel must be a text channel", ephemeral=True)
|
|
return
|
|
|
|
try:
|
|
subreddit = await self.api.subreddit(name)
|
|
await subreddit.load()
|
|
except (NotFound, Forbidden, Redirect) as e:
|
|
self.logger.debug(f"Subreddit {name} raised {e.__class__.__name__} on add")
|
|
await ctx.send("Subreddit may be private, quarantined, or nonexistent.", ephemeral=True)
|
|
return
|
|
|
|
exists = await SubredditFollow.find_one(
|
|
q(display_name=subreddit.display_name, guild=ctx.guild.id)
|
|
)
|
|
if exists:
|
|
await ctx.send("Subreddit already being followed in this guild", ephemeral=True)
|
|
return
|
|
|
|
count = len([i async for i in SubredditFollow.find(q(guild=ctx.guild.id))])
|
|
if count >= 12:
|
|
await ctx.send("Cannot follow more than 12 Subreddits", ephemeral=True)
|
|
return
|
|
|
|
if subreddit.over18 and not channel.nsfw:
|
|
await ctx.send(
|
|
"Subreddit is nsfw, but channel is not. Mark the channel NSFW first.",
|
|
ephemeral=True,
|
|
)
|
|
return
|
|
|
|
sr = await Subreddit.find_one(q(display_name=subreddit.display_name))
|
|
if not sr:
|
|
sr = Subreddit(display_name=subreddit.display_name, over18=subreddit.over18)
|
|
await sr.commit()
|
|
|
|
srf = SubredditFollow(
|
|
display_name=subreddit.display_name,
|
|
channel=channel.id,
|
|
guild=ctx.guild.id,
|
|
admin=ctx.author.id,
|
|
)
|
|
await srf.commit()
|
|
|
|
await ctx.send(f"Now following `r/{name}` in {channel.mention}")
|
|
|
|
@reddit.subcommand(sub_cmd_name="unfollow", sub_cmd_description="Unfollow Subreddits")
|
|
@check(admin_or_permissions(Permissions.MANAGE_GUILD))
|
|
async def _subreddit_unfollow(self, ctx: InteractionContext) -> None:
|
|
subs = SubredditFollow.find(q(guild=ctx.guild.id))
|
|
subreddits = []
|
|
async for sub in subs:
|
|
subreddits.append(sub)
|
|
if not subreddits:
|
|
await ctx.send("You need to follow a Subreddit first", ephemeral=True)
|
|
return
|
|
|
|
options = []
|
|
names = []
|
|
for idx, subreddit in enumerate(subreddits):
|
|
sub = await Subreddit.find_one(q(display_name=subreddit.display_name))
|
|
names.append(sub.display_name)
|
|
option = SelectOption(label=sub.display_name, value=str(idx))
|
|
options.append(option)
|
|
|
|
select = Select(
|
|
options=options, custom_id="to_delete", min_values=1, max_values=len(subreddits)
|
|
)
|
|
|
|
components = [ActionRow(select)]
|
|
block = "\n".join(x for x in names)
|
|
message = await ctx.send(
|
|
content=f"You are following the following subreddits:\n```\n{block}\n```\n\n"
|
|
"Please choose subreddits to unfollow",
|
|
components=components,
|
|
)
|
|
|
|
try:
|
|
context = await self.bot.wait_for_component(
|
|
check=lambda x: ctx.author.id == x.context.author.id,
|
|
messages=message,
|
|
timeout=60 * 5,
|
|
)
|
|
for to_delete in context.context.values:
|
|
follow = get(subreddits, guild=ctx.guild.id, display_name=names[int(to_delete)])
|
|
try:
|
|
await follow.delete()
|
|
except Exception:
|
|
self.logger.debug("Ignoring deletion error")
|
|
for row in components:
|
|
for component in row.components:
|
|
component.disabled = True
|
|
|
|
block = "\n".join(names[int(x)] for x in context.context.values)
|
|
await context.context.edit_origin(
|
|
content=f"Unfollowed the following:\n```\n{block}\n```", components=components
|
|
)
|
|
except asyncio.TimeoutError:
|
|
for row in components:
|
|
for component in row.components:
|
|
component.disabled = True
|
|
await message.edit(components=components)
|
|
|
|
@reddit.subcommand(sub_cmd_name="hot", sub_cmd_description="Get the hot post of a subreddit")
|
|
@slash_option(
|
|
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
|
|
)
|
|
async def _subreddit_hot(self, ctx: InteractionContext, name: str) -> None:
|
|
await ctx.defer()
|
|
name = name.replace("r/", "")
|
|
if len(name) > 20 or len(name) < 3:
|
|
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
|
return
|
|
try:
|
|
subreddit = await self.api.subreddit(name)
|
|
await subreddit.load()
|
|
except (NotFound, Forbidden, Redirect) as e:
|
|
self.logger.debug(f"Subreddit {name} raised {e.__class__.__name__} in hot")
|
|
await ctx.send("Subreddit may be private, quarantined, or nonexistent.", ephemeral=True)
|
|
return
|
|
try:
|
|
post = [x async for x in subreddit.hot(limit=1)][0]
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to get post from {name}", exc_info=e)
|
|
await ctx.send("Well, this is awkward. Something went wrong", ephemeral=True)
|
|
return
|
|
|
|
embeds = await self.post_embeds(subreddit, post)
|
|
if post.over_18 and not ctx.channel.nsfw:
|
|
try:
|
|
await ctx.author.send(embeds=embeds)
|
|
await ctx.send("Hey! Due to content, I had to DM the result to you")
|
|
except Exception:
|
|
await ctx.send("Hey! Due to content, I cannot share the result")
|
|
else:
|
|
await ctx.send(embeds=embeds)
|
|
|
|
@reddit.subcommand(sub_cmd_name="top", sub_cmd_description="Get the top post of a subreddit")
|
|
@slash_option(
|
|
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
|
|
)
|
|
@slash_option(
|
|
name="time",
|
|
description="Top time",
|
|
opt_type=OptionTypes.STRING,
|
|
required=False,
|
|
choices=[
|
|
SlashCommandChoice(name="All", value="all"),
|
|
SlashCommandChoice(name="Day", value="day"),
|
|
SlashCommandChoice(name="Hour", value="hour"),
|
|
SlashCommandChoice(name="Month", value="month"),
|
|
SlashCommandChoice(name="Week", value="week"),
|
|
SlashCommandChoice(name="Year", value="year"),
|
|
],
|
|
)
|
|
async def _subreddit_top(self, ctx: InteractionContext, name: str, time: str = "all") -> None:
|
|
await ctx.defer()
|
|
name = name.replace("r/", "")
|
|
if len(name) > 20 or len(name) < 3:
|
|
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
|
return
|
|
try:
|
|
subreddit = await self.api.subreddit(name)
|
|
await subreddit.load()
|
|
except (NotFound, Forbidden, Redirect) as e:
|
|
self.logger.debug(f"Subreddit {name} raised {e.__class__.__name__} in top")
|
|
await ctx.send("Subreddit may be private, quarantined, or nonexistent.", ephemeral=True)
|
|
return
|
|
try:
|
|
post = [x async for x in subreddit.top(time_filter=time, limit=1)][0]
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to get post from {name}", exc_info=e)
|
|
await ctx.send("Well, this is awkward. Something went wrong", ephemeral=True)
|
|
return
|
|
|
|
embeds = await self.post_embeds(subreddit, post)
|
|
if post.over_18 and not ctx.channel.nsfw:
|
|
try:
|
|
await ctx.author.send(embeds=embeds)
|
|
await ctx.send("Hey! Due to content, I had to DM the result to you")
|
|
except Exception:
|
|
await ctx.send("Hey! Due to content, I cannot share the result")
|
|
else:
|
|
await ctx.send(embeds=embeds)
|
|
|
|
@reddit.subcommand(
|
|
sub_cmd_name="random", sub_cmd_description="Get a random post of a subreddit"
|
|
)
|
|
@slash_option(
|
|
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
|
|
)
|
|
async def _subreddit_random(self, ctx: InteractionContext, name: str) -> None:
|
|
await ctx.defer()
|
|
name = name.replace("r/", "")
|
|
if len(name) > 20 or len(name) < 3:
|
|
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
|
return
|
|
try:
|
|
subreddit = await self.api.subreddit(name)
|
|
await subreddit.load()
|
|
except (NotFound, Forbidden, Redirect) as e:
|
|
self.logger.debug(f"Subreddit {name} raised {e.__class__.__name__} in random")
|
|
await ctx.send("Subreddit may be private, quarantined, or nonexistent.", ephemeral=True)
|
|
return
|
|
try:
|
|
post = await subreddit.random()
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to get post from {name}", exc_info=e)
|
|
await ctx.send("Well, this is awkward. Something went wrong", ephemeral=True)
|
|
return
|
|
|
|
embeds = await self.post_embeds(subreddit, post)
|
|
if post.over_18 and not ctx.channel.nsfw:
|
|
try:
|
|
await ctx.author.send(embeds=embeds)
|
|
await ctx.send("Hey! Due to content, I had to DM the result to you")
|
|
except Exception:
|
|
await ctx.send("Hey! Due to content, I cannot share the result")
|
|
else:
|
|
await ctx.send(embeds=embeds)
|
|
|
|
@reddit.subcommand(
|
|
sub_cmd_name="rising", sub_cmd_description="Get a rising post of a subreddit"
|
|
)
|
|
@slash_option(
|
|
name="name", description="Subreddit name", opt_type=OptionTypes.STRING, required=True
|
|
)
|
|
async def _subreddit_rising(self, ctx: InteractionContext, name: str) -> None:
|
|
await ctx.defer()
|
|
name = name.replace("r/", "")
|
|
if len(name) > 20 or len(name) < 3:
|
|
await ctx.send("Invalid Subreddit name", ephemeral=True)
|
|
return
|
|
try:
|
|
subreddit = await self.api.subreddit(name)
|
|
await subreddit.load()
|
|
except (NotFound, Forbidden, Redirect) as e:
|
|
self.logger.debug(f"Subreddit {name} raised {e.__class__.__name__} in rising")
|
|
await ctx.send("Subreddit may be private, quarantined, or nonexistent.", ephemeral=True)
|
|
return
|
|
try:
|
|
post = [x async for x in subreddit.rising(limit=1)][0]
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to get post from {name}", exc_info=e)
|
|
await ctx.send("Well, this is awkward. Something went wrong", ephemeral=True)
|
|
return
|
|
|
|
embeds = await self.post_embeds(subreddit, post)
|
|
if post.over_18 and not ctx.channel.nsfw:
|
|
try:
|
|
await ctx.author.send(embeds=embeds)
|
|
await ctx.send("Hey! Due to content, I had to DM the result to you")
|
|
except Exception:
|
|
await ctx.send("Hey! Due to content, I cannot share the result")
|
|
else:
|
|
await ctx.send(embeds=embeds)
|
|
|
|
@reddit.subcommand(sub_cmd_name="post", sub_cmd_description="Get a specific submission")
|
|
@slash_option(
|
|
name="sid", description="Submission ID", opt_type=OptionTypes.STRING, required=True
|
|
)
|
|
async def _reddit_post(self, ctx: InteractionContext, sid: str) -> None:
|
|
await ctx.defer()
|
|
try:
|
|
post = await self.api.submission(sid)
|
|
await post.load()
|
|
except (NotFound, Forbidden, Redirect) as e:
|
|
self.logger.debug(f"Submission {sid} raised {e.__class__.__name__} in post")
|
|
await ctx.send("Subreddit may be private, quarantined, or nonexistent.", ephemeral=True)
|
|
return
|
|
|
|
embeds = await self.post_embeds(post.subreddit, post)
|
|
if post.over_18 and not ctx.channel.nsfw:
|
|
try:
|
|
await ctx.author.send(embeds=embeds)
|
|
await ctx.send("Hey! Due to content, I had to DM the result to you")
|
|
except Exception:
|
|
await ctx.send("Hey! Due to content, I cannot share the result")
|
|
else:
|
|
await ctx.send(embeds=embeds)
|
|
|
|
|
|
def setup(bot: Client) -> None:
|
|
"""Add RedditCog to JARVIS"""
|
|
if JarvisConfig.from_yaml().reddit:
|
|
RedditCog(bot)
|