diff --git a/autogpt_platform/backend/backend/api/features/integrations/router.py b/autogpt_platform/backend/backend/api/features/integrations/router.py index 00500dc8a8..4eacf83e71 100644 --- a/autogpt_platform/backend/backend/api/features/integrations/router.py +++ b/autogpt_platform/backend/backend/api/features/integrations/router.py @@ -1,7 +1,7 @@ import asyncio import logging from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Annotated, List, Literal +from typing import TYPE_CHECKING, Annotated, Any, List, Literal from autogpt_libs.auth import get_user_id from fastapi import ( @@ -14,7 +14,7 @@ from fastapi import ( Security, status, ) -from pydantic import BaseModel, Field, SecretStr +from pydantic import BaseModel, Field, SecretStr, model_validator from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY from backend.api.features.library.db import set_preset_webhook, update_preset @@ -39,7 +39,11 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step from backend.data.user import get_user_integrations from backend.executor.utils import add_graph_execution from backend.integrations.ayrshare import AyrshareClient, SocialPlatform -from backend.integrations.creds_manager import IntegrationCredentialsManager +from backend.integrations.credentials_store import provider_matches +from backend.integrations.creds_manager import ( + IntegrationCredentialsManager, + create_mcp_oauth_handler, +) from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME from backend.integrations.providers import ProviderName from backend.integrations.webhooks import get_webhook_manager @@ -102,9 +106,37 @@ class CredentialsMetaResponse(BaseModel): scopes: list[str] | None username: str | None host: str | None = Field( - default=None, description="Host pattern for host-scoped credentials" + default=None, + description="Host pattern for host-scoped or MCP server URL for MCP credentials", ) + @model_validator(mode="before") + @classmethod + def _normalize_provider(cls, data: Any) -> Any: + """Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug.""" + if isinstance(data, dict): + prov = data.get("provider", "") + if isinstance(prov, str) and prov.startswith("ProviderName."): + member = prov.removeprefix("ProviderName.") + try: + data = {**data, "provider": ProviderName[member].value} + except KeyError: + pass + return data + + @staticmethod + def get_host(cred: Credentials) -> str | None: + """Extract host from credential: HostScoped host or MCP server URL.""" + if isinstance(cred, HostScopedCredentials): + return cred.host + if isinstance(cred, OAuth2Credentials) and cred.provider in ( + ProviderName.MCP, + ProviderName.MCP.value, + "ProviderName.MCP", + ): + return (cred.metadata or {}).get("mcp_server_url") + return None + @router.post("/{provider}/callback", summary="Exchange OAuth code for tokens") async def callback( @@ -179,9 +211,7 @@ async def callback( title=credentials.title, scopes=credentials.scopes, username=credentials.username, - host=( - credentials.host if isinstance(credentials, HostScopedCredentials) else None - ), + host=(CredentialsMetaResponse.get_host(credentials)), ) @@ -199,7 +229,7 @@ async def list_credentials( title=cred.title, scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None, username=cred.username if isinstance(cred, OAuth2Credentials) else None, - host=cred.host if isinstance(cred, HostScopedCredentials) else None, + host=CredentialsMetaResponse.get_host(cred), ) for cred in credentials ] @@ -222,7 +252,7 @@ async def list_credentials_by_provider( title=cred.title, scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None, username=cred.username if isinstance(cred, OAuth2Credentials) else None, - host=cred.host if isinstance(cred, HostScopedCredentials) else None, + host=CredentialsMetaResponse.get_host(cred), ) for cred in credentials ] @@ -322,7 +352,11 @@ async def delete_credentials( tokens_revoked = None if isinstance(creds, OAuth2Credentials): - handler = _get_provider_oauth_handler(request, provider) + if provider_matches(provider.value, ProviderName.MCP.value): + # MCP uses dynamic per-server OAuth — create handler from metadata + handler = create_mcp_oauth_handler(creds) + else: + handler = _get_provider_oauth_handler(request, provider) tokens_revoked = await handler.revoke_tokens(creds) return CredentialsDeletionResponse(revoked=tokens_revoked) diff --git a/autogpt_platform/backend/backend/api/features/mcp/__init__.py b/autogpt_platform/backend/backend/api/features/mcp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/autogpt_platform/backend/backend/api/features/mcp/routes.py b/autogpt_platform/backend/backend/api/features/mcp/routes.py new file mode 100644 index 0000000000..f8d311f372 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/mcp/routes.py @@ -0,0 +1,404 @@ +""" +MCP (Model Context Protocol) API routes. + +Provides endpoints for MCP tool discovery and OAuth authentication so the +frontend can list available tools on an MCP server before placing a block. +""" + +import logging +from typing import Annotated, Any +from urllib.parse import urlparse + +import fastapi +from autogpt_libs.auth import get_user_id +from fastapi import Security +from pydantic import BaseModel, Field + +from backend.api.features.integrations.router import CredentialsMetaResponse +from backend.blocks.mcp.client import MCPClient, MCPClientError +from backend.blocks.mcp.oauth import MCPOAuthHandler +from backend.data.model import OAuth2Credentials +from backend.integrations.creds_manager import IntegrationCredentialsManager +from backend.integrations.providers import ProviderName +from backend.util.request import HTTPClientError, Requests +from backend.util.settings import Settings + +logger = logging.getLogger(__name__) + +settings = Settings() +router = fastapi.APIRouter(tags=["mcp"]) +creds_manager = IntegrationCredentialsManager() + + +# ====================== Tool Discovery ====================== # + + +class DiscoverToolsRequest(BaseModel): + """Request to discover tools on an MCP server.""" + + server_url: str = Field(description="URL of the MCP server") + auth_token: str | None = Field( + default=None, + description="Optional Bearer token for authenticated MCP servers", + ) + + +class MCPToolResponse(BaseModel): + """A single MCP tool returned by discovery.""" + + name: str + description: str + input_schema: dict[str, Any] + + +class DiscoverToolsResponse(BaseModel): + """Response containing the list of tools available on an MCP server.""" + + tools: list[MCPToolResponse] + server_name: str | None = None + protocol_version: str | None = None + + +@router.post( + "/discover-tools", + summary="Discover available tools on an MCP server", + response_model=DiscoverToolsResponse, +) +async def discover_tools( + request: DiscoverToolsRequest, + user_id: Annotated[str, Security(get_user_id)], +) -> DiscoverToolsResponse: + """ + Connect to an MCP server and return its available tools. + + If the user has a stored MCP credential for this server URL, it will be + used automatically — no need to pass an explicit auth token. + """ + auth_token = request.auth_token + + # Auto-use stored MCP credential when no explicit token is provided. + if not auth_token: + mcp_creds = await creds_manager.store.get_creds_by_provider( + user_id, ProviderName.MCP.value + ) + # Find the freshest credential for this server URL + best_cred: OAuth2Credentials | None = None + for cred in mcp_creds: + if ( + isinstance(cred, OAuth2Credentials) + and (cred.metadata or {}).get("mcp_server_url") == request.server_url + ): + if best_cred is None or ( + (cred.access_token_expires_at or 0) + > (best_cred.access_token_expires_at or 0) + ): + best_cred = cred + if best_cred: + # Refresh the token if expired before using it + best_cred = await creds_manager.refresh_if_needed(user_id, best_cred) + logger.info( + f"Using MCP credential {best_cred.id} for {request.server_url}, " + f"expires_at={best_cred.access_token_expires_at}" + ) + auth_token = best_cred.access_token.get_secret_value() + + client = MCPClient(request.server_url, auth_token=auth_token) + + try: + init_result = await client.initialize() + tools = await client.list_tools() + except HTTPClientError as e: + if e.status_code in (401, 403): + raise fastapi.HTTPException( + status_code=401, + detail="This MCP server requires authentication. " + "Please provide a valid auth token.", + ) + raise fastapi.HTTPException(status_code=502, detail=str(e)) + except MCPClientError as e: + raise fastapi.HTTPException(status_code=502, detail=str(e)) + except Exception as e: + raise fastapi.HTTPException( + status_code=502, + detail=f"Failed to connect to MCP server: {e}", + ) + + return DiscoverToolsResponse( + tools=[ + MCPToolResponse( + name=t.name, + description=t.description, + input_schema=t.input_schema, + ) + for t in tools + ], + server_name=( + init_result.get("serverInfo", {}).get("name") + or urlparse(request.server_url).hostname + or "MCP" + ), + protocol_version=init_result.get("protocolVersion"), + ) + + +# ======================== OAuth Flow ======================== # + + +class MCPOAuthLoginRequest(BaseModel): + """Request to start an OAuth flow for an MCP server.""" + + server_url: str = Field(description="URL of the MCP server that requires OAuth") + + +class MCPOAuthLoginResponse(BaseModel): + """Response with the OAuth login URL for the user to authenticate.""" + + login_url: str + state_token: str + + +@router.post( + "/oauth/login", + summary="Initiate OAuth login for an MCP server", +) +async def mcp_oauth_login( + request: MCPOAuthLoginRequest, + user_id: Annotated[str, Security(get_user_id)], +) -> MCPOAuthLoginResponse: + """ + Discover OAuth metadata from the MCP server and return a login URL. + + 1. Discovers the protected-resource metadata (RFC 9728) + 2. Fetches the authorization server metadata (RFC 8414) + 3. Performs Dynamic Client Registration (RFC 7591) if available + 4. Returns the authorization URL for the frontend to open in a popup + """ + client = MCPClient(request.server_url) + + # Step 1: Discover protected-resource metadata (RFC 9728) + protected_resource = await client.discover_auth() + + metadata: dict[str, Any] | None = None + + if protected_resource and protected_resource.get("authorization_servers"): + auth_server_url = protected_resource["authorization_servers"][0] + resource_url = protected_resource.get("resource", request.server_url) + + # Step 2a: Discover auth-server metadata (RFC 8414) + metadata = await client.discover_auth_server_metadata(auth_server_url) + else: + # Fallback: Some MCP servers (e.g. Linear) are their own auth server + # and serve OAuth metadata directly without protected-resource metadata. + # Don't assume a resource_url — omitting it lets the auth server choose + # the correct audience for the token (RFC 8707 resource is optional). + resource_url = None + metadata = await client.discover_auth_server_metadata(request.server_url) + + if ( + not metadata + or "authorization_endpoint" not in metadata + or "token_endpoint" not in metadata + ): + raise fastapi.HTTPException( + status_code=400, + detail="This MCP server does not advertise OAuth support. " + "You may need to provide an auth token manually.", + ) + + authorize_url = metadata["authorization_endpoint"] + token_url = metadata["token_endpoint"] + registration_endpoint = metadata.get("registration_endpoint") + revoke_url = metadata.get("revocation_endpoint") + + # Step 3: Dynamic Client Registration (RFC 7591) if available + frontend_base_url = settings.config.frontend_base_url + if not frontend_base_url: + raise fastapi.HTTPException( + status_code=500, + detail="Frontend base URL is not configured.", + ) + redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback" + + client_id = "" + client_secret = "" + if registration_endpoint: + reg_result = await _register_mcp_client( + registration_endpoint, redirect_uri, request.server_url + ) + if reg_result: + client_id = reg_result.get("client_id", "") + client_secret = reg_result.get("client_secret", "") + + if not client_id: + client_id = "autogpt-platform" + + # Step 4: Store state token with OAuth metadata for the callback + scopes = (protected_resource or {}).get("scopes_supported") or metadata.get( + "scopes_supported", [] + ) + state_token, code_challenge = await creds_manager.store.store_state_token( + user_id, + ProviderName.MCP.value, + scopes, + state_metadata={ + "authorize_url": authorize_url, + "token_url": token_url, + "revoke_url": revoke_url, + "resource_url": resource_url, + "server_url": request.server_url, + "client_id": client_id, + "client_secret": client_secret, + }, + ) + + # Step 5: Build and return the login URL + handler = MCPOAuthHandler( + client_id=client_id, + client_secret=client_secret, + redirect_uri=redirect_uri, + authorize_url=authorize_url, + token_url=token_url, + resource_url=resource_url, + ) + login_url = handler.get_login_url( + scopes, state_token, code_challenge=code_challenge + ) + + return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token) + + +class MCPOAuthCallbackRequest(BaseModel): + """Request to exchange an OAuth code for tokens.""" + + code: str = Field(description="Authorization code from OAuth callback") + state_token: str = Field(description="State token for CSRF verification") + + +class MCPOAuthCallbackResponse(BaseModel): + """Response after successfully storing OAuth credentials.""" + + credential_id: str + + +@router.post( + "/oauth/callback", + summary="Exchange OAuth code for MCP tokens", +) +async def mcp_oauth_callback( + request: MCPOAuthCallbackRequest, + user_id: Annotated[str, Security(get_user_id)], +) -> CredentialsMetaResponse: + """ + Exchange the authorization code for tokens and store the credential. + + The frontend calls this after receiving the OAuth code from the popup. + On success, subsequent ``/discover-tools`` calls for the same server URL + will automatically use the stored credential. + """ + valid_state = await creds_manager.store.verify_state_token( + user_id, request.state_token, ProviderName.MCP.value + ) + if not valid_state: + raise fastapi.HTTPException( + status_code=400, + detail="Invalid or expired state token.", + ) + + meta = valid_state.state_metadata + frontend_base_url = settings.config.frontend_base_url + if not frontend_base_url: + raise fastapi.HTTPException( + status_code=500, + detail="Frontend base URL is not configured.", + ) + redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback" + + handler = MCPOAuthHandler( + client_id=meta["client_id"], + client_secret=meta.get("client_secret", ""), + redirect_uri=redirect_uri, + authorize_url=meta["authorize_url"], + token_url=meta["token_url"], + revoke_url=meta.get("revoke_url"), + resource_url=meta.get("resource_url"), + ) + + try: + credentials = await handler.exchange_code_for_tokens( + request.code, valid_state.scopes, valid_state.code_verifier + ) + except Exception as e: + raise fastapi.HTTPException( + status_code=400, + detail=f"OAuth token exchange failed: {e}", + ) + + # Enrich credential metadata for future lookup and token refresh + if credentials.metadata is None: + credentials.metadata = {} + credentials.metadata["mcp_server_url"] = meta["server_url"] + credentials.metadata["mcp_client_id"] = meta["client_id"] + credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "") + credentials.metadata["mcp_token_url"] = meta["token_url"] + credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "") + + hostname = urlparse(meta["server_url"]).hostname or meta["server_url"] + credentials.title = f"MCP: {hostname}" + + # Remove old MCP credentials for the same server to prevent stale token buildup. + try: + old_creds = await creds_manager.store.get_creds_by_provider( + user_id, ProviderName.MCP.value + ) + for old in old_creds: + if ( + isinstance(old, OAuth2Credentials) + and (old.metadata or {}).get("mcp_server_url") == meta["server_url"] + ): + await creds_manager.store.delete_creds_by_id(user_id, old.id) + logger.info( + f"Removed old MCP credential {old.id} for {meta['server_url']}" + ) + except Exception: + logger.debug("Could not clean up old MCP credentials", exc_info=True) + + await creds_manager.create(user_id, credentials) + + return CredentialsMetaResponse( + id=credentials.id, + provider=credentials.provider, + type=credentials.type, + title=credentials.title, + scopes=credentials.scopes, + username=credentials.username, + host=credentials.metadata.get("mcp_server_url"), + ) + + +# ======================== Helpers ======================== # + + +async def _register_mcp_client( + registration_endpoint: str, + redirect_uri: str, + server_url: str, +) -> dict[str, Any] | None: + """Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server.""" + try: + response = await Requests(raise_for_status=True).post( + registration_endpoint, + json={ + "client_name": "AutoGPT Platform", + "redirect_uris": [redirect_uri], + "grant_types": ["authorization_code"], + "response_types": ["code"], + "token_endpoint_auth_method": "client_secret_post", + }, + ) + data = response.json() + if isinstance(data, dict) and "client_id" in data: + return data + return None + except Exception as e: + logger.warning(f"Dynamic client registration failed for {server_url}: {e}") + return None diff --git a/autogpt_platform/backend/backend/api/features/mcp/test_routes.py b/autogpt_platform/backend/backend/api/features/mcp/test_routes.py new file mode 100644 index 0000000000..e86b9f4865 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/mcp/test_routes.py @@ -0,0 +1,436 @@ +"""Tests for MCP API routes. + +Uses httpx.AsyncClient with ASGITransport instead of fastapi.testclient.TestClient +to avoid creating blocking portals that can corrupt pytest-asyncio's session event loop. +""" + +from unittest.mock import AsyncMock, patch + +import fastapi +import httpx +import pytest +import pytest_asyncio +from autogpt_libs.auth import get_user_id + +from backend.api.features.mcp.routes import router +from backend.blocks.mcp.client import MCPClientError, MCPTool +from backend.util.request import HTTPClientError + +app = fastapi.FastAPI() +app.include_router(router) +app.dependency_overrides[get_user_id] = lambda: "test-user-id" + + +@pytest_asyncio.fixture(scope="module") +async def client(): + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as c: + yield c + + +class TestDiscoverTools: + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_success(self, client): + mock_tools = [ + MCPTool( + name="get_weather", + description="Get weather for a city", + input_schema={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + ), + MCPTool( + name="add_numbers", + description="Add two numbers", + input_schema={ + "type": "object", + "properties": { + "a": {"type": "number"}, + "b": {"type": "number"}, + }, + }, + ), + ] + + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + instance = MockClient.return_value + instance.initialize = AsyncMock( + return_value={ + "protocolVersion": "2025-03-26", + "serverInfo": {"name": "test-server"}, + } + ) + instance.list_tools = AsyncMock(return_value=mock_tools) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://mcp.example.com/mcp"}, + ) + + assert response.status_code == 200 + data = response.json() + assert len(data["tools"]) == 2 + assert data["tools"][0]["name"] == "get_weather" + assert data["tools"][1]["name"] == "add_numbers" + assert data["server_name"] == "test-server" + assert data["protocol_version"] == "2025-03-26" + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_with_auth_token(self, client): + with patch("backend.api.features.mcp.routes.MCPClient") as MockClient: + instance = MockClient.return_value + instance.initialize = AsyncMock( + return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"} + ) + instance.list_tools = AsyncMock(return_value=[]) + + response = await client.post( + "/discover-tools", + json={ + "server_url": "https://mcp.example.com/mcp", + "auth_token": "my-secret-token", + }, + ) + + assert response.status_code == 200 + MockClient.assert_called_once_with( + "https://mcp.example.com/mcp", + auth_token="my-secret-token", + ) + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_auto_uses_stored_credential(self, client): + """When no explicit token is given, stored MCP credentials are used.""" + from pydantic import SecretStr + + from backend.data.model import OAuth2Credentials + + stored_cred = OAuth2Credentials( + provider="mcp", + title="MCP: example.com", + access_token=SecretStr("stored-token-123"), + refresh_token=None, + access_token_expires_at=None, + refresh_token_expires_at=None, + scopes=[], + metadata={"mcp_server_url": "https://mcp.example.com/mcp"}, + ) + + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred]) + mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred) + instance = MockClient.return_value + instance.initialize = AsyncMock( + return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"} + ) + instance.list_tools = AsyncMock(return_value=[]) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://mcp.example.com/mcp"}, + ) + + assert response.status_code == 200 + MockClient.assert_called_once_with( + "https://mcp.example.com/mcp", + auth_token="stored-token-123", + ) + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_mcp_error(self, client): + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + instance = MockClient.return_value + instance.initialize = AsyncMock( + side_effect=MCPClientError("Connection refused") + ) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://bad-server.example.com/mcp"}, + ) + + assert response.status_code == 502 + assert "Connection refused" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_generic_error(self, client): + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + instance = MockClient.return_value + instance.initialize = AsyncMock(side_effect=Exception("Network timeout")) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://timeout.example.com/mcp"}, + ) + + assert response.status_code == 502 + assert "Failed to connect" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_auth_required(self, client): + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + instance = MockClient.return_value + instance.initialize = AsyncMock( + side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401) + ) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://auth-server.example.com/mcp"}, + ) + + assert response.status_code == 401 + assert "requires authentication" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_forbidden(self, client): + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + ): + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + instance = MockClient.return_value + instance.initialize = AsyncMock( + side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403) + ) + + response = await client.post( + "/discover-tools", + json={"server_url": "https://auth-server.example.com/mcp"}, + ) + + assert response.status_code == 401 + assert "requires authentication" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_tools_missing_url(self, client): + response = await client.post("/discover-tools", json={}) + assert response.status_code == 422 + + +class TestOAuthLogin: + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_login_success(self, client): + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + patch("backend.api.features.mcp.routes.settings") as mock_settings, + patch( + "backend.api.features.mcp.routes._register_mcp_client" + ) as mock_register, + ): + instance = MockClient.return_value + instance.discover_auth = AsyncMock( + return_value={ + "authorization_servers": ["https://auth.sentry.io"], + "resource": "https://mcp.sentry.dev/mcp", + "scopes_supported": ["openid"], + } + ) + instance.discover_auth_server_metadata = AsyncMock( + return_value={ + "authorization_endpoint": "https://auth.sentry.io/authorize", + "token_endpoint": "https://auth.sentry.io/token", + "registration_endpoint": "https://auth.sentry.io/register", + } + ) + mock_register.return_value = { + "client_id": "registered-client-id", + "client_secret": "registered-secret", + } + mock_cm.store.store_state_token = AsyncMock( + return_value=("state-token-123", "code-challenge-abc") + ) + mock_settings.config.frontend_base_url = "http://localhost:3000" + + response = await client.post( + "/oauth/login", + json={"server_url": "https://mcp.sentry.dev/mcp"}, + ) + + assert response.status_code == 200 + data = response.json() + assert "login_url" in data + assert data["state_token"] == "state-token-123" + assert "auth.sentry.io/authorize" in data["login_url"] + assert "registered-client-id" in data["login_url"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_login_no_oauth_support(self, client): + with patch("backend.api.features.mcp.routes.MCPClient") as MockClient: + instance = MockClient.return_value + instance.discover_auth = AsyncMock(return_value=None) + instance.discover_auth_server_metadata = AsyncMock(return_value=None) + + response = await client.post( + "/oauth/login", + json={"server_url": "https://simple-server.example.com/mcp"}, + ) + + assert response.status_code == 400 + assert "does not advertise OAuth" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_login_fallback_to_public_client(self, client): + """When DCR is unavailable, falls back to default public client ID.""" + with ( + patch("backend.api.features.mcp.routes.MCPClient") as MockClient, + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + patch("backend.api.features.mcp.routes.settings") as mock_settings, + ): + instance = MockClient.return_value + instance.discover_auth = AsyncMock( + return_value={ + "authorization_servers": ["https://auth.example.com"], + "resource": "https://mcp.example.com/mcp", + } + ) + instance.discover_auth_server_metadata = AsyncMock( + return_value={ + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + # No registration_endpoint + } + ) + mock_cm.store.store_state_token = AsyncMock( + return_value=("state-abc", "challenge-xyz") + ) + mock_settings.config.frontend_base_url = "http://localhost:3000" + + response = await client.post( + "/oauth/login", + json={"server_url": "https://mcp.example.com/mcp"}, + ) + + assert response.status_code == 200 + data = response.json() + assert "autogpt-platform" in data["login_url"] + + +class TestOAuthCallback: + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_callback_success(self, client): + from pydantic import SecretStr + + from backend.data.model import OAuth2Credentials + + mock_creds = OAuth2Credentials( + provider="mcp", + title=None, + access_token=SecretStr("access-token-xyz"), + refresh_token=None, + access_token_expires_at=None, + refresh_token_expires_at=None, + scopes=[], + metadata={ + "mcp_token_url": "https://auth.sentry.io/token", + "mcp_resource_url": "https://mcp.sentry.dev/mcp", + }, + ) + + with ( + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + patch("backend.api.features.mcp.routes.settings") as mock_settings, + patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler, + ): + mock_settings.config.frontend_base_url = "http://localhost:3000" + + # Mock state verification + mock_state = AsyncMock() + mock_state.state_metadata = { + "authorize_url": "https://auth.sentry.io/authorize", + "token_url": "https://auth.sentry.io/token", + "client_id": "test-client-id", + "client_secret": "test-secret", + "server_url": "https://mcp.sentry.dev/mcp", + } + mock_state.scopes = ["openid"] + mock_state.code_verifier = "verifier-123" + mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state) + mock_cm.create = AsyncMock() + + handler_instance = MockHandler.return_value + handler_instance.exchange_code_for_tokens = AsyncMock( + return_value=mock_creds + ) + + # Mock old credential cleanup + mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[]) + + response = await client.post( + "/oauth/callback", + json={"code": "auth-code-abc", "state_token": "state-token-123"}, + ) + + assert response.status_code == 200 + data = response.json() + assert "id" in data + assert data["provider"] == "mcp" + assert data["type"] == "oauth2" + mock_cm.create.assert_called_once() + + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_callback_invalid_state(self, client): + with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm: + mock_cm.store.verify_state_token = AsyncMock(return_value=None) + + response = await client.post( + "/oauth/callback", + json={"code": "auth-code", "state_token": "bad-state"}, + ) + + assert response.status_code == 400 + assert "Invalid or expired" in response.json()["detail"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_oauth_callback_token_exchange_fails(self, client): + with ( + patch("backend.api.features.mcp.routes.creds_manager") as mock_cm, + patch("backend.api.features.mcp.routes.settings") as mock_settings, + patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler, + ): + mock_settings.config.frontend_base_url = "http://localhost:3000" + mock_state = AsyncMock() + mock_state.state_metadata = { + "authorize_url": "https://auth.example.com/authorize", + "token_url": "https://auth.example.com/token", + "client_id": "cid", + "server_url": "https://mcp.example.com/mcp", + } + mock_state.scopes = [] + mock_state.code_verifier = "v" + mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state) + + handler_instance = MockHandler.return_value + handler_instance.exchange_code_for_tokens = AsyncMock( + side_effect=RuntimeError("Token exchange failed") + ) + + response = await client.post( + "/oauth/callback", + json={"code": "bad-code", "state_token": "state"}, + ) + + assert response.status_code == 400 + assert "token exchange failed" in response.json()["detail"].lower() diff --git a/autogpt_platform/backend/backend/api/rest_api.py b/autogpt_platform/backend/backend/api/rest_api.py index 7220845679..f37f28dd7c 100644 --- a/autogpt_platform/backend/backend/api/rest_api.py +++ b/autogpt_platform/backend/backend/api/rest_api.py @@ -26,6 +26,7 @@ import backend.api.features.executions.review.routes import backend.api.features.library.db import backend.api.features.library.model import backend.api.features.library.routes +import backend.api.features.mcp.routes as mcp_routes import backend.api.features.oauth import backend.api.features.otto.routes import backend.api.features.postmark.postmark @@ -343,6 +344,11 @@ app.include_router( tags=["workspace"], prefix="/api/workspace", ) +app.include_router( + mcp_routes.router, + tags=["v2", "mcp"], + prefix="/api/mcp", +) app.include_router( backend.api.features.oauth.router, tags=["oauth"], diff --git a/autogpt_platform/backend/backend/blocks/_base.py b/autogpt_platform/backend/backend/blocks/_base.py index 0ba4daec40..632c5e43b9 100644 --- a/autogpt_platform/backend/backend/blocks/_base.py +++ b/autogpt_platform/backend/backend/blocks/_base.py @@ -64,6 +64,7 @@ class BlockType(Enum): AI = "AI" AYRSHARE = "Ayrshare" HUMAN_IN_THE_LOOP = "Human In The Loop" + MCP_TOOL = "MCP Tool" class BlockCategory(Enum): diff --git a/autogpt_platform/backend/backend/blocks/mcp/__init__.py b/autogpt_platform/backend/backend/blocks/mcp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/autogpt_platform/backend/backend/blocks/mcp/block.py b/autogpt_platform/backend/backend/blocks/mcp/block.py new file mode 100644 index 0000000000..9e3056d928 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/block.py @@ -0,0 +1,300 @@ +""" +MCP (Model Context Protocol) Tool Block. + +A single dynamic block that can connect to any MCP server, discover available tools, +and execute them. Works like AgentExecutorBlock — the user selects a tool from a +dropdown and the input/output schema adapts dynamically. +""" + +import json +import logging +from typing import Any, Literal + +from pydantic import SecretStr + +from backend.blocks._base import ( + Block, + BlockCategory, + BlockSchemaInput, + BlockSchemaOutput, + BlockType, +) +from backend.blocks.mcp.client import MCPClient, MCPClientError +from backend.data.block import BlockInput, BlockOutput +from backend.data.model import ( + CredentialsField, + CredentialsMetaInput, + OAuth2Credentials, + SchemaField, +) +from backend.integrations.providers import ProviderName +from backend.util.json import validate_with_jsonschema + +logger = logging.getLogger(__name__) + +TEST_CREDENTIALS = OAuth2Credentials( + id="test-mcp-cred", + provider="mcp", + access_token=SecretStr("mock-mcp-token"), + refresh_token=SecretStr("mock-refresh"), + scopes=[], + title="Mock MCP credential", +) +TEST_CREDENTIALS_INPUT = { + "provider": TEST_CREDENTIALS.provider, + "id": TEST_CREDENTIALS.id, + "type": TEST_CREDENTIALS.type, + "title": TEST_CREDENTIALS.title, +} + + +MCPCredentials = CredentialsMetaInput[Literal[ProviderName.MCP], Literal["oauth2"]] + + +class MCPToolBlock(Block): + """ + A block that connects to an MCP server, lets the user pick a tool, + and executes it with dynamic input/output schema. + + The flow: + 1. User provides an MCP server URL (and optional credentials) + 2. Frontend calls the backend to get tool list from that URL + 3. User selects a tool from a dropdown (available_tools) + 4. The block's input schema updates to reflect the selected tool's parameters + 5. On execution, the block calls the MCP server to run the tool + """ + + class Input(BlockSchemaInput): + server_url: str = SchemaField( + description="URL of the MCP server (Streamable HTTP endpoint)", + placeholder="https://mcp.example.com/mcp", + ) + credentials: MCPCredentials = CredentialsField( + discriminator="server_url", + description="MCP server OAuth credentials", + default={}, + ) + selected_tool: str = SchemaField( + description="The MCP tool to execute", + placeholder="Select a tool", + default="", + ) + tool_input_schema: dict[str, Any] = SchemaField( + description="JSON Schema for the selected tool's input parameters. " + "Populated automatically when a tool is selected.", + default={}, + hidden=True, + ) + + tool_arguments: dict[str, Any] = SchemaField( + description="Arguments to pass to the selected MCP tool. " + "The fields here are defined by the tool's input schema.", + default={}, + ) + + @classmethod + def get_input_schema(cls, data: BlockInput) -> dict[str, Any]: + """Return the tool's input schema so the builder UI renders dynamic fields.""" + return data.get("tool_input_schema", {}) + + @classmethod + def get_input_defaults(cls, data: BlockInput) -> BlockInput: + """Return the current tool_arguments as defaults for the dynamic fields.""" + return data.get("tool_arguments", {}) + + @classmethod + def get_missing_input(cls, data: BlockInput) -> set[str]: + """Check which required tool arguments are missing.""" + required_fields = cls.get_input_schema(data).get("required", []) + tool_arguments = data.get("tool_arguments", {}) + return set(required_fields) - set(tool_arguments) + + @classmethod + def get_mismatch_error(cls, data: BlockInput) -> str | None: + """Validate tool_arguments against the tool's input schema.""" + tool_schema = cls.get_input_schema(data) + if not tool_schema: + return None + tool_arguments = data.get("tool_arguments", {}) + return validate_with_jsonschema(tool_schema, tool_arguments) + + class Output(BlockSchemaOutput): + result: Any = SchemaField(description="The result returned by the MCP tool") + error: str = SchemaField(description="Error message if the tool call failed") + + def __init__(self): + super().__init__( + id="a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4", + description="Connect to any MCP server and execute its tools. " + "Provide a server URL, select a tool, and pass arguments dynamically.", + categories={BlockCategory.DEVELOPER_TOOLS}, + input_schema=MCPToolBlock.Input, + output_schema=MCPToolBlock.Output, + block_type=BlockType.MCP_TOOL, + test_credentials=TEST_CREDENTIALS, + test_input={ + "server_url": "https://mcp.example.com/mcp", + "credentials": TEST_CREDENTIALS_INPUT, + "selected_tool": "get_weather", + "tool_input_schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + "tool_arguments": {"city": "London"}, + }, + test_output=[ + ( + "result", + {"weather": "sunny", "temperature": 20}, + ), + ], + test_mock={ + "_call_mcp_tool": lambda *a, **kw: { + "weather": "sunny", + "temperature": 20, + }, + }, + ) + + async def _call_mcp_tool( + self, + server_url: str, + tool_name: str, + arguments: dict[str, Any], + auth_token: str | None = None, + ) -> Any: + """Call a tool on the MCP server. Extracted for easy mocking in tests.""" + client = MCPClient(server_url, auth_token=auth_token) + await client.initialize() + result = await client.call_tool(tool_name, arguments) + + if result.is_error: + error_text = "" + for item in result.content: + if item.get("type") == "text": + error_text += item.get("text", "") + raise MCPClientError( + f"MCP tool '{tool_name}' returned an error: " + f"{error_text or 'Unknown error'}" + ) + + # Extract text content from the result + output_parts = [] + for item in result.content: + if item.get("type") == "text": + text = item.get("text", "") + # Try to parse as JSON for structured output + try: + output_parts.append(json.loads(text)) + except (json.JSONDecodeError, ValueError): + output_parts.append(text) + elif item.get("type") == "image": + output_parts.append( + { + "type": "image", + "data": item.get("data"), + "mimeType": item.get("mimeType"), + } + ) + elif item.get("type") == "resource": + output_parts.append(item.get("resource", {})) + + # If single result, unwrap + if len(output_parts) == 1: + return output_parts[0] + return output_parts if output_parts else None + + @staticmethod + async def _auto_lookup_credential( + user_id: str, server_url: str + ) -> "OAuth2Credentials | None": + """Auto-lookup stored MCP credential for a server URL. + + This is a fallback for nodes that don't have ``credentials`` explicitly + set (e.g. nodes created before the credential field was wired up). + """ + from backend.integrations.creds_manager import IntegrationCredentialsManager + from backend.integrations.providers import ProviderName + + try: + mgr = IntegrationCredentialsManager() + mcp_creds = await mgr.store.get_creds_by_provider( + user_id, ProviderName.MCP.value + ) + best: OAuth2Credentials | None = None + for cred in mcp_creds: + if ( + isinstance(cred, OAuth2Credentials) + and (cred.metadata or {}).get("mcp_server_url") == server_url + ): + if best is None or ( + (cred.access_token_expires_at or 0) + > (best.access_token_expires_at or 0) + ): + best = cred + if best: + best = await mgr.refresh_if_needed(user_id, best) + logger.info( + "Auto-resolved MCP credential %s for %s", best.id, server_url + ) + return best + except Exception: + logger.warning("Auto-lookup MCP credential failed", exc_info=True) + return None + + async def run( + self, + input_data: Input, + *, + user_id: str, + credentials: OAuth2Credentials | None = None, + **kwargs, + ) -> BlockOutput: + if not input_data.server_url: + yield "error", "MCP server URL is required" + return + + if not input_data.selected_tool: + yield "error", "No tool selected. Please select a tool from the dropdown." + return + + # Validate required tool arguments before calling the server. + # The executor-level validation is bypassed for MCP blocks because + # get_input_defaults() flattens tool_arguments, stripping tool_input_schema + # from the validation context. + required = set(input_data.tool_input_schema.get("required", [])) + if required: + missing = required - set(input_data.tool_arguments.keys()) + if missing: + yield "error", ( + f"Missing required argument(s): {', '.join(sorted(missing))}. " + f"Please fill in all required fields marked with * in the block form." + ) + return + + # If no credentials were injected by the executor (e.g. legacy nodes + # that don't have the credentials field set), try to auto-lookup + # the stored MCP credential for this server URL. + if credentials is None: + credentials = await self._auto_lookup_credential( + user_id, input_data.server_url + ) + + auth_token = ( + credentials.access_token.get_secret_value() if credentials else None + ) + + try: + result = await self._call_mcp_tool( + server_url=input_data.server_url, + tool_name=input_data.selected_tool, + arguments=input_data.tool_arguments, + auth_token=auth_token, + ) + yield "result", result + except MCPClientError as e: + yield "error", str(e) + except Exception as e: + logger.exception(f"MCP tool call failed: {e}") + yield "error", f"MCP tool call failed: {str(e)}" diff --git a/autogpt_platform/backend/backend/blocks/mcp/client.py b/autogpt_platform/backend/backend/blocks/mcp/client.py new file mode 100644 index 0000000000..050349dbcc --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/client.py @@ -0,0 +1,323 @@ +""" +MCP (Model Context Protocol) HTTP client. + +Implements the MCP Streamable HTTP transport for listing tools and calling tools +on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST. + +Handles both JSON and SSE (text/event-stream) response formats per the MCP spec. + +Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports +""" + +import json +import logging +from dataclasses import dataclass, field +from typing import Any + +from backend.util.request import Requests + +logger = logging.getLogger(__name__) + + +@dataclass +class MCPTool: + """Represents an MCP tool discovered from a server.""" + + name: str + description: str + input_schema: dict[str, Any] + + +@dataclass +class MCPCallResult: + """Result from calling an MCP tool.""" + + content: list[dict[str, Any]] = field(default_factory=list) + is_error: bool = False + + +class MCPClientError(Exception): + """Raised when an MCP protocol error occurs.""" + + pass + + +class MCPClient: + """ + Async HTTP client for the MCP Streamable HTTP transport. + + Communicates with MCP servers using JSON-RPC 2.0 over HTTP POST. + Supports optional Bearer token authentication. + """ + + def __init__( + self, + server_url: str, + auth_token: str | None = None, + ): + self.server_url = server_url.rstrip("/") + self.auth_token = auth_token + self._request_id = 0 + self._session_id: str | None = None + + def _next_id(self) -> int: + self._request_id += 1 + return self._request_id + + def _build_headers(self) -> dict[str, str]: + headers = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + } + if self.auth_token: + headers["Authorization"] = f"Bearer {self.auth_token}" + if self._session_id: + headers["Mcp-Session-Id"] = self._session_id + return headers + + def _build_jsonrpc_request( + self, method: str, params: dict[str, Any] | None = None + ) -> dict[str, Any]: + req: dict[str, Any] = { + "jsonrpc": "2.0", + "method": method, + "id": self._next_id(), + } + if params is not None: + req["params"] = params + return req + + @staticmethod + def _parse_sse_response(text: str) -> dict[str, Any]: + """Parse an SSE (text/event-stream) response body into JSON-RPC data. + + MCP servers may return responses as SSE with format: + event: message + data: {"jsonrpc":"2.0","result":{...},"id":1} + + We extract the last `data:` line that contains a JSON-RPC response + (i.e. has an "id" field), which is the reply to our request. + """ + last_data: dict[str, Any] | None = None + for line in text.splitlines(): + stripped = line.strip() + if stripped.startswith("data:"): + payload = stripped[len("data:") :].strip() + if not payload: + continue + try: + parsed = json.loads(payload) + # Only keep JSON-RPC responses (have "id"), skip notifications + if isinstance(parsed, dict) and "id" in parsed: + last_data = parsed + except (json.JSONDecodeError, ValueError): + continue + if last_data is None: + raise MCPClientError("No JSON-RPC response found in SSE stream") + return last_data + + async def _send_request( + self, method: str, params: dict[str, Any] | None = None + ) -> Any: + """Send a JSON-RPC request to the MCP server and return the result. + + Handles both ``application/json`` and ``text/event-stream`` responses + as required by the MCP Streamable HTTP transport specification. + """ + payload = self._build_jsonrpc_request(method, params) + headers = self._build_headers() + + requests = Requests( + raise_for_status=True, + extra_headers=headers, + ) + response = await requests.post(self.server_url, json=payload) + + # Capture session ID from response (MCP Streamable HTTP transport) + session_id = response.headers.get("Mcp-Session-Id") + if session_id: + self._session_id = session_id + + content_type = response.headers.get("content-type", "") + if "text/event-stream" in content_type: + body = self._parse_sse_response(response.text()) + else: + try: + body = response.json() + except Exception as e: + raise MCPClientError( + f"MCP server returned non-JSON response: {e}" + ) from e + + if not isinstance(body, dict): + raise MCPClientError( + f"MCP server returned unexpected JSON type: {type(body).__name__}" + ) + + # Handle JSON-RPC error + if "error" in body: + error = body["error"] + if isinstance(error, dict): + raise MCPClientError( + f"MCP server error [{error.get('code', '?')}]: " + f"{error.get('message', 'Unknown error')}" + ) + raise MCPClientError(f"MCP server error: {error}") + + return body.get("result") + + async def _send_notification(self, method: str) -> None: + """Send a JSON-RPC notification (no id, no response expected).""" + headers = self._build_headers() + notification = {"jsonrpc": "2.0", "method": method} + requests = Requests( + raise_for_status=False, + extra_headers=headers, + ) + await requests.post(self.server_url, json=notification) + + async def discover_auth(self) -> dict[str, Any] | None: + """Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec). + + Returns ``None`` if the server doesn't require auth, otherwise returns + a dict with: + - ``authorization_servers``: list of authorization server URLs + - ``resource``: the resource indicator URL (usually the MCP endpoint) + - ``scopes_supported``: optional list of supported scopes + + The caller can then fetch the authorization server metadata to get + ``authorization_endpoint``, ``token_endpoint``, etc. + """ + from urllib.parse import urlparse + + parsed = urlparse(self.server_url) + base = f"{parsed.scheme}://{parsed.netloc}" + + # Build candidates for protected-resource metadata (per RFC 9728) + path = parsed.path.rstrip("/") + candidates = [] + if path and path != "/": + candidates.append(f"{base}/.well-known/oauth-protected-resource{path}") + candidates.append(f"{base}/.well-known/oauth-protected-resource") + + requests = Requests( + raise_for_status=False, + ) + for url in candidates: + try: + resp = await requests.get(url) + if resp.status == 200: + data = resp.json() + if isinstance(data, dict) and "authorization_servers" in data: + return data + except Exception: + continue + + return None + + async def discover_auth_server_metadata( + self, auth_server_url: str + ) -> dict[str, Any] | None: + """Fetch the OAuth Authorization Server Metadata (RFC 8414). + + Given an authorization server URL, returns a dict with: + - ``authorization_endpoint`` + - ``token_endpoint`` + - ``registration_endpoint`` (for dynamic client registration) + - ``scopes_supported`` + - ``code_challenge_methods_supported`` + - etc. + """ + from urllib.parse import urlparse + + parsed = urlparse(auth_server_url) + base = f"{parsed.scheme}://{parsed.netloc}" + path = parsed.path.rstrip("/") + + # Try standard metadata endpoints (RFC 8414 and OpenID Connect) + candidates = [] + if path and path != "/": + candidates.append(f"{base}/.well-known/oauth-authorization-server{path}") + candidates.append(f"{base}/.well-known/oauth-authorization-server") + candidates.append(f"{base}/.well-known/openid-configuration") + + requests = Requests( + raise_for_status=False, + ) + for url in candidates: + try: + resp = await requests.get(url) + if resp.status == 200: + data = resp.json() + if isinstance(data, dict) and "authorization_endpoint" in data: + return data + except Exception: + continue + + return None + + async def initialize(self) -> dict[str, Any]: + """ + Send the MCP initialize request. + + This is required by the MCP protocol before any other requests. + Returns the server's capabilities. + """ + result = await self._send_request( + "initialize", + { + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": {"name": "AutoGPT-Platform", "version": "1.0.0"}, + }, + ) + # Send initialized notification (no response expected) + await self._send_notification("notifications/initialized") + + return result or {} + + async def list_tools(self) -> list[MCPTool]: + """ + Discover available tools from the MCP server. + + Returns a list of MCPTool objects with name, description, and input schema. + """ + result = await self._send_request("tools/list") + if not result or "tools" not in result: + return [] + + tools = [] + for tool_data in result["tools"]: + tools.append( + MCPTool( + name=tool_data.get("name", ""), + description=tool_data.get("description", ""), + input_schema=tool_data.get("inputSchema", {}), + ) + ) + return tools + + async def call_tool( + self, tool_name: str, arguments: dict[str, Any] + ) -> MCPCallResult: + """ + Call a tool on the MCP server. + + Args: + tool_name: The name of the tool to call. + arguments: The arguments to pass to the tool. + + Returns: + MCPCallResult with the tool's response content. + """ + result = await self._send_request( + "tools/call", + {"name": tool_name, "arguments": arguments}, + ) + if not result: + return MCPCallResult(is_error=True) + + return MCPCallResult( + content=result.get("content", []), + is_error=result.get("isError", False), + ) diff --git a/autogpt_platform/backend/backend/blocks/mcp/oauth.py b/autogpt_platform/backend/backend/blocks/mcp/oauth.py new file mode 100644 index 0000000000..2228336cd3 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/oauth.py @@ -0,0 +1,204 @@ +""" +MCP OAuth handler for MCP servers that use OAuth 2.1 authorization. + +Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed, +MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata. +This handler accepts those endpoints at construction time. +""" + +import logging +import time +import urllib.parse +from typing import ClassVar, Optional + +from pydantic import SecretStr + +from backend.data.model import OAuth2Credentials +from backend.integrations.oauth.base import BaseOAuthHandler +from backend.integrations.providers import ProviderName +from backend.util.request import Requests + +logger = logging.getLogger(__name__) + + +class MCPOAuthHandler(BaseOAuthHandler): + """ + OAuth handler for MCP servers with dynamically-discovered endpoints. + + Construction requires the authorization and token endpoint URLs, + which are obtained via MCP OAuth metadata discovery + (``MCPClient.discover_auth`` + ``discover_auth_server_metadata``). + """ + + PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP + DEFAULT_SCOPES: ClassVar[list[str]] = [] + + def __init__( + self, + client_id: str, + client_secret: str, + redirect_uri: str, + *, + authorize_url: str, + token_url: str, + revoke_url: str | None = None, + resource_url: str | None = None, + ): + self.client_id = client_id + self.client_secret = client_secret + self.redirect_uri = redirect_uri + self.authorize_url = authorize_url + self.token_url = token_url + self.revoke_url = revoke_url + self.resource_url = resource_url + + def get_login_url( + self, + scopes: list[str], + state: str, + code_challenge: Optional[str], + ) -> str: + scopes = self.handle_default_scopes(scopes) + + params: dict[str, str] = { + "response_type": "code", + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "state": state, + } + if scopes: + params["scope"] = " ".join(scopes) + # PKCE (S256) — included when the caller provides a code_challenge + if code_challenge: + params["code_challenge"] = code_challenge + params["code_challenge_method"] = "S256" + # MCP spec requires resource indicator (RFC 8707) + if self.resource_url: + params["resource"] = self.resource_url + + return f"{self.authorize_url}?{urllib.parse.urlencode(params)}" + + async def exchange_code_for_tokens( + self, + code: str, + scopes: list[str], + code_verifier: Optional[str], + ) -> OAuth2Credentials: + data: dict[str, str] = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": self.redirect_uri, + "client_id": self.client_id, + } + if self.client_secret: + data["client_secret"] = self.client_secret + if code_verifier: + data["code_verifier"] = code_verifier + if self.resource_url: + data["resource"] = self.resource_url + + response = await Requests(raise_for_status=True).post( + self.token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + tokens = response.json() + + if "error" in tokens: + raise RuntimeError( + f"Token exchange failed: {tokens.get('error_description', tokens['error'])}" + ) + + if "access_token" not in tokens: + raise RuntimeError("OAuth token response missing 'access_token' field") + + now = int(time.time()) + expires_in = tokens.get("expires_in") + + return OAuth2Credentials( + provider=self.PROVIDER_NAME, + title=None, + access_token=SecretStr(tokens["access_token"]), + refresh_token=( + SecretStr(tokens["refresh_token"]) + if tokens.get("refresh_token") + else None + ), + access_token_expires_at=now + expires_in if expires_in else None, + refresh_token_expires_at=None, + scopes=scopes, + metadata={ + "mcp_token_url": self.token_url, + "mcp_resource_url": self.resource_url, + }, + ) + + async def _refresh_tokens( + self, credentials: OAuth2Credentials + ) -> OAuth2Credentials: + if not credentials.refresh_token: + raise ValueError("No refresh token available for MCP OAuth credentials") + + data: dict[str, str] = { + "grant_type": "refresh_token", + "refresh_token": credentials.refresh_token.get_secret_value(), + "client_id": self.client_id, + } + if self.client_secret: + data["client_secret"] = self.client_secret + if self.resource_url: + data["resource"] = self.resource_url + + response = await Requests(raise_for_status=True).post( + self.token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + tokens = response.json() + + if "error" in tokens: + raise RuntimeError( + f"Token refresh failed: {tokens.get('error_description', tokens['error'])}" + ) + + if "access_token" not in tokens: + raise RuntimeError("OAuth refresh response missing 'access_token' field") + + now = int(time.time()) + expires_in = tokens.get("expires_in") + + return OAuth2Credentials( + id=credentials.id, + provider=self.PROVIDER_NAME, + title=credentials.title, + access_token=SecretStr(tokens["access_token"]), + refresh_token=( + SecretStr(tokens["refresh_token"]) + if tokens.get("refresh_token") + else credentials.refresh_token + ), + access_token_expires_at=now + expires_in if expires_in else None, + refresh_token_expires_at=credentials.refresh_token_expires_at, + scopes=credentials.scopes, + metadata=credentials.metadata, + ) + + async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool: + if not self.revoke_url: + return False + + try: + data = { + "token": credentials.access_token.get_secret_value(), + "token_type_hint": "access_token", + "client_id": self.client_id, + } + await Requests().post( + self.revoke_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + return True + except Exception: + logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True) + return False diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_e2e.py b/autogpt_platform/backend/backend/blocks/mcp/test_e2e.py new file mode 100644 index 0000000000..7818fac9ce --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_e2e.py @@ -0,0 +1,109 @@ +""" +End-to-end tests against a real public MCP server. + +These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp) +which is publicly accessible without authentication and returns SSE responses. + +Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped +independently of the rest of the test suite (they require network access). +""" + +import json +import os + +import pytest + +from backend.blocks.mcp.client import MCPClient + +# Public MCP server that requires no authentication +OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp" + +# Skip all tests in this module unless RUN_E2E env var is set +pytestmark = pytest.mark.skipif( + not os.environ.get("RUN_E2E"), reason="set RUN_E2E=1 to run e2e tests" +) + + +class TestRealMCPServer: + """Tests against the live OpenAI docs MCP server.""" + + @pytest.mark.asyncio(loop_scope="session") + async def test_initialize(self): + """Verify we can complete the MCP handshake with a real server.""" + client = MCPClient(OPENAI_DOCS_MCP_URL) + result = await client.initialize() + + assert result["protocolVersion"] == "2025-03-26" + assert "serverInfo" in result + assert result["serverInfo"]["name"] == "openai-docs-mcp" + assert "tools" in result.get("capabilities", {}) + + @pytest.mark.asyncio(loop_scope="session") + async def test_list_tools(self): + """Verify we can discover tools from a real MCP server.""" + client = MCPClient(OPENAI_DOCS_MCP_URL) + await client.initialize() + tools = await client.list_tools() + + assert len(tools) >= 3 # server has at least 5 tools as of writing + + tool_names = {t.name for t in tools} + # These tools are documented and should be stable + assert "search_openai_docs" in tool_names + assert "list_openai_docs" in tool_names + assert "fetch_openai_doc" in tool_names + + # Verify schema structure + search_tool = next(t for t in tools if t.name == "search_openai_docs") + assert "query" in search_tool.input_schema.get("properties", {}) + assert "query" in search_tool.input_schema.get("required", []) + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_list_api_endpoints(self): + """Call the list_api_endpoints tool and verify we get real data.""" + client = MCPClient(OPENAI_DOCS_MCP_URL) + await client.initialize() + result = await client.call_tool("list_api_endpoints", {}) + + assert not result.is_error + assert len(result.content) >= 1 + assert result.content[0]["type"] == "text" + + data = json.loads(result.content[0]["text"]) + assert "paths" in data or "urls" in data + # The OpenAI API should have many endpoints + total = data.get("total", len(data.get("paths", []))) + assert total > 50 + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_search(self): + """Search for docs and verify we get results.""" + client = MCPClient(OPENAI_DOCS_MCP_URL) + await client.initialize() + result = await client.call_tool( + "search_openai_docs", {"query": "chat completions", "limit": 3} + ) + + assert not result.is_error + assert len(result.content) >= 1 + + @pytest.mark.asyncio(loop_scope="session") + async def test_sse_response_handling(self): + """Verify the client correctly handles SSE responses from a real server. + + This is the key test — our local test server returns JSON, + but real MCP servers typically return SSE. This proves the + SSE parsing works end-to-end. + """ + client = MCPClient(OPENAI_DOCS_MCP_URL) + # initialize() internally calls _send_request which must parse SSE + result = await client.initialize() + + # If we got here without error, SSE parsing works + assert isinstance(result, dict) + assert "protocolVersion" in result + + # Also verify list_tools works (another SSE response) + tools = await client.list_tools() + assert len(tools) > 0 + assert all(hasattr(t, "name") for t in tools) diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_integration.py b/autogpt_platform/backend/backend/blocks/mcp/test_integration.py new file mode 100644 index 0000000000..70658dbaaf --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_integration.py @@ -0,0 +1,389 @@ +""" +Integration tests for MCP client and MCPToolBlock against a real HTTP server. + +These tests spin up a local MCP test server and run the full client/block flow +against it — no mocking, real HTTP requests. +""" + +import asyncio +import json +import threading +from unittest.mock import patch + +import pytest +from aiohttp import web +from pydantic import SecretStr + +from backend.blocks.mcp.block import MCPToolBlock +from backend.blocks.mcp.client import MCPClient +from backend.blocks.mcp.test_server import create_test_mcp_app +from backend.data.model import OAuth2Credentials + +MOCK_USER_ID = "test-user-integration" + + +class _MCPTestServer: + """ + Run an MCP test server in a background thread with its own event loop. + This avoids event loop conflicts with pytest-asyncio. + """ + + def __init__(self, auth_token: str | None = None): + self.auth_token = auth_token + self.url: str = "" + self._runner: web.AppRunner | None = None + self._loop: asyncio.AbstractEventLoop | None = None + self._thread: threading.Thread | None = None + self._started = threading.Event() + + def _run(self): + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + self._loop.run_until_complete(self._start()) + self._started.set() + self._loop.run_forever() + + async def _start(self): + app = create_test_mcp_app(auth_token=self.auth_token) + self._runner = web.AppRunner(app) + await self._runner.setup() + site = web.TCPSite(self._runner, "127.0.0.1", 0) + await site.start() + port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr] + self.url = f"http://127.0.0.1:{port}/mcp" + + def start(self): + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + if not self._started.wait(timeout=5): + raise RuntimeError("MCP test server failed to start within 5 seconds") + return self + + def stop(self): + if self._loop and self._runner: + asyncio.run_coroutine_threadsafe(self._runner.cleanup(), self._loop).result( + timeout=5 + ) + self._loop.call_soon_threadsafe(self._loop.stop) + if self._thread: + self._thread.join(timeout=5) + + +@pytest.fixture(scope="module") +def mcp_server(): + """Start a local MCP test server in a background thread.""" + server = _MCPTestServer() + server.start() + yield server.url + server.stop() + + +@pytest.fixture(scope="module") +def mcp_server_with_auth(): + """Start a local MCP test server with auth in a background thread.""" + server = _MCPTestServer(auth_token="test-secret-token") + server.start() + yield server.url, "test-secret-token" + server.stop() + + +@pytest.fixture(autouse=True) +def _allow_localhost(): + """ + Allow 127.0.0.1 through SSRF protection for integration tests. + + The Requests class blocks private IPs by default. We patch the Requests + constructor to always include 127.0.0.1 as a trusted origin so the local + test server is reachable. + """ + from backend.util.request import Requests + + original_init = Requests.__init__ + + def patched_init(self, *args, **kwargs): + trusted = list(kwargs.get("trusted_origins") or []) + trusted.append("http://127.0.0.1") + kwargs["trusted_origins"] = trusted + original_init(self, *args, **kwargs) + + with patch.object(Requests, "__init__", patched_init): + yield + + +def _make_client(url: str, auth_token: str | None = None) -> MCPClient: + """Create an MCPClient for integration tests.""" + return MCPClient(url, auth_token=auth_token) + + +# ── MCPClient integration tests ────────────────────────────────────── + + +class TestMCPClientIntegration: + """Test MCPClient against a real local MCP server.""" + + @pytest.mark.asyncio(loop_scope="session") + async def test_initialize(self, mcp_server): + client = _make_client(mcp_server) + result = await client.initialize() + + assert result["protocolVersion"] == "2025-03-26" + assert result["serverInfo"]["name"] == "test-mcp-server" + assert "tools" in result["capabilities"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_list_tools(self, mcp_server): + client = _make_client(mcp_server) + await client.initialize() + tools = await client.list_tools() + + assert len(tools) == 3 + + tool_names = {t.name for t in tools} + assert tool_names == {"get_weather", "add_numbers", "echo"} + + # Check get_weather schema + weather = next(t for t in tools if t.name == "get_weather") + assert weather.description == "Get current weather for a city" + assert "city" in weather.input_schema["properties"] + assert weather.input_schema["required"] == ["city"] + + # Check add_numbers schema + add = next(t for t in tools if t.name == "add_numbers") + assert "a" in add.input_schema["properties"] + assert "b" in add.input_schema["properties"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_get_weather(self, mcp_server): + client = _make_client(mcp_server) + await client.initialize() + result = await client.call_tool("get_weather", {"city": "London"}) + + assert not result.is_error + assert len(result.content) == 1 + assert result.content[0]["type"] == "text" + + data = json.loads(result.content[0]["text"]) + assert data["city"] == "London" + assert data["temperature"] == 22 + assert data["condition"] == "sunny" + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_add_numbers(self, mcp_server): + client = _make_client(mcp_server) + await client.initialize() + result = await client.call_tool("add_numbers", {"a": 3, "b": 7}) + + assert not result.is_error + data = json.loads(result.content[0]["text"]) + assert data["result"] == 10 + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_echo(self, mcp_server): + client = _make_client(mcp_server) + await client.initialize() + result = await client.call_tool("echo", {"message": "Hello MCP!"}) + + assert not result.is_error + assert result.content[0]["text"] == "Hello MCP!" + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_unknown_tool(self, mcp_server): + client = _make_client(mcp_server) + await client.initialize() + result = await client.call_tool("nonexistent_tool", {}) + + assert result.is_error + assert "Unknown tool" in result.content[0]["text"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_auth_success(self, mcp_server_with_auth): + url, token = mcp_server_with_auth + client = _make_client(url, auth_token=token) + result = await client.initialize() + + assert result["protocolVersion"] == "2025-03-26" + + tools = await client.list_tools() + assert len(tools) == 3 + + @pytest.mark.asyncio(loop_scope="session") + async def test_auth_failure(self, mcp_server_with_auth): + url, _ = mcp_server_with_auth + client = _make_client(url, auth_token="wrong-token") + + with pytest.raises(Exception): + await client.initialize() + + @pytest.mark.asyncio(loop_scope="session") + async def test_auth_missing(self, mcp_server_with_auth): + url, _ = mcp_server_with_auth + client = _make_client(url) + + with pytest.raises(Exception): + await client.initialize() + + +# ── MCPToolBlock integration tests ─────────────────────────────────── + + +class TestMCPToolBlockIntegration: + """Test MCPToolBlock end-to-end against a real local MCP server.""" + + @pytest.mark.asyncio(loop_scope="session") + async def test_full_flow_get_weather(self, mcp_server): + """Full flow: discover tools, select one, execute it.""" + # Step 1: Discover tools (simulating what the frontend/API would do) + client = _make_client(mcp_server) + await client.initialize() + tools = await client.list_tools() + assert len(tools) == 3 + + # Step 2: User selects "get_weather" and we get its schema + weather_tool = next(t for t in tools if t.name == "get_weather") + + # Step 3: Execute the block — no credentials (public server) + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=mcp_server, + selected_tool="get_weather", + tool_input_schema=weather_tool.input_schema, + tool_arguments={"city": "Paris"}, + ) + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + result = outputs[0][1] + assert result["city"] == "Paris" + assert result["temperature"] == 22 + assert result["condition"] == "sunny" + + @pytest.mark.asyncio(loop_scope="session") + async def test_full_flow_add_numbers(self, mcp_server): + """Full flow for add_numbers tool.""" + client = _make_client(mcp_server) + await client.initialize() + tools = await client.list_tools() + add_tool = next(t for t in tools if t.name == "add_numbers") + + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=mcp_server, + selected_tool="add_numbers", + tool_input_schema=add_tool.input_schema, + tool_arguments={"a": 42, "b": 58}, + ) + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + assert outputs[0][1]["result"] == 100 + + @pytest.mark.asyncio(loop_scope="session") + async def test_full_flow_echo_plain_text(self, mcp_server): + """Verify plain text (non-JSON) responses work.""" + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=mcp_server, + selected_tool="echo", + tool_input_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + tool_arguments={"message": "Hello from AutoGPT!"}, + ) + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + assert outputs[0][1] == "Hello from AutoGPT!" + + @pytest.mark.asyncio(loop_scope="session") + async def test_full_flow_unknown_tool_yields_error(self, mcp_server): + """Calling an unknown tool should yield an error output.""" + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=mcp_server, + selected_tool="nonexistent_tool", + tool_arguments={}, + ) + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "error" + assert "returned an error" in outputs[0][1] + + @pytest.mark.asyncio(loop_scope="session") + async def test_full_flow_with_auth(self, mcp_server_with_auth): + """Full flow with authentication via credentials kwarg.""" + url, token = mcp_server_with_auth + + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=url, + selected_tool="echo", + tool_input_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + tool_arguments={"message": "Authenticated!"}, + ) + + # Pass credentials via the standard kwarg (as the executor would) + test_creds = OAuth2Credentials( + id="test-cred", + provider="mcp", + access_token=SecretStr(token), + refresh_token=SecretStr(""), + scopes=[], + title="Test MCP credential", + ) + + outputs = [] + async for name, data in block.run( + input_data, user_id=MOCK_USER_ID, credentials=test_creds + ): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + assert outputs[0][1] == "Authenticated!" + + @pytest.mark.asyncio(loop_scope="session") + async def test_no_credentials_runs_without_auth(self, mcp_server): + """Block runs without auth when no credentials are provided.""" + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url=mcp_server, + selected_tool="echo", + tool_input_schema={ + "type": "object", + "properties": {"message": {"type": "string"}}, + "required": ["message"], + }, + tool_arguments={"message": "No auth needed"}, + ) + + outputs = [] + async for name, data in block.run( + input_data, user_id=MOCK_USER_ID, credentials=None + ): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + assert outputs[0][1] == "No auth needed" diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_mcp.py b/autogpt_platform/backend/backend/blocks/mcp/test_mcp.py new file mode 100644 index 0000000000..8cb49b0fee --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_mcp.py @@ -0,0 +1,619 @@ +""" +Tests for MCP client and MCPToolBlock. +""" + +import json +from unittest.mock import AsyncMock, patch + +import pytest + +from backend.blocks.mcp.block import MCPToolBlock +from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError +from backend.util.test import execute_block_test + +# ── SSE parsing unit tests ─────────────────────────────────────────── + + +class TestSSEParsing: + """Tests for SSE (text/event-stream) response parsing.""" + + def test_parse_sse_simple(self): + sse = ( + "event: message\n" + 'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n' + "\n" + ) + body = MCPClient._parse_sse_response(sse) + assert body["result"] == {"tools": []} + assert body["id"] == 1 + + def test_parse_sse_with_notifications(self): + """SSE streams can contain notifications (no id) before the response.""" + sse = ( + "event: message\n" + 'data: {"jsonrpc":"2.0","method":"some/notification"}\n' + "\n" + "event: message\n" + 'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n' + "\n" + ) + body = MCPClient._parse_sse_response(sse) + assert body["result"] == {"ok": True} + assert body["id"] == 2 + + def test_parse_sse_error_response(self): + sse = ( + "event: message\n" + 'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n' + ) + body = MCPClient._parse_sse_response(sse) + assert "error" in body + assert body["error"]["code"] == -32600 + + def test_parse_sse_no_data_raises(self): + with pytest.raises(MCPClientError, match="No JSON-RPC response found"): + MCPClient._parse_sse_response("event: message\n\n") + + def test_parse_sse_empty_raises(self): + with pytest.raises(MCPClientError, match="No JSON-RPC response found"): + MCPClient._parse_sse_response("") + + def test_parse_sse_ignores_non_data_lines(self): + sse = ( + ": comment line\n" + "event: message\n" + "id: 123\n" + 'data: {"jsonrpc":"2.0","result":"ok","id":1}\n' + "\n" + ) + body = MCPClient._parse_sse_response(sse) + assert body["result"] == "ok" + + def test_parse_sse_uses_last_response(self): + """If multiple responses exist, use the last one.""" + sse = ( + 'data: {"jsonrpc":"2.0","result":"first","id":1}\n' + "\n" + 'data: {"jsonrpc":"2.0","result":"second","id":2}\n' + "\n" + ) + body = MCPClient._parse_sse_response(sse) + assert body["result"] == "second" + + +# ── MCPClient unit tests ───────────────────────────────────────────── + + +class TestMCPClient: + """Tests for the MCP HTTP client.""" + + def test_build_headers_without_auth(self): + client = MCPClient("https://mcp.example.com") + headers = client._build_headers() + assert "Authorization" not in headers + assert headers["Content-Type"] == "application/json" + + def test_build_headers_with_auth(self): + client = MCPClient("https://mcp.example.com", auth_token="my-token") + headers = client._build_headers() + assert headers["Authorization"] == "Bearer my-token" + + def test_build_jsonrpc_request(self): + client = MCPClient("https://mcp.example.com") + req = client._build_jsonrpc_request("tools/list") + assert req["jsonrpc"] == "2.0" + assert req["method"] == "tools/list" + assert "id" in req + assert "params" not in req + + def test_build_jsonrpc_request_with_params(self): + client = MCPClient("https://mcp.example.com") + req = client._build_jsonrpc_request( + "tools/call", {"name": "test", "arguments": {"x": 1}} + ) + assert req["params"] == {"name": "test", "arguments": {"x": 1}} + + def test_request_id_increments(self): + client = MCPClient("https://mcp.example.com") + req1 = client._build_jsonrpc_request("tools/list") + req2 = client._build_jsonrpc_request("tools/list") + assert req2["id"] > req1["id"] + + def test_server_url_trailing_slash_stripped(self): + client = MCPClient("https://mcp.example.com/mcp/") + assert client.server_url == "https://mcp.example.com/mcp" + + @pytest.mark.asyncio(loop_scope="session") + async def test_send_request_success(self): + client = MCPClient("https://mcp.example.com") + + mock_response = AsyncMock() + mock_response.json.return_value = { + "jsonrpc": "2.0", + "result": {"tools": []}, + "id": 1, + } + + with patch.object(client, "_send_request", return_value={"tools": []}): + result = await client._send_request("tools/list") + assert result == {"tools": []} + + @pytest.mark.asyncio(loop_scope="session") + async def test_send_request_error(self): + client = MCPClient("https://mcp.example.com") + + async def mock_send(*args, **kwargs): + raise MCPClientError("MCP server error [-32600]: Invalid Request") + + with patch.object(client, "_send_request", side_effect=mock_send): + with pytest.raises(MCPClientError, match="Invalid Request"): + await client._send_request("tools/list") + + @pytest.mark.asyncio(loop_scope="session") + async def test_list_tools(self): + client = MCPClient("https://mcp.example.com") + + mock_result = { + "tools": [ + { + "name": "get_weather", + "description": "Get current weather for a city", + "inputSchema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + { + "name": "search", + "description": "Search the web", + "inputSchema": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }, + ] + } + + with patch.object(client, "_send_request", return_value=mock_result): + tools = await client.list_tools() + + assert len(tools) == 2 + assert tools[0].name == "get_weather" + assert tools[0].description == "Get current weather for a city" + assert tools[0].input_schema["properties"]["city"]["type"] == "string" + assert tools[1].name == "search" + + @pytest.mark.asyncio(loop_scope="session") + async def test_list_tools_empty(self): + client = MCPClient("https://mcp.example.com") + + with patch.object(client, "_send_request", return_value={"tools": []}): + tools = await client.list_tools() + + assert tools == [] + + @pytest.mark.asyncio(loop_scope="session") + async def test_list_tools_none_result(self): + client = MCPClient("https://mcp.example.com") + + with patch.object(client, "_send_request", return_value=None): + tools = await client.list_tools() + + assert tools == [] + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_success(self): + client = MCPClient("https://mcp.example.com") + + mock_result = { + "content": [ + {"type": "text", "text": json.dumps({"temp": 20, "city": "London"})} + ], + "isError": False, + } + + with patch.object(client, "_send_request", return_value=mock_result): + result = await client.call_tool("get_weather", {"city": "London"}) + + assert not result.is_error + assert len(result.content) == 1 + assert result.content[0]["type"] == "text" + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_error(self): + client = MCPClient("https://mcp.example.com") + + mock_result = { + "content": [{"type": "text", "text": "City not found"}], + "isError": True, + } + + with patch.object(client, "_send_request", return_value=mock_result): + result = await client.call_tool("get_weather", {"city": "???"}) + + assert result.is_error + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_tool_none_result(self): + client = MCPClient("https://mcp.example.com") + + with patch.object(client, "_send_request", return_value=None): + result = await client.call_tool("get_weather", {"city": "London"}) + + assert result.is_error + + @pytest.mark.asyncio(loop_scope="session") + async def test_initialize(self): + client = MCPClient("https://mcp.example.com") + + mock_result = { + "protocolVersion": "2025-03-26", + "capabilities": {"tools": {}}, + "serverInfo": {"name": "test-server", "version": "1.0.0"}, + } + + with ( + patch.object(client, "_send_request", return_value=mock_result) as mock_req, + patch.object(client, "_send_notification") as mock_notif, + ): + result = await client.initialize() + + mock_req.assert_called_once() + mock_notif.assert_called_once_with("notifications/initialized") + assert result["protocolVersion"] == "2025-03-26" + + +# ── MCPToolBlock unit tests ────────────────────────────────────────── + +MOCK_USER_ID = "test-user-123" + + +class TestMCPToolBlock: + """Tests for the MCPToolBlock.""" + + def test_block_instantiation(self): + block = MCPToolBlock() + assert block.id == "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4" + assert block.name == "MCPToolBlock" + + def test_input_schema_has_required_fields(self): + block = MCPToolBlock() + schema = block.input_schema.jsonschema() + props = schema.get("properties", {}) + assert "server_url" in props + assert "selected_tool" in props + assert "tool_arguments" in props + assert "credentials" in props + + def test_output_schema(self): + block = MCPToolBlock() + schema = block.output_schema.jsonschema() + props = schema.get("properties", {}) + assert "result" in props + assert "error" in props + + def test_get_input_schema_with_tool_schema(self): + tool_schema = { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + } + data = {"tool_input_schema": tool_schema} + result = MCPToolBlock.Input.get_input_schema(data) + assert result == tool_schema + + def test_get_input_schema_without_tool_schema(self): + result = MCPToolBlock.Input.get_input_schema({}) + assert result == {} + + def test_get_input_defaults(self): + data = {"tool_arguments": {"city": "London"}} + result = MCPToolBlock.Input.get_input_defaults(data) + assert result == {"city": "London"} + + def test_get_missing_input(self): + data = { + "tool_input_schema": { + "type": "object", + "properties": { + "city": {"type": "string"}, + "units": {"type": "string"}, + }, + "required": ["city", "units"], + }, + "tool_arguments": {"city": "London"}, + } + missing = MCPToolBlock.Input.get_missing_input(data) + assert missing == {"units"} + + def test_get_missing_input_all_present(self): + data = { + "tool_input_schema": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + "tool_arguments": {"city": "London"}, + } + missing = MCPToolBlock.Input.get_missing_input(data) + assert missing == set() + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_with_mock(self): + """Test the block using the built-in test infrastructure.""" + block = MCPToolBlock() + await execute_block_test(block) + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_missing_server_url(self): + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="", + selected_tool="test", + ) + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + assert outputs == [("error", "MCP server URL is required")] + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_missing_tool(self): + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="", + ) + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + assert outputs == [ + ("error", "No tool selected. Please select a tool from the dropdown.") + ] + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_success(self): + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="get_weather", + tool_input_schema={ + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + tool_arguments={"city": "London"}, + ) + + async def mock_call(*args, **kwargs): + return {"temp": 20, "city": "London"} + + block._call_mcp_tool = mock_call # type: ignore + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert len(outputs) == 1 + assert outputs[0][0] == "result" + assert outputs[0][1] == {"temp": 20, "city": "London"} + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_mcp_error(self): + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="bad_tool", + ) + + async def mock_call(*args, **kwargs): + raise MCPClientError("Tool not found") + + block._call_mcp_tool = mock_call # type: ignore + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert outputs[0][0] == "error" + assert "Tool not found" in outputs[0][1] + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_mcp_tool_parses_json_text(self): + block = MCPToolBlock() + + mock_result = MCPCallResult( + content=[ + {"type": "text", "text": '{"temp": 20}'}, + ], + is_error=False, + ) + + async def mock_init(self): + return {} + + async def mock_call(self, name, args): + return mock_result + + with ( + patch.object(MCPClient, "initialize", mock_init), + patch.object(MCPClient, "call_tool", mock_call), + ): + result = await block._call_mcp_tool( + "https://mcp.example.com", "test_tool", {} + ) + + assert result == {"temp": 20} + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_mcp_tool_plain_text(self): + block = MCPToolBlock() + + mock_result = MCPCallResult( + content=[ + {"type": "text", "text": "Hello, world!"}, + ], + is_error=False, + ) + + async def mock_init(self): + return {} + + async def mock_call(self, name, args): + return mock_result + + with ( + patch.object(MCPClient, "initialize", mock_init), + patch.object(MCPClient, "call_tool", mock_call), + ): + result = await block._call_mcp_tool( + "https://mcp.example.com", "test_tool", {} + ) + + assert result == "Hello, world!" + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_mcp_tool_multiple_content(self): + block = MCPToolBlock() + + mock_result = MCPCallResult( + content=[ + {"type": "text", "text": "Part 1"}, + {"type": "text", "text": '{"part": 2}'}, + ], + is_error=False, + ) + + async def mock_init(self): + return {} + + async def mock_call(self, name, args): + return mock_result + + with ( + patch.object(MCPClient, "initialize", mock_init), + patch.object(MCPClient, "call_tool", mock_call), + ): + result = await block._call_mcp_tool( + "https://mcp.example.com", "test_tool", {} + ) + + assert result == ["Part 1", {"part": 2}] + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_mcp_tool_error_result(self): + block = MCPToolBlock() + + mock_result = MCPCallResult( + content=[{"type": "text", "text": "Something went wrong"}], + is_error=True, + ) + + async def mock_init(self): + return {} + + async def mock_call(self, name, args): + return mock_result + + with ( + patch.object(MCPClient, "initialize", mock_init), + patch.object(MCPClient, "call_tool", mock_call), + ): + with pytest.raises(MCPClientError, match="returned an error"): + await block._call_mcp_tool("https://mcp.example.com", "test_tool", {}) + + @pytest.mark.asyncio(loop_scope="session") + async def test_call_mcp_tool_image_content(self): + block = MCPToolBlock() + + mock_result = MCPCallResult( + content=[ + { + "type": "image", + "data": "base64data==", + "mimeType": "image/png", + } + ], + is_error=False, + ) + + async def mock_init(self): + return {} + + async def mock_call(self, name, args): + return mock_result + + with ( + patch.object(MCPClient, "initialize", mock_init), + patch.object(MCPClient, "call_tool", mock_call), + ): + result = await block._call_mcp_tool( + "https://mcp.example.com", "test_tool", {} + ) + + assert result == { + "type": "image", + "data": "base64data==", + "mimeType": "image/png", + } + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_with_credentials(self): + """Verify the block uses OAuth2Credentials and passes auth token.""" + from pydantic import SecretStr + + from backend.data.model import OAuth2Credentials + + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="test_tool", + ) + + captured_tokens: list[str | None] = [] + + async def mock_call(server_url, tool_name, arguments, auth_token=None): + captured_tokens.append(auth_token) + return "ok" + + block._call_mcp_tool = mock_call # type: ignore + + test_creds = OAuth2Credentials( + id="cred-123", + provider="mcp", + access_token=SecretStr("resolved-token"), + refresh_token=SecretStr(""), + scopes=[], + title="Test MCP credential", + ) + + async for _ in block.run( + input_data, user_id=MOCK_USER_ID, credentials=test_creds + ): + pass + + assert captured_tokens == ["resolved-token"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_run_without_credentials(self): + """Verify the block works without credentials (public server).""" + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="test_tool", + ) + + captured_tokens: list[str | None] = [] + + async def mock_call(server_url, tool_name, arguments, auth_token=None): + captured_tokens.append(auth_token) + return "ok" + + block._call_mcp_tool = mock_call # type: ignore + + outputs = [] + async for name, data in block.run(input_data, user_id=MOCK_USER_ID): + outputs.append((name, data)) + + assert captured_tokens == [None] + assert outputs == [("result", "ok")] diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_oauth.py b/autogpt_platform/backend/backend/blocks/mcp/test_oauth.py new file mode 100644 index 0000000000..e9a42f68ea --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_oauth.py @@ -0,0 +1,242 @@ +""" +Tests for MCP OAuth handler. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic import SecretStr + +from backend.blocks.mcp.client import MCPClient +from backend.blocks.mcp.oauth import MCPOAuthHandler +from backend.data.model import OAuth2Credentials + + +def _mock_response(json_data: dict, status: int = 200) -> MagicMock: + """Create a mock Response with synchronous json() (matching Requests.Response).""" + resp = MagicMock() + resp.status = status + resp.ok = 200 <= status < 300 + resp.json.return_value = json_data + return resp + + +class TestMCPOAuthHandler: + """Tests for the MCPOAuthHandler.""" + + def _make_handler(self, **overrides) -> MCPOAuthHandler: + defaults = { + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "redirect_uri": "https://app.example.com/callback", + "authorize_url": "https://auth.example.com/authorize", + "token_url": "https://auth.example.com/token", + } + defaults.update(overrides) + return MCPOAuthHandler(**defaults) + + def test_get_login_url_basic(self): + handler = self._make_handler() + url = handler.get_login_url( + scopes=["read", "write"], + state="random-state-token", + code_challenge="S256-challenge-value", + ) + + assert "https://auth.example.com/authorize?" in url + assert "response_type=code" in url + assert "client_id=test-client-id" in url + assert "state=random-state-token" in url + assert "code_challenge=S256-challenge-value" in url + assert "code_challenge_method=S256" in url + assert "scope=read+write" in url + + def test_get_login_url_with_resource(self): + handler = self._make_handler(resource_url="https://mcp.example.com/mcp") + url = handler.get_login_url( + scopes=[], state="state", code_challenge="challenge" + ) + + assert "resource=https" in url + + def test_get_login_url_without_pkce(self): + handler = self._make_handler() + url = handler.get_login_url(scopes=["read"], state="state", code_challenge=None) + + assert "code_challenge" not in url + assert "code_challenge_method" not in url + + @pytest.mark.asyncio(loop_scope="session") + async def test_exchange_code_for_tokens(self): + handler = self._make_handler() + + resp = _mock_response( + { + "access_token": "new-access-token", + "refresh_token": "new-refresh-token", + "expires_in": 3600, + "token_type": "Bearer", + } + ) + + with patch("backend.blocks.mcp.oauth.Requests") as MockRequests: + instance = MockRequests.return_value + instance.post = AsyncMock(return_value=resp) + + creds = await handler.exchange_code_for_tokens( + code="auth-code", + scopes=["read"], + code_verifier="pkce-verifier", + ) + + assert isinstance(creds, OAuth2Credentials) + assert creds.access_token.get_secret_value() == "new-access-token" + assert creds.refresh_token is not None + assert creds.refresh_token.get_secret_value() == "new-refresh-token" + assert creds.scopes == ["read"] + assert creds.access_token_expires_at is not None + + @pytest.mark.asyncio(loop_scope="session") + async def test_refresh_tokens(self): + handler = self._make_handler() + + existing_creds = OAuth2Credentials( + id="existing-id", + provider="mcp", + access_token=SecretStr("old-token"), + refresh_token=SecretStr("old-refresh"), + scopes=["read"], + title="test", + ) + + resp = _mock_response( + { + "access_token": "refreshed-token", + "refresh_token": "new-refresh", + "expires_in": 3600, + } + ) + + with patch("backend.blocks.mcp.oauth.Requests") as MockRequests: + instance = MockRequests.return_value + instance.post = AsyncMock(return_value=resp) + + refreshed = await handler._refresh_tokens(existing_creds) + + assert refreshed.id == "existing-id" + assert refreshed.access_token.get_secret_value() == "refreshed-token" + assert refreshed.refresh_token is not None + assert refreshed.refresh_token.get_secret_value() == "new-refresh" + + @pytest.mark.asyncio(loop_scope="session") + async def test_refresh_tokens_no_refresh_token(self): + handler = self._make_handler() + + creds = OAuth2Credentials( + provider="mcp", + access_token=SecretStr("token"), + scopes=["read"], + title="test", + ) + + with pytest.raises(ValueError, match="No refresh token"): + await handler._refresh_tokens(creds) + + @pytest.mark.asyncio(loop_scope="session") + async def test_revoke_tokens_no_url(self): + handler = self._make_handler(revoke_url=None) + + creds = OAuth2Credentials( + provider="mcp", + access_token=SecretStr("token"), + scopes=[], + title="test", + ) + + result = await handler.revoke_tokens(creds) + assert result is False + + @pytest.mark.asyncio(loop_scope="session") + async def test_revoke_tokens_with_url(self): + handler = self._make_handler(revoke_url="https://auth.example.com/revoke") + + creds = OAuth2Credentials( + provider="mcp", + access_token=SecretStr("token"), + scopes=[], + title="test", + ) + + resp = _mock_response({}, status=200) + + with patch("backend.blocks.mcp.oauth.Requests") as MockRequests: + instance = MockRequests.return_value + instance.post = AsyncMock(return_value=resp) + + result = await handler.revoke_tokens(creds) + + assert result is True + + +class TestMCPClientDiscovery: + """Tests for MCPClient OAuth metadata discovery.""" + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_auth_found(self): + client = MCPClient("https://mcp.example.com/mcp") + + metadata = { + "authorization_servers": ["https://auth.example.com"], + "resource": "https://mcp.example.com/mcp", + } + + resp = _mock_response(metadata, status=200) + + with patch("backend.blocks.mcp.client.Requests") as MockRequests: + instance = MockRequests.return_value + instance.get = AsyncMock(return_value=resp) + + result = await client.discover_auth() + + assert result is not None + assert result["authorization_servers"] == ["https://auth.example.com"] + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_auth_not_found(self): + client = MCPClient("https://mcp.example.com/mcp") + + resp = _mock_response({}, status=404) + + with patch("backend.blocks.mcp.client.Requests") as MockRequests: + instance = MockRequests.return_value + instance.get = AsyncMock(return_value=resp) + + result = await client.discover_auth() + + assert result is None + + @pytest.mark.asyncio(loop_scope="session") + async def test_discover_auth_server_metadata(self): + client = MCPClient("https://mcp.example.com/mcp") + + server_metadata = { + "issuer": "https://auth.example.com", + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + "registration_endpoint": "https://auth.example.com/register", + "code_challenge_methods_supported": ["S256"], + } + + resp = _mock_response(server_metadata, status=200) + + with patch("backend.blocks.mcp.client.Requests") as MockRequests: + instance = MockRequests.return_value + instance.get = AsyncMock(return_value=resp) + + result = await client.discover_auth_server_metadata( + "https://auth.example.com" + ) + + assert result is not None + assert result["authorization_endpoint"] == "https://auth.example.com/authorize" + assert result["token_endpoint"] == "https://auth.example.com/token" diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_server.py b/autogpt_platform/backend/backend/blocks/mcp/test_server.py new file mode 100644 index 0000000000..a6732932bc --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_server.py @@ -0,0 +1,162 @@ +""" +Minimal MCP server for integration testing. + +Implements the MCP Streamable HTTP transport (JSON-RPC 2.0 over HTTP POST) +with a few sample tools. Runs on localhost with a random available port. +""" + +import json +import logging + +from aiohttp import web + +logger = logging.getLogger(__name__) + +# Sample tools this test server exposes +TEST_TOOLS = [ + { + "name": "get_weather", + "description": "Get current weather for a city", + "inputSchema": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name", + }, + }, + "required": ["city"], + }, + }, + { + "name": "add_numbers", + "description": "Add two numbers together", + "inputSchema": { + "type": "object", + "properties": { + "a": {"type": "number", "description": "First number"}, + "b": {"type": "number", "description": "Second number"}, + }, + "required": ["a", "b"], + }, + }, + { + "name": "echo", + "description": "Echo back the input message", + "inputSchema": { + "type": "object", + "properties": { + "message": {"type": "string", "description": "Message to echo"}, + }, + "required": ["message"], + }, + }, +] + + +def _handle_initialize(params: dict) -> dict: + return { + "protocolVersion": "2025-03-26", + "capabilities": {"tools": {"listChanged": False}}, + "serverInfo": {"name": "test-mcp-server", "version": "1.0.0"}, + } + + +def _handle_tools_list(params: dict) -> dict: + return {"tools": TEST_TOOLS} + + +def _handle_tools_call(params: dict) -> dict: + tool_name = params.get("name", "") + arguments = params.get("arguments", {}) + + if tool_name == "get_weather": + city = arguments.get("city", "Unknown") + return { + "content": [ + { + "type": "text", + "text": json.dumps( + {"city": city, "temperature": 22, "condition": "sunny"} + ), + } + ], + } + + elif tool_name == "add_numbers": + a = arguments.get("a", 0) + b = arguments.get("b", 0) + return { + "content": [{"type": "text", "text": json.dumps({"result": a + b})}], + } + + elif tool_name == "echo": + message = arguments.get("message", "") + return { + "content": [{"type": "text", "text": message}], + } + + else: + return { + "content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}], + "isError": True, + } + + +HANDLERS = { + "initialize": _handle_initialize, + "tools/list": _handle_tools_list, + "tools/call": _handle_tools_call, +} + + +async def handle_mcp_request(request: web.Request) -> web.Response: + """Handle incoming MCP JSON-RPC 2.0 requests.""" + # Check auth if configured + expected_token = request.app.get("auth_token") + if expected_token: + auth_header = request.headers.get("Authorization", "") + if auth_header != f"Bearer {expected_token}": + return web.json_response( + { + "jsonrpc": "2.0", + "error": {"code": -32001, "message": "Unauthorized"}, + "id": None, + }, + status=401, + ) + + body = await request.json() + + # Handle notifications (no id field) — just acknowledge + if "id" not in body: + return web.Response(status=202) + + method = body.get("method", "") + params = body.get("params", {}) + request_id = body.get("id") + + handler = HANDLERS.get(method) + if not handler: + return web.json_response( + { + "jsonrpc": "2.0", + "error": { + "code": -32601, + "message": f"Method not found: {method}", + }, + "id": request_id, + } + ) + + result = handler(params) + return web.json_response({"jsonrpc": "2.0", "result": result, "id": request_id}) + + +def create_test_mcp_app(auth_token: str | None = None) -> web.Application: + """Create an aiohttp app that acts as an MCP server.""" + app = web.Application() + app.router.add_post("/mcp", handle_mcp_request) + if auth_token: + app["auth_token"] = auth_token + return app diff --git a/autogpt_platform/backend/backend/copilot/executor/processor.py b/autogpt_platform/backend/backend/copilot/executor/processor.py index 5da9077341..eb941d5efd 100644 --- a/autogpt_platform/backend/backend/copilot/executor/processor.py +++ b/autogpt_platform/backend/backend/copilot/executor/processor.py @@ -11,9 +11,12 @@ import time from backend.copilot import service as copilot_service from backend.copilot import stream_registry +from backend.copilot.config import ChatConfig from backend.copilot.response_model import StreamError, StreamFinish, StreamFinishStep +from backend.copilot.sdk import service as sdk_service from backend.executor.cluster_lock import ClusterLock from backend.util.decorator import error_logged +from backend.util.feature_flag import Flag, is_feature_enabled from backend.util.logging import TruncatedLogger, configure_logging from backend.util.process import set_service_name from backend.util.retry import func_retry @@ -177,14 +180,27 @@ class CoPilotProcessor: refresh_interval = 30.0 # Refresh lock every 30 seconds try: + # Choose service based on LaunchDarkly flag + config = ChatConfig() + use_sdk = await is_feature_enabled( + Flag.COPILOT_SDK, + entry.user_id or "anonymous", + default=config.use_claude_agent_sdk, + ) + stream_fn = ( + sdk_service.stream_chat_completion_sdk + if use_sdk + else copilot_service.stream_chat_completion + ) + log.info(f"Using {'SDK' if use_sdk else 'standard'} service") + # Stream chat completion and publish chunks to Redis - async for chunk in copilot_service.stream_chat_completion( + async for chunk in stream_fn( session_id=entry.session_id, message=entry.message if entry.message else None, is_user_message=entry.is_user_message, user_id=entry.user_id, context=entry.context, - _task_id=entry.task_id, ): # Check for cancellation if cancel.is_set(): diff --git a/autogpt_platform/backend/backend/copilot/tools/feature_requests.py b/autogpt_platform/backend/backend/copilot/tools/feature_requests.py index ebfc37f475..a4fb070eb6 100644 --- a/autogpt_platform/backend/backend/copilot/tools/feature_requests.py +++ b/autogpt_platform/backend/backend/copilot/tools/feature_requests.py @@ -7,8 +7,8 @@ from pydantic import SecretStr from backend.blocks.linear._api import LinearClient from backend.copilot.model import ChatSession +from backend.data.db_accessors import user_db from backend.data.model import APIKeyCredentials -from backend.data.user import get_user_email_by_id from backend.util.settings import Settings from .base import BaseTool @@ -333,7 +333,9 @@ class CreateFeatureRequestTool(BaseTool): # Resolve a human-readable name (email) for the Linear customer record. # Fall back to user_id if the lookup fails or returns None. try: - customer_display_name = await get_user_email_by_id(user_id) or user_id + customer_display_name = ( + await user_db().get_user_email_by_id(user_id) or user_id + ) except Exception: customer_display_name = user_id diff --git a/autogpt_platform/backend/backend/copilot/tools/utils.py b/autogpt_platform/backend/backend/copilot/tools/utils.py index b200016b02..60747566a6 100644 --- a/autogpt_platform/backend/backend/copilot/tools/utils.py +++ b/autogpt_platform/backend/backend/copilot/tools/utils.py @@ -14,6 +14,7 @@ from backend.data.model import ( OAuth2Credentials, ) from backend.integrations.creds_manager import IntegrationCredentialsManager +from backend.integrations.providers import ProviderName from backend.util.exceptions import NotFoundError logger = logging.getLogger(__name__) @@ -359,7 +360,7 @@ async def match_user_credentials_to_graph( _, _, ) in aggregated_creds.items(): - # Find first matching credential by provider, type, and scopes + # Find first matching credential by provider, type, scopes, and host/URL matching_cred = next( ( cred @@ -374,6 +375,10 @@ async def match_user_credentials_to_graph( cred.type != "host_scoped" or _credential_is_for_host(cred, credential_requirements) ) + and ( + cred.provider != ProviderName.MCP + or _credential_is_for_mcp_server(cred, credential_requirements) + ) ), None, ) @@ -444,6 +449,22 @@ def _credential_is_for_host( return credential.matches_url(list(requirements.discriminator_values)[0]) +def _credential_is_for_mcp_server( + credential: Credentials, + requirements: CredentialsFieldInfo, +) -> bool: + """Check if an MCP OAuth credential matches the required server URL.""" + if not requirements.discriminator_values: + return True + + server_url = ( + credential.metadata.get("mcp_server_url") + if isinstance(credential, OAuth2Credentials) + else None + ) + return server_url in requirements.discriminator_values if server_url else False + + async def check_user_has_required_credentials( user_id: str, required_credentials: list[CredentialsMetaInput], diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index f39a0144e7..94f99852e8 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -33,6 +33,7 @@ from backend.util import type as type_utils from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError from backend.util.json import SafeJson from backend.util.models import Pagination +from backend.util.request import parse_url from .block import BlockInput from .db import BaseDbModel @@ -449,6 +450,9 @@ class GraphModel(Graph, GraphMeta): continue if ProviderName.HTTP in field.provider: continue + # MCP credentials are intentionally split by server URL + if ProviderName.MCP in field.provider: + continue # If this happens, that means a block implementation probably needs # to be updated. @@ -505,6 +509,18 @@ class GraphModel(Graph, GraphMeta): "required": ["id", "provider", "type"], } + # Add a descriptive display title when URL-based discriminator values + # are present (e.g. "mcp.sentry.dev" instead of just "Mcp") + if ( + field_info.discriminator + and not field_info.discriminator_mapping + and field_info.discriminator_values + ): + hostnames = sorted( + parse_url(str(v)).netloc for v in field_info.discriminator_values + ) + field_schema["display_name"] = ", ".join(hostnames) + # Add other (optional) field info items field_schema.update( field_info.model_dump( @@ -549,8 +565,17 @@ class GraphModel(Graph, GraphMeta): for graph in [self] + self.sub_graphs: for node in graph.nodes: - # Track if this node requires credentials (credentials_optional=False means required) - node_required_map[node.id] = not node.credentials_optional + # A node's credentials are optional if either: + # 1. The node metadata says so (credentials_optional=True), or + # 2. All credential fields on the block have defaults (not required by schema) + block_required = node.block.input_schema.get_required_fields() + creds_required_by_schema = any( + fname in block_required + for fname in node.block.input_schema.get_credentials_fields() + ) + node_required_map[node.id] = ( + not node.credentials_optional and creds_required_by_schema + ) for ( field_name, @@ -776,6 +801,19 @@ class GraphModel(Graph, GraphMeta): "'credentials' and `*_credentials` are reserved" ) + # Check custom block-level validation (e.g., MCP dynamic tool arguments). + # Blocks can override get_missing_input to report additional missing fields + # beyond the standard top-level required fields. + if for_run: + credential_fields = InputSchema.get_credentials_fields() + custom_missing = InputSchema.get_missing_input(node.input_default) + for field_name in custom_missing: + if ( + field_name not in provided_inputs + and field_name not in credential_fields + ): + node_errors[node.id][field_name] = "This field is required" + # Get input schema properties and check dependencies input_fields = InputSchema.model_fields diff --git a/autogpt_platform/backend/backend/data/graph_test.py b/autogpt_platform/backend/backend/data/graph_test.py index 442c8ed4be..3cb6f24b87 100644 --- a/autogpt_platform/backend/backend/data/graph_test.py +++ b/autogpt_platform/backend/backend/data/graph_test.py @@ -462,3 +462,120 @@ def test_node_credentials_optional_with_other_metadata(): assert node.credentials_optional is True assert node.metadata["position"] == {"x": 100, "y": 200} assert node.metadata["customized_name"] == "My Custom Node" + + +# ============================================================================ +# Tests for MCP Credential Deduplication +# ============================================================================ + + +def test_mcp_credential_combine_different_servers(): + """Two MCP credential fields with different server URLs should produce + separate entries when combined (not merged into one).""" + from backend.data.model import CredentialsFieldInfo, CredentialsType + from backend.integrations.providers import ProviderName + + oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"]) + + field_sentry = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + discriminator_values={"https://mcp.sentry.dev/mcp"}, + ) + field_linear = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + discriminator_values={"https://mcp.linear.app/mcp"}, + ) + + combined = CredentialsFieldInfo.combine( + (field_sentry, ("node-sentry", "credentials")), + (field_linear, ("node-linear", "credentials")), + ) + + # Should produce 2 separate credential entries + assert len(combined) == 2, ( + f"Expected 2 credential entries for 2 MCP blocks with different servers, " + f"got {len(combined)}: {list(combined.keys())}" + ) + + # Each entry should contain the server hostname in its key + keys = list(combined.keys()) + assert any( + "mcp.sentry.dev" in k for k in keys + ), f"Expected 'mcp.sentry.dev' in one key, got {keys}" + assert any( + "mcp.linear.app" in k for k in keys + ), f"Expected 'mcp.linear.app' in one key, got {keys}" + + +def test_mcp_credential_combine_same_server(): + """Two MCP credential fields with the same server URL should be combined + into one credential entry.""" + from backend.data.model import CredentialsFieldInfo, CredentialsType + from backend.integrations.providers import ProviderName + + oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"]) + + field_a = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + discriminator_values={"https://mcp.sentry.dev/mcp"}, + ) + field_b = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + discriminator_values={"https://mcp.sentry.dev/mcp"}, + ) + + combined = CredentialsFieldInfo.combine( + (field_a, ("node-a", "credentials")), + (field_b, ("node-b", "credentials")), + ) + + # Should produce 1 credential entry (same server URL) + assert len(combined) == 1, ( + f"Expected 1 credential entry for 2 MCP blocks with same server, " + f"got {len(combined)}: {list(combined.keys())}" + ) + + +def test_mcp_credential_combine_no_discriminator_values(): + """MCP credential fields without discriminator_values should be merged + into a single entry (backwards compat for blocks without server_url set).""" + from backend.data.model import CredentialsFieldInfo, CredentialsType + from backend.integrations.providers import ProviderName + + oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"]) + + field_a = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + ) + field_b = CredentialsFieldInfo( + credentials_provider=frozenset([ProviderName.MCP]), + credentials_types=oauth2_types, + credentials_scopes=None, + discriminator="server_url", + ) + + combined = CredentialsFieldInfo.combine( + (field_a, ("node-a", "credentials")), + (field_b, ("node-b", "credentials")), + ) + + # Should produce 1 entry (no URL differentiation) + assert len(combined) == 1, ( + f"Expected 1 credential entry for MCP blocks without discriminator_values, " + f"got {len(combined)}: {list(combined.keys())}" + ) diff --git a/autogpt_platform/backend/backend/data/model.py b/autogpt_platform/backend/backend/data/model.py index e61f7efbd0..c9d8c5879f 100644 --- a/autogpt_platform/backend/backend/data/model.py +++ b/autogpt_platform/backend/backend/data/model.py @@ -29,6 +29,7 @@ from pydantic import ( GetCoreSchemaHandler, SecretStr, field_serializer, + model_validator, ) from pydantic_core import ( CoreSchema, @@ -502,6 +503,25 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]): provider: CP type: CT + @model_validator(mode="before") + @classmethod + def _normalize_legacy_provider(cls, data: Any) -> Any: + """Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug. + + Python 3.13 changed ``str(StrEnum)`` to return ``"ClassName.MEMBER"`` + instead of the plain value. Old stored credential references may have + ``provider: "ProviderName.MCP"`` instead of ``"mcp"``. + """ + if isinstance(data, dict): + prov = data.get("provider", "") + if isinstance(prov, str) and prov.startswith("ProviderName."): + member = prov.removeprefix("ProviderName.") + try: + data = {**data, "provider": ProviderName[member].value} + except KeyError: + pass + return data + @classmethod def allowed_providers(cls) -> tuple[ProviderName, ...] | None: return get_args(cls.model_fields["provider"].annotation) @@ -606,11 +626,18 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]): ] = defaultdict(list) for field, key in fields: - if field.provider == frozenset([ProviderName.HTTP]): - # HTTP host-scoped credentials can have different hosts that reqires different credential sets. - # Group by host extracted from the URL + if ( + field.discriminator + and not field.discriminator_mapping + and field.discriminator_values + ): + # URL-based discrimination (e.g. HTTP host-scoped, MCP server URL): + # Each unique host gets its own credential entry. + provider_prefix = next(iter(field.provider)) + # Use .value for enum types to get the plain string (e.g. "mcp" not "ProviderName.MCP") + prefix_str = getattr(provider_prefix, "value", str(provider_prefix)) providers = frozenset( - [cast(CP, "http")] + [cast(CP, prefix_str)] + [ cast(CP, parse_url(str(value)).netloc) for value in field.discriminator_values diff --git a/autogpt_platform/backend/backend/executor/manager.py b/autogpt_platform/backend/backend/executor/manager.py index e6cf257c80..4444e15d22 100644 --- a/autogpt_platform/backend/backend/executor/manager.py +++ b/autogpt_platform/backend/backend/executor/manager.py @@ -20,6 +20,7 @@ from backend.blocks import get_block from backend.blocks._base import BlockSchema from backend.blocks.agent import AgentExecutorBlock from backend.blocks.io import AgentOutputBlock +from backend.blocks.mcp.block import MCPToolBlock from backend.data import redis_client as redis from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry from backend.data.credit import UsageTransactionMetadata @@ -231,6 +232,18 @@ async def execute_node( _input_data.nodes_input_masks = nodes_input_masks _input_data.user_id = user_id input_data = _input_data.model_dump() + elif isinstance(node_block, MCPToolBlock): + _mcp_data = MCPToolBlock.Input(**node.input_default) + # Dynamic tool fields are flattened to top-level by validate_exec + # (via get_input_defaults). Collect them back into tool_arguments. + tool_schema = _mcp_data.tool_input_schema + tool_props = set(tool_schema.get("properties", {}).keys()) + merged_args = {**_mcp_data.tool_arguments} + for key in tool_props: + if key in input_data: + merged_args[key] = input_data[key] + _mcp_data.tool_arguments = merged_args + input_data = _mcp_data.model_dump() data.inputs = input_data # Execute the node @@ -267,8 +280,34 @@ async def execute_node( # Handle regular credentials fields for field_name, input_type in input_model.get_credentials_fields().items(): - credentials_meta = input_type(**input_data[field_name]) - credentials, lock = await creds_manager.acquire(user_id, credentials_meta.id) + field_value = input_data.get(field_name) + if not field_value or ( + isinstance(field_value, dict) and not field_value.get("id") + ): + # No credentials configured — nullify so JSON schema validation + # doesn't choke on the empty default `{}`. + input_data[field_name] = None + continue # Block runs without credentials + + credentials_meta = input_type(**field_value) + # Write normalized values back so JSON schema validation also passes + # (model_validator may have fixed legacy formats like "ProviderName.MCP") + input_data[field_name] = credentials_meta.model_dump(mode="json") + try: + credentials, lock = await creds_manager.acquire( + user_id, credentials_meta.id + ) + except ValueError: + # Credential was deleted or doesn't exist. + # If the field has a default, run without credentials. + if input_model.model_fields[field_name].default is not None: + log_metadata.warning( + f"Credentials #{credentials_meta.id} not found, " + "running without (field has default)" + ) + input_data[field_name] = None + continue + raise creds_locks.append(lock) extra_exec_kwargs[field_name] = credentials diff --git a/autogpt_platform/backend/backend/executor/utils.py b/autogpt_platform/backend/backend/executor/utils.py index bb5da1e527..2b9a454061 100644 --- a/autogpt_platform/backend/backend/executor/utils.py +++ b/autogpt_platform/backend/backend/executor/utils.py @@ -260,7 +260,13 @@ async def _validate_node_input_credentials( # Track if any credential field is missing for this node has_missing_credentials = False + # A credential field is optional if the node metadata says so, or if + # the block schema declares a default for the field. + required_fields = block.input_schema.get_required_fields() + is_creds_optional = node.credentials_optional + for field_name, credentials_meta_type in credentials_fields.items(): + field_is_optional = is_creds_optional or field_name not in required_fields try: # Check nodes_input_masks first, then input_default field_value = None @@ -273,7 +279,7 @@ async def _validate_node_input_credentials( elif field_name in node.input_default: # For optional credentials, don't use input_default - treat as missing # This prevents stale credential IDs from failing validation - if node.credentials_optional: + if field_is_optional: field_value = None else: field_value = node.input_default[field_name] @@ -283,8 +289,8 @@ async def _validate_node_input_credentials( isinstance(field_value, dict) and not field_value.get("id") ): has_missing_credentials = True - # If node has credentials_optional flag, mark for skipping instead of error - if node.credentials_optional: + # If credential field is optional, skip instead of error + if field_is_optional: continue # Don't add error, will be marked for skip after loop else: credential_errors[node.id][ @@ -334,16 +340,16 @@ async def _validate_node_input_credentials( ] = "Invalid credentials: type/provider mismatch" continue - # If node has optional credentials and any are missing, mark for skipping - # But only if there are no other errors for this node + # If node has optional credentials and any are missing, allow running without. + # The executor will pass credentials=None to the block's run(). if ( has_missing_credentials - and node.credentials_optional + and is_creds_optional and node.id not in credential_errors ): - nodes_to_skip.add(node.id) logger.info( - f"Node #{node.id} will be skipped: optional credentials not configured" + f"Node #{node.id}: optional credentials not configured, " + "running without" ) return credential_errors, nodes_to_skip diff --git a/autogpt_platform/backend/backend/executor/utils_test.py b/autogpt_platform/backend/backend/executor/utils_test.py index db33249583..069086a6fd 100644 --- a/autogpt_platform/backend/backend/executor/utils_test.py +++ b/autogpt_platform/backend/backend/executor/utils_test.py @@ -495,6 +495,7 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip( mock_block.input_schema.get_credentials_fields.return_value = { "credentials": mock_credentials_field_type } + mock_block.input_schema.get_required_fields.return_value = {"credentials"} mock_node.block = mock_block # Create mock graph @@ -508,8 +509,8 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip( nodes_input_masks=None, ) - # Node should be in nodes_to_skip, not in errors - assert mock_node.id in nodes_to_skip + # Node should NOT be in nodes_to_skip (runs without credentials) and not in errors + assert mock_node.id not in nodes_to_skip assert mock_node.id not in errors @@ -535,6 +536,7 @@ async def test_validate_node_input_credentials_required_missing_creds_error( mock_block.input_schema.get_credentials_fields.return_value = { "credentials": mock_credentials_field_type } + mock_block.input_schema.get_required_fields.return_value = {"credentials"} mock_node.block = mock_block # Create mock graph diff --git a/autogpt_platform/backend/backend/integrations/credentials_store.py b/autogpt_platform/backend/backend/integrations/credentials_store.py index 384405b0c7..3e79a6c047 100644 --- a/autogpt_platform/backend/backend/integrations/credentials_store.py +++ b/autogpt_platform/backend/backend/integrations/credentials_store.py @@ -22,6 +22,27 @@ from backend.util.settings import Settings settings = Settings() + +def provider_matches(stored: str, expected: str) -> bool: + """Compare provider strings, handling Python 3.13 ``str(StrEnum)`` bug. + + On Python 3.13, ``str(ProviderName.MCP)`` returns ``"ProviderName.MCP"`` + instead of ``"mcp"``. OAuth states persisted with the buggy format need + to match when ``expected`` is the canonical value (e.g. ``"mcp"``). + """ + if stored == expected: + return True + if stored.startswith("ProviderName."): + member = stored.removeprefix("ProviderName.") + from backend.integrations.providers import ProviderName + + try: + return ProviderName[member].value == expected + except KeyError: + pass + return False + + # This is an overrride since ollama doesn't actually require an API key, but the creddential system enforces one be attached ollama_credentials = APIKeyCredentials( id="744fdc56-071a-4761-b5a5-0af0ce10a2b5", @@ -389,7 +410,7 @@ class IntegrationCredentialsStore: self, user_id: str, provider: str ) -> list[Credentials]: credentials = await self.get_all_creds(user_id) - return [c for c in credentials if c.provider == provider] + return [c for c in credentials if provider_matches(c.provider, provider)] async def get_authorized_providers(self, user_id: str) -> list[str]: credentials = await self.get_all_creds(user_id) @@ -485,17 +506,6 @@ class IntegrationCredentialsStore: async with self.edit_user_integrations(user_id) as user_integrations: user_integrations.oauth_states.append(state) - async with await self.locked_user_integrations(user_id): - - user_integrations = await self._get_user_integrations(user_id) - oauth_states = user_integrations.oauth_states - oauth_states.append(state) - user_integrations.oauth_states = oauth_states - - await self.db_manager.update_user_integrations( - user_id=user_id, data=user_integrations - ) - return token, code_challenge def _generate_code_challenge(self) -> tuple[str, str]: @@ -521,7 +531,7 @@ class IntegrationCredentialsStore: state for state in oauth_states if secrets.compare_digest(state.token, token) - and state.provider == provider + and provider_matches(state.provider, provider) and state.expires_at > now.timestamp() ), None, diff --git a/autogpt_platform/backend/backend/integrations/creds_manager.py b/autogpt_platform/backend/backend/integrations/creds_manager.py index f2b6a9da4f..5634dd73b6 100644 --- a/autogpt_platform/backend/backend/integrations/creds_manager.py +++ b/autogpt_platform/backend/backend/integrations/creds_manager.py @@ -9,7 +9,10 @@ from redis.asyncio.lock import Lock as AsyncRedisLock from backend.data.model import Credentials, OAuth2Credentials from backend.data.redis_client import get_redis_async -from backend.integrations.credentials_store import IntegrationCredentialsStore +from backend.integrations.credentials_store import ( + IntegrationCredentialsStore, + provider_matches, +) from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME from backend.integrations.providers import ProviderName from backend.util.exceptions import MissingConfigError @@ -137,7 +140,10 @@ class IntegrationCredentialsManager: self, user_id: str, credentials: OAuth2Credentials, lock: bool = True ) -> OAuth2Credentials: async with self._locked(user_id, credentials.id, "refresh"): - oauth_handler = await _get_provider_oauth_handler(credentials.provider) + if provider_matches(credentials.provider, ProviderName.MCP.value): + oauth_handler = create_mcp_oauth_handler(credentials) + else: + oauth_handler = await _get_provider_oauth_handler(credentials.provider) if oauth_handler.needs_refresh(credentials): logger.debug( f"Refreshing '{credentials.provider}' " @@ -236,3 +242,31 @@ async def _get_provider_oauth_handler(provider_name_str: str) -> "BaseOAuthHandl client_secret=client_secret, redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback", ) + + +def create_mcp_oauth_handler( + credentials: OAuth2Credentials, +) -> "BaseOAuthHandler": + """Create an MCPOAuthHandler from credential metadata for token refresh. + + MCP OAuth handlers have dynamic endpoints discovered per-server, so they + can't be registered as singletons in HANDLERS_BY_NAME. Instead, the handler + is reconstructed from metadata stored on the credential during initial auth. + """ + from backend.blocks.mcp.oauth import MCPOAuthHandler + + meta = credentials.metadata or {} + token_url = meta.get("mcp_token_url", "") + if not token_url: + raise ValueError( + f"MCP credential {credentials.id} is missing 'mcp_token_url' metadata; " + "cannot refresh tokens" + ) + return MCPOAuthHandler( + client_id=meta.get("mcp_client_id", ""), + client_secret=meta.get("mcp_client_secret", ""), + redirect_uri="", # Not needed for token refresh + authorize_url="", # Not needed for token refresh + token_url=token_url, + resource_url=meta.get("mcp_resource_url"), + ) diff --git a/autogpt_platform/backend/backend/integrations/providers.py b/autogpt_platform/backend/backend/integrations/providers.py index 8a0d6fd183..a462cd787f 100644 --- a/autogpt_platform/backend/backend/integrations/providers.py +++ b/autogpt_platform/backend/backend/integrations/providers.py @@ -30,6 +30,7 @@ class ProviderName(str, Enum): IDEOGRAM = "ideogram" JINA = "jina" LLAMA_API = "llama_api" + MCP = "mcp" MEDIUM = "medium" MEM0 = "mem0" NOTION = "notion" diff --git a/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py b/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py index 99eee404b9..8fdbe10383 100644 --- a/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py +++ b/autogpt_platform/backend/backend/integrations/webhooks/graph_lifecycle_hooks.py @@ -51,6 +51,21 @@ async def _on_graph_activate(graph: "BaseGraph | GraphModel", user_id: str): if ( creds_meta := new_node.input_default.get(creds_field_name) ) and not await get_credentials(creds_meta["id"]): + # If the credential field is optional (has a default in the + # schema, or node metadata marks it optional), clear the stale + # reference instead of blocking the save. + creds_field_optional = ( + new_node.credentials_optional + or creds_field_name not in block_input_schema.get_required_fields() + ) + if creds_field_optional: + new_node.input_default[creds_field_name] = {} + logger.warning( + f"Node #{new_node.id}: cleared stale optional " + f"credentials #{creds_meta['id']} for " + f"'{creds_field_name}'" + ) + continue raise ValueError( f"Node #{new_node.id} input '{creds_field_name}' updated with " f"non-existent credentials #{creds_meta['id']}" diff --git a/autogpt_platform/backend/backend/util/request.py b/autogpt_platform/backend/backend/util/request.py index 95e5ee32f7..9470909dfc 100644 --- a/autogpt_platform/backend/backend/util/request.py +++ b/autogpt_platform/backend/backend/util/request.py @@ -101,7 +101,7 @@ class HostResolver(abc.AbstractResolver): def __init__(self, ssl_hostname: str, ip_addresses: list[str]): self.ssl_hostname = ssl_hostname self.ip_addresses = ip_addresses - self._default = aiohttp.AsyncResolver() + self._default = aiohttp.ThreadedResolver() async def resolve(self, host, port=0, family=socket.AF_INET): if host == self.ssl_hostname: @@ -467,7 +467,7 @@ class Requests: resolver = HostResolver(ssl_hostname=hostname, ip_addresses=ip_addresses) ssl_context = ssl.create_default_context() connector = aiohttp.TCPConnector(resolver=resolver, ssl=ssl_context) - session_kwargs = {} + session_kwargs: dict = {} if connector: session_kwargs["connector"] = connector diff --git a/autogpt_platform/frontend/src/app/(platform)/auth/integrations/mcp_callback/route.ts b/autogpt_platform/frontend/src/app/(platform)/auth/integrations/mcp_callback/route.ts new file mode 100644 index 0000000000..326f42e049 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/auth/integrations/mcp_callback/route.ts @@ -0,0 +1,96 @@ +import { NextResponse } from "next/server"; + +/** + * Safely encode a value as JSON for embedding in a script tag. + * Escapes characters that could break out of the script context to prevent XSS. + */ +function safeJsonStringify(value: unknown): string { + return JSON.stringify(value) + .replace(//g, "\\u003e") + .replace(/&/g, "\\u0026"); +} + +// MCP-specific OAuth callback route. +// +// Unlike the generic oauth_callback which relies on window.opener.postMessage, +// this route uses BroadcastChannel as the PRIMARY communication method. +// This is critical because cross-origin OAuth flows (e.g. Sentry → localhost) +// often lose window.opener due to COOP (Cross-Origin-Opener-Policy) headers. +// +// BroadcastChannel works across all same-origin tabs/popups regardless of opener. +export async function GET(request: Request) { + const { searchParams } = new URL(request.url); + const code = searchParams.get("code"); + const state = searchParams.get("state"); + + const success = Boolean(code && state); + const message = success + ? { success: true, code, state } + : { + success: false, + message: `Missing parameters: ${searchParams.toString()}`, + }; + + return new NextResponse( + ` + + MCP Sign-in + +
+
+

Completing sign-in...

+
+ + + +`, + { headers: { "Content-Type": "text/html" } }, + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/CustomNode.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/CustomNode.tsx index d4aa26480d..62e796b748 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/CustomNode.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/CustomNode.tsx @@ -47,7 +47,10 @@ export type CustomNode = XYNode; export const CustomNode: React.FC> = React.memo( ({ data, id: nodeId, selected }) => { - const { inputSchema, outputSchema } = useCustomNode({ data, nodeId }); + const { inputSchema, outputSchema, isMCPWithTool } = useCustomNode({ + data, + nodeId, + }); const isAgent = data.uiType === BlockUIType.AGENT; @@ -98,6 +101,7 @@ export const CustomNode: React.FC> = React.memo( jsonSchema={preprocessInputSchema(inputSchema)} nodeId={nodeId} uiType={data.uiType} + isMCPWithTool={isMCPWithTool} className={cn( "bg-white px-4", isWebhook && "pointer-events-none opacity-50", diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx index c4659b8dcf..9a3add62b6 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/components/NodeHeader.tsx @@ -20,10 +20,8 @@ type Props = { export const NodeHeader = ({ data, nodeId }: Props) => { const updateNodeData = useNodeStore((state) => state.updateNodeData); - const title = - (data.metadata?.customized_name as string) || - data.hardcodedValues?.agent_name || - data.title; + + const title = (data.metadata?.customized_name as string) || data.title; const [isEditingTitle, setIsEditingTitle] = useState(false); const [editedTitle, setEditedTitle] = useState(title); diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/useCustomNode.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/useCustomNode.tsx index e58d0ab12b..050515a02f 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/useCustomNode.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/CustomNode/useCustomNode.tsx @@ -3,6 +3,34 @@ import { CustomNodeData } from "./CustomNode"; import { BlockUIType } from "../../../types"; import { useMemo } from "react"; import { mergeSchemaForResolution } from "./helpers"; +/** + * Build a dynamic input schema for MCP blocks. + * + * When a tool has been selected (tool_input_schema is populated), the block + * renders the selected tool's input parameters *plus* the credentials field + * so users can select/change the OAuth credential used for execution. + * + * Static fields like server_url, selected_tool, available_tools, and + * tool_arguments are hidden because they're pre-configured from the dialog. + */ +function buildMCPInputSchema( + toolInputSchema: Record, + blockInputSchema: Record, +): Record { + // Extract the credentials field from the block's original input schema + const credentialsSchema = + blockInputSchema?.properties?.credentials ?? undefined; + + return { + type: "object", + properties: { + // Credentials field first so the dropdown appears at the top + ...(credentialsSchema ? { credentials: credentialsSchema } : {}), + ...(toolInputSchema.properties ?? {}), + }, + required: [...(toolInputSchema.required ?? [])], + }; +} export const useCustomNode = ({ data, @@ -19,10 +47,18 @@ export const useCustomNode = ({ ); const isAgent = data.uiType === BlockUIType.AGENT; + const isMCPWithTool = + data.uiType === BlockUIType.MCP_TOOL && + !!data.hardcodedValues?.tool_input_schema?.properties; const currentInputSchema = isAgent ? (data.hardcodedValues.input_schema ?? {}) - : data.inputSchema; + : isMCPWithTool + ? buildMCPInputSchema( + data.hardcodedValues.tool_input_schema, + data.inputSchema, + ) + : data.inputSchema; const currentOutputSchema = isAgent ? (data.hardcodedValues.output_schema ?? {}) : data.outputSchema; @@ -54,5 +90,6 @@ export const useCustomNode = ({ return { inputSchema, outputSchema, + isMCPWithTool, }; }; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/FormCreator.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/FormCreator.tsx index d6a3fabffa..77b21dda92 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/FormCreator.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/FlowEditor/nodes/FormCreator.tsx @@ -9,39 +9,72 @@ interface FormCreatorProps { jsonSchema: RJSFSchema; nodeId: string; uiType: BlockUIType; + /** When true the block is an MCP Tool with a selected tool. */ + isMCPWithTool?: boolean; showHandles?: boolean; className?: string; } export const FormCreator: React.FC = React.memo( - ({ jsonSchema, nodeId, uiType, showHandles = true, className }) => { + ({ + jsonSchema, + nodeId, + uiType, + isMCPWithTool = false, + showHandles = true, + className, + }) => { const updateNodeData = useNodeStore((state) => state.updateNodeData); const getHardCodedValues = useNodeStore( (state) => state.getHardCodedValues, ); + const isAgent = uiType === BlockUIType.AGENT; + const handleChange = ({ formData }: any) => { if ("credentials" in formData && !formData.credentials?.id) { delete formData.credentials; } - const updatedValues = - uiType === BlockUIType.AGENT - ? { - ...getHardCodedValues(nodeId), - inputs: formData, - } - : formData; + let updatedValues; + if (isAgent) { + updatedValues = { + ...getHardCodedValues(nodeId), + inputs: formData, + }; + } else if (isMCPWithTool) { + // Separate credentials from tool arguments — credentials are stored + // at the top level of hardcodedValues, not inside tool_arguments. + const { credentials, ...toolArgs } = formData; + updatedValues = { + ...getHardCodedValues(nodeId), + tool_arguments: toolArgs, + ...(credentials?.id ? { credentials } : {}), + }; + } else { + updatedValues = formData; + } updateNodeData(nodeId, { hardcodedValues: updatedValues }); }; const hardcodedValues = getHardCodedValues(nodeId); - const initialValues = - uiType === BlockUIType.AGENT - ? (hardcodedValues.inputs ?? {}) - : hardcodedValues; + + let initialValues; + if (isAgent) { + initialValues = hardcodedValues.inputs ?? {}; + } else if (isMCPWithTool) { + // Merge tool arguments with credentials for the form + initialValues = { + ...(hardcodedValues.tool_arguments ?? {}), + ...(hardcodedValues.credentials?.id + ? { credentials: hardcodedValues.credentials } + : {}), + }; + } else { + initialValues = hardcodedValues; + } return (
; + availableTools: Record; + /** Credentials meta from OAuth flow, null for public servers. */ + credentials: CredentialsMetaInput | null; +}; + +interface MCPToolDialogProps { + open: boolean; + onClose: () => void; + onConfirm: (result: MCPToolDialogResult) => void; +} + +type DialogStep = "url" | "tool"; + +export function MCPToolDialog({ + open, + onClose, + onConfirm, +}: MCPToolDialogProps) { + const allProviders = useContext(CredentialsProvidersContext); + + const [step, setStep] = useState("url"); + const [serverUrl, setServerUrl] = useState(""); + const [tools, setTools] = useState([]); + const [serverName, setServerName] = useState(null); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); + const [authRequired, setAuthRequired] = useState(false); + const [oauthLoading, setOauthLoading] = useState(false); + const [showManualToken, setShowManualToken] = useState(false); + const [manualToken, setManualToken] = useState(""); + const [selectedTool, setSelectedTool] = useState( + null, + ); + const [credentials, setCredentials] = useState( + null, + ); + + const startOAuthRef = useRef(false); + const oauthAbortRef = useRef<((reason?: string) => void) | null>(null); + + // Clean up on unmount + useEffect(() => { + return () => { + oauthAbortRef.current?.(); + }; + }, []); + + const reset = useCallback(() => { + oauthAbortRef.current?.(); + oauthAbortRef.current = null; + setStep("url"); + setServerUrl(""); + setManualToken(""); + setTools([]); + setServerName(null); + setLoading(false); + setError(null); + setAuthRequired(false); + setOauthLoading(false); + setShowManualToken(false); + setSelectedTool(null); + setCredentials(null); + }, []); + + const handleClose = useCallback(() => { + reset(); + onClose(); + }, [reset, onClose]); + + const discoverTools = useCallback(async (url: string, authToken?: string) => { + setLoading(true); + setError(null); + try { + const response = await postV2DiscoverAvailableToolsOnAnMcpServer({ + server_url: url, + auth_token: authToken || null, + }); + if (response.status !== 200) throw response.data; + setTools(response.data.tools); + setServerName(response.data.server_name ?? null); + setAuthRequired(false); + setShowManualToken(false); + setStep("tool"); + } catch (e: any) { + if (e?.status === 401 || e?.status === 403) { + setAuthRequired(true); + setError(null); + // Automatically start OAuth sign-in instead of requiring a second click + setLoading(false); + startOAuthRef.current = true; + return; + } else { + const message = + e?.message || e?.detail || "Failed to connect to MCP server"; + setError( + typeof message === "string" ? message : JSON.stringify(message), + ); + } + } finally { + setLoading(false); + } + }, []); + + const handleDiscoverTools = useCallback(() => { + if (!serverUrl.trim()) return; + discoverTools(serverUrl.trim(), manualToken.trim() || undefined); + }, [serverUrl, manualToken, discoverTools]); + + const handleOAuthSignIn = useCallback(async () => { + if (!serverUrl.trim()) return; + setError(null); + + // Abort any previous OAuth flow + oauthAbortRef.current?.(); + + setOauthLoading(true); + + try { + const loginResponse = await postV2InitiateOauthLoginForAnMcpServer({ + server_url: serverUrl.trim(), + }); + if (loginResponse.status !== 200) throw loginResponse.data; + const { login_url, state_token } = loginResponse.data; + + const { promise, cleanup } = openOAuthPopup(login_url, { + stateToken: state_token, + useCrossOriginListeners: true, + }); + oauthAbortRef.current = cleanup.abort; + + const result = await promise; + + // Exchange code for tokens via the credentials provider (updates cache) + setLoading(true); + setOauthLoading(false); + + const mcpProvider = allProviders?.["mcp"]; + let callbackResult; + if (mcpProvider) { + callbackResult = await mcpProvider.mcpOAuthCallback( + result.code, + state_token, + ); + } else { + const cbResponse = await postV2ExchangeOauthCodeForMcpTokens({ + code: result.code, + state_token, + }); + if (cbResponse.status !== 200) throw cbResponse.data; + callbackResult = cbResponse.data; + } + + setCredentials({ + id: callbackResult.id, + provider: callbackResult.provider, + type: callbackResult.type, + title: callbackResult.title, + }); + setAuthRequired(false); + + // Discover tools now that we're authenticated + const toolsResponse = await postV2DiscoverAvailableToolsOnAnMcpServer({ + server_url: serverUrl.trim(), + }); + if (toolsResponse.status !== 200) throw toolsResponse.data; + setTools(toolsResponse.data.tools); + setServerName(toolsResponse.data.server_name ?? null); + setStep("tool"); + } catch (e: any) { + // If server doesn't support OAuth → show manual token entry + if (e?.status === 400) { + setShowManualToken(true); + setError( + "This server does not support OAuth sign-in. Please enter a token manually.", + ); + } else if (e?.message === "OAuth flow timed out") { + setError("OAuth sign-in timed out. Please try again."); + } else { + const status = e?.status; + let message: string; + if (status === 401 || status === 403) { + message = + "Authentication succeeded but the server still rejected the request. " + + "The token audience may not match. Please try again."; + } else { + message = e?.message || e?.detail || "Failed to complete sign-in"; + } + setError( + typeof message === "string" ? message : JSON.stringify(message), + ); + } + } finally { + setOauthLoading(false); + setLoading(false); + oauthAbortRef.current = null; + } + }, [serverUrl, allProviders]); + + // Auto-start OAuth sign-in when server returns 401/403 + useEffect(() => { + if (authRequired && startOAuthRef.current) { + startOAuthRef.current = false; + handleOAuthSignIn(); + } + }, [authRequired, handleOAuthSignIn]); + + const handleConfirm = useCallback(() => { + if (!selectedTool) return; + + const availableTools: Record = {}; + for (const t of tools) { + availableTools[t.name] = { + description: t.description, + input_schema: t.input_schema, + }; + } + + onConfirm({ + serverUrl: serverUrl.trim(), + serverName, + selectedTool: selectedTool.name, + toolInputSchema: selectedTool.input_schema, + availableTools, + credentials, + }); + reset(); + }, [ + selectedTool, + tools, + serverUrl, + serverName, + credentials, + onConfirm, + reset, + ]); + + return ( + !isOpen && handleClose()}> + + + + {step === "url" + ? "Connect to MCP Server" + : `Select a Tool${serverName ? ` — ${serverName}` : ""}`} + + + {step === "url" + ? "Enter the URL of an MCP server to discover its available tools." + : `Found ${tools.length} tool${tools.length !== 1 ? "s" : ""}. Select one to add to your agent.`} + + + + {step === "url" && ( +
+
+ + setServerUrl(e.target.value)} + onKeyDown={(e) => e.key === "Enter" && handleDiscoverTools()} + autoFocus + /> +
+ + {/* Auth required: show manual token option */} + {authRequired && !showManualToken && ( + + )} + + {/* Manual token entry — only visible when expanded */} + {showManualToken && ( +
+ + setManualToken(e.target.value)} + onKeyDown={(e) => e.key === "Enter" && handleDiscoverTools()} + autoFocus + /> +
+ )} + + {error &&

{error}

} +
+ )} + + {step === "tool" && ( + +
+ {tools.map((tool) => ( + setSelectedTool(tool)} + /> + ))} +
+
+ )} + + + {step === "tool" && ( + + )} + + {step === "url" && ( + + )} + {step === "tool" && ( + + )} + +
+
+ ); +} + +// --------------- Tool Card Component --------------- // + +/** Truncate a description to a reasonable length for the collapsed view. */ +function truncateDescription(text: string, maxLen = 120): string { + if (text.length <= maxLen) return text; + return text.slice(0, maxLen).trimEnd() + "…"; +} + +/** Pretty-print a JSON Schema type for a parameter. */ +function schemaTypeLabel(schema: Record): string { + if (schema.type) return schema.type; + if (schema.anyOf) + return schema.anyOf.map((s: any) => s.type ?? "any").join(" | "); + if (schema.oneOf) + return schema.oneOf.map((s: any) => s.type ?? "any").join(" | "); + return "any"; +} + +function MCPToolCard({ + tool, + selected, + onSelect, +}: { + tool: MCPToolResponse; + selected: boolean; + onSelect: () => void; +}) { + const [expanded, setExpanded] = useState(false); + const schema = tool.input_schema as Record; + const properties = schema?.properties ?? {}; + const required = new Set(schema?.required ?? []); + const paramNames = Object.keys(properties); + + // Strip XML-like tags from description for cleaner display. + // Loop to handle nested tags like ipt> (CodeQL fix). + let cleanDescription = tool.description ?? ""; + let prev = ""; + while (prev !== cleanDescription) { + prev = cleanDescription; + cleanDescription = cleanDescription.replace(/<[^>]*>/g, ""); + } + cleanDescription = cleanDescription.trim(); + + return ( + + )} + + ); +} diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/Block.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/Block.tsx index 10f4fc8a44..07c6795808 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/Block.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/NewControlPanel/NewBlockMenu/Block.tsx @@ -1,7 +1,7 @@ import { Button } from "@/components/__legacy__/ui/button"; import { Skeleton } from "@/components/__legacy__/ui/skeleton"; import { beautifyString, cn } from "@/lib/utils"; -import React, { ButtonHTMLAttributes } from "react"; +import React, { ButtonHTMLAttributes, useCallback, useState } from "react"; import { highlightText } from "./helpers"; import { PlusIcon } from "@phosphor-icons/react"; import { BlockInfo } from "@/app/api/__generated__/models/blockInfo"; @@ -9,6 +9,12 @@ import { useControlPanelStore } from "../../../stores/controlPanelStore"; import { blockDragPreviewStyle } from "./style"; import { useReactFlow } from "@xyflow/react"; import { useNodeStore } from "../../../stores/nodeStore"; +import { BlockUIType, SpecialBlockID } from "@/lib/autogpt-server-api"; +import { + MCPToolDialog, + type MCPToolDialogResult, +} from "@/app/(platform)/build/components/MCPToolDialog"; + interface Props extends ButtonHTMLAttributes { title?: string; description?: string; @@ -33,22 +39,86 @@ export const Block: BlockComponent = ({ ); const { setViewport } = useReactFlow(); const { addBlock } = useNodeStore(); + const [mcpDialogOpen, setMcpDialogOpen] = useState(false); + + const isMCPBlock = blockData.uiType === BlockUIType.MCP_TOOL; + + const addBlockAndCenter = useCallback( + (block: BlockInfo, hardcodedValues?: Record) => { + const customNode = addBlock(block, hardcodedValues); + setTimeout(() => { + setViewport( + { + x: -customNode.position.x * 0.8 + window.innerWidth / 2, + y: -customNode.position.y * 0.8 + (window.innerHeight - 400) / 2, + zoom: 0.8, + }, + { duration: 500 }, + ); + }, 50); + return customNode; + }, + [addBlock, setViewport], + ); + + const updateNodeData = useNodeStore((state) => state.updateNodeData); + + const handleMCPToolConfirm = useCallback( + (result: MCPToolDialogResult) => { + // Derive a display label: prefer server name, fall back to URL hostname. + let serverLabel = result.serverName; + if (!serverLabel) { + try { + serverLabel = new URL(result.serverUrl).hostname; + } catch { + serverLabel = "MCP"; + } + } + + const customNode = addBlockAndCenter(blockData, { + server_url: result.serverUrl, + server_name: serverLabel, + selected_tool: result.selectedTool, + tool_input_schema: result.toolInputSchema, + available_tools: result.availableTools, + credentials: result.credentials ?? undefined, + }); + if (customNode) { + const title = result.selectedTool + ? `${serverLabel}: ${beautifyString(result.selectedTool)}` + : undefined; + updateNodeData(customNode.id, { + metadata: { + ...customNode.data.metadata, + credentials_optional: true, + ...(title && { customized_name: title }), + }, + }); + } + setMcpDialogOpen(false); + }, + [addBlockAndCenter, blockData, updateNodeData], + ); const handleClick = () => { - const customNode = addBlock(blockData); - setTimeout(() => { - setViewport( - { - x: -customNode.position.x * 0.8 + window.innerWidth / 2, - y: -customNode.position.y * 0.8 + (window.innerHeight - 400) / 2, - zoom: 0.8, + if (isMCPBlock) { + setMcpDialogOpen(true); + return; + } + const customNode = addBlockAndCenter(blockData); + // Set customized_name for agent blocks so the agent's name persists + if (customNode && blockData.id === SpecialBlockID.AGENT) { + updateNodeData(customNode.id, { + metadata: { + ...customNode.data.metadata, + customized_name: blockData.name, }, - { duration: 500 }, - ); - }, 50); + }); + } }; const handleDragStart = (e: React.DragEvent) => { + if (isMCPBlock) return; e.dataTransfer.effectAllowed = "copy"; e.dataTransfer.setData("application/reactflow", JSON.stringify(blockData)); @@ -71,46 +141,56 @@ export const Block: BlockComponent = ({ : undefined; return ( -
- +
+ {title && ( + + {highlightText(beautifyString(title), highlightedText)} + + )} + {description && ( + + {highlightText(description, highlightedText)} + + )} +
+
+ +
+ + {isMCPBlock && ( + setMcpDialogOpen(false)} + onConfirm={handleMCPToolConfirm} + /> + )} + ); }; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/types.ts b/autogpt_platform/frontend/src/app/(platform)/build/components/types.ts index 2fde427330..0f5021351d 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/types.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/types.ts @@ -9,4 +9,5 @@ export enum BlockUIType { AGENT = "Agent", AI = "AI", AYRSHARE = "Ayrshare", + MCP_TOOL = "MCP Tool", } diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index 8e48931540..63a8a856b9 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -4269,6 +4269,128 @@ } } }, + "/api/mcp/discover-tools": { + "post": { + "tags": ["v2", "mcp", "mcp"], + "summary": "Discover available tools on an MCP server", + "description": "Connect to an MCP server and return its available tools.\n\nIf the user has a stored MCP credential for this server URL, it will be\nused automatically — no need to pass an explicit auth token.", + "operationId": "postV2Discover available tools on an mcp server", + "requestBody": { + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/DiscoverToolsRequest" } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DiscoverToolsResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, + "/api/mcp/oauth/callback": { + "post": { + "tags": ["v2", "mcp", "mcp"], + "summary": "Exchange OAuth code for MCP tokens", + "description": "Exchange the authorization code for tokens and store the credential.\n\nThe frontend calls this after receiving the OAuth code from the popup.\nOn success, subsequent ``/discover-tools`` calls for the same server URL\nwill automatically use the stored credential.", + "operationId": "postV2Exchange oauth code for mcp tokens", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/MCPOAuthCallbackRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/CredentialsMetaResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, + "/api/mcp/oauth/login": { + "post": { + "tags": ["v2", "mcp", "mcp"], + "summary": "Initiate OAuth login for an MCP server", + "description": "Discover OAuth metadata from the MCP server and return a login URL.\n\n1. Discovers the protected-resource metadata (RFC 9728)\n2. Fetches the authorization server metadata (RFC 8414)\n3. Performs Dynamic Client Registration (RFC 7591) if available\n4. Returns the authorization URL for the frontend to open in a popup", + "operationId": "postV2Initiate oauth login for an mcp server", + "requestBody": { + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/MCPOAuthLoginRequest" } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/MCPOAuthLoginResponse" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + }, + "security": [{ "HTTPBearerJWT": [] }] + } + }, "/api/oauth/app/{client_id}": { "get": { "tags": ["oauth"], @@ -7691,7 +7813,7 @@ "host": { "anyOf": [{ "type": "string" }, { "type": "null" }], "title": "Host", - "description": "Host pattern for host-scoped credentials" + "description": "Host pattern for host-scoped or MCP server URL for MCP credentials" } }, "type": "object", @@ -7711,6 +7833,45 @@ "required": ["version_counts"], "title": "DeleteGraphResponse" }, + "DiscoverToolsRequest": { + "properties": { + "server_url": { + "type": "string", + "title": "Server Url", + "description": "URL of the MCP server" + }, + "auth_token": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Auth Token", + "description": "Optional Bearer token for authenticated MCP servers" + } + }, + "type": "object", + "required": ["server_url"], + "title": "DiscoverToolsRequest", + "description": "Request to discover tools on an MCP server." + }, + "DiscoverToolsResponse": { + "properties": { + "tools": { + "items": { "$ref": "#/components/schemas/MCPToolResponse" }, + "type": "array", + "title": "Tools" + }, + "server_name": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Server Name" + }, + "protocol_version": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Protocol Version" + } + }, + "type": "object", + "required": ["tools"], + "title": "DiscoverToolsResponse", + "description": "Response containing the list of tools available on an MCP server." + }, "DocPageResponse": { "properties": { "type": { @@ -9287,6 +9448,62 @@ "required": ["login_url", "state_token"], "title": "LoginResponse" }, + "MCPOAuthCallbackRequest": { + "properties": { + "code": { + "type": "string", + "title": "Code", + "description": "Authorization code from OAuth callback" + }, + "state_token": { + "type": "string", + "title": "State Token", + "description": "State token for CSRF verification" + } + }, + "type": "object", + "required": ["code", "state_token"], + "title": "MCPOAuthCallbackRequest", + "description": "Request to exchange an OAuth code for tokens." + }, + "MCPOAuthLoginRequest": { + "properties": { + "server_url": { + "type": "string", + "title": "Server Url", + "description": "URL of the MCP server that requires OAuth" + } + }, + "type": "object", + "required": ["server_url"], + "title": "MCPOAuthLoginRequest", + "description": "Request to start an OAuth flow for an MCP server." + }, + "MCPOAuthLoginResponse": { + "properties": { + "login_url": { "type": "string", "title": "Login Url" }, + "state_token": { "type": "string", "title": "State Token" } + }, + "type": "object", + "required": ["login_url", "state_token"], + "title": "MCPOAuthLoginResponse", + "description": "Response with the OAuth login URL for the user to authenticate." + }, + "MCPToolResponse": { + "properties": { + "name": { "type": "string", "title": "Name" }, + "description": { "type": "string", "title": "Description" }, + "input_schema": { + "additionalProperties": true, + "type": "object", + "title": "Input Schema" + } + }, + "type": "object", + "required": ["name", "description", "input_schema"], + "title": "MCPToolResponse", + "description": "A single MCP tool returned by discovery." + }, "MarketplaceListing": { "properties": { "id": { "type": "string", "title": "Id" }, diff --git a/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView.tsx b/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView.tsx index 135a960431..22d0a318a9 100644 --- a/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView.tsx +++ b/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/CredentialsGroupedView.tsx @@ -38,13 +38,8 @@ export function CredentialsGroupedView({ const allProviders = useContext(CredentialsProvidersContext); const { userCredentialFields, systemCredentialFields } = useMemo( - () => - splitCredentialFieldsBySystem( - credentialFields, - allProviders, - inputCredentials, - ), - [credentialFields, allProviders, inputCredentials], + () => splitCredentialFieldsBySystem(credentialFields, allProviders), + [credentialFields, allProviders], ); const hasSystemCredentials = systemCredentialFields.length > 0; @@ -86,11 +81,13 @@ export function CredentialsGroupedView({ const providerNames = schema.credentials_provider || []; const credentialTypes = schema.credentials_types || []; const requiredScopes = schema.credentials_scopes; + const discriminatorValues = schema.discriminator_values; const savedCredential = findSavedCredentialByProviderAndType( providerNames, credentialTypes, requiredScopes, allProviders, + discriminatorValues, ); if (savedCredential) { diff --git a/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/helpers.ts b/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/helpers.ts index 5f439d3a32..2d8d001a72 100644 --- a/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/helpers.ts +++ b/autogpt_platform/frontend/src/components/contextual/CredentialsInput/components/CredentialsGroupedView/helpers.ts @@ -23,10 +23,35 @@ function hasRequiredScopes( return true; } +/** Check if a credential matches the discriminator values (e.g. MCP server URL). */ +function matchesDiscriminatorValues( + credential: { host?: string | null; provider: string; type: string }, + discriminatorValues?: string[], +) { + // MCP OAuth2 credentials must match by server URL + if (credential.type === "oauth2" && credential.provider === "mcp") { + if (!discriminatorValues || discriminatorValues.length === 0) return false; + return ( + credential.host != null && discriminatorValues.includes(credential.host) + ); + } + // Host-scoped credentials match by host + if (credential.type === "host_scoped" && credential.host) { + if (!discriminatorValues || discriminatorValues.length === 0) return true; + return discriminatorValues.some((v) => { + try { + return new URL(v).hostname === credential.host; + } catch { + return false; + } + }); + } + return true; +} + export function splitCredentialFieldsBySystem( credentialFields: CredentialField[], allProviders: CredentialsProvidersContextType | null, - inputCredentials?: Record, ) { if (!allProviders || credentialFields.length === 0) { return { @@ -52,17 +77,9 @@ export function splitCredentialFieldsBySystem( } } - const sortByUnsetFirst = (a: CredentialField, b: CredentialField) => { - const aIsSet = Boolean(inputCredentials?.[a[0]]); - const bIsSet = Boolean(inputCredentials?.[b[0]]); - - if (aIsSet === bIsSet) return 0; - return aIsSet ? 1 : -1; - }; - return { - userCredentialFields: userFields.sort(sortByUnsetFirst), - systemCredentialFields: systemFields.sort(sortByUnsetFirst), + userCredentialFields: userFields, + systemCredentialFields: systemFields, }; } @@ -160,6 +177,7 @@ export function findSavedCredentialByProviderAndType( credentialTypes: string[], requiredScopes: string[] | undefined, allProviders: CredentialsProvidersContextType | null, + discriminatorValues?: string[], ): SavedCredential | undefined { for (const providerName of providerNames) { const providerData = allProviders?.[providerName]; @@ -176,9 +194,14 @@ export function findSavedCredentialByProviderAndType( credentialTypes.length === 0 || credentialTypes.includes(credential.type); const scopesMatch = hasRequiredScopes(credential, requiredScopes); + const hostMatches = matchesDiscriminatorValues( + credential, + discriminatorValues, + ); if (!typeMatches) continue; if (!scopesMatch) continue; + if (!hostMatches) continue; matchingCredentials.push(credential as SavedCredential); } @@ -190,9 +213,14 @@ export function findSavedCredentialByProviderAndType( credentialTypes.length === 0 || credentialTypes.includes(credential.type); const scopesMatch = hasRequiredScopes(credential, requiredScopes); + const hostMatches = matchesDiscriminatorValues( + credential, + discriminatorValues, + ); if (!typeMatches) continue; if (!scopesMatch) continue; + if (!hostMatches) continue; matchingCredentials.push(credential as SavedCredential); } @@ -214,6 +242,7 @@ export function findSavedUserCredentialByProviderAndType( credentialTypes: string[], requiredScopes: string[] | undefined, allProviders: CredentialsProvidersContextType | null, + discriminatorValues?: string[], ): SavedCredential | undefined { for (const providerName of providerNames) { const providerData = allProviders?.[providerName]; @@ -230,9 +259,14 @@ export function findSavedUserCredentialByProviderAndType( credentialTypes.length === 0 || credentialTypes.includes(credential.type); const scopesMatch = hasRequiredScopes(credential, requiredScopes); + const hostMatches = matchesDiscriminatorValues( + credential, + discriminatorValues, + ); if (!typeMatches) continue; if (!scopesMatch) continue; + if (!hostMatches) continue; matchingCredentials.push(credential as SavedCredential); } diff --git a/autogpt_platform/frontend/src/components/contextual/CredentialsInput/useCredentialsInput.ts b/autogpt_platform/frontend/src/components/contextual/CredentialsInput/useCredentialsInput.ts index 509713ff1e..9ab2e08141 100644 --- a/autogpt_platform/frontend/src/components/contextual/CredentialsInput/useCredentialsInput.ts +++ b/autogpt_platform/frontend/src/components/contextual/CredentialsInput/useCredentialsInput.ts @@ -5,14 +5,14 @@ import { BlockIOCredentialsSubSchema, CredentialsMetaInput, } from "@/lib/autogpt-server-api/types"; +import { postV2InitiateOauthLoginForAnMcpServer } from "@/app/api/__generated__/endpoints/mcp/mcp"; +import { openOAuthPopup } from "@/lib/oauth-popup"; import { useQueryClient } from "@tanstack/react-query"; import { useEffect, useRef, useState } from "react"; import { filterSystemCredentials, getActionButtonText, getSystemCredentials, - OAUTH_TIMEOUT_MS, - OAuthPopupResultMessage, } from "./helpers"; export type CredentialsInputState = ReturnType; @@ -57,6 +57,14 @@ export function useCredentialsInput({ const queryClient = useQueryClient(); const credentials = useCredentials(schema, siblingInputs); const hasAttemptedAutoSelect = useRef(false); + const oauthAbortRef = useRef<((reason?: string) => void) | null>(null); + + // Clean up on unmount + useEffect(() => { + return () => { + oauthAbortRef.current?.(); + }; + }, []); const deleteCredentialsMutation = useDeleteV1DeleteCredentials({ mutation: { @@ -81,11 +89,14 @@ export function useCredentialsInput({ } }, [credentials, onLoaded]); - // Unselect credential if not available + // Unselect credential if not available in the loaded credential list. + // Skip when no credentials have been loaded yet (empty list could mean + // the provider data hasn't finished loading, not that the credential is invalid). useEffect(() => { if (readOnly) return; if (!credentials || !("savedCredentials" in credentials)) return; const availableCreds = credentials.savedCredentials; + if (availableCreds.length === 0) return; if ( selectedCredential && !availableCreds.some((c) => c.id === selectedCredential.id) @@ -110,7 +121,9 @@ export function useCredentialsInput({ if (hasAttemptedAutoSelect.current) return; hasAttemptedAutoSelect.current = true; - if (isOptional) return; + // Auto-select if exactly one credential matches. + // For optional fields with multiple options, let the user choose. + if (isOptional && savedCreds.length > 1) return; const cred = savedCreds[0]; onSelectCredential({ @@ -148,7 +161,9 @@ export function useCredentialsInput({ supportsHostScoped, savedCredentials, oAuthCallback, + mcpOAuthCallback, isSystemProvider, + discriminatorValue, } = credentials; // Split credentials into user and system @@ -157,72 +172,66 @@ export function useCredentialsInput({ async function handleOAuthLogin() { setOAuthError(null); - const { login_url, state_token } = await api.oAuthLogin( - provider, - schema.credentials_scopes, - ); - setOAuth2FlowInProgress(true); - const popup = window.open(login_url, "_blank", "popup=true"); - if (!popup) { - throw new Error( - "Failed to open popup window. Please allow popups for this site.", + // Abort any previous OAuth flow + oauthAbortRef.current?.(); + + // MCP uses dynamic OAuth discovery per server URL + const isMCP = provider === "mcp" && !!discriminatorValue; + + try { + let login_url: string; + let state_token: string; + + if (isMCP) { + const mcpLoginResponse = await postV2InitiateOauthLoginForAnMcpServer({ + server_url: discriminatorValue!, + }); + if (mcpLoginResponse.status !== 200) throw mcpLoginResponse.data; + ({ login_url, state_token } = mcpLoginResponse.data); + } else { + ({ login_url, state_token } = await api.oAuthLogin( + provider, + schema.credentials_scopes, + )); + } + + setOAuth2FlowInProgress(true); + + const { promise, cleanup } = openOAuthPopup(login_url, { + stateToken: state_token, + useCrossOriginListeners: isMCP, + // Standard OAuth uses "oauth_popup_result", MCP uses "mcp_oauth_result" + acceptMessageTypes: isMCP + ? ["mcp_oauth_result"] + : ["oauth_popup_result"], + }); + + oauthAbortRef.current = cleanup.abort; + // Expose abort signal for the waiting modal's cancel button + const controller = new AbortController(); + cleanup.signal.addEventListener("abort", () => + controller.abort("completed"), ); - } + setOAuthPopupController(controller); - const controller = new AbortController(); - setOAuthPopupController(controller); - controller.signal.onabort = () => { - console.debug("OAuth flow aborted"); - setOAuth2FlowInProgress(false); - popup.close(); - }; + const result = await promise; - const handleMessage = async (e: MessageEvent) => { - console.debug("Message received:", e.data); - if ( - typeof e.data != "object" || - !("message_type" in e.data) || - e.data.message_type !== "oauth_popup_result" - ) { - console.debug("Ignoring irrelevant message"); - return; - } + // Exchange code for tokens via the provider (updates credential cache) + const credentialResult = isMCP + ? await mcpOAuthCallback(result.code, state_token) + : await oAuthCallback(result.code, result.state); - if (!e.data.success) { - console.error("OAuth flow failed:", e.data.message); - setOAuthError(`OAuth flow failed: ${e.data.message}`); - setOAuth2FlowInProgress(false); - return; - } - - if (e.data.state !== state_token) { - console.error("Invalid state token received"); - setOAuthError("Invalid state token received"); - setOAuth2FlowInProgress(false); - return; - } - - try { - console.debug("Processing OAuth callback"); - const credentials = await oAuthCallback(e.data.code, e.data.state); - console.debug("OAuth callback processed successfully"); - - // Check if the credential's scopes match the required scopes + // Check if the credential's scopes match the required scopes (skip for MCP) + if (!isMCP) { const requiredScopes = schema.credentials_scopes; if (requiredScopes && requiredScopes.length > 0) { - const grantedScopes = new Set(credentials.scopes || []); + const grantedScopes = new Set(credentialResult.scopes || []); const hasAllRequiredScopes = new Set(requiredScopes).isSubsetOf( grantedScopes, ); if (!hasAllRequiredScopes) { - console.error( - `Newly created OAuth credential for ${providerName} has insufficient scopes. Required:`, - requiredScopes, - "Granted:", - credentials.scopes, - ); setOAuthError( "Connection failed: the granted permissions don't match what's required. " + "Please contact the application administrator.", @@ -230,38 +239,28 @@ export function useCredentialsInput({ return; } } + } - onSelectCredential({ - id: credentials.id, - type: "oauth2", - title: credentials.title, - provider, - }); - } catch (error) { - console.error("Error in OAuth callback:", error); + onSelectCredential({ + id: credentialResult.id, + type: "oauth2", + title: credentialResult.title, + provider, + }); + } catch (error) { + if (error instanceof Error && error.message === "OAuth flow timed out") { + setOAuthError("OAuth flow timed out"); + } else { setOAuthError( - `Error in OAuth callback: ${ + `OAuth error: ${ error instanceof Error ? error.message : String(error) }`, ); - } finally { - console.debug("Finalizing OAuth flow"); - setOAuth2FlowInProgress(false); - controller.abort("success"); } - }; - - console.debug("Adding message event listener"); - window.addEventListener("message", handleMessage, { - signal: controller.signal, - }); - - setTimeout(() => { - console.debug("OAuth flow timed out"); - controller.abort("timeout"); + } finally { setOAuth2FlowInProgress(false); - setOAuthError("OAuth flow timed out"); - }, OAUTH_TIMEOUT_MS); + oauthAbortRef.current = null; + } } function handleActionButtonClick() { diff --git a/autogpt_platform/frontend/src/hooks/useCredentials.ts b/autogpt_platform/frontend/src/hooks/useCredentials.ts index eda6ab0278..9a78e5b8f4 100644 --- a/autogpt_platform/frontend/src/hooks/useCredentials.ts +++ b/autogpt_platform/frontend/src/hooks/useCredentials.ts @@ -100,6 +100,11 @@ export default function useCredentials( return false; } + // Filter MCP OAuth2 credentials by server URL matching + if (c.type === "oauth2" && c.provider === "mcp") { + return discriminatorValue != null && c.host === discriminatorValue; + } + // Filter by OAuth credentials that have sufficient scopes for this block if (c.type === "oauth2") { const requiredScopes = credsInputSchema.credentials_scopes; diff --git a/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts b/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts index 65625f1cfb..ffc21269e6 100644 --- a/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts +++ b/autogpt_platform/frontend/src/lib/autogpt-server-api/types.ts @@ -749,10 +749,12 @@ export enum BlockUIType { AGENT = "Agent", AI = "AI", AYRSHARE = "Ayrshare", + MCP_TOOL = "MCP Tool", } export enum SpecialBlockID { AGENT = "e189baac-8c20-45a1-94a7-55177ea42565", + MCP_TOOL = "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4", SMART_DECISION = "3b191d9f-356f-482d-8238-ba04b6d18381", OUTPUT = "363ae599-353e-4804-937e-b2ee3cef3da4", } diff --git a/autogpt_platform/frontend/src/lib/oauth-popup.ts b/autogpt_platform/frontend/src/lib/oauth-popup.ts new file mode 100644 index 0000000000..2927887751 --- /dev/null +++ b/autogpt_platform/frontend/src/lib/oauth-popup.ts @@ -0,0 +1,177 @@ +/** + * Shared utility for OAuth popup flows with cross-origin support. + * + * Handles BroadcastChannel, postMessage, and localStorage polling + * to reliably receive OAuth callback results even when COOP headers + * sever the window.opener relationship. + */ + +const DEFAULT_TIMEOUT_MS = 5 * 60 * 1000; // 5 minutes + +export type OAuthPopupResult = { + code: string; + state: string; +}; + +export type OAuthPopupOptions = { + /** State token to validate against incoming messages */ + stateToken: string; + /** + * Use BroadcastChannel + localStorage polling for cross-origin OAuth (MCP). + * Standard OAuth only uses postMessage via window.opener. + */ + useCrossOriginListeners?: boolean; + /** BroadcastChannel name (default: "mcp_oauth") */ + broadcastChannelName?: string; + /** localStorage key for cross-origin fallback (default: "mcp_oauth_result") */ + localStorageKey?: string; + /** Message types to accept (default: ["oauth_popup_result", "mcp_oauth_result"]) */ + acceptMessageTypes?: string[]; + /** Timeout in ms (default: 5 minutes) */ + timeout?: number; +}; + +type Cleanup = { + /** Abort the OAuth flow and close the popup */ + abort: (reason?: string) => void; + /** The AbortController signal */ + signal: AbortSignal; +}; + +/** + * Opens an OAuth popup and sets up listeners for the callback result. + * + * Opens a blank popup synchronously (to avoid popup blockers), then navigates + * it to the login URL. Returns a promise that resolves with the OAuth code/state. + * + * @param loginUrl - The OAuth authorization URL to navigate to + * @param options - Configuration for message handling + * @returns Object with `promise` (resolves with OAuth result) and `abort` (cancels flow) + */ +export function openOAuthPopup( + loginUrl: string, + options: OAuthPopupOptions, +): { promise: Promise; cleanup: Cleanup } { + const { + stateToken, + useCrossOriginListeners = false, + broadcastChannelName = "mcp_oauth", + localStorageKey = "mcp_oauth_result", + acceptMessageTypes = ["oauth_popup_result", "mcp_oauth_result"], + timeout = DEFAULT_TIMEOUT_MS, + } = options; + + const controller = new AbortController(); + + // Open popup synchronously (before any async work) to avoid browser popup blockers + const width = 500; + const height = 700; + const left = window.screenX + (window.outerWidth - width) / 2; + const top = window.screenY + (window.outerHeight - height) / 2; + const popup = window.open( + "about:blank", + "_blank", + `width=${width},height=${height},left=${left},top=${top},popup=true,scrollbars=yes`, + ); + + if (popup && !popup.closed) { + popup.location.href = loginUrl; + } else { + // Popup was blocked — open in new tab as fallback + window.open(loginUrl, "_blank"); + } + + // Close popup on abort + controller.signal.addEventListener("abort", () => { + if (popup && !popup.closed) popup.close(); + }); + + // Clear any stale localStorage entry + if (useCrossOriginListeners) { + try { + localStorage.removeItem(localStorageKey); + } catch {} + } + + const promise = new Promise((resolve, reject) => { + let handled = false; + + const handleResult = (data: any) => { + if (handled) return; // Prevent double-handling + + // Validate message type + const messageType = data?.message_type ?? data?.type; + if (!messageType || !acceptMessageTypes.includes(messageType)) return; + + // Validate state token + if (data.state !== stateToken) { + // State mismatch — this message is for a different listener. Ignore silently. + return; + } + + handled = true; + + if (!data.success) { + reject(new Error(data.message || "OAuth authentication failed")); + } else { + resolve({ code: data.code, state: data.state }); + } + + controller.abort("completed"); + }; + + // Listener: postMessage (works for same-origin popups) + window.addEventListener( + "message", + (event: MessageEvent) => { + if (typeof event.data === "object") { + handleResult(event.data); + } + }, + { signal: controller.signal }, + ); + + // Cross-origin listeners for MCP OAuth + if (useCrossOriginListeners) { + // Listener: BroadcastChannel (works across tabs/popups without opener) + try { + const bc = new BroadcastChannel(broadcastChannelName); + bc.onmessage = (event) => handleResult(event.data); + controller.signal.addEventListener("abort", () => bc.close()); + } catch {} + + // Listener: localStorage polling (most reliable cross-tab fallback) + const pollInterval = setInterval(() => { + try { + const stored = localStorage.getItem(localStorageKey); + if (stored) { + const data = JSON.parse(stored); + localStorage.removeItem(localStorageKey); + handleResult(data); + } + } catch {} + }, 500); + controller.signal.addEventListener("abort", () => + clearInterval(pollInterval), + ); + } + + // Timeout + const timeoutId = setTimeout(() => { + if (!handled) { + handled = true; + reject(new Error("OAuth flow timed out")); + controller.abort("timeout"); + } + }, timeout); + controller.signal.addEventListener("abort", () => clearTimeout(timeoutId)); + }); + + return { + promise, + cleanup: { + abort: (reason?: string) => controller.abort(reason || "canceled"), + signal: controller.signal, + }, + }; +} diff --git a/autogpt_platform/frontend/src/middleware.ts b/autogpt_platform/frontend/src/middleware.ts index af1c823295..8cec8a2645 100644 --- a/autogpt_platform/frontend/src/middleware.ts +++ b/autogpt_platform/frontend/src/middleware.ts @@ -18,6 +18,6 @@ export const config = { * Note: /auth/authorize and /auth/integrations/* ARE protected and need * middleware to run for authentication checks. */ - "/((?!_next/static|_next/image|favicon.ico|auth/callback|.*\\.(?:svg|png|jpg|jpeg|gif|webp)$).*)", + "/((?!_next/static|_next/image|favicon.ico|auth/callback|auth/integrations/mcp_callback|.*\\.(?:svg|png|jpg|jpeg|gif|webp)$).*)", ], }; diff --git a/autogpt_platform/frontend/src/providers/agent-credentials/credentials-provider.tsx b/autogpt_platform/frontend/src/providers/agent-credentials/credentials-provider.tsx index e47cc65e13..a426d8f667 100644 --- a/autogpt_platform/frontend/src/providers/agent-credentials/credentials-provider.tsx +++ b/autogpt_platform/frontend/src/providers/agent-credentials/credentials-provider.tsx @@ -8,6 +8,7 @@ import { HostScopedCredentials, UserPasswordCredentials, } from "@/lib/autogpt-server-api"; +import { postV2ExchangeOauthCodeForMcpTokens } from "@/app/api/__generated__/endpoints/mcp/mcp"; import { useBackendAPI } from "@/lib/autogpt-server-api/context"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { toDisplayName } from "@/providers/agent-credentials/helper"; @@ -38,6 +39,11 @@ export type CredentialsProviderData = { code: string, state_token: string, ) => Promise; + /** MCP-specific OAuth callback that uses dynamic per-server OAuth discovery. */ + mcpOAuthCallback: ( + code: string, + state_token: string, + ) => Promise; createAPIKeyCredentials: ( credentials: APIKeyCredentialsCreatable, ) => Promise; @@ -120,6 +126,35 @@ export default function CredentialsProvider({ [api, addCredentials, onFailToast], ); + /** Exchanges an MCP OAuth code for tokens and adds the result to the internal credentials store. */ + const mcpOAuthCallback = useCallback( + async ( + code: string, + state_token: string, + ): Promise => { + try { + const response = await postV2ExchangeOauthCodeForMcpTokens({ + code, + state_token, + }); + if (response.status !== 200) throw response.data; + const credsMeta: CredentialsMetaResponse = { + ...response.data, + title: response.data.title ?? undefined, + scopes: response.data.scopes ?? undefined, + username: response.data.username ?? undefined, + host: response.data.host ?? undefined, + }; + addCredentials("mcp", credsMeta); + return credsMeta; + } catch (error) { + onFailToast("complete MCP OAuth authentication")(error); + throw error; + } + }, + [addCredentials, onFailToast], + ); + /** Wraps `BackendAPI.createAPIKeyCredentials`, and adds the result to the internal credentials store. */ const createAPIKeyCredentials = useCallback( async ( @@ -258,6 +293,7 @@ export default function CredentialsProvider({ isSystemProvider: systemProviders.has(provider), oAuthCallback: (code: string, state_token: string) => oAuthCallback(provider, code, state_token), + mcpOAuthCallback, createAPIKeyCredentials: ( credentials: APIKeyCredentialsCreatable, ) => createAPIKeyCredentials(provider, credentials), @@ -286,6 +322,7 @@ export default function CredentialsProvider({ createHostScopedCredentials, deleteCredentials, oAuthCallback, + mcpOAuthCallback, onFailToast, ]); diff --git a/autogpt_platform/frontend/src/tests/pages/build.page.ts b/autogpt_platform/frontend/src/tests/pages/build.page.ts index 9370288f8e..3bb9552b82 100644 --- a/autogpt_platform/frontend/src/tests/pages/build.page.ts +++ b/autogpt_platform/frontend/src/tests/pages/build.page.ts @@ -528,6 +528,9 @@ export class BuildPage extends BasePage { async getBlocksToSkip(): Promise { return [ (await this.getGithubTriggerBlockDetails()).map((b) => b.id), + // MCP Tool block requires an interactive dialog (server URL + OAuth) before + // it can be placed, so it can't be tested via the standard "add block" flow. + "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4", ].flat(); } diff --git a/docs/integrations/README.md b/docs/integrations/README.md index a471ef3533..c216aa4836 100644 --- a/docs/integrations/README.md +++ b/docs/integrations/README.md @@ -467,6 +467,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim | [Github Update Comment](block-integrations/github/issues.md#github-update-comment) | A block that updates an existing comment on a GitHub issue or pull request | | [Github Update File](block-integrations/github/repo.md#github-update-file) | This block updates an existing file in a GitHub repository | | [Instantiate Code Sandbox](block-integrations/misc.md#instantiate-code-sandbox) | Instantiate a sandbox environment with internet access in which you can execute code with the Execute Code Step block | +| [MCP Tool](block-integrations/mcp/block.md#mcp-tool) | Connect to any MCP server and execute its tools | | [Slant3D Order Webhook](block-integrations/slant3d/webhook.md#slant3d-order-webhook) | This block triggers on Slant3D order status updates and outputs the event details, including tracking information when orders are shipped | ## Media Generation diff --git a/docs/integrations/SUMMARY.md b/docs/integrations/SUMMARY.md index f481ae2e0a..3ad4bf2c6d 100644 --- a/docs/integrations/SUMMARY.md +++ b/docs/integrations/SUMMARY.md @@ -84,6 +84,7 @@ * [Linear Projects](block-integrations/linear/projects.md) * [LLM](block-integrations/llm.md) * [Logic](block-integrations/logic.md) +* [Mcp Block](block-integrations/mcp/block.md) * [Misc](block-integrations/misc.md) * [Notion Create Page](block-integrations/notion/create_page.md) * [Notion Read Database](block-integrations/notion/read_database.md) diff --git a/docs/integrations/block-integrations/mcp/block.md b/docs/integrations/block-integrations/mcp/block.md new file mode 100644 index 0000000000..6858e42e94 --- /dev/null +++ b/docs/integrations/block-integrations/mcp/block.md @@ -0,0 +1,40 @@ +# Mcp Block + +Blocks for connecting to and executing tools on MCP (Model Context Protocol) servers. + + +## MCP Tool + +### What it is +Connect to any MCP server and execute its tools. Provide a server URL, select a tool, and pass arguments dynamically. + +### How it works + +The block uses JSON-RPC 2.0 over HTTP to communicate with MCP servers. When configuring, it sends an `initialize` request followed by `tools/list` to discover available tools and their input schemas. On execution, it calls `tools/call` with the selected tool name and arguments, then extracts text, image, or resource content from the response. + +Authentication is handled via OAuth 2.0 when the server requires it. The block supports optional credentials — public servers work without authentication, while protected servers trigger a standard OAuth flow with PKCE. Tokens are automatically refreshed when they expire. + + +### Inputs + +| Input | Description | Type | Required | +|-------|-------------|------|----------| +| server_url | URL of the MCP server (Streamable HTTP endpoint) | str | Yes | +| selected_tool | The MCP tool to execute | str | No | +| tool_arguments | Arguments to pass to the selected MCP tool. The fields here are defined by the tool's input schema. | Dict[str, Any] | No | + +### Outputs + +| Output | Description | Type | +|--------|-------------|------| +| error | Error message if the tool call failed | str | +| result | The result returned by the MCP tool | Result | + +### Possible use case + +- **Connecting to third-party APIs**: Use an MCP server like Sentry or Linear to query issues, create tickets, or manage projects without building custom integrations. +- **AI-powered tool execution**: Chain MCP tool calls with AI blocks to let agents dynamically discover and use external tools based on task requirements. +- **Data retrieval from knowledge bases**: Connect to MCP servers like DeepWiki to search documentation, retrieve code context, or query structured knowledge bases. + + +---