Create helper function for updates for easier debugging
This commit is contained in:
parent
92cd0b1eae
commit
b1d5415625
2 changed files with 200 additions and 112 deletions
|
@ -1,22 +1,19 @@
|
|||
"""JARVIS bot utility commands."""
|
||||
import asyncio
|
||||
import logging
|
||||
import platform
|
||||
from io import BytesIO
|
||||
|
||||
import git
|
||||
import psutil
|
||||
from aiofile import AIOFile, LineReader
|
||||
from dis_snek import MessageContext, Scale, Snake
|
||||
from dis_snek.client.errors import HTTPException
|
||||
from dis_snek.client.utils.misc_utils import find
|
||||
from dis_snek.models.discord.embed import EmbedField
|
||||
from dis_snek.models.discord.file import File
|
||||
from molter import msg_command
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from jarvis.utils import build_embed, get_all_commands
|
||||
from jarvis.utils import build_embed
|
||||
from jarvis.utils.updates import update
|
||||
|
||||
|
||||
class BotutilCog(Scale):
|
||||
|
@ -80,118 +77,23 @@ class BotutilCog(Scale):
|
|||
|
||||
@msg_command(name="update")
|
||||
async def _update(self, ctx: MessageContext) -> None:
|
||||
repo = git.Repo(".")
|
||||
current_hash = repo.head.object.hexsha
|
||||
origin = repo.remotes.origin
|
||||
origin.fetch()
|
||||
remote_hash = origin.refs[repo.active_branch.name].object.hexsha
|
||||
|
||||
if current_hash != remote_hash:
|
||||
self.logger.info("Updating...")
|
||||
current_commands = get_all_commands()
|
||||
changes = origin.pull()
|
||||
|
||||
self.logger.info("Changes pulled...")
|
||||
self.logger.debug("Sleeping for 3 seconds to allow changes")
|
||||
await asyncio.sleep(3)
|
||||
self.logger.debug("Finished sleeping, loading new commands")
|
||||
|
||||
reloaded = []
|
||||
loaded = []
|
||||
unloaded = []
|
||||
new_commands = get_all_commands()
|
||||
for module in current_commands:
|
||||
if module not in new_commands:
|
||||
self.bot.unload_extension(module)
|
||||
unloaded.append(module)
|
||||
for module, commands in new_commands.items():
|
||||
self.logger.debug(f"Processing {module}")
|
||||
if module not in current_commands:
|
||||
self.bot.load_extension(module)
|
||||
loaded.append(module)
|
||||
elif len(current_commands[module]) != len(commands):
|
||||
self.bot.reload_extension(module)
|
||||
reloaded.append(module)
|
||||
else:
|
||||
for command in commands:
|
||||
old_command = find(
|
||||
lambda x: x.resolved_name == command.resolved_name,
|
||||
current_commands[module],
|
||||
)
|
||||
|
||||
# Extract useful info
|
||||
old_args = old_command.options
|
||||
if old_args:
|
||||
old_arg_names = [x.name for x in old_args]
|
||||
new_args = command.options
|
||||
if new_args:
|
||||
new_arg_names = [x.name for x in new_args]
|
||||
|
||||
# No changes
|
||||
if not old_args and not new_args:
|
||||
continue
|
||||
|
||||
# Check if number arguments have changed
|
||||
if len(old_args) != len(new_args):
|
||||
self.bot.reload_extension(module)
|
||||
reloaded.append(module)
|
||||
elif any(x not in old_arg_names for x in new_arg_names) or any(
|
||||
x not in new_arg_names for x in old_arg_names
|
||||
):
|
||||
self.bot.reload_extension(module)
|
||||
reloaded.append(module)
|
||||
elif any(new_args[idx].type != x.type for idx, x in enumerate(old_args)):
|
||||
self.bot.reload_extension(module)
|
||||
reloaded.append(module)
|
||||
|
||||
file_changes = {}
|
||||
for change in sorted(changes, key=lambda x: x.commit.committed_datetime, reverse=True):
|
||||
if change.commit.hexsha == current_hash:
|
||||
break
|
||||
files = change.commit.stats.files
|
||||
for file, stats in files.items():
|
||||
if file not in file_changes:
|
||||
file_changes[file] = {"insertions": 0, "deletions": 0, "lines": 0}
|
||||
for k, v in stats.items():
|
||||
file_changes[file][k] += v
|
||||
|
||||
table = Table(title="File Changes")
|
||||
|
||||
table.add_column("File", justify="left", style="white", no_wrap=True)
|
||||
table.add_column("Insertions", justify="center", style="green")
|
||||
table.add_column("Deletions", justify="center", style="red")
|
||||
table.add_column("Lines", justify="center", style="magenta")
|
||||
|
||||
i_total = 0
|
||||
d_total = 0
|
||||
l_total = 0
|
||||
for file, stats in file_changes.items():
|
||||
i_total += stats["insertions"]
|
||||
d_total += stats["deletions"]
|
||||
l_total += stats["lines"]
|
||||
table.add_row(
|
||||
file,
|
||||
str(stats["insertions"]),
|
||||
str(stats["deletions"]),
|
||||
str(stats["lines"]),
|
||||
)
|
||||
|
||||
table.add_row("Total", str(i_total), str(d_total), str(l_total))
|
||||
status = update(self.bot, self.logger)
|
||||
if status:
|
||||
console = Console()
|
||||
with console.capture() as capture:
|
||||
console.print(table)
|
||||
console.print(status.table)
|
||||
self.logger.debug(capture.get())
|
||||
self.logger.debug(len(capture.get()))
|
||||
new = "\n".join(loaded)
|
||||
removed = "\n".join(unloaded)
|
||||
changed = "\n".join(reloaded)
|
||||
added = "\n".join(status.added)
|
||||
removed = "\n".join(status.removed)
|
||||
changed = "\n".join(status.changed)
|
||||
|
||||
fields = [
|
||||
EmbedField(name="Old Commit", value=current_hash),
|
||||
EmbedField(name="New Commit", value=remote_hash),
|
||||
EmbedField(name="Old Commit", value=status.old_hash),
|
||||
EmbedField(name="New Commit", value=status.new_hash),
|
||||
]
|
||||
if loaded:
|
||||
fields.append(EmbedField(name="New Modules", value=f"```\n{new}\n```"))
|
||||
if added:
|
||||
fields.append(EmbedField(name="New Modules", value=f"```\n{added}\n```"))
|
||||
if removed:
|
||||
fields.append(EmbedField(name="Removed Modules", value=f"```\n{removed}\n```"))
|
||||
if changed:
|
||||
|
@ -206,11 +108,10 @@ class BotutilCog(Scale):
|
|||
try:
|
||||
await ctx.reply(f"```ansi\n{capture.get()}\n```", embed=embed)
|
||||
except HTTPException:
|
||||
await ctx.reply(f"Total Changes: {l_total}", embed=embed)
|
||||
await ctx.reply(f"Total Changes: {status.total_lines}", embed=embed)
|
||||
|
||||
else:
|
||||
embed = build_embed(title="Update Status", description="No changes applied", fields=[])
|
||||
embed.set_footer(text=current_hash)
|
||||
embed.set_thumbnail(url="https://dev.zevaryx.com/git.png")
|
||||
await ctx.reply(embed=embed)
|
||||
|
||||
|
|
187
jarvis/utils/updates.py
Normal file
187
jarvis/utils/updates.py
Normal file
|
@ -0,0 +1,187 @@
|
|||
"""JARVIS update handler."""
|
||||
import asyncio
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pkgutil import iter_modules
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||
|
||||
import git
|
||||
from dis_snek.client.utils.misc_utils import find, find_all
|
||||
from dis_snek.models.snek.application_commands import SlashCommand
|
||||
from dis_snek.models.snek.scale import Scale
|
||||
from rich.table import Table
|
||||
|
||||
import jarvis.cogs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
from dis_snek.client.client import Snake
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateResult:
|
||||
"""JARVIS update result."""
|
||||
|
||||
old_hash: str
|
||||
new_hash: str
|
||||
table: Table
|
||||
added: List[str]
|
||||
removed: List[str]
|
||||
changed: List[str]
|
||||
inserted_lines: int
|
||||
deleted_lines: int
|
||||
total_lines: int
|
||||
|
||||
|
||||
def get_all_commands(module: ModuleType = jarvis.cogs) -> Dict[str, Callable]:
|
||||
"""Get all SlashCommands from a specified module."""
|
||||
commands = {}
|
||||
for item in iter_modules(module.__path__):
|
||||
new_module = importlib.import_module(f"{module.__name__}.{item.name}")
|
||||
if item.ispkg:
|
||||
if cmds := get_all_commands(new_module):
|
||||
commands.update(cmds)
|
||||
else:
|
||||
inspect_result = inspect.getmembers(new_module)
|
||||
cogs = []
|
||||
for _, val in inspect_result:
|
||||
if inspect.isclass(val) and issubclass(val, Scale) and val is not Scale:
|
||||
cogs.append(val)
|
||||
for cog in cogs:
|
||||
values = cog.__dict__.values()
|
||||
commands[cog.__module__] = find_all(lambda x: isinstance(x, SlashCommand), values)
|
||||
return {k: v for k, v in commands.items() if v}
|
||||
|
||||
|
||||
async def update(bot: "Snake", logger: "Logger" = None) -> Optional[UpdateResult]:
|
||||
"""
|
||||
Update JARVIS and return an UpdateResult.
|
||||
|
||||
Args:
|
||||
bot: Bot instance
|
||||
logger: Logger instance
|
||||
|
||||
Returns:
|
||||
UpdateResult object
|
||||
"""
|
||||
if not logger:
|
||||
logger = logging.getLogger(__name__)
|
||||
repo = git.Repo(".")
|
||||
current_hash = repo.head.object.hexsha
|
||||
origin = repo.remotes.origin
|
||||
origin.fetch()
|
||||
remote_hash = origin.refs[repo.active_branch.name].object.hexsha
|
||||
|
||||
if current_hash != remote_hash:
|
||||
logger.info(f"Updating from {current_hash} to {remote_hash}")
|
||||
current_commands = get_all_commands()
|
||||
|
||||
changes = origin.pull()
|
||||
logger.info(f"Pulled {len(changes)} changes")
|
||||
await asyncio.sleep(3)
|
||||
|
||||
new_commands = get_all_commands()
|
||||
|
||||
logger.info("Checking if any modules need reloaded...")
|
||||
|
||||
reloaded = []
|
||||
loaded = []
|
||||
unloaded = []
|
||||
|
||||
logger.debug("Checking for removed cogs")
|
||||
for module in current_commands.keys():
|
||||
if module not in new_commands:
|
||||
logger.debug(f"Module {module} removed after update")
|
||||
bot.unload_extension(module)
|
||||
unloaded.append(module)
|
||||
|
||||
logger.debug("Checking for new/modified commands")
|
||||
for module, commands in new_commands.items():
|
||||
logger.debug(f"Processing {module}")
|
||||
if module not in current_commands:
|
||||
bot.load_extension(module)
|
||||
loaded.append(module)
|
||||
elif len(current_commands[module]) != len(commands):
|
||||
bot.reload_extension(module)
|
||||
reloaded.append(module)
|
||||
else:
|
||||
for command in commands:
|
||||
old_command = find(
|
||||
lambda x: x.resolved_name == command.resolved_name, current_commands[module]
|
||||
)
|
||||
|
||||
# Extract useful info
|
||||
old_args = old_command.options
|
||||
if old_args:
|
||||
old_arg_names = [x.name for x in old_args]
|
||||
new_args = command.options
|
||||
if new_args:
|
||||
new_arg_names = [x.name for x in new_args]
|
||||
|
||||
# No changes
|
||||
if not old_args and not new_args:
|
||||
continue
|
||||
|
||||
# Check if number arguments have changed
|
||||
if len(old_args) != len(new_args):
|
||||
bot.reload_extension(module)
|
||||
reloaded.append(module)
|
||||
elif any(x not in old_arg_names for x in new_arg_names) or any(
|
||||
x not in new_arg_names for x in old_arg_names
|
||||
):
|
||||
bot.reload_extension(module)
|
||||
reloaded.append(module)
|
||||
elif any(new_args[idx].type != x.type for idx, x in enumerate(old_args)):
|
||||
bot.reload_extension(module)
|
||||
reloaded.append(module)
|
||||
|
||||
file_changes = {}
|
||||
for change in sorted(changes, key=lambda x: x.commit.committed_datetime, reverse=True):
|
||||
if change.commit.hexsha == current_hash:
|
||||
break
|
||||
files = change.commit.stats.files
|
||||
for file, stats in files.items():
|
||||
if file not in file_changes:
|
||||
file_changes[file] = {"insertions": 0, "deletions": 0, "lines": 0}
|
||||
for k, v in stats.items():
|
||||
file_changes[file][k] += v
|
||||
|
||||
table = Table(title="File Changes")
|
||||
|
||||
table.add_column("File", justify="left", style="white", no_wrap=True)
|
||||
table.add_column("Insertions", justify="center", style="green")
|
||||
table.add_column("Deletions", justify="center", style="red")
|
||||
table.add_column("Lines", justify="center", style="magenta")
|
||||
|
||||
i_total = 0
|
||||
d_total = 0
|
||||
l_total = 0
|
||||
for file, stats in file_changes.items():
|
||||
i_total += stats["insertions"]
|
||||
d_total += stats["deletions"]
|
||||
l_total += stats["lines"]
|
||||
table.add_row(
|
||||
file,
|
||||
str(stats["insertions"]),
|
||||
str(stats["deletions"]),
|
||||
str(stats["lines"]),
|
||||
)
|
||||
|
||||
table.add_row("Total", str(i_total), str(d_total), str(l_total))
|
||||
|
||||
return UpdateResult(
|
||||
table=table,
|
||||
old_hash=current_hash,
|
||||
new_hash=remote_hash,
|
||||
added=loaded,
|
||||
removed=unloaded,
|
||||
changed=reloaded,
|
||||
inserted_lines=i_total,
|
||||
deleted_lines=d_total,
|
||||
total_lines=l_total,
|
||||
)
|
||||
return None
|
Loading…
Add table
Reference in a new issue