fix(backend/ws): Add user_id to websocket event subscription key (#9660)

- Add `user_id` to WS subscription key
- Add error catching to WS message handler
This commit is contained in:
Reinier van der Leer
2025-03-20 17:54:04 +01:00
committed by GitHub
parent 90b147ff51
commit 9a661b5101
6 changed files with 197 additions and 48 deletions

View File

@@ -1,7 +1,7 @@
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timezone from datetime import datetime, timezone
from multiprocessing import Manager from multiprocessing import Manager
from typing import Any, AsyncGenerator, Generator, Generic, Type, TypeVar from typing import Any, AsyncGenerator, Generator, Generic, Optional, Type, TypeVar
from prisma import Json from prisma import Json
from prisma.enums import AgentExecutionStatus from prisma.enums import AgentExecutionStatus
@@ -65,6 +65,7 @@ class ExecutionQueue(Generic[T]):
class ExecutionResult(BaseModel): class ExecutionResult(BaseModel):
user_id: str
graph_id: str graph_id: str
graph_version: int graph_version: int
graph_exec_id: str graph_exec_id: str
@@ -80,27 +81,28 @@ class ExecutionResult(BaseModel):
end_time: datetime | None end_time: datetime | None
@staticmethod @staticmethod
def from_graph(graph: AgentGraphExecution): def from_graph(graph_exec: AgentGraphExecution):
return ExecutionResult( return ExecutionResult(
graph_id=graph.agentGraphId, user_id=graph_exec.userId,
graph_version=graph.agentGraphVersion, graph_id=graph_exec.agentGraphId,
graph_exec_id=graph.id, graph_version=graph_exec.agentGraphVersion,
graph_exec_id=graph_exec.id,
node_exec_id="", node_exec_id="",
node_id="", node_id="",
block_id="", block_id="",
status=graph.executionStatus, status=graph_exec.executionStatus,
# TODO: Populate input_data & output_data from AgentNodeExecutions # TODO: Populate input_data & output_data from AgentNodeExecutions
# Input & Output comes AgentInputBlock & AgentOutputBlock. # Input & Output comes AgentInputBlock & AgentOutputBlock.
input_data={}, input_data={},
output_data={}, output_data={},
add_time=graph.createdAt, add_time=graph_exec.createdAt,
queue_time=graph.createdAt, queue_time=graph_exec.createdAt,
start_time=graph.startedAt, start_time=graph_exec.startedAt,
end_time=graph.updatedAt, end_time=graph_exec.updatedAt,
) )
@staticmethod @staticmethod
def from_db(execution: AgentNodeExecution): def from_db(execution: AgentNodeExecution, user_id: Optional[str] = None):
if execution.executionData: if execution.executionData:
# Execution that has been queued for execution will persist its data. # Execution that has been queued for execution will persist its data.
input_data = type.convert(execution.executionData, dict[str, Any]) input_data = type.convert(execution.executionData, dict[str, Any])
@@ -115,8 +117,15 @@ class ExecutionResult(BaseModel):
output_data[data.name].append(type.convert(data.data, Type[Any])) output_data[data.name].append(type.convert(data.data, Type[Any]))
graph_execution: AgentGraphExecution | None = execution.AgentGraphExecution graph_execution: AgentGraphExecution | None = execution.AgentGraphExecution
if graph_execution:
user_id = graph_execution.userId
elif not user_id:
raise ValueError(
"AgentGraphExecution must be included or user_id passed in"
)
return ExecutionResult( return ExecutionResult(
user_id=user_id,
graph_id=graph_execution.agentGraphId if graph_execution else "", graph_id=graph_execution.agentGraphId if graph_execution else "",
graph_version=graph_execution.agentGraphVersion if graph_execution else 0, graph_version=graph_execution.agentGraphVersion if graph_execution else 0,
graph_exec_id=execution.agentGraphExecutionId, graph_exec_id=execution.agentGraphExecutionId,
@@ -175,7 +184,7 @@ async def create_graph_execution(
) )
return result.id, [ return result.id, [
ExecutionResult.from_db(execution) ExecutionResult.from_db(execution, result.userId)
for execution in result.AgentNodeExecutions or [] for execution in result.AgentNodeExecutions or []
] ]

View File

@@ -217,7 +217,8 @@ class GraphExecution(GraphExecutionMeta):
graph_exec = GraphExecutionMeta.from_db(_graph_exec) graph_exec = GraphExecutionMeta.from_db(_graph_exec)
node_executions = [ node_executions = [
ExecutionResult.from_db(ne) for ne in _graph_exec.AgentNodeExecutions ExecutionResult.from_db(ne, _graph_exec.userId)
for ne in _graph_exec.AgentNodeExecutions
] ]
inputs = { inputs = {

View File

@@ -20,23 +20,25 @@ class ConnectionManager:
for subscribers in self.subscriptions.values(): for subscribers in self.subscriptions.values():
subscribers.discard(websocket) subscribers.discard(websocket)
async def subscribe(self, graph_id: str, graph_version: int, websocket: WebSocket): async def subscribe(
key = f"{graph_id}_{graph_version}" self, *, user_id: str, graph_id: str, graph_version: int, websocket: WebSocket
):
key = f"{user_id}_{graph_id}_{graph_version}"
if key not in self.subscriptions: if key not in self.subscriptions:
self.subscriptions[key] = set() self.subscriptions[key] = set()
self.subscriptions[key].add(websocket) self.subscriptions[key].add(websocket)
async def unsubscribe( async def unsubscribe(
self, graph_id: str, graph_version: int, websocket: WebSocket self, *, user_id: str, graph_id: str, graph_version: int, websocket: WebSocket
): ):
key = f"{graph_id}_{graph_version}" key = f"{user_id}_{graph_id}_{graph_version}"
if key in self.subscriptions: if key in self.subscriptions:
self.subscriptions[key].discard(websocket) self.subscriptions[key].discard(websocket)
if not self.subscriptions[key]: if not self.subscriptions[key]:
del self.subscriptions[key] del self.subscriptions[key]
async def send_execution_result(self, result: execution.ExecutionResult): async def send_execution_result(self, result: execution.ExecutionResult):
key = f"{result.graph_id}_{result.graph_version}" key = f"{result.user_id}_{result.graph_id}_{result.graph_version}"
if key in self.subscriptions: if key in self.subscriptions:
message = WsMessage( message = WsMessage(
method=Methods.EXECUTION_EVENT, method=Methods.EXECUTION_EVENT,

View File

@@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
import uvicorn import uvicorn
from autogpt_libs.auth import parse_jwt_token from autogpt_libs.auth import parse_jwt_token
from autogpt_libs.utils.cache import thread_cached
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
from starlette.middleware.cors import CORSMiddleware from starlette.middleware.cors import CORSMiddleware
@@ -12,7 +13,7 @@ from backend.data.execution import AsyncRedisExecutionEventBus
from backend.data.user import DEFAULT_USER_ID from backend.data.user import DEFAULT_USER_ID
from backend.server.conn_manager import ConnectionManager from backend.server.conn_manager import ConnectionManager
from backend.server.model import ExecutionSubscription, Methods, WsMessage from backend.server.model import ExecutionSubscription, Methods, WsMessage
from backend.util.service import AppProcess from backend.util.service import AppProcess, get_service_client
from backend.util.settings import AppEnvironment, Config, Settings from backend.util.settings import AppEnvironment, Config, Settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -39,6 +40,13 @@ def get_connection_manager():
return _connection_manager return _connection_manager
@thread_cached
def get_db_client():
from backend.executor import DatabaseManager
return get_service_client(DatabaseManager)
async def event_broadcaster(manager: ConnectionManager): async def event_broadcaster(manager: ConnectionManager):
try: try:
redis.connect() redis.connect()
@@ -74,7 +82,10 @@ async def authenticate_websocket(websocket: WebSocket) -> str:
async def handle_subscribe( async def handle_subscribe(
websocket: WebSocket, manager: ConnectionManager, message: WsMessage connection_manager: ConnectionManager,
websocket: WebSocket,
user_id: str,
message: WsMessage,
): ):
if not message.data: if not message.data:
await websocket.send_text( await websocket.send_text(
@@ -85,20 +96,47 @@ async def handle_subscribe(
).model_dump_json() ).model_dump_json()
) )
else: else:
ex_sub = ExecutionSubscription.model_validate(message.data) sub_req = ExecutionSubscription.model_validate(message.data)
await manager.subscribe(ex_sub.graph_id, ex_sub.graph_version, websocket)
logger.debug(f"New execution subscription for graph {ex_sub.graph_id}") # Verify that user has read access to graph
# if not get_db_client().get_graph(
# graph_id=sub_req.graph_id,
# version=sub_req.graph_version,
# user_id=user_id,
# ):
# await websocket.send_text(
# WsMessage(
# method=Methods.ERROR,
# success=False,
# error="Access denied",
# ).model_dump_json()
# )
# return
await connection_manager.subscribe(
user_id=user_id,
graph_id=sub_req.graph_id,
graph_version=sub_req.graph_version,
websocket=websocket,
)
logger.debug(
f"New execution subscription for user #{user_id} "
f"graph #{sub_req.graph_id}v{sub_req.graph_version}"
)
await websocket.send_text( await websocket.send_text(
WsMessage( WsMessage(
method=Methods.SUBSCRIBE, method=Methods.SUBSCRIBE,
success=True, success=True,
channel=f"{ex_sub.graph_id}_{ex_sub.graph_version}", channel=f"{user_id}_{sub_req.graph_id}_{sub_req.graph_version}",
).model_dump_json() ).model_dump_json()
) )
async def handle_unsubscribe( async def handle_unsubscribe(
websocket: WebSocket, manager: ConnectionManager, message: WsMessage connection_manager: ConnectionManager,
websocket: WebSocket,
user_id: str,
message: WsMessage,
): ):
if not message.data: if not message.data:
await websocket.send_text( await websocket.send_text(
@@ -109,14 +147,22 @@ async def handle_unsubscribe(
).model_dump_json() ).model_dump_json()
) )
else: else:
ex_sub = ExecutionSubscription.model_validate(message.data) unsub_req = ExecutionSubscription.model_validate(message.data)
await manager.unsubscribe(ex_sub.graph_id, ex_sub.graph_version, websocket) await connection_manager.unsubscribe(
logger.debug(f"Removed execution subscription for graph {ex_sub.graph_id}") user_id=user_id,
graph_id=unsub_req.graph_id,
graph_version=unsub_req.graph_version,
websocket=websocket,
)
logger.debug(
f"Removed execution subscription for user #{user_id} "
f"graph #{unsub_req.graph_id}v{unsub_req.graph_version}"
)
await websocket.send_text( await websocket.send_text(
WsMessage( WsMessage(
method=Methods.UNSUBSCRIBE, method=Methods.UNSUBSCRIBE,
success=True, success=True,
channel=f"{ex_sub.graph_id}_{ex_sub.graph_version}", channel=f"{unsub_req.graph_id}_{unsub_req.graph_version}",
).model_dump_json() ).model_dump_json()
) )
@@ -145,13 +191,32 @@ async def websocket_router(
) )
continue continue
try:
if message.method == Methods.SUBSCRIBE: if message.method == Methods.SUBSCRIBE:
await handle_subscribe(websocket, manager, message) await handle_subscribe(
connection_manager=manager,
websocket=websocket,
user_id=user_id,
message=message,
)
continue
elif message.method == Methods.UNSUBSCRIBE: elif message.method == Methods.UNSUBSCRIBE:
await handle_unsubscribe(websocket, manager, message) await handle_unsubscribe(
connection_manager=manager,
websocket=websocket,
user_id=user_id,
message=message,
)
continue
except Exception as e:
logger.error(
f"Error while handling '{message.method}' message "
f"for user #{user_id}: {e}"
)
continue
elif message.method == Methods.ERROR: if message.method == Methods.ERROR:
logger.error(f"WebSocket Error message received: {message.data}") logger.error(f"WebSocket Error message received: {message.data}")
else: else:

View File

@@ -46,17 +46,27 @@ def test_disconnect(
async def test_subscribe( async def test_subscribe(
connection_manager: ConnectionManager, mock_websocket: AsyncMock connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None: ) -> None:
await connection_manager.subscribe("test_graph", 1, mock_websocket) await connection_manager.subscribe(
assert mock_websocket in connection_manager.subscriptions["test_graph_1"] user_id="user-1",
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
assert mock_websocket in connection_manager.subscriptions["user-1_test_graph_1"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_unsubscribe( async def test_unsubscribe(
connection_manager: ConnectionManager, mock_websocket: AsyncMock connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None: ) -> None:
connection_manager.subscriptions["test_graph_1"] = {mock_websocket} connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
await connection_manager.unsubscribe("test_graph", 1, mock_websocket) await connection_manager.unsubscribe(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
assert "test_graph" not in connection_manager.subscriptions assert "test_graph" not in connection_manager.subscriptions
@@ -65,8 +75,9 @@ async def test_unsubscribe(
async def test_send_execution_result( async def test_send_execution_result(
connection_manager: ConnectionManager, mock_websocket: AsyncMock connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None: ) -> None:
connection_manager.subscriptions["test_graph_1"] = {mock_websocket} connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
result: ExecutionResult = ExecutionResult( result: ExecutionResult = ExecutionResult(
user_id="user-1",
graph_id="test_graph", graph_id="test_graph",
graph_version=1, graph_version=1,
graph_exec_id="test_exec_id", graph_exec_id="test_exec_id",
@@ -87,17 +98,45 @@ async def test_send_execution_result(
mock_websocket.send_text.assert_called_once_with( mock_websocket.send_text.assert_called_once_with(
WsMessage( WsMessage(
method=Methods.EXECUTION_EVENT, method=Methods.EXECUTION_EVENT,
channel="test_graph_1", channel="user-1_test_graph_1",
data=result.model_dump(), data=result.model_dump(),
).model_dump_json() ).model_dump_json()
) )
@pytest.mark.asyncio
async def test_send_execution_result_user_mismatch(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
result: ExecutionResult = ExecutionResult(
user_id="user-2",
graph_id="test_graph",
graph_version=1,
graph_exec_id="test_exec_id",
node_exec_id="test_node_exec_id",
node_id="test_node_id",
block_id="test_block_id",
status=ExecutionStatus.COMPLETED,
input_data={"input1": "value1"},
output_data={"output1": ["result1"]},
add_time=datetime.now(tz=timezone.utc),
queue_time=None,
start_time=datetime.now(tz=timezone.utc),
end_time=datetime.now(tz=timezone.utc),
)
await connection_manager.send_execution_result(result)
mock_websocket.send_text.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_execution_result_no_subscribers( async def test_send_execution_result_no_subscribers(
connection_manager: ConnectionManager, mock_websocket: AsyncMock connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None: ) -> None:
result: ExecutionResult = ExecutionResult( result: ExecutionResult = ExecutionResult(
user_id="user-1",
graph_id="test_graph", graph_id="test_graph",
graph_version=1, graph_version=1,
graph_exec_id="test_exec_id", graph_exec_id="test_exec_id",

View File

@@ -4,6 +4,7 @@ from unittest.mock import AsyncMock
import pytest import pytest
from fastapi import WebSocket, WebSocketDisconnect from fastapi import WebSocket, WebSocketDisconnect
from backend.data.user import DEFAULT_USER_ID
from backend.server.conn_manager import ConnectionManager from backend.server.conn_manager import ConnectionManager
from backend.server.ws_api import ( from backend.server.ws_api import (
Methods, Methods,
@@ -41,7 +42,12 @@ async def test_websocket_router_subscribe(
) )
mock_manager.connect.assert_called_once_with(mock_websocket) mock_manager.connect.assert_called_once_with(mock_websocket)
mock_manager.subscribe.assert_called_once_with("test_graph", 1, mock_websocket) mock_manager.subscribe.assert_called_once_with(
user_id=DEFAULT_USER_ID,
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once() mock_websocket.send_text.assert_called_once()
assert '"method":"subscribe"' in mock_websocket.send_text.call_args[0][0] assert '"method":"subscribe"' in mock_websocket.send_text.call_args[0][0]
assert '"success":true' in mock_websocket.send_text.call_args[0][0] assert '"success":true' in mock_websocket.send_text.call_args[0][0]
@@ -65,7 +71,12 @@ async def test_websocket_router_unsubscribe(
) )
mock_manager.connect.assert_called_once_with(mock_websocket) mock_manager.connect.assert_called_once_with(mock_websocket)
mock_manager.unsubscribe.assert_called_once_with("test_graph", 1, mock_websocket) mock_manager.unsubscribe.assert_called_once_with(
user_id=DEFAULT_USER_ID,
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once() mock_websocket.send_text.assert_called_once()
assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0] assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0]
assert '"success":true' in mock_websocket.send_text.call_args[0][0] assert '"success":true' in mock_websocket.send_text.call_args[0][0]
@@ -101,10 +112,18 @@ async def test_handle_subscribe_success(
) )
await handle_subscribe( await handle_subscribe(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
) )
mock_manager.subscribe.assert_called_once_with("test_graph", 1, mock_websocket) mock_manager.subscribe.assert_called_once_with(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once() mock_websocket.send_text.assert_called_once()
assert '"method":"subscribe"' in mock_websocket.send_text.call_args[0][0] assert '"method":"subscribe"' in mock_websocket.send_text.call_args[0][0]
assert '"success":true' in mock_websocket.send_text.call_args[0][0] assert '"success":true' in mock_websocket.send_text.call_args[0][0]
@@ -117,7 +136,10 @@ async def test_handle_subscribe_missing_data(
message = WsMessage(method=Methods.SUBSCRIBE) message = WsMessage(method=Methods.SUBSCRIBE)
await handle_subscribe( await handle_subscribe(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
) )
mock_manager.subscribe.assert_not_called() mock_manager.subscribe.assert_not_called()
@@ -135,10 +157,18 @@ async def test_handle_unsubscribe_success(
) )
await handle_unsubscribe( await handle_unsubscribe(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
) )
mock_manager.unsubscribe.assert_called_once_with("test_graph", 1, mock_websocket) mock_manager.unsubscribe.assert_called_once_with(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once() mock_websocket.send_text.assert_called_once()
assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0] assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0]
assert '"success":true' in mock_websocket.send_text.call_args[0][0] assert '"success":true' in mock_websocket.send_text.call_args[0][0]
@@ -151,7 +181,10 @@ async def test_handle_unsubscribe_missing_data(
message = WsMessage(method=Methods.UNSUBSCRIBE) message = WsMessage(method=Methods.UNSUBSCRIBE)
await handle_unsubscribe( await handle_unsubscribe(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
) )
mock_manager.unsubscribe.assert_not_called() mock_manager.unsubscribe.assert_not_called()