195 lines
6.2 KiB
Python
195 lines
6.2 KiB
Python
"""JARVIS update handler."""
|
|
import asyncio
|
|
import importlib
|
|
import inspect
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from pkgutil import iter_modules
|
|
from types import ModuleType
|
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
|
|
|
import git
|
|
from dis_snek.client.utils.misc_utils import find, find_all
|
|
from dis_snek.models.snek.application_commands import SlashCommand
|
|
from dis_snek.models.snek.scale import Scale
|
|
from rich.table import Table
|
|
|
|
import jarvis.cogs
|
|
|
|
if TYPE_CHECKING:
|
|
from dis_snek.client.client import Snake
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class UpdateResult:
|
|
"""JARVIS update result."""
|
|
|
|
old_hash: str
|
|
new_hash: str
|
|
table: Table
|
|
added: List[str]
|
|
removed: List[str]
|
|
changed: List[str]
|
|
inserted_lines: int
|
|
deleted_lines: int
|
|
total_lines: int
|
|
|
|
|
|
def get_all_commands(module: ModuleType = jarvis.cogs) -> Dict[str, Callable]:
|
|
"""Get all SlashCommands from a specified module."""
|
|
commands = {}
|
|
for item in iter_modules(module.__path__):
|
|
new_module = importlib.import_module(f"{module.__name__}.{item.name}")
|
|
if item.ispkg:
|
|
if cmds := get_all_commands(new_module):
|
|
commands.update(cmds)
|
|
else:
|
|
inspect_result = inspect.getmembers(new_module)
|
|
cogs = []
|
|
for _, val in inspect_result:
|
|
if inspect.isclass(val) and issubclass(val, Scale) and val is not Scale:
|
|
cogs.append(val)
|
|
for cog in cogs:
|
|
values = cog.__dict__.values()
|
|
commands[cog.__module__] = find_all(lambda x: isinstance(x, SlashCommand), values)
|
|
return {k: v for k, v in commands.items() if v}
|
|
|
|
|
|
def get_git_changes() -> dict:
|
|
"""Get all Git changes"""
|
|
repo = git.Repo(".")
|
|
current_hash = repo.head.object.hexsha
|
|
origin = repo.remotes.origin
|
|
changes = origin.fetch()
|
|
|
|
file_changes = {}
|
|
for change in changes:
|
|
if change.commit.hexsha == current_hash:
|
|
break
|
|
files = change.commit.stats.files
|
|
for file, stats in files.items():
|
|
if file not in file_changes:
|
|
file_changes[file] = {"insertions": 0, "deletions": 0, "lines": 0}
|
|
for k, v in stats.items():
|
|
file_changes[file][k] += v
|
|
|
|
table = Table(title="File Changes")
|
|
|
|
table.add_column("File", justify="left", style="white", no_wrap=True)
|
|
table.add_column("Insertions", justify="center", style="green")
|
|
table.add_column("Deletions", justify="center", style="red")
|
|
table.add_column("Lines", justify="center", style="magenta")
|
|
|
|
i_total = 0
|
|
d_total = 0
|
|
l_total = 0
|
|
for file, stats in file_changes.items():
|
|
i_total += stats["insertions"]
|
|
d_total += stats["deletions"]
|
|
l_total += stats["lines"]
|
|
table.add_row(
|
|
file,
|
|
str(stats["insertions"]),
|
|
str(stats["deletions"]),
|
|
str(stats["lines"]),
|
|
)
|
|
|
|
table.add_row("Total", str(i_total), str(d_total), str(l_total))
|
|
return {
|
|
"table": table,
|
|
"inserted_lines": i_total,
|
|
"deleted_lines": d_total,
|
|
"total_lines": l_total,
|
|
}
|
|
|
|
|
|
async def update(bot: "Snake") -> Optional[UpdateResult]:
|
|
"""
|
|
Update JARVIS and return an UpdateResult.
|
|
|
|
Args:
|
|
bot: Bot instance
|
|
|
|
Returns:
|
|
UpdateResult object
|
|
"""
|
|
repo = git.Repo(".")
|
|
current_hash = repo.head.object.hexsha
|
|
origin = repo.remotes.origin
|
|
origin.fetch()
|
|
remote_hash = origin.refs[repo.active_branch.name].object.hexsha
|
|
|
|
if current_hash != remote_hash:
|
|
logger.info(f"Updating from {current_hash} to {remote_hash}")
|
|
current_commands = get_all_commands()
|
|
changes = get_git_changes()
|
|
|
|
origin.pull()
|
|
await asyncio.sleep(3)
|
|
|
|
new_commands = get_all_commands()
|
|
|
|
logger.info("Checking if any modules need reloaded...")
|
|
|
|
reloaded = []
|
|
loaded = []
|
|
unloaded = []
|
|
|
|
logger.debug("Checking for removed cogs")
|
|
for module in current_commands.keys():
|
|
if module not in new_commands:
|
|
logger.debug(f"Module {module} removed after update")
|
|
bot.unload_extension(module)
|
|
unloaded.append(module)
|
|
|
|
logger.debug("Checking for new/modified commands")
|
|
for module, commands in new_commands.items():
|
|
logger.debug(f"Processing {module}")
|
|
if module not in current_commands:
|
|
bot.load_extension(module)
|
|
loaded.append(module)
|
|
elif len(current_commands[module]) != len(commands):
|
|
bot.reload_extension(module)
|
|
reloaded.append(module)
|
|
else:
|
|
for command in commands:
|
|
old_command = find(
|
|
lambda x: x.resolved_name == command.resolved_name, current_commands[module]
|
|
)
|
|
|
|
# Extract useful info
|
|
old_args = old_command.options
|
|
if old_args:
|
|
old_arg_names = [x.name for x in old_args]
|
|
new_args = command.options
|
|
if new_args:
|
|
new_arg_names = [x.name for x in new_args]
|
|
|
|
# No changes
|
|
if not old_args and not new_args:
|
|
continue
|
|
|
|
# Check if number arguments have changed
|
|
if len(old_args) != len(new_args):
|
|
bot.reload_extension(module)
|
|
reloaded.append(module)
|
|
elif any(x not in old_arg_names for x in new_arg_names) or any(
|
|
x not in new_arg_names for x in old_arg_names
|
|
):
|
|
bot.reload_extension(module)
|
|
reloaded.append(module)
|
|
elif any(new_args[idx].type != x.type for idx, x in enumerate(old_args)):
|
|
bot.reload_extension(module)
|
|
reloaded.append(module)
|
|
|
|
return UpdateResult(
|
|
old_hash=current_hash,
|
|
new_hash=remote_hash,
|
|
added=loaded,
|
|
removed=unloaded,
|
|
changed=reloaded,
|
|
**changes,
|
|
)
|
|
return None
|