fix(backend/ws): Add user_id to websocket event subscription key (#9660)

- Add `user_id` to WS subscription key
- Add error catching to WS message handler
This commit is contained in:
Reinier van der Leer
2025-03-20 17:54:04 +01:00
committed by GitHub
parent 90b147ff51
commit 9a661b5101
6 changed files with 197 additions and 48 deletions

View File

@@ -1,7 +1,7 @@
from collections import defaultdict
from datetime import datetime, timezone
from multiprocessing import Manager
from typing import Any, AsyncGenerator, Generator, Generic, Type, TypeVar
from typing import Any, AsyncGenerator, Generator, Generic, Optional, Type, TypeVar
from prisma import Json
from prisma.enums import AgentExecutionStatus
@@ -65,6 +65,7 @@ class ExecutionQueue(Generic[T]):
class ExecutionResult(BaseModel):
user_id: str
graph_id: str
graph_version: int
graph_exec_id: str
@@ -80,27 +81,28 @@ class ExecutionResult(BaseModel):
end_time: datetime | None
@staticmethod
def from_graph(graph: AgentGraphExecution):
def from_graph(graph_exec: AgentGraphExecution):
return ExecutionResult(
graph_id=graph.agentGraphId,
graph_version=graph.agentGraphVersion,
graph_exec_id=graph.id,
user_id=graph_exec.userId,
graph_id=graph_exec.agentGraphId,
graph_version=graph_exec.agentGraphVersion,
graph_exec_id=graph_exec.id,
node_exec_id="",
node_id="",
block_id="",
status=graph.executionStatus,
status=graph_exec.executionStatus,
# TODO: Populate input_data & output_data from AgentNodeExecutions
# Input & Output comes AgentInputBlock & AgentOutputBlock.
input_data={},
output_data={},
add_time=graph.createdAt,
queue_time=graph.createdAt,
start_time=graph.startedAt,
end_time=graph.updatedAt,
add_time=graph_exec.createdAt,
queue_time=graph_exec.createdAt,
start_time=graph_exec.startedAt,
end_time=graph_exec.updatedAt,
)
@staticmethod
def from_db(execution: AgentNodeExecution):
def from_db(execution: AgentNodeExecution, user_id: Optional[str] = None):
if execution.executionData:
# Execution that has been queued for execution will persist its data.
input_data = type.convert(execution.executionData, dict[str, Any])
@@ -115,8 +117,15 @@ class ExecutionResult(BaseModel):
output_data[data.name].append(type.convert(data.data, Type[Any]))
graph_execution: AgentGraphExecution | None = execution.AgentGraphExecution
if graph_execution:
user_id = graph_execution.userId
elif not user_id:
raise ValueError(
"AgentGraphExecution must be included or user_id passed in"
)
return ExecutionResult(
user_id=user_id,
graph_id=graph_execution.agentGraphId if graph_execution else "",
graph_version=graph_execution.agentGraphVersion if graph_execution else 0,
graph_exec_id=execution.agentGraphExecutionId,
@@ -175,7 +184,7 @@ async def create_graph_execution(
)
return result.id, [
ExecutionResult.from_db(execution)
ExecutionResult.from_db(execution, result.userId)
for execution in result.AgentNodeExecutions or []
]

View File

@@ -217,7 +217,8 @@ class GraphExecution(GraphExecutionMeta):
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
node_executions = [
ExecutionResult.from_db(ne) for ne in _graph_exec.AgentNodeExecutions
ExecutionResult.from_db(ne, _graph_exec.userId)
for ne in _graph_exec.AgentNodeExecutions
]
inputs = {

View File

@@ -20,23 +20,25 @@ class ConnectionManager:
for subscribers in self.subscriptions.values():
subscribers.discard(websocket)
async def subscribe(self, graph_id: str, graph_version: int, websocket: WebSocket):
key = f"{graph_id}_{graph_version}"
async def subscribe(
self, *, user_id: str, graph_id: str, graph_version: int, websocket: WebSocket
):
key = f"{user_id}_{graph_id}_{graph_version}"
if key not in self.subscriptions:
self.subscriptions[key] = set()
self.subscriptions[key].add(websocket)
async def unsubscribe(
self, graph_id: str, graph_version: int, websocket: WebSocket
self, *, user_id: str, graph_id: str, graph_version: int, websocket: WebSocket
):
key = f"{graph_id}_{graph_version}"
key = f"{user_id}_{graph_id}_{graph_version}"
if key in self.subscriptions:
self.subscriptions[key].discard(websocket)
if not self.subscriptions[key]:
del self.subscriptions[key]
async def send_execution_result(self, result: execution.ExecutionResult):
key = f"{result.graph_id}_{result.graph_version}"
key = f"{result.user_id}_{result.graph_id}_{result.graph_version}"
if key in self.subscriptions:
message = WsMessage(
method=Methods.EXECUTION_EVENT,

View File

@@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
import uvicorn
from autogpt_libs.auth import parse_jwt_token
from autogpt_libs.utils.cache import thread_cached
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
from starlette.middleware.cors import CORSMiddleware
@@ -12,7 +13,7 @@ from backend.data.execution import AsyncRedisExecutionEventBus
from backend.data.user import DEFAULT_USER_ID
from backend.server.conn_manager import ConnectionManager
from backend.server.model import ExecutionSubscription, Methods, WsMessage
from backend.util.service import AppProcess
from backend.util.service import AppProcess, get_service_client
from backend.util.settings import AppEnvironment, Config, Settings
logger = logging.getLogger(__name__)
@@ -39,6 +40,13 @@ def get_connection_manager():
return _connection_manager
@thread_cached
def get_db_client():
from backend.executor import DatabaseManager
return get_service_client(DatabaseManager)
async def event_broadcaster(manager: ConnectionManager):
try:
redis.connect()
@@ -74,7 +82,10 @@ async def authenticate_websocket(websocket: WebSocket) -> str:
async def handle_subscribe(
websocket: WebSocket, manager: ConnectionManager, message: WsMessage
connection_manager: ConnectionManager,
websocket: WebSocket,
user_id: str,
message: WsMessage,
):
if not message.data:
await websocket.send_text(
@@ -85,20 +96,47 @@ async def handle_subscribe(
).model_dump_json()
)
else:
ex_sub = ExecutionSubscription.model_validate(message.data)
await manager.subscribe(ex_sub.graph_id, ex_sub.graph_version, websocket)
logger.debug(f"New execution subscription for graph {ex_sub.graph_id}")
sub_req = ExecutionSubscription.model_validate(message.data)
# Verify that user has read access to graph
# if not get_db_client().get_graph(
# graph_id=sub_req.graph_id,
# version=sub_req.graph_version,
# user_id=user_id,
# ):
# await websocket.send_text(
# WsMessage(
# method=Methods.ERROR,
# success=False,
# error="Access denied",
# ).model_dump_json()
# )
# return
await connection_manager.subscribe(
user_id=user_id,
graph_id=sub_req.graph_id,
graph_version=sub_req.graph_version,
websocket=websocket,
)
logger.debug(
f"New execution subscription for user #{user_id} "
f"graph #{sub_req.graph_id}v{sub_req.graph_version}"
)
await websocket.send_text(
WsMessage(
method=Methods.SUBSCRIBE,
success=True,
channel=f"{ex_sub.graph_id}_{ex_sub.graph_version}",
channel=f"{user_id}_{sub_req.graph_id}_{sub_req.graph_version}",
).model_dump_json()
)
async def handle_unsubscribe(
websocket: WebSocket, manager: ConnectionManager, message: WsMessage
connection_manager: ConnectionManager,
websocket: WebSocket,
user_id: str,
message: WsMessage,
):
if not message.data:
await websocket.send_text(
@@ -109,14 +147,22 @@ async def handle_unsubscribe(
).model_dump_json()
)
else:
ex_sub = ExecutionSubscription.model_validate(message.data)
await manager.unsubscribe(ex_sub.graph_id, ex_sub.graph_version, websocket)
logger.debug(f"Removed execution subscription for graph {ex_sub.graph_id}")
unsub_req = ExecutionSubscription.model_validate(message.data)
await connection_manager.unsubscribe(
user_id=user_id,
graph_id=unsub_req.graph_id,
graph_version=unsub_req.graph_version,
websocket=websocket,
)
logger.debug(
f"Removed execution subscription for user #{user_id} "
f"graph #{unsub_req.graph_id}v{unsub_req.graph_version}"
)
await websocket.send_text(
WsMessage(
method=Methods.UNSUBSCRIBE,
success=True,
channel=f"{ex_sub.graph_id}_{ex_sub.graph_version}",
channel=f"{unsub_req.graph_id}_{unsub_req.graph_version}",
).model_dump_json()
)
@@ -145,13 +191,32 @@ async def websocket_router(
)
continue
if message.method == Methods.SUBSCRIBE:
await handle_subscribe(websocket, manager, message)
try:
if message.method == Methods.SUBSCRIBE:
await handle_subscribe(
connection_manager=manager,
websocket=websocket,
user_id=user_id,
message=message,
)
continue
elif message.method == Methods.UNSUBSCRIBE:
await handle_unsubscribe(websocket, manager, message)
elif message.method == Methods.UNSUBSCRIBE:
await handle_unsubscribe(
connection_manager=manager,
websocket=websocket,
user_id=user_id,
message=message,
)
continue
except Exception as e:
logger.error(
f"Error while handling '{message.method}' message "
f"for user #{user_id}: {e}"
)
continue
elif message.method == Methods.ERROR:
if message.method == Methods.ERROR:
logger.error(f"WebSocket Error message received: {message.data}")
else:

View File

@@ -46,17 +46,27 @@ def test_disconnect(
async def test_subscribe(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
await connection_manager.subscribe("test_graph", 1, mock_websocket)
assert mock_websocket in connection_manager.subscriptions["test_graph_1"]
await connection_manager.subscribe(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
assert mock_websocket in connection_manager.subscriptions["user-1_test_graph_1"]
@pytest.mark.asyncio
async def test_unsubscribe(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
connection_manager.subscriptions["test_graph_1"] = {mock_websocket}
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
await connection_manager.unsubscribe("test_graph", 1, mock_websocket)
await connection_manager.unsubscribe(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
assert "test_graph" not in connection_manager.subscriptions
@@ -65,8 +75,9 @@ async def test_unsubscribe(
async def test_send_execution_result(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
connection_manager.subscriptions["test_graph_1"] = {mock_websocket}
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
result: ExecutionResult = ExecutionResult(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
graph_exec_id="test_exec_id",
@@ -87,17 +98,45 @@ async def test_send_execution_result(
mock_websocket.send_text.assert_called_once_with(
WsMessage(
method=Methods.EXECUTION_EVENT,
channel="test_graph_1",
channel="user-1_test_graph_1",
data=result.model_dump(),
).model_dump_json()
)
@pytest.mark.asyncio
async def test_send_execution_result_user_mismatch(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
result: ExecutionResult = ExecutionResult(
user_id="user-2",
graph_id="test_graph",
graph_version=1,
graph_exec_id="test_exec_id",
node_exec_id="test_node_exec_id",
node_id="test_node_id",
block_id="test_block_id",
status=ExecutionStatus.COMPLETED,
input_data={"input1": "value1"},
output_data={"output1": ["result1"]},
add_time=datetime.now(tz=timezone.utc),
queue_time=None,
start_time=datetime.now(tz=timezone.utc),
end_time=datetime.now(tz=timezone.utc),
)
await connection_manager.send_execution_result(result)
mock_websocket.send_text.assert_not_called()
@pytest.mark.asyncio
async def test_send_execution_result_no_subscribers(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
result: ExecutionResult = ExecutionResult(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
graph_exec_id="test_exec_id",

View File

@@ -4,6 +4,7 @@ from unittest.mock import AsyncMock
import pytest
from fastapi import WebSocket, WebSocketDisconnect
from backend.data.user import DEFAULT_USER_ID
from backend.server.conn_manager import ConnectionManager
from backend.server.ws_api import (
Methods,
@@ -41,7 +42,12 @@ async def test_websocket_router_subscribe(
)
mock_manager.connect.assert_called_once_with(mock_websocket)
mock_manager.subscribe.assert_called_once_with("test_graph", 1, mock_websocket)
mock_manager.subscribe.assert_called_once_with(
user_id=DEFAULT_USER_ID,
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once()
assert '"method":"subscribe"' in mock_websocket.send_text.call_args[0][0]
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
@@ -65,7 +71,12 @@ async def test_websocket_router_unsubscribe(
)
mock_manager.connect.assert_called_once_with(mock_websocket)
mock_manager.unsubscribe.assert_called_once_with("test_graph", 1, mock_websocket)
mock_manager.unsubscribe.assert_called_once_with(
user_id=DEFAULT_USER_ID,
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once()
assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0]
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
@@ -101,10 +112,18 @@ async def test_handle_subscribe_success(
)
await handle_subscribe(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message
connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
)
mock_manager.subscribe.assert_called_once_with("test_graph", 1, mock_websocket)
mock_manager.subscribe.assert_called_once_with(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once()
assert '"method":"subscribe"' in mock_websocket.send_text.call_args[0][0]
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
@@ -117,7 +136,10 @@ async def test_handle_subscribe_missing_data(
message = WsMessage(method=Methods.SUBSCRIBE)
await handle_subscribe(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message
connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
)
mock_manager.subscribe.assert_not_called()
@@ -135,10 +157,18 @@ async def test_handle_unsubscribe_success(
)
await handle_unsubscribe(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message
connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
)
mock_manager.unsubscribe.assert_called_once_with("test_graph", 1, mock_websocket)
mock_manager.unsubscribe.assert_called_once_with(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once()
assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0]
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
@@ -151,7 +181,10 @@ async def test_handle_unsubscribe_missing_data(
message = WsMessage(method=Methods.UNSUBSCRIBE)
await handle_unsubscribe(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message
connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
)
mock_manager.unsubscribe.assert_not_called()