diff --git a/jarvis/cogs/jokes.py b/jarvis/cogs/jokes.py index 3e6123d..7510145 100644 --- a/jarvis/cogs/jokes.py +++ b/jarvis/cogs/jokes.py @@ -3,7 +3,9 @@ import discord from random import randint import html import re -from jarvis.utils import build_embed, db +import traceback +from jarvis.utils import build_embed +from jarvis.utils.db import DBManager from jarvis.utils.field import Field from discord.ext import commands from datetime import datetime @@ -18,30 +20,23 @@ class JokeCog(commands.Cog): def __init__(self, bot): self.bot = bot - self.db = db.create_connection() + config = jarvis.config.get_config() + self.db = DBManager(config.mongo) # TODO: Make this a command group with subcommands @commands.command(name="joke", help="Hear a joke") async def _joke(self, ctx, id: str = None): - if randint(1, 100_000) == 5779 and id is None: - await ctx.send(f"<@{ctx.message.author.id}>") - return - # TODO: Add this as a parameter that can be passed in - threshold = 500 # Minimum score - coll = self.db.jokes.reddit - result = None - if id: - result = coll.find_one({"id": id}) - else: - result = list( - coll.aggregate( - [ - {"$match": {"score": {"$gt": threshold}}}, - {"$sample": {"size": 1}}, - ] - ) - )[0] - while result["body"] in ["[removed]", "[deleted]"]: + try: + if randint(1, 100_000) == 5779 and id is None: + await ctx.send(f"<@{ctx.message.author.id}>") + return + # TODO: Add this as a parameter that can be passed in + threshold = 500 # Minimum score + coll = self.db.mongo.jarvis.jokes + result = None + if id: + result = coll.find_one({"id": id}) + else: result = list( coll.aggregate( [ @@ -50,37 +45,50 @@ class JokeCog(commands.Cog): ] ) )[0] - # TODO: Build a custom embed to show the joke - if result is None: - await ctx.send("Humor module failed. Please try again later.") - return - emotes = re.findall(r"(&#x[a-fA-F0-9]*;)", result["body"]) - for match in emotes: - result["body"] = result["body"].replace( - match, html.unescape(match) + while result["body"] in ["[removed]", "[deleted]"]: + result = list( + coll.aggregate( + [ + {"$match": {"score": {"$gt": threshold}}}, + {"$sample": {"size": 1}}, + ] + ) + )[0] + # TODO: Build a custom embed to show the joke + if result is None: + await ctx.send("Humor module failed. Please try again later.") + return + emotes = re.findall(r"(&#x[a-fA-F0-9]*;)", result["body"]) + for match in emotes: + result["body"] = result["body"].replace( + match, html.unescape(match) + ) + emotes = re.findall(r"(&#x[a-fA-F0-9]*;)", result["title"]) + for match in emotes: + result["title"] = result["title"].replace( + match, html.unescape(match) + ) + fields = [ + Field("​", result["body"], False), + Field("Score", result["score"]), + # Field( + # "Created At", + # str(datetime.fromtimestamp(result["created_utc"])), + # ), + Field("ID", result["id"]), + ] + embed = build_embed( + title=result["title"], + description="", + fields=fields, + url=f"https://reddit.com/r/jokes/comments/{result['id']}", + timestamp=datetime.fromtimestamp(result["created_utc"]), ) - emotes = re.findall(r"(&#x[a-fA-F0-9]*;)", result["title"]) - for match in emotes: - result["title"] = result["title"].replace( - match, html.unescape(match) + await ctx.send(embed=embed) + except Exception: + await ctx.send( + "Encountered error:\n```\n" + traceback.format_exc() + "\n```" ) - fields = [ - Field("​", result["body"], False), - Field("Score", result["score"]), - # Field( - # "Created At", - # str(datetime.fromtimestamp(result["created_utc"])), - # ), - Field("ID", result["id"]), - ] - embed = build_embed( - title=result["title"], - description="", - fields=fields, - url=f"https://reddit.com/r/jokes/comments/{result['id']}", - timestamp=datetime.fromtimestamp(result["created_utc"]), - ) - await ctx.send(embed=embed) # await ctx.send(f"**{result['title']}**\n\n{result['body']}") diff --git a/jarvis/config.py b/jarvis/config.py index 52dabb9..dc44da2 100644 --- a/jarvis/config.py +++ b/jarvis/config.py @@ -1,5 +1,5 @@ -from dataclasses import dataclass from yaml import load +from jarvis.utils.db import DBManager try: from yaml import CLoader as Loader @@ -7,14 +7,27 @@ except ImportError: from yaml import Loader -@dataclass class Config: - token: str - client_id: str - admins: list - logo: str - mongo: dict - urls: dict + def __new__(cls, *args, **kwargs): + it = cls.__dict__.get("it") + if it is not None: + return it + cls.__it__ = it = object.__new__(cls) + it.init(*args, **kwargs) + return it + + def init( + self, token: str, client_id: str, logo: str, mongo: dict, urls: dict + ): + self.token = token + self.client_id = client_id + self.logo = logo + self.mongo = mongo + self.urls = urls + db = DBManager(config=mongo).mongo.jarvis.config + db_config = db.find() + for item in db_config: + setattr(self, item["key"], item["value"]) @classmethod def from_yaml(cls, y): @@ -23,6 +36,8 @@ class Config: def get_config(path: str = "config.yaml") -> Config: + if Config.__dict__.get("it"): + return Config() with open(path) as f: raw = f.read() y = load(raw, Loader=Loader) diff --git a/jarvis/utils/__init__.py b/jarvis/utils/__init__.py index 881b8ab..499037a 100644 --- a/jarvis/utils/__init__.py +++ b/jarvis/utils/__init__.py @@ -30,8 +30,8 @@ def unconvert_bytesize(size, ending: str): def get_prefix(bot, message): prefixes = ["$", ">", "?", "!"] - if not message.guild: - return "?" + # if not message.guild: + # return "?" return commands.when_mentioned_or(*prefixes)(bot, message) diff --git a/jarvis/utils/db.py b/jarvis/utils/db.py index 20a79d6..d6c564c 100644 --- a/jarvis/utils/db.py +++ b/jarvis/utils/db.py @@ -1,7 +1,14 @@ from pymongo import MongoClient -from jarvis.config import get_config -def create_connection() -> MongoClient: - config = get_config() - return MongoClient(**config.mongo) +class DBManager(object): + def __new__(cls, *args, **kwargs): + it = cls.__dict__.get("it") + if it is not None: + return it + cls.__it__ = it = object.__new__(cls) + it.init(*args, **kwargs) + return it + + def init(self, config: dict): + self.mongo = MongoClient(**config)