mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-09 14:25:25 -05:00
fix(mcp): OAuth discovery fallback, session ID, credential lookup, and DNS reliability
- Support MCP servers that serve OAuth metadata directly without protected-resource metadata (e.g. Linear) by falling back to discover_auth_server_metadata on the server's own origin - Omit resource_url when no protected-resource metadata exists to avoid token audience mismatch errors (RFC 8707 resource is optional) - Add Mcp-Session-Id header tracking per MCP Streamable HTTP spec - Fall back to server_url credential lookup when credential_id is empty (pruneEmptyValues strips it from saved graphs) - Use ThreadedResolver instead of c-ares AsyncResolver to avoid DNS failures in forked subprocess environments - Simplify OAuth UX: single "Sign in & Connect" button on 401, remove sticky localStorage URL prefill - Clean up stale MCP credentials on re-authentication
This commit is contained in:
@@ -81,13 +81,24 @@ async def discover_tools(
|
||||
mcp_creds = await creds_manager.store.get_creds_by_provider(
|
||||
user_id, str(ProviderName.MCP)
|
||||
)
|
||||
# Find the freshest credential for this server URL
|
||||
best_cred: OAuth2Credentials | None = None
|
||||
for cred in mcp_creds:
|
||||
if (
|
||||
isinstance(cred, OAuth2Credentials)
|
||||
and cred.metadata.get("mcp_server_url") == request.server_url
|
||||
):
|
||||
auth_token = cred.access_token.get_secret_value()
|
||||
break
|
||||
if best_cred is None or (
|
||||
(cred.access_token_expires_at or 0)
|
||||
> (best_cred.access_token_expires_at or 0)
|
||||
):
|
||||
best_cred = cred
|
||||
if best_cred:
|
||||
logger.info(
|
||||
f"Using MCP credential {best_cred.id} for {request.server_url}, "
|
||||
f"expires_at={best_cred.access_token_expires_at}"
|
||||
)
|
||||
auth_token = best_cred.access_token.get_secret_value()
|
||||
except Exception:
|
||||
logger.debug("Could not look up stored MCP credentials", exc_info=True)
|
||||
|
||||
@@ -115,6 +126,9 @@ async def discover_tools(
|
||||
)
|
||||
except HTTPClientError as e:
|
||||
if e.status_code in (401, 403):
|
||||
logger.warning(
|
||||
f"MCP server returned {e.status_code} for {request.server_url}: {e}"
|
||||
)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=401,
|
||||
detail="This MCP server requires authentication. "
|
||||
@@ -174,31 +188,38 @@ async def mcp_oauth_login(
|
||||
detail=f"Failed to discover OAuth metadata: {e}",
|
||||
)
|
||||
|
||||
if not protected_resource or "authorization_servers" not in protected_resource:
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
if protected_resource and "authorization_servers" in protected_resource:
|
||||
auth_server_url = protected_resource["authorization_servers"][0]
|
||||
resource_url = protected_resource.get("resource", request.server_url)
|
||||
|
||||
# Step 2a: Discover auth-server metadata (RFC 8414)
|
||||
try:
|
||||
metadata = await client.discover_auth_server_metadata(auth_server_url)
|
||||
except Exception as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Failed to discover authorization server metadata: {e}",
|
||||
)
|
||||
else:
|
||||
# Fallback: Some MCP servers (e.g. Linear) are their own auth server
|
||||
# and serve OAuth metadata directly without protected-resource metadata.
|
||||
# Don't assume a resource_url — omitting it lets the auth server choose
|
||||
# the correct audience for the token (RFC 8707 resource is optional).
|
||||
resource_url = None
|
||||
try:
|
||||
metadata = await client.discover_auth_server_metadata(request.server_url)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not metadata or "authorization_endpoint" not in metadata:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400,
|
||||
detail="This MCP server does not advertise OAuth support. "
|
||||
"You may need to provide an auth token manually.",
|
||||
)
|
||||
|
||||
auth_server_url = protected_resource["authorization_servers"][0]
|
||||
resource_url = protected_resource.get("resource", request.server_url)
|
||||
|
||||
# Step 2: Discover auth-server metadata (RFC 8414)
|
||||
try:
|
||||
metadata = await client.discover_auth_server_metadata(auth_server_url)
|
||||
except Exception as e:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Failed to discover authorization server metadata: {e}",
|
||||
)
|
||||
|
||||
if not metadata or "authorization_endpoint" not in metadata:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=502,
|
||||
detail="Authorization server metadata is missing required endpoints.",
|
||||
)
|
||||
|
||||
authorize_url = metadata["authorization_endpoint"]
|
||||
token_url = metadata["token_endpoint"]
|
||||
registration_endpoint = metadata.get("registration_endpoint")
|
||||
@@ -227,7 +248,10 @@ async def mcp_oauth_login(
|
||||
client_id = "autogpt-platform"
|
||||
|
||||
# Step 4: Store state token with OAuth metadata for the callback
|
||||
scopes = protected_resource.get("scopes_supported", [])
|
||||
scopes = (
|
||||
(protected_resource or {}).get("scopes_supported")
|
||||
or metadata.get("scopes_supported", [])
|
||||
)
|
||||
state_token, code_challenge = await creds_manager.store.store_state_token(
|
||||
user_id,
|
||||
str(ProviderName.MCP),
|
||||
@@ -327,10 +351,27 @@ async def mcp_oauth_callback(
|
||||
credentials.metadata["mcp_server_url"] = meta["server_url"]
|
||||
credentials.metadata["mcp_client_id"] = meta["client_id"]
|
||||
credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "")
|
||||
credentials.metadata["mcp_token_url"] = meta["token_url"]
|
||||
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
|
||||
|
||||
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
|
||||
credentials.title = f"MCP: {hostname}"
|
||||
|
||||
# Remove old MCP credentials for the same server to prevent stale token buildup
|
||||
try:
|
||||
old_creds = await creds_manager.store.get_creds_by_provider(
|
||||
user_id, str(ProviderName.MCP)
|
||||
)
|
||||
for old in old_creds:
|
||||
if (
|
||||
isinstance(old, OAuth2Credentials)
|
||||
and old.metadata.get("mcp_server_url") == meta["server_url"]
|
||||
):
|
||||
await creds_manager.store.delete_creds_by_id(user_id, old.id)
|
||||
logger.info(f"Removed old MCP credential {old.id} for {meta['server_url']}")
|
||||
except Exception:
|
||||
logger.debug("Could not clean up old MCP credentials", exc_info=True)
|
||||
|
||||
await creds_manager.create(user_id, credentials)
|
||||
|
||||
return MCPOAuthCallbackResponse(credential_id=credentials.id)
|
||||
|
||||
@@ -193,17 +193,50 @@ class MCPToolBlock(Block):
|
||||
return output_parts[0]
|
||||
return output_parts if output_parts else None
|
||||
|
||||
async def _resolve_auth_token(self, credential_id: str, user_id: str) -> str | None:
|
||||
"""Resolve a Bearer token from a stored credential ID, refreshing if needed."""
|
||||
if not credential_id:
|
||||
return None
|
||||
async def _resolve_auth_token(
|
||||
self, credential_id: str, user_id: str, server_url: str = ""
|
||||
) -> str | None:
|
||||
"""Resolve a Bearer token from a stored credential ID, refreshing if needed.
|
||||
|
||||
Falls back to looking up credentials by server_url when credential_id
|
||||
is empty (e.g. when pruneEmptyValues strips it from the saved graph).
|
||||
"""
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
store = IntegrationCredentialsStore()
|
||||
creds = await store.get_creds_by_id(user_id, credential_id)
|
||||
creds = None
|
||||
|
||||
if credential_id:
|
||||
creds = await store.get_creds_by_id(user_id, credential_id)
|
||||
|
||||
# Fallback: look up by server_url (same approach as discover-tools)
|
||||
if not creds and server_url:
|
||||
logger.info(
|
||||
f"credential_id not available, looking up credential by server_url"
|
||||
)
|
||||
try:
|
||||
mcp_creds = await store.get_creds_by_provider(
|
||||
user_id, str(ProviderName.MCP)
|
||||
)
|
||||
best: OAuth2Credentials | None = None
|
||||
for c in mcp_creds:
|
||||
if (
|
||||
isinstance(c, OAuth2Credentials)
|
||||
and c.metadata.get("mcp_server_url") == server_url
|
||||
):
|
||||
if best is None or (
|
||||
(c.access_token_expires_at or 0)
|
||||
> (best.access_token_expires_at or 0)
|
||||
):
|
||||
best = c
|
||||
creds = best
|
||||
except Exception:
|
||||
logger.debug("Could not look up MCP credentials by server_url", exc_info=True)
|
||||
|
||||
if not creds:
|
||||
logger.warning(f"Credential {credential_id} not found")
|
||||
return None
|
||||
|
||||
if isinstance(creds, OAuth2Credentials):
|
||||
# Refresh if token expires within 5 minutes
|
||||
if (
|
||||
@@ -266,7 +299,9 @@ class MCPToolBlock(Block):
|
||||
yield "error", "No tool selected. Please select a tool from the dropdown."
|
||||
return
|
||||
|
||||
auth_token = await self._resolve_auth_token(input_data.credential_id, user_id)
|
||||
auth_token = await self._resolve_auth_token(
|
||||
input_data.credential_id, user_id, server_url=input_data.server_url
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self._call_mcp_tool(
|
||||
|
||||
@@ -60,6 +60,7 @@ class MCPClient:
|
||||
self.auth_token = auth_token
|
||||
self.trusted_origins = trusted_origins or []
|
||||
self._request_id = 0
|
||||
self._session_id: str | None = None
|
||||
|
||||
def _next_id(self) -> int:
|
||||
self._request_id += 1
|
||||
@@ -72,6 +73,8 @@ class MCPClient:
|
||||
}
|
||||
if self.auth_token:
|
||||
headers["Authorization"] = f"Bearer {self.auth_token}"
|
||||
if self._session_id:
|
||||
headers["Mcp-Session-Id"] = self._session_id
|
||||
return headers
|
||||
|
||||
def _build_jsonrpc_request(
|
||||
@@ -133,6 +136,11 @@ class MCPClient:
|
||||
)
|
||||
response = await requests.post(self.server_url, json=payload)
|
||||
|
||||
# Capture session ID from response (MCP Streamable HTTP transport)
|
||||
session_id = response.headers.get("Mcp-Session-Id")
|
||||
if session_id:
|
||||
self._session_id = session_id
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "text/event-stream" in content_type:
|
||||
body = self._parse_sse_response(response.text())
|
||||
|
||||
@@ -573,7 +573,7 @@ class TestMCPToolBlock:
|
||||
captured_tokens.append(auth_token)
|
||||
return "ok"
|
||||
|
||||
async def mock_resolve(self, cred_id, uid):
|
||||
async def mock_resolve(self, cred_id, uid, server_url=""):
|
||||
return "resolved-token"
|
||||
|
||||
block._call_mcp_tool = mock_call # type: ignore
|
||||
|
||||
@@ -101,7 +101,7 @@ class HostResolver(abc.AbstractResolver):
|
||||
def __init__(self, ssl_hostname: str, ip_addresses: list[str]):
|
||||
self.ssl_hostname = ssl_hostname
|
||||
self.ip_addresses = ip_addresses
|
||||
self._default = aiohttp.AsyncResolver()
|
||||
self._default = aiohttp.ThreadedResolver()
|
||||
|
||||
async def resolve(self, host, port=0, family=socket.AF_INET):
|
||||
if host == self.ssl_hostname:
|
||||
@@ -467,7 +467,13 @@ class Requests:
|
||||
resolver = HostResolver(ssl_hostname=hostname, ip_addresses=ip_addresses)
|
||||
ssl_context = ssl.create_default_context()
|
||||
connector = aiohttp.TCPConnector(resolver=resolver, ssl=ssl_context)
|
||||
session_kwargs = {}
|
||||
else:
|
||||
# Use ThreadedResolver for trusted origins to avoid c-ares DNS issues
|
||||
# in subprocess environments (e.g. ExecutionManager on macOS).
|
||||
connector = aiohttp.TCPConnector(
|
||||
resolver=aiohttp.ThreadedResolver()
|
||||
)
|
||||
session_kwargs: dict = {}
|
||||
if connector:
|
||||
session_kwargs["connector"] = connector
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ interface MCPToolDialogProps {
|
||||
type DialogStep = "url" | "tool";
|
||||
|
||||
const OAUTH_TIMEOUT_MS = 5 * 60 * 1000; // 5 minutes
|
||||
const STORAGE_KEY = "mcp_last_server_url";
|
||||
|
||||
|
||||
export function MCPToolDialog({
|
||||
open,
|
||||
@@ -71,23 +71,7 @@ export function MCPToolDialog({
|
||||
);
|
||||
const popupCheckRef = useRef<ReturnType<typeof setInterval> | null>(null);
|
||||
const oauthHandledRef = useRef(false);
|
||||
const autoConnectAttemptedRef = useRef(false);
|
||||
|
||||
// Pre-fill last used server URL when dialog opens (without auto-connecting)
|
||||
useEffect(() => {
|
||||
if (!open) {
|
||||
autoConnectAttemptedRef.current = false;
|
||||
return;
|
||||
}
|
||||
|
||||
if (autoConnectAttemptedRef.current) return;
|
||||
autoConnectAttemptedRef.current = true;
|
||||
|
||||
const lastUrl = localStorage.getItem(STORAGE_KEY);
|
||||
if (lastUrl) {
|
||||
setServerUrl(lastUrl);
|
||||
}
|
||||
}, [open]);
|
||||
// (no auto-prefill — dialog starts fresh each time)
|
||||
|
||||
// Clean up listeners on unmount
|
||||
useEffect(() => {
|
||||
@@ -160,7 +144,6 @@ export function MCPToolDialog({
|
||||
setError(null);
|
||||
try {
|
||||
const result = await api.mcpDiscoverTools(url, authToken);
|
||||
localStorage.setItem(STORAGE_KEY, url);
|
||||
setTools(result.tools);
|
||||
setServerName(result.server_name);
|
||||
setAuthRequired(false);
|
||||
@@ -218,12 +201,19 @@ export function MCPToolDialog({
|
||||
);
|
||||
setCredentialId(callbackResult.credential_id);
|
||||
const result = await api.mcpDiscoverTools(serverUrl.trim());
|
||||
localStorage.setItem(STORAGE_KEY, serverUrl.trim());
|
||||
setTools(result.tools);
|
||||
setServerName(result.server_name);
|
||||
setStep("tool");
|
||||
} catch (e: any) {
|
||||
const message = e?.message || e?.detail || "Failed to complete sign-in";
|
||||
const status = e?.status;
|
||||
let message: string;
|
||||
if (status === 401 || status === 403) {
|
||||
message =
|
||||
"Authentication succeeded but the server still rejected the request. " +
|
||||
"The token audience may not match. Please try again.";
|
||||
} else {
|
||||
message = e?.message || e?.detail || "Failed to complete sign-in";
|
||||
}
|
||||
setError(
|
||||
typeof message === "string" ? message : JSON.stringify(message),
|
||||
);
|
||||
@@ -410,35 +400,14 @@ export function MCPToolDialog({
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Auth required: show sign-in panel */}
|
||||
{authRequired && (
|
||||
<div className="flex flex-col items-center gap-3 rounded-lg border border-gray-200 bg-gray-50 p-4 dark:border-slate-700 dark:bg-slate-800">
|
||||
<p className="text-sm font-medium text-gray-600 dark:text-gray-300">
|
||||
This server requires authentication
|
||||
</p>
|
||||
<Button
|
||||
onClick={handleOAuthSignIn}
|
||||
disabled={oauthLoading || loading}
|
||||
className="w-full"
|
||||
>
|
||||
{oauthLoading ? (
|
||||
<span className="flex items-center gap-2">
|
||||
<LoadingSpinner className="size-4" />
|
||||
Waiting for sign-in...
|
||||
</span>
|
||||
) : (
|
||||
"Sign in"
|
||||
)}
|
||||
</Button>
|
||||
{!showManualToken && (
|
||||
<button
|
||||
onClick={() => setShowManualToken(true)}
|
||||
className="text-xs text-gray-500 underline hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-300"
|
||||
>
|
||||
or enter a token manually
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
{/* Auth required: show manual token option */}
|
||||
{authRequired && !showManualToken && (
|
||||
<button
|
||||
onClick={() => setShowManualToken(true)}
|
||||
className="text-xs text-gray-500 underline hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-300"
|
||||
>
|
||||
or enter a token manually
|
||||
</button>
|
||||
)}
|
||||
|
||||
{/* Manual token entry — only visible when expanded */}
|
||||
@@ -495,14 +464,20 @@ export function MCPToolDialog({
|
||||
</Button>
|
||||
{step === "url" && (
|
||||
<Button
|
||||
onClick={handleDiscoverTools}
|
||||
onClick={
|
||||
authRequired && !showManualToken
|
||||
? handleOAuthSignIn
|
||||
: handleDiscoverTools
|
||||
}
|
||||
disabled={!serverUrl.trim() || loading || oauthLoading}
|
||||
>
|
||||
{loading ? (
|
||||
{loading || oauthLoading ? (
|
||||
<span className="flex items-center gap-2">
|
||||
<LoadingSpinner className="size-4" />
|
||||
Connecting...
|
||||
{oauthLoading ? "Waiting for sign-in..." : "Connecting..."}
|
||||
</span>
|
||||
) : authRequired && !showManualToken ? (
|
||||
"Sign in & Connect"
|
||||
) : (
|
||||
"Discover Tools"
|
||||
)}
|
||||
|
||||
Reference in New Issue
Block a user