From 3ed2b3cdeb471526255cf7a1e9712a485f4272c4 Mon Sep 17 00:00:00 2001 From: Vlad Stan Date: Wed, 11 Jan 2023 10:25:09 +0200 Subject: [PATCH] refactor: move more logic to `InstallableExtension` --- lnbits/core/helpers.py | 104 ----------------------------------- lnbits/core/views/api.py | 41 ++++++-------- lnbits/core/views/generic.py | 5 +- lnbits/helpers.py | 103 +++++++++++++++++++++++++++++++++- 4 files changed, 121 insertions(+), 132 deletions(-) diff --git a/lnbits/core/helpers.py b/lnbits/core/helpers.py index 1dac5955..9c198b2b 100644 --- a/lnbits/core/helpers.py +++ b/lnbits/core/helpers.py @@ -1,18 +1,8 @@ -import hashlib import importlib -import os import re -import urllib.request -from http import HTTPStatus -from typing import List -import httpx -from fastapi.exceptions import HTTPException from loguru import logger -from lnbits.helpers import InstallableExtension, get_valid_extensions -from lnbits.settings import settings - from . import db as core_db from .crud import update_migration_version @@ -48,97 +38,3 @@ async def run_migration(db, migrations_module, current_version): else: async with core_db.connect() as conn: await update_migration_version(conn, db_name, version) - - -async def get_installable_extensions() -> List[InstallableExtension]: - extension_list: List[InstallableExtension] = [] - - async with httpx.AsyncClient() as client: - for url in settings.lnbits_extensions_manifests: - resp = await client.get(url) - if resp.status_code != 200: - raise HTTPException( - status_code=404, - detail=f"Unable to fetch extension list for repository: {url}", - ) - for e in resp.json()["extensions"]: - extension_list += [ - InstallableExtension( - id=e["id"], - name=e["name"], - archive=e["archive"], - hash=e["hash"], - short_description=e["shortDescription"], - details=e["details"] if "details" in e else "", - icon=e["icon"], - dependencies=e["dependencies"] if "dependencies" in e else [], - ) - ] - - return extension_list - - -async def get_installable_extension_meta( - ext_id: str, hash: str -) -> InstallableExtension: - installable_extensions: List[ - InstallableExtension - ] = await get_installable_extensions() - - valid_extensions = [ - e for e in installable_extensions if e.id == ext_id and e.hash == hash - ] - if len(valid_extensions) == 0: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail=f"Unknown extension id: {ext_id}", - ) - extension = valid_extensions[0] - - # check that all dependecies are installed - installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True))) - if not set(extension.dependencies).issubset(installed_extensions): - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, - detail=f"Not all dependencies are installed: {extension.dependencies}", - ) - - return extension - - -def download_extension_archive(archive: str, ext_zip_file: str, hash: str): - if os.path.isfile(ext_zip_file): - os.remove(ext_zip_file) - try: - download_url(archive, ext_zip_file) - except Exception as ex: - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, - detail="Cannot fetch extension archive file", - ) - - archive_hash = file_hash(ext_zip_file) - if hash != archive_hash: - # remove downloaded archive - if os.path.isfile(ext_zip_file): - os.remove(ext_zip_file) - raise HTTPException( - status_code=HTTPStatus.NOT_FOUND, - detail="File hash missmatch. Will not install.", - ) - - -def download_url(url, save_path): - with urllib.request.urlopen(url) as dl_file: - with open(save_path, "wb") as out_file: - out_file.write(dl_file.read()) - - -def file_hash(filename): - h = hashlib.sha256() - b = bytearray(128 * 1024) - mv = memoryview(b) - with open(filename, "rb", buffering=0) as f: - while n := f.readinto(mv): - h.update(mv[:n]) - return h.hexdigest() diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index b22ca838..fde40275 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -1,6 +1,5 @@ import asyncio import hashlib -import importlib import inspect import json import os @@ -30,26 +29,18 @@ from fastapi import ( ) from fastapi.exceptions import HTTPException from fastapi.params import Body -from genericpath import isfile from loguru import logger from pydantic import BaseModel from pydantic.fields import Field -from sse_starlette.sse import EventSourceResponse, ServerSentEvent +from sse_starlette.sse import EventSourceResponse from starlette.responses import StreamingResponse from lnbits import bolt11, lnurl -from lnbits.core.helpers import ( - download_extension_archive, - file_hash, - get_installable_extension_meta, - get_installable_extensions, - migrate_extension_database, -) +from lnbits.core.helpers import migrate_extension_database from lnbits.core.models import Payment, User, Wallet from lnbits.decorators import ( WalletTypeInfo, check_admin, - check_user_exists, get_key_type, require_admin_key, require_invoice_key, @@ -737,34 +728,34 @@ async def websocket_update_get(item_id: str, data: str): async def api_install_extension( ext_id: str, hash: str, user: User = Depends(check_admin) ): - - ext_meta: InstallableExtension = await get_installable_extension_meta(ext_id, hash) - - download_extension_archive(ext_meta.archive, ext_meta.zip_path, ext_meta.hash) + ext_info: InstallableExtension = await InstallableExtension.get_extension_info( + ext_id, hash + ) + ext_info.download_archive() try: ext_dir = os.path.join("lnbits/extensions", ext_id) shutil.rmtree(ext_dir, True) - with zipfile.ZipFile(ext_meta.zip_path, "r") as zip_ref: + with zipfile.ZipFile(ext_info.zip_path, "r") as zip_ref: zip_ref.extractall("lnbits/extensions") ext_upgrade_dir = os.path.join( - "lnbits/upgrades", f"{ext_meta.id}-{ext_meta.hash}" + "lnbits/upgrades", f"{ext_info.id}-{ext_info.hash}" ) os.makedirs("lnbits/upgrades", exist_ok=True) shutil.rmtree(ext_upgrade_dir, True) - with zipfile.ZipFile(ext_meta.zip_path, "r") as zip_ref: + with zipfile.ZipFile(ext_info.zip_path, "r") as zip_ref: zip_ref.extractall(ext_upgrade_dir) module_name = f"lnbits.extensions.{ext_id}" module_installed = module_name in sys.modules # todo: is admin only ext = Extension( - code=ext_meta.id, + code=ext_info.id, is_valid=True, is_admin_only=False, - name=ext_meta.name, - hash=ext_meta.hash if module_installed else "", + name=ext_info.name, + hash=ext_info.hash if module_installed else "", ) current_versions = await get_dbversions() @@ -791,8 +782,8 @@ async def api_install_extension( except Exception as ex: logger.warning(ex) # remove downloaded archive - if os.path.isfile(ext_meta.zip_path): - os.remove(ext_meta.zip_path) + if os.path.isfile(ext_info.zip_path): + os.remove(ext_info.zip_path) # remove module from extensions shutil.rmtree(ext_dir, True) @@ -804,7 +795,9 @@ async def api_install_extension( @core_app.delete("/api/v1/extension/{ext_id}") async def api_uninstall_extension(ext_id: str, user: User = Depends(check_admin)): try: - extension_list: List[InstallableExtension] = await get_installable_extensions() + extension_list: List[ + InstallableExtension + ] = await InstallableExtension.get_installable_extensions() except Exception as ex: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, diff --git a/lnbits/core/views/generic.py b/lnbits/core/views/generic.py index d14a43f6..04862e62 100644 --- a/lnbits/core/views/generic.py +++ b/lnbits/core/views/generic.py @@ -11,7 +11,6 @@ from pydantic.types import UUID4 from starlette.responses import HTMLResponse, JSONResponse from lnbits.core import db -from lnbits.core.helpers import get_installable_extensions from lnbits.core.models import User from lnbits.decorators import check_admin, check_user_exists from lnbits.helpers import template_renderer, url_for @@ -81,7 +80,9 @@ async def extensions_install( ) try: - extension_list: List[InstallableExtension] = await get_installable_extensions() + extension_list: List[ + InstallableExtension + ] = await InstallableExtension.get_installable_extensions() except Exception as ex: logger.warning(ex) raise HTTPException( diff --git a/lnbits/helpers.py b/lnbits/helpers.py index 522e9b00..cece852d 100644 --- a/lnbits/helpers.py +++ b/lnbits/helpers.py @@ -1,12 +1,18 @@ import glob +import hashlib import json import os +import shutil +import urllib.request from http import HTTPStatus from typing import Any, List, NamedTuple, Optional +import httpx import jinja2 import shortuuid # type: ignore +from fastapi.exceptions import HTTPException from fastapi.responses import JSONResponse +from loguru import logger from starlette.types import ASGIApp, Receive, Scope, Send from lnbits.jinja2_templating import Jinja2Templates @@ -52,10 +58,87 @@ class InstallableExtension(NamedTuple): def zip_path(self): extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions") os.makedirs(extensions_data_dir, exist_ok=True) - ext_data_dir = os.path.join(extensions_data_dir, self.id) - shutil.rmtree(ext_data_dir, True) return os.path.join(extensions_data_dir, f"{self.id}.zip") + def download_archive(self): + ext_zip_file = self.zip_path + if os.path.isfile(ext_zip_file): + os.remove(ext_zip_file) + try: + download_url(self.archive, ext_zip_file) + except Exception as ex: + logger.warning(ex) + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail="Cannot fetch extension archive file", + ) + + archive_hash = file_hash(ext_zip_file) + if self.hash != archive_hash: + # remove downloaded archive + if os.path.isfile(ext_zip_file): + os.remove(ext_zip_file) + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail="File hash missmatch. Will not install.", + ) + + @classmethod + async def get_extension_info(cls, ext_id: str, hash: str) -> "InstallableExtension": + installable_extensions: List[ + InstallableExtension + ] = await InstallableExtension.get_installable_extensions() + + valid_extensions = [ + e for e in installable_extensions if e.id == ext_id and e.hash == hash + ] + if len(valid_extensions) == 0: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"Unknown extension id: {ext_id}", + ) + extension = valid_extensions[0] + + # check that all dependecies are installed + installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True))) + if not set(extension.dependencies).issubset(installed_extensions): + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail=f"Not all dependencies are installed: {extension.dependencies}", + ) + + return extension + + @classmethod + async def get_installable_extensions(cls) -> List["InstallableExtension"]: + extension_list: List[InstallableExtension] = [] + + async with httpx.AsyncClient() as client: + for url in settings.lnbits_extensions_manifests: + resp = await client.get(url) + if resp.status_code != 200: + raise HTTPException( + status_code=404, + detail=f"Unable to fetch extension list for repository: {url}", + ) + for e in resp.json()["extensions"]: + extension_list += [ + InstallableExtension( + id=e["id"], + name=e["name"], + archive=e["archive"], + hash=e["hash"], + short_description=e["shortDescription"], + details=e["details"] if "details" in e else "", + icon=e["icon"], + dependencies=e["dependencies"] + if "dependencies" in e + else [], + ) + ] + + return extension_list + class ExtensionManager: def __init__(self, include_disabled_exts=False): @@ -289,3 +372,19 @@ def get_current_extension_name() -> str: except: ext_name = extension_director_name return ext_name + + +def download_url(url, save_path): + with urllib.request.urlopen(url) as dl_file: + with open(save_path, "wb") as out_file: + out_file.write(dl_file.read()) + + +def file_hash(filename): + h = hashlib.sha256() + b = bytearray(128 * 1024) + mv = memoryview(b) + with open(filename, "rb", buffering=0) as f: + while n := f.readinto(mv): + h.update(mv[:n]) + return h.hexdigest()