[perf] Faster require invoice key (#3603)

This commit is contained in:
Vlad Stan 2025-12-05 10:03:51 +02:00 committed by GitHub
parent d9b045c526
commit 850087a8ec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 175 additions and 40 deletions

View file

@ -3,9 +3,10 @@ from time import time
from uuid import uuid4
from lnbits.core.db import db
from lnbits.core.models.wallets import WalletsFilters, WalletType
from lnbits.core.models.wallets import BaseWallet, WalletsFilters, WalletType
from lnbits.db import Connection, Filters, Page
from lnbits.settings import settings
from lnbits.utils.cache import cache
from ..models import Wallet
@ -51,6 +52,11 @@ async def delete_wallet(
conn: Connection | None = None,
) -> None:
now = int(time())
cached_wallet: BaseWallet | None = cache.pop(f"auth:wallet:{wallet_id}")
if cached_wallet:
cache.pop(f"auth:x-api-key:{cached_wallet.adminkey}")
cache.pop(f"auth:x-api-key:{cached_wallet.inkey}")
await (conn or db).execute(
# Timestamp placeholder is safe from SQL injection (not user input)
f"""
@ -240,6 +246,24 @@ async def get_wallet_for_key(
return wallet
async def get_base_wallet_for_key(
key: str,
conn: Connection | None = None,
) -> BaseWallet | None:
wallet = await (conn or db).fetchone(
"""
SELECT id, "user", wallet_type, adminkey, inkey FROM wallets
WHERE (adminkey = :key OR inkey = :key) AND deleted = false
""",
{"key": key},
BaseWallet,
)
if not wallet:
return None
return wallet
async def get_source_wallet(
wallet: Wallet, conn: Connection | None = None
) -> Wallet | None:

View file

@ -46,7 +46,7 @@ from .users import (
UserAcls,
UserExtra,
)
from .wallets import BaseWallet, CreateWallet, KeyType, Wallet, WalletTypeInfo
from .wallets import CreateWallet, KeyType, Wallet, WalletInfo, WalletTypeInfo
from .webpush import CreateWebPushSubscription, WebPushSubscription
__all__ = [
@ -57,7 +57,6 @@ __all__ = [
"AuditEntry",
"AuditFilters",
"BalanceDelta",
"BaseWallet",
"Callback",
"CancelInvoice",
"ConversionData",
@ -99,6 +98,7 @@ __all__ = [
"UserAcls",
"UserExtra",
"Wallet",
"WalletInfo",
"WalletTypeInfo",
"WebPushSubscription",
]

View file

@ -11,7 +11,7 @@ from lnbits.db import FilterModel
from lnbits.settings import settings
class BaseWallet(BaseModel):
class WalletInfo(BaseModel):
id: str
name: str
adminkey: str
@ -110,13 +110,16 @@ class WalletExtra(BaseModel):
]
class Wallet(BaseModel):
class BaseWallet(BaseModel):
id: str
user: str
name: str
wallet_type: str = WalletType.LIGHTNING.value
adminkey: str
inkey: str
wallet_type: str = WalletType.LIGHTNING.value
class Wallet(BaseWallet):
name: str
# Must be set only for shared wallets
shared_wallet_id: str | None = None
deleted: bool = False
@ -230,6 +233,12 @@ class WalletTypeInfo:
wallet: Wallet
@dataclass
class BaseWalletTypeInfo:
key_type: KeyType
wallet: BaseWallet
class WalletsFilters(FilterModel):
__search_fields__ = ["id", "name", "currency"]

View file

@ -8,7 +8,6 @@ from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from lnbits.core.models import (
BaseWallet,
ConversionData,
CreateWallet,
User,
@ -71,7 +70,6 @@ async def health_check(
"/api/v1/wallets",
name="Wallets",
description="Get basic info for all of user's wallets.",
response_model=list[BaseWallet],
)
async def api_wallets(user: User = Depends(check_user_exists)) -> list[Wallet]:
return user.wallets

View file

@ -24,7 +24,7 @@ from lnbits.core.models.lnurl import CreateLnurlPayment, LnurlScan
from lnbits.decorators import (
WalletTypeInfo,
require_admin_key,
require_invoice_key,
require_base_invoice_key,
)
from lnbits.helpers import check_callback_url
from lnbits.settings import settings
@ -48,7 +48,7 @@ async def _handle(lnurl: str) -> LnurlResponseModel:
@lnurl_router.get(
"/api/v1/lnurlscan/{code}",
dependencies=[Depends(require_invoice_key)],
dependencies=[Depends(require_base_invoice_key)],
deprecated=True,
response_model=LnurlPayResponse
| LnurlWithdrawResponse
@ -64,7 +64,7 @@ async def api_lnurlscan(code: str) -> LnurlResponseModel:
@lnurl_router.post(
"/api/v1/lnurlscan",
dependencies=[Depends(require_invoice_key)],
dependencies=[Depends(require_base_invoice_key)],
response_model=LnurlPayResponse
| LnurlWithdrawResponse
| LnurlAuthResponse

View file

@ -36,13 +36,15 @@ from lnbits.core.models import (
)
from lnbits.core.models.payments import UpdatePaymentLabels
from lnbits.core.models.users import AccountId
from lnbits.core.models.wallets import BaseWalletTypeInfo
from lnbits.db import Filters, Page
from lnbits.decorators import (
WalletTypeInfo,
check_account_id_exists,
parse_filters,
require_admin_key,
require_invoice_key,
require_base_admin_key,
require_base_invoice_key,
)
from lnbits.helpers import (
filter_dict_keys,
@ -82,7 +84,7 @@ payment_router = APIRouter(prefix="/api/v1/payments", tags=["Payments"])
openapi_extra=generate_filter_params_openapi(PaymentFilters),
)
async def api_payments(
key_info: WalletTypeInfo = Depends(require_invoice_key),
key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key),
filters: Filters = Depends(parse_filters(PaymentFilters)),
):
await update_pending_payments(key_info.wallet.id)
@ -101,7 +103,7 @@ async def api_payments(
openapi_extra=generate_filter_params_openapi(PaymentFilters),
)
async def api_payments_history(
key_info: WalletTypeInfo = Depends(require_invoice_key),
key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key),
group: DateTrunc = Query("day"),
filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)),
):
@ -178,7 +180,7 @@ async def api_payments_daily_stats(
openapi_extra=generate_filter_params_openapi(PaymentFilters),
)
async def api_payments_paginated(
key_info: WalletTypeInfo = Depends(require_invoice_key),
key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key),
recheck_pending: bool = Query(
False, description="Force check and update of pending payments."
),
@ -243,10 +245,10 @@ async def api_all_payments_paginated(
)
async def api_payments_create(
invoice_data: CreateInvoice,
wallet: WalletTypeInfo = Depends(require_invoice_key),
key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key),
) -> Payment:
wallet_id = wallet.wallet.id
if invoice_data.out is True and wallet.key_type == KeyType.admin:
wallet_id = key_info.wallet.id
if invoice_data.out is True and key_info.key_type == KeyType.admin:
if not invoice_data.bolt11:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
@ -274,9 +276,8 @@ async def api_payments_create(
async def api_update_payment_labels(
payment_hash: str,
data: UpdatePaymentLabels,
key_type: WalletTypeInfo = Depends(require_admin_key),
key_type: BaseWalletTypeInfo = Depends(require_base_admin_key),
) -> SimpleStatus:
payment = await get_standalone_payment(payment_hash, wallet_id=key_type.wallet.id)
if payment is None:
raise HTTPException(HTTPStatus.NOT_FOUND, "Payment does not exist.")

View file

@ -7,10 +7,11 @@ from fastapi import (
)
from starlette.responses import RedirectResponse
from lnbits.core.models.wallets import BaseWalletTypeInfo
from lnbits.decorators import (
WalletTypeInfo,
require_admin_key,
require_invoice_key,
require_base_invoice_key,
)
from ..crud import (
@ -50,12 +51,12 @@ async def api_create_tinyurl(
description="get a tinyurl by id",
)
async def api_get_tinyurl(
tinyurl_id: str, wallet: WalletTypeInfo = Depends(require_invoice_key)
tinyurl_id: str, key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key)
):
try:
tinyurl = await get_tinyurl(tinyurl_id)
if tinyurl:
if tinyurl.wallet == wallet.wallet.id:
if tinyurl.wallet == key_info.wallet.id:
return tinyurl
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN, detail="Wrong key provided."

View file

@ -37,6 +37,7 @@ from lnbits.decorators import (
require_invoice_key,
)
from lnbits.helpers import generate_filter_params_openapi
from lnbits.utils.cache import cache
from ..crud import (
delete_wallet,
@ -133,6 +134,10 @@ async def api_reset_wallet_keys(
if not wallet or wallet.user != account_id.id:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Wallet not found")
cache.pop(f"auth:wallet:{wallet.id}")
cache.pop(f"auth:x-api-key:{wallet.adminkey}")
cache.pop(f"auth:x-api-key:{wallet.inkey}")
wallet.adminkey = uuid4().hex
wallet.inkey = uuid4().hex
await update_wallet(wallet)

View file

@ -19,6 +19,7 @@ from lnbits.core.crud import (
get_wallet_for_key,
)
from lnbits.core.crud.users import get_user_access_control_lists
from lnbits.core.crud.wallets import get_base_wallet_for_key
from lnbits.core.models import (
AccessTokenPayload,
Account,
@ -28,6 +29,7 @@ from lnbits.core.models import (
WalletTypeInfo,
)
from lnbits.core.models.users import AccountId
from lnbits.core.models.wallets import BaseWallet, BaseWalletTypeInfo
from lnbits.db import Connection, Filter, Filters, TFilterModel
from lnbits.helpers import normalize_path, path_segments, sha256s
from lnbits.settings import AuthMethods, settings
@ -54,7 +56,7 @@ api_key_query = APIKeyQuery(
)
class KeyChecker(SecurityBase):
class BaseKeyChecker(SecurityBase):
def __init__(
self,
api_key: str | None = None,
@ -79,8 +81,7 @@ class KeyChecker(SecurityBase):
)
self.model: APIKey = openapi_model # type: ignore
async def __call__(self, request: Request) -> WalletTypeInfo:
def _extract_key_value(self, request):
key_value = (
self._api_key
if self._api_key
@ -93,6 +94,31 @@ class KeyChecker(SecurityBase):
detail="No Api Key provided.",
)
return key_value
async def _extract_key_type(self, key_value: str, wallet: BaseWallet) -> KeyType:
if self.expected_key_type is KeyType.admin and wallet.adminkey != key_value:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Invalid adminkey.",
)
key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice
return key_type
class KeyChecker(BaseKeyChecker):
def __init__(
self,
api_key: str | None = None,
expected_key_type: KeyType | None = None,
):
super().__init__(api_key, expected_key_type)
async def __call__(self, request: Request) -> WalletTypeInfo:
key_value = self._extract_key_value(request)
wallet = await get_wallet_for_key(key_value)
if not wallet:
@ -102,18 +128,52 @@ class KeyChecker(SecurityBase):
)
request.scope["user_id"] = wallet.user
if self.expected_key_type is KeyType.admin and wallet.adminkey != key_value:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Invalid adminkey.",
)
await _check_user_access(request, wallet.user)
key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice
key_type = await self._extract_key_type(key_value, wallet)
return WalletTypeInfo(key_type, wallet)
class LightKeyChecker(BaseKeyChecker):
def __init__(
self,
api_key: str | None = None,
expected_key_type: KeyType | None = None,
):
super().__init__(api_key, expected_key_type)
async def __call__(self, request: Request) -> BaseWalletTypeInfo:
key_value = self._extract_key_value(request)
cache_key = f"auth:x-api-key:{key_value}"
cache_time = settings.auth_authentication_cache_minutes * 60
if cache_time > 0:
key_info: BaseWalletTypeInfo | None = cache.get(cache_key)
if key_info:
request.scope["user_id"] = key_info.wallet.user
await _check_user_access(request, key_info.wallet.user)
return key_info
wallet = await get_base_wallet_for_key(key_value)
if not wallet:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Wallet not found.",
)
request.scope["user_id"] = wallet.user
await _check_user_access(request, wallet.user)
key_type = await self._extract_key_type(key_value, wallet)
key_info = BaseWalletTypeInfo(key_type, wallet)
if cache_time > 0:
cache.set(cache_key, key_info, expiry=cache_time)
cache.set(f"auth:wallet:{wallet.id}", wallet, expiry=cache_time)
return key_info
async def require_admin_key(
request: Request,
api_key_header: str = Security(api_key_header),
@ -126,6 +186,18 @@ async def require_admin_key(
return await check(request)
async def require_base_admin_key(
request: Request,
api_key_header: str = Security(api_key_header),
api_key_query: str = Security(api_key_query),
) -> BaseWalletTypeInfo:
check: LightKeyChecker = LightKeyChecker(
api_key=api_key_header or api_key_query,
expected_key_type=KeyType.admin,
)
return await check(request)
async def require_invoice_key(
request: Request,
api_key_header: str = Security(api_key_header),
@ -138,6 +210,18 @@ async def require_invoice_key(
return await check(request)
async def require_base_invoice_key(
request: Request,
api_key_header: str = Security(api_key_header),
api_key_query: str = Security(api_key_query),
) -> BaseWalletTypeInfo:
check: LightKeyChecker = LightKeyChecker(
api_key=api_key_header or api_key_query,
expected_key_type=KeyType.invoice,
)
return await check(request)
async def check_access_token(
header_access_token: Annotated[str | None, Depends(oauth2_scheme)],
cookie_access_token: Annotated[str | None, Cookie()] = None,

File diff suppressed because one or more lines are too long

View file

@ -574,6 +574,9 @@ window.localisation.en = {
authentication: 'Authentication',
auth_token_expiry_label: 'Token expire minutes',
auth_token_expiry_hint: 'Time in minutes until the token expires',
auth_authentication_cache_label: 'Cache time (minutes)',
auth_authentication_cache_hint:
'Time in minutes to cache successful authentication (0 to disable)',
auth_allowed_methods_label: 'Allowed authorization methods',
auth_allowed_methods_hint: 'Select authorization methods',
auth_nostr_label: 'Nostr Request URL',

View file

@ -28,6 +28,16 @@
>
</q-input>
</div>
<div class="col-12 col-md-6">
<q-input
filled
v-model="formData.auth_authentication_cache_minutes"
type="number"
:label="$t('auth_authentication_cache_label')"
:hint="$t('auth_authentication_cache_hint')"
>
</q-input>
</div>
<div class="col-12 col-md-6">
<q-select
filled

View file

@ -5,7 +5,7 @@ import operator
import pytest
from lnbits.core.models import BaseWallet
from lnbits.core.models import WalletInfo
from tests.wallets.fixtures.models import FundingSourceConfig, WalletTest
wallets_module = importlib.import_module("lnbits.wallets")
@ -57,7 +57,7 @@ def build_test_id(test: WalletTest):
return f"{test.funding_source.name}.{test.function}({test.description})"
def load_funding_source(funding_source: FundingSourceConfig) -> BaseWallet:
def load_funding_source(funding_source: FundingSourceConfig) -> WalletInfo:
custom_settings = funding_source.settings
original_settings = {}
@ -67,7 +67,7 @@ def load_funding_source(funding_source: FundingSourceConfig) -> BaseWallet:
original_settings[s] = getattr(settings, s)
setattr(settings, s, custom_settings[s])
fs_instance: BaseWallet = getattr(wallets_module, funding_source.wallet_class)()
fs_instance: WalletInfo = getattr(wallets_module, funding_source.wallet_class)()
# rollback settings (global variable)
for s in original_settings:

View file

@ -5,7 +5,7 @@ import pytest
from loguru import logger
from pytest_mock.plugin import MockerFixture
from lnbits.core.models import BaseWallet
from lnbits.core.models import WalletInfo
from tests.wallets.fixtures.models import DataObject
from tests.wallets.fixtures.models import Mock as RpcMock
from tests.wallets.helpers import (
@ -91,7 +91,7 @@ def _check_calls(expected_calls):
func_call["spy"].assert_called_with(*args, **kwargs)
def _spy_mocks(mocker: MockerFixture, test_data: WalletTest, wallet: BaseWallet):
def _spy_mocks(mocker: MockerFixture, test_data: WalletTest, wallet: WalletInfo):
expected_calls: dict[str, list] = {}
for mock in test_data.mocks:
client_field = getattr(wallet, mock.name)