From 7decc20a32f8a4ddf317a25dcfe4679d395134de Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Mon, 9 Feb 2026 17:37:24 +0400 Subject: [PATCH] fix(backend/mcp): Auto-refresh expired OAuth tokens before MCP tool calls _resolve_auth_token now checks token expiry and refreshes using MCPOAuthHandler with metadata (token_url, client_id, client_secret) stored during the OAuth callback flow. --- .../backend/backend/blocks/mcp/block.py | 53 ++++++++++- .../backend/backend/blocks/mcp/test_mcp.py | 95 +++++++++++++++++++ 2 files changed, 144 insertions(+), 4 deletions(-) diff --git a/autogpt_platform/backend/backend/blocks/mcp/block.py b/autogpt_platform/backend/backend/blocks/mcp/block.py index 2bb229e2e5..8a5856ed17 100644 --- a/autogpt_platform/backend/backend/blocks/mcp/block.py +++ b/autogpt_platform/backend/backend/blocks/mcp/block.py @@ -8,7 +8,11 @@ dropdown and the input/output schema adapts dynamically. import json import logging -from typing import Any +import time +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from backend.integrations.credentials_store import IntegrationCredentialsStore from backend.blocks.mcp.client import MCPClient, MCPClientError from backend.data.block import ( @@ -190,22 +194,63 @@ class MCPToolBlock(Block): return output_parts if output_parts else None async def _resolve_auth_token(self, credential_id: str, user_id: str) -> str | None: - """Resolve a Bearer token from a stored credential ID.""" + """Resolve a Bearer token from a stored credential ID, refreshing if needed.""" if not credential_id: return None - from backend.util.clients import get_integration_credentials_store + from backend.integrations.credentials_store import IntegrationCredentialsStore - store = get_integration_credentials_store() + store = IntegrationCredentialsStore() creds = await store.get_creds_by_id(user_id, credential_id) if not creds: logger.warning(f"Credential {credential_id} not found") return None if isinstance(creds, OAuth2Credentials): + # Refresh if token expires within 5 minutes + if ( + creds.access_token_expires_at + and creds.access_token_expires_at < int(time.time()) + 300 + ): + creds = await self._refresh_mcp_oauth(creds, user_id, store) return creds.access_token.get_secret_value() if hasattr(creds, "api_key") and creds.api_key: return creds.api_key.get_secret_value() or None return None + async def _refresh_mcp_oauth( + self, + creds: OAuth2Credentials, + user_id: str, + store: "IntegrationCredentialsStore", + ) -> OAuth2Credentials: + """Refresh MCP OAuth tokens using metadata stored during the OAuth callback.""" + from backend.blocks.mcp.oauth import MCPOAuthHandler + + metadata = creds.metadata or {} + token_url = metadata.get("mcp_token_url") + if not token_url: + logger.warning( + f"Cannot refresh MCP credential {creds.id}: no token_url in metadata" + ) + return creds + + handler = MCPOAuthHandler( + client_id=metadata.get("mcp_client_id", ""), + client_secret=metadata.get("mcp_client_secret", ""), + redirect_uri="", # Not needed for refresh + authorize_url="", # Not needed for refresh + token_url=token_url, + resource_url=metadata.get("mcp_resource_url"), + ) + + try: + fresh = await handler.refresh_tokens(creds) + await store.update_creds(user_id, fresh) + logger.info(f"Refreshed MCP OAuth credential {creds.id}") + return fresh + except Exception: + logger.exception(f"Failed to refresh MCP OAuth credential {creds.id}") + return creds + async def run( self, input_data: Input, diff --git a/autogpt_platform/backend/backend/blocks/mcp/test_mcp.py b/autogpt_platform/backend/backend/blocks/mcp/test_mcp.py index 0fd8a686b3..5921a5fff1 100644 --- a/autogpt_platform/backend/backend/blocks/mcp/test_mcp.py +++ b/autogpt_platform/backend/backend/blocks/mcp/test_mcp.py @@ -607,3 +607,98 @@ class TestMCPToolBlock: assert captured_tokens == [None] assert outputs == [("result", "ok")] + + @pytest.mark.asyncio + async def test_resolve_auth_token_refreshes_expired(self): + """Verify _resolve_auth_token refreshes expired MCP OAuth tokens.""" + import time + + from pydantic import SecretStr + + from backend.data.model import OAuth2Credentials + + block = MCPToolBlock() + expired_creds = OAuth2Credentials( + id="cred-expired", + provider="mcp", + title="MCP: test", + access_token=SecretStr("old-token"), + refresh_token=SecretStr("refresh-tok"), + access_token_expires_at=int(time.time()) - 60, # Already expired + scopes=[], + metadata={ + "mcp_token_url": "https://auth.example.com/token", + "mcp_client_id": "client-id", + "mcp_client_secret": "client-secret", + "mcp_resource_url": "https://mcp.example.com", + }, + ) + fresh_creds = OAuth2Credentials( + id="cred-expired", + provider="mcp", + title="MCP: test", + access_token=SecretStr("fresh-token"), + refresh_token=SecretStr("refresh-tok"), + access_token_expires_at=int(time.time()) + 3600, + scopes=[], + metadata=expired_creds.metadata, + ) + + mock_store = AsyncMock() + mock_store.get_creds_by_id = AsyncMock(return_value=expired_creds) + mock_store.update_creds = AsyncMock() + + mock_handler_instance = AsyncMock() + mock_handler_instance.refresh_tokens = AsyncMock(return_value=fresh_creds) + + with ( + patch( + "backend.integrations.credentials_store.IntegrationCredentialsStore", + return_value=mock_store, + ), + patch( + "backend.blocks.mcp.oauth.MCPOAuthHandler", + return_value=mock_handler_instance, + ), + ): + token = await block._resolve_auth_token("cred-expired", "user-1") + + assert token == "fresh-token" + mock_handler_instance.refresh_tokens.assert_awaited_once_with(expired_creds) + mock_store.update_creds.assert_awaited_once_with("user-1", fresh_creds) + + @pytest.mark.asyncio + async def test_resolve_auth_token_skips_refresh_if_valid(self): + """Verify _resolve_auth_token does NOT refresh tokens that are still valid.""" + import time + + from pydantic import SecretStr + + from backend.data.model import OAuth2Credentials + + block = MCPToolBlock() + valid_creds = OAuth2Credentials( + id="cred-valid", + provider="mcp", + title="MCP: test", + access_token=SecretStr("valid-token"), + refresh_token=SecretStr("refresh-tok"), + access_token_expires_at=int(time.time()) + 3600, # Still valid + scopes=[], + metadata={ + "mcp_token_url": "https://auth.example.com/token", + }, + ) + + mock_store = AsyncMock() + mock_store.get_creds_by_id = AsyncMock(return_value=valid_creds) + + with patch( + "backend.integrations.credentials_store.IntegrationCredentialsStore", + return_value=mock_store, + ): + token = await block._resolve_auth_token("cred-valid", "user-1") + + assert token == "valid-token" + # update_creds should NOT have been called (no refresh needed) + mock_store.update_creds.assert_not_awaited()