diff --git a/lnbits/core/helpers.py b/lnbits/core/helpers.py index 7ff72d32..1dac5955 100644 --- a/lnbits/core/helpers.py +++ b/lnbits/core/helpers.py @@ -1,14 +1,16 @@ import hashlib import importlib +import os import re import urllib.request +from http import HTTPStatus from typing import List import httpx from fastapi.exceptions import HTTPException from loguru import logger -from lnbits.helpers import InstallableExtension +from lnbits.helpers import InstallableExtension, get_valid_extensions from lnbits.settings import settings from . import db as core_db @@ -76,6 +78,56 @@ async def get_installable_extensions() -> List[InstallableExtension]: return extension_list +async def get_installable_extension_meta( + ext_id: str, hash: str +) -> InstallableExtension: + installable_extensions: List[ + InstallableExtension + ] = await get_installable_extensions() + + valid_extensions = [ + e for e in installable_extensions if e.id == ext_id and e.hash == hash + ] + if len(valid_extensions) == 0: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"Unknown extension id: {ext_id}", + ) + extension = valid_extensions[0] + + # check that all dependecies are installed + installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True))) + if not set(extension.dependencies).issubset(installed_extensions): + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail=f"Not all dependencies are installed: {extension.dependencies}", + ) + + return extension + + +def download_extension_archive(archive: str, ext_zip_file: str, hash: str): + if os.path.isfile(ext_zip_file): + os.remove(ext_zip_file) + try: + download_url(archive, ext_zip_file) + except Exception as ex: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail="Cannot fetch extension archive file", + ) + + archive_hash = file_hash(ext_zip_file) + if hash != archive_hash: + # remove downloaded archive + if os.path.isfile(ext_zip_file): + os.remove(ext_zip_file) + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail="File hash missmatch. Will not install.", + ) + + def download_url(url, save_path): with urllib.request.urlopen(url) as dl_file: with open(save_path, "wb") as out_file: diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index 690a3a7e..b22ca838 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -39,8 +39,9 @@ from starlette.responses import StreamingResponse from lnbits import bolt11, lnurl from lnbits.core.helpers import ( - download_url, + download_extension_archive, file_hash, + get_installable_extension_meta, get_installable_extensions, migrate_extension_database, ) @@ -736,80 +737,34 @@ async def websocket_update_get(item_id: str, data: str): async def api_install_extension( ext_id: str, hash: str, user: User = Depends(check_admin) ): - try: - extension_list: List[InstallableExtension] = await get_installable_extensions() - except Exception as ex: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, - detail="Cannot fetch installable extension list", - ) - extensions = [e for e in extension_list if e.id == ext_id and e.hash == hash] - if len(extensions) == 0: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail=f"Unknown extension id: {ext_id}", - ) - extension = extensions[0] + ext_meta: InstallableExtension = await get_installable_extension_meta(ext_id, hash) - # check that all dependecies are installed - installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True))) - if not set(extension.dependencies).issubset(installed_extensions): - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, - detail=f"Not all dependencies are installed: {extension.dependencies}", - ) - - # move files to the right location - extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions") - os.makedirs(extensions_data_dir, exist_ok=True) - ext_data_dir = os.path.join(extensions_data_dir, ext_id) - shutil.rmtree(ext_data_dir, True) - ext_zip_file = os.path.join(extensions_data_dir, f"{ext_id}.zip") - if os.path.isfile(ext_zip_file): - os.remove(ext_zip_file) - - try: - download_url(extension.archive, ext_zip_file) - except Exception as ex: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, - detail="Cannot fetch extension archive file", - ) - - archive_hash = file_hash(ext_zip_file) - if extension.hash != archive_hash: - # remove downloaded archive - if os.path.isfile(ext_zip_file): - os.remove(ext_zip_file) - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, - detail="File hash missmatch. Will not install.", - ) + download_extension_archive(ext_meta.archive, ext_meta.zip_path, ext_meta.hash) try: ext_dir = os.path.join("lnbits/extensions", ext_id) shutil.rmtree(ext_dir, True) - with zipfile.ZipFile(ext_zip_file, "r") as zip_ref: + with zipfile.ZipFile(ext_meta.zip_path, "r") as zip_ref: zip_ref.extractall("lnbits/extensions") ext_upgrade_dir = os.path.join( - "lnbits/upgrades", f"{extension.id}-{extension.hash}" + "lnbits/upgrades", f"{ext_meta.id}-{ext_meta.hash}" ) os.makedirs("lnbits/upgrades", exist_ok=True) shutil.rmtree(ext_upgrade_dir, True) - with zipfile.ZipFile(ext_zip_file, "r") as zip_ref: + with zipfile.ZipFile(ext_meta.zip_path, "r") as zip_ref: zip_ref.extractall(ext_upgrade_dir) module_name = f"lnbits.extensions.{ext_id}" module_installed = module_name in sys.modules # todo: is admin only ext = Extension( - code=extension.id, + code=ext_meta.id, is_valid=True, is_admin_only=False, - name=extension.name, - hash=extension.hash if module_installed else "", + name=ext_meta.name, + hash=ext_meta.hash if module_installed else "", ) current_versions = await get_dbversions() @@ -836,8 +791,8 @@ async def api_install_extension( except Exception as ex: logger.warning(ex) # remove downloaded archive - if os.path.isfile(ext_zip_file): - os.remove(ext_zip_file) + if os.path.isfile(ext_meta.zip_path): + os.remove(ext_meta.zip_path) # remove module from extensions shutil.rmtree(ext_dir, True) diff --git a/lnbits/helpers.py b/lnbits/helpers.py index a9e583a7..522e9b00 100644 --- a/lnbits/helpers.py +++ b/lnbits/helpers.py @@ -48,6 +48,14 @@ class InstallableExtension(NamedTuple): is_admin_only: bool = False version: Optional[int] = 0 + @property + def zip_path(self): + extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions") + os.makedirs(extensions_data_dir, exist_ok=True) + ext_data_dir = os.path.join(extensions_data_dir, self.id) + shutil.rmtree(ext_data_dir, True) + return os.path.join(extensions_data_dir, f"{self.id}.zip") + class ExtensionManager: def __init__(self, include_disabled_exts=False):