201 lines
6.6 KiB
Python
201 lines
6.6 KiB
Python
"""JARVIS Reddit sync."""
|
|
import asyncio
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
from typing import List, Optional
|
|
|
|
from asyncpraw import Reddit
|
|
from asyncpraw.models.reddit.submission import Submission
|
|
from asyncpraw.models.reddit.submission import Subreddit as Sub
|
|
from asyncprawcore.exceptions import Forbidden, NotFound
|
|
from dis_snek import Snake
|
|
from dis_snek.client.errors import NotFound as DNotFound
|
|
from dis_snek.models.discord.embed import Embed
|
|
from jarvis_core.db import q
|
|
from jarvis_core.db.models import Subreddit, SubredditFollow
|
|
|
|
from jarvis_tasks import const
|
|
from jarvis_tasks.config import TaskConfig
|
|
from jarvis_tasks.util import build_embed
|
|
|
|
DEFAULT_USER_AGENT = f"python:JARVIS-Tasks:{const.__version__} (by u/zevaryx)"
|
|
|
|
config = TaskConfig.from_yaml()
|
|
config.reddit["user_agent"] = config.reddit.get("user_agent", DEFAULT_USER_AGENT)
|
|
running = []
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def post_embeds(sub: Sub, post: Submission) -> Optional[List[Embed]]:
|
|
"""
|
|
Build a post embeds.
|
|
|
|
Args:
|
|
post: Post to build embeds
|
|
"""
|
|
url = "https://reddit.com" + post.permalink
|
|
await post.author.load()
|
|
author_url = f"https://reddit.com/u/{post.author.name}"
|
|
images = []
|
|
content = f"**{post.title}**"
|
|
if "url" in vars(post):
|
|
if any(post.url.endswith(x) for x in ["jpeg", "jpg", "png", "gif"]):
|
|
images = [post.url]
|
|
if "media_metadata" in vars(post):
|
|
for k, v in post.media_metadata.items():
|
|
if v["status"] != "valid" or v["m"] not in ["image/jpg", "image/png", "image/gif"]:
|
|
continue
|
|
ext = v["m"].split("/")[-1]
|
|
i_url = f"https://i.redd.it/{k}.{ext}"
|
|
images.append(i_url)
|
|
if len(images) == 4:
|
|
break
|
|
|
|
if "selftext" in vars(post) and post.selftext:
|
|
content += "\n\n" + post.selftext
|
|
if len(content) > 600:
|
|
content = content[:600] + "..."
|
|
content += f"\n\n[View this post]({url})"
|
|
|
|
if not images and not content:
|
|
logging.debug(f"Post {post.id} had neither content nor images?")
|
|
return None
|
|
|
|
color = "#FF4500"
|
|
if "primary_color" in vars(sub):
|
|
color = sub.primary_color
|
|
base_embed = build_embed(
|
|
title="", description=content, fields=[], timestamp=post.created_utc, url=url, color=color
|
|
)
|
|
base_embed.set_author(
|
|
name="u/" + post.author.name, url=author_url, icon_url=post.author.icon_img
|
|
)
|
|
base_embed.set_footer(
|
|
text="Reddit", icon_url="https://www.redditinc.com/assets/images/site/reddit-logo.png"
|
|
)
|
|
|
|
embeds = [base_embed]
|
|
|
|
if len(images) > 0:
|
|
embeds[0].set_image(url=images[0])
|
|
for image in images[1:4]:
|
|
embed = Embed(url=url)
|
|
embed.set_image(url=image)
|
|
embeds.append(embed)
|
|
|
|
return embeds
|
|
|
|
|
|
async def _stream(sub: Sub, bot: Snake) -> None:
|
|
"""
|
|
Stream a subreddit
|
|
|
|
Args:
|
|
sub: Subreddit to stream
|
|
bot: Snake instance
|
|
"""
|
|
now = datetime.now(tz=timezone.utc)
|
|
await sub.load()
|
|
running.append(sub.display_name)
|
|
logger.debug(f"Streaming subreddit {sub.display_name}")
|
|
async for post in sub.stream.submissions():
|
|
if not post:
|
|
logger.debug(f"Got None for post in {sub.display_name}")
|
|
continue
|
|
logger.debug(f"Got new post in {sub.display_name}")
|
|
if post.created_utc < now.timestamp():
|
|
continue
|
|
follows = SubredditFollow.find(q(display_name=sub.display_name))
|
|
follows_to_delete = []
|
|
num_follows = 0
|
|
|
|
async for follow in follows:
|
|
num_follows += 1
|
|
guild = await bot.fetch_guild(follow.guild)
|
|
if not guild:
|
|
logger.warning(f"Follow {follow.id}'s guild no longer exists, deleting")
|
|
follows_to_delete.append(follow)
|
|
continue
|
|
|
|
channel = await bot.fetch_channel(follow.channel)
|
|
if not channel:
|
|
logger.warning(f"Follow {follow.id}'s channel no longer exists, deleting")
|
|
follows_to_delete.append(follow)
|
|
continue
|
|
|
|
embeds = await post_embeds(sub, post)
|
|
timestamp = int(post.created_utc)
|
|
|
|
try:
|
|
await channel.send(
|
|
f"`r/{sub.display_name}` was posted to at <t:{timestamp}:f>",
|
|
embeds=embeds,
|
|
)
|
|
except DNotFound:
|
|
logger.warning(f"Follow {follow.id}'s channel no longer exists, deleting")
|
|
follows_to_delete.append(follow)
|
|
continue
|
|
except Exception:
|
|
logger.error(
|
|
f"Failed to send message to {channel.id} in {channel.guild.name}", exc_info=True
|
|
)
|
|
|
|
# Delete invalid follows
|
|
for follow in follows_to_delete:
|
|
await follow.delete()
|
|
|
|
if num_follows == 0:
|
|
s = await Subreddit.find_one(q(display_name=sub.display_name))
|
|
if s:
|
|
await s.delete()
|
|
break
|
|
running.remove(sub.display_name)
|
|
|
|
|
|
async def reddit(bot: Snake) -> None:
|
|
"""
|
|
Sync Reddit posts in the background.
|
|
|
|
Args:
|
|
bot: Snake instance
|
|
"""
|
|
if not config.reddit:
|
|
logger.warn("Missing Reddit config, not starting")
|
|
return
|
|
logger.debug("Starting Task-reddit")
|
|
red = Reddit(**config.reddit)
|
|
|
|
while True:
|
|
subs = Subreddit.find(q(display_name__nin=running))
|
|
|
|
# Go through all actively followed subreddits
|
|
async for sub in subs:
|
|
logger.debug(f"Creating stream for {sub.display_name}")
|
|
if sub.display_name in running:
|
|
logger.debug(f"Follow {sub.display_name} was found despite filter")
|
|
continue
|
|
|
|
is_followed = await SubredditFollow.find_one(q(display_name=sub.display_name))
|
|
if not is_followed:
|
|
logger.warn(f"Subreddit {sub.display_name} has no followers, removing")
|
|
await sub.delete()
|
|
continue
|
|
|
|
# Get subreddit
|
|
try:
|
|
sub = await red.subreddit(sub.display_name)
|
|
except (NotFound, Forbidden) as e:
|
|
# Subreddit is either quarantined, deleted, or private
|
|
logger.warn(f"Subreddit {sub.display_name} raised {e.__class__.__name__}, removing")
|
|
try:
|
|
await sub.delete()
|
|
except Exception:
|
|
logger.debug("Ignoring deletion error")
|
|
continue
|
|
|
|
# Create and run stream
|
|
coro = _stream(sub, bot)
|
|
asyncio.create_task(coro)
|
|
|
|
# Check every 60 seconds
|
|
await asyncio.sleep(60)
|