[perf] Faster require invoice key (#3603)
This commit is contained in:
parent
d9b045c526
commit
850087a8ec
14 changed files with 175 additions and 40 deletions
|
|
@ -3,9 +3,10 @@ from time import time
|
|||
from uuid import uuid4
|
||||
|
||||
from lnbits.core.db import db
|
||||
from lnbits.core.models.wallets import WalletsFilters, WalletType
|
||||
from lnbits.core.models.wallets import BaseWallet, WalletsFilters, WalletType
|
||||
from lnbits.db import Connection, Filters, Page
|
||||
from lnbits.settings import settings
|
||||
from lnbits.utils.cache import cache
|
||||
|
||||
from ..models import Wallet
|
||||
|
||||
|
|
@ -51,6 +52,11 @@ async def delete_wallet(
|
|||
conn: Connection | None = None,
|
||||
) -> None:
|
||||
now = int(time())
|
||||
cached_wallet: BaseWallet | None = cache.pop(f"auth:wallet:{wallet_id}")
|
||||
if cached_wallet:
|
||||
cache.pop(f"auth:x-api-key:{cached_wallet.adminkey}")
|
||||
cache.pop(f"auth:x-api-key:{cached_wallet.inkey}")
|
||||
|
||||
await (conn or db).execute(
|
||||
# Timestamp placeholder is safe from SQL injection (not user input)
|
||||
f"""
|
||||
|
|
@ -240,6 +246,24 @@ async def get_wallet_for_key(
|
|||
return wallet
|
||||
|
||||
|
||||
async def get_base_wallet_for_key(
|
||||
key: str,
|
||||
conn: Connection | None = None,
|
||||
) -> BaseWallet | None:
|
||||
wallet = await (conn or db).fetchone(
|
||||
"""
|
||||
SELECT id, "user", wallet_type, adminkey, inkey FROM wallets
|
||||
WHERE (adminkey = :key OR inkey = :key) AND deleted = false
|
||||
""",
|
||||
{"key": key},
|
||||
BaseWallet,
|
||||
)
|
||||
if not wallet:
|
||||
return None
|
||||
|
||||
return wallet
|
||||
|
||||
|
||||
async def get_source_wallet(
|
||||
wallet: Wallet, conn: Connection | None = None
|
||||
) -> Wallet | None:
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ from .users import (
|
|||
UserAcls,
|
||||
UserExtra,
|
||||
)
|
||||
from .wallets import BaseWallet, CreateWallet, KeyType, Wallet, WalletTypeInfo
|
||||
from .wallets import CreateWallet, KeyType, Wallet, WalletInfo, WalletTypeInfo
|
||||
from .webpush import CreateWebPushSubscription, WebPushSubscription
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -57,7 +57,6 @@ __all__ = [
|
|||
"AuditEntry",
|
||||
"AuditFilters",
|
||||
"BalanceDelta",
|
||||
"BaseWallet",
|
||||
"Callback",
|
||||
"CancelInvoice",
|
||||
"ConversionData",
|
||||
|
|
@ -99,6 +98,7 @@ __all__ = [
|
|||
"UserAcls",
|
||||
"UserExtra",
|
||||
"Wallet",
|
||||
"WalletInfo",
|
||||
"WalletTypeInfo",
|
||||
"WebPushSubscription",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from lnbits.db import FilterModel
|
|||
from lnbits.settings import settings
|
||||
|
||||
|
||||
class BaseWallet(BaseModel):
|
||||
class WalletInfo(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
adminkey: str
|
||||
|
|
@ -110,13 +110,16 @@ class WalletExtra(BaseModel):
|
|||
]
|
||||
|
||||
|
||||
class Wallet(BaseModel):
|
||||
class BaseWallet(BaseModel):
|
||||
id: str
|
||||
user: str
|
||||
name: str
|
||||
wallet_type: str = WalletType.LIGHTNING.value
|
||||
adminkey: str
|
||||
inkey: str
|
||||
wallet_type: str = WalletType.LIGHTNING.value
|
||||
|
||||
|
||||
class Wallet(BaseWallet):
|
||||
name: str
|
||||
# Must be set only for shared wallets
|
||||
shared_wallet_id: str | None = None
|
||||
deleted: bool = False
|
||||
|
|
@ -230,6 +233,12 @@ class WalletTypeInfo:
|
|||
wallet: Wallet
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseWalletTypeInfo:
|
||||
key_type: KeyType
|
||||
wallet: BaseWallet
|
||||
|
||||
|
||||
class WalletsFilters(FilterModel):
|
||||
__search_fields__ = ["id", "name", "currency"]
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from fastapi import APIRouter, Depends
|
|||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from lnbits.core.models import (
|
||||
BaseWallet,
|
||||
ConversionData,
|
||||
CreateWallet,
|
||||
User,
|
||||
|
|
@ -71,7 +70,6 @@ async def health_check(
|
|||
"/api/v1/wallets",
|
||||
name="Wallets",
|
||||
description="Get basic info for all of user's wallets.",
|
||||
response_model=list[BaseWallet],
|
||||
)
|
||||
async def api_wallets(user: User = Depends(check_user_exists)) -> list[Wallet]:
|
||||
return user.wallets
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from lnbits.core.models.lnurl import CreateLnurlPayment, LnurlScan
|
|||
from lnbits.decorators import (
|
||||
WalletTypeInfo,
|
||||
require_admin_key,
|
||||
require_invoice_key,
|
||||
require_base_invoice_key,
|
||||
)
|
||||
from lnbits.helpers import check_callback_url
|
||||
from lnbits.settings import settings
|
||||
|
|
@ -48,7 +48,7 @@ async def _handle(lnurl: str) -> LnurlResponseModel:
|
|||
|
||||
@lnurl_router.get(
|
||||
"/api/v1/lnurlscan/{code}",
|
||||
dependencies=[Depends(require_invoice_key)],
|
||||
dependencies=[Depends(require_base_invoice_key)],
|
||||
deprecated=True,
|
||||
response_model=LnurlPayResponse
|
||||
| LnurlWithdrawResponse
|
||||
|
|
@ -64,7 +64,7 @@ async def api_lnurlscan(code: str) -> LnurlResponseModel:
|
|||
|
||||
@lnurl_router.post(
|
||||
"/api/v1/lnurlscan",
|
||||
dependencies=[Depends(require_invoice_key)],
|
||||
dependencies=[Depends(require_base_invoice_key)],
|
||||
response_model=LnurlPayResponse
|
||||
| LnurlWithdrawResponse
|
||||
| LnurlAuthResponse
|
||||
|
|
|
|||
|
|
@ -36,13 +36,15 @@ from lnbits.core.models import (
|
|||
)
|
||||
from lnbits.core.models.payments import UpdatePaymentLabels
|
||||
from lnbits.core.models.users import AccountId
|
||||
from lnbits.core.models.wallets import BaseWalletTypeInfo
|
||||
from lnbits.db import Filters, Page
|
||||
from lnbits.decorators import (
|
||||
WalletTypeInfo,
|
||||
check_account_id_exists,
|
||||
parse_filters,
|
||||
require_admin_key,
|
||||
require_invoice_key,
|
||||
require_base_admin_key,
|
||||
require_base_invoice_key,
|
||||
)
|
||||
from lnbits.helpers import (
|
||||
filter_dict_keys,
|
||||
|
|
@ -82,7 +84,7 @@ payment_router = APIRouter(prefix="/api/v1/payments", tags=["Payments"])
|
|||
openapi_extra=generate_filter_params_openapi(PaymentFilters),
|
||||
)
|
||||
async def api_payments(
|
||||
key_info: WalletTypeInfo = Depends(require_invoice_key),
|
||||
key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key),
|
||||
filters: Filters = Depends(parse_filters(PaymentFilters)),
|
||||
):
|
||||
await update_pending_payments(key_info.wallet.id)
|
||||
|
|
@ -101,7 +103,7 @@ async def api_payments(
|
|||
openapi_extra=generate_filter_params_openapi(PaymentFilters),
|
||||
)
|
||||
async def api_payments_history(
|
||||
key_info: WalletTypeInfo = Depends(require_invoice_key),
|
||||
key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key),
|
||||
group: DateTrunc = Query("day"),
|
||||
filters: Filters[PaymentFilters] = Depends(parse_filters(PaymentFilters)),
|
||||
):
|
||||
|
|
@ -178,7 +180,7 @@ async def api_payments_daily_stats(
|
|||
openapi_extra=generate_filter_params_openapi(PaymentFilters),
|
||||
)
|
||||
async def api_payments_paginated(
|
||||
key_info: WalletTypeInfo = Depends(require_invoice_key),
|
||||
key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key),
|
||||
recheck_pending: bool = Query(
|
||||
False, description="Force check and update of pending payments."
|
||||
),
|
||||
|
|
@ -243,10 +245,10 @@ async def api_all_payments_paginated(
|
|||
)
|
||||
async def api_payments_create(
|
||||
invoice_data: CreateInvoice,
|
||||
wallet: WalletTypeInfo = Depends(require_invoice_key),
|
||||
key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key),
|
||||
) -> Payment:
|
||||
wallet_id = wallet.wallet.id
|
||||
if invoice_data.out is True and wallet.key_type == KeyType.admin:
|
||||
wallet_id = key_info.wallet.id
|
||||
if invoice_data.out is True and key_info.key_type == KeyType.admin:
|
||||
if not invoice_data.bolt11:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
|
|
@ -274,9 +276,8 @@ async def api_payments_create(
|
|||
async def api_update_payment_labels(
|
||||
payment_hash: str,
|
||||
data: UpdatePaymentLabels,
|
||||
key_type: WalletTypeInfo = Depends(require_admin_key),
|
||||
key_type: BaseWalletTypeInfo = Depends(require_base_admin_key),
|
||||
) -> SimpleStatus:
|
||||
|
||||
payment = await get_standalone_payment(payment_hash, wallet_id=key_type.wallet.id)
|
||||
if payment is None:
|
||||
raise HTTPException(HTTPStatus.NOT_FOUND, "Payment does not exist.")
|
||||
|
|
|
|||
|
|
@ -7,10 +7,11 @@ from fastapi import (
|
|||
)
|
||||
from starlette.responses import RedirectResponse
|
||||
|
||||
from lnbits.core.models.wallets import BaseWalletTypeInfo
|
||||
from lnbits.decorators import (
|
||||
WalletTypeInfo,
|
||||
require_admin_key,
|
||||
require_invoice_key,
|
||||
require_base_invoice_key,
|
||||
)
|
||||
|
||||
from ..crud import (
|
||||
|
|
@ -50,12 +51,12 @@ async def api_create_tinyurl(
|
|||
description="get a tinyurl by id",
|
||||
)
|
||||
async def api_get_tinyurl(
|
||||
tinyurl_id: str, wallet: WalletTypeInfo = Depends(require_invoice_key)
|
||||
tinyurl_id: str, key_info: BaseWalletTypeInfo = Depends(require_base_invoice_key)
|
||||
):
|
||||
try:
|
||||
tinyurl = await get_tinyurl(tinyurl_id)
|
||||
if tinyurl:
|
||||
if tinyurl.wallet == wallet.wallet.id:
|
||||
if tinyurl.wallet == key_info.wallet.id:
|
||||
return tinyurl
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN, detail="Wrong key provided."
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ from lnbits.decorators import (
|
|||
require_invoice_key,
|
||||
)
|
||||
from lnbits.helpers import generate_filter_params_openapi
|
||||
from lnbits.utils.cache import cache
|
||||
|
||||
from ..crud import (
|
||||
delete_wallet,
|
||||
|
|
@ -133,6 +134,10 @@ async def api_reset_wallet_keys(
|
|||
if not wallet or wallet.user != account_id.id:
|
||||
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Wallet not found")
|
||||
|
||||
cache.pop(f"auth:wallet:{wallet.id}")
|
||||
cache.pop(f"auth:x-api-key:{wallet.adminkey}")
|
||||
cache.pop(f"auth:x-api-key:{wallet.inkey}")
|
||||
|
||||
wallet.adminkey = uuid4().hex
|
||||
wallet.inkey = uuid4().hex
|
||||
await update_wallet(wallet)
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from lnbits.core.crud import (
|
|||
get_wallet_for_key,
|
||||
)
|
||||
from lnbits.core.crud.users import get_user_access_control_lists
|
||||
from lnbits.core.crud.wallets import get_base_wallet_for_key
|
||||
from lnbits.core.models import (
|
||||
AccessTokenPayload,
|
||||
Account,
|
||||
|
|
@ -28,6 +29,7 @@ from lnbits.core.models import (
|
|||
WalletTypeInfo,
|
||||
)
|
||||
from lnbits.core.models.users import AccountId
|
||||
from lnbits.core.models.wallets import BaseWallet, BaseWalletTypeInfo
|
||||
from lnbits.db import Connection, Filter, Filters, TFilterModel
|
||||
from lnbits.helpers import normalize_path, path_segments, sha256s
|
||||
from lnbits.settings import AuthMethods, settings
|
||||
|
|
@ -54,7 +56,7 @@ api_key_query = APIKeyQuery(
|
|||
)
|
||||
|
||||
|
||||
class KeyChecker(SecurityBase):
|
||||
class BaseKeyChecker(SecurityBase):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
|
|
@ -79,8 +81,7 @@ class KeyChecker(SecurityBase):
|
|||
)
|
||||
self.model: APIKey = openapi_model # type: ignore
|
||||
|
||||
async def __call__(self, request: Request) -> WalletTypeInfo:
|
||||
|
||||
def _extract_key_value(self, request):
|
||||
key_value = (
|
||||
self._api_key
|
||||
if self._api_key
|
||||
|
|
@ -93,6 +94,31 @@ class KeyChecker(SecurityBase):
|
|||
detail="No Api Key provided.",
|
||||
)
|
||||
|
||||
return key_value
|
||||
|
||||
async def _extract_key_type(self, key_value: str, wallet: BaseWallet) -> KeyType:
|
||||
if self.expected_key_type is KeyType.admin and wallet.adminkey != key_value:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Invalid adminkey.",
|
||||
)
|
||||
|
||||
key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice
|
||||
return key_type
|
||||
|
||||
|
||||
class KeyChecker(BaseKeyChecker):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
expected_key_type: KeyType | None = None,
|
||||
):
|
||||
super().__init__(api_key, expected_key_type)
|
||||
|
||||
async def __call__(self, request: Request) -> WalletTypeInfo:
|
||||
key_value = self._extract_key_value(request)
|
||||
|
||||
wallet = await get_wallet_for_key(key_value)
|
||||
|
||||
if not wallet:
|
||||
|
|
@ -102,18 +128,52 @@ class KeyChecker(SecurityBase):
|
|||
)
|
||||
|
||||
request.scope["user_id"] = wallet.user
|
||||
if self.expected_key_type is KeyType.admin and wallet.adminkey != key_value:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Invalid adminkey.",
|
||||
)
|
||||
|
||||
await _check_user_access(request, wallet.user)
|
||||
|
||||
key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice
|
||||
key_type = await self._extract_key_type(key_value, wallet)
|
||||
return WalletTypeInfo(key_type, wallet)
|
||||
|
||||
|
||||
class LightKeyChecker(BaseKeyChecker):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
expected_key_type: KeyType | None = None,
|
||||
):
|
||||
super().__init__(api_key, expected_key_type)
|
||||
|
||||
async def __call__(self, request: Request) -> BaseWalletTypeInfo:
|
||||
key_value = self._extract_key_value(request)
|
||||
cache_key = f"auth:x-api-key:{key_value}"
|
||||
cache_time = settings.auth_authentication_cache_minutes * 60
|
||||
|
||||
if cache_time > 0:
|
||||
key_info: BaseWalletTypeInfo | None = cache.get(cache_key)
|
||||
if key_info:
|
||||
request.scope["user_id"] = key_info.wallet.user
|
||||
await _check_user_access(request, key_info.wallet.user)
|
||||
return key_info
|
||||
|
||||
wallet = await get_base_wallet_for_key(key_value)
|
||||
|
||||
if not wallet:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail="Wallet not found.",
|
||||
)
|
||||
request.scope["user_id"] = wallet.user
|
||||
await _check_user_access(request, wallet.user)
|
||||
|
||||
key_type = await self._extract_key_type(key_value, wallet)
|
||||
key_info = BaseWalletTypeInfo(key_type, wallet)
|
||||
|
||||
if cache_time > 0:
|
||||
cache.set(cache_key, key_info, expiry=cache_time)
|
||||
cache.set(f"auth:wallet:{wallet.id}", wallet, expiry=cache_time)
|
||||
return key_info
|
||||
|
||||
|
||||
async def require_admin_key(
|
||||
request: Request,
|
||||
api_key_header: str = Security(api_key_header),
|
||||
|
|
@ -126,6 +186,18 @@ async def require_admin_key(
|
|||
return await check(request)
|
||||
|
||||
|
||||
async def require_base_admin_key(
|
||||
request: Request,
|
||||
api_key_header: str = Security(api_key_header),
|
||||
api_key_query: str = Security(api_key_query),
|
||||
) -> BaseWalletTypeInfo:
|
||||
check: LightKeyChecker = LightKeyChecker(
|
||||
api_key=api_key_header or api_key_query,
|
||||
expected_key_type=KeyType.admin,
|
||||
)
|
||||
return await check(request)
|
||||
|
||||
|
||||
async def require_invoice_key(
|
||||
request: Request,
|
||||
api_key_header: str = Security(api_key_header),
|
||||
|
|
@ -138,6 +210,18 @@ async def require_invoice_key(
|
|||
return await check(request)
|
||||
|
||||
|
||||
async def require_base_invoice_key(
|
||||
request: Request,
|
||||
api_key_header: str = Security(api_key_header),
|
||||
api_key_query: str = Security(api_key_query),
|
||||
) -> BaseWalletTypeInfo:
|
||||
check: LightKeyChecker = LightKeyChecker(
|
||||
api_key=api_key_header or api_key_query,
|
||||
expected_key_type=KeyType.invoice,
|
||||
)
|
||||
return await check(request)
|
||||
|
||||
|
||||
async def check_access_token(
|
||||
header_access_token: Annotated[str | None, Depends(oauth2_scheme)],
|
||||
cookie_access_token: Annotated[str | None, Cookie()] = None,
|
||||
|
|
|
|||
2
lnbits/static/bundle.min.js
vendored
2
lnbits/static/bundle.min.js
vendored
File diff suppressed because one or more lines are too long
|
|
@ -574,6 +574,9 @@ window.localisation.en = {
|
|||
authentication: 'Authentication',
|
||||
auth_token_expiry_label: 'Token expire minutes',
|
||||
auth_token_expiry_hint: 'Time in minutes until the token expires',
|
||||
auth_authentication_cache_label: 'Cache time (minutes)',
|
||||
auth_authentication_cache_hint:
|
||||
'Time in minutes to cache successful authentication (0 to disable)',
|
||||
auth_allowed_methods_label: 'Allowed authorization methods',
|
||||
auth_allowed_methods_hint: 'Select authorization methods',
|
||||
auth_nostr_label: 'Nostr Request URL',
|
||||
|
|
|
|||
|
|
@ -28,6 +28,16 @@
|
|||
>
|
||||
</q-input>
|
||||
</div>
|
||||
<div class="col-12 col-md-6">
|
||||
<q-input
|
||||
filled
|
||||
v-model="formData.auth_authentication_cache_minutes"
|
||||
type="number"
|
||||
:label="$t('auth_authentication_cache_label')"
|
||||
:hint="$t('auth_authentication_cache_hint')"
|
||||
>
|
||||
</q-input>
|
||||
</div>
|
||||
<div class="col-12 col-md-6">
|
||||
<q-select
|
||||
filled
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import operator
|
|||
|
||||
import pytest
|
||||
|
||||
from lnbits.core.models import BaseWallet
|
||||
from lnbits.core.models import WalletInfo
|
||||
from tests.wallets.fixtures.models import FundingSourceConfig, WalletTest
|
||||
|
||||
wallets_module = importlib.import_module("lnbits.wallets")
|
||||
|
|
@ -57,7 +57,7 @@ def build_test_id(test: WalletTest):
|
|||
return f"{test.funding_source.name}.{test.function}({test.description})"
|
||||
|
||||
|
||||
def load_funding_source(funding_source: FundingSourceConfig) -> BaseWallet:
|
||||
def load_funding_source(funding_source: FundingSourceConfig) -> WalletInfo:
|
||||
custom_settings = funding_source.settings
|
||||
original_settings = {}
|
||||
|
||||
|
|
@ -67,7 +67,7 @@ def load_funding_source(funding_source: FundingSourceConfig) -> BaseWallet:
|
|||
original_settings[s] = getattr(settings, s)
|
||||
setattr(settings, s, custom_settings[s])
|
||||
|
||||
fs_instance: BaseWallet = getattr(wallets_module, funding_source.wallet_class)()
|
||||
fs_instance: WalletInfo = getattr(wallets_module, funding_source.wallet_class)()
|
||||
|
||||
# rollback settings (global variable)
|
||||
for s in original_settings:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
from loguru import logger
|
||||
from pytest_mock.plugin import MockerFixture
|
||||
|
||||
from lnbits.core.models import BaseWallet
|
||||
from lnbits.core.models import WalletInfo
|
||||
from tests.wallets.fixtures.models import DataObject
|
||||
from tests.wallets.fixtures.models import Mock as RpcMock
|
||||
from tests.wallets.helpers import (
|
||||
|
|
@ -91,7 +91,7 @@ def _check_calls(expected_calls):
|
|||
func_call["spy"].assert_called_with(*args, **kwargs)
|
||||
|
||||
|
||||
def _spy_mocks(mocker: MockerFixture, test_data: WalletTest, wallet: BaseWallet):
|
||||
def _spy_mocks(mocker: MockerFixture, test_data: WalletTest, wallet: WalletInfo):
|
||||
expected_calls: dict[str, list] = {}
|
||||
for mock in test_data.mocks:
|
||||
client_field = getattr(wallet, mock.name)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue