Add redirect paths for extensions (#1532)

* feat: basic redirect to extension endpoint

* feat: filter by headers

* refactor: extract `middleware.py`

* fix: do not add twice the same extension to redirects

* chore: code clean-up
This commit is contained in:
Vlad Stan 2023-02-22 11:12:16 +02:00 committed by GitHub
parent 84e369aad2
commit 0d5fef1cb9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 151 additions and 55 deletions

View file

@ -32,18 +32,14 @@ from .core import (
) )
from .core.services import check_admin_settings from .core.services import check_admin_settings
from .core.views.generic import core_html_routes from .core.views.generic import core_html_routes
from .extension_manager import ( from .extension_manager import Extension, InstallableExtension, get_valid_extensions
Extension,
InstallableExtension,
InstalledExtensionMiddleware,
get_valid_extensions,
)
from .helpers import ( from .helpers import (
get_css_vendored, get_css_vendored,
get_js_vendored, get_js_vendored,
template_renderer, template_renderer,
url_for_vendored, url_for_vendored,
) )
from .middleware import ExtensionsRedirectMiddleware, InstalledExtensionMiddleware
from .requestvars import g from .requestvars import g
from .tasks import ( from .tasks import (
catch_everything_and_restart, catch_everything_and_restart,
@ -81,7 +77,10 @@ def create_app() -> FastAPI:
) )
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
# order of these two middlewares is important
app.add_middleware(InstalledExtensionMiddleware) app.add_middleware(InstalledExtensionMiddleware)
app.add_middleware(ExtensionsRedirectMiddleware)
register_startup(app) register_startup(app)
register_assets(app) register_assets(app)
@ -240,6 +239,15 @@ def register_ext_routes(app: FastAPI, ext: Extension) -> None:
for s in ext_statics: for s in ext_statics:
app.mount(s["path"], s["app"], s["name"]) app.mount(s["path"], s["app"], s["name"])
if hasattr(ext_module, f"{ext.code}_redirect_paths"):
ext_redirects = getattr(ext_module, f"{ext.code}_redirect_paths")
settings.lnbits_extensions_redirects = [
r for r in settings.lnbits_extensions_redirects if r["ext_id"] != ext.code
]
for r in ext_redirects:
r["ext_id"] = ext.code
settings.lnbits_extensions_redirects.append(r)
logger.trace(f"adding route for extension {ext_module}") logger.trace(f"adding route for extension {ext_module}")
prefix = f"/upgrades/{ext.upgrade_hash}" if ext.upgrade_hash != "" else "" prefix = f"/upgrades/{ext.upgrade_hash}" if ext.upgrade_hash != "" else ""

View file

@ -11,10 +11,8 @@ from typing import Any, List, NamedTuple, Optional, Tuple
import httpx import httpx
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from fastapi.responses import JSONResponse
from loguru import logger from loguru import logger
from pydantic import BaseModel from pydantic import BaseModel
from starlette.types import ASGIApp, Receive, Scope, Send
from lnbits.settings import settings from lnbits.settings import settings
@ -461,51 +459,6 @@ class InstallableExtension(BaseModel):
return selected_release[0] if len(selected_release) != 0 else None return selected_release[0] if len(selected_release) != 0 else None
class InstalledExtensionMiddleware:
# This middleware class intercepts calls made to the extensions API and:
# - it blocks the calls if the extension has been disabled or uninstalled.
# - it redirects the calls to the latest version of the extension if the extension has been upgraded.
# - otherwise it has no effect
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if "path" not in scope:
await self.app(scope, receive, send)
return
path_elements = scope["path"].split("/")
if len(path_elements) > 2:
_, path_name, path_type, *rest = path_elements
else:
_, path_name = path_elements
path_type = None
# block path for all users if the extension is disabled
if path_name in settings.lnbits_deactivated_extensions:
response = JSONResponse(
status_code=HTTPStatus.NOT_FOUND,
content={"detail": f"Extension '{path_name}' disabled"},
)
await response(scope, receive, send)
return
# re-route API trafic if the extension has been upgraded
if path_type == "api":
upgraded_extensions = list(
filter(
lambda ext: ext.endswith(f"/{path_name}"),
settings.lnbits_upgraded_extensions,
)
)
if len(upgraded_extensions) != 0:
upgrade_path = upgraded_extensions[0]
tail = "/".join(rest)
scope["path"] = f"/upgrades/{upgrade_path}/{path_type}/{tail}"
await self.app(scope, receive, send)
class CreateExtension(BaseModel): class CreateExtension(BaseModel):
ext_id: str ext_id: str
archive: str archive: str

133
lnbits/middleware.py Normal file
View file

@ -0,0 +1,133 @@
from http import HTTPStatus
from typing import List, Tuple
from fastapi.responses import JSONResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from lnbits.settings import settings
class InstalledExtensionMiddleware:
# This middleware class intercepts calls made to the extensions API and:
# - it blocks the calls if the extension has been disabled or uninstalled.
# - it redirects the calls to the latest version of the extension if the extension has been upgraded.
# - otherwise it has no effect
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if "path" not in scope:
await self.app(scope, receive, send)
return
path_elements = scope["path"].split("/")
if len(path_elements) > 2:
_, path_name, path_type, *rest = path_elements
else:
_, path_name = path_elements
path_type = None
# block path for all users if the extension is disabled
if path_name in settings.lnbits_deactivated_extensions:
response = JSONResponse(
status_code=HTTPStatus.NOT_FOUND,
content={"detail": f"Extension '{path_name}' disabled"},
)
await response(scope, receive, send)
return
# re-route API trafic if the extension has been upgraded
if path_type == "api":
upgraded_extensions = list(
filter(
lambda ext: ext.endswith(f"/{path_name}"),
settings.lnbits_upgraded_extensions,
)
)
if len(upgraded_extensions) != 0:
upgrade_path = upgraded_extensions[0]
tail = "/".join(rest)
scope["path"] = f"/upgrades/{upgrade_path}/{path_type}/{tail}"
await self.app(scope, receive, send)
class ExtensionsRedirectMiddleware:
# Extensions are allowed to specify redirect paths.
# A call to a path outside the scope of the extension can be redirected to one of the extension's endpoints.
# Eg: redirect `GET /.well-known` to `GET /lnurlp/api/v1/well-known`
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if "path" not in scope:
await self.app(scope, receive, send)
return
req_headers = scope["headers"] if "headers" in scope else []
redirect = self._find_redirect(scope["path"], req_headers)
if redirect:
scope["path"] = self._new_path(redirect, scope["path"])
await self.app(scope, receive, send)
def _find_redirect(self, path: str, req_headers: List[Tuple[bytes, bytes]]):
return next(
(
r
for r in settings.lnbits_extensions_redirects
if self._redirect_matches(r, path, req_headers)
),
None,
)
def _redirect_matches(
self, redirect: dict, path: str, req_headers: List[Tuple[bytes, bytes]]
) -> bool:
if "from_path" not in redirect:
return False
header_filters = (
redirect["header_filters"] if "header_filters" in redirect else []
)
return self._has_common_path(redirect["from_path"], path) and self._has_headers(
header_filters, req_headers
)
def _has_headers(
self, filter_headers: dict, req_headers: List[Tuple[bytes, bytes]]
) -> bool:
for h in filter_headers:
if not self._has_header(req_headers, (str(h), str(filter_headers[h]))):
return False
return True
def _has_header(
self, req_headers: List[Tuple[bytes, bytes]], header: Tuple[str, str]
) -> bool:
for h in req_headers:
if (
h[0].decode().lower() == header[0].lower()
and h[1].decode() == header[1]
):
return True
return False
def _has_common_path(self, redirect_path: str, req_path: str) -> bool:
redirect_path_elements = redirect_path.split("/")
req_path_elements = req_path.split("/")
if len(redirect_path) > len(req_path):
return False
sub_path = req_path_elements[: len(redirect_path_elements)]
return redirect_path == "/".join(sub_path)
def _new_path(self, redirect: dict, req_path: str) -> str:
from_path = redirect["from_path"].split("/")
redirect_to = redirect["redirect_to_path"].split("/")
req_tail_path = req_path.split("/")[len(from_path) :]
elements = [
e for e in ([redirect["ext_id"]] + redirect_to + req_tail_path) if e != ""
]
return "/" + "/".join(elements)

View file

@ -4,7 +4,7 @@ import json
import subprocess import subprocess
from os import path from os import path
from sqlite3 import Row from sqlite3 import Row
from typing import List, Optional from typing import Any, List, Optional
import httpx import httpx
from loguru import logger from loguru import logger
@ -59,6 +59,8 @@ class InstalledExtensionsSettings(LNbitsSettings):
lnbits_deactivated_extensions: List[str] = Field(default=[]) lnbits_deactivated_extensions: List[str] = Field(default=[])
# upgraded extensions that require API redirects # upgraded extensions that require API redirects
lnbits_upgraded_extensions: List[str] = Field(default=[]) lnbits_upgraded_extensions: List[str] = Field(default=[])
# list of redirects that extensions want to perform
lnbits_extensions_redirects: List[Any] = Field(default=[])
class ThemesSettings(LNbitsSettings): class ThemesSettings(LNbitsSettings):
@ -264,7 +266,7 @@ class SuperUserSettings(LNbitsSettings):
class TransientSettings(InstalledExtensionsSettings): class TransientSettings(InstalledExtensionsSettings):
# Transient Settings: # Transient Settings:
# - are initialized, updated and used at runtime # - are initialized, updated and used at runtime
# - are not read from a file or from the `setings` table # - are not read from a file or from the `settings` table
# - are not persisted in the `settings` table when the settings are updated # - are not persisted in the `settings` table when the settings are updated
# - are cleared on server restart # - are cleared on server restart