mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-08 13:55:06 -05:00
fix(mcp): Address PR review comments
- Fix get_missing_input/get_mismatch_error to validate tool_arguments dict instead of the entire BlockInput data (critical bug) - Add type check for non-dict JSON-RPC error field in client.py - Add try/catch for non-JSON responses in client.py - Add raise_for_status and error payload checks to OAuth token requests - Remove hardcoded placeholder skip-list from _extract_auth_token - Fix server start timeout check in integration tests - Remove unused MCPTool import, move execute_block_test to top-level - Update tests to match fixed validation behavior - Fix MCP_BLOCK_IMPLEMENTATION.md (remove duplicate section, local path) - Soften PKCE comment in oauth.py
This commit is contained in:
@@ -28,12 +28,12 @@ Result yielded as block output
|
||||
1. **Single block, not many blocks** — One `MCPBlock` handles all MCP servers/tools
|
||||
2. **Dynamic schema via AgentExecutorBlock pattern** — Override `get_input_schema()`,
|
||||
`get_input_defaults()`, `get_missing_input()` on the Input class
|
||||
3. **Auth via API key credentials** — Use existing `APIKeyCredentials` with `ProviderName.MCP`
|
||||
provider. The API key is sent as Bearer token in the HTTP Authorization header to the MCP
|
||||
server. This keeps it simple and uses existing infrastructure.
|
||||
3. **Auth via API key or OAuth2 credentials** — Use existing `APIKeyCredentials` or
|
||||
`OAuth2Credentials` with `ProviderName.MCP` provider. API keys are sent as Bearer tokens;
|
||||
OAuth2 uses the access token.
|
||||
4. **HTTP-based MCP client** — Use `aiohttp` (already a dependency) to implement MCP Streamable
|
||||
HTTP transport directly. No need for the `mcp` Python SDK — the protocol is simple JSON-RPC
|
||||
over HTTP.
|
||||
over HTTP. Handles both JSON and SSE response formats.
|
||||
5. **No new DB tables** — Everything fits in existing `AgentBlock` + `AgentNode` tables
|
||||
|
||||
## Implementation Files
|
||||
@@ -43,75 +43,23 @@ Result yielded as block output
|
||||
- `__init__.py`
|
||||
- `block.py` — MCPToolBlock implementation
|
||||
- `client.py` — MCP HTTP client (list_tools, call_tool)
|
||||
- `test_mcp.py` — Tests (34 tests)
|
||||
- `oauth.py` — MCP OAuth handler for dynamic endpoint discovery
|
||||
- `test_mcp.py` — Unit tests
|
||||
- `test_oauth.py` — OAuth handler tests
|
||||
- `test_integration.py` — Integration tests with local test server
|
||||
- `test_e2e.py` — E2E tests against real MCP servers
|
||||
|
||||
### Modified Files
|
||||
- `backend/integrations/providers.py` — Add `MCP = "mcp"` to ProviderName
|
||||
- `pyproject.toml` — No changes needed (using aiohttp which is already a dep)
|
||||
|
||||
## Detailed Design
|
||||
|
||||
### MCP Client (`client.py`)
|
||||
|
||||
Simple async HTTP client for MCP Streamable HTTP protocol:
|
||||
|
||||
```python
|
||||
class MCPClient:
|
||||
async def list_tools(server_url: str, headers: dict) -> list[MCPTool]
|
||||
async def call_tool(server_url: str, tool_name: str, arguments: dict, headers: dict) -> Any
|
||||
```
|
||||
|
||||
Uses JSON-RPC 2.0 over HTTP POST:
|
||||
- `tools/list` → `{"jsonrpc": "2.0", "method": "tools/list", "id": 1}`
|
||||
- `tools/call` → `{"jsonrpc": "2.0", "method": "tools/call", "params": {"name": "...", "arguments": {...}}, "id": 2}`
|
||||
|
||||
### MCPBlock (`block.py`)
|
||||
|
||||
Key fields:
|
||||
- `server_url: str` — MCP server endpoint URL
|
||||
- `credentials: MCPCredentialsInput` — API key for auth (optional)
|
||||
- `available_tools: dict` — Cached tools list from server (populated by frontend API call)
|
||||
- `selected_tool: str` — Which tool the user selected
|
||||
- `tool_input_schema: dict` — JSON schema of the selected tool's inputs
|
||||
- `tool_arguments: dict` — The actual tool arguments (dynamic, validated against tool_input_schema)
|
||||
|
||||
Dynamic schema pattern (like AgentExecutorBlock):
|
||||
```python
|
||||
@classmethod
|
||||
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
|
||||
return data.get("tool_input_schema", {})
|
||||
|
||||
@classmethod
|
||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||
return data.get("tool_arguments", {})
|
||||
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
required = cls.get_input_schema(data).get("required", [])
|
||||
return set(required) - set(data)
|
||||
```
|
||||
|
||||
### Auth
|
||||
|
||||
Use existing `APIKeyCredentials` with provider `"mcp"`:
|
||||
- User creates an API key credential for their MCP server
|
||||
- Block sends it as `Authorization: Bearer <key>` header
|
||||
- Credentials are optional (some MCP servers don't need auth)
|
||||
|
||||
## Dev Loop
|
||||
|
||||
```bash
|
||||
cd /Users/majdyz/Code/AutoGPT2/autogpt_platform/backend
|
||||
poetry run pytest backend/blocks/test/test_mcp_block.py -xvs # Run MCP-specific tests
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs -k "MCP" # Run block test suite for MCP
|
||||
```
|
||||
|
||||
## Dev Loop
|
||||
|
||||
```bash
|
||||
cd /Users/majdyz/Code/AutoGPT2/autogpt_platform/backend
|
||||
poetry run pytest backend/blocks/mcp/test_mcp.py -xvs # Run MCP-specific tests (34 tests)
|
||||
poetry run pytest backend/blocks/test/test_block.py -xvs -k "MCP" # Run block test suite for MCP
|
||||
cd autogpt_platform/backend
|
||||
poetry run pytest backend/blocks/mcp/test_mcp.py -xvs # Unit tests
|
||||
poetry run pytest backend/blocks/mcp/test_oauth.py -xvs # OAuth tests
|
||||
poetry run pytest backend/blocks/mcp/test_integration.py -xvs # Integration tests
|
||||
poetry run pytest backend/blocks/mcp/ -xvs # All MCP tests
|
||||
```
|
||||
|
||||
## Status
|
||||
@@ -120,6 +68,9 @@ poetry run pytest backend/blocks/test/test_block.py -xvs -k "MCP" # Run block t
|
||||
- [x] Add ProviderName.MCP
|
||||
- [x] Implement MCP client (client.py)
|
||||
- [x] Implement MCPToolBlock (block.py)
|
||||
- [x] Write unit tests (34 tests — all passing)
|
||||
- [x] Add OAuth2 support (oauth.py)
|
||||
- [x] Write unit tests
|
||||
- [x] Write integration tests
|
||||
- [x] Write E2E tests
|
||||
- [x] Run tests & fix issues
|
||||
- [ ] Create PR
|
||||
- [x] Create PR
|
||||
|
||||
@@ -116,7 +116,8 @@ class MCPToolBlock(Block):
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
"""Check which required tool arguments are missing."""
|
||||
required_fields = cls.get_input_schema(data).get("required", [])
|
||||
return set(required_fields) - set(data)
|
||||
tool_arguments = data.get("tool_arguments", {})
|
||||
return set(required_fields) - set(tool_arguments)
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
@@ -124,7 +125,8 @@ class MCPToolBlock(Block):
|
||||
tool_schema = cls.get_input_schema(data)
|
||||
if not tool_schema:
|
||||
return None
|
||||
return validate_with_jsonschema(tool_schema, data)
|
||||
tool_arguments = data.get("tool_arguments", {})
|
||||
return validate_with_jsonschema(tool_schema, tool_arguments)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
result: Any = SchemaField(
|
||||
@@ -228,8 +230,7 @@ class MCPToolBlock(Block):
|
||||
|
||||
if isinstance(credentials, APIKeyCredentials) and credentials.api_key:
|
||||
token_value = credentials.api_key.get_secret_value()
|
||||
# Skip placeholder/fake tokens
|
||||
if token_value and token_value not in ("", "FAKE_API_KEY", "placeholder"):
|
||||
if token_value:
|
||||
return token_value
|
||||
|
||||
return None
|
||||
|
||||
@@ -137,15 +137,22 @@ class MCPClient:
|
||||
if "text/event-stream" in content_type:
|
||||
body = self._parse_sse_response(response.text())
|
||||
else:
|
||||
body = response.json()
|
||||
try:
|
||||
body = response.json()
|
||||
except (ValueError, Exception) as e:
|
||||
raise MCPClientError(
|
||||
f"MCP server returned non-JSON response: {e}"
|
||||
) from e
|
||||
|
||||
# Handle JSON-RPC error
|
||||
if "error" in body:
|
||||
error = body["error"]
|
||||
raise MCPClientError(
|
||||
f"MCP server error [{error.get('code', '?')}]: "
|
||||
f"{error.get('message', 'Unknown error')}"
|
||||
)
|
||||
if isinstance(error, dict):
|
||||
raise MCPClientError(
|
||||
f"MCP server error [{error.get('code', '?')}]: "
|
||||
f"{error.get('message', 'Unknown error')}"
|
||||
)
|
||||
raise MCPClientError(f"MCP server error: {error}")
|
||||
|
||||
return body.get("result")
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ class MCPOAuthHandler(BaseOAuthHandler):
|
||||
}
|
||||
if scopes:
|
||||
params["scope"] = " ".join(scopes)
|
||||
# PKCE is required by the MCP spec (S256 only)
|
||||
# PKCE (S256) — included when the caller provides a code_challenge
|
||||
if code_challenge:
|
||||
params["code_challenge"] = code_challenge
|
||||
params["code_challenge_method"] = "S256"
|
||||
@@ -97,13 +97,18 @@ class MCPOAuthHandler(BaseOAuthHandler):
|
||||
if self.resource_url:
|
||||
data["resource"] = self.resource_url
|
||||
|
||||
response = await Requests().post(
|
||||
response = await Requests(raise_for_status=True).post(
|
||||
self.token_url,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
tokens = response.json()
|
||||
|
||||
if "error" in tokens:
|
||||
raise RuntimeError(
|
||||
f"Token exchange failed: {tokens.get('error_description', tokens['error'])}"
|
||||
)
|
||||
|
||||
now = int(time.time())
|
||||
expires_in = tokens.get("expires_in")
|
||||
|
||||
@@ -141,13 +146,18 @@ class MCPOAuthHandler(BaseOAuthHandler):
|
||||
if self.resource_url:
|
||||
data["resource"] = self.resource_url
|
||||
|
||||
response = await Requests().post(
|
||||
response = await Requests(raise_for_status=True).post(
|
||||
self.token_url,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
tokens = response.json()
|
||||
|
||||
if "error" in tokens:
|
||||
raise RuntimeError(
|
||||
f"Token refresh failed: {tokens.get('error_description', tokens['error'])}"
|
||||
)
|
||||
|
||||
now = int(time.time())
|
||||
expires_in = tokens.get("expires_in")
|
||||
|
||||
|
||||
@@ -52,7 +52,8 @@ class _MCPTestServer:
|
||||
def start(self):
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
self._started.wait(timeout=5)
|
||||
if not self._started.wait(timeout=5):
|
||||
raise RuntimeError("MCP test server failed to start within 5 seconds")
|
||||
return self
|
||||
|
||||
def stop(self):
|
||||
|
||||
@@ -6,12 +6,12 @@ import json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.mcp.block import MCPToolBlock, TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
|
||||
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError, MCPTool
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
# ── SSE parsing unit tests ───────────────────────────────────────────
|
||||
|
||||
@@ -81,7 +81,6 @@ class TestSSEParsing:
|
||||
)
|
||||
body = MCPClient._parse_sse_response(sse)
|
||||
assert body["result"] == "second"
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
|
||||
# ── MCPClient unit tests ─────────────────────────────────────────────
|
||||
@@ -326,7 +325,7 @@ class TestMCPToolBlock:
|
||||
},
|
||||
"required": ["city", "units"],
|
||||
},
|
||||
"city": "London",
|
||||
"tool_arguments": {"city": "London"},
|
||||
}
|
||||
missing = MCPToolBlock.Input.get_missing_input(data)
|
||||
assert missing == {"units"}
|
||||
@@ -338,7 +337,7 @@ class TestMCPToolBlock:
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
"city": "London",
|
||||
"tool_arguments": {"city": "London"},
|
||||
}
|
||||
missing = MCPToolBlock.Input.get_missing_input(data)
|
||||
assert missing == set()
|
||||
@@ -575,8 +574,8 @@ class TestMCPToolBlock:
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_skips_placeholder_credentials(self):
|
||||
"""Ensure placeholder API keys are not sent to the MCP server."""
|
||||
async def test_run_sends_api_key_credentials(self):
|
||||
"""Ensure non-empty API keys are sent to the MCP server."""
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
@@ -584,11 +583,11 @@ class TestMCPToolBlock:
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
placeholder_creds = APIKeyCredentials(
|
||||
creds = APIKeyCredentials(
|
||||
id="test-id",
|
||||
provider="mcp",
|
||||
api_key=SecretStr("FAKE_API_KEY"),
|
||||
title="Placeholder",
|
||||
api_key=SecretStr("real-api-key"),
|
||||
title="Real",
|
||||
)
|
||||
|
||||
captured_tokens = []
|
||||
@@ -599,10 +598,10 @@ class TestMCPToolBlock:
|
||||
|
||||
block._call_mcp_tool = mock_call # type: ignore
|
||||
|
||||
async for _ in block.run(input_data, credentials=placeholder_creds):
|
||||
async for _ in block.run(input_data, credentials=creds):
|
||||
pass
|
||||
|
||||
assert captured_tokens == [None]
|
||||
assert captured_tokens == ["real-api-key"]
|
||||
|
||||
|
||||
# ── OAuth2 credential support tests ─────────────────────────────────
|
||||
@@ -628,14 +627,6 @@ class TestMCPOAuth2Support:
|
||||
token = MCPToolBlock._extract_auth_token(creds)
|
||||
assert token == "oauth2-access-token"
|
||||
|
||||
def test_extract_auth_token_placeholder_skipped(self):
|
||||
creds = APIKeyCredentials(
|
||||
id="test", provider="mcp",
|
||||
api_key=SecretStr("FAKE_API_KEY"), title="test",
|
||||
)
|
||||
token = MCPToolBlock._extract_auth_token(creds)
|
||||
assert token is None
|
||||
|
||||
def test_extract_auth_token_empty_skipped(self):
|
||||
creds = APIKeyCredentials(
|
||||
id="test", provider="mcp",
|
||||
|
||||
Reference in New Issue
Block a user