mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-08 13:55:06 -05:00
Compare commits
9 Commits
feat/mcp-b
...
feat/copit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0bbe8a184d | ||
|
|
7592deed63 | ||
|
|
b9c759ce4f | ||
|
|
5efb80d47b | ||
|
|
b49d8e2cba | ||
|
|
452544530d | ||
|
|
32ee7e6cf8 | ||
|
|
670663c406 | ||
|
|
0dbe4cf51e |
@@ -1,76 +0,0 @@
|
||||
# MCP Block Implementation Plan
|
||||
|
||||
## Overview
|
||||
|
||||
Create a single **MCPBlock** that dynamically integrates with any MCP (Model Context Protocol)
|
||||
server. Users provide a server URL, the block discovers available tools, presents them as a
|
||||
dropdown, and dynamically adjusts input/output schema based on the selected tool — exactly like
|
||||
`AgentExecutorBlock` handles dynamic schemas.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
User provides MCP server URL + credentials
|
||||
↓
|
||||
MCPBlock fetches tools via MCP protocol (tools/list)
|
||||
↓
|
||||
User selects tool from dropdown (stored in constantInput)
|
||||
↓
|
||||
Input schema dynamically updates based on selected tool's inputSchema
|
||||
↓
|
||||
On execution: MCPBlock calls the tool via MCP protocol (tools/call)
|
||||
↓
|
||||
Result yielded as block output
|
||||
```
|
||||
|
||||
## Design Decisions
|
||||
|
||||
1. **Single block, not many blocks** — One `MCPBlock` handles all MCP servers/tools
|
||||
2. **Dynamic schema via AgentExecutorBlock pattern** — Override `get_input_schema()`,
|
||||
`get_input_defaults()`, `get_missing_input()` on the Input class
|
||||
3. **Auth via API key or OAuth2 credentials** — Use existing `APIKeyCredentials` or
|
||||
`OAuth2Credentials` with `ProviderName.MCP` provider. API keys are sent as Bearer tokens;
|
||||
OAuth2 uses the access token.
|
||||
4. **HTTP-based MCP client** — Use `aiohttp` (already a dependency) to implement MCP Streamable
|
||||
HTTP transport directly. No need for the `mcp` Python SDK — the protocol is simple JSON-RPC
|
||||
over HTTP. Handles both JSON and SSE response formats.
|
||||
5. **No new DB tables** — Everything fits in existing `AgentBlock` + `AgentNode` tables
|
||||
|
||||
## Implementation Files
|
||||
|
||||
### New Files
|
||||
- `backend/blocks/mcp/` — MCP block package
|
||||
- `__init__.py`
|
||||
- `block.py` — MCPToolBlock implementation
|
||||
- `client.py` — MCP HTTP client (list_tools, call_tool)
|
||||
- `oauth.py` — MCP OAuth handler for dynamic endpoint discovery
|
||||
- `test_mcp.py` — Unit tests
|
||||
- `test_oauth.py` — OAuth handler tests
|
||||
- `test_integration.py` — Integration tests with local test server
|
||||
- `test_e2e.py` — E2E tests against real MCP servers
|
||||
|
||||
### Modified Files
|
||||
- `backend/integrations/providers.py` — Add `MCP = "mcp"` to ProviderName
|
||||
|
||||
## Dev Loop
|
||||
|
||||
```bash
|
||||
cd autogpt_platform/backend
|
||||
poetry run pytest backend/blocks/mcp/test_mcp.py -xvs # Unit tests
|
||||
poetry run pytest backend/blocks/mcp/test_oauth.py -xvs # OAuth tests
|
||||
poetry run pytest backend/blocks/mcp/test_integration.py -xvs # Integration tests
|
||||
poetry run pytest backend/blocks/mcp/ -xvs # All MCP tests
|
||||
```
|
||||
|
||||
## Status
|
||||
|
||||
- [x] Research & Design
|
||||
- [x] Add ProviderName.MCP
|
||||
- [x] Implement MCP client (client.py)
|
||||
- [x] Implement MCPToolBlock (block.py)
|
||||
- [x] Add OAuth2 support (oauth.py)
|
||||
- [x] Write unit tests
|
||||
- [x] Write integration tests
|
||||
- [x] Write E2E tests
|
||||
- [x] Run tests & fix issues
|
||||
- [x] Create PR
|
||||
@@ -27,12 +27,20 @@ class ChatConfig(BaseSettings):
|
||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||
|
||||
# Streaming Configuration
|
||||
# Note: When using Claude Agent SDK, context management is handled automatically
|
||||
# via the SDK's built-in compaction. This is mainly used for the fallback path.
|
||||
max_context_messages: int = Field(
|
||||
default=50, ge=1, le=200, description="Maximum context messages"
|
||||
default=100,
|
||||
ge=1,
|
||||
le=500,
|
||||
description="Max context messages (SDK handles compaction automatically)",
|
||||
)
|
||||
|
||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Max retries for fallback path (SDK handles retries internally)",
|
||||
)
|
||||
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||
max_agent_schedules: int = Field(
|
||||
default=30, description="Maximum number of agent schedules"
|
||||
@@ -93,6 +101,12 @@ class ChatConfig(BaseSettings):
|
||||
description="Name of the prompt in Langfuse to fetch",
|
||||
)
|
||||
|
||||
# Claude Agent SDK Configuration
|
||||
use_claude_agent_sdk: bool = Field(
|
||||
default=True,
|
||||
description="Use Claude Agent SDK for chat completions",
|
||||
)
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@classmethod
|
||||
def get_api_key(cls, v):
|
||||
@@ -132,6 +146,17 @@ class ChatConfig(BaseSettings):
|
||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||
return v
|
||||
|
||||
@field_validator("use_claude_agent_sdk", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_agent_sdk(cls, v):
|
||||
"""Get use_claude_agent_sdk from environment if not provided."""
|
||||
# Check environment variable - default to True if not set
|
||||
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
# Default to True (SDK enabled by default)
|
||||
return True if v is None else v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
|
||||
@@ -273,9 +273,8 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
|
||||
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
@@ -317,11 +316,9 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
return None
|
||||
|
||||
messages = prisma_session.Messages
|
||||
logger.info(
|
||||
f"Loading session {session_id} from DB: "
|
||||
f"has_messages={messages is not None}, "
|
||||
f"message_count={len(messages) if messages else 0}, "
|
||||
f"roles={[m.role for m in messages] if messages else []}"
|
||||
logger.debug(
|
||||
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
|
||||
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
|
||||
)
|
||||
|
||||
return ChatSession.from_db(prisma_session, messages)
|
||||
@@ -372,10 +369,9 @@ async def _save_session_to_db(
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
logger.debug(
|
||||
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
|
||||
f"roles={[m['role'] for m in messages_data]}"
|
||||
)
|
||||
await chat_db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
@@ -415,7 +411,7 @@ async def get_chat_session(
|
||||
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||
|
||||
# Fall back to database
|
||||
logger.info(f"Session {session_id} not in cache, checking database")
|
||||
logger.debug(f"Session {session_id} not in cache, checking database")
|
||||
session = await _get_session_from_db(session_id)
|
||||
|
||||
if session is None:
|
||||
@@ -432,7 +428,6 @@ async def get_chat_session(
|
||||
# Cache the session from DB
|
||||
try:
|
||||
await _cache_session(session)
|
||||
logger.info(f"Cached session {session_id} from database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||
|
||||
@@ -603,13 +598,19 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
return False
|
||||
|
||||
# Invalidate cache so next fetch gets updated title
|
||||
# Update title in cache if it exists (instead of invalidating).
|
||||
# This prevents race conditions where cache invalidation causes
|
||||
# the frontend to see stale DB data while streaming is still in progress.
|
||||
try:
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
cached = await _get_session_from_cache(session_id)
|
||||
if cached:
|
||||
cached.title = title
|
||||
await _cache_session(cached)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
||||
# Not critical - title will be correct on next full cache refresh
|
||||
logger.warning(
|
||||
f"Failed to update title in cache for session {session_id}: {e}"
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Chat API routes for chat session management and streaming via SSE."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid as uuid_module
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -16,8 +17,17 @@ from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
create_chat_session,
|
||||
get_chat_session,
|
||||
get_user_sessions,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
||||
from .sdk import service as sdk_service
|
||||
from .tracking import track_user_message
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -209,6 +219,10 @@ async def get_session(
|
||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||
session_id, user_id
|
||||
)
|
||||
logger.info(
|
||||
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
|
||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
||||
)
|
||||
if active_task:
|
||||
# Filter out the in-progress assistant message from the session response.
|
||||
# The client will receive the complete assistant response through the SSE
|
||||
@@ -265,10 +279,30 @@ async def stream_chat_post(
|
||||
containing the task_id for reconnection.
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
# Add user message to session BEFORE creating task to avoid race condition
|
||||
# where GET_SESSION sees the task as "running" but the message isn't saved yet
|
||||
if request.message:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="user" if request.is_user_message else "assistant",
|
||||
content=request.message,
|
||||
)
|
||||
)
|
||||
if request.is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
message_length=len(request.message),
|
||||
)
|
||||
logger.info(
|
||||
f"[STREAM] Saving user message to session {session_id}, "
|
||||
f"msg_count={len(session.messages)}"
|
||||
)
|
||||
session = await upsert_chat_session(session)
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
task_id = str(uuid_module.uuid4())
|
||||
operation_id = str(uuid_module.uuid4())
|
||||
@@ -283,24 +317,38 @@ async def stream_chat_post(
|
||||
|
||||
# Background task that runs the AI generation independently of SSE connection
|
||||
async def run_ai_generation():
|
||||
chunk_count = 0
|
||||
try:
|
||||
# Emit a start event with task_id for reconnection
|
||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
# Choose service based on configuration
|
||||
use_sdk = config.use_claude_agent_sdk
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else chat_service.stream_chat_completion
|
||||
)
|
||||
# Pass message=None since we already added it to the session above
|
||||
async for chunk in stream_fn(
|
||||
session_id,
|
||||
request.message,
|
||||
None, # Message already in session
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
session=session, # Pass session with message already added
|
||||
context=request.context,
|
||||
):
|
||||
chunk_count += 1
|
||||
# Write to Redis (subscribers will receive via XREAD)
|
||||
await stream_registry.publish_chunk(task_id, chunk)
|
||||
|
||||
# Mark task as completed
|
||||
await stream_registry.mark_task_completed(task_id, "completed")
|
||||
logger.info(
|
||||
f"[BG_TASK] AI generation completed for session {session_id}: {chunk_count} chunks, marking task {task_id} as completed"
|
||||
)
|
||||
# Mark task as completed (also publishes StreamFinish)
|
||||
completed = await stream_registry.mark_task_completed(task_id, "completed")
|
||||
logger.info(f"[BG_TASK] mark_task_completed returned: {completed}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in background AI generation for session {session_id}: {e}"
|
||||
@@ -315,7 +363,7 @@ async def stream_chat_post(
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
subscriber_queue = None
|
||||
try:
|
||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||
# Subscribe to the task stream (replays + live updates)
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
@@ -323,6 +371,7 @@ async def stream_chat_post(
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
logger.warning(f"Failed to subscribe to task {task_id}")
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
@@ -341,11 +390,11 @@ async def stream_chat_post(
|
||||
yield StreamHeartbeat().to_sse()
|
||||
|
||||
except GeneratorExit:
|
||||
pass # Client disconnected - background task continues
|
||||
pass # Client disconnected - normal behavior
|
||||
except Exception as e:
|
||||
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
if subscriber_queue is not None:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(
|
||||
@@ -400,35 +449,21 @@ async def stream_chat_get(
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
# Choose service based on configuration
|
||||
use_sdk = config.use_claude_agent_sdk
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else chat_service.stream_chat_completion
|
||||
)
|
||||
async for chunk in stream_fn(
|
||||
session_id,
|
||||
message,
|
||||
is_user_message=is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
):
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Chat stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
logger.info(
|
||||
"Chat stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_count": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
@@ -550,8 +585,6 @@ async def stream_task(
|
||||
)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import asyncio
|
||||
|
||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||
try:
|
||||
while True:
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
"""Claude Agent SDK integration for CoPilot.
|
||||
|
||||
This module provides the integration layer between the Claude Agent SDK
|
||||
and the existing CoPilot tool system, enabling drop-in replacement of
|
||||
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
||||
"""
|
||||
|
||||
from .service import stream_chat_completion_sdk
|
||||
from .tool_adapter import create_copilot_mcp_server
|
||||
|
||||
__all__ = [
|
||||
"stream_chat_completion_sdk",
|
||||
"create_copilot_mcp_server",
|
||||
]
|
||||
@@ -0,0 +1,348 @@
|
||||
"""Anthropic SDK fallback implementation.
|
||||
|
||||
This module provides the fallback streaming implementation using the Anthropic SDK
|
||||
directly when the Claude Agent SDK is not available.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, cast
|
||||
|
||||
from ..model import ChatMessage, ChatSession
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
from .tool_adapter import get_tool_definitions, get_tool_handlers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def stream_with_anthropic(
|
||||
session: ChatSession,
|
||||
system_prompt: str,
|
||||
text_block_id: str,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Stream using Anthropic SDK directly with tool calling support.
|
||||
|
||||
This function accumulates messages into the session for persistence.
|
||||
The caller should NOT yield an additional StreamFinish - this function handles it.
|
||||
"""
|
||||
import anthropic
|
||||
|
||||
# Only use ANTHROPIC_API_KEY - don't fall back to OpenRouter keys
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
yield StreamError(
|
||||
errorText="ANTHROPIC_API_KEY not configured for fallback",
|
||||
code="config_error",
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
client = anthropic.AsyncAnthropic(api_key=api_key)
|
||||
tool_definitions = get_tool_definitions()
|
||||
tool_handlers = get_tool_handlers()
|
||||
|
||||
anthropic_tools = [
|
||||
{
|
||||
"name": t["name"],
|
||||
"description": t["description"],
|
||||
"input_schema": t["inputSchema"],
|
||||
}
|
||||
for t in tool_definitions
|
||||
]
|
||||
|
||||
anthropic_messages = _convert_session_to_anthropic(session)
|
||||
|
||||
if not anthropic_messages or anthropic_messages[-1]["role"] != "user":
|
||||
anthropic_messages.append(
|
||||
{"role": "user", "content": "Continue with the task."}
|
||||
)
|
||||
|
||||
has_started_text = False
|
||||
max_iterations = 10
|
||||
accumulated_text = ""
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
for _ in range(max_iterations):
|
||||
try:
|
||||
async with client.messages.stream(
|
||||
model="claude-sonnet-4-20250514",
|
||||
max_tokens=4096,
|
||||
system=system_prompt,
|
||||
messages=cast(Any, anthropic_messages),
|
||||
tools=cast(Any, anthropic_tools) if anthropic_tools else [],
|
||||
) as stream:
|
||||
async for event in stream:
|
||||
if event.type == "content_block_start":
|
||||
block = event.content_block
|
||||
if hasattr(block, "type"):
|
||||
if block.type == "text" and not has_started_text:
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
has_started_text = True
|
||||
elif block.type == "tool_use":
|
||||
yield StreamToolInputStart(
|
||||
toolCallId=block.id, toolName=block.name
|
||||
)
|
||||
|
||||
elif event.type == "content_block_delta":
|
||||
delta = event.delta
|
||||
if hasattr(delta, "type") and delta.type == "text_delta":
|
||||
accumulated_text += delta.text
|
||||
yield StreamTextDelta(id=text_block_id, delta=delta.text)
|
||||
|
||||
final_message = await stream.get_final_message()
|
||||
|
||||
if final_message.stop_reason == "tool_use":
|
||||
if has_started_text:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
has_started_text = False
|
||||
text_block_id = str(uuid.uuid4())
|
||||
|
||||
tool_results = []
|
||||
assistant_content: list[dict[str, Any]] = []
|
||||
|
||||
for block in final_message.content:
|
||||
if block.type == "text":
|
||||
assistant_content.append(
|
||||
{"type": "text", "text": block.text}
|
||||
)
|
||||
elif block.type == "tool_use":
|
||||
assistant_content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"input": block.input,
|
||||
}
|
||||
)
|
||||
|
||||
# Track tool call for session persistence
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": block.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.name,
|
||||
"arguments": json.dumps(
|
||||
block.input
|
||||
if isinstance(block.input, dict)
|
||||
else {}
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=block.id,
|
||||
toolName=block.name,
|
||||
input=(
|
||||
block.input if isinstance(block.input, dict) else {}
|
||||
),
|
||||
)
|
||||
|
||||
output, is_error = await _execute_tool(
|
||||
block.name, block.input, tool_handlers
|
||||
)
|
||||
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=block.id,
|
||||
toolName=block.name,
|
||||
output=output,
|
||||
success=not is_error,
|
||||
)
|
||||
|
||||
# Save tool result to session
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=output,
|
||||
tool_call_id=block.id,
|
||||
)
|
||||
)
|
||||
|
||||
tool_results.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": block.id,
|
||||
"content": output,
|
||||
"is_error": is_error,
|
||||
}
|
||||
)
|
||||
|
||||
# Save assistant message with tool calls to session
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=accumulated_text or None,
|
||||
tool_calls=(
|
||||
accumulated_tool_calls
|
||||
if accumulated_tool_calls
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
# Reset for next iteration
|
||||
accumulated_text = ""
|
||||
accumulated_tool_calls = []
|
||||
|
||||
anthropic_messages.append(
|
||||
{"role": "assistant", "content": assistant_content}
|
||||
)
|
||||
anthropic_messages.append({"role": "user", "content": tool_results})
|
||||
continue
|
||||
|
||||
else:
|
||||
if has_started_text:
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
|
||||
# Save final assistant response to session
|
||||
if accumulated_text:
|
||||
session.messages.append(
|
||||
ChatMessage(role="assistant", content=accumulated_text)
|
||||
)
|
||||
|
||||
yield StreamUsage(
|
||||
promptTokens=final_message.usage.input_tokens,
|
||||
completionTokens=final_message.usage.output_tokens,
|
||||
totalTokens=final_message.usage.input_tokens
|
||||
+ final_message.usage.output_tokens,
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Anthropic Fallback] Error: {e}", exc_info=True)
|
||||
yield StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="anthropic_error",
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
yield StreamError(errorText="Max tool iterations reached", code="max_iterations")
|
||||
yield StreamFinish()
|
||||
|
||||
|
||||
def _convert_session_to_anthropic(session: ChatSession) -> list[dict[str, Any]]:
|
||||
"""Convert session messages to Anthropic format.
|
||||
|
||||
Handles merging consecutive same-role messages (Anthropic requires alternating roles).
|
||||
"""
|
||||
messages: list[dict[str, Any]] = []
|
||||
|
||||
for msg in session.messages:
|
||||
if msg.role == "user":
|
||||
new_msg = {"role": "user", "content": msg.content or ""}
|
||||
elif msg.role == "assistant":
|
||||
content: list[dict[str, Any]] = []
|
||||
if msg.content:
|
||||
content.append({"type": "text", "text": msg.content})
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", {})
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tc.get("id", str(uuid.uuid4())),
|
||||
"name": func.get("name", ""),
|
||||
"input": args,
|
||||
}
|
||||
)
|
||||
if content:
|
||||
new_msg = {"role": "assistant", "content": content}
|
||||
else:
|
||||
continue # Skip empty assistant messages
|
||||
elif msg.role == "tool":
|
||||
new_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": msg.tool_call_id or "",
|
||||
"content": msg.content or "",
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
continue
|
||||
|
||||
messages.append(new_msg)
|
||||
|
||||
# Merge consecutive same-role messages (Anthropic requires alternating roles)
|
||||
return _merge_consecutive_roles(messages)
|
||||
|
||||
|
||||
def _merge_consecutive_roles(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Merge consecutive messages with the same role.
|
||||
|
||||
Anthropic API requires alternating user/assistant roles.
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
merged: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
if merged and merged[-1]["role"] == msg["role"]:
|
||||
# Merge with previous message
|
||||
prev_content = merged[-1]["content"]
|
||||
new_content = msg["content"]
|
||||
|
||||
# Normalize both to list-of-blocks form
|
||||
if isinstance(prev_content, str):
|
||||
prev_content = [{"type": "text", "text": prev_content}]
|
||||
if isinstance(new_content, str):
|
||||
new_content = [{"type": "text", "text": new_content}]
|
||||
|
||||
# Ensure both are lists
|
||||
if not isinstance(prev_content, list):
|
||||
prev_content = [prev_content]
|
||||
if not isinstance(new_content, list):
|
||||
new_content = [new_content]
|
||||
|
||||
merged[-1]["content"] = prev_content + new_content
|
||||
else:
|
||||
merged.append(msg)
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
async def _execute_tool(
|
||||
tool_name: str, tool_input: Any, handlers: dict[str, Any]
|
||||
) -> tuple[str, bool]:
|
||||
"""Execute a tool and return (output, is_error)."""
|
||||
handler = handlers.get(tool_name)
|
||||
if not handler:
|
||||
return f"Unknown tool: {tool_name}", True
|
||||
|
||||
try:
|
||||
result = await handler(tool_input)
|
||||
# Safely extract output - handle empty or missing content
|
||||
content = result.get("content") or []
|
||||
if content and isinstance(content, list) and len(content) > 0:
|
||||
first_item = content[0]
|
||||
output = first_item.get("text", "") if isinstance(first_item, dict) else ""
|
||||
else:
|
||||
output = ""
|
||||
is_error = result.get("isError", False)
|
||||
return output, is_error
|
||||
except Exception as e:
|
||||
return f"Error: {str(e)}", True
|
||||
@@ -0,0 +1,311 @@
|
||||
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||
|
||||
This module provides the adapter layer that converts streaming messages from
|
||||
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
|
||||
the frontend expects.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from backend.api.features.chat.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SDKResponseAdapter:
|
||||
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||
|
||||
This class maintains state during a streaming session to properly track
|
||||
text blocks, tool calls, and message lifecycle.
|
||||
"""
|
||||
|
||||
def __init__(self, message_id: str | None = None):
|
||||
"""Initialize the adapter.
|
||||
|
||||
Args:
|
||||
message_id: Optional message ID. If not provided, one will be generated.
|
||||
"""
|
||||
self.message_id = message_id or str(uuid.uuid4())
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
self.has_started_text = False
|
||||
self.has_ended_text = False
|
||||
self.current_tool_calls: dict[str, dict[str, Any]] = {}
|
||||
self.task_id: str | None = None
|
||||
|
||||
def set_task_id(self, task_id: str) -> None:
|
||||
"""Set the task ID for reconnection support."""
|
||||
self.task_id = task_id
|
||||
|
||||
def convert_message(self, sdk_message: Any) -> list[StreamBaseResponse]:
|
||||
"""Convert a single SDK message to Vercel AI SDK format.
|
||||
|
||||
Args:
|
||||
sdk_message: A message from the Claude Agent SDK.
|
||||
|
||||
Returns:
|
||||
List of StreamBaseResponse objects (may be empty or multiple).
|
||||
"""
|
||||
responses: list[StreamBaseResponse] = []
|
||||
|
||||
# Handle different SDK message types - use class name since SDK uses dataclasses
|
||||
class_name = type(sdk_message).__name__
|
||||
msg_subtype = getattr(sdk_message, "subtype", None)
|
||||
|
||||
if class_name == "SystemMessage":
|
||||
if msg_subtype == "init":
|
||||
# Session initialization - emit start
|
||||
responses.append(
|
||||
StreamStart(
|
||||
messageId=self.message_id,
|
||||
taskId=self.task_id,
|
||||
)
|
||||
)
|
||||
|
||||
elif class_name == "AssistantMessage":
|
||||
# Assistant message with content blocks
|
||||
content = getattr(sdk_message, "content", [])
|
||||
for block in content:
|
||||
# Check block type by class name (SDK uses dataclasses) or dict type
|
||||
block_class = type(block).__name__
|
||||
block_type = block.get("type") if isinstance(block, dict) else None
|
||||
|
||||
if block_class == "TextBlock" or block_type == "text":
|
||||
# Text content
|
||||
text = getattr(block, "text", None) or (
|
||||
block.get("text") if isinstance(block, dict) else ""
|
||||
)
|
||||
|
||||
if text:
|
||||
# Start text block if needed (or restart after tool calls)
|
||||
if not self.has_started_text or self.has_ended_text:
|
||||
# Generate new text block ID for text after tools
|
||||
if self.has_ended_text:
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
self.has_ended_text = False
|
||||
responses.append(StreamTextStart(id=self.text_block_id))
|
||||
self.has_started_text = True
|
||||
|
||||
# Emit text delta
|
||||
responses.append(
|
||||
StreamTextDelta(
|
||||
id=self.text_block_id,
|
||||
delta=text,
|
||||
)
|
||||
)
|
||||
|
||||
elif block_class == "ToolUseBlock" or block_type == "tool_use":
|
||||
# Tool call
|
||||
tool_id_raw = getattr(block, "id", None) or (
|
||||
block.get("id") if isinstance(block, dict) else None
|
||||
)
|
||||
tool_id: str = (
|
||||
str(tool_id_raw) if tool_id_raw else str(uuid.uuid4())
|
||||
)
|
||||
|
||||
tool_name_raw = getattr(block, "name", None) or (
|
||||
block.get("name") if isinstance(block, dict) else None
|
||||
)
|
||||
tool_name: str = str(tool_name_raw) if tool_name_raw else "unknown"
|
||||
|
||||
tool_input = getattr(block, "input", None) or (
|
||||
block.get("input") if isinstance(block, dict) else {}
|
||||
)
|
||||
|
||||
# End text block if we were streaming text
|
||||
if self.has_started_text and not self.has_ended_text:
|
||||
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||
self.has_ended_text = True
|
||||
|
||||
# Emit tool input start
|
||||
responses.append(
|
||||
StreamToolInputStart(
|
||||
toolCallId=tool_id,
|
||||
toolName=tool_name,
|
||||
)
|
||||
)
|
||||
|
||||
# Emit tool input available with full input
|
||||
responses.append(
|
||||
StreamToolInputAvailable(
|
||||
toolCallId=tool_id,
|
||||
toolName=tool_name,
|
||||
input=tool_input if isinstance(tool_input, dict) else {},
|
||||
)
|
||||
)
|
||||
|
||||
# Track the tool call
|
||||
self.current_tool_calls[tool_id] = {
|
||||
"name": tool_name,
|
||||
"input": tool_input,
|
||||
}
|
||||
|
||||
elif class_name in ("ToolResultMessage", "UserMessage"):
|
||||
# Tool result - check for tool_result content
|
||||
content = getattr(sdk_message, "content", [])
|
||||
|
||||
for block in content:
|
||||
block_class = type(block).__name__
|
||||
block_type = block.get("type") if isinstance(block, dict) else None
|
||||
|
||||
if block_class == "ToolResultBlock" or block_type == "tool_result":
|
||||
tool_use_id = getattr(block, "tool_use_id", None) or (
|
||||
block.get("tool_use_id") if isinstance(block, dict) else None
|
||||
)
|
||||
result_content = getattr(block, "content", None) or (
|
||||
block.get("content") if isinstance(block, dict) else ""
|
||||
)
|
||||
is_error = getattr(block, "is_error", False) or (
|
||||
block.get("is_error", False)
|
||||
if isinstance(block, dict)
|
||||
else False
|
||||
)
|
||||
|
||||
if tool_use_id:
|
||||
tool_info = self.current_tool_calls.get(tool_use_id, {})
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
|
||||
# Format the output
|
||||
if isinstance(result_content, list):
|
||||
# Extract text from content blocks
|
||||
output_text = ""
|
||||
for item in result_content:
|
||||
if (
|
||||
isinstance(item, dict)
|
||||
and item.get("type") == "text"
|
||||
):
|
||||
output_text += item.get("text", "")
|
||||
elif hasattr(item, "text"):
|
||||
output_text += getattr(item, "text", "")
|
||||
output = output_text
|
||||
elif isinstance(result_content, str):
|
||||
output = result_content
|
||||
else:
|
||||
output = json.dumps(result_content)
|
||||
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_use_id,
|
||||
toolName=tool_name,
|
||||
output=output,
|
||||
success=not is_error,
|
||||
)
|
||||
)
|
||||
|
||||
elif class_name == "ResultMessage":
|
||||
# Final result
|
||||
if msg_subtype == "success":
|
||||
# End text block if still open
|
||||
if self.has_started_text and not self.has_ended_text:
|
||||
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||
self.has_ended_text = True
|
||||
|
||||
# Emit finish
|
||||
responses.append(StreamFinish())
|
||||
|
||||
elif msg_subtype in ("error", "error_during_execution"):
|
||||
error_msg = getattr(sdk_message, "error", "Unknown error")
|
||||
responses.append(
|
||||
StreamError(
|
||||
errorText=str(error_msg),
|
||||
code="sdk_error",
|
||||
)
|
||||
)
|
||||
responses.append(StreamFinish())
|
||||
|
||||
elif class_name == "ErrorMessage":
|
||||
# Error message
|
||||
error_msg = getattr(sdk_message, "message", None) or getattr(
|
||||
sdk_message, "error", "Unknown error"
|
||||
)
|
||||
responses.append(
|
||||
StreamError(
|
||||
errorText=str(error_msg),
|
||||
code="sdk_error",
|
||||
)
|
||||
)
|
||||
responses.append(StreamFinish())
|
||||
|
||||
else:
|
||||
logger.debug(f"Unhandled SDK message type: {class_name}")
|
||||
|
||||
return responses
|
||||
|
||||
def create_heartbeat(self, tool_call_id: str | None = None) -> StreamHeartbeat:
|
||||
"""Create a heartbeat response."""
|
||||
return StreamHeartbeat(toolCallId=tool_call_id)
|
||||
|
||||
def create_usage(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
) -> StreamUsage:
|
||||
"""Create a usage statistics response."""
|
||||
return StreamUsage(
|
||||
promptTokens=prompt_tokens,
|
||||
completionTokens=completion_tokens,
|
||||
totalTokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
|
||||
async def adapt_sdk_stream(
|
||||
sdk_stream: AsyncGenerator[Any, None],
|
||||
message_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Adapt a Claude Agent SDK stream to Vercel AI SDK format.
|
||||
|
||||
Args:
|
||||
sdk_stream: The async generator from the Claude Agent SDK.
|
||||
message_id: Optional message ID for the response.
|
||||
task_id: Optional task ID for reconnection support.
|
||||
|
||||
Yields:
|
||||
StreamBaseResponse objects in Vercel AI SDK format.
|
||||
"""
|
||||
adapter = SDKResponseAdapter(message_id=message_id)
|
||||
if task_id:
|
||||
adapter.set_task_id(task_id)
|
||||
|
||||
# Emit start immediately
|
||||
yield StreamStart(messageId=adapter.message_id, taskId=task_id)
|
||||
|
||||
finished = False
|
||||
try:
|
||||
async for sdk_message in sdk_stream:
|
||||
responses = adapter.convert_message(sdk_message)
|
||||
for response in responses:
|
||||
# Skip duplicate start messages
|
||||
if isinstance(response, StreamStart):
|
||||
continue
|
||||
if isinstance(response, StreamFinish):
|
||||
finished = True
|
||||
yield response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in SDK stream: {e}", exc_info=True)
|
||||
yield StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="stream_error",
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
# Ensure terminal StreamFinish if SDK stream ended without one
|
||||
if not finished:
|
||||
yield StreamFinish()
|
||||
@@ -0,0 +1,278 @@
|
||||
"""Security hooks for Claude Agent SDK integration.
|
||||
|
||||
This module provides security hooks that validate tool calls before execution,
|
||||
ensuring multi-user isolation and preventing unauthorized operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Tools that are blocked entirely (CLI/system access)
|
||||
BLOCKED_TOOLS = {
|
||||
"Bash",
|
||||
"bash",
|
||||
"shell",
|
||||
"exec",
|
||||
"terminal",
|
||||
"command",
|
||||
"Read", # Block raw file read - use workspace tools instead
|
||||
"Write", # Block raw file write - use workspace tools instead
|
||||
"Edit", # Block raw file edit - use workspace tools instead
|
||||
"Glob", # Block raw file glob - use workspace tools instead
|
||||
"Grep", # Block raw file grep - use workspace tools instead
|
||||
}
|
||||
|
||||
# Dangerous patterns in tool inputs
|
||||
DANGEROUS_PATTERNS = [
|
||||
r"sudo",
|
||||
r"rm\s+-rf",
|
||||
r"dd\s+if=",
|
||||
r"/etc/passwd",
|
||||
r"/etc/shadow",
|
||||
r"chmod\s+777",
|
||||
r"curl\s+.*\|.*sh",
|
||||
r"wget\s+.*\|.*sh",
|
||||
r"eval\s*\(",
|
||||
r"exec\s*\(",
|
||||
r"__import__",
|
||||
r"os\.system",
|
||||
r"subprocess",
|
||||
]
|
||||
|
||||
|
||||
def _validate_tool_access(tool_name: str, tool_input: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Validate that a tool call is allowed.
|
||||
|
||||
Returns:
|
||||
Empty dict to allow, or dict with hookSpecificOutput to deny
|
||||
"""
|
||||
# Block forbidden tools
|
||||
if tool_name in BLOCKED_TOOLS:
|
||||
logger.warning(f"Blocked tool access attempt: {tool_name}")
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": (
|
||||
f"Tool '{tool_name}' is not available. "
|
||||
"Use the CoPilot-specific tools instead."
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
# Check for dangerous patterns in tool input
|
||||
input_str = str(tool_input)
|
||||
|
||||
for pattern in DANGEROUS_PATTERNS:
|
||||
if re.search(pattern, input_str, re.IGNORECASE):
|
||||
logger.warning(
|
||||
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
|
||||
)
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": "Input contains blocked pattern",
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def _validate_user_isolation(
|
||||
tool_name: str, tool_input: dict[str, Any], user_id: str | None
|
||||
) -> dict[str, Any]:
|
||||
"""Validate that tool calls respect user isolation."""
|
||||
# For workspace file tools, ensure path doesn't escape
|
||||
if "workspace" in tool_name.lower():
|
||||
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
||||
if path:
|
||||
# Check for path traversal
|
||||
if ".." in path or path.startswith("/"):
|
||||
logger.warning(
|
||||
f"Blocked path traversal attempt: {path} by user {user_id}"
|
||||
)
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": "Path traversal not allowed",
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def create_security_hooks(user_id: str | None) -> dict[str, Any]:
|
||||
"""Create the security hooks configuration for Claude Agent SDK.
|
||||
|
||||
Includes security validation and observability hooks:
|
||||
- PreToolUse: Security validation before tool execution
|
||||
- PostToolUse: Log successful tool executions
|
||||
- PostToolUseFailure: Log and handle failed tool executions
|
||||
- PreCompact: Log context compaction events (SDK handles compaction automatically)
|
||||
|
||||
Args:
|
||||
user_id: Current user ID for isolation validation
|
||||
|
||||
Returns:
|
||||
Hooks configuration dict for ClaudeAgentOptions
|
||||
"""
|
||||
try:
|
||||
from claude_agent_sdk import HookMatcher
|
||||
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
||||
|
||||
async def pre_tool_use_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Combined pre-tool-use validation hook."""
|
||||
_ = context # unused but required by signature
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
|
||||
|
||||
# Validate basic tool access
|
||||
result = _validate_tool_access(tool_name, tool_input)
|
||||
if result:
|
||||
return cast(SyncHookJSONOutput, result)
|
||||
|
||||
# Validate user isolation
|
||||
result = _validate_user_isolation(tool_name, tool_input, user_id)
|
||||
if result:
|
||||
return cast(SyncHookJSONOutput, result)
|
||||
|
||||
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def post_tool_use_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log successful tool executions for observability."""
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def post_tool_failure_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log failed tool executions for debugging."""
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
error = input_data.get("error", "Unknown error")
|
||||
logger.warning(
|
||||
f"[SDK] Tool failed: {tool_name}, error={error}, "
|
||||
f"user={user_id}, tool_use_id={tool_use_id}"
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def pre_compact_hook(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log when SDK triggers context compaction.
|
||||
|
||||
The SDK automatically compacts conversation history when it grows too large.
|
||||
This hook provides visibility into when compaction happens.
|
||||
"""
|
||||
_ = context, tool_use_id
|
||||
trigger = input_data.get("trigger", "auto")
|
||||
logger.info(
|
||||
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
return {
|
||||
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
|
||||
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
|
||||
"PostToolUseFailure": [
|
||||
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
|
||||
],
|
||||
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
|
||||
}
|
||||
except ImportError:
|
||||
# Fallback for when SDK isn't available - return empty hooks
|
||||
return {}
|
||||
|
||||
|
||||
def create_strict_security_hooks(
|
||||
user_id: str | None,
|
||||
allowed_tools: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create strict security hooks that only allow specific tools.
|
||||
|
||||
Args:
|
||||
user_id: Current user ID
|
||||
allowed_tools: List of allowed tool names (defaults to CoPilot tools)
|
||||
|
||||
Returns:
|
||||
Hooks configuration dict
|
||||
"""
|
||||
try:
|
||||
from claude_agent_sdk import HookMatcher
|
||||
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
||||
|
||||
from .tool_adapter import RAW_TOOL_NAMES
|
||||
|
||||
tools_list = allowed_tools if allowed_tools is not None else RAW_TOOL_NAMES
|
||||
allowed_set = set(tools_list)
|
||||
|
||||
async def strict_pre_tool_use(
|
||||
input_data: HookInput,
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Strict validation that only allows whitelisted tools."""
|
||||
_ = context # unused but required by signature
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
|
||||
|
||||
# Remove MCP prefix if present
|
||||
clean_name = tool_name.removeprefix("mcp__copilot__")
|
||||
|
||||
if clean_name not in allowed_set:
|
||||
logger.warning(f"Blocked non-whitelisted tool: {tool_name}")
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
{
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": (
|
||||
f"Tool '{tool_name}' is not in the allowed list"
|
||||
),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Run standard validations using clean_name for consistent checks
|
||||
result = _validate_tool_access(clean_name, tool_input)
|
||||
if result:
|
||||
return cast(SyncHookJSONOutput, result)
|
||||
|
||||
result = _validate_user_isolation(clean_name, tool_input, user_id)
|
||||
if result:
|
||||
return cast(SyncHookJSONOutput, result)
|
||||
|
||||
logger.debug(
|
||||
f"[SDK Audit] Tool call: tool={tool_name}, "
|
||||
f"user={user_id}, tool_use_id={tool_use_id}"
|
||||
)
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
return {
|
||||
"PreToolUse": [
|
||||
HookMatcher(matcher="*", hooks=[strict_pre_tool_use]),
|
||||
],
|
||||
}
|
||||
except ImportError:
|
||||
return {}
|
||||
@@ -0,0 +1,475 @@
|
||||
"""Claude Agent SDK service layer for CoPilot chat completions."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
|
||||
from backend.data.understanding import (
|
||||
format_understanding_for_prompt,
|
||||
get_business_understanding,
|
||||
)
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from ..config import ChatConfig
|
||||
from ..model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from ..tracking import track_user_message
|
||||
from .anthropic_fallback import stream_with_anthropic
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .security_hooks import create_security_hooks
|
||||
from .tool_adapter import (
|
||||
COPILOT_TOOL_NAMES,
|
||||
create_copilot_mcp_server,
|
||||
set_execution_context,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
# Set to hold background tasks to prevent garbage collection
|
||||
_background_tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations.
|
||||
|
||||
Here is everything you know about the current user from previous interactions:
|
||||
|
||||
<users_information>
|
||||
{users_information}
|
||||
</users_information>
|
||||
|
||||
## YOUR CORE MANDATE
|
||||
|
||||
You are action-oriented. Your success is measured by:
|
||||
- **Value Delivery**: Does the user think "wow, that was amazing" or "what was the point"?
|
||||
- **Demonstrable Proof**: Show working automations, not descriptions of what's possible
|
||||
- **Time Saved**: Focus on tangible efficiency gains
|
||||
- **Quality Output**: Deliver results that meet or exceed expectations
|
||||
|
||||
## YOUR WORKFLOW
|
||||
|
||||
Adapt flexibly to the conversation context. Not every interaction requires all stages:
|
||||
|
||||
1. **Explore & Understand**: Learn about the user's business, tasks, and goals. Use `add_understanding` to capture important context that will improve future conversations.
|
||||
|
||||
2. **Assess Automation Potential**: Help the user understand whether and how AI can automate their task.
|
||||
|
||||
3. **Prepare for AI**: Provide brief, actionable guidance on prerequisites (data, access, etc.).
|
||||
|
||||
4. **Discover or Create Agents**:
|
||||
- **Always check the user's library first** with `find_library_agent` (these may be customized to their needs)
|
||||
- Search the marketplace with `find_agent` for pre-built automations
|
||||
- Find reusable components with `find_block`
|
||||
- Create custom solutions with `create_agent` if nothing suitable exists
|
||||
- Modify existing library agents with `edit_agent`
|
||||
|
||||
5. **Execute**: Run automations immediately, schedule them, or set up webhooks using `run_agent`. Test specific components with `run_block`.
|
||||
|
||||
6. **Show Results**: Display outputs using `agent_output`.
|
||||
|
||||
## BEHAVIORAL GUIDELINES
|
||||
|
||||
**Be Concise:**
|
||||
- Target 2-5 short lines maximum
|
||||
- Make every word count—no repetition or filler
|
||||
- Use lightweight structure for scannability (bullets, numbered lists, short prompts)
|
||||
- Avoid jargon (blocks, slugs, cron) unless the user asks
|
||||
|
||||
**Be Proactive:**
|
||||
- Suggest next steps before being asked
|
||||
- Anticipate needs based on conversation context and user information
|
||||
- Look for opportunities to expand scope when relevant
|
||||
- Reveal capabilities through action, not explanation
|
||||
|
||||
**Use Tools Effectively:**
|
||||
- Select the right tool for each task
|
||||
- **Always check `find_library_agent` before searching the marketplace**
|
||||
- Use `add_understanding` to capture valuable business context
|
||||
- When tool calls fail, try alternative approaches
|
||||
|
||||
## CRITICAL REMINDER
|
||||
|
||||
You are NOT a chatbot. You are NOT documentation. You are a partner who helps busy business owners get value quickly by showing proof through working automations. Bias toward action over explanation."""
|
||||
|
||||
|
||||
async def _build_system_prompt(
|
||||
user_id: str | None, has_conversation_history: bool = False
|
||||
) -> tuple[str, Any]:
|
||||
"""Build the system prompt with user's business understanding context.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to fetch understanding for.
|
||||
has_conversation_history: Whether there's existing conversation history.
|
||||
If True, we don't tell the model to greet/introduce (since they're
|
||||
already in a conversation).
|
||||
"""
|
||||
understanding = None
|
||||
if user_id:
|
||||
try:
|
||||
understanding = await get_business_understanding(user_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch business understanding: {e}")
|
||||
|
||||
if understanding:
|
||||
context = format_understanding_for_prompt(understanding)
|
||||
elif has_conversation_history:
|
||||
# Don't tell model to greet if there's conversation history
|
||||
context = "No prior understanding saved yet. Continue the existing conversation naturally."
|
||||
else:
|
||||
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
||||
|
||||
return DEFAULT_SYSTEM_PROMPT.format(users_information=context), understanding
|
||||
|
||||
|
||||
def _format_conversation_history(session: ChatSession) -> str:
|
||||
"""Format conversation history as a prompt context.
|
||||
|
||||
The SDK handles context compaction automatically, but we apply
|
||||
max_context_messages as a safety guard to limit initial prompt size.
|
||||
"""
|
||||
if not session.messages:
|
||||
return ""
|
||||
|
||||
# Get all messages except the last user message (which will be the prompt)
|
||||
messages = session.messages[:-1] if session.messages else []
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
# Apply max_context_messages limit as a safety guard
|
||||
# (SDK handles compaction, but this prevents excessively large initial prompts)
|
||||
max_messages = config.max_context_messages
|
||||
if len(messages) > max_messages:
|
||||
messages = messages[-max_messages:]
|
||||
|
||||
history_parts = ["<conversation_history>"]
|
||||
|
||||
for msg in messages:
|
||||
if msg.role == "user":
|
||||
history_parts.append(f"User: {msg.content or ''}")
|
||||
elif msg.role == "assistant":
|
||||
# Pass full content - SDK handles compaction automatically
|
||||
history_parts.append(f"Assistant: {msg.content or ''}")
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
func = tc.get("function", {})
|
||||
history_parts.append(
|
||||
f" [Called tool: {func.get('name', 'unknown')}]"
|
||||
)
|
||||
elif msg.role == "tool":
|
||||
# Truncate large tool results to avoid blowing context window
|
||||
tool_content = msg.content or ""
|
||||
if len(tool_content) > 500:
|
||||
tool_content = tool_content[:500] + "... (truncated)"
|
||||
history_parts.append(f" [Tool result: {tool_content}]")
|
||||
|
||||
history_parts.append("</conversation_history>")
|
||||
history_parts.append("")
|
||||
history_parts.append(
|
||||
"Continue this conversation. Respond to the user's latest message:"
|
||||
)
|
||||
history_parts.append("")
|
||||
|
||||
return "\n".join(history_parts)
|
||||
|
||||
|
||||
async def _generate_session_title(
|
||||
message: str,
|
||||
user_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> str | None:
|
||||
"""Generate a concise title for a chat session."""
|
||||
from backend.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
try:
|
||||
# Build extra_body for OpenRouter tracing
|
||||
extra_body: dict[str, Any] = {
|
||||
"posthogProperties": {"environment": settings.config.app_env.value},
|
||||
}
|
||||
if user_id:
|
||||
extra_body["user"] = user_id[:128]
|
||||
extra_body["posthogDistinctId"] = user_id
|
||||
if session_id:
|
||||
extra_body["session_id"] = session_id[:128]
|
||||
|
||||
client = openai.AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
response = await client.chat.completions.create(
|
||||
model=config.title_model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Generate a very short title (3-6 words) for a chat conversation based on the user's first message. Return ONLY the title, no quotes or punctuation.",
|
||||
},
|
||||
{"role": "user", "content": message[:500]},
|
||||
],
|
||||
max_tokens=20,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
title = response.choices[0].message.content
|
||||
if title:
|
||||
title = title.strip().strip("\"'")
|
||||
return title[:47] + "..." if len(title) > 50 else title
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate session title: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def stream_chat_completion_sdk(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
tool_call_response: str | None = None, # noqa: ARG001
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
retry_count: int = 0, # noqa: ARG001
|
||||
session: ChatSession | None = None,
|
||||
context: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Stream chat completion using Claude Agent SDK.
|
||||
|
||||
Drop-in replacement for stream_chat_completion with improved reliability.
|
||||
"""
|
||||
|
||||
if session is None:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(
|
||||
f"Session {session_id} not found. Please create a new session first."
|
||||
)
|
||||
|
||||
if message:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="user" if is_user_message else "assistant", content=message
|
||||
)
|
||||
)
|
||||
if is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id, session_id=session_id, message_length=len(message)
|
||||
)
|
||||
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Generate title for new sessions (first user message)
|
||||
if is_user_message and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
if len(user_messages) == 1:
|
||||
first_message = user_messages[0].content or message or ""
|
||||
if first_message:
|
||||
task = asyncio.create_task(
|
||||
_update_title_async(session_id, first_message, user_id)
|
||||
)
|
||||
# Store reference to prevent garbage collection
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
# Check if there's conversation history (more than just the current message)
|
||||
has_history = len(session.messages) > 1
|
||||
system_prompt, _ = await _build_system_prompt(
|
||||
user_id, has_conversation_history=has_history
|
||||
)
|
||||
set_execution_context(user_id, session, None)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
text_block_id = str(uuid.uuid4())
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
yield StreamStart(messageId=message_id, taskId=task_id)
|
||||
|
||||
# Track whether the stream completed normally via ResultMessage
|
||||
stream_completed = False
|
||||
|
||||
try:
|
||||
try:
|
||||
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
||||
|
||||
# Create MCP server with CoPilot tools
|
||||
mcp_server = create_copilot_mcp_server()
|
||||
|
||||
options = ClaudeAgentOptions(
|
||||
system_prompt=system_prompt,
|
||||
mcp_servers={"copilot": mcp_server}, # type: ignore[arg-type]
|
||||
allowed_tools=COPILOT_TOOL_NAMES,
|
||||
hooks=create_security_hooks(user_id), # type: ignore[arg-type]
|
||||
continue_conversation=True, # Enable conversation continuation
|
||||
)
|
||||
|
||||
adapter = SDKResponseAdapter(message_id=message_id)
|
||||
adapter.set_task_id(task_id)
|
||||
|
||||
async with ClaudeSDKClient(options=options) as client:
|
||||
# Build prompt with conversation history for context
|
||||
# The SDK doesn't support replaying full conversation history,
|
||||
# so we include it as context in the prompt
|
||||
current_message = message or ""
|
||||
if not current_message and session.messages:
|
||||
last_user = [m for m in session.messages if m.role == "user"]
|
||||
if last_user:
|
||||
current_message = last_user[-1].content or ""
|
||||
|
||||
# Include conversation history if there are prior messages
|
||||
if len(session.messages) > 1:
|
||||
history_context = _format_conversation_history(session)
|
||||
prompt = f"{history_context}{current_message}"
|
||||
else:
|
||||
prompt = current_message
|
||||
|
||||
# Guard against empty prompts
|
||||
if not prompt.strip():
|
||||
yield StreamError(
|
||||
errorText="Message cannot be empty.",
|
||||
code="empty_prompt",
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
await client.query(prompt, session_id=session_id)
|
||||
|
||||
# Track assistant response to save to session
|
||||
# We may need multiple assistant messages if text comes after tool results
|
||||
assistant_response = ChatMessage(role="assistant", content="")
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
has_appended_assistant = False
|
||||
has_tool_results = False # Track if we've received tool results
|
||||
|
||||
# Receive messages from the SDK
|
||||
async for sdk_msg in client.receive_messages():
|
||||
for response in adapter.convert_message(sdk_msg):
|
||||
if isinstance(response, StreamStart):
|
||||
continue
|
||||
yield response
|
||||
|
||||
# Accumulate text deltas into assistant response
|
||||
if isinstance(response, StreamTextDelta):
|
||||
delta = response.delta or ""
|
||||
# After tool results, create new assistant message for post-tool text
|
||||
if has_tool_results and has_appended_assistant:
|
||||
assistant_response = ChatMessage(
|
||||
role="assistant", content=delta
|
||||
)
|
||||
accumulated_tool_calls = [] # Reset for new message
|
||||
session.messages.append(assistant_response)
|
||||
has_tool_results = False
|
||||
else:
|
||||
assistant_response.content = (
|
||||
assistant_response.content or ""
|
||||
) + delta
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
|
||||
# Track tool calls on the assistant message
|
||||
elif isinstance(response, StreamToolInputAvailable):
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": response.toolCallId,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": response.toolName,
|
||||
"arguments": json.dumps(response.input or {}),
|
||||
},
|
||||
}
|
||||
)
|
||||
# Update assistant message with tool calls
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
# Append assistant message if not already (tool-only response)
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
|
||||
elif isinstance(response, StreamToolOutputAvailable):
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=(
|
||||
response.output
|
||||
if isinstance(response.output, str)
|
||||
else str(response.output)
|
||||
),
|
||||
tool_call_id=response.toolCallId,
|
||||
)
|
||||
)
|
||||
has_tool_results = True
|
||||
|
||||
elif isinstance(response, StreamFinish):
|
||||
stream_completed = True
|
||||
|
||||
# Break out of the message loop if we received finish signal
|
||||
if stream_completed:
|
||||
break
|
||||
|
||||
# Ensure assistant response is saved even if no text deltas
|
||||
# (e.g., only tool calls were made)
|
||||
if (
|
||||
assistant_response.content or assistant_response.tool_calls
|
||||
) and not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"[SDK] claude-agent-sdk not available, using Anthropic fallback"
|
||||
)
|
||||
async for response in stream_with_anthropic(
|
||||
session, system_prompt, text_block_id
|
||||
):
|
||||
if isinstance(response, StreamFinish):
|
||||
stream_completed = True
|
||||
yield response
|
||||
|
||||
# Save the session with accumulated messages
|
||||
await upsert_chat_session(session)
|
||||
logger.debug(
|
||||
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
|
||||
)
|
||||
# Yield StreamFinish to signal completion to the caller (routes.py)
|
||||
# Only if one hasn't already been yielded by the stream
|
||||
if not stream_completed:
|
||||
yield StreamFinish()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SDK] Error: {e}", exc_info=True)
|
||||
# Save session even on error to preserve any partial response
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
except Exception as save_err:
|
||||
logger.error(f"[SDK] Failed to save session on error: {save_err}")
|
||||
# Sanitize error message to avoid exposing internal details
|
||||
yield StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="sdk_error",
|
||||
)
|
||||
yield StreamFinish()
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
session_id: str, message: str, user_id: str | None = None
|
||||
) -> None:
|
||||
"""Background task to update session title."""
|
||||
try:
|
||||
title = await _generate_session_title(
|
||||
message, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if title:
|
||||
await update_session_title(session_id, title)
|
||||
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[SDK] Failed to update session title: {e}")
|
||||
@@ -0,0 +1,217 @@
|
||||
"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools.
|
||||
|
||||
This module provides the adapter layer that converts existing BaseTool implementations
|
||||
into in-process MCP tools that can be used with the Claude Agent SDK.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from contextvars import ContextVar
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools import TOOL_REGISTRY
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Context variables to pass user/session info to tool execution
|
||||
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
|
||||
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
"current_session", default=None
|
||||
)
|
||||
_current_tool_call_id: ContextVar[str | None] = ContextVar(
|
||||
"current_tool_call_id", default=None
|
||||
)
|
||||
|
||||
|
||||
def set_execution_context(
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
tool_call_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set the execution context for tool calls.
|
||||
|
||||
This must be called before streaming begins to ensure tools have access
|
||||
to user_id and session information.
|
||||
"""
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_current_tool_call_id.set(tool_call_id)
|
||||
|
||||
|
||||
def get_execution_context() -> tuple[str | None, ChatSession | None, str | None]:
|
||||
"""Get the current execution context."""
|
||||
return (
|
||||
_current_user_id.get(),
|
||||
_current_session.get(),
|
||||
_current_tool_call_id.get(),
|
||||
)
|
||||
|
||||
|
||||
def create_tool_handler(base_tool: BaseTool):
|
||||
"""Create an async handler function for a BaseTool.
|
||||
|
||||
This wraps the existing BaseTool._execute method to be compatible
|
||||
with the Claude Agent SDK MCP tool format.
|
||||
"""
|
||||
|
||||
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the wrapped tool and return MCP-formatted response."""
|
||||
user_id, session, tool_call_id = get_execution_context()
|
||||
|
||||
if session is None:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(
|
||||
{
|
||||
"error": "No session context available",
|
||||
"type": "error",
|
||||
}
|
||||
),
|
||||
}
|
||||
],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
try:
|
||||
# Call the existing tool's execute method
|
||||
# Generate unique tool_call_id per invocation for proper correlation
|
||||
effective_id = tool_call_id or f"sdk-{uuid.uuid4().hex[:12]}"
|
||||
result = await base_tool.execute(
|
||||
user_id=user_id,
|
||||
session=session,
|
||||
tool_call_id=effective_id,
|
||||
**args,
|
||||
)
|
||||
|
||||
# The result is a StreamToolOutputAvailable, extract the output
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
result.output
|
||||
if isinstance(result.output, str)
|
||||
else json.dumps(result.output)
|
||||
),
|
||||
}
|
||||
],
|
||||
"isError": not result.success,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(
|
||||
{
|
||||
"error": str(e),
|
||||
"type": "error",
|
||||
"message": f"Failed to execute {base_tool.name}",
|
||||
}
|
||||
),
|
||||
}
|
||||
],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
return tool_handler
|
||||
|
||||
|
||||
def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
|
||||
"""Build a JSON Schema input schema for a tool."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": base_tool.parameters.get("properties", {}),
|
||||
"required": base_tool.parameters.get("required", []),
|
||||
}
|
||||
|
||||
|
||||
def get_tool_definitions() -> list[dict[str, Any]]:
|
||||
"""Get all tool definitions in MCP format.
|
||||
|
||||
Returns a list of tool definitions that can be used with
|
||||
create_sdk_mcp_server or as raw tool definitions.
|
||||
"""
|
||||
tool_definitions = []
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
tool_def = {
|
||||
"name": tool_name,
|
||||
"description": base_tool.description,
|
||||
"inputSchema": _build_input_schema(base_tool),
|
||||
}
|
||||
tool_definitions.append(tool_def)
|
||||
|
||||
return tool_definitions
|
||||
|
||||
|
||||
def get_tool_handlers() -> dict[str, Any]:
|
||||
"""Get all tool handlers mapped by name.
|
||||
|
||||
Returns a dictionary mapping tool names to their handler functions.
|
||||
"""
|
||||
handlers = {}
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
handlers[tool_name] = create_tool_handler(base_tool)
|
||||
|
||||
return handlers
|
||||
|
||||
|
||||
# Create the MCP server configuration
|
||||
def create_copilot_mcp_server():
|
||||
"""Create an in-process MCP server configuration for CoPilot tools.
|
||||
|
||||
This can be passed to ClaudeAgentOptions.mcp_servers.
|
||||
|
||||
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
|
||||
package being available. This function returns the configuration that
|
||||
can be used with the SDK.
|
||||
"""
|
||||
try:
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
# Create decorated tool functions
|
||||
sdk_tools = []
|
||||
|
||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
||||
# Get the handler
|
||||
handler = create_tool_handler(base_tool)
|
||||
|
||||
# Create the decorated tool
|
||||
# The @tool decorator expects (name, description, schema)
|
||||
# Pass full JSON schema with type, properties, and required
|
||||
decorated = tool(
|
||||
tool_name,
|
||||
base_tool.description,
|
||||
_build_input_schema(base_tool),
|
||||
)(handler)
|
||||
|
||||
sdk_tools.append(decorated)
|
||||
|
||||
# Create the MCP server
|
||||
server = create_sdk_mcp_server(
|
||||
name="copilot",
|
||||
version="1.0.0",
|
||||
tools=sdk_tools,
|
||||
)
|
||||
|
||||
return server
|
||||
|
||||
except ImportError:
|
||||
# Let ImportError propagate so service.py handles the fallback
|
||||
raise
|
||||
|
||||
|
||||
# List of tool names for allowed_tools configuration
|
||||
COPILOT_TOOL_NAMES = [f"mcp__copilot__{name}" for name in TOOL_REGISTRY.keys()]
|
||||
|
||||
# Also export the raw tool names for flexibility
|
||||
RAW_TOOL_NAMES = list(TOOL_REGISTRY.keys())
|
||||
@@ -555,6 +555,10 @@ async def get_active_task_for_session(
|
||||
if task_user_id and user_id != task_user_id:
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
|
||||
)
|
||||
|
||||
# Get the last message ID from Redis Stream
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
last_id = "0-0"
|
||||
|
||||
@@ -1,265 +0,0 @@
|
||||
"""
|
||||
MCP (Model Context Protocol) Tool Block.
|
||||
|
||||
A single dynamic block that can connect to any MCP server, discover available tools,
|
||||
and execute them. Works like AgentExecutorBlock — the user selects a tool from a
|
||||
dropdown and the input/output schema adapts dynamically.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockInput,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
BlockType,
|
||||
)
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.json import validate_with_jsonschema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MCPCredentials = APIKeyCredentials | OAuth2Credentials
|
||||
MCPCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.MCP], Literal["api_key", "oauth2"]
|
||||
]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="mcp",
|
||||
api_key=SecretStr("test-mcp-token"),
|
||||
title="Mock MCP Credentials",
|
||||
)
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
class MCPToolBlock(Block):
|
||||
"""
|
||||
A block that connects to an MCP server, lets the user pick a tool,
|
||||
and executes it with dynamic input/output schema.
|
||||
|
||||
The flow:
|
||||
1. User provides an MCP server URL (and optional credentials)
|
||||
2. Frontend calls the backend to get tool list from that URL
|
||||
3. User selects a tool from a dropdown (available_tools)
|
||||
4. The block's input schema updates to reflect the selected tool's parameters
|
||||
5. On execution, the block calls the MCP server to run the tool
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
# -- Static fields (always shown) --
|
||||
credentials: MCPCredentialsInput = CredentialsField(
|
||||
description="Credentials for the MCP server. Use an API key for Bearer "
|
||||
"token auth, or OAuth2 for servers that support it. For public "
|
||||
"servers, create a credential with any placeholder value.",
|
||||
)
|
||||
server_url: str = SchemaField(
|
||||
description="URL of the MCP server (Streamable HTTP endpoint)",
|
||||
placeholder="https://mcp.example.com/mcp",
|
||||
)
|
||||
available_tools: dict[str, Any] = SchemaField(
|
||||
description="Available tools on the MCP server. "
|
||||
"This is populated automatically when a server URL is provided.",
|
||||
default={},
|
||||
hidden=True,
|
||||
)
|
||||
selected_tool: str = SchemaField(
|
||||
description="The MCP tool to execute",
|
||||
placeholder="Select a tool",
|
||||
default="",
|
||||
)
|
||||
tool_input_schema: dict[str, Any] = SchemaField(
|
||||
description="JSON Schema for the selected tool's input parameters. "
|
||||
"Populated automatically when a tool is selected.",
|
||||
default={},
|
||||
hidden=True,
|
||||
)
|
||||
|
||||
# -- Dynamic field: actual arguments for the selected tool --
|
||||
tool_arguments: dict[str, Any] = SchemaField(
|
||||
description="Arguments to pass to the selected MCP tool. "
|
||||
"The fields here are defined by the tool's input schema.",
|
||||
default={},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
|
||||
"""Return the tool's input schema so the builder UI renders dynamic fields."""
|
||||
return data.get("tool_input_schema", {})
|
||||
|
||||
@classmethod
|
||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
||||
"""Return the current tool_arguments as defaults for the dynamic fields."""
|
||||
return data.get("tool_arguments", {})
|
||||
|
||||
@classmethod
|
||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
||||
"""Check which required tool arguments are missing."""
|
||||
required_fields = cls.get_input_schema(data).get("required", [])
|
||||
tool_arguments = data.get("tool_arguments", {})
|
||||
return set(required_fields) - set(tool_arguments)
|
||||
|
||||
@classmethod
|
||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
||||
"""Validate tool_arguments against the tool's input schema."""
|
||||
tool_schema = cls.get_input_schema(data)
|
||||
if not tool_schema:
|
||||
return None
|
||||
tool_arguments = data.get("tool_arguments", {})
|
||||
return validate_with_jsonschema(tool_schema, tool_arguments)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
result: Any = SchemaField(description="The result returned by the MCP tool")
|
||||
error: str = SchemaField(description="Error message if the tool call failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4",
|
||||
description="Connect to any MCP server and execute its tools. "
|
||||
"Provide a server URL, select a tool, and pass arguments dynamically.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=MCPToolBlock.Input,
|
||||
output_schema=MCPToolBlock.Output,
|
||||
block_type=BlockType.STANDARD,
|
||||
test_input={
|
||||
"server_url": "https://mcp.example.com/mcp",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"selected_tool": "get_weather",
|
||||
"tool_input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
"tool_arguments": {"city": "London"},
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"result",
|
||||
{"weather": "sunny", "temperature": 20},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"_call_mcp_tool": lambda *a, **kw: {
|
||||
"weather": "sunny",
|
||||
"temperature": 20,
|
||||
},
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
)
|
||||
|
||||
async def _call_mcp_tool(
|
||||
self,
|
||||
server_url: str,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
auth_token: str | None = None,
|
||||
) -> Any:
|
||||
"""Call a tool on the MCP server. Extracted for easy mocking in tests."""
|
||||
# Trust the user-configured server URL to allow internal/localhost servers
|
||||
client = MCPClient(
|
||||
server_url,
|
||||
auth_token=auth_token,
|
||||
trusted_origins=[server_url],
|
||||
)
|
||||
await client.initialize()
|
||||
result = await client.call_tool(tool_name, arguments)
|
||||
|
||||
if result.is_error:
|
||||
error_text = ""
|
||||
for item in result.content:
|
||||
if item.get("type") == "text":
|
||||
error_text += item.get("text", "")
|
||||
raise MCPClientError(
|
||||
f"MCP tool '{tool_name}' returned an error: "
|
||||
f"{error_text or 'Unknown error'}"
|
||||
)
|
||||
|
||||
# Extract text content from the result
|
||||
output_parts = []
|
||||
for item in result.content:
|
||||
if item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
# Try to parse as JSON for structured output
|
||||
try:
|
||||
output_parts.append(json.loads(text))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
output_parts.append(text)
|
||||
elif item.get("type") == "image":
|
||||
output_parts.append(
|
||||
{
|
||||
"type": "image",
|
||||
"data": item.get("data"),
|
||||
"mimeType": item.get("mimeType"),
|
||||
}
|
||||
)
|
||||
elif item.get("type") == "resource":
|
||||
output_parts.append(item.get("resource", {}))
|
||||
|
||||
# If single result, unwrap
|
||||
if len(output_parts) == 1:
|
||||
return output_parts[0]
|
||||
return output_parts if output_parts else None
|
||||
|
||||
@staticmethod
|
||||
def _extract_auth_token(credentials: MCPCredentials) -> str | None:
|
||||
"""Extract a Bearer token from either API key or OAuth2 credentials."""
|
||||
if isinstance(credentials, OAuth2Credentials):
|
||||
return credentials.access_token.get_secret_value()
|
||||
|
||||
if isinstance(credentials, APIKeyCredentials) and credentials.api_key:
|
||||
token_value = credentials.api_key.get_secret_value()
|
||||
if token_value:
|
||||
return token_value
|
||||
|
||||
return None
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: MCPCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
if not input_data.server_url:
|
||||
yield "error", "MCP server URL is required"
|
||||
return
|
||||
|
||||
if not input_data.selected_tool:
|
||||
yield "error", "No tool selected. Please select a tool from the dropdown."
|
||||
return
|
||||
|
||||
auth_token = self._extract_auth_token(credentials)
|
||||
|
||||
try:
|
||||
result = await self._call_mcp_tool(
|
||||
server_url=input_data.server_url,
|
||||
tool_name=input_data.selected_tool,
|
||||
arguments=input_data.tool_arguments,
|
||||
auth_token=auth_token,
|
||||
)
|
||||
yield "result", result
|
||||
except MCPClientError as e:
|
||||
yield "error", str(e)
|
||||
except Exception as e:
|
||||
logger.exception(f"MCP tool call failed: {e}")
|
||||
yield "error", f"MCP tool call failed: {str(e)}"
|
||||
@@ -1,316 +0,0 @@
|
||||
"""
|
||||
MCP (Model Context Protocol) HTTP client.
|
||||
|
||||
Implements the MCP Streamable HTTP transport for listing tools and calling tools
|
||||
on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST.
|
||||
|
||||
Handles both JSON and SSE (text/event-stream) response formats per the MCP spec.
|
||||
|
||||
Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from backend.util.request import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPTool:
|
||||
"""Represents an MCP tool discovered from a server."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
input_schema: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPCallResult:
|
||||
"""Result from calling an MCP tool."""
|
||||
|
||||
content: list[dict[str, Any]] = field(default_factory=list)
|
||||
is_error: bool = False
|
||||
|
||||
|
||||
class MCPClientError(Exception):
|
||||
"""Raised when an MCP protocol error occurs."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""
|
||||
Async HTTP client for the MCP Streamable HTTP transport.
|
||||
|
||||
Communicates with MCP servers using JSON-RPC 2.0 over HTTP POST.
|
||||
Supports optional Bearer token authentication.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
auth_token: str | None = None,
|
||||
trusted_origins: list[str] | None = None,
|
||||
):
|
||||
self.server_url = server_url.rstrip("/")
|
||||
self.auth_token = auth_token
|
||||
self.trusted_origins = trusted_origins or []
|
||||
self._request_id = 0
|
||||
|
||||
def _next_id(self) -> int:
|
||||
self._request_id += 1
|
||||
return self._request_id
|
||||
|
||||
def _build_headers(self) -> dict[str, str]:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
}
|
||||
if self.auth_token:
|
||||
headers["Authorization"] = f"Bearer {self.auth_token}"
|
||||
return headers
|
||||
|
||||
def _build_jsonrpc_request(
|
||||
self, method: str, params: dict[str, Any] | None = None
|
||||
) -> dict[str, Any]:
|
||||
req: dict[str, Any] = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"id": self._next_id(),
|
||||
}
|
||||
if params is not None:
|
||||
req["params"] = params
|
||||
return req
|
||||
|
||||
@staticmethod
|
||||
def _parse_sse_response(text: str) -> dict[str, Any]:
|
||||
"""Parse an SSE (text/event-stream) response body into JSON-RPC data.
|
||||
|
||||
MCP servers may return responses as SSE with format:
|
||||
event: message
|
||||
data: {"jsonrpc":"2.0","result":{...},"id":1}
|
||||
|
||||
We extract the last `data:` line that contains a JSON-RPC response
|
||||
(i.e. has an "id" field), which is the reply to our request.
|
||||
"""
|
||||
last_data: dict[str, Any] | None = None
|
||||
for line in text.splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("data:"):
|
||||
payload = stripped[len("data:") :].strip()
|
||||
if not payload:
|
||||
continue
|
||||
try:
|
||||
parsed = json.loads(payload)
|
||||
# Only keep JSON-RPC responses (have "id"), skip notifications
|
||||
if isinstance(parsed, dict) and "id" in parsed:
|
||||
last_data = parsed
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
if last_data is None:
|
||||
raise MCPClientError("No JSON-RPC response found in SSE stream")
|
||||
return last_data
|
||||
|
||||
async def _send_request(
|
||||
self, method: str, params: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""Send a JSON-RPC request to the MCP server and return the result.
|
||||
|
||||
Handles both ``application/json`` and ``text/event-stream`` responses
|
||||
as required by the MCP Streamable HTTP transport specification.
|
||||
"""
|
||||
payload = self._build_jsonrpc_request(method, params)
|
||||
headers = self._build_headers()
|
||||
|
||||
requests = Requests(
|
||||
raise_for_status=True,
|
||||
extra_headers=headers,
|
||||
trusted_origins=self.trusted_origins,
|
||||
)
|
||||
response = await requests.post(self.server_url, json=payload)
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "text/event-stream" in content_type:
|
||||
body = self._parse_sse_response(response.text())
|
||||
else:
|
||||
try:
|
||||
body = response.json()
|
||||
except (ValueError, Exception) as e:
|
||||
raise MCPClientError(
|
||||
f"MCP server returned non-JSON response: {e}"
|
||||
) from e
|
||||
|
||||
# Handle JSON-RPC error
|
||||
if "error" in body:
|
||||
error = body["error"]
|
||||
if isinstance(error, dict):
|
||||
raise MCPClientError(
|
||||
f"MCP server error [{error.get('code', '?')}]: "
|
||||
f"{error.get('message', 'Unknown error')}"
|
||||
)
|
||||
raise MCPClientError(f"MCP server error: {error}")
|
||||
|
||||
return body.get("result")
|
||||
|
||||
async def _send_notification(self, method: str) -> None:
|
||||
"""Send a JSON-RPC notification (no id, no response expected)."""
|
||||
headers = self._build_headers()
|
||||
notification = {"jsonrpc": "2.0", "method": method}
|
||||
requests = Requests(
|
||||
raise_for_status=False,
|
||||
extra_headers=headers,
|
||||
trusted_origins=self.trusted_origins,
|
||||
)
|
||||
await requests.post(self.server_url, json=notification)
|
||||
|
||||
async def discover_auth(self) -> dict[str, Any] | None:
|
||||
"""Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec).
|
||||
|
||||
Returns ``None`` if the server doesn't require auth, otherwise returns
|
||||
a dict with:
|
||||
- ``authorization_servers``: list of authorization server URLs
|
||||
- ``resource``: the resource indicator URL (usually the MCP endpoint)
|
||||
- ``scopes_supported``: optional list of supported scopes
|
||||
|
||||
The caller can then fetch the authorization server metadata to get
|
||||
``authorization_endpoint``, ``token_endpoint``, etc.
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(self.server_url)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
# Build candidates for protected-resource metadata (per RFC 9728)
|
||||
path = parsed.path.rstrip("/")
|
||||
candidates = []
|
||||
if path and path != "/":
|
||||
candidates.append(f"{base}/.well-known/oauth-protected-resource{path}")
|
||||
candidates.append(f"{base}/.well-known/oauth-protected-resource")
|
||||
|
||||
requests = Requests(
|
||||
raise_for_status=False,
|
||||
trusted_origins=self.trusted_origins,
|
||||
)
|
||||
for url in candidates:
|
||||
try:
|
||||
resp = await requests.get(url)
|
||||
if resp.status == 200:
|
||||
data = resp.json()
|
||||
if isinstance(data, dict) and "authorization_servers" in data:
|
||||
return data
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
async def discover_auth_server_metadata(
|
||||
self, auth_server_url: str
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch the OAuth Authorization Server Metadata (RFC 8414).
|
||||
|
||||
Given an authorization server URL, returns a dict with:
|
||||
- ``authorization_endpoint``
|
||||
- ``token_endpoint``
|
||||
- ``registration_endpoint`` (for dynamic client registration)
|
||||
- ``scopes_supported``
|
||||
- ``code_challenge_methods_supported``
|
||||
- etc.
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(auth_server_url)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
path = parsed.path.rstrip("/")
|
||||
|
||||
# Try standard metadata endpoints (RFC 8414 and OpenID Connect)
|
||||
candidates = []
|
||||
if path and path != "/":
|
||||
candidates.append(f"{base}/.well-known/oauth-authorization-server{path}")
|
||||
candidates.append(f"{base}/.well-known/oauth-authorization-server")
|
||||
candidates.append(f"{base}/.well-known/openid-configuration")
|
||||
|
||||
requests = Requests(
|
||||
raise_for_status=False,
|
||||
trusted_origins=self.trusted_origins,
|
||||
)
|
||||
for url in candidates:
|
||||
try:
|
||||
resp = await requests.get(url)
|
||||
if resp.status == 200:
|
||||
data = resp.json()
|
||||
if isinstance(data, dict) and "authorization_endpoint" in data:
|
||||
return data
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
async def initialize(self) -> dict[str, Any]:
|
||||
"""
|
||||
Send the MCP initialize request.
|
||||
|
||||
This is required by the MCP protocol before any other requests.
|
||||
Returns the server's capabilities.
|
||||
"""
|
||||
result = await self._send_request(
|
||||
"initialize",
|
||||
{
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "AutoGPT-Platform", "version": "1.0.0"},
|
||||
},
|
||||
)
|
||||
# Send initialized notification (no response expected)
|
||||
await self._send_notification("notifications/initialized")
|
||||
|
||||
return result or {}
|
||||
|
||||
async def list_tools(self) -> list[MCPTool]:
|
||||
"""
|
||||
Discover available tools from the MCP server.
|
||||
|
||||
Returns a list of MCPTool objects with name, description, and input schema.
|
||||
"""
|
||||
result = await self._send_request("tools/list")
|
||||
if not result or "tools" not in result:
|
||||
return []
|
||||
|
||||
tools = []
|
||||
for tool_data in result["tools"]:
|
||||
tools.append(
|
||||
MCPTool(
|
||||
name=tool_data.get("name", ""),
|
||||
description=tool_data.get("description", ""),
|
||||
input_schema=tool_data.get("inputSchema", {}),
|
||||
)
|
||||
)
|
||||
return tools
|
||||
|
||||
async def call_tool(
|
||||
self, tool_name: str, arguments: dict[str, Any]
|
||||
) -> MCPCallResult:
|
||||
"""
|
||||
Call a tool on the MCP server.
|
||||
|
||||
Args:
|
||||
tool_name: The name of the tool to call.
|
||||
arguments: The arguments to pass to the tool.
|
||||
|
||||
Returns:
|
||||
MCPCallResult with the tool's response content.
|
||||
"""
|
||||
result = await self._send_request(
|
||||
"tools/call",
|
||||
{"name": tool_name, "arguments": arguments},
|
||||
)
|
||||
if not result:
|
||||
return MCPCallResult(is_error=True)
|
||||
|
||||
return MCPCallResult(
|
||||
content=result.get("content", []),
|
||||
is_error=result.get("isError", False),
|
||||
)
|
||||
@@ -1,21 +0,0 @@
|
||||
"""
|
||||
Conftest for MCP block tests.
|
||||
|
||||
Override the session-scoped server and graph_cleanup fixtures from
|
||||
backend/conftest.py so that MCP integration tests don't spin up the
|
||||
full SpinTestServer infrastructure.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server():
|
||||
"""No-op override — MCP tests don't need the full platform server."""
|
||||
yield None
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def graph_cleanup(server):
|
||||
"""No-op override — MCP tests don't create graphs."""
|
||||
yield
|
||||
@@ -1,198 +0,0 @@
|
||||
"""
|
||||
MCP OAuth handler for MCP servers that use OAuth 2.1 authorization.
|
||||
|
||||
Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed,
|
||||
MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata.
|
||||
This handler accepts those endpoints at construction time.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import Requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
OAuth handler for MCP servers with dynamically-discovered endpoints.
|
||||
|
||||
Construction requires the authorization and token endpoint URLs,
|
||||
which are obtained via MCP OAuth metadata discovery
|
||||
(``MCPClient.discover_auth`` + ``discover_auth_server_metadata``).
|
||||
"""
|
||||
|
||||
PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP
|
||||
DEFAULT_SCOPES: ClassVar[list[str]] = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
redirect_uri: str,
|
||||
*,
|
||||
authorize_url: str,
|
||||
token_url: str,
|
||||
revoke_url: str | None = None,
|
||||
resource_url: str | None = None,
|
||||
):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.authorize_url = authorize_url
|
||||
self.token_url = token_url
|
||||
self.revoke_url = revoke_url
|
||||
self.resource_url = resource_url
|
||||
|
||||
def get_login_url(
|
||||
self,
|
||||
scopes: list[str],
|
||||
state: str,
|
||||
code_challenge: Optional[str],
|
||||
) -> str:
|
||||
scopes = self.handle_default_scopes(scopes)
|
||||
|
||||
params: dict[str, str] = {
|
||||
"response_type": "code",
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"state": state,
|
||||
}
|
||||
if scopes:
|
||||
params["scope"] = " ".join(scopes)
|
||||
# PKCE (S256) — included when the caller provides a code_challenge
|
||||
if code_challenge:
|
||||
params["code_challenge"] = code_challenge
|
||||
params["code_challenge_method"] = "S256"
|
||||
# MCP spec requires resource indicator (RFC 8707)
|
||||
if self.resource_url:
|
||||
params["resource"] = self.resource_url
|
||||
|
||||
return f"{self.authorize_url}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
async def exchange_code_for_tokens(
|
||||
self,
|
||||
code: str,
|
||||
scopes: list[str],
|
||||
code_verifier: Optional[str],
|
||||
) -> OAuth2Credentials:
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"client_id": self.client_id,
|
||||
}
|
||||
if self.client_secret:
|
||||
data["client_secret"] = self.client_secret
|
||||
if code_verifier:
|
||||
data["code_verifier"] = code_verifier
|
||||
if self.resource_url:
|
||||
data["resource"] = self.resource_url
|
||||
|
||||
response = await Requests(raise_for_status=True).post(
|
||||
self.token_url,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
tokens = response.json()
|
||||
|
||||
if "error" in tokens:
|
||||
raise RuntimeError(
|
||||
f"Token exchange failed: {tokens.get('error_description', tokens['error'])}"
|
||||
)
|
||||
|
||||
now = int(time.time())
|
||||
expires_in = tokens.get("expires_in")
|
||||
|
||||
return OAuth2Credentials(
|
||||
provider=str(self.PROVIDER_NAME),
|
||||
title=None,
|
||||
access_token=SecretStr(tokens["access_token"]),
|
||||
refresh_token=(
|
||||
SecretStr(tokens["refresh_token"])
|
||||
if tokens.get("refresh_token")
|
||||
else None
|
||||
),
|
||||
access_token_expires_at=now + expires_in if expires_in else None,
|
||||
refresh_token_expires_at=None,
|
||||
scopes=scopes,
|
||||
metadata={
|
||||
"mcp_token_url": self.token_url,
|
||||
"mcp_resource_url": self.resource_url,
|
||||
},
|
||||
)
|
||||
|
||||
async def _refresh_tokens(
|
||||
self, credentials: OAuth2Credentials
|
||||
) -> OAuth2Credentials:
|
||||
if not credentials.refresh_token:
|
||||
raise ValueError("No refresh token available for MCP OAuth credentials")
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||
"client_id": self.client_id,
|
||||
}
|
||||
if self.client_secret:
|
||||
data["client_secret"] = self.client_secret
|
||||
if self.resource_url:
|
||||
data["resource"] = self.resource_url
|
||||
|
||||
response = await Requests(raise_for_status=True).post(
|
||||
self.token_url,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
tokens = response.json()
|
||||
|
||||
if "error" in tokens:
|
||||
raise RuntimeError(
|
||||
f"Token refresh failed: {tokens.get('error_description', tokens['error'])}"
|
||||
)
|
||||
|
||||
now = int(time.time())
|
||||
expires_in = tokens.get("expires_in")
|
||||
|
||||
return OAuth2Credentials(
|
||||
id=credentials.id,
|
||||
provider=str(self.PROVIDER_NAME),
|
||||
title=credentials.title,
|
||||
access_token=SecretStr(tokens["access_token"]),
|
||||
refresh_token=(
|
||||
SecretStr(str(tokens["refresh_token"]))
|
||||
if tokens.get("refresh_token")
|
||||
else credentials.refresh_token
|
||||
),
|
||||
access_token_expires_at=now + expires_in if expires_in else None,
|
||||
refresh_token_expires_at=credentials.refresh_token_expires_at,
|
||||
scopes=credentials.scopes,
|
||||
metadata=credentials.metadata,
|
||||
)
|
||||
|
||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
if not self.revoke_url:
|
||||
return False
|
||||
|
||||
try:
|
||||
data = {
|
||||
"token": credentials.access_token.get_secret_value(),
|
||||
"token_type_hint": "access_token",
|
||||
"client_id": self.client_id,
|
||||
}
|
||||
await Requests().post(
|
||||
self.revoke_url,
|
||||
data=data,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True)
|
||||
return False
|
||||
@@ -1,104 +0,0 @@
|
||||
"""
|
||||
End-to-end tests against a real public MCP server.
|
||||
|
||||
These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp)
|
||||
which is publicly accessible without authentication and returns SSE responses.
|
||||
|
||||
Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped
|
||||
independently of the rest of the test suite (they require network access).
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.mcp.client import MCPClient
|
||||
|
||||
# Public MCP server that requires no authentication
|
||||
OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp"
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
class TestRealMCPServer:
|
||||
"""Tests against the live OpenAI docs MCP server."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize(self):
|
||||
"""Verify we can complete the MCP handshake with a real server."""
|
||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||
result = await client.initialize()
|
||||
|
||||
assert result["protocolVersion"] == "2025-03-26"
|
||||
assert "serverInfo" in result
|
||||
assert result["serverInfo"]["name"] == "openai-docs-mcp"
|
||||
assert "tools" in result.get("capabilities", {})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools(self):
|
||||
"""Verify we can discover tools from a real MCP server."""
|
||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||
await client.initialize()
|
||||
tools = await client.list_tools()
|
||||
|
||||
assert len(tools) >= 3 # server has at least 5 tools as of writing
|
||||
|
||||
tool_names = {t.name for t in tools}
|
||||
# These tools are documented and should be stable
|
||||
assert "search_openai_docs" in tool_names
|
||||
assert "list_openai_docs" in tool_names
|
||||
assert "fetch_openai_doc" in tool_names
|
||||
|
||||
# Verify schema structure
|
||||
search_tool = next(t for t in tools if t.name == "search_openai_docs")
|
||||
assert "query" in search_tool.input_schema.get("properties", {})
|
||||
assert "query" in search_tool.input_schema.get("required", [])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_list_api_endpoints(self):
|
||||
"""Call the list_api_endpoints tool and verify we get real data."""
|
||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||
await client.initialize()
|
||||
result = await client.call_tool("list_api_endpoints", {})
|
||||
|
||||
assert not result.is_error
|
||||
assert len(result.content) >= 1
|
||||
assert result.content[0]["type"] == "text"
|
||||
|
||||
data = json.loads(result.content[0]["text"])
|
||||
assert "paths" in data or "urls" in data
|
||||
# The OpenAI API should have many endpoints
|
||||
total = data.get("total", len(data.get("paths", [])))
|
||||
assert total > 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_search(self):
|
||||
"""Search for docs and verify we get results."""
|
||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||
await client.initialize()
|
||||
result = await client.call_tool(
|
||||
"search_openai_docs", {"query": "chat completions", "limit": 3}
|
||||
)
|
||||
|
||||
assert not result.is_error
|
||||
assert len(result.content) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_response_handling(self):
|
||||
"""Verify the client correctly handles SSE responses from a real server.
|
||||
|
||||
This is the key test — our local test server returns JSON,
|
||||
but real MCP servers typically return SSE. This proves the
|
||||
SSE parsing works end-to-end.
|
||||
"""
|
||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
||||
# initialize() internally calls _send_request which must parse SSE
|
||||
result = await client.initialize()
|
||||
|
||||
# If we got here without error, SSE parsing works
|
||||
assert isinstance(result, dict)
|
||||
assert "protocolVersion" in result
|
||||
|
||||
# Also verify list_tools works (another SSE response)
|
||||
tools = await client.list_tools()
|
||||
assert len(tools) > 0
|
||||
assert all(hasattr(t, "name") for t in tools)
|
||||
@@ -1,367 +0,0 @@
|
||||
"""
|
||||
Integration tests for MCP client and MCPToolBlock against a real HTTP server.
|
||||
|
||||
These tests spin up a local MCP test server and run the full client/block flow
|
||||
against it — no mocking, real HTTP requests.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.mcp.block import MCPToolBlock
|
||||
from backend.blocks.mcp.client import MCPClient
|
||||
from backend.blocks.mcp.test_server import create_test_mcp_app
|
||||
from backend.data.model import APIKeyCredentials
|
||||
|
||||
|
||||
class _MCPTestServer:
|
||||
"""
|
||||
Run an MCP test server in a background thread with its own event loop.
|
||||
This avoids event loop conflicts with pytest-asyncio.
|
||||
"""
|
||||
|
||||
def __init__(self, auth_token: str | None = None):
|
||||
self.auth_token = auth_token
|
||||
self.url: str = ""
|
||||
self._runner: web.AppRunner | None = None
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._thread: threading.Thread | None = None
|
||||
self._started = threading.Event()
|
||||
|
||||
def _run(self):
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._loop.run_until_complete(self._start())
|
||||
self._started.set()
|
||||
self._loop.run_forever()
|
||||
|
||||
async def _start(self):
|
||||
app = create_test_mcp_app(auth_token=self.auth_token)
|
||||
self._runner = web.AppRunner(app)
|
||||
await self._runner.setup()
|
||||
site = web.TCPSite(self._runner, "127.0.0.1", 0)
|
||||
await site.start()
|
||||
port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr]
|
||||
self.url = f"http://127.0.0.1:{port}/mcp"
|
||||
|
||||
def start(self):
|
||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
||||
self._thread.start()
|
||||
if not self._started.wait(timeout=5):
|
||||
raise RuntimeError("MCP test server failed to start within 5 seconds")
|
||||
return self
|
||||
|
||||
def stop(self):
|
||||
if self._loop and self._runner:
|
||||
asyncio.run_coroutine_threadsafe(self._runner.cleanup(), self._loop).result(
|
||||
timeout=5
|
||||
)
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mcp_server():
|
||||
"""Start a local MCP test server in a background thread."""
|
||||
server = _MCPTestServer()
|
||||
server.start()
|
||||
yield server.url
|
||||
server.stop()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mcp_server_with_auth():
|
||||
"""Start a local MCP test server with auth in a background thread."""
|
||||
server = _MCPTestServer(auth_token="test-secret-token")
|
||||
server.start()
|
||||
yield server.url, "test-secret-token"
|
||||
server.stop()
|
||||
|
||||
|
||||
def _make_client(url: str, auth_token: str | None = None) -> MCPClient:
|
||||
"""Create an MCPClient with localhost trusted for integration tests."""
|
||||
return MCPClient(url, auth_token=auth_token, trusted_origins=[url])
|
||||
|
||||
|
||||
def _make_fake_creds(api_key: str = "FAKE_API_KEY") -> APIKeyCredentials:
|
||||
return APIKeyCredentials(
|
||||
id="test-integration",
|
||||
provider="mcp",
|
||||
api_key=SecretStr(api_key),
|
||||
title="test",
|
||||
)
|
||||
|
||||
|
||||
# ── MCPClient integration tests ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestMCPClientIntegration:
|
||||
"""Test MCPClient against a real local MCP server."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize(self, mcp_server):
|
||||
client = _make_client(mcp_server)
|
||||
result = await client.initialize()
|
||||
|
||||
assert result["protocolVersion"] == "2025-03-26"
|
||||
assert result["serverInfo"]["name"] == "test-mcp-server"
|
||||
assert "tools" in result["capabilities"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools(self, mcp_server):
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
tools = await client.list_tools()
|
||||
|
||||
assert len(tools) == 3
|
||||
|
||||
tool_names = {t.name for t in tools}
|
||||
assert tool_names == {"get_weather", "add_numbers", "echo"}
|
||||
|
||||
# Check get_weather schema
|
||||
weather = next(t for t in tools if t.name == "get_weather")
|
||||
assert weather.description == "Get current weather for a city"
|
||||
assert "city" in weather.input_schema["properties"]
|
||||
assert weather.input_schema["required"] == ["city"]
|
||||
|
||||
# Check add_numbers schema
|
||||
add = next(t for t in tools if t.name == "add_numbers")
|
||||
assert "a" in add.input_schema["properties"]
|
||||
assert "b" in add.input_schema["properties"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_get_weather(self, mcp_server):
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
result = await client.call_tool("get_weather", {"city": "London"})
|
||||
|
||||
assert not result.is_error
|
||||
assert len(result.content) == 1
|
||||
assert result.content[0]["type"] == "text"
|
||||
|
||||
data = json.loads(result.content[0]["text"])
|
||||
assert data["city"] == "London"
|
||||
assert data["temperature"] == 22
|
||||
assert data["condition"] == "sunny"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_add_numbers(self, mcp_server):
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
result = await client.call_tool("add_numbers", {"a": 3, "b": 7})
|
||||
|
||||
assert not result.is_error
|
||||
data = json.loads(result.content[0]["text"])
|
||||
assert data["result"] == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_echo(self, mcp_server):
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
result = await client.call_tool("echo", {"message": "Hello MCP!"})
|
||||
|
||||
assert not result.is_error
|
||||
assert result.content[0]["text"] == "Hello MCP!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_unknown_tool(self, mcp_server):
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
result = await client.call_tool("nonexistent_tool", {})
|
||||
|
||||
assert result.is_error
|
||||
assert "Unknown tool" in result.content[0]["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_success(self, mcp_server_with_auth):
|
||||
url, token = mcp_server_with_auth
|
||||
client = _make_client(url, auth_token=token)
|
||||
result = await client.initialize()
|
||||
|
||||
assert result["protocolVersion"] == "2025-03-26"
|
||||
|
||||
tools = await client.list_tools()
|
||||
assert len(tools) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_failure(self, mcp_server_with_auth):
|
||||
url, _ = mcp_server_with_auth
|
||||
client = _make_client(url, auth_token="wrong-token")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await client.initialize()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_missing(self, mcp_server_with_auth):
|
||||
url, _ = mcp_server_with_auth
|
||||
client = _make_client(url)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await client.initialize()
|
||||
|
||||
|
||||
# ── MCPToolBlock integration tests ───────────────────────────────────
|
||||
|
||||
|
||||
class TestMCPToolBlockIntegration:
|
||||
"""Test MCPToolBlock end-to-end against a real local MCP server."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_flow_get_weather(self, mcp_server):
|
||||
"""Full flow: discover tools, select one, execute it."""
|
||||
# Step 1: Discover tools (simulating what the frontend/API would do)
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
tools = await client.list_tools()
|
||||
assert len(tools) == 3
|
||||
|
||||
# Step 2: User selects "get_weather" and we get its schema
|
||||
weather_tool = next(t for t in tools if t.name == "get_weather")
|
||||
|
||||
# Step 3: Execute the block with the selected tool
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url=mcp_server,
|
||||
selected_tool="get_weather",
|
||||
tool_input_schema=weather_tool.input_schema,
|
||||
tool_arguments={"city": "Paris"},
|
||||
credentials={ # type: ignore
|
||||
"provider": "mcp",
|
||||
"id": "test",
|
||||
"type": "api_key",
|
||||
"title": "test",
|
||||
},
|
||||
)
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, credentials=_make_fake_creds()):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "result"
|
||||
result = outputs[0][1]
|
||||
assert result["city"] == "Paris"
|
||||
assert result["temperature"] == 22
|
||||
assert result["condition"] == "sunny"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_flow_add_numbers(self, mcp_server):
|
||||
"""Full flow for add_numbers tool."""
|
||||
client = _make_client(mcp_server)
|
||||
await client.initialize()
|
||||
tools = await client.list_tools()
|
||||
add_tool = next(t for t in tools if t.name == "add_numbers")
|
||||
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url=mcp_server,
|
||||
selected_tool="add_numbers",
|
||||
tool_input_schema=add_tool.input_schema,
|
||||
tool_arguments={"a": 42, "b": 58},
|
||||
credentials={ # type: ignore
|
||||
"provider": "mcp",
|
||||
"id": "test",
|
||||
"type": "api_key",
|
||||
"title": "test",
|
||||
},
|
||||
)
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, credentials=_make_fake_creds()):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "result"
|
||||
assert outputs[0][1]["result"] == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_flow_echo_plain_text(self, mcp_server):
|
||||
"""Verify plain text (non-JSON) responses work."""
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url=mcp_server,
|
||||
selected_tool="echo",
|
||||
tool_input_schema={
|
||||
"type": "object",
|
||||
"properties": {"message": {"type": "string"}},
|
||||
"required": ["message"],
|
||||
},
|
||||
tool_arguments={"message": "Hello from AutoGPT!"},
|
||||
credentials={ # type: ignore
|
||||
"provider": "mcp",
|
||||
"id": "test",
|
||||
"type": "api_key",
|
||||
"title": "test",
|
||||
},
|
||||
)
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, credentials=_make_fake_creds()):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "result"
|
||||
assert outputs[0][1] == "Hello from AutoGPT!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_flow_unknown_tool_yields_error(self, mcp_server):
|
||||
"""Calling an unknown tool should yield an error output."""
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url=mcp_server,
|
||||
selected_tool="nonexistent_tool",
|
||||
tool_arguments={},
|
||||
credentials={ # type: ignore
|
||||
"provider": "mcp",
|
||||
"id": "test",
|
||||
"type": "api_key",
|
||||
"title": "test",
|
||||
},
|
||||
)
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, credentials=_make_fake_creds()):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "error"
|
||||
assert "returned an error" in outputs[0][1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_flow_with_auth(self, mcp_server_with_auth):
|
||||
"""Full flow with authentication."""
|
||||
url, token = mcp_server_with_auth
|
||||
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url=url,
|
||||
selected_tool="echo",
|
||||
tool_input_schema={
|
||||
"type": "object",
|
||||
"properties": {"message": {"type": "string"}},
|
||||
"required": ["message"],
|
||||
},
|
||||
tool_arguments={"message": "Authenticated!"},
|
||||
credentials={ # type: ignore
|
||||
"provider": "mcp",
|
||||
"id": "test",
|
||||
"type": "api_key",
|
||||
"title": "test",
|
||||
},
|
||||
)
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(
|
||||
input_data, credentials=_make_fake_creds(api_key=token)
|
||||
):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "result"
|
||||
assert outputs[0][1] == "Authenticated!"
|
||||
@@ -1,667 +0,0 @@
|
||||
"""
|
||||
Tests for MCP client and MCPToolBlock.
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.mcp.block import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
MCPToolBlock,
|
||||
)
|
||||
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError
|
||||
from backend.data.model import APIKeyCredentials, OAuth2Credentials
|
||||
from backend.util.test import execute_block_test
|
||||
|
||||
# ── SSE parsing unit tests ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSSEParsing:
|
||||
"""Tests for SSE (text/event-stream) response parsing."""
|
||||
|
||||
def test_parse_sse_simple(self):
|
||||
sse = (
|
||||
"event: message\n"
|
||||
'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n'
|
||||
"\n"
|
||||
)
|
||||
body = MCPClient._parse_sse_response(sse)
|
||||
assert body["result"] == {"tools": []}
|
||||
assert body["id"] == 1
|
||||
|
||||
def test_parse_sse_with_notifications(self):
|
||||
"""SSE streams can contain notifications (no id) before the response."""
|
||||
sse = (
|
||||
"event: message\n"
|
||||
'data: {"jsonrpc":"2.0","method":"some/notification"}\n'
|
||||
"\n"
|
||||
"event: message\n"
|
||||
'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n'
|
||||
"\n"
|
||||
)
|
||||
body = MCPClient._parse_sse_response(sse)
|
||||
assert body["result"] == {"ok": True}
|
||||
assert body["id"] == 2
|
||||
|
||||
def test_parse_sse_error_response(self):
|
||||
sse = (
|
||||
"event: message\n"
|
||||
'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n'
|
||||
)
|
||||
body = MCPClient._parse_sse_response(sse)
|
||||
assert "error" in body
|
||||
assert body["error"]["code"] == -32600
|
||||
|
||||
def test_parse_sse_no_data_raises(self):
|
||||
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
|
||||
MCPClient._parse_sse_response("event: message\n\n")
|
||||
|
||||
def test_parse_sse_empty_raises(self):
|
||||
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
|
||||
MCPClient._parse_sse_response("")
|
||||
|
||||
def test_parse_sse_ignores_non_data_lines(self):
|
||||
sse = (
|
||||
": comment line\n"
|
||||
"event: message\n"
|
||||
"id: 123\n"
|
||||
'data: {"jsonrpc":"2.0","result":"ok","id":1}\n'
|
||||
"\n"
|
||||
)
|
||||
body = MCPClient._parse_sse_response(sse)
|
||||
assert body["result"] == "ok"
|
||||
|
||||
def test_parse_sse_uses_last_response(self):
|
||||
"""If multiple responses exist, use the last one."""
|
||||
sse = (
|
||||
'data: {"jsonrpc":"2.0","result":"first","id":1}\n'
|
||||
"\n"
|
||||
'data: {"jsonrpc":"2.0","result":"second","id":2}\n'
|
||||
"\n"
|
||||
)
|
||||
body = MCPClient._parse_sse_response(sse)
|
||||
assert body["result"] == "second"
|
||||
|
||||
|
||||
# ── MCPClient unit tests ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestMCPClient:
|
||||
"""Tests for the MCP HTTP client."""
|
||||
|
||||
def test_build_headers_without_auth(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
headers = client._build_headers()
|
||||
assert "Authorization" not in headers
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
|
||||
def test_build_headers_with_auth(self):
|
||||
client = MCPClient("https://mcp.example.com", auth_token="my-token")
|
||||
headers = client._build_headers()
|
||||
assert headers["Authorization"] == "Bearer my-token"
|
||||
|
||||
def test_build_jsonrpc_request(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
req = client._build_jsonrpc_request("tools/list")
|
||||
assert req["jsonrpc"] == "2.0"
|
||||
assert req["method"] == "tools/list"
|
||||
assert "id" in req
|
||||
assert "params" not in req
|
||||
|
||||
def test_build_jsonrpc_request_with_params(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
req = client._build_jsonrpc_request(
|
||||
"tools/call", {"name": "test", "arguments": {"x": 1}}
|
||||
)
|
||||
assert req["params"] == {"name": "test", "arguments": {"x": 1}}
|
||||
|
||||
def test_request_id_increments(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
req1 = client._build_jsonrpc_request("tools/list")
|
||||
req2 = client._build_jsonrpc_request("tools/list")
|
||||
assert req2["id"] > req1["id"]
|
||||
|
||||
def test_server_url_trailing_slash_stripped(self):
|
||||
client = MCPClient("https://mcp.example.com/mcp/")
|
||||
assert client.server_url == "https://mcp.example.com/mcp"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_request_success(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.json.return_value = {
|
||||
"jsonrpc": "2.0",
|
||||
"result": {"tools": []},
|
||||
"id": 1,
|
||||
}
|
||||
|
||||
with patch.object(client, "_send_request", return_value={"tools": []}):
|
||||
result = await client._send_request("tools/list")
|
||||
assert result == {"tools": []}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_request_error(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
async def mock_send(*args, **kwargs):
|
||||
raise MCPClientError("MCP server error [-32600]: Invalid Request")
|
||||
|
||||
with patch.object(client, "_send_request", side_effect=mock_send):
|
||||
with pytest.raises(MCPClientError, match="Invalid Request"):
|
||||
await client._send_request("tools/list")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
mock_result = {
|
||||
"tools": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather for a city",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "search",
|
||||
"description": "Search the web",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch.object(client, "_send_request", return_value=mock_result):
|
||||
tools = await client.list_tools()
|
||||
|
||||
assert len(tools) == 2
|
||||
assert tools[0].name == "get_weather"
|
||||
assert tools[0].description == "Get current weather for a city"
|
||||
assert tools[0].input_schema["properties"]["city"]["type"] == "string"
|
||||
assert tools[1].name == "search"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_empty(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
with patch.object(client, "_send_request", return_value={"tools": []}):
|
||||
tools = await client.list_tools()
|
||||
|
||||
assert tools == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_none_result(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
with patch.object(client, "_send_request", return_value=None):
|
||||
tools = await client.list_tools()
|
||||
|
||||
assert tools == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_success(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
mock_result = {
|
||||
"content": [
|
||||
{"type": "text", "text": json.dumps({"temp": 20, "city": "London"})}
|
||||
],
|
||||
"isError": False,
|
||||
}
|
||||
|
||||
with patch.object(client, "_send_request", return_value=mock_result):
|
||||
result = await client.call_tool("get_weather", {"city": "London"})
|
||||
|
||||
assert not result.is_error
|
||||
assert len(result.content) == 1
|
||||
assert result.content[0]["type"] == "text"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_error(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
mock_result = {
|
||||
"content": [{"type": "text", "text": "City not found"}],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
with patch.object(client, "_send_request", return_value=mock_result):
|
||||
result = await client.call_tool("get_weather", {"city": "???"})
|
||||
|
||||
assert result.is_error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_none_result(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
with patch.object(client, "_send_request", return_value=None):
|
||||
result = await client.call_tool("get_weather", {"city": "London"})
|
||||
|
||||
assert result.is_error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize(self):
|
||||
client = MCPClient("https://mcp.example.com")
|
||||
|
||||
mock_result = {
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": {"tools": {}},
|
||||
"serverInfo": {"name": "test-server", "version": "1.0.0"},
|
||||
}
|
||||
|
||||
with (
|
||||
patch.object(client, "_send_request", return_value=mock_result) as mock_req,
|
||||
patch.object(client, "_send_notification") as mock_notif,
|
||||
):
|
||||
result = await client.initialize()
|
||||
|
||||
mock_req.assert_called_once()
|
||||
mock_notif.assert_called_once_with("notifications/initialized")
|
||||
assert result["protocolVersion"] == "2025-03-26"
|
||||
|
||||
|
||||
# ── MCPToolBlock unit tests ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestMCPToolBlock:
|
||||
"""Tests for the MCPToolBlock."""
|
||||
|
||||
def test_block_instantiation(self):
|
||||
block = MCPToolBlock()
|
||||
assert block.id == "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
|
||||
assert block.name == "MCPToolBlock"
|
||||
|
||||
def test_input_schema_has_required_fields(self):
|
||||
block = MCPToolBlock()
|
||||
schema = block.input_schema.jsonschema()
|
||||
props = schema.get("properties", {})
|
||||
assert "server_url" in props
|
||||
assert "selected_tool" in props
|
||||
assert "tool_arguments" in props
|
||||
assert "credentials" in props
|
||||
|
||||
def test_output_schema(self):
|
||||
block = MCPToolBlock()
|
||||
schema = block.output_schema.jsonschema()
|
||||
props = schema.get("properties", {})
|
||||
assert "result" in props
|
||||
assert "error" in props
|
||||
|
||||
def test_get_input_schema_with_tool_schema(self):
|
||||
tool_schema = {
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"],
|
||||
}
|
||||
data = {"tool_input_schema": tool_schema}
|
||||
result = MCPToolBlock.Input.get_input_schema(data)
|
||||
assert result == tool_schema
|
||||
|
||||
def test_get_input_schema_without_tool_schema(self):
|
||||
result = MCPToolBlock.Input.get_input_schema({})
|
||||
assert result == {}
|
||||
|
||||
def test_get_input_defaults(self):
|
||||
data = {"tool_arguments": {"city": "London"}}
|
||||
result = MCPToolBlock.Input.get_input_defaults(data)
|
||||
assert result == {"city": "London"}
|
||||
|
||||
def test_get_missing_input(self):
|
||||
data = {
|
||||
"tool_input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"units": {"type": "string"},
|
||||
},
|
||||
"required": ["city", "units"],
|
||||
},
|
||||
"tool_arguments": {"city": "London"},
|
||||
}
|
||||
missing = MCPToolBlock.Input.get_missing_input(data)
|
||||
assert missing == {"units"}
|
||||
|
||||
def test_get_missing_input_all_present(self):
|
||||
data = {
|
||||
"tool_input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
"required": ["city"],
|
||||
},
|
||||
"tool_arguments": {"city": "London"},
|
||||
}
|
||||
missing = MCPToolBlock.Input.get_missing_input(data)
|
||||
assert missing == set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_mock(self):
|
||||
"""Test the block using the built-in test infrastructure."""
|
||||
block = MCPToolBlock()
|
||||
await execute_block_test(block)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_missing_server_url(self):
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="",
|
||||
selected_tool="test",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, credentials=TEST_CREDENTIALS):
|
||||
outputs.append((name, data))
|
||||
assert outputs == [("error", "MCP server URL is required")]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_missing_tool(self):
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
selected_tool="",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, credentials=TEST_CREDENTIALS):
|
||||
outputs.append((name, data))
|
||||
assert outputs == [
|
||||
("error", "No tool selected. Please select a tool from the dropdown.")
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_success(self):
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
selected_tool="get_weather",
|
||||
tool_input_schema={
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
},
|
||||
tool_arguments={"city": "London"},
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
async def mock_call(*args, **kwargs):
|
||||
return {"temp": 20, "city": "London"}
|
||||
|
||||
block._call_mcp_tool = mock_call # type: ignore
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, credentials=TEST_CREDENTIALS):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "result"
|
||||
assert outputs[0][1] == {"temp": 20, "city": "London"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_mcp_error(self):
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
selected_tool="bad_tool",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
async def mock_call(*args, **kwargs):
|
||||
raise MCPClientError("Tool not found")
|
||||
|
||||
block._call_mcp_tool = mock_call # type: ignore
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, credentials=TEST_CREDENTIALS):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert outputs[0][0] == "error"
|
||||
assert "Tool not found" in outputs[0][1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_mcp_tool_parses_json_text(self):
|
||||
block = MCPToolBlock()
|
||||
|
||||
mock_result = MCPCallResult(
|
||||
content=[
|
||||
{"type": "text", "text": '{"temp": 20}'},
|
||||
],
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
async def mock_init(self):
|
||||
return {}
|
||||
|
||||
async def mock_call(self, name, args):
|
||||
return mock_result
|
||||
|
||||
with (
|
||||
patch.object(MCPClient, "initialize", mock_init),
|
||||
patch.object(MCPClient, "call_tool", mock_call),
|
||||
):
|
||||
result = await block._call_mcp_tool(
|
||||
"https://mcp.example.com", "test_tool", {}
|
||||
)
|
||||
|
||||
assert result == {"temp": 20}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_mcp_tool_plain_text(self):
|
||||
block = MCPToolBlock()
|
||||
|
||||
mock_result = MCPCallResult(
|
||||
content=[
|
||||
{"type": "text", "text": "Hello, world!"},
|
||||
],
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
async def mock_init(self):
|
||||
return {}
|
||||
|
||||
async def mock_call(self, name, args):
|
||||
return mock_result
|
||||
|
||||
with (
|
||||
patch.object(MCPClient, "initialize", mock_init),
|
||||
patch.object(MCPClient, "call_tool", mock_call),
|
||||
):
|
||||
result = await block._call_mcp_tool(
|
||||
"https://mcp.example.com", "test_tool", {}
|
||||
)
|
||||
|
||||
assert result == "Hello, world!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_mcp_tool_multiple_content(self):
|
||||
block = MCPToolBlock()
|
||||
|
||||
mock_result = MCPCallResult(
|
||||
content=[
|
||||
{"type": "text", "text": "Part 1"},
|
||||
{"type": "text", "text": '{"part": 2}'},
|
||||
],
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
async def mock_init(self):
|
||||
return {}
|
||||
|
||||
async def mock_call(self, name, args):
|
||||
return mock_result
|
||||
|
||||
with (
|
||||
patch.object(MCPClient, "initialize", mock_init),
|
||||
patch.object(MCPClient, "call_tool", mock_call),
|
||||
):
|
||||
result = await block._call_mcp_tool(
|
||||
"https://mcp.example.com", "test_tool", {}
|
||||
)
|
||||
|
||||
assert result == ["Part 1", {"part": 2}]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_mcp_tool_error_result(self):
|
||||
block = MCPToolBlock()
|
||||
|
||||
mock_result = MCPCallResult(
|
||||
content=[{"type": "text", "text": "Something went wrong"}],
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
async def mock_init(self):
|
||||
return {}
|
||||
|
||||
async def mock_call(self, name, args):
|
||||
return mock_result
|
||||
|
||||
with (
|
||||
patch.object(MCPClient, "initialize", mock_init),
|
||||
patch.object(MCPClient, "call_tool", mock_call),
|
||||
):
|
||||
with pytest.raises(MCPClientError, match="returned an error"):
|
||||
await block._call_mcp_tool("https://mcp.example.com", "test_tool", {})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_mcp_tool_image_content(self):
|
||||
block = MCPToolBlock()
|
||||
|
||||
mock_result = MCPCallResult(
|
||||
content=[
|
||||
{
|
||||
"type": "image",
|
||||
"data": "base64data==",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
],
|
||||
is_error=False,
|
||||
)
|
||||
|
||||
async def mock_init(self):
|
||||
return {}
|
||||
|
||||
async def mock_call(self, name, args):
|
||||
return mock_result
|
||||
|
||||
with (
|
||||
patch.object(MCPClient, "initialize", mock_init),
|
||||
patch.object(MCPClient, "call_tool", mock_call),
|
||||
):
|
||||
result = await block._call_mcp_tool(
|
||||
"https://mcp.example.com", "test_tool", {}
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"type": "image",
|
||||
"data": "base64data==",
|
||||
"mimeType": "image/png",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_sends_api_key_credentials(self):
|
||||
"""Ensure non-empty API keys are sent to the MCP server."""
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
selected_tool="test_tool",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
creds = APIKeyCredentials(
|
||||
id="test-id",
|
||||
provider="mcp",
|
||||
api_key=SecretStr("real-api-key"),
|
||||
title="Real",
|
||||
)
|
||||
|
||||
captured_tokens = []
|
||||
|
||||
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
||||
captured_tokens.append(auth_token)
|
||||
return "ok"
|
||||
|
||||
block._call_mcp_tool = mock_call # type: ignore
|
||||
|
||||
async for _ in block.run(input_data, credentials=creds):
|
||||
pass
|
||||
|
||||
assert captured_tokens == ["real-api-key"]
|
||||
|
||||
|
||||
# ── OAuth2 credential support tests ─────────────────────────────────
|
||||
|
||||
|
||||
class TestMCPOAuth2Support:
|
||||
"""Tests for OAuth2 credential support in MCPToolBlock."""
|
||||
|
||||
def test_extract_auth_token_from_api_key(self):
|
||||
creds = APIKeyCredentials(
|
||||
id="test",
|
||||
provider="mcp",
|
||||
api_key=SecretStr("my-api-key"),
|
||||
title="test",
|
||||
)
|
||||
token = MCPToolBlock._extract_auth_token(creds)
|
||||
assert token == "my-api-key"
|
||||
|
||||
def test_extract_auth_token_from_oauth2(self):
|
||||
creds = OAuth2Credentials(
|
||||
id="test",
|
||||
provider="mcp",
|
||||
access_token=SecretStr("oauth2-access-token"),
|
||||
scopes=["read"],
|
||||
title="test",
|
||||
)
|
||||
token = MCPToolBlock._extract_auth_token(creds)
|
||||
assert token == "oauth2-access-token"
|
||||
|
||||
def test_extract_auth_token_empty_skipped(self):
|
||||
creds = APIKeyCredentials(
|
||||
id="test",
|
||||
provider="mcp",
|
||||
api_key=SecretStr(""),
|
||||
title="test",
|
||||
)
|
||||
token = MCPToolBlock._extract_auth_token(creds)
|
||||
assert token is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_oauth2_credentials(self):
|
||||
"""Verify the block can run with OAuth2 credentials."""
|
||||
block = MCPToolBlock()
|
||||
input_data = MCPToolBlock.Input(
|
||||
server_url="https://mcp.example.com/mcp",
|
||||
selected_tool="test_tool",
|
||||
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
)
|
||||
|
||||
oauth2_creds = OAuth2Credentials(
|
||||
id="test-id",
|
||||
provider="mcp",
|
||||
access_token=SecretStr("real-oauth2-token"),
|
||||
scopes=["read", "write"],
|
||||
title="MCP OAuth",
|
||||
)
|
||||
|
||||
captured_tokens = []
|
||||
|
||||
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
||||
captured_tokens.append(auth_token)
|
||||
return {"status": "ok"}
|
||||
|
||||
block._call_mcp_tool = mock_call # type: ignore
|
||||
|
||||
outputs = []
|
||||
async for name, data in block.run(input_data, credentials=oauth2_creds):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert captured_tokens == ["real-oauth2-token"]
|
||||
assert outputs == [("result", {"status": "ok"})]
|
||||
@@ -1,242 +0,0 @@
|
||||
"""
|
||||
Tests for MCP OAuth handler.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.mcp.client import MCPClient
|
||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
||||
from backend.data.model import OAuth2Credentials
|
||||
|
||||
|
||||
def _mock_response(json_data: dict, status: int = 200) -> MagicMock:
|
||||
"""Create a mock Response with synchronous json() (matching Requests.Response)."""
|
||||
resp = MagicMock()
|
||||
resp.status = status
|
||||
resp.ok = 200 <= status < 300
|
||||
resp.json.return_value = json_data
|
||||
return resp
|
||||
|
||||
|
||||
class TestMCPOAuthHandler:
|
||||
"""Tests for the MCPOAuthHandler."""
|
||||
|
||||
def _make_handler(self, **overrides) -> MCPOAuthHandler:
|
||||
defaults = {
|
||||
"client_id": "test-client-id",
|
||||
"client_secret": "test-client-secret",
|
||||
"redirect_uri": "https://app.example.com/callback",
|
||||
"authorize_url": "https://auth.example.com/authorize",
|
||||
"token_url": "https://auth.example.com/token",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return MCPOAuthHandler(**defaults)
|
||||
|
||||
def test_get_login_url_basic(self):
|
||||
handler = self._make_handler()
|
||||
url = handler.get_login_url(
|
||||
scopes=["read", "write"],
|
||||
state="random-state-token",
|
||||
code_challenge="S256-challenge-value",
|
||||
)
|
||||
|
||||
assert "https://auth.example.com/authorize?" in url
|
||||
assert "response_type=code" in url
|
||||
assert "client_id=test-client-id" in url
|
||||
assert "state=random-state-token" in url
|
||||
assert "code_challenge=S256-challenge-value" in url
|
||||
assert "code_challenge_method=S256" in url
|
||||
assert "scope=read+write" in url
|
||||
|
||||
def test_get_login_url_with_resource(self):
|
||||
handler = self._make_handler(resource_url="https://mcp.example.com/mcp")
|
||||
url = handler.get_login_url(
|
||||
scopes=[], state="state", code_challenge="challenge"
|
||||
)
|
||||
|
||||
assert "resource=https" in url
|
||||
|
||||
def test_get_login_url_without_pkce(self):
|
||||
handler = self._make_handler()
|
||||
url = handler.get_login_url(scopes=["read"], state="state", code_challenge=None)
|
||||
|
||||
assert "code_challenge" not in url
|
||||
assert "code_challenge_method" not in url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exchange_code_for_tokens(self):
|
||||
handler = self._make_handler()
|
||||
|
||||
resp = _mock_response(
|
||||
{
|
||||
"access_token": "new-access-token",
|
||||
"refresh_token": "new-refresh-token",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer",
|
||||
}
|
||||
)
|
||||
|
||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
||||
instance = MockRequests.return_value
|
||||
instance.post = AsyncMock(return_value=resp)
|
||||
|
||||
creds = await handler.exchange_code_for_tokens(
|
||||
code="auth-code",
|
||||
scopes=["read"],
|
||||
code_verifier="pkce-verifier",
|
||||
)
|
||||
|
||||
assert isinstance(creds, OAuth2Credentials)
|
||||
assert creds.access_token.get_secret_value() == "new-access-token"
|
||||
assert creds.refresh_token is not None
|
||||
assert creds.refresh_token.get_secret_value() == "new-refresh-token"
|
||||
assert creds.scopes == ["read"]
|
||||
assert creds.access_token_expires_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens(self):
|
||||
handler = self._make_handler()
|
||||
|
||||
existing_creds = OAuth2Credentials(
|
||||
id="existing-id",
|
||||
provider="mcp",
|
||||
access_token=SecretStr("old-token"),
|
||||
refresh_token=SecretStr("old-refresh"),
|
||||
scopes=["read"],
|
||||
title="test",
|
||||
)
|
||||
|
||||
resp = _mock_response(
|
||||
{
|
||||
"access_token": "refreshed-token",
|
||||
"refresh_token": "new-refresh",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
)
|
||||
|
||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
||||
instance = MockRequests.return_value
|
||||
instance.post = AsyncMock(return_value=resp)
|
||||
|
||||
refreshed = await handler._refresh_tokens(existing_creds)
|
||||
|
||||
assert refreshed.id == "existing-id"
|
||||
assert refreshed.access_token.get_secret_value() == "refreshed-token"
|
||||
assert refreshed.refresh_token is not None
|
||||
assert refreshed.refresh_token.get_secret_value() == "new-refresh"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_tokens_no_refresh_token(self):
|
||||
handler = self._make_handler()
|
||||
|
||||
creds = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=["read"],
|
||||
title="test",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="No refresh token"):
|
||||
await handler._refresh_tokens(creds)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_tokens_no_url(self):
|
||||
handler = self._make_handler(revoke_url=None)
|
||||
|
||||
creds = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=[],
|
||||
title="test",
|
||||
)
|
||||
|
||||
result = await handler.revoke_tokens(creds)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_revoke_tokens_with_url(self):
|
||||
handler = self._make_handler(revoke_url="https://auth.example.com/revoke")
|
||||
|
||||
creds = OAuth2Credentials(
|
||||
provider="mcp",
|
||||
access_token=SecretStr("token"),
|
||||
scopes=[],
|
||||
title="test",
|
||||
)
|
||||
|
||||
resp = _mock_response({}, status=200)
|
||||
|
||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
||||
instance = MockRequests.return_value
|
||||
instance.post = AsyncMock(return_value=resp)
|
||||
|
||||
result = await handler.revoke_tokens(creds)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestMCPClientDiscovery:
|
||||
"""Tests for MCPClient OAuth metadata discovery."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_auth_found(self):
|
||||
client = MCPClient("https://mcp.example.com/mcp")
|
||||
|
||||
metadata = {
|
||||
"authorization_servers": ["https://auth.example.com"],
|
||||
"resource": "https://mcp.example.com/mcp",
|
||||
}
|
||||
|
||||
resp = _mock_response(metadata, status=200)
|
||||
|
||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
||||
instance = MockRequests.return_value
|
||||
instance.get = AsyncMock(return_value=resp)
|
||||
|
||||
result = await client.discover_auth()
|
||||
|
||||
assert result is not None
|
||||
assert result["authorization_servers"] == ["https://auth.example.com"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_auth_not_found(self):
|
||||
client = MCPClient("https://mcp.example.com/mcp")
|
||||
|
||||
resp = _mock_response({}, status=404)
|
||||
|
||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
||||
instance = MockRequests.return_value
|
||||
instance.get = AsyncMock(return_value=resp)
|
||||
|
||||
result = await client.discover_auth()
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_auth_server_metadata(self):
|
||||
client = MCPClient("https://mcp.example.com/mcp")
|
||||
|
||||
server_metadata = {
|
||||
"issuer": "https://auth.example.com",
|
||||
"authorization_endpoint": "https://auth.example.com/authorize",
|
||||
"token_endpoint": "https://auth.example.com/token",
|
||||
"registration_endpoint": "https://auth.example.com/register",
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
}
|
||||
|
||||
resp = _mock_response(server_metadata, status=200)
|
||||
|
||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
||||
instance = MockRequests.return_value
|
||||
instance.get = AsyncMock(return_value=resp)
|
||||
|
||||
result = await client.discover_auth_server_metadata(
|
||||
"https://auth.example.com"
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result["authorization_endpoint"] == "https://auth.example.com/authorize"
|
||||
assert result["token_endpoint"] == "https://auth.example.com/token"
|
||||
@@ -1,162 +0,0 @@
|
||||
"""
|
||||
Minimal MCP server for integration testing.
|
||||
|
||||
Implements the MCP Streamable HTTP transport (JSON-RPC 2.0 over HTTP POST)
|
||||
with a few sample tools. Runs on localhost with a random available port.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Sample tools this test server exposes
|
||||
TEST_TOOLS = [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get current weather for a city",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "add_numbers",
|
||||
"description": "Add two numbers together",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number", "description": "First number"},
|
||||
"b": {"type": "number", "description": "Second number"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "echo",
|
||||
"description": "Echo back the input message",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {"type": "string", "description": "Message to echo"},
|
||||
},
|
||||
"required": ["message"],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _handle_initialize(params: dict) -> dict:
|
||||
return {
|
||||
"protocolVersion": "2025-03-26",
|
||||
"capabilities": {"tools": {"listChanged": False}},
|
||||
"serverInfo": {"name": "test-mcp-server", "version": "1.0.0"},
|
||||
}
|
||||
|
||||
|
||||
def _handle_tools_list(params: dict) -> dict:
|
||||
return {"tools": TEST_TOOLS}
|
||||
|
||||
|
||||
def _handle_tools_call(params: dict) -> dict:
|
||||
tool_name = params.get("name", "")
|
||||
arguments = params.get("arguments", {})
|
||||
|
||||
if tool_name == "get_weather":
|
||||
city = arguments.get("city", "Unknown")
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(
|
||||
{"city": city, "temperature": 22, "condition": "sunny"}
|
||||
),
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
elif tool_name == "add_numbers":
|
||||
a = arguments.get("a", 0)
|
||||
b = arguments.get("b", 0)
|
||||
return {
|
||||
"content": [{"type": "text", "text": json.dumps({"result": a + b})}],
|
||||
}
|
||||
|
||||
elif tool_name == "echo":
|
||||
message = arguments.get("message", "")
|
||||
return {
|
||||
"content": [{"type": "text", "text": message}],
|
||||
}
|
||||
|
||||
else:
|
||||
return {
|
||||
"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}],
|
||||
"isError": True,
|
||||
}
|
||||
|
||||
|
||||
HANDLERS = {
|
||||
"initialize": _handle_initialize,
|
||||
"tools/list": _handle_tools_list,
|
||||
"tools/call": _handle_tools_call,
|
||||
}
|
||||
|
||||
|
||||
async def handle_mcp_request(request: web.Request) -> web.Response:
|
||||
"""Handle incoming MCP JSON-RPC 2.0 requests."""
|
||||
# Check auth if configured
|
||||
expected_token = request.app.get("auth_token")
|
||||
if expected_token:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header != f"Bearer {expected_token}":
|
||||
return web.json_response(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32001, "message": "Unauthorized"},
|
||||
"id": None,
|
||||
},
|
||||
status=401,
|
||||
)
|
||||
|
||||
body = await request.json()
|
||||
|
||||
# Handle notifications (no id field) — just acknowledge
|
||||
if "id" not in body:
|
||||
return web.Response(status=202)
|
||||
|
||||
method = body.get("method", "")
|
||||
params = body.get("params", {})
|
||||
request_id = body.get("id")
|
||||
|
||||
handler = HANDLERS.get(method)
|
||||
if not handler:
|
||||
return web.json_response(
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": f"Method not found: {method}",
|
||||
},
|
||||
"id": request_id,
|
||||
}
|
||||
)
|
||||
|
||||
result = handler(params)
|
||||
return web.json_response({"jsonrpc": "2.0", "result": result, "id": request_id})
|
||||
|
||||
|
||||
def create_test_mcp_app(auth_token: str | None = None) -> web.Application:
|
||||
"""Create an aiohttp app that acts as an MCP server."""
|
||||
app = web.Application()
|
||||
app.router.add_post("/mcp", handle_mcp_request)
|
||||
if auth_token:
|
||||
app["auth_token"] = auth_token
|
||||
return app
|
||||
@@ -30,7 +30,6 @@ class ProviderName(str, Enum):
|
||||
IDEOGRAM = "ideogram"
|
||||
JINA = "jina"
|
||||
LLAMA_API = "llama_api"
|
||||
MCP = "mcp"
|
||||
MEDIUM = "medium"
|
||||
MEM0 = "mem0"
|
||||
NOTION = "notion"
|
||||
|
||||
101
autogpt_platform/backend/poetry.lock
generated
101
autogpt_platform/backend/poetry.lock
generated
@@ -896,6 +896,29 @@ files = [
|
||||
{file = "charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "claude-agent-sdk"
|
||||
version = "0.1.33"
|
||||
description = "Python SDK for Claude Code"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "claude_agent_sdk-0.1.33-py3-none-macosx_11_0_arm64.whl", hash = "sha256:57886a2dd124e5b3c9e12ec3e4841742ab3444d1e428b45ceaec8841c96698fa"},
|
||||
{file = "claude_agent_sdk-0.1.33-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:ea0f1e4fadeec766000122723c406a6f47c6210ea11bb5cc0c88af11ef7c940c"},
|
||||
{file = "claude_agent_sdk-0.1.33-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:0ecd822c577b4ea2a52e51146a24dcea73eb69ff366bdb875785dadb116d593b"},
|
||||
{file = "claude_agent_sdk-0.1.33-py3-none-win_amd64.whl", hash = "sha256:a9fbd09d8f947005e087340ecd0706ed35639c946b4bd49429d3132db4cb3751"},
|
||||
{file = "claude_agent_sdk-0.1.33.tar.gz", hash = "sha256:134bf403bb7553d829dadec42c30ecef340f5d4ad1595c1bdef933a9ca3129cf"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=4.0.0"
|
||||
mcp = ">=0.1.0"
|
||||
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
dev = ["anyio[trio] (>=4.0.0)", "mypy (>=1.0.0)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.20.0)", "pytest-cov (>=4.0.0)", "ruff (>=0.1.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "cleo"
|
||||
version = "2.1.0"
|
||||
@@ -2562,6 +2585,18 @@ http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
zstd = ["zstandard (>=0.18.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "httpx-sse"
|
||||
version = "0.4.3"
|
||||
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc"},
|
||||
{file = "httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "1.4.1"
|
||||
@@ -3279,6 +3314,39 @@ files = [
|
||||
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mcp"
|
||||
version = "1.26.0"
|
||||
description = "Model Context Protocol SDK"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca"},
|
||||
{file = "mcp-1.26.0.tar.gz", hash = "sha256:db6e2ef491eecc1a0d93711a76f28dec2e05999f93afd48795da1c1137142c66"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=4.5"
|
||||
httpx = ">=0.27.1"
|
||||
httpx-sse = ">=0.4"
|
||||
jsonschema = ">=4.20.0"
|
||||
pydantic = ">=2.11.0,<3.0.0"
|
||||
pydantic-settings = ">=2.5.2"
|
||||
pyjwt = {version = ">=2.10.1", extras = ["crypto"]}
|
||||
python-multipart = ">=0.0.9"
|
||||
pywin32 = {version = ">=310", markers = "sys_platform == \"win32\""}
|
||||
sse-starlette = ">=1.6.1"
|
||||
starlette = ">=0.27"
|
||||
typing-extensions = ">=4.9.0"
|
||||
typing-inspection = ">=0.4.1"
|
||||
uvicorn = {version = ">=0.31.1", markers = "sys_platform != \"emscripten\""}
|
||||
|
||||
[package.extras]
|
||||
cli = ["python-dotenv (>=1.0.0)", "typer (>=0.16.0)"]
|
||||
rich = ["rich (>=13.9.4)"]
|
||||
ws = ["websockets (>=15.0.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "mdurl"
|
||||
version = "0.1.2"
|
||||
@@ -5961,7 +6029,7 @@ description = "Python for Window Extensions"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
markers = "platform_system == \"Windows\""
|
||||
markers = "sys_platform == \"win32\" or platform_system == \"Windows\""
|
||||
files = [
|
||||
{file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"},
|
||||
{file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"},
|
||||
@@ -6006,13 +6074,6 @@ optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main", "dev"]
|
||||
files = [
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:c2514fceb77bc5e7a2f7adfaa1feb2fb311607c9cb518dbc378688ec73d8292f"},
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c57bb8c96f6d1808c030b1687b9b5fb476abaa47f0db9c0101f5e9f394e97f4"},
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:efd7b85f94a6f21e4932043973a7ba2613b059c4a000551892ac9f1d11f5baf3"},
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22ba7cfcad58ef3ecddc7ed1db3409af68d023b7f940da23c6c2a1890976eda6"},
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6344df0d5755a2c9a276d4473ae6b90647e216ab4757f8426893b5dd2ac3f369"},
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-win32.whl", hash = "sha256:3ff07ec89bae51176c0549bc4c63aa6202991da2d9a6129d7aef7f1407d3f295"},
|
||||
{file = "PyYAML-6.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:5cf4e27da7e3fbed4d6c3d8e797387aaad68102272f8f9752883bc32d61cb87b"},
|
||||
{file = "pyyaml-6.0.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:214ed4befebe12df36bcc8bc2b64b396ca31be9304b8f59e25c11cf94a4c033b"},
|
||||
{file = "pyyaml-6.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02ea2dfa234451bbb8772601d7b8e426c2bfa197136796224e50e35a78777956"},
|
||||
{file = "pyyaml-6.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b30236e45cf30d2b8e7b3e85881719e98507abed1011bf463a8fa23e9c3e98a8"},
|
||||
@@ -6942,6 +7003,28 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
|
||||
pymysql = ["pymysql"]
|
||||
sqlcipher = ["sqlcipher3_binary"]
|
||||
|
||||
[[package]]
|
||||
name = "sse-starlette"
|
||||
version = "3.2.0"
|
||||
description = "SSE plugin for Starlette"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "sse_starlette-3.2.0-py3-none-any.whl", hash = "sha256:5876954bd51920fc2cd51baee47a080eb88a37b5b784e615abb0b283f801cdbf"},
|
||||
{file = "sse_starlette-3.2.0.tar.gz", hash = "sha256:8127594edfb51abe44eac9c49e59b0b01f1039d0c7461c6fd91d4e03b70da422"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
anyio = ">=4.7.0"
|
||||
starlette = ">=0.49.1"
|
||||
|
||||
[package.extras]
|
||||
daphne = ["daphne (>=4.2.0)"]
|
||||
examples = ["aiosqlite (>=0.21.0)", "fastapi (>=0.115.12)", "sqlalchemy[asyncio] (>=2.0.41)", "uvicorn (>=0.34.0)"]
|
||||
granian = ["granian (>=2.3.1)"]
|
||||
uvicorn = ["uvicorn (>=0.34.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "stagehand"
|
||||
version = "0.5.9"
|
||||
@@ -8382,4 +8465,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "40b2c87c3c86bd10214bd30ad291cead75da5060ab894105025ee4c0a3b3828e"
|
||||
content-hash = "2e2541233117d1f048be2d3c701fb8d5577b445002c0017362027f278a1a4d06"
|
||||
|
||||
@@ -13,6 +13,7 @@ aio-pika = "^9.5.5"
|
||||
aiohttp = "^3.10.0"
|
||||
aiodns = "^3.5.0"
|
||||
anthropic = "^0.59.0"
|
||||
claude-agent-sdk = "^0.1.0"
|
||||
apscheduler = "^3.11.1"
|
||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||
bleach = { extras = ["css"], version = "^6.2.0" }
|
||||
|
||||
@@ -467,7 +467,6 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
||||
| [Github Update Comment](block-integrations/github/issues.md#github-update-comment) | A block that updates an existing comment on a GitHub issue or pull request |
|
||||
| [Github Update File](block-integrations/github/repo.md#github-update-file) | This block updates an existing file in a GitHub repository |
|
||||
| [Instantiate Code Sandbox](block-integrations/misc.md#instantiate-code-sandbox) | Instantiate a sandbox environment with internet access in which you can execute code with the Execute Code Step block |
|
||||
| [MCP Tool](block-integrations/mcp/block.md#mcp-tool) | Connect to any MCP server and execute its tools |
|
||||
| [Slant3D Order Webhook](block-integrations/slant3d/webhook.md#slant3d-order-webhook) | This block triggers on Slant3D order status updates and outputs the event details, including tracking information when orders are shipped |
|
||||
|
||||
## Media Generation
|
||||
|
||||
@@ -84,7 +84,6 @@
|
||||
* [Linear Projects](block-integrations/linear/projects.md)
|
||||
* [LLM](block-integrations/llm.md)
|
||||
* [Logic](block-integrations/logic.md)
|
||||
* [Mcp Block](block-integrations/mcp/block.md)
|
||||
* [Misc](block-integrations/misc.md)
|
||||
* [Notion Create Page](block-integrations/notion/create_page.md)
|
||||
* [Notion Read Database](block-integrations/notion/read_database.md)
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
# Mcp Block
|
||||
<!-- MANUAL: file_description -->
|
||||
_Add a description of this category of blocks._
|
||||
<!-- END MANUAL -->
|
||||
|
||||
## MCP Tool
|
||||
|
||||
### What it is
|
||||
Connect to any MCP server and execute its tools. Provide a server URL, select a tool, and pass arguments dynamically.
|
||||
|
||||
### How it works
|
||||
<!-- MANUAL: how_it_works -->
|
||||
_Add technical explanation here._
|
||||
<!-- END MANUAL -->
|
||||
|
||||
### Inputs
|
||||
|
||||
| Input | Description | Type | Required |
|
||||
|-------|-------------|------|----------|
|
||||
| server_url | URL of the MCP server (Streamable HTTP endpoint) | str | Yes |
|
||||
| selected_tool | The MCP tool to execute | str | No |
|
||||
| tool_arguments | Arguments to pass to the selected MCP tool. The fields here are defined by the tool's input schema. | Dict[str, Any] | No |
|
||||
|
||||
### Outputs
|
||||
|
||||
| Output | Description | Type |
|
||||
|--------|-------------|------|
|
||||
| error | Error message if the tool call failed | str |
|
||||
| result | The result returned by the MCP tool | Result |
|
||||
|
||||
### Possible use case
|
||||
<!-- MANUAL: use_case -->
|
||||
_Add practical use case examples here._
|
||||
<!-- END MANUAL -->
|
||||
|
||||
---
|
||||
Reference in New Issue
Block a user