refactor: extract dome methods to helpers
This commit is contained in:
parent
cae71faf37
commit
cb6349fd76
3 changed files with 73 additions and 58 deletions
|
|
@ -1,14 +1,16 @@
|
|||
import hashlib
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import urllib.request
|
||||
from http import HTTPStatus
|
||||
from typing import List
|
||||
|
||||
import httpx
|
||||
from fastapi.exceptions import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from lnbits.helpers import InstallableExtension
|
||||
from lnbits.helpers import InstallableExtension, get_valid_extensions
|
||||
from lnbits.settings import settings
|
||||
|
||||
from . import db as core_db
|
||||
|
|
@ -76,6 +78,56 @@ async def get_installable_extensions() -> List[InstallableExtension]:
|
|||
return extension_list
|
||||
|
||||
|
||||
async def get_installable_extension_meta(
|
||||
ext_id: str, hash: str
|
||||
) -> InstallableExtension:
|
||||
installable_extensions: List[
|
||||
InstallableExtension
|
||||
] = await get_installable_extensions()
|
||||
|
||||
valid_extensions = [
|
||||
e for e in installable_extensions if e.id == ext_id and e.hash == hash
|
||||
]
|
||||
if len(valid_extensions) == 0:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
detail=f"Unknown extension id: {ext_id}",
|
||||
)
|
||||
extension = valid_extensions[0]
|
||||
|
||||
# check that all dependecies are installed
|
||||
installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True)))
|
||||
if not set(extension.dependencies).issubset(installed_extensions):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail=f"Not all dependencies are installed: {extension.dependencies}",
|
||||
)
|
||||
|
||||
return extension
|
||||
|
||||
|
||||
def download_extension_archive(archive: str, ext_zip_file: str, hash: str):
|
||||
if os.path.isfile(ext_zip_file):
|
||||
os.remove(ext_zip_file)
|
||||
try:
|
||||
download_url(archive, ext_zip_file)
|
||||
except Exception as ex:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail="Cannot fetch extension archive file",
|
||||
)
|
||||
|
||||
archive_hash = file_hash(ext_zip_file)
|
||||
if hash != archive_hash:
|
||||
# remove downloaded archive
|
||||
if os.path.isfile(ext_zip_file):
|
||||
os.remove(ext_zip_file)
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail="File hash missmatch. Will not install.",
|
||||
)
|
||||
|
||||
|
||||
def download_url(url, save_path):
|
||||
with urllib.request.urlopen(url) as dl_file:
|
||||
with open(save_path, "wb") as out_file:
|
||||
|
|
|
|||
|
|
@ -39,8 +39,9 @@ from starlette.responses import StreamingResponse
|
|||
|
||||
from lnbits import bolt11, lnurl
|
||||
from lnbits.core.helpers import (
|
||||
download_url,
|
||||
download_extension_archive,
|
||||
file_hash,
|
||||
get_installable_extension_meta,
|
||||
get_installable_extensions,
|
||||
migrate_extension_database,
|
||||
)
|
||||
|
|
@ -736,80 +737,34 @@ async def websocket_update_get(item_id: str, data: str):
|
|||
async def api_install_extension(
|
||||
ext_id: str, hash: str, user: User = Depends(check_admin)
|
||||
):
|
||||
try:
|
||||
extension_list: List[InstallableExtension] = await get_installable_extensions()
|
||||
except Exception as ex:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail="Cannot fetch installable extension list",
|
||||
)
|
||||
|
||||
extensions = [e for e in extension_list if e.id == ext_id and e.hash == hash]
|
||||
if len(extensions) == 0:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
detail=f"Unknown extension id: {ext_id}",
|
||||
)
|
||||
extension = extensions[0]
|
||||
ext_meta: InstallableExtension = await get_installable_extension_meta(ext_id, hash)
|
||||
|
||||
# check that all dependecies are installed
|
||||
installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True)))
|
||||
if not set(extension.dependencies).issubset(installed_extensions):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail=f"Not all dependencies are installed: {extension.dependencies}",
|
||||
)
|
||||
|
||||
# move files to the right location
|
||||
extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions")
|
||||
os.makedirs(extensions_data_dir, exist_ok=True)
|
||||
ext_data_dir = os.path.join(extensions_data_dir, ext_id)
|
||||
shutil.rmtree(ext_data_dir, True)
|
||||
ext_zip_file = os.path.join(extensions_data_dir, f"{ext_id}.zip")
|
||||
if os.path.isfile(ext_zip_file):
|
||||
os.remove(ext_zip_file)
|
||||
|
||||
try:
|
||||
download_url(extension.archive, ext_zip_file)
|
||||
except Exception as ex:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail="Cannot fetch extension archive file",
|
||||
)
|
||||
|
||||
archive_hash = file_hash(ext_zip_file)
|
||||
if extension.hash != archive_hash:
|
||||
# remove downloaded archive
|
||||
if os.path.isfile(ext_zip_file):
|
||||
os.remove(ext_zip_file)
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
detail="File hash missmatch. Will not install.",
|
||||
)
|
||||
download_extension_archive(ext_meta.archive, ext_meta.zip_path, ext_meta.hash)
|
||||
|
||||
try:
|
||||
ext_dir = os.path.join("lnbits/extensions", ext_id)
|
||||
shutil.rmtree(ext_dir, True)
|
||||
with zipfile.ZipFile(ext_zip_file, "r") as zip_ref:
|
||||
with zipfile.ZipFile(ext_meta.zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall("lnbits/extensions")
|
||||
|
||||
ext_upgrade_dir = os.path.join(
|
||||
"lnbits/upgrades", f"{extension.id}-{extension.hash}"
|
||||
"lnbits/upgrades", f"{ext_meta.id}-{ext_meta.hash}"
|
||||
)
|
||||
os.makedirs("lnbits/upgrades", exist_ok=True)
|
||||
shutil.rmtree(ext_upgrade_dir, True)
|
||||
with zipfile.ZipFile(ext_zip_file, "r") as zip_ref:
|
||||
with zipfile.ZipFile(ext_meta.zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(ext_upgrade_dir)
|
||||
|
||||
module_name = f"lnbits.extensions.{ext_id}"
|
||||
module_installed = module_name in sys.modules
|
||||
# todo: is admin only
|
||||
ext = Extension(
|
||||
code=extension.id,
|
||||
code=ext_meta.id,
|
||||
is_valid=True,
|
||||
is_admin_only=False,
|
||||
name=extension.name,
|
||||
hash=extension.hash if module_installed else "",
|
||||
name=ext_meta.name,
|
||||
hash=ext_meta.hash if module_installed else "",
|
||||
)
|
||||
|
||||
current_versions = await get_dbversions()
|
||||
|
|
@ -836,8 +791,8 @@ async def api_install_extension(
|
|||
except Exception as ex:
|
||||
logger.warning(ex)
|
||||
# remove downloaded archive
|
||||
if os.path.isfile(ext_zip_file):
|
||||
os.remove(ext_zip_file)
|
||||
if os.path.isfile(ext_meta.zip_path):
|
||||
os.remove(ext_meta.zip_path)
|
||||
|
||||
# remove module from extensions
|
||||
shutil.rmtree(ext_dir, True)
|
||||
|
|
|
|||
|
|
@ -48,6 +48,14 @@ class InstallableExtension(NamedTuple):
|
|||
is_admin_only: bool = False
|
||||
version: Optional[int] = 0
|
||||
|
||||
@property
|
||||
def zip_path(self):
|
||||
extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions")
|
||||
os.makedirs(extensions_data_dir, exist_ok=True)
|
||||
ext_data_dir = os.path.join(extensions_data_dir, self.id)
|
||||
shutil.rmtree(ext_data_dir, True)
|
||||
return os.path.join(extensions_data_dir, f"{self.id}.zip")
|
||||
|
||||
|
||||
class ExtensionManager:
|
||||
def __init__(self, include_disabled_exts=False):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue