From b1d5415625119820c760eaa655910b42c3ce9ce8 Mon Sep 17 00:00:00 2001 From: Zevaryx Date: Sun, 1 May 2022 15:01:46 -0600 Subject: [PATCH] Create helper function for updates for easier debugging --- jarvis/cogs/botutil.py | 125 +++------------------------ jarvis/utils/updates.py | 187 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 200 insertions(+), 112 deletions(-) create mode 100644 jarvis/utils/updates.py diff --git a/jarvis/cogs/botutil.py b/jarvis/cogs/botutil.py index 3b901d4..5b60d2c 100644 --- a/jarvis/cogs/botutil.py +++ b/jarvis/cogs/botutil.py @@ -1,22 +1,19 @@ """JARVIS bot utility commands.""" -import asyncio import logging import platform from io import BytesIO -import git import psutil from aiofile import AIOFile, LineReader from dis_snek import MessageContext, Scale, Snake from dis_snek.client.errors import HTTPException -from dis_snek.client.utils.misc_utils import find from dis_snek.models.discord.embed import EmbedField from dis_snek.models.discord.file import File from molter import msg_command from rich.console import Console -from rich.table import Table -from jarvis.utils import build_embed, get_all_commands +from jarvis.utils import build_embed +from jarvis.utils.updates import update class BotutilCog(Scale): @@ -80,118 +77,23 @@ class BotutilCog(Scale): @msg_command(name="update") async def _update(self, ctx: MessageContext) -> None: - repo = git.Repo(".") - current_hash = repo.head.object.hexsha - origin = repo.remotes.origin - origin.fetch() - remote_hash = origin.refs[repo.active_branch.name].object.hexsha - - if current_hash != remote_hash: - self.logger.info("Updating...") - current_commands = get_all_commands() - changes = origin.pull() - - self.logger.info("Changes pulled...") - self.logger.debug("Sleeping for 3 seconds to allow changes") - await asyncio.sleep(3) - self.logger.debug("Finished sleeping, loading new commands") - - reloaded = [] - loaded = [] - unloaded = [] - new_commands = get_all_commands() - for module in current_commands: - if module not in new_commands: - self.bot.unload_extension(module) - unloaded.append(module) - for module, commands in new_commands.items(): - self.logger.debug(f"Processing {module}") - if module not in current_commands: - self.bot.load_extension(module) - loaded.append(module) - elif len(current_commands[module]) != len(commands): - self.bot.reload_extension(module) - reloaded.append(module) - else: - for command in commands: - old_command = find( - lambda x: x.resolved_name == command.resolved_name, - current_commands[module], - ) - - # Extract useful info - old_args = old_command.options - if old_args: - old_arg_names = [x.name for x in old_args] - new_args = command.options - if new_args: - new_arg_names = [x.name for x in new_args] - - # No changes - if not old_args and not new_args: - continue - - # Check if number arguments have changed - if len(old_args) != len(new_args): - self.bot.reload_extension(module) - reloaded.append(module) - elif any(x not in old_arg_names for x in new_arg_names) or any( - x not in new_arg_names for x in old_arg_names - ): - self.bot.reload_extension(module) - reloaded.append(module) - elif any(new_args[idx].type != x.type for idx, x in enumerate(old_args)): - self.bot.reload_extension(module) - reloaded.append(module) - - file_changes = {} - for change in sorted(changes, key=lambda x: x.commit.committed_datetime, reverse=True): - if change.commit.hexsha == current_hash: - break - files = change.commit.stats.files - for file, stats in files.items(): - if file not in file_changes: - file_changes[file] = {"insertions": 0, "deletions": 0, "lines": 0} - for k, v in stats.items(): - file_changes[file][k] += v - - table = Table(title="File Changes") - - table.add_column("File", justify="left", style="white", no_wrap=True) - table.add_column("Insertions", justify="center", style="green") - table.add_column("Deletions", justify="center", style="red") - table.add_column("Lines", justify="center", style="magenta") - - i_total = 0 - d_total = 0 - l_total = 0 - for file, stats in file_changes.items(): - i_total += stats["insertions"] - d_total += stats["deletions"] - l_total += stats["lines"] - table.add_row( - file, - str(stats["insertions"]), - str(stats["deletions"]), - str(stats["lines"]), - ) - - table.add_row("Total", str(i_total), str(d_total), str(l_total)) + status = update(self.bot, self.logger) + if status: console = Console() with console.capture() as capture: - console.print(table) + console.print(status.table) self.logger.debug(capture.get()) self.logger.debug(len(capture.get())) - new = "\n".join(loaded) - removed = "\n".join(unloaded) - changed = "\n".join(reloaded) + added = "\n".join(status.added) + removed = "\n".join(status.removed) + changed = "\n".join(status.changed) fields = [ - EmbedField(name="Old Commit", value=current_hash), - EmbedField(name="New Commit", value=remote_hash), + EmbedField(name="Old Commit", value=status.old_hash), + EmbedField(name="New Commit", value=status.new_hash), ] - if loaded: - fields.append(EmbedField(name="New Modules", value=f"```\n{new}\n```")) + if added: + fields.append(EmbedField(name="New Modules", value=f"```\n{added}\n```")) if removed: fields.append(EmbedField(name="Removed Modules", value=f"```\n{removed}\n```")) if changed: @@ -206,11 +108,10 @@ class BotutilCog(Scale): try: await ctx.reply(f"```ansi\n{capture.get()}\n```", embed=embed) except HTTPException: - await ctx.reply(f"Total Changes: {l_total}", embed=embed) + await ctx.reply(f"Total Changes: {status.total_lines}", embed=embed) else: embed = build_embed(title="Update Status", description="No changes applied", fields=[]) - embed.set_footer(text=current_hash) embed.set_thumbnail(url="https://dev.zevaryx.com/git.png") await ctx.reply(embed=embed) diff --git a/jarvis/utils/updates.py b/jarvis/utils/updates.py new file mode 100644 index 0000000..b1a1cad --- /dev/null +++ b/jarvis/utils/updates.py @@ -0,0 +1,187 @@ +"""JARVIS update handler.""" +import asyncio +import importlib +import inspect +import logging +from dataclasses import dataclass +from pkgutil import iter_modules +from types import ModuleType +from typing import TYPE_CHECKING, Callable, Dict, List, Optional + +import git +from dis_snek.client.utils.misc_utils import find, find_all +from dis_snek.models.snek.application_commands import SlashCommand +from dis_snek.models.snek.scale import Scale +from rich.table import Table + +import jarvis.cogs + +if TYPE_CHECKING: + from logging import Logger + + from dis_snek.client.client import Snake + + +@dataclass +class UpdateResult: + """JARVIS update result.""" + + old_hash: str + new_hash: str + table: Table + added: List[str] + removed: List[str] + changed: List[str] + inserted_lines: int + deleted_lines: int + total_lines: int + + +def get_all_commands(module: ModuleType = jarvis.cogs) -> Dict[str, Callable]: + """Get all SlashCommands from a specified module.""" + commands = {} + for item in iter_modules(module.__path__): + new_module = importlib.import_module(f"{module.__name__}.{item.name}") + if item.ispkg: + if cmds := get_all_commands(new_module): + commands.update(cmds) + else: + inspect_result = inspect.getmembers(new_module) + cogs = [] + for _, val in inspect_result: + if inspect.isclass(val) and issubclass(val, Scale) and val is not Scale: + cogs.append(val) + for cog in cogs: + values = cog.__dict__.values() + commands[cog.__module__] = find_all(lambda x: isinstance(x, SlashCommand), values) + return {k: v for k, v in commands.items() if v} + + +async def update(bot: "Snake", logger: "Logger" = None) -> Optional[UpdateResult]: + """ + Update JARVIS and return an UpdateResult. + + Args: + bot: Bot instance + logger: Logger instance + + Returns: + UpdateResult object + """ + if not logger: + logger = logging.getLogger(__name__) + repo = git.Repo(".") + current_hash = repo.head.object.hexsha + origin = repo.remotes.origin + origin.fetch() + remote_hash = origin.refs[repo.active_branch.name].object.hexsha + + if current_hash != remote_hash: + logger.info(f"Updating from {current_hash} to {remote_hash}") + current_commands = get_all_commands() + + changes = origin.pull() + logger.info(f"Pulled {len(changes)} changes") + await asyncio.sleep(3) + + new_commands = get_all_commands() + + logger.info("Checking if any modules need reloaded...") + + reloaded = [] + loaded = [] + unloaded = [] + + logger.debug("Checking for removed cogs") + for module in current_commands.keys(): + if module not in new_commands: + logger.debug(f"Module {module} removed after update") + bot.unload_extension(module) + unloaded.append(module) + + logger.debug("Checking for new/modified commands") + for module, commands in new_commands.items(): + logger.debug(f"Processing {module}") + if module not in current_commands: + bot.load_extension(module) + loaded.append(module) + elif len(current_commands[module]) != len(commands): + bot.reload_extension(module) + reloaded.append(module) + else: + for command in commands: + old_command = find( + lambda x: x.resolved_name == command.resolved_name, current_commands[module] + ) + + # Extract useful info + old_args = old_command.options + if old_args: + old_arg_names = [x.name for x in old_args] + new_args = command.options + if new_args: + new_arg_names = [x.name for x in new_args] + + # No changes + if not old_args and not new_args: + continue + + # Check if number arguments have changed + if len(old_args) != len(new_args): + bot.reload_extension(module) + reloaded.append(module) + elif any(x not in old_arg_names for x in new_arg_names) or any( + x not in new_arg_names for x in old_arg_names + ): + bot.reload_extension(module) + reloaded.append(module) + elif any(new_args[idx].type != x.type for idx, x in enumerate(old_args)): + bot.reload_extension(module) + reloaded.append(module) + + file_changes = {} + for change in sorted(changes, key=lambda x: x.commit.committed_datetime, reverse=True): + if change.commit.hexsha == current_hash: + break + files = change.commit.stats.files + for file, stats in files.items(): + if file not in file_changes: + file_changes[file] = {"insertions": 0, "deletions": 0, "lines": 0} + for k, v in stats.items(): + file_changes[file][k] += v + + table = Table(title="File Changes") + + table.add_column("File", justify="left", style="white", no_wrap=True) + table.add_column("Insertions", justify="center", style="green") + table.add_column("Deletions", justify="center", style="red") + table.add_column("Lines", justify="center", style="magenta") + + i_total = 0 + d_total = 0 + l_total = 0 + for file, stats in file_changes.items(): + i_total += stats["insertions"] + d_total += stats["deletions"] + l_total += stats["lines"] + table.add_row( + file, + str(stats["insertions"]), + str(stats["deletions"]), + str(stats["lines"]), + ) + + table.add_row("Total", str(i_total), str(d_total), str(l_total)) + + return UpdateResult( + table=table, + old_hash=current_hash, + new_hash=remote_hash, + added=loaded, + removed=unloaded, + changed=reloaded, + inserted_lines=i_total, + deleted_lines=d_total, + total_lines=l_total, + ) + return None