fix(backend/mcp): Auto-refresh expired OAuth tokens before MCP tool calls

_resolve_auth_token now checks token expiry and refreshes using
MCPOAuthHandler with metadata (token_url, client_id, client_secret)
stored during the OAuth callback flow.
This commit is contained in:
Zamil Majdy
2026-02-09 17:37:24 +04:00
parent 54375065d5
commit 7decc20a32
2 changed files with 144 additions and 4 deletions

View File

@@ -8,7 +8,11 @@ dropdown and the input/output schema adapts dynamically.
import json
import logging
from typing import Any
import time
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.data.block import (
@@ -190,22 +194,63 @@ class MCPToolBlock(Block):
return output_parts if output_parts else None
async def _resolve_auth_token(self, credential_id: str, user_id: str) -> str | None:
"""Resolve a Bearer token from a stored credential ID."""
"""Resolve a Bearer token from a stored credential ID, refreshing if needed."""
if not credential_id:
return None
from backend.util.clients import get_integration_credentials_store
from backend.integrations.credentials_store import IntegrationCredentialsStore
store = get_integration_credentials_store()
store = IntegrationCredentialsStore()
creds = await store.get_creds_by_id(user_id, credential_id)
if not creds:
logger.warning(f"Credential {credential_id} not found")
return None
if isinstance(creds, OAuth2Credentials):
# Refresh if token expires within 5 minutes
if (
creds.access_token_expires_at
and creds.access_token_expires_at < int(time.time()) + 300
):
creds = await self._refresh_mcp_oauth(creds, user_id, store)
return creds.access_token.get_secret_value()
if hasattr(creds, "api_key") and creds.api_key:
return creds.api_key.get_secret_value() or None
return None
async def _refresh_mcp_oauth(
self,
creds: OAuth2Credentials,
user_id: str,
store: "IntegrationCredentialsStore",
) -> OAuth2Credentials:
"""Refresh MCP OAuth tokens using metadata stored during the OAuth callback."""
from backend.blocks.mcp.oauth import MCPOAuthHandler
metadata = creds.metadata or {}
token_url = metadata.get("mcp_token_url")
if not token_url:
logger.warning(
f"Cannot refresh MCP credential {creds.id}: no token_url in metadata"
)
return creds
handler = MCPOAuthHandler(
client_id=metadata.get("mcp_client_id", ""),
client_secret=metadata.get("mcp_client_secret", ""),
redirect_uri="", # Not needed for refresh
authorize_url="", # Not needed for refresh
token_url=token_url,
resource_url=metadata.get("mcp_resource_url"),
)
try:
fresh = await handler.refresh_tokens(creds)
await store.update_creds(user_id, fresh)
logger.info(f"Refreshed MCP OAuth credential {creds.id}")
return fresh
except Exception:
logger.exception(f"Failed to refresh MCP OAuth credential {creds.id}")
return creds
async def run(
self,
input_data: Input,

View File

@@ -607,3 +607,98 @@ class TestMCPToolBlock:
assert captured_tokens == [None]
assert outputs == [("result", "ok")]
@pytest.mark.asyncio
async def test_resolve_auth_token_refreshes_expired(self):
"""Verify _resolve_auth_token refreshes expired MCP OAuth tokens."""
import time
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
block = MCPToolBlock()
expired_creds = OAuth2Credentials(
id="cred-expired",
provider="mcp",
title="MCP: test",
access_token=SecretStr("old-token"),
refresh_token=SecretStr("refresh-tok"),
access_token_expires_at=int(time.time()) - 60, # Already expired
scopes=[],
metadata={
"mcp_token_url": "https://auth.example.com/token",
"mcp_client_id": "client-id",
"mcp_client_secret": "client-secret",
"mcp_resource_url": "https://mcp.example.com",
},
)
fresh_creds = OAuth2Credentials(
id="cred-expired",
provider="mcp",
title="MCP: test",
access_token=SecretStr("fresh-token"),
refresh_token=SecretStr("refresh-tok"),
access_token_expires_at=int(time.time()) + 3600,
scopes=[],
metadata=expired_creds.metadata,
)
mock_store = AsyncMock()
mock_store.get_creds_by_id = AsyncMock(return_value=expired_creds)
mock_store.update_creds = AsyncMock()
mock_handler_instance = AsyncMock()
mock_handler_instance.refresh_tokens = AsyncMock(return_value=fresh_creds)
with (
patch(
"backend.integrations.credentials_store.IntegrationCredentialsStore",
return_value=mock_store,
),
patch(
"backend.blocks.mcp.oauth.MCPOAuthHandler",
return_value=mock_handler_instance,
),
):
token = await block._resolve_auth_token("cred-expired", "user-1")
assert token == "fresh-token"
mock_handler_instance.refresh_tokens.assert_awaited_once_with(expired_creds)
mock_store.update_creds.assert_awaited_once_with("user-1", fresh_creds)
@pytest.mark.asyncio
async def test_resolve_auth_token_skips_refresh_if_valid(self):
"""Verify _resolve_auth_token does NOT refresh tokens that are still valid."""
import time
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
block = MCPToolBlock()
valid_creds = OAuth2Credentials(
id="cred-valid",
provider="mcp",
title="MCP: test",
access_token=SecretStr("valid-token"),
refresh_token=SecretStr("refresh-tok"),
access_token_expires_at=int(time.time()) + 3600, # Still valid
scopes=[],
metadata={
"mcp_token_url": "https://auth.example.com/token",
},
)
mock_store = AsyncMock()
mock_store.get_creds_by_id = AsyncMock(return_value=valid_creds)
with patch(
"backend.integrations.credentials_store.IntegrationCredentialsStore",
return_value=mock_store,
):
token = await block._resolve_auth_token("cred-valid", "user-1")
assert token == "valid-token"
# update_creds should NOT have been called (no refresh needed)
mock_store.update_creds.assert_not_awaited()