mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-09 22:35:54 -05:00
refactor(mcp): Share OAuth popup logic and fix credential persistence
- Extract shared OAuth popup utility (oauth-popup.ts) used by both MCPToolDialog and useCredentialsInput, eliminating ~200 lines of duplicated BroadcastChannel/postMessage/localStorage listener code - Add mcpOAuthCallback to credentials provider so MCP credentials are added to the in-memory cache after OAuth (fixes credentials not appearing in the credential picker after OAuth via MCPToolDialog) - Fix oauth_test.py async fixtures missing loop_scope="session" - Add MCP token refresh handler in creds_manager for dynamic endpoints - Fix enum string representation in CredentialsFieldInfo.combine()
This commit is contained in:
@@ -68,7 +68,7 @@ async def test_user(server, test_user_id: str):
|
||||
await PrismaUser.prisma().delete(where={"id": test_user_id})
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@pytest_asyncio.fixture(loop_scope="session")
|
||||
async def test_oauth_app(test_user: str):
|
||||
"""Create a test OAuth application in the database."""
|
||||
app_id = str(uuid.uuid4())
|
||||
@@ -123,7 +123,7 @@ def pkce_credentials() -> tuple[str, str]:
|
||||
return generate_pkce()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@pytest_asyncio.fixture(loop_scope="session")
|
||||
async def client(server, test_user: str) -> AsyncGenerator[httpx.AsyncClient, None]:
|
||||
"""
|
||||
Create an async HTTP client that talks directly to the FastAPI app.
|
||||
@@ -288,7 +288,7 @@ async def test_authorize_invalid_client_returns_error(
|
||||
assert query_params["error"][0] == "invalid_client"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@pytest_asyncio.fixture(loop_scope="session")
|
||||
async def inactive_oauth_app(test_user: str):
|
||||
"""Create an inactive test OAuth application in the database."""
|
||||
app_id = str(uuid.uuid4())
|
||||
@@ -1005,7 +1005,7 @@ async def test_token_refresh_revoked(
|
||||
assert "revoked" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@pytest_asyncio.fixture(loop_scope="session")
|
||||
async def other_oauth_app(test_user: str):
|
||||
"""Create a second OAuth application for cross-app tests."""
|
||||
app_id = str(uuid.uuid4())
|
||||
|
||||
@@ -463,6 +463,9 @@ class GraphModel(Graph, GraphMeta):
|
||||
continue
|
||||
if ProviderName.HTTP in field.provider:
|
||||
continue
|
||||
# MCP credentials are intentionally split by server URL
|
||||
if ProviderName.MCP in field.provider:
|
||||
continue
|
||||
|
||||
# If this happens, that means a block implementation probably needs
|
||||
# to be updated.
|
||||
@@ -520,7 +523,7 @@ class GraphModel(Graph, GraphMeta):
|
||||
}
|
||||
|
||||
# Add a descriptive display title when URL-based discriminator values
|
||||
# are present (e.g. "MCP: mcp.sentry.dev" instead of just "Mcp")
|
||||
# are present (e.g. "mcp.sentry.dev" instead of just "Mcp")
|
||||
if (
|
||||
field_info.discriminator
|
||||
and not field_info.discriminator_mapping
|
||||
@@ -529,10 +532,7 @@ class GraphModel(Graph, GraphMeta):
|
||||
hostnames = sorted(
|
||||
parse_url(str(v)).netloc for v in field_info.discriminator_values
|
||||
)
|
||||
base_name = (
|
||||
next(iter(field_info.provider), "").replace("_", " ").upper()
|
||||
)
|
||||
field_schema["display_name"] = f"{base_name}: {', '.join(hostnames)}"
|
||||
field_schema["display_name"] = ", ".join(hostnames)
|
||||
|
||||
# Add other (optional) field info items
|
||||
field_schema.update(
|
||||
|
||||
@@ -611,8 +611,10 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
# URL-based discrimination (e.g. HTTP host-scoped, MCP server URL):
|
||||
# Each unique host gets its own credential entry.
|
||||
provider_prefix = next(iter(field.provider))
|
||||
# Use .value for enum types to get the plain string (e.g. "mcp" not "ProviderName.MCP")
|
||||
prefix_str = getattr(provider_prefix, "value", str(provider_prefix))
|
||||
providers = frozenset(
|
||||
[cast(CP, str(provider_prefix))]
|
||||
[cast(CP, prefix_str)]
|
||||
+ [
|
||||
cast(CP, parse_url(str(value)).netloc)
|
||||
for value in field.discriminator_values
|
||||
|
||||
@@ -137,7 +137,10 @@ class IntegrationCredentialsManager:
|
||||
self, user_id: str, credentials: OAuth2Credentials, lock: bool = True
|
||||
) -> OAuth2Credentials:
|
||||
async with self._locked(user_id, credentials.id, "refresh"):
|
||||
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
|
||||
if credentials.provider == str(ProviderName.MCP):
|
||||
oauth_handler = _create_mcp_oauth_handler(credentials)
|
||||
else:
|
||||
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
|
||||
if oauth_handler.needs_refresh(credentials):
|
||||
logger.debug(
|
||||
f"Refreshing '{credentials.provider}' "
|
||||
@@ -236,3 +239,25 @@ async def _get_provider_oauth_handler(provider_name_str: str) -> "BaseOAuthHandl
|
||||
client_secret=client_secret,
|
||||
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
|
||||
)
|
||||
|
||||
|
||||
def _create_mcp_oauth_handler(
|
||||
credentials: OAuth2Credentials,
|
||||
) -> "BaseOAuthHandler":
|
||||
"""Create an MCPOAuthHandler from credential metadata for token refresh.
|
||||
|
||||
MCP OAuth handlers have dynamic endpoints discovered per-server, so they
|
||||
can't be registered as singletons in HANDLERS_BY_NAME. Instead, the handler
|
||||
is reconstructed from metadata stored on the credential during initial auth.
|
||||
"""
|
||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||
|
||||
meta = credentials.metadata or {}
|
||||
return MCPOAuthHandler(
|
||||
client_id=meta.get("mcp_client_id", ""),
|
||||
client_secret=meta.get("mcp_client_secret", ""),
|
||||
redirect_uri="", # Not needed for token refresh
|
||||
authorize_url="", # Not needed for token refresh
|
||||
token_url=meta.get("mcp_token_url", ""),
|
||||
resource_url=meta.get("mcp_resource_url"),
|
||||
)
|
||||
|
||||
@@ -47,13 +47,16 @@ export async function GET(request: Request) {
|
||||
var msg = ${safeJsonStringify(message)};
|
||||
var sent = false;
|
||||
|
||||
console.log("[MCP Callback] Script running, success:", msg.success, "code:", !!msg.code, "state:", !!msg.state);
|
||||
|
||||
// Method 1: BroadcastChannel (reliable across tabs/popups, no opener needed)
|
||||
try {
|
||||
var bc = new BroadcastChannel("mcp_oauth");
|
||||
bc.postMessage({ type: "mcp_oauth_result", success: msg.success, code: msg.code, state: msg.state, message: msg.message });
|
||||
bc.close();
|
||||
sent = true;
|
||||
} catch(e) { console.warn("BroadcastChannel failed:", e); }
|
||||
console.log("[MCP Callback] BroadcastChannel message sent");
|
||||
} catch(e) { console.warn("[MCP Callback] BroadcastChannel failed:", e); }
|
||||
|
||||
// Method 2: window.opener.postMessage (fallback for same-origin popups)
|
||||
try {
|
||||
@@ -63,14 +66,18 @@ export async function GET(request: Request) {
|
||||
window.location.origin
|
||||
);
|
||||
sent = true;
|
||||
console.log("[MCP Callback] postMessage sent to opener");
|
||||
} else {
|
||||
console.log("[MCP Callback] window.opener not available (COOP or popup blocked)");
|
||||
}
|
||||
} catch(e) { console.warn("postMessage failed:", e); }
|
||||
} catch(e) { console.warn("[MCP Callback] postMessage failed:", e); }
|
||||
|
||||
// Method 3: localStorage (most reliable cross-tab fallback)
|
||||
try {
|
||||
localStorage.setItem("mcp_oauth_result", JSON.stringify(msg));
|
||||
sent = true;
|
||||
} catch(e) { console.warn("localStorage failed:", e); }
|
||||
console.log("[MCP Callback] localStorage set");
|
||||
} catch(e) { console.warn("[MCP Callback] localStorage failed:", e); }
|
||||
|
||||
var statusEl = document.getElementById("status");
|
||||
var spinnerEl = document.getElementById("spinner");
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState, useCallback, useRef, useEffect } from "react";
|
||||
import React, {
|
||||
useState,
|
||||
useCallback,
|
||||
useRef,
|
||||
useEffect,
|
||||
useContext,
|
||||
} from "react";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
@@ -18,6 +24,8 @@ import { ScrollArea } from "@/components/__legacy__/ui/scroll-area";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import type { CredentialsMetaInput, MCPTool } from "@/lib/autogpt-server-api";
|
||||
import { CaretDown } from "@phosphor-icons/react";
|
||||
import { openOAuthPopup } from "@/lib/oauth-popup";
|
||||
import { CredentialsProvidersContext } from "@/providers/agent-credentials/credentials-provider";
|
||||
|
||||
export type MCPToolDialogResult = {
|
||||
serverUrl: string;
|
||||
@@ -37,14 +45,13 @@ interface MCPToolDialogProps {
|
||||
|
||||
type DialogStep = "url" | "tool";
|
||||
|
||||
const OAUTH_TIMEOUT_MS = 5 * 60 * 1000; // 5 minutes
|
||||
|
||||
export function MCPToolDialog({
|
||||
open,
|
||||
onClose,
|
||||
onConfirm,
|
||||
}: MCPToolDialogProps) {
|
||||
const api = useBackendAPI();
|
||||
const allProviders = useContext(CredentialsProvidersContext);
|
||||
|
||||
const [step, setStep] = useState<DialogStep>("url");
|
||||
const [serverUrl, setServerUrl] = useState("");
|
||||
@@ -61,74 +68,19 @@ export function MCPToolDialog({
|
||||
null,
|
||||
);
|
||||
|
||||
const oauthLoadingRef = useRef(false);
|
||||
const stateTokenRef = useRef<string | null>(null);
|
||||
const broadcastChannelRef = useRef<BroadcastChannel | null>(null);
|
||||
const messageHandlerRef = useRef<((event: MessageEvent) => void) | null>(
|
||||
null,
|
||||
);
|
||||
const storageHandlerRef = useRef<((event: StorageEvent) => void) | null>(
|
||||
null,
|
||||
);
|
||||
const popupCheckRef = useRef<ReturnType<typeof setInterval> | null>(null);
|
||||
const storagePollRef = useRef<ReturnType<typeof setInterval> | null>(null);
|
||||
const oauthHandledRef = useRef(false);
|
||||
// (no auto-prefill — dialog starts fresh each time)
|
||||
const startOAuthRef = useRef(false);
|
||||
const oauthAbortRef = useRef<((reason?: string) => void) | null>(null);
|
||||
|
||||
// Clean up listeners on unmount
|
||||
// Clean up on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (messageHandlerRef.current) {
|
||||
window.removeEventListener("message", messageHandlerRef.current);
|
||||
}
|
||||
if (storageHandlerRef.current) {
|
||||
window.removeEventListener("storage", storageHandlerRef.current);
|
||||
}
|
||||
if (broadcastChannelRef.current) {
|
||||
broadcastChannelRef.current.close();
|
||||
}
|
||||
if (popupCheckRef.current) {
|
||||
clearInterval(popupCheckRef.current);
|
||||
}
|
||||
if (storagePollRef.current) {
|
||||
clearInterval(storagePollRef.current);
|
||||
}
|
||||
oauthAbortRef.current?.();
|
||||
};
|
||||
}, []);
|
||||
|
||||
const cleanupOAuthListeners = useCallback(() => {
|
||||
if (messageHandlerRef.current) {
|
||||
window.removeEventListener("message", messageHandlerRef.current);
|
||||
messageHandlerRef.current = null;
|
||||
}
|
||||
if (storageHandlerRef.current) {
|
||||
window.removeEventListener("storage", storageHandlerRef.current);
|
||||
storageHandlerRef.current = null;
|
||||
}
|
||||
if (broadcastChannelRef.current) {
|
||||
broadcastChannelRef.current.close();
|
||||
broadcastChannelRef.current = null;
|
||||
}
|
||||
if (popupCheckRef.current) {
|
||||
clearInterval(popupCheckRef.current);
|
||||
popupCheckRef.current = null;
|
||||
}
|
||||
if (storagePollRef.current) {
|
||||
clearInterval(storagePollRef.current);
|
||||
storagePollRef.current = null;
|
||||
}
|
||||
// Clean up any stale localStorage entry
|
||||
try {
|
||||
localStorage.removeItem("mcp_oauth_result");
|
||||
} catch {}
|
||||
setOauthLoading(false);
|
||||
oauthLoadingRef.current = false;
|
||||
// NOTE: do NOT reset oauthHandledRef here — it guards against double-handling
|
||||
// and must only be reset when starting a new OAuth flow.
|
||||
}, []);
|
||||
|
||||
const reset = useCallback(() => {
|
||||
cleanupOAuthListeners();
|
||||
oauthAbortRef.current?.();
|
||||
oauthAbortRef.current = null;
|
||||
setStep("url");
|
||||
setServerUrl("");
|
||||
setManualToken("");
|
||||
@@ -137,11 +89,11 @@ export function MCPToolDialog({
|
||||
setLoading(false);
|
||||
setError(null);
|
||||
setAuthRequired(false);
|
||||
setOauthLoading(false);
|
||||
setShowManualToken(false);
|
||||
setSelectedTool(null);
|
||||
setCredentials(null);
|
||||
stateTokenRef.current = null;
|
||||
}, [cleanupOAuthListeners]);
|
||||
}, []);
|
||||
|
||||
const handleClose = useCallback(() => {
|
||||
reset();
|
||||
@@ -163,6 +115,10 @@ export function MCPToolDialog({
|
||||
if (e?.status === 401 || e?.status === 403) {
|
||||
setAuthRequired(true);
|
||||
setError(null);
|
||||
// Automatically start OAuth sign-in instead of requiring a second click
|
||||
setLoading(false);
|
||||
startOAuthRef.current = true;
|
||||
return;
|
||||
} else {
|
||||
const message =
|
||||
e?.message || e?.detail || "Failed to connect to MCP server";
|
||||
@@ -182,44 +138,60 @@ export function MCPToolDialog({
|
||||
discoverTools(serverUrl.trim(), manualToken.trim() || undefined);
|
||||
}, [serverUrl, manualToken, discoverTools]);
|
||||
|
||||
const handleOAuthResult = useCallback(
|
||||
async (data: {
|
||||
success: boolean;
|
||||
code?: string;
|
||||
state?: string;
|
||||
message?: string;
|
||||
}) => {
|
||||
// Prevent double-handling (BroadcastChannel + postMessage may both fire)
|
||||
if (oauthHandledRef.current) return;
|
||||
oauthHandledRef.current = true;
|
||||
const handleOAuthSignIn = useCallback(async () => {
|
||||
if (!serverUrl.trim()) return;
|
||||
setError(null);
|
||||
|
||||
if (!data.success) {
|
||||
setError(data.message || "OAuth authentication failed.");
|
||||
cleanupOAuthListeners();
|
||||
return;
|
||||
}
|
||||
// Abort any previous OAuth flow
|
||||
oauthAbortRef.current?.();
|
||||
|
||||
cleanupOAuthListeners();
|
||||
setOauthLoading(true);
|
||||
|
||||
try {
|
||||
const { login_url, state_token } = await api.mcpOAuthLogin(
|
||||
serverUrl.trim(),
|
||||
);
|
||||
|
||||
const { promise, cleanup } = openOAuthPopup(login_url, {
|
||||
stateToken: state_token,
|
||||
useCrossOriginListeners: true,
|
||||
});
|
||||
oauthAbortRef.current = cleanup.abort;
|
||||
|
||||
const result = await promise;
|
||||
|
||||
// Exchange code for tokens via the credentials provider (updates cache)
|
||||
setLoading(true);
|
||||
setOauthLoading(false);
|
||||
|
||||
const mcpProvider = allProviders?.["mcp"];
|
||||
const callbackResult = mcpProvider
|
||||
? await mcpProvider.mcpOAuthCallback(result.code, state_token)
|
||||
: await api.mcpOAuthCallback(result.code, state_token);
|
||||
|
||||
setCredentials({
|
||||
id: callbackResult.id,
|
||||
provider: callbackResult.provider,
|
||||
type: callbackResult.type,
|
||||
title: callbackResult.title,
|
||||
});
|
||||
setAuthRequired(false);
|
||||
|
||||
// Exchange code for tokens (stored server-side)
|
||||
setLoading(true);
|
||||
try {
|
||||
const callbackResult = await api.mcpOAuthCallback(
|
||||
data.code!,
|
||||
stateTokenRef.current!,
|
||||
// Discover tools now that we're authenticated
|
||||
const toolsResult = await api.mcpDiscoverTools(serverUrl.trim());
|
||||
setTools(toolsResult.tools);
|
||||
setServerName(toolsResult.server_name);
|
||||
setStep("tool");
|
||||
} catch (e: any) {
|
||||
// If server doesn't support OAuth → show manual token entry
|
||||
if (e?.status === 400) {
|
||||
setShowManualToken(true);
|
||||
setError(
|
||||
"This server does not support OAuth sign-in. Please enter a token manually.",
|
||||
);
|
||||
setCredentials({
|
||||
id: callbackResult.id,
|
||||
provider: callbackResult.provider,
|
||||
type: callbackResult.type,
|
||||
title: callbackResult.title,
|
||||
});
|
||||
const result = await api.mcpDiscoverTools(serverUrl.trim());
|
||||
setTools(result.tools);
|
||||
setServerName(result.server_name);
|
||||
setStep("tool");
|
||||
} catch (e: any) {
|
||||
} else if (e?.message === "OAuth flow timed out") {
|
||||
setError("OAuth sign-in timed out. Please try again.");
|
||||
} else {
|
||||
const status = e?.status;
|
||||
let message: string;
|
||||
if (status === 401 || status === 403) {
|
||||
@@ -232,156 +204,21 @@ export function MCPToolDialog({
|
||||
setError(
|
||||
typeof message === "string" ? message : JSON.stringify(message),
|
||||
);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
},
|
||||
[api, serverUrl, cleanupOAuthListeners],
|
||||
);
|
||||
|
||||
const handleOAuthSignIn = useCallback(async () => {
|
||||
if (!serverUrl.trim()) return;
|
||||
setError(null);
|
||||
oauthHandledRef.current = false;
|
||||
|
||||
// Open popup SYNCHRONOUSLY (before async call) to avoid browser popup blockers
|
||||
const width = 500;
|
||||
const height = 700;
|
||||
const left = window.screenX + (window.outerWidth - width) / 2;
|
||||
const top = window.screenY + (window.outerHeight - height) / 2;
|
||||
const popup = window.open(
|
||||
"about:blank",
|
||||
"mcp_oauth",
|
||||
`width=${width},height=${height},left=${left},top=${top},scrollbars=yes`,
|
||||
);
|
||||
|
||||
setOauthLoading(true);
|
||||
oauthLoadingRef.current = true;
|
||||
|
||||
try {
|
||||
const { login_url, state_token } = await api.mcpOAuthLogin(
|
||||
serverUrl.trim(),
|
||||
);
|
||||
stateTokenRef.current = state_token;
|
||||
|
||||
if (popup && !popup.closed) {
|
||||
popup.location.href = login_url;
|
||||
} else {
|
||||
// Popup was blocked — open in new tab as fallback
|
||||
window.open(login_url, "_blank");
|
||||
}
|
||||
|
||||
// Clear any stale localStorage entry before starting
|
||||
try {
|
||||
localStorage.removeItem("mcp_oauth_result");
|
||||
} catch {}
|
||||
|
||||
// Listener 1: BroadcastChannel (works even when window.opener is null)
|
||||
try {
|
||||
const bc = new BroadcastChannel("mcp_oauth");
|
||||
bc.onmessage = (event) => {
|
||||
if (event.data?.type === "mcp_oauth_result") {
|
||||
handleOAuthResult(event.data);
|
||||
}
|
||||
};
|
||||
broadcastChannelRef.current = bc;
|
||||
} catch (e) {
|
||||
console.warn("BroadcastChannel not available:", e);
|
||||
}
|
||||
|
||||
// Listener 2: window.postMessage (fallback)
|
||||
const handleMessage = (event: MessageEvent) => {
|
||||
if (event.origin !== window.location.origin) return;
|
||||
if (event.data?.message_type === "mcp_oauth_result") {
|
||||
handleOAuthResult(event.data);
|
||||
}
|
||||
};
|
||||
messageHandlerRef.current = handleMessage;
|
||||
window.addEventListener("message", handleMessage);
|
||||
|
||||
// Listener 3: localStorage (most reliable cross-tab fallback)
|
||||
const handleStorage = (event: StorageEvent) => {
|
||||
if (event.key === "mcp_oauth_result" && event.newValue) {
|
||||
try {
|
||||
const data = JSON.parse(event.newValue);
|
||||
localStorage.removeItem("mcp_oauth_result");
|
||||
handleOAuthResult(data);
|
||||
} catch {}
|
||||
}
|
||||
};
|
||||
storageHandlerRef.current = handleStorage;
|
||||
window.addEventListener("storage", handleStorage);
|
||||
|
||||
// Fallback 1: Poll localStorage periodically.
|
||||
// StorageEvent only fires in OTHER windows, and BroadcastChannel can fail
|
||||
// in some cross-origin popup scenarios. Direct polling is the most reliable.
|
||||
storagePollRef.current = setInterval(() => {
|
||||
if (!oauthLoadingRef.current || oauthHandledRef.current) {
|
||||
if (storagePollRef.current) clearInterval(storagePollRef.current);
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const stored = localStorage.getItem("mcp_oauth_result");
|
||||
if (stored) {
|
||||
const data = JSON.parse(stored);
|
||||
localStorage.removeItem("mcp_oauth_result");
|
||||
handleOAuthResult(data);
|
||||
}
|
||||
} catch {}
|
||||
}, 500);
|
||||
|
||||
// Fallback 2: detect popup close (gives up if popup closed without result)
|
||||
const popupRef = popup;
|
||||
popupCheckRef.current = setInterval(() => {
|
||||
if (!oauthLoadingRef.current || oauthHandledRef.current) {
|
||||
if (popupCheckRef.current) clearInterval(popupCheckRef.current);
|
||||
return;
|
||||
}
|
||||
if (popupRef && popupRef.closed) {
|
||||
// Grace period: wait one more poll cycle for localStorage to be set
|
||||
setTimeout(() => {
|
||||
if (oauthHandledRef.current) return;
|
||||
try {
|
||||
const stored = localStorage.getItem("mcp_oauth_result");
|
||||
if (stored) {
|
||||
const data = JSON.parse(stored);
|
||||
localStorage.removeItem("mcp_oauth_result");
|
||||
handleOAuthResult(data);
|
||||
return;
|
||||
}
|
||||
} catch {}
|
||||
// Popup closed without result — give up
|
||||
if (popupCheckRef.current) clearInterval(popupCheckRef.current);
|
||||
}, 1000);
|
||||
if (popupCheckRef.current) clearInterval(popupCheckRef.current);
|
||||
}
|
||||
}, 500);
|
||||
|
||||
// Timeout
|
||||
setTimeout(() => {
|
||||
if (oauthLoadingRef.current) {
|
||||
cleanupOAuthListeners();
|
||||
setError("OAuth sign-in timed out. Please try again.");
|
||||
}
|
||||
}, OAUTH_TIMEOUT_MS);
|
||||
} catch (e: any) {
|
||||
if (popup && !popup.closed) popup.close();
|
||||
|
||||
// If server doesn't support OAuth → show manual token entry
|
||||
if (e?.status === 400) {
|
||||
setShowManualToken(true);
|
||||
setError(
|
||||
"This server does not support OAuth sign-in. Please enter a token manually.",
|
||||
);
|
||||
} else {
|
||||
const message = e?.message || "Failed to initiate sign-in";
|
||||
setError(
|
||||
typeof message === "string" ? message : JSON.stringify(message),
|
||||
);
|
||||
}
|
||||
cleanupOAuthListeners();
|
||||
} finally {
|
||||
setOauthLoading(false);
|
||||
setLoading(false);
|
||||
oauthAbortRef.current = null;
|
||||
}
|
||||
}, [api, serverUrl, handleOAuthResult, cleanupOAuthListeners]);
|
||||
}, [api, serverUrl, allProviders]);
|
||||
|
||||
// Auto-start OAuth sign-in when server returns 401/403
|
||||
useEffect(() => {
|
||||
if (authRequired && startOAuthRef.current) {
|
||||
startOAuthRef.current = false;
|
||||
handleOAuthSignIn();
|
||||
}
|
||||
}, [authRequired, handleOAuthSignIn]);
|
||||
|
||||
const handleConfirm = useCallback(() => {
|
||||
if (!selectedTool) return;
|
||||
@@ -403,7 +240,15 @@ export function MCPToolDialog({
|
||||
credentials,
|
||||
});
|
||||
reset();
|
||||
}, [selectedTool, tools, serverUrl, credentials, onConfirm, reset]);
|
||||
}, [
|
||||
selectedTool,
|
||||
tools,
|
||||
serverUrl,
|
||||
serverName,
|
||||
credentials,
|
||||
onConfirm,
|
||||
reset,
|
||||
]);
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={(isOpen) => !isOpen && handleClose()}>
|
||||
|
||||
@@ -86,8 +86,7 @@ export function CredentialsInput({
|
||||
handleCredentialSelect,
|
||||
} = hookData;
|
||||
|
||||
const displayName =
|
||||
(schema as any).display_name || toDisplayName(provider);
|
||||
const displayName = (schema as any).display_name || toDisplayName(provider);
|
||||
const selectedCredentialIsSystem =
|
||||
selectedCredential && isSystemCredential(selectedCredential);
|
||||
|
||||
|
||||
@@ -5,14 +5,13 @@ import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsMetaInput,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import { openOAuthPopup } from "@/lib/oauth-popup";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import {
|
||||
filterSystemCredentials,
|
||||
getActionButtonText,
|
||||
getSystemCredentials,
|
||||
OAUTH_TIMEOUT_MS,
|
||||
OAuthPopupResultMessage,
|
||||
} from "./helpers";
|
||||
|
||||
export type CredentialsInputState = ReturnType<typeof useCredentialsInput>;
|
||||
@@ -57,6 +56,14 @@ export function useCredentialsInput({
|
||||
const queryClient = useQueryClient();
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
const hasAttemptedAutoSelect = useRef(false);
|
||||
const oauthAbortRef = useRef<((reason?: string) => void) | null>(null);
|
||||
|
||||
// Clean up on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
oauthAbortRef.current?.();
|
||||
};
|
||||
}, []);
|
||||
|
||||
const deleteCredentialsMutation = useDeleteV1DeleteCredentials({
|
||||
mutation: {
|
||||
@@ -148,6 +155,7 @@ export function useCredentialsInput({
|
||||
supportsHostScoped,
|
||||
savedCredentials,
|
||||
oAuthCallback,
|
||||
mcpOAuthCallback,
|
||||
isSystemProvider,
|
||||
discriminatorValue,
|
||||
} = credentials;
|
||||
@@ -159,128 +167,92 @@ export function useCredentialsInput({
|
||||
async function handleOAuthLogin() {
|
||||
setOAuthError(null);
|
||||
|
||||
// Abort any previous OAuth flow
|
||||
oauthAbortRef.current?.();
|
||||
|
||||
// MCP uses dynamic OAuth discovery per server URL
|
||||
const isMCP = provider === "mcp" && !!discriminatorValue;
|
||||
let login_url: string;
|
||||
let state_token: string;
|
||||
|
||||
if (isMCP) {
|
||||
({ login_url, state_token } = await api.mcpOAuthLogin(
|
||||
discriminatorValue!,
|
||||
));
|
||||
} else {
|
||||
({ login_url, state_token } = await api.oAuthLogin(
|
||||
provider,
|
||||
schema.credentials_scopes,
|
||||
));
|
||||
}
|
||||
try {
|
||||
let login_url: string;
|
||||
let state_token: string;
|
||||
|
||||
setOAuth2FlowInProgress(true);
|
||||
const popup = window.open(login_url, "_blank", "popup=true");
|
||||
if (isMCP) {
|
||||
({ login_url, state_token } = await api.mcpOAuthLogin(
|
||||
discriminatorValue!,
|
||||
));
|
||||
} else {
|
||||
({ login_url, state_token } = await api.oAuthLogin(
|
||||
provider,
|
||||
schema.credentials_scopes,
|
||||
));
|
||||
}
|
||||
|
||||
if (!popup) {
|
||||
throw new Error(
|
||||
"Failed to open popup window. Please allow popups for this site.",
|
||||
setOAuth2FlowInProgress(true);
|
||||
|
||||
const { promise, cleanup } = openOAuthPopup(login_url, {
|
||||
stateToken: state_token,
|
||||
useCrossOriginListeners: isMCP,
|
||||
// Standard OAuth uses "oauth_popup_result", MCP uses "mcp_oauth_result"
|
||||
acceptMessageTypes: isMCP
|
||||
? ["mcp_oauth_result"]
|
||||
: ["oauth_popup_result"],
|
||||
});
|
||||
|
||||
oauthAbortRef.current = cleanup.abort;
|
||||
// Expose abort signal for the waiting modal's cancel button
|
||||
const controller = new AbortController();
|
||||
cleanup.signal.addEventListener("abort", () =>
|
||||
controller.abort("completed"),
|
||||
);
|
||||
}
|
||||
setOAuthPopupController(controller);
|
||||
|
||||
const controller = new AbortController();
|
||||
setOAuthPopupController(controller);
|
||||
controller.signal.onabort = () => {
|
||||
console.debug("OAuth flow aborted");
|
||||
setOAuth2FlowInProgress(false);
|
||||
popup.close();
|
||||
};
|
||||
const result = await promise;
|
||||
|
||||
const handleMessage = async (e: MessageEvent<OAuthPopupResultMessage>) => {
|
||||
console.debug("Message received:", e.data);
|
||||
if (
|
||||
typeof e.data != "object" ||
|
||||
!("message_type" in e.data) ||
|
||||
e.data.message_type !== "oauth_popup_result"
|
||||
) {
|
||||
console.debug("Ignoring irrelevant message");
|
||||
return;
|
||||
}
|
||||
// Exchange code for tokens via the provider (updates credential cache)
|
||||
const credentialResult = isMCP
|
||||
? await mcpOAuthCallback(result.code, state_token)
|
||||
: await oAuthCallback(result.code, result.state);
|
||||
|
||||
if (!e.data.success) {
|
||||
console.error("OAuth flow failed:", e.data.message);
|
||||
setOAuthError(`OAuth flow failed: ${e.data.message}`);
|
||||
setOAuth2FlowInProgress(false);
|
||||
return;
|
||||
}
|
||||
// Check if the credential's scopes match the required scopes (skip for MCP)
|
||||
if (!isMCP) {
|
||||
const requiredScopes = schema.credentials_scopes;
|
||||
if (requiredScopes && requiredScopes.length > 0) {
|
||||
const grantedScopes = new Set(credentialResult.scopes || []);
|
||||
const hasAllRequiredScopes = new Set(requiredScopes).isSubsetOf(
|
||||
grantedScopes,
|
||||
);
|
||||
|
||||
if (e.data.state !== state_token) {
|
||||
console.error("Invalid state token received");
|
||||
setOAuthError("Invalid state token received");
|
||||
setOAuth2FlowInProgress(false);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
console.debug("Processing OAuth callback");
|
||||
// MCP uses its own callback endpoint
|
||||
const credentials = isMCP
|
||||
? await api.mcpOAuthCallback(e.data.code, e.data.state)
|
||||
: await oAuthCallback(e.data.code, e.data.state);
|
||||
console.debug("OAuth callback processed successfully");
|
||||
|
||||
// Check if the credential's scopes match the required scopes (skip for MCP)
|
||||
if (!isMCP) {
|
||||
const requiredScopes = schema.credentials_scopes;
|
||||
if (requiredScopes && requiredScopes.length > 0) {
|
||||
const grantedScopes = new Set(credentials.scopes || []);
|
||||
const hasAllRequiredScopes = new Set(requiredScopes).isSubsetOf(
|
||||
grantedScopes,
|
||||
if (!hasAllRequiredScopes) {
|
||||
setOAuthError(
|
||||
"Connection failed: the granted permissions don't match what's required. " +
|
||||
"Please contact the application administrator.",
|
||||
);
|
||||
|
||||
if (!hasAllRequiredScopes) {
|
||||
console.error(
|
||||
`Newly created OAuth credential for ${providerName} has insufficient scopes. Required:`,
|
||||
requiredScopes,
|
||||
"Granted:",
|
||||
credentials.scopes,
|
||||
);
|
||||
setOAuthError(
|
||||
"Connection failed: the granted permissions don't match what's required. " +
|
||||
"Please contact the application administrator.",
|
||||
);
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
onSelectCredential({
|
||||
id: credentials.id,
|
||||
type: "oauth2",
|
||||
title: credentials.title,
|
||||
provider,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error in OAuth callback:", error);
|
||||
onSelectCredential({
|
||||
id: credentialResult.id,
|
||||
type: "oauth2",
|
||||
title: credentialResult.title,
|
||||
provider,
|
||||
});
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.message === "OAuth flow timed out") {
|
||||
setOAuthError("OAuth flow timed out");
|
||||
} else {
|
||||
setOAuthError(
|
||||
`Error in OAuth callback: ${
|
||||
`OAuth error: ${
|
||||
error instanceof Error ? error.message : String(error)
|
||||
}`,
|
||||
);
|
||||
} finally {
|
||||
console.debug("Finalizing OAuth flow");
|
||||
setOAuth2FlowInProgress(false);
|
||||
controller.abort("success");
|
||||
}
|
||||
};
|
||||
|
||||
console.debug("Adding message event listener");
|
||||
window.addEventListener("message", handleMessage, {
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
||||
setTimeout(() => {
|
||||
console.debug("OAuth flow timed out");
|
||||
controller.abort("timeout");
|
||||
} finally {
|
||||
setOAuth2FlowInProgress(false);
|
||||
setOAuthError("OAuth flow timed out");
|
||||
}, OAUTH_TIMEOUT_MS);
|
||||
oauthAbortRef.current = null;
|
||||
}
|
||||
}
|
||||
|
||||
function handleActionButtonClick() {
|
||||
|
||||
177
autogpt_platform/frontend/src/lib/oauth-popup.ts
Normal file
177
autogpt_platform/frontend/src/lib/oauth-popup.ts
Normal file
@@ -0,0 +1,177 @@
|
||||
/**
|
||||
* Shared utility for OAuth popup flows with cross-origin support.
|
||||
*
|
||||
* Handles BroadcastChannel, postMessage, and localStorage polling
|
||||
* to reliably receive OAuth callback results even when COOP headers
|
||||
* sever the window.opener relationship.
|
||||
*/
|
||||
|
||||
const DEFAULT_TIMEOUT_MS = 5 * 60 * 1000; // 5 minutes
|
||||
|
||||
export type OAuthPopupResult = {
|
||||
code: string;
|
||||
state: string;
|
||||
};
|
||||
|
||||
export type OAuthPopupOptions = {
|
||||
/** State token to validate against incoming messages */
|
||||
stateToken: string;
|
||||
/**
|
||||
* Use BroadcastChannel + localStorage polling for cross-origin OAuth (MCP).
|
||||
* Standard OAuth only uses postMessage via window.opener.
|
||||
*/
|
||||
useCrossOriginListeners?: boolean;
|
||||
/** BroadcastChannel name (default: "mcp_oauth") */
|
||||
broadcastChannelName?: string;
|
||||
/** localStorage key for cross-origin fallback (default: "mcp_oauth_result") */
|
||||
localStorageKey?: string;
|
||||
/** Message types to accept (default: ["oauth_popup_result", "mcp_oauth_result"]) */
|
||||
acceptMessageTypes?: string[];
|
||||
/** Timeout in ms (default: 5 minutes) */
|
||||
timeout?: number;
|
||||
};
|
||||
|
||||
type Cleanup = {
|
||||
/** Abort the OAuth flow and close the popup */
|
||||
abort: (reason?: string) => void;
|
||||
/** The AbortController signal */
|
||||
signal: AbortSignal;
|
||||
};
|
||||
|
||||
/**
|
||||
* Opens an OAuth popup and sets up listeners for the callback result.
|
||||
*
|
||||
* Opens a blank popup synchronously (to avoid popup blockers), then navigates
|
||||
* it to the login URL. Returns a promise that resolves with the OAuth code/state.
|
||||
*
|
||||
* @param loginUrl - The OAuth authorization URL to navigate to
|
||||
* @param options - Configuration for message handling
|
||||
* @returns Object with `promise` (resolves with OAuth result) and `abort` (cancels flow)
|
||||
*/
|
||||
export function openOAuthPopup(
|
||||
loginUrl: string,
|
||||
options: OAuthPopupOptions,
|
||||
): { promise: Promise<OAuthPopupResult>; cleanup: Cleanup } {
|
||||
const {
|
||||
stateToken,
|
||||
useCrossOriginListeners = false,
|
||||
broadcastChannelName = "mcp_oauth",
|
||||
localStorageKey = "mcp_oauth_result",
|
||||
acceptMessageTypes = ["oauth_popup_result", "mcp_oauth_result"],
|
||||
timeout = DEFAULT_TIMEOUT_MS,
|
||||
} = options;
|
||||
|
||||
const controller = new AbortController();
|
||||
|
||||
// Open popup synchronously (before any async work) to avoid browser popup blockers
|
||||
const width = 500;
|
||||
const height = 700;
|
||||
const left = window.screenX + (window.outerWidth - width) / 2;
|
||||
const top = window.screenY + (window.outerHeight - height) / 2;
|
||||
const popup = window.open(
|
||||
"about:blank",
|
||||
"_blank",
|
||||
`width=${width},height=${height},left=${left},top=${top},popup=true,scrollbars=yes`,
|
||||
);
|
||||
|
||||
if (popup && !popup.closed) {
|
||||
popup.location.href = loginUrl;
|
||||
} else {
|
||||
// Popup was blocked — open in new tab as fallback
|
||||
window.open(loginUrl, "_blank");
|
||||
}
|
||||
|
||||
// Close popup on abort
|
||||
controller.signal.addEventListener("abort", () => {
|
||||
if (popup && !popup.closed) popup.close();
|
||||
});
|
||||
|
||||
// Clear any stale localStorage entry
|
||||
if (useCrossOriginListeners) {
|
||||
try {
|
||||
localStorage.removeItem(localStorageKey);
|
||||
} catch {}
|
||||
}
|
||||
|
||||
const promise = new Promise<OAuthPopupResult>((resolve, reject) => {
|
||||
let handled = false;
|
||||
|
||||
const handleResult = (data: any) => {
|
||||
if (handled) return; // Prevent double-handling
|
||||
|
||||
// Validate message type
|
||||
const messageType = data?.message_type ?? data?.type;
|
||||
if (!messageType || !acceptMessageTypes.includes(messageType)) return;
|
||||
|
||||
// Validate state token
|
||||
if (data.state !== stateToken) {
|
||||
// State mismatch — this message is for a different listener. Ignore silently.
|
||||
return;
|
||||
}
|
||||
|
||||
handled = true;
|
||||
|
||||
if (!data.success) {
|
||||
reject(new Error(data.message || "OAuth authentication failed"));
|
||||
} else {
|
||||
resolve({ code: data.code, state: data.state });
|
||||
}
|
||||
|
||||
controller.abort("completed");
|
||||
};
|
||||
|
||||
// Listener: postMessage (works for same-origin popups)
|
||||
window.addEventListener(
|
||||
"message",
|
||||
(event: MessageEvent) => {
|
||||
if (typeof event.data === "object") {
|
||||
handleResult(event.data);
|
||||
}
|
||||
},
|
||||
{ signal: controller.signal },
|
||||
);
|
||||
|
||||
// Cross-origin listeners for MCP OAuth
|
||||
if (useCrossOriginListeners) {
|
||||
// Listener: BroadcastChannel (works across tabs/popups without opener)
|
||||
try {
|
||||
const bc = new BroadcastChannel(broadcastChannelName);
|
||||
bc.onmessage = (event) => handleResult(event.data);
|
||||
controller.signal.addEventListener("abort", () => bc.close());
|
||||
} catch {}
|
||||
|
||||
// Listener: localStorage polling (most reliable cross-tab fallback)
|
||||
const pollInterval = setInterval(() => {
|
||||
try {
|
||||
const stored = localStorage.getItem(localStorageKey);
|
||||
if (stored) {
|
||||
const data = JSON.parse(stored);
|
||||
localStorage.removeItem(localStorageKey);
|
||||
handleResult(data);
|
||||
}
|
||||
} catch {}
|
||||
}, 500);
|
||||
controller.signal.addEventListener("abort", () =>
|
||||
clearInterval(pollInterval),
|
||||
);
|
||||
}
|
||||
|
||||
// Timeout
|
||||
const timeoutId = setTimeout(() => {
|
||||
if (!handled) {
|
||||
handled = true;
|
||||
reject(new Error("OAuth flow timed out"));
|
||||
controller.abort("timeout");
|
||||
}
|
||||
}, timeout);
|
||||
controller.signal.addEventListener("abort", () => clearTimeout(timeoutId));
|
||||
});
|
||||
|
||||
return {
|
||||
promise,
|
||||
cleanup: {
|
||||
abort: (reason?: string) => controller.abort(reason || "canceled"),
|
||||
signal: controller.signal,
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -38,6 +38,11 @@ export type CredentialsProviderData = {
|
||||
code: string,
|
||||
state_token: string,
|
||||
) => Promise<CredentialsMetaResponse>;
|
||||
/** MCP-specific OAuth callback that uses dynamic per-server OAuth discovery. */
|
||||
mcpOAuthCallback: (
|
||||
code: string,
|
||||
state_token: string,
|
||||
) => Promise<CredentialsMetaResponse>;
|
||||
createAPIKeyCredentials: (
|
||||
credentials: APIKeyCredentialsCreatable,
|
||||
) => Promise<CredentialsMetaResponse>;
|
||||
@@ -120,6 +125,24 @@ export default function CredentialsProvider({
|
||||
[api, addCredentials, onFailToast],
|
||||
);
|
||||
|
||||
/** Wraps `BackendAPI.mcpOAuthCallback`, and adds the result to the internal credentials store. */
|
||||
const mcpOAuthCallback = useCallback(
|
||||
async (
|
||||
code: string,
|
||||
state_token: string,
|
||||
): Promise<CredentialsMetaResponse> => {
|
||||
try {
|
||||
const credsMeta = await api.mcpOAuthCallback(code, state_token);
|
||||
addCredentials("mcp", credsMeta);
|
||||
return credsMeta;
|
||||
} catch (error) {
|
||||
onFailToast("complete MCP OAuth authentication")(error);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
[api, addCredentials, onFailToast],
|
||||
);
|
||||
|
||||
/** Wraps `BackendAPI.createAPIKeyCredentials`, and adds the result to the internal credentials store. */
|
||||
const createAPIKeyCredentials = useCallback(
|
||||
async (
|
||||
@@ -258,6 +281,7 @@ export default function CredentialsProvider({
|
||||
isSystemProvider: systemProviders.has(provider),
|
||||
oAuthCallback: (code: string, state_token: string) =>
|
||||
oAuthCallback(provider, code, state_token),
|
||||
mcpOAuthCallback,
|
||||
createAPIKeyCredentials: (
|
||||
credentials: APIKeyCredentialsCreatable,
|
||||
) => createAPIKeyCredentials(provider, credentials),
|
||||
@@ -286,6 +310,7 @@ export default function CredentialsProvider({
|
||||
createHostScopedCredentials,
|
||||
deleteCredentials,
|
||||
oAuthCallback,
|
||||
mcpOAuthCallback,
|
||||
onFailToast,
|
||||
]);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user