From 938fc54af3479996a95de08e0553b02f93c56ca9 Mon Sep 17 00:00:00 2001 From: Stefan Stammberger Date: Sun, 22 Aug 2021 20:07:24 +0200 Subject: [PATCH] feat: switch from Quart to FastAPI part I --- lnbits/__main__.py | 11 +- lnbits/app.py | 123 ++++++++++++++-------- lnbits/auth_bearer.py | 49 +++++++++ lnbits/core/__init__.py | 27 +++-- lnbits/core/views/api.py | 46 ++++---- lnbits/core/views/generic.py | 29 ++--- lnbits/decorators.py | 14 +-- lnbits/extensions/offlineshop/__init__.py | 7 +- lnbits/jinja2_templating.py | 36 +++++++ lnbits/requestvars.py | 9 ++ lnbits/templates/base.html | 4 +- lnbits/templates/print.html | 4 +- 12 files changed, 245 insertions(+), 114 deletions(-) create mode 100644 lnbits/auth_bearer.py create mode 100644 lnbits/jinja2_templating.py create mode 100644 lnbits/requestvars.py diff --git a/lnbits/__main__.py b/lnbits/__main__.py index 90b08642..bd564787 100644 --- a/lnbits/__main__.py +++ b/lnbits/__main__.py @@ -1,4 +1,7 @@ +from hypercorn.trio import serve import trio +import trio_asyncio +from hypercorn.config import Config from .commands import migrate_databases, transpile_scss, bundle_vendored @@ -8,7 +11,7 @@ bundle_vendored() from .app import create_app -app = create_app() +app = trio.run(create_app) from .settings import ( LNBITS_SITE_TITLE, @@ -17,6 +20,8 @@ from .settings import ( LNBITS_DATA_FOLDER, WALLET, LNBITS_COMMIT, + HOST, + PORT ) print( @@ -30,4 +35,6 @@ print( """ ) -app.run(host=app.config["HOST"], port=app.config["PORT"]) +config = Config() +config.bind = [f"{HOST}:{PORT}"] +trio_asyncio.run(serve, app, config) diff --git a/lnbits/app.py b/lnbits/app.py index 5da9a195..a686d5f7 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -1,12 +1,15 @@ +import jinja2 +from lnbits.jinja2_templating import Jinja2Templates import sys import warnings import importlib import traceback +import trio -from quart import g -from quart_trio import QuartTrio -from quart_cors import cors # type: ignore -from quart_compress import Compress # type: ignore +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.gzip import GZipMiddleware +from fastapi.staticfiles import StaticFiles from .commands import db_migrate, handle_assets from .core import core_app @@ -26,32 +29,66 @@ from .tasks import ( catch_everything_and_restart, ) from .settings import WALLET +from .requestvars import g, request_global +import lnbits.settings - -def create_app(config_object="lnbits.settings") -> QuartTrio: +async def create_app(config_object="lnbits.settings") -> FastAPI: """Create application factory. :param config_object: The configuration object to use. """ - app = QuartTrio(__name__, static_folder="static") - app.config.from_object(config_object) - app.asgi_http_class = ASGIProxyFix + app = FastAPI() + app.mount("/static", StaticFiles(directory="lnbits/static"), name="static") - cors(app) - Compress(app) + origins = [ + "http://localhost", + "http://localhost:5000", + ] + + app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + g().config = lnbits.settings + g().templates = build_standard_jinja_templates() + + app.add_middleware(GZipMiddleware, minimum_size=1000) + # app.add_middleware(ASGIProxyFix) check_funding_source(app) register_assets(app) - register_blueprints(app) - register_filters(app) - register_commands(app) + register_routes(app) + # register_commands(app) register_async_tasks(app) - register_exception_handlers(app) + # register_exception_handlers(app) return app +def build_standard_jinja_templates(): + t = Jinja2Templates( + loader=jinja2.FileSystemLoader(["lnbits/templates", "lnbits/core/templates"]), + ) + t.env.globals["SITE_TITLE"] = lnbits.settings.LNBITS_SITE_TITLE + t.env.globals["SITE_TAGLINE"] = lnbits.settings.LNBITS_SITE_TAGLINE + t.env.globals["SITE_DESCRIPTION"] = lnbits.settings.LNBITS_SITE_DESCRIPTION + t.env.globals["LNBITS_THEME_OPTIONS"] = lnbits.settings.LNBITS_THEME_OPTIONS + t.env.globals["LNBITS_VERSION"] = lnbits.settings.LNBITS_COMMIT + t.env.globals["EXTENSIONS"] = get_valid_extensions() + + if g().config.DEBUG: + t.env.globals["VENDORED_JS"] = map(url_for_vendored, get_js_vendored()) + t.env.globals["VENDORED_CSS"] = map(url_for_vendored, get_css_vendored()) + else: + t.env.globals["VENDORED_JS"] = ["/static/bundle.js"] + t.env.globals["VENDORED_CSS"] = ["/static/bundle.css"] -def check_funding_source(app: QuartTrio) -> None: - @app.before_serving + return t + +def check_funding_source(app: FastAPI) -> None: + @app.on_event("startup") async def check_wallet_status(): error_message, balance = await WALLET.status() if error_message: @@ -67,64 +104,60 @@ def check_funding_source(app: QuartTrio) -> None: ) -def register_blueprints(app: QuartTrio) -> None: +def register_routes(app: FastAPI) -> None: """Register Flask blueprints / LNbits extensions.""" - app.register_blueprint(core_app) + app.include_router(core_app) for ext in get_valid_extensions(): try: ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}") - bp = getattr(ext_module, f"{ext.code}_ext") + ext_route = getattr(ext_module, f"{ext.code}_ext") - app.register_blueprint(bp, url_prefix=f"/{ext.code}") + app.include_router(ext_route) except Exception: raise ImportError( f"Please make sure that the extension `{ext.code}` follows conventions." ) -def register_commands(app: QuartTrio): +def register_commands(app: FastAPI): """Register Click commands.""" app.cli.add_command(db_migrate) app.cli.add_command(handle_assets) -def register_assets(app: QuartTrio): +def register_assets(app: FastAPI): """Serve each vendored asset separately or a bundle.""" - @app.before_request + @app.on_event("startup") async def vendored_assets_variable(): - if app.config["DEBUG"]: - g.VENDORED_JS = map(url_for_vendored, get_js_vendored()) - g.VENDORED_CSS = map(url_for_vendored, get_css_vendored()) + if g().config.DEBUG: + g().VENDORED_JS = map(url_for_vendored, get_js_vendored()) + g().VENDORED_CSS = map(url_for_vendored, get_css_vendored()) else: - g.VENDORED_JS = ["/static/bundle.js"] - g.VENDORED_CSS = ["/static/bundle.css"] - - -def register_filters(app: QuartTrio): - """Jinja filters.""" - app.jinja_env.globals["SITE_TITLE"] = app.config["LNBITS_SITE_TITLE"] - app.jinja_env.globals["SITE_TAGLINE"] = app.config["LNBITS_SITE_TAGLINE"] - app.jinja_env.globals["SITE_DESCRIPTION"] = app.config["LNBITS_SITE_DESCRIPTION"] - app.jinja_env.globals["LNBITS_THEME_OPTIONS"] = app.config["LNBITS_THEME_OPTIONS"] - app.jinja_env.globals["LNBITS_VERSION"] = app.config["LNBITS_COMMIT"] - app.jinja_env.globals["EXTENSIONS"] = get_valid_extensions() + g().VENDORED_JS = ["/static/bundle.js"] + g().VENDORED_CSS = ["/static/bundle.css"] def register_async_tasks(app): - @app.route("/wallet/webhook", methods=["GET", "POST", "PUT", "PATCH", "DELETE"]) + @app.route("/wallet/webhook") async def webhook_listener(): return await webhook_handler() - @app.before_serving + @app.on_event("startup") async def listeners(): run_deferred_async() - app.nursery.start_soon(catch_everything_and_restart, check_pending_payments) - app.nursery.start_soon(catch_everything_and_restart, invoice_listener) - app.nursery.start_soon(catch_everything_and_restart, internal_invoice_listener) + trio.open_process(check_pending_payments) + trio.open_process(invoice_listener) + trio.open_process(internal_invoice_listener) + + async with trio.open_nursery() as n: + pass + # n.start_soon(catch_everything_and_restart, check_pending_payments) + # n.start_soon(catch_everything_and_restart, invoice_listener) + # n.start_soon(catch_everything_and_restart, internal_invoice_listener) - @app.after_serving + @app.on_event("shutdown") async def stop_listeners(): pass diff --git a/lnbits/auth_bearer.py b/lnbits/auth_bearer.py new file mode 100644 index 00000000..81b93427 --- /dev/null +++ b/lnbits/auth_bearer.py @@ -0,0 +1,49 @@ +from fastapi import Request, HTTPException +from fastapi.security.api_key import APIKeyQuery, APIKeyCookie, APIKeyHeader, APIKey + +# https://medium.com/data-rebels/fastapi-authentication-revisited-enabling-api-key-authentication-122dc5975680 + +from fastapi import Security, Depends, FastAPI, HTTPException +from fastapi.security.api_key import APIKeyQuery, APIKeyCookie, APIKeyHeader, APIKey +from fastapi.security.base import SecurityBase + + + +API_KEY = "usr" +API_KEY_NAME = "X-API-key" + +api_key_query = APIKeyQuery(name=API_KEY_NAME, auto_error=False) +api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) + + + +class AuthBearer(SecurityBase): + def __init__(self, scheme_name: str = None, auto_error: bool = True): + self.scheme_name = scheme_name or self.__class__.__name__ + self.auto_error = auto_error + + async def __call__(self, request: Request): + key = await self.get_api_key() + print(key) + # credentials: HTTPAuthorizationCredentials = await super(AuthBearer, self).__call__(request) + # if credentials: + # if not credentials.scheme == "Bearer": + # raise HTTPException( + # status_code=403, detail="Invalid authentication scheme.") + # if not self.verify_jwt(credentials.credentials): + # raise HTTPException( + # status_code=403, detail="Invalid token or expired token.") + # return credentials.credentials + # else: + # raise HTTPException( + # status_code=403, detail="Invalid authorization code.") + async def get_api_key(self, + api_key_query: str = Security(api_key_query), + api_key_header: str = Security(api_key_header), + ): + if api_key_query == API_KEY: + return api_key_query + elif api_key_header == API_KEY: + return api_key_header + else: + raise HTTPException(status_code=403, detail="Could not validate credentials") \ No newline at end of file diff --git a/lnbits/core/__init__.py b/lnbits/core/__init__.py index 12dcded8..d988d573 100644 --- a/lnbits/core/__init__.py +++ b/lnbits/core/__init__.py @@ -1,22 +1,19 @@ -from quart import Blueprint +from fastapi.routing import APIRouter + from lnbits.db import Database db = Database("database") -core_app: Blueprint = Blueprint( - "core", - __name__, - template_folder="templates", - static_folder="static", - static_url_path="/core/static", -) - - -from .views.api import * # noqa -from .views.generic import * # noqa -from .views.public_api import * # noqa -from .tasks import register_listeners +core_app: APIRouter = APIRouter() from lnbits.tasks import record_async -core_app.record(record_async(register_listeners)) +from .tasks import register_listeners +from .views.api import * # noqa +from .views.generic import * # noqa +from .views.public_api import * # noqa + + +@core_app.on_event("startup") +def do_startup(): + record_async(register_listeners) diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index b47ba51e..c266d72d 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -1,10 +1,12 @@ +from fastapi.param_functions import Depends +from lnbits.auth_bearer import AuthBearer from pydantic import BaseModel import trio import json import httpx import hashlib from urllib.parse import urlparse, urlunparse, urlencode, parse_qs, ParseResult -from quart import g, current_app, make_response, url_for +from quart import current_app, make_response, url_for from fastapi import Query @@ -15,6 +17,7 @@ from typing import Dict, List, Optional, Union from lnbits import bolt11, lnurl from lnbits.decorators import api_check_wallet_key, api_validate_post_request from lnbits.utils.exchange_rates import currencies, fiat_amount_as_satoshis +from lnbits.requestvars import g from .. import core_app, db from ..crud import get_payments, save_balance_check, update_wallet @@ -28,11 +31,14 @@ from ..services import ( from ..tasks import api_invoice_listeners -@core_app.get("/api/v1/wallet") -@api_check_wallet_key("invoice") +@core_app.get( + "/api/v1/wallet", + # dependencies=[Depends(AuthBearer())] +) +# @api_check_wallet_key("invoice") async def api_wallet(): return ( - {"id": g.wallet.id, "name": g.wallet.name, "balance": g.wallet.balance_msat}, + {"id": g().wallet.id, "name": g().wallet.name, "balance": g().wallet.balance_msat}, HTTPStatus.OK, ) @@ -40,12 +46,12 @@ async def api_wallet(): @core_app.put("/api/v1/wallet/") @api_check_wallet_key("invoice") async def api_update_wallet(new_name: str): - await update_wallet(g.wallet.id, new_name) + await update_wallet(g().wallet.id, new_name) return ( { - "id": g.wallet.id, - "name": g.wallet.name, - "balance": g.wallet.balance_msat, + "id": g().wallet.id, + "name": g().wallet.name, + "balance": g().wallet.balance_msat, }, HTTPStatus.OK, ) @@ -55,7 +61,7 @@ async def api_update_wallet(new_name: str): @api_check_wallet_key("invoice") async def api_payments(): return ( - await get_payments(wallet_id=g.wallet.id, pending=True, complete=True), + await get_payments(wallet_id=g().wallet.id, pending=True, complete=True), HTTPStatus.OK, ) @@ -88,7 +94,7 @@ async def api_payments_create_invoice(data: CreateInvoiceData): async with db.connect() as conn: try: payment_hash, payment_request = await create_invoice( - wallet_id=g.wallet.id, + wallet_id=g().wallet.id, amount=amount, memo=memo, description_hash=description_hash, @@ -105,8 +111,8 @@ async def api_payments_create_invoice(data: CreateInvoiceData): lnurl_response: Union[None, bool, str] = None if data.lnurl_callback: - if "lnurl_balance_check" in g.data: - save_balance_check(g.wallet.id, data.lnurl_balance_check) + if "lnurl_balance_check" in g().data: + save_balance_check(g().wallet.id, data.lnurl_balance_check) async with httpx.AsyncClient() as client: try: @@ -117,7 +123,7 @@ async def api_payments_create_invoice(data: CreateInvoiceData): "balanceNotify": url_for( "core.lnurl_balance_notify", service=urlparse(data.lnurl_callback).netloc, - wal=g.wallet.id, + wal=g().wallet.id, _external=True, ), }, @@ -217,14 +223,14 @@ async def api_payments_pay_lnurl(data: CreateLNURLData): if invoice.amount_msat != data.amount: return ( { - "message": f"{domain} returned an invalid invoice. Expected {g.data['amount']} msat, got {invoice.amount_msat}." + "message": f"{domain} returned an invalid invoice. Expected {g().data['amount']} msat, got {invoice.amount_msat}." }, HTTPStatus.BAD_REQUEST, ) - if invoice.description_hash != g.data["description_hash"]: + if invoice.description_hash != g().data["description_hash"]: return ( { - "message": f"{domain} returned an invalid invoice. Expected description_hash == {g.data['description_hash']}, got {invoice.description_hash}." + "message": f"{domain} returned an invalid invoice. Expected description_hash == {g().data['description_hash']}, got {invoice.description_hash}." }, HTTPStatus.BAD_REQUEST, ) @@ -237,7 +243,7 @@ async def api_payments_pay_lnurl(data: CreateLNURLData): extra["comment"] = data.comment payment_hash = await pay_invoice( - wallet_id=g.wallet.id, + wallet_id=g().wallet.id, payment_request=params["pr"], description=data.description, extra=extra, @@ -257,7 +263,7 @@ async def api_payments_pay_lnurl(data: CreateLNURLData): @core_app.get("/api/v1/payments/") @api_check_wallet_key("invoice") async def api_payment(payment_hash): - payment = await g.wallet.get_payment(payment_hash) + payment = await g().wallet.get_payment(payment_hash) if not payment: return {"message": "Payment does not exist."}, HTTPStatus.NOT_FOUND @@ -278,7 +284,7 @@ async def api_payment(payment_hash): @core_app.get("/api/v1/payments/sse") @api_check_wallet_key("invoice", accept_querystring=True) async def api_payments_sse(): - this_wallet_id = g.wallet.id + this_wallet_id = g().wallet.id send_payment, receive_payment = trio.open_memory_channel(0) @@ -356,7 +362,7 @@ async def api_lnurlscan(code: str): params.update(kind="auth") params.update(callback=url) # with k1 already in it - lnurlauth_key = g.wallet.lnurlauth_key(domain) + lnurlauth_key = g().wallet.lnurlauth_key(domain) params.update(pubkey=lnurlauth_key.verifying_key.to_string("compressed").hex()) else: async with httpx.AsyncClient() as client: diff --git a/lnbits/core/views/generic.py b/lnbits/core/views/generic.py index c781fb92..cd997e3b 100644 --- a/lnbits/core/views/generic.py +++ b/lnbits/core/views/generic.py @@ -1,15 +1,10 @@ +from lnbits.requestvars import g from os import path from http import HTTPStatus -from quart import ( - g, - current_app, - abort, - request, - redirect, - render_template, - send_from_directory, - url_for, -) +from typing import Optional +import jinja2 + +from starlette.responses import HTMLResponse from lnbits.core import core_app, db from lnbits.decorators import check_user_exists, validate_uuids @@ -26,20 +21,18 @@ from ..crud import ( ) from ..services import redeem_lnurl_withdraw, pay_invoice from fastapi import FastAPI, Request -from fastapi.templating import Jinja2Templates +from fastapi.responses import FileResponse +from lnbits.jinja2_templating import Jinja2Templates -templates = Jinja2Templates(directory="templates") @core_app.get("/favicon.ico") async def favicon(): - return await send_from_directory( - path.join(core_app.root_path, "static"), "favicon.ico" - ) + return FileResponse("lnbits/core/static/favicon.ico") + - -@core_app.get("/") +@core_app.get("/", response_class=HTMLResponse) async def home(request: Request, lightning: str = None): - return templates.TemplateResponse("core/index.html", {"request": request, "lnurl": lightning}) + return g().templates.TemplateResponse("core/index.html", {"request": request, "lnurl": lightning}) @core_app.get("/extensions") diff --git a/lnbits/decorators.py b/lnbits/decorators.py index 5d923c35..a5a270e1 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -7,7 +7,7 @@ from uuid import UUID from lnbits.core.crud import get_user, get_wallet_for_key from lnbits.settings import LNBITS_ALLOWED_USERS - +from lnbits.requestvars import g def api_check_wallet_key(key_type: str = "invoice", accept_querystring=False): def wrap(view): @@ -15,14 +15,14 @@ def api_check_wallet_key(key_type: str = "invoice", accept_querystring=False): async def wrapped_view(**kwargs): try: key_value = request.headers.get("X-Api-Key") or request.args["api-key"] - g.wallet = await get_wallet_for_key(key_value, key_type) + g().wallet = await get_wallet_for_key(key_value, key_type) except KeyError: return ( jsonify({"message": "`X-Api-Key` header missing."}), HTTPStatus.BAD_REQUEST, ) - if not g.wallet: + if not g().wallet: return jsonify({"message": "Wrong keys."}), HTTPStatus.UNAUTHORIZED return await view(**kwargs) @@ -44,9 +44,9 @@ def api_validate_post_request(*, schema: dict): v = Validator(schema) data = await request.get_json() - g.data = {key: data[key] for key in schema.keys() if key in data} + g().data = {key: data[key] for key in schema.keys() if key in data} - if not v.validate(g.data): + if not v.validate(g().data): return ( jsonify({"message": f"Errors in request data: {v.errors}"}), HTTPStatus.BAD_REQUEST, @@ -63,11 +63,11 @@ def check_user_exists(param: str = "usr"): def wrap(view): @wraps(view) async def wrapped_view(**kwargs): - g.user = await get_user(request.args.get(param, type=str)) or abort( + g().user = await get_user(request.args.get(param, type=str)) or abort( HTTPStatus.NOT_FOUND, "User does not exist." ) - if LNBITS_ALLOWED_USERS and g.user.id not in LNBITS_ALLOWED_USERS: + if LNBITS_ALLOWED_USERS and g().user.id not in LNBITS_ALLOWED_USERS: abort(HTTPStatus.UNAUTHORIZED, "User not authorized.") return await view(**kwargs) diff --git a/lnbits/extensions/offlineshop/__init__.py b/lnbits/extensions/offlineshop/__init__.py index 1f9dd123..bde90f3f 100644 --- a/lnbits/extensions/offlineshop/__init__.py +++ b/lnbits/extensions/offlineshop/__init__.py @@ -1,11 +1,12 @@ -from quart import Blueprint +from fastapi import APIRouter from lnbits.db import Database db = Database("ext_offlineshop") -offlineshop_ext: Blueprint = Blueprint( - "offlineshop", __name__, static_folder="static", template_folder="templates" +offlineshop_ext: APIRouter = APIRouter( + prefix="/Extension", + tags=["Apps", "Offlineshop"] ) diff --git a/lnbits/jinja2_templating.py b/lnbits/jinja2_templating.py new file mode 100644 index 00000000..f3303445 --- /dev/null +++ b/lnbits/jinja2_templating.py @@ -0,0 +1,36 @@ +# Borrowed from the excellent accent-starlette +# https://github.com/accent-starlette/starlette-core/blob/master/starlette_core/templating.py + +import typing + +from starlette import templating +from starlette.datastructures import QueryParams + +from lnbits.requestvars import g + +try: + import jinja2 +except ImportError: # pragma: nocover + jinja2 = None # type: ignore + + +class Jinja2Templates(templating.Jinja2Templates): + def __init__(self, loader: jinja2.BaseLoader) -> None: + assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates" + self.env = self.get_environment(loader) + + def get_environment(self, loader: "jinja2.BaseLoader") -> "jinja2.Environment": + @jinja2.contextfunction + def url_for(context: dict, name: str, **path_params: typing.Any) -> str: + request = context["request"] + return request.url_for(name, **path_params) + + def url_params_update(init: QueryParams, **new: typing.Any) -> QueryParams: + values = dict(init) + values.update(new) + return QueryParams(**values) + + env = jinja2.Environment(loader=loader, autoescape=True) + env.globals["url_for"] = url_for + env.globals["url_params_update"] = url_params_update + return env diff --git a/lnbits/requestvars.py b/lnbits/requestvars.py new file mode 100644 index 00000000..7dcf9203 --- /dev/null +++ b/lnbits/requestvars.py @@ -0,0 +1,9 @@ +import contextvars +import types + +request_global = contextvars.ContextVar("request_global", + default=types.SimpleNamespace()) + + +def g() -> types.SimpleNamespace: + return request_global.get() diff --git a/lnbits/templates/base.html b/lnbits/templates/base.html index aa673bd1..b49cb3b9 100644 --- a/lnbits/templates/base.html +++ b/lnbits/templates/base.html @@ -2,7 +2,7 @@ - {% for url in g.VENDORED_CSS %} + {% for url in VENDORED_CSS %} {% endfor %} @@ -184,7 +184,7 @@ {% block vue_templates %}{% endblock %} - {% for url in g.VENDORED_JS %} + {% for url in VENDORED_JS %} {% endfor %} diff --git a/lnbits/templates/print.html b/lnbits/templates/print.html index 3b0d0782..0cfc6c64 100644 --- a/lnbits/templates/print.html +++ b/lnbits/templates/print.html @@ -2,7 +2,7 @@ - {% for url in g.VENDORED_CSS %} + {% for url in VENDORED_CSS %} {% endfor %}