Add redirect paths for extensions (#1532)
* feat: basic redirect to extension endpoint * feat: filter by headers * refactor: extract `middleware.py` * fix: do not add twice the same extension to redirects * chore: code clean-up
This commit is contained in:
parent
84e369aad2
commit
0d5fef1cb9
4 changed files with 151 additions and 55 deletions
|
|
@ -32,18 +32,14 @@ from .core import (
|
||||||
)
|
)
|
||||||
from .core.services import check_admin_settings
|
from .core.services import check_admin_settings
|
||||||
from .core.views.generic import core_html_routes
|
from .core.views.generic import core_html_routes
|
||||||
from .extension_manager import (
|
from .extension_manager import Extension, InstallableExtension, get_valid_extensions
|
||||||
Extension,
|
|
||||||
InstallableExtension,
|
|
||||||
InstalledExtensionMiddleware,
|
|
||||||
get_valid_extensions,
|
|
||||||
)
|
|
||||||
from .helpers import (
|
from .helpers import (
|
||||||
get_css_vendored,
|
get_css_vendored,
|
||||||
get_js_vendored,
|
get_js_vendored,
|
||||||
template_renderer,
|
template_renderer,
|
||||||
url_for_vendored,
|
url_for_vendored,
|
||||||
)
|
)
|
||||||
|
from .middleware import ExtensionsRedirectMiddleware, InstalledExtensionMiddleware
|
||||||
from .requestvars import g
|
from .requestvars import g
|
||||||
from .tasks import (
|
from .tasks import (
|
||||||
catch_everything_and_restart,
|
catch_everything_and_restart,
|
||||||
|
|
@ -81,7 +77,10 @@ def create_app() -> FastAPI:
|
||||||
)
|
)
|
||||||
|
|
||||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||||
|
|
||||||
|
# order of these two middlewares is important
|
||||||
app.add_middleware(InstalledExtensionMiddleware)
|
app.add_middleware(InstalledExtensionMiddleware)
|
||||||
|
app.add_middleware(ExtensionsRedirectMiddleware)
|
||||||
|
|
||||||
register_startup(app)
|
register_startup(app)
|
||||||
register_assets(app)
|
register_assets(app)
|
||||||
|
|
@ -240,6 +239,15 @@ def register_ext_routes(app: FastAPI, ext: Extension) -> None:
|
||||||
for s in ext_statics:
|
for s in ext_statics:
|
||||||
app.mount(s["path"], s["app"], s["name"])
|
app.mount(s["path"], s["app"], s["name"])
|
||||||
|
|
||||||
|
if hasattr(ext_module, f"{ext.code}_redirect_paths"):
|
||||||
|
ext_redirects = getattr(ext_module, f"{ext.code}_redirect_paths")
|
||||||
|
settings.lnbits_extensions_redirects = [
|
||||||
|
r for r in settings.lnbits_extensions_redirects if r["ext_id"] != ext.code
|
||||||
|
]
|
||||||
|
for r in ext_redirects:
|
||||||
|
r["ext_id"] = ext.code
|
||||||
|
settings.lnbits_extensions_redirects.append(r)
|
||||||
|
|
||||||
logger.trace(f"adding route for extension {ext_module}")
|
logger.trace(f"adding route for extension {ext_module}")
|
||||||
|
|
||||||
prefix = f"/upgrades/{ext.upgrade_hash}" if ext.upgrade_hash != "" else ""
|
prefix = f"/upgrades/{ext.upgrade_hash}" if ext.upgrade_hash != "" else ""
|
||||||
|
|
|
||||||
|
|
@ -11,10 +11,8 @@ from typing import Any, List, NamedTuple, Optional, Tuple
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi.exceptions import HTTPException
|
from fastapi.exceptions import HTTPException
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
||||||
|
|
||||||
from lnbits.settings import settings
|
from lnbits.settings import settings
|
||||||
|
|
||||||
|
|
@ -461,51 +459,6 @@ class InstallableExtension(BaseModel):
|
||||||
return selected_release[0] if len(selected_release) != 0 else None
|
return selected_release[0] if len(selected_release) != 0 else None
|
||||||
|
|
||||||
|
|
||||||
class InstalledExtensionMiddleware:
|
|
||||||
# This middleware class intercepts calls made to the extensions API and:
|
|
||||||
# - it blocks the calls if the extension has been disabled or uninstalled.
|
|
||||||
# - it redirects the calls to the latest version of the extension if the extension has been upgraded.
|
|
||||||
# - otherwise it has no effect
|
|
||||||
def __init__(self, app: ASGIApp) -> None:
|
|
||||||
self.app = app
|
|
||||||
|
|
||||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
||||||
if "path" not in scope:
|
|
||||||
await self.app(scope, receive, send)
|
|
||||||
return
|
|
||||||
|
|
||||||
path_elements = scope["path"].split("/")
|
|
||||||
if len(path_elements) > 2:
|
|
||||||
_, path_name, path_type, *rest = path_elements
|
|
||||||
else:
|
|
||||||
_, path_name = path_elements
|
|
||||||
path_type = None
|
|
||||||
|
|
||||||
# block path for all users if the extension is disabled
|
|
||||||
if path_name in settings.lnbits_deactivated_extensions:
|
|
||||||
response = JSONResponse(
|
|
||||||
status_code=HTTPStatus.NOT_FOUND,
|
|
||||||
content={"detail": f"Extension '{path_name}' disabled"},
|
|
||||||
)
|
|
||||||
await response(scope, receive, send)
|
|
||||||
return
|
|
||||||
|
|
||||||
# re-route API trafic if the extension has been upgraded
|
|
||||||
if path_type == "api":
|
|
||||||
upgraded_extensions = list(
|
|
||||||
filter(
|
|
||||||
lambda ext: ext.endswith(f"/{path_name}"),
|
|
||||||
settings.lnbits_upgraded_extensions,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if len(upgraded_extensions) != 0:
|
|
||||||
upgrade_path = upgraded_extensions[0]
|
|
||||||
tail = "/".join(rest)
|
|
||||||
scope["path"] = f"/upgrades/{upgrade_path}/{path_type}/{tail}"
|
|
||||||
|
|
||||||
await self.app(scope, receive, send)
|
|
||||||
|
|
||||||
|
|
||||||
class CreateExtension(BaseModel):
|
class CreateExtension(BaseModel):
|
||||||
ext_id: str
|
ext_id: str
|
||||||
archive: str
|
archive: str
|
||||||
|
|
|
||||||
133
lnbits/middleware.py
Normal file
133
lnbits/middleware.py
Normal file
|
|
@ -0,0 +1,133 @@
|
||||||
|
from http import HTTPStatus
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
|
|
||||||
|
from lnbits.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class InstalledExtensionMiddleware:
|
||||||
|
# This middleware class intercepts calls made to the extensions API and:
|
||||||
|
# - it blocks the calls if the extension has been disabled or uninstalled.
|
||||||
|
# - it redirects the calls to the latest version of the extension if the extension has been upgraded.
|
||||||
|
# - otherwise it has no effect
|
||||||
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
|
self.app = app
|
||||||
|
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
|
if "path" not in scope:
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
path_elements = scope["path"].split("/")
|
||||||
|
if len(path_elements) > 2:
|
||||||
|
_, path_name, path_type, *rest = path_elements
|
||||||
|
else:
|
||||||
|
_, path_name = path_elements
|
||||||
|
path_type = None
|
||||||
|
|
||||||
|
# block path for all users if the extension is disabled
|
||||||
|
if path_name in settings.lnbits_deactivated_extensions:
|
||||||
|
response = JSONResponse(
|
||||||
|
status_code=HTTPStatus.NOT_FOUND,
|
||||||
|
content={"detail": f"Extension '{path_name}' disabled"},
|
||||||
|
)
|
||||||
|
await response(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
# re-route API trafic if the extension has been upgraded
|
||||||
|
if path_type == "api":
|
||||||
|
upgraded_extensions = list(
|
||||||
|
filter(
|
||||||
|
lambda ext: ext.endswith(f"/{path_name}"),
|
||||||
|
settings.lnbits_upgraded_extensions,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if len(upgraded_extensions) != 0:
|
||||||
|
upgrade_path = upgraded_extensions[0]
|
||||||
|
tail = "/".join(rest)
|
||||||
|
scope["path"] = f"/upgrades/{upgrade_path}/{path_type}/{tail}"
|
||||||
|
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionsRedirectMiddleware:
|
||||||
|
# Extensions are allowed to specify redirect paths.
|
||||||
|
# A call to a path outside the scope of the extension can be redirected to one of the extension's endpoints.
|
||||||
|
# Eg: redirect `GET /.well-known` to `GET /lnurlp/api/v1/well-known`
|
||||||
|
|
||||||
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
|
self.app = app
|
||||||
|
|
||||||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
|
if "path" not in scope:
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
req_headers = scope["headers"] if "headers" in scope else []
|
||||||
|
redirect = self._find_redirect(scope["path"], req_headers)
|
||||||
|
if redirect:
|
||||||
|
scope["path"] = self._new_path(redirect, scope["path"])
|
||||||
|
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
def _find_redirect(self, path: str, req_headers: List[Tuple[bytes, bytes]]):
|
||||||
|
return next(
|
||||||
|
(
|
||||||
|
r
|
||||||
|
for r in settings.lnbits_extensions_redirects
|
||||||
|
if self._redirect_matches(r, path, req_headers)
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _redirect_matches(
|
||||||
|
self, redirect: dict, path: str, req_headers: List[Tuple[bytes, bytes]]
|
||||||
|
) -> bool:
|
||||||
|
if "from_path" not in redirect:
|
||||||
|
return False
|
||||||
|
header_filters = (
|
||||||
|
redirect["header_filters"] if "header_filters" in redirect else []
|
||||||
|
)
|
||||||
|
return self._has_common_path(redirect["from_path"], path) and self._has_headers(
|
||||||
|
header_filters, req_headers
|
||||||
|
)
|
||||||
|
|
||||||
|
def _has_headers(
|
||||||
|
self, filter_headers: dict, req_headers: List[Tuple[bytes, bytes]]
|
||||||
|
) -> bool:
|
||||||
|
for h in filter_headers:
|
||||||
|
if not self._has_header(req_headers, (str(h), str(filter_headers[h]))):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _has_header(
|
||||||
|
self, req_headers: List[Tuple[bytes, bytes]], header: Tuple[str, str]
|
||||||
|
) -> bool:
|
||||||
|
for h in req_headers:
|
||||||
|
if (
|
||||||
|
h[0].decode().lower() == header[0].lower()
|
||||||
|
and h[1].decode() == header[1]
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _has_common_path(self, redirect_path: str, req_path: str) -> bool:
|
||||||
|
redirect_path_elements = redirect_path.split("/")
|
||||||
|
req_path_elements = req_path.split("/")
|
||||||
|
if len(redirect_path) > len(req_path):
|
||||||
|
return False
|
||||||
|
sub_path = req_path_elements[: len(redirect_path_elements)]
|
||||||
|
return redirect_path == "/".join(sub_path)
|
||||||
|
|
||||||
|
def _new_path(self, redirect: dict, req_path: str) -> str:
|
||||||
|
from_path = redirect["from_path"].split("/")
|
||||||
|
redirect_to = redirect["redirect_to_path"].split("/")
|
||||||
|
req_tail_path = req_path.split("/")[len(from_path) :]
|
||||||
|
|
||||||
|
elements = [
|
||||||
|
e for e in ([redirect["ext_id"]] + redirect_to + req_tail_path) if e != ""
|
||||||
|
]
|
||||||
|
|
||||||
|
return "/" + "/".join(elements)
|
||||||
|
|
@ -4,7 +4,7 @@ import json
|
||||||
import subprocess
|
import subprocess
|
||||||
from os import path
|
from os import path
|
||||||
from sqlite3 import Row
|
from sqlite3 import Row
|
||||||
from typing import List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
@ -59,6 +59,8 @@ class InstalledExtensionsSettings(LNbitsSettings):
|
||||||
lnbits_deactivated_extensions: List[str] = Field(default=[])
|
lnbits_deactivated_extensions: List[str] = Field(default=[])
|
||||||
# upgraded extensions that require API redirects
|
# upgraded extensions that require API redirects
|
||||||
lnbits_upgraded_extensions: List[str] = Field(default=[])
|
lnbits_upgraded_extensions: List[str] = Field(default=[])
|
||||||
|
# list of redirects that extensions want to perform
|
||||||
|
lnbits_extensions_redirects: List[Any] = Field(default=[])
|
||||||
|
|
||||||
|
|
||||||
class ThemesSettings(LNbitsSettings):
|
class ThemesSettings(LNbitsSettings):
|
||||||
|
|
@ -264,7 +266,7 @@ class SuperUserSettings(LNbitsSettings):
|
||||||
class TransientSettings(InstalledExtensionsSettings):
|
class TransientSettings(InstalledExtensionsSettings):
|
||||||
# Transient Settings:
|
# Transient Settings:
|
||||||
# - are initialized, updated and used at runtime
|
# - are initialized, updated and used at runtime
|
||||||
# - are not read from a file or from the `setings` table
|
# - are not read from a file or from the `settings` table
|
||||||
# - are not persisted in the `settings` table when the settings are updated
|
# - are not persisted in the `settings` table when the settings are updated
|
||||||
# - are cleared on server restart
|
# - are cleared on server restart
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue