[feat] Check payment tag (#2522)

* feat: check if the payment is made for an extension that the user disabed
This commit is contained in:
Vlad Stan 2024-05-24 17:24:59 +03:00 committed by GitHub
parent 93965bc5b6
commit 7c68a02eee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 52 additions and 23 deletions

View file

@ -453,3 +453,8 @@ class BalanceDelta(BaseModel):
@property @property
def delta_msats(self): def delta_msats(self):
return self.node_balance_msats - self.lnbits_balance_msats return self.node_balance_msats - self.lnbits_balance_msats
class SimpleStatus(BaseModel):
success: bool
message: str

View file

@ -17,7 +17,11 @@ from py_vapid.utils import b64urlencode
from lnbits.core.db import db from lnbits.core.db import db
from lnbits.db import Connection from lnbits.db import Connection
from lnbits.decorators import WalletTypeInfo, require_admin_key from lnbits.decorators import (
WalletTypeInfo,
check_user_extension_access,
require_admin_key,
)
from lnbits.helpers import url_for from lnbits.helpers import url_for
from lnbits.lnurl import LnurlErrorResponse from lnbits.lnurl import LnurlErrorResponse
from lnbits.lnurl import decode as decode_lnurl from lnbits.lnurl import decode as decode_lnurl
@ -300,18 +304,13 @@ async def pay_invoice(
# do the balance check # do the balance check
wallet = await get_wallet(wallet_id, conn=conn) wallet = await get_wallet(wallet_id, conn=conn)
assert wallet, "Wallet for balancecheck could not be fetched" assert wallet, "Wallet for balancecheck could not be fetched"
if wallet.balance_msat < 0: _check_wallet_balance(wallet, fee_reserve_total_msat, internal_checking_id)
logger.debug("balance is too low, deleting temporary payment")
if ( if extra and "tag" in extra:
not internal_checking_id # check if the payment is made for an extension that the user disabled
and wallet.balance_msat > -fee_reserve_total_msat status = await check_user_extension_access(wallet.user, extra["tag"])
): if not status.success:
raise PaymentError( raise PaymentError(status.message)
f"You must reserve at least ({round(fee_reserve_total_msat/1000)}"
" sat) to cover potential routing fees.",
status="failed",
)
raise PaymentError("Insufficient balance.", status="failed")
if internal_checking_id: if internal_checking_id:
service_fee_msat = service_fee(invoice.amount_msat, internal=True) service_fee_msat = service_fee(invoice.amount_msat, internal=True)
@ -402,6 +401,22 @@ async def pay_invoice(
return invoice.payment_hash return invoice.payment_hash
def _check_wallet_balance(
wallet: Wallet,
fee_reserve_total_msat: int,
internal_checking_id: Optional[str] = None,
):
if wallet.balance_msat < 0:
logger.debug("balance is too low, deleting temporary payment")
if not internal_checking_id and wallet.balance_msat > -fee_reserve_total_msat:
raise PaymentError(
f"You must reserve at least ({round(fee_reserve_total_msat/1000)}"
" sat) to cover potential routing fees.",
status="failed",
)
raise PaymentError("Insufficient balance.", status="failed")
async def check_wallet_limits(wallet_id, conn, amount_msat): async def check_wallet_limits(wallet_id, conn, amount_msat):
await check_time_limit_between_transactions(conn, wallet_id) await check_time_limit_between_transactions(conn, wallet_id)
await check_wallet_daily_withdraw_limit(conn, wallet_id, amount_msat) await check_wallet_daily_withdraw_limit(conn, wallet_id, amount_msat)

View file

@ -18,7 +18,7 @@ from lnbits.core.crud import (
get_user_active_extensions_ids, get_user_active_extensions_ids,
get_wallet_for_key, get_wallet_for_key,
) )
from lnbits.core.models import KeyType, User, WalletTypeInfo from lnbits.core.models import KeyType, SimpleStatus, User, WalletTypeInfo
from lnbits.db import Filter, Filters, TFilterModel from lnbits.db import Filter, Filters, TFilterModel
from lnbits.settings import AuthMethods, settings from lnbits.settings import AuthMethods, settings
@ -210,27 +210,36 @@ def parse_filters(model: Type[TFilterModel]):
return dependency return dependency
async def _check_user_extension_access(user_id: str, current_path: str): async def check_user_extension_access(user_id: str, ext_id: str) -> SimpleStatus:
""" """
Check if the user has access to a particular extension. Check if the user has access to a particular extension.
Raises HTTP Forbidden if the user is not allowed. Raises HTTP Forbidden if the user is not allowed.
""" """
path = current_path.split("/")
ext_id = path[3] if path[1] == "upgrades" else path[1]
if settings.is_admin_extension(ext_id) and not settings.is_admin_user(user_id): if settings.is_admin_extension(ext_id) and not settings.is_admin_user(user_id):
raise HTTPException( return SimpleStatus(
HTTPStatus.FORBIDDEN, success=False, message=f"User not authorized for extension '{ext_id}'."
f"User not authorized for extension '{ext_id}'.",
) )
if settings.is_extension_id(ext_id): if settings.is_extension_id(ext_id):
ext_ids = await get_user_active_extensions_ids(user_id) ext_ids = await get_user_active_extensions_ids(user_id)
if ext_id not in ext_ids: if ext_id not in ext_ids:
raise HTTPException( return SimpleStatus(
HTTPStatus.FORBIDDEN, success=False, message=f"User extension '{ext_id}' not enabled."
f"User extension '{ext_id}' not enabled.",
) )
return SimpleStatus(success=True, message="OK")
async def _check_user_extension_access(user_id: str, current_path: str):
path = current_path.split("/")
ext_id = path[3] if path[1] == "upgrades" else path[1]
status = await check_user_extension_access(user_id, ext_id)
if not status.success:
raise HTTPException(
HTTPStatus.FORBIDDEN,
status.message,
)
async def _get_account_from_token(access_token): async def _get_account_from_token(access_token):
try: try: