[feat] fetch all payments for user (#3132)

This commit is contained in:
Vlad Stan 2025-04-29 13:52:07 +03:00 committed by GitHub
parent 2dee26b728
commit e339bb6181
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 87 additions and 5 deletions

View file

@ -1,7 +1,7 @@
from time import time from time import time
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
from lnbits.core.crud.wallets import get_total_balance, get_wallet from lnbits.core.crud.wallets import get_total_balance, get_wallet, get_wallets_ids
from lnbits.core.db import db from lnbits.core.db import db
from lnbits.core.models import PaymentState from lnbits.core.models import PaymentState
from lnbits.db import Connection, DateTrunc, Filters, Page from lnbits.db import Connection, DateTrunc, Filters, Page
@ -95,6 +95,7 @@ async def get_latest_payments_by_extension(
async def get_payments_paginated( async def get_payments_paginated(
*, *,
wallet_id: Optional[str] = None, wallet_id: Optional[str] = None,
user_id: Optional[str] = None,
complete: bool = False, complete: bool = False,
pending: bool = False, pending: bool = False,
failed: bool = False, failed: bool = False,
@ -121,6 +122,13 @@ async def get_payments_paginated(
if wallet_id: if wallet_id:
values["wallet_id"] = wallet_id values["wallet_id"] = wallet_id
clause.append("wallet_id = :wallet_id") clause.append("wallet_id = :wallet_id")
elif user_id:
wallet_ids = await get_wallets_ids(user_id=user_id, conn=conn) or [
"no-wallets-for-user"
]
# wallet ids are safe to use in sql queries
wallet_ids_str = [f"'{w}'" for w in wallet_ids]
clause.append(f""" wallet_id IN ({", ".join(wallet_ids_str)}) """)
if complete and pending: if complete and pending:
clause.append( clause.append(

View file

@ -135,6 +135,20 @@ async def get_wallets(
) )
async def get_wallets_ids(
user_id: str, deleted: Optional[bool] = None, conn: Optional[Connection] = None
) -> list[str]:
where = "AND deleted = :deleted" if deleted is not None else ""
result: list[dict] = await (conn or db).fetchall(
f"""
SELECT id FROM wallets
WHERE "user" = :user {where}
""",
{"user": user_id, "deleted": deleted},
)
return [row["id"] for row in result]
async def get_wallets_count(): async def get_wallets_count():
result = await db.execute("SELECT COUNT(*) as count FROM wallets") result = await db.execute("SELECT COUNT(*) as count FROM wallets")
row = result.mappings().first() row = result.mappings().first()

View file

@ -264,13 +264,21 @@ async def _api_payments_create_invoice(data: CreateInvoice, wallet: Wallet):
response_description="list of payments", response_description="list of payments",
response_model=Page[Payment], response_model=Page[Payment],
openapi_extra=generate_filter_params_openapi(PaymentFilters), openapi_extra=generate_filter_params_openapi(PaymentFilters),
dependencies=[Depends(check_admin)],
) )
async def api_all_payments_paginated( async def api_all_payments_paginated(
filters: Filters = Depends(parse_filters(PaymentFilters)), filters: Filters = Depends(parse_filters(PaymentFilters)),
user: User = Depends(check_user_exists),
): ):
if user.admin:
# admin user can see payments from all wallets
for_user_id = None
else:
# regular user can only see payments from their wallets
for_user_id = user.id
return await get_payments_paginated( return await get_payments_paginated(
filters=filters, filters=filters,
user_id=for_user_id,
) )

View file

@ -8,10 +8,10 @@ from bolt11 import encode as bolt11_encode
from bolt11.types import MilliSatoshi from bolt11.types import MilliSatoshi
from pytest_mock.plugin import MockerFixture from pytest_mock.plugin import MockerFixture
from lnbits.core.crud import get_standalone_payment, get_wallet from lnbits.core.crud import create_wallet, get_standalone_payment, get_wallet
from lnbits.core.crud.payments import get_payment from lnbits.core.crud.payments import get_payment, get_payments_paginated
from lnbits.core.models import Payment, PaymentState, Wallet from lnbits.core.models import Payment, PaymentState, Wallet
from lnbits.core.services import create_invoice, pay_invoice from lnbits.core.services import create_invoice, create_user_account, pay_invoice
from lnbits.exceptions import InvoiceError, PaymentError from lnbits.exceptions import InvoiceError, PaymentError
from lnbits.settings import Settings from lnbits.settings import Settings
from lnbits.tasks import ( from lnbits.tasks import (
@ -596,3 +596,55 @@ async def test_service_fee(
assert service_fee_payment.amount == 422_400 assert service_fee_payment.amount == 422_400
assert service_fee_payment.bolt11 == external_invoice.payment_request assert service_fee_payment.bolt11 == external_invoice.payment_request
assert service_fee_payment.preimage is None assert service_fee_payment.preimage is None
@pytest.mark.anyio
async def test_get_payments_for_user(to_wallet: Wallet):
all_payments = await get_payments_paginated()
total_before = all_payments.total
user = await create_user_account()
wallet_one = await create_wallet(user_id=user.id, wallet_name="first wallet")
wallet_two = await create_wallet(user_id=user.id, wallet_name="second wallet")
user_payments = await get_payments_paginated(user_id=user.id)
assert user_payments.total == 0
payment = await create_invoice(wallet_id=wallet_one.id, amount=100, memo="one")
user_payments = await get_payments_paginated(user_id=user.id)
assert user_payments.total == 1
# this will create a payment in the to_wallet that we need to count for at the end
await pay_invoice(
wallet_id=to_wallet.id,
payment_request=payment.bolt11,
)
user_payments = await get_payments_paginated(user_id=user.id)
assert user_payments.total == 1
payment = await create_invoice(wallet_id=wallet_one.id, amount=3, memo="two")
user_payments = await get_payments_paginated(user_id=user.id)
assert user_payments.total == 2
payment = await create_invoice(wallet_id=wallet_two.id, amount=3, memo="three")
user_payments = await get_payments_paginated(user_id=user.id)
assert user_payments.total == 3
await pay_invoice(
wallet_id=wallet_one.id,
payment_request=payment.bolt11,
)
user_payments = await get_payments_paginated(user_id=user.id)
assert user_payments.total == 4
all_payments = await get_payments_paginated()
total_after = all_payments.total
assert total_after == total_before + 5, "Total payments should be updated."
@pytest.mark.anyio
async def test_get_payments_for_non_user():
user_payments = await get_payments_paginated(user_id="nonexistent")
assert (
user_payments.total == 0
), "No payments should be found for non-existent user."