mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-10 23:05:17 -05:00
fix(backend/mcp): Use httpx.AsyncClient in test_routes to prevent event loop corruption
Replace fastapi.testclient.TestClient with httpx.AsyncClient + ASGITransport. TestClient creates a new anyio blocking portal per request. When 11+ portals are created and destroyed in a session that also has pytest-asyncio session-scoped async fixtures, the session event loop gets corrupted, causing "RuntimeError: Event loop is closed" in subsequent async tests. AsyncClient with ASGITransport runs the ASGI app directly in the current event loop without creating blocking portals.
This commit is contained in:
@@ -1,9 +1,15 @@
|
||||
"""Tests for MCP API routes."""
|
||||
"""Tests for MCP API routes.
|
||||
|
||||
Uses httpx.AsyncClient with ASGITransport instead of fastapi.testclient.TestClient
|
||||
to avoid creating blocking portals that can corrupt pytest-asyncio's session event loop.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from autogpt_libs.auth import get_user_id
|
||||
|
||||
from backend.api.features.mcp.routes import router
|
||||
@@ -13,11 +19,18 @@ from backend.util.request import HTTPClientError
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(router)
|
||||
app.dependency_overrides[get_user_id] = lambda: "test-user-id"
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="module")
|
||||
async def client():
|
||||
transport = httpx.ASGITransport(app=app)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
|
||||
yield c
|
||||
|
||||
|
||||
class TestDiscoverTools:
|
||||
def test_discover_tools_success(self):
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_success(self, client):
|
||||
mock_tools = [
|
||||
MCPTool(
|
||||
name="get_weather",
|
||||
@@ -51,7 +64,7 @@ class TestDiscoverTools:
|
||||
)
|
||||
instance.list_tools = AsyncMock(return_value=mock_tools)
|
||||
|
||||
response = client.post(
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "https://mcp.example.com/mcp"},
|
||||
)
|
||||
@@ -64,7 +77,8 @@ class TestDiscoverTools:
|
||||
assert data["server_name"] == "test-server"
|
||||
assert data["protocol_version"] == "2025-03-26"
|
||||
|
||||
def test_discover_tools_with_auth_token(self):
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_with_auth_token(self, client):
|
||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
@@ -72,7 +86,7 @@ class TestDiscoverTools:
|
||||
)
|
||||
instance.list_tools = AsyncMock(return_value=[])
|
||||
|
||||
response = client.post(
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={
|
||||
"server_url": "https://mcp.example.com/mcp",
|
||||
@@ -86,7 +100,8 @@ class TestDiscoverTools:
|
||||
auth_token="my-secret-token",
|
||||
)
|
||||
|
||||
def test_discover_tools_auto_uses_stored_credential(self):
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_auto_uses_stored_credential(self, client):
|
||||
"""When no explicit token is given, stored MCP credentials are used."""
|
||||
from pydantic import SecretStr
|
||||
|
||||
@@ -115,7 +130,7 @@ class TestDiscoverTools:
|
||||
)
|
||||
instance.list_tools = AsyncMock(return_value=[])
|
||||
|
||||
response = client.post(
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "https://mcp.example.com/mcp"},
|
||||
)
|
||||
@@ -126,14 +141,15 @@ class TestDiscoverTools:
|
||||
auth_token="stored-token-123",
|
||||
)
|
||||
|
||||
def test_discover_tools_mcp_error(self):
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_mcp_error(self, client):
|
||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
side_effect=MCPClientError("Connection refused")
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "https://bad-server.example.com/mcp"},
|
||||
)
|
||||
@@ -141,12 +157,13 @@ class TestDiscoverTools:
|
||||
assert response.status_code == 502
|
||||
assert "Connection refused" in response.json()["detail"]
|
||||
|
||||
def test_discover_tools_generic_error(self):
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_generic_error(self, client):
|
||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
|
||||
|
||||
response = client.post(
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "https://timeout.example.com/mcp"},
|
||||
)
|
||||
@@ -154,14 +171,15 @@ class TestDiscoverTools:
|
||||
assert response.status_code == 502
|
||||
assert "Failed to connect" in response.json()["detail"]
|
||||
|
||||
def test_discover_tools_auth_required(self):
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_auth_required(self, client):
|
||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "https://auth-server.example.com/mcp"},
|
||||
)
|
||||
@@ -169,14 +187,15 @@ class TestDiscoverTools:
|
||||
assert response.status_code == 401
|
||||
assert "requires authentication" in response.json()["detail"]
|
||||
|
||||
def test_discover_tools_forbidden(self):
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_forbidden(self, client):
|
||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||
instance = MockClient.return_value
|
||||
instance.initialize = AsyncMock(
|
||||
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
response = await client.post(
|
||||
"/discover-tools",
|
||||
json={"server_url": "https://auth-server.example.com/mcp"},
|
||||
)
|
||||
@@ -184,13 +203,15 @@ class TestDiscoverTools:
|
||||
assert response.status_code == 401
|
||||
assert "requires authentication" in response.json()["detail"]
|
||||
|
||||
def test_discover_tools_missing_url(self):
|
||||
response = client.post("/discover-tools", json={})
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_discover_tools_missing_url(self, client):
|
||||
response = await client.post("/discover-tools", json={})
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
class TestOAuthLogin:
|
||||
def test_oauth_login_success(self):
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_login_success(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
@@ -223,7 +244,7 @@ class TestOAuthLogin:
|
||||
)
|
||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||
|
||||
response = client.post(
|
||||
response = await client.post(
|
||||
"/oauth/login",
|
||||
json={"server_url": "https://mcp.sentry.dev/mcp"},
|
||||
)
|
||||
@@ -235,13 +256,14 @@ class TestOAuthLogin:
|
||||
assert "auth.sentry.io/authorize" in data["login_url"]
|
||||
assert "registered-client-id" in data["login_url"]
|
||||
|
||||
def test_oauth_login_no_oauth_support(self):
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_login_no_oauth_support(self, client):
|
||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
||||
instance = MockClient.return_value
|
||||
instance.discover_auth = AsyncMock(return_value=None)
|
||||
instance.discover_auth_server_metadata = AsyncMock(return_value=None)
|
||||
|
||||
response = client.post(
|
||||
response = await client.post(
|
||||
"/oauth/login",
|
||||
json={"server_url": "https://simple-server.example.com/mcp"},
|
||||
)
|
||||
@@ -249,7 +271,8 @@ class TestOAuthLogin:
|
||||
assert response.status_code == 400
|
||||
assert "does not advertise OAuth" in response.json()["detail"]
|
||||
|
||||
def test_oauth_login_fallback_to_public_client(self):
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_login_fallback_to_public_client(self, client):
|
||||
"""When DCR is unavailable, falls back to default public client ID."""
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
||||
@@ -275,7 +298,7 @@ class TestOAuthLogin:
|
||||
)
|
||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
||||
|
||||
response = client.post(
|
||||
response = await client.post(
|
||||
"/oauth/login",
|
||||
json={"server_url": "https://mcp.example.com/mcp"},
|
||||
)
|
||||
@@ -286,7 +309,8 @@ class TestOAuthLogin:
|
||||
|
||||
|
||||
class TestOAuthCallback:
|
||||
def test_oauth_callback_success(self):
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_callback_success(self, client):
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
@@ -334,7 +358,7 @@ class TestOAuthCallback:
|
||||
# Mock old credential cleanup
|
||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
||||
|
||||
response = client.post(
|
||||
response = await client.post(
|
||||
"/oauth/callback",
|
||||
json={"code": "auth-code-abc", "state_token": "state-token-123"},
|
||||
)
|
||||
@@ -346,11 +370,12 @@ class TestOAuthCallback:
|
||||
assert data["type"] == "oauth2"
|
||||
mock_cm.create.assert_called_once()
|
||||
|
||||
def test_oauth_callback_invalid_state(self):
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_callback_invalid_state(self, client):
|
||||
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
|
||||
mock_cm.store.verify_state_token = AsyncMock(return_value=None)
|
||||
|
||||
response = client.post(
|
||||
response = await client.post(
|
||||
"/oauth/callback",
|
||||
json={"code": "auth-code", "state_token": "bad-state"},
|
||||
)
|
||||
@@ -358,7 +383,8 @@ class TestOAuthCallback:
|
||||
assert response.status_code == 400
|
||||
assert "Invalid or expired" in response.json()["detail"]
|
||||
|
||||
def test_oauth_callback_token_exchange_fails(self):
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_oauth_callback_token_exchange_fails(self, client):
|
||||
with (
|
||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
||||
@@ -381,7 +407,7 @@ class TestOAuthCallback:
|
||||
side_effect=RuntimeError("Token exchange failed")
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
response = await client.post(
|
||||
"/oauth/callback",
|
||||
json={"code": "bad-code", "state_token": "state"},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user