diff --git a/lnbits/__init__.py b/lnbits/__init__.py index 30ddf071..0c49e9a4 100644 --- a/lnbits/__init__.py +++ b/lnbits/__init__.py @@ -1,5 +1,6 @@ from .core.services import create_invoice, pay_invoice from .decorators import ( + check_account_exists, check_admin, check_super_user, check_user_exists, @@ -11,6 +12,7 @@ from .exceptions import InvoiceError, PaymentError __all__ = [ "InvoiceError", "PaymentError", + "check_account_exists", "check_admin", "check_super_user", "check_user_exists", diff --git a/lnbits/core/models/users.py b/lnbits/core/models/users.py index 1484bacf..c06f3a08 100644 --- a/lnbits/core/models/users.py +++ b/lnbits/core/models/users.py @@ -172,8 +172,15 @@ class UserAcls(BaseModel): return None -class Account(BaseModel): +class AccountId(BaseModel): id: str + + @property + def is_admin_id(self) -> bool: + return settings.is_admin_user(self.id) + + +class Account(AccountId): external_id: str | None = None # for external account linking username: str | None = None password_hash: str | None = None diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index a75dc474..49beca97 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -14,7 +14,12 @@ from lnbits.core.models import ( User, Wallet, ) -from lnbits.decorators import check_user_exists +from lnbits.core.models.users import AccountId +from lnbits.decorators import ( + check_account_exists, + check_account_id_exists, + check_user_exists, +) from lnbits.settings import settings from lnbits.utils.exchange_rates import ( allowed_currencies, @@ -39,7 +44,9 @@ async def health() -> dict: @api_router.get("/api/v1/status", status_code=HTTPStatus.OK) -async def health_check(user: User = Depends(check_user_exists)) -> dict: +async def health_check( + account_id: AccountId = Depends(check_account_id_exists), +) -> dict: stat: dict[str, Any] = { "server_time": int(time()), "up_time": settings.lnbits_server_up_time, @@ -47,7 +54,7 @@ async def health_check(user: User = Depends(check_user_exists)) -> dict: } stat["version"] = settings.version - if not user.admin: + if not account_id.is_admin_id: return stat funding_source = get_funding_source() @@ -78,7 +85,7 @@ async def api_create_account(data: CreateWallet) -> Wallet: @api_router.get( "/api/v1/rate/history", - dependencies=[Depends(check_user_exists)], + dependencies=[Depends(check_account_exists)], ) async def api_exchange_rate_history() -> list[dict]: return settings.lnbits_exchange_rate_history diff --git a/lnbits/core/views/asset_api.py b/lnbits/core/views/asset_api.py index cd318653..3a357e33 100644 --- a/lnbits/core/views/asset_api.py +++ b/lnbits/core/views/asset_api.py @@ -15,11 +15,11 @@ from lnbits.core.crud.assets import ( ) from lnbits.core.models.assets import AssetFilters, AssetInfo, AssetUpdate from lnbits.core.models.misc import SimpleStatus -from lnbits.core.models.users import User +from lnbits.core.models.users import AccountId from lnbits.core.services.assets import create_user_asset from lnbits.db import Filters, Page from lnbits.decorators import ( - check_user_exists, + check_account_id_exists, optional_user_id, parse_filters, ) @@ -35,10 +35,10 @@ upload_file_param = File(...) summary="Get paginated list user assets", ) async def api_get_user_assets( - user: User = Depends(check_user_exists), + account_id: AccountId = Depends(check_account_id_exists), filters: Filters = Depends(parse_filters(AssetFilters)), ) -> Page[AssetInfo]: - return await get_user_assets(user.id, filters=filters) + return await get_user_assets(account_id.id, filters=filters) @asset_router.get( @@ -48,9 +48,9 @@ async def api_get_user_assets( ) async def api_get_asset( asset_id: str, - user: User = Depends(check_user_exists), + account_id: AccountId = Depends(check_account_id_exists), ) -> AssetInfo: - asset_info = await get_user_asset_info(user.id, asset_id) + asset_info = await get_user_asset_info(account_id.id, asset_id) if not asset_info: raise HTTPException(HTTPStatus.NOT_FOUND, "Asset not found.") return asset_info @@ -120,12 +120,12 @@ async def api_get_asset_thumbnail( async def api_update_asset( asset_id: str, data: AssetUpdate, - user: User = Depends(check_user_exists), + account_id: AccountId = Depends(check_account_id_exists), ) -> AssetInfo: - if user.admin: + if account_id.is_admin_id: asset_info = await get_asset_info(asset_id) else: - asset_info = await get_user_asset_info(user.id, asset_id) + asset_info = await get_user_asset_info(account_id.id, asset_id) if not asset_info: raise HTTPException(HTTPStatus.NOT_FOUND, "Asset not found.") @@ -144,13 +144,13 @@ async def api_update_asset( summary="Upload user assets", ) async def api_upload_asset( - user: User = Depends(check_user_exists), + account_id: AccountId = Depends(check_account_id_exists), file: UploadFile = upload_file_param, public_asset: bool = False, ) -> AssetInfo: - asset = await create_user_asset(user.id, file, public_asset) + asset = await create_user_asset(account_id.id, file, public_asset) - asset_info = await get_user_asset_info(user.id, asset.id) + asset_info = await get_user_asset_info(account_id.id, asset.id) if not asset_info: raise ValueError("Failed to retrieve asset info after upload.") @@ -164,11 +164,11 @@ async def api_upload_asset( ) async def api_delete_asset( asset_id: str, - user: User = Depends(check_user_exists), + account_id: AccountId = Depends(check_account_id_exists), ) -> SimpleStatus: - asset = await get_user_asset(user.id, asset_id) + asset = await get_user_asset(account_id.id, asset_id) if not asset: raise HTTPException(HTTPStatus.NOT_FOUND, "Asset not found.") - await delete_user_asset(user.id, asset_id) + await delete_user_asset(account_id.id, asset_id) return SimpleStatus(success=True, message="Asset deleted successfully.") diff --git a/lnbits/core/views/auth_api.py b/lnbits/core/views/auth_api.py index c607d5e0..a7fbb17c 100644 --- a/lnbits/core/views/auth_api.py +++ b/lnbits/core/views/auth_api.py @@ -26,7 +26,11 @@ from lnbits.core.models.users import ( ) from lnbits.core.services import create_user_account from lnbits.core.services.users import update_user_account -from lnbits.decorators import access_token_payload, check_user_exists +from lnbits.decorators import ( + access_token_payload, + check_account_exists, + check_user_exists, +) from lnbits.helpers import ( create_access_token, decrypt_internal_message, @@ -119,11 +123,11 @@ async def login_usr(data: LoginUsr) -> JSONResponse: @auth_router.get("/acl") async def api_get_user_acls( request: Request, - user: User = Depends(check_user_exists), + account: Account = Depends(check_account_exists), ) -> UserAcls: api_routes = get_api_routes(request.app.router.routes) - acls = await get_user_access_control_lists(user.id) + acls = await get_user_access_control_lists(account.id) # Add missing/new endpoints to the ACLs for acl in acls.access_control_list: @@ -136,7 +140,7 @@ async def api_get_user_acls( acl.endpoints.append(EndpointAccess(path=path, name=name)) acl.endpoints.sort(key=lambda e: e.name.lower()) - return UserAcls(id=user.id, access_control_list=acls.access_control_list) + return UserAcls(id=account.id, access_control_list=acls.access_control_list) @auth_router.put("/acl") @@ -144,13 +148,13 @@ async def api_get_user_acls( async def api_update_user_acl( request: Request, data: UpdateAccessControlList, - user: User = Depends(check_user_exists), + account: Account = Depends(check_account_exists), ) -> UserAcls: - account = await get_account(user.id) + if not account or not account.verify_password(data.password): raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.") - user_acls = await get_user_access_control_lists(user.id) + user_acls = await get_user_access_control_lists(account.id) acl = user_acls.get_acl_by_id(data.id) if acl: user_acls.access_control_list.remove(acl) @@ -175,33 +179,30 @@ async def api_update_user_acl( @auth_router.delete("/acl") async def api_delete_user_acl( - data: DeleteAccessControlList, - user: User = Depends(check_user_exists), + data: DeleteAccessControlList, account: Account = Depends(check_account_exists) ): - account = await get_account(user.id) if not account or not account.verify_password(data.password): raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.") - user_acls = await get_user_access_control_lists(user.id) + user_acls = await get_user_access_control_lists(account.id) user_acls.delete_acl_by_id(data.id) await update_user_access_control_list(user_acls) @auth_router.post("/acl/token") async def api_create_user_api_token( - data: ApiTokenRequest, - user: User = Depends(check_user_exists), + data: ApiTokenRequest, account: Account = Depends(check_account_exists) ) -> ApiTokenResponse: if not data.expiration_time_minutes > 0: raise ValueError("Expiration time must be in the future.") - account = await get_account(user.id) + if not account or not account.verify_password(data.password): raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.") if not account.username: raise ValueError("Username must be configured.") - acls = await get_user_access_control_lists(user.id) + acls = await get_user_access_control_lists(account.id) acl = acls.get_acl_by_id(data.acl_id) if not acl: raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid ACL id.") @@ -218,18 +219,16 @@ async def api_create_user_api_token( @auth_router.delete("/acl/token") async def api_delete_user_api_token( - data: DeleteTokenRequest, - user: User = Depends(check_user_exists), + data: DeleteTokenRequest, account: Account = Depends(check_account_exists) ): - account = await get_account(user.id) if not account or not account.verify_password(data.password): raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.") if not account.username: raise ValueError("Username must be configured.") - acls = await get_user_access_control_lists(user.id) + acls = await get_user_access_control_lists(account.id) acl = acls.get_acl_by_id(data.acl_id) if not acl: raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid ACL id.") @@ -318,24 +317,20 @@ async def register(data: RegisterUser) -> JSONResponse: @auth_router.put("/pubkey") async def update_pubkey( data: UpdateUserPubkey, - user: User = Depends(check_user_exists), + account: Account = Depends(check_account_exists), payload: AccessTokenPayload = Depends(access_token_payload), ) -> User | None: - if data.user_id != user.id: + if data.user_id != account.id: raise ValueError("Invalid user ID.") _validate_auth_timeout(payload.auth_time) if ( data.pubkey - and data.pubkey != user.pubkey + and data.pubkey != account.pubkey and await get_account_by_pubkey(data.pubkey) ): raise ValueError("Public key already in use.") - account = await get_account(user.id) - if not account: - raise HTTPException(HTTPStatus.NOT_FOUND, "Account not found.") - account.pubkey = normalize_public_key(data.pubkey) await update_account(account) return await get_user_from_account(account) @@ -344,23 +339,19 @@ async def update_pubkey( @auth_router.put("/password") async def update_password( data: UpdateUserPassword, - user: User = Depends(check_user_exists), + account: Account = Depends(check_account_exists), payload: AccessTokenPayload = Depends(access_token_payload), ) -> User | None: _validate_auth_timeout(payload.auth_time) - if data.user_id != user.id: + if data.user_id != account.id: raise ValueError("Invalid user ID.") if ( data.username - and user.username != data.username + and account.username != data.username and await get_account_by_username(data.username) ): raise HTTPException(HTTPStatus.BAD_REQUEST, "Username already exists.") - account = await get_account(user.id) - if not account: - raise ValueError("Account not found.") - # old accounts do not have a password if account.password_hash: if not data.password_old: @@ -419,15 +410,11 @@ async def reset_password(data: ResetUserPassword) -> JSONResponse: @auth_router.put("/update") async def update( - data: UpdateUser, user: User = Depends(check_user_exists) + data: UpdateUser, account: Account = Depends(check_account_exists) ) -> User | None: - if data.user_id != user.id: + if data.user_id != account.id: raise HTTPException(HTTPStatus.BAD_REQUEST, "Invalid user ID.") - account = await get_account(user.id) - if not account: - raise HTTPException(HTTPStatus.NOT_FOUND, "Account not found.") - if data.username: account.username = data.username if data.extra: diff --git a/lnbits/core/views/extension_api.py b/lnbits/core/views/extension_api.py index dcbde2a3..c9373635 100644 --- a/lnbits/core/views/extension_api.py +++ b/lnbits/core/views/extension_api.py @@ -23,6 +23,7 @@ from lnbits.core.models.extensions import ( UserExtension, UserExtensionInfo, ) +from lnbits.core.models.users import AccountId from lnbits.core.services import check_transaction_status, create_invoice from lnbits.core.services.extensions import ( activate_extension, @@ -33,8 +34,9 @@ from lnbits.core.services.extensions import ( uninstall_extension, ) from lnbits.decorators import ( + check_account_exists, + check_account_id_exists, check_admin, - check_user_exists, ) from lnbits.settings import settings @@ -163,7 +165,7 @@ async def api_update_pay_to_enable( @extension_router.put("/{ext_id}/enable") async def api_enable_extension( - ext_id: str, user: User = Depends(check_user_exists) + ext_id: str, account_id: AccountId = Depends(check_account_id_exists) ) -> SimpleStatus: if ext_id not in [e.code for e in await get_valid_extensions()]: raise HTTPException( @@ -177,12 +179,12 @@ async def api_enable_extension( if not ext.active: raise ValueError(f"Extension '{ext_id}' is not activated.") - user_ext = await get_user_extension(user.id, ext_id) + user_ext = await get_user_extension(account_id.id, ext_id) if not user_ext: - user_ext = UserExtension(user=user.id, extension=ext_id, active=False) + user_ext = UserExtension(user=account_id.id, extension=ext_id, active=False) await create_user_extension(user_ext) - if user.admin or not ext.requires_payment: + if account_id.is_admin_id or not ext.requires_payment: user_ext.active = True await update_user_extension(user_ext) return SimpleStatus(success=True, message=f"Extension '{ext_id}' enabled.") @@ -219,13 +221,13 @@ async def api_enable_extension( @extension_router.put("/{ext_id}/disable") async def api_disable_extension( - ext_id: str, user: User = Depends(check_user_exists) + ext_id: str, account_id: AccountId = Depends(check_account_id_exists) ) -> SimpleStatus: if ext_id not in [e.code for e in await get_valid_extensions()]: raise HTTPException( HTTPStatus.BAD_REQUEST, f"Extension '{ext_id}' doesn't exist." ) - user_ext = await get_user_extension(user.id, ext_id) + user_ext = await get_user_extension(account_id.id, ext_id) if not user_ext or not user_ext.active: return SimpleStatus( success=True, message=f"Extension '{ext_id}' already disabled." @@ -376,7 +378,9 @@ async def get_pay_to_install_invoice( @extension_router.put("/{ext_id}/invoice/enable") async def get_pay_to_enable_invoice( - ext_id: str, data: PayToEnableInfo, user: User = Depends(check_user_exists) + ext_id: str, + data: PayToEnableInfo, + account_id: AccountId = Depends(check_account_id_exists), ): if not data.amount or data.amount <= 0: raise HTTPException( @@ -422,9 +426,9 @@ async def get_pay_to_enable_invoice( memo=f"Enable '{ext.name}' extension.", ) - user_ext = await get_user_extension(user.id, ext_id) + user_ext = await get_user_extension(account_id.id, ext_id) if not user_ext: - user_ext = UserExtension(user=user.id, extension=ext_id, active=False) + user_ext = UserExtension(user=account_id.id, extension=ext_id, active=False) await create_user_extension(user_ext) user_ext_info = user_ext.extra if user_ext.extra else UserExtensionInfo() user_ext_info.payment_hash_to_enable = payment.payment_hash @@ -435,7 +439,7 @@ async def get_pay_to_enable_invoice( @extension_router.get( "/release/{org}/{repo}/{tag_name}", - dependencies=[Depends(check_user_exists)], + dependencies=[Depends(check_account_exists)], ) async def get_extension_release(org: str, repo: str, tag_name: str): try: @@ -456,10 +460,12 @@ async def get_extension_release(org: str, repo: str, tag_name: str): @extension_router.get("") async def api_get_user_extensions( - user: User = Depends(check_user_exists), + account_id: AccountId = Depends(check_account_id_exists), ) -> list[Extension]: - user_extensions_ids = [ue.extension for ue in await get_user_extensions(user.id)] + user_extensions_ids = [ + ue.extension for ue in await get_user_extensions(account_id.id) + ] return [ ext for ext in await get_valid_extensions(False) @@ -498,7 +504,7 @@ async def delete_extension_db(ext_id: str): # TODO: create a response model for this @extension_router.get("/all") -async def extensions(user: User = Depends(check_user_exists)): +async def extensions(account_id: AccountId = Depends(check_account_id_exists)): installed_exts: list[InstallableExtension] = await get_installed_extensions() installed_exts_ids = [e.id for e in installed_exts] @@ -510,7 +516,7 @@ async def extensions(user: User = Depends(check_user_exists)): installed_ext = next((ie for ie in installed_exts if e.id == ie.id), None) if installed_ext and installed_ext.meta: installed_release = installed_ext.meta.installed_release - if installed_ext.meta.pay_to_enable and not user.admin: + if installed_ext.meta.pay_to_enable and not account_id.is_admin_id: # not a security leak, but better not to share the wallet id installed_ext.meta.pay_to_enable.wallet = None pay_to_enable = installed_ext.meta.pay_to_enable diff --git a/lnbits/core/views/extensions_builder_api.py b/lnbits/core/views/extensions_builder_api.py index 5cdac982..584d7387 100644 --- a/lnbits/core/views/extensions_builder_api.py +++ b/lnbits/core/views/extensions_builder_api.py @@ -16,6 +16,7 @@ from lnbits.core.models.extensions import ( UserExtension, ) from lnbits.core.models.extensions_builder import ExtensionData +from lnbits.core.models.users import AccountId from lnbits.core.services.extensions import ( activate_extension, install_extension, @@ -26,9 +27,9 @@ from lnbits.core.services.extensions_builder import ( zip_directory, ) from lnbits.decorators import ( + check_account_id_exists, check_admin, check_extension_builder, - check_user_exists, ) from ..crud import ( @@ -128,10 +129,10 @@ async def api_deploy_extension( ) async def api_preview_extension( data: ExtensionData, - user: User = Depends(check_user_exists), + account_id: AccountId = Depends(check_account_id_exists), ) -> SimpleStatus: stub_ext_id = "extension_builder_stub" - working_dir_name = "preview_" + sha256(user.id.encode("utf-8")).hexdigest() + working_dir_name = "preview_" + sha256(account_id.id.encode("utf-8")).hexdigest() await build_extension_from_data(data, stub_ext_id, working_dir_name) return SimpleStatus(success=True, message=f"Extension '{data.id}' preview ready.") diff --git a/lnbits/core/views/payment_api.py b/lnbits/core/views/payment_api.py index e9aea1c8..98d615c4 100644 --- a/lnbits/core/views/payment_api.py +++ b/lnbits/core/views/payment_api.py @@ -35,11 +35,11 @@ from lnbits.core.models import ( SimpleStatus, ) from lnbits.core.models.payments import UpdatePaymentLabels -from lnbits.core.models.users import User +from lnbits.core.models.users import AccountId from lnbits.db import Filters, Page from lnbits.decorators import ( WalletTypeInfo, - check_user_exists, + check_account_id_exists, parse_filters, require_admin_key, require_invoice_key, @@ -118,14 +118,14 @@ async def api_payments_history( async def api_payments_counting_stats( count_by: PaymentCountField = Query("tag"), filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)), - user: User = Depends(check_user_exists), + account_id: AccountId = Depends(check_account_id_exists), ): - if user.admin: + if account_id.is_admin_id: # admin user can see payments from all wallets for_user_id = None else: # regular user can only see payments from their wallets - for_user_id = user.id + for_user_id = account_id.id return await get_payment_count_stats(count_by, filters=filters, user_id=for_user_id) @@ -138,14 +138,14 @@ async def api_payments_counting_stats( ) async def api_payments_wallets_stats( filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)), - user: User = Depends(check_user_exists), + account_id: AccountId = Depends(check_account_id_exists), ): - if user.admin: + if account_id.is_admin_id: # admin user can see payments from all wallets for_user_id = None else: # regular user can only see payments from their wallets - for_user_id = user.id + for_user_id = account_id.id return await get_wallets_stats(filters, user_id=for_user_id) @@ -157,15 +157,15 @@ async def api_payments_wallets_stats( openapi_extra=generate_filter_params_openapi(PaymentFilters), ) async def api_payments_daily_stats( - user: User = Depends(check_user_exists), + account_id: AccountId = Depends(check_account_id_exists), filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)), ): - if user.admin: + if account_id.is_admin_id: # admin user can see payments from all wallets for_user_id = None else: # regular user can only see payments from their wallets - for_user_id = user.id + for_user_id = account_id.id return await get_payments_daily_stats(filters, user_id=for_user_id) @@ -208,14 +208,14 @@ async def api_payments_paginated( ) async def api_all_payments_paginated( filters: Filters = Depends(parse_filters(PaymentFilters)), - user: User = Depends(check_user_exists), + account_id: AccountId = Depends(check_account_id_exists), ): - if user.admin: + if account_id.is_admin_id: # admin user can see payments from all wallets for_user_id = None else: # regular user can only see payments from their wallets - for_user_id = user.id + for_user_id = account_id.id return await get_payments_paginated( filters=filters, diff --git a/lnbits/core/views/wallet_api.py b/lnbits/core/views/wallet_api.py index ce6d81b6..54c5e2cb 100644 --- a/lnbits/core/views/wallet_api.py +++ b/lnbits/core/views/wallet_api.py @@ -12,9 +12,10 @@ from lnbits.core.crud.wallets import ( create_wallet, get_wallets_paginated, ) -from lnbits.core.models import CreateWallet, KeyType, User, Wallet, WalletTypeInfo +from lnbits.core.models import CreateWallet, KeyType, Wallet, WalletTypeInfo from lnbits.core.models.lnurl import StoredPayLink, StoredPayLinks from lnbits.core.models.misc import SimpleStatus +from lnbits.core.models.users import Account, AccountId from lnbits.core.models.wallets import ( WalletsFilters, WalletSharePermission, @@ -29,7 +30,8 @@ from lnbits.core.services.wallets import ( ) from lnbits.db import Filters, Page from lnbits.decorators import ( - check_user_exists, + check_account_exists, + check_account_id_exists, parse_filters, require_admin_key, require_invoice_key, @@ -65,11 +67,11 @@ async def api_wallet(key_info: WalletTypeInfo = Depends(require_invoice_key)): openapi_extra=generate_filter_params_openapi(WalletsFilters), ) async def api_wallets_paginated( - user: User = Depends(check_user_exists), + account_id: AccountId = Depends(check_account_id_exists), filters: Filters = Depends(parse_filters(WalletsFilters)), ): page = await get_wallets_paginated( - user_id=user.id, + user_id=account_id.id, filters=filters, ) @@ -85,7 +87,7 @@ async def api_invite_wallet_share( @wallet_router.delete("/share/invite/{share_request_id}") async def api_reject_wallet_invitation( - share_request_id: str, invited_user: User = Depends(check_user_exists) + share_request_id: str, invited_user: Account = Depends(check_account_exists) ) -> SimpleStatus: await reject_wallet_invitation(invited_user.id, share_request_id) return SimpleStatus(success=True, message="Invitation rejected.") @@ -124,10 +126,11 @@ async def api_update_wallet_name( @wallet_router.put("/reset/{wallet_id}") async def api_reset_wallet_keys( - wallet_id: str, user: User = Depends(check_user_exists) + wallet_id: str, + account_id: AccountId = Depends(check_account_id_exists), ) -> Wallet: wallet = await get_wallet(wallet_id) - if not wallet or wallet.user != user.id: + if not wallet or wallet.user != account_id.id: raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Wallet not found") wallet.adminkey = uuid4().hex @@ -175,10 +178,10 @@ async def api_update_wallet( @wallet_router.delete("/{wallet_id}") async def api_delete_wallet( - wallet_id: str, user: User = Depends(check_user_exists) + wallet_id: str, account_id: AccountId = Depends(check_account_id_exists) ) -> None: wallet = await get_wallet(wallet_id) - if not wallet or wallet.user != user.id: + if not wallet or wallet.user != account_id.id: raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Wallet not found") await delete_wallet( diff --git a/lnbits/core/views/webpush_api.py b/lnbits/core/views/webpush_api.py index 72094d05..98db75e1 100644 --- a/lnbits/core/views/webpush_api.py +++ b/lnbits/core/views/webpush_api.py @@ -15,9 +15,9 @@ from lnbits.core.models import ( CreateWebPushSubscription, WebPushSubscription, ) -from lnbits.core.models.users import User +from lnbits.core.models.users import AccountId from lnbits.decorators import ( - check_user_exists, + check_account_id_exists, ) from ..crud import ( @@ -33,20 +33,20 @@ webpush_router = APIRouter(prefix="/api/v1/webpush", tags=["Webpush"]) async def api_create_webpush_subscription( request: Request, data: CreateWebPushSubscription, - user: User = Depends(check_user_exists), + account_id: AccountId = Depends(check_account_id_exists), ) -> WebPushSubscription: try: subscription = json.loads(data.subscription) endpoint = subscription["endpoint"] host = urlparse(str(request.url)).netloc - subscription = await get_webpush_subscription(endpoint, user.id) + subscription = await get_webpush_subscription(endpoint, account_id.id) if subscription: return subscription else: return await create_webpush_subscription( endpoint, - user.id, + account_id.id, data.subscription, host, ) @@ -61,13 +61,13 @@ async def api_create_webpush_subscription( @webpush_router.delete("", status_code=HTTPStatus.OK) async def api_delete_webpush_subscription( request: Request, - user: User = Depends(check_user_exists), + account_id: AccountId = Depends(check_account_id_exists), ): try: endpoint = unquote( base64.b64decode(str(request.query_params.get("endpoint"))).decode("utf-8") ) - count = await delete_webpush_subscription(endpoint, user.id) + count = await delete_webpush_subscription(endpoint, account_id.id) return {"count": count} except Exception as exc: logger.debug(exc) diff --git a/lnbits/decorators.py b/lnbits/decorators.py index d70be6be..51a7bff8 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -27,9 +27,11 @@ from lnbits.core.models import ( User, WalletTypeInfo, ) +from lnbits.core.models.users import AccountId from lnbits.db import Connection, Filter, Filters, TFilterModel -from lnbits.helpers import normalize_path, path_segments +from lnbits.helpers import normalize_path, path_segments, sha256s from lnbits.settings import AuthMethods, settings +from lnbits.utils.cache import cache oauth2_scheme = OAuth2PasswordBearer( tokenUrl="api/v1/auth", @@ -106,7 +108,7 @@ class KeyChecker(SecurityBase): detail="Invalid adminkey.", ) - await _check_user_extension_access(wallet.user, request["path"]) + await _check_user_access(request, wallet.user) key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice return WalletTypeInfo(key_type, wallet) @@ -144,11 +146,49 @@ async def check_access_token( return header_access_token or cookie_access_token or bearer_access_token -async def check_user_exists( +async def check_account_id_exists( r: Request, access_token: Annotated[str | None, Depends(check_access_token)], usr: UUID4 | None = None, -) -> User: +) -> AccountId: + cache_key: str | None = None + if access_token: + cache_key = f"auth:access_token:{sha256s(access_token)}" + elif usr: + cache_key = f"auth:user_id:{sha256s(usr.hex)}" + + if cache_key and settings.auth_authentication_cache_minutes > 0: + account_id = cache.get(cache_key) + if account_id: + r.scope["user_id"] = account_id.id + await _check_user_access(r, account_id) + return account_id + + account = await check_account_exists(r, access_token, usr) + account_id = AccountId(id=account.id) + + if cache_key and settings.auth_authentication_cache_minutes > 0: + cache.set( + cache_key, + account_id, + expiry=settings.auth_authentication_cache_minutes * 60, + ) + + return account_id + + +async def check_account_exists( + r: Request, + access_token: Annotated[str | None, Depends(check_access_token)], + usr: UUID4 | None = None, +) -> Account: + """ + Check that the account exists based on access token or user id. + More performant version of `check_user_exists`. + Unlike `check_user_exists`, this function: + - does not fetch the user wallets + - caches the account info based on settings cache time + """ if access_token: account = await _get_account_from_token(access_token, r["path"], r["method"]) elif usr and settings.is_auth_method_allowed(AuthMethods.user_id_only): @@ -164,13 +204,21 @@ async def check_user_exists( raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.") r.scope["user_id"] = account.id - if not settings.is_user_allowed(account.id): - raise HTTPException(HTTPStatus.FORBIDDEN, "User not allowed.") + await _check_user_access(r, account.id) + return account + + +async def check_user_exists( + r: Request, + access_token: Annotated[str | None, Depends(check_access_token)], + usr: UUID4 | None = None, +) -> User: + account = await check_account_exists(r, access_token, usr) user = await get_user_from_account(account) if not user: raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.") - await _check_user_extension_access(user.id, r["path"]) + return user @@ -280,6 +328,12 @@ async def check_user_extension_access( return SimpleStatus(success=True, message="OK") +async def _check_user_access(r: Request, user_id: str): + if not settings.is_user_allowed(user_id): + raise HTTPException(HTTPStatus.FORBIDDEN, "User not allowed.") + await _check_user_extension_access(user_id, r["path"]) + + async def _check_user_extension_access(user_id: str, path: str): ext_id = path_segments(path)[0] status = await check_user_extension_access(user_id, ext_id) diff --git a/lnbits/settings.py b/lnbits/settings.py index 700859e9..76ae92e0 100644 --- a/lnbits/settings.py +++ b/lnbits/settings.py @@ -756,6 +756,7 @@ class AuthSettings(LNbitsSettings): # How many seconds after login the user is allowed to update its credentials. # A fresh login is required afterwards. auth_credetials_update_threshold: int = Field(default=120, gt=0) + auth_authentication_cache_minutes: int = Field(default=10, ge=0) def is_auth_method_allowed(self, method: AuthMethods): return method.value in self.auth_allowed_methods