import types import typing from datetime import datetime, timezone import pytest from beanie import Document, init_beanie from mongomock_motor import AsyncMongoMockClient from pydantic import BaseModel from pydantic.fields import FieldInfo from jarvis_core.db.models import Pin, all_models from jarvis_core.db.utils import Snowflake MAX_SNOWFLAKE = 18446744073709551615 async def get_default(annotation: type): if annotation is Snowflake: return MAX_SNOWFLAKE if annotation.__class__ is typing._UnionGenericAlias or annotation.__class__ is types.UnionType: return annotation.__args__[0]() if issubclass(annotation, BaseModel): data = {} for name, info in annotation.model_fields.items(): if info.is_required(): data[name] = await get_default(info.annotation) return annotation(**data) if annotation is datetime: return datetime.now(tz=timezone.utc) return annotation() async def create_data_dict(model_fields: dict[str, FieldInfo]): data = {} for name, info in model_fields.items(): if info.is_required(): if ( type(info.annotation) is typing._GenericAlias and (link := info.annotation.__args__[0]) in all_models ): reference = await create_data_dict(link.model_fields) nested = link(**reference) await nested.save() nested = await link.find_one(link.id == nested.id) data[name] = nested else: data[name] = await get_default(info.annotation) return data @pytest.fixture(autouse=True) async def my_fixture(): client = AsyncMongoMockClient(tz_aware=True, tzinfo=timezone.utc) await init_beanie(document_models=all_models, database=client.get_database(name="test_models")) async def test_models(): for model in all_models: data = await create_data_dict(model.model_fields) await model(**data).save() saved = await model.find_one() for key, value in data.items(): if model is Pin: continue # This is broken af, it works but I can't test it saved_value = getattr(saved, key) # Don't care about microseconds for these tests # Mongosock tends to round, so we if isinstance(saved_value, datetime): saved_value = int(saved_value.astimezone(timezone.utc).timestamp()) value = int(value.timestamp()) assert value == saved_value