[perf] reuse connection (#3624)

This commit is contained in:
Vlad Stan 2025-12-06 15:52:06 +02:00 committed by GitHub
parent 71e0b396d2
commit 5d79327906
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 322 additions and 189 deletions

View file

@ -33,9 +33,9 @@ async def get_payment(checking_id: str, conn: Connection | None = None) -> Payme
async def get_standalone_payment( async def get_standalone_payment(
checking_id_or_hash: str, checking_id_or_hash: str,
conn: Connection | None = None,
incoming: bool | None = False, incoming: bool | None = False,
wallet_id: str | None = None, wallet_id: str | None = None,
conn: Connection | None = None,
) -> Payment | None: ) -> Payment | None:
clause: str = "checking_id = :checking_id OR payment_hash = :hash" clause: str = "checking_id = :checking_id OR payment_hash = :hash"
values = { values = {
@ -46,7 +46,7 @@ async def get_standalone_payment(
clause = f"({clause}) AND amount > 0" clause = f"({clause}) AND amount > 0"
if wallet_id: if wallet_id:
wallet = await get_wallet(wallet_id) wallet = await get_wallet(wallet_id, conn=conn)
if not wallet or not wallet.can_view_payments: if not wallet or not wallet.can_view_payments:
return None return None
values["wallet_id"] = wallet.source_wallet_id values["wallet_id"] = wallet.source_wallet_id
@ -69,7 +69,7 @@ async def get_standalone_payment(
async def get_wallet_payment( async def get_wallet_payment(
wallet_id: str, payment_hash: str, conn: Connection | None = None wallet_id: str, payment_hash: str, conn: Connection | None = None
) -> Payment | None: ) -> Payment | None:
wallet = await get_wallet(wallet_id) wallet = await get_wallet(wallet_id, conn=conn)
if not wallet or not wallet.can_view_payments: if not wallet or not wallet.can_view_payments:
return None return None
payment = await (conn or db).fetchone( payment = await (conn or db).fetchone(
@ -124,7 +124,6 @@ async def get_payments_paginated( # noqa: C901
Filters payments to be returned by: Filters payments to be returned by:
- complete | pending | failed | outgoing | incoming. - complete | pending | failed | outgoing | incoming.
""" """
values: dict[str, Any] = { values: dict[str, Any] = {
"time": since, "time": since,
} }
@ -134,7 +133,7 @@ async def get_payments_paginated( # noqa: C901
clause.append(f"time > {db.timestamp_placeholder('time')}") clause.append(f"time > {db.timestamp_placeholder('time')}")
if wallet_id: if wallet_id:
wallet = await get_wallet(wallet_id) wallet = await get_wallet(wallet_id, conn=conn)
if not wallet or not wallet.can_view_payments: if not wallet or not wallet.can_view_payments:
return Page(data=[], total=0) return Page(data=[], total=0)
@ -326,6 +325,7 @@ async def get_payments_history(
wallet_id: str | None = None, wallet_id: str | None = None,
group: DateTrunc = "day", group: DateTrunc = "day",
filters: Filters | None = None, filters: Filters | None = None,
conn: Connection | None = None,
) -> list[PaymentHistoryPoint]: ) -> list[PaymentHistoryPoint]:
if not filters: if not filters:
filters = Filters() filters = Filters()
@ -361,13 +361,13 @@ async def get_payments_history(
filters.values(values), filters.values(values),
) )
if wallet_id: if wallet_id:
wallet = await get_wallet(wallet_id) wallet = await get_wallet(wallet_id, conn=conn)
if not wallet or not wallet.can_view_payments: if not wallet or not wallet.can_view_payments:
return [] return []
balance = wallet.balance_msat balance = wallet.balance_msat
values["wallet_id"] = wallet.source_wallet_id values["wallet_id"] = wallet.source_wallet_id
else: else:
balance = await get_total_balance() balance = await get_total_balance(conn=conn)
# since we dont know the balance at the starting point, # since we dont know the balance at the starting point,
# we take the current balance and walk backwards # we take the current balance and walk backwards

View file

@ -30,9 +30,9 @@ async def create_account(
return account return account
async def update_account(account: Account) -> Account: async def update_account(account: Account, conn: Connection | None = None) -> Account:
account.updated_at = datetime.now(timezone.utc) account.updated_at = datetime.now(timezone.utc)
await db.update("accounts", account) await (conn or db).update("accounts", account)
return account return account
@ -171,21 +171,23 @@ async def get_account_by_username_or_email(
async def get_user(user_id: str, conn: Connection | None = None) -> User | None: async def get_user(user_id: str, conn: Connection | None = None) -> User | None:
account = await get_account(user_id, conn) async with db.reuse_conn(conn) if conn else db.connect() as conn:
if not account: account = await get_account(user_id, conn=conn)
return None if not account:
return await get_user_from_account(account, conn) return None
return await get_user_from_account(account, conn=conn)
async def get_user_from_account( async def get_user_from_account(
account: Account, conn: Connection | None = None account: Account, conn: Connection | None = None
) -> User | None: ) -> User | None:
extensions = await get_user_active_extensions_ids(account.id, conn=conn) async with db.reuse_conn(conn) if conn else db.connect() as conn:
wallets = await get_wallets(account.id, deleted=False, conn=conn) extensions = await get_user_active_extensions_ids(account.id, conn=conn)
wallets = await get_wallets(account.id, deleted=False, conn=conn)
if len(wallets) == 0: if len(wallets) == 0:
wallet = await create_wallet(user_id=account.id, conn=conn) wallet = await create_wallet(user_id=account.id, conn=conn)
wallets.append(wallet) wallets.append(wallet)
return User( return User(
id=account.id, id=account.id,
@ -205,9 +207,11 @@ async def get_user_from_account(
) )
async def update_user_access_control_list(user_acls: UserAcls): async def update_user_access_control_list(
user_acls: UserAcls, conn: Connection | None = None
):
user_acls.updated_at = datetime.now(timezone.utc) user_acls.updated_at = datetime.now(timezone.utc)
await db.update("accounts", user_acls) await (conn or db).update("accounts", user_acls)
async def get_user_access_control_lists( async def get_user_access_control_lists(

View file

@ -786,3 +786,63 @@ async def m038_add_labels_for_payments(db: Connection):
ALTER TABLE apipayments ADD COLUMN labels TEXT ALTER TABLE apipayments ADD COLUMN labels TEXT
""" """
) )
async def m039_index_payments(db: Connection):
indexes = [
"wallet_id",
"checking_id",
"payment_hash",
"amount",
"labels",
"time",
"status",
"created_at",
"updated_at",
]
for index in indexes:
logger.debug(f"Creating index idx_payments_{index}...")
await db.execute(
f"""
CREATE INDEX IF NOT EXISTS idx_payments_{index} ON apipayments ({index});
"""
)
async def m040_index_wallets(db: Connection):
indexes = [
"id",
"user",
"deleted",
"adminkey",
"inkey",
"wallet_type",
"created_at",
"updated_at",
]
for index in indexes:
logger.debug(f"Creating index idx_wallets_{index}...")
await db.execute(
f"""
CREATE INDEX IF NOT EXISTS idx_wallets_{index} ON wallets ("{index}");
"""
)
async def m042_index_accounts(db: Connection):
indexes = [
"id",
"email",
"username",
"pubkey",
"external_id",
]
for index in indexes:
logger.debug(f"Creating index idx_wallets_{index}...")
await db.execute(
f"""
CREATE INDEX IF NOT EXISTS idx_accounts_{index} ON accounts ("{index}");
"""
)

View file

@ -201,6 +201,10 @@ class Account(AccountId):
self.is_admin = settings.is_admin_user(self.id) self.is_admin = settings.is_admin_user(self.id)
self.fiat_providers = settings.get_fiat_providers_for_user(self.id) self.fiat_providers = settings.get_fiat_providers_for_user(self.id)
@property
def has_password(self) -> bool:
return self.password_hash is not None
def hash_password(self, password: str) -> str: def hash_password(self, password: str) -> str:
"""sets and returns the hashed password""" """sets and returns the hashed password"""
salt = gensalt() salt = gensalt()

View file

@ -16,6 +16,7 @@ from lnbits.core.crud.extensions import (
update_installed_extension, update_installed_extension,
) )
from lnbits.core.helpers import migrate_extension_database from lnbits.core.helpers import migrate_extension_database
from lnbits.db import Connection
from lnbits.settings import settings from lnbits.settings import settings
from ..models.extensions import Extension, ExtensionMeta, InstallableExtension from ..models.extensions import Extension, ExtensionMeta, InstallableExtension
@ -149,9 +150,9 @@ async def start_extension_background_work(ext_id: str) -> bool:
async def get_valid_extensions( async def get_valid_extensions(
include_deactivated: bool | None = True, include_deactivated: bool | None = True, conn: Connection | None = None
) -> list[Extension]: ) -> list[Extension]:
installed_extensions = await get_installed_extensions() installed_extensions = await get_installed_extensions(conn=conn)
valid_extensions = [Extension.from_installable_ext(e) for e in installed_extensions] valid_extensions = [Extension.from_installable_ext(e) for e in installed_extensions]
if include_deactivated: if include_deactivated:

View file

@ -67,12 +67,14 @@ async def pay_invoice(
if settings.lnbits_only_allow_incoming_payments: if settings.lnbits_only_allow_incoming_payments:
raise PaymentError("Only incoming payments allowed.", status="failed") raise PaymentError("Only incoming payments allowed.", status="failed")
invoice = _validate_payment_request(payment_request, max_sat) invoice = _validate_payment_request(payment_request, max_sat)
if not invoice.amount_msat: if not invoice.amount_msat:
raise ValueError("Missig invoice amount.") raise ValueError("Missig invoice amount.")
async with db.reuse_conn(conn) if conn else db.connect() as new_conn: async with db.reuse_conn(conn) if conn else db.connect() as new_conn:
amount_msat = invoice.amount_msat amount_msat = invoice.amount_msat
wallet = await _check_wallet_for_payment(wallet_id, tag, amount_msat, new_conn) wallet = await _check_wallet_for_payment(wallet_id, tag, amount_msat, new_conn)
if not wallet.can_send_payments: if not wallet.can_send_payments:
raise PaymentError( raise PaymentError(
"Wallet does not have permission to pay invoices.", "Wallet does not have permission to pay invoices.",
@ -95,10 +97,12 @@ async def pay_invoice(
labels=labels, labels=labels,
) )
payment = await _pay_invoice(wallet.source_wallet_id, create_payment_model, conn)
async with db.reuse_conn(conn) if conn else db.connect() as new_conn: async with db.reuse_conn(conn) if conn else db.connect() as new_conn:
await _credit_service_fee_wallet(wallet, payment, new_conn) payment = await _pay_invoice(
wallet.source_wallet_id, create_payment_model, conn=new_conn
)
await _credit_service_fee_wallet(wallet, payment, conn=new_conn)
return payment return payment
@ -197,20 +201,22 @@ async def create_wallet_invoice(wallet_id: str, data: CreateInvoice) -> Payment:
# do not save memo if description_hash or unhashed_description is set # do not save memo if description_hash or unhashed_description is set
memo = "" memo = ""
payment = await create_invoice( async with db.connect() as conn:
wallet_id=wallet_id, payment = await create_invoice(
amount=data.amount, wallet_id=wallet_id,
memo=memo, amount=data.amount,
currency=data.unit, memo=memo,
description_hash=description_hash, currency=data.unit,
unhashed_description=unhashed_description, description_hash=description_hash,
expiry=data.expiry, unhashed_description=unhashed_description,
extra=data.extra, expiry=data.expiry,
webhook=data.webhook, extra=data.extra,
internal=data.internal, webhook=data.webhook,
payment_hash=data.payment_hash, internal=data.internal,
labels=data.labels, payment_hash=data.payment_hash,
) labels=data.labels,
conn=conn,
)
if data.lnurl_withdraw: if data.lnurl_withdraw:
try: try:
@ -355,13 +361,15 @@ async def update_pending_payments(wallet_id: str):
await update_pending_payment(payment) await update_pending_payment(payment)
async def update_pending_payment(payment: Payment) -> Payment: async def update_pending_payment(
payment: Payment, conn: Connection | None = None
) -> Payment:
status = await payment.check_status() status = await payment.check_status()
if status.failed: if status.failed:
payment.status = PaymentState.FAILED payment.status = PaymentState.FAILED
await update_payment(payment) await update_payment(payment, conn=conn)
elif status.success: elif status.success:
payment = await update_payment_success_status(payment, status) payment = await update_payment_success_status(payment, status, conn=conn)
return payment return payment
@ -698,6 +706,7 @@ async def _pay_internal_invoice(
internal_payment = await check_internal( internal_payment = await check_internal(
create_payment_model.payment_hash, conn=conn create_payment_model.payment_hash, conn=conn
) )
if not internal_payment: if not internal_payment:
return None return None
@ -706,6 +715,7 @@ async def _pay_internal_invoice(
internal_invoice = await get_standalone_payment( internal_invoice = await get_standalone_payment(
internal_payment.checking_id, incoming=True, conn=conn internal_payment.checking_id, incoming=True, conn=conn
) )
if not internal_invoice: if not internal_invoice:
raise PaymentError("Internal payment not found.", status="failed") raise PaymentError("Internal payment not found.", status="failed")
@ -727,6 +737,7 @@ async def _pay_internal_invoice(
internal_id = f"internal_{create_payment_model.payment_hash}" internal_id = f"internal_{create_payment_model.payment_hash}"
logger.debug(f"creating temporary internal payment with id {internal_id}") logger.debug(f"creating temporary internal payment with id {internal_id}")
payment = await create_payment( payment = await create_payment(
checking_id=internal_id, checking_id=internal_id,
data=create_payment_model, data=create_payment_model,

View file

@ -3,7 +3,9 @@ from uuid import uuid4
from loguru import logger from loguru import logger
from lnbits.core.db import db
from lnbits.core.models.extensions import UserExtension from lnbits.core.models.extensions import UserExtension
from lnbits.db import Connection
from lnbits.settings import ( from lnbits.settings import (
EditableSettings, EditableSettings,
SuperSettings, SuperSettings,
@ -48,37 +50,43 @@ async def create_user_account_no_ckeck(
account: Account | None = None, account: Account | None = None,
wallet_name: str | None = None, wallet_name: str | None = None,
default_exts: list[str] | None = None, default_exts: list[str] | None = None,
conn: Connection | None = None,
) -> User: ) -> User:
async with db.reuse_conn(conn) if conn else db.connect() as conn:
if account:
account.validate_fields()
if account.username and await get_account_by_username(
account.username, conn=conn
):
raise ValueError("Username already exists.")
if account: if account.email and await get_account_by_email(account.email, conn=conn):
account.validate_fields() raise ValueError("Email already exists.")
if account.username and await get_account_by_username(account.username):
raise ValueError("Username already exists.")
if account.email and await get_account_by_email(account.email): if account.pubkey and await get_account_by_pubkey(
raise ValueError("Email already exists.") account.pubkey, conn=conn
):
raise ValueError("Pubkey already exists.")
if account.pubkey and await get_account_by_pubkey(account.pubkey): if not account.id:
raise ValueError("Pubkey already exists.") account.id = uuid4().hex
if not account.id: account = await create_account(account, conn=conn)
account.id = uuid4().hex await create_wallet(
user_id=account.id,
wallet_name=wallet_name or settings.lnbits_default_wallet_name,
conn=conn,
)
account = await create_account(account) user_extensions = (default_exts or []) + settings.lnbits_user_default_extensions
await create_wallet( for ext_id in user_extensions:
user_id=account.id, try:
wallet_name=wallet_name or settings.lnbits_default_wallet_name, user_ext = UserExtension(user=account.id, extension=ext_id, active=True)
) await create_user_extension(user_ext, conn=conn)
except Exception as e:
logger.error(f"Error enabeling default extension {ext_id}: {e}")
user_extensions = (default_exts or []) + settings.lnbits_user_default_extensions user = await get_user_from_account(account, conn=conn)
for ext_id in user_extensions:
try:
user_ext = UserExtension(user=account.id, extension=ext_id, active=True)
await create_user_extension(user_ext)
except Exception as e:
logger.error(f"Error enabeling default extension {ext_id}: {e}")
user = await get_user_from_account(account)
if not user: if not user:
raise ValueError("Cannot find user for account.") raise ValueError("Cannot find user for account.")

View file

@ -8,8 +8,8 @@ from urllib.parse import urlparse
from fastapi import APIRouter, Depends, File from fastapi import APIRouter, Depends, File
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from lnbits.core.models import User
from lnbits.core.models.notifications import NotificationType from lnbits.core.models.notifications import NotificationType
from lnbits.core.models.users import Account
from lnbits.core.services import ( from lnbits.core.services import (
enqueue_admin_notification, enqueue_admin_notification,
get_balance_delta, get_balance_delta,
@ -67,9 +67,9 @@ async def api_test_email():
@admin_router.get("/api/v1/settings") @admin_router.get("/api/v1/settings")
async def api_get_settings( async def api_get_settings(
user: User = Depends(check_admin), account: Account = Depends(check_admin),
) -> AdminSettings | None: ) -> AdminSettings | None:
admin_settings = await get_admin_settings(user.super_user) admin_settings = await get_admin_settings(account.is_super_user)
return admin_settings return admin_settings
@ -77,12 +77,14 @@ async def api_get_settings(
"/api/v1/settings", "/api/v1/settings",
status_code=HTTPStatus.OK, status_code=HTTPStatus.OK,
) )
async def api_update_settings(data: UpdateSettings, user: User = Depends(check_admin)): async def api_update_settings(
data: UpdateSettings, account: Account = Depends(check_admin)
):
enqueue_admin_notification( enqueue_admin_notification(
NotificationType.settings_update, {"username": user.username} NotificationType.settings_update, {"username": account.username}
) )
await update_admin_settings(data) await update_admin_settings(data)
admin_settings = await get_admin_settings(user.super_user) admin_settings = await get_admin_settings(account.is_super_user)
if not admin_settings: if not admin_settings:
raise ValueError("Updated admin settings not found.") raise ValueError("Updated admin settings not found.")
update_cached_settings(admin_settings.dict()) update_cached_settings(admin_settings.dict())
@ -94,9 +96,11 @@ async def api_update_settings(data: UpdateSettings, user: User = Depends(check_a
"/api/v1/settings", "/api/v1/settings",
status_code=HTTPStatus.OK, status_code=HTTPStatus.OK,
) )
async def api_update_settings_partial(data: dict, user: User = Depends(check_admin)): async def api_update_settings_partial(
data: dict, account: Account = Depends(check_admin)
):
updatable_settings = dict_to_settings({**settings.dict(), **data}) updatable_settings = dict_to_settings({**settings.dict(), **data})
return await api_update_settings(updatable_settings, user) return await api_update_settings(updatable_settings, account)
@admin_router.get( @admin_router.get(
@ -110,9 +114,9 @@ async def api_reset_settings(field_name: str):
@admin_router.delete("/api/v1/settings", status_code=HTTPStatus.OK) @admin_router.delete("/api/v1/settings", status_code=HTTPStatus.OK)
async def api_delete_settings(user: User = Depends(check_super_user)) -> None: async def api_delete_settings(account: Account = Depends(check_super_user)) -> None:
enqueue_admin_notification( enqueue_admin_notification(
NotificationType.settings_update, {"username": user.username} NotificationType.settings_update, {"username": account.username}
) )
await reset_core_settings() await reset_core_settings()
server_restart.set() server_restart.set()

View file

@ -7,9 +7,10 @@ from fastapi import APIRouter, Depends, HTTPException
from loguru import logger from loguru import logger
from lnbits.core.crud.extensions import get_user_extensions from lnbits.core.crud.extensions import get_user_extensions
from lnbits.core.crud.wallets import get_wallets_ids
from lnbits.core.db import db
from lnbits.core.models import ( from lnbits.core.models import (
SimpleStatus, SimpleStatus,
User,
) )
from lnbits.core.models.extensions import ( from lnbits.core.models.extensions import (
CreateExtension, CreateExtension,
@ -23,7 +24,7 @@ from lnbits.core.models.extensions import (
UserExtension, UserExtension,
UserExtensionInfo, UserExtensionInfo,
) )
from lnbits.core.models.users import AccountId from lnbits.core.models.users import Account, AccountId
from lnbits.core.services import check_transaction_status, create_invoice from lnbits.core.services import check_transaction_status, create_invoice
from lnbits.core.services.extensions import ( from lnbits.core.services.extensions import (
activate_extension, activate_extension,
@ -144,9 +145,10 @@ async def api_extension_details(
async def api_update_pay_to_enable( async def api_update_pay_to_enable(
ext_id: str, ext_id: str,
data: PayToEnableInfo, data: PayToEnableInfo,
user: User = Depends(check_admin), account: Account = Depends(check_admin),
) -> SimpleStatus: ) -> SimpleStatus:
if data.wallet not in user.wallet_ids: user_wallet_ids = await get_wallets_ids(account.id, deleted=False)
if data.wallet not in user_wallet_ids:
raise HTTPException( raise HTTPException(
HTTPStatus.BAD_REQUEST, "Wallet does not belong to this admin user." HTTPStatus.BAD_REQUEST, "Wallet does not belong to this admin user."
) )
@ -462,15 +464,16 @@ async def get_extension_release(org: str, repo: str, tag_name: str):
async def api_get_user_extensions( async def api_get_user_extensions(
account_id: AccountId = Depends(check_account_id_exists), account_id: AccountId = Depends(check_account_id_exists),
) -> list[Extension]: ) -> list[Extension]:
async with db.connect() as conn:
user_extensions_ids = [ user_extensions_ids = [
ue.extension for ue in await get_user_extensions(account_id.id) ue.extension for ue in await get_user_extensions(account_id.id, conn=conn)
] ]
return [ valid_extensions = [
ext ext
for ext in await get_valid_extensions(False) for ext in await get_valid_extensions(False, conn=conn)
if ext.code in user_extensions_ids if ext.code in user_extensions_ids
] ]
return valid_extensions
@extension_router.delete( @extension_router.delete(
@ -505,7 +508,16 @@ async def delete_extension_db(ext_id: str):
# TODO: create a response model for this # TODO: create a response model for this
@extension_router.get("/all") @extension_router.get("/all")
async def extensions(account_id: AccountId = Depends(check_account_id_exists)): async def extensions(account_id: AccountId = Depends(check_account_id_exists)):
installed_exts: list[InstallableExtension] = await get_installed_extensions() async with db.connect() as conn:
installed_exts: list[InstallableExtension] = await get_installed_extensions(
conn=conn
)
all_ext_ids = [ext.code for ext in await get_valid_extensions(conn=conn)]
inactive_extensions = [
e.id for e in await get_installed_extensions(active=False, conn=conn)
]
db_versions = await get_db_versions(conn=conn)
installed_exts_ids = [e.id for e in installed_exts] installed_exts_ids = [e.id for e in installed_exts]
installable_exts = await InstallableExtension.get_installable_extensions( installable_exts = await InstallableExtension.get_installable_extensions(
@ -536,10 +548,6 @@ async def extensions(account_id: AccountId = Depends(check_account_id_exists)):
e.short_description = installed_ext.short_description e.short_description = installed_ext.short_description
e.icon = installed_ext.icon e.icon = installed_ext.icon
all_ext_ids = [ext.code for ext in await get_valid_extensions()]
inactive_extensions = [e.id for e in await get_installed_extensions(active=False)]
db_versions = await get_db_versions()
extension_data = [ extension_data = [
{ {
"id": ext.id, "id": ext.id,

View file

@ -7,7 +7,6 @@ from fastapi.responses import FileResponse
from lnbits.core.models import ( from lnbits.core.models import (
SimpleStatus, SimpleStatus,
User,
) )
from lnbits.core.models.extensions import ( from lnbits.core.models.extensions import (
Extension, Extension,
@ -16,7 +15,7 @@ from lnbits.core.models.extensions import (
UserExtension, UserExtension,
) )
from lnbits.core.models.extensions_builder import ExtensionData from lnbits.core.models.extensions_builder import ExtensionData
from lnbits.core.models.users import AccountId from lnbits.core.models.users import Account, AccountId
from lnbits.core.services.extensions import ( from lnbits.core.services.extensions import (
activate_extension, activate_extension,
install_extension, install_extension,
@ -85,9 +84,9 @@ async def api_build_extension(data: ExtensionData) -> FileResponse:
) )
async def api_deploy_extension( async def api_deploy_extension(
data: ExtensionData, data: ExtensionData,
user: User = Depends(check_admin), account: Account = Depends(check_admin),
) -> SimpleStatus: ) -> SimpleStatus:
working_dir_name = "deploy_" + sha256(user.id.encode("utf-8")).hexdigest() working_dir_name = "deploy_" + sha256(account.id.encode("utf-8")).hexdigest()
stub_ext_id = "extension_builder_stub" stub_ext_id = "extension_builder_stub"
release, build_dir = await build_extension_from_data( release, build_dir = await build_extension_from_data(
data, stub_ext_id, working_dir_name data, stub_ext_id, working_dir_name
@ -111,9 +110,9 @@ async def api_deploy_extension(
await activate_extension(Extension.from_installable_ext(ext_info)) await activate_extension(Extension.from_installable_ext(ext_info))
user_ext = await get_user_extension(user.id, data.id) user_ext = await get_user_extension(account.id, data.id)
if not user_ext: if not user_ext:
user_ext = UserExtension(user=user.id, extension=data.id, active=True) user_ext = UserExtension(user=account.id, extension=data.id, active=True)
await create_user_extension(user_ext) await create_user_extension(user_ext)
elif not user_ext.active: elif not user_ext.active:
user_ext.active = True user_ext.active = True

View file

@ -18,6 +18,7 @@ from lnbits.core.crud.payments import (
update_payment, update_payment,
) )
from lnbits.core.crud.users import get_account from lnbits.core.crud.users import get_account
from lnbits.core.db import db
from lnbits.core.models import ( from lnbits.core.models import (
CancelInvoice, CancelInvoice,
CreateInvoice, CreateInvoice,
@ -69,7 +70,6 @@ from ..services import (
perform_withdraw, perform_withdraw,
settle_hold_invoice, settle_hold_invoice,
update_pending_payment, update_pending_payment,
update_pending_payments,
) )
payment_router = APIRouter(prefix="/api/v1/payments", tags=["Payments"]) payment_router = APIRouter(prefix="/api/v1/payments", tags=["Payments"])
@ -87,7 +87,6 @@ async def api_payments(
key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key), key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key),
filters: Filters = Depends(parse_filters(PaymentFilters)), filters: Filters = Depends(parse_filters(PaymentFilters)),
): ):
await update_pending_payments(key_info.wallet.id)
return await get_payments( return await get_payments(
wallet_id=key_info.wallet.id, wallet_id=key_info.wallet.id,
pending=True, pending=True,
@ -107,7 +106,6 @@ async def api_payments_history(
group: DateTrunc = Query("day"), group: DateTrunc = Query("day"),
filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)), filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)),
): ):
await update_pending_payments(key_info.wallet.id)
return await get_payments_history(key_info.wallet.id, group, filters) return await get_payments_history(key_info.wallet.id, group, filters)
@ -186,18 +184,24 @@ async def api_payments_paginated(
), ),
filters: Filters = Depends(parse_filters(PaymentFilters)), filters: Filters = Depends(parse_filters(PaymentFilters)),
) -> Page[Payment]: ) -> Page[Payment]:
page = await get_payments_paginated( async with db.connect() as conn:
wallet_id=key_info.wallet.id, page = await get_payments_paginated(
filters=filters, wallet_id=key_info.wallet.id,
) filters=filters,
if not recheck_pending: conn=conn,
return page )
if not recheck_pending:
return page
for payment in page.data: payments = []
if payment.pending: for payment in page.data:
await update_pending_payment(payment) if payment.pending:
refreshed_payment = await update_pending_payment(payment, conn=conn)
payments.append(refreshed_payment)
else:
payments.append(payment)
return page return Page(data=payments, total=page.total)
@payment_router.get( @payment_router.get(
@ -219,10 +223,10 @@ async def api_all_payments_paginated(
# regular user can only see payments from their wallets # regular user can only see payments from their wallets
for_user_id = account_id.id for_user_id = account_id.id
return await get_payments_paginated( async with db.connect() as conn:
filters=filters, return await get_payments_paginated(
user_id=for_user_id, filters=filters, user_id=for_user_id, conn=conn
) )
@payment_router.post( @payment_router.post(

View file

@ -115,12 +115,12 @@ async def api_create_user(data: CreateUser) -> CreateUser:
@users_router.put("/user/{user_id}", name="Update user") @users_router.put("/user/{user_id}", name="Update user")
async def api_update_user( async def api_update_user(
user_id: str, data: CreateUser, user: User = Depends(check_admin) user_id: str, data: CreateUser, account: Account = Depends(check_admin)
) -> CreateUser: ) -> CreateUser:
if user_id != data.id: if user_id != data.id:
raise HTTPException(HTTPStatus.BAD_REQUEST, "User Id missmatch.") raise HTTPException(HTTPStatus.BAD_REQUEST, "User Id missmatch.")
if user_id == settings.super_user and user.id != settings.super_user: if user_id == settings.super_user and account.id != settings.super_user:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
detail="Action only allowed for super user.", detail="Action only allowed for super user.",
@ -154,7 +154,7 @@ async def api_update_user(
name="Delete user by Id", name="Delete user by Id",
) )
async def api_users_delete_user( async def api_users_delete_user(
user_id: str, user: User = Depends(check_admin) user_id: str, account: Account = Depends(check_admin)
) -> SimpleStatus: ) -> SimpleStatus:
wallets = await get_wallets(user_id, deleted=False) wallets = await get_wallets(user_id, deleted=False)
if len(wallets) > 0: if len(wallets) > 0:
@ -169,7 +169,7 @@ async def api_users_delete_user(
detail="Cannot delete super user.", detail="Cannot delete super user.",
) )
if user_id in settings.lnbits_admin_users and not user.super_user: if user_id in settings.lnbits_admin_users and not account.is_super_user:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
detail="Only super_user can delete admin user.", detail="Only super_user can delete admin user.",
@ -295,7 +295,7 @@ async def api_users_delete_all_user_wallet(user_id: str) -> SimpleStatus:
"The second time it is called will delete the entry from the DB", "The second time it is called will delete the entry from the DB",
) )
async def api_users_delete_user_wallet( async def api_users_delete_user_wallet(
user_id: str, wallet: str, user: User = Depends(check_admin) user_id: str, wallet: str, account: Account = Depends(check_admin)
) -> SimpleStatus: ) -> SimpleStatus:
wal = await get_wallet(wallet) wal = await get_wallet(wallet)
if not wal: if not wal:
@ -304,7 +304,7 @@ async def api_users_delete_user_wallet(
detail="Wallet does not exist.", detail="Wallet does not exist.",
) )
if user_id == settings.super_user and user.id != settings.super_user: if user_id == settings.super_user and account.id != settings.super_user:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
detail="Action only allowed for super user.", detail="Action only allowed for super user.",

View file

@ -20,6 +20,7 @@ from lnbits.core.crud import (
) )
from lnbits.core.crud.users import get_user_access_control_lists from lnbits.core.crud.users import get_user_access_control_lists
from lnbits.core.crud.wallets import get_base_wallet_for_key from lnbits.core.crud.wallets import get_base_wallet_for_key
from lnbits.core.db import db
from lnbits.core.models import ( from lnbits.core.models import (
AccessTokenPayload, AccessTokenPayload,
Account, Account,
@ -119,16 +120,17 @@ class KeyChecker(BaseKeyChecker):
async def __call__(self, request: Request) -> WalletTypeInfo: async def __call__(self, request: Request) -> WalletTypeInfo:
key_value = self._extract_key_value(request) key_value = self._extract_key_value(request)
wallet = await get_wallet_for_key(key_value) async with db.connect() as conn:
wallet = await get_wallet_for_key(key_value, conn=conn)
if not wallet: if not wallet:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, status_code=HTTPStatus.NOT_FOUND,
detail="Wallet not found.", detail="Wallet not found.",
) )
request.scope["user_id"] = wallet.user request.scope["user_id"] = wallet.user
await _check_user_access(request, wallet.user) await _check_user_access(request, wallet.user, conn=conn)
key_type = await self._extract_key_type(key_value, wallet) key_type = await self._extract_key_type(key_value, wallet)
return WalletTypeInfo(key_type, wallet) return WalletTypeInfo(key_type, wallet)
@ -148,22 +150,23 @@ class LightKeyChecker(BaseKeyChecker):
cache_key = f"auth:x-api-key:{key_value}" cache_key = f"auth:x-api-key:{key_value}"
cache_time = settings.auth_authentication_cache_minutes * 60 cache_time = settings.auth_authentication_cache_minutes * 60
if cache_time > 0: async with db.connect() as conn:
key_info: BaseWalletTypeInfo | None = cache.get(cache_key) if cache_time > 0:
if key_info: key_info: BaseWalletTypeInfo | None = cache.get(cache_key)
request.scope["user_id"] = key_info.wallet.user if key_info:
await _check_user_access(request, key_info.wallet.user) request.scope["user_id"] = key_info.wallet.user
return key_info await _check_user_access(request, key_info.wallet.user, conn=conn)
return key_info
wallet = await get_base_wallet_for_key(key_value) wallet = await get_base_wallet_for_key(key_value, conn=conn)
if not wallet: if not wallet:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, status_code=HTTPStatus.NOT_FOUND,
detail="Wallet not found.", detail="Wallet not found.",
) )
request.scope["user_id"] = wallet.user request.scope["user_id"] = wallet.user
await _check_user_access(request, wallet.user) await _check_user_access(request, wallet.user, conn=conn)
key_type = await self._extract_key_type(key_value, wallet) key_type = await self._extract_key_type(key_value, wallet)
key_info = BaseWalletTypeInfo(key_type, wallet) key_info = BaseWalletTypeInfo(key_type, wallet)
@ -241,15 +244,16 @@ async def check_account_id_exists(
elif usr: elif usr:
cache_key = f"auth:user_id:{sha256s(usr.hex)}" cache_key = f"auth:user_id:{sha256s(usr.hex)}"
if cache_key and settings.auth_authentication_cache_minutes > 0: async with db.connect() as conn:
account_id = cache.get(cache_key) if cache_key and settings.auth_authentication_cache_minutes > 0:
if account_id: account_id = cache.get(cache_key)
r.scope["user_id"] = account_id.id if account_id:
await _check_user_access(r, account_id) r.scope["user_id"] = account_id.id
return account_id await _check_user_access(r, account_id, conn=conn)
return account_id
account = await check_account_exists(r, access_token, usr) account = await _check_account_exists(r, access_token, usr, conn=conn)
account_id = AccountId(id=account.id) account_id = AccountId(id=account.id)
if cache_key and settings.auth_authentication_cache_minutes > 0: if cache_key and settings.auth_authentication_cache_minutes > 0:
cache.set( cache.set(
@ -265,6 +269,15 @@ async def check_account_exists(
r: Request, r: Request,
access_token: Annotated[str | None, Depends(check_access_token)], access_token: Annotated[str | None, Depends(check_access_token)],
usr: UUID4 | None = None, usr: UUID4 | None = None,
) -> Account:
return await _check_account_exists(r, access_token, usr)
async def _check_account_exists(
r: Request,
access_token: Annotated[str | None, Depends(check_access_token)],
usr: UUID4 | None = None,
conn: Connection | None = None,
) -> Account: ) -> Account:
""" """
Check that the account exists based on access token or user id. Check that the account exists based on access token or user id.
@ -273,22 +286,27 @@ async def check_account_exists(
- does not fetch the user wallets - does not fetch the user wallets
- caches the account info based on settings cache time - caches the account info based on settings cache time
""" """
if access_token: async with db.reuse_conn(conn) if conn else db.connect() as new_conn:
account = await _get_account_from_token(access_token, r["path"], r["method"]) if access_token:
elif usr and settings.is_auth_method_allowed(AuthMethods.user_id_only): account = await _get_account_from_token(
account = await get_account(usr.hex) access_token, r["path"], r["method"], conn=new_conn
if account and account.is_admin: )
raise HTTPException( elif usr and settings.is_auth_method_allowed(AuthMethods.user_id_only):
HTTPStatus.FORBIDDEN, "User id only access for admins is forbidden." account = await get_account(usr.hex, conn=new_conn)
if account and account.is_admin:
raise HTTPException(
HTTPStatus.FORBIDDEN, "User id only access for admins is forbidden."
)
else:
raise HTTPException(
HTTPStatus.UNAUTHORIZED, "Missing user ID or access token."
) )
else:
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Missing user ID or access token.")
if not account: if not account:
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.") raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.")
r.scope["user_id"] = account.id r.scope["user_id"] = account.id
await _check_user_access(r, account.id) await _check_user_access(r, account.id, conn=new_conn)
return account return account
@ -298,8 +316,9 @@ async def check_user_exists(
access_token: Annotated[str | None, Depends(check_access_token)], access_token: Annotated[str | None, Depends(check_access_token)],
usr: UUID4 | None = None, usr: UUID4 | None = None,
) -> User: ) -> User:
account = await check_account_exists(r, access_token, usr) async with db.connect() as conn:
user = await get_user_from_account(account) account = await _check_account_exists(r, access_token, usr, conn=conn)
user = await get_user_from_account(account, conn=conn)
if not user: if not user:
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.") raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.")
@ -330,29 +349,36 @@ async def access_token_payload(
return AccessTokenPayload(**payload) return AccessTokenPayload(**payload)
async def check_admin(user: Annotated[User, Depends(check_user_exists)]) -> User: async def check_admin(
if user.id != settings.super_user and user.id not in settings.lnbits_admin_users: account: Annotated[Account, Depends(check_account_exists)],
) -> Account:
if (
account.id != settings.super_user
and account.id not in settings.lnbits_admin_users
):
raise HTTPException( raise HTTPException(
HTTPStatus.FORBIDDEN, "User not authorized. No admin privileges." HTTPStatus.FORBIDDEN, "User not authorized. No admin privileges."
) )
if not user.has_password: if not account.has_password:
raise HTTPException( raise HTTPException(
HTTPStatus.FORBIDDEN, "Admin users must have credentials configured." HTTPStatus.FORBIDDEN, "Admin users must have credentials configured."
) )
return user return account
async def check_super_user(user: Annotated[User, Depends(check_user_exists)]) -> User: async def check_super_user(
if user.id != settings.super_user: account: Annotated[Account, Depends(check_admin)],
) -> Account:
if account.id != settings.super_user:
raise HTTPException( raise HTTPException(
HTTPStatus.FORBIDDEN, "User not authorized. No super user privileges." HTTPStatus.FORBIDDEN, "User not authorized. No super user privileges."
) )
if not user.has_password: if not account.has_password:
raise HTTPException( raise HTTPException(
HTTPStatus.FORBIDDEN, "Super user must have credentials configured." HTTPStatus.FORBIDDEN, "Super user must have credentials configured."
) )
return user return account
def parse_filters(model: type[TFilterModel]): def parse_filters(model: type[TFilterModel]):
@ -412,15 +438,17 @@ async def check_user_extension_access(
return SimpleStatus(success=True, message="OK") return SimpleStatus(success=True, message="OK")
async def _check_user_access(r: Request, user_id: str): async def _check_user_access(r: Request, user_id: str, conn: Connection | None = None):
if not settings.is_user_allowed(user_id): if not settings.is_user_allowed(user_id):
raise HTTPException(HTTPStatus.FORBIDDEN, "User not allowed.") raise HTTPException(HTTPStatus.FORBIDDEN, "User not allowed.")
await _check_user_extension_access(user_id, r["path"]) await _check_user_extension_access(user_id, r["path"], conn=conn)
async def _check_user_extension_access(user_id: str, path: str): async def _check_user_extension_access(
user_id: str, path: str, conn: Connection | None = None
):
ext_id = path_segments(path)[0] ext_id = path_segments(path)[0]
status = await check_user_extension_access(user_id, ext_id) status = await check_user_extension_access(user_id, ext_id, conn=conn)
if not status.success: if not status.success:
raise HTTPException( raise HTTPException(
HTTPStatus.FORBIDDEN, HTTPStatus.FORBIDDEN,
@ -429,12 +457,12 @@ async def _check_user_extension_access(user_id: str, path: str):
async def _get_account_from_token( async def _get_account_from_token(
access_token: str, path: str, method: str access_token: str, path: str, method: str, conn: Connection | None = None
) -> Account | None: ) -> Account | None:
try: try:
payload: dict = jwt.decode(access_token, settings.auth_secret_key, ["HS256"]) payload: dict = jwt.decode(access_token, settings.auth_secret_key, ["HS256"])
return await _get_account_from_jwt_payload( return await _get_account_from_jwt_payload(
AccessTokenPayload(**payload), path, method AccessTokenPayload(**payload), path, method, conn=conn
) )
except jwt.ExpiredSignatureError as exc: except jwt.ExpiredSignatureError as exc:
@ -447,33 +475,35 @@ async def _get_account_from_token(
async def _get_account_from_jwt_payload( async def _get_account_from_jwt_payload(
payload: AccessTokenPayload, path: str, method: str payload: AccessTokenPayload, path: str, method: str, conn: Connection | None = None
) -> Account | None: ) -> Account | None:
account = None account = None
if payload.sub: if payload.sub:
account = await get_account_by_username(payload.sub) account = await get_account_by_username(payload.sub, conn=conn)
elif payload.usr: elif payload.usr:
account = await get_account(payload.usr) account = await get_account(payload.usr, conn=conn)
elif payload.email: elif payload.email:
account = await get_account_by_email(payload.email) account = await get_account_by_email(payload.email, conn=conn)
if not account: if not account:
return None return None
if payload.api_token_id: if payload.api_token_id:
await _check_account_api_access(account.id, payload.api_token_id, path, method) await _check_account_api_access(
account.id, payload.api_token_id, path, method, conn=conn
)
return account return account
async def _check_account_api_access( async def _check_account_api_access(
user_id: str, token_id: str, path: str, method: str user_id: str, token_id: str, path: str, method: str, conn: Connection | None = None
): ):
segments = path.split("/") segments = path.split("/")
if len(segments) < 3: if len(segments) < 3:
raise HTTPException(HTTPStatus.FORBIDDEN, "Not an API endpoint.") raise HTTPException(HTTPStatus.FORBIDDEN, "Not an API endpoint.")
acls = await get_user_access_control_lists(user_id) acls = await get_user_access_control_lists(user_id, conn=conn)
acl = acls.get_acl_by_token_id(token_id) acl = acls.get_acl_by_token_id(token_id)
if not acl: if not acl:
raise HTTPException(HTTPStatus.FORBIDDEN, "Invalid token id.") raise HTTPException(HTTPStatus.FORBIDDEN, "Invalid token id.")