diff --git a/lnbits/app.py b/lnbits/app.py index a7f41ea9..53f6f4b5 100644 --- a/lnbits/app.py +++ b/lnbits/app.py @@ -1,5 +1,5 @@ -import importlib import asyncio +import importlib from quart import Quart, g from quart_cors import cors # type: ignore @@ -8,7 +8,7 @@ from secure import SecureHeaders # type: ignore from .commands import db_migrate, handle_assets from .core import core_app -from .db import open_db +from .db import open_db, open_ext_db from .helpers import get_valid_extensions, get_js_vendored, get_css_vendored, url_for_vendored from .proxy_fix import ASGIProxyFix @@ -43,7 +43,17 @@ def register_blueprints(app: Quart) -> None: for ext in get_valid_extensions(): try: ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}") - app.register_blueprint(getattr(ext_module, f"{ext.code}_ext"), url_prefix=f"/{ext.code}") + bp = getattr(ext_module, f"{ext.code}_ext") + + @bp.before_request + async def before_request(): + g.ext_db = open_ext_db(ext.code) + + @bp.teardown_request + async def after_request(exc): + g.ext_db.__exit__(type(exc), exc, None) + + app.register_blueprint(bp, url_prefix=f"/{ext.code}") except Exception: raise ImportError(f"Please make sure that the extension `{ext.code}` follows conventions.") @@ -99,8 +109,8 @@ def register_async_tasks(app): @app.before_serving async def listeners(): - loop = asyncio.get_event_loop() - loop.create_task(invoice_listener(app)) + loop = asyncio.get_running_loop() + loop.create_task(invoice_listener()) @app.after_serving async def stop_listeners(): diff --git a/lnbits/core/tasks.py b/lnbits/core/tasks.py index b13f9f5c..1dff052e 100644 --- a/lnbits/core/tasks.py +++ b/lnbits/core/tasks.py @@ -31,8 +31,8 @@ def run_on_pseudo_request(awaitable: Awaitable): send_push_promise=lambda x, h: None, ) async with main_app.request_context(fk): - g.db = open_db() - await awaitable + with open_db() as g.db: + await awaitable loop = asyncio.get_event_loop() loop.create_task(run(awaitable)) @@ -57,16 +57,15 @@ async def webhook_handler(): return "", HTTPStatus.NO_CONTENT -async def invoice_listener(app): - run_on_pseudo_request(_invoice_listener()) - - -async def _invoice_listener(): +async def invoice_listener(): async for checking_id in WALLET.paid_invoices_stream(): - g.db = open_db() - payment = get_standalone_payment(checking_id) - if payment.is_in: - payment.set_pending(False) - for ext_name, cb in invoice_listeners: - g.ext_db = open_ext_db(ext_name) + run_on_pseudo_request(invoice_callback_dispatcher(checking_id)) + + +async def invoice_callback_dispatcher(checking_id: str): + payment = get_standalone_payment(checking_id) + if payment and payment.is_in: + payment.set_pending(False) + for ext_name, cb in invoice_listeners: + with open_ext_db(ext_name) as g.ext_db: # type: ignore await cb(payment)