mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-08 13:55:06 -05:00
Compare commits
8 Commits
ntindle/go
...
feat/mcp-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7aab2eb1d5 | ||
|
|
5ab28ccda2 | ||
|
|
4fe0f05980 | ||
|
|
19b3373052 | ||
|
|
7db3f12876 | ||
|
|
e9b996abb0 | ||
|
|
9b972389a0 | ||
|
|
cd64562e1b |
1328
autogpt_platform/autogpt_libs/poetry.lock
generated
1328
autogpt_platform/autogpt_libs/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -11,15 +11,15 @@ python = ">=3.10,<4.0"
|
|||||||
colorama = "^0.4.6"
|
colorama = "^0.4.6"
|
||||||
cryptography = "^45.0"
|
cryptography = "^45.0"
|
||||||
expiringdict = "^1.2.2"
|
expiringdict = "^1.2.2"
|
||||||
fastapi = "^0.116.1"
|
fastapi = "^0.128.0"
|
||||||
google-cloud-logging = "^3.12.1"
|
google-cloud-logging = "^3.13.0"
|
||||||
launchdarkly-server-sdk = "^9.12.0"
|
launchdarkly-server-sdk = "^9.14.1"
|
||||||
pydantic = "^2.11.7"
|
pydantic = "^2.12.5"
|
||||||
pydantic-settings = "^2.10.1"
|
pydantic-settings = "^2.12.0"
|
||||||
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
|
pyjwt = { version = "^2.11.0", extras = ["crypto"] }
|
||||||
redis = "^6.2.0"
|
redis = "^6.2.0"
|
||||||
supabase = "^2.16.0"
|
supabase = "^2.27.2"
|
||||||
uvicorn = "^0.35.0"
|
uvicorn = "^0.40.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pyright = "^1.1.404"
|
pyright = "^1.1.404"
|
||||||
|
|||||||
76
autogpt_platform/backend/MCP_BLOCK_IMPLEMENTATION.md
Normal file
76
autogpt_platform/backend/MCP_BLOCK_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
# MCP Block Implementation Plan
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Create a single **MCPBlock** that dynamically integrates with any MCP (Model Context Protocol)
|
||||||
|
server. Users provide a server URL, the block discovers available tools, presents them as a
|
||||||
|
dropdown, and dynamically adjusts input/output schema based on the selected tool — exactly like
|
||||||
|
`AgentExecutorBlock` handles dynamic schemas.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
User provides MCP server URL + credentials
|
||||||
|
↓
|
||||||
|
MCPBlock fetches tools via MCP protocol (tools/list)
|
||||||
|
↓
|
||||||
|
User selects tool from dropdown (stored in constantInput)
|
||||||
|
↓
|
||||||
|
Input schema dynamically updates based on selected tool's inputSchema
|
||||||
|
↓
|
||||||
|
On execution: MCPBlock calls the tool via MCP protocol (tools/call)
|
||||||
|
↓
|
||||||
|
Result yielded as block output
|
||||||
|
```
|
||||||
|
|
||||||
|
## Design Decisions
|
||||||
|
|
||||||
|
1. **Single block, not many blocks** — One `MCPBlock` handles all MCP servers/tools
|
||||||
|
2. **Dynamic schema via AgentExecutorBlock pattern** — Override `get_input_schema()`,
|
||||||
|
`get_input_defaults()`, `get_missing_input()` on the Input class
|
||||||
|
3. **Auth via API key or OAuth2 credentials** — Use existing `APIKeyCredentials` or
|
||||||
|
`OAuth2Credentials` with `ProviderName.MCP` provider. API keys are sent as Bearer tokens;
|
||||||
|
OAuth2 uses the access token.
|
||||||
|
4. **HTTP-based MCP client** — Use `aiohttp` (already a dependency) to implement MCP Streamable
|
||||||
|
HTTP transport directly. No need for the `mcp` Python SDK — the protocol is simple JSON-RPC
|
||||||
|
over HTTP. Handles both JSON and SSE response formats.
|
||||||
|
5. **No new DB tables** — Everything fits in existing `AgentBlock` + `AgentNode` tables
|
||||||
|
|
||||||
|
## Implementation Files
|
||||||
|
|
||||||
|
### New Files
|
||||||
|
- `backend/blocks/mcp/` — MCP block package
|
||||||
|
- `__init__.py`
|
||||||
|
- `block.py` — MCPToolBlock implementation
|
||||||
|
- `client.py` — MCP HTTP client (list_tools, call_tool)
|
||||||
|
- `oauth.py` — MCP OAuth handler for dynamic endpoint discovery
|
||||||
|
- `test_mcp.py` — Unit tests
|
||||||
|
- `test_oauth.py` — OAuth handler tests
|
||||||
|
- `test_integration.py` — Integration tests with local test server
|
||||||
|
- `test_e2e.py` — E2E tests against real MCP servers
|
||||||
|
|
||||||
|
### Modified Files
|
||||||
|
- `backend/integrations/providers.py` — Add `MCP = "mcp"` to ProviderName
|
||||||
|
|
||||||
|
## Dev Loop
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd autogpt_platform/backend
|
||||||
|
poetry run pytest backend/blocks/mcp/test_mcp.py -xvs # Unit tests
|
||||||
|
poetry run pytest backend/blocks/mcp/test_oauth.py -xvs # OAuth tests
|
||||||
|
poetry run pytest backend/blocks/mcp/test_integration.py -xvs # Integration tests
|
||||||
|
poetry run pytest backend/blocks/mcp/ -xvs # All MCP tests
|
||||||
|
```
|
||||||
|
|
||||||
|
## Status
|
||||||
|
|
||||||
|
- [x] Research & Design
|
||||||
|
- [x] Add ProviderName.MCP
|
||||||
|
- [x] Implement MCP client (client.py)
|
||||||
|
- [x] Implement MCPToolBlock (block.py)
|
||||||
|
- [x] Add OAuth2 support (oauth.py)
|
||||||
|
- [x] Write unit tests
|
||||||
|
- [x] Write integration tests
|
||||||
|
- [x] Write E2E tests
|
||||||
|
- [x] Run tests & fix issues
|
||||||
|
- [x] Create PR
|
||||||
@@ -117,7 +117,7 @@ def build_missing_credentials_from_graph(
|
|||||||
preserving all supported credential types for each field.
|
preserving all supported credential types for each field.
|
||||||
"""
|
"""
|
||||||
matched_keys = set(matched_credentials.keys()) if matched_credentials else set()
|
matched_keys = set(matched_credentials.keys()) if matched_credentials else set()
|
||||||
aggregated_fields = graph.regular_credentials_inputs
|
aggregated_fields = graph.aggregate_credentials_inputs()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
field_key: _serialize_missing_credential(field_key, field_info)
|
field_key: _serialize_missing_credential(field_key, field_info)
|
||||||
@@ -244,7 +244,7 @@ async def match_user_credentials_to_graph(
|
|||||||
missing_creds: list[str] = []
|
missing_creds: list[str] = []
|
||||||
|
|
||||||
# Get aggregated credentials requirements from the graph
|
# Get aggregated credentials requirements from the graph
|
||||||
aggregated_creds = graph.regular_credentials_inputs
|
aggregated_creds = graph.aggregate_credentials_inputs()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Matching credentials for graph {graph.id}: {len(aggregated_creds)} required"
|
f"Matching credentials for graph {graph.id}: {len(aggregated_creds)} required"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,78 +0,0 @@
|
|||||||
"""Tests for chat tools utility functions."""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.data.model import CredentialsFieldInfo
|
|
||||||
|
|
||||||
|
|
||||||
def _make_regular_field() -> CredentialsFieldInfo:
|
|
||||||
return CredentialsFieldInfo.model_validate(
|
|
||||||
{
|
|
||||||
"credentials_provider": ["github"],
|
|
||||||
"credentials_types": ["api_key"],
|
|
||||||
"is_auto_credential": False,
|
|
||||||
},
|
|
||||||
by_alias=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_build_missing_credentials_excludes_auto_creds():
|
|
||||||
"""
|
|
||||||
build_missing_credentials_from_graph() should use regular_credentials_inputs
|
|
||||||
and thus exclude auto_credentials from the "missing" set.
|
|
||||||
"""
|
|
||||||
from backend.api.features.chat.tools.utils import (
|
|
||||||
build_missing_credentials_from_graph,
|
|
||||||
)
|
|
||||||
|
|
||||||
regular_field = _make_regular_field()
|
|
||||||
|
|
||||||
mock_graph = MagicMock()
|
|
||||||
# regular_credentials_inputs should only return the non-auto field
|
|
||||||
mock_graph.regular_credentials_inputs = {
|
|
||||||
"github_api_key": (regular_field, {("node-1", "credentials")}, True),
|
|
||||||
}
|
|
||||||
|
|
||||||
result = build_missing_credentials_from_graph(mock_graph, matched_credentials=None)
|
|
||||||
|
|
||||||
# Should include the regular credential
|
|
||||||
assert "github_api_key" in result
|
|
||||||
# Should NOT include the auto_credential (not in regular_credentials_inputs)
|
|
||||||
assert "google_oauth2" not in result
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_match_user_credentials_excludes_auto_creds():
|
|
||||||
"""
|
|
||||||
match_user_credentials_to_graph() should use regular_credentials_inputs
|
|
||||||
and thus exclude auto_credentials from matching.
|
|
||||||
"""
|
|
||||||
from backend.api.features.chat.tools.utils import match_user_credentials_to_graph
|
|
||||||
|
|
||||||
regular_field = _make_regular_field()
|
|
||||||
|
|
||||||
mock_graph = MagicMock()
|
|
||||||
mock_graph.id = "test-graph"
|
|
||||||
# regular_credentials_inputs returns only non-auto fields
|
|
||||||
mock_graph.regular_credentials_inputs = {
|
|
||||||
"github_api_key": (regular_field, {("node-1", "credentials")}, True),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Mock the credentials manager to return no credentials
|
|
||||||
with patch(
|
|
||||||
"backend.api.features.chat.tools.utils.IntegrationCredentialsManager"
|
|
||||||
) as MockCredsMgr:
|
|
||||||
mock_store = AsyncMock()
|
|
||||||
mock_store.get_all_creds.return_value = []
|
|
||||||
MockCredsMgr.return_value.store = mock_store
|
|
||||||
|
|
||||||
matched, missing = await match_user_credentials_to_graph(
|
|
||||||
user_id="test-user", graph=mock_graph
|
|
||||||
)
|
|
||||||
|
|
||||||
# No credentials available, so github should be missing
|
|
||||||
assert len(matched) == 0
|
|
||||||
assert len(missing) == 1
|
|
||||||
assert "github_api_key" in missing[0]
|
|
||||||
@@ -1103,7 +1103,7 @@ async def create_preset_from_graph_execution(
|
|||||||
raise NotFoundError(
|
raise NotFoundError(
|
||||||
f"Graph #{graph_execution.graph_id} not found or accessible"
|
f"Graph #{graph_execution.graph_id} not found or accessible"
|
||||||
)
|
)
|
||||||
elif len(graph.regular_credentials_inputs) > 0:
|
elif len(graph.aggregate_credentials_inputs()) > 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Graph execution #{graph_exec_id} can't be turned into a preset "
|
f"Graph execution #{graph_exec_id} can't be turned into a preset "
|
||||||
"because it was run before this feature existed "
|
"because it was run before this feature existed "
|
||||||
|
|||||||
@@ -478,7 +478,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
webset = aexa.websets.get(id=input_data.external_id)
|
webset = await aexa.websets.get(id=input_data.external_id)
|
||||||
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
webset_result = Webset.model_validate(webset.model_dump(by_alias=True))
|
||||||
|
|
||||||
yield "webset", webset_result
|
yield "webset", webset_result
|
||||||
@@ -494,7 +494,7 @@ class ExaCreateOrFindWebsetBlock(Block):
|
|||||||
count=input_data.search_count,
|
count=input_data.search_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
webset = aexa.websets.create(
|
webset = await aexa.websets.create(
|
||||||
params=CreateWebsetParameters(
|
params=CreateWebsetParameters(
|
||||||
search=search_params,
|
search=search_params,
|
||||||
external_id=input_data.external_id,
|
external_id=input_data.external_id,
|
||||||
@@ -554,7 +554,7 @@ class ExaUpdateWebsetBlock(Block):
|
|||||||
if input_data.metadata is not None:
|
if input_data.metadata is not None:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_webset = aexa.websets.update(id=input_data.webset_id, params=payload)
|
sdk_webset = await aexa.websets.update(id=input_data.webset_id, params=payload)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
sdk_webset.status.value
|
sdk_webset.status.value
|
||||||
@@ -617,7 +617,7 @@ class ExaListWebsetsBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = aexa.websets.list(
|
response = await aexa.websets.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
@@ -678,7 +678,7 @@ class ExaGetWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_webset = aexa.websets.get(id=input_data.webset_id)
|
sdk_webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
sdk_webset.status.value
|
sdk_webset.status.value
|
||||||
@@ -748,7 +748,7 @@ class ExaDeleteWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_webset = aexa.websets.delete(id=input_data.webset_id)
|
deleted_webset = await aexa.websets.delete(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
deleted_webset.status.value
|
deleted_webset.status.value
|
||||||
@@ -798,7 +798,7 @@ class ExaCancelWebsetBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_webset = aexa.websets.cancel(id=input_data.webset_id)
|
canceled_webset = await aexa.websets.cancel(id=input_data.webset_id)
|
||||||
|
|
||||||
status_str = (
|
status_str = (
|
||||||
canceled_webset.status.value
|
canceled_webset.status.value
|
||||||
@@ -968,7 +968,7 @@ class ExaPreviewWebsetBlock(Block):
|
|||||||
entity["description"] = input_data.entity_description
|
entity["description"] = input_data.entity_description
|
||||||
payload["entity"] = entity
|
payload["entity"] = entity
|
||||||
|
|
||||||
sdk_preview = aexa.websets.preview(params=payload)
|
sdk_preview = await aexa.websets.preview(params=payload)
|
||||||
|
|
||||||
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
preview = PreviewWebsetModel.from_sdk(sdk_preview)
|
||||||
|
|
||||||
@@ -1051,7 +1051,7 @@ class ExaWebsetStatusBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status = (
|
status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
@@ -1185,7 +1185,7 @@ class ExaWebsetSummaryBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
# Extract basic info
|
# Extract basic info
|
||||||
webset_id = webset.id
|
webset_id = webset.id
|
||||||
@@ -1211,7 +1211,7 @@ class ExaWebsetSummaryBlock(Block):
|
|||||||
total_items = 0
|
total_items = 0
|
||||||
|
|
||||||
if input_data.include_sample_items and input_data.sample_size > 0:
|
if input_data.include_sample_items and input_data.sample_size > 0:
|
||||||
items_response = aexa.websets.items.list(
|
items_response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||||
)
|
)
|
||||||
sample_items_data = [
|
sample_items_data = [
|
||||||
@@ -1362,7 +1362,7 @@ class ExaWebsetReadyCheckBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get webset details
|
# Get webset details
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
status = (
|
status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
|
|||||||
@@ -202,7 +202,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_enrichment = aexa.websets.enrichments.create(
|
sdk_enrichment = await aexa.websets.enrichments.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -223,7 +223,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
items_enriched = 0
|
items_enriched = 0
|
||||||
|
|
||||||
while time.time() - poll_start < input_data.polling_timeout:
|
while time.time() - poll_start < input_data.polling_timeout:
|
||||||
current_enrich = aexa.websets.enrichments.get(
|
current_enrich = await aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=enrichment_id
|
webset_id=input_data.webset_id, id=enrichment_id
|
||||||
)
|
)
|
||||||
current_status = (
|
current_status = (
|
||||||
@@ -234,7 +234,7 @@ class ExaCreateEnrichmentBlock(Block):
|
|||||||
|
|
||||||
if current_status in ["completed", "failed", "cancelled"]:
|
if current_status in ["completed", "failed", "cancelled"]:
|
||||||
# Estimate items from webset searches
|
# Estimate items from webset searches
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
if webset.searches:
|
if webset.searches:
|
||||||
for search in webset.searches:
|
for search in webset.searches:
|
||||||
if search.progress:
|
if search.progress:
|
||||||
@@ -329,7 +329,7 @@ class ExaGetEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_enrichment = aexa.websets.enrichments.get(
|
sdk_enrichment = await aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -474,7 +474,7 @@ class ExaDeleteEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_enrichment = aexa.websets.enrichments.delete(
|
deleted_enrichment = await aexa.websets.enrichments.delete(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -525,13 +525,13 @@ class ExaCancelEnrichmentBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_enrichment = aexa.websets.enrichments.cancel(
|
canceled_enrichment = await aexa.websets.enrichments.cancel(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to estimate how many items were enriched before cancellation
|
# Try to estimate how many items were enriched before cancellation
|
||||||
items_enriched = 0
|
items_enriched = 0
|
||||||
items_response = aexa.websets.items.list(
|
items_response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=100
|
webset_id=input_data.webset_id, limit=100
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
def _create_test_mock():
|
def _create_test_mock():
|
||||||
"""Create test mocks for the AsyncExa SDK."""
|
"""Create test mocks for the AsyncExa SDK."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
# Create mock SDK import object
|
# Create mock SDK import object
|
||||||
mock_import = MagicMock()
|
mock_import = MagicMock()
|
||||||
@@ -247,7 +247,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(
|
||||||
imports=MagicMock(create=lambda *args, **kwargs: mock_import)
|
imports=MagicMock(create=AsyncMock(return_value=mock_import))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -294,7 +294,7 @@ class ExaCreateImportBlock(Block):
|
|||||||
if input_data.metadata:
|
if input_data.metadata:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_import = aexa.websets.imports.create(
|
sdk_import = await aexa.websets.imports.create(
|
||||||
params=payload, csv_data=input_data.csv_data
|
params=payload, csv_data=input_data.csv_data
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -360,7 +360,7 @@ class ExaGetImportBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_import = aexa.websets.imports.get(import_id=input_data.import_id)
|
sdk_import = await aexa.websets.imports.get(import_id=input_data.import_id)
|
||||||
|
|
||||||
import_obj = ImportModel.from_sdk(sdk_import)
|
import_obj = ImportModel.from_sdk(sdk_import)
|
||||||
|
|
||||||
@@ -426,7 +426,7 @@ class ExaListImportsBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = aexa.websets.imports.list(
|
response = await aexa.websets.imports.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
@@ -474,7 +474,9 @@ class ExaDeleteImportBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_import = aexa.websets.imports.delete(import_id=input_data.import_id)
|
deleted_import = await aexa.websets.imports.delete(
|
||||||
|
import_id=input_data.import_id
|
||||||
|
)
|
||||||
|
|
||||||
yield "import_id", deleted_import.id
|
yield "import_id", deleted_import.id
|
||||||
yield "success", "true"
|
yield "success", "true"
|
||||||
@@ -573,14 +575,14 @@ class ExaExportWebsetBlock(Block):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create mock iterator
|
# Create async iterator for list_all
|
||||||
mock_items = [mock_item1, mock_item2]
|
async def async_item_iterator(*args, **kwargs):
|
||||||
|
for item in [mock_item1, mock_item2]:
|
||||||
|
yield item
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(items=MagicMock(list_all=async_item_iterator))
|
||||||
items=MagicMock(list_all=lambda *args, **kwargs: iter(mock_items))
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -602,7 +604,7 @@ class ExaExportWebsetBlock(Block):
|
|||||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||||
)
|
)
|
||||||
|
|
||||||
for sdk_item in item_iterator:
|
async for sdk_item in item_iterator:
|
||||||
if len(all_items) >= input_data.max_items:
|
if len(all_items) >= input_data.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class ExaGetWebsetItemBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_item = aexa.websets.items.get(
|
sdk_item = await aexa.websets.items.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.item_id
|
webset_id=input_data.webset_id, id=input_data.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -269,7 +269,7 @@ class ExaListWebsetItemsBlock(Block):
|
|||||||
response = None
|
response = None
|
||||||
|
|
||||||
while time.time() - start_time < input_data.wait_timeout:
|
while time.time() - start_time < input_data.wait_timeout:
|
||||||
response = aexa.websets.items.list(
|
response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
@@ -282,13 +282,13 @@ class ExaListWebsetItemsBlock(Block):
|
|||||||
interval = min(interval * 1.2, 10)
|
interval = min(interval * 1.2, 10)
|
||||||
|
|
||||||
if not response:
|
if not response:
|
||||||
response = aexa.websets.items.list(
|
response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = aexa.websets.items.list(
|
response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
@@ -340,7 +340,7 @@ class ExaDeleteWebsetItemBlock(Block):
|
|||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_item = aexa.websets.items.delete(
|
deleted_item = await aexa.websets.items.delete(
|
||||||
webset_id=input_data.webset_id, id=input_data.item_id
|
webset_id=input_data.webset_id, id=input_data.item_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -408,7 +408,7 @@ class ExaBulkWebsetItemsBlock(Block):
|
|||||||
webset_id=input_data.webset_id, limit=input_data.max_items
|
webset_id=input_data.webset_id, limit=input_data.max_items
|
||||||
)
|
)
|
||||||
|
|
||||||
for sdk_item in item_iterator:
|
async for sdk_item in item_iterator:
|
||||||
if len(all_items) >= input_data.max_items:
|
if len(all_items) >= input_data.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -475,7 +475,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
entity_type = "unknown"
|
entity_type = "unknown"
|
||||||
if webset.searches:
|
if webset.searches:
|
||||||
@@ -495,7 +495,7 @@ class ExaWebsetItemsSummaryBlock(Block):
|
|||||||
# Get sample items if requested
|
# Get sample items if requested
|
||||||
sample_items: List[WebsetItemModel] = []
|
sample_items: List[WebsetItemModel] = []
|
||||||
if input_data.sample_size > 0:
|
if input_data.sample_size > 0:
|
||||||
items_response = aexa.websets.items.list(
|
items_response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id, limit=input_data.sample_size
|
webset_id=input_data.webset_id, limit=input_data.sample_size
|
||||||
)
|
)
|
||||||
# Convert to our stable models
|
# Convert to our stable models
|
||||||
@@ -569,7 +569,7 @@ class ExaGetNewItemsBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get items starting from cursor
|
# Get items starting from cursor
|
||||||
response = aexa.websets.items.list(
|
response = await aexa.websets.items.list(
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
cursor=input_data.since_cursor,
|
cursor=input_data.since_cursor,
|
||||||
limit=input_data.max_items,
|
limit=input_data.max_items,
|
||||||
|
|||||||
@@ -233,7 +233,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
def _create_test_mock():
|
def _create_test_mock():
|
||||||
"""Create test mocks for the AsyncExa SDK."""
|
"""Create test mocks for the AsyncExa SDK."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
# Create mock SDK monitor object
|
# Create mock SDK monitor object
|
||||||
mock_monitor = MagicMock()
|
mock_monitor = MagicMock()
|
||||||
@@ -263,7 +263,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
return {
|
return {
|
||||||
"_get_client": lambda *args, **kwargs: MagicMock(
|
"_get_client": lambda *args, **kwargs: MagicMock(
|
||||||
websets=MagicMock(
|
websets=MagicMock(
|
||||||
monitors=MagicMock(create=lambda *args, **kwargs: mock_monitor)
|
monitors=MagicMock(create=AsyncMock(return_value=mock_monitor))
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -320,7 +320,7 @@ class ExaCreateMonitorBlock(Block):
|
|||||||
if input_data.metadata:
|
if input_data.metadata:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_monitor = aexa.websets.monitors.create(params=payload)
|
sdk_monitor = await aexa.websets.monitors.create(params=payload)
|
||||||
|
|
||||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||||
|
|
||||||
@@ -384,7 +384,7 @@ class ExaGetMonitorBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_monitor = aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
sdk_monitor = await aexa.websets.monitors.get(monitor_id=input_data.monitor_id)
|
||||||
|
|
||||||
monitor = MonitorModel.from_sdk(sdk_monitor)
|
monitor = MonitorModel.from_sdk(sdk_monitor)
|
||||||
|
|
||||||
@@ -476,7 +476,7 @@ class ExaUpdateMonitorBlock(Block):
|
|||||||
if input_data.metadata is not None:
|
if input_data.metadata is not None:
|
||||||
payload["metadata"] = input_data.metadata
|
payload["metadata"] = input_data.metadata
|
||||||
|
|
||||||
sdk_monitor = aexa.websets.monitors.update(
|
sdk_monitor = await aexa.websets.monitors.update(
|
||||||
monitor_id=input_data.monitor_id, params=payload
|
monitor_id=input_data.monitor_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -522,7 +522,9 @@ class ExaDeleteMonitorBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
deleted_monitor = aexa.websets.monitors.delete(monitor_id=input_data.monitor_id)
|
deleted_monitor = await aexa.websets.monitors.delete(
|
||||||
|
monitor_id=input_data.monitor_id
|
||||||
|
)
|
||||||
|
|
||||||
yield "monitor_id", deleted_monitor.id
|
yield "monitor_id", deleted_monitor.id
|
||||||
yield "success", "true"
|
yield "success", "true"
|
||||||
@@ -579,7 +581,7 @@ class ExaListMonitorsBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
response = aexa.websets.monitors.list(
|
response = await aexa.websets.monitors.list(
|
||||||
cursor=input_data.cursor,
|
cursor=input_data.cursor,
|
||||||
limit=input_data.limit,
|
limit=input_data.limit,
|
||||||
webset_id=input_data.webset_id,
|
webset_id=input_data.webset_id,
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
WebsetTargetStatus.IDLE,
|
WebsetTargetStatus.IDLE,
|
||||||
WebsetTargetStatus.ANY_COMPLETE,
|
WebsetTargetStatus.ANY_COMPLETE,
|
||||||
]:
|
]:
|
||||||
final_webset = aexa.websets.wait_until_idle(
|
final_webset = await aexa.websets.wait_until_idle(
|
||||||
id=input_data.webset_id,
|
id=input_data.webset_id,
|
||||||
timeout=input_data.timeout,
|
timeout=input_data.timeout,
|
||||||
poll_interval=input_data.check_interval,
|
poll_interval=input_data.check_interval,
|
||||||
@@ -164,7 +164,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
interval = input_data.check_interval
|
interval = input_data.check_interval
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current webset status
|
# Get current webset status
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
current_status = (
|
current_status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
if hasattr(webset.status, "value")
|
if hasattr(webset.status, "value")
|
||||||
@@ -209,7 +209,7 @@ class ExaWaitForWebsetBlock(Block):
|
|||||||
|
|
||||||
# Timeout reached
|
# Timeout reached
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
final_status = (
|
final_status = (
|
||||||
webset.status.value
|
webset.status.value
|
||||||
if hasattr(webset.status, "value")
|
if hasattr(webset.status, "value")
|
||||||
@@ -345,7 +345,7 @@ class ExaWaitForSearchBlock(Block):
|
|||||||
try:
|
try:
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current search status using SDK
|
# Get current search status using SDK
|
||||||
search = aexa.websets.searches.get(
|
search = await aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -401,7 +401,7 @@ class ExaWaitForSearchBlock(Block):
|
|||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
# Get last known status
|
# Get last known status
|
||||||
search = aexa.websets.searches.get(
|
search = await aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
final_status = (
|
final_status = (
|
||||||
@@ -503,7 +503,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
try:
|
try:
|
||||||
while time.time() - start_time < input_data.timeout:
|
while time.time() - start_time < input_data.timeout:
|
||||||
# Get current enrichment status using SDK
|
# Get current enrichment status using SDK
|
||||||
enrichment = aexa.websets.enrichments.get(
|
enrichment = await aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -548,7 +548,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
# Get last known status
|
# Get last known status
|
||||||
enrichment = aexa.websets.enrichments.get(
|
enrichment = await aexa.websets.enrichments.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
webset_id=input_data.webset_id, id=input_data.enrichment_id
|
||||||
)
|
)
|
||||||
final_status = (
|
final_status = (
|
||||||
@@ -575,7 +575,7 @@ class ExaWaitForEnrichmentBlock(Block):
|
|||||||
) -> tuple[list[SampleEnrichmentModel], int]:
|
) -> tuple[list[SampleEnrichmentModel], int]:
|
||||||
"""Get sample enriched data and count."""
|
"""Get sample enriched data and count."""
|
||||||
# Get a few items to see enrichment results using SDK
|
# Get a few items to see enrichment results using SDK
|
||||||
response = aexa.websets.items.list(webset_id=webset_id, limit=5)
|
response = await aexa.websets.items.list(webset_id=webset_id, limit=5)
|
||||||
|
|
||||||
sample_data: list[SampleEnrichmentModel] = []
|
sample_data: list[SampleEnrichmentModel] = []
|
||||||
enriched_count = 0
|
enriched_count = 0
|
||||||
|
|||||||
@@ -317,7 +317,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
|||||||
|
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_search = aexa.websets.searches.create(
|
sdk_search = await aexa.websets.searches.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -350,7 +350,7 @@ class ExaCreateWebsetSearchBlock(Block):
|
|||||||
poll_start = time.time()
|
poll_start = time.time()
|
||||||
|
|
||||||
while time.time() - poll_start < input_data.polling_timeout:
|
while time.time() - poll_start < input_data.polling_timeout:
|
||||||
current_search = aexa.websets.searches.get(
|
current_search = await aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=search_id
|
webset_id=input_data.webset_id, id=search_id
|
||||||
)
|
)
|
||||||
current_status = (
|
current_status = (
|
||||||
@@ -442,7 +442,7 @@ class ExaGetWebsetSearchBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
sdk_search = aexa.websets.searches.get(
|
sdk_search = await aexa.websets.searches.get(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -523,7 +523,7 @@ class ExaCancelWebsetSearchBlock(Block):
|
|||||||
# Use AsyncExa SDK
|
# Use AsyncExa SDK
|
||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
canceled_search = aexa.websets.searches.cancel(
|
canceled_search = await aexa.websets.searches.cancel(
|
||||||
webset_id=input_data.webset_id, id=input_data.search_id
|
webset_id=input_data.webset_id, id=input_data.search_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -604,7 +604,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
|||||||
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
aexa = AsyncExa(api_key=credentials.api_key.get_secret_value())
|
||||||
|
|
||||||
# Get webset to check existing searches
|
# Get webset to check existing searches
|
||||||
webset = aexa.websets.get(id=input_data.webset_id)
|
webset = await aexa.websets.get(id=input_data.webset_id)
|
||||||
|
|
||||||
# Look for existing search with same query
|
# Look for existing search with same query
|
||||||
existing_search = None
|
existing_search = None
|
||||||
@@ -636,7 +636,7 @@ class ExaFindOrCreateSearchBlock(Block):
|
|||||||
if input_data.entity_type != SearchEntityType.AUTO:
|
if input_data.entity_type != SearchEntityType.AUTO:
|
||||||
payload["entity"] = {"type": input_data.entity_type.value}
|
payload["entity"] = {"type": input_data.entity_type.value}
|
||||||
|
|
||||||
sdk_search = aexa.websets.searches.create(
|
sdk_search = await aexa.websets.searches.create(
|
||||||
webset_id=input_data.webset_id, params=payload
|
webset_id=input_data.webset_id, params=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -596,10 +596,10 @@ def extract_openai_tool_calls(response) -> list[ToolContentBlock] | None:
|
|||||||
|
|
||||||
def get_parallel_tool_calls_param(
|
def get_parallel_tool_calls_param(
|
||||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||||
):
|
) -> bool | openai.Omit:
|
||||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||||
return openai.NOT_GIVEN
|
return openai.omit
|
||||||
return parallel_tool_calls
|
return parallel_tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
265
autogpt_platform/backend/backend/blocks/mcp/block.py
Normal file
265
autogpt_platform/backend/backend/blocks/mcp/block.py
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
"""
|
||||||
|
MCP (Model Context Protocol) Tool Block.
|
||||||
|
|
||||||
|
A single dynamic block that can connect to any MCP server, discover available tools,
|
||||||
|
and execute them. Works like AgentExecutorBlock — the user selects a tool from a
|
||||||
|
dropdown and the input/output schema adapts dynamically.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||||
|
from backend.data.block import (
|
||||||
|
Block,
|
||||||
|
BlockCategory,
|
||||||
|
BlockInput,
|
||||||
|
BlockOutput,
|
||||||
|
BlockSchemaInput,
|
||||||
|
BlockSchemaOutput,
|
||||||
|
BlockType,
|
||||||
|
)
|
||||||
|
from backend.data.model import (
|
||||||
|
APIKeyCredentials,
|
||||||
|
CredentialsField,
|
||||||
|
CredentialsMetaInput,
|
||||||
|
OAuth2Credentials,
|
||||||
|
SchemaField,
|
||||||
|
)
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.json import validate_with_jsonschema
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MCPCredentials = APIKeyCredentials | OAuth2Credentials
|
||||||
|
MCPCredentialsInput = CredentialsMetaInput[
|
||||||
|
Literal[ProviderName.MCP], Literal["api_key", "oauth2"]
|
||||||
|
]
|
||||||
|
|
||||||
|
TEST_CREDENTIALS = APIKeyCredentials(
|
||||||
|
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||||
|
provider="mcp",
|
||||||
|
api_key=SecretStr("test-mcp-token"),
|
||||||
|
title="Mock MCP Credentials",
|
||||||
|
)
|
||||||
|
TEST_CREDENTIALS_INPUT = {
|
||||||
|
"provider": TEST_CREDENTIALS.provider,
|
||||||
|
"id": TEST_CREDENTIALS.id,
|
||||||
|
"type": TEST_CREDENTIALS.type,
|
||||||
|
"title": TEST_CREDENTIALS.title,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MCPToolBlock(Block):
|
||||||
|
"""
|
||||||
|
A block that connects to an MCP server, lets the user pick a tool,
|
||||||
|
and executes it with dynamic input/output schema.
|
||||||
|
|
||||||
|
The flow:
|
||||||
|
1. User provides an MCP server URL (and optional credentials)
|
||||||
|
2. Frontend calls the backend to get tool list from that URL
|
||||||
|
3. User selects a tool from a dropdown (available_tools)
|
||||||
|
4. The block's input schema updates to reflect the selected tool's parameters
|
||||||
|
5. On execution, the block calls the MCP server to run the tool
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Input(BlockSchemaInput):
|
||||||
|
# -- Static fields (always shown) --
|
||||||
|
credentials: MCPCredentialsInput = CredentialsField(
|
||||||
|
description="Credentials for the MCP server. Use an API key for Bearer "
|
||||||
|
"token auth, or OAuth2 for servers that support it. For public "
|
||||||
|
"servers, create a credential with any placeholder value.",
|
||||||
|
)
|
||||||
|
server_url: str = SchemaField(
|
||||||
|
description="URL of the MCP server (Streamable HTTP endpoint)",
|
||||||
|
placeholder="https://mcp.example.com/mcp",
|
||||||
|
)
|
||||||
|
available_tools: dict[str, Any] = SchemaField(
|
||||||
|
description="Available tools on the MCP server. "
|
||||||
|
"This is populated automatically when a server URL is provided.",
|
||||||
|
default={},
|
||||||
|
hidden=True,
|
||||||
|
)
|
||||||
|
selected_tool: str = SchemaField(
|
||||||
|
description="The MCP tool to execute",
|
||||||
|
placeholder="Select a tool",
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
tool_input_schema: dict[str, Any] = SchemaField(
|
||||||
|
description="JSON Schema for the selected tool's input parameters. "
|
||||||
|
"Populated automatically when a tool is selected.",
|
||||||
|
default={},
|
||||||
|
hidden=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- Dynamic field: actual arguments for the selected tool --
|
||||||
|
tool_arguments: dict[str, Any] = SchemaField(
|
||||||
|
description="Arguments to pass to the selected MCP tool. "
|
||||||
|
"The fields here are defined by the tool's input schema.",
|
||||||
|
default={},
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
|
||||||
|
"""Return the tool's input schema so the builder UI renders dynamic fields."""
|
||||||
|
return data.get("tool_input_schema", {})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||||
|
"""Return the current tool_arguments as defaults for the dynamic fields."""
|
||||||
|
return data.get("tool_arguments", {})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||||
|
"""Check which required tool arguments are missing."""
|
||||||
|
required_fields = cls.get_input_schema(data).get("required", [])
|
||||||
|
tool_arguments = data.get("tool_arguments", {})
|
||||||
|
return set(required_fields) - set(tool_arguments)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||||
|
"""Validate tool_arguments against the tool's input schema."""
|
||||||
|
tool_schema = cls.get_input_schema(data)
|
||||||
|
if not tool_schema:
|
||||||
|
return None
|
||||||
|
tool_arguments = data.get("tool_arguments", {})
|
||||||
|
return validate_with_jsonschema(tool_schema, tool_arguments)
|
||||||
|
|
||||||
|
class Output(BlockSchemaOutput):
|
||||||
|
result: Any = SchemaField(description="The result returned by the MCP tool")
|
||||||
|
error: str = SchemaField(description="Error message if the tool call failed")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
id="a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4",
|
||||||
|
description="Connect to any MCP server and execute its tools. "
|
||||||
|
"Provide a server URL, select a tool, and pass arguments dynamically.",
|
||||||
|
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||||
|
input_schema=MCPToolBlock.Input,
|
||||||
|
output_schema=MCPToolBlock.Output,
|
||||||
|
block_type=BlockType.STANDARD,
|
||||||
|
test_input={
|
||||||
|
"server_url": "https://mcp.example.com/mcp",
|
||||||
|
"credentials": TEST_CREDENTIALS_INPUT,
|
||||||
|
"selected_tool": "get_weather",
|
||||||
|
"tool_input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
"tool_arguments": {"city": "London"},
|
||||||
|
},
|
||||||
|
test_output=[
|
||||||
|
(
|
||||||
|
"result",
|
||||||
|
{"weather": "sunny", "temperature": 20},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
test_mock={
|
||||||
|
"_call_mcp_tool": lambda *a, **kw: {
|
||||||
|
"weather": "sunny",
|
||||||
|
"temperature": 20,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
test_credentials=TEST_CREDENTIALS,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _call_mcp_tool(
|
||||||
|
self,
|
||||||
|
server_url: str,
|
||||||
|
tool_name: str,
|
||||||
|
arguments: dict[str, Any],
|
||||||
|
auth_token: str | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Call a tool on the MCP server. Extracted for easy mocking in tests."""
|
||||||
|
# Trust the user-configured server URL to allow internal/localhost servers
|
||||||
|
client = MCPClient(
|
||||||
|
server_url,
|
||||||
|
auth_token=auth_token,
|
||||||
|
trusted_origins=[server_url],
|
||||||
|
)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool(tool_name, arguments)
|
||||||
|
|
||||||
|
if result.is_error:
|
||||||
|
error_text = ""
|
||||||
|
for item in result.content:
|
||||||
|
if item.get("type") == "text":
|
||||||
|
error_text += item.get("text", "")
|
||||||
|
raise MCPClientError(
|
||||||
|
f"MCP tool '{tool_name}' returned an error: "
|
||||||
|
f"{error_text or 'Unknown error'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract text content from the result
|
||||||
|
output_parts = []
|
||||||
|
for item in result.content:
|
||||||
|
if item.get("type") == "text":
|
||||||
|
text = item.get("text", "")
|
||||||
|
# Try to parse as JSON for structured output
|
||||||
|
try:
|
||||||
|
output_parts.append(json.loads(text))
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
output_parts.append(text)
|
||||||
|
elif item.get("type") == "image":
|
||||||
|
output_parts.append(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"data": item.get("data"),
|
||||||
|
"mimeType": item.get("mimeType"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif item.get("type") == "resource":
|
||||||
|
output_parts.append(item.get("resource", {}))
|
||||||
|
|
||||||
|
# If single result, unwrap
|
||||||
|
if len(output_parts) == 1:
|
||||||
|
return output_parts[0]
|
||||||
|
return output_parts if output_parts else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_auth_token(credentials: MCPCredentials) -> str | None:
|
||||||
|
"""Extract a Bearer token from either API key or OAuth2 credentials."""
|
||||||
|
if isinstance(credentials, OAuth2Credentials):
|
||||||
|
return credentials.access_token.get_secret_value()
|
||||||
|
|
||||||
|
if isinstance(credentials, APIKeyCredentials) and credentials.api_key:
|
||||||
|
token_value = credentials.api_key.get_secret_value()
|
||||||
|
if token_value:
|
||||||
|
return token_value
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
input_data: Input,
|
||||||
|
*,
|
||||||
|
credentials: MCPCredentials,
|
||||||
|
**kwargs,
|
||||||
|
) -> BlockOutput:
|
||||||
|
if not input_data.server_url:
|
||||||
|
yield "error", "MCP server URL is required"
|
||||||
|
return
|
||||||
|
|
||||||
|
if not input_data.selected_tool:
|
||||||
|
yield "error", "No tool selected. Please select a tool from the dropdown."
|
||||||
|
return
|
||||||
|
|
||||||
|
auth_token = self._extract_auth_token(credentials)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self._call_mcp_tool(
|
||||||
|
server_url=input_data.server_url,
|
||||||
|
tool_name=input_data.selected_tool,
|
||||||
|
arguments=input_data.tool_arguments,
|
||||||
|
auth_token=auth_token,
|
||||||
|
)
|
||||||
|
yield "result", result
|
||||||
|
except MCPClientError as e:
|
||||||
|
yield "error", str(e)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"MCP tool call failed: {e}")
|
||||||
|
yield "error", f"MCP tool call failed: {str(e)}"
|
||||||
316
autogpt_platform/backend/backend/blocks/mcp/client.py
Normal file
316
autogpt_platform/backend/backend/blocks/mcp/client.py
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
"""
|
||||||
|
MCP (Model Context Protocol) HTTP client.
|
||||||
|
|
||||||
|
Implements the MCP Streamable HTTP transport for listing tools and calling tools
|
||||||
|
on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST.
|
||||||
|
|
||||||
|
Handles both JSON and SSE (text/event-stream) response formats per the MCP spec.
|
||||||
|
|
||||||
|
Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MCPTool:
|
||||||
|
"""Represents an MCP tool discovered from a server."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
input_schema: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MCPCallResult:
|
||||||
|
"""Result from calling an MCP tool."""
|
||||||
|
|
||||||
|
content: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
is_error: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class MCPClientError(Exception):
|
||||||
|
"""Raised when an MCP protocol error occurs."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MCPClient:
|
||||||
|
"""
|
||||||
|
Async HTTP client for the MCP Streamable HTTP transport.
|
||||||
|
|
||||||
|
Communicates with MCP servers using JSON-RPC 2.0 over HTTP POST.
|
||||||
|
Supports optional Bearer token authentication.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_url: str,
|
||||||
|
auth_token: str | None = None,
|
||||||
|
trusted_origins: list[str] | None = None,
|
||||||
|
):
|
||||||
|
self.server_url = server_url.rstrip("/")
|
||||||
|
self.auth_token = auth_token
|
||||||
|
self.trusted_origins = trusted_origins or []
|
||||||
|
self._request_id = 0
|
||||||
|
|
||||||
|
def _next_id(self) -> int:
|
||||||
|
self._request_id += 1
|
||||||
|
return self._request_id
|
||||||
|
|
||||||
|
def _build_headers(self) -> dict[str, str]:
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
}
|
||||||
|
if self.auth_token:
|
||||||
|
headers["Authorization"] = f"Bearer {self.auth_token}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def _build_jsonrpc_request(
|
||||||
|
self, method: str, params: dict[str, Any] | None = None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
req: dict[str, Any] = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": method,
|
||||||
|
"id": self._next_id(),
|
||||||
|
}
|
||||||
|
if params is not None:
|
||||||
|
req["params"] = params
|
||||||
|
return req
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_sse_response(text: str) -> dict[str, Any]:
|
||||||
|
"""Parse an SSE (text/event-stream) response body into JSON-RPC data.
|
||||||
|
|
||||||
|
MCP servers may return responses as SSE with format:
|
||||||
|
event: message
|
||||||
|
data: {"jsonrpc":"2.0","result":{...},"id":1}
|
||||||
|
|
||||||
|
We extract the last `data:` line that contains a JSON-RPC response
|
||||||
|
(i.e. has an "id" field), which is the reply to our request.
|
||||||
|
"""
|
||||||
|
last_data: dict[str, Any] | None = None
|
||||||
|
for line in text.splitlines():
|
||||||
|
stripped = line.strip()
|
||||||
|
if stripped.startswith("data:"):
|
||||||
|
payload = stripped[len("data:") :].strip()
|
||||||
|
if not payload:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
parsed = json.loads(payload)
|
||||||
|
# Only keep JSON-RPC responses (have "id"), skip notifications
|
||||||
|
if isinstance(parsed, dict) and "id" in parsed:
|
||||||
|
last_data = parsed
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
continue
|
||||||
|
if last_data is None:
|
||||||
|
raise MCPClientError("No JSON-RPC response found in SSE stream")
|
||||||
|
return last_data
|
||||||
|
|
||||||
|
async def _send_request(
|
||||||
|
self, method: str, params: dict[str, Any] | None = None
|
||||||
|
) -> Any:
|
||||||
|
"""Send a JSON-RPC request to the MCP server and return the result.
|
||||||
|
|
||||||
|
Handles both ``application/json`` and ``text/event-stream`` responses
|
||||||
|
as required by the MCP Streamable HTTP transport specification.
|
||||||
|
"""
|
||||||
|
payload = self._build_jsonrpc_request(method, params)
|
||||||
|
headers = self._build_headers()
|
||||||
|
|
||||||
|
requests = Requests(
|
||||||
|
raise_for_status=True,
|
||||||
|
extra_headers=headers,
|
||||||
|
trusted_origins=self.trusted_origins,
|
||||||
|
)
|
||||||
|
response = await requests.post(self.server_url, json=payload)
|
||||||
|
|
||||||
|
content_type = response.headers.get("content-type", "")
|
||||||
|
if "text/event-stream" in content_type:
|
||||||
|
body = self._parse_sse_response(response.text())
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
body = response.json()
|
||||||
|
except (ValueError, Exception) as e:
|
||||||
|
raise MCPClientError(
|
||||||
|
f"MCP server returned non-JSON response: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# Handle JSON-RPC error
|
||||||
|
if "error" in body:
|
||||||
|
error = body["error"]
|
||||||
|
if isinstance(error, dict):
|
||||||
|
raise MCPClientError(
|
||||||
|
f"MCP server error [{error.get('code', '?')}]: "
|
||||||
|
f"{error.get('message', 'Unknown error')}"
|
||||||
|
)
|
||||||
|
raise MCPClientError(f"MCP server error: {error}")
|
||||||
|
|
||||||
|
return body.get("result")
|
||||||
|
|
||||||
|
async def _send_notification(self, method: str) -> None:
|
||||||
|
"""Send a JSON-RPC notification (no id, no response expected)."""
|
||||||
|
headers = self._build_headers()
|
||||||
|
notification = {"jsonrpc": "2.0", "method": method}
|
||||||
|
requests = Requests(
|
||||||
|
raise_for_status=False,
|
||||||
|
extra_headers=headers,
|
||||||
|
trusted_origins=self.trusted_origins,
|
||||||
|
)
|
||||||
|
await requests.post(self.server_url, json=notification)
|
||||||
|
|
||||||
|
async def discover_auth(self) -> dict[str, Any] | None:
|
||||||
|
"""Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec).
|
||||||
|
|
||||||
|
Returns ``None`` if the server doesn't require auth, otherwise returns
|
||||||
|
a dict with:
|
||||||
|
- ``authorization_servers``: list of authorization server URLs
|
||||||
|
- ``resource``: the resource indicator URL (usually the MCP endpoint)
|
||||||
|
- ``scopes_supported``: optional list of supported scopes
|
||||||
|
|
||||||
|
The caller can then fetch the authorization server metadata to get
|
||||||
|
``authorization_endpoint``, ``token_endpoint``, etc.
|
||||||
|
"""
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
parsed = urlparse(self.server_url)
|
||||||
|
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||||
|
|
||||||
|
# Build candidates for protected-resource metadata (per RFC 9728)
|
||||||
|
path = parsed.path.rstrip("/")
|
||||||
|
candidates = []
|
||||||
|
if path and path != "/":
|
||||||
|
candidates.append(f"{base}/.well-known/oauth-protected-resource{path}")
|
||||||
|
candidates.append(f"{base}/.well-known/oauth-protected-resource")
|
||||||
|
|
||||||
|
requests = Requests(
|
||||||
|
raise_for_status=False,
|
||||||
|
trusted_origins=self.trusted_origins,
|
||||||
|
)
|
||||||
|
for url in candidates:
|
||||||
|
try:
|
||||||
|
resp = await requests.get(url)
|
||||||
|
if resp.status == 200:
|
||||||
|
data = resp.json()
|
||||||
|
if isinstance(data, dict) and "authorization_servers" in data:
|
||||||
|
return data
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def discover_auth_server_metadata(
|
||||||
|
self, auth_server_url: str
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Fetch the OAuth Authorization Server Metadata (RFC 8414).
|
||||||
|
|
||||||
|
Given an authorization server URL, returns a dict with:
|
||||||
|
- ``authorization_endpoint``
|
||||||
|
- ``token_endpoint``
|
||||||
|
- ``registration_endpoint`` (for dynamic client registration)
|
||||||
|
- ``scopes_supported``
|
||||||
|
- ``code_challenge_methods_supported``
|
||||||
|
- etc.
|
||||||
|
"""
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
parsed = urlparse(auth_server_url)
|
||||||
|
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||||
|
path = parsed.path.rstrip("/")
|
||||||
|
|
||||||
|
# Try standard metadata endpoints (RFC 8414 and OpenID Connect)
|
||||||
|
candidates = []
|
||||||
|
if path and path != "/":
|
||||||
|
candidates.append(f"{base}/.well-known/oauth-authorization-server{path}")
|
||||||
|
candidates.append(f"{base}/.well-known/oauth-authorization-server")
|
||||||
|
candidates.append(f"{base}/.well-known/openid-configuration")
|
||||||
|
|
||||||
|
requests = Requests(
|
||||||
|
raise_for_status=False,
|
||||||
|
trusted_origins=self.trusted_origins,
|
||||||
|
)
|
||||||
|
for url in candidates:
|
||||||
|
try:
|
||||||
|
resp = await requests.get(url)
|
||||||
|
if resp.status == 200:
|
||||||
|
data = resp.json()
|
||||||
|
if isinstance(data, dict) and "authorization_endpoint" in data:
|
||||||
|
return data
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def initialize(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Send the MCP initialize request.
|
||||||
|
|
||||||
|
This is required by the MCP protocol before any other requests.
|
||||||
|
Returns the server's capabilities.
|
||||||
|
"""
|
||||||
|
result = await self._send_request(
|
||||||
|
"initialize",
|
||||||
|
{
|
||||||
|
"protocolVersion": "2025-03-26",
|
||||||
|
"capabilities": {},
|
||||||
|
"clientInfo": {"name": "AutoGPT-Platform", "version": "1.0.0"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Send initialized notification (no response expected)
|
||||||
|
await self._send_notification("notifications/initialized")
|
||||||
|
|
||||||
|
return result or {}
|
||||||
|
|
||||||
|
async def list_tools(self) -> list[MCPTool]:
|
||||||
|
"""
|
||||||
|
Discover available tools from the MCP server.
|
||||||
|
|
||||||
|
Returns a list of MCPTool objects with name, description, and input schema.
|
||||||
|
"""
|
||||||
|
result = await self._send_request("tools/list")
|
||||||
|
if not result or "tools" not in result:
|
||||||
|
return []
|
||||||
|
|
||||||
|
tools = []
|
||||||
|
for tool_data in result["tools"]:
|
||||||
|
tools.append(
|
||||||
|
MCPTool(
|
||||||
|
name=tool_data.get("name", ""),
|
||||||
|
description=tool_data.get("description", ""),
|
||||||
|
input_schema=tool_data.get("inputSchema", {}),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
async def call_tool(
|
||||||
|
self, tool_name: str, arguments: dict[str, Any]
|
||||||
|
) -> MCPCallResult:
|
||||||
|
"""
|
||||||
|
Call a tool on the MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: The name of the tool to call.
|
||||||
|
arguments: The arguments to pass to the tool.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MCPCallResult with the tool's response content.
|
||||||
|
"""
|
||||||
|
result = await self._send_request(
|
||||||
|
"tools/call",
|
||||||
|
{"name": tool_name, "arguments": arguments},
|
||||||
|
)
|
||||||
|
if not result:
|
||||||
|
return MCPCallResult(is_error=True)
|
||||||
|
|
||||||
|
return MCPCallResult(
|
||||||
|
content=result.get("content", []),
|
||||||
|
is_error=result.get("isError", False),
|
||||||
|
)
|
||||||
21
autogpt_platform/backend/backend/blocks/mcp/conftest.py
Normal file
21
autogpt_platform/backend/backend/blocks/mcp/conftest.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""
|
||||||
|
Conftest for MCP block tests.
|
||||||
|
|
||||||
|
Override the session-scoped server and graph_cleanup fixtures from
|
||||||
|
backend/conftest.py so that MCP integration tests don't spin up the
|
||||||
|
full SpinTestServer infrastructure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def server():
|
||||||
|
"""No-op override — MCP tests don't need the full platform server."""
|
||||||
|
yield None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def graph_cleanup(server):
|
||||||
|
"""No-op override — MCP tests don't create graphs."""
|
||||||
|
yield
|
||||||
198
autogpt_platform/backend/backend/blocks/mcp/oauth.py
Normal file
198
autogpt_platform/backend/backend/blocks/mcp/oauth.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
"""
|
||||||
|
MCP OAuth handler for MCP servers that use OAuth 2.1 authorization.
|
||||||
|
|
||||||
|
Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed,
|
||||||
|
MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata.
|
||||||
|
This handler accepts those endpoints at construction time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import urllib.parse
|
||||||
|
from typing import ClassVar, Optional
|
||||||
|
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.data.model import OAuth2Credentials
|
||||||
|
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||||
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPOAuthHandler(BaseOAuthHandler):
|
||||||
|
"""
|
||||||
|
OAuth handler for MCP servers with dynamically-discovered endpoints.
|
||||||
|
|
||||||
|
Construction requires the authorization and token endpoint URLs,
|
||||||
|
which are obtained via MCP OAuth metadata discovery
|
||||||
|
(``MCPClient.discover_auth`` + ``discover_auth_server_metadata``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP
|
||||||
|
DEFAULT_SCOPES: ClassVar[list[str]] = []
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
client_id: str,
|
||||||
|
client_secret: str,
|
||||||
|
redirect_uri: str,
|
||||||
|
*,
|
||||||
|
authorize_url: str,
|
||||||
|
token_url: str,
|
||||||
|
revoke_url: str | None = None,
|
||||||
|
resource_url: str | None = None,
|
||||||
|
):
|
||||||
|
self.client_id = client_id
|
||||||
|
self.client_secret = client_secret
|
||||||
|
self.redirect_uri = redirect_uri
|
||||||
|
self.authorize_url = authorize_url
|
||||||
|
self.token_url = token_url
|
||||||
|
self.revoke_url = revoke_url
|
||||||
|
self.resource_url = resource_url
|
||||||
|
|
||||||
|
def get_login_url(
|
||||||
|
self,
|
||||||
|
scopes: list[str],
|
||||||
|
state: str,
|
||||||
|
code_challenge: Optional[str],
|
||||||
|
) -> str:
|
||||||
|
scopes = self.handle_default_scopes(scopes)
|
||||||
|
|
||||||
|
params: dict[str, str] = {
|
||||||
|
"response_type": "code",
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"redirect_uri": self.redirect_uri,
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
if scopes:
|
||||||
|
params["scope"] = " ".join(scopes)
|
||||||
|
# PKCE (S256) — included when the caller provides a code_challenge
|
||||||
|
if code_challenge:
|
||||||
|
params["code_challenge"] = code_challenge
|
||||||
|
params["code_challenge_method"] = "S256"
|
||||||
|
# MCP spec requires resource indicator (RFC 8707)
|
||||||
|
if self.resource_url:
|
||||||
|
params["resource"] = self.resource_url
|
||||||
|
|
||||||
|
return f"{self.authorize_url}?{urllib.parse.urlencode(params)}"
|
||||||
|
|
||||||
|
async def exchange_code_for_tokens(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
scopes: list[str],
|
||||||
|
code_verifier: Optional[str],
|
||||||
|
) -> OAuth2Credentials:
|
||||||
|
data: dict[str, str] = {
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": self.redirect_uri,
|
||||||
|
"client_id": self.client_id,
|
||||||
|
}
|
||||||
|
if self.client_secret:
|
||||||
|
data["client_secret"] = self.client_secret
|
||||||
|
if code_verifier:
|
||||||
|
data["code_verifier"] = code_verifier
|
||||||
|
if self.resource_url:
|
||||||
|
data["resource"] = self.resource_url
|
||||||
|
|
||||||
|
response = await Requests(raise_for_status=True).post(
|
||||||
|
self.token_url,
|
||||||
|
data=data,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
)
|
||||||
|
tokens = response.json()
|
||||||
|
|
||||||
|
if "error" in tokens:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Token exchange failed: {tokens.get('error_description', tokens['error'])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
now = int(time.time())
|
||||||
|
expires_in = tokens.get("expires_in")
|
||||||
|
|
||||||
|
return OAuth2Credentials(
|
||||||
|
provider=str(self.PROVIDER_NAME),
|
||||||
|
title=None,
|
||||||
|
access_token=SecretStr(tokens["access_token"]),
|
||||||
|
refresh_token=(
|
||||||
|
SecretStr(tokens["refresh_token"])
|
||||||
|
if tokens.get("refresh_token")
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
access_token_expires_at=now + expires_in if expires_in else None,
|
||||||
|
refresh_token_expires_at=None,
|
||||||
|
scopes=scopes,
|
||||||
|
metadata={
|
||||||
|
"mcp_token_url": self.token_url,
|
||||||
|
"mcp_resource_url": self.resource_url,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _refresh_tokens(
|
||||||
|
self, credentials: OAuth2Credentials
|
||||||
|
) -> OAuth2Credentials:
|
||||||
|
if not credentials.refresh_token:
|
||||||
|
raise ValueError("No refresh token available for MCP OAuth credentials")
|
||||||
|
|
||||||
|
data: dict[str, str] = {
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||||
|
"client_id": self.client_id,
|
||||||
|
}
|
||||||
|
if self.client_secret:
|
||||||
|
data["client_secret"] = self.client_secret
|
||||||
|
if self.resource_url:
|
||||||
|
data["resource"] = self.resource_url
|
||||||
|
|
||||||
|
response = await Requests(raise_for_status=True).post(
|
||||||
|
self.token_url,
|
||||||
|
data=data,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
)
|
||||||
|
tokens = response.json()
|
||||||
|
|
||||||
|
if "error" in tokens:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Token refresh failed: {tokens.get('error_description', tokens['error'])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
now = int(time.time())
|
||||||
|
expires_in = tokens.get("expires_in")
|
||||||
|
|
||||||
|
return OAuth2Credentials(
|
||||||
|
id=credentials.id,
|
||||||
|
provider=str(self.PROVIDER_NAME),
|
||||||
|
title=credentials.title,
|
||||||
|
access_token=SecretStr(tokens["access_token"]),
|
||||||
|
refresh_token=(
|
||||||
|
SecretStr(str(tokens["refresh_token"]))
|
||||||
|
if tokens.get("refresh_token")
|
||||||
|
else credentials.refresh_token
|
||||||
|
),
|
||||||
|
access_token_expires_at=now + expires_in if expires_in else None,
|
||||||
|
refresh_token_expires_at=credentials.refresh_token_expires_at,
|
||||||
|
scopes=credentials.scopes,
|
||||||
|
metadata=credentials.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||||
|
if not self.revoke_url:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = {
|
||||||
|
"token": credentials.access_token.get_secret_value(),
|
||||||
|
"token_type_hint": "access_token",
|
||||||
|
"client_id": self.client_id,
|
||||||
|
}
|
||||||
|
await Requests().post(
|
||||||
|
self.revoke_url,
|
||||||
|
data=data,
|
||||||
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True)
|
||||||
|
return False
|
||||||
104
autogpt_platform/backend/backend/blocks/mcp/test_e2e.py
Normal file
104
autogpt_platform/backend/backend/blocks/mcp/test_e2e.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""
|
||||||
|
End-to-end tests against a real public MCP server.
|
||||||
|
|
||||||
|
These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp)
|
||||||
|
which is publicly accessible without authentication and returns SSE responses.
|
||||||
|
|
||||||
|
Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped
|
||||||
|
independently of the rest of the test suite (they require network access).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.blocks.mcp.client import MCPClient
|
||||||
|
|
||||||
|
# Public MCP server that requires no authentication
|
||||||
|
OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.e2e
|
||||||
|
class TestRealMCPServer:
|
||||||
|
"""Tests against the live OpenAI docs MCP server."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_initialize(self):
|
||||||
|
"""Verify we can complete the MCP handshake with a real server."""
|
||||||
|
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||||
|
result = await client.initialize()
|
||||||
|
|
||||||
|
assert result["protocolVersion"] == "2025-03-26"
|
||||||
|
assert "serverInfo" in result
|
||||||
|
assert result["serverInfo"]["name"] == "openai-docs-mcp"
|
||||||
|
assert "tools" in result.get("capabilities", {})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_tools(self):
|
||||||
|
"""Verify we can discover tools from a real MCP server."""
|
||||||
|
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||||
|
await client.initialize()
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
|
assert len(tools) >= 3 # server has at least 5 tools as of writing
|
||||||
|
|
||||||
|
tool_names = {t.name for t in tools}
|
||||||
|
# These tools are documented and should be stable
|
||||||
|
assert "search_openai_docs" in tool_names
|
||||||
|
assert "list_openai_docs" in tool_names
|
||||||
|
assert "fetch_openai_doc" in tool_names
|
||||||
|
|
||||||
|
# Verify schema structure
|
||||||
|
search_tool = next(t for t in tools if t.name == "search_openai_docs")
|
||||||
|
assert "query" in search_tool.input_schema.get("properties", {})
|
||||||
|
assert "query" in search_tool.input_schema.get("required", [])
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_tool_list_api_endpoints(self):
|
||||||
|
"""Call the list_api_endpoints tool and verify we get real data."""
|
||||||
|
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool("list_api_endpoints", {})
|
||||||
|
|
||||||
|
assert not result.is_error
|
||||||
|
assert len(result.content) >= 1
|
||||||
|
assert result.content[0]["type"] == "text"
|
||||||
|
|
||||||
|
data = json.loads(result.content[0]["text"])
|
||||||
|
assert "paths" in data or "urls" in data
|
||||||
|
# The OpenAI API should have many endpoints
|
||||||
|
total = data.get("total", len(data.get("paths", [])))
|
||||||
|
assert total > 50
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_tool_search(self):
|
||||||
|
"""Search for docs and verify we get results."""
|
||||||
|
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool(
|
||||||
|
"search_openai_docs", {"query": "chat completions", "limit": 3}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not result.is_error
|
||||||
|
assert len(result.content) >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sse_response_handling(self):
|
||||||
|
"""Verify the client correctly handles SSE responses from a real server.
|
||||||
|
|
||||||
|
This is the key test — our local test server returns JSON,
|
||||||
|
but real MCP servers typically return SSE. This proves the
|
||||||
|
SSE parsing works end-to-end.
|
||||||
|
"""
|
||||||
|
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||||
|
# initialize() internally calls _send_request which must parse SSE
|
||||||
|
result = await client.initialize()
|
||||||
|
|
||||||
|
# If we got here without error, SSE parsing works
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert "protocolVersion" in result
|
||||||
|
|
||||||
|
# Also verify list_tools works (another SSE response)
|
||||||
|
tools = await client.list_tools()
|
||||||
|
assert len(tools) > 0
|
||||||
|
assert all(hasattr(t, "name") for t in tools)
|
||||||
367
autogpt_platform/backend/backend/blocks/mcp/test_integration.py
Normal file
367
autogpt_platform/backend/backend/blocks/mcp/test_integration.py
Normal file
@@ -0,0 +1,367 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for MCP client and MCPToolBlock against a real HTTP server.
|
||||||
|
|
||||||
|
These tests spin up a local MCP test server and run the full client/block flow
|
||||||
|
against it — no mocking, real HTTP requests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from aiohttp import web
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.blocks.mcp.block import MCPToolBlock
|
||||||
|
from backend.blocks.mcp.client import MCPClient
|
||||||
|
from backend.blocks.mcp.test_server import create_test_mcp_app
|
||||||
|
from backend.data.model import APIKeyCredentials
|
||||||
|
|
||||||
|
|
||||||
|
class _MCPTestServer:
|
||||||
|
"""
|
||||||
|
Run an MCP test server in a background thread with its own event loop.
|
||||||
|
This avoids event loop conflicts with pytest-asyncio.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, auth_token: str | None = None):
|
||||||
|
self.auth_token = auth_token
|
||||||
|
self.url: str = ""
|
||||||
|
self._runner: web.AppRunner | None = None
|
||||||
|
self._loop: asyncio.AbstractEventLoop | None = None
|
||||||
|
self._thread: threading.Thread | None = None
|
||||||
|
self._started = threading.Event()
|
||||||
|
|
||||||
|
def _run(self):
|
||||||
|
self._loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self._loop)
|
||||||
|
self._loop.run_until_complete(self._start())
|
||||||
|
self._started.set()
|
||||||
|
self._loop.run_forever()
|
||||||
|
|
||||||
|
async def _start(self):
|
||||||
|
app = create_test_mcp_app(auth_token=self.auth_token)
|
||||||
|
self._runner = web.AppRunner(app)
|
||||||
|
await self._runner.setup()
|
||||||
|
site = web.TCPSite(self._runner, "127.0.0.1", 0)
|
||||||
|
await site.start()
|
||||||
|
port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr]
|
||||||
|
self.url = f"http://127.0.0.1:{port}/mcp"
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
if not self._started.wait(timeout=5):
|
||||||
|
raise RuntimeError("MCP test server failed to start within 5 seconds")
|
||||||
|
return self
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
if self._loop and self._runner:
|
||||||
|
asyncio.run_coroutine_threadsafe(self._runner.cleanup(), self._loop).result(
|
||||||
|
timeout=5
|
||||||
|
)
|
||||||
|
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||||
|
if self._thread:
|
||||||
|
self._thread.join(timeout=5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mcp_server():
|
||||||
|
"""Start a local MCP test server in a background thread."""
|
||||||
|
server = _MCPTestServer()
|
||||||
|
server.start()
|
||||||
|
yield server.url
|
||||||
|
server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mcp_server_with_auth():
|
||||||
|
"""Start a local MCP test server with auth in a background thread."""
|
||||||
|
server = _MCPTestServer(auth_token="test-secret-token")
|
||||||
|
server.start()
|
||||||
|
yield server.url, "test-secret-token"
|
||||||
|
server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_client(url: str, auth_token: str | None = None) -> MCPClient:
|
||||||
|
"""Create an MCPClient with localhost trusted for integration tests."""
|
||||||
|
return MCPClient(url, auth_token=auth_token, trusted_origins=[url])
|
||||||
|
|
||||||
|
|
||||||
|
def _make_fake_creds(api_key: str = "FAKE_API_KEY") -> APIKeyCredentials:
|
||||||
|
return APIKeyCredentials(
|
||||||
|
id="test-integration",
|
||||||
|
provider="mcp",
|
||||||
|
api_key=SecretStr(api_key),
|
||||||
|
title="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── MCPClient integration tests ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPClientIntegration:
|
||||||
|
"""Test MCPClient against a real local MCP server."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_initialize(self, mcp_server):
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
result = await client.initialize()
|
||||||
|
|
||||||
|
assert result["protocolVersion"] == "2025-03-26"
|
||||||
|
assert result["serverInfo"]["name"] == "test-mcp-server"
|
||||||
|
assert "tools" in result["capabilities"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_tools(self, mcp_server):
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
|
assert len(tools) == 3
|
||||||
|
|
||||||
|
tool_names = {t.name for t in tools}
|
||||||
|
assert tool_names == {"get_weather", "add_numbers", "echo"}
|
||||||
|
|
||||||
|
# Check get_weather schema
|
||||||
|
weather = next(t for t in tools if t.name == "get_weather")
|
||||||
|
assert weather.description == "Get current weather for a city"
|
||||||
|
assert "city" in weather.input_schema["properties"]
|
||||||
|
assert weather.input_schema["required"] == ["city"]
|
||||||
|
|
||||||
|
# Check add_numbers schema
|
||||||
|
add = next(t for t in tools if t.name == "add_numbers")
|
||||||
|
assert "a" in add.input_schema["properties"]
|
||||||
|
assert "b" in add.input_schema["properties"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_tool_get_weather(self, mcp_server):
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool("get_weather", {"city": "London"})
|
||||||
|
|
||||||
|
assert not result.is_error
|
||||||
|
assert len(result.content) == 1
|
||||||
|
assert result.content[0]["type"] == "text"
|
||||||
|
|
||||||
|
data = json.loads(result.content[0]["text"])
|
||||||
|
assert data["city"] == "London"
|
||||||
|
assert data["temperature"] == 22
|
||||||
|
assert data["condition"] == "sunny"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_tool_add_numbers(self, mcp_server):
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool("add_numbers", {"a": 3, "b": 7})
|
||||||
|
|
||||||
|
assert not result.is_error
|
||||||
|
data = json.loads(result.content[0]["text"])
|
||||||
|
assert data["result"] == 10
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_tool_echo(self, mcp_server):
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool("echo", {"message": "Hello MCP!"})
|
||||||
|
|
||||||
|
assert not result.is_error
|
||||||
|
assert result.content[0]["text"] == "Hello MCP!"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_unknown_tool(self, mcp_server):
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
result = await client.call_tool("nonexistent_tool", {})
|
||||||
|
|
||||||
|
assert result.is_error
|
||||||
|
assert "Unknown tool" in result.content[0]["text"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_success(self, mcp_server_with_auth):
|
||||||
|
url, token = mcp_server_with_auth
|
||||||
|
client = _make_client(url, auth_token=token)
|
||||||
|
result = await client.initialize()
|
||||||
|
|
||||||
|
assert result["protocolVersion"] == "2025-03-26"
|
||||||
|
|
||||||
|
tools = await client.list_tools()
|
||||||
|
assert len(tools) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_failure(self, mcp_server_with_auth):
|
||||||
|
url, _ = mcp_server_with_auth
|
||||||
|
client = _make_client(url, auth_token="wrong-token")
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
await client.initialize()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth_missing(self, mcp_server_with_auth):
|
||||||
|
url, _ = mcp_server_with_auth
|
||||||
|
client = _make_client(url)
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
await client.initialize()
|
||||||
|
|
||||||
|
|
||||||
|
# ── MCPToolBlock integration tests ───────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPToolBlockIntegration:
|
||||||
|
"""Test MCPToolBlock end-to-end against a real local MCP server."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_flow_get_weather(self, mcp_server):
|
||||||
|
"""Full flow: discover tools, select one, execute it."""
|
||||||
|
# Step 1: Discover tools (simulating what the frontend/API would do)
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
tools = await client.list_tools()
|
||||||
|
assert len(tools) == 3
|
||||||
|
|
||||||
|
# Step 2: User selects "get_weather" and we get its schema
|
||||||
|
weather_tool = next(t for t in tools if t.name == "get_weather")
|
||||||
|
|
||||||
|
# Step 3: Execute the block with the selected tool
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url=mcp_server,
|
||||||
|
selected_tool="get_weather",
|
||||||
|
tool_input_schema=weather_tool.input_schema,
|
||||||
|
tool_arguments={"city": "Paris"},
|
||||||
|
credentials={ # type: ignore
|
||||||
|
"provider": "mcp",
|
||||||
|
"id": "test",
|
||||||
|
"type": "api_key",
|
||||||
|
"title": "test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, credentials=_make_fake_creds()):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert outputs[0][0] == "result"
|
||||||
|
result = outputs[0][1]
|
||||||
|
assert result["city"] == "Paris"
|
||||||
|
assert result["temperature"] == 22
|
||||||
|
assert result["condition"] == "sunny"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_flow_add_numbers(self, mcp_server):
|
||||||
|
"""Full flow for add_numbers tool."""
|
||||||
|
client = _make_client(mcp_server)
|
||||||
|
await client.initialize()
|
||||||
|
tools = await client.list_tools()
|
||||||
|
add_tool = next(t for t in tools if t.name == "add_numbers")
|
||||||
|
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url=mcp_server,
|
||||||
|
selected_tool="add_numbers",
|
||||||
|
tool_input_schema=add_tool.input_schema,
|
||||||
|
tool_arguments={"a": 42, "b": 58},
|
||||||
|
credentials={ # type: ignore
|
||||||
|
"provider": "mcp",
|
||||||
|
"id": "test",
|
||||||
|
"type": "api_key",
|
||||||
|
"title": "test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, credentials=_make_fake_creds()):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert outputs[0][0] == "result"
|
||||||
|
assert outputs[0][1]["result"] == 100
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_flow_echo_plain_text(self, mcp_server):
|
||||||
|
"""Verify plain text (non-JSON) responses work."""
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url=mcp_server,
|
||||||
|
selected_tool="echo",
|
||||||
|
tool_input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"message": {"type": "string"}},
|
||||||
|
"required": ["message"],
|
||||||
|
},
|
||||||
|
tool_arguments={"message": "Hello from AutoGPT!"},
|
||||||
|
credentials={ # type: ignore
|
||||||
|
"provider": "mcp",
|
||||||
|
"id": "test",
|
||||||
|
"type": "api_key",
|
||||||
|
"title": "test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, credentials=_make_fake_creds()):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert outputs[0][0] == "result"
|
||||||
|
assert outputs[0][1] == "Hello from AutoGPT!"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_flow_unknown_tool_yields_error(self, mcp_server):
|
||||||
|
"""Calling an unknown tool should yield an error output."""
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url=mcp_server,
|
||||||
|
selected_tool="nonexistent_tool",
|
||||||
|
tool_arguments={},
|
||||||
|
credentials={ # type: ignore
|
||||||
|
"provider": "mcp",
|
||||||
|
"id": "test",
|
||||||
|
"type": "api_key",
|
||||||
|
"title": "test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, credentials=_make_fake_creds()):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert outputs[0][0] == "error"
|
||||||
|
assert "returned an error" in outputs[0][1]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_flow_with_auth(self, mcp_server_with_auth):
|
||||||
|
"""Full flow with authentication."""
|
||||||
|
url, token = mcp_server_with_auth
|
||||||
|
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url=url,
|
||||||
|
selected_tool="echo",
|
||||||
|
tool_input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"message": {"type": "string"}},
|
||||||
|
"required": ["message"],
|
||||||
|
},
|
||||||
|
tool_arguments={"message": "Authenticated!"},
|
||||||
|
credentials={ # type: ignore
|
||||||
|
"provider": "mcp",
|
||||||
|
"id": "test",
|
||||||
|
"type": "api_key",
|
||||||
|
"title": "test",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(
|
||||||
|
input_data, credentials=_make_fake_creds(api_key=token)
|
||||||
|
):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert outputs[0][0] == "result"
|
||||||
|
assert outputs[0][1] == "Authenticated!"
|
||||||
667
autogpt_platform/backend/backend/blocks/mcp/test_mcp.py
Normal file
667
autogpt_platform/backend/backend/blocks/mcp/test_mcp.py
Normal file
@@ -0,0 +1,667 @@
|
|||||||
|
"""
|
||||||
|
Tests for MCP client and MCPToolBlock.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.blocks.mcp.block import (
|
||||||
|
TEST_CREDENTIALS,
|
||||||
|
TEST_CREDENTIALS_INPUT,
|
||||||
|
MCPToolBlock,
|
||||||
|
)
|
||||||
|
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError
|
||||||
|
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||||
|
from backend.util.test import execute_block_test
|
||||||
|
|
||||||
|
# ── SSE parsing unit tests ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSEParsing:
|
||||||
|
"""Tests for SSE (text/event-stream) response parsing."""
|
||||||
|
|
||||||
|
def test_parse_sse_simple(self):
|
||||||
|
sse = (
|
||||||
|
"event: message\n"
|
||||||
|
'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n'
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
body = MCPClient._parse_sse_response(sse)
|
||||||
|
assert body["result"] == {"tools": []}
|
||||||
|
assert body["id"] == 1
|
||||||
|
|
||||||
|
def test_parse_sse_with_notifications(self):
|
||||||
|
"""SSE streams can contain notifications (no id) before the response."""
|
||||||
|
sse = (
|
||||||
|
"event: message\n"
|
||||||
|
'data: {"jsonrpc":"2.0","method":"some/notification"}\n'
|
||||||
|
"\n"
|
||||||
|
"event: message\n"
|
||||||
|
'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n'
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
body = MCPClient._parse_sse_response(sse)
|
||||||
|
assert body["result"] == {"ok": True}
|
||||||
|
assert body["id"] == 2
|
||||||
|
|
||||||
|
def test_parse_sse_error_response(self):
|
||||||
|
sse = (
|
||||||
|
"event: message\n"
|
||||||
|
'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n'
|
||||||
|
)
|
||||||
|
body = MCPClient._parse_sse_response(sse)
|
||||||
|
assert "error" in body
|
||||||
|
assert body["error"]["code"] == -32600
|
||||||
|
|
||||||
|
def test_parse_sse_no_data_raises(self):
|
||||||
|
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
|
||||||
|
MCPClient._parse_sse_response("event: message\n\n")
|
||||||
|
|
||||||
|
def test_parse_sse_empty_raises(self):
|
||||||
|
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
|
||||||
|
MCPClient._parse_sse_response("")
|
||||||
|
|
||||||
|
def test_parse_sse_ignores_non_data_lines(self):
|
||||||
|
sse = (
|
||||||
|
": comment line\n"
|
||||||
|
"event: message\n"
|
||||||
|
"id: 123\n"
|
||||||
|
'data: {"jsonrpc":"2.0","result":"ok","id":1}\n'
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
body = MCPClient._parse_sse_response(sse)
|
||||||
|
assert body["result"] == "ok"
|
||||||
|
|
||||||
|
def test_parse_sse_uses_last_response(self):
|
||||||
|
"""If multiple responses exist, use the last one."""
|
||||||
|
sse = (
|
||||||
|
'data: {"jsonrpc":"2.0","result":"first","id":1}\n'
|
||||||
|
"\n"
|
||||||
|
'data: {"jsonrpc":"2.0","result":"second","id":2}\n'
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
body = MCPClient._parse_sse_response(sse)
|
||||||
|
assert body["result"] == "second"
|
||||||
|
|
||||||
|
|
||||||
|
# ── MCPClient unit tests ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPClient:
|
||||||
|
"""Tests for the MCP HTTP client."""
|
||||||
|
|
||||||
|
def test_build_headers_without_auth(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
headers = client._build_headers()
|
||||||
|
assert "Authorization" not in headers
|
||||||
|
assert headers["Content-Type"] == "application/json"
|
||||||
|
|
||||||
|
def test_build_headers_with_auth(self):
|
||||||
|
client = MCPClient("https://mcp.example.com", auth_token="my-token")
|
||||||
|
headers = client._build_headers()
|
||||||
|
assert headers["Authorization"] == "Bearer my-token"
|
||||||
|
|
||||||
|
def test_build_jsonrpc_request(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
req = client._build_jsonrpc_request("tools/list")
|
||||||
|
assert req["jsonrpc"] == "2.0"
|
||||||
|
assert req["method"] == "tools/list"
|
||||||
|
assert "id" in req
|
||||||
|
assert "params" not in req
|
||||||
|
|
||||||
|
def test_build_jsonrpc_request_with_params(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
req = client._build_jsonrpc_request(
|
||||||
|
"tools/call", {"name": "test", "arguments": {"x": 1}}
|
||||||
|
)
|
||||||
|
assert req["params"] == {"name": "test", "arguments": {"x": 1}}
|
||||||
|
|
||||||
|
def test_request_id_increments(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
req1 = client._build_jsonrpc_request("tools/list")
|
||||||
|
req2 = client._build_jsonrpc_request("tools/list")
|
||||||
|
assert req2["id"] > req1["id"]
|
||||||
|
|
||||||
|
def test_server_url_trailing_slash_stripped(self):
|
||||||
|
client = MCPClient("https://mcp.example.com/mcp/")
|
||||||
|
assert client.server_url == "https://mcp.example.com/mcp"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_request_success(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"result": {"tools": []},
|
||||||
|
"id": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value={"tools": []}):
|
||||||
|
result = await client._send_request("tools/list")
|
||||||
|
assert result == {"tools": []}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_request_error(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
async def mock_send(*args, **kwargs):
|
||||||
|
raise MCPClientError("MCP server error [-32600]: Invalid Request")
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", side_effect=mock_send):
|
||||||
|
with pytest.raises(MCPClientError, match="Invalid Request"):
|
||||||
|
await client._send_request("tools/list")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_tools(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
mock_result = {
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get current weather for a city",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "search",
|
||||||
|
"description": "Search the web",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"query": {"type": "string"}},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value=mock_result):
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
|
assert len(tools) == 2
|
||||||
|
assert tools[0].name == "get_weather"
|
||||||
|
assert tools[0].description == "Get current weather for a city"
|
||||||
|
assert tools[0].input_schema["properties"]["city"]["type"] == "string"
|
||||||
|
assert tools[1].name == "search"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_tools_empty(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value={"tools": []}):
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
|
assert tools == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_tools_none_result(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value=None):
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
|
assert tools == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_tool_success(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
mock_result = {
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": json.dumps({"temp": 20, "city": "London"})}
|
||||||
|
],
|
||||||
|
"isError": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value=mock_result):
|
||||||
|
result = await client.call_tool("get_weather", {"city": "London"})
|
||||||
|
|
||||||
|
assert not result.is_error
|
||||||
|
assert len(result.content) == 1
|
||||||
|
assert result.content[0]["type"] == "text"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_tool_error(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
mock_result = {
|
||||||
|
"content": [{"type": "text", "text": "City not found"}],
|
||||||
|
"isError": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value=mock_result):
|
||||||
|
result = await client.call_tool("get_weather", {"city": "???"})
|
||||||
|
|
||||||
|
assert result.is_error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_tool_none_result(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
with patch.object(client, "_send_request", return_value=None):
|
||||||
|
result = await client.call_tool("get_weather", {"city": "London"})
|
||||||
|
|
||||||
|
assert result.is_error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_initialize(self):
|
||||||
|
client = MCPClient("https://mcp.example.com")
|
||||||
|
|
||||||
|
mock_result = {
|
||||||
|
"protocolVersion": "2025-03-26",
|
||||||
|
"capabilities": {"tools": {}},
|
||||||
|
"serverInfo": {"name": "test-server", "version": "1.0.0"},
|
||||||
|
}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(client, "_send_request", return_value=mock_result) as mock_req,
|
||||||
|
patch.object(client, "_send_notification") as mock_notif,
|
||||||
|
):
|
||||||
|
result = await client.initialize()
|
||||||
|
|
||||||
|
mock_req.assert_called_once()
|
||||||
|
mock_notif.assert_called_once_with("notifications/initialized")
|
||||||
|
assert result["protocolVersion"] == "2025-03-26"
|
||||||
|
|
||||||
|
|
||||||
|
# ── MCPToolBlock unit tests ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPToolBlock:
|
||||||
|
"""Tests for the MCPToolBlock."""
|
||||||
|
|
||||||
|
def test_block_instantiation(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
assert block.id == "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
|
||||||
|
assert block.name == "MCPToolBlock"
|
||||||
|
|
||||||
|
def test_input_schema_has_required_fields(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
schema = block.input_schema.jsonschema()
|
||||||
|
props = schema.get("properties", {})
|
||||||
|
assert "server_url" in props
|
||||||
|
assert "selected_tool" in props
|
||||||
|
assert "tool_arguments" in props
|
||||||
|
assert "credentials" in props
|
||||||
|
|
||||||
|
def test_output_schema(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
schema = block.output_schema.jsonschema()
|
||||||
|
props = schema.get("properties", {})
|
||||||
|
assert "result" in props
|
||||||
|
assert "error" in props
|
||||||
|
|
||||||
|
def test_get_input_schema_with_tool_schema(self):
|
||||||
|
tool_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"query": {"type": "string"}},
|
||||||
|
"required": ["query"],
|
||||||
|
}
|
||||||
|
data = {"tool_input_schema": tool_schema}
|
||||||
|
result = MCPToolBlock.Input.get_input_schema(data)
|
||||||
|
assert result == tool_schema
|
||||||
|
|
||||||
|
def test_get_input_schema_without_tool_schema(self):
|
||||||
|
result = MCPToolBlock.Input.get_input_schema({})
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_get_input_defaults(self):
|
||||||
|
data = {"tool_arguments": {"city": "London"}}
|
||||||
|
result = MCPToolBlock.Input.get_input_defaults(data)
|
||||||
|
assert result == {"city": "London"}
|
||||||
|
|
||||||
|
def test_get_missing_input(self):
|
||||||
|
data = {
|
||||||
|
"tool_input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {"type": "string"},
|
||||||
|
"units": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["city", "units"],
|
||||||
|
},
|
||||||
|
"tool_arguments": {"city": "London"},
|
||||||
|
}
|
||||||
|
missing = MCPToolBlock.Input.get_missing_input(data)
|
||||||
|
assert missing == {"units"}
|
||||||
|
|
||||||
|
def test_get_missing_input_all_present(self):
|
||||||
|
data = {
|
||||||
|
"tool_input_schema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
"tool_arguments": {"city": "London"},
|
||||||
|
}
|
||||||
|
missing = MCPToolBlock.Input.get_missing_input(data)
|
||||||
|
assert missing == set()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_with_mock(self):
|
||||||
|
"""Test the block using the built-in test infrastructure."""
|
||||||
|
block = MCPToolBlock()
|
||||||
|
await execute_block_test(block)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_missing_server_url(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url="",
|
||||||
|
selected_tool="test",
|
||||||
|
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
)
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, credentials=TEST_CREDENTIALS):
|
||||||
|
outputs.append((name, data))
|
||||||
|
assert outputs == [("error", "MCP server URL is required")]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_missing_tool(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url="https://mcp.example.com/mcp",
|
||||||
|
selected_tool="",
|
||||||
|
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
)
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, credentials=TEST_CREDENTIALS):
|
||||||
|
outputs.append((name, data))
|
||||||
|
assert outputs == [
|
||||||
|
("error", "No tool selected. Please select a tool from the dropdown.")
|
||||||
|
]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_success(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url="https://mcp.example.com/mcp",
|
||||||
|
selected_tool="get_weather",
|
||||||
|
tool_input_schema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"city": {"type": "string"}},
|
||||||
|
},
|
||||||
|
tool_arguments={"city": "London"},
|
||||||
|
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_call(*args, **kwargs):
|
||||||
|
return {"temp": 20, "city": "London"}
|
||||||
|
|
||||||
|
block._call_mcp_tool = mock_call # type: ignore
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, credentials=TEST_CREDENTIALS):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert len(outputs) == 1
|
||||||
|
assert outputs[0][0] == "result"
|
||||||
|
assert outputs[0][1] == {"temp": 20, "city": "London"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_mcp_error(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url="https://mcp.example.com/mcp",
|
||||||
|
selected_tool="bad_tool",
|
||||||
|
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_call(*args, **kwargs):
|
||||||
|
raise MCPClientError("Tool not found")
|
||||||
|
|
||||||
|
block._call_mcp_tool = mock_call # type: ignore
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, credentials=TEST_CREDENTIALS):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert outputs[0][0] == "error"
|
||||||
|
assert "Tool not found" in outputs[0][1]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_mcp_tool_parses_json_text(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
|
||||||
|
mock_result = MCPCallResult(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": '{"temp": 20}'},
|
||||||
|
],
|
||||||
|
is_error=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_init(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def mock_call(self, name, args):
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MCPClient, "initialize", mock_init),
|
||||||
|
patch.object(MCPClient, "call_tool", mock_call),
|
||||||
|
):
|
||||||
|
result = await block._call_mcp_tool(
|
||||||
|
"https://mcp.example.com", "test_tool", {}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"temp": 20}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_mcp_tool_plain_text(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
|
||||||
|
mock_result = MCPCallResult(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "Hello, world!"},
|
||||||
|
],
|
||||||
|
is_error=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_init(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def mock_call(self, name, args):
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MCPClient, "initialize", mock_init),
|
||||||
|
patch.object(MCPClient, "call_tool", mock_call),
|
||||||
|
):
|
||||||
|
result = await block._call_mcp_tool(
|
||||||
|
"https://mcp.example.com", "test_tool", {}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "Hello, world!"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_mcp_tool_multiple_content(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
|
||||||
|
mock_result = MCPCallResult(
|
||||||
|
content=[
|
||||||
|
{"type": "text", "text": "Part 1"},
|
||||||
|
{"type": "text", "text": '{"part": 2}'},
|
||||||
|
],
|
||||||
|
is_error=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_init(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def mock_call(self, name, args):
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MCPClient, "initialize", mock_init),
|
||||||
|
patch.object(MCPClient, "call_tool", mock_call),
|
||||||
|
):
|
||||||
|
result = await block._call_mcp_tool(
|
||||||
|
"https://mcp.example.com", "test_tool", {}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == ["Part 1", {"part": 2}]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_mcp_tool_error_result(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
|
||||||
|
mock_result = MCPCallResult(
|
||||||
|
content=[{"type": "text", "text": "Something went wrong"}],
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_init(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def mock_call(self, name, args):
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MCPClient, "initialize", mock_init),
|
||||||
|
patch.object(MCPClient, "call_tool", mock_call),
|
||||||
|
):
|
||||||
|
with pytest.raises(MCPClientError, match="returned an error"):
|
||||||
|
await block._call_mcp_tool("https://mcp.example.com", "test_tool", {})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_mcp_tool_image_content(self):
|
||||||
|
block = MCPToolBlock()
|
||||||
|
|
||||||
|
mock_result = MCPCallResult(
|
||||||
|
content=[
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"data": "base64data==",
|
||||||
|
"mimeType": "image/png",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
is_error=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_init(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def mock_call(self, name, args):
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(MCPClient, "initialize", mock_init),
|
||||||
|
patch.object(MCPClient, "call_tool", mock_call),
|
||||||
|
):
|
||||||
|
result = await block._call_mcp_tool(
|
||||||
|
"https://mcp.example.com", "test_tool", {}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
"type": "image",
|
||||||
|
"data": "base64data==",
|
||||||
|
"mimeType": "image/png",
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_sends_api_key_credentials(self):
|
||||||
|
"""Ensure non-empty API keys are sent to the MCP server."""
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url="https://mcp.example.com/mcp",
|
||||||
|
selected_tool="test_tool",
|
||||||
|
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
creds = APIKeyCredentials(
|
||||||
|
id="test-id",
|
||||||
|
provider="mcp",
|
||||||
|
api_key=SecretStr("real-api-key"),
|
||||||
|
title="Real",
|
||||||
|
)
|
||||||
|
|
||||||
|
captured_tokens = []
|
||||||
|
|
||||||
|
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
||||||
|
captured_tokens.append(auth_token)
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
block._call_mcp_tool = mock_call # type: ignore
|
||||||
|
|
||||||
|
async for _ in block.run(input_data, credentials=creds):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert captured_tokens == ["real-api-key"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── OAuth2 credential support tests ─────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPOAuth2Support:
|
||||||
|
"""Tests for OAuth2 credential support in MCPToolBlock."""
|
||||||
|
|
||||||
|
def test_extract_auth_token_from_api_key(self):
|
||||||
|
creds = APIKeyCredentials(
|
||||||
|
id="test",
|
||||||
|
provider="mcp",
|
||||||
|
api_key=SecretStr("my-api-key"),
|
||||||
|
title="test",
|
||||||
|
)
|
||||||
|
token = MCPToolBlock._extract_auth_token(creds)
|
||||||
|
assert token == "my-api-key"
|
||||||
|
|
||||||
|
def test_extract_auth_token_from_oauth2(self):
|
||||||
|
creds = OAuth2Credentials(
|
||||||
|
id="test",
|
||||||
|
provider="mcp",
|
||||||
|
access_token=SecretStr("oauth2-access-token"),
|
||||||
|
scopes=["read"],
|
||||||
|
title="test",
|
||||||
|
)
|
||||||
|
token = MCPToolBlock._extract_auth_token(creds)
|
||||||
|
assert token == "oauth2-access-token"
|
||||||
|
|
||||||
|
def test_extract_auth_token_empty_skipped(self):
|
||||||
|
creds = APIKeyCredentials(
|
||||||
|
id="test",
|
||||||
|
provider="mcp",
|
||||||
|
api_key=SecretStr(""),
|
||||||
|
title="test",
|
||||||
|
)
|
||||||
|
token = MCPToolBlock._extract_auth_token(creds)
|
||||||
|
assert token is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_with_oauth2_credentials(self):
|
||||||
|
"""Verify the block can run with OAuth2 credentials."""
|
||||||
|
block = MCPToolBlock()
|
||||||
|
input_data = MCPToolBlock.Input(
|
||||||
|
server_url="https://mcp.example.com/mcp",
|
||||||
|
selected_tool="test_tool",
|
||||||
|
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
oauth2_creds = OAuth2Credentials(
|
||||||
|
id="test-id",
|
||||||
|
provider="mcp",
|
||||||
|
access_token=SecretStr("real-oauth2-token"),
|
||||||
|
scopes=["read", "write"],
|
||||||
|
title="MCP OAuth",
|
||||||
|
)
|
||||||
|
|
||||||
|
captured_tokens = []
|
||||||
|
|
||||||
|
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
||||||
|
captured_tokens.append(auth_token)
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
block._call_mcp_tool = mock_call # type: ignore
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
async for name, data in block.run(input_data, credentials=oauth2_creds):
|
||||||
|
outputs.append((name, data))
|
||||||
|
|
||||||
|
assert captured_tokens == ["real-oauth2-token"]
|
||||||
|
assert outputs == [("result", {"status": "ok"})]
|
||||||
242
autogpt_platform/backend/backend/blocks/mcp/test_oauth.py
Normal file
242
autogpt_platform/backend/backend/blocks/mcp/test_oauth.py
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
"""
|
||||||
|
Tests for MCP OAuth handler.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from backend.blocks.mcp.client import MCPClient
|
||||||
|
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||||
|
from backend.data.model import OAuth2Credentials
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_response(json_data: dict, status: int = 200) -> MagicMock:
|
||||||
|
"""Create a mock Response with synchronous json() (matching Requests.Response)."""
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.status = status
|
||||||
|
resp.ok = 200 <= status < 300
|
||||||
|
resp.json.return_value = json_data
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPOAuthHandler:
|
||||||
|
"""Tests for the MCPOAuthHandler."""
|
||||||
|
|
||||||
|
def _make_handler(self, **overrides) -> MCPOAuthHandler:
|
||||||
|
defaults = {
|
||||||
|
"client_id": "test-client-id",
|
||||||
|
"client_secret": "test-client-secret",
|
||||||
|
"redirect_uri": "https://app.example.com/callback",
|
||||||
|
"authorize_url": "https://auth.example.com/authorize",
|
||||||
|
"token_url": "https://auth.example.com/token",
|
||||||
|
}
|
||||||
|
defaults.update(overrides)
|
||||||
|
return MCPOAuthHandler(**defaults)
|
||||||
|
|
||||||
|
def test_get_login_url_basic(self):
|
||||||
|
handler = self._make_handler()
|
||||||
|
url = handler.get_login_url(
|
||||||
|
scopes=["read", "write"],
|
||||||
|
state="random-state-token",
|
||||||
|
code_challenge="S256-challenge-value",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "https://auth.example.com/authorize?" in url
|
||||||
|
assert "response_type=code" in url
|
||||||
|
assert "client_id=test-client-id" in url
|
||||||
|
assert "state=random-state-token" in url
|
||||||
|
assert "code_challenge=S256-challenge-value" in url
|
||||||
|
assert "code_challenge_method=S256" in url
|
||||||
|
assert "scope=read+write" in url
|
||||||
|
|
||||||
|
def test_get_login_url_with_resource(self):
|
||||||
|
handler = self._make_handler(resource_url="https://mcp.example.com/mcp")
|
||||||
|
url = handler.get_login_url(
|
||||||
|
scopes=[], state="state", code_challenge="challenge"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "resource=https" in url
|
||||||
|
|
||||||
|
def test_get_login_url_without_pkce(self):
|
||||||
|
handler = self._make_handler()
|
||||||
|
url = handler.get_login_url(scopes=["read"], state="state", code_challenge=None)
|
||||||
|
|
||||||
|
assert "code_challenge" not in url
|
||||||
|
assert "code_challenge_method" not in url
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exchange_code_for_tokens(self):
|
||||||
|
handler = self._make_handler()
|
||||||
|
|
||||||
|
resp = _mock_response(
|
||||||
|
{
|
||||||
|
"access_token": "new-access-token",
|
||||||
|
"refresh_token": "new-refresh-token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
||||||
|
instance = MockRequests.return_value
|
||||||
|
instance.post = AsyncMock(return_value=resp)
|
||||||
|
|
||||||
|
creds = await handler.exchange_code_for_tokens(
|
||||||
|
code="auth-code",
|
||||||
|
scopes=["read"],
|
||||||
|
code_verifier="pkce-verifier",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(creds, OAuth2Credentials)
|
||||||
|
assert creds.access_token.get_secret_value() == "new-access-token"
|
||||||
|
assert creds.refresh_token is not None
|
||||||
|
assert creds.refresh_token.get_secret_value() == "new-refresh-token"
|
||||||
|
assert creds.scopes == ["read"]
|
||||||
|
assert creds.access_token_expires_at is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_tokens(self):
|
||||||
|
handler = self._make_handler()
|
||||||
|
|
||||||
|
existing_creds = OAuth2Credentials(
|
||||||
|
id="existing-id",
|
||||||
|
provider="mcp",
|
||||||
|
access_token=SecretStr("old-token"),
|
||||||
|
refresh_token=SecretStr("old-refresh"),
|
||||||
|
scopes=["read"],
|
||||||
|
title="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = _mock_response(
|
||||||
|
{
|
||||||
|
"access_token": "refreshed-token",
|
||||||
|
"refresh_token": "new-refresh",
|
||||||
|
"expires_in": 3600,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
||||||
|
instance = MockRequests.return_value
|
||||||
|
instance.post = AsyncMock(return_value=resp)
|
||||||
|
|
||||||
|
refreshed = await handler._refresh_tokens(existing_creds)
|
||||||
|
|
||||||
|
assert refreshed.id == "existing-id"
|
||||||
|
assert refreshed.access_token.get_secret_value() == "refreshed-token"
|
||||||
|
assert refreshed.refresh_token is not None
|
||||||
|
assert refreshed.refresh_token.get_secret_value() == "new-refresh"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_tokens_no_refresh_token(self):
|
||||||
|
handler = self._make_handler()
|
||||||
|
|
||||||
|
creds = OAuth2Credentials(
|
||||||
|
provider="mcp",
|
||||||
|
access_token=SecretStr("token"),
|
||||||
|
scopes=["read"],
|
||||||
|
title="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No refresh token"):
|
||||||
|
await handler._refresh_tokens(creds)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_tokens_no_url(self):
|
||||||
|
handler = self._make_handler(revoke_url=None)
|
||||||
|
|
||||||
|
creds = OAuth2Credentials(
|
||||||
|
provider="mcp",
|
||||||
|
access_token=SecretStr("token"),
|
||||||
|
scopes=[],
|
||||||
|
title="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await handler.revoke_tokens(creds)
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_tokens_with_url(self):
|
||||||
|
handler = self._make_handler(revoke_url="https://auth.example.com/revoke")
|
||||||
|
|
||||||
|
creds = OAuth2Credentials(
|
||||||
|
provider="mcp",
|
||||||
|
access_token=SecretStr("token"),
|
||||||
|
scopes=[],
|
||||||
|
title="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = _mock_response({}, status=200)
|
||||||
|
|
||||||
|
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
||||||
|
instance = MockRequests.return_value
|
||||||
|
instance.post = AsyncMock(return_value=resp)
|
||||||
|
|
||||||
|
result = await handler.revoke_tokens(creds)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPClientDiscovery:
|
||||||
|
"""Tests for MCPClient OAuth metadata discovery."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discover_auth_found(self):
|
||||||
|
client = MCPClient("https://mcp.example.com/mcp")
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"authorization_servers": ["https://auth.example.com"],
|
||||||
|
"resource": "https://mcp.example.com/mcp",
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = _mock_response(metadata, status=200)
|
||||||
|
|
||||||
|
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
||||||
|
instance = MockRequests.return_value
|
||||||
|
instance.get = AsyncMock(return_value=resp)
|
||||||
|
|
||||||
|
result = await client.discover_auth()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["authorization_servers"] == ["https://auth.example.com"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discover_auth_not_found(self):
|
||||||
|
client = MCPClient("https://mcp.example.com/mcp")
|
||||||
|
|
||||||
|
resp = _mock_response({}, status=404)
|
||||||
|
|
||||||
|
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
||||||
|
instance = MockRequests.return_value
|
||||||
|
instance.get = AsyncMock(return_value=resp)
|
||||||
|
|
||||||
|
result = await client.discover_auth()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discover_auth_server_metadata(self):
|
||||||
|
client = MCPClient("https://mcp.example.com/mcp")
|
||||||
|
|
||||||
|
server_metadata = {
|
||||||
|
"issuer": "https://auth.example.com",
|
||||||
|
"authorization_endpoint": "https://auth.example.com/authorize",
|
||||||
|
"token_endpoint": "https://auth.example.com/token",
|
||||||
|
"registration_endpoint": "https://auth.example.com/register",
|
||||||
|
"code_challenge_methods_supported": ["S256"],
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = _mock_response(server_metadata, status=200)
|
||||||
|
|
||||||
|
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
||||||
|
instance = MockRequests.return_value
|
||||||
|
instance.get = AsyncMock(return_value=resp)
|
||||||
|
|
||||||
|
result = await client.discover_auth_server_metadata(
|
||||||
|
"https://auth.example.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["authorization_endpoint"] == "https://auth.example.com/authorize"
|
||||||
|
assert result["token_endpoint"] == "https://auth.example.com/token"
|
||||||
162
autogpt_platform/backend/backend/blocks/mcp/test_server.py
Normal file
162
autogpt_platform/backend/backend/blocks/mcp/test_server.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
"""
|
||||||
|
Minimal MCP server for integration testing.
|
||||||
|
|
||||||
|
Implements the MCP Streamable HTTP transport (JSON-RPC 2.0 over HTTP POST)
|
||||||
|
with a few sample tools. Runs on localhost with a random available port.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Sample tools this test server exposes
|
||||||
|
TEST_TOOLS = [
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get current weather for a city",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "City name",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["city"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "add_numbers",
|
||||||
|
"description": "Add two numbers together",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": {"type": "number", "description": "First number"},
|
||||||
|
"b": {"type": "number", "description": "Second number"},
|
||||||
|
},
|
||||||
|
"required": ["a", "b"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "echo",
|
||||||
|
"description": "Echo back the input message",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"message": {"type": "string", "description": "Message to echo"},
|
||||||
|
},
|
||||||
|
"required": ["message"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_initialize(params: dict) -> dict:
|
||||||
|
return {
|
||||||
|
"protocolVersion": "2025-03-26",
|
||||||
|
"capabilities": {"tools": {"listChanged": False}},
|
||||||
|
"serverInfo": {"name": "test-mcp-server", "version": "1.0.0"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_tools_list(params: dict) -> dict:
|
||||||
|
return {"tools": TEST_TOOLS}
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_tools_call(params: dict) -> dict:
|
||||||
|
tool_name = params.get("name", "")
|
||||||
|
arguments = params.get("arguments", {})
|
||||||
|
|
||||||
|
if tool_name == "get_weather":
|
||||||
|
city = arguments.get("city", "Unknown")
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": json.dumps(
|
||||||
|
{"city": city, "temperature": 22, "condition": "sunny"}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
elif tool_name == "add_numbers":
|
||||||
|
a = arguments.get("a", 0)
|
||||||
|
b = arguments.get("b", 0)
|
||||||
|
return {
|
||||||
|
"content": [{"type": "text", "text": json.dumps({"result": a + b})}],
|
||||||
|
}
|
||||||
|
|
||||||
|
elif tool_name == "echo":
|
||||||
|
message = arguments.get("message", "")
|
||||||
|
return {
|
||||||
|
"content": [{"type": "text", "text": message}],
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}],
|
||||||
|
"isError": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
HANDLERS = {
|
||||||
|
"initialize": _handle_initialize,
|
||||||
|
"tools/list": _handle_tools_list,
|
||||||
|
"tools/call": _handle_tools_call,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_mcp_request(request: web.Request) -> web.Response:
|
||||||
|
"""Handle incoming MCP JSON-RPC 2.0 requests."""
|
||||||
|
# Check auth if configured
|
||||||
|
expected_token = request.app.get("auth_token")
|
||||||
|
if expected_token:
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if auth_header != f"Bearer {expected_token}":
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"error": {"code": -32001, "message": "Unauthorized"},
|
||||||
|
"id": None,
|
||||||
|
},
|
||||||
|
status=401,
|
||||||
|
)
|
||||||
|
|
||||||
|
body = await request.json()
|
||||||
|
|
||||||
|
# Handle notifications (no id field) — just acknowledge
|
||||||
|
if "id" not in body:
|
||||||
|
return web.Response(status=202)
|
||||||
|
|
||||||
|
method = body.get("method", "")
|
||||||
|
params = body.get("params", {})
|
||||||
|
request_id = body.get("id")
|
||||||
|
|
||||||
|
handler = HANDLERS.get(method)
|
||||||
|
if not handler:
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"error": {
|
||||||
|
"code": -32601,
|
||||||
|
"message": f"Method not found: {method}",
|
||||||
|
},
|
||||||
|
"id": request_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = handler(params)
|
||||||
|
return web.json_response({"jsonrpc": "2.0", "result": result, "id": request_id})
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_mcp_app(auth_token: str | None = None) -> web.Application:
|
||||||
|
"""Create an aiohttp app that acts as an MCP server."""
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_post("/mcp", handle_mcp_request)
|
||||||
|
if auth_token:
|
||||||
|
app["auth_token"] = auth_token
|
||||||
|
return app
|
||||||
@@ -319,8 +319,6 @@ class BlockSchema(BaseModel):
|
|||||||
"credentials_provider": [config.get("provider", "google")],
|
"credentials_provider": [config.get("provider", "google")],
|
||||||
"credentials_types": [config.get("type", "oauth2")],
|
"credentials_types": [config.get("type", "oauth2")],
|
||||||
"credentials_scopes": config.get("scopes"),
|
"credentials_scopes": config.get("scopes"),
|
||||||
"is_auto_credential": True,
|
|
||||||
"input_field_name": info["field_name"],
|
|
||||||
}
|
}
|
||||||
result[kwarg_name] = CredentialsFieldInfo.model_validate(
|
result[kwarg_name] = CredentialsFieldInfo.model_validate(
|
||||||
auto_schema, by_alias=True
|
auto_schema, by_alias=True
|
||||||
|
|||||||
@@ -447,7 +447,8 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def credentials_input_schema(self) -> dict[str, Any]:
|
def credentials_input_schema(self) -> dict[str, Any]:
|
||||||
graph_credentials_inputs = self.regular_credentials_inputs
|
graph_credentials_inputs = self.aggregate_credentials_inputs()
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
f"Combined credentials input fields for graph #{self.id} ({self.name}): "
|
||||||
f"{graph_credentials_inputs}"
|
f"{graph_credentials_inputs}"
|
||||||
@@ -603,28 +604,6 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
for key, (field_info, node_field_pairs) in combined.items()
|
for key, (field_info, node_field_pairs) in combined.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
|
||||||
def regular_credentials_inputs(
|
|
||||||
self,
|
|
||||||
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
|
|
||||||
"""Credentials that need explicit user mapping (CredentialsMetaInput fields)."""
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in self.aggregate_credentials_inputs().items()
|
|
||||||
if not v[0].is_auto_credential
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def auto_credentials_inputs(
|
|
||||||
self,
|
|
||||||
) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]:
|
|
||||||
"""Credentials embedded in file fields (_credentials_id), resolved at execution time."""
|
|
||||||
return {
|
|
||||||
k: v
|
|
||||||
for k, v in self.aggregate_credentials_inputs().items()
|
|
||||||
if v[0].is_auto_credential
|
|
||||||
}
|
|
||||||
|
|
||||||
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
def reassign_ids(self, user_id: str, reassign_graph_id: bool = False):
|
||||||
"""
|
"""
|
||||||
Reassigns all IDs in the graph to new UUIDs.
|
Reassigns all IDs in the graph to new UUIDs.
|
||||||
@@ -675,16 +654,6 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
) and graph_id in graph_id_map:
|
) and graph_id in graph_id_map:
|
||||||
node.input_default["graph_id"] = graph_id_map[graph_id]
|
node.input_default["graph_id"] = graph_id_map[graph_id]
|
||||||
|
|
||||||
# Clear auto-credentials references (e.g., _credentials_id in
|
|
||||||
# GoogleDriveFile fields) so the new user must re-authenticate
|
|
||||||
# with their own account
|
|
||||||
for node in graph.nodes:
|
|
||||||
if not node.input_default:
|
|
||||||
continue
|
|
||||||
for key, value in node.input_default.items():
|
|
||||||
if isinstance(value, dict) and "_credentials_id" in value:
|
|
||||||
del value["_credentials_id"]
|
|
||||||
|
|
||||||
def validate_graph(
|
def validate_graph(
|
||||||
self,
|
self,
|
||||||
for_run: bool = False,
|
for_run: bool = False,
|
||||||
|
|||||||
@@ -463,328 +463,3 @@ def test_node_credentials_optional_with_other_metadata():
|
|||||||
assert node.credentials_optional is True
|
assert node.credentials_optional is True
|
||||||
assert node.metadata["position"] == {"x": 100, "y": 200}
|
assert node.metadata["position"] == {"x": 100, "y": 200}
|
||||||
assert node.metadata["customized_name"] == "My Custom Node"
|
assert node.metadata["customized_name"] == "My Custom Node"
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Tests for _reassign_ids credential clearing (Fix 3: SECRT-1772)
|
|
||||||
def test_combine_preserves_is_auto_credential_flag():
|
|
||||||
"""
|
|
||||||
CredentialsFieldInfo.combine() must propagate is_auto_credential and
|
|
||||||
input_field_name to the combined result. Regression test for reviewer
|
|
||||||
finding that combine() dropped these fields.
|
|
||||||
"""
|
|
||||||
from backend.data.model import CredentialsFieldInfo
|
|
||||||
|
|
||||||
auto_field = CredentialsFieldInfo.model_validate(
|
|
||||||
{
|
|
||||||
"credentials_provider": ["google"],
|
|
||||||
"credentials_types": ["oauth2"],
|
|
||||||
"credentials_scopes": ["drive.readonly"],
|
|
||||||
"is_auto_credential": True,
|
|
||||||
"input_field_name": "spreadsheet",
|
|
||||||
},
|
|
||||||
by_alias=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# combine() takes *args of (field_info, key) tuples
|
|
||||||
combined = CredentialsFieldInfo.combine(
|
|
||||||
(auto_field, ("node-1", "credentials")),
|
|
||||||
(auto_field, ("node-2", "credentials")),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(combined) == 1
|
|
||||||
group_key = next(iter(combined))
|
|
||||||
combined_info, combined_keys = combined[group_key]
|
|
||||||
|
|
||||||
assert combined_info.is_auto_credential is True
|
|
||||||
assert combined_info.input_field_name == "spreadsheet"
|
|
||||||
assert combined_keys == {("node-1", "credentials"), ("node-2", "credentials")}
|
|
||||||
|
|
||||||
|
|
||||||
def test_combine_preserves_regular_credential_defaults():
|
|
||||||
"""Regular credentials should have is_auto_credential=False after combine()."""
|
|
||||||
from backend.data.model import CredentialsFieldInfo
|
|
||||||
|
|
||||||
regular_field = CredentialsFieldInfo.model_validate(
|
|
||||||
{
|
|
||||||
"credentials_provider": ["github"],
|
|
||||||
"credentials_types": ["api_key"],
|
|
||||||
"is_auto_credential": False,
|
|
||||||
},
|
|
||||||
by_alias=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
combined = CredentialsFieldInfo.combine(
|
|
||||||
(regular_field, ("node-1", "credentials")),
|
|
||||||
)
|
|
||||||
|
|
||||||
group_key = next(iter(combined))
|
|
||||||
combined_info, _ = combined[group_key]
|
|
||||||
|
|
||||||
assert combined_info.is_auto_credential is False
|
|
||||||
assert combined_info.input_field_name is None
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def test_reassign_ids_clears_credentials_id():
|
|
||||||
"""
|
|
||||||
[SECRT-1772] _reassign_ids should clear _credentials_id from
|
|
||||||
GoogleDriveFile-style input_default fields so forked agents
|
|
||||||
don't retain the original creator's credential references.
|
|
||||||
"""
|
|
||||||
from backend.data.graph import GraphModel
|
|
||||||
|
|
||||||
node = Node(
|
|
||||||
id="node-1",
|
|
||||||
block_id=StoreValueBlock().id,
|
|
||||||
input_default={
|
|
||||||
"spreadsheet": {
|
|
||||||
"_credentials_id": "original-cred-id",
|
|
||||||
"id": "file-123",
|
|
||||||
"name": "test.xlsx",
|
|
||||||
"mimeType": "application/vnd.google-apps.spreadsheet",
|
|
||||||
"url": "https://docs.google.com/spreadsheets/d/file-123",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
graph = Graph(
|
|
||||||
id="test-graph",
|
|
||||||
name="Test",
|
|
||||||
description="Test",
|
|
||||||
nodes=[node],
|
|
||||||
links=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
|
|
||||||
|
|
||||||
# _credentials_id key should be removed (not set to None) so that
|
|
||||||
# _acquire_auto_credentials correctly errors instead of treating it as chained data
|
|
||||||
assert "_credentials_id" not in graph.nodes[0].input_default["spreadsheet"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_reassign_ids_preserves_non_credential_fields():
|
|
||||||
"""
|
|
||||||
Regression guard: _reassign_ids should NOT modify non-credential fields
|
|
||||||
like name, mimeType, id, url.
|
|
||||||
"""
|
|
||||||
from backend.data.graph import GraphModel
|
|
||||||
|
|
||||||
node = Node(
|
|
||||||
id="node-1",
|
|
||||||
block_id=StoreValueBlock().id,
|
|
||||||
input_default={
|
|
||||||
"spreadsheet": {
|
|
||||||
"_credentials_id": "cred-abc",
|
|
||||||
"id": "file-123",
|
|
||||||
"name": "test.xlsx",
|
|
||||||
"mimeType": "application/vnd.google-apps.spreadsheet",
|
|
||||||
"url": "https://docs.google.com/spreadsheets/d/file-123",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
graph = Graph(
|
|
||||||
id="test-graph",
|
|
||||||
name="Test",
|
|
||||||
description="Test",
|
|
||||||
nodes=[node],
|
|
||||||
links=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
|
|
||||||
|
|
||||||
field = graph.nodes[0].input_default["spreadsheet"]
|
|
||||||
assert field["id"] == "file-123"
|
|
||||||
assert field["name"] == "test.xlsx"
|
|
||||||
assert field["mimeType"] == "application/vnd.google-apps.spreadsheet"
|
|
||||||
assert field["url"] == "https://docs.google.com/spreadsheets/d/file-123"
|
|
||||||
|
|
||||||
|
|
||||||
def test_reassign_ids_handles_no_credentials():
|
|
||||||
"""
|
|
||||||
Regression guard: _reassign_ids should not error when input_default
|
|
||||||
has no dict fields with _credentials_id.
|
|
||||||
"""
|
|
||||||
from backend.data.graph import GraphModel
|
|
||||||
|
|
||||||
node = Node(
|
|
||||||
id="node-1",
|
|
||||||
block_id=StoreValueBlock().id,
|
|
||||||
input_default={
|
|
||||||
"input": "some value",
|
|
||||||
"another_input": 42,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
graph = Graph(
|
|
||||||
id="test-graph",
|
|
||||||
name="Test",
|
|
||||||
description="Test",
|
|
||||||
nodes=[node],
|
|
||||||
links=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
|
|
||||||
|
|
||||||
# Should not error, fields unchanged
|
|
||||||
assert graph.nodes[0].input_default["input"] == "some value"
|
|
||||||
assert graph.nodes[0].input_default["another_input"] == 42
|
|
||||||
|
|
||||||
|
|
||||||
def test_reassign_ids_handles_multiple_credential_fields():
|
|
||||||
"""
|
|
||||||
[SECRT-1772] When a node has multiple dict fields with _credentials_id,
|
|
||||||
ALL of them should be cleared.
|
|
||||||
"""
|
|
||||||
from backend.data.graph import GraphModel
|
|
||||||
|
|
||||||
node = Node(
|
|
||||||
id="node-1",
|
|
||||||
block_id=StoreValueBlock().id,
|
|
||||||
input_default={
|
|
||||||
"spreadsheet": {
|
|
||||||
"_credentials_id": "cred-1",
|
|
||||||
"id": "file-1",
|
|
||||||
"name": "file1.xlsx",
|
|
||||||
},
|
|
||||||
"doc_file": {
|
|
||||||
"_credentials_id": "cred-2",
|
|
||||||
"id": "file-2",
|
|
||||||
"name": "file2.docx",
|
|
||||||
},
|
|
||||||
"plain_input": "not a dict",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
graph = Graph(
|
|
||||||
id="test-graph",
|
|
||||||
name="Test",
|
|
||||||
description="Test",
|
|
||||||
nodes=[node],
|
|
||||||
links=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
GraphModel._reassign_ids(graph, user_id="new-user", graph_id_map={})
|
|
||||||
|
|
||||||
assert "_credentials_id" not in graph.nodes[0].input_default["spreadsheet"]
|
|
||||||
assert "_credentials_id" not in graph.nodes[0].input_default["doc_file"]
|
|
||||||
assert graph.nodes[0].input_default["plain_input"] == "not a dict"
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Tests for discriminate() field propagation
|
|
||||||
def test_discriminate_preserves_is_auto_credential_flag():
|
|
||||||
"""
|
|
||||||
CredentialsFieldInfo.discriminate() must propagate is_auto_credential and
|
|
||||||
input_field_name to the discriminated result. Regression test for
|
|
||||||
discriminate() dropping these fields (same class of bug as combine()).
|
|
||||||
"""
|
|
||||||
from backend.data.model import CredentialsFieldInfo
|
|
||||||
|
|
||||||
auto_field = CredentialsFieldInfo.model_validate(
|
|
||||||
{
|
|
||||||
"credentials_provider": ["google", "openai"],
|
|
||||||
"credentials_types": ["oauth2"],
|
|
||||||
"credentials_scopes": ["drive.readonly"],
|
|
||||||
"is_auto_credential": True,
|
|
||||||
"input_field_name": "spreadsheet",
|
|
||||||
"discriminator": "model",
|
|
||||||
"discriminator_mapping": {"gpt-4": "openai", "gemini": "google"},
|
|
||||||
},
|
|
||||||
by_alias=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
discriminated = auto_field.discriminate("gemini")
|
|
||||||
|
|
||||||
assert discriminated.is_auto_credential is True
|
|
||||||
assert discriminated.input_field_name == "spreadsheet"
|
|
||||||
assert discriminated.provider == frozenset(["google"])
|
|
||||||
|
|
||||||
|
|
||||||
def test_discriminate_preserves_regular_credential_defaults():
|
|
||||||
"""Regular credentials should have is_auto_credential=False after discriminate()."""
|
|
||||||
from backend.data.model import CredentialsFieldInfo
|
|
||||||
|
|
||||||
regular_field = CredentialsFieldInfo.model_validate(
|
|
||||||
{
|
|
||||||
"credentials_provider": ["google", "openai"],
|
|
||||||
"credentials_types": ["api_key"],
|
|
||||||
"is_auto_credential": False,
|
|
||||||
"discriminator": "model",
|
|
||||||
"discriminator_mapping": {"gpt-4": "openai", "gemini": "google"},
|
|
||||||
},
|
|
||||||
by_alias=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
discriminated = regular_field.discriminate("gpt-4")
|
|
||||||
|
|
||||||
assert discriminated.is_auto_credential is False
|
|
||||||
assert discriminated.input_field_name is None
|
|
||||||
assert discriminated.provider == frozenset(["openai"])
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Tests for credentials_input_schema excluding auto_credentials
|
|
||||||
def test_credentials_input_schema_excludes_auto_creds():
|
|
||||||
"""
|
|
||||||
GraphModel.credentials_input_schema should exclude auto_credentials
|
|
||||||
(is_auto_credential=True) from the schema. Auto_credentials are
|
|
||||||
transparently resolved at execution time via file picker data.
|
|
||||||
"""
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from unittest.mock import PropertyMock, patch
|
|
||||||
|
|
||||||
from backend.data.graph import GraphModel, NodeModel
|
|
||||||
from backend.data.model import CredentialsFieldInfo
|
|
||||||
|
|
||||||
regular_field_info = CredentialsFieldInfo.model_validate(
|
|
||||||
{
|
|
||||||
"credentials_provider": ["github"],
|
|
||||||
"credentials_types": ["api_key"],
|
|
||||||
"is_auto_credential": False,
|
|
||||||
},
|
|
||||||
by_alias=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
graph = GraphModel(
|
|
||||||
id="test-graph",
|
|
||||||
version=1,
|
|
||||||
name="Test",
|
|
||||||
description="Test",
|
|
||||||
user_id="test-user",
|
|
||||||
created_at=datetime.now(timezone.utc),
|
|
||||||
nodes=[
|
|
||||||
NodeModel(
|
|
||||||
id="node-1",
|
|
||||||
block_id=StoreValueBlock().id,
|
|
||||||
input_default={},
|
|
||||||
graph_id="test-graph",
|
|
||||||
graph_version=1,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
links=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock regular_credentials_inputs to return only the non-auto field (3-tuple)
|
|
||||||
regular_only = {
|
|
||||||
"github_credentials": (
|
|
||||||
regular_field_info,
|
|
||||||
{("node-1", "credentials")},
|
|
||||||
True,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch.object(
|
|
||||||
type(graph),
|
|
||||||
"regular_credentials_inputs",
|
|
||||||
new_callable=PropertyMock,
|
|
||||||
return_value=regular_only,
|
|
||||||
):
|
|
||||||
schema = graph.credentials_input_schema
|
|
||||||
field_names = set(schema.get("properties", {}).keys())
|
|
||||||
# Should include regular credential but NOT auto_credential
|
|
||||||
assert "github_credentials" in field_names
|
|
||||||
assert "google_credentials" not in field_names
|
|
||||||
|
|||||||
@@ -571,8 +571,6 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
discriminator: Optional[str] = None
|
discriminator: Optional[str] = None
|
||||||
discriminator_mapping: Optional[dict[str, CP]] = None
|
discriminator_mapping: Optional[dict[str, CP]] = None
|
||||||
discriminator_values: set[Any] = Field(default_factory=set)
|
discriminator_values: set[Any] = Field(default_factory=set)
|
||||||
is_auto_credential: bool = False
|
|
||||||
input_field_name: Optional[str] = None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def combine(
|
def combine(
|
||||||
@@ -653,9 +651,6 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
+ "_credentials"
|
+ "_credentials"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Propagate is_auto_credential from the combined field.
|
|
||||||
# All fields in a group should share the same is_auto_credential
|
|
||||||
# value since auto and regular credentials serve different purposes.
|
|
||||||
result[group_key] = (
|
result[group_key] = (
|
||||||
CredentialsFieldInfo[CP, CT](
|
CredentialsFieldInfo[CP, CT](
|
||||||
credentials_provider=combined.provider,
|
credentials_provider=combined.provider,
|
||||||
@@ -664,8 +659,6 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
discriminator=combined.discriminator,
|
discriminator=combined.discriminator,
|
||||||
discriminator_mapping=combined.discriminator_mapping,
|
discriminator_mapping=combined.discriminator_mapping,
|
||||||
discriminator_values=set(all_discriminator_values),
|
discriminator_values=set(all_discriminator_values),
|
||||||
is_auto_credential=combined.is_auto_credential,
|
|
||||||
input_field_name=combined.input_field_name,
|
|
||||||
),
|
),
|
||||||
combined_keys,
|
combined_keys,
|
||||||
)
|
)
|
||||||
@@ -691,8 +684,6 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
discriminator=self.discriminator,
|
discriminator=self.discriminator,
|
||||||
discriminator_mapping=self.discriminator_mapping,
|
discriminator_mapping=self.discriminator_mapping,
|
||||||
discriminator_values=self.discriminator_values,
|
discriminator_values=self.discriminator_values,
|
||||||
is_auto_credential=self.is_auto_credential,
|
|
||||||
input_field_name=self.input_field_name,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -172,81 +172,6 @@ def execute_graph(
|
|||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
async def _acquire_auto_credentials(
|
|
||||||
input_model: type[BlockSchema],
|
|
||||||
input_data: dict[str, Any],
|
|
||||||
creds_manager: "IntegrationCredentialsManager",
|
|
||||||
user_id: str,
|
|
||||||
) -> tuple[dict[str, Any], list[AsyncRedisLock]]:
|
|
||||||
"""
|
|
||||||
Resolve auto_credentials from GoogleDriveFileField-style inputs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(extra_exec_kwargs, locks): kwargs to inject into block execution, and
|
|
||||||
credential locks to release after execution completes.
|
|
||||||
"""
|
|
||||||
extra_exec_kwargs: dict[str, Any] = {}
|
|
||||||
locks: list[AsyncRedisLock] = []
|
|
||||||
|
|
||||||
# NOTE: If a block ever has multiple auto-credential fields, a ValueError
|
|
||||||
# on a later field will strand locks acquired for earlier fields. They'll
|
|
||||||
# auto-expire via Redis TTL, but add a try/except to release partial locks
|
|
||||||
# if that becomes a real scenario.
|
|
||||||
for kwarg_name, info in input_model.get_auto_credentials_fields().items():
|
|
||||||
field_name = info["field_name"]
|
|
||||||
field_data = input_data.get(field_name)
|
|
||||||
|
|
||||||
if field_data and isinstance(field_data, dict):
|
|
||||||
# Check if _credentials_id key exists in the field data
|
|
||||||
if "_credentials_id" in field_data:
|
|
||||||
cred_id = field_data["_credentials_id"]
|
|
||||||
if cred_id:
|
|
||||||
# Credential ID provided - acquire credentials
|
|
||||||
provider = info.get("config", {}).get(
|
|
||||||
"provider", "external service"
|
|
||||||
)
|
|
||||||
file_name = field_data.get("name", "selected file")
|
|
||||||
try:
|
|
||||||
credentials, lock = await creds_manager.acquire(
|
|
||||||
user_id, cred_id
|
|
||||||
)
|
|
||||||
locks.append(lock)
|
|
||||||
extra_exec_kwargs[kwarg_name] = credentials
|
|
||||||
except ValueError:
|
|
||||||
raise ValueError(
|
|
||||||
f"{provider.capitalize()} credentials for "
|
|
||||||
f"'{file_name}' in field '{field_name}' are not "
|
|
||||||
f"available in your account. "
|
|
||||||
f"This can happen if the agent was created by another "
|
|
||||||
f"user or the credentials were deleted. "
|
|
||||||
f"Please open the agent in the builder and re-select "
|
|
||||||
f"the file to authenticate with your own account."
|
|
||||||
)
|
|
||||||
# else: _credentials_id is explicitly None, skip (chained data)
|
|
||||||
else:
|
|
||||||
# _credentials_id key missing entirely - this is an error
|
|
||||||
provider = info.get("config", {}).get("provider", "external service")
|
|
||||||
file_name = field_data.get("name", "selected file")
|
|
||||||
raise ValueError(
|
|
||||||
f"Authentication missing for '{file_name}' in field "
|
|
||||||
f"'{field_name}'. Please re-select the file to authenticate "
|
|
||||||
f"with {provider.capitalize()}."
|
|
||||||
)
|
|
||||||
elif field_data is None and field_name not in input_data:
|
|
||||||
# Field not in input_data at all = connected from upstream block, skip
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# field_data is None/empty but key IS in input_data = user didn't select
|
|
||||||
provider = info.get("config", {}).get("provider", "external service")
|
|
||||||
raise ValueError(
|
|
||||||
f"No file selected for '{field_name}'. "
|
|
||||||
f"Please select a file to provide "
|
|
||||||
f"{provider.capitalize()} authentication."
|
|
||||||
)
|
|
||||||
|
|
||||||
return extra_exec_kwargs, locks
|
|
||||||
|
|
||||||
|
|
||||||
async def execute_node(
|
async def execute_node(
|
||||||
node: Node,
|
node: Node,
|
||||||
data: NodeExecutionEntry,
|
data: NodeExecutionEntry,
|
||||||
@@ -346,14 +271,41 @@ async def execute_node(
|
|||||||
extra_exec_kwargs[field_name] = credentials
|
extra_exec_kwargs[field_name] = credentials
|
||||||
|
|
||||||
# Handle auto-generated credentials (e.g., from GoogleDriveFileInput)
|
# Handle auto-generated credentials (e.g., from GoogleDriveFileInput)
|
||||||
auto_extra_kwargs, auto_locks = await _acquire_auto_credentials(
|
for kwarg_name, info in input_model.get_auto_credentials_fields().items():
|
||||||
input_model=input_model,
|
field_name = info["field_name"]
|
||||||
input_data=input_data,
|
field_data = input_data.get(field_name)
|
||||||
creds_manager=creds_manager,
|
if field_data and isinstance(field_data, dict):
|
||||||
user_id=user_id,
|
# Check if _credentials_id key exists in the field data
|
||||||
)
|
if "_credentials_id" in field_data:
|
||||||
extra_exec_kwargs.update(auto_extra_kwargs)
|
cred_id = field_data["_credentials_id"]
|
||||||
creds_locks.extend(auto_locks)
|
if cred_id:
|
||||||
|
# Credential ID provided - acquire credentials
|
||||||
|
provider = info.get("config", {}).get(
|
||||||
|
"provider", "external service"
|
||||||
|
)
|
||||||
|
file_name = field_data.get("name", "selected file")
|
||||||
|
try:
|
||||||
|
credentials, lock = await creds_manager.acquire(
|
||||||
|
user_id, cred_id
|
||||||
|
)
|
||||||
|
creds_locks.append(lock)
|
||||||
|
extra_exec_kwargs[kwarg_name] = credentials
|
||||||
|
except ValueError:
|
||||||
|
# Credential was deleted or doesn't exist
|
||||||
|
raise ValueError(
|
||||||
|
f"Authentication expired for '{file_name}' in field '{field_name}'. "
|
||||||
|
f"The saved {provider.capitalize()} credentials no longer exist. "
|
||||||
|
f"Please re-select the file to re-authenticate."
|
||||||
|
)
|
||||||
|
# else: _credentials_id is explicitly None, skip credentials (for chained data)
|
||||||
|
else:
|
||||||
|
# _credentials_id key missing entirely - this is an error
|
||||||
|
provider = info.get("config", {}).get("provider", "external service")
|
||||||
|
file_name = field_data.get("name", "selected file")
|
||||||
|
raise ValueError(
|
||||||
|
f"Authentication missing for '{file_name}' in field '{field_name}'. "
|
||||||
|
f"Please re-select the file to authenticate with {provider.capitalize()}."
|
||||||
|
)
|
||||||
|
|
||||||
output_size = 0
|
output_size = 0
|
||||||
|
|
||||||
|
|||||||
@@ -1,320 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for auto_credentials handling in execute_node().
|
|
||||||
|
|
||||||
These test the _acquire_auto_credentials() helper function extracted from
|
|
||||||
execute_node() (manager.py lines 273-308).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pytest_mock import MockerFixture
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def google_drive_file_data():
|
|
||||||
return {
|
|
||||||
"valid": {
|
|
||||||
"_credentials_id": "cred-id-123",
|
|
||||||
"id": "file-123",
|
|
||||||
"name": "test.xlsx",
|
|
||||||
"mimeType": "application/vnd.google-apps.spreadsheet",
|
|
||||||
},
|
|
||||||
"chained": {
|
|
||||||
"_credentials_id": None,
|
|
||||||
"id": "file-456",
|
|
||||||
"name": "chained.xlsx",
|
|
||||||
"mimeType": "application/vnd.google-apps.spreadsheet",
|
|
||||||
},
|
|
||||||
"missing_key": {
|
|
||||||
"id": "file-789",
|
|
||||||
"name": "bad.xlsx",
|
|
||||||
"mimeType": "application/vnd.google-apps.spreadsheet",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_input_model(mocker: MockerFixture):
|
|
||||||
"""Create a mock input model with get_auto_credentials_fields() returning one field."""
|
|
||||||
input_model = mocker.MagicMock()
|
|
||||||
input_model.get_auto_credentials_fields.return_value = {
|
|
||||||
"credentials": {
|
|
||||||
"field_name": "spreadsheet",
|
|
||||||
"config": {
|
|
||||||
"provider": "google",
|
|
||||||
"type": "oauth2",
|
|
||||||
"scopes": ["https://www.googleapis.com/auth/drive.readonly"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return input_model
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_creds_manager(mocker: MockerFixture):
|
|
||||||
manager = mocker.AsyncMock()
|
|
||||||
mock_lock = mocker.AsyncMock()
|
|
||||||
mock_creds = mocker.MagicMock()
|
|
||||||
mock_creds.id = "cred-id-123"
|
|
||||||
mock_creds.provider = "google"
|
|
||||||
manager.acquire.return_value = (mock_creds, mock_lock)
|
|
||||||
return manager, mock_creds, mock_lock
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_auto_credentials_happy_path(
|
|
||||||
mocker: MockerFixture,
|
|
||||||
google_drive_file_data,
|
|
||||||
mock_input_model,
|
|
||||||
mock_creds_manager,
|
|
||||||
):
|
|
||||||
"""When field_data has a valid _credentials_id, credentials should be acquired."""
|
|
||||||
from backend.executor.manager import _acquire_auto_credentials
|
|
||||||
|
|
||||||
manager, mock_creds, mock_lock = mock_creds_manager
|
|
||||||
input_data = {"spreadsheet": google_drive_file_data["valid"]}
|
|
||||||
|
|
||||||
extra_kwargs, locks = await _acquire_auto_credentials(
|
|
||||||
input_model=mock_input_model,
|
|
||||||
input_data=input_data,
|
|
||||||
creds_manager=manager,
|
|
||||||
user_id="user-1",
|
|
||||||
)
|
|
||||||
|
|
||||||
manager.acquire.assert_called_once_with("user-1", "cred-id-123")
|
|
||||||
assert extra_kwargs["credentials"] == mock_creds
|
|
||||||
assert mock_lock in locks
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_auto_credentials_field_none_static_raises(
|
|
||||||
mocker: MockerFixture,
|
|
||||||
mock_input_model,
|
|
||||||
mock_creds_manager,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
[THE BUG FIX TEST — OPEN-2895]
|
|
||||||
When field_data is None and the key IS in input_data (user didn't select a file),
|
|
||||||
should raise ValueError instead of silently skipping.
|
|
||||||
"""
|
|
||||||
from backend.executor.manager import _acquire_auto_credentials
|
|
||||||
|
|
||||||
manager, _, _ = mock_creds_manager
|
|
||||||
# Key is present but value is None = user didn't select a file
|
|
||||||
input_data = {"spreadsheet": None}
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="No file selected"):
|
|
||||||
await _acquire_auto_credentials(
|
|
||||||
input_model=mock_input_model,
|
|
||||||
input_data=input_data,
|
|
||||||
creds_manager=manager,
|
|
||||||
user_id="user-1",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_auto_credentials_field_absent_skips(
|
|
||||||
mocker: MockerFixture,
|
|
||||||
mock_input_model,
|
|
||||||
mock_creds_manager,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
When the field key is NOT in input_data at all (upstream connection),
|
|
||||||
should skip without error.
|
|
||||||
"""
|
|
||||||
from backend.executor.manager import _acquire_auto_credentials
|
|
||||||
|
|
||||||
manager, _, _ = mock_creds_manager
|
|
||||||
# Key not present = connected from upstream block
|
|
||||||
input_data = {}
|
|
||||||
|
|
||||||
extra_kwargs, locks = await _acquire_auto_credentials(
|
|
||||||
input_model=mock_input_model,
|
|
||||||
input_data=input_data,
|
|
||||||
creds_manager=manager,
|
|
||||||
user_id="user-1",
|
|
||||||
)
|
|
||||||
|
|
||||||
manager.acquire.assert_not_called()
|
|
||||||
assert "credentials" not in extra_kwargs
|
|
||||||
assert locks == []
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_auto_credentials_chained_cred_id_none(
|
|
||||||
mocker: MockerFixture,
|
|
||||||
google_drive_file_data,
|
|
||||||
mock_input_model,
|
|
||||||
mock_creds_manager,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
When _credentials_id is explicitly None (chained data from upstream),
|
|
||||||
should skip credential acquisition.
|
|
||||||
"""
|
|
||||||
from backend.executor.manager import _acquire_auto_credentials
|
|
||||||
|
|
||||||
manager, _, _ = mock_creds_manager
|
|
||||||
input_data = {"spreadsheet": google_drive_file_data["chained"]}
|
|
||||||
|
|
||||||
extra_kwargs, locks = await _acquire_auto_credentials(
|
|
||||||
input_model=mock_input_model,
|
|
||||||
input_data=input_data,
|
|
||||||
creds_manager=manager,
|
|
||||||
user_id="user-1",
|
|
||||||
)
|
|
||||||
|
|
||||||
manager.acquire.assert_not_called()
|
|
||||||
assert "credentials" not in extra_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_auto_credentials_missing_cred_id_key_raises(
|
|
||||||
mocker: MockerFixture,
|
|
||||||
google_drive_file_data,
|
|
||||||
mock_input_model,
|
|
||||||
mock_creds_manager,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
When _credentials_id key is missing entirely from field_data dict,
|
|
||||||
should raise ValueError.
|
|
||||||
"""
|
|
||||||
from backend.executor.manager import _acquire_auto_credentials
|
|
||||||
|
|
||||||
manager, _, _ = mock_creds_manager
|
|
||||||
input_data = {"spreadsheet": google_drive_file_data["missing_key"]}
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Authentication missing"):
|
|
||||||
await _acquire_auto_credentials(
|
|
||||||
input_model=mock_input_model,
|
|
||||||
input_data=input_data,
|
|
||||||
creds_manager=manager,
|
|
||||||
user_id="user-1",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_auto_credentials_ownership_mismatch_error(
|
|
||||||
mocker: MockerFixture,
|
|
||||||
google_drive_file_data,
|
|
||||||
mock_input_model,
|
|
||||||
mock_creds_manager,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
[SECRT-1772] When acquire() raises ValueError (credential belongs to another user),
|
|
||||||
the error message should mention 'not available' (not 'expired').
|
|
||||||
"""
|
|
||||||
from backend.executor.manager import _acquire_auto_credentials
|
|
||||||
|
|
||||||
manager, _, _ = mock_creds_manager
|
|
||||||
manager.acquire.side_effect = ValueError(
|
|
||||||
"Credentials #cred-id-123 for user #user-2 not found"
|
|
||||||
)
|
|
||||||
input_data = {"spreadsheet": google_drive_file_data["valid"]}
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="not available in your account"):
|
|
||||||
await _acquire_auto_credentials(
|
|
||||||
input_model=mock_input_model,
|
|
||||||
input_data=input_data,
|
|
||||||
creds_manager=manager,
|
|
||||||
user_id="user-2",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_auto_credentials_deleted_credential_error(
|
|
||||||
mocker: MockerFixture,
|
|
||||||
google_drive_file_data,
|
|
||||||
mock_input_model,
|
|
||||||
mock_creds_manager,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
[SECRT-1772] When acquire() raises ValueError (credential was deleted),
|
|
||||||
the error message should mention 'not available' (not 'expired').
|
|
||||||
"""
|
|
||||||
from backend.executor.manager import _acquire_auto_credentials
|
|
||||||
|
|
||||||
manager, _, _ = mock_creds_manager
|
|
||||||
manager.acquire.side_effect = ValueError(
|
|
||||||
"Credentials #cred-id-123 for user #user-1 not found"
|
|
||||||
)
|
|
||||||
input_data = {"spreadsheet": google_drive_file_data["valid"]}
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="not available in your account"):
|
|
||||||
await _acquire_auto_credentials(
|
|
||||||
input_model=mock_input_model,
|
|
||||||
input_data=input_data,
|
|
||||||
creds_manager=manager,
|
|
||||||
user_id="user-1",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_auto_credentials_lock_appended(
|
|
||||||
mocker: MockerFixture,
|
|
||||||
google_drive_file_data,
|
|
||||||
mock_input_model,
|
|
||||||
mock_creds_manager,
|
|
||||||
):
|
|
||||||
"""Lock from acquire() should be included in returned locks list."""
|
|
||||||
from backend.executor.manager import _acquire_auto_credentials
|
|
||||||
|
|
||||||
manager, _, mock_lock = mock_creds_manager
|
|
||||||
input_data = {"spreadsheet": google_drive_file_data["valid"]}
|
|
||||||
|
|
||||||
extra_kwargs, locks = await _acquire_auto_credentials(
|
|
||||||
input_model=mock_input_model,
|
|
||||||
input_data=input_data,
|
|
||||||
creds_manager=manager,
|
|
||||||
user_id="user-1",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(locks) == 1
|
|
||||||
assert locks[0] is mock_lock
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_auto_credentials_multiple_fields(
|
|
||||||
mocker: MockerFixture,
|
|
||||||
mock_creds_manager,
|
|
||||||
):
|
|
||||||
"""When there are multiple auto_credentials fields, only valid ones should acquire."""
|
|
||||||
from backend.executor.manager import _acquire_auto_credentials
|
|
||||||
|
|
||||||
manager, mock_creds, mock_lock = mock_creds_manager
|
|
||||||
|
|
||||||
input_model = mocker.MagicMock()
|
|
||||||
input_model.get_auto_credentials_fields.return_value = {
|
|
||||||
"credentials": {
|
|
||||||
"field_name": "spreadsheet",
|
|
||||||
"config": {"provider": "google", "type": "oauth2"},
|
|
||||||
},
|
|
||||||
"credentials2": {
|
|
||||||
"field_name": "doc_file",
|
|
||||||
"config": {"provider": "google", "type": "oauth2"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
input_data = {
|
|
||||||
"spreadsheet": {
|
|
||||||
"_credentials_id": "cred-id-123",
|
|
||||||
"id": "file-1",
|
|
||||||
"name": "file1.xlsx",
|
|
||||||
},
|
|
||||||
"doc_file": {
|
|
||||||
"_credentials_id": None,
|
|
||||||
"id": "file-2",
|
|
||||||
"name": "chained.doc",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
extra_kwargs, locks = await _acquire_auto_credentials(
|
|
||||||
input_model=input_model,
|
|
||||||
input_data=input_data,
|
|
||||||
creds_manager=manager,
|
|
||||||
user_id="user-1",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only the first field should have acquired credentials
|
|
||||||
manager.acquire.assert_called_once_with("user-1", "cred-id-123")
|
|
||||||
assert "credentials" in extra_kwargs
|
|
||||||
assert "credentials2" not in extra_kwargs
|
|
||||||
assert len(locks) == 1
|
|
||||||
@@ -259,8 +259,7 @@ async def _validate_node_input_credentials(
|
|||||||
|
|
||||||
# Find any fields of type CredentialsMetaInput
|
# Find any fields of type CredentialsMetaInput
|
||||||
credentials_fields = block.input_schema.get_credentials_fields()
|
credentials_fields = block.input_schema.get_credentials_fields()
|
||||||
auto_credentials_fields = block.input_schema.get_auto_credentials_fields()
|
if not credentials_fields:
|
||||||
if not credentials_fields and not auto_credentials_fields:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Track if any credential field is missing for this node
|
# Track if any credential field is missing for this node
|
||||||
@@ -340,47 +339,6 @@ async def _validate_node_input_credentials(
|
|||||||
] = "Invalid credentials: type/provider mismatch"
|
] = "Invalid credentials: type/provider mismatch"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Validate auto-credentials (GoogleDriveFileField-based)
|
|
||||||
# These have _credentials_id embedded in the file field data
|
|
||||||
if auto_credentials_fields:
|
|
||||||
for _kwarg_name, info in auto_credentials_fields.items():
|
|
||||||
field_name = info["field_name"]
|
|
||||||
# Check input_default and nodes_input_masks for the field value
|
|
||||||
field_value = node.input_default.get(field_name)
|
|
||||||
if nodes_input_masks and node.id in nodes_input_masks:
|
|
||||||
field_value = nodes_input_masks[node.id].get(
|
|
||||||
field_name, field_value
|
|
||||||
)
|
|
||||||
|
|
||||||
if field_value and isinstance(field_value, dict):
|
|
||||||
if "_credentials_id" not in field_value:
|
|
||||||
# Key removed (e.g., on fork) — needs re-auth
|
|
||||||
has_missing_credentials = True
|
|
||||||
credential_errors[node.id][field_name] = (
|
|
||||||
"Authentication missing for the selected file. "
|
|
||||||
"Please re-select the file to authenticate with "
|
|
||||||
"your own account."
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
cred_id = field_value.get("_credentials_id")
|
|
||||||
if cred_id and isinstance(cred_id, str):
|
|
||||||
try:
|
|
||||||
creds_store = get_integration_credentials_store()
|
|
||||||
creds = await creds_store.get_creds_by_id(user_id, cred_id)
|
|
||||||
except Exception as e:
|
|
||||||
has_missing_credentials = True
|
|
||||||
credential_errors[node.id][
|
|
||||||
field_name
|
|
||||||
] = f"Credentials not available: {e}"
|
|
||||||
continue
|
|
||||||
if not creds:
|
|
||||||
has_missing_credentials = True
|
|
||||||
credential_errors[node.id][field_name] = (
|
|
||||||
"The saved credentials are not available "
|
|
||||||
"for your account. Please re-select the file to "
|
|
||||||
"authenticate with your own account."
|
|
||||||
)
|
|
||||||
|
|
||||||
# If node has optional credentials and any are missing, mark for skipping
|
# If node has optional credentials and any are missing, mark for skipping
|
||||||
# But only if there are no other errors for this node
|
# But only if there are no other errors for this node
|
||||||
if (
|
if (
|
||||||
@@ -412,9 +370,8 @@ def make_node_credentials_input_map(
|
|||||||
"""
|
"""
|
||||||
result: dict[str, dict[str, JsonValue]] = {}
|
result: dict[str, dict[str, JsonValue]] = {}
|
||||||
|
|
||||||
# Only map regular credentials (not auto_credentials, which are resolved
|
# Get aggregated credentials fields for the graph
|
||||||
# at execution time from _credentials_id in file field data)
|
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||||
graph_cred_inputs = graph.regular_credentials_inputs
|
|
||||||
|
|
||||||
for graph_input_name, (_, compatible_node_fields, _) in graph_cred_inputs.items():
|
for graph_input_name, (_, compatible_node_fields, _) in graph_cred_inputs.items():
|
||||||
# Best-effort map: skip missing items
|
# Best-effort map: skip missing items
|
||||||
|
|||||||
@@ -907,335 +907,3 @@ async def test_stop_graph_execution_cascades_to_child_with_reviews(
|
|||||||
|
|
||||||
# Verify both parent and child status updates
|
# Verify both parent and child status updates
|
||||||
assert mock_execution_db.update_graph_execution_stats.call_count >= 1
|
assert mock_execution_db.update_graph_execution_stats.call_count >= 1
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Tests for auto_credentials validation in _validate_node_input_credentials
|
|
||||||
# (Fix 3: SECRT-1772 + Fix 4: Path 4)
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_validate_node_input_credentials_auto_creds_valid(
|
|
||||||
mocker: MockerFixture,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
[SECRT-1772] When a node has auto_credentials with a valid _credentials_id
|
|
||||||
that exists in the store, validation should pass without errors.
|
|
||||||
"""
|
|
||||||
from backend.executor.utils import _validate_node_input_credentials
|
|
||||||
|
|
||||||
mock_node = mocker.MagicMock()
|
|
||||||
mock_node.id = "node-with-auto-creds"
|
|
||||||
mock_node.credentials_optional = False
|
|
||||||
mock_node.input_default = {
|
|
||||||
"spreadsheet": {
|
|
||||||
"_credentials_id": "valid-cred-id",
|
|
||||||
"id": "file-123",
|
|
||||||
"name": "test.xlsx",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_block = mocker.MagicMock()
|
|
||||||
# No regular credentials fields
|
|
||||||
mock_block.input_schema.get_credentials_fields.return_value = {}
|
|
||||||
# Has auto_credentials fields
|
|
||||||
mock_block.input_schema.get_auto_credentials_fields.return_value = {
|
|
||||||
"credentials": {
|
|
||||||
"field_name": "spreadsheet",
|
|
||||||
"config": {"provider": "google", "type": "oauth2"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
mock_node.block = mock_block
|
|
||||||
|
|
||||||
mock_graph = mocker.MagicMock()
|
|
||||||
mock_graph.nodes = [mock_node]
|
|
||||||
|
|
||||||
# Mock the credentials store to return valid credentials
|
|
||||||
mock_store = mocker.MagicMock()
|
|
||||||
mock_creds = mocker.MagicMock()
|
|
||||||
mock_creds.id = "valid-cred-id"
|
|
||||||
mock_store.get_creds_by_id = mocker.AsyncMock(return_value=mock_creds)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.executor.utils.get_integration_credentials_store",
|
|
||||||
return_value=mock_store,
|
|
||||||
)
|
|
||||||
|
|
||||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
|
||||||
graph=mock_graph,
|
|
||||||
user_id="test-user",
|
|
||||||
nodes_input_masks=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert mock_node.id not in errors
|
|
||||||
assert mock_node.id not in nodes_to_skip
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_validate_node_input_credentials_auto_creds_missing(
|
|
||||||
mocker: MockerFixture,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
[SECRT-1772] When a node has auto_credentials with a _credentials_id
|
|
||||||
that doesn't exist for the current user, validation should report an error.
|
|
||||||
"""
|
|
||||||
from backend.executor.utils import _validate_node_input_credentials
|
|
||||||
|
|
||||||
mock_node = mocker.MagicMock()
|
|
||||||
mock_node.id = "node-with-bad-auto-creds"
|
|
||||||
mock_node.credentials_optional = False
|
|
||||||
mock_node.input_default = {
|
|
||||||
"spreadsheet": {
|
|
||||||
"_credentials_id": "other-users-cred-id",
|
|
||||||
"id": "file-123",
|
|
||||||
"name": "test.xlsx",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_block = mocker.MagicMock()
|
|
||||||
mock_block.input_schema.get_credentials_fields.return_value = {}
|
|
||||||
mock_block.input_schema.get_auto_credentials_fields.return_value = {
|
|
||||||
"credentials": {
|
|
||||||
"field_name": "spreadsheet",
|
|
||||||
"config": {"provider": "google", "type": "oauth2"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
mock_node.block = mock_block
|
|
||||||
|
|
||||||
mock_graph = mocker.MagicMock()
|
|
||||||
mock_graph.nodes = [mock_node]
|
|
||||||
|
|
||||||
# Mock the credentials store to return None (cred not found for this user)
|
|
||||||
mock_store = mocker.MagicMock()
|
|
||||||
mock_store.get_creds_by_id = mocker.AsyncMock(return_value=None)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.executor.utils.get_integration_credentials_store",
|
|
||||||
return_value=mock_store,
|
|
||||||
)
|
|
||||||
|
|
||||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
|
||||||
graph=mock_graph,
|
|
||||||
user_id="different-user",
|
|
||||||
nodes_input_masks=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert mock_node.id in errors
|
|
||||||
assert "spreadsheet" in errors[mock_node.id]
|
|
||||||
assert "not available" in errors[mock_node.id]["spreadsheet"].lower()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_validate_node_input_credentials_both_regular_and_auto(
|
|
||||||
mocker: MockerFixture,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
[SECRT-1772] A node that has BOTH regular credentials AND auto_credentials
|
|
||||||
should have both validated.
|
|
||||||
"""
|
|
||||||
from backend.executor.utils import _validate_node_input_credentials
|
|
||||||
|
|
||||||
mock_node = mocker.MagicMock()
|
|
||||||
mock_node.id = "node-with-both-creds"
|
|
||||||
mock_node.credentials_optional = False
|
|
||||||
mock_node.input_default = {
|
|
||||||
"credentials": {
|
|
||||||
"id": "regular-cred-id",
|
|
||||||
"provider": "github",
|
|
||||||
"type": "api_key",
|
|
||||||
},
|
|
||||||
"spreadsheet": {
|
|
||||||
"_credentials_id": "auto-cred-id",
|
|
||||||
"id": "file-123",
|
|
||||||
"name": "test.xlsx",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_credentials_field_type = mocker.MagicMock()
|
|
||||||
mock_credentials_meta = mocker.MagicMock()
|
|
||||||
mock_credentials_meta.id = "regular-cred-id"
|
|
||||||
mock_credentials_meta.provider = "github"
|
|
||||||
mock_credentials_meta.type = "api_key"
|
|
||||||
mock_credentials_field_type.model_validate.return_value = mock_credentials_meta
|
|
||||||
|
|
||||||
mock_block = mocker.MagicMock()
|
|
||||||
# Regular credentials field
|
|
||||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
|
||||||
"credentials": mock_credentials_field_type,
|
|
||||||
}
|
|
||||||
# Auto-credentials field
|
|
||||||
mock_block.input_schema.get_auto_credentials_fields.return_value = {
|
|
||||||
"auto_credentials": {
|
|
||||||
"field_name": "spreadsheet",
|
|
||||||
"config": {"provider": "google", "type": "oauth2"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
mock_node.block = mock_block
|
|
||||||
|
|
||||||
mock_graph = mocker.MagicMock()
|
|
||||||
mock_graph.nodes = [mock_node]
|
|
||||||
|
|
||||||
# Mock the credentials store to return valid credentials for both
|
|
||||||
mock_store = mocker.MagicMock()
|
|
||||||
mock_regular_creds = mocker.MagicMock()
|
|
||||||
mock_regular_creds.id = "regular-cred-id"
|
|
||||||
mock_regular_creds.provider = "github"
|
|
||||||
mock_regular_creds.type = "api_key"
|
|
||||||
|
|
||||||
mock_auto_creds = mocker.MagicMock()
|
|
||||||
mock_auto_creds.id = "auto-cred-id"
|
|
||||||
|
|
||||||
def get_creds_side_effect(user_id, cred_id):
|
|
||||||
if cred_id == "regular-cred-id":
|
|
||||||
return mock_regular_creds
|
|
||||||
elif cred_id == "auto-cred-id":
|
|
||||||
return mock_auto_creds
|
|
||||||
return None
|
|
||||||
|
|
||||||
mock_store.get_creds_by_id = mocker.AsyncMock(side_effect=get_creds_side_effect)
|
|
||||||
mocker.patch(
|
|
||||||
"backend.executor.utils.get_integration_credentials_store",
|
|
||||||
return_value=mock_store,
|
|
||||||
)
|
|
||||||
|
|
||||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
|
||||||
graph=mock_graph,
|
|
||||||
user_id="test-user",
|
|
||||||
nodes_input_masks=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Both should validate successfully - no errors
|
|
||||||
assert mock_node.id not in errors
|
|
||||||
assert mock_node.id not in nodes_to_skip
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_validate_node_input_credentials_auto_creds_skipped_when_none(
|
|
||||||
mocker: MockerFixture,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
When a node has auto_credentials but the field value has _credentials_id=None
|
|
||||||
(e.g., from upstream connection), validation should skip it without error.
|
|
||||||
"""
|
|
||||||
from backend.executor.utils import _validate_node_input_credentials
|
|
||||||
|
|
||||||
mock_node = mocker.MagicMock()
|
|
||||||
mock_node.id = "node-with-chained-auto-creds"
|
|
||||||
mock_node.credentials_optional = False
|
|
||||||
mock_node.input_default = {
|
|
||||||
"spreadsheet": {
|
|
||||||
"_credentials_id": None,
|
|
||||||
"id": "file-123",
|
|
||||||
"name": "test.xlsx",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_block = mocker.MagicMock()
|
|
||||||
mock_block.input_schema.get_credentials_fields.return_value = {}
|
|
||||||
mock_block.input_schema.get_auto_credentials_fields.return_value = {
|
|
||||||
"credentials": {
|
|
||||||
"field_name": "spreadsheet",
|
|
||||||
"config": {"provider": "google", "type": "oauth2"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
mock_node.block = mock_block
|
|
||||||
|
|
||||||
mock_graph = mocker.MagicMock()
|
|
||||||
mock_graph.nodes = [mock_node]
|
|
||||||
|
|
||||||
errors, nodes_to_skip = await _validate_node_input_credentials(
|
|
||||||
graph=mock_graph,
|
|
||||||
user_id="test-user",
|
|
||||||
nodes_input_masks=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# No error - chained data with None cred_id is valid
|
|
||||||
assert mock_node.id not in errors
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Tests for CredentialsFieldInfo auto_credential tag (Fix 4: Path 4)
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def test_credentials_field_info_auto_credential_tag():
|
|
||||||
"""
|
|
||||||
[Path 4] CredentialsFieldInfo should support is_auto_credential and
|
|
||||||
input_field_name fields for distinguishing auto from regular credentials.
|
|
||||||
"""
|
|
||||||
from backend.data.model import CredentialsFieldInfo
|
|
||||||
|
|
||||||
# Regular credential should have is_auto_credential=False by default
|
|
||||||
regular = CredentialsFieldInfo.model_validate(
|
|
||||||
{
|
|
||||||
"credentials_provider": ["github"],
|
|
||||||
"credentials_types": ["api_key"],
|
|
||||||
},
|
|
||||||
by_alias=True,
|
|
||||||
)
|
|
||||||
assert regular.is_auto_credential is False
|
|
||||||
assert regular.input_field_name is None
|
|
||||||
|
|
||||||
# Auto credential should have is_auto_credential=True
|
|
||||||
auto = CredentialsFieldInfo.model_validate(
|
|
||||||
{
|
|
||||||
"credentials_provider": ["google"],
|
|
||||||
"credentials_types": ["oauth2"],
|
|
||||||
"is_auto_credential": True,
|
|
||||||
"input_field_name": "spreadsheet",
|
|
||||||
},
|
|
||||||
by_alias=True,
|
|
||||||
)
|
|
||||||
assert auto.is_auto_credential is True
|
|
||||||
assert auto.input_field_name == "spreadsheet"
|
|
||||||
|
|
||||||
|
|
||||||
def test_make_node_credentials_input_map_excludes_auto_creds(
|
|
||||||
mocker: MockerFixture,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
[Path 4] make_node_credentials_input_map should only include regular credentials,
|
|
||||||
not auto_credentials (which are resolved at execution time).
|
|
||||||
"""
|
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
|
||||||
from backend.executor.utils import make_node_credentials_input_map
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
# Create a mock graph with aggregate_credentials_inputs that returns
|
|
||||||
# both regular and auto credentials
|
|
||||||
mock_graph = mocker.MagicMock()
|
|
||||||
|
|
||||||
regular_field_info = CredentialsFieldInfo.model_validate(
|
|
||||||
{
|
|
||||||
"credentials_provider": ["github"],
|
|
||||||
"credentials_types": ["api_key"],
|
|
||||||
"is_auto_credential": False,
|
|
||||||
},
|
|
||||||
by_alias=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock regular_credentials_inputs property (auto_credentials are excluded)
|
|
||||||
mock_graph.regular_credentials_inputs = {
|
|
||||||
"github_creds": (regular_field_info, {("node-1", "credentials")}, True),
|
|
||||||
}
|
|
||||||
|
|
||||||
graph_credentials_input = {
|
|
||||||
"github_creds": CredentialsMetaInput(
|
|
||||||
id="cred-123",
|
|
||||||
provider=ProviderName("github"),
|
|
||||||
type="api_key",
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
result = make_node_credentials_input_map(mock_graph, graph_credentials_input)
|
|
||||||
|
|
||||||
# Regular credentials should be mapped
|
|
||||||
assert "node-1" in result
|
|
||||||
assert "credentials" in result["node-1"]
|
|
||||||
|
|
||||||
# Auto credentials should NOT appear in the result
|
|
||||||
# (they would have been mapped to the kwarg_name "credentials" not "spreadsheet")
|
|
||||||
for node_id, fields in result.items():
|
|
||||||
for field_name, value in fields.items():
|
|
||||||
# Verify no auto-credential phantom entries
|
|
||||||
if isinstance(value, dict):
|
|
||||||
assert "_credentials_id" not in value
|
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class ProviderName(str, Enum):
|
|||||||
IDEOGRAM = "ideogram"
|
IDEOGRAM = "ideogram"
|
||||||
JINA = "jina"
|
JINA = "jina"
|
||||||
LLAMA_API = "llama_api"
|
LLAMA_API = "llama_api"
|
||||||
|
MCP = "mcp"
|
||||||
MEDIUM = "medium"
|
MEDIUM = "medium"
|
||||||
MEM0 = "mem0"
|
MEM0 = "mem0"
|
||||||
NOTION = "notion"
|
NOTION = "notion"
|
||||||
|
|||||||
6845
autogpt_platform/backend/poetry.lock
generated
6845
autogpt_platform/backend/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -21,7 +21,7 @@ cryptography = "^45.0"
|
|||||||
discord-py = "^2.5.2"
|
discord-py = "^2.5.2"
|
||||||
e2b-code-interpreter = "^1.5.2"
|
e2b-code-interpreter = "^1.5.2"
|
||||||
elevenlabs = "^1.50.0"
|
elevenlabs = "^1.50.0"
|
||||||
fastapi = "^0.116.1"
|
fastapi = "^0.128.0"
|
||||||
feedparser = "^6.0.11"
|
feedparser = "^6.0.11"
|
||||||
flake8 = "^7.3.0"
|
flake8 = "^7.3.0"
|
||||||
google-api-python-client = "^2.177.0"
|
google-api-python-client = "^2.177.0"
|
||||||
@@ -35,7 +35,7 @@ jinja2 = "^3.1.6"
|
|||||||
jsonref = "^1.1.0"
|
jsonref = "^1.1.0"
|
||||||
jsonschema = "^4.25.0"
|
jsonschema = "^4.25.0"
|
||||||
langfuse = "^3.11.0"
|
langfuse = "^3.11.0"
|
||||||
launchdarkly-server-sdk = "^9.12.0"
|
launchdarkly-server-sdk = "^9.14.1"
|
||||||
mem0ai = "^0.1.115"
|
mem0ai = "^0.1.115"
|
||||||
moviepy = "^2.1.2"
|
moviepy = "^2.1.2"
|
||||||
ollama = "^0.5.1"
|
ollama = "^0.5.1"
|
||||||
@@ -52,8 +52,8 @@ prometheus-client = "^0.22.1"
|
|||||||
prometheus-fastapi-instrumentator = "^7.0.0"
|
prometheus-fastapi-instrumentator = "^7.0.0"
|
||||||
psutil = "^7.0.0"
|
psutil = "^7.0.0"
|
||||||
psycopg2-binary = "^2.9.10"
|
psycopg2-binary = "^2.9.10"
|
||||||
pydantic = { extras = ["email"], version = "^2.11.7" }
|
pydantic = { extras = ["email"], version = "^2.12.5" }
|
||||||
pydantic-settings = "^2.10.1"
|
pydantic-settings = "^2.12.0"
|
||||||
pytest = "^8.4.1"
|
pytest = "^8.4.1"
|
||||||
pytest-asyncio = "^1.1.0"
|
pytest-asyncio = "^1.1.0"
|
||||||
python-dotenv = "^1.1.1"
|
python-dotenv = "^1.1.1"
|
||||||
@@ -65,11 +65,11 @@ sentry-sdk = {extras = ["anthropic", "fastapi", "launchdarkly", "openai", "sqlal
|
|||||||
sqlalchemy = "^2.0.40"
|
sqlalchemy = "^2.0.40"
|
||||||
strenum = "^0.4.9"
|
strenum = "^0.4.9"
|
||||||
stripe = "^11.5.0"
|
stripe = "^11.5.0"
|
||||||
supabase = "2.17.0"
|
supabase = "2.27.2"
|
||||||
tenacity = "^9.1.2"
|
tenacity = "^9.1.2"
|
||||||
todoist-api-python = "^2.1.7"
|
todoist-api-python = "^2.1.7"
|
||||||
tweepy = "^4.16.0"
|
tweepy = "^4.16.0"
|
||||||
uvicorn = { extras = ["standard"], version = "^0.35.0" }
|
uvicorn = { extras = ["standard"], version = "^0.40.0" }
|
||||||
websockets = "^15.0"
|
websockets = "^15.0"
|
||||||
youtube-transcript-api = "^1.2.1"
|
youtube-transcript-api = "^1.2.1"
|
||||||
yt-dlp = "2025.12.08"
|
yt-dlp = "2025.12.08"
|
||||||
|
|||||||
@@ -12307,7 +12307,9 @@
|
|||||||
"title": "Location"
|
"title": "Location"
|
||||||
},
|
},
|
||||||
"msg": { "type": "string", "title": "Message" },
|
"msg": { "type": "string", "title": "Message" },
|
||||||
"type": { "type": "string", "title": "Error Type" }
|
"type": { "type": "string", "title": "Error Type" },
|
||||||
|
"input": { "title": "Input" },
|
||||||
|
"ctx": { "type": "object", "title": "Context" }
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["loc", "msg", "type"],
|
"required": ["loc", "msg", "type"],
|
||||||
|
|||||||
@@ -4,9 +4,7 @@ import { loadScript } from "@/services/scripts/scripts";
|
|||||||
export async function loadGoogleAPIPicker(): Promise<void> {
|
export async function loadGoogleAPIPicker(): Promise<void> {
|
||||||
validateWindow();
|
validateWindow();
|
||||||
|
|
||||||
await loadScript("https://apis.google.com/js/api.js", {
|
await loadScript("https://apis.google.com/js/api.js");
|
||||||
referrerPolicy: "no-referrer-when-downgrade",
|
|
||||||
});
|
|
||||||
|
|
||||||
const googleAPI = window.gapi;
|
const googleAPI = window.gapi;
|
||||||
if (!googleAPI) {
|
if (!googleAPI) {
|
||||||
@@ -29,9 +27,7 @@ export async function loadGoogleIdentityServices(): Promise<void> {
|
|||||||
throw new Error("Google Identity Services cannot load on server");
|
throw new Error("Google Identity Services cannot load on server");
|
||||||
}
|
}
|
||||||
|
|
||||||
await loadScript("https://accounts.google.com/gsi/client", {
|
await loadScript("https://accounts.google.com/gsi/client");
|
||||||
referrerPolicy: "no-referrer-when-downgrade",
|
|
||||||
});
|
|
||||||
|
|
||||||
const google = window.google;
|
const google = window.google;
|
||||||
if (!google?.accounts?.oauth2) {
|
if (!google?.accounts?.oauth2) {
|
||||||
|
|||||||
@@ -467,6 +467,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
|||||||
| [Github Update Comment](block-integrations/github/issues.md#github-update-comment) | A block that updates an existing comment on a GitHub issue or pull request |
|
| [Github Update Comment](block-integrations/github/issues.md#github-update-comment) | A block that updates an existing comment on a GitHub issue or pull request |
|
||||||
| [Github Update File](block-integrations/github/repo.md#github-update-file) | This block updates an existing file in a GitHub repository |
|
| [Github Update File](block-integrations/github/repo.md#github-update-file) | This block updates an existing file in a GitHub repository |
|
||||||
| [Instantiate Code Sandbox](block-integrations/misc.md#instantiate-code-sandbox) | Instantiate a sandbox environment with internet access in which you can execute code with the Execute Code Step block |
|
| [Instantiate Code Sandbox](block-integrations/misc.md#instantiate-code-sandbox) | Instantiate a sandbox environment with internet access in which you can execute code with the Execute Code Step block |
|
||||||
|
| [MCP Tool](block-integrations/mcp/block.md#mcp-tool) | Connect to any MCP server and execute its tools |
|
||||||
| [Slant3D Order Webhook](block-integrations/slant3d/webhook.md#slant3d-order-webhook) | This block triggers on Slant3D order status updates and outputs the event details, including tracking information when orders are shipped |
|
| [Slant3D Order Webhook](block-integrations/slant3d/webhook.md#slant3d-order-webhook) | This block triggers on Slant3D order status updates and outputs the event details, including tracking information when orders are shipped |
|
||||||
|
|
||||||
## Media Generation
|
## Media Generation
|
||||||
|
|||||||
@@ -84,6 +84,7 @@
|
|||||||
* [Linear Projects](block-integrations/linear/projects.md)
|
* [Linear Projects](block-integrations/linear/projects.md)
|
||||||
* [LLM](block-integrations/llm.md)
|
* [LLM](block-integrations/llm.md)
|
||||||
* [Logic](block-integrations/logic.md)
|
* [Logic](block-integrations/logic.md)
|
||||||
|
* [Mcp Block](block-integrations/mcp/block.md)
|
||||||
* [Misc](block-integrations/misc.md)
|
* [Misc](block-integrations/misc.md)
|
||||||
* [Notion Create Page](block-integrations/notion/create_page.md)
|
* [Notion Create Page](block-integrations/notion/create_page.md)
|
||||||
* [Notion Read Database](block-integrations/notion/read_database.md)
|
* [Notion Read Database](block-integrations/notion/read_database.md)
|
||||||
|
|||||||
36
docs/integrations/block-integrations/mcp/block.md
Normal file
36
docs/integrations/block-integrations/mcp/block.md
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# Mcp Block
|
||||||
|
<!-- MANUAL: file_description -->
|
||||||
|
_Add a description of this category of blocks._
|
||||||
|
<!-- END MANUAL -->
|
||||||
|
|
||||||
|
## MCP Tool
|
||||||
|
|
||||||
|
### What it is
|
||||||
|
Connect to any MCP server and execute its tools. Provide a server URL, select a tool, and pass arguments dynamically.
|
||||||
|
|
||||||
|
### How it works
|
||||||
|
<!-- MANUAL: how_it_works -->
|
||||||
|
_Add technical explanation here._
|
||||||
|
<!-- END MANUAL -->
|
||||||
|
|
||||||
|
### Inputs
|
||||||
|
|
||||||
|
| Input | Description | Type | Required |
|
||||||
|
|-------|-------------|------|----------|
|
||||||
|
| server_url | URL of the MCP server (Streamable HTTP endpoint) | str | Yes |
|
||||||
|
| selected_tool | The MCP tool to execute | str | No |
|
||||||
|
| tool_arguments | Arguments to pass to the selected MCP tool. The fields here are defined by the tool's input schema. | Dict[str, Any] | No |
|
||||||
|
|
||||||
|
### Outputs
|
||||||
|
|
||||||
|
| Output | Description | Type |
|
||||||
|
|--------|-------------|------|
|
||||||
|
| error | Error message if the tool call failed | str |
|
||||||
|
| result | The result returned by the MCP tool | Result |
|
||||||
|
|
||||||
|
### Possible use case
|
||||||
|
<!-- MANUAL: use_case -->
|
||||||
|
_Add practical use case examples here._
|
||||||
|
<!-- END MANUAL -->
|
||||||
|
|
||||||
|
---
|
||||||
Reference in New Issue
Block a user