Refactor: Calculate invoice expiry outside of the crud (#1849)

* expiry before crud
* using Bolt11 as argument instead of payment_request
* fix missing expiry

---------

Co-authored-by: dni  <office@dnilabs.com>
This commit is contained in:
callebtc 2023-10-10 12:03:24 +02:00 committed by GitHub
parent 80529dee4b
commit b2384c10cc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 13 deletions

View file

@ -5,7 +5,6 @@ from urllib.parse import urlparse
from uuid import UUID, uuid4 from uuid import UUID, uuid4
import shortuuid import shortuuid
from bolt11.decode import decode
from lnbits.core.db import db from lnbits.core.db import db
from lnbits.core.models import WalletType from lnbits.core.models import WalletType
@ -535,6 +534,7 @@ async def create_payment(
memo: str, memo: str,
fee: int = 0, fee: int = 0,
preimage: Optional[str] = None, preimage: Optional[str] = None,
expiry: Optional[datetime.datetime] = None,
pending: bool = True, pending: bool = True,
extra: Optional[Dict] = None, extra: Optional[Dict] = None,
webhook: Optional[str] = None, webhook: Optional[str] = None,
@ -545,14 +545,6 @@ async def create_payment(
previous_payment = await get_standalone_payment(checking_id, conn=conn) previous_payment = await get_standalone_payment(checking_id, conn=conn)
assert previous_payment is None, "Payment already exists" assert previous_payment is None, "Payment already exists"
invoice = decode(payment_request)
if invoice.expiry:
expiration_date = datetime.datetime.fromtimestamp(invoice.date + invoice.expiry)
else:
# assume maximum bolt11 expiry of 31 days to be on the safe side
expiration_date = datetime.datetime.now() + datetime.timedelta(days=31)
await (conn or db).execute( await (conn or db).execute(
""" """
INSERT INTO apipayments INSERT INTO apipayments
@ -576,7 +568,7 @@ async def create_payment(
else None else None
), ),
webhook, webhook,
db.datetime_to_timestamp(expiration_date), db.datetime_to_timestamp(expiry) if expiry else None,
), ),
) )

View file

@ -1,4 +1,5 @@
import asyncio import asyncio
import datetime
import json import json
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
@ -6,6 +7,8 @@ from typing import Dict, List, Optional, Tuple, TypedDict
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
import httpx import httpx
from bolt11 import Bolt11
from bolt11 import decode as bolt11_decode
from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives import serialization
from fastapi import Depends, WebSocket from fastapi import Depends, WebSocket
from lnurl import LnurlErrorResponse from lnurl import LnurlErrorResponse
@ -14,7 +17,6 @@ from loguru import logger
from py_vapid import Vapid from py_vapid import Vapid
from py_vapid.utils import b64urlencode from py_vapid.utils import b64urlencode
from lnbits import bolt11
from lnbits.core.db import db from lnbits.core.db import db
from lnbits.db import Connection from lnbits.db import Connection
from lnbits.decorators import WalletTypeInfo, require_admin_key from lnbits.decorators import WalletTypeInfo, require_admin_key
@ -138,7 +140,7 @@ async def create_invoice(
if not ok or not payment_request or not checking_id: if not ok or not payment_request or not checking_id:
raise InvoiceFailure(error_message or "unexpected backend error.") raise InvoiceFailure(error_message or "unexpected backend error.")
invoice = bolt11.decode(payment_request) invoice = bolt11_decode(payment_request)
amount_msat = 1000 * amount_sat amount_msat = 1000 * amount_sat
await create_payment( await create_payment(
@ -147,6 +149,7 @@ async def create_invoice(
payment_request=payment_request, payment_request=payment_request,
payment_hash=invoice.payment_hash, payment_hash=invoice.payment_hash,
amount=amount_msat, amount=amount_msat,
expiry=get_bolt11_expiry(invoice),
memo=memo, memo=memo,
extra=extra, extra=extra,
webhook=webhook, webhook=webhook,
@ -175,7 +178,7 @@ async def pay_invoice(
If the payment is still in flight, we hope that some other process If the payment is still in flight, we hope that some other process
will regularly check for the payment. will regularly check for the payment.
""" """
invoice = bolt11.decode(payment_request) invoice = bolt11_decode(payment_request)
if not invoice.amount_msat or not invoice.amount_msat > 0: if not invoice.amount_msat or not invoice.amount_msat > 0:
raise ValueError("Amountless invoices not supported.") raise ValueError("Amountless invoices not supported.")
@ -203,6 +206,7 @@ async def pay_invoice(
payment_hash: str payment_hash: str
amount: int amount: int
memo: str memo: str
expiry: Optional[datetime.datetime]
extra: Optional[Dict] extra: Optional[Dict]
payment_kwargs: PaymentKwargs = PaymentKwargs( payment_kwargs: PaymentKwargs = PaymentKwargs(
@ -210,6 +214,7 @@ async def pay_invoice(
payment_request=payment_request, payment_request=payment_request,
payment_hash=invoice.payment_hash, payment_hash=invoice.payment_hash,
amount=-invoice.amount_msat, amount=-invoice.amount_msat,
expiry=get_bolt11_expiry(invoice),
memo=description or invoice.description or "", memo=description or invoice.description or "",
extra=extra, extra=extra,
) )
@ -650,3 +655,11 @@ async def get_balance_delta() -> Tuple[int, int, int]:
if error_message: if error_message:
raise Exception(error_message) raise Exception(error_message)
return node_balance - total_balance, node_balance, total_balance return node_balance - total_balance, node_balance, total_balance
def get_bolt11_expiry(invoice: Bolt11) -> datetime.datetime:
if invoice.expiry:
return datetime.datetime.fromtimestamp(invoice.date + invoice.expiry)
else:
# assume maximum bolt11 expiry of 31 days to be on the safe side
return datetime.datetime.now() + datetime.timedelta(days=31)