Files
AutoGPT/autogpt_platform/backend/backend/api/ws_api_test.py
Reinier van der Leer de78d062a9 refactor(backend/api): Clean up API file structure (#11629)
We'll soon be needing a more feature-complete external API. To make way
for this, I'm moving some files around so:
- We can more easily create new versions of our external API
- The file structure of our internal API is more homogeneous

These changes are quite opinionated, but IMO in any case they're better
than the chaotic structure we have now.

### Changes 🏗️

- Move `backend/server` -> `backend/api`
- Move `backend/server/routers` + `backend/server/v2` ->
`backend/api/features`
  - Change absolute sibling imports to relative imports
- Move `backend/server/v2/AutoMod` -> `backend/executor/automod`
- Combine `backend/server/routers/analytics_*test.py` ->
`backend/api/features/analytics_test.py`
- Sort OpenAPI spec file

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - CI tests
  - [x] Clicking around in the app -> no obvious breakage
2025-12-20 20:33:10 +00:00

297 lines
9.8 KiB
Python

import json
from typing import cast
from unittest.mock import AsyncMock
import pytest
from fastapi import WebSocket, WebSocketDisconnect
from pytest_snapshot.plugin import Snapshot
from backend.api.conn_manager import ConnectionManager
from backend.api.test_helpers import override_config
from backend.api.ws_api import AppEnvironment, WebsocketServer, WSMessage, WSMethod
from backend.api.ws_api import app as websocket_app
from backend.api.ws_api import (
handle_subscribe,
handle_unsubscribe,
settings,
websocket_router,
)
from backend.data.user import DEFAULT_USER_ID
@pytest.fixture
def mock_websocket() -> AsyncMock:
mock = AsyncMock(spec=WebSocket)
mock.query_params = {} # Add query_params attribute for authentication
return mock
@pytest.fixture
def mock_manager() -> AsyncMock:
return AsyncMock(spec=ConnectionManager)
def test_websocket_server_uses_cors_helper(mocker) -> None:
cors_params = {
"allow_origins": ["https://app.example.com"],
"allow_origin_regex": None,
}
mocker.patch("backend.api.ws_api.uvicorn.run")
cors_middleware = mocker.patch(
"backend.api.ws_api.CORSMiddleware", return_value=object()
)
build_cors = mocker.patch(
"backend.api.ws_api.build_cors_params", return_value=cors_params
)
with override_config(
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
), override_config(settings, "app_env", AppEnvironment.LOCAL):
WebsocketServer().run()
build_cors.assert_called_once_with(
cors_params["allow_origins"], AppEnvironment.LOCAL
)
cors_middleware.assert_called_once_with(
app=websocket_app,
allow_origins=cors_params["allow_origins"],
allow_origin_regex=cors_params["allow_origin_regex"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def test_websocket_server_blocks_localhost_in_production(mocker) -> None:
mocker.patch("backend.api.ws_api.uvicorn.run")
with override_config(
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
), override_config(settings, "app_env", AppEnvironment.PRODUCTION):
with pytest.raises(ValueError):
WebsocketServer().run()
@pytest.mark.asyncio
async def test_websocket_router_subscribe(
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot, mocker
) -> None:
# Mock the authenticate_websocket function to ensure it returns a valid user_id
mocker.patch(
"backend.api.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
)
mock_websocket.receive_text.side_effect = [
WSMessage(
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
data={"graph_exec_id": "test-graph-exec-1"},
).model_dump_json(),
WebSocketDisconnect(),
]
mock_manager.subscribe_graph_exec.return_value = (
f"{DEFAULT_USER_ID}|graph_exec#test-graph-exec-1"
)
await websocket_router(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
)
mock_manager.connect_socket.assert_called_once_with(
mock_websocket, user_id=DEFAULT_USER_ID
)
mock_manager.subscribe_graph_exec.assert_called_once_with(
user_id=DEFAULT_USER_ID,
graph_exec_id="test-graph-exec-1",
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once()
assert (
'"method":"subscribe_graph_execution"'
in mock_websocket.send_text.call_args[0][0]
)
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
# Capture and snapshot the WebSocket response message
sent_message = mock_websocket.send_text.call_args[0][0]
parsed_message = json.loads(sent_message)
snapshot.snapshot_dir = "snapshots"
snapshot.assert_match(json.dumps(parsed_message, indent=2, sort_keys=True), "sub")
mock_manager.disconnect_socket.assert_called_once_with(
mock_websocket, user_id=DEFAULT_USER_ID
)
@pytest.mark.asyncio
async def test_websocket_router_unsubscribe(
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot, mocker
) -> None:
# Mock the authenticate_websocket function to ensure it returns a valid user_id
mocker.patch(
"backend.api.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
)
mock_websocket.receive_text.side_effect = [
WSMessage(
method=WSMethod.UNSUBSCRIBE,
data={"graph_exec_id": "test-graph-exec-1"},
).model_dump_json(),
WebSocketDisconnect(),
]
mock_manager.unsubscribe_graph_exec.return_value = (
f"{DEFAULT_USER_ID}|graph_exec#test-graph-exec-1"
)
await websocket_router(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
)
mock_manager.connect_socket.assert_called_once_with(
mock_websocket, user_id=DEFAULT_USER_ID
)
mock_manager.unsubscribe_graph_exec.assert_called_once_with(
user_id=DEFAULT_USER_ID,
graph_exec_id="test-graph-exec-1",
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once()
assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0]
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
# Capture and snapshot the WebSocket response message
sent_message = mock_websocket.send_text.call_args[0][0]
parsed_message = json.loads(sent_message)
snapshot.snapshot_dir = "snapshots"
snapshot.assert_match(json.dumps(parsed_message, indent=2, sort_keys=True), "unsub")
mock_manager.disconnect_socket.assert_called_once_with(
mock_websocket, user_id=DEFAULT_USER_ID
)
@pytest.mark.asyncio
async def test_websocket_router_invalid_method(
mock_websocket: AsyncMock, mock_manager: AsyncMock, mocker
) -> None:
# Mock the authenticate_websocket function to ensure it returns a valid user_id
mocker.patch(
"backend.api.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
)
mock_websocket.receive_text.side_effect = [
WSMessage(method=WSMethod.GRAPH_EXECUTION_EVENT).model_dump_json(),
WebSocketDisconnect(),
]
await websocket_router(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
)
mock_manager.connect_socket.assert_called_once_with(
mock_websocket, user_id=DEFAULT_USER_ID
)
mock_websocket.send_text.assert_called_once()
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
mock_manager.disconnect_socket.assert_called_once_with(
mock_websocket, user_id=DEFAULT_USER_ID
)
@pytest.mark.asyncio
async def test_handle_subscribe_success(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
message = WSMessage(
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
data={"graph_exec_id": "test-graph-exec-id"},
)
mock_manager.subscribe_graph_exec.return_value = (
"user-1|graph_exec#test-graph-exec-id"
)
await handle_subscribe(
connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
)
mock_manager.subscribe_graph_exec.assert_called_once_with(
user_id="user-1",
graph_exec_id="test-graph-exec-id",
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once()
assert (
'"method":"subscribe_graph_execution"'
in mock_websocket.send_text.call_args[0][0]
)
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
@pytest.mark.asyncio
async def test_handle_subscribe_missing_data(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
message = WSMessage(method=WSMethod.SUBSCRIBE_GRAPH_EXEC)
await handle_subscribe(
connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
)
mock_manager.subscribe_graph_exec.assert_not_called()
mock_websocket.send_text.assert_called_once()
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
@pytest.mark.asyncio
async def test_handle_unsubscribe_success(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
message = WSMessage(
method=WSMethod.UNSUBSCRIBE, data={"graph_exec_id": "test-graph-exec-id"}
)
mock_manager.unsubscribe_graph_exec.return_value = (
"user-1|graph_exec#test-graph-exec-id"
)
await handle_unsubscribe(
connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
)
mock_manager.unsubscribe_graph_exec.assert_called_once_with(
user_id="user-1",
graph_exec_id="test-graph-exec-id",
websocket=mock_websocket,
)
mock_websocket.send_text.assert_called_once()
assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0]
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
@pytest.mark.asyncio
async def test_handle_unsubscribe_missing_data(
mock_websocket: AsyncMock, mock_manager: AsyncMock
) -> None:
message = WSMessage(method=WSMethod.UNSUBSCRIBE)
await handle_unsubscribe(
connection_manager=cast(ConnectionManager, mock_manager),
websocket=cast(WebSocket, mock_websocket),
user_id="user-1",
message=message,
)
mock_manager._unsubscribe.assert_not_called()
mock_websocket.send_text.assert_called_once()
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
assert '"success":false' in mock_websocket.send_text.call_args[0][0]