refactor: move more logic to InstallableExtension

This commit is contained in:
Vlad Stan 2023-01-11 10:57:19 +02:00
parent 1b07768b76
commit e7a150e708
3 changed files with 36 additions and 24 deletions

View file

@ -3,11 +3,13 @@ import re
from loguru import logger
from lnbits.helpers import Extension
from . import db as core_db
from .crud import update_migration_version
async def migrate_extension_database(ext, current_version):
async def migrate_extension_database(ext: Extension, current_version):
try:
ext_migrations = importlib.import_module(f"{ext.module_name}.migrations")
ext_db = importlib.import_module(ext.module_name).db

View file

@ -7,7 +7,6 @@ import shutil
import sys
import time
import uuid
import zipfile
from http import HTTPStatus
from io import BytesIO
from pathlib import Path
@ -731,42 +730,28 @@ async def api_install_extension(
ext_info: InstallableExtension = await InstallableExtension.get_extension_info(
ext_id, hash
)
ext_info.download_archive()
try:
ext_dir = os.path.join("lnbits/extensions", ext_id)
shutil.rmtree(ext_dir, True)
with zipfile.ZipFile(ext_info.zip_path, "r") as zip_ref:
zip_ref.extractall("lnbits/extensions")
ext_info.download_archive()
ext_info.extract_archive()
ext_upgrade_dir = os.path.join(
"lnbits/upgrades", f"{ext_info.id}-{ext_info.hash}"
)
os.makedirs("lnbits/upgrades", exist_ok=True)
shutil.rmtree(ext_upgrade_dir, True)
with zipfile.ZipFile(ext_info.zip_path, "r") as zip_ref:
zip_ref.extractall(ext_upgrade_dir)
module_name = f"lnbits.extensions.{ext_id}"
module_installed = module_name in sys.modules
# todo: is admin only
ext = Extension(
code=ext_info.id,
is_valid=True,
is_admin_only=False,
name=ext_info.name,
hash=ext_info.hash if module_installed else "",
hash=ext_info.hash if ext_info.module_installed else "",
)
current_versions = await get_dbversions()
current_version = current_versions.get(ext.code, 0)
await migrate_extension_database(ext, current_version) # todo: use new module
db_version = (await get_dbversions()).get(ext.code, 0)
await migrate_extension_database(ext, db_version) # todo: use new module
# disable by default
await update_user_extension(user_id=USER_ID_ALL, extension=ext_id, active=False)
settings.lnbits_disabled_extensions += [ext_id]
if module_installed:
if ext_info.module_installed:
# update upgraded extensions list if module already installed
ext_temp_path = f"{ext.hash}/{ext.code}"
clean_upgraded_exts = list(
@ -786,7 +771,7 @@ async def api_install_extension(
os.remove(ext_info.zip_path)
# remove module from extensions
shutil.rmtree(ext_dir, True)
shutil.rmtree(ext_info.ext_dir, True)
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(ex)
)

View file

@ -3,7 +3,9 @@ import hashlib
import json
import os
import shutil
import sys
import urllib.request
import zipfile
from http import HTTPStatus
from typing import Any, List, NamedTuple, Optional
@ -55,11 +57,23 @@ class InstallableExtension(NamedTuple):
version: Optional[int] = 0
@property
def zip_path(self):
def zip_path(self) -> str:
extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions")
os.makedirs(extensions_data_dir, exist_ok=True)
return os.path.join(extensions_data_dir, f"{self.id}.zip")
@property
def ext_dir(self) -> str:
return os.path.join("lnbits", "extensions", self.id)
@property
def module_name(self) -> str:
return f"lnbits.extensions.{self.id}"
@property
def module_installed(self) -> bool:
return self.module_name in sys.modules
def download_archive(self):
ext_zip_file = self.zip_path
if os.path.isfile(ext_zip_file):
@ -83,6 +97,17 @@ class InstallableExtension(NamedTuple):
detail="File hash missmatch. Will not install.",
)
def extract_archive(self):
shutil.rmtree(self.ext_dir, True)
with zipfile.ZipFile(self.zip_path, "r") as zip_ref:
zip_ref.extractall(os.path.join("lnbits", "extensions"))
ext_upgrade_dir = os.path.join("lnbits", "upgrades", f"{self.id}-{self.hash}")
os.makedirs(os.path.join("lnbits", "upgrades"), exist_ok=True)
shutil.rmtree(ext_upgrade_dir, True)
with zipfile.ZipFile(self.zip_path, "r") as zip_ref:
zip_ref.extractall(ext_upgrade_dir)
@classmethod
async def get_extension_info(cls, ext_id: str, hash: str) -> "InstallableExtension":
installable_extensions: List[