perf: use check_account_exists decorator (#3600)

This commit is contained in:
Vlad Stan 2025-12-04 10:17:47 +02:00 committed by GitHub
parent 5213508dc1
commit b3efb4d378
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 182 additions and 114 deletions

View file

@ -1,5 +1,6 @@
from .core.services import create_invoice, pay_invoice
from .decorators import (
check_account_exists,
check_admin,
check_super_user,
check_user_exists,
@ -11,6 +12,7 @@ from .exceptions import InvoiceError, PaymentError
__all__ = [
"InvoiceError",
"PaymentError",
"check_account_exists",
"check_admin",
"check_super_user",
"check_user_exists",

View file

@ -172,8 +172,15 @@ class UserAcls(BaseModel):
return None
class Account(BaseModel):
class AccountId(BaseModel):
id: str
@property
def is_admin_id(self) -> bool:
return settings.is_admin_user(self.id)
class Account(AccountId):
external_id: str | None = None # for external account linking
username: str | None = None
password_hash: str | None = None

View file

@ -14,7 +14,12 @@ from lnbits.core.models import (
User,
Wallet,
)
from lnbits.decorators import check_user_exists
from lnbits.core.models.users import AccountId
from lnbits.decorators import (
check_account_exists,
check_account_id_exists,
check_user_exists,
)
from lnbits.settings import settings
from lnbits.utils.exchange_rates import (
allowed_currencies,
@ -39,7 +44,9 @@ async def health() -> dict:
@api_router.get("/api/v1/status", status_code=HTTPStatus.OK)
async def health_check(user: User = Depends(check_user_exists)) -> dict:
async def health_check(
account_id: AccountId = Depends(check_account_id_exists),
) -> dict:
stat: dict[str, Any] = {
"server_time": int(time()),
"up_time": settings.lnbits_server_up_time,
@ -47,7 +54,7 @@ async def health_check(user: User = Depends(check_user_exists)) -> dict:
}
stat["version"] = settings.version
if not user.admin:
if not account_id.is_admin_id:
return stat
funding_source = get_funding_source()
@ -78,7 +85,7 @@ async def api_create_account(data: CreateWallet) -> Wallet:
@api_router.get(
"/api/v1/rate/history",
dependencies=[Depends(check_user_exists)],
dependencies=[Depends(check_account_exists)],
)
async def api_exchange_rate_history() -> list[dict]:
return settings.lnbits_exchange_rate_history

View file

@ -15,11 +15,11 @@ from lnbits.core.crud.assets import (
)
from lnbits.core.models.assets import AssetFilters, AssetInfo, AssetUpdate
from lnbits.core.models.misc import SimpleStatus
from lnbits.core.models.users import User
from lnbits.core.models.users import AccountId
from lnbits.core.services.assets import create_user_asset
from lnbits.db import Filters, Page
from lnbits.decorators import (
check_user_exists,
check_account_id_exists,
optional_user_id,
parse_filters,
)
@ -35,10 +35,10 @@ upload_file_param = File(...)
summary="Get paginated list user assets",
)
async def api_get_user_assets(
user: User = Depends(check_user_exists),
account_id: AccountId = Depends(check_account_id_exists),
filters: Filters = Depends(parse_filters(AssetFilters)),
) -> Page[AssetInfo]:
return await get_user_assets(user.id, filters=filters)
return await get_user_assets(account_id.id, filters=filters)
@asset_router.get(
@ -48,9 +48,9 @@ async def api_get_user_assets(
)
async def api_get_asset(
asset_id: str,
user: User = Depends(check_user_exists),
account_id: AccountId = Depends(check_account_id_exists),
) -> AssetInfo:
asset_info = await get_user_asset_info(user.id, asset_id)
asset_info = await get_user_asset_info(account_id.id, asset_id)
if not asset_info:
raise HTTPException(HTTPStatus.NOT_FOUND, "Asset not found.")
return asset_info
@ -120,12 +120,12 @@ async def api_get_asset_thumbnail(
async def api_update_asset(
asset_id: str,
data: AssetUpdate,
user: User = Depends(check_user_exists),
account_id: AccountId = Depends(check_account_id_exists),
) -> AssetInfo:
if user.admin:
if account_id.is_admin_id:
asset_info = await get_asset_info(asset_id)
else:
asset_info = await get_user_asset_info(user.id, asset_id)
asset_info = await get_user_asset_info(account_id.id, asset_id)
if not asset_info:
raise HTTPException(HTTPStatus.NOT_FOUND, "Asset not found.")
@ -144,13 +144,13 @@ async def api_update_asset(
summary="Upload user assets",
)
async def api_upload_asset(
user: User = Depends(check_user_exists),
account_id: AccountId = Depends(check_account_id_exists),
file: UploadFile = upload_file_param,
public_asset: bool = False,
) -> AssetInfo:
asset = await create_user_asset(user.id, file, public_asset)
asset = await create_user_asset(account_id.id, file, public_asset)
asset_info = await get_user_asset_info(user.id, asset.id)
asset_info = await get_user_asset_info(account_id.id, asset.id)
if not asset_info:
raise ValueError("Failed to retrieve asset info after upload.")
@ -164,11 +164,11 @@ async def api_upload_asset(
)
async def api_delete_asset(
asset_id: str,
user: User = Depends(check_user_exists),
account_id: AccountId = Depends(check_account_id_exists),
) -> SimpleStatus:
asset = await get_user_asset(user.id, asset_id)
asset = await get_user_asset(account_id.id, asset_id)
if not asset:
raise HTTPException(HTTPStatus.NOT_FOUND, "Asset not found.")
await delete_user_asset(user.id, asset_id)
await delete_user_asset(account_id.id, asset_id)
return SimpleStatus(success=True, message="Asset deleted successfully.")

View file

@ -26,7 +26,11 @@ from lnbits.core.models.users import (
)
from lnbits.core.services import create_user_account
from lnbits.core.services.users import update_user_account
from lnbits.decorators import access_token_payload, check_user_exists
from lnbits.decorators import (
access_token_payload,
check_account_exists,
check_user_exists,
)
from lnbits.helpers import (
create_access_token,
decrypt_internal_message,
@ -119,11 +123,11 @@ async def login_usr(data: LoginUsr) -> JSONResponse:
@auth_router.get("/acl")
async def api_get_user_acls(
request: Request,
user: User = Depends(check_user_exists),
account: Account = Depends(check_account_exists),
) -> UserAcls:
api_routes = get_api_routes(request.app.router.routes)
acls = await get_user_access_control_lists(user.id)
acls = await get_user_access_control_lists(account.id)
# Add missing/new endpoints to the ACLs
for acl in acls.access_control_list:
@ -136,7 +140,7 @@ async def api_get_user_acls(
acl.endpoints.append(EndpointAccess(path=path, name=name))
acl.endpoints.sort(key=lambda e: e.name.lower())
return UserAcls(id=user.id, access_control_list=acls.access_control_list)
return UserAcls(id=account.id, access_control_list=acls.access_control_list)
@auth_router.put("/acl")
@ -144,13 +148,13 @@ async def api_get_user_acls(
async def api_update_user_acl(
request: Request,
data: UpdateAccessControlList,
user: User = Depends(check_user_exists),
account: Account = Depends(check_account_exists),
) -> UserAcls:
account = await get_account(user.id)
if not account or not account.verify_password(data.password):
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
user_acls = await get_user_access_control_lists(user.id)
user_acls = await get_user_access_control_lists(account.id)
acl = user_acls.get_acl_by_id(data.id)
if acl:
user_acls.access_control_list.remove(acl)
@ -175,33 +179,30 @@ async def api_update_user_acl(
@auth_router.delete("/acl")
async def api_delete_user_acl(
data: DeleteAccessControlList,
user: User = Depends(check_user_exists),
data: DeleteAccessControlList, account: Account = Depends(check_account_exists)
):
account = await get_account(user.id)
if not account or not account.verify_password(data.password):
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
user_acls = await get_user_access_control_lists(user.id)
user_acls = await get_user_access_control_lists(account.id)
user_acls.delete_acl_by_id(data.id)
await update_user_access_control_list(user_acls)
@auth_router.post("/acl/token")
async def api_create_user_api_token(
data: ApiTokenRequest,
user: User = Depends(check_user_exists),
data: ApiTokenRequest, account: Account = Depends(check_account_exists)
) -> ApiTokenResponse:
if not data.expiration_time_minutes > 0:
raise ValueError("Expiration time must be in the future.")
account = await get_account(user.id)
if not account or not account.verify_password(data.password):
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
if not account.username:
raise ValueError("Username must be configured.")
acls = await get_user_access_control_lists(user.id)
acls = await get_user_access_control_lists(account.id)
acl = acls.get_acl_by_id(data.acl_id)
if not acl:
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid ACL id.")
@ -218,18 +219,16 @@ async def api_create_user_api_token(
@auth_router.delete("/acl/token")
async def api_delete_user_api_token(
data: DeleteTokenRequest,
user: User = Depends(check_user_exists),
data: DeleteTokenRequest, account: Account = Depends(check_account_exists)
):
account = await get_account(user.id)
if not account or not account.verify_password(data.password):
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
if not account.username:
raise ValueError("Username must be configured.")
acls = await get_user_access_control_lists(user.id)
acls = await get_user_access_control_lists(account.id)
acl = acls.get_acl_by_id(data.acl_id)
if not acl:
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid ACL id.")
@ -318,24 +317,20 @@ async def register(data: RegisterUser) -> JSONResponse:
@auth_router.put("/pubkey")
async def update_pubkey(
data: UpdateUserPubkey,
user: User = Depends(check_user_exists),
account: Account = Depends(check_account_exists),
payload: AccessTokenPayload = Depends(access_token_payload),
) -> User | None:
if data.user_id != user.id:
if data.user_id != account.id:
raise ValueError("Invalid user ID.")
_validate_auth_timeout(payload.auth_time)
if (
data.pubkey
and data.pubkey != user.pubkey
and data.pubkey != account.pubkey
and await get_account_by_pubkey(data.pubkey)
):
raise ValueError("Public key already in use.")
account = await get_account(user.id)
if not account:
raise HTTPException(HTTPStatus.NOT_FOUND, "Account not found.")
account.pubkey = normalize_public_key(data.pubkey)
await update_account(account)
return await get_user_from_account(account)
@ -344,23 +339,19 @@ async def update_pubkey(
@auth_router.put("/password")
async def update_password(
data: UpdateUserPassword,
user: User = Depends(check_user_exists),
account: Account = Depends(check_account_exists),
payload: AccessTokenPayload = Depends(access_token_payload),
) -> User | None:
_validate_auth_timeout(payload.auth_time)
if data.user_id != user.id:
if data.user_id != account.id:
raise ValueError("Invalid user ID.")
if (
data.username
and user.username != data.username
and account.username != data.username
and await get_account_by_username(data.username)
):
raise HTTPException(HTTPStatus.BAD_REQUEST, "Username already exists.")
account = await get_account(user.id)
if not account:
raise ValueError("Account not found.")
# old accounts do not have a password
if account.password_hash:
if not data.password_old:
@ -419,15 +410,11 @@ async def reset_password(data: ResetUserPassword) -> JSONResponse:
@auth_router.put("/update")
async def update(
data: UpdateUser, user: User = Depends(check_user_exists)
data: UpdateUser, account: Account = Depends(check_account_exists)
) -> User | None:
if data.user_id != user.id:
if data.user_id != account.id:
raise HTTPException(HTTPStatus.BAD_REQUEST, "Invalid user ID.")
account = await get_account(user.id)
if not account:
raise HTTPException(HTTPStatus.NOT_FOUND, "Account not found.")
if data.username:
account.username = data.username
if data.extra:

View file

@ -23,6 +23,7 @@ from lnbits.core.models.extensions import (
UserExtension,
UserExtensionInfo,
)
from lnbits.core.models.users import AccountId
from lnbits.core.services import check_transaction_status, create_invoice
from lnbits.core.services.extensions import (
activate_extension,
@ -33,8 +34,9 @@ from lnbits.core.services.extensions import (
uninstall_extension,
)
from lnbits.decorators import (
check_account_exists,
check_account_id_exists,
check_admin,
check_user_exists,
)
from lnbits.settings import settings
@ -163,7 +165,7 @@ async def api_update_pay_to_enable(
@extension_router.put("/{ext_id}/enable")
async def api_enable_extension(
ext_id: str, user: User = Depends(check_user_exists)
ext_id: str, account_id: AccountId = Depends(check_account_id_exists)
) -> SimpleStatus:
if ext_id not in [e.code for e in await get_valid_extensions()]:
raise HTTPException(
@ -177,12 +179,12 @@ async def api_enable_extension(
if not ext.active:
raise ValueError(f"Extension '{ext_id}' is not activated.")
user_ext = await get_user_extension(user.id, ext_id)
user_ext = await get_user_extension(account_id.id, ext_id)
if not user_ext:
user_ext = UserExtension(user=user.id, extension=ext_id, active=False)
user_ext = UserExtension(user=account_id.id, extension=ext_id, active=False)
await create_user_extension(user_ext)
if user.admin or not ext.requires_payment:
if account_id.is_admin_id or not ext.requires_payment:
user_ext.active = True
await update_user_extension(user_ext)
return SimpleStatus(success=True, message=f"Extension '{ext_id}' enabled.")
@ -219,13 +221,13 @@ async def api_enable_extension(
@extension_router.put("/{ext_id}/disable")
async def api_disable_extension(
ext_id: str, user: User = Depends(check_user_exists)
ext_id: str, account_id: AccountId = Depends(check_account_id_exists)
) -> SimpleStatus:
if ext_id not in [e.code for e in await get_valid_extensions()]:
raise HTTPException(
HTTPStatus.BAD_REQUEST, f"Extension '{ext_id}' doesn't exist."
)
user_ext = await get_user_extension(user.id, ext_id)
user_ext = await get_user_extension(account_id.id, ext_id)
if not user_ext or not user_ext.active:
return SimpleStatus(
success=True, message=f"Extension '{ext_id}' already disabled."
@ -376,7 +378,9 @@ async def get_pay_to_install_invoice(
@extension_router.put("/{ext_id}/invoice/enable")
async def get_pay_to_enable_invoice(
ext_id: str, data: PayToEnableInfo, user: User = Depends(check_user_exists)
ext_id: str,
data: PayToEnableInfo,
account_id: AccountId = Depends(check_account_id_exists),
):
if not data.amount or data.amount <= 0:
raise HTTPException(
@ -422,9 +426,9 @@ async def get_pay_to_enable_invoice(
memo=f"Enable '{ext.name}' extension.",
)
user_ext = await get_user_extension(user.id, ext_id)
user_ext = await get_user_extension(account_id.id, ext_id)
if not user_ext:
user_ext = UserExtension(user=user.id, extension=ext_id, active=False)
user_ext = UserExtension(user=account_id.id, extension=ext_id, active=False)
await create_user_extension(user_ext)
user_ext_info = user_ext.extra if user_ext.extra else UserExtensionInfo()
user_ext_info.payment_hash_to_enable = payment.payment_hash
@ -435,7 +439,7 @@ async def get_pay_to_enable_invoice(
@extension_router.get(
"/release/{org}/{repo}/{tag_name}",
dependencies=[Depends(check_user_exists)],
dependencies=[Depends(check_account_exists)],
)
async def get_extension_release(org: str, repo: str, tag_name: str):
try:
@ -456,10 +460,12 @@ async def get_extension_release(org: str, repo: str, tag_name: str):
@extension_router.get("")
async def api_get_user_extensions(
user: User = Depends(check_user_exists),
account_id: AccountId = Depends(check_account_id_exists),
) -> list[Extension]:
user_extensions_ids = [ue.extension for ue in await get_user_extensions(user.id)]
user_extensions_ids = [
ue.extension for ue in await get_user_extensions(account_id.id)
]
return [
ext
for ext in await get_valid_extensions(False)
@ -498,7 +504,7 @@ async def delete_extension_db(ext_id: str):
# TODO: create a response model for this
@extension_router.get("/all")
async def extensions(user: User = Depends(check_user_exists)):
async def extensions(account_id: AccountId = Depends(check_account_id_exists)):
installed_exts: list[InstallableExtension] = await get_installed_extensions()
installed_exts_ids = [e.id for e in installed_exts]
@ -510,7 +516,7 @@ async def extensions(user: User = Depends(check_user_exists)):
installed_ext = next((ie for ie in installed_exts if e.id == ie.id), None)
if installed_ext and installed_ext.meta:
installed_release = installed_ext.meta.installed_release
if installed_ext.meta.pay_to_enable and not user.admin:
if installed_ext.meta.pay_to_enable and not account_id.is_admin_id:
# not a security leak, but better not to share the wallet id
installed_ext.meta.pay_to_enable.wallet = None
pay_to_enable = installed_ext.meta.pay_to_enable

View file

@ -16,6 +16,7 @@ from lnbits.core.models.extensions import (
UserExtension,
)
from lnbits.core.models.extensions_builder import ExtensionData
from lnbits.core.models.users import AccountId
from lnbits.core.services.extensions import (
activate_extension,
install_extension,
@ -26,9 +27,9 @@ from lnbits.core.services.extensions_builder import (
zip_directory,
)
from lnbits.decorators import (
check_account_id_exists,
check_admin,
check_extension_builder,
check_user_exists,
)
from ..crud import (
@ -128,10 +129,10 @@ async def api_deploy_extension(
)
async def api_preview_extension(
data: ExtensionData,
user: User = Depends(check_user_exists),
account_id: AccountId = Depends(check_account_id_exists),
) -> SimpleStatus:
stub_ext_id = "extension_builder_stub"
working_dir_name = "preview_" + sha256(user.id.encode("utf-8")).hexdigest()
working_dir_name = "preview_" + sha256(account_id.id.encode("utf-8")).hexdigest()
await build_extension_from_data(data, stub_ext_id, working_dir_name)
return SimpleStatus(success=True, message=f"Extension '{data.id}' preview ready.")

View file

@ -35,11 +35,11 @@ from lnbits.core.models import (
SimpleStatus,
)
from lnbits.core.models.payments import UpdatePaymentLabels
from lnbits.core.models.users import User
from lnbits.core.models.users import AccountId
from lnbits.db import Filters, Page
from lnbits.decorators import (
WalletTypeInfo,
check_user_exists,
check_account_id_exists,
parse_filters,
require_admin_key,
require_invoice_key,
@ -118,14 +118,14 @@ async def api_payments_history(
async def api_payments_counting_stats(
count_by: PaymentCountField = Query("tag"),
filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)),
user: User = Depends(check_user_exists),
account_id: AccountId = Depends(check_account_id_exists),
):
if user.admin:
if account_id.is_admin_id:
# admin user can see payments from all wallets
for_user_id = None
else:
# regular user can only see payments from their wallets
for_user_id = user.id
for_user_id = account_id.id
return await get_payment_count_stats(count_by, filters=filters, user_id=for_user_id)
@ -138,14 +138,14 @@ async def api_payments_counting_stats(
)
async def api_payments_wallets_stats(
filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)),
user: User = Depends(check_user_exists),
account_id: AccountId = Depends(check_account_id_exists),
):
if user.admin:
if account_id.is_admin_id:
# admin user can see payments from all wallets
for_user_id = None
else:
# regular user can only see payments from their wallets
for_user_id = user.id
for_user_id = account_id.id
return await get_wallets_stats(filters, user_id=for_user_id)
@ -157,15 +157,15 @@ async def api_payments_wallets_stats(
openapi_extra=generate_filter_params_openapi(PaymentFilters),
)
async def api_payments_daily_stats(
user: User = Depends(check_user_exists),
account_id: AccountId = Depends(check_account_id_exists),
filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)),
):
if user.admin:
if account_id.is_admin_id:
# admin user can see payments from all wallets
for_user_id = None
else:
# regular user can only see payments from their wallets
for_user_id = user.id
for_user_id = account_id.id
return await get_payments_daily_stats(filters, user_id=for_user_id)
@ -208,14 +208,14 @@ async def api_payments_paginated(
)
async def api_all_payments_paginated(
filters: Filters = Depends(parse_filters(PaymentFilters)),
user: User = Depends(check_user_exists),
account_id: AccountId = Depends(check_account_id_exists),
):
if user.admin:
if account_id.is_admin_id:
# admin user can see payments from all wallets
for_user_id = None
else:
# regular user can only see payments from their wallets
for_user_id = user.id
for_user_id = account_id.id
return await get_payments_paginated(
filters=filters,

View file

@ -12,9 +12,10 @@ from lnbits.core.crud.wallets import (
create_wallet,
get_wallets_paginated,
)
from lnbits.core.models import CreateWallet, KeyType, User, Wallet, WalletTypeInfo
from lnbits.core.models import CreateWallet, KeyType, Wallet, WalletTypeInfo
from lnbits.core.models.lnurl import StoredPayLink, StoredPayLinks
from lnbits.core.models.misc import SimpleStatus
from lnbits.core.models.users import Account, AccountId
from lnbits.core.models.wallets import (
WalletsFilters,
WalletSharePermission,
@ -29,7 +30,8 @@ from lnbits.core.services.wallets import (
)
from lnbits.db import Filters, Page
from lnbits.decorators import (
check_user_exists,
check_account_exists,
check_account_id_exists,
parse_filters,
require_admin_key,
require_invoice_key,
@ -65,11 +67,11 @@ async def api_wallet(key_info: WalletTypeInfo = Depends(require_invoice_key)):
openapi_extra=generate_filter_params_openapi(WalletsFilters),
)
async def api_wallets_paginated(
user: User = Depends(check_user_exists),
account_id: AccountId = Depends(check_account_id_exists),
filters: Filters = Depends(parse_filters(WalletsFilters)),
):
page = await get_wallets_paginated(
user_id=user.id,
user_id=account_id.id,
filters=filters,
)
@ -85,7 +87,7 @@ async def api_invite_wallet_share(
@wallet_router.delete("/share/invite/{share_request_id}")
async def api_reject_wallet_invitation(
share_request_id: str, invited_user: User = Depends(check_user_exists)
share_request_id: str, invited_user: Account = Depends(check_account_exists)
) -> SimpleStatus:
await reject_wallet_invitation(invited_user.id, share_request_id)
return SimpleStatus(success=True, message="Invitation rejected.")
@ -124,10 +126,11 @@ async def api_update_wallet_name(
@wallet_router.put("/reset/{wallet_id}")
async def api_reset_wallet_keys(
wallet_id: str, user: User = Depends(check_user_exists)
wallet_id: str,
account_id: AccountId = Depends(check_account_id_exists),
) -> Wallet:
wallet = await get_wallet(wallet_id)
if not wallet or wallet.user != user.id:
if not wallet or wallet.user != account_id.id:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Wallet not found")
wallet.adminkey = uuid4().hex
@ -175,10 +178,10 @@ async def api_update_wallet(
@wallet_router.delete("/{wallet_id}")
async def api_delete_wallet(
wallet_id: str, user: User = Depends(check_user_exists)
wallet_id: str, account_id: AccountId = Depends(check_account_id_exists)
) -> None:
wallet = await get_wallet(wallet_id)
if not wallet or wallet.user != user.id:
if not wallet or wallet.user != account_id.id:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Wallet not found")
await delete_wallet(

View file

@ -15,9 +15,9 @@ from lnbits.core.models import (
CreateWebPushSubscription,
WebPushSubscription,
)
from lnbits.core.models.users import User
from lnbits.core.models.users import AccountId
from lnbits.decorators import (
check_user_exists,
check_account_id_exists,
)
from ..crud import (
@ -33,20 +33,20 @@ webpush_router = APIRouter(prefix="/api/v1/webpush", tags=["Webpush"])
async def api_create_webpush_subscription(
request: Request,
data: CreateWebPushSubscription,
user: User = Depends(check_user_exists),
account_id: AccountId = Depends(check_account_id_exists),
) -> WebPushSubscription:
try:
subscription = json.loads(data.subscription)
endpoint = subscription["endpoint"]
host = urlparse(str(request.url)).netloc
subscription = await get_webpush_subscription(endpoint, user.id)
subscription = await get_webpush_subscription(endpoint, account_id.id)
if subscription:
return subscription
else:
return await create_webpush_subscription(
endpoint,
user.id,
account_id.id,
data.subscription,
host,
)
@ -61,13 +61,13 @@ async def api_create_webpush_subscription(
@webpush_router.delete("", status_code=HTTPStatus.OK)
async def api_delete_webpush_subscription(
request: Request,
user: User = Depends(check_user_exists),
account_id: AccountId = Depends(check_account_id_exists),
):
try:
endpoint = unquote(
base64.b64decode(str(request.query_params.get("endpoint"))).decode("utf-8")
)
count = await delete_webpush_subscription(endpoint, user.id)
count = await delete_webpush_subscription(endpoint, account_id.id)
return {"count": count}
except Exception as exc:
logger.debug(exc)

View file

@ -27,9 +27,11 @@ from lnbits.core.models import (
User,
WalletTypeInfo,
)
from lnbits.core.models.users import AccountId
from lnbits.db import Connection, Filter, Filters, TFilterModel
from lnbits.helpers import normalize_path, path_segments
from lnbits.helpers import normalize_path, path_segments, sha256s
from lnbits.settings import AuthMethods, settings
from lnbits.utils.cache import cache
oauth2_scheme = OAuth2PasswordBearer(
tokenUrl="api/v1/auth",
@ -106,7 +108,7 @@ class KeyChecker(SecurityBase):
detail="Invalid adminkey.",
)
await _check_user_extension_access(wallet.user, request["path"])
await _check_user_access(request, wallet.user)
key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice
return WalletTypeInfo(key_type, wallet)
@ -144,11 +146,49 @@ async def check_access_token(
return header_access_token or cookie_access_token or bearer_access_token
async def check_user_exists(
async def check_account_id_exists(
r: Request,
access_token: Annotated[str | None, Depends(check_access_token)],
usr: UUID4 | None = None,
) -> User:
) -> AccountId:
cache_key: str | None = None
if access_token:
cache_key = f"auth:access_token:{sha256s(access_token)}"
elif usr:
cache_key = f"auth:user_id:{sha256s(usr.hex)}"
if cache_key and settings.auth_authentication_cache_minutes > 0:
account_id = cache.get(cache_key)
if account_id:
r.scope["user_id"] = account_id.id
await _check_user_access(r, account_id)
return account_id
account = await check_account_exists(r, access_token, usr)
account_id = AccountId(id=account.id)
if cache_key and settings.auth_authentication_cache_minutes > 0:
cache.set(
cache_key,
account_id,
expiry=settings.auth_authentication_cache_minutes * 60,
)
return account_id
async def check_account_exists(
r: Request,
access_token: Annotated[str | None, Depends(check_access_token)],
usr: UUID4 | None = None,
) -> Account:
"""
Check that the account exists based on access token or user id.
More performant version of `check_user_exists`.
Unlike `check_user_exists`, this function:
- does not fetch the user wallets
- caches the account info based on settings cache time
"""
if access_token:
account = await _get_account_from_token(access_token, r["path"], r["method"])
elif usr and settings.is_auth_method_allowed(AuthMethods.user_id_only):
@ -164,13 +204,21 @@ async def check_user_exists(
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.")
r.scope["user_id"] = account.id
if not settings.is_user_allowed(account.id):
raise HTTPException(HTTPStatus.FORBIDDEN, "User not allowed.")
await _check_user_access(r, account.id)
return account
async def check_user_exists(
r: Request,
access_token: Annotated[str | None, Depends(check_access_token)],
usr: UUID4 | None = None,
) -> User:
account = await check_account_exists(r, access_token, usr)
user = await get_user_from_account(account)
if not user:
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.")
await _check_user_extension_access(user.id, r["path"])
return user
@ -280,6 +328,12 @@ async def check_user_extension_access(
return SimpleStatus(success=True, message="OK")
async def _check_user_access(r: Request, user_id: str):
if not settings.is_user_allowed(user_id):
raise HTTPException(HTTPStatus.FORBIDDEN, "User not allowed.")
await _check_user_extension_access(user_id, r["path"])
async def _check_user_extension_access(user_id: str, path: str):
ext_id = path_segments(path)[0]
status = await check_user_extension_access(user_id, ext_id)

View file

@ -756,6 +756,7 @@ class AuthSettings(LNbitsSettings):
# How many seconds after login the user is allowed to update its credentials.
# A fresh login is required afterwards.
auth_credetials_update_threshold: int = Field(default=120, gt=0)
auth_authentication_cache_minutes: int = Field(default=10, ge=0)
def is_auth_method_allowed(self, method: AuthMethods):
return method.value in self.auth_allowed_methods