fix(mcp): Use manual credential resolution instead of CredentialsField

The block framework's CredentialsField requires credentials to always be
present, which doesn't work for public MCP servers. Replace it with a
plain credential_id field and manual resolution from the credential store,
allowing both authenticated and public MCP servers to work seamlessly.
This commit is contained in:
Zamil Majdy
2026-02-09 14:41:14 +04:00
parent 03487f7b4d
commit d62fde9445
5 changed files with 74 additions and 149 deletions

View File

@@ -8,9 +8,7 @@ dropdown and the input/output schema adapts dynamically.
import json
import logging
from typing import Any, Literal
from pydantic import SecretStr
from typing import Any
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.data.block import (
@@ -22,36 +20,11 @@ from backend.data.block import (
BlockSchemaOutput,
BlockType,
)
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
OAuth2Credentials,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.data.model import OAuth2Credentials, SchemaField
from backend.util.json import validate_with_jsonschema
logger = logging.getLogger(__name__)
MCPCredentials = APIKeyCredentials | OAuth2Credentials
MCPCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.MCP], Literal["api_key", "oauth2"]
]
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="mcp",
api_key=SecretStr("test-mcp-token"),
title="Mock MCP Credentials",
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
class MCPToolBlock(Block):
"""
@@ -67,16 +40,15 @@ class MCPToolBlock(Block):
"""
class Input(BlockSchemaInput):
# -- Static fields (always shown) --
credentials: MCPCredentialsInput = CredentialsField(
description="Credentials for the MCP server. Use an API key for Bearer "
"token auth, or OAuth2 for servers that support it. For public "
"servers, create a credential with any placeholder value.",
)
server_url: str = SchemaField(
description="URL of the MCP server (Streamable HTTP endpoint)",
placeholder="https://mcp.example.com/mcp",
)
credential_id: str = SchemaField(
description="Credential ID from OAuth flow (empty for public servers)",
default="",
hidden=True,
)
available_tools: dict[str, Any] = SchemaField(
description="Available tools on the MCP server. "
"This is populated automatically when a server URL is provided.",
@@ -95,7 +67,6 @@ class MCPToolBlock(Block):
hidden=True,
)
# -- Dynamic field: actual arguments for the selected tool --
tool_arguments: dict[str, Any] = SchemaField(
description="Arguments to pass to the selected MCP tool. "
"The fields here are defined by the tool's input schema.",
@@ -143,7 +114,6 @@ class MCPToolBlock(Block):
block_type=BlockType.STANDARD,
test_input={
"server_url": "https://mcp.example.com/mcp",
"credentials": TEST_CREDENTIALS_INPUT,
"selected_tool": "get_weather",
"tool_input_schema": {
"type": "object",
@@ -164,7 +134,6 @@ class MCPToolBlock(Block):
"temperature": 20,
},
},
test_credentials=TEST_CREDENTIALS,
)
async def _call_mcp_tool(
@@ -220,24 +189,28 @@ class MCPToolBlock(Block):
return output_parts[0]
return output_parts if output_parts else None
@staticmethod
def _extract_auth_token(credentials: MCPCredentials) -> str | None:
"""Extract a Bearer token from either API key or OAuth2 credentials."""
if isinstance(credentials, OAuth2Credentials):
return credentials.access_token.get_secret_value()
if isinstance(credentials, APIKeyCredentials) and credentials.api_key:
token_value = credentials.api_key.get_secret_value()
if token_value:
return token_value
async def _resolve_auth_token(self, credential_id: str, user_id: str) -> str | None:
"""Resolve a Bearer token from a stored credential ID."""
if not credential_id:
return None
from backend.util.clients import get_integration_credentials_store
store = get_integration_credentials_store()
creds = await store.get_creds_by_id(user_id, credential_id)
if not creds:
logger.warning(f"Credential {credential_id} not found")
return None
if isinstance(creds, OAuth2Credentials):
return creds.access_token.get_secret_value()
if hasattr(creds, "api_key") and creds.api_key:
return creds.api_key.get_secret_value() or None
return None
async def run(
self,
input_data: Input,
*,
credentials: MCPCredentials,
user_id: str,
**kwargs,
) -> BlockOutput:
if not input_data.server_url:
@@ -248,7 +221,7 @@ class MCPToolBlock(Block):
yield "error", "No tool selected. Please select a tool from the dropdown."
return
auth_token = self._extract_auth_token(credentials)
auth_token = await self._resolve_auth_token(input_data.credential_id, user_id)
try:
result = await self._call_mcp_tool(

View File

@@ -6,15 +6,9 @@ import json
from unittest.mock import AsyncMock, patch
import pytest
from pydantic import SecretStr
from backend.blocks.mcp.block import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
MCPToolBlock,
)
from backend.blocks.mcp.block import MCPToolBlock
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError
from backend.data.model import APIKeyCredentials, OAuth2Credentials
from backend.util.test import execute_block_test
# ── SSE parsing unit tests ───────────────────────────────────────────
@@ -273,6 +267,8 @@ class TestMCPClient:
# ── MCPToolBlock unit tests ──────────────────────────────────────────
MOCK_USER_ID = "test-user-123"
class TestMCPToolBlock:
"""Tests for the MCPToolBlock."""
@@ -289,7 +285,7 @@ class TestMCPToolBlock:
assert "server_url" in props
assert "selected_tool" in props
assert "tool_arguments" in props
assert "credentials" in props
assert "credential_id" in props
def test_output_schema(self):
block = MCPToolBlock()
@@ -356,10 +352,9 @@ class TestMCPToolBlock:
input_data = MCPToolBlock.Input(
server_url="",
selected_tool="test",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
)
outputs = []
async for name, data in block.run(input_data, credentials=TEST_CREDENTIALS):
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert outputs == [("error", "MCP server URL is required")]
@@ -369,10 +364,9 @@ class TestMCPToolBlock:
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
)
outputs = []
async for name, data in block.run(input_data, credentials=TEST_CREDENTIALS):
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert outputs == [
("error", "No tool selected. Please select a tool from the dropdown.")
@@ -389,7 +383,6 @@ class TestMCPToolBlock:
"properties": {"city": {"type": "string"}},
},
tool_arguments={"city": "London"},
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
)
async def mock_call(*args, **kwargs):
@@ -398,7 +391,7 @@ class TestMCPToolBlock:
block._call_mcp_tool = mock_call # type: ignore
outputs = []
async for name, data in block.run(input_data, credentials=TEST_CREDENTIALS):
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
@@ -411,7 +404,6 @@ class TestMCPToolBlock:
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="bad_tool",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
)
async def mock_call(*args, **kwargs):
@@ -420,7 +412,7 @@ class TestMCPToolBlock:
block._call_mcp_tool = mock_call # type: ignore
outputs = []
async for name, data in block.run(input_data, credentials=TEST_CREDENTIALS):
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert outputs[0][0] == "error"
@@ -566,20 +558,39 @@ class TestMCPToolBlock:
}
@pytest.mark.asyncio
async def test_run_sends_api_key_credentials(self):
"""Ensure non-empty API keys are sent to the MCP server."""
async def test_run_with_credential_id(self):
"""Verify the block resolves credential_id and passes auth token."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="test_tool",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
credential_id="cred-123",
)
creds = APIKeyCredentials(
id="test-id",
provider="mcp",
api_key=SecretStr("real-api-key"),
title="Real",
captured_tokens = []
async def mock_call(server_url, tool_name, arguments, auth_token=None):
captured_tokens.append(auth_token)
return "ok"
async def mock_resolve(self, cred_id, uid):
return "resolved-token"
block._call_mcp_tool = mock_call # type: ignore
with patch.object(MCPToolBlock, "_resolve_auth_token", mock_resolve):
async for _ in block.run(input_data, user_id=MOCK_USER_ID):
pass
assert captured_tokens == ["resolved-token"]
@pytest.mark.asyncio
async def test_run_without_credential_id(self):
"""Verify the block works without credentials (public server)."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="test_tool",
)
captured_tokens = []
@@ -590,78 +601,9 @@ class TestMCPToolBlock:
block._call_mcp_tool = mock_call # type: ignore
async for _ in block.run(input_data, credentials=creds):
pass
assert captured_tokens == ["real-api-key"]
# ── OAuth2 credential support tests ─────────────────────────────────
class TestMCPOAuth2Support:
"""Tests for OAuth2 credential support in MCPToolBlock."""
def test_extract_auth_token_from_api_key(self):
creds = APIKeyCredentials(
id="test",
provider="mcp",
api_key=SecretStr("my-api-key"),
title="test",
)
token = MCPToolBlock._extract_auth_token(creds)
assert token == "my-api-key"
def test_extract_auth_token_from_oauth2(self):
creds = OAuth2Credentials(
id="test",
provider="mcp",
access_token=SecretStr("oauth2-access-token"),
scopes=["read"],
title="test",
)
token = MCPToolBlock._extract_auth_token(creds)
assert token == "oauth2-access-token"
def test_extract_auth_token_empty_skipped(self):
creds = APIKeyCredentials(
id="test",
provider="mcp",
api_key=SecretStr(""),
title="test",
)
token = MCPToolBlock._extract_auth_token(creds)
assert token is None
@pytest.mark.asyncio
async def test_run_with_oauth2_credentials(self):
"""Verify the block can run with OAuth2 credentials."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="test_tool",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
)
oauth2_creds = OAuth2Credentials(
id="test-id",
provider="mcp",
access_token=SecretStr("real-oauth2-token"),
scopes=["read", "write"],
title="MCP OAuth",
)
captured_tokens = []
async def mock_call(server_url, tool_name, arguments, auth_token=None):
captured_tokens.append(auth_token)
return {"status": "ok"}
block._call_mcp_tool = mock_call # type: ignore
outputs = []
async for name, data in block.run(input_data, credentials=oauth2_creds):
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert captured_tokens == ["real-oauth2-token"]
assert outputs == [("result", {"status": "ok"})]
assert captured_tokens == [None]
assert outputs == [("result", "ok")]

View File

@@ -68,6 +68,7 @@ export const Block: BlockComponent = ({
selected_tool: result.selectedTool,
tool_input_schema: result.toolInputSchema,
available_tools: result.availableTools,
credential_id: result.credentialId ?? "",
});
setMcpDialogOpen(false);
},

View File

@@ -199,6 +199,7 @@ export function BlocksControl({
selected_tool: result.selectedTool,
tool_input_schema: result.toolInputSchema,
available_tools: result.availableTools,
credential_id: result.credentialId ?? "",
});
setMcpDialogOpen(false);
},

View File

@@ -25,6 +25,8 @@ export type MCPToolDialogResult = {
selectedTool: string;
toolInputSchema: Record<string, any>;
availableTools: Record<string, any>;
/** Credential ID from OAuth flow, null for public servers. */
credentialId: string | null;
};
interface MCPToolDialogProps {
@@ -56,6 +58,7 @@ export function MCPToolDialog({
const [showManualToken, setShowManualToken] = useState(false);
const [manualToken, setManualToken] = useState("");
const [selectedTool, setSelectedTool] = useState<MCPTool | null>(null);
const [credentialId, setCredentialId] = useState<string | null>(null);
const oauthLoadingRef = useRef(false);
const stateTokenRef = useRef<string | null>(null);
@@ -120,6 +123,7 @@ export function MCPToolDialog({
setAuthRequired(false);
setShowManualToken(false);
setSelectedTool(null);
setCredentialId(null);
stateTokenRef.current = null;
}, [cleanupOAuthListeners]);
@@ -186,8 +190,11 @@ export function MCPToolDialog({
// Exchange code for tokens (stored server-side)
setLoading(true);
try {
await api.mcpOAuthCallback(data.code!, stateTokenRef.current!);
// Retry discovery — backend auto-uses stored credential
const callbackResult = await api.mcpOAuthCallback(
data.code!,
stateTokenRef.current!,
);
setCredentialId(callbackResult.credential_id);
const result = await api.mcpDiscoverTools(serverUrl.trim());
localStorage.setItem(STORAGE_KEY, serverUrl.trim());
setTools(result.tools);
@@ -299,9 +306,10 @@ export function MCPToolDialog({
selectedTool: selectedTool.name,
toolInputSchema: selectedTool.input_schema,
availableTools,
credentialId,
});
reset();
}, [selectedTool, tools, serverUrl, onConfirm, reset]);
}, [selectedTool, tools, serverUrl, credentialId, onConfirm, reset]);
return (
<Dialog open={open} onOpenChange={(isOpen) => !isOpen && handleClose()}>