fix many mypy complaints, specially on bolt11.py
This commit is contained in:
parent
ce28db76c9
commit
dc3d96c6a8
5 changed files with 99 additions and 96 deletions
|
|
@ -85,4 +85,3 @@ def migrate_databases():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
app.run()
|
app.run()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,10 @@
|
||||||
# type: ignore
|
import bitstring # type: ignore
|
||||||
|
|
||||||
import bitstring
|
|
||||||
import re
|
import re
|
||||||
import hashlib
|
import hashlib
|
||||||
from typing import List, NamedTuple
|
from typing import List, NamedTuple, Optional
|
||||||
from bech32 import bech32_decode, CHARSET
|
from bech32 import bech32_decode, CHARSET
|
||||||
from ecdsa import SECP256k1, VerifyingKey
|
from ecdsa import SECP256k1, VerifyingKey # type: ignore
|
||||||
from ecdsa.util import sigdecode_string
|
from ecdsa.util import sigdecode_string # type: ignore
|
||||||
from binascii import unhexlify
|
from binascii import unhexlify
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -19,40 +17,40 @@ class Route(NamedTuple):
|
||||||
|
|
||||||
|
|
||||||
class Invoice(object):
|
class Invoice(object):
|
||||||
payment_hash: str = None
|
payment_hash: str
|
||||||
amount_msat: int = 0
|
amount_msat: int = 0
|
||||||
description: str = None
|
description: Optional[str] = None
|
||||||
payee: str = None
|
description_hash: Optional[str] = None
|
||||||
date: int = None
|
payee: str
|
||||||
|
date: int
|
||||||
expiry: int = 3600
|
expiry: int = 3600
|
||||||
secret: str = None
|
secret: Optional[str] = None
|
||||||
route_hints: List[Route] = []
|
route_hints: List[Route] = []
|
||||||
min_final_cltv_expiry: int = 18
|
min_final_cltv_expiry: int = 18
|
||||||
|
|
||||||
|
|
||||||
def decode(pr: str) -> Invoice:
|
def decode(pr: str) -> Invoice:
|
||||||
""" Super naïve bolt11 decoder,
|
"""bolt11 decoder,
|
||||||
only gets payment_hash, description/description_hash and amount in msatoshi.
|
|
||||||
based on https://github.com/rustyrussell/lightning-payencode/blob/master/lnaddr.py
|
based on https://github.com/rustyrussell/lightning-payencode/blob/master/lnaddr.py
|
||||||
"""
|
"""
|
||||||
hrp, data = bech32_decode(pr)
|
|
||||||
if not hrp:
|
|
||||||
raise ValueError("Bad bech32 checksum")
|
|
||||||
|
|
||||||
|
hrp, decoded_data = bech32_decode(pr)
|
||||||
|
if hrp is None or decoded_data is None:
|
||||||
|
raise ValueError("Bad bech32 checksum")
|
||||||
if not hrp.startswith("ln"):
|
if not hrp.startswith("ln"):
|
||||||
raise ValueError("Does not start with ln")
|
raise ValueError("Does not start with ln")
|
||||||
|
|
||||||
data = u5_to_bitarray(data)
|
bitarray = _u5_to_bitarray(decoded_data)
|
||||||
|
|
||||||
# final signature 65 bytes, split it off.
|
# final signature 65 bytes, split it off.
|
||||||
if len(data) < 65 * 8:
|
if len(bitarray) < 65 * 8:
|
||||||
raise ValueError("Too short to contain signature")
|
raise ValueError("Too short to contain signature")
|
||||||
|
|
||||||
# extract the signature
|
# extract the signature
|
||||||
signature = data[-65 * 8 :].tobytes()
|
signature = bitarray[-65 * 8 :].tobytes()
|
||||||
|
|
||||||
# the tagged fields as a bitstream
|
# the tagged fields as a bitstream
|
||||||
data = bitstring.ConstBitStream(data[: -65 * 8])
|
data = bitstring.ConstBitStream(bitarray[: -65 * 8])
|
||||||
|
|
||||||
# build the invoice object
|
# build the invoice object
|
||||||
invoice = Invoice()
|
invoice = Invoice()
|
||||||
|
|
@ -62,35 +60,35 @@ def decode(pr: str) -> Invoice:
|
||||||
if m:
|
if m:
|
||||||
amountstr = hrp[2 + m.end() :]
|
amountstr = hrp[2 + m.end() :]
|
||||||
if amountstr != "":
|
if amountstr != "":
|
||||||
invoice.amount_msat = unshorten_amount(amountstr)
|
invoice.amount_msat = _unshorten_amount(amountstr)
|
||||||
|
|
||||||
# pull out date
|
# pull out date
|
||||||
invoice.date = data.read(35).uint
|
invoice.date = data.read(35).uint
|
||||||
|
|
||||||
while data.pos != data.len:
|
while data.pos != data.len:
|
||||||
tag, tagdata, data = pull_tagged(data)
|
tag, tagdata, data = _pull_tagged(data)
|
||||||
data_length = len(tagdata) / 5
|
data_length = len(tagdata) / 5
|
||||||
|
|
||||||
if tag == "d":
|
if tag == "d":
|
||||||
invoice.description = trim_to_bytes(tagdata).decode("utf-8")
|
invoice.description = _trim_to_bytes(tagdata).decode("utf-8")
|
||||||
elif tag == "h" and data_length == 52:
|
elif tag == "h" and data_length == 52:
|
||||||
invoice.description = trim_to_bytes(tagdata).hex()
|
invoice.description_hash = _trim_to_bytes(tagdata).hex()
|
||||||
elif tag == "p" and data_length == 52:
|
elif tag == "p" and data_length == 52:
|
||||||
invoice.payment_hash = trim_to_bytes(tagdata).hex()
|
invoice.payment_hash = _trim_to_bytes(tagdata).hex()
|
||||||
elif tag == "x":
|
elif tag == "x":
|
||||||
invoice.expiry = tagdata.uint
|
invoice.expiry = tagdata.uint
|
||||||
elif tag == "n":
|
elif tag == "n":
|
||||||
invoice.payee = trim_to_bytes(tagdata).hex()
|
invoice.payee = _trim_to_bytes(tagdata).hex()
|
||||||
# this won't work in most cases, we must extract the payee
|
# this won't work in most cases, we must extract the payee
|
||||||
# from the signature
|
# from the signature
|
||||||
elif tag == "s":
|
elif tag == "s":
|
||||||
invoice.secret = trim_to_bytes(tagdata).hex()
|
invoice.secret = _trim_to_bytes(tagdata).hex()
|
||||||
elif tag == "r":
|
elif tag == "r":
|
||||||
s = bitstring.ConstBitStream(tagdata)
|
s = bitstring.ConstBitStream(tagdata)
|
||||||
while s.pos + 264 + 64 + 32 + 32 + 16 < s.len:
|
while s.pos + 264 + 64 + 32 + 32 + 16 < s.len:
|
||||||
route = Route(
|
route = Route(
|
||||||
pubkey=s.read(264).tobytes().hex(),
|
pubkey=s.read(264).tobytes().hex(),
|
||||||
short_channel_id=readable_scid(s.read(64).intbe),
|
short_channel_id=_readable_scid(s.read(64).intbe),
|
||||||
base_fee_msat=s.read(32).intbe,
|
base_fee_msat=s.read(32).intbe,
|
||||||
ppm_fee=s.read(32).intbe,
|
ppm_fee=s.read(32).intbe,
|
||||||
cltv=s.read(16).intbe,
|
cltv=s.read(16).intbe,
|
||||||
|
|
@ -116,7 +114,7 @@ def decode(pr: str) -> Invoice:
|
||||||
return invoice
|
return invoice
|
||||||
|
|
||||||
|
|
||||||
def unshorten_amount(amount: str) -> int:
|
def _unshorten_amount(amount: str) -> int:
|
||||||
""" Given a shortened amount, return millisatoshis
|
""" Given a shortened amount, return millisatoshis
|
||||||
"""
|
"""
|
||||||
# BOLT #11:
|
# BOLT #11:
|
||||||
|
|
@ -141,18 +139,18 @@ def unshorten_amount(amount: str) -> int:
|
||||||
raise ValueError("Invalid amount '{}'".format(amount))
|
raise ValueError("Invalid amount '{}'".format(amount))
|
||||||
|
|
||||||
if unit in units:
|
if unit in units:
|
||||||
return int(amount[:-1]) * 100_000_000_000 / units[unit]
|
return int(int(amount[:-1]) * 100_000_000_000 / units[unit])
|
||||||
else:
|
else:
|
||||||
return int(amount) * 100_000_000_000
|
return int(amount) * 100_000_000_000
|
||||||
|
|
||||||
|
|
||||||
def pull_tagged(stream):
|
def _pull_tagged(stream):
|
||||||
tag = stream.read(5).uint
|
tag = stream.read(5).uint
|
||||||
length = stream.read(5).uint * 32 + stream.read(5).uint
|
length = stream.read(5).uint * 32 + stream.read(5).uint
|
||||||
return (CHARSET[tag], stream.read(length * 5), stream)
|
return (CHARSET[tag], stream.read(length * 5), stream)
|
||||||
|
|
||||||
|
|
||||||
def trim_to_bytes(barr):
|
def _trim_to_bytes(barr):
|
||||||
# Adds a byte if necessary.
|
# Adds a byte if necessary.
|
||||||
b = barr.tobytes()
|
b = barr.tobytes()
|
||||||
if barr.len % 8 != 0:
|
if barr.len % 8 != 0:
|
||||||
|
|
@ -160,7 +158,7 @@ def trim_to_bytes(barr):
|
||||||
return b
|
return b
|
||||||
|
|
||||||
|
|
||||||
def readable_scid(short_channel_id: int) -> str:
|
def _readable_scid(short_channel_id: int) -> str:
|
||||||
return "{blockheight}x{transactionindex}x{outputindex}".format(
|
return "{blockheight}x{transactionindex}x{outputindex}".format(
|
||||||
blockheight=((short_channel_id >> 40) & 0xFFFFFF),
|
blockheight=((short_channel_id >> 40) & 0xFFFFFF),
|
||||||
transactionindex=((short_channel_id >> 16) & 0xFFFFFF),
|
transactionindex=((short_channel_id >> 16) & 0xFFFFFF),
|
||||||
|
|
@ -168,7 +166,7 @@ def readable_scid(short_channel_id: int) -> str:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def u5_to_bitarray(arr):
|
def _u5_to_bitarray(arr: List[int]) -> bitstring.BitArray:
|
||||||
ret = bitstring.BitArray()
|
ret = bitstring.BitArray()
|
||||||
for a in arr:
|
for a in arr:
|
||||||
ret += bitstring.pack("uint:5", a)
|
ret += bitstring.pack("uint:5", a)
|
||||||
|
|
|
||||||
|
|
@ -276,10 +276,10 @@ def delete_payment(checking_id: str) -> None:
|
||||||
db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,))
|
db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,))
|
||||||
|
|
||||||
|
|
||||||
def check_internal(payment_hash: str) -> None:
|
def check_internal(payment_hash: str) -> Optional[str]:
|
||||||
with open_db() as db:
|
with open_db() as db:
|
||||||
row = db.fetchone("SELECT checking_id FROM apipayments WHERE hash = ?", (payment_hash,))
|
row = db.fetchone("SELECT checking_id FROM apipayments WHERE hash = ?", (payment_hash,))
|
||||||
if not row:
|
if not row:
|
||||||
return False
|
return None
|
||||||
else:
|
else:
|
||||||
return row["checking_id"]
|
return row["checking_id"]
|
||||||
|
|
|
||||||
|
|
@ -1,27 +1,22 @@
|
||||||
from typing import Optional, Tuple, Dict
|
from typing import Optional, Tuple, Dict, TypedDict
|
||||||
|
|
||||||
from lnbits import bolt11
|
from lnbits import bolt11
|
||||||
from lnbits.helpers import urlsafe_short_hash
|
from lnbits.helpers import urlsafe_short_hash
|
||||||
from lnbits.settings import WALLET
|
from lnbits.settings import WALLET
|
||||||
|
from lnbits.wallets.base import PaymentStatus
|
||||||
|
|
||||||
from .crud import get_wallet, create_payment, delete_payment, check_internal, update_payment_status, get_wallet_payment
|
from .crud import get_wallet, create_payment, delete_payment, check_internal, update_payment_status, get_wallet_payment
|
||||||
|
|
||||||
|
|
||||||
def create_invoice(
|
def create_invoice(
|
||||||
*,
|
*, wallet_id: str, amount: int, memo: str, description_hash: Optional[bytes] = None, extra: Optional[Dict] = None,
|
||||||
wallet_id: str,
|
|
||||||
amount: int,
|
|
||||||
memo: Optional[str] = None,
|
|
||||||
description_hash: Optional[bytes] = None,
|
|
||||||
extra: Optional[Dict] = None,
|
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
try:
|
invoice_memo = None if description_hash else memo
|
||||||
ok, checking_id, payment_request, error_message = WALLET.create_invoice(
|
storeable_memo = memo
|
||||||
amount=amount, memo=memo, description_hash=description_hash
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
ok, error_message = False, str(e)
|
|
||||||
|
|
||||||
|
ok, checking_id, payment_request, error_message = WALLET.create_invoice(
|
||||||
|
amount=amount, memo=invoice_memo, description_hash=description_hash
|
||||||
|
)
|
||||||
if not ok:
|
if not ok:
|
||||||
raise Exception(error_message or "Unexpected backend error.")
|
raise Exception(error_message or "Unexpected backend error.")
|
||||||
|
|
||||||
|
|
@ -34,7 +29,7 @@ def create_invoice(
|
||||||
payment_request=payment_request,
|
payment_request=payment_request,
|
||||||
payment_hash=invoice.payment_hash,
|
payment_hash=invoice.payment_hash,
|
||||||
amount=amount_msat,
|
amount=amount_msat,
|
||||||
memo=memo,
|
memo=storeable_memo,
|
||||||
extra=extra,
|
extra=extra,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -47,60 +42,71 @@ def pay_invoice(
|
||||||
temp_id = f"temp_{urlsafe_short_hash()}"
|
temp_id = f"temp_{urlsafe_short_hash()}"
|
||||||
internal_id = f"internal_{urlsafe_short_hash()}"
|
internal_id = f"internal_{urlsafe_short_hash()}"
|
||||||
|
|
||||||
try:
|
invoice = bolt11.decode(payment_request)
|
||||||
invoice = bolt11.decode(payment_request)
|
if invoice.amount_msat == 0:
|
||||||
if invoice.amount_msat == 0:
|
raise ValueError("Amountless invoices not supported.")
|
||||||
raise ValueError("Amountless invoices not supported.")
|
if max_sat and invoice.amount_msat > max_sat * 1000:
|
||||||
if max_sat and invoice.amount_msat > max_sat * 1000:
|
raise ValueError("Amount in invoice is too high.")
|
||||||
raise ValueError("Amount in invoice is too high.")
|
|
||||||
|
|
||||||
# put all parameters that don't change here
|
# put all parameters that don't change here
|
||||||
payment_kwargs = dict(
|
PaymentKwargs = TypedDict(
|
||||||
wallet_id=wallet_id,
|
"PaymentKwargs",
|
||||||
payment_request=payment_request,
|
{
|
||||||
payment_hash=invoice.payment_hash,
|
"wallet_id": str,
|
||||||
amount=-invoice.amount_msat,
|
"payment_request": str,
|
||||||
memo=invoice.description,
|
"payment_hash": str,
|
||||||
extra=extra,
|
"amount": int,
|
||||||
)
|
"memo": str,
|
||||||
|
"extra": Optional[Dict],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
payment_kwargs: PaymentKwargs = dict(
|
||||||
|
wallet_id=wallet_id,
|
||||||
|
payment_request=payment_request,
|
||||||
|
payment_hash=invoice.payment_hash,
|
||||||
|
amount=-invoice.amount_msat,
|
||||||
|
memo=invoice.description or "",
|
||||||
|
extra=extra,
|
||||||
|
)
|
||||||
|
|
||||||
# check_internal() returns the checking_id of the invoice we're waiting for
|
# check_internal() returns the checking_id of the invoice we're waiting for
|
||||||
internal = check_internal(invoice.payment_hash)
|
internal = check_internal(invoice.payment_hash)
|
||||||
if internal:
|
if internal:
|
||||||
# create a new payment from this wallet
|
# create a new payment from this wallet
|
||||||
create_payment(checking_id=internal_id, fee=0, pending=False, **payment_kwargs)
|
create_payment(checking_id=internal_id, fee=0, pending=False, **payment_kwargs)
|
||||||
else:
|
else:
|
||||||
# create a temporary payment here so we can check if
|
# create a temporary payment here so we can check if
|
||||||
# the balance is enough in the next step
|
# the balance is enough in the next step
|
||||||
fee_reserve = max(1000, int(invoice.amount_msat * 0.01))
|
fee_reserve = max(1000, int(invoice.amount_msat * 0.01))
|
||||||
create_payment(checking_id=temp_id, fee=-fee_reserve, **payment_kwargs)
|
create_payment(checking_id=temp_id, fee=-fee_reserve, **payment_kwargs)
|
||||||
|
|
||||||
# do the balance check
|
# do the balance check
|
||||||
wallet = get_wallet(wallet_id)
|
wallet = get_wallet(wallet_id)
|
||||||
assert wallet, "invalid wallet id"
|
assert wallet, "invalid wallet id"
|
||||||
if wallet.balance_msat < 0:
|
if wallet.balance_msat < 0:
|
||||||
raise PermissionError("Insufficient balance.")
|
raise PermissionError("Insufficient balance.")
|
||||||
|
|
||||||
if internal:
|
if internal:
|
||||||
# mark the invoice from the other side as not pending anymore
|
# mark the invoice from the other side as not pending anymore
|
||||||
# so the other side only has access to his new money when we are sure
|
# so the other side only has access to his new money when we are sure
|
||||||
# the payer has enough to deduct from
|
# the payer has enough to deduct from
|
||||||
update_payment_status(checking_id=internal, pending=False)
|
update_payment_status(checking_id=internal, pending=False)
|
||||||
else:
|
else:
|
||||||
# actually pay the external invoice
|
# actually pay the external invoice
|
||||||
ok, checking_id, fee_msat, error_message = WALLET.pay_invoice(payment_request)
|
ok, checking_id, fee_msat, error_message = WALLET.pay_invoice(payment_request)
|
||||||
if ok:
|
if ok:
|
||||||
create_payment(checking_id=checking_id, fee=fee_msat, **payment_kwargs)
|
create_payment(checking_id=checking_id, fee=fee_msat, **payment_kwargs)
|
||||||
delete_payment(temp_id)
|
delete_payment(temp_id)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
ok, error_message = False, str(e)
|
|
||||||
if not ok:
|
if not ok:
|
||||||
raise Exception(error_message or "Unexpected backend error.")
|
raise Exception(error_message or "Unexpected backend error.")
|
||||||
|
|
||||||
return invoice.payment_hash
|
return invoice.payment_hash
|
||||||
|
|
||||||
|
|
||||||
def check_invoice_status(wallet_id: str, payment_hash: str) -> str:
|
def check_invoice_status(wallet_id: str, payment_hash: str) -> PaymentStatus:
|
||||||
payment = get_wallet_payment(wallet_id, payment_hash)
|
payment = get_wallet_payment(wallet_id, payment_hash)
|
||||||
|
if not payment:
|
||||||
|
return PaymentStatus(None)
|
||||||
|
|
||||||
return WALLET.get_invoice_status(payment.checking_id)
|
return WALLET.get_invoice_status(payment.checking_id)
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ def api_paywall_create_invoice(paywall_id):
|
||||||
try:
|
try:
|
||||||
amount = g.data["amount"] if g.data["amount"] > paywall.amount else paywall.amount
|
amount = g.data["amount"] if g.data["amount"] > paywall.amount else paywall.amount
|
||||||
payment_hash, payment_request = create_invoice(
|
payment_hash, payment_request = create_invoice(
|
||||||
wallet_id=paywall.wallet, amount=amount, memo=f"{paywall.memo}", extra={'tag': 'paywall'}
|
wallet_id=paywall.wallet, amount=amount, memo=f"{paywall.memo}", extra={"tag": "paywall"}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR
|
return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue