[feat] access control lists (with access tokens) (#2864)

This commit is contained in:
Vlad Stan 2025-01-16 17:25:27 +02:00 committed by GitHub
parent f415a92914
commit b164317121
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 2131 additions and 67 deletions

View file

@ -58,6 +58,7 @@ from .users import (
get_account_by_username_or_email,
get_accounts,
get_user,
get_user_access_control_lists,
get_user_from_account,
update_account,
)
@ -145,6 +146,7 @@ __all__ = [
"get_accounts",
"get_user",
"get_user_from_account",
"get_user_access_control_lists",
"update_account",
# wallets
"create_wallet",

View file

@ -6,6 +6,7 @@ from uuid import uuid4
from lnbits.core.crud.extensions import get_user_active_extensions_ids
from lnbits.core.crud.wallets import get_wallets
from lnbits.core.db import db
from lnbits.core.models import UserAcls
from lnbits.db import Connection, Filters, Page
from ..models import (
@ -185,3 +186,20 @@ async def get_user_from_account(
super_user=account.is_super_user,
has_password=account.password_hash is not None,
)
async def update_user_access_control_list(user_acls: UserAcls):
user_acls.updated_at = datetime.now(timezone.utc)
await db.update("accounts", user_acls)
async def get_user_access_control_lists(
user_id: str, conn: Optional[Connection] = None
) -> UserAcls:
user_acls = await (conn or db).fetchone(
"SELECT id, access_control_list FROM accounts WHERE id = :id",
{"id": user_id},
UserAcls,
)
return user_acls or UserAcls(id=user_id)

View file

@ -684,3 +684,11 @@ async def m029_create_audit_table(db: Connection):
);
"""
)
async def m030_add_user_api_tokens_column(db: Connection):
await db.execute(
"""
ALTER TABLE accounts ADD COLUMN access_control_list TEXT
"""
)

View file

@ -36,6 +36,7 @@ from .users import (
UpdateUserPassword,
UpdateUserPubkey,
User,
UserAcls,
UserExtra,
)
from .wallets import BaseWallet, CreateWallet, KeyType, Wallet, WalletTypeInfo
@ -73,6 +74,7 @@ __all__ = [
"Account",
"AccountFilters",
"AccountOverview",
"UserAcls",
"CreateUser",
"RegisterUser",
"LoginUsernamePassword",

View file

@ -38,6 +38,11 @@ class SimpleStatus(BaseModel):
message: str
class SimpleItem(BaseModel):
id: str
name: str
class DbVersion(BaseModel):
db: str
version: int

View file

@ -8,6 +8,7 @@ from fastapi import Query
from passlib.context import CryptContext
from pydantic import BaseModel, Field
from lnbits.core.models.misc import SimpleItem
from lnbits.db import FilterModel
from lnbits.helpers import is_valid_email_address, is_valid_pubkey, is_valid_username
from lnbits.settings import settings
@ -28,6 +29,66 @@ class UserExtra(BaseModel):
provider: Optional[str] = "lnbits" # auth provider
class EndpointAccess(BaseModel):
path: str
name: str
read: bool = False
write: bool = False
def supports_method(self, method: str) -> bool:
# all http methods
if method in ["GET", "OPTIONS", "HEAD"]:
return self.read
if method in ["POST", "PUT", "PATCH", "DELETE"]:
return self.write
return False
class AccessControlList(BaseModel):
id: str
name: str
endpoints: list[EndpointAccess] = []
token_id_list: list[SimpleItem] = []
def get_endpoint(self, path: str) -> Optional[EndpointAccess]:
for e in self.endpoints:
if e.path == path:
return e
return None
def get_token_by_id(self, token_id: str) -> Optional[SimpleItem]:
for t in self.token_id_list:
if t.id == token_id:
return t
return None
def delete_token_by_id(self, token_id: str):
self.token_id_list = [t for t in self.token_id_list if t.id != token_id]
class UserAcls(BaseModel):
id: str
access_control_list: list[AccessControlList] = []
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
def get_acl_by_id(self, acl_id: str) -> Optional[AccessControlList]:
for acl in self.access_control_list:
if acl.id == acl_id:
return acl
return None
def delete_acl_by_id(self, acl_id: str):
self.access_control_list = [
acl for acl in self.access_control_list if acl.id != acl_id
]
def get_acl_by_token_id(self, token_id: str) -> Optional[AccessControlList]:
for acl in self.access_control_list:
if acl.get_token_by_id(token_id):
return acl
return None
class Account(BaseModel):
id: str
username: Optional[str] = None
@ -35,6 +96,7 @@ class Account(BaseModel):
pubkey: Optional[str] = None
email: Optional[str] = None
extra: UserExtra = UserExtra()
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@ -193,8 +255,36 @@ class AccessTokenPayload(BaseModel):
usr: Optional[str] = None
email: Optional[str] = None
auth_time: Optional[int] = 0
api_token_id: Optional[str] = None
class UpdateBalance(BaseModel):
id: str
amount: int
class ApiTokenRequest(BaseModel):
acl_id: str
token_name: str
password: str
expiration_time_minutes: int
class ApiTokenResponse(BaseModel):
id: str
api_token: str
class UpdateAccessControlList(AccessControlList):
password: str
class DeleteAccessControlList(BaseModel):
id: str
password: str
class DeleteTokenRequest(BaseModel):
id: str
acl_id: str
password: str

View file

@ -5,7 +5,7 @@
{% block scripts %} {{ window_vars(user) }}{% endblock %} {% block page %}
<div class="row q-col-gutter-md">
<div v-if="user" class="col-12 col-md-6 q-gutter-y-md">
<div v-if="user" class="col-md-12 col-lg-6 q-gutter-y-md">
<q-card>
<q-card-section>
<div class="row">
@ -21,6 +21,11 @@
:label="$t('look_and_feel')"
@update="val => tab = val.name"
></q-tab>
<q-tab
name="api_acls"
:label="$t('access_control_list')"
@update="val => tab = val.name"
></q-tab>
</q-tabs>
<q-tab-panels v-model="tab">
<q-tab-panel name="user">
@ -505,6 +510,230 @@
</div>
</div>
</q-tab-panel>
<q-tab-panel name="api_acls">
<div class="row q-mb-md">
<q-badge v-if="user.admin">
<span
v-text="$t('access_control_list_admin_warning')"
></span>
</q-badge>
</div>
<q-card-section>
<div class="row q-mb-md q-gutter-y-md">
<div class="col-sm-12 col-md-6">
<q-select
v-model="selectedApiAcl.id"
emit-value
map-options
@update:model-value="handleApiACLSelected"
:options="apiAcl.data.map(t => ({label: t.name, value: t.id}))"
:label="$t('access_control_list')"
dense
>
</q-select>
</div>
<div class="col-sm-12 col-md-6">
<q-btn
@click="askPasswordAndRunFunction('newApiAclDialog')"
filled
outline
icon="add"
:label="$t('access_control_list')"
color="grey"
class="float-right"
></q-btn>
</div>
</div>
<div v-if="selectedApiAcl.id">
<div class="row q-mb-md">
<div class="col-sm-12 col-md-6">
<q-select
:options="selectedApiAcl.token_id_list.map(t => ({label: t.name, value: t.id}))"
v-model="apiAcl.selectedTokenId"
emit-value
map-options
:label="$t('api_tokens')"
dense
>
</q-select>
</div>
<div class="col-sm-12 col-md-6 q-pl-sm">
<q-btn
v-if="apiAcl.selectedTokenId"
@click="askPasswordAndRunFunction('deleteToken')"
icon="delete"
filled
color="negative"
class="float-left"
></q-btn>
<q-btn
@click="askPasswordAndRunFunction('newTokenAclDialog')"
outline
icon="add"
:label="$t('api_token')"
filled
color="grey"
class="float-right"
></q-btn>
</div>
</div>
<div v-if="apiAcl.apiToken" class="row q-mb-md">
<div class="col-12">
<q-badge>
<span>Use this token in the HTTP</span>
<strong>
&nbsp;<code>Authorization</code>
&nbsp;
</strong>
<span> header.</span>
</q-badge>
</div>
<div class="col-12">
<table
class="full-width lnbits__table-bordered"
style="
border-collapse: collapse;
background-color: grey;
"
>
<thead>
<tr>
<th>
<span class="float-left">Header Name</span>
</th>
<th>
<span class="float-left">Header Value</span>
</th>
</tr>
</thead>
<tbody>
<tr>
<td>
<strong>Authorization</strong>
</td>
<td>
<div class="row q-pt-sm">
<div class="col-2 q-mt-sm">
<strong>Bearer &nbsp;</strong>
</div>
<div class="col-10">
<q-input
v-model="apiAcl.apiToken"
:label="$t('api_token_id')"
filled
dense
readonly
:type="selectedApiAcl.showId ? 'text': 'password'"
class="q-mb-md"
>
<q-btn
@click="selectedApiAcl.showId = !selectedApiAcl.showId"
dense
flat
:icon="selectedApiAcl.showId ? 'visibility_off' : 'visibility'"
color="black"
></q-btn>
<q-btn
@click="copyText(apiAcl.apiToken)"
icon="content_copy"
color="black"
flat
dense
></q-btn>
</q-input>
</div>
</div>
</td>
</tr>
</tbody>
</table>
</div>
<div class="col-12">
<q-badge>
<span
>Please store this token. It cannot be later
retrieved, only revoked.</span
>
</q-badge>
</div>
</div>
<q-table
row-key="path"
:rows="selectedApiAcl.endpoints"
:columns="apiAcl.columns"
v-model:pagination="apiAcl.pagination"
>
<template v-slot:header="props">
<q-tr :props="props">
<q-th
v-for="col in props.cols"
:key="col.name"
:props="props"
>
<q-toggle
v-if="col.name == 'read'"
v-model="selectedApiAcl.allRead"
@update:model-value="handleAllEndpointsReadAccess()"
:label="$t('read')"
size="sm"
></q-toggle>
<q-toggle
v-else-if="col.name == 'write'"
v-model="selectedApiAcl.allWrite"
@update:model-value="handleAllEndpointsWriteAccess()"
:label="$t('write')"
size="sm"
></q-toggle>
<span v-else v-text="col.label"></span>
</q-th>
</q-tr>
</template>
<template v-slot:body="props">
<q-tr :props="props">
<q-td>
<span v-text="props.row.name"></span>
</q-td>
<q-td>
<span v-text="props.row.path"></span>
</q-td>
<q-td>
<q-toggle size="sm" v-model="props.row.read" />
</q-td>
<q-td>
<q-toggle size="sm" v-model="props.row.write" />
</q-td>
</q-tr>
</template>
</q-table>
<q-separator></q-separator>
</div>
<div v-if="selectedApiAcl.id" class="row q-mt-md">
<div class="col-sm-12 col-md-6">
<q-btn
@click="askPasswordAndRunFunction('updateApiACLs')"
:label="$t('update')"
filled
color="primary"
></q-btn>
</div>
<div class="col-sm-12 col-md-6">
<q-btn
@click="askPasswordAndRunFunction('deleteApiACL')"
:label="$t('delete')"
icon="delete"
color="negative"
class="float-right"
>
</q-btn>
</div>
</div>
</q-card-section>
</q-tab-panel>
</q-tab-panels>
</div>
</div>
@ -519,4 +748,126 @@
</q-card>
</div>
</div>
<q-dialog v-model="apiAcl.showPasswordDialog" position="top">
<q-card class="q-pa-md q-pt-md lnbits__dialog-card">
<strong>User Password</strong>
<div class="row q-mt-md q-col-gutter-md">
<div class="col-12">
<q-input
v-model="apiAcl.password"
type="password"
dense
filled
label="Password"
hint="User password is required for this action."
>
</q-input>
</div>
</div>
<div class="row q-mt-lg">
<q-btn
@click="runPasswordGuardedFunction()"
:label="$t('ok')"
color="primary"
></q-btn>
<q-btn
v-close-popup
flat
color="grey"
class="q-ml-auto"
v-text="$t('cancel')"
></q-btn>
</div>
</q-card>
</q-dialog>
<q-dialog v-model="apiAcl.showNewAclDialog" position="top">
<q-card class="q-pa-md q-pt-md lnbits__dialog-card">
<strong>New API Access Control List</strong>
<div class="row q-mt-md q-col-gutter-md">
<div class="col-12">
<q-input v-model="apiAcl.newAclName" dense filled label="ACL Name">
</q-input>
</div>
</div>
<div class="row q-mt-lg">
<q-btn @click="addApiACL()" label="Create" color="primary"></q-btn>
<q-btn
v-close-popup
flat
color="grey"
class="q-ml-auto"
v-text="$t('close')"
></q-btn>
</div>
</q-card>
</q-dialog>
<q-dialog v-model="apiAcl.showNewTokenDialog" position="top">
<q-card class="q-pa-md q-pt-md lnbits__dialog-card">
<strong>New API Token</strong>
<div class="row q-col-gutter-md q-mt-md">
<div class="col-12">
<q-input v-model="apiAcl.newTokenName" dense filled label="Token Name">
</q-input>
</div>
<div class="col-12">
<q-input
v-model="apiAcl.newTokenExpiry"
dense
filled
label="Expiration"
hit="Expiration time in minutes (default xxx)"
>
<template v-slot:prepend>
<q-icon name="event" class="cursor-pointer">
<q-popup-proxy
cover
transition-show="scale"
transition-hide="scale"
>
<q-date v-model="apiAcl.newTokenExpiry" mask="YYYY-MM-DD HH:mm">
<div class="row items-center justify-end">
<q-btn v-close-popup label="Close" color="primary" flat />
</div>
</q-date>
</q-popup-proxy>
</q-icon>
</template>
<template v-slot:append>
<q-icon name="access_time" class="cursor-pointer">
<q-popup-proxy
cover
transition-show="scale"
transition-hide="scale"
>
<q-time
v-model="apiAcl.newTokenExpiry"
mask="YYYY-MM-DD HH:mm"
format24h
>
<div class="row items-center justify-end">
<q-btn v-close-popup label="Close" color="primary" flat />
</div>
</q-time>
</q-popup-proxy>
</q-icon>
</template>
</q-input>
</div>
</div>
<div class="row q-mt-lg">
<q-btn @click="generateApiToken()" label="Create" color="primary"></q-btn>
<q-btn
v-close-popup
flat
color="grey"
class="q-ml-auto"
v-text="$t('close')"
></q-btn>
</div>
</q-card>
</q-dialog>
{% endblock %}

View file

@ -15,7 +15,6 @@ from fastapi import (
from fastapi.exceptions import HTTPException
from fastapi.responses import StreamingResponse
from lnbits.core.crud import get_user
from lnbits.core.models import (
BaseWallet,
ConversionData,
@ -55,16 +54,12 @@ async def health() -> dict:
@api_router.get("/api/v1/status", status_code=HTTPStatus.OK)
async def health_check(wallet: WalletTypeInfo = Depends(require_invoice_key)) -> dict:
async def health_check(user: User = Depends(check_user_exists)) -> dict:
stat: dict[str, Any] = {
"server_time": int(time()),
"up_time": int(time() - settings.server_startup_time),
}
user = await get_user(wallet.wallet.user)
if not user:
return stat
stat["version"] = settings.version
if not user.admin:
return stat
@ -227,7 +222,7 @@ async def api_perform_lnurlauth(
@api_router.get(
"/api/v1/rate/history",
dependencies=[Depends(require_invoice_key)],
dependencies=[Depends(check_user_exists)],
)
async def api_exchange_rate_history() -> list[dict]:
return settings.lnbits_exchange_rate_history

View file

@ -11,12 +11,26 @@ from fastapi.responses import JSONResponse, RedirectResponse
from fastapi_sso.sso.base import OpenID, SSOBase
from loguru import logger
from lnbits.core.crud.users import (
get_user_access_control_lists,
update_user_access_control_list,
)
from lnbits.core.models.misc import SimpleItem
from lnbits.core.models.users import (
ApiTokenRequest,
ApiTokenResponse,
DeleteAccessControlList,
DeleteTokenRequest,
EndpointAccess,
UpdateAccessControlList,
)
from lnbits.core.services import create_user_account
from lnbits.decorators import access_token_payload, check_user_exists
from lnbits.helpers import (
create_access_token,
decrypt_internal_message,
encrypt_internal_message,
get_api_routes,
is_valid_email_address,
is_valid_username,
)
@ -44,6 +58,7 @@ from ..models import (
UpdateUserPassword,
UpdateUserPubkey,
User,
UserAcls,
UserExtra,
)
@ -98,6 +113,124 @@ async def login_usr(data: LoginUsr) -> JSONResponse:
return _auth_success_response(account.username, account.id, account.email)
@auth_router.get("/acl")
async def api_get_user_acls(
request: Request,
user: User = Depends(check_user_exists),
) -> UserAcls:
api_routes = get_api_routes(request.app.router.routes)
acls = await get_user_access_control_lists(user.id)
# Add missing/new endpoints to the ACLs
for acl in acls.access_control_list:
acl_api_routes = {**api_routes}
for route in api_routes.keys():
if acl.get_endpoint(route):
acl_api_routes.pop(route, None)
for path, name in acl_api_routes.items():
acl.endpoints.append(EndpointAccess(path=path, name=name))
acl.endpoints.sort(key=lambda e: e.name.lower())
return UserAcls(id=user.id, access_control_list=acls.access_control_list)
@auth_router.put("/acl")
@auth_router.patch("/acl")
async def api_update_user_acl(
request: Request,
data: UpdateAccessControlList,
user: User = Depends(check_user_exists),
) -> UserAcls:
account = await get_account(user.id)
if not account or not account.verify_password(data.password):
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
user_acls = await get_user_access_control_lists(user.id)
acl = user_acls.get_acl_by_id(data.id)
if acl:
user_acls.access_control_list.remove(acl)
else:
data.endpoints = []
data.id = uuid4().hex
api_routes = get_api_routes(request.app.router.routes)
for path, name in api_routes.items():
data.endpoints.append(EndpointAccess(path=path, name=name))
api_paths = get_api_routes(request.app.router.routes).keys()
data.endpoints = [e for e in data.endpoints if e.path in api_paths]
data.endpoints.sort(key=lambda e: e.name.lower())
user_acls.access_control_list.append(data)
user_acls.access_control_list.sort(key=lambda t: t.name.lower())
await update_user_access_control_list(user_acls)
return user_acls
@auth_router.delete("/acl")
async def api_delete_user_acl(
data: DeleteAccessControlList,
user: User = Depends(check_user_exists),
):
account = await get_account(user.id)
if not account or not account.verify_password(data.password):
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
user_acls = await get_user_access_control_lists(user.id)
user_acls.delete_acl_by_id(data.id)
await update_user_access_control_list(user_acls)
@auth_router.post("/acl/token")
async def api_create_user_api_token(
data: ApiTokenRequest,
user: User = Depends(check_user_exists),
) -> ApiTokenResponse:
assert data.expiration_time_minutes > 0, "Expiration time must be in the future."
account = await get_account(user.id)
if not account or not account.verify_password(data.password):
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
assert account.username, "Username must be configured."
acls = await get_user_access_control_lists(user.id)
acl = acls.get_acl_by_id(data.acl_id)
if not acl:
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid ACL id.")
api_token_id = uuid4().hex
api_token = _auth_api_token_response(
account.username, api_token_id, data.expiration_time_minutes
)
acl.token_id_list.append(SimpleItem(id=api_token_id, name=data.token_name))
await update_user_access_control_list(acls)
return ApiTokenResponse(id=api_token_id, api_token=api_token)
@auth_router.delete("/acl/token")
async def api_delete_user_api_token(
data: DeleteTokenRequest,
user: User = Depends(check_user_exists),
):
account = await get_account(user.id)
if not account or not account.verify_password(data.password):
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
assert account.username, "Username must be configured."
acls = await get_user_access_control_lists(user.id)
acl = acls.get_acl_by_id(data.acl_id)
if not acl:
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid ACL id.")
acl.delete_token_by_id(data.id)
await update_user_access_control_list(acls)
@auth_router.get("/{provider}", description="SSO Provider")
async def login_with_sso_provider(
request: Request, provider: str, user_id: Optional[str] = None
@ -370,6 +503,17 @@ def _auth_success_response(
return response
def _auth_api_token_response(
username: str, api_token_id: str, token_expire_minutes: int
):
payload = AccessTokenPayload(
sub=username, api_token_id=api_token_id, auth_time=int(time())
)
return create_access_token(
data=payload.dict(), token_expire_minutes=token_expire_minutes
)
def _auth_redirect_response(path: str, email: str) -> RedirectResponse:
payload = AccessTokenPayload(sub="" or "", email=email, auth_time=int(time()))
access_token = create_access_token(data=payload.dict())

View file

@ -15,9 +15,9 @@ from lnbits.core.models import (
CreateWebPushSubscription,
WebPushSubscription,
)
from lnbits.core.models.users import User
from lnbits.decorators import (
WalletTypeInfo,
require_admin_key,
check_user_exists,
)
from ..crud import (
@ -33,20 +33,20 @@ webpush_router = APIRouter(prefix="/api/v1/webpush", tags=["Webpush"])
async def api_create_webpush_subscription(
request: Request,
data: CreateWebPushSubscription,
wallet: WalletTypeInfo = Depends(require_admin_key),
user: User = Depends(check_user_exists),
) -> WebPushSubscription:
try:
subscription = json.loads(data.subscription)
endpoint = subscription["endpoint"]
host = urlparse(str(request.url)).netloc
subscription = await get_webpush_subscription(endpoint, wallet.wallet.user)
subscription = await get_webpush_subscription(endpoint, user.id)
if subscription:
return subscription
else:
return await create_webpush_subscription(
endpoint,
wallet.wallet.user,
user.id,
data.subscription,
host,
)
@ -61,13 +61,13 @@ async def api_create_webpush_subscription(
@webpush_router.delete("", status_code=HTTPStatus.OK)
async def api_delete_webpush_subscription(
request: Request,
wallet: WalletTypeInfo = Depends(require_admin_key),
user: User = Depends(check_user_exists),
):
try:
endpoint = unquote(
base64.b64decode(str(request.query_params.get("endpoint"))).decode("utf-8")
)
count = await delete_webpush_subscription(endpoint, wallet.wallet.user)
count = await delete_webpush_subscription(endpoint, user.id)
return {"count": count}
except Exception as exc:
logger.debug(exc)

View file

@ -18,6 +18,7 @@ from lnbits.core.crud import (
get_user_from_account,
get_wallet_for_key,
)
from lnbits.core.crud.users import get_user_access_control_lists
from lnbits.core.models import (
AccessTokenPayload,
Account,
@ -140,7 +141,7 @@ async def check_user_exists(
usr: Optional[UUID4] = None,
) -> User:
if access_token:
account = await _get_account_from_token(access_token)
account = await _get_account_from_token(access_token, r["path"], r["method"])
elif usr and settings.is_auth_method_allowed(AuthMethods.user_id_only):
account = await get_account(usr.hex)
else:
@ -161,13 +162,14 @@ async def check_user_exists(
async def optional_user_id(
r: Request,
access_token: Annotated[Optional[str], Depends(check_access_token)],
usr: Optional[UUID4] = None,
) -> Optional[str]:
if usr and settings.is_auth_method_allowed(AuthMethods.user_id_only):
return usr.hex
if access_token:
account = await _get_account_from_token(access_token)
account = await _get_account_from_token(access_token, r["path"], r["method"])
return account.id if account else None
return None
@ -257,9 +259,8 @@ async def check_user_extension_access(
return SimpleStatus(success=True, message="OK")
async def _check_user_extension_access(user_id: str, current_path: str):
path = current_path.split("/")
ext_id = path[3] if path[1] == "upgrades" else path[1]
async def _check_user_extension_access(user_id: str, path: str):
ext_id = _path_segments(path)[0]
status = await check_user_extension_access(user_id, ext_id)
if not status.success:
raise HTTPException(
@ -268,17 +269,15 @@ async def _check_user_extension_access(user_id: str, current_path: str):
)
async def _get_account_from_token(access_token) -> Optional[Account]:
async def _get_account_from_token(
access_token: str, path: str, method: str
) -> Optional[Account]:
try:
payload: dict = jwt.decode(access_token, settings.auth_secret_key, ["HS256"])
user = await _get_user_from_jwt_payload(payload)
if not user:
raise HTTPException(
HTTPStatus.UNAUTHORIZED, "Data missing for access token."
return await _get_account_from_jwt_payload(
AccessTokenPayload(**payload), path, method
)
return user
except jwt.ExpiredSignatureError as exc:
raise HTTPException(
HTTPStatus.UNAUTHORIZED, "Session expired.", {"token-expired": "true"}
@ -288,11 +287,48 @@ async def _get_account_from_token(access_token) -> Optional[Account]:
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid access token.") from exc
async def _get_user_from_jwt_payload(payload) -> Optional[Account]:
if "sub" in payload and payload.get("sub"):
return await get_account_by_username(str(payload.get("sub")))
if "usr" in payload and payload.get("usr"):
return await get_account(str(payload.get("usr")))
if "email" in payload and payload.get("email"):
return await get_account_by_email(str(payload.get("email")))
async def _get_account_from_jwt_payload(
payload: AccessTokenPayload, path: str, method: str
) -> Optional[Account]:
account = None
if payload.sub is not None:
account = await get_account_by_username(payload.sub)
if payload.usr is not None:
account = await get_account(payload.usr)
if payload.email is not None:
account = await get_account_by_email(payload.email)
if not account:
return None
if payload.api_token_id:
await _check_account_api_access(account.id, payload.api_token_id, path, method)
return account
async def _check_account_api_access(
user_id: str, token_id: str, path: str, method: str
):
segments = path.split("/")
if len(segments) < 3:
raise HTTPException(HTTPStatus.FORBIDDEN, "Not an API endpoint.")
acls = await get_user_access_control_lists(user_id)
acl = acls.get_acl_by_token_id(token_id)
if not acl:
raise HTTPException(HTTPStatus.FORBIDDEN, "Invalid token id.")
path = "/" + "/".join(_path_segments(path)[:3])
endpoint = acl.get_endpoint(path)
if not endpoint:
raise HTTPException(HTTPStatus.FORBIDDEN, "Path not allowed.")
if not endpoint.supports_method(method):
raise HTTPException(HTTPStatus.FORBIDDEN, "Method not allowed.")
def _path_segments(path: str) -> list[str]:
segments = path.split("/")
if segments[1] == "upgrades":
return segments[3:]
return segments[1:]

View file

@ -9,6 +9,7 @@ from urllib import request
import jinja2
import jwt
import shortuuid
from fastapi.routing import APIRoute
from packaging import version
from pydantic.schema import field_schema
@ -198,12 +199,11 @@ def is_valid_pubkey(pubkey: str) -> bool:
return False
def create_access_token(data: dict):
expire = datetime.now(timezone.utc) + timedelta(
minutes=settings.auth_token_expire_minutes
)
to_encode = data.copy()
to_encode.update({"exp": expire})
def create_access_token(data: dict, token_expire_minutes: Optional[int] = None) -> str:
minutes = token_expire_minutes or settings.auth_token_expire_minutes
expire = datetime.now(timezone.utc) + timedelta(minutes=minutes)
to_encode = {k: v for k, v in data.items() if v is not None}
to_encode.update({"exp": expire}) # todo:check expiration
return jwt.encode(to_encode, settings.auth_secret_key, "HS256")
@ -270,3 +270,19 @@ def file_hash(filename):
while n := f.readinto(mv):
h.update(mv[:n])
return h.hexdigest()
def get_api_routes(routes: list) -> dict[str, str]:
data = {}
for route in routes:
if not isinstance(route, APIRoute):
continue
segments = route.path.split("/")
if len(segments) < 3:
continue
if "/".join(segments[1:3]) == "api/v1":
data["/".join(segments[0:4])] = segments[3].capitalize()
elif "/".join(segments[2:4]) == "api/v1":
data["/".join(segments[0:4])] = segments[1].capitalize()
return data

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -482,6 +482,12 @@ body.body--dark .q-field--error .q-field__messages {
width: 500px;
}
.lnbits__table-bordered td,
.lnbits__table-bordered th {
border: 1px solid black;
border-collapse: collapse;
}
.q-table--dense th:first-child,
.q-table--dense td:first-child,
.q-table--dense .q-table__bottom {

View file

@ -85,9 +85,11 @@ window.localisation.en = {
cancel: 'Cancel',
scan: 'Scan',
read: 'Read',
write: 'Write',
pay: 'Pay',
memo: 'Memo',
date: 'Date',
path: 'Path',
processing_payment: 'Processing payment...',
not_enough_funds: 'Not enough funds!',
search_by_tag_memo_amount: 'Search by tag, memo, amount',
@ -256,6 +258,13 @@ window.localisation.en = {
back: 'Back',
logout: 'Logout',
look_and_feel: 'Look and Feel',
api_token: 'API Token',
api_tokens: 'API Tokens',
access_control_list: 'Access Control List',
access_control_list_admin_warning:
'This is an admin account. The generated tokens will have admin privileges.',
new_api_acl: 'New Access Control List',
api_token_id: 'Token Id',
toggle_gradient: 'Toggle Gradient',
gradient_background: 'Gradient Background',
language: 'Language',

View file

@ -21,6 +21,60 @@ window.AccountPageLogic = {
newPasswordRepeat: null,
username: null,
pubkey: null
},
apiAcl: {
showNewAclDialog: false,
showPasswordDialog: false,
showNewTokenDialog: false,
data: [],
passwordGuardedFunction: null,
newAclName: '',
newTokenName: '',
password: '',
apiToken: null,
selectedTokenId: null,
columns: [
{
name: 'Name',
align: 'left',
label: this.$t('Name'),
field: 'Name',
sortable: false
},
{
name: 'path',
align: 'left',
label: this.$t('path'),
field: 'path',
sortable: false
},
{
name: 'read',
align: 'left',
label: this.$t('read'),
field: 'read',
sortable: false
},
{
name: 'write',
align: 'left',
label: this.$t('write'),
field: 'write',
sortable: false
}
],
pagination: {
rowsPerPage: 100,
page: 1
}
},
selectedApiAcl: {
id: null,
name: null,
endpoints: [],
token_id_list: [],
allRead: false,
allWrite: false
}
}
},
@ -151,6 +205,215 @@ window.AccountPageLogic = {
newPassword: null,
newPasswordRepeat: null
}
},
newApiAclDialog() {
this.apiAcl.newAclName = null
this.apiAcl.showNewAclDialog = true
},
newTokenAclDialog() {
this.apiAcl.newTokenName = null
this.apiAcl.newTokenExpiry = null
this.apiAcl.showNewTokenDialog = true
},
handleApiACLSelected(aclId) {
this.selectedApiAcl = {
id: null,
name: null,
endpoints: [],
token_id_list: []
}
this.apiAcl.selectedTokenId = null
if (!aclId) {
return
}
setTimeout(() => {
const selectedApiAcl = this.apiAcl.data.find(t => t.id === aclId)
if (!this.selectedApiAcl) {
return
}
this.selectedApiAcl = {...selectedApiAcl}
this.selectedApiAcl.allRead = this.selectedApiAcl.endpoints.every(
e => e.read
)
this.selectedApiAcl.allWrite = this.selectedApiAcl.endpoints.every(
e => e.write
)
})
},
handleAllEndpointsReadAccess() {
this.selectedApiAcl.endpoints.forEach(
e => (e.read = this.selectedApiAcl.allRead)
)
},
handleAllEndpointsWriteAccess() {
this.selectedApiAcl.endpoints.forEach(
e => (e.write = this.selectedApiAcl.allWrite)
)
},
async getApiACLs() {
try {
const {data} = await LNbits.api.request('GET', '/api/v1/auth/acl', null)
this.apiAcl.data = data.access_control_list
} catch (e) {
LNbits.utils.notifyApiError(e)
}
},
askPasswordAndRunFunction(func) {
this.apiAcl.passwordGuardedFunction = func
this.apiAcl.showPasswordDialog = true
},
runPasswordGuardedFunction() {
this.apiAcl.showPasswordDialog = false
const func = this.apiAcl.passwordGuardedFunction
if (func) {
this[func]()
}
},
async addApiACL() {
if (!this.apiAcl.newAclName) {
this.$q.notify({
type: 'warning',
message: 'Name is required.'
})
return
}
try {
const {data} = await LNbits.api.request(
'PUT',
'/api/v1/auth/acl',
null,
{
id: this.apiAcl.newAclName,
name: this.apiAcl.newAclName,
password: this.apiAcl.password
}
)
this.apiAcl.data = data.access_control_list
const acl = this.apiAcl.data.find(
t => t.name === this.apiAcl.newAclName
)
this.handleApiACLSelected(acl.id)
this.apiAcl.showNewAclDialog = false
this.$q.notify({
type: 'positive',
message: 'Access Control List created.'
})
} catch (e) {
LNbits.utils.notifyApiError(e)
} finally {
this.apiAcl.name = ''
this.apiAcl.password = ''
}
this.apiAcl.showNewAclDialog = false
},
async updateApiACLs() {
try {
const {data} = await LNbits.api.request(
'PUT',
'/api/v1/auth/acl',
null,
{
id: this.user.id,
password: this.apiAcl.password,
...this.selectedApiAcl
}
)
this.apiAcl.data = data.access_control_list
} catch (e) {
LNbits.utils.notifyApiError(e)
} finally {
this.apiAcl.password = ''
}
},
async deleteApiACL() {
if (!this.selectedApiAcl.id) {
return
}
try {
await LNbits.api.request('DELETE', '/api/v1/auth/acl', null, {
id: this.selectedApiAcl.id,
password: this.apiAcl.password
})
this.$q.notify({
type: 'positive',
message: 'Access Control List deleted.'
})
} catch (e) {
LNbits.utils.notifyApiError(e)
} finally {
this.apiAcl.password = ''
}
this.apiAcl.data = this.apiAcl.data.filter(
t => t.id !== this.selectedApiAcl.id
)
this.handleApiACLSelected(this.apiAcl.data[0]?.id)
},
async generateApiToken() {
if (!this.selectedApiAcl.id) {
return
}
const expirationTimeMilliseconds =
new Date(this.apiAcl.newTokenExpiry) - new Date()
try {
const {data} = await LNbits.api.request(
'POST',
'/api/v1/auth/acl/token',
null,
{
acl_id: this.selectedApiAcl.id,
token_name: this.apiAcl.newTokenName,
password: this.apiAcl.password,
expiration_time_minutes: Math.trunc(
expirationTimeMilliseconds / 60000
)
}
)
this.apiAcl.apiToken = data.api_token
this.apiAcl.selectedTokenId = data.id
Quasar.Notify.create({
type: 'positive',
message: 'Token Generated.'
})
await this.getApiACLs()
this.handleApiACLSelected(this.selectedApiAcl.id)
this.apiAcl.showNewTokenDialog = false
} catch (e) {
LNbits.utils.notifyApiError(e)
} finally {
this.apiAcl.password = ''
}
},
async deleteToken() {
if (!this.apiAcl.selectedTokenId) {
return
}
try {
await LNbits.api.request('DELETE', '/api/v1/auth/acl/token', null, {
id: this.apiAcl.selectedTokenId,
acl_id: this.selectedApiAcl.id,
password: this.apiAcl.password
})
this.$q.notify({
type: 'positive',
message: 'Token deleted.'
})
this.selectedApiAcl.token_id_list =
this.selectedApiAcl.token_id_list.filter(
t => t.id !== this.apiAcl.selectedTokenId
)
this.apiAcl.selectedTokenId = null
} catch (e) {
LNbits.utils.notifyApiError(e)
} finally {
this.apiAcl.password = ''
}
}
},
async created() {
@ -166,5 +429,6 @@ window.AccountPageLogic = {
if (hash) {
this.tab = hash
}
await this.getApiACLs()
}
}

View file

@ -308,14 +308,9 @@ window.app.component('lnbits-notifications-btn', {
.subscribe(options)
.then(subscription => {
LNbits.api
.request(
'POST',
'/api/v1/webpush',
this.g.user.wallets[0].adminkey,
{
.request('POST', '/api/v1/webpush', null, {
subscription: JSON.stringify(subscription)
}
)
})
.then(response => {
this.saveUserSubscribed(response.data.user)
this.isSubscribed = true
@ -337,7 +332,7 @@ window.app.component('lnbits-notifications-btn', {
.request(
'DELETE',
'/api/v1/webpush?endpoint=' + btoa(subscription.endpoint),
this.g.user.wallets[0].adminkey
null
)
.then(() => {
this.removeUserSubscribed(this.g.user.id)

View file

@ -152,6 +152,12 @@ body.body--dark .q-field--error {
width: 500px;
}
.lnbits__table-bordered td,
.lnbits__table-bordered th {
border: 1px solid black;
border-collapse: collapse;
}
.q-table--dense {
th:first-child,
td:first-child,

View file

@ -2,6 +2,7 @@ import base64
import json
import os
import time
from uuid import uuid4
import jwt
import pytest
@ -9,8 +10,22 @@ import secp256k1
import shortuuid
from httpx import AsyncClient
from lnbits.core.crud.users import (
get_user_access_control_lists,
update_user_access_control_list,
)
from lnbits.core.models import AccessTokenPayload, User
from lnbits.core.models.misc import SimpleItem
from lnbits.core.models.users import (
AccessControlList,
ApiTokenRequest,
DeleteTokenRequest,
EndpointAccess,
UpdateAccessControlList,
UserAcls,
)
from lnbits.core.views.user_api import api_users_reset_password
from lnbits.helpers import create_access_token
from lnbits.settings import AuthMethods, Settings
from lnbits.utils.nostr import hex_to_npub, sign_event
@ -1030,3 +1045,907 @@ async def test_reset_password_auth_threshold_expired(
" in the first 1 seconds."
" Please login again or ask a new reset key!"
)
################################ ACL ################################
@pytest.mark.anyio
async def test_api_update_user_acl_success(http_client: AsyncClient, user_alan: User):
# Login to get access token
response = await http_client.post(
"/api/v1/auth", json={"username": user_alan.username, "password": "secret1234"}
)
assert response.status_code == 200, "Alan logs in OK"
access_token = response.json().get("access_token")
assert access_token is not None
# Create a new ACL
data = UpdateAccessControlList(
id="", name="New ACL", password="secret1234", endpoints=[]
)
response = await http_client.put(
"/api/v1/auth/acl",
json=data.dict(),
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == 200, "ACL should be created successfully."
user_acls = UserAcls(**response.json())
assert any(
acl.name == "New ACL" for acl in user_acls.access_control_list
), "ACL should be in the list."
@pytest.mark.anyio
async def test_api_update_user_acl_invalid_password(
http_client: AsyncClient, user_alan: User
):
# Login to get access token
response = await http_client.post(
"/api/v1/auth", json={"username": user_alan.username, "password": "secret1234"}
)
assert response.status_code == 200, "Alan logs in OK"
access_token = response.json().get("access_token")
assert access_token is not None
# Attempt to create a new ACL with an invalid password
data = UpdateAccessControlList(
id="", name="New ACL", password="wrong_password", endpoints=[]
)
response = await http_client.put(
"/api/v1/auth/acl",
json=data.dict(),
headers={"Authorization": f"Bearer {access_token}"},
)
assert (
response.status_code == 401
), "Invalid password should result in unauthorized error."
assert response.json().get("detail") == "Invalid credentials."
@pytest.mark.anyio
async def test_api_update_user_acl_update_existing(
http_client: AsyncClient, user_alan: User
):
# Login to get access token
response = await http_client.post(
"/api/v1/auth", json={"username": user_alan.username, "password": "secret1234"}
)
assert response.status_code == 200, "Alan logs in OK"
access_token = response.json().get("access_token")
assert access_token is not None
# Create a new ACL
data = UpdateAccessControlList(
id="", name="New ACL", password="secret1234", endpoints=[]
)
response = await http_client.put(
"/api/v1/auth/acl",
json=data.dict(),
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == 200, "ACL should be created successfully."
user_acls = UserAcls(**response.json())
acl = next(acl for acl in user_acls.access_control_list if acl.name == "New ACL")
# Update the existing ACL
data = UpdateAccessControlList(
id=acl.id, name="Updated ACL", password="secret1234", endpoints=[]
)
response = await http_client.put(
"/api/v1/auth/acl",
json=data.dict(),
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == 200, "ACL should be updated successfully."
user_acls = UserAcls(**response.json())
assert any(
acl.name == "Updated ACL" for acl in user_acls.access_control_list
), "ACL should be updated in the list."
@pytest.mark.anyio
async def test_api_update_user_acl_missing_password(
http_client: AsyncClient, user_alan: User
):
# Login to get access token
response = await http_client.post(
"/api/v1/auth", json={"username": user_alan.username, "password": "secret1234"}
)
assert response.status_code == 200, "Alan logs in OK"
access_token = response.json().get("access_token")
assert access_token is not None
# Attempt to create a new ACL with a missing password
data = UpdateAccessControlList(id="", name="New ACL", password="", endpoints=[])
response = await http_client.put(
"/api/v1/auth/acl",
json=data.dict(),
headers={"Authorization": f"Bearer {access_token}"},
)
assert (
response.status_code == 401
), "Missing password should result in unauthorized error."
assert response.json().get("detail") == "Invalid credentials."
@pytest.mark.anyio
async def test_api_get_user_acls_success(http_client: AsyncClient):
# Register a new user to obtain the access token
tiny_id = shortuuid.uuid()[:8]
response = await http_client.post(
"/api/v1/auth/register",
json={
"username": f"u21.{tiny_id}",
"password": "secret1234",
"password_repeat": "secret1234",
"email": f"u21.{tiny_id}@lnbits.com",
},
)
assert response.status_code == 200, "User created."
access_token = response.json().get("access_token")
assert access_token is not None
# Get user ACLs
response = await http_client.get(
"/api/v1/auth/acl", headers={"Authorization": f"Bearer {access_token}"}
)
assert response.status_code == 200, "ACLs fetched successfully."
user_acls = UserAcls(**response.json())
assert user_acls.id is not None, "User ID should be set."
assert isinstance(user_acls.access_control_list, list), "ACL should be a list."
@pytest.mark.anyio
async def test_api_get_user_acls_no_auth(http_client: AsyncClient):
# Attempt to get user ACLs without authentication
response = await http_client.get("/api/v1/auth/acl")
assert response.status_code == 401, "Unauthorized access."
@pytest.mark.anyio
async def test_api_get_user_acls_invalid_token(http_client: AsyncClient):
# Attempt to get user ACLs with an invalid token
response = await http_client.get(
"/api/v1/auth/acl", headers={"Authorization": "Bearer invalid_token"}
)
assert response.status_code == 401, "Unauthorized access."
@pytest.mark.anyio
async def test_api_get_user_acls_empty_acl(http_client: AsyncClient):
# Register a new user to obtain the access token
tiny_id = shortuuid.uuid()[:8]
response = await http_client.post(
"/api/v1/auth/register",
json={
"username": f"u21.{tiny_id}",
"password": "secret1234",
"password_repeat": "secret1234",
"email": f"u21.{tiny_id}@lnbits.com",
},
)
assert response.status_code == 200, "User created."
access_token = response.json().get("access_token")
assert access_token is not None
# Get user ACLs
response = await http_client.get(
"/api/v1/auth/acl", headers={"Authorization": f"Bearer {access_token}"}
)
assert response.status_code == 200, "ACLs fetched successfully."
user_acls = UserAcls(**response.json())
assert user_acls.id is not None, "User ID should be set."
assert len(user_acls.access_control_list) == 0, "ACL should be empty."
@pytest.mark.anyio
async def test_api_get_user_acls_with_acl(http_client: AsyncClient):
# Register a new user to obtain the access token
tiny_id = shortuuid.uuid()[:8]
response = await http_client.post(
"/api/v1/auth/register",
json={
"username": f"u21.{tiny_id}",
"password": "secret1234",
"password_repeat": "secret1234",
"email": f"u21.{tiny_id}@lnbits.com",
},
)
assert response.status_code == 200, "User created."
access_token = response.json().get("access_token")
assert access_token is not None
# Create a new ACL for the user
acl_data = UpdateAccessControlList(
id="",
name="Test ACL",
endpoints=[],
password="secret1234",
)
response = await http_client.put(
"/api/v1/auth/acl",
headers={"Authorization": f"Bearer {access_token}"},
json=acl_data.dict(),
)
assert response.status_code == 200, "ACL created successfully."
# Get user ACLs
response = await http_client.get(
"/api/v1/auth/acl", headers={"Authorization": f"Bearer {access_token}"}
)
assert response.status_code == 200, "ACLs fetched successfully."
user_acls = UserAcls(**response.json())
assert user_acls.id is not None, "User ID should be set."
assert len(user_acls.access_control_list) == 1, "ACL should contain one item."
assert user_acls.access_control_list[0].name == "Test ACL", "ACL name should match."
@pytest.mark.anyio
async def test_api_get_user_acls_sorted(http_client: AsyncClient):
# Register a new user to obtain the access token
tiny_id = shortuuid.uuid()[:8]
response = await http_client.post(
"/api/v1/auth/register",
json={
"username": f"u21.{tiny_id}",
"password": "secret1234",
"password_repeat": "secret1234",
"email": f"u21.{tiny_id}@lnbits.com",
},
)
assert response.status_code == 200, "User created."
access_token = response.json().get("access_token")
assert access_token is not None
# Create some ACLs for the user
acl_names = ["zeta", "alpha", "gamma"]
for name in acl_names:
response = await http_client.put(
"/api/v1/auth/acl",
headers={"Authorization": f"Bearer {access_token}"},
json={"id": name, "name": name, "password": "secret1234"},
)
assert (
response.status_code == 200
), f"ACL '{name}' should be created successfully."
# Get the user's ACLs
response = await http_client.get(
"/api/v1/auth/acl",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == 200, "ACLs retrieved."
user_acls = UserAcls(**response.json())
# Check that the ACLs are sorted alphabetically by name
acl_names_sorted = sorted(acl_names)
retrieved_acl_names = [acl.name for acl in user_acls.access_control_list]
assert (
retrieved_acl_names == acl_names_sorted
), "ACLs are not sorted alphabetically by name."
@pytest.mark.anyio
async def test_api_delete_user_acl_success(http_client: AsyncClient):
# Register a new user to obtain the access token
tiny_id = shortuuid.uuid()[:8]
response = await http_client.post(
"/api/v1/auth/register",
json={
"username": f"u21.{tiny_id}",
"password": "secret1234",
"password_repeat": "secret1234",
"email": f"u21.{tiny_id}@lnbits.com",
},
)
assert response.status_code == 200, "User created."
access_token = response.json().get("access_token")
assert access_token is not None
# Create an ACL for the user
response = await http_client.put(
"/api/v1/auth/acl",
headers={"Authorization": f"Bearer {access_token}"},
json={
"id": "Test ACL",
"name": "Test ACL",
"password": "secret1234",
},
)
assert response.status_code == 200, "ACL created."
acl_id = response.json()["access_control_list"][0]["id"]
# Delete the ACL
response = await http_client.request(
"DELETE",
"/api/v1/auth/acl",
headers={"Authorization": f"Bearer {access_token}"},
json={
"id": acl_id,
"password": "secret1234",
},
)
assert response.status_code == 200, "ACL deleted."
@pytest.mark.anyio
async def test_api_delete_user_acl_invalid_password(http_client: AsyncClient):
# Register a new user to obtain the access token
tiny_id = shortuuid.uuid()[:8]
response = await http_client.post(
"/api/v1/auth/register",
json={
"username": f"u21.{tiny_id}",
"password": "secret1234",
"password_repeat": "secret1234",
"email": f"u21.{tiny_id}@lnbits.com",
},
)
assert response.status_code == 200, "User created."
access_token = response.json().get("access_token")
assert access_token is not None
# Create an ACL for the user
response = await http_client.put(
"/api/v1/auth/acl",
headers={"Authorization": f"Bearer {access_token}"},
json={
"id": "Test ACL",
"name": "Test ACL",
"password": "secret1234",
},
)
assert response.status_code == 200, "ACL created."
acl_id = response.json()["access_control_list"][0]["id"]
# Attempt to delete the ACL with an invalid password
response = await http_client.request(
"DELETE",
"/api/v1/auth/acl",
headers={"Authorization": f"Bearer {access_token}"},
json={
"id": acl_id,
"password": "wrongpassword",
},
)
assert response.status_code == 401, "Invalid credentials."
@pytest.mark.anyio
async def test_api_delete_user_acl_nonexistent_acl(http_client: AsyncClient):
# Register a new user to obtain the access token
tiny_id = shortuuid.uuid()[:8]
response = await http_client.post(
"/api/v1/auth/register",
json={
"username": f"u21.{tiny_id}",
"password": "secret1234",
"password_repeat": "secret1234",
"email": f"u21.{tiny_id}@lnbits.com",
},
)
assert response.status_code == 200, "User created."
access_token = response.json().get("access_token")
assert access_token is not None
# Attempt to delete a nonexistent ACL
response = await http_client.request(
"DELETE",
"/api/v1/auth/acl",
headers={"Authorization": f"Bearer {access_token}"},
json={
"id": "nonexistent_acl_id",
"password": "secret1234",
},
)
assert response.status_code == 200, "ACL deleted."
@pytest.mark.anyio
async def test_api_delete_user_acl_missing_password(http_client: AsyncClient):
# Register a new user to obtain the access token
tiny_id = shortuuid.uuid()[:8]
response = await http_client.post(
"/api/v1/auth/register",
json={
"username": f"u21.{tiny_id}",
"password": "secret1234",
"password_repeat": "secret1234",
"email": f"u21.{tiny_id}@lnbits.com",
},
)
assert response.status_code == 200, "User created."
access_token = response.json().get("access_token")
assert access_token is not None
# Create an ACL for the user
response = await http_client.put(
"/api/v1/auth/acl",
headers={"Authorization": f"Bearer {access_token}"},
json={
"id": "Test ACL",
"name": "Test ACL",
"password": "secret1234",
},
)
assert response.status_code == 200, "ACL created."
acl_id = response.json()["access_control_list"][0]["id"]
# Attempt to delete the ACL without providing a password
response = await http_client.request(
"DELETE",
"/api/v1/auth/acl",
headers={"Authorization": f"Bearer {access_token}"},
json={
"id": acl_id,
},
)
assert response.status_code == 400, "Missing password."
################################ TOKEN ################################
@pytest.mark.anyio
async def test_api_create_user_api_token_success(
http_client: AsyncClient, settings: Settings
):
# Register a new user
tiny_id = shortuuid.uuid()[:8]
response = await http_client.post(
"/api/v1/auth/register",
json={
"username": f"u21.{tiny_id}",
"password": "secret1234",
"password_repeat": "secret1234",
"email": f"u21.{tiny_id}@lnbits.com",
},
)
assert response.status_code == 200, "User created."
access_token = response.json().get("access_token")
assert access_token is not None
# Create a new ACL
acl_data = UpdateAccessControlList(
id="", password="secret1234", name="Test ACL", endpoints=[]
)
response = await http_client.put(
"/api/v1/auth/acl",
headers={"Authorization": f"Bearer {access_token}"},
json=acl_data.dict(),
)
assert response.status_code == 200, "ACL created."
acl_id = response.json()["access_control_list"][0]["id"]
# Create API token
token_request = ApiTokenRequest(
acl_id=acl_id,
token_name="Test Token",
expiration_time_minutes=60,
password="secret1234",
)
response = await http_client.post(
"/api/v1/auth/acl/token",
headers={"Authorization": f"Bearer {access_token}"},
json=token_request.dict(),
)
assert response.status_code == 200, "API token created."
api_token = response.json().get("api_token")
assert api_token is not None
# Verify the token exists
response = await http_client.get(
"/api/v1/auth/acl",
headers={"Authorization": f"Bearer {access_token}"},
)
assert response.status_code == 200, "ACLs fetched successfully."
acls = UserAcls(**response.json())
# Decode the access token to get the user ID
payload: dict = jwt.decode(api_token, settings.auth_secret_key, ["HS256"])
# Check the expiration time
expiration_time = payload.get("exp")
assert expiration_time is not None, "Expiration time should be set."
assert (
0 <= 3600 - (expiration_time - time.time()) <= 5
), "Expiration time should be 60 minutes from now."
token_id = payload["api_token_id"]
assert any(
token_id in [token.id for token in acl.token_id_list]
for acl in acls.access_control_list
), "API token should be part of at least one ACL."
@pytest.mark.anyio
async def test_acl_api_token_access(user_alan: User, http_client: AsyncClient):
user_acls = await get_user_access_control_lists(user_alan.id)
acl = AccessControlList(id=uuid4().hex, name="Test ACL", endpoints=[])
user_acls.access_control_list = [acl]
api_token_id = uuid4().hex
payload = AccessTokenPayload(
sub=user_alan.username or user_alan.id,
api_token_id=api_token_id,
auth_time=int(time.time()),
)
api_token = create_access_token(data=payload.dict(), token_expire_minutes=10)
acl.token_id_list.append(SimpleItem(id=api_token_id, name="Test Token"))
await update_user_access_control_list(user_acls)
headers = {"Authorization": f"Bearer {api_token}"}
response = await http_client.get("/api/v1/auth/acl", headers=headers)
assert response.status_code == 403, "Path not allowed."
assert response.json()["detail"] == "Path not allowed."
# Grant read access
endpoint = EndpointAccess(path="/api/v1/auth", name="Get User ACLs", read=True)
acl.endpoints.append(endpoint)
await update_user_access_control_list(user_acls)
response = await http_client.get("/api/v1/auth/acl", headers=headers)
assert response.status_code == 200, "Access granted."
response = await http_client.put("/api/v1/auth/acl", headers=headers)
assert response.status_code == 403, "Method not allowed."
response = await http_client.post(
"/api/v1/auth/acl/token", headers=headers, json={}
)
assert response.status_code == 403, "Method not allowed."
response = await http_client.patch("/api/v1/auth/acl", headers=headers)
assert response.status_code == 403, "Method not allowed."
response = await http_client.delete("/api/v1/auth/acl", headers=headers)
assert response.status_code == 403, "Method not allowed."
# Grant write access
endpoint.write = True
await update_user_access_control_list(user_acls)
response = await http_client.get("/api/v1/auth/acl", headers=headers)
assert response.status_code == 200, "Access granted."
response = await http_client.put("/api/v1/auth/acl", headers=headers)
assert response.status_code == 400, "Access granted, validation error expected."
response = await http_client.post(
"/api/v1/auth/acl/token", headers=headers, json={}
)
assert response.status_code == 400, "Access granted, validation error expected."
response = await http_client.patch("/api/v1/auth/acl", headers=headers)
assert response.status_code == 400, "Access granted, validation error expected."
response = await http_client.delete("/api/v1/auth/acl", headers=headers)
assert response.status_code == 400, "Access granted, validation error expected."
# Revoke read access
endpoint.read = False
await update_user_access_control_list(user_acls)
response = await http_client.get("/api/v1/auth/acl", headers=headers)
assert response.status_code == 403, "Method not allowed."
response = await http_client.put("/api/v1/auth/acl", headers=headers)
assert (
response.status_code == 400
), "Access still granted, validation error expected."
@pytest.mark.anyio
async def test_api_create_user_api_token_invalid_password(http_client: AsyncClient):
# Register a new user
tiny_id = shortuuid.uuid()[:8]
response = await http_client.post(
"/api/v1/auth/register",
json={
"username": f"u21.{tiny_id}",
"password": "secret1234",
"password_repeat": "secret1234",
"email": f"u21.{tiny_id}@lnbits.com",
},
)
assert response.status_code == 200, "User created."
access_token = response.json().get("access_token")
assert access_token is not None
# Create a new ACL
acl_data = UpdateAccessControlList(
password="secret1234", id="", name="Test ACL", endpoints=[]
)
response = await http_client.put(
"/api/v1/auth/acl",
headers={"Authorization": f"Bearer {access_token}"},
json=acl_data.dict(),
)
assert response.status_code == 200, "ACL created."
acl_id = response.json()["access_control_list"][0]["id"]
# Create API token with invalid password
token_request = ApiTokenRequest(
acl_id=acl_id,
token_name="Test Token",
expiration_time_minutes=60,
password="wrongpassword",
)
response = await http_client.post(
"/api/v1/auth/acl/token",
headers={"Authorization": f"Bearer {access_token}"},
json=token_request.dict(),
)
assert response.status_code == 401, "Invalid credentials."
@pytest.mark.anyio
async def test_api_create_user_api_token_invalid_acl_id(http_client: AsyncClient):
# Register a new user
tiny_id = shortuuid.uuid()[:8]
response = await http_client.post(
"/api/v1/auth/register",
json={
"username": f"u21.{tiny_id}",
"password": "secret1234",
"password_repeat": "secret1234",
"email": f"u21.{tiny_id}@lnbits.com",
},
)
assert response.status_code == 200, "User created."
access_token = response.json().get("access_token")
assert access_token is not None
# Create API token with invalid ACL ID
token_request = ApiTokenRequest(
acl_id="invalid_acl_id",
token_name="Test Token",
expiration_time_minutes=60,
password="secret1234",
)
response = await http_client.post(
"/api/v1/auth/acl/token",
headers={"Authorization": f"Bearer {access_token}"},
json=token_request.dict(),
)
assert response.status_code == 401, "Invalid ACL id."
@pytest.mark.anyio
async def test_api_create_user_api_token_expiration_time_invalid(
http_client: AsyncClient,
):
# Register a new user
tiny_id = shortuuid.uuid()[:8]
response = await http_client.post(
"/api/v1/auth/register",
json={
"username": f"u21.{tiny_id}",
"password": "secret1234",
"password_repeat": "secret1234",
"email": f"u21.{tiny_id}@lnbits.com",
},
)
assert response.status_code == 200, "User created."
access_token = response.json().get("access_token")
assert access_token is not None
# Create a new ACL
acl_data = UpdateAccessControlList(
id="", password="secret1234", name="Test ACL", endpoints=[]
)
response = await http_client.put(
"/api/v1/auth/acl",
headers={"Authorization": f"Bearer {access_token}"},
json=acl_data.dict(),
)
assert response.status_code == 200, "ACL created."
acl_id = response.json()["access_control_list"][0]["id"]
# Create API token with invalid expiration time
token_request = ApiTokenRequest(
acl_id=acl_id,
token_name="Test Token",
expiration_time_minutes=-1,
password="secret1234",
)
response = await http_client.post(
"/api/v1/auth/acl/token",
headers={"Authorization": f"Bearer {access_token}"},
json=token_request.dict(),
)
assert response.status_code == 400, "Expiration time must be in the future."
@pytest.mark.anyio
async def test_api_delete_user_api_token_success(
http_client: AsyncClient, settings: Settings
):
# Register a new user
tiny_id = shortuuid.uuid()[:8]
response = await http_client.post(
"/api/v1/auth/register",
json={
"username": f"u21.{tiny_id}",
"password": "secret1234",
"password_repeat": "secret1234",
"email": f"u21.{tiny_id}@lnbits.com",
},
)
assert response.status_code == 200, "User created."
access_token = response.json().get("access_token")
assert access_token is not None
# Decode the access token to get the user ID
payload: dict = jwt.decode(access_token, settings.auth_secret_key, ["HS256"])
user_id = payload["usr"]
# Create a new ACL
acl_data = UpdateAccessControlList(
id="", name="Test ACL", endpoints=[], password="secret1234"
)
user_acls = await get_user_access_control_lists(user_id)
user_acls.access_control_list.append(acl_data)
await update_user_access_control_list(user_acls)
# Create a new API token
api_token_request = ApiTokenRequest(
acl_id=acl_data.id,
token_name="Test Token",
expiration_time_minutes=60,
password="secret1234",
)
response = await http_client.post(
"/api/v1/auth/acl/token",
headers={"Authorization": f"Bearer {access_token}"},
json=api_token_request.dict(),
)
assert response.status_code == 200, "API token created."
api_token_id = response.json().get("id")
assert api_token_id is not None
# Delete the API token
delete_token_request = DeleteTokenRequest(
acl_id=acl_data.id, id=api_token_id, password="secret1234"
)
response = await http_client.request(
"DELETE",
"/api/v1/auth/acl/token",
headers={"Authorization": f"Bearer {access_token}"},
json=delete_token_request.dict(),
)
assert response.status_code == 200, "API token deleted."
@pytest.mark.anyio
async def test_api_delete_user_api_token_invalid_password(
http_client: AsyncClient, settings: Settings
):
# Register a new user
tiny_id = shortuuid.uuid()[:8]
response = await http_client.post(
"/api/v1/auth/register",
json={
"username": f"u21.{tiny_id}",
"password": "secret1234",
"password_repeat": "secret1234",
"email": f"u21.{tiny_id}@lnbits.com",
},
)
assert response.status_code == 200, "User created."
access_token = response.json().get("access_token")
assert access_token is not None
# Decode the access token to get the user ID
payload: dict = jwt.decode(access_token, settings.auth_secret_key, ["HS256"])
user_id = payload["usr"]
# Create a new ACL
acl_data = UpdateAccessControlList(
id="", name="Test ACL", endpoints=[], password="secret1234"
)
user_acls = await get_user_access_control_lists(user_id)
user_acls.access_control_list.append(acl_data)
await update_user_access_control_list(user_acls)
# Create a new API token
api_token_request = ApiTokenRequest(
acl_id=acl_data.id,
token_name="Test Token",
expiration_time_minutes=60,
password="secret1234",
)
response = await http_client.post(
"/api/v1/auth/acl/token",
headers={"Authorization": f"Bearer {access_token}"},
json=api_token_request.dict(),
)
assert response.status_code == 200, "API token created."
api_token_id = response.json().get("id")
assert api_token_id is not None
# Attempt to delete the API token with an invalid password
delete_token_request = DeleteTokenRequest(
acl_id=acl_data.id, id=api_token_id, password="wrong_password"
)
response = await http_client.request(
"DELETE",
"/api/v1/auth/acl/token",
headers={"Authorization": f"Bearer {access_token}"},
json=delete_token_request.dict(),
)
assert response.status_code == 401, "Invalid credentials."
@pytest.mark.anyio
async def test_api_delete_user_api_token_invalid_acl_id(
http_client: AsyncClient,
):
# Register a new user
tiny_id = shortuuid.uuid()[:8]
response = await http_client.post(
"/api/v1/auth/register",
json={
"username": f"u21.{tiny_id}",
"password": "secret1234",
"password_repeat": "secret1234",
"email": f"u21.{tiny_id}@lnbits.com",
},
)
assert response.status_code == 200, "User created."
access_token = response.json().get("access_token")
assert access_token is not None
# Attempt to delete an API token with an invalid ACL ID
delete_token_request = DeleteTokenRequest(
acl_id="invalid_acl_id", id="invalid_token_id", password="secret1234"
)
response = await http_client.request(
"DELETE",
"/api/v1/auth/acl/token",
headers={"Authorization": f"Bearer {access_token}"},
json=delete_token_request.dict(),
)
assert response.status_code == 401, "Invalid ACL id."
@pytest.mark.anyio
async def test_api_delete_user_api_token_missing_token_id(
http_client: AsyncClient, settings: Settings
):
# Register a new user
tiny_id = shortuuid.uuid()[:8]
response = await http_client.post(
"/api/v1/auth/register",
json={
"username": f"u21.{tiny_id}",
"password": "secret1234",
"password_repeat": "secret1234",
"email": f"u21.{tiny_id}@lnbits.com",
},
)
assert response.status_code == 200, "User created."
access_token = response.json().get("access_token")
assert access_token is not None
# Decode the access token to get the user ID
payload: dict = jwt.decode(access_token, settings.auth_secret_key, ["HS256"])
user_id = payload["usr"]
# Create a new ACL
acl_data = UpdateAccessControlList(
id="", name="Test ACL", endpoints=[], password="secret1234"
)
user_acls = await get_user_access_control_lists(user_id)
user_acls.access_control_list.append(acl_data)
await update_user_access_control_list(user_acls)
# Attempt to delete an API token with a missing token ID
delete_token_request = DeleteTokenRequest(
acl_id=acl_data.id, id="", password="secret1234"
)
response = await http_client.request(
"DELETE",
"/api/v1/auth/acl/token",
headers={"Authorization": f"Bearer {access_token}"},
json=delete_token_request.dict(),
)
assert response.status_code == 200, "Does noting if token not found."

View file

@ -4,20 +4,20 @@ import pytest
@pytest.mark.anyio
async def test_create___bad_body(client, adminkey_headers_from):
async def test_create___bad_body(client, user_headers_from):
response = await client.post(
"/api/v1/webpush",
headers=adminkey_headers_from,
headers=user_headers_from,
json={"subscription": "bad_json"},
)
assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
@pytest.mark.anyio
async def test_create___missing_fields(client, adminkey_headers_from):
async def test_create___missing_fields(client, user_headers_from):
response = await client.post(
"/api/v1/webpush",
headers=adminkey_headers_from,
headers=user_headers_from,
json={"subscription": """{"a": "x"}"""},
)
assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
@ -34,30 +34,30 @@ async def test_create___bad_access_key(client, inkey_headers_from):
@pytest.mark.anyio
async def test_delete__bad_endpoint_format(client, adminkey_headers_from):
async def test_delete__bad_endpoint_format(client, user_headers_from):
response = await client.delete(
"/api/v1/webpush",
params={"endpoint": "https://this.should.be.base64.com"},
headers=adminkey_headers_from,
headers=user_headers_from,
)
assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
@pytest.mark.anyio
async def test_delete__no_endpoint_param(client, adminkey_headers_from):
async def test_delete__no_endpoint_param(client, user_headers_from):
response = await client.delete(
"/api/v1/webpush",
headers=adminkey_headers_from,
headers=user_headers_from,
)
assert response.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
@pytest.mark.anyio
async def test_delete__no_endpoint_found(client, adminkey_headers_from):
async def test_delete__no_endpoint_found(client, user_headers_from):
response = await client.delete(
"/api/v1/webpush",
params={"endpoint": "aHR0cHM6Ly9kZW1vLmxuYml0cy5jb20="},
headers=adminkey_headers_from,
headers=user_headers_from,
)
assert response.status_code == HTTPStatus.OK
assert response.json()["count"] == 0
@ -73,17 +73,17 @@ async def test_delete__bad_access_key(client, inkey_headers_from):
@pytest.mark.anyio
async def test_create_and_delete(client, adminkey_headers_from):
async def test_create_and_delete(client, user_headers_from):
response = await client.post(
"/api/v1/webpush",
headers=adminkey_headers_from,
headers=user_headers_from,
json={"subscription": """{"endpoint": "https://demo.lnbits.com"}"""},
)
assert response.status_code == HTTPStatus.CREATED
response = await client.delete(
"/api/v1/webpush",
params={"endpoint": "aHR0cHM6Ly9kZW1vLmxuYml0cy5jb20="},
headers=adminkey_headers_from,
headers=user_headers_from,
)
assert response.status_code == HTTPStatus.OK
assert response.json()["count"] == 1

View file

@ -219,6 +219,18 @@ async def adminkey_headers_from(from_wallet):
}
@pytest.fixture(scope="session")
async def user_headers_from(client: AsyncClient, from_user: User):
response = await client.post("/api/v1/auth/usr", json={"usr": from_user.id})
client.cookies.clear()
access_token = response.json().get("access_token")
yield {
"Authorization": "Bearer " + access_token,
"Content-type": "application/json",
}
@pytest.fixture(scope="session")
async def inkey_headers_to(to_wallet):
wallet = to_wallet

48
tests/copilot_prompt.md Normal file
View file

@ -0,0 +1,48 @@
# GitHub Copilot Prompts
Make sure to:
- select the code that you want to test. The prompt specifies the name of the file and the function to be tested (this redundancy is needed)
- open tabs with relevant files for the tests, for example: `conftest.py`, `test_auth.py`. This helps Copilot with context.
## Examples
### Create Comprehensive suite of unit tests
_Sample 1_
@workspace /tests Develop a comprehensive suite of unit tests for the selected code (only the function (only the function api_create_user_api_token in auth_api.py file) in auth_api.py file).
Requirements:
- use register endpoint to obtain the access token (see example in test_register_ok)
- write multiple test functions that cover a wide range of scenarios, including the succes flow, edge cases, exception handling, and data validation
- for the success case create a new ACL before creating the token
_Sample 2_
@workspace /tests Develop a comprehensive suite of unit tests for the selected code (only the function check_user_exists in decorators.py file) .
Requirements:
- write multiple test functions that cover a wide range of scenarios, including the succes flow, edge cases, security vulnerabilities, exception handling, and data validation
- use the login endpoint to obtain a valid access token. Use the `user_alan: User` fixture for the login params. Check the `test_login_alan_username_password_ok` function in the `test_auth.py` file as an example for login.
- do not use mocks. For the request parameter initialize the fastapi.Request class.
- make sure to cover all if-then-else branches
### Create tests for a particular usecase
_Sample 1_
@workspace /tests Develop a test for the selected code (only the function api_get_user_acls in auth_api.py file).
Requirements:
- use register endpoint to obtain the access token (see example in test_register_ok)
- the test should only check that the ACLs are sorted alphabeticaly by name
_Sample 1_
@workspace /tests Develop a test for the selected code (only the function check_user_exists in decorators.py file).
Requirements:
- use register endpoint to obtain the access token (see example in the file test_auth.py the function test_register_ok())
- the test should register a new user, obtain the access token then delete the user. Then check that check_user_exists() fails as expected
@workspace /tests Develop a test for the selected code (only the function check_user_exists in decorators.py file).
Requirements:
- check only the branch where user_id_only login is allowed

View file

@ -0,0 +1,138 @@
from uuid import uuid4
import jwt
import pytest
import shortuuid
from fastapi import Request
from fastapi.exceptions import HTTPException
from httpx import AsyncClient
from pydantic.types import UUID4
from lnbits.core.crud.users import delete_account
from lnbits.core.models import User
from lnbits.core.models.users import AccessTokenPayload
from lnbits.decorators import check_user_exists
from lnbits.settings import AuthMethods, Settings, settings
@pytest.mark.anyio
async def test_check_user_exists_with_valid_access_token(
http_client: AsyncClient, user_alan: User
):
# Login to get a valid access token
response = await http_client.post(
"/api/v1/auth", json={"username": user_alan.username, "password": "secret1234"}
)
assert response.status_code == 200, "Alan logs in OK"
access_token = response.json()["access_token"]
assert access_token is not None
request = Request({"type": "http", "path": "/some/path", "method": "GET"})
user = await check_user_exists(request, access_token=access_token)
assert user.id == user_alan.id
assert request.scope["user_id"] == user.id
@pytest.mark.anyio
async def test_check_user_exists_with_invalid_access_token():
request = Request({"type": "http", "path": "/some/path", "method": "GET"})
with pytest.raises(HTTPException) as exc_info:
await check_user_exists(request, access_token="invalid_token")
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Invalid access token."
@pytest.mark.anyio
async def test_check_user_exists_with_missing_access_token():
request = Request({"type": "http", "path": "/some/path", "method": "GET"})
with pytest.raises(HTTPException) as exc_info:
await check_user_exists(request, access_token=None)
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Missing user ID or access token."
@pytest.mark.anyio
async def test_check_user_exists_with_valid_user_id(user_alan: User):
request = Request({"type": "http", "path": "/some/path", "method": "GET"})
user = await check_user_exists(request, access_token=None, usr=UUID4(user_alan.id))
assert user.id == user_alan.id
@pytest.mark.anyio
async def test_check_user_exists_with_invalid_user_id():
request = Request({"type": "http", "path": "/some/path", "method": "GET"})
with pytest.raises(HTTPException) as exc_info:
await check_user_exists(request, access_token=None, usr=uuid4())
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "User not found."
@pytest.mark.anyio
async def test_check_user_exists_with_user_not_allowed(user_alan: User):
settings.lnbits_admin_users = []
request = Request({"type": "http", "path": "/some/path", "method": "GET"})
settings.lnbits_allowed_users = ["only_this_user_id"]
with pytest.raises(HTTPException) as exc_info:
await check_user_exists(request, access_token=None, usr=UUID4(user_alan.id))
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "User not allowed."
@pytest.mark.anyio
async def test_check_user_exists_after_user_deletion(http_client: AsyncClient):
# Register a new user
tiny_id = shortuuid.uuid()[:8]
register_response = await http_client.post(
"/api/v1/auth/register",
json={
"username": f"u21.{tiny_id}",
"password": "secret1234",
"password_repeat": "secret1234",
"email": f"u21.{tiny_id}@lnbits.com",
},
)
assert register_response.status_code == 200, "User registers OK"
access_token = register_response.json()["access_token"]
assert access_token is not None
payload: dict = jwt.decode(access_token, settings.auth_secret_key, ["HS256"])
access_token_payload = AccessTokenPayload(**payload)
# Get the user ID
user_id = access_token_payload.usr
assert user_id, "User ID is not None"
# Delete the user
await delete_account(user_id)
# Attempt to check user existence with the deleted user's access token
request = Request({"type": "http", "path": "/some/path", "method": "GET"})
with pytest.raises(HTTPException) as exc_info:
await check_user_exists(request, access_token=access_token)
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "User not found."
@pytest.mark.anyio
async def test_check_user_exists_with_user_id_only_allowed(
user_alan: User, settings: Settings
):
settings.auth_allowed_methods = [AuthMethods.user_id_only.value]
request = Request({"type": "http", "path": "/some/path", "method": "GET"})
user = await check_user_exists(request, access_token=None, usr=UUID4(user_alan.id))
assert user.id == user_alan.id
assert request.scope["user_id"] == user.id
@pytest.mark.anyio
async def test_check_user_exists_with_user_id_only_not_allowed(user_alan: User):
settings.auth_allowed_methods = []
request = Request({"type": "http", "path": "/some/path", "method": "GET"})
with pytest.raises(HTTPException) as exc_info:
await check_user_exists(request, access_token=None, usr=UUID4(user_alan.id))
assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Missing user ID or access token."