Create helper function for updates for easier debugging

This commit is contained in:
Zeva Rose 2022-05-01 15:01:46 -06:00
parent 92cd0b1eae
commit b1d5415625
2 changed files with 200 additions and 112 deletions

View file

@ -1,22 +1,19 @@
"""JARVIS bot utility commands.""" """JARVIS bot utility commands."""
import asyncio
import logging import logging
import platform import platform
from io import BytesIO from io import BytesIO
import git
import psutil import psutil
from aiofile import AIOFile, LineReader from aiofile import AIOFile, LineReader
from dis_snek import MessageContext, Scale, Snake from dis_snek import MessageContext, Scale, Snake
from dis_snek.client.errors import HTTPException from dis_snek.client.errors import HTTPException
from dis_snek.client.utils.misc_utils import find
from dis_snek.models.discord.embed import EmbedField from dis_snek.models.discord.embed import EmbedField
from dis_snek.models.discord.file import File from dis_snek.models.discord.file import File
from molter import msg_command from molter import msg_command
from rich.console import Console from rich.console import Console
from rich.table import Table
from jarvis.utils import build_embed, get_all_commands from jarvis.utils import build_embed
from jarvis.utils.updates import update
class BotutilCog(Scale): class BotutilCog(Scale):
@ -80,118 +77,23 @@ class BotutilCog(Scale):
@msg_command(name="update") @msg_command(name="update")
async def _update(self, ctx: MessageContext) -> None: async def _update(self, ctx: MessageContext) -> None:
repo = git.Repo(".") status = update(self.bot, self.logger)
current_hash = repo.head.object.hexsha if status:
origin = repo.remotes.origin
origin.fetch()
remote_hash = origin.refs[repo.active_branch.name].object.hexsha
if current_hash != remote_hash:
self.logger.info("Updating...")
current_commands = get_all_commands()
changes = origin.pull()
self.logger.info("Changes pulled...")
self.logger.debug("Sleeping for 3 seconds to allow changes")
await asyncio.sleep(3)
self.logger.debug("Finished sleeping, loading new commands")
reloaded = []
loaded = []
unloaded = []
new_commands = get_all_commands()
for module in current_commands:
if module not in new_commands:
self.bot.unload_extension(module)
unloaded.append(module)
for module, commands in new_commands.items():
self.logger.debug(f"Processing {module}")
if module not in current_commands:
self.bot.load_extension(module)
loaded.append(module)
elif len(current_commands[module]) != len(commands):
self.bot.reload_extension(module)
reloaded.append(module)
else:
for command in commands:
old_command = find(
lambda x: x.resolved_name == command.resolved_name,
current_commands[module],
)
# Extract useful info
old_args = old_command.options
if old_args:
old_arg_names = [x.name for x in old_args]
new_args = command.options
if new_args:
new_arg_names = [x.name for x in new_args]
# No changes
if not old_args and not new_args:
continue
# Check if number arguments have changed
if len(old_args) != len(new_args):
self.bot.reload_extension(module)
reloaded.append(module)
elif any(x not in old_arg_names for x in new_arg_names) or any(
x not in new_arg_names for x in old_arg_names
):
self.bot.reload_extension(module)
reloaded.append(module)
elif any(new_args[idx].type != x.type for idx, x in enumerate(old_args)):
self.bot.reload_extension(module)
reloaded.append(module)
file_changes = {}
for change in sorted(changes, key=lambda x: x.commit.committed_datetime, reverse=True):
if change.commit.hexsha == current_hash:
break
files = change.commit.stats.files
for file, stats in files.items():
if file not in file_changes:
file_changes[file] = {"insertions": 0, "deletions": 0, "lines": 0}
for k, v in stats.items():
file_changes[file][k] += v
table = Table(title="File Changes")
table.add_column("File", justify="left", style="white", no_wrap=True)
table.add_column("Insertions", justify="center", style="green")
table.add_column("Deletions", justify="center", style="red")
table.add_column("Lines", justify="center", style="magenta")
i_total = 0
d_total = 0
l_total = 0
for file, stats in file_changes.items():
i_total += stats["insertions"]
d_total += stats["deletions"]
l_total += stats["lines"]
table.add_row(
file,
str(stats["insertions"]),
str(stats["deletions"]),
str(stats["lines"]),
)
table.add_row("Total", str(i_total), str(d_total), str(l_total))
console = Console() console = Console()
with console.capture() as capture: with console.capture() as capture:
console.print(table) console.print(status.table)
self.logger.debug(capture.get()) self.logger.debug(capture.get())
self.logger.debug(len(capture.get())) self.logger.debug(len(capture.get()))
new = "\n".join(loaded) added = "\n".join(status.added)
removed = "\n".join(unloaded) removed = "\n".join(status.removed)
changed = "\n".join(reloaded) changed = "\n".join(status.changed)
fields = [ fields = [
EmbedField(name="Old Commit", value=current_hash), EmbedField(name="Old Commit", value=status.old_hash),
EmbedField(name="New Commit", value=remote_hash), EmbedField(name="New Commit", value=status.new_hash),
] ]
if loaded: if added:
fields.append(EmbedField(name="New Modules", value=f"```\n{new}\n```")) fields.append(EmbedField(name="New Modules", value=f"```\n{added}\n```"))
if removed: if removed:
fields.append(EmbedField(name="Removed Modules", value=f"```\n{removed}\n```")) fields.append(EmbedField(name="Removed Modules", value=f"```\n{removed}\n```"))
if changed: if changed:
@ -206,11 +108,10 @@ class BotutilCog(Scale):
try: try:
await ctx.reply(f"```ansi\n{capture.get()}\n```", embed=embed) await ctx.reply(f"```ansi\n{capture.get()}\n```", embed=embed)
except HTTPException: except HTTPException:
await ctx.reply(f"Total Changes: {l_total}", embed=embed) await ctx.reply(f"Total Changes: {status.total_lines}", embed=embed)
else: else:
embed = build_embed(title="Update Status", description="No changes applied", fields=[]) embed = build_embed(title="Update Status", description="No changes applied", fields=[])
embed.set_footer(text=current_hash)
embed.set_thumbnail(url="https://dev.zevaryx.com/git.png") embed.set_thumbnail(url="https://dev.zevaryx.com/git.png")
await ctx.reply(embed=embed) await ctx.reply(embed=embed)

187
jarvis/utils/updates.py Normal file
View file

@ -0,0 +1,187 @@
"""JARVIS update handler."""
import asyncio
import importlib
import inspect
import logging
from dataclasses import dataclass
from pkgutil import iter_modules
from types import ModuleType
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
import git
from dis_snek.client.utils.misc_utils import find, find_all
from dis_snek.models.snek.application_commands import SlashCommand
from dis_snek.models.snek.scale import Scale
from rich.table import Table
import jarvis.cogs
if TYPE_CHECKING:
from logging import Logger
from dis_snek.client.client import Snake
@dataclass
class UpdateResult:
"""JARVIS update result."""
old_hash: str
new_hash: str
table: Table
added: List[str]
removed: List[str]
changed: List[str]
inserted_lines: int
deleted_lines: int
total_lines: int
def get_all_commands(module: ModuleType = jarvis.cogs) -> Dict[str, Callable]:
"""Get all SlashCommands from a specified module."""
commands = {}
for item in iter_modules(module.__path__):
new_module = importlib.import_module(f"{module.__name__}.{item.name}")
if item.ispkg:
if cmds := get_all_commands(new_module):
commands.update(cmds)
else:
inspect_result = inspect.getmembers(new_module)
cogs = []
for _, val in inspect_result:
if inspect.isclass(val) and issubclass(val, Scale) and val is not Scale:
cogs.append(val)
for cog in cogs:
values = cog.__dict__.values()
commands[cog.__module__] = find_all(lambda x: isinstance(x, SlashCommand), values)
return {k: v for k, v in commands.items() if v}
async def update(bot: "Snake", logger: "Logger" = None) -> Optional[UpdateResult]:
"""
Update JARVIS and return an UpdateResult.
Args:
bot: Bot instance
logger: Logger instance
Returns:
UpdateResult object
"""
if not logger:
logger = logging.getLogger(__name__)
repo = git.Repo(".")
current_hash = repo.head.object.hexsha
origin = repo.remotes.origin
origin.fetch()
remote_hash = origin.refs[repo.active_branch.name].object.hexsha
if current_hash != remote_hash:
logger.info(f"Updating from {current_hash} to {remote_hash}")
current_commands = get_all_commands()
changes = origin.pull()
logger.info(f"Pulled {len(changes)} changes")
await asyncio.sleep(3)
new_commands = get_all_commands()
logger.info("Checking if any modules need reloaded...")
reloaded = []
loaded = []
unloaded = []
logger.debug("Checking for removed cogs")
for module in current_commands.keys():
if module not in new_commands:
logger.debug(f"Module {module} removed after update")
bot.unload_extension(module)
unloaded.append(module)
logger.debug("Checking for new/modified commands")
for module, commands in new_commands.items():
logger.debug(f"Processing {module}")
if module not in current_commands:
bot.load_extension(module)
loaded.append(module)
elif len(current_commands[module]) != len(commands):
bot.reload_extension(module)
reloaded.append(module)
else:
for command in commands:
old_command = find(
lambda x: x.resolved_name == command.resolved_name, current_commands[module]
)
# Extract useful info
old_args = old_command.options
if old_args:
old_arg_names = [x.name for x in old_args]
new_args = command.options
if new_args:
new_arg_names = [x.name for x in new_args]
# No changes
if not old_args and not new_args:
continue
# Check if number arguments have changed
if len(old_args) != len(new_args):
bot.reload_extension(module)
reloaded.append(module)
elif any(x not in old_arg_names for x in new_arg_names) or any(
x not in new_arg_names for x in old_arg_names
):
bot.reload_extension(module)
reloaded.append(module)
elif any(new_args[idx].type != x.type for idx, x in enumerate(old_args)):
bot.reload_extension(module)
reloaded.append(module)
file_changes = {}
for change in sorted(changes, key=lambda x: x.commit.committed_datetime, reverse=True):
if change.commit.hexsha == current_hash:
break
files = change.commit.stats.files
for file, stats in files.items():
if file not in file_changes:
file_changes[file] = {"insertions": 0, "deletions": 0, "lines": 0}
for k, v in stats.items():
file_changes[file][k] += v
table = Table(title="File Changes")
table.add_column("File", justify="left", style="white", no_wrap=True)
table.add_column("Insertions", justify="center", style="green")
table.add_column("Deletions", justify="center", style="red")
table.add_column("Lines", justify="center", style="magenta")
i_total = 0
d_total = 0
l_total = 0
for file, stats in file_changes.items():
i_total += stats["insertions"]
d_total += stats["deletions"]
l_total += stats["lines"]
table.add_row(
file,
str(stats["insertions"]),
str(stats["deletions"]),
str(stats["lines"]),
)
table.add_row("Total", str(i_total), str(d_total), str(l_total))
return UpdateResult(
table=table,
old_hash=current_hash,
new_hash=remote_hash,
added=loaded,
removed=unloaded,
changed=reloaded,
inserted_lines=i_total,
deleted_lines=d_total,
total_lines=l_total,
)
return None