Compare commits

..

3 Commits

Author SHA1 Message Date
Lluis Agusti
4fa9c6a797 chore: improvements 2026-02-13 23:33:09 +08:00
Lluis Agusti
256d59303a Merge remote-tracking branch 'origin/dev' into lluis/improve-create-edit-ux 2026-02-13 22:37:38 +08:00
Lluis Agusti
e0aa565192 fix: improve create agent ux 2026-02-13 21:41:35 +08:00
101 changed files with 1082 additions and 10849 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -62,7 +62,7 @@ jobs:
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v4
uses: github/codeql-action/init@v3
with:
languages: ${{ matrix.language }}
build-mode: ${{ matrix.build-mode }}
@@ -93,6 +93,6 @@ jobs:
exit 1
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v4
uses: github/codeql-action/analyze@v3
with:
category: "/language:${{matrix.language}}"

View File

@@ -7,10 +7,6 @@ on:
- "docs/integrations/**"
- "autogpt_platform/backend/backend/blocks/**"
concurrency:
group: claude-docs-review-${{ github.event.pull_request.number }}
cancel-in-progress: true
jobs:
claude-review:
# Only run for PRs from members/collaborators
@@ -95,35 +91,5 @@ jobs:
3. Read corresponding documentation files to verify accuracy
4. Provide your feedback as a PR comment
## IMPORTANT: Comment Marker
Start your PR comment with exactly this HTML comment marker on its own line:
<!-- CLAUDE_DOCS_REVIEW -->
This marker is used to identify and replace your comment on subsequent runs.
Be constructive and specific. If everything looks good, say so!
If there are issues, explain what's wrong and suggest how to fix it.
- name: Delete old Claude review comments
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
# Get all comment IDs with our marker, sorted by creation date (oldest first)
COMMENT_IDS=$(gh api \
repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments \
--jq '[.[] | select(.body | contains("<!-- CLAUDE_DOCS_REVIEW -->"))] | sort_by(.created_at) | .[].id')
# Count comments
COMMENT_COUNT=$(echo "$COMMENT_IDS" | grep -c . || true)
if [ "$COMMENT_COUNT" -gt 1 ]; then
# Delete all but the last (newest) comment
echo "$COMMENT_IDS" | head -n -1 | while read -r COMMENT_ID; do
if [ -n "$COMMENT_ID" ]; then
echo "Deleting old review comment: $COMMENT_ID"
gh api -X DELETE repos/${{ github.repository }}/issues/comments/$COMMENT_ID
fi
done
else
echo "No old review comments to clean up"
fi

View File

@@ -1,39 +0,0 @@
name: PR Overlap Detection
on:
pull_request:
types: [opened, synchronize, reopened]
branches:
- dev
- master
permissions:
contents: read
pull-requests: write
jobs:
check-overlaps:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0 # Need full history for merge testing
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Configure git
run: |
git config user.email "github-actions[bot]@users.noreply.github.com"
git config user.name "github-actions[bot]"
- name: Run overlap detection
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# Always succeed - this check informs contributors, it shouldn't block merging
continue-on-error: true
run: |
python .github/scripts/detect_overlaps.py ${{ github.event.pull_request.number }}

View File

@@ -66,19 +66,13 @@ ENV POETRY_HOME=/opt/poetry \
DEBIAN_FRONTEND=noninteractive
ENV PATH=/opt/poetry/bin:$PATH
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
# for the bash_exec MCP tool.
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.13 \
python3-pip \
ffmpeg \
imagemagick \
jq \
ripgrep \
tree \
bubblewrap \
&& rm -rf /var/lib/apt/lists/*
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*

View File

@@ -27,11 +27,12 @@ class ChatConfig(BaseSettings):
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
# Streaming Configuration
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(
default=3,
description="Max retries for fallback path (SDK handles retries internally)",
max_context_messages: int = Field(
default=50, ge=1, le=200, description="Maximum context messages"
)
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(default=3, description="Maximum number of retries")
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
max_agent_schedules: int = Field(
default=30, description="Maximum number of agent schedules"
@@ -92,31 +93,6 @@ class ChatConfig(BaseSettings):
description="Name of the prompt in Langfuse to fetch",
)
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,
description="Use Claude Agent SDK for chat completions",
)
claude_agent_model: str | None = Field(
default=None,
description="Model for the Claude Agent SDK path. If None, derives from "
"the `model` field by stripping the OpenRouter provider prefix.",
)
claude_agent_max_buffer_size: int = Field(
default=10 * 1024 * 1024, # 10MB (default SDK is 1MB)
description="Max buffer size in bytes for Claude Agent SDK JSON message parsing. "
"Increase if tool outputs exceed the limit.",
)
claude_agent_max_subtasks: int = Field(
default=10,
description="Max number of sub-agent Tasks the SDK can spawn per session.",
)
claude_agent_use_resume: bool = Field(
default=True,
description="Use --resume for multi-turn conversations instead of "
"history compression. Falls back to compression when unavailable.",
)
# Extended thinking configuration for Claude models
thinking_enabled: bool = Field(
default=True,
@@ -162,17 +138,6 @@ class ChatConfig(BaseSettings):
v = os.getenv("CHAT_INTERNAL_API_KEY")
return v
@field_validator("use_claude_agent_sdk", mode="before")
@classmethod
def get_use_claude_agent_sdk(cls, v):
"""Get use_claude_agent_sdk from environment if not provided."""
# Check environment variable - default to True if not set
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
if env_val:
return env_val in ("true", "1", "yes", "on")
# Default to True (SDK enabled by default)
return True if v is None else v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",

View File

@@ -334,8 +334,9 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None:
try:
session = ChatSession.model_validate_json(raw_session)
logger.info(
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
f"Loading session {session_id} from cache: "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
except Exception as e:
@@ -377,9 +378,11 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
return None
messages = prisma_session.Messages
logger.debug(
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
logger.info(
f"Loading session {session_id} from DB: "
f"has_messages={messages is not None}, "
f"message_count={len(messages) if messages else 0}, "
f"roles={[m.role for m in messages] if messages else []}"
)
return ChatSession.from_db(prisma_session, messages)
@@ -430,9 +433,10 @@ async def _save_session_to_db(
"function_call": msg.function_call,
}
)
logger.debug(
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
f"roles={[m['role'] for m in messages_data]}"
logger.info(
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
)
await chat_db.add_chat_messages_batch(
session_id=session.session_id,
@@ -472,7 +476,7 @@ async def get_chat_session(
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
# Fall back to database
logger.debug(f"Session {session_id} not in cache, checking database")
logger.info(f"Session {session_id} not in cache, checking database")
session = await _get_session_from_db(session_id)
if session is None:
@@ -489,6 +493,7 @@ async def get_chat_session(
# Cache the session from DB
try:
await _cache_session(session)
logger.info(f"Cached session {session_id} from database")
except Exception as e:
logger.warning(f"Failed to cache session {session_id}: {e}")
@@ -553,40 +558,6 @@ async def upsert_chat_session(
return session
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
"""Atomically append a message to a session and persist it.
Acquires the session lock, re-fetches the latest session state,
appends the message, and saves — preventing message loss when
concurrent requests modify the same session.
"""
lock = await _get_session_lock(session_id)
async with lock:
session = await get_chat_session(session_id)
if session is None:
raise ValueError(f"Session {session_id} not found")
session.messages.append(message)
existing_message_count = await chat_db.get_chat_session_message_count(
session_id
)
try:
await _save_session_to_db(session, existing_message_count)
except Exception as e:
raise DatabaseError(
f"Failed to persist message to session {session_id}"
) from e
try:
await _cache_session(session)
except Exception as e:
logger.warning(f"Cache write failed for session {session_id}: {e}")
return session
async def create_chat_session(user_id: str) -> ChatSession:
"""Create a new chat session and persist it.
@@ -693,19 +664,13 @@ async def update_session_title(session_id: str, title: str) -> bool:
logger.warning(f"Session {session_id} not found for title update")
return False
# Update title in cache if it exists (instead of invalidating).
# This prevents race conditions where cache invalidation causes
# the frontend to see stale DB data while streaming is still in progress.
# Invalidate cache so next fetch gets updated title
try:
cached = await _get_session_from_cache(session_id)
if cached:
cached.title = title
await _cache_session(cached)
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
except Exception as e:
# Not critical - title will be correct on next full cache refresh
logger.warning(
f"Failed to update title in cache for session {session_id}: {e}"
)
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
return True
except Exception as e:

View File

@@ -1,6 +1,5 @@
"""Chat API routes for chat session management and streaming via SSE."""
import asyncio
import logging
import uuid as uuid_module
from collections.abc import AsyncGenerator
@@ -12,22 +11,13 @@ from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from backend.util.exceptions import NotFoundError
from backend.util.feature_flag import Flag, is_feature_enabled
from . import service as chat_service
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
from .model import (
ChatMessage,
ChatSession,
append_and_save_message,
create_chat_session,
get_chat_session,
get_user_sessions,
)
from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart
from .sdk import service as sdk_service
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
from .response_model import StreamFinish, StreamHeartbeat
from .tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
@@ -51,7 +41,6 @@ from .tools.models import (
SetupRequirementsResponse,
UnderstandingUpdatedResponse,
)
from .tracking import track_user_message
config = ChatConfig()
@@ -243,10 +232,6 @@ async def get_session(
active_task, last_message_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_task:
# Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE
@@ -316,9 +301,10 @@ async def stream_chat_post(
f"user={user_id}, message_len={len(request.message)}",
extra={"json_fields": log_meta},
)
session = await _validate_and_get_session(session_id, user_id)
logger.info(
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
@@ -327,25 +313,6 @@ async def stream_chat_post(
},
)
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message re-fetches inside a lock to prevent
# message loss from concurrent requests.
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
if request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
session = await append_and_save_message(session_id, message)
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
@@ -361,7 +328,7 @@ async def stream_chat_post(
operation_id=operation_id,
)
logger.info(
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
@@ -382,47 +349,15 @@ async def stream_chat_post(
first_chunk_time, ttfc = None, None
chunk_count = 0
try:
# Emit a start event with task_id for reconnection
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
logger.info(
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
* 1000,
}
},
)
# Choose service based on LaunchDarkly flag (falls back to config default)
use_sdk = await is_feature_enabled(
Flag.COPILOT_SDK,
user_id or "anonymous",
default=config.use_claude_agent_sdk,
)
stream_fn = (
sdk_service.stream_chat_completion_sdk
if use_sdk
else chat_service.stream_chat_completion
)
logger.info(
f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion",
extra={"json_fields": log_meta},
)
# Pass message=None since we already added it to the session above
async for chunk in stream_fn(
async for chunk in chat_service.stream_chat_completion(
session_id,
None, # Message already in session
request.message,
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass session with message already added
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context,
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
):
# Skip duplicate StreamStart — we already published one above
if isinstance(chunk, StreamStart):
continue
chunk_count += 1
if first_chunk_time is None:
first_chunk_time = time_module.perf_counter()
@@ -443,7 +378,7 @@ async def stream_chat_post(
gen_end_time = time_module.perf_counter()
total_time = (gen_end_time - gen_start_time) * 1000
logger.info(
f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; "
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
f"task={task_id}, session={session_id}, "
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
extra={
@@ -470,17 +405,6 @@ async def stream_chat_post(
}
},
)
# Publish a StreamError so the frontend can display an error message
try:
await stream_registry.publish_chunk(
task_id,
StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
),
)
except Exception:
pass # Best-effort; mark_task_completed will publish StreamFinish
await stream_registry.mark_task_completed(task_id, "failed")
# Start the AI generation in a background task
@@ -583,14 +507,8 @@ async def stream_chat_post(
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
},
)
# Surface error to frontend so it doesn't appear stuck
yield StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
).to_sse()
yield StreamFinish().to_sse()
finally:
# Unsubscribe when client disconnects or stream ends
# Unsubscribe when client disconnects or stream ends to prevent resource leak
if subscriber_queue is not None:
try:
await stream_registry.unsubscribe_from_task(
@@ -834,6 +752,8 @@ async def stream_task(
)
async def event_generator() -> AsyncGenerator[str, None]:
import asyncio
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try:
while True:

View File

@@ -1,14 +0,0 @@
"""Claude Agent SDK integration for CoPilot.
This module provides the integration layer between the Claude Agent SDK
and the existing CoPilot tool system, enabling drop-in replacement of
the current LLM orchestration with the battle-tested Claude Agent SDK.
"""
from .service import stream_chat_completion_sdk
from .tool_adapter import create_copilot_mcp_server
__all__ = [
"stream_chat_completion_sdk",
"create_copilot_mcp_server",
]

View File

@@ -1,203 +0,0 @@
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
This module provides the adapter layer that converts streaming messages from
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
the frontend expects.
"""
import json
import logging
import uuid
from claude_agent_sdk import (
AssistantMessage,
Message,
ResultMessage,
SystemMessage,
TextBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
from backend.api.features.chat.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from backend.api.features.chat.sdk.tool_adapter import (
MCP_TOOL_PREFIX,
pop_pending_tool_output,
)
logger = logging.getLogger(__name__)
class SDKResponseAdapter:
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
This class maintains state during a streaming session to properly track
text blocks, tool calls, and message lifecycle.
"""
def __init__(self, message_id: str | None = None):
self.message_id = message_id or str(uuid.uuid4())
self.text_block_id = str(uuid.uuid4())
self.has_started_text = False
self.has_ended_text = False
self.current_tool_calls: dict[str, dict[str, str]] = {}
self.task_id: str | None = None
self.step_open = False
def set_task_id(self, task_id: str) -> None:
"""Set the task ID for reconnection support."""
self.task_id = task_id
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
"""Convert a single SDK message to Vercel AI SDK format."""
responses: list[StreamBaseResponse] = []
if isinstance(sdk_message, SystemMessage):
if sdk_message.subtype == "init":
responses.append(
StreamStart(messageId=self.message_id, taskId=self.task_id)
)
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
responses.append(StreamStartStep())
self.step_open = True
elif isinstance(sdk_message, AssistantMessage):
# After tool results, the SDK sends a new AssistantMessage for the
# next LLM turn. Open a new step if the previous one was closed.
if not self.step_open:
responses.append(StreamStartStep())
self.step_open = True
for block in sdk_message.content:
if isinstance(block, TextBlock):
if block.text:
self._ensure_text_started(responses)
responses.append(
StreamTextDelta(id=self.text_block_id, delta=block.text)
)
elif isinstance(block, ToolUseBlock):
self._end_text_if_open(responses)
# Strip MCP prefix so frontend sees "find_block"
# instead of "mcp__copilot__find_block".
tool_name = block.name.removeprefix(MCP_TOOL_PREFIX)
responses.append(
StreamToolInputStart(toolCallId=block.id, toolName=tool_name)
)
responses.append(
StreamToolInputAvailable(
toolCallId=block.id,
toolName=tool_name,
input=block.input,
)
)
self.current_tool_calls[block.id] = {"name": tool_name}
elif isinstance(sdk_message, UserMessage):
# UserMessage carries tool results back from tool execution.
content = sdk_message.content
blocks = content if isinstance(content, list) else []
for block in blocks:
if isinstance(block, ToolResultBlock) and block.tool_use_id:
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
tool_name = tool_info.get("name", "unknown")
# Prefer the stashed full output over the SDK's
# (potentially truncated) ToolResultBlock content.
# The SDK truncates large results, writing them to disk,
# which breaks frontend widget parsing.
output = pop_pending_tool_output(tool_name) or (
_extract_tool_output(block.content)
)
responses.append(
StreamToolOutputAvailable(
toolCallId=block.tool_use_id,
toolName=tool_name,
output=output,
success=not (block.is_error or False),
)
)
# Close the current step after tool results — the next
# AssistantMessage will open a new step for the continuation.
if self.step_open:
responses.append(StreamFinishStep())
self.step_open = False
elif isinstance(sdk_message, ResultMessage):
self._end_text_if_open(responses)
# Close the step before finishing.
if self.step_open:
responses.append(StreamFinishStep())
self.step_open = False
if sdk_message.subtype == "success":
responses.append(StreamFinish())
elif sdk_message.subtype in ("error", "error_during_execution"):
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
responses.append(
StreamError(errorText=str(error_msg), code="sdk_error")
)
responses.append(StreamFinish())
else:
logger.warning(
f"Unexpected ResultMessage subtype: {sdk_message.subtype}"
)
responses.append(StreamFinish())
else:
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
return responses
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
"""Start (or restart) a text block if needed."""
if not self.has_started_text or self.has_ended_text:
if self.has_ended_text:
self.text_block_id = str(uuid.uuid4())
self.has_ended_text = False
responses.append(StreamTextStart(id=self.text_block_id))
self.has_started_text = True
def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None:
"""End the current text block if one is open."""
if self.has_started_text and not self.has_ended_text:
responses.append(StreamTextEnd(id=self.text_block_id))
self.has_ended_text = True
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
"""Extract a string output from a ToolResultBlock's content field."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
if parts:
return "".join(parts)
try:
return json.dumps(content)
except (TypeError, ValueError):
return str(content)
if content is None:
return ""
try:
return json.dumps(content)
except (TypeError, ValueError):
return str(content)

View File

@@ -1,366 +0,0 @@
"""Unit tests for the SDK response adapter."""
from claude_agent_sdk import (
AssistantMessage,
ResultMessage,
SystemMessage,
TextBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
from backend.api.features.chat.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from .response_adapter import SDKResponseAdapter
from .tool_adapter import MCP_TOOL_PREFIX
def _adapter() -> SDKResponseAdapter:
a = SDKResponseAdapter(message_id="msg-1")
a.set_task_id("task-1")
return a
# -- SystemMessage -----------------------------------------------------------
def test_system_init_emits_start_and_step():
adapter = _adapter()
results = adapter.convert_message(SystemMessage(subtype="init", data={}))
assert len(results) == 2
assert isinstance(results[0], StreamStart)
assert results[0].messageId == "msg-1"
assert results[0].taskId == "task-1"
assert isinstance(results[1], StreamStartStep)
def test_system_non_init_emits_nothing():
adapter = _adapter()
results = adapter.convert_message(SystemMessage(subtype="other", data={}))
assert results == []
# -- AssistantMessage with TextBlock -----------------------------------------
def test_text_block_emits_step_start_and_delta():
adapter = _adapter()
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
results = adapter.convert_message(msg)
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamTextStart)
assert isinstance(results[2], StreamTextDelta)
assert results[2].delta == "hello"
def test_empty_text_block_emits_only_step():
adapter = _adapter()
msg = AssistantMessage(content=[TextBlock(text="")], model="test")
results = adapter.convert_message(msg)
# Empty text skipped, but step still opens
assert len(results) == 1
assert isinstance(results[0], StreamStartStep)
def test_multiple_text_deltas_reuse_block_id():
adapter = _adapter()
msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test")
msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test")
r1 = adapter.convert_message(msg1)
r2 = adapter.convert_message(msg2)
# First gets step+start+delta, second only delta (block & step already started)
assert len(r1) == 3
assert isinstance(r1[0], StreamStartStep)
assert isinstance(r1[1], StreamTextStart)
assert len(r2) == 1
assert isinstance(r2[0], StreamTextDelta)
assert r1[1].id == r2[0].id # same block ID
# -- AssistantMessage with ToolUseBlock --------------------------------------
def test_tool_use_emits_input_start_and_available():
"""Tool names arrive with MCP prefix and should be stripped for the frontend."""
adapter = _adapter()
msg = AssistantMessage(
content=[
ToolUseBlock(
id="tool-1",
name=f"{MCP_TOOL_PREFIX}find_agent",
input={"q": "x"},
)
],
model="test",
)
results = adapter.convert_message(msg)
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamToolInputStart)
assert results[1].toolCallId == "tool-1"
assert results[1].toolName == "find_agent" # prefix stripped
assert isinstance(results[2], StreamToolInputAvailable)
assert results[2].toolName == "find_agent" # prefix stripped
assert results[2].input == {"q": "x"}
def test_text_then_tool_ends_text_block():
adapter = _adapter()
text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
tool_msg = AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
model="test",
)
adapter.convert_message(text_msg) # opens step + text
results = adapter.convert_message(tool_msg)
# Step already open, so: TextEnd, ToolInputStart, ToolInputAvailable
assert len(results) == 3
assert isinstance(results[0], StreamTextEnd)
assert isinstance(results[1], StreamToolInputStart)
# -- UserMessage with ToolResultBlock ----------------------------------------
def test_tool_result_emits_output_and_finish_step():
adapter = _adapter()
# First register the tool call (opens step) — SDK sends prefixed name
tool_msg = AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_agent", input={})],
model="test",
)
adapter.convert_message(tool_msg)
# Now send tool result
result_msg = UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")]
)
results = adapter.convert_message(result_msg)
assert len(results) == 2
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].toolCallId == "t1"
assert results[0].toolName == "find_agent" # prefix stripped
assert results[0].output == "found 3 agents"
assert results[0].success is True
assert isinstance(results[1], StreamFinishStep)
def test_tool_result_error():
adapter = _adapter()
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}run_agent", input={})
],
model="test",
)
)
result_msg = UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)]
)
results = adapter.convert_message(result_msg)
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].success is False
assert isinstance(results[1], StreamFinishStep)
def test_tool_result_list_content():
adapter = _adapter()
adapter.convert_message(
AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
model="test",
)
)
result_msg = UserMessage(
content=[
ToolResultBlock(
tool_use_id="t1",
content=[
{"type": "text", "text": "line1"},
{"type": "text", "text": "line2"},
],
)
]
)
results = adapter.convert_message(result_msg)
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].output == "line1line2"
assert isinstance(results[1], StreamFinishStep)
def test_string_user_message_ignored():
"""A plain string UserMessage (not tool results) produces no output."""
adapter = _adapter()
results = adapter.convert_message(UserMessage(content="hello"))
assert results == []
# -- ResultMessage -----------------------------------------------------------
def test_result_success_emits_finish_step_and_finish():
adapter = _adapter()
# Start some text first (opens step)
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="done")], model="test")
)
msg = ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="s1",
)
results = adapter.convert_message(msg)
# TextEnd + FinishStep + StreamFinish
assert len(results) == 3
assert isinstance(results[0], StreamTextEnd)
assert isinstance(results[1], StreamFinishStep)
assert isinstance(results[2], StreamFinish)
def test_result_error_emits_error_and_finish():
adapter = _adapter()
msg = ResultMessage(
subtype="error",
duration_ms=100,
duration_api_ms=50,
is_error=True,
num_turns=0,
session_id="s1",
result="API rate limited",
)
results = adapter.convert_message(msg)
# No step was open, so no FinishStep — just Error + Finish
assert len(results) == 2
assert isinstance(results[0], StreamError)
assert "API rate limited" in results[0].errorText
assert isinstance(results[1], StreamFinish)
# -- Text after tools (new block ID) ----------------------------------------
def test_text_after_tool_gets_new_block_id():
adapter = _adapter()
# Text -> Tool -> ToolResult -> Text should get a new text block ID and step
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="before")], model="test")
)
adapter.convert_message(
AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
model="test",
)
)
# Send tool result (closes step)
adapter.convert_message(
UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="ok")])
)
results = adapter.convert_message(
AssistantMessage(content=[TextBlock(text="after")], model="test")
)
# Should get StreamStartStep (new step) + StreamTextStart (new block) + StreamTextDelta
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamTextStart)
assert isinstance(results[2], StreamTextDelta)
assert results[2].delta == "after"
# -- Full conversation flow --------------------------------------------------
def test_full_conversation_flow():
"""Simulate a complete conversation: init -> text -> tool -> result -> text -> finish."""
adapter = _adapter()
all_responses: list[StreamBaseResponse] = []
# 1. Init
all_responses.extend(
adapter.convert_message(SystemMessage(subtype="init", data={}))
)
# 2. Assistant text
all_responses.extend(
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="Let me search")], model="test")
)
)
# 3. Tool use
all_responses.extend(
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(
id="t1",
name=f"{MCP_TOOL_PREFIX}find_agent",
input={"query": "email"},
)
],
model="test",
)
)
)
# 4. Tool result
all_responses.extend(
adapter.convert_message(
UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")]
)
)
)
# 5. More text
all_responses.extend(
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="I found 2")], model="test")
)
)
# 6. Result
all_responses.extend(
adapter.convert_message(
ResultMessage(
subtype="success",
duration_ms=500,
duration_api_ms=400,
is_error=False,
num_turns=2,
session_id="s1",
)
)
)
types = [type(r).__name__ for r in all_responses]
assert types == [
"StreamStart",
"StreamStartStep", # step 1: text + tool call
"StreamTextStart",
"StreamTextDelta", # "Let me search"
"StreamTextEnd", # closed before tool
"StreamToolInputStart",
"StreamToolInputAvailable",
"StreamToolOutputAvailable", # tool result
"StreamFinishStep", # step 1 closed after tool result
"StreamStartStep", # step 2: continuation text
"StreamTextStart", # new block after tool
"StreamTextDelta", # "I found 2"
"StreamTextEnd", # closed by result
"StreamFinishStep", # step 2 closed
"StreamFinish",
]

View File

@@ -1,305 +0,0 @@
"""Security hooks for Claude Agent SDK integration.
This module provides security hooks that validate tool calls before execution,
ensuring multi-user isolation and preventing unauthorized operations.
"""
import json
import logging
import os
import re
from collections.abc import Callable
from typing import Any, cast
from backend.api.features.chat.sdk.tool_adapter import (
BLOCKED_TOOLS,
DANGEROUS_PATTERNS,
MCP_TOOL_PREFIX,
WORKSPACE_SCOPED_TOOLS,
)
logger = logging.getLogger(__name__)
def _deny(reason: str) -> dict[str, Any]:
"""Return a hook denial response."""
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": reason,
}
}
def _validate_workspace_path(
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None
) -> dict[str, Any]:
"""Validate that a workspace-scoped tool only accesses allowed paths.
Allowed directories:
- The SDK working directory (``/tmp/copilot-<session>/``)
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
"""
path = tool_input.get("file_path") or tool_input.get("path") or ""
if not path:
# Glob/Grep without a path default to cwd which is already sandboxed
return {}
# Resolve relative paths against sdk_cwd (the SDK sets cwd so the LLM
# naturally uses relative paths like "test.txt" instead of absolute ones).
# Tilde paths (~/) are home-dir references, not relative — expand first.
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
# Allow access within the SDK working directory
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
return {}
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
claude_dir = os.path.realpath(os.path.expanduser("~/.claude/projects"))
tool_results_seg = os.sep + "tool-results" + os.sep
if resolved.startswith(claude_dir + os.sep) and tool_results_seg in resolved:
return {}
logger.warning(
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
)
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
return _deny(
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
f"directory.{workspace_hint} "
"This is enforced by the platform and cannot be bypassed."
)
def _validate_tool_access(
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None = None
) -> dict[str, Any]:
"""Validate that a tool call is allowed.
Returns:
Empty dict to allow, or dict with hookSpecificOutput to deny
"""
# Block forbidden tools
if tool_name in BLOCKED_TOOLS:
logger.warning(f"Blocked tool access attempt: {tool_name}")
return _deny(
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
"This is enforced by the platform and cannot be bypassed. "
"Use the CoPilot-specific MCP tools instead."
)
# Workspace-scoped tools: allowed only within the SDK workspace directory
if tool_name in WORKSPACE_SCOPED_TOOLS:
return _validate_workspace_path(tool_name, tool_input, sdk_cwd)
# Check for dangerous patterns in tool input
# Use json.dumps for predictable format (str() produces Python repr)
input_str = json.dumps(tool_input) if tool_input else ""
for pattern in DANGEROUS_PATTERNS:
if re.search(pattern, input_str, re.IGNORECASE):
logger.warning(
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
)
return _deny(
"[SECURITY] Input contains a blocked pattern. "
"This is enforced by the platform and cannot be bypassed."
)
return {}
def _validate_user_isolation(
tool_name: str, tool_input: dict[str, Any], user_id: str | None
) -> dict[str, Any]:
"""Validate that tool calls respect user isolation."""
# For workspace file tools, ensure path doesn't escape
if "workspace" in tool_name.lower():
path = tool_input.get("path", "") or tool_input.get("file_path", "")
if path:
# Check for path traversal
if ".." in path or path.startswith("/"):
logger.warning(
f"Blocked path traversal attempt: {path} by user {user_id}"
)
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": "Path traversal not allowed",
}
}
return {}
def create_security_hooks(
user_id: str | None,
sdk_cwd: str | None = None,
max_subtasks: int = 3,
on_stop: Callable[[str, str], None] | None = None,
) -> dict[str, Any]:
"""Create the security hooks configuration for Claude Agent SDK.
Includes security validation and observability hooks:
- PreToolUse: Security validation before tool execution
- PostToolUse: Log successful tool executions
- PostToolUseFailure: Log and handle failed tool executions
- PreCompact: Log context compaction events (SDK handles compaction automatically)
- Stop: Capture transcript path for stateless resume (when *on_stop* is provided)
Args:
user_id: Current user ID for isolation validation
sdk_cwd: SDK working directory for workspace-scoped tool validation
max_subtasks: Maximum Task (sub-agent) spawns allowed per session
on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when
the SDK finishes processing — used to read the JSONL transcript
before the CLI process exits.
Returns:
Hooks configuration dict for ClaudeAgentOptions
"""
try:
from claude_agent_sdk import HookMatcher
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
# Per-session counter for Task sub-agent spawns
task_spawn_count = 0
async def pre_tool_use_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Combined pre-tool-use validation hook."""
nonlocal task_spawn_count
_ = context # unused but required by signature
tool_name = cast(str, input_data.get("tool_name", ""))
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
# Rate-limit Task (sub-agent) spawns per session
if tool_name == "Task":
task_spawn_count += 1
if task_spawn_count > max_subtasks:
logger.warning(
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
)
return cast(
SyncHookJSONOutput,
_deny(
f"Maximum {max_subtasks} sub-tasks per session. "
"Please continue in the main conversation."
),
)
# Strip MCP prefix for consistent validation
is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX)
clean_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
# Only block non-CoPilot tools; our MCP-registered tools
# (including Read for oversized results) are already sandboxed.
if not is_copilot_tool:
result = _validate_tool_access(clean_name, tool_input, sdk_cwd)
if result:
return cast(SyncHookJSONOutput, result)
# Validate user isolation
result = _validate_user_isolation(clean_name, tool_input, user_id)
if result:
return cast(SyncHookJSONOutput, result)
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
return cast(SyncHookJSONOutput, {})
async def post_tool_use_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log successful tool executions for observability."""
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
return cast(SyncHookJSONOutput, {})
async def post_tool_failure_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log failed tool executions for debugging."""
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
error = input_data.get("error", "Unknown error")
logger.warning(
f"[SDK] Tool failed: {tool_name}, error={error}, "
f"user={user_id}, tool_use_id={tool_use_id}"
)
return cast(SyncHookJSONOutput, {})
async def pre_compact_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log when SDK triggers context compaction.
The SDK automatically compacts conversation history when it grows too large.
This hook provides visibility into when compaction happens.
"""
_ = context, tool_use_id
trigger = input_data.get("trigger", "auto")
logger.info(
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
)
return cast(SyncHookJSONOutput, {})
# --- Stop hook: capture transcript path for stateless resume ---
async def stop_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Capture transcript path when SDK finishes processing.
The Stop hook fires while the CLI process is still alive, giving us
a reliable window to read the JSONL transcript before SIGTERM.
"""
_ = context, tool_use_id
transcript_path = cast(str, input_data.get("transcript_path", ""))
sdk_session_id = cast(str, input_data.get("session_id", ""))
if transcript_path and on_stop:
logger.info(
f"[SDK] Stop hook: transcript_path={transcript_path}, "
f"sdk_session_id={sdk_session_id[:12]}..."
)
on_stop(transcript_path, sdk_session_id)
return cast(SyncHookJSONOutput, {})
hooks: dict[str, Any] = {
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
"PostToolUseFailure": [
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
],
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
}
if on_stop is not None:
hooks["Stop"] = [HookMatcher(matcher=None, hooks=[stop_hook])]
return hooks
except ImportError:
# Fallback for when SDK isn't available - return empty hooks
logger.warning("claude-agent-sdk not available, security hooks disabled")
return {}

View File

@@ -1,165 +0,0 @@
"""Unit tests for SDK security hooks."""
import os
from .security_hooks import _validate_tool_access, _validate_user_isolation
SDK_CWD = "/tmp/copilot-abc123"
def _is_denied(result: dict) -> bool:
hook = result.get("hookSpecificOutput", {})
return hook.get("permissionDecision") == "deny"
# -- Blocked tools -----------------------------------------------------------
def test_blocked_tools_denied():
for tool in ("bash", "shell", "exec", "terminal", "command"):
result = _validate_tool_access(tool, {})
assert _is_denied(result), f"{tool} should be blocked"
def test_unknown_tool_allowed():
result = _validate_tool_access("SomeCustomTool", {})
assert result == {}
# -- Workspace-scoped tools --------------------------------------------------
def test_read_within_workspace_allowed():
result = _validate_tool_access(
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_write_within_workspace_allowed():
result = _validate_tool_access(
"Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_edit_within_workspace_allowed():
result = _validate_tool_access(
"Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_glob_within_workspace_allowed():
result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
assert result == {}
def test_grep_within_workspace_allowed():
result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_outside_workspace_denied():
result = _validate_tool_access(
"Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_write_outside_workspace_denied():
result = _validate_tool_access(
"Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_traversal_attack_denied():
result = _validate_tool_access(
"Read",
{"file_path": f"{SDK_CWD}/../../etc/passwd"},
sdk_cwd=SDK_CWD,
)
assert _is_denied(result)
def test_no_path_allowed():
"""Glob/Grep without a path argument defaults to cwd — should pass."""
result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_no_cwd_denies_absolute():
"""If no sdk_cwd is set, absolute paths are denied."""
result = _validate_tool_access("Read", {"file_path": "/tmp/anything"})
assert _is_denied(result)
# -- Tool-results directory --------------------------------------------------
def test_read_tool_results_allowed():
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_claude_projects_without_tool_results_denied():
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert _is_denied(result)
# -- Built-in Bash is blocked (use bash_exec MCP tool instead) ---------------
def test_bash_builtin_always_blocked():
"""SDK built-in Bash is blocked — bash_exec MCP tool with bubblewrap is used instead."""
result = _validate_tool_access("Bash", {"command": "echo hello"}, sdk_cwd=SDK_CWD)
assert _is_denied(result)
# -- Dangerous patterns ------------------------------------------------------
def test_dangerous_pattern_blocked():
result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"})
assert _is_denied(result)
def test_subprocess_pattern_blocked():
result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"})
assert _is_denied(result)
# -- User isolation ----------------------------------------------------------
def test_workspace_path_traversal_blocked():
result = _validate_user_isolation(
"workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1"
)
assert _is_denied(result)
def test_workspace_absolute_path_blocked():
result = _validate_user_isolation(
"workspace_read", {"path": "/etc/passwd"}, user_id="user-1"
)
assert _is_denied(result)
def test_workspace_normal_path_allowed():
result = _validate_user_isolation(
"workspace_read", {"path": "src/main.py"}, user_id="user-1"
)
assert result == {}
def test_non_workspace_tool_passes_isolation():
result = _validate_user_isolation(
"find_agent", {"query": "email"}, user_id="user-1"
)
assert result == {}

View File

@@ -1,752 +0,0 @@
"""Claude Agent SDK service layer for CoPilot chat completions."""
import asyncio
import json
import logging
import os
import uuid
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any
from backend.util.exceptions import NotFoundError
from .. import stream_registry
from ..config import ChatConfig
from ..model import (
ChatMessage,
ChatSession,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from ..response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamStart,
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from ..service import (
_build_system_prompt,
_execute_long_running_tool_with_streaming,
_generate_session_title,
)
from ..tools.models import OperationPendingResponse, OperationStartedResponse
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tracking import track_user_message
from .response_adapter import SDKResponseAdapter
from .security_hooks import create_security_hooks
from .tool_adapter import (
COPILOT_TOOL_NAMES,
SDK_DISALLOWED_TOOLS,
LongRunningCallback,
create_copilot_mcp_server,
set_execution_context,
)
from .transcript import (
download_transcript,
read_transcript_file,
upload_transcript,
validate_transcript,
write_transcript_to_tempfile,
)
logger = logging.getLogger(__name__)
config = ChatConfig()
# Set to hold background tasks to prevent garbage collection
_background_tasks: set[asyncio.Task[Any]] = set()
@dataclass
class CapturedTranscript:
"""Info captured by the SDK Stop hook for stateless --resume."""
path: str = ""
sdk_session_id: str = ""
@property
def available(self) -> bool:
return bool(self.path)
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
# Appended to the system prompt to inform the agent about available tools.
# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead,
# which has kernel-level network isolation (unshare --net).
_SDK_TOOL_SUPPLEMENT = """
## Tool notes
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
for shell commands — it runs in a network-isolated sandbox.
- **Shared workspace**: The SDK Read/Write tools and `bash_exec` share the
same working directory. Files created by one are readable by the other.
These files are **ephemeral** — they exist only for the current session.
- **Persistent storage**: Use `write_workspace_file` / `read_workspace_file`
for files that should persist across sessions (stored in cloud storage).
- Long-running tools (create_agent, edit_agent, etc.) are handled
asynchronously. You will receive an immediate response; the actual result
is delivered to the user via a background stream.
"""
def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
"""Build a callback that delegates long-running tools to the non-SDK infrastructure.
Long-running tools (create_agent, edit_agent, etc.) are delegated to the
existing background infrastructure: stream_registry (Redis Streams),
database persistence, and SSE reconnection. This means results survive
page refreshes / pod restarts, and the frontend shows the proper loading
widget with progress updates.
The returned callback matches the ``LongRunningCallback`` signature:
``(tool_name, args, session) -> MCP response dict``.
"""
async def _callback(
tool_name: str, args: dict[str, Any], session: ChatSession
) -> dict[str, Any]:
operation_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
tool_call_id = f"sdk-{uuid.uuid4().hex[:12]}"
session_id = session.session_id
# --- Build user-friendly messages (matches non-SDK service) ---
if tool_name == "create_agent":
desc = args.get("description", "")
desc_preview = (desc[:100] + "...") if len(desc) > 100 else desc
pending_msg = (
f"Creating your agent: {desc_preview}"
if desc_preview
else "Creating agent... This may take a few minutes."
)
started_msg = (
"Agent creation started. You can close this tab - "
"check your library in a few minutes."
)
elif tool_name == "edit_agent":
changes = args.get("changes", "")
changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes
pending_msg = (
f"Editing agent: {changes_preview}"
if changes_preview
else "Editing agent... This may take a few minutes."
)
started_msg = (
"Agent edit started. You can close this tab - "
"check your library in a few minutes."
)
else:
pending_msg = f"Running {tool_name}... This may take a few minutes."
started_msg = (
f"{tool_name} started. You can close this tab - "
"check back in a few minutes."
)
# --- Register task in Redis for SSE reconnection ---
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# --- Save OperationPendingResponse to chat history ---
pending_message = ChatMessage(
role="tool",
content=OperationPendingResponse(
message=pending_msg,
operation_id=operation_id,
tool_name=tool_name,
).model_dump_json(),
tool_call_id=tool_call_id,
)
session.messages.append(pending_message)
await upsert_chat_session(session)
# --- Spawn background task (reuses non-SDK infrastructure) ---
bg_task = asyncio.create_task(
_execute_long_running_tool_with_streaming(
tool_name=tool_name,
parameters=args,
tool_call_id=tool_call_id,
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
user_id=user_id,
)
)
_background_tasks.add(bg_task)
bg_task.add_done_callback(_background_tasks.discard)
await stream_registry.set_task_asyncio_task(task_id, bg_task)
logger.info(
f"[SDK] Long-running tool {tool_name} delegated to background "
f"(operation_id={operation_id}, task_id={task_id})"
)
# --- Return OperationStartedResponse as MCP tool result ---
# This flows through SDK → response adapter → frontend, triggering
# the loading widget with SSE reconnection support.
started_json = OperationStartedResponse(
message=started_msg,
operation_id=operation_id,
tool_name=tool_name,
task_id=task_id,
).model_dump_json()
return {
"content": [{"type": "text", "text": started_json}],
"isError": False,
}
return _callback
def _resolve_sdk_model() -> str | None:
"""Resolve the model name for the Claude Agent SDK CLI.
Uses ``config.claude_agent_model`` if set, otherwise derives from
``config.model`` by stripping the OpenRouter provider prefix (e.g.,
``"anthropic/claude-opus-4.6"`` → ``"claude-opus-4.6"``).
"""
if config.claude_agent_model:
return config.claude_agent_model
model = config.model
if "/" in model:
return model.split("/", 1)[1]
return model
def _build_sdk_env() -> dict[str, str]:
"""Build env vars for the SDK CLI process.
Routes API calls through OpenRouter (or a custom base_url) using
the same ``config.api_key`` / ``config.base_url`` as the non-SDK path.
This gives per-call token and cost tracking on the OpenRouter dashboard.
Only overrides ``ANTHROPIC_API_KEY`` when a valid proxy URL and auth
token are both present — otherwise returns an empty dict so the SDK
falls back to its default credentials.
"""
env: dict[str, str] = {}
if config.api_key and config.base_url:
# Strip /v1 suffix — SDK expects the base URL without a version path
base = config.base_url.rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
if not base or not base.startswith("http"):
# Invalid base_url — don't override SDK defaults
return env
env["ANTHROPIC_BASE_URL"] = base
env["ANTHROPIC_AUTH_TOKEN"] = config.api_key
# Must be explicitly empty so the CLI uses AUTH_TOKEN instead
env["ANTHROPIC_API_KEY"] = ""
return env
def _make_sdk_cwd(session_id: str) -> str:
"""Create a safe, session-specific working directory path.
Delegates to :func:`~backend.api.features.chat.tools.sandbox.make_session_path`
(single source of truth for path sanitization) and adds a defence-in-depth
assertion.
"""
cwd = make_session_path(session_id)
# Defence-in-depth: normpath + startswith is a CodeQL-recognised sanitizer
cwd = os.path.normpath(cwd)
if not cwd.startswith(_SDK_CWD_PREFIX):
raise ValueError(f"SDK cwd escaped prefix: {cwd}")
return cwd
def _cleanup_sdk_tool_results(cwd: str) -> None:
"""Remove SDK tool-result files for a specific session working directory.
The SDK creates tool-result files under ~/.claude/projects/<encoded-cwd>/tool-results/.
We clean only the specific cwd's results to avoid race conditions between
concurrent sessions.
Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id.
"""
import shutil
# Validate cwd is under the expected prefix
normalized = os.path.normpath(cwd)
if not normalized.startswith(_SDK_CWD_PREFIX):
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
return
# SDK encodes the cwd path by replacing '/' with '-'
encoded_cwd = normalized.replace("/", "-")
# Construct the project directory path (known-safe home expansion)
claude_projects = os.path.expanduser("~/.claude/projects")
project_dir = os.path.join(claude_projects, encoded_cwd)
# Security check 3: Validate project_dir is under ~/.claude/projects
project_dir = os.path.normpath(project_dir)
if not project_dir.startswith(claude_projects):
logger.warning(
f"[SDK] Rejecting cleanup for escaped project path: {project_dir}"
)
return
results_dir = os.path.join(project_dir, "tool-results")
if os.path.isdir(results_dir):
for filename in os.listdir(results_dir):
file_path = os.path.join(results_dir, filename)
try:
if os.path.isfile(file_path):
os.remove(file_path)
except OSError:
pass
# Also clean up the temp cwd directory itself
try:
shutil.rmtree(normalized, ignore_errors=True)
except OSError:
pass
async def _compress_conversation_history(
session: ChatSession,
) -> list[ChatMessage]:
"""Compress prior conversation messages if they exceed the token threshold.
Uses the shared compress_context() from prompt.py which supports:
- LLM summarization of old messages (keeps recent ones intact)
- Progressive content truncation as fallback
- Middle-out deletion as last resort
Returns the compressed prior messages (everything except the current message).
"""
prior = session.messages[:-1]
if len(prior) < 2:
return prior
from backend.util.prompt import compress_context
# Convert ChatMessages to dicts for compress_context
messages_dict = []
for msg in prior:
msg_dict: dict[str, Any] = {"role": msg.role}
if msg.content:
msg_dict["content"] = msg.content
if msg.tool_calls:
msg_dict["tool_calls"] = msg.tool_calls
if msg.tool_call_id:
msg_dict["tool_call_id"] = msg.tool_call_id
messages_dict.append(msg_dict)
try:
import openai
async with openai.AsyncOpenAI(
api_key=config.api_key, base_url=config.base_url, timeout=30.0
) as client:
result = await compress_context(
messages=messages_dict,
model=config.model,
client=client,
)
except Exception as e:
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
# Fall back to truncation-only (no LLM summarization)
result = await compress_context(
messages=messages_dict,
model=config.model,
client=None,
)
if result.was_compacted:
logger.info(
f"[SDK] Context compacted: {result.original_token_count} -> "
f"{result.token_count} tokens "
f"({result.messages_summarized} summarized, "
f"{result.messages_dropped} dropped)"
)
# Convert compressed dicts back to ChatMessages
return [
ChatMessage(
role=m["role"],
content=m.get("content"),
tool_calls=m.get("tool_calls"),
tool_call_id=m.get("tool_call_id"),
)
for m in result.messages
]
return prior
def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
"""Format conversation messages into a context prefix for the user message.
Returns a string like:
<conversation_history>
User: hello
You responded: Hi! How can I help?
</conversation_history>
Returns None if there are no messages to format.
"""
if not messages:
return None
lines: list[str] = []
for msg in messages:
if not msg.content:
continue
if msg.role == "user":
lines.append(f"User: {msg.content}")
elif msg.role == "assistant":
lines.append(f"You responded: {msg.content}")
# Skip tool messages — they're internal details
if not lines:
return None
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
async def stream_chat_completion_sdk(
session_id: str,
message: str | None = None,
tool_call_response: str | None = None, # noqa: ARG001
is_user_message: bool = True,
user_id: str | None = None,
retry_count: int = 0, # noqa: ARG001
session: ChatSession | None = None,
context: dict[str, str] | None = None, # noqa: ARG001
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream chat completion using Claude Agent SDK.
Drop-in replacement for stream_chat_completion with improved reliability.
"""
if session is None:
session = await get_chat_session(session_id, user_id)
if not session:
raise NotFoundError(
f"Session {session_id} not found. Please create a new session first."
)
if message:
session.messages.append(
ChatMessage(
role="user" if is_user_message else "assistant", content=message
)
)
if is_user_message:
track_user_message(
user_id=user_id, session_id=session_id, message_length=len(message)
)
session = await upsert_chat_session(session)
# Generate title for new sessions (first user message)
if is_user_message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
if len(user_messages) == 1:
first_message = user_messages[0].content or message or ""
if first_message:
task = asyncio.create_task(
_update_title_async(session_id, first_message, user_id)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
# Build system prompt (reuses non-SDK path with Langfuse support)
has_history = len(session.messages) > 1
system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=has_history
)
system_prompt += _SDK_TOOL_SUPPLEMENT
message_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
yield StreamStart(messageId=message_id, taskId=task_id)
stream_completed = False
# Initialise sdk_cwd before the try so the finally can reference it
# even if _make_sdk_cwd raises (in that case it stays as "").
sdk_cwd = ""
use_resume = False
try:
# Use a session-specific temp dir to avoid cleanup race conditions
# between concurrent sessions.
sdk_cwd = _make_sdk_cwd(session_id)
os.makedirs(sdk_cwd, exist_ok=True)
set_execution_context(
user_id,
session,
long_running_callback=_build_long_running_callback(user_id),
)
try:
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
# Fail fast when no API credentials are available at all
sdk_env = _build_sdk_env()
if not sdk_env and not os.environ.get("ANTHROPIC_API_KEY"):
raise RuntimeError(
"No API key configured. Set OPEN_ROUTER_API_KEY "
"(or CHAT_API_KEY) for OpenRouter routing, "
"or ANTHROPIC_API_KEY for direct Anthropic access."
)
mcp_server = create_copilot_mcp_server()
sdk_model = _resolve_sdk_model()
# --- Transcript capture via Stop hook ---
captured_transcript = CapturedTranscript()
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
captured_transcript.path = transcript_path
captured_transcript.sdk_session_id = sdk_session_id
security_hooks = create_security_hooks(
user_id,
sdk_cwd=sdk_cwd,
max_subtasks=config.claude_agent_max_subtasks,
on_stop=_on_stop if config.claude_agent_use_resume else None,
)
# --- Resume strategy: download transcript from bucket ---
resume_file: str | None = None
use_resume = False
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
transcript_content = await download_transcript(user_id, session_id)
if transcript_content and validate_transcript(transcript_content):
resume_file = write_transcript_to_tempfile(
transcript_content, session_id, sdk_cwd
)
if resume_file:
use_resume = True
logger.info(
f"[SDK] Using --resume with transcript "
f"({len(transcript_content)} bytes)"
)
sdk_options_kwargs: dict[str, Any] = {
"system_prompt": system_prompt,
"mcp_servers": {"copilot": mcp_server},
"allowed_tools": COPILOT_TOOL_NAMES,
"disallowed_tools": SDK_DISALLOWED_TOOLS,
"hooks": security_hooks,
"cwd": sdk_cwd,
"max_buffer_size": config.claude_agent_max_buffer_size,
}
if sdk_env:
sdk_options_kwargs["model"] = sdk_model
sdk_options_kwargs["env"] = sdk_env
if use_resume and resume_file:
sdk_options_kwargs["resume"] = resume_file
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type]
adapter = SDKResponseAdapter(message_id=message_id)
adapter.set_task_id(task_id)
async with ClaudeSDKClient(options=options) as client:
current_message = message or ""
if not current_message and session.messages:
last_user = [m for m in session.messages if m.role == "user"]
if last_user:
current_message = last_user[-1].content or ""
if not current_message.strip():
yield StreamError(
errorText="Message cannot be empty.",
code="empty_prompt",
)
yield StreamFinish()
return
# Build query: with --resume the CLI already has full
# context, so we only send the new message. Without
# resume, compress history into a context prefix.
query_message = current_message
if not use_resume and len(session.messages) > 1:
logger.warning(
f"[SDK] Using compression fallback for session "
f"{session_id} ({len(session.messages)} messages) — "
f"no transcript available for --resume"
)
compressed = await _compress_conversation_history(session)
history_context = _format_conversation_context(compressed)
if history_context:
query_message = (
f"{history_context}\n\n"
f"Now, the user says:\n{current_message}"
)
logger.info(
f"[SDK] Sending query ({len(session.messages)} msgs in session)"
)
logger.debug(f"[SDK] Query preview: {current_message[:80]!r}")
await client.query(query_message, session_id=session_id)
assistant_response = ChatMessage(role="assistant", content="")
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False
async for sdk_msg in client.receive_messages():
logger.debug(
f"[SDK] Received: {type(sdk_msg).__name__} "
f"{getattr(sdk_msg, 'subtype', '')}"
)
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
continue
yield response
if isinstance(response, StreamTextDelta):
delta = response.delta or ""
# After tool results, start a new assistant
# message for the post-tool text.
if has_tool_results and has_appended_assistant:
assistant_response = ChatMessage(
role="assistant", content=delta
)
accumulated_tool_calls = []
has_appended_assistant = False
has_tool_results = False
session.messages.append(assistant_response)
has_appended_assistant = True
else:
assistant_response.content = (
assistant_response.content or ""
) + delta
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolInputAvailable):
accumulated_tool_calls.append(
{
"id": response.toolCallId,
"type": "function",
"function": {
"name": response.toolName,
"arguments": json.dumps(response.input or {}),
},
}
)
assistant_response.tool_calls = accumulated_tool_calls
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolOutputAvailable):
session.messages.append(
ChatMessage(
role="tool",
content=(
response.output
if isinstance(response.output, str)
else str(response.output)
),
tool_call_id=response.toolCallId,
)
)
has_tool_results = True
elif isinstance(response, StreamFinish):
stream_completed = True
if stream_completed:
break
if (
assistant_response.content or assistant_response.tool_calls
) and not has_appended_assistant:
session.messages.append(assistant_response)
# --- Capture transcript while CLI is still alive ---
# Must happen INSIDE async with: close() sends SIGTERM
# which kills the CLI before it can flush the JSONL.
if (
config.claude_agent_use_resume
and user_id
and captured_transcript.available
):
# Give CLI time to flush JSONL writes before we read
await asyncio.sleep(0.5)
raw_transcript = read_transcript_file(captured_transcript.path)
if raw_transcript:
task = asyncio.create_task(
_upload_transcript_bg(user_id, session_id, raw_transcript)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
else:
logger.debug("[SDK] Stop hook fired but transcript not usable")
except ImportError:
raise RuntimeError(
"claude-agent-sdk is not installed. "
"Disable SDK mode (CHAT_USE_CLAUDE_AGENT_SDK=false) "
"to use the OpenAI-compatible fallback."
)
await upsert_chat_session(session)
logger.debug(
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
)
if not stream_completed:
yield StreamFinish()
except Exception as e:
logger.error(f"[SDK] Error: {e}", exc_info=True)
try:
await upsert_chat_session(session)
except Exception as save_err:
logger.error(f"[SDK] Failed to save session on error: {save_err}")
yield StreamError(
errorText="An error occurred. Please try again.",
code="sdk_error",
)
yield StreamFinish()
finally:
if sdk_cwd:
_cleanup_sdk_tool_results(sdk_cwd)
async def _upload_transcript_bg(
user_id: str, session_id: str, raw_content: str
) -> None:
"""Background task to strip progress entries and upload transcript."""
try:
await upload_transcript(user_id, session_id, raw_content)
except Exception as e:
logger.error(f"[SDK] Failed to upload transcript for {session_id}: {e}")
async def _update_title_async(
session_id: str, message: str, user_id: str | None = None
) -> None:
"""Background task to update session title."""
try:
title = await _generate_session_title(
message, user_id=user_id, session_id=session_id
)
if title:
await update_session_title(session_id, title)
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
except Exception as e:
logger.warning(f"[SDK] Failed to update session title: {e}")

View File

@@ -1,363 +0,0 @@
"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools.
This module provides the adapter layer that converts existing BaseTool implementations
into in-process MCP tools that can be used with the Claude Agent SDK.
Long-running tools (``is_long_running=True``) are delegated to the non-SDK
background infrastructure (stream_registry, Redis persistence, SSE reconnection)
via a callback provided by the service layer. This avoids wasteful SDK polling
and makes results survive page refreshes.
"""
import itertools
import json
import logging
import os
import uuid
from collections.abc import Awaitable, Callable
from contextvars import ContextVar
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools import TOOL_REGISTRY
from backend.api.features.chat.tools.base import BaseTool
logger = logging.getLogger(__name__)
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
# in the path — prevents reading settings, credentials, or other sensitive files.
_SDK_PROJECTS_DIR = os.path.expanduser("~/.claude/projects/")
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
MCP_SERVER_NAME = "copilot"
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
# Context variables to pass user/session info to tool execution
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
# Stash for MCP tool outputs before the SDK potentially truncates them.
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar(
"pending_tool_outputs", default=None # type: ignore[arg-type]
)
# Callback type for delegating long-running tools to the non-SDK infrastructure.
# Args: (tool_name, arguments, session) → MCP-formatted response dict.
LongRunningCallback = Callable[
[str, dict[str, Any], ChatSession], Awaitable[dict[str, Any]]
]
# ContextVar so the service layer can inject the callback per-request.
_long_running_callback: ContextVar[LongRunningCallback | None] = ContextVar(
"long_running_callback", default=None
)
def set_execution_context(
user_id: str | None,
session: ChatSession,
long_running_callback: LongRunningCallback | None = None,
) -> None:
"""Set the execution context for tool calls.
This must be called before streaming begins to ensure tools have access
to user_id and session information.
Args:
user_id: Current user's ID.
session: Current chat session.
long_running_callback: Optional callback to delegate long-running tools
to the non-SDK background infrastructure (stream_registry + Redis).
"""
_current_user_id.set(user_id)
_current_session.set(session)
_pending_tool_outputs.set({})
_long_running_callback.set(long_running_callback)
def get_execution_context() -> tuple[str | None, ChatSession | None]:
"""Get the current execution context."""
return (
_current_user_id.get(),
_current_session.get(),
)
def pop_pending_tool_output(tool_name: str) -> str | None:
"""Pop and return the stashed full output for *tool_name*.
The SDK CLI may truncate large tool results (writing them to disk and
replacing the content with a file reference). This stash keeps the
original MCP output so the response adapter can forward it to the
frontend for proper widget rendering.
Returns ``None`` if nothing was stashed for *tool_name*.
"""
pending = _pending_tool_outputs.get(None)
if pending is None:
return None
return pending.pop(tool_name, None)
async def _execute_tool_sync(
base_tool: BaseTool,
user_id: str | None,
session: ChatSession,
args: dict[str, Any],
) -> dict[str, Any]:
"""Execute a tool synchronously and return MCP-formatted response."""
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
result = await base_tool.execute(
user_id=user_id,
session=session,
tool_call_id=effective_id,
**args,
)
text = (
result.output if isinstance(result.output, str) else json.dumps(result.output)
)
# Stash the full output before the SDK potentially truncates it.
pending = _pending_tool_outputs.get(None)
if pending is not None:
pending[base_tool.name] = text
return {
"content": [{"type": "text", "text": text}],
"isError": not result.success,
}
def _mcp_error(message: str) -> dict[str, Any]:
return {
"content": [
{"type": "text", "text": json.dumps({"error": message, "type": "error"})}
],
"isError": True,
}
def create_tool_handler(base_tool: BaseTool):
"""Create an async handler function for a BaseTool.
This wraps the existing BaseTool._execute method to be compatible
with the Claude Agent SDK MCP tool format.
Long-running tools (``is_long_running=True``) are delegated to the
non-SDK background infrastructure via a callback set in the execution
context. The callback persists the operation in Redis (stream_registry)
so results survive page refreshes and pod restarts.
"""
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Execute the wrapped tool and return MCP-formatted response."""
user_id, session = get_execution_context()
if session is None:
return _mcp_error("No session context available")
# --- Long-running: delegate to non-SDK background infrastructure ---
if base_tool.is_long_running:
callback = _long_running_callback.get(None)
if callback:
try:
return await callback(base_tool.name, args, session)
except Exception as e:
logger.error(
f"Long-running callback failed for {base_tool.name}: {e}",
exc_info=True,
)
return _mcp_error(f"Failed to start {base_tool.name}: {e}")
# No callback — fall through to synchronous execution
logger.warning(
f"[SDK] No long-running callback for {base_tool.name}, "
f"executing synchronously (may block)"
)
# --- Normal (fast) tool: execute synchronously ---
try:
return await _execute_tool_sync(base_tool, user_id, session, args)
except Exception as e:
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
return tool_handler
def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
"""Build a JSON Schema input schema for a tool."""
return {
"type": "object",
"properties": base_tool.parameters.get("properties", {}),
"required": base_tool.parameters.get("required", []),
}
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Read a file with optional offset/limit. Restricted to SDK working directory.
After reading, the file is deleted to prevent accumulation in long-running pods.
"""
file_path = args.get("file_path", "")
offset = args.get("offset", 0)
limit = args.get("limit", 2000)
# Security: only allow reads under ~/.claude/projects/**/tool-results/
real_path = os.path.realpath(file_path)
if not real_path.startswith(_SDK_PROJECTS_DIR) or "tool-results" not in real_path:
return {
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
"isError": True,
}
try:
with open(real_path) as f:
selected = list(itertools.islice(f, offset, offset + limit))
content = "".join(selected)
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
# don't delete here — the SDK may read in multiple chunks.
return {"content": [{"type": "text", "text": content}], "isError": False}
except FileNotFoundError:
return {
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
"isError": True,
}
except Exception as e:
return {
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
"isError": True,
}
_READ_TOOL_NAME = "Read"
_READ_TOOL_DESCRIPTION = (
"Read a file from the local filesystem. "
"Use offset and limit to read specific line ranges for large files."
)
_READ_TOOL_SCHEMA = {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "The absolute path to the file to read",
},
"offset": {
"type": "integer",
"description": "Line number to start reading from (0-indexed). Default: 0",
},
"limit": {
"type": "integer",
"description": "Number of lines to read. Default: 2000",
},
},
"required": ["file_path"],
}
# Create the MCP server configuration
def create_copilot_mcp_server():
"""Create an in-process MCP server configuration for CoPilot tools.
This can be passed to ClaudeAgentOptions.mcp_servers.
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
package being available. This function returns the configuration that
can be used with the SDK.
"""
try:
from claude_agent_sdk import create_sdk_mcp_server, tool
# Create decorated tool functions
sdk_tools = []
for tool_name, base_tool in TOOL_REGISTRY.items():
handler = create_tool_handler(base_tool)
decorated = tool(
tool_name,
base_tool.description,
_build_input_schema(base_tool),
)(handler)
sdk_tools.append(decorated)
# Add the Read tool so the SDK can read back oversized tool results
read_tool = tool(
_READ_TOOL_NAME,
_READ_TOOL_DESCRIPTION,
_READ_TOOL_SCHEMA,
)(_read_file_handler)
sdk_tools.append(read_tool)
server = create_sdk_mcp_server(
name=MCP_SERVER_NAME,
version="1.0.0",
tools=sdk_tools,
)
return server
except ImportError:
# Let ImportError propagate so service.py handles the fallback
raise
# SDK built-in tools allowed within the workspace directory.
# Security hooks validate that file paths stay within sdk_cwd.
# Bash is NOT included — use the sandboxed MCP bash_exec tool instead,
# which provides kernel-level network isolation via unshare --net.
# Task allows spawning sub-agents (rate-limited by security hooks).
# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk.
_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Task", "WebSearch"]
# SDK built-in tools that must be explicitly blocked.
# Bash: dangerous — agent uses mcp__copilot__bash_exec with kernel-level
# network isolation (unshare --net) instead.
# WebFetch: SSRF risk — can reach internal network (localhost, 10.x, etc.).
# Agent uses the SSRF-protected mcp__copilot__web_fetch tool instead.
SDK_DISALLOWED_TOOLS = ["Bash", "WebFetch"]
# Tools that are blocked entirely in security hooks (defence-in-depth).
# Includes SDK_DISALLOWED_TOOLS plus common aliases/synonyms.
BLOCKED_TOOLS = {
*SDK_DISALLOWED_TOOLS,
"bash",
"shell",
"exec",
"terminal",
"command",
}
# Tools allowed only when their path argument stays within the SDK workspace.
# The SDK uses these to handle oversized tool results (writes to tool-results/
# files, then reads them back) and for workspace file operations.
WORKSPACE_SCOPED_TOOLS = {"Read", "Write", "Edit", "Glob", "Grep"}
# Dangerous patterns in tool inputs
DANGEROUS_PATTERNS = [
r"sudo",
r"rm\s+-rf",
r"dd\s+if=",
r"/etc/passwd",
r"/etc/shadow",
r"chmod\s+777",
r"curl\s+.*\|.*sh",
r"wget\s+.*\|.*sh",
r"eval\s*\(",
r"exec\s*\(",
r"__import__",
r"os\.system",
r"subprocess",
]
# List of tool names for allowed_tools configuration
# Include MCP tools, the MCP Read tool for oversized results,
# and SDK built-in file tools for workspace operations.
COPILOT_TOOL_NAMES = [
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
*_SDK_BUILTIN_TOOLS,
]

View File

@@ -1,356 +0,0 @@
"""JSONL transcript management for stateless multi-turn resume.
The Claude Code CLI persists conversations as JSONL files (one JSON object per
line). When the SDK's ``Stop`` hook fires we read this file, strip bloat
(progress entries, metadata), and upload the result to bucket storage. On the
next turn we download the transcript, write it to a temp file, and pass
``--resume`` so the CLI can reconstruct the full conversation.
Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
filesystem for self-hosted) — no DB column needed.
"""
import json
import logging
import os
import re
logger = logging.getLogger(__name__)
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
_SAFE_ID_RE = re.compile(r"[^0-9a-fA-F-]")
# Entry types that can be safely removed from the transcript without breaking
# the parentUuid conversation tree that ``--resume`` relies on.
# - progress: UI progress ticks, no message content (avg 97KB for agent_progress)
# - file-history-snapshot: undo tracking metadata
# - queue-operation: internal queue bookkeeping
# - summary: session summaries
# - pr-link: PR link metadata
STRIPPABLE_TYPES = frozenset(
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
)
# Workspace storage constants — deterministic path from session_id.
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
# ---------------------------------------------------------------------------
# Progress stripping
# ---------------------------------------------------------------------------
def strip_progress_entries(content: str) -> str:
"""Remove progress/metadata entries from a JSONL transcript.
Removes entries whose ``type`` is in ``STRIPPABLE_TYPES`` and reparents
any remaining child entries so the ``parentUuid`` chain stays intact.
Typically reduces transcript size by ~30%.
"""
lines = content.strip().split("\n")
entries: list[dict] = []
for line in lines:
try:
entries.append(json.loads(line))
except json.JSONDecodeError:
# Keep unparseable lines as-is (safety)
entries.append({"_raw": line})
stripped_uuids: set[str] = set()
uuid_to_parent: dict[str, str] = {}
kept: list[dict] = []
for entry in entries:
if "_raw" in entry:
kept.append(entry)
continue
uid = entry.get("uuid", "")
parent = entry.get("parentUuid", "")
entry_type = entry.get("type", "")
if uid:
uuid_to_parent[uid] = parent
if entry_type in STRIPPABLE_TYPES:
if uid:
stripped_uuids.add(uid)
else:
kept.append(entry)
# Reparent: walk up chain through stripped entries to find surviving ancestor
for entry in kept:
if "_raw" in entry:
continue
parent = entry.get("parentUuid", "")
original_parent = parent
while parent in stripped_uuids:
parent = uuid_to_parent.get(parent, "")
if parent != original_parent:
entry["parentUuid"] = parent
result_lines: list[str] = []
for entry in kept:
if "_raw" in entry:
result_lines.append(entry["_raw"])
else:
result_lines.append(json.dumps(entry, separators=(",", ":")))
return "\n".join(result_lines) + "\n"
# ---------------------------------------------------------------------------
# Local file I/O (read from CLI's JSONL, write temp file for --resume)
# ---------------------------------------------------------------------------
def read_transcript_file(transcript_path: str) -> str | None:
"""Read a JSONL transcript file from disk.
Returns the raw JSONL content, or ``None`` if the file is missing, empty,
or only contains metadata (≤2 lines with no conversation messages).
"""
if not transcript_path or not os.path.isfile(transcript_path):
logger.debug(f"[Transcript] File not found: {transcript_path}")
return None
try:
with open(transcript_path) as f:
content = f.read()
if not content.strip():
logger.debug(f"[Transcript] Empty file: {transcript_path}")
return None
lines = content.strip().split("\n")
if len(lines) < 3:
# Raw files with ≤2 lines are metadata-only
# (queue-operation + file-history-snapshot, no conversation).
logger.debug(
f"[Transcript] Too few lines ({len(lines)}): {transcript_path}"
)
return None
# Quick structural validation — parse first and last lines.
json.loads(lines[0])
json.loads(lines[-1])
logger.info(
f"[Transcript] Read {len(lines)} lines, "
f"{len(content)} bytes from {transcript_path}"
)
return content
except (json.JSONDecodeError, OSError) as e:
logger.warning(f"[Transcript] Failed to read {transcript_path}: {e}")
return None
def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
"""Sanitize an ID for safe use in file paths.
Session/user IDs are expected to be UUIDs (hex + hyphens). Strip
everything else and truncate to *max_len* so the result cannot introduce
path separators or other special characters.
"""
cleaned = _SAFE_ID_RE.sub("", raw_id or "")[:max_len]
return cleaned or "unknown"
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
def write_transcript_to_tempfile(
transcript_content: str,
session_id: str,
cwd: str,
) -> str | None:
"""Write JSONL transcript to a temp file inside *cwd* for ``--resume``.
The file lives in the session working directory so it is cleaned up
automatically when the session ends.
Returns the absolute path to the file, or ``None`` on failure.
"""
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
real_cwd = os.path.realpath(cwd)
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
return None
try:
os.makedirs(real_cwd, exist_ok=True)
safe_id = _sanitize_id(session_id, max_len=8)
jsonl_path = os.path.realpath(
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
)
if not jsonl_path.startswith(real_cwd):
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
return None
with open(jsonl_path, "w") as f:
f.write(transcript_content)
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
return jsonl_path
except OSError as e:
logger.warning(f"[Transcript] Failed to write resume file: {e}")
return None
def validate_transcript(content: str | None) -> bool:
"""Check that a transcript has actual conversation messages.
A valid transcript for resume needs at least one user message and one
assistant message (not just queue-operation / file-history-snapshot
metadata).
"""
if not content or not content.strip():
return False
lines = content.strip().split("\n")
if len(lines) < 2:
return False
has_user = False
has_assistant = False
for line in lines:
try:
entry = json.loads(line)
msg_type = entry.get("type")
if msg_type == "user":
has_user = True
elif msg_type == "assistant":
has_assistant = True
except json.JSONDecodeError:
return False
return has_user and has_assistant
# ---------------------------------------------------------------------------
# Bucket storage (GCS / local via WorkspaceStorageBackend)
# ---------------------------------------------------------------------------
def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
"""Return (workspace_id, file_id, filename) for a session's transcript.
Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl``
IDs are sanitized to hex+hyphen to prevent path traversal.
"""
return (
TRANSCRIPT_STORAGE_PREFIX,
_sanitize_id(user_id),
f"{_sanitize_id(session_id)}.jsonl",
)
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
"""Build the full storage path string that ``retrieve()`` expects.
``store()`` returns a path like ``gcs://bucket/workspaces/...`` or
``local://workspace_id/file_id/filename``. Since we use deterministic
arguments we can reconstruct the same path for download/delete without
having stored the return value.
"""
from backend.util.workspace_storage import GCSWorkspaceStorage
wid, fid, fname = _storage_path_parts(user_id, session_id)
if isinstance(backend, GCSWorkspaceStorage):
blob = f"workspaces/{wid}/{fid}/{fname}"
return f"gcs://{backend.bucket_name}/{blob}"
else:
# LocalWorkspaceStorage returns local://{relative_path}
return f"local://{wid}/{fid}/{fname}"
async def upload_transcript(user_id: str, session_id: str, content: str) -> None:
"""Strip progress entries and upload transcript to bucket storage.
Safety: only overwrites when the new (stripped) transcript is larger than
what is already stored. Since JSONL is append-only, the latest transcript
is always the longest. This prevents a slow/stale background task from
clobbering a newer upload from a concurrent turn.
"""
from backend.util.workspace_storage import get_workspace_storage
stripped = strip_progress_entries(content)
if not validate_transcript(stripped):
logger.warning(
f"[Transcript] Skipping upload — stripped content is not a valid "
f"transcript for session {session_id}"
)
return
storage = await get_workspace_storage()
wid, fid, fname = _storage_path_parts(user_id, session_id)
encoded = stripped.encode("utf-8")
new_size = len(encoded)
# Check existing transcript size to avoid overwriting newer with older
path = _build_storage_path(user_id, session_id, storage)
try:
existing = await storage.retrieve(path)
if len(existing) >= new_size:
logger.info(
f"[Transcript] Skipping upload — existing transcript "
f"({len(existing)}B) >= new ({new_size}B) for session "
f"{session_id}"
)
return
except (FileNotFoundError, Exception):
pass # No existing transcript or retrieval error — proceed with upload
await storage.store(
workspace_id=wid,
file_id=fid,
filename=fname,
content=encoded,
)
logger.info(
f"[Transcript] Uploaded {new_size} bytes "
f"(stripped from {len(content)}) for session {session_id}"
)
async def download_transcript(user_id: str, session_id: str) -> str | None:
"""Download transcript from bucket storage.
Returns the JSONL content string, or ``None`` if not found.
"""
from backend.util.workspace_storage import get_workspace_storage
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
try:
data = await storage.retrieve(path)
content = data.decode("utf-8")
logger.info(
f"[Transcript] Downloaded {len(content)} bytes for session {session_id}"
)
return content
except FileNotFoundError:
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
return None
except Exception as e:
logger.warning(f"[Transcript] Failed to download transcript: {e}")
return None
async def delete_transcript(user_id: str, session_id: str) -> None:
"""Delete transcript from bucket storage (e.g. after resume failure)."""
from backend.util.workspace_storage import get_workspace_storage
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
try:
await storage.delete(path)
logger.info(f"[Transcript] Deleted transcript for session {session_id}")
except Exception as e:
logger.warning(f"[Transcript] Failed to delete transcript: {e}")

View File

@@ -245,16 +245,12 @@ async def _get_system_prompt_template(context: str) -> str:
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
async def _build_system_prompt(
user_id: str | None, has_conversation_history: bool = False
) -> tuple[str, Any]:
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
"""Build the full system prompt including business understanding if available.
Args:
user_id: The user ID for fetching business understanding.
has_conversation_history: Whether there's existing conversation history.
If True, we don't tell the model to greet/introduce (since they're
already in a conversation).
user_id: The user ID for fetching business understanding
If "default" and this is the user's first session, will use "onboarding" instead.
Returns:
Tuple of (compiled prompt string, business understanding object)
@@ -270,8 +266,6 @@ async def _build_system_prompt(
if understanding:
context = format_understanding_for_prompt(understanding)
elif has_conversation_history:
context = "No prior understanding saved yet. Continue the existing conversation naturally."
else:
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
@@ -380,6 +374,7 @@ async def stream_chat_completion(
Raises:
NotFoundError: If session_id is invalid
ValueError: If max_context_messages is exceeded
"""
completion_start = time.monotonic()
@@ -464,9 +459,8 @@ async def stream_chat_completion(
# Generate title for new sessions on first user message (non-blocking)
# Check: is_user_message, no title yet, and this is the first user message
user_messages = [m for m in session.messages if m.role == "user"]
first_user_msg = message or (user_messages[0].content if user_messages else None)
if is_user_message and first_user_msg and not session.title:
if is_user_message and message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
if len(user_messages) == 1:
# First user message - generate title in background
import asyncio
@@ -474,7 +468,7 @@ async def stream_chat_completion(
# Capture only the values we need (not the session object) to avoid
# stale data issues when the main flow modifies the session
captured_session_id = session_id
captured_message = first_user_msg
captured_message = message
captured_user_id = user_id
async def _update_title():
@@ -1243,7 +1237,7 @@ async def _stream_chat_chunks(
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
logger.info(
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; "
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
f"session={session.session_id}, user={session.user_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
)

View File

@@ -1,4 +1,3 @@
import asyncio
import logging
from os import getenv
@@ -12,8 +11,6 @@ from .response_model import (
StreamTextDelta,
StreamToolOutputAvailable,
)
from .sdk import service as sdk_service
from .sdk.transcript import download_transcript
logger = logging.getLogger(__name__)
@@ -83,96 +80,3 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
session = await get_chat_session(session.session_id)
assert session, "Session not found"
assert session.usage, "Usage is empty"
@pytest.mark.asyncio(loop_scope="session")
async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
"""Test that the SDK --resume path captures and uses transcripts across turns.
Turn 1: Send a message containing a unique keyword.
Turn 2: Ask the model to recall that keyword — proving the transcript was
persisted and restored via --resume.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
from .config import ChatConfig
cfg = ChatConfig()
if not cfg.claude_agent_use_resume:
return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test")
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
# --- Turn 1: send a message with a unique keyword ---
keyword = "ZEPHYR42"
turn1_msg = (
f"Please remember this special keyword: {keyword}. "
"Just confirm you've noted it, keep your response brief."
)
turn1_text = ""
turn1_errors: list[str] = []
turn1_ended = False
async for chunk in sdk_service.stream_chat_completion_sdk(
session.session_id,
turn1_msg,
user_id=test_user_id,
):
if isinstance(chunk, StreamTextDelta):
turn1_text += chunk.delta
elif isinstance(chunk, StreamError):
turn1_errors.append(chunk.errorText)
elif isinstance(chunk, StreamFinish):
turn1_ended = True
assert turn1_ended, "Turn 1 did not finish"
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
assert turn1_text, "Turn 1 produced no text"
# Wait for background upload task to complete (retry up to 5s)
transcript = None
for _ in range(10):
await asyncio.sleep(0.5)
transcript = await download_transcript(test_user_id, session.session_id)
if transcript:
break
assert transcript, (
"Transcript was not uploaded to bucket after turn 1 — "
"Stop hook may not have fired or transcript was too small"
)
logger.info(f"Turn 1 transcript uploaded: {len(transcript)} bytes")
# Reload session for turn 2
session = await get_chat_session(session.session_id, test_user_id)
assert session, "Session not found after turn 1"
# --- Turn 2: ask model to recall the keyword ---
turn2_msg = "What was the special keyword I asked you to remember?"
turn2_text = ""
turn2_errors: list[str] = []
turn2_ended = False
async for chunk in sdk_service.stream_chat_completion_sdk(
session.session_id,
turn2_msg,
user_id=test_user_id,
session=session,
):
if isinstance(chunk, StreamTextDelta):
turn2_text += chunk.delta
elif isinstance(chunk, StreamError):
turn2_errors.append(chunk.errorText)
elif isinstance(chunk, StreamFinish):
turn2_ended = True
assert turn2_ended, "Turn 2 did not finish"
assert not turn2_errors, f"Turn 2 errors: {turn2_errors}"
assert turn2_text, "Turn 2 produced no text"
assert keyword in turn2_text, (
f"Model did not recall keyword '{keyword}' in turn 2. "
f"Response: {turn2_text[:200]}"
)
logger.info(f"Turn 2 recalled keyword successfully: {turn2_text[:100]}")

View File

@@ -814,28 +814,6 @@ async def get_active_task_for_session(
if task_user_id and user_id != task_user_id:
continue
# Auto-expire stale tasks that exceeded stream_timeout
created_at_str = meta.get("created_at", "")
if created_at_str:
try:
created_at = datetime.fromisoformat(created_at_str)
age_seconds = (
datetime.now(timezone.utc) - created_at
).total_seconds()
if age_seconds > config.stream_timeout:
logger.warning(
f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... "
f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)"
)
await mark_task_completed(task_id, "failed")
continue
except (ValueError, TypeError):
pass
logger.info(
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
)
# Get the last message ID from Redis Stream
stream_key = _get_task_stream_key(task_id)
last_id = "0-0"

View File

@@ -9,8 +9,6 @@ from backend.api.features.chat.tracking import track_tool_called
from .add_understanding import AddUnderstandingTool
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .check_operation_status import CheckOperationStatusTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool
@@ -22,7 +20,6 @@ from .get_doc_page import GetDocPageTool
from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .search_docs import SearchDocsTool
from .web_fetch import WebFetchTool
from .workspace_files import (
DeleteWorkspaceFileTool,
ListWorkspaceFilesTool,
@@ -47,14 +44,8 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"view_agent_output": AgentOutputTool(),
"check_operation_status": CheckOperationStatusTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
# Web fetch for safe URL retrieval
"web_fetch": WebFetchTool(),
# Sandboxed code execution (bubblewrap)
"bash_exec": BashExecTool(),
# Persistent workspace tools (cloud storage, survives across sessions)
# Feature request tools
"search_feature_requests": SearchFeatureRequestsTool(),
"create_feature_request": CreateFeatureRequestTool(),

View File

@@ -1,131 +0,0 @@
"""Bash execution tool — run shell commands in a bubblewrap sandbox.
Full Bash scripting is allowed (loops, conditionals, pipes, functions, etc.).
Safety comes from OS-level isolation (bubblewrap): only system dirs visible
read-only, writable workspace only, clean env, no network.
Requires bubblewrap (``bwrap``) — the tool is disabled when bwrap is not
available (e.g. macOS development).
"""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
BashExecResponse,
ErrorResponse,
ToolResponseBase,
)
from backend.api.features.chat.tools.sandbox import (
get_workspace_dir,
has_full_sandbox,
run_sandboxed,
)
logger = logging.getLogger(__name__)
class BashExecTool(BaseTool):
"""Execute Bash commands in a bubblewrap sandbox."""
@property
def name(self) -> str:
return "bash_exec"
@property
def description(self) -> str:
if not has_full_sandbox():
return (
"Bash execution is DISABLED — bubblewrap sandbox is not "
"available on this platform. Do not call this tool."
)
return (
"Execute a Bash command or script in a bubblewrap sandbox. "
"Full Bash scripting is supported (loops, conditionals, pipes, "
"functions, etc.). "
"The sandbox shares the same working directory as the SDK Read/Write "
"tools — files created by either are accessible to both. "
"SECURITY: Only system directories (/usr, /bin, /lib, /etc) are "
"visible read-only, the per-session workspace is the only writable "
"path, environment variables are wiped (no secrets), all network "
"access is blocked at the kernel level, and resource limits are "
"enforced (max 64 processes, 512MB memory, 50MB file size). "
"Application code, configs, and other directories are NOT accessible. "
"To fetch web content, use the web_fetch tool instead. "
"Execution is killed after the timeout (default 30s, max 120s). "
"Returns stdout and stderr. "
"Useful for file manipulation, data processing with Unix tools "
"(grep, awk, sed, jq, etc.), and running shell scripts."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "Bash command or script to execute.",
},
"timeout": {
"type": "integer",
"description": (
"Max execution time in seconds (default 30, max 120)."
),
"default": 30,
},
},
"required": ["command"],
}
@property
def requires_auth(self) -> bool:
return False
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
session_id = session.session_id if session else None
if not has_full_sandbox():
return ErrorResponse(
message="bash_exec requires bubblewrap sandbox (Linux only).",
error="sandbox_unavailable",
session_id=session_id,
)
command: str = (kwargs.get("command") or "").strip()
timeout: int = kwargs.get("timeout", 30)
if not command:
return ErrorResponse(
message="No command provided.",
error="empty_command",
session_id=session_id,
)
workspace = get_workspace_dir(session_id or "default")
stdout, stderr, exit_code, timed_out = await run_sandboxed(
command=["bash", "-c", command],
cwd=workspace,
timeout=timeout,
)
return BashExecResponse(
message=(
"Execution timed out"
if timed_out
else f"Command executed (exit {exit_code})"
),
stdout=stdout,
stderr=stderr,
exit_code=exit_code,
timed_out=timed_out,
session_id=session_id,
)

View File

@@ -1,127 +0,0 @@
"""CheckOperationStatusTool — query the status of a long-running operation."""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
ErrorResponse,
ResponseType,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
class OperationStatusResponse(ToolResponseBase):
"""Response for check_operation_status tool."""
type: ResponseType = ResponseType.OPERATION_STATUS
task_id: str
operation_id: str
status: str # "running", "completed", "failed"
tool_name: str | None = None
message: str = ""
class CheckOperationStatusTool(BaseTool):
"""Check the status of a long-running operation (create_agent, edit_agent, etc.).
The CoPilot uses this tool to report back to the user whether an
operation that was started earlier has completed, failed, or is still
running.
"""
@property
def name(self) -> str:
return "check_operation_status"
@property
def description(self) -> str:
return (
"Check the current status of a long-running operation such as "
"create_agent or edit_agent. Accepts either an operation_id or "
"task_id from a previous operation_started response. "
"Returns the current status: running, completed, or failed."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"operation_id": {
"type": "string",
"description": (
"The operation_id from an operation_started response."
),
},
"task_id": {
"type": "string",
"description": (
"The task_id from an operation_started response. "
"Used as fallback if operation_id is not provided."
),
},
},
"required": [],
}
@property
def requires_auth(self) -> bool:
return False
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
from backend.api.features.chat import stream_registry
operation_id = (kwargs.get("operation_id") or "").strip()
task_id = (kwargs.get("task_id") or "").strip()
if not operation_id and not task_id:
return ErrorResponse(
message="Please provide an operation_id or task_id.",
error="missing_parameter",
)
task = None
if operation_id:
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None and task_id:
task = await stream_registry.get_task(task_id)
if task is None:
# Task not in Redis — it may have already expired (TTL).
# Check conversation history for the result instead.
return ErrorResponse(
message=(
"Operation not found — it may have already completed and "
"expired from the status tracker. Check the conversation "
"history for the result."
),
error="not_found",
)
status_messages = {
"running": (
f"The {task.tool_name or 'operation'} is still running. "
"Please wait for it to complete."
),
"completed": (
f"The {task.tool_name or 'operation'} has completed successfully."
),
"failed": f"The {task.tool_name or 'operation'} has failed.",
}
return OperationStatusResponse(
task_id=task.task_id,
operation_id=task.operation_id,
status=task.status,
tool_name=task.tool_name,
message=status_messages.get(task.status, f"Status: {task.status}"),
)

View File

@@ -146,7 +146,6 @@ class FindBlockTool(BaseTool):
id=block_id,
name=block.name,
description=block.description or "",
categories=[c.value for c in block.categories],
)
)

View File

@@ -41,12 +41,6 @@ class ResponseType(str, Enum):
OPERATION_IN_PROGRESS = "operation_in_progress"
# Input validation
INPUT_VALIDATION_ERROR = "input_validation_error"
# Web fetch
WEB_FETCH = "web_fetch"
# Code execution
BASH_EXEC = "bash_exec"
# Operation status check
OPERATION_STATUS = "operation_status"
# Feature request types
FEATURE_REQUEST_SEARCH = "feature_request_search"
FEATURE_REQUEST_CREATED = "feature_request_created"
@@ -344,19 +338,6 @@ class BlockInfoSummary(BaseModel):
id: str
name: str
description: str
categories: list[str]
input_schema: dict[str, Any] = Field(
default_factory=dict,
description="Full JSON schema for block inputs",
)
output_schema: dict[str, Any] = Field(
default_factory=dict,
description="Full JSON schema for block outputs",
)
required_inputs: list[BlockInputFieldInfo] = Field(
default_factory=list,
description="List of input fields for this block",
)
class BlockListResponse(ToolResponseBase):
@@ -366,10 +347,6 @@ class BlockListResponse(ToolResponseBase):
blocks: list[BlockInfoSummary]
count: int
query: str
usage_hint: str = Field(
default="To execute a block, call run_block with block_id set to the block's "
"'id' field and input_data containing the fields listed in required_inputs."
)
class BlockDetails(BaseModel):
@@ -458,27 +435,6 @@ class AsyncProcessingResponse(ToolResponseBase):
task_id: str | None = None
class WebFetchResponse(ToolResponseBase):
"""Response for web_fetch tool."""
type: ResponseType = ResponseType.WEB_FETCH
url: str
status_code: int
content_type: str
content: str
truncated: bool = False
class BashExecResponse(ToolResponseBase):
"""Response for bash_exec tool."""
type: ResponseType = ResponseType.BASH_EXEC
stdout: str
stderr: str
exit_code: int
timed_out: bool = False
# Feature request models
class FeatureRequestInfo(BaseModel):
"""Information about a feature request issue."""

View File

@@ -1,265 +0,0 @@
"""Sandbox execution utilities for code execution tools.
Provides filesystem + network isolated command execution using **bubblewrap**
(``bwrap``): whitelist-only filesystem (only system dirs visible read-only),
writable workspace only, clean environment, network blocked.
Tools that call :func:`run_sandboxed` must first check :func:`has_full_sandbox`
and refuse to run if bubblewrap is not available.
"""
import asyncio
import logging
import os
import platform
import shutil
logger = logging.getLogger(__name__)
_DEFAULT_TIMEOUT = 30
_MAX_TIMEOUT = 120
# ---------------------------------------------------------------------------
# Sandbox capability detection (cached at first call)
# ---------------------------------------------------------------------------
_BWRAP_AVAILABLE: bool | None = None
def has_full_sandbox() -> bool:
"""Return True if bubblewrap is available (filesystem + network isolation).
On non-Linux platforms (macOS), always returns False.
"""
global _BWRAP_AVAILABLE
if _BWRAP_AVAILABLE is None:
_BWRAP_AVAILABLE = (
platform.system() == "Linux" and shutil.which("bwrap") is not None
)
return _BWRAP_AVAILABLE
WORKSPACE_PREFIX = "/tmp/copilot-"
def make_session_path(session_id: str) -> str:
"""Build a sanitized, session-specific path under :data:`WORKSPACE_PREFIX`.
Shared by both the SDK working-directory setup and the sandbox tools so
they always resolve to the same directory for a given session.
Steps:
1. Strip all characters except ``[A-Za-z0-9-]``.
2. Construct ``/tmp/copilot-<safe_id>``.
3. Validate via ``os.path.normpath`` + ``startswith`` (CodeQL-recognised
sanitizer) to prevent path traversal.
Raises:
ValueError: If the resulting path escapes the prefix.
"""
import re
safe_id = re.sub(r"[^A-Za-z0-9-]", "", session_id)
if not safe_id:
safe_id = "default"
path = os.path.normpath(f"{WORKSPACE_PREFIX}{safe_id}")
if not path.startswith(WORKSPACE_PREFIX):
raise ValueError(f"Session path escaped prefix: {path}")
return path
def get_workspace_dir(session_id: str) -> str:
"""Get or create the workspace directory for a session.
Uses :func:`make_session_path` — the same path the SDK uses — so that
bash_exec shares the workspace with the SDK file tools.
"""
workspace = make_session_path(session_id)
os.makedirs(workspace, exist_ok=True)
return workspace
# ---------------------------------------------------------------------------
# Bubblewrap command builder
# ---------------------------------------------------------------------------
# System directories mounted read-only inside the sandbox.
# ONLY these are visible — /app, /root, /home, /opt, /var etc. are NOT accessible.
_SYSTEM_RO_BINDS = [
"/usr", # binaries, libraries, Python interpreter
"/etc", # system config: ld.so, locale, passwd, alternatives
]
# Compat paths: symlinks to /usr/* on modern Debian, real dirs on older systems.
# On Debian 13 these are symlinks (e.g. /bin -> usr/bin). bwrap --ro-bind
# can't create a symlink target, so we detect and use --symlink instead.
# /lib64 is critical: the ELF dynamic linker lives at /lib64/ld-linux-x86-64.so.2.
_COMPAT_PATHS = [
("/bin", "usr/bin"), # -> /usr/bin on Debian 13
("/sbin", "usr/sbin"), # -> /usr/sbin on Debian 13
("/lib", "usr/lib"), # -> /usr/lib on Debian 13
("/lib64", "usr/lib64"), # 64-bit libraries / ELF interpreter
]
# Resource limits to prevent fork bombs, memory exhaustion, and disk abuse.
# Applied via ulimit inside the sandbox before exec'ing the user command.
_RESOURCE_LIMITS = (
"ulimit -u 64" # max 64 processes (prevents fork bombs)
" -v 524288" # 512 MB virtual memory
" -f 51200" # 50 MB max file size (1024-byte blocks)
" -n 256" # 256 open file descriptors
" 2>/dev/null"
)
def _build_bwrap_command(
command: list[str], cwd: str, env: dict[str, str]
) -> list[str]:
"""Build a bubblewrap command with strict filesystem + network isolation.
Security model:
- **Whitelist-only filesystem**: only system directories (``/usr``, ``/etc``,
``/bin``, ``/lib``) are mounted read-only. Application code (``/app``),
home directories, ``/var``, ``/opt``, etc. are NOT accessible at all.
- **Writable workspace only**: the per-session workspace is the sole
writable path.
- **Clean environment**: ``--clearenv`` wipes all inherited env vars.
Only the explicitly-passed safe env vars are set inside the sandbox.
- **Network isolation**: ``--unshare-net`` blocks all network access.
- **Resource limits**: ulimit caps on processes (64), memory (512MB),
file size (50MB), and open FDs (256) to prevent fork bombs and abuse.
- **New session**: prevents terminal control escape.
- **Die with parent**: prevents orphaned sandbox processes.
"""
cmd = [
"bwrap",
# Create a new user namespace so bwrap can set up sandboxing
# inside unprivileged Docker containers (no CAP_SYS_ADMIN needed).
"--unshare-user",
# Wipe all inherited environment variables (API keys, secrets, etc.)
"--clearenv",
]
# Set only the safe env vars inside the sandbox
for key, value in env.items():
cmd.extend(["--setenv", key, value])
# System directories: read-only
for path in _SYSTEM_RO_BINDS:
cmd.extend(["--ro-bind", path, path])
# Compat paths: use --symlink when host path is a symlink (Debian 13),
# --ro-bind when it's a real directory (older distros).
for path, symlink_target in _COMPAT_PATHS:
if os.path.islink(path):
cmd.extend(["--symlink", symlink_target, path])
elif os.path.exists(path):
cmd.extend(["--ro-bind", path, path])
# Wrap the user command with resource limits:
# sh -c 'ulimit ...; exec "$@"' -- <original command>
# `exec "$@"` replaces the shell so there's no extra process overhead,
# and properly handles arguments with spaces.
limited_command = [
"sh",
"-c",
f'{_RESOURCE_LIMITS}; exec "$@"',
"--",
*command,
]
cmd.extend(
[
# Fresh virtual filesystems
"--dev",
"/dev",
"--proc",
"/proc",
"--tmpfs",
"/tmp",
# Workspace bind AFTER --tmpfs /tmp so it's visible through the tmpfs.
# (workspace lives under /tmp/copilot-<session>)
"--bind",
cwd,
cwd,
# Isolation
"--unshare-net",
"--die-with-parent",
"--new-session",
"--chdir",
cwd,
"--",
*limited_command,
]
)
return cmd
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
async def run_sandboxed(
command: list[str],
cwd: str,
timeout: int = _DEFAULT_TIMEOUT,
env: dict[str, str] | None = None,
) -> tuple[str, str, int, bool]:
"""Run a command inside a bubblewrap sandbox.
Callers **must** check :func:`has_full_sandbox` before calling this
function. If bubblewrap is not available, this function raises
:class:`RuntimeError` rather than running unsandboxed.
Returns:
(stdout, stderr, exit_code, timed_out)
"""
if not has_full_sandbox():
raise RuntimeError(
"run_sandboxed() requires bubblewrap but bwrap is not available. "
"Callers must check has_full_sandbox() before calling this function."
)
timeout = min(max(timeout, 1), _MAX_TIMEOUT)
safe_env = {
"PATH": "/usr/local/bin:/usr/bin:/bin",
"HOME": cwd,
"TMPDIR": cwd,
"LANG": "en_US.UTF-8",
"PYTHONDONTWRITEBYTECODE": "1",
"PYTHONIOENCODING": "utf-8",
}
if env:
safe_env.update(env)
full_command = _build_bwrap_command(command, cwd, safe_env)
try:
proc = await asyncio.create_subprocess_exec(
*full_command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=cwd,
env=safe_env,
)
try:
stdout_bytes, stderr_bytes = await asyncio.wait_for(
proc.communicate(), timeout=timeout
)
stdout = stdout_bytes.decode("utf-8", errors="replace")
stderr = stderr_bytes.decode("utf-8", errors="replace")
return stdout, stderr, proc.returncode or 0, False
except asyncio.TimeoutError:
proc.kill()
await proc.communicate()
return "", f"Execution timed out after {timeout}s", -1, True
except RuntimeError:
raise
except Exception as e:
return "", f"Sandbox error: {e}", -1, False

View File

@@ -15,7 +15,6 @@ from backend.data.model import (
OAuth2Credentials,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.exceptions import NotFoundError
logger = logging.getLogger(__name__)
@@ -360,7 +359,7 @@ async def match_user_credentials_to_graph(
_,
_,
) in aggregated_creds.items():
# Find first matching credential by provider, type, scopes, and host/URL
# Find first matching credential by provider, type, and scopes
matching_cred = next(
(
cred
@@ -375,10 +374,6 @@ async def match_user_credentials_to_graph(
cred.type != "host_scoped"
or _credential_is_for_host(cred, credential_requirements)
)
and (
cred.provider != ProviderName.MCP
or _credential_is_for_mcp_server(cred, credential_requirements)
)
),
None,
)
@@ -449,22 +444,6 @@ def _credential_is_for_host(
return credential.matches_url(list(requirements.discriminator_values)[0])
def _credential_is_for_mcp_server(
credential: Credentials,
requirements: CredentialsFieldInfo,
) -> bool:
"""Check if an MCP OAuth credential matches the required server URL."""
if not requirements.discriminator_values:
return True
server_url = (
credential.metadata.get("mcp_server_url")
if isinstance(credential, OAuth2Credentials)
else None
)
return server_url in requirements.discriminator_values if server_url else False
async def check_user_has_required_credentials(
user_id: str,
required_credentials: list[CredentialsMetaInput],

View File

@@ -1,151 +0,0 @@
"""Web fetch tool — safely retrieve public web page content."""
import logging
from typing import Any
import aiohttp
import html2text
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
ErrorResponse,
ToolResponseBase,
WebFetchResponse,
)
from backend.util.request import Requests
logger = logging.getLogger(__name__)
# Limits
_MAX_CONTENT_BYTES = 102_400 # 100 KB download cap
_REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=15)
# Content types we'll read as text
_TEXT_CONTENT_TYPES = {
"text/html",
"text/plain",
"text/xml",
"text/csv",
"text/markdown",
"application/json",
"application/xml",
"application/xhtml+xml",
"application/rss+xml",
"application/atom+xml",
}
def _is_text_content(content_type: str) -> bool:
base = content_type.split(";")[0].strip().lower()
return base in _TEXT_CONTENT_TYPES or base.startswith("text/")
def _html_to_text(html: str) -> str:
h = html2text.HTML2Text()
h.ignore_links = False
h.ignore_images = True
h.body_width = 0
return h.handle(html)
class WebFetchTool(BaseTool):
"""Safely fetch content from a public URL using SSRF-protected HTTP."""
@property
def name(self) -> str:
return "web_fetch"
@property
def description(self) -> str:
return (
"Fetch the content of a public web page by URL. "
"Returns readable text extracted from HTML by default. "
"Useful for reading documentation, articles, and API responses. "
"Only supports HTTP/HTTPS GET requests to public URLs "
"(private/internal network addresses are blocked)."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The public HTTP/HTTPS URL to fetch.",
},
"extract_text": {
"type": "boolean",
"description": (
"If true (default), extract readable text from HTML. "
"If false, return raw content."
),
"default": True,
},
},
"required": ["url"],
}
@property
def requires_auth(self) -> bool:
return False
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
url: str = (kwargs.get("url") or "").strip()
extract_text: bool = kwargs.get("extract_text", True)
session_id = session.session_id if session else None
if not url:
return ErrorResponse(
message="Please provide a URL to fetch.",
error="missing_url",
session_id=session_id,
)
try:
client = Requests(raise_for_status=False, retry_max_attempts=1)
response = await client.get(url, timeout=_REQUEST_TIMEOUT)
except ValueError as e:
# validate_url raises ValueError for SSRF / blocked IPs
return ErrorResponse(
message=f"URL blocked: {e}",
error="url_blocked",
session_id=session_id,
)
except Exception as e:
logger.warning(f"[web_fetch] Request failed for {url}: {e}")
return ErrorResponse(
message=f"Failed to fetch URL: {e}",
error="fetch_failed",
session_id=session_id,
)
content_type = response.headers.get("content-type", "")
if not _is_text_content(content_type):
return ErrorResponse(
message=f"Non-text content type: {content_type.split(';')[0]}",
error="unsupported_content_type",
session_id=session_id,
)
raw = response.content[:_MAX_CONTENT_BYTES]
text = raw.decode("utf-8", errors="replace")
if extract_text and "html" in content_type.lower():
text = _html_to_text(text)
return WebFetchResponse(
message=f"Fetched {url}",
url=response.url,
status_code=response.status,
content_type=content_type.split(";")[0].strip(),
content=text,
truncated=False,
session_id=session_id,
)

View File

@@ -88,9 +88,7 @@ class ListWorkspaceFilesTool(BaseTool):
@property
def description(self) -> str:
return (
"List files in the user's persistent workspace (cloud storage). "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Read/Glob tools instead. "
"List files in the user's workspace. "
"Returns file names, paths, sizes, and metadata. "
"Optionally filter by path prefix."
)
@@ -206,9 +204,7 @@ class ReadWorkspaceFileTool(BaseTool):
@property
def description(self) -> str:
return (
"Read a file from the user's persistent workspace (cloud storage). "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Read tool instead. "
"Read a file from the user's workspace. "
"Specify either file_id or path to identify the file. "
"For small text files, returns content directly. "
"For large or binary files, returns metadata and a download URL. "
@@ -382,9 +378,7 @@ class WriteWorkspaceFileTool(BaseTool):
@property
def description(self) -> str:
return (
"Write or create a file in the user's persistent workspace (cloud storage). "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Write tool instead. "
"Write or create a file in the user's workspace. "
"Provide the content as a base64-encoded string. "
f"Maximum file size is {Config().max_file_size_mb}MB. "
"Files are saved to the current session's folder by default. "
@@ -529,7 +523,7 @@ class DeleteWorkspaceFileTool(BaseTool):
@property
def description(self) -> str:
return (
"Delete a file from the user's persistent workspace (cloud storage). "
"Delete a file from the user's workspace. "
"Specify either file_id or path to identify the file. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."

View File

@@ -1,7 +1,7 @@
import asyncio
import logging
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Annotated, Any, List, Literal
from typing import TYPE_CHECKING, Annotated, List, Literal
from autogpt_libs.auth import get_user_id
from fastapi import (
@@ -14,7 +14,7 @@ from fastapi import (
Security,
status,
)
from pydantic import BaseModel, Field, SecretStr, model_validator
from pydantic import BaseModel, Field, SecretStr
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
from backend.api.features.library.db import set_preset_webhook, update_preset
@@ -39,11 +39,7 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
from backend.data.user import get_user_integrations
from backend.executor.utils import add_graph_execution
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
from backend.integrations.credentials_store import provider_matches
from backend.integrations.creds_manager import (
IntegrationCredentialsManager,
create_mcp_oauth_handler,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks import get_webhook_manager
@@ -106,37 +102,9 @@ class CredentialsMetaResponse(BaseModel):
scopes: list[str] | None
username: str | None
host: str | None = Field(
default=None,
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
default=None, description="Host pattern for host-scoped credentials"
)
@model_validator(mode="before")
@classmethod
def _normalize_provider(cls, data: Any) -> Any:
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug."""
if isinstance(data, dict):
prov = data.get("provider", "")
if isinstance(prov, str) and prov.startswith("ProviderName."):
member = prov.removeprefix("ProviderName.")
try:
data = {**data, "provider": ProviderName[member].value}
except KeyError:
pass
return data
@staticmethod
def get_host(cred: Credentials) -> str | None:
"""Extract host from credential: HostScoped host or MCP server URL."""
if isinstance(cred, HostScopedCredentials):
return cred.host
if isinstance(cred, OAuth2Credentials) and cred.provider in (
ProviderName.MCP,
ProviderName.MCP.value,
"ProviderName.MCP",
):
return (cred.metadata or {}).get("mcp_server_url")
return None
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
async def callback(
@@ -211,7 +179,9 @@ async def callback(
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=(CredentialsMetaResponse.get_host(credentials)),
host=(
credentials.host if isinstance(credentials, HostScopedCredentials) else None
),
)
@@ -229,7 +199,7 @@ async def list_credentials(
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred),
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
]
@@ -252,7 +222,7 @@ async def list_credentials_by_provider(
title=cred.title,
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
host=CredentialsMetaResponse.get_host(cred),
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
)
for cred in credentials
]
@@ -352,11 +322,7 @@ async def delete_credentials(
tokens_revoked = None
if isinstance(creds, OAuth2Credentials):
if provider_matches(provider.value, ProviderName.MCP.value):
# MCP uses dynamic per-server OAuth — create handler from metadata
handler = create_mcp_oauth_handler(creds)
else:
handler = _get_provider_oauth_handler(request, provider)
handler = _get_provider_oauth_handler(request, provider)
tokens_revoked = await handler.revoke_tokens(creds)
return CredentialsDeletionResponse(revoked=tokens_revoked)

View File

@@ -1,404 +0,0 @@
"""
MCP (Model Context Protocol) API routes.
Provides endpoints for MCP tool discovery and OAuth authentication so the
frontend can list available tools on an MCP server before placing a block.
"""
import logging
from typing import Annotated, Any
from urllib.parse import urlparse
import fastapi
from autogpt_libs.auth import get_user_id
from fastapi import Security
from pydantic import BaseModel, Field
from backend.api.features.integrations.router import CredentialsMetaResponse
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.request import HTTPClientError, Requests
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
settings = Settings()
router = fastapi.APIRouter(tags=["mcp"])
creds_manager = IntegrationCredentialsManager()
# ====================== Tool Discovery ====================== #
class DiscoverToolsRequest(BaseModel):
"""Request to discover tools on an MCP server."""
server_url: str = Field(description="URL of the MCP server")
auth_token: str | None = Field(
default=None,
description="Optional Bearer token for authenticated MCP servers",
)
class MCPToolResponse(BaseModel):
"""A single MCP tool returned by discovery."""
name: str
description: str
input_schema: dict[str, Any]
class DiscoverToolsResponse(BaseModel):
"""Response containing the list of tools available on an MCP server."""
tools: list[MCPToolResponse]
server_name: str | None = None
protocol_version: str | None = None
@router.post(
"/discover-tools",
summary="Discover available tools on an MCP server",
response_model=DiscoverToolsResponse,
)
async def discover_tools(
request: DiscoverToolsRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> DiscoverToolsResponse:
"""
Connect to an MCP server and return its available tools.
If the user has a stored MCP credential for this server URL, it will be
used automatically — no need to pass an explicit auth token.
"""
auth_token = request.auth_token
# Auto-use stored MCP credential when no explicit token is provided.
if not auth_token:
mcp_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
# Find the freshest credential for this server URL
best_cred: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and (cred.metadata or {}).get("mcp_server_url") == request.server_url
):
if best_cred is None or (
(cred.access_token_expires_at or 0)
> (best_cred.access_token_expires_at or 0)
):
best_cred = cred
if best_cred:
# Refresh the token if expired before using it
best_cred = await creds_manager.refresh_if_needed(user_id, best_cred)
logger.info(
f"Using MCP credential {best_cred.id} for {request.server_url}, "
f"expires_at={best_cred.access_token_expires_at}"
)
auth_token = best_cred.access_token.get_secret_value()
client = MCPClient(request.server_url, auth_token=auth_token)
try:
init_result = await client.initialize()
tools = await client.list_tools()
except HTTPClientError as e:
if e.status_code in (401, 403):
raise fastapi.HTTPException(
status_code=401,
detail="This MCP server requires authentication. "
"Please provide a valid auth token.",
)
raise fastapi.HTTPException(status_code=502, detail=str(e))
except MCPClientError as e:
raise fastapi.HTTPException(status_code=502, detail=str(e))
except Exception as e:
raise fastapi.HTTPException(
status_code=502,
detail=f"Failed to connect to MCP server: {e}",
)
return DiscoverToolsResponse(
tools=[
MCPToolResponse(
name=t.name,
description=t.description,
input_schema=t.input_schema,
)
for t in tools
],
server_name=(
init_result.get("serverInfo", {}).get("name")
or urlparse(request.server_url).hostname
or "MCP"
),
protocol_version=init_result.get("protocolVersion"),
)
# ======================== OAuth Flow ======================== #
class MCPOAuthLoginRequest(BaseModel):
"""Request to start an OAuth flow for an MCP server."""
server_url: str = Field(description="URL of the MCP server that requires OAuth")
class MCPOAuthLoginResponse(BaseModel):
"""Response with the OAuth login URL for the user to authenticate."""
login_url: str
state_token: str
@router.post(
"/oauth/login",
summary="Initiate OAuth login for an MCP server",
)
async def mcp_oauth_login(
request: MCPOAuthLoginRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> MCPOAuthLoginResponse:
"""
Discover OAuth metadata from the MCP server and return a login URL.
1. Discovers the protected-resource metadata (RFC 9728)
2. Fetches the authorization server metadata (RFC 8414)
3. Performs Dynamic Client Registration (RFC 7591) if available
4. Returns the authorization URL for the frontend to open in a popup
"""
client = MCPClient(request.server_url)
# Step 1: Discover protected-resource metadata (RFC 9728)
protected_resource = await client.discover_auth()
metadata: dict[str, Any] | None = None
if protected_resource and protected_resource.get("authorization_servers"):
auth_server_url = protected_resource["authorization_servers"][0]
resource_url = protected_resource.get("resource", request.server_url)
# Step 2a: Discover auth-server metadata (RFC 8414)
metadata = await client.discover_auth_server_metadata(auth_server_url)
else:
# Fallback: Some MCP servers (e.g. Linear) are their own auth server
# and serve OAuth metadata directly without protected-resource metadata.
# Don't assume a resource_url — omitting it lets the auth server choose
# the correct audience for the token (RFC 8707 resource is optional).
resource_url = None
metadata = await client.discover_auth_server_metadata(request.server_url)
if (
not metadata
or "authorization_endpoint" not in metadata
or "token_endpoint" not in metadata
):
raise fastapi.HTTPException(
status_code=400,
detail="This MCP server does not advertise OAuth support. "
"You may need to provide an auth token manually.",
)
authorize_url = metadata["authorization_endpoint"]
token_url = metadata["token_endpoint"]
registration_endpoint = metadata.get("registration_endpoint")
revoke_url = metadata.get("revocation_endpoint")
# Step 3: Dynamic Client Registration (RFC 7591) if available
frontend_base_url = settings.config.frontend_base_url
if not frontend_base_url:
raise fastapi.HTTPException(
status_code=500,
detail="Frontend base URL is not configured.",
)
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
client_id = ""
client_secret = ""
if registration_endpoint:
reg_result = await _register_mcp_client(
registration_endpoint, redirect_uri, request.server_url
)
if reg_result:
client_id = reg_result.get("client_id", "")
client_secret = reg_result.get("client_secret", "")
if not client_id:
client_id = "autogpt-platform"
# Step 4: Store state token with OAuth metadata for the callback
scopes = (protected_resource or {}).get("scopes_supported") or metadata.get(
"scopes_supported", []
)
state_token, code_challenge = await creds_manager.store.store_state_token(
user_id,
ProviderName.MCP.value,
scopes,
state_metadata={
"authorize_url": authorize_url,
"token_url": token_url,
"revoke_url": revoke_url,
"resource_url": resource_url,
"server_url": request.server_url,
"client_id": client_id,
"client_secret": client_secret,
},
)
# Step 5: Build and return the login URL
handler = MCPOAuthHandler(
client_id=client_id,
client_secret=client_secret,
redirect_uri=redirect_uri,
authorize_url=authorize_url,
token_url=token_url,
resource_url=resource_url,
)
login_url = handler.get_login_url(
scopes, state_token, code_challenge=code_challenge
)
return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token)
class MCPOAuthCallbackRequest(BaseModel):
"""Request to exchange an OAuth code for tokens."""
code: str = Field(description="Authorization code from OAuth callback")
state_token: str = Field(description="State token for CSRF verification")
class MCPOAuthCallbackResponse(BaseModel):
"""Response after successfully storing OAuth credentials."""
credential_id: str
@router.post(
"/oauth/callback",
summary="Exchange OAuth code for MCP tokens",
)
async def mcp_oauth_callback(
request: MCPOAuthCallbackRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> CredentialsMetaResponse:
"""
Exchange the authorization code for tokens and store the credential.
The frontend calls this after receiving the OAuth code from the popup.
On success, subsequent ``/discover-tools`` calls for the same server URL
will automatically use the stored credential.
"""
valid_state = await creds_manager.store.verify_state_token(
user_id, request.state_token, ProviderName.MCP.value
)
if not valid_state:
raise fastapi.HTTPException(
status_code=400,
detail="Invalid or expired state token.",
)
meta = valid_state.state_metadata
frontend_base_url = settings.config.frontend_base_url
if not frontend_base_url:
raise fastapi.HTTPException(
status_code=500,
detail="Frontend base URL is not configured.",
)
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
handler = MCPOAuthHandler(
client_id=meta["client_id"],
client_secret=meta.get("client_secret", ""),
redirect_uri=redirect_uri,
authorize_url=meta["authorize_url"],
token_url=meta["token_url"],
revoke_url=meta.get("revoke_url"),
resource_url=meta.get("resource_url"),
)
try:
credentials = await handler.exchange_code_for_tokens(
request.code, valid_state.scopes, valid_state.code_verifier
)
except Exception as e:
raise fastapi.HTTPException(
status_code=400,
detail=f"OAuth token exchange failed: {e}",
)
# Enrich credential metadata for future lookup and token refresh
if credentials.metadata is None:
credentials.metadata = {}
credentials.metadata["mcp_server_url"] = meta["server_url"]
credentials.metadata["mcp_client_id"] = meta["client_id"]
credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "")
credentials.metadata["mcp_token_url"] = meta["token_url"]
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
credentials.title = f"MCP: {hostname}"
# Remove old MCP credentials for the same server to prevent stale token buildup.
try:
old_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
for old in old_creds:
if (
isinstance(old, OAuth2Credentials)
and (old.metadata or {}).get("mcp_server_url") == meta["server_url"]
):
await creds_manager.store.delete_creds_by_id(user_id, old.id)
logger.info(
f"Removed old MCP credential {old.id} for {meta['server_url']}"
)
except Exception:
logger.debug("Could not clean up old MCP credentials", exc_info=True)
await creds_manager.create(user_id, credentials)
return CredentialsMetaResponse(
id=credentials.id,
provider=credentials.provider,
type=credentials.type,
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=credentials.metadata.get("mcp_server_url"),
)
# ======================== Helpers ======================== #
async def _register_mcp_client(
registration_endpoint: str,
redirect_uri: str,
server_url: str,
) -> dict[str, Any] | None:
"""Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server."""
try:
response = await Requests(raise_for_status=True).post(
registration_endpoint,
json={
"client_name": "AutoGPT Platform",
"redirect_uris": [redirect_uri],
"grant_types": ["authorization_code"],
"response_types": ["code"],
"token_endpoint_auth_method": "client_secret_post",
},
)
data = response.json()
if isinstance(data, dict) and "client_id" in data:
return data
return None
except Exception as e:
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
return None

View File

@@ -1,436 +0,0 @@
"""Tests for MCP API routes.
Uses httpx.AsyncClient with ASGITransport instead of fastapi.testclient.TestClient
to avoid creating blocking portals that can corrupt pytest-asyncio's session event loop.
"""
from unittest.mock import AsyncMock, patch
import fastapi
import httpx
import pytest
import pytest_asyncio
from autogpt_libs.auth import get_user_id
from backend.api.features.mcp.routes import router
from backend.blocks.mcp.client import MCPClientError, MCPTool
from backend.util.request import HTTPClientError
app = fastapi.FastAPI()
app.include_router(router)
app.dependency_overrides[get_user_id] = lambda: "test-user-id"
@pytest_asyncio.fixture(scope="module")
async def client():
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
yield c
class TestDiscoverTools:
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_success(self, client):
mock_tools = [
MCPTool(
name="get_weather",
description="Get weather for a city",
input_schema={
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
),
MCPTool(
name="add_numbers",
description="Add two numbers",
input_schema={
"type": "object",
"properties": {
"a": {"type": "number"},
"b": {"type": "number"},
},
},
),
]
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={
"protocolVersion": "2025-03-26",
"serverInfo": {"name": "test-server"},
}
)
instance.list_tools = AsyncMock(return_value=mock_tools)
response = await client.post(
"/discover-tools",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
data = response.json()
assert len(data["tools"]) == 2
assert data["tools"][0]["name"] == "get_weather"
assert data["tools"][1]["name"] == "add_numbers"
assert data["server_name"] == "test-server"
assert data["protocol_version"] == "2025-03-26"
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_with_auth_token(self, client):
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
)
instance.list_tools = AsyncMock(return_value=[])
response = await client.post(
"/discover-tools",
json={
"server_url": "https://mcp.example.com/mcp",
"auth_token": "my-secret-token",
},
)
assert response.status_code == 200
MockClient.assert_called_once_with(
"https://mcp.example.com/mcp",
auth_token="my-secret-token",
)
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_auto_uses_stored_credential(self, client):
"""When no explicit token is given, stored MCP credentials are used."""
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
stored_cred = OAuth2Credentials(
provider="mcp",
title="MCP: example.com",
access_token=SecretStr("stored-token-123"),
refresh_token=None,
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=[],
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
)
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred)
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
)
instance.list_tools = AsyncMock(return_value=[])
response = await client.post(
"/discover-tools",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
MockClient.assert_called_once_with(
"https://mcp.example.com/mcp",
auth_token="stored-token-123",
)
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_mcp_error(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=MCPClientError("Connection refused")
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://bad-server.example.com/mcp"},
)
assert response.status_code == 502
assert "Connection refused" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_generic_error(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
response = await client.post(
"/discover-tools",
json={"server_url": "https://timeout.example.com/mcp"},
)
assert response.status_code == 502
assert "Failed to connect" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_auth_required(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://auth-server.example.com/mcp"},
)
assert response.status_code == 401
assert "requires authentication" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_forbidden(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
)
response = await client.post(
"/discover-tools",
json={"server_url": "https://auth-server.example.com/mcp"},
)
assert response.status_code == 401
assert "requires authentication" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_missing_url(self, client):
response = await client.post("/discover-tools", json={})
assert response.status_code == 422
class TestOAuthLogin:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_success(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch(
"backend.api.features.mcp.routes._register_mcp_client"
) as mock_register,
):
instance = MockClient.return_value
instance.discover_auth = AsyncMock(
return_value={
"authorization_servers": ["https://auth.sentry.io"],
"resource": "https://mcp.sentry.dev/mcp",
"scopes_supported": ["openid"],
}
)
instance.discover_auth_server_metadata = AsyncMock(
return_value={
"authorization_endpoint": "https://auth.sentry.io/authorize",
"token_endpoint": "https://auth.sentry.io/token",
"registration_endpoint": "https://auth.sentry.io/register",
}
)
mock_register.return_value = {
"client_id": "registered-client-id",
"client_secret": "registered-secret",
}
mock_cm.store.store_state_token = AsyncMock(
return_value=("state-token-123", "code-challenge-abc")
)
mock_settings.config.frontend_base_url = "http://localhost:3000"
response = await client.post(
"/oauth/login",
json={"server_url": "https://mcp.sentry.dev/mcp"},
)
assert response.status_code == 200
data = response.json()
assert "login_url" in data
assert data["state_token"] == "state-token-123"
assert "auth.sentry.io/authorize" in data["login_url"]
assert "registered-client-id" in data["login_url"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_no_oauth_support(self, client):
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
instance = MockClient.return_value
instance.discover_auth = AsyncMock(return_value=None)
instance.discover_auth_server_metadata = AsyncMock(return_value=None)
response = await client.post(
"/oauth/login",
json={"server_url": "https://simple-server.example.com/mcp"},
)
assert response.status_code == 400
assert "does not advertise OAuth" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_fallback_to_public_client(self, client):
"""When DCR is unavailable, falls back to default public client ID."""
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
):
instance = MockClient.return_value
instance.discover_auth = AsyncMock(
return_value={
"authorization_servers": ["https://auth.example.com"],
"resource": "https://mcp.example.com/mcp",
}
)
instance.discover_auth_server_metadata = AsyncMock(
return_value={
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
# No registration_endpoint
}
)
mock_cm.store.store_state_token = AsyncMock(
return_value=("state-abc", "challenge-xyz")
)
mock_settings.config.frontend_base_url = "http://localhost:3000"
response = await client.post(
"/oauth/login",
json={"server_url": "https://mcp.example.com/mcp"},
)
assert response.status_code == 200
data = response.json()
assert "autogpt-platform" in data["login_url"]
class TestOAuthCallback:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_success(self, client):
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
mock_creds = OAuth2Credentials(
provider="mcp",
title=None,
access_token=SecretStr("access-token-xyz"),
refresh_token=None,
access_token_expires_at=None,
refresh_token_expires_at=None,
scopes=[],
metadata={
"mcp_token_url": "https://auth.sentry.io/token",
"mcp_resource_url": "https://mcp.sentry.dev/mcp",
},
)
with (
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
):
mock_settings.config.frontend_base_url = "http://localhost:3000"
# Mock state verification
mock_state = AsyncMock()
mock_state.state_metadata = {
"authorize_url": "https://auth.sentry.io/authorize",
"token_url": "https://auth.sentry.io/token",
"client_id": "test-client-id",
"client_secret": "test-secret",
"server_url": "https://mcp.sentry.dev/mcp",
}
mock_state.scopes = ["openid"]
mock_state.code_verifier = "verifier-123"
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
mock_cm.create = AsyncMock()
handler_instance = MockHandler.return_value
handler_instance.exchange_code_for_tokens = AsyncMock(
return_value=mock_creds
)
# Mock old credential cleanup
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
response = await client.post(
"/oauth/callback",
json={"code": "auth-code-abc", "state_token": "state-token-123"},
)
assert response.status_code == 200
data = response.json()
assert "id" in data
assert data["provider"] == "mcp"
assert data["type"] == "oauth2"
mock_cm.create.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_invalid_state(self, client):
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
mock_cm.store.verify_state_token = AsyncMock(return_value=None)
response = await client.post(
"/oauth/callback",
json={"code": "auth-code", "state_token": "bad-state"},
)
assert response.status_code == 400
assert "Invalid or expired" in response.json()["detail"]
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_token_exchange_fails(self, client):
with (
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
patch("backend.api.features.mcp.routes.settings") as mock_settings,
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
):
mock_settings.config.frontend_base_url = "http://localhost:3000"
mock_state = AsyncMock()
mock_state.state_metadata = {
"authorize_url": "https://auth.example.com/authorize",
"token_url": "https://auth.example.com/token",
"client_id": "cid",
"server_url": "https://mcp.example.com/mcp",
}
mock_state.scopes = []
mock_state.code_verifier = "v"
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
handler_instance = MockHandler.return_value
handler_instance.exchange_code_for_tokens = AsyncMock(
side_effect=RuntimeError("Token exchange failed")
)
response = await client.post(
"/oauth/callback",
json={"code": "bad-code", "state_token": "state"},
)
assert response.status_code == 400
assert "token exchange failed" in response.json()["detail"].lower()

View File

@@ -26,7 +26,6 @@ import backend.api.features.executions.review.routes
import backend.api.features.library.db
import backend.api.features.library.model
import backend.api.features.library.routes
import backend.api.features.mcp.routes as mcp_routes
import backend.api.features.oauth
import backend.api.features.otto.routes
import backend.api.features.postmark.postmark
@@ -344,11 +343,6 @@ app.include_router(
tags=["workspace"],
prefix="/api/workspace",
)
app.include_router(
mcp_routes.router,
tags=["v2", "mcp"],
prefix="/api/mcp",
)
app.include_router(
backend.api.features.oauth.router,
tags=["oauth"],

View File

@@ -64,7 +64,6 @@ class BlockType(Enum):
AI = "AI"
AYRSHARE = "Ayrshare"
HUMAN_IN_THE_LOOP = "Human In The Loop"
MCP_TOOL = "MCP Tool"
class BlockCategory(Enum):

View File

@@ -126,7 +126,6 @@ class PrintToConsoleBlock(Block):
output_schema=PrintToConsoleBlock.Output,
test_input={"text": "Hello, World!"},
is_sensitive_action=True,
disabled=True, # Disabled per Nick Tindle's request (OPEN-3000)
test_output=[
("output", "Hello, World!"),
("status", "printed"),

View File

@@ -1,300 +0,0 @@
"""
MCP (Model Context Protocol) Tool Block.
A single dynamic block that can connect to any MCP server, discover available tools,
and execute them. Works like AgentExecutorBlock — the user selects a tool from a
dropdown and the input/output schema adapts dynamically.
"""
import json
import logging
from typing import Any, Literal
from pydantic import SecretStr
from backend.blocks._base import (
Block,
BlockCategory,
BlockSchemaInput,
BlockSchemaOutput,
BlockType,
)
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.data.block import BlockInput, BlockOutput
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
OAuth2Credentials,
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.json import validate_with_jsonschema
logger = logging.getLogger(__name__)
TEST_CREDENTIALS = OAuth2Credentials(
id="test-mcp-cred",
provider="mcp",
access_token=SecretStr("mock-mcp-token"),
refresh_token=SecretStr("mock-refresh"),
scopes=[],
title="Mock MCP credential",
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}
MCPCredentials = CredentialsMetaInput[Literal[ProviderName.MCP], Literal["oauth2"]]
class MCPToolBlock(Block):
"""
A block that connects to an MCP server, lets the user pick a tool,
and executes it with dynamic input/output schema.
The flow:
1. User provides an MCP server URL (and optional credentials)
2. Frontend calls the backend to get tool list from that URL
3. User selects a tool from a dropdown (available_tools)
4. The block's input schema updates to reflect the selected tool's parameters
5. On execution, the block calls the MCP server to run the tool
"""
class Input(BlockSchemaInput):
server_url: str = SchemaField(
description="URL of the MCP server (Streamable HTTP endpoint)",
placeholder="https://mcp.example.com/mcp",
)
credentials: MCPCredentials = CredentialsField(
discriminator="server_url",
description="MCP server OAuth credentials",
default={},
)
selected_tool: str = SchemaField(
description="The MCP tool to execute",
placeholder="Select a tool",
default="",
)
tool_input_schema: dict[str, Any] = SchemaField(
description="JSON Schema for the selected tool's input parameters. "
"Populated automatically when a tool is selected.",
default={},
hidden=True,
)
tool_arguments: dict[str, Any] = SchemaField(
description="Arguments to pass to the selected MCP tool. "
"The fields here are defined by the tool's input schema.",
default={},
)
@classmethod
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
"""Return the tool's input schema so the builder UI renders dynamic fields."""
return data.get("tool_input_schema", {})
@classmethod
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
"""Return the current tool_arguments as defaults for the dynamic fields."""
return data.get("tool_arguments", {})
@classmethod
def get_missing_input(cls, data: BlockInput) -> set[str]:
"""Check which required tool arguments are missing."""
required_fields = cls.get_input_schema(data).get("required", [])
tool_arguments = data.get("tool_arguments", {})
return set(required_fields) - set(tool_arguments)
@classmethod
def get_mismatch_error(cls, data: BlockInput) -> str | None:
"""Validate tool_arguments against the tool's input schema."""
tool_schema = cls.get_input_schema(data)
if not tool_schema:
return None
tool_arguments = data.get("tool_arguments", {})
return validate_with_jsonschema(tool_schema, tool_arguments)
class Output(BlockSchemaOutput):
result: Any = SchemaField(description="The result returned by the MCP tool")
error: str = SchemaField(description="Error message if the tool call failed")
def __init__(self):
super().__init__(
id="a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4",
description="Connect to any MCP server and execute its tools. "
"Provide a server URL, select a tool, and pass arguments dynamically.",
categories={BlockCategory.DEVELOPER_TOOLS},
input_schema=MCPToolBlock.Input,
output_schema=MCPToolBlock.Output,
block_type=BlockType.MCP_TOOL,
test_credentials=TEST_CREDENTIALS,
test_input={
"server_url": "https://mcp.example.com/mcp",
"credentials": TEST_CREDENTIALS_INPUT,
"selected_tool": "get_weather",
"tool_input_schema": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
"tool_arguments": {"city": "London"},
},
test_output=[
(
"result",
{"weather": "sunny", "temperature": 20},
),
],
test_mock={
"_call_mcp_tool": lambda *a, **kw: {
"weather": "sunny",
"temperature": 20,
},
},
)
async def _call_mcp_tool(
self,
server_url: str,
tool_name: str,
arguments: dict[str, Any],
auth_token: str | None = None,
) -> Any:
"""Call a tool on the MCP server. Extracted for easy mocking in tests."""
client = MCPClient(server_url, auth_token=auth_token)
await client.initialize()
result = await client.call_tool(tool_name, arguments)
if result.is_error:
error_text = ""
for item in result.content:
if item.get("type") == "text":
error_text += item.get("text", "")
raise MCPClientError(
f"MCP tool '{tool_name}' returned an error: "
f"{error_text or 'Unknown error'}"
)
# Extract text content from the result
output_parts = []
for item in result.content:
if item.get("type") == "text":
text = item.get("text", "")
# Try to parse as JSON for structured output
try:
output_parts.append(json.loads(text))
except (json.JSONDecodeError, ValueError):
output_parts.append(text)
elif item.get("type") == "image":
output_parts.append(
{
"type": "image",
"data": item.get("data"),
"mimeType": item.get("mimeType"),
}
)
elif item.get("type") == "resource":
output_parts.append(item.get("resource", {}))
# If single result, unwrap
if len(output_parts) == 1:
return output_parts[0]
return output_parts if output_parts else None
@staticmethod
async def _auto_lookup_credential(
user_id: str, server_url: str
) -> "OAuth2Credentials | None":
"""Auto-lookup stored MCP credential for a server URL.
This is a fallback for nodes that don't have ``credentials`` explicitly
set (e.g. nodes created before the credential field was wired up).
"""
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
try:
mgr = IntegrationCredentialsManager()
mcp_creds = await mgr.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
best: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and (cred.metadata or {}).get("mcp_server_url") == server_url
):
if best is None or (
(cred.access_token_expires_at or 0)
> (best.access_token_expires_at or 0)
):
best = cred
if best:
best = await mgr.refresh_if_needed(user_id, best)
logger.info(
"Auto-resolved MCP credential %s for %s", best.id, server_url
)
return best
except Exception:
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
return None
async def run(
self,
input_data: Input,
*,
user_id: str,
credentials: OAuth2Credentials | None = None,
**kwargs,
) -> BlockOutput:
if not input_data.server_url:
yield "error", "MCP server URL is required"
return
if not input_data.selected_tool:
yield "error", "No tool selected. Please select a tool from the dropdown."
return
# Validate required tool arguments before calling the server.
# The executor-level validation is bypassed for MCP blocks because
# get_input_defaults() flattens tool_arguments, stripping tool_input_schema
# from the validation context.
required = set(input_data.tool_input_schema.get("required", []))
if required:
missing = required - set(input_data.tool_arguments.keys())
if missing:
yield "error", (
f"Missing required argument(s): {', '.join(sorted(missing))}. "
f"Please fill in all required fields marked with * in the block form."
)
return
# If no credentials were injected by the executor (e.g. legacy nodes
# that don't have the credentials field set), try to auto-lookup
# the stored MCP credential for this server URL.
if credentials is None:
credentials = await self._auto_lookup_credential(
user_id, input_data.server_url
)
auth_token = (
credentials.access_token.get_secret_value() if credentials else None
)
try:
result = await self._call_mcp_tool(
server_url=input_data.server_url,
tool_name=input_data.selected_tool,
arguments=input_data.tool_arguments,
auth_token=auth_token,
)
yield "result", result
except MCPClientError as e:
yield "error", str(e)
except Exception as e:
logger.exception(f"MCP tool call failed: {e}")
yield "error", f"MCP tool call failed: {str(e)}"

View File

@@ -1,323 +0,0 @@
"""
MCP (Model Context Protocol) HTTP client.
Implements the MCP Streamable HTTP transport for listing tools and calling tools
on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST.
Handles both JSON and SSE (text/event-stream) response formats per the MCP spec.
Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
"""
import json
import logging
from dataclasses import dataclass, field
from typing import Any
from backend.util.request import Requests
logger = logging.getLogger(__name__)
@dataclass
class MCPTool:
"""Represents an MCP tool discovered from a server."""
name: str
description: str
input_schema: dict[str, Any]
@dataclass
class MCPCallResult:
"""Result from calling an MCP tool."""
content: list[dict[str, Any]] = field(default_factory=list)
is_error: bool = False
class MCPClientError(Exception):
"""Raised when an MCP protocol error occurs."""
pass
class MCPClient:
"""
Async HTTP client for the MCP Streamable HTTP transport.
Communicates with MCP servers using JSON-RPC 2.0 over HTTP POST.
Supports optional Bearer token authentication.
"""
def __init__(
self,
server_url: str,
auth_token: str | None = None,
):
self.server_url = server_url.rstrip("/")
self.auth_token = auth_token
self._request_id = 0
self._session_id: str | None = None
def _next_id(self) -> int:
self._request_id += 1
return self._request_id
def _build_headers(self) -> dict[str, str]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
}
if self.auth_token:
headers["Authorization"] = f"Bearer {self.auth_token}"
if self._session_id:
headers["Mcp-Session-Id"] = self._session_id
return headers
def _build_jsonrpc_request(
self, method: str, params: dict[str, Any] | None = None
) -> dict[str, Any]:
req: dict[str, Any] = {
"jsonrpc": "2.0",
"method": method,
"id": self._next_id(),
}
if params is not None:
req["params"] = params
return req
@staticmethod
def _parse_sse_response(text: str) -> dict[str, Any]:
"""Parse an SSE (text/event-stream) response body into JSON-RPC data.
MCP servers may return responses as SSE with format:
event: message
data: {"jsonrpc":"2.0","result":{...},"id":1}
We extract the last `data:` line that contains a JSON-RPC response
(i.e. has an "id" field), which is the reply to our request.
"""
last_data: dict[str, Any] | None = None
for line in text.splitlines():
stripped = line.strip()
if stripped.startswith("data:"):
payload = stripped[len("data:") :].strip()
if not payload:
continue
try:
parsed = json.loads(payload)
# Only keep JSON-RPC responses (have "id"), skip notifications
if isinstance(parsed, dict) and "id" in parsed:
last_data = parsed
except (json.JSONDecodeError, ValueError):
continue
if last_data is None:
raise MCPClientError("No JSON-RPC response found in SSE stream")
return last_data
async def _send_request(
self, method: str, params: dict[str, Any] | None = None
) -> Any:
"""Send a JSON-RPC request to the MCP server and return the result.
Handles both ``application/json`` and ``text/event-stream`` responses
as required by the MCP Streamable HTTP transport specification.
"""
payload = self._build_jsonrpc_request(method, params)
headers = self._build_headers()
requests = Requests(
raise_for_status=True,
extra_headers=headers,
)
response = await requests.post(self.server_url, json=payload)
# Capture session ID from response (MCP Streamable HTTP transport)
session_id = response.headers.get("Mcp-Session-Id")
if session_id:
self._session_id = session_id
content_type = response.headers.get("content-type", "")
if "text/event-stream" in content_type:
body = self._parse_sse_response(response.text())
else:
try:
body = response.json()
except Exception as e:
raise MCPClientError(
f"MCP server returned non-JSON response: {e}"
) from e
if not isinstance(body, dict):
raise MCPClientError(
f"MCP server returned unexpected JSON type: {type(body).__name__}"
)
# Handle JSON-RPC error
if "error" in body:
error = body["error"]
if isinstance(error, dict):
raise MCPClientError(
f"MCP server error [{error.get('code', '?')}]: "
f"{error.get('message', 'Unknown error')}"
)
raise MCPClientError(f"MCP server error: {error}")
return body.get("result")
async def _send_notification(self, method: str) -> None:
"""Send a JSON-RPC notification (no id, no response expected)."""
headers = self._build_headers()
notification = {"jsonrpc": "2.0", "method": method}
requests = Requests(
raise_for_status=False,
extra_headers=headers,
)
await requests.post(self.server_url, json=notification)
async def discover_auth(self) -> dict[str, Any] | None:
"""Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec).
Returns ``None`` if the server doesn't require auth, otherwise returns
a dict with:
- ``authorization_servers``: list of authorization server URLs
- ``resource``: the resource indicator URL (usually the MCP endpoint)
- ``scopes_supported``: optional list of supported scopes
The caller can then fetch the authorization server metadata to get
``authorization_endpoint``, ``token_endpoint``, etc.
"""
from urllib.parse import urlparse
parsed = urlparse(self.server_url)
base = f"{parsed.scheme}://{parsed.netloc}"
# Build candidates for protected-resource metadata (per RFC 9728)
path = parsed.path.rstrip("/")
candidates = []
if path and path != "/":
candidates.append(f"{base}/.well-known/oauth-protected-resource{path}")
candidates.append(f"{base}/.well-known/oauth-protected-resource")
requests = Requests(
raise_for_status=False,
)
for url in candidates:
try:
resp = await requests.get(url)
if resp.status == 200:
data = resp.json()
if isinstance(data, dict) and "authorization_servers" in data:
return data
except Exception:
continue
return None
async def discover_auth_server_metadata(
self, auth_server_url: str
) -> dict[str, Any] | None:
"""Fetch the OAuth Authorization Server Metadata (RFC 8414).
Given an authorization server URL, returns a dict with:
- ``authorization_endpoint``
- ``token_endpoint``
- ``registration_endpoint`` (for dynamic client registration)
- ``scopes_supported``
- ``code_challenge_methods_supported``
- etc.
"""
from urllib.parse import urlparse
parsed = urlparse(auth_server_url)
base = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/")
# Try standard metadata endpoints (RFC 8414 and OpenID Connect)
candidates = []
if path and path != "/":
candidates.append(f"{base}/.well-known/oauth-authorization-server{path}")
candidates.append(f"{base}/.well-known/oauth-authorization-server")
candidates.append(f"{base}/.well-known/openid-configuration")
requests = Requests(
raise_for_status=False,
)
for url in candidates:
try:
resp = await requests.get(url)
if resp.status == 200:
data = resp.json()
if isinstance(data, dict) and "authorization_endpoint" in data:
return data
except Exception:
continue
return None
async def initialize(self) -> dict[str, Any]:
"""
Send the MCP initialize request.
This is required by the MCP protocol before any other requests.
Returns the server's capabilities.
"""
result = await self._send_request(
"initialize",
{
"protocolVersion": "2025-03-26",
"capabilities": {},
"clientInfo": {"name": "AutoGPT-Platform", "version": "1.0.0"},
},
)
# Send initialized notification (no response expected)
await self._send_notification("notifications/initialized")
return result or {}
async def list_tools(self) -> list[MCPTool]:
"""
Discover available tools from the MCP server.
Returns a list of MCPTool objects with name, description, and input schema.
"""
result = await self._send_request("tools/list")
if not result or "tools" not in result:
return []
tools = []
for tool_data in result["tools"]:
tools.append(
MCPTool(
name=tool_data.get("name", ""),
description=tool_data.get("description", ""),
input_schema=tool_data.get("inputSchema", {}),
)
)
return tools
async def call_tool(
self, tool_name: str, arguments: dict[str, Any]
) -> MCPCallResult:
"""
Call a tool on the MCP server.
Args:
tool_name: The name of the tool to call.
arguments: The arguments to pass to the tool.
Returns:
MCPCallResult with the tool's response content.
"""
result = await self._send_request(
"tools/call",
{"name": tool_name, "arguments": arguments},
)
if not result:
return MCPCallResult(is_error=True)
return MCPCallResult(
content=result.get("content", []),
is_error=result.get("isError", False),
)

View File

@@ -1,204 +0,0 @@
"""
MCP OAuth handler for MCP servers that use OAuth 2.1 authorization.
Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed,
MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata.
This handler accepts those endpoints at construction time.
"""
import logging
import time
import urllib.parse
from typing import ClassVar, Optional
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
from backend.integrations.oauth.base import BaseOAuthHandler
from backend.integrations.providers import ProviderName
from backend.util.request import Requests
logger = logging.getLogger(__name__)
class MCPOAuthHandler(BaseOAuthHandler):
"""
OAuth handler for MCP servers with dynamically-discovered endpoints.
Construction requires the authorization and token endpoint URLs,
which are obtained via MCP OAuth metadata discovery
(``MCPClient.discover_auth`` + ``discover_auth_server_metadata``).
"""
PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP
DEFAULT_SCOPES: ClassVar[list[str]] = []
def __init__(
self,
client_id: str,
client_secret: str,
redirect_uri: str,
*,
authorize_url: str,
token_url: str,
revoke_url: str | None = None,
resource_url: str | None = None,
):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
self.authorize_url = authorize_url
self.token_url = token_url
self.revoke_url = revoke_url
self.resource_url = resource_url
def get_login_url(
self,
scopes: list[str],
state: str,
code_challenge: Optional[str],
) -> str:
scopes = self.handle_default_scopes(scopes)
params: dict[str, str] = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"state": state,
}
if scopes:
params["scope"] = " ".join(scopes)
# PKCE (S256) — included when the caller provides a code_challenge
if code_challenge:
params["code_challenge"] = code_challenge
params["code_challenge_method"] = "S256"
# MCP spec requires resource indicator (RFC 8707)
if self.resource_url:
params["resource"] = self.resource_url
return f"{self.authorize_url}?{urllib.parse.urlencode(params)}"
async def exchange_code_for_tokens(
self,
code: str,
scopes: list[str],
code_verifier: Optional[str],
) -> OAuth2Credentials:
data: dict[str, str] = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self.redirect_uri,
"client_id": self.client_id,
}
if self.client_secret:
data["client_secret"] = self.client_secret
if code_verifier:
data["code_verifier"] = code_verifier
if self.resource_url:
data["resource"] = self.resource_url
response = await Requests(raise_for_status=True).post(
self.token_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
tokens = response.json()
if "error" in tokens:
raise RuntimeError(
f"Token exchange failed: {tokens.get('error_description', tokens['error'])}"
)
if "access_token" not in tokens:
raise RuntimeError("OAuth token response missing 'access_token' field")
now = int(time.time())
expires_in = tokens.get("expires_in")
return OAuth2Credentials(
provider=self.PROVIDER_NAME,
title=None,
access_token=SecretStr(tokens["access_token"]),
refresh_token=(
SecretStr(tokens["refresh_token"])
if tokens.get("refresh_token")
else None
),
access_token_expires_at=now + expires_in if expires_in else None,
refresh_token_expires_at=None,
scopes=scopes,
metadata={
"mcp_token_url": self.token_url,
"mcp_resource_url": self.resource_url,
},
)
async def _refresh_tokens(
self, credentials: OAuth2Credentials
) -> OAuth2Credentials:
if not credentials.refresh_token:
raise ValueError("No refresh token available for MCP OAuth credentials")
data: dict[str, str] = {
"grant_type": "refresh_token",
"refresh_token": credentials.refresh_token.get_secret_value(),
"client_id": self.client_id,
}
if self.client_secret:
data["client_secret"] = self.client_secret
if self.resource_url:
data["resource"] = self.resource_url
response = await Requests(raise_for_status=True).post(
self.token_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
tokens = response.json()
if "error" in tokens:
raise RuntimeError(
f"Token refresh failed: {tokens.get('error_description', tokens['error'])}"
)
if "access_token" not in tokens:
raise RuntimeError("OAuth refresh response missing 'access_token' field")
now = int(time.time())
expires_in = tokens.get("expires_in")
return OAuth2Credentials(
id=credentials.id,
provider=self.PROVIDER_NAME,
title=credentials.title,
access_token=SecretStr(tokens["access_token"]),
refresh_token=(
SecretStr(tokens["refresh_token"])
if tokens.get("refresh_token")
else credentials.refresh_token
),
access_token_expires_at=now + expires_in if expires_in else None,
refresh_token_expires_at=credentials.refresh_token_expires_at,
scopes=credentials.scopes,
metadata=credentials.metadata,
)
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
if not self.revoke_url:
return False
try:
data = {
"token": credentials.access_token.get_secret_value(),
"token_type_hint": "access_token",
"client_id": self.client_id,
}
await Requests().post(
self.revoke_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
return True
except Exception:
logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True)
return False

View File

@@ -1,109 +0,0 @@
"""
End-to-end tests against a real public MCP server.
These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp)
which is publicly accessible without authentication and returns SSE responses.
Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped
independently of the rest of the test suite (they require network access).
"""
import json
import os
import pytest
from backend.blocks.mcp.client import MCPClient
# Public MCP server that requires no authentication
OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp"
# Skip all tests in this module unless RUN_E2E env var is set
pytestmark = pytest.mark.skipif(
not os.environ.get("RUN_E2E"), reason="set RUN_E2E=1 to run e2e tests"
)
class TestRealMCPServer:
"""Tests against the live OpenAI docs MCP server."""
@pytest.mark.asyncio(loop_scope="session")
async def test_initialize(self):
"""Verify we can complete the MCP handshake with a real server."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
result = await client.initialize()
assert result["protocolVersion"] == "2025-03-26"
assert "serverInfo" in result
assert result["serverInfo"]["name"] == "openai-docs-mcp"
assert "tools" in result.get("capabilities", {})
@pytest.mark.asyncio(loop_scope="session")
async def test_list_tools(self):
"""Verify we can discover tools from a real MCP server."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
await client.initialize()
tools = await client.list_tools()
assert len(tools) >= 3 # server has at least 5 tools as of writing
tool_names = {t.name for t in tools}
# These tools are documented and should be stable
assert "search_openai_docs" in tool_names
assert "list_openai_docs" in tool_names
assert "fetch_openai_doc" in tool_names
# Verify schema structure
search_tool = next(t for t in tools if t.name == "search_openai_docs")
assert "query" in search_tool.input_schema.get("properties", {})
assert "query" in search_tool.input_schema.get("required", [])
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_list_api_endpoints(self):
"""Call the list_api_endpoints tool and verify we get real data."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
await client.initialize()
result = await client.call_tool("list_api_endpoints", {})
assert not result.is_error
assert len(result.content) >= 1
assert result.content[0]["type"] == "text"
data = json.loads(result.content[0]["text"])
assert "paths" in data or "urls" in data
# The OpenAI API should have many endpoints
total = data.get("total", len(data.get("paths", [])))
assert total > 50
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_search(self):
"""Search for docs and verify we get results."""
client = MCPClient(OPENAI_DOCS_MCP_URL)
await client.initialize()
result = await client.call_tool(
"search_openai_docs", {"query": "chat completions", "limit": 3}
)
assert not result.is_error
assert len(result.content) >= 1
@pytest.mark.asyncio(loop_scope="session")
async def test_sse_response_handling(self):
"""Verify the client correctly handles SSE responses from a real server.
This is the key test — our local test server returns JSON,
but real MCP servers typically return SSE. This proves the
SSE parsing works end-to-end.
"""
client = MCPClient(OPENAI_DOCS_MCP_URL)
# initialize() internally calls _send_request which must parse SSE
result = await client.initialize()
# If we got here without error, SSE parsing works
assert isinstance(result, dict)
assert "protocolVersion" in result
# Also verify list_tools works (another SSE response)
tools = await client.list_tools()
assert len(tools) > 0
assert all(hasattr(t, "name") for t in tools)

View File

@@ -1,389 +0,0 @@
"""
Integration tests for MCP client and MCPToolBlock against a real HTTP server.
These tests spin up a local MCP test server and run the full client/block flow
against it — no mocking, real HTTP requests.
"""
import asyncio
import json
import threading
from unittest.mock import patch
import pytest
from aiohttp import web
from pydantic import SecretStr
from backend.blocks.mcp.block import MCPToolBlock
from backend.blocks.mcp.client import MCPClient
from backend.blocks.mcp.test_server import create_test_mcp_app
from backend.data.model import OAuth2Credentials
MOCK_USER_ID = "test-user-integration"
class _MCPTestServer:
"""
Run an MCP test server in a background thread with its own event loop.
This avoids event loop conflicts with pytest-asyncio.
"""
def __init__(self, auth_token: str | None = None):
self.auth_token = auth_token
self.url: str = ""
self._runner: web.AppRunner | None = None
self._loop: asyncio.AbstractEventLoop | None = None
self._thread: threading.Thread | None = None
self._started = threading.Event()
def _run(self):
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._loop.run_until_complete(self._start())
self._started.set()
self._loop.run_forever()
async def _start(self):
app = create_test_mcp_app(auth_token=self.auth_token)
self._runner = web.AppRunner(app)
await self._runner.setup()
site = web.TCPSite(self._runner, "127.0.0.1", 0)
await site.start()
port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr]
self.url = f"http://127.0.0.1:{port}/mcp"
def start(self):
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
if not self._started.wait(timeout=5):
raise RuntimeError("MCP test server failed to start within 5 seconds")
return self
def stop(self):
if self._loop and self._runner:
asyncio.run_coroutine_threadsafe(self._runner.cleanup(), self._loop).result(
timeout=5
)
self._loop.call_soon_threadsafe(self._loop.stop)
if self._thread:
self._thread.join(timeout=5)
@pytest.fixture(scope="module")
def mcp_server():
"""Start a local MCP test server in a background thread."""
server = _MCPTestServer()
server.start()
yield server.url
server.stop()
@pytest.fixture(scope="module")
def mcp_server_with_auth():
"""Start a local MCP test server with auth in a background thread."""
server = _MCPTestServer(auth_token="test-secret-token")
server.start()
yield server.url, "test-secret-token"
server.stop()
@pytest.fixture(autouse=True)
def _allow_localhost():
"""
Allow 127.0.0.1 through SSRF protection for integration tests.
The Requests class blocks private IPs by default. We patch the Requests
constructor to always include 127.0.0.1 as a trusted origin so the local
test server is reachable.
"""
from backend.util.request import Requests
original_init = Requests.__init__
def patched_init(self, *args, **kwargs):
trusted = list(kwargs.get("trusted_origins") or [])
trusted.append("http://127.0.0.1")
kwargs["trusted_origins"] = trusted
original_init(self, *args, **kwargs)
with patch.object(Requests, "__init__", patched_init):
yield
def _make_client(url: str, auth_token: str | None = None) -> MCPClient:
"""Create an MCPClient for integration tests."""
return MCPClient(url, auth_token=auth_token)
# ── MCPClient integration tests ──────────────────────────────────────
class TestMCPClientIntegration:
"""Test MCPClient against a real local MCP server."""
@pytest.mark.asyncio(loop_scope="session")
async def test_initialize(self, mcp_server):
client = _make_client(mcp_server)
result = await client.initialize()
assert result["protocolVersion"] == "2025-03-26"
assert result["serverInfo"]["name"] == "test-mcp-server"
assert "tools" in result["capabilities"]
@pytest.mark.asyncio(loop_scope="session")
async def test_list_tools(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
tools = await client.list_tools()
assert len(tools) == 3
tool_names = {t.name for t in tools}
assert tool_names == {"get_weather", "add_numbers", "echo"}
# Check get_weather schema
weather = next(t for t in tools if t.name == "get_weather")
assert weather.description == "Get current weather for a city"
assert "city" in weather.input_schema["properties"]
assert weather.input_schema["required"] == ["city"]
# Check add_numbers schema
add = next(t for t in tools if t.name == "add_numbers")
assert "a" in add.input_schema["properties"]
assert "b" in add.input_schema["properties"]
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_get_weather(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
result = await client.call_tool("get_weather", {"city": "London"})
assert not result.is_error
assert len(result.content) == 1
assert result.content[0]["type"] == "text"
data = json.loads(result.content[0]["text"])
assert data["city"] == "London"
assert data["temperature"] == 22
assert data["condition"] == "sunny"
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_add_numbers(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
result = await client.call_tool("add_numbers", {"a": 3, "b": 7})
assert not result.is_error
data = json.loads(result.content[0]["text"])
assert data["result"] == 10
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_echo(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
result = await client.call_tool("echo", {"message": "Hello MCP!"})
assert not result.is_error
assert result.content[0]["text"] == "Hello MCP!"
@pytest.mark.asyncio(loop_scope="session")
async def test_call_unknown_tool(self, mcp_server):
client = _make_client(mcp_server)
await client.initialize()
result = await client.call_tool("nonexistent_tool", {})
assert result.is_error
assert "Unknown tool" in result.content[0]["text"]
@pytest.mark.asyncio(loop_scope="session")
async def test_auth_success(self, mcp_server_with_auth):
url, token = mcp_server_with_auth
client = _make_client(url, auth_token=token)
result = await client.initialize()
assert result["protocolVersion"] == "2025-03-26"
tools = await client.list_tools()
assert len(tools) == 3
@pytest.mark.asyncio(loop_scope="session")
async def test_auth_failure(self, mcp_server_with_auth):
url, _ = mcp_server_with_auth
client = _make_client(url, auth_token="wrong-token")
with pytest.raises(Exception):
await client.initialize()
@pytest.mark.asyncio(loop_scope="session")
async def test_auth_missing(self, mcp_server_with_auth):
url, _ = mcp_server_with_auth
client = _make_client(url)
with pytest.raises(Exception):
await client.initialize()
# ── MCPToolBlock integration tests ───────────────────────────────────
class TestMCPToolBlockIntegration:
"""Test MCPToolBlock end-to-end against a real local MCP server."""
@pytest.mark.asyncio(loop_scope="session")
async def test_full_flow_get_weather(self, mcp_server):
"""Full flow: discover tools, select one, execute it."""
# Step 1: Discover tools (simulating what the frontend/API would do)
client = _make_client(mcp_server)
await client.initialize()
tools = await client.list_tools()
assert len(tools) == 3
# Step 2: User selects "get_weather" and we get its schema
weather_tool = next(t for t in tools if t.name == "get_weather")
# Step 3: Execute the block — no credentials (public server)
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="get_weather",
tool_input_schema=weather_tool.input_schema,
tool_arguments={"city": "Paris"},
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
result = outputs[0][1]
assert result["city"] == "Paris"
assert result["temperature"] == 22
assert result["condition"] == "sunny"
@pytest.mark.asyncio(loop_scope="session")
async def test_full_flow_add_numbers(self, mcp_server):
"""Full flow for add_numbers tool."""
client = _make_client(mcp_server)
await client.initialize()
tools = await client.list_tools()
add_tool = next(t for t in tools if t.name == "add_numbers")
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="add_numbers",
tool_input_schema=add_tool.input_schema,
tool_arguments={"a": 42, "b": 58},
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1]["result"] == 100
@pytest.mark.asyncio(loop_scope="session")
async def test_full_flow_echo_plain_text(self, mcp_server):
"""Verify plain text (non-JSON) responses work."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="echo",
tool_input_schema={
"type": "object",
"properties": {"message": {"type": "string"}},
"required": ["message"],
},
tool_arguments={"message": "Hello from AutoGPT!"},
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1] == "Hello from AutoGPT!"
@pytest.mark.asyncio(loop_scope="session")
async def test_full_flow_unknown_tool_yields_error(self, mcp_server):
"""Calling an unknown tool should yield an error output."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="nonexistent_tool",
tool_arguments={},
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "error"
assert "returned an error" in outputs[0][1]
@pytest.mark.asyncio(loop_scope="session")
async def test_full_flow_with_auth(self, mcp_server_with_auth):
"""Full flow with authentication via credentials kwarg."""
url, token = mcp_server_with_auth
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=url,
selected_tool="echo",
tool_input_schema={
"type": "object",
"properties": {"message": {"type": "string"}},
"required": ["message"],
},
tool_arguments={"message": "Authenticated!"},
)
# Pass credentials via the standard kwarg (as the executor would)
test_creds = OAuth2Credentials(
id="test-cred",
provider="mcp",
access_token=SecretStr(token),
refresh_token=SecretStr(""),
scopes=[],
title="Test MCP credential",
)
outputs = []
async for name, data in block.run(
input_data, user_id=MOCK_USER_ID, credentials=test_creds
):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1] == "Authenticated!"
@pytest.mark.asyncio(loop_scope="session")
async def test_no_credentials_runs_without_auth(self, mcp_server):
"""Block runs without auth when no credentials are provided."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url=mcp_server,
selected_tool="echo",
tool_input_schema={
"type": "object",
"properties": {"message": {"type": "string"}},
"required": ["message"],
},
tool_arguments={"message": "No auth needed"},
)
outputs = []
async for name, data in block.run(
input_data, user_id=MOCK_USER_ID, credentials=None
):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1] == "No auth needed"

View File

@@ -1,619 +0,0 @@
"""
Tests for MCP client and MCPToolBlock.
"""
import json
from unittest.mock import AsyncMock, patch
import pytest
from backend.blocks.mcp.block import MCPToolBlock
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError
from backend.util.test import execute_block_test
# ── SSE parsing unit tests ───────────────────────────────────────────
class TestSSEParsing:
"""Tests for SSE (text/event-stream) response parsing."""
def test_parse_sse_simple(self):
sse = (
"event: message\n"
'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == {"tools": []}
assert body["id"] == 1
def test_parse_sse_with_notifications(self):
"""SSE streams can contain notifications (no id) before the response."""
sse = (
"event: message\n"
'data: {"jsonrpc":"2.0","method":"some/notification"}\n'
"\n"
"event: message\n"
'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == {"ok": True}
assert body["id"] == 2
def test_parse_sse_error_response(self):
sse = (
"event: message\n"
'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n'
)
body = MCPClient._parse_sse_response(sse)
assert "error" in body
assert body["error"]["code"] == -32600
def test_parse_sse_no_data_raises(self):
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
MCPClient._parse_sse_response("event: message\n\n")
def test_parse_sse_empty_raises(self):
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
MCPClient._parse_sse_response("")
def test_parse_sse_ignores_non_data_lines(self):
sse = (
": comment line\n"
"event: message\n"
"id: 123\n"
'data: {"jsonrpc":"2.0","result":"ok","id":1}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == "ok"
def test_parse_sse_uses_last_response(self):
"""If multiple responses exist, use the last one."""
sse = (
'data: {"jsonrpc":"2.0","result":"first","id":1}\n'
"\n"
'data: {"jsonrpc":"2.0","result":"second","id":2}\n'
"\n"
)
body = MCPClient._parse_sse_response(sse)
assert body["result"] == "second"
# ── MCPClient unit tests ─────────────────────────────────────────────
class TestMCPClient:
"""Tests for the MCP HTTP client."""
def test_build_headers_without_auth(self):
client = MCPClient("https://mcp.example.com")
headers = client._build_headers()
assert "Authorization" not in headers
assert headers["Content-Type"] == "application/json"
def test_build_headers_with_auth(self):
client = MCPClient("https://mcp.example.com", auth_token="my-token")
headers = client._build_headers()
assert headers["Authorization"] == "Bearer my-token"
def test_build_jsonrpc_request(self):
client = MCPClient("https://mcp.example.com")
req = client._build_jsonrpc_request("tools/list")
assert req["jsonrpc"] == "2.0"
assert req["method"] == "tools/list"
assert "id" in req
assert "params" not in req
def test_build_jsonrpc_request_with_params(self):
client = MCPClient("https://mcp.example.com")
req = client._build_jsonrpc_request(
"tools/call", {"name": "test", "arguments": {"x": 1}}
)
assert req["params"] == {"name": "test", "arguments": {"x": 1}}
def test_request_id_increments(self):
client = MCPClient("https://mcp.example.com")
req1 = client._build_jsonrpc_request("tools/list")
req2 = client._build_jsonrpc_request("tools/list")
assert req2["id"] > req1["id"]
def test_server_url_trailing_slash_stripped(self):
client = MCPClient("https://mcp.example.com/mcp/")
assert client.server_url == "https://mcp.example.com/mcp"
@pytest.mark.asyncio(loop_scope="session")
async def test_send_request_success(self):
client = MCPClient("https://mcp.example.com")
mock_response = AsyncMock()
mock_response.json.return_value = {
"jsonrpc": "2.0",
"result": {"tools": []},
"id": 1,
}
with patch.object(client, "_send_request", return_value={"tools": []}):
result = await client._send_request("tools/list")
assert result == {"tools": []}
@pytest.mark.asyncio(loop_scope="session")
async def test_send_request_error(self):
client = MCPClient("https://mcp.example.com")
async def mock_send(*args, **kwargs):
raise MCPClientError("MCP server error [-32600]: Invalid Request")
with patch.object(client, "_send_request", side_effect=mock_send):
with pytest.raises(MCPClientError, match="Invalid Request"):
await client._send_request("tools/list")
@pytest.mark.asyncio(loop_scope="session")
async def test_list_tools(self):
client = MCPClient("https://mcp.example.com")
mock_result = {
"tools": [
{
"name": "get_weather",
"description": "Get current weather for a city",
"inputSchema": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
},
{
"name": "search",
"description": "Search the web",
"inputSchema": {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
},
},
]
}
with patch.object(client, "_send_request", return_value=mock_result):
tools = await client.list_tools()
assert len(tools) == 2
assert tools[0].name == "get_weather"
assert tools[0].description == "Get current weather for a city"
assert tools[0].input_schema["properties"]["city"]["type"] == "string"
assert tools[1].name == "search"
@pytest.mark.asyncio(loop_scope="session")
async def test_list_tools_empty(self):
client = MCPClient("https://mcp.example.com")
with patch.object(client, "_send_request", return_value={"tools": []}):
tools = await client.list_tools()
assert tools == []
@pytest.mark.asyncio(loop_scope="session")
async def test_list_tools_none_result(self):
client = MCPClient("https://mcp.example.com")
with patch.object(client, "_send_request", return_value=None):
tools = await client.list_tools()
assert tools == []
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_success(self):
client = MCPClient("https://mcp.example.com")
mock_result = {
"content": [
{"type": "text", "text": json.dumps({"temp": 20, "city": "London"})}
],
"isError": False,
}
with patch.object(client, "_send_request", return_value=mock_result):
result = await client.call_tool("get_weather", {"city": "London"})
assert not result.is_error
assert len(result.content) == 1
assert result.content[0]["type"] == "text"
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_error(self):
client = MCPClient("https://mcp.example.com")
mock_result = {
"content": [{"type": "text", "text": "City not found"}],
"isError": True,
}
with patch.object(client, "_send_request", return_value=mock_result):
result = await client.call_tool("get_weather", {"city": "???"})
assert result.is_error
@pytest.mark.asyncio(loop_scope="session")
async def test_call_tool_none_result(self):
client = MCPClient("https://mcp.example.com")
with patch.object(client, "_send_request", return_value=None):
result = await client.call_tool("get_weather", {"city": "London"})
assert result.is_error
@pytest.mark.asyncio(loop_scope="session")
async def test_initialize(self):
client = MCPClient("https://mcp.example.com")
mock_result = {
"protocolVersion": "2025-03-26",
"capabilities": {"tools": {}},
"serverInfo": {"name": "test-server", "version": "1.0.0"},
}
with (
patch.object(client, "_send_request", return_value=mock_result) as mock_req,
patch.object(client, "_send_notification") as mock_notif,
):
result = await client.initialize()
mock_req.assert_called_once()
mock_notif.assert_called_once_with("notifications/initialized")
assert result["protocolVersion"] == "2025-03-26"
# ── MCPToolBlock unit tests ──────────────────────────────────────────
MOCK_USER_ID = "test-user-123"
class TestMCPToolBlock:
"""Tests for the MCPToolBlock."""
def test_block_instantiation(self):
block = MCPToolBlock()
assert block.id == "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
assert block.name == "MCPToolBlock"
def test_input_schema_has_required_fields(self):
block = MCPToolBlock()
schema = block.input_schema.jsonschema()
props = schema.get("properties", {})
assert "server_url" in props
assert "selected_tool" in props
assert "tool_arguments" in props
assert "credentials" in props
def test_output_schema(self):
block = MCPToolBlock()
schema = block.output_schema.jsonschema()
props = schema.get("properties", {})
assert "result" in props
assert "error" in props
def test_get_input_schema_with_tool_schema(self):
tool_schema = {
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
}
data = {"tool_input_schema": tool_schema}
result = MCPToolBlock.Input.get_input_schema(data)
assert result == tool_schema
def test_get_input_schema_without_tool_schema(self):
result = MCPToolBlock.Input.get_input_schema({})
assert result == {}
def test_get_input_defaults(self):
data = {"tool_arguments": {"city": "London"}}
result = MCPToolBlock.Input.get_input_defaults(data)
assert result == {"city": "London"}
def test_get_missing_input(self):
data = {
"tool_input_schema": {
"type": "object",
"properties": {
"city": {"type": "string"},
"units": {"type": "string"},
},
"required": ["city", "units"],
},
"tool_arguments": {"city": "London"},
}
missing = MCPToolBlock.Input.get_missing_input(data)
assert missing == {"units"}
def test_get_missing_input_all_present(self):
data = {
"tool_input_schema": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
"tool_arguments": {"city": "London"},
}
missing = MCPToolBlock.Input.get_missing_input(data)
assert missing == set()
@pytest.mark.asyncio(loop_scope="session")
async def test_run_with_mock(self):
"""Test the block using the built-in test infrastructure."""
block = MCPToolBlock()
await execute_block_test(block)
@pytest.mark.asyncio(loop_scope="session")
async def test_run_missing_server_url(self):
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="",
selected_tool="test",
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert outputs == [("error", "MCP server URL is required")]
@pytest.mark.asyncio(loop_scope="session")
async def test_run_missing_tool(self):
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="",
)
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert outputs == [
("error", "No tool selected. Please select a tool from the dropdown.")
]
@pytest.mark.asyncio(loop_scope="session")
async def test_run_success(self):
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="get_weather",
tool_input_schema={
"type": "object",
"properties": {"city": {"type": "string"}},
},
tool_arguments={"city": "London"},
)
async def mock_call(*args, **kwargs):
return {"temp": 20, "city": "London"}
block._call_mcp_tool = mock_call # type: ignore
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert len(outputs) == 1
assert outputs[0][0] == "result"
assert outputs[0][1] == {"temp": 20, "city": "London"}
@pytest.mark.asyncio(loop_scope="session")
async def test_run_mcp_error(self):
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="bad_tool",
)
async def mock_call(*args, **kwargs):
raise MCPClientError("Tool not found")
block._call_mcp_tool = mock_call # type: ignore
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert outputs[0][0] == "error"
assert "Tool not found" in outputs[0][1]
@pytest.mark.asyncio(loop_scope="session")
async def test_call_mcp_tool_parses_json_text(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[
{"type": "text", "text": '{"temp": 20}'},
],
is_error=False,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
result = await block._call_mcp_tool(
"https://mcp.example.com", "test_tool", {}
)
assert result == {"temp": 20}
@pytest.mark.asyncio(loop_scope="session")
async def test_call_mcp_tool_plain_text(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[
{"type": "text", "text": "Hello, world!"},
],
is_error=False,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
result = await block._call_mcp_tool(
"https://mcp.example.com", "test_tool", {}
)
assert result == "Hello, world!"
@pytest.mark.asyncio(loop_scope="session")
async def test_call_mcp_tool_multiple_content(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[
{"type": "text", "text": "Part 1"},
{"type": "text", "text": '{"part": 2}'},
],
is_error=False,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
result = await block._call_mcp_tool(
"https://mcp.example.com", "test_tool", {}
)
assert result == ["Part 1", {"part": 2}]
@pytest.mark.asyncio(loop_scope="session")
async def test_call_mcp_tool_error_result(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[{"type": "text", "text": "Something went wrong"}],
is_error=True,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
with pytest.raises(MCPClientError, match="returned an error"):
await block._call_mcp_tool("https://mcp.example.com", "test_tool", {})
@pytest.mark.asyncio(loop_scope="session")
async def test_call_mcp_tool_image_content(self):
block = MCPToolBlock()
mock_result = MCPCallResult(
content=[
{
"type": "image",
"data": "base64data==",
"mimeType": "image/png",
}
],
is_error=False,
)
async def mock_init(self):
return {}
async def mock_call(self, name, args):
return mock_result
with (
patch.object(MCPClient, "initialize", mock_init),
patch.object(MCPClient, "call_tool", mock_call),
):
result = await block._call_mcp_tool(
"https://mcp.example.com", "test_tool", {}
)
assert result == {
"type": "image",
"data": "base64data==",
"mimeType": "image/png",
}
@pytest.mark.asyncio(loop_scope="session")
async def test_run_with_credentials(self):
"""Verify the block uses OAuth2Credentials and passes auth token."""
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="test_tool",
)
captured_tokens: list[str | None] = []
async def mock_call(server_url, tool_name, arguments, auth_token=None):
captured_tokens.append(auth_token)
return "ok"
block._call_mcp_tool = mock_call # type: ignore
test_creds = OAuth2Credentials(
id="cred-123",
provider="mcp",
access_token=SecretStr("resolved-token"),
refresh_token=SecretStr(""),
scopes=[],
title="Test MCP credential",
)
async for _ in block.run(
input_data, user_id=MOCK_USER_ID, credentials=test_creds
):
pass
assert captured_tokens == ["resolved-token"]
@pytest.mark.asyncio(loop_scope="session")
async def test_run_without_credentials(self):
"""Verify the block works without credentials (public server)."""
block = MCPToolBlock()
input_data = MCPToolBlock.Input(
server_url="https://mcp.example.com/mcp",
selected_tool="test_tool",
)
captured_tokens: list[str | None] = []
async def mock_call(server_url, tool_name, arguments, auth_token=None):
captured_tokens.append(auth_token)
return "ok"
block._call_mcp_tool = mock_call # type: ignore
outputs = []
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
outputs.append((name, data))
assert captured_tokens == [None]
assert outputs == [("result", "ok")]

View File

@@ -1,242 +0,0 @@
"""
Tests for MCP OAuth handler.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
from backend.blocks.mcp.client import MCPClient
from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
def _mock_response(json_data: dict, status: int = 200) -> MagicMock:
"""Create a mock Response with synchronous json() (matching Requests.Response)."""
resp = MagicMock()
resp.status = status
resp.ok = 200 <= status < 300
resp.json.return_value = json_data
return resp
class TestMCPOAuthHandler:
"""Tests for the MCPOAuthHandler."""
def _make_handler(self, **overrides) -> MCPOAuthHandler:
defaults = {
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"redirect_uri": "https://app.example.com/callback",
"authorize_url": "https://auth.example.com/authorize",
"token_url": "https://auth.example.com/token",
}
defaults.update(overrides)
return MCPOAuthHandler(**defaults)
def test_get_login_url_basic(self):
handler = self._make_handler()
url = handler.get_login_url(
scopes=["read", "write"],
state="random-state-token",
code_challenge="S256-challenge-value",
)
assert "https://auth.example.com/authorize?" in url
assert "response_type=code" in url
assert "client_id=test-client-id" in url
assert "state=random-state-token" in url
assert "code_challenge=S256-challenge-value" in url
assert "code_challenge_method=S256" in url
assert "scope=read+write" in url
def test_get_login_url_with_resource(self):
handler = self._make_handler(resource_url="https://mcp.example.com/mcp")
url = handler.get_login_url(
scopes=[], state="state", code_challenge="challenge"
)
assert "resource=https" in url
def test_get_login_url_without_pkce(self):
handler = self._make_handler()
url = handler.get_login_url(scopes=["read"], state="state", code_challenge=None)
assert "code_challenge" not in url
assert "code_challenge_method" not in url
@pytest.mark.asyncio(loop_scope="session")
async def test_exchange_code_for_tokens(self):
handler = self._make_handler()
resp = _mock_response(
{
"access_token": "new-access-token",
"refresh_token": "new-refresh-token",
"expires_in": 3600,
"token_type": "Bearer",
}
)
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
instance = MockRequests.return_value
instance.post = AsyncMock(return_value=resp)
creds = await handler.exchange_code_for_tokens(
code="auth-code",
scopes=["read"],
code_verifier="pkce-verifier",
)
assert isinstance(creds, OAuth2Credentials)
assert creds.access_token.get_secret_value() == "new-access-token"
assert creds.refresh_token is not None
assert creds.refresh_token.get_secret_value() == "new-refresh-token"
assert creds.scopes == ["read"]
assert creds.access_token_expires_at is not None
@pytest.mark.asyncio(loop_scope="session")
async def test_refresh_tokens(self):
handler = self._make_handler()
existing_creds = OAuth2Credentials(
id="existing-id",
provider="mcp",
access_token=SecretStr("old-token"),
refresh_token=SecretStr("old-refresh"),
scopes=["read"],
title="test",
)
resp = _mock_response(
{
"access_token": "refreshed-token",
"refresh_token": "new-refresh",
"expires_in": 3600,
}
)
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
instance = MockRequests.return_value
instance.post = AsyncMock(return_value=resp)
refreshed = await handler._refresh_tokens(existing_creds)
assert refreshed.id == "existing-id"
assert refreshed.access_token.get_secret_value() == "refreshed-token"
assert refreshed.refresh_token is not None
assert refreshed.refresh_token.get_secret_value() == "new-refresh"
@pytest.mark.asyncio(loop_scope="session")
async def test_refresh_tokens_no_refresh_token(self):
handler = self._make_handler()
creds = OAuth2Credentials(
provider="mcp",
access_token=SecretStr("token"),
scopes=["read"],
title="test",
)
with pytest.raises(ValueError, match="No refresh token"):
await handler._refresh_tokens(creds)
@pytest.mark.asyncio(loop_scope="session")
async def test_revoke_tokens_no_url(self):
handler = self._make_handler(revoke_url=None)
creds = OAuth2Credentials(
provider="mcp",
access_token=SecretStr("token"),
scopes=[],
title="test",
)
result = await handler.revoke_tokens(creds)
assert result is False
@pytest.mark.asyncio(loop_scope="session")
async def test_revoke_tokens_with_url(self):
handler = self._make_handler(revoke_url="https://auth.example.com/revoke")
creds = OAuth2Credentials(
provider="mcp",
access_token=SecretStr("token"),
scopes=[],
title="test",
)
resp = _mock_response({}, status=200)
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
instance = MockRequests.return_value
instance.post = AsyncMock(return_value=resp)
result = await handler.revoke_tokens(creds)
assert result is True
class TestMCPClientDiscovery:
"""Tests for MCPClient OAuth metadata discovery."""
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_auth_found(self):
client = MCPClient("https://mcp.example.com/mcp")
metadata = {
"authorization_servers": ["https://auth.example.com"],
"resource": "https://mcp.example.com/mcp",
}
resp = _mock_response(metadata, status=200)
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
instance = MockRequests.return_value
instance.get = AsyncMock(return_value=resp)
result = await client.discover_auth()
assert result is not None
assert result["authorization_servers"] == ["https://auth.example.com"]
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_auth_not_found(self):
client = MCPClient("https://mcp.example.com/mcp")
resp = _mock_response({}, status=404)
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
instance = MockRequests.return_value
instance.get = AsyncMock(return_value=resp)
result = await client.discover_auth()
assert result is None
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_auth_server_metadata(self):
client = MCPClient("https://mcp.example.com/mcp")
server_metadata = {
"issuer": "https://auth.example.com",
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
"registration_endpoint": "https://auth.example.com/register",
"code_challenge_methods_supported": ["S256"],
}
resp = _mock_response(server_metadata, status=200)
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
instance = MockRequests.return_value
instance.get = AsyncMock(return_value=resp)
result = await client.discover_auth_server_metadata(
"https://auth.example.com"
)
assert result is not None
assert result["authorization_endpoint"] == "https://auth.example.com/authorize"
assert result["token_endpoint"] == "https://auth.example.com/token"

View File

@@ -1,162 +0,0 @@
"""
Minimal MCP server for integration testing.
Implements the MCP Streamable HTTP transport (JSON-RPC 2.0 over HTTP POST)
with a few sample tools. Runs on localhost with a random available port.
"""
import json
import logging
from aiohttp import web
logger = logging.getLogger(__name__)
# Sample tools this test server exposes
TEST_TOOLS = [
{
"name": "get_weather",
"description": "Get current weather for a city",
"inputSchema": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "City name",
},
},
"required": ["city"],
},
},
{
"name": "add_numbers",
"description": "Add two numbers together",
"inputSchema": {
"type": "object",
"properties": {
"a": {"type": "number", "description": "First number"},
"b": {"type": "number", "description": "Second number"},
},
"required": ["a", "b"],
},
},
{
"name": "echo",
"description": "Echo back the input message",
"inputSchema": {
"type": "object",
"properties": {
"message": {"type": "string", "description": "Message to echo"},
},
"required": ["message"],
},
},
]
def _handle_initialize(params: dict) -> dict:
return {
"protocolVersion": "2025-03-26",
"capabilities": {"tools": {"listChanged": False}},
"serverInfo": {"name": "test-mcp-server", "version": "1.0.0"},
}
def _handle_tools_list(params: dict) -> dict:
return {"tools": TEST_TOOLS}
def _handle_tools_call(params: dict) -> dict:
tool_name = params.get("name", "")
arguments = params.get("arguments", {})
if tool_name == "get_weather":
city = arguments.get("city", "Unknown")
return {
"content": [
{
"type": "text",
"text": json.dumps(
{"city": city, "temperature": 22, "condition": "sunny"}
),
}
],
}
elif tool_name == "add_numbers":
a = arguments.get("a", 0)
b = arguments.get("b", 0)
return {
"content": [{"type": "text", "text": json.dumps({"result": a + b})}],
}
elif tool_name == "echo":
message = arguments.get("message", "")
return {
"content": [{"type": "text", "text": message}],
}
else:
return {
"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}],
"isError": True,
}
HANDLERS = {
"initialize": _handle_initialize,
"tools/list": _handle_tools_list,
"tools/call": _handle_tools_call,
}
async def handle_mcp_request(request: web.Request) -> web.Response:
"""Handle incoming MCP JSON-RPC 2.0 requests."""
# Check auth if configured
expected_token = request.app.get("auth_token")
if expected_token:
auth_header = request.headers.get("Authorization", "")
if auth_header != f"Bearer {expected_token}":
return web.json_response(
{
"jsonrpc": "2.0",
"error": {"code": -32001, "message": "Unauthorized"},
"id": None,
},
status=401,
)
body = await request.json()
# Handle notifications (no id field) — just acknowledge
if "id" not in body:
return web.Response(status=202)
method = body.get("method", "")
params = body.get("params", {})
request_id = body.get("id")
handler = HANDLERS.get(method)
if not handler:
return web.json_response(
{
"jsonrpc": "2.0",
"error": {
"code": -32601,
"message": f"Method not found: {method}",
},
"id": request_id,
}
)
result = handler(params)
return web.json_response({"jsonrpc": "2.0", "result": result, "id": request_id})
def create_test_mcp_app(auth_token: str | None = None) -> web.Application:
"""Create an aiohttp app that acts as an MCP server."""
app = web.Application()
app.router.add_post("/mcp", handle_mcp_request)
if auth_token:
app["auth_token"] = auth_token
return app

View File

@@ -33,7 +33,6 @@ from backend.util import type as type_utils
from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError
from backend.util.json import SafeJson
from backend.util.models import Pagination
from backend.util.request import parse_url
from .block import BlockInput
from .db import BaseDbModel
@@ -450,9 +449,6 @@ class GraphModel(Graph, GraphMeta):
continue
if ProviderName.HTTP in field.provider:
continue
# MCP credentials are intentionally split by server URL
if ProviderName.MCP in field.provider:
continue
# If this happens, that means a block implementation probably needs
# to be updated.
@@ -509,18 +505,6 @@ class GraphModel(Graph, GraphMeta):
"required": ["id", "provider", "type"],
}
# Add a descriptive display title when URL-based discriminator values
# are present (e.g. "mcp.sentry.dev" instead of just "Mcp")
if (
field_info.discriminator
and not field_info.discriminator_mapping
and field_info.discriminator_values
):
hostnames = sorted(
parse_url(str(v)).netloc for v in field_info.discriminator_values
)
field_schema["display_name"] = ", ".join(hostnames)
# Add other (optional) field info items
field_schema.update(
field_info.model_dump(
@@ -565,17 +549,8 @@ class GraphModel(Graph, GraphMeta):
for graph in [self] + self.sub_graphs:
for node in graph.nodes:
# A node's credentials are optional if either:
# 1. The node metadata says so (credentials_optional=True), or
# 2. All credential fields on the block have defaults (not required by schema)
block_required = node.block.input_schema.get_required_fields()
creds_required_by_schema = any(
fname in block_required
for fname in node.block.input_schema.get_credentials_fields()
)
node_required_map[node.id] = (
not node.credentials_optional and creds_required_by_schema
)
# Track if this node requires credentials (credentials_optional=False means required)
node_required_map[node.id] = not node.credentials_optional
for (
field_name,
@@ -801,19 +776,6 @@ class GraphModel(Graph, GraphMeta):
"'credentials' and `*_credentials` are reserved"
)
# Check custom block-level validation (e.g., MCP dynamic tool arguments).
# Blocks can override get_missing_input to report additional missing fields
# beyond the standard top-level required fields.
if for_run:
credential_fields = InputSchema.get_credentials_fields()
custom_missing = InputSchema.get_missing_input(node.input_default)
for field_name in custom_missing:
if (
field_name not in provided_inputs
and field_name not in credential_fields
):
node_errors[node.id][field_name] = "This field is required"
# Get input schema properties and check dependencies
input_fields = InputSchema.model_fields

View File

@@ -462,120 +462,3 @@ def test_node_credentials_optional_with_other_metadata():
assert node.credentials_optional is True
assert node.metadata["position"] == {"x": 100, "y": 200}
assert node.metadata["customized_name"] == "My Custom Node"
# ============================================================================
# Tests for MCP Credential Deduplication
# ============================================================================
def test_mcp_credential_combine_different_servers():
"""Two MCP credential fields with different server URLs should produce
separate entries when combined (not merged into one)."""
from backend.data.model import CredentialsFieldInfo, CredentialsType
from backend.integrations.providers import ProviderName
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
field_sentry = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.MCP]),
credentials_types=oauth2_types,
credentials_scopes=None,
discriminator="server_url",
discriminator_values={"https://mcp.sentry.dev/mcp"},
)
field_linear = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.MCP]),
credentials_types=oauth2_types,
credentials_scopes=None,
discriminator="server_url",
discriminator_values={"https://mcp.linear.app/mcp"},
)
combined = CredentialsFieldInfo.combine(
(field_sentry, ("node-sentry", "credentials")),
(field_linear, ("node-linear", "credentials")),
)
# Should produce 2 separate credential entries
assert len(combined) == 2, (
f"Expected 2 credential entries for 2 MCP blocks with different servers, "
f"got {len(combined)}: {list(combined.keys())}"
)
# Each entry should contain the server hostname in its key
keys = list(combined.keys())
assert any(
"mcp.sentry.dev" in k for k in keys
), f"Expected 'mcp.sentry.dev' in one key, got {keys}"
assert any(
"mcp.linear.app" in k for k in keys
), f"Expected 'mcp.linear.app' in one key, got {keys}"
def test_mcp_credential_combine_same_server():
"""Two MCP credential fields with the same server URL should be combined
into one credential entry."""
from backend.data.model import CredentialsFieldInfo, CredentialsType
from backend.integrations.providers import ProviderName
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
field_a = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.MCP]),
credentials_types=oauth2_types,
credentials_scopes=None,
discriminator="server_url",
discriminator_values={"https://mcp.sentry.dev/mcp"},
)
field_b = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.MCP]),
credentials_types=oauth2_types,
credentials_scopes=None,
discriminator="server_url",
discriminator_values={"https://mcp.sentry.dev/mcp"},
)
combined = CredentialsFieldInfo.combine(
(field_a, ("node-a", "credentials")),
(field_b, ("node-b", "credentials")),
)
# Should produce 1 credential entry (same server URL)
assert len(combined) == 1, (
f"Expected 1 credential entry for 2 MCP blocks with same server, "
f"got {len(combined)}: {list(combined.keys())}"
)
def test_mcp_credential_combine_no_discriminator_values():
"""MCP credential fields without discriminator_values should be merged
into a single entry (backwards compat for blocks without server_url set)."""
from backend.data.model import CredentialsFieldInfo, CredentialsType
from backend.integrations.providers import ProviderName
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
field_a = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.MCP]),
credentials_types=oauth2_types,
credentials_scopes=None,
discriminator="server_url",
)
field_b = CredentialsFieldInfo(
credentials_provider=frozenset([ProviderName.MCP]),
credentials_types=oauth2_types,
credentials_scopes=None,
discriminator="server_url",
)
combined = CredentialsFieldInfo.combine(
(field_a, ("node-a", "credentials")),
(field_b, ("node-b", "credentials")),
)
# Should produce 1 entry (no URL differentiation)
assert len(combined) == 1, (
f"Expected 1 credential entry for MCP blocks without discriminator_values, "
f"got {len(combined)}: {list(combined.keys())}"
)

View File

@@ -29,7 +29,6 @@ from pydantic import (
GetCoreSchemaHandler,
SecretStr,
field_serializer,
model_validator,
)
from pydantic_core import (
CoreSchema,
@@ -503,25 +502,6 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
provider: CP
type: CT
@model_validator(mode="before")
@classmethod
def _normalize_legacy_provider(cls, data: Any) -> Any:
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug.
Python 3.13 changed ``str(StrEnum)`` to return ``"ClassName.MEMBER"``
instead of the plain value. Old stored credential references may have
``provider: "ProviderName.MCP"`` instead of ``"mcp"``.
"""
if isinstance(data, dict):
prov = data.get("provider", "")
if isinstance(prov, str) and prov.startswith("ProviderName."):
member = prov.removeprefix("ProviderName.")
try:
data = {**data, "provider": ProviderName[member].value}
except KeyError:
pass
return data
@classmethod
def allowed_providers(cls) -> tuple[ProviderName, ...] | None:
return get_args(cls.model_fields["provider"].annotation)
@@ -626,18 +606,11 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
] = defaultdict(list)
for field, key in fields:
if (
field.discriminator
and not field.discriminator_mapping
and field.discriminator_values
):
# URL-based discrimination (e.g. HTTP host-scoped, MCP server URL):
# Each unique host gets its own credential entry.
provider_prefix = next(iter(field.provider))
# Use .value for enum types to get the plain string (e.g. "mcp" not "ProviderName.MCP")
prefix_str = getattr(provider_prefix, "value", str(provider_prefix))
if field.provider == frozenset([ProviderName.HTTP]):
# HTTP host-scoped credentials can have different hosts that reqires different credential sets.
# Group by host extracted from the URL
providers = frozenset(
[cast(CP, prefix_str)]
[cast(CP, "http")]
+ [
cast(CP, parse_url(str(value)).netloc)
for value in field.discriminator_values

View File

@@ -20,7 +20,6 @@ from backend.blocks import get_block
from backend.blocks._base import BlockSchema
from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.io import AgentOutputBlock
from backend.blocks.mcp.block import MCPToolBlock
from backend.data import redis_client as redis
from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry
from backend.data.credit import UsageTransactionMetadata
@@ -229,18 +228,6 @@ async def execute_node(
_input_data.nodes_input_masks = nodes_input_masks
_input_data.user_id = user_id
input_data = _input_data.model_dump()
elif isinstance(node_block, MCPToolBlock):
_mcp_data = MCPToolBlock.Input(**node.input_default)
# Dynamic tool fields are flattened to top-level by validate_exec
# (via get_input_defaults). Collect them back into tool_arguments.
tool_schema = _mcp_data.tool_input_schema
tool_props = set(tool_schema.get("properties", {}).keys())
merged_args = {**_mcp_data.tool_arguments}
for key in tool_props:
if key in input_data:
merged_args[key] = input_data[key]
_mcp_data.tool_arguments = merged_args
input_data = _mcp_data.model_dump()
data.inputs = input_data
# Execute the node
@@ -277,34 +264,8 @@ async def execute_node(
# Handle regular credentials fields
for field_name, input_type in input_model.get_credentials_fields().items():
field_value = input_data.get(field_name)
if not field_value or (
isinstance(field_value, dict) and not field_value.get("id")
):
# No credentials configured — nullify so JSON schema validation
# doesn't choke on the empty default `{}`.
input_data[field_name] = None
continue # Block runs without credentials
credentials_meta = input_type(**field_value)
# Write normalized values back so JSON schema validation also passes
# (model_validator may have fixed legacy formats like "ProviderName.MCP")
input_data[field_name] = credentials_meta.model_dump(mode="json")
try:
credentials, lock = await creds_manager.acquire(
user_id, credentials_meta.id
)
except ValueError:
# Credential was deleted or doesn't exist.
# If the field has a default, run without credentials.
if input_model.model_fields[field_name].default is not None:
log_metadata.warning(
f"Credentials #{credentials_meta.id} not found, "
"running without (field has default)"
)
input_data[field_name] = None
continue
raise
credentials_meta = input_type(**input_data[field_name])
credentials, lock = await creds_manager.acquire(user_id, credentials_meta.id)
creds_locks.append(lock)
extra_exec_kwargs[field_name] = credentials

View File

@@ -260,13 +260,7 @@ async def _validate_node_input_credentials(
# Track if any credential field is missing for this node
has_missing_credentials = False
# A credential field is optional if the node metadata says so, or if
# the block schema declares a default for the field.
required_fields = block.input_schema.get_required_fields()
is_creds_optional = node.credentials_optional
for field_name, credentials_meta_type in credentials_fields.items():
field_is_optional = is_creds_optional or field_name not in required_fields
try:
# Check nodes_input_masks first, then input_default
field_value = None
@@ -279,7 +273,7 @@ async def _validate_node_input_credentials(
elif field_name in node.input_default:
# For optional credentials, don't use input_default - treat as missing
# This prevents stale credential IDs from failing validation
if field_is_optional:
if node.credentials_optional:
field_value = None
else:
field_value = node.input_default[field_name]
@@ -289,8 +283,8 @@ async def _validate_node_input_credentials(
isinstance(field_value, dict) and not field_value.get("id")
):
has_missing_credentials = True
# If credential field is optional, skip instead of error
if field_is_optional:
# If node has credentials_optional flag, mark for skipping instead of error
if node.credentials_optional:
continue # Don't add error, will be marked for skip after loop
else:
credential_errors[node.id][
@@ -340,16 +334,16 @@ async def _validate_node_input_credentials(
] = "Invalid credentials: type/provider mismatch"
continue
# If node has optional credentials and any are missing, allow running without.
# The executor will pass credentials=None to the block's run().
# If node has optional credentials and any are missing, mark for skipping
# But only if there are no other errors for this node
if (
has_missing_credentials
and is_creds_optional
and node.credentials_optional
and node.id not in credential_errors
):
nodes_to_skip.add(node.id)
logger.info(
f"Node #{node.id}: optional credentials not configured, "
"running without"
f"Node #{node.id} will be skipped: optional credentials not configured"
)
return credential_errors, nodes_to_skip

View File

@@ -495,7 +495,6 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip(
mock_block.input_schema.get_credentials_fields.return_value = {
"credentials": mock_credentials_field_type
}
mock_block.input_schema.get_required_fields.return_value = {"credentials"}
mock_node.block = mock_block
# Create mock graph
@@ -509,8 +508,8 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip(
nodes_input_masks=None,
)
# Node should NOT be in nodes_to_skip (runs without credentials) and not in errors
assert mock_node.id not in nodes_to_skip
# Node should be in nodes_to_skip, not in errors
assert mock_node.id in nodes_to_skip
assert mock_node.id not in errors
@@ -536,7 +535,6 @@ async def test_validate_node_input_credentials_required_missing_creds_error(
mock_block.input_schema.get_credentials_fields.return_value = {
"credentials": mock_credentials_field_type
}
mock_block.input_schema.get_required_fields.return_value = {"credentials"}
mock_node.block = mock_block
# Create mock graph

View File

@@ -22,27 +22,6 @@ from backend.util.settings import Settings
settings = Settings()
def provider_matches(stored: str, expected: str) -> bool:
"""Compare provider strings, handling Python 3.13 ``str(StrEnum)`` bug.
On Python 3.13, ``str(ProviderName.MCP)`` returns ``"ProviderName.MCP"``
instead of ``"mcp"``. OAuth states persisted with the buggy format need
to match when ``expected`` is the canonical value (e.g. ``"mcp"``).
"""
if stored == expected:
return True
if stored.startswith("ProviderName."):
member = stored.removeprefix("ProviderName.")
from backend.integrations.providers import ProviderName
try:
return ProviderName[member].value == expected
except KeyError:
pass
return False
# This is an overrride since ollama doesn't actually require an API key, but the creddential system enforces one be attached
ollama_credentials = APIKeyCredentials(
id="744fdc56-071a-4761-b5a5-0af0ce10a2b5",
@@ -410,7 +389,7 @@ class IntegrationCredentialsStore:
self, user_id: str, provider: str
) -> list[Credentials]:
credentials = await self.get_all_creds(user_id)
return [c for c in credentials if provider_matches(c.provider, provider)]
return [c for c in credentials if c.provider == provider]
async def get_authorized_providers(self, user_id: str) -> list[str]:
credentials = await self.get_all_creds(user_id)
@@ -506,6 +485,17 @@ class IntegrationCredentialsStore:
async with self.edit_user_integrations(user_id) as user_integrations:
user_integrations.oauth_states.append(state)
async with await self.locked_user_integrations(user_id):
user_integrations = await self._get_user_integrations(user_id)
oauth_states = user_integrations.oauth_states
oauth_states.append(state)
user_integrations.oauth_states = oauth_states
await self.db_manager.update_user_integrations(
user_id=user_id, data=user_integrations
)
return token, code_challenge
def _generate_code_challenge(self) -> tuple[str, str]:
@@ -531,7 +521,7 @@ class IntegrationCredentialsStore:
state
for state in oauth_states
if secrets.compare_digest(state.token, token)
and provider_matches(state.provider, provider)
and state.provider == provider
and state.expires_at > now.timestamp()
),
None,

View File

@@ -9,10 +9,7 @@ from redis.asyncio.lock import Lock as AsyncRedisLock
from backend.data.model import Credentials, OAuth2Credentials
from backend.data.redis_client import get_redis_async
from backend.integrations.credentials_store import (
IntegrationCredentialsStore,
provider_matches,
)
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
from backend.integrations.providers import ProviderName
from backend.util.exceptions import MissingConfigError
@@ -140,10 +137,7 @@ class IntegrationCredentialsManager:
self, user_id: str, credentials: OAuth2Credentials, lock: bool = True
) -> OAuth2Credentials:
async with self._locked(user_id, credentials.id, "refresh"):
if provider_matches(credentials.provider, ProviderName.MCP.value):
oauth_handler = create_mcp_oauth_handler(credentials)
else:
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
if oauth_handler.needs_refresh(credentials):
logger.debug(
f"Refreshing '{credentials.provider}' "
@@ -242,31 +236,3 @@ async def _get_provider_oauth_handler(provider_name_str: str) -> "BaseOAuthHandl
client_secret=client_secret,
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
)
def create_mcp_oauth_handler(
credentials: OAuth2Credentials,
) -> "BaseOAuthHandler":
"""Create an MCPOAuthHandler from credential metadata for token refresh.
MCP OAuth handlers have dynamic endpoints discovered per-server, so they
can't be registered as singletons in HANDLERS_BY_NAME. Instead, the handler
is reconstructed from metadata stored on the credential during initial auth.
"""
from backend.blocks.mcp.oauth import MCPOAuthHandler
meta = credentials.metadata or {}
token_url = meta.get("mcp_token_url", "")
if not token_url:
raise ValueError(
f"MCP credential {credentials.id} is missing 'mcp_token_url' metadata; "
"cannot refresh tokens"
)
return MCPOAuthHandler(
client_id=meta.get("mcp_client_id", ""),
client_secret=meta.get("mcp_client_secret", ""),
redirect_uri="", # Not needed for token refresh
authorize_url="", # Not needed for token refresh
token_url=token_url,
resource_url=meta.get("mcp_resource_url"),
)

View File

@@ -30,7 +30,6 @@ class ProviderName(str, Enum):
IDEOGRAM = "ideogram"
JINA = "jina"
LLAMA_API = "llama_api"
MCP = "mcp"
MEDIUM = "medium"
MEM0 = "mem0"
NOTION = "notion"

View File

@@ -51,21 +51,6 @@ async def _on_graph_activate(graph: "BaseGraph | GraphModel", user_id: str):
if (
creds_meta := new_node.input_default.get(creds_field_name)
) and not await get_credentials(creds_meta["id"]):
# If the credential field is optional (has a default in the
# schema, or node metadata marks it optional), clear the stale
# reference instead of blocking the save.
creds_field_optional = (
new_node.credentials_optional
or creds_field_name not in block_input_schema.get_required_fields()
)
if creds_field_optional:
new_node.input_default[creds_field_name] = {}
logger.warning(
f"Node #{new_node.id}: cleared stale optional "
f"credentials #{creds_meta['id']} for "
f"'{creds_field_name}'"
)
continue
raise ValueError(
f"Node #{new_node.id} input '{creds_field_name}' updated with "
f"non-existent credentials #{creds_meta['id']}"

View File

@@ -38,7 +38,6 @@ class Flag(str, Enum):
AGENT_ACTIVITY = "agent-activity"
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
CHAT = "chat"
COPILOT_SDK = "copilot-sdk"
def is_configured() -> bool:

View File

@@ -101,7 +101,7 @@ class HostResolver(abc.AbstractResolver):
def __init__(self, ssl_hostname: str, ip_addresses: list[str]):
self.ssl_hostname = ssl_hostname
self.ip_addresses = ip_addresses
self._default = aiohttp.ThreadedResolver()
self._default = aiohttp.AsyncResolver()
async def resolve(self, host, port=0, family=socket.AF_INET):
if host == self.ssl_hostname:
@@ -467,7 +467,7 @@ class Requests:
resolver = HostResolver(ssl_hostname=hostname, ip_addresses=ip_addresses)
ssl_context = ssl.create_default_context()
connector = aiohttp.TCPConnector(resolver=resolver, ssl=ssl_context)
session_kwargs: dict = {}
session_kwargs = {}
if connector:
session_kwargs["connector"] = connector

View File

@@ -897,29 +897,6 @@ files = [
{file = "charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a"},
]
[[package]]
name = "claude-agent-sdk"
version = "0.1.35"
description = "Python SDK for Claude Code"
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "claude_agent_sdk-0.1.35-py3-none-macosx_11_0_arm64.whl", hash = "sha256:df67f4deade77b16a9678b3a626c176498e40417f33b04beda9628287f375591"},
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:14963944f55ded7c8ed518feebfa5b4284aa6dd8d81aeff2e5b21a962ce65097"},
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:84344dcc535d179c1fc8a11c6f34c37c3b583447bdf09d869effb26514fd7a65"},
{file = "claude_agent_sdk-0.1.35-py3-none-win_amd64.whl", hash = "sha256:1b3d54b47448c93f6f372acd4d1757f047c3c1e8ef5804be7a1e3e53e2c79a5f"},
{file = "claude_agent_sdk-0.1.35.tar.gz", hash = "sha256:0f98e2b3c71ca85abfc042e7a35c648df88e87fda41c52e6779ef7b038dcbb52"},
]
[package.dependencies]
anyio = ">=4.0.0"
mcp = ">=0.1.0"
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""}
[package.extras]
dev = ["anyio[trio] (>=4.0.0)", "mypy (>=1.0.0)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.20.0)", "pytest-cov (>=4.0.0)", "ruff (>=0.1.0)"]
[[package]]
name = "cleo"
version = "2.1.0"
@@ -2616,18 +2593,6 @@ http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "httpx-sse"
version = "0.4.3"
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc"},
{file = "httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d"},
]
[[package]]
name = "huggingface-hub"
version = "1.4.1"
@@ -3345,39 +3310,6 @@ files = [
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
]
[[package]]
name = "mcp"
version = "1.26.0"
description = "Model Context Protocol SDK"
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca"},
{file = "mcp-1.26.0.tar.gz", hash = "sha256:db6e2ef491eecc1a0d93711a76f28dec2e05999f93afd48795da1c1137142c66"},
]
[package.dependencies]
anyio = ">=4.5"
httpx = ">=0.27.1"
httpx-sse = ">=0.4"
jsonschema = ">=4.20.0"
pydantic = ">=2.11.0,<3.0.0"
pydantic-settings = ">=2.5.2"
pyjwt = {version = ">=2.10.1", extras = ["crypto"]}
python-multipart = ">=0.0.9"
pywin32 = {version = ">=310", markers = "sys_platform == \"win32\""}
sse-starlette = ">=1.6.1"
starlette = ">=0.27"
typing-extensions = ">=4.9.0"
typing-inspection = ">=0.4.1"
uvicorn = {version = ">=0.31.1", markers = "sys_platform != \"emscripten\""}
[package.extras]
cli = ["python-dotenv (>=1.0.0)", "typer (>=0.16.0)"]
rich = ["rich (>=13.9.4)"]
ws = ["websockets (>=15.0.1)"]
[[package]]
name = "mdurl"
version = "0.1.2"
@@ -6062,7 +5994,7 @@ description = "Python for Window Extensions"
optional = false
python-versions = "*"
groups = ["main"]
markers = "sys_platform == \"win32\" or platform_system == \"Windows\""
markers = "platform_system == \"Windows\""
files = [
{file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"},
{file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"},
@@ -7042,28 +6974,6 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
pymysql = ["pymysql"]
sqlcipher = ["sqlcipher3_binary"]
[[package]]
name = "sse-starlette"
version = "3.2.0"
description = "SSE plugin for Starlette"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "sse_starlette-3.2.0-py3-none-any.whl", hash = "sha256:5876954bd51920fc2cd51baee47a080eb88a37b5b784e615abb0b283f801cdbf"},
{file = "sse_starlette-3.2.0.tar.gz", hash = "sha256:8127594edfb51abe44eac9c49e59b0b01f1039d0c7461c6fd91d4e03b70da422"},
]
[package.dependencies]
anyio = ">=4.7.0"
starlette = ">=0.49.1"
[package.extras]
daphne = ["daphne (>=4.2.0)"]
examples = ["aiosqlite (>=0.21.0)", "fastapi (>=0.115.12)", "sqlalchemy[asyncio] (>=2.0.41)", "uvicorn (>=0.34.0)"]
granian = ["granian (>=2.3.1)"]
uvicorn = ["uvicorn (>=0.34.0)"]
[[package]]
name = "stagehand"
version = "0.5.9"
@@ -8530,4 +8440,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.14"
content-hash = "55e095de555482f0fe47de7695f390fe93e7bcf739b31c391b2e5e3c3d938ae3"
content-hash = "fa9c5deadf593e815dd2190f58e22152373900603f5f244b9616cd721de84d2f"

View File

@@ -16,7 +16,6 @@ anthropic = "^0.79.0"
apscheduler = "^3.11.1"
autogpt-libs = { path = "../autogpt_libs", develop = true }
bleach = { extras = ["css"], version = "^6.2.0" }
claude-agent-sdk = "^0.1.0"
click = "^8.2.0"
cryptography = "^46.0"
discord-py = "^2.5.2"

View File

@@ -1,133 +0,0 @@
"""Tests for SDK security hooks — workspace paths, tool access, and deny messages.
These are pure unit tests with no external dependencies (no SDK, no DB, no server).
They validate that the security hooks correctly block unauthorized paths,
tool access, and dangerous input patterns.
Note: Bash command validation was removed — the SDK built-in Bash tool is not in
allowed_tools, and the bash_exec MCP tool has kernel-level network isolation
(unshare --net) making command-level parsing unnecessary.
"""
from backend.api.features.chat.sdk.security_hooks import (
_validate_tool_access,
_validate_workspace_path,
)
SDK_CWD = "/tmp/copilot-test-session"
def _is_denied(result: dict) -> bool:
hook = result.get("hookSpecificOutput", {})
return hook.get("permissionDecision") == "deny"
def _reason(result: dict) -> str:
return result.get("hookSpecificOutput", {}).get("permissionDecisionReason", "")
# ============================================================
# Workspace path validation (Read, Write, Edit, etc.)
# ============================================================
class TestWorkspacePathValidation:
def test_path_in_workspace(self):
result = _validate_workspace_path(
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD
)
assert not _is_denied(result)
def test_path_outside_workspace(self):
result = _validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD)
assert _is_denied(result)
def test_tool_results_allowed(self):
result = _validate_workspace_path(
"Read",
{"file_path": "~/.claude/projects/abc/tool-results/out.txt"},
SDK_CWD,
)
assert not _is_denied(result)
def test_claude_settings_blocked(self):
result = _validate_workspace_path(
"Read", {"file_path": "~/.claude/settings.json"}, SDK_CWD
)
assert _is_denied(result)
def test_claude_projects_without_tool_results(self):
result = _validate_workspace_path(
"Read", {"file_path": "~/.claude/projects/abc/credentials.json"}, SDK_CWD
)
assert _is_denied(result)
def test_no_path_allowed(self):
"""Glob/Grep without path defaults to cwd — should be allowed."""
result = _validate_workspace_path("Grep", {"pattern": "foo"}, SDK_CWD)
assert not _is_denied(result)
def test_path_traversal_with_dotdot(self):
result = _validate_workspace_path(
"Read", {"file_path": f"{SDK_CWD}/../../../etc/passwd"}, SDK_CWD
)
assert _is_denied(result)
# ============================================================
# Tool access validation
# ============================================================
class TestToolAccessValidation:
def test_blocked_tools(self):
for tool in ("bash", "shell", "exec", "terminal", "command"):
result = _validate_tool_access(tool, {})
assert _is_denied(result), f"Tool '{tool}' should be blocked"
def test_bash_builtin_blocked(self):
"""SDK built-in Bash (capital) is blocked as defence-in-depth."""
result = _validate_tool_access("Bash", {"command": "echo hello"}, SDK_CWD)
assert _is_denied(result)
assert "Bash" in _reason(result)
def test_workspace_tools_delegate(self):
result = _validate_tool_access(
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD
)
assert not _is_denied(result)
def test_dangerous_pattern_blocked(self):
result = _validate_tool_access("SomeUnknownTool", {"data": "sudo rm -rf /"})
assert _is_denied(result)
def test_safe_unknown_tool_allowed(self):
result = _validate_tool_access("SomeSafeTool", {"data": "hello world"})
assert not _is_denied(result)
# ============================================================
# Deny message quality (ntindle feedback)
# ============================================================
class TestDenyMessageClarity:
"""Deny messages must include [SECURITY] and 'cannot be bypassed'
so the model knows the restriction is enforced, not a suggestion."""
def test_blocked_tool_message(self):
reason = _reason(_validate_tool_access("bash", {}))
assert "[SECURITY]" in reason
assert "cannot be bypassed" in reason
def test_bash_builtin_blocked_message(self):
reason = _reason(_validate_tool_access("Bash", {"command": "echo hello"}))
assert "[SECURITY]" in reason
assert "cannot be bypassed" in reason
def test_workspace_path_message(self):
reason = _reason(
_validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD)
)
assert "[SECURITY]" in reason
assert "cannot be bypassed" in reason

View File

@@ -1,255 +0,0 @@
"""Unit tests for JSONL transcript management utilities."""
import json
import os
from backend.api.features.chat.sdk.transcript import (
STRIPPABLE_TYPES,
read_transcript_file,
strip_progress_entries,
validate_transcript,
write_transcript_to_tempfile,
)
def _make_jsonl(*entries: dict) -> str:
return "\n".join(json.dumps(e) for e in entries) + "\n"
# --- Fixtures ---
METADATA_LINE = {"type": "queue-operation", "subtype": "create"}
FILE_HISTORY = {"type": "file-history-snapshot", "files": []}
USER_MSG = {"type": "user", "uuid": "u1", "message": {"role": "user", "content": "hi"}}
ASST_MSG = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {"role": "assistant", "content": "hello"},
}
PROGRESS_ENTRY = {
"type": "progress",
"uuid": "p1",
"parentUuid": "u1",
"data": {"type": "bash_progress", "stdout": "running..."},
}
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
# --- read_transcript_file ---
class TestReadTranscriptFile:
def test_returns_content_for_valid_file(self, tmp_path):
path = tmp_path / "session.jsonl"
path.write_text(VALID_TRANSCRIPT)
result = read_transcript_file(str(path))
assert result is not None
assert "user" in result
def test_returns_none_for_missing_file(self):
assert read_transcript_file("/nonexistent/path.jsonl") is None
def test_returns_none_for_empty_path(self):
assert read_transcript_file("") is None
def test_returns_none_for_empty_file(self, tmp_path):
path = tmp_path / "empty.jsonl"
path.write_text("")
assert read_transcript_file(str(path)) is None
def test_returns_none_for_metadata_only(self, tmp_path):
content = _make_jsonl(METADATA_LINE, FILE_HISTORY)
path = tmp_path / "meta.jsonl"
path.write_text(content)
assert read_transcript_file(str(path)) is None
def test_returns_none_for_invalid_json(self, tmp_path):
path = tmp_path / "bad.jsonl"
path.write_text("not json\n{}\n{}\n")
assert read_transcript_file(str(path)) is None
def test_no_size_limit(self, tmp_path):
"""Large files are accepted — bucket storage has no size limit."""
big_content = {"type": "user", "uuid": "u9", "data": "x" * 1_000_000}
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, big_content, ASST_MSG)
path = tmp_path / "big.jsonl"
path.write_text(content)
result = read_transcript_file(str(path))
assert result is not None
# --- write_transcript_to_tempfile ---
class TestWriteTranscriptToTempfile:
"""Tests use /tmp/copilot-* paths to satisfy the sandbox prefix check."""
def test_writes_file_and_returns_path(self):
cwd = "/tmp/copilot-test-write"
try:
result = write_transcript_to_tempfile(
VALID_TRANSCRIPT, "sess-1234-abcd", cwd
)
assert result is not None
assert os.path.isfile(result)
assert result.endswith(".jsonl")
with open(result) as f:
assert f.read() == VALID_TRANSCRIPT
finally:
import shutil
shutil.rmtree(cwd, ignore_errors=True)
def test_creates_parent_directory(self):
cwd = "/tmp/copilot-test-mkdir"
try:
result = write_transcript_to_tempfile(VALID_TRANSCRIPT, "sess-1234", cwd)
assert result is not None
assert os.path.isdir(cwd)
finally:
import shutil
shutil.rmtree(cwd, ignore_errors=True)
def test_uses_session_id_prefix(self):
cwd = "/tmp/copilot-test-prefix"
try:
result = write_transcript_to_tempfile(
VALID_TRANSCRIPT, "abcdef12-rest", cwd
)
assert result is not None
assert "abcdef12" in os.path.basename(result)
finally:
import shutil
shutil.rmtree(cwd, ignore_errors=True)
def test_rejects_cwd_outside_sandbox(self, tmp_path):
cwd = str(tmp_path / "not-copilot")
result = write_transcript_to_tempfile(VALID_TRANSCRIPT, "sess-1234", cwd)
assert result is None
# --- validate_transcript ---
class TestValidateTranscript:
def test_valid_transcript(self):
assert validate_transcript(VALID_TRANSCRIPT) is True
def test_none_content(self):
assert validate_transcript(None) is False
def test_empty_content(self):
assert validate_transcript("") is False
def test_metadata_only(self):
content = _make_jsonl(METADATA_LINE, FILE_HISTORY)
assert validate_transcript(content) is False
def test_user_only_no_assistant(self):
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG)
assert validate_transcript(content) is False
def test_assistant_only_no_user(self):
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, ASST_MSG)
assert validate_transcript(content) is False
def test_invalid_json_returns_false(self):
assert validate_transcript("not json\n{}\n{}\n") is False
# --- strip_progress_entries ---
class TestStripProgressEntries:
def test_strips_all_strippable_types(self):
"""All STRIPPABLE_TYPES are removed from the output."""
entries = [
USER_MSG,
{"type": "progress", "uuid": "p1", "parentUuid": "u1"},
{"type": "file-history-snapshot", "files": []},
{"type": "queue-operation", "subtype": "create"},
{"type": "summary", "text": "..."},
{"type": "pr-link", "url": "..."},
ASST_MSG,
]
result = strip_progress_entries(_make_jsonl(*entries))
result_types = {json.loads(line)["type"] for line in result.strip().split("\n")}
assert result_types == {"user", "assistant"}
for stype in STRIPPABLE_TYPES:
assert stype not in result_types
def test_reparents_children_of_stripped_entries(self):
"""An assistant message whose parent is a progress entry gets reparented."""
progress = {
"type": "progress",
"uuid": "p1",
"parentUuid": "u1",
"data": {"type": "bash_progress"},
}
asst = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "p1", # Points to progress
"message": {"role": "assistant", "content": "done"},
}
content = _make_jsonl(USER_MSG, progress, asst)
result = strip_progress_entries(content)
lines = [json.loads(line) for line in result.strip().split("\n")]
asst_entry = next(e for e in lines if e["type"] == "assistant")
# Should be reparented to u1 (the user message)
assert asst_entry["parentUuid"] == "u1"
def test_reparents_through_chain(self):
"""Reparenting walks through multiple stripped entries."""
p1 = {"type": "progress", "uuid": "p1", "parentUuid": "u1"}
p2 = {"type": "progress", "uuid": "p2", "parentUuid": "p1"}
p3 = {"type": "progress", "uuid": "p3", "parentUuid": "p2"}
asst = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "p3", # 3 levels deep
"message": {"role": "assistant", "content": "done"},
}
content = _make_jsonl(USER_MSG, p1, p2, p3, asst)
result = strip_progress_entries(content)
lines = [json.loads(line) for line in result.strip().split("\n")]
asst_entry = next(e for e in lines if e["type"] == "assistant")
assert asst_entry["parentUuid"] == "u1"
def test_preserves_non_strippable_entries(self):
"""User, assistant, and system entries are preserved."""
system = {"type": "system", "uuid": "s1", "message": "prompt"}
content = _make_jsonl(system, USER_MSG, ASST_MSG)
result = strip_progress_entries(content)
result_types = [json.loads(line)["type"] for line in result.strip().split("\n")]
assert result_types == ["system", "user", "assistant"]
def test_empty_input(self):
result = strip_progress_entries("")
# Should return just a newline (empty content stripped)
assert result.strip() == ""
def test_no_strippable_entries(self):
"""When there's nothing to strip, output matches input structure."""
content = _make_jsonl(USER_MSG, ASST_MSG)
result = strip_progress_entries(content)
result_lines = result.strip().split("\n")
assert len(result_lines) == 2
def test_handles_entries_without_uuid(self):
"""Entries without uuid field are handled gracefully."""
no_uuid = {"type": "queue-operation", "subtype": "create"}
content = _make_jsonl(no_uuid, USER_MSG, ASST_MSG)
result = strip_progress_entries(content)
result_types = [json.loads(line)["type"] for line in result.strip().split("\n")]
# queue-operation is strippable
assert "queue-operation" not in result_types
assert "user" in result_types
assert "assistant" in result_types

View File

@@ -1,96 +0,0 @@
import { NextResponse } from "next/server";
/**
* Safely encode a value as JSON for embedding in a script tag.
* Escapes characters that could break out of the script context to prevent XSS.
*/
function safeJsonStringify(value: unknown): string {
return JSON.stringify(value)
.replace(/</g, "\\u003c")
.replace(/>/g, "\\u003e")
.replace(/&/g, "\\u0026");
}
// MCP-specific OAuth callback route.
//
// Unlike the generic oauth_callback which relies on window.opener.postMessage,
// this route uses BroadcastChannel as the PRIMARY communication method.
// This is critical because cross-origin OAuth flows (e.g. Sentry → localhost)
// often lose window.opener due to COOP (Cross-Origin-Opener-Policy) headers.
//
// BroadcastChannel works across all same-origin tabs/popups regardless of opener.
export async function GET(request: Request) {
const { searchParams } = new URL(request.url);
const code = searchParams.get("code");
const state = searchParams.get("state");
const success = Boolean(code && state);
const message = success
? { success: true, code, state }
: {
success: false,
message: `Missing parameters: ${searchParams.toString()}`,
};
return new NextResponse(
`<!DOCTYPE html>
<html>
<head><title>MCP Sign-in</title></head>
<body style="font-family: system-ui, -apple-system, sans-serif; display: flex; align-items: center; justify-content: center; min-height: 100vh; margin: 0; background: #f9fafb;">
<div style="text-align: center; max-width: 400px; padding: 2rem;">
<div id="spinner" style="margin: 0 auto 1rem; width: 32px; height: 32px; border: 3px solid #e5e7eb; border-top-color: #3b82f6; border-radius: 50%; animation: spin 0.8s linear infinite;"></div>
<p id="status" style="color: #374151; font-size: 16px;">Completing sign-in...</p>
</div>
<style>@keyframes spin { to { transform: rotate(360deg); } }</style>
<script>
(function() {
var msg = ${safeJsonStringify(message)};
var sent = false;
// Method 1: BroadcastChannel (reliable across tabs/popups, no opener needed)
try {
var bc = new BroadcastChannel("mcp_oauth");
bc.postMessage({ type: "mcp_oauth_result", success: msg.success, code: msg.code, state: msg.state, message: msg.message });
bc.close();
sent = true;
} catch(e) { /* BroadcastChannel not supported */ }
// Method 2: window.opener.postMessage (fallback for same-origin popups)
try {
if (window.opener && !window.opener.closed) {
window.opener.postMessage(
{ message_type: "mcp_oauth_result", success: msg.success, code: msg.code, state: msg.state, message: msg.message },
window.location.origin
);
sent = true;
}
} catch(e) { /* opener not available (COOP) */ }
// Method 3: localStorage (most reliable cross-tab fallback)
try {
localStorage.setItem("mcp_oauth_result", JSON.stringify(msg));
sent = true;
} catch(e) { /* localStorage not available */ }
var statusEl = document.getElementById("status");
var spinnerEl = document.getElementById("spinner");
spinnerEl.style.display = "none";
if (msg.success && sent) {
statusEl.textContent = "Sign-in complete! This window will close.";
statusEl.style.color = "#059669";
setTimeout(function() { window.close(); }, 1500);
} else if (msg.success) {
statusEl.textContent = "Sign-in successful! You can close this tab and return to the builder.";
statusEl.style.color = "#059669";
} else {
statusEl.textContent = "Sign-in failed: " + (msg.message || "Unknown error");
statusEl.style.color = "#dc2626";
}
})();
</script>
</body>
</html>`,
{ headers: { "Content-Type": "text/html" } },
);
}

View File

@@ -47,10 +47,7 @@ export type CustomNode = XYNode<CustomNodeData, "custom">;
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
({ data, id: nodeId, selected }) => {
const { inputSchema, outputSchema, isMCPWithTool } = useCustomNode({
data,
nodeId,
});
const { inputSchema, outputSchema } = useCustomNode({ data, nodeId });
const isAgent = data.uiType === BlockUIType.AGENT;
@@ -101,7 +98,6 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
jsonSchema={preprocessInputSchema(inputSchema)}
nodeId={nodeId}
uiType={data.uiType}
isMCPWithTool={isMCPWithTool}
className={cn(
"bg-white px-4",
isWebhook && "pointer-events-none opacity-50",

View File

@@ -20,8 +20,10 @@ type Props = {
export const NodeHeader = ({ data, nodeId }: Props) => {
const updateNodeData = useNodeStore((state) => state.updateNodeData);
const title = (data.metadata?.customized_name as string) || data.title;
const title =
(data.metadata?.customized_name as string) ||
data.hardcodedValues?.agent_name ||
data.title;
const [isEditingTitle, setIsEditingTitle] = useState(false);
const [editedTitle, setEditedTitle] = useState(title);

View File

@@ -3,34 +3,6 @@ import { CustomNodeData } from "./CustomNode";
import { BlockUIType } from "../../../types";
import { useMemo } from "react";
import { mergeSchemaForResolution } from "./helpers";
/**
* Build a dynamic input schema for MCP blocks.
*
* When a tool has been selected (tool_input_schema is populated), the block
* renders the selected tool's input parameters *plus* the credentials field
* so users can select/change the OAuth credential used for execution.
*
* Static fields like server_url, selected_tool, available_tools, and
* tool_arguments are hidden because they're pre-configured from the dialog.
*/
function buildMCPInputSchema(
toolInputSchema: Record<string, any>,
blockInputSchema: Record<string, any>,
): Record<string, any> {
// Extract the credentials field from the block's original input schema
const credentialsSchema =
blockInputSchema?.properties?.credentials ?? undefined;
return {
type: "object",
properties: {
// Credentials field first so the dropdown appears at the top
...(credentialsSchema ? { credentials: credentialsSchema } : {}),
...(toolInputSchema.properties ?? {}),
},
required: [...(toolInputSchema.required ?? [])],
};
}
export const useCustomNode = ({
data,
@@ -47,18 +19,10 @@ export const useCustomNode = ({
);
const isAgent = data.uiType === BlockUIType.AGENT;
const isMCPWithTool =
data.uiType === BlockUIType.MCP_TOOL &&
!!data.hardcodedValues?.tool_input_schema?.properties;
const currentInputSchema = isAgent
? (data.hardcodedValues.input_schema ?? {})
: isMCPWithTool
? buildMCPInputSchema(
data.hardcodedValues.tool_input_schema,
data.inputSchema,
)
: data.inputSchema;
: data.inputSchema;
const currentOutputSchema = isAgent
? (data.hardcodedValues.output_schema ?? {})
: data.outputSchema;
@@ -90,6 +54,5 @@ export const useCustomNode = ({
return {
inputSchema,
outputSchema,
isMCPWithTool,
};
};

View File

@@ -9,72 +9,39 @@ interface FormCreatorProps {
jsonSchema: RJSFSchema;
nodeId: string;
uiType: BlockUIType;
/** When true the block is an MCP Tool with a selected tool. */
isMCPWithTool?: boolean;
showHandles?: boolean;
className?: string;
}
export const FormCreator: React.FC<FormCreatorProps> = React.memo(
({
jsonSchema,
nodeId,
uiType,
isMCPWithTool = false,
showHandles = true,
className,
}) => {
({ jsonSchema, nodeId, uiType, showHandles = true, className }) => {
const updateNodeData = useNodeStore((state) => state.updateNodeData);
const getHardCodedValues = useNodeStore(
(state) => state.getHardCodedValues,
);
const isAgent = uiType === BlockUIType.AGENT;
const handleChange = ({ formData }: any) => {
if ("credentials" in formData && !formData.credentials?.id) {
delete formData.credentials;
}
let updatedValues;
if (isAgent) {
updatedValues = {
...getHardCodedValues(nodeId),
inputs: formData,
};
} else if (isMCPWithTool) {
// Separate credentials from tool arguments — credentials are stored
// at the top level of hardcodedValues, not inside tool_arguments.
const { credentials, ...toolArgs } = formData;
updatedValues = {
...getHardCodedValues(nodeId),
tool_arguments: toolArgs,
...(credentials?.id ? { credentials } : {}),
};
} else {
updatedValues = formData;
}
const updatedValues =
uiType === BlockUIType.AGENT
? {
...getHardCodedValues(nodeId),
inputs: formData,
}
: formData;
updateNodeData(nodeId, { hardcodedValues: updatedValues });
};
const hardcodedValues = getHardCodedValues(nodeId);
let initialValues;
if (isAgent) {
initialValues = hardcodedValues.inputs ?? {};
} else if (isMCPWithTool) {
// Merge tool arguments with credentials for the form
initialValues = {
...(hardcodedValues.tool_arguments ?? {}),
...(hardcodedValues.credentials?.id
? { credentials: hardcodedValues.credentials }
: {}),
};
} else {
initialValues = hardcodedValues;
}
const initialValues =
uiType === BlockUIType.AGENT
? (hardcodedValues.inputs ?? {})
: hardcodedValues;
return (
<div

View File

@@ -1,558 +0,0 @@
"use client";
import React, {
useState,
useCallback,
useRef,
useEffect,
useContext,
} from "react";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
} from "@/components/__legacy__/ui/dialog";
import { Button } from "@/components/__legacy__/ui/button";
import { Input } from "@/components/__legacy__/ui/input";
import { Label } from "@/components/__legacy__/ui/label";
import { LoadingSpinner } from "@/components/__legacy__/ui/loading";
import { Badge } from "@/components/__legacy__/ui/badge";
import { ScrollArea } from "@/components/__legacy__/ui/scroll-area";
import type { CredentialsMetaInput } from "@/lib/autogpt-server-api";
import type { MCPToolResponse } from "@/app/api/__generated__/models/mCPToolResponse";
import {
postV2DiscoverAvailableToolsOnAnMcpServer,
postV2InitiateOauthLoginForAnMcpServer,
postV2ExchangeOauthCodeForMcpTokens,
} from "@/app/api/__generated__/endpoints/mcp/mcp";
import { CaretDown } from "@phosphor-icons/react";
import { openOAuthPopup } from "@/lib/oauth-popup";
import { CredentialsProvidersContext } from "@/providers/agent-credentials/credentials-provider";
export type MCPToolDialogResult = {
serverUrl: string;
serverName: string | null;
selectedTool: string;
toolInputSchema: Record<string, any>;
availableTools: Record<string, any>;
/** Credentials meta from OAuth flow, null for public servers. */
credentials: CredentialsMetaInput | null;
};
interface MCPToolDialogProps {
open: boolean;
onClose: () => void;
onConfirm: (result: MCPToolDialogResult) => void;
}
type DialogStep = "url" | "tool";
export function MCPToolDialog({
open,
onClose,
onConfirm,
}: MCPToolDialogProps) {
const allProviders = useContext(CredentialsProvidersContext);
const [step, setStep] = useState<DialogStep>("url");
const [serverUrl, setServerUrl] = useState("");
const [tools, setTools] = useState<MCPToolResponse[]>([]);
const [serverName, setServerName] = useState<string | null>(null);
const [loading, setLoading] = useState(false);
const [error, setError] = useState<string | null>(null);
const [authRequired, setAuthRequired] = useState(false);
const [oauthLoading, setOauthLoading] = useState(false);
const [showManualToken, setShowManualToken] = useState(false);
const [manualToken, setManualToken] = useState("");
const [selectedTool, setSelectedTool] = useState<MCPToolResponse | null>(
null,
);
const [credentials, setCredentials] = useState<CredentialsMetaInput | null>(
null,
);
const startOAuthRef = useRef(false);
const oauthAbortRef = useRef<((reason?: string) => void) | null>(null);
// Clean up on unmount
useEffect(() => {
return () => {
oauthAbortRef.current?.();
};
}, []);
const reset = useCallback(() => {
oauthAbortRef.current?.();
oauthAbortRef.current = null;
setStep("url");
setServerUrl("");
setManualToken("");
setTools([]);
setServerName(null);
setLoading(false);
setError(null);
setAuthRequired(false);
setOauthLoading(false);
setShowManualToken(false);
setSelectedTool(null);
setCredentials(null);
}, []);
const handleClose = useCallback(() => {
reset();
onClose();
}, [reset, onClose]);
const discoverTools = useCallback(async (url: string, authToken?: string) => {
setLoading(true);
setError(null);
try {
const response = await postV2DiscoverAvailableToolsOnAnMcpServer({
server_url: url,
auth_token: authToken || null,
});
if (response.status !== 200) throw response.data;
setTools(response.data.tools);
setServerName(response.data.server_name ?? null);
setAuthRequired(false);
setShowManualToken(false);
setStep("tool");
} catch (e: any) {
if (e?.status === 401 || e?.status === 403) {
setAuthRequired(true);
setError(null);
// Automatically start OAuth sign-in instead of requiring a second click
setLoading(false);
startOAuthRef.current = true;
return;
} else {
const message =
e?.message || e?.detail || "Failed to connect to MCP server";
setError(
typeof message === "string" ? message : JSON.stringify(message),
);
}
} finally {
setLoading(false);
}
}, []);
const handleDiscoverTools = useCallback(() => {
if (!serverUrl.trim()) return;
discoverTools(serverUrl.trim(), manualToken.trim() || undefined);
}, [serverUrl, manualToken, discoverTools]);
const handleOAuthSignIn = useCallback(async () => {
if (!serverUrl.trim()) return;
setError(null);
// Abort any previous OAuth flow
oauthAbortRef.current?.();
setOauthLoading(true);
try {
const loginResponse = await postV2InitiateOauthLoginForAnMcpServer({
server_url: serverUrl.trim(),
});
if (loginResponse.status !== 200) throw loginResponse.data;
const { login_url, state_token } = loginResponse.data;
const { promise, cleanup } = openOAuthPopup(login_url, {
stateToken: state_token,
useCrossOriginListeners: true,
});
oauthAbortRef.current = cleanup.abort;
const result = await promise;
// Exchange code for tokens via the credentials provider (updates cache)
setLoading(true);
setOauthLoading(false);
const mcpProvider = allProviders?.["mcp"];
let callbackResult;
if (mcpProvider) {
callbackResult = await mcpProvider.mcpOAuthCallback(
result.code,
state_token,
);
} else {
const cbResponse = await postV2ExchangeOauthCodeForMcpTokens({
code: result.code,
state_token,
});
if (cbResponse.status !== 200) throw cbResponse.data;
callbackResult = cbResponse.data;
}
setCredentials({
id: callbackResult.id,
provider: callbackResult.provider,
type: callbackResult.type,
title: callbackResult.title,
});
setAuthRequired(false);
// Discover tools now that we're authenticated
const toolsResponse = await postV2DiscoverAvailableToolsOnAnMcpServer({
server_url: serverUrl.trim(),
});
if (toolsResponse.status !== 200) throw toolsResponse.data;
setTools(toolsResponse.data.tools);
setServerName(toolsResponse.data.server_name ?? null);
setStep("tool");
} catch (e: any) {
// If server doesn't support OAuth → show manual token entry
if (e?.status === 400) {
setShowManualToken(true);
setError(
"This server does not support OAuth sign-in. Please enter a token manually.",
);
} else if (e?.message === "OAuth flow timed out") {
setError("OAuth sign-in timed out. Please try again.");
} else {
const status = e?.status;
let message: string;
if (status === 401 || status === 403) {
message =
"Authentication succeeded but the server still rejected the request. " +
"The token audience may not match. Please try again.";
} else {
message = e?.message || e?.detail || "Failed to complete sign-in";
}
setError(
typeof message === "string" ? message : JSON.stringify(message),
);
}
} finally {
setOauthLoading(false);
setLoading(false);
oauthAbortRef.current = null;
}
}, [serverUrl, allProviders]);
// Auto-start OAuth sign-in when server returns 401/403
useEffect(() => {
if (authRequired && startOAuthRef.current) {
startOAuthRef.current = false;
handleOAuthSignIn();
}
}, [authRequired, handleOAuthSignIn]);
const handleConfirm = useCallback(() => {
if (!selectedTool) return;
const availableTools: Record<string, any> = {};
for (const t of tools) {
availableTools[t.name] = {
description: t.description,
input_schema: t.input_schema,
};
}
onConfirm({
serverUrl: serverUrl.trim(),
serverName,
selectedTool: selectedTool.name,
toolInputSchema: selectedTool.input_schema,
availableTools,
credentials,
});
reset();
}, [
selectedTool,
tools,
serverUrl,
serverName,
credentials,
onConfirm,
reset,
]);
return (
<Dialog open={open} onOpenChange={(isOpen) => !isOpen && handleClose()}>
<DialogContent className="max-w-lg">
<DialogHeader>
<DialogTitle>
{step === "url"
? "Connect to MCP Server"
: `Select a Tool${serverName ? `${serverName}` : ""}`}
</DialogTitle>
<DialogDescription>
{step === "url"
? "Enter the URL of an MCP server to discover its available tools."
: `Found ${tools.length} tool${tools.length !== 1 ? "s" : ""}. Select one to add to your agent.`}
</DialogDescription>
</DialogHeader>
{step === "url" && (
<div className="flex flex-col gap-4 py-2">
<div className="flex flex-col gap-2">
<Label htmlFor="mcp-server-url">Server URL</Label>
<Input
id="mcp-server-url"
type="url"
placeholder="https://mcp.example.com/mcp"
value={serverUrl}
onChange={(e) => setServerUrl(e.target.value)}
onKeyDown={(e) => e.key === "Enter" && handleDiscoverTools()}
autoFocus
/>
</div>
{/* Auth required: show manual token option */}
{authRequired && !showManualToken && (
<button
onClick={() => setShowManualToken(true)}
className="text-xs text-gray-500 underline hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-300"
>
or enter a token manually
</button>
)}
{/* Manual token entry — only visible when expanded */}
{showManualToken && (
<div className="flex flex-col gap-2">
<Label htmlFor="mcp-auth-token" className="text-sm">
Bearer Token
</Label>
<Input
id="mcp-auth-token"
type="password"
placeholder="Paste your auth token here"
value={manualToken}
onChange={(e) => setManualToken(e.target.value)}
onKeyDown={(e) => e.key === "Enter" && handleDiscoverTools()}
autoFocus
/>
</div>
)}
{error && <p className="text-sm text-red-500">{error}</p>}
</div>
)}
{step === "tool" && (
<ScrollArea className="max-h-[50vh] py-2">
<div className="flex flex-col gap-2 pr-3">
{tools.map((tool) => (
<MCPToolCard
key={tool.name}
tool={tool}
selected={selectedTool?.name === tool.name}
onSelect={() => setSelectedTool(tool)}
/>
))}
</div>
</ScrollArea>
)}
<DialogFooter>
{step === "tool" && (
<Button
variant="outline"
onClick={() => {
setStep("url");
setSelectedTool(null);
}}
>
Back
</Button>
)}
<Button variant="outline" onClick={handleClose}>
Cancel
</Button>
{step === "url" && (
<Button
onClick={
authRequired && !showManualToken
? handleOAuthSignIn
: handleDiscoverTools
}
disabled={!serverUrl.trim() || loading || oauthLoading}
>
{loading || oauthLoading ? (
<span className="flex items-center gap-2">
<LoadingSpinner className="size-4" />
{oauthLoading ? "Waiting for sign-in..." : "Connecting..."}
</span>
) : authRequired && !showManualToken ? (
"Sign in & Connect"
) : (
"Discover Tools"
)}
</Button>
)}
{step === "tool" && (
<Button onClick={handleConfirm} disabled={!selectedTool}>
Add Block
</Button>
)}
</DialogFooter>
</DialogContent>
</Dialog>
);
}
// --------------- Tool Card Component --------------- //
/** Truncate a description to a reasonable length for the collapsed view. */
function truncateDescription(text: string, maxLen = 120): string {
if (text.length <= maxLen) return text;
return text.slice(0, maxLen).trimEnd() + "…";
}
/** Pretty-print a JSON Schema type for a parameter. */
function schemaTypeLabel(schema: Record<string, any>): string {
if (schema.type) return schema.type;
if (schema.anyOf)
return schema.anyOf.map((s: any) => s.type ?? "any").join(" | ");
if (schema.oneOf)
return schema.oneOf.map((s: any) => s.type ?? "any").join(" | ");
return "any";
}
function MCPToolCard({
tool,
selected,
onSelect,
}: {
tool: MCPToolResponse;
selected: boolean;
onSelect: () => void;
}) {
const [expanded, setExpanded] = useState(false);
const schema = tool.input_schema as Record<string, any>;
const properties = schema?.properties ?? {};
const required = new Set<string>(schema?.required ?? []);
const paramNames = Object.keys(properties);
// Strip XML-like tags from description for cleaner display.
// Loop to handle nested tags like <scr<script>ipt> (CodeQL fix).
let cleanDescription = tool.description ?? "";
let prev = "";
while (prev !== cleanDescription) {
prev = cleanDescription;
cleanDescription = cleanDescription.replace(/<[^>]*>/g, "");
}
cleanDescription = cleanDescription.trim();
return (
<button
onClick={onSelect}
className={`group flex flex-col rounded-lg border text-left transition-colors ${
selected
? "border-blue-500 bg-blue-50 dark:border-blue-400 dark:bg-blue-950"
: "border-gray-200 hover:border-gray-300 hover:bg-gray-50 dark:border-slate-700 dark:hover:border-slate-600 dark:hover:bg-slate-800"
}`}
>
{/* Header */}
<div className="flex items-center gap-2 px-3 pb-1 pt-3">
<span className="flex-1 text-sm font-semibold dark:text-white">
{tool.name}
</span>
{paramNames.length > 0 && (
<Badge variant="secondary" className="text-[10px]">
{paramNames.length} param{paramNames.length !== 1 ? "s" : ""}
</Badge>
)}
</div>
{/* Description (collapsed: truncated) */}
{cleanDescription && (
<p className="px-3 pb-1 text-xs leading-relaxed text-gray-500 dark:text-gray-400">
{expanded ? cleanDescription : truncateDescription(cleanDescription)}
</p>
)}
{/* Parameter badges (collapsed view) */}
{!expanded && paramNames.length > 0 && (
<div className="flex flex-wrap gap-1 px-3 pb-2">
{paramNames.slice(0, 6).map((name) => (
<Badge
key={name}
variant="outline"
className="text-[10px] font-normal"
>
{name}
{required.has(name) && (
<span className="ml-0.5 text-red-400">*</span>
)}
</Badge>
))}
{paramNames.length > 6 && (
<Badge variant="outline" className="text-[10px] font-normal">
+{paramNames.length - 6} more
</Badge>
)}
</div>
)}
{/* Expanded: full parameter details */}
{expanded && paramNames.length > 0 && (
<div className="mx-3 mb-2 rounded border border-gray-100 bg-gray-50/50 dark:border-slate-700 dark:bg-slate-800/50">
<table className="w-full text-xs">
<thead>
<tr className="border-b border-gray-100 dark:border-slate-700">
<th className="px-2 py-1 text-left font-medium text-gray-500 dark:text-gray-400">
Parameter
</th>
<th className="px-2 py-1 text-left font-medium text-gray-500 dark:text-gray-400">
Type
</th>
<th className="px-2 py-1 text-left font-medium text-gray-500 dark:text-gray-400">
Description
</th>
</tr>
</thead>
<tbody>
{paramNames.map((name) => {
const prop = properties[name] ?? {};
return (
<tr
key={name}
className="border-b border-gray-50 last:border-0 dark:border-slate-700/50"
>
<td className="px-2 py-1 font-mono text-[11px] text-gray-700 dark:text-gray-300">
{name}
{required.has(name) && (
<span className="ml-0.5 text-red-400">*</span>
)}
</td>
<td className="px-2 py-1 text-gray-500 dark:text-gray-400">
{schemaTypeLabel(prop)}
</td>
<td className="max-w-[200px] truncate px-2 py-1 text-gray-500 dark:text-gray-400">
{prop.description ?? "—"}
</td>
</tr>
);
})}
</tbody>
</table>
</div>
)}
{/* Toggle details */}
{(paramNames.length > 0 || cleanDescription.length > 120) && (
<button
type="button"
onClick={(e) => {
e.stopPropagation();
setExpanded((prev) => !prev);
}}
className="flex w-full items-center justify-center gap-1 border-t border-gray-100 py-1.5 text-[10px] text-gray-400 hover:text-gray-600 dark:border-slate-700 dark:text-gray-500 dark:hover:text-gray-300"
>
{expanded ? "Hide details" : "Show details"}
<CaretDown
className={`h-3 w-3 transition-transform ${expanded ? "rotate-180" : ""}`}
/>
</button>
)}
</button>
);
}

View File

@@ -1,7 +1,7 @@
import { Button } from "@/components/__legacy__/ui/button";
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
import { beautifyString, cn } from "@/lib/utils";
import React, { ButtonHTMLAttributes, useCallback, useState } from "react";
import React, { ButtonHTMLAttributes } from "react";
import { highlightText } from "./helpers";
import { PlusIcon } from "@phosphor-icons/react";
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
@@ -9,12 +9,6 @@ import { useControlPanelStore } from "../../../stores/controlPanelStore";
import { blockDragPreviewStyle } from "./style";
import { useReactFlow } from "@xyflow/react";
import { useNodeStore } from "../../../stores/nodeStore";
import { BlockUIType, SpecialBlockID } from "@/lib/autogpt-server-api";
import {
MCPToolDialog,
type MCPToolDialogResult,
} from "@/app/(platform)/build/components/MCPToolDialog";
interface Props extends ButtonHTMLAttributes<HTMLButtonElement> {
title?: string;
description?: string;
@@ -39,86 +33,22 @@ export const Block: BlockComponent = ({
);
const { setViewport } = useReactFlow();
const { addBlock } = useNodeStore();
const [mcpDialogOpen, setMcpDialogOpen] = useState(false);
const isMCPBlock = blockData.uiType === BlockUIType.MCP_TOOL;
const addBlockAndCenter = useCallback(
(block: BlockInfo, hardcodedValues?: Record<string, any>) => {
const customNode = addBlock(block, hardcodedValues);
setTimeout(() => {
setViewport(
{
x: -customNode.position.x * 0.8 + window.innerWidth / 2,
y: -customNode.position.y * 0.8 + (window.innerHeight - 400) / 2,
zoom: 0.8,
},
{ duration: 500 },
);
}, 50);
return customNode;
},
[addBlock, setViewport],
);
const updateNodeData = useNodeStore((state) => state.updateNodeData);
const handleMCPToolConfirm = useCallback(
(result: MCPToolDialogResult) => {
// Derive a display label: prefer server name, fall back to URL hostname.
let serverLabel = result.serverName;
if (!serverLabel) {
try {
serverLabel = new URL(result.serverUrl).hostname;
} catch {
serverLabel = "MCP";
}
}
const customNode = addBlockAndCenter(blockData, {
server_url: result.serverUrl,
server_name: serverLabel,
selected_tool: result.selectedTool,
tool_input_schema: result.toolInputSchema,
available_tools: result.availableTools,
credentials: result.credentials ?? undefined,
});
if (customNode) {
const title = result.selectedTool
? `${serverLabel}: ${beautifyString(result.selectedTool)}`
: undefined;
updateNodeData(customNode.id, {
metadata: {
...customNode.data.metadata,
credentials_optional: true,
...(title && { customized_name: title }),
},
});
}
setMcpDialogOpen(false);
},
[addBlockAndCenter, blockData, updateNodeData],
);
const handleClick = () => {
if (isMCPBlock) {
setMcpDialogOpen(true);
return;
}
const customNode = addBlockAndCenter(blockData);
// Set customized_name for agent blocks so the agent's name persists
if (customNode && blockData.id === SpecialBlockID.AGENT) {
updateNodeData(customNode.id, {
metadata: {
...customNode.data.metadata,
customized_name: blockData.name,
const customNode = addBlock(blockData);
setTimeout(() => {
setViewport(
{
x: -customNode.position.x * 0.8 + window.innerWidth / 2,
y: -customNode.position.y * 0.8 + (window.innerHeight - 400) / 2,
zoom: 0.8,
},
});
}
{ duration: 500 },
);
}, 50);
};
const handleDragStart = (e: React.DragEvent<HTMLButtonElement>) => {
if (isMCPBlock) return;
e.dataTransfer.effectAllowed = "copy";
e.dataTransfer.setData("application/reactflow", JSON.stringify(blockData));
@@ -141,56 +71,46 @@ export const Block: BlockComponent = ({
: undefined;
return (
<>
<Button
draggable={!isMCPBlock}
data-id={blockDataId}
className={cn(
"group flex h-16 w-full min-w-[7.5rem] items-center justify-start space-x-3 whitespace-normal rounded-[0.75rem] bg-zinc-50 px-[0.875rem] py-[0.625rem] text-start shadow-none",
"hover:cursor-default hover:bg-zinc-100 focus:ring-0 active:bg-zinc-100 active:ring-1 active:ring-zinc-300 disabled:cursor-not-allowed",
isMCPBlock && "hover:cursor-pointer",
className,
)}
onDragStart={handleDragStart}
onClick={handleClick}
{...rest}
>
<div className="flex flex-1 flex-col items-start gap-0.5">
{title && (
<span
className={cn(
"line-clamp-1 font-sans text-sm font-medium leading-[1.375rem] text-zinc-800 group-disabled:text-zinc-400",
)}
>
{highlightText(beautifyString(title), highlightedText)}
</span>
)}
{description && (
<span
className={cn(
"line-clamp-1 font-sans text-xs font-normal leading-5 text-zinc-500 group-disabled:text-zinc-400",
)}
>
{highlightText(description, highlightedText)}
</span>
)}
</div>
<div
className={cn(
"flex h-7 w-7 items-center justify-center rounded-[0.5rem] bg-zinc-700 group-disabled:bg-zinc-400",
)}
>
<PlusIcon className="h-5 w-5 text-zinc-50" />
</div>
</Button>
{isMCPBlock && (
<MCPToolDialog
open={mcpDialogOpen}
onClose={() => setMcpDialogOpen(false)}
onConfirm={handleMCPToolConfirm}
/>
<Button
draggable={true}
data-id={blockDataId}
className={cn(
"group flex h-16 w-full min-w-[7.5rem] items-center justify-start space-x-3 whitespace-normal rounded-[0.75rem] bg-zinc-50 px-[0.875rem] py-[0.625rem] text-start shadow-none",
"hover:cursor-default hover:bg-zinc-100 focus:ring-0 active:bg-zinc-100 active:ring-1 active:ring-zinc-300 disabled:cursor-not-allowed",
className,
)}
</>
onDragStart={handleDragStart}
onClick={handleClick}
{...rest}
>
<div className="flex flex-1 flex-col items-start gap-0.5">
{title && (
<span
className={cn(
"line-clamp-1 font-sans text-sm font-medium leading-[1.375rem] text-zinc-800 group-disabled:text-zinc-400",
)}
>
{highlightText(beautifyString(title), highlightedText)}
</span>
)}
{description && (
<span
className={cn(
"line-clamp-1 font-sans text-xs font-normal leading-5 text-zinc-500 group-disabled:text-zinc-400",
)}
>
{highlightText(description, highlightedText)}
</span>
)}
</div>
<div
className={cn(
"flex h-7 w-7 items-center justify-center rounded-[0.5rem] bg-zinc-700 group-disabled:bg-zinc-400",
)}
>
<PlusIcon className="h-5 w-5 text-zinc-50" />
</div>
</Button>
);
};

View File

@@ -9,5 +9,4 @@ export enum BlockUIType {
AGENT = "Agent",
AI = "AI",
AYRSHARE = "Ayrshare",
MCP_TOOL = "MCP Tool",
}

View File

@@ -24,7 +24,6 @@ import { FindBlocksTool } from "../../tools/FindBlocks/FindBlocks";
import { RunAgentTool } from "../../tools/RunAgent/RunAgent";
import { RunBlockTool } from "../../tools/RunBlock/RunBlock";
import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs";
import { GenericTool } from "../../tools/GenericTool/GenericTool";
import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput";
// ---------------------------------------------------------------------------
@@ -274,16 +273,6 @@ export const ChatMessagesContainer = ({
/>
);
default:
// Render a generic tool indicator for SDK built-in
// tools (Read, Glob, Grep, etc.) or any unrecognized tool
if (part.type.startsWith("tool-")) {
return (
<GenericTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
}
return null;
}
})}

View File

@@ -4,7 +4,6 @@ import { Button } from "@/components/atoms/Button/Button";
import { Text } from "@/components/atoms/Text/Text";
import {
BookOpenIcon,
CheckFatIcon,
PencilSimpleIcon,
WarningDiamondIcon,
} from "@phosphor-icons/react";
@@ -24,6 +23,7 @@ import {
ClarificationQuestionsCard,
ClarifyingQuestion,
} from "./components/ClarificationQuestionsCard";
import sparklesImg from "./components/MiniGame/assets/sparkles.png";
import { MiniGame } from "./components/MiniGame/MiniGame";
import {
AccordionIcon,
@@ -83,7 +83,8 @@ function getAccordionMeta(output: CreateAgentToolOutput) {
) {
return {
icon,
title: "Creating agent, this may take a few minutes. Sit back and relax.",
title:
"Creating agent, this may take a few minutes. Play while you wait.",
expanded: true,
};
}
@@ -167,16 +168,22 @@ export function CreateAgentTool({ part }: Props) {
{isAgentSavedOutput(output) && (
<div className="rounded-xl border border-border/60 bg-card p-4 shadow-sm">
<div className="flex items-baseline gap-2">
<CheckFatIcon
size={18}
weight="regular"
className="relative top-1 text-green-500"
<img
src={sparklesImg.src}
alt="sparkles"
width={24}
height={24}
className="relative top-1"
/>
<Text
variant="body-medium"
className="text-blacks mb-2 text-[16px]"
>
{output.message}
Agent{" "}
<span className="text-[rgb(124,58,237)]">
{output.agent_name}
</span>{" "}
has been saved to your library!
</Text>
</div>
<div className="mt-3 flex flex-wrap gap-4">

View File

@@ -2,20 +2,78 @@
import { useMiniGame } from "./useMiniGame";
function Key({ children }: { children: React.ReactNode }) {
return <strong>[{children}]</strong>;
}
export function MiniGame() {
const { canvasRef } = useMiniGame();
const {
canvasRef,
activeMode,
showOverlay,
score,
highScore,
onContinue,
} = useMiniGame();
const isRunActive =
activeMode === "run" || activeMode === "idle" || activeMode === "over";
const isBossActive =
activeMode === "boss" ||
activeMode === "boss-intro" ||
activeMode === "boss-defeated";
let overlayText: string | undefined;
let buttonLabel = "Continue";
if (activeMode === "idle") {
buttonLabel = "Start";
} else if (activeMode === "boss-intro") {
overlayText = "Face the bandit!";
} else if (activeMode === "boss-defeated") {
overlayText = "Great job, keep on going";
} else if (activeMode === "over") {
overlayText = `Score: ${score} / Record: ${highScore}`;
buttonLabel = "Retry";
}
return (
<div
className="w-full overflow-hidden rounded-md bg-background text-foreground"
style={{ border: "1px solid #d17fff" }}
>
<canvas
ref={canvasRef}
tabIndex={0}
className="block w-full outline-none"
style={{ imageRendering: "pixelated" }}
/>
<div className="flex flex-col gap-2">
<p className="text-sm font-medium text-purple-500">
{isBossActive ? (
<>
Duel mode: <Key></Key> to move · <Key>Z</Key> to attack ·{" "}
<Key>X</Key> to block · <Key>Space</Key> to jump
</>
) : (
<>
Run mode: <Key>Space</Key> to jump
</>
)}
</p>
<div
className="relative w-full overflow-hidden rounded-md bg-background text-foreground"
style={{ border: "1px solid #d17fff" }}
>
<canvas
ref={canvasRef}
tabIndex={0}
className="block w-full outline-none"
/>
{showOverlay && (
<div className="absolute inset-0 flex flex-col items-center justify-center gap-3 bg-black/40">
{overlayText && (
<p className="text-lg font-bold text-white">{overlayText}</p>
)}
<button
type="button"
onClick={onContinue}
className="rounded-md bg-white px-4 py-2 text-sm font-semibold text-zinc-800 shadow-md transition-colors hover:bg-zinc-100"
>
{buttonLabel}
</button>
</div>
)}
</div>
</div>
);
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

View File

@@ -136,7 +136,7 @@ export function getAnimationText(part: {
if (isOperationPendingOutput(output)) return "Agent creation in progress";
if (isOperationInProgressOutput(output))
return "Agent creation already in progress";
if (isAgentSavedOutput(output)) return `Saved "${output.agent_name}"`;
if (isAgentSavedOutput(output)) return `Saved ${output.agent_name}`;
if (isAgentPreviewOutput(output)) return `Preview "${output.agent_name}"`;
if (isClarificationNeededOutput(output)) return "Needs clarification";
return "Error creating agent";

View File

@@ -5,7 +5,6 @@ import type { ToolUIPart } from "ai";
import { useCopilotChatActions } from "../../components/CopilotChatActionsProvider/useCopilotChatActions";
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
import { ProgressBar } from "../../components/ProgressBar/ProgressBar";
import {
ContentCardDescription,
ContentCodeBlock,
@@ -15,7 +14,7 @@ import {
ContentMessage,
} from "../../components/ToolAccordion/AccordionContent";
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
import { useAsymptoticProgress } from "../../hooks/useAsymptoticProgress";
import { MiniGame } from "../CreateAgent/components/MiniGame/MiniGame";
import {
ClarificationQuestionsCard,
ClarifyingQuestion,
@@ -80,7 +79,12 @@ function getAccordionMeta(output: EditAgentToolOutput): {
isOperationPendingOutput(output) ||
isOperationInProgressOutput(output)
) {
return { icon: <OrbitLoader size={32} />, title: "Editing agent" };
return {
icon: <OrbitLoader size={32} />,
title:
"Editing agent, this may take a few minutes. Play while you wait.",
expanded: true,
};
}
return {
icon: (
@@ -105,7 +109,6 @@ export function EditAgentTool({ part }: Props) {
(isOperationStartedOutput(output) ||
isOperationPendingOutput(output) ||
isOperationInProgressOutput(output));
const progress = useAsymptoticProgress(isOperating);
const hasExpandableContent =
part.state === "output-available" &&
!!output &&
@@ -149,9 +152,9 @@ export function EditAgentTool({ part }: Props) {
<ToolAccordion {...getAccordionMeta(output)}>
{isOperating && (
<ContentGrid>
<ProgressBar value={progress} className="max-w-[280px]" />
<MiniGame />
<ContentHint>
This could take a few minutes, grab a coffee
This could take a few minutes play while you wait!
</ContentHint>
</ContentGrid>
)}

View File

@@ -1,63 +0,0 @@
"use client";
import { ToolUIPart } from "ai";
import { GearIcon } from "@phosphor-icons/react";
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
interface Props {
part: ToolUIPart;
}
function extractToolName(part: ToolUIPart): string {
// ToolUIPart.type is "tool-{name}", extract the name portion.
return part.type.replace(/^tool-/, "");
}
function formatToolName(name: string): string {
// "search_docs" → "Search docs", "Read" → "Read"
return name.replace(/_/g, " ").replace(/^\w/, (c) => c.toUpperCase());
}
function getAnimationText(part: ToolUIPart): string {
const label = formatToolName(extractToolName(part));
switch (part.state) {
case "input-streaming":
case "input-available":
return `Running ${label}`;
case "output-available":
return `${label} completed`;
case "output-error":
return `${label} failed`;
default:
return `Running ${label}`;
}
}
export function GenericTool({ part }: Props) {
const isStreaming =
part.state === "input-streaming" || part.state === "input-available";
const isError = part.state === "output-error";
return (
<div className="py-2">
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<GearIcon
size={14}
weight="regular"
className={
isError
? "text-red-500"
: isStreaming
? "animate-spin text-neutral-500"
: "text-neutral-400"
}
/>
<MorphingTextAnimation
text={getAnimationText(part)}
className={isError ? "text-red-500" : undefined}
/>
</div>
</div>
);
}

View File

@@ -2,8 +2,14 @@
import type { ToolUIPart } from "ai";
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
import { OrbitLoader } from "../../components/OrbitLoader/OrbitLoader";
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
import { ContentMessage } from "../../components/ToolAccordion/AccordionContent";
import {
ContentGrid,
ContentHint,
ContentMessage,
} from "../../components/ToolAccordion/AccordionContent";
import { MiniGame } from "../CreateAgent/components/MiniGame/MiniGame";
import {
getAccordionMeta,
getAnimationText,
@@ -60,6 +66,21 @@ export function RunAgentTool({ part }: Props) {
/>
</div>
{isStreaming && !output && (
<ToolAccordion
icon={<OrbitLoader size={32} />}
title="Running agent, this may take a few minutes. Play while you wait."
expanded={true}
>
<ContentGrid>
<MiniGame />
<ContentHint>
This could take a few minutes play while you wait!
</ContentHint>
</ContentGrid>
</ToolAccordion>
)}
{hasExpandableContent && output && (
<ToolAccordion {...getAccordionMeta(output)}>
{isRunAgentExecutionStartedOutput(output) && (

View File

@@ -4269,128 +4269,6 @@
}
}
},
"/api/mcp/discover-tools": {
"post": {
"tags": ["v2", "mcp", "mcp"],
"summary": "Discover available tools on an MCP server",
"description": "Connect to an MCP server and return its available tools.\n\nIf the user has a stored MCP credential for this server URL, it will be\nused automatically — no need to pass an explicit auth token.",
"operationId": "postV2Discover available tools on an mcp server",
"requestBody": {
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/DiscoverToolsRequest" }
}
},
"required": true
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/DiscoverToolsResponse"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
},
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/mcp/oauth/callback": {
"post": {
"tags": ["v2", "mcp", "mcp"],
"summary": "Exchange OAuth code for MCP tokens",
"description": "Exchange the authorization code for tokens and store the credential.\n\nThe frontend calls this after receiving the OAuth code from the popup.\nOn success, subsequent ``/discover-tools`` calls for the same server URL\nwill automatically use the stored credential.",
"operationId": "postV2Exchange oauth code for mcp tokens",
"requestBody": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/MCPOAuthCallbackRequest"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/CredentialsMetaResponse"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
},
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/mcp/oauth/login": {
"post": {
"tags": ["v2", "mcp", "mcp"],
"summary": "Initiate OAuth login for an MCP server",
"description": "Discover OAuth metadata from the MCP server and return a login URL.\n\n1. Discovers the protected-resource metadata (RFC 9728)\n2. Fetches the authorization server metadata (RFC 8414)\n3. Performs Dynamic Client Registration (RFC 7591) if available\n4. Returns the authorization URL for the frontend to open in a popup",
"operationId": "postV2Initiate oauth login for an mcp server",
"requestBody": {
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/MCPOAuthLoginRequest" }
}
},
"required": true
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/MCPOAuthLoginResponse"
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
},
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/oauth/app/{client_id}": {
"get": {
"tags": ["oauth"],
@@ -7188,57 +7066,13 @@
"properties": {
"id": { "type": "string", "title": "Id" },
"name": { "type": "string", "title": "Name" },
"description": { "type": "string", "title": "Description" },
"categories": {
"items": { "type": "string" },
"type": "array",
"title": "Categories"
},
"input_schema": {
"additionalProperties": true,
"type": "object",
"title": "Input Schema",
"description": "Full JSON schema for block inputs"
},
"output_schema": {
"additionalProperties": true,
"type": "object",
"title": "Output Schema",
"description": "Full JSON schema for block outputs"
},
"required_inputs": {
"items": { "$ref": "#/components/schemas/BlockInputFieldInfo" },
"type": "array",
"title": "Required Inputs",
"description": "List of input fields for this block"
}
"description": { "type": "string", "title": "Description" }
},
"type": "object",
"required": ["id", "name", "description", "categories"],
"required": ["id", "name", "description"],
"title": "BlockInfoSummary",
"description": "Summary of a block for search results."
},
"BlockInputFieldInfo": {
"properties": {
"name": { "type": "string", "title": "Name" },
"type": { "type": "string", "title": "Type" },
"description": {
"type": "string",
"title": "Description",
"default": ""
},
"required": {
"type": "boolean",
"title": "Required",
"default": false
},
"default": { "anyOf": [{}, { "type": "null" }], "title": "Default" }
},
"type": "object",
"required": ["name", "type"],
"title": "BlockInputFieldInfo",
"description": "Information about a block input field."
},
"BlockListResponse": {
"properties": {
"type": {
@@ -7256,12 +7090,7 @@
"title": "Blocks"
},
"count": { "type": "integer", "title": "Count" },
"query": { "type": "string", "title": "Query" },
"usage_hint": {
"type": "string",
"title": "Usage Hint",
"default": "To execute a block, call run_block with block_id set to the block's 'id' field and input_data containing the fields listed in required_inputs."
}
"query": { "type": "string", "title": "Query" }
},
"type": "object",
"required": ["message", "blocks", "count", "query"],
@@ -7813,7 +7642,7 @@
"host": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Host",
"description": "Host pattern for host-scoped or MCP server URL for MCP credentials"
"description": "Host pattern for host-scoped credentials"
}
},
"type": "object",
@@ -7833,45 +7662,6 @@
"required": ["version_counts"],
"title": "DeleteGraphResponse"
},
"DiscoverToolsRequest": {
"properties": {
"server_url": {
"type": "string",
"title": "Server Url",
"description": "URL of the MCP server"
},
"auth_token": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Auth Token",
"description": "Optional Bearer token for authenticated MCP servers"
}
},
"type": "object",
"required": ["server_url"],
"title": "DiscoverToolsRequest",
"description": "Request to discover tools on an MCP server."
},
"DiscoverToolsResponse": {
"properties": {
"tools": {
"items": { "$ref": "#/components/schemas/MCPToolResponse" },
"type": "array",
"title": "Tools"
},
"server_name": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Server Name"
},
"protocol_version": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Protocol Version"
}
},
"type": "object",
"required": ["tools"],
"title": "DiscoverToolsResponse",
"description": "Response containing the list of tools available on an MCP server."
},
"DocPageResponse": {
"properties": {
"type": {
@@ -9448,62 +9238,6 @@
"required": ["login_url", "state_token"],
"title": "LoginResponse"
},
"MCPOAuthCallbackRequest": {
"properties": {
"code": {
"type": "string",
"title": "Code",
"description": "Authorization code from OAuth callback"
},
"state_token": {
"type": "string",
"title": "State Token",
"description": "State token for CSRF verification"
}
},
"type": "object",
"required": ["code", "state_token"],
"title": "MCPOAuthCallbackRequest",
"description": "Request to exchange an OAuth code for tokens."
},
"MCPOAuthLoginRequest": {
"properties": {
"server_url": {
"type": "string",
"title": "Server Url",
"description": "URL of the MCP server that requires OAuth"
}
},
"type": "object",
"required": ["server_url"],
"title": "MCPOAuthLoginRequest",
"description": "Request to start an OAuth flow for an MCP server."
},
"MCPOAuthLoginResponse": {
"properties": {
"login_url": { "type": "string", "title": "Login Url" },
"state_token": { "type": "string", "title": "State Token" }
},
"type": "object",
"required": ["login_url", "state_token"],
"title": "MCPOAuthLoginResponse",
"description": "Response with the OAuth login URL for the user to authenticate."
},
"MCPToolResponse": {
"properties": {
"name": { "type": "string", "title": "Name" },
"description": { "type": "string", "title": "Description" },
"input_schema": {
"additionalProperties": true,
"type": "object",
"title": "Input Schema"
}
},
"type": "object",
"required": ["name", "description", "input_schema"],
"title": "MCPToolResponse",
"description": "A single MCP tool returned by discovery."
},
"MarketplaceListing": {
"properties": {
"id": { "type": "string", "title": "Id" },
@@ -10762,9 +10496,6 @@
"operation_pending",
"operation_in_progress",
"input_validation_error",
"web_fetch",
"bash_exec",
"operation_status",
"feature_request_search",
"feature_request_created"
],

View File

@@ -38,8 +38,13 @@ export function CredentialsGroupedView({
const allProviders = useContext(CredentialsProvidersContext);
const { userCredentialFields, systemCredentialFields } = useMemo(
() => splitCredentialFieldsBySystem(credentialFields, allProviders),
[credentialFields, allProviders],
() =>
splitCredentialFieldsBySystem(
credentialFields,
allProviders,
inputCredentials,
),
[credentialFields, allProviders, inputCredentials],
);
const hasSystemCredentials = systemCredentialFields.length > 0;
@@ -81,13 +86,11 @@ export function CredentialsGroupedView({
const providerNames = schema.credentials_provider || [];
const credentialTypes = schema.credentials_types || [];
const requiredScopes = schema.credentials_scopes;
const discriminatorValues = schema.discriminator_values;
const savedCredential = findSavedCredentialByProviderAndType(
providerNames,
credentialTypes,
requiredScopes,
allProviders,
discriminatorValues,
);
if (savedCredential) {

View File

@@ -23,35 +23,10 @@ function hasRequiredScopes(
return true;
}
/** Check if a credential matches the discriminator values (e.g. MCP server URL). */
function matchesDiscriminatorValues(
credential: { host?: string | null; provider: string; type: string },
discriminatorValues?: string[],
) {
// MCP OAuth2 credentials must match by server URL
if (credential.type === "oauth2" && credential.provider === "mcp") {
if (!discriminatorValues || discriminatorValues.length === 0) return false;
return (
credential.host != null && discriminatorValues.includes(credential.host)
);
}
// Host-scoped credentials match by host
if (credential.type === "host_scoped" && credential.host) {
if (!discriminatorValues || discriminatorValues.length === 0) return true;
return discriminatorValues.some((v) => {
try {
return new URL(v).hostname === credential.host;
} catch {
return false;
}
});
}
return true;
}
export function splitCredentialFieldsBySystem(
credentialFields: CredentialField[],
allProviders: CredentialsProvidersContextType | null,
inputCredentials?: Record<string, unknown>,
) {
if (!allProviders || credentialFields.length === 0) {
return {
@@ -77,9 +52,17 @@ export function splitCredentialFieldsBySystem(
}
}
const sortByUnsetFirst = (a: CredentialField, b: CredentialField) => {
const aIsSet = Boolean(inputCredentials?.[a[0]]);
const bIsSet = Boolean(inputCredentials?.[b[0]]);
if (aIsSet === bIsSet) return 0;
return aIsSet ? 1 : -1;
};
return {
userCredentialFields: userFields,
systemCredentialFields: systemFields,
userCredentialFields: userFields.sort(sortByUnsetFirst),
systemCredentialFields: systemFields.sort(sortByUnsetFirst),
};
}
@@ -177,7 +160,6 @@ export function findSavedCredentialByProviderAndType(
credentialTypes: string[],
requiredScopes: string[] | undefined,
allProviders: CredentialsProvidersContextType | null,
discriminatorValues?: string[],
): SavedCredential | undefined {
for (const providerName of providerNames) {
const providerData = allProviders?.[providerName];
@@ -194,14 +176,9 @@ export function findSavedCredentialByProviderAndType(
credentialTypes.length === 0 ||
credentialTypes.includes(credential.type);
const scopesMatch = hasRequiredScopes(credential, requiredScopes);
const hostMatches = matchesDiscriminatorValues(
credential,
discriminatorValues,
);
if (!typeMatches) continue;
if (!scopesMatch) continue;
if (!hostMatches) continue;
matchingCredentials.push(credential as SavedCredential);
}
@@ -213,14 +190,9 @@ export function findSavedCredentialByProviderAndType(
credentialTypes.length === 0 ||
credentialTypes.includes(credential.type);
const scopesMatch = hasRequiredScopes(credential, requiredScopes);
const hostMatches = matchesDiscriminatorValues(
credential,
discriminatorValues,
);
if (!typeMatches) continue;
if (!scopesMatch) continue;
if (!hostMatches) continue;
matchingCredentials.push(credential as SavedCredential);
}
@@ -242,7 +214,6 @@ export function findSavedUserCredentialByProviderAndType(
credentialTypes: string[],
requiredScopes: string[] | undefined,
allProviders: CredentialsProvidersContextType | null,
discriminatorValues?: string[],
): SavedCredential | undefined {
for (const providerName of providerNames) {
const providerData = allProviders?.[providerName];
@@ -259,14 +230,9 @@ export function findSavedUserCredentialByProviderAndType(
credentialTypes.length === 0 ||
credentialTypes.includes(credential.type);
const scopesMatch = hasRequiredScopes(credential, requiredScopes);
const hostMatches = matchesDiscriminatorValues(
credential,
discriminatorValues,
);
if (!typeMatches) continue;
if (!scopesMatch) continue;
if (!hostMatches) continue;
matchingCredentials.push(credential as SavedCredential);
}

View File

@@ -5,14 +5,14 @@ import {
BlockIOCredentialsSubSchema,
CredentialsMetaInput,
} from "@/lib/autogpt-server-api/types";
import { postV2InitiateOauthLoginForAnMcpServer } from "@/app/api/__generated__/endpoints/mcp/mcp";
import { openOAuthPopup } from "@/lib/oauth-popup";
import { useQueryClient } from "@tanstack/react-query";
import { useEffect, useRef, useState } from "react";
import {
filterSystemCredentials,
getActionButtonText,
getSystemCredentials,
OAUTH_TIMEOUT_MS,
OAuthPopupResultMessage,
} from "./helpers";
export type CredentialsInputState = ReturnType<typeof useCredentialsInput>;
@@ -57,14 +57,6 @@ export function useCredentialsInput({
const queryClient = useQueryClient();
const credentials = useCredentials(schema, siblingInputs);
const hasAttemptedAutoSelect = useRef(false);
const oauthAbortRef = useRef<((reason?: string) => void) | null>(null);
// Clean up on unmount
useEffect(() => {
return () => {
oauthAbortRef.current?.();
};
}, []);
const deleteCredentialsMutation = useDeleteV1DeleteCredentials({
mutation: {
@@ -89,14 +81,11 @@ export function useCredentialsInput({
}
}, [credentials, onLoaded]);
// Unselect credential if not available in the loaded credential list.
// Skip when no credentials have been loaded yet (empty list could mean
// the provider data hasn't finished loading, not that the credential is invalid).
// Unselect credential if not available
useEffect(() => {
if (readOnly) return;
if (!credentials || !("savedCredentials" in credentials)) return;
const availableCreds = credentials.savedCredentials;
if (availableCreds.length === 0) return;
if (
selectedCredential &&
!availableCreds.some((c) => c.id === selectedCredential.id)
@@ -121,9 +110,7 @@ export function useCredentialsInput({
if (hasAttemptedAutoSelect.current) return;
hasAttemptedAutoSelect.current = true;
// Auto-select if exactly one credential matches.
// For optional fields with multiple options, let the user choose.
if (isOptional && savedCreds.length > 1) return;
if (isOptional) return;
const cred = savedCreds[0];
onSelectCredential({
@@ -161,9 +148,7 @@ export function useCredentialsInput({
supportsHostScoped,
savedCredentials,
oAuthCallback,
mcpOAuthCallback,
isSystemProvider,
discriminatorValue,
} = credentials;
// Split credentials into user and system
@@ -172,66 +157,72 @@ export function useCredentialsInput({
async function handleOAuthLogin() {
setOAuthError(null);
const { login_url, state_token } = await api.oAuthLogin(
provider,
schema.credentials_scopes,
);
setOAuth2FlowInProgress(true);
const popup = window.open(login_url, "_blank", "popup=true");
// Abort any previous OAuth flow
oauthAbortRef.current?.();
if (!popup) {
throw new Error(
"Failed to open popup window. Please allow popups for this site.",
);
}
// MCP uses dynamic OAuth discovery per server URL
const isMCP = provider === "mcp" && !!discriminatorValue;
const controller = new AbortController();
setOAuthPopupController(controller);
controller.signal.onabort = () => {
console.debug("OAuth flow aborted");
setOAuth2FlowInProgress(false);
popup.close();
};
try {
let login_url: string;
let state_token: string;
if (isMCP) {
const mcpLoginResponse = await postV2InitiateOauthLoginForAnMcpServer({
server_url: discriminatorValue!,
});
if (mcpLoginResponse.status !== 200) throw mcpLoginResponse.data;
({ login_url, state_token } = mcpLoginResponse.data);
} else {
({ login_url, state_token } = await api.oAuthLogin(
provider,
schema.credentials_scopes,
));
const handleMessage = async (e: MessageEvent<OAuthPopupResultMessage>) => {
console.debug("Message received:", e.data);
if (
typeof e.data != "object" ||
!("message_type" in e.data) ||
e.data.message_type !== "oauth_popup_result"
) {
console.debug("Ignoring irrelevant message");
return;
}
setOAuth2FlowInProgress(true);
if (!e.data.success) {
console.error("OAuth flow failed:", e.data.message);
setOAuthError(`OAuth flow failed: ${e.data.message}`);
setOAuth2FlowInProgress(false);
return;
}
const { promise, cleanup } = openOAuthPopup(login_url, {
stateToken: state_token,
useCrossOriginListeners: isMCP,
// Standard OAuth uses "oauth_popup_result", MCP uses "mcp_oauth_result"
acceptMessageTypes: isMCP
? ["mcp_oauth_result"]
: ["oauth_popup_result"],
});
if (e.data.state !== state_token) {
console.error("Invalid state token received");
setOAuthError("Invalid state token received");
setOAuth2FlowInProgress(false);
return;
}
oauthAbortRef.current = cleanup.abort;
// Expose abort signal for the waiting modal's cancel button
const controller = new AbortController();
cleanup.signal.addEventListener("abort", () =>
controller.abort("completed"),
);
setOAuthPopupController(controller);
try {
console.debug("Processing OAuth callback");
const credentials = await oAuthCallback(e.data.code, e.data.state);
console.debug("OAuth callback processed successfully");
const result = await promise;
// Exchange code for tokens via the provider (updates credential cache)
const credentialResult = isMCP
? await mcpOAuthCallback(result.code, state_token)
: await oAuthCallback(result.code, result.state);
// Check if the credential's scopes match the required scopes (skip for MCP)
if (!isMCP) {
// Check if the credential's scopes match the required scopes
const requiredScopes = schema.credentials_scopes;
if (requiredScopes && requiredScopes.length > 0) {
const grantedScopes = new Set(credentialResult.scopes || []);
const grantedScopes = new Set(credentials.scopes || []);
const hasAllRequiredScopes = new Set(requiredScopes).isSubsetOf(
grantedScopes,
);
if (!hasAllRequiredScopes) {
console.error(
`Newly created OAuth credential for ${providerName} has insufficient scopes. Required:`,
requiredScopes,
"Granted:",
credentials.scopes,
);
setOAuthError(
"Connection failed: the granted permissions don't match what's required. " +
"Please contact the application administrator.",
@@ -239,28 +230,38 @@ export function useCredentialsInput({
return;
}
}
}
onSelectCredential({
id: credentialResult.id,
type: "oauth2",
title: credentialResult.title,
provider,
});
} catch (error) {
if (error instanceof Error && error.message === "OAuth flow timed out") {
setOAuthError("OAuth flow timed out");
} else {
onSelectCredential({
id: credentials.id,
type: "oauth2",
title: credentials.title,
provider,
});
} catch (error) {
console.error("Error in OAuth callback:", error);
setOAuthError(
`OAuth error: ${
`Error in OAuth callback: ${
error instanceof Error ? error.message : String(error)
}`,
);
} finally {
console.debug("Finalizing OAuth flow");
setOAuth2FlowInProgress(false);
controller.abort("success");
}
} finally {
};
console.debug("Adding message event listener");
window.addEventListener("message", handleMessage, {
signal: controller.signal,
});
setTimeout(() => {
console.debug("OAuth flow timed out");
controller.abort("timeout");
setOAuth2FlowInProgress(false);
oauthAbortRef.current = null;
}
setOAuthError("OAuth flow timed out");
}, OAUTH_TIMEOUT_MS);
}
function handleActionButtonClick() {

View File

@@ -100,11 +100,6 @@ export default function useCredentials(
return false;
}
// Filter MCP OAuth2 credentials by server URL matching
if (c.type === "oauth2" && c.provider === "mcp") {
return discriminatorValue != null && c.host === discriminatorValue;
}
// Filter by OAuth credentials that have sufficient scopes for this block
if (c.type === "oauth2") {
const requiredScopes = credsInputSchema.credentials_scopes;

View File

@@ -749,12 +749,10 @@ export enum BlockUIType {
AGENT = "Agent",
AI = "AI",
AYRSHARE = "Ayrshare",
MCP_TOOL = "MCP Tool",
}
export enum SpecialBlockID {
AGENT = "e189baac-8c20-45a1-94a7-55177ea42565",
MCP_TOOL = "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4",
SMART_DECISION = "3b191d9f-356f-482d-8238-ba04b6d18381",
OUTPUT = "363ae599-353e-4804-937e-b2ee3cef3da4",
}

View File

@@ -1,177 +0,0 @@
/**
* Shared utility for OAuth popup flows with cross-origin support.
*
* Handles BroadcastChannel, postMessage, and localStorage polling
* to reliably receive OAuth callback results even when COOP headers
* sever the window.opener relationship.
*/
const DEFAULT_TIMEOUT_MS = 5 * 60 * 1000; // 5 minutes
export type OAuthPopupResult = {
code: string;
state: string;
};
export type OAuthPopupOptions = {
/** State token to validate against incoming messages */
stateToken: string;
/**
* Use BroadcastChannel + localStorage polling for cross-origin OAuth (MCP).
* Standard OAuth only uses postMessage via window.opener.
*/
useCrossOriginListeners?: boolean;
/** BroadcastChannel name (default: "mcp_oauth") */
broadcastChannelName?: string;
/** localStorage key for cross-origin fallback (default: "mcp_oauth_result") */
localStorageKey?: string;
/** Message types to accept (default: ["oauth_popup_result", "mcp_oauth_result"]) */
acceptMessageTypes?: string[];
/** Timeout in ms (default: 5 minutes) */
timeout?: number;
};
type Cleanup = {
/** Abort the OAuth flow and close the popup */
abort: (reason?: string) => void;
/** The AbortController signal */
signal: AbortSignal;
};
/**
* Opens an OAuth popup and sets up listeners for the callback result.
*
* Opens a blank popup synchronously (to avoid popup blockers), then navigates
* it to the login URL. Returns a promise that resolves with the OAuth code/state.
*
* @param loginUrl - The OAuth authorization URL to navigate to
* @param options - Configuration for message handling
* @returns Object with `promise` (resolves with OAuth result) and `abort` (cancels flow)
*/
export function openOAuthPopup(
loginUrl: string,
options: OAuthPopupOptions,
): { promise: Promise<OAuthPopupResult>; cleanup: Cleanup } {
const {
stateToken,
useCrossOriginListeners = false,
broadcastChannelName = "mcp_oauth",
localStorageKey = "mcp_oauth_result",
acceptMessageTypes = ["oauth_popup_result", "mcp_oauth_result"],
timeout = DEFAULT_TIMEOUT_MS,
} = options;
const controller = new AbortController();
// Open popup synchronously (before any async work) to avoid browser popup blockers
const width = 500;
const height = 700;
const left = window.screenX + (window.outerWidth - width) / 2;
const top = window.screenY + (window.outerHeight - height) / 2;
const popup = window.open(
"about:blank",
"_blank",
`width=${width},height=${height},left=${left},top=${top},popup=true,scrollbars=yes`,
);
if (popup && !popup.closed) {
popup.location.href = loginUrl;
} else {
// Popup was blocked — open in new tab as fallback
window.open(loginUrl, "_blank");
}
// Close popup on abort
controller.signal.addEventListener("abort", () => {
if (popup && !popup.closed) popup.close();
});
// Clear any stale localStorage entry
if (useCrossOriginListeners) {
try {
localStorage.removeItem(localStorageKey);
} catch {}
}
const promise = new Promise<OAuthPopupResult>((resolve, reject) => {
let handled = false;
const handleResult = (data: any) => {
if (handled) return; // Prevent double-handling
// Validate message type
const messageType = data?.message_type ?? data?.type;
if (!messageType || !acceptMessageTypes.includes(messageType)) return;
// Validate state token
if (data.state !== stateToken) {
// State mismatch — this message is for a different listener. Ignore silently.
return;
}
handled = true;
if (!data.success) {
reject(new Error(data.message || "OAuth authentication failed"));
} else {
resolve({ code: data.code, state: data.state });
}
controller.abort("completed");
};
// Listener: postMessage (works for same-origin popups)
window.addEventListener(
"message",
(event: MessageEvent) => {
if (typeof event.data === "object") {
handleResult(event.data);
}
},
{ signal: controller.signal },
);
// Cross-origin listeners for MCP OAuth
if (useCrossOriginListeners) {
// Listener: BroadcastChannel (works across tabs/popups without opener)
try {
const bc = new BroadcastChannel(broadcastChannelName);
bc.onmessage = (event) => handleResult(event.data);
controller.signal.addEventListener("abort", () => bc.close());
} catch {}
// Listener: localStorage polling (most reliable cross-tab fallback)
const pollInterval = setInterval(() => {
try {
const stored = localStorage.getItem(localStorageKey);
if (stored) {
const data = JSON.parse(stored);
localStorage.removeItem(localStorageKey);
handleResult(data);
}
} catch {}
}, 500);
controller.signal.addEventListener("abort", () =>
clearInterval(pollInterval),
);
}
// Timeout
const timeoutId = setTimeout(() => {
if (!handled) {
handled = true;
reject(new Error("OAuth flow timed out"));
controller.abort("timeout");
}
}, timeout);
controller.signal.addEventListener("abort", () => clearTimeout(timeoutId));
});
return {
promise,
cleanup: {
abort: (reason?: string) => controller.abort(reason || "canceled"),
signal: controller.signal,
},
};
}

View File

@@ -18,6 +18,6 @@ export const config = {
* Note: /auth/authorize and /auth/integrations/* ARE protected and need
* middleware to run for authentication checks.
*/
"/((?!_next/static|_next/image|favicon.ico|auth/callback|auth/integrations/mcp_callback|.*\\.(?:svg|png|jpg|jpeg|gif|webp)$).*)",
"/((?!_next/static|_next/image|favicon.ico|auth/callback|.*\\.(?:svg|png|jpg|jpeg|gif|webp)$).*)",
],
};

View File

@@ -8,7 +8,6 @@ import {
HostScopedCredentials,
UserPasswordCredentials,
} from "@/lib/autogpt-server-api";
import { postV2ExchangeOauthCodeForMcpTokens } from "@/app/api/__generated__/endpoints/mcp/mcp";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { toDisplayName } from "@/providers/agent-credentials/helper";
@@ -39,11 +38,6 @@ export type CredentialsProviderData = {
code: string,
state_token: string,
) => Promise<CredentialsMetaResponse>;
/** MCP-specific OAuth callback that uses dynamic per-server OAuth discovery. */
mcpOAuthCallback: (
code: string,
state_token: string,
) => Promise<CredentialsMetaResponse>;
createAPIKeyCredentials: (
credentials: APIKeyCredentialsCreatable,
) => Promise<CredentialsMetaResponse>;
@@ -126,35 +120,6 @@ export default function CredentialsProvider({
[api, addCredentials, onFailToast],
);
/** Exchanges an MCP OAuth code for tokens and adds the result to the internal credentials store. */
const mcpOAuthCallback = useCallback(
async (
code: string,
state_token: string,
): Promise<CredentialsMetaResponse> => {
try {
const response = await postV2ExchangeOauthCodeForMcpTokens({
code,
state_token,
});
if (response.status !== 200) throw response.data;
const credsMeta: CredentialsMetaResponse = {
...response.data,
title: response.data.title ?? undefined,
scopes: response.data.scopes ?? undefined,
username: response.data.username ?? undefined,
host: response.data.host ?? undefined,
};
addCredentials("mcp", credsMeta);
return credsMeta;
} catch (error) {
onFailToast("complete MCP OAuth authentication")(error);
throw error;
}
},
[addCredentials, onFailToast],
);
/** Wraps `BackendAPI.createAPIKeyCredentials`, and adds the result to the internal credentials store. */
const createAPIKeyCredentials = useCallback(
async (
@@ -293,7 +258,6 @@ export default function CredentialsProvider({
isSystemProvider: systemProviders.has(provider),
oAuthCallback: (code: string, state_token: string) =>
oAuthCallback(provider, code, state_token),
mcpOAuthCallback,
createAPIKeyCredentials: (
credentials: APIKeyCredentialsCreatable,
) => createAPIKeyCredentials(provider, credentials),
@@ -322,7 +286,6 @@ export default function CredentialsProvider({
createHostScopedCredentials,
deleteCredentials,
oAuthCallback,
mcpOAuthCallback,
onFailToast,
]);

View File

@@ -528,9 +528,6 @@ export class BuildPage extends BasePage {
async getBlocksToSkip(): Promise<string[]> {
return [
(await this.getGithubTriggerBlockDetails()).map((b) => b.id),
// MCP Tool block requires an interactive dialog (server URL + OAuth) before
// it can be placed, so it can't be tested via the standard "add block" flow.
"a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4",
].flat();
}

View File

@@ -467,7 +467,6 @@ Below is a comprehensive list of all available blocks, categorized by their prim
| [Github Update Comment](block-integrations/github/issues.md#github-update-comment) | A block that updates an existing comment on a GitHub issue or pull request |
| [Github Update File](block-integrations/github/repo.md#github-update-file) | This block updates an existing file in a GitHub repository |
| [Instantiate Code Sandbox](block-integrations/misc.md#instantiate-code-sandbox) | Instantiate a sandbox environment with internet access in which you can execute code with the Execute Code Step block |
| [MCP Tool](block-integrations/mcp/block.md#mcp-tool) | Connect to any MCP server and execute its tools |
| [Slant3D Order Webhook](block-integrations/slant3d/webhook.md#slant3d-order-webhook) | This block triggers on Slant3D order status updates and outputs the event details, including tracking information when orders are shipped |
## Media Generation

View File

@@ -84,7 +84,6 @@
* [Linear Projects](block-integrations/linear/projects.md)
* [LLM](block-integrations/llm.md)
* [Logic](block-integrations/logic.md)
* [Mcp Block](block-integrations/mcp/block.md)
* [Misc](block-integrations/misc.md)
* [Notion Create Page](block-integrations/notion/create_page.md)
* [Notion Read Database](block-integrations/notion/read_database.md)

Some files were not shown because too many files have changed in this diff Show More