diff --git a/lnbits/core/services/notifications.py b/lnbits/core/services/notifications.py index 9eadc646..02c4a5e1 100644 --- a/lnbits/core/services/notifications.py +++ b/lnbits/core/services/notifications.py @@ -277,7 +277,7 @@ async def send_payment_notification(wallet: Wallet, payment: Payment): async def send_ws_payment_notification(wallet: Wallet, payment: Payment): # TODO: websocket message should be a clean payment model - # await websocket_manager.send_data(payment.json(), wallet.inkey) + # await websocket_manager.send(wallet.inkey, payment.json()) # TODO: figure out why we send the balance with the payment here. # cleaner would be to have a separate message for the balance # and send it with the id of the wallet so wallets can subscribe to it @@ -288,12 +288,11 @@ async def send_ws_payment_notification(wallet: Wallet, payment: Payment): "payment": json.loads(payment.json()), }, ) - await websocket_manager.send_data(payment_notification, wallet.inkey) - await websocket_manager.send_data(payment_notification, wallet.adminkey) - - await websocket_manager.send_data( - json.dumps({"pending": payment.pending, "status": payment.status}), + await websocket_manager.send(wallet.inkey, payment_notification) + await websocket_manager.send(wallet.adminkey, payment_notification) + await websocket_manager.send( payment.payment_hash, + json.dumps({"pending": payment.pending, "status": payment.status}), ) diff --git a/lnbits/core/services/websockets.py b/lnbits/core/services/websockets.py index 509e060a..f4a74002 100644 --- a/lnbits/core/services/websockets.py +++ b/lnbits/core/services/websockets.py @@ -1,27 +1,65 @@ -from fastapi import WebSocket +from asyncio import Queue +from dataclasses import dataclass + +from fastapi import WebSocket, WebSocketDisconnect from loguru import logger +from lnbits.settings import settings + + +@dataclass +class WebsocketConnection: + item_id: str + websocket: WebSocket + receive_queue: Queue[str] + class WebsocketConnectionManager: def __init__(self) -> None: - self.active_connections: list[WebSocket] = [] + self.active_connections: list[WebsocketConnection] = [] - async def connect(self, websocket: WebSocket, item_id: str): + async def connect(self, item_id: str, websocket: WebSocket) -> WebsocketConnection: logger.debug(f"Websocket connected to {item_id}") await websocket.accept() - self.active_connections.append(websocket) + conn = WebsocketConnection( + item_id=item_id, + websocket=websocket, + receive_queue=Queue(), + ) + self.active_connections.append(conn) + return conn - def disconnect(self, websocket: WebSocket): - self.active_connections.remove(websocket) + async def listen(self, conn: WebsocketConnection) -> None: + while settings.lnbits_running: + try: + data = await conn.websocket.receive_text() + logger.debug(f"WS received data from {conn.item_id}: {data}") + conn.receive_queue.put_nowait(data) + except WebSocketDisconnect: + for _conn in self.active_connections: + if _conn.websocket == conn.websocket: + self.active_connections.remove(_conn) + logger.debug(f"WS disconnected from {conn.item_id}") + break # out of the listen and the fastapi route - async def send_data(self, message: str, item_id: str): - for connection in self.active_connections: - if connection.path_params["item_id"] == item_id: - await connection.send_text(message) + def get_connections(self, item_id: str) -> list[WebsocketConnection]: + conns = [] + for conn in self.active_connections: + if conn.item_id == item_id: + conns.append(conn) + return conns + + def has_connection(self, item_id: str) -> bool: + return len(self.get_connections(item_id)) > 0 + + async def send(self, item_id: str, data: str) -> None: + for conn in self.get_connections(item_id): + await conn.websocket.send_text(data) websocket_manager = WebsocketConnectionManager() -async def websocket_updater(item_id: str, data: str): - return await websocket_manager.send_data(data, item_id) +# deprecated import and use `websocket_manager.send()` instead +async def websocket_updater(item_id: str, data: str) -> None: + return await websocket_manager.send(item_id, data) diff --git a/lnbits/core/views/websocket_api.py b/lnbits/core/views/websocket_api.py index e0ea5392..956e2cbd 100644 --- a/lnbits/core/views/websocket_api.py +++ b/lnbits/core/views/websocket_api.py @@ -1,33 +1,20 @@ -from fastapi import ( - APIRouter, - WebSocket, - WebSocketDisconnect, -) +from fastapi import APIRouter, WebSocket -from lnbits.settings import settings - -from ..services import ( - websocket_manager, - websocket_updater, -) +from ..services import websocket_manager websocket_router = APIRouter(prefix="/api/v1/ws", tags=["Websocket"]) @websocket_router.websocket("/{item_id}") -async def websocket_connect(websocket: WebSocket, item_id: str): - await websocket_manager.connect(websocket, item_id) - try: - while settings.lnbits_running: - await websocket.receive_text() - except WebSocketDisconnect: - websocket_manager.disconnect(websocket) +async def websocket_connect(websocket: WebSocket, item_id: str) -> None: + conn = await websocket_manager.connect(item_id, websocket) + await websocket_manager.listen(conn) @websocket_router.post("/{item_id}") async def websocket_update_post(item_id: str, data: str): try: - await websocket_updater(item_id, data) + await websocket_manager.send(item_id, data) return {"sent": True, "data": data} except Exception: return {"sent": False, "data": data} @@ -36,7 +23,7 @@ async def websocket_update_post(item_id: str, data: str): @websocket_router.get("/{item_id}/{data}") async def websocket_update_get(item_id: str, data: str): try: - await websocket_updater(item_id, data) + await websocket_manager.send(item_id, data) return {"sent": True, "data": data} except Exception: return {"sent": False, "data": data}