Files
AutoGPT/autogpt_platform/backend/test/server/test_ws_api.py
Reinier van der Leer 9a661b5101 fix(backend/ws): Add user_id to websocket event subscription key (#9660)
- Add `user_id` to WS subscription key
- Add error catching to WS message handler
2025-03-20 17:54:04 +01:00

194 lines
6.1 KiB
Python

from typing import cast
from unittest.mock import AsyncMock
import pytest
from fastapi import WebSocket, WebSocketDisconnect
from backend.data.user import DEFAULT_USER_ID
from backend.server.conn_manager import ConnectionManager
from backend.server.ws_api import (
Methods,
WsMessage,
handle_subscribe,
handle_unsubscribe,
websocket_router,
)
@pytest.fixture
def mock_websocket() -> AsyncMock:
return AsyncMock(spec=WebSocket)
@pytest.fixture
def mock_manager() -> AsyncMock:
return AsyncMock(spec=ConnectionManager)
@pytest.mark.asyncio
async def test_websocket_router_subscribe(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
mock_websocket.receive_text.side_effect = [
WsMessage(
method=Methods.SUBSCRIBE,
data={"graph_id": "test_graph", "graph_version": 1},
).model_dump_json(),
WebSocketDisconnect(),
]
await websocket_router(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
)
mock_manager.connect.assert_called_once_with(mock_websocket)
mock_manager.subscribe.assert_called_once_with(
user_id=DEFAULT_USER_ID,
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once()
assert '"method":"subscribe"' in mock_websocket.send_text.call_args[0][0]
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
mock_manager.disconnect.assert_called_once_with(mock_websocket)
@pytest.mark.asyncio
async def test_websocket_router_unsubscribe(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
mock_websocket.receive_text.side_effect = [
WsMessage(
method=Methods.UNSUBSCRIBE,
data={"graph_id": "test_graph", "graph_version": 1},
).model_dump_json(),
WebSocketDisconnect(),
]
await websocket_router(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
)
mock_manager.connect.assert_called_once_with(mock_websocket)
mock_manager.unsubscribe.assert_called_once_with(
user_id=DEFAULT_USER_ID,
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once()
assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0]
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
mock_manager.disconnect.assert_called_once_with(mock_websocket)
@pytest.mark.asyncio
async def test_websocket_router_invalid_method(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
mock_websocket.receive_text.side_effect = [
WsMessage(method=Methods.EXECUTION_EVENT).model_dump_json(),
WebSocketDisconnect(),
]
await websocket_router(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
)
mock_manager.connect.assert_called_once_with(mock_websocket)
mock_websocket.send_text.assert_called_once()
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
mock_manager.disconnect.assert_called_once_with(mock_websocket)
@pytest.mark.asyncio
async def test_handle_subscribe_success(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
message = WsMessage(
method=Methods.SUBSCRIBE, data={"graph_id": "test_graph", "graph_version": 1}
)
await handle_subscribe(
connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
)
mock_manager.subscribe.assert_called_once_with(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once()
assert '"method":"subscribe"' in mock_websocket.send_text.call_args[0][0]
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
@pytest.mark.asyncio
async def test_handle_subscribe_missing_data(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
message = WsMessage(method=Methods.SUBSCRIBE)
await handle_subscribe(
connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
)
mock_manager.subscribe.assert_not_called()
mock_websocket.send_text.assert_called_once()
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
@pytest.mark.asyncio
async def test_handle_unsubscribe_success(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
message = WsMessage(
method=Methods.UNSUBSCRIBE, data={"graph_id": "test_graph", "graph_version": 1}
)
await handle_unsubscribe(
connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
)
mock_manager.unsubscribe.assert_called_once_with(
user_id="user-1",
graph_id="test_graph",
graph_version=1,
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once()
assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0]
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
@pytest.mark.asyncio
async def test_handle_unsubscribe_missing_data(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
message = WsMessage(method=Methods.UNSUBSCRIBE)
await handle_unsubscribe(
connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
)
mock_manager.unsubscribe.assert_not_called()
mock_websocket.send_text.assert_called_once()
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
assert '"success":false' in mock_websocket.send_text.call_args[0][0]