diff --git a/autogpt_platform/backend/backend/blocks/mcp/block.py b/autogpt_platform/backend/backend/blocks/mcp/block.py index e7ac707758..e11483dc67 100644 --- a/autogpt_platform/backend/backend/blocks/mcp/block.py +++ b/autogpt_platform/backend/backend/blocks/mcp/block.py @@ -26,6 +26,7 @@ from backend.data.model import ( APIKeyCredentials, CredentialsField, CredentialsMetaInput, + OAuth2Credentials, SchemaField, ) from backend.integrations.providers import ProviderName @@ -33,8 +34,9 @@ from backend.util.json import validate_with_jsonschema logger = logging.getLogger(__name__) +MCPCredentials = APIKeyCredentials | OAuth2Credentials MCPCredentialsInput = CredentialsMetaInput[ - Literal[ProviderName.MCP], Literal["api_key"] + Literal[ProviderName.MCP], Literal["api_key", "oauth2"] ] TEST_CREDENTIALS = APIKeyCredentials( @@ -67,8 +69,9 @@ class MCPToolBlock(Block): class Input(BlockSchemaInput): # -- Static fields (always shown) -- credentials: MCPCredentialsInput = CredentialsField( - description="API key / Bearer token for the MCP server (optional for " - "public servers — create a credential with any placeholder value).", + description="Credentials for the MCP server. Use an API key for Bearer " + "token auth, or OAuth2 for servers that support it. For public " + "servers, create a credential with any placeholder value.", ) server_url: str = SchemaField( description="URL of the MCP server (Streamable HTTP endpoint)", @@ -217,11 +220,25 @@ class MCPToolBlock(Block): return output_parts[0] return output_parts if output_parts else None + @staticmethod + def _extract_auth_token(credentials: MCPCredentials) -> str | None: + """Extract a Bearer token from either API key or OAuth2 credentials.""" + if isinstance(credentials, OAuth2Credentials): + return credentials.access_token.get_secret_value() + + if isinstance(credentials, APIKeyCredentials) and credentials.api_key: + token_value = credentials.api_key.get_secret_value() + # Skip placeholder/fake tokens + if token_value and token_value not in ("", "FAKE_API_KEY", "placeholder"): + return token_value + + return None + async def run( self, input_data: Input, *, - credentials: APIKeyCredentials, + credentials: MCPCredentials, **kwargs, ) -> BlockOutput: if not input_data.server_url: @@ -232,12 +249,7 @@ class MCPToolBlock(Block): yield "error", "No tool selected. Please select a tool from the dropdown." return - auth_token: str | None = None - if credentials and credentials.api_key: - token_value = credentials.api_key.get_secret_value() - # Skip placeholder/fake tokens - if token_value and token_value not in ("", "FAKE_API_KEY", "placeholder"): - auth_token = token_value + auth_token = self._extract_auth_token(credentials) try: result = await self._call_mcp_tool( diff --git a/autogpt_platform/backend/backend/blocks/mcp/client.py b/autogpt_platform/backend/backend/blocks/mcp/client.py index 9953fe013e..c92b19b386 100644 --- a/autogpt_platform/backend/backend/blocks/mcp/client.py +++ b/autogpt_platform/backend/backend/blocks/mcp/client.py @@ -4,9 +4,12 @@ MCP (Model Context Protocol) HTTP client. Implements the MCP Streamable HTTP transport for listing tools and calling tools on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST. -Reference: https://modelcontextprotocol.io/docs/concepts/transports +Handles both JSON and SSE (text/event-stream) response formats per the MCP spec. + +Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports """ +import json import logging from dataclasses import dataclass, field from typing import Any @@ -83,10 +86,43 @@ class MCPClient: req["params"] = params return req + @staticmethod + def _parse_sse_response(text: str) -> dict[str, Any]: + """Parse an SSE (text/event-stream) response body into JSON-RPC data. + + MCP servers may return responses as SSE with format: + event: message + data: {"jsonrpc":"2.0","result":{...},"id":1} + + We extract the last `data:` line that contains a JSON-RPC response + (i.e. has an "id" field), which is the reply to our request. + """ + last_data: dict[str, Any] | None = None + for line in text.splitlines(): + stripped = line.strip() + if stripped.startswith("data:"): + payload = stripped[len("data:") :].strip() + if not payload: + continue + try: + parsed = json.loads(payload) + # Only keep JSON-RPC responses (have "id"), skip notifications + if isinstance(parsed, dict) and "id" in parsed: + last_data = parsed + except (json.JSONDecodeError, ValueError): + continue + if last_data is None: + raise MCPClientError("No JSON-RPC response found in SSE stream") + return last_data + async def _send_request( self, method: str, params: dict[str, Any] | None = None ) -> Any: - """Send a JSON-RPC request to the MCP server and return the result.""" + """Send a JSON-RPC request to the MCP server and return the result. + + Handles both ``application/json`` and ``text/event-stream`` responses + as required by the MCP Streamable HTTP transport specification. + """ payload = self._build_jsonrpc_request(method, params) headers = self._build_headers() @@ -96,7 +132,12 @@ class MCPClient: trusted_origins=self.trusted_origins, ) response = await requests.post(self.server_url, json=payload) - body = response.json() + + content_type = response.headers.get("content-type", "") + if "text/event-stream" in content_type: + body = self._parse_sse_response(response.text()) + else: + body = response.json() # Handle JSON-RPC error if "error" in body: @@ -119,6 +160,90 @@ class MCPClient: ) await requests.post(self.server_url, json=notification) + async def discover_auth(self) -> dict[str, Any] | None: + """Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec). + + Returns ``None`` if the server doesn't require auth, otherwise returns + a dict with: + - ``authorization_servers``: list of authorization server URLs + - ``resource``: the resource indicator URL (usually the MCP endpoint) + - ``scopes_supported``: optional list of supported scopes + + The caller can then fetch the authorization server metadata to get + ``authorization_endpoint``, ``token_endpoint``, etc. + """ + from urllib.parse import urlparse + + parsed = urlparse(self.server_url) + base = f"{parsed.scheme}://{parsed.netloc}" + + # Build candidates for protected-resource metadata (per RFC 9728) + path = parsed.path.rstrip("/") + candidates = [] + if path and path != "/": + candidates.append(f"{base}/.well-known/oauth-protected-resource{path}") + candidates.append(f"{base}/.well-known/oauth-protected-resource") + + requests = Requests( + raise_for_status=False, + trusted_origins=self.trusted_origins, + ) + for url in candidates: + try: + resp = await requests.get(url) + if resp.status == 200: + data = resp.json() + if isinstance(data, dict) and "authorization_servers" in data: + return data + except Exception: + continue + + return None + + async def discover_auth_server_metadata( + self, auth_server_url: str + ) -> dict[str, Any] | None: + """Fetch the OAuth Authorization Server Metadata (RFC 8414). + + Given an authorization server URL, returns a dict with: + - ``authorization_endpoint`` + - ``token_endpoint`` + - ``registration_endpoint`` (for dynamic client registration) + - ``scopes_supported`` + - ``code_challenge_methods_supported`` + - etc. + """ + from urllib.parse import urlparse + + parsed = urlparse(auth_server_url) + base = f"{parsed.scheme}://{parsed.netloc}" + path = parsed.path.rstrip("/") + + # Try standard metadata endpoints (RFC 8414 and OpenID Connect) + candidates = [] + if path and path != "/": + candidates.append( + f"{base}/.well-known/oauth-authorization-server{path}" + ) + candidates.append(f"{base}/.well-known/oauth-authorization-server") + candidates.append(f"{base}/.well-known/openid-configuration") + + requests = Requests( + raise_for_status=False, + trusted_origins=self.trusted_origins, + ) + for url in candidates: + try: + resp = await requests.get(url) + if resp.status == 200: + data = resp.json() + if isinstance(data, dict) and "authorization_endpoint" in data: + return data + except Exception: + continue + + return None + async def initialize(self) -> dict[str, Any]: """ Send the MCP initialize request. diff --git a/autogpt_platform/backend/backend/blocks/mcp/oauth.py b/autogpt_platform/backend/backend/blocks/mcp/oauth.py new file mode 100644 index 0000000000..ecf680393b --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/oauth.py @@ -0,0 +1,188 @@ +""" +MCP OAuth handler for MCP servers that use OAuth 2.1 authorization. + +Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed, +MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata. +This handler accepts those endpoints at construction time. +""" + +import logging +import time +import urllib.parse +from typing import ClassVar, Optional + +from pydantic import SecretStr + +from backend.data.model import OAuth2Credentials +from backend.integrations.providers import ProviderName +from backend.integrations.oauth.base import BaseOAuthHandler +from backend.util.request import Requests + +logger = logging.getLogger(__name__) + + +class MCPOAuthHandler(BaseOAuthHandler): + """ + OAuth handler for MCP servers with dynamically-discovered endpoints. + + Construction requires the authorization and token endpoint URLs, + which are obtained via MCP OAuth metadata discovery + (``MCPClient.discover_auth`` + ``discover_auth_server_metadata``). + """ + + PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP + DEFAULT_SCOPES: ClassVar[list[str]] = [] + + def __init__( + self, + client_id: str, + client_secret: str, + redirect_uri: str, + *, + authorize_url: str, + token_url: str, + revoke_url: str | None = None, + resource_url: str | None = None, + ): + self.client_id = client_id + self.client_secret = client_secret + self.redirect_uri = redirect_uri + self.authorize_url = authorize_url + self.token_url = token_url + self.revoke_url = revoke_url + self.resource_url = resource_url + + def get_login_url( + self, + scopes: list[str], + state: str, + code_challenge: Optional[str], + ) -> str: + scopes = self.handle_default_scopes(scopes) + + params: dict[str, str] = { + "response_type": "code", + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "state": state, + } + if scopes: + params["scope"] = " ".join(scopes) + # PKCE is required by the MCP spec (S256 only) + if code_challenge: + params["code_challenge"] = code_challenge + params["code_challenge_method"] = "S256" + # MCP spec requires resource indicator (RFC 8707) + if self.resource_url: + params["resource"] = self.resource_url + + return f"{self.authorize_url}?{urllib.parse.urlencode(params)}" + + async def exchange_code_for_tokens( + self, + code: str, + scopes: list[str], + code_verifier: Optional[str], + ) -> OAuth2Credentials: + data: dict[str, str] = { + "grant_type": "authorization_code", + "code": code, + "redirect_uri": self.redirect_uri, + "client_id": self.client_id, + } + if self.client_secret: + data["client_secret"] = self.client_secret + if code_verifier: + data["code_verifier"] = code_verifier + if self.resource_url: + data["resource"] = self.resource_url + + response = await Requests().post( + self.token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + tokens = response.json() + + now = int(time.time()) + expires_in = tokens.get("expires_in") + + return OAuth2Credentials( + provider=str(self.PROVIDER_NAME), + title=None, + access_token=SecretStr(tokens["access_token"]), + refresh_token=( + SecretStr(tokens["refresh_token"]) + if tokens.get("refresh_token") + else None + ), + access_token_expires_at=now + expires_in if expires_in else None, + refresh_token_expires_at=None, + scopes=scopes, + metadata={ + "mcp_token_url": self.token_url, + "mcp_resource_url": self.resource_url, + }, + ) + + async def _refresh_tokens( + self, credentials: OAuth2Credentials + ) -> OAuth2Credentials: + if not credentials.refresh_token: + raise ValueError("No refresh token available for MCP OAuth credentials") + + data: dict[str, str] = { + "grant_type": "refresh_token", + "refresh_token": credentials.refresh_token.get_secret_value(), + "client_id": self.client_id, + } + if self.client_secret: + data["client_secret"] = self.client_secret + if self.resource_url: + data["resource"] = self.resource_url + + response = await Requests().post( + self.token_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + tokens = response.json() + + now = int(time.time()) + expires_in = tokens.get("expires_in") + + return OAuth2Credentials( + id=credentials.id, + provider=str(self.PROVIDER_NAME), + title=credentials.title, + access_token=SecretStr(tokens["access_token"]), + refresh_token=( + SecretStr(str(tokens["refresh_token"])) + if tokens.get("refresh_token") + else credentials.refresh_token + ), + access_token_expires_at=now + expires_in if expires_in else None, + refresh_token_expires_at=credentials.refresh_token_expires_at, + scopes=credentials.scopes, + metadata=credentials.metadata, + ) + + async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool: + if not self.revoke_url: + return False + + try: + data = { + "token": credentials.access_token.get_secret_value(), + "token_type_hint": "access_token", + "client_id": self.client_id, + } + await Requests().post( + self.revoke_url, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + return True + except Exception: + logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True) + return False diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_e2e.py b/autogpt_platform/backend/backend/blocks/mcp/test_e2e.py new file mode 100644 index 0000000000..6c85292929 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_e2e.py @@ -0,0 +1,104 @@ +""" +End-to-end tests against a real public MCP server. + +These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp) +which is publicly accessible without authentication and returns SSE responses. + +Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped +independently of the rest of the test suite (they require network access). +""" + +import json + +import pytest + +from backend.blocks.mcp.client import MCPClient + +# Public MCP server that requires no authentication +OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp" + + +@pytest.mark.e2e +class TestRealMCPServer: + """Tests against the live OpenAI docs MCP server.""" + + @pytest.mark.asyncio + async def test_initialize(self): + """Verify we can complete the MCP handshake with a real server.""" + client = MCPClient(OPENAI_DOCS_MCP_URL) + result = await client.initialize() + + assert result["protocolVersion"] == "2025-03-26" + assert "serverInfo" in result + assert result["serverInfo"]["name"] == "openai-docs-mcp" + assert "tools" in result.get("capabilities", {}) + + @pytest.mark.asyncio + async def test_list_tools(self): + """Verify we can discover tools from a real MCP server.""" + client = MCPClient(OPENAI_DOCS_MCP_URL) + await client.initialize() + tools = await client.list_tools() + + assert len(tools) >= 3 # server has at least 5 tools as of writing + + tool_names = {t.name for t in tools} + # These tools are documented and should be stable + assert "search_openai_docs" in tool_names + assert "list_openai_docs" in tool_names + assert "fetch_openai_doc" in tool_names + + # Verify schema structure + search_tool = next(t for t in tools if t.name == "search_openai_docs") + assert "query" in search_tool.input_schema.get("properties", {}) + assert "query" in search_tool.input_schema.get("required", []) + + @pytest.mark.asyncio + async def test_call_tool_list_api_endpoints(self): + """Call the list_api_endpoints tool and verify we get real data.""" + client = MCPClient(OPENAI_DOCS_MCP_URL) + await client.initialize() + result = await client.call_tool("list_api_endpoints", {}) + + assert not result.is_error + assert len(result.content) >= 1 + assert result.content[0]["type"] == "text" + + data = json.loads(result.content[0]["text"]) + assert "paths" in data or "urls" in data + # The OpenAI API should have many endpoints + total = data.get("total", len(data.get("paths", []))) + assert total > 50 + + @pytest.mark.asyncio + async def test_call_tool_search(self): + """Search for docs and verify we get results.""" + client = MCPClient(OPENAI_DOCS_MCP_URL) + await client.initialize() + result = await client.call_tool( + "search_openai_docs", {"query": "chat completions", "limit": 3} + ) + + assert not result.is_error + assert len(result.content) >= 1 + + @pytest.mark.asyncio + async def test_sse_response_handling(self): + """Verify the client correctly handles SSE responses from a real server. + + This is the key test — our local test server returns JSON, + but real MCP servers typically return SSE. This proves the + SSE parsing works end-to-end. + """ + client = MCPClient(OPENAI_DOCS_MCP_URL) + # initialize() internally calls _send_request which must parse SSE + result = await client.initialize() + + # If we got here without error, SSE parsing works + assert isinstance(result, dict) + assert "protocolVersion" in result + + # Also verify list_tools works (another SSE response) + tools = await client.list_tools() + assert len(tools) > 0 + assert all(hasattr(t, "name") for t in tools) diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_mcp.py b/autogpt_platform/backend/backend/blocks/mcp/test_mcp.py index 4d92559460..521b35c5f8 100644 --- a/autogpt_platform/backend/backend/blocks/mcp/test_mcp.py +++ b/autogpt_platform/backend/backend/blocks/mcp/test_mcp.py @@ -7,8 +7,80 @@ from unittest.mock import AsyncMock, patch import pytest +from pydantic import SecretStr + from backend.blocks.mcp.block import MCPToolBlock, TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT +from backend.data.model import APIKeyCredentials, OAuth2Credentials from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError, MCPTool + +# ── SSE parsing unit tests ─────────────────────────────────────────── + + +class TestSSEParsing: + """Tests for SSE (text/event-stream) response parsing.""" + + def test_parse_sse_simple(self): + sse = ( + "event: message\n" + 'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n' + "\n" + ) + body = MCPClient._parse_sse_response(sse) + assert body["result"] == {"tools": []} + assert body["id"] == 1 + + def test_parse_sse_with_notifications(self): + """SSE streams can contain notifications (no id) before the response.""" + sse = ( + "event: message\n" + 'data: {"jsonrpc":"2.0","method":"some/notification"}\n' + "\n" + "event: message\n" + 'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n' + "\n" + ) + body = MCPClient._parse_sse_response(sse) + assert body["result"] == {"ok": True} + assert body["id"] == 2 + + def test_parse_sse_error_response(self): + sse = ( + "event: message\n" + 'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n' + ) + body = MCPClient._parse_sse_response(sse) + assert "error" in body + assert body["error"]["code"] == -32600 + + def test_parse_sse_no_data_raises(self): + with pytest.raises(MCPClientError, match="No JSON-RPC response found"): + MCPClient._parse_sse_response("event: message\n\n") + + def test_parse_sse_empty_raises(self): + with pytest.raises(MCPClientError, match="No JSON-RPC response found"): + MCPClient._parse_sse_response("") + + def test_parse_sse_ignores_non_data_lines(self): + sse = ( + ": comment line\n" + "event: message\n" + "id: 123\n" + 'data: {"jsonrpc":"2.0","result":"ok","id":1}\n' + "\n" + ) + body = MCPClient._parse_sse_response(sse) + assert body["result"] == "ok" + + def test_parse_sse_uses_last_response(self): + """If multiple responses exist, use the last one.""" + sse = ( + 'data: {"jsonrpc":"2.0","result":"first","id":1}\n' + "\n" + 'data: {"jsonrpc":"2.0","result":"second","id":2}\n' + "\n" + ) + body = MCPClient._parse_sse_response(sse) + assert body["result"] == "second" from backend.util.test import execute_block_test @@ -505,9 +577,6 @@ class TestMCPToolBlock: @pytest.mark.asyncio async def test_run_skips_placeholder_credentials(self): """Ensure placeholder API keys are not sent to the MCP server.""" - from backend.data.model import APIKeyCredentials - from pydantic import SecretStr - block = MCPToolBlock() input_data = MCPToolBlock.Input( server_url="https://mcp.example.com/mcp", @@ -534,3 +603,74 @@ class TestMCPToolBlock: pass assert captured_tokens == [None] + + +# ── OAuth2 credential support tests ───────────────────────────────── + + +class TestMCPOAuth2Support: + """Tests for OAuth2 credential support in MCPToolBlock.""" + + def test_extract_auth_token_from_api_key(self): + creds = APIKeyCredentials( + id="test", provider="mcp", + api_key=SecretStr("my-api-key"), title="test", + ) + token = MCPToolBlock._extract_auth_token(creds) + assert token == "my-api-key" + + def test_extract_auth_token_from_oauth2(self): + creds = OAuth2Credentials( + id="test", provider="mcp", + access_token=SecretStr("oauth2-access-token"), + scopes=["read"], title="test", + ) + token = MCPToolBlock._extract_auth_token(creds) + assert token == "oauth2-access-token" + + def test_extract_auth_token_placeholder_skipped(self): + creds = APIKeyCredentials( + id="test", provider="mcp", + api_key=SecretStr("FAKE_API_KEY"), title="test", + ) + token = MCPToolBlock._extract_auth_token(creds) + assert token is None + + def test_extract_auth_token_empty_skipped(self): + creds = APIKeyCredentials( + id="test", provider="mcp", + api_key=SecretStr(""), title="test", + ) + token = MCPToolBlock._extract_auth_token(creds) + assert token is None + + @pytest.mark.asyncio + async def test_run_with_oauth2_credentials(self): + """Verify the block can run with OAuth2 credentials.""" + block = MCPToolBlock() + input_data = MCPToolBlock.Input( + server_url="https://mcp.example.com/mcp", + selected_tool="test_tool", + credentials=TEST_CREDENTIALS_INPUT, # type: ignore + ) + + oauth2_creds = OAuth2Credentials( + id="test-id", provider="mcp", + access_token=SecretStr("real-oauth2-token"), + scopes=["read", "write"], title="MCP OAuth", + ) + + captured_tokens = [] + + async def mock_call(server_url, tool_name, arguments, auth_token=None): + captured_tokens.append(auth_token) + return {"status": "ok"} + + block._call_mcp_tool = mock_call # type: ignore + + outputs = [] + async for name, data in block.run(input_data, credentials=oauth2_creds): + outputs.append((name, data)) + + assert captured_tokens == ["real-oauth2-token"] + assert outputs == [("result", {"status": "ok"})] diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_oauth.py b/autogpt_platform/backend/backend/blocks/mcp/test_oauth.py new file mode 100644 index 0000000000..6134def1c6 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/mcp/test_oauth.py @@ -0,0 +1,252 @@ +""" +Tests for MCP OAuth handler. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic import SecretStr + +from backend.blocks.mcp.client import MCPClient +from backend.blocks.mcp.oauth import MCPOAuthHandler +from backend.data.model import OAuth2Credentials + + +def _mock_response(json_data: dict, status: int = 200) -> MagicMock: + """Create a mock Response with synchronous json() (matching Requests.Response).""" + resp = MagicMock() + resp.status = status + resp.ok = 200 <= status < 300 + resp.json.return_value = json_data + return resp + + +class TestMCPOAuthHandler: + """Tests for the MCPOAuthHandler.""" + + def _make_handler(self, **overrides) -> MCPOAuthHandler: + defaults = { + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "redirect_uri": "https://app.example.com/callback", + "authorize_url": "https://auth.example.com/authorize", + "token_url": "https://auth.example.com/token", + } + defaults.update(overrides) + return MCPOAuthHandler(**defaults) + + def test_get_login_url_basic(self): + handler = self._make_handler() + url = handler.get_login_url( + scopes=["read", "write"], + state="random-state-token", + code_challenge="S256-challenge-value", + ) + + assert "https://auth.example.com/authorize?" in url + assert "response_type=code" in url + assert "client_id=test-client-id" in url + assert "state=random-state-token" in url + assert "code_challenge=S256-challenge-value" in url + assert "code_challenge_method=S256" in url + assert "scope=read+write" in url + + def test_get_login_url_with_resource(self): + handler = self._make_handler( + resource_url="https://mcp.example.com/mcp" + ) + url = handler.get_login_url( + scopes=[], state="state", code_challenge="challenge" + ) + + assert "resource=https" in url + + def test_get_login_url_without_pkce(self): + handler = self._make_handler() + url = handler.get_login_url( + scopes=["read"], state="state", code_challenge=None + ) + + assert "code_challenge" not in url + assert "code_challenge_method" not in url + + @pytest.mark.asyncio + async def test_exchange_code_for_tokens(self): + handler = self._make_handler() + + resp = _mock_response({ + "access_token": "new-access-token", + "refresh_token": "new-refresh-token", + "expires_in": 3600, + "token_type": "Bearer", + }) + + with patch( + "backend.blocks.mcp.oauth.Requests" + ) as MockRequests: + instance = MockRequests.return_value + instance.post = AsyncMock(return_value=resp) + + creds = await handler.exchange_code_for_tokens( + code="auth-code", + scopes=["read"], + code_verifier="pkce-verifier", + ) + + assert isinstance(creds, OAuth2Credentials) + assert creds.access_token.get_secret_value() == "new-access-token" + assert creds.refresh_token.get_secret_value() == "new-refresh-token" + assert creds.scopes == ["read"] + assert creds.access_token_expires_at is not None + + @pytest.mark.asyncio + async def test_refresh_tokens(self): + handler = self._make_handler() + + existing_creds = OAuth2Credentials( + id="existing-id", + provider="mcp", + access_token=SecretStr("old-token"), + refresh_token=SecretStr("old-refresh"), + scopes=["read"], + title="test", + ) + + resp = _mock_response({ + "access_token": "refreshed-token", + "refresh_token": "new-refresh", + "expires_in": 3600, + }) + + with patch( + "backend.blocks.mcp.oauth.Requests" + ) as MockRequests: + instance = MockRequests.return_value + instance.post = AsyncMock(return_value=resp) + + refreshed = await handler._refresh_tokens(existing_creds) + + assert refreshed.id == "existing-id" + assert refreshed.access_token.get_secret_value() == "refreshed-token" + assert refreshed.refresh_token.get_secret_value() == "new-refresh" + + @pytest.mark.asyncio + async def test_refresh_tokens_no_refresh_token(self): + handler = self._make_handler() + + creds = OAuth2Credentials( + provider="mcp", + access_token=SecretStr("token"), + scopes=["read"], + title="test", + ) + + with pytest.raises(ValueError, match="No refresh token"): + await handler._refresh_tokens(creds) + + @pytest.mark.asyncio + async def test_revoke_tokens_no_url(self): + handler = self._make_handler(revoke_url=None) + + creds = OAuth2Credentials( + provider="mcp", + access_token=SecretStr("token"), + scopes=[], title="test", + ) + + result = await handler.revoke_tokens(creds) + assert result is False + + @pytest.mark.asyncio + async def test_revoke_tokens_with_url(self): + handler = self._make_handler( + revoke_url="https://auth.example.com/revoke" + ) + + creds = OAuth2Credentials( + provider="mcp", + access_token=SecretStr("token"), + scopes=[], title="test", + ) + + resp = _mock_response({}, status=200) + + with patch( + "backend.blocks.mcp.oauth.Requests" + ) as MockRequests: + instance = MockRequests.return_value + instance.post = AsyncMock(return_value=resp) + + result = await handler.revoke_tokens(creds) + + assert result is True + + +class TestMCPClientDiscovery: + """Tests for MCPClient OAuth metadata discovery.""" + + @pytest.mark.asyncio + async def test_discover_auth_found(self): + client = MCPClient("https://mcp.example.com/mcp") + + metadata = { + "authorization_servers": ["https://auth.example.com"], + "resource": "https://mcp.example.com/mcp", + } + + resp = _mock_response(metadata, status=200) + + with patch( + "backend.blocks.mcp.client.Requests" + ) as MockRequests: + instance = MockRequests.return_value + instance.get = AsyncMock(return_value=resp) + + result = await client.discover_auth() + + assert result is not None + assert result["authorization_servers"] == ["https://auth.example.com"] + + @pytest.mark.asyncio + async def test_discover_auth_not_found(self): + client = MCPClient("https://mcp.example.com/mcp") + + resp = _mock_response({}, status=404) + + with patch( + "backend.blocks.mcp.client.Requests" + ) as MockRequests: + instance = MockRequests.return_value + instance.get = AsyncMock(return_value=resp) + + result = await client.discover_auth() + + assert result is None + + @pytest.mark.asyncio + async def test_discover_auth_server_metadata(self): + client = MCPClient("https://mcp.example.com/mcp") + + server_metadata = { + "issuer": "https://auth.example.com", + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + "registration_endpoint": "https://auth.example.com/register", + "code_challenge_methods_supported": ["S256"], + } + + resp = _mock_response(server_metadata, status=200) + + with patch( + "backend.blocks.mcp.client.Requests" + ) as MockRequests: + instance = MockRequests.return_value + instance.get = AsyncMock(return_value=resp) + + result = await client.discover_auth_server_metadata( + "https://auth.example.com" + ) + + assert result is not None + assert result["authorization_endpoint"] == "https://auth.example.com/authorize" + assert result["token_endpoint"] == "https://auth.example.com/token"