diff --git a/lnbits/app.py b/lnbits/app.py index 3eff6451..6161d764 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -32,18 +32,14 @@ from .core import ( ) from .core.services import check_admin_settings from .core.views.generic import core_html_routes -from .extension_manager import ( - Extension, - InstallableExtension, - InstalledExtensionMiddleware, - get_valid_extensions, -) +from .extension_manager import Extension, InstallableExtension, get_valid_extensions from .helpers import ( get_css_vendored, get_js_vendored, template_renderer, url_for_vendored, ) +from .middleware import ExtensionsRedirectMiddleware, InstalledExtensionMiddleware from .requestvars import g from .tasks import ( catch_everything_and_restart, @@ -81,7 +77,10 @@ def create_app() -> FastAPI: ) app.add_middleware(GZipMiddleware, minimum_size=1000) + + # order of these two middlewares is important app.add_middleware(InstalledExtensionMiddleware) + app.add_middleware(ExtensionsRedirectMiddleware) register_startup(app) register_assets(app) @@ -240,6 +239,15 @@ def register_ext_routes(app: FastAPI, ext: Extension) -> None: for s in ext_statics: app.mount(s["path"], s["app"], s["name"]) + if hasattr(ext_module, f"{ext.code}_redirect_paths"): + ext_redirects = getattr(ext_module, f"{ext.code}_redirect_paths") + settings.lnbits_extensions_redirects = [ + r for r in settings.lnbits_extensions_redirects if r["ext_id"] != ext.code + ] + for r in ext_redirects: + r["ext_id"] = ext.code + settings.lnbits_extensions_redirects.append(r) + logger.trace(f"adding route for extension {ext_module}") prefix = f"/upgrades/{ext.upgrade_hash}" if ext.upgrade_hash != "" else "" diff --git a/lnbits/extension_manager.py b/lnbits/extension_manager.py index b08b6b80..9f34d181 100644 --- a/lnbits/extension_manager.py +++ b/lnbits/extension_manager.py @@ -11,10 +11,8 @@ from typing import Any, List, NamedTuple, Optional, Tuple import httpx from fastapi.exceptions import HTTPException -from fastapi.responses import JSONResponse from loguru import logger from pydantic import BaseModel -from starlette.types import ASGIApp, Receive, Scope, Send from lnbits.settings import settings @@ -461,51 +459,6 @@ class InstallableExtension(BaseModel): return selected_release[0] if len(selected_release) != 0 else None -class InstalledExtensionMiddleware: - # This middleware class intercepts calls made to the extensions API and: - # - it blocks the calls if the extension has been disabled or uninstalled. - # - it redirects the calls to the latest version of the extension if the extension has been upgraded. - # - otherwise it has no effect - def __init__(self, app: ASGIApp) -> None: - self.app = app - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if "path" not in scope: - await self.app(scope, receive, send) - return - - path_elements = scope["path"].split("/") - if len(path_elements) > 2: - _, path_name, path_type, *rest = path_elements - else: - _, path_name = path_elements - path_type = None - - # block path for all users if the extension is disabled - if path_name in settings.lnbits_deactivated_extensions: - response = JSONResponse( - status_code=HTTPStatus.NOT_FOUND, - content={"detail": f"Extension '{path_name}' disabled"}, - ) - await response(scope, receive, send) - return - - # re-route API trafic if the extension has been upgraded - if path_type == "api": - upgraded_extensions = list( - filter( - lambda ext: ext.endswith(f"/{path_name}"), - settings.lnbits_upgraded_extensions, - ) - ) - if len(upgraded_extensions) != 0: - upgrade_path = upgraded_extensions[0] - tail = "/".join(rest) - scope["path"] = f"/upgrades/{upgrade_path}/{path_type}/{tail}" - - await self.app(scope, receive, send) - - class CreateExtension(BaseModel): ext_id: str archive: str diff --git a/lnbits/middleware.py b/lnbits/middleware.py new file mode 100644 index 00000000..daac03bf --- /dev/null +++ b/lnbits/middleware.py @@ -0,0 +1,133 @@ +from http import HTTPStatus +from typing import List, Tuple + +from fastapi.responses import JSONResponse +from starlette.types import ASGIApp, Receive, Scope, Send + +from lnbits.settings import settings + + +class InstalledExtensionMiddleware: + # This middleware class intercepts calls made to the extensions API and: + # - it blocks the calls if the extension has been disabled or uninstalled. + # - it redirects the calls to the latest version of the extension if the extension has been upgraded. + # - otherwise it has no effect + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if "path" not in scope: + await self.app(scope, receive, send) + return + + path_elements = scope["path"].split("/") + if len(path_elements) > 2: + _, path_name, path_type, *rest = path_elements + else: + _, path_name = path_elements + path_type = None + + # block path for all users if the extension is disabled + if path_name in settings.lnbits_deactivated_extensions: + response = JSONResponse( + status_code=HTTPStatus.NOT_FOUND, + content={"detail": f"Extension '{path_name}' disabled"}, + ) + await response(scope, receive, send) + return + + # re-route API trafic if the extension has been upgraded + if path_type == "api": + upgraded_extensions = list( + filter( + lambda ext: ext.endswith(f"/{path_name}"), + settings.lnbits_upgraded_extensions, + ) + ) + if len(upgraded_extensions) != 0: + upgrade_path = upgraded_extensions[0] + tail = "/".join(rest) + scope["path"] = f"/upgrades/{upgrade_path}/{path_type}/{tail}" + + await self.app(scope, receive, send) + + +class ExtensionsRedirectMiddleware: + # Extensions are allowed to specify redirect paths. + # A call to a path outside the scope of the extension can be redirected to one of the extension's endpoints. + # Eg: redirect `GET /.well-known` to `GET /lnurlp/api/v1/well-known` + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if "path" not in scope: + await self.app(scope, receive, send) + return + + req_headers = scope["headers"] if "headers" in scope else [] + redirect = self._find_redirect(scope["path"], req_headers) + if redirect: + scope["path"] = self._new_path(redirect, scope["path"]) + + await self.app(scope, receive, send) + + def _find_redirect(self, path: str, req_headers: List[Tuple[bytes, bytes]]): + return next( + ( + r + for r in settings.lnbits_extensions_redirects + if self._redirect_matches(r, path, req_headers) + ), + None, + ) + + def _redirect_matches( + self, redirect: dict, path: str, req_headers: List[Tuple[bytes, bytes]] + ) -> bool: + if "from_path" not in redirect: + return False + header_filters = ( + redirect["header_filters"] if "header_filters" in redirect else [] + ) + return self._has_common_path(redirect["from_path"], path) and self._has_headers( + header_filters, req_headers + ) + + def _has_headers( + self, filter_headers: dict, req_headers: List[Tuple[bytes, bytes]] + ) -> bool: + for h in filter_headers: + if not self._has_header(req_headers, (str(h), str(filter_headers[h]))): + return False + return True + + def _has_header( + self, req_headers: List[Tuple[bytes, bytes]], header: Tuple[str, str] + ) -> bool: + for h in req_headers: + if ( + h[0].decode().lower() == header[0].lower() + and h[1].decode() == header[1] + ): + return True + return False + + def _has_common_path(self, redirect_path: str, req_path: str) -> bool: + redirect_path_elements = redirect_path.split("/") + req_path_elements = req_path.split("/") + if len(redirect_path) > len(req_path): + return False + sub_path = req_path_elements[: len(redirect_path_elements)] + return redirect_path == "/".join(sub_path) + + def _new_path(self, redirect: dict, req_path: str) -> str: + from_path = redirect["from_path"].split("/") + redirect_to = redirect["redirect_to_path"].split("/") + req_tail_path = req_path.split("/")[len(from_path) :] + + elements = [ + e for e in ([redirect["ext_id"]] + redirect_to + req_tail_path) if e != "" + ] + + return "/" + "/".join(elements) diff --git a/lnbits/settings.py b/lnbits/settings.py index 30dd6f62..0f103682 100644 --- a/lnbits/settings.py +++ b/lnbits/settings.py @@ -4,7 +4,7 @@ import json import subprocess from os import path from sqlite3 import Row -from typing import List, Optional +from typing import Any, List, Optional import httpx from loguru import logger @@ -59,6 +59,8 @@ class InstalledExtensionsSettings(LNbitsSettings): lnbits_deactivated_extensions: List[str] = Field(default=[]) # upgraded extensions that require API redirects lnbits_upgraded_extensions: List[str] = Field(default=[]) + # list of redirects that extensions want to perform + lnbits_extensions_redirects: List[Any] = Field(default=[]) class ThemesSettings(LNbitsSettings): @@ -264,7 +266,7 @@ class SuperUserSettings(LNbitsSettings): class TransientSettings(InstalledExtensionsSettings): # Transient Settings: # - are initialized, updated and used at runtime - # - are not read from a file or from the `setings` table + # - are not read from a file or from the `settings` table # - are not persisted in the `settings` table when the settings are updated # - are cleared on server restart