Migrate permission checks, utils, and init

This commit is contained in:
Zeva Rose 2022-02-01 17:54:13 -07:00
parent 9055897965
commit 9fa3e2c26b
9 changed files with 1565 additions and 98 deletions

View file

@ -1,15 +1,8 @@
"""Main J.A.R.V.I.S. package.""" """Main J.A.R.V.I.S. package."""
import asyncio
import logging import logging
from pathlib import Path
from typing import Optional
from discord import Intents from dis_snek import Intents, Snake
from discord.ext import commands
from discord.utils import find
from discord_slash import SlashCommand
from mongoengine import connect from mongoengine import connect
from psutil import Process
from jarvis import logo # noqa: F401 from jarvis import logo # noqa: F401
from jarvis import tasks, utils from jarvis import tasks, utils
@ -24,53 +17,26 @@ file_handler = logging.FileHandler(filename="jarvis.log", encoding="UTF-8", mode
file_handler.setFormatter(logging.Formatter("[%(asctime)s][%(levelname)s][%(name)s] %(message)s")) file_handler.setFormatter(logging.Formatter("[%(asctime)s][%(levelname)s][%(name)s] %(message)s"))
logger.addHandler(file_handler) logger.addHandler(file_handler)
if asyncio.get_event_loop().is_closed():
asyncio.set_event_loop(asyncio.new_event_loop())
intents = Intents.default() intents = Intents.default()
intents.members = True intents.members = True
restart_ctx = None restart_ctx = None
jarvis = commands.Bot( jarvis = Snake(intents=intents, default_prefix=utils.get_prefix, sync_interactions=jconfig.sync)
command_prefix=utils.get_prefix,
intents=intents,
help_command=None,
max_messages=jconfig.max_messages,
)
slash = SlashCommand(jarvis, sync_commands=False, sync_on_cog_reload=True) __version__ = "2.0.0a0"
jarvis_self = Process()
__version__ = "1.11.2"
@jarvis.event @jarvis.add_listener
async def on_ready() -> None: async def on_ready() -> None:
"""d.py on_ready override.""" """Lepton on_ready override."""
global restart_ctx global restart_ctx
print(" Logged in as {0.user}".format(jarvis)) print(" Logged in as {0.user}".format(jarvis))
print(" Connected to {} guild(s)".format(len(jarvis.guilds))) print(" Connected to {} guild(s)".format(len(jarvis.guilds)))
with jarvis_self.oneshot():
print(f" Current PID: {jarvis_self.pid}")
Path(f"jarvis.{jarvis_self.pid}.pid").touch()
if restart_ctx:
channel = None
if "guild" in restart_ctx:
guild = find(lambda x: x.id == restart_ctx["guild"], jarvis.guilds)
if guild:
channel = find(lambda x: x.id == restart_ctx["channel"], guild.channels)
elif "user" in restart_ctx:
channel = jarvis.get_user(restart_ctx["user"])
if channel:
await channel.send("Core systems restarted and back online.")
restart_ctx = None
def run(ctx: dict = None) -> Optional[dict]: def run() -> None:
"""Run J.A.R.V.I.S.""" """Run J.A.R.V.I.S."""
global restart_ctx
if ctx:
restart_ctx = ctx
connect( connect(
db="ctc2", db="ctc2",
alias="ctc2", alias="ctc2",
@ -84,8 +50,10 @@ def run(ctx: dict = None) -> Optional[dict]:
**jconfig.mongo["connect"], **jconfig.mongo["connect"],
) )
jconfig.get_db_config() jconfig.get_db_config()
for extension in utils.get_extensions(): for extension in utils.get_extensions():
jarvis.load_extension(extension) jarvis.load_extension(extension)
print( print(
" https://discord.com/api/oauth2/authorize?client_id=" " https://discord.com/api/oauth2/authorize?client_id="
+ "{}&permissions=8&scope=bot%20applications.commands".format(jconfig.client_id) # noqa: W503 + "{}&permissions=8&scope=bot%20applications.commands".format(jconfig.client_id) # noqa: W503

View file

@ -1,10 +1,10 @@
"""J.A.R.V.I.S. Admin Cogs.""" """J.A.R.V.I.S. Admin Cogs."""
from discord.ext.commands import Bot from dis_snek import Snake
from jarvis.cogs.admin import ban, kick, lock, lockdown, mute, purge, roleping, warning from jarvis.cogs.admin import ban, kick, lock, lockdown, mute, purge, roleping, warning
def setup(bot: Bot) -> None: def setup(bot: Snake) -> None:
"""Add admin cogs to J.A.R.V.I.S.""" """Add admin cogs to J.A.R.V.I.S."""
bot.add_cog(ban.BanCog(bot)) bot.add_cog(ban.BanCog(bot))
bot.add_cog(kick.KickCog(bot)) bot.add_cog(kick.KickCog(bot))

View file

@ -1,4 +1,6 @@
"""Load the config for J.A.R.V.I.S.""" """Load the config for J.A.R.V.I.S."""
import os
from pymongo import MongoClient from pymongo import MongoClient
from yaml import load from yaml import load
@ -27,6 +29,7 @@ class Config(object):
logo: str, logo: str,
mongo: dict, mongo: dict,
urls: dict, urls: dict,
sync: bool,
log_level: str = "WARNING", log_level: str = "WARNING",
cogs: list = None, cogs: list = None,
events: bool = True, events: bool = True,
@ -46,6 +49,7 @@ class Config(object):
self.max_messages = max_messages self.max_messages = max_messages
self.gitlab_token = gitlab_token self.gitlab_token = gitlab_token
self.twitter = twitter self.twitter = twitter
self.sync = sync or os.environ("SYNC_COMMANDS", False)
self.__db_loaded = False self.__db_loaded = False
self.__mongo = MongoClient(**self.mongo["connect"]) self.__mongo = MongoClient(**self.mongo["connect"])

View file

@ -1,9 +1,11 @@
"""J.A.R.V.I.S. Utility Functions.""" """J.A.R.V.I.S. Utility Functions."""
from datetime import datetime from datetime import datetime
from pkgutil import iter_modules from pkgutil import iter_modules
from typing import Any, Callable, Iterable, Optional, TypeVar
import git import git
from discord import Color, Embed, Message from dis_snek.models.discord.embed import Color, Embed
from dis_snek.models.discord.message import Message
from discord.ext import commands from discord.ext import commands
import jarvis.cogs import jarvis.cogs
@ -12,6 +14,30 @@ from jarvis.config import get_config
__all__ = ["field", "db", "cachecog", "permissions"] __all__ = ["field", "db", "cachecog", "permissions"]
T = TypeVar("T")
def build_embed(
title: str,
description: str,
fields: list,
color: str = "#FF0000",
timestamp: datetime = None,
**kwargs: dict,
) -> Embed:
"""Embed builder utility function."""
if not timestamp:
timestamp = datetime.utcnow()
embed = Embed(
title=title,
description=description,
color=parse_color_hex(color),
timestamp=timestamp,
fields=fields,
**kwargs,
)
return embed
def convert_bytesize(b: int) -> str: def convert_bytesize(b: int) -> str:
"""Convert bytes amount to human readable.""" """Convert bytes amount to human readable."""
@ -57,29 +83,6 @@ def parse_color_hex(hex: str) -> Color:
return Color.from_rgb(*rgb) return Color.from_rgb(*rgb)
def build_embed(
title: str,
description: str,
fields: list,
color: str = "#FF0000",
timestamp: datetime = None,
**kwargs: dict,
) -> Embed:
"""Embed builder utility function."""
if not timestamp:
timestamp = datetime.utcnow()
embed = Embed(
title=title,
description=description,
color=parse_color_hex(color),
timestamp=timestamp,
**kwargs,
)
for field in fields:
embed.add_field(**field.to_dict())
return embed
def update() -> int: def update() -> int:
"""J.A.R.V.I.S. update utility.""" """J.A.R.V.I.S. update utility."""
repo = git.Repo(".") repo = git.Repo(".")
@ -99,3 +102,10 @@ def get_repo_hash() -> str:
"""J.A.R.V.I.S. current branch hash.""" """J.A.R.V.I.S. current branch hash."""
repo = git.Repo(".") repo = git.Repo(".")
return repo.head.object.hexsha return repo.head.object.hexsha
def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> Optional[T]:
for element in seq:
if predicate(element):
return element
return None

View file

@ -1,21 +1,22 @@
"""Cog wrapper for command caching.""" """Cog wrapper for command caching."""
from datetime import datetime, timedelta from datetime import datetime, timedelta
from discord.ext import commands from dis_snek import InteractionContext, Scale, Snek
from discord.ext.tasks import loop from dis_snek.ext.tasks.task import Task
from discord.utils import find from dis_snek.ext.tasks.triggers import IntervalTrigger
from discord_slash import SlashContext
from jarvis.utils import find
class CacheCog(commands.Cog): class CacheCog(Scale):
"""Cog wrapper for command caching.""" """Cog wrapper for command caching."""
def __init__(self, bot: commands.Bot): def __init__(self, bot: Snek):
self.bot = bot self.bot = bot
self.cache = {} self.cache = {}
self._expire_interaction.start() self._expire_interaction.start()
def check_cache(self, ctx: SlashContext, **kwargs: dict) -> dict: def check_cache(self, ctx: InteractionContext, **kwargs: dict) -> dict:
"""Check the cache.""" """Check the cache."""
if not kwargs: if not kwargs:
kwargs = {} kwargs = {}
@ -27,7 +28,7 @@ class CacheCog(commands.Cog):
self.cache.values(), self.cache.values(),
) )
@loop(minutes=1) @Task.create(IntervalTrigger(minutes=1))
async def _expire_interaction(self) -> None: async def _expire_interaction(self) -> None:
keys = list(self.cache.keys()) keys = list(self.cache.keys())
for key in keys: for key in keys:

View file

@ -1,16 +0,0 @@
"""Embed field helper."""
from dataclasses import dataclass
from typing import Any
@dataclass
class Field:
"""Embed Field."""
name: Any
value: Any
inline: bool = True
def to_dict(self) -> dict:
"""Convert Field to d.py field dict."""
return {"name": self.name, "value": self.value, "inline": self.inline}

View file

@ -1,5 +1,5 @@
"""Permissions wrappers.""" """Permissions wrappers."""
from discord.ext import commands from dis_snek import Context, Permissions
from jarvis.config import get_config from jarvis.config import get_config
@ -7,22 +7,21 @@ from jarvis.config import get_config
def user_is_bot_admin() -> bool: def user_is_bot_admin() -> bool:
"""Check if a user is a J.A.R.V.I.S. admin.""" """Check if a user is a J.A.R.V.I.S. admin."""
def predicate(ctx: commands.Context) -> bool: def predicate(ctx: Context) -> bool:
"""Command check predicate.""" """Command check predicate."""
if getattr(get_config(), "admins", None): if getattr(get_config(), "admins", None):
return ctx.author.id in get_config().admins return ctx.author.id in get_config().admins
else: else:
return False return False
return commands.check(predicate) return predicate
def admin_or_permissions(**perms: dict) -> bool: def admin_or_permissions(*perms: list) -> bool:
"""Check if a user is an admin or has other perms.""" """Check if a user is an admin or has other perms."""
original = commands.has_permissions(**perms).predicate
async def extended_check(ctx: commands.Context) -> bool: async def predicate(ctx: Context) -> bool:
"""Extended check predicate.""" # noqa: D401 """Extended check predicate.""" # noqa: D401
return await commands.has_permissions(administrator=True).predicate(ctx) or await original(ctx) return ctx.author.has_permission(Permissions.ADMINISTRATOR) or ctx.author.has_permission(*perms)
return commands.check(extended_check) return predicate

1476
poetry.lock generated Normal file

File diff suppressed because it is too large Load diff

25
pyproject.toml Normal file
View file

@ -0,0 +1,25 @@
[tool.poetry]
name = "jarvis"
version = "2.0.0a0"
description = "J.A.R.V.I.S. admin bot"
authors = ["Zevaryx <zevaryx@gmail.com>"]
[tool.poetry.dependencies]
python = "^3.10"
PyYAML = "^6.0"
dis-snek = "^5.0.0"
GitPython = "^3.1.26"
mongoengine = "^0.23.1"
opencv-python = "^4.5.5"
Pillow = "^9.0.0"
psutil = "^5.9.0"
python-gitlab = "^3.1.1"
ulid-py = "^1.1.0"
[tool.poetry.dev-dependencies]
python-lsp-server = {extras = ["all"], version = "^1.3.3"}
black = "^22.1.0"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"