Add update command, get_all_commands helper

This commit is contained in:
Zeva Rose 2022-04-30 23:07:02 -06:00
parent 73cb150339
commit dc47a44650
2 changed files with 63 additions and 1 deletions

View file

@ -1,16 +1,20 @@
"""JARVIS bot utility commands."""
import asyncio
import logging
import platform
from io import BytesIO
from typing import get_type_hints
import git
import psutil
from aiofile import AIOFile, LineReader
from dis_snek import MessageContext, Scale, Snake
from dis_snek.client.utils.misc_utils import find
from dis_snek.models.discord.embed import EmbedField
from dis_snek.models.discord.file import File
from molter import msg_command
from jarvis.utils import build_embed
from jarvis.utils import build_embed, get_all_commands
class BotutilCog(Scale):
@ -72,6 +76,38 @@ class BotutilCog(Scale):
embed.set_image(url=self.bot.user.avatar.url)
await ctx.send(embed=embed)
@msg_command(name="update")
async def _update(self, ctx: MessageContext) -> None:
repo = git.Repo(".")
current_hash = repo.head.object.hexsha
origin = repo.remotes.origin
if current_hash != origin.refs[repo.active_branch.name].object.hexsha:
current_commands = get_all_commands()
_ = origin.pull()
await asyncio.sleep(3)
new_commands = get_all_commands()
for module, commands in new_commands:
if module not in current_commands:
self.bot.load_extension(module)
elif len(current_commands[module]) != len(commands):
self.bot.reload_extension(module)
else:
for command in commands:
old_command = find(
lambda x: x.__name__ == command.__name__, current_commands
)
old_args = get_type_hints(old_command)
new_args = get_type_hints(command)
if len(old_args) != len(new_args):
self.bot.reload_extension(module)
elif any(x not in old_args for x in new_args) or any(
x not in new_args for x in old_args
):
self.bot.reload_extension(module)
elif any(new_args[x] != y for x, y in old_args):
self.bot.reload_extension(module)
def setup(bot: Snake) -> None:
"""Add BotutilCog to JARVIS"""

View file

@ -1,11 +1,18 @@
"""JARVIS Utility Functions."""
import importlib
import inspect
from datetime import datetime, timezone
from pkgutil import iter_modules
from types import ModuleType
from typing import Callable, Dict
import git
from dis_snek.client.utils.misc_utils import find_all
from dis_snek.models.discord.embed import Embed, EmbedField
from dis_snek.models.discord.guild import AuditLogEntry
from dis_snek.models.discord.user import Member
from dis_snek.models.snek import Scale
from dis_snek.models.snek.application_commands import SlashCommand
import jarvis.cogs
from jarvis.config import get_config
@ -71,6 +78,25 @@ def get_extensions(path: str = jarvis.cogs.__path__) -> list:
return ["jarvis.cogs.{}".format(x) for x in vals]
def get_all_commands(module: ModuleType = jarvis.cogs) -> Dict[str, Callable]:
commands = {}
for item in iter_modules(module.__path__):
new_module = importlib.import_module(f"{module.__name__}.{item.name}")
if item.ispkg:
if cmds := get_all_commands(new_module):
commands.update(cmds)
else:
inspect_result = inspect.getmembers(new_module)
cogs = []
for _, val in inspect_result:
if inspect.isclass(val) and issubclass(val, Scale) and val is not Scale:
cogs.append(val)
for cog in cogs:
values = cog.__dict__.values()
commands[cog.__module__] = find_all(lambda x: isinstance(x, SlashCommand), values)
return {k: v for k, v in commands.items() if v}
def update() -> int:
"""JARVIS update utility."""
repo = git.Repo(".")