From 3e0fd39175af556870a333a6b9a976c6c9d60ef2 Mon Sep 17 00:00:00 2001 From: Tiago vasconcelos Date: Sun, 22 Aug 2021 12:16:31 +0100 Subject: [PATCH 1/4] some syntax refactoring --- lnbits/extensions/lnurlp/views_api.py | 18 ++++++++--------- lnbits/extensions/offlineshop/views_api.py | 8 ++++---- lnbits/extensions/paywall/views_api.py | 22 ++++++++++----------- lnbits/extensions/streamalerts/views_api.py | 12 +++++------ lnbits/extensions/subdomains/views_api.py | 14 ++++++------- lnbits/extensions/watchonly/views_api.py | 8 ++++---- lnbits/extensions/withdraw/views_api.py | 10 +++++----- 7 files changed, 46 insertions(+), 46 deletions(-) diff --git a/lnbits/extensions/lnurlp/views_api.py b/lnbits/extensions/lnurlp/views_api.py index b08a2162..a41e510e 100644 --- a/lnbits/extensions/lnurlp/views_api.py +++ b/lnbits/extensions/lnurlp/views_api.py @@ -48,7 +48,7 @@ async def api_links(): ) -@lnurlp_ext.get("/api/v1/links/") +@lnurlp_ext.get("/api/v1/links/{link_id}") @api_check_wallet_key("invoice") async def api_link_retrieve(link_id): link = await get_pay_link(link_id) @@ -65,14 +65,14 @@ class CreateData(BaseModel): description: str min: int = Query(0.01, ge=0.01) max: int = Query(0.01, ge=0.01) - currency: Optional[str] - comment_chars: int = Query(0, ge=0, lt=800) - webhook_url: Optional[str] - success_text: Optional[str] - success_url: Optional[str] + currency: Optional[str] + comment_chars: int = Query(0, ge=0, lt=800) + webhook_url: Optional[str] + success_text: Optional[str] + success_url: Optional[str] @lnurlp_ext.post("/api/v1/links") -@lnurlp_ext.put("/api/v1/links/") +@lnurlp_ext.put("/api/v1/links/{link_id}") @api_check_wallet_key("invoice") async def api_link_create_or_update(data: CreateData, link_id=None): if data.min > data.max: @@ -111,7 +111,7 @@ async def api_link_create_or_update(data: CreateData, link_id=None): ) -@lnurlp_ext.delete("/api/v1/links/") +@lnurlp_ext.delete("/api/v1/links/{link_id}") @api_check_wallet_key("invoice") async def api_link_delete(link_id): link = await get_pay_link(link_id) @@ -127,7 +127,7 @@ async def api_link_delete(link_id): return "", HTTPStatus.NO_CONTENT -@lnurlp_ext.get("/api/v1/rate/") +@lnurlp_ext.get("/api/v1/rate/{currency}") async def api_check_fiat_rate(currency): try: rate = await get_fiat_rate_satoshis(currency) diff --git a/lnbits/extensions/offlineshop/views_api.py b/lnbits/extensions/offlineshop/views_api.py index af2150cb..be860bc0 100644 --- a/lnbits/extensions/offlineshop/views_api.py +++ b/lnbits/extensions/offlineshop/views_api.py @@ -53,11 +53,11 @@ class CreateItemsData(BaseModel): name: str description: str image: Optional[str] - price: int - unit: str + price: int + unit: str @offlineshop_ext.post("/api/v1/offlineshop/items") -@offlineshop_ext.put("/api/v1/offlineshop/items/") +@offlineshop_ext.put("/api/v1/offlineshop/items/{item_id}") @api_check_wallet_key("invoice") async def api_add_or_update_item(data: CreateItemsData, item_id=None): shop = await get_or_create_shop_by_wallet(g.wallet.id) @@ -84,7 +84,7 @@ async def api_add_or_update_item(data: CreateItemsData, item_id=None): return "", HTTPStatus.OK -@offlineshop_ext.delete("/api/v1/offlineshop/items/") +@offlineshop_ext.delete("/api/v1/offlineshop/items/{item_id}") @api_check_wallet_key("invoice") async def api_delete_item(item_id): shop = await get_or_create_shop_by_wallet(g.wallet.id) diff --git a/lnbits/extensions/paywall/views_api.py b/lnbits/extensions/paywall/views_api.py index 94103899..0d4b181f 100644 --- a/lnbits/extensions/paywall/views_api.py +++ b/lnbits/extensions/paywall/views_api.py @@ -29,8 +29,8 @@ class CreateData(BaseModel): url: Optional[str] = Query(...) memo: Optional[str] = Query(...) description: str - amount: int - remembers: bool + amount: int + remembers: bool @paywall_ext.post("/api/v1/paywalls") @api_check_wallet_key("invoice") @@ -39,7 +39,7 @@ async def api_paywall_create(data: CreateData): return paywall, HTTPStatus.CREATED -@paywall_ext.delete("/api/v1/paywalls/") +@paywall_ext.delete("/api/v1/paywalls/{paywall_id}") @api_check_wallet_key("invoice") async def api_paywall_delete(paywall_id): paywall = await get_paywall(paywall_id) @@ -55,7 +55,7 @@ async def api_paywall_delete(paywall_id): return "", HTTPStatus.NO_CONTENT -@paywall_ext.post("/api/v1/paywalls//invoice") +@paywall_ext.post("/api/v1/paywalls/{paywall_id}/invoice") async def api_paywall_create_invoice(amount: int = Query(..., ge=1), paywall_id = None): paywall = await get_paywall(paywall_id) @@ -76,26 +76,26 @@ async def api_paywall_create_invoice(amount: int = Query(..., ge=1), paywall_id extra={"tag": "paywall"}, ) except Exception as e: - return jsonable_encoder({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR + return {"message": str(e)}, HTTPStatus.INTERNAL_SERVER_ERROR return ( - jsonable_encoder({"payment_hash": payment_hash, "payment_request": payment_request}), + {"payment_hash": payment_hash, "payment_request": payment_request}, HTTPStatus.CREATED, ) -@paywall_ext.post("/api/v1/paywalls//check_invoice") +@paywall_ext.post("/api/v1/paywalls/{paywall_id}/check_invoice") async def api_paywal_check_invoice(payment_hash: str = Query(...), paywall_id = None): paywall = await get_paywall(paywall_id) if not paywall: - return jsonable_encoder({"message": "Paywall does not exist."}), HTTPStatus.NOT_FOUND + return {"message": "Paywall does not exist."}, HTTPStatus.NOT_FOUND try: status = await check_invoice_status(paywall.wallet, payment_hash) is_paid = not status.pending except Exception: - return jsonable_encoder({"paid": False}), HTTPStatus.OK + return {"paid": False}, HTTPStatus.OK if is_paid: wallet = await get_wallet(paywall.wallet) @@ -103,8 +103,8 @@ async def api_paywal_check_invoice(payment_hash: str = Query(...), paywall_id = await payment.set_pending(False) return ( - jsonable_encoder({"paid": True, "url": paywall.url, "remembers": paywall.remembers}), + {"paid": True, "url": paywall.url, "remembers": paywall.remembers}, HTTPStatus.OK, ) - return jsonable_encoder({"paid": False}), HTTPStatus.OK + return {"paid": False}, HTTPStatus.OK diff --git a/lnbits/extensions/streamalerts/views_api.py b/lnbits/extensions/streamalerts/views_api.py index 9e34ff36..992e5188 100644 --- a/lnbits/extensions/streamalerts/views_api.py +++ b/lnbits/extensions/streamalerts/views_api.py @@ -46,7 +46,7 @@ async def api_create_service(data: CreateServicesData): return service._asdict(), HTTPStatus.CREATED -@streamalerts_ext.get("/api/v1/getaccess/") +@streamalerts_ext.get("/api/v1/getaccess/{service_id}") async def api_get_access(service_id): """Redirect to Streamlabs' Approve/Decline page for API access for Service with service_id @@ -69,7 +69,7 @@ async def api_get_access(service_id): return ({"message": "Service does not exist!"}, HTTPStatus.BAD_REQUEST) -@streamalerts_ext.get("/api/v1/authenticate/") +@streamalerts_ext.get("/api/v1/authenticate/{service_id}") async def api_authenticate_service(Code: str, State: str, service_id): """Endpoint visited via redirect during third party API authentication @@ -183,7 +183,7 @@ async def api_get_donations(): ) -@streamalerts_ext.put("/api/v1/donations/") +@streamalerts_ext.put("/api/v1/donations/{donation_id}") @api_check_wallet_key("invoice") async def api_update_donation(donation_id=None): """Update a donation with the data given in the request""" @@ -208,7 +208,7 @@ async def api_update_donation(donation_id=None): return donation._asdict(), HTTPStatus.CREATED -@streamalerts_ext.put("/api/v1/services/") +@streamalerts_ext.put("/api/v1/services/{service_id}") @api_check_wallet_key("invoice") async def api_update_service(service_id=None): """Update a service with the data given in the request""" @@ -229,7 +229,7 @@ async def api_update_service(service_id=None): return service._asdict(), HTTPStatus.CREATED -@streamalerts_ext.delete("/api/v1/donations/") +@streamalerts_ext.delete("/api/v1/donations/{donation_id}") @api_check_wallet_key("invoice") async def api_delete_donation(donation_id): """Delete the donation with the given donation_id""" @@ -245,7 +245,7 @@ async def api_delete_donation(donation_id): return "", HTTPStatus.NO_CONTENT -@streamalerts_ext.delete("/api/v1/services/") +@streamalerts_ext.delete("/api/v1/services/{service_id}") @api_check_wallet_key("invoice") async def api_delete_service(service_id): """Delete the service with the given service_id""" diff --git a/lnbits/extensions/subdomains/views_api.py b/lnbits/extensions/subdomains/views_api.py index 3684c4ad..72000238 100644 --- a/lnbits/extensions/subdomains/views_api.py +++ b/lnbits/extensions/subdomains/views_api.py @@ -51,7 +51,7 @@ class CreateDomainsData(BaseModel): allowed_record_types: str @subdomains_ext.post("/api/v1/domains") -@subdomains_ext.put("/api/v1/domains/") +@subdomains_ext.put("/api/v1/domains/{domain_id}") @api_check_wallet_key("invoice") async def api_domain_create(data: CreateDomainsData, domain_id=None): if domain_id: @@ -66,10 +66,10 @@ async def api_domain_create(data: CreateDomainsData, domain_id=None): domain = await update_domain(domain_id, **data) else: domain = await create_domain(**data) - return jsonify(domain._asdict()), HTTPStatus.CREATED + return domain._asdict(), HTTPStatus.CREATED -@subdomains_ext.delete("/api/v1/domains/") +@subdomains_ext.delete("/api/v1/domains/{domain_id}") @api_check_wallet_key("invoice") async def api_domain_delete(domain_id): domain = await get_domain(domain_id) @@ -110,14 +110,14 @@ class CreateDomainsData(BaseModel): duration: int record_type: str -@subdomains_ext.post("/api/v1/subdomains/") +@subdomains_ext.post("/api/v1/subdomains/{domain_id}") async def api_subdomain_make_subdomain(data: CreateDomainsData, domain_id): domain = await get_domain(domain_id) # If the request is coming for the non-existant domain if not domain: - return jsonify({"message": "LNsubdomain does not exist."}), HTTPStatus.NOT_FOUND + return {"message": "LNsubdomain does not exist."}, HTTPStatus.NOT_FOUND ## If record_type is not one of the allowed ones reject the request if data.record_type not in domain.allowed_record_types: @@ -184,7 +184,7 @@ async def api_subdomain_make_subdomain(data: CreateDomainsData, domain_id): ) -@subdomains_ext.get("/api/v1/subdomains/") +@subdomains_ext.get("/api/v1/subdomains/{payment_hash}") async def api_subdomain_send_subdomain(payment_hash): subdomain = await get_subdomain(payment_hash) try: @@ -199,7 +199,7 @@ async def api_subdomain_send_subdomain(payment_hash): return {"paid": False}, HTTPStatus.OK -@subdomains_ext.delete("/api/v1/subdomains/") +@subdomains_ext.delete("/api/v1/subdomains/{subdomain_id}") @api_check_wallet_key("invoice") async def api_subdomain_delete(subdomain_id): subdomain = await get_subdomain(subdomain_id) diff --git a/lnbits/extensions/watchonly/views_api.py b/lnbits/extensions/watchonly/views_api.py index 24a269bf..e4598a97 100644 --- a/lnbits/extensions/watchonly/views_api.py +++ b/lnbits/extensions/watchonly/views_api.py @@ -37,7 +37,7 @@ async def api_wallets_retrieve(): return "" -@watchonly_ext.get("/api/v1/wallet/") +@watchonly_ext.get("/api/v1/wallet/{wallet_id}") @api_check_wallet_key("invoice") async def api_wallet_retrieve(wallet_id): wallet = await get_watch_wallet(wallet_id) @@ -63,7 +63,7 @@ async def api_wallet_create_or_update(masterPub: str, Title: str, wallet_id=None return wallet._asdict(), HTTPStatus.CREATED -@watchonly_ext.delete("/api/v1/wallet/") +@watchonly_ext.delete("/api/v1/wallet/{wallet_id}") @api_check_wallet_key("admin") async def api_wallet_delete(wallet_id): wallet = await get_watch_wallet(wallet_id) @@ -79,7 +79,7 @@ async def api_wallet_delete(wallet_id): #############################ADDRESSES########################## -@watchonly_ext.get("/api/v1/address/") +@watchonly_ext.get("/api/v1/address/{wallet_id}") @api_check_wallet_key("invoice") async def api_fresh_address(wallet_id): await get_fresh_address(wallet_id) @@ -89,7 +89,7 @@ async def api_fresh_address(wallet_id): return [address._asdict() for address in addresses], HTTPStatus.OK -@watchonly_ext.get("/api/v1/addresses/") +@watchonly_ext.get("/api/v1/addresses/{wallet_id}") @api_check_wallet_key("invoice") async def api_get_addresses(wallet_id): wallet = await get_watch_wallet(wallet_id) diff --git a/lnbits/extensions/withdraw/views_api.py b/lnbits/extensions/withdraw/views_api.py index 97d7a9de..b34b9fbb 100644 --- a/lnbits/extensions/withdraw/views_api.py +++ b/lnbits/extensions/withdraw/views_api.py @@ -46,7 +46,7 @@ async def api_links(): ) -@withdraw_ext.get("/api/v1/links/") +@withdraw_ext.get("/api/v1/links/{link_id}") @api_check_wallet_key("invoice") async def api_link_retrieve(link_id): link = await get_withdraw_link(link_id, 0) @@ -70,7 +70,7 @@ class CreateData(BaseModel): is_unique: bool @withdraw_ext.post("/api/v1/links") -@withdraw_ext.put("/api/v1/links/") +@withdraw_ext.put("/api/v1/links/{link_id}") @api_check_wallet_key("admin") async def api_link_create_or_update(data: CreateData, link_id: str = None): if data.max_withdrawable < data.min_withdrawable: @@ -97,7 +97,7 @@ async def api_link_create_or_update(data: CreateData, link_id: str = None): HTTPStatus.NOT_FOUND, ) if link.wallet != g.wallet.id: - return jsonify({"message": "Not your withdraw link."}), HTTPStatus.FORBIDDEN + return {"message": "Not your withdraw link."}, HTTPStatus.FORBIDDEN link = await update_withdraw_link(link_id, **data, usescsv=usescsv, used=0) else: link = await create_withdraw_link( @@ -109,7 +109,7 @@ async def api_link_create_or_update(data: CreateData, link_id: str = None): ) -@withdraw_ext.delete("/api/v1/links/") +@withdraw_ext.delete("/api/v1/links/{link_id}") @api_check_wallet_key("admin") async def api_link_delete(link_id): link = await get_withdraw_link(link_id) @@ -127,7 +127,7 @@ async def api_link_delete(link_id): return "", HTTPStatus.NO_CONTENT -@withdraw_ext.get("/api/v1/links//") +@withdraw_ext.get("/api/v1/links/{the_hash}/{lnurl_id}") @api_check_wallet_key("invoice") async def api_hash_retrieve(the_hash, lnurl_id): hashCheck = await get_hash_check(the_hash, lnurl_id) From fc68e0a6dad1b1e6f2948785858864aff38a569c Mon Sep 17 00:00:00 2001 From: Tiago vasconcelos Date: Sun, 22 Aug 2021 12:17:43 +0100 Subject: [PATCH 2/4] fastAPI refactoring --- lnbits/extensions/lndhub/views_api.py | 121 ++++++++++++++------------ 1 file changed, 63 insertions(+), 58 deletions(-) diff --git a/lnbits/extensions/lndhub/views_api.py b/lnbits/extensions/lndhub/views_api.py index de61820a..057c8d08 100644 --- a/lnbits/extensions/lndhub/views_api.py +++ b/lnbits/extensions/lndhub/views_api.py @@ -2,6 +2,11 @@ import time from base64 import urlsafe_b64encode from quart import jsonify, g, request +from fastapi import FastAPI, Query +from fastapi.encoders import jsonable_encoder +from fastapi.responses import JSONResponse +from pydantic import BaseModel + from lnbits.core.services import pay_invoice, create_invoice from lnbits.core.crud import get_payments, delete_expired_invoices from lnbits.decorators import api_validate_post_request @@ -13,62 +18,62 @@ from .decorators import check_wallet from .utils import to_buffer, decoded_as_lndhub -@lndhub_ext.route("/ext/getinfo", methods=["GET"]) +@lndhub_ext.get("/ext/getinfo") async def lndhub_getinfo(): - return jsonify({"error": True, "code": 1, "message": "bad auth"}) + return {"error": True, "code": 1, "message": "bad auth"} -@lndhub_ext.route("/ext/auth", methods=["POST"]) -@api_validate_post_request( - schema={ - "login": {"type": "string", "required": True, "excludes": "refresh_token"}, - "password": {"type": "string", "required": True, "excludes": "refresh_token"}, - "refresh_token": { - "type": "string", - "required": True, - "excludes": ["login", "password"], - }, - } -) -async def lndhub_auth(): +@lndhub_ext.post("/ext/auth") +# @api_validate_post_request( +# schema={ +# "login": {"type": "string", "required": True, "excludes": "refresh_token"}, +# "password": {"type": "string", "required": True, "excludes": "refresh_token"}, +# "refresh_token": { +# "type": "string", +# "required": True, +# "excludes": ["login", "password"], +# }, +# } +# ) +async def lndhub_auth(login: str, password: str, refresh_token: str): #missing the "excludes" thing token = ( - g.data["refresh_token"] - if "refresh_token" in g.data and g.data["refresh_token"] + refresh_token + if refresh_token else urlsafe_b64encode( - (g.data["login"] + ":" + g.data["password"]).encode("utf-8") + (login + ":" + password).encode("utf-8") ).decode("ascii") ) - return jsonify({"refresh_token": token, "access_token": token}) + return {"refresh_token": token, "access_token": token} -@lndhub_ext.route("/ext/addinvoice", methods=["POST"]) +@lndhub_ext.post("/ext/addinvoice") @check_wallet() -@api_validate_post_request( - schema={ - "amt": {"type": "string", "required": True}, - "memo": {"type": "string", "required": True}, - "preimage": {"type": "string", "required": False}, - } -) -async def lndhub_addinvoice(): +# @api_validate_post_request( +# schema={ +# "amt": {"type": "string", "required": True}, +# "memo": {"type": "string", "required": True}, +# "preimage": {"type": "string", "required": False}, +# } +# ) +async def lndhub_addinvoice(amt: str, memo: str, preimage: str = ""): try: _, pr = await create_invoice( wallet_id=g.wallet.id, - amount=int(g.data["amt"]), - memo=g.data["memo"], + amount=int(amt), + memo=memo, extra={"tag": "lndhub"}, ) except Exception as e: - return jsonify( + return { "error": True, "code": 7, "message": "Failed to create invoice: " + str(e), } - ) + invoice = bolt11.decode(pr) - return jsonify( + return { "pay_req": pr, "payment_request": pr, @@ -76,30 +81,30 @@ async def lndhub_addinvoice(): "r_hash": to_buffer(invoice.payment_hash), "hash": invoice.payment_hash, } - ) -@lndhub_ext.route("/ext/payinvoice", methods=["POST"]) + +@lndhub_ext.post("/ext/payinvoice") @check_wallet(requires_admin=True) -@api_validate_post_request(schema={"invoice": {"type": "string", "required": True}}) -async def lndhub_payinvoice(): +# @api_validate_post_request(schema={"invoice": {"type": "string", "required": True}}) +async def lndhub_payinvoice(invoice: str): try: await pay_invoice( wallet_id=g.wallet.id, - payment_request=g.data["invoice"], + payment_request=invoice, extra={"tag": "lndhub"}, ) except Exception as e: - return jsonify( + return { "error": True, "code": 10, "message": "Payment failed: " + str(e), } - ) - invoice: bolt11.Invoice = bolt11.decode(g.data["invoice"]) - return jsonify( + + invoice: bolt11.Invoice = bolt11.decode(invoice) + return { "payment_error": "", "payment_preimage": "0" * 64, @@ -113,16 +118,16 @@ async def lndhub_payinvoice(): "timestamp": int(time.time()), "memo": invoice.description, } - ) -@lndhub_ext.route("/ext/balance", methods=["GET"]) + +@lndhub_ext.get("/ext/balance") @check_wallet() async def lndhub_balance(): - return jsonify({"BTC": {"AvailableBalance": g.wallet.balance}}) + return {"BTC": {"AvailableBalance": g.wallet.balance}} -@lndhub_ext.route("/ext/gettxs", methods=["GET"]) +@lndhub_ext.get("/ext/gettxs") @check_wallet() async def lndhub_gettxs(): for payment in await get_payments( @@ -138,7 +143,7 @@ async def lndhub_gettxs(): ) limit = int(request.args.get("limit", 200)) - return jsonify( + return [ { "payment_preimage": payment.preimage, @@ -164,10 +169,10 @@ async def lndhub_gettxs(): )[:limit] ) ] - ) -@lndhub_ext.route("/ext/getuserinvoices", methods=["GET"]) + +@lndhub_ext.get("/ext/getuserinvoices") @check_wallet() async def lndhub_getuserinvoices(): await delete_expired_invoices() @@ -184,7 +189,7 @@ async def lndhub_getuserinvoices(): ) limit = int(request.args.get("limit", 200)) - return jsonify( + return [ { "r_hash": to_buffer(invoice.payment_hash), @@ -210,31 +215,31 @@ async def lndhub_getuserinvoices(): )[:limit] ) ] - ) -@lndhub_ext.route("/ext/getbtc", methods=["GET"]) + +@lndhub_ext.get("/ext/getbtc") @check_wallet() async def lndhub_getbtc(): "load an address for incoming onchain btc" - return jsonify([]) + return [] -@lndhub_ext.route("/ext/getpending", methods=["GET"]) +@lndhub_ext.get("/ext/getpending") @check_wallet() async def lndhub_getpending(): "pending onchain transactions" - return jsonify([]) + return [] -@lndhub_ext.route("/ext/decodeinvoice", methods=["GET"]) +@lndhub_ext.get("/ext/decodeinvoice") async def lndhub_decodeinvoice(): invoice = request.args.get("invoice") inv = bolt11.decode(invoice) - return jsonify(decoded_as_lndhub(inv)) + return decoded_as_lndhub(inv) -@lndhub_ext.route("/ext/checkrouteinvoice", methods=["GET"]) +@lndhub_ext.get("/ext/checkrouteinvoice") async def lndhub_checkrouteinvoice(): "not implemented on canonical lndhub" pass From 938fc54af3479996a95de08e0553b02f93c56ca9 Mon Sep 17 00:00:00 2001 From: Stefan Stammberger Date: Sun, 22 Aug 2021 20:07:24 +0200 Subject: [PATCH 3/4] feat: switch from Quart to FastAPI part I --- lnbits/__main__.py | 11 +- lnbits/app.py | 123 ++++++++++++++-------- lnbits/auth_bearer.py | 49 +++++++++ lnbits/core/__init__.py | 27 +++-- lnbits/core/views/api.py | 46 ++++---- lnbits/core/views/generic.py | 29 ++--- lnbits/decorators.py | 14 +-- lnbits/extensions/offlineshop/__init__.py | 7 +- lnbits/jinja2_templating.py | 36 +++++++ lnbits/requestvars.py | 9 ++ lnbits/templates/base.html | 4 +- lnbits/templates/print.html | 4 +- 12 files changed, 245 insertions(+), 114 deletions(-) create mode 100644 lnbits/auth_bearer.py create mode 100644 lnbits/jinja2_templating.py create mode 100644 lnbits/requestvars.py diff --git a/lnbits/__main__.py b/lnbits/__main__.py index 90b08642..bd564787 100644 --- a/lnbits/__main__.py +++ b/lnbits/__main__.py @@ -1,4 +1,7 @@ +from hypercorn.trio import serve import trio +import trio_asyncio +from hypercorn.config import Config from .commands import migrate_databases, transpile_scss, bundle_vendored @@ -8,7 +11,7 @@ bundle_vendored() from .app import create_app -app = create_app() +app = trio.run(create_app) from .settings import ( LNBITS_SITE_TITLE, @@ -17,6 +20,8 @@ from .settings import ( LNBITS_DATA_FOLDER, WALLET, LNBITS_COMMIT, + HOST, + PORT ) print( @@ -30,4 +35,6 @@ print( """ ) -app.run(host=app.config["HOST"], port=app.config["PORT"]) +config = Config() +config.bind = [f"{HOST}:{PORT}"] +trio_asyncio.run(serve, app, config) diff --git a/lnbits/app.py b/lnbits/app.py index 5da9a195..a686d5f7 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -1,12 +1,15 @@ +import jinja2 +from lnbits.jinja2_templating import Jinja2Templates import sys import warnings import importlib import traceback +import trio -from quart import g -from quart_trio import QuartTrio -from quart_cors import cors # type: ignore -from quart_compress import Compress # type: ignore +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.gzip import GZipMiddleware +from fastapi.staticfiles import StaticFiles from .commands import db_migrate, handle_assets from .core import core_app @@ -26,32 +29,66 @@ from .tasks import ( catch_everything_and_restart, ) from .settings import WALLET +from .requestvars import g, request_global +import lnbits.settings - -def create_app(config_object="lnbits.settings") -> QuartTrio: +async def create_app(config_object="lnbits.settings") -> FastAPI: """Create application factory. :param config_object: The configuration object to use. """ - app = QuartTrio(__name__, static_folder="static") - app.config.from_object(config_object) - app.asgi_http_class = ASGIProxyFix + app = FastAPI() + app.mount("/static", StaticFiles(directory="lnbits/static"), name="static") - cors(app) - Compress(app) + origins = [ + "http://localhost", + "http://localhost:5000", + ] + + app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + g().config = lnbits.settings + g().templates = build_standard_jinja_templates() + + app.add_middleware(GZipMiddleware, minimum_size=1000) + # app.add_middleware(ASGIProxyFix) check_funding_source(app) register_assets(app) - register_blueprints(app) - register_filters(app) - register_commands(app) + register_routes(app) + # register_commands(app) register_async_tasks(app) - register_exception_handlers(app) + # register_exception_handlers(app) return app +def build_standard_jinja_templates(): + t = Jinja2Templates( + loader=jinja2.FileSystemLoader(["lnbits/templates", "lnbits/core/templates"]), + ) + t.env.globals["SITE_TITLE"] = lnbits.settings.LNBITS_SITE_TITLE + t.env.globals["SITE_TAGLINE"] = lnbits.settings.LNBITS_SITE_TAGLINE + t.env.globals["SITE_DESCRIPTION"] = lnbits.settings.LNBITS_SITE_DESCRIPTION + t.env.globals["LNBITS_THEME_OPTIONS"] = lnbits.settings.LNBITS_THEME_OPTIONS + t.env.globals["LNBITS_VERSION"] = lnbits.settings.LNBITS_COMMIT + t.env.globals["EXTENSIONS"] = get_valid_extensions() + + if g().config.DEBUG: + t.env.globals["VENDORED_JS"] = map(url_for_vendored, get_js_vendored()) + t.env.globals["VENDORED_CSS"] = map(url_for_vendored, get_css_vendored()) + else: + t.env.globals["VENDORED_JS"] = ["/static/bundle.js"] + t.env.globals["VENDORED_CSS"] = ["/static/bundle.css"] -def check_funding_source(app: QuartTrio) -> None: - @app.before_serving + return t + +def check_funding_source(app: FastAPI) -> None: + @app.on_event("startup") async def check_wallet_status(): error_message, balance = await WALLET.status() if error_message: @@ -67,64 +104,60 @@ def check_funding_source(app: QuartTrio) -> None: ) -def register_blueprints(app: QuartTrio) -> None: +def register_routes(app: FastAPI) -> None: """Register Flask blueprints / LNbits extensions.""" - app.register_blueprint(core_app) + app.include_router(core_app) for ext in get_valid_extensions(): try: ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}") - bp = getattr(ext_module, f"{ext.code}_ext") + ext_route = getattr(ext_module, f"{ext.code}_ext") - app.register_blueprint(bp, url_prefix=f"/{ext.code}") + app.include_router(ext_route) except Exception: raise ImportError( f"Please make sure that the extension `{ext.code}` follows conventions." ) -def register_commands(app: QuartTrio): +def register_commands(app: FastAPI): """Register Click commands.""" app.cli.add_command(db_migrate) app.cli.add_command(handle_assets) -def register_assets(app: QuartTrio): +def register_assets(app: FastAPI): """Serve each vendored asset separately or a bundle.""" - @app.before_request + @app.on_event("startup") async def vendored_assets_variable(): - if app.config["DEBUG"]: - g.VENDORED_JS = map(url_for_vendored, get_js_vendored()) - g.VENDORED_CSS = map(url_for_vendored, get_css_vendored()) + if g().config.DEBUG: + g().VENDORED_JS = map(url_for_vendored, get_js_vendored()) + g().VENDORED_CSS = map(url_for_vendored, get_css_vendored()) else: - g.VENDORED_JS = ["/static/bundle.js"] - g.VENDORED_CSS = ["/static/bundle.css"] - - -def register_filters(app: QuartTrio): - """Jinja filters.""" - app.jinja_env.globals["SITE_TITLE"] = app.config["LNBITS_SITE_TITLE"] - app.jinja_env.globals["SITE_TAGLINE"] = app.config["LNBITS_SITE_TAGLINE"] - app.jinja_env.globals["SITE_DESCRIPTION"] = app.config["LNBITS_SITE_DESCRIPTION"] - app.jinja_env.globals["LNBITS_THEME_OPTIONS"] = app.config["LNBITS_THEME_OPTIONS"] - app.jinja_env.globals["LNBITS_VERSION"] = app.config["LNBITS_COMMIT"] - app.jinja_env.globals["EXTENSIONS"] = get_valid_extensions() + g().VENDORED_JS = ["/static/bundle.js"] + g().VENDORED_CSS = ["/static/bundle.css"] def register_async_tasks(app): - @app.route("/wallet/webhook", methods=["GET", "POST", "PUT", "PATCH", "DELETE"]) + @app.route("/wallet/webhook") async def webhook_listener(): return await webhook_handler() - @app.before_serving + @app.on_event("startup") async def listeners(): run_deferred_async() - app.nursery.start_soon(catch_everything_and_restart, check_pending_payments) - app.nursery.start_soon(catch_everything_and_restart, invoice_listener) - app.nursery.start_soon(catch_everything_and_restart, internal_invoice_listener) + trio.open_process(check_pending_payments) + trio.open_process(invoice_listener) + trio.open_process(internal_invoice_listener) + + async with trio.open_nursery() as n: + pass + # n.start_soon(catch_everything_and_restart, check_pending_payments) + # n.start_soon(catch_everything_and_restart, invoice_listener) + # n.start_soon(catch_everything_and_restart, internal_invoice_listener) - @app.after_serving + @app.on_event("shutdown") async def stop_listeners(): pass diff --git a/lnbits/auth_bearer.py b/lnbits/auth_bearer.py new file mode 100644 index 00000000..81b93427 --- /dev/null +++ b/lnbits/auth_bearer.py @@ -0,0 +1,49 @@ +from fastapi import Request, HTTPException +from fastapi.security.api_key import APIKeyQuery, APIKeyCookie, APIKeyHeader, APIKey + +# https://medium.com/data-rebels/fastapi-authentication-revisited-enabling-api-key-authentication-122dc5975680 + +from fastapi import Security, Depends, FastAPI, HTTPException +from fastapi.security.api_key import APIKeyQuery, APIKeyCookie, APIKeyHeader, APIKey +from fastapi.security.base import SecurityBase + + + +API_KEY = "usr" +API_KEY_NAME = "X-API-key" + +api_key_query = APIKeyQuery(name=API_KEY_NAME, auto_error=False) +api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) + + + +class AuthBearer(SecurityBase): + def __init__(self, scheme_name: str = None, auto_error: bool = True): + self.scheme_name = scheme_name or self.__class__.__name__ + self.auto_error = auto_error + + async def __call__(self, request: Request): + key = await self.get_api_key() + print(key) + # credentials: HTTPAuthorizationCredentials = await super(AuthBearer, self).__call__(request) + # if credentials: + # if not credentials.scheme == "Bearer": + # raise HTTPException( + # status_code=403, detail="Invalid authentication scheme.") + # if not self.verify_jwt(credentials.credentials): + # raise HTTPException( + # status_code=403, detail="Invalid token or expired token.") + # return credentials.credentials + # else: + # raise HTTPException( + # status_code=403, detail="Invalid authorization code.") + async def get_api_key(self, + api_key_query: str = Security(api_key_query), + api_key_header: str = Security(api_key_header), + ): + if api_key_query == API_KEY: + return api_key_query + elif api_key_header == API_KEY: + return api_key_header + else: + raise HTTPException(status_code=403, detail="Could not validate credentials") \ No newline at end of file diff --git a/lnbits/core/__init__.py b/lnbits/core/__init__.py index 12dcded8..d988d573 100644 --- a/lnbits/core/__init__.py +++ b/lnbits/core/__init__.py @@ -1,22 +1,19 @@ -from quart import Blueprint +from fastapi.routing import APIRouter + from lnbits.db import Database db = Database("database") -core_app: Blueprint = Blueprint( - "core", - __name__, - template_folder="templates", - static_folder="static", - static_url_path="/core/static", -) - - -from .views.api import * # noqa -from .views.generic import * # noqa -from .views.public_api import * # noqa -from .tasks import register_listeners +core_app: APIRouter = APIRouter() from lnbits.tasks import record_async -core_app.record(record_async(register_listeners)) +from .tasks import register_listeners +from .views.api import * # noqa +from .views.generic import * # noqa +from .views.public_api import * # noqa + + +@core_app.on_event("startup") +def do_startup(): + record_async(register_listeners) diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index b47ba51e..c266d72d 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -1,10 +1,12 @@ +from fastapi.param_functions import Depends +from lnbits.auth_bearer import AuthBearer from pydantic import BaseModel import trio import json import httpx import hashlib from urllib.parse import urlparse, urlunparse, urlencode, parse_qs, ParseResult -from quart import g, current_app, make_response, url_for +from quart import current_app, make_response, url_for from fastapi import Query @@ -15,6 +17,7 @@ from typing import Dict, List, Optional, Union from lnbits import bolt11, lnurl from lnbits.decorators import api_check_wallet_key, api_validate_post_request from lnbits.utils.exchange_rates import currencies, fiat_amount_as_satoshis +from lnbits.requestvars import g from .. import core_app, db from ..crud import get_payments, save_balance_check, update_wallet @@ -28,11 +31,14 @@ from ..services import ( from ..tasks import api_invoice_listeners -@core_app.get("/api/v1/wallet") -@api_check_wallet_key("invoice") +@core_app.get( + "/api/v1/wallet", + # dependencies=[Depends(AuthBearer())] +) +# @api_check_wallet_key("invoice") async def api_wallet(): return ( - {"id": g.wallet.id, "name": g.wallet.name, "balance": g.wallet.balance_msat}, + {"id": g().wallet.id, "name": g().wallet.name, "balance": g().wallet.balance_msat}, HTTPStatus.OK, ) @@ -40,12 +46,12 @@ async def api_wallet(): @core_app.put("/api/v1/wallet/") @api_check_wallet_key("invoice") async def api_update_wallet(new_name: str): - await update_wallet(g.wallet.id, new_name) + await update_wallet(g().wallet.id, new_name) return ( { - "id": g.wallet.id, - "name": g.wallet.name, - "balance": g.wallet.balance_msat, + "id": g().wallet.id, + "name": g().wallet.name, + "balance": g().wallet.balance_msat, }, HTTPStatus.OK, ) @@ -55,7 +61,7 @@ async def api_update_wallet(new_name: str): @api_check_wallet_key("invoice") async def api_payments(): return ( - await get_payments(wallet_id=g.wallet.id, pending=True, complete=True), + await get_payments(wallet_id=g().wallet.id, pending=True, complete=True), HTTPStatus.OK, ) @@ -88,7 +94,7 @@ async def api_payments_create_invoice(data: CreateInvoiceData): async with db.connect() as conn: try: payment_hash, payment_request = await create_invoice( - wallet_id=g.wallet.id, + wallet_id=g().wallet.id, amount=amount, memo=memo, description_hash=description_hash, @@ -105,8 +111,8 @@ async def api_payments_create_invoice(data: CreateInvoiceData): lnurl_response: Union[None, bool, str] = None if data.lnurl_callback: - if "lnurl_balance_check" in g.data: - save_balance_check(g.wallet.id, data.lnurl_balance_check) + if "lnurl_balance_check" in g().data: + save_balance_check(g().wallet.id, data.lnurl_balance_check) async with httpx.AsyncClient() as client: try: @@ -117,7 +123,7 @@ async def api_payments_create_invoice(data: CreateInvoiceData): "balanceNotify": url_for( "core.lnurl_balance_notify", service=urlparse(data.lnurl_callback).netloc, - wal=g.wallet.id, + wal=g().wallet.id, _external=True, ), }, @@ -217,14 +223,14 @@ async def api_payments_pay_lnurl(data: CreateLNURLData): if invoice.amount_msat != data.amount: return ( { - "message": f"{domain} returned an invalid invoice. Expected {g.data['amount']} msat, got {invoice.amount_msat}." + "message": f"{domain} returned an invalid invoice. Expected {g().data['amount']} msat, got {invoice.amount_msat}." }, HTTPStatus.BAD_REQUEST, ) - if invoice.description_hash != g.data["description_hash"]: + if invoice.description_hash != g().data["description_hash"]: return ( { - "message": f"{domain} returned an invalid invoice. Expected description_hash == {g.data['description_hash']}, got {invoice.description_hash}." + "message": f"{domain} returned an invalid invoice. Expected description_hash == {g().data['description_hash']}, got {invoice.description_hash}." }, HTTPStatus.BAD_REQUEST, ) @@ -237,7 +243,7 @@ async def api_payments_pay_lnurl(data: CreateLNURLData): extra["comment"] = data.comment payment_hash = await pay_invoice( - wallet_id=g.wallet.id, + wallet_id=g().wallet.id, payment_request=params["pr"], description=data.description, extra=extra, @@ -257,7 +263,7 @@ async def api_payments_pay_lnurl(data: CreateLNURLData): @core_app.get("/api/v1/payments/") @api_check_wallet_key("invoice") async def api_payment(payment_hash): - payment = await g.wallet.get_payment(payment_hash) + payment = await g().wallet.get_payment(payment_hash) if not payment: return {"message": "Payment does not exist."}, HTTPStatus.NOT_FOUND @@ -278,7 +284,7 @@ async def api_payment(payment_hash): @core_app.get("/api/v1/payments/sse") @api_check_wallet_key("invoice", accept_querystring=True) async def api_payments_sse(): - this_wallet_id = g.wallet.id + this_wallet_id = g().wallet.id send_payment, receive_payment = trio.open_memory_channel(0) @@ -356,7 +362,7 @@ async def api_lnurlscan(code: str): params.update(kind="auth") params.update(callback=url) # with k1 already in it - lnurlauth_key = g.wallet.lnurlauth_key(domain) + lnurlauth_key = g().wallet.lnurlauth_key(domain) params.update(pubkey=lnurlauth_key.verifying_key.to_string("compressed").hex()) else: async with httpx.AsyncClient() as client: diff --git a/lnbits/core/views/generic.py b/lnbits/core/views/generic.py index c781fb92..cd997e3b 100644 --- a/lnbits/core/views/generic.py +++ b/lnbits/core/views/generic.py @@ -1,15 +1,10 @@ +from lnbits.requestvars import g from os import path from http import HTTPStatus -from quart import ( - g, - current_app, - abort, - request, - redirect, - render_template, - send_from_directory, - url_for, -) +from typing import Optional +import jinja2 + +from starlette.responses import HTMLResponse from lnbits.core import core_app, db from lnbits.decorators import check_user_exists, validate_uuids @@ -26,20 +21,18 @@ from ..crud import ( ) from ..services import redeem_lnurl_withdraw, pay_invoice from fastapi import FastAPI, Request -from fastapi.templating import Jinja2Templates +from fastapi.responses import FileResponse +from lnbits.jinja2_templating import Jinja2Templates -templates = Jinja2Templates(directory="templates") @core_app.get("/favicon.ico") async def favicon(): - return await send_from_directory( - path.join(core_app.root_path, "static"), "favicon.ico" - ) + return FileResponse("lnbits/core/static/favicon.ico") + - -@core_app.get("/") +@core_app.get("/", response_class=HTMLResponse) async def home(request: Request, lightning: str = None): - return templates.TemplateResponse("core/index.html", {"request": request, "lnurl": lightning}) + return g().templates.TemplateResponse("core/index.html", {"request": request, "lnurl": lightning}) @core_app.get("/extensions") diff --git a/lnbits/decorators.py b/lnbits/decorators.py index 5d923c35..a5a270e1 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -7,7 +7,7 @@ from uuid import UUID from lnbits.core.crud import get_user, get_wallet_for_key from lnbits.settings import LNBITS_ALLOWED_USERS - +from lnbits.requestvars import g def api_check_wallet_key(key_type: str = "invoice", accept_querystring=False): def wrap(view): @@ -15,14 +15,14 @@ def api_check_wallet_key(key_type: str = "invoice", accept_querystring=False): async def wrapped_view(**kwargs): try: key_value = request.headers.get("X-Api-Key") or request.args["api-key"] - g.wallet = await get_wallet_for_key(key_value, key_type) + g().wallet = await get_wallet_for_key(key_value, key_type) except KeyError: return ( jsonify({"message": "`X-Api-Key` header missing."}), HTTPStatus.BAD_REQUEST, ) - if not g.wallet: + if not g().wallet: return jsonify({"message": "Wrong keys."}), HTTPStatus.UNAUTHORIZED return await view(**kwargs) @@ -44,9 +44,9 @@ def api_validate_post_request(*, schema: dict): v = Validator(schema) data = await request.get_json() - g.data = {key: data[key] for key in schema.keys() if key in data} + g().data = {key: data[key] for key in schema.keys() if key in data} - if not v.validate(g.data): + if not v.validate(g().data): return ( jsonify({"message": f"Errors in request data: {v.errors}"}), HTTPStatus.BAD_REQUEST, @@ -63,11 +63,11 @@ def check_user_exists(param: str = "usr"): def wrap(view): @wraps(view) async def wrapped_view(**kwargs): - g.user = await get_user(request.args.get(param, type=str)) or abort( + g().user = await get_user(request.args.get(param, type=str)) or abort( HTTPStatus.NOT_FOUND, "User does not exist." ) - if LNBITS_ALLOWED_USERS and g.user.id not in LNBITS_ALLOWED_USERS: + if LNBITS_ALLOWED_USERS and g().user.id not in LNBITS_ALLOWED_USERS: abort(HTTPStatus.UNAUTHORIZED, "User not authorized.") return await view(**kwargs) diff --git a/lnbits/extensions/offlineshop/__init__.py b/lnbits/extensions/offlineshop/__init__.py index 1f9dd123..bde90f3f 100644 --- a/lnbits/extensions/offlineshop/__init__.py +++ b/lnbits/extensions/offlineshop/__init__.py @@ -1,11 +1,12 @@ -from quart import Blueprint +from fastapi import APIRouter from lnbits.db import Database db = Database("ext_offlineshop") -offlineshop_ext: Blueprint = Blueprint( - "offlineshop", __name__, static_folder="static", template_folder="templates" +offlineshop_ext: APIRouter = APIRouter( + prefix="/Extension", + tags=["Apps", "Offlineshop"] ) diff --git a/lnbits/jinja2_templating.py b/lnbits/jinja2_templating.py new file mode 100644 index 00000000..f3303445 --- /dev/null +++ b/lnbits/jinja2_templating.py @@ -0,0 +1,36 @@ +# Borrowed from the excellent accent-starlette +# https://github.com/accent-starlette/starlette-core/blob/master/starlette_core/templating.py + +import typing + +from starlette import templating +from starlette.datastructures import QueryParams + +from lnbits.requestvars import g + +try: + import jinja2 +except ImportError: # pragma: nocover + jinja2 = None # type: ignore + + +class Jinja2Templates(templating.Jinja2Templates): + def __init__(self, loader: jinja2.BaseLoader) -> None: + assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates" + self.env = self.get_environment(loader) + + def get_environment(self, loader: "jinja2.BaseLoader") -> "jinja2.Environment": + @jinja2.contextfunction + def url_for(context: dict, name: str, **path_params: typing.Any) -> str: + request = context["request"] + return request.url_for(name, **path_params) + + def url_params_update(init: QueryParams, **new: typing.Any) -> QueryParams: + values = dict(init) + values.update(new) + return QueryParams(**values) + + env = jinja2.Environment(loader=loader, autoescape=True) + env.globals["url_for"] = url_for + env.globals["url_params_update"] = url_params_update + return env diff --git a/lnbits/requestvars.py b/lnbits/requestvars.py new file mode 100644 index 00000000..7dcf9203 --- /dev/null +++ b/lnbits/requestvars.py @@ -0,0 +1,9 @@ +import contextvars +import types + +request_global = contextvars.ContextVar("request_global", + default=types.SimpleNamespace()) + + +def g() -> types.SimpleNamespace: + return request_global.get() diff --git a/lnbits/templates/base.html b/lnbits/templates/base.html index aa673bd1..b49cb3b9 100644 --- a/lnbits/templates/base.html +++ b/lnbits/templates/base.html @@ -2,7 +2,7 @@ - {% for url in g.VENDORED_CSS %} + {% for url in VENDORED_CSS %} {% endfor %} @@ -184,7 +184,7 @@ {% block vue_templates %}{% endblock %} - {% for url in g.VENDORED_JS %} + {% for url in VENDORED_JS %} {% endfor %} diff --git a/lnbits/templates/print.html b/lnbits/templates/print.html index 3b0d0782..0cfc6c64 100644 --- a/lnbits/templates/print.html +++ b/lnbits/templates/print.html @@ -2,7 +2,7 @@ - {% for url in g.VENDORED_CSS %} + {% for url in VENDORED_CSS %} {% endfor %}