refactor: extract dome methods to helpers

This commit is contained in:
Vlad Stan 2023-01-11 09:15:35 +02:00
parent cae71faf37
commit cb6349fd76
3 changed files with 73 additions and 58 deletions

View file

@ -1,14 +1,16 @@
import hashlib
import importlib
import os
import re
import urllib.request
from http import HTTPStatus
from typing import List
import httpx
from fastapi.exceptions import HTTPException
from loguru import logger
from lnbits.helpers import InstallableExtension
from lnbits.helpers import InstallableExtension, get_valid_extensions
from lnbits.settings import settings
from . import db as core_db
@ -76,6 +78,56 @@ async def get_installable_extensions() -> List[InstallableExtension]:
return extension_list
async def get_installable_extension_meta(
ext_id: str, hash: str
) -> InstallableExtension:
installable_extensions: List[
InstallableExtension
] = await get_installable_extensions()
valid_extensions = [
e for e in installable_extensions if e.id == ext_id and e.hash == hash
]
if len(valid_extensions) == 0:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Unknown extension id: {ext_id}",
)
extension = valid_extensions[0]
# check that all dependecies are installed
installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True)))
if not set(extension.dependencies).issubset(installed_extensions):
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail=f"Not all dependencies are installed: {extension.dependencies}",
)
return extension
def download_extension_archive(archive: str, ext_zip_file: str, hash: str):
if os.path.isfile(ext_zip_file):
os.remove(ext_zip_file)
try:
download_url(archive, ext_zip_file)
except Exception as ex:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Cannot fetch extension archive file",
)
archive_hash = file_hash(ext_zip_file)
if hash != archive_hash:
# remove downloaded archive
if os.path.isfile(ext_zip_file):
os.remove(ext_zip_file)
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="File hash missmatch. Will not install.",
)
def download_url(url, save_path):
with urllib.request.urlopen(url) as dl_file:
with open(save_path, "wb") as out_file:

View file

@ -39,8 +39,9 @@ from starlette.responses import StreamingResponse
from lnbits import bolt11, lnurl
from lnbits.core.helpers import (
download_url,
download_extension_archive,
file_hash,
get_installable_extension_meta,
get_installable_extensions,
migrate_extension_database,
)
@ -736,80 +737,34 @@ async def websocket_update_get(item_id: str, data: str):
async def api_install_extension(
ext_id: str, hash: str, user: User = Depends(check_admin)
):
try:
extension_list: List[InstallableExtension] = await get_installable_extensions()
except Exception as ex:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Cannot fetch installable extension list",
)
extensions = [e for e in extension_list if e.id == ext_id and e.hash == hash]
if len(extensions) == 0:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Unknown extension id: {ext_id}",
)
extension = extensions[0]
ext_meta: InstallableExtension = await get_installable_extension_meta(ext_id, hash)
# check that all dependecies are installed
installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True)))
if not set(extension.dependencies).issubset(installed_extensions):
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail=f"Not all dependencies are installed: {extension.dependencies}",
)
# move files to the right location
extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions")
os.makedirs(extensions_data_dir, exist_ok=True)
ext_data_dir = os.path.join(extensions_data_dir, ext_id)
shutil.rmtree(ext_data_dir, True)
ext_zip_file = os.path.join(extensions_data_dir, f"{ext_id}.zip")
if os.path.isfile(ext_zip_file):
os.remove(ext_zip_file)
try:
download_url(extension.archive, ext_zip_file)
except Exception as ex:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Cannot fetch extension archive file",
)
archive_hash = file_hash(ext_zip_file)
if extension.hash != archive_hash:
# remove downloaded archive
if os.path.isfile(ext_zip_file):
os.remove(ext_zip_file)
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="File hash missmatch. Will not install.",
)
download_extension_archive(ext_meta.archive, ext_meta.zip_path, ext_meta.hash)
try:
ext_dir = os.path.join("lnbits/extensions", ext_id)
shutil.rmtree(ext_dir, True)
with zipfile.ZipFile(ext_zip_file, "r") as zip_ref:
with zipfile.ZipFile(ext_meta.zip_path, "r") as zip_ref:
zip_ref.extractall("lnbits/extensions")
ext_upgrade_dir = os.path.join(
"lnbits/upgrades", f"{extension.id}-{extension.hash}"
"lnbits/upgrades", f"{ext_meta.id}-{ext_meta.hash}"
)
os.makedirs("lnbits/upgrades", exist_ok=True)
shutil.rmtree(ext_upgrade_dir, True)
with zipfile.ZipFile(ext_zip_file, "r") as zip_ref:
with zipfile.ZipFile(ext_meta.zip_path, "r") as zip_ref:
zip_ref.extractall(ext_upgrade_dir)
module_name = f"lnbits.extensions.{ext_id}"
module_installed = module_name in sys.modules
# todo: is admin only
ext = Extension(
code=extension.id,
code=ext_meta.id,
is_valid=True,
is_admin_only=False,
name=extension.name,
hash=extension.hash if module_installed else "",
name=ext_meta.name,
hash=ext_meta.hash if module_installed else "",
)
current_versions = await get_dbversions()
@ -836,8 +791,8 @@ async def api_install_extension(
except Exception as ex:
logger.warning(ex)
# remove downloaded archive
if os.path.isfile(ext_zip_file):
os.remove(ext_zip_file)
if os.path.isfile(ext_meta.zip_path):
os.remove(ext_meta.zip_path)
# remove module from extensions
shutil.rmtree(ext_dir, True)

View file

@ -48,6 +48,14 @@ class InstallableExtension(NamedTuple):
is_admin_only: bool = False
version: Optional[int] = 0
@property
def zip_path(self):
extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions")
os.makedirs(extensions_data_dir, exist_ok=True)
ext_data_dir = os.path.join(extensions_data_dir, self.id)
shutil.rmtree(ext_data_dir, True)
return os.path.join(extensions_data_dir, f"{self.id}.zip")
class ExtensionManager:
def __init__(self, include_disabled_exts=False):