Singleton usage, migrate to more mongodb storage
This commit is contained in:
parent
7e489c5646
commit
af544974c6
4 changed files with 94 additions and 64 deletions
|
@ -3,7 +3,9 @@ import discord
|
|||
from random import randint
|
||||
import html
|
||||
import re
|
||||
from jarvis.utils import build_embed, db
|
||||
import traceback
|
||||
from jarvis.utils import build_embed
|
||||
from jarvis.utils.db import DBManager
|
||||
from jarvis.utils.field import Field
|
||||
from discord.ext import commands
|
||||
from datetime import datetime
|
||||
|
@ -18,30 +20,23 @@ class JokeCog(commands.Cog):
|
|||
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.db = db.create_connection()
|
||||
config = jarvis.config.get_config()
|
||||
self.db = DBManager(config.mongo)
|
||||
|
||||
# TODO: Make this a command group with subcommands
|
||||
@commands.command(name="joke", help="Hear a joke")
|
||||
async def _joke(self, ctx, id: str = None):
|
||||
if randint(1, 100_000) == 5779 and id is None:
|
||||
await ctx.send(f"<@{ctx.message.author.id}>")
|
||||
return
|
||||
# TODO: Add this as a parameter that can be passed in
|
||||
threshold = 500 # Minimum score
|
||||
coll = self.db.jokes.reddit
|
||||
result = None
|
||||
if id:
|
||||
result = coll.find_one({"id": id})
|
||||
else:
|
||||
result = list(
|
||||
coll.aggregate(
|
||||
[
|
||||
{"$match": {"score": {"$gt": threshold}}},
|
||||
{"$sample": {"size": 1}},
|
||||
]
|
||||
)
|
||||
)[0]
|
||||
while result["body"] in ["[removed]", "[deleted]"]:
|
||||
try:
|
||||
if randint(1, 100_000) == 5779 and id is None:
|
||||
await ctx.send(f"<@{ctx.message.author.id}>")
|
||||
return
|
||||
# TODO: Add this as a parameter that can be passed in
|
||||
threshold = 500 # Minimum score
|
||||
coll = self.db.mongo.jarvis.jokes
|
||||
result = None
|
||||
if id:
|
||||
result = coll.find_one({"id": id})
|
||||
else:
|
||||
result = list(
|
||||
coll.aggregate(
|
||||
[
|
||||
|
@ -50,37 +45,50 @@ class JokeCog(commands.Cog):
|
|||
]
|
||||
)
|
||||
)[0]
|
||||
# TODO: Build a custom embed to show the joke
|
||||
if result is None:
|
||||
await ctx.send("Humor module failed. Please try again later.")
|
||||
return
|
||||
emotes = re.findall(r"(&#x[a-fA-F0-9]*;)", result["body"])
|
||||
for match in emotes:
|
||||
result["body"] = result["body"].replace(
|
||||
match, html.unescape(match)
|
||||
while result["body"] in ["[removed]", "[deleted]"]:
|
||||
result = list(
|
||||
coll.aggregate(
|
||||
[
|
||||
{"$match": {"score": {"$gt": threshold}}},
|
||||
{"$sample": {"size": 1}},
|
||||
]
|
||||
)
|
||||
)[0]
|
||||
# TODO: Build a custom embed to show the joke
|
||||
if result is None:
|
||||
await ctx.send("Humor module failed. Please try again later.")
|
||||
return
|
||||
emotes = re.findall(r"(&#x[a-fA-F0-9]*;)", result["body"])
|
||||
for match in emotes:
|
||||
result["body"] = result["body"].replace(
|
||||
match, html.unescape(match)
|
||||
)
|
||||
emotes = re.findall(r"(&#x[a-fA-F0-9]*;)", result["title"])
|
||||
for match in emotes:
|
||||
result["title"] = result["title"].replace(
|
||||
match, html.unescape(match)
|
||||
)
|
||||
fields = [
|
||||
Field("", result["body"], False),
|
||||
Field("Score", result["score"]),
|
||||
# Field(
|
||||
# "Created At",
|
||||
# str(datetime.fromtimestamp(result["created_utc"])),
|
||||
# ),
|
||||
Field("ID", result["id"]),
|
||||
]
|
||||
embed = build_embed(
|
||||
title=result["title"],
|
||||
description="",
|
||||
fields=fields,
|
||||
url=f"https://reddit.com/r/jokes/comments/{result['id']}",
|
||||
timestamp=datetime.fromtimestamp(result["created_utc"]),
|
||||
)
|
||||
emotes = re.findall(r"(&#x[a-fA-F0-9]*;)", result["title"])
|
||||
for match in emotes:
|
||||
result["title"] = result["title"].replace(
|
||||
match, html.unescape(match)
|
||||
await ctx.send(embed=embed)
|
||||
except Exception:
|
||||
await ctx.send(
|
||||
"Encountered error:\n```\n" + traceback.format_exc() + "\n```"
|
||||
)
|
||||
fields = [
|
||||
Field("", result["body"], False),
|
||||
Field("Score", result["score"]),
|
||||
# Field(
|
||||
# "Created At",
|
||||
# str(datetime.fromtimestamp(result["created_utc"])),
|
||||
# ),
|
||||
Field("ID", result["id"]),
|
||||
]
|
||||
embed = build_embed(
|
||||
title=result["title"],
|
||||
description="",
|
||||
fields=fields,
|
||||
url=f"https://reddit.com/r/jokes/comments/{result['id']}",
|
||||
timestamp=datetime.fromtimestamp(result["created_utc"]),
|
||||
)
|
||||
await ctx.send(embed=embed)
|
||||
# await ctx.send(f"**{result['title']}**\n\n{result['body']}")
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from yaml import load
|
||||
from jarvis.utils.db import DBManager
|
||||
|
||||
try:
|
||||
from yaml import CLoader as Loader
|
||||
|
@ -7,14 +7,27 @@ except ImportError:
|
|||
from yaml import Loader
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
token: str
|
||||
client_id: str
|
||||
admins: list
|
||||
logo: str
|
||||
mongo: dict
|
||||
urls: dict
|
||||
def __new__(cls, *args, **kwargs):
|
||||
it = cls.__dict__.get("it")
|
||||
if it is not None:
|
||||
return it
|
||||
cls.__it__ = it = object.__new__(cls)
|
||||
it.init(*args, **kwargs)
|
||||
return it
|
||||
|
||||
def init(
|
||||
self, token: str, client_id: str, logo: str, mongo: dict, urls: dict
|
||||
):
|
||||
self.token = token
|
||||
self.client_id = client_id
|
||||
self.logo = logo
|
||||
self.mongo = mongo
|
||||
self.urls = urls
|
||||
db = DBManager(config=mongo).mongo.jarvis.config
|
||||
db_config = db.find()
|
||||
for item in db_config:
|
||||
setattr(self, item["key"], item["value"])
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, y):
|
||||
|
@ -23,6 +36,8 @@ class Config:
|
|||
|
||||
|
||||
def get_config(path: str = "config.yaml") -> Config:
|
||||
if Config.__dict__.get("it"):
|
||||
return Config()
|
||||
with open(path) as f:
|
||||
raw = f.read()
|
||||
y = load(raw, Loader=Loader)
|
||||
|
|
|
@ -30,8 +30,8 @@ def unconvert_bytesize(size, ending: str):
|
|||
|
||||
def get_prefix(bot, message):
|
||||
prefixes = ["$", ">", "?", "!"]
|
||||
if not message.guild:
|
||||
return "?"
|
||||
# if not message.guild:
|
||||
# return "?"
|
||||
|
||||
return commands.when_mentioned_or(*prefixes)(bot, message)
|
||||
|
||||
|
|
|
@ -1,7 +1,14 @@
|
|||
from pymongo import MongoClient
|
||||
from jarvis.config import get_config
|
||||
|
||||
|
||||
def create_connection() -> MongoClient:
|
||||
config = get_config()
|
||||
return MongoClient(**config.mongo)
|
||||
class DBManager(object):
|
||||
def __new__(cls, *args, **kwargs):
|
||||
it = cls.__dict__.get("it")
|
||||
if it is not None:
|
||||
return it
|
||||
cls.__it__ = it = object.__new__(cls)
|
||||
it.init(*args, **kwargs)
|
||||
return it
|
||||
|
||||
def init(self, config: dict):
|
||||
self.mongo = MongoClient(**config)
|
||||
|
|
Loading…
Add table
Reference in a new issue