mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-17 10:12:02 -05:00
We'll soon be needing a more feature-complete external API. To make way for this, I'm moving some files around so: - We can more easily create new versions of our external API - The file structure of our internal API is more homogeneous These changes are quite opinionated, but IMO in any case they're better than the chaotic structure we have now. ### Changes 🏗️ - Move `backend/server` -> `backend/api` - Move `backend/server/routers` + `backend/server/v2` -> `backend/api/features` - Change absolute sibling imports to relative imports - Move `backend/server/v2/AutoMod` -> `backend/executor/automod` - Combine `backend/server/routers/analytics_*test.py` -> `backend/api/features/analytics_test.py` - Sort OpenAPI spec file ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - CI tests - [x] Clicking around in the app -> no obvious breakage
297 lines
9.8 KiB
Python
297 lines
9.8 KiB
Python
import json
|
|
from typing import cast
|
|
from unittest.mock import AsyncMock
|
|
|
|
import pytest
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
from pytest_snapshot.plugin import Snapshot
|
|
|
|
from backend.api.conn_manager import ConnectionManager
|
|
from backend.api.test_helpers import override_config
|
|
from backend.api.ws_api import AppEnvironment, WebsocketServer, WSMessage, WSMethod
|
|
from backend.api.ws_api import app as websocket_app
|
|
from backend.api.ws_api import (
|
|
handle_subscribe,
|
|
handle_unsubscribe,
|
|
settings,
|
|
websocket_router,
|
|
)
|
|
from backend.data.user import DEFAULT_USER_ID
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_websocket() -> AsyncMock:
|
|
mock = AsyncMock(spec=WebSocket)
|
|
mock.query_params = {} # Add query_params attribute for authentication
|
|
return mock
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_manager() -> AsyncMock:
|
|
return AsyncMock(spec=ConnectionManager)
|
|
|
|
|
|
def test_websocket_server_uses_cors_helper(mocker) -> None:
|
|
cors_params = {
|
|
"allow_origins": ["https://app.example.com"],
|
|
"allow_origin_regex": None,
|
|
}
|
|
mocker.patch("backend.api.ws_api.uvicorn.run")
|
|
cors_middleware = mocker.patch(
|
|
"backend.api.ws_api.CORSMiddleware", return_value=object()
|
|
)
|
|
build_cors = mocker.patch(
|
|
"backend.api.ws_api.build_cors_params", return_value=cors_params
|
|
)
|
|
|
|
with override_config(
|
|
settings, "backend_cors_allow_origins", cors_params["allow_origins"]
|
|
), override_config(settings, "app_env", AppEnvironment.LOCAL):
|
|
WebsocketServer().run()
|
|
|
|
build_cors.assert_called_once_with(
|
|
cors_params["allow_origins"], AppEnvironment.LOCAL
|
|
)
|
|
cors_middleware.assert_called_once_with(
|
|
app=websocket_app,
|
|
allow_origins=cors_params["allow_origins"],
|
|
allow_origin_regex=cors_params["allow_origin_regex"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
def test_websocket_server_blocks_localhost_in_production(mocker) -> None:
|
|
mocker.patch("backend.api.ws_api.uvicorn.run")
|
|
|
|
with override_config(
|
|
settings, "backend_cors_allow_origins", ["http://localhost:3000"]
|
|
), override_config(settings, "app_env", AppEnvironment.PRODUCTION):
|
|
with pytest.raises(ValueError):
|
|
WebsocketServer().run()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_websocket_router_subscribe(
|
|
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot, mocker
|
|
) -> None:
|
|
# Mock the authenticate_websocket function to ensure it returns a valid user_id
|
|
mocker.patch(
|
|
"backend.api.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
|
|
)
|
|
|
|
mock_websocket.receive_text.side_effect = [
|
|
WSMessage(
|
|
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
|
|
data={"graph_exec_id": "test-graph-exec-1"},
|
|
).model_dump_json(),
|
|
WebSocketDisconnect(),
|
|
]
|
|
mock_manager.subscribe_graph_exec.return_value = (
|
|
f"{DEFAULT_USER_ID}|graph_exec#test-graph-exec-1"
|
|
)
|
|
|
|
await websocket_router(
|
|
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
|
|
)
|
|
|
|
mock_manager.connect_socket.assert_called_once_with(
|
|
mock_websocket, user_id=DEFAULT_USER_ID
|
|
)
|
|
mock_manager.subscribe_graph_exec.assert_called_once_with(
|
|
user_id=DEFAULT_USER_ID,
|
|
graph_exec_id="test-graph-exec-1",
|
|
websocket=mock_websocket,
|
|
)
|
|
mock_websocket.send_text.assert_called_once()
|
|
assert (
|
|
'"method":"subscribe_graph_execution"'
|
|
in mock_websocket.send_text.call_args[0][0]
|
|
)
|
|
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
|
|
|
|
# Capture and snapshot the WebSocket response message
|
|
sent_message = mock_websocket.send_text.call_args[0][0]
|
|
parsed_message = json.loads(sent_message)
|
|
snapshot.snapshot_dir = "snapshots"
|
|
snapshot.assert_match(json.dumps(parsed_message, indent=2, sort_keys=True), "sub")
|
|
|
|
mock_manager.disconnect_socket.assert_called_once_with(
|
|
mock_websocket, user_id=DEFAULT_USER_ID
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_websocket_router_unsubscribe(
|
|
mock_websocket: AsyncMock, mock_manager: AsyncMock, snapshot: Snapshot, mocker
|
|
) -> None:
|
|
# Mock the authenticate_websocket function to ensure it returns a valid user_id
|
|
mocker.patch(
|
|
"backend.api.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
|
|
)
|
|
|
|
mock_websocket.receive_text.side_effect = [
|
|
WSMessage(
|
|
method=WSMethod.UNSUBSCRIBE,
|
|
data={"graph_exec_id": "test-graph-exec-1"},
|
|
).model_dump_json(),
|
|
WebSocketDisconnect(),
|
|
]
|
|
mock_manager.unsubscribe_graph_exec.return_value = (
|
|
f"{DEFAULT_USER_ID}|graph_exec#test-graph-exec-1"
|
|
)
|
|
|
|
await websocket_router(
|
|
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
|
|
)
|
|
|
|
mock_manager.connect_socket.assert_called_once_with(
|
|
mock_websocket, user_id=DEFAULT_USER_ID
|
|
)
|
|
mock_manager.unsubscribe_graph_exec.assert_called_once_with(
|
|
user_id=DEFAULT_USER_ID,
|
|
graph_exec_id="test-graph-exec-1",
|
|
websocket=mock_websocket,
|
|
)
|
|
mock_websocket.send_text.assert_called_once()
|
|
assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0]
|
|
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
|
|
|
|
# Capture and snapshot the WebSocket response message
|
|
sent_message = mock_websocket.send_text.call_args[0][0]
|
|
parsed_message = json.loads(sent_message)
|
|
snapshot.snapshot_dir = "snapshots"
|
|
snapshot.assert_match(json.dumps(parsed_message, indent=2, sort_keys=True), "unsub")
|
|
|
|
mock_manager.disconnect_socket.assert_called_once_with(
|
|
mock_websocket, user_id=DEFAULT_USER_ID
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_websocket_router_invalid_method(
|
|
mock_websocket: AsyncMock, mock_manager: AsyncMock, mocker
|
|
) -> None:
|
|
# Mock the authenticate_websocket function to ensure it returns a valid user_id
|
|
mocker.patch(
|
|
"backend.api.ws_api.authenticate_websocket", return_value=DEFAULT_USER_ID
|
|
)
|
|
|
|
mock_websocket.receive_text.side_effect = [
|
|
WSMessage(method=WSMethod.GRAPH_EXECUTION_EVENT).model_dump_json(),
|
|
WebSocketDisconnect(),
|
|
]
|
|
|
|
await websocket_router(
|
|
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
|
|
)
|
|
|
|
mock_manager.connect_socket.assert_called_once_with(
|
|
mock_websocket, user_id=DEFAULT_USER_ID
|
|
)
|
|
mock_websocket.send_text.assert_called_once()
|
|
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
|
|
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
|
|
mock_manager.disconnect_socket.assert_called_once_with(
|
|
mock_websocket, user_id=DEFAULT_USER_ID
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_subscribe_success(
|
|
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
|
) -> None:
|
|
message = WSMessage(
|
|
method=WSMethod.SUBSCRIBE_GRAPH_EXEC,
|
|
data={"graph_exec_id": "test-graph-exec-id"},
|
|
)
|
|
mock_manager.subscribe_graph_exec.return_value = (
|
|
"user-1|graph_exec#test-graph-exec-id"
|
|
)
|
|
|
|
await handle_subscribe(
|
|
connection_manager=cast(ConnectionManager, mock_manager),
|
|
websocket=cast(WebSocket, mock_websocket),
|
|
user_id="user-1",
|
|
message=message,
|
|
)
|
|
|
|
mock_manager.subscribe_graph_exec.assert_called_once_with(
|
|
user_id="user-1",
|
|
graph_exec_id="test-graph-exec-id",
|
|
websocket=mock_websocket,
|
|
)
|
|
mock_websocket.send_text.assert_called_once()
|
|
assert (
|
|
'"method":"subscribe_graph_execution"'
|
|
in mock_websocket.send_text.call_args[0][0]
|
|
)
|
|
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_subscribe_missing_data(
|
|
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
|
) -> None:
|
|
message = WSMessage(method=WSMethod.SUBSCRIBE_GRAPH_EXEC)
|
|
|
|
await handle_subscribe(
|
|
connection_manager=cast(ConnectionManager, mock_manager),
|
|
websocket=cast(WebSocket, mock_websocket),
|
|
user_id="user-1",
|
|
message=message,
|
|
)
|
|
|
|
mock_manager.subscribe_graph_exec.assert_not_called()
|
|
mock_websocket.send_text.assert_called_once()
|
|
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
|
|
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_unsubscribe_success(
|
|
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
|
) -> None:
|
|
message = WSMessage(
|
|
method=WSMethod.UNSUBSCRIBE, data={"graph_exec_id": "test-graph-exec-id"}
|
|
)
|
|
mock_manager.unsubscribe_graph_exec.return_value = (
|
|
"user-1|graph_exec#test-graph-exec-id"
|
|
)
|
|
|
|
await handle_unsubscribe(
|
|
connection_manager=cast(ConnectionManager, mock_manager),
|
|
websocket=cast(WebSocket, mock_websocket),
|
|
user_id="user-1",
|
|
message=message,
|
|
)
|
|
|
|
mock_manager.unsubscribe_graph_exec.assert_called_once_with(
|
|
user_id="user-1",
|
|
graph_exec_id="test-graph-exec-id",
|
|
websocket=mock_websocket,
|
|
)
|
|
mock_websocket.send_text.assert_called_once()
|
|
assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0]
|
|
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_unsubscribe_missing_data(
|
|
mock_websocket: AsyncMock, mock_manager: AsyncMock
|
|
) -> None:
|
|
message = WSMessage(method=WSMethod.UNSUBSCRIBE)
|
|
|
|
await handle_unsubscribe(
|
|
connection_manager=cast(ConnectionManager, mock_manager),
|
|
websocket=cast(WebSocket, mock_websocket),
|
|
user_id="user-1",
|
|
message=message,
|
|
)
|
|
|
|
mock_manager._unsubscribe.assert_not_called()
|
|
mock_websocket.send_text.assert_called_once()
|
|
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
|
|
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
|