187 lines
4.7 KiB
Python
187 lines
4.7 KiB
Python
"""J.A.R.V.I.S. Utility Functions."""
|
|
from datetime import datetime
|
|
from pkgutil import iter_modules
|
|
from typing import Any, Callable, Iterable, List, Optional, TypeVar
|
|
|
|
import git
|
|
from dis_snek.models.discord.embed import Embed
|
|
|
|
import jarvis.cogs
|
|
import jarvis.db
|
|
from jarvis.config import get_config
|
|
|
|
__all__ = ["field", "db", "cachecog", "permissions"]
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def build_embed(
|
|
title: str,
|
|
description: str,
|
|
fields: list,
|
|
color: str = "#FF0000",
|
|
timestamp: datetime = None,
|
|
**kwargs: dict,
|
|
) -> Embed:
|
|
"""Embed builder utility function."""
|
|
if not timestamp:
|
|
timestamp = datetime.now()
|
|
embed = Embed(
|
|
title=title,
|
|
description=description,
|
|
color=color,
|
|
timestamp=timestamp,
|
|
**kwargs,
|
|
)
|
|
for field in fields:
|
|
embed.add_field(**field.to_dict())
|
|
return embed
|
|
|
|
|
|
def convert_bytesize(b: int) -> str:
|
|
"""Convert bytes amount to human readable."""
|
|
b = float(b)
|
|
sizes = ["B", "KB", "MB", "GB", "TB", "PB"]
|
|
size = 0
|
|
while b >= 1024 and size < len(sizes) - 1:
|
|
b = b / 1024
|
|
size += 1
|
|
return "{:0.3f} {}".format(b, sizes[size])
|
|
|
|
|
|
def unconvert_bytesize(size: int, ending: str) -> int:
|
|
"""Convert human readable to bytes."""
|
|
ending = ending.upper()
|
|
sizes = ["B", "KB", "MB", "GB", "TB", "PB"]
|
|
if ending == "B":
|
|
return size
|
|
# Rounding is only because bytes cannot be partial
|
|
return round(size * (1024 ** sizes.index(ending)))
|
|
|
|
|
|
def get_extensions(path: str = jarvis.cogs.__path__) -> list:
|
|
"""Get J.A.R.V.I.S. cogs."""
|
|
config = get_config()
|
|
vals = config.cogs or [x.name for x in iter_modules(path)]
|
|
return ["jarvis.cogs.{}".format(x) for x in vals]
|
|
|
|
|
|
def update() -> int:
|
|
"""J.A.R.V.I.S. update utility."""
|
|
repo = git.Repo(".")
|
|
dirty = repo.is_dirty()
|
|
current_hash = repo.head.object.hexsha
|
|
origin = repo.remotes.origin
|
|
origin.fetch()
|
|
if current_hash != origin.refs["main"].object.hexsha:
|
|
if dirty:
|
|
return 2
|
|
origin.pull()
|
|
return 0
|
|
return 1
|
|
|
|
|
|
def get_repo_hash() -> str:
|
|
"""J.A.R.V.I.S. current branch hash."""
|
|
repo = git.Repo(".")
|
|
return repo.head.object.hexsha
|
|
|
|
|
|
def find(predicate: Callable, sequence: Iterable) -> Optional[Any]:
|
|
"""
|
|
Find the first element in a sequence that matches the predicate.
|
|
|
|
??? Hint "Example Usage:"
|
|
```python
|
|
member = find(lambda m: m.name == "UserName", guild.members)
|
|
```
|
|
Args:
|
|
predicate: A callable that returns a boolean value
|
|
sequence: A sequence to be searched
|
|
|
|
Returns:
|
|
A match if found, otherwise None
|
|
|
|
"""
|
|
for el in sequence:
|
|
if predicate(el):
|
|
return el
|
|
return None
|
|
|
|
|
|
def find_all(predicate: Callable, sequence: Iterable) -> List[Any]:
|
|
"""
|
|
Find all elements in a sequence that match the predicate.
|
|
|
|
??? Hint "Example Usage:"
|
|
```python
|
|
members = find_all(lambda m: m.name == "UserName", guild.members)
|
|
```
|
|
Args:
|
|
predicate: A callable that returns a boolean value
|
|
sequence: A sequence to be searched
|
|
|
|
Returns:
|
|
A list of matches
|
|
|
|
"""
|
|
matches = []
|
|
for el in sequence:
|
|
if predicate(el):
|
|
matches.append(el)
|
|
return matches
|
|
|
|
|
|
def get(sequence: Iterable, **kwargs: Any) -> Optional[Any]:
|
|
"""
|
|
Find the first element in a sequence that matches all attrs.
|
|
|
|
??? Hint "Example Usage:"
|
|
```python
|
|
channel = get(guild.channels, nsfw=False, category="General")
|
|
```
|
|
|
|
Args:
|
|
sequence: A sequence to be searched
|
|
kwargs: Keyword arguments to search the sequence for
|
|
|
|
Returns:
|
|
A match if found, otherwise None
|
|
"""
|
|
if not kwargs:
|
|
return sequence[0]
|
|
|
|
for el in sequence:
|
|
if any(not hasattr(el, attr) for attr in kwargs.keys()):
|
|
continue
|
|
if all(getattr(el, attr) == value for attr, value in kwargs.items()):
|
|
return el
|
|
return None
|
|
|
|
|
|
def get_all(sequence: Iterable, **kwargs: Any) -> List[Any]:
|
|
"""
|
|
Find all elements in a sequence that match all attrs.
|
|
|
|
??? Hint "Example Usage:"
|
|
```python
|
|
channels = get_all(guild.channels, nsfw=False, category="General")
|
|
```
|
|
|
|
Args:
|
|
sequence: A sequence to be searched
|
|
kwargs: Keyword arguments to search the sequence for
|
|
|
|
Returns:
|
|
A list of matches
|
|
"""
|
|
if not kwargs:
|
|
return sequence
|
|
|
|
matches = []
|
|
for el in sequence:
|
|
if any(not hasattr(el, attr) for attr in kwargs.keys()):
|
|
continue
|
|
if all(getattr(el, attr) == value for attr, value in kwargs.items()):
|
|
matches.append(el)
|
|
return matches
|