diff --git a/jarvis/utils/updates.py b/jarvis/utils/updates.py index da9ba68..726c777 100644 --- a/jarvis/utils/updates.py +++ b/jarvis/utils/updates.py @@ -6,7 +6,7 @@ from importlib import import_module from inspect import getmembers, isclass from pkgutil import iter_modules from types import FunctionType, ModuleType -from typing import TYPE_CHECKING, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional import git from dis_snek.client.utils.misc_utils import find, find_all @@ -39,8 +39,8 @@ def get_all_commands(module: ModuleType = jarvis.cogs) -> Dict[str, Callable]: """Get all SlashCommands from a specified module.""" commands = {} - def validate_ires(entry: tuple) -> bool: - return isclass(entry[1]) and issubclass(entry[1], Scale) and entry[1] is not Scale + def validate_ires(entry: Any) -> bool: + return isclass(entry) and issubclass(entry, Scale) and entry is not Scale def validate_cog(cog: FunctionType) -> bool: return isinstance(cog, SlashCommand) @@ -52,13 +52,13 @@ def get_all_commands(module: ModuleType = jarvis.cogs) -> Dict[str, Callable]: commands.update(cmds) else: inspect_result = getmembers(new_module) - cogs = find_all(validate_ires, inspect_result) - commands.update( - { - commands[cog.__module__]: find_all(validate_cog, cog.__dict__.values()) - for cog in cogs - } - ) + cogs = [] + for _, val in inspect_result: + if validate_ires(val): + cogs.append(val) + for cog in cogs: + values = cog.__dict__.values() + commands[cog.__module__] = find_all(lambda x: isinstance(x, SlashCommand), values) return {k: v for k, v in commands.items() if v}