From dc3d96c6a8fcb3925a5851ee66de4f6d92e6d0c2 Mon Sep 17 00:00:00 2001 From: fiatjaf Date: Wed, 2 Sep 2020 21:11:08 -0300 Subject: [PATCH] fix many mypy complaints, specially on bolt11.py --- lnbits/__init__.py | 1 - lnbits/bolt11.py | 66 +++++++------ lnbits/core/crud.py | 4 +- lnbits/core/services.py | 122 +++++++++++++------------ lnbits/extensions/paywall/views_api.py | 2 +- 5 files changed, 99 insertions(+), 96 deletions(-) diff --git a/lnbits/__init__.py b/lnbits/__init__.py index 479338ed..b2433760 100644 --- a/lnbits/__init__.py +++ b/lnbits/__init__.py @@ -85,4 +85,3 @@ def migrate_databases(): if __name__ == "__main__": app.run() - diff --git a/lnbits/bolt11.py b/lnbits/bolt11.py index d9344617..184844f7 100644 --- a/lnbits/bolt11.py +++ b/lnbits/bolt11.py @@ -1,12 +1,10 @@ -# type: ignore - -import bitstring +import bitstring # type: ignore import re import hashlib -from typing import List, NamedTuple +from typing import List, NamedTuple, Optional from bech32 import bech32_decode, CHARSET -from ecdsa import SECP256k1, VerifyingKey -from ecdsa.util import sigdecode_string +from ecdsa import SECP256k1, VerifyingKey # type: ignore +from ecdsa.util import sigdecode_string # type: ignore from binascii import unhexlify @@ -19,40 +17,40 @@ class Route(NamedTuple): class Invoice(object): - payment_hash: str = None + payment_hash: str amount_msat: int = 0 - description: str = None - payee: str = None - date: int = None + description: Optional[str] = None + description_hash: Optional[str] = None + payee: str + date: int expiry: int = 3600 - secret: str = None + secret: Optional[str] = None route_hints: List[Route] = [] min_final_cltv_expiry: int = 18 def decode(pr: str) -> Invoice: - """ Super naïve bolt11 decoder, - only gets payment_hash, description/description_hash and amount in msatoshi. + """bolt11 decoder, based on https://github.com/rustyrussell/lightning-payencode/blob/master/lnaddr.py """ - hrp, data = bech32_decode(pr) - if not hrp: - raise ValueError("Bad bech32 checksum") + hrp, decoded_data = bech32_decode(pr) + if hrp is None or decoded_data is None: + raise ValueError("Bad bech32 checksum") if not hrp.startswith("ln"): raise ValueError("Does not start with ln") - data = u5_to_bitarray(data) + bitarray = _u5_to_bitarray(decoded_data) # final signature 65 bytes, split it off. - if len(data) < 65 * 8: + if len(bitarray) < 65 * 8: raise ValueError("Too short to contain signature") # extract the signature - signature = data[-65 * 8 :].tobytes() + signature = bitarray[-65 * 8 :].tobytes() # the tagged fields as a bitstream - data = bitstring.ConstBitStream(data[: -65 * 8]) + data = bitstring.ConstBitStream(bitarray[: -65 * 8]) # build the invoice object invoice = Invoice() @@ -62,35 +60,35 @@ def decode(pr: str) -> Invoice: if m: amountstr = hrp[2 + m.end() :] if amountstr != "": - invoice.amount_msat = unshorten_amount(amountstr) + invoice.amount_msat = _unshorten_amount(amountstr) # pull out date invoice.date = data.read(35).uint while data.pos != data.len: - tag, tagdata, data = pull_tagged(data) + tag, tagdata, data = _pull_tagged(data) data_length = len(tagdata) / 5 if tag == "d": - invoice.description = trim_to_bytes(tagdata).decode("utf-8") + invoice.description = _trim_to_bytes(tagdata).decode("utf-8") elif tag == "h" and data_length == 52: - invoice.description = trim_to_bytes(tagdata).hex() + invoice.description_hash = _trim_to_bytes(tagdata).hex() elif tag == "p" and data_length == 52: - invoice.payment_hash = trim_to_bytes(tagdata).hex() + invoice.payment_hash = _trim_to_bytes(tagdata).hex() elif tag == "x": invoice.expiry = tagdata.uint elif tag == "n": - invoice.payee = trim_to_bytes(tagdata).hex() + invoice.payee = _trim_to_bytes(tagdata).hex() # this won't work in most cases, we must extract the payee # from the signature elif tag == "s": - invoice.secret = trim_to_bytes(tagdata).hex() + invoice.secret = _trim_to_bytes(tagdata).hex() elif tag == "r": s = bitstring.ConstBitStream(tagdata) while s.pos + 264 + 64 + 32 + 32 + 16 < s.len: route = Route( pubkey=s.read(264).tobytes().hex(), - short_channel_id=readable_scid(s.read(64).intbe), + short_channel_id=_readable_scid(s.read(64).intbe), base_fee_msat=s.read(32).intbe, ppm_fee=s.read(32).intbe, cltv=s.read(16).intbe, @@ -116,7 +114,7 @@ def decode(pr: str) -> Invoice: return invoice -def unshorten_amount(amount: str) -> int: +def _unshorten_amount(amount: str) -> int: """ Given a shortened amount, return millisatoshis """ # BOLT #11: @@ -141,18 +139,18 @@ def unshorten_amount(amount: str) -> int: raise ValueError("Invalid amount '{}'".format(amount)) if unit in units: - return int(amount[:-1]) * 100_000_000_000 / units[unit] + return int(int(amount[:-1]) * 100_000_000_000 / units[unit]) else: return int(amount) * 100_000_000_000 -def pull_tagged(stream): +def _pull_tagged(stream): tag = stream.read(5).uint length = stream.read(5).uint * 32 + stream.read(5).uint return (CHARSET[tag], stream.read(length * 5), stream) -def trim_to_bytes(barr): +def _trim_to_bytes(barr): # Adds a byte if necessary. b = barr.tobytes() if barr.len % 8 != 0: @@ -160,7 +158,7 @@ def trim_to_bytes(barr): return b -def readable_scid(short_channel_id: int) -> str: +def _readable_scid(short_channel_id: int) -> str: return "{blockheight}x{transactionindex}x{outputindex}".format( blockheight=((short_channel_id >> 40) & 0xFFFFFF), transactionindex=((short_channel_id >> 16) & 0xFFFFFF), @@ -168,7 +166,7 @@ def readable_scid(short_channel_id: int) -> str: ) -def u5_to_bitarray(arr): +def _u5_to_bitarray(arr: List[int]) -> bitstring.BitArray: ret = bitstring.BitArray() for a in arr: ret += bitstring.pack("uint:5", a) diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py index 2d064af1..4733a493 100644 --- a/lnbits/core/crud.py +++ b/lnbits/core/crud.py @@ -276,10 +276,10 @@ def delete_payment(checking_id: str) -> None: db.execute("DELETE FROM apipayments WHERE checking_id = ?", (checking_id,)) -def check_internal(payment_hash: str) -> None: +def check_internal(payment_hash: str) -> Optional[str]: with open_db() as db: row = db.fetchone("SELECT checking_id FROM apipayments WHERE hash = ?", (payment_hash,)) if not row: - return False + return None else: return row["checking_id"] diff --git a/lnbits/core/services.py b/lnbits/core/services.py index c15f78a0..6f39111f 100644 --- a/lnbits/core/services.py +++ b/lnbits/core/services.py @@ -1,27 +1,22 @@ -from typing import Optional, Tuple, Dict +from typing import Optional, Tuple, Dict, TypedDict from lnbits import bolt11 from lnbits.helpers import urlsafe_short_hash from lnbits.settings import WALLET +from lnbits.wallets.base import PaymentStatus from .crud import get_wallet, create_payment, delete_payment, check_internal, update_payment_status, get_wallet_payment def create_invoice( - *, - wallet_id: str, - amount: int, - memo: Optional[str] = None, - description_hash: Optional[bytes] = None, - extra: Optional[Dict] = None, + *, wallet_id: str, amount: int, memo: str, description_hash: Optional[bytes] = None, extra: Optional[Dict] = None, ) -> Tuple[str, str]: - try: - ok, checking_id, payment_request, error_message = WALLET.create_invoice( - amount=amount, memo=memo, description_hash=description_hash - ) - except Exception as e: - ok, error_message = False, str(e) + invoice_memo = None if description_hash else memo + storeable_memo = memo + ok, checking_id, payment_request, error_message = WALLET.create_invoice( + amount=amount, memo=invoice_memo, description_hash=description_hash + ) if not ok: raise Exception(error_message or "Unexpected backend error.") @@ -34,7 +29,7 @@ def create_invoice( payment_request=payment_request, payment_hash=invoice.payment_hash, amount=amount_msat, - memo=memo, + memo=storeable_memo, extra=extra, ) @@ -47,60 +42,71 @@ def pay_invoice( temp_id = f"temp_{urlsafe_short_hash()}" internal_id = f"internal_{urlsafe_short_hash()}" - try: - invoice = bolt11.decode(payment_request) - if invoice.amount_msat == 0: - raise ValueError("Amountless invoices not supported.") - if max_sat and invoice.amount_msat > max_sat * 1000: - raise ValueError("Amount in invoice is too high.") + invoice = bolt11.decode(payment_request) + if invoice.amount_msat == 0: + raise ValueError("Amountless invoices not supported.") + if max_sat and invoice.amount_msat > max_sat * 1000: + raise ValueError("Amount in invoice is too high.") - # put all parameters that don't change here - payment_kwargs = dict( - wallet_id=wallet_id, - payment_request=payment_request, - payment_hash=invoice.payment_hash, - amount=-invoice.amount_msat, - memo=invoice.description, - extra=extra, - ) + # put all parameters that don't change here + PaymentKwargs = TypedDict( + "PaymentKwargs", + { + "wallet_id": str, + "payment_request": str, + "payment_hash": str, + "amount": int, + "memo": str, + "extra": Optional[Dict], + }, + ) + payment_kwargs: PaymentKwargs = dict( + wallet_id=wallet_id, + payment_request=payment_request, + payment_hash=invoice.payment_hash, + amount=-invoice.amount_msat, + memo=invoice.description or "", + extra=extra, + ) - # check_internal() returns the checking_id of the invoice we're waiting for - internal = check_internal(invoice.payment_hash) - if internal: - # create a new payment from this wallet - create_payment(checking_id=internal_id, fee=0, pending=False, **payment_kwargs) - else: - # create a temporary payment here so we can check if - # the balance is enough in the next step - fee_reserve = max(1000, int(invoice.amount_msat * 0.01)) - create_payment(checking_id=temp_id, fee=-fee_reserve, **payment_kwargs) + # check_internal() returns the checking_id of the invoice we're waiting for + internal = check_internal(invoice.payment_hash) + if internal: + # create a new payment from this wallet + create_payment(checking_id=internal_id, fee=0, pending=False, **payment_kwargs) + else: + # create a temporary payment here so we can check if + # the balance is enough in the next step + fee_reserve = max(1000, int(invoice.amount_msat * 0.01)) + create_payment(checking_id=temp_id, fee=-fee_reserve, **payment_kwargs) - # do the balance check - wallet = get_wallet(wallet_id) - assert wallet, "invalid wallet id" - if wallet.balance_msat < 0: - raise PermissionError("Insufficient balance.") + # do the balance check + wallet = get_wallet(wallet_id) + assert wallet, "invalid wallet id" + if wallet.balance_msat < 0: + raise PermissionError("Insufficient balance.") - if internal: - # mark the invoice from the other side as not pending anymore - # so the other side only has access to his new money when we are sure - # the payer has enough to deduct from - update_payment_status(checking_id=internal, pending=False) - else: - # actually pay the external invoice - ok, checking_id, fee_msat, error_message = WALLET.pay_invoice(payment_request) - if ok: - create_payment(checking_id=checking_id, fee=fee_msat, **payment_kwargs) - delete_payment(temp_id) + if internal: + # mark the invoice from the other side as not pending anymore + # so the other side only has access to his new money when we are sure + # the payer has enough to deduct from + update_payment_status(checking_id=internal, pending=False) + else: + # actually pay the external invoice + ok, checking_id, fee_msat, error_message = WALLET.pay_invoice(payment_request) + if ok: + create_payment(checking_id=checking_id, fee=fee_msat, **payment_kwargs) + delete_payment(temp_id) - except Exception as e: - ok, error_message = False, str(e) if not ok: raise Exception(error_message or "Unexpected backend error.") return invoice.payment_hash -def check_invoice_status(wallet_id: str, payment_hash: str) -> str: +def check_invoice_status(wallet_id: str, payment_hash: str) -> PaymentStatus: payment = get_wallet_payment(wallet_id, payment_hash) + if not payment: + return PaymentStatus(None) + return WALLET.get_invoice_status(payment.checking_id) diff --git a/lnbits/extensions/paywall/views_api.py b/lnbits/extensions/paywall/views_api.py index f00ce795..85786a5f 100644 --- a/lnbits/extensions/paywall/views_api.py +++ b/lnbits/extensions/paywall/views_api.py @@ -64,7 +64,7 @@ def api_paywall_create_invoice(paywall_id): try: amount = g.data["amount"] if g.data["amount"] > paywall.amount else paywall.amount payment_hash, payment_request = create_invoice( - wallet_id=paywall.wallet, amount=amount, memo=f"{paywall.memo}", extra={'tag': 'paywall'} + wallet_id=paywall.wallet, amount=amount, memo=f"{paywall.memo}", extra={"tag": "paywall"} ) except Exception as e: return jsonify({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR