mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-09 06:15:41 -05:00
fix(mcp): Use manual credential resolution instead of CredentialsField
The block framework's CredentialsField requires credentials to always be present, which doesn't work for public MCP servers. Replace it with a plain credential_id field and manual resolution from the credential store, allowing both authenticated and public MCP servers to work seamlessly.
This commit is contained in:
@@ -8,9 +8,7 @@ dropdown and the input/output schema adapts dynamically.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
from typing import Any
|
||||
|
||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||
from backend.data.block import (
|
||||
@@ -22,36 +20,11 @@ from backend.data.block import (
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.data.model import OAuth2Credentials, SchemaField
|
||||
from backend.util.json import validate_with_jsonschema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MCPCredentials = APIKeyCredentials | OAuth2Credentials
|
||||
MCPCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.MCP], Literal["api_key", "oauth2"]
|
||||
]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="mcp",
|
||||
api_key=SecretStr("test-mcp-token"),
|
||||
title="Mock MCP Credentials",
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class MCPToolBlock(Block):
|
||||
"""
|
||||
@@ -67,16 +40,15 @@ class MCPToolBlock(Block):
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
# -- Static fields (always shown) --
|
||||
credentials: MCPCredentialsInput = CredentialsField(
|
||||
description="Credentials for the MCP server. Use an API key for Bearer "
|
||||
"token auth, or OAuth2 for servers that support it. For public "
|
||||
"servers, create a credential with any placeholder value.",
|
||||
)
|
||||
server_url: str = SchemaField(
|
||||
description="URL of the MCP server (Streamable HTTP endpoint)",
|
||||
placeholder="https://mcp.example.com/mcp",
|
||||
)
|
||||
credential_id: str = SchemaField(
|
||||
description="Credential ID from OAuth flow (empty for public servers)",
|
||||
default="",
|
||||
hidden=True,
|
||||
)
|
||||
available_tools: dict[str, Any] = SchemaField(
|
||||
description="Available tools on the MCP server. "
|
||||
"This is populated automatically when a server URL is provided.",
|
||||
@@ -95,7 +67,6 @@ class MCPToolBlock(Block):
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
# -- Dynamic field: actual arguments for the selected tool --
|
||||
tool_arguments: dict[str, Any] = SchemaField(
|
||||
description="Arguments to pass to the selected MCP tool. "
|
||||
"The fields here are defined by the tool's input schema.",
|
||||
@@ -143,7 +114,6 @@ class MCPToolBlock(Block):
|
||||
block_type=BlockType.STANDARD,
|
||||
test_input={
|
||||
"server_url": "https://mcp.example.com/mcp",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"selected_tool": "get_weather",
|
||||
"tool_input_schema": {
|
||||
"type": "object",
|
||||
@@ -164,7 +134,6 @@ class MCPToolBlock(Block):
|
||||
"temperature": 20,
|
||||
},
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def _call_mcp_tool(
|
||||
@@ -220,24 +189,28 @@ class MCPToolBlock(Block):
|
||||
return output_parts[0]
|
||||
return output_parts if output_parts else None
|
||||
|
||||
@staticmethod
|
||||
def _extract_auth_token(credentials: MCPCredentials) -> str | None:
|
||||
"""Extract a Bearer token from either API key or OAuth2 credentials."""
|
||||
if isinstance(credentials, OAuth2Credentials):
|
||||
return credentials.access_token.get_secret_value()
|
||||
|
||||
if isinstance(credentials, APIKeyCredentials) and credentials.api_key:
|
||||
token_value = credentials.api_key.get_secret_value()
|
||||
if token_value:
|
||||
return token_value
|
||||
async def _resolve_auth_token(self, credential_id: str, user_id: str) -> str | None:
|
||||
"""Resolve a Bearer token from a stored credential ID."""
|
||||
if not credential_id:
|
||||
return None
|
||||
from backend.util.clients import get_integration_credentials_store
|
||||
|
||||
store = get_integration_credentials_store()
|
||||
creds = await store.get_creds_by_id(user_id, credential_id)
|
||||
if not creds:
|
||||
logger.warning(f"Credential {credential_id} not found")
|
||||
return None
|
||||
if isinstance(creds, OAuth2Credentials):
|
||||
return creds.access_token.get_secret_value()
|
||||
if hasattr(creds, "api_key") and creds.api_key:
|
||||
return creds.api_key.get_secret_value() or None
|
||||
return None
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: MCPCredentials,
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
if not input_data.server_url:
|
||||
@@ -248,7 +221,7 @@ class MCPToolBlock(Block):
|
||||
yield "error", "No tool selected. Please select a tool from the dropdown."
|
||||
return
|
||||
|
||||
auth_token = self._extract_auth_token(credentials)
|
||||
auth_token = await self._resolve_auth_token(input_data.credential_id, user_id)
|
||||
|
||||
try:
|
||||
result = await self._call_mcp_tool(
|
||||
|
||||
@@ -6,15 +6,9 @@ import json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.mcp.block import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
MCPToolBlock,
|
||||
)
|
||||
from backend.blocks.mcp.block import MCPToolBlock
|
||||
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
# ── SSE parsing unit tests ───────────────────────────────────────────
|
||||
@@ -273,6 +267,8 @@ class TestMCPClient:
|
||||
|
||||
# ── MCPToolBlock unit tests ──────────────────────────────────────────
|
||||
|
||||
MOCK_USER_ID = "test-user-123"
|
||||
|
||||
|
||||
class TestMCPToolBlock:
|
||||
"""Tests for the MCPToolBlock."""
|
||||
@@ -289,7 +285,7 @@ class TestMCPToolBlock:
|
||||
assert "server_url" in props
|
||||
assert "selected_tool" in props
|
||||
assert "tool_arguments" in props
|
||||
assert "credentials" in props
|
||||
assert "credential_id" in props
|
||||
|
||||
def test_output_schema(self):
|
||||
block = MCPToolBlock()
|
||||
@@ -356,10 +352,9 @@ class TestMCPToolBlock:
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="",
|
||||
selected_tool="test",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, credentials=TEST_CREDENTIALS):
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
assert outputs == [("error", "MCP server URL is required")]
|
||||
|
||||
@@ -369,10 +364,9 @@ class TestMCPToolBlock:
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
selected_tool="",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, credentials=TEST_CREDENTIALS):
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
assert outputs == [
|
||||
("error", "No tool selected. Please select a tool from the dropdown.")
|
||||
@@ -389,7 +383,6 @@ class TestMCPToolBlock:
|
||||
"properties": {"city": {"type": "string"}},
|
||||
},
|
||||
tool_arguments={"city": "London"},
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
async def mock_call(*args, **kwargs):
|
||||
@@ -398,7 +391,7 @@ class TestMCPToolBlock:
|
||||
block._call_mcp_tool = mock_call # type: ignore
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, credentials=TEST_CREDENTIALS):
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
@@ -411,7 +404,6 @@ class TestMCPToolBlock:
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
selected_tool="bad_tool",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
async def mock_call(*args, **kwargs):
|
||||
@@ -420,7 +412,7 @@ class TestMCPToolBlock:
|
||||
block._call_mcp_tool = mock_call # type: ignore
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, credentials=TEST_CREDENTIALS):
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert outputs[0][0] == "error"
|
||||
@@ -566,20 +558,39 @@ class TestMCPToolBlock:
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_sends_api_key_credentials(self):
|
||||
"""Ensure non-empty API keys are sent to the MCP server."""
|
||||
async def test_run_with_credential_id(self):
|
||||
"""Verify the block resolves credential_id and passes auth token."""
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
selected_tool="test_tool",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
credential_id="cred-123",
|
||||
)
|
||||
|
||||
creds = APIKeyCredentials(
|
||||
id="test-id",
|
||||
provider="mcp",
|
||||
api_key=SecretStr("real-api-key"),
|
||||
title="Real",
|
||||
captured_tokens = []
|
||||
|
||||
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
||||
captured_tokens.append(auth_token)
|
||||
return "ok"
|
||||
|
||||
async def mock_resolve(self, cred_id, uid):
|
||||
return "resolved-token"
|
||||
|
||||
block._call_mcp_tool = mock_call # type: ignore
|
||||
|
||||
with patch.object(MCPToolBlock, "_resolve_auth_token", mock_resolve):
|
||||
async for _ in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
pass
|
||||
|
||||
assert captured_tokens == ["resolved-token"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_without_credential_id(self):
|
||||
"""Verify the block works without credentials (public server)."""
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
selected_tool="test_tool",
|
||||
)
|
||||
|
||||
captured_tokens = []
|
||||
@@ -590,78 +601,9 @@ class TestMCPToolBlock:
|
||||
|
||||
block._call_mcp_tool = mock_call # type: ignore
|
||||
|
||||
async for _ in block.run(input_data, credentials=creds):
|
||||
pass
|
||||
|
||||
assert captured_tokens == ["real-api-key"]
|
||||
|
||||
|
||||
# ── OAuth2 credential support tests ─────────────────────────────────
|
||||
|
||||
|
||||
class TestMCPOAuth2Support:
|
||||
"""Tests for OAuth2 credential support in MCPToolBlock."""
|
||||
|
||||
def test_extract_auth_token_from_api_key(self):
|
||||
creds = APIKeyCredentials(
|
||||
id="test",
|
||||
provider="mcp",
|
||||
api_key=SecretStr("my-api-key"),
|
||||
title="test",
|
||||
)
|
||||
token = MCPToolBlock._extract_auth_token(creds)
|
||||
assert token == "my-api-key"
|
||||
|
||||
def test_extract_auth_token_from_oauth2(self):
|
||||
creds = OAuth2Credentials(
|
||||
id="test",
|
||||
provider="mcp",
|
||||
access_token=SecretStr("oauth2-access-token"),
|
||||
scopes=["read"],
|
||||
title="test",
|
||||
)
|
||||
token = MCPToolBlock._extract_auth_token(creds)
|
||||
assert token == "oauth2-access-token"
|
||||
|
||||
def test_extract_auth_token_empty_skipped(self):
|
||||
creds = APIKeyCredentials(
|
||||
id="test",
|
||||
provider="mcp",
|
||||
api_key=SecretStr(""),
|
||||
title="test",
|
||||
)
|
||||
token = MCPToolBlock._extract_auth_token(creds)
|
||||
assert token is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_oauth2_credentials(self):
|
||||
"""Verify the block can run with OAuth2 credentials."""
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
selected_tool="test_tool",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
oauth2_creds = OAuth2Credentials(
|
||||
id="test-id",
|
||||
provider="mcp",
|
||||
access_token=SecretStr("real-oauth2-token"),
|
||||
scopes=["read", "write"],
|
||||
title="MCP OAuth",
|
||||
)
|
||||
|
||||
captured_tokens = []
|
||||
|
||||
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
||||
captured_tokens.append(auth_token)
|
||||
return {"status": "ok"}
|
||||
|
||||
block._call_mcp_tool = mock_call # type: ignore
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, credentials=oauth2_creds):
|
||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert captured_tokens == ["real-oauth2-token"]
|
||||
assert outputs == [("result", {"status": "ok"})]
|
||||
assert captured_tokens == [None]
|
||||
assert outputs == [("result", "ok")]
|
||||
|
||||
@@ -68,6 +68,7 @@ export const Block: BlockComponent = ({
|
||||
selected_tool: result.selectedTool,
|
||||
tool_input_schema: result.toolInputSchema,
|
||||
available_tools: result.availableTools,
|
||||
credential_id: result.credentialId ?? "",
|
||||
});
|
||||
setMcpDialogOpen(false);
|
||||
},
|
||||
|
||||
@@ -199,6 +199,7 @@ export function BlocksControl({
|
||||
selected_tool: result.selectedTool,
|
||||
tool_input_schema: result.toolInputSchema,
|
||||
available_tools: result.availableTools,
|
||||
credential_id: result.credentialId ?? "",
|
||||
});
|
||||
setMcpDialogOpen(false);
|
||||
},
|
||||
|
||||
@@ -25,6 +25,8 @@ export type MCPToolDialogResult = {
|
||||
selectedTool: string;
|
||||
toolInputSchema: Record<string, any>;
|
||||
availableTools: Record<string, any>;
|
||||
/** Credential ID from OAuth flow, null for public servers. */
|
||||
credentialId: string | null;
|
||||
};
|
||||
|
||||
interface MCPToolDialogProps {
|
||||
@@ -56,6 +58,7 @@ export function MCPToolDialog({
|
||||
const [showManualToken, setShowManualToken] = useState(false);
|
||||
const [manualToken, setManualToken] = useState("");
|
||||
const [selectedTool, setSelectedTool] = useState<MCPTool | null>(null);
|
||||
const [credentialId, setCredentialId] = useState<string | null>(null);
|
||||
|
||||
const oauthLoadingRef = useRef(false);
|
||||
const stateTokenRef = useRef<string | null>(null);
|
||||
@@ -120,6 +123,7 @@ export function MCPToolDialog({
|
||||
setAuthRequired(false);
|
||||
setShowManualToken(false);
|
||||
setSelectedTool(null);
|
||||
setCredentialId(null);
|
||||
stateTokenRef.current = null;
|
||||
}, [cleanupOAuthListeners]);
|
||||
|
||||
@@ -186,8 +190,11 @@ export function MCPToolDialog({
|
||||
// Exchange code for tokens (stored server-side)
|
||||
setLoading(true);
|
||||
try {
|
||||
await api.mcpOAuthCallback(data.code!, stateTokenRef.current!);
|
||||
// Retry discovery — backend auto-uses stored credential
|
||||
const callbackResult = await api.mcpOAuthCallback(
|
||||
data.code!,
|
||||
stateTokenRef.current!,
|
||||
);
|
||||
setCredentialId(callbackResult.credential_id);
|
||||
const result = await api.mcpDiscoverTools(serverUrl.trim());
|
||||
localStorage.setItem(STORAGE_KEY, serverUrl.trim());
|
||||
setTools(result.tools);
|
||||
@@ -299,9 +306,10 @@ export function MCPToolDialog({
|
||||
selectedTool: selectedTool.name,
|
||||
toolInputSchema: selectedTool.input_schema,
|
||||
availableTools,
|
||||
credentialId,
|
||||
});
|
||||
reset();
|
||||
}, [selectedTool, tools, serverUrl, onConfirm, reset]);
|
||||
}, [selectedTool, tools, serverUrl, credentialId, onConfirm, reset]);
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={(isOpen) => !isOpen && handleClose()}>
|
||||
|
||||
Reference in New Issue
Block a user