[feat] fetch all payments for user (#3132)
This commit is contained in:
parent
2dee26b728
commit
e339bb6181
4 changed files with 87 additions and 5 deletions
|
|
@ -1,7 +1,7 @@
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
from lnbits.core.crud.wallets import get_total_balance, get_wallet
|
from lnbits.core.crud.wallets import get_total_balance, get_wallet, get_wallets_ids
|
||||||
from lnbits.core.db import db
|
from lnbits.core.db import db
|
||||||
from lnbits.core.models import PaymentState
|
from lnbits.core.models import PaymentState
|
||||||
from lnbits.db import Connection, DateTrunc, Filters, Page
|
from lnbits.db import Connection, DateTrunc, Filters, Page
|
||||||
|
|
@ -95,6 +95,7 @@ async def get_latest_payments_by_extension(
|
||||||
async def get_payments_paginated(
|
async def get_payments_paginated(
|
||||||
*,
|
*,
|
||||||
wallet_id: Optional[str] = None,
|
wallet_id: Optional[str] = None,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
complete: bool = False,
|
complete: bool = False,
|
||||||
pending: bool = False,
|
pending: bool = False,
|
||||||
failed: bool = False,
|
failed: bool = False,
|
||||||
|
|
@ -121,6 +122,13 @@ async def get_payments_paginated(
|
||||||
if wallet_id:
|
if wallet_id:
|
||||||
values["wallet_id"] = wallet_id
|
values["wallet_id"] = wallet_id
|
||||||
clause.append("wallet_id = :wallet_id")
|
clause.append("wallet_id = :wallet_id")
|
||||||
|
elif user_id:
|
||||||
|
wallet_ids = await get_wallets_ids(user_id=user_id, conn=conn) or [
|
||||||
|
"no-wallets-for-user"
|
||||||
|
]
|
||||||
|
# wallet ids are safe to use in sql queries
|
||||||
|
wallet_ids_str = [f"'{w}'" for w in wallet_ids]
|
||||||
|
clause.append(f""" wallet_id IN ({", ".join(wallet_ids_str)}) """)
|
||||||
|
|
||||||
if complete and pending:
|
if complete and pending:
|
||||||
clause.append(
|
clause.append(
|
||||||
|
|
|
||||||
|
|
@ -135,6 +135,20 @@ async def get_wallets(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_wallets_ids(
|
||||||
|
user_id: str, deleted: Optional[bool] = None, conn: Optional[Connection] = None
|
||||||
|
) -> list[str]:
|
||||||
|
where = "AND deleted = :deleted" if deleted is not None else ""
|
||||||
|
result: list[dict] = await (conn or db).fetchall(
|
||||||
|
f"""
|
||||||
|
SELECT id FROM wallets
|
||||||
|
WHERE "user" = :user {where}
|
||||||
|
""",
|
||||||
|
{"user": user_id, "deleted": deleted},
|
||||||
|
)
|
||||||
|
return [row["id"] for row in result]
|
||||||
|
|
||||||
|
|
||||||
async def get_wallets_count():
|
async def get_wallets_count():
|
||||||
result = await db.execute("SELECT COUNT(*) as count FROM wallets")
|
result = await db.execute("SELECT COUNT(*) as count FROM wallets")
|
||||||
row = result.mappings().first()
|
row = result.mappings().first()
|
||||||
|
|
|
||||||
|
|
@ -264,13 +264,21 @@ async def _api_payments_create_invoice(data: CreateInvoice, wallet: Wallet):
|
||||||
response_description="list of payments",
|
response_description="list of payments",
|
||||||
response_model=Page[Payment],
|
response_model=Page[Payment],
|
||||||
openapi_extra=generate_filter_params_openapi(PaymentFilters),
|
openapi_extra=generate_filter_params_openapi(PaymentFilters),
|
||||||
dependencies=[Depends(check_admin)],
|
|
||||||
)
|
)
|
||||||
async def api_all_payments_paginated(
|
async def api_all_payments_paginated(
|
||||||
filters: Filters = Depends(parse_filters(PaymentFilters)),
|
filters: Filters = Depends(parse_filters(PaymentFilters)),
|
||||||
|
user: User = Depends(check_user_exists),
|
||||||
):
|
):
|
||||||
|
if user.admin:
|
||||||
|
# admin user can see payments from all wallets
|
||||||
|
for_user_id = None
|
||||||
|
else:
|
||||||
|
# regular user can only see payments from their wallets
|
||||||
|
for_user_id = user.id
|
||||||
|
|
||||||
return await get_payments_paginated(
|
return await get_payments_paginated(
|
||||||
filters=filters,
|
filters=filters,
|
||||||
|
user_id=for_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,10 @@ from bolt11 import encode as bolt11_encode
|
||||||
from bolt11.types import MilliSatoshi
|
from bolt11.types import MilliSatoshi
|
||||||
from pytest_mock.plugin import MockerFixture
|
from pytest_mock.plugin import MockerFixture
|
||||||
|
|
||||||
from lnbits.core.crud import get_standalone_payment, get_wallet
|
from lnbits.core.crud import create_wallet, get_standalone_payment, get_wallet
|
||||||
from lnbits.core.crud.payments import get_payment
|
from lnbits.core.crud.payments import get_payment, get_payments_paginated
|
||||||
from lnbits.core.models import Payment, PaymentState, Wallet
|
from lnbits.core.models import Payment, PaymentState, Wallet
|
||||||
from lnbits.core.services import create_invoice, pay_invoice
|
from lnbits.core.services import create_invoice, create_user_account, pay_invoice
|
||||||
from lnbits.exceptions import InvoiceError, PaymentError
|
from lnbits.exceptions import InvoiceError, PaymentError
|
||||||
from lnbits.settings import Settings
|
from lnbits.settings import Settings
|
||||||
from lnbits.tasks import (
|
from lnbits.tasks import (
|
||||||
|
|
@ -596,3 +596,55 @@ async def test_service_fee(
|
||||||
assert service_fee_payment.amount == 422_400
|
assert service_fee_payment.amount == 422_400
|
||||||
assert service_fee_payment.bolt11 == external_invoice.payment_request
|
assert service_fee_payment.bolt11 == external_invoice.payment_request
|
||||||
assert service_fee_payment.preimage is None
|
assert service_fee_payment.preimage is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_payments_for_user(to_wallet: Wallet):
|
||||||
|
all_payments = await get_payments_paginated()
|
||||||
|
total_before = all_payments.total
|
||||||
|
|
||||||
|
user = await create_user_account()
|
||||||
|
wallet_one = await create_wallet(user_id=user.id, wallet_name="first wallet")
|
||||||
|
wallet_two = await create_wallet(user_id=user.id, wallet_name="second wallet")
|
||||||
|
|
||||||
|
user_payments = await get_payments_paginated(user_id=user.id)
|
||||||
|
assert user_payments.total == 0
|
||||||
|
|
||||||
|
payment = await create_invoice(wallet_id=wallet_one.id, amount=100, memo="one")
|
||||||
|
user_payments = await get_payments_paginated(user_id=user.id)
|
||||||
|
assert user_payments.total == 1
|
||||||
|
# this will create a payment in the to_wallet that we need to count for at the end
|
||||||
|
await pay_invoice(
|
||||||
|
wallet_id=to_wallet.id,
|
||||||
|
payment_request=payment.bolt11,
|
||||||
|
)
|
||||||
|
user_payments = await get_payments_paginated(user_id=user.id)
|
||||||
|
assert user_payments.total == 1
|
||||||
|
|
||||||
|
payment = await create_invoice(wallet_id=wallet_one.id, amount=3, memo="two")
|
||||||
|
user_payments = await get_payments_paginated(user_id=user.id)
|
||||||
|
assert user_payments.total == 2
|
||||||
|
|
||||||
|
payment = await create_invoice(wallet_id=wallet_two.id, amount=3, memo="three")
|
||||||
|
user_payments = await get_payments_paginated(user_id=user.id)
|
||||||
|
assert user_payments.total == 3
|
||||||
|
|
||||||
|
await pay_invoice(
|
||||||
|
wallet_id=wallet_one.id,
|
||||||
|
payment_request=payment.bolt11,
|
||||||
|
)
|
||||||
|
user_payments = await get_payments_paginated(user_id=user.id)
|
||||||
|
assert user_payments.total == 4
|
||||||
|
|
||||||
|
all_payments = await get_payments_paginated()
|
||||||
|
total_after = all_payments.total
|
||||||
|
|
||||||
|
assert total_after == total_before + 5, "Total payments should be updated."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_get_payments_for_non_user():
|
||||||
|
user_payments = await get_payments_paginated(user_id="nonexistent")
|
||||||
|
assert (
|
||||||
|
user_payments.total == 0
|
||||||
|
), "No payments should be found for non-existent user."
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue