From 9a661b510122f73a438176e0e9844dbab5fda044 Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Thu, 20 Mar 2025 17:54:04 +0100 Subject: [PATCH] fix(backend/ws): Add `user_id` to websocket event subscription key (#9660) - Add `user_id` to WS subscription key - Add error catching to WS message handler --- .../backend/backend/data/execution.py | 33 ++++--- .../backend/backend/data/graph.py | 3 +- .../backend/backend/server/conn_manager.py | 12 ++- .../backend/backend/server/ws_api.py | 97 ++++++++++++++++--- .../backend/test/server/test_con_manager.py | 51 ++++++++-- .../backend/test/server/test_ws_api.py | 49 ++++++++-- 6 files changed, 197 insertions(+), 48 deletions(-) diff --git a/autogpt_platform/backend/backend/data/execution.py b/autogpt_platform/backend/backend/data/execution.py index aef43bc9ae..397187d5f5 100644 --- a/autogpt_platform/backend/backend/data/execution.py +++ b/autogpt_platform/backend/backend/data/execution.py @@ -1,7 +1,7 @@ from collections import defaultdict from datetime import datetime, timezone from multiprocessing import Manager -from typing import Any, AsyncGenerator, Generator, Generic, Type, TypeVar +from typing import Any, AsyncGenerator, Generator, Generic, Optional, Type, TypeVar from prisma import Json from prisma.enums import AgentExecutionStatus @@ -65,6 +65,7 @@ class ExecutionQueue(Generic[T]): class ExecutionResult(BaseModel): + user_id: str graph_id: str graph_version: int graph_exec_id: str @@ -80,27 +81,28 @@ class ExecutionResult(BaseModel): end_time: datetime | None @staticmethod - def from_graph(graph: AgentGraphExecution): + def from_graph(graph_exec: AgentGraphExecution): return ExecutionResult( - graph_id=graph.agentGraphId, - graph_version=graph.agentGraphVersion, - graph_exec_id=graph.id, + user_id=graph_exec.userId, + graph_id=graph_exec.agentGraphId, + graph_version=graph_exec.agentGraphVersion, + graph_exec_id=graph_exec.id, node_exec_id="", node_id="", block_id="", - status=graph.executionStatus, + status=graph_exec.executionStatus, # TODO: Populate input_data & output_data from AgentNodeExecutions # Input & Output comes AgentInputBlock & AgentOutputBlock. input_data={}, output_data={}, - add_time=graph.createdAt, - queue_time=graph.createdAt, - start_time=graph.startedAt, - end_time=graph.updatedAt, + add_time=graph_exec.createdAt, + queue_time=graph_exec.createdAt, + start_time=graph_exec.startedAt, + end_time=graph_exec.updatedAt, ) @staticmethod - def from_db(execution: AgentNodeExecution): + def from_db(execution: AgentNodeExecution, user_id: Optional[str] = None): if execution.executionData: # Execution that has been queued for execution will persist its data. input_data = type.convert(execution.executionData, dict[str, Any]) @@ -115,8 +117,15 @@ class ExecutionResult(BaseModel): output_data[data.name].append(type.convert(data.data, Type[Any])) graph_execution: AgentGraphExecution | None = execution.AgentGraphExecution + if graph_execution: + user_id = graph_execution.userId + elif not user_id: + raise ValueError( + "AgentGraphExecution must be included or user_id passed in" + ) return ExecutionResult( + user_id=user_id, graph_id=graph_execution.agentGraphId if graph_execution else "", graph_version=graph_execution.agentGraphVersion if graph_execution else 0, graph_exec_id=execution.agentGraphExecutionId, @@ -175,7 +184,7 @@ async def create_graph_execution( ) return result.id, [ - ExecutionResult.from_db(execution) + ExecutionResult.from_db(execution, result.userId) for execution in result.AgentNodeExecutions or [] ] diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index e44f331a5d..693ca14023 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -217,7 +217,8 @@ class GraphExecution(GraphExecutionMeta): graph_exec = GraphExecutionMeta.from_db(_graph_exec) node_executions = [ - ExecutionResult.from_db(ne) for ne in _graph_exec.AgentNodeExecutions + ExecutionResult.from_db(ne, _graph_exec.userId) + for ne in _graph_exec.AgentNodeExecutions ] inputs = { diff --git a/autogpt_platform/backend/backend/server/conn_manager.py b/autogpt_platform/backend/backend/server/conn_manager.py index da62108540..b128debf30 100644 --- a/autogpt_platform/backend/backend/server/conn_manager.py +++ b/autogpt_platform/backend/backend/server/conn_manager.py @@ -20,23 +20,25 @@ class ConnectionManager: for subscribers in self.subscriptions.values(): subscribers.discard(websocket) - async def subscribe(self, graph_id: str, graph_version: int, websocket: WebSocket): - key = f"{graph_id}_{graph_version}" + async def subscribe( + self, *, user_id: str, graph_id: str, graph_version: int, websocket: WebSocket + ): + key = f"{user_id}_{graph_id}_{graph_version}" if key not in self.subscriptions: self.subscriptions[key] = set() self.subscriptions[key].add(websocket) async def unsubscribe( - self, graph_id: str, graph_version: int, websocket: WebSocket + self, *, user_id: str, graph_id: str, graph_version: int, websocket: WebSocket ): - key = f"{graph_id}_{graph_version}" + key = f"{user_id}_{graph_id}_{graph_version}" if key in self.subscriptions: self.subscriptions[key].discard(websocket) if not self.subscriptions[key]: del self.subscriptions[key] async def send_execution_result(self, result: execution.ExecutionResult): - key = f"{result.graph_id}_{result.graph_version}" + key = f"{result.user_id}_{result.graph_id}_{result.graph_version}" if key in self.subscriptions: message = WsMessage( method=Methods.EXECUTION_EVENT, diff --git a/autogpt_platform/backend/backend/server/ws_api.py b/autogpt_platform/backend/backend/server/ws_api.py index 25492dd304..d826ee8364 100644 --- a/autogpt_platform/backend/backend/server/ws_api.py +++ b/autogpt_platform/backend/backend/server/ws_api.py @@ -4,6 +4,7 @@ from contextlib import asynccontextmanager import uvicorn from autogpt_libs.auth import parse_jwt_token +from autogpt_libs.utils.cache import thread_cached from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect from starlette.middleware.cors import CORSMiddleware @@ -12,7 +13,7 @@ from backend.data.execution import AsyncRedisExecutionEventBus from backend.data.user import DEFAULT_USER_ID from backend.server.conn_manager import ConnectionManager from backend.server.model import ExecutionSubscription, Methods, WsMessage -from backend.util.service import AppProcess +from backend.util.service import AppProcess, get_service_client from backend.util.settings import AppEnvironment, Config, Settings logger = logging.getLogger(__name__) @@ -39,6 +40,13 @@ def get_connection_manager(): return _connection_manager +@thread_cached +def get_db_client(): + from backend.executor import DatabaseManager + + return get_service_client(DatabaseManager) + + async def event_broadcaster(manager: ConnectionManager): try: redis.connect() @@ -74,7 +82,10 @@ async def authenticate_websocket(websocket: WebSocket) -> str: async def handle_subscribe( - websocket: WebSocket, manager: ConnectionManager, message: WsMessage + connection_manager: ConnectionManager, + websocket: WebSocket, + user_id: str, + message: WsMessage, ): if not message.data: await websocket.send_text( @@ -85,20 +96,47 @@ async def handle_subscribe( ).model_dump_json() ) else: - ex_sub = ExecutionSubscription.model_validate(message.data) - await manager.subscribe(ex_sub.graph_id, ex_sub.graph_version, websocket) - logger.debug(f"New execution subscription for graph {ex_sub.graph_id}") + sub_req = ExecutionSubscription.model_validate(message.data) + + # Verify that user has read access to graph + # if not get_db_client().get_graph( + # graph_id=sub_req.graph_id, + # version=sub_req.graph_version, + # user_id=user_id, + # ): + # await websocket.send_text( + # WsMessage( + # method=Methods.ERROR, + # success=False, + # error="Access denied", + # ).model_dump_json() + # ) + # return + + await connection_manager.subscribe( + user_id=user_id, + graph_id=sub_req.graph_id, + graph_version=sub_req.graph_version, + websocket=websocket, + ) + logger.debug( + f"New execution subscription for user #{user_id} " + f"graph #{sub_req.graph_id}v{sub_req.graph_version}" + ) await websocket.send_text( WsMessage( method=Methods.SUBSCRIBE, success=True, - channel=f"{ex_sub.graph_id}_{ex_sub.graph_version}", + channel=f"{user_id}_{sub_req.graph_id}_{sub_req.graph_version}", ).model_dump_json() ) async def handle_unsubscribe( - websocket: WebSocket, manager: ConnectionManager, message: WsMessage + connection_manager: ConnectionManager, + websocket: WebSocket, + user_id: str, + message: WsMessage, ): if not message.data: await websocket.send_text( @@ -109,14 +147,22 @@ async def handle_unsubscribe( ).model_dump_json() ) else: - ex_sub = ExecutionSubscription.model_validate(message.data) - await manager.unsubscribe(ex_sub.graph_id, ex_sub.graph_version, websocket) - logger.debug(f"Removed execution subscription for graph {ex_sub.graph_id}") + unsub_req = ExecutionSubscription.model_validate(message.data) + await connection_manager.unsubscribe( + user_id=user_id, + graph_id=unsub_req.graph_id, + graph_version=unsub_req.graph_version, + websocket=websocket, + ) + logger.debug( + f"Removed execution subscription for user #{user_id} " + f"graph #{unsub_req.graph_id}v{unsub_req.graph_version}" + ) await websocket.send_text( WsMessage( method=Methods.UNSUBSCRIBE, success=True, - channel=f"{ex_sub.graph_id}_{ex_sub.graph_version}", + channel=f"{unsub_req.graph_id}_{unsub_req.graph_version}", ).model_dump_json() ) @@ -145,13 +191,32 @@ async def websocket_router( ) continue - if message.method == Methods.SUBSCRIBE: - await handle_subscribe(websocket, manager, message) + try: + if message.method == Methods.SUBSCRIBE: + await handle_subscribe( + connection_manager=manager, + websocket=websocket, + user_id=user_id, + message=message, + ) + continue - elif message.method == Methods.UNSUBSCRIBE: - await handle_unsubscribe(websocket, manager, message) + elif message.method == Methods.UNSUBSCRIBE: + await handle_unsubscribe( + connection_manager=manager, + websocket=websocket, + user_id=user_id, + message=message, + ) + continue + except Exception as e: + logger.error( + f"Error while handling '{message.method}' message " + f"for user #{user_id}: {e}" + ) + continue - elif message.method == Methods.ERROR: + if message.method == Methods.ERROR: logger.error(f"WebSocket Error message received: {message.data}") else: diff --git a/autogpt_platform/backend/test/server/test_con_manager.py b/autogpt_platform/backend/test/server/test_con_manager.py index e9acab9dd7..774f8c4a50 100644 --- a/autogpt_platform/backend/test/server/test_con_manager.py +++ b/autogpt_platform/backend/test/server/test_con_manager.py @@ -46,17 +46,27 @@ def test_disconnect( async def test_subscribe( connection_manager: ConnectionManager, mock_websocket: AsyncMock ) -> None: - await connection_manager.subscribe("test_graph", 1, mock_websocket) - assert mock_websocket in connection_manager.subscriptions["test_graph_1"] + await connection_manager.subscribe( + user_id="user-1", + graph_id="test_graph", + graph_version=1, + websocket=mock_websocket, + ) + assert mock_websocket in connection_manager.subscriptions["user-1_test_graph_1"] @pytest.mark.asyncio async def test_unsubscribe( connection_manager: ConnectionManager, mock_websocket: AsyncMock ) -> None: - connection_manager.subscriptions["test_graph_1"] = {mock_websocket} + connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket} - await connection_manager.unsubscribe("test_graph", 1, mock_websocket) + await connection_manager.unsubscribe( + user_id="user-1", + graph_id="test_graph", + graph_version=1, + websocket=mock_websocket, + ) assert "test_graph" not in connection_manager.subscriptions @@ -65,8 +75,9 @@ async def test_unsubscribe( async def test_send_execution_result( connection_manager: ConnectionManager, mock_websocket: AsyncMock ) -> None: - connection_manager.subscriptions["test_graph_1"] = {mock_websocket} + connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket} result: ExecutionResult = ExecutionResult( + user_id="user-1", graph_id="test_graph", graph_version=1, graph_exec_id="test_exec_id", @@ -87,17 +98,45 @@ async def test_send_execution_result( mock_websocket.send_text.assert_called_once_with( WsMessage( method=Methods.EXECUTION_EVENT, - channel="test_graph_1", + channel="user-1_test_graph_1", data=result.model_dump(), ).model_dump_json() ) +@pytest.mark.asyncio +async def test_send_execution_result_user_mismatch( + connection_manager: ConnectionManager, mock_websocket: AsyncMock +) -> None: + connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket} + result: ExecutionResult = ExecutionResult( + user_id="user-2", + graph_id="test_graph", + graph_version=1, + graph_exec_id="test_exec_id", + node_exec_id="test_node_exec_id", + node_id="test_node_id", + block_id="test_block_id", + status=ExecutionStatus.COMPLETED, + input_data={"input1": "value1"}, + output_data={"output1": ["result1"]}, + add_time=datetime.now(tz=timezone.utc), + queue_time=None, + start_time=datetime.now(tz=timezone.utc), + end_time=datetime.now(tz=timezone.utc), + ) + + await connection_manager.send_execution_result(result) + + mock_websocket.send_text.assert_not_called() + + @pytest.mark.asyncio async def test_send_execution_result_no_subscribers( connection_manager: ConnectionManager, mock_websocket: AsyncMock ) -> None: result: ExecutionResult = ExecutionResult( + user_id="user-1", graph_id="test_graph", graph_version=1, graph_exec_id="test_exec_id", diff --git a/autogpt_platform/backend/test/server/test_ws_api.py b/autogpt_platform/backend/test/server/test_ws_api.py index 069b6c772f..e689742fd5 100644 --- a/autogpt_platform/backend/test/server/test_ws_api.py +++ b/autogpt_platform/backend/test/server/test_ws_api.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock import pytest from fastapi import WebSocket, WebSocketDisconnect +from backend.data.user import DEFAULT_USER_ID from backend.server.conn_manager import ConnectionManager from backend.server.ws_api import ( Methods, @@ -41,7 +42,12 @@ async def test_websocket_router_subscribe( ) mock_manager.connect.assert_called_once_with(mock_websocket) - mock_manager.subscribe.assert_called_once_with("test_graph", 1, mock_websocket) + mock_manager.subscribe.assert_called_once_with( + user_id=DEFAULT_USER_ID, + graph_id="test_graph", + graph_version=1, + websocket=mock_websocket, + ) mock_websocket.send_text.assert_called_once() assert '"method":"subscribe"' in mock_websocket.send_text.call_args[0][0] assert '"success":true' in mock_websocket.send_text.call_args[0][0] @@ -65,7 +71,12 @@ async def test_websocket_router_unsubscribe( ) mock_manager.connect.assert_called_once_with(mock_websocket) - mock_manager.unsubscribe.assert_called_once_with("test_graph", 1, mock_websocket) + mock_manager.unsubscribe.assert_called_once_with( + user_id=DEFAULT_USER_ID, + graph_id="test_graph", + graph_version=1, + websocket=mock_websocket, + ) mock_websocket.send_text.assert_called_once() assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0] assert '"success":true' in mock_websocket.send_text.call_args[0][0] @@ -101,10 +112,18 @@ async def test_handle_subscribe_success( ) await handle_subscribe( - cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message + connection_manager=cast(ConnectionManager, mock_manager), + websocket=cast(WebSocket, mock_websocket), + user_id="user-1", + message=message, ) - mock_manager.subscribe.assert_called_once_with("test_graph", 1, mock_websocket) + mock_manager.subscribe.assert_called_once_with( + user_id="user-1", + graph_id="test_graph", + graph_version=1, + websocket=mock_websocket, + ) mock_websocket.send_text.assert_called_once() assert '"method":"subscribe"' in mock_websocket.send_text.call_args[0][0] assert '"success":true' in mock_websocket.send_text.call_args[0][0] @@ -117,7 +136,10 @@ async def test_handle_subscribe_missing_data( message = WsMessage(method=Methods.SUBSCRIBE) await handle_subscribe( - cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message + connection_manager=cast(ConnectionManager, mock_manager), + websocket=cast(WebSocket, mock_websocket), + user_id="user-1", + message=message, ) mock_manager.subscribe.assert_not_called() @@ -135,10 +157,18 @@ async def test_handle_unsubscribe_success( ) await handle_unsubscribe( - cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message + connection_manager=cast(ConnectionManager, mock_manager), + websocket=cast(WebSocket, mock_websocket), + user_id="user-1", + message=message, ) - mock_manager.unsubscribe.assert_called_once_with("test_graph", 1, mock_websocket) + mock_manager.unsubscribe.assert_called_once_with( + user_id="user-1", + graph_id="test_graph", + graph_version=1, + websocket=mock_websocket, + ) mock_websocket.send_text.assert_called_once() assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0] assert '"success":true' in mock_websocket.send_text.call_args[0][0] @@ -151,7 +181,10 @@ async def test_handle_unsubscribe_missing_data( message = WsMessage(method=Methods.UNSUBSCRIBE) await handle_unsubscribe( - cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message + connection_manager=cast(ConnectionManager, mock_manager), + websocket=cast(WebSocket, mock_websocket), + user_id="user-1", + message=message, ) mock_manager.unsubscribe.assert_not_called()