diff --git a/jarvis/cogs/botutil.py b/jarvis/cogs/botutil.py index 8e808a0..a02c1f8 100644 --- a/jarvis/cogs/botutil.py +++ b/jarvis/cogs/botutil.py @@ -1,16 +1,20 @@ """JARVIS bot utility commands.""" +import asyncio import logging import platform from io import BytesIO +from typing import get_type_hints +import git import psutil from aiofile import AIOFile, LineReader from dis_snek import MessageContext, Scale, Snake +from dis_snek.client.utils.misc_utils import find from dis_snek.models.discord.embed import EmbedField from dis_snek.models.discord.file import File from molter import msg_command -from jarvis.utils import build_embed +from jarvis.utils import build_embed, get_all_commands class BotutilCog(Scale): @@ -72,6 +76,38 @@ class BotutilCog(Scale): embed.set_image(url=self.bot.user.avatar.url) await ctx.send(embed=embed) + @msg_command(name="update") + async def _update(self, ctx: MessageContext) -> None: + repo = git.Repo(".") + current_hash = repo.head.object.hexsha + origin = repo.remotes.origin + + if current_hash != origin.refs[repo.active_branch.name].object.hexsha: + current_commands = get_all_commands() + _ = origin.pull() + await asyncio.sleep(3) + new_commands = get_all_commands() + for module, commands in new_commands: + if module not in current_commands: + self.bot.load_extension(module) + elif len(current_commands[module]) != len(commands): + self.bot.reload_extension(module) + else: + for command in commands: + old_command = find( + lambda x: x.__name__ == command.__name__, current_commands + ) + old_args = get_type_hints(old_command) + new_args = get_type_hints(command) + if len(old_args) != len(new_args): + self.bot.reload_extension(module) + elif any(x not in old_args for x in new_args) or any( + x not in new_args for x in old_args + ): + self.bot.reload_extension(module) + elif any(new_args[x] != y for x, y in old_args): + self.bot.reload_extension(module) + def setup(bot: Snake) -> None: """Add BotutilCog to JARVIS""" diff --git a/jarvis/utils/__init__.py b/jarvis/utils/__init__.py index 751bfe6..2937c8a 100644 --- a/jarvis/utils/__init__.py +++ b/jarvis/utils/__init__.py @@ -1,11 +1,18 @@ """JARVIS Utility Functions.""" +import importlib +import inspect from datetime import datetime, timezone from pkgutil import iter_modules +from types import ModuleType +from typing import Callable, Dict import git +from dis_snek.client.utils.misc_utils import find_all from dis_snek.models.discord.embed import Embed, EmbedField from dis_snek.models.discord.guild import AuditLogEntry from dis_snek.models.discord.user import Member +from dis_snek.models.snek import Scale +from dis_snek.models.snek.application_commands import SlashCommand import jarvis.cogs from jarvis.config import get_config @@ -71,6 +78,25 @@ def get_extensions(path: str = jarvis.cogs.__path__) -> list: return ["jarvis.cogs.{}".format(x) for x in vals] +def get_all_commands(module: ModuleType = jarvis.cogs) -> Dict[str, Callable]: + commands = {} + for item in iter_modules(module.__path__): + new_module = importlib.import_module(f"{module.__name__}.{item.name}") + if item.ispkg: + if cmds := get_all_commands(new_module): + commands.update(cmds) + else: + inspect_result = inspect.getmembers(new_module) + cogs = [] + for _, val in inspect_result: + if inspect.isclass(val) and issubclass(val, Scale) and val is not Scale: + cogs.append(val) + for cog in cogs: + values = cog.__dict__.values() + commands[cog.__module__] = find_all(lambda x: isinstance(x, SlashCommand), values) + return {k: v for k, v in commands.items() if v} + + def update() -> int: """JARVIS update utility.""" repo = git.Repo(".")