jarvis-bot/jarvis/utils/__init__.py

101 lines
2.6 KiB
Python

"""J.A.R.V.I.S. Utility Functions."""
from datetime import datetime
from pkgutil import iter_modules
import git
from discord import Color, Embed, Message
from discord.ext import commands
import jarvis.cogs
import jarvis.db
from jarvis.config import get_config
__all__ = ["field", "db", "cachecog", "permissions"]
def convert_bytesize(b: int) -> str:
"""Convert bytes amount to human readable."""
b = float(b)
sizes = ["B", "KB", "MB", "GB", "TB", "PB"]
size = 0
while b >= 1024 and size < len(sizes) - 1:
b = b / 1024
size += 1
return "{:0.3f} {}".format(b, sizes[size])
def unconvert_bytesize(size: int, ending: str) -> int:
"""Convert human readable to bytes."""
ending = ending.upper()
sizes = ["B", "KB", "MB", "GB", "TB", "PB"]
if ending == "B":
return size
# Rounding is only because bytes cannot be partial
return round(size * (1024 ** sizes.index(ending)))
def get_prefix(bot: commands.Bot, message: Message) -> list:
"""Get bot prefixes."""
prefixes = ["!", "-", "%"]
# if not message.guild:
# return "?"
return commands.when_mentioned_or(*prefixes)(bot, message)
def get_extensions(path: str = jarvis.cogs.__path__) -> list:
"""Get J.A.R.V.I.S. cogs."""
config = get_config()
vals = config.cogs or [x.name for x in iter_modules(path)]
return ["jarvis.cogs.{}".format(x) for x in vals]
def parse_color_hex(hex: str) -> Color:
"""Convert a hex color to a d.py Color."""
hex = hex.lstrip("#")
rgb = tuple(int(hex[i : i + 2], 16) for i in (0, 2, 4)) # noqa: E203
return Color.from_rgb(*rgb)
def build_embed(
title: str,
description: str,
fields: list,
color: str = "#FF0000",
timestamp: datetime = None,
**kwargs: dict,
) -> Embed:
"""Embed builder utility function."""
if not timestamp:
timestamp = datetime.utcnow()
embed = Embed(
title=title,
description=description,
color=parse_color_hex(color),
timestamp=timestamp,
**kwargs,
)
for field in fields:
embed.add_field(**field.to_dict())
return embed
def update() -> int:
"""J.A.R.V.I.S. update utility."""
repo = git.Repo(".")
dirty = repo.is_dirty()
current_hash = repo.head.object.hexsha
origin = repo.remotes.origin
origin.fetch()
if current_hash != origin.refs["main"].object.hexsha:
if dirty:
return 2
origin.pull()
return 0
return 1
def get_repo_hash() -> str:
"""J.A.R.V.I.S. current branch hash."""
repo = git.Repo(".")
return repo.head.object.hexsha