perflint and pylint recommendations

This commit is contained in:
Zeva Rose 2022-05-02 01:18:33 -06:00
parent af42b385ab
commit d76d030bfd
15 changed files with 100 additions and 168 deletions

View file

@ -7,34 +7,37 @@ from dis_snek import Intents
from jarvis_core.db import connect from jarvis_core.db import connect
from jarvis_core.log import get_logger from jarvis_core.log import get_logger
from jarvis import const, utils from jarvis import const
from jarvis.client import Jarvis from jarvis.client import Jarvis
from jarvis.cogs import __path__ as cogs_path
from jarvis.config import JarvisConfig from jarvis.config import JarvisConfig
from jarvis.utils import get_extensions
__version__ = const.__version__ __version__ = const.__version__
jconfig = JarvisConfig.from_yaml()
logger = get_logger("jarvis", show_locals=jconfig.log_level == "DEBUG")
logger.setLevel(jconfig.log_level)
file_handler = logging.FileHandler(filename="jarvis.log", encoding="UTF-8", mode="w")
file_handler.setFormatter(
logging.Formatter("[%(asctime)s] [%(name)s] [%(levelname)8s] %(message)s")
)
logger.addHandler(file_handler)
intents = Intents.DEFAULT | Intents.MESSAGES | Intents.GUILD_MEMBERS | Intents.GUILD_MESSAGE_CONTENT
restart_ctx = None
jarvis = Jarvis(
intents=intents,
sync_interactions=jconfig.sync,
delete_unused_application_cmds=True,
send_command_tracebacks=False,
)
async def run() -> None: async def run() -> None:
"""Run JARVIS""" """Run JARVIS"""
jconfig = JarvisConfig.from_yaml()
logger = get_logger("jarvis", show_locals=jconfig.log_level == "DEBUG")
logger.setLevel(jconfig.log_level)
file_handler = logging.FileHandler(filename="jarvis.log", encoding="UTF-8", mode="w")
file_handler.setFormatter(
logging.Formatter("[%(asctime)s] [%(name)s] [%(levelname)8s] %(message)s")
)
logger.addHandler(file_handler)
intents = (
Intents.DEFAULT | Intents.MESSAGES | Intents.GUILD_MEMBERS | Intents.GUILD_MESSAGE_CONTENT
)
jarvis = Jarvis(
intents=intents,
sync_interactions=jconfig.sync,
delete_unused_application_cmds=True,
send_command_tracebacks=False,
)
if jconfig.log_level == "DEBUG": if jconfig.log_level == "DEBUG":
jurigged.watch() jurigged.watch()
if jconfig.rook_token: if jconfig.rook_token:
@ -46,9 +49,9 @@ async def run() -> None:
# jconfig.get_db_config() # jconfig.get_db_config()
logger.debug("Loading extensions") logger.debug("Loading extensions")
for extension in utils.get_extensions(): for extension in get_extensions(cogs_path):
jarvis.load_extension(extension) jarvis.load_extension(extension)
logger.debug(f"Loaded {extension}") logger.debug("Loaded %s", extension)
jarvis.max_messages = jconfig.max_messages jarvis.max_messages = jconfig.max_messages
logger.debug("Running JARVIS") logger.debug("Running JARVIS")

View file

@ -270,7 +270,7 @@ class Jarvis(Snake):
channel = await guild.fetch_channel(log.channel) channel = await guild.fetch_channel(log.channel)
embed = build_embed( embed = build_embed(
title="Member Left", title="Member Left",
desciption=f"{user.username}#{user.discriminator} left {guild.name}", description=f"{user.username}#{user.discriminator} left {guild.name}",
fields=[], fields=[],
) )
embed.set_author(name=user.username, icon_url=user.avatar.url) embed.set_author(name=user.username, icon_url=user.avatar.url)
@ -394,12 +394,9 @@ class Jarvis(Snake):
rolepings = await Roleping.find(q(guild=message.guild.id, active=True)).to_list(None) rolepings = await Roleping.find(q(guild=message.guild.id, active=True)).to_list(None)
# Get all role IDs involved with message # Get all role IDs involved with message
roles = [] roles = [x.id async for x in message.mention_roles]
async for mention in message.mention_roles:
roles.append(mention.id)
async for mention in message.mention_users: async for mention in message.mention_users:
for role in mention.roles: roles += [x.id for x in mention.roles]
roles.append(role.id)
if not roles: if not roles:
return return
@ -417,12 +414,15 @@ class Jarvis(Snake):
user_is_admin = message.author.has_permission(Permissions.ADMINISTRATOR) user_is_admin = message.author.has_permission(Permissions.ADMINISTRATOR)
# Check if user in a bypass list # Check if user in a bypass list
def check_has_role(roleping: Roleping) -> bool:
return any(role.id in roleping.bypass["roles"] for role in message.author.roles)
user_has_bypass = False user_has_bypass = False
for roleping in rolepings: for roleping in rolepings:
if message.author.id in roleping.bypass["users"]: if message.author.id in roleping.bypass["users"]:
user_has_bypass = True user_has_bypass = True
break break
if any(role.id in roleping.bypass["roles"] for role in message.author.roles): if check_has_role(roleping):
user_has_bypass = True user_has_bypass = True
break break
@ -553,7 +553,7 @@ class Jarvis(Snake):
) )
await channel.send(embed=embed) await channel.send(embed=embed)
except Exception as e: except Exception as e:
self.logger.warn( self.logger.warning(
f"Failed to process edit {before.guild.id}/{before.channel.id}/{before.id}: {e}" f"Failed to process edit {before.guild.id}/{before.channel.id}/{before.id}: {e}"
) )
if not isinstance(after.channel, DMChannel) and not after.author.bot: if not isinstance(after.channel, DMChannel) and not after.author.bot:
@ -629,6 +629,6 @@ class Jarvis(Snake):
) )
await channel.send(embed=embed) await channel.send(embed=embed)
except Exception as e: except Exception as e:
self.logger.warn( self.logger.warning(
f"Failed to process edit {message.guild.id}/{message.channel.id}/{message.id}: {e}" f"Failed to process edit {message.guild.id}/{message.channel.id}/{message.id}: {e}"
) )

View file

@ -1,8 +1,7 @@
"""JARVIS BanCog.""" """JARVIS BanCog."""
import logging
import re import re
from dis_snek import InteractionContext, Permissions, Snake from dis_snek import InteractionContext, Permissions
from dis_snek.client.utils.misc_utils import find, find_all from dis_snek.client.utils.misc_utils import find, find_all
from dis_snek.ext.paginators import Paginator from dis_snek.ext.paginators import Paginator
from dis_snek.models.discord.embed import EmbedField from dis_snek.models.discord.embed import EmbedField
@ -26,10 +25,6 @@ from jarvis.utils.permissions import admin_or_permissions
class BanCog(ModcaseCog): class BanCog(ModcaseCog):
"""JARVIS BanCog.""" """JARVIS BanCog."""
def __init__(self, bot: Snake):
super().__init__(bot)
self.logger = logging.getLogger(__name__)
async def discord_apply_ban( async def discord_apply_ban(
self, self,
ctx: InteractionContext, ctx: InteractionContext,

View file

@ -1,7 +1,5 @@
"""JARVIS KickCog.""" """JARVIS KickCog."""
import logging from dis_snek import InteractionContext, Permissions
from dis_snek import InteractionContext, Permissions, Snake
from dis_snek.models.discord.embed import EmbedField from dis_snek.models.discord.embed import EmbedField
from dis_snek.models.discord.user import User from dis_snek.models.discord.user import User
from dis_snek.models.snek.application_commands import ( from dis_snek.models.snek.application_commands import (
@ -20,10 +18,6 @@ from jarvis.utils.permissions import admin_or_permissions
class KickCog(ModcaseCog): class KickCog(ModcaseCog):
"""JARVIS KickCog.""" """JARVIS KickCog."""
def __init__(self, bot: Snake):
super().__init__(bot)
self.logger = logging.getLogger(__name__)
@slash_command(name="kick", description="Kick a user") @slash_command(name="kick", description="Kick a user")
@slash_option(name="user", description="User to kick", opt_type=OptionTypes.USER, required=True) @slash_option(name="user", description="User to kick", opt_type=OptionTypes.USER, required=True)
@slash_option( @slash_option(

View file

@ -1,11 +1,10 @@
"""JARVIS MuteCog.""" """JARVIS MuteCog."""
import asyncio import asyncio
import logging
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from dateparser import parse from dateparser import parse
from dateparser_data.settings import default_parsers from dateparser_data.settings import default_parsers
from dis_snek import InteractionContext, Permissions, Snake from dis_snek import InteractionContext, Permissions
from dis_snek.client.errors import Forbidden from dis_snek.client.errors import Forbidden
from dis_snek.models.discord.embed import EmbedField from dis_snek.models.discord.embed import EmbedField
from dis_snek.models.discord.modal import InputText, Modal, TextStyles from dis_snek.models.discord.modal import InputText, Modal, TextStyles
@ -29,10 +28,6 @@ from jarvis.utils.permissions import admin_or_permissions
class MuteCog(ModcaseCog): class MuteCog(ModcaseCog):
"""JARVIS MuteCog.""" """JARVIS MuteCog."""
def __init__(self, bot: Snake):
super().__init__(bot)
self.logger = logging.getLogger(__name__)
async def _apply_timeout( async def _apply_timeout(
self, ctx: InteractionContext, user: Member, reason: str, until: datetime self, ctx: InteractionContext, user: Member, reason: str, until: datetime
) -> None: ) -> None:

View file

@ -1,8 +1,7 @@
"""JARVIS WarningCog.""" """JARVIS WarningCog."""
import logging
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from dis_snek import InteractionContext, Permissions, Snake from dis_snek import InteractionContext, Permissions
from dis_snek.client.utils.misc_utils import get_all from dis_snek.client.utils.misc_utils import get_all
from dis_snek.ext.paginators import Paginator from dis_snek.ext.paginators import Paginator
from dis_snek.models.discord.embed import EmbedField from dis_snek.models.discord.embed import EmbedField
@ -25,10 +24,6 @@ from jarvis.utils.permissions import admin_or_permissions
class WarningCog(ModcaseCog): class WarningCog(ModcaseCog):
"""JARVIS WarningCog.""" """JARVIS WarningCog."""
def __init__(self, bot: Snake):
super().__init__(bot)
self.logger = logging.getLogger(__name__)
@slash_command(name="warn", description="Warn a user") @slash_command(name="warn", description="Warn a user")
@slash_option(name="user", description="User to warn", opt_type=OptionTypes.USER, required=True) @slash_option(name="user", description="User to warn", opt_type=OptionTypes.USER, required=True)
@slash_option( @slash_option(

View file

@ -5,11 +5,11 @@ from io import BytesIO
import psutil import psutil
from aiofile import AIOFile, LineReader from aiofile import AIOFile, LineReader
from dis_snek import MessageContext, Scale, Snake from dis_snek import Scale, Snake
from dis_snek.client.errors import HTTPException from dis_snek.client.errors import HTTPException
from dis_snek.models.discord.embed import EmbedField from dis_snek.models.discord.embed import EmbedField
from dis_snek.models.discord.file import File from dis_snek.models.discord.file import File
from molter import msg_command from molter import MessageContext, msg_command
from rich.console import Console from rich.console import Console
from jarvis.utils import build_embed from jarvis.utils import build_embed
@ -64,13 +64,13 @@ class BotutilCog(Scale):
async def _sysinfo(self, ctx: MessageContext) -> None: async def _sysinfo(self, ctx: MessageContext) -> None:
st_ts = int(self.bot.start_time.timestamp()) st_ts = int(self.bot.start_time.timestamp())
ut_ts = int(psutil.boot_time()) ut_ts = int(psutil.boot_time())
fields = [ fields = (
EmbedField(name="Operation System", value=platform.system() or "Unknown", inline=False), EmbedField(name="Operation System", value=platform.system() or "Unknown", inline=False),
EmbedField(name="Version", value=platform.release() or "N/A", inline=False), EmbedField(name="Version", value=platform.release() or "N/A", inline=False),
EmbedField(name="System Start Time", value=f"<t:{ut_ts}:F> (<t:{ut_ts}:R>)"), EmbedField(name="System Start Time", value=f"<t:{ut_ts}:F> (<t:{ut_ts}:R>)"),
EmbedField(name="Python Version", value=platform.python_version()), EmbedField(name="Python Version", value=platform.python_version()),
EmbedField(name="Bot Start Time", value=f"<t:{st_ts}:F> (<t:{st_ts}:R>)"), EmbedField(name="Bot Start Time", value=f"<t:{st_ts}:F> (<t:{st_ts}:R>)"),
] )
embed = build_embed(title="System Info", description="", fields=fields) embed = build_embed(title="System Info", description="", fields=fields)
embed.set_image(url=self.bot.user.avatar.url) embed.set_image(url=self.bot.user.avatar.url)
await ctx.send(embed=embed) await ctx.send(embed=embed)
@ -108,7 +108,7 @@ class BotutilCog(Scale):
try: try:
await ctx.reply(f"```ansi\n{capture.get()}\n```", embed=embed) await ctx.reply(f"```ansi\n{capture.get()}\n```", embed=embed)
except HTTPException: except HTTPException:
await ctx.reply(f"Total Changes: {status.total_lines}", embed=embed) await ctx.reply(f"Total Changes: {status.lines['total_lines']}", embed=embed)
else: else:
embed = build_embed(title="Update Status", description="No changes applied", fields=[]) embed = build_embed(title="Update Status", description="No changes applied", fields=[])

View file

@ -85,7 +85,7 @@ class ImageCog(Scale):
if tgt_size > unconvert_bytesize(8, "MB"): if tgt_size > unconvert_bytesize(8, "MB"):
await ctx.send("Target too large to send. Please make target < 8MB", ephemeral=True) await ctx.send("Target too large to send. Please make target < 8MB", ephemeral=True)
return return
elif tgt_size < 1024: if tgt_size < 1024:
await ctx.send("Sizes < 1KB are extremely unreliable and are disabled", ephemeral=True) await ctx.send("Sizes < 1KB are extremely unreliable and are disabled", ephemeral=True)
return return

View file

@ -135,7 +135,7 @@ class StarboardCog(Scale):
if c and isinstance(c, GuildText): if c and isinstance(c, GuildText):
channel_list.append(c) channel_list.append(c)
else: else:
self.logger.warn( self.logger.warning(
f"Starboard {starboard.channel} no longer valid in {ctx.guild.name}" f"Starboard {starboard.channel} no longer valid in {ctx.guild.name}"
) )
to_delete.append(starboard) to_delete.append(starboard)

View file

@ -79,14 +79,16 @@ class VerifyCog(Scale):
role = await ctx.guild.fetch_role(setting.value) role = await ctx.guild.fetch_role(setting.value)
await ctx.author.add_role(role, reason="Verification passed") await ctx.author.add_role(role, reason="Verification passed")
except AttributeError: except AttributeError:
self.logger.warn("Verified role deleted before verification finished") self.logger.warning("Verified role deleted before verification finished")
setting = await Setting.find_one(q(guild=ctx.guild.id, setting="unverified")) setting = await Setting.find_one(q(guild=ctx.guild.id, setting="unverified"))
if setting: if setting:
try: try:
role = await ctx.guild.fetch_role(setting.value) role = await ctx.guild.fetch_role(setting.value)
await ctx.author.remove_role(role, reason="Verification passed") await ctx.author.remove_role(role, reason="Verification passed")
except AttributeError: except AttributeError:
self.logger.warn("Unverified role deleted before verification finished") self.logger.warning(
"Unverified role deleted before verification finished"
)
await response.context.edit_origin( await response.context.edit_origin(
content=f"Welcome, {ctx.author.mention}. Please enjoy your stay.", content=f"Welcome, {ctx.author.mention}. Please enjoy your stay.",

View file

@ -12,7 +12,7 @@ except ImportError:
class JarvisConfig(CConfig): class JarvisConfig(CConfig):
REQUIRED = ["token", "mongo", "urls"] REQUIRED = ("token", "mongo", "urls")
OPTIONAL = { OPTIONAL = {
"sync": False, "sync": False,
"log_level": "WARNING", "log_level": "WARNING",

View file

@ -1,24 +1,14 @@
"""JARVIS Utility Functions.""" """JARVIS Utility Functions."""
import importlib
import inspect
from datetime import datetime, timezone from datetime import datetime, timezone
from pkgutil import iter_modules from pkgutil import iter_modules
from types import ModuleType
from typing import Callable, Dict
import git import git
from dis_snek.client.utils.misc_utils import find_all
from dis_snek.models.discord.embed import Embed, EmbedField from dis_snek.models.discord.embed import Embed, EmbedField
from dis_snek.models.discord.guild import AuditLogEntry from dis_snek.models.discord.guild import AuditLogEntry
from dis_snek.models.discord.user import Member from dis_snek.models.discord.user import Member
from dis_snek.models.snek import Scale
from dis_snek.models.snek.application_commands import SlashCommand
import jarvis.cogs
from jarvis.config import get_config from jarvis.config import get_config
__all__ = ["cachecog", "permissions"]
def build_embed( def build_embed(
title: str, title: str,
@ -71,30 +61,11 @@ def modlog_embed(
return embed return embed
def get_extensions(path: str = jarvis.cogs.__path__) -> list: def get_extensions(path: str) -> list:
"""Get JARVIS cogs.""" """Get JARVIS cogs."""
config = get_config() config = get_config()
vals = config.cogs or [x.name for x in iter_modules(path)] vals = config.cogs or [x.name for x in iter_modules(path)]
return ["jarvis.cogs.{}".format(x) for x in vals] return [f"jarvis.cogs.{x}" for x in vals]
def get_all_commands(module: ModuleType = jarvis.cogs) -> Dict[str, Callable]:
commands = {}
for item in iter_modules(module.__path__):
new_module = importlib.import_module(f"{module.__name__}.{item.name}")
if item.ispkg:
if cmds := get_all_commands(new_module):
commands.update(cmds)
else:
inspect_result = inspect.getmembers(new_module)
cogs = []
for _, val in inspect_result:
if inspect.isclass(val) and issubclass(val, Scale) and val is not Scale:
cogs.append(val)
for cog in cogs:
values = cog.__dict__.values()
commands[cog.__module__] = find_all(lambda x: isinstance(x, SlashCommand), values)
return {k: v for k, v in commands.items() if v}
def update() -> int: def update() -> int:

View file

@ -1,11 +1,8 @@
"""Cog wrapper for command caching.""" """Cog wrapper for command caching."""
from datetime import datetime, timedelta, timezone import logging
from dis_snek import InteractionContext, Scale, Snake from dis_snek import InteractionContext, Scale, Snake
from dis_snek.client.utils.misc_utils import find
from dis_snek.models.discord.embed import EmbedField from dis_snek.models.discord.embed import EmbedField
from dis_snek.models.snek.tasks.task import Task
from dis_snek.models.snek.tasks.triggers import IntervalTrigger
from jarvis_core.db import q from jarvis_core.db import q
from jarvis_core.db.models import ( from jarvis_core.db.models import (
Action, Action,
@ -24,42 +21,15 @@ MODLOG_LOOKUP = {"Ban": Ban, "Kick": Kick, "Mute": Mute, "Warning": Warning}
IGNORE_COMMANDS = {"Ban": ["bans"], "Kick": [], "Mute": ["unmute"], "Warning": ["warnings"]} IGNORE_COMMANDS = {"Ban": ["bans"], "Kick": [], "Mute": ["unmute"], "Warning": ["warnings"]}
class CacheCog(Scale):
"""Cog wrapper for command caching."""
def __init__(self, bot: Snake):
self.bot = bot
self.cache = {}
self._expire_interaction.start()
def check_cache(self, ctx: InteractionContext, **kwargs: dict) -> dict:
"""Check the cache."""
if not kwargs:
kwargs = {}
return find(
lambda x: x["command"] == ctx.subcommand_name # noqa: W503
and x["user"] == ctx.author.id # noqa: W503
and x["guild"] == ctx.guild.id # noqa: W503
and all(x[k] == v for k, v in kwargs.items()), # noqa: W503
self.cache.values(),
)
@Task.create(IntervalTrigger(minutes=1))
async def _expire_interaction(self) -> None:
keys = list(self.cache.keys())
for key in keys:
if self.cache[key]["timeout"] <= datetime.now(tz=timezone.utc) + timedelta(minutes=1):
del self.cache[key]
class ModcaseCog(Scale): class ModcaseCog(Scale):
"""Cog wrapper for moderation case logging.""" """Cog wrapper for moderation case logging."""
def __init__(self, bot: Snake): def __init__(self, bot: Snake):
self.bot = bot self.bot = bot
self.logger = logging.getLogger(__name__)
self.add_scale_postrun(self.log) self.add_scale_postrun(self.log)
async def log(self, ctx: InteractionContext, *args: list, **kwargs: dict) -> None: async def log(self, ctx: InteractionContext, *_args: list, **kwargs: dict) -> None:
""" """
Log a moderation activity in a moderation case. Log a moderation activity in a moderation case.
@ -71,31 +41,31 @@ class ModcaseCog(Scale):
if name in MODLOG_LOOKUP and ctx.command not in IGNORE_COMMANDS[name]: if name in MODLOG_LOOKUP and ctx.command not in IGNORE_COMMANDS[name]:
user = kwargs.pop("user", None) user = kwargs.pop("user", None)
if not user and not ctx.target_id: if not user and not ctx.target_id:
self.logger.warn(f"Admin action {name} missing user, exiting") self.logger.warning("Admin action %s missing user, exiting", name)
return return
elif ctx.target_id: if ctx.target_id:
user = ctx.target user = ctx.target
coll = MODLOG_LOOKUP.get(name, None) coll = MODLOG_LOOKUP.get(name, None)
if not coll: if not coll:
self.logger.warn(f"Unsupported action {name}, exiting") self.logger.warning("Unsupported action %s, exiting", name)
return return
action = await coll.find_one(q(user=user.id, guild=ctx.guild_id, active=True)) action = await coll.find_one(q(user=user.id, guild=ctx.guild_id, active=True))
if not action: if not action:
self.logger.warn(f"Missing action {name}, exiting") self.logger.warning("Missing action %s, exiting", name)
return return
action = Action(action_type=name.lower(), parent=action.id) action = Action(action_type=name.lower(), parent=action.id)
note = Note(admin=self.bot.user.id, content="Moderation case opened automatically") note = Note(admin=self.bot.user.id, content="Moderation case opened automatically")
await Modlog(user=user.id, admin=ctx.author.id, actions=[action], notes=[note]).commit() await Modlog(user=user.id, admin=ctx.author.id, actions=[action], notes=[note]).commit()
notify = await Setting.find_one(q(guild=ctx.guild.id, setting="notify", value=True)) notify = await Setting.find_one(q(guild=ctx.guild.id, setting="notify", value=True))
if notify and name not in ["Kick", "Ban"]: # Ignore Kick and Ban, as these are unique if notify and name not in ("Kick", "Ban"): # Ignore Kick and Ban, as these are unique
fields = [ fields = (
EmbedField(name="Action Type", value=name, inline=False), EmbedField(name="Action Type", value=name, inline=False),
EmbedField( EmbedField(
name="Reason", value=kwargs.get("reason", None) or "N/A", inline=False name="Reason", value=kwargs.get("reason", None) or "N/A", inline=False
), ),
] )
embed = build_embed( embed = build_embed(
title="Admin action taken", title="Admin action taken",
description=f"Admin action has been taken against you in {ctx.guild.name}", description=f"Admin action has been taken against you in {ctx.guild.name}",

View file

@ -13,7 +13,7 @@ def warning_embed(user: Member, reason: str) -> Embed:
user: User to warn user: User to warn
reason: Warning reason reason: Warning reason
""" """
fields = [EmbedField(name="Reason", value=reason, inline=False)] fields = (EmbedField(name="Reason", value=reason, inline=False),)
embed = build_embed( embed = build_embed(
title="Warning", description=f"{user.mention} has been warned", fields=fields title="Warning", description=f"{user.mention} has been warned", fields=fields
) )

View file

@ -1,11 +1,11 @@
"""JARVIS update handler.""" """JARVIS update handler."""
import asyncio import asyncio
import importlib
import inspect
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from importlib import import_module
from inspect import getmembers, isclass
from pkgutil import iter_modules from pkgutil import iter_modules
from types import ModuleType from types import FunctionType, ModuleType
from typing import TYPE_CHECKING, Callable, Dict, List, Optional from typing import TYPE_CHECKING, Callable, Dict, List, Optional
import git import git
@ -19,7 +19,7 @@ import jarvis.cogs
if TYPE_CHECKING: if TYPE_CHECKING:
from dis_snek.client.client import Snake from dis_snek.client.client import Snake
logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@dataclass @dataclass
@ -32,49 +32,57 @@ class UpdateResult:
added: List[str] added: List[str]
removed: List[str] removed: List[str]
changed: List[str] changed: List[str]
inserted_lines: int lines: Dict[str, int]
deleted_lines: int
total_lines: int
def get_all_commands(module: ModuleType = jarvis.cogs) -> Dict[str, Callable]: def get_all_commands(module: ModuleType = jarvis.cogs) -> Dict[str, Callable]:
"""Get all SlashCommands from a specified module.""" """Get all SlashCommands from a specified module."""
commands = {} commands = {}
def validate_ires(entry: tuple) -> bool:
return isclass(entry[1]) and issubclass(entry[1], Scale) and entry[1] is not Scale
def validate_cog(cog: FunctionType) -> bool:
return isinstance(cog, SlashCommand)
for item in iter_modules(module.__path__): for item in iter_modules(module.__path__):
new_module = importlib.import_module(f"{module.__name__}.{item.name}") new_module = import_module(f"{module.__name__}.{item.name}")
if item.ispkg: if item.ispkg:
if cmds := get_all_commands(new_module): if cmds := get_all_commands(new_module):
commands.update(cmds) commands.update(cmds)
else: else:
inspect_result = inspect.getmembers(new_module) inspect_result = getmembers(new_module)
cogs = [] cogs = find_all(validate_ires, inspect_result)
for _, val in inspect_result: commands.update(
if inspect.isclass(val) and issubclass(val, Scale) and val is not Scale: {
cogs.append(val) commands[cog.__module__]: find_all(validate_cog, cog.__dict__.values())
for cog in cogs: for cog in cogs
values = cog.__dict__.values() }
commands[cog.__module__] = find_all(lambda x: isinstance(x, SlashCommand), values) )
return {k: v for k, v in commands.items() if v} return {k: v for k, v in commands.items() if v}
def get_git_changes(repo: git.Repo) -> dict: def get_git_changes(repo: git.Repo) -> dict:
"""Get all Git changes""" """Get all Git changes"""
logger = _logger
logger.debug("Getting all git changes") logger.debug("Getting all git changes")
head = repo.head.ref current_hash = repo.head.ref.object.hexsha
current_hash = head.object.hexsha tracking = repo.head.ref.tracking_branch()
tracking = head.tracking_branch()
file_changes = {} file_changes = {}
for commit in tracking.commit.iter_items(repo, f"{head.path}..{tracking.path}"): for commit in tracking.commit.iter_items(repo, f"{repo.head.ref.path}..{tracking.path}"):
if commit.hexsha == current_hash: if commit.hexsha == current_hash:
break break
files = commit.stats.files files = commit.stats.files
file_changes.update(
{key: {"insertions": 0, "deletions": 0, "lines": 0} for key in files.keys()}
)
for file, stats in files.items(): for file, stats in files.items():
if file not in file_changes: if file not in file_changes:
file_changes[file] = {"insertions": 0, "deletions": 0, "lines": 0} file_changes[file] = {"insertions": 0, "deletions": 0, "lines": 0}
for k, v in stats.items(): for key, val in stats.items():
file_changes[file][k] += v file_changes[file][key] += val
logger.debug(f"Found {len(file_changes)} changed files") logger.debug("Found %i changed files", len(file_changes))
table = Table(title="File Changes") table = Table(title="File Changes")
@ -96,14 +104,12 @@ def get_git_changes(repo: git.Repo) -> dict:
str(stats["deletions"]), str(stats["deletions"]),
str(stats["lines"]), str(stats["lines"]),
) )
logger.debug(f"{i_total} insertions, {d_total} deletions, {l_total} total") logger.debug("%i insertions, %i deletions, %i total", i_total, d_total, l_total)
table.add_row("Total", str(i_total), str(d_total), str(l_total)) table.add_row("Total", str(i_total), str(d_total), str(l_total))
return { return {
"table": table, "table": table,
"inserted_lines": i_total, "lines": {"inserted_lines": i_total, "deleted_lines": d_total, "total_lines": l_total},
"deleted_lines": d_total,
"total_lines": l_total,
} }
@ -117,6 +123,7 @@ async def update(bot: "Snake") -> Optional[UpdateResult]:
Returns: Returns:
UpdateResult object UpdateResult object
""" """
logger = _logger
repo = git.Repo(".") repo = git.Repo(".")
current_hash = repo.head.object.hexsha current_hash = repo.head.object.hexsha
origin = repo.remotes.origin origin = repo.remotes.origin
@ -124,7 +131,7 @@ async def update(bot: "Snake") -> Optional[UpdateResult]:
remote_hash = origin.refs[repo.active_branch.name].object.hexsha remote_hash = origin.refs[repo.active_branch.name].object.hexsha
if current_hash != remote_hash: if current_hash != remote_hash:
logger.info(f"Updating from {current_hash} to {remote_hash}") logger.info("Updating from %s to %s", current_hash, remote_hash)
current_commands = get_all_commands() current_commands = get_all_commands()
changes = get_git_changes(repo) changes = get_git_changes(repo)
@ -142,13 +149,13 @@ async def update(bot: "Snake") -> Optional[UpdateResult]:
logger.debug("Checking for removed cogs") logger.debug("Checking for removed cogs")
for module in current_commands.keys(): for module in current_commands.keys():
if module not in new_commands: if module not in new_commands:
logger.debug(f"Module {module} removed after update") logger.debug("Module %s removed after update", module)
bot.shed_scale(module) bot.shed_scale(module)
unloaded.append(module) unloaded.append(module)
logger.debug("Checking for new/modified commands") logger.debug("Checking for new/modified commands")
for module, commands in new_commands.items(): for module, commands in new_commands.items():
logger.debug(f"Processing {module}") logger.debug("Processing %s", module)
if module not in current_commands: if module not in current_commands:
bot.grow_scale(module) bot.grow_scale(module)
loaded.append(module) loaded.append(module)