mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
## Summary <img width="1000" alt="image" src="https://github.com/user-attachments/assets/18e8ef34-d222-453c-8b0a-1b25ef8cf806" /> <img width="250" alt="image" src="https://github.com/user-attachments/assets/ba97556c-09c5-4f76-9f4e-49a2e8e57468" /> <img width="250" alt="image" src="https://github.com/user-attachments/assets/68f7804a-fe74-442d-9849-39a229c052cf" /> <img width="250" alt="image" src="https://github.com/user-attachments/assets/700690ba-f9fe-4726-8871-3bfbab586001" /> Full-stack MCP (Model Context Protocol) tool block integration that allows users to connect to any MCP server, discover available tools, authenticate via OAuth, and execute tools — all through the standard AutoGPT credential system. ### Backend - **MCPToolBlock** (`blocks/mcp/block.py`): New block using `CredentialsMetaInput` pattern with optional credentials (`default={}`), supporting both authenticated (OAuth) and public MCP servers. Includes auto-lookup fallback for backward compatibility. - **MCP Client** (`blocks/mcp/client.py`): HTTP transport with JSON-RPC 2.0, tool discovery, tool execution with robust error handling (type-checked error fields, non-JSON response handling) - **MCP OAuth Handler** (`blocks/mcp/oauth.py`): RFC 8414 discovery, dynamic per-server OAuth with PKCE, token storage and refresh via `raise_for_status=True` - **MCP API Routes** (`api/features/mcp/routes.py`): `discover-tools`, `oauth/login`, `oauth/callback` endpoints with credential cleanup, defensive OAuth metadata validation - **Credential system integration**: - `CredentialsMetaInput` model_validator normalizes legacy `"ProviderName.MCP"` format from Python 3.13's `str(StrEnum)` change - `CredentialsFieldInfo.combine()` supports URL-based credential discrimination (each MCP server gets its own credential entry) - `aggregate_credentials_inputs` checks block schema defaults for credential optionality - Executor normalizes credential data for both Pydantic and JSON schema validation paths - Chat credential matching handles MCP server URL filtering - `provider_matches()` helper used consistently for Python 3.13 StrEnum compatibility - **Pre-run validation**: `_validate_graph_get_errors` now calls `get_missing_input()` for custom block-level validation (MCP tool arguments) - **Security**: HTML tag stripping loop to prevent XSS bypass, SSRF protection (removed trusted_origins) ### Frontend - **MCPToolDialog** (`MCPToolDialog.tsx`): Full tool discovery UI — enter server URL, authenticate if needed, browse tools, select tool and configure - **OAuth popup** (`oauth-popup.ts`): Shared utility supporting cross-origin MCP OAuth flows with BroadcastChannel + localStorage fallback - **Credential integration**: MCP-specific OAuth flow in `useCredentialsInput`, server URL filtering in `useCredentials`, MCP callback page - **CredentialsSelect**: Auto-selects first available credential instead of defaulting to "None", credentials listed before "None" in dropdown - **Node rendering**: Dynamic tool input schema rendering on MCP nodes, proper handling in both legacy and new flow editors - **Block title persistence**: `customized_name` set at block creation for both MCP and Agent blocks — no fallback logic needed, titles survive save/load reliably - **Stable credential ordering**: Removed `sortByUnsetFirst` that caused credential inputs to jump when selected ### Tests (~2060 lines) - Unit tests: block, client, tool execution - Integration tests: mock MCP server with auth - OAuth flow tests - API endpoint tests - Credential combining/optionality tests - E2e tests (skipped in CI, run manually) ## Key Design Decisions 1. **Optional credentials via `default={}`**: MCP servers can be public (no auth) or private (OAuth). The `credentials` field has `default={}` making it optional at the schema level, so public servers work without prompting for credentials. 2. **URL-based credential discrimination**: Each MCP server URL gets its own credential entry in the "Run agent" form (via `discriminator="server_url"`), so agents using multiple MCP servers prompt for each independently. 3. **Model-level normalization**: Python 3.13 changed `str(StrEnum)` to return `"ClassName.MEMBER"`. Rather than scattering fixes across the codebase, a Pydantic `model_validator(mode="before")` on `CredentialsMetaInput` handles normalization centrally, and `provider_matches()` handles lookups. 4. **Credential auto-select**: `CredentialsSelect` component defaults to the first available credential and notifies the parent state, ensuring credentials are pre-filled in the "Run agent" dialog without requiring manual selection. 5. **customized_name for block titles**: Both MCP and Agent blocks set `customized_name` in metadata at creation time. This eliminates convoluted runtime fallback logic (`agent_name`, hostname extraction) — the title is persisted once and read directly. ## Test plan - [x] Unit/integration tests pass (68 MCP + 11 graph = 79 tests) - [x] Manual: MCP block with public server (DeepWiki) — no credentials needed, tools discovered and executable - [x] Manual: MCP block with OAuth server (Linear, Sentry) — OAuth flow prompts correctly - [x] Manual: "Run agent" form shows correct credential requirements per MCP server - [x] Manual: Credential auto-selects when exactly one matches, pre-selects first when multiple exist - [x] Manual: Credential ordering stays stable when selecting/deselecting - [x] Manual: MCP block title persists after save and refresh - [x] Manual: Agent block title persists after save and refresh (via customized_name) - [ ] Manual: Shared agent with MCP block prompts new user for credentials --------- Co-authored-by: Otto <otto@agpt.co> Co-authored-by: Ubbe <hi@ubbe.dev>
273 lines
11 KiB
Python
273 lines
11 KiB
Python
import logging
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
from datetime import datetime
|
|
from typing import TYPE_CHECKING, Any, Callable, Coroutine
|
|
|
|
from autogpt_libs.utils.synchronize import AsyncRedisKeyedMutex
|
|
from redis.asyncio.lock import Lock as AsyncRedisLock
|
|
|
|
from backend.data.model import Credentials, OAuth2Credentials
|
|
from backend.data.redis_client import get_redis_async
|
|
from backend.integrations.credentials_store import (
|
|
IntegrationCredentialsStore,
|
|
provider_matches,
|
|
)
|
|
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
|
from backend.integrations.providers import ProviderName
|
|
from backend.util.exceptions import MissingConfigError
|
|
from backend.util.settings import Settings
|
|
|
|
if TYPE_CHECKING:
|
|
from backend.integrations.oauth import BaseOAuthHandler
|
|
|
|
logger = logging.getLogger(__name__)
|
|
settings = Settings()
|
|
|
|
|
|
class IntegrationCredentialsManager:
|
|
"""
|
|
Handles the lifecycle of integration credentials.
|
|
- Automatically refreshes requested credentials if needed.
|
|
- Uses locking mechanisms to ensure system-wide consistency and
|
|
prevent invalidation of in-use tokens.
|
|
|
|
### ⚠️ Gotcha
|
|
With `acquire(..)`, credentials can only be in use in one place at a time (e.g. one
|
|
block execution).
|
|
|
|
### Locking mechanism
|
|
- Because *getting* credentials can result in a refresh (= *invalidation* +
|
|
*replacement*) of the stored credentials, *getting* is an operation that
|
|
potentially requires read/write access.
|
|
- Checking whether a token has to be refreshed is subject to an additional `refresh`
|
|
scoped lock to prevent unnecessary sequential refreshes when multiple executions
|
|
try to access the same credentials simultaneously.
|
|
- We MUST lock credentials while in use to prevent them from being invalidated while
|
|
they are in use, e.g. because they are being refreshed by a different part
|
|
of the system.
|
|
- The `!time_sensitive` lock in `acquire(..)` is part of a two-tier locking
|
|
mechanism in which *updating* gets priority over *getting* credentials.
|
|
This is to prevent a long queue of waiting *get* requests from blocking essential
|
|
credential refreshes or user-initiated updates.
|
|
|
|
It is possible to implement a reader/writer locking system where either multiple
|
|
readers or a single writer can have simultaneous access, but this would add a lot of
|
|
complexity to the mechanism. I don't expect the current ("simple") mechanism to
|
|
cause so much latency that it's worth implementing.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.store = IntegrationCredentialsStore()
|
|
self._locks = None
|
|
|
|
async def locks(self) -> AsyncRedisKeyedMutex:
|
|
if self._locks:
|
|
return self._locks
|
|
|
|
self._locks = AsyncRedisKeyedMutex(await get_redis_async())
|
|
return self._locks
|
|
|
|
async def create(self, user_id: str, credentials: Credentials) -> None:
|
|
return await self.store.add_creds(user_id, credentials)
|
|
|
|
async def exists(self, user_id: str, credentials_id: str) -> bool:
|
|
return (await self.store.get_creds_by_id(user_id, credentials_id)) is not None
|
|
|
|
async def get(
|
|
self, user_id: str, credentials_id: str, lock: bool = True
|
|
) -> Credentials | None:
|
|
credentials = await self.store.get_creds_by_id(user_id, credentials_id)
|
|
if not credentials:
|
|
return None
|
|
|
|
# Refresh OAuth credentials if needed
|
|
if credentials.type == "oauth2" and credentials.access_token_expires_at:
|
|
logger.debug(
|
|
f"Credentials #{credentials.id} expire at "
|
|
f"{datetime.fromtimestamp(credentials.access_token_expires_at)}; "
|
|
f"current time is {datetime.now()}"
|
|
)
|
|
credentials = await self.refresh_if_needed(user_id, credentials, lock)
|
|
else:
|
|
logger.debug(f"Credentials #{credentials.id} never expire")
|
|
|
|
return credentials
|
|
|
|
async def acquire(
|
|
self, user_id: str, credentials_id: str
|
|
) -> tuple[Credentials, AsyncRedisLock]:
|
|
"""
|
|
⚠️ WARNING: this locks credentials system-wide and blocks both acquiring
|
|
and updating them elsewhere until the lock is released.
|
|
See the class docstring for more info.
|
|
"""
|
|
# Use a low-priority (!time_sensitive) locking queue on top of the general lock
|
|
# to allow priority access for refreshing/updating the tokens.
|
|
async with self._locked(user_id, credentials_id, "!time_sensitive"):
|
|
lock = await self._acquire_lock(user_id, credentials_id)
|
|
credentials = await self.get(user_id, credentials_id, lock=False)
|
|
if not credentials:
|
|
raise ValueError(
|
|
f"Credentials #{credentials_id} for user #{user_id} not found"
|
|
)
|
|
return credentials, lock
|
|
|
|
def cached_getter(
|
|
self, user_id: str
|
|
) -> Callable[[str], "Coroutine[Any, Any, Credentials | None]"]:
|
|
all_credentials = None
|
|
|
|
async def get_credentials(creds_id: str) -> "Credentials | None":
|
|
nonlocal all_credentials
|
|
if not all_credentials:
|
|
# Fetch credentials on first necessity
|
|
all_credentials = await self.store.get_all_creds(user_id)
|
|
|
|
credential = next((c for c in all_credentials if c.id == creds_id), None)
|
|
if not credential:
|
|
return None
|
|
if credential.type != "oauth2" or not credential.access_token_expires_at:
|
|
# Credential doesn't expire
|
|
return credential
|
|
|
|
# Credential is OAuth2 credential and has expiration timestamp
|
|
return await self.refresh_if_needed(user_id, credential)
|
|
|
|
return get_credentials
|
|
|
|
async def refresh_if_needed(
|
|
self, user_id: str, credentials: OAuth2Credentials, lock: bool = True
|
|
) -> OAuth2Credentials:
|
|
async with self._locked(user_id, credentials.id, "refresh"):
|
|
if provider_matches(credentials.provider, ProviderName.MCP.value):
|
|
oauth_handler = create_mcp_oauth_handler(credentials)
|
|
else:
|
|
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
|
|
if oauth_handler.needs_refresh(credentials):
|
|
logger.debug(
|
|
f"Refreshing '{credentials.provider}' "
|
|
f"credentials #{credentials.id}"
|
|
)
|
|
_lock = None
|
|
if lock:
|
|
# Wait until the credentials are no longer in use anywhere
|
|
_lock = await self._acquire_lock(user_id, credentials.id)
|
|
|
|
fresh_credentials = await oauth_handler.refresh_tokens(credentials)
|
|
await self.store.update_creds(user_id, fresh_credentials)
|
|
if _lock and (await _lock.locked()) and (await _lock.owned()):
|
|
try:
|
|
await _lock.release()
|
|
except Exception as e:
|
|
logger.warning(f"Failed to release OAuth refresh lock: {e}")
|
|
|
|
credentials = fresh_credentials
|
|
return credentials
|
|
|
|
async def update(self, user_id: str, updated: Credentials) -> None:
|
|
async with self._locked(user_id, updated.id):
|
|
await self.store.update_creds(user_id, updated)
|
|
|
|
async def delete(self, user_id: str, credentials_id: str) -> None:
|
|
async with self._locked(user_id, credentials_id):
|
|
await self.store.delete_creds_by_id(user_id, credentials_id)
|
|
|
|
# -- Locking utilities -- #
|
|
|
|
async def _acquire_lock(
|
|
self, user_id: str, credentials_id: str, *args: str
|
|
) -> AsyncRedisLock:
|
|
key = (
|
|
f"user:{user_id}",
|
|
f"credentials:{credentials_id}",
|
|
*args,
|
|
)
|
|
locks = await self.locks()
|
|
return await locks.acquire(key)
|
|
|
|
@asynccontextmanager
|
|
async def _locked(self, user_id: str, credentials_id: str, *args: str):
|
|
lock = await self._acquire_lock(user_id, credentials_id, *args)
|
|
try:
|
|
yield
|
|
finally:
|
|
if (await lock.locked()) and (await lock.owned()):
|
|
try:
|
|
await lock.release()
|
|
except Exception as e:
|
|
logger.warning(f"Failed to release credentials lock: {e}")
|
|
|
|
async def release_all_locks(self):
|
|
"""Call this on process termination to ensure all locks are released"""
|
|
await (await self.locks()).release_all_locks()
|
|
await (await self.store.locks()).release_all_locks()
|
|
|
|
|
|
async def _get_provider_oauth_handler(provider_name_str: str) -> "BaseOAuthHandler":
|
|
provider_name = ProviderName(provider_name_str)
|
|
if provider_name not in HANDLERS_BY_NAME:
|
|
raise KeyError(f"Unknown provider '{provider_name}'")
|
|
|
|
provider_creds = CREDENTIALS_BY_PROVIDER[provider_name]
|
|
if not provider_creds.use_secrets:
|
|
# This is safe to do as we check that the env vars exist in the registry
|
|
client_id = (
|
|
os.getenv(provider_creds.client_id_env_var)
|
|
if provider_creds.client_id_env_var
|
|
else None
|
|
)
|
|
client_secret = (
|
|
os.getenv(provider_creds.client_secret_env_var)
|
|
if provider_creds.client_secret_env_var
|
|
else None
|
|
)
|
|
else:
|
|
client_id = getattr(settings.secrets, f"{provider_name.value}_client_id")
|
|
client_secret = getattr(
|
|
settings.secrets, f"{provider_name.value}_client_secret"
|
|
)
|
|
|
|
if not (client_id and client_secret):
|
|
raise MissingConfigError(
|
|
f"Integration with provider '{provider_name}' is not configured",
|
|
)
|
|
|
|
handler_class = HANDLERS_BY_NAME[provider_name]
|
|
frontend_base_url = (
|
|
settings.config.frontend_base_url or settings.config.platform_base_url
|
|
)
|
|
return handler_class(
|
|
client_id=client_id,
|
|
client_secret=client_secret,
|
|
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
|
|
)
|
|
|
|
|
|
def create_mcp_oauth_handler(
|
|
credentials: OAuth2Credentials,
|
|
) -> "BaseOAuthHandler":
|
|
"""Create an MCPOAuthHandler from credential metadata for token refresh.
|
|
|
|
MCP OAuth handlers have dynamic endpoints discovered per-server, so they
|
|
can't be registered as singletons in HANDLERS_BY_NAME. Instead, the handler
|
|
is reconstructed from metadata stored on the credential during initial auth.
|
|
"""
|
|
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
|
|
|
meta = credentials.metadata or {}
|
|
token_url = meta.get("mcp_token_url", "")
|
|
if not token_url:
|
|
raise ValueError(
|
|
f"MCP credential {credentials.id} is missing 'mcp_token_url' metadata; "
|
|
"cannot refresh tokens"
|
|
)
|
|
return MCPOAuthHandler(
|
|
client_id=meta.get("mcp_client_id", ""),
|
|
client_secret=meta.get("mcp_client_secret", ""),
|
|
redirect_uri="", # Not needed for token refresh
|
|
authorize_url="", # Not needed for token refresh
|
|
token_url=token_url,
|
|
resource_url=meta.get("mcp_resource_url"),
|
|
)
|