Merge branch 'beanie' into 'main'
Beanie See merge request stark-industries/jarvis/jarvis-tasks!2
This commit is contained in:
commit
0819d4bb85
17 changed files with 1400 additions and 1014 deletions
28
.flake8
28
.flake8
|
@ -1,20 +1,20 @@
|
||||||
[flake8]
|
[flake8]
|
||||||
exclude =
|
exclude =
|
||||||
run.py
|
tests/*
|
||||||
|
|
||||||
extend-ignore =
|
extend-ignore =
|
||||||
Q0, E501, C812, E203, W503, # These default to arguing with Black. We might configure some of them eventually
|
Q0, E501, C812, E203, W503,
|
||||||
ANN001, # Ignore self and cls annotations
|
ANN1, ANN003,
|
||||||
ANN002, ANN003, # Ignore *args and **kwargs
|
ANN204, ANN206,
|
||||||
ANN101, # Ignore self
|
D105, D107,
|
||||||
ANN204, ANN206, # return annotations for special methods and class methods
|
S311,
|
||||||
D105, D107, # Missing Docstrings in magic method and __init__
|
D401,
|
||||||
S311, # Standard pseudo-random generators are not suitable for security/cryptographic purposes.
|
D400,
|
||||||
D401, # First line should be in imperative mood; try rephrasing
|
D101, D102,
|
||||||
D400, # First line should end with a period
|
D106,
|
||||||
D101, # Missing docstring in public class
|
R503, E712
|
||||||
|
|
||||||
# Plugins we don't currently include: flake8-return
|
|
||||||
R503, # missing explicit return at the end of function ableto return non-None value.
|
|
||||||
|
|
||||||
max-line-length=100
|
max-line-length=100
|
||||||
|
|
||||||
|
per-file-ignores =
|
||||||
|
jarvis_core/db/models/__init__.py:F401
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
repos:
|
repos:
|
||||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
rev: v4.1.0
|
rev: v4.4.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: check-toml
|
- id: check-toml
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
|
@ -9,21 +9,19 @@ repos:
|
||||||
- id: requirements-txt-fixer
|
- id: requirements-txt-fixer
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
- id: debug-statements
|
- id: debug-statements
|
||||||
language_version: python3.10
|
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
args: [--markdown-linebreak-ext=md]
|
args: [--markdown-linebreak-ext=md]
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/pygrep-hooks
|
- repo: https://github.com/pre-commit/pygrep-hooks
|
||||||
rev: v1.9.0
|
rev: v1.10.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: python-check-blanket-noqa
|
- id: python-check-blanket-noqa
|
||||||
|
|
||||||
- repo: https://github.com/psf/black
|
- repo: https://github.com/psf/black
|
||||||
rev: 22.3.0
|
rev: 23.7.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
args: [--line-length=100, --target-version=py310]
|
args: [--line-length=100]
|
||||||
language_version: python3.10
|
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-isort
|
- repo: https://github.com/pre-commit/mirrors-isort
|
||||||
rev: v5.10.1
|
rev: v5.10.1
|
||||||
|
@ -32,12 +30,12 @@ repos:
|
||||||
args: ["--profile", "black"]
|
args: ["--profile", "black"]
|
||||||
|
|
||||||
- repo: https://github.com/pycqa/flake8
|
- repo: https://github.com/pycqa/flake8
|
||||||
rev: 4.0.1
|
rev: 6.1.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
- flake8-annotations~=2.0
|
- flake8-annotations~=2.0
|
||||||
#- flake8-bandit~=2.1
|
#- flake8-bandit # Uncomment once works again
|
||||||
- flake8-docstrings~=1.5
|
- flake8-docstrings~=1.5
|
||||||
- flake8-bugbear
|
- flake8-bugbear
|
||||||
- flake8-comprehensions
|
- flake8-comprehensions
|
||||||
|
@ -46,4 +44,3 @@ repos:
|
||||||
- flake8-deprecated
|
- flake8-deprecated
|
||||||
- flake8-print
|
- flake8-print
|
||||||
- flake8-return
|
- flake8-return
|
||||||
language_version: python3.10
|
|
||||||
|
|
|
@ -1,24 +1,19 @@
|
||||||
"""JARVIS background tasks."""
|
"""JARVIS background tasks."""
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
#import rook
|
from interactions import Client, Intents
|
||||||
from jarvis_core.db import connect
|
from jarvis_core.db import connect
|
||||||
from jarvis_core.log import get_logger
|
from jarvis_core.log import get_logger
|
||||||
from naff import Client, Intents
|
|
||||||
|
|
||||||
from jarvis_tasks import const
|
from jarvis_tasks import const
|
||||||
from jarvis_tasks.config import TaskConfig
|
from jarvis_tasks.config import load_config
|
||||||
from jarvis_tasks.prometheus.serve import StatTracker
|
|
||||||
from jarvis_tasks.tasks import (
|
from jarvis_tasks.tasks import (
|
||||||
autokick,
|
autokick,
|
||||||
ban,
|
ban,
|
||||||
lock,
|
lock,
|
||||||
lockdown,
|
lockdown,
|
||||||
reddit,
|
|
||||||
reminder,
|
reminder,
|
||||||
temprole,
|
temprole,
|
||||||
twitter,
|
|
||||||
warning,
|
warning,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -26,7 +21,7 @@ __version__ = const.__version__
|
||||||
logger = None
|
logger = None
|
||||||
|
|
||||||
|
|
||||||
async def _start(config: Optional[str] = "config.yaml") -> None:
|
async def _start() -> None:
|
||||||
"""
|
"""
|
||||||
Main start function.
|
Main start function.
|
||||||
|
|
||||||
|
@ -34,15 +29,11 @@ async def _start(config: Optional[str] = "config.yaml") -> None:
|
||||||
config: Config path
|
config: Config path
|
||||||
"""
|
"""
|
||||||
# Load config
|
# Load config
|
||||||
config = TaskConfig.from_yaml(config)
|
config = load_config()
|
||||||
|
|
||||||
# if config.rook_token:
|
|
||||||
# rook.start(token=config.rook_token, labels={"env": "dev"})
|
|
||||||
|
|
||||||
# Connect to database
|
# Connect to database
|
||||||
testing = config.mongo["database"] != "jarvis"
|
logger.debug(f"Connecting to database, environ={config.environment.value}")
|
||||||
logger.debug(f"Connecting to database, testing={testing}")
|
await connect(**config.mongo.dict(), testing=config.environment.value == "develop")
|
||||||
connect(**config.mongo["connect"], testing=testing)
|
|
||||||
|
|
||||||
# Get event loop
|
# Get event loop
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
@ -53,7 +44,7 @@ async def _start(config: Optional[str] = "config.yaml") -> None:
|
||||||
bot = Client(intents=intents, loop=loop)
|
bot = Client(intents=intents, loop=loop)
|
||||||
await bot.login(config.token)
|
await bot.login(config.token)
|
||||||
logger.info(f"Logged in as {bot.user.username}#{bot.user.discriminator}")
|
logger.info(f"Logged in as {bot.user.username}#{bot.user.discriminator}")
|
||||||
tracker = StatTracker()
|
# tracker = StatTracker()
|
||||||
|
|
||||||
# Start tasks
|
# Start tasks
|
||||||
try:
|
try:
|
||||||
|
@ -63,13 +54,13 @@ async def _start(config: Optional[str] = "config.yaml") -> None:
|
||||||
ban.unban,
|
ban.unban,
|
||||||
lock.unlock,
|
lock.unlock,
|
||||||
lockdown.lift,
|
lockdown.lift,
|
||||||
reddit.reddit,
|
|
||||||
reminder.remind,
|
reminder.remind,
|
||||||
temprole.remove,
|
temprole.remove,
|
||||||
twitter.twitter,
|
|
||||||
warning.unwarn,
|
warning.unwarn,
|
||||||
]
|
]
|
||||||
tasks = [loop.create_task(f(bot)) for f in functions] + [loop.create_task(tracker.start())]
|
tasks = [loop.create_task(f(bot)) for f in functions] + [
|
||||||
|
# loop.create_task(tracker.start())
|
||||||
|
]
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
await task
|
await task
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
@ -77,7 +68,7 @@ async def _start(config: Optional[str] = "config.yaml") -> None:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
|
|
||||||
def start(config: Optional[str] = "config.yaml") -> None:
|
def start() -> None:
|
||||||
"""
|
"""
|
||||||
Start the background tasks.
|
Start the background tasks.
|
||||||
|
|
||||||
|
@ -86,10 +77,10 @@ def start(config: Optional[str] = "config.yaml") -> None:
|
||||||
"""
|
"""
|
||||||
global logger, debug
|
global logger, debug
|
||||||
# Set log level
|
# Set log level
|
||||||
_config = TaskConfig.from_yaml(config)
|
_config = load_config()
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
logger.setLevel(_config.log_level)
|
logger.setLevel(_config.log_level)
|
||||||
|
|
||||||
# Run the main tasks
|
# Run the main tasks
|
||||||
logger.debug("Starting asyncio")
|
logger.debug("Starting asyncio")
|
||||||
asyncio.run(_start(config))
|
asyncio.run(_start())
|
||||||
|
|
|
@ -1,7 +1,107 @@
|
||||||
"""Task config."""
|
"""Task config."""
|
||||||
from jarvis_core.config import Config
|
from enum import Enum
|
||||||
|
from os import environ
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import orjson as json
|
||||||
|
import yaml
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from jarvis_core.util import find_all
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
try:
|
||||||
|
from yaml import CLoader as Loader
|
||||||
|
except ImportError:
|
||||||
|
from yaml import Loader
|
||||||
|
|
||||||
|
|
||||||
class TaskConfig(Config):
|
class Environment(Enum):
|
||||||
REQUIRED = ["token", "mongo"]
|
"""JARVIS running environment."""
|
||||||
OPTIONAL = {"log_level": "WARNING", "twitter": None, "reddit": None, "rook_token": None}
|
|
||||||
|
production = "production"
|
||||||
|
develop = "develop"
|
||||||
|
|
||||||
|
|
||||||
|
class Mongo(BaseModel):
|
||||||
|
"""MongoDB config."""
|
||||||
|
|
||||||
|
host: list[str] | str = "localhost"
|
||||||
|
username: Optional[str] = None
|
||||||
|
password: Optional[str] = None
|
||||||
|
port: int = 27017
|
||||||
|
|
||||||
|
|
||||||
|
class Config(BaseModel):
|
||||||
|
"""Tasks config model."""
|
||||||
|
|
||||||
|
token: str
|
||||||
|
mongo: Mongo
|
||||||
|
log_level: str = "INFO"
|
||||||
|
environment: Environment = Environment.develop
|
||||||
|
|
||||||
|
|
||||||
|
_config: Config = None
|
||||||
|
|
||||||
|
|
||||||
|
def _load_json() -> Config | None:
|
||||||
|
path = Path("config.json")
|
||||||
|
if path.exists():
|
||||||
|
with path.open() as f:
|
||||||
|
j = json.loads(f.read())
|
||||||
|
return Config(**j)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_yaml() -> Config | None:
|
||||||
|
path = Path("config.yaml")
|
||||||
|
if path.exists():
|
||||||
|
with path.open() as f:
|
||||||
|
y = yaml.load(f.read(), Loader=Loader)
|
||||||
|
return Config(**y)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_env() -> Config | None:
|
||||||
|
load_dotenv()
|
||||||
|
data = {}
|
||||||
|
mongo = {}
|
||||||
|
mongo_keys = find_all(lambda x: x.upper().startswith("MONGO"), environ.keys())
|
||||||
|
|
||||||
|
config_keys = mongo_keys + ["TOKEN", "LOG_LEVEL", "ENVIRONMENT"]
|
||||||
|
|
||||||
|
for item, value in environ.items():
|
||||||
|
if item not in config_keys:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if item in mongo_keys:
|
||||||
|
key = "_".join(item.split("_")[1:]).lower()
|
||||||
|
mongo[key] = value
|
||||||
|
else:
|
||||||
|
data[item.lower()] = value
|
||||||
|
|
||||||
|
data["mongo"] = mongo
|
||||||
|
|
||||||
|
return Config(**data)
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(method: Optional[str] = None) -> Config:
|
||||||
|
"""
|
||||||
|
Load the config using the specified method first
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: Method to use first
|
||||||
|
"""
|
||||||
|
global _config
|
||||||
|
if _config is not None:
|
||||||
|
return _config
|
||||||
|
|
||||||
|
methods = {"yaml": _load_yaml, "json": _load_json, "env": _load_env}
|
||||||
|
method_names = list(methods.keys())
|
||||||
|
if method and method in method_names:
|
||||||
|
method_names.remove(method)
|
||||||
|
method_names.insert(0, method)
|
||||||
|
|
||||||
|
for method in method_names:
|
||||||
|
if _config := methods[method]():
|
||||||
|
return _config
|
||||||
|
|
||||||
|
raise FileNotFoundError("Missing one of: config.yaml, config.json, .env")
|
||||||
|
|
|
@ -3,9 +3,9 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
from jarvis_core.db import q
|
from beanie.operators import Exists
|
||||||
from jarvis_core.db.models import Setting
|
from jarvis_core.db.models import Setting
|
||||||
from naff import Client
|
from interactions import Client
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -22,17 +22,21 @@ async def autokick(bot: Client) -> None:
|
||||||
logger.debug("Starting Task-autokick")
|
logger.debug("Starting Task-autokick")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
autokicks = Setting.find(q(setting="autokick", value__exists=True))
|
autokicks = Setting.find(Setting.setting == "autokick", Exists(Setting.value))
|
||||||
async for auto in autokicks:
|
async for auto in autokicks:
|
||||||
if auto.value <= 0:
|
if auto.value <= 0:
|
||||||
logger.warn("Autokick setting <= 0, deleting")
|
logger.warn("Autokick setting <= 0, deleting")
|
||||||
await auto.delete()
|
await auto.delete()
|
||||||
continue
|
continue
|
||||||
verified = await Setting.find_one(
|
verified = await Setting.find_one(
|
||||||
q(setting="verified", value__exists=True, guild=auto.guild)
|
Setting.setting == "verified",
|
||||||
|
Exists(Setting.value),
|
||||||
|
Setting.guild == auto.guild,
|
||||||
)
|
)
|
||||||
unverified = await Setting.find_one(
|
unverified = await Setting.find_one(
|
||||||
q(setting="unverified", value__exists=True, guild=auto.guild)
|
Setting.setting == "unverified",
|
||||||
|
Exists(Setting.value),
|
||||||
|
Setting.guild == auto.guild,
|
||||||
)
|
)
|
||||||
if not verified or not unverified:
|
if not verified or not unverified:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -47,7 +51,9 @@ async def autokick(bot: Client) -> None:
|
||||||
await unverified.delete()
|
await unverified.delete()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if guild.id not in resync or resync[guild.id] >= datetime.now(tz=timezone.utc):
|
if guild.id not in resync or resync[guild.id] >= datetime.now(
|
||||||
|
tz=timezone.utc
|
||||||
|
):
|
||||||
logger.info(f"Guild {guild.id} out of date, resyncing")
|
logger.info(f"Guild {guild.id} out of date, resyncing")
|
||||||
limit = 1000
|
limit = 1000
|
||||||
guild_id = guild.id
|
guild_id = guild.id
|
||||||
|
@ -60,7 +66,9 @@ async def autokick(bot: Client) -> None:
|
||||||
|
|
||||||
role = await guild.fetch_role(unverified.value)
|
role = await guild.fetch_role(unverified.value)
|
||||||
for member in role.members:
|
for member in role.members:
|
||||||
if member.joined_at + timedelta(days=auto.value) >= datetime.now(tz=timezone.utc):
|
if member.joined_at + timedelta(days=auto.value) >= datetime.now(
|
||||||
|
tz=timezone.utc
|
||||||
|
):
|
||||||
await member.kick(reason="Failed to verify in {auto.value} days")
|
await member.kick(reason="Failed to verify in {auto.value} days")
|
||||||
|
|
||||||
await asyncio.sleep(14400)
|
await asyncio.sleep(14400)
|
||||||
|
|
|
@ -3,12 +3,12 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
from jarvis_core.db import q
|
from beanie.operators import LTE
|
||||||
from jarvis_core.db.models import Ban, Unban
|
from jarvis_core.db.models import Ban, Unban
|
||||||
from naff import Client
|
from interactions import Client
|
||||||
from naff.client.errors import NotFound
|
from interactions.client.errors import NotFound
|
||||||
from naff.models.discord.guild import Guild
|
from interactions.models.discord.guild import Guild
|
||||||
from naff.models.discord.user import User
|
from interactions.models.discord.user import User
|
||||||
|
|
||||||
from jarvis_tasks.util import runat
|
from jarvis_tasks.util import runat
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ async def _unban(bot: int, guild: Guild, user: User, ban: Ban) -> None:
|
||||||
except NotFound:
|
except NotFound:
|
||||||
logger.debug(f"User {user.id} not banned from guild {guild.id}")
|
logger.debug(f"User {user.id} not banned from guild {guild.id}")
|
||||||
ban.active = False
|
ban.active = False
|
||||||
await ban.commit()
|
await ban.save()
|
||||||
await Unban(
|
await Unban(
|
||||||
user=user.id,
|
user=user.id,
|
||||||
guild=guild.id,
|
guild=guild.id,
|
||||||
|
@ -32,7 +32,7 @@ async def _unban(bot: int, guild: Guild, user: User, ban: Ban) -> None:
|
||||||
discrim=user.discriminator,
|
discrim=user.discriminator,
|
||||||
admin=bot,
|
admin=bot,
|
||||||
reason="Ban expired",
|
reason="Ban expired",
|
||||||
).commit()
|
).save()
|
||||||
queue.remove(ban.id)
|
queue.remove(ban.id)
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,7 +46,9 @@ async def unban(bot: Client) -> None:
|
||||||
logger.debug("Starting Task-ban")
|
logger.debug("Starting Task-ban")
|
||||||
while True:
|
while True:
|
||||||
max_ts = datetime.now(tz=timezone.utc) + timedelta(minutes=9)
|
max_ts = datetime.now(tz=timezone.utc) + timedelta(minutes=9)
|
||||||
bans = Ban.find(q(type="temp", active=True, duration__lte=max_ts))
|
bans = Ban.find(
|
||||||
|
Ban.type == "temp", Ban.active == True, LTE(Ban.duration, max_ts)
|
||||||
|
)
|
||||||
async for ban in bans:
|
async for ban in bans:
|
||||||
if ban.id in queue:
|
if ban.id in queue:
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -3,11 +3,11 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
from jarvis_core.db import q
|
from beanie.operators import LTE
|
||||||
from jarvis_core.db.models import Lock
|
from jarvis_core.db.models import Lock
|
||||||
from naff import Client
|
from interactions import Client
|
||||||
from naff.client.utils.misc_utils import get
|
from interactions.client.utils.misc_utils import get
|
||||||
from naff.models.discord.channel import GuildChannel
|
from interactions.models.discord.channel import GuildChannel
|
||||||
|
|
||||||
from jarvis_tasks.util import runat
|
from jarvis_tasks.util import runat
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ async def _unlock(channel: GuildChannel, lock: Lock) -> None:
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("Locked channel deleted, ignoring error")
|
logger.debug("Locked channel deleted, ignoring error")
|
||||||
lock.active = False
|
lock.active = False
|
||||||
await lock.commit()
|
await lock.save()
|
||||||
queue.remove(lock.id)
|
queue.remove(lock.id)
|
||||||
|
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ async def unlock(bot: Client) -> None:
|
||||||
logger.debug("Starting Task-lock")
|
logger.debug("Starting Task-lock")
|
||||||
while True:
|
while True:
|
||||||
max_ts = datetime.now(tz=timezone.utc) + timedelta(seconds=55)
|
max_ts = datetime.now(tz=timezone.utc) + timedelta(seconds=55)
|
||||||
locks = Lock.find(q(active=True, created_at__lte=max_ts))
|
locks = Lock.find(Lock.active == True, LTE(Lock.created_at, max_ts))
|
||||||
async for lock in locks:
|
async for lock in locks:
|
||||||
if lock.id in queue:
|
if lock.id in queue:
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -3,11 +3,11 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
from jarvis_core.db import q
|
from beanie.operators import LTE
|
||||||
from jarvis_core.db.models import Lockdown
|
from jarvis_core.db.models import Lockdown
|
||||||
from naff import Client
|
from interactions import Client
|
||||||
from naff.models.discord.enums import Permissions
|
from interactions.models.discord.enums import Permissions
|
||||||
from naff.models.discord.role import Role
|
from interactions.models.discord.role import Role
|
||||||
|
|
||||||
from jarvis_tasks.util import runat
|
from jarvis_tasks.util import runat
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ async def _lift(role: Role, lock: Lockdown) -> None:
|
||||||
original_perms = Permissions(lock.original_perms)
|
original_perms = Permissions(lock.original_perms)
|
||||||
await role.edit(permissions=original_perms)
|
await role.edit(permissions=original_perms)
|
||||||
lock.active = False
|
lock.active = False
|
||||||
await lock.commit()
|
await lock.save()
|
||||||
queue.remove(lock.id)
|
queue.remove(lock.id)
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ async def lift(bot: Client) -> None:
|
||||||
logger.debug("Starting Task-lift")
|
logger.debug("Starting Task-lift")
|
||||||
while True:
|
while True:
|
||||||
max_ts = datetime.now(tz=timezone.utc) + timedelta(seconds=55)
|
max_ts = datetime.now(tz=timezone.utc) + timedelta(seconds=55)
|
||||||
locks = Lockdown.find(q(active=True, created_at__lte=max_ts))
|
locks = Lockdown.find(Lockdown.active == True, LTE(Lockdown.created_at, max_ts))
|
||||||
async for lock in locks:
|
async for lock in locks:
|
||||||
if lock.id in queue:
|
if lock.id in queue:
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -10,27 +10,31 @@ from asyncpraw.models.reddit.redditor import Redditor as Ruser
|
||||||
from asyncpraw.models.reddit.submission import Submission
|
from asyncpraw.models.reddit.submission import Submission
|
||||||
from asyncpraw.models.reddit.submission import Subreddit as Sub
|
from asyncpraw.models.reddit.submission import Subreddit as Sub
|
||||||
from asyncprawcore.exceptions import Forbidden, NotFound
|
from asyncprawcore.exceptions import Forbidden, NotFound
|
||||||
from jarvis_core.db import q
|
from beanie.operators import NotIn
|
||||||
from jarvis_core.db.models import Redditor, RedditorFollow, Subreddit, SubredditFollow
|
from jarvis_core.db.models import Subreddit, SubredditFollow
|
||||||
from naff import Client
|
|
||||||
from naff.client.errors import NotFound as DNotFound
|
# from jarvis_core.db.models import Redditor, RedditorFollow, Subreddit, SubredditFollow
|
||||||
from naff.models.discord.embed import Embed, EmbedField
|
from interactions import Client
|
||||||
|
from interactions.client.errors import NotFound as DNotFound
|
||||||
|
from interactions.models.discord.embed import Embed, EmbedField
|
||||||
|
|
||||||
from jarvis_tasks import const
|
from jarvis_tasks import const
|
||||||
from jarvis_tasks.config import TaskConfig
|
from jarvis_tasks.config import load_config
|
||||||
from jarvis_tasks.prometheus.stats import reddit_count, reddit_gauge
|
from jarvis_tasks.prometheus.stats import reddit_count, reddit_gauge
|
||||||
from jarvis_tasks.util import build_embed
|
from jarvis_tasks.util import build_embed
|
||||||
|
|
||||||
DEFAULT_USER_AGENT = f"python:JARVIS-Tasks:{const.__version__} (by u/zevaryx)"
|
DEFAULT_USER_AGENT = f"python:JARVIS-Tasks:{const.__version__} (by u/zevaryx)"
|
||||||
|
|
||||||
config = TaskConfig.from_yaml()
|
config = load_config()
|
||||||
config.reddit["user_agent"] = config.reddit.get("user_agent", DEFAULT_USER_AGENT)
|
config.reddit.user_agent = config.reddit.dict().get("user_agent", DEFAULT_USER_AGENT)
|
||||||
running = []
|
running = []
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
image_link = re.compile(r"https?://(?:www)?\.?preview\.redd\.it\/(.*\..*)\?.*")
|
image_link = re.compile(r"https?://(?:www)?\.?preview\.redd\.it\/(.*\..*)\?.*")
|
||||||
|
|
||||||
|
|
||||||
async def post_embeds(sub: Sub, post: Submission, reddit: Reddit) -> Optional[List[Embed]]:
|
async def post_embeds(
|
||||||
|
sub: Sub, post: Submission, reddit: Reddit
|
||||||
|
) -> Optional[List[Embed]]:
|
||||||
"""
|
"""
|
||||||
Build a post embeds.
|
Build a post embeds.
|
||||||
|
|
||||||
|
@ -52,14 +56,20 @@ async def post_embeds(sub: Sub, post: Submission, reddit: Reddit) -> Optional[Li
|
||||||
og_post = post # noqa: F841
|
og_post = post # noqa: F841
|
||||||
post = await reddit.submission(post.crosspost_parent_list[0]["id"])
|
post = await reddit.submission(post.crosspost_parent_list[0]["id"])
|
||||||
await post.load()
|
await post.load()
|
||||||
fields.append(EmbedField(name="Crossposted From", value=post.subreddit_name_prefixed))
|
fields.append(
|
||||||
|
EmbedField(name="Crossposted From", value=post.subreddit_name_prefixed)
|
||||||
|
)
|
||||||
content = f"> **{post.title}**"
|
content = f"> **{post.title}**"
|
||||||
if "url" in vars(post):
|
if "url" in vars(post):
|
||||||
if any(post.url.endswith(x) for x in ["jpeg", "jpg", "png", "gif"]):
|
if any(post.url.endswith(x) for x in ["jpeg", "jpg", "png", "gif"]):
|
||||||
images = [post.url]
|
images = [post.url]
|
||||||
if "media_metadata" in vars(post):
|
if "media_metadata" in vars(post):
|
||||||
for k, v in post.media_metadata.items():
|
for k, v in post.media_metadata.items():
|
||||||
if v["status"] != "valid" or v["m"] not in ["image/jpg", "image/png", "image/gif"]:
|
if v["status"] != "valid" or v["m"] not in [
|
||||||
|
"image/jpg",
|
||||||
|
"image/png",
|
||||||
|
"image/gif",
|
||||||
|
]:
|
||||||
continue
|
continue
|
||||||
ext = v["m"].split("/")[-1]
|
ext = v["m"].split("/")[-1]
|
||||||
i_url = f"https://i.redd.it/{k}.{ext}"
|
i_url = f"https://i.redd.it/{k}.{ext}"
|
||||||
|
@ -77,7 +87,9 @@ async def post_embeds(sub: Sub, post: Submission, reddit: Reddit) -> Optional[Li
|
||||||
if post.spoiler:
|
if post.spoiler:
|
||||||
content += "||"
|
content += "||"
|
||||||
content += f"\n\n[View this post]({url})"
|
content += f"\n\n[View this post]({url})"
|
||||||
content = "\n".join(image_link.sub(r"https://i.redd.it/\1", x) for x in content.split("\n"))
|
content = "\n".join(
|
||||||
|
image_link.sub(r"https://i.redd.it/\1", x) for x in content.split("\n")
|
||||||
|
)
|
||||||
|
|
||||||
if not images and not content:
|
if not images and not content:
|
||||||
logger.debug(f"Post {post.id} had neither content nor images?")
|
logger.debug(f"Post {post.id} had neither content nor images?")
|
||||||
|
@ -94,9 +106,12 @@ async def post_embeds(sub: Sub, post: Submission, reddit: Reddit) -> Optional[Li
|
||||||
url=url,
|
url=url,
|
||||||
color=color,
|
color=color,
|
||||||
)
|
)
|
||||||
base_embed.set_author(name="u/" + post.author.name, url=author_url, icon_url=author_icon)
|
base_embed.set_author(
|
||||||
|
name="u/" + post.author.name, url=author_url, icon_url=author_icon
|
||||||
|
)
|
||||||
base_embed.set_footer(
|
base_embed.set_footer(
|
||||||
text="Reddit", icon_url="https://www.redditinc.com/assets/images/site/reddit-logo.png"
|
text="Reddit",
|
||||||
|
icon_url="https://www.redditinc.com/assets/images/site/reddit-logo.png",
|
||||||
)
|
)
|
||||||
|
|
||||||
embeds = [base_embed]
|
embeds = [base_embed]
|
||||||
|
@ -111,84 +126,92 @@ async def post_embeds(sub: Sub, post: Submission, reddit: Reddit) -> Optional[Li
|
||||||
return embeds
|
return embeds
|
||||||
|
|
||||||
|
|
||||||
async def _stream_user(sub: Ruser, bot: Client, reddit: Reddit) -> None:
|
# async def _stream_user(sub: Ruser, bot: Client, reddit: Reddit) -> None:
|
||||||
"""
|
# """
|
||||||
Stream a redditor
|
# Stream a redditor
|
||||||
|
|
||||||
Args:
|
# Args:
|
||||||
sub: Redditor to stream
|
# sub: Redditor to stream
|
||||||
bot: Client instance
|
# bot: Client instance
|
||||||
"""
|
# """
|
||||||
now = datetime.now(tz=timezone.utc)
|
# now = datetime.now(tz=timezone.utc)
|
||||||
await sub.load()
|
# await sub.load()
|
||||||
running.append(sub.name)
|
# running.append(sub.name)
|
||||||
logger.debug(f"Streaming user {sub.name}")
|
# logger.debug(f"Streaming user {sub.name}")
|
||||||
try:
|
# try:
|
||||||
async for post in sub.stream.submissions():
|
# async for post in sub.stream.submissions():
|
||||||
if not post:
|
# if not post:
|
||||||
logger.debug(f"Got None for post from {sub.name}")
|
# logger.debug(f"Got None for post from {sub.name}")
|
||||||
continue
|
# continue
|
||||||
await post.subreddit.load()
|
# await post.subreddit.load()
|
||||||
if post.created_utc < now.timestamp():
|
# if post.created_utc < now.timestamp():
|
||||||
continue
|
# continue
|
||||||
logger.debug(f"Got new post from {sub.name} in r/{post.subreddit.display_name}")
|
# logger.debug(
|
||||||
follows = RedditorFollow.find(q(name=sub.name))
|
# f"Got new post from {sub.name} in r/{post.subreddit.display_name}"
|
||||||
num_follows = 0
|
# )
|
||||||
|
# follows = RedditorFollow.find(RedditorFollow.name == sub.name)
|
||||||
|
# num_follows = 0
|
||||||
|
|
||||||
async for follow in follows:
|
# async for follow in follows:
|
||||||
num_follows += 1
|
# num_follows += 1
|
||||||
|
|
||||||
guild = await bot.fetch_guild(follow.guild)
|
# guild = await bot.fetch_guild(follow.guild)
|
||||||
if not guild:
|
# if not guild:
|
||||||
logger.warning(f"Follow {follow.id}'s guild no longer exists, deleting")
|
# logger.warning(
|
||||||
await follow.delete()
|
# f"Follow {follow.id}'s guild no longer exists, deleting"
|
||||||
num_follows -= 1
|
# )
|
||||||
continue
|
# await follow.delete()
|
||||||
|
# num_follows -= 1
|
||||||
|
# continue
|
||||||
|
|
||||||
channel = await bot.fetch_channel(follow.channel)
|
# channel = await bot.fetch_channel(follow.channel)
|
||||||
if not channel:
|
# if not channel:
|
||||||
logger.warning(f"Follow {follow.id}'s channel no longer exists, deleting")
|
# logger.warning(
|
||||||
await follow.delete()
|
# f"Follow {follow.id}'s channel no longer exists, deleting"
|
||||||
num_follows -= 1
|
# )
|
||||||
continue
|
# await follow.delete()
|
||||||
|
# num_follows -= 1
|
||||||
|
# continue
|
||||||
|
|
||||||
embeds = await post_embeds(post.subreddit, post, reddit)
|
# embeds = await post_embeds(post.subreddit, post, reddit)
|
||||||
timestamp = int(post.created_utc)
|
# timestamp = int(post.created_utc)
|
||||||
|
|
||||||
try:
|
# try:
|
||||||
await channel.send(
|
# await channel.send(
|
||||||
f"`u/{sub.name}` posted to r/{post.subreddit.display_name} at <t:{timestamp}:f>",
|
# f"`u/{sub.name}` posted to r/{post.subreddit.display_name} at <t:{timestamp}:f>",
|
||||||
embeds=embeds,
|
# embeds=embeds,
|
||||||
)
|
# )
|
||||||
count = reddit_count.labels(
|
# count = reddit_count.labels(
|
||||||
guild_id=guild.id,
|
# guild_id=guild.id,
|
||||||
guild_name=guild.name,
|
# guild_name=guild.name,
|
||||||
subreddit_name=post.subreddit.display_name,
|
# subreddit_name=post.subreddit.display_name,
|
||||||
redditor_name=sub.name,
|
# redditor_name=sub.name,
|
||||||
)
|
# )
|
||||||
count.inc()
|
# count.inc()
|
||||||
except DNotFound:
|
# except DNotFound:
|
||||||
logger.warning(f"Follow {follow.id}'s channel no longer exists, deleting")
|
# logger.warning(
|
||||||
await follow.delete()
|
# f"Follow {follow.id}'s channel no longer exists, deleting"
|
||||||
num_follows -= 1
|
# )
|
||||||
continue
|
# await follow.delete()
|
||||||
except Exception:
|
# num_follows -= 1
|
||||||
logger.error(
|
# continue
|
||||||
f"Failed to send message to {channel.id} in {channel.guild.name}",
|
# except Exception:
|
||||||
exc_info=True,
|
# logger.error(
|
||||||
)
|
# f"Failed to send message to {channel.id} in {channel.guild.name}",
|
||||||
|
# exc_info=True,
|
||||||
|
# )
|
||||||
|
|
||||||
gauge = reddit_gauge.labels(redditor_name=sub.name)
|
# gauge = reddit_gauge.labels(redditor_name=sub.name)
|
||||||
gauge.set(num_follows)
|
# gauge.set(num_follows)
|
||||||
|
|
||||||
if num_follows == 0:
|
# if num_follows == 0:
|
||||||
s = await Redditor.find_one(q(name=sub.name))
|
# s = await Redditor.find_one(Redditor.name == sub.name)
|
||||||
if s:
|
# if s:
|
||||||
await s.delete()
|
# await s.delete()
|
||||||
break
|
# break
|
||||||
except Exception:
|
# except Exception:
|
||||||
logger.error(f"Redditor stream {sub.name} failed", exc_info=True)
|
# logger.error(f"Redditor stream {sub.name} failed", exc_info=True)
|
||||||
running.remove(sub.name)
|
# running.remove(sub.name)
|
||||||
|
|
||||||
|
|
||||||
async def _stream_subreddit(sub: Sub, bot: Client, reddit: Reddit) -> None:
|
async def _stream_subreddit(sub: Sub, bot: Client, reddit: Reddit) -> None:
|
||||||
|
@ -211,7 +234,9 @@ async def _stream_subreddit(sub: Sub, bot: Client, reddit: Reddit) -> None:
|
||||||
if post.created_utc < now.timestamp():
|
if post.created_utc < now.timestamp():
|
||||||
continue
|
continue
|
||||||
logger.debug(f"Got new post in {sub.display_name}")
|
logger.debug(f"Got new post in {sub.display_name}")
|
||||||
follows = SubredditFollow.find(q(display_name=sub.display_name))
|
follows = SubredditFollow.find(
|
||||||
|
SubredditFollow.display_name == sub.display_name
|
||||||
|
)
|
||||||
num_follows = 0
|
num_follows = 0
|
||||||
|
|
||||||
async for follow in follows:
|
async for follow in follows:
|
||||||
|
@ -219,14 +244,18 @@ async def _stream_subreddit(sub: Sub, bot: Client, reddit: Reddit) -> None:
|
||||||
|
|
||||||
guild = await bot.fetch_guild(follow.guild)
|
guild = await bot.fetch_guild(follow.guild)
|
||||||
if not guild:
|
if not guild:
|
||||||
logger.warning(f"Follow {follow.id}'s guild no longer exists, deleting")
|
logger.warning(
|
||||||
|
f"Follow {follow.id}'s guild no longer exists, deleting"
|
||||||
|
)
|
||||||
await follow.delete()
|
await follow.delete()
|
||||||
num_follows -= 1
|
num_follows -= 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
channel = await bot.fetch_channel(follow.channel)
|
channel = await bot.fetch_channel(follow.channel)
|
||||||
if not channel:
|
if not channel:
|
||||||
logger.warning(f"Follow {follow.id}'s channel no longer exists, deleting")
|
logger.warning(
|
||||||
|
f"Follow {follow.id}'s channel no longer exists, deleting"
|
||||||
|
)
|
||||||
await follow.delete()
|
await follow.delete()
|
||||||
num_follows -= 1
|
num_follows -= 1
|
||||||
continue
|
continue
|
||||||
|
@ -247,7 +276,9 @@ async def _stream_subreddit(sub: Sub, bot: Client, reddit: Reddit) -> None:
|
||||||
)
|
)
|
||||||
count.inc()
|
count.inc()
|
||||||
except DNotFound:
|
except DNotFound:
|
||||||
logger.warning(f"Follow {follow.id}'s channel no longer exists, deleting")
|
logger.warning(
|
||||||
|
f"Follow {follow.id}'s channel no longer exists, deleting"
|
||||||
|
)
|
||||||
await follow.delete()
|
await follow.delete()
|
||||||
num_follows -= 1
|
num_follows -= 1
|
||||||
continue
|
continue
|
||||||
|
@ -263,7 +294,7 @@ async def _stream_subreddit(sub: Sub, bot: Client, reddit: Reddit) -> None:
|
||||||
gauge.set(num_follows)
|
gauge.set(num_follows)
|
||||||
|
|
||||||
if num_follows == 0:
|
if num_follows == 0:
|
||||||
s = await Subreddit.find_one(q(display_name=sub.display_name))
|
s = await Subreddit.find_one(Subreddit.display_name == sub.display_name)
|
||||||
if s:
|
if s:
|
||||||
await s.delete()
|
await s.delete()
|
||||||
break
|
break
|
||||||
|
@ -276,12 +307,12 @@ async def _stream(sub: Sub | Ruser, bot: Client, reddit: Reddit) -> None:
|
||||||
"""
|
"""
|
||||||
Stream handler.
|
Stream handler.
|
||||||
|
|
||||||
Decides what type of stream to launch based on `type(sub)`
|
Decides what type of stream to launch based on `isinstance(sub, Sub)`
|
||||||
"""
|
"""
|
||||||
if isinstance(sub, Sub):
|
if isinstance(sub, Sub):
|
||||||
await _stream_subreddit(sub, bot, reddit)
|
await _stream_subreddit(sub, bot, reddit)
|
||||||
else:
|
# else:
|
||||||
await _stream_user(sub, bot, reddit)
|
# await _stream_user(sub, bot, reddit)
|
||||||
|
|
||||||
|
|
||||||
async def reddit(bot: Client) -> None:
|
async def reddit(bot: Client) -> None:
|
||||||
|
@ -301,7 +332,9 @@ async def reddit(bot: Client) -> None:
|
||||||
async for sub in Subreddit.find():
|
async for sub in Subreddit.find():
|
||||||
count = 0
|
count = 0
|
||||||
|
|
||||||
async for follow in SubredditFollow.find(q(display_name=sub.display_name)):
|
async for follow in SubredditFollow.find(
|
||||||
|
SubredditFollow.display_name == sub.display_name
|
||||||
|
):
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
guild = await bot.fetch_guild(follow.guild)
|
guild = await bot.fetch_guild(follow.guild)
|
||||||
|
@ -316,30 +349,30 @@ async def reddit(bot: Client) -> None:
|
||||||
logger.debug(f"Subreddit {sub.display_name} has no followers, removing")
|
logger.debug(f"Subreddit {sub.display_name} has no followers, removing")
|
||||||
await sub.delete()
|
await sub.delete()
|
||||||
|
|
||||||
logger.debug("Validating redditor follows")
|
# logger.debug("Validating redditor follows")
|
||||||
async for sub in Redditor.find():
|
# async for sub in Redditor.find():
|
||||||
count = 0
|
# count = 0
|
||||||
|
|
||||||
async for follow in RedditorFollow.find(q(name=sub.name)):
|
# async for follow in RedditorFollow.find(RedditorFollow.name == sub.name):
|
||||||
count += 1
|
# count += 1
|
||||||
|
|
||||||
guild = await bot.fetch_guild(follow.guild)
|
# guild = await bot.fetch_guild(follow.guild)
|
||||||
channel = await bot.fetch_channel(follow.channel)
|
# channel = await bot.fetch_channel(follow.channel)
|
||||||
if not guild or not channel:
|
# if not guild or not channel:
|
||||||
logger.debug(f"Follow {follow.id} invalid, deleting")
|
# logger.debug(f"Follow {follow.id} invalid, deleting")
|
||||||
await follow.delete()
|
# await follow.delete()
|
||||||
count -= 1
|
# count -= 1
|
||||||
continue
|
# continue
|
||||||
|
|
||||||
if count == 0:
|
# if count == 0:
|
||||||
logger.debug(f"Redditor {sub.name} has no followers, removing")
|
# logger.debug(f"Redditor {sub.name} has no followers, removing")
|
||||||
await sub.delete()
|
# await sub.delete()
|
||||||
|
|
||||||
old_count = 0
|
old_count = 0
|
||||||
while True:
|
while True:
|
||||||
count = len(running)
|
count = len(running)
|
||||||
subs = Subreddit.find(q(display_name__nin=running))
|
subs = Subreddit.find(NotIn(Subreddit.display_name, running))
|
||||||
users = Redditor.find(q(name__nin=running))
|
# users = Redditor.find(NotIn(Redditor.name, running))
|
||||||
|
|
||||||
# Go through all actively followed subreddits
|
# Go through all actively followed subreddits
|
||||||
async for sub in subs:
|
async for sub in subs:
|
||||||
|
@ -348,7 +381,9 @@ async def reddit(bot: Client) -> None:
|
||||||
logger.debug(f"Follow {sub.display_name} was found despite filter")
|
logger.debug(f"Follow {sub.display_name} was found despite filter")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
is_followed = await SubredditFollow.find_one(q(display_name=sub.display_name))
|
is_followed = await SubredditFollow.find_one(
|
||||||
|
SubredditFollow.display_name == sub.display_name
|
||||||
|
)
|
||||||
if not is_followed:
|
if not is_followed:
|
||||||
logger.warn(f"Subreddit {sub.display_name} has no followers, removing")
|
logger.warn(f"Subreddit {sub.display_name} has no followers, removing")
|
||||||
await sub.delete()
|
await sub.delete()
|
||||||
|
@ -359,7 +394,9 @@ async def reddit(bot: Client) -> None:
|
||||||
sub = await red.subreddit(sub.display_name)
|
sub = await red.subreddit(sub.display_name)
|
||||||
except (NotFound, Forbidden) as e:
|
except (NotFound, Forbidden) as e:
|
||||||
# Subreddit is either quarantined, deleted, or private
|
# Subreddit is either quarantined, deleted, or private
|
||||||
logger.warn(f"Subreddit {sub.display_name} raised {e.__class__.__name__}, removing")
|
logger.warn(
|
||||||
|
f"Subreddit {sub.display_name} raised {e.__class__.__name__}, removing"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
await sub.delete()
|
await sub.delete()
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -372,34 +409,38 @@ async def reddit(bot: Client) -> None:
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
# Go through all actively followed redditors
|
# Go through all actively followed redditors
|
||||||
async for sub in users:
|
# async for sub in users:
|
||||||
logger.debug(f"Creating stream for {sub.name}")
|
# logger.debug(f"Creating stream for {sub.name}")
|
||||||
if sub.name in running:
|
# if sub.name in running:
|
||||||
logger.debug(f"Follow {sub.name} was found despite filter")
|
# logger.debug(f"Follow {sub.name} was found despite filter")
|
||||||
continue
|
# continue
|
||||||
|
|
||||||
is_followed = await SubredditFollow.find_one(q(name=sub.name))
|
# is_followed = await SubredditFollow.find_one(
|
||||||
if not is_followed:
|
# SubredditFollow.name == sub.name
|
||||||
logger.warn(f"Redditor {sub.name} has no followers, removing")
|
# )
|
||||||
await sub.delete()
|
# if not is_followed:
|
||||||
continue
|
# logger.warn(f"Redditor {sub.name} has no followers, removing")
|
||||||
|
# await sub.delete()
|
||||||
|
# continue
|
||||||
|
|
||||||
# Get subreddit
|
# # Get subreddit
|
||||||
try:
|
# try:
|
||||||
sub = await red.user(sub.name)
|
# sub = await red.user(sub.name)
|
||||||
except (NotFound, Forbidden) as e:
|
# except (NotFound, Forbidden) as e:
|
||||||
# Subreddit is either quarantined, deleted, or private
|
# # Subreddit is either quarantined, deleted, or private
|
||||||
logger.warn(f"Redditor {sub.display_name} raised {e.__class__.__name__}, removing")
|
# logger.warn(
|
||||||
try:
|
# f"Redditor {sub.display_name} raised {e.__class__.__name__}, removing"
|
||||||
await sub.delete()
|
# )
|
||||||
except Exception:
|
# try:
|
||||||
logger.debug("Ignoring deletion error")
|
# await sub.delete()
|
||||||
continue
|
# except Exception:
|
||||||
|
# logger.debug("Ignoring deletion error")
|
||||||
|
# continue
|
||||||
|
|
||||||
# Create and run stream
|
# # Create and run stream
|
||||||
coro = _stream(sub, bot, red)
|
# coro = _stream(sub, bot, red)
|
||||||
asyncio.create_task(coro)
|
# asyncio.create_task(coro)
|
||||||
count += 1
|
# count += 1
|
||||||
|
|
||||||
if old_count != count:
|
if old_count != count:
|
||||||
logger.debug(f"Now streaming {count} subreddits")
|
logger.debug(f"Now streaming {count} subreddits")
|
||||||
|
|
|
@ -1,15 +1,17 @@
|
||||||
"""JARVIS reminders."""
|
"""JARVIS reminders."""
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from jarvis_core.db import q
|
import pytz
|
||||||
|
from beanie.operators import LTE, NotIn
|
||||||
|
from croniter import croniter
|
||||||
|
from interactions import Client
|
||||||
|
from interactions.models.discord.channel import GuildText
|
||||||
|
from interactions.models.discord.embed import Embed
|
||||||
|
from interactions.models.discord.user import User
|
||||||
from jarvis_core.db.models import Reminder
|
from jarvis_core.db.models import Reminder
|
||||||
from naff import Client
|
|
||||||
from naff.models.discord.channel import GuildText
|
|
||||||
from naff.models.discord.embed import Embed
|
|
||||||
from naff.models.discord.user import User
|
|
||||||
|
|
||||||
from jarvis_tasks.prometheus.stats import reminder_count
|
from jarvis_tasks.prometheus.stats import reminder_count
|
||||||
from jarvis_tasks.util import build_embed, runat
|
from jarvis_tasks.util import build_embed, runat
|
||||||
|
@ -52,14 +54,24 @@ async def _remind(
|
||||||
f"Reminder {reminder.id} private, sent notification to origin channel"
|
f"Reminder {reminder.id} private, sent notification to origin channel"
|
||||||
)
|
)
|
||||||
reminder.active = False
|
reminder.active = False
|
||||||
await reminder.commit()
|
await reminder.save()
|
||||||
delete = False
|
delete = False
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Reminder {reminder.id} failed, no way to contact user.")
|
logger.warning(f"Reminder {reminder.id} failed, no way to contact user.")
|
||||||
|
if reminder.repeat:
|
||||||
|
now = datetime.now(tz=pytz.timezone(reminder.timezone))
|
||||||
|
cron = croniter(reminder.repeat, now)
|
||||||
|
reminder.remind_at = cron.next(datetime)
|
||||||
|
reminder.total_reminders += 1
|
||||||
|
delete = False
|
||||||
if delete:
|
if delete:
|
||||||
await reminder.delete()
|
await reminder.delete()
|
||||||
|
else:
|
||||||
|
await reminder.save()
|
||||||
if reminded:
|
if reminded:
|
||||||
count = reminder_count.labels(guild_id=channel.guild.id, guild_name=channel.guild.name)
|
guild_id = channel.guild.id if channel.guild else user.id
|
||||||
|
guild_name = channel.guild.name if channel.guild else user.username
|
||||||
|
count = reminder_count.labels(guild_id=guild_id, guild_name=guild_name)
|
||||||
count.inc()
|
count.inc()
|
||||||
queue.remove(reminder.id)
|
queue.remove(reminder.id)
|
||||||
|
|
||||||
|
@ -73,8 +85,13 @@ async def remind(bot: Client) -> None:
|
||||||
"""
|
"""
|
||||||
logger.debug("Starting Task-remind")
|
logger.debug("Starting Task-remind")
|
||||||
while True:
|
while True:
|
||||||
max_ts = datetime.now(tz=timezone.utc) + timedelta(seconds=5)
|
max_ts = datetime.now(tz=pytz.utc) + timedelta(seconds=5)
|
||||||
reminders = Reminder.find(q(id__nin=queue, remind_at__lte=max_ts, active=True))
|
reminders = Reminder.find(
|
||||||
|
NotIn(Reminder.id, queue),
|
||||||
|
LTE(Reminder.remind_at, max_ts),
|
||||||
|
Reminder.active == True,
|
||||||
|
)
|
||||||
|
|
||||||
async for reminder in reminders:
|
async for reminder in reminders:
|
||||||
if reminder.id in queue:
|
if reminder.id in queue:
|
||||||
logger.debug(f"Reminder {reminder.id} was found despite filter")
|
logger.debug(f"Reminder {reminder.id} was found despite filter")
|
||||||
|
|
|
@ -3,9 +3,9 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
from jarvis_core.db import q
|
from beanie.operators import LTE, NotIn
|
||||||
from jarvis_core.db.models import Temprole
|
from jarvis_core.db.models import Temprole
|
||||||
from naff import Client
|
from interactions import Client
|
||||||
|
|
||||||
from jarvis_tasks.util import runat
|
from jarvis_tasks.util import runat
|
||||||
|
|
||||||
|
@ -52,7 +52,9 @@ async def remove(bot: Client) -> None:
|
||||||
logger.debug("Starting Task-remove")
|
logger.debug("Starting Task-remove")
|
||||||
while True:
|
while True:
|
||||||
max_ts = datetime.now(tz=timezone.utc) + timedelta(seconds=45)
|
max_ts = datetime.now(tz=timezone.utc) + timedelta(seconds=45)
|
||||||
temproles = Temprole.find(q(expires_at__lte=max_ts, id__nin=queue))
|
temproles = Temprole.find(
|
||||||
|
LTE(Temprole.expires_at, max_ts), NotIn(Temprole.id, queue)
|
||||||
|
)
|
||||||
async for temprole in temproles:
|
async for temprole in temproles:
|
||||||
if temprole.id in queue:
|
if temprole.id in queue:
|
||||||
logger.warn("Temprole found despite filter")
|
logger.warn("Temprole found despite filter")
|
||||||
|
|
|
@ -5,20 +5,19 @@ from datetime import datetime, timedelta, timezone
|
||||||
from html import unescape
|
from html import unescape
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from jarvis_core.db import q
|
|
||||||
from jarvis_core.db.models import TwitterAccount, TwitterFollow
|
from jarvis_core.db.models import TwitterAccount, TwitterFollow
|
||||||
from naff import Client
|
from interactions import Client
|
||||||
from naff.client.errors import NotFound
|
from interactions.client.errors import NotFound
|
||||||
from naff.models.discord.embed import Embed
|
from interactions.models.discord.embed import Embed
|
||||||
from tweepy.streaming import StreamRule
|
from tweepy.streaming import StreamRule
|
||||||
from tweepy.asynchronous import AsyncClient, AsyncStreamingClient
|
from tweepy.asynchronous import AsyncClient, AsyncStreamingClient
|
||||||
from tweepy import Media, Tweet, User
|
from tweepy import Media, Tweet, User
|
||||||
|
|
||||||
from jarvis_tasks.config import TaskConfig
|
from jarvis_tasks.config import load_config
|
||||||
from jarvis_tasks.prometheus.stats import twitter_count, twitter_error, twitter_gauge
|
from jarvis_tasks.prometheus.stats import twitter_count, twitter_error, twitter_gauge
|
||||||
from jarvis_tasks.util import build_embed
|
from jarvis_tasks.util import build_embed
|
||||||
|
|
||||||
config = TaskConfig.from_yaml()
|
config = load_config()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
tlogger = logging.getLogger("Tweepy")
|
tlogger = logging.getLogger("Tweepy")
|
||||||
tlogger.setLevel(logging.DEBUG)
|
tlogger.setLevel(logging.DEBUG)
|
||||||
|
@ -29,7 +28,9 @@ DEFAULT_TWEET_FIELDS = "created_at"
|
||||||
DEFAULT_USER_FIELDS = "url,profile_image_url"
|
DEFAULT_USER_FIELDS = "url,profile_image_url"
|
||||||
|
|
||||||
|
|
||||||
async def tweet_embeds(tweet: Tweet, retweet: bool, quoted: bool, api: AsyncClient) -> List[Embed]:
|
async def tweet_embeds(
|
||||||
|
tweet: Tweet, retweet: bool, quoted: bool, api: AsyncClient
|
||||||
|
) -> List[Embed]:
|
||||||
"""
|
"""
|
||||||
Build a tweet embeds.
|
Build a tweet embeds.
|
||||||
|
|
||||||
|
@ -119,7 +120,10 @@ class JARVISTwitterStream(AsyncStreamingClient):
|
||||||
Args:
|
Args:
|
||||||
status_code: HTTP Status Code
|
status_code: HTTP Status Code
|
||||||
"""
|
"""
|
||||||
logger.error(f"Received status code {status_code} while streaming, restarting", exc_info=True)
|
logger.error(
|
||||||
|
f"Received status code {status_code} while streaming, restarting",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
errors = twitter_error.labels(error_code=status_code)
|
errors = twitter_error.labels(error_code=status_code)
|
||||||
errors.inc()
|
errors.inc()
|
||||||
self.disconnect()
|
self.disconnect()
|
||||||
|
@ -142,7 +146,7 @@ class JARVISTwitterStream(AsyncStreamingClient):
|
||||||
)
|
)
|
||||||
author = status.includes.get("users")[0]
|
author = status.includes.get("users")[0]
|
||||||
logger.debug(f"{author.username} sent new tweet")
|
logger.debug(f"{author.username} sent new tweet")
|
||||||
follows = TwitterFollow.find(q(twitter_id=author.id))
|
follows = TwitterFollow.find(TwitterFollow.twitter_id == author.id)
|
||||||
num_follows = 0
|
num_follows = 0
|
||||||
|
|
||||||
retweet = False
|
retweet = False
|
||||||
|
@ -184,7 +188,11 @@ class JARVISTwitterStream(AsyncStreamingClient):
|
||||||
f"`@{author.username}` {mod} this at <t:{timestamp}:f>",
|
f"`@{author.username}` {mod} this at <t:{timestamp}:f>",
|
||||||
embeds=embeds,
|
embeds=embeds,
|
||||||
)
|
)
|
||||||
count = twitter_count.labels(guild_id=guild.id, guild_name=guild.name, twitter_handle=author.username)
|
count = twitter_count.labels(
|
||||||
|
guild_id=guild.id,
|
||||||
|
guild_name=guild.name,
|
||||||
|
twitter_handle=author.username,
|
||||||
|
)
|
||||||
count.inc()
|
count.inc()
|
||||||
except NotFound:
|
except NotFound:
|
||||||
logger.warn(f"Follow {follow.id} invalid, deleting")
|
logger.warn(f"Follow {follow.id} invalid, deleting")
|
||||||
|
@ -192,11 +200,17 @@ class JARVISTwitterStream(AsyncStreamingClient):
|
||||||
num_follows -= 1
|
num_follows -= 1
|
||||||
continue
|
continue
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug(f"Failed to send message to {channel.id} in {channel.guild.name}")
|
logger.debug(
|
||||||
|
f"Failed to send message to {channel.id} in {channel.guild.name}"
|
||||||
|
)
|
||||||
|
|
||||||
if num_follows == 0:
|
if num_follows == 0:
|
||||||
logger.warning(f"Account {author.username} no longer has followers, removing")
|
logger.warning(
|
||||||
account = await TwitterAccount.find_one(q(twitter_id=author.id))
|
f"Account {author.username} no longer has followers, removing"
|
||||||
|
)
|
||||||
|
account = await TwitterAccount.find_one(
|
||||||
|
TwitterAccount.twitter_id == author.id
|
||||||
|
)
|
||||||
if account:
|
if account:
|
||||||
await account.delete()
|
await account.delete()
|
||||||
self.disconnect()
|
self.disconnect()
|
||||||
|
@ -216,7 +230,9 @@ async def twitter(bot: Client) -> None:
|
||||||
logger.warn("Missing Twitter config, not starting")
|
logger.warn("Missing Twitter config, not starting")
|
||||||
return
|
return
|
||||||
api = AsyncClient(bearer_token=config.twitter["bearer_token"])
|
api = AsyncClient(bearer_token=config.twitter["bearer_token"])
|
||||||
stream = JARVISTwitterStream(bot=bot, bearer_token=config.twitter["bearer_token"], api=api)
|
stream = JARVISTwitterStream(
|
||||||
|
bot=bot, bearer_token=config.twitter["bearer_token"], api=api
|
||||||
|
)
|
||||||
rules = await stream.get_rules()
|
rules = await stream.get_rules()
|
||||||
if rules.data:
|
if rules.data:
|
||||||
await stream.delete_rules(rules.data)
|
await stream.delete_rules(rules.data)
|
||||||
|
@ -226,7 +242,9 @@ async def twitter(bot: Client) -> None:
|
||||||
async for account in TwitterAccount.find():
|
async for account in TwitterAccount.find():
|
||||||
count = 0
|
count = 0
|
||||||
|
|
||||||
async for follow in TwitterFollow.find(q(twitter_id=account.twitter_id)):
|
async for follow in TwitterFollow.find(
|
||||||
|
TwitterFollow.twitter_id == account.twitter_id
|
||||||
|
):
|
||||||
count += 1
|
count += 1
|
||||||
try:
|
try:
|
||||||
guild = await bot.fetch_guild(follow.guild)
|
guild = await bot.fetch_guild(follow.guild)
|
||||||
|
@ -270,7 +288,7 @@ async def twitter(bot: Client) -> None:
|
||||||
continue
|
continue
|
||||||
account.handle = user.data.username
|
account.handle = user.data.username
|
||||||
account.last_sync = datetime.now(tz=timezone.utc)
|
account.last_sync = datetime.now(tz=timezone.utc)
|
||||||
await account.commit()
|
await account.save()
|
||||||
ids.append(account.twitter_id)
|
ids.append(account.twitter_id)
|
||||||
|
|
||||||
# Get new tweets
|
# Get new tweets
|
||||||
|
|
|
@ -3,9 +3,9 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
from jarvis_core.db import q
|
from beanie.operators import LTE, NotIn
|
||||||
from jarvis_core.db.models import Warning
|
from jarvis_core.db.models import Warning
|
||||||
from naff import Client
|
from interactions import Client
|
||||||
|
|
||||||
from jarvis_tasks.util import runat
|
from jarvis_tasks.util import runat
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
|
||||||
async def _unwarn(warn: Warning) -> None:
|
async def _unwarn(warn: Warning) -> None:
|
||||||
logger.debug(f"Deactivating warning {warn.id}")
|
logger.debug(f"Deactivating warning {warn.id}")
|
||||||
warn.active = False
|
warn.active = False
|
||||||
await warn.commit()
|
await warn.save()
|
||||||
queue.remove(warn.id)
|
queue.remove(warn.id)
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,7 +30,11 @@ async def unwarn(bot: Client) -> None:
|
||||||
logger.debug("Starting Task-unwarn")
|
logger.debug("Starting Task-unwarn")
|
||||||
while True:
|
while True:
|
||||||
max_ts = datetime.now(tz=timezone.utc) + timedelta(minutes=55)
|
max_ts = datetime.now(tz=timezone.utc) + timedelta(minutes=55)
|
||||||
warns = Warning.find(q(active=True, expires_at__lte=max_ts, id__nin=queue))
|
warns = Warning.find(
|
||||||
|
Warning.active == True,
|
||||||
|
LTE(Warning.expires_at, max_ts),
|
||||||
|
NotIn(Warning.id, queue),
|
||||||
|
)
|
||||||
async for warn in warns:
|
async for warn in warns:
|
||||||
if warn.id in queue:
|
if warn.id in queue:
|
||||||
logger.warn("Warning found despite filter")
|
logger.warn("Warning found despite filter")
|
||||||
|
|
|
@ -4,7 +4,7 @@ from datetime import datetime, timezone
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from typing import Coroutine
|
from typing import Coroutine
|
||||||
|
|
||||||
from naff.models.discord.embed import Embed
|
from interactions.models.discord.embed import Embed
|
||||||
|
|
||||||
|
|
||||||
async def runat(when: datetime, coro: Coroutine, logger: Logger) -> None:
|
async def runat(when: datetime, coro: Coroutine, logger: Logger) -> None:
|
||||||
|
|
1593
poetry.lock
generated
1593
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -1,23 +1,26 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "jarvis-tasks"
|
name = "jarvis-tasks"
|
||||||
version = "0.9.1"
|
version = "0.11.0"
|
||||||
description = ""
|
description = ""
|
||||||
authors = ["Your Name <you@example.com>"]
|
authors = ["Zevaryx <zevaryx@gmail.com>"]
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.10,<4"
|
python = ">=3.10,<4"
|
||||||
jarvis-core = {git = "https://git.zevaryx.com/stark-industries/jarvis/jarvis-core.git", rev = "main"}
|
jarvis-core = { git = "https://git.zevaryx.com/stark-industries/jarvis/jarvis-core.git", rev = "main" }
|
||||||
naff = ">=2.1.0"
|
|
||||||
aiohttp = "^3.8.3"
|
aiohttp = "^3.8.3"
|
||||||
tweepy = {extras = ["async"], version = "^4.13.0"}
|
tweepy = { extras = ["async"], version = "^4.13.0" }
|
||||||
asyncpraw = "^7.5.0"
|
asyncpraw = "^7.5.0"
|
||||||
#rook = "^0.1.170"
|
|
||||||
uvicorn = "^0.17.6"
|
uvicorn = "^0.17.6"
|
||||||
prometheus-client = "^0.14.1"
|
prometheus-client = "^0.14.1"
|
||||||
|
interactions-py = "^5.3.1"
|
||||||
|
pydantic = ">=2.3.0,<3"
|
||||||
|
# rocketry = "^2.5.1"
|
||||||
|
croniter = "^1.4.1"
|
||||||
|
beanie = ">=1.21.0"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
pytest = "^7.1"
|
pytest = "^7.1"
|
||||||
black = {version = "^22.3.0", allow-prereleases = true}
|
black = { version = "^22.3.0", allow-prereleases = true }
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core>=1.0.0"]
|
requires = ["poetry-core>=1.0.0"]
|
||||||
|
|
38
sample.env
Normal file
38
sample.env
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
# Base Config, required
|
||||||
|
TOKEN=
|
||||||
|
|
||||||
|
# Base Config, optional
|
||||||
|
ENVIRONMENT=develop
|
||||||
|
SYNC=false
|
||||||
|
LOG_LEVEL=INFO
|
||||||
|
JURIGGED=false
|
||||||
|
|
||||||
|
# MongoDB, required
|
||||||
|
MONGO_HOST=localhost
|
||||||
|
MONGO_USERNAME=
|
||||||
|
MONGO_PASSWORD=
|
||||||
|
MONGO_PORT=27017
|
||||||
|
|
||||||
|
# Redis, required
|
||||||
|
REDIS_HOST=localhost
|
||||||
|
REDIS_USERNAME=
|
||||||
|
REDIS_PASSWORD=
|
||||||
|
|
||||||
|
# Mastodon, optional
|
||||||
|
MASTODON_TOKEN=
|
||||||
|
MASTODON_URL=
|
||||||
|
|
||||||
|
# Reddit, optional
|
||||||
|
REDDIT_USER_AGENT=
|
||||||
|
REDDIT_CLIENT_SECRET=
|
||||||
|
REDDIT_CLIENT_ID=
|
||||||
|
|
||||||
|
# Twitter, optional
|
||||||
|
TWITTER_CONSUMER_KEY=
|
||||||
|
TWITTER_CONSUMER_SECRET=
|
||||||
|
TWITTER_ACCESS_TOKEN=
|
||||||
|
TWITTER_ACCESS_SECRET=
|
||||||
|
TWITTER_BEARER_TOKEN=
|
||||||
|
|
||||||
|
# URLs, optional
|
||||||
|
URL_DBRAND=
|
Loading…
Add table
Reference in a new issue