refactor: move more logic to InstallableExtension
This commit is contained in:
parent
1b07768b76
commit
e7a150e708
3 changed files with 36 additions and 24 deletions
|
|
@ -3,11 +3,13 @@ import re
|
|||
|
||||
from loguru import logger
|
||||
|
||||
from lnbits.helpers import Extension
|
||||
|
||||
from . import db as core_db
|
||||
from .crud import update_migration_version
|
||||
|
||||
|
||||
async def migrate_extension_database(ext, current_version):
|
||||
async def migrate_extension_database(ext: Extension, current_version):
|
||||
try:
|
||||
ext_migrations = importlib.import_module(f"{ext.module_name}.migrations")
|
||||
ext_db = importlib.import_module(ext.module_name).db
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import shutil
|
|||
import sys
|
||||
import time
|
||||
import uuid
|
||||
import zipfile
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
|
@ -731,42 +730,28 @@ async def api_install_extension(
|
|||
ext_info: InstallableExtension = await InstallableExtension.get_extension_info(
|
||||
ext_id, hash
|
||||
)
|
||||
ext_info.download_archive()
|
||||
|
||||
try:
|
||||
ext_dir = os.path.join("lnbits/extensions", ext_id)
|
||||
shutil.rmtree(ext_dir, True)
|
||||
with zipfile.ZipFile(ext_info.zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall("lnbits/extensions")
|
||||
ext_info.download_archive()
|
||||
ext_info.extract_archive()
|
||||
|
||||
ext_upgrade_dir = os.path.join(
|
||||
"lnbits/upgrades", f"{ext_info.id}-{ext_info.hash}"
|
||||
)
|
||||
os.makedirs("lnbits/upgrades", exist_ok=True)
|
||||
shutil.rmtree(ext_upgrade_dir, True)
|
||||
with zipfile.ZipFile(ext_info.zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(ext_upgrade_dir)
|
||||
|
||||
module_name = f"lnbits.extensions.{ext_id}"
|
||||
module_installed = module_name in sys.modules
|
||||
# todo: is admin only
|
||||
ext = Extension(
|
||||
code=ext_info.id,
|
||||
is_valid=True,
|
||||
is_admin_only=False,
|
||||
name=ext_info.name,
|
||||
hash=ext_info.hash if module_installed else "",
|
||||
hash=ext_info.hash if ext_info.module_installed else "",
|
||||
)
|
||||
|
||||
current_versions = await get_dbversions()
|
||||
current_version = current_versions.get(ext.code, 0)
|
||||
await migrate_extension_database(ext, current_version) # todo: use new module
|
||||
db_version = (await get_dbversions()).get(ext.code, 0)
|
||||
await migrate_extension_database(ext, db_version) # todo: use new module
|
||||
|
||||
# disable by default
|
||||
await update_user_extension(user_id=USER_ID_ALL, extension=ext_id, active=False)
|
||||
settings.lnbits_disabled_extensions += [ext_id]
|
||||
|
||||
if module_installed:
|
||||
if ext_info.module_installed:
|
||||
# update upgraded extensions list if module already installed
|
||||
ext_temp_path = f"{ext.hash}/{ext.code}"
|
||||
clean_upgraded_exts = list(
|
||||
|
|
@ -786,7 +771,7 @@ async def api_install_extension(
|
|||
os.remove(ext_info.zip_path)
|
||||
|
||||
# remove module from extensions
|
||||
shutil.rmtree(ext_dir, True)
|
||||
shutil.rmtree(ext_info.ext_dir, True)
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(ex)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@ import hashlib
|
|||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import urllib.request
|
||||
import zipfile
|
||||
from http import HTTPStatus
|
||||
from typing import Any, List, NamedTuple, Optional
|
||||
|
||||
|
|
@ -55,11 +57,23 @@ class InstallableExtension(NamedTuple):
|
|||
version: Optional[int] = 0
|
||||
|
||||
@property
|
||||
def zip_path(self):
|
||||
def zip_path(self) -> str:
|
||||
extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions")
|
||||
os.makedirs(extensions_data_dir, exist_ok=True)
|
||||
return os.path.join(extensions_data_dir, f"{self.id}.zip")
|
||||
|
||||
@property
|
||||
def ext_dir(self) -> str:
|
||||
return os.path.join("lnbits", "extensions", self.id)
|
||||
|
||||
@property
|
||||
def module_name(self) -> str:
|
||||
return f"lnbits.extensions.{self.id}"
|
||||
|
||||
@property
|
||||
def module_installed(self) -> bool:
|
||||
return self.module_name in sys.modules
|
||||
|
||||
def download_archive(self):
|
||||
ext_zip_file = self.zip_path
|
||||
if os.path.isfile(ext_zip_file):
|
||||
|
|
@ -83,6 +97,17 @@ class InstallableExtension(NamedTuple):
|
|||
detail="File hash missmatch. Will not install.",
|
||||
)
|
||||
|
||||
def extract_archive(self):
|
||||
shutil.rmtree(self.ext_dir, True)
|
||||
with zipfile.ZipFile(self.zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(os.path.join("lnbits", "extensions"))
|
||||
|
||||
ext_upgrade_dir = os.path.join("lnbits", "upgrades", f"{self.id}-{self.hash}")
|
||||
os.makedirs(os.path.join("lnbits", "upgrades"), exist_ok=True)
|
||||
shutil.rmtree(ext_upgrade_dir, True)
|
||||
with zipfile.ZipFile(self.zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(ext_upgrade_dir)
|
||||
|
||||
@classmethod
|
||||
async def get_extension_info(cls, ext_id: str, hash: str) -> "InstallableExtension":
|
||||
installable_extensions: List[
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue