mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-09 14:25:25 -05:00
fix(backend/mcp): Auto-refresh expired OAuth tokens before MCP tool calls
_resolve_auth_token now checks token expiry and refreshes using MCPOAuthHandler with metadata (token_url, client_id, client_secret) stored during the OAuth callback flow.
This commit is contained in:
@@ -8,7 +8,11 @@ dropdown and the input/output schema adapts dynamically.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||
from backend.data.block import (
|
||||
@@ -190,22 +194,63 @@ class MCPToolBlock(Block):
|
||||
return output_parts if output_parts else None
|
||||
|
||||
async def _resolve_auth_token(self, credential_id: str, user_id: str) -> str | None:
|
||||
"""Resolve a Bearer token from a stored credential ID."""
|
||||
"""Resolve a Bearer token from a stored credential ID, refreshing if needed."""
|
||||
if not credential_id:
|
||||
return None
|
||||
from backend.util.clients import get_integration_credentials_store
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
|
||||
store = get_integration_credentials_store()
|
||||
store = IntegrationCredentialsStore()
|
||||
creds = await store.get_creds_by_id(user_id, credential_id)
|
||||
if not creds:
|
||||
logger.warning(f"Credential {credential_id} not found")
|
||||
return None
|
||||
if isinstance(creds, OAuth2Credentials):
|
||||
# Refresh if token expires within 5 minutes
|
||||
if (
|
||||
creds.access_token_expires_at
|
||||
and creds.access_token_expires_at < int(time.time()) + 300
|
||||
):
|
||||
creds = await self._refresh_mcp_oauth(creds, user_id, store)
|
||||
return creds.access_token.get_secret_value()
|
||||
if hasattr(creds, "api_key") and creds.api_key:
|
||||
return creds.api_key.get_secret_value() or None
|
||||
return None
|
||||
|
||||
async def _refresh_mcp_oauth(
|
||||
self,
|
||||
creds: OAuth2Credentials,
|
||||
user_id: str,
|
||||
store: "IntegrationCredentialsStore",
|
||||
) -> OAuth2Credentials:
|
||||
"""Refresh MCP OAuth tokens using metadata stored during the OAuth callback."""
|
||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||
|
||||
metadata = creds.metadata or {}
|
||||
token_url = metadata.get("mcp_token_url")
|
||||
if not token_url:
|
||||
logger.warning(
|
||||
f"Cannot refresh MCP credential {creds.id}: no token_url in metadata"
|
||||
)
|
||||
return creds
|
||||
|
||||
handler = MCPOAuthHandler(
|
||||
client_id=metadata.get("mcp_client_id", ""),
|
||||
client_secret=metadata.get("mcp_client_secret", ""),
|
||||
redirect_uri="", # Not needed for refresh
|
||||
authorize_url="", # Not needed for refresh
|
||||
token_url=token_url,
|
||||
resource_url=metadata.get("mcp_resource_url"),
|
||||
)
|
||||
|
||||
try:
|
||||
fresh = await handler.refresh_tokens(creds)
|
||||
await store.update_creds(user_id, fresh)
|
||||
logger.info(f"Refreshed MCP OAuth credential {creds.id}")
|
||||
return fresh
|
||||
except Exception:
|
||||
logger.exception(f"Failed to refresh MCP OAuth credential {creds.id}")
|
||||
return creds
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
|
||||
@@ -607,3 +607,98 @@ class TestMCPToolBlock:
|
||||
|
||||
assert captured_tokens == [None]
|
||||
assert outputs == [("result", "ok")]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_auth_token_refreshes_expired(self):
|
||||
"""Verify _resolve_auth_token refreshes expired MCP OAuth tokens."""
|
||||
import time
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
block = MCPToolBlock()
|
||||
expired_creds = OAuth2Credentials(
|
||||
id="cred-expired",
|
||||
provider="mcp",
|
||||
title="MCP: test",
|
||||
access_token=SecretStr("old-token"),
|
||||
refresh_token=SecretStr("refresh-tok"),
|
||||
access_token_expires_at=int(time.time()) - 60, # Already expired
|
||||
scopes=[],
|
||||
metadata={
|
||||
"mcp_token_url": "https://auth.example.com/token",
|
||||
"mcp_client_id": "client-id",
|
||||
"mcp_client_secret": "client-secret",
|
||||
"mcp_resource_url": "https://mcp.example.com",
|
||||
},
|
||||
)
|
||||
fresh_creds = OAuth2Credentials(
|
||||
id="cred-expired",
|
||||
provider="mcp",
|
||||
title="MCP: test",
|
||||
access_token=SecretStr("fresh-token"),
|
||||
refresh_token=SecretStr("refresh-tok"),
|
||||
access_token_expires_at=int(time.time()) + 3600,
|
||||
scopes=[],
|
||||
metadata=expired_creds.metadata,
|
||||
)
|
||||
|
||||
mock_store = AsyncMock()
|
||||
mock_store.get_creds_by_id = AsyncMock(return_value=expired_creds)
|
||||
mock_store.update_creds = AsyncMock()
|
||||
|
||||
mock_handler_instance = AsyncMock()
|
||||
mock_handler_instance.refresh_tokens = AsyncMock(return_value=fresh_creds)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.integrations.credentials_store.IntegrationCredentialsStore",
|
||||
return_value=mock_store,
|
||||
),
|
||||
patch(
|
||||
"backend.blocks.mcp.oauth.MCPOAuthHandler",
|
||||
return_value=mock_handler_instance,
|
||||
),
|
||||
):
|
||||
token = await block._resolve_auth_token("cred-expired", "user-1")
|
||||
|
||||
assert token == "fresh-token"
|
||||
mock_handler_instance.refresh_tokens.assert_awaited_once_with(expired_creds)
|
||||
mock_store.update_creds.assert_awaited_once_with("user-1", fresh_creds)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_auth_token_skips_refresh_if_valid(self):
|
||||
"""Verify _resolve_auth_token does NOT refresh tokens that are still valid."""
|
||||
import time
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
block = MCPToolBlock()
|
||||
valid_creds = OAuth2Credentials(
|
||||
id="cred-valid",
|
||||
provider="mcp",
|
||||
title="MCP: test",
|
||||
access_token=SecretStr("valid-token"),
|
||||
refresh_token=SecretStr("refresh-tok"),
|
||||
access_token_expires_at=int(time.time()) + 3600, # Still valid
|
||||
scopes=[],
|
||||
metadata={
|
||||
"mcp_token_url": "https://auth.example.com/token",
|
||||
},
|
||||
)
|
||||
|
||||
mock_store = AsyncMock()
|
||||
mock_store.get_creds_by_id = AsyncMock(return_value=valid_creds)
|
||||
|
||||
with patch(
|
||||
"backend.integrations.credentials_store.IntegrationCredentialsStore",
|
||||
return_value=mock_store,
|
||||
):
|
||||
token = await block._resolve_auth_token("cred-valid", "user-1")
|
||||
|
||||
assert token == "valid-token"
|
||||
# update_creds should NOT have been called (no refresh needed)
|
||||
mock_store.update_creds.assert_not_awaited()
|
||||
|
||||
Reference in New Issue
Block a user