[feat] Check payment tag (#2522)
* feat: check if the payment is made for an extension that the user disabed
This commit is contained in:
parent
93965bc5b6
commit
7c68a02eee
3 changed files with 52 additions and 23 deletions
|
|
@ -453,3 +453,8 @@ class BalanceDelta(BaseModel):
|
||||||
@property
|
@property
|
||||||
def delta_msats(self):
|
def delta_msats(self):
|
||||||
return self.node_balance_msats - self.lnbits_balance_msats
|
return self.node_balance_msats - self.lnbits_balance_msats
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleStatus(BaseModel):
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,11 @@ from py_vapid.utils import b64urlencode
|
||||||
|
|
||||||
from lnbits.core.db import db
|
from lnbits.core.db import db
|
||||||
from lnbits.db import Connection
|
from lnbits.db import Connection
|
||||||
from lnbits.decorators import WalletTypeInfo, require_admin_key
|
from lnbits.decorators import (
|
||||||
|
WalletTypeInfo,
|
||||||
|
check_user_extension_access,
|
||||||
|
require_admin_key,
|
||||||
|
)
|
||||||
from lnbits.helpers import url_for
|
from lnbits.helpers import url_for
|
||||||
from lnbits.lnurl import LnurlErrorResponse
|
from lnbits.lnurl import LnurlErrorResponse
|
||||||
from lnbits.lnurl import decode as decode_lnurl
|
from lnbits.lnurl import decode as decode_lnurl
|
||||||
|
|
@ -300,18 +304,13 @@ async def pay_invoice(
|
||||||
# do the balance check
|
# do the balance check
|
||||||
wallet = await get_wallet(wallet_id, conn=conn)
|
wallet = await get_wallet(wallet_id, conn=conn)
|
||||||
assert wallet, "Wallet for balancecheck could not be fetched"
|
assert wallet, "Wallet for balancecheck could not be fetched"
|
||||||
if wallet.balance_msat < 0:
|
_check_wallet_balance(wallet, fee_reserve_total_msat, internal_checking_id)
|
||||||
logger.debug("balance is too low, deleting temporary payment")
|
|
||||||
if (
|
if extra and "tag" in extra:
|
||||||
not internal_checking_id
|
# check if the payment is made for an extension that the user disabled
|
||||||
and wallet.balance_msat > -fee_reserve_total_msat
|
status = await check_user_extension_access(wallet.user, extra["tag"])
|
||||||
):
|
if not status.success:
|
||||||
raise PaymentError(
|
raise PaymentError(status.message)
|
||||||
f"You must reserve at least ({round(fee_reserve_total_msat/1000)}"
|
|
||||||
" sat) to cover potential routing fees.",
|
|
||||||
status="failed",
|
|
||||||
)
|
|
||||||
raise PaymentError("Insufficient balance.", status="failed")
|
|
||||||
|
|
||||||
if internal_checking_id:
|
if internal_checking_id:
|
||||||
service_fee_msat = service_fee(invoice.amount_msat, internal=True)
|
service_fee_msat = service_fee(invoice.amount_msat, internal=True)
|
||||||
|
|
@ -402,6 +401,22 @@ async def pay_invoice(
|
||||||
return invoice.payment_hash
|
return invoice.payment_hash
|
||||||
|
|
||||||
|
|
||||||
|
def _check_wallet_balance(
|
||||||
|
wallet: Wallet,
|
||||||
|
fee_reserve_total_msat: int,
|
||||||
|
internal_checking_id: Optional[str] = None,
|
||||||
|
):
|
||||||
|
if wallet.balance_msat < 0:
|
||||||
|
logger.debug("balance is too low, deleting temporary payment")
|
||||||
|
if not internal_checking_id and wallet.balance_msat > -fee_reserve_total_msat:
|
||||||
|
raise PaymentError(
|
||||||
|
f"You must reserve at least ({round(fee_reserve_total_msat/1000)}"
|
||||||
|
" sat) to cover potential routing fees.",
|
||||||
|
status="failed",
|
||||||
|
)
|
||||||
|
raise PaymentError("Insufficient balance.", status="failed")
|
||||||
|
|
||||||
|
|
||||||
async def check_wallet_limits(wallet_id, conn, amount_msat):
|
async def check_wallet_limits(wallet_id, conn, amount_msat):
|
||||||
await check_time_limit_between_transactions(conn, wallet_id)
|
await check_time_limit_between_transactions(conn, wallet_id)
|
||||||
await check_wallet_daily_withdraw_limit(conn, wallet_id, amount_msat)
|
await check_wallet_daily_withdraw_limit(conn, wallet_id, amount_msat)
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ from lnbits.core.crud import (
|
||||||
get_user_active_extensions_ids,
|
get_user_active_extensions_ids,
|
||||||
get_wallet_for_key,
|
get_wallet_for_key,
|
||||||
)
|
)
|
||||||
from lnbits.core.models import KeyType, User, WalletTypeInfo
|
from lnbits.core.models import KeyType, SimpleStatus, User, WalletTypeInfo
|
||||||
from lnbits.db import Filter, Filters, TFilterModel
|
from lnbits.db import Filter, Filters, TFilterModel
|
||||||
from lnbits.settings import AuthMethods, settings
|
from lnbits.settings import AuthMethods, settings
|
||||||
|
|
||||||
|
|
@ -210,27 +210,36 @@ def parse_filters(model: Type[TFilterModel]):
|
||||||
return dependency
|
return dependency
|
||||||
|
|
||||||
|
|
||||||
async def _check_user_extension_access(user_id: str, current_path: str):
|
async def check_user_extension_access(user_id: str, ext_id: str) -> SimpleStatus:
|
||||||
"""
|
"""
|
||||||
Check if the user has access to a particular extension.
|
Check if the user has access to a particular extension.
|
||||||
Raises HTTP Forbidden if the user is not allowed.
|
Raises HTTP Forbidden if the user is not allowed.
|
||||||
"""
|
"""
|
||||||
path = current_path.split("/")
|
|
||||||
ext_id = path[3] if path[1] == "upgrades" else path[1]
|
|
||||||
if settings.is_admin_extension(ext_id) and not settings.is_admin_user(user_id):
|
if settings.is_admin_extension(ext_id) and not settings.is_admin_user(user_id):
|
||||||
raise HTTPException(
|
return SimpleStatus(
|
||||||
HTTPStatus.FORBIDDEN,
|
success=False, message=f"User not authorized for extension '{ext_id}'."
|
||||||
f"User not authorized for extension '{ext_id}'.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if settings.is_extension_id(ext_id):
|
if settings.is_extension_id(ext_id):
|
||||||
ext_ids = await get_user_active_extensions_ids(user_id)
|
ext_ids = await get_user_active_extensions_ids(user_id)
|
||||||
if ext_id not in ext_ids:
|
if ext_id not in ext_ids:
|
||||||
raise HTTPException(
|
return SimpleStatus(
|
||||||
HTTPStatus.FORBIDDEN,
|
success=False, message=f"User extension '{ext_id}' not enabled."
|
||||||
f"User extension '{ext_id}' not enabled.",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return SimpleStatus(success=True, message="OK")
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_user_extension_access(user_id: str, current_path: str):
|
||||||
|
path = current_path.split("/")
|
||||||
|
ext_id = path[3] if path[1] == "upgrades" else path[1]
|
||||||
|
status = await check_user_extension_access(user_id, ext_id)
|
||||||
|
if not status.success:
|
||||||
|
raise HTTPException(
|
||||||
|
HTTPStatus.FORBIDDEN,
|
||||||
|
status.message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _get_account_from_token(access_token):
|
async def _get_account_from_token(access_token):
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue