fix: stuck pay_invoice (#2783)
db connection wasnt passed to `get_user_active_extensions_id`. and in some contexts nested db connection will stuck the server
This commit is contained in:
parent
e3d6b6befa
commit
f5ccf5c157
2 changed files with 7 additions and 5 deletions
|
|
@ -520,14 +520,14 @@ async def _check_wallet_for_payment(
|
||||||
wallet_id: str,
|
wallet_id: str,
|
||||||
tag: str,
|
tag: str,
|
||||||
amount_msat: int,
|
amount_msat: int,
|
||||||
conn: Optional[Connection],
|
conn: Optional[Connection] = None,
|
||||||
):
|
):
|
||||||
wallet = await get_wallet(wallet_id, conn=conn)
|
wallet = await get_wallet(wallet_id, conn=conn)
|
||||||
if not wallet:
|
if not wallet:
|
||||||
raise PaymentError(f"Could not fetch wallet '{wallet_id}'.", status="failed")
|
raise PaymentError(f"Could not fetch wallet '{wallet_id}'.", status="failed")
|
||||||
|
|
||||||
# check if the payment is made for an extension that the user disabled
|
# check if the payment is made for an extension that the user disabled
|
||||||
status = await check_user_extension_access(wallet.user, tag)
|
status = await check_user_extension_access(wallet.user, tag, conn=conn)
|
||||||
if not status.success:
|
if not status.success:
|
||||||
raise PaymentError(status.message)
|
raise PaymentError(status.message)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ from lnbits.core.models import (
|
||||||
User,
|
User,
|
||||||
WalletTypeInfo,
|
WalletTypeInfo,
|
||||||
)
|
)
|
||||||
from lnbits.db import Filter, Filters, TFilterModel
|
from lnbits.db import Connection, Filter, Filters, TFilterModel
|
||||||
from lnbits.settings import AuthMethods, settings
|
from lnbits.settings import AuthMethods, settings
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth", auto_error=False)
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth", auto_error=False)
|
||||||
|
|
@ -235,7 +235,9 @@ def parse_filters(model: Type[TFilterModel]):
|
||||||
return dependency
|
return dependency
|
||||||
|
|
||||||
|
|
||||||
async def check_user_extension_access(user_id: str, ext_id: str) -> SimpleStatus:
|
async def check_user_extension_access(
|
||||||
|
user_id: str, ext_id: str, conn: Optional[Connection] = None
|
||||||
|
) -> SimpleStatus:
|
||||||
"""
|
"""
|
||||||
Check if the user has access to a particular extension.
|
Check if the user has access to a particular extension.
|
||||||
Raises HTTP Forbidden if the user is not allowed.
|
Raises HTTP Forbidden if the user is not allowed.
|
||||||
|
|
@ -246,7 +248,7 @@ async def check_user_extension_access(user_id: str, ext_id: str) -> SimpleStatus
|
||||||
)
|
)
|
||||||
|
|
||||||
if settings.is_extension_id(ext_id):
|
if settings.is_extension_id(ext_id):
|
||||||
ext_ids = await get_user_active_extensions_ids(user_id)
|
ext_ids = await get_user_active_extensions_ids(user_id, conn=conn)
|
||||||
if ext_id not in ext_ids:
|
if ext_id not in ext_ids:
|
||||||
return SimpleStatus(
|
return SimpleStatus(
|
||||||
success=False, message=f"User extension '{ext_id}' not enabled."
|
success=False, message=f"User extension '{ext_id}' not enabled."
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue