215 lines
8.7 KiB
Python
215 lines
8.7 KiB
Python
"""JARVIS bot utility commands."""
|
|
import asyncio
|
|
import logging
|
|
import platform
|
|
from io import BytesIO
|
|
|
|
import git
|
|
import psutil
|
|
from aiofile import AIOFile, LineReader
|
|
from dis_snek import MessageContext, Scale, Snake
|
|
from dis_snek.client.utils.misc_utils import find
|
|
from dis_snek.models.discord.embed import EmbedField
|
|
from dis_snek.models.discord.file import File
|
|
from molter import msg_command
|
|
from rich.console import Console
|
|
from rich.table import Table
|
|
|
|
from jarvis.utils import build_embed, get_all_commands
|
|
|
|
|
|
class BotutilCog(Scale):
|
|
"""JARVIS Bot Utility Cog."""
|
|
|
|
def __init__(self, bot: Snake):
|
|
self.bot = bot
|
|
self.logger = logging.getLogger(__name__)
|
|
self.add_scale_check(self.is_owner)
|
|
|
|
async def is_owner(self, ctx: MessageContext) -> bool:
|
|
"""Checks if author is bot owner."""
|
|
return ctx.author.id == self.bot.owner.id
|
|
|
|
@msg_command(name="tail")
|
|
async def _tail(self, ctx: MessageContext, count: int = 10) -> None:
|
|
lines = []
|
|
async with AIOFile("jarvis.log", "r") as af:
|
|
async for line in LineReader(af):
|
|
lines.append(line)
|
|
if len(lines) == count + 1:
|
|
lines.pop(0)
|
|
log = "".join(lines)
|
|
if len(log) > 1500:
|
|
with BytesIO() as file_bytes:
|
|
file_bytes.write(log.encode("UTF8"))
|
|
file_bytes.seek(0)
|
|
log = File(file_bytes, file_name=f"tail_{count}.log")
|
|
await ctx.reply(content=f"Here's the last {count} lines of the log", file=log)
|
|
else:
|
|
await ctx.reply(content=f"```\n{log}\n```")
|
|
|
|
@msg_command(name="log")
|
|
async def _log(self, ctx: MessageContext) -> None:
|
|
async with AIOFile("jarvis.log", "r") as af:
|
|
with BytesIO() as file_bytes:
|
|
raw = await af.read_bytes()
|
|
file_bytes.write(raw)
|
|
file_bytes.seek(0)
|
|
log = File(file_bytes, file_name="jarvis.log")
|
|
await ctx.reply(content="Here's the latest log", file=log)
|
|
|
|
@msg_command(name="crash")
|
|
async def _crash(self, ctx: MessageContext) -> None:
|
|
raise Exception("As you wish")
|
|
|
|
@msg_command(name="sysinfo")
|
|
async def _sysinfo(self, ctx: MessageContext) -> None:
|
|
st_ts = int(self.bot.start_time.timestamp())
|
|
ut_ts = int(psutil.boot_time())
|
|
fields = [
|
|
EmbedField(name="Operation System", value=platform.system() or "Unknown", inline=False),
|
|
EmbedField(name="Version", value=platform.release() or "N/A", inline=False),
|
|
EmbedField(name="System Start Time", value=f"<t:{ut_ts}:F> (<t:{ut_ts}:R>)"),
|
|
EmbedField(name="Python Version", value=platform.python_version()),
|
|
EmbedField(name="Bot Start Time", value=f"<t:{st_ts}:F> (<t:{st_ts}:R>)"),
|
|
]
|
|
embed = build_embed(title="System Info", description="", fields=fields)
|
|
embed.set_image(url=self.bot.user.avatar.url)
|
|
await ctx.send(embed=embed)
|
|
|
|
@msg_command(name="update")
|
|
async def _update(self, ctx: MessageContext) -> None:
|
|
repo = git.Repo(".")
|
|
current_hash = repo.head.object.hexsha
|
|
origin = repo.remotes.origin
|
|
origin.fetch()
|
|
remote_hash = origin.refs[repo.active_branch.name].object.hexsha
|
|
|
|
if current_hash != remote_hash:
|
|
self.logger.info("Updating...")
|
|
current_commands = get_all_commands()
|
|
changes = origin.pull()
|
|
|
|
self.logger.info("Changes pulled...")
|
|
self.logger.debug("Sleeping for 3 seconds to allow changes")
|
|
await asyncio.sleep(3)
|
|
self.logger.debug("Finished sleeping, loading new commands")
|
|
|
|
reloaded = []
|
|
loaded = []
|
|
unloaded = []
|
|
new_commands = get_all_commands()
|
|
for module in current_commands:
|
|
if module not in new_commands:
|
|
self.bot.unload_extension(module)
|
|
unloaded.append(module)
|
|
for module, commands in new_commands.items():
|
|
self.logger.debug(f"Processing {module}")
|
|
if module not in current_commands:
|
|
self.bot.load_extension(module)
|
|
loaded.append(module)
|
|
elif len(current_commands[module]) != len(commands):
|
|
self.bot.reload_extension(module)
|
|
reloaded.append(module)
|
|
else:
|
|
for command in commands:
|
|
old_command = find(
|
|
lambda x: x.resolved_name == command.resolved_name,
|
|
current_commands[module],
|
|
)
|
|
|
|
# Extract useful info
|
|
old_args = old_command.options
|
|
if old_args:
|
|
old_arg_names = [x.name for x in old_args]
|
|
new_args = command.options
|
|
if new_args:
|
|
new_arg_names = [x.name for x in new_args]
|
|
|
|
# No changes
|
|
if not old_args and not new_args:
|
|
continue
|
|
|
|
# Check if number arguments have changed
|
|
if len(old_args) != len(new_args):
|
|
self.bot.reload_extension(module)
|
|
reloaded.append(module)
|
|
elif any(x not in old_arg_names for x in new_arg_names) or any(
|
|
x not in new_arg_names for x in old_arg_names
|
|
):
|
|
self.bot.reload_extension(module)
|
|
reloaded.append(module)
|
|
elif any(new_args[idx].type != x.type for idx, x in enumerate(old_args)):
|
|
self.bot.reload_extension(module)
|
|
reloaded.append(module)
|
|
|
|
file_changes = {}
|
|
for change in changes:
|
|
if change.commit.hexsha == current_hash:
|
|
break
|
|
files = change.commit.stats.files
|
|
for file, stats in files.items():
|
|
if file not in file_changes:
|
|
file_changes[file] = {"insertions": 0, "deletions": 0, "lines": 0}
|
|
for k, v in stats.items():
|
|
file_changes[file][k] += v
|
|
|
|
table = Table(title="File Changes")
|
|
|
|
table.add_column("File", justify="left", style="white", no_wrap=True)
|
|
table.add_column("Insertions", justify="center", style="green")
|
|
table.add_column("Deletions", justify="center", style="red")
|
|
table.add_column("Lines", justify="center", style="magenta")
|
|
|
|
i_total = 0
|
|
d_total = 0
|
|
l_total = 0
|
|
for file, stats in file_changes.items():
|
|
i_total += stats["insertions"]
|
|
d_total += stats["deletions"]
|
|
l_total += stats["lines"]
|
|
table.add_row(
|
|
file,
|
|
str(stats["insertions"]),
|
|
str(stats["deletions"]),
|
|
str(stats["lines"]),
|
|
)
|
|
|
|
table.add_row("Total", str(i_total), str(d_total), str(l_total))
|
|
self.logger.debug(table)
|
|
console = Console()
|
|
with console.capture() as capture:
|
|
console.print(table)
|
|
new = "\n".join(loaded)
|
|
removed = "\n".join(unloaded)
|
|
changed = "\n".join(reloaded)
|
|
|
|
fields = [
|
|
EmbedField(name="Old Commit", value=current_hash),
|
|
EmbedField(name="New Commit", value=remote_hash),
|
|
]
|
|
if loaded:
|
|
fields.append(EmbedField(name="New Modules", value=f"```\n{new}\n```"))
|
|
if removed:
|
|
fields.append(EmbedField(name="Removed Modules", value=f"```\n{removed}\n```"))
|
|
if changed:
|
|
fields.append(EmbedField(name="Changed Modules", value=f"```\n{changed}\n```"))
|
|
|
|
embed = build_embed(
|
|
"Update Status", description="Updates have been applied", fields=fields
|
|
)
|
|
embed.set_image(url="https://dev.zevaryx.com/git.png")
|
|
|
|
self.logger.info("Updates applied")
|
|
await ctx.reply(f"```ansi\n{capture.get()}\n```", embed=embed)
|
|
|
|
else:
|
|
embed = build_embed(title="Update Status", description="No changes applied", fields=[])
|
|
embed.set_footer(text=current_hash)
|
|
embed.set_image(url="https://dev.zevaryx.com/git.png")
|
|
await ctx.reply(embed=embed)
|
|
|
|
|
|
def setup(bot: Snake) -> None:
|
|
"""Add BotutilCog to JARVIS"""
|
|
BotutilCog(bot)
|