diff --git a/lnbits/db.py b/lnbits/db.py index 8703adf6..b1434a6f 100644 --- a/lnbits/db.py +++ b/lnbits/db.py @@ -8,7 +8,7 @@ import time from contextlib import asynccontextmanager from datetime import datetime, timezone from enum import Enum -from typing import Any, Generic, Literal, Optional, TypeVar, Union +from typing import Any, Generic, Literal, Optional, TypeVar, Union, get_origin from loguru import logger from pydantic import BaseModel, ValidationError, root_validator @@ -605,12 +605,17 @@ def model_to_dict(model: BaseModel) -> dict: _dict: dict = {} for key, value in model.dict().items(): type_ = model.__fields__[key].type_ + outertype_ = model.__fields__[key].outer_type_ if model.__fields__[key].field_info.extra.get("no_database", False): continue if isinstance(value, datetime): _dict[key] = value.timestamp() continue - if type(type_) is type(BaseModel) or type_ is dict: + if ( + type(type_) is type(BaseModel) + or type_ is dict + or get_origin(outertype_) is list + ): _dict[key] = json.dumps(value) continue _dict[key] = value @@ -645,10 +650,12 @@ def dict_to_model(_row: dict, model: type[TModel]) -> TModel: logger.warning(f"Converting {key} to model `{model}`.") continue type_ = model.__fields__[key].type_ - if isinstance(value, list): + outertype_ = model.__fields__[key].outer_type_ + if get_origin(outertype_) is list: + _items = json.loads(value) if isinstance(value, str) else value _dict[key] = [ dict_to_submodel(type_, v) if issubclass(type_, BaseModel) else v - for v in value + for v in _items ] continue if issubclass(type_, bool): diff --git a/tests/helpers.py b/tests/helpers.py index 4a401fd2..49c1d680 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -22,6 +22,7 @@ class DbTestModel2(BaseModel): label: str description: Optional[str] = None child: DbTestModel + child_list: list[DbTestModel] class DbTestModel3(BaseModel): @@ -29,6 +30,8 @@ class DbTestModel3(BaseModel): user: str child: DbTestModel2 active: bool = False + children: list[DbTestModel] + children_ids: list[int] = [] def get_random_string(iterations: int = 10): diff --git a/tests/unit/test_helpers_query.py b/tests/unit/test_helpers_query.py index abea0803..81af0be6 100644 --- a/tests/unit/test_helpers_query.py +++ b/tests/unit/test_helpers_query.py @@ -18,7 +18,10 @@ test_data = DbTestModel3( label="test", description="mydesc", child=DbTestModel(id=3, name="myname", value="myvalue"), + child_list=[DbTestModel(id=6, name="myname", value="myvalue")], ), + children=[DbTestModel(id=4, name="myname", value="myvalue")], + children_ids=[4, 1, 3], active=True, ) @@ -27,8 +30,9 @@ test_data = DbTestModel3( async def test_helpers_insert_query(): q = insert_query("test_helpers_query", test_data) assert q == ( - """INSERT INTO test_helpers_query ("id", "user", "child", "active") """ - "VALUES (:id, :user, :child, :active)" + "INSERT INTO test_helpers_query " + """("id", "user", "child", "active", "children", "children_ids") """ + "VALUES (:id, :user, :child, :active, :children, :children_ids)" ) @@ -37,7 +41,8 @@ async def test_helpers_update_query(): q = update_query("test_helpers_query", test_data) assert q == ( """UPDATE test_helpers_query SET "id" = :id, "user" = """ - """:user, "child" = :child, "active" = :active WHERE id = :id""" + """:user, "child" = :child, "active" = :active, "children" = """ + """:children, "children_ids" = :children_ids WHERE id = :id""" ) @@ -47,9 +52,17 @@ child_json = json.dumps( "label": "test", "description": "mydesc", "child": {"id": 3, "name": "myname", "value": "myvalue"}, + "child_list": [{"id": 6, "name": "myname", "value": "myvalue"}], } ) -test_dict = {"id": 1, "user": "userid", "child": child_json, "active": True} +test_dict = { + "id": 1, + "user": "userid", + "child": child_json, + "active": True, + "children": '[{"id": 4, "name": "myname", "value": "myvalue"}]', + "children_ids": "[4, 1, 3]", +} @pytest.mark.asyncio