diff --git a/autogpt_platform/backend/backend/data/notification_bus.py b/autogpt_platform/backend/backend/data/notification_bus.py new file mode 100644 index 0000000000..ddd0681a2c --- /dev/null +++ b/autogpt_platform/backend/backend/data/notification_bus.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import AsyncGenerator + +from pydantic import BaseModel + +from backend.data.event_bus import AsyncRedisEventBus +from backend.server.model import NotificationPayload +from backend.util.settings import Settings + + +class NotificationEvent(BaseModel): + """Generic notification event destined for websocket delivery.""" + + user_id: str + payload: NotificationPayload + + +class AsyncRedisNotificationEventBus(AsyncRedisEventBus[NotificationEvent]): + Model = NotificationEvent # type: ignore + + @property + def event_bus_name(self) -> str: + return Settings().config.notification_event_bus_name + + async def publish(self, event: NotificationEvent) -> None: + await self.publish_event(event, event.user_id) + + async def listen( + self, user_id: str = "*" + ) -> AsyncGenerator[NotificationEvent, None]: + async for event in self.listen_events(user_id): + yield event diff --git a/autogpt_platform/backend/backend/server/conn_manager.py b/autogpt_platform/backend/backend/server/conn_manager.py index 0430028610..8d65117564 100644 --- a/autogpt_platform/backend/backend/server/conn_manager.py +++ b/autogpt_platform/backend/backend/server/conn_manager.py @@ -1,3 +1,4 @@ +import asyncio from typing import Dict, Set from fastapi import WebSocket @@ -7,7 +8,7 @@ from backend.data.execution import ( GraphExecutionEvent, NodeExecutionEvent, ) -from backend.server.model import WSMessage, WSMethod +from backend.server.model import NotificationPayload, WSMessage, WSMethod _EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = { ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT, @@ -19,15 +20,24 @@ class ConnectionManager: def __init__(self): self.active_connections: Set[WebSocket] = set() self.subscriptions: Dict[str, Set[WebSocket]] = {} + self.user_connections: Dict[str, Set[WebSocket]] = {} - async def connect_socket(self, websocket: WebSocket): + async def connect_socket(self, websocket: WebSocket, *, user_id: str): await websocket.accept() self.active_connections.add(websocket) + if user_id not in self.user_connections: + self.user_connections[user_id] = set() + self.user_connections[user_id].add(websocket) - def disconnect_socket(self, websocket: WebSocket): - self.active_connections.remove(websocket) + def disconnect_socket(self, websocket: WebSocket, *, user_id: str): + self.active_connections.discard(websocket) for subscribers in self.subscriptions.values(): subscribers.discard(websocket) + user_conns = self.user_connections.get(user_id) + if user_conns is not None: + user_conns.discard(websocket) + if not user_conns: + self.user_connections.pop(user_id, None) async def subscribe_graph_exec( self, *, user_id: str, graph_exec_id: str, websocket: WebSocket @@ -92,6 +102,26 @@ class ConnectionManager: return n_sent + async def send_notification( + self, *, user_id: str, payload: NotificationPayload + ) -> int: + """Send a notification to all websocket connections belonging to a user.""" + message = WSMessage( + method=WSMethod.NOTIFICATION, + data=payload.model_dump(), + ).model_dump_json() + + connections = tuple(self.user_connections.get(user_id, set())) + if not connections: + return 0 + + await asyncio.gather( + *(connection.send_text(message) for connection in connections), + return_exceptions=True, + ) + + return len(connections) + async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str: if channel_key not in self.subscriptions: self.subscriptions[channel_key] = set() diff --git a/autogpt_platform/backend/backend/server/conn_manager_test.py b/autogpt_platform/backend/backend/server/conn_manager_test.py index 401a9eaf81..379928fae7 100644 --- a/autogpt_platform/backend/backend/server/conn_manager_test.py +++ b/autogpt_platform/backend/backend/server/conn_manager_test.py @@ -10,7 +10,7 @@ from backend.data.execution import ( NodeExecutionEvent, ) from backend.server.conn_manager import ConnectionManager -from backend.server.model import WSMessage, WSMethod +from backend.server.model import NotificationPayload, WSMessage, WSMethod @pytest.fixture @@ -29,8 +29,9 @@ def mock_websocket() -> AsyncMock: async def test_connect( connection_manager: ConnectionManager, mock_websocket: AsyncMock ) -> None: - await connection_manager.connect_socket(mock_websocket) + await connection_manager.connect_socket(mock_websocket, user_id="user-1") assert mock_websocket in connection_manager.active_connections + assert mock_websocket in connection_manager.user_connections["user-1"] mock_websocket.accept.assert_called_once() @@ -39,11 +40,13 @@ def test_disconnect( ) -> None: connection_manager.active_connections.add(mock_websocket) connection_manager.subscriptions["test_channel_42"] = {mock_websocket} + connection_manager.user_connections["user-1"] = {mock_websocket} - connection_manager.disconnect_socket(mock_websocket) + connection_manager.disconnect_socket(mock_websocket, user_id="user-1") assert mock_websocket not in connection_manager.active_connections assert mock_websocket not in connection_manager.subscriptions["test_channel_42"] + assert "user-1" not in connection_manager.user_connections @pytest.mark.asyncio @@ -207,3 +210,22 @@ async def test_send_execution_result_no_subscribers( await connection_manager.send_execution_update(result) mock_websocket.send_text.assert_not_called() + + +@pytest.mark.asyncio +async def test_send_notification( + connection_manager: ConnectionManager, mock_websocket: AsyncMock +) -> None: + connection_manager.user_connections["user-1"] = {mock_websocket} + + await connection_manager.send_notification( + user_id="user-1", payload=NotificationPayload(type="info", event="hey") + ) + + mock_websocket.send_text.assert_called_once() + sent_message = mock_websocket.send_text.call_args[0][0] + expected_message = WSMessage( + method=WSMethod.NOTIFICATION, + data={"type": "info", "event": "hey"}, + ).model_dump_json() + assert sent_message == expected_message diff --git a/autogpt_platform/backend/backend/server/model.py b/autogpt_platform/backend/backend/server/model.py index bbb904a794..24ba4fa7ee 100644 --- a/autogpt_platform/backend/backend/server/model.py +++ b/autogpt_platform/backend/backend/server/model.py @@ -14,6 +14,7 @@ class WSMethod(enum.Enum): UNSUBSCRIBE = "unsubscribe" GRAPH_EXECUTION_EVENT = "graph_execution_event" NODE_EXECUTION_EVENT = "node_execution_event" + NOTIFICATION = "notification" ERROR = "error" HEARTBEAT = "heartbeat" @@ -76,3 +77,12 @@ class TimezoneResponse(pydantic.BaseModel): class UpdateTimezoneRequest(pydantic.BaseModel): timezone: TimeZoneName + + +class NotificationPayload(pydantic.BaseModel): + type: str + event: str + + +class OnboardingNotificationPayload(NotificationPayload): + step: str diff --git a/autogpt_platform/backend/backend/server/ws_api.py b/autogpt_platform/backend/backend/server/ws_api.py index f55bdf284a..344fd7e1a6 100644 --- a/autogpt_platform/backend/backend/server/ws_api.py +++ b/autogpt_platform/backend/backend/server/ws_api.py @@ -10,6 +10,7 @@ from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect from starlette.middleware.cors import CORSMiddleware from backend.data.execution import AsyncRedisExecutionEventBus +from backend.data.notification_bus import AsyncRedisNotificationEventBus from backend.data.user import DEFAULT_USER_ID from backend.monitoring.instrumentation import ( instrument_fastapi, @@ -62,9 +63,21 @@ def get_connection_manager(): @continuous_retry() async def event_broadcaster(manager: ConnectionManager): - event_queue = AsyncRedisExecutionEventBus() - async for event in event_queue.listen("*"): - await manager.send_execution_update(event) + execution_bus = AsyncRedisExecutionEventBus() + notification_bus = AsyncRedisNotificationEventBus() + + async def execution_worker(): + async for event in execution_bus.listen("*"): + await manager.send_execution_update(event) + + async def notification_worker(): + async for notification in notification_bus.listen("*"): + await manager.send_notification( + user_id=notification.user_id, + payload=notification.payload, + ) + + await asyncio.gather(execution_worker(), notification_worker()) async def authenticate_websocket(websocket: WebSocket) -> str: @@ -229,7 +242,7 @@ async def websocket_router( user_id = await authenticate_websocket(websocket) if not user_id: return - await manager.connect_socket(websocket) + await manager.connect_socket(websocket, user_id=user_id) # Track WebSocket connection update_websocket_connections(user_id, 1) @@ -302,7 +315,7 @@ async def websocket_router( ) except WebSocketDisconnect: - manager.disconnect_socket(websocket) + manager.disconnect_socket(websocket, user_id=user_id) logger.debug("WebSocket client disconnected") finally: update_websocket_connections(user_id, -1) diff --git a/autogpt_platform/backend/backend/server/ws_api_test.py b/autogpt_platform/backend/backend/server/ws_api_test.py index 51a4722fb1..0bc9902145 100644 --- a/autogpt_platform/backend/backend/server/ws_api_test.py +++ b/autogpt_platform/backend/backend/server/ws_api_test.py @@ -96,7 +96,9 @@ async def test_websocket_router_subscribe( cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager) ) - mock_manager.connect_socket.assert_called_once_with(mock_websocket) + mock_manager.connect_socket.assert_called_once_with( + mock_websocket, user_id=DEFAULT_USER_ID + ) mock_manager.subscribe_graph_exec.assert_called_once_with( user_id=DEFAULT_USER_ID, graph_exec_id="test-graph-exec-1", @@ -115,7 +117,9 @@ async def test_websocket_router_subscribe( snapshot.snapshot_dir = "snapshots" snapshot.assert_match(json.dumps(parsed_message, indent=2, sort_keys=True), "sub") - mock_manager.disconnect_socket.assert_called_once_with(mock_websocket) + mock_manager.disconnect_socket.assert_called_once_with( + mock_websocket, user_id=DEFAULT_USER_ID + ) @pytest.mark.asyncio @@ -142,7 +146,9 @@ async def test_websocket_router_unsubscribe( cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager) ) - mock_manager.connect_socket.assert_called_once_with(mock_websocket) + mock_manager.connect_socket.assert_called_once_with( + mock_websocket, user_id=DEFAULT_USER_ID + ) mock_manager.unsubscribe_graph_exec.assert_called_once_with( user_id=DEFAULT_USER_ID, graph_exec_id="test-graph-exec-1", @@ -158,7 +164,9 @@ async def test_websocket_router_unsubscribe( snapshot.snapshot_dir = "snapshots" snapshot.assert_match(json.dumps(parsed_message, indent=2, sort_keys=True), "unsub") - mock_manager.disconnect_socket.assert_called_once_with(mock_websocket) + mock_manager.disconnect_socket.assert_called_once_with( + mock_websocket, user_id=DEFAULT_USER_ID + ) @pytest.mark.asyncio @@ -179,11 +187,15 @@ async def test_websocket_router_invalid_method( cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager) ) - mock_manager.connect_socket.assert_called_once_with(mock_websocket) + mock_manager.connect_socket.assert_called_once_with( + mock_websocket, user_id=DEFAULT_USER_ID + ) mock_websocket.send_text.assert_called_once() assert '"method":"error"' in mock_websocket.send_text.call_args[0][0] assert '"success":false' in mock_websocket.send_text.call_args[0][0] - mock_manager.disconnect_socket.assert_called_once_with(mock_websocket) + mock_manager.disconnect_socket.assert_called_once_with( + mock_websocket, user_id=DEFAULT_USER_ID + ) @pytest.mark.asyncio diff --git a/autogpt_platform/backend/backend/util/settings.py b/autogpt_platform/backend/backend/util/settings.py index 4a59df760c..42014f9698 100644 --- a/autogpt_platform/backend/backend/util/settings.py +++ b/autogpt_platform/backend/backend/util/settings.py @@ -418,6 +418,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings): description="Name of the event bus", ) + notification_event_bus_name: str = Field( + default="notification_event", + description="Name of the websocket notification event bus", + ) + trust_endpoints_for_requests: List[str] = Field( default_factory=list, description="A whitelist of trusted internal endpoints for the backend to make requests to.", diff --git a/autogpt_platform/frontend/src/lib/autogpt-server-api/client.ts b/autogpt_platform/frontend/src/lib/autogpt-server-api/client.ts index 8072e81489..bec9c1749c 100644 --- a/autogpt_platform/frontend/src/lib/autogpt-server-api/client.ts +++ b/autogpt_platform/frontend/src/lib/autogpt-server-api/client.ts @@ -70,6 +70,7 @@ import type { UserOnboarding, UserPasswordCredentials, UsersBalanceHistoryResponse, + WebSocketNotification, } from "./types"; import { environment } from "@/services/environment"; @@ -1346,6 +1347,7 @@ type WebsocketMessageTypeMap = { subscribe_graph_executions: { graph_id: GraphID }; graph_execution_event: GraphExecution; node_execution_event: NodeExecutionResult; + notification: WebSocketNotification; heartbeat: "ping" | "pong"; }; diff --git a/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts b/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts index 1419ffb686..ea82e34d3e 100644 --- a/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts +++ b/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts @@ -976,6 +976,20 @@ export interface UserOnboarding { agentRuns: number; } +export interface OnboardingNotificationPayload { + type: "onboarding"; + event: string; + step: OnboardingStep; +} + +export type WebSocketNotification = + | OnboardingNotificationPayload + | { + type: string; + event: string; + [key: string]: unknown; + }; + /* *** UTILITIES *** */ /** Use branded types for IDs -> deny mixing IDs between different object classes */