72 lines
2.5 KiB
Python
72 lines
2.5 KiB
Python
import types
|
|
import typing
|
|
from datetime import datetime, timezone
|
|
|
|
import pytest
|
|
from beanie import Document, init_beanie
|
|
from mongomock_motor import AsyncMongoMockClient
|
|
from pydantic import BaseModel
|
|
from pydantic.fields import FieldInfo
|
|
|
|
from jarvis_core.db.models import Pin, all_models
|
|
from jarvis_core.db.utils import Snowflake
|
|
|
|
MAX_SNOWFLAKE = 18446744073709551615
|
|
|
|
|
|
async def get_default(annotation: type):
|
|
if annotation is Snowflake:
|
|
return MAX_SNOWFLAKE
|
|
if annotation.__class__ is typing._UnionGenericAlias or annotation.__class__ is types.UnionType:
|
|
return annotation.__args__[0]()
|
|
if issubclass(annotation, BaseModel):
|
|
data = {}
|
|
for name, info in annotation.model_fields.items():
|
|
if info.is_required():
|
|
data[name] = await get_default(info.annotation)
|
|
return annotation(**data)
|
|
if annotation is datetime:
|
|
return datetime.now(tz=timezone.utc)
|
|
return annotation()
|
|
|
|
|
|
async def create_data_dict(model_fields: dict[str, FieldInfo]):
|
|
data = {}
|
|
for name, info in model_fields.items():
|
|
if info.is_required():
|
|
if (
|
|
type(info.annotation) is typing._GenericAlias
|
|
and (link := info.annotation.__args__[0]) in all_models
|
|
):
|
|
reference = await create_data_dict(link.model_fields)
|
|
nested = link(**reference)
|
|
await nested.save()
|
|
nested = await link.find_one(link.id == nested.id)
|
|
data[name] = nested
|
|
else:
|
|
data[name] = await get_default(info.annotation)
|
|
return data
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
async def my_fixture():
|
|
client = AsyncMongoMockClient(tz_aware=True, tzinfo=timezone.utc)
|
|
await init_beanie(document_models=all_models, database=client.get_database(name="test_models"))
|
|
|
|
|
|
async def test_models():
|
|
for model in all_models:
|
|
data = await create_data_dict(model.model_fields)
|
|
await model(**data).save()
|
|
saved = await model.find_one()
|
|
for key, value in data.items():
|
|
if model is Pin:
|
|
continue # This is broken af, it works but I can't test it
|
|
saved_value = getattr(saved, key)
|
|
# Don't care about microseconds for these tests
|
|
# Mongosock tends to round, so we
|
|
if isinstance(saved_value, datetime):
|
|
saved_value = int(saved_value.astimezone(timezone.utc).timestamp())
|
|
value = int(value.timestamp())
|
|
|
|
assert value == saved_value
|