refactor: use CreatePayment model instead of a lot of kwargs (#2667)

- refactoring create_payment a bit to use a model instead of 10 kwargs
This commit is contained in:
dni ⚡ 2024-09-24 11:13:30 +02:00 committed by GitHub
parent 053ea20508
commit 9d7e54f6b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 84 additions and 78 deletions

View file

@ -1,7 +1,6 @@
import datetime
import json
from time import time
from typing import Dict, List, Literal, Optional
from typing import Literal, Optional
from uuid import uuid4
import shortuuid
@ -27,6 +26,7 @@ from lnbits.settings import (
from .models import (
Account,
AccountFilters,
CreatePayment,
Payment,
PaymentFilters,
PaymentHistoryPoint,
@ -38,9 +38,6 @@ from .models import (
WebPushSubscription,
)
# accounts
# --------
async def create_account(
user_id: Optional[str] = None,
@ -430,7 +427,7 @@ async def get_installed_extension(
async def get_installed_extensions(
active: Optional[bool] = None,
conn: Optional[Connection] = None,
) -> List[InstallableExtension]:
) -> list[InstallableExtension]:
rows = await (conn or db).fetchall(
"SELECT * FROM installed_extensions",
)
@ -456,7 +453,7 @@ async def get_user_extension(
async def get_user_extensions(
user_id: str, conn: Optional[Connection] = None
) -> List[UserExtension]:
) -> list[UserExtension]:
rows = await (conn or db).fetchall(
"""
SELECT extension, active, extra as _extra FROM extensions
@ -481,7 +478,7 @@ async def update_user_extension(
async def get_user_active_extensions_ids(
user_id: str, conn: Optional[Connection] = None
) -> List[str]:
) -> list[str]:
rows = await (conn or db).fetchall(
"""
SELECT extension FROM extensions WHERE "user" = :user AND active
@ -650,7 +647,7 @@ async def get_wallet(
return Wallet(**row) if row else None
async def get_wallets(user_id: str, conn: Optional[Connection] = None) -> List[Wallet]:
async def get_wallets(user_id: str, conn: Optional[Connection] = None) -> list[Wallet]:
rows = await (conn or db).fetchall(
"""
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0)
@ -772,7 +769,7 @@ async def get_payments_paginated(
"wallet": wallet_id,
"time": since,
}
clause: List[str] = []
clause: list[str] = []
if since is not None:
clause.append(f"time > {db.timestamp_placeholder('time')}")
@ -882,23 +879,13 @@ async def delete_expired_invoices(
async def create_payment(
*,
wallet_id: str,
checking_id: str,
payment_request: str,
payment_hash: str,
amount: int,
memo: str,
fee: int = 0,
data: CreatePayment,
status: PaymentState = PaymentState.PENDING,
preimage: Optional[str] = None,
expiry: Optional[datetime.datetime] = None,
extra: Optional[Dict] = None,
webhook: Optional[str] = None,
conn: Optional[Connection] = None,
) -> Payment:
# we don't allow the creation of the same invoice twice
# note: this can be removed if the db uniquess constarints are set appropriately
# note: this can be removed if the db uniqueness constraints are set appropriately
previous_payment = await get_standalone_payment(checking_id, conn=conn)
assert previous_payment is None, "Payment already exists"
@ -912,27 +899,27 @@ async def create_payment(
:amount, :status, :memo, :fee, :extra, :webhook, {expiry_ph}, :pending)
""",
{
"wallet": wallet_id,
"wallet": data.wallet_id,
"checking_id": checking_id,
"bolt11": payment_request,
"hash": payment_hash,
"preimage": preimage,
"amount": amount,
"bolt11": data.payment_request,
"hash": data.payment_hash,
"preimage": data.preimage,
"amount": data.amount,
"status": status.value,
"memo": memo,
"fee": fee,
"memo": data.memo,
"fee": data.fee,
"extra": (
json.dumps(extra)
if extra and extra != {} and isinstance(extra, dict)
json.dumps(data.extra)
if data.extra and data.extra != {} and isinstance(data.extra, dict)
else None
),
"webhook": webhook,
"expiry": expiry if expiry else None,
"webhook": data.webhook,
"expiry": data.expiry if data.expiry else None,
"pending": False, # TODO: remove this in next release
},
)
new_payment = await get_wallet_payment(wallet_id, payment_hash, conn=conn)
new_payment = await get_wallet_payment(data.wallet_id, data.payment_hash, conn=conn)
assert new_payment, "Newly created payment couldn't be retrieved"
return new_payment
@ -963,7 +950,7 @@ async def update_payment_details(
"preimage": preimage,
}
set_clause: List[str] = []
set_clause: list[str] = []
if new_checking_id is not None:
set_clause.append("checking_id = :checking_id")
if status is not None:
@ -1023,7 +1010,7 @@ async def get_payments_history(
wallet_id: Optional[str] = None,
group: DateTrunc = "day",
filters: Optional[Filters] = None,
) -> List[PaymentHistoryPoint]:
) -> list[PaymentHistoryPoint]:
if not filters:
filters = Filters()
@ -1249,7 +1236,7 @@ async def get_tinyurl(tinyurl_id: str) -> Optional[TinyURL]:
return TinyURL.from_row(row) if row else None
async def get_tinyurl_by_url(url: str) -> List[TinyURL]:
async def get_tinyurl_by_url(url: str) -> list[TinyURL]:
rows = await db.fetchall(
"SELECT * FROM tiny_url WHERE url = :url",
{"url": url},
@ -1301,7 +1288,7 @@ async def get_webpush_subscription(
async def get_webpush_subscriptions_for_user(
user: str,
) -> List[WebPushSubscription]:
) -> list[WebPushSubscription]:
rows = await db.fetchall(
"""SELECT * FROM webpush_subscriptions WHERE "user" = :user""",
{"user": user},

View file

@ -212,6 +212,19 @@ class PaymentState(str, Enum):
return self.value
class CreatePayment(BaseModel):
wallet_id: str
payment_request: str
payment_hash: str
amount: int
memo: str
preimage: Optional[str] = None
expiry: Optional[datetime.datetime] = None
extra: Optional[dict] = None
webhook: Optional[str] = None
fee: int = 0
class Payment(FromRowModel):
status: str
# TODO should be removed in the future, backward compatibility
@ -225,7 +238,7 @@ class Payment(FromRowModel):
preimage: str
payment_hash: str
expiry: Optional[float]
extra: dict = {}
extra: Optional[dict]
wallet_id: str
webhook: Optional[str]
webhook_status: Optional[int]

View file

@ -1,10 +1,9 @@
import asyncio
import datetime
import json
import time
from io import BytesIO
from pathlib import Path
from typing import Dict, List, Optional, Tuple, TypedDict
from typing import Optional
from urllib.parse import parse_qs, urlparse
from uuid import UUID, uuid4
@ -68,16 +67,24 @@ from .crud import (
update_user_extension,
)
from .helpers import to_valid_user_id
from .models import BalanceDelta, Payment, PaymentState, User, UserConfig, Wallet
from .models import (
BalanceDelta,
CreatePayment,
Payment,
PaymentState,
User,
UserConfig,
Wallet,
)
async def calculate_fiat_amounts(
amount: float,
wallet_id: str,
currency: Optional[str] = None,
extra: Optional[Dict] = None,
extra: Optional[dict] = None,
conn: Optional[Connection] = None,
) -> Tuple[int, Optional[Dict]]:
) -> tuple[int, Optional[dict]]:
wallet = await get_wallet(wallet_id, conn=conn)
assert wallet, "invalid wallet_id"
wallet_currency = wallet.currency or settings.lnbits_default_accounting_currency
@ -118,11 +125,11 @@ async def create_invoice(
description_hash: Optional[bytes] = None,
unhashed_description: Optional[bytes] = None,
expiry: Optional[int] = None,
extra: Optional[Dict] = None,
extra: Optional[dict] = None,
webhook: Optional[str] = None,
internal: Optional[bool] = False,
conn: Optional[Connection] = None,
) -> Tuple[str, str]:
) -> tuple[str, str]:
if not amount > 0:
raise InvoiceError("Amountless invoices not supported.", status="failed")
@ -167,17 +174,20 @@ async def create_invoice(
invoice = bolt11_decode(payment_request)
amount_msat = 1000 * amount_sat
await create_payment(
create_payment_model = CreatePayment(
wallet_id=wallet_id,
checking_id=checking_id,
payment_request=payment_request,
payment_hash=invoice.payment_hash,
amount=amount_msat,
amount=amount_sat * 1000,
expiry=invoice.expiry_date,
memo=memo,
extra=extra,
webhook=webhook,
)
await create_payment(
checking_id=checking_id,
data=create_payment_model,
conn=conn,
)
@ -189,7 +199,7 @@ async def pay_invoice(
wallet_id: str,
payment_request: str,
max_sat: Optional[int] = None,
extra: Optional[Dict] = None,
extra: Optional[dict] = None,
description: str = "",
conn: Optional[Connection] = None,
) -> str:
@ -223,17 +233,7 @@ async def pay_invoice(
invoice.amount_msat / 1000, wallet_id, extra=extra, conn=conn
)
# put all parameters that don't change here
class PaymentKwargs(TypedDict):
wallet_id: str
payment_request: str
payment_hash: str
amount: int
memo: str
expiry: Optional[datetime.datetime]
extra: Optional[Dict]
payment_kwargs: PaymentKwargs = PaymentKwargs(
create_payment_model = CreatePayment(
wallet_id=wallet_id,
payment_request=payment_request,
payment_hash=invoice.payment_hash,
@ -252,9 +252,6 @@ async def pay_invoice(
# (pending only)
internal_checking_id = await check_internal(invoice.payment_hash, conn=conn)
if internal_checking_id:
fee_reserve_total_msat = fee_reserve_total(
invoice.amount_msat, internal=True
)
# perform additional checks on the internal payment
# the payment hash is not enough to make sure that this is the same invoice
internal_invoice = await get_standalone_payment(
@ -269,16 +266,23 @@ async def pay_invoice(
logger.debug(f"creating temporary internal payment with id {internal_id}")
# create a new payment from this wallet
fee_reserve_total_msat = fee_reserve_total(
invoice.amount_msat, internal=True
)
create_payment_model.fee = abs(fee_reserve_total_msat)
new_payment = await create_payment(
checking_id=internal_id,
fee=0 + abs(fee_reserve_total_msat),
data=create_payment_model,
status=PaymentState.SUCCESS,
conn=conn,
**payment_kwargs,
)
else:
new_payment = await _create_external_payment(
temp_id, invoice.amount_msat, conn=conn, **payment_kwargs
temp_id=temp_id,
amount_msat=invoice.amount_msat,
data=create_payment_model,
conn=conn,
)
# do the balance check
@ -377,14 +381,16 @@ async def pay_invoice(
# credit service fee wallet
if settings.lnbits_service_fee_wallet and service_fee_msat:
new_payment = await create_payment(
create_payment_model = CreatePayment(
wallet_id=settings.lnbits_service_fee_wallet,
fee=0,
amount=abs(service_fee_msat),
memo="Service fee",
checking_id="service_fee" + temp_id,
payment_request=payment_request,
payment_hash=invoice.payment_hash,
amount=abs(service_fee_msat),
memo="Service fee",
)
new_payment = await create_payment(
checking_id=f"service_fee_{temp_id}",
data=create_payment_model,
status=PaymentState.SUCCESS,
)
return invoice.payment_hash
@ -393,8 +399,8 @@ async def pay_invoice(
async def _create_external_payment(
temp_id: str,
amount_msat: MilliSatoshi,
data: CreatePayment,
conn: Optional[Connection],
**payment_kwargs,
) -> Payment:
fee_reserve_total_msat = fee_reserve_total(amount_msat, internal=False)
@ -428,11 +434,11 @@ async def _create_external_payment(
# create a temporary payment here so we can check if
# the balance is enough in the next step
try:
data.fee = -abs(fee_reserve_total_msat)
new_payment = await create_payment(
checking_id=temp_id,
fee=-abs(fee_reserve_total_msat),
data=data,
conn=conn,
**payment_kwargs,
)
return new_payment
except Exception as exc:
@ -514,7 +520,7 @@ async def redeem_lnurl_withdraw(
wallet_id: str,
lnurl_request: str,
memo: Optional[str] = None,
extra: Optional[Dict] = None,
extra: Optional[dict] = None,
wait_seconds: int = 0,
conn: Optional[Connection] = None,
) -> None:
@ -853,7 +859,7 @@ async def create_user_account(
class WebsocketConnectionManager:
def __init__(self) -> None:
self.active_connections: List[WebSocket] = []
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket, item_id: str):
logger.debug(f"Websocket connected to {item_id}")