diff --git a/lnbits/core/services/payments.py b/lnbits/core/services/payments.py index a00dd464..d0f3c15a 100644 --- a/lnbits/core/services/payments.py +++ b/lnbits/core/services/payments.py @@ -44,6 +44,9 @@ from ..models import ( ) from .notifications import send_payment_notification +payment_lock = asyncio.Lock() +wallets_payments_lock: dict[str, asyncio.Lock] = {} + async def pay_invoice( *, @@ -79,12 +82,12 @@ async def pay_invoice( extra=extra, ) - payment = await _pay_invoice(wallet, create_payment_model, conn) + payment = await _pay_invoice(wallet.id, create_payment_model, conn) async with db.reuse_conn(conn) if conn else db.connect() as new_conn: await _credit_service_fee_wallet(payment, new_conn) - return payment + return payment async def create_invoice( @@ -441,11 +444,27 @@ async def get_payments_daily_stats( return data -async def _pay_invoice(wallet, create_payment_model, conn): - payment = await _pay_internal_invoice(wallet, create_payment_model, conn) - if not payment: - payment = await _pay_external_invoice(wallet, create_payment_model, conn) - return payment +async def _pay_invoice( + wallet_id: str, + create_payment_model: CreatePayment, + conn: Optional[Connection] = None, +): + async with payment_lock: + if wallet_id not in wallets_payments_lock: + wallets_payments_lock[wallet_id] = asyncio.Lock() + + async with wallets_payments_lock[wallet_id]: + # get the wallet again to make sure we have the latest balance + wallet = await get_wallet(wallet_id, conn=conn) + if not wallet: + raise PaymentError( + f"Could not fetch wallet '{wallet_id}'.", status="failed" + ) + + payment = await _pay_internal_invoice(wallet, create_payment_model, conn) + if not payment: + payment = await _pay_external_invoice(wallet, create_payment_model, conn) + return payment async def _pay_internal_invoice( @@ -526,6 +545,8 @@ async def _pay_external_invoice( fee_reserve_total_msat = fee_reserve_total(amount_msat, internal=False) + if wallet.balance_msat < abs(amount_msat): + raise PaymentError("Insufficient balance.", status="failed") if wallet.balance_msat < abs(amount_msat) + fee_reserve_total_msat: raise PaymentError( f"You must reserve at least ({round(fee_reserve_total_msat/1000)}" diff --git a/tests/regtest/test_real_invoice.py b/tests/regtest/test_real_invoice.py index c799c1f3..22707aca 100644 --- a/tests/regtest/test_real_invoice.py +++ b/tests/regtest/test_real_invoice.py @@ -5,8 +5,12 @@ import pytest from lnbits import bolt11 from lnbits.core.crud import get_standalone_payment, update_payment +from lnbits.core.crud.wallets import create_wallet, get_wallet from lnbits.core.models import CreateInvoice, Payment, PaymentState from lnbits.core.services import fee_reserve_total, get_balance_delta +from lnbits.core.services.payments import pay_invoice, update_wallet_balance +from lnbits.core.services.users import create_user_account +from lnbits.exceptions import PaymentError from lnbits.tasks import create_task, wait_for_paid_invoices from lnbits.wallets import get_funding_source @@ -153,6 +157,39 @@ async def test_pay_real_invoice_set_pending_and_check_state( assert payment.success +@pytest.mark.anyio +@pytest.mark.skipif(is_fake, reason="this only works in regtest") +async def test_pay_real_invoices_in_parallel(): + user = await create_user_account() + wallet = await create_wallet(user_id=user.id) + + # more to cover routing feems + await update_wallet_balance(wallet, 1100) + + # these must be external invoices + real_invoice_one = get_real_invoice(1000) + real_invoice_two = get_real_invoice(1000) + + async def pay_first(): + return await pay_invoice( + wallet_id=wallet.id, + payment_request=real_invoice_one["payment_request"], + ) + + async def pay_second(): + return await pay_invoice( + wallet_id=wallet.id, + payment_request=real_invoice_two["payment_request"], + ) + + with pytest.raises(PaymentError, match="Insufficient balance."): + await asyncio.gather(pay_first(), pay_second()) + + wallet_after = await get_wallet(wallet.id) + assert wallet_after + assert 0 <= wallet_after.balance <= 100, "One payment should be deducted." + + @pytest.mark.anyio @pytest.mark.skipif(is_fake, reason="this only works in regtest") async def test_pay_hold_invoice_check_pending( diff --git a/tests/unit/test_pay_invoice.py b/tests/unit/test_pay_invoice.py index 387b7987..dc4e019b 100644 --- a/tests/unit/test_pay_invoice.py +++ b/tests/unit/test_pay_invoice.py @@ -12,6 +12,7 @@ from lnbits.core.crud import create_wallet, get_standalone_payment, get_wallet from lnbits.core.crud.payments import get_payment, get_payments_paginated from lnbits.core.models import Payment, PaymentState, Wallet from lnbits.core.services import create_invoice, create_user_account, pay_invoice +from lnbits.core.services.payments import update_wallet_balance from lnbits.exceptions import InvoiceError, PaymentError from lnbits.settings import Settings from lnbits.tasks import ( @@ -114,6 +115,62 @@ async def test_pay_twice(to_wallet: Wallet): ) +@pytest.mark.anyio +async def test_pay_twice_fast(): + user = await create_user_account() + wallet_one = await create_wallet(user_id=user.id) + wallet_two = await create_wallet(user_id=user.id) + + await update_wallet_balance(wallet_one, 1000) + payment_a = await create_invoice(wallet_id=wallet_two.id, amount=1000, memo="AAA") + payment_b = await create_invoice(wallet_id=wallet_two.id, amount=1000, memo="BBB") + + async def pay_first(): + return await pay_invoice( + wallet_id=wallet_one.id, + payment_request=payment_a.bolt11, + ) + + async def pay_second(): + return await pay_invoice( + wallet_id=wallet_one.id, + payment_request=payment_b.bolt11, + ) + + with pytest.raises(PaymentError, match="Insufficient balance."): + await asyncio.gather(pay_first(), pay_second()) + + wallet_one_after = await get_wallet(wallet_one.id) + assert wallet_one_after + assert wallet_one_after.balance == 0, "One payment should be deducted." + + wallet_two_after = await get_wallet(wallet_two.id) + assert wallet_two_after + assert wallet_two_after.balance == 1000, "One payment received." + + +@pytest.mark.anyio +async def test_pay_twice_fast_same_invoice(to_wallet: Wallet): + payment = await create_invoice( + wallet_id=to_wallet.id, amount=3, memo="Twice fast same invoice" + ) + + async def pay_first(): + return await pay_invoice( + wallet_id=to_wallet.id, + payment_request=payment.bolt11, + ) + + async def pay_second(): + return await pay_invoice( + wallet_id=to_wallet.id, + payment_request=payment.bolt11, + ) + + with pytest.raises(PaymentError, match="Payment already paid."): + await asyncio.gather(pay_first(), pay_second()) + + @pytest.mark.anyio async def test_fake_wallet_pay_external( to_wallet: Wallet, external_funding_source: FakeWallet