From f5ccf5c157555bfb410e9578f428c04c9dc134a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?dni=20=E2=9A=A1?= Date: Thu, 28 Nov 2024 11:50:15 +0100 Subject: [PATCH] fix: stuck pay_invoice (#2783) db connection wasnt passed to `get_user_active_extensions_id`. and in some contexts nested db connection will stuck the server --- lnbits/core/services/payments.py | 4 ++-- lnbits/decorators.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/lnbits/core/services/payments.py b/lnbits/core/services/payments.py index 10c7e401..118da736 100644 --- a/lnbits/core/services/payments.py +++ b/lnbits/core/services/payments.py @@ -520,14 +520,14 @@ async def _check_wallet_for_payment( wallet_id: str, tag: str, amount_msat: int, - conn: Optional[Connection], + conn: Optional[Connection] = None, ): wallet = await get_wallet(wallet_id, conn=conn) if not wallet: raise PaymentError(f"Could not fetch wallet '{wallet_id}'.", status="failed") # check if the payment is made for an extension that the user disabled - status = await check_user_extension_access(wallet.user, tag) + status = await check_user_extension_access(wallet.user, tag, conn=conn) if not status.success: raise PaymentError(status.message) diff --git a/lnbits/decorators.py b/lnbits/decorators.py index 4f8ed115..83f55b8a 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -26,7 +26,7 @@ from lnbits.core.models import ( User, WalletTypeInfo, ) -from lnbits.db import Filter, Filters, TFilterModel +from lnbits.db import Connection, Filter, Filters, TFilterModel from lnbits.settings import AuthMethods, settings oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth", auto_error=False) @@ -235,7 +235,9 @@ def parse_filters(model: Type[TFilterModel]): return dependency -async def check_user_extension_access(user_id: str, ext_id: str) -> SimpleStatus: +async def check_user_extension_access( + user_id: str, ext_id: str, conn: Optional[Connection] = None +) -> SimpleStatus: """ Check if the user has access to a particular extension. Raises HTTP Forbidden if the user is not allowed. @@ -246,7 +248,7 @@ async def check_user_extension_access(user_id: str, ext_id: str) -> SimpleStatus ) if settings.is_extension_id(ext_id): - ext_ids = await get_user_active_extensions_ids(user_id) + ext_ids = await get_user_active_extensions_ids(user_id, conn=conn) if ext_id not in ext_ids: return SimpleStatus( success=False, message=f"User extension '{ext_id}' not enabled."