perf: use check_account_exists decorator (#3600)
This commit is contained in:
parent
5213508dc1
commit
b3efb4d378
12 changed files with 182 additions and 114 deletions
|
|
@ -1,5 +1,6 @@
|
||||||
from .core.services import create_invoice, pay_invoice
|
from .core.services import create_invoice, pay_invoice
|
||||||
from .decorators import (
|
from .decorators import (
|
||||||
|
check_account_exists,
|
||||||
check_admin,
|
check_admin,
|
||||||
check_super_user,
|
check_super_user,
|
||||||
check_user_exists,
|
check_user_exists,
|
||||||
|
|
@ -11,6 +12,7 @@ from .exceptions import InvoiceError, PaymentError
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"InvoiceError",
|
"InvoiceError",
|
||||||
"PaymentError",
|
"PaymentError",
|
||||||
|
"check_account_exists",
|
||||||
"check_admin",
|
"check_admin",
|
||||||
"check_super_user",
|
"check_super_user",
|
||||||
"check_user_exists",
|
"check_user_exists",
|
||||||
|
|
|
||||||
|
|
@ -172,8 +172,15 @@ class UserAcls(BaseModel):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class Account(BaseModel):
|
class AccountId(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_admin_id(self) -> bool:
|
||||||
|
return settings.is_admin_user(self.id)
|
||||||
|
|
||||||
|
|
||||||
|
class Account(AccountId):
|
||||||
external_id: str | None = None # for external account linking
|
external_id: str | None = None # for external account linking
|
||||||
username: str | None = None
|
username: str | None = None
|
||||||
password_hash: str | None = None
|
password_hash: str | None = None
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,12 @@ from lnbits.core.models import (
|
||||||
User,
|
User,
|
||||||
Wallet,
|
Wallet,
|
||||||
)
|
)
|
||||||
from lnbits.decorators import check_user_exists
|
from lnbits.core.models.users import AccountId
|
||||||
|
from lnbits.decorators import (
|
||||||
|
check_account_exists,
|
||||||
|
check_account_id_exists,
|
||||||
|
check_user_exists,
|
||||||
|
)
|
||||||
from lnbits.settings import settings
|
from lnbits.settings import settings
|
||||||
from lnbits.utils.exchange_rates import (
|
from lnbits.utils.exchange_rates import (
|
||||||
allowed_currencies,
|
allowed_currencies,
|
||||||
|
|
@ -39,7 +44,9 @@ async def health() -> dict:
|
||||||
|
|
||||||
|
|
||||||
@api_router.get("/api/v1/status", status_code=HTTPStatus.OK)
|
@api_router.get("/api/v1/status", status_code=HTTPStatus.OK)
|
||||||
async def health_check(user: User = Depends(check_user_exists)) -> dict:
|
async def health_check(
|
||||||
|
account_id: AccountId = Depends(check_account_id_exists),
|
||||||
|
) -> dict:
|
||||||
stat: dict[str, Any] = {
|
stat: dict[str, Any] = {
|
||||||
"server_time": int(time()),
|
"server_time": int(time()),
|
||||||
"up_time": settings.lnbits_server_up_time,
|
"up_time": settings.lnbits_server_up_time,
|
||||||
|
|
@ -47,7 +54,7 @@ async def health_check(user: User = Depends(check_user_exists)) -> dict:
|
||||||
}
|
}
|
||||||
|
|
||||||
stat["version"] = settings.version
|
stat["version"] = settings.version
|
||||||
if not user.admin:
|
if not account_id.is_admin_id:
|
||||||
return stat
|
return stat
|
||||||
|
|
||||||
funding_source = get_funding_source()
|
funding_source = get_funding_source()
|
||||||
|
|
@ -78,7 +85,7 @@ async def api_create_account(data: CreateWallet) -> Wallet:
|
||||||
|
|
||||||
@api_router.get(
|
@api_router.get(
|
||||||
"/api/v1/rate/history",
|
"/api/v1/rate/history",
|
||||||
dependencies=[Depends(check_user_exists)],
|
dependencies=[Depends(check_account_exists)],
|
||||||
)
|
)
|
||||||
async def api_exchange_rate_history() -> list[dict]:
|
async def api_exchange_rate_history() -> list[dict]:
|
||||||
return settings.lnbits_exchange_rate_history
|
return settings.lnbits_exchange_rate_history
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,11 @@ from lnbits.core.crud.assets import (
|
||||||
)
|
)
|
||||||
from lnbits.core.models.assets import AssetFilters, AssetInfo, AssetUpdate
|
from lnbits.core.models.assets import AssetFilters, AssetInfo, AssetUpdate
|
||||||
from lnbits.core.models.misc import SimpleStatus
|
from lnbits.core.models.misc import SimpleStatus
|
||||||
from lnbits.core.models.users import User
|
from lnbits.core.models.users import AccountId
|
||||||
from lnbits.core.services.assets import create_user_asset
|
from lnbits.core.services.assets import create_user_asset
|
||||||
from lnbits.db import Filters, Page
|
from lnbits.db import Filters, Page
|
||||||
from lnbits.decorators import (
|
from lnbits.decorators import (
|
||||||
check_user_exists,
|
check_account_id_exists,
|
||||||
optional_user_id,
|
optional_user_id,
|
||||||
parse_filters,
|
parse_filters,
|
||||||
)
|
)
|
||||||
|
|
@ -35,10 +35,10 @@ upload_file_param = File(...)
|
||||||
summary="Get paginated list user assets",
|
summary="Get paginated list user assets",
|
||||||
)
|
)
|
||||||
async def api_get_user_assets(
|
async def api_get_user_assets(
|
||||||
user: User = Depends(check_user_exists),
|
account_id: AccountId = Depends(check_account_id_exists),
|
||||||
filters: Filters = Depends(parse_filters(AssetFilters)),
|
filters: Filters = Depends(parse_filters(AssetFilters)),
|
||||||
) -> Page[AssetInfo]:
|
) -> Page[AssetInfo]:
|
||||||
return await get_user_assets(user.id, filters=filters)
|
return await get_user_assets(account_id.id, filters=filters)
|
||||||
|
|
||||||
|
|
||||||
@asset_router.get(
|
@asset_router.get(
|
||||||
|
|
@ -48,9 +48,9 @@ async def api_get_user_assets(
|
||||||
)
|
)
|
||||||
async def api_get_asset(
|
async def api_get_asset(
|
||||||
asset_id: str,
|
asset_id: str,
|
||||||
user: User = Depends(check_user_exists),
|
account_id: AccountId = Depends(check_account_id_exists),
|
||||||
) -> AssetInfo:
|
) -> AssetInfo:
|
||||||
asset_info = await get_user_asset_info(user.id, asset_id)
|
asset_info = await get_user_asset_info(account_id.id, asset_id)
|
||||||
if not asset_info:
|
if not asset_info:
|
||||||
raise HTTPException(HTTPStatus.NOT_FOUND, "Asset not found.")
|
raise HTTPException(HTTPStatus.NOT_FOUND, "Asset not found.")
|
||||||
return asset_info
|
return asset_info
|
||||||
|
|
@ -120,12 +120,12 @@ async def api_get_asset_thumbnail(
|
||||||
async def api_update_asset(
|
async def api_update_asset(
|
||||||
asset_id: str,
|
asset_id: str,
|
||||||
data: AssetUpdate,
|
data: AssetUpdate,
|
||||||
user: User = Depends(check_user_exists),
|
account_id: AccountId = Depends(check_account_id_exists),
|
||||||
) -> AssetInfo:
|
) -> AssetInfo:
|
||||||
if user.admin:
|
if account_id.is_admin_id:
|
||||||
asset_info = await get_asset_info(asset_id)
|
asset_info = await get_asset_info(asset_id)
|
||||||
else:
|
else:
|
||||||
asset_info = await get_user_asset_info(user.id, asset_id)
|
asset_info = await get_user_asset_info(account_id.id, asset_id)
|
||||||
|
|
||||||
if not asset_info:
|
if not asset_info:
|
||||||
raise HTTPException(HTTPStatus.NOT_FOUND, "Asset not found.")
|
raise HTTPException(HTTPStatus.NOT_FOUND, "Asset not found.")
|
||||||
|
|
@ -144,13 +144,13 @@ async def api_update_asset(
|
||||||
summary="Upload user assets",
|
summary="Upload user assets",
|
||||||
)
|
)
|
||||||
async def api_upload_asset(
|
async def api_upload_asset(
|
||||||
user: User = Depends(check_user_exists),
|
account_id: AccountId = Depends(check_account_id_exists),
|
||||||
file: UploadFile = upload_file_param,
|
file: UploadFile = upload_file_param,
|
||||||
public_asset: bool = False,
|
public_asset: bool = False,
|
||||||
) -> AssetInfo:
|
) -> AssetInfo:
|
||||||
asset = await create_user_asset(user.id, file, public_asset)
|
asset = await create_user_asset(account_id.id, file, public_asset)
|
||||||
|
|
||||||
asset_info = await get_user_asset_info(user.id, asset.id)
|
asset_info = await get_user_asset_info(account_id.id, asset.id)
|
||||||
if not asset_info:
|
if not asset_info:
|
||||||
raise ValueError("Failed to retrieve asset info after upload.")
|
raise ValueError("Failed to retrieve asset info after upload.")
|
||||||
|
|
||||||
|
|
@ -164,11 +164,11 @@ async def api_upload_asset(
|
||||||
)
|
)
|
||||||
async def api_delete_asset(
|
async def api_delete_asset(
|
||||||
asset_id: str,
|
asset_id: str,
|
||||||
user: User = Depends(check_user_exists),
|
account_id: AccountId = Depends(check_account_id_exists),
|
||||||
) -> SimpleStatus:
|
) -> SimpleStatus:
|
||||||
asset = await get_user_asset(user.id, asset_id)
|
asset = await get_user_asset(account_id.id, asset_id)
|
||||||
if not asset:
|
if not asset:
|
||||||
raise HTTPException(HTTPStatus.NOT_FOUND, "Asset not found.")
|
raise HTTPException(HTTPStatus.NOT_FOUND, "Asset not found.")
|
||||||
|
|
||||||
await delete_user_asset(user.id, asset_id)
|
await delete_user_asset(account_id.id, asset_id)
|
||||||
return SimpleStatus(success=True, message="Asset deleted successfully.")
|
return SimpleStatus(success=True, message="Asset deleted successfully.")
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,11 @@ from lnbits.core.models.users import (
|
||||||
)
|
)
|
||||||
from lnbits.core.services import create_user_account
|
from lnbits.core.services import create_user_account
|
||||||
from lnbits.core.services.users import update_user_account
|
from lnbits.core.services.users import update_user_account
|
||||||
from lnbits.decorators import access_token_payload, check_user_exists
|
from lnbits.decorators import (
|
||||||
|
access_token_payload,
|
||||||
|
check_account_exists,
|
||||||
|
check_user_exists,
|
||||||
|
)
|
||||||
from lnbits.helpers import (
|
from lnbits.helpers import (
|
||||||
create_access_token,
|
create_access_token,
|
||||||
decrypt_internal_message,
|
decrypt_internal_message,
|
||||||
|
|
@ -119,11 +123,11 @@ async def login_usr(data: LoginUsr) -> JSONResponse:
|
||||||
@auth_router.get("/acl")
|
@auth_router.get("/acl")
|
||||||
async def api_get_user_acls(
|
async def api_get_user_acls(
|
||||||
request: Request,
|
request: Request,
|
||||||
user: User = Depends(check_user_exists),
|
account: Account = Depends(check_account_exists),
|
||||||
) -> UserAcls:
|
) -> UserAcls:
|
||||||
api_routes = get_api_routes(request.app.router.routes)
|
api_routes = get_api_routes(request.app.router.routes)
|
||||||
|
|
||||||
acls = await get_user_access_control_lists(user.id)
|
acls = await get_user_access_control_lists(account.id)
|
||||||
|
|
||||||
# Add missing/new endpoints to the ACLs
|
# Add missing/new endpoints to the ACLs
|
||||||
for acl in acls.access_control_list:
|
for acl in acls.access_control_list:
|
||||||
|
|
@ -136,7 +140,7 @@ async def api_get_user_acls(
|
||||||
acl.endpoints.append(EndpointAccess(path=path, name=name))
|
acl.endpoints.append(EndpointAccess(path=path, name=name))
|
||||||
acl.endpoints.sort(key=lambda e: e.name.lower())
|
acl.endpoints.sort(key=lambda e: e.name.lower())
|
||||||
|
|
||||||
return UserAcls(id=user.id, access_control_list=acls.access_control_list)
|
return UserAcls(id=account.id, access_control_list=acls.access_control_list)
|
||||||
|
|
||||||
|
|
||||||
@auth_router.put("/acl")
|
@auth_router.put("/acl")
|
||||||
|
|
@ -144,13 +148,13 @@ async def api_get_user_acls(
|
||||||
async def api_update_user_acl(
|
async def api_update_user_acl(
|
||||||
request: Request,
|
request: Request,
|
||||||
data: UpdateAccessControlList,
|
data: UpdateAccessControlList,
|
||||||
user: User = Depends(check_user_exists),
|
account: Account = Depends(check_account_exists),
|
||||||
) -> UserAcls:
|
) -> UserAcls:
|
||||||
account = await get_account(user.id)
|
|
||||||
if not account or not account.verify_password(data.password):
|
if not account or not account.verify_password(data.password):
|
||||||
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
|
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
|
||||||
|
|
||||||
user_acls = await get_user_access_control_lists(user.id)
|
user_acls = await get_user_access_control_lists(account.id)
|
||||||
acl = user_acls.get_acl_by_id(data.id)
|
acl = user_acls.get_acl_by_id(data.id)
|
||||||
if acl:
|
if acl:
|
||||||
user_acls.access_control_list.remove(acl)
|
user_acls.access_control_list.remove(acl)
|
||||||
|
|
@ -175,33 +179,30 @@ async def api_update_user_acl(
|
||||||
|
|
||||||
@auth_router.delete("/acl")
|
@auth_router.delete("/acl")
|
||||||
async def api_delete_user_acl(
|
async def api_delete_user_acl(
|
||||||
data: DeleteAccessControlList,
|
data: DeleteAccessControlList, account: Account = Depends(check_account_exists)
|
||||||
user: User = Depends(check_user_exists),
|
|
||||||
):
|
):
|
||||||
account = await get_account(user.id)
|
|
||||||
if not account or not account.verify_password(data.password):
|
if not account or not account.verify_password(data.password):
|
||||||
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
|
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
|
||||||
|
|
||||||
user_acls = await get_user_access_control_lists(user.id)
|
user_acls = await get_user_access_control_lists(account.id)
|
||||||
user_acls.delete_acl_by_id(data.id)
|
user_acls.delete_acl_by_id(data.id)
|
||||||
await update_user_access_control_list(user_acls)
|
await update_user_access_control_list(user_acls)
|
||||||
|
|
||||||
|
|
||||||
@auth_router.post("/acl/token")
|
@auth_router.post("/acl/token")
|
||||||
async def api_create_user_api_token(
|
async def api_create_user_api_token(
|
||||||
data: ApiTokenRequest,
|
data: ApiTokenRequest, account: Account = Depends(check_account_exists)
|
||||||
user: User = Depends(check_user_exists),
|
|
||||||
) -> ApiTokenResponse:
|
) -> ApiTokenResponse:
|
||||||
if not data.expiration_time_minutes > 0:
|
if not data.expiration_time_minutes > 0:
|
||||||
raise ValueError("Expiration time must be in the future.")
|
raise ValueError("Expiration time must be in the future.")
|
||||||
account = await get_account(user.id)
|
|
||||||
if not account or not account.verify_password(data.password):
|
if not account or not account.verify_password(data.password):
|
||||||
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
|
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
|
||||||
|
|
||||||
if not account.username:
|
if not account.username:
|
||||||
raise ValueError("Username must be configured.")
|
raise ValueError("Username must be configured.")
|
||||||
|
|
||||||
acls = await get_user_access_control_lists(user.id)
|
acls = await get_user_access_control_lists(account.id)
|
||||||
acl = acls.get_acl_by_id(data.acl_id)
|
acl = acls.get_acl_by_id(data.acl_id)
|
||||||
if not acl:
|
if not acl:
|
||||||
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid ACL id.")
|
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid ACL id.")
|
||||||
|
|
@ -218,18 +219,16 @@ async def api_create_user_api_token(
|
||||||
|
|
||||||
@auth_router.delete("/acl/token")
|
@auth_router.delete("/acl/token")
|
||||||
async def api_delete_user_api_token(
|
async def api_delete_user_api_token(
|
||||||
data: DeleteTokenRequest,
|
data: DeleteTokenRequest, account: Account = Depends(check_account_exists)
|
||||||
user: User = Depends(check_user_exists),
|
|
||||||
):
|
):
|
||||||
|
|
||||||
account = await get_account(user.id)
|
|
||||||
if not account or not account.verify_password(data.password):
|
if not account or not account.verify_password(data.password):
|
||||||
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
|
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
|
||||||
|
|
||||||
if not account.username:
|
if not account.username:
|
||||||
raise ValueError("Username must be configured.")
|
raise ValueError("Username must be configured.")
|
||||||
|
|
||||||
acls = await get_user_access_control_lists(user.id)
|
acls = await get_user_access_control_lists(account.id)
|
||||||
acl = acls.get_acl_by_id(data.acl_id)
|
acl = acls.get_acl_by_id(data.acl_id)
|
||||||
if not acl:
|
if not acl:
|
||||||
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid ACL id.")
|
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid ACL id.")
|
||||||
|
|
@ -318,24 +317,20 @@ async def register(data: RegisterUser) -> JSONResponse:
|
||||||
@auth_router.put("/pubkey")
|
@auth_router.put("/pubkey")
|
||||||
async def update_pubkey(
|
async def update_pubkey(
|
||||||
data: UpdateUserPubkey,
|
data: UpdateUserPubkey,
|
||||||
user: User = Depends(check_user_exists),
|
account: Account = Depends(check_account_exists),
|
||||||
payload: AccessTokenPayload = Depends(access_token_payload),
|
payload: AccessTokenPayload = Depends(access_token_payload),
|
||||||
) -> User | None:
|
) -> User | None:
|
||||||
if data.user_id != user.id:
|
if data.user_id != account.id:
|
||||||
raise ValueError("Invalid user ID.")
|
raise ValueError("Invalid user ID.")
|
||||||
|
|
||||||
_validate_auth_timeout(payload.auth_time)
|
_validate_auth_timeout(payload.auth_time)
|
||||||
if (
|
if (
|
||||||
data.pubkey
|
data.pubkey
|
||||||
and data.pubkey != user.pubkey
|
and data.pubkey != account.pubkey
|
||||||
and await get_account_by_pubkey(data.pubkey)
|
and await get_account_by_pubkey(data.pubkey)
|
||||||
):
|
):
|
||||||
raise ValueError("Public key already in use.")
|
raise ValueError("Public key already in use.")
|
||||||
|
|
||||||
account = await get_account(user.id)
|
|
||||||
if not account:
|
|
||||||
raise HTTPException(HTTPStatus.NOT_FOUND, "Account not found.")
|
|
||||||
|
|
||||||
account.pubkey = normalize_public_key(data.pubkey)
|
account.pubkey = normalize_public_key(data.pubkey)
|
||||||
await update_account(account)
|
await update_account(account)
|
||||||
return await get_user_from_account(account)
|
return await get_user_from_account(account)
|
||||||
|
|
@ -344,23 +339,19 @@ async def update_pubkey(
|
||||||
@auth_router.put("/password")
|
@auth_router.put("/password")
|
||||||
async def update_password(
|
async def update_password(
|
||||||
data: UpdateUserPassword,
|
data: UpdateUserPassword,
|
||||||
user: User = Depends(check_user_exists),
|
account: Account = Depends(check_account_exists),
|
||||||
payload: AccessTokenPayload = Depends(access_token_payload),
|
payload: AccessTokenPayload = Depends(access_token_payload),
|
||||||
) -> User | None:
|
) -> User | None:
|
||||||
_validate_auth_timeout(payload.auth_time)
|
_validate_auth_timeout(payload.auth_time)
|
||||||
if data.user_id != user.id:
|
if data.user_id != account.id:
|
||||||
raise ValueError("Invalid user ID.")
|
raise ValueError("Invalid user ID.")
|
||||||
if (
|
if (
|
||||||
data.username
|
data.username
|
||||||
and user.username != data.username
|
and account.username != data.username
|
||||||
and await get_account_by_username(data.username)
|
and await get_account_by_username(data.username)
|
||||||
):
|
):
|
||||||
raise HTTPException(HTTPStatus.BAD_REQUEST, "Username already exists.")
|
raise HTTPException(HTTPStatus.BAD_REQUEST, "Username already exists.")
|
||||||
|
|
||||||
account = await get_account(user.id)
|
|
||||||
if not account:
|
|
||||||
raise ValueError("Account not found.")
|
|
||||||
|
|
||||||
# old accounts do not have a password
|
# old accounts do not have a password
|
||||||
if account.password_hash:
|
if account.password_hash:
|
||||||
if not data.password_old:
|
if not data.password_old:
|
||||||
|
|
@ -419,15 +410,11 @@ async def reset_password(data: ResetUserPassword) -> JSONResponse:
|
||||||
|
|
||||||
@auth_router.put("/update")
|
@auth_router.put("/update")
|
||||||
async def update(
|
async def update(
|
||||||
data: UpdateUser, user: User = Depends(check_user_exists)
|
data: UpdateUser, account: Account = Depends(check_account_exists)
|
||||||
) -> User | None:
|
) -> User | None:
|
||||||
if data.user_id != user.id:
|
if data.user_id != account.id:
|
||||||
raise HTTPException(HTTPStatus.BAD_REQUEST, "Invalid user ID.")
|
raise HTTPException(HTTPStatus.BAD_REQUEST, "Invalid user ID.")
|
||||||
|
|
||||||
account = await get_account(user.id)
|
|
||||||
if not account:
|
|
||||||
raise HTTPException(HTTPStatus.NOT_FOUND, "Account not found.")
|
|
||||||
|
|
||||||
if data.username:
|
if data.username:
|
||||||
account.username = data.username
|
account.username = data.username
|
||||||
if data.extra:
|
if data.extra:
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@ from lnbits.core.models.extensions import (
|
||||||
UserExtension,
|
UserExtension,
|
||||||
UserExtensionInfo,
|
UserExtensionInfo,
|
||||||
)
|
)
|
||||||
|
from lnbits.core.models.users import AccountId
|
||||||
from lnbits.core.services import check_transaction_status, create_invoice
|
from lnbits.core.services import check_transaction_status, create_invoice
|
||||||
from lnbits.core.services.extensions import (
|
from lnbits.core.services.extensions import (
|
||||||
activate_extension,
|
activate_extension,
|
||||||
|
|
@ -33,8 +34,9 @@ from lnbits.core.services.extensions import (
|
||||||
uninstall_extension,
|
uninstall_extension,
|
||||||
)
|
)
|
||||||
from lnbits.decorators import (
|
from lnbits.decorators import (
|
||||||
|
check_account_exists,
|
||||||
|
check_account_id_exists,
|
||||||
check_admin,
|
check_admin,
|
||||||
check_user_exists,
|
|
||||||
)
|
)
|
||||||
from lnbits.settings import settings
|
from lnbits.settings import settings
|
||||||
|
|
||||||
|
|
@ -163,7 +165,7 @@ async def api_update_pay_to_enable(
|
||||||
|
|
||||||
@extension_router.put("/{ext_id}/enable")
|
@extension_router.put("/{ext_id}/enable")
|
||||||
async def api_enable_extension(
|
async def api_enable_extension(
|
||||||
ext_id: str, user: User = Depends(check_user_exists)
|
ext_id: str, account_id: AccountId = Depends(check_account_id_exists)
|
||||||
) -> SimpleStatus:
|
) -> SimpleStatus:
|
||||||
if ext_id not in [e.code for e in await get_valid_extensions()]:
|
if ext_id not in [e.code for e in await get_valid_extensions()]:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -177,12 +179,12 @@ async def api_enable_extension(
|
||||||
if not ext.active:
|
if not ext.active:
|
||||||
raise ValueError(f"Extension '{ext_id}' is not activated.")
|
raise ValueError(f"Extension '{ext_id}' is not activated.")
|
||||||
|
|
||||||
user_ext = await get_user_extension(user.id, ext_id)
|
user_ext = await get_user_extension(account_id.id, ext_id)
|
||||||
if not user_ext:
|
if not user_ext:
|
||||||
user_ext = UserExtension(user=user.id, extension=ext_id, active=False)
|
user_ext = UserExtension(user=account_id.id, extension=ext_id, active=False)
|
||||||
await create_user_extension(user_ext)
|
await create_user_extension(user_ext)
|
||||||
|
|
||||||
if user.admin or not ext.requires_payment:
|
if account_id.is_admin_id or not ext.requires_payment:
|
||||||
user_ext.active = True
|
user_ext.active = True
|
||||||
await update_user_extension(user_ext)
|
await update_user_extension(user_ext)
|
||||||
return SimpleStatus(success=True, message=f"Extension '{ext_id}' enabled.")
|
return SimpleStatus(success=True, message=f"Extension '{ext_id}' enabled.")
|
||||||
|
|
@ -219,13 +221,13 @@ async def api_enable_extension(
|
||||||
|
|
||||||
@extension_router.put("/{ext_id}/disable")
|
@extension_router.put("/{ext_id}/disable")
|
||||||
async def api_disable_extension(
|
async def api_disable_extension(
|
||||||
ext_id: str, user: User = Depends(check_user_exists)
|
ext_id: str, account_id: AccountId = Depends(check_account_id_exists)
|
||||||
) -> SimpleStatus:
|
) -> SimpleStatus:
|
||||||
if ext_id not in [e.code for e in await get_valid_extensions()]:
|
if ext_id not in [e.code for e in await get_valid_extensions()]:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
HTTPStatus.BAD_REQUEST, f"Extension '{ext_id}' doesn't exist."
|
HTTPStatus.BAD_REQUEST, f"Extension '{ext_id}' doesn't exist."
|
||||||
)
|
)
|
||||||
user_ext = await get_user_extension(user.id, ext_id)
|
user_ext = await get_user_extension(account_id.id, ext_id)
|
||||||
if not user_ext or not user_ext.active:
|
if not user_ext or not user_ext.active:
|
||||||
return SimpleStatus(
|
return SimpleStatus(
|
||||||
success=True, message=f"Extension '{ext_id}' already disabled."
|
success=True, message=f"Extension '{ext_id}' already disabled."
|
||||||
|
|
@ -376,7 +378,9 @@ async def get_pay_to_install_invoice(
|
||||||
|
|
||||||
@extension_router.put("/{ext_id}/invoice/enable")
|
@extension_router.put("/{ext_id}/invoice/enable")
|
||||||
async def get_pay_to_enable_invoice(
|
async def get_pay_to_enable_invoice(
|
||||||
ext_id: str, data: PayToEnableInfo, user: User = Depends(check_user_exists)
|
ext_id: str,
|
||||||
|
data: PayToEnableInfo,
|
||||||
|
account_id: AccountId = Depends(check_account_id_exists),
|
||||||
):
|
):
|
||||||
if not data.amount or data.amount <= 0:
|
if not data.amount or data.amount <= 0:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
@ -422,9 +426,9 @@ async def get_pay_to_enable_invoice(
|
||||||
memo=f"Enable '{ext.name}' extension.",
|
memo=f"Enable '{ext.name}' extension.",
|
||||||
)
|
)
|
||||||
|
|
||||||
user_ext = await get_user_extension(user.id, ext_id)
|
user_ext = await get_user_extension(account_id.id, ext_id)
|
||||||
if not user_ext:
|
if not user_ext:
|
||||||
user_ext = UserExtension(user=user.id, extension=ext_id, active=False)
|
user_ext = UserExtension(user=account_id.id, extension=ext_id, active=False)
|
||||||
await create_user_extension(user_ext)
|
await create_user_extension(user_ext)
|
||||||
user_ext_info = user_ext.extra if user_ext.extra else UserExtensionInfo()
|
user_ext_info = user_ext.extra if user_ext.extra else UserExtensionInfo()
|
||||||
user_ext_info.payment_hash_to_enable = payment.payment_hash
|
user_ext_info.payment_hash_to_enable = payment.payment_hash
|
||||||
|
|
@ -435,7 +439,7 @@ async def get_pay_to_enable_invoice(
|
||||||
|
|
||||||
@extension_router.get(
|
@extension_router.get(
|
||||||
"/release/{org}/{repo}/{tag_name}",
|
"/release/{org}/{repo}/{tag_name}",
|
||||||
dependencies=[Depends(check_user_exists)],
|
dependencies=[Depends(check_account_exists)],
|
||||||
)
|
)
|
||||||
async def get_extension_release(org: str, repo: str, tag_name: str):
|
async def get_extension_release(org: str, repo: str, tag_name: str):
|
||||||
try:
|
try:
|
||||||
|
|
@ -456,10 +460,12 @@ async def get_extension_release(org: str, repo: str, tag_name: str):
|
||||||
|
|
||||||
@extension_router.get("")
|
@extension_router.get("")
|
||||||
async def api_get_user_extensions(
|
async def api_get_user_extensions(
|
||||||
user: User = Depends(check_user_exists),
|
account_id: AccountId = Depends(check_account_id_exists),
|
||||||
) -> list[Extension]:
|
) -> list[Extension]:
|
||||||
|
|
||||||
user_extensions_ids = [ue.extension for ue in await get_user_extensions(user.id)]
|
user_extensions_ids = [
|
||||||
|
ue.extension for ue in await get_user_extensions(account_id.id)
|
||||||
|
]
|
||||||
return [
|
return [
|
||||||
ext
|
ext
|
||||||
for ext in await get_valid_extensions(False)
|
for ext in await get_valid_extensions(False)
|
||||||
|
|
@ -498,7 +504,7 @@ async def delete_extension_db(ext_id: str):
|
||||||
|
|
||||||
# TODO: create a response model for this
|
# TODO: create a response model for this
|
||||||
@extension_router.get("/all")
|
@extension_router.get("/all")
|
||||||
async def extensions(user: User = Depends(check_user_exists)):
|
async def extensions(account_id: AccountId = Depends(check_account_id_exists)):
|
||||||
installed_exts: list[InstallableExtension] = await get_installed_extensions()
|
installed_exts: list[InstallableExtension] = await get_installed_extensions()
|
||||||
installed_exts_ids = [e.id for e in installed_exts]
|
installed_exts_ids = [e.id for e in installed_exts]
|
||||||
|
|
||||||
|
|
@ -510,7 +516,7 @@ async def extensions(user: User = Depends(check_user_exists)):
|
||||||
installed_ext = next((ie for ie in installed_exts if e.id == ie.id), None)
|
installed_ext = next((ie for ie in installed_exts if e.id == ie.id), None)
|
||||||
if installed_ext and installed_ext.meta:
|
if installed_ext and installed_ext.meta:
|
||||||
installed_release = installed_ext.meta.installed_release
|
installed_release = installed_ext.meta.installed_release
|
||||||
if installed_ext.meta.pay_to_enable and not user.admin:
|
if installed_ext.meta.pay_to_enable and not account_id.is_admin_id:
|
||||||
# not a security leak, but better not to share the wallet id
|
# not a security leak, but better not to share the wallet id
|
||||||
installed_ext.meta.pay_to_enable.wallet = None
|
installed_ext.meta.pay_to_enable.wallet = None
|
||||||
pay_to_enable = installed_ext.meta.pay_to_enable
|
pay_to_enable = installed_ext.meta.pay_to_enable
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from lnbits.core.models.extensions import (
|
||||||
UserExtension,
|
UserExtension,
|
||||||
)
|
)
|
||||||
from lnbits.core.models.extensions_builder import ExtensionData
|
from lnbits.core.models.extensions_builder import ExtensionData
|
||||||
|
from lnbits.core.models.users import AccountId
|
||||||
from lnbits.core.services.extensions import (
|
from lnbits.core.services.extensions import (
|
||||||
activate_extension,
|
activate_extension,
|
||||||
install_extension,
|
install_extension,
|
||||||
|
|
@ -26,9 +27,9 @@ from lnbits.core.services.extensions_builder import (
|
||||||
zip_directory,
|
zip_directory,
|
||||||
)
|
)
|
||||||
from lnbits.decorators import (
|
from lnbits.decorators import (
|
||||||
|
check_account_id_exists,
|
||||||
check_admin,
|
check_admin,
|
||||||
check_extension_builder,
|
check_extension_builder,
|
||||||
check_user_exists,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..crud import (
|
from ..crud import (
|
||||||
|
|
@ -128,10 +129,10 @@ async def api_deploy_extension(
|
||||||
)
|
)
|
||||||
async def api_preview_extension(
|
async def api_preview_extension(
|
||||||
data: ExtensionData,
|
data: ExtensionData,
|
||||||
user: User = Depends(check_user_exists),
|
account_id: AccountId = Depends(check_account_id_exists),
|
||||||
) -> SimpleStatus:
|
) -> SimpleStatus:
|
||||||
stub_ext_id = "extension_builder_stub"
|
stub_ext_id = "extension_builder_stub"
|
||||||
working_dir_name = "preview_" + sha256(user.id.encode("utf-8")).hexdigest()
|
working_dir_name = "preview_" + sha256(account_id.id.encode("utf-8")).hexdigest()
|
||||||
await build_extension_from_data(data, stub_ext_id, working_dir_name)
|
await build_extension_from_data(data, stub_ext_id, working_dir_name)
|
||||||
|
|
||||||
return SimpleStatus(success=True, message=f"Extension '{data.id}' preview ready.")
|
return SimpleStatus(success=True, message=f"Extension '{data.id}' preview ready.")
|
||||||
|
|
|
||||||
|
|
@ -35,11 +35,11 @@ from lnbits.core.models import (
|
||||||
SimpleStatus,
|
SimpleStatus,
|
||||||
)
|
)
|
||||||
from lnbits.core.models.payments import UpdatePaymentLabels
|
from lnbits.core.models.payments import UpdatePaymentLabels
|
||||||
from lnbits.core.models.users import User
|
from lnbits.core.models.users import AccountId
|
||||||
from lnbits.db import Filters, Page
|
from lnbits.db import Filters, Page
|
||||||
from lnbits.decorators import (
|
from lnbits.decorators import (
|
||||||
WalletTypeInfo,
|
WalletTypeInfo,
|
||||||
check_user_exists,
|
check_account_id_exists,
|
||||||
parse_filters,
|
parse_filters,
|
||||||
require_admin_key,
|
require_admin_key,
|
||||||
require_invoice_key,
|
require_invoice_key,
|
||||||
|
|
@ -118,14 +118,14 @@ async def api_payments_history(
|
||||||
async def api_payments_counting_stats(
|
async def api_payments_counting_stats(
|
||||||
count_by: PaymentCountField = Query("tag"),
|
count_by: PaymentCountField = Query("tag"),
|
||||||
filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)),
|
filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)),
|
||||||
user: User = Depends(check_user_exists),
|
account_id: AccountId = Depends(check_account_id_exists),
|
||||||
):
|
):
|
||||||
if user.admin:
|
if account_id.is_admin_id:
|
||||||
# admin user can see payments from all wallets
|
# admin user can see payments from all wallets
|
||||||
for_user_id = None
|
for_user_id = None
|
||||||
else:
|
else:
|
||||||
# regular user can only see payments from their wallets
|
# regular user can only see payments from their wallets
|
||||||
for_user_id = user.id
|
for_user_id = account_id.id
|
||||||
|
|
||||||
return await get_payment_count_stats(count_by, filters=filters, user_id=for_user_id)
|
return await get_payment_count_stats(count_by, filters=filters, user_id=for_user_id)
|
||||||
|
|
||||||
|
|
@ -138,14 +138,14 @@ async def api_payments_counting_stats(
|
||||||
)
|
)
|
||||||
async def api_payments_wallets_stats(
|
async def api_payments_wallets_stats(
|
||||||
filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)),
|
filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)),
|
||||||
user: User = Depends(check_user_exists),
|
account_id: AccountId = Depends(check_account_id_exists),
|
||||||
):
|
):
|
||||||
if user.admin:
|
if account_id.is_admin_id:
|
||||||
# admin user can see payments from all wallets
|
# admin user can see payments from all wallets
|
||||||
for_user_id = None
|
for_user_id = None
|
||||||
else:
|
else:
|
||||||
# regular user can only see payments from their wallets
|
# regular user can only see payments from their wallets
|
||||||
for_user_id = user.id
|
for_user_id = account_id.id
|
||||||
|
|
||||||
return await get_wallets_stats(filters, user_id=for_user_id)
|
return await get_wallets_stats(filters, user_id=for_user_id)
|
||||||
|
|
||||||
|
|
@ -157,15 +157,15 @@ async def api_payments_wallets_stats(
|
||||||
openapi_extra=generate_filter_params_openapi(PaymentFilters),
|
openapi_extra=generate_filter_params_openapi(PaymentFilters),
|
||||||
)
|
)
|
||||||
async def api_payments_daily_stats(
|
async def api_payments_daily_stats(
|
||||||
user: User = Depends(check_user_exists),
|
account_id: AccountId = Depends(check_account_id_exists),
|
||||||
filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)),
|
filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)),
|
||||||
):
|
):
|
||||||
if user.admin:
|
if account_id.is_admin_id:
|
||||||
# admin user can see payments from all wallets
|
# admin user can see payments from all wallets
|
||||||
for_user_id = None
|
for_user_id = None
|
||||||
else:
|
else:
|
||||||
# regular user can only see payments from their wallets
|
# regular user can only see payments from their wallets
|
||||||
for_user_id = user.id
|
for_user_id = account_id.id
|
||||||
return await get_payments_daily_stats(filters, user_id=for_user_id)
|
return await get_payments_daily_stats(filters, user_id=for_user_id)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -208,14 +208,14 @@ async def api_payments_paginated(
|
||||||
)
|
)
|
||||||
async def api_all_payments_paginated(
|
async def api_all_payments_paginated(
|
||||||
filters: Filters = Depends(parse_filters(PaymentFilters)),
|
filters: Filters = Depends(parse_filters(PaymentFilters)),
|
||||||
user: User = Depends(check_user_exists),
|
account_id: AccountId = Depends(check_account_id_exists),
|
||||||
):
|
):
|
||||||
if user.admin:
|
if account_id.is_admin_id:
|
||||||
# admin user can see payments from all wallets
|
# admin user can see payments from all wallets
|
||||||
for_user_id = None
|
for_user_id = None
|
||||||
else:
|
else:
|
||||||
# regular user can only see payments from their wallets
|
# regular user can only see payments from their wallets
|
||||||
for_user_id = user.id
|
for_user_id = account_id.id
|
||||||
|
|
||||||
return await get_payments_paginated(
|
return await get_payments_paginated(
|
||||||
filters=filters,
|
filters=filters,
|
||||||
|
|
|
||||||
|
|
@ -12,9 +12,10 @@ from lnbits.core.crud.wallets import (
|
||||||
create_wallet,
|
create_wallet,
|
||||||
get_wallets_paginated,
|
get_wallets_paginated,
|
||||||
)
|
)
|
||||||
from lnbits.core.models import CreateWallet, KeyType, User, Wallet, WalletTypeInfo
|
from lnbits.core.models import CreateWallet, KeyType, Wallet, WalletTypeInfo
|
||||||
from lnbits.core.models.lnurl import StoredPayLink, StoredPayLinks
|
from lnbits.core.models.lnurl import StoredPayLink, StoredPayLinks
|
||||||
from lnbits.core.models.misc import SimpleStatus
|
from lnbits.core.models.misc import SimpleStatus
|
||||||
|
from lnbits.core.models.users import Account, AccountId
|
||||||
from lnbits.core.models.wallets import (
|
from lnbits.core.models.wallets import (
|
||||||
WalletsFilters,
|
WalletsFilters,
|
||||||
WalletSharePermission,
|
WalletSharePermission,
|
||||||
|
|
@ -29,7 +30,8 @@ from lnbits.core.services.wallets import (
|
||||||
)
|
)
|
||||||
from lnbits.db import Filters, Page
|
from lnbits.db import Filters, Page
|
||||||
from lnbits.decorators import (
|
from lnbits.decorators import (
|
||||||
check_user_exists,
|
check_account_exists,
|
||||||
|
check_account_id_exists,
|
||||||
parse_filters,
|
parse_filters,
|
||||||
require_admin_key,
|
require_admin_key,
|
||||||
require_invoice_key,
|
require_invoice_key,
|
||||||
|
|
@ -65,11 +67,11 @@ async def api_wallet(key_info: WalletTypeInfo = Depends(require_invoice_key)):
|
||||||
openapi_extra=generate_filter_params_openapi(WalletsFilters),
|
openapi_extra=generate_filter_params_openapi(WalletsFilters),
|
||||||
)
|
)
|
||||||
async def api_wallets_paginated(
|
async def api_wallets_paginated(
|
||||||
user: User = Depends(check_user_exists),
|
account_id: AccountId = Depends(check_account_id_exists),
|
||||||
filters: Filters = Depends(parse_filters(WalletsFilters)),
|
filters: Filters = Depends(parse_filters(WalletsFilters)),
|
||||||
):
|
):
|
||||||
page = await get_wallets_paginated(
|
page = await get_wallets_paginated(
|
||||||
user_id=user.id,
|
user_id=account_id.id,
|
||||||
filters=filters,
|
filters=filters,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -85,7 +87,7 @@ async def api_invite_wallet_share(
|
||||||
|
|
||||||
@wallet_router.delete("/share/invite/{share_request_id}")
|
@wallet_router.delete("/share/invite/{share_request_id}")
|
||||||
async def api_reject_wallet_invitation(
|
async def api_reject_wallet_invitation(
|
||||||
share_request_id: str, invited_user: User = Depends(check_user_exists)
|
share_request_id: str, invited_user: Account = Depends(check_account_exists)
|
||||||
) -> SimpleStatus:
|
) -> SimpleStatus:
|
||||||
await reject_wallet_invitation(invited_user.id, share_request_id)
|
await reject_wallet_invitation(invited_user.id, share_request_id)
|
||||||
return SimpleStatus(success=True, message="Invitation rejected.")
|
return SimpleStatus(success=True, message="Invitation rejected.")
|
||||||
|
|
@ -124,10 +126,11 @@ async def api_update_wallet_name(
|
||||||
|
|
||||||
@wallet_router.put("/reset/{wallet_id}")
|
@wallet_router.put("/reset/{wallet_id}")
|
||||||
async def api_reset_wallet_keys(
|
async def api_reset_wallet_keys(
|
||||||
wallet_id: str, user: User = Depends(check_user_exists)
|
wallet_id: str,
|
||||||
|
account_id: AccountId = Depends(check_account_id_exists),
|
||||||
) -> Wallet:
|
) -> Wallet:
|
||||||
wallet = await get_wallet(wallet_id)
|
wallet = await get_wallet(wallet_id)
|
||||||
if not wallet or wallet.user != user.id:
|
if not wallet or wallet.user != account_id.id:
|
||||||
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Wallet not found")
|
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Wallet not found")
|
||||||
|
|
||||||
wallet.adminkey = uuid4().hex
|
wallet.adminkey = uuid4().hex
|
||||||
|
|
@ -175,10 +178,10 @@ async def api_update_wallet(
|
||||||
|
|
||||||
@wallet_router.delete("/{wallet_id}")
|
@wallet_router.delete("/{wallet_id}")
|
||||||
async def api_delete_wallet(
|
async def api_delete_wallet(
|
||||||
wallet_id: str, user: User = Depends(check_user_exists)
|
wallet_id: str, account_id: AccountId = Depends(check_account_id_exists)
|
||||||
) -> None:
|
) -> None:
|
||||||
wallet = await get_wallet(wallet_id)
|
wallet = await get_wallet(wallet_id)
|
||||||
if not wallet or wallet.user != user.id:
|
if not wallet or wallet.user != account_id.id:
|
||||||
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Wallet not found")
|
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Wallet not found")
|
||||||
|
|
||||||
await delete_wallet(
|
await delete_wallet(
|
||||||
|
|
|
||||||
|
|
@ -15,9 +15,9 @@ from lnbits.core.models import (
|
||||||
CreateWebPushSubscription,
|
CreateWebPushSubscription,
|
||||||
WebPushSubscription,
|
WebPushSubscription,
|
||||||
)
|
)
|
||||||
from lnbits.core.models.users import User
|
from lnbits.core.models.users import AccountId
|
||||||
from lnbits.decorators import (
|
from lnbits.decorators import (
|
||||||
check_user_exists,
|
check_account_id_exists,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..crud import (
|
from ..crud import (
|
||||||
|
|
@ -33,20 +33,20 @@ webpush_router = APIRouter(prefix="/api/v1/webpush", tags=["Webpush"])
|
||||||
async def api_create_webpush_subscription(
|
async def api_create_webpush_subscription(
|
||||||
request: Request,
|
request: Request,
|
||||||
data: CreateWebPushSubscription,
|
data: CreateWebPushSubscription,
|
||||||
user: User = Depends(check_user_exists),
|
account_id: AccountId = Depends(check_account_id_exists),
|
||||||
) -> WebPushSubscription:
|
) -> WebPushSubscription:
|
||||||
try:
|
try:
|
||||||
subscription = json.loads(data.subscription)
|
subscription = json.loads(data.subscription)
|
||||||
endpoint = subscription["endpoint"]
|
endpoint = subscription["endpoint"]
|
||||||
host = urlparse(str(request.url)).netloc
|
host = urlparse(str(request.url)).netloc
|
||||||
|
|
||||||
subscription = await get_webpush_subscription(endpoint, user.id)
|
subscription = await get_webpush_subscription(endpoint, account_id.id)
|
||||||
if subscription:
|
if subscription:
|
||||||
return subscription
|
return subscription
|
||||||
else:
|
else:
|
||||||
return await create_webpush_subscription(
|
return await create_webpush_subscription(
|
||||||
endpoint,
|
endpoint,
|
||||||
user.id,
|
account_id.id,
|
||||||
data.subscription,
|
data.subscription,
|
||||||
host,
|
host,
|
||||||
)
|
)
|
||||||
|
|
@ -61,13 +61,13 @@ async def api_create_webpush_subscription(
|
||||||
@webpush_router.delete("", status_code=HTTPStatus.OK)
|
@webpush_router.delete("", status_code=HTTPStatus.OK)
|
||||||
async def api_delete_webpush_subscription(
|
async def api_delete_webpush_subscription(
|
||||||
request: Request,
|
request: Request,
|
||||||
user: User = Depends(check_user_exists),
|
account_id: AccountId = Depends(check_account_id_exists),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
endpoint = unquote(
|
endpoint = unquote(
|
||||||
base64.b64decode(str(request.query_params.get("endpoint"))).decode("utf-8")
|
base64.b64decode(str(request.query_params.get("endpoint"))).decode("utf-8")
|
||||||
)
|
)
|
||||||
count = await delete_webpush_subscription(endpoint, user.id)
|
count = await delete_webpush_subscription(endpoint, account_id.id)
|
||||||
return {"count": count}
|
return {"count": count}
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.debug(exc)
|
logger.debug(exc)
|
||||||
|
|
|
||||||
|
|
@ -27,9 +27,11 @@ from lnbits.core.models import (
|
||||||
User,
|
User,
|
||||||
WalletTypeInfo,
|
WalletTypeInfo,
|
||||||
)
|
)
|
||||||
|
from lnbits.core.models.users import AccountId
|
||||||
from lnbits.db import Connection, Filter, Filters, TFilterModel
|
from lnbits.db import Connection, Filter, Filters, TFilterModel
|
||||||
from lnbits.helpers import normalize_path, path_segments
|
from lnbits.helpers import normalize_path, path_segments, sha256s
|
||||||
from lnbits.settings import AuthMethods, settings
|
from lnbits.settings import AuthMethods, settings
|
||||||
|
from lnbits.utils.cache import cache
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(
|
oauth2_scheme = OAuth2PasswordBearer(
|
||||||
tokenUrl="api/v1/auth",
|
tokenUrl="api/v1/auth",
|
||||||
|
|
@ -106,7 +108,7 @@ class KeyChecker(SecurityBase):
|
||||||
detail="Invalid adminkey.",
|
detail="Invalid adminkey.",
|
||||||
)
|
)
|
||||||
|
|
||||||
await _check_user_extension_access(wallet.user, request["path"])
|
await _check_user_access(request, wallet.user)
|
||||||
|
|
||||||
key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice
|
key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice
|
||||||
return WalletTypeInfo(key_type, wallet)
|
return WalletTypeInfo(key_type, wallet)
|
||||||
|
|
@ -144,11 +146,49 @@ async def check_access_token(
|
||||||
return header_access_token or cookie_access_token or bearer_access_token
|
return header_access_token or cookie_access_token or bearer_access_token
|
||||||
|
|
||||||
|
|
||||||
async def check_user_exists(
|
async def check_account_id_exists(
|
||||||
r: Request,
|
r: Request,
|
||||||
access_token: Annotated[str | None, Depends(check_access_token)],
|
access_token: Annotated[str | None, Depends(check_access_token)],
|
||||||
usr: UUID4 | None = None,
|
usr: UUID4 | None = None,
|
||||||
) -> User:
|
) -> AccountId:
|
||||||
|
cache_key: str | None = None
|
||||||
|
if access_token:
|
||||||
|
cache_key = f"auth:access_token:{sha256s(access_token)}"
|
||||||
|
elif usr:
|
||||||
|
cache_key = f"auth:user_id:{sha256s(usr.hex)}"
|
||||||
|
|
||||||
|
if cache_key and settings.auth_authentication_cache_minutes > 0:
|
||||||
|
account_id = cache.get(cache_key)
|
||||||
|
if account_id:
|
||||||
|
r.scope["user_id"] = account_id.id
|
||||||
|
await _check_user_access(r, account_id)
|
||||||
|
return account_id
|
||||||
|
|
||||||
|
account = await check_account_exists(r, access_token, usr)
|
||||||
|
account_id = AccountId(id=account.id)
|
||||||
|
|
||||||
|
if cache_key and settings.auth_authentication_cache_minutes > 0:
|
||||||
|
cache.set(
|
||||||
|
cache_key,
|
||||||
|
account_id,
|
||||||
|
expiry=settings.auth_authentication_cache_minutes * 60,
|
||||||
|
)
|
||||||
|
|
||||||
|
return account_id
|
||||||
|
|
||||||
|
|
||||||
|
async def check_account_exists(
|
||||||
|
r: Request,
|
||||||
|
access_token: Annotated[str | None, Depends(check_access_token)],
|
||||||
|
usr: UUID4 | None = None,
|
||||||
|
) -> Account:
|
||||||
|
"""
|
||||||
|
Check that the account exists based on access token or user id.
|
||||||
|
More performant version of `check_user_exists`.
|
||||||
|
Unlike `check_user_exists`, this function:
|
||||||
|
- does not fetch the user wallets
|
||||||
|
- caches the account info based on settings cache time
|
||||||
|
"""
|
||||||
if access_token:
|
if access_token:
|
||||||
account = await _get_account_from_token(access_token, r["path"], r["method"])
|
account = await _get_account_from_token(access_token, r["path"], r["method"])
|
||||||
elif usr and settings.is_auth_method_allowed(AuthMethods.user_id_only):
|
elif usr and settings.is_auth_method_allowed(AuthMethods.user_id_only):
|
||||||
|
|
@ -164,13 +204,21 @@ async def check_user_exists(
|
||||||
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.")
|
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.")
|
||||||
|
|
||||||
r.scope["user_id"] = account.id
|
r.scope["user_id"] = account.id
|
||||||
if not settings.is_user_allowed(account.id):
|
await _check_user_access(r, account.id)
|
||||||
raise HTTPException(HTTPStatus.FORBIDDEN, "User not allowed.")
|
|
||||||
|
|
||||||
|
return account
|
||||||
|
|
||||||
|
|
||||||
|
async def check_user_exists(
|
||||||
|
r: Request,
|
||||||
|
access_token: Annotated[str | None, Depends(check_access_token)],
|
||||||
|
usr: UUID4 | None = None,
|
||||||
|
) -> User:
|
||||||
|
account = await check_account_exists(r, access_token, usr)
|
||||||
user = await get_user_from_account(account)
|
user = await get_user_from_account(account)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.")
|
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.")
|
||||||
await _check_user_extension_access(user.id, r["path"])
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -280,6 +328,12 @@ async def check_user_extension_access(
|
||||||
return SimpleStatus(success=True, message="OK")
|
return SimpleStatus(success=True, message="OK")
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_user_access(r: Request, user_id: str):
|
||||||
|
if not settings.is_user_allowed(user_id):
|
||||||
|
raise HTTPException(HTTPStatus.FORBIDDEN, "User not allowed.")
|
||||||
|
await _check_user_extension_access(user_id, r["path"])
|
||||||
|
|
||||||
|
|
||||||
async def _check_user_extension_access(user_id: str, path: str):
|
async def _check_user_extension_access(user_id: str, path: str):
|
||||||
ext_id = path_segments(path)[0]
|
ext_id = path_segments(path)[0]
|
||||||
status = await check_user_extension_access(user_id, ext_id)
|
status = await check_user_extension_access(user_id, ext_id)
|
||||||
|
|
|
||||||
|
|
@ -756,6 +756,7 @@ class AuthSettings(LNbitsSettings):
|
||||||
# How many seconds after login the user is allowed to update its credentials.
|
# How many seconds after login the user is allowed to update its credentials.
|
||||||
# A fresh login is required afterwards.
|
# A fresh login is required afterwards.
|
||||||
auth_credetials_update_threshold: int = Field(default=120, gt=0)
|
auth_credetials_update_threshold: int = Field(default=120, gt=0)
|
||||||
|
auth_authentication_cache_minutes: int = Field(default=10, ge=0)
|
||||||
|
|
||||||
def is_auth_method_allowed(self, method: AuthMethods):
|
def is_auth_method_allowed(self, method: AuthMethods):
|
||||||
return method.value in self.auth_allowed_methods
|
return method.value in self.auth_allowed_methods
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue