[perf] Faster require invoice key (#3603)

This commit is contained in:
Vlad Stan 2025-12-05 10:03:51 +02:00 committed by GitHub
parent d9b045c526
commit 850087a8ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 175 additions and 40 deletions

View file

@ -3,9 +3,10 @@ from time import time
from uuid import uuid4 from uuid import uuid4
from lnbits.core.db import db from lnbits.core.db import db
from lnbits.core.models.wallets import WalletsFilters, WalletType from lnbits.core.models.wallets import BaseWallet, WalletsFilters, WalletType
from lnbits.db import Connection, Filters, Page from lnbits.db import Connection, Filters, Page
from lnbits.settings import settings from lnbits.settings import settings
from lnbits.utils.cache import cache
from ..models import Wallet from ..models import Wallet
@ -51,6 +52,11 @@ async def delete_wallet(
conn: Connection | None = None, conn: Connection | None = None,
) -> None: ) -> None:
now = int(time()) now = int(time())
cached_wallet: BaseWallet | None = cache.pop(f"auth:wallet:{wallet_id}")
if cached_wallet:
cache.pop(f"auth:x-api-key:{cached_wallet.adminkey}")
cache.pop(f"auth:x-api-key:{cached_wallet.inkey}")
await (conn or db).execute( await (conn or db).execute(
# Timestamp placeholder is safe from SQL injection (not user input) # Timestamp placeholder is safe from SQL injection (not user input)
f""" f"""
@ -240,6 +246,24 @@ async def get_wallet_for_key(
return wallet return wallet
async def get_base_wallet_for_key(
key: str,
conn: Connection | None = None,
) -> BaseWallet | None:
wallet = await (conn or db).fetchone(
"""
SELECT id, "user", wallet_type, adminkey, inkey FROM wallets
WHERE (adminkey = :key OR inkey = :key) AND deleted = false
""",
{"key": key},
BaseWallet,
)
if not wallet:
return None
return wallet
async def get_source_wallet( async def get_source_wallet(
wallet: Wallet, conn: Connection | None = None wallet: Wallet, conn: Connection | None = None
) -> Wallet | None: ) -> Wallet | None:

View file

@ -46,7 +46,7 @@ from .users import (
UserAcls, UserAcls,
UserExtra, UserExtra,
) )
from .wallets import BaseWallet, CreateWallet, KeyType, Wallet, WalletTypeInfo from .wallets import CreateWallet, KeyType, Wallet, WalletInfo, WalletTypeInfo
from .webpush import CreateWebPushSubscription, WebPushSubscription from .webpush import CreateWebPushSubscription, WebPushSubscription
__all__ = [ __all__ = [
@ -57,7 +57,6 @@ __all__ = [
"AuditEntry", "AuditEntry",
"AuditFilters", "AuditFilters",
"BalanceDelta", "BalanceDelta",
"BaseWallet",
"Callback", "Callback",
"CancelInvoice", "CancelInvoice",
"ConversionData", "ConversionData",
@ -99,6 +98,7 @@ __all__ = [
"UserAcls", "UserAcls",
"UserExtra", "UserExtra",
"Wallet", "Wallet",
"WalletInfo",
"WalletTypeInfo", "WalletTypeInfo",
"WebPushSubscription", "WebPushSubscription",
] ]

View file

@ -11,7 +11,7 @@ from lnbits.db import FilterModel
from lnbits.settings import settings from lnbits.settings import settings
class BaseWallet(BaseModel): class WalletInfo(BaseModel):
id: str id: str
name: str name: str
adminkey: str adminkey: str
@ -110,13 +110,16 @@ class WalletExtra(BaseModel):
] ]
class Wallet(BaseModel): class BaseWallet(BaseModel):
id: str id: str
user: str user: str
name: str wallet_type: str = WalletType.LIGHTNING.value
adminkey: str adminkey: str
inkey: str inkey: str
wallet_type: str = WalletType.LIGHTNING.value
class Wallet(BaseWallet):
name: str
# Must be set only for shared wallets # Must be set only for shared wallets
shared_wallet_id: str | None = None shared_wallet_id: str | None = None
deleted: bool = False deleted: bool = False
@ -230,6 +233,12 @@ class WalletTypeInfo:
wallet: Wallet wallet: Wallet
@dataclass
class BaseWalletTypeInfo:
key_type: KeyType
wallet: BaseWallet
class WalletsFilters(FilterModel): class WalletsFilters(FilterModel):
__search_fields__ = ["id", "name", "currency"] __search_fields__ = ["id", "name", "currency"]

View file

@ -8,7 +8,6 @@ from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from lnbits.core.models import ( from lnbits.core.models import (
BaseWallet,
ConversionData, ConversionData,
CreateWallet, CreateWallet,
User, User,
@ -71,7 +70,6 @@ async def health_check(
"/api/v1/wallets", "/api/v1/wallets",
name="Wallets", name="Wallets",
description="Get basic info for all of user's wallets.", description="Get basic info for all of user's wallets.",
response_model=list[BaseWallet],
) )
async def api_wallets(user: User = Depends(check_user_exists)) -> list[Wallet]: async def api_wallets(user: User = Depends(check_user_exists)) -> list[Wallet]:
return user.wallets return user.wallets

View file

@ -24,7 +24,7 @@ from lnbits.core.models.lnurl import CreateLnurlPayment, LnurlScan
from lnbits.decorators import ( from lnbits.decorators import (
WalletTypeInfo, WalletTypeInfo,
require_admin_key, require_admin_key,
require_invoice_key, require_base_invoice_key,
) )
from lnbits.helpers import check_callback_url from lnbits.helpers import check_callback_url
from lnbits.settings import settings from lnbits.settings import settings
@ -48,7 +48,7 @@ async def _handle(lnurl: str) -> LnurlResponseModel:
@lnurl_router.get( @lnurl_router.get(
"/api/v1/lnurlscan/{code}", "/api/v1/lnurlscan/{code}",
dependencies=[Depends(require_invoice_key)], dependencies=[Depends(require_base_invoice_key)],
deprecated=True, deprecated=True,
response_model=LnurlPayResponse response_model=LnurlPayResponse
| LnurlWithdrawResponse | LnurlWithdrawResponse
@ -64,7 +64,7 @@ async def api_lnurlscan(code: str) -> LnurlResponseModel:
@lnurl_router.post( @lnurl_router.post(
"/api/v1/lnurlscan", "/api/v1/lnurlscan",
dependencies=[Depends(require_invoice_key)], dependencies=[Depends(require_base_invoice_key)],
response_model=LnurlPayResponse response_model=LnurlPayResponse
| LnurlWithdrawResponse | LnurlWithdrawResponse
| LnurlAuthResponse | LnurlAuthResponse

View file

@ -36,13 +36,15 @@ from lnbits.core.models import (
) )
from lnbits.core.models.payments import UpdatePaymentLabels from lnbits.core.models.payments import UpdatePaymentLabels
from lnbits.core.models.users import AccountId from lnbits.core.models.users import AccountId
from lnbits.core.models.wallets import BaseWalletTypeInfo
from lnbits.db import Filters, Page from lnbits.db import Filters, Page
from lnbits.decorators import ( from lnbits.decorators import (
WalletTypeInfo, WalletTypeInfo,
check_account_id_exists, check_account_id_exists,
parse_filters, parse_filters,
require_admin_key, require_admin_key,
require_invoice_key, require_base_admin_key,
require_base_invoice_key,
) )
from lnbits.helpers import ( from lnbits.helpers import (
filter_dict_keys, filter_dict_keys,
@ -82,7 +84,7 @@ payment_router = APIRouter(prefix="/api/v1/payments", tags=["Payments"])
openapi_extra=generate_filter_params_openapi(PaymentFilters), openapi_extra=generate_filter_params_openapi(PaymentFilters),
) )
async def api_payments( async def api_payments(
key_info: WalletTypeInfo = Depends(require_invoice_key), key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key),
filters: Filters = Depends(parse_filters(PaymentFilters)), filters: Filters = Depends(parse_filters(PaymentFilters)),
): ):
await update_pending_payments(key_info.wallet.id) await update_pending_payments(key_info.wallet.id)
@ -101,7 +103,7 @@ async def api_payments(
openapi_extra=generate_filter_params_openapi(PaymentFilters), openapi_extra=generate_filter_params_openapi(PaymentFilters),
) )
async def api_payments_history( async def api_payments_history(
key_info: WalletTypeInfo = Depends(require_invoice_key), key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key),
group: DateTrunc = Query("day"), group: DateTrunc = Query("day"),
filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)), filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)),
): ):
@ -178,7 +180,7 @@ async def api_payments_daily_stats(
openapi_extra=generate_filter_params_openapi(PaymentFilters), openapi_extra=generate_filter_params_openapi(PaymentFilters),
) )
async def api_payments_paginated( async def api_payments_paginated(
key_info: WalletTypeInfo = Depends(require_invoice_key), key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key),
recheck_pending: bool = Query( recheck_pending: bool = Query(
False, description="Force check and update of pending payments." False, description="Force check and update of pending payments."
), ),
@ -243,10 +245,10 @@ async def api_all_payments_paginated(
) )
async def api_payments_create( async def api_payments_create(
invoice_data: CreateInvoice, invoice_data: CreateInvoice,
wallet: WalletTypeInfo = Depends(require_invoice_key), key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key),
) -> Payment: ) -> Payment:
wallet_id = wallet.wallet.id wallet_id = key_info.wallet.id
if invoice_data.out is True and wallet.key_type == KeyType.admin: if invoice_data.out is True and key_info.key_type == KeyType.admin:
if not invoice_data.bolt11: if not invoice_data.bolt11:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
@ -274,9 +276,8 @@ async def api_payments_create(
async def api_update_payment_labels( async def api_update_payment_labels(
payment_hash: str, payment_hash: str,
data: UpdatePaymentLabels, data: UpdatePaymentLabels,
key_type: WalletTypeInfo = Depends(require_admin_key), key_type: BaseWalletTypeInfo = Depends(require_base_admin_key),
) -> SimpleStatus: ) -> SimpleStatus:
payment = await get_standalone_payment(payment_hash, wallet_id=key_type.wallet.id) payment = await get_standalone_payment(payment_hash, wallet_id=key_type.wallet.id)
if payment is None: if payment is None:
raise HTTPException(HTTPStatus.NOT_FOUND, "Payment does not exist.") raise HTTPException(HTTPStatus.NOT_FOUND, "Payment does not exist.")

View file

@ -7,10 +7,11 @@ from fastapi import (
) )
from starlette.responses import RedirectResponse from starlette.responses import RedirectResponse
from lnbits.core.models.wallets import BaseWalletTypeInfo
from lnbits.decorators import ( from lnbits.decorators import (
WalletTypeInfo, WalletTypeInfo,
require_admin_key, require_admin_key,
require_invoice_key, require_base_invoice_key,
) )
from ..crud import ( from ..crud import (
@ -50,12 +51,12 @@ async def api_create_tinyurl(
description="get a tinyurl by id", description="get a tinyurl by id",
) )
async def api_get_tinyurl( async def api_get_tinyurl(
tinyurl_id: str, wallet: WalletTypeInfo = Depends(require_invoice_key) tinyurl_id: str, key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key)
): ):
try: try:
tinyurl = await get_tinyurl(tinyurl_id) tinyurl = await get_tinyurl(tinyurl_id)
if tinyurl: if tinyurl:
if tinyurl.wallet == wallet.wallet.id: if tinyurl.wallet == key_info.wallet.id:
return tinyurl return tinyurl
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="Wrong key provided." status_code=HTTPStatus.FORBIDDEN, detail="Wrong key provided."

View file

@ -37,6 +37,7 @@ from lnbits.decorators import (
require_invoice_key, require_invoice_key,
) )
from lnbits.helpers import generate_filter_params_openapi from lnbits.helpers import generate_filter_params_openapi
from lnbits.utils.cache import cache
from ..crud import ( from ..crud import (
delete_wallet, delete_wallet,
@ -133,6 +134,10 @@ async def api_reset_wallet_keys(
if not wallet or wallet.user != account_id.id: if not wallet or wallet.user != account_id.id:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Wallet not found") raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Wallet not found")
cache.pop(f"auth:wallet:{wallet.id}")
cache.pop(f"auth:x-api-key:{wallet.adminkey}")
cache.pop(f"auth:x-api-key:{wallet.inkey}")
wallet.adminkey = uuid4().hex wallet.adminkey = uuid4().hex
wallet.inkey = uuid4().hex wallet.inkey = uuid4().hex
await update_wallet(wallet) await update_wallet(wallet)

View file

@ -19,6 +19,7 @@ from lnbits.core.crud import (
get_wallet_for_key, get_wallet_for_key,
) )
from lnbits.core.crud.users import get_user_access_control_lists from lnbits.core.crud.users import get_user_access_control_lists
from lnbits.core.crud.wallets import get_base_wallet_for_key
from lnbits.core.models import ( from lnbits.core.models import (
AccessTokenPayload, AccessTokenPayload,
Account, Account,
@ -28,6 +29,7 @@ from lnbits.core.models import (
WalletTypeInfo, WalletTypeInfo,
) )
from lnbits.core.models.users import AccountId from lnbits.core.models.users import AccountId
from lnbits.core.models.wallets import BaseWallet, BaseWalletTypeInfo
from lnbits.db import Connection, Filter, Filters, TFilterModel from lnbits.db import Connection, Filter, Filters, TFilterModel
from lnbits.helpers import normalize_path, path_segments, sha256s from lnbits.helpers import normalize_path, path_segments, sha256s
from lnbits.settings import AuthMethods, settings from lnbits.settings import AuthMethods, settings
@ -54,7 +56,7 @@ api_key_query = APIKeyQuery(
) )
class KeyChecker(SecurityBase): class BaseKeyChecker(SecurityBase):
def __init__( def __init__(
self, self,
api_key: str | None = None, api_key: str | None = None,
@ -79,8 +81,7 @@ class KeyChecker(SecurityBase):
) )
self.model: APIKey = openapi_model # type: ignore self.model: APIKey = openapi_model # type: ignore
async def __call__(self, request: Request) -> WalletTypeInfo: def _extract_key_value(self, request):
key_value = ( key_value = (
self._api_key self._api_key
if self._api_key if self._api_key
@ -93,6 +94,31 @@ class KeyChecker(SecurityBase):
detail="No Api Key provided.", detail="No Api Key provided.",
) )
return key_value
async def _extract_key_type(self, key_value: str, wallet: BaseWallet) -> KeyType:
if self.expected_key_type is KeyType.admin and wallet.adminkey != key_value:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Invalid adminkey.",
)
key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice
return key_type
class KeyChecker(BaseKeyChecker):
def __init__(
self,
api_key: str | None = None,
expected_key_type: KeyType | None = None,
):
super().__init__(api_key, expected_key_type)
async def __call__(self, request: Request) -> WalletTypeInfo:
key_value = self._extract_key_value(request)
wallet = await get_wallet_for_key(key_value) wallet = await get_wallet_for_key(key_value)
if not wallet: if not wallet:
@ -102,18 +128,52 @@ class KeyChecker(SecurityBase):
) )
request.scope["user_id"] = wallet.user request.scope["user_id"] = wallet.user
if self.expected_key_type is KeyType.admin and wallet.adminkey != key_value:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Invalid adminkey.",
)
await _check_user_access(request, wallet.user) await _check_user_access(request, wallet.user)
key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice key_type = await self._extract_key_type(key_value, wallet)
return WalletTypeInfo(key_type, wallet) return WalletTypeInfo(key_type, wallet)
class LightKeyChecker(BaseKeyChecker):
def __init__(
self,
api_key: str | None = None,
expected_key_type: KeyType | None = None,
):
super().__init__(api_key, expected_key_type)
async def __call__(self, request: Request) -> BaseWalletTypeInfo:
key_value = self._extract_key_value(request)
cache_key = f"auth:x-api-key:{key_value}"
cache_time = settings.auth_authentication_cache_minutes * 60
if cache_time > 0:
key_info: BaseWalletTypeInfo | None = cache.get(cache_key)
if key_info:
request.scope["user_id"] = key_info.wallet.user
await _check_user_access(request, key_info.wallet.user)
return key_info
wallet = await get_base_wallet_for_key(key_value)
if not wallet:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Wallet not found.",
)
request.scope["user_id"] = wallet.user
await _check_user_access(request, wallet.user)
key_type = await self._extract_key_type(key_value, wallet)
key_info = BaseWalletTypeInfo(key_type, wallet)
if cache_time > 0:
cache.set(cache_key, key_info, expiry=cache_time)
cache.set(f"auth:wallet:{wallet.id}", wallet, expiry=cache_time)
return key_info
async def require_admin_key( async def require_admin_key(
request: Request, request: Request,
api_key_header: str = Security(api_key_header), api_key_header: str = Security(api_key_header),
@ -126,6 +186,18 @@ async def require_admin_key(
return await check(request) return await check(request)
async def require_base_admin_key(
request: Request,
api_key_header: str = Security(api_key_header),
api_key_query: str = Security(api_key_query),
) -> BaseWalletTypeInfo:
check: LightKeyChecker = LightKeyChecker(
api_key=api_key_header or api_key_query,
expected_key_type=KeyType.admin,
)
return await check(request)
async def require_invoice_key( async def require_invoice_key(
request: Request, request: Request,
api_key_header: str = Security(api_key_header), api_key_header: str = Security(api_key_header),
@ -138,6 +210,18 @@ async def require_invoice_key(
return await check(request) return await check(request)
async def require_base_invoice_key(
request: Request,
api_key_header: str = Security(api_key_header),
api_key_query: str = Security(api_key_query),
) -> BaseWalletTypeInfo:
check: LightKeyChecker = LightKeyChecker(
api_key=api_key_header or api_key_query,
expected_key_type=KeyType.invoice,
)
return await check(request)
async def check_access_token( async def check_access_token(
header_access_token: Annotated[str | None, Depends(oauth2_scheme)], header_access_token: Annotated[str | None, Depends(oauth2_scheme)],
cookie_access_token: Annotated[str | None, Cookie()] = None, cookie_access_token: Annotated[str | None, Cookie()] = None,

File diff suppressed because one or more lines are too long

View file

@ -574,6 +574,9 @@ window.localisation.en = {
authentication: 'Authentication', authentication: 'Authentication',
auth_token_expiry_label: 'Token expire minutes', auth_token_expiry_label: 'Token expire minutes',
auth_token_expiry_hint: 'Time in minutes until the token expires', auth_token_expiry_hint: 'Time in minutes until the token expires',
auth_authentication_cache_label: 'Cache time (minutes)',
auth_authentication_cache_hint:
'Time in minutes to cache successful authentication (0 to disable)',
auth_allowed_methods_label: 'Allowed authorization methods', auth_allowed_methods_label: 'Allowed authorization methods',
auth_allowed_methods_hint: 'Select authorization methods', auth_allowed_methods_hint: 'Select authorization methods',
auth_nostr_label: 'Nostr Request URL', auth_nostr_label: 'Nostr Request URL',

View file

@ -28,6 +28,16 @@
> >
</q-input> </q-input>
</div> </div>
<div class="col-12 col-md-6">
<q-input
filled
v-model="formData.auth_authentication_cache_minutes"
type="number"
:label="$t('auth_authentication_cache_label')"
:hint="$t('auth_authentication_cache_hint')"
>
</q-input>
</div>
<div class="col-12 col-md-6"> <div class="col-12 col-md-6">
<q-select <q-select
filled filled

View file

@ -5,7 +5,7 @@ import operator
import pytest import pytest
from lnbits.core.models import BaseWallet from lnbits.core.models import WalletInfo
from tests.wallets.fixtures.models import FundingSourceConfig, WalletTest from tests.wallets.fixtures.models import FundingSourceConfig, WalletTest
wallets_module = importlib.import_module("lnbits.wallets") wallets_module = importlib.import_module("lnbits.wallets")
@ -57,7 +57,7 @@ def build_test_id(test: WalletTest):
return f"{test.funding_source.name}.{test.function}({test.description})" return f"{test.funding_source.name}.{test.function}({test.description})"
def load_funding_source(funding_source: FundingSourceConfig) -> BaseWallet: def load_funding_source(funding_source: FundingSourceConfig) -> WalletInfo:
custom_settings = funding_source.settings custom_settings = funding_source.settings
original_settings = {} original_settings = {}
@ -67,7 +67,7 @@ def load_funding_source(funding_source: FundingSourceConfig) -> BaseWallet:
original_settings[s] = getattr(settings, s) original_settings[s] = getattr(settings, s)
setattr(settings, s, custom_settings[s]) setattr(settings, s, custom_settings[s])
fs_instance: BaseWallet = getattr(wallets_module, funding_source.wallet_class)() fs_instance: WalletInfo = getattr(wallets_module, funding_source.wallet_class)()
# rollback settings (global variable) # rollback settings (global variable)
for s in original_settings: for s in original_settings:

View file

@ -5,7 +5,7 @@ import pytest
from loguru import logger from loguru import logger
from pytest_mock.plugin import MockerFixture from pytest_mock.plugin import MockerFixture
from lnbits.core.models import BaseWallet from lnbits.core.models import WalletInfo
from tests.wallets.fixtures.models import DataObject from tests.wallets.fixtures.models import DataObject
from tests.wallets.fixtures.models import Mock as RpcMock from tests.wallets.fixtures.models import Mock as RpcMock
from tests.wallets.helpers import ( from tests.wallets.helpers import (
@ -91,7 +91,7 @@ def _check_calls(expected_calls):
func_call["spy"].assert_called_with(*args, **kwargs) func_call["spy"].assert_called_with(*args, **kwargs)
def _spy_mocks(mocker: MockerFixture, test_data: WalletTest, wallet: BaseWallet): def _spy_mocks(mocker: MockerFixture, test_data: WalletTest, wallet: WalletInfo):
expected_calls: dict[str, list] = {} expected_calls: dict[str, list] = {}
for mock in test_data.mocks: for mock in test_data.mocks:
client_field = getattr(wallet, mock.name) client_field = getattr(wallet, mock.name)