diff --git a/lnbits/db.py b/lnbits/db.py index 91cfb56f..6b25f4b5 100644 --- a/lnbits/db.py +++ b/lnbits/db.py @@ -554,9 +554,13 @@ class Filters(BaseModel, Generic[TFilterModel]): for page_filter in self.filters: where_stmts.append(page_filter.statement) if self.search and self.model and self.model.__search_fields__: - where_stmts.append( - f"lower(concat({', '.join(self.model.__search_fields__)})) LIKE :search" + # Use `COALESCE` to handle `NULL` values and `||` + # for cross-database compatible string concatenation + _fields = self.model.__search_fields__ + search_expr = " || ".join( + f"COALESCE(CAST({field} AS TEXT), '')" for field in _fields ) + where_stmts.append(f"lower({search_expr}) LIKE :search") if where_stmts: return "WHERE " + " AND ".join(where_stmts) @@ -576,7 +580,7 @@ class Filters(BaseModel, Generic[TFilterModel]): for key, value in page_filter.values.items(): values[key] = value if self.search and self.model: - values["search"] = f"%{self.search}%" + values["search"] = f"%{self.search.lower()}%" return values diff --git a/tests/unit/test_crud_payments.py b/tests/unit/test_crud_payments.py index b177ed73..c5be2410 100644 --- a/tests/unit/test_crud_payments.py +++ b/tests/unit/test_crud_payments.py @@ -1,8 +1,18 @@ import pytest -from lnbits.core.crud import create_wallet, get_payments, update_payment -from lnbits.core.models import PaymentState -from lnbits.core.services import create_user_account, update_wallet_balance +from lnbits.core.crud import ( + create_wallet, + get_payments, + get_payments_paginated, + update_payment, +) +from lnbits.core.models import PaymentFilters, PaymentState +from lnbits.core.services import ( + create_invoice, + create_user_account, + update_wallet_balance, +) +from lnbits.db import Filters async def update_payments(payments): @@ -64,3 +74,62 @@ async def test_crud_get_payments(app): # both false should return failed payments # payments = await get_payments(wallet_id=wallet.id, complete=False, pending=False) # assert len(payments) == 2, "should return 2 failed payment" + + +@pytest.mark.anyio +async def test_crud_search_payments(): + + user = await create_user_account() + wallet = await create_wallet(user_id=user.id) + filters: Filters = Filters( + search="", + model=PaymentFilters, + ) + # no memo + await create_invoice(wallet_id=wallet.id, amount=30, memo="") + await create_invoice(wallet_id=wallet.id, amount=30, memo="Invoice A") + filters.search = "Invoice A" + page = await get_payments_paginated( + wallet_id=wallet.id, + filters=filters, + ) + assert page.total == 1, "should return only Invoice A" + + filters.search = "Invoice B" + page = await get_payments_paginated( + wallet_id=wallet.id, + filters=filters, + ) + assert page.total == 0, "no Invoice B yet" + + for i in range(15): + await create_invoice(wallet_id=wallet.id, amount=30 + i, memo="Invoice A") + await create_invoice(wallet_id=wallet.id, amount=30 + i, memo="Invoice B") + + filters.search = None + page = await get_payments_paginated( + wallet_id=wallet.id, + filters=filters, + ) + assert page.total == 32, "should return all payments" + + filters.search = "Invoice A" + page = await get_payments_paginated( + wallet_id=wallet.id, + filters=filters, + ) + assert page.total == 16 + + filters.search = "Invoice B" + page = await get_payments_paginated( + wallet_id=wallet.id, + filters=filters, + ) + assert page.total == 15 + + filters.search = "Invoice" + page = await get_payments_paginated( + wallet_id=wallet.id, + filters=filters, + ) + assert page.total == 31