perf: use check_account_exists decorator (#3600)
This commit is contained in:
parent
5213508dc1
commit
b3efb4d378
12 changed files with 182 additions and 114 deletions
|
|
@ -1,5 +1,6 @@
|
|||
from .core.services import create_invoice, pay_invoice
|
||||
from .decorators import (
|
||||
check_account_exists,
|
||||
check_admin,
|
||||
check_super_user,
|
||||
check_user_exists,
|
||||
|
|
@ -11,6 +12,7 @@ from .exceptions import InvoiceError, PaymentError
|
|||
__all__ = [
|
||||
"InvoiceError",
|
||||
"PaymentError",
|
||||
"check_account_exists",
|
||||
"check_admin",
|
||||
"check_super_user",
|
||||
"check_user_exists",
|
||||
|
|
|
|||
|
|
@ -172,8 +172,15 @@ class UserAcls(BaseModel):
|
|||
return None
|
||||
|
||||
|
||||
class Account(BaseModel):
|
||||
class AccountId(BaseModel):
|
||||
id: str
|
||||
|
||||
@property
|
||||
def is_admin_id(self) -> bool:
|
||||
return settings.is_admin_user(self.id)
|
||||
|
||||
|
||||
class Account(AccountId):
|
||||
external_id: str | None = None # for external account linking
|
||||
username: str | None = None
|
||||
password_hash: str | None = None
|
||||
|
|
|
|||
|
|
@ -14,7 +14,12 @@ from lnbits.core.models import (
|
|||
User,
|
||||
Wallet,
|
||||
)
|
||||
from lnbits.decorators import check_user_exists
|
||||
from lnbits.core.models.users import AccountId
|
||||
from lnbits.decorators import (
|
||||
check_account_exists,
|
||||
check_account_id_exists,
|
||||
check_user_exists,
|
||||
)
|
||||
from lnbits.settings import settings
|
||||
from lnbits.utils.exchange_rates import (
|
||||
allowed_currencies,
|
||||
|
|
@ -39,7 +44,9 @@ async def health() -> dict:
|
|||
|
||||
|
||||
@api_router.get("/api/v1/status", status_code=HTTPStatus.OK)
|
||||
async def health_check(user: User = Depends(check_user_exists)) -> dict:
|
||||
async def health_check(
|
||||
account_id: AccountId = Depends(check_account_id_exists),
|
||||
) -> dict:
|
||||
stat: dict[str, Any] = {
|
||||
"server_time": int(time()),
|
||||
"up_time": settings.lnbits_server_up_time,
|
||||
|
|
@ -47,7 +54,7 @@ async def health_check(user: User = Depends(check_user_exists)) -> dict:
|
|||
}
|
||||
|
||||
stat["version"] = settings.version
|
||||
if not user.admin:
|
||||
if not account_id.is_admin_id:
|
||||
return stat
|
||||
|
||||
funding_source = get_funding_source()
|
||||
|
|
@ -78,7 +85,7 @@ async def api_create_account(data: CreateWallet) -> Wallet:
|
|||
|
||||
@api_router.get(
|
||||
"/api/v1/rate/history",
|
||||
dependencies=[Depends(check_user_exists)],
|
||||
dependencies=[Depends(check_account_exists)],
|
||||
)
|
||||
async def api_exchange_rate_history() -> list[dict]:
|
||||
return settings.lnbits_exchange_rate_history
|
||||
|
|
|
|||
|
|
@ -15,11 +15,11 @@ from lnbits.core.crud.assets import (
|
|||
)
|
||||
from lnbits.core.models.assets import AssetFilters, AssetInfo, AssetUpdate
|
||||
from lnbits.core.models.misc import SimpleStatus
|
||||
from lnbits.core.models.users import User
|
||||
from lnbits.core.models.users import AccountId
|
||||
from lnbits.core.services.assets import create_user_asset
|
||||
from lnbits.db import Filters, Page
|
||||
from lnbits.decorators import (
|
||||
check_user_exists,
|
||||
check_account_id_exists,
|
||||
optional_user_id,
|
||||
parse_filters,
|
||||
)
|
||||
|
|
@ -35,10 +35,10 @@ upload_file_param = File(...)
|
|||
summary="Get paginated list user assets",
|
||||
)
|
||||
async def api_get_user_assets(
|
||||
user: User = Depends(check_user_exists),
|
||||
account_id: AccountId = Depends(check_account_id_exists),
|
||||
filters: Filters = Depends(parse_filters(AssetFilters)),
|
||||
) -> Page[AssetInfo]:
|
||||
return await get_user_assets(user.id, filters=filters)
|
||||
return await get_user_assets(account_id.id, filters=filters)
|
||||
|
||||
|
||||
@asset_router.get(
|
||||
|
|
@ -48,9 +48,9 @@ async def api_get_user_assets(
|
|||
)
|
||||
async def api_get_asset(
|
||||
asset_id: str,
|
||||
user: User = Depends(check_user_exists),
|
||||
account_id: AccountId = Depends(check_account_id_exists),
|
||||
) -> AssetInfo:
|
||||
asset_info = await get_user_asset_info(user.id, asset_id)
|
||||
asset_info = await get_user_asset_info(account_id.id, asset_id)
|
||||
if not asset_info:
|
||||
raise HTTPException(HTTPStatus.NOT_FOUND, "Asset not found.")
|
||||
return asset_info
|
||||
|
|
@ -120,12 +120,12 @@ async def api_get_asset_thumbnail(
|
|||
async def api_update_asset(
|
||||
asset_id: str,
|
||||
data: AssetUpdate,
|
||||
user: User = Depends(check_user_exists),
|
||||
account_id: AccountId = Depends(check_account_id_exists),
|
||||
) -> AssetInfo:
|
||||
if user.admin:
|
||||
if account_id.is_admin_id:
|
||||
asset_info = await get_asset_info(asset_id)
|
||||
else:
|
||||
asset_info = await get_user_asset_info(user.id, asset_id)
|
||||
asset_info = await get_user_asset_info(account_id.id, asset_id)
|
||||
|
||||
if not asset_info:
|
||||
raise HTTPException(HTTPStatus.NOT_FOUND, "Asset not found.")
|
||||
|
|
@ -144,13 +144,13 @@ async def api_update_asset(
|
|||
summary="Upload user assets",
|
||||
)
|
||||
async def api_upload_asset(
|
||||
user: User = Depends(check_user_exists),
|
||||
account_id: AccountId = Depends(check_account_id_exists),
|
||||
file: UploadFile = upload_file_param,
|
||||
public_asset: bool = False,
|
||||
) -> AssetInfo:
|
||||
asset = await create_user_asset(user.id, file, public_asset)
|
||||
asset = await create_user_asset(account_id.id, file, public_asset)
|
||||
|
||||
asset_info = await get_user_asset_info(user.id, asset.id)
|
||||
asset_info = await get_user_asset_info(account_id.id, asset.id)
|
||||
if not asset_info:
|
||||
raise ValueError("Failed to retrieve asset info after upload.")
|
||||
|
||||
|
|
@ -164,11 +164,11 @@ async def api_upload_asset(
|
|||
)
|
||||
async def api_delete_asset(
|
||||
asset_id: str,
|
||||
user: User = Depends(check_user_exists),
|
||||
account_id: AccountId = Depends(check_account_id_exists),
|
||||
) -> SimpleStatus:
|
||||
asset = await get_user_asset(user.id, asset_id)
|
||||
asset = await get_user_asset(account_id.id, asset_id)
|
||||
if not asset:
|
||||
raise HTTPException(HTTPStatus.NOT_FOUND, "Asset not found.")
|
||||
|
||||
await delete_user_asset(user.id, asset_id)
|
||||
await delete_user_asset(account_id.id, asset_id)
|
||||
return SimpleStatus(success=True, message="Asset deleted successfully.")
|
||||
|
|
|
|||
|
|
@ -26,7 +26,11 @@ from lnbits.core.models.users import (
|
|||
)
|
||||
from lnbits.core.services import create_user_account
|
||||
from lnbits.core.services.users import update_user_account
|
||||
from lnbits.decorators import access_token_payload, check_user_exists
|
||||
from lnbits.decorators import (
|
||||
access_token_payload,
|
||||
check_account_exists,
|
||||
check_user_exists,
|
||||
)
|
||||
from lnbits.helpers import (
|
||||
create_access_token,
|
||||
decrypt_internal_message,
|
||||
|
|
@ -119,11 +123,11 @@ async def login_usr(data: LoginUsr) -> JSONResponse:
|
|||
@auth_router.get("/acl")
|
||||
async def api_get_user_acls(
|
||||
request: Request,
|
||||
user: User = Depends(check_user_exists),
|
||||
account: Account = Depends(check_account_exists),
|
||||
) -> UserAcls:
|
||||
api_routes = get_api_routes(request.app.router.routes)
|
||||
|
||||
acls = await get_user_access_control_lists(user.id)
|
||||
acls = await get_user_access_control_lists(account.id)
|
||||
|
||||
# Add missing/new endpoints to the ACLs
|
||||
for acl in acls.access_control_list:
|
||||
|
|
@ -136,7 +140,7 @@ async def api_get_user_acls(
|
|||
acl.endpoints.append(EndpointAccess(path=path, name=name))
|
||||
acl.endpoints.sort(key=lambda e: e.name.lower())
|
||||
|
||||
return UserAcls(id=user.id, access_control_list=acls.access_control_list)
|
||||
return UserAcls(id=account.id, access_control_list=acls.access_control_list)
|
||||
|
||||
|
||||
@auth_router.put("/acl")
|
||||
|
|
@ -144,13 +148,13 @@ async def api_get_user_acls(
|
|||
async def api_update_user_acl(
|
||||
request: Request,
|
||||
data: UpdateAccessControlList,
|
||||
user: User = Depends(check_user_exists),
|
||||
account: Account = Depends(check_account_exists),
|
||||
) -> UserAcls:
|
||||
account = await get_account(user.id)
|
||||
|
||||
if not account or not account.verify_password(data.password):
|
||||
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
|
||||
|
||||
user_acls = await get_user_access_control_lists(user.id)
|
||||
user_acls = await get_user_access_control_lists(account.id)
|
||||
acl = user_acls.get_acl_by_id(data.id)
|
||||
if acl:
|
||||
user_acls.access_control_list.remove(acl)
|
||||
|
|
@ -175,33 +179,30 @@ async def api_update_user_acl(
|
|||
|
||||
@auth_router.delete("/acl")
|
||||
async def api_delete_user_acl(
|
||||
data: DeleteAccessControlList,
|
||||
user: User = Depends(check_user_exists),
|
||||
data: DeleteAccessControlList, account: Account = Depends(check_account_exists)
|
||||
):
|
||||
account = await get_account(user.id)
|
||||
if not account or not account.verify_password(data.password):
|
||||
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
|
||||
|
||||
user_acls = await get_user_access_control_lists(user.id)
|
||||
user_acls = await get_user_access_control_lists(account.id)
|
||||
user_acls.delete_acl_by_id(data.id)
|
||||
await update_user_access_control_list(user_acls)
|
||||
|
||||
|
||||
@auth_router.post("/acl/token")
|
||||
async def api_create_user_api_token(
|
||||
data: ApiTokenRequest,
|
||||
user: User = Depends(check_user_exists),
|
||||
data: ApiTokenRequest, account: Account = Depends(check_account_exists)
|
||||
) -> ApiTokenResponse:
|
||||
if not data.expiration_time_minutes > 0:
|
||||
raise ValueError("Expiration time must be in the future.")
|
||||
account = await get_account(user.id)
|
||||
|
||||
if not account or not account.verify_password(data.password):
|
||||
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
|
||||
|
||||
if not account.username:
|
||||
raise ValueError("Username must be configured.")
|
||||
|
||||
acls = await get_user_access_control_lists(user.id)
|
||||
acls = await get_user_access_control_lists(account.id)
|
||||
acl = acls.get_acl_by_id(data.acl_id)
|
||||
if not acl:
|
||||
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid ACL id.")
|
||||
|
|
@ -218,18 +219,16 @@ async def api_create_user_api_token(
|
|||
|
||||
@auth_router.delete("/acl/token")
|
||||
async def api_delete_user_api_token(
|
||||
data: DeleteTokenRequest,
|
||||
user: User = Depends(check_user_exists),
|
||||
data: DeleteTokenRequest, account: Account = Depends(check_account_exists)
|
||||
):
|
||||
|
||||
account = await get_account(user.id)
|
||||
if not account or not account.verify_password(data.password):
|
||||
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
|
||||
|
||||
if not account.username:
|
||||
raise ValueError("Username must be configured.")
|
||||
|
||||
acls = await get_user_access_control_lists(user.id)
|
||||
acls = await get_user_access_control_lists(account.id)
|
||||
acl = acls.get_acl_by_id(data.acl_id)
|
||||
if not acl:
|
||||
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid ACL id.")
|
||||
|
|
@ -318,24 +317,20 @@ async def register(data: RegisterUser) -> JSONResponse:
|
|||
@auth_router.put("/pubkey")
|
||||
async def update_pubkey(
|
||||
data: UpdateUserPubkey,
|
||||
user: User = Depends(check_user_exists),
|
||||
account: Account = Depends(check_account_exists),
|
||||
payload: AccessTokenPayload = Depends(access_token_payload),
|
||||
) -> User | None:
|
||||
if data.user_id != user.id:
|
||||
if data.user_id != account.id:
|
||||
raise ValueError("Invalid user ID.")
|
||||
|
||||
_validate_auth_timeout(payload.auth_time)
|
||||
if (
|
||||
data.pubkey
|
||||
and data.pubkey != user.pubkey
|
||||
and data.pubkey != account.pubkey
|
||||
and await get_account_by_pubkey(data.pubkey)
|
||||
):
|
||||
raise ValueError("Public key already in use.")
|
||||
|
||||
account = await get_account(user.id)
|
||||
if not account:
|
||||
raise HTTPException(HTTPStatus.NOT_FOUND, "Account not found.")
|
||||
|
||||
account.pubkey = normalize_public_key(data.pubkey)
|
||||
await update_account(account)
|
||||
return await get_user_from_account(account)
|
||||
|
|
@ -344,23 +339,19 @@ async def update_pubkey(
|
|||
@auth_router.put("/password")
|
||||
async def update_password(
|
||||
data: UpdateUserPassword,
|
||||
user: User = Depends(check_user_exists),
|
||||
account: Account = Depends(check_account_exists),
|
||||
payload: AccessTokenPayload = Depends(access_token_payload),
|
||||
) -> User | None:
|
||||
_validate_auth_timeout(payload.auth_time)
|
||||
if data.user_id != user.id:
|
||||
if data.user_id != account.id:
|
||||
raise ValueError("Invalid user ID.")
|
||||
if (
|
||||
data.username
|
||||
and user.username != data.username
|
||||
and account.username != data.username
|
||||
and await get_account_by_username(data.username)
|
||||
):
|
||||
raise HTTPException(HTTPStatus.BAD_REQUEST, "Username already exists.")
|
||||
|
||||
account = await get_account(user.id)
|
||||
if not account:
|
||||
raise ValueError("Account not found.")
|
||||
|
||||
# old accounts do not have a password
|
||||
if account.password_hash:
|
||||
if not data.password_old:
|
||||
|
|
@ -419,15 +410,11 @@ async def reset_password(data: ResetUserPassword) -> JSONResponse:
|
|||
|
||||
@auth_router.put("/update")
|
||||
async def update(
|
||||
data: UpdateUser, user: User = Depends(check_user_exists)
|
||||
data: UpdateUser, account: Account = Depends(check_account_exists)
|
||||
) -> User | None:
|
||||
if data.user_id != user.id:
|
||||
if data.user_id != account.id:
|
||||
raise HTTPException(HTTPStatus.BAD_REQUEST, "Invalid user ID.")
|
||||
|
||||
account = await get_account(user.id)
|
||||
if not account:
|
||||
raise HTTPException(HTTPStatus.NOT_FOUND, "Account not found.")
|
||||
|
||||
if data.username:
|
||||
account.username = data.username
|
||||
if data.extra:
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from lnbits.core.models.extensions import (
|
|||
UserExtension,
|
||||
UserExtensionInfo,
|
||||
)
|
||||
from lnbits.core.models.users import AccountId
|
||||
from lnbits.core.services import check_transaction_status, create_invoice
|
||||
from lnbits.core.services.extensions import (
|
||||
activate_extension,
|
||||
|
|
@ -33,8 +34,9 @@ from lnbits.core.services.extensions import (
|
|||
uninstall_extension,
|
||||
)
|
||||
from lnbits.decorators import (
|
||||
check_account_exists,
|
||||
check_account_id_exists,
|
||||
check_admin,
|
||||
check_user_exists,
|
||||
)
|
||||
from lnbits.settings import settings
|
||||
|
||||
|
|
@ -163,7 +165,7 @@ async def api_update_pay_to_enable(
|
|||
|
||||
@extension_router.put("/{ext_id}/enable")
|
||||
async def api_enable_extension(
|
||||
ext_id: str, user: User = Depends(check_user_exists)
|
||||
ext_id: str, account_id: AccountId = Depends(check_account_id_exists)
|
||||
) -> SimpleStatus:
|
||||
if ext_id not in [e.code for e in await get_valid_extensions()]:
|
||||
raise HTTPException(
|
||||
|
|
@ -177,12 +179,12 @@ async def api_enable_extension(
|
|||
if not ext.active:
|
||||
raise ValueError(f"Extension '{ext_id}' is not activated.")
|
||||
|
||||
user_ext = await get_user_extension(user.id, ext_id)
|
||||
user_ext = await get_user_extension(account_id.id, ext_id)
|
||||
if not user_ext:
|
||||
user_ext = UserExtension(user=user.id, extension=ext_id, active=False)
|
||||
user_ext = UserExtension(user=account_id.id, extension=ext_id, active=False)
|
||||
await create_user_extension(user_ext)
|
||||
|
||||
if user.admin or not ext.requires_payment:
|
||||
if account_id.is_admin_id or not ext.requires_payment:
|
||||
user_ext.active = True
|
||||
await update_user_extension(user_ext)
|
||||
return SimpleStatus(success=True, message=f"Extension '{ext_id}' enabled.")
|
||||
|
|
@ -219,13 +221,13 @@ async def api_enable_extension(
|
|||
|
||||
@extension_router.put("/{ext_id}/disable")
|
||||
async def api_disable_extension(
|
||||
ext_id: str, user: User = Depends(check_user_exists)
|
||||
ext_id: str, account_id: AccountId = Depends(check_account_id_exists)
|
||||
) -> SimpleStatus:
|
||||
if ext_id not in [e.code for e in await get_valid_extensions()]:
|
||||
raise HTTPException(
|
||||
HTTPStatus.BAD_REQUEST, f"Extension '{ext_id}' doesn't exist."
|
||||
)
|
||||
user_ext = await get_user_extension(user.id, ext_id)
|
||||
user_ext = await get_user_extension(account_id.id, ext_id)
|
||||
if not user_ext or not user_ext.active:
|
||||
return SimpleStatus(
|
||||
success=True, message=f"Extension '{ext_id}' already disabled."
|
||||
|
|
@ -376,7 +378,9 @@ async def get_pay_to_install_invoice(
|
|||
|
||||
@extension_router.put("/{ext_id}/invoice/enable")
|
||||
async def get_pay_to_enable_invoice(
|
||||
ext_id: str, data: PayToEnableInfo, user: User = Depends(check_user_exists)
|
||||
ext_id: str,
|
||||
data: PayToEnableInfo,
|
||||
account_id: AccountId = Depends(check_account_id_exists),
|
||||
):
|
||||
if not data.amount or data.amount <= 0:
|
||||
raise HTTPException(
|
||||
|
|
@ -422,9 +426,9 @@ async def get_pay_to_enable_invoice(
|
|||
memo=f"Enable '{ext.name}' extension.",
|
||||
)
|
||||
|
||||
user_ext = await get_user_extension(user.id, ext_id)
|
||||
user_ext = await get_user_extension(account_id.id, ext_id)
|
||||
if not user_ext:
|
||||
user_ext = UserExtension(user=user.id, extension=ext_id, active=False)
|
||||
user_ext = UserExtension(user=account_id.id, extension=ext_id, active=False)
|
||||
await create_user_extension(user_ext)
|
||||
user_ext_info = user_ext.extra if user_ext.extra else UserExtensionInfo()
|
||||
user_ext_info.payment_hash_to_enable = payment.payment_hash
|
||||
|
|
@ -435,7 +439,7 @@ async def get_pay_to_enable_invoice(
|
|||
|
||||
@extension_router.get(
|
||||
"/release/{org}/{repo}/{tag_name}",
|
||||
dependencies=[Depends(check_user_exists)],
|
||||
dependencies=[Depends(check_account_exists)],
|
||||
)
|
||||
async def get_extension_release(org: str, repo: str, tag_name: str):
|
||||
try:
|
||||
|
|
@ -456,10 +460,12 @@ async def get_extension_release(org: str, repo: str, tag_name: str):
|
|||
|
||||
@extension_router.get("")
|
||||
async def api_get_user_extensions(
|
||||
user: User = Depends(check_user_exists),
|
||||
account_id: AccountId = Depends(check_account_id_exists),
|
||||
) -> list[Extension]:
|
||||
|
||||
user_extensions_ids = [ue.extension for ue in await get_user_extensions(user.id)]
|
||||
user_extensions_ids = [
|
||||
ue.extension for ue in await get_user_extensions(account_id.id)
|
||||
]
|
||||
return [
|
||||
ext
|
||||
for ext in await get_valid_extensions(False)
|
||||
|
|
@ -498,7 +504,7 @@ async def delete_extension_db(ext_id: str):
|
|||
|
||||
# TODO: create a response model for this
|
||||
@extension_router.get("/all")
|
||||
async def extensions(user: User = Depends(check_user_exists)):
|
||||
async def extensions(account_id: AccountId = Depends(check_account_id_exists)):
|
||||
installed_exts: list[InstallableExtension] = await get_installed_extensions()
|
||||
installed_exts_ids = [e.id for e in installed_exts]
|
||||
|
||||
|
|
@ -510,7 +516,7 @@ async def extensions(user: User = Depends(check_user_exists)):
|
|||
installed_ext = next((ie for ie in installed_exts if e.id == ie.id), None)
|
||||
if installed_ext and installed_ext.meta:
|
||||
installed_release = installed_ext.meta.installed_release
|
||||
if installed_ext.meta.pay_to_enable and not user.admin:
|
||||
if installed_ext.meta.pay_to_enable and not account_id.is_admin_id:
|
||||
# not a security leak, but better not to share the wallet id
|
||||
installed_ext.meta.pay_to_enable.wallet = None
|
||||
pay_to_enable = installed_ext.meta.pay_to_enable
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from lnbits.core.models.extensions import (
|
|||
UserExtension,
|
||||
)
|
||||
from lnbits.core.models.extensions_builder import ExtensionData
|
||||
from lnbits.core.models.users import AccountId
|
||||
from lnbits.core.services.extensions import (
|
||||
activate_extension,
|
||||
install_extension,
|
||||
|
|
@ -26,9 +27,9 @@ from lnbits.core.services.extensions_builder import (
|
|||
zip_directory,
|
||||
)
|
||||
from lnbits.decorators import (
|
||||
check_account_id_exists,
|
||||
check_admin,
|
||||
check_extension_builder,
|
||||
check_user_exists,
|
||||
)
|
||||
|
||||
from ..crud import (
|
||||
|
|
@ -128,10 +129,10 @@ async def api_deploy_extension(
|
|||
)
|
||||
async def api_preview_extension(
|
||||
data: ExtensionData,
|
||||
user: User = Depends(check_user_exists),
|
||||
account_id: AccountId = Depends(check_account_id_exists),
|
||||
) -> SimpleStatus:
|
||||
stub_ext_id = "extension_builder_stub"
|
||||
working_dir_name = "preview_" + sha256(user.id.encode("utf-8")).hexdigest()
|
||||
working_dir_name = "preview_" + sha256(account_id.id.encode("utf-8")).hexdigest()
|
||||
await build_extension_from_data(data, stub_ext_id, working_dir_name)
|
||||
|
||||
return SimpleStatus(success=True, message=f"Extension '{data.id}' preview ready.")
|
||||
|
|
|
|||
|
|
@ -35,11 +35,11 @@ from lnbits.core.models import (
|
|||
SimpleStatus,
|
||||
)
|
||||
from lnbits.core.models.payments import UpdatePaymentLabels
|
||||
from lnbits.core.models.users import User
|
||||
from lnbits.core.models.users import AccountId
|
||||
from lnbits.db import Filters, Page
|
||||
from lnbits.decorators import (
|
||||
WalletTypeInfo,
|
||||
check_user_exists,
|
||||
check_account_id_exists,
|
||||
parse_filters,
|
||||
require_admin_key,
|
||||
require_invoice_key,
|
||||
|
|
@ -118,14 +118,14 @@ async def api_payments_history(
|
|||
async def api_payments_counting_stats(
|
||||
count_by: PaymentCountField = Query("tag"),
|
||||
filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)),
|
||||
user: User = Depends(check_user_exists),
|
||||
account_id: AccountId = Depends(check_account_id_exists),
|
||||
):
|
||||
if user.admin:
|
||||
if account_id.is_admin_id:
|
||||
# admin user can see payments from all wallets
|
||||
for_user_id = None
|
||||
else:
|
||||
# regular user can only see payments from their wallets
|
||||
for_user_id = user.id
|
||||
for_user_id = account_id.id
|
||||
|
||||
return await get_payment_count_stats(count_by, filters=filters, user_id=for_user_id)
|
||||
|
||||
|
|
@ -138,14 +138,14 @@ async def api_payments_counting_stats(
|
|||
)
|
||||
async def api_payments_wallets_stats(
|
||||
filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)),
|
||||
user: User = Depends(check_user_exists),
|
||||
account_id: AccountId = Depends(check_account_id_exists),
|
||||
):
|
||||
if user.admin:
|
||||
if account_id.is_admin_id:
|
||||
# admin user can see payments from all wallets
|
||||
for_user_id = None
|
||||
else:
|
||||
# regular user can only see payments from their wallets
|
||||
for_user_id = user.id
|
||||
for_user_id = account_id.id
|
||||
|
||||
return await get_wallets_stats(filters, user_id=for_user_id)
|
||||
|
||||
|
|
@ -157,15 +157,15 @@ async def api_payments_wallets_stats(
|
|||
openapi_extra=generate_filter_params_openapi(PaymentFilters),
|
||||
)
|
||||
async def api_payments_daily_stats(
|
||||
user: User = Depends(check_user_exists),
|
||||
account_id: AccountId = Depends(check_account_id_exists),
|
||||
filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)),
|
||||
):
|
||||
if user.admin:
|
||||
if account_id.is_admin_id:
|
||||
# admin user can see payments from all wallets
|
||||
for_user_id = None
|
||||
else:
|
||||
# regular user can only see payments from their wallets
|
||||
for_user_id = user.id
|
||||
for_user_id = account_id.id
|
||||
return await get_payments_daily_stats(filters, user_id=for_user_id)
|
||||
|
||||
|
||||
|
|
@ -208,14 +208,14 @@ async def api_payments_paginated(
|
|||
)
|
||||
async def api_all_payments_paginated(
|
||||
filters: Filters = Depends(parse_filters(PaymentFilters)),
|
||||
user: User = Depends(check_user_exists),
|
||||
account_id: AccountId = Depends(check_account_id_exists),
|
||||
):
|
||||
if user.admin:
|
||||
if account_id.is_admin_id:
|
||||
# admin user can see payments from all wallets
|
||||
for_user_id = None
|
||||
else:
|
||||
# regular user can only see payments from their wallets
|
||||
for_user_id = user.id
|
||||
for_user_id = account_id.id
|
||||
|
||||
return await get_payments_paginated(
|
||||
filters=filters,
|
||||
|
|
|
|||
|
|
@ -12,9 +12,10 @@ from lnbits.core.crud.wallets import (
|
|||
create_wallet,
|
||||
get_wallets_paginated,
|
||||
)
|
||||
from lnbits.core.models import CreateWallet, KeyType, User, Wallet, WalletTypeInfo
|
||||
from lnbits.core.models import CreateWallet, KeyType, Wallet, WalletTypeInfo
|
||||
from lnbits.core.models.lnurl import StoredPayLink, StoredPayLinks
|
||||
from lnbits.core.models.misc import SimpleStatus
|
||||
from lnbits.core.models.users import Account, AccountId
|
||||
from lnbits.core.models.wallets import (
|
||||
WalletsFilters,
|
||||
WalletSharePermission,
|
||||
|
|
@ -29,7 +30,8 @@ from lnbits.core.services.wallets import (
|
|||
)
|
||||
from lnbits.db import Filters, Page
|
||||
from lnbits.decorators import (
|
||||
check_user_exists,
|
||||
check_account_exists,
|
||||
check_account_id_exists,
|
||||
parse_filters,
|
||||
require_admin_key,
|
||||
require_invoice_key,
|
||||
|
|
@ -65,11 +67,11 @@ async def api_wallet(key_info: WalletTypeInfo = Depends(require_invoice_key)):
|
|||
openapi_extra=generate_filter_params_openapi(WalletsFilters),
|
||||
)
|
||||
async def api_wallets_paginated(
|
||||
user: User = Depends(check_user_exists),
|
||||
account_id: AccountId = Depends(check_account_id_exists),
|
||||
filters: Filters = Depends(parse_filters(WalletsFilters)),
|
||||
):
|
||||
page = await get_wallets_paginated(
|
||||
user_id=user.id,
|
||||
user_id=account_id.id,
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
|
|
@ -85,7 +87,7 @@ async def api_invite_wallet_share(
|
|||
|
||||
@wallet_router.delete("/share/invite/{share_request_id}")
|
||||
async def api_reject_wallet_invitation(
|
||||
share_request_id: str, invited_user: User = Depends(check_user_exists)
|
||||
share_request_id: str, invited_user: Account = Depends(check_account_exists)
|
||||
) -> SimpleStatus:
|
||||
await reject_wallet_invitation(invited_user.id, share_request_id)
|
||||
return SimpleStatus(success=True, message="Invitation rejected.")
|
||||
|
|
@ -124,10 +126,11 @@ async def api_update_wallet_name(
|
|||
|
||||
@wallet_router.put("/reset/{wallet_id}")
|
||||
async def api_reset_wallet_keys(
|
||||
wallet_id: str, user: User = Depends(check_user_exists)
|
||||
wallet_id: str,
|
||||
account_id: AccountId = Depends(check_account_id_exists),
|
||||
) -> Wallet:
|
||||
wallet = await get_wallet(wallet_id)
|
||||
if not wallet or wallet.user != user.id:
|
||||
if not wallet or wallet.user != account_id.id:
|
||||
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Wallet not found")
|
||||
|
||||
wallet.adminkey = uuid4().hex
|
||||
|
|
@ -175,10 +178,10 @@ async def api_update_wallet(
|
|||
|
||||
@wallet_router.delete("/{wallet_id}")
|
||||
async def api_delete_wallet(
|
||||
wallet_id: str, user: User = Depends(check_user_exists)
|
||||
wallet_id: str, account_id: AccountId = Depends(check_account_id_exists)
|
||||
) -> None:
|
||||
wallet = await get_wallet(wallet_id)
|
||||
if not wallet or wallet.user != user.id:
|
||||
if not wallet or wallet.user != account_id.id:
|
||||
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Wallet not found")
|
||||
|
||||
await delete_wallet(
|
||||
|
|
|
|||
|
|
@ -15,9 +15,9 @@ from lnbits.core.models import (
|
|||
CreateWebPushSubscription,
|
||||
WebPushSubscription,
|
||||
)
|
||||
from lnbits.core.models.users import User
|
||||
from lnbits.core.models.users import AccountId
|
||||
from lnbits.decorators import (
|
||||
check_user_exists,
|
||||
check_account_id_exists,
|
||||
)
|
||||
|
||||
from ..crud import (
|
||||
|
|
@ -33,20 +33,20 @@ webpush_router = APIRouter(prefix="/api/v1/webpush", tags=["Webpush"])
|
|||
async def api_create_webpush_subscription(
|
||||
request: Request,
|
||||
data: CreateWebPushSubscription,
|
||||
user: User = Depends(check_user_exists),
|
||||
account_id: AccountId = Depends(check_account_id_exists),
|
||||
) -> WebPushSubscription:
|
||||
try:
|
||||
subscription = json.loads(data.subscription)
|
||||
endpoint = subscription["endpoint"]
|
||||
host = urlparse(str(request.url)).netloc
|
||||
|
||||
subscription = await get_webpush_subscription(endpoint, user.id)
|
||||
subscription = await get_webpush_subscription(endpoint, account_id.id)
|
||||
if subscription:
|
||||
return subscription
|
||||
else:
|
||||
return await create_webpush_subscription(
|
||||
endpoint,
|
||||
user.id,
|
||||
account_id.id,
|
||||
data.subscription,
|
||||
host,
|
||||
)
|
||||
|
|
@ -61,13 +61,13 @@ async def api_create_webpush_subscription(
|
|||
@webpush_router.delete("", status_code=HTTPStatus.OK)
|
||||
async def api_delete_webpush_subscription(
|
||||
request: Request,
|
||||
user: User = Depends(check_user_exists),
|
||||
account_id: AccountId = Depends(check_account_id_exists),
|
||||
):
|
||||
try:
|
||||
endpoint = unquote(
|
||||
base64.b64decode(str(request.query_params.get("endpoint"))).decode("utf-8")
|
||||
)
|
||||
count = await delete_webpush_subscription(endpoint, user.id)
|
||||
count = await delete_webpush_subscription(endpoint, account_id.id)
|
||||
return {"count": count}
|
||||
except Exception as exc:
|
||||
logger.debug(exc)
|
||||
|
|
|
|||
|
|
@ -27,9 +27,11 @@ from lnbits.core.models import (
|
|||
User,
|
||||
WalletTypeInfo,
|
||||
)
|
||||
from lnbits.core.models.users import AccountId
|
||||
from lnbits.db import Connection, Filter, Filters, TFilterModel
|
||||
from lnbits.helpers import normalize_path, path_segments
|
||||
from lnbits.helpers import normalize_path, path_segments, sha256s
|
||||
from lnbits.settings import AuthMethods, settings
|
||||
from lnbits.utils.cache import cache
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(
|
||||
tokenUrl="api/v1/auth",
|
||||
|
|
@ -106,7 +108,7 @@ class KeyChecker(SecurityBase):
|
|||
detail="Invalid adminkey.",
|
||||
)
|
||||
|
||||
await _check_user_extension_access(wallet.user, request["path"])
|
||||
await _check_user_access(request, wallet.user)
|
||||
|
||||
key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice
|
||||
return WalletTypeInfo(key_type, wallet)
|
||||
|
|
@ -144,11 +146,49 @@ async def check_access_token(
|
|||
return header_access_token or cookie_access_token or bearer_access_token
|
||||
|
||||
|
||||
async def check_user_exists(
|
||||
async def check_account_id_exists(
|
||||
r: Request,
|
||||
access_token: Annotated[str | None, Depends(check_access_token)],
|
||||
usr: UUID4 | None = None,
|
||||
) -> User:
|
||||
) -> AccountId:
|
||||
cache_key: str | None = None
|
||||
if access_token:
|
||||
cache_key = f"auth:access_token:{sha256s(access_token)}"
|
||||
elif usr:
|
||||
cache_key = f"auth:user_id:{sha256s(usr.hex)}"
|
||||
|
||||
if cache_key and settings.auth_authentication_cache_minutes > 0:
|
||||
account_id = cache.get(cache_key)
|
||||
if account_id:
|
||||
r.scope["user_id"] = account_id.id
|
||||
await _check_user_access(r, account_id)
|
||||
return account_id
|
||||
|
||||
account = await check_account_exists(r, access_token, usr)
|
||||
account_id = AccountId(id=account.id)
|
||||
|
||||
if cache_key and settings.auth_authentication_cache_minutes > 0:
|
||||
cache.set(
|
||||
cache_key,
|
||||
account_id,
|
||||
expiry=settings.auth_authentication_cache_minutes * 60,
|
||||
)
|
||||
|
||||
return account_id
|
||||
|
||||
|
||||
async def check_account_exists(
|
||||
r: Request,
|
||||
access_token: Annotated[str | None, Depends(check_access_token)],
|
||||
usr: UUID4 | None = None,
|
||||
) -> Account:
|
||||
"""
|
||||
Check that the account exists based on access token or user id.
|
||||
More performant version of `check_user_exists`.
|
||||
Unlike `check_user_exists`, this function:
|
||||
- does not fetch the user wallets
|
||||
- caches the account info based on settings cache time
|
||||
"""
|
||||
if access_token:
|
||||
account = await _get_account_from_token(access_token, r["path"], r["method"])
|
||||
elif usr and settings.is_auth_method_allowed(AuthMethods.user_id_only):
|
||||
|
|
@ -164,13 +204,21 @@ async def check_user_exists(
|
|||
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.")
|
||||
|
||||
r.scope["user_id"] = account.id
|
||||
if not settings.is_user_allowed(account.id):
|
||||
raise HTTPException(HTTPStatus.FORBIDDEN, "User not allowed.")
|
||||
await _check_user_access(r, account.id)
|
||||
|
||||
return account
|
||||
|
||||
|
||||
async def check_user_exists(
|
||||
r: Request,
|
||||
access_token: Annotated[str | None, Depends(check_access_token)],
|
||||
usr: UUID4 | None = None,
|
||||
) -> User:
|
||||
account = await check_account_exists(r, access_token, usr)
|
||||
user = await get_user_from_account(account)
|
||||
if not user:
|
||||
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.")
|
||||
await _check_user_extension_access(user.id, r["path"])
|
||||
|
||||
return user
|
||||
|
||||
|
||||
|
|
@ -280,6 +328,12 @@ async def check_user_extension_access(
|
|||
return SimpleStatus(success=True, message="OK")
|
||||
|
||||
|
||||
async def _check_user_access(r: Request, user_id: str):
|
||||
if not settings.is_user_allowed(user_id):
|
||||
raise HTTPException(HTTPStatus.FORBIDDEN, "User not allowed.")
|
||||
await _check_user_extension_access(user_id, r["path"])
|
||||
|
||||
|
||||
async def _check_user_extension_access(user_id: str, path: str):
|
||||
ext_id = path_segments(path)[0]
|
||||
status = await check_user_extension_access(user_id, ext_id)
|
||||
|
|
|
|||
|
|
@ -756,6 +756,7 @@ class AuthSettings(LNbitsSettings):
|
|||
# How many seconds after login the user is allowed to update its credentials.
|
||||
# A fresh login is required afterwards.
|
||||
auth_credetials_update_threshold: int = Field(default=120, gt=0)
|
||||
auth_authentication_cache_minutes: int = Field(default=10, ge=0)
|
||||
|
||||
def is_auth_method_allowed(self, method: AuthMethods):
|
||||
return method.value in self.auth_allowed_methods
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue