diff --git a/lnbits/app.py b/lnbits/app.py
index 1b1292ce..9fb388c0 100644
--- a/lnbits/app.py
+++ b/lnbits/app.py
@@ -1,10 +1,15 @@
import asyncio
+import glob
import importlib
import logging
+import os
import signal
import sys
import traceback
+import zipfile
from http import HTTPStatus
+from pathlib import Path
+from typing import Callable
from fastapi import FastAPI, Request
from fastapi.exceptions import HTTPException, RequestValidationError
@@ -18,10 +23,12 @@ from lnbits.core.tasks import register_task_listeners
from lnbits.settings import get_wallet_class, set_wallet_class, settings
from .commands import migrate_databases
-from .core import core_app
+from .core import core_app, core_app_extra
from .core.services import check_admin_settings
from .core.views.generic import core_html_routes
from .helpers import (
+ EnabledExtensionMiddleware,
+ Extension,
get_css_vendored,
get_js_vendored,
get_valid_extensions,
@@ -65,6 +72,7 @@ def create_app() -> FastAPI:
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
+ app.add_middleware(EnabledExtensionMiddleware)
register_startup(app)
register_assets(app)
@@ -72,6 +80,8 @@ def create_app() -> FastAPI:
register_async_tasks(app)
register_exception_handlers(app)
+ setattr(core_app_extra, "register_new_ext_routes", register_new_ext_routes(app))
+
return app
@@ -105,6 +115,22 @@ async def check_funding_source() -> None:
)
+def check_installed_extensions():
+ """
+ Check extensions that have been installed, but for some reason no longer present in the 'lnbits/extensions' directory.
+ One reason might be a docker-container that was re-created.
+ The 'data' directory (where the '.zip' files live) is expected to persist state.
+ """
+ extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions")
+
+ zip_files = glob.glob(f"{extensions_data_dir}/*.zip")
+ for zip_file in zip_files:
+ ext_name = Path(zip_file).stem
+ if not Path(f"lnbits/extensions/{ext_name}").is_dir():
+ with zipfile.ZipFile(zip_file, "r") as zip_ref:
+ zip_ref.extractall("lnbits/extensions/")
+
+
def register_routes(app: FastAPI) -> None:
"""Register FastAPI routes / LNbits extensions."""
app.include_router(core_app)
@@ -112,20 +138,7 @@ def register_routes(app: FastAPI) -> None:
for ext in get_valid_extensions():
try:
- ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}")
- ext_route = getattr(ext_module, f"{ext.code}_ext")
-
- if hasattr(ext_module, f"{ext.code}_start"):
- ext_start_func = getattr(ext_module, f"{ext.code}_start")
- ext_start_func()
-
- if hasattr(ext_module, f"{ext.code}_static_files"):
- ext_statics = getattr(ext_module, f"{ext.code}_static_files")
- for s in ext_statics:
- app.mount(s["path"], s["app"], s["name"])
-
- logger.trace(f"adding route for extension {ext_module}")
- app.include_router(ext_route)
+ register_ext_routes(app, ext)
except Exception as e:
logger.error(str(e))
raise ImportError(
@@ -133,6 +146,31 @@ def register_routes(app: FastAPI) -> None:
)
+def register_new_ext_routes(app: FastAPI) -> Callable:
+ def register_new_ext_routes_fn(ext: Extension):
+ register_ext_routes(app, ext)
+
+ return register_new_ext_routes_fn
+
+
+def register_ext_routes(app: FastAPI, ext: Extension) -> None:
+ """Register FastAPI routes for extension."""
+ ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}")
+ ext_route = getattr(ext_module, f"{ext.code}_ext")
+
+ if hasattr(ext_module, f"{ext.code}_start"):
+ ext_start_func = getattr(ext_module, f"{ext.code}_start")
+ ext_start_func()
+
+ if hasattr(ext_module, f"{ext.code}_static_files"):
+ ext_statics = getattr(ext_module, f"{ext.code}_static_files")
+ for s in ext_statics:
+ app.mount(s["path"], s["app"], s["name"])
+
+ logger.trace(f"adding route for extension {ext_module}")
+ app.include_router(ext_route)
+
+
def register_startup(app: FastAPI):
@app.on_event("startup")
async def lnbits_startup():
@@ -151,6 +189,9 @@ def register_startup(app: FastAPI):
# 4. initialize funding source
await check_funding_source()
+
+ # 5. check extensions in `data` directory
+ await check_installed_extensions()
except Exception as e:
logger.error(str(e))
raise ImportError("Failed to run 'startup' event.")
diff --git a/lnbits/commands.py b/lnbits/commands.py
index 82ea1430..66b45b93 100644
--- a/lnbits/commands.py
+++ b/lnbits/commands.py
@@ -11,6 +11,8 @@ from lnbits.settings import settings
from .core import db as core_db
from .core import migrations as core_migrations
+from .core.crud import USER_ID_ALL, get_dbversions, get_inactive_extensions
+from .core.helpers import migrate_extension_database, run_migration
from .db import COCKROACH, POSTGRES, SQLITE
from .helpers import (
get_css_vendored,
@@ -59,30 +61,6 @@ def bundle_vendored():
async def migrate_databases():
"""Creates the necessary databases if they don't exist already; or migrates them."""
- async def set_migration_version(conn, db_name, version):
- await conn.execute(
- """
- INSERT INTO dbversions (db, version) VALUES (?, ?)
- ON CONFLICT (db) DO UPDATE SET version = ?
- """,
- (db_name, version, version),
- )
-
- async def run_migration(db, migrations_module, db_name):
- for key, migrate in migrations_module.__dict__.items():
- match = match = matcher.match(key)
- if match:
- version = int(match.group(1))
- if version > current_versions.get(db_name, 0):
- logger.debug(f"running migration {db_name}.{version}")
- await migrate(db)
-
- if db.schema == None:
- await set_migration_version(db, db_name, version)
- else:
- async with core_db.connect() as conn:
- await set_migration_version(conn, db_name, version)
-
async with core_db.connect() as conn:
if conn.type == SQLITE:
exists = await conn.fetchone(
@@ -96,27 +74,18 @@ async def migrate_databases():
if not exists:
await core_migrations.m000_create_migrations_table(conn)
- rows = await (await conn.execute("SELECT * FROM dbversions")).fetchall()
- current_versions = {row["db"]: row["version"] for row in rows}
- matcher = re.compile(r"^m(\d\d\d)_")
- db_name = core_migrations.__name__.split(".")[-2]
- await run_migration(conn, core_migrations, db_name)
+ current_versions = await get_dbversions(conn)
+ core_version = current_versions.get("core", 0)
+ await run_migration(conn, core_migrations, core_version)
for ext in get_valid_extensions():
- try:
-
- module_str = (
- ext.migration_module or f"lnbits.extensions.{ext.code}.migrations"
- )
- ext_migrations = importlib.import_module(module_str)
- ext_db = importlib.import_module(f"lnbits.extensions.{ext.code}").db
- db_name = ext.db_name or module_str.split(".")[-2]
- except ImportError:
- raise ImportError(
- f"Please make sure that the extension `{ext.code}` has a migrations file."
- )
-
- async with ext_db.connect() as ext_conn:
- await run_migration(ext_conn, ext_migrations, db_name)
+ current_version = current_versions.get(ext.code, 0)
+ await migrate_extension_database(ext, current_version)
logger.info("✔️ All migrations done.")
+
+
+async def load_disabled_extension_list() -> None:
+ """Update list of extensions that have been explicitly disabled"""
+ inactive_extensions = await get_inactive_extensions(user_id=USER_ID_ALL)
+ settings.lnbits_disabled_extensions += inactive_extensions
diff --git a/lnbits/core/__init__.py b/lnbits/core/__init__.py
index dec15d78..75b6d587 100644
--- a/lnbits/core/__init__.py
+++ b/lnbits/core/__init__.py
@@ -1,11 +1,14 @@
from fastapi.routing import APIRouter
+from lnbits.core.models import CoreAppExtra
from lnbits.db import Database
db = Database("database")
core_app: APIRouter = APIRouter()
+core_app_extra: CoreAppExtra = CoreAppExtra()
+
from .views.admin_api import * # noqa
from .views.api import * # noqa
from .views.generic import * # noqa
diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py
index a80fadf2..1289c33a 100644
--- a/lnbits/core/crud.py
+++ b/lnbits/core/crud.py
@@ -11,6 +11,8 @@ from lnbits.settings import AdminSettings, EditableSettings, SuperSettings, sett
from . import db
from .models import BalanceCheck, Payment, User, Wallet
+USER_ID_ALL = "all"
+
# accounts
# --------
@@ -78,6 +80,18 @@ async def update_user_extension(
)
+async def get_inactive_extensions(
+ *, user_id: str, conn: Optional[Connection] = None
+) -> List[str]:
+ inactive_extensions = await (conn or db).fetchall(
+ """SELECT extension FROM extensions WHERE "user" = ? AND NOT active""",
+ (user_id,),
+ )
+ return (
+ [ext[0] for ext in inactive_extensions] if len(inactive_extensions) != 0 else []
+ )
+
+
# wallets
# -------
@@ -620,3 +634,20 @@ async def create_admin_settings(super_user: str, new_settings: dict):
sql = f"INSERT INTO settings (super_user, editable_settings) VALUES (?, ?)"
await db.execute(sql, (super_user, json.dumps(new_settings)))
return await get_super_settings()
+
+
+# db versions
+# --------------
+async def get_dbversions(conn: Optional[Connection] = None):
+ rows = await (conn or db).fetchall("SELECT * FROM dbversions")
+ return {row["db"]: row["version"] for row in rows}
+
+
+async def update_migration_version(conn, db_name, version):
+ await (conn or db).execute(
+ """
+ INSERT INTO dbversions (db, version) VALUES (?, ?)
+ ON CONFLICT (db) DO UPDATE SET version = ?
+ """,
+ (db_name, version, version),
+ )
diff --git a/lnbits/core/helpers.py b/lnbits/core/helpers.py
new file mode 100644
index 00000000..3675d438
--- /dev/null
+++ b/lnbits/core/helpers.py
@@ -0,0 +1,93 @@
+import hashlib
+import importlib
+import re
+import urllib.request
+from typing import List
+
+import httpx
+from fastapi.exceptions import HTTPException
+from loguru import logger
+
+from lnbits.helpers import InstallableExtension
+from lnbits.settings import settings
+
+from . import db as core_db
+from .crud import update_migration_version
+
+
+async def migrate_extension_database(ext, current_version):
+ try:
+ ext_migrations = importlib.import_module(
+ f"lnbits.extensions.{ext.code}.migrations"
+ )
+ ext_db = importlib.import_module(f"lnbits.extensions.{ext.code}").db
+ except ImportError:
+ raise ImportError(
+ f"Please make sure that the extension `{ext.code}` has a migrations file."
+ )
+
+ async with ext_db.connect() as ext_conn:
+ await run_migration(ext_conn, ext_migrations, current_version)
+
+
+async def run_migration(db, migrations_module, current_version):
+ matcher = re.compile(r"^m(\d\d\d)_")
+ db_name = migrations_module.__name__.split(".")[-2]
+ for key, migrate in migrations_module.__dict__.items():
+ match = match = matcher.match(key)
+ if match:
+ version = int(match.group(1))
+ if version > current_version:
+ logger.debug(f"running migration {db_name}.{version}")
+ print(f"running migration {db_name}.{version}")
+ await migrate(db)
+
+ if db.schema == None:
+ await update_migration_version(db, db_name, version)
+ else:
+ async with core_db.connect() as conn:
+ await update_migration_version(conn, db_name, version)
+
+
+async def get_installable_extensions() -> List[InstallableExtension]:
+ extension_list: List[InstallableExtension] = []
+
+ async with httpx.AsyncClient() as client:
+ for url in settings.lnbits_extensions_manifests:
+ resp = await client.get(url)
+ if resp.status_code != 200:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Unable to fetch extension list for repository: {url}",
+ )
+ for e in resp.json()["extensions"]:
+ extension_list += [
+ InstallableExtension(
+ id=e["id"],
+ name=e["name"],
+ archive=e["archive"],
+ hash=e["hash"],
+ short_description=e["shortDescription"],
+ details=e["details"] if "details" in e else "",
+ icon=e["icon"],
+ dependencies=e["dependencies"] if "dependencies" in e else [],
+ )
+ ]
+
+ return extension_list
+
+
+def download_url(url, save_path):
+ with urllib.request.urlopen(url) as dl_file:
+ with open(save_path, "wb") as out_file:
+ out_file.write(dl_file.read())
+
+
+def file_hash(filename):
+ h = hashlib.sha256()
+ b = bytearray(128 * 1024)
+ mv = memoryview(b)
+ with open(filename, "rb", buffering=0) as f:
+ while n := f.readinto(mv):
+ h.update(mv[:n])
+ return h.hexdigest()
diff --git a/lnbits/core/models.py b/lnbits/core/models.py
index eca1bf50..7b147208 100644
--- a/lnbits/core/models.py
+++ b/lnbits/core/models.py
@@ -4,7 +4,7 @@ import hmac
import json
import time
from sqlite3 import Row
-from typing import Dict, List, Optional
+from typing import Callable, Dict, List, Optional
from ecdsa import SECP256k1, SigningKey
from fastapi import Query
@@ -213,3 +213,7 @@ class BalanceCheck(BaseModel):
@classmethod
def from_row(cls, row: Row):
return cls(wallet=row["wallet"], service=row["service"], url=row["url"])
+
+
+class CoreAppExtra:
+ register_new_ext_routes: Callable
diff --git a/lnbits/core/templates/core/extensions.html b/lnbits/core/templates/core/extensions.html
index 88e50269..dc0037e2 100644
--- a/lnbits/core/templates/core/extensions.html
+++ b/lnbits/core/templates/core/extensions.html
@@ -4,6 +4,15 @@
{% endblock %} {% block page %}
+
+ Add or Remove Extensions
+
+
+ Back
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {% raw %}
+ {{ extension.name}}
+ {{ extension.shortDescription }}
+
+ Depends on:
+
+
+
+
+
+ {% endraw %}
+
+
+
+
+
+
+ Uninstall
+
+
+
+
+
+ Install
+
+
+
+
+
+
+ User Review Comming Soon
+
+
+ User Review Comming Soon
+
+
+ User Review Comming Soon
+
+
+ User Review Comming Soon
+
+
+ User Review Comming Soon
+
+
+
+
+
+
+
+
+ Warning
+
+ You are about to remove the extension for all users.
+ Are you sure you want to continue?
+
+
+
+ Yes, Uninstall
+ Cancel
+
+
+
+
+ {%raw%}
+
+
+ {{selectedExtension.name}}
+
+
+
+ Done
+
+
+
+ {%endraw%}
+
+{% endblock %} {% block scripts %} {{ window_vars(user) }}
+
+{% endblock %}
diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py
index d545df9a..7e552e8a 100644
--- a/lnbits/core/views/api.py
+++ b/lnbits/core/views/api.py
@@ -1,11 +1,18 @@
import asyncio
import hashlib
+import importlib
+import inspect
import json
+import os
+import shutil
+import sys
import time
import uuid
+import zipfile
from http import HTTPStatus
from io import BytesIO
-from typing import Dict, Optional, Tuple, Union
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple, Union
from urllib.parse import ParseResult, parse_qs, urlencode, urlparse, urlunparse
import async_timeout
@@ -22,6 +29,8 @@ from fastapi import (
WebSocketDisconnect,
)
from fastapi.exceptions import HTTPException
+from fastapi.params import Body
+from genericpath import isfile
from loguru import logger
from pydantic import BaseModel
from pydantic.fields import Field
@@ -29,15 +38,27 @@ from sse_starlette.sse import EventSourceResponse
from starlette.responses import StreamingResponse
from lnbits import bolt11, lnurl
-from lnbits.core.models import Payment, Wallet
+from lnbits.core.helpers import (
+ download_url,
+ file_hash,
+ get_installable_extensions,
+ migrate_extension_database,
+)
+from lnbits.core.models import Payment, User, Wallet
from lnbits.decorators import (
WalletTypeInfo,
check_admin,
+ check_user_exists,
get_key_type,
require_admin_key,
require_invoice_key,
)
-from lnbits.helpers import url_for
+from lnbits.helpers import (
+ Extension,
+ InstallableExtension,
+ get_valid_extensions,
+ url_for,
+)
from lnbits.settings import get_wallet_class, settings
from lnbits.utils.exchange_rates import (
currencies,
@@ -45,13 +66,16 @@ from lnbits.utils.exchange_rates import (
satoshis_amount_as_fiat,
)
-from .. import core_app, db
+from .. import core_app, core_app_extra, db
from ..crud import (
+ USER_ID_ALL,
+ get_dbversions,
get_payments,
get_standalone_payment,
get_total_balance,
get_wallet_for_key,
save_balance_check,
+ update_user_extension,
update_wallet,
)
from ..services import (
@@ -706,3 +730,185 @@ async def websocket_update_get(item_id: str, data: str):
return {"sent": True, "data": data}
except:
return {"sent": False, "data": data}
+
+
+@core_app.post("/api/v1/extension/{ext_id}/{hash}")
+async def api_install_extension(
+ ext_id: str, hash: str, user: User = Depends(check_user_exists)
+):
+ if not user.admin:
+ raise HTTPException(
+ status_code=HTTPStatus.UNAUTHORIZED, detail="Only for admin users"
+ )
+
+ try:
+ extension_list: List[InstallableExtension] = await get_installable_extensions()
+ except Exception as ex:
+ raise HTTPException(
+ status_code=HTTPStatus.NOT_FOUND,
+ detail="Cannot fetch installable extension list",
+ )
+
+ extensions = [e for e in extension_list if e.id == ext_id and e.hash == hash]
+ if len(extensions) == 0:
+ raise HTTPException(
+ status_code=HTTPStatus.BAD_REQUEST,
+ detail=f"Unknown extension id: {ext_id}",
+ )
+ extension = extensions[0]
+
+ # check that all dependecies are installed
+ installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True)))
+ if not set(extension.dependencies).issubset(installed_extensions):
+ raise HTTPException(
+ status_code=HTTPStatus.NOT_FOUND,
+ detail=f"Not all dependencies are installed: {extension.dependencies}",
+ )
+
+ # move files to the right location
+ extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions")
+ os.makedirs(extensions_data_dir, exist_ok=True)
+ ext_data_dir = os.path.join(extensions_data_dir, ext_id)
+ shutil.rmtree(ext_data_dir, True)
+ ext_zip_file = os.path.join(extensions_data_dir, f"{ext_id}.zip")
+ if os.path.isfile(ext_zip_file):
+ os.remove(ext_zip_file)
+
+ try:
+ download_url(extension.archive, ext_zip_file)
+ except Exception as ex:
+ raise HTTPException(
+ status_code=HTTPStatus.NOT_FOUND,
+ detail="Cannot fetch extension archive file",
+ )
+
+ archive_hash = file_hash(ext_zip_file)
+ if extension.hash != archive_hash:
+ # remove downloaded archive
+ if os.path.isfile(ext_zip_file):
+ os.remove(ext_zip_file)
+ raise HTTPException(
+ status_code=HTTPStatus.NOT_FOUND,
+ detail="File hash missmatch. Will not install.",
+ )
+
+ try:
+ ext_dir = os.path.join("lnbits/extensions", ext_id)
+ shutil.rmtree(ext_dir, True)
+ with zipfile.ZipFile(ext_zip_file, "r") as zip_ref:
+ zip_ref.extractall("lnbits/extensions")
+
+ # todo: is admin only
+ ext = Extension(extension.id, True, extension.is_admin_only, extension.name)
+
+ current_versions = await get_dbversions()
+ current_version = current_versions.get(ext.code, 0)
+
+ module_name = f"lnbits.extensions.{ext.code}"
+ # if module_name in sys.modules:
+ # importlib.reload(sys.modules[module_name])
+ ext_module = importlib.import_module(module_name)
+ # sys.modules[module_name] = importlib.reload(ext_module)
+
+ modules_to_reload = list_modules_for_extension(ext_id)
+ print("### modules_to_reload", modules_to_reload)
+ for m in modules_to_reload:
+ importlib.reload(sys.modules[m])
+
+ await migrate_extension_database(ext, current_version)
+
+ # disable by default
+ await update_user_extension(user_id=USER_ID_ALL, extension=ext_id, active=False)
+ settings.lnbits_disabled_extensions += [ext_id]
+
+ # mount routes at the very end
+ core_app_extra.register_new_ext_routes(ext)
+ except Exception as ex:
+ logger.warning(ex)
+ # remove downloaded archive
+ if os.path.isfile(ext_zip_file):
+ os.remove(ext_zip_file)
+
+ # remove module from extensions
+ shutil.rmtree(ext_dir, True)
+ raise HTTPException(
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(ex)
+ )
+
+
+@core_app.delete("/api/v1/extension/{ext_id}")
+async def api_uninstall_extension(ext_id: str, user: User = Depends(check_user_exists)):
+ if not user.admin:
+ raise HTTPException(
+ status_code=HTTPStatus.UNAUTHORIZED, detail="Only for admin users"
+ )
+
+ try:
+ extension_list: List[InstallableExtension] = await get_installable_extensions()
+ except Exception as ex:
+ raise HTTPException(
+ status_code=HTTPStatus.NOT_FOUND,
+ detail="Cannot fetch installable extension list",
+ )
+
+ extensions = [e for e in extension_list if e.id == ext_id]
+ if len(extensions) == 0:
+ raise HTTPException(
+ status_code=HTTPStatus.BAD_REQUEST,
+ detail=f"Unknown extension id: {ext_id}",
+ )
+
+ # check that other extensions do not depend on this one
+ for active_ext_id in list(map(lambda e: e.code, get_valid_extensions(True))):
+ active_ext = next(
+ (ext for ext in extension_list if ext.id == active_ext_id), None
+ )
+ if active_ext and ext_id in active_ext.dependencies:
+ raise HTTPException(
+ status_code=HTTPStatus.BAD_REQUEST,
+ detail=f"Cannot uninstall. Extension '{active_ext.name}' depends on this one.",
+ )
+
+ try:
+ settings.lnbits_disabled_extensions += [ext_id]
+
+ # remove downloaded archive
+ ext_zip_file = os.path.join(
+ settings.lnbits_data_folder, "extensions", f"{ext_id}.zip"
+ )
+ if os.path.isfile(ext_zip_file):
+ os.remove(ext_zip_file)
+
+ # module_name = f"lnbits.extensions.{ext_id}"
+
+ # modules_to_delete = list_modules_for_extension(ext_id)
+ # print('### modules_to_delete', modules_to_delete)
+ # for m in modules_to_delete:
+ # module = sys.modules[m]
+ # del sys.modules[m]
+ # del module
+
+ # remove module from extensions
+ ext_dir = os.path.join("lnbits/extensions", ext_id)
+ shutil.rmtree(ext_dir, True)
+ except Exception as ex:
+ raise HTTPException(
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(ex)
+ )
+
+
+def list_modules_for_extension(ext_id: str) -> List[str]:
+ modules_for_extension = []
+ for key in sys.modules.keys():
+ try:
+ module = sys.modules[key]
+ moduleFilePath = inspect.getfile(module).lower()
+
+ dir_name = str(Path(moduleFilePath).parent.absolute())
+ if dir_name.endswith(f"lnbits/extensions/{ext_id}"):
+ print("## moduleFilePath", moduleFilePath)
+ modules_for_extension += [key]
+
+ except:
+ pass # built in modules throw if queried
+ return modules_for_extension
diff --git a/lnbits/core/views/generic.py b/lnbits/core/views/generic.py
index ab19feb8..d14a43f6 100644
--- a/lnbits/core/views/generic.py
+++ b/lnbits/core/views/generic.py
@@ -1,6 +1,6 @@
import asyncio
from http import HTTPStatus
-from typing import Optional
+from typing import List, Optional
from fastapi import Depends, Query, Request, status
from fastapi.exceptions import HTTPException
@@ -11,17 +11,20 @@ from pydantic.types import UUID4
from starlette.responses import HTMLResponse, JSONResponse
from lnbits.core import db
+from lnbits.core.helpers import get_installable_extensions
from lnbits.core.models import User
from lnbits.decorators import check_admin, check_user_exists
from lnbits.helpers import template_renderer, url_for
from lnbits.settings import get_wallet_class, settings
-from ...helpers import get_valid_extensions
+from ...helpers import InstallableExtension, get_valid_extensions
from ..crud import (
+ USER_ID_ALL,
create_account,
create_wallet,
delete_wallet,
get_balance_check,
+ get_inactive_extensions,
get_user,
save_balance_notify,
update_user_extension,
@@ -52,35 +55,10 @@ async def extensions(
enable: str = Query(None),
disable: str = Query(None),
):
- extension_to_enable = enable
- extension_to_disable = disable
-
- if extension_to_enable and extension_to_disable:
- raise HTTPException(
- HTTPStatus.BAD_REQUEST, "You can either `enable` or `disable` an extension."
- )
-
- # check if extension exists
- if extension_to_enable or extension_to_disable:
- ext = extension_to_enable or extension_to_disable
- if ext not in [e.code for e in get_valid_extensions()]:
- raise HTTPException(
- HTTPStatus.BAD_REQUEST, f"Extension '{ext}' doesn't exist."
- )
-
- if extension_to_enable:
- logger.info(f"Enabling extension: {extension_to_enable} for user {user.id}")
- await update_user_extension(
- user_id=user.id, extension=extension_to_enable, active=True
- )
- elif extension_to_disable:
- logger.info(f"Disabling extension: {extension_to_disable} for user {user.id}")
- await update_user_extension(
- user_id=user.id, extension=extension_to_disable, active=False
- )
+ await toggle_extension(enable, disable, user.id)
# Update user as his extensions have been updated
- if extension_to_enable or extension_to_disable:
+ if enable or disable:
user = await get_user(user.id) # type: ignore
return template_renderer().TemplateResponse(
@@ -88,6 +66,70 @@ async def extensions(
)
+@core_html_routes.get(
+ "/install", name="install.extensions", response_class=HTMLResponse
+)
+async def extensions_install(
+ request: Request,
+ user: User = Depends(check_user_exists), # type: ignore
+ activate: str = Query(None), # type: ignore
+ deactivate: str = Query(None), # type: ignore
+):
+ if not user.admin:
+ raise HTTPException(
+ status_code=HTTPStatus.UNAUTHORIZED, detail="Only for admin users"
+ )
+
+ try:
+ extension_list: List[InstallableExtension] = await get_installable_extensions()
+ except Exception as ex:
+ logger.warning(ex)
+ raise HTTPException(
+ status_code=HTTPStatus.NOT_FOUND,
+ detail="Cannot fetch installable extension list",
+ )
+
+ try:
+ if deactivate:
+ settings.lnbits_disabled_extensions += [deactivate]
+ elif activate:
+ settings.lnbits_disabled_extensions = list(
+ filter(lambda e: e != activate, settings.lnbits_disabled_extensions)
+ )
+ await toggle_extension(activate, deactivate, USER_ID_ALL)
+
+ installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True)))
+ inactive_extensions = await get_inactive_extensions(user_id=USER_ID_ALL)
+ extensions = list(
+ map(
+ lambda ext: {
+ "id": ext.id,
+ "name": ext.name,
+ "hash": ext.hash,
+ "icon": ext.icon,
+ "shortDescription": ext.short_description,
+ "details": ext.details,
+ "dependencies": ext.dependencies,
+ "isInstalled": ext.id in installed_extensions,
+ "isActive": not ext.id in inactive_extensions,
+ },
+ extension_list,
+ )
+ )
+
+ return template_renderer().TemplateResponse(
+ "core/install.html",
+ {
+ "request": request,
+ "user": user.dict(),
+ "extensions": extensions,
+ },
+ )
+ except Exception as e:
+ logger.warning(e)
+ raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(e))
+
+
@core_html_routes.get(
"/wallet",
response_class=HTMLResponse,
@@ -327,3 +369,29 @@ async def index(request: Request, user: User = Depends(check_admin)):
"balance": balance,
},
)
+
+
+async def toggle_extension(extension_to_enable, extension_to_disable, user_id):
+ if extension_to_enable and extension_to_disable:
+ raise HTTPException(
+ HTTPStatus.BAD_REQUEST, "You can either `enable` or `disable` an extension."
+ )
+
+ # check if extension exists
+ if extension_to_enable or extension_to_disable:
+ ext = extension_to_enable or extension_to_disable
+ if ext not in [e.code for e in get_valid_extensions(True)]:
+ raise HTTPException(
+ HTTPStatus.BAD_REQUEST, f"Extension '{ext}' doesn't exist."
+ )
+
+ if extension_to_enable:
+ logger.info(f"Enabling extension: {extension_to_enable} for user {user_id}")
+ await update_user_extension(
+ user_id=user_id, extension=extension_to_enable, active=True
+ )
+ elif extension_to_disable:
+ logger.info(f"Disabling extension: {extension_to_disable} for user {user_id}")
+ await update_user_extension(
+ user_id=user_id, extension=extension_to_disable, active=False
+ )
diff --git a/lnbits/helpers.py b/lnbits/helpers.py
index 4804bdea..52a7f6ab 100644
--- a/lnbits/helpers.py
+++ b/lnbits/helpers.py
@@ -1,10 +1,13 @@
import glob
import json
import os
+from http import HTTPStatus
from typing import Any, List, NamedTuple, Optional
import jinja2
-import shortuuid
+import shortuuid # type: ignore
+from fastapi.responses import JSONResponse
+from starlette.types import ASGIApp, Receive, Scope, Send
from lnbits.jinja2_templating import Jinja2Templates
from lnbits.requestvars import g
@@ -25,8 +28,10 @@ class Extension(NamedTuple):
class ExtensionManager:
- def __init__(self):
- self._disabled: List[str] = settings.lnbits_disabled_extensions
+ def __init__(self, include_disabled_exts=False):
+ self._disabled: List[str] = (
+ [] if include_disabled_exts else settings.lnbits_disabled_extensions
+ )
self._admin_only: List[str] = settings.lnbits_admin_extensions
self._extension_folders: List[str] = [
x[1] for x in os.walk(os.path.join(settings.lnbits_path, "extensions"))
@@ -74,9 +79,40 @@ class ExtensionManager:
return output
-def get_valid_extensions() -> List[Extension]:
+class EnabledExtensionMiddleware:
+ def __init__(self, app: ASGIApp) -> None:
+ self.app = app
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ pathname = scope["path"].split("/")[1]
+ if pathname in settings.lnbits_disabled_extensions:
+ response = JSONResponse(
+ status_code=HTTPStatus.NOT_FOUND,
+ content={"detail": f"Extension '{pathname}' disabled"},
+ )
+ await response(scope, receive, send)
+ return
+
+ await self.app(scope, receive, send)
+
+
+class InstallableExtension(NamedTuple):
+ id: str
+ name: str
+ archive: str
+ hash: str
+ short_description: Optional[str] = None
+ details: Optional[str] = None
+ icon: Optional[str] = None
+ dependencies: List[str] = []
+ is_admin_only: bool = False
+
+
+def get_valid_extensions(include_disabled_exts=False) -> List[Extension]:
return [
- extension for extension in ExtensionManager().extensions if extension.is_valid
+ extension
+ for extension in ExtensionManager(include_disabled_exts).extensions
+ if extension.is_valid
]
diff --git a/lnbits/settings.py b/lnbits/settings.py
index 6ec4db0c..d00d038d 100644
--- a/lnbits/settings.py
+++ b/lnbits/settings.py
@@ -40,6 +40,7 @@ class UsersSettings(LNbitsSettings):
lnbits_allowed_users: List[str] = Field(default=[])
lnbits_admin_extensions: List[str] = Field(default=[])
lnbits_disabled_extensions: List[str] = Field(default=[])
+ lnbits_extensions_manifests: List[str] = Field(default=[])
class ThemesSettings(LNbitsSettings):
diff --git a/lnbits/static/js/base.js b/lnbits/static/js/base.js
index 32b075b7..d424d563 100644
--- a/lnbits/static/js/base.js
+++ b/lnbits/static/js/base.js
@@ -141,7 +141,8 @@ window.LNbits = {
admin: data.admin,
email: data.email,
extensions: data.extensions,
- wallets: data.wallets
+ wallets: data.wallets,
+ admin: data.admin
}
var mapWallet = this.wallet
obj.wallets = obj.wallets