Add update command, get_all_commands helper
This commit is contained in:
parent
73cb150339
commit
dc47a44650
2 changed files with 63 additions and 1 deletions
|
@ -1,16 +1,20 @@
|
|||
"""JARVIS bot utility commands."""
|
||||
import asyncio
|
||||
import logging
|
||||
import platform
|
||||
from io import BytesIO
|
||||
from typing import get_type_hints
|
||||
|
||||
import git
|
||||
import psutil
|
||||
from aiofile import AIOFile, LineReader
|
||||
from dis_snek import MessageContext, Scale, Snake
|
||||
from dis_snek.client.utils.misc_utils import find
|
||||
from dis_snek.models.discord.embed import EmbedField
|
||||
from dis_snek.models.discord.file import File
|
||||
from molter import msg_command
|
||||
|
||||
from jarvis.utils import build_embed
|
||||
from jarvis.utils import build_embed, get_all_commands
|
||||
|
||||
|
||||
class BotutilCog(Scale):
|
||||
|
@ -72,6 +76,38 @@ class BotutilCog(Scale):
|
|||
embed.set_image(url=self.bot.user.avatar.url)
|
||||
await ctx.send(embed=embed)
|
||||
|
||||
@msg_command(name="update")
|
||||
async def _update(self, ctx: MessageContext) -> None:
|
||||
repo = git.Repo(".")
|
||||
current_hash = repo.head.object.hexsha
|
||||
origin = repo.remotes.origin
|
||||
|
||||
if current_hash != origin.refs[repo.active_branch.name].object.hexsha:
|
||||
current_commands = get_all_commands()
|
||||
_ = origin.pull()
|
||||
await asyncio.sleep(3)
|
||||
new_commands = get_all_commands()
|
||||
for module, commands in new_commands:
|
||||
if module not in current_commands:
|
||||
self.bot.load_extension(module)
|
||||
elif len(current_commands[module]) != len(commands):
|
||||
self.bot.reload_extension(module)
|
||||
else:
|
||||
for command in commands:
|
||||
old_command = find(
|
||||
lambda x: x.__name__ == command.__name__, current_commands
|
||||
)
|
||||
old_args = get_type_hints(old_command)
|
||||
new_args = get_type_hints(command)
|
||||
if len(old_args) != len(new_args):
|
||||
self.bot.reload_extension(module)
|
||||
elif any(x not in old_args for x in new_args) or any(
|
||||
x not in new_args for x in old_args
|
||||
):
|
||||
self.bot.reload_extension(module)
|
||||
elif any(new_args[x] != y for x, y in old_args):
|
||||
self.bot.reload_extension(module)
|
||||
|
||||
|
||||
def setup(bot: Snake) -> None:
|
||||
"""Add BotutilCog to JARVIS"""
|
||||
|
|
|
@ -1,11 +1,18 @@
|
|||
"""JARVIS Utility Functions."""
|
||||
import importlib
|
||||
import inspect
|
||||
from datetime import datetime, timezone
|
||||
from pkgutil import iter_modules
|
||||
from types import ModuleType
|
||||
from typing import Callable, Dict
|
||||
|
||||
import git
|
||||
from dis_snek.client.utils.misc_utils import find_all
|
||||
from dis_snek.models.discord.embed import Embed, EmbedField
|
||||
from dis_snek.models.discord.guild import AuditLogEntry
|
||||
from dis_snek.models.discord.user import Member
|
||||
from dis_snek.models.snek import Scale
|
||||
from dis_snek.models.snek.application_commands import SlashCommand
|
||||
|
||||
import jarvis.cogs
|
||||
from jarvis.config import get_config
|
||||
|
@ -71,6 +78,25 @@ def get_extensions(path: str = jarvis.cogs.__path__) -> list:
|
|||
return ["jarvis.cogs.{}".format(x) for x in vals]
|
||||
|
||||
|
||||
def get_all_commands(module: ModuleType = jarvis.cogs) -> Dict[str, Callable]:
|
||||
commands = {}
|
||||
for item in iter_modules(module.__path__):
|
||||
new_module = importlib.import_module(f"{module.__name__}.{item.name}")
|
||||
if item.ispkg:
|
||||
if cmds := get_all_commands(new_module):
|
||||
commands.update(cmds)
|
||||
else:
|
||||
inspect_result = inspect.getmembers(new_module)
|
||||
cogs = []
|
||||
for _, val in inspect_result:
|
||||
if inspect.isclass(val) and issubclass(val, Scale) and val is not Scale:
|
||||
cogs.append(val)
|
||||
for cog in cogs:
|
||||
values = cog.__dict__.values()
|
||||
commands[cog.__module__] = find_all(lambda x: isinstance(x, SlashCommand), values)
|
||||
return {k: v for k, v in commands.items() if v}
|
||||
|
||||
|
||||
def update() -> int:
|
||||
"""JARVIS update utility."""
|
||||
repo = git.Repo(".")
|
||||
|
|
Loading…
Add table
Reference in a new issue