Add redirect paths for extensions (#1532)
* feat: basic redirect to extension endpoint * feat: filter by headers * refactor: extract `middleware.py` * fix: do not add twice the same extension to redirects * chore: code clean-up
This commit is contained in:
parent
84e369aad2
commit
0d5fef1cb9
4 changed files with 151 additions and 55 deletions
|
|
@ -32,18 +32,14 @@ from .core import (
|
|||
)
|
||||
from .core.services import check_admin_settings
|
||||
from .core.views.generic import core_html_routes
|
||||
from .extension_manager import (
|
||||
Extension,
|
||||
InstallableExtension,
|
||||
InstalledExtensionMiddleware,
|
||||
get_valid_extensions,
|
||||
)
|
||||
from .extension_manager import Extension, InstallableExtension, get_valid_extensions
|
||||
from .helpers import (
|
||||
get_css_vendored,
|
||||
get_js_vendored,
|
||||
template_renderer,
|
||||
url_for_vendored,
|
||||
)
|
||||
from .middleware import ExtensionsRedirectMiddleware, InstalledExtensionMiddleware
|
||||
from .requestvars import g
|
||||
from .tasks import (
|
||||
catch_everything_and_restart,
|
||||
|
|
@ -81,7 +77,10 @@ def create_app() -> FastAPI:
|
|||
)
|
||||
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
# order of these two middlewares is important
|
||||
app.add_middleware(InstalledExtensionMiddleware)
|
||||
app.add_middleware(ExtensionsRedirectMiddleware)
|
||||
|
||||
register_startup(app)
|
||||
register_assets(app)
|
||||
|
|
@ -240,6 +239,15 @@ def register_ext_routes(app: FastAPI, ext: Extension) -> None:
|
|||
for s in ext_statics:
|
||||
app.mount(s["path"], s["app"], s["name"])
|
||||
|
||||
if hasattr(ext_module, f"{ext.code}_redirect_paths"):
|
||||
ext_redirects = getattr(ext_module, f"{ext.code}_redirect_paths")
|
||||
settings.lnbits_extensions_redirects = [
|
||||
r for r in settings.lnbits_extensions_redirects if r["ext_id"] != ext.code
|
||||
]
|
||||
for r in ext_redirects:
|
||||
r["ext_id"] = ext.code
|
||||
settings.lnbits_extensions_redirects.append(r)
|
||||
|
||||
logger.trace(f"adding route for extension {ext_module}")
|
||||
|
||||
prefix = f"/upgrades/{ext.upgrade_hash}" if ext.upgrade_hash != "" else ""
|
||||
|
|
|
|||
|
|
@ -11,10 +11,8 @@ from typing import Any, List, NamedTuple, Optional, Tuple
|
|||
|
||||
import httpx
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from lnbits.settings import settings
|
||||
|
||||
|
|
@ -461,51 +459,6 @@ class InstallableExtension(BaseModel):
|
|||
return selected_release[0] if len(selected_release) != 0 else None
|
||||
|
||||
|
||||
class InstalledExtensionMiddleware:
|
||||
# This middleware class intercepts calls made to the extensions API and:
|
||||
# - it blocks the calls if the extension has been disabled or uninstalled.
|
||||
# - it redirects the calls to the latest version of the extension if the extension has been upgraded.
|
||||
# - otherwise it has no effect
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if "path" not in scope:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
path_elements = scope["path"].split("/")
|
||||
if len(path_elements) > 2:
|
||||
_, path_name, path_type, *rest = path_elements
|
||||
else:
|
||||
_, path_name = path_elements
|
||||
path_type = None
|
||||
|
||||
# block path for all users if the extension is disabled
|
||||
if path_name in settings.lnbits_deactivated_extensions:
|
||||
response = JSONResponse(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
content={"detail": f"Extension '{path_name}' disabled"},
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
# re-route API trafic if the extension has been upgraded
|
||||
if path_type == "api":
|
||||
upgraded_extensions = list(
|
||||
filter(
|
||||
lambda ext: ext.endswith(f"/{path_name}"),
|
||||
settings.lnbits_upgraded_extensions,
|
||||
)
|
||||
)
|
||||
if len(upgraded_extensions) != 0:
|
||||
upgrade_path = upgraded_extensions[0]
|
||||
tail = "/".join(rest)
|
||||
scope["path"] = f"/upgrades/{upgrade_path}/{path_type}/{tail}"
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
|
||||
class CreateExtension(BaseModel):
|
||||
ext_id: str
|
||||
archive: str
|
||||
|
|
|
|||
133
lnbits/middleware.py
Normal file
133
lnbits/middleware.py
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
from http import HTTPStatus
|
||||
from typing import List, Tuple
|
||||
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from lnbits.settings import settings
|
||||
|
||||
|
||||
class InstalledExtensionMiddleware:
|
||||
# This middleware class intercepts calls made to the extensions API and:
|
||||
# - it blocks the calls if the extension has been disabled or uninstalled.
|
||||
# - it redirects the calls to the latest version of the extension if the extension has been upgraded.
|
||||
# - otherwise it has no effect
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if "path" not in scope:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
path_elements = scope["path"].split("/")
|
||||
if len(path_elements) > 2:
|
||||
_, path_name, path_type, *rest = path_elements
|
||||
else:
|
||||
_, path_name = path_elements
|
||||
path_type = None
|
||||
|
||||
# block path for all users if the extension is disabled
|
||||
if path_name in settings.lnbits_deactivated_extensions:
|
||||
response = JSONResponse(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
content={"detail": f"Extension '{path_name}' disabled"},
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
# re-route API trafic if the extension has been upgraded
|
||||
if path_type == "api":
|
||||
upgraded_extensions = list(
|
||||
filter(
|
||||
lambda ext: ext.endswith(f"/{path_name}"),
|
||||
settings.lnbits_upgraded_extensions,
|
||||
)
|
||||
)
|
||||
if len(upgraded_extensions) != 0:
|
||||
upgrade_path = upgraded_extensions[0]
|
||||
tail = "/".join(rest)
|
||||
scope["path"] = f"/upgrades/{upgrade_path}/{path_type}/{tail}"
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
|
||||
class ExtensionsRedirectMiddleware:
|
||||
# Extensions are allowed to specify redirect paths.
|
||||
# A call to a path outside the scope of the extension can be redirected to one of the extension's endpoints.
|
||||
# Eg: redirect `GET /.well-known` to `GET /lnurlp/api/v1/well-known`
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if "path" not in scope:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
req_headers = scope["headers"] if "headers" in scope else []
|
||||
redirect = self._find_redirect(scope["path"], req_headers)
|
||||
if redirect:
|
||||
scope["path"] = self._new_path(redirect, scope["path"])
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
def _find_redirect(self, path: str, req_headers: List[Tuple[bytes, bytes]]):
|
||||
return next(
|
||||
(
|
||||
r
|
||||
for r in settings.lnbits_extensions_redirects
|
||||
if self._redirect_matches(r, path, req_headers)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
def _redirect_matches(
|
||||
self, redirect: dict, path: str, req_headers: List[Tuple[bytes, bytes]]
|
||||
) -> bool:
|
||||
if "from_path" not in redirect:
|
||||
return False
|
||||
header_filters = (
|
||||
redirect["header_filters"] if "header_filters" in redirect else []
|
||||
)
|
||||
return self._has_common_path(redirect["from_path"], path) and self._has_headers(
|
||||
header_filters, req_headers
|
||||
)
|
||||
|
||||
def _has_headers(
|
||||
self, filter_headers: dict, req_headers: List[Tuple[bytes, bytes]]
|
||||
) -> bool:
|
||||
for h in filter_headers:
|
||||
if not self._has_header(req_headers, (str(h), str(filter_headers[h]))):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _has_header(
|
||||
self, req_headers: List[Tuple[bytes, bytes]], header: Tuple[str, str]
|
||||
) -> bool:
|
||||
for h in req_headers:
|
||||
if (
|
||||
h[0].decode().lower() == header[0].lower()
|
||||
and h[1].decode() == header[1]
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _has_common_path(self, redirect_path: str, req_path: str) -> bool:
|
||||
redirect_path_elements = redirect_path.split("/")
|
||||
req_path_elements = req_path.split("/")
|
||||
if len(redirect_path) > len(req_path):
|
||||
return False
|
||||
sub_path = req_path_elements[: len(redirect_path_elements)]
|
||||
return redirect_path == "/".join(sub_path)
|
||||
|
||||
def _new_path(self, redirect: dict, req_path: str) -> str:
|
||||
from_path = redirect["from_path"].split("/")
|
||||
redirect_to = redirect["redirect_to_path"].split("/")
|
||||
req_tail_path = req_path.split("/")[len(from_path) :]
|
||||
|
||||
elements = [
|
||||
e for e in ([redirect["ext_id"]] + redirect_to + req_tail_path) if e != ""
|
||||
]
|
||||
|
||||
return "/" + "/".join(elements)
|
||||
|
|
@ -4,7 +4,7 @@ import json
|
|||
import subprocess
|
||||
from os import path
|
||||
from sqlite3 import Row
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
|
@ -59,6 +59,8 @@ class InstalledExtensionsSettings(LNbitsSettings):
|
|||
lnbits_deactivated_extensions: List[str] = Field(default=[])
|
||||
# upgraded extensions that require API redirects
|
||||
lnbits_upgraded_extensions: List[str] = Field(default=[])
|
||||
# list of redirects that extensions want to perform
|
||||
lnbits_extensions_redirects: List[Any] = Field(default=[])
|
||||
|
||||
|
||||
class ThemesSettings(LNbitsSettings):
|
||||
|
|
@ -264,7 +266,7 @@ class SuperUserSettings(LNbitsSettings):
|
|||
class TransientSettings(InstalledExtensionsSettings):
|
||||
# Transient Settings:
|
||||
# - are initialized, updated and used at runtime
|
||||
# - are not read from a file or from the `setings` table
|
||||
# - are not read from a file or from the `settings` table
|
||||
# - are not persisted in the `settings` table when the settings are updated
|
||||
# - are cleared on server restart
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue