Files
AutoGPT/autogpt_platform/backend/test/server/test_con_manager.py
Reinier van der Leer 9a661b5101 fix(backend/ws): Add user_id to websocket event subscription key (#9660)
- Add `user_id` to WS subscription key
- Add error catching to WS message handler
2025-03-20 17:54:04 +01:00

158 lines
4.8 KiB
Python

from datetime import datetime, timezone
from unittest.mock import AsyncMock
import pytest
from fastapi import WebSocket
from backend.data.execution import ExecutionResult, ExecutionStatus
from backend.server.conn_manager import ConnectionManager
from backend.server.model import Methods, WsMessage
@pytest.fixture
def connection_manager() -> ConnectionManager:
return ConnectionManager()
@pytest.fixture
def mock_websocket() -> AsyncMock:
websocket: AsyncMock = AsyncMock(spec=WebSocket)
websocket.send_text = AsyncMock()
return websocket
@pytest.mark.asyncio
async def test_connect(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
await connection_manager.connect(mock_websocket)
assert mock_websocket in connection_manager.active_connections
mock_websocket.accept.assert_called_once()
def test_disconnect(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
connection_manager.active_connections.add(mock_websocket)
connection_manager.subscriptions["test_graph_1"] = {mock_websocket}
connection_manager.disconnect(mock_websocket)
assert mock_websocket not in connection_manager.active_connections
assert mock_websocket not in connection_manager.subscriptions["test_graph_1"]
@pytest.mark.asyncio
async def test_subscribe(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
await connection_manager.subscribe(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
assert mock_websocket in connection_manager.subscriptions["user-1_test_graph_1"]
@pytest.mark.asyncio
async def test_unsubscribe(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
await connection_manager.unsubscribe(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
assert "test_graph" not in connection_manager.subscriptions
@pytest.mark.asyncio
async def test_send_execution_result(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
result: ExecutionResult = ExecutionResult(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
graph_exec_id="test_exec_id",
node_exec_id="test_node_exec_id",
node_id="test_node_id",
block_id="test_block_id",
status=ExecutionStatus.COMPLETED,
input_data={"input1": "value1"},
output_data={"output1": ["result1"]},
add_time=datetime.now(tz=timezone.utc),
queue_time=None,
start_time=datetime.now(tz=timezone.utc),
end_time=datetime.now(tz=timezone.utc),
)
await connection_manager.send_execution_result(result)
mock_websocket.send_text.assert_called_once_with(
WsMessage(
method=Methods.EXECUTION_EVENT,
channel="user-1_test_graph_1",
data=result.model_dump(),
).model_dump_json()
)
@pytest.mark.asyncio
async def test_send_execution_result_user_mismatch(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
result: ExecutionResult = ExecutionResult(
user_id="user-2",
graph_id="test_graph",
graph_version=1,
graph_exec_id="test_exec_id",
node_exec_id="test_node_exec_id",
node_id="test_node_id",
block_id="test_block_id",
status=ExecutionStatus.COMPLETED,
input_data={"input1": "value1"},
output_data={"output1": ["result1"]},
add_time=datetime.now(tz=timezone.utc),
queue_time=None,
start_time=datetime.now(tz=timezone.utc),
end_time=datetime.now(tz=timezone.utc),
)
await connection_manager.send_execution_result(result)
mock_websocket.send_text.assert_not_called()
@pytest.mark.asyncio
async def test_send_execution_result_no_subscribers(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
result: ExecutionResult = ExecutionResult(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
graph_exec_id="test_exec_id",
node_exec_id="test_node_exec_id",
node_id="test_node_id",
block_id="test_block_id",
status=ExecutionStatus.COMPLETED,
input_data={"input1": "value1"},
output_data={"output1": ["result1"]},
add_time=datetime.now(),
queue_time=None,
start_time=datetime.now(),
end_time=datetime.now(),
)
await connection_manager.send_execution_result(result)
mock_websocket.send_text.assert_not_called()