refactor: use CreatePayment model instead of a lot of kwargs (#2667)

- refactoring create_payment a bit to use a model instead of 10 kwargs
This commit is contained in:
dni ⚡ 2024-09-24 11:13:30 +02:00 committed by GitHub
parent 053ea20508
commit 9d7e54f6b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 84 additions and 78 deletions

View file

@ -1,7 +1,6 @@
import datetime
import json import json
from time import time from time import time
from typing import Dict, List, Literal, Optional from typing import Literal, Optional
from uuid import uuid4 from uuid import uuid4
import shortuuid import shortuuid
@ -27,6 +26,7 @@ from lnbits.settings import (
from .models import ( from .models import (
Account, Account,
AccountFilters, AccountFilters,
CreatePayment,
Payment, Payment,
PaymentFilters, PaymentFilters,
PaymentHistoryPoint, PaymentHistoryPoint,
@ -38,9 +38,6 @@ from .models import (
WebPushSubscription, WebPushSubscription,
) )
# accounts
# --------
async def create_account( async def create_account(
user_id: Optional[str] = None, user_id: Optional[str] = None,
@ -430,7 +427,7 @@ async def get_installed_extension(
async def get_installed_extensions( async def get_installed_extensions(
active: Optional[bool] = None, active: Optional[bool] = None,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> List[InstallableExtension]: ) -> list[InstallableExtension]:
rows = await (conn or db).fetchall( rows = await (conn or db).fetchall(
"SELECT * FROM installed_extensions", "SELECT * FROM installed_extensions",
) )
@ -456,7 +453,7 @@ async def get_user_extension(
async def get_user_extensions( async def get_user_extensions(
user_id: str, conn: Optional[Connection] = None user_id: str, conn: Optional[Connection] = None
) -> List[UserExtension]: ) -> list[UserExtension]:
rows = await (conn or db).fetchall( rows = await (conn or db).fetchall(
""" """
SELECT extension, active, extra as _extra FROM extensions SELECT extension, active, extra as _extra FROM extensions
@ -481,7 +478,7 @@ async def update_user_extension(
async def get_user_active_extensions_ids( async def get_user_active_extensions_ids(
user_id: str, conn: Optional[Connection] = None user_id: str, conn: Optional[Connection] = None
) -> List[str]: ) -> list[str]:
rows = await (conn or db).fetchall( rows = await (conn or db).fetchall(
""" """
SELECT extension FROM extensions WHERE "user" = :user AND active SELECT extension FROM extensions WHERE "user" = :user AND active
@ -650,7 +647,7 @@ async def get_wallet(
return Wallet(**row) if row else None return Wallet(**row) if row else None
async def get_wallets(user_id: str, conn: Optional[Connection] = None) -> List[Wallet]: async def get_wallets(user_id: str, conn: Optional[Connection] = None) -> list[Wallet]:
rows = await (conn or db).fetchall( rows = await (conn or db).fetchall(
""" """
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0) SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0)
@ -772,7 +769,7 @@ async def get_payments_paginated(
"wallet": wallet_id, "wallet": wallet_id,
"time": since, "time": since,
} }
clause: List[str] = [] clause: list[str] = []
if since is not None: if since is not None:
clause.append(f"time > {db.timestamp_placeholder('time')}") clause.append(f"time > {db.timestamp_placeholder('time')}")
@ -882,23 +879,13 @@ async def delete_expired_invoices(
async def create_payment( async def create_payment(
*,
wallet_id: str,
checking_id: str, checking_id: str,
payment_request: str, data: CreatePayment,
payment_hash: str,
amount: int,
memo: str,
fee: int = 0,
status: PaymentState = PaymentState.PENDING, status: PaymentState = PaymentState.PENDING,
preimage: Optional[str] = None,
expiry: Optional[datetime.datetime] = None,
extra: Optional[Dict] = None,
webhook: Optional[str] = None,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> Payment: ) -> Payment:
# we don't allow the creation of the same invoice twice # we don't allow the creation of the same invoice twice
# note: this can be removed if the db uniquess constarints are set appropriately # note: this can be removed if the db uniqueness constraints are set appropriately
previous_payment = await get_standalone_payment(checking_id, conn=conn) previous_payment = await get_standalone_payment(checking_id, conn=conn)
assert previous_payment is None, "Payment already exists" assert previous_payment is None, "Payment already exists"
@ -912,27 +899,27 @@ async def create_payment(
:amount, :status, :memo, :fee, :extra, :webhook, {expiry_ph}, :pending) :amount, :status, :memo, :fee, :extra, :webhook, {expiry_ph}, :pending)
""", """,
{ {
"wallet": wallet_id, "wallet": data.wallet_id,
"checking_id": checking_id, "checking_id": checking_id,
"bolt11": payment_request, "bolt11": data.payment_request,
"hash": payment_hash, "hash": data.payment_hash,
"preimage": preimage, "preimage": data.preimage,
"amount": amount, "amount": data.amount,
"status": status.value, "status": status.value,
"memo": memo, "memo": data.memo,
"fee": fee, "fee": data.fee,
"extra": ( "extra": (
json.dumps(extra) json.dumps(data.extra)
if extra and extra != {} and isinstance(extra, dict) if data.extra and data.extra != {} and isinstance(data.extra, dict)
else None else None
), ),
"webhook": webhook, "webhook": data.webhook,
"expiry": expiry if expiry else None, "expiry": data.expiry if data.expiry else None,
"pending": False, # TODO: remove this in next release "pending": False, # TODO: remove this in next release
}, },
) )
new_payment = await get_wallet_payment(wallet_id, payment_hash, conn=conn) new_payment = await get_wallet_payment(data.wallet_id, data.payment_hash, conn=conn)
assert new_payment, "Newly created payment couldn't be retrieved" assert new_payment, "Newly created payment couldn't be retrieved"
return new_payment return new_payment
@ -963,7 +950,7 @@ async def update_payment_details(
"preimage": preimage, "preimage": preimage,
} }
set_clause: List[str] = [] set_clause: list[str] = []
if new_checking_id is not None: if new_checking_id is not None:
set_clause.append("checking_id = :checking_id") set_clause.append("checking_id = :checking_id")
if status is not None: if status is not None:
@ -1023,7 +1010,7 @@ async def get_payments_history(
wallet_id: Optional[str] = None, wallet_id: Optional[str] = None,
group: DateTrunc = "day", group: DateTrunc = "day",
filters: Optional[Filters] = None, filters: Optional[Filters] = None,
) -> List[PaymentHistoryPoint]: ) -> list[PaymentHistoryPoint]:
if not filters: if not filters:
filters = Filters() filters = Filters()
@ -1249,7 +1236,7 @@ async def get_tinyurl(tinyurl_id: str) -> Optional[TinyURL]:
return TinyURL.from_row(row) if row else None return TinyURL.from_row(row) if row else None
async def get_tinyurl_by_url(url: str) -> List[TinyURL]: async def get_tinyurl_by_url(url: str) -> list[TinyURL]:
rows = await db.fetchall( rows = await db.fetchall(
"SELECT * FROM tiny_url WHERE url = :url", "SELECT * FROM tiny_url WHERE url = :url",
{"url": url}, {"url": url},
@ -1301,7 +1288,7 @@ async def get_webpush_subscription(
async def get_webpush_subscriptions_for_user( async def get_webpush_subscriptions_for_user(
user: str, user: str,
) -> List[WebPushSubscription]: ) -> list[WebPushSubscription]:
rows = await db.fetchall( rows = await db.fetchall(
"""SELECT * FROM webpush_subscriptions WHERE "user" = :user""", """SELECT * FROM webpush_subscriptions WHERE "user" = :user""",
{"user": user}, {"user": user},

View file

@ -212,6 +212,19 @@ class PaymentState(str, Enum):
return self.value return self.value
class CreatePayment(BaseModel):
wallet_id: str
payment_request: str
payment_hash: str
amount: int
memo: str
preimage: Optional[str] = None
expiry: Optional[datetime.datetime] = None
extra: Optional[dict] = None
webhook: Optional[str] = None
fee: int = 0
class Payment(FromRowModel): class Payment(FromRowModel):
status: str status: str
# TODO should be removed in the future, backward compatibility # TODO should be removed in the future, backward compatibility
@ -225,7 +238,7 @@ class Payment(FromRowModel):
preimage: str preimage: str
payment_hash: str payment_hash: str
expiry: Optional[float] expiry: Optional[float]
extra: dict = {} extra: Optional[dict]
wallet_id: str wallet_id: str
webhook: Optional[str] webhook: Optional[str]
webhook_status: Optional[int] webhook_status: Optional[int]

View file

@ -1,10 +1,9 @@
import asyncio import asyncio
import datetime
import json import json
import time import time
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, TypedDict from typing import Optional
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
from uuid import UUID, uuid4 from uuid import UUID, uuid4
@ -68,16 +67,24 @@ from .crud import (
update_user_extension, update_user_extension,
) )
from .helpers import to_valid_user_id from .helpers import to_valid_user_id
from .models import BalanceDelta, Payment, PaymentState, User, UserConfig, Wallet from .models import (
BalanceDelta,
CreatePayment,
Payment,
PaymentState,
User,
UserConfig,
Wallet,
)
async def calculate_fiat_amounts( async def calculate_fiat_amounts(
amount: float, amount: float,
wallet_id: str, wallet_id: str,
currency: Optional[str] = None, currency: Optional[str] = None,
extra: Optional[Dict] = None, extra: Optional[dict] = None,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> Tuple[int, Optional[Dict]]: ) -> tuple[int, Optional[dict]]:
wallet = await get_wallet(wallet_id, conn=conn) wallet = await get_wallet(wallet_id, conn=conn)
assert wallet, "invalid wallet_id" assert wallet, "invalid wallet_id"
wallet_currency = wallet.currency or settings.lnbits_default_accounting_currency wallet_currency = wallet.currency or settings.lnbits_default_accounting_currency
@ -118,11 +125,11 @@ async def create_invoice(
description_hash: Optional[bytes] = None, description_hash: Optional[bytes] = None,
unhashed_description: Optional[bytes] = None, unhashed_description: Optional[bytes] = None,
expiry: Optional[int] = None, expiry: Optional[int] = None,
extra: Optional[Dict] = None, extra: Optional[dict] = None,
webhook: Optional[str] = None, webhook: Optional[str] = None,
internal: Optional[bool] = False, internal: Optional[bool] = False,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> Tuple[str, str]: ) -> tuple[str, str]:
if not amount > 0: if not amount > 0:
raise InvoiceError("Amountless invoices not supported.", status="failed") raise InvoiceError("Amountless invoices not supported.", status="failed")
@ -167,17 +174,20 @@ async def create_invoice(
invoice = bolt11_decode(payment_request) invoice = bolt11_decode(payment_request)
amount_msat = 1000 * amount_sat create_payment_model = CreatePayment(
await create_payment(
wallet_id=wallet_id, wallet_id=wallet_id,
checking_id=checking_id,
payment_request=payment_request, payment_request=payment_request,
payment_hash=invoice.payment_hash, payment_hash=invoice.payment_hash,
amount=amount_msat, amount=amount_sat * 1000,
expiry=invoice.expiry_date, expiry=invoice.expiry_date,
memo=memo, memo=memo,
extra=extra, extra=extra,
webhook=webhook, webhook=webhook,
)
await create_payment(
checking_id=checking_id,
data=create_payment_model,
conn=conn, conn=conn,
) )
@ -189,7 +199,7 @@ async def pay_invoice(
wallet_id: str, wallet_id: str,
payment_request: str, payment_request: str,
max_sat: Optional[int] = None, max_sat: Optional[int] = None,
extra: Optional[Dict] = None, extra: Optional[dict] = None,
description: str = "", description: str = "",
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> str: ) -> str:
@ -223,17 +233,7 @@ async def pay_invoice(
invoice.amount_msat / 1000, wallet_id, extra=extra, conn=conn invoice.amount_msat / 1000, wallet_id, extra=extra, conn=conn
) )
# put all parameters that don't change here create_payment_model = CreatePayment(
class PaymentKwargs(TypedDict):
wallet_id: str
payment_request: str
payment_hash: str
amount: int
memo: str
expiry: Optional[datetime.datetime]
extra: Optional[Dict]
payment_kwargs: PaymentKwargs = PaymentKwargs(
wallet_id=wallet_id, wallet_id=wallet_id,
payment_request=payment_request, payment_request=payment_request,
payment_hash=invoice.payment_hash, payment_hash=invoice.payment_hash,
@ -252,9 +252,6 @@ async def pay_invoice(
# (pending only) # (pending only)
internal_checking_id = await check_internal(invoice.payment_hash, conn=conn) internal_checking_id = await check_internal(invoice.payment_hash, conn=conn)
if internal_checking_id: if internal_checking_id:
fee_reserve_total_msat = fee_reserve_total(
invoice.amount_msat, internal=True
)
# perform additional checks on the internal payment # perform additional checks on the internal payment
# the payment hash is not enough to make sure that this is the same invoice # the payment hash is not enough to make sure that this is the same invoice
internal_invoice = await get_standalone_payment( internal_invoice = await get_standalone_payment(
@ -269,16 +266,23 @@ async def pay_invoice(
logger.debug(f"creating temporary internal payment with id {internal_id}") logger.debug(f"creating temporary internal payment with id {internal_id}")
# create a new payment from this wallet # create a new payment from this wallet
fee_reserve_total_msat = fee_reserve_total(
invoice.amount_msat, internal=True
)
create_payment_model.fee = abs(fee_reserve_total_msat)
new_payment = await create_payment( new_payment = await create_payment(
checking_id=internal_id, checking_id=internal_id,
fee=0 + abs(fee_reserve_total_msat), data=create_payment_model,
status=PaymentState.SUCCESS, status=PaymentState.SUCCESS,
conn=conn, conn=conn,
**payment_kwargs,
) )
else: else:
new_payment = await _create_external_payment( new_payment = await _create_external_payment(
temp_id, invoice.amount_msat, conn=conn, **payment_kwargs temp_id=temp_id,
amount_msat=invoice.amount_msat,
data=create_payment_model,
conn=conn,
) )
# do the balance check # do the balance check
@ -377,14 +381,16 @@ async def pay_invoice(
# credit service fee wallet # credit service fee wallet
if settings.lnbits_service_fee_wallet and service_fee_msat: if settings.lnbits_service_fee_wallet and service_fee_msat:
new_payment = await create_payment( create_payment_model = CreatePayment(
wallet_id=settings.lnbits_service_fee_wallet, wallet_id=settings.lnbits_service_fee_wallet,
fee=0,
amount=abs(service_fee_msat),
memo="Service fee",
checking_id="service_fee" + temp_id,
payment_request=payment_request, payment_request=payment_request,
payment_hash=invoice.payment_hash, payment_hash=invoice.payment_hash,
amount=abs(service_fee_msat),
memo="Service fee",
)
new_payment = await create_payment(
checking_id=f"service_fee_{temp_id}",
data=create_payment_model,
status=PaymentState.SUCCESS, status=PaymentState.SUCCESS,
) )
return invoice.payment_hash return invoice.payment_hash
@ -393,8 +399,8 @@ async def pay_invoice(
async def _create_external_payment( async def _create_external_payment(
temp_id: str, temp_id: str,
amount_msat: MilliSatoshi, amount_msat: MilliSatoshi,
data: CreatePayment,
conn: Optional[Connection], conn: Optional[Connection],
**payment_kwargs,
) -> Payment: ) -> Payment:
fee_reserve_total_msat = fee_reserve_total(amount_msat, internal=False) fee_reserve_total_msat = fee_reserve_total(amount_msat, internal=False)
@ -428,11 +434,11 @@ async def _create_external_payment(
# create a temporary payment here so we can check if # create a temporary payment here so we can check if
# the balance is enough in the next step # the balance is enough in the next step
try: try:
data.fee = -abs(fee_reserve_total_msat)
new_payment = await create_payment( new_payment = await create_payment(
checking_id=temp_id, checking_id=temp_id,
fee=-abs(fee_reserve_total_msat), data=data,
conn=conn, conn=conn,
**payment_kwargs,
) )
return new_payment return new_payment
except Exception as exc: except Exception as exc:
@ -514,7 +520,7 @@ async def redeem_lnurl_withdraw(
wallet_id: str, wallet_id: str,
lnurl_request: str, lnurl_request: str,
memo: Optional[str] = None, memo: Optional[str] = None,
extra: Optional[Dict] = None, extra: Optional[dict] = None,
wait_seconds: int = 0, wait_seconds: int = 0,
conn: Optional[Connection] = None, conn: Optional[Connection] = None,
) -> None: ) -> None:
@ -853,7 +859,7 @@ async def create_user_account(
class WebsocketConnectionManager: class WebsocketConnectionManager:
def __init__(self) -> None: def __init__(self) -> None:
self.active_connections: List[WebSocket] = [] self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket, item_id: str): async def connect(self, websocket: WebSocket, item_id: str):
logger.debug(f"Websocket connected to {item_id}") logger.debug(f"Websocket connected to {item_id}")