refactor: extract dome methods to helpers

This commit is contained in:
Vlad Stan 2023-01-11 09:15:35 +02:00
parent cae71faf37
commit cb6349fd76
3 changed files with 73 additions and 58 deletions

View file

@ -1,14 +1,16 @@
import hashlib import hashlib
import importlib import importlib
import os
import re import re
import urllib.request import urllib.request
from http import HTTPStatus
from typing import List from typing import List
import httpx import httpx
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from loguru import logger from loguru import logger
from lnbits.helpers import InstallableExtension from lnbits.helpers import InstallableExtension, get_valid_extensions
from lnbits.settings import settings from lnbits.settings import settings
from . import db as core_db from . import db as core_db
@ -76,6 +78,56 @@ async def get_installable_extensions() -> List[InstallableExtension]:
return extension_list return extension_list
async def get_installable_extension_meta(
ext_id: str, hash: str
) -> InstallableExtension:
installable_extensions: List[
InstallableExtension
] = await get_installable_extensions()
valid_extensions = [
e for e in installable_extensions if e.id == ext_id and e.hash == hash
]
if len(valid_extensions) == 0:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Unknown extension id: {ext_id}",
)
extension = valid_extensions[0]
# check that all dependecies are installed
installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True)))
if not set(extension.dependencies).issubset(installed_extensions):
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail=f"Not all dependencies are installed: {extension.dependencies}",
)
return extension
def download_extension_archive(archive: str, ext_zip_file: str, hash: str):
if os.path.isfile(ext_zip_file):
os.remove(ext_zip_file)
try:
download_url(archive, ext_zip_file)
except Exception as ex:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Cannot fetch extension archive file",
)
archive_hash = file_hash(ext_zip_file)
if hash != archive_hash:
# remove downloaded archive
if os.path.isfile(ext_zip_file):
os.remove(ext_zip_file)
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="File hash missmatch. Will not install.",
)
def download_url(url, save_path): def download_url(url, save_path):
with urllib.request.urlopen(url) as dl_file: with urllib.request.urlopen(url) as dl_file:
with open(save_path, "wb") as out_file: with open(save_path, "wb") as out_file:

View file

@ -39,8 +39,9 @@ from starlette.responses import StreamingResponse
from lnbits import bolt11, lnurl from lnbits import bolt11, lnurl
from lnbits.core.helpers import ( from lnbits.core.helpers import (
download_url, download_extension_archive,
file_hash, file_hash,
get_installable_extension_meta,
get_installable_extensions, get_installable_extensions,
migrate_extension_database, migrate_extension_database,
) )
@ -736,80 +737,34 @@ async def websocket_update_get(item_id: str, data: str):
async def api_install_extension( async def api_install_extension(
ext_id: str, hash: str, user: User = Depends(check_admin) ext_id: str, hash: str, user: User = Depends(check_admin)
): ):
try:
extension_list: List[InstallableExtension] = await get_installable_extensions()
except Exception as ex:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Cannot fetch installable extension list",
)
extensions = [e for e in extension_list if e.id == ext_id and e.hash == hash] ext_meta: InstallableExtension = await get_installable_extension_meta(ext_id, hash)
if len(extensions) == 0:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Unknown extension id: {ext_id}",
)
extension = extensions[0]
# check that all dependecies are installed download_extension_archive(ext_meta.archive, ext_meta.zip_path, ext_meta.hash)
installed_extensions = list(map(lambda e: e.code, get_valid_extensions(True)))
if not set(extension.dependencies).issubset(installed_extensions):
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail=f"Not all dependencies are installed: {extension.dependencies}",
)
# move files to the right location
extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions")
os.makedirs(extensions_data_dir, exist_ok=True)
ext_data_dir = os.path.join(extensions_data_dir, ext_id)
shutil.rmtree(ext_data_dir, True)
ext_zip_file = os.path.join(extensions_data_dir, f"{ext_id}.zip")
if os.path.isfile(ext_zip_file):
os.remove(ext_zip_file)
try:
download_url(extension.archive, ext_zip_file)
except Exception as ex:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="Cannot fetch extension archive file",
)
archive_hash = file_hash(ext_zip_file)
if extension.hash != archive_hash:
# remove downloaded archive
if os.path.isfile(ext_zip_file):
os.remove(ext_zip_file)
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail="File hash missmatch. Will not install.",
)
try: try:
ext_dir = os.path.join("lnbits/extensions", ext_id) ext_dir = os.path.join("lnbits/extensions", ext_id)
shutil.rmtree(ext_dir, True) shutil.rmtree(ext_dir, True)
with zipfile.ZipFile(ext_zip_file, "r") as zip_ref: with zipfile.ZipFile(ext_meta.zip_path, "r") as zip_ref:
zip_ref.extractall("lnbits/extensions") zip_ref.extractall("lnbits/extensions")
ext_upgrade_dir = os.path.join( ext_upgrade_dir = os.path.join(
"lnbits/upgrades", f"{extension.id}-{extension.hash}" "lnbits/upgrades", f"{ext_meta.id}-{ext_meta.hash}"
) )
os.makedirs("lnbits/upgrades", exist_ok=True) os.makedirs("lnbits/upgrades", exist_ok=True)
shutil.rmtree(ext_upgrade_dir, True) shutil.rmtree(ext_upgrade_dir, True)
with zipfile.ZipFile(ext_zip_file, "r") as zip_ref: with zipfile.ZipFile(ext_meta.zip_path, "r") as zip_ref:
zip_ref.extractall(ext_upgrade_dir) zip_ref.extractall(ext_upgrade_dir)
module_name = f"lnbits.extensions.{ext_id}" module_name = f"lnbits.extensions.{ext_id}"
module_installed = module_name in sys.modules module_installed = module_name in sys.modules
# todo: is admin only # todo: is admin only
ext = Extension( ext = Extension(
code=extension.id, code=ext_meta.id,
is_valid=True, is_valid=True,
is_admin_only=False, is_admin_only=False,
name=extension.name, name=ext_meta.name,
hash=extension.hash if module_installed else "", hash=ext_meta.hash if module_installed else "",
) )
current_versions = await get_dbversions() current_versions = await get_dbversions()
@ -836,8 +791,8 @@ async def api_install_extension(
except Exception as ex: except Exception as ex:
logger.warning(ex) logger.warning(ex)
# remove downloaded archive # remove downloaded archive
if os.path.isfile(ext_zip_file): if os.path.isfile(ext_meta.zip_path):
os.remove(ext_zip_file) os.remove(ext_meta.zip_path)
# remove module from extensions # remove module from extensions
shutil.rmtree(ext_dir, True) shutil.rmtree(ext_dir, True)

View file

@ -48,6 +48,14 @@ class InstallableExtension(NamedTuple):
is_admin_only: bool = False is_admin_only: bool = False
version: Optional[int] = 0 version: Optional[int] = 0
@property
def zip_path(self):
extensions_data_dir = os.path.join(settings.lnbits_data_folder, "extensions")
os.makedirs(extensions_data_dir, exist_ok=True)
ext_data_dir = os.path.join(extensions_data_dir, self.id)
shutil.rmtree(ext_data_dir, True)
return os.path.join(extensions_data_dir, f"{self.id}.zip")
class ExtensionManager: class ExtensionManager:
def __init__(self, include_disabled_exts=False): def __init__(self, include_disabled_exts=False):