diff --git a/lnbits/extensions/copilot/views.py b/lnbits/extensions/copilot/views.py index 99e38c9c..516bf4e3 100644 --- a/lnbits/extensions/copilot/views.py +++ b/lnbits/extensions/copilot/views.py @@ -2,7 +2,6 @@ from http import HTTPStatus import httpx from collections import defaultdict from lnbits.decorators import check_user_exists -import asyncio from .crud import get_copilot from functools import wraps @@ -10,7 +9,7 @@ from functools import wraps from lnbits.decorators import check_user_exists from . import copilot_ext, copilot_renderer -from fastapi import FastAPI, Request, WebSocket +from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect from fastapi.params import Depends from fastapi.templating import Jinja2Templates from fastapi.param_functions import Query @@ -48,28 +47,44 @@ async def panel(request: Request): # socket_relay is a list where the control panel or # lnurl endpoints can leave a message for the compose window -connected_websockets = defaultdict(set) + +class ConnectionManager: + def __init__(self): + self.active_connections: List[WebSocket] = [] + + async def connect(self, websocket: WebSocket): + await websocket.accept() + self.active_connections.append(websocket) + + def disconnect(self, websocket: WebSocket): + self.active_connections.remove(websocket) + + async def send_personal_message(self, message: str, websocket: WebSocket): + await websocket.send_text(message) + + async def broadcast(self, message: str): + for connection in self.active_connections: + await connection.send_text(message) -@copilot_ext.websocket("/ws/{id}/") -async def websocket_endpoint(websocket: WebSocket, id: str = Query(None)): - copilot = await get_copilot(id) - if not copilot: - return "", HTTPStatus.FORBIDDEN - await websocket.accept() - invoice_queue = asyncio.Queue() - connected_websockets[id].add(invoice_queue) +manager = ConnectionManager() + + +@copilot_ext.websocket("/ws/{socket_id}") +async def websocket_endpoint(websocket: WebSocket, socket_id: str): + await manager.connect(websocket) try: while True: data = await websocket.receive_text() - await websocket.send_text(f"Message text was: {data}") - finally: - connected_websockets[id].remove(invoice_queue) + await manager.send_personal_message(f"You wrote: {data}", websocket) + await manager.broadcast(f"Client #{socket_id} says: {data}") + except WebSocketDisconnect: + manager.disconnect(websocket) + await manager.broadcast(f"Client #{socket_id} left the chat") async def updater(copilot_id, data, comment): copilot = await get_copilot(copilot_id) if not copilot: return - for queue in connected_websockets[copilot_id]: - await queue.send(f"{data + '-' + comment}") + manager.broadcast(f"{data + '-' + comment}") diff --git a/lnbits/extensions/copilot/views_api.py b/lnbits/extensions/copilot/views_api.py index 74ce6291..58b486bf 100644 --- a/lnbits/extensions/copilot/views_api.py +++ b/lnbits/extensions/copilot/views_api.py @@ -106,6 +106,7 @@ async def api_copilot_ws_relay( data: str = Query(None), ): copilot = await get_copilot(copilot_id) + print(copilot) if not copilot: raise HTTPException( status_code=HTTPStatus.NOT_FOUND,