fix: add list handling (#2758)

This commit is contained in:
dni ⚡ 2024-11-01 10:12:18 +01:00 committed by GitHub
parent 2fa0a3c995
commit acb1b1ed91
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 31 additions and 8 deletions

View file

@ -8,7 +8,7 @@ import time
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Generic, Literal, Optional, TypeVar, Union
from typing import Any, Generic, Literal, Optional, TypeVar, Union, get_origin
from loguru import logger
from pydantic import BaseModel, ValidationError, root_validator
@ -605,12 +605,17 @@ def model_to_dict(model: BaseModel) -> dict:
_dict: dict = {}
for key, value in model.dict().items():
type_ = model.__fields__[key].type_
outertype_ = model.__fields__[key].outer_type_
if model.__fields__[key].field_info.extra.get("no_database", False):
continue
if isinstance(value, datetime):
_dict[key] = value.timestamp()
continue
if type(type_) is type(BaseModel) or type_ is dict:
if (
type(type_) is type(BaseModel)
or type_ is dict
or get_origin(outertype_) is list
):
_dict[key] = json.dumps(value)
continue
_dict[key] = value
@ -645,10 +650,12 @@ def dict_to_model(_row: dict, model: type[TModel]) -> TModel:
logger.warning(f"Converting {key} to model `{model}`.")
continue
type_ = model.__fields__[key].type_
if isinstance(value, list):
outertype_ = model.__fields__[key].outer_type_
if get_origin(outertype_) is list:
_items = json.loads(value) if isinstance(value, str) else value
_dict[key] = [
dict_to_submodel(type_, v) if issubclass(type_, BaseModel) else v
for v in value
for v in _items
]
continue
if issubclass(type_, bool):

View file

@ -22,6 +22,7 @@ class DbTestModel2(BaseModel):
label: str
description: Optional[str] = None
child: DbTestModel
child_list: list[DbTestModel]
class DbTestModel3(BaseModel):
@ -29,6 +30,8 @@ class DbTestModel3(BaseModel):
user: str
child: DbTestModel2
active: bool = False
children: list[DbTestModel]
children_ids: list[int] = []
def get_random_string(iterations: int = 10):

View file

@ -18,7 +18,10 @@ test_data = DbTestModel3(
label="test",
description="mydesc",
child=DbTestModel(id=3, name="myname", value="myvalue"),
child_list=[DbTestModel(id=6, name="myname", value="myvalue")],
),
children=[DbTestModel(id=4, name="myname", value="myvalue")],
children_ids=[4, 1, 3],
active=True,
)
@ -27,8 +30,9 @@ test_data = DbTestModel3(
async def test_helpers_insert_query():
q = insert_query("test_helpers_query", test_data)
assert q == (
"""INSERT INTO test_helpers_query ("id", "user", "child", "active") """
"VALUES (:id, :user, :child, :active)"
"INSERT INTO test_helpers_query "
"""("id", "user", "child", "active", "children", "children_ids") """
"VALUES (:id, :user, :child, :active, :children, :children_ids)"
)
@ -37,7 +41,8 @@ async def test_helpers_update_query():
q = update_query("test_helpers_query", test_data)
assert q == (
"""UPDATE test_helpers_query SET "id" = :id, "user" = """
""":user, "child" = :child, "active" = :active WHERE id = :id"""
""":user, "child" = :child, "active" = :active, "children" = """
""":children, "children_ids" = :children_ids WHERE id = :id"""
)
@ -47,9 +52,17 @@ child_json = json.dumps(
"label": "test",
"description": "mydesc",
"child": {"id": 3, "name": "myname", "value": "myvalue"},
"child_list": [{"id": 6, "name": "myname", "value": "myvalue"}],
}
)
test_dict = {"id": 1, "user": "userid", "child": child_json, "active": True}
test_dict = {
"id": 1,
"user": "userid",
"child": child_json,
"active": True,
"children": '[{"id": 4, "name": "myname", "value": "myvalue"}]',
"children_ids": "[4, 1, 3]",
}
@pytest.mark.asyncio