diff --git a/lnbits/core/crud/payments.py b/lnbits/core/crud/payments.py index e11cd035..f9978db9 100644 --- a/lnbits/core/crud/payments.py +++ b/lnbits/core/crud/payments.py @@ -33,9 +33,9 @@ async def get_payment(checking_id: str, conn: Connection | None = None) -> Payme async def get_standalone_payment( checking_id_or_hash: str, - conn: Connection | None = None, incoming: bool | None = False, wallet_id: str | None = None, + conn: Connection | None = None, ) -> Payment | None: clause: str = "checking_id = :checking_id OR payment_hash = :hash" values = { @@ -46,7 +46,7 @@ async def get_standalone_payment( clause = f"({clause}) AND amount > 0" if wallet_id: - wallet = await get_wallet(wallet_id) + wallet = await get_wallet(wallet_id, conn=conn) if not wallet or not wallet.can_view_payments: return None values["wallet_id"] = wallet.source_wallet_id @@ -69,7 +69,7 @@ async def get_standalone_payment( async def get_wallet_payment( wallet_id: str, payment_hash: str, conn: Connection | None = None ) -> Payment | None: - wallet = await get_wallet(wallet_id) + wallet = await get_wallet(wallet_id, conn=conn) if not wallet or not wallet.can_view_payments: return None payment = await (conn or db).fetchone( @@ -124,7 +124,6 @@ async def get_payments_paginated( # noqa: C901 Filters payments to be returned by: - complete | pending | failed | outgoing | incoming. """ - values: dict[str, Any] = { "time": since, } @@ -134,7 +133,7 @@ async def get_payments_paginated( # noqa: C901 clause.append(f"time > {db.timestamp_placeholder('time')}") if wallet_id: - wallet = await get_wallet(wallet_id) + wallet = await get_wallet(wallet_id, conn=conn) if not wallet or not wallet.can_view_payments: return Page(data=[], total=0) @@ -326,6 +325,7 @@ async def get_payments_history( wallet_id: str | None = None, group: DateTrunc = "day", filters: Filters | None = None, + conn: Connection | None = None, ) -> list[PaymentHistoryPoint]: if not filters: filters = Filters() @@ -361,13 +361,13 @@ async def get_payments_history( filters.values(values), ) if wallet_id: - wallet = await get_wallet(wallet_id) + wallet = await get_wallet(wallet_id, conn=conn) if not wallet or not wallet.can_view_payments: return [] balance = wallet.balance_msat values["wallet_id"] = wallet.source_wallet_id else: - balance = await get_total_balance() + balance = await get_total_balance(conn=conn) # since we dont know the balance at the starting point, # we take the current balance and walk backwards diff --git a/lnbits/core/crud/users.py b/lnbits/core/crud/users.py index df3ef97f..36359332 100644 --- a/lnbits/core/crud/users.py +++ b/lnbits/core/crud/users.py @@ -30,9 +30,9 @@ async def create_account( return account -async def update_account(account: Account) -> Account: +async def update_account(account: Account, conn: Connection | None = None) -> Account: account.updated_at = datetime.now(timezone.utc) - await db.update("accounts", account) + await (conn or db).update("accounts", account) return account @@ -171,21 +171,23 @@ async def get_account_by_username_or_email( async def get_user(user_id: str, conn: Connection | None = None) -> User | None: - account = await get_account(user_id, conn) - if not account: - return None - return await get_user_from_account(account, conn) + async with db.reuse_conn(conn) if conn else db.connect() as conn: + account = await get_account(user_id, conn=conn) + if not account: + return None + return await get_user_from_account(account, conn=conn) async def get_user_from_account( account: Account, conn: Connection | None = None ) -> User | None: - extensions = await get_user_active_extensions_ids(account.id, conn=conn) - wallets = await get_wallets(account.id, deleted=False, conn=conn) + async with db.reuse_conn(conn) if conn else db.connect() as conn: + extensions = await get_user_active_extensions_ids(account.id, conn=conn) + wallets = await get_wallets(account.id, deleted=False, conn=conn) - if len(wallets) == 0: - wallet = await create_wallet(user_id=account.id, conn=conn) - wallets.append(wallet) + if len(wallets) == 0: + wallet = await create_wallet(user_id=account.id, conn=conn) + wallets.append(wallet) return User( id=account.id, @@ -205,9 +207,11 @@ async def get_user_from_account( ) -async def update_user_access_control_list(user_acls: UserAcls): +async def update_user_access_control_list( + user_acls: UserAcls, conn: Connection | None = None +): user_acls.updated_at = datetime.now(timezone.utc) - await db.update("accounts", user_acls) + await (conn or db).update("accounts", user_acls) async def get_user_access_control_lists( diff --git a/lnbits/core/migrations.py b/lnbits/core/migrations.py index f9f39005..669e281b 100644 --- a/lnbits/core/migrations.py +++ b/lnbits/core/migrations.py @@ -786,3 +786,63 @@ async def m038_add_labels_for_payments(db: Connection): ALTER TABLE apipayments ADD COLUMN labels TEXT """ ) + + +async def m039_index_payments(db: Connection): + indexes = [ + "wallet_id", + "checking_id", + "payment_hash", + "amount", + "labels", + "time", + "status", + "created_at", + "updated_at", + ] + for index in indexes: + logger.debug(f"Creating index idx_payments_{index}...") + await db.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_payments_{index} ON apipayments ({index}); + """ + ) + + +async def m040_index_wallets(db: Connection): + indexes = [ + "id", + "user", + "deleted", + "adminkey", + "inkey", + "wallet_type", + "created_at", + "updated_at", + ] + + for index in indexes: + logger.debug(f"Creating index idx_wallets_{index}...") + await db.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_wallets_{index} ON wallets ("{index}"); + """ + ) + + +async def m042_index_accounts(db: Connection): + indexes = [ + "id", + "email", + "username", + "pubkey", + "external_id", + ] + + for index in indexes: + logger.debug(f"Creating index idx_wallets_{index}...") + await db.execute( + f""" + CREATE INDEX IF NOT EXISTS idx_accounts_{index} ON accounts ("{index}"); + """ + ) diff --git a/lnbits/core/models/users.py b/lnbits/core/models/users.py index c06f3a08..11ce2bac 100644 --- a/lnbits/core/models/users.py +++ b/lnbits/core/models/users.py @@ -201,6 +201,10 @@ class Account(AccountId): self.is_admin = settings.is_admin_user(self.id) self.fiat_providers = settings.get_fiat_providers_for_user(self.id) + @property + def has_password(self) -> bool: + return self.password_hash is not None + def hash_password(self, password: str) -> str: """sets and returns the hashed password""" salt = gensalt() diff --git a/lnbits/core/services/extensions.py b/lnbits/core/services/extensions.py index b2a561d8..8743efbc 100644 --- a/lnbits/core/services/extensions.py +++ b/lnbits/core/services/extensions.py @@ -16,6 +16,7 @@ from lnbits.core.crud.extensions import ( update_installed_extension, ) from lnbits.core.helpers import migrate_extension_database +from lnbits.db import Connection from lnbits.settings import settings from ..models.extensions import Extension, ExtensionMeta, InstallableExtension @@ -149,9 +150,9 @@ async def start_extension_background_work(ext_id: str) -> bool: async def get_valid_extensions( - include_deactivated: bool | None = True, + include_deactivated: bool | None = True, conn: Connection | None = None ) -> list[Extension]: - installed_extensions = await get_installed_extensions() + installed_extensions = await get_installed_extensions(conn=conn) valid_extensions = [Extension.from_installable_ext(e) for e in installed_extensions] if include_deactivated: diff --git a/lnbits/core/services/payments.py b/lnbits/core/services/payments.py index ae029948..80be5620 100644 --- a/lnbits/core/services/payments.py +++ b/lnbits/core/services/payments.py @@ -67,12 +67,14 @@ async def pay_invoice( if settings.lnbits_only_allow_incoming_payments: raise PaymentError("Only incoming payments allowed.", status="failed") invoice = _validate_payment_request(payment_request, max_sat) + if not invoice.amount_msat: raise ValueError("Missig invoice amount.") async with db.reuse_conn(conn) if conn else db.connect() as new_conn: amount_msat = invoice.amount_msat wallet = await _check_wallet_for_payment(wallet_id, tag, amount_msat, new_conn) + if not wallet.can_send_payments: raise PaymentError( "Wallet does not have permission to pay invoices.", @@ -95,10 +97,12 @@ async def pay_invoice( labels=labels, ) - payment = await _pay_invoice(wallet.source_wallet_id, create_payment_model, conn) - async with db.reuse_conn(conn) if conn else db.connect() as new_conn: - await _credit_service_fee_wallet(wallet, payment, new_conn) + payment = await _pay_invoice( + wallet.source_wallet_id, create_payment_model, conn=new_conn + ) + + await _credit_service_fee_wallet(wallet, payment, conn=new_conn) return payment @@ -197,20 +201,22 @@ async def create_wallet_invoice(wallet_id: str, data: CreateInvoice) -> Payment: # do not save memo if description_hash or unhashed_description is set memo = "" - payment = await create_invoice( - wallet_id=wallet_id, - amount=data.amount, - memo=memo, - currency=data.unit, - description_hash=description_hash, - unhashed_description=unhashed_description, - expiry=data.expiry, - extra=data.extra, - webhook=data.webhook, - internal=data.internal, - payment_hash=data.payment_hash, - labels=data.labels, - ) + async with db.connect() as conn: + payment = await create_invoice( + wallet_id=wallet_id, + amount=data.amount, + memo=memo, + currency=data.unit, + description_hash=description_hash, + unhashed_description=unhashed_description, + expiry=data.expiry, + extra=data.extra, + webhook=data.webhook, + internal=data.internal, + payment_hash=data.payment_hash, + labels=data.labels, + conn=conn, + ) if data.lnurl_withdraw: try: @@ -355,13 +361,15 @@ async def update_pending_payments(wallet_id: str): await update_pending_payment(payment) -async def update_pending_payment(payment: Payment) -> Payment: +async def update_pending_payment( + payment: Payment, conn: Connection | None = None +) -> Payment: status = await payment.check_status() if status.failed: payment.status = PaymentState.FAILED - await update_payment(payment) + await update_payment(payment, conn=conn) elif status.success: - payment = await update_payment_success_status(payment, status) + payment = await update_payment_success_status(payment, status, conn=conn) return payment @@ -698,6 +706,7 @@ async def _pay_internal_invoice( internal_payment = await check_internal( create_payment_model.payment_hash, conn=conn ) + if not internal_payment: return None @@ -706,6 +715,7 @@ async def _pay_internal_invoice( internal_invoice = await get_standalone_payment( internal_payment.checking_id, incoming=True, conn=conn ) + if not internal_invoice: raise PaymentError("Internal payment not found.", status="failed") @@ -727,6 +737,7 @@ async def _pay_internal_invoice( internal_id = f"internal_{create_payment_model.payment_hash}" logger.debug(f"creating temporary internal payment with id {internal_id}") + payment = await create_payment( checking_id=internal_id, data=create_payment_model, diff --git a/lnbits/core/services/users.py b/lnbits/core/services/users.py index c9586ff1..a3c9b800 100644 --- a/lnbits/core/services/users.py +++ b/lnbits/core/services/users.py @@ -3,7 +3,9 @@ from uuid import uuid4 from loguru import logger +from lnbits.core.db import db from lnbits.core.models.extensions import UserExtension +from lnbits.db import Connection from lnbits.settings import ( EditableSettings, SuperSettings, @@ -48,37 +50,43 @@ async def create_user_account_no_ckeck( account: Account | None = None, wallet_name: str | None = None, default_exts: list[str] | None = None, + conn: Connection | None = None, ) -> User: + async with db.reuse_conn(conn) if conn else db.connect() as conn: + if account: + account.validate_fields() + if account.username and await get_account_by_username( + account.username, conn=conn + ): + raise ValueError("Username already exists.") - if account: - account.validate_fields() - if account.username and await get_account_by_username(account.username): - raise ValueError("Username already exists.") + if account.email and await get_account_by_email(account.email, conn=conn): + raise ValueError("Email already exists.") - if account.email and await get_account_by_email(account.email): - raise ValueError("Email already exists.") + if account.pubkey and await get_account_by_pubkey( + account.pubkey, conn=conn + ): + raise ValueError("Pubkey already exists.") - if account.pubkey and await get_account_by_pubkey(account.pubkey): - raise ValueError("Pubkey already exists.") + if not account.id: + account.id = uuid4().hex - if not account.id: - account.id = uuid4().hex + account = await create_account(account, conn=conn) + await create_wallet( + user_id=account.id, + wallet_name=wallet_name or settings.lnbits_default_wallet_name, + conn=conn, + ) - account = await create_account(account) - await create_wallet( - user_id=account.id, - wallet_name=wallet_name or settings.lnbits_default_wallet_name, - ) + user_extensions = (default_exts or []) + settings.lnbits_user_default_extensions + for ext_id in user_extensions: + try: + user_ext = UserExtension(user=account.id, extension=ext_id, active=True) + await create_user_extension(user_ext, conn=conn) + except Exception as e: + logger.error(f"Error enabeling default extension {ext_id}: {e}") - user_extensions = (default_exts or []) + settings.lnbits_user_default_extensions - for ext_id in user_extensions: - try: - user_ext = UserExtension(user=account.id, extension=ext_id, active=True) - await create_user_extension(user_ext) - except Exception as e: - logger.error(f"Error enabeling default extension {ext_id}: {e}") - - user = await get_user_from_account(account) + user = await get_user_from_account(account, conn=conn) if not user: raise ValueError("Cannot find user for account.") diff --git a/lnbits/core/views/admin_api.py b/lnbits/core/views/admin_api.py index 7df89bc0..6e69d07c 100644 --- a/lnbits/core/views/admin_api.py +++ b/lnbits/core/views/admin_api.py @@ -8,8 +8,8 @@ from urllib.parse import urlparse from fastapi import APIRouter, Depends, File from fastapi.responses import FileResponse -from lnbits.core.models import User from lnbits.core.models.notifications import NotificationType +from lnbits.core.models.users import Account from lnbits.core.services import ( enqueue_admin_notification, get_balance_delta, @@ -67,9 +67,9 @@ async def api_test_email(): @admin_router.get("/api/v1/settings") async def api_get_settings( - user: User = Depends(check_admin), + account: Account = Depends(check_admin), ) -> AdminSettings | None: - admin_settings = await get_admin_settings(user.super_user) + admin_settings = await get_admin_settings(account.is_super_user) return admin_settings @@ -77,12 +77,14 @@ async def api_get_settings( "/api/v1/settings", status_code=HTTPStatus.OK, ) -async def api_update_settings(data: UpdateSettings, user: User = Depends(check_admin)): +async def api_update_settings( + data: UpdateSettings, account: Account = Depends(check_admin) +): enqueue_admin_notification( - NotificationType.settings_update, {"username": user.username} + NotificationType.settings_update, {"username": account.username} ) await update_admin_settings(data) - admin_settings = await get_admin_settings(user.super_user) + admin_settings = await get_admin_settings(account.is_super_user) if not admin_settings: raise ValueError("Updated admin settings not found.") update_cached_settings(admin_settings.dict()) @@ -94,9 +96,11 @@ async def api_update_settings(data: UpdateSettings, user: User = Depends(check_a "/api/v1/settings", status_code=HTTPStatus.OK, ) -async def api_update_settings_partial(data: dict, user: User = Depends(check_admin)): +async def api_update_settings_partial( + data: dict, account: Account = Depends(check_admin) +): updatable_settings = dict_to_settings({**settings.dict(), **data}) - return await api_update_settings(updatable_settings, user) + return await api_update_settings(updatable_settings, account) @admin_router.get( @@ -110,9 +114,9 @@ async def api_reset_settings(field_name: str): @admin_router.delete("/api/v1/settings", status_code=HTTPStatus.OK) -async def api_delete_settings(user: User = Depends(check_super_user)) -> None: +async def api_delete_settings(account: Account = Depends(check_super_user)) -> None: enqueue_admin_notification( - NotificationType.settings_update, {"username": user.username} + NotificationType.settings_update, {"username": account.username} ) await reset_core_settings() server_restart.set() diff --git a/lnbits/core/views/extension_api.py b/lnbits/core/views/extension_api.py index 63a3a3c3..784c8b3e 100644 --- a/lnbits/core/views/extension_api.py +++ b/lnbits/core/views/extension_api.py @@ -7,9 +7,10 @@ from fastapi import APIRouter, Depends, HTTPException from loguru import logger from lnbits.core.crud.extensions import get_user_extensions +from lnbits.core.crud.wallets import get_wallets_ids +from lnbits.core.db import db from lnbits.core.models import ( SimpleStatus, - User, ) from lnbits.core.models.extensions import ( CreateExtension, @@ -23,7 +24,7 @@ from lnbits.core.models.extensions import ( UserExtension, UserExtensionInfo, ) -from lnbits.core.models.users import AccountId +from lnbits.core.models.users import Account, AccountId from lnbits.core.services import check_transaction_status, create_invoice from lnbits.core.services.extensions import ( activate_extension, @@ -144,9 +145,10 @@ async def api_extension_details( async def api_update_pay_to_enable( ext_id: str, data: PayToEnableInfo, - user: User = Depends(check_admin), + account: Account = Depends(check_admin), ) -> SimpleStatus: - if data.wallet not in user.wallet_ids: + user_wallet_ids = await get_wallets_ids(account.id, deleted=False) + if data.wallet not in user_wallet_ids: raise HTTPException( HTTPStatus.BAD_REQUEST, "Wallet does not belong to this admin user." ) @@ -462,15 +464,16 @@ async def get_extension_release(org: str, repo: str, tag_name: str): async def api_get_user_extensions( account_id: AccountId = Depends(check_account_id_exists), ) -> list[Extension]: - - user_extensions_ids = [ - ue.extension for ue in await get_user_extensions(account_id.id) - ] - return [ - ext - for ext in await get_valid_extensions(False) - if ext.code in user_extensions_ids - ] + async with db.connect() as conn: + user_extensions_ids = [ + ue.extension for ue in await get_user_extensions(account_id.id, conn=conn) + ] + valid_extensions = [ + ext + for ext in await get_valid_extensions(False, conn=conn) + if ext.code in user_extensions_ids + ] + return valid_extensions @extension_router.delete( @@ -505,7 +508,16 @@ async def delete_extension_db(ext_id: str): # TODO: create a response model for this @extension_router.get("/all") async def extensions(account_id: AccountId = Depends(check_account_id_exists)): - installed_exts: list[InstallableExtension] = await get_installed_extensions() + async with db.connect() as conn: + installed_exts: list[InstallableExtension] = await get_installed_extensions( + conn=conn + ) + all_ext_ids = [ext.code for ext in await get_valid_extensions(conn=conn)] + inactive_extensions = [ + e.id for e in await get_installed_extensions(active=False, conn=conn) + ] + db_versions = await get_db_versions(conn=conn) + installed_exts_ids = [e.id for e in installed_exts] installable_exts = await InstallableExtension.get_installable_extensions( @@ -536,10 +548,6 @@ async def extensions(account_id: AccountId = Depends(check_account_id_exists)): e.short_description = installed_ext.short_description e.icon = installed_ext.icon - all_ext_ids = [ext.code for ext in await get_valid_extensions()] - inactive_extensions = [e.id for e in await get_installed_extensions(active=False)] - db_versions = await get_db_versions() - extension_data = [ { "id": ext.id, diff --git a/lnbits/core/views/extensions_builder_api.py b/lnbits/core/views/extensions_builder_api.py index 584d7387..4ff5a1c3 100644 --- a/lnbits/core/views/extensions_builder_api.py +++ b/lnbits/core/views/extensions_builder_api.py @@ -7,7 +7,6 @@ from fastapi.responses import FileResponse from lnbits.core.models import ( SimpleStatus, - User, ) from lnbits.core.models.extensions import ( Extension, @@ -16,7 +15,7 @@ from lnbits.core.models.extensions import ( UserExtension, ) from lnbits.core.models.extensions_builder import ExtensionData -from lnbits.core.models.users import AccountId +from lnbits.core.models.users import Account, AccountId from lnbits.core.services.extensions import ( activate_extension, install_extension, @@ -85,9 +84,9 @@ async def api_build_extension(data: ExtensionData) -> FileResponse: ) async def api_deploy_extension( data: ExtensionData, - user: User = Depends(check_admin), + account: Account = Depends(check_admin), ) -> SimpleStatus: - working_dir_name = "deploy_" + sha256(user.id.encode("utf-8")).hexdigest() + working_dir_name = "deploy_" + sha256(account.id.encode("utf-8")).hexdigest() stub_ext_id = "extension_builder_stub" release, build_dir = await build_extension_from_data( data, stub_ext_id, working_dir_name @@ -111,9 +110,9 @@ async def api_deploy_extension( await activate_extension(Extension.from_installable_ext(ext_info)) - user_ext = await get_user_extension(user.id, data.id) + user_ext = await get_user_extension(account.id, data.id) if not user_ext: - user_ext = UserExtension(user=user.id, extension=data.id, active=True) + user_ext = UserExtension(user=account.id, extension=data.id, active=True) await create_user_extension(user_ext) elif not user_ext.active: user_ext.active = True diff --git a/lnbits/core/views/payment_api.py b/lnbits/core/views/payment_api.py index 0759022d..97cba732 100644 --- a/lnbits/core/views/payment_api.py +++ b/lnbits/core/views/payment_api.py @@ -18,6 +18,7 @@ from lnbits.core.crud.payments import ( update_payment, ) from lnbits.core.crud.users import get_account +from lnbits.core.db import db from lnbits.core.models import ( CancelInvoice, CreateInvoice, @@ -69,7 +70,6 @@ from ..services import ( perform_withdraw, settle_hold_invoice, update_pending_payment, - update_pending_payments, ) payment_router = APIRouter(prefix="/api/v1/payments", tags=["Payments"]) @@ -87,7 +87,6 @@ async def api_payments( key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key), filters: Filters = Depends(parse_filters(PaymentFilters)), ): - await update_pending_payments(key_info.wallet.id) return await get_payments( wallet_id=key_info.wallet.id, pending=True, @@ -107,7 +106,6 @@ async def api_payments_history( group: DateTrunc = Query("day"), filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)), ): - await update_pending_payments(key_info.wallet.id) return await get_payments_history(key_info.wallet.id, group, filters) @@ -186,18 +184,24 @@ async def api_payments_paginated( ), filters: Filters = Depends(parse_filters(PaymentFilters)), ) -> Page[Payment]: - page = await get_payments_paginated( - wallet_id=key_info.wallet.id, - filters=filters, - ) - if not recheck_pending: - return page + async with db.connect() as conn: + page = await get_payments_paginated( + wallet_id=key_info.wallet.id, + filters=filters, + conn=conn, + ) + if not recheck_pending: + return page - for payment in page.data: - if payment.pending: - await update_pending_payment(payment) + payments = [] + for payment in page.data: + if payment.pending: + refreshed_payment = await update_pending_payment(payment, conn=conn) + payments.append(refreshed_payment) + else: + payments.append(payment) - return page + return Page(data=payments, total=page.total) @payment_router.get( @@ -219,10 +223,10 @@ async def api_all_payments_paginated( # regular user can only see payments from their wallets for_user_id = account_id.id - return await get_payments_paginated( - filters=filters, - user_id=for_user_id, - ) + async with db.connect() as conn: + return await get_payments_paginated( + filters=filters, user_id=for_user_id, conn=conn + ) @payment_router.post( diff --git a/lnbits/core/views/user_api.py b/lnbits/core/views/user_api.py index 4f1725a3..012e4f26 100644 --- a/lnbits/core/views/user_api.py +++ b/lnbits/core/views/user_api.py @@ -115,12 +115,12 @@ async def api_create_user(data: CreateUser) -> CreateUser: @users_router.put("/user/{user_id}", name="Update user") async def api_update_user( - user_id: str, data: CreateUser, user: User = Depends(check_admin) + user_id: str, data: CreateUser, account: Account = Depends(check_admin) ) -> CreateUser: if user_id != data.id: raise HTTPException(HTTPStatus.BAD_REQUEST, "User Id missmatch.") - if user_id == settings.super_user and user.id != settings.super_user: + if user_id == settings.super_user and account.id != settings.super_user: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail="Action only allowed for super user.", @@ -154,7 +154,7 @@ async def api_update_user( name="Delete user by Id", ) async def api_users_delete_user( - user_id: str, user: User = Depends(check_admin) + user_id: str, account: Account = Depends(check_admin) ) -> SimpleStatus: wallets = await get_wallets(user_id, deleted=False) if len(wallets) > 0: @@ -169,7 +169,7 @@ async def api_users_delete_user( detail="Cannot delete super user.", ) - if user_id in settings.lnbits_admin_users and not user.super_user: + if user_id in settings.lnbits_admin_users and not account.is_super_user: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail="Only super_user can delete admin user.", @@ -295,7 +295,7 @@ async def api_users_delete_all_user_wallet(user_id: str) -> SimpleStatus: "The second time it is called will delete the entry from the DB", ) async def api_users_delete_user_wallet( - user_id: str, wallet: str, user: User = Depends(check_admin) + user_id: str, wallet: str, account: Account = Depends(check_admin) ) -> SimpleStatus: wal = await get_wallet(wallet) if not wal: @@ -304,7 +304,7 @@ async def api_users_delete_user_wallet( detail="Wallet does not exist.", ) - if user_id == settings.super_user and user.id != settings.super_user: + if user_id == settings.super_user and account.id != settings.super_user: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail="Action only allowed for super user.", diff --git a/lnbits/decorators.py b/lnbits/decorators.py index 7d886529..f7280826 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -20,6 +20,7 @@ from lnbits.core.crud import ( ) from lnbits.core.crud.users import get_user_access_control_lists from lnbits.core.crud.wallets import get_base_wallet_for_key +from lnbits.core.db import db from lnbits.core.models import ( AccessTokenPayload, Account, @@ -119,16 +120,17 @@ class KeyChecker(BaseKeyChecker): async def __call__(self, request: Request) -> WalletTypeInfo: key_value = self._extract_key_value(request) - wallet = await get_wallet_for_key(key_value) + async with db.connect() as conn: + wallet = await get_wallet_for_key(key_value, conn=conn) - if not wallet: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, - detail="Wallet not found.", - ) + if not wallet: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail="Wallet not found.", + ) - request.scope["user_id"] = wallet.user - await _check_user_access(request, wallet.user) + request.scope["user_id"] = wallet.user + await _check_user_access(request, wallet.user, conn=conn) key_type = await self._extract_key_type(key_value, wallet) return WalletTypeInfo(key_type, wallet) @@ -148,22 +150,23 @@ class LightKeyChecker(BaseKeyChecker): cache_key = f"auth:x-api-key:{key_value}" cache_time = settings.auth_authentication_cache_minutes * 60 - if cache_time > 0: - key_info: BaseWalletTypeInfo | None = cache.get(cache_key) - if key_info: - request.scope["user_id"] = key_info.wallet.user - await _check_user_access(request, key_info.wallet.user) - return key_info + async with db.connect() as conn: + if cache_time > 0: + key_info: BaseWalletTypeInfo | None = cache.get(cache_key) + if key_info: + request.scope["user_id"] = key_info.wallet.user + await _check_user_access(request, key_info.wallet.user, conn=conn) + return key_info - wallet = await get_base_wallet_for_key(key_value) + wallet = await get_base_wallet_for_key(key_value, conn=conn) - if not wallet: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, - detail="Wallet not found.", - ) - request.scope["user_id"] = wallet.user - await _check_user_access(request, wallet.user) + if not wallet: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail="Wallet not found.", + ) + request.scope["user_id"] = wallet.user + await _check_user_access(request, wallet.user, conn=conn) key_type = await self._extract_key_type(key_value, wallet) key_info = BaseWalletTypeInfo(key_type, wallet) @@ -241,15 +244,16 @@ async def check_account_id_exists( elif usr: cache_key = f"auth:user_id:{sha256s(usr.hex)}" - if cache_key and settings.auth_authentication_cache_minutes > 0: - account_id = cache.get(cache_key) - if account_id: - r.scope["user_id"] = account_id.id - await _check_user_access(r, account_id) - return account_id + async with db.connect() as conn: + if cache_key and settings.auth_authentication_cache_minutes > 0: + account_id = cache.get(cache_key) + if account_id: + r.scope["user_id"] = account_id.id + await _check_user_access(r, account_id, conn=conn) + return account_id - account = await check_account_exists(r, access_token, usr) - account_id = AccountId(id=account.id) + account = await _check_account_exists(r, access_token, usr, conn=conn) + account_id = AccountId(id=account.id) if cache_key and settings.auth_authentication_cache_minutes > 0: cache.set( @@ -265,6 +269,15 @@ async def check_account_exists( r: Request, access_token: Annotated[str | None, Depends(check_access_token)], usr: UUID4 | None = None, +) -> Account: + return await _check_account_exists(r, access_token, usr) + + +async def _check_account_exists( + r: Request, + access_token: Annotated[str | None, Depends(check_access_token)], + usr: UUID4 | None = None, + conn: Connection | None = None, ) -> Account: """ Check that the account exists based on access token or user id. @@ -273,22 +286,27 @@ async def check_account_exists( - does not fetch the user wallets - caches the account info based on settings cache time """ - if access_token: - account = await _get_account_from_token(access_token, r["path"], r["method"]) - elif usr and settings.is_auth_method_allowed(AuthMethods.user_id_only): - account = await get_account(usr.hex) - if account and account.is_admin: - raise HTTPException( - HTTPStatus.FORBIDDEN, "User id only access for admins is forbidden." + async with db.reuse_conn(conn) if conn else db.connect() as new_conn: + if access_token: + account = await _get_account_from_token( + access_token, r["path"], r["method"], conn=new_conn + ) + elif usr and settings.is_auth_method_allowed(AuthMethods.user_id_only): + account = await get_account(usr.hex, conn=new_conn) + if account and account.is_admin: + raise HTTPException( + HTTPStatus.FORBIDDEN, "User id only access for admins is forbidden." + ) + else: + raise HTTPException( + HTTPStatus.UNAUTHORIZED, "Missing user ID or access token." ) - else: - raise HTTPException(HTTPStatus.UNAUTHORIZED, "Missing user ID or access token.") - if not account: - raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.") + if not account: + raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.") - r.scope["user_id"] = account.id - await _check_user_access(r, account.id) + r.scope["user_id"] = account.id + await _check_user_access(r, account.id, conn=new_conn) return account @@ -298,8 +316,9 @@ async def check_user_exists( access_token: Annotated[str | None, Depends(check_access_token)], usr: UUID4 | None = None, ) -> User: - account = await check_account_exists(r, access_token, usr) - user = await get_user_from_account(account) + async with db.connect() as conn: + account = await _check_account_exists(r, access_token, usr, conn=conn) + user = await get_user_from_account(account, conn=conn) if not user: raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.") @@ -330,29 +349,36 @@ async def access_token_payload( return AccessTokenPayload(**payload) -async def check_admin(user: Annotated[User, Depends(check_user_exists)]) -> User: - if user.id != settings.super_user and user.id not in settings.lnbits_admin_users: +async def check_admin( + account: Annotated[Account, Depends(check_account_exists)], +) -> Account: + if ( + account.id != settings.super_user + and account.id not in settings.lnbits_admin_users + ): raise HTTPException( HTTPStatus.FORBIDDEN, "User not authorized. No admin privileges." ) - if not user.has_password: + if not account.has_password: raise HTTPException( HTTPStatus.FORBIDDEN, "Admin users must have credentials configured." ) - return user + return account -async def check_super_user(user: Annotated[User, Depends(check_user_exists)]) -> User: - if user.id != settings.super_user: +async def check_super_user( + account: Annotated[Account, Depends(check_admin)], +) -> Account: + if account.id != settings.super_user: raise HTTPException( HTTPStatus.FORBIDDEN, "User not authorized. No super user privileges." ) - if not user.has_password: + if not account.has_password: raise HTTPException( HTTPStatus.FORBIDDEN, "Super user must have credentials configured." ) - return user + return account def parse_filters(model: type[TFilterModel]): @@ -412,15 +438,17 @@ async def check_user_extension_access( return SimpleStatus(success=True, message="OK") -async def _check_user_access(r: Request, user_id: str): +async def _check_user_access(r: Request, user_id: str, conn: Connection | None = None): if not settings.is_user_allowed(user_id): raise HTTPException(HTTPStatus.FORBIDDEN, "User not allowed.") - await _check_user_extension_access(user_id, r["path"]) + await _check_user_extension_access(user_id, r["path"], conn=conn) -async def _check_user_extension_access(user_id: str, path: str): +async def _check_user_extension_access( + user_id: str, path: str, conn: Connection | None = None +): ext_id = path_segments(path)[0] - status = await check_user_extension_access(user_id, ext_id) + status = await check_user_extension_access(user_id, ext_id, conn=conn) if not status.success: raise HTTPException( HTTPStatus.FORBIDDEN, @@ -429,12 +457,12 @@ async def _check_user_extension_access(user_id: str, path: str): async def _get_account_from_token( - access_token: str, path: str, method: str + access_token: str, path: str, method: str, conn: Connection | None = None ) -> Account | None: try: payload: dict = jwt.decode(access_token, settings.auth_secret_key, ["HS256"]) return await _get_account_from_jwt_payload( - AccessTokenPayload(**payload), path, method + AccessTokenPayload(**payload), path, method, conn=conn ) except jwt.ExpiredSignatureError as exc: @@ -447,33 +475,35 @@ async def _get_account_from_token( async def _get_account_from_jwt_payload( - payload: AccessTokenPayload, path: str, method: str + payload: AccessTokenPayload, path: str, method: str, conn: Connection | None = None ) -> Account | None: account = None if payload.sub: - account = await get_account_by_username(payload.sub) + account = await get_account_by_username(payload.sub, conn=conn) elif payload.usr: - account = await get_account(payload.usr) + account = await get_account(payload.usr, conn=conn) elif payload.email: - account = await get_account_by_email(payload.email) + account = await get_account_by_email(payload.email, conn=conn) if not account: return None if payload.api_token_id: - await _check_account_api_access(account.id, payload.api_token_id, path, method) + await _check_account_api_access( + account.id, payload.api_token_id, path, method, conn=conn + ) return account async def _check_account_api_access( - user_id: str, token_id: str, path: str, method: str + user_id: str, token_id: str, path: str, method: str, conn: Connection | None = None ): segments = path.split("/") if len(segments) < 3: raise HTTPException(HTTPStatus.FORBIDDEN, "Not an API endpoint.") - acls = await get_user_access_control_lists(user_id) + acls = await get_user_access_control_lists(user_id, conn=conn) acl = acls.get_acl_by_token_id(token_id) if not acl: raise HTTPException(HTTPStatus.FORBIDDEN, "Invalid token id.")