[perf] reuse connection (#3624)

This commit is contained in:
Vlad Stan 2025-12-06 15:52:06 +02:00 committed by GitHub
parent 71e0b396d2
commit 5d79327906
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 322 additions and 189 deletions

View file

@ -33,9 +33,9 @@ async def get_payment(checking_id: str, conn: Connection | None = None) -> Payme
async def get_standalone_payment(
checking_id_or_hash: str,
conn: Connection | None = None,
incoming: bool | None = False,
wallet_id: str | None = None,
conn: Connection | None = None,
) -> Payment | None:
clause: str = "checking_id = :checking_id OR payment_hash = :hash"
values = {
@ -46,7 +46,7 @@ async def get_standalone_payment(
clause = f"({clause}) AND amount > 0"
if wallet_id:
wallet = await get_wallet(wallet_id)
wallet = await get_wallet(wallet_id, conn=conn)
if not wallet or not wallet.can_view_payments:
return None
values["wallet_id"] = wallet.source_wallet_id
@ -69,7 +69,7 @@ async def get_standalone_payment(
async def get_wallet_payment(
wallet_id: str, payment_hash: str, conn: Connection | None = None
) -> Payment | None:
wallet = await get_wallet(wallet_id)
wallet = await get_wallet(wallet_id, conn=conn)
if not wallet or not wallet.can_view_payments:
return None
payment = await (conn or db).fetchone(
@ -124,7 +124,6 @@ async def get_payments_paginated( # noqa: C901
Filters payments to be returned by:
- complete | pending | failed | outgoing | incoming.
"""
values: dict[str, Any] = {
"time": since,
}
@ -134,7 +133,7 @@ async def get_payments_paginated( # noqa: C901
clause.append(f"time > {db.timestamp_placeholder('time')}")
if wallet_id:
wallet = await get_wallet(wallet_id)
wallet = await get_wallet(wallet_id, conn=conn)
if not wallet or not wallet.can_view_payments:
return Page(data=[], total=0)
@ -326,6 +325,7 @@ async def get_payments_history(
wallet_id: str | None = None,
group: DateTrunc = "day",
filters: Filters | None = None,
conn: Connection | None = None,
) -> list[PaymentHistoryPoint]:
if not filters:
filters = Filters()
@ -361,13 +361,13 @@ async def get_payments_history(
filters.values(values),
)
if wallet_id:
wallet = await get_wallet(wallet_id)
wallet = await get_wallet(wallet_id, conn=conn)
if not wallet or not wallet.can_view_payments:
return []
balance = wallet.balance_msat
values["wallet_id"] = wallet.source_wallet_id
else:
balance = await get_total_balance()
balance = await get_total_balance(conn=conn)
# since we dont know the balance at the starting point,
# we take the current balance and walk backwards

View file

@ -30,9 +30,9 @@ async def create_account(
return account
async def update_account(account: Account) -> Account:
async def update_account(account: Account, conn: Connection | None = None) -> Account:
account.updated_at = datetime.now(timezone.utc)
await db.update("accounts", account)
await (conn or db).update("accounts", account)
return account
@ -171,21 +171,23 @@ async def get_account_by_username_or_email(
async def get_user(user_id: str, conn: Connection | None = None) -> User | None:
account = await get_account(user_id, conn)
if not account:
return None
return await get_user_from_account(account, conn)
async with db.reuse_conn(conn) if conn else db.connect() as conn:
account = await get_account(user_id, conn=conn)
if not account:
return None
return await get_user_from_account(account, conn=conn)
async def get_user_from_account(
account: Account, conn: Connection | None = None
) -> User | None:
extensions = await get_user_active_extensions_ids(account.id, conn=conn)
wallets = await get_wallets(account.id, deleted=False, conn=conn)
async with db.reuse_conn(conn) if conn else db.connect() as conn:
extensions = await get_user_active_extensions_ids(account.id, conn=conn)
wallets = await get_wallets(account.id, deleted=False, conn=conn)
if len(wallets) == 0:
wallet = await create_wallet(user_id=account.id, conn=conn)
wallets.append(wallet)
if len(wallets) == 0:
wallet = await create_wallet(user_id=account.id, conn=conn)
wallets.append(wallet)
return User(
id=account.id,
@ -205,9 +207,11 @@ async def get_user_from_account(
)
async def update_user_access_control_list(user_acls: UserAcls):
async def update_user_access_control_list(
user_acls: UserAcls, conn: Connection | None = None
):
user_acls.updated_at = datetime.now(timezone.utc)
await db.update("accounts", user_acls)
await (conn or db).update("accounts", user_acls)
async def get_user_access_control_lists(

View file

@ -786,3 +786,63 @@ async def m038_add_labels_for_payments(db: Connection):
ALTER TABLE apipayments ADD COLUMN labels TEXT
"""
)
async def m039_index_payments(db: Connection):
indexes = [
"wallet_id",
"checking_id",
"payment_hash",
"amount",
"labels",
"time",
"status",
"created_at",
"updated_at",
]
for index in indexes:
logger.debug(f"Creating index idx_payments_{index}...")
await db.execute(
f"""
CREATE INDEX IF NOT EXISTS idx_payments_{index} ON apipayments ({index});
"""
)
async def m040_index_wallets(db: Connection):
indexes = [
"id",
"user",
"deleted",
"adminkey",
"inkey",
"wallet_type",
"created_at",
"updated_at",
]
for index in indexes:
logger.debug(f"Creating index idx_wallets_{index}...")
await db.execute(
f"""
CREATE INDEX IF NOT EXISTS idx_wallets_{index} ON wallets ("{index}");
"""
)
async def m042_index_accounts(db: Connection):
indexes = [
"id",
"email",
"username",
"pubkey",
"external_id",
]
for index in indexes:
logger.debug(f"Creating index idx_wallets_{index}...")
await db.execute(
f"""
CREATE INDEX IF NOT EXISTS idx_accounts_{index} ON accounts ("{index}");
"""
)

View file

@ -201,6 +201,10 @@ class Account(AccountId):
self.is_admin = settings.is_admin_user(self.id)
self.fiat_providers = settings.get_fiat_providers_for_user(self.id)
@property
def has_password(self) -> bool:
return self.password_hash is not None
def hash_password(self, password: str) -> str:
"""sets and returns the hashed password"""
salt = gensalt()

View file

@ -16,6 +16,7 @@ from lnbits.core.crud.extensions import (
update_installed_extension,
)
from lnbits.core.helpers import migrate_extension_database
from lnbits.db import Connection
from lnbits.settings import settings
from ..models.extensions import Extension, ExtensionMeta, InstallableExtension
@ -149,9 +150,9 @@ async def start_extension_background_work(ext_id: str) -> bool:
async def get_valid_extensions(
include_deactivated: bool | None = True,
include_deactivated: bool | None = True, conn: Connection | None = None
) -> list[Extension]:
installed_extensions = await get_installed_extensions()
installed_extensions = await get_installed_extensions(conn=conn)
valid_extensions = [Extension.from_installable_ext(e) for e in installed_extensions]
if include_deactivated:

View file

@ -67,12 +67,14 @@ async def pay_invoice(
if settings.lnbits_only_allow_incoming_payments:
raise PaymentError("Only incoming payments allowed.", status="failed")
invoice = _validate_payment_request(payment_request, max_sat)
if not invoice.amount_msat:
raise ValueError("Missig invoice amount.")
async with db.reuse_conn(conn) if conn else db.connect() as new_conn:
amount_msat = invoice.amount_msat
wallet = await _check_wallet_for_payment(wallet_id, tag, amount_msat, new_conn)
if not wallet.can_send_payments:
raise PaymentError(
"Wallet does not have permission to pay invoices.",
@ -95,10 +97,12 @@ async def pay_invoice(
labels=labels,
)
payment = await _pay_invoice(wallet.source_wallet_id, create_payment_model, conn)
async with db.reuse_conn(conn) if conn else db.connect() as new_conn:
await _credit_service_fee_wallet(wallet, payment, new_conn)
payment = await _pay_invoice(
wallet.source_wallet_id, create_payment_model, conn=new_conn
)
await _credit_service_fee_wallet(wallet, payment, conn=new_conn)
return payment
@ -197,20 +201,22 @@ async def create_wallet_invoice(wallet_id: str, data: CreateInvoice) -> Payment:
# do not save memo if description_hash or unhashed_description is set
memo = ""
payment = await create_invoice(
wallet_id=wallet_id,
amount=data.amount,
memo=memo,
currency=data.unit,
description_hash=description_hash,
unhashed_description=unhashed_description,
expiry=data.expiry,
extra=data.extra,
webhook=data.webhook,
internal=data.internal,
payment_hash=data.payment_hash,
labels=data.labels,
)
async with db.connect() as conn:
payment = await create_invoice(
wallet_id=wallet_id,
amount=data.amount,
memo=memo,
currency=data.unit,
description_hash=description_hash,
unhashed_description=unhashed_description,
expiry=data.expiry,
extra=data.extra,
webhook=data.webhook,
internal=data.internal,
payment_hash=data.payment_hash,
labels=data.labels,
conn=conn,
)
if data.lnurl_withdraw:
try:
@ -355,13 +361,15 @@ async def update_pending_payments(wallet_id: str):
await update_pending_payment(payment)
async def update_pending_payment(payment: Payment) -> Payment:
async def update_pending_payment(
payment: Payment, conn: Connection | None = None
) -> Payment:
status = await payment.check_status()
if status.failed:
payment.status = PaymentState.FAILED
await update_payment(payment)
await update_payment(payment, conn=conn)
elif status.success:
payment = await update_payment_success_status(payment, status)
payment = await update_payment_success_status(payment, status, conn=conn)
return payment
@ -698,6 +706,7 @@ async def _pay_internal_invoice(
internal_payment = await check_internal(
create_payment_model.payment_hash, conn=conn
)
if not internal_payment:
return None
@ -706,6 +715,7 @@ async def _pay_internal_invoice(
internal_invoice = await get_standalone_payment(
internal_payment.checking_id, incoming=True, conn=conn
)
if not internal_invoice:
raise PaymentError("Internal payment not found.", status="failed")
@ -727,6 +737,7 @@ async def _pay_internal_invoice(
internal_id = f"internal_{create_payment_model.payment_hash}"
logger.debug(f"creating temporary internal payment with id {internal_id}")
payment = await create_payment(
checking_id=internal_id,
data=create_payment_model,

View file

@ -3,7 +3,9 @@ from uuid import uuid4
from loguru import logger
from lnbits.core.db import db
from lnbits.core.models.extensions import UserExtension
from lnbits.db import Connection
from lnbits.settings import (
EditableSettings,
SuperSettings,
@ -48,37 +50,43 @@ async def create_user_account_no_ckeck(
account: Account | None = None,
wallet_name: str | None = None,
default_exts: list[str] | None = None,
conn: Connection | None = None,
) -> User:
async with db.reuse_conn(conn) if conn else db.connect() as conn:
if account:
account.validate_fields()
if account.username and await get_account_by_username(
account.username, conn=conn
):
raise ValueError("Username already exists.")
if account:
account.validate_fields()
if account.username and await get_account_by_username(account.username):
raise ValueError("Username already exists.")
if account.email and await get_account_by_email(account.email, conn=conn):
raise ValueError("Email already exists.")
if account.email and await get_account_by_email(account.email):
raise ValueError("Email already exists.")
if account.pubkey and await get_account_by_pubkey(
account.pubkey, conn=conn
):
raise ValueError("Pubkey already exists.")
if account.pubkey and await get_account_by_pubkey(account.pubkey):
raise ValueError("Pubkey already exists.")
if not account.id:
account.id = uuid4().hex
if not account.id:
account.id = uuid4().hex
account = await create_account(account, conn=conn)
await create_wallet(
user_id=account.id,
wallet_name=wallet_name or settings.lnbits_default_wallet_name,
conn=conn,
)
account = await create_account(account)
await create_wallet(
user_id=account.id,
wallet_name=wallet_name or settings.lnbits_default_wallet_name,
)
user_extensions = (default_exts or []) + settings.lnbits_user_default_extensions
for ext_id in user_extensions:
try:
user_ext = UserExtension(user=account.id, extension=ext_id, active=True)
await create_user_extension(user_ext, conn=conn)
except Exception as e:
logger.error(f"Error enabeling default extension {ext_id}: {e}")
user_extensions = (default_exts or []) + settings.lnbits_user_default_extensions
for ext_id in user_extensions:
try:
user_ext = UserExtension(user=account.id, extension=ext_id, active=True)
await create_user_extension(user_ext)
except Exception as e:
logger.error(f"Error enabeling default extension {ext_id}: {e}")
user = await get_user_from_account(account)
user = await get_user_from_account(account, conn=conn)
if not user:
raise ValueError("Cannot find user for account.")

View file

@ -8,8 +8,8 @@ from urllib.parse import urlparse
from fastapi import APIRouter, Depends, File
from fastapi.responses import FileResponse
from lnbits.core.models import User
from lnbits.core.models.notifications import NotificationType
from lnbits.core.models.users import Account
from lnbits.core.services import (
enqueue_admin_notification,
get_balance_delta,
@ -67,9 +67,9 @@ async def api_test_email():
@admin_router.get("/api/v1/settings")
async def api_get_settings(
user: User = Depends(check_admin),
account: Account = Depends(check_admin),
) -> AdminSettings | None:
admin_settings = await get_admin_settings(user.super_user)
admin_settings = await get_admin_settings(account.is_super_user)
return admin_settings
@ -77,12 +77,14 @@ async def api_get_settings(
"/api/v1/settings",
status_code=HTTPStatus.OK,
)
async def api_update_settings(data: UpdateSettings, user: User = Depends(check_admin)):
async def api_update_settings(
data: UpdateSettings, account: Account = Depends(check_admin)
):
enqueue_admin_notification(
NotificationType.settings_update, {"username": user.username}
NotificationType.settings_update, {"username": account.username}
)
await update_admin_settings(data)
admin_settings = await get_admin_settings(user.super_user)
admin_settings = await get_admin_settings(account.is_super_user)
if not admin_settings:
raise ValueError("Updated admin settings not found.")
update_cached_settings(admin_settings.dict())
@ -94,9 +96,11 @@ async def api_update_settings(data: UpdateSettings, user: User = Depends(check_a
"/api/v1/settings",
status_code=HTTPStatus.OK,
)
async def api_update_settings_partial(data: dict, user: User = Depends(check_admin)):
async def api_update_settings_partial(
data: dict, account: Account = Depends(check_admin)
):
updatable_settings = dict_to_settings({**settings.dict(), **data})
return await api_update_settings(updatable_settings, user)
return await api_update_settings(updatable_settings, account)
@admin_router.get(
@ -110,9 +114,9 @@ async def api_reset_settings(field_name: str):
@admin_router.delete("/api/v1/settings", status_code=HTTPStatus.OK)
async def api_delete_settings(user: User = Depends(check_super_user)) -> None:
async def api_delete_settings(account: Account = Depends(check_super_user)) -> None:
enqueue_admin_notification(
NotificationType.settings_update, {"username": user.username}
NotificationType.settings_update, {"username": account.username}
)
await reset_core_settings()
server_restart.set()

View file

@ -7,9 +7,10 @@ from fastapi import APIRouter, Depends, HTTPException
from loguru import logger
from lnbits.core.crud.extensions import get_user_extensions
from lnbits.core.crud.wallets import get_wallets_ids
from lnbits.core.db import db
from lnbits.core.models import (
SimpleStatus,
User,
)
from lnbits.core.models.extensions import (
CreateExtension,
@ -23,7 +24,7 @@ from lnbits.core.models.extensions import (
UserExtension,
UserExtensionInfo,
)
from lnbits.core.models.users import AccountId
from lnbits.core.models.users import Account, AccountId
from lnbits.core.services import check_transaction_status, create_invoice
from lnbits.core.services.extensions import (
activate_extension,
@ -144,9 +145,10 @@ async def api_extension_details(
async def api_update_pay_to_enable(
ext_id: str,
data: PayToEnableInfo,
user: User = Depends(check_admin),
account: Account = Depends(check_admin),
) -> SimpleStatus:
if data.wallet not in user.wallet_ids:
user_wallet_ids = await get_wallets_ids(account.id, deleted=False)
if data.wallet not in user_wallet_ids:
raise HTTPException(
HTTPStatus.BAD_REQUEST, "Wallet does not belong to this admin user."
)
@ -462,15 +464,16 @@ async def get_extension_release(org: str, repo: str, tag_name: str):
async def api_get_user_extensions(
account_id: AccountId = Depends(check_account_id_exists),
) -> list[Extension]:
user_extensions_ids = [
ue.extension for ue in await get_user_extensions(account_id.id)
]
return [
ext
for ext in await get_valid_extensions(False)
if ext.code in user_extensions_ids
]
async with db.connect() as conn:
user_extensions_ids = [
ue.extension for ue in await get_user_extensions(account_id.id, conn=conn)
]
valid_extensions = [
ext
for ext in await get_valid_extensions(False, conn=conn)
if ext.code in user_extensions_ids
]
return valid_extensions
@extension_router.delete(
@ -505,7 +508,16 @@ async def delete_extension_db(ext_id: str):
# TODO: create a response model for this
@extension_router.get("/all")
async def extensions(account_id: AccountId = Depends(check_account_id_exists)):
installed_exts: list[InstallableExtension] = await get_installed_extensions()
async with db.connect() as conn:
installed_exts: list[InstallableExtension] = await get_installed_extensions(
conn=conn
)
all_ext_ids = [ext.code for ext in await get_valid_extensions(conn=conn)]
inactive_extensions = [
e.id for e in await get_installed_extensions(active=False, conn=conn)
]
db_versions = await get_db_versions(conn=conn)
installed_exts_ids = [e.id for e in installed_exts]
installable_exts = await InstallableExtension.get_installable_extensions(
@ -536,10 +548,6 @@ async def extensions(account_id: AccountId = Depends(check_account_id_exists)):
e.short_description = installed_ext.short_description
e.icon = installed_ext.icon
all_ext_ids = [ext.code for ext in await get_valid_extensions()]
inactive_extensions = [e.id for e in await get_installed_extensions(active=False)]
db_versions = await get_db_versions()
extension_data = [
{
"id": ext.id,

View file

@ -7,7 +7,6 @@ from fastapi.responses import FileResponse
from lnbits.core.models import (
SimpleStatus,
User,
)
from lnbits.core.models.extensions import (
Extension,
@ -16,7 +15,7 @@ from lnbits.core.models.extensions import (
UserExtension,
)
from lnbits.core.models.extensions_builder import ExtensionData
from lnbits.core.models.users import AccountId
from lnbits.core.models.users import Account, AccountId
from lnbits.core.services.extensions import (
activate_extension,
install_extension,
@ -85,9 +84,9 @@ async def api_build_extension(data: ExtensionData) -> FileResponse:
)
async def api_deploy_extension(
data: ExtensionData,
user: User = Depends(check_admin),
account: Account = Depends(check_admin),
) -> SimpleStatus:
working_dir_name = "deploy_" + sha256(user.id.encode("utf-8")).hexdigest()
working_dir_name = "deploy_" + sha256(account.id.encode("utf-8")).hexdigest()
stub_ext_id = "extension_builder_stub"
release, build_dir = await build_extension_from_data(
data, stub_ext_id, working_dir_name
@ -111,9 +110,9 @@ async def api_deploy_extension(
await activate_extension(Extension.from_installable_ext(ext_info))
user_ext = await get_user_extension(user.id, data.id)
user_ext = await get_user_extension(account.id, data.id)
if not user_ext:
user_ext = UserExtension(user=user.id, extension=data.id, active=True)
user_ext = UserExtension(user=account.id, extension=data.id, active=True)
await create_user_extension(user_ext)
elif not user_ext.active:
user_ext.active = True

View file

@ -18,6 +18,7 @@ from lnbits.core.crud.payments import (
update_payment,
)
from lnbits.core.crud.users import get_account
from lnbits.core.db import db
from lnbits.core.models import (
CancelInvoice,
CreateInvoice,
@ -69,7 +70,6 @@ from ..services import (
perform_withdraw,
settle_hold_invoice,
update_pending_payment,
update_pending_payments,
)
payment_router = APIRouter(prefix="/api/v1/payments", tags=["Payments"])
@ -87,7 +87,6 @@ async def api_payments(
key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key),
filters: Filters = Depends(parse_filters(PaymentFilters)),
):
await update_pending_payments(key_info.wallet.id)
return await get_payments(
wallet_id=key_info.wallet.id,
pending=True,
@ -107,7 +106,6 @@ async def api_payments_history(
group: DateTrunc = Query("day"),
filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)),
):
await update_pending_payments(key_info.wallet.id)
return await get_payments_history(key_info.wallet.id, group, filters)
@ -186,18 +184,24 @@ async def api_payments_paginated(
),
filters: Filters = Depends(parse_filters(PaymentFilters)),
) -> Page[Payment]:
page = await get_payments_paginated(
wallet_id=key_info.wallet.id,
filters=filters,
)
if not recheck_pending:
return page
async with db.connect() as conn:
page = await get_payments_paginated(
wallet_id=key_info.wallet.id,
filters=filters,
conn=conn,
)
if not recheck_pending:
return page
for payment in page.data:
if payment.pending:
await update_pending_payment(payment)
payments = []
for payment in page.data:
if payment.pending:
refreshed_payment = await update_pending_payment(payment, conn=conn)
payments.append(refreshed_payment)
else:
payments.append(payment)
return page
return Page(data=payments, total=page.total)
@payment_router.get(
@ -219,10 +223,10 @@ async def api_all_payments_paginated(
# regular user can only see payments from their wallets
for_user_id = account_id.id
return await get_payments_paginated(
filters=filters,
user_id=for_user_id,
)
async with db.connect() as conn:
return await get_payments_paginated(
filters=filters, user_id=for_user_id, conn=conn
)
@payment_router.post(

View file

@ -115,12 +115,12 @@ async def api_create_user(data: CreateUser) -> CreateUser:
@users_router.put("/user/{user_id}", name="Update user")
async def api_update_user(
user_id: str, data: CreateUser, user: User = Depends(check_admin)
user_id: str, data: CreateUser, account: Account = Depends(check_admin)
) -> CreateUser:
if user_id != data.id:
raise HTTPException(HTTPStatus.BAD_REQUEST, "User Id missmatch.")
if user_id == settings.super_user and user.id != settings.super_user:
if user_id == settings.super_user and account.id != settings.super_user:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Action only allowed for super user.",
@ -154,7 +154,7 @@ async def api_update_user(
name="Delete user by Id",
)
async def api_users_delete_user(
user_id: str, user: User = Depends(check_admin)
user_id: str, account: Account = Depends(check_admin)
) -> SimpleStatus:
wallets = await get_wallets(user_id, deleted=False)
if len(wallets) > 0:
@ -169,7 +169,7 @@ async def api_users_delete_user(
detail="Cannot delete super user.",
)
if user_id in settings.lnbits_admin_users and not user.super_user:
if user_id in settings.lnbits_admin_users and not account.is_super_user:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Only super_user can delete admin user.",
@ -295,7 +295,7 @@ async def api_users_delete_all_user_wallet(user_id: str) -> SimpleStatus:
"The second time it is called will delete the entry from the DB",
)
async def api_users_delete_user_wallet(
user_id: str, wallet: str, user: User = Depends(check_admin)
user_id: str, wallet: str, account: Account = Depends(check_admin)
) -> SimpleStatus:
wal = await get_wallet(wallet)
if not wal:
@ -304,7 +304,7 @@ async def api_users_delete_user_wallet(
detail="Wallet does not exist.",
)
if user_id == settings.super_user and user.id != settings.super_user:
if user_id == settings.super_user and account.id != settings.super_user:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Action only allowed for super user.",

View file

@ -20,6 +20,7 @@ from lnbits.core.crud import (
)
from lnbits.core.crud.users import get_user_access_control_lists
from lnbits.core.crud.wallets import get_base_wallet_for_key
from lnbits.core.db import db
from lnbits.core.models import (
AccessTokenPayload,
Account,
@ -119,16 +120,17 @@ class KeyChecker(BaseKeyChecker):
async def __call__(self, request: Request) -> WalletTypeInfo:
key_value = self._extract_key_value(request)
wallet = await get_wallet_for_key(key_value)
async with db.connect() as conn:
wallet = await get_wallet_for_key(key_value, conn=conn)
if not wallet:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Wallet not found.",
)
if not wallet:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Wallet not found.",
)
request.scope["user_id"] = wallet.user
await _check_user_access(request, wallet.user)
request.scope["user_id"] = wallet.user
await _check_user_access(request, wallet.user, conn=conn)
key_type = await self._extract_key_type(key_value, wallet)
return WalletTypeInfo(key_type, wallet)
@ -148,22 +150,23 @@ class LightKeyChecker(BaseKeyChecker):
cache_key = f"auth:x-api-key:{key_value}"
cache_time = settings.auth_authentication_cache_minutes * 60
if cache_time > 0:
key_info: BaseWalletTypeInfo | None = cache.get(cache_key)
if key_info:
request.scope["user_id"] = key_info.wallet.user
await _check_user_access(request, key_info.wallet.user)
return key_info
async with db.connect() as conn:
if cache_time > 0:
key_info: BaseWalletTypeInfo | None = cache.get(cache_key)
if key_info:
request.scope["user_id"] = key_info.wallet.user
await _check_user_access(request, key_info.wallet.user, conn=conn)
return key_info
wallet = await get_base_wallet_for_key(key_value)
wallet = await get_base_wallet_for_key(key_value, conn=conn)
if not wallet:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Wallet not found.",
)
request.scope["user_id"] = wallet.user
await _check_user_access(request, wallet.user)
if not wallet:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Wallet not found.",
)
request.scope["user_id"] = wallet.user
await _check_user_access(request, wallet.user, conn=conn)
key_type = await self._extract_key_type(key_value, wallet)
key_info = BaseWalletTypeInfo(key_type, wallet)
@ -241,15 +244,16 @@ async def check_account_id_exists(
elif usr:
cache_key = f"auth:user_id:{sha256s(usr.hex)}"
if cache_key and settings.auth_authentication_cache_minutes > 0:
account_id = cache.get(cache_key)
if account_id:
r.scope["user_id"] = account_id.id
await _check_user_access(r, account_id)
return account_id
async with db.connect() as conn:
if cache_key and settings.auth_authentication_cache_minutes > 0:
account_id = cache.get(cache_key)
if account_id:
r.scope["user_id"] = account_id.id
await _check_user_access(r, account_id, conn=conn)
return account_id
account = await check_account_exists(r, access_token, usr)
account_id = AccountId(id=account.id)
account = await _check_account_exists(r, access_token, usr, conn=conn)
account_id = AccountId(id=account.id)
if cache_key and settings.auth_authentication_cache_minutes > 0:
cache.set(
@ -265,6 +269,15 @@ async def check_account_exists(
r: Request,
access_token: Annotated[str | None, Depends(check_access_token)],
usr: UUID4 | None = None,
) -> Account:
return await _check_account_exists(r, access_token, usr)
async def _check_account_exists(
r: Request,
access_token: Annotated[str | None, Depends(check_access_token)],
usr: UUID4 | None = None,
conn: Connection | None = None,
) -> Account:
"""
Check that the account exists based on access token or user id.
@ -273,22 +286,27 @@ async def check_account_exists(
- does not fetch the user wallets
- caches the account info based on settings cache time
"""
if access_token:
account = await _get_account_from_token(access_token, r["path"], r["method"])
elif usr and settings.is_auth_method_allowed(AuthMethods.user_id_only):
account = await get_account(usr.hex)
if account and account.is_admin:
raise HTTPException(
HTTPStatus.FORBIDDEN, "User id only access for admins is forbidden."
async with db.reuse_conn(conn) if conn else db.connect() as new_conn:
if access_token:
account = await _get_account_from_token(
access_token, r["path"], r["method"], conn=new_conn
)
elif usr and settings.is_auth_method_allowed(AuthMethods.user_id_only):
account = await get_account(usr.hex, conn=new_conn)
if account and account.is_admin:
raise HTTPException(
HTTPStatus.FORBIDDEN, "User id only access for admins is forbidden."
)
else:
raise HTTPException(
HTTPStatus.UNAUTHORIZED, "Missing user ID or access token."
)
else:
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Missing user ID or access token.")
if not account:
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.")
if not account:
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.")
r.scope["user_id"] = account.id
await _check_user_access(r, account.id)
r.scope["user_id"] = account.id
await _check_user_access(r, account.id, conn=new_conn)
return account
@ -298,8 +316,9 @@ async def check_user_exists(
access_token: Annotated[str | None, Depends(check_access_token)],
usr: UUID4 | None = None,
) -> User:
account = await check_account_exists(r, access_token, usr)
user = await get_user_from_account(account)
async with db.connect() as conn:
account = await _check_account_exists(r, access_token, usr, conn=conn)
user = await get_user_from_account(account, conn=conn)
if not user:
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.")
@ -330,29 +349,36 @@ async def access_token_payload(
return AccessTokenPayload(**payload)
async def check_admin(user: Annotated[User, Depends(check_user_exists)]) -> User:
if user.id != settings.super_user and user.id not in settings.lnbits_admin_users:
async def check_admin(
account: Annotated[Account, Depends(check_account_exists)],
) -> Account:
if (
account.id != settings.super_user
and account.id not in settings.lnbits_admin_users
):
raise HTTPException(
HTTPStatus.FORBIDDEN, "User not authorized. No admin privileges."
)
if not user.has_password:
if not account.has_password:
raise HTTPException(
HTTPStatus.FORBIDDEN, "Admin users must have credentials configured."
)
return user
return account
async def check_super_user(user: Annotated[User, Depends(check_user_exists)]) -> User:
if user.id != settings.super_user:
async def check_super_user(
account: Annotated[Account, Depends(check_admin)],
) -> Account:
if account.id != settings.super_user:
raise HTTPException(
HTTPStatus.FORBIDDEN, "User not authorized. No super user privileges."
)
if not user.has_password:
if not account.has_password:
raise HTTPException(
HTTPStatus.FORBIDDEN, "Super user must have credentials configured."
)
return user
return account
def parse_filters(model: type[TFilterModel]):
@ -412,15 +438,17 @@ async def check_user_extension_access(
return SimpleStatus(success=True, message="OK")
async def _check_user_access(r: Request, user_id: str):
async def _check_user_access(r: Request, user_id: str, conn: Connection | None = None):
if not settings.is_user_allowed(user_id):
raise HTTPException(HTTPStatus.FORBIDDEN, "User not allowed.")
await _check_user_extension_access(user_id, r["path"])
await _check_user_extension_access(user_id, r["path"], conn=conn)
async def _check_user_extension_access(user_id: str, path: str):
async def _check_user_extension_access(
user_id: str, path: str, conn: Connection | None = None
):
ext_id = path_segments(path)[0]
status = await check_user_extension_access(user_id, ext_id)
status = await check_user_extension_access(user_id, ext_id, conn=conn)
if not status.success:
raise HTTPException(
HTTPStatus.FORBIDDEN,
@ -429,12 +457,12 @@ async def _check_user_extension_access(user_id: str, path: str):
async def _get_account_from_token(
access_token: str, path: str, method: str
access_token: str, path: str, method: str, conn: Connection | None = None
) -> Account | None:
try:
payload: dict = jwt.decode(access_token, settings.auth_secret_key, ["HS256"])
return await _get_account_from_jwt_payload(
AccessTokenPayload(**payload), path, method
AccessTokenPayload(**payload), path, method, conn=conn
)
except jwt.ExpiredSignatureError as exc:
@ -447,33 +475,35 @@ async def _get_account_from_token(
async def _get_account_from_jwt_payload(
payload: AccessTokenPayload, path: str, method: str
payload: AccessTokenPayload, path: str, method: str, conn: Connection | None = None
) -> Account | None:
account = None
if payload.sub:
account = await get_account_by_username(payload.sub)
account = await get_account_by_username(payload.sub, conn=conn)
elif payload.usr:
account = await get_account(payload.usr)
account = await get_account(payload.usr, conn=conn)
elif payload.email:
account = await get_account_by_email(payload.email)
account = await get_account_by_email(payload.email, conn=conn)
if not account:
return None
if payload.api_token_id:
await _check_account_api_access(account.id, payload.api_token_id, path, method)
await _check_account_api_access(
account.id, payload.api_token_id, path, method, conn=conn
)
return account
async def _check_account_api_access(
user_id: str, token_id: str, path: str, method: str
user_id: str, token_id: str, path: str, method: str, conn: Connection | None = None
):
segments = path.split("/")
if len(segments) < 3:
raise HTTPException(HTTPStatus.FORBIDDEN, "Not an API endpoint.")
acls = await get_user_access_control_lists(user_id)
acls = await get_user_access_control_lists(user_id, conn=conn)
acl = acls.get_acl_by_token_id(token_id)
if not acl:
raise HTTPException(HTTPStatus.FORBIDDEN, "Invalid token id.")