fix(mcp): OAuth discovery fallback, session ID, credential lookup, and DNS reliability

- Support MCP servers that serve OAuth metadata directly without
  protected-resource metadata (e.g. Linear) by falling back to
  discover_auth_server_metadata on the server's own origin
- Omit resource_url when no protected-resource metadata exists to
  avoid token audience mismatch errors (RFC 8707 resource is optional)
- Add Mcp-Session-Id header tracking per MCP Streamable HTTP spec
- Fall back to server_url credential lookup when credential_id is
  empty (pruneEmptyValues strips it from saved graphs)
- Use ThreadedResolver instead of c-ares AsyncResolver to avoid DNS
  failures in forked subprocess environments
- Simplify OAuth UX: single "Sign in & Connect" button on 401,
  remove sticky localStorage URL prefill
- Clean up stale MCP credentials on re-authentication
This commit is contained in:
Zamil Majdy
2026-02-09 18:51:53 +04:00
parent 6c2791b00b
commit 340520ba85
6 changed files with 150 additions and 85 deletions

View File

@@ -81,13 +81,24 @@ async def discover_tools(
mcp_creds = await creds_manager.store.get_creds_by_provider(
user_id, str(ProviderName.MCP)
)
# Find the freshest credential for this server URL
best_cred: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and cred.metadata.get("mcp_server_url") == request.server_url
):
auth_token = cred.access_token.get_secret_value()
break
if best_cred is None or (
(cred.access_token_expires_at or 0)
> (best_cred.access_token_expires_at or 0)
):
best_cred = cred
if best_cred:
logger.info(
f"Using MCP credential {best_cred.id} for {request.server_url}, "
f"expires_at={best_cred.access_token_expires_at}"
)
auth_token = best_cred.access_token.get_secret_value()
except Exception:
logger.debug("Could not look up stored MCP credentials", exc_info=True)
@@ -115,6 +126,9 @@ async def discover_tools(
)
except HTTPClientError as e:
if e.status_code in (401, 403):
logger.warning(
f"MCP server returned {e.status_code} for {request.server_url}: {e}"
)
raise fastapi.HTTPException(
status_code=401,
detail="This MCP server requires authentication. "
@@ -174,31 +188,38 @@ async def mcp_oauth_login(
detail=f"Failed to discover OAuth metadata: {e}",
)
if not protected_resource or "authorization_servers" not in protected_resource:
metadata: dict[str, Any] | None = None
if protected_resource and "authorization_servers" in protected_resource:
auth_server_url = protected_resource["authorization_servers"][0]
resource_url = protected_resource.get("resource", request.server_url)
# Step 2a: Discover auth-server metadata (RFC 8414)
try:
metadata = await client.discover_auth_server_metadata(auth_server_url)
except Exception as e:
raise fastapi.HTTPException(
status_code=502,
detail=f"Failed to discover authorization server metadata: {e}",
)
else:
# Fallback: Some MCP servers (e.g. Linear) are their own auth server
# and serve OAuth metadata directly without protected-resource metadata.
# Don't assume a resource_url — omitting it lets the auth server choose
# the correct audience for the token (RFC 8707 resource is optional).
resource_url = None
try:
metadata = await client.discover_auth_server_metadata(request.server_url)
except Exception:
pass
if not metadata or "authorization_endpoint" not in metadata:
raise fastapi.HTTPException(
status_code=400,
detail="This MCP server does not advertise OAuth support. "
"You may need to provide an auth token manually.",
)
auth_server_url = protected_resource["authorization_servers"][0]
resource_url = protected_resource.get("resource", request.server_url)
# Step 2: Discover auth-server metadata (RFC 8414)
try:
metadata = await client.discover_auth_server_metadata(auth_server_url)
except Exception as e:
raise fastapi.HTTPException(
status_code=502,
detail=f"Failed to discover authorization server metadata: {e}",
)
if not metadata or "authorization_endpoint" not in metadata:
raise fastapi.HTTPException(
status_code=502,
detail="Authorization server metadata is missing required endpoints.",
)
authorize_url = metadata["authorization_endpoint"]
token_url = metadata["token_endpoint"]
registration_endpoint = metadata.get("registration_endpoint")
@@ -227,7 +248,10 @@ async def mcp_oauth_login(
client_id = "autogpt-platform"
# Step 4: Store state token with OAuth metadata for the callback
scopes = protected_resource.get("scopes_supported", [])
scopes = (
(protected_resource or {}).get("scopes_supported")
or metadata.get("scopes_supported", [])
)
state_token, code_challenge = await creds_manager.store.store_state_token(
user_id,
str(ProviderName.MCP),
@@ -327,10 +351,27 @@ async def mcp_oauth_callback(
credentials.metadata["mcp_server_url"] = meta["server_url"]
credentials.metadata["mcp_client_id"] = meta["client_id"]
credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "")
credentials.metadata["mcp_token_url"] = meta["token_url"]
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
credentials.title = f"MCP: {hostname}"
# Remove old MCP credentials for the same server to prevent stale token buildup
try:
old_creds = await creds_manager.store.get_creds_by_provider(
user_id, str(ProviderName.MCP)
)
for old in old_creds:
if (
isinstance(old, OAuth2Credentials)
and old.metadata.get("mcp_server_url") == meta["server_url"]
):
await creds_manager.store.delete_creds_by_id(user_id, old.id)
logger.info(f"Removed old MCP credential {old.id} for {meta['server_url']}")
except Exception:
logger.debug("Could not clean up old MCP credentials", exc_info=True)
await creds_manager.create(user_id, credentials)
return MCPOAuthCallbackResponse(credential_id=credentials.id)

View File

@@ -193,17 +193,50 @@ class MCPToolBlock(Block):
return output_parts[0]
return output_parts if output_parts else None
async def _resolve_auth_token(self, credential_id: str, user_id: str) -> str | None:
"""Resolve a Bearer token from a stored credential ID, refreshing if needed."""
if not credential_id:
return None
async def _resolve_auth_token(
self, credential_id: str, user_id: str, server_url: str = ""
) -> str | None:
"""Resolve a Bearer token from a stored credential ID, refreshing if needed.
Falls back to looking up credentials by server_url when credential_id
is empty (e.g. when pruneEmptyValues strips it from the saved graph).
"""
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.integrations.providers import ProviderName
store = IntegrationCredentialsStore()
creds = await store.get_creds_by_id(user_id, credential_id)
creds = None
if credential_id:
creds = await store.get_creds_by_id(user_id, credential_id)
# Fallback: look up by server_url (same approach as discover-tools)
if not creds and server_url:
logger.info(
f"credential_id not available, looking up credential by server_url"
)
try:
mcp_creds = await store.get_creds_by_provider(
user_id, str(ProviderName.MCP)
)
best: OAuth2Credentials | None = None
for c in mcp_creds:
if (
isinstance(c, OAuth2Credentials)
and c.metadata.get("mcp_server_url") == server_url
):
if best is None or (
(c.access_token_expires_at or 0)
> (best.access_token_expires_at or 0)
):
best = c
creds = best
except Exception:
logger.debug("Could not look up MCP credentials by server_url", exc_info=True)
if not creds:
logger.warning(f"Credential {credential_id} not found")
return None
if isinstance(creds, OAuth2Credentials):
# Refresh if token expires within 5 minutes
if (
@@ -266,7 +299,9 @@ class MCPToolBlock(Block):
yield "error", "No tool selected. Please select a tool from the dropdown."
return
auth_token = await self._resolve_auth_token(input_data.credential_id, user_id)
auth_token = await self._resolve_auth_token(
input_data.credential_id, user_id, server_url=input_data.server_url
)
try:
result = await self._call_mcp_tool(

View File

@@ -60,6 +60,7 @@ class MCPClient:
self.auth_token = auth_token
self.trusted_origins = trusted_origins or []
self._request_id = 0
self._session_id: str | None = None
def _next_id(self) -> int:
self._request_id += 1
@@ -72,6 +73,8 @@ class MCPClient:
}
if self.auth_token:
headers["Authorization"] = f"Bearer {self.auth_token}"
if self._session_id:
headers["Mcp-Session-Id"] = self._session_id
return headers
def _build_jsonrpc_request(
@@ -133,6 +136,11 @@ class MCPClient:
)
response = await requests.post(self.server_url, json=payload)
# Capture session ID from response (MCP Streamable HTTP transport)
session_id = response.headers.get("Mcp-Session-Id")
if session_id:
self._session_id = session_id
content_type = response.headers.get("content-type", "")
if "text/event-stream" in content_type:
body = self._parse_sse_response(response.text())

View File

@@ -573,7 +573,7 @@ class TestMCPToolBlock:
captured_tokens.append(auth_token)
return "ok"
async def mock_resolve(self, cred_id, uid):
async def mock_resolve(self, cred_id, uid, server_url=""):
return "resolved-token"
block._call_mcp_tool = mock_call # type: ignore

View File

@@ -101,7 +101,7 @@ class HostResolver(abc.AbstractResolver):
def __init__(self, ssl_hostname: str, ip_addresses: list[str]):
self.ssl_hostname = ssl_hostname
self.ip_addresses = ip_addresses
self._default = aiohttp.AsyncResolver()
self._default = aiohttp.ThreadedResolver()
async def resolve(self, host, port=0, family=socket.AF_INET):
if host == self.ssl_hostname:
@@ -467,7 +467,13 @@ class Requests:
resolver = HostResolver(ssl_hostname=hostname, ip_addresses=ip_addresses)
ssl_context = ssl.create_default_context()
connector = aiohttp.TCPConnector(resolver=resolver, ssl=ssl_context)
session_kwargs = {}
else:
# Use ThreadedResolver for trusted origins to avoid c-ares DNS issues
# in subprocess environments (e.g. ExecutionManager on macOS).
connector = aiohttp.TCPConnector(
resolver=aiohttp.ThreadedResolver()
)
session_kwargs: dict = {}
if connector:
session_kwargs["connector"] = connector

View File

@@ -38,7 +38,7 @@ interface MCPToolDialogProps {
type DialogStep = "url" | "tool";
const OAUTH_TIMEOUT_MS = 5 * 60 * 1000; // 5 minutes
const STORAGE_KEY = "mcp_last_server_url";
export function MCPToolDialog({
open,
@@ -71,23 +71,7 @@ export function MCPToolDialog({
);
const popupCheckRef = useRef<ReturnType<typeof setInterval> | null>(null);
const oauthHandledRef = useRef(false);
const autoConnectAttemptedRef = useRef(false);
// Pre-fill last used server URL when dialog opens (without auto-connecting)
useEffect(() => {
if (!open) {
autoConnectAttemptedRef.current = false;
return;
}
if (autoConnectAttemptedRef.current) return;
autoConnectAttemptedRef.current = true;
const lastUrl = localStorage.getItem(STORAGE_KEY);
if (lastUrl) {
setServerUrl(lastUrl);
}
}, [open]);
// (no auto-prefill — dialog starts fresh each time)
// Clean up listeners on unmount
useEffect(() => {
@@ -160,7 +144,6 @@ export function MCPToolDialog({
setError(null);
try {
const result = await api.mcpDiscoverTools(url, authToken);
localStorage.setItem(STORAGE_KEY, url);
setTools(result.tools);
setServerName(result.server_name);
setAuthRequired(false);
@@ -218,12 +201,19 @@ export function MCPToolDialog({
);
setCredentialId(callbackResult.credential_id);
const result = await api.mcpDiscoverTools(serverUrl.trim());
localStorage.setItem(STORAGE_KEY, serverUrl.trim());
setTools(result.tools);
setServerName(result.server_name);
setStep("tool");
} catch (e: any) {
const message = e?.message || e?.detail || "Failed to complete sign-in";
const status = e?.status;
let message: string;
if (status === 401 || status === 403) {
message =
"Authentication succeeded but the server still rejected the request. " +
"The token audience may not match. Please try again.";
} else {
message = e?.message || e?.detail || "Failed to complete sign-in";
}
setError(
typeof message === "string" ? message : JSON.stringify(message),
);
@@ -410,35 +400,14 @@ export function MCPToolDialog({
/>
</div>
{/* Auth required: show sign-in panel */}
{authRequired && (
<div className="flex flex-col items-center gap-3 rounded-lg border border-gray-200 bg-gray-50 p-4 dark:border-slate-700 dark:bg-slate-800">
<p className="text-sm font-medium text-gray-600 dark:text-gray-300">
This server requires authentication
</p>
<Button
onClick={handleOAuthSignIn}
disabled={oauthLoading || loading}
className="w-full"
>
{oauthLoading ? (
<span className="flex items-center gap-2">
<LoadingSpinner className="size-4" />
Waiting for sign-in...
</span>
) : (
"Sign in"
)}
</Button>
{!showManualToken && (
<button
onClick={() => setShowManualToken(true)}
className="text-xs text-gray-500 underline hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-300"
>
or enter a token manually
</button>
)}
</div>
{/* Auth required: show manual token option */}
{authRequired && !showManualToken && (
<button
onClick={() => setShowManualToken(true)}
className="text-xs text-gray-500 underline hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-300"
>
or enter a token manually
</button>
)}
{/* Manual token entry — only visible when expanded */}
@@ -495,14 +464,20 @@ export function MCPToolDialog({
</Button>
{step === "url" && (
<Button
onClick={handleDiscoverTools}
onClick={
authRequired && !showManualToken
? handleOAuthSignIn
: handleDiscoverTools
}
disabled={!serverUrl.trim() || loading || oauthLoading}
>
{loading ? (
{loading || oauthLoading ? (
<span className="flex items-center gap-2">
<LoadingSpinner className="size-4" />
Connecting...
{oauthLoading ? "Waiting for sign-in..." : "Connecting..."}
</span>
) : authRequired && !showManualToken ? (
"Sign in & Connect"
) : (
"Discover Tools"
)}