refactor: move more logic to InstallableExtension

This commit is contained in:
Vlad Stan 2023-01-11 10:57:19 +02:00
parent 1b07768b76
commit e7a150e708
3 changed files with 36 additions and 24 deletions

View file

@ -3,11 +3,13 @@ import re
from loguru import logger from loguru import logger
from lnbits.helpers import Extension
from . import db as core_db from . import db as core_db
from .crud import update_migration_version from .crud import update_migration_version
async def migrate_extension_database(ext, current_version): async def migrate_extension_database(ext: Extension, current_version):
try: try:
ext_migrations = importlib.import_module(f"{ext.module_name}.migrations") ext_migrations = importlib.import_module(f"{ext.module_name}.migrations")
ext_db = importlib.import_module(ext.module_name).db ext_db = importlib.import_module(ext.module_name).db

View file

@ -7,7 +7,6 @@ import shutil
import sys import sys
import time import time
import uuid import uuid
import zipfile
from http import HTTPStatus from http import HTTPStatus
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
@ -731,42 +730,28 @@ async def api_install_extension(
ext_info: InstallableExtension = await InstallableExtension.get_extension_info( ext_info: InstallableExtension = await InstallableExtension.get_extension_info(
ext_id, hash ext_id, hash
) )
ext_info.download_archive()
try: try:
ext_dir = os.path.join("lnbits/extensions", ext_id) ext_info.download_archive()
shutil.rmtree(ext_dir, True) ext_info.extract_archive()
with zipfile.ZipFile(ext_info.zip_path, "r") as zip_ref:
zip_ref.extractall("lnbits/extensions")
ext_upgrade_dir = os.path.join(
"lnbits/upgrades", f"{ext_info.id}-{ext_info.hash}"
)
os.makedirs("lnbits/upgrades", exist_ok=True)
shutil.rmtree(ext_upgrade_dir, True)
with zipfile.ZipFile(ext_info.zip_path, "r") as zip_ref:
zip_ref.extractall(ext_upgrade_dir)
module_name = f"lnbits.extensions.{ext_id}"
module_installed = module_name in sys.modules
# todo: is admin only # todo: is admin only
ext = Extension( ext = Extension(
code=ext_info.id, code=ext_info.id,
is_valid=True, is_valid=True,
is_admin_only=False, is_admin_only=False,
name=ext_info.name, name=ext_info.name,
hash=ext_info.hash if module_installed else "", hash=ext_info.hash if ext_info.module_installed else "",
) )
current_versions = await get_dbversions() db_version = (await get_dbversions()).get(ext.code, 0)
current_version = current_versions.get(ext.code, 0) await migrate_extension_database(ext, db_version) # todo: use new module
await migrate_extension_database(ext, current_version) # todo: use new module
# disable by default # disable by default
await update_user_extension(user_id=USER_ID_ALL, extension=ext_id, active=False) await update_user_extension(user_id=USER_ID_ALL, extension=ext_id, active=False)
settings.lnbits_disabled_extensions += [ext_id] settings.lnbits_disabled_extensions += [ext_id]
if module_installed: if ext_info.module_installed:
# update upgraded extensions list if module already installed # update upgraded extensions list if module already installed
ext_temp_path = f"{ext.hash}/{ext.code}" ext_temp_path = f"{ext.hash}/{ext.code}"
clean_upgraded_exts = list( clean_upgraded_exts = list(
@ -786,7 +771,7 @@ async def api_install_extension(
os.remove(ext_info.zip_path) os.remove(ext_info.zip_path)
# remove module from extensions # remove module from extensions
shutil.rmtree(ext_dir, True) shutil.rmtree(ext_info.ext_dir, True)
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(ex) status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(ex)
) )

View file

@ -3,7 +3,9 @@ import hashlib
import json import json
import os import os
import shutil import shutil
import sys
import urllib.request import urllib.request
import zipfile
from http import HTTPStatus from http import HTTPStatus
from typing import Any, List, NamedTuple, Optional from typing import Any, List, NamedTuple, Optional
@ -55,11 +57,23 @@ class InstallableExtension(NamedTuple):
version: Optional[int] = 0 version: Optional[int] = 0
@property @property
def zip_path(self): def zip_path(self) -> str:
extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions") extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions")
os.makedirs(extensions_data_dir, exist_ok=True) os.makedirs(extensions_data_dir, exist_ok=True)
return os.path.join(extensions_data_dir, f"{self.id}.zip") return os.path.join(extensions_data_dir, f"{self.id}.zip")
@property
def ext_dir(self) -> str:
return os.path.join("lnbits", "extensions", self.id)
@property
def module_name(self) -> str:
return f"lnbits.extensions.{self.id}"
@property
def module_installed(self) -> bool:
return self.module_name in sys.modules
def download_archive(self): def download_archive(self):
ext_zip_file = self.zip_path ext_zip_file = self.zip_path
if os.path.isfile(ext_zip_file): if os.path.isfile(ext_zip_file):
@ -83,6 +97,17 @@ class InstallableExtension(NamedTuple):
detail="File hash missmatch. Will not install.", detail="File hash missmatch. Will not install.",
) )
def extract_archive(self):
shutil.rmtree(self.ext_dir, True)
with zipfile.ZipFile(self.zip_path, "r") as zip_ref:
zip_ref.extractall(os.path.join("lnbits", "extensions"))
ext_upgrade_dir = os.path.join("lnbits", "upgrades", f"{self.id}-{self.hash}")
os.makedirs(os.path.join("lnbits", "upgrades"), exist_ok=True)
shutil.rmtree(ext_upgrade_dir, True)
with zipfile.ZipFile(self.zip_path, "r") as zip_ref:
zip_ref.extractall(ext_upgrade_dir)
@classmethod @classmethod
async def get_extension_info(cls, ext_id: str, hash: str) -> "InstallableExtension": async def get_extension_info(cls, ext_id: str, hash: str) -> "InstallableExtension":
installable_extensions: List[ installable_extensions: List[