refactor: use CreatePayment model instead of a lot of kwargs (#2667)
- refactoring create_payment a bit to use a model instead of 10 kwargs
This commit is contained in:
parent
053ea20508
commit
9d7e54f6b2
3 changed files with 84 additions and 78 deletions
|
|
@ -1,7 +1,6 @@
|
|||
import datetime
|
||||
import json
|
||||
from time import time
|
||||
from typing import Dict, List, Literal, Optional
|
||||
from typing import Literal, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import shortuuid
|
||||
|
|
@ -27,6 +26,7 @@ from lnbits.settings import (
|
|||
from .models import (
|
||||
Account,
|
||||
AccountFilters,
|
||||
CreatePayment,
|
||||
Payment,
|
||||
PaymentFilters,
|
||||
PaymentHistoryPoint,
|
||||
|
|
@ -38,9 +38,6 @@ from .models import (
|
|||
WebPushSubscription,
|
||||
)
|
||||
|
||||
# accounts
|
||||
# --------
|
||||
|
||||
|
||||
async def create_account(
|
||||
user_id: Optional[str] = None,
|
||||
|
|
@ -430,7 +427,7 @@ async def get_installed_extension(
|
|||
async def get_installed_extensions(
|
||||
active: Optional[bool] = None,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> List[InstallableExtension]:
|
||||
) -> list[InstallableExtension]:
|
||||
rows = await (conn or db).fetchall(
|
||||
"SELECT * FROM installed_extensions",
|
||||
)
|
||||
|
|
@ -456,7 +453,7 @@ async def get_user_extension(
|
|||
|
||||
async def get_user_extensions(
|
||||
user_id: str, conn: Optional[Connection] = None
|
||||
) -> List[UserExtension]:
|
||||
) -> list[UserExtension]:
|
||||
rows = await (conn or db).fetchall(
|
||||
"""
|
||||
SELECT extension, active, extra as _extra FROM extensions
|
||||
|
|
@ -481,7 +478,7 @@ async def update_user_extension(
|
|||
|
||||
async def get_user_active_extensions_ids(
|
||||
user_id: str, conn: Optional[Connection] = None
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
rows = await (conn or db).fetchall(
|
||||
"""
|
||||
SELECT extension FROM extensions WHERE "user" = :user AND active
|
||||
|
|
@ -650,7 +647,7 @@ async def get_wallet(
|
|||
return Wallet(**row) if row else None
|
||||
|
||||
|
||||
async def get_wallets(user_id: str, conn: Optional[Connection] = None) -> List[Wallet]:
|
||||
async def get_wallets(user_id: str, conn: Optional[Connection] = None) -> list[Wallet]:
|
||||
rows = await (conn or db).fetchall(
|
||||
"""
|
||||
SELECT *, COALESCE((SELECT balance FROM balances WHERE wallet = wallets.id), 0)
|
||||
|
|
@ -772,7 +769,7 @@ async def get_payments_paginated(
|
|||
"wallet": wallet_id,
|
||||
"time": since,
|
||||
}
|
||||
clause: List[str] = []
|
||||
clause: list[str] = []
|
||||
|
||||
if since is not None:
|
||||
clause.append(f"time > {db.timestamp_placeholder('time')}")
|
||||
|
|
@ -882,23 +879,13 @@ async def delete_expired_invoices(
|
|||
|
||||
|
||||
async def create_payment(
|
||||
*,
|
||||
wallet_id: str,
|
||||
checking_id: str,
|
||||
payment_request: str,
|
||||
payment_hash: str,
|
||||
amount: int,
|
||||
memo: str,
|
||||
fee: int = 0,
|
||||
data: CreatePayment,
|
||||
status: PaymentState = PaymentState.PENDING,
|
||||
preimage: Optional[str] = None,
|
||||
expiry: Optional[datetime.datetime] = None,
|
||||
extra: Optional[Dict] = None,
|
||||
webhook: Optional[str] = None,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> Payment:
|
||||
# we don't allow the creation of the same invoice twice
|
||||
# note: this can be removed if the db uniquess constarints are set appropriately
|
||||
# note: this can be removed if the db uniqueness constraints are set appropriately
|
||||
previous_payment = await get_standalone_payment(checking_id, conn=conn)
|
||||
assert previous_payment is None, "Payment already exists"
|
||||
|
||||
|
|
@ -912,27 +899,27 @@ async def create_payment(
|
|||
:amount, :status, :memo, :fee, :extra, :webhook, {expiry_ph}, :pending)
|
||||
""",
|
||||
{
|
||||
"wallet": wallet_id,
|
||||
"wallet": data.wallet_id,
|
||||
"checking_id": checking_id,
|
||||
"bolt11": payment_request,
|
||||
"hash": payment_hash,
|
||||
"preimage": preimage,
|
||||
"amount": amount,
|
||||
"bolt11": data.payment_request,
|
||||
"hash": data.payment_hash,
|
||||
"preimage": data.preimage,
|
||||
"amount": data.amount,
|
||||
"status": status.value,
|
||||
"memo": memo,
|
||||
"fee": fee,
|
||||
"memo": data.memo,
|
||||
"fee": data.fee,
|
||||
"extra": (
|
||||
json.dumps(extra)
|
||||
if extra and extra != {} and isinstance(extra, dict)
|
||||
json.dumps(data.extra)
|
||||
if data.extra and data.extra != {} and isinstance(data.extra, dict)
|
||||
else None
|
||||
),
|
||||
"webhook": webhook,
|
||||
"expiry": expiry if expiry else None,
|
||||
"webhook": data.webhook,
|
||||
"expiry": data.expiry if data.expiry else None,
|
||||
"pending": False, # TODO: remove this in next release
|
||||
},
|
||||
)
|
||||
|
||||
new_payment = await get_wallet_payment(wallet_id, payment_hash, conn=conn)
|
||||
new_payment = await get_wallet_payment(data.wallet_id, data.payment_hash, conn=conn)
|
||||
assert new_payment, "Newly created payment couldn't be retrieved"
|
||||
|
||||
return new_payment
|
||||
|
|
@ -963,7 +950,7 @@ async def update_payment_details(
|
|||
"preimage": preimage,
|
||||
}
|
||||
|
||||
set_clause: List[str] = []
|
||||
set_clause: list[str] = []
|
||||
if new_checking_id is not None:
|
||||
set_clause.append("checking_id = :checking_id")
|
||||
if status is not None:
|
||||
|
|
@ -1023,7 +1010,7 @@ async def get_payments_history(
|
|||
wallet_id: Optional[str] = None,
|
||||
group: DateTrunc = "day",
|
||||
filters: Optional[Filters] = None,
|
||||
) -> List[PaymentHistoryPoint]:
|
||||
) -> list[PaymentHistoryPoint]:
|
||||
if not filters:
|
||||
filters = Filters()
|
||||
|
||||
|
|
@ -1249,7 +1236,7 @@ async def get_tinyurl(tinyurl_id: str) -> Optional[TinyURL]:
|
|||
return TinyURL.from_row(row) if row else None
|
||||
|
||||
|
||||
async def get_tinyurl_by_url(url: str) -> List[TinyURL]:
|
||||
async def get_tinyurl_by_url(url: str) -> list[TinyURL]:
|
||||
rows = await db.fetchall(
|
||||
"SELECT * FROM tiny_url WHERE url = :url",
|
||||
{"url": url},
|
||||
|
|
@ -1301,7 +1288,7 @@ async def get_webpush_subscription(
|
|||
|
||||
async def get_webpush_subscriptions_for_user(
|
||||
user: str,
|
||||
) -> List[WebPushSubscription]:
|
||||
) -> list[WebPushSubscription]:
|
||||
rows = await db.fetchall(
|
||||
"""SELECT * FROM webpush_subscriptions WHERE "user" = :user""",
|
||||
{"user": user},
|
||||
|
|
|
|||
|
|
@ -212,6 +212,19 @@ class PaymentState(str, Enum):
|
|||
return self.value
|
||||
|
||||
|
||||
class CreatePayment(BaseModel):
|
||||
wallet_id: str
|
||||
payment_request: str
|
||||
payment_hash: str
|
||||
amount: int
|
||||
memo: str
|
||||
preimage: Optional[str] = None
|
||||
expiry: Optional[datetime.datetime] = None
|
||||
extra: Optional[dict] = None
|
||||
webhook: Optional[str] = None
|
||||
fee: int = 0
|
||||
|
||||
|
||||
class Payment(FromRowModel):
|
||||
status: str
|
||||
# TODO should be removed in the future, backward compatibility
|
||||
|
|
@ -225,7 +238,7 @@ class Payment(FromRowModel):
|
|||
preimage: str
|
||||
payment_hash: str
|
||||
expiry: Optional[float]
|
||||
extra: dict = {}
|
||||
extra: Optional[dict]
|
||||
wallet_id: str
|
||||
webhook: Optional[str]
|
||||
webhook_status: Optional[int]
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import time
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, TypedDict
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
|
|
@ -68,16 +67,24 @@ from .crud import (
|
|||
update_user_extension,
|
||||
)
|
||||
from .helpers import to_valid_user_id
|
||||
from .models import BalanceDelta, Payment, PaymentState, User, UserConfig, Wallet
|
||||
from .models import (
|
||||
BalanceDelta,
|
||||
CreatePayment,
|
||||
Payment,
|
||||
PaymentState,
|
||||
User,
|
||||
UserConfig,
|
||||
Wallet,
|
||||
)
|
||||
|
||||
|
||||
async def calculate_fiat_amounts(
|
||||
amount: float,
|
||||
wallet_id: str,
|
||||
currency: Optional[str] = None,
|
||||
extra: Optional[Dict] = None,
|
||||
extra: Optional[dict] = None,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> Tuple[int, Optional[Dict]]:
|
||||
) -> tuple[int, Optional[dict]]:
|
||||
wallet = await get_wallet(wallet_id, conn=conn)
|
||||
assert wallet, "invalid wallet_id"
|
||||
wallet_currency = wallet.currency or settings.lnbits_default_accounting_currency
|
||||
|
|
@ -118,11 +125,11 @@ async def create_invoice(
|
|||
description_hash: Optional[bytes] = None,
|
||||
unhashed_description: Optional[bytes] = None,
|
||||
expiry: Optional[int] = None,
|
||||
extra: Optional[Dict] = None,
|
||||
extra: Optional[dict] = None,
|
||||
webhook: Optional[str] = None,
|
||||
internal: Optional[bool] = False,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> Tuple[str, str]:
|
||||
) -> tuple[str, str]:
|
||||
if not amount > 0:
|
||||
raise InvoiceError("Amountless invoices not supported.", status="failed")
|
||||
|
||||
|
|
@ -167,17 +174,20 @@ async def create_invoice(
|
|||
|
||||
invoice = bolt11_decode(payment_request)
|
||||
|
||||
amount_msat = 1000 * amount_sat
|
||||
await create_payment(
|
||||
create_payment_model = CreatePayment(
|
||||
wallet_id=wallet_id,
|
||||
checking_id=checking_id,
|
||||
payment_request=payment_request,
|
||||
payment_hash=invoice.payment_hash,
|
||||
amount=amount_msat,
|
||||
amount=amount_sat * 1000,
|
||||
expiry=invoice.expiry_date,
|
||||
memo=memo,
|
||||
extra=extra,
|
||||
webhook=webhook,
|
||||
)
|
||||
|
||||
await create_payment(
|
||||
checking_id=checking_id,
|
||||
data=create_payment_model,
|
||||
conn=conn,
|
||||
)
|
||||
|
||||
|
|
@ -189,7 +199,7 @@ async def pay_invoice(
|
|||
wallet_id: str,
|
||||
payment_request: str,
|
||||
max_sat: Optional[int] = None,
|
||||
extra: Optional[Dict] = None,
|
||||
extra: Optional[dict] = None,
|
||||
description: str = "",
|
||||
conn: Optional[Connection] = None,
|
||||
) -> str:
|
||||
|
|
@ -223,17 +233,7 @@ async def pay_invoice(
|
|||
invoice.amount_msat / 1000, wallet_id, extra=extra, conn=conn
|
||||
)
|
||||
|
||||
# put all parameters that don't change here
|
||||
class PaymentKwargs(TypedDict):
|
||||
wallet_id: str
|
||||
payment_request: str
|
||||
payment_hash: str
|
||||
amount: int
|
||||
memo: str
|
||||
expiry: Optional[datetime.datetime]
|
||||
extra: Optional[Dict]
|
||||
|
||||
payment_kwargs: PaymentKwargs = PaymentKwargs(
|
||||
create_payment_model = CreatePayment(
|
||||
wallet_id=wallet_id,
|
||||
payment_request=payment_request,
|
||||
payment_hash=invoice.payment_hash,
|
||||
|
|
@ -252,9 +252,6 @@ async def pay_invoice(
|
|||
# (pending only)
|
||||
internal_checking_id = await check_internal(invoice.payment_hash, conn=conn)
|
||||
if internal_checking_id:
|
||||
fee_reserve_total_msat = fee_reserve_total(
|
||||
invoice.amount_msat, internal=True
|
||||
)
|
||||
# perform additional checks on the internal payment
|
||||
# the payment hash is not enough to make sure that this is the same invoice
|
||||
internal_invoice = await get_standalone_payment(
|
||||
|
|
@ -269,16 +266,23 @@ async def pay_invoice(
|
|||
|
||||
logger.debug(f"creating temporary internal payment with id {internal_id}")
|
||||
# create a new payment from this wallet
|
||||
|
||||
fee_reserve_total_msat = fee_reserve_total(
|
||||
invoice.amount_msat, internal=True
|
||||
)
|
||||
create_payment_model.fee = abs(fee_reserve_total_msat)
|
||||
new_payment = await create_payment(
|
||||
checking_id=internal_id,
|
||||
fee=0 + abs(fee_reserve_total_msat),
|
||||
data=create_payment_model,
|
||||
status=PaymentState.SUCCESS,
|
||||
conn=conn,
|
||||
**payment_kwargs,
|
||||
)
|
||||
else:
|
||||
new_payment = await _create_external_payment(
|
||||
temp_id, invoice.amount_msat, conn=conn, **payment_kwargs
|
||||
temp_id=temp_id,
|
||||
amount_msat=invoice.amount_msat,
|
||||
data=create_payment_model,
|
||||
conn=conn,
|
||||
)
|
||||
|
||||
# do the balance check
|
||||
|
|
@ -377,14 +381,16 @@ async def pay_invoice(
|
|||
|
||||
# credit service fee wallet
|
||||
if settings.lnbits_service_fee_wallet and service_fee_msat:
|
||||
new_payment = await create_payment(
|
||||
create_payment_model = CreatePayment(
|
||||
wallet_id=settings.lnbits_service_fee_wallet,
|
||||
fee=0,
|
||||
amount=abs(service_fee_msat),
|
||||
memo="Service fee",
|
||||
checking_id="service_fee" + temp_id,
|
||||
payment_request=payment_request,
|
||||
payment_hash=invoice.payment_hash,
|
||||
amount=abs(service_fee_msat),
|
||||
memo="Service fee",
|
||||
)
|
||||
new_payment = await create_payment(
|
||||
checking_id=f"service_fee_{temp_id}",
|
||||
data=create_payment_model,
|
||||
status=PaymentState.SUCCESS,
|
||||
)
|
||||
return invoice.payment_hash
|
||||
|
|
@ -393,8 +399,8 @@ async def pay_invoice(
|
|||
async def _create_external_payment(
|
||||
temp_id: str,
|
||||
amount_msat: MilliSatoshi,
|
||||
data: CreatePayment,
|
||||
conn: Optional[Connection],
|
||||
**payment_kwargs,
|
||||
) -> Payment:
|
||||
fee_reserve_total_msat = fee_reserve_total(amount_msat, internal=False)
|
||||
|
||||
|
|
@ -428,11 +434,11 @@ async def _create_external_payment(
|
|||
# create a temporary payment here so we can check if
|
||||
# the balance is enough in the next step
|
||||
try:
|
||||
data.fee = -abs(fee_reserve_total_msat)
|
||||
new_payment = await create_payment(
|
||||
checking_id=temp_id,
|
||||
fee=-abs(fee_reserve_total_msat),
|
||||
data=data,
|
||||
conn=conn,
|
||||
**payment_kwargs,
|
||||
)
|
||||
return new_payment
|
||||
except Exception as exc:
|
||||
|
|
@ -514,7 +520,7 @@ async def redeem_lnurl_withdraw(
|
|||
wallet_id: str,
|
||||
lnurl_request: str,
|
||||
memo: Optional[str] = None,
|
||||
extra: Optional[Dict] = None,
|
||||
extra: Optional[dict] = None,
|
||||
wait_seconds: int = 0,
|
||||
conn: Optional[Connection] = None,
|
||||
) -> None:
|
||||
|
|
@ -853,7 +859,7 @@ async def create_user_account(
|
|||
|
||||
class WebsocketConnectionManager:
|
||||
def __init__(self) -> None:
|
||||
self.active_connections: List[WebSocket] = []
|
||||
self.active_connections: list[WebSocket] = []
|
||||
|
||||
async def connect(self, websocket: WebSocket, item_id: str):
|
||||
logger.debug(f"Websocket connected to {item_id}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue