diff --git a/lnbits/extensions/lnurlpos/lnurl.py b/lnbits/extensions/lnurlpos/lnurl.py index a98bcdf4..5a0de90c 100644 --- a/lnbits/extensions/lnurlpos/lnurl.py +++ b/lnbits/extensions/lnurlpos/lnurl.py @@ -24,24 +24,27 @@ from .crud import ( update_lnurlpospayment, ) + def bech32_decode(bech): """tweaked version of bech32_decode that ignores length limitations""" - if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or - (bech.lower() != bech and bech.upper() != bech)): + if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or ( + bech.lower() != bech and bech.upper() != bech + ): return bech = bech.lower() - pos = bech.rfind('1') + pos = bech.rfind("1") if pos < 1 or pos + 7 > len(bech): return - if not all(x in bech32.CHARSET for x in bech[pos+1:]): + if not all(x in bech32.CHARSET for x in bech[pos + 1 :]): return hrp = bech[:pos] - data = [bech32.CHARSET.find(x) for x in bech[pos+1:]] + data = [bech32.CHARSET.find(x) for x in bech[pos + 1 :]] encoding = bech32.bech32_verify_checksum(hrp, data) if encoding is None: return return bytes(bech32.convertbits(data[:-6], 5, 8, False)) + def xor_decrypt(key, blob): s = BytesIO(blob) variant = s.read(1)[0] @@ -62,10 +65,12 @@ def xor_decrypt(key, blob): if len(payload) != l: raise RuntimeError("Missing payload bytes") hmacval = s.read() - expected = hmac.new(key, b"Data:" + blob[:-len(hmacval)], digestmod="sha256").digest() + expected = hmac.new( + key, b"Data:" + blob[: -len(hmacval)], digestmod="sha256" + ).digest() if len(hmacval) < 8: raise RuntimeError("HMAC is too short") - if hmacval != expected[:len(hmacval)]: + if hmacval != expected[: len(hmacval)]: raise RuntimeError("HMAC is invalid") secret = hmac.new(key, b"Round secret:" + nonce, digestmod="sha256").digest() payload = bytearray(payload) @@ -76,6 +81,7 @@ def xor_decrypt(key, blob): amount_in_cent = compact.read_from(s) return pin, amount_in_cent + @lnurlpos_ext.get( "/api/v1/lnurl/{pos_id}", status_code=HTTPStatus.OK, @@ -85,12 +91,6 @@ async def lnurl_v1_params( request: Request, pos_id: str = Query(None), p: str = Query(None), -): - return await handle_lnurl_firstrequest(request, pos_id, p) - - -async def handle_lnurl_firstrequest( - request: Request, pos_id: str, payload: str ): pos = await get_lnurlpos(pos_id) if not pos: @@ -100,10 +100,18 @@ async def handle_lnurl_firstrequest( } if len(payload) % 4 > 0: - payload += "="*(4-(len(payload)%4)) + payload += "=" * (4 - (len(payload) % 4)) data = base64.urlsafe_b64decode(payload) - pin, amount_in_cent = xor_decrypt(pos.key.encode(), data) + pin = 0 + amount_in_cent = 0 + try: + result = xor_decrypt(pos.key.encode(), data) + pin = result[0] + amount_in_cent = result[1] + except Exception as exc: + return {"status": "ERROR", "reason": str(exc)} + price_msat = ( await fiat_amount_as_satoshis(float(amount_in_cent) / 100, pos.currency) if pos.currency != "sat"