feat: add typing for tasks (#2629)

* feat: add typing for tasks

* fixup!
This commit is contained in:
dni ⚡ 2024-08-07 09:57:15 +02:00 committed by GitHub
parent 27b9e8254c
commit ddb8fcb986
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,7 +4,13 @@ import time
import traceback import traceback
import uuid import uuid
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional from typing import (
Callable,
Coroutine,
Dict,
List,
Optional,
)
from loguru import logger from loguru import logger
from py_vapid import Vapid from py_vapid import Vapid
@ -17,7 +23,7 @@ from lnbits.core.crud import (
update_payment_details, update_payment_details,
update_payment_status, update_payment_status,
) )
from lnbits.core.models import PaymentState from lnbits.core.models import Payment, PaymentState
from lnbits.settings import settings from lnbits.settings import settings
from lnbits.wallets import get_funding_source from lnbits.wallets import get_funding_source
@ -25,17 +31,13 @@ tasks: List[asyncio.Task] = []
unique_tasks: Dict[str, asyncio.Task] = {} unique_tasks: Dict[str, asyncio.Task] = {}
def create_task(coro): def create_task(coro: Coroutine) -> asyncio.Task:
task = asyncio.create_task(coro) task = asyncio.create_task(coro)
tasks.append(task) tasks.append(task)
return task return task
def create_permanent_task(func): def create_unique_task(name: str, coro: Coroutine) -> asyncio.Task:
return create_task(catch_everything_and_restart(func))
def create_unique_task(name: str, coro):
if unique_tasks.get(name): if unique_tasks.get(name):
logger.warning(f"task `{name}` already exists, cancelling it") logger.warning(f"task `{name}` already exists, cancelling it")
try: try:
@ -47,11 +49,17 @@ def create_unique_task(name: str, coro):
return task return task
def create_permanent_unique_task(name: str, coro): def create_permanent_task(func: Callable[[], Coroutine]) -> asyncio.Task:
return create_task(catch_everything_and_restart(func))
def create_permanent_unique_task(
name: str, coro: Callable[[], Coroutine]
) -> asyncio.Task:
return create_unique_task(name, catch_everything_and_restart(coro, name)) return create_unique_task(name, catch_everything_and_restart(coro, name))
def cancel_all_tasks(): def cancel_all_tasks() -> None:
for task in tasks: for task in tasks:
try: try:
task.cancel() task.cancel()
@ -64,9 +72,12 @@ def cancel_all_tasks():
logger.warning(f"error while cancelling task `{name}`: {exc!s}") logger.warning(f"error while cancelling task `{name}`: {exc!s}")
async def catch_everything_and_restart(func, name: str = "unnamed"): async def catch_everything_and_restart(
func: Callable[[], Coroutine],
name: str = "unnamed",
) -> Coroutine:
try: try:
await func() return await func()
except asyncio.CancelledError: except asyncio.CancelledError:
raise # because we must pass this up raise # because we must pass this up
except Exception as exc: except Exception as exc:
@ -74,7 +85,7 @@ async def catch_everything_and_restart(func, name: str = "unnamed"):
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.error("will restart the task in 5 seconds.") logger.error("will restart the task in 5 seconds.")
await asyncio.sleep(5) await asyncio.sleep(5)
await catch_everything_and_restart(func, name) return catch_everything_and_restart(func, name)
invoice_listeners: Dict[str, asyncio.Queue] = {} invoice_listeners: Dict[str, asyncio.Queue] = {}
@ -101,7 +112,7 @@ def register_invoice_listener(send_chan: asyncio.Queue, name: Optional[str] = No
internal_invoice_queue: asyncio.Queue = asyncio.Queue(0) internal_invoice_queue: asyncio.Queue = asyncio.Queue(0)
async def internal_invoice_listener(): async def internal_invoice_listener() -> None:
""" """
internal_invoice_queue will be filled directly in core/services.py internal_invoice_queue will be filled directly in core/services.py
after the payment was deemed to be settled internally. after the payment was deemed to be settled internally.
@ -111,10 +122,10 @@ async def internal_invoice_listener():
while settings.lnbits_running: while settings.lnbits_running:
checking_id = await internal_invoice_queue.get() checking_id = await internal_invoice_queue.get()
logger.info(f"got an internal payment notification {checking_id}") logger.info(f"got an internal payment notification {checking_id}")
create_task(invoice_callback_dispatcher(checking_id, is_internal=True)) await invoice_callback_dispatcher(checking_id, is_internal=True)
async def invoice_listener(): async def invoice_listener() -> None:
""" """
invoice_listener will collect all invoices that come directly invoice_listener will collect all invoices that come directly
from the backend wallet. from the backend wallet.
@ -124,7 +135,22 @@ async def invoice_listener():
funding_source = get_funding_source() funding_source = get_funding_source()
async for checking_id in funding_source.paid_invoices_stream(): async for checking_id in funding_source.paid_invoices_stream():
logger.info(f"got a payment notification {checking_id}") logger.info(f"got a payment notification {checking_id}")
create_task(invoice_callback_dispatcher(checking_id)) await invoice_callback_dispatcher(checking_id)
def wait_for_paid_invoices(
invoice_listener_name: str,
func: Callable[[Payment], Coroutine],
) -> Callable[[], Coroutine]:
async def wrapper() -> None:
invoice_queue: asyncio.Queue = asyncio.Queue()
register_invoice_listener(invoice_queue, invoice_listener_name)
while settings.lnbits_running:
payment = await invoice_queue.get()
await func(payment)
return wrapper
async def check_pending_payments(): async def check_pending_payments():