From e339bb6181789748861cca50d708289009d7e2f8 Mon Sep 17 00:00:00 2001 From: Vlad Stan Date: Tue, 29 Apr 2025 13:52:07 +0300 Subject: [PATCH] [feat] fetch all payments for user (#3132) --- lnbits/core/crud/payments.py | 10 +++++- lnbits/core/crud/wallets.py | 14 ++++++++ lnbits/core/views/payment_api.py | 10 +++++- tests/unit/test_pay_invoice.py | 58 ++++++++++++++++++++++++++++++-- 4 files changed, 87 insertions(+), 5 deletions(-) diff --git a/lnbits/core/crud/payments.py b/lnbits/core/crud/payments.py index 0b6e61fa..79a01f6b 100644 --- a/lnbits/core/crud/payments.py +++ b/lnbits/core/crud/payments.py @@ -1,7 +1,7 @@ from time import time from typing import Any, Optional, Tuple -from lnbits.core.crud.wallets import get_total_balance, get_wallet +from lnbits.core.crud.wallets import get_total_balance, get_wallet, get_wallets_ids from lnbits.core.db import db from lnbits.core.models import PaymentState from lnbits.db import Connection, DateTrunc, Filters, Page @@ -95,6 +95,7 @@ async def get_latest_payments_by_extension( async def get_payments_paginated( *, wallet_id: Optional[str] = None, + user_id: Optional[str] = None, complete: bool = False, pending: bool = False, failed: bool = False, @@ -121,6 +122,13 @@ async def get_payments_paginated( if wallet_id: values["wallet_id"] = wallet_id clause.append("wallet_id = :wallet_id") + elif user_id: + wallet_ids = await get_wallets_ids(user_id=user_id, conn=conn) or [ + "no-wallets-for-user" + ] + # wallet ids are safe to use in sql queries + wallet_ids_str = [f"'{w}'" for w in wallet_ids] + clause.append(f""" wallet_id IN ({", ".join(wallet_ids_str)}) """) if complete and pending: clause.append( diff --git a/lnbits/core/crud/wallets.py b/lnbits/core/crud/wallets.py index 68fbda8a..5ac50fb2 100644 --- a/lnbits/core/crud/wallets.py +++ b/lnbits/core/crud/wallets.py @@ -135,6 +135,20 @@ async def get_wallets( ) +async def get_wallets_ids( + user_id: str, deleted: Optional[bool] = None, conn: Optional[Connection] = None +) -> list[str]: + where = "AND deleted = :deleted" if deleted is not None else "" + result: list[dict] = await (conn or db).fetchall( + f""" + SELECT id FROM wallets + WHERE "user" = :user {where} + """, + {"user": user_id, "deleted": deleted}, + ) + return [row["id"] for row in result] + + async def get_wallets_count(): result = await db.execute("SELECT COUNT(*) as count FROM wallets") row = result.mappings().first() diff --git a/lnbits/core/views/payment_api.py b/lnbits/core/views/payment_api.py index 18ee97a3..1bff6d76 100644 --- a/lnbits/core/views/payment_api.py +++ b/lnbits/core/views/payment_api.py @@ -264,13 +264,21 @@ async def _api_payments_create_invoice(data: CreateInvoice, wallet: Wallet): response_description="list of payments", response_model=Page[Payment], openapi_extra=generate_filter_params_openapi(PaymentFilters), - dependencies=[Depends(check_admin)], ) async def api_all_payments_paginated( filters: Filters = Depends(parse_filters(PaymentFilters)), + user: User = Depends(check_user_exists), ): + if user.admin: + # admin user can see payments from all wallets + for_user_id = None + else: + # regular user can only see payments from their wallets + for_user_id = user.id + return await get_payments_paginated( filters=filters, + user_id=for_user_id, ) diff --git a/tests/unit/test_pay_invoice.py b/tests/unit/test_pay_invoice.py index 774824cf..387b7987 100644 --- a/tests/unit/test_pay_invoice.py +++ b/tests/unit/test_pay_invoice.py @@ -8,10 +8,10 @@ from bolt11 import encode as bolt11_encode from bolt11.types import MilliSatoshi from pytest_mock.plugin import MockerFixture -from lnbits.core.crud import get_standalone_payment, get_wallet -from lnbits.core.crud.payments import get_payment +from lnbits.core.crud import create_wallet, get_standalone_payment, get_wallet +from lnbits.core.crud.payments import get_payment, get_payments_paginated from lnbits.core.models import Payment, PaymentState, Wallet -from lnbits.core.services import create_invoice, pay_invoice +from lnbits.core.services import create_invoice, create_user_account, pay_invoice from lnbits.exceptions import InvoiceError, PaymentError from lnbits.settings import Settings from lnbits.tasks import ( @@ -596,3 +596,55 @@ async def test_service_fee( assert service_fee_payment.amount == 422_400 assert service_fee_payment.bolt11 == external_invoice.payment_request assert service_fee_payment.preimage is None + + +@pytest.mark.anyio +async def test_get_payments_for_user(to_wallet: Wallet): + all_payments = await get_payments_paginated() + total_before = all_payments.total + + user = await create_user_account() + wallet_one = await create_wallet(user_id=user.id, wallet_name="first wallet") + wallet_two = await create_wallet(user_id=user.id, wallet_name="second wallet") + + user_payments = await get_payments_paginated(user_id=user.id) + assert user_payments.total == 0 + + payment = await create_invoice(wallet_id=wallet_one.id, amount=100, memo="one") + user_payments = await get_payments_paginated(user_id=user.id) + assert user_payments.total == 1 + # this will create a payment in the to_wallet that we need to count for at the end + await pay_invoice( + wallet_id=to_wallet.id, + payment_request=payment.bolt11, + ) + user_payments = await get_payments_paginated(user_id=user.id) + assert user_payments.total == 1 + + payment = await create_invoice(wallet_id=wallet_one.id, amount=3, memo="two") + user_payments = await get_payments_paginated(user_id=user.id) + assert user_payments.total == 2 + + payment = await create_invoice(wallet_id=wallet_two.id, amount=3, memo="three") + user_payments = await get_payments_paginated(user_id=user.id) + assert user_payments.total == 3 + + await pay_invoice( + wallet_id=wallet_one.id, + payment_request=payment.bolt11, + ) + user_payments = await get_payments_paginated(user_id=user.id) + assert user_payments.total == 4 + + all_payments = await get_payments_paginated() + total_after = all_payments.total + + assert total_after == total_before + 5, "Total payments should be updated." + + +@pytest.mark.anyio +async def test_get_payments_for_non_user(): + user_payments = await get_payments_paginated(user_id="nonexistent") + assert ( + user_payments.total == 0 + ), "No payments should be found for non-existent user."