Compare commits

..

9 Commits

Author SHA1 Message Date
Zamil Majdy
0bbe8a184d Merge dev and resolve poetry.lock conflict 2026-02-08 19:40:17 +04:00
Zamil Majdy
7592deed63 fix(backend/chat): Address remaining PR review comments
- Fix tool_call_id always being "sdk-call" by generating unique IDs per invocation
- Fix validation using original tool_name instead of clean_name in security hooks
- Fix duplicate StreamFinish in Anthropic fallback path
- Fix ImportError fallback returning plain dict instead of re-raising
- Extract _build_input_schema helper to deduplicate schema construction
- Add else branch for unhandled SDK message types for observability
- Truncate large tool results in conversation history to prevent context overflow
2026-02-08 19:39:10 +04:00
Zamil Majdy
b9c759ce4f fix(backend/chat): Address additional PR review comments
- Add terminal StreamFinish in adapt_sdk_stream if SDK ends without one
- Sanitize error message in adapt_sdk_stream exception handler
- Pass full JSON schema (type, properties, required) to tool decorator
2026-02-08 07:14:45 +04:00
Zamil Majdy
5efb80d47b fix(backend/chat): Address PR review comments for Claude SDK integration
- Add StreamFinish after ErrorMessage in response adapter
- Fix str.replace to removeprefix in security hooks
- Apply max_context_messages limit as safety guard in history formatting
- Add empty prompt guard before sending to SDK
- Sanitize error messages to avoid exposing internal details
- Fix fire-and-forget asyncio.create_task by storing task reference
- Fix tool_calls population on assistant messages
- Rewrite Anthropic fallback to persist messages and merge consecutive roles
- Only use ANTHROPIC_API_KEY for fallback (not OpenRouter keys)
- Fix IndexError when tool result content list is empty
2026-02-06 13:25:10 +04:00
Zamil Majdy
b49d8e2cba fix lock 2026-02-06 13:19:53 +04:00
Zamil Majdy
452544530d feat(chat/sdk): Enable native SDK context compaction
- Remove manual truncation in conversation history formatting
- SDK's automatic compaction handles context limits intelligently
- Add observability hooks:
  - PreCompact: Log when SDK triggers context compaction
  - PostToolUse: Log successful tool executions
  - PostToolUseFailure: Log and debug failed tool executions
- Update config: increase max_context_messages (SDK handles compaction)
2026-02-06 12:44:48 +04:00
Zamil Majdy
32ee7e6cf8 fix(chat): Remove aggressive stale task detection
The 60-second timeout was too aggressive and could incorrectly mark
legitimate long-running tool calls as stale. Relying on Redis TTL
(1 hour) for cleanup is sufficient and more reliable.
2026-02-06 11:45:54 +04:00
Zamil Majdy
670663c406 Merge dev and resolve poetry.lock conflict 2026-02-06 11:40:41 +04:00
Zamil Majdy
0dbe4cf51e feat(backend/chat): Add Claude Agent SDK integration for CoPilot
This PR adds Claude Agent SDK as the default backend for CoPilot chat completions,
replacing the direct OpenAI API integration.

Key changes:
- Add Claude Agent SDK service layer with MCP tool adapter
- Fix message persistence after tool calls (messages no longer disappear on refresh)
- Add OpenRouter tracing for session title generation
- Add security hooks for user context validation
- Add Anthropic fallback when SDK is not available
- Clean up excessive debug logging
2026-02-06 11:38:17 +04:00
27 changed files with 1855 additions and 2522 deletions

View File

@@ -1,76 +0,0 @@
# MCP Block Implementation Plan
## Overview
Create a single **MCPBlock** that dynamically integrates with any MCP (Model Context Protocol)
server. Users provide a server URL, the block discovers available tools, presents them as a
dropdown, and dynamically adjusts input/output schema based on the selected tool — exactly like
`AgentExecutorBlock` handles dynamic schemas.
## Architecture
```
User provides MCP server URL + credentials
MCPBlock fetches tools via MCP protocol (tools/list)
User selects tool from dropdown (stored in constantInput)
Input schema dynamically updates based on selected tool's inputSchema
On execution: MCPBlock calls the tool via MCP protocol (tools/call)
Result yielded as block output
```
## Design Decisions
1. **Single block, not many blocks** — One `MCPBlock` handles all MCP servers/tools
2. **Dynamic schema via AgentExecutorBlock pattern** — Override `get_input_schema()`,
`get_input_defaults()`, `get_missing_input()` on the Input class
3. **Auth via API key or OAuth2 credentials** — Use existing `APIKeyCredentials` or
`OAuth2Credentials` with `ProviderName.MCP` provider. API keys are sent as Bearer tokens;
OAuth2 uses the access token.
4. **HTTP-based MCP client** — Use `aiohttp` (already a dependency) to implement MCP Streamable
HTTP transport directly. No need for the `mcp` Python SDK — the protocol is simple JSON-RPC
over HTTP. Handles both JSON and SSE response formats.
5. **No new DB tables** — Everything fits in existing `AgentBlock` + `AgentNode` tables
## Implementation Files
### New Files
- `backend/blocks/mcp/` — MCP block package
- `__init__.py`
- `block.py` — MCPToolBlock implementation
- `client.py` — MCP HTTP client (list_tools, call_tool)
- `oauth.py` — MCP OAuth handler for dynamic endpoint discovery
- `test_mcp.py` — Unit tests
- `test_oauth.py` — OAuth handler tests
- `test_integration.py` — Integration tests with local test server
- `test_e2e.py` — E2E tests against real MCP servers
### Modified Files
- `backend/integrations/providers.py` — Add `MCP = "mcp"` to ProviderName
## Dev Loop
```bash
cd autogpt_platform/backend
poetry run pytest backend/blocks/mcp/test_mcp.py -xvs # Unit tests
poetry run pytest backend/blocks/mcp/test_oauth.py -xvs # OAuth tests
poetry run pytest backend/blocks/mcp/test_integration.py -xvs # Integration tests
poetry run pytest backend/blocks/mcp/ -xvs # All MCP tests
```
## Status
- [x] Research & Design
- [x] Add ProviderName.MCP
- [x] Implement MCP client (client.py)
- [x] Implement MCPToolBlock (block.py)
- [x] Add OAuth2 support (oauth.py)
- [x] Write unit tests
- [x] Write integration tests
- [x] Write E2E tests
- [x] Run tests & fix issues
- [x] Create PR

View File

@@ -27,12 +27,20 @@ class ChatConfig(BaseSettings):
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
# Streaming Configuration
# Note: When using Claude Agent SDK, context management is handled automatically
# via the SDK's built-in compaction. This is mainly used for the fallback path.
max_context_messages: int = Field(
default=50, ge=1, le=200, description="Maximum context messages"
default=100,
ge=1,
le=500,
description="Max context messages (SDK handles compaction automatically)",
)
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(default=3, description="Maximum number of retries")
max_retries: int = Field(
default=3,
description="Max retries for fallback path (SDK handles retries internally)",
)
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
max_agent_schedules: int = Field(
default=30, description="Maximum number of agent schedules"
@@ -93,6 +101,12 @@ class ChatConfig(BaseSettings):
description="Name of the prompt in Langfuse to fetch",
)
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,
description="Use Claude Agent SDK for chat completions",
)
@field_validator("api_key", mode="before")
@classmethod
def get_api_key(cls, v):
@@ -132,6 +146,17 @@ class ChatConfig(BaseSettings):
v = os.getenv("CHAT_INTERNAL_API_KEY")
return v
@field_validator("use_claude_agent_sdk", mode="before")
@classmethod
def get_use_claude_agent_sdk(cls, v):
"""Get use_claude_agent_sdk from environment if not provided."""
# Check environment variable - default to True if not set
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
if env_val:
return env_val in ("true", "1", "yes", "on")
# Default to True (SDK enabled by default)
return True if v is None else v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",

View File

@@ -273,9 +273,8 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None:
try:
session = ChatSession.model_validate_json(raw_session)
logger.info(
f"Loading session {session_id} from cache: "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
)
return session
except Exception as e:
@@ -317,11 +316,9 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
return None
messages = prisma_session.Messages
logger.info(
f"Loading session {session_id} from DB: "
f"has_messages={messages is not None}, "
f"message_count={len(messages) if messages else 0}, "
f"roles={[m.role for m in messages] if messages else []}"
logger.debug(
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
)
return ChatSession.from_db(prisma_session, messages)
@@ -372,10 +369,9 @@ async def _save_session_to_db(
"function_call": msg.function_call,
}
)
logger.info(
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
logger.debug(
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
f"roles={[m['role'] for m in messages_data]}"
)
await chat_db.add_chat_messages_batch(
session_id=session.session_id,
@@ -415,7 +411,7 @@ async def get_chat_session(
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
# Fall back to database
logger.info(f"Session {session_id} not in cache, checking database")
logger.debug(f"Session {session_id} not in cache, checking database")
session = await _get_session_from_db(session_id)
if session is None:
@@ -432,7 +428,6 @@ async def get_chat_session(
# Cache the session from DB
try:
await _cache_session(session)
logger.info(f"Cached session {session_id} from database")
except Exception as e:
logger.warning(f"Failed to cache session {session_id}: {e}")
@@ -603,13 +598,19 @@ async def update_session_title(session_id: str, title: str) -> bool:
logger.warning(f"Session {session_id} not found for title update")
return False
# Invalidate cache so next fetch gets updated title
# Update title in cache if it exists (instead of invalidating).
# This prevents race conditions where cache invalidation causes
# the frontend to see stale DB data while streaming is still in progress.
try:
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
cached = await _get_session_from_cache(session_id)
if cached:
cached.title = title
await _cache_session(cached)
except Exception as e:
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
# Not critical - title will be correct on next full cache refresh
logger.warning(
f"Failed to update title in cache for session {session_id}: {e}"
)
return True
except Exception as e:

View File

@@ -1,5 +1,6 @@
"""Chat API routes for chat session management and streaming via SSE."""
import asyncio
import logging
import uuid as uuid_module
from collections.abc import AsyncGenerator
@@ -16,8 +17,17 @@ from . import service as chat_service
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .model import (
ChatMessage,
ChatSession,
create_chat_session,
get_chat_session,
get_user_sessions,
upsert_chat_session,
)
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
from .sdk import service as sdk_service
from .tracking import track_user_message
config = ChatConfig()
@@ -209,6 +219,10 @@ async def get_session(
active_task, last_message_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_task:
# Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE
@@ -265,10 +279,30 @@ async def stream_chat_post(
containing the task_id for reconnection.
"""
import asyncio
session = await _validate_and_get_session(session_id, user_id)
# Add user message to session BEFORE creating task to avoid race condition
# where GET_SESSION sees the task as "running" but the message isn't saved yet
if request.message:
session.messages.append(
ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
)
if request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
logger.info(
f"[STREAM] Saving user message to session {session_id}, "
f"msg_count={len(session.messages)}"
)
session = await upsert_chat_session(session)
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
@@ -283,24 +317,38 @@ async def stream_chat_post(
# Background task that runs the AI generation independently of SSE connection
async def run_ai_generation():
chunk_count = 0
try:
# Emit a start event with task_id for reconnection
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
async for chunk in chat_service.stream_chat_completion(
# Choose service based on configuration
use_sdk = config.use_claude_agent_sdk
stream_fn = (
sdk_service.stream_chat_completion_sdk
if use_sdk
else chat_service.stream_chat_completion
)
# Pass message=None since we already added it to the session above
async for chunk in stream_fn(
session_id,
request.message,
None, # Message already in session
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
session=session, # Pass session with message already added
context=request.context,
):
chunk_count += 1
# Write to Redis (subscribers will receive via XREAD)
await stream_registry.publish_chunk(task_id, chunk)
# Mark task as completed
await stream_registry.mark_task_completed(task_id, "completed")
logger.info(
f"[BG_TASK] AI generation completed for session {session_id}: {chunk_count} chunks, marking task {task_id} as completed"
)
# Mark task as completed (also publishes StreamFinish)
completed = await stream_registry.mark_task_completed(task_id, "completed")
logger.info(f"[BG_TASK] mark_task_completed returned: {completed}")
except Exception as e:
logger.error(
f"Error in background AI generation for session {session_id}: {e}"
@@ -315,7 +363,7 @@ async def stream_chat_post(
async def event_generator() -> AsyncGenerator[str, None]:
subscriber_queue = None
try:
# Subscribe to the task stream (this replays existing messages + live updates)
# Subscribe to the task stream (replays + live updates)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
@@ -323,6 +371,7 @@ async def stream_chat_post(
)
if subscriber_queue is None:
logger.warning(f"Failed to subscribe to task {task_id}")
yield StreamFinish().to_sse()
yield "data: [DONE]\n\n"
return
@@ -341,11 +390,11 @@ async def stream_chat_post(
yield StreamHeartbeat().to_sse()
except GeneratorExit:
pass # Client disconnected - background task continues
pass # Client disconnected - normal behavior
except Exception as e:
logger.error(f"Error in SSE stream for task {task_id}: {e}")
finally:
# Unsubscribe when client disconnects or stream ends to prevent resource leak
# Unsubscribe when client disconnects or stream ends
if subscriber_queue is not None:
try:
await stream_registry.unsubscribe_from_task(
@@ -400,35 +449,21 @@ async def stream_chat_get(
session = await _validate_and_get_session(session_id, user_id)
async def event_generator() -> AsyncGenerator[str, None]:
chunk_count = 0
first_chunk_type: str | None = None
async for chunk in chat_service.stream_chat_completion(
# Choose service based on configuration
use_sdk = config.use_claude_agent_sdk
stream_fn = (
sdk_service.stream_chat_completion_sdk
if use_sdk
else chat_service.stream_chat_completion
)
async for chunk in stream_fn(
session_id,
message,
is_user_message=is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
logger.info(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
# AI SDK protocol termination
yield "data: [DONE]\n\n"
@@ -550,8 +585,6 @@ async def stream_task(
)
async def event_generator() -> AsyncGenerator[str, None]:
import asyncio
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try:
while True:

View File

@@ -0,0 +1,14 @@
"""Claude Agent SDK integration for CoPilot.
This module provides the integration layer between the Claude Agent SDK
and the existing CoPilot tool system, enabling drop-in replacement of
the current LLM orchestration with the battle-tested Claude Agent SDK.
"""
from .service import stream_chat_completion_sdk
from .tool_adapter import create_copilot_mcp_server
__all__ = [
"stream_chat_completion_sdk",
"create_copilot_mcp_server",
]

View File

@@ -0,0 +1,348 @@
"""Anthropic SDK fallback implementation.
This module provides the fallback streaming implementation using the Anthropic SDK
directly when the Claude Agent SDK is not available.
"""
import json
import logging
import os
import uuid
from collections.abc import AsyncGenerator
from typing import Any, cast
from ..model import ChatMessage, ChatSession
from ..response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
from .tool_adapter import get_tool_definitions, get_tool_handlers
logger = logging.getLogger(__name__)
async def stream_with_anthropic(
session: ChatSession,
system_prompt: str,
text_block_id: str,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream using Anthropic SDK directly with tool calling support.
This function accumulates messages into the session for persistence.
The caller should NOT yield an additional StreamFinish - this function handles it.
"""
import anthropic
# Only use ANTHROPIC_API_KEY - don't fall back to OpenRouter keys
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
yield StreamError(
errorText="ANTHROPIC_API_KEY not configured for fallback",
code="config_error",
)
yield StreamFinish()
return
client = anthropic.AsyncAnthropic(api_key=api_key)
tool_definitions = get_tool_definitions()
tool_handlers = get_tool_handlers()
anthropic_tools = [
{
"name": t["name"],
"description": t["description"],
"input_schema": t["inputSchema"],
}
for t in tool_definitions
]
anthropic_messages = _convert_session_to_anthropic(session)
if not anthropic_messages or anthropic_messages[-1]["role"] != "user":
anthropic_messages.append(
{"role": "user", "content": "Continue with the task."}
)
has_started_text = False
max_iterations = 10
accumulated_text = ""
accumulated_tool_calls: list[dict[str, Any]] = []
for _ in range(max_iterations):
try:
async with client.messages.stream(
model="claude-sonnet-4-20250514",
max_tokens=4096,
system=system_prompt,
messages=cast(Any, anthropic_messages),
tools=cast(Any, anthropic_tools) if anthropic_tools else [],
) as stream:
async for event in stream:
if event.type == "content_block_start":
block = event.content_block
if hasattr(block, "type"):
if block.type == "text" and not has_started_text:
yield StreamTextStart(id=text_block_id)
has_started_text = True
elif block.type == "tool_use":
yield StreamToolInputStart(
toolCallId=block.id, toolName=block.name
)
elif event.type == "content_block_delta":
delta = event.delta
if hasattr(delta, "type") and delta.type == "text_delta":
accumulated_text += delta.text
yield StreamTextDelta(id=text_block_id, delta=delta.text)
final_message = await stream.get_final_message()
if final_message.stop_reason == "tool_use":
if has_started_text:
yield StreamTextEnd(id=text_block_id)
has_started_text = False
text_block_id = str(uuid.uuid4())
tool_results = []
assistant_content: list[dict[str, Any]] = []
for block in final_message.content:
if block.type == "text":
assistant_content.append(
{"type": "text", "text": block.text}
)
elif block.type == "tool_use":
assistant_content.append(
{
"type": "tool_use",
"id": block.id,
"name": block.name,
"input": block.input,
}
)
# Track tool call for session persistence
accumulated_tool_calls.append(
{
"id": block.id,
"type": "function",
"function": {
"name": block.name,
"arguments": json.dumps(
block.input
if isinstance(block.input, dict)
else {}
),
},
}
)
yield StreamToolInputAvailable(
toolCallId=block.id,
toolName=block.name,
input=(
block.input if isinstance(block.input, dict) else {}
),
)
output, is_error = await _execute_tool(
block.name, block.input, tool_handlers
)
yield StreamToolOutputAvailable(
toolCallId=block.id,
toolName=block.name,
output=output,
success=not is_error,
)
# Save tool result to session
session.messages.append(
ChatMessage(
role="tool",
content=output,
tool_call_id=block.id,
)
)
tool_results.append(
{
"type": "tool_result",
"tool_use_id": block.id,
"content": output,
"is_error": is_error,
}
)
# Save assistant message with tool calls to session
session.messages.append(
ChatMessage(
role="assistant",
content=accumulated_text or None,
tool_calls=(
accumulated_tool_calls
if accumulated_tool_calls
else None
),
)
)
# Reset for next iteration
accumulated_text = ""
accumulated_tool_calls = []
anthropic_messages.append(
{"role": "assistant", "content": assistant_content}
)
anthropic_messages.append({"role": "user", "content": tool_results})
continue
else:
if has_started_text:
yield StreamTextEnd(id=text_block_id)
# Save final assistant response to session
if accumulated_text:
session.messages.append(
ChatMessage(role="assistant", content=accumulated_text)
)
yield StreamUsage(
promptTokens=final_message.usage.input_tokens,
completionTokens=final_message.usage.output_tokens,
totalTokens=final_message.usage.input_tokens
+ final_message.usage.output_tokens,
)
yield StreamFinish()
return
except Exception as e:
logger.error(f"[Anthropic Fallback] Error: {e}", exc_info=True)
yield StreamError(
errorText="An error occurred. Please try again.",
code="anthropic_error",
)
yield StreamFinish()
return
yield StreamError(errorText="Max tool iterations reached", code="max_iterations")
yield StreamFinish()
def _convert_session_to_anthropic(session: ChatSession) -> list[dict[str, Any]]:
"""Convert session messages to Anthropic format.
Handles merging consecutive same-role messages (Anthropic requires alternating roles).
"""
messages: list[dict[str, Any]] = []
for msg in session.messages:
if msg.role == "user":
new_msg = {"role": "user", "content": msg.content or ""}
elif msg.role == "assistant":
content: list[dict[str, Any]] = []
if msg.content:
content.append({"type": "text", "text": msg.content})
if msg.tool_calls:
for tc in msg.tool_calls:
func = tc.get("function", {})
args = func.get("arguments", {})
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
content.append(
{
"type": "tool_use",
"id": tc.get("id", str(uuid.uuid4())),
"name": func.get("name", ""),
"input": args,
}
)
if content:
new_msg = {"role": "assistant", "content": content}
else:
continue # Skip empty assistant messages
elif msg.role == "tool":
new_msg = {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": msg.tool_call_id or "",
"content": msg.content or "",
}
],
}
else:
continue
messages.append(new_msg)
# Merge consecutive same-role messages (Anthropic requires alternating roles)
return _merge_consecutive_roles(messages)
def _merge_consecutive_roles(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Merge consecutive messages with the same role.
Anthropic API requires alternating user/assistant roles.
"""
if not messages:
return []
merged: list[dict[str, Any]] = []
for msg in messages:
if merged and merged[-1]["role"] == msg["role"]:
# Merge with previous message
prev_content = merged[-1]["content"]
new_content = msg["content"]
# Normalize both to list-of-blocks form
if isinstance(prev_content, str):
prev_content = [{"type": "text", "text": prev_content}]
if isinstance(new_content, str):
new_content = [{"type": "text", "text": new_content}]
# Ensure both are lists
if not isinstance(prev_content, list):
prev_content = [prev_content]
if not isinstance(new_content, list):
new_content = [new_content]
merged[-1]["content"] = prev_content + new_content
else:
merged.append(msg)
return merged
async def _execute_tool(
tool_name: str, tool_input: Any, handlers: dict[str, Any]
) -> tuple[str, bool]:
"""Execute a tool and return (output, is_error)."""
handler = handlers.get(tool_name)
if not handler:
return f"Unknown tool: {tool_name}", True
try:
result = await handler(tool_input)
# Safely extract output - handle empty or missing content
content = result.get("content") or []
if content and isinstance(content, list) and len(content) > 0:
first_item = content[0]
output = first_item.get("text", "") if isinstance(first_item, dict) else ""
else:
output = ""
is_error = result.get("isError", False)
return output, is_error
except Exception as e:
return f"Error: {str(e)}", True

View File

@@ -0,0 +1,311 @@
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
This module provides the adapter layer that converts streaming messages from
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
the frontend expects.
"""
import json
import logging
import uuid
from typing import Any, AsyncGenerator
from backend.api.features.chat.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamHeartbeat,
StreamStart,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
logger = logging.getLogger(__name__)
class SDKResponseAdapter:
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
This class maintains state during a streaming session to properly track
text blocks, tool calls, and message lifecycle.
"""
def __init__(self, message_id: str | None = None):
"""Initialize the adapter.
Args:
message_id: Optional message ID. If not provided, one will be generated.
"""
self.message_id = message_id or str(uuid.uuid4())
self.text_block_id = str(uuid.uuid4())
self.has_started_text = False
self.has_ended_text = False
self.current_tool_calls: dict[str, dict[str, Any]] = {}
self.task_id: str | None = None
def set_task_id(self, task_id: str) -> None:
"""Set the task ID for reconnection support."""
self.task_id = task_id
def convert_message(self, sdk_message: Any) -> list[StreamBaseResponse]:
"""Convert a single SDK message to Vercel AI SDK format.
Args:
sdk_message: A message from the Claude Agent SDK.
Returns:
List of StreamBaseResponse objects (may be empty or multiple).
"""
responses: list[StreamBaseResponse] = []
# Handle different SDK message types - use class name since SDK uses dataclasses
class_name = type(sdk_message).__name__
msg_subtype = getattr(sdk_message, "subtype", None)
if class_name == "SystemMessage":
if msg_subtype == "init":
# Session initialization - emit start
responses.append(
StreamStart(
messageId=self.message_id,
taskId=self.task_id,
)
)
elif class_name == "AssistantMessage":
# Assistant message with content blocks
content = getattr(sdk_message, "content", [])
for block in content:
# Check block type by class name (SDK uses dataclasses) or dict type
block_class = type(block).__name__
block_type = block.get("type") if isinstance(block, dict) else None
if block_class == "TextBlock" or block_type == "text":
# Text content
text = getattr(block, "text", None) or (
block.get("text") if isinstance(block, dict) else ""
)
if text:
# Start text block if needed (or restart after tool calls)
if not self.has_started_text or self.has_ended_text:
# Generate new text block ID for text after tools
if self.has_ended_text:
self.text_block_id = str(uuid.uuid4())
self.has_ended_text = False
responses.append(StreamTextStart(id=self.text_block_id))
self.has_started_text = True
# Emit text delta
responses.append(
StreamTextDelta(
id=self.text_block_id,
delta=text,
)
)
elif block_class == "ToolUseBlock" or block_type == "tool_use":
# Tool call
tool_id_raw = getattr(block, "id", None) or (
block.get("id") if isinstance(block, dict) else None
)
tool_id: str = (
str(tool_id_raw) if tool_id_raw else str(uuid.uuid4())
)
tool_name_raw = getattr(block, "name", None) or (
block.get("name") if isinstance(block, dict) else None
)
tool_name: str = str(tool_name_raw) if tool_name_raw else "unknown"
tool_input = getattr(block, "input", None) or (
block.get("input") if isinstance(block, dict) else {}
)
# End text block if we were streaming text
if self.has_started_text and not self.has_ended_text:
responses.append(StreamTextEnd(id=self.text_block_id))
self.has_ended_text = True
# Emit tool input start
responses.append(
StreamToolInputStart(
toolCallId=tool_id,
toolName=tool_name,
)
)
# Emit tool input available with full input
responses.append(
StreamToolInputAvailable(
toolCallId=tool_id,
toolName=tool_name,
input=tool_input if isinstance(tool_input, dict) else {},
)
)
# Track the tool call
self.current_tool_calls[tool_id] = {
"name": tool_name,
"input": tool_input,
}
elif class_name in ("ToolResultMessage", "UserMessage"):
# Tool result - check for tool_result content
content = getattr(sdk_message, "content", [])
for block in content:
block_class = type(block).__name__
block_type = block.get("type") if isinstance(block, dict) else None
if block_class == "ToolResultBlock" or block_type == "tool_result":
tool_use_id = getattr(block, "tool_use_id", None) or (
block.get("tool_use_id") if isinstance(block, dict) else None
)
result_content = getattr(block, "content", None) or (
block.get("content") if isinstance(block, dict) else ""
)
is_error = getattr(block, "is_error", False) or (
block.get("is_error", False)
if isinstance(block, dict)
else False
)
if tool_use_id:
tool_info = self.current_tool_calls.get(tool_use_id, {})
tool_name = tool_info.get("name", "unknown")
# Format the output
if isinstance(result_content, list):
# Extract text from content blocks
output_text = ""
for item in result_content:
if (
isinstance(item, dict)
and item.get("type") == "text"
):
output_text += item.get("text", "")
elif hasattr(item, "text"):
output_text += getattr(item, "text", "")
output = output_text
elif isinstance(result_content, str):
output = result_content
else:
output = json.dumps(result_content)
responses.append(
StreamToolOutputAvailable(
toolCallId=tool_use_id,
toolName=tool_name,
output=output,
success=not is_error,
)
)
elif class_name == "ResultMessage":
# Final result
if msg_subtype == "success":
# End text block if still open
if self.has_started_text and not self.has_ended_text:
responses.append(StreamTextEnd(id=self.text_block_id))
self.has_ended_text = True
# Emit finish
responses.append(StreamFinish())
elif msg_subtype in ("error", "error_during_execution"):
error_msg = getattr(sdk_message, "error", "Unknown error")
responses.append(
StreamError(
errorText=str(error_msg),
code="sdk_error",
)
)
responses.append(StreamFinish())
elif class_name == "ErrorMessage":
# Error message
error_msg = getattr(sdk_message, "message", None) or getattr(
sdk_message, "error", "Unknown error"
)
responses.append(
StreamError(
errorText=str(error_msg),
code="sdk_error",
)
)
responses.append(StreamFinish())
else:
logger.debug(f"Unhandled SDK message type: {class_name}")
return responses
def create_heartbeat(self, tool_call_id: str | None = None) -> StreamHeartbeat:
"""Create a heartbeat response."""
return StreamHeartbeat(toolCallId=tool_call_id)
def create_usage(
self,
prompt_tokens: int,
completion_tokens: int,
) -> StreamUsage:
"""Create a usage statistics response."""
return StreamUsage(
promptTokens=prompt_tokens,
completionTokens=completion_tokens,
totalTokens=prompt_tokens + completion_tokens,
)
async def adapt_sdk_stream(
sdk_stream: AsyncGenerator[Any, None],
message_id: str | None = None,
task_id: str | None = None,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Adapt a Claude Agent SDK stream to Vercel AI SDK format.
Args:
sdk_stream: The async generator from the Claude Agent SDK.
message_id: Optional message ID for the response.
task_id: Optional task ID for reconnection support.
Yields:
StreamBaseResponse objects in Vercel AI SDK format.
"""
adapter = SDKResponseAdapter(message_id=message_id)
if task_id:
adapter.set_task_id(task_id)
# Emit start immediately
yield StreamStart(messageId=adapter.message_id, taskId=task_id)
finished = False
try:
async for sdk_message in sdk_stream:
responses = adapter.convert_message(sdk_message)
for response in responses:
# Skip duplicate start messages
if isinstance(response, StreamStart):
continue
if isinstance(response, StreamFinish):
finished = True
yield response
except Exception as e:
logger.error(f"Error in SDK stream: {e}", exc_info=True)
yield StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
)
yield StreamFinish()
return
# Ensure terminal StreamFinish if SDK stream ended without one
if not finished:
yield StreamFinish()

View File

@@ -0,0 +1,278 @@
"""Security hooks for Claude Agent SDK integration.
This module provides security hooks that validate tool calls before execution,
ensuring multi-user isolation and preventing unauthorized operations.
"""
import logging
import re
from typing import Any, cast
logger = logging.getLogger(__name__)
# Tools that are blocked entirely (CLI/system access)
BLOCKED_TOOLS = {
"Bash",
"bash",
"shell",
"exec",
"terminal",
"command",
"Read", # Block raw file read - use workspace tools instead
"Write", # Block raw file write - use workspace tools instead
"Edit", # Block raw file edit - use workspace tools instead
"Glob", # Block raw file glob - use workspace tools instead
"Grep", # Block raw file grep - use workspace tools instead
}
# Dangerous patterns in tool inputs
DANGEROUS_PATTERNS = [
r"sudo",
r"rm\s+-rf",
r"dd\s+if=",
r"/etc/passwd",
r"/etc/shadow",
r"chmod\s+777",
r"curl\s+.*\|.*sh",
r"wget\s+.*\|.*sh",
r"eval\s*\(",
r"exec\s*\(",
r"__import__",
r"os\.system",
r"subprocess",
]
def _validate_tool_access(tool_name: str, tool_input: dict[str, Any]) -> dict[str, Any]:
"""Validate that a tool call is allowed.
Returns:
Empty dict to allow, or dict with hookSpecificOutput to deny
"""
# Block forbidden tools
if tool_name in BLOCKED_TOOLS:
logger.warning(f"Blocked tool access attempt: {tool_name}")
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": (
f"Tool '{tool_name}' is not available. "
"Use the CoPilot-specific tools instead."
),
}
}
# Check for dangerous patterns in tool input
input_str = str(tool_input)
for pattern in DANGEROUS_PATTERNS:
if re.search(pattern, input_str, re.IGNORECASE):
logger.warning(
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
)
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": "Input contains blocked pattern",
}
}
return {}
def _validate_user_isolation(
tool_name: str, tool_input: dict[str, Any], user_id: str | None
) -> dict[str, Any]:
"""Validate that tool calls respect user isolation."""
# For workspace file tools, ensure path doesn't escape
if "workspace" in tool_name.lower():
path = tool_input.get("path", "") or tool_input.get("file_path", "")
if path:
# Check for path traversal
if ".." in path or path.startswith("/"):
logger.warning(
f"Blocked path traversal attempt: {path} by user {user_id}"
)
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": "Path traversal not allowed",
}
}
return {}
def create_security_hooks(user_id: str | None) -> dict[str, Any]:
"""Create the security hooks configuration for Claude Agent SDK.
Includes security validation and observability hooks:
- PreToolUse: Security validation before tool execution
- PostToolUse: Log successful tool executions
- PostToolUseFailure: Log and handle failed tool executions
- PreCompact: Log context compaction events (SDK handles compaction automatically)
Args:
user_id: Current user ID for isolation validation
Returns:
Hooks configuration dict for ClaudeAgentOptions
"""
try:
from claude_agent_sdk import HookMatcher
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
async def pre_tool_use_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Combined pre-tool-use validation hook."""
_ = context # unused but required by signature
tool_name = cast(str, input_data.get("tool_name", ""))
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
# Validate basic tool access
result = _validate_tool_access(tool_name, tool_input)
if result:
return cast(SyncHookJSONOutput, result)
# Validate user isolation
result = _validate_user_isolation(tool_name, tool_input, user_id)
if result:
return cast(SyncHookJSONOutput, result)
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
return cast(SyncHookJSONOutput, {})
async def post_tool_use_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log successful tool executions for observability."""
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
return cast(SyncHookJSONOutput, {})
async def post_tool_failure_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log failed tool executions for debugging."""
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
error = input_data.get("error", "Unknown error")
logger.warning(
f"[SDK] Tool failed: {tool_name}, error={error}, "
f"user={user_id}, tool_use_id={tool_use_id}"
)
return cast(SyncHookJSONOutput, {})
async def pre_compact_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log when SDK triggers context compaction.
The SDK automatically compacts conversation history when it grows too large.
This hook provides visibility into when compaction happens.
"""
_ = context, tool_use_id
trigger = input_data.get("trigger", "auto")
logger.info(
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
)
return cast(SyncHookJSONOutput, {})
return {
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
"PostToolUseFailure": [
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
],
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
}
except ImportError:
# Fallback for when SDK isn't available - return empty hooks
return {}
def create_strict_security_hooks(
user_id: str | None,
allowed_tools: list[str] | None = None,
) -> dict[str, Any]:
"""Create strict security hooks that only allow specific tools.
Args:
user_id: Current user ID
allowed_tools: List of allowed tool names (defaults to CoPilot tools)
Returns:
Hooks configuration dict
"""
try:
from claude_agent_sdk import HookMatcher
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
from .tool_adapter import RAW_TOOL_NAMES
tools_list = allowed_tools if allowed_tools is not None else RAW_TOOL_NAMES
allowed_set = set(tools_list)
async def strict_pre_tool_use(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Strict validation that only allows whitelisted tools."""
_ = context # unused but required by signature
tool_name = cast(str, input_data.get("tool_name", ""))
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
# Remove MCP prefix if present
clean_name = tool_name.removeprefix("mcp__copilot__")
if clean_name not in allowed_set:
logger.warning(f"Blocked non-whitelisted tool: {tool_name}")
return cast(
SyncHookJSONOutput,
{
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": (
f"Tool '{tool_name}' is not in the allowed list"
),
}
},
)
# Run standard validations using clean_name for consistent checks
result = _validate_tool_access(clean_name, tool_input)
if result:
return cast(SyncHookJSONOutput, result)
result = _validate_user_isolation(clean_name, tool_input, user_id)
if result:
return cast(SyncHookJSONOutput, result)
logger.debug(
f"[SDK Audit] Tool call: tool={tool_name}, "
f"user={user_id}, tool_use_id={tool_use_id}"
)
return cast(SyncHookJSONOutput, {})
return {
"PreToolUse": [
HookMatcher(matcher="*", hooks=[strict_pre_tool_use]),
],
}
except ImportError:
return {}

View File

@@ -0,0 +1,475 @@
"""Claude Agent SDK service layer for CoPilot chat completions."""
import asyncio
import json
import logging
import uuid
from collections.abc import AsyncGenerator
from typing import Any
import openai
from backend.data.understanding import (
format_understanding_for_prompt,
get_business_understanding,
)
from backend.util.exceptions import NotFoundError
from ..config import ChatConfig
from ..model import (
ChatMessage,
ChatSession,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from ..response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamStart,
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from ..tracking import track_user_message
from .anthropic_fallback import stream_with_anthropic
from .response_adapter import SDKResponseAdapter
from .security_hooks import create_security_hooks
from .tool_adapter import (
COPILOT_TOOL_NAMES,
create_copilot_mcp_server,
set_execution_context,
)
logger = logging.getLogger(__name__)
config = ChatConfig()
# Set to hold background tasks to prevent garbage collection
_background_tasks: set[asyncio.Task[Any]] = set()
DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations.
Here is everything you know about the current user from previous interactions:
<users_information>
{users_information}
</users_information>
## YOUR CORE MANDATE
You are action-oriented. Your success is measured by:
- **Value Delivery**: Does the user think "wow, that was amazing" or "what was the point"?
- **Demonstrable Proof**: Show working automations, not descriptions of what's possible
- **Time Saved**: Focus on tangible efficiency gains
- **Quality Output**: Deliver results that meet or exceed expectations
## YOUR WORKFLOW
Adapt flexibly to the conversation context. Not every interaction requires all stages:
1. **Explore & Understand**: Learn about the user's business, tasks, and goals. Use `add_understanding` to capture important context that will improve future conversations.
2. **Assess Automation Potential**: Help the user understand whether and how AI can automate their task.
3. **Prepare for AI**: Provide brief, actionable guidance on prerequisites (data, access, etc.).
4. **Discover or Create Agents**:
- **Always check the user's library first** with `find_library_agent` (these may be customized to their needs)
- Search the marketplace with `find_agent` for pre-built automations
- Find reusable components with `find_block`
- Create custom solutions with `create_agent` if nothing suitable exists
- Modify existing library agents with `edit_agent`
5. **Execute**: Run automations immediately, schedule them, or set up webhooks using `run_agent`. Test specific components with `run_block`.
6. **Show Results**: Display outputs using `agent_output`.
## BEHAVIORAL GUIDELINES
**Be Concise:**
- Target 2-5 short lines maximum
- Make every word count—no repetition or filler
- Use lightweight structure for scannability (bullets, numbered lists, short prompts)
- Avoid jargon (blocks, slugs, cron) unless the user asks
**Be Proactive:**
- Suggest next steps before being asked
- Anticipate needs based on conversation context and user information
- Look for opportunities to expand scope when relevant
- Reveal capabilities through action, not explanation
**Use Tools Effectively:**
- Select the right tool for each task
- **Always check `find_library_agent` before searching the marketplace**
- Use `add_understanding` to capture valuable business context
- When tool calls fail, try alternative approaches
## CRITICAL REMINDER
You are NOT a chatbot. You are NOT documentation. You are a partner who helps busy business owners get value quickly by showing proof through working automations. Bias toward action over explanation."""
async def _build_system_prompt(
user_id: str | None, has_conversation_history: bool = False
) -> tuple[str, Any]:
"""Build the system prompt with user's business understanding context.
Args:
user_id: The user ID to fetch understanding for.
has_conversation_history: Whether there's existing conversation history.
If True, we don't tell the model to greet/introduce (since they're
already in a conversation).
"""
understanding = None
if user_id:
try:
understanding = await get_business_understanding(user_id)
except Exception as e:
logger.warning(f"Failed to fetch business understanding: {e}")
if understanding:
context = format_understanding_for_prompt(understanding)
elif has_conversation_history:
# Don't tell model to greet if there's conversation history
context = "No prior understanding saved yet. Continue the existing conversation naturally."
else:
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
return DEFAULT_SYSTEM_PROMPT.format(users_information=context), understanding
def _format_conversation_history(session: ChatSession) -> str:
"""Format conversation history as a prompt context.
The SDK handles context compaction automatically, but we apply
max_context_messages as a safety guard to limit initial prompt size.
"""
if not session.messages:
return ""
# Get all messages except the last user message (which will be the prompt)
messages = session.messages[:-1] if session.messages else []
if not messages:
return ""
# Apply max_context_messages limit as a safety guard
# (SDK handles compaction, but this prevents excessively large initial prompts)
max_messages = config.max_context_messages
if len(messages) > max_messages:
messages = messages[-max_messages:]
history_parts = ["<conversation_history>"]
for msg in messages:
if msg.role == "user":
history_parts.append(f"User: {msg.content or ''}")
elif msg.role == "assistant":
# Pass full content - SDK handles compaction automatically
history_parts.append(f"Assistant: {msg.content or ''}")
if msg.tool_calls:
for tc in msg.tool_calls:
func = tc.get("function", {})
history_parts.append(
f" [Called tool: {func.get('name', 'unknown')}]"
)
elif msg.role == "tool":
# Truncate large tool results to avoid blowing context window
tool_content = msg.content or ""
if len(tool_content) > 500:
tool_content = tool_content[:500] + "... (truncated)"
history_parts.append(f" [Tool result: {tool_content}]")
history_parts.append("</conversation_history>")
history_parts.append("")
history_parts.append(
"Continue this conversation. Respond to the user's latest message:"
)
history_parts.append("")
return "\n".join(history_parts)
async def _generate_session_title(
message: str,
user_id: str | None = None,
session_id: str | None = None,
) -> str | None:
"""Generate a concise title for a chat session."""
from backend.util.settings import Settings
settings = Settings()
try:
# Build extra_body for OpenRouter tracing
extra_body: dict[str, Any] = {
"posthogProperties": {"environment": settings.config.app_env.value},
}
if user_id:
extra_body["user"] = user_id[:128]
extra_body["posthogDistinctId"] = user_id
if session_id:
extra_body["session_id"] = session_id[:128]
client = openai.AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
response = await client.chat.completions.create(
model=config.title_model,
messages=[
{
"role": "system",
"content": "Generate a very short title (3-6 words) for a chat conversation based on the user's first message. Return ONLY the title, no quotes or punctuation.",
},
{"role": "user", "content": message[:500]},
],
max_tokens=20,
extra_body=extra_body,
)
title = response.choices[0].message.content
if title:
title = title.strip().strip("\"'")
return title[:47] + "..." if len(title) > 50 else title
return None
except Exception as e:
logger.warning(f"Failed to generate session title: {e}")
return None
async def stream_chat_completion_sdk(
session_id: str,
message: str | None = None,
tool_call_response: str | None = None, # noqa: ARG001
is_user_message: bool = True,
user_id: str | None = None,
retry_count: int = 0, # noqa: ARG001
session: ChatSession | None = None,
context: dict[str, str] | None = None, # noqa: ARG001
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream chat completion using Claude Agent SDK.
Drop-in replacement for stream_chat_completion with improved reliability.
"""
if session is None:
session = await get_chat_session(session_id, user_id)
if not session:
raise NotFoundError(
f"Session {session_id} not found. Please create a new session first."
)
if message:
session.messages.append(
ChatMessage(
role="user" if is_user_message else "assistant", content=message
)
)
if is_user_message:
track_user_message(
user_id=user_id, session_id=session_id, message_length=len(message)
)
session = await upsert_chat_session(session)
# Generate title for new sessions (first user message)
if is_user_message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
if len(user_messages) == 1:
first_message = user_messages[0].content or message or ""
if first_message:
task = asyncio.create_task(
_update_title_async(session_id, first_message, user_id)
)
# Store reference to prevent garbage collection
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
# Check if there's conversation history (more than just the current message)
has_history = len(session.messages) > 1
system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=has_history
)
set_execution_context(user_id, session, None)
message_id = str(uuid.uuid4())
text_block_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
yield StreamStart(messageId=message_id, taskId=task_id)
# Track whether the stream completed normally via ResultMessage
stream_completed = False
try:
try:
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
# Create MCP server with CoPilot tools
mcp_server = create_copilot_mcp_server()
options = ClaudeAgentOptions(
system_prompt=system_prompt,
mcp_servers={"copilot": mcp_server}, # type: ignore[arg-type]
allowed_tools=COPILOT_TOOL_NAMES,
hooks=create_security_hooks(user_id), # type: ignore[arg-type]
continue_conversation=True, # Enable conversation continuation
)
adapter = SDKResponseAdapter(message_id=message_id)
adapter.set_task_id(task_id)
async with ClaudeSDKClient(options=options) as client:
# Build prompt with conversation history for context
# The SDK doesn't support replaying full conversation history,
# so we include it as context in the prompt
current_message = message or ""
if not current_message and session.messages:
last_user = [m for m in session.messages if m.role == "user"]
if last_user:
current_message = last_user[-1].content or ""
# Include conversation history if there are prior messages
if len(session.messages) > 1:
history_context = _format_conversation_history(session)
prompt = f"{history_context}{current_message}"
else:
prompt = current_message
# Guard against empty prompts
if not prompt.strip():
yield StreamError(
errorText="Message cannot be empty.",
code="empty_prompt",
)
yield StreamFinish()
return
await client.query(prompt, session_id=session_id)
# Track assistant response to save to session
# We may need multiple assistant messages if text comes after tool results
assistant_response = ChatMessage(role="assistant", content="")
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False # Track if we've received tool results
# Receive messages from the SDK
async for sdk_msg in client.receive_messages():
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
continue
yield response
# Accumulate text deltas into assistant response
if isinstance(response, StreamTextDelta):
delta = response.delta or ""
# After tool results, create new assistant message for post-tool text
if has_tool_results and has_appended_assistant:
assistant_response = ChatMessage(
role="assistant", content=delta
)
accumulated_tool_calls = [] # Reset for new message
session.messages.append(assistant_response)
has_tool_results = False
else:
assistant_response.content = (
assistant_response.content or ""
) + delta
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
# Track tool calls on the assistant message
elif isinstance(response, StreamToolInputAvailable):
accumulated_tool_calls.append(
{
"id": response.toolCallId,
"type": "function",
"function": {
"name": response.toolName,
"arguments": json.dumps(response.input or {}),
},
}
)
# Update assistant message with tool calls
assistant_response.tool_calls = accumulated_tool_calls
# Append assistant message if not already (tool-only response)
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolOutputAvailable):
session.messages.append(
ChatMessage(
role="tool",
content=(
response.output
if isinstance(response.output, str)
else str(response.output)
),
tool_call_id=response.toolCallId,
)
)
has_tool_results = True
elif isinstance(response, StreamFinish):
stream_completed = True
# Break out of the message loop if we received finish signal
if stream_completed:
break
# Ensure assistant response is saved even if no text deltas
# (e.g., only tool calls were made)
if (
assistant_response.content or assistant_response.tool_calls
) and not has_appended_assistant:
session.messages.append(assistant_response)
except ImportError:
logger.warning(
"[SDK] claude-agent-sdk not available, using Anthropic fallback"
)
async for response in stream_with_anthropic(
session, system_prompt, text_block_id
):
if isinstance(response, StreamFinish):
stream_completed = True
yield response
# Save the session with accumulated messages
await upsert_chat_session(session)
logger.debug(
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
)
# Yield StreamFinish to signal completion to the caller (routes.py)
# Only if one hasn't already been yielded by the stream
if not stream_completed:
yield StreamFinish()
except Exception as e:
logger.error(f"[SDK] Error: {e}", exc_info=True)
# Save session even on error to preserve any partial response
try:
await upsert_chat_session(session)
except Exception as save_err:
logger.error(f"[SDK] Failed to save session on error: {save_err}")
# Sanitize error message to avoid exposing internal details
yield StreamError(
errorText="An error occurred. Please try again.",
code="sdk_error",
)
yield StreamFinish()
async def _update_title_async(
session_id: str, message: str, user_id: str | None = None
) -> None:
"""Background task to update session title."""
try:
title = await _generate_session_title(
message, user_id=user_id, session_id=session_id
)
if title:
await update_session_title(session_id, title)
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
except Exception as e:
logger.warning(f"[SDK] Failed to update session title: {e}")

View File

@@ -0,0 +1,217 @@
"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools.
This module provides the adapter layer that converts existing BaseTool implementations
into in-process MCP tools that can be used with the Claude Agent SDK.
"""
import json
import logging
import uuid
from contextvars import ContextVar
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools import TOOL_REGISTRY
from backend.api.features.chat.tools.base import BaseTool
logger = logging.getLogger(__name__)
# Context variables to pass user/session info to tool execution
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
_current_tool_call_id: ContextVar[str | None] = ContextVar(
"current_tool_call_id", default=None
)
def set_execution_context(
user_id: str | None,
session: ChatSession,
tool_call_id: str | None = None,
) -> None:
"""Set the execution context for tool calls.
This must be called before streaming begins to ensure tools have access
to user_id and session information.
"""
_current_user_id.set(user_id)
_current_session.set(session)
_current_tool_call_id.set(tool_call_id)
def get_execution_context() -> tuple[str | None, ChatSession | None, str | None]:
"""Get the current execution context."""
return (
_current_user_id.get(),
_current_session.get(),
_current_tool_call_id.get(),
)
def create_tool_handler(base_tool: BaseTool):
"""Create an async handler function for a BaseTool.
This wraps the existing BaseTool._execute method to be compatible
with the Claude Agent SDK MCP tool format.
"""
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Execute the wrapped tool and return MCP-formatted response."""
user_id, session, tool_call_id = get_execution_context()
if session is None:
return {
"content": [
{
"type": "text",
"text": json.dumps(
{
"error": "No session context available",
"type": "error",
}
),
}
],
"isError": True,
}
try:
# Call the existing tool's execute method
# Generate unique tool_call_id per invocation for proper correlation
effective_id = tool_call_id or f"sdk-{uuid.uuid4().hex[:12]}"
result = await base_tool.execute(
user_id=user_id,
session=session,
tool_call_id=effective_id,
**args,
)
# The result is a StreamToolOutputAvailable, extract the output
return {
"content": [
{
"type": "text",
"text": (
result.output
if isinstance(result.output, str)
else json.dumps(result.output)
),
}
],
"isError": not result.success,
}
except Exception as e:
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
return {
"content": [
{
"type": "text",
"text": json.dumps(
{
"error": str(e),
"type": "error",
"message": f"Failed to execute {base_tool.name}",
}
),
}
],
"isError": True,
}
return tool_handler
def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
"""Build a JSON Schema input schema for a tool."""
return {
"type": "object",
"properties": base_tool.parameters.get("properties", {}),
"required": base_tool.parameters.get("required", []),
}
def get_tool_definitions() -> list[dict[str, Any]]:
"""Get all tool definitions in MCP format.
Returns a list of tool definitions that can be used with
create_sdk_mcp_server or as raw tool definitions.
"""
tool_definitions = []
for tool_name, base_tool in TOOL_REGISTRY.items():
tool_def = {
"name": tool_name,
"description": base_tool.description,
"inputSchema": _build_input_schema(base_tool),
}
tool_definitions.append(tool_def)
return tool_definitions
def get_tool_handlers() -> dict[str, Any]:
"""Get all tool handlers mapped by name.
Returns a dictionary mapping tool names to their handler functions.
"""
handlers = {}
for tool_name, base_tool in TOOL_REGISTRY.items():
handlers[tool_name] = create_tool_handler(base_tool)
return handlers
# Create the MCP server configuration
def create_copilot_mcp_server():
"""Create an in-process MCP server configuration for CoPilot tools.
This can be passed to ClaudeAgentOptions.mcp_servers.
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
package being available. This function returns the configuration that
can be used with the SDK.
"""
try:
from claude_agent_sdk import create_sdk_mcp_server, tool
# Create decorated tool functions
sdk_tools = []
for tool_name, base_tool in TOOL_REGISTRY.items():
# Get the handler
handler = create_tool_handler(base_tool)
# Create the decorated tool
# The @tool decorator expects (name, description, schema)
# Pass full JSON schema with type, properties, and required
decorated = tool(
tool_name,
base_tool.description,
_build_input_schema(base_tool),
)(handler)
sdk_tools.append(decorated)
# Create the MCP server
server = create_sdk_mcp_server(
name="copilot",
version="1.0.0",
tools=sdk_tools,
)
return server
except ImportError:
# Let ImportError propagate so service.py handles the fallback
raise
# List of tool names for allowed_tools configuration
COPILOT_TOOL_NAMES = [f"mcp__copilot__{name}" for name in TOOL_REGISTRY.keys()]
# Also export the raw tool names for flexibility
RAW_TOOL_NAMES = list(TOOL_REGISTRY.keys())

View File

@@ -555,6 +555,10 @@ async def get_active_task_for_session(
if task_user_id and user_id != task_user_id:
continue
logger.info(
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
)
# Get the last message ID from Redis Stream
stream_key = _get_task_stream_key(task_id)
last_id = "0-0"

View File

@@ -1,265 +0,0 @@
"""
MCP (Model Context Protocol) Tool Block.
A single dynamic block that can connect to any MCP server, discover available tools,
and execute them. Works like AgentExecutorBlock — the user selects a tool from a
dropdown and the input/output schema adapts dynamically.
"""
import json
import logging
from typing import Any, Literal
from pydantic import SecretStr
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.data.block import (
Block,
BlockCategory,
BlockInput,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
BlockType,
)
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
OAuth2Credentials,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.json import validate_with_jsonschema
logger = logging.getLogger(__name__)
MCPCredentials = APIKeyCredentials | OAuth2Credentials
MCPCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.MCP], Literal["api_key", "oauth2"]
]
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="mcp",
api_key=SecretStr("test-mcp-token"),
title="Mock MCP Credentials",
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
class MCPToolBlock(Block):
"""
A block that connects to an MCP server, lets the user pick a tool,
and executes it with dynamic input/output schema.
The flow:
1. User provides an MCP server URL (and optional credentials)
2. Frontend calls the backend to get tool list from that URL
3. User selects a tool from a dropdown (available_tools)
4. The block's input schema updates to reflect the selected tool's parameters
5. On execution, the block calls the MCP server to run the tool
"""
class Input(BlockSchemaInput):
# -- Static fields (always shown) --
credentials: MCPCredentialsInput = CredentialsField(
description="Credentials for the MCP server. Use an API key for Bearer "
"token auth, or OAuth2 for servers that support it. For public "
"servers, create a credential with any placeholder value.",
)
server_url: str = SchemaField(
description="URL of the MCP server (Streamable HTTP endpoint)",
placeholder="https://mcp.example.com/mcp",
)
available_tools: dict[str, Any] = SchemaField(
description="Available tools on the MCP server. "
"This is populated automatically when a server URL is provided.",
default={},
hidden=True,
)
selected_tool: str = SchemaField(
description="The MCP tool to execute",
placeholder="Select a tool",
default="",
)
tool_input_schema: dict[str, Any] = SchemaField(
description="JSON Schema for the selected tool's input parameters. "
"Populated automatically when a tool is selected.",
default={},
hidden=True,
)
# -- Dynamic field: actual arguments for the selected tool --
tool_arguments: dict[str, Any] = SchemaField(
description="Arguments to pass to the selected MCP tool. "
"The fields here are defined by the tool's input schema.",
default={},
)
@classmethod
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
"""Return the tool's input schema so the builder UI renders dynamic fields."""
return data.get("tool_input_schema", {})
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
"""Return the current tool_arguments as defaults for the dynamic fields."""
return data.get("tool_arguments", {})
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
"""Check which required tool arguments are missing."""
required_fields = cls.get_input_schema(data).get("required", [])
tool_arguments = data.get("tool_arguments", {})
return set(required_fields) - set(tool_arguments)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
"""Validate tool_arguments against the tool's input schema."""
tool_schema = cls.get_input_schema(data)
if not tool_schema:
return None
tool_arguments = data.get("tool_arguments", {})
return validate_with_jsonschema(tool_schema, tool_arguments)
class Output(BlockSchemaOutput):
result: Any = SchemaField(description="The result returned by the MCP tool")
error: str = SchemaField(description="Error message if the tool call failed")
def __init__(self):
super().__init__(
id="a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4",
description="Connect to any MCP server and execute its tools. "
"Provide a server URL, select a tool, and pass arguments dynamically.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=MCPToolBlock.Input,
output_schema=MCPToolBlock.Output,
block_type=BlockType.STANDARD,
test_input={
"server_url": "https://mcp.example.com/mcp",
"credentials": TEST_CREDENTIALS_INPUT,
"selected_tool": "get_weather",
"tool_input_schema": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
"tool_arguments": {"city": "London"},
},
test_output=[
(
"result",
{"weather": "sunny", "temperature": 20},
),
],
test_mock={
"_call_mcp_tool": lambda *a, **kw: {
"weather": "sunny",
"temperature": 20,
},
},
test_credentials=TEST_CREDENTIALS,
)
async def _call_mcp_tool(
self,
server_url: str,
tool_name: str,
arguments: dict[str, Any],
auth_token: str | None = None,
) -> Any:
"""Call a tool on the MCP server. Extracted for easy mocking in tests."""
# Trust the user-configured server URL to allow internal/localhost servers
client = MCPClient(
server_url,
auth_token=auth_token,
trusted_origins=[server_url],
)
await client.initialize()
result = await client.call_tool(tool_name, arguments)
if result.is_error:
error_text = ""
for item in result.content:
if item.get("type") == "text":
error_text += item.get("text", "")
raise MCPClientError(
f"MCP tool '{tool_name}' returned an error: "
f"{error_text or 'Unknown error'}"
)
# Extract text content from the result
output_parts = []
for item in result.content:
if item.get("type") == "text":
text = item.get("text", "")
# Try to parse as JSON for structured output
try:
output_parts.append(json.loads(text))
except (json.JSONDecodeError, ValueError):
output_parts.append(text)
elif item.get("type") == "image":
output_parts.append(
{
"type": "image",
"data": item.get("data"),
"mimeType": item.get("mimeType"),
}
)
elif item.get("type") == "resource":
output_parts.append(item.get("resource", {}))
# If single result, unwrap
if len(output_parts) == 1:
return output_parts[0]
return output_parts if output_parts else None
@staticmethod
def _extract_auth_token(credentials: MCPCredentials) -> str | None:
"""Extract a Bearer token from either API key or OAuth2 credentials."""
if isinstance(credentials, OAuth2Credentials):
return credentials.access_token.get_secret_value()
if isinstance(credentials, APIKeyCredentials) and credentials.api_key:
token_value = credentials.api_key.get_secret_value()
if token_value:
return token_value
return None
async def run(
self,
input_data: Input,
*,
credentials: MCPCredentials,
**kwargs,
) -> BlockOutput:
if not input_data.server_url:
yield "error", "MCP server URL is required"
return
if not input_data.selected_tool:
yield "error", "No tool selected. Please select a tool from the dropdown."
return
auth_token = self._extract_auth_token(credentials)
try:
result = await self._call_mcp_tool(
server_url=input_data.server_url,
tool_name=input_data.selected_tool,
arguments=input_data.tool_arguments,
auth_token=auth_token,
)
yield "result", result
except MCPClientError as e:
yield "error", str(e)
except Exception as e:
logger.exception(f"MCP tool call failed: {e}")
yield "error", f"MCP tool call failed: {str(e)}"

View File

@@ -1,316 +0,0 @@
"""
MCP (Model Context Protocol) HTTP client.
Implements the MCP Streamable HTTP transport for listing tools and calling tools
on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST.
Handles both JSON and SSE (text/event-stream) response formats per the MCP spec.
Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
"""
import json
import logging
from dataclasses import dataclass, field
from typing import Any
from backend.util.request import Requests
logger = logging.getLogger(__name__)
@dataclass
class MCPTool:
"""Represents an MCP tool discovered from a server."""
name: str
description: str
input_schema: dict[str, Any]
@dataclass
class MCPCallResult:
"""Result from calling an MCP tool."""
content: list[dict[str, Any]] = field(default_factory=list)
is_error: bool = False
class MCPClientError(Exception):
"""Raised when an MCP protocol error occurs."""
pass
class MCPClient:
"""
Async HTTP client for the MCP Streamable HTTP transport.
Communicates with MCP servers using JSON-RPC 2.0 over HTTP POST.
Supports optional Bearer token authentication.
"""
def __init__(
self,
server_url: str,
auth_token: str | None = None,
trusted_origins: list[str] | None = None,
):
self.server_url = server_url.rstrip("/")
self.auth_token = auth_token
self.trusted_origins = trusted_origins or []
self._request_id = 0
def _next_id(self) -> int:
self._request_id += 1
return self._request_id
def _build_headers(self) -> dict[str, str]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
}
if self.auth_token:
headers["Authorization"] = f"Bearer {self.auth_token}"
return headers
def _build_jsonrpc_request(
self, method: str, params: dict[str, Any] | None = None
) -> dict[str, Any]:
req: dict[str, Any] = {
"jsonrpc": "2.0",
"method": method,
"id": self._next_id(),
}
if params is not None:
req["params"] = params
return req
@staticmethod
def _parse_sse_response(text: str) -> dict[str, Any]:
"""Parse an SSE (text/event-stream) response body into JSON-RPC data.
MCP servers may return responses as SSE with format:
event: message
data: {"jsonrpc":"2.0","result":{...},"id":1}
We extract the last `data:` line that contains a JSON-RPC response
(i.e. has an "id" field), which is the reply to our request.
"""
last_data: dict[str, Any] | None = None
for line in text.splitlines():
stripped = line.strip()
if stripped.startswith("data:"):
payload = stripped[len("data:") :].strip()
if not payload:
continue
try:
parsed = json.loads(payload)
# Only keep JSON-RPC responses (have "id"), skip notifications
if isinstance(parsed, dict) and "id" in parsed:
last_data = parsed
except (json.JSONDecodeError, ValueError):
continue
if last_data is None:
raise MCPClientError("No JSON-RPC response found in SSE stream")
return last_data
async def _send_request(
self, method: str, params: dict[str, Any] | None = None
) -> Any:
"""Send a JSON-RPC request to the MCP server and return the result.
Handles both ``application/json`` and ``text/event-stream`` responses
as required by the MCP Streamable HTTP transport specification.
"""
payload = self._build_jsonrpc_request(method, params)
headers = self._build_headers()
requests = Requests(
raise_for_status=True,
extra_headers=headers,
trusted_origins=self.trusted_origins,
)
response = await requests.post(self.server_url, json=payload)
content_type = response.headers.get("content-type", "")
if "text/event-stream" in content_type:
body = self._parse_sse_response(response.text())
else:
try:
body = response.json()
except (ValueError, Exception) as e:
raise MCPClientError(
f"MCP server returned non-JSON response: {e}"
) from e
# Handle JSON-RPC error
if "error" in body:
error = body["error"]
if isinstance(error, dict):
raise MCPClientError(
f"MCP server error [{error.get('code', '?')}]: "
f"{error.get('message', 'Unknown error')}"
)
raise MCPClientError(f"MCP server error: {error}")
return body.get("result")
async def _send_notification(self, method: str) -> None:
"""Send a JSON-RPC notification (no id, no response expected)."""
headers = self._build_headers()
notification = {"jsonrpc": "2.0", "method": method}
requests = Requests(
raise_for_status=False,
extra_headers=headers,
trusted_origins=self.trusted_origins,
)
await requests.post(self.server_url, json=notification)
async def discover_auth(self) -> dict[str, Any] | None:
"""Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec).
Returns ``None`` if the server doesn't require auth, otherwise returns
a dict with:
- ``authorization_servers``: list of authorization server URLs
- ``resource``: the resource indicator URL (usually the MCP endpoint)
- ``scopes_supported``: optional list of supported scopes
The caller can then fetch the authorization server metadata to get
``authorization_endpoint``, ``token_endpoint``, etc.
"""
from urllib.parse import urlparse
parsed = urlparse(self.server_url)
base = f"{parsed.scheme}://{parsed.netloc}"
# Build candidates for protected-resource metadata (per RFC 9728)
path = parsed.path.rstrip("/")
candidates = []
if path and path != "/":
candidates.append(f"{base}/.well-known/oauth-protected-resource{path}")
candidates.append(f"{base}/.well-known/oauth-protected-resource")
requests = Requests(
raise_for_status=False,
trusted_origins=self.trusted_origins,
)
for url in candidates:
try:
resp = await requests.get(url)
if resp.status == 200:
data = resp.json()
if isinstance(data, dict) and "authorization_servers" in data:
return data
except Exception:
continue
return None
async def discover_auth_server_metadata(
self, auth_server_url: str
) -> dict[str, Any] | None:
"""Fetch the OAuth Authorization Server Metadata (RFC 8414).
Given an authorization server URL, returns a dict with:
- ``authorization_endpoint``
- ``token_endpoint``
- ``registration_endpoint`` (for dynamic client registration)
- ``scopes_supported``
- ``code_challenge_methods_supported``
- etc.
"""
from urllib.parse import urlparse
parsed = urlparse(auth_server_url)
base = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/")
# Try standard metadata endpoints (RFC 8414 and OpenID Connect)
candidates = []
if path and path != "/":
candidates.append(f"{base}/.well-known/oauth-authorization-server{path}")
candidates.append(f"{base}/.well-known/oauth-authorization-server")
candidates.append(f"{base}/.well-known/openid-configuration")
requests = Requests(
raise_for_status=False,
trusted_origins=self.trusted_origins,
)
for url in candidates:
try:
resp = await requests.get(url)
if resp.status == 200:
data = resp.json()
if isinstance(data, dict) and "authorization_endpoint" in data:
return data
except Exception:
continue
return None
async def initialize(self) -> dict[str, Any]:
"""
Send the MCP initialize request.
This is required by the MCP protocol before any other requests.
Returns the server's capabilities.
"""
result = await self._send_request(
"initialize",
{
"protocolVersion": "2025-03-26",
"capabilities": {},
"clientInfo": {"name": "AutoGPT-Platform", "version": "1.0.0"},
},
)
# Send initialized notification (no response expected)
await self._send_notification("notifications/initialized")
return result or {}
async def list_tools(self) -> list[MCPTool]:
"""
Discover available tools from the MCP server.
Returns a list of MCPTool objects with name, description, and input schema.
"""
result = await self._send_request("tools/list")
if not result or "tools" not in result:
return []
tools = []
for tool_data in result["tools"]:
tools.append(
MCPTool(
name=tool_data.get("name", ""),
description=tool_data.get("description", ""),
input_schema=tool_data.get("inputSchema", {}),
)
)
return tools
async def call_tool(
self, tool_name: str, arguments: dict[str, Any]
) -> MCPCallResult:
"""
Call a tool on the MCP server.
Args:
tool_name: The name of the tool to call.
arguments: The arguments to pass to the tool.
Returns:
MCPCallResult with the tool's response content.
"""
result = await self._send_request(
"tools/call",
{"name": tool_name, "arguments": arguments},
)
if not result:
return MCPCallResult(is_error=True)
return MCPCallResult(
content=result.get("content", []),
is_error=result.get("isError", False),
)

View File

@@ -1,21 +0,0 @@
"""
Conftest for MCP block tests.
Override the session-scoped server and graph_cleanup fixtures from
backend/conftest.py so that MCP integration tests don't spin up the
full SpinTestServer infrastructure.
"""
import pytest
@pytest.fixture(scope="session")
def server():
"""No-op override — MCP tests don't need the full platform server."""
yield None
@pytest.fixture(scope="session", autouse=True)
def graph_cleanup(server):
"""No-op override — MCP tests don't create graphs."""
yield

View File

@@ -1,198 +0,0 @@
"""
MCP OAuth handler for MCP servers that use OAuth 2.1 authorization.
Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed,
MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata.
This handler accepts those endpoints at construction time.
"""
import logging
import time
import urllib.parse
from typing import ClassVar, Optional
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
from backend.integrations.oauth.base import BaseOAuthHandler
from backend.integrations.providers import ProviderName
from backend.util.request import Requests
logger = logging.getLogger(__name__)
class MCPOAuthHandler(BaseOAuthHandler):
"""
OAuth handler for MCP servers with dynamically-discovered endpoints.
Construction requires the authorization and token endpoint URLs,
which are obtained via MCP OAuth metadata discovery
(``MCPClient.discover_auth`` + ``discover_auth_server_metadata``).
"""
PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP
DEFAULT_SCOPES: ClassVar[list[str]] = []
def __init__(
self,
client_id: str,
client_secret: str,
redirect_uri: str,
*,
authorize_url: str,
token_url: str,
revoke_url: str | None = None,
resource_url: str | None = None,
):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
self.authorize_url = authorize_url
self.token_url = token_url
self.revoke_url = revoke_url
self.resource_url = resource_url
def get_login_url(
self,
scopes: list[str],
state: str,
code_challenge: Optional[str],
) -> str:
scopes = self.handle_default_scopes(scopes)
params: dict[str, str] = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"state": state,
}
if scopes:
params["scope"] = " ".join(scopes)
# PKCE (S256) — included when the caller provides a code_challenge
if code_challenge:
params["code_challenge"] = code_challenge
params["code_challenge_method"] = "S256"
# MCP spec requires resource indicator (RFC 8707)
if self.resource_url:
params["resource"] = self.resource_url
return f"{self.authorize_url}?{urllib.parse.urlencode(params)}"
async def exchange_code_for_tokens(
self,
code: str,
scopes: list[str],
code_verifier: Optional[str],
) -> OAuth2Credentials:
data: dict[str, str] = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self.redirect_uri,
"client_id": self.client_id,
}
if self.client_secret:
data["client_secret"] = self.client_secret
if code_verifier:
data["code_verifier"] = code_verifier
if self.resource_url:
data["resource"] = self.resource_url
response = await Requests(raise_for_status=True).post(
self.token_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
tokens = response.json()
if "error" in tokens:
raise RuntimeError(
f"Token exchange failed: {tokens.get('error_description', tokens['error'])}"
)
now = int(time.time())
expires_in = tokens.get("expires_in")
return OAuth2Credentials(
provider=str(self.PROVIDER_NAME),
title=None,
access_token=SecretStr(tokens["access_token"]),
refresh_token=(
SecretStr(tokens["refresh_token"])
if tokens.get("refresh_token")
else None
),
access_token_expires_at=now + expires_in if expires_in else None,
refresh_token_expires_at=None,
scopes=scopes,
metadata={
"mcp_token_url": self.token_url,
"mcp_resource_url": self.resource_url,
},
)
async def _refresh_tokens(
self, credentials: OAuth2Credentials
) -> OAuth2Credentials:
if not credentials.refresh_token:
raise ValueError("No refresh token available for MCP OAuth credentials")
data: dict[str, str] = {
"grant_type": "refresh_token",
"refresh_token": credentials.refresh_token.get_secret_value(),
"client_id": self.client_id,
}
if self.client_secret:
data["client_secret"] = self.client_secret
if self.resource_url:
data["resource"] = self.resource_url
response = await Requests(raise_for_status=True).post(
self.token_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
tokens = response.json()
if "error" in tokens:
raise RuntimeError(
f"Token refresh failed: {tokens.get('error_description', tokens['error'])}"
)
now = int(time.time())
expires_in = tokens.get("expires_in")
return OAuth2Credentials(
id=credentials.id,
provider=str(self.PROVIDER_NAME),
title=credentials.title,
access_token=SecretStr(tokens["access_token"]),
refresh_token=(
SecretStr(str(tokens["refresh_token"]))
if tokens.get("refresh_token")
else credentials.refresh_token
),
access_token_expires_at=now + expires_in if expires_in else None,
refresh_token_expires_at=credentials.refresh_token_expires_at,
scopes=credentials.scopes,
metadata=credentials.metadata,
)
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
if not self.revoke_url:
return False
try:
data = {
"token": credentials.access_token.get_secret_value(),
"token_type_hint": "access_token",
"client_id": self.client_id,
}
await Requests().post(
self.revoke_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
return True
except Exception:
logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True)
return False

View File

@@ -1,104 +0,0 @@
"""
End-to-end tests against a real public MCP server.
These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp)
which is publicly accessible without authentication and returns SSE responses.
Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped
independently of the rest of the test suite (they require network access).
"""
import json
import pytest
from backend.blocks.mcp.client import MCPClient
# Public MCP server that requires no authentication
OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp"
@pytest.mark.e2e
class TestRealMCPServer:
"""Tests against the live OpenAI docs MCP server."""
@pytest.mark.asyncio
async def test_initialize(self):
"""Verify we can complete the MCP handshake with a real server."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
result = await client.initialize()
assert result["protocolVersion"] == "2025-03-26"
assert "serverInfo" in result
assert result["serverInfo"]["name"] == "openai-docs-mcp"
assert "tools" in result.get("capabilities", {})
@pytest.mark.asyncio
async def test_list_tools(self):
"""Verify we can discover tools from a real MCP server."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
await client.initialize()
tools = await client.list_tools()
assert len(tools) >= 3 # server has at least 5 tools as of writing
tool_names = {t.name for t in tools}
# These tools are documented and should be stable
assert "search_openai_docs" in tool_names
assert "list_openai_docs" in tool_names
assert "fetch_openai_doc" in tool_names
# Verify schema structure
search_tool = next(t for t in tools if t.name == "search_openai_docs")
assert "query" in search_tool.input_schema.get("properties", {})
assert "query" in search_tool.input_schema.get("required", [])
@pytest.mark.asyncio
async def test_call_tool_list_api_endpoints(self):
"""Call the list_api_endpoints tool and verify we get real data."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
await client.initialize()
result = await client.call_tool("list_api_endpoints", {})
assert not result.is_error
assert len(result.content) >= 1
assert result.content[0]["type"] == "text"
data = json.loads(result.content[0]["text"])
assert "paths" in data or "urls" in data
# The OpenAI API should have many endpoints
total = data.get("total", len(data.get("paths", [])))
assert total > 50
@pytest.mark.asyncio
async def test_call_tool_search(self):
"""Search for docs and verify we get results."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
await client.initialize()
result = await client.call_tool(
"search_openai_docs", {"query": "chat completions", "limit": 3}
)
assert not result.is_error
assert len(result.content) >= 1
@pytest.mark.asyncio
async def test_sse_response_handling(self):
"""Verify the client correctly handles SSE responses from a real server.
This is the key test — our local test server returns JSON,
but real MCP servers typically return SSE. This proves the
SSE parsing works end-to-end.
"""
client = MCPClient(OPENAI_DOCS_MCP_URL)
# initialize() internally calls _send_request which must parse SSE
result = await client.initialize()
# If we got here without error, SSE parsing works
assert isinstance(result, dict)
assert "protocolVersion" in result
# Also verify list_tools works (another SSE response)
tools = await client.list_tools()
assert len(tools) > 0
assert all(hasattr(t, "name") for t in tools)

View File

@@ -1,367 +0,0 @@
"""
Integration tests for MCP client and MCPToolBlock against a real HTTP server.
These tests spin up a local MCP test server and run the full client/block flow
against it — no mocking, real HTTP requests.
"""
import asyncio
import json
import threading
import pytest
from aiohttp import web
from pydantic import SecretStr
from backend.blocks.mcp.block import MCPToolBlock
from backend.blocks.mcp.client import MCPClient
from backend.blocks.mcp.test_server import create_test_mcp_app
from backend.data.model import APIKeyCredentials
class _MCPTestServer:
"""
Run an MCP test server in a background thread with its own event loop.
This avoids event loop conflicts with pytest-asyncio.
"""
def __init__(self, auth_token: str | None = None):
self.auth_token = auth_token
self.url: str = ""
self._runner: web.AppRunner | None = None
self._loop: asyncio.AbstractEventLoop | None = None
self._thread: threading.Thread | None = None
self._started = threading.Event()
def _run(self):
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._loop.run_until_complete(self._start())
self._started.set()
self._loop.run_forever()
async def _start(self):
app = create_test_mcp_app(auth_token=self.auth_token)
self._runner = web.AppRunner(app)
await self._runner.setup()
site = web.TCPSite(self._runner, "127.0.0.1", 0)
await site.start()
port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr]
self.url = f"http://127.0.0.1:{port}/mcp"
def start(self):
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
if not self._started.wait(timeout=5):
raise RuntimeError("MCP test server failed to start within 5 seconds")
return self
def stop(self):
if self._loop and self._runner:
asyncio.run_coroutine_threadsafe(self._runner.cleanup(), self._loop).result(
timeout=5
)
self._loop.call_soon_threadsafe(self._loop.stop)
if self._thread:
self._thread.join(timeout=5)
@pytest.fixture(scope="module")
def mcp_server():
"""Start a local MCP test server in a background thread."""
server = _MCPTestServer()
server.start()
yield server.url
server.stop()
@pytest.fixture(scope="module")
def mcp_server_with_auth():
"""Start a local MCP test server with auth in a background thread."""
server = _MCPTestServer(auth_token="test-secret-token")
server.start()
yield server.url, "test-secret-token"
server.stop()
def _make_client(url: str, auth_token: str | None = None) -> MCPClient:
"""Create an MCPClient with localhost trusted for integration tests."""
return MCPClient(url, auth_token=auth_token, trusted_origins=[url])
def _make_fake_creds(api_key: str = "FAKE_API_KEY") -> APIKeyCredentials:
return APIKeyCredentials(
id="test-integration",
provider="mcp",
api_key=SecretStr(api_key),
title="test",
)
# ── MCPClient integration tests ──────────────────────────────────────
class TestMCPClientIntegration:
"""Test MCPClient against a real local MCP server."""
@pytest.mark.asyncio
async def test_initialize(self, mcp_server):
client = _make_client(mcp_server)
result = await client.initialize()
assert result["protocolVersion"] == "2025-03-26"
assert result["serverInfo"]["name"] == "test-mcp-server"
assert "tools" in result["capabilities"]
@pytest.mark.asyncio
async def test_list_tools(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
tools = await client.list_tools()
assert len(tools) == 3
tool_names = {t.name for t in tools}
assert tool_names == {"get_weather", "add_numbers", "echo"}
# Check get_weather schema
weather = next(t for t in tools if t.name == "get_weather")
assert weather.description == "Get current weather for a city"
assert "city" in weather.input_schema["properties"]
assert weather.input_schema["required"] == ["city"]
# Check add_numbers schema
add = next(t for t in tools if t.name == "add_numbers")
assert "a" in add.input_schema["properties"]
assert "b" in add.input_schema["properties"]
@pytest.mark.asyncio
async def test_call_tool_get_weather(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
result = await client.call_tool("get_weather", {"city": "London"})
assert not result.is_error
assert len(result.content) == 1
assert result.content[0]["type"] == "text"
data = json.loads(result.content[0]["text"])
assert data["city"] == "London"
assert data["temperature"] == 22
assert data["condition"] == "sunny"
@pytest.mark.asyncio
async def test_call_tool_add_numbers(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
result = await client.call_tool("add_numbers", {"a": 3, "b": 7})
assert not result.is_error
data = json.loads(result.content[0]["text"])
assert data["result"] == 10
@pytest.mark.asyncio
async def test_call_tool_echo(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
result = await client.call_tool("echo", {"message": "Hello MCP!"})
assert not result.is_error
assert result.content[0]["text"] == "Hello MCP!"
@pytest.mark.asyncio
async def test_call_unknown_tool(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
result = await client.call_tool("nonexistent_tool", {})
assert result.is_error
assert "Unknown tool" in result.content[0]["text"]
@pytest.mark.asyncio
async def test_auth_success(self, mcp_server_with_auth):
url, token = mcp_server_with_auth
client = _make_client(url, auth_token=token)
result = await client.initialize()
assert result["protocolVersion"] == "2025-03-26"
tools = await client.list_tools()
assert len(tools) == 3
@pytest.mark.asyncio
async def test_auth_failure(self, mcp_server_with_auth):
url, _ = mcp_server_with_auth
client = _make_client(url, auth_token="wrong-token")
with pytest.raises(Exception):
await client.initialize()
@pytest.mark.asyncio
async def test_auth_missing(self, mcp_server_with_auth):
url, _ = mcp_server_with_auth
client = _make_client(url)
with pytest.raises(Exception):
await client.initialize()
# ── MCPToolBlock integration tests ───────────────────────────────────
class TestMCPToolBlockIntegration:
"""Test MCPToolBlock end-to-end against a real local MCP server."""
@pytest.mark.asyncio
async def test_full_flow_get_weather(self, mcp_server):
"""Full flow: discover tools, select one, execute it."""
# Step 1: Discover tools (simulating what the frontend/API would do)
client = _make_client(mcp_server)
await client.initialize()
tools = await client.list_tools()
assert len(tools) == 3
# Step 2: User selects "get_weather" and we get its schema
weather_tool = next(t for t in tools if t.name == "get_weather")
# Step 3: Execute the block with the selected tool
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="get_weather",
tool_input_schema=weather_tool.input_schema,
tool_arguments={"city": "Paris"},
credentials={ # type: ignore
"provider": "mcp",
"id": "test",
"type": "api_key",
"title": "test",
},
)
outputs = []
async for name, data in block.run(input_data, credentials=_make_fake_creds()):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
result = outputs[0][1]
assert result["city"] == "Paris"
assert result["temperature"] == 22
assert result["condition"] == "sunny"
@pytest.mark.asyncio
async def test_full_flow_add_numbers(self, mcp_server):
"""Full flow for add_numbers tool."""
client = _make_client(mcp_server)
await client.initialize()
tools = await client.list_tools()
add_tool = next(t for t in tools if t.name == "add_numbers")
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="add_numbers",
tool_input_schema=add_tool.input_schema,
tool_arguments={"a": 42, "b": 58},
credentials={ # type: ignore
"provider": "mcp",
"id": "test",
"type": "api_key",
"title": "test",
},
)
outputs = []
async for name, data in block.run(input_data, credentials=_make_fake_creds()):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1]["result"] == 100
@pytest.mark.asyncio
async def test_full_flow_echo_plain_text(self, mcp_server):
"""Verify plain text (non-JSON) responses work."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="echo",
tool_input_schema={
"type": "object",
"properties": {"message": {"type": "string"}},
"required": ["message"],
},
tool_arguments={"message": "Hello from AutoGPT!"},
credentials={ # type: ignore
"provider": "mcp",
"id": "test",
"type": "api_key",
"title": "test",
},
)
outputs = []
async for name, data in block.run(input_data, credentials=_make_fake_creds()):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1] == "Hello from AutoGPT!"
@pytest.mark.asyncio
async def test_full_flow_unknown_tool_yields_error(self, mcp_server):
"""Calling an unknown tool should yield an error output."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="nonexistent_tool",
tool_arguments={},
credentials={ # type: ignore
"provider": "mcp",
"id": "test",
"type": "api_key",
"title": "test",
},
)
outputs = []
async for name, data in block.run(input_data, credentials=_make_fake_creds()):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "error"
assert "returned an error" in outputs[0][1]
@pytest.mark.asyncio
async def test_full_flow_with_auth(self, mcp_server_with_auth):
"""Full flow with authentication."""
url, token = mcp_server_with_auth
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=url,
selected_tool="echo",
tool_input_schema={
"type": "object",
"properties": {"message": {"type": "string"}},
"required": ["message"],
},
tool_arguments={"message": "Authenticated!"},
credentials={ # type: ignore
"provider": "mcp",
"id": "test",
"type": "api_key",
"title": "test",
},
)
outputs = []
async for name, data in block.run(
input_data, credentials=_make_fake_creds(api_key=token)
):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1] == "Authenticated!"

View File

@@ -1,667 +0,0 @@
"""
Tests for MCP client and MCPToolBlock.
"""
import json
from unittest.mock import AsyncMock, patch
import pytest
from pydantic import SecretStr
from backend.blocks.mcp.block import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
MCPToolBlock,
)
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError
from backend.data.model import APIKeyCredentials, OAuth2Credentials
from backend.util.test import execute_block_test
# ── SSE parsing unit tests ───────────────────────────────────────────
class TestSSEParsing:
"""Tests for SSE (text/event-stream) response parsing."""
def test_parse_sse_simple(self):
sse = (
"event: message\n"
'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == {"tools": []}
assert body["id"] == 1
def test_parse_sse_with_notifications(self):
"""SSE streams can contain notifications (no id) before the response."""
sse = (
"event: message\n"
'data: {"jsonrpc":"2.0","method":"some/notification"}\n'
"\n"
"event: message\n"
'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == {"ok": True}
assert body["id"] == 2
def test_parse_sse_error_response(self):
sse = (
"event: message\n"
'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n'
)
body = MCPClient._parse_sse_response(sse)
assert "error" in body
assert body["error"]["code"] == -32600
def test_parse_sse_no_data_raises(self):
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
MCPClient._parse_sse_response("event: message\n\n")
def test_parse_sse_empty_raises(self):
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
MCPClient._parse_sse_response("")
def test_parse_sse_ignores_non_data_lines(self):
sse = (
": comment line\n"
"event: message\n"
"id: 123\n"
'data: {"jsonrpc":"2.0","result":"ok","id":1}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == "ok"
def test_parse_sse_uses_last_response(self):
"""If multiple responses exist, use the last one."""
sse = (
'data: {"jsonrpc":"2.0","result":"first","id":1}\n'
"\n"
'data: {"jsonrpc":"2.0","result":"second","id":2}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == "second"
# ── MCPClient unit tests ─────────────────────────────────────────────
class TestMCPClient:
"""Tests for the MCP HTTP client."""
def test_build_headers_without_auth(self):
client = MCPClient("https://mcp.example.com")
headers = client._build_headers()
assert "Authorization" not in headers
assert headers["Content-Type"] == "application/json"
def test_build_headers_with_auth(self):
client = MCPClient("https://mcp.example.com", auth_token="my-token")
headers = client._build_headers()
assert headers["Authorization"] == "Bearer my-token"
def test_build_jsonrpc_request(self):
client = MCPClient("https://mcp.example.com")
req = client._build_jsonrpc_request("tools/list")
assert req["jsonrpc"] == "2.0"
assert req["method"] == "tools/list"
assert "id" in req
assert "params" not in req
def test_build_jsonrpc_request_with_params(self):
client = MCPClient("https://mcp.example.com")
req = client._build_jsonrpc_request(
"tools/call", {"name": "test", "arguments": {"x": 1}}
)
assert req["params"] == {"name": "test", "arguments": {"x": 1}}
def test_request_id_increments(self):
client = MCPClient("https://mcp.example.com")
req1 = client._build_jsonrpc_request("tools/list")
req2 = client._build_jsonrpc_request("tools/list")
assert req2["id"] > req1["id"]
def test_server_url_trailing_slash_stripped(self):
client = MCPClient("https://mcp.example.com/mcp/")
assert client.server_url == "https://mcp.example.com/mcp"
@pytest.mark.asyncio
async def test_send_request_success(self):
client = MCPClient("https://mcp.example.com")
mock_response = AsyncMock()
mock_response.json.return_value = {
"jsonrpc": "2.0",
"result": {"tools": []},
"id": 1,
}
with patch.object(client, "_send_request", return_value={"tools": []}):
result = await client._send_request("tools/list")
assert result == {"tools": []}
@pytest.mark.asyncio
async def test_send_request_error(self):
client = MCPClient("https://mcp.example.com")
async def mock_send(*args, **kwargs):
raise MCPClientError("MCP server error [-32600]: Invalid Request")
with patch.object(client, "_send_request", side_effect=mock_send):
with pytest.raises(MCPClientError, match="Invalid Request"):
await client._send_request("tools/list")
@pytest.mark.asyncio
async def test_list_tools(self):
client = MCPClient("https://mcp.example.com")
mock_result = {
"tools": [
{
"name": "get_weather",
"description": "Get current weather for a city",
"inputSchema": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
},
{
"name": "search",
"description": "Search the web",
"inputSchema": {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
},
},
]
}
with patch.object(client, "_send_request", return_value=mock_result):
tools = await client.list_tools()
assert len(tools) == 2
assert tools[0].name == "get_weather"
assert tools[0].description == "Get current weather for a city"
assert tools[0].input_schema["properties"]["city"]["type"] == "string"
assert tools[1].name == "search"
@pytest.mark.asyncio
async def test_list_tools_empty(self):
client = MCPClient("https://mcp.example.com")
with patch.object(client, "_send_request", return_value={"tools": []}):
tools = await client.list_tools()
assert tools == []
@pytest.mark.asyncio
async def test_list_tools_none_result(self):
client = MCPClient("https://mcp.example.com")
with patch.object(client, "_send_request", return_value=None):
tools = await client.list_tools()
assert tools == []
@pytest.mark.asyncio
async def test_call_tool_success(self):
client = MCPClient("https://mcp.example.com")
mock_result = {
"content": [
{"type": "text", "text": json.dumps({"temp": 20, "city": "London"})}
],
"isError": False,
}
with patch.object(client, "_send_request", return_value=mock_result):
result = await client.call_tool("get_weather", {"city": "London"})
assert not result.is_error
assert len(result.content) == 1
assert result.content[0]["type"] == "text"
@pytest.mark.asyncio
async def test_call_tool_error(self):
client = MCPClient("https://mcp.example.com")
mock_result = {
"content": [{"type": "text", "text": "City not found"}],
"isError": True,
}
with patch.object(client, "_send_request", return_value=mock_result):
result = await client.call_tool("get_weather", {"city": "???"})
assert result.is_error
@pytest.mark.asyncio
async def test_call_tool_none_result(self):
client = MCPClient("https://mcp.example.com")
with patch.object(client, "_send_request", return_value=None):
result = await client.call_tool("get_weather", {"city": "London"})
assert result.is_error
@pytest.mark.asyncio
async def test_initialize(self):
client = MCPClient("https://mcp.example.com")
mock_result = {
"protocolVersion": "2025-03-26",
"capabilities": {"tools": {}},
"serverInfo": {"name": "test-server", "version": "1.0.0"},
}
with (
patch.object(client, "_send_request", return_value=mock_result) as mock_req,
patch.object(client, "_send_notification") as mock_notif,
):
result = await client.initialize()
mock_req.assert_called_once()
mock_notif.assert_called_once_with("notifications/initialized")
assert result["protocolVersion"] == "2025-03-26"
# ── MCPToolBlock unit tests ──────────────────────────────────────────
class TestMCPToolBlock:
"""Tests for the MCPToolBlock."""
def test_block_instantiation(self):
block = MCPToolBlock()
assert block.id == "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
assert block.name == "MCPToolBlock"
def test_input_schema_has_required_fields(self):
block = MCPToolBlock()
schema = block.input_schema.jsonschema()
props = schema.get("properties", {})
assert "server_url" in props
assert "selected_tool" in props
assert "tool_arguments" in props
assert "credentials" in props
def test_output_schema(self):
block = MCPToolBlock()
schema = block.output_schema.jsonschema()
props = schema.get("properties", {})
assert "result" in props
assert "error" in props
def test_get_input_schema_with_tool_schema(self):
tool_schema = {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
}
data = {"tool_input_schema": tool_schema}
result = MCPToolBlock.Input.get_input_schema(data)
assert result == tool_schema
def test_get_input_schema_without_tool_schema(self):
result = MCPToolBlock.Input.get_input_schema({})
assert result == {}
def test_get_input_defaults(self):
data = {"tool_arguments": {"city": "London"}}
result = MCPToolBlock.Input.get_input_defaults(data)
assert result == {"city": "London"}
def test_get_missing_input(self):
data = {
"tool_input_schema": {
"type": "object",
"properties": {
"city": {"type": "string"},
"units": {"type": "string"},
},
"required": ["city", "units"],
},
"tool_arguments": {"city": "London"},
}
missing = MCPToolBlock.Input.get_missing_input(data)
assert missing == {"units"}
def test_get_missing_input_all_present(self):
data = {
"tool_input_schema": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
"tool_arguments": {"city": "London"},
}
missing = MCPToolBlock.Input.get_missing_input(data)
assert missing == set()
@pytest.mark.asyncio
async def test_run_with_mock(self):
"""Test the block using the built-in test infrastructure."""
block = MCPToolBlock()
await execute_block_test(block)
@pytest.mark.asyncio
async def test_run_missing_server_url(self):
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="",
selected_tool="test",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
)
outputs = []
async for name, data in block.run(input_data, credentials=TEST_CREDENTIALS):
outputs.append((name, data))
assert outputs == [("error", "MCP server URL is required")]
@pytest.mark.asyncio
async def test_run_missing_tool(self):
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
)
outputs = []
async for name, data in block.run(input_data, credentials=TEST_CREDENTIALS):
outputs.append((name, data))
assert outputs == [
("error", "No tool selected. Please select a tool from the dropdown.")
]
@pytest.mark.asyncio
async def test_run_success(self):
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="get_weather",
tool_input_schema={
"type": "object",
"properties": {"city": {"type": "string"}},
},
tool_arguments={"city": "London"},
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
)
async def mock_call(*args, **kwargs):
return {"temp": 20, "city": "London"}
block._call_mcp_tool = mock_call # type: ignore
outputs = []
async for name, data in block.run(input_data, credentials=TEST_CREDENTIALS):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1] == {"temp": 20, "city": "London"}
@pytest.mark.asyncio
async def test_run_mcp_error(self):
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="bad_tool",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
)
async def mock_call(*args, **kwargs):
raise MCPClientError("Tool not found")
block._call_mcp_tool = mock_call # type: ignore
outputs = []
async for name, data in block.run(input_data, credentials=TEST_CREDENTIALS):
outputs.append((name, data))
assert outputs[0][0] == "error"
assert "Tool not found" in outputs[0][1]
@pytest.mark.asyncio
async def test_call_mcp_tool_parses_json_text(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[
{"type": "text", "text": '{"temp": 20}'},
],
is_error=False,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
result = await block._call_mcp_tool(
"https://mcp.example.com", "test_tool", {}
)
assert result == {"temp": 20}
@pytest.mark.asyncio
async def test_call_mcp_tool_plain_text(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[
{"type": "text", "text": "Hello, world!"},
],
is_error=False,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
result = await block._call_mcp_tool(
"https://mcp.example.com", "test_tool", {}
)
assert result == "Hello, world!"
@pytest.mark.asyncio
async def test_call_mcp_tool_multiple_content(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[
{"type": "text", "text": "Part 1"},
{"type": "text", "text": '{"part": 2}'},
],
is_error=False,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
result = await block._call_mcp_tool(
"https://mcp.example.com", "test_tool", {}
)
assert result == ["Part 1", {"part": 2}]
@pytest.mark.asyncio
async def test_call_mcp_tool_error_result(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[{"type": "text", "text": "Something went wrong"}],
is_error=True,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
with pytest.raises(MCPClientError, match="returned an error"):
await block._call_mcp_tool("https://mcp.example.com", "test_tool", {})
@pytest.mark.asyncio
async def test_call_mcp_tool_image_content(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[
{
"type": "image",
"data": "base64data==",
"mimeType": "image/png",
}
],
is_error=False,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
result = await block._call_mcp_tool(
"https://mcp.example.com", "test_tool", {}
)
assert result == {
"type": "image",
"data": "base64data==",
"mimeType": "image/png",
}
@pytest.mark.asyncio
async def test_run_sends_api_key_credentials(self):
"""Ensure non-empty API keys are sent to the MCP server."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="test_tool",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
)
creds = APIKeyCredentials(
id="test-id",
provider="mcp",
api_key=SecretStr("real-api-key"),
title="Real",
)
captured_tokens = []
async def mock_call(server_url, tool_name, arguments, auth_token=None):
captured_tokens.append(auth_token)
return "ok"
block._call_mcp_tool = mock_call # type: ignore
async for _ in block.run(input_data, credentials=creds):
pass
assert captured_tokens == ["real-api-key"]
# ── OAuth2 credential support tests ─────────────────────────────────
class TestMCPOAuth2Support:
"""Tests for OAuth2 credential support in MCPToolBlock."""
def test_extract_auth_token_from_api_key(self):
creds = APIKeyCredentials(
id="test",
provider="mcp",
api_key=SecretStr("my-api-key"),
title="test",
)
token = MCPToolBlock._extract_auth_token(creds)
assert token == "my-api-key"
def test_extract_auth_token_from_oauth2(self):
creds = OAuth2Credentials(
id="test",
provider="mcp",
access_token=SecretStr("oauth2-access-token"),
scopes=["read"],
title="test",
)
token = MCPToolBlock._extract_auth_token(creds)
assert token == "oauth2-access-token"
def test_extract_auth_token_empty_skipped(self):
creds = APIKeyCredentials(
id="test",
provider="mcp",
api_key=SecretStr(""),
title="test",
)
token = MCPToolBlock._extract_auth_token(creds)
assert token is None
@pytest.mark.asyncio
async def test_run_with_oauth2_credentials(self):
"""Verify the block can run with OAuth2 credentials."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="test_tool",
credentials=TEST_CREDENTIALS_INPUT, # type: ignore
)
oauth2_creds = OAuth2Credentials(
id="test-id",
provider="mcp",
access_token=SecretStr("real-oauth2-token"),
scopes=["read", "write"],
title="MCP OAuth",
)
captured_tokens = []
async def mock_call(server_url, tool_name, arguments, auth_token=None):
captured_tokens.append(auth_token)
return {"status": "ok"}
block._call_mcp_tool = mock_call # type: ignore
outputs = []
async for name, data in block.run(input_data, credentials=oauth2_creds):
outputs.append((name, data))
assert captured_tokens == ["real-oauth2-token"]
assert outputs == [("result", {"status": "ok"})]

View File

@@ -1,242 +0,0 @@
"""
Tests for MCP OAuth handler.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
from backend.blocks.mcp.client import MCPClient
from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
def _mock_response(json_data: dict, status: int = 200) -> MagicMock:
"""Create a mock Response with synchronous json() (matching Requests.Response)."""
resp = MagicMock()
resp.status = status
resp.ok = 200 <= status < 300
resp.json.return_value = json_data
return resp
class TestMCPOAuthHandler:
"""Tests for the MCPOAuthHandler."""
def _make_handler(self, **overrides) -> MCPOAuthHandler:
defaults = {
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"redirect_uri": "https://app.example.com/callback",
"authorize_url": "https://auth.example.com/authorize",
"token_url": "https://auth.example.com/token",
}
defaults.update(overrides)
return MCPOAuthHandler(**defaults)
def test_get_login_url_basic(self):
handler = self._make_handler()
url = handler.get_login_url(
scopes=["read", "write"],
state="random-state-token",
code_challenge="S256-challenge-value",
)
assert "https://auth.example.com/authorize?" in url
assert "response_type=code" in url
assert "client_id=test-client-id" in url
assert "state=random-state-token" in url
assert "code_challenge=S256-challenge-value" in url
assert "code_challenge_method=S256" in url
assert "scope=read+write" in url
def test_get_login_url_with_resource(self):
handler = self._make_handler(resource_url="https://mcp.example.com/mcp")
url = handler.get_login_url(
scopes=[], state="state", code_challenge="challenge"
)
assert "resource=https" in url
def test_get_login_url_without_pkce(self):
handler = self._make_handler()
url = handler.get_login_url(scopes=["read"], state="state", code_challenge=None)
assert "code_challenge" not in url
assert "code_challenge_method" not in url
@pytest.mark.asyncio
async def test_exchange_code_for_tokens(self):
handler = self._make_handler()
resp = _mock_response(
{
"access_token": "new-access-token",
"refresh_token": "new-refresh-token",
"expires_in": 3600,
"token_type": "Bearer",
}
)
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
instance = MockRequests.return_value
instance.post = AsyncMock(return_value=resp)
creds = await handler.exchange_code_for_tokens(
code="auth-code",
scopes=["read"],
code_verifier="pkce-verifier",
)
assert isinstance(creds, OAuth2Credentials)
assert creds.access_token.get_secret_value() == "new-access-token"
assert creds.refresh_token is not None
assert creds.refresh_token.get_secret_value() == "new-refresh-token"
assert creds.scopes == ["read"]
assert creds.access_token_expires_at is not None
@pytest.mark.asyncio
async def test_refresh_tokens(self):
handler = self._make_handler()
existing_creds = OAuth2Credentials(
id="existing-id",
provider="mcp",
access_token=SecretStr("old-token"),
refresh_token=SecretStr("old-refresh"),
scopes=["read"],
title="test",
)
resp = _mock_response(
{
"access_token": "refreshed-token",
"refresh_token": "new-refresh",
"expires_in": 3600,
}
)
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
instance = MockRequests.return_value
instance.post = AsyncMock(return_value=resp)
refreshed = await handler._refresh_tokens(existing_creds)
assert refreshed.id == "existing-id"
assert refreshed.access_token.get_secret_value() == "refreshed-token"
assert refreshed.refresh_token is not None
assert refreshed.refresh_token.get_secret_value() == "new-refresh"
@pytest.mark.asyncio
async def test_refresh_tokens_no_refresh_token(self):
handler = self._make_handler()
creds = OAuth2Credentials(
provider="mcp",
access_token=SecretStr("token"),
scopes=["read"],
title="test",
)
with pytest.raises(ValueError, match="No refresh token"):
await handler._refresh_tokens(creds)
@pytest.mark.asyncio
async def test_revoke_tokens_no_url(self):
handler = self._make_handler(revoke_url=None)
creds = OAuth2Credentials(
provider="mcp",
access_token=SecretStr("token"),
scopes=[],
title="test",
)
result = await handler.revoke_tokens(creds)
assert result is False
@pytest.mark.asyncio
async def test_revoke_tokens_with_url(self):
handler = self._make_handler(revoke_url="https://auth.example.com/revoke")
creds = OAuth2Credentials(
provider="mcp",
access_token=SecretStr("token"),
scopes=[],
title="test",
)
resp = _mock_response({}, status=200)
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
instance = MockRequests.return_value
instance.post = AsyncMock(return_value=resp)
result = await handler.revoke_tokens(creds)
assert result is True
class TestMCPClientDiscovery:
"""Tests for MCPClient OAuth metadata discovery."""
@pytest.mark.asyncio
async def test_discover_auth_found(self):
client = MCPClient("https://mcp.example.com/mcp")
metadata = {
"authorization_servers": ["https://auth.example.com"],
"resource": "https://mcp.example.com/mcp",
}
resp = _mock_response(metadata, status=200)
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
instance = MockRequests.return_value
instance.get = AsyncMock(return_value=resp)
result = await client.discover_auth()
assert result is not None
assert result["authorization_servers"] == ["https://auth.example.com"]
@pytest.mark.asyncio
async def test_discover_auth_not_found(self):
client = MCPClient("https://mcp.example.com/mcp")
resp = _mock_response({}, status=404)
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
instance = MockRequests.return_value
instance.get = AsyncMock(return_value=resp)
result = await client.discover_auth()
assert result is None
@pytest.mark.asyncio
async def test_discover_auth_server_metadata(self):
client = MCPClient("https://mcp.example.com/mcp")
server_metadata = {
"issuer": "https://auth.example.com",
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
"registration_endpoint": "https://auth.example.com/register",
"code_challenge_methods_supported": ["S256"],
}
resp = _mock_response(server_metadata, status=200)
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
instance = MockRequests.return_value
instance.get = AsyncMock(return_value=resp)
result = await client.discover_auth_server_metadata(
"https://auth.example.com"
)
assert result is not None
assert result["authorization_endpoint"] == "https://auth.example.com/authorize"
assert result["token_endpoint"] == "https://auth.example.com/token"

View File

@@ -1,162 +0,0 @@
"""
Minimal MCP server for integration testing.
Implements the MCP Streamable HTTP transport (JSON-RPC 2.0 over HTTP POST)
with a few sample tools. Runs on localhost with a random available port.
"""
import json
import logging
from aiohttp import web
logger = logging.getLogger(__name__)
# Sample tools this test server exposes
TEST_TOOLS = [
{
"name": "get_weather",
"description": "Get current weather for a city",
"inputSchema": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City name",
},
},
"required": ["city"],
},
},
{
"name": "add_numbers",
"description": "Add two numbers together",
"inputSchema": {
"type": "object",
"properties": {
"a": {"type": "number", "description": "First number"},
"b": {"type": "number", "description": "Second number"},
},
"required": ["a", "b"],
},
},
{
"name": "echo",
"description": "Echo back the input message",
"inputSchema": {
"type": "object",
"properties": {
"message": {"type": "string", "description": "Message to echo"},
},
"required": ["message"],
},
},
]
def _handle_initialize(params: dict) -> dict:
return {
"protocolVersion": "2025-03-26",
"capabilities": {"tools": {"listChanged": False}},
"serverInfo": {"name": "test-mcp-server", "version": "1.0.0"},
}
def _handle_tools_list(params: dict) -> dict:
return {"tools": TEST_TOOLS}
def _handle_tools_call(params: dict) -> dict:
tool_name = params.get("name", "")
arguments = params.get("arguments", {})
if tool_name == "get_weather":
city = arguments.get("city", "Unknown")
return {
"content": [
{
"type": "text",
"text": json.dumps(
{"city": city, "temperature": 22, "condition": "sunny"}
),
}
],
}
elif tool_name == "add_numbers":
a = arguments.get("a", 0)
b = arguments.get("b", 0)
return {
"content": [{"type": "text", "text": json.dumps({"result": a + b})}],
}
elif tool_name == "echo":
message = arguments.get("message", "")
return {
"content": [{"type": "text", "text": message}],
}
else:
return {
"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}],
"isError": True,
}
HANDLERS = {
"initialize": _handle_initialize,
"tools/list": _handle_tools_list,
"tools/call": _handle_tools_call,
}
async def handle_mcp_request(request: web.Request) -> web.Response:
"""Handle incoming MCP JSON-RPC 2.0 requests."""
# Check auth if configured
expected_token = request.app.get("auth_token")
if expected_token:
auth_header = request.headers.get("Authorization", "")
if auth_header != f"Bearer {expected_token}":
return web.json_response(
{
"jsonrpc": "2.0",
"error": {"code": -32001, "message": "Unauthorized"},
"id": None,
},
status=401,
)
body = await request.json()
# Handle notifications (no id field) — just acknowledge
if "id" not in body:
return web.Response(status=202)
method = body.get("method", "")
params = body.get("params", {})
request_id = body.get("id")
handler = HANDLERS.get(method)
if not handler:
return web.json_response(
{
"jsonrpc": "2.0",
"error": {
"code": -32601,
"message": f"Method not found: {method}",
},
"id": request_id,
}
)
result = handler(params)
return web.json_response({"jsonrpc": "2.0", "result": result, "id": request_id})
def create_test_mcp_app(auth_token: str | None = None) -> web.Application:
"""Create an aiohttp app that acts as an MCP server."""
app = web.Application()
app.router.add_post("/mcp", handle_mcp_request)
if auth_token:
app["auth_token"] = auth_token
return app

View File

@@ -30,7 +30,6 @@ class ProviderName(str, Enum):
IDEOGRAM = "ideogram"
JINA = "jina"
LLAMA_API = "llama_api"
MCP = "mcp"
MEDIUM = "medium"
MEM0 = "mem0"
NOTION = "notion"

View File

@@ -896,6 +896,29 @@ files = [
{file = "charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a"},
]
[[package]]
name = "claude-agent-sdk"
version = "0.1.33"
description = "Python SDK for Claude Code"
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "claude_agent_sdk-0.1.33-py3-none-macosx_11_0_arm64.whl", hash = "sha256:57886a2dd124e5b3c9e12ec3e4841742ab3444d1e428b45ceaec8841c96698fa"},
{file = "claude_agent_sdk-0.1.33-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:ea0f1e4fadeec766000122723c406a6f47c6210ea11bb5cc0c88af11ef7c940c"},
{file = "claude_agent_sdk-0.1.33-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:0ecd822c577b4ea2a52e51146a24dcea73eb69ff366bdb875785dadb116d593b"},
{file = "claude_agent_sdk-0.1.33-py3-none-win_amd64.whl", hash = "sha256:a9fbd09d8f947005e087340ecd0706ed35639c946b4bd49429d3132db4cb3751"},
{file = "claude_agent_sdk-0.1.33.tar.gz", hash = "sha256:134bf403bb7553d829dadec42c30ecef340f5d4ad1595c1bdef933a9ca3129cf"},
]
[package.dependencies]
anyio = ">=4.0.0"
mcp = ">=0.1.0"
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""}
[package.extras]
dev = ["anyio[trio] (>=4.0.0)", "mypy (>=1.0.0)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.20.0)", "pytest-cov (>=4.0.0)", "ruff (>=0.1.0)"]
[[package]]
name = "cleo"
version = "2.1.0"
@@ -2562,6 +2585,18 @@ http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "httpx-sse"
version = "0.4.3"
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc"},
{file = "httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d"},
]
[[package]]
name = "huggingface-hub"
version = "1.4.1"
@@ -3279,6 +3314,39 @@ files = [
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
]
[[package]]
name = "mcp"
version = "1.26.0"
description = "Model Context Protocol SDK"
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca"},
{file = "mcp-1.26.0.tar.gz", hash = "sha256:db6e2ef491eecc1a0d93711a76f28dec2e05999f93afd48795da1c1137142c66"},
]
[package.dependencies]
anyio = ">=4.5"
httpx = ">=0.27.1"
httpx-sse = ">=0.4"
jsonschema = ">=4.20.0"
pydantic = ">=2.11.0,<3.0.0"
pydantic-settings = ">=2.5.2"
pyjwt = {version = ">=2.10.1", extras = ["crypto"]}
python-multipart = ">=0.0.9"
pywin32 = {version = ">=310", markers = "sys_platform == \"win32\""}
sse-starlette = ">=1.6.1"
starlette = ">=0.27"
typing-extensions = ">=4.9.0"
typing-inspection = ">=0.4.1"
uvicorn = {version = ">=0.31.1", markers = "sys_platform != \"emscripten\""}
[package.extras]
cli = ["python-dotenv (>=1.0.0)", "typer (>=0.16.0)"]
rich = ["rich (>=13.9.4)"]
ws = ["websockets (>=15.0.1)"]
[[package]]
name = "mdurl"
version = "0.1.2"
@@ -5961,7 +6029,7 @@ description = "Python for Window Extensions"
optional = false
python-versions = "*"
groups = ["main"]
markers = "platform_system == \"Windows\""
markers = "sys_platform == \"win32\" or platform_system == \"Windows\""
files = [
{file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"},
{file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"},
@@ -6006,13 +6074,6 @@ optional = false
python-versions = ">=3.8"
groups = ["main", "dev"]
files = [
{file = "PyYAML-6.0.3-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:c2514fceb77bc5e7a2f7adfaa1feb2fb311607c9cb518dbc378688ec73d8292f"},
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c57bb8c96f6d1808c030b1687b9b5fb476abaa47f0db9c0101f5e9f394e97f4"},
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:efd7b85f94a6f21e4932043973a7ba2613b059c4a000551892ac9f1d11f5baf3"},
{file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22ba7cfcad58ef3ecddc7ed1db3409af68d023b7f940da23c6c2a1890976eda6"},
{file = "PyYAML-6.0.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6344df0d5755a2c9a276d4473ae6b90647e216ab4757f8426893b5dd2ac3f369"},
{file = "PyYAML-6.0.3-cp38-cp38-win32.whl", hash = "sha256:3ff07ec89bae51176c0549bc4c63aa6202991da2d9a6129d7aef7f1407d3f295"},
{file = "PyYAML-6.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:5cf4e27da7e3fbed4d6c3d8e797387aaad68102272f8f9752883bc32d61cb87b"},
{file = "pyyaml-6.0.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:214ed4befebe12df36bcc8bc2b64b396ca31be9304b8f59e25c11cf94a4c033b"},
{file = "pyyaml-6.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02ea2dfa234451bbb8772601d7b8e426c2bfa197136796224e50e35a78777956"},
{file = "pyyaml-6.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b30236e45cf30d2b8e7b3e85881719e98507abed1011bf463a8fa23e9c3e98a8"},
@@ -6942,6 +7003,28 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
pymysql = ["pymysql"]
sqlcipher = ["sqlcipher3_binary"]
[[package]]
name = "sse-starlette"
version = "3.2.0"
description = "SSE plugin for Starlette"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "sse_starlette-3.2.0-py3-none-any.whl", hash = "sha256:5876954bd51920fc2cd51baee47a080eb88a37b5b784e615abb0b283f801cdbf"},
{file = "sse_starlette-3.2.0.tar.gz", hash = "sha256:8127594edfb51abe44eac9c49e59b0b01f1039d0c7461c6fd91d4e03b70da422"},
]
[package.dependencies]
anyio = ">=4.7.0"
starlette = ">=0.49.1"
[package.extras]
daphne = ["daphne (>=4.2.0)"]
examples = ["aiosqlite (>=0.21.0)", "fastapi (>=0.115.12)", "sqlalchemy[asyncio] (>=2.0.41)", "uvicorn (>=0.34.0)"]
granian = ["granian (>=2.3.1)"]
uvicorn = ["uvicorn (>=0.34.0)"]
[[package]]
name = "stagehand"
version = "0.5.9"
@@ -8382,4 +8465,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.14"
content-hash = "40b2c87c3c86bd10214bd30ad291cead75da5060ab894105025ee4c0a3b3828e"
content-hash = "2e2541233117d1f048be2d3c701fb8d5577b445002c0017362027f278a1a4d06"

View File

@@ -13,6 +13,7 @@ aio-pika = "^9.5.5"
aiohttp = "^3.10.0"
aiodns = "^3.5.0"
anthropic = "^0.59.0"
claude-agent-sdk = "^0.1.0"
apscheduler = "^3.11.1"
autogpt-libs = { path = "../autogpt_libs", develop = true }
bleach = { extras = ["css"], version = "^6.2.0" }

View File

@@ -467,7 +467,6 @@ Below is a comprehensive list of all available blocks, categorized by their prim
| [Github Update Comment](block-integrations/github/issues.md#github-update-comment) | A block that updates an existing comment on a GitHub issue or pull request |
| [Github Update File](block-integrations/github/repo.md#github-update-file) | This block updates an existing file in a GitHub repository |
| [Instantiate Code Sandbox](block-integrations/misc.md#instantiate-code-sandbox) | Instantiate a sandbox environment with internet access in which you can execute code with the Execute Code Step block |
| [MCP Tool](block-integrations/mcp/block.md#mcp-tool) | Connect to any MCP server and execute its tools |
| [Slant3D Order Webhook](block-integrations/slant3d/webhook.md#slant3d-order-webhook) | This block triggers on Slant3D order status updates and outputs the event details, including tracking information when orders are shipped |
## Media Generation

View File

@@ -84,7 +84,6 @@
* [Linear Projects](block-integrations/linear/projects.md)
* [LLM](block-integrations/llm.md)
* [Logic](block-integrations/logic.md)
* [Mcp Block](block-integrations/mcp/block.md)
* [Misc](block-integrations/misc.md)
* [Notion Create Page](block-integrations/notion/create_page.md)
* [Notion Read Database](block-integrations/notion/read_database.md)

View File

@@ -1,36 +0,0 @@
# Mcp Block
<!-- MANUAL: file_description -->
_Add a description of this category of blocks._
<!-- END MANUAL -->
## MCP Tool
### What it is
Connect to any MCP server and execute its tools. Provide a server URL, select a tool, and pass arguments dynamically.
### How it works
<!-- MANUAL: how_it_works -->
_Add technical explanation here._
<!-- END MANUAL -->
### Inputs
| Input | Description | Type | Required |
|-------|-------------|------|----------|
| server_url | URL of the MCP server (Streamable HTTP endpoint) | str | Yes |
| selected_tool | The MCP tool to execute | str | No |
| tool_arguments | Arguments to pass to the selected MCP tool. The fields here are defined by the tool's input schema. | Dict[str, Any] | No |
### Outputs
| Output | Description | Type |
|--------|-------------|------|
| error | Error message if the tool call failed | str |
| result | The result returned by the MCP tool | Result |
### Possible use case
<!-- MANUAL: use_case -->
_Add practical use case examples here._
<!-- END MANUAL -->
---