Migrate permission checks, utils, and init
This commit is contained in:
parent
9055897965
commit
9fa3e2c26b
9 changed files with 1565 additions and 98 deletions
|
@ -1,15 +1,8 @@
|
||||||
"""Main J.A.R.V.I.S. package."""
|
"""Main J.A.R.V.I.S. package."""
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from discord import Intents
|
from dis_snek import Intents, Snake
|
||||||
from discord.ext import commands
|
|
||||||
from discord.utils import find
|
|
||||||
from discord_slash import SlashCommand
|
|
||||||
from mongoengine import connect
|
from mongoengine import connect
|
||||||
from psutil import Process
|
|
||||||
|
|
||||||
from jarvis import logo # noqa: F401
|
from jarvis import logo # noqa: F401
|
||||||
from jarvis import tasks, utils
|
from jarvis import tasks, utils
|
||||||
|
@ -24,53 +17,26 @@ file_handler = logging.FileHandler(filename="jarvis.log", encoding="UTF-8", mode
|
||||||
file_handler.setFormatter(logging.Formatter("[%(asctime)s][%(levelname)s][%(name)s] %(message)s"))
|
file_handler.setFormatter(logging.Formatter("[%(asctime)s][%(levelname)s][%(name)s] %(message)s"))
|
||||||
logger.addHandler(file_handler)
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
if asyncio.get_event_loop().is_closed():
|
|
||||||
asyncio.set_event_loop(asyncio.new_event_loop())
|
|
||||||
|
|
||||||
intents = Intents.default()
|
intents = Intents.default()
|
||||||
intents.members = True
|
intents.members = True
|
||||||
restart_ctx = None
|
restart_ctx = None
|
||||||
|
|
||||||
|
|
||||||
jarvis = commands.Bot(
|
jarvis = Snake(intents=intents, default_prefix=utils.get_prefix, sync_interactions=jconfig.sync)
|
||||||
command_prefix=utils.get_prefix,
|
|
||||||
intents=intents,
|
|
||||||
help_command=None,
|
|
||||||
max_messages=jconfig.max_messages,
|
|
||||||
)
|
|
||||||
|
|
||||||
slash = SlashCommand(jarvis, sync_commands=False, sync_on_cog_reload=True)
|
__version__ = "2.0.0a0"
|
||||||
jarvis_self = Process()
|
|
||||||
__version__ = "1.11.2"
|
|
||||||
|
|
||||||
|
|
||||||
@jarvis.event
|
@jarvis.add_listener
|
||||||
async def on_ready() -> None:
|
async def on_ready() -> None:
|
||||||
"""d.py on_ready override."""
|
"""Lepton on_ready override."""
|
||||||
global restart_ctx
|
global restart_ctx
|
||||||
print(" Logged in as {0.user}".format(jarvis))
|
print(" Logged in as {0.user}".format(jarvis))
|
||||||
print(" Connected to {} guild(s)".format(len(jarvis.guilds)))
|
print(" Connected to {} guild(s)".format(len(jarvis.guilds)))
|
||||||
with jarvis_self.oneshot():
|
|
||||||
print(f" Current PID: {jarvis_self.pid}")
|
|
||||||
Path(f"jarvis.{jarvis_self.pid}.pid").touch()
|
|
||||||
if restart_ctx:
|
|
||||||
channel = None
|
|
||||||
if "guild" in restart_ctx:
|
|
||||||
guild = find(lambda x: x.id == restart_ctx["guild"], jarvis.guilds)
|
|
||||||
if guild:
|
|
||||||
channel = find(lambda x: x.id == restart_ctx["channel"], guild.channels)
|
|
||||||
elif "user" in restart_ctx:
|
|
||||||
channel = jarvis.get_user(restart_ctx["user"])
|
|
||||||
if channel:
|
|
||||||
await channel.send("Core systems restarted and back online.")
|
|
||||||
restart_ctx = None
|
|
||||||
|
|
||||||
|
|
||||||
def run(ctx: dict = None) -> Optional[dict]:
|
def run() -> None:
|
||||||
"""Run J.A.R.V.I.S."""
|
"""Run J.A.R.V.I.S."""
|
||||||
global restart_ctx
|
|
||||||
if ctx:
|
|
||||||
restart_ctx = ctx
|
|
||||||
connect(
|
connect(
|
||||||
db="ctc2",
|
db="ctc2",
|
||||||
alias="ctc2",
|
alias="ctc2",
|
||||||
|
@ -84,8 +50,10 @@ def run(ctx: dict = None) -> Optional[dict]:
|
||||||
**jconfig.mongo["connect"],
|
**jconfig.mongo["connect"],
|
||||||
)
|
)
|
||||||
jconfig.get_db_config()
|
jconfig.get_db_config()
|
||||||
|
|
||||||
for extension in utils.get_extensions():
|
for extension in utils.get_extensions():
|
||||||
jarvis.load_extension(extension)
|
jarvis.load_extension(extension)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
" https://discord.com/api/oauth2/authorize?client_id="
|
" https://discord.com/api/oauth2/authorize?client_id="
|
||||||
+ "{}&permissions=8&scope=bot%20applications.commands".format(jconfig.client_id) # noqa: W503
|
+ "{}&permissions=8&scope=bot%20applications.commands".format(jconfig.client_id) # noqa: W503
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
"""J.A.R.V.I.S. Admin Cogs."""
|
"""J.A.R.V.I.S. Admin Cogs."""
|
||||||
from discord.ext.commands import Bot
|
from dis_snek import Snake
|
||||||
|
|
||||||
from jarvis.cogs.admin import ban, kick, lock, lockdown, mute, purge, roleping, warning
|
from jarvis.cogs.admin import ban, kick, lock, lockdown, mute, purge, roleping, warning
|
||||||
|
|
||||||
|
|
||||||
def setup(bot: Bot) -> None:
|
def setup(bot: Snake) -> None:
|
||||||
"""Add admin cogs to J.A.R.V.I.S."""
|
"""Add admin cogs to J.A.R.V.I.S."""
|
||||||
bot.add_cog(ban.BanCog(bot))
|
bot.add_cog(ban.BanCog(bot))
|
||||||
bot.add_cog(kick.KickCog(bot))
|
bot.add_cog(kick.KickCog(bot))
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
"""Load the config for J.A.R.V.I.S."""
|
"""Load the config for J.A.R.V.I.S."""
|
||||||
|
import os
|
||||||
|
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
from yaml import load
|
from yaml import load
|
||||||
|
|
||||||
|
@ -27,6 +29,7 @@ class Config(object):
|
||||||
logo: str,
|
logo: str,
|
||||||
mongo: dict,
|
mongo: dict,
|
||||||
urls: dict,
|
urls: dict,
|
||||||
|
sync: bool,
|
||||||
log_level: str = "WARNING",
|
log_level: str = "WARNING",
|
||||||
cogs: list = None,
|
cogs: list = None,
|
||||||
events: bool = True,
|
events: bool = True,
|
||||||
|
@ -46,6 +49,7 @@ class Config(object):
|
||||||
self.max_messages = max_messages
|
self.max_messages = max_messages
|
||||||
self.gitlab_token = gitlab_token
|
self.gitlab_token = gitlab_token
|
||||||
self.twitter = twitter
|
self.twitter = twitter
|
||||||
|
self.sync = sync or os.environ("SYNC_COMMANDS", False)
|
||||||
self.__db_loaded = False
|
self.__db_loaded = False
|
||||||
self.__mongo = MongoClient(**self.mongo["connect"])
|
self.__mongo = MongoClient(**self.mongo["connect"])
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
"""J.A.R.V.I.S. Utility Functions."""
|
"""J.A.R.V.I.S. Utility Functions."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pkgutil import iter_modules
|
from pkgutil import iter_modules
|
||||||
|
from typing import Any, Callable, Iterable, Optional, TypeVar
|
||||||
|
|
||||||
import git
|
import git
|
||||||
from discord import Color, Embed, Message
|
from dis_snek.models.discord.embed import Color, Embed
|
||||||
|
from dis_snek.models.discord.message import Message
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
import jarvis.cogs
|
import jarvis.cogs
|
||||||
|
@ -12,6 +14,30 @@ from jarvis.config import get_config
|
||||||
|
|
||||||
__all__ = ["field", "db", "cachecog", "permissions"]
|
__all__ = ["field", "db", "cachecog", "permissions"]
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def build_embed(
|
||||||
|
title: str,
|
||||||
|
description: str,
|
||||||
|
fields: list,
|
||||||
|
color: str = "#FF0000",
|
||||||
|
timestamp: datetime = None,
|
||||||
|
**kwargs: dict,
|
||||||
|
) -> Embed:
|
||||||
|
"""Embed builder utility function."""
|
||||||
|
if not timestamp:
|
||||||
|
timestamp = datetime.utcnow()
|
||||||
|
embed = Embed(
|
||||||
|
title=title,
|
||||||
|
description=description,
|
||||||
|
color=parse_color_hex(color),
|
||||||
|
timestamp=timestamp,
|
||||||
|
fields=fields,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return embed
|
||||||
|
|
||||||
|
|
||||||
def convert_bytesize(b: int) -> str:
|
def convert_bytesize(b: int) -> str:
|
||||||
"""Convert bytes amount to human readable."""
|
"""Convert bytes amount to human readable."""
|
||||||
|
@ -57,29 +83,6 @@ def parse_color_hex(hex: str) -> Color:
|
||||||
return Color.from_rgb(*rgb)
|
return Color.from_rgb(*rgb)
|
||||||
|
|
||||||
|
|
||||||
def build_embed(
|
|
||||||
title: str,
|
|
||||||
description: str,
|
|
||||||
fields: list,
|
|
||||||
color: str = "#FF0000",
|
|
||||||
timestamp: datetime = None,
|
|
||||||
**kwargs: dict,
|
|
||||||
) -> Embed:
|
|
||||||
"""Embed builder utility function."""
|
|
||||||
if not timestamp:
|
|
||||||
timestamp = datetime.utcnow()
|
|
||||||
embed = Embed(
|
|
||||||
title=title,
|
|
||||||
description=description,
|
|
||||||
color=parse_color_hex(color),
|
|
||||||
timestamp=timestamp,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
for field in fields:
|
|
||||||
embed.add_field(**field.to_dict())
|
|
||||||
return embed
|
|
||||||
|
|
||||||
|
|
||||||
def update() -> int:
|
def update() -> int:
|
||||||
"""J.A.R.V.I.S. update utility."""
|
"""J.A.R.V.I.S. update utility."""
|
||||||
repo = git.Repo(".")
|
repo = git.Repo(".")
|
||||||
|
@ -99,3 +102,10 @@ def get_repo_hash() -> str:
|
||||||
"""J.A.R.V.I.S. current branch hash."""
|
"""J.A.R.V.I.S. current branch hash."""
|
||||||
repo = git.Repo(".")
|
repo = git.Repo(".")
|
||||||
return repo.head.object.hexsha
|
return repo.head.object.hexsha
|
||||||
|
|
||||||
|
|
||||||
|
def find(predicate: Callable[[T], Any], seq: Iterable[T]) -> Optional[T]:
|
||||||
|
for element in seq:
|
||||||
|
if predicate(element):
|
||||||
|
return element
|
||||||
|
return None
|
||||||
|
|
|
@ -1,21 +1,22 @@
|
||||||
"""Cog wrapper for command caching."""
|
"""Cog wrapper for command caching."""
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from discord.ext import commands
|
from dis_snek import InteractionContext, Scale, Snek
|
||||||
from discord.ext.tasks import loop
|
from dis_snek.ext.tasks.task import Task
|
||||||
from discord.utils import find
|
from dis_snek.ext.tasks.triggers import IntervalTrigger
|
||||||
from discord_slash import SlashContext
|
|
||||||
|
from jarvis.utils import find
|
||||||
|
|
||||||
|
|
||||||
class CacheCog(commands.Cog):
|
class CacheCog(Scale):
|
||||||
"""Cog wrapper for command caching."""
|
"""Cog wrapper for command caching."""
|
||||||
|
|
||||||
def __init__(self, bot: commands.Bot):
|
def __init__(self, bot: Snek):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self._expire_interaction.start()
|
self._expire_interaction.start()
|
||||||
|
|
||||||
def check_cache(self, ctx: SlashContext, **kwargs: dict) -> dict:
|
def check_cache(self, ctx: InteractionContext, **kwargs: dict) -> dict:
|
||||||
"""Check the cache."""
|
"""Check the cache."""
|
||||||
if not kwargs:
|
if not kwargs:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
@ -27,7 +28,7 @@ class CacheCog(commands.Cog):
|
||||||
self.cache.values(),
|
self.cache.values(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@loop(minutes=1)
|
@Task.create(IntervalTrigger(minutes=1))
|
||||||
async def _expire_interaction(self) -> None:
|
async def _expire_interaction(self) -> None:
|
||||||
keys = list(self.cache.keys())
|
keys = list(self.cache.keys())
|
||||||
for key in keys:
|
for key in keys:
|
||||||
|
|
|
@ -1,16 +0,0 @@
|
||||||
"""Embed field helper."""
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Field:
|
|
||||||
"""Embed Field."""
|
|
||||||
|
|
||||||
name: Any
|
|
||||||
value: Any
|
|
||||||
inline: bool = True
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
"""Convert Field to d.py field dict."""
|
|
||||||
return {"name": self.name, "value": self.value, "inline": self.inline}
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""Permissions wrappers."""
|
"""Permissions wrappers."""
|
||||||
from discord.ext import commands
|
from dis_snek import Context, Permissions
|
||||||
|
|
||||||
from jarvis.config import get_config
|
from jarvis.config import get_config
|
||||||
|
|
||||||
|
@ -7,22 +7,21 @@ from jarvis.config import get_config
|
||||||
def user_is_bot_admin() -> bool:
|
def user_is_bot_admin() -> bool:
|
||||||
"""Check if a user is a J.A.R.V.I.S. admin."""
|
"""Check if a user is a J.A.R.V.I.S. admin."""
|
||||||
|
|
||||||
def predicate(ctx: commands.Context) -> bool:
|
def predicate(ctx: Context) -> bool:
|
||||||
"""Command check predicate."""
|
"""Command check predicate."""
|
||||||
if getattr(get_config(), "admins", None):
|
if getattr(get_config(), "admins", None):
|
||||||
return ctx.author.id in get_config().admins
|
return ctx.author.id in get_config().admins
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return commands.check(predicate)
|
return predicate
|
||||||
|
|
||||||
|
|
||||||
def admin_or_permissions(**perms: dict) -> bool:
|
def admin_or_permissions(*perms: list) -> bool:
|
||||||
"""Check if a user is an admin or has other perms."""
|
"""Check if a user is an admin or has other perms."""
|
||||||
original = commands.has_permissions(**perms).predicate
|
|
||||||
|
|
||||||
async def extended_check(ctx: commands.Context) -> bool:
|
async def predicate(ctx: Context) -> bool:
|
||||||
"""Extended check predicate.""" # noqa: D401
|
"""Extended check predicate.""" # noqa: D401
|
||||||
return await commands.has_permissions(administrator=True).predicate(ctx) or await original(ctx)
|
return ctx.author.has_permission(Permissions.ADMINISTRATOR) or ctx.author.has_permission(*perms)
|
||||||
|
|
||||||
return commands.check(extended_check)
|
return predicate
|
||||||
|
|
1476
poetry.lock
generated
Normal file
1476
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
25
pyproject.toml
Normal file
25
pyproject.toml
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
[tool.poetry]
|
||||||
|
name = "jarvis"
|
||||||
|
version = "2.0.0a0"
|
||||||
|
description = "J.A.R.V.I.S. admin bot"
|
||||||
|
authors = ["Zevaryx <zevaryx@gmail.com>"]
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = "^3.10"
|
||||||
|
PyYAML = "^6.0"
|
||||||
|
dis-snek = "^5.0.0"
|
||||||
|
GitPython = "^3.1.26"
|
||||||
|
mongoengine = "^0.23.1"
|
||||||
|
opencv-python = "^4.5.5"
|
||||||
|
Pillow = "^9.0.0"
|
||||||
|
psutil = "^5.9.0"
|
||||||
|
python-gitlab = "^3.1.1"
|
||||||
|
ulid-py = "^1.1.0"
|
||||||
|
|
||||||
|
[tool.poetry.dev-dependencies]
|
||||||
|
python-lsp-server = {extras = ["all"], version = "^1.3.3"}
|
||||||
|
black = "^22.1.0"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core>=1.0.0"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
Loading…
Add table
Reference in a new issue