fix: stuck pay_invoice (#2783)

db connection wasnt passed to `get_user_active_extensions_id`. and in
some contexts nested db connection will stuck the server
This commit is contained in:
dni ⚡ 2024-11-28 11:50:15 +01:00 committed by GitHub
parent e3d6b6befa
commit f5ccf5c157
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 7 additions and 5 deletions

View file

@ -520,14 +520,14 @@ async def _check_wallet_for_payment(
wallet_id: str, wallet_id: str,
tag: str, tag: str,
amount_msat: int, amount_msat: int,
conn: Optional[Connection], conn: Optional[Connection] = None,
): ):
wallet = await get_wallet(wallet_id, conn=conn) wallet = await get_wallet(wallet_id, conn=conn)
if not wallet: if not wallet:
raise PaymentError(f"Could not fetch wallet '{wallet_id}'.", status="failed") raise PaymentError(f"Could not fetch wallet '{wallet_id}'.", status="failed")
# check if the payment is made for an extension that the user disabled # check if the payment is made for an extension that the user disabled
status = await check_user_extension_access(wallet.user, tag) status = await check_user_extension_access(wallet.user, tag, conn=conn)
if not status.success: if not status.success:
raise PaymentError(status.message) raise PaymentError(status.message)

View file

@ -26,7 +26,7 @@ from lnbits.core.models import (
User, User,
WalletTypeInfo, WalletTypeInfo,
) )
from lnbits.db import Filter, Filters, TFilterModel from lnbits.db import Connection, Filter, Filters, TFilterModel
from lnbits.settings import AuthMethods, settings from lnbits.settings import AuthMethods, settings
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth", auto_error=False) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth", auto_error=False)
@ -235,7 +235,9 @@ def parse_filters(model: Type[TFilterModel]):
return dependency return dependency
async def check_user_extension_access(user_id: str, ext_id: str) -> SimpleStatus: async def check_user_extension_access(
user_id: str, ext_id: str, conn: Optional[Connection] = None
) -> SimpleStatus:
""" """
Check if the user has access to a particular extension. Check if the user has access to a particular extension.
Raises HTTP Forbidden if the user is not allowed. Raises HTTP Forbidden if the user is not allowed.
@ -246,7 +248,7 @@ async def check_user_extension_access(user_id: str, ext_id: str) -> SimpleStatus
) )
if settings.is_extension_id(ext_id): if settings.is_extension_id(ext_id):
ext_ids = await get_user_active_extensions_ids(user_id) ext_ids = await get_user_active_extensions_ids(user_id, conn=conn)
if ext_id not in ext_ids: if ext_id not in ext_ids:
return SimpleStatus( return SimpleStatus(
success=False, message=f"User extension '{ext_id}' not enabled." success=False, message=f"User extension '{ext_id}' not enabled."