feat(backend/blocks/mcp): Add SSE support, OAuth auth, and e2e tests

- Handle text/event-stream (SSE) responses from real MCP servers
  (MCPClient._parse_sse_response) alongside plain JSON responses
- Add e2e tests against OpenAI docs MCP server (developers.openai.com/mcp)
  verifying SSE parsing, tool discovery, and tool execution work with a
  real production MCP server
- Support both api_key and oauth2 credential types on MCPToolBlock
  (MCPCredentials union type, _extract_auth_token helper)
- Add MCPOAuthHandler implementing BaseOAuthHandler with dynamic
  endpoints (authorize_url, token_url) for MCP OAuth 2.1 with PKCE
- Add OAuth metadata discovery to MCPClient (discover_auth,
  discover_auth_server_metadata) per RFC 9728 / RFC 8414
- 76 total tests: 46 unit, 11 OAuth, 14 integration, 5 e2e
This commit is contained in:
Zamil Majdy
2026-02-08 16:32:50 +04:00
parent e9b996abb0
commit 7db3f12876
6 changed files with 837 additions and 16 deletions

View File

@@ -26,6 +26,7 @@ from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
OAuth2Credentials,
SchemaField,
)
from backend.integrations.providers import ProviderName
@@ -33,8 +34,9 @@ from backend.util.json import validate_with_jsonschema
logger = logging.getLogger(__name__)
MCPCredentials = APIKeyCredentials | OAuth2Credentials
MCPCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.MCP], Literal["api_key"]
Literal[ProviderName.MCP], Literal["api_key", "oauth2"]
]
TEST_CREDENTIALS = APIKeyCredentials(
@@ -67,8 +69,9 @@ class MCPToolBlock(Block):
class Input(BlockSchemaInput):
# -- Static fields (always shown) --
credentials: MCPCredentialsInput = CredentialsField(
description="API key / Bearer token for the MCP server (optional for "
"public servers — create a credential with any placeholder value).",
description="Credentials for the MCP server. Use an API key for Bearer "
"token auth, or OAuth2 for servers that support it. For public "
"servers, create a credential with any placeholder value.",
)
server_url: str = SchemaField(
description="URL of the MCP server (Streamable HTTP endpoint)",
@@ -217,11 +220,25 @@ class MCPToolBlock(Block):
return output_parts[0]
return output_parts if output_parts else None
@staticmethod
def _extract_auth_token(credentials: MCPCredentials) -> str | None:
"""Extract a Bearer token from either API key or OAuth2 credentials."""
if isinstance(credentials, OAuth2Credentials):
return credentials.access_token.get_secret_value()
if isinstance(credentials, APIKeyCredentials) and credentials.api_key:
token_value = credentials.api_key.get_secret_value()
# Skip placeholder/fake tokens
if token_value and token_value not in ("", "FAKE_API_KEY", "placeholder"):
return token_value
return None
async def run(
self,
input_data: Input,
*,
credentials: APIKeyCredentials,
credentials: MCPCredentials,
**kwargs,
) -> BlockOutput:
if not input_data.server_url:
@@ -232,12 +249,7 @@ class MCPToolBlock(Block):
yield "error", "No tool selected. Please select a tool from the dropdown."
return
auth_token: str | None = None
if credentials and credentials.api_key:
token_value = credentials.api_key.get_secret_value()
# Skip placeholder/fake tokens
if token_value and token_value not in ("", "FAKE_API_KEY", "placeholder"):
auth_token = token_value
auth_token = self._extract_auth_token(credentials)
try:
result = await self._call_mcp_tool(

View File

@@ -4,9 +4,12 @@ MCP (Model Context Protocol) HTTP client.
Implements the MCP Streamable HTTP transport for listing tools and calling tools
on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST.
Reference: https://modelcontextprotocol.io/docs/concepts/transports
Handles both JSON and SSE (text/event-stream) response formats per the MCP spec.
Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
"""
import json
import logging
from dataclasses import dataclass, field
from typing import Any
@@ -83,10 +86,43 @@ class MCPClient:
req["params"] = params
return req
@staticmethod
def _parse_sse_response(text: str) -> dict[str, Any]:
"""Parse an SSE (text/event-stream) response body into JSON-RPC data.
MCP servers may return responses as SSE with format:
event: message
data: {"jsonrpc":"2.0","result":{...},"id":1}
We extract the last `data:` line that contains a JSON-RPC response
(i.e. has an "id" field), which is the reply to our request.
"""
last_data: dict[str, Any] | None = None
for line in text.splitlines():
stripped = line.strip()
if stripped.startswith("data:"):
payload = stripped[len("data:") :].strip()
if not payload:
continue
try:
parsed = json.loads(payload)
# Only keep JSON-RPC responses (have "id"), skip notifications
if isinstance(parsed, dict) and "id" in parsed:
last_data = parsed
except (json.JSONDecodeError, ValueError):
continue
if last_data is None:
raise MCPClientError("No JSON-RPC response found in SSE stream")
return last_data
async def _send_request(
self, method: str, params: dict[str, Any] | None = None
) -> Any:
"""Send a JSON-RPC request to the MCP server and return the result."""
"""Send a JSON-RPC request to the MCP server and return the result.
Handles both ``application/json`` and ``text/event-stream`` responses
as required by the MCP Streamable HTTP transport specification.
"""
payload = self._build_jsonrpc_request(method, params)
headers = self._build_headers()
@@ -96,7 +132,12 @@ class MCPClient:
trusted_origins=self.trusted_origins,
)
response = await requests.post(self.server_url, json=payload)
body = response.json()
content_type = response.headers.get("content-type", "")
if "text/event-stream" in content_type:
body = self._parse_sse_response(response.text())
else:
body = response.json()
# Handle JSON-RPC error
if "error" in body:
@@ -119,6 +160,90 @@ class MCPClient:
)
await requests.post(self.server_url, json=notification)
async def discover_auth(self) -> dict[str, Any] | None:
"""Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec).
Returns ``None`` if the server doesn't require auth, otherwise returns
a dict with:
- ``authorization_servers``: list of authorization server URLs
- ``resource``: the resource indicator URL (usually the MCP endpoint)
- ``scopes_supported``: optional list of supported scopes
The caller can then fetch the authorization server metadata to get
``authorization_endpoint``, ``token_endpoint``, etc.
"""
from urllib.parse import urlparse
parsed = urlparse(self.server_url)
base = f"{parsed.scheme}://{parsed.netloc}"
# Build candidates for protected-resource metadata (per RFC 9728)
path = parsed.path.rstrip("/")
candidates = []
if path and path != "/":
candidates.append(f"{base}/.well-known/oauth-protected-resource{path}")
candidates.append(f"{base}/.well-known/oauth-protected-resource")
requests = Requests(
raise_for_status=False,
trusted_origins=self.trusted_origins,
)
for url in candidates:
try:
resp = await requests.get(url)
if resp.status == 200:
data = resp.json()
if isinstance(data, dict) and "authorization_servers" in data:
return data
except Exception:
continue
return None
async def discover_auth_server_metadata(
self, auth_server_url: str
) -> dict[str, Any] | None:
"""Fetch the OAuth Authorization Server Metadata (RFC 8414).
Given an authorization server URL, returns a dict with:
- ``authorization_endpoint``
- ``token_endpoint``
- ``registration_endpoint`` (for dynamic client registration)
- ``scopes_supported``
- ``code_challenge_methods_supported``
- etc.
"""
from urllib.parse import urlparse
parsed = urlparse(auth_server_url)
base = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/")
# Try standard metadata endpoints (RFC 8414 and OpenID Connect)
candidates = []
if path and path != "/":
candidates.append(
f"{base}/.well-known/oauth-authorization-server{path}"
)
candidates.append(f"{base}/.well-known/oauth-authorization-server")
candidates.append(f"{base}/.well-known/openid-configuration")
requests = Requests(
raise_for_status=False,
trusted_origins=self.trusted_origins,
)
for url in candidates:
try:
resp = await requests.get(url)
if resp.status == 200:
data = resp.json()
if isinstance(data, dict) and "authorization_endpoint" in data:
return data
except Exception:
continue
return None
async def initialize(self) -> dict[str, Any]:
"""
Send the MCP initialize request.

View File

@@ -0,0 +1,188 @@
"""
MCP OAuth handler for MCP servers that use OAuth 2.1 authorization.
Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed,
MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata.
This handler accepts those endpoints at construction time.
"""
import logging
import time
import urllib.parse
from typing import ClassVar, Optional
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
from backend.integrations.providers import ProviderName
from backend.integrations.oauth.base import BaseOAuthHandler
from backend.util.request import Requests
logger = logging.getLogger(__name__)
class MCPOAuthHandler(BaseOAuthHandler):
"""
OAuth handler for MCP servers with dynamically-discovered endpoints.
Construction requires the authorization and token endpoint URLs,
which are obtained via MCP OAuth metadata discovery
(``MCPClient.discover_auth`` + ``discover_auth_server_metadata``).
"""
PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP
DEFAULT_SCOPES: ClassVar[list[str]] = []
def __init__(
self,
client_id: str,
client_secret: str,
redirect_uri: str,
*,
authorize_url: str,
token_url: str,
revoke_url: str | None = None,
resource_url: str | None = None,
):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
self.authorize_url = authorize_url
self.token_url = token_url
self.revoke_url = revoke_url
self.resource_url = resource_url
def get_login_url(
self,
scopes: list[str],
state: str,
code_challenge: Optional[str],
) -> str:
scopes = self.handle_default_scopes(scopes)
params: dict[str, str] = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"state": state,
}
if scopes:
params["scope"] = " ".join(scopes)
# PKCE is required by the MCP spec (S256 only)
if code_challenge:
params["code_challenge"] = code_challenge
params["code_challenge_method"] = "S256"
# MCP spec requires resource indicator (RFC 8707)
if self.resource_url:
params["resource"] = self.resource_url
return f"{self.authorize_url}?{urllib.parse.urlencode(params)}"
async def exchange_code_for_tokens(
self,
code: str,
scopes: list[str],
code_verifier: Optional[str],
) -> OAuth2Credentials:
data: dict[str, str] = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self.redirect_uri,
"client_id": self.client_id,
}
if self.client_secret:
data["client_secret"] = self.client_secret
if code_verifier:
data["code_verifier"] = code_verifier
if self.resource_url:
data["resource"] = self.resource_url
response = await Requests().post(
self.token_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
tokens = response.json()
now = int(time.time())
expires_in = tokens.get("expires_in")
return OAuth2Credentials(
provider=str(self.PROVIDER_NAME),
title=None,
access_token=SecretStr(tokens["access_token"]),
refresh_token=(
SecretStr(tokens["refresh_token"])
if tokens.get("refresh_token")
else None
),
access_token_expires_at=now + expires_in if expires_in else None,
refresh_token_expires_at=None,
scopes=scopes,
metadata={
"mcp_token_url": self.token_url,
"mcp_resource_url": self.resource_url,
},
)
async def _refresh_tokens(
self, credentials: OAuth2Credentials
) -> OAuth2Credentials:
if not credentials.refresh_token:
raise ValueError("No refresh token available for MCP OAuth credentials")
data: dict[str, str] = {
"grant_type": "refresh_token",
"refresh_token": credentials.refresh_token.get_secret_value(),
"client_id": self.client_id,
}
if self.client_secret:
data["client_secret"] = self.client_secret
if self.resource_url:
data["resource"] = self.resource_url
response = await Requests().post(
self.token_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
tokens = response.json()
now = int(time.time())
expires_in = tokens.get("expires_in")
return OAuth2Credentials(
id=credentials.id,
provider=str(self.PROVIDER_NAME),
title=credentials.title,
access_token=SecretStr(tokens["access_token"]),
refresh_token=(
SecretStr(str(tokens["refresh_token"]))
if tokens.get("refresh_token")
else credentials.refresh_token
),
access_token_expires_at=now + expires_in if expires_in else None,
refresh_token_expires_at=credentials.refresh_token_expires_at,
scopes=credentials.scopes,
metadata=credentials.metadata,
)
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
if not self.revoke_url:
return False
try:
data = {
"token": credentials.access_token.get_secret_value(),
"token_type_hint": "access_token",
"client_id": self.client_id,
}
await Requests().post(
self.revoke_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
return True
except Exception:
logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True)
return False

View File

@@ -0,0 +1,104 @@
"""
End-to-end tests against a real public MCP server.
These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp)
which is publicly accessible without authentication and returns SSE responses.
Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped
independently of the rest of the test suite (they require network access).
"""
import json
import pytest
from backend.blocks.mcp.client import MCPClient
# Public MCP server that requires no authentication
OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp"
@pytest.mark.e2e
class TestRealMCPServer:
"""Tests against the live OpenAI docs MCP server."""
@pytest.mark.asyncio
async def test_initialize(self):
"""Verify we can complete the MCP handshake with a real server."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
result = await client.initialize()
assert result["protocolVersion"] == "2025-03-26"
assert "serverInfo" in result
assert result["serverInfo"]["name"] == "openai-docs-mcp"
assert "tools" in result.get("capabilities", {})
@pytest.mark.asyncio
async def test_list_tools(self):
"""Verify we can discover tools from a real MCP server."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
await client.initialize()
tools = await client.list_tools()
assert len(tools) >= 3 # server has at least 5 tools as of writing
tool_names = {t.name for t in tools}
# These tools are documented and should be stable
assert "search_openai_docs" in tool_names
assert "list_openai_docs" in tool_names
assert "fetch_openai_doc" in tool_names
# Verify schema structure
search_tool = next(t for t in tools if t.name == "search_openai_docs")
assert "query" in search_tool.input_schema.get("properties", {})
assert "query" in search_tool.input_schema.get("required", [])
@pytest.mark.asyncio
async def test_call_tool_list_api_endpoints(self):
"""Call the list_api_endpoints tool and verify we get real data."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
await client.initialize()
result = await client.call_tool("list_api_endpoints", {})
assert not result.is_error
assert len(result.content) >= 1
assert result.content[0]["type"] == "text"
data = json.loads(result.content[0]["text"])
assert "paths" in data or "urls" in data
# The OpenAI API should have many endpoints
total = data.get("total", len(data.get("paths", [])))
assert total > 50
@pytest.mark.asyncio
async def test_call_tool_search(self):
"""Search for docs and verify we get results."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
await client.initialize()
result = await client.call_tool(
"search_openai_docs", {"query": "chat completions", "limit": 3}
)
assert not result.is_error
assert len(result.content) >= 1
@pytest.mark.asyncio
async def test_sse_response_handling(self):
"""Verify the client correctly handles SSE responses from a real server.
This is the key test — our local test server returns JSON,
but real MCP servers typically return SSE. This proves the
SSE parsing works end-to-end.
"""
client = MCPClient(OPENAI_DOCS_MCP_URL)
# initialize() internally calls _send_request which must parse SSE
result = await client.initialize()
# If we got here without error, SSE parsing works
assert isinstance(result, dict)
assert "protocolVersion" in result
# Also verify list_tools works (another SSE response)
tools = await client.list_tools()
assert len(tools) > 0
assert all(hasattr(t, "name") for t in tools)

View File

@@ -7,8 +7,80 @@ from unittest.mock import AsyncMock, patch
import pytest
from pydantic import SecretStr
from backend.blocks.mcp.block import MCPToolBlock, TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
from backend.data.model import APIKeyCredentials, OAuth2Credentials
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError, MCPTool
# ── SSE parsing unit tests ───────────────────────────────────────────
class TestSSEParsing:
"""Tests for SSE (text/event-stream) response parsing."""
def test_parse_sse_simple(self):
sse = (
"event: message\n"
'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == {"tools": []}
assert body["id"] == 1
def test_parse_sse_with_notifications(self):
"""SSE streams can contain notifications (no id) before the response."""
sse = (
"event: message\n"
'data: {"jsonrpc":"2.0","method":"some/notification"}\n'
"\n"
"event: message\n"
'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == {"ok": True}
assert body["id"] == 2
def test_parse_sse_error_response(self):
sse = (
"event: message\n"
'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n'
)
body = MCPClient._parse_sse_response(sse)
assert "error" in body
assert body["error"]["code"] == -32600
def test_parse_sse_no_data_raises(self):
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
MCPClient._parse_sse_response("event: message\n\n")
def test_parse_sse_empty_raises(self):
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
MCPClient._parse_sse_response("")
def test_parse_sse_ignores_non_data_lines(self):
sse = (
": comment line\n"
"event: message\n"
"id: 123\n"
'data: {"jsonrpc":"2.0","result":"ok","id":1}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == "ok"
def test_parse_sse_uses_last_response(self):
"""If multiple responses exist, use the last one."""
sse = (
'data: {"jsonrpc":"2.0","result":"first","id":1}\n'
"\n"
'data: {"jsonrpc":"2.0","result":"second","id":2}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == "second"
from backend.util.test import execute_block_test
@@ -505,9 +577,6 @@ class TestMCPToolBlock:
@pytest.mark.asyncio
async def test_run_skips_placeholder_credentials(self):
"""Ensure placeholder API keys are not sent to the MCP server."""
from backend.data.model import APIKeyCredentials
from pydantic import SecretStr
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
@@ -534,3 +603,74 @@ class TestMCPToolBlock:
pass
assert captured_tokens == [None]
# ── OAuth2 credential support tests ─────────────────────────────────
class TestMCPOAuth2Support:
"""Tests for OAuth2 credential support in MCPToolBlock."""
def test_extract_auth_token_from_api_key(self):
creds = APIKeyCredentials(
id="test", provider="mcp",
api_key=SecretStr("my-api-key"), title="test",
)
token = MCPToolBlock._extract_auth_token(creds)
assert token == "my-api-key"
def test_extract_auth_token_from_oauth2(self):
creds = OAuth2Credentials(
id="test", provider="mcp",
access_token=SecretStr("oauth2-access-token"),
scopes=["read"], title="test",
)
token = MCPToolBlock._extract_auth_token(creds)
assert token == "oauth2-access-token"
def test_extract_auth_token_placeholder_skipped(self):
creds = APIKeyCredentials(
id="test", provider="mcp",
api_key=SecretStr("FAKE_API_KEY"), title="test",
)
token = MCPToolBlock._extract_auth_token(creds)
assert token is None
def test_extract_auth_token_empty_skipped(self):
creds = APIKeyCredentials(
id="test", provider="mcp",
api_key=SecretStr(""), title="test",
)
token = MCPToolBlock._extract_auth_token(creds)
assert token is None
@pytest.mark.asyncio
async def test_run_with_oauth2_credentials(self):
"""Verify the block can run with OAuth2 credentials."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="test_tool",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
)
oauth2_creds = OAuth2Credentials(
id="test-id", provider="mcp",
access_token=SecretStr("real-oauth2-token"),
scopes=["read", "write"], title="MCP OAuth",
)
captured_tokens = []
async def mock_call(server_url, tool_name, arguments, auth_token=None):
captured_tokens.append(auth_token)
return {"status": "ok"}
block._call_mcp_tool = mock_call # type: ignore
outputs = []
async for name, data in block.run(input_data, credentials=oauth2_creds):
outputs.append((name, data))
assert captured_tokens == ["real-oauth2-token"]
assert outputs == [("result", {"status": "ok"})]

View File

@@ -0,0 +1,252 @@
"""
Tests for MCP OAuth handler.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
from backend.blocks.mcp.client import MCPClient
from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
def _mock_response(json_data: dict, status: int = 200) -> MagicMock:
"""Create a mock Response with synchronous json() (matching Requests.Response)."""
resp = MagicMock()
resp.status = status
resp.ok = 200 <= status < 300
resp.json.return_value = json_data
return resp
class TestMCPOAuthHandler:
"""Tests for the MCPOAuthHandler."""
def _make_handler(self, **overrides) -> MCPOAuthHandler:
defaults = {
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"redirect_uri": "https://app.example.com/callback",
"authorize_url": "https://auth.example.com/authorize",
"token_url": "https://auth.example.com/token",
}
defaults.update(overrides)
return MCPOAuthHandler(**defaults)
def test_get_login_url_basic(self):
handler = self._make_handler()
url = handler.get_login_url(
scopes=["read", "write"],
state="random-state-token",
code_challenge="S256-challenge-value",
)
assert "https://auth.example.com/authorize?" in url
assert "response_type=code" in url
assert "client_id=test-client-id" in url
assert "state=random-state-token" in url
assert "code_challenge=S256-challenge-value" in url
assert "code_challenge_method=S256" in url
assert "scope=read+write" in url
def test_get_login_url_with_resource(self):
handler = self._make_handler(
resource_url="https://mcp.example.com/mcp"
)
url = handler.get_login_url(
scopes=[], state="state", code_challenge="challenge"
)
assert "resource=https" in url
def test_get_login_url_without_pkce(self):
handler = self._make_handler()
url = handler.get_login_url(
scopes=["read"], state="state", code_challenge=None
)
assert "code_challenge" not in url
assert "code_challenge_method" not in url
@pytest.mark.asyncio
async def test_exchange_code_for_tokens(self):
handler = self._make_handler()
resp = _mock_response({
"access_token": "new-access-token",
"refresh_token": "new-refresh-token",
"expires_in": 3600,
"token_type": "Bearer",
})
with patch(
"backend.blocks.mcp.oauth.Requests"
) as MockRequests:
instance = MockRequests.return_value
instance.post = AsyncMock(return_value=resp)
creds = await handler.exchange_code_for_tokens(
code="auth-code",
scopes=["read"],
code_verifier="pkce-verifier",
)
assert isinstance(creds, OAuth2Credentials)
assert creds.access_token.get_secret_value() == "new-access-token"
assert creds.refresh_token.get_secret_value() == "new-refresh-token"
assert creds.scopes == ["read"]
assert creds.access_token_expires_at is not None
@pytest.mark.asyncio
async def test_refresh_tokens(self):
handler = self._make_handler()
existing_creds = OAuth2Credentials(
id="existing-id",
provider="mcp",
access_token=SecretStr("old-token"),
refresh_token=SecretStr("old-refresh"),
scopes=["read"],
title="test",
)
resp = _mock_response({
"access_token": "refreshed-token",
"refresh_token": "new-refresh",
"expires_in": 3600,
})
with patch(
"backend.blocks.mcp.oauth.Requests"
) as MockRequests:
instance = MockRequests.return_value
instance.post = AsyncMock(return_value=resp)
refreshed = await handler._refresh_tokens(existing_creds)
assert refreshed.id == "existing-id"
assert refreshed.access_token.get_secret_value() == "refreshed-token"
assert refreshed.refresh_token.get_secret_value() == "new-refresh"
@pytest.mark.asyncio
async def test_refresh_tokens_no_refresh_token(self):
handler = self._make_handler()
creds = OAuth2Credentials(
provider="mcp",
access_token=SecretStr("token"),
scopes=["read"],
title="test",
)
with pytest.raises(ValueError, match="No refresh token"):
await handler._refresh_tokens(creds)
@pytest.mark.asyncio
async def test_revoke_tokens_no_url(self):
handler = self._make_handler(revoke_url=None)
creds = OAuth2Credentials(
provider="mcp",
access_token=SecretStr("token"),
scopes=[], title="test",
)
result = await handler.revoke_tokens(creds)
assert result is False
@pytest.mark.asyncio
async def test_revoke_tokens_with_url(self):
handler = self._make_handler(
revoke_url="https://auth.example.com/revoke"
)
creds = OAuth2Credentials(
provider="mcp",
access_token=SecretStr("token"),
scopes=[], title="test",
)
resp = _mock_response({}, status=200)
with patch(
"backend.blocks.mcp.oauth.Requests"
) as MockRequests:
instance = MockRequests.return_value
instance.post = AsyncMock(return_value=resp)
result = await handler.revoke_tokens(creds)
assert result is True
class TestMCPClientDiscovery:
"""Tests for MCPClient OAuth metadata discovery."""
@pytest.mark.asyncio
async def test_discover_auth_found(self):
client = MCPClient("https://mcp.example.com/mcp")
metadata = {
"authorization_servers": ["https://auth.example.com"],
"resource": "https://mcp.example.com/mcp",
}
resp = _mock_response(metadata, status=200)
with patch(
"backend.blocks.mcp.client.Requests"
) as MockRequests:
instance = MockRequests.return_value
instance.get = AsyncMock(return_value=resp)
result = await client.discover_auth()
assert result is not None
assert result["authorization_servers"] == ["https://auth.example.com"]
@pytest.mark.asyncio
async def test_discover_auth_not_found(self):
client = MCPClient("https://mcp.example.com/mcp")
resp = _mock_response({}, status=404)
with patch(
"backend.blocks.mcp.client.Requests"
) as MockRequests:
instance = MockRequests.return_value
instance.get = AsyncMock(return_value=resp)
result = await client.discover_auth()
assert result is None
@pytest.mark.asyncio
async def test_discover_auth_server_metadata(self):
client = MCPClient("https://mcp.example.com/mcp")
server_metadata = {
"issuer": "https://auth.example.com",
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
"registration_endpoint": "https://auth.example.com/register",
"code_challenge_methods_supported": ["S256"],
}
resp = _mock_response(server_metadata, status=200)
with patch(
"backend.blocks.mcp.client.Requests"
) as MockRequests:
instance = MockRequests.return_value
instance.get = AsyncMock(return_value=resp)
result = await client.discover_auth_server_metadata(
"https://auth.example.com"
)
assert result is not None
assert result["authorization_endpoint"] == "https://auth.example.com/authorize"
assert result["token_endpoint"] == "https://auth.example.com/token"