fix: make startup extension check sync (#2819)

This commit is contained in:
dni ⚡ 2024-12-16 09:45:39 +01:00 committed by GitHub
parent b3351f2b17
commit b66a8b3de9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 56 additions and 19 deletions

View file

@ -31,7 +31,7 @@ jobs:
LNBITS_BACKEND_WALLET_CLASS: FakeWallet LNBITS_BACKEND_WALLET_CLASS: FakeWallet
run: | run: |
poetry run lnbits & poetry run lnbits &
sleep 5 sleep 10
- name: setup java version - name: setup java version
run: | run: |

View file

@ -65,7 +65,6 @@ from .middleware import (
from .requestvars import g from .requestvars import g
from .tasks import ( from .tasks import (
check_pending_payments, check_pending_payments,
create_task,
internal_invoice_listener, internal_invoice_listener,
invoice_listener, invoice_listener,
) )
@ -81,6 +80,10 @@ async def startup(app: FastAPI):
await check_admin_settings() await check_admin_settings()
await check_webpush_settings() await check_webpush_settings()
# check extensions after restart
if not settings.lnbits_extensions_deactivate_all:
await check_and_register_extensions(app)
log_server_info() log_server_info()
# initialize WALLET # initialize WALLET
@ -97,7 +100,7 @@ async def startup(app: FastAPI):
init_core_routers(app) init_core_routers(app)
# initialize tasks # initialize tasks
register_async_tasks(app) register_async_tasks()
async def shutdown(): async def shutdown():
@ -395,16 +398,21 @@ def register_new_ratelimiter(app: FastAPI) -> Callable:
return register_new_ratelimiter_fn return register_new_ratelimiter_fn
def register_ext_tasks(ext: Extension) -> None:
"""Register extension async tasks."""
ext_module = importlib.import_module(ext.module_name)
if hasattr(ext_module, f"{ext.code}_start"):
ext_start_func = getattr(ext_module, f"{ext.code}_start")
ext_start_func()
def register_ext_routes(app: FastAPI, ext: Extension) -> None: def register_ext_routes(app: FastAPI, ext: Extension) -> None:
"""Register FastAPI routes for extension.""" """Register FastAPI routes for extension."""
ext_module = importlib.import_module(ext.module_name) ext_module = importlib.import_module(ext.module_name)
ext_route = getattr(ext_module, f"{ext.code}_ext") ext_route = getattr(ext_module, f"{ext.code}_ext")
if hasattr(ext_module, f"{ext.code}_start"):
ext_start_func = getattr(ext_module, f"{ext.code}_start")
ext_start_func()
if hasattr(ext_module, f"{ext.code}_static_files"): if hasattr(ext_module, f"{ext.code}_static_files"):
ext_statics = getattr(ext_module, f"{ext.code}_static_files") ext_statics = getattr(ext_module, f"{ext.code}_static_files")
for s in ext_statics: for s in ext_statics:
@ -431,15 +439,12 @@ async def check_and_register_extensions(app: FastAPI):
for ext in await get_valid_extensions(False): for ext in await get_valid_extensions(False):
try: try:
register_ext_routes(app, ext) register_ext_routes(app, ext)
register_ext_tasks(ext)
except Exception as exc: except Exception as exc:
logger.error(f"Could not load extension `{ext.code}`: {exc!s}") logger.error(f"Could not load extension `{ext.code}`: {exc!s}")
def register_async_tasks(app: FastAPI): def register_async_tasks():
# check extensions after restart
if not settings.lnbits_extensions_deactivate_all:
create_task(check_and_register_extensions(app))
create_permanent_task(wait_for_audit_data) create_permanent_task(wait_for_audit_data)
create_permanent_task(check_pending_payments) create_permanent_task(check_pending_payments)

View file

@ -49,6 +49,8 @@ async def install_extension(ext_info: InstallableExtension) -> Extension:
# call stop while the old routes are still active # call stop while the old routes are still active
await stop_extension_background_work(ext_id) await stop_extension_background_work(ext_id)
await start_extension_background_work(ext_id)
return extension return extension
@ -76,7 +78,7 @@ async def deactivate_extension(ext_id: str):
async def stop_extension_background_work(ext_id: str) -> bool: async def stop_extension_background_work(ext_id: str) -> bool:
""" """
Stop background work for extension (like asyncio.Tasks, WebSockets, etc). Stop background work for extension (like asyncio.Tasks, WebSockets, etc).
Extensions SHOULD expose a `api_stop()` function. Extension must expose a `myextension_stop()` function if it is starting tasks.
""" """
upgrade_hash = settings.extension_upgrade_hash(ext_id) upgrade_hash = settings.extension_upgrade_hash(ext_id)
ext = Extension(code=ext_id, is_valid=True, upgrade_hash=upgrade_hash) ext = Extension(code=ext_id, is_valid=True, upgrade_hash=upgrade_hash)
@ -85,11 +87,10 @@ async def stop_extension_background_work(ext_id: str) -> bool:
logger.info(f"Stopping background work for extension '{ext.module_name}'.") logger.info(f"Stopping background work for extension '{ext.module_name}'.")
old_module = importlib.import_module(ext.module_name) old_module = importlib.import_module(ext.module_name)
# Extensions must expose an `{ext_id}_stop()` function at the module level stop_fn_name = f"{ext_id}_stop"
# The `api_stop()` function is for backwards compatibility (will be deprecated) assert hasattr(
stop_fns = [f"{ext_id}_stop", "api_stop"] old_module, stop_fn_name
stop_fn_name = next((fn for fn in stop_fns if hasattr(old_module, fn)), None) ), f"No stop function found for '{ext.module_name}'."
assert stop_fn_name, f"No stop function found for '{ext.module_name}'."
stop_fn = getattr(old_module, stop_fn_name) stop_fn = getattr(old_module, stop_fn_name)
if stop_fn: if stop_fn:
@ -97,7 +98,6 @@ async def stop_extension_background_work(ext_id: str) -> bool:
await stop_fn() await stop_fn()
else: else:
stop_fn() stop_fn()
logger.info(f"Stopped background work for extension '{ext.module_name}'.") logger.info(f"Stopped background work for extension '{ext.module_name}'.")
except Exception as ex: except Exception as ex:
logger.warning(f"Failed to stop background work for '{ext.module_name}'.") logger.warning(f"Failed to stop background work for '{ext.module_name}'.")
@ -107,6 +107,38 @@ async def stop_extension_background_work(ext_id: str) -> bool:
return True return True
async def start_extension_background_work(ext_id: str) -> bool:
"""
Start background work for extension (like asyncio.Tasks, WebSockets, etc).
Extension CAN expose a `myextension_start()` function if it is starting tasks.
Extension MUST expose a `myextension_stop()` in that case.
"""
upgrade_hash = settings.extension_upgrade_hash(ext_id)
ext = Extension(code=ext_id, is_valid=True, upgrade_hash=upgrade_hash)
try:
logger.info(f"Starting background work for extension '{ext.module_name}'.")
new_module = importlib.import_module(ext.module_name)
start_fn_name = f"{ext_id}_start"
# start function is optional, return False if not found
if not hasattr(new_module, start_fn_name):
return False
start_fn = getattr(new_module, start_fn_name)
if start_fn:
if asyncio.iscoroutinefunction(start_fn):
await start_fn()
else:
start_fn()
logger.info(f"Started background work for extension '{ext.module_name}'.")
return True
except Exception as ex:
logger.warning(f"Failed to start background work for '{ext.module_name}'.")
logger.warning(ex)
return False
async def get_valid_extensions( async def get_valid_extensions(
include_deactivated: Optional[bool] = True, include_deactivated: Optional[bool] = True,
) -> list[Extension]: ) -> list[Extension]: