From e7a150e708935282b47e7e107faf4d75686a3889 Mon Sep 17 00:00:00 2001 From: Vlad Stan Date: Wed, 11 Jan 2023 10:57:19 +0200 Subject: [PATCH] refactor: move more logic to `InstallableExtension` --- lnbits/core/helpers.py | 4 +++- lnbits/core/views/api.py | 29 +++++++---------------------- lnbits/helpers.py | 27 ++++++++++++++++++++++++++- 3 files changed, 36 insertions(+), 24 deletions(-) diff --git a/lnbits/core/helpers.py b/lnbits/core/helpers.py index 9c198b2b..4cd5edbc 100644 --- a/lnbits/core/helpers.py +++ b/lnbits/core/helpers.py @@ -3,11 +3,13 @@ import re from loguru import logger +from lnbits.helpers import Extension + from . import db as core_db from .crud import update_migration_version -async def migrate_extension_database(ext, current_version): +async def migrate_extension_database(ext: Extension, current_version): try: ext_migrations = importlib.import_module(f"{ext.module_name}.migrations") ext_db = importlib.import_module(ext.module_name).db diff --git a/lnbits/core/views/api.py b/lnbits/core/views/api.py index fde40275..649305e0 100644 --- a/lnbits/core/views/api.py +++ b/lnbits/core/views/api.py @@ -7,7 +7,6 @@ import shutil import sys import time import uuid -import zipfile from http import HTTPStatus from io import BytesIO from pathlib import Path @@ -731,42 +730,28 @@ async def api_install_extension( ext_info: InstallableExtension = await InstallableExtension.get_extension_info( ext_id, hash ) - ext_info.download_archive() try: - ext_dir = os.path.join("lnbits/extensions", ext_id) - shutil.rmtree(ext_dir, True) - with zipfile.ZipFile(ext_info.zip_path, "r") as zip_ref: - zip_ref.extractall("lnbits/extensions") + ext_info.download_archive() + ext_info.extract_archive() - ext_upgrade_dir = os.path.join( - "lnbits/upgrades", f"{ext_info.id}-{ext_info.hash}" - ) - os.makedirs("lnbits/upgrades", exist_ok=True) - shutil.rmtree(ext_upgrade_dir, True) - with zipfile.ZipFile(ext_info.zip_path, "r") as zip_ref: - zip_ref.extractall(ext_upgrade_dir) - - module_name = f"lnbits.extensions.{ext_id}" - module_installed = module_name in sys.modules # todo: is admin only ext = Extension( code=ext_info.id, is_valid=True, is_admin_only=False, name=ext_info.name, - hash=ext_info.hash if module_installed else "", + hash=ext_info.hash if ext_info.module_installed else "", ) - current_versions = await get_dbversions() - current_version = current_versions.get(ext.code, 0) - await migrate_extension_database(ext, current_version) # todo: use new module + db_version = (await get_dbversions()).get(ext.code, 0) + await migrate_extension_database(ext, db_version) # todo: use new module # disable by default await update_user_extension(user_id=USER_ID_ALL, extension=ext_id, active=False) settings.lnbits_disabled_extensions += [ext_id] - if module_installed: + if ext_info.module_installed: # update upgraded extensions list if module already installed ext_temp_path = f"{ext.hash}/{ext.code}" clean_upgraded_exts = list( @@ -786,7 +771,7 @@ async def api_install_extension( os.remove(ext_info.zip_path) # remove module from extensions - shutil.rmtree(ext_dir, True) + shutil.rmtree(ext_info.ext_dir, True) raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(ex) ) diff --git a/lnbits/helpers.py b/lnbits/helpers.py index cece852d..d8d4bd3d 100644 --- a/lnbits/helpers.py +++ b/lnbits/helpers.py @@ -3,7 +3,9 @@ import hashlib import json import os import shutil +import sys import urllib.request +import zipfile from http import HTTPStatus from typing import Any, List, NamedTuple, Optional @@ -55,11 +57,23 @@ class InstallableExtension(NamedTuple): version: Optional[int] = 0 @property - def zip_path(self): + def zip_path(self) -> str: extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions") os.makedirs(extensions_data_dir, exist_ok=True) return os.path.join(extensions_data_dir, f"{self.id}.zip") + @property + def ext_dir(self) -> str: + return os.path.join("lnbits", "extensions", self.id) + + @property + def module_name(self) -> str: + return f"lnbits.extensions.{self.id}" + + @property + def module_installed(self) -> bool: + return self.module_name in sys.modules + def download_archive(self): ext_zip_file = self.zip_path if os.path.isfile(ext_zip_file): @@ -83,6 +97,17 @@ class InstallableExtension(NamedTuple): detail="File hash missmatch. Will not install.", ) + def extract_archive(self): + shutil.rmtree(self.ext_dir, True) + with zipfile.ZipFile(self.zip_path, "r") as zip_ref: + zip_ref.extractall(os.path.join("lnbits", "extensions")) + + ext_upgrade_dir = os.path.join("lnbits", "upgrades", f"{self.id}-{self.hash}") + os.makedirs(os.path.join("lnbits", "upgrades"), exist_ok=True) + shutil.rmtree(ext_upgrade_dir, True) + with zipfile.ZipFile(self.zip_path, "r") as zip_ref: + zip_ref.extractall(ext_upgrade_dir) + @classmethod async def get_extension_info(cls, ext_id: str, hash: str) -> "InstallableExtension": installable_extensions: List[