[perf] reuse connection (#3624)
This commit is contained in:
parent
71e0b396d2
commit
5d79327906
13 changed files with 322 additions and 189 deletions
|
|
@ -33,9 +33,9 @@ async def get_payment(checking_id: str, conn: Connection | None = None) -> Payme
|
||||||
|
|
||||||
async def get_standalone_payment(
|
async def get_standalone_payment(
|
||||||
checking_id_or_hash: str,
|
checking_id_or_hash: str,
|
||||||
conn: Connection | None = None,
|
|
||||||
incoming: bool | None = False,
|
incoming: bool | None = False,
|
||||||
wallet_id: str | None = None,
|
wallet_id: str | None = None,
|
||||||
|
conn: Connection | None = None,
|
||||||
) -> Payment | None:
|
) -> Payment | None:
|
||||||
clause: str = "checking_id = :checking_id OR payment_hash = :hash"
|
clause: str = "checking_id = :checking_id OR payment_hash = :hash"
|
||||||
values = {
|
values = {
|
||||||
|
|
@ -46,7 +46,7 @@ async def get_standalone_payment(
|
||||||
clause = f"({clause}) AND amount > 0"
|
clause = f"({clause}) AND amount > 0"
|
||||||
|
|
||||||
if wallet_id:
|
if wallet_id:
|
||||||
wallet = await get_wallet(wallet_id)
|
wallet = await get_wallet(wallet_id, conn=conn)
|
||||||
if not wallet or not wallet.can_view_payments:
|
if not wallet or not wallet.can_view_payments:
|
||||||
return None
|
return None
|
||||||
values["wallet_id"] = wallet.source_wallet_id
|
values["wallet_id"] = wallet.source_wallet_id
|
||||||
|
|
@ -69,7 +69,7 @@ async def get_standalone_payment(
|
||||||
async def get_wallet_payment(
|
async def get_wallet_payment(
|
||||||
wallet_id: str, payment_hash: str, conn: Connection | None = None
|
wallet_id: str, payment_hash: str, conn: Connection | None = None
|
||||||
) -> Payment | None:
|
) -> Payment | None:
|
||||||
wallet = await get_wallet(wallet_id)
|
wallet = await get_wallet(wallet_id, conn=conn)
|
||||||
if not wallet or not wallet.can_view_payments:
|
if not wallet or not wallet.can_view_payments:
|
||||||
return None
|
return None
|
||||||
payment = await (conn or db).fetchone(
|
payment = await (conn or db).fetchone(
|
||||||
|
|
@ -124,7 +124,6 @@ async def get_payments_paginated( # noqa: C901
|
||||||
Filters payments to be returned by:
|
Filters payments to be returned by:
|
||||||
- complete | pending | failed | outgoing | incoming.
|
- complete | pending | failed | outgoing | incoming.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
values: dict[str, Any] = {
|
values: dict[str, Any] = {
|
||||||
"time": since,
|
"time": since,
|
||||||
}
|
}
|
||||||
|
|
@ -134,7 +133,7 @@ async def get_payments_paginated( # noqa: C901
|
||||||
clause.append(f"time > {db.timestamp_placeholder('time')}")
|
clause.append(f"time > {db.timestamp_placeholder('time')}")
|
||||||
|
|
||||||
if wallet_id:
|
if wallet_id:
|
||||||
wallet = await get_wallet(wallet_id)
|
wallet = await get_wallet(wallet_id, conn=conn)
|
||||||
if not wallet or not wallet.can_view_payments:
|
if not wallet or not wallet.can_view_payments:
|
||||||
return Page(data=[], total=0)
|
return Page(data=[], total=0)
|
||||||
|
|
||||||
|
|
@ -326,6 +325,7 @@ async def get_payments_history(
|
||||||
wallet_id: str | None = None,
|
wallet_id: str | None = None,
|
||||||
group: DateTrunc = "day",
|
group: DateTrunc = "day",
|
||||||
filters: Filters | None = None,
|
filters: Filters | None = None,
|
||||||
|
conn: Connection | None = None,
|
||||||
) -> list[PaymentHistoryPoint]:
|
) -> list[PaymentHistoryPoint]:
|
||||||
if not filters:
|
if not filters:
|
||||||
filters = Filters()
|
filters = Filters()
|
||||||
|
|
@ -361,13 +361,13 @@ async def get_payments_history(
|
||||||
filters.values(values),
|
filters.values(values),
|
||||||
)
|
)
|
||||||
if wallet_id:
|
if wallet_id:
|
||||||
wallet = await get_wallet(wallet_id)
|
wallet = await get_wallet(wallet_id, conn=conn)
|
||||||
if not wallet or not wallet.can_view_payments:
|
if not wallet or not wallet.can_view_payments:
|
||||||
return []
|
return []
|
||||||
balance = wallet.balance_msat
|
balance = wallet.balance_msat
|
||||||
values["wallet_id"] = wallet.source_wallet_id
|
values["wallet_id"] = wallet.source_wallet_id
|
||||||
else:
|
else:
|
||||||
balance = await get_total_balance()
|
balance = await get_total_balance(conn=conn)
|
||||||
|
|
||||||
# since we dont know the balance at the starting point,
|
# since we dont know the balance at the starting point,
|
||||||
# we take the current balance and walk backwards
|
# we take the current balance and walk backwards
|
||||||
|
|
|
||||||
|
|
@ -30,9 +30,9 @@ async def create_account(
|
||||||
return account
|
return account
|
||||||
|
|
||||||
|
|
||||||
async def update_account(account: Account) -> Account:
|
async def update_account(account: Account, conn: Connection | None = None) -> Account:
|
||||||
account.updated_at = datetime.now(timezone.utc)
|
account.updated_at = datetime.now(timezone.utc)
|
||||||
await db.update("accounts", account)
|
await (conn or db).update("accounts", account)
|
||||||
return account
|
return account
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -171,21 +171,23 @@ async def get_account_by_username_or_email(
|
||||||
|
|
||||||
|
|
||||||
async def get_user(user_id: str, conn: Connection | None = None) -> User | None:
|
async def get_user(user_id: str, conn: Connection | None = None) -> User | None:
|
||||||
account = await get_account(user_id, conn)
|
async with db.reuse_conn(conn) if conn else db.connect() as conn:
|
||||||
if not account:
|
account = await get_account(user_id, conn=conn)
|
||||||
return None
|
if not account:
|
||||||
return await get_user_from_account(account, conn)
|
return None
|
||||||
|
return await get_user_from_account(account, conn=conn)
|
||||||
|
|
||||||
|
|
||||||
async def get_user_from_account(
|
async def get_user_from_account(
|
||||||
account: Account, conn: Connection | None = None
|
account: Account, conn: Connection | None = None
|
||||||
) -> User | None:
|
) -> User | None:
|
||||||
extensions = await get_user_active_extensions_ids(account.id, conn=conn)
|
async with db.reuse_conn(conn) if conn else db.connect() as conn:
|
||||||
wallets = await get_wallets(account.id, deleted=False, conn=conn)
|
extensions = await get_user_active_extensions_ids(account.id, conn=conn)
|
||||||
|
wallets = await get_wallets(account.id, deleted=False, conn=conn)
|
||||||
|
|
||||||
if len(wallets) == 0:
|
if len(wallets) == 0:
|
||||||
wallet = await create_wallet(user_id=account.id, conn=conn)
|
wallet = await create_wallet(user_id=account.id, conn=conn)
|
||||||
wallets.append(wallet)
|
wallets.append(wallet)
|
||||||
|
|
||||||
return User(
|
return User(
|
||||||
id=account.id,
|
id=account.id,
|
||||||
|
|
@ -205,9 +207,11 @@ async def get_user_from_account(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def update_user_access_control_list(user_acls: UserAcls):
|
async def update_user_access_control_list(
|
||||||
|
user_acls: UserAcls, conn: Connection | None = None
|
||||||
|
):
|
||||||
user_acls.updated_at = datetime.now(timezone.utc)
|
user_acls.updated_at = datetime.now(timezone.utc)
|
||||||
await db.update("accounts", user_acls)
|
await (conn or db).update("accounts", user_acls)
|
||||||
|
|
||||||
|
|
||||||
async def get_user_access_control_lists(
|
async def get_user_access_control_lists(
|
||||||
|
|
|
||||||
|
|
@ -786,3 +786,63 @@ async def m038_add_labels_for_payments(db: Connection):
|
||||||
ALTER TABLE apipayments ADD COLUMN labels TEXT
|
ALTER TABLE apipayments ADD COLUMN labels TEXT
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def m039_index_payments(db: Connection):
|
||||||
|
indexes = [
|
||||||
|
"wallet_id",
|
||||||
|
"checking_id",
|
||||||
|
"payment_hash",
|
||||||
|
"amount",
|
||||||
|
"labels",
|
||||||
|
"time",
|
||||||
|
"status",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
]
|
||||||
|
for index in indexes:
|
||||||
|
logger.debug(f"Creating index idx_payments_{index}...")
|
||||||
|
await db.execute(
|
||||||
|
f"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_payments_{index} ON apipayments ({index});
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def m040_index_wallets(db: Connection):
|
||||||
|
indexes = [
|
||||||
|
"id",
|
||||||
|
"user",
|
||||||
|
"deleted",
|
||||||
|
"adminkey",
|
||||||
|
"inkey",
|
||||||
|
"wallet_type",
|
||||||
|
"created_at",
|
||||||
|
"updated_at",
|
||||||
|
]
|
||||||
|
|
||||||
|
for index in indexes:
|
||||||
|
logger.debug(f"Creating index idx_wallets_{index}...")
|
||||||
|
await db.execute(
|
||||||
|
f"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_wallets_{index} ON wallets ("{index}");
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def m042_index_accounts(db: Connection):
|
||||||
|
indexes = [
|
||||||
|
"id",
|
||||||
|
"email",
|
||||||
|
"username",
|
||||||
|
"pubkey",
|
||||||
|
"external_id",
|
||||||
|
]
|
||||||
|
|
||||||
|
for index in indexes:
|
||||||
|
logger.debug(f"Creating index idx_wallets_{index}...")
|
||||||
|
await db.execute(
|
||||||
|
f"""
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_accounts_{index} ON accounts ("{index}");
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -201,6 +201,10 @@ class Account(AccountId):
|
||||||
self.is_admin = settings.is_admin_user(self.id)
|
self.is_admin = settings.is_admin_user(self.id)
|
||||||
self.fiat_providers = settings.get_fiat_providers_for_user(self.id)
|
self.fiat_providers = settings.get_fiat_providers_for_user(self.id)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_password(self) -> bool:
|
||||||
|
return self.password_hash is not None
|
||||||
|
|
||||||
def hash_password(self, password: str) -> str:
|
def hash_password(self, password: str) -> str:
|
||||||
"""sets and returns the hashed password"""
|
"""sets and returns the hashed password"""
|
||||||
salt = gensalt()
|
salt = gensalt()
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from lnbits.core.crud.extensions import (
|
||||||
update_installed_extension,
|
update_installed_extension,
|
||||||
)
|
)
|
||||||
from lnbits.core.helpers import migrate_extension_database
|
from lnbits.core.helpers import migrate_extension_database
|
||||||
|
from lnbits.db import Connection
|
||||||
from lnbits.settings import settings
|
from lnbits.settings import settings
|
||||||
|
|
||||||
from ..models.extensions import Extension, ExtensionMeta, InstallableExtension
|
from ..models.extensions import Extension, ExtensionMeta, InstallableExtension
|
||||||
|
|
@ -149,9 +150,9 @@ async def start_extension_background_work(ext_id: str) -> bool:
|
||||||
|
|
||||||
|
|
||||||
async def get_valid_extensions(
|
async def get_valid_extensions(
|
||||||
include_deactivated: bool | None = True,
|
include_deactivated: bool | None = True, conn: Connection | None = None
|
||||||
) -> list[Extension]:
|
) -> list[Extension]:
|
||||||
installed_extensions = await get_installed_extensions()
|
installed_extensions = await get_installed_extensions(conn=conn)
|
||||||
valid_extensions = [Extension.from_installable_ext(e) for e in installed_extensions]
|
valid_extensions = [Extension.from_installable_ext(e) for e in installed_extensions]
|
||||||
|
|
||||||
if include_deactivated:
|
if include_deactivated:
|
||||||
|
|
|
||||||
|
|
@ -67,12 +67,14 @@ async def pay_invoice(
|
||||||
if settings.lnbits_only_allow_incoming_payments:
|
if settings.lnbits_only_allow_incoming_payments:
|
||||||
raise PaymentError("Only incoming payments allowed.", status="failed")
|
raise PaymentError("Only incoming payments allowed.", status="failed")
|
||||||
invoice = _validate_payment_request(payment_request, max_sat)
|
invoice = _validate_payment_request(payment_request, max_sat)
|
||||||
|
|
||||||
if not invoice.amount_msat:
|
if not invoice.amount_msat:
|
||||||
raise ValueError("Missig invoice amount.")
|
raise ValueError("Missig invoice amount.")
|
||||||
|
|
||||||
async with db.reuse_conn(conn) if conn else db.connect() as new_conn:
|
async with db.reuse_conn(conn) if conn else db.connect() as new_conn:
|
||||||
amount_msat = invoice.amount_msat
|
amount_msat = invoice.amount_msat
|
||||||
wallet = await _check_wallet_for_payment(wallet_id, tag, amount_msat, new_conn)
|
wallet = await _check_wallet_for_payment(wallet_id, tag, amount_msat, new_conn)
|
||||||
|
|
||||||
if not wallet.can_send_payments:
|
if not wallet.can_send_payments:
|
||||||
raise PaymentError(
|
raise PaymentError(
|
||||||
"Wallet does not have permission to pay invoices.",
|
"Wallet does not have permission to pay invoices.",
|
||||||
|
|
@ -95,10 +97,12 @@ async def pay_invoice(
|
||||||
labels=labels,
|
labels=labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
payment = await _pay_invoice(wallet.source_wallet_id, create_payment_model, conn)
|
|
||||||
|
|
||||||
async with db.reuse_conn(conn) if conn else db.connect() as new_conn:
|
async with db.reuse_conn(conn) if conn else db.connect() as new_conn:
|
||||||
await _credit_service_fee_wallet(wallet, payment, new_conn)
|
payment = await _pay_invoice(
|
||||||
|
wallet.source_wallet_id, create_payment_model, conn=new_conn
|
||||||
|
)
|
||||||
|
|
||||||
|
await _credit_service_fee_wallet(wallet, payment, conn=new_conn)
|
||||||
|
|
||||||
return payment
|
return payment
|
||||||
|
|
||||||
|
|
@ -197,20 +201,22 @@ async def create_wallet_invoice(wallet_id: str, data: CreateInvoice) -> Payment:
|
||||||
# do not save memo if description_hash or unhashed_description is set
|
# do not save memo if description_hash or unhashed_description is set
|
||||||
memo = ""
|
memo = ""
|
||||||
|
|
||||||
payment = await create_invoice(
|
async with db.connect() as conn:
|
||||||
wallet_id=wallet_id,
|
payment = await create_invoice(
|
||||||
amount=data.amount,
|
wallet_id=wallet_id,
|
||||||
memo=memo,
|
amount=data.amount,
|
||||||
currency=data.unit,
|
memo=memo,
|
||||||
description_hash=description_hash,
|
currency=data.unit,
|
||||||
unhashed_description=unhashed_description,
|
description_hash=description_hash,
|
||||||
expiry=data.expiry,
|
unhashed_description=unhashed_description,
|
||||||
extra=data.extra,
|
expiry=data.expiry,
|
||||||
webhook=data.webhook,
|
extra=data.extra,
|
||||||
internal=data.internal,
|
webhook=data.webhook,
|
||||||
payment_hash=data.payment_hash,
|
internal=data.internal,
|
||||||
labels=data.labels,
|
payment_hash=data.payment_hash,
|
||||||
)
|
labels=data.labels,
|
||||||
|
conn=conn,
|
||||||
|
)
|
||||||
|
|
||||||
if data.lnurl_withdraw:
|
if data.lnurl_withdraw:
|
||||||
try:
|
try:
|
||||||
|
|
@ -355,13 +361,15 @@ async def update_pending_payments(wallet_id: str):
|
||||||
await update_pending_payment(payment)
|
await update_pending_payment(payment)
|
||||||
|
|
||||||
|
|
||||||
async def update_pending_payment(payment: Payment) -> Payment:
|
async def update_pending_payment(
|
||||||
|
payment: Payment, conn: Connection | None = None
|
||||||
|
) -> Payment:
|
||||||
status = await payment.check_status()
|
status = await payment.check_status()
|
||||||
if status.failed:
|
if status.failed:
|
||||||
payment.status = PaymentState.FAILED
|
payment.status = PaymentState.FAILED
|
||||||
await update_payment(payment)
|
await update_payment(payment, conn=conn)
|
||||||
elif status.success:
|
elif status.success:
|
||||||
payment = await update_payment_success_status(payment, status)
|
payment = await update_payment_success_status(payment, status, conn=conn)
|
||||||
return payment
|
return payment
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -698,6 +706,7 @@ async def _pay_internal_invoice(
|
||||||
internal_payment = await check_internal(
|
internal_payment = await check_internal(
|
||||||
create_payment_model.payment_hash, conn=conn
|
create_payment_model.payment_hash, conn=conn
|
||||||
)
|
)
|
||||||
|
|
||||||
if not internal_payment:
|
if not internal_payment:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -706,6 +715,7 @@ async def _pay_internal_invoice(
|
||||||
internal_invoice = await get_standalone_payment(
|
internal_invoice = await get_standalone_payment(
|
||||||
internal_payment.checking_id, incoming=True, conn=conn
|
internal_payment.checking_id, incoming=True, conn=conn
|
||||||
)
|
)
|
||||||
|
|
||||||
if not internal_invoice:
|
if not internal_invoice:
|
||||||
raise PaymentError("Internal payment not found.", status="failed")
|
raise PaymentError("Internal payment not found.", status="failed")
|
||||||
|
|
||||||
|
|
@ -727,6 +737,7 @@ async def _pay_internal_invoice(
|
||||||
|
|
||||||
internal_id = f"internal_{create_payment_model.payment_hash}"
|
internal_id = f"internal_{create_payment_model.payment_hash}"
|
||||||
logger.debug(f"creating temporary internal payment with id {internal_id}")
|
logger.debug(f"creating temporary internal payment with id {internal_id}")
|
||||||
|
|
||||||
payment = await create_payment(
|
payment = await create_payment(
|
||||||
checking_id=internal_id,
|
checking_id=internal_id,
|
||||||
data=create_payment_model,
|
data=create_payment_model,
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,9 @@ from uuid import uuid4
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from lnbits.core.db import db
|
||||||
from lnbits.core.models.extensions import UserExtension
|
from lnbits.core.models.extensions import UserExtension
|
||||||
|
from lnbits.db import Connection
|
||||||
from lnbits.settings import (
|
from lnbits.settings import (
|
||||||
EditableSettings,
|
EditableSettings,
|
||||||
SuperSettings,
|
SuperSettings,
|
||||||
|
|
@ -48,37 +50,43 @@ async def create_user_account_no_ckeck(
|
||||||
account: Account | None = None,
|
account: Account | None = None,
|
||||||
wallet_name: str | None = None,
|
wallet_name: str | None = None,
|
||||||
default_exts: list[str] | None = None,
|
default_exts: list[str] | None = None,
|
||||||
|
conn: Connection | None = None,
|
||||||
) -> User:
|
) -> User:
|
||||||
|
async with db.reuse_conn(conn) if conn else db.connect() as conn:
|
||||||
|
if account:
|
||||||
|
account.validate_fields()
|
||||||
|
if account.username and await get_account_by_username(
|
||||||
|
account.username, conn=conn
|
||||||
|
):
|
||||||
|
raise ValueError("Username already exists.")
|
||||||
|
|
||||||
if account:
|
if account.email and await get_account_by_email(account.email, conn=conn):
|
||||||
account.validate_fields()
|
raise ValueError("Email already exists.")
|
||||||
if account.username and await get_account_by_username(account.username):
|
|
||||||
raise ValueError("Username already exists.")
|
|
||||||
|
|
||||||
if account.email and await get_account_by_email(account.email):
|
if account.pubkey and await get_account_by_pubkey(
|
||||||
raise ValueError("Email already exists.")
|
account.pubkey, conn=conn
|
||||||
|
):
|
||||||
|
raise ValueError("Pubkey already exists.")
|
||||||
|
|
||||||
if account.pubkey and await get_account_by_pubkey(account.pubkey):
|
if not account.id:
|
||||||
raise ValueError("Pubkey already exists.")
|
account.id = uuid4().hex
|
||||||
|
|
||||||
if not account.id:
|
account = await create_account(account, conn=conn)
|
||||||
account.id = uuid4().hex
|
await create_wallet(
|
||||||
|
user_id=account.id,
|
||||||
|
wallet_name=wallet_name or settings.lnbits_default_wallet_name,
|
||||||
|
conn=conn,
|
||||||
|
)
|
||||||
|
|
||||||
account = await create_account(account)
|
user_extensions = (default_exts or []) + settings.lnbits_user_default_extensions
|
||||||
await create_wallet(
|
for ext_id in user_extensions:
|
||||||
user_id=account.id,
|
try:
|
||||||
wallet_name=wallet_name or settings.lnbits_default_wallet_name,
|
user_ext = UserExtension(user=account.id, extension=ext_id, active=True)
|
||||||
)
|
await create_user_extension(user_ext, conn=conn)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error enabeling default extension {ext_id}: {e}")
|
||||||
|
|
||||||
user_extensions = (default_exts or []) + settings.lnbits_user_default_extensions
|
user = await get_user_from_account(account, conn=conn)
|
||||||
for ext_id in user_extensions:
|
|
||||||
try:
|
|
||||||
user_ext = UserExtension(user=account.id, extension=ext_id, active=True)
|
|
||||||
await create_user_extension(user_ext)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error enabeling default extension {ext_id}: {e}")
|
|
||||||
|
|
||||||
user = await get_user_from_account(account)
|
|
||||||
if not user:
|
if not user:
|
||||||
raise ValueError("Cannot find user for account.")
|
raise ValueError("Cannot find user for account.")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ from urllib.parse import urlparse
|
||||||
from fastapi import APIRouter, Depends, File
|
from fastapi import APIRouter, Depends, File
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
|
|
||||||
from lnbits.core.models import User
|
|
||||||
from lnbits.core.models.notifications import NotificationType
|
from lnbits.core.models.notifications import NotificationType
|
||||||
|
from lnbits.core.models.users import Account
|
||||||
from lnbits.core.services import (
|
from lnbits.core.services import (
|
||||||
enqueue_admin_notification,
|
enqueue_admin_notification,
|
||||||
get_balance_delta,
|
get_balance_delta,
|
||||||
|
|
@ -67,9 +67,9 @@ async def api_test_email():
|
||||||
|
|
||||||
@admin_router.get("/api/v1/settings")
|
@admin_router.get("/api/v1/settings")
|
||||||
async def api_get_settings(
|
async def api_get_settings(
|
||||||
user: User = Depends(check_admin),
|
account: Account = Depends(check_admin),
|
||||||
) -> AdminSettings | None:
|
) -> AdminSettings | None:
|
||||||
admin_settings = await get_admin_settings(user.super_user)
|
admin_settings = await get_admin_settings(account.is_super_user)
|
||||||
return admin_settings
|
return admin_settings
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -77,12 +77,14 @@ async def api_get_settings(
|
||||||
"/api/v1/settings",
|
"/api/v1/settings",
|
||||||
status_code=HTTPStatus.OK,
|
status_code=HTTPStatus.OK,
|
||||||
)
|
)
|
||||||
async def api_update_settings(data: UpdateSettings, user: User = Depends(check_admin)):
|
async def api_update_settings(
|
||||||
|
data: UpdateSettings, account: Account = Depends(check_admin)
|
||||||
|
):
|
||||||
enqueue_admin_notification(
|
enqueue_admin_notification(
|
||||||
NotificationType.settings_update, {"username": user.username}
|
NotificationType.settings_update, {"username": account.username}
|
||||||
)
|
)
|
||||||
await update_admin_settings(data)
|
await update_admin_settings(data)
|
||||||
admin_settings = await get_admin_settings(user.super_user)
|
admin_settings = await get_admin_settings(account.is_super_user)
|
||||||
if not admin_settings:
|
if not admin_settings:
|
||||||
raise ValueError("Updated admin settings not found.")
|
raise ValueError("Updated admin settings not found.")
|
||||||
update_cached_settings(admin_settings.dict())
|
update_cached_settings(admin_settings.dict())
|
||||||
|
|
@ -94,9 +96,11 @@ async def api_update_settings(data: UpdateSettings, user: User = Depends(check_a
|
||||||
"/api/v1/settings",
|
"/api/v1/settings",
|
||||||
status_code=HTTPStatus.OK,
|
status_code=HTTPStatus.OK,
|
||||||
)
|
)
|
||||||
async def api_update_settings_partial(data: dict, user: User = Depends(check_admin)):
|
async def api_update_settings_partial(
|
||||||
|
data: dict, account: Account = Depends(check_admin)
|
||||||
|
):
|
||||||
updatable_settings = dict_to_settings({**settings.dict(), **data})
|
updatable_settings = dict_to_settings({**settings.dict(), **data})
|
||||||
return await api_update_settings(updatable_settings, user)
|
return await api_update_settings(updatable_settings, account)
|
||||||
|
|
||||||
|
|
||||||
@admin_router.get(
|
@admin_router.get(
|
||||||
|
|
@ -110,9 +114,9 @@ async def api_reset_settings(field_name: str):
|
||||||
|
|
||||||
|
|
||||||
@admin_router.delete("/api/v1/settings", status_code=HTTPStatus.OK)
|
@admin_router.delete("/api/v1/settings", status_code=HTTPStatus.OK)
|
||||||
async def api_delete_settings(user: User = Depends(check_super_user)) -> None:
|
async def api_delete_settings(account: Account = Depends(check_super_user)) -> None:
|
||||||
enqueue_admin_notification(
|
enqueue_admin_notification(
|
||||||
NotificationType.settings_update, {"username": user.username}
|
NotificationType.settings_update, {"username": account.username}
|
||||||
)
|
)
|
||||||
await reset_core_settings()
|
await reset_core_settings()
|
||||||
server_restart.set()
|
server_restart.set()
|
||||||
|
|
|
||||||
|
|
@ -7,9 +7,10 @@ from fastapi import APIRouter, Depends, HTTPException
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lnbits.core.crud.extensions import get_user_extensions
|
from lnbits.core.crud.extensions import get_user_extensions
|
||||||
|
from lnbits.core.crud.wallets import get_wallets_ids
|
||||||
|
from lnbits.core.db import db
|
||||||
from lnbits.core.models import (
|
from lnbits.core.models import (
|
||||||
SimpleStatus,
|
SimpleStatus,
|
||||||
User,
|
|
||||||
)
|
)
|
||||||
from lnbits.core.models.extensions import (
|
from lnbits.core.models.extensions import (
|
||||||
CreateExtension,
|
CreateExtension,
|
||||||
|
|
@ -23,7 +24,7 @@ from lnbits.core.models.extensions import (
|
||||||
UserExtension,
|
UserExtension,
|
||||||
UserExtensionInfo,
|
UserExtensionInfo,
|
||||||
)
|
)
|
||||||
from lnbits.core.models.users import AccountId
|
from lnbits.core.models.users import Account, AccountId
|
||||||
from lnbits.core.services import check_transaction_status, create_invoice
|
from lnbits.core.services import check_transaction_status, create_invoice
|
||||||
from lnbits.core.services.extensions import (
|
from lnbits.core.services.extensions import (
|
||||||
activate_extension,
|
activate_extension,
|
||||||
|
|
@ -144,9 +145,10 @@ async def api_extension_details(
|
||||||
async def api_update_pay_to_enable(
|
async def api_update_pay_to_enable(
|
||||||
ext_id: str,
|
ext_id: str,
|
||||||
data: PayToEnableInfo,
|
data: PayToEnableInfo,
|
||||||
user: User = Depends(check_admin),
|
account: Account = Depends(check_admin),
|
||||||
) -> SimpleStatus:
|
) -> SimpleStatus:
|
||||||
if data.wallet not in user.wallet_ids:
|
user_wallet_ids = await get_wallets_ids(account.id, deleted=False)
|
||||||
|
if data.wallet not in user_wallet_ids:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
HTTPStatus.BAD_REQUEST, "Wallet does not belong to this admin user."
|
HTTPStatus.BAD_REQUEST, "Wallet does not belong to this admin user."
|
||||||
)
|
)
|
||||||
|
|
@ -462,15 +464,16 @@ async def get_extension_release(org: str, repo: str, tag_name: str):
|
||||||
async def api_get_user_extensions(
|
async def api_get_user_extensions(
|
||||||
account_id: AccountId = Depends(check_account_id_exists),
|
account_id: AccountId = Depends(check_account_id_exists),
|
||||||
) -> list[Extension]:
|
) -> list[Extension]:
|
||||||
|
async with db.connect() as conn:
|
||||||
user_extensions_ids = [
|
user_extensions_ids = [
|
||||||
ue.extension for ue in await get_user_extensions(account_id.id)
|
ue.extension for ue in await get_user_extensions(account_id.id, conn=conn)
|
||||||
]
|
]
|
||||||
return [
|
valid_extensions = [
|
||||||
ext
|
ext
|
||||||
for ext in await get_valid_extensions(False)
|
for ext in await get_valid_extensions(False, conn=conn)
|
||||||
if ext.code in user_extensions_ids
|
if ext.code in user_extensions_ids
|
||||||
]
|
]
|
||||||
|
return valid_extensions
|
||||||
|
|
||||||
|
|
||||||
@extension_router.delete(
|
@extension_router.delete(
|
||||||
|
|
@ -505,7 +508,16 @@ async def delete_extension_db(ext_id: str):
|
||||||
# TODO: create a response model for this
|
# TODO: create a response model for this
|
||||||
@extension_router.get("/all")
|
@extension_router.get("/all")
|
||||||
async def extensions(account_id: AccountId = Depends(check_account_id_exists)):
|
async def extensions(account_id: AccountId = Depends(check_account_id_exists)):
|
||||||
installed_exts: list[InstallableExtension] = await get_installed_extensions()
|
async with db.connect() as conn:
|
||||||
|
installed_exts: list[InstallableExtension] = await get_installed_extensions(
|
||||||
|
conn=conn
|
||||||
|
)
|
||||||
|
all_ext_ids = [ext.code for ext in await get_valid_extensions(conn=conn)]
|
||||||
|
inactive_extensions = [
|
||||||
|
e.id for e in await get_installed_extensions(active=False, conn=conn)
|
||||||
|
]
|
||||||
|
db_versions = await get_db_versions(conn=conn)
|
||||||
|
|
||||||
installed_exts_ids = [e.id for e in installed_exts]
|
installed_exts_ids = [e.id for e in installed_exts]
|
||||||
|
|
||||||
installable_exts = await InstallableExtension.get_installable_extensions(
|
installable_exts = await InstallableExtension.get_installable_extensions(
|
||||||
|
|
@ -536,10 +548,6 @@ async def extensions(account_id: AccountId = Depends(check_account_id_exists)):
|
||||||
e.short_description = installed_ext.short_description
|
e.short_description = installed_ext.short_description
|
||||||
e.icon = installed_ext.icon
|
e.icon = installed_ext.icon
|
||||||
|
|
||||||
all_ext_ids = [ext.code for ext in await get_valid_extensions()]
|
|
||||||
inactive_extensions = [e.id for e in await get_installed_extensions(active=False)]
|
|
||||||
db_versions = await get_db_versions()
|
|
||||||
|
|
||||||
extension_data = [
|
extension_data = [
|
||||||
{
|
{
|
||||||
"id": ext.id,
|
"id": ext.id,
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ from fastapi.responses import FileResponse
|
||||||
|
|
||||||
from lnbits.core.models import (
|
from lnbits.core.models import (
|
||||||
SimpleStatus,
|
SimpleStatus,
|
||||||
User,
|
|
||||||
)
|
)
|
||||||
from lnbits.core.models.extensions import (
|
from lnbits.core.models.extensions import (
|
||||||
Extension,
|
Extension,
|
||||||
|
|
@ -16,7 +15,7 @@ from lnbits.core.models.extensions import (
|
||||||
UserExtension,
|
UserExtension,
|
||||||
)
|
)
|
||||||
from lnbits.core.models.extensions_builder import ExtensionData
|
from lnbits.core.models.extensions_builder import ExtensionData
|
||||||
from lnbits.core.models.users import AccountId
|
from lnbits.core.models.users import Account, AccountId
|
||||||
from lnbits.core.services.extensions import (
|
from lnbits.core.services.extensions import (
|
||||||
activate_extension,
|
activate_extension,
|
||||||
install_extension,
|
install_extension,
|
||||||
|
|
@ -85,9 +84,9 @@ async def api_build_extension(data: ExtensionData) -> FileResponse:
|
||||||
)
|
)
|
||||||
async def api_deploy_extension(
|
async def api_deploy_extension(
|
||||||
data: ExtensionData,
|
data: ExtensionData,
|
||||||
user: User = Depends(check_admin),
|
account: Account = Depends(check_admin),
|
||||||
) -> SimpleStatus:
|
) -> SimpleStatus:
|
||||||
working_dir_name = "deploy_" + sha256(user.id.encode("utf-8")).hexdigest()
|
working_dir_name = "deploy_" + sha256(account.id.encode("utf-8")).hexdigest()
|
||||||
stub_ext_id = "extension_builder_stub"
|
stub_ext_id = "extension_builder_stub"
|
||||||
release, build_dir = await build_extension_from_data(
|
release, build_dir = await build_extension_from_data(
|
||||||
data, stub_ext_id, working_dir_name
|
data, stub_ext_id, working_dir_name
|
||||||
|
|
@ -111,9 +110,9 @@ async def api_deploy_extension(
|
||||||
|
|
||||||
await activate_extension(Extension.from_installable_ext(ext_info))
|
await activate_extension(Extension.from_installable_ext(ext_info))
|
||||||
|
|
||||||
user_ext = await get_user_extension(user.id, data.id)
|
user_ext = await get_user_extension(account.id, data.id)
|
||||||
if not user_ext:
|
if not user_ext:
|
||||||
user_ext = UserExtension(user=user.id, extension=data.id, active=True)
|
user_ext = UserExtension(user=account.id, extension=data.id, active=True)
|
||||||
await create_user_extension(user_ext)
|
await create_user_extension(user_ext)
|
||||||
elif not user_ext.active:
|
elif not user_ext.active:
|
||||||
user_ext.active = True
|
user_ext.active = True
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ from lnbits.core.crud.payments import (
|
||||||
update_payment,
|
update_payment,
|
||||||
)
|
)
|
||||||
from lnbits.core.crud.users import get_account
|
from lnbits.core.crud.users import get_account
|
||||||
|
from lnbits.core.db import db
|
||||||
from lnbits.core.models import (
|
from lnbits.core.models import (
|
||||||
CancelInvoice,
|
CancelInvoice,
|
||||||
CreateInvoice,
|
CreateInvoice,
|
||||||
|
|
@ -69,7 +70,6 @@ from ..services import (
|
||||||
perform_withdraw,
|
perform_withdraw,
|
||||||
settle_hold_invoice,
|
settle_hold_invoice,
|
||||||
update_pending_payment,
|
update_pending_payment,
|
||||||
update_pending_payments,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
payment_router = APIRouter(prefix="/api/v1/payments", tags=["Payments"])
|
payment_router = APIRouter(prefix="/api/v1/payments", tags=["Payments"])
|
||||||
|
|
@ -87,7 +87,6 @@ async def api_payments(
|
||||||
key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key),
|
key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key),
|
||||||
filters: Filters = Depends(parse_filters(PaymentFilters)),
|
filters: Filters = Depends(parse_filters(PaymentFilters)),
|
||||||
):
|
):
|
||||||
await update_pending_payments(key_info.wallet.id)
|
|
||||||
return await get_payments(
|
return await get_payments(
|
||||||
wallet_id=key_info.wallet.id,
|
wallet_id=key_info.wallet.id,
|
||||||
pending=True,
|
pending=True,
|
||||||
|
|
@ -107,7 +106,6 @@ async def api_payments_history(
|
||||||
group: DateTrunc = Query("day"),
|
group: DateTrunc = Query("day"),
|
||||||
filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)),
|
filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)),
|
||||||
):
|
):
|
||||||
await update_pending_payments(key_info.wallet.id)
|
|
||||||
return await get_payments_history(key_info.wallet.id, group, filters)
|
return await get_payments_history(key_info.wallet.id, group, filters)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -186,18 +184,24 @@ async def api_payments_paginated(
|
||||||
),
|
),
|
||||||
filters: Filters = Depends(parse_filters(PaymentFilters)),
|
filters: Filters = Depends(parse_filters(PaymentFilters)),
|
||||||
) -> Page[Payment]:
|
) -> Page[Payment]:
|
||||||
page = await get_payments_paginated(
|
async with db.connect() as conn:
|
||||||
wallet_id=key_info.wallet.id,
|
page = await get_payments_paginated(
|
||||||
filters=filters,
|
wallet_id=key_info.wallet.id,
|
||||||
)
|
filters=filters,
|
||||||
if not recheck_pending:
|
conn=conn,
|
||||||
return page
|
)
|
||||||
|
if not recheck_pending:
|
||||||
|
return page
|
||||||
|
|
||||||
for payment in page.data:
|
payments = []
|
||||||
if payment.pending:
|
for payment in page.data:
|
||||||
await update_pending_payment(payment)
|
if payment.pending:
|
||||||
|
refreshed_payment = await update_pending_payment(payment, conn=conn)
|
||||||
|
payments.append(refreshed_payment)
|
||||||
|
else:
|
||||||
|
payments.append(payment)
|
||||||
|
|
||||||
return page
|
return Page(data=payments, total=page.total)
|
||||||
|
|
||||||
|
|
||||||
@payment_router.get(
|
@payment_router.get(
|
||||||
|
|
@ -219,10 +223,10 @@ async def api_all_payments_paginated(
|
||||||
# regular user can only see payments from their wallets
|
# regular user can only see payments from their wallets
|
||||||
for_user_id = account_id.id
|
for_user_id = account_id.id
|
||||||
|
|
||||||
return await get_payments_paginated(
|
async with db.connect() as conn:
|
||||||
filters=filters,
|
return await get_payments_paginated(
|
||||||
user_id=for_user_id,
|
filters=filters, user_id=for_user_id, conn=conn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@payment_router.post(
|
@payment_router.post(
|
||||||
|
|
|
||||||
|
|
@ -115,12 +115,12 @@ async def api_create_user(data: CreateUser) -> CreateUser:
|
||||||
|
|
||||||
@users_router.put("/user/{user_id}", name="Update user")
|
@users_router.put("/user/{user_id}", name="Update user")
|
||||||
async def api_update_user(
|
async def api_update_user(
|
||||||
user_id: str, data: CreateUser, user: User = Depends(check_admin)
|
user_id: str, data: CreateUser, account: Account = Depends(check_admin)
|
||||||
) -> CreateUser:
|
) -> CreateUser:
|
||||||
if user_id != data.id:
|
if user_id != data.id:
|
||||||
raise HTTPException(HTTPStatus.BAD_REQUEST, "User Id missmatch.")
|
raise HTTPException(HTTPStatus.BAD_REQUEST, "User Id missmatch.")
|
||||||
|
|
||||||
if user_id == settings.super_user and user.id != settings.super_user:
|
if user_id == settings.super_user and account.id != settings.super_user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.BAD_REQUEST,
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
detail="Action only allowed for super user.",
|
detail="Action only allowed for super user.",
|
||||||
|
|
@ -154,7 +154,7 @@ async def api_update_user(
|
||||||
name="Delete user by Id",
|
name="Delete user by Id",
|
||||||
)
|
)
|
||||||
async def api_users_delete_user(
|
async def api_users_delete_user(
|
||||||
user_id: str, user: User = Depends(check_admin)
|
user_id: str, account: Account = Depends(check_admin)
|
||||||
) -> SimpleStatus:
|
) -> SimpleStatus:
|
||||||
wallets = await get_wallets(user_id, deleted=False)
|
wallets = await get_wallets(user_id, deleted=False)
|
||||||
if len(wallets) > 0:
|
if len(wallets) > 0:
|
||||||
|
|
@ -169,7 +169,7 @@ async def api_users_delete_user(
|
||||||
detail="Cannot delete super user.",
|
detail="Cannot delete super user.",
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_id in settings.lnbits_admin_users and not user.super_user:
|
if user_id in settings.lnbits_admin_users and not account.is_super_user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.BAD_REQUEST,
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
detail="Only super_user can delete admin user.",
|
detail="Only super_user can delete admin user.",
|
||||||
|
|
@ -295,7 +295,7 @@ async def api_users_delete_all_user_wallet(user_id: str) -> SimpleStatus:
|
||||||
"The second time it is called will delete the entry from the DB",
|
"The second time it is called will delete the entry from the DB",
|
||||||
)
|
)
|
||||||
async def api_users_delete_user_wallet(
|
async def api_users_delete_user_wallet(
|
||||||
user_id: str, wallet: str, user: User = Depends(check_admin)
|
user_id: str, wallet: str, account: Account = Depends(check_admin)
|
||||||
) -> SimpleStatus:
|
) -> SimpleStatus:
|
||||||
wal = await get_wallet(wallet)
|
wal = await get_wallet(wallet)
|
||||||
if not wal:
|
if not wal:
|
||||||
|
|
@ -304,7 +304,7 @@ async def api_users_delete_user_wallet(
|
||||||
detail="Wallet does not exist.",
|
detail="Wallet does not exist.",
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_id == settings.super_user and user.id != settings.super_user:
|
if user_id == settings.super_user and account.id != settings.super_user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.BAD_REQUEST,
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
detail="Action only allowed for super user.",
|
detail="Action only allowed for super user.",
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ from lnbits.core.crud import (
|
||||||
)
|
)
|
||||||
from lnbits.core.crud.users import get_user_access_control_lists
|
from lnbits.core.crud.users import get_user_access_control_lists
|
||||||
from lnbits.core.crud.wallets import get_base_wallet_for_key
|
from lnbits.core.crud.wallets import get_base_wallet_for_key
|
||||||
|
from lnbits.core.db import db
|
||||||
from lnbits.core.models import (
|
from lnbits.core.models import (
|
||||||
AccessTokenPayload,
|
AccessTokenPayload,
|
||||||
Account,
|
Account,
|
||||||
|
|
@ -119,16 +120,17 @@ class KeyChecker(BaseKeyChecker):
|
||||||
async def __call__(self, request: Request) -> WalletTypeInfo:
|
async def __call__(self, request: Request) -> WalletTypeInfo:
|
||||||
key_value = self._extract_key_value(request)
|
key_value = self._extract_key_value(request)
|
||||||
|
|
||||||
wallet = await get_wallet_for_key(key_value)
|
async with db.connect() as conn:
|
||||||
|
wallet = await get_wallet_for_key(key_value, conn=conn)
|
||||||
|
|
||||||
if not wallet:
|
if not wallet:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.NOT_FOUND,
|
status_code=HTTPStatus.NOT_FOUND,
|
||||||
detail="Wallet not found.",
|
detail="Wallet not found.",
|
||||||
)
|
)
|
||||||
|
|
||||||
request.scope["user_id"] = wallet.user
|
request.scope["user_id"] = wallet.user
|
||||||
await _check_user_access(request, wallet.user)
|
await _check_user_access(request, wallet.user, conn=conn)
|
||||||
|
|
||||||
key_type = await self._extract_key_type(key_value, wallet)
|
key_type = await self._extract_key_type(key_value, wallet)
|
||||||
return WalletTypeInfo(key_type, wallet)
|
return WalletTypeInfo(key_type, wallet)
|
||||||
|
|
@ -148,22 +150,23 @@ class LightKeyChecker(BaseKeyChecker):
|
||||||
cache_key = f"auth:x-api-key:{key_value}"
|
cache_key = f"auth:x-api-key:{key_value}"
|
||||||
cache_time = settings.auth_authentication_cache_minutes * 60
|
cache_time = settings.auth_authentication_cache_minutes * 60
|
||||||
|
|
||||||
if cache_time > 0:
|
async with db.connect() as conn:
|
||||||
key_info: BaseWalletTypeInfo | None = cache.get(cache_key)
|
if cache_time > 0:
|
||||||
if key_info:
|
key_info: BaseWalletTypeInfo | None = cache.get(cache_key)
|
||||||
request.scope["user_id"] = key_info.wallet.user
|
if key_info:
|
||||||
await _check_user_access(request, key_info.wallet.user)
|
request.scope["user_id"] = key_info.wallet.user
|
||||||
return key_info
|
await _check_user_access(request, key_info.wallet.user, conn=conn)
|
||||||
|
return key_info
|
||||||
|
|
||||||
wallet = await get_base_wallet_for_key(key_value)
|
wallet = await get_base_wallet_for_key(key_value, conn=conn)
|
||||||
|
|
||||||
if not wallet:
|
if not wallet:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=HTTPStatus.NOT_FOUND,
|
status_code=HTTPStatus.NOT_FOUND,
|
||||||
detail="Wallet not found.",
|
detail="Wallet not found.",
|
||||||
)
|
)
|
||||||
request.scope["user_id"] = wallet.user
|
request.scope["user_id"] = wallet.user
|
||||||
await _check_user_access(request, wallet.user)
|
await _check_user_access(request, wallet.user, conn=conn)
|
||||||
|
|
||||||
key_type = await self._extract_key_type(key_value, wallet)
|
key_type = await self._extract_key_type(key_value, wallet)
|
||||||
key_info = BaseWalletTypeInfo(key_type, wallet)
|
key_info = BaseWalletTypeInfo(key_type, wallet)
|
||||||
|
|
@ -241,15 +244,16 @@ async def check_account_id_exists(
|
||||||
elif usr:
|
elif usr:
|
||||||
cache_key = f"auth:user_id:{sha256s(usr.hex)}"
|
cache_key = f"auth:user_id:{sha256s(usr.hex)}"
|
||||||
|
|
||||||
if cache_key and settings.auth_authentication_cache_minutes > 0:
|
async with db.connect() as conn:
|
||||||
account_id = cache.get(cache_key)
|
if cache_key and settings.auth_authentication_cache_minutes > 0:
|
||||||
if account_id:
|
account_id = cache.get(cache_key)
|
||||||
r.scope["user_id"] = account_id.id
|
if account_id:
|
||||||
await _check_user_access(r, account_id)
|
r.scope["user_id"] = account_id.id
|
||||||
return account_id
|
await _check_user_access(r, account_id, conn=conn)
|
||||||
|
return account_id
|
||||||
|
|
||||||
account = await check_account_exists(r, access_token, usr)
|
account = await _check_account_exists(r, access_token, usr, conn=conn)
|
||||||
account_id = AccountId(id=account.id)
|
account_id = AccountId(id=account.id)
|
||||||
|
|
||||||
if cache_key and settings.auth_authentication_cache_minutes > 0:
|
if cache_key and settings.auth_authentication_cache_minutes > 0:
|
||||||
cache.set(
|
cache.set(
|
||||||
|
|
@ -265,6 +269,15 @@ async def check_account_exists(
|
||||||
r: Request,
|
r: Request,
|
||||||
access_token: Annotated[str | None, Depends(check_access_token)],
|
access_token: Annotated[str | None, Depends(check_access_token)],
|
||||||
usr: UUID4 | None = None,
|
usr: UUID4 | None = None,
|
||||||
|
) -> Account:
|
||||||
|
return await _check_account_exists(r, access_token, usr)
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_account_exists(
|
||||||
|
r: Request,
|
||||||
|
access_token: Annotated[str | None, Depends(check_access_token)],
|
||||||
|
usr: UUID4 | None = None,
|
||||||
|
conn: Connection | None = None,
|
||||||
) -> Account:
|
) -> Account:
|
||||||
"""
|
"""
|
||||||
Check that the account exists based on access token or user id.
|
Check that the account exists based on access token or user id.
|
||||||
|
|
@ -273,22 +286,27 @@ async def check_account_exists(
|
||||||
- does not fetch the user wallets
|
- does not fetch the user wallets
|
||||||
- caches the account info based on settings cache time
|
- caches the account info based on settings cache time
|
||||||
"""
|
"""
|
||||||
if access_token:
|
async with db.reuse_conn(conn) if conn else db.connect() as new_conn:
|
||||||
account = await _get_account_from_token(access_token, r["path"], r["method"])
|
if access_token:
|
||||||
elif usr and settings.is_auth_method_allowed(AuthMethods.user_id_only):
|
account = await _get_account_from_token(
|
||||||
account = await get_account(usr.hex)
|
access_token, r["path"], r["method"], conn=new_conn
|
||||||
if account and account.is_admin:
|
)
|
||||||
raise HTTPException(
|
elif usr and settings.is_auth_method_allowed(AuthMethods.user_id_only):
|
||||||
HTTPStatus.FORBIDDEN, "User id only access for admins is forbidden."
|
account = await get_account(usr.hex, conn=new_conn)
|
||||||
|
if account and account.is_admin:
|
||||||
|
raise HTTPException(
|
||||||
|
HTTPStatus.FORBIDDEN, "User id only access for admins is forbidden."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
HTTPStatus.UNAUTHORIZED, "Missing user ID or access token."
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Missing user ID or access token.")
|
|
||||||
|
|
||||||
if not account:
|
if not account:
|
||||||
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.")
|
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.")
|
||||||
|
|
||||||
r.scope["user_id"] = account.id
|
r.scope["user_id"] = account.id
|
||||||
await _check_user_access(r, account.id)
|
await _check_user_access(r, account.id, conn=new_conn)
|
||||||
|
|
||||||
return account
|
return account
|
||||||
|
|
||||||
|
|
@ -298,8 +316,9 @@ async def check_user_exists(
|
||||||
access_token: Annotated[str | None, Depends(check_access_token)],
|
access_token: Annotated[str | None, Depends(check_access_token)],
|
||||||
usr: UUID4 | None = None,
|
usr: UUID4 | None = None,
|
||||||
) -> User:
|
) -> User:
|
||||||
account = await check_account_exists(r, access_token, usr)
|
async with db.connect() as conn:
|
||||||
user = await get_user_from_account(account)
|
account = await _check_account_exists(r, access_token, usr, conn=conn)
|
||||||
|
user = await get_user_from_account(account, conn=conn)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.")
|
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.")
|
||||||
|
|
||||||
|
|
@ -330,29 +349,36 @@ async def access_token_payload(
|
||||||
return AccessTokenPayload(**payload)
|
return AccessTokenPayload(**payload)
|
||||||
|
|
||||||
|
|
||||||
async def check_admin(user: Annotated[User, Depends(check_user_exists)]) -> User:
|
async def check_admin(
|
||||||
if user.id != settings.super_user and user.id not in settings.lnbits_admin_users:
|
account: Annotated[Account, Depends(check_account_exists)],
|
||||||
|
) -> Account:
|
||||||
|
if (
|
||||||
|
account.id != settings.super_user
|
||||||
|
and account.id not in settings.lnbits_admin_users
|
||||||
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
HTTPStatus.FORBIDDEN, "User not authorized. No admin privileges."
|
HTTPStatus.FORBIDDEN, "User not authorized. No admin privileges."
|
||||||
)
|
)
|
||||||
if not user.has_password:
|
if not account.has_password:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
HTTPStatus.FORBIDDEN, "Admin users must have credentials configured."
|
HTTPStatus.FORBIDDEN, "Admin users must have credentials configured."
|
||||||
)
|
)
|
||||||
|
|
||||||
return user
|
return account
|
||||||
|
|
||||||
|
|
||||||
async def check_super_user(user: Annotated[User, Depends(check_user_exists)]) -> User:
|
async def check_super_user(
|
||||||
if user.id != settings.super_user:
|
account: Annotated[Account, Depends(check_admin)],
|
||||||
|
) -> Account:
|
||||||
|
if account.id != settings.super_user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
HTTPStatus.FORBIDDEN, "User not authorized. No super user privileges."
|
HTTPStatus.FORBIDDEN, "User not authorized. No super user privileges."
|
||||||
)
|
)
|
||||||
if not user.has_password:
|
if not account.has_password:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
HTTPStatus.FORBIDDEN, "Super user must have credentials configured."
|
HTTPStatus.FORBIDDEN, "Super user must have credentials configured."
|
||||||
)
|
)
|
||||||
return user
|
return account
|
||||||
|
|
||||||
|
|
||||||
def parse_filters(model: type[TFilterModel]):
|
def parse_filters(model: type[TFilterModel]):
|
||||||
|
|
@ -412,15 +438,17 @@ async def check_user_extension_access(
|
||||||
return SimpleStatus(success=True, message="OK")
|
return SimpleStatus(success=True, message="OK")
|
||||||
|
|
||||||
|
|
||||||
async def _check_user_access(r: Request, user_id: str):
|
async def _check_user_access(r: Request, user_id: str, conn: Connection | None = None):
|
||||||
if not settings.is_user_allowed(user_id):
|
if not settings.is_user_allowed(user_id):
|
||||||
raise HTTPException(HTTPStatus.FORBIDDEN, "User not allowed.")
|
raise HTTPException(HTTPStatus.FORBIDDEN, "User not allowed.")
|
||||||
await _check_user_extension_access(user_id, r["path"])
|
await _check_user_extension_access(user_id, r["path"], conn=conn)
|
||||||
|
|
||||||
|
|
||||||
async def _check_user_extension_access(user_id: str, path: str):
|
async def _check_user_extension_access(
|
||||||
|
user_id: str, path: str, conn: Connection | None = None
|
||||||
|
):
|
||||||
ext_id = path_segments(path)[0]
|
ext_id = path_segments(path)[0]
|
||||||
status = await check_user_extension_access(user_id, ext_id)
|
status = await check_user_extension_access(user_id, ext_id, conn=conn)
|
||||||
if not status.success:
|
if not status.success:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
HTTPStatus.FORBIDDEN,
|
HTTPStatus.FORBIDDEN,
|
||||||
|
|
@ -429,12 +457,12 @@ async def _check_user_extension_access(user_id: str, path: str):
|
||||||
|
|
||||||
|
|
||||||
async def _get_account_from_token(
|
async def _get_account_from_token(
|
||||||
access_token: str, path: str, method: str
|
access_token: str, path: str, method: str, conn: Connection | None = None
|
||||||
) -> Account | None:
|
) -> Account | None:
|
||||||
try:
|
try:
|
||||||
payload: dict = jwt.decode(access_token, settings.auth_secret_key, ["HS256"])
|
payload: dict = jwt.decode(access_token, settings.auth_secret_key, ["HS256"])
|
||||||
return await _get_account_from_jwt_payload(
|
return await _get_account_from_jwt_payload(
|
||||||
AccessTokenPayload(**payload), path, method
|
AccessTokenPayload(**payload), path, method, conn=conn
|
||||||
)
|
)
|
||||||
|
|
||||||
except jwt.ExpiredSignatureError as exc:
|
except jwt.ExpiredSignatureError as exc:
|
||||||
|
|
@ -447,33 +475,35 @@ async def _get_account_from_token(
|
||||||
|
|
||||||
|
|
||||||
async def _get_account_from_jwt_payload(
|
async def _get_account_from_jwt_payload(
|
||||||
payload: AccessTokenPayload, path: str, method: str
|
payload: AccessTokenPayload, path: str, method: str, conn: Connection | None = None
|
||||||
) -> Account | None:
|
) -> Account | None:
|
||||||
account = None
|
account = None
|
||||||
if payload.sub:
|
if payload.sub:
|
||||||
account = await get_account_by_username(payload.sub)
|
account = await get_account_by_username(payload.sub, conn=conn)
|
||||||
elif payload.usr:
|
elif payload.usr:
|
||||||
account = await get_account(payload.usr)
|
account = await get_account(payload.usr, conn=conn)
|
||||||
elif payload.email:
|
elif payload.email:
|
||||||
account = await get_account_by_email(payload.email)
|
account = await get_account_by_email(payload.email, conn=conn)
|
||||||
|
|
||||||
if not account:
|
if not account:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if payload.api_token_id:
|
if payload.api_token_id:
|
||||||
await _check_account_api_access(account.id, payload.api_token_id, path, method)
|
await _check_account_api_access(
|
||||||
|
account.id, payload.api_token_id, path, method, conn=conn
|
||||||
|
)
|
||||||
|
|
||||||
return account
|
return account
|
||||||
|
|
||||||
|
|
||||||
async def _check_account_api_access(
|
async def _check_account_api_access(
|
||||||
user_id: str, token_id: str, path: str, method: str
|
user_id: str, token_id: str, path: str, method: str, conn: Connection | None = None
|
||||||
):
|
):
|
||||||
segments = path.split("/")
|
segments = path.split("/")
|
||||||
if len(segments) < 3:
|
if len(segments) < 3:
|
||||||
raise HTTPException(HTTPStatus.FORBIDDEN, "Not an API endpoint.")
|
raise HTTPException(HTTPStatus.FORBIDDEN, "Not an API endpoint.")
|
||||||
|
|
||||||
acls = await get_user_access_control_lists(user_id)
|
acls = await get_user_access_control_lists(user_id, conn=conn)
|
||||||
acl = acls.get_acl_by_token_id(token_id)
|
acl = acls.get_acl_by_token_id(token_id)
|
||||||
if not acl:
|
if not acl:
|
||||||
raise HTTPException(HTTPStatus.FORBIDDEN, "Invalid token id.")
|
raise HTTPException(HTTPStatus.FORBIDDEN, "Invalid token id.")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue