fix(mcp): Address PR review comments

- Fix get_missing_input/get_mismatch_error to validate tool_arguments
  dict instead of the entire BlockInput data (critical bug)
- Add type check for non-dict JSON-RPC error field in client.py
- Add try/catch for non-JSON responses in client.py
- Add raise_for_status and error payload checks to OAuth token requests
- Remove hardcoded placeholder skip-list from _extract_auth_token
- Fix server start timeout check in integration tests
- Remove unused MCPTool import, move execute_block_test to top-level
- Update tests to match fixed validation behavior
- Fix MCP_BLOCK_IMPLEMENTATION.md (remove duplicate section, local path)
- Soften PKCE comment in oauth.py
This commit is contained in:
Zamil Majdy
2026-02-08 19:34:28 +04:00
parent 7db3f12876
commit 19b3373052
6 changed files with 62 additions and 101 deletions

View File

@@ -28,12 +28,12 @@ Result yielded as block output
1. **Single block, not many blocks** — One `MCPBlock` handles all MCP servers/tools
2. **Dynamic schema via AgentExecutorBlock pattern** — Override `get_input_schema()`,
`get_input_defaults()`, `get_missing_input()` on the Input class
3. **Auth via API key credentials** — Use existing `APIKeyCredentials` with `ProviderName.MCP`
provider. The API key is sent as Bearer token in the HTTP Authorization header to the MCP
server. This keeps it simple and uses existing infrastructure.
3. **Auth via API key or OAuth2 credentials** — Use existing `APIKeyCredentials` or
`OAuth2Credentials` with `ProviderName.MCP` provider. API keys are sent as Bearer tokens;
OAuth2 uses the access token.
4. **HTTP-based MCP client** — Use `aiohttp` (already a dependency) to implement MCP Streamable
HTTP transport directly. No need for the `mcp` Python SDK — the protocol is simple JSON-RPC
over HTTP.
over HTTP. Handles both JSON and SSE response formats.
5. **No new DB tables** — Everything fits in existing `AgentBlock` + `AgentNode` tables
## Implementation Files
@@ -43,75 +43,23 @@ Result yielded as block output
- `__init__.py`
- `block.py` — MCPToolBlock implementation
- `client.py` — MCP HTTP client (list_tools, call_tool)
- `test_mcp.py` — Tests (34 tests)
- `oauth.py` — MCP OAuth handler for dynamic endpoint discovery
- `test_mcp.py` — Unit tests
- `test_oauth.py` — OAuth handler tests
- `test_integration.py` — Integration tests with local test server
- `test_e2e.py` — E2E tests against real MCP servers
### Modified Files
- `backend/integrations/providers.py` — Add `MCP = "mcp"` to ProviderName
- `pyproject.toml` — No changes needed (using aiohttp which is already a dep)
## Detailed Design
### MCP Client (`client.py`)
Simple async HTTP client for MCP Streamable HTTP protocol:
```python
class MCPClient:
async def list_tools(server_url: str, headers: dict) -> list[MCPTool]
async def call_tool(server_url: str, tool_name: str, arguments: dict, headers: dict) -> Any
```
Uses JSON-RPC 2.0 over HTTP POST:
- `tools/list``{"jsonrpc": "2.0", "method": "tools/list", "id": 1}`
- `tools/call``{"jsonrpc": "2.0", "method": "tools/call", "params": {"name": "...", "arguments": {...}}, "id": 2}`
### MCPBlock (`block.py`)
Key fields:
- `server_url: str` — MCP server endpoint URL
- `credentials: MCPCredentialsInput` — API key for auth (optional)
- `available_tools: dict` — Cached tools list from server (populated by frontend API call)
- `selected_tool: str` — Which tool the user selected
- `tool_input_schema: dict` — JSON schema of the selected tool's inputs
- `tool_arguments: dict` — The actual tool arguments (dynamic, validated against tool_input_schema)
Dynamic schema pattern (like AgentExecutorBlock):
```python
@classmethod
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
return data.get("tool_input_schema", {})
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
return data.get("tool_arguments", {})
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
required = cls.get_input_schema(data).get("required", [])
return set(required) - set(data)
```
### Auth
Use existing `APIKeyCredentials` with provider `"mcp"`:
- User creates an API key credential for their MCP server
- Block sends it as `Authorization: Bearer <key>` header
- Credentials are optional (some MCP servers don't need auth)
## Dev Loop
```bash
cd /Users/majdyz/Code/AutoGPT2/autogpt_platform/backend
poetry run pytest backend/blocks/test/test_mcp_block.py -xvs # Run MCP-specific tests
poetry run pytest backend/blocks/test/test_block.py -xvs -k "MCP" # Run block test suite for MCP
```
## Dev Loop
```bash
cd /Users/majdyz/Code/AutoGPT2/autogpt_platform/backend
poetry run pytest backend/blocks/mcp/test_mcp.py -xvs # Run MCP-specific tests (34 tests)
poetry run pytest backend/blocks/test/test_block.py -xvs -k "MCP" # Run block test suite for MCP
cd autogpt_platform/backend
poetry run pytest backend/blocks/mcp/test_mcp.py -xvs # Unit tests
poetry run pytest backend/blocks/mcp/test_oauth.py -xvs # OAuth tests
poetry run pytest backend/blocks/mcp/test_integration.py -xvs # Integration tests
poetry run pytest backend/blocks/mcp/ -xvs # All MCP tests
```
## Status
@@ -120,6 +68,9 @@ poetry run pytest backend/blocks/test/test_block.py -xvs -k "MCP" # Run block t
- [x] Add ProviderName.MCP
- [x] Implement MCP client (client.py)
- [x] Implement MCPToolBlock (block.py)
- [x] Write unit tests (34 tests — all passing)
- [x] Add OAuth2 support (oauth.py)
- [x] Write unit tests
- [x] Write integration tests
- [x] Write E2E tests
- [x] Run tests & fix issues
- [ ] Create PR
- [x] Create PR

View File

@@ -116,7 +116,8 @@ class MCPToolBlock(Block):
def get_missing_input(cls, data: BlockInput) -> set[str]:
"""Check which required tool arguments are missing."""
required_fields = cls.get_input_schema(data).get("required", [])
return set(required_fields) - set(data)
tool_arguments = data.get("tool_arguments", {})
return set(required_fields) - set(tool_arguments)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
@@ -124,7 +125,8 @@ class MCPToolBlock(Block):
tool_schema = cls.get_input_schema(data)
if not tool_schema:
return None
return validate_with_jsonschema(tool_schema, data)
tool_arguments = data.get("tool_arguments", {})
return validate_with_jsonschema(tool_schema, tool_arguments)
class Output(BlockSchemaOutput):
result: Any = SchemaField(
@@ -228,8 +230,7 @@ class MCPToolBlock(Block):
if isinstance(credentials, APIKeyCredentials) and credentials.api_key:
token_value = credentials.api_key.get_secret_value()
# Skip placeholder/fake tokens
if token_value and token_value not in ("", "FAKE_API_KEY", "placeholder"):
if token_value:
return token_value
return None

View File

@@ -137,15 +137,22 @@ class MCPClient:
if "text/event-stream" in content_type:
body = self._parse_sse_response(response.text())
else:
body = response.json()
try:
body = response.json()
except (ValueError, Exception) as e:
raise MCPClientError(
f"MCP server returned non-JSON response: {e}"
) from e
# Handle JSON-RPC error
if "error" in body:
error = body["error"]
raise MCPClientError(
f"MCP server error [{error.get('code', '?')}]: "
f"{error.get('message', 'Unknown error')}"
)
if isinstance(error, dict):
raise MCPClientError(
f"MCP server error [{error.get('code', '?')}]: "
f"{error.get('message', 'Unknown error')}"
)
raise MCPClientError(f"MCP server error: {error}")
return body.get("result")

View File

@@ -68,7 +68,7 @@ class MCPOAuthHandler(BaseOAuthHandler):
}
if scopes:
params["scope"] = " ".join(scopes)
# PKCE is required by the MCP spec (S256 only)
# PKCE (S256) — included when the caller provides a code_challenge
if code_challenge:
params["code_challenge"] = code_challenge
params["code_challenge_method"] = "S256"
@@ -97,13 +97,18 @@ class MCPOAuthHandler(BaseOAuthHandler):
if self.resource_url:
data["resource"] = self.resource_url
response = await Requests().post(
response = await Requests(raise_for_status=True).post(
self.token_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
tokens = response.json()
if "error" in tokens:
raise RuntimeError(
f"Token exchange failed: {tokens.get('error_description', tokens['error'])}"
)
now = int(time.time())
expires_in = tokens.get("expires_in")
@@ -141,13 +146,18 @@ class MCPOAuthHandler(BaseOAuthHandler):
if self.resource_url:
data["resource"] = self.resource_url
response = await Requests().post(
response = await Requests(raise_for_status=True).post(
self.token_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
tokens = response.json()
if "error" in tokens:
raise RuntimeError(
f"Token refresh failed: {tokens.get('error_description', tokens['error'])}"
)
now = int(time.time())
expires_in = tokens.get("expires_in")

View File

@@ -52,7 +52,8 @@ class _MCPTestServer:
def start(self):
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
self._started.wait(timeout=5)
if not self._started.wait(timeout=5):
raise RuntimeError("MCP test server failed to start within 5 seconds")
return self
def stop(self):

View File

@@ -6,12 +6,12 @@ import json
from unittest.mock import AsyncMock, patch
import pytest
from pydantic import SecretStr
from backend.blocks.mcp.block import MCPToolBlock, TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError
from backend.data.model import APIKeyCredentials, OAuth2Credentials
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError, MCPTool
from backend.util.test import execute_block_test
# ── SSE parsing unit tests ───────────────────────────────────────────
@@ -81,7 +81,6 @@ class TestSSEParsing:
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == "second"
from backend.util.test import execute_block_test
# ── MCPClient unit tests ─────────────────────────────────────────────
@@ -326,7 +325,7 @@ class TestMCPToolBlock:
},
"required": ["city", "units"],
},
"city": "London",
"tool_arguments": {"city": "London"},
}
missing = MCPToolBlock.Input.get_missing_input(data)
assert missing == {"units"}
@@ -338,7 +337,7 @@ class TestMCPToolBlock:
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
"city": "London",
"tool_arguments": {"city": "London"},
}
missing = MCPToolBlock.Input.get_missing_input(data)
assert missing == set()
@@ -575,8 +574,8 @@ class TestMCPToolBlock:
}
@pytest.mark.asyncio
async def test_run_skips_placeholder_credentials(self):
"""Ensure placeholder API keys are not sent to the MCP server."""
async def test_run_sends_api_key_credentials(self):
"""Ensure non-empty API keys are sent to the MCP server."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
@@ -584,11 +583,11 @@ class TestMCPToolBlock:
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
)
placeholder_creds = APIKeyCredentials(
creds = APIKeyCredentials(
id="test-id",
provider="mcp",
api_key=SecretStr("FAKE_API_KEY"),
title="Placeholder",
api_key=SecretStr("real-api-key"),
title="Real",
)
captured_tokens = []
@@ -599,10 +598,10 @@ class TestMCPToolBlock:
block._call_mcp_tool = mock_call # type: ignore
async for _ in block.run(input_data, credentials=placeholder_creds):
async for _ in block.run(input_data, credentials=creds):
pass
assert captured_tokens == [None]
assert captured_tokens == ["real-api-key"]
# ── OAuth2 credential support tests ─────────────────────────────────
@@ -628,14 +627,6 @@ class TestMCPOAuth2Support:
token = MCPToolBlock._extract_auth_token(creds)
assert token == "oauth2-access-token"
def test_extract_auth_token_placeholder_skipped(self):
creds = APIKeyCredentials(
id="test", provider="mcp",
api_key=SecretStr("FAKE_API_KEY"), title="test",
)
token = MCPToolBlock._extract_auth_token(creds)
assert token is None
def test_extract_auth_token_empty_skipped(self):
creds = APIKeyCredentials(
id="test", provider="mcp",