Merge branch 'main' into beanie

This commit is contained in:
Zeva Rose 2023-03-23 22:05:15 -06:00
commit 0e9a155ab7
15 changed files with 1443 additions and 827 deletions

View file

@ -1,6 +1,6 @@
"""Load global config.""" """Load global config."""
import ast import os
from os import environ from lib2to3.pgen2 import token
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Type from typing import Any, Dict, List, Optional, Type

View file

@ -7,7 +7,14 @@ from jarvis_core.db.models import all_models
async def connect( async def connect(
host: str, username: str, password: str, port: int = 27017, testing: bool = False, extra_models: list = [] username: str,
password: str,
port: int = 27017,
testing: bool = False,
host: str = None,
hosts: list[str] = None,
replicaset: str = None,
extra_models: list = [],
) -> None: ) -> None:
""" """
Connect to MongoDB. Connect to MongoDB.
@ -20,7 +27,6 @@ async def connect(
testing: Whether or not to use jarvis_dev testing: Whether or not to use jarvis_dev
extra_models: Extra beanie models to register extra_models: Extra beanie models to register
""" """
global CLIENT, JARVISDB, CTC2DB, JARVIS_INST, CTC2_INST
client = AsyncIOMotorClient(host=host, username=username, password=password, port=port, tz_aware=True, tzinfo=utc) client = AsyncIOMotorClient(host=host, username=username, password=password, port=port, tz_aware=True, tzinfo=utc)
db = client.jarvis_dev if testing else client.jarvis db = client.jarvis_dev if testing else client.jarvis
await init_beanie(database=db, document_models=all_models + extra_models) await init_beanie(database=db, document_models=all_models + extra_models)

26
jarvis_core/db/fields.py Normal file
View file

@ -0,0 +1,26 @@
import bson
import marshmallow as ma
from marshmallow import fields as ma_fields
from umongo import fields
class BinaryField(fields.BaseField, ma_fields.Field):
default_error_messages = {"invalid": "Not a valid byte sequence."}
def _serialize(self, value, attr, data, **kwargs):
return bytes(value)
def _deserialize(self, value, attr, data, **kwargs):
if not isinstance(value, bytes):
self.fail("invalid")
return value
def _serialize_to_mongo(self, obj):
return bson.binary.Binary(obj)
def _deserialize_from_mongo(self, value):
return bytes(value)
class RawField(fields.BaseField, ma_fields.Raw):
pass

View file

@ -0,0 +1,122 @@
from datetime import datetime
from typing import List, Optional
from umongo import Document, EmbeddedDocument, fields
from jarvis_core import __version__
from jarvis_core.db import JARVIS_INST
from jarvis_core.db.fields import BinaryField
from jarvis_core.db.utils import get_id, get_now
@JARVIS_INST.register
class Image(Document):
discord_id: int = fields.IntegerField(unique=True)
image_data: List[bytes] = BinaryField()
image_ext: str = fields.StringField()
created_at: datetime = fields.AwareDateTimeField(default=get_now)
@JARVIS_INST.register
class PermissionOverwriteBackup(EmbeddedDocument):
id: int = fields.IntegerField()
type: int = fields.IntegerField()
allow: int = fields.IntegerField()
deny: int = fields.IntegerField()
@JARVIS_INST.register
class WebhookBackup(EmbeddedDocument):
id: int = fields.IntegerField()
channel_id: int = fields.IntegerField()
type: int = fields.IntegerField()
avatar: Image = fields.ReferenceField(Image)
name: str = fields.StringField()
@JARVIS_INST.register
class ChannelBackup(EmbeddedDocument):
id: int = fields.IntegerField()
name: str = fields.StringField()
type: int = fields.IntegerField()
position: int = fields.IntegerField()
topic: Optional[str] = fields.StringField(default=None)
nsfw: bool = fields.BooleanField(default=False)
rate_limit_per_user: int = fields.IntegerField(default=None)
bitrate: Optional[int] = fields.IntegerField(default=None)
user_limit: Optional[int] = fields.IntegerField(default=None)
permission_overwrites: List[PermissionOverwriteBackup] = fields.ListField(
fields.EmbeddedField(PermissionOverwriteBackup), factory=list
)
parent_id: Optional[int] = fields.IntegerField(default=None)
rtc_region: Optional[str] = fields.StringField(default=None)
video_quality_mode: Optional[int] = fields.IntegerField(default=None)
default_auto_archive_duration: Optional[int] = fields.IntegerField(default=None)
webhooks: List[WebhookBackup] = fields.ListField(
fields.EmbeddedField(WebhookBackup), factory=list
)
@JARVIS_INST.register
class RoleBackup(EmbeddedDocument):
id: int = fields.IntegerField()
name: str = fields.StringField()
permissions: int = fields.IntegerField()
color: str = fields.StringField()
hoist: bool = fields.BooleanField()
mentionable: bool = fields.BooleanField()
@JARVIS_INST.register
class EmojiBackup(EmbeddedDocument):
id: int = fields.IntegerField()
name: str = fields.StringField()
image: Image = fields.ReferenceField(Image)
@JARVIS_INST.register
class StickerBackup(EmbeddedDocument):
id: int = fields.IntegerField()
name: str = fields.StringField()
format_type: int = fields.IntegerField()
tags: str = fields.StringField()
type: int = fields.IntegerField()
image: Image = fields.ReferenceField(Image)
@JARVIS_INST.register
class GuildBackup(EmbeddedDocument):
name: str = fields.StringField(required=True)
description: str = fields.StringField(default=None)
default_message_notifications: Optional[int] = fields.IntegerField(default=None)
explicit_content_filter: Optional[int] = fields.IntegerField(default=None)
afk_channel: Optional[int] = fields.IntegerField(default=None)
afk_timeout: Optional[int] = fields.IntegerField(default=None)
icon: Optional[Image] = fields.ReferenceField(Image, default=None)
owner: int = fields.IntegerField(required=True)
splash: Optional[Image] = fields.ReferenceField(Image, default=None)
discovery_splash: Optional[Image] = fields.ReferenceField(Image, default=None)
banner: Optional[Image] = fields.ReferenceField(Image, default=None)
system_channel: Optional[int] = fields.IntegerField(default=None)
system_channel_flags: Optional[int] = fields.IntegerField(default=None)
rules_channel: Optional[int] = fields.IntegerField(default=None)
public_updates_channel: Optional[int] = fields.IntegerField(default=None)
preferred_locale: Optional[str] = fields.StringField(default=None)
features: List[str] = fields.ListField(fields.StringField, factory=list)
channels: List[ChannelBackup] = fields.ListField(
fields.EmbeddedField(ChannelBackup), factory=list
)
roles: List[RoleBackup] = fields.ListField(fields.EmbeddedField(RoleBackup), factory=list)
emojis: List[EmojiBackup] = fields.ListField(fields.EmbeddedField(EmojiBackup), factory=list)
stickers: List[StickerBackup] = fields.ListField(
fields.EmbeddedField(StickerBackup), factory=list
)
@JARVIS_INST.register
class Backup(Document):
created_at: datetime = fields.AwareDateTimeField(default=get_now)
guild_id: int = fields.IntegerField()
bkid: str = fields.StringField(default=get_id)
guild: GuildBackup = fields.EmbeddedField(GuildBackup, required=True)
version: str = fields.StringField(default=__version__)

View file

@ -0,0 +1,36 @@
"""Mastodon databaes models."""
from datetime import datetime, timezone
from umongo import Document, fields
from jarvis_core.db import JARVIS_INST
from jarvis_core.db.utils import get_now
@JARVIS_INST.register
class MastodonUser(Document):
"""User object."""
user_id: int = fields.IntegerField(required=True)
acct: str = fields.StringField(required=True)
username: str = fields.StringField(required=True)
last_sync: datetime = fields.AwareDateTimeField(default=get_now)
class Meta:
collection_name = "mastodonuser"
@JARVIS_INST.register
class MastodonFollow(Document):
"""User Follow object."""
active: bool = fields.BooleanField(default=True)
user_id: int = fields.IntegerField(required=True)
channel: int = fields.IntegerField(required=True)
guild: int = fields.IntegerField(required=True)
reblogged: bool = fields.BooleanField(default=True)
admin: int = fields.IntegerField(required=True)
created_at: datetime = fields.AwareDateTimeField(default=get_now)
class Meta:
collection_name = "mastodonfollow"

View file

@ -2,7 +2,6 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from functools import partial from functools import partial
import nanoid
from bson import ObjectId from bson import ObjectId
from beanie import Document from beanie import Document
from pydantic import BaseModel, Field from pydantic import BaseModel, Field

15
jarvis_core/db/utils.py Normal file
View file

@ -0,0 +1,15 @@
from datetime import datetime, timezone
import nanoid
NANOID_ALPHA = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
def get_now() -> datetime:
"""Get proper timestamp."""
return datetime.now(tz=timezone.utc)
def get_id() -> str:
"""Get nanoid."""
return nanoid.generate(NANOID_ALPHA, 12)

1410
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,19 +1,19 @@
[tool.poetry] [tool.poetry]
name = "jarvis-core" name = "jarvis-core"
version = "0.10.2" version = "0.17.0"
description = "JARVIS core" description = "JARVIS core"
authors = ["Zevaryx <zevaryx@gmail.com>"] authors = ["Zevaryx <zevaryx@gmail.com>"]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.10" python = "^3.10"
orjson = {version = "^3.6.6"} orjson = {version = "^3.6.6"}
motor = "^2.5.1" motor = "^3.1.1"
PyYAML = {version = "^6.0"} PyYAML = {version = "^6.0"}
pytz = "^2022.1" pytz = "^2022.1"
aiohttp = "^3.8.1" aiohttp = "^3.8.1"
rich = "^12.3.0" rich = "^12.3.0"
nanoid = "^2.0.0" nanoid = "^2.0.0"
python-dotenv = {version = "^0.20.0"} python-dotenv = "^0.21.0"
beanie = "^1.17.0" beanie = "^1.17.0"
pydantic = "^1.10.7" pydantic = "^1.10.7"
@ -22,6 +22,8 @@ pytest = "^7.1"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
black = "^23.1.0" black = "^23.1.0"
ipython = "^8.5.0"
rich = "^12.6.0"
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry-core>=1.0.0"]