Compare commits

..

No commits in common. "e863f5c07778cc424d7e5682d7a1651deb93fbe3" and "ec4219e5a54bea78ff19f23f1754a036e8d0eae3" have entirely different histories.

26 changed files with 1623 additions and 2053 deletions

22
.flake8
View file

@ -3,16 +3,18 @@ exclude =
tests/*
extend-ignore =
Q0, E501, C812, E203, W503,
ANN1, ANN003,
ANN204, ANN206,
D105, D107,
S311,
D401,
D400,
D101, D102,
D106,
R503,
Q0, E501, C812, E203, W503, # These default to arguing with Black. We might configure some of them eventually
ANN1, # Ignore self and cls annotations
ANN204, ANN206, # return annotations for special methods and class methods
D105, D107, # Missing Docstrings in magic method and __init__
S311, # Standard pseudo-random generators are not suitable for security/cryptographic purposes.
D401, # First line should be in imperative mood; try rephrasing
D400, # First line should end with a period
D101, # Missing docstring in public class
D106, # Missing docstring in public nested class
# Plugins we don't currently include: flake8-return
R503, # missing explicit return at the end of function ableto return non-None value.
max-line-length=100

View file

@ -1,50 +0,0 @@
precommit:
stage: test
image: python:3.12-bookworm
before_script:
- apt-get update && apt-get install -y --no-install-recommends git
script:
- pip install -r requirements.precommit.txt
- pre-commit run --all-files
rules:
- if: $CI_PIPELINE_SOURCE == 'merge_request_event'
.test_template: &test_template
stage: test
script:
- pip install poetry
- poetry install
- source `poetry env info --path`/bin/activate
- python -m pytest
test python3.10:
<<: *test_template
image: python:3.10-slim
test python3.11:
<<: *test_template
image: python:3.11-slim
test python3.12:
<<: *test_template
image: python:3.12-slim
coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/'
artifacts:
reports:
coverage_report:
coverage_format: cobertura
path: coverage.xml
release:
stage: build
rules:
- if: $CI_COMMIT_TAG
script:
- pip install poetry
- poetry build
- poetry config repositories.gitlab "${CI_API_V4_URL}/projects/${CI_PROJECT_ID}/packages/pypi"
- poetry config http-basic.gitlab gitlab-ci-token "$CI_JOB_TOKEN"
- poetry publish --repository gitlab
include:
- template: Jobs/SAST.gitlab-ci.yml

View file

@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.1.0
hooks:
- id: check-toml
- id: check-yaml
@ -9,19 +9,21 @@ repos:
- id: requirements-txt-fixer
- id: end-of-file-fixer
- id: debug-statements
language_version: python3.10
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
rev: v1.9.0
hooks:
- id: python-check-blanket-noqa
- repo: https://github.com/psf/black
rev: 23.7.0
rev: 22.3.0
hooks:
- id: black
args: [--line-length=100]
args: [--line-length=100, --target-version=py310]
language_version: python3.10
- repo: https://github.com/pre-commit/mirrors-isort
rev: v5.10.1
@ -30,7 +32,7 @@ repos:
args: ["--profile", "black"]
- repo: https://github.com/pycqa/flake8
rev: 6.1.0
rev: 4.0.1
hooks:
- id: flake8
additional_dependencies:
@ -44,3 +46,4 @@ repos:
- flake8-deprecated
- flake8-print
- flake8-return
language_version: python3.10

View file

@ -1,4 +1,3 @@
{
"python.formatting.provider": "black",
"python.analysis.typeCheckingMode": "off"
"python.formatting.provider": "black"
}

74
jarvis_core/config.py Normal file
View file

@ -0,0 +1,74 @@
"""Load global config."""
import os
from lib2to3.pgen2 import token
from pathlib import Path
from typing import Union
from dotenv import load_dotenv
from yaml import load
from jarvis_core.util import Singleton, find_all
try:
from yaml import CLoader as Loader
except ImportError:
from yaml import Loader
DEFAULT_YAML_PATH = Path("config.yaml")
DEFAULT_ENV_PATH = Path(".env")
class Config(Singleton):
REQUIRED = []
OPTIONAL = {}
ENV_REQUIRED = []
ENV_OPTIONAL = {}
@classmethod
def _process_env(cls, **kwargs) -> dict:
"""Process environment variables into standard arguments"""
@classmethod
def from_env(cls, filepath: Union[Path, str] = DEFAULT_ENV_PATH) -> "Config":
"""Loag the environment config."""
if inst := cls.__dict__.get("inst"):
return inst
load_dotenv(filepath)
data = {}
for item in cls.ENV_REQUIRED:
data[item] = os.environ.get(item, None)
for item, default in cls.ENV_OPTIONAL.items():
data[item] = os.environ.get(item, default)
data = cls._process_env(**data)
return cls(**data)
@classmethod
def from_yaml(cls, filepath: Union[Path, str] = DEFAULT_YAML_PATH) -> "Config":
"""Load the yaml config file."""
if inst := cls.__dict__.get("inst"):
return inst
if isinstance(filepath, str):
filepath = Path(filepath)
with filepath.open() as f:
raw = f.read()
y = load(raw, Loader=Loader)
return cls(**y)
@classmethod
def load(cls) -> "Config":
if DEFAULT_ENV_PATH.exists():
return cls.from_env()
else:
return cls.from_yaml()
@classmethod
def reload(cls) -> bool:
"""Reload the config."""
return cls.__dict__.pop("inst", None) is None

View file

@ -1,34 +1,101 @@
"""JARVIS database models and utilities."""
from datetime import timezone
from beanie import init_beanie
from bson import ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from pytz import utc
from umongo.frameworks import MotorAsyncIOInstance
from jarvis_core.db.models import all_models
from jarvis_core.util import find
CLIENT = None
JARVISDB = None
CTC2DB = None
JARVIS_INST = MotorAsyncIOInstance()
CTC2_INST = MotorAsyncIOInstance()
async def connect(
host: list[str] | str,
def connect(
username: str,
password: str,
port: int = 27017,
testing: bool = False,
extra_models: list = None,
host: str = None,
hosts: list[str] = None,
replicaset: str = None,
) -> None:
"""
Connect to MongoDB.
Args:
host: Hostname/IP, or list of hosts for replica sets
host: Hostname/IP
username: Username
password: Password
port: Port
testing: Whether or not to use jarvis_dev
extra_models: Extra beanie models to register
"""
extra_models = extra_models or []
client = AsyncIOMotorClient(
host, username=username, password=password, port=port, tz_aware=True, tzinfo=timezone.utc
global CLIENT, JARVISDB, CTC2DB, JARVIS_INST, CTC2_INST
if not replicaset:
CLIENT = AsyncIOMotorClient(
host=host, username=username, password=password, port=port, tz_aware=True, tzinfo=utc
)
db = client.jarvis_dev if testing else client.jarvis
await init_beanie(database=db, document_models=all_models + extra_models)
else:
CLIENT = AsyncIOMotorClient(
hosts, username=username, password=password, tz_aware=True, tzinfo=utc, replicaset=replicaset
)
JARVISDB = CLIENT.narvis if testing else CLIENT.jarvis
CTC2DB = CLIENT.ctc2
JARVIS_INST.set_db(JARVISDB)
CTC2_INST.set_db(CTC2DB)
QUERY_OPS = ["ne", "lt", "lte", "gt", "gte", "not", "in", "nin", "mod", "all", "size"]
STRING_OPS = [
"exact",
"iexact",
"contains",
"icontains",
"startswith",
"istartswith",
"endswith",
"iendswith",
"wholeword",
"iwholeword",
"regex",
"iregex" "match",
]
GEO_OPS = [
"get_within",
"geo_within_box",
"geo_within_polygon",
"geo_within_center",
"geo_within_sphere",
"geo_intersects",
"near",
"within_distance",
"within_spherical_distance",
"near_sphere",
"within_box",
"within_polygon",
"max_distance",
"min_distance",
]
ALL_OPS = QUERY_OPS + STRING_OPS + GEO_OPS
def q(**kwargs: dict) -> dict:
"""uMongo query wrapper.""" # noqa: D403
query = {}
for key, value in kwargs.items():
if key == "_id":
value = ObjectId(value)
elif "__" in key:
args = key.split("__")
if not any(x in ALL_OPS for x in args):
key = ".".join(args)
else:
idx = args.index(find(lambda x: x in ALL_OPS, args))
key = ".".join(args[:idx])
value = {f"${args[idx]}": value}
query[key] = value
return query

26
jarvis_core/db/fields.py Normal file
View file

@ -0,0 +1,26 @@
import bson
import marshmallow as ma
from marshmallow import fields as ma_fields
from umongo import fields
class BinaryField(fields.BaseField, ma_fields.Field):
default_error_messages = {"invalid": "Not a valid byte sequence."}
def _serialize(self, value, attr, data, **kwargs):
return bytes(value)
def _deserialize(self, value, attr, data, **kwargs):
if not isinstance(value, bytes):
self.fail("invalid")
return value
def _serialize_to_mongo(self, obj):
return bson.binary.Binary(obj)
def _deserialize_from_mongo(self, value):
return bytes(value)
class RawField(fields.BaseField, ma_fields.Raw):
pass

View file

@ -1,279 +1,257 @@
"""JARVIS database models."""
import re
from datetime import datetime
from typing import Optional
from typing import Any, List
from beanie import Document, Link
from pydantic import BaseModel, Field
import marshmallow as ma
from umongo import Document, EmbeddedDocument, fields
from jarvis_core.db.models.actions import Ban, Kick, Mute, Unban, Warning
from jarvis_core.db.models.captcha import Captcha
from jarvis_core.db.models.modlog import Action, Modlog, Note
from jarvis_core.db.models.reddit import Subreddit, SubredditFollow
from jarvis_core.db.models.twitter import TwitterAccount, TwitterFollow
from jarvis_core.db.utils import NowField, Snowflake, SnowflakeDocument
__all__ = [
"Action",
"Autopurge",
"Autoreact",
"Ban",
"Captcha" "Config",
"Filter",
"Guess",
"Kick",
"Lock",
"Lockdown",
"Modlog",
"Mute",
"Note",
"Pin",
"Pinboard",
"Phishlist",
"Purge",
"Reminder",
"Rolegiver",
"Bypass",
"Roleping",
"Setting",
"Subreddit",
"SubredditFollow",
"Tag",
"Temprole",
"TwitterAccount",
"TwitterFollow",
"Unban",
"UserSetting",
"Warning",
"all_models",
]
from jarvis_core.db import CTC2_INST, JARVIS_INST
from jarvis_core.db.fields import RawField
from jarvis_core.db.models.actions import *
from jarvis_core.db.models.backups import *
from jarvis_core.db.models.mastodon import *
from jarvis_core.db.models.modlog import *
from jarvis_core.db.models.reddit import *
from jarvis_core.db.models.twitter import *
from jarvis_core.db.utils import get_now
class Autopurge(SnowflakeDocument):
guild: Snowflake
channel: Snowflake
delay: int = 30
admin: Snowflake
created_at: datetime = NowField()
@JARVIS_INST.register
class Autopurge(Document):
guild: int = fields.IntegerField(required=True)
channel: int = fields.IntegerField(required=True)
delay: int = fields.IntegerField(default=30)
admin: int = fields.IntegerField(required=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
class Autoreact(SnowflakeDocument):
guild: Snowflake
channel: Snowflake
reactions: list[str] = Field(default_factory=list)
admin: Snowflake
thread: bool
created_at: datetime = NowField()
@JARVIS_INST.register
class Autoreact(Document):
guild: int = fields.IntegerField(required=True)
channel: int = fields.IntegerField(required=True)
reactions: List[str] = fields.ListField(fields.StringField())
admin: int = fields.IntegerField(required=True)
thread: bool = fields.BooleanField(default=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
class Config(SnowflakeDocument):
@JARVIS_INST.register
class Config(Document):
"""Config database object."""
key: str
value: str | int | bool
key: str = fields.StringField(required=True)
value: Any = RawField(required=True)
class Filter(SnowflakeDocument):
"""Filter database object."""
@JARVIS_INST.register
class Filter(Document):
"""Regex Filter database object."""
guild: Snowflake
name: str
filters: list[str] = Field(default_factory=list)
def _validate_filters(value):
for v in value:
try:
re.compile(v)
except re.error:
raise ValueError(f"Invalid regex: {v}")
guild: int = fields.IntegerField(required=True)
name: str = fields.StringField(required=True)
filters: List[str] = fields.ListField(fields.StringField(), validate=[_validate_filters])
class Guess(SnowflakeDocument):
@CTC2_INST.register
class Guess(Document):
"""Guess database object."""
correct: bool
guess: str
user: Snowflake
correct: bool = fields.BooleanField(default=False)
guess: str = fields.StringField(required=True)
user: int = fields.IntegerField(required=True)
class Permission(BaseModel):
@JARVIS_INST.register
class Permission(EmbeddedDocument):
"""Embedded Permissions document."""
id: Snowflake
allow: Optional[Snowflake] = 0
deny: Optional[Snowflake] = 0
id: int = fields.IntegerField(required=True)
allow: int = fields.IntegerField(default=0)
deny: int = fields.IntegerField(default=0)
class Lock(SnowflakeDocument):
@JARVIS_INST.register
class Lock(Document):
"""Lock database object."""
active: bool = True
admin: Snowflake
channel: Snowflake
duration: int = 10
reason: str
original_perms: Permission
created_at: datetime = NowField()
active: bool = fields.BooleanField(default=True)
admin: int = fields.IntegerField(required=True)
channel: int = fields.IntegerField(required=True)
duration: int = fields.IntegerField(default=10)
guild: int = fields.IntegerField(required=True)
reason: str = fields.StringField(required=True)
original_perms: Permission = fields.EmbeddedField(Permission, required=False)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
class Lockdown(SnowflakeDocument):
@JARVIS_INST.register
class Lockdown(Document):
"""Lockdown database object."""
active: bool = True
admin: Snowflake
duration: int = 10
guild: Snowflake
reason: str
original_perms: Snowflake
created_at: datetime = NowField()
active: bool = fields.BooleanField(default=True)
admin: int = fields.IntegerField(required=True)
duration: int = fields.IntegerField(default=10)
guild: int = fields.IntegerField(required=True)
reason: str = fields.StringField(required=True)
original_perms: int = fields.IntegerField(required=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
class Purge(SnowflakeDocument):
@JARVIS_INST.register
class Event(Document):
"""Event Meetup Object."""
user: int = fields.IntegerField(required=True)
going: bool = fields.BooleanField(required=True)
travel_method: str = fields.StringField()
before_flight: str = fields.StringField()
before_arrival_time: datetime = fields.AwareDateTimeField()
before_departure_time: datetime = fields.AwareDateTimeField()
after_flight: str = fields.StringField()
after_arrival_time: datetime = fields.AwareDateTimeField()
after_departure_time: datetime = fields.AwareDateTimeField()
hotel: str = fields.StringField()
event_name: str = fields.StringField()
@JARVIS_INST.register
class Phishlist(Document):
"""Phishing safelist."""
url: str = fields.StringField(required=True)
confirmed: bool = fields.BooleanField(default=False)
valid: bool = fields.BooleanField(default=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
@JARVIS_INST.register
class Purge(Document):
"""Purge database object."""
admin: Snowflake
channel: Snowflake
guild: Snowflake
count_: int = Field(10, alias="count")
created_at: datetime = NowField()
admin: int = fields.IntegerField(required=True)
channel: int = fields.IntegerField(required=True)
guild: int = fields.IntegerField(required=True)
count: int = fields.IntegerField(default=10)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
class Reminder(SnowflakeDocument):
@JARVIS_INST.register
class Reminder(Document):
"""Reminder database object."""
active: bool = True
user: Snowflake
guild: Snowflake
channel: Snowflake
message: str
remind_at: datetime
created_at: datetime = NowField()
repeat: Optional[str] = None
timezone: str = "UTC"
total_reminders: int = 0
parent: Optional[str] = None
private: bool = False
active: bool = fields.BooleanField(default=True)
user: int = fields.IntegerField(required=True)
guild: int = fields.IntegerField(required=True)
channel: int = fields.IntegerField(required=True)
message: str = fields.StringField(required=True)
remind_at: datetime = fields.AwareDateTimeField(required=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
private: bool = fields.BooleanField(default=False)
class Rolegiver(SnowflakeDocument):
@JARVIS_INST.register
class Rolegiver(Document):
"""Rolegiver database object."""
guild: Snowflake
roles: Optional[list[Snowflake]] = Field(default_factory=list)
group: Optional[str] = None
guild: int = fields.IntegerField(required=True)
roles: List[int] = fields.ListField(fields.IntegerField())
class Bypass(BaseModel):
@JARVIS_INST.register
class Bypass(EmbeddedDocument):
"""Roleping bypass embedded object."""
users: Optional[list[Snowflake]] = Field(default_factory=list)
roles: Optional[list[Snowflake]] = Field(default_factory=list)
users: List[int] = fields.ListField(fields.IntegerField())
roles: List[int] = fields.ListField(fields.IntegerField())
class Roleping(SnowflakeDocument):
@JARVIS_INST.register
class Roleping(Document):
"""Roleping database object."""
active: bool = True
role: Snowflake
guild: Snowflake
admin: Snowflake
bypass: Bypass
created_at: datetime = NowField()
active: bool = fields.BooleanField(default=True)
role: int = fields.IntegerField(required=True)
guild: int = fields.IntegerField(required=True)
admin: int = fields.IntegerField(required=True)
bypass: Bypass = fields.EmbeddedField(Bypass)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
class Setting(SnowflakeDocument):
@JARVIS_INST.register
class Setting(Document):
"""Setting database object."""
guild: Snowflake
setting: str
value: str | int | bool | list[int | str]
guild: int = fields.IntegerField(required=True)
setting: str = fields.StringField(required=True)
value: Any = RawField()
class Phishlist(SnowflakeDocument):
"""Phishlist database object."""
@JARVIS_INST.register
class Star(Document):
"""Star database object."""
url: str
confirmed: bool = False
valid: bool = True
created_at: datetime = NowField()
active: bool = fields.BooleanField(default=True)
index: int = fields.IntegerField(required=True)
message: int = fields.IntegerField(required=True)
channel: int = fields.IntegerField(required=True)
starboard: int = fields.IntegerField(required=True)
guild: int = fields.IntegerField(required=True)
admin: int = fields.IntegerField(required=True)
star: int = fields.IntegerField(required=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
class Pinboard(SnowflakeDocument):
"""Pinboard database object."""
@JARVIS_INST.register
class Starboard(Document):
"""Starboard database object."""
channel: Snowflake
guild: Snowflake
admin: Snowflake
created_at: datetime = NowField()
channel: int = fields.IntegerField(required=True)
guild: int = fields.IntegerField(required=True)
admin: int = fields.IntegerField(required=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
class Pin(SnowflakeDocument):
"""Pin database object."""
active: bool = True
index: int
message: Snowflake
channel: Snowflake
pinboard: Link[Pinboard]
guild: Snowflake
admin: Snowflake
pin: Snowflake
created_at: datetime = NowField()
class Tag(SnowflakeDocument):
@JARVIS_INST.register
class Tag(Document):
"""Tag database object."""
creator: Snowflake
name: str
content: str
guild: Snowflake
created_at: datetime = NowField()
edited_at: Optional[datetime] = None
editor: Optional[Snowflake] = None
creator: int = fields.IntegerField(required=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
editor: int = fields.IntegerField()
edited_at: datetime = fields.AwareDateTimeField()
name: str = fields.StringField(required=True)
content: str = fields.StringField(required=True)
guild: int = fields.IntegerField(required=True)
class Temprole(SnowflakeDocument):
@JARVIS_INST.register
class Temprole(Document):
"""Temporary role object."""
guild: Snowflake
user: Snowflake
role: Snowflake
admin: Snowflake
expires_at: datetime
reapply_on_rejoin: bool = True
created_at: datetime = NowField()
guild: int = fields.IntegerField(required=True)
user: int = fields.IntegerField(required=True)
role: int = fields.IntegerField(required=True)
admin: int = fields.IntegerField(required=True)
expires_at: datetime = fields.AwareDateTimeField(required=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
class UserSetting(SnowflakeDocument):
@JARVIS_INST.register
class UserSetting(Document):
"""User Setting object."""
user: Snowflake
type: str
setting: str
value: str | int | bool
user: int = fields.IntegerField(required=True)
type: str = fields.StringField(required=True)
setting: str = fields.StringField(required=True)
value: Any = RawField()
all_models: list[Document] = [
Autopurge,
Autoreact,
Ban,
Captcha,
Config,
Filter,
Guess,
Kick,
Lock,
Lockdown,
Modlog,
Mute,
Pin,
Pinboard,
Phishlist,
Purge,
Reminder,
Rolegiver,
Roleping,
Setting,
Subreddit,
SubredditFollow,
Tag,
Temprole,
TwitterAccount,
TwitterFollow,
Unban,
UserSetting,
Warning,
]
class Meta:
collection_name = "usersetting"

View file

@ -1,64 +1,72 @@
"""User action models."""
from datetime import datetime
from typing import Optional
from datetime import datetime, timezone
from jarvis_core.db.utils import NowField, Snowflake, SnowflakeDocument
from umongo import Document, fields
from jarvis_core.db import JARVIS_INST
from jarvis_core.db.utils import get_now
class Ban(SnowflakeDocument):
active: bool = True
admin: Snowflake
user: Snowflake
username: str
discrim: Optional[int]
duration: Optional[int]
guild: Snowflake
type: str = "perm"
reason: str
created_at: datetime = NowField()
@JARVIS_INST.register
class Ban(Document):
active: bool = fields.BooleanField(default=True)
admin: int = fields.IntegerField(required=True)
user: int = fields.IntegerField(required=True)
username: str = fields.StringField(required=True)
discrim: int = fields.IntegerField(required=True)
duration: int = fields.IntegerField(required=False, default=None)
guild: int = fields.IntegerField(required=True)
type: str = fields.StringField(default="perm")
reason: str = fields.StringField(required=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
class Kick(SnowflakeDocument):
@JARVIS_INST.register
class Kick(Document):
"""Kick database object."""
admin: Snowflake
guild: Snowflake
reason: str
user: Snowflake
created_at: datetime = NowField()
admin: int = fields.IntegerField(required=True)
guild: int = fields.IntegerField(required=True)
reason: str = fields.StringField(required=True)
user: int = fields.IntegerField(required=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
class Mute(SnowflakeDocument):
@JARVIS_INST.register
class Mute(Document):
"""Mute database object."""
active: bool = True
user: Snowflake
admin: Snowflake
duration: int = 10
guild: Snowflake
reason: str
created_at: datetime = NowField()
active: bool = fields.BooleanField(default=True)
user: int = fields.IntegerField(required=True)
admin: int = fields.IntegerField(required=True)
duration: int = fields.IntegerField(default=10)
guild: int = fields.IntegerField(required=True)
reason: str = fields.StringField(required=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
class Unban(SnowflakeDocument):
@JARVIS_INST.register
class Unban(Document):
"""Unban database object."""
user: Snowflake
username: str
discrim: Optional[str]
guild: Snowflake
reason: str
created_at: datetime = NowField()
user: int = fields.IntegerField(required=True)
username: str = fields.StringField(required=True)
discrim: int = fields.IntegerField(required=True)
guild: int = fields.IntegerField(required=True)
admin: int = fields.IntegerField(required=True)
reason: str = fields.StringField(required=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
class Warning(SnowflakeDocument):
@JARVIS_INST.register
class Warning(Document):
"""Warning database object."""
active: bool = True
admin: Snowflake
user: Snowflake
guild: Snowflake
duration: int = 24
reason: str
expires_at: datetime
created_at: datetime = NowField()
active: bool = fields.BooleanField(default=True)
admin: int = fields.IntegerField(required=True)
user: int = fields.IntegerField(required=True)
guild: int = fields.IntegerField(required=True)
duration: int = fields.IntegerField(default=24)
reason: str = fields.StringField(required=True)
expires_at: datetime = fields.AwareDateTimeField(required=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)

View file

@ -1,105 +1,122 @@
"""JARVIS Backup Models (NYI)."""
from datetime import datetime
from typing import Optional
from typing import List, Optional
from beanie import Document, Indexed, Link
from pydantic import BaseModel, Field
from umongo import Document, EmbeddedDocument, fields
from jarvis_core import __version__
from jarvis_core.db.utils import NanoField, NowField
from jarvis_core.db import JARVIS_INST
from jarvis_core.db.fields import BinaryField
from jarvis_core.db.utils import get_id, get_now
@JARVIS_INST.register
class Image(Document):
discord_id: int = Indexed(int, unique=True)
image_data: list[bytes]
image_ext: str
created_at: datetime = NowField()
discord_id: int = fields.IntegerField(unique=True)
image_data: List[bytes] = BinaryField()
image_ext: str = fields.StringField()
created_at: datetime = fields.AwareDateTimeField(default=get_now)
class PermissionOverwriteBackup(BaseModel):
id: int
type: int
allow: int
deny: int
@JARVIS_INST.register
class PermissionOverwriteBackup(EmbeddedDocument):
id: int = fields.IntegerField()
type: int = fields.IntegerField()
allow: int = fields.IntegerField()
deny: int = fields.IntegerField()
class WebhookBackup(BaseModel):
id: int
channel_id: int
type: int
avatar: Link[Image]
name: str
@JARVIS_INST.register
class WebhookBackup(EmbeddedDocument):
id: int = fields.IntegerField()
channel_id: int = fields.IntegerField()
type: int = fields.IntegerField()
avatar: Image = fields.ReferenceField(Image)
name: str = fields.StringField()
class ChannelBackup(BaseModel):
id: int
name: str
type: int
position: int
topic: Optional[str] = None
nsfw: bool = False
rate_limit_per_user: Optional[int] = None
bitrate: Optional[int] = None
user_limit: Optional[int] = None
permission_overwrites: list[PermissionOverwriteBackup] = Field(default_factory=list)
parent_id: Optional[int] = None
rtc_region: Optional[str] = None
video_quality_mode: Optional[int] = None
default_auto_archive_duration: Optional[int] = None
webhooks: list[WebhookBackup] = Field(default_factory=list)
@JARVIS_INST.register
class ChannelBackup(EmbeddedDocument):
id: int = fields.IntegerField()
name: str = fields.StringField()
type: int = fields.IntegerField()
position: int = fields.IntegerField()
topic: Optional[str] = fields.StringField(default=None)
nsfw: bool = fields.BooleanField(default=False)
rate_limit_per_user: int = fields.IntegerField(default=None)
bitrate: Optional[int] = fields.IntegerField(default=None)
user_limit: Optional[int] = fields.IntegerField(default=None)
permission_overwrites: List[PermissionOverwriteBackup] = fields.ListField(
fields.EmbeddedField(PermissionOverwriteBackup), factory=list
)
parent_id: Optional[int] = fields.IntegerField(default=None)
rtc_region: Optional[str] = fields.StringField(default=None)
video_quality_mode: Optional[int] = fields.IntegerField(default=None)
default_auto_archive_duration: Optional[int] = fields.IntegerField(default=None)
webhooks: List[WebhookBackup] = fields.ListField(
fields.EmbeddedField(WebhookBackup), factory=list
)
class RoleBackup(BaseModel):
id: int
name: str
permissions: int
color: str
hoist: bool
mentionable: bool
@JARVIS_INST.register
class RoleBackup(EmbeddedDocument):
id: int = fields.IntegerField()
name: str = fields.StringField()
permissions: int = fields.IntegerField()
color: str = fields.StringField()
hoist: bool = fields.BooleanField()
mentionable: bool = fields.BooleanField()
class EmojiBackup(BaseModel):
id: int
name: str
image: Link[Image]
@JARVIS_INST.register
class EmojiBackup(EmbeddedDocument):
id: int = fields.IntegerField()
name: str = fields.StringField()
image: Image = fields.ReferenceField(Image)
class StickerBackup(BaseModel):
id: int
name: str
format_type: int
tags: str
type: int
image: Link[Image]
@JARVIS_INST.register
class StickerBackup(EmbeddedDocument):
id: int = fields.IntegerField()
name: str = fields.StringField()
format_type: int = fields.IntegerField()
tags: str = fields.StringField()
type: int = fields.IntegerField()
image: Image = fields.ReferenceField(Image)
class GuildBackup(BaseModel):
name: str
description: Optional[str] = None
default_message_notifications: Optional[int] = None
explicit_content_filter: Optional[int] = None
afk_channel: Optional[int] = None
afk_timeout: Optional[int] = None
icon: Optional[Link[Image]] = None
owner: int
splash: Optional[Link[Image]] = None
discovery_splash: Optional[Link[Image]] = None
banner: Optional[Link[Image]] = None
system_channel: Optional[int] = None
system_channel_flags: Optional[int] = None
rules_channel: Optional[int] = None
public_updates_channel: Optional[int] = None
preferred_locale: Optional[str] = None
features: list[str] = Field(default_factory=list)
channels: list[ChannelBackup] = Field(default_factory=list)
roles: list[RoleBackup] = Field(default_factory=list)
emojis: list[EmojiBackup] = Field(default_factory=list)
stickers: list[StickerBackup] = Field(default_factory=list)
@JARVIS_INST.register
class GuildBackup(EmbeddedDocument):
name: str = fields.StringField(required=True)
description: str = fields.StringField(default=None)
default_message_notifications: Optional[int] = fields.IntegerField(default=None)
explicit_content_filter: Optional[int] = fields.IntegerField(default=None)
afk_channel: Optional[int] = fields.IntegerField(default=None)
afk_timeout: Optional[int] = fields.IntegerField(default=None)
icon: Optional[Image] = fields.ReferenceField(Image, default=None)
owner: int = fields.IntegerField(required=True)
splash: Optional[Image] = fields.ReferenceField(Image, default=None)
discovery_splash: Optional[Image] = fields.ReferenceField(Image, default=None)
banner: Optional[Image] = fields.ReferenceField(Image, default=None)
system_channel: Optional[int] = fields.IntegerField(default=None)
system_channel_flags: Optional[int] = fields.IntegerField(default=None)
rules_channel: Optional[int] = fields.IntegerField(default=None)
public_updates_channel: Optional[int] = fields.IntegerField(default=None)
preferred_locale: Optional[str] = fields.StringField(default=None)
features: List[str] = fields.ListField(fields.StringField, factory=list)
channels: List[ChannelBackup] = fields.ListField(
fields.EmbeddedField(ChannelBackup), factory=list
)
roles: List[RoleBackup] = fields.ListField(fields.EmbeddedField(RoleBackup), factory=list)
emojis: List[EmojiBackup] = fields.ListField(fields.EmbeddedField(EmojiBackup), factory=list)
stickers: List[StickerBackup] = fields.ListField(
fields.EmbeddedField(StickerBackup), factory=list
)
@JARVIS_INST.register
class Backup(Document):
created_at: datetime = NowField()
guild_id: int
bkid: str = NanoField()
guild: GuildBackup
version: str = Field(default=__version__)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
guild_id: int = fields.IntegerField()
bkid: str = fields.StringField(default=get_id)
guild: GuildBackup = fields.EmbeddedField(GuildBackup, required=True)
version: str = fields.StringField(default=__version__)

View file

@ -1,11 +0,0 @@
"""JARVIS Verification Captcha."""
from datetime import datetime
from jarvis_core.db.utils import NowField, Snowflake, SnowflakeDocument
class Captcha(SnowflakeDocument):
user: Snowflake
guild: Snowflake
correct: str
created_at: datetime = NowField()

View file

@ -1,30 +1,36 @@
"""Mastodon databaes models."""
from datetime import datetime
from datetime import datetime, timezone
from beanie import Document
from umongo import Document, fields
from jarvis_core.db import JARVIS_INST
from jarvis_core.db.utils import NowField
from jarvis_core.db.utils import get_now
@JARVIS_INST.register
class MastodonUser(Document):
"""User object."""
user_id: int
acct: str
username: str
last_sync: datetime = NowField()
user_id: int = fields.IntegerField(required=True)
acct: str = fields.StringField(required=True)
username: str = fields.StringField(required=True)
last_sync: datetime = fields.AwareDateTimeField(default=get_now)
class Meta:
collection_name = "mastodonuser"
@JARVIS_INST.register
class MastodonFollow(Document):
"""User Follow object."""
active: bool = True
user_id: int
channel: int
guild: int
reblogged: bool = True
admin: int
created_at: datetime = NowField()
active: bool = fields.BooleanField(default=True)
user_id: int = fields.IntegerField(required=True)
channel: int = fields.IntegerField(required=True)
guild: int = fields.IntegerField(required=True)
reblogged: bool = fields.BooleanField(default=True)
admin: int = fields.IntegerField(required=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
class Meta:
collection_name = "mastodonfollow"

View file

@ -1,36 +1,41 @@
"""Modlog database models."""
from datetime import datetime
from datetime import datetime, timezone
from typing import List
from beanie import PydanticObjectId
from pydantic import BaseModel, Field
from bson import ObjectId
from umongo import Document, EmbeddedDocument, fields
from jarvis_core.db.utils import NanoField, NowField, Snowflake, SnowflakeDocument
from jarvis_core.db import JARVIS_INST
from jarvis_core.db.utils import get_id, get_now
class Action(BaseModel):
@JARVIS_INST.register
class Action(EmbeddedDocument):
"""Modlog embedded action document."""
action_type: str
parent: PydanticObjectId
orphaned: bool = False
action_type: str = fields.StringField(required=True)
parent: ObjectId = fields.ObjectIdField(required=True)
orphaned: bool = fields.BoolField(default=False)
class Note(BaseModel):
@JARVIS_INST.register
class Note(EmbeddedDocument):
"""Modlog embedded note document."""
admin: Snowflake
content: str
created_at: datetime = NowField()
admin: int = fields.IntegerField(required=True)
content: str = fields.StrField(required=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
class Modlog(SnowflakeDocument):
@JARVIS_INST.register
class Modlog(Document):
"""Modlog database object."""
user: Snowflake
nanoid: str = NanoField()
guild: Snowflake
admin: Snowflake
actions: list[Action] = Field(default_factory=list)
notes: list[Note] = Field(default_factory=list)
open: bool = True
created_at: datetime = NowField()
user: int = fields.IntegerField(required=True)
nanoid: str = fields.StringField(default=get_id)
guild: int = fields.IntegerField(required=True)
admin: int = fields.IntegerField(required=True)
actions: List[Action] = fields.ListField(fields.EmbeddedField(Action), factory=list)
open: bool = fields.BoolField(default=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
notes: List[Note] = fields.ListField(fields.EmbeddedField(Note), factory=list)

View file

@ -1,27 +1,52 @@
"""Reddit databaes models."""
from datetime import datetime
from datetime import datetime, timezone
from beanie import Document
from umongo import Document, fields
from jarvis_core.db.utils import NowField
from jarvis_core.db import JARVIS_INST
from jarvis_core.db.utils import get_now
@JARVIS_INST.register
class Subreddit(Document):
"""Subreddit object."""
display_name: str
over18: bool = False
display_name: str = fields.StringField(required=True)
over18: bool = fields.BooleanField(default=False)
@JARVIS_INST.register
class SubredditFollow(Document):
"""Subreddit Follow object."""
active: bool = True
display_name: str
channel: int
guild: int
admin: int
created_at: datetime = NowField()
active: bool = fields.BooleanField(default=True)
display_name: str = fields.StringField(required=True)
channel: int = fields.IntegerField(required=True)
guild: int = fields.IntegerField(required=True)
admin: int = fields.IntegerField(required=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
class Setting:
name = "subredditfollow"
class Meta:
collection_name = "subredditfollow"
@JARVIS_INST.register
class Redditor(Document):
"""Reddit User object."""
name: str = fields.StringField(required=True)
@JARVIS_INST.register
class RedditorFollow(Document):
"""Reditor Follow object."""
active: bool = fields.BooleanField(default=True)
name: str = fields.StringField(required=True)
channel: int = fields.IntegerField(required=True)
guild: int = fields.IntegerField(required=True)
admin: int = fields.IntegerField(required=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
class Meta:
collection_name = "redditorfollow"

View file

@ -1,33 +1,36 @@
"""Twitter database models."""
from datetime import datetime
from datetime import datetime, timezone
from beanie import Document
from umongo import Document, fields
from jarvis_core.db.utils import NowField
from jarvis_core.db import JARVIS_INST
from jarvis_core.db.utils import get_now
@JARVIS_INST.register
class TwitterAccount(Document):
"""Twitter Account object."""
handle: str
twitter_id: int
last_tweet: int
last_sync: datetime = NowField()
handle: str = fields.StringField(required=True)
twitter_id: int = fields.IntegerField(required=True)
last_tweet: int = fields.IntegerField(required=True)
last_sync: datetime = fields.AwareDateTimeField(default=get_now)
class Setting:
name = "twitteraccount"
class Meta:
collection_name = "twitteraccount"
@JARVIS_INST.register
class TwitterFollow(Document):
"""Twitter Follow object."""
active: bool = True
twitter_id: int
channel: int
guild: int
retweets: bool = True
admin: int
created_at: datetime = NowField()
active: bool = fields.BooleanField(default=True)
twitter_id: int = fields.IntegerField(required=True)
channel: int = fields.IntegerField(required=True)
guild: int = fields.IntegerField(required=True)
retweets: bool = fields.BooleanField(default=True)
admin: int = fields.IntegerField(required=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
class Setting:
name = "twitterfollow"
class Meta:
collection_name = "twitterfollow"

View file

@ -1,12 +1,6 @@
"""JARVIS Core Database utilities."""
from datetime import datetime, timezone
from functools import partial
from typing import Any
import nanoid
from beanie import Document
from pydantic import Field, GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
NANOID_ALPHA = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
@ -19,20 +13,3 @@ def get_now() -> datetime:
def get_id() -> str:
"""Get nanoid."""
return nanoid.generate(NANOID_ALPHA, 12)
NowField = partial(Field, default_factory=get_now)
NanoField = partial(Field, default_factory=get_id)
class Snowflake(int):
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
return core_schema.no_info_after_validator_function(cls, handler(int))
class SnowflakeDocument(Document):
class Settings:
bson_encoders = {Snowflake: str}

View file

@ -7,7 +7,7 @@ invites = re.compile(
flags=re.IGNORECASE,
)
custom_emote = re.compile(r"<a?:\w+:(\d+)>$", flags=re.IGNORECASE)
custom_emote = re.compile(r"<:\w+:(\d+)>$", flags=re.IGNORECASE)
valid_text = re.compile(
r"[\w\s\-\\/.!@#$:;\[\]%^*'\"()+=<>,\u0080-\U000E0FFF]*", flags=re.IGNORECASE

View file

@ -9,6 +9,38 @@ from jarvis_core.filters import url
DEFAULT_BLOCKSIZE = 8 * 1024 * 1024
class Singleton(object):
REQUIRED = []
OPTIONAL = {}
def __new__(cls, *args: list, **kwargs: dict):
"""Create a new singleton."""
inst = cls.__dict__.get("inst")
if inst is not None:
return inst
inst = object.__new__(cls)
inst.init(*args, **kwargs)
inst._validate()
cls.__inst__ = inst
return inst
def _validate(self) -> None:
for key in self.REQUIRED:
if not getattr(self, key, None):
raise ValueError(f"Missing required key: {key}")
def init(self, **kwargs: dict) -> None:
"""Initialize the object."""
for key, value in kwargs.items():
setattr(self, key, value)
for key, value in self.OPTIONAL.items():
if not getattr(self, key, None):
setattr(self, key, value)
async def hash(
data: str, method: Union[Callable, str] = hashlib.sha256, size: int = DEFAULT_BLOCKSIZE
) -> Tuple[str, int, str]:

View file

@ -68,7 +68,7 @@ def fmt(*formats: List[Format | Fore | Back] | int) -> str:
ret = fmt + fore + back
if not any([ret, fore, back]):
return RESET
ret = RESET
if ret[-1] == ";":
ret = ret[:-1]

1848
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,148 +1,27 @@
[tool.poetry]
name = "jarvis-core"
version = "1.0.1"
version = "0.16.1"
description = "JARVIS core"
authors = ["Zevaryx <zevaryx@gmail.com>"]
[tool.poetry.dependencies]
python = ">=3.10,<4"
orjson = { version = ">=3.6.6,<4" }
motor = ">=3.1.1,<4"
PyYAML = { version = ">=6.0,<7" }
aiohttp = ">=3.8.1,<4"
rich = ">=13.7.1"
nanoid = ">=2.0.0,<3"
python-dotenv = "1.0.1"
beanie = ">=1.17.0,<2"
pydantic = ">=2.3.0,<3"
python-dateutil = ">=2.9.0.post0,<3"
setuptools = ">=69.2.0,<70"
python = "^3.10"
orjson = "^3.6.6"
motor = "^3.1.1"
umongo = "^3.1.0"
PyYAML = "^6.0"
pytz = "^2022.1"
aiohttp = "^3.8.1"
rich = "^12.3.0"
nanoid = "^2.0.0"
python-dotenv = "^0.21.0"
[tool.poetry.group.dev.dependencies]
black = "^23.1.0"
[tool.poetry.dev-dependencies]
pytest = "^7.1"
ipython = "^8.5.0"
mongomock_motor = "^0.0.29"
pytest-asyncio = "^0.23.5.post1"
pytest-cov = "^4.1.0"
faker = "^24.3.0"
rich = "^12.6.0"
black = {version = "^22.10.0", allow-prereleases = true}
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
minversion = "8.0"
asyncio_mode = "auto"
testpaths = ["tests"]
addopts = "--cov=jarvis_core --cov-report term-missing --cov-report xml:coverage.xml"
filterwarnings = [
'ignore:`general_plain_validator_function` is deprecated',
'ignore:pkg_resources is deprecated as an API',
]
[tool.coverage.run]
omit = [
"tests/",
"jarvis_core/db/models/backups.py",
"jarvis_core/db/models/mastodon.py",
"jarvis_core/db/models/reddit.py",
"jarvis_core/db/models/twitter.py",
]
[tool.black]
line-length = 120
[tool.isort]
profile = "black"
skip = ["__init__.py"]
[tool.mypy]
ignore_missing_imports = true
[tool.pyright]
useLibraryCodeForTypes = true
reportMissingImports = false
[tool.ruff]
line-length = 120
target-version = "py312"
output-format = "full"
[tool.ruff.lint]
task-tags = ["TODO", "FIXME", "XXX", "HACK", "REVIEW", "NOTE"]
select = ["E", "F", "B", "Q", "RUF", "D", "ANN", "RET", "C"]
ignore-init-module-imports = true
ignore = [
"Q0",
"E501",
# These default to arguing with Black. We might configure some of them eventually
"ANN1",
# These insist that we have Type Annotations for self and cls.
"D105",
"D107",
# Missing Docstrings in magic method and __init__
"D401",
# First line should be in imperative mood; try rephrasing
"D400",
"D415",
# First line should end with a period
"D106",
# Missing docstring in public nested class. This doesn't work well with Metadata classes.
"D417",
# Missing argument in the docstring
"D406",
# Section name should end with a newline
"D407",
# Missing dashed underline after section
"D212",
# Multi-line docstring summary should start at the first line
"D404",
# First word of the docstring should not be This
"D203",
# 1 blank line required before class docstring
# Everything below this line is something we care about, but don't currently meet
"ANN001",
# Missing type annotation for function argument 'token'
"ANN002",
# Missing type annotation for *args
"ANN003",
# Missing type annotation for **kwargs
"ANN401",
# Dynamically typed expressions (typing.Any) are disallowed
# "B009",
# Do not call getattr with a constant attribute value, it is not any safer than normal property access.
"B010",
# Do not call setattr with a constant attribute value, it is not any safer than normal property access.
"D100",
# Missing docstring in public module
"D101",
# ... class
"D102",
# ... method
"D103",
# ... function
"D104",
# ... package
"E712",
# Ignore == True because of Beanie
# Plugins we don't currently include: flake8-return
"RET503",
# missing explicit return at the end of function ableto return non-None value.
"RET504",
# unecessary variable assignement before return statement.
]
[tool.ruff.lint.flake8-quotes]
docstring-quotes = "double"
[tool.ruff.lint.flake8-annotations]
mypy-init-return = true
suppress-dummy-args = true
suppress-none-returning = true
[tool.ruff.lint.flake8-errmsg]
max-string-length = 20
[tool.ruff.lint.mccabe]
max-complexity = 13

View file

@ -1 +0,0 @@
pre-commit==3.6.2

View file

@ -1,47 +0,0 @@
import pytest
from jarvis_core import filters
@pytest.fixture()
def faker_locale():
return ["en_US"]
def test_invites(faker):
invites = ["discord.gg/asdf", "discord.com/invite/asdf", "discord://asdf/invite/asdf"]
for invite in invites:
assert filters.invites.match(invite)
for _ in range(100):
assert not filters.invites.match(faker.url())
def test_custom_emotes():
emotes = ["<:test:000>", "<a:animated:000>"]
not_emotes = ["<invalid:000>", "<:a:invalid:000>", "<invalid:000:>"]
for emote in emotes:
print(emote)
assert filters.custom_emote.match(emote)
for not_emote in not_emotes:
assert not filters.custom_emote.match(not_emote)
def test_url(faker):
for _ in range(100):
assert filters.url.match(faker.url())
def test_email(faker):
for _ in range(100):
assert filters.email.match(faker.ascii_email())
def test_ipv4(faker):
for _ in range(100):
assert filters.ipv4.match(faker.ipv4())
def test_ipv4(faker):
for _ in range(100):
assert filters.ipv6.match(faker.ipv6())

View file

@ -0,0 +1,5 @@
from jarvis_core import __version__
def test_version():
assert __version__ == "0.1.0"

View file

@ -1,72 +0,0 @@
import types
import typing
from datetime import datetime, timezone
import pytest
from beanie import Document, init_beanie
from mongomock_motor import AsyncMongoMockClient
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from jarvis_core.db.models import Pin, all_models
from jarvis_core.db.utils import Snowflake
MAX_SNOWFLAKE = 18446744073709551615
async def get_default(annotation: type):
if annotation is Snowflake:
return MAX_SNOWFLAKE
if annotation.__class__ is typing._UnionGenericAlias or annotation.__class__ is types.UnionType:
return annotation.__args__[0]()
if issubclass(annotation, BaseModel):
data = {}
for name, info in annotation.model_fields.items():
if info.is_required():
data[name] = await get_default(info.annotation)
return annotation(**data)
if annotation is datetime:
return datetime.now(tz=timezone.utc)
return annotation()
async def create_data_dict(model_fields: dict[str, FieldInfo]):
data = {}
for name, info in model_fields.items():
if info.is_required():
if (
type(info.annotation) is typing._GenericAlias
and (link := info.annotation.__args__[0]) in all_models
):
reference = await create_data_dict(link.model_fields)
nested = link(**reference)
await nested.save()
nested = await link.find_one(link.id == nested.id)
data[name] = nested
else:
data[name] = await get_default(info.annotation)
return data
@pytest.fixture(autouse=True)
async def my_fixture():
client = AsyncMongoMockClient(tz_aware=True, tzinfo=timezone.utc)
await init_beanie(document_models=all_models, database=client.get_database(name="test_models"))
async def test_models():
for model in all_models:
data = await create_data_dict(model.model_fields)
await model(**data).save()
saved = await model.find_one()
for key, value in data.items():
if model is Pin:
continue # This is broken af, it works but I can't test it
saved_value = getattr(saved, key)
# Don't care about microseconds for these tests
# Mongosock tends to round, so we
if isinstance(saved_value, datetime):
saved_value = int(saved_value.astimezone(timezone.utc).timestamp())
value = int(value.timestamp())
assert value == saved_value

View file

@ -1,83 +0,0 @@
from dataclasses import dataclass
import pytest
from aiohttp import ClientConnectionError, ClientResponseError
from jarvis_core import util
from jarvis_core.util import ansi, http
async def test_hash():
hashes: dict[str, dict[str, str]] = {
"sha256": {
"hello": "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824",
"https://zevaryx.com/media/logo.png": "668ddf4ec8b0c7315c8a8bfdedc36b242ff8f4bba5debccd8f5fa07776234b6a",
},
"sha1": {
"hello": "aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d",
"https://zevaryx.com/media/logo.png": "989f8065819c6946493797209f73ffe37103f988",
},
}
for hash_method, items in hashes.items():
for value, correct in items.items():
print(value)
assert (await util.hash(data=value, method=hash_method))[0] == correct
with pytest.raises(ClientResponseError):
await util.hash("https://zevaryx.com/known-not-to-exist")
with pytest.raises(ClientConnectionError):
await util.hash("https://known-to-not-exist.zevaryx.com")
def test_bytesize():
size = 4503599627370496
converted = util.convert_bytesize(size)
assert converted == "4.000 PB"
assert util.unconvert_bytesize(4, "PB") == size
assert util.convert_bytesize(None) == "??? B"
assert util.unconvert_bytesize(4, "B") == 4
def test_find_get():
@dataclass
class TestModel:
x: int
models = [TestModel(3), TestModel(9), TestModel(100), TestModel(-2)]
assert util.find(lambda x: x.x > 0, models).x == 3
assert util.find(lambda x: x.x > 100, models) is None
assert len(util.find_all(lambda x: x.x % 2 == 0, models)) == 2
assert util.get(models, x=3).x == 3
assert util.get(models, x=11) is None
assert util.get(models).x == 3
assert util.get(models, y=3) is None
assert len(util.get_all(models, x=9)) == 1
assert len(util.get_all(models, y=1)) == 0
assert util.get_all(models) == models
async def test_http_get_size():
url = "http://ipv4.download.thinkbroadband.com/100MB.zip"
size = 104857600
assert await http.get_size(url) == size
with pytest.raises(ValueError):
await http.get_size("invalid")
def test_ansi():
known = "\x1b[0;35;41m"
assert ansi.fmt(1, ansi.Format.NORMAL, ansi.Fore.PINK, ansi.Back.ORANGE) == known
assert 4 in ansi.Format
assert 2 not in ansi.Format
assert ansi.fmt() == ansi.RESET