mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-13 16:25:05 -05:00
Compare commits
4 Commits
dev
...
refactor/r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5739de04f0 | ||
|
|
28ad3d0b01 | ||
|
|
361d6ff6fc | ||
|
|
0fe6cc8dc7 |
1229
.github/scripts/detect_overlaps.py
vendored
1229
.github/scripts/detect_overlaps.py
vendored
File diff suppressed because it is too large
Load Diff
4
.github/workflows/codeql.yml
vendored
4
.github/workflows/codeql.yml
vendored
@@ -62,7 +62,7 @@ jobs:
|
|||||||
|
|
||||||
# Initializes the CodeQL tools for scanning.
|
# Initializes the CodeQL tools for scanning.
|
||||||
- name: Initialize CodeQL
|
- name: Initialize CodeQL
|
||||||
uses: github/codeql-action/init@v4
|
uses: github/codeql-action/init@v3
|
||||||
with:
|
with:
|
||||||
languages: ${{ matrix.language }}
|
languages: ${{ matrix.language }}
|
||||||
build-mode: ${{ matrix.build-mode }}
|
build-mode: ${{ matrix.build-mode }}
|
||||||
@@ -93,6 +93,6 @@ jobs:
|
|||||||
exit 1
|
exit 1
|
||||||
|
|
||||||
- name: Perform CodeQL Analysis
|
- name: Perform CodeQL Analysis
|
||||||
uses: github/codeql-action/analyze@v4
|
uses: github/codeql-action/analyze@v3
|
||||||
with:
|
with:
|
||||||
category: "/language:${{matrix.language}}"
|
category: "/language:${{matrix.language}}"
|
||||||
|
|||||||
34
.github/workflows/docs-claude-review.yml
vendored
34
.github/workflows/docs-claude-review.yml
vendored
@@ -7,10 +7,6 @@ on:
|
|||||||
- "docs/integrations/**"
|
- "docs/integrations/**"
|
||||||
- "autogpt_platform/backend/backend/blocks/**"
|
- "autogpt_platform/backend/backend/blocks/**"
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: claude-docs-review-${{ github.event.pull_request.number }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
claude-review:
|
claude-review:
|
||||||
# Only run for PRs from members/collaborators
|
# Only run for PRs from members/collaborators
|
||||||
@@ -95,35 +91,5 @@ jobs:
|
|||||||
3. Read corresponding documentation files to verify accuracy
|
3. Read corresponding documentation files to verify accuracy
|
||||||
4. Provide your feedback as a PR comment
|
4. Provide your feedback as a PR comment
|
||||||
|
|
||||||
## IMPORTANT: Comment Marker
|
|
||||||
Start your PR comment with exactly this HTML comment marker on its own line:
|
|
||||||
<!-- CLAUDE_DOCS_REVIEW -->
|
|
||||||
|
|
||||||
This marker is used to identify and replace your comment on subsequent runs.
|
|
||||||
|
|
||||||
Be constructive and specific. If everything looks good, say so!
|
Be constructive and specific. If everything looks good, say so!
|
||||||
If there are issues, explain what's wrong and suggest how to fix it.
|
If there are issues, explain what's wrong and suggest how to fix it.
|
||||||
|
|
||||||
- name: Delete old Claude review comments
|
|
||||||
env:
|
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
run: |
|
|
||||||
# Get all comment IDs with our marker, sorted by creation date (oldest first)
|
|
||||||
COMMENT_IDS=$(gh api \
|
|
||||||
repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments \
|
|
||||||
--jq '[.[] | select(.body | contains("<!-- CLAUDE_DOCS_REVIEW -->"))] | sort_by(.created_at) | .[].id')
|
|
||||||
|
|
||||||
# Count comments
|
|
||||||
COMMENT_COUNT=$(echo "$COMMENT_IDS" | grep -c . || true)
|
|
||||||
|
|
||||||
if [ "$COMMENT_COUNT" -gt 1 ]; then
|
|
||||||
# Delete all but the last (newest) comment
|
|
||||||
echo "$COMMENT_IDS" | head -n -1 | while read -r COMMENT_ID; do
|
|
||||||
if [ -n "$COMMENT_ID" ]; then
|
|
||||||
echo "Deleting old review comment: $COMMENT_ID"
|
|
||||||
gh api -X DELETE repos/${{ github.repository }}/issues/comments/$COMMENT_ID
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
else
|
|
||||||
echo "No old review comments to clean up"
|
|
||||||
fi
|
|
||||||
|
|||||||
39
.github/workflows/pr-overlap-check.yml
vendored
39
.github/workflows/pr-overlap-check.yml
vendored
@@ -1,39 +0,0 @@
|
|||||||
name: PR Overlap Detection
|
|
||||||
|
|
||||||
on:
|
|
||||||
pull_request:
|
|
||||||
types: [opened, synchronize, reopened]
|
|
||||||
branches:
|
|
||||||
- dev
|
|
||||||
- master
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
pull-requests: write
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
check-overlaps:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0 # Need full history for merge testing
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: '3.11'
|
|
||||||
|
|
||||||
- name: Configure git
|
|
||||||
run: |
|
|
||||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
|
||||||
git config user.name "github-actions[bot]"
|
|
||||||
|
|
||||||
- name: Run overlap detection
|
|
||||||
env:
|
|
||||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
# Always succeed - this check informs contributors, it shouldn't block merging
|
|
||||||
continue-on-error: true
|
|
||||||
run: |
|
|
||||||
python .github/scripts/detect_overlaps.py ${{ github.event.pull_request.number }}
|
|
||||||
@@ -66,19 +66,13 @@ ENV POETRY_HOME=/opt/poetry \
|
|||||||
DEBIAN_FRONTEND=noninteractive
|
DEBIAN_FRONTEND=noninteractive
|
||||||
ENV PATH=/opt/poetry/bin:$PATH
|
ENV PATH=/opt/poetry/bin:$PATH
|
||||||
|
|
||||||
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
|
# Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
|
||||||
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
|
|
||||||
# for the bash_exec MCP tool.
|
|
||||||
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
python3.13 \
|
python3.13 \
|
||||||
python3-pip \
|
python3-pip \
|
||||||
ffmpeg \
|
ffmpeg \
|
||||||
imagemagick \
|
imagemagick \
|
||||||
jq \
|
|
||||||
ripgrep \
|
|
||||||
tree \
|
|
||||||
bubblewrap \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||||
|
|||||||
@@ -27,11 +27,12 @@ class ChatConfig(BaseSettings):
|
|||||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||||
|
|
||||||
# Streaming Configuration
|
# Streaming Configuration
|
||||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
max_context_messages: int = Field(
|
||||||
max_retries: int = Field(
|
default=50, ge=1, le=200, description="Maximum context messages"
|
||||||
default=3,
|
|
||||||
description="Max retries for fallback path (SDK handles retries internally)",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||||
|
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||||
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
|
||||||
max_agent_schedules: int = Field(
|
max_agent_schedules: int = Field(
|
||||||
default=30, description="Maximum number of agent schedules"
|
default=30, description="Maximum number of agent schedules"
|
||||||
@@ -92,31 +93,6 @@ class ChatConfig(BaseSettings):
|
|||||||
description="Name of the prompt in Langfuse to fetch",
|
description="Name of the prompt in Langfuse to fetch",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Claude Agent SDK Configuration
|
|
||||||
use_claude_agent_sdk: bool = Field(
|
|
||||||
default=True,
|
|
||||||
description="Use Claude Agent SDK for chat completions",
|
|
||||||
)
|
|
||||||
claude_agent_model: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Model for the Claude Agent SDK path. If None, derives from "
|
|
||||||
"the `model` field by stripping the OpenRouter provider prefix.",
|
|
||||||
)
|
|
||||||
claude_agent_max_buffer_size: int = Field(
|
|
||||||
default=10 * 1024 * 1024, # 10MB (default SDK is 1MB)
|
|
||||||
description="Max buffer size in bytes for Claude Agent SDK JSON message parsing. "
|
|
||||||
"Increase if tool outputs exceed the limit.",
|
|
||||||
)
|
|
||||||
claude_agent_max_subtasks: int = Field(
|
|
||||||
default=10,
|
|
||||||
description="Max number of sub-agent Tasks the SDK can spawn per session.",
|
|
||||||
)
|
|
||||||
claude_agent_use_resume: bool = Field(
|
|
||||||
default=True,
|
|
||||||
description="Use --resume for multi-turn conversations instead of "
|
|
||||||
"history compression. Falls back to compression when unavailable.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extended thinking configuration for Claude models
|
# Extended thinking configuration for Claude models
|
||||||
thinking_enabled: bool = Field(
|
thinking_enabled: bool = Field(
|
||||||
default=True,
|
default=True,
|
||||||
@@ -162,17 +138,6 @@ class ChatConfig(BaseSettings):
|
|||||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator("use_claude_agent_sdk", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def get_use_claude_agent_sdk(cls, v):
|
|
||||||
"""Get use_claude_agent_sdk from environment if not provided."""
|
|
||||||
# Check environment variable - default to True if not set
|
|
||||||
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
|
|
||||||
if env_val:
|
|
||||||
return env_val in ("true", "1", "yes", "on")
|
|
||||||
# Default to True (SDK enabled by default)
|
|
||||||
return True if v is None else v
|
|
||||||
|
|
||||||
# Prompt paths for different contexts
|
# Prompt paths for different contexts
|
||||||
PROMPT_PATHS: dict[str, str] = {
|
PROMPT_PATHS: dict[str, str] = {
|
||||||
"default": "prompts/chat_system.md",
|
"default": "prompts/chat_system.md",
|
||||||
|
|||||||
@@ -334,8 +334,9 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
|||||||
try:
|
try:
|
||||||
session = ChatSession.model_validate_json(raw_session)
|
session = ChatSession.model_validate_json(raw_session)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
|
f"Loading session {session_id} from cache: "
|
||||||
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
|
f"message_count={len(session.messages)}, "
|
||||||
|
f"roles={[m.role for m in session.messages]}"
|
||||||
)
|
)
|
||||||
return session
|
return session
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -377,9 +378,11 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
messages = prisma_session.Messages
|
messages = prisma_session.Messages
|
||||||
logger.debug(
|
logger.info(
|
||||||
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
|
f"Loading session {session_id} from DB: "
|
||||||
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
|
f"has_messages={messages is not None}, "
|
||||||
|
f"message_count={len(messages) if messages else 0}, "
|
||||||
|
f"roles={[m.role for m in messages] if messages else []}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatSession.from_db(prisma_session, messages)
|
return ChatSession.from_db(prisma_session, messages)
|
||||||
@@ -430,9 +433,10 @@ async def _save_session_to_db(
|
|||||||
"function_call": msg.function_call,
|
"function_call": msg.function_call,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.info(
|
||||||
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
|
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||||
f"roles={[m['role'] for m in messages_data]}"
|
f"roles={[m['role'] for m in messages_data]}, "
|
||||||
|
f"start_sequence={existing_message_count}"
|
||||||
)
|
)
|
||||||
await chat_db.add_chat_messages_batch(
|
await chat_db.add_chat_messages_batch(
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
@@ -472,7 +476,7 @@ async def get_chat_session(
|
|||||||
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
logger.warning(f"Unexpected cache error for session {session_id}: {e}")
|
||||||
|
|
||||||
# Fall back to database
|
# Fall back to database
|
||||||
logger.debug(f"Session {session_id} not in cache, checking database")
|
logger.info(f"Session {session_id} not in cache, checking database")
|
||||||
session = await _get_session_from_db(session_id)
|
session = await _get_session_from_db(session_id)
|
||||||
|
|
||||||
if session is None:
|
if session is None:
|
||||||
@@ -489,6 +493,7 @@ async def get_chat_session(
|
|||||||
# Cache the session from DB
|
# Cache the session from DB
|
||||||
try:
|
try:
|
||||||
await _cache_session(session)
|
await _cache_session(session)
|
||||||
|
logger.info(f"Cached session {session_id} from database")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||||
|
|
||||||
@@ -553,40 +558,6 @@ async def upsert_chat_session(
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
|
||||||
"""Atomically append a message to a session and persist it.
|
|
||||||
|
|
||||||
Acquires the session lock, re-fetches the latest session state,
|
|
||||||
appends the message, and saves — preventing message loss when
|
|
||||||
concurrent requests modify the same session.
|
|
||||||
"""
|
|
||||||
lock = await _get_session_lock(session_id)
|
|
||||||
|
|
||||||
async with lock:
|
|
||||||
session = await get_chat_session(session_id)
|
|
||||||
if session is None:
|
|
||||||
raise ValueError(f"Session {session_id} not found")
|
|
||||||
|
|
||||||
session.messages.append(message)
|
|
||||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
|
||||||
session_id
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await _save_session_to_db(session, existing_message_count)
|
|
||||||
except Exception as e:
|
|
||||||
raise DatabaseError(
|
|
||||||
f"Failed to persist message to session {session_id}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
try:
|
|
||||||
await _cache_session(session)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Cache write failed for session {session_id}: {e}")
|
|
||||||
|
|
||||||
return session
|
|
||||||
|
|
||||||
|
|
||||||
async def create_chat_session(user_id: str) -> ChatSession:
|
async def create_chat_session(user_id: str) -> ChatSession:
|
||||||
"""Create a new chat session and persist it.
|
"""Create a new chat session and persist it.
|
||||||
|
|
||||||
@@ -693,19 +664,13 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
|||||||
logger.warning(f"Session {session_id} not found for title update")
|
logger.warning(f"Session {session_id} not found for title update")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Update title in cache if it exists (instead of invalidating).
|
# Invalidate cache so next fetch gets updated title
|
||||||
# This prevents race conditions where cache invalidation causes
|
|
||||||
# the frontend to see stale DB data while streaming is still in progress.
|
|
||||||
try:
|
try:
|
||||||
cached = await _get_session_from_cache(session_id)
|
redis_key = _get_session_cache_key(session_id)
|
||||||
if cached:
|
async_redis = await get_redis_async()
|
||||||
cached.title = title
|
await async_redis.delete(redis_key)
|
||||||
await _cache_session(cached)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Not critical - title will be correct on next full cache refresh
|
logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
|
||||||
logger.warning(
|
|
||||||
f"Failed to update title in cache for session {session_id}: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
"""Chat API routes for chat session management and streaming via SSE."""
|
"""Chat API routes for chat session management and streaming via SSE."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import uuid as uuid_module
|
import uuid as uuid_module
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
@@ -12,22 +11,13 @@ from fastapi.responses import StreamingResponse
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
|
||||||
|
|
||||||
from . import service as chat_service
|
from . import service as chat_service
|
||||||
from . import stream_registry
|
from . import stream_registry
|
||||||
from .completion_handler import process_operation_failure, process_operation_success
|
from .completion_handler import process_operation_failure, process_operation_success
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import (
|
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||||
ChatMessage,
|
from .response_model import StreamFinish, StreamHeartbeat
|
||||||
ChatSession,
|
|
||||||
append_and_save_message,
|
|
||||||
create_chat_session,
|
|
||||||
get_chat_session,
|
|
||||||
get_user_sessions,
|
|
||||||
)
|
|
||||||
from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart
|
|
||||||
from .sdk import service as sdk_service
|
|
||||||
from .tools.models import (
|
from .tools.models import (
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
AgentOutputResponse,
|
AgentOutputResponse,
|
||||||
@@ -51,7 +41,6 @@ from .tools.models import (
|
|||||||
SetupRequirementsResponse,
|
SetupRequirementsResponse,
|
||||||
UnderstandingUpdatedResponse,
|
UnderstandingUpdatedResponse,
|
||||||
)
|
)
|
||||||
from .tracking import track_user_message
|
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
@@ -243,10 +232,6 @@ async def get_session(
|
|||||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||||
session_id, user_id
|
session_id, user_id
|
||||||
)
|
)
|
||||||
logger.info(
|
|
||||||
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
|
|
||||||
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
|
|
||||||
)
|
|
||||||
if active_task:
|
if active_task:
|
||||||
# Filter out the in-progress assistant message from the session response.
|
# Filter out the in-progress assistant message from the session response.
|
||||||
# The client will receive the complete assistant response through the SSE
|
# The client will receive the complete assistant response through the SSE
|
||||||
@@ -316,6 +301,7 @@ async def stream_chat_post(
|
|||||||
f"user={user_id}, message_len={len(request.message)}",
|
f"user={user_id}, message_len={len(request.message)}",
|
||||||
extra={"json_fields": log_meta},
|
extra={"json_fields": log_meta},
|
||||||
)
|
)
|
||||||
|
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
||||||
@@ -327,25 +313,6 @@ async def stream_chat_post(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Atomically append user message to session BEFORE creating task to avoid
|
|
||||||
# race condition where GET_SESSION sees task as "running" but message isn't
|
|
||||||
# saved yet. append_and_save_message re-fetches inside a lock to prevent
|
|
||||||
# message loss from concurrent requests.
|
|
||||||
if request.message:
|
|
||||||
message = ChatMessage(
|
|
||||||
role="user" if request.is_user_message else "assistant",
|
|
||||||
content=request.message,
|
|
||||||
)
|
|
||||||
if request.is_user_message:
|
|
||||||
track_user_message(
|
|
||||||
user_id=user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
message_length=len(request.message),
|
|
||||||
)
|
|
||||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
|
||||||
session = await append_and_save_message(session_id, message)
|
|
||||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
# Create a task in the stream registry for reconnection support
|
||||||
task_id = str(uuid_module.uuid4())
|
task_id = str(uuid_module.uuid4())
|
||||||
operation_id = str(uuid_module.uuid4())
|
operation_id = str(uuid_module.uuid4())
|
||||||
@@ -382,47 +349,15 @@ async def stream_chat_post(
|
|||||||
first_chunk_time, ttfc = None, None
|
first_chunk_time, ttfc = None, None
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
try:
|
try:
|
||||||
# Emit a start event with task_id for reconnection
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
|
||||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
|
|
||||||
* 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Choose service based on LaunchDarkly flag (falls back to config default)
|
|
||||||
use_sdk = await is_feature_enabled(
|
|
||||||
Flag.COPILOT_SDK,
|
|
||||||
user_id or "anonymous",
|
|
||||||
default=config.use_claude_agent_sdk,
|
|
||||||
)
|
|
||||||
stream_fn = (
|
|
||||||
sdk_service.stream_chat_completion_sdk
|
|
||||||
if use_sdk
|
|
||||||
else chat_service.stream_chat_completion
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
# Pass message=None since we already added it to the session above
|
|
||||||
async for chunk in stream_fn(
|
|
||||||
session_id,
|
session_id,
|
||||||
None, # Message already in session
|
request.message,
|
||||||
is_user_message=request.is_user_message,
|
is_user_message=request.is_user_message,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session=session, # Pass session with message already added
|
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||||
context=request.context,
|
context=request.context,
|
||||||
|
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
|
||||||
):
|
):
|
||||||
# Skip duplicate StreamStart — we already published one above
|
|
||||||
if isinstance(chunk, StreamStart):
|
|
||||||
continue
|
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
if first_chunk_time is None:
|
if first_chunk_time is None:
|
||||||
first_chunk_time = time_module.perf_counter()
|
first_chunk_time = time_module.perf_counter()
|
||||||
@@ -470,17 +405,6 @@ async def stream_chat_post(
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# Publish a StreamError so the frontend can display an error message
|
|
||||||
try:
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task_id,
|
|
||||||
StreamError(
|
|
||||||
errorText="An error occurred. Please try again.",
|
|
||||||
code="stream_error",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass # Best-effort; mark_task_completed will publish StreamFinish
|
|
||||||
await stream_registry.mark_task_completed(task_id, "failed")
|
await stream_registry.mark_task_completed(task_id, "failed")
|
||||||
|
|
||||||
# Start the AI generation in a background task
|
# Start the AI generation in a background task
|
||||||
@@ -583,14 +507,8 @@ async def stream_chat_post(
|
|||||||
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# Surface error to frontend so it doesn't appear stuck
|
|
||||||
yield StreamError(
|
|
||||||
errorText="An error occurred. Please try again.",
|
|
||||||
code="stream_error",
|
|
||||||
).to_sse()
|
|
||||||
yield StreamFinish().to_sse()
|
|
||||||
finally:
|
finally:
|
||||||
# Unsubscribe when client disconnects or stream ends
|
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||||
if subscriber_queue is not None:
|
if subscriber_queue is not None:
|
||||||
try:
|
try:
|
||||||
await stream_registry.unsubscribe_from_task(
|
await stream_registry.unsubscribe_from_task(
|
||||||
@@ -834,6 +752,8 @@ async def stream_task(
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
async def event_generator() -> AsyncGenerator[str, None]:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@@ -1,14 +0,0 @@
|
|||||||
"""Claude Agent SDK integration for CoPilot.
|
|
||||||
|
|
||||||
This module provides the integration layer between the Claude Agent SDK
|
|
||||||
and the existing CoPilot tool system, enabling drop-in replacement of
|
|
||||||
the current LLM orchestration with the battle-tested Claude Agent SDK.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .service import stream_chat_completion_sdk
|
|
||||||
from .tool_adapter import create_copilot_mcp_server
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"stream_chat_completion_sdk",
|
|
||||||
"create_copilot_mcp_server",
|
|
||||||
]
|
|
||||||
@@ -1,203 +0,0 @@
|
|||||||
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
|
||||||
|
|
||||||
This module provides the adapter layer that converts streaming messages from
|
|
||||||
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
|
|
||||||
the frontend expects.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from claude_agent_sdk import (
|
|
||||||
AssistantMessage,
|
|
||||||
Message,
|
|
||||||
ResultMessage,
|
|
||||||
SystemMessage,
|
|
||||||
TextBlock,
|
|
||||||
ToolResultBlock,
|
|
||||||
ToolUseBlock,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
|
|
||||||
from backend.api.features.chat.response_model import (
|
|
||||||
StreamBaseResponse,
|
|
||||||
StreamError,
|
|
||||||
StreamFinish,
|
|
||||||
StreamFinishStep,
|
|
||||||
StreamStart,
|
|
||||||
StreamStartStep,
|
|
||||||
StreamTextDelta,
|
|
||||||
StreamTextEnd,
|
|
||||||
StreamTextStart,
|
|
||||||
StreamToolInputAvailable,
|
|
||||||
StreamToolInputStart,
|
|
||||||
StreamToolOutputAvailable,
|
|
||||||
)
|
|
||||||
from backend.api.features.chat.sdk.tool_adapter import (
|
|
||||||
MCP_TOOL_PREFIX,
|
|
||||||
pop_pending_tool_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class SDKResponseAdapter:
|
|
||||||
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
|
||||||
|
|
||||||
This class maintains state during a streaming session to properly track
|
|
||||||
text blocks, tool calls, and message lifecycle.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, message_id: str | None = None):
|
|
||||||
self.message_id = message_id or str(uuid.uuid4())
|
|
||||||
self.text_block_id = str(uuid.uuid4())
|
|
||||||
self.has_started_text = False
|
|
||||||
self.has_ended_text = False
|
|
||||||
self.current_tool_calls: dict[str, dict[str, str]] = {}
|
|
||||||
self.task_id: str | None = None
|
|
||||||
self.step_open = False
|
|
||||||
|
|
||||||
def set_task_id(self, task_id: str) -> None:
|
|
||||||
"""Set the task ID for reconnection support."""
|
|
||||||
self.task_id = task_id
|
|
||||||
|
|
||||||
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
|
|
||||||
"""Convert a single SDK message to Vercel AI SDK format."""
|
|
||||||
responses: list[StreamBaseResponse] = []
|
|
||||||
|
|
||||||
if isinstance(sdk_message, SystemMessage):
|
|
||||||
if sdk_message.subtype == "init":
|
|
||||||
responses.append(
|
|
||||||
StreamStart(messageId=self.message_id, taskId=self.task_id)
|
|
||||||
)
|
|
||||||
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
|
|
||||||
responses.append(StreamStartStep())
|
|
||||||
self.step_open = True
|
|
||||||
|
|
||||||
elif isinstance(sdk_message, AssistantMessage):
|
|
||||||
# After tool results, the SDK sends a new AssistantMessage for the
|
|
||||||
# next LLM turn. Open a new step if the previous one was closed.
|
|
||||||
if not self.step_open:
|
|
||||||
responses.append(StreamStartStep())
|
|
||||||
self.step_open = True
|
|
||||||
|
|
||||||
for block in sdk_message.content:
|
|
||||||
if isinstance(block, TextBlock):
|
|
||||||
if block.text:
|
|
||||||
self._ensure_text_started(responses)
|
|
||||||
responses.append(
|
|
||||||
StreamTextDelta(id=self.text_block_id, delta=block.text)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif isinstance(block, ToolUseBlock):
|
|
||||||
self._end_text_if_open(responses)
|
|
||||||
|
|
||||||
# Strip MCP prefix so frontend sees "find_block"
|
|
||||||
# instead of "mcp__copilot__find_block".
|
|
||||||
tool_name = block.name.removeprefix(MCP_TOOL_PREFIX)
|
|
||||||
|
|
||||||
responses.append(
|
|
||||||
StreamToolInputStart(toolCallId=block.id, toolName=tool_name)
|
|
||||||
)
|
|
||||||
responses.append(
|
|
||||||
StreamToolInputAvailable(
|
|
||||||
toolCallId=block.id,
|
|
||||||
toolName=tool_name,
|
|
||||||
input=block.input,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.current_tool_calls[block.id] = {"name": tool_name}
|
|
||||||
|
|
||||||
elif isinstance(sdk_message, UserMessage):
|
|
||||||
# UserMessage carries tool results back from tool execution.
|
|
||||||
content = sdk_message.content
|
|
||||||
blocks = content if isinstance(content, list) else []
|
|
||||||
for block in blocks:
|
|
||||||
if isinstance(block, ToolResultBlock) and block.tool_use_id:
|
|
||||||
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
|
|
||||||
tool_name = tool_info.get("name", "unknown")
|
|
||||||
|
|
||||||
# Prefer the stashed full output over the SDK's
|
|
||||||
# (potentially truncated) ToolResultBlock content.
|
|
||||||
# The SDK truncates large results, writing them to disk,
|
|
||||||
# which breaks frontend widget parsing.
|
|
||||||
output = pop_pending_tool_output(tool_name) or (
|
|
||||||
_extract_tool_output(block.content)
|
|
||||||
)
|
|
||||||
|
|
||||||
responses.append(
|
|
||||||
StreamToolOutputAvailable(
|
|
||||||
toolCallId=block.tool_use_id,
|
|
||||||
toolName=tool_name,
|
|
||||||
output=output,
|
|
||||||
success=not (block.is_error or False),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Close the current step after tool results — the next
|
|
||||||
# AssistantMessage will open a new step for the continuation.
|
|
||||||
if self.step_open:
|
|
||||||
responses.append(StreamFinishStep())
|
|
||||||
self.step_open = False
|
|
||||||
|
|
||||||
elif isinstance(sdk_message, ResultMessage):
|
|
||||||
self._end_text_if_open(responses)
|
|
||||||
# Close the step before finishing.
|
|
||||||
if self.step_open:
|
|
||||||
responses.append(StreamFinishStep())
|
|
||||||
self.step_open = False
|
|
||||||
|
|
||||||
if sdk_message.subtype == "success":
|
|
||||||
responses.append(StreamFinish())
|
|
||||||
elif sdk_message.subtype in ("error", "error_during_execution"):
|
|
||||||
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
|
|
||||||
responses.append(
|
|
||||||
StreamError(errorText=str(error_msg), code="sdk_error")
|
|
||||||
)
|
|
||||||
responses.append(StreamFinish())
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Unexpected ResultMessage subtype: {sdk_message.subtype}"
|
|
||||||
)
|
|
||||||
responses.append(StreamFinish())
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
|
|
||||||
|
|
||||||
return responses
|
|
||||||
|
|
||||||
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
|
|
||||||
"""Start (or restart) a text block if needed."""
|
|
||||||
if not self.has_started_text or self.has_ended_text:
|
|
||||||
if self.has_ended_text:
|
|
||||||
self.text_block_id = str(uuid.uuid4())
|
|
||||||
self.has_ended_text = False
|
|
||||||
responses.append(StreamTextStart(id=self.text_block_id))
|
|
||||||
self.has_started_text = True
|
|
||||||
|
|
||||||
def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None:
|
|
||||||
"""End the current text block if one is open."""
|
|
||||||
if self.has_started_text and not self.has_ended_text:
|
|
||||||
responses.append(StreamTextEnd(id=self.text_block_id))
|
|
||||||
self.has_ended_text = True
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
|
||||||
"""Extract a string output from a ToolResultBlock's content field."""
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
if isinstance(content, list):
|
|
||||||
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
|
|
||||||
if parts:
|
|
||||||
return "".join(parts)
|
|
||||||
try:
|
|
||||||
return json.dumps(content)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
return str(content)
|
|
||||||
if content is None:
|
|
||||||
return ""
|
|
||||||
try:
|
|
||||||
return json.dumps(content)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
return str(content)
|
|
||||||
@@ -1,366 +0,0 @@
|
|||||||
"""Unit tests for the SDK response adapter."""
|
|
||||||
|
|
||||||
from claude_agent_sdk import (
|
|
||||||
AssistantMessage,
|
|
||||||
ResultMessage,
|
|
||||||
SystemMessage,
|
|
||||||
TextBlock,
|
|
||||||
ToolResultBlock,
|
|
||||||
ToolUseBlock,
|
|
||||||
UserMessage,
|
|
||||||
)
|
|
||||||
|
|
||||||
from backend.api.features.chat.response_model import (
|
|
||||||
StreamBaseResponse,
|
|
||||||
StreamError,
|
|
||||||
StreamFinish,
|
|
||||||
StreamFinishStep,
|
|
||||||
StreamStart,
|
|
||||||
StreamStartStep,
|
|
||||||
StreamTextDelta,
|
|
||||||
StreamTextEnd,
|
|
||||||
StreamTextStart,
|
|
||||||
StreamToolInputAvailable,
|
|
||||||
StreamToolInputStart,
|
|
||||||
StreamToolOutputAvailable,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .response_adapter import SDKResponseAdapter
|
|
||||||
from .tool_adapter import MCP_TOOL_PREFIX
|
|
||||||
|
|
||||||
|
|
||||||
def _adapter() -> SDKResponseAdapter:
|
|
||||||
a = SDKResponseAdapter(message_id="msg-1")
|
|
||||||
a.set_task_id("task-1")
|
|
||||||
return a
|
|
||||||
|
|
||||||
|
|
||||||
# -- SystemMessage -----------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_system_init_emits_start_and_step():
|
|
||||||
adapter = _adapter()
|
|
||||||
results = adapter.convert_message(SystemMessage(subtype="init", data={}))
|
|
||||||
assert len(results) == 2
|
|
||||||
assert isinstance(results[0], StreamStart)
|
|
||||||
assert results[0].messageId == "msg-1"
|
|
||||||
assert results[0].taskId == "task-1"
|
|
||||||
assert isinstance(results[1], StreamStartStep)
|
|
||||||
|
|
||||||
|
|
||||||
def test_system_non_init_emits_nothing():
|
|
||||||
adapter = _adapter()
|
|
||||||
results = adapter.convert_message(SystemMessage(subtype="other", data={}))
|
|
||||||
assert results == []
|
|
||||||
|
|
||||||
|
|
||||||
# -- AssistantMessage with TextBlock -----------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_text_block_emits_step_start_and_delta():
|
|
||||||
adapter = _adapter()
|
|
||||||
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
|
|
||||||
results = adapter.convert_message(msg)
|
|
||||||
assert len(results) == 3
|
|
||||||
assert isinstance(results[0], StreamStartStep)
|
|
||||||
assert isinstance(results[1], StreamTextStart)
|
|
||||||
assert isinstance(results[2], StreamTextDelta)
|
|
||||||
assert results[2].delta == "hello"
|
|
||||||
|
|
||||||
|
|
||||||
def test_empty_text_block_emits_only_step():
|
|
||||||
adapter = _adapter()
|
|
||||||
msg = AssistantMessage(content=[TextBlock(text="")], model="test")
|
|
||||||
results = adapter.convert_message(msg)
|
|
||||||
# Empty text skipped, but step still opens
|
|
||||||
assert len(results) == 1
|
|
||||||
assert isinstance(results[0], StreamStartStep)
|
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_text_deltas_reuse_block_id():
|
|
||||||
adapter = _adapter()
|
|
||||||
msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test")
|
|
||||||
msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test")
|
|
||||||
r1 = adapter.convert_message(msg1)
|
|
||||||
r2 = adapter.convert_message(msg2)
|
|
||||||
# First gets step+start+delta, second only delta (block & step already started)
|
|
||||||
assert len(r1) == 3
|
|
||||||
assert isinstance(r1[0], StreamStartStep)
|
|
||||||
assert isinstance(r1[1], StreamTextStart)
|
|
||||||
assert len(r2) == 1
|
|
||||||
assert isinstance(r2[0], StreamTextDelta)
|
|
||||||
assert r1[1].id == r2[0].id # same block ID
|
|
||||||
|
|
||||||
|
|
||||||
# -- AssistantMessage with ToolUseBlock --------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_tool_use_emits_input_start_and_available():
|
|
||||||
"""Tool names arrive with MCP prefix and should be stripped for the frontend."""
|
|
||||||
adapter = _adapter()
|
|
||||||
msg = AssistantMessage(
|
|
||||||
content=[
|
|
||||||
ToolUseBlock(
|
|
||||||
id="tool-1",
|
|
||||||
name=f"{MCP_TOOL_PREFIX}find_agent",
|
|
||||||
input={"q": "x"},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(msg)
|
|
||||||
assert len(results) == 3
|
|
||||||
assert isinstance(results[0], StreamStartStep)
|
|
||||||
assert isinstance(results[1], StreamToolInputStart)
|
|
||||||
assert results[1].toolCallId == "tool-1"
|
|
||||||
assert results[1].toolName == "find_agent" # prefix stripped
|
|
||||||
assert isinstance(results[2], StreamToolInputAvailable)
|
|
||||||
assert results[2].toolName == "find_agent" # prefix stripped
|
|
||||||
assert results[2].input == {"q": "x"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_text_then_tool_ends_text_block():
|
|
||||||
adapter = _adapter()
|
|
||||||
text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
|
|
||||||
tool_msg = AssistantMessage(
|
|
||||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
adapter.convert_message(text_msg) # opens step + text
|
|
||||||
results = adapter.convert_message(tool_msg)
|
|
||||||
# Step already open, so: TextEnd, ToolInputStart, ToolInputAvailable
|
|
||||||
assert len(results) == 3
|
|
||||||
assert isinstance(results[0], StreamTextEnd)
|
|
||||||
assert isinstance(results[1], StreamToolInputStart)
|
|
||||||
|
|
||||||
|
|
||||||
# -- UserMessage with ToolResultBlock ----------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_tool_result_emits_output_and_finish_step():
|
|
||||||
adapter = _adapter()
|
|
||||||
# First register the tool call (opens step) — SDK sends prefixed name
|
|
||||||
tool_msg = AssistantMessage(
|
|
||||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_agent", input={})],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
adapter.convert_message(tool_msg)
|
|
||||||
|
|
||||||
# Now send tool result
|
|
||||||
result_msg = UserMessage(
|
|
||||||
content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")]
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(result_msg)
|
|
||||||
assert len(results) == 2
|
|
||||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
|
||||||
assert results[0].toolCallId == "t1"
|
|
||||||
assert results[0].toolName == "find_agent" # prefix stripped
|
|
||||||
assert results[0].output == "found 3 agents"
|
|
||||||
assert results[0].success is True
|
|
||||||
assert isinstance(results[1], StreamFinishStep)
|
|
||||||
|
|
||||||
|
|
||||||
def test_tool_result_error():
|
|
||||||
adapter = _adapter()
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(
|
|
||||||
content=[
|
|
||||||
ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}run_agent", input={})
|
|
||||||
],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result_msg = UserMessage(
|
|
||||||
content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)]
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(result_msg)
|
|
||||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
|
||||||
assert results[0].success is False
|
|
||||||
assert isinstance(results[1], StreamFinishStep)
|
|
||||||
|
|
||||||
|
|
||||||
def test_tool_result_list_content():
|
|
||||||
adapter = _adapter()
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(
|
|
||||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result_msg = UserMessage(
|
|
||||||
content=[
|
|
||||||
ToolResultBlock(
|
|
||||||
tool_use_id="t1",
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "line1"},
|
|
||||||
{"type": "text", "text": "line2"},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(result_msg)
|
|
||||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
|
||||||
assert results[0].output == "line1line2"
|
|
||||||
assert isinstance(results[1], StreamFinishStep)
|
|
||||||
|
|
||||||
|
|
||||||
def test_string_user_message_ignored():
|
|
||||||
"""A plain string UserMessage (not tool results) produces no output."""
|
|
||||||
adapter = _adapter()
|
|
||||||
results = adapter.convert_message(UserMessage(content="hello"))
|
|
||||||
assert results == []
|
|
||||||
|
|
||||||
|
|
||||||
# -- ResultMessage -----------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_result_success_emits_finish_step_and_finish():
|
|
||||||
adapter = _adapter()
|
|
||||||
# Start some text first (opens step)
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(content=[TextBlock(text="done")], model="test")
|
|
||||||
)
|
|
||||||
msg = ResultMessage(
|
|
||||||
subtype="success",
|
|
||||||
duration_ms=100,
|
|
||||||
duration_api_ms=50,
|
|
||||||
is_error=False,
|
|
||||||
num_turns=1,
|
|
||||||
session_id="s1",
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(msg)
|
|
||||||
# TextEnd + FinishStep + StreamFinish
|
|
||||||
assert len(results) == 3
|
|
||||||
assert isinstance(results[0], StreamTextEnd)
|
|
||||||
assert isinstance(results[1], StreamFinishStep)
|
|
||||||
assert isinstance(results[2], StreamFinish)
|
|
||||||
|
|
||||||
|
|
||||||
def test_result_error_emits_error_and_finish():
|
|
||||||
adapter = _adapter()
|
|
||||||
msg = ResultMessage(
|
|
||||||
subtype="error",
|
|
||||||
duration_ms=100,
|
|
||||||
duration_api_ms=50,
|
|
||||||
is_error=True,
|
|
||||||
num_turns=0,
|
|
||||||
session_id="s1",
|
|
||||||
result="API rate limited",
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(msg)
|
|
||||||
# No step was open, so no FinishStep — just Error + Finish
|
|
||||||
assert len(results) == 2
|
|
||||||
assert isinstance(results[0], StreamError)
|
|
||||||
assert "API rate limited" in results[0].errorText
|
|
||||||
assert isinstance(results[1], StreamFinish)
|
|
||||||
|
|
||||||
|
|
||||||
# -- Text after tools (new block ID) ----------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_text_after_tool_gets_new_block_id():
|
|
||||||
adapter = _adapter()
|
|
||||||
# Text -> Tool -> ToolResult -> Text should get a new text block ID and step
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(content=[TextBlock(text="before")], model="test")
|
|
||||||
)
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(
|
|
||||||
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Send tool result (closes step)
|
|
||||||
adapter.convert_message(
|
|
||||||
UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="ok")])
|
|
||||||
)
|
|
||||||
results = adapter.convert_message(
|
|
||||||
AssistantMessage(content=[TextBlock(text="after")], model="test")
|
|
||||||
)
|
|
||||||
# Should get StreamStartStep (new step) + StreamTextStart (new block) + StreamTextDelta
|
|
||||||
assert len(results) == 3
|
|
||||||
assert isinstance(results[0], StreamStartStep)
|
|
||||||
assert isinstance(results[1], StreamTextStart)
|
|
||||||
assert isinstance(results[2], StreamTextDelta)
|
|
||||||
assert results[2].delta == "after"
|
|
||||||
|
|
||||||
|
|
||||||
# -- Full conversation flow --------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_full_conversation_flow():
|
|
||||||
"""Simulate a complete conversation: init -> text -> tool -> result -> text -> finish."""
|
|
||||||
adapter = _adapter()
|
|
||||||
all_responses: list[StreamBaseResponse] = []
|
|
||||||
|
|
||||||
# 1. Init
|
|
||||||
all_responses.extend(
|
|
||||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
|
||||||
)
|
|
||||||
# 2. Assistant text
|
|
||||||
all_responses.extend(
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(content=[TextBlock(text="Let me search")], model="test")
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# 3. Tool use
|
|
||||||
all_responses.extend(
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(
|
|
||||||
content=[
|
|
||||||
ToolUseBlock(
|
|
||||||
id="t1",
|
|
||||||
name=f"{MCP_TOOL_PREFIX}find_agent",
|
|
||||||
input={"query": "email"},
|
|
||||||
)
|
|
||||||
],
|
|
||||||
model="test",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# 4. Tool result
|
|
||||||
all_responses.extend(
|
|
||||||
adapter.convert_message(
|
|
||||||
UserMessage(
|
|
||||||
content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# 5. More text
|
|
||||||
all_responses.extend(
|
|
||||||
adapter.convert_message(
|
|
||||||
AssistantMessage(content=[TextBlock(text="I found 2")], model="test")
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# 6. Result
|
|
||||||
all_responses.extend(
|
|
||||||
adapter.convert_message(
|
|
||||||
ResultMessage(
|
|
||||||
subtype="success",
|
|
||||||
duration_ms=500,
|
|
||||||
duration_api_ms=400,
|
|
||||||
is_error=False,
|
|
||||||
num_turns=2,
|
|
||||||
session_id="s1",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
types = [type(r).__name__ for r in all_responses]
|
|
||||||
assert types == [
|
|
||||||
"StreamStart",
|
|
||||||
"StreamStartStep", # step 1: text + tool call
|
|
||||||
"StreamTextStart",
|
|
||||||
"StreamTextDelta", # "Let me search"
|
|
||||||
"StreamTextEnd", # closed before tool
|
|
||||||
"StreamToolInputStart",
|
|
||||||
"StreamToolInputAvailable",
|
|
||||||
"StreamToolOutputAvailable", # tool result
|
|
||||||
"StreamFinishStep", # step 1 closed after tool result
|
|
||||||
"StreamStartStep", # step 2: continuation text
|
|
||||||
"StreamTextStart", # new block after tool
|
|
||||||
"StreamTextDelta", # "I found 2"
|
|
||||||
"StreamTextEnd", # closed by result
|
|
||||||
"StreamFinishStep", # step 2 closed
|
|
||||||
"StreamFinish",
|
|
||||||
]
|
|
||||||
@@ -1,335 +0,0 @@
|
|||||||
"""Security hooks for Claude Agent SDK integration.
|
|
||||||
|
|
||||||
This module provides security hooks that validate tool calls before execution,
|
|
||||||
ensuring multi-user isolation and preventing unauthorized operations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import Any, cast
|
|
||||||
|
|
||||||
from backend.api.features.chat.sdk.tool_adapter import MCP_TOOL_PREFIX
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Tools that are blocked entirely (CLI/system access).
|
|
||||||
# "Bash" (capital) is the SDK built-in — it's NOT in allowed_tools but blocked
|
|
||||||
# here as defence-in-depth. The agent uses mcp__copilot__bash_exec instead,
|
|
||||||
# which has kernel-level network isolation (unshare --net).
|
|
||||||
BLOCKED_TOOLS = {
|
|
||||||
"Bash",
|
|
||||||
"bash",
|
|
||||||
"shell",
|
|
||||||
"exec",
|
|
||||||
"terminal",
|
|
||||||
"command",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Tools allowed only when their path argument stays within the SDK workspace.
|
|
||||||
# The SDK uses these to handle oversized tool results (writes to tool-results/
|
|
||||||
# files, then reads them back) and for workspace file operations.
|
|
||||||
WORKSPACE_SCOPED_TOOLS = {"Read", "Write", "Edit", "Glob", "Grep"}
|
|
||||||
|
|
||||||
# Dangerous patterns in tool inputs
|
|
||||||
DANGEROUS_PATTERNS = [
|
|
||||||
r"sudo",
|
|
||||||
r"rm\s+-rf",
|
|
||||||
r"dd\s+if=",
|
|
||||||
r"/etc/passwd",
|
|
||||||
r"/etc/shadow",
|
|
||||||
r"chmod\s+777",
|
|
||||||
r"curl\s+.*\|.*sh",
|
|
||||||
r"wget\s+.*\|.*sh",
|
|
||||||
r"eval\s*\(",
|
|
||||||
r"exec\s*\(",
|
|
||||||
r"__import__",
|
|
||||||
r"os\.system",
|
|
||||||
r"subprocess",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _deny(reason: str) -> dict[str, Any]:
|
|
||||||
"""Return a hook denial response."""
|
|
||||||
return {
|
|
||||||
"hookSpecificOutput": {
|
|
||||||
"hookEventName": "PreToolUse",
|
|
||||||
"permissionDecision": "deny",
|
|
||||||
"permissionDecisionReason": reason,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_workspace_path(
|
|
||||||
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Validate that a workspace-scoped tool only accesses allowed paths.
|
|
||||||
|
|
||||||
Allowed directories:
|
|
||||||
- The SDK working directory (``/tmp/copilot-<session>/``)
|
|
||||||
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
|
|
||||||
"""
|
|
||||||
path = tool_input.get("file_path") or tool_input.get("path") or ""
|
|
||||||
if not path:
|
|
||||||
# Glob/Grep without a path default to cwd which is already sandboxed
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# Resolve relative paths against sdk_cwd (the SDK sets cwd so the LLM
|
|
||||||
# naturally uses relative paths like "test.txt" instead of absolute ones).
|
|
||||||
# Tilde paths (~/) are home-dir references, not relative — expand first.
|
|
||||||
if path.startswith("~"):
|
|
||||||
resolved = os.path.realpath(os.path.expanduser(path))
|
|
||||||
elif not os.path.isabs(path) and sdk_cwd:
|
|
||||||
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
|
|
||||||
else:
|
|
||||||
resolved = os.path.realpath(path)
|
|
||||||
|
|
||||||
# Allow access within the SDK working directory
|
|
||||||
if sdk_cwd:
|
|
||||||
norm_cwd = os.path.realpath(sdk_cwd)
|
|
||||||
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
|
|
||||||
claude_dir = os.path.realpath(os.path.expanduser("~/.claude/projects"))
|
|
||||||
tool_results_seg = os.sep + "tool-results" + os.sep
|
|
||||||
if resolved.startswith(claude_dir + os.sep) and tool_results_seg in resolved:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
|
|
||||||
)
|
|
||||||
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
|
|
||||||
return _deny(
|
|
||||||
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
|
|
||||||
f"directory.{workspace_hint} "
|
|
||||||
"This is enforced by the platform and cannot be bypassed."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_tool_access(
|
|
||||||
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None = None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Validate that a tool call is allowed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Empty dict to allow, or dict with hookSpecificOutput to deny
|
|
||||||
"""
|
|
||||||
# Block forbidden tools
|
|
||||||
if tool_name in BLOCKED_TOOLS:
|
|
||||||
logger.warning(f"Blocked tool access attempt: {tool_name}")
|
|
||||||
return _deny(
|
|
||||||
f"[SECURITY] Tool '{tool_name}' is blocked for security. "
|
|
||||||
"This is enforced by the platform and cannot be bypassed. "
|
|
||||||
"Use the CoPilot-specific MCP tools instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Workspace-scoped tools: allowed only within the SDK workspace directory
|
|
||||||
if tool_name in WORKSPACE_SCOPED_TOOLS:
|
|
||||||
return _validate_workspace_path(tool_name, tool_input, sdk_cwd)
|
|
||||||
|
|
||||||
# Check for dangerous patterns in tool input
|
|
||||||
# Use json.dumps for predictable format (str() produces Python repr)
|
|
||||||
input_str = json.dumps(tool_input) if tool_input else ""
|
|
||||||
|
|
||||||
for pattern in DANGEROUS_PATTERNS:
|
|
||||||
if re.search(pattern, input_str, re.IGNORECASE):
|
|
||||||
logger.warning(
|
|
||||||
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
|
|
||||||
)
|
|
||||||
return _deny(
|
|
||||||
"[SECURITY] Input contains a blocked pattern. "
|
|
||||||
"This is enforced by the platform and cannot be bypassed."
|
|
||||||
)
|
|
||||||
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_user_isolation(
|
|
||||||
tool_name: str, tool_input: dict[str, Any], user_id: str | None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Validate that tool calls respect user isolation."""
|
|
||||||
# For workspace file tools, ensure path doesn't escape
|
|
||||||
if "workspace" in tool_name.lower():
|
|
||||||
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
|
||||||
if path:
|
|
||||||
# Check for path traversal
|
|
||||||
if ".." in path or path.startswith("/"):
|
|
||||||
logger.warning(
|
|
||||||
f"Blocked path traversal attempt: {path} by user {user_id}"
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"hookSpecificOutput": {
|
|
||||||
"hookEventName": "PreToolUse",
|
|
||||||
"permissionDecision": "deny",
|
|
||||||
"permissionDecisionReason": "Path traversal not allowed",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def create_security_hooks(
|
|
||||||
user_id: str | None,
|
|
||||||
sdk_cwd: str | None = None,
|
|
||||||
max_subtasks: int = 3,
|
|
||||||
on_stop: Callable[[str, str], None] | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Create the security hooks configuration for Claude Agent SDK.
|
|
||||||
|
|
||||||
Includes security validation and observability hooks:
|
|
||||||
- PreToolUse: Security validation before tool execution
|
|
||||||
- PostToolUse: Log successful tool executions
|
|
||||||
- PostToolUseFailure: Log and handle failed tool executions
|
|
||||||
- PreCompact: Log context compaction events (SDK handles compaction automatically)
|
|
||||||
- Stop: Capture transcript path for stateless resume (when *on_stop* is provided)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: Current user ID for isolation validation
|
|
||||||
sdk_cwd: SDK working directory for workspace-scoped tool validation
|
|
||||||
max_subtasks: Maximum Task (sub-agent) spawns allowed per session
|
|
||||||
on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when
|
|
||||||
the SDK finishes processing — used to read the JSONL transcript
|
|
||||||
before the CLI process exits.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Hooks configuration dict for ClaudeAgentOptions
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from claude_agent_sdk import HookMatcher
|
|
||||||
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
|
|
||||||
|
|
||||||
# Per-session counter for Task sub-agent spawns
|
|
||||||
task_spawn_count = 0
|
|
||||||
|
|
||||||
async def pre_tool_use_hook(
|
|
||||||
input_data: HookInput,
|
|
||||||
tool_use_id: str | None,
|
|
||||||
context: HookContext,
|
|
||||||
) -> SyncHookJSONOutput:
|
|
||||||
"""Combined pre-tool-use validation hook."""
|
|
||||||
nonlocal task_spawn_count
|
|
||||||
_ = context # unused but required by signature
|
|
||||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
|
||||||
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
|
|
||||||
|
|
||||||
# Rate-limit Task (sub-agent) spawns per session
|
|
||||||
if tool_name == "Task":
|
|
||||||
task_spawn_count += 1
|
|
||||||
if task_spawn_count > max_subtasks:
|
|
||||||
logger.warning(
|
|
||||||
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
|
|
||||||
)
|
|
||||||
return cast(
|
|
||||||
SyncHookJSONOutput,
|
|
||||||
_deny(
|
|
||||||
f"Maximum {max_subtasks} sub-tasks per session. "
|
|
||||||
"Please continue in the main conversation."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Strip MCP prefix for consistent validation
|
|
||||||
is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX)
|
|
||||||
clean_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
|
|
||||||
|
|
||||||
# Only block non-CoPilot tools; our MCP-registered tools
|
|
||||||
# (including Read for oversized results) are already sandboxed.
|
|
||||||
if not is_copilot_tool:
|
|
||||||
result = _validate_tool_access(clean_name, tool_input, sdk_cwd)
|
|
||||||
if result:
|
|
||||||
return cast(SyncHookJSONOutput, result)
|
|
||||||
|
|
||||||
# Validate user isolation
|
|
||||||
result = _validate_user_isolation(clean_name, tool_input, user_id)
|
|
||||||
if result:
|
|
||||||
return cast(SyncHookJSONOutput, result)
|
|
||||||
|
|
||||||
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
|
|
||||||
return cast(SyncHookJSONOutput, {})
|
|
||||||
|
|
||||||
async def post_tool_use_hook(
|
|
||||||
input_data: HookInput,
|
|
||||||
tool_use_id: str | None,
|
|
||||||
context: HookContext,
|
|
||||||
) -> SyncHookJSONOutput:
|
|
||||||
"""Log successful tool executions for observability."""
|
|
||||||
_ = context
|
|
||||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
|
||||||
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
|
|
||||||
return cast(SyncHookJSONOutput, {})
|
|
||||||
|
|
||||||
async def post_tool_failure_hook(
|
|
||||||
input_data: HookInput,
|
|
||||||
tool_use_id: str | None,
|
|
||||||
context: HookContext,
|
|
||||||
) -> SyncHookJSONOutput:
|
|
||||||
"""Log failed tool executions for debugging."""
|
|
||||||
_ = context
|
|
||||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
|
||||||
error = input_data.get("error", "Unknown error")
|
|
||||||
logger.warning(
|
|
||||||
f"[SDK] Tool failed: {tool_name}, error={error}, "
|
|
||||||
f"user={user_id}, tool_use_id={tool_use_id}"
|
|
||||||
)
|
|
||||||
return cast(SyncHookJSONOutput, {})
|
|
||||||
|
|
||||||
async def pre_compact_hook(
|
|
||||||
input_data: HookInput,
|
|
||||||
tool_use_id: str | None,
|
|
||||||
context: HookContext,
|
|
||||||
) -> SyncHookJSONOutput:
|
|
||||||
"""Log when SDK triggers context compaction.
|
|
||||||
|
|
||||||
The SDK automatically compacts conversation history when it grows too large.
|
|
||||||
This hook provides visibility into when compaction happens.
|
|
||||||
"""
|
|
||||||
_ = context, tool_use_id
|
|
||||||
trigger = input_data.get("trigger", "auto")
|
|
||||||
logger.info(
|
|
||||||
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
|
|
||||||
)
|
|
||||||
return cast(SyncHookJSONOutput, {})
|
|
||||||
|
|
||||||
# --- Stop hook: capture transcript path for stateless resume ---
|
|
||||||
async def stop_hook(
|
|
||||||
input_data: HookInput,
|
|
||||||
tool_use_id: str | None,
|
|
||||||
context: HookContext,
|
|
||||||
) -> SyncHookJSONOutput:
|
|
||||||
"""Capture transcript path when SDK finishes processing.
|
|
||||||
|
|
||||||
The Stop hook fires while the CLI process is still alive, giving us
|
|
||||||
a reliable window to read the JSONL transcript before SIGTERM.
|
|
||||||
"""
|
|
||||||
_ = context, tool_use_id
|
|
||||||
transcript_path = cast(str, input_data.get("transcript_path", ""))
|
|
||||||
sdk_session_id = cast(str, input_data.get("session_id", ""))
|
|
||||||
|
|
||||||
if transcript_path and on_stop:
|
|
||||||
logger.info(
|
|
||||||
f"[SDK] Stop hook: transcript_path={transcript_path}, "
|
|
||||||
f"sdk_session_id={sdk_session_id[:12]}..."
|
|
||||||
)
|
|
||||||
on_stop(transcript_path, sdk_session_id)
|
|
||||||
|
|
||||||
return cast(SyncHookJSONOutput, {})
|
|
||||||
|
|
||||||
hooks: dict[str, Any] = {
|
|
||||||
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
|
|
||||||
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
|
|
||||||
"PostToolUseFailure": [
|
|
||||||
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
|
|
||||||
],
|
|
||||||
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
|
|
||||||
}
|
|
||||||
|
|
||||||
if on_stop is not None:
|
|
||||||
hooks["Stop"] = [HookMatcher(matcher=None, hooks=[stop_hook])]
|
|
||||||
|
|
||||||
return hooks
|
|
||||||
except ImportError:
|
|
||||||
# Fallback for when SDK isn't available - return empty hooks
|
|
||||||
logger.warning("claude-agent-sdk not available, security hooks disabled")
|
|
||||||
return {}
|
|
||||||
@@ -1,165 +0,0 @@
|
|||||||
"""Unit tests for SDK security hooks."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from .security_hooks import _validate_tool_access, _validate_user_isolation
|
|
||||||
|
|
||||||
SDK_CWD = "/tmp/copilot-abc123"
|
|
||||||
|
|
||||||
|
|
||||||
def _is_denied(result: dict) -> bool:
|
|
||||||
hook = result.get("hookSpecificOutput", {})
|
|
||||||
return hook.get("permissionDecision") == "deny"
|
|
||||||
|
|
||||||
|
|
||||||
# -- Blocked tools -----------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_blocked_tools_denied():
|
|
||||||
for tool in ("bash", "shell", "exec", "terminal", "command"):
|
|
||||||
result = _validate_tool_access(tool, {})
|
|
||||||
assert _is_denied(result), f"{tool} should be blocked"
|
|
||||||
|
|
||||||
|
|
||||||
def test_unknown_tool_allowed():
|
|
||||||
result = _validate_tool_access("SomeCustomTool", {})
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
# -- Workspace-scoped tools --------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_read_within_workspace_allowed():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_write_within_workspace_allowed():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_edit_within_workspace_allowed():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_glob_within_workspace_allowed():
|
|
||||||
result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_grep_within_workspace_allowed():
|
|
||||||
result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_read_outside_workspace_denied():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_write_outside_workspace_denied():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_traversal_attack_denied():
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Read",
|
|
||||||
{"file_path": f"{SDK_CWD}/../../etc/passwd"},
|
|
||||||
sdk_cwd=SDK_CWD,
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_no_path_allowed():
|
|
||||||
"""Glob/Grep without a path argument defaults to cwd — should pass."""
|
|
||||||
result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_read_no_cwd_denies_absolute():
|
|
||||||
"""If no sdk_cwd is set, absolute paths are denied."""
|
|
||||||
result = _validate_tool_access("Read", {"file_path": "/tmp/anything"})
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
# -- Tool-results directory --------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_read_tool_results_allowed():
|
|
||||||
home = os.path.expanduser("~")
|
|
||||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
|
||||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_read_claude_projects_without_tool_results_denied():
|
|
||||||
home = os.path.expanduser("~")
|
|
||||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
|
||||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
# -- Built-in Bash is blocked (use bash_exec MCP tool instead) ---------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_bash_builtin_always_blocked():
|
|
||||||
"""SDK built-in Bash is blocked — bash_exec MCP tool with bubblewrap is used instead."""
|
|
||||||
result = _validate_tool_access("Bash", {"command": "echo hello"}, sdk_cwd=SDK_CWD)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
# -- Dangerous patterns ------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_dangerous_pattern_blocked():
|
|
||||||
result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"})
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_subprocess_pattern_blocked():
|
|
||||||
result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"})
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
# -- User isolation ----------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def test_workspace_path_traversal_blocked():
|
|
||||||
result = _validate_user_isolation(
|
|
||||||
"workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1"
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_workspace_absolute_path_blocked():
|
|
||||||
result = _validate_user_isolation(
|
|
||||||
"workspace_read", {"path": "/etc/passwd"}, user_id="user-1"
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
def test_workspace_normal_path_allowed():
|
|
||||||
result = _validate_user_isolation(
|
|
||||||
"workspace_read", {"path": "src/main.py"}, user_id="user-1"
|
|
||||||
)
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
|
|
||||||
def test_non_workspace_tool_passes_isolation():
|
|
||||||
result = _validate_user_isolation(
|
|
||||||
"find_agent", {"query": "email"}, user_id="user-1"
|
|
||||||
)
|
|
||||||
assert result == {}
|
|
||||||
@@ -1,751 +0,0 @@
|
|||||||
"""Claude Agent SDK service layer for CoPilot chat completions."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.util.exceptions import NotFoundError
|
|
||||||
|
|
||||||
from .. import stream_registry
|
|
||||||
from ..config import ChatConfig
|
|
||||||
from ..model import (
|
|
||||||
ChatMessage,
|
|
||||||
ChatSession,
|
|
||||||
get_chat_session,
|
|
||||||
update_session_title,
|
|
||||||
upsert_chat_session,
|
|
||||||
)
|
|
||||||
from ..response_model import (
|
|
||||||
StreamBaseResponse,
|
|
||||||
StreamError,
|
|
||||||
StreamFinish,
|
|
||||||
StreamStart,
|
|
||||||
StreamTextDelta,
|
|
||||||
StreamToolInputAvailable,
|
|
||||||
StreamToolOutputAvailable,
|
|
||||||
)
|
|
||||||
from ..service import (
|
|
||||||
_build_system_prompt,
|
|
||||||
_execute_long_running_tool_with_streaming,
|
|
||||||
_generate_session_title,
|
|
||||||
)
|
|
||||||
from ..tools.models import OperationPendingResponse, OperationStartedResponse
|
|
||||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
|
||||||
from ..tracking import track_user_message
|
|
||||||
from .response_adapter import SDKResponseAdapter
|
|
||||||
from .security_hooks import create_security_hooks
|
|
||||||
from .tool_adapter import (
|
|
||||||
COPILOT_TOOL_NAMES,
|
|
||||||
LongRunningCallback,
|
|
||||||
create_copilot_mcp_server,
|
|
||||||
set_execution_context,
|
|
||||||
)
|
|
||||||
from .transcript import (
|
|
||||||
download_transcript,
|
|
||||||
read_transcript_file,
|
|
||||||
upload_transcript,
|
|
||||||
validate_transcript,
|
|
||||||
write_transcript_to_tempfile,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
config = ChatConfig()
|
|
||||||
|
|
||||||
# Set to hold background tasks to prevent garbage collection
|
|
||||||
_background_tasks: set[asyncio.Task[Any]] = set()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CapturedTranscript:
|
|
||||||
"""Info captured by the SDK Stop hook for stateless --resume."""
|
|
||||||
|
|
||||||
path: str = ""
|
|
||||||
sdk_session_id: str = ""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def available(self) -> bool:
|
|
||||||
return bool(self.path)
|
|
||||||
|
|
||||||
|
|
||||||
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
|
|
||||||
|
|
||||||
# Appended to the system prompt to inform the agent about available tools.
|
|
||||||
# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead,
|
|
||||||
# which has kernel-level network isolation (unshare --net).
|
|
||||||
_SDK_TOOL_SUPPLEMENT = """
|
|
||||||
|
|
||||||
## Tool notes
|
|
||||||
|
|
||||||
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
|
||||||
for shell commands — it runs in a network-isolated sandbox.
|
|
||||||
- **Shared workspace**: The SDK Read/Write tools and `bash_exec` share the
|
|
||||||
same working directory. Files created by one are readable by the other.
|
|
||||||
These files are **ephemeral** — they exist only for the current session.
|
|
||||||
- **Persistent storage**: Use `write_workspace_file` / `read_workspace_file`
|
|
||||||
for files that should persist across sessions (stored in cloud storage).
|
|
||||||
- Long-running tools (create_agent, edit_agent, etc.) are handled
|
|
||||||
asynchronously. You will receive an immediate response; the actual result
|
|
||||||
is delivered to the user via a background stream.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
|
|
||||||
"""Build a callback that delegates long-running tools to the non-SDK infrastructure.
|
|
||||||
|
|
||||||
Long-running tools (create_agent, edit_agent, etc.) are delegated to the
|
|
||||||
existing background infrastructure: stream_registry (Redis Streams),
|
|
||||||
database persistence, and SSE reconnection. This means results survive
|
|
||||||
page refreshes / pod restarts, and the frontend shows the proper loading
|
|
||||||
widget with progress updates.
|
|
||||||
|
|
||||||
The returned callback matches the ``LongRunningCallback`` signature:
|
|
||||||
``(tool_name, args, session) -> MCP response dict``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def _callback(
|
|
||||||
tool_name: str, args: dict[str, Any], session: ChatSession
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
operation_id = str(uuid.uuid4())
|
|
||||||
task_id = str(uuid.uuid4())
|
|
||||||
tool_call_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
|
||||||
session_id = session.session_id
|
|
||||||
|
|
||||||
# --- Build user-friendly messages (matches non-SDK service) ---
|
|
||||||
if tool_name == "create_agent":
|
|
||||||
desc = args.get("description", "")
|
|
||||||
desc_preview = (desc[:100] + "...") if len(desc) > 100 else desc
|
|
||||||
pending_msg = (
|
|
||||||
f"Creating your agent: {desc_preview}"
|
|
||||||
if desc_preview
|
|
||||||
else "Creating agent... This may take a few minutes."
|
|
||||||
)
|
|
||||||
started_msg = (
|
|
||||||
"Agent creation started. You can close this tab - "
|
|
||||||
"check your library in a few minutes."
|
|
||||||
)
|
|
||||||
elif tool_name == "edit_agent":
|
|
||||||
changes = args.get("changes", "")
|
|
||||||
changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes
|
|
||||||
pending_msg = (
|
|
||||||
f"Editing agent: {changes_preview}"
|
|
||||||
if changes_preview
|
|
||||||
else "Editing agent... This may take a few minutes."
|
|
||||||
)
|
|
||||||
started_msg = (
|
|
||||||
"Agent edit started. You can close this tab - "
|
|
||||||
"check your library in a few minutes."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
pending_msg = f"Running {tool_name}... This may take a few minutes."
|
|
||||||
started_msg = (
|
|
||||||
f"{tool_name} started. You can close this tab - "
|
|
||||||
"check back in a few minutes."
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Register task in Redis for SSE reconnection ---
|
|
||||||
await stream_registry.create_task(
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
tool_name=tool_name,
|
|
||||||
operation_id=operation_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Save OperationPendingResponse to chat history ---
|
|
||||||
pending_message = ChatMessage(
|
|
||||||
role="tool",
|
|
||||||
content=OperationPendingResponse(
|
|
||||||
message=pending_msg,
|
|
||||||
operation_id=operation_id,
|
|
||||||
tool_name=tool_name,
|
|
||||||
).model_dump_json(),
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
)
|
|
||||||
session.messages.append(pending_message)
|
|
||||||
await upsert_chat_session(session)
|
|
||||||
|
|
||||||
# --- Spawn background task (reuses non-SDK infrastructure) ---
|
|
||||||
bg_task = asyncio.create_task(
|
|
||||||
_execute_long_running_tool_with_streaming(
|
|
||||||
tool_name=tool_name,
|
|
||||||
parameters=args,
|
|
||||||
tool_call_id=tool_call_id,
|
|
||||||
operation_id=operation_id,
|
|
||||||
task_id=task_id,
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
_background_tasks.add(bg_task)
|
|
||||||
bg_task.add_done_callback(_background_tasks.discard)
|
|
||||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[SDK] Long-running tool {tool_name} delegated to background "
|
|
||||||
f"(operation_id={operation_id}, task_id={task_id})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Return OperationStartedResponse as MCP tool result ---
|
|
||||||
# This flows through SDK → response adapter → frontend, triggering
|
|
||||||
# the loading widget with SSE reconnection support.
|
|
||||||
started_json = OperationStartedResponse(
|
|
||||||
message=started_msg,
|
|
||||||
operation_id=operation_id,
|
|
||||||
tool_name=tool_name,
|
|
||||||
task_id=task_id,
|
|
||||||
).model_dump_json()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": started_json}],
|
|
||||||
"isError": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
return _callback
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_sdk_model() -> str | None:
|
|
||||||
"""Resolve the model name for the Claude Agent SDK CLI.
|
|
||||||
|
|
||||||
Uses ``config.claude_agent_model`` if set, otherwise derives from
|
|
||||||
``config.model`` by stripping the OpenRouter provider prefix (e.g.,
|
|
||||||
``"anthropic/claude-opus-4.6"`` → ``"claude-opus-4.6"``).
|
|
||||||
"""
|
|
||||||
if config.claude_agent_model:
|
|
||||||
return config.claude_agent_model
|
|
||||||
model = config.model
|
|
||||||
if "/" in model:
|
|
||||||
return model.split("/", 1)[1]
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def _build_sdk_env() -> dict[str, str]:
|
|
||||||
"""Build env vars for the SDK CLI process.
|
|
||||||
|
|
||||||
Routes API calls through OpenRouter (or a custom base_url) using
|
|
||||||
the same ``config.api_key`` / ``config.base_url`` as the non-SDK path.
|
|
||||||
This gives per-call token and cost tracking on the OpenRouter dashboard.
|
|
||||||
|
|
||||||
Only overrides ``ANTHROPIC_API_KEY`` when a valid proxy URL and auth
|
|
||||||
token are both present — otherwise returns an empty dict so the SDK
|
|
||||||
falls back to its default credentials.
|
|
||||||
"""
|
|
||||||
env: dict[str, str] = {}
|
|
||||||
if config.api_key and config.base_url:
|
|
||||||
# Strip /v1 suffix — SDK expects the base URL without a version path
|
|
||||||
base = config.base_url.rstrip("/")
|
|
||||||
if base.endswith("/v1"):
|
|
||||||
base = base[:-3]
|
|
||||||
if not base or not base.startswith("http"):
|
|
||||||
# Invalid base_url — don't override SDK defaults
|
|
||||||
return env
|
|
||||||
env["ANTHROPIC_BASE_URL"] = base
|
|
||||||
env["ANTHROPIC_AUTH_TOKEN"] = config.api_key
|
|
||||||
# Must be explicitly empty so the CLI uses AUTH_TOKEN instead
|
|
||||||
env["ANTHROPIC_API_KEY"] = ""
|
|
||||||
return env
|
|
||||||
|
|
||||||
|
|
||||||
def _make_sdk_cwd(session_id: str) -> str:
|
|
||||||
"""Create a safe, session-specific working directory path.
|
|
||||||
|
|
||||||
Delegates to :func:`~backend.api.features.chat.tools.sandbox.make_session_path`
|
|
||||||
(single source of truth for path sanitization) and adds a defence-in-depth
|
|
||||||
assertion.
|
|
||||||
"""
|
|
||||||
cwd = make_session_path(session_id)
|
|
||||||
# Defence-in-depth: normpath + startswith is a CodeQL-recognised sanitizer
|
|
||||||
cwd = os.path.normpath(cwd)
|
|
||||||
if not cwd.startswith(_SDK_CWD_PREFIX):
|
|
||||||
raise ValueError(f"SDK cwd escaped prefix: {cwd}")
|
|
||||||
return cwd
|
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_sdk_tool_results(cwd: str) -> None:
|
|
||||||
"""Remove SDK tool-result files for a specific session working directory.
|
|
||||||
|
|
||||||
The SDK creates tool-result files under ~/.claude/projects/<encoded-cwd>/tool-results/.
|
|
||||||
We clean only the specific cwd's results to avoid race conditions between
|
|
||||||
concurrent sessions.
|
|
||||||
|
|
||||||
Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id.
|
|
||||||
"""
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
# Validate cwd is under the expected prefix
|
|
||||||
normalized = os.path.normpath(cwd)
|
|
||||||
if not normalized.startswith(_SDK_CWD_PREFIX):
|
|
||||||
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# SDK encodes the cwd path by replacing '/' with '-'
|
|
||||||
encoded_cwd = normalized.replace("/", "-")
|
|
||||||
|
|
||||||
# Construct the project directory path (known-safe home expansion)
|
|
||||||
claude_projects = os.path.expanduser("~/.claude/projects")
|
|
||||||
project_dir = os.path.join(claude_projects, encoded_cwd)
|
|
||||||
|
|
||||||
# Security check 3: Validate project_dir is under ~/.claude/projects
|
|
||||||
project_dir = os.path.normpath(project_dir)
|
|
||||||
if not project_dir.startswith(claude_projects):
|
|
||||||
logger.warning(
|
|
||||||
f"[SDK] Rejecting cleanup for escaped project path: {project_dir}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
results_dir = os.path.join(project_dir, "tool-results")
|
|
||||||
if os.path.isdir(results_dir):
|
|
||||||
for filename in os.listdir(results_dir):
|
|
||||||
file_path = os.path.join(results_dir, filename)
|
|
||||||
try:
|
|
||||||
if os.path.isfile(file_path):
|
|
||||||
os.remove(file_path)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Also clean up the temp cwd directory itself
|
|
||||||
try:
|
|
||||||
shutil.rmtree(normalized, ignore_errors=True)
|
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def _compress_conversation_history(
|
|
||||||
session: ChatSession,
|
|
||||||
) -> list[ChatMessage]:
|
|
||||||
"""Compress prior conversation messages if they exceed the token threshold.
|
|
||||||
|
|
||||||
Uses the shared compress_context() from prompt.py which supports:
|
|
||||||
- LLM summarization of old messages (keeps recent ones intact)
|
|
||||||
- Progressive content truncation as fallback
|
|
||||||
- Middle-out deletion as last resort
|
|
||||||
|
|
||||||
Returns the compressed prior messages (everything except the current message).
|
|
||||||
"""
|
|
||||||
prior = session.messages[:-1]
|
|
||||||
if len(prior) < 2:
|
|
||||||
return prior
|
|
||||||
|
|
||||||
from backend.util.prompt import compress_context
|
|
||||||
|
|
||||||
# Convert ChatMessages to dicts for compress_context
|
|
||||||
messages_dict = []
|
|
||||||
for msg in prior:
|
|
||||||
msg_dict: dict[str, Any] = {"role": msg.role}
|
|
||||||
if msg.content:
|
|
||||||
msg_dict["content"] = msg.content
|
|
||||||
if msg.tool_calls:
|
|
||||||
msg_dict["tool_calls"] = msg.tool_calls
|
|
||||||
if msg.tool_call_id:
|
|
||||||
msg_dict["tool_call_id"] = msg.tool_call_id
|
|
||||||
messages_dict.append(msg_dict)
|
|
||||||
|
|
||||||
try:
|
|
||||||
import openai
|
|
||||||
|
|
||||||
async with openai.AsyncOpenAI(
|
|
||||||
api_key=config.api_key, base_url=config.base_url, timeout=30.0
|
|
||||||
) as client:
|
|
||||||
result = await compress_context(
|
|
||||||
messages=messages_dict,
|
|
||||||
model=config.model,
|
|
||||||
client=client,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
|
|
||||||
# Fall back to truncation-only (no LLM summarization)
|
|
||||||
result = await compress_context(
|
|
||||||
messages=messages_dict,
|
|
||||||
model=config.model,
|
|
||||||
client=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.was_compacted:
|
|
||||||
logger.info(
|
|
||||||
f"[SDK] Context compacted: {result.original_token_count} -> "
|
|
||||||
f"{result.token_count} tokens "
|
|
||||||
f"({result.messages_summarized} summarized, "
|
|
||||||
f"{result.messages_dropped} dropped)"
|
|
||||||
)
|
|
||||||
# Convert compressed dicts back to ChatMessages
|
|
||||||
return [
|
|
||||||
ChatMessage(
|
|
||||||
role=m["role"],
|
|
||||||
content=m.get("content"),
|
|
||||||
tool_calls=m.get("tool_calls"),
|
|
||||||
tool_call_id=m.get("tool_call_id"),
|
|
||||||
)
|
|
||||||
for m in result.messages
|
|
||||||
]
|
|
||||||
|
|
||||||
return prior
|
|
||||||
|
|
||||||
|
|
||||||
def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
|
|
||||||
"""Format conversation messages into a context prefix for the user message.
|
|
||||||
|
|
||||||
Returns a string like:
|
|
||||||
<conversation_history>
|
|
||||||
User: hello
|
|
||||||
You responded: Hi! How can I help?
|
|
||||||
</conversation_history>
|
|
||||||
|
|
||||||
Returns None if there are no messages to format.
|
|
||||||
"""
|
|
||||||
if not messages:
|
|
||||||
return None
|
|
||||||
|
|
||||||
lines: list[str] = []
|
|
||||||
for msg in messages:
|
|
||||||
if not msg.content:
|
|
||||||
continue
|
|
||||||
if msg.role == "user":
|
|
||||||
lines.append(f"User: {msg.content}")
|
|
||||||
elif msg.role == "assistant":
|
|
||||||
lines.append(f"You responded: {msg.content}")
|
|
||||||
# Skip tool messages — they're internal details
|
|
||||||
|
|
||||||
if not lines:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
|
|
||||||
|
|
||||||
|
|
||||||
async def stream_chat_completion_sdk(
|
|
||||||
session_id: str,
|
|
||||||
message: str | None = None,
|
|
||||||
tool_call_response: str | None = None, # noqa: ARG001
|
|
||||||
is_user_message: bool = True,
|
|
||||||
user_id: str | None = None,
|
|
||||||
retry_count: int = 0, # noqa: ARG001
|
|
||||||
session: ChatSession | None = None,
|
|
||||||
context: dict[str, str] | None = None, # noqa: ARG001
|
|
||||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
|
||||||
"""Stream chat completion using Claude Agent SDK.
|
|
||||||
|
|
||||||
Drop-in replacement for stream_chat_completion with improved reliability.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if session is None:
|
|
||||||
session = await get_chat_session(session_id, user_id)
|
|
||||||
|
|
||||||
if not session:
|
|
||||||
raise NotFoundError(
|
|
||||||
f"Session {session_id} not found. Please create a new session first."
|
|
||||||
)
|
|
||||||
|
|
||||||
if message:
|
|
||||||
session.messages.append(
|
|
||||||
ChatMessage(
|
|
||||||
role="user" if is_user_message else "assistant", content=message
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if is_user_message:
|
|
||||||
track_user_message(
|
|
||||||
user_id=user_id, session_id=session_id, message_length=len(message)
|
|
||||||
)
|
|
||||||
|
|
||||||
session = await upsert_chat_session(session)
|
|
||||||
|
|
||||||
# Generate title for new sessions (first user message)
|
|
||||||
if is_user_message and not session.title:
|
|
||||||
user_messages = [m for m in session.messages if m.role == "user"]
|
|
||||||
if len(user_messages) == 1:
|
|
||||||
first_message = user_messages[0].content or message or ""
|
|
||||||
if first_message:
|
|
||||||
task = asyncio.create_task(
|
|
||||||
_update_title_async(session_id, first_message, user_id)
|
|
||||||
)
|
|
||||||
_background_tasks.add(task)
|
|
||||||
task.add_done_callback(_background_tasks.discard)
|
|
||||||
|
|
||||||
# Build system prompt (reuses non-SDK path with Langfuse support)
|
|
||||||
has_history = len(session.messages) > 1
|
|
||||||
system_prompt, _ = await _build_system_prompt(
|
|
||||||
user_id, has_conversation_history=has_history
|
|
||||||
)
|
|
||||||
system_prompt += _SDK_TOOL_SUPPLEMENT
|
|
||||||
message_id = str(uuid.uuid4())
|
|
||||||
task_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
yield StreamStart(messageId=message_id, taskId=task_id)
|
|
||||||
|
|
||||||
stream_completed = False
|
|
||||||
# Initialise sdk_cwd before the try so the finally can reference it
|
|
||||||
# even if _make_sdk_cwd raises (in that case it stays as "").
|
|
||||||
sdk_cwd = ""
|
|
||||||
use_resume = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Use a session-specific temp dir to avoid cleanup race conditions
|
|
||||||
# between concurrent sessions.
|
|
||||||
sdk_cwd = _make_sdk_cwd(session_id)
|
|
||||||
os.makedirs(sdk_cwd, exist_ok=True)
|
|
||||||
|
|
||||||
set_execution_context(
|
|
||||||
user_id,
|
|
||||||
session,
|
|
||||||
long_running_callback=_build_long_running_callback(user_id),
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
|
||||||
|
|
||||||
# Fail fast when no API credentials are available at all
|
|
||||||
sdk_env = _build_sdk_env()
|
|
||||||
if not sdk_env and not os.environ.get("ANTHROPIC_API_KEY"):
|
|
||||||
raise RuntimeError(
|
|
||||||
"No API key configured. Set OPEN_ROUTER_API_KEY "
|
|
||||||
"(or CHAT_API_KEY) for OpenRouter routing, "
|
|
||||||
"or ANTHROPIC_API_KEY for direct Anthropic access."
|
|
||||||
)
|
|
||||||
|
|
||||||
mcp_server = create_copilot_mcp_server()
|
|
||||||
|
|
||||||
sdk_model = _resolve_sdk_model()
|
|
||||||
|
|
||||||
# --- Transcript capture via Stop hook ---
|
|
||||||
captured_transcript = CapturedTranscript()
|
|
||||||
|
|
||||||
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
|
|
||||||
captured_transcript.path = transcript_path
|
|
||||||
captured_transcript.sdk_session_id = sdk_session_id
|
|
||||||
|
|
||||||
security_hooks = create_security_hooks(
|
|
||||||
user_id,
|
|
||||||
sdk_cwd=sdk_cwd,
|
|
||||||
max_subtasks=config.claude_agent_max_subtasks,
|
|
||||||
on_stop=_on_stop if config.claude_agent_use_resume else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Resume strategy: download transcript from bucket ---
|
|
||||||
resume_file: str | None = None
|
|
||||||
use_resume = False
|
|
||||||
|
|
||||||
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
|
||||||
transcript_content = await download_transcript(user_id, session_id)
|
|
||||||
if transcript_content and validate_transcript(transcript_content):
|
|
||||||
resume_file = write_transcript_to_tempfile(
|
|
||||||
transcript_content, session_id, sdk_cwd
|
|
||||||
)
|
|
||||||
if resume_file:
|
|
||||||
use_resume = True
|
|
||||||
logger.info(
|
|
||||||
f"[SDK] Using --resume with transcript "
|
|
||||||
f"({len(transcript_content)} bytes)"
|
|
||||||
)
|
|
||||||
|
|
||||||
sdk_options_kwargs: dict[str, Any] = {
|
|
||||||
"system_prompt": system_prompt,
|
|
||||||
"mcp_servers": {"copilot": mcp_server},
|
|
||||||
"allowed_tools": COPILOT_TOOL_NAMES,
|
|
||||||
"disallowed_tools": ["Bash"],
|
|
||||||
"hooks": security_hooks,
|
|
||||||
"cwd": sdk_cwd,
|
|
||||||
"max_buffer_size": config.claude_agent_max_buffer_size,
|
|
||||||
}
|
|
||||||
if sdk_env:
|
|
||||||
sdk_options_kwargs["model"] = sdk_model
|
|
||||||
sdk_options_kwargs["env"] = sdk_env
|
|
||||||
if use_resume and resume_file:
|
|
||||||
sdk_options_kwargs["resume"] = resume_file
|
|
||||||
|
|
||||||
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
adapter = SDKResponseAdapter(message_id=message_id)
|
|
||||||
adapter.set_task_id(task_id)
|
|
||||||
|
|
||||||
async with ClaudeSDKClient(options=options) as client:
|
|
||||||
current_message = message or ""
|
|
||||||
if not current_message and session.messages:
|
|
||||||
last_user = [m for m in session.messages if m.role == "user"]
|
|
||||||
if last_user:
|
|
||||||
current_message = last_user[-1].content or ""
|
|
||||||
|
|
||||||
if not current_message.strip():
|
|
||||||
yield StreamError(
|
|
||||||
errorText="Message cannot be empty.",
|
|
||||||
code="empty_prompt",
|
|
||||||
)
|
|
||||||
yield StreamFinish()
|
|
||||||
return
|
|
||||||
|
|
||||||
# Build query: with --resume the CLI already has full
|
|
||||||
# context, so we only send the new message. Without
|
|
||||||
# resume, compress history into a context prefix.
|
|
||||||
query_message = current_message
|
|
||||||
if not use_resume and len(session.messages) > 1:
|
|
||||||
logger.warning(
|
|
||||||
f"[SDK] Using compression fallback for session "
|
|
||||||
f"{session_id} ({len(session.messages)} messages) — "
|
|
||||||
f"no transcript available for --resume"
|
|
||||||
)
|
|
||||||
compressed = await _compress_conversation_history(session)
|
|
||||||
history_context = _format_conversation_context(compressed)
|
|
||||||
if history_context:
|
|
||||||
query_message = (
|
|
||||||
f"{history_context}\n\n"
|
|
||||||
f"Now, the user says:\n{current_message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[SDK] Sending query ({len(session.messages)} msgs in session)"
|
|
||||||
)
|
|
||||||
logger.debug(f"[SDK] Query preview: {current_message[:80]!r}")
|
|
||||||
await client.query(query_message, session_id=session_id)
|
|
||||||
|
|
||||||
assistant_response = ChatMessage(role="assistant", content="")
|
|
||||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
|
||||||
has_appended_assistant = False
|
|
||||||
has_tool_results = False
|
|
||||||
|
|
||||||
async for sdk_msg in client.receive_messages():
|
|
||||||
logger.debug(
|
|
||||||
f"[SDK] Received: {type(sdk_msg).__name__} "
|
|
||||||
f"{getattr(sdk_msg, 'subtype', '')}"
|
|
||||||
)
|
|
||||||
for response in adapter.convert_message(sdk_msg):
|
|
||||||
if isinstance(response, StreamStart):
|
|
||||||
continue
|
|
||||||
|
|
||||||
yield response
|
|
||||||
|
|
||||||
if isinstance(response, StreamTextDelta):
|
|
||||||
delta = response.delta or ""
|
|
||||||
# After tool results, start a new assistant
|
|
||||||
# message for the post-tool text.
|
|
||||||
if has_tool_results and has_appended_assistant:
|
|
||||||
assistant_response = ChatMessage(
|
|
||||||
role="assistant", content=delta
|
|
||||||
)
|
|
||||||
accumulated_tool_calls = []
|
|
||||||
has_appended_assistant = False
|
|
||||||
has_tool_results = False
|
|
||||||
session.messages.append(assistant_response)
|
|
||||||
has_appended_assistant = True
|
|
||||||
else:
|
|
||||||
assistant_response.content = (
|
|
||||||
assistant_response.content or ""
|
|
||||||
) + delta
|
|
||||||
if not has_appended_assistant:
|
|
||||||
session.messages.append(assistant_response)
|
|
||||||
has_appended_assistant = True
|
|
||||||
|
|
||||||
elif isinstance(response, StreamToolInputAvailable):
|
|
||||||
accumulated_tool_calls.append(
|
|
||||||
{
|
|
||||||
"id": response.toolCallId,
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": response.toolName,
|
|
||||||
"arguments": json.dumps(response.input or {}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
assistant_response.tool_calls = accumulated_tool_calls
|
|
||||||
if not has_appended_assistant:
|
|
||||||
session.messages.append(assistant_response)
|
|
||||||
has_appended_assistant = True
|
|
||||||
|
|
||||||
elif isinstance(response, StreamToolOutputAvailable):
|
|
||||||
session.messages.append(
|
|
||||||
ChatMessage(
|
|
||||||
role="tool",
|
|
||||||
content=(
|
|
||||||
response.output
|
|
||||||
if isinstance(response.output, str)
|
|
||||||
else str(response.output)
|
|
||||||
),
|
|
||||||
tool_call_id=response.toolCallId,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
has_tool_results = True
|
|
||||||
|
|
||||||
elif isinstance(response, StreamFinish):
|
|
||||||
stream_completed = True
|
|
||||||
|
|
||||||
if stream_completed:
|
|
||||||
break
|
|
||||||
|
|
||||||
if (
|
|
||||||
assistant_response.content or assistant_response.tool_calls
|
|
||||||
) and not has_appended_assistant:
|
|
||||||
session.messages.append(assistant_response)
|
|
||||||
|
|
||||||
# --- Capture transcript while CLI is still alive ---
|
|
||||||
# Must happen INSIDE async with: close() sends SIGTERM
|
|
||||||
# which kills the CLI before it can flush the JSONL.
|
|
||||||
if (
|
|
||||||
config.claude_agent_use_resume
|
|
||||||
and user_id
|
|
||||||
and captured_transcript.available
|
|
||||||
):
|
|
||||||
# Give CLI time to flush JSONL writes before we read
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
raw_transcript = read_transcript_file(captured_transcript.path)
|
|
||||||
if raw_transcript:
|
|
||||||
task = asyncio.create_task(
|
|
||||||
_upload_transcript_bg(user_id, session_id, raw_transcript)
|
|
||||||
)
|
|
||||||
_background_tasks.add(task)
|
|
||||||
task.add_done_callback(_background_tasks.discard)
|
|
||||||
else:
|
|
||||||
logger.debug("[SDK] Stop hook fired but transcript not usable")
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
raise RuntimeError(
|
|
||||||
"claude-agent-sdk is not installed. "
|
|
||||||
"Disable SDK mode (CHAT_USE_CLAUDE_AGENT_SDK=false) "
|
|
||||||
"to use the OpenAI-compatible fallback."
|
|
||||||
)
|
|
||||||
|
|
||||||
await upsert_chat_session(session)
|
|
||||||
logger.debug(
|
|
||||||
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
|
|
||||||
)
|
|
||||||
if not stream_completed:
|
|
||||||
yield StreamFinish()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[SDK] Error: {e}", exc_info=True)
|
|
||||||
try:
|
|
||||||
await upsert_chat_session(session)
|
|
||||||
except Exception as save_err:
|
|
||||||
logger.error(f"[SDK] Failed to save session on error: {save_err}")
|
|
||||||
yield StreamError(
|
|
||||||
errorText="An error occurred. Please try again.",
|
|
||||||
code="sdk_error",
|
|
||||||
)
|
|
||||||
yield StreamFinish()
|
|
||||||
finally:
|
|
||||||
if sdk_cwd:
|
|
||||||
_cleanup_sdk_tool_results(sdk_cwd)
|
|
||||||
|
|
||||||
|
|
||||||
async def _upload_transcript_bg(
|
|
||||||
user_id: str, session_id: str, raw_content: str
|
|
||||||
) -> None:
|
|
||||||
"""Background task to strip progress entries and upload transcript."""
|
|
||||||
try:
|
|
||||||
await upload_transcript(user_id, session_id, raw_content)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[SDK] Failed to upload transcript for {session_id}: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def _update_title_async(
|
|
||||||
session_id: str, message: str, user_id: str | None = None
|
|
||||||
) -> None:
|
|
||||||
"""Background task to update session title."""
|
|
||||||
try:
|
|
||||||
title = await _generate_session_title(
|
|
||||||
message, user_id=user_id, session_id=session_id
|
|
||||||
)
|
|
||||||
if title:
|
|
||||||
await update_session_title(session_id, title)
|
|
||||||
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[SDK] Failed to update session title: {e}")
|
|
||||||
@@ -1,322 +0,0 @@
|
|||||||
"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools.
|
|
||||||
|
|
||||||
This module provides the adapter layer that converts existing BaseTool implementations
|
|
||||||
into in-process MCP tools that can be used with the Claude Agent SDK.
|
|
||||||
|
|
||||||
Long-running tools (``is_long_running=True``) are delegated to the non-SDK
|
|
||||||
background infrastructure (stream_registry, Redis persistence, SSE reconnection)
|
|
||||||
via a callback provided by the service layer. This avoids wasteful SDK polling
|
|
||||||
and makes results survive page refreshes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import itertools
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from collections.abc import Awaitable, Callable
|
|
||||||
from contextvars import ContextVar
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.api.features.chat.tools import TOOL_REGISTRY
|
|
||||||
from backend.api.features.chat.tools.base import BaseTool
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
|
|
||||||
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
|
|
||||||
# in the path — prevents reading settings, credentials, or other sensitive files.
|
|
||||||
_SDK_PROJECTS_DIR = os.path.expanduser("~/.claude/projects/")
|
|
||||||
|
|
||||||
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
|
|
||||||
MCP_SERVER_NAME = "copilot"
|
|
||||||
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
|
|
||||||
|
|
||||||
# Context variables to pass user/session info to tool execution
|
|
||||||
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
|
|
||||||
_current_session: ContextVar[ChatSession | None] = ContextVar(
|
|
||||||
"current_session", default=None
|
|
||||||
)
|
|
||||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
|
||||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
|
||||||
# response adapter when it builds StreamToolOutputAvailable.
|
|
||||||
_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar(
|
|
||||||
"pending_tool_outputs", default=None # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Callback type for delegating long-running tools to the non-SDK infrastructure.
|
|
||||||
# Args: (tool_name, arguments, session) → MCP-formatted response dict.
|
|
||||||
LongRunningCallback = Callable[
|
|
||||||
[str, dict[str, Any], ChatSession], Awaitable[dict[str, Any]]
|
|
||||||
]
|
|
||||||
|
|
||||||
# ContextVar so the service layer can inject the callback per-request.
|
|
||||||
_long_running_callback: ContextVar[LongRunningCallback | None] = ContextVar(
|
|
||||||
"long_running_callback", default=None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def set_execution_context(
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
long_running_callback: LongRunningCallback | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Set the execution context for tool calls.
|
|
||||||
|
|
||||||
This must be called before streaming begins to ensure tools have access
|
|
||||||
to user_id and session information.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: Current user's ID.
|
|
||||||
session: Current chat session.
|
|
||||||
long_running_callback: Optional callback to delegate long-running tools
|
|
||||||
to the non-SDK background infrastructure (stream_registry + Redis).
|
|
||||||
"""
|
|
||||||
_current_user_id.set(user_id)
|
|
||||||
_current_session.set(session)
|
|
||||||
_pending_tool_outputs.set({})
|
|
||||||
_long_running_callback.set(long_running_callback)
|
|
||||||
|
|
||||||
|
|
||||||
def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
|
||||||
"""Get the current execution context."""
|
|
||||||
return (
|
|
||||||
_current_user_id.get(),
|
|
||||||
_current_session.get(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def pop_pending_tool_output(tool_name: str) -> str | None:
|
|
||||||
"""Pop and return the stashed full output for *tool_name*.
|
|
||||||
|
|
||||||
The SDK CLI may truncate large tool results (writing them to disk and
|
|
||||||
replacing the content with a file reference). This stash keeps the
|
|
||||||
original MCP output so the response adapter can forward it to the
|
|
||||||
frontend for proper widget rendering.
|
|
||||||
|
|
||||||
Returns ``None`` if nothing was stashed for *tool_name*.
|
|
||||||
"""
|
|
||||||
pending = _pending_tool_outputs.get(None)
|
|
||||||
if pending is None:
|
|
||||||
return None
|
|
||||||
return pending.pop(tool_name, None)
|
|
||||||
|
|
||||||
|
|
||||||
async def _execute_tool_sync(
|
|
||||||
base_tool: BaseTool,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
args: dict[str, Any],
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Execute a tool synchronously and return MCP-formatted response."""
|
|
||||||
effective_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
|
||||||
result = await base_tool.execute(
|
|
||||||
user_id=user_id,
|
|
||||||
session=session,
|
|
||||||
tool_call_id=effective_id,
|
|
||||||
**args,
|
|
||||||
)
|
|
||||||
|
|
||||||
text = (
|
|
||||||
result.output if isinstance(result.output, str) else json.dumps(result.output)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Stash the full output before the SDK potentially truncates it.
|
|
||||||
pending = _pending_tool_outputs.get(None)
|
|
||||||
if pending is not None:
|
|
||||||
pending[base_tool.name] = text
|
|
||||||
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": text}],
|
|
||||||
"isError": not result.success,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _mcp_error(message: str) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": json.dumps({"error": message, "type": "error"})}
|
|
||||||
],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def create_tool_handler(base_tool: BaseTool):
|
|
||||||
"""Create an async handler function for a BaseTool.
|
|
||||||
|
|
||||||
This wraps the existing BaseTool._execute method to be compatible
|
|
||||||
with the Claude Agent SDK MCP tool format.
|
|
||||||
|
|
||||||
Long-running tools (``is_long_running=True``) are delegated to the
|
|
||||||
non-SDK background infrastructure via a callback set in the execution
|
|
||||||
context. The callback persists the operation in Redis (stream_registry)
|
|
||||||
so results survive page refreshes and pod restarts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Execute the wrapped tool and return MCP-formatted response."""
|
|
||||||
user_id, session = get_execution_context()
|
|
||||||
|
|
||||||
if session is None:
|
|
||||||
return _mcp_error("No session context available")
|
|
||||||
|
|
||||||
# --- Long-running: delegate to non-SDK background infrastructure ---
|
|
||||||
if base_tool.is_long_running:
|
|
||||||
callback = _long_running_callback.get(None)
|
|
||||||
if callback:
|
|
||||||
try:
|
|
||||||
return await callback(base_tool.name, args, session)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Long-running callback failed for {base_tool.name}: {e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
return _mcp_error(f"Failed to start {base_tool.name}: {e}")
|
|
||||||
# No callback — fall through to synchronous execution
|
|
||||||
logger.warning(
|
|
||||||
f"[SDK] No long-running callback for {base_tool.name}, "
|
|
||||||
f"executing synchronously (may block)"
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Normal (fast) tool: execute synchronously ---
|
|
||||||
try:
|
|
||||||
return await _execute_tool_sync(base_tool, user_id, session, args)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
|
|
||||||
return _mcp_error(f"Failed to execute {base_tool.name}: {e}")
|
|
||||||
|
|
||||||
return tool_handler
|
|
||||||
|
|
||||||
|
|
||||||
def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
|
|
||||||
"""Build a JSON Schema input schema for a tool."""
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": base_tool.parameters.get("properties", {}),
|
|
||||||
"required": base_tool.parameters.get("required", []),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Read a file with optional offset/limit. Restricted to SDK working directory.
|
|
||||||
|
|
||||||
After reading, the file is deleted to prevent accumulation in long-running pods.
|
|
||||||
"""
|
|
||||||
file_path = args.get("file_path", "")
|
|
||||||
offset = args.get("offset", 0)
|
|
||||||
limit = args.get("limit", 2000)
|
|
||||||
|
|
||||||
# Security: only allow reads under ~/.claude/projects/**/tool-results/
|
|
||||||
real_path = os.path.realpath(file_path)
|
|
||||||
if not real_path.startswith(_SDK_PROJECTS_DIR) or "tool-results" not in real_path:
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(real_path) as f:
|
|
||||||
selected = list(itertools.islice(f, offset, offset + limit))
|
|
||||||
content = "".join(selected)
|
|
||||||
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
|
|
||||||
# don't delete here — the SDK may read in multiple chunks.
|
|
||||||
return {"content": [{"type": "text", "text": content}], "isError": False}
|
|
||||||
except FileNotFoundError:
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
_READ_TOOL_NAME = "Read"
|
|
||||||
_READ_TOOL_DESCRIPTION = (
|
|
||||||
"Read a file from the local filesystem. "
|
|
||||||
"Use offset and limit to read specific line ranges for large files."
|
|
||||||
)
|
|
||||||
_READ_TOOL_SCHEMA = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"file_path": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The absolute path to the file to read",
|
|
||||||
},
|
|
||||||
"offset": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Line number to start reading from (0-indexed). Default: 0",
|
|
||||||
},
|
|
||||||
"limit": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "Number of lines to read. Default: 2000",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["file_path"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Create the MCP server configuration
|
|
||||||
def create_copilot_mcp_server():
|
|
||||||
"""Create an in-process MCP server configuration for CoPilot tools.
|
|
||||||
|
|
||||||
This can be passed to ClaudeAgentOptions.mcp_servers.
|
|
||||||
|
|
||||||
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
|
|
||||||
package being available. This function returns the configuration that
|
|
||||||
can be used with the SDK.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
|
||||||
|
|
||||||
# Create decorated tool functions
|
|
||||||
sdk_tools = []
|
|
||||||
|
|
||||||
for tool_name, base_tool in TOOL_REGISTRY.items():
|
|
||||||
handler = create_tool_handler(base_tool)
|
|
||||||
decorated = tool(
|
|
||||||
tool_name,
|
|
||||||
base_tool.description,
|
|
||||||
_build_input_schema(base_tool),
|
|
||||||
)(handler)
|
|
||||||
sdk_tools.append(decorated)
|
|
||||||
|
|
||||||
# Add the Read tool so the SDK can read back oversized tool results
|
|
||||||
read_tool = tool(
|
|
||||||
_READ_TOOL_NAME,
|
|
||||||
_READ_TOOL_DESCRIPTION,
|
|
||||||
_READ_TOOL_SCHEMA,
|
|
||||||
)(_read_file_handler)
|
|
||||||
sdk_tools.append(read_tool)
|
|
||||||
|
|
||||||
server = create_sdk_mcp_server(
|
|
||||||
name=MCP_SERVER_NAME,
|
|
||||||
version="1.0.0",
|
|
||||||
tools=sdk_tools,
|
|
||||||
)
|
|
||||||
|
|
||||||
return server
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
# Let ImportError propagate so service.py handles the fallback
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
# SDK built-in tools allowed within the workspace directory.
|
|
||||||
# Security hooks validate that file paths stay within sdk_cwd.
|
|
||||||
# Bash is NOT included — use the sandboxed MCP bash_exec tool instead,
|
|
||||||
# which provides kernel-level network isolation via unshare --net.
|
|
||||||
# Task allows spawning sub-agents (rate-limited by security hooks).
|
|
||||||
_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Task"]
|
|
||||||
|
|
||||||
# List of tool names for allowed_tools configuration
|
|
||||||
# Include MCP tools, the MCP Read tool for oversized results,
|
|
||||||
# and SDK built-in file tools for workspace operations.
|
|
||||||
COPILOT_TOOL_NAMES = [
|
|
||||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
|
||||||
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
|
||||||
*_SDK_BUILTIN_TOOLS,
|
|
||||||
]
|
|
||||||
@@ -1,356 +0,0 @@
|
|||||||
"""JSONL transcript management for stateless multi-turn resume.
|
|
||||||
|
|
||||||
The Claude Code CLI persists conversations as JSONL files (one JSON object per
|
|
||||||
line). When the SDK's ``Stop`` hook fires we read this file, strip bloat
|
|
||||||
(progress entries, metadata), and upload the result to bucket storage. On the
|
|
||||||
next turn we download the transcript, write it to a temp file, and pass
|
|
||||||
``--resume`` so the CLI can reconstruct the full conversation.
|
|
||||||
|
|
||||||
Storage is handled via ``WorkspaceStorageBackend`` (GCS in prod, local
|
|
||||||
filesystem for self-hosted) — no DB column needed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
|
|
||||||
_SAFE_ID_RE = re.compile(r"[^0-9a-fA-F-]")
|
|
||||||
|
|
||||||
# Entry types that can be safely removed from the transcript without breaking
|
|
||||||
# the parentUuid conversation tree that ``--resume`` relies on.
|
|
||||||
# - progress: UI progress ticks, no message content (avg 97KB for agent_progress)
|
|
||||||
# - file-history-snapshot: undo tracking metadata
|
|
||||||
# - queue-operation: internal queue bookkeeping
|
|
||||||
# - summary: session summaries
|
|
||||||
# - pr-link: PR link metadata
|
|
||||||
STRIPPABLE_TYPES = frozenset(
|
|
||||||
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Workspace storage constants — deterministic path from session_id.
|
|
||||||
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Progress stripping
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def strip_progress_entries(content: str) -> str:
|
|
||||||
"""Remove progress/metadata entries from a JSONL transcript.
|
|
||||||
|
|
||||||
Removes entries whose ``type`` is in ``STRIPPABLE_TYPES`` and reparents
|
|
||||||
any remaining child entries so the ``parentUuid`` chain stays intact.
|
|
||||||
Typically reduces transcript size by ~30%.
|
|
||||||
"""
|
|
||||||
lines = content.strip().split("\n")
|
|
||||||
|
|
||||||
entries: list[dict] = []
|
|
||||||
for line in lines:
|
|
||||||
try:
|
|
||||||
entries.append(json.loads(line))
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
# Keep unparseable lines as-is (safety)
|
|
||||||
entries.append({"_raw": line})
|
|
||||||
|
|
||||||
stripped_uuids: set[str] = set()
|
|
||||||
uuid_to_parent: dict[str, str] = {}
|
|
||||||
kept: list[dict] = []
|
|
||||||
|
|
||||||
for entry in entries:
|
|
||||||
if "_raw" in entry:
|
|
||||||
kept.append(entry)
|
|
||||||
continue
|
|
||||||
uid = entry.get("uuid", "")
|
|
||||||
parent = entry.get("parentUuid", "")
|
|
||||||
entry_type = entry.get("type", "")
|
|
||||||
|
|
||||||
if uid:
|
|
||||||
uuid_to_parent[uid] = parent
|
|
||||||
|
|
||||||
if entry_type in STRIPPABLE_TYPES:
|
|
||||||
if uid:
|
|
||||||
stripped_uuids.add(uid)
|
|
||||||
else:
|
|
||||||
kept.append(entry)
|
|
||||||
|
|
||||||
# Reparent: walk up chain through stripped entries to find surviving ancestor
|
|
||||||
for entry in kept:
|
|
||||||
if "_raw" in entry:
|
|
||||||
continue
|
|
||||||
parent = entry.get("parentUuid", "")
|
|
||||||
original_parent = parent
|
|
||||||
while parent in stripped_uuids:
|
|
||||||
parent = uuid_to_parent.get(parent, "")
|
|
||||||
if parent != original_parent:
|
|
||||||
entry["parentUuid"] = parent
|
|
||||||
|
|
||||||
result_lines: list[str] = []
|
|
||||||
for entry in kept:
|
|
||||||
if "_raw" in entry:
|
|
||||||
result_lines.append(entry["_raw"])
|
|
||||||
else:
|
|
||||||
result_lines.append(json.dumps(entry, separators=(",", ":")))
|
|
||||||
|
|
||||||
return "\n".join(result_lines) + "\n"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Local file I/O (read from CLI's JSONL, write temp file for --resume)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def read_transcript_file(transcript_path: str) -> str | None:
|
|
||||||
"""Read a JSONL transcript file from disk.
|
|
||||||
|
|
||||||
Returns the raw JSONL content, or ``None`` if the file is missing, empty,
|
|
||||||
or only contains metadata (≤2 lines with no conversation messages).
|
|
||||||
"""
|
|
||||||
if not transcript_path or not os.path.isfile(transcript_path):
|
|
||||||
logger.debug(f"[Transcript] File not found: {transcript_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(transcript_path) as f:
|
|
||||||
content = f.read()
|
|
||||||
|
|
||||||
if not content.strip():
|
|
||||||
logger.debug(f"[Transcript] Empty file: {transcript_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
lines = content.strip().split("\n")
|
|
||||||
if len(lines) < 3:
|
|
||||||
# Raw files with ≤2 lines are metadata-only
|
|
||||||
# (queue-operation + file-history-snapshot, no conversation).
|
|
||||||
logger.debug(
|
|
||||||
f"[Transcript] Too few lines ({len(lines)}): {transcript_path}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Quick structural validation — parse first and last lines.
|
|
||||||
json.loads(lines[0])
|
|
||||||
json.loads(lines[-1])
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[Transcript] Read {len(lines)} lines, "
|
|
||||||
f"{len(content)} bytes from {transcript_path}"
|
|
||||||
)
|
|
||||||
return content
|
|
||||||
|
|
||||||
except (json.JSONDecodeError, OSError) as e:
|
|
||||||
logger.warning(f"[Transcript] Failed to read {transcript_path}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
|
||||||
"""Sanitize an ID for safe use in file paths.
|
|
||||||
|
|
||||||
Session/user IDs are expected to be UUIDs (hex + hyphens). Strip
|
|
||||||
everything else and truncate to *max_len* so the result cannot introduce
|
|
||||||
path separators or other special characters.
|
|
||||||
"""
|
|
||||||
cleaned = _SAFE_ID_RE.sub("", raw_id or "")[:max_len]
|
|
||||||
return cleaned or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
|
||||||
|
|
||||||
|
|
||||||
def write_transcript_to_tempfile(
|
|
||||||
transcript_content: str,
|
|
||||||
session_id: str,
|
|
||||||
cwd: str,
|
|
||||||
) -> str | None:
|
|
||||||
"""Write JSONL transcript to a temp file inside *cwd* for ``--resume``.
|
|
||||||
|
|
||||||
The file lives in the session working directory so it is cleaned up
|
|
||||||
automatically when the session ends.
|
|
||||||
|
|
||||||
Returns the absolute path to the file, or ``None`` on failure.
|
|
||||||
"""
|
|
||||||
# Validate cwd is under the expected sandbox prefix (CodeQL sanitizer).
|
|
||||||
real_cwd = os.path.realpath(cwd)
|
|
||||||
if not real_cwd.startswith(_SAFE_CWD_PREFIX):
|
|
||||||
logger.warning(f"[Transcript] cwd outside sandbox: {cwd}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
os.makedirs(real_cwd, exist_ok=True)
|
|
||||||
safe_id = _sanitize_id(session_id, max_len=8)
|
|
||||||
jsonl_path = os.path.realpath(
|
|
||||||
os.path.join(real_cwd, f"transcript-{safe_id}.jsonl")
|
|
||||||
)
|
|
||||||
if not jsonl_path.startswith(real_cwd):
|
|
||||||
logger.warning(f"[Transcript] Path escaped cwd: {jsonl_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
with open(jsonl_path, "w") as f:
|
|
||||||
f.write(transcript_content)
|
|
||||||
|
|
||||||
logger.info(f"[Transcript] Wrote resume file: {jsonl_path}")
|
|
||||||
return jsonl_path
|
|
||||||
|
|
||||||
except OSError as e:
|
|
||||||
logger.warning(f"[Transcript] Failed to write resume file: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def validate_transcript(content: str | None) -> bool:
|
|
||||||
"""Check that a transcript has actual conversation messages.
|
|
||||||
|
|
||||||
A valid transcript for resume needs at least one user message and one
|
|
||||||
assistant message (not just queue-operation / file-history-snapshot
|
|
||||||
metadata).
|
|
||||||
"""
|
|
||||||
if not content or not content.strip():
|
|
||||||
return False
|
|
||||||
|
|
||||||
lines = content.strip().split("\n")
|
|
||||||
if len(lines) < 2:
|
|
||||||
return False
|
|
||||||
|
|
||||||
has_user = False
|
|
||||||
has_assistant = False
|
|
||||||
|
|
||||||
for line in lines:
|
|
||||||
try:
|
|
||||||
entry = json.loads(line)
|
|
||||||
msg_type = entry.get("type")
|
|
||||||
if msg_type == "user":
|
|
||||||
has_user = True
|
|
||||||
elif msg_type == "assistant":
|
|
||||||
has_assistant = True
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return has_user and has_assistant
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Bucket storage (GCS / local via WorkspaceStorageBackend)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
|
||||||
"""Return (workspace_id, file_id, filename) for a session's transcript.
|
|
||||||
|
|
||||||
Path structure: ``chat-transcripts/{user_id}/{session_id}.jsonl``
|
|
||||||
IDs are sanitized to hex+hyphen to prevent path traversal.
|
|
||||||
"""
|
|
||||||
return (
|
|
||||||
TRANSCRIPT_STORAGE_PREFIX,
|
|
||||||
_sanitize_id(user_id),
|
|
||||||
f"{_sanitize_id(session_id)}.jsonl",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
|
||||||
"""Build the full storage path string that ``retrieve()`` expects.
|
|
||||||
|
|
||||||
``store()`` returns a path like ``gcs://bucket/workspaces/...`` or
|
|
||||||
``local://workspace_id/file_id/filename``. Since we use deterministic
|
|
||||||
arguments we can reconstruct the same path for download/delete without
|
|
||||||
having stored the return value.
|
|
||||||
"""
|
|
||||||
from backend.util.workspace_storage import GCSWorkspaceStorage
|
|
||||||
|
|
||||||
wid, fid, fname = _storage_path_parts(user_id, session_id)
|
|
||||||
|
|
||||||
if isinstance(backend, GCSWorkspaceStorage):
|
|
||||||
blob = f"workspaces/{wid}/{fid}/{fname}"
|
|
||||||
return f"gcs://{backend.bucket_name}/{blob}"
|
|
||||||
else:
|
|
||||||
# LocalWorkspaceStorage returns local://{relative_path}
|
|
||||||
return f"local://{wid}/{fid}/{fname}"
|
|
||||||
|
|
||||||
|
|
||||||
async def upload_transcript(user_id: str, session_id: str, content: str) -> None:
|
|
||||||
"""Strip progress entries and upload transcript to bucket storage.
|
|
||||||
|
|
||||||
Safety: only overwrites when the new (stripped) transcript is larger than
|
|
||||||
what is already stored. Since JSONL is append-only, the latest transcript
|
|
||||||
is always the longest. This prevents a slow/stale background task from
|
|
||||||
clobbering a newer upload from a concurrent turn.
|
|
||||||
"""
|
|
||||||
from backend.util.workspace_storage import get_workspace_storage
|
|
||||||
|
|
||||||
stripped = strip_progress_entries(content)
|
|
||||||
if not validate_transcript(stripped):
|
|
||||||
logger.warning(
|
|
||||||
f"[Transcript] Skipping upload — stripped content is not a valid "
|
|
||||||
f"transcript for session {session_id}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
storage = await get_workspace_storage()
|
|
||||||
wid, fid, fname = _storage_path_parts(user_id, session_id)
|
|
||||||
encoded = stripped.encode("utf-8")
|
|
||||||
new_size = len(encoded)
|
|
||||||
|
|
||||||
# Check existing transcript size to avoid overwriting newer with older
|
|
||||||
path = _build_storage_path(user_id, session_id, storage)
|
|
||||||
try:
|
|
||||||
existing = await storage.retrieve(path)
|
|
||||||
if len(existing) >= new_size:
|
|
||||||
logger.info(
|
|
||||||
f"[Transcript] Skipping upload — existing transcript "
|
|
||||||
f"({len(existing)}B) >= new ({new_size}B) for session "
|
|
||||||
f"{session_id}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
except (FileNotFoundError, Exception):
|
|
||||||
pass # No existing transcript or retrieval error — proceed with upload
|
|
||||||
|
|
||||||
await storage.store(
|
|
||||||
workspace_id=wid,
|
|
||||||
file_id=fid,
|
|
||||||
filename=fname,
|
|
||||||
content=encoded,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"[Transcript] Uploaded {new_size} bytes "
|
|
||||||
f"(stripped from {len(content)}) for session {session_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def download_transcript(user_id: str, session_id: str) -> str | None:
|
|
||||||
"""Download transcript from bucket storage.
|
|
||||||
|
|
||||||
Returns the JSONL content string, or ``None`` if not found.
|
|
||||||
"""
|
|
||||||
from backend.util.workspace_storage import get_workspace_storage
|
|
||||||
|
|
||||||
storage = await get_workspace_storage()
|
|
||||||
path = _build_storage_path(user_id, session_id, storage)
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = await storage.retrieve(path)
|
|
||||||
content = data.decode("utf-8")
|
|
||||||
logger.info(
|
|
||||||
f"[Transcript] Downloaded {len(content)} bytes for session {session_id}"
|
|
||||||
)
|
|
||||||
return content
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[Transcript] Failed to download transcript: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_transcript(user_id: str, session_id: str) -> None:
|
|
||||||
"""Delete transcript from bucket storage (e.g. after resume failure)."""
|
|
||||||
from backend.util.workspace_storage import get_workspace_storage
|
|
||||||
|
|
||||||
storage = await get_workspace_storage()
|
|
||||||
path = _build_storage_path(user_id, session_id, storage)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await storage.delete(path)
|
|
||||||
logger.info(f"[Transcript] Deleted transcript for session {session_id}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[Transcript] Failed to delete transcript: {e}")
|
|
||||||
@@ -245,16 +245,12 @@ async def _get_system_prompt_template(context: str) -> str:
|
|||||||
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
|
return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
|
||||||
|
|
||||||
|
|
||||||
async def _build_system_prompt(
|
async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
|
||||||
user_id: str | None, has_conversation_history: bool = False
|
|
||||||
) -> tuple[str, Any]:
|
|
||||||
"""Build the full system prompt including business understanding if available.
|
"""Build the full system prompt including business understanding if available.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: The user ID for fetching business understanding.
|
user_id: The user ID for fetching business understanding
|
||||||
has_conversation_history: Whether there's existing conversation history.
|
If "default" and this is the user's first session, will use "onboarding" instead.
|
||||||
If True, we don't tell the model to greet/introduce (since they're
|
|
||||||
already in a conversation).
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (compiled prompt string, business understanding object)
|
Tuple of (compiled prompt string, business understanding object)
|
||||||
@@ -270,8 +266,6 @@ async def _build_system_prompt(
|
|||||||
|
|
||||||
if understanding:
|
if understanding:
|
||||||
context = format_understanding_for_prompt(understanding)
|
context = format_understanding_for_prompt(understanding)
|
||||||
elif has_conversation_history:
|
|
||||||
context = "No prior understanding saved yet. Continue the existing conversation naturally."
|
|
||||||
else:
|
else:
|
||||||
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
|
||||||
|
|
||||||
@@ -380,6 +374,7 @@ async def stream_chat_completion(
|
|||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotFoundError: If session_id is invalid
|
NotFoundError: If session_id is invalid
|
||||||
|
ValueError: If max_context_messages is exceeded
|
||||||
|
|
||||||
"""
|
"""
|
||||||
completion_start = time.monotonic()
|
completion_start = time.monotonic()
|
||||||
@@ -464,9 +459,8 @@ async def stream_chat_completion(
|
|||||||
|
|
||||||
# Generate title for new sessions on first user message (non-blocking)
|
# Generate title for new sessions on first user message (non-blocking)
|
||||||
# Check: is_user_message, no title yet, and this is the first user message
|
# Check: is_user_message, no title yet, and this is the first user message
|
||||||
|
if is_user_message and message and not session.title:
|
||||||
user_messages = [m for m in session.messages if m.role == "user"]
|
user_messages = [m for m in session.messages if m.role == "user"]
|
||||||
first_user_msg = message or (user_messages[0].content if user_messages else None)
|
|
||||||
if is_user_message and first_user_msg and not session.title:
|
|
||||||
if len(user_messages) == 1:
|
if len(user_messages) == 1:
|
||||||
# First user message - generate title in background
|
# First user message - generate title in background
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -474,7 +468,7 @@ async def stream_chat_completion(
|
|||||||
# Capture only the values we need (not the session object) to avoid
|
# Capture only the values we need (not the session object) to avoid
|
||||||
# stale data issues when the main flow modifies the session
|
# stale data issues when the main flow modifies the session
|
||||||
captured_session_id = session_id
|
captured_session_id = session_id
|
||||||
captured_message = first_user_msg
|
captured_message = message
|
||||||
captured_user_id = user_id
|
captured_user_id = user_id
|
||||||
|
|
||||||
async def _update_title():
|
async def _update_title():
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
from os import getenv
|
from os import getenv
|
||||||
|
|
||||||
@@ -12,8 +11,6 @@ from .response_model import (
|
|||||||
StreamTextDelta,
|
StreamTextDelta,
|
||||||
StreamToolOutputAvailable,
|
StreamToolOutputAvailable,
|
||||||
)
|
)
|
||||||
from .sdk import service as sdk_service
|
|
||||||
from .sdk.transcript import download_transcript
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -83,96 +80,3 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
|
|||||||
session = await get_chat_session(session.session_id)
|
session = await get_chat_session(session.session_id)
|
||||||
assert session, "Session not found"
|
assert session, "Session not found"
|
||||||
assert session.usage, "Usage is empty"
|
assert session.usage, "Usage is empty"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
|
||||||
"""Test that the SDK --resume path captures and uses transcripts across turns.
|
|
||||||
|
|
||||||
Turn 1: Send a message containing a unique keyword.
|
|
||||||
Turn 2: Ask the model to recall that keyword — proving the transcript was
|
|
||||||
persisted and restored via --resume.
|
|
||||||
"""
|
|
||||||
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
|
|
||||||
if not api_key:
|
|
||||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
|
||||||
|
|
||||||
from .config import ChatConfig
|
|
||||||
|
|
||||||
cfg = ChatConfig()
|
|
||||||
if not cfg.claude_agent_use_resume:
|
|
||||||
return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test")
|
|
||||||
|
|
||||||
session = await create_chat_session(test_user_id)
|
|
||||||
session = await upsert_chat_session(session)
|
|
||||||
|
|
||||||
# --- Turn 1: send a message with a unique keyword ---
|
|
||||||
keyword = "ZEPHYR42"
|
|
||||||
turn1_msg = (
|
|
||||||
f"Please remember this special keyword: {keyword}. "
|
|
||||||
"Just confirm you've noted it, keep your response brief."
|
|
||||||
)
|
|
||||||
turn1_text = ""
|
|
||||||
turn1_errors: list[str] = []
|
|
||||||
turn1_ended = False
|
|
||||||
|
|
||||||
async for chunk in sdk_service.stream_chat_completion_sdk(
|
|
||||||
session.session_id,
|
|
||||||
turn1_msg,
|
|
||||||
user_id=test_user_id,
|
|
||||||
):
|
|
||||||
if isinstance(chunk, StreamTextDelta):
|
|
||||||
turn1_text += chunk.delta
|
|
||||||
elif isinstance(chunk, StreamError):
|
|
||||||
turn1_errors.append(chunk.errorText)
|
|
||||||
elif isinstance(chunk, StreamFinish):
|
|
||||||
turn1_ended = True
|
|
||||||
|
|
||||||
assert turn1_ended, "Turn 1 did not finish"
|
|
||||||
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
|
|
||||||
assert turn1_text, "Turn 1 produced no text"
|
|
||||||
|
|
||||||
# Wait for background upload task to complete (retry up to 5s)
|
|
||||||
transcript = None
|
|
||||||
for _ in range(10):
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
transcript = await download_transcript(test_user_id, session.session_id)
|
|
||||||
if transcript:
|
|
||||||
break
|
|
||||||
assert transcript, (
|
|
||||||
"Transcript was not uploaded to bucket after turn 1 — "
|
|
||||||
"Stop hook may not have fired or transcript was too small"
|
|
||||||
)
|
|
||||||
logger.info(f"Turn 1 transcript uploaded: {len(transcript)} bytes")
|
|
||||||
|
|
||||||
# Reload session for turn 2
|
|
||||||
session = await get_chat_session(session.session_id, test_user_id)
|
|
||||||
assert session, "Session not found after turn 1"
|
|
||||||
|
|
||||||
# --- Turn 2: ask model to recall the keyword ---
|
|
||||||
turn2_msg = "What was the special keyword I asked you to remember?"
|
|
||||||
turn2_text = ""
|
|
||||||
turn2_errors: list[str] = []
|
|
||||||
turn2_ended = False
|
|
||||||
|
|
||||||
async for chunk in sdk_service.stream_chat_completion_sdk(
|
|
||||||
session.session_id,
|
|
||||||
turn2_msg,
|
|
||||||
user_id=test_user_id,
|
|
||||||
session=session,
|
|
||||||
):
|
|
||||||
if isinstance(chunk, StreamTextDelta):
|
|
||||||
turn2_text += chunk.delta
|
|
||||||
elif isinstance(chunk, StreamError):
|
|
||||||
turn2_errors.append(chunk.errorText)
|
|
||||||
elif isinstance(chunk, StreamFinish):
|
|
||||||
turn2_ended = True
|
|
||||||
|
|
||||||
assert turn2_ended, "Turn 2 did not finish"
|
|
||||||
assert not turn2_errors, f"Turn 2 errors: {turn2_errors}"
|
|
||||||
assert turn2_text, "Turn 2 produced no text"
|
|
||||||
assert keyword in turn2_text, (
|
|
||||||
f"Model did not recall keyword '{keyword}' in turn 2. "
|
|
||||||
f"Response: {turn2_text[:200]}"
|
|
||||||
)
|
|
||||||
logger.info(f"Turn 2 recalled keyword successfully: {turn2_text[:100]}")
|
|
||||||
|
|||||||
@@ -814,28 +814,6 @@ async def get_active_task_for_session(
|
|||||||
if task_user_id and user_id != task_user_id:
|
if task_user_id and user_id != task_user_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Auto-expire stale tasks that exceeded stream_timeout
|
|
||||||
created_at_str = meta.get("created_at", "")
|
|
||||||
if created_at_str:
|
|
||||||
try:
|
|
||||||
created_at = datetime.fromisoformat(created_at_str)
|
|
||||||
age_seconds = (
|
|
||||||
datetime.now(timezone.utc) - created_at
|
|
||||||
).total_seconds()
|
|
||||||
if age_seconds > config.stream_timeout:
|
|
||||||
logger.warning(
|
|
||||||
f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... "
|
|
||||||
f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)"
|
|
||||||
)
|
|
||||||
await mark_task_completed(task_id, "failed")
|
|
||||||
continue
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get the last message ID from Redis Stream
|
# Get the last message ID from Redis Stream
|
||||||
stream_key = _get_task_stream_key(task_id)
|
stream_key = _get_task_stream_key(task_id)
|
||||||
last_id = "0-0"
|
last_id = "0-0"
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ from backend.api.features.chat.tracking import track_tool_called
|
|||||||
from .add_understanding import AddUnderstandingTool
|
from .add_understanding import AddUnderstandingTool
|
||||||
from .agent_output import AgentOutputTool
|
from .agent_output import AgentOutputTool
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .bash_exec import BashExecTool
|
|
||||||
from .check_operation_status import CheckOperationStatusTool
|
|
||||||
from .create_agent import CreateAgentTool
|
from .create_agent import CreateAgentTool
|
||||||
from .customize_agent import CustomizeAgentTool
|
from .customize_agent import CustomizeAgentTool
|
||||||
from .edit_agent import EditAgentTool
|
from .edit_agent import EditAgentTool
|
||||||
@@ -22,7 +20,6 @@ from .get_doc_page import GetDocPageTool
|
|||||||
from .run_agent import RunAgentTool
|
from .run_agent import RunAgentTool
|
||||||
from .run_block import RunBlockTool
|
from .run_block import RunBlockTool
|
||||||
from .search_docs import SearchDocsTool
|
from .search_docs import SearchDocsTool
|
||||||
from .web_fetch import WebFetchTool
|
|
||||||
from .workspace_files import (
|
from .workspace_files import (
|
||||||
DeleteWorkspaceFileTool,
|
DeleteWorkspaceFileTool,
|
||||||
ListWorkspaceFilesTool,
|
ListWorkspaceFilesTool,
|
||||||
@@ -47,14 +44,8 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
|||||||
"run_agent": RunAgentTool(),
|
"run_agent": RunAgentTool(),
|
||||||
"run_block": RunBlockTool(),
|
"run_block": RunBlockTool(),
|
||||||
"view_agent_output": AgentOutputTool(),
|
"view_agent_output": AgentOutputTool(),
|
||||||
"check_operation_status": CheckOperationStatusTool(),
|
|
||||||
"search_docs": SearchDocsTool(),
|
"search_docs": SearchDocsTool(),
|
||||||
"get_doc_page": GetDocPageTool(),
|
"get_doc_page": GetDocPageTool(),
|
||||||
# Web fetch for safe URL retrieval
|
|
||||||
"web_fetch": WebFetchTool(),
|
|
||||||
# Sandboxed code execution (bubblewrap)
|
|
||||||
"bash_exec": BashExecTool(),
|
|
||||||
# Persistent workspace tools (cloud storage, survives across sessions)
|
|
||||||
# Feature request tools
|
# Feature request tools
|
||||||
"search_feature_requests": SearchFeatureRequestsTool(),
|
"search_feature_requests": SearchFeatureRequestsTool(),
|
||||||
"create_feature_request": CreateFeatureRequestTool(),
|
"create_feature_request": CreateFeatureRequestTool(),
|
||||||
|
|||||||
@@ -1,131 +0,0 @@
|
|||||||
"""Bash execution tool — run shell commands in a bubblewrap sandbox.
|
|
||||||
|
|
||||||
Full Bash scripting is allowed (loops, conditionals, pipes, functions, etc.).
|
|
||||||
Safety comes from OS-level isolation (bubblewrap): only system dirs visible
|
|
||||||
read-only, writable workspace only, clean env, no network.
|
|
||||||
|
|
||||||
Requires bubblewrap (``bwrap``) — the tool is disabled when bwrap is not
|
|
||||||
available (e.g. macOS development).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.api.features.chat.tools.base import BaseTool
|
|
||||||
from backend.api.features.chat.tools.models import (
|
|
||||||
BashExecResponse,
|
|
||||||
ErrorResponse,
|
|
||||||
ToolResponseBase,
|
|
||||||
)
|
|
||||||
from backend.api.features.chat.tools.sandbox import (
|
|
||||||
get_workspace_dir,
|
|
||||||
has_full_sandbox,
|
|
||||||
run_sandboxed,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class BashExecTool(BaseTool):
|
|
||||||
"""Execute Bash commands in a bubblewrap sandbox."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "bash_exec"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
if not has_full_sandbox():
|
|
||||||
return (
|
|
||||||
"Bash execution is DISABLED — bubblewrap sandbox is not "
|
|
||||||
"available on this platform. Do not call this tool."
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
"Execute a Bash command or script in a bubblewrap sandbox. "
|
|
||||||
"Full Bash scripting is supported (loops, conditionals, pipes, "
|
|
||||||
"functions, etc.). "
|
|
||||||
"The sandbox shares the same working directory as the SDK Read/Write "
|
|
||||||
"tools — files created by either are accessible to both. "
|
|
||||||
"SECURITY: Only system directories (/usr, /bin, /lib, /etc) are "
|
|
||||||
"visible read-only, the per-session workspace is the only writable "
|
|
||||||
"path, environment variables are wiped (no secrets), all network "
|
|
||||||
"access is blocked at the kernel level, and resource limits are "
|
|
||||||
"enforced (max 64 processes, 512MB memory, 50MB file size). "
|
|
||||||
"Application code, configs, and other directories are NOT accessible. "
|
|
||||||
"To fetch web content, use the web_fetch tool instead. "
|
|
||||||
"Execution is killed after the timeout (default 30s, max 120s). "
|
|
||||||
"Returns stdout and stderr. "
|
|
||||||
"Useful for file manipulation, data processing with Unix tools "
|
|
||||||
"(grep, awk, sed, jq, etc.), and running shell scripts."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"command": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Bash command or script to execute.",
|
|
||||||
},
|
|
||||||
"timeout": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": (
|
|
||||||
"Max execution time in seconds (default 30, max 120)."
|
|
||||||
),
|
|
||||||
"default": 30,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["command"],
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
session_id = session.session_id if session else None
|
|
||||||
|
|
||||||
if not has_full_sandbox():
|
|
||||||
return ErrorResponse(
|
|
||||||
message="bash_exec requires bubblewrap sandbox (Linux only).",
|
|
||||||
error="sandbox_unavailable",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
command: str = (kwargs.get("command") or "").strip()
|
|
||||||
timeout: int = kwargs.get("timeout", 30)
|
|
||||||
|
|
||||||
if not command:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="No command provided.",
|
|
||||||
error="empty_command",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
workspace = get_workspace_dir(session_id or "default")
|
|
||||||
|
|
||||||
stdout, stderr, exit_code, timed_out = await run_sandboxed(
|
|
||||||
command=["bash", "-c", command],
|
|
||||||
cwd=workspace,
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
return BashExecResponse(
|
|
||||||
message=(
|
|
||||||
"Execution timed out"
|
|
||||||
if timed_out
|
|
||||||
else f"Command executed (exit {exit_code})"
|
|
||||||
),
|
|
||||||
stdout=stdout,
|
|
||||||
stderr=stderr,
|
|
||||||
exit_code=exit_code,
|
|
||||||
timed_out=timed_out,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
@@ -1,127 +0,0 @@
|
|||||||
"""CheckOperationStatusTool — query the status of a long-running operation."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.api.features.chat.tools.base import BaseTool
|
|
||||||
from backend.api.features.chat.tools.models import (
|
|
||||||
ErrorResponse,
|
|
||||||
ResponseType,
|
|
||||||
ToolResponseBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class OperationStatusResponse(ToolResponseBase):
|
|
||||||
"""Response for check_operation_status tool."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.OPERATION_STATUS
|
|
||||||
task_id: str
|
|
||||||
operation_id: str
|
|
||||||
status: str # "running", "completed", "failed"
|
|
||||||
tool_name: str | None = None
|
|
||||||
message: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
class CheckOperationStatusTool(BaseTool):
|
|
||||||
"""Check the status of a long-running operation (create_agent, edit_agent, etc.).
|
|
||||||
|
|
||||||
The CoPilot uses this tool to report back to the user whether an
|
|
||||||
operation that was started earlier has completed, failed, or is still
|
|
||||||
running.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "check_operation_status"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return (
|
|
||||||
"Check the current status of a long-running operation such as "
|
|
||||||
"create_agent or edit_agent. Accepts either an operation_id or "
|
|
||||||
"task_id from a previous operation_started response. "
|
|
||||||
"Returns the current status: running, completed, or failed."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"operation_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"The operation_id from an operation_started response."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"task_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": (
|
|
||||||
"The task_id from an operation_started response. "
|
|
||||||
"Used as fallback if operation_id is not provided."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
from backend.api.features.chat import stream_registry
|
|
||||||
|
|
||||||
operation_id = (kwargs.get("operation_id") or "").strip()
|
|
||||||
task_id = (kwargs.get("task_id") or "").strip()
|
|
||||||
|
|
||||||
if not operation_id and not task_id:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please provide an operation_id or task_id.",
|
|
||||||
error="missing_parameter",
|
|
||||||
)
|
|
||||||
|
|
||||||
task = None
|
|
||||||
if operation_id:
|
|
||||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
|
||||||
if task is None and task_id:
|
|
||||||
task = await stream_registry.get_task(task_id)
|
|
||||||
|
|
||||||
if task is None:
|
|
||||||
# Task not in Redis — it may have already expired (TTL).
|
|
||||||
# Check conversation history for the result instead.
|
|
||||||
return ErrorResponse(
|
|
||||||
message=(
|
|
||||||
"Operation not found — it may have already completed and "
|
|
||||||
"expired from the status tracker. Check the conversation "
|
|
||||||
"history for the result."
|
|
||||||
),
|
|
||||||
error="not_found",
|
|
||||||
)
|
|
||||||
|
|
||||||
status_messages = {
|
|
||||||
"running": (
|
|
||||||
f"The {task.tool_name or 'operation'} is still running. "
|
|
||||||
"Please wait for it to complete."
|
|
||||||
),
|
|
||||||
"completed": (
|
|
||||||
f"The {task.tool_name or 'operation'} has completed successfully."
|
|
||||||
),
|
|
||||||
"failed": f"The {task.tool_name or 'operation'} has failed.",
|
|
||||||
}
|
|
||||||
|
|
||||||
return OperationStatusResponse(
|
|
||||||
task_id=task.task_id,
|
|
||||||
operation_id=task.operation_id,
|
|
||||||
status=task.status,
|
|
||||||
tool_name=task.tool_name,
|
|
||||||
message=status_messages.get(task.status, f"Status: {task.status}"),
|
|
||||||
)
|
|
||||||
@@ -146,7 +146,6 @@ class FindBlockTool(BaseTool):
|
|||||||
id=block_id,
|
id=block_id,
|
||||||
name=block.name,
|
name=block.name,
|
||||||
description=block.description or "",
|
description=block.description or "",
|
||||||
categories=[c.value for c in block.categories],
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -41,12 +41,6 @@ class ResponseType(str, Enum):
|
|||||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||||
# Input validation
|
# Input validation
|
||||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||||
# Web fetch
|
|
||||||
WEB_FETCH = "web_fetch"
|
|
||||||
# Code execution
|
|
||||||
BASH_EXEC = "bash_exec"
|
|
||||||
# Operation status check
|
|
||||||
OPERATION_STATUS = "operation_status"
|
|
||||||
# Feature request types
|
# Feature request types
|
||||||
FEATURE_REQUEST_SEARCH = "feature_request_search"
|
FEATURE_REQUEST_SEARCH = "feature_request_search"
|
||||||
FEATURE_REQUEST_CREATED = "feature_request_created"
|
FEATURE_REQUEST_CREATED = "feature_request_created"
|
||||||
@@ -344,19 +338,6 @@ class BlockInfoSummary(BaseModel):
|
|||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
categories: list[str]
|
|
||||||
input_schema: dict[str, Any] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="Full JSON schema for block inputs",
|
|
||||||
)
|
|
||||||
output_schema: dict[str, Any] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="Full JSON schema for block outputs",
|
|
||||||
)
|
|
||||||
required_inputs: list[BlockInputFieldInfo] = Field(
|
|
||||||
default_factory=list,
|
|
||||||
description="List of input fields for this block",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BlockListResponse(ToolResponseBase):
|
class BlockListResponse(ToolResponseBase):
|
||||||
@@ -366,10 +347,6 @@ class BlockListResponse(ToolResponseBase):
|
|||||||
blocks: list[BlockInfoSummary]
|
blocks: list[BlockInfoSummary]
|
||||||
count: int
|
count: int
|
||||||
query: str
|
query: str
|
||||||
usage_hint: str = Field(
|
|
||||||
default="To execute a block, call run_block with block_id set to the block's "
|
|
||||||
"'id' field and input_data containing the fields listed in required_inputs."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BlockDetails(BaseModel):
|
class BlockDetails(BaseModel):
|
||||||
@@ -458,27 +435,6 @@ class AsyncProcessingResponse(ToolResponseBase):
|
|||||||
task_id: str | None = None
|
task_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class WebFetchResponse(ToolResponseBase):
|
|
||||||
"""Response for web_fetch tool."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.WEB_FETCH
|
|
||||||
url: str
|
|
||||||
status_code: int
|
|
||||||
content_type: str
|
|
||||||
content: str
|
|
||||||
truncated: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class BashExecResponse(ToolResponseBase):
|
|
||||||
"""Response for bash_exec tool."""
|
|
||||||
|
|
||||||
type: ResponseType = ResponseType.BASH_EXEC
|
|
||||||
stdout: str
|
|
||||||
stderr: str
|
|
||||||
exit_code: int
|
|
||||||
timed_out: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
# Feature request models
|
# Feature request models
|
||||||
class FeatureRequestInfo(BaseModel):
|
class FeatureRequestInfo(BaseModel):
|
||||||
"""Information about a feature request issue."""
|
"""Information about a feature request issue."""
|
||||||
|
|||||||
@@ -1,265 +0,0 @@
|
|||||||
"""Sandbox execution utilities for code execution tools.
|
|
||||||
|
|
||||||
Provides filesystem + network isolated command execution using **bubblewrap**
|
|
||||||
(``bwrap``): whitelist-only filesystem (only system dirs visible read-only),
|
|
||||||
writable workspace only, clean environment, network blocked.
|
|
||||||
|
|
||||||
Tools that call :func:`run_sandboxed` must first check :func:`has_full_sandbox`
|
|
||||||
and refuse to run if bubblewrap is not available.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_DEFAULT_TIMEOUT = 30
|
|
||||||
_MAX_TIMEOUT = 120
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Sandbox capability detection (cached at first call)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
_BWRAP_AVAILABLE: bool | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def has_full_sandbox() -> bool:
|
|
||||||
"""Return True if bubblewrap is available (filesystem + network isolation).
|
|
||||||
|
|
||||||
On non-Linux platforms (macOS), always returns False.
|
|
||||||
"""
|
|
||||||
global _BWRAP_AVAILABLE
|
|
||||||
if _BWRAP_AVAILABLE is None:
|
|
||||||
_BWRAP_AVAILABLE = (
|
|
||||||
platform.system() == "Linux" and shutil.which("bwrap") is not None
|
|
||||||
)
|
|
||||||
return _BWRAP_AVAILABLE
|
|
||||||
|
|
||||||
|
|
||||||
WORKSPACE_PREFIX = "/tmp/copilot-"
|
|
||||||
|
|
||||||
|
|
||||||
def make_session_path(session_id: str) -> str:
|
|
||||||
"""Build a sanitized, session-specific path under :data:`WORKSPACE_PREFIX`.
|
|
||||||
|
|
||||||
Shared by both the SDK working-directory setup and the sandbox tools so
|
|
||||||
they always resolve to the same directory for a given session.
|
|
||||||
|
|
||||||
Steps:
|
|
||||||
1. Strip all characters except ``[A-Za-z0-9-]``.
|
|
||||||
2. Construct ``/tmp/copilot-<safe_id>``.
|
|
||||||
3. Validate via ``os.path.normpath`` + ``startswith`` (CodeQL-recognised
|
|
||||||
sanitizer) to prevent path traversal.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the resulting path escapes the prefix.
|
|
||||||
"""
|
|
||||||
import re
|
|
||||||
|
|
||||||
safe_id = re.sub(r"[^A-Za-z0-9-]", "", session_id)
|
|
||||||
if not safe_id:
|
|
||||||
safe_id = "default"
|
|
||||||
path = os.path.normpath(f"{WORKSPACE_PREFIX}{safe_id}")
|
|
||||||
if not path.startswith(WORKSPACE_PREFIX):
|
|
||||||
raise ValueError(f"Session path escaped prefix: {path}")
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
def get_workspace_dir(session_id: str) -> str:
|
|
||||||
"""Get or create the workspace directory for a session.
|
|
||||||
|
|
||||||
Uses :func:`make_session_path` — the same path the SDK uses — so that
|
|
||||||
bash_exec shares the workspace with the SDK file tools.
|
|
||||||
"""
|
|
||||||
workspace = make_session_path(session_id)
|
|
||||||
os.makedirs(workspace, exist_ok=True)
|
|
||||||
return workspace
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Bubblewrap command builder
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
# System directories mounted read-only inside the sandbox.
|
|
||||||
# ONLY these are visible — /app, /root, /home, /opt, /var etc. are NOT accessible.
|
|
||||||
_SYSTEM_RO_BINDS = [
|
|
||||||
"/usr", # binaries, libraries, Python interpreter
|
|
||||||
"/etc", # system config: ld.so, locale, passwd, alternatives
|
|
||||||
]
|
|
||||||
|
|
||||||
# Compat paths: symlinks to /usr/* on modern Debian, real dirs on older systems.
|
|
||||||
# On Debian 13 these are symlinks (e.g. /bin -> usr/bin). bwrap --ro-bind
|
|
||||||
# can't create a symlink target, so we detect and use --symlink instead.
|
|
||||||
# /lib64 is critical: the ELF dynamic linker lives at /lib64/ld-linux-x86-64.so.2.
|
|
||||||
_COMPAT_PATHS = [
|
|
||||||
("/bin", "usr/bin"), # -> /usr/bin on Debian 13
|
|
||||||
("/sbin", "usr/sbin"), # -> /usr/sbin on Debian 13
|
|
||||||
("/lib", "usr/lib"), # -> /usr/lib on Debian 13
|
|
||||||
("/lib64", "usr/lib64"), # 64-bit libraries / ELF interpreter
|
|
||||||
]
|
|
||||||
|
|
||||||
# Resource limits to prevent fork bombs, memory exhaustion, and disk abuse.
|
|
||||||
# Applied via ulimit inside the sandbox before exec'ing the user command.
|
|
||||||
_RESOURCE_LIMITS = (
|
|
||||||
"ulimit -u 64" # max 64 processes (prevents fork bombs)
|
|
||||||
" -v 524288" # 512 MB virtual memory
|
|
||||||
" -f 51200" # 50 MB max file size (1024-byte blocks)
|
|
||||||
" -n 256" # 256 open file descriptors
|
|
||||||
" 2>/dev/null"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_bwrap_command(
|
|
||||||
command: list[str], cwd: str, env: dict[str, str]
|
|
||||||
) -> list[str]:
|
|
||||||
"""Build a bubblewrap command with strict filesystem + network isolation.
|
|
||||||
|
|
||||||
Security model:
|
|
||||||
- **Whitelist-only filesystem**: only system directories (``/usr``, ``/etc``,
|
|
||||||
``/bin``, ``/lib``) are mounted read-only. Application code (``/app``),
|
|
||||||
home directories, ``/var``, ``/opt``, etc. are NOT accessible at all.
|
|
||||||
- **Writable workspace only**: the per-session workspace is the sole
|
|
||||||
writable path.
|
|
||||||
- **Clean environment**: ``--clearenv`` wipes all inherited env vars.
|
|
||||||
Only the explicitly-passed safe env vars are set inside the sandbox.
|
|
||||||
- **Network isolation**: ``--unshare-net`` blocks all network access.
|
|
||||||
- **Resource limits**: ulimit caps on processes (64), memory (512MB),
|
|
||||||
file size (50MB), and open FDs (256) to prevent fork bombs and abuse.
|
|
||||||
- **New session**: prevents terminal control escape.
|
|
||||||
- **Die with parent**: prevents orphaned sandbox processes.
|
|
||||||
"""
|
|
||||||
cmd = [
|
|
||||||
"bwrap",
|
|
||||||
# Create a new user namespace so bwrap can set up sandboxing
|
|
||||||
# inside unprivileged Docker containers (no CAP_SYS_ADMIN needed).
|
|
||||||
"--unshare-user",
|
|
||||||
# Wipe all inherited environment variables (API keys, secrets, etc.)
|
|
||||||
"--clearenv",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Set only the safe env vars inside the sandbox
|
|
||||||
for key, value in env.items():
|
|
||||||
cmd.extend(["--setenv", key, value])
|
|
||||||
|
|
||||||
# System directories: read-only
|
|
||||||
for path in _SYSTEM_RO_BINDS:
|
|
||||||
cmd.extend(["--ro-bind", path, path])
|
|
||||||
|
|
||||||
# Compat paths: use --symlink when host path is a symlink (Debian 13),
|
|
||||||
# --ro-bind when it's a real directory (older distros).
|
|
||||||
for path, symlink_target in _COMPAT_PATHS:
|
|
||||||
if os.path.islink(path):
|
|
||||||
cmd.extend(["--symlink", symlink_target, path])
|
|
||||||
elif os.path.exists(path):
|
|
||||||
cmd.extend(["--ro-bind", path, path])
|
|
||||||
|
|
||||||
# Wrap the user command with resource limits:
|
|
||||||
# sh -c 'ulimit ...; exec "$@"' -- <original command>
|
|
||||||
# `exec "$@"` replaces the shell so there's no extra process overhead,
|
|
||||||
# and properly handles arguments with spaces.
|
|
||||||
limited_command = [
|
|
||||||
"sh",
|
|
||||||
"-c",
|
|
||||||
f'{_RESOURCE_LIMITS}; exec "$@"',
|
|
||||||
"--",
|
|
||||||
*command,
|
|
||||||
]
|
|
||||||
|
|
||||||
cmd.extend(
|
|
||||||
[
|
|
||||||
# Fresh virtual filesystems
|
|
||||||
"--dev",
|
|
||||||
"/dev",
|
|
||||||
"--proc",
|
|
||||||
"/proc",
|
|
||||||
"--tmpfs",
|
|
||||||
"/tmp",
|
|
||||||
# Workspace bind AFTER --tmpfs /tmp so it's visible through the tmpfs.
|
|
||||||
# (workspace lives under /tmp/copilot-<session>)
|
|
||||||
"--bind",
|
|
||||||
cwd,
|
|
||||||
cwd,
|
|
||||||
# Isolation
|
|
||||||
"--unshare-net",
|
|
||||||
"--die-with-parent",
|
|
||||||
"--new-session",
|
|
||||||
"--chdir",
|
|
||||||
cwd,
|
|
||||||
"--",
|
|
||||||
*limited_command,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return cmd
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Public API
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
async def run_sandboxed(
|
|
||||||
command: list[str],
|
|
||||||
cwd: str,
|
|
||||||
timeout: int = _DEFAULT_TIMEOUT,
|
|
||||||
env: dict[str, str] | None = None,
|
|
||||||
) -> tuple[str, str, int, bool]:
|
|
||||||
"""Run a command inside a bubblewrap sandbox.
|
|
||||||
|
|
||||||
Callers **must** check :func:`has_full_sandbox` before calling this
|
|
||||||
function. If bubblewrap is not available, this function raises
|
|
||||||
:class:`RuntimeError` rather than running unsandboxed.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(stdout, stderr, exit_code, timed_out)
|
|
||||||
"""
|
|
||||||
if not has_full_sandbox():
|
|
||||||
raise RuntimeError(
|
|
||||||
"run_sandboxed() requires bubblewrap but bwrap is not available. "
|
|
||||||
"Callers must check has_full_sandbox() before calling this function."
|
|
||||||
)
|
|
||||||
|
|
||||||
timeout = min(max(timeout, 1), _MAX_TIMEOUT)
|
|
||||||
|
|
||||||
safe_env = {
|
|
||||||
"PATH": "/usr/local/bin:/usr/bin:/bin",
|
|
||||||
"HOME": cwd,
|
|
||||||
"TMPDIR": cwd,
|
|
||||||
"LANG": "en_US.UTF-8",
|
|
||||||
"PYTHONDONTWRITEBYTECODE": "1",
|
|
||||||
"PYTHONIOENCODING": "utf-8",
|
|
||||||
}
|
|
||||||
if env:
|
|
||||||
safe_env.update(env)
|
|
||||||
|
|
||||||
full_command = _build_bwrap_command(command, cwd, safe_env)
|
|
||||||
|
|
||||||
try:
|
|
||||||
proc = await asyncio.create_subprocess_exec(
|
|
||||||
*full_command,
|
|
||||||
stdout=asyncio.subprocess.PIPE,
|
|
||||||
stderr=asyncio.subprocess.PIPE,
|
|
||||||
cwd=cwd,
|
|
||||||
env=safe_env,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
stdout_bytes, stderr_bytes = await asyncio.wait_for(
|
|
||||||
proc.communicate(), timeout=timeout
|
|
||||||
)
|
|
||||||
stdout = stdout_bytes.decode("utf-8", errors="replace")
|
|
||||||
stderr = stderr_bytes.decode("utf-8", errors="replace")
|
|
||||||
return stdout, stderr, proc.returncode or 0, False
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
proc.kill()
|
|
||||||
await proc.communicate()
|
|
||||||
return "", f"Execution timed out after {timeout}s", -1, True
|
|
||||||
|
|
||||||
except RuntimeError:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
return "", f"Sandbox error: {e}", -1, False
|
|
||||||
@@ -15,7 +15,6 @@ from backend.data.model import (
|
|||||||
OAuth2Credentials,
|
OAuth2Credentials,
|
||||||
)
|
)
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -360,7 +359,7 @@ async def match_user_credentials_to_graph(
|
|||||||
_,
|
_,
|
||||||
_,
|
_,
|
||||||
) in aggregated_creds.items():
|
) in aggregated_creds.items():
|
||||||
# Find first matching credential by provider, type, scopes, and host/URL
|
# Find first matching credential by provider, type, and scopes
|
||||||
matching_cred = next(
|
matching_cred = next(
|
||||||
(
|
(
|
||||||
cred
|
cred
|
||||||
@@ -375,10 +374,6 @@ async def match_user_credentials_to_graph(
|
|||||||
cred.type != "host_scoped"
|
cred.type != "host_scoped"
|
||||||
or _credential_is_for_host(cred, credential_requirements)
|
or _credential_is_for_host(cred, credential_requirements)
|
||||||
)
|
)
|
||||||
and (
|
|
||||||
cred.provider != ProviderName.MCP
|
|
||||||
or _credential_is_for_mcp_server(cred, credential_requirements)
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -449,22 +444,6 @@ def _credential_is_for_host(
|
|||||||
return credential.matches_url(list(requirements.discriminator_values)[0])
|
return credential.matches_url(list(requirements.discriminator_values)[0])
|
||||||
|
|
||||||
|
|
||||||
def _credential_is_for_mcp_server(
|
|
||||||
credential: Credentials,
|
|
||||||
requirements: CredentialsFieldInfo,
|
|
||||||
) -> bool:
|
|
||||||
"""Check if an MCP OAuth credential matches the required server URL."""
|
|
||||||
if not requirements.discriminator_values:
|
|
||||||
return True
|
|
||||||
|
|
||||||
server_url = (
|
|
||||||
credential.metadata.get("mcp_server_url")
|
|
||||||
if isinstance(credential, OAuth2Credentials)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
return server_url in requirements.discriminator_values if server_url else False
|
|
||||||
|
|
||||||
|
|
||||||
async def check_user_has_required_credentials(
|
async def check_user_has_required_credentials(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
required_credentials: list[CredentialsMetaInput],
|
required_credentials: list[CredentialsMetaInput],
|
||||||
|
|||||||
@@ -1,151 +0,0 @@
|
|||||||
"""Web fetch tool — safely retrieve public web page content."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import html2text
|
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.api.features.chat.tools.base import BaseTool
|
|
||||||
from backend.api.features.chat.tools.models import (
|
|
||||||
ErrorResponse,
|
|
||||||
ToolResponseBase,
|
|
||||||
WebFetchResponse,
|
|
||||||
)
|
|
||||||
from backend.util.request import Requests
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Limits
|
|
||||||
_MAX_CONTENT_BYTES = 102_400 # 100 KB download cap
|
|
||||||
_REQUEST_TIMEOUT = aiohttp.ClientTimeout(total=15)
|
|
||||||
|
|
||||||
# Content types we'll read as text
|
|
||||||
_TEXT_CONTENT_TYPES = {
|
|
||||||
"text/html",
|
|
||||||
"text/plain",
|
|
||||||
"text/xml",
|
|
||||||
"text/csv",
|
|
||||||
"text/markdown",
|
|
||||||
"application/json",
|
|
||||||
"application/xml",
|
|
||||||
"application/xhtml+xml",
|
|
||||||
"application/rss+xml",
|
|
||||||
"application/atom+xml",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _is_text_content(content_type: str) -> bool:
|
|
||||||
base = content_type.split(";")[0].strip().lower()
|
|
||||||
return base in _TEXT_CONTENT_TYPES or base.startswith("text/")
|
|
||||||
|
|
||||||
|
|
||||||
def _html_to_text(html: str) -> str:
|
|
||||||
h = html2text.HTML2Text()
|
|
||||||
h.ignore_links = False
|
|
||||||
h.ignore_images = True
|
|
||||||
h.body_width = 0
|
|
||||||
return h.handle(html)
|
|
||||||
|
|
||||||
|
|
||||||
class WebFetchTool(BaseTool):
|
|
||||||
"""Safely fetch content from a public URL using SSRF-protected HTTP."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "web_fetch"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return (
|
|
||||||
"Fetch the content of a public web page by URL. "
|
|
||||||
"Returns readable text extracted from HTML by default. "
|
|
||||||
"Useful for reading documentation, articles, and API responses. "
|
|
||||||
"Only supports HTTP/HTTPS GET requests to public URLs "
|
|
||||||
"(private/internal network addresses are blocked)."
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def parameters(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"url": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The public HTTP/HTTPS URL to fetch.",
|
|
||||||
},
|
|
||||||
"extract_text": {
|
|
||||||
"type": "boolean",
|
|
||||||
"description": (
|
|
||||||
"If true (default), extract readable text from HTML. "
|
|
||||||
"If false, return raw content."
|
|
||||||
),
|
|
||||||
"default": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["url"],
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_auth(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _execute(
|
|
||||||
self,
|
|
||||||
user_id: str | None,
|
|
||||||
session: ChatSession,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ToolResponseBase:
|
|
||||||
url: str = (kwargs.get("url") or "").strip()
|
|
||||||
extract_text: bool = kwargs.get("extract_text", True)
|
|
||||||
session_id = session.session_id if session else None
|
|
||||||
|
|
||||||
if not url:
|
|
||||||
return ErrorResponse(
|
|
||||||
message="Please provide a URL to fetch.",
|
|
||||||
error="missing_url",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
client = Requests(raise_for_status=False, retry_max_attempts=1)
|
|
||||||
response = await client.get(url, timeout=_REQUEST_TIMEOUT)
|
|
||||||
except ValueError as e:
|
|
||||||
# validate_url raises ValueError for SSRF / blocked IPs
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"URL blocked: {e}",
|
|
||||||
error="url_blocked",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[web_fetch] Request failed for {url}: {e}")
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Failed to fetch URL: {e}",
|
|
||||||
error="fetch_failed",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
content_type = response.headers.get("content-type", "")
|
|
||||||
if not _is_text_content(content_type):
|
|
||||||
return ErrorResponse(
|
|
||||||
message=f"Non-text content type: {content_type.split(';')[0]}",
|
|
||||||
error="unsupported_content_type",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
raw = response.content[:_MAX_CONTENT_BYTES]
|
|
||||||
text = raw.decode("utf-8", errors="replace")
|
|
||||||
|
|
||||||
if extract_text and "html" in content_type.lower():
|
|
||||||
text = _html_to_text(text)
|
|
||||||
|
|
||||||
return WebFetchResponse(
|
|
||||||
message=f"Fetched {url}",
|
|
||||||
url=response.url,
|
|
||||||
status_code=response.status,
|
|
||||||
content_type=content_type.split(";")[0].strip(),
|
|
||||||
content=text,
|
|
||||||
truncated=False,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
@@ -88,9 +88,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
|||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"List files in the user's persistent workspace (cloud storage). "
|
"List files in the user's workspace. "
|
||||||
"These files survive across sessions. "
|
|
||||||
"For ephemeral session files, use the SDK Read/Glob tools instead. "
|
|
||||||
"Returns file names, paths, sizes, and metadata. "
|
"Returns file names, paths, sizes, and metadata. "
|
||||||
"Optionally filter by path prefix."
|
"Optionally filter by path prefix."
|
||||||
)
|
)
|
||||||
@@ -206,9 +204,7 @@ class ReadWorkspaceFileTool(BaseTool):
|
|||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"Read a file from the user's persistent workspace (cloud storage). "
|
"Read a file from the user's workspace. "
|
||||||
"These files survive across sessions. "
|
|
||||||
"For ephemeral session files, use the SDK Read tool instead. "
|
|
||||||
"Specify either file_id or path to identify the file. "
|
"Specify either file_id or path to identify the file. "
|
||||||
"For small text files, returns content directly. "
|
"For small text files, returns content directly. "
|
||||||
"For large or binary files, returns metadata and a download URL. "
|
"For large or binary files, returns metadata and a download URL. "
|
||||||
@@ -382,9 +378,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
|||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"Write or create a file in the user's persistent workspace (cloud storage). "
|
"Write or create a file in the user's workspace. "
|
||||||
"These files survive across sessions. "
|
|
||||||
"For ephemeral session files, use the SDK Write tool instead. "
|
|
||||||
"Provide the content as a base64-encoded string. "
|
"Provide the content as a base64-encoded string. "
|
||||||
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
||||||
"Files are saved to the current session's folder by default. "
|
"Files are saved to the current session's folder by default. "
|
||||||
@@ -529,7 +523,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
|||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"Delete a file from the user's persistent workspace (cloud storage). "
|
"Delete a file from the user's workspace. "
|
||||||
"Specify either file_id or path to identify the file. "
|
"Specify either file_id or path to identify the file. "
|
||||||
"Paths are scoped to the current session by default. "
|
"Paths are scoped to the current session by default. "
|
||||||
"Use /sessions/<session_id>/... for cross-session access."
|
"Use /sessions/<session_id>/... for cross-session access."
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import TYPE_CHECKING, Annotated, Any, List, Literal
|
from typing import TYPE_CHECKING, Annotated, List, Literal
|
||||||
|
|
||||||
from autogpt_libs.auth import get_user_id
|
from autogpt_libs.auth import get_user_id
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
@@ -14,7 +14,7 @@ from fastapi import (
|
|||||||
Security,
|
Security,
|
||||||
status,
|
status,
|
||||||
)
|
)
|
||||||
from pydantic import BaseModel, Field, SecretStr, model_validator
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR, HTTP_502_BAD_GATEWAY
|
||||||
|
|
||||||
from backend.api.features.library.db import set_preset_webhook, update_preset
|
from backend.api.features.library.db import set_preset_webhook, update_preset
|
||||||
@@ -39,11 +39,7 @@ from backend.data.onboarding import OnboardingStep, complete_onboarding_step
|
|||||||
from backend.data.user import get_user_integrations
|
from backend.data.user import get_user_integrations
|
||||||
from backend.executor.utils import add_graph_execution
|
from backend.executor.utils import add_graph_execution
|
||||||
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
from backend.integrations.ayrshare import AyrshareClient, SocialPlatform
|
||||||
from backend.integrations.credentials_store import provider_matches
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.creds_manager import (
|
|
||||||
IntegrationCredentialsManager,
|
|
||||||
create_mcp_oauth_handler,
|
|
||||||
)
|
|
||||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.integrations.webhooks import get_webhook_manager
|
from backend.integrations.webhooks import get_webhook_manager
|
||||||
@@ -106,37 +102,9 @@ class CredentialsMetaResponse(BaseModel):
|
|||||||
scopes: list[str] | None
|
scopes: list[str] | None
|
||||||
username: str | None
|
username: str | None
|
||||||
host: str | None = Field(
|
host: str | None = Field(
|
||||||
default=None,
|
default=None, description="Host pattern for host-scoped credentials"
|
||||||
description="Host pattern for host-scoped or MCP server URL for MCP credentials",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def _normalize_provider(cls, data: Any) -> Any:
|
|
||||||
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug."""
|
|
||||||
if isinstance(data, dict):
|
|
||||||
prov = data.get("provider", "")
|
|
||||||
if isinstance(prov, str) and prov.startswith("ProviderName."):
|
|
||||||
member = prov.removeprefix("ProviderName.")
|
|
||||||
try:
|
|
||||||
data = {**data, "provider": ProviderName[member].value}
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
return data
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_host(cred: Credentials) -> str | None:
|
|
||||||
"""Extract host from credential: HostScoped host or MCP server URL."""
|
|
||||||
if isinstance(cred, HostScopedCredentials):
|
|
||||||
return cred.host
|
|
||||||
if isinstance(cred, OAuth2Credentials) and cred.provider in (
|
|
||||||
ProviderName.MCP,
|
|
||||||
ProviderName.MCP.value,
|
|
||||||
"ProviderName.MCP",
|
|
||||||
):
|
|
||||||
return (cred.metadata or {}).get("mcp_server_url")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
@router.post("/{provider}/callback", summary="Exchange OAuth code for tokens")
|
||||||
async def callback(
|
async def callback(
|
||||||
@@ -211,7 +179,9 @@ async def callback(
|
|||||||
title=credentials.title,
|
title=credentials.title,
|
||||||
scopes=credentials.scopes,
|
scopes=credentials.scopes,
|
||||||
username=credentials.username,
|
username=credentials.username,
|
||||||
host=(CredentialsMetaResponse.get_host(credentials)),
|
host=(
|
||||||
|
credentials.host if isinstance(credentials, HostScopedCredentials) else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -229,7 +199,7 @@ async def list_credentials(
|
|||||||
title=cred.title,
|
title=cred.title,
|
||||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||||
host=CredentialsMetaResponse.get_host(cred),
|
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||||
)
|
)
|
||||||
for cred in credentials
|
for cred in credentials
|
||||||
]
|
]
|
||||||
@@ -252,7 +222,7 @@ async def list_credentials_by_provider(
|
|||||||
title=cred.title,
|
title=cred.title,
|
||||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||||
host=CredentialsMetaResponse.get_host(cred),
|
host=cred.host if isinstance(cred, HostScopedCredentials) else None,
|
||||||
)
|
)
|
||||||
for cred in credentials
|
for cred in credentials
|
||||||
]
|
]
|
||||||
@@ -352,10 +322,6 @@ async def delete_credentials(
|
|||||||
|
|
||||||
tokens_revoked = None
|
tokens_revoked = None
|
||||||
if isinstance(creds, OAuth2Credentials):
|
if isinstance(creds, OAuth2Credentials):
|
||||||
if provider_matches(provider.value, ProviderName.MCP.value):
|
|
||||||
# MCP uses dynamic per-server OAuth — create handler from metadata
|
|
||||||
handler = create_mcp_oauth_handler(creds)
|
|
||||||
else:
|
|
||||||
handler = _get_provider_oauth_handler(request, provider)
|
handler = _get_provider_oauth_handler(request, provider)
|
||||||
tokens_revoked = await handler.revoke_tokens(creds)
|
tokens_revoked = await handler.revoke_tokens(creds)
|
||||||
|
|
||||||
|
|||||||
@@ -1,404 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP (Model Context Protocol) API routes.
|
|
||||||
|
|
||||||
Provides endpoints for MCP tool discovery and OAuth authentication so the
|
|
||||||
frontend can list available tools on an MCP server before placing a block.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Annotated, Any
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import fastapi
|
|
||||||
from autogpt_libs.auth import get_user_id
|
|
||||||
from fastapi import Security
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from backend.api.features.integrations.router import CredentialsMetaResponse
|
|
||||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
|
||||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.request import HTTPClientError, Requests
|
|
||||||
from backend.util.settings import Settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
router = fastapi.APIRouter(tags=["mcp"])
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
|
||||||
|
|
||||||
|
|
||||||
# ====================== Tool Discovery ====================== #
|
|
||||||
|
|
||||||
|
|
||||||
class DiscoverToolsRequest(BaseModel):
|
|
||||||
"""Request to discover tools on an MCP server."""
|
|
||||||
|
|
||||||
server_url: str = Field(description="URL of the MCP server")
|
|
||||||
auth_token: str | None = Field(
|
|
||||||
default=None,
|
|
||||||
description="Optional Bearer token for authenticated MCP servers",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MCPToolResponse(BaseModel):
|
|
||||||
"""A single MCP tool returned by discovery."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
input_schema: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class DiscoverToolsResponse(BaseModel):
|
|
||||||
"""Response containing the list of tools available on an MCP server."""
|
|
||||||
|
|
||||||
tools: list[MCPToolResponse]
|
|
||||||
server_name: str | None = None
|
|
||||||
protocol_version: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/discover-tools",
|
|
||||||
summary="Discover available tools on an MCP server",
|
|
||||||
response_model=DiscoverToolsResponse,
|
|
||||||
)
|
|
||||||
async def discover_tools(
|
|
||||||
request: DiscoverToolsRequest,
|
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
|
||||||
) -> DiscoverToolsResponse:
|
|
||||||
"""
|
|
||||||
Connect to an MCP server and return its available tools.
|
|
||||||
|
|
||||||
If the user has a stored MCP credential for this server URL, it will be
|
|
||||||
used automatically — no need to pass an explicit auth token.
|
|
||||||
"""
|
|
||||||
auth_token = request.auth_token
|
|
||||||
|
|
||||||
# Auto-use stored MCP credential when no explicit token is provided.
|
|
||||||
if not auth_token:
|
|
||||||
mcp_creds = await creds_manager.store.get_creds_by_provider(
|
|
||||||
user_id, ProviderName.MCP.value
|
|
||||||
)
|
|
||||||
# Find the freshest credential for this server URL
|
|
||||||
best_cred: OAuth2Credentials | None = None
|
|
||||||
for cred in mcp_creds:
|
|
||||||
if (
|
|
||||||
isinstance(cred, OAuth2Credentials)
|
|
||||||
and (cred.metadata or {}).get("mcp_server_url") == request.server_url
|
|
||||||
):
|
|
||||||
if best_cred is None or (
|
|
||||||
(cred.access_token_expires_at or 0)
|
|
||||||
> (best_cred.access_token_expires_at or 0)
|
|
||||||
):
|
|
||||||
best_cred = cred
|
|
||||||
if best_cred:
|
|
||||||
# Refresh the token if expired before using it
|
|
||||||
best_cred = await creds_manager.refresh_if_needed(user_id, best_cred)
|
|
||||||
logger.info(
|
|
||||||
f"Using MCP credential {best_cred.id} for {request.server_url}, "
|
|
||||||
f"expires_at={best_cred.access_token_expires_at}"
|
|
||||||
)
|
|
||||||
auth_token = best_cred.access_token.get_secret_value()
|
|
||||||
|
|
||||||
client = MCPClient(request.server_url, auth_token=auth_token)
|
|
||||||
|
|
||||||
try:
|
|
||||||
init_result = await client.initialize()
|
|
||||||
tools = await client.list_tools()
|
|
||||||
except HTTPClientError as e:
|
|
||||||
if e.status_code in (401, 403):
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="This MCP server requires authentication. "
|
|
||||||
"Please provide a valid auth token.",
|
|
||||||
)
|
|
||||||
raise fastapi.HTTPException(status_code=502, detail=str(e))
|
|
||||||
except MCPClientError as e:
|
|
||||||
raise fastapi.HTTPException(status_code=502, detail=str(e))
|
|
||||||
except Exception as e:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=502,
|
|
||||||
detail=f"Failed to connect to MCP server: {e}",
|
|
||||||
)
|
|
||||||
|
|
||||||
return DiscoverToolsResponse(
|
|
||||||
tools=[
|
|
||||||
MCPToolResponse(
|
|
||||||
name=t.name,
|
|
||||||
description=t.description,
|
|
||||||
input_schema=t.input_schema,
|
|
||||||
)
|
|
||||||
for t in tools
|
|
||||||
],
|
|
||||||
server_name=(
|
|
||||||
init_result.get("serverInfo", {}).get("name")
|
|
||||||
or urlparse(request.server_url).hostname
|
|
||||||
or "MCP"
|
|
||||||
),
|
|
||||||
protocol_version=init_result.get("protocolVersion"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ======================== OAuth Flow ======================== #
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthLoginRequest(BaseModel):
|
|
||||||
"""Request to start an OAuth flow for an MCP server."""
|
|
||||||
|
|
||||||
server_url: str = Field(description="URL of the MCP server that requires OAuth")
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthLoginResponse(BaseModel):
|
|
||||||
"""Response with the OAuth login URL for the user to authenticate."""
|
|
||||||
|
|
||||||
login_url: str
|
|
||||||
state_token: str
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/oauth/login",
|
|
||||||
summary="Initiate OAuth login for an MCP server",
|
|
||||||
)
|
|
||||||
async def mcp_oauth_login(
|
|
||||||
request: MCPOAuthLoginRequest,
|
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
|
||||||
) -> MCPOAuthLoginResponse:
|
|
||||||
"""
|
|
||||||
Discover OAuth metadata from the MCP server and return a login URL.
|
|
||||||
|
|
||||||
1. Discovers the protected-resource metadata (RFC 9728)
|
|
||||||
2. Fetches the authorization server metadata (RFC 8414)
|
|
||||||
3. Performs Dynamic Client Registration (RFC 7591) if available
|
|
||||||
4. Returns the authorization URL for the frontend to open in a popup
|
|
||||||
"""
|
|
||||||
client = MCPClient(request.server_url)
|
|
||||||
|
|
||||||
# Step 1: Discover protected-resource metadata (RFC 9728)
|
|
||||||
protected_resource = await client.discover_auth()
|
|
||||||
|
|
||||||
metadata: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
if protected_resource and protected_resource.get("authorization_servers"):
|
|
||||||
auth_server_url = protected_resource["authorization_servers"][0]
|
|
||||||
resource_url = protected_resource.get("resource", request.server_url)
|
|
||||||
|
|
||||||
# Step 2a: Discover auth-server metadata (RFC 8414)
|
|
||||||
metadata = await client.discover_auth_server_metadata(auth_server_url)
|
|
||||||
else:
|
|
||||||
# Fallback: Some MCP servers (e.g. Linear) are their own auth server
|
|
||||||
# and serve OAuth metadata directly without protected-resource metadata.
|
|
||||||
# Don't assume a resource_url — omitting it lets the auth server choose
|
|
||||||
# the correct audience for the token (RFC 8707 resource is optional).
|
|
||||||
resource_url = None
|
|
||||||
metadata = await client.discover_auth_server_metadata(request.server_url)
|
|
||||||
|
|
||||||
if (
|
|
||||||
not metadata
|
|
||||||
or "authorization_endpoint" not in metadata
|
|
||||||
or "token_endpoint" not in metadata
|
|
||||||
):
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="This MCP server does not advertise OAuth support. "
|
|
||||||
"You may need to provide an auth token manually.",
|
|
||||||
)
|
|
||||||
|
|
||||||
authorize_url = metadata["authorization_endpoint"]
|
|
||||||
token_url = metadata["token_endpoint"]
|
|
||||||
registration_endpoint = metadata.get("registration_endpoint")
|
|
||||||
revoke_url = metadata.get("revocation_endpoint")
|
|
||||||
|
|
||||||
# Step 3: Dynamic Client Registration (RFC 7591) if available
|
|
||||||
frontend_base_url = settings.config.frontend_base_url
|
|
||||||
if not frontend_base_url:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Frontend base URL is not configured.",
|
|
||||||
)
|
|
||||||
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
|
|
||||||
|
|
||||||
client_id = ""
|
|
||||||
client_secret = ""
|
|
||||||
if registration_endpoint:
|
|
||||||
reg_result = await _register_mcp_client(
|
|
||||||
registration_endpoint, redirect_uri, request.server_url
|
|
||||||
)
|
|
||||||
if reg_result:
|
|
||||||
client_id = reg_result.get("client_id", "")
|
|
||||||
client_secret = reg_result.get("client_secret", "")
|
|
||||||
|
|
||||||
if not client_id:
|
|
||||||
client_id = "autogpt-platform"
|
|
||||||
|
|
||||||
# Step 4: Store state token with OAuth metadata for the callback
|
|
||||||
scopes = (protected_resource or {}).get("scopes_supported") or metadata.get(
|
|
||||||
"scopes_supported", []
|
|
||||||
)
|
|
||||||
state_token, code_challenge = await creds_manager.store.store_state_token(
|
|
||||||
user_id,
|
|
||||||
ProviderName.MCP.value,
|
|
||||||
scopes,
|
|
||||||
state_metadata={
|
|
||||||
"authorize_url": authorize_url,
|
|
||||||
"token_url": token_url,
|
|
||||||
"revoke_url": revoke_url,
|
|
||||||
"resource_url": resource_url,
|
|
||||||
"server_url": request.server_url,
|
|
||||||
"client_id": client_id,
|
|
||||||
"client_secret": client_secret,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 5: Build and return the login URL
|
|
||||||
handler = MCPOAuthHandler(
|
|
||||||
client_id=client_id,
|
|
||||||
client_secret=client_secret,
|
|
||||||
redirect_uri=redirect_uri,
|
|
||||||
authorize_url=authorize_url,
|
|
||||||
token_url=token_url,
|
|
||||||
resource_url=resource_url,
|
|
||||||
)
|
|
||||||
login_url = handler.get_login_url(
|
|
||||||
scopes, state_token, code_challenge=code_challenge
|
|
||||||
)
|
|
||||||
|
|
||||||
return MCPOAuthLoginResponse(login_url=login_url, state_token=state_token)
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthCallbackRequest(BaseModel):
|
|
||||||
"""Request to exchange an OAuth code for tokens."""
|
|
||||||
|
|
||||||
code: str = Field(description="Authorization code from OAuth callback")
|
|
||||||
state_token: str = Field(description="State token for CSRF verification")
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthCallbackResponse(BaseModel):
|
|
||||||
"""Response after successfully storing OAuth credentials."""
|
|
||||||
|
|
||||||
credential_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/oauth/callback",
|
|
||||||
summary="Exchange OAuth code for MCP tokens",
|
|
||||||
)
|
|
||||||
async def mcp_oauth_callback(
|
|
||||||
request: MCPOAuthCallbackRequest,
|
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
|
||||||
) -> CredentialsMetaResponse:
|
|
||||||
"""
|
|
||||||
Exchange the authorization code for tokens and store the credential.
|
|
||||||
|
|
||||||
The frontend calls this after receiving the OAuth code from the popup.
|
|
||||||
On success, subsequent ``/discover-tools`` calls for the same server URL
|
|
||||||
will automatically use the stored credential.
|
|
||||||
"""
|
|
||||||
valid_state = await creds_manager.store.verify_state_token(
|
|
||||||
user_id, request.state_token, ProviderName.MCP.value
|
|
||||||
)
|
|
||||||
if not valid_state:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Invalid or expired state token.",
|
|
||||||
)
|
|
||||||
|
|
||||||
meta = valid_state.state_metadata
|
|
||||||
frontend_base_url = settings.config.frontend_base_url
|
|
||||||
if not frontend_base_url:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=500,
|
|
||||||
detail="Frontend base URL is not configured.",
|
|
||||||
)
|
|
||||||
redirect_uri = f"{frontend_base_url}/auth/integrations/mcp_callback"
|
|
||||||
|
|
||||||
handler = MCPOAuthHandler(
|
|
||||||
client_id=meta["client_id"],
|
|
||||||
client_secret=meta.get("client_secret", ""),
|
|
||||||
redirect_uri=redirect_uri,
|
|
||||||
authorize_url=meta["authorize_url"],
|
|
||||||
token_url=meta["token_url"],
|
|
||||||
revoke_url=meta.get("revoke_url"),
|
|
||||||
resource_url=meta.get("resource_url"),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
credentials = await handler.exchange_code_for_tokens(
|
|
||||||
request.code, valid_state.scopes, valid_state.code_verifier
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise fastapi.HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"OAuth token exchange failed: {e}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Enrich credential metadata for future lookup and token refresh
|
|
||||||
if credentials.metadata is None:
|
|
||||||
credentials.metadata = {}
|
|
||||||
credentials.metadata["mcp_server_url"] = meta["server_url"]
|
|
||||||
credentials.metadata["mcp_client_id"] = meta["client_id"]
|
|
||||||
credentials.metadata["mcp_client_secret"] = meta.get("client_secret", "")
|
|
||||||
credentials.metadata["mcp_token_url"] = meta["token_url"]
|
|
||||||
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
|
|
||||||
|
|
||||||
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
|
|
||||||
credentials.title = f"MCP: {hostname}"
|
|
||||||
|
|
||||||
# Remove old MCP credentials for the same server to prevent stale token buildup.
|
|
||||||
try:
|
|
||||||
old_creds = await creds_manager.store.get_creds_by_provider(
|
|
||||||
user_id, ProviderName.MCP.value
|
|
||||||
)
|
|
||||||
for old in old_creds:
|
|
||||||
if (
|
|
||||||
isinstance(old, OAuth2Credentials)
|
|
||||||
and (old.metadata or {}).get("mcp_server_url") == meta["server_url"]
|
|
||||||
):
|
|
||||||
await creds_manager.store.delete_creds_by_id(user_id, old.id)
|
|
||||||
logger.info(
|
|
||||||
f"Removed old MCP credential {old.id} for {meta['server_url']}"
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Could not clean up old MCP credentials", exc_info=True)
|
|
||||||
|
|
||||||
await creds_manager.create(user_id, credentials)
|
|
||||||
|
|
||||||
return CredentialsMetaResponse(
|
|
||||||
id=credentials.id,
|
|
||||||
provider=credentials.provider,
|
|
||||||
type=credentials.type,
|
|
||||||
title=credentials.title,
|
|
||||||
scopes=credentials.scopes,
|
|
||||||
username=credentials.username,
|
|
||||||
host=credentials.metadata.get("mcp_server_url"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ======================== Helpers ======================== #
|
|
||||||
|
|
||||||
|
|
||||||
async def _register_mcp_client(
|
|
||||||
registration_endpoint: str,
|
|
||||||
redirect_uri: str,
|
|
||||||
server_url: str,
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Attempt Dynamic Client Registration (RFC 7591) with an MCP auth server."""
|
|
||||||
try:
|
|
||||||
response = await Requests(raise_for_status=True).post(
|
|
||||||
registration_endpoint,
|
|
||||||
json={
|
|
||||||
"client_name": "AutoGPT Platform",
|
|
||||||
"redirect_uris": [redirect_uri],
|
|
||||||
"grant_types": ["authorization_code"],
|
|
||||||
"response_types": ["code"],
|
|
||||||
"token_endpoint_auth_method": "client_secret_post",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
data = response.json()
|
|
||||||
if isinstance(data, dict) and "client_id" in data:
|
|
||||||
return data
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
|
|
||||||
return None
|
|
||||||
@@ -1,436 +0,0 @@
|
|||||||
"""Tests for MCP API routes.
|
|
||||||
|
|
||||||
Uses httpx.AsyncClient with ASGITransport instead of fastapi.testclient.TestClient
|
|
||||||
to avoid creating blocking portals that can corrupt pytest-asyncio's session event loop.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
|
|
||||||
import fastapi
|
|
||||||
import httpx
|
|
||||||
import pytest
|
|
||||||
import pytest_asyncio
|
|
||||||
from autogpt_libs.auth import get_user_id
|
|
||||||
|
|
||||||
from backend.api.features.mcp.routes import router
|
|
||||||
from backend.blocks.mcp.client import MCPClientError, MCPTool
|
|
||||||
from backend.util.request import HTTPClientError
|
|
||||||
|
|
||||||
app = fastapi.FastAPI()
|
|
||||||
app.include_router(router)
|
|
||||||
app.dependency_overrides[get_user_id] = lambda: "test-user-id"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="module")
|
|
||||||
async def client():
|
|
||||||
transport = httpx.ASGITransport(app=app)
|
|
||||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as c:
|
|
||||||
yield c
|
|
||||||
|
|
||||||
|
|
||||||
class TestDiscoverTools:
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_tools_success(self, client):
|
|
||||||
mock_tools = [
|
|
||||||
MCPTool(
|
|
||||||
name="get_weather",
|
|
||||||
description="Get weather for a city",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"city": {"type": "string"}},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
MCPTool(
|
|
||||||
name="add_numbers",
|
|
||||||
description="Add two numbers",
|
|
||||||
input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"a": {"type": "number"},
|
|
||||||
"b": {"type": "number"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
):
|
|
||||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"protocolVersion": "2025-03-26",
|
|
||||||
"serverInfo": {"name": "test-server"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
instance.list_tools = AsyncMock(return_value=mock_tools)
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://mcp.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert len(data["tools"]) == 2
|
|
||||||
assert data["tools"][0]["name"] == "get_weather"
|
|
||||||
assert data["tools"][1]["name"] == "add_numbers"
|
|
||||||
assert data["server_name"] == "test-server"
|
|
||||||
assert data["protocol_version"] == "2025-03-26"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_tools_with_auth_token(self, client):
|
|
||||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
|
||||||
)
|
|
||||||
instance.list_tools = AsyncMock(return_value=[])
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={
|
|
||||||
"server_url": "https://mcp.example.com/mcp",
|
|
||||||
"auth_token": "my-secret-token",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
MockClient.assert_called_once_with(
|
|
||||||
"https://mcp.example.com/mcp",
|
|
||||||
auth_token="my-secret-token",
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_tools_auto_uses_stored_credential(self, client):
|
|
||||||
"""When no explicit token is given, stored MCP credentials are used."""
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
stored_cred = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
title="MCP: example.com",
|
|
||||||
access_token=SecretStr("stored-token-123"),
|
|
||||||
refresh_token=None,
|
|
||||||
access_token_expires_at=None,
|
|
||||||
refresh_token_expires_at=None,
|
|
||||||
scopes=[],
|
|
||||||
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
):
|
|
||||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
|
|
||||||
mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred)
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
|
|
||||||
)
|
|
||||||
instance.list_tools = AsyncMock(return_value=[])
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://mcp.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
MockClient.assert_called_once_with(
|
|
||||||
"https://mcp.example.com/mcp",
|
|
||||||
auth_token="stored-token-123",
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_tools_mcp_error(self, client):
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
):
|
|
||||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
side_effect=MCPClientError("Connection refused")
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://bad-server.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 502
|
|
||||||
assert "Connection refused" in response.json()["detail"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_tools_generic_error(self, client):
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
):
|
|
||||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://timeout.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 502
|
|
||||||
assert "Failed to connect" in response.json()["detail"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_tools_auth_required(self, client):
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
):
|
|
||||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://auth-server.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 401
|
|
||||||
assert "requires authentication" in response.json()["detail"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_tools_forbidden(self, client):
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
):
|
|
||||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.initialize = AsyncMock(
|
|
||||||
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/discover-tools",
|
|
||||||
json={"server_url": "https://auth-server.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 401
|
|
||||||
assert "requires authentication" in response.json()["detail"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_tools_missing_url(self, client):
|
|
||||||
response = await client.post("/discover-tools", json={})
|
|
||||||
assert response.status_code == 422
|
|
||||||
|
|
||||||
|
|
||||||
class TestOAuthLogin:
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_oauth_login_success(self, client):
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
|
||||||
patch(
|
|
||||||
"backend.api.features.mcp.routes._register_mcp_client"
|
|
||||||
) as mock_register,
|
|
||||||
):
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.discover_auth = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"authorization_servers": ["https://auth.sentry.io"],
|
|
||||||
"resource": "https://mcp.sentry.dev/mcp",
|
|
||||||
"scopes_supported": ["openid"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
instance.discover_auth_server_metadata = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"authorization_endpoint": "https://auth.sentry.io/authorize",
|
|
||||||
"token_endpoint": "https://auth.sentry.io/token",
|
|
||||||
"registration_endpoint": "https://auth.sentry.io/register",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
mock_register.return_value = {
|
|
||||||
"client_id": "registered-client-id",
|
|
||||||
"client_secret": "registered-secret",
|
|
||||||
}
|
|
||||||
mock_cm.store.store_state_token = AsyncMock(
|
|
||||||
return_value=("state-token-123", "code-challenge-abc")
|
|
||||||
)
|
|
||||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/oauth/login",
|
|
||||||
json={"server_url": "https://mcp.sentry.dev/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "login_url" in data
|
|
||||||
assert data["state_token"] == "state-token-123"
|
|
||||||
assert "auth.sentry.io/authorize" in data["login_url"]
|
|
||||||
assert "registered-client-id" in data["login_url"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_oauth_login_no_oauth_support(self, client):
|
|
||||||
with patch("backend.api.features.mcp.routes.MCPClient") as MockClient:
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.discover_auth = AsyncMock(return_value=None)
|
|
||||||
instance.discover_auth_server_metadata = AsyncMock(return_value=None)
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/oauth/login",
|
|
||||||
json={"server_url": "https://simple-server.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "does not advertise OAuth" in response.json()["detail"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_oauth_login_fallback_to_public_client(self, client):
|
|
||||||
"""When DCR is unavailable, falls back to default public client ID."""
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
|
||||||
):
|
|
||||||
instance = MockClient.return_value
|
|
||||||
instance.discover_auth = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"authorization_servers": ["https://auth.example.com"],
|
|
||||||
"resource": "https://mcp.example.com/mcp",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
instance.discover_auth_server_metadata = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"authorization_endpoint": "https://auth.example.com/authorize",
|
|
||||||
"token_endpoint": "https://auth.example.com/token",
|
|
||||||
# No registration_endpoint
|
|
||||||
}
|
|
||||||
)
|
|
||||||
mock_cm.store.store_state_token = AsyncMock(
|
|
||||||
return_value=("state-abc", "challenge-xyz")
|
|
||||||
)
|
|
||||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/oauth/login",
|
|
||||||
json={"server_url": "https://mcp.example.com/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "autogpt-platform" in data["login_url"]
|
|
||||||
|
|
||||||
|
|
||||||
class TestOAuthCallback:
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_oauth_callback_success(self, client):
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
mock_creds = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
title=None,
|
|
||||||
access_token=SecretStr("access-token-xyz"),
|
|
||||||
refresh_token=None,
|
|
||||||
access_token_expires_at=None,
|
|
||||||
refresh_token_expires_at=None,
|
|
||||||
scopes=[],
|
|
||||||
metadata={
|
|
||||||
"mcp_token_url": "https://auth.sentry.io/token",
|
|
||||||
"mcp_resource_url": "https://mcp.sentry.dev/mcp",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
|
||||||
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
|
|
||||||
):
|
|
||||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
|
||||||
|
|
||||||
# Mock state verification
|
|
||||||
mock_state = AsyncMock()
|
|
||||||
mock_state.state_metadata = {
|
|
||||||
"authorize_url": "https://auth.sentry.io/authorize",
|
|
||||||
"token_url": "https://auth.sentry.io/token",
|
|
||||||
"client_id": "test-client-id",
|
|
||||||
"client_secret": "test-secret",
|
|
||||||
"server_url": "https://mcp.sentry.dev/mcp",
|
|
||||||
}
|
|
||||||
mock_state.scopes = ["openid"]
|
|
||||||
mock_state.code_verifier = "verifier-123"
|
|
||||||
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
|
|
||||||
mock_cm.create = AsyncMock()
|
|
||||||
|
|
||||||
handler_instance = MockHandler.return_value
|
|
||||||
handler_instance.exchange_code_for_tokens = AsyncMock(
|
|
||||||
return_value=mock_creds
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock old credential cleanup
|
|
||||||
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/oauth/callback",
|
|
||||||
json={"code": "auth-code-abc", "state_token": "state-token-123"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
data = response.json()
|
|
||||||
assert "id" in data
|
|
||||||
assert data["provider"] == "mcp"
|
|
||||||
assert data["type"] == "oauth2"
|
|
||||||
mock_cm.create.assert_called_once()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_oauth_callback_invalid_state(self, client):
|
|
||||||
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
|
|
||||||
mock_cm.store.verify_state_token = AsyncMock(return_value=None)
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/oauth/callback",
|
|
||||||
json={"code": "auth-code", "state_token": "bad-state"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "Invalid or expired" in response.json()["detail"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_oauth_callback_token_exchange_fails(self, client):
|
|
||||||
with (
|
|
||||||
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
|
|
||||||
patch("backend.api.features.mcp.routes.settings") as mock_settings,
|
|
||||||
patch("backend.api.features.mcp.routes.MCPOAuthHandler") as MockHandler,
|
|
||||||
):
|
|
||||||
mock_settings.config.frontend_base_url = "http://localhost:3000"
|
|
||||||
mock_state = AsyncMock()
|
|
||||||
mock_state.state_metadata = {
|
|
||||||
"authorize_url": "https://auth.example.com/authorize",
|
|
||||||
"token_url": "https://auth.example.com/token",
|
|
||||||
"client_id": "cid",
|
|
||||||
"server_url": "https://mcp.example.com/mcp",
|
|
||||||
}
|
|
||||||
mock_state.scopes = []
|
|
||||||
mock_state.code_verifier = "v"
|
|
||||||
mock_cm.store.verify_state_token = AsyncMock(return_value=mock_state)
|
|
||||||
|
|
||||||
handler_instance = MockHandler.return_value
|
|
||||||
handler_instance.exchange_code_for_tokens = AsyncMock(
|
|
||||||
side_effect=RuntimeError("Token exchange failed")
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await client.post(
|
|
||||||
"/oauth/callback",
|
|
||||||
json={"code": "bad-code", "state_token": "state"},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response.status_code == 400
|
|
||||||
assert "token exchange failed" in response.json()["detail"].lower()
|
|
||||||
@@ -26,7 +26,6 @@ import backend.api.features.executions.review.routes
|
|||||||
import backend.api.features.library.db
|
import backend.api.features.library.db
|
||||||
import backend.api.features.library.model
|
import backend.api.features.library.model
|
||||||
import backend.api.features.library.routes
|
import backend.api.features.library.routes
|
||||||
import backend.api.features.mcp.routes as mcp_routes
|
|
||||||
import backend.api.features.oauth
|
import backend.api.features.oauth
|
||||||
import backend.api.features.otto.routes
|
import backend.api.features.otto.routes
|
||||||
import backend.api.features.postmark.postmark
|
import backend.api.features.postmark.postmark
|
||||||
@@ -344,11 +343,6 @@ app.include_router(
|
|||||||
tags=["workspace"],
|
tags=["workspace"],
|
||||||
prefix="/api/workspace",
|
prefix="/api/workspace",
|
||||||
)
|
)
|
||||||
app.include_router(
|
|
||||||
mcp_routes.router,
|
|
||||||
tags=["v2", "mcp"],
|
|
||||||
prefix="/api/mcp",
|
|
||||||
)
|
|
||||||
app.include_router(
|
app.include_router(
|
||||||
backend.api.features.oauth.router,
|
backend.api.features.oauth.router,
|
||||||
tags=["oauth"],
|
tags=["oauth"],
|
||||||
|
|||||||
@@ -64,7 +64,6 @@ class BlockType(Enum):
|
|||||||
AI = "AI"
|
AI = "AI"
|
||||||
AYRSHARE = "Ayrshare"
|
AYRSHARE = "Ayrshare"
|
||||||
HUMAN_IN_THE_LOOP = "Human In The Loop"
|
HUMAN_IN_THE_LOOP = "Human In The Loop"
|
||||||
MCP_TOOL = "MCP Tool"
|
|
||||||
|
|
||||||
|
|
||||||
class BlockCategory(Enum):
|
class BlockCategory(Enum):
|
||||||
|
|||||||
@@ -126,7 +126,6 @@ class PrintToConsoleBlock(Block):
|
|||||||
output_schema=PrintToConsoleBlock.Output,
|
output_schema=PrintToConsoleBlock.Output,
|
||||||
test_input={"text": "Hello, World!"},
|
test_input={"text": "Hello, World!"},
|
||||||
is_sensitive_action=True,
|
is_sensitive_action=True,
|
||||||
disabled=True, # Disabled per Nick Tindle's request (OPEN-3000)
|
|
||||||
test_output=[
|
test_output=[
|
||||||
("output", "Hello, World!"),
|
("output", "Hello, World!"),
|
||||||
("status", "printed"),
|
("status", "printed"),
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ from backend.blocks.jina._auth import (
|
|||||||
from backend.blocks.search import GetRequest
|
from backend.blocks.search import GetRequest
|
||||||
from backend.data.model import SchemaField
|
from backend.data.model import SchemaField
|
||||||
from backend.util.exceptions import BlockExecutionError
|
from backend.util.exceptions import BlockExecutionError
|
||||||
from backend.util.request import HTTPClientError, HTTPServerError, validate_url
|
|
||||||
|
|
||||||
|
|
||||||
class SearchTheWebBlock(Block, GetRequest):
|
class SearchTheWebBlock(Block, GetRequest):
|
||||||
@@ -111,12 +110,7 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
|
|||||||
self, input_data: Input, *, credentials: JinaCredentials, **kwargs
|
self, input_data: Input, *, credentials: JinaCredentials, **kwargs
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
if input_data.raw_content:
|
if input_data.raw_content:
|
||||||
try:
|
url = input_data.url
|
||||||
parsed_url, _, _ = await validate_url(input_data.url, [])
|
|
||||||
url = parsed_url.geturl()
|
|
||||||
except ValueError as e:
|
|
||||||
yield "error", f"Invalid URL: {e}"
|
|
||||||
return
|
|
||||||
headers = {}
|
headers = {}
|
||||||
else:
|
else:
|
||||||
url = f"https://r.jina.ai/{input_data.url}"
|
url = f"https://r.jina.ai/{input_data.url}"
|
||||||
@@ -125,20 +119,5 @@ class ExtractWebsiteContentBlock(Block, GetRequest):
|
|||||||
"Authorization": f"Bearer {credentials.api_key.get_secret_value()}",
|
"Authorization": f"Bearer {credentials.api_key.get_secret_value()}",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
|
||||||
content = await self.get_request(url, json=False, headers=headers)
|
content = await self.get_request(url, json=False, headers=headers)
|
||||||
except HTTPClientError as e:
|
|
||||||
yield "error", f"Client error ({e.status_code}) fetching {input_data.url}: {e}"
|
|
||||||
return
|
|
||||||
except HTTPServerError as e:
|
|
||||||
yield "error", f"Server error ({e.status_code}) fetching {input_data.url}: {e}"
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
yield "error", f"Failed to fetch {input_data.url}: {e}"
|
|
||||||
return
|
|
||||||
|
|
||||||
if not content:
|
|
||||||
yield "error", f"No content returned for {input_data.url}"
|
|
||||||
return
|
|
||||||
|
|
||||||
yield "content", content
|
yield "content", content
|
||||||
|
|||||||
@@ -1,300 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP (Model Context Protocol) Tool Block.
|
|
||||||
|
|
||||||
A single dynamic block that can connect to any MCP server, discover available tools,
|
|
||||||
and execute them. Works like AgentExecutorBlock — the user selects a tool from a
|
|
||||||
dropdown and the input/output schema adapts dynamically.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.blocks._base import (
|
|
||||||
Block,
|
|
||||||
BlockCategory,
|
|
||||||
BlockSchemaInput,
|
|
||||||
BlockSchemaOutput,
|
|
||||||
BlockType,
|
|
||||||
)
|
|
||||||
from backend.blocks.mcp.client import MCPClient, MCPClientError
|
|
||||||
from backend.data.block import BlockInput, BlockOutput
|
|
||||||
from backend.data.model import (
|
|
||||||
CredentialsField,
|
|
||||||
CredentialsMetaInput,
|
|
||||||
OAuth2Credentials,
|
|
||||||
SchemaField,
|
|
||||||
)
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.json import validate_with_jsonschema
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
TEST_CREDENTIALS = OAuth2Credentials(
|
|
||||||
id="test-mcp-cred",
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("mock-mcp-token"),
|
|
||||||
refresh_token=SecretStr("mock-refresh"),
|
|
||||||
scopes=[],
|
|
||||||
title="Mock MCP credential",
|
|
||||||
)
|
|
||||||
TEST_CREDENTIALS_INPUT = {
|
|
||||||
"provider": TEST_CREDENTIALS.provider,
|
|
||||||
"id": TEST_CREDENTIALS.id,
|
|
||||||
"type": TEST_CREDENTIALS.type,
|
|
||||||
"title": TEST_CREDENTIALS.title,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
MCPCredentials = CredentialsMetaInput[Literal[ProviderName.MCP], Literal["oauth2"]]
|
|
||||||
|
|
||||||
|
|
||||||
class MCPToolBlock(Block):
|
|
||||||
"""
|
|
||||||
A block that connects to an MCP server, lets the user pick a tool,
|
|
||||||
and executes it with dynamic input/output schema.
|
|
||||||
|
|
||||||
The flow:
|
|
||||||
1. User provides an MCP server URL (and optional credentials)
|
|
||||||
2. Frontend calls the backend to get tool list from that URL
|
|
||||||
3. User selects a tool from a dropdown (available_tools)
|
|
||||||
4. The block's input schema updates to reflect the selected tool's parameters
|
|
||||||
5. On execution, the block calls the MCP server to run the tool
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Input(BlockSchemaInput):
|
|
||||||
server_url: str = SchemaField(
|
|
||||||
description="URL of the MCP server (Streamable HTTP endpoint)",
|
|
||||||
placeholder="https://mcp.example.com/mcp",
|
|
||||||
)
|
|
||||||
credentials: MCPCredentials = CredentialsField(
|
|
||||||
discriminator="server_url",
|
|
||||||
description="MCP server OAuth credentials",
|
|
||||||
default={},
|
|
||||||
)
|
|
||||||
selected_tool: str = SchemaField(
|
|
||||||
description="The MCP tool to execute",
|
|
||||||
placeholder="Select a tool",
|
|
||||||
default="",
|
|
||||||
)
|
|
||||||
tool_input_schema: dict[str, Any] = SchemaField(
|
|
||||||
description="JSON Schema for the selected tool's input parameters. "
|
|
||||||
"Populated automatically when a tool is selected.",
|
|
||||||
default={},
|
|
||||||
hidden=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_arguments: dict[str, Any] = SchemaField(
|
|
||||||
description="Arguments to pass to the selected MCP tool. "
|
|
||||||
"The fields here are defined by the tool's input schema.",
|
|
||||||
default={},
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_input_schema(cls, data: BlockInput) -> dict[str, Any]:
|
|
||||||
"""Return the tool's input schema so the builder UI renders dynamic fields."""
|
|
||||||
return data.get("tool_input_schema", {})
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_input_defaults(cls, data: BlockInput) -> BlockInput:
|
|
||||||
"""Return the current tool_arguments as defaults for the dynamic fields."""
|
|
||||||
return data.get("tool_arguments", {})
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_missing_input(cls, data: BlockInput) -> set[str]:
|
|
||||||
"""Check which required tool arguments are missing."""
|
|
||||||
required_fields = cls.get_input_schema(data).get("required", [])
|
|
||||||
tool_arguments = data.get("tool_arguments", {})
|
|
||||||
return set(required_fields) - set(tool_arguments)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_mismatch_error(cls, data: BlockInput) -> str | None:
|
|
||||||
"""Validate tool_arguments against the tool's input schema."""
|
|
||||||
tool_schema = cls.get_input_schema(data)
|
|
||||||
if not tool_schema:
|
|
||||||
return None
|
|
||||||
tool_arguments = data.get("tool_arguments", {})
|
|
||||||
return validate_with_jsonschema(tool_schema, tool_arguments)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
|
||||||
result: Any = SchemaField(description="The result returned by the MCP tool")
|
|
||||||
error: str = SchemaField(description="Error message if the tool call failed")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(
|
|
||||||
id="a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4",
|
|
||||||
description="Connect to any MCP server and execute its tools. "
|
|
||||||
"Provide a server URL, select a tool, and pass arguments dynamically.",
|
|
||||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
|
||||||
input_schema=MCPToolBlock.Input,
|
|
||||||
output_schema=MCPToolBlock.Output,
|
|
||||||
block_type=BlockType.MCP_TOOL,
|
|
||||||
test_credentials=TEST_CREDENTIALS,
|
|
||||||
test_input={
|
|
||||||
"server_url": "https://mcp.example.com/mcp",
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT,
|
|
||||||
"selected_tool": "get_weather",
|
|
||||||
"tool_input_schema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"city": {"type": "string"}},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
"tool_arguments": {"city": "London"},
|
|
||||||
},
|
|
||||||
test_output=[
|
|
||||||
(
|
|
||||||
"result",
|
|
||||||
{"weather": "sunny", "temperature": 20},
|
|
||||||
),
|
|
||||||
],
|
|
||||||
test_mock={
|
|
||||||
"_call_mcp_tool": lambda *a, **kw: {
|
|
||||||
"weather": "sunny",
|
|
||||||
"temperature": 20,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _call_mcp_tool(
|
|
||||||
self,
|
|
||||||
server_url: str,
|
|
||||||
tool_name: str,
|
|
||||||
arguments: dict[str, Any],
|
|
||||||
auth_token: str | None = None,
|
|
||||||
) -> Any:
|
|
||||||
"""Call a tool on the MCP server. Extracted for easy mocking in tests."""
|
|
||||||
client = MCPClient(server_url, auth_token=auth_token)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool(tool_name, arguments)
|
|
||||||
|
|
||||||
if result.is_error:
|
|
||||||
error_text = ""
|
|
||||||
for item in result.content:
|
|
||||||
if item.get("type") == "text":
|
|
||||||
error_text += item.get("text", "")
|
|
||||||
raise MCPClientError(
|
|
||||||
f"MCP tool '{tool_name}' returned an error: "
|
|
||||||
f"{error_text or 'Unknown error'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract text content from the result
|
|
||||||
output_parts = []
|
|
||||||
for item in result.content:
|
|
||||||
if item.get("type") == "text":
|
|
||||||
text = item.get("text", "")
|
|
||||||
# Try to parse as JSON for structured output
|
|
||||||
try:
|
|
||||||
output_parts.append(json.loads(text))
|
|
||||||
except (json.JSONDecodeError, ValueError):
|
|
||||||
output_parts.append(text)
|
|
||||||
elif item.get("type") == "image":
|
|
||||||
output_parts.append(
|
|
||||||
{
|
|
||||||
"type": "image",
|
|
||||||
"data": item.get("data"),
|
|
||||||
"mimeType": item.get("mimeType"),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif item.get("type") == "resource":
|
|
||||||
output_parts.append(item.get("resource", {}))
|
|
||||||
|
|
||||||
# If single result, unwrap
|
|
||||||
if len(output_parts) == 1:
|
|
||||||
return output_parts[0]
|
|
||||||
return output_parts if output_parts else None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _auto_lookup_credential(
|
|
||||||
user_id: str, server_url: str
|
|
||||||
) -> "OAuth2Credentials | None":
|
|
||||||
"""Auto-lookup stored MCP credential for a server URL.
|
|
||||||
|
|
||||||
This is a fallback for nodes that don't have ``credentials`` explicitly
|
|
||||||
set (e.g. nodes created before the credential field was wired up).
|
|
||||||
"""
|
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
try:
|
|
||||||
mgr = IntegrationCredentialsManager()
|
|
||||||
mcp_creds = await mgr.store.get_creds_by_provider(
|
|
||||||
user_id, ProviderName.MCP.value
|
|
||||||
)
|
|
||||||
best: OAuth2Credentials | None = None
|
|
||||||
for cred in mcp_creds:
|
|
||||||
if (
|
|
||||||
isinstance(cred, OAuth2Credentials)
|
|
||||||
and (cred.metadata or {}).get("mcp_server_url") == server_url
|
|
||||||
):
|
|
||||||
if best is None or (
|
|
||||||
(cred.access_token_expires_at or 0)
|
|
||||||
> (best.access_token_expires_at or 0)
|
|
||||||
):
|
|
||||||
best = cred
|
|
||||||
if best:
|
|
||||||
best = await mgr.refresh_if_needed(user_id, best)
|
|
||||||
logger.info(
|
|
||||||
"Auto-resolved MCP credential %s for %s", best.id, server_url
|
|
||||||
)
|
|
||||||
return best
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def run(
|
|
||||||
self,
|
|
||||||
input_data: Input,
|
|
||||||
*,
|
|
||||||
user_id: str,
|
|
||||||
credentials: OAuth2Credentials | None = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> BlockOutput:
|
|
||||||
if not input_data.server_url:
|
|
||||||
yield "error", "MCP server URL is required"
|
|
||||||
return
|
|
||||||
|
|
||||||
if not input_data.selected_tool:
|
|
||||||
yield "error", "No tool selected. Please select a tool from the dropdown."
|
|
||||||
return
|
|
||||||
|
|
||||||
# Validate required tool arguments before calling the server.
|
|
||||||
# The executor-level validation is bypassed for MCP blocks because
|
|
||||||
# get_input_defaults() flattens tool_arguments, stripping tool_input_schema
|
|
||||||
# from the validation context.
|
|
||||||
required = set(input_data.tool_input_schema.get("required", []))
|
|
||||||
if required:
|
|
||||||
missing = required - set(input_data.tool_arguments.keys())
|
|
||||||
if missing:
|
|
||||||
yield "error", (
|
|
||||||
f"Missing required argument(s): {', '.join(sorted(missing))}. "
|
|
||||||
f"Please fill in all required fields marked with * in the block form."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# If no credentials were injected by the executor (e.g. legacy nodes
|
|
||||||
# that don't have the credentials field set), try to auto-lookup
|
|
||||||
# the stored MCP credential for this server URL.
|
|
||||||
if credentials is None:
|
|
||||||
credentials = await self._auto_lookup_credential(
|
|
||||||
user_id, input_data.server_url
|
|
||||||
)
|
|
||||||
|
|
||||||
auth_token = (
|
|
||||||
credentials.access_token.get_secret_value() if credentials else None
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await self._call_mcp_tool(
|
|
||||||
server_url=input_data.server_url,
|
|
||||||
tool_name=input_data.selected_tool,
|
|
||||||
arguments=input_data.tool_arguments,
|
|
||||||
auth_token=auth_token,
|
|
||||||
)
|
|
||||||
yield "result", result
|
|
||||||
except MCPClientError as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"MCP tool call failed: {e}")
|
|
||||||
yield "error", f"MCP tool call failed: {str(e)}"
|
|
||||||
@@ -1,323 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP (Model Context Protocol) HTTP client.
|
|
||||||
|
|
||||||
Implements the MCP Streamable HTTP transport for listing tools and calling tools
|
|
||||||
on remote MCP servers. Uses JSON-RPC 2.0 over HTTP POST.
|
|
||||||
|
|
||||||
Handles both JSON and SSE (text/event-stream) response formats per the MCP spec.
|
|
||||||
|
|
||||||
Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from backend.util.request import Requests
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MCPTool:
|
|
||||||
"""Represents an MCP tool discovered from a server."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
input_schema: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MCPCallResult:
|
|
||||||
"""Result from calling an MCP tool."""
|
|
||||||
|
|
||||||
content: list[dict[str, Any]] = field(default_factory=list)
|
|
||||||
is_error: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class MCPClientError(Exception):
|
|
||||||
"""Raised when an MCP protocol error occurs."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MCPClient:
|
|
||||||
"""
|
|
||||||
Async HTTP client for the MCP Streamable HTTP transport.
|
|
||||||
|
|
||||||
Communicates with MCP servers using JSON-RPC 2.0 over HTTP POST.
|
|
||||||
Supports optional Bearer token authentication.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
server_url: str,
|
|
||||||
auth_token: str | None = None,
|
|
||||||
):
|
|
||||||
self.server_url = server_url.rstrip("/")
|
|
||||||
self.auth_token = auth_token
|
|
||||||
self._request_id = 0
|
|
||||||
self._session_id: str | None = None
|
|
||||||
|
|
||||||
def _next_id(self) -> int:
|
|
||||||
self._request_id += 1
|
|
||||||
return self._request_id
|
|
||||||
|
|
||||||
def _build_headers(self) -> dict[str, str]:
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Accept": "application/json, text/event-stream",
|
|
||||||
}
|
|
||||||
if self.auth_token:
|
|
||||||
headers["Authorization"] = f"Bearer {self.auth_token}"
|
|
||||||
if self._session_id:
|
|
||||||
headers["Mcp-Session-Id"] = self._session_id
|
|
||||||
return headers
|
|
||||||
|
|
||||||
def _build_jsonrpc_request(
|
|
||||||
self, method: str, params: dict[str, Any] | None = None
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
req: dict[str, Any] = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"method": method,
|
|
||||||
"id": self._next_id(),
|
|
||||||
}
|
|
||||||
if params is not None:
|
|
||||||
req["params"] = params
|
|
||||||
return req
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _parse_sse_response(text: str) -> dict[str, Any]:
|
|
||||||
"""Parse an SSE (text/event-stream) response body into JSON-RPC data.
|
|
||||||
|
|
||||||
MCP servers may return responses as SSE with format:
|
|
||||||
event: message
|
|
||||||
data: {"jsonrpc":"2.0","result":{...},"id":1}
|
|
||||||
|
|
||||||
We extract the last `data:` line that contains a JSON-RPC response
|
|
||||||
(i.e. has an "id" field), which is the reply to our request.
|
|
||||||
"""
|
|
||||||
last_data: dict[str, Any] | None = None
|
|
||||||
for line in text.splitlines():
|
|
||||||
stripped = line.strip()
|
|
||||||
if stripped.startswith("data:"):
|
|
||||||
payload = stripped[len("data:") :].strip()
|
|
||||||
if not payload:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
parsed = json.loads(payload)
|
|
||||||
# Only keep JSON-RPC responses (have "id"), skip notifications
|
|
||||||
if isinstance(parsed, dict) and "id" in parsed:
|
|
||||||
last_data = parsed
|
|
||||||
except (json.JSONDecodeError, ValueError):
|
|
||||||
continue
|
|
||||||
if last_data is None:
|
|
||||||
raise MCPClientError("No JSON-RPC response found in SSE stream")
|
|
||||||
return last_data
|
|
||||||
|
|
||||||
async def _send_request(
|
|
||||||
self, method: str, params: dict[str, Any] | None = None
|
|
||||||
) -> Any:
|
|
||||||
"""Send a JSON-RPC request to the MCP server and return the result.
|
|
||||||
|
|
||||||
Handles both ``application/json`` and ``text/event-stream`` responses
|
|
||||||
as required by the MCP Streamable HTTP transport specification.
|
|
||||||
"""
|
|
||||||
payload = self._build_jsonrpc_request(method, params)
|
|
||||||
headers = self._build_headers()
|
|
||||||
|
|
||||||
requests = Requests(
|
|
||||||
raise_for_status=True,
|
|
||||||
extra_headers=headers,
|
|
||||||
)
|
|
||||||
response = await requests.post(self.server_url, json=payload)
|
|
||||||
|
|
||||||
# Capture session ID from response (MCP Streamable HTTP transport)
|
|
||||||
session_id = response.headers.get("Mcp-Session-Id")
|
|
||||||
if session_id:
|
|
||||||
self._session_id = session_id
|
|
||||||
|
|
||||||
content_type = response.headers.get("content-type", "")
|
|
||||||
if "text/event-stream" in content_type:
|
|
||||||
body = self._parse_sse_response(response.text())
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
body = response.json()
|
|
||||||
except Exception as e:
|
|
||||||
raise MCPClientError(
|
|
||||||
f"MCP server returned non-JSON response: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if not isinstance(body, dict):
|
|
||||||
raise MCPClientError(
|
|
||||||
f"MCP server returned unexpected JSON type: {type(body).__name__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle JSON-RPC error
|
|
||||||
if "error" in body:
|
|
||||||
error = body["error"]
|
|
||||||
if isinstance(error, dict):
|
|
||||||
raise MCPClientError(
|
|
||||||
f"MCP server error [{error.get('code', '?')}]: "
|
|
||||||
f"{error.get('message', 'Unknown error')}"
|
|
||||||
)
|
|
||||||
raise MCPClientError(f"MCP server error: {error}")
|
|
||||||
|
|
||||||
return body.get("result")
|
|
||||||
|
|
||||||
async def _send_notification(self, method: str) -> None:
|
|
||||||
"""Send a JSON-RPC notification (no id, no response expected)."""
|
|
||||||
headers = self._build_headers()
|
|
||||||
notification = {"jsonrpc": "2.0", "method": method}
|
|
||||||
requests = Requests(
|
|
||||||
raise_for_status=False,
|
|
||||||
extra_headers=headers,
|
|
||||||
)
|
|
||||||
await requests.post(self.server_url, json=notification)
|
|
||||||
|
|
||||||
async def discover_auth(self) -> dict[str, Any] | None:
|
|
||||||
"""Probe the MCP server's OAuth metadata (RFC 9728 / MCP spec).
|
|
||||||
|
|
||||||
Returns ``None`` if the server doesn't require auth, otherwise returns
|
|
||||||
a dict with:
|
|
||||||
- ``authorization_servers``: list of authorization server URLs
|
|
||||||
- ``resource``: the resource indicator URL (usually the MCP endpoint)
|
|
||||||
- ``scopes_supported``: optional list of supported scopes
|
|
||||||
|
|
||||||
The caller can then fetch the authorization server metadata to get
|
|
||||||
``authorization_endpoint``, ``token_endpoint``, etc.
|
|
||||||
"""
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
parsed = urlparse(self.server_url)
|
|
||||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
|
||||||
|
|
||||||
# Build candidates for protected-resource metadata (per RFC 9728)
|
|
||||||
path = parsed.path.rstrip("/")
|
|
||||||
candidates = []
|
|
||||||
if path and path != "/":
|
|
||||||
candidates.append(f"{base}/.well-known/oauth-protected-resource{path}")
|
|
||||||
candidates.append(f"{base}/.well-known/oauth-protected-resource")
|
|
||||||
|
|
||||||
requests = Requests(
|
|
||||||
raise_for_status=False,
|
|
||||||
)
|
|
||||||
for url in candidates:
|
|
||||||
try:
|
|
||||||
resp = await requests.get(url)
|
|
||||||
if resp.status == 200:
|
|
||||||
data = resp.json()
|
|
||||||
if isinstance(data, dict) and "authorization_servers" in data:
|
|
||||||
return data
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def discover_auth_server_metadata(
|
|
||||||
self, auth_server_url: str
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""Fetch the OAuth Authorization Server Metadata (RFC 8414).
|
|
||||||
|
|
||||||
Given an authorization server URL, returns a dict with:
|
|
||||||
- ``authorization_endpoint``
|
|
||||||
- ``token_endpoint``
|
|
||||||
- ``registration_endpoint`` (for dynamic client registration)
|
|
||||||
- ``scopes_supported``
|
|
||||||
- ``code_challenge_methods_supported``
|
|
||||||
- etc.
|
|
||||||
"""
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
parsed = urlparse(auth_server_url)
|
|
||||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
|
||||||
path = parsed.path.rstrip("/")
|
|
||||||
|
|
||||||
# Try standard metadata endpoints (RFC 8414 and OpenID Connect)
|
|
||||||
candidates = []
|
|
||||||
if path and path != "/":
|
|
||||||
candidates.append(f"{base}/.well-known/oauth-authorization-server{path}")
|
|
||||||
candidates.append(f"{base}/.well-known/oauth-authorization-server")
|
|
||||||
candidates.append(f"{base}/.well-known/openid-configuration")
|
|
||||||
|
|
||||||
requests = Requests(
|
|
||||||
raise_for_status=False,
|
|
||||||
)
|
|
||||||
for url in candidates:
|
|
||||||
try:
|
|
||||||
resp = await requests.get(url)
|
|
||||||
if resp.status == 200:
|
|
||||||
data = resp.json()
|
|
||||||
if isinstance(data, dict) and "authorization_endpoint" in data:
|
|
||||||
return data
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def initialize(self) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Send the MCP initialize request.
|
|
||||||
|
|
||||||
This is required by the MCP protocol before any other requests.
|
|
||||||
Returns the server's capabilities.
|
|
||||||
"""
|
|
||||||
result = await self._send_request(
|
|
||||||
"initialize",
|
|
||||||
{
|
|
||||||
"protocolVersion": "2025-03-26",
|
|
||||||
"capabilities": {},
|
|
||||||
"clientInfo": {"name": "AutoGPT-Platform", "version": "1.0.0"},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Send initialized notification (no response expected)
|
|
||||||
await self._send_notification("notifications/initialized")
|
|
||||||
|
|
||||||
return result or {}
|
|
||||||
|
|
||||||
async def list_tools(self) -> list[MCPTool]:
|
|
||||||
"""
|
|
||||||
Discover available tools from the MCP server.
|
|
||||||
|
|
||||||
Returns a list of MCPTool objects with name, description, and input schema.
|
|
||||||
"""
|
|
||||||
result = await self._send_request("tools/list")
|
|
||||||
if not result or "tools" not in result:
|
|
||||||
return []
|
|
||||||
|
|
||||||
tools = []
|
|
||||||
for tool_data in result["tools"]:
|
|
||||||
tools.append(
|
|
||||||
MCPTool(
|
|
||||||
name=tool_data.get("name", ""),
|
|
||||||
description=tool_data.get("description", ""),
|
|
||||||
input_schema=tool_data.get("inputSchema", {}),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return tools
|
|
||||||
|
|
||||||
async def call_tool(
|
|
||||||
self, tool_name: str, arguments: dict[str, Any]
|
|
||||||
) -> MCPCallResult:
|
|
||||||
"""
|
|
||||||
Call a tool on the MCP server.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: The name of the tool to call.
|
|
||||||
arguments: The arguments to pass to the tool.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
MCPCallResult with the tool's response content.
|
|
||||||
"""
|
|
||||||
result = await self._send_request(
|
|
||||||
"tools/call",
|
|
||||||
{"name": tool_name, "arguments": arguments},
|
|
||||||
)
|
|
||||||
if not result:
|
|
||||||
return MCPCallResult(is_error=True)
|
|
||||||
|
|
||||||
return MCPCallResult(
|
|
||||||
content=result.get("content", []),
|
|
||||||
is_error=result.get("isError", False),
|
|
||||||
)
|
|
||||||
@@ -1,204 +0,0 @@
|
|||||||
"""
|
|
||||||
MCP OAuth handler for MCP servers that use OAuth 2.1 authorization.
|
|
||||||
|
|
||||||
Unlike other OAuth handlers (GitHub, Google, etc.) where endpoints are fixed,
|
|
||||||
MCP servers have dynamic endpoints discovered via RFC 9728 / RFC 8414 metadata.
|
|
||||||
This handler accepts those endpoints at construction time.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import urllib.parse
|
|
||||||
from typing import ClassVar, Optional
|
|
||||||
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.util.request import Requests
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class MCPOAuthHandler(BaseOAuthHandler):
|
|
||||||
"""
|
|
||||||
OAuth handler for MCP servers with dynamically-discovered endpoints.
|
|
||||||
|
|
||||||
Construction requires the authorization and token endpoint URLs,
|
|
||||||
which are obtained via MCP OAuth metadata discovery
|
|
||||||
(``MCPClient.discover_auth`` + ``discover_auth_server_metadata``).
|
|
||||||
"""
|
|
||||||
|
|
||||||
PROVIDER_NAME: ClassVar[ProviderName | str] = ProviderName.MCP
|
|
||||||
DEFAULT_SCOPES: ClassVar[list[str]] = []
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
client_id: str,
|
|
||||||
client_secret: str,
|
|
||||||
redirect_uri: str,
|
|
||||||
*,
|
|
||||||
authorize_url: str,
|
|
||||||
token_url: str,
|
|
||||||
revoke_url: str | None = None,
|
|
||||||
resource_url: str | None = None,
|
|
||||||
):
|
|
||||||
self.client_id = client_id
|
|
||||||
self.client_secret = client_secret
|
|
||||||
self.redirect_uri = redirect_uri
|
|
||||||
self.authorize_url = authorize_url
|
|
||||||
self.token_url = token_url
|
|
||||||
self.revoke_url = revoke_url
|
|
||||||
self.resource_url = resource_url
|
|
||||||
|
|
||||||
def get_login_url(
|
|
||||||
self,
|
|
||||||
scopes: list[str],
|
|
||||||
state: str,
|
|
||||||
code_challenge: Optional[str],
|
|
||||||
) -> str:
|
|
||||||
scopes = self.handle_default_scopes(scopes)
|
|
||||||
|
|
||||||
params: dict[str, str] = {
|
|
||||||
"response_type": "code",
|
|
||||||
"client_id": self.client_id,
|
|
||||||
"redirect_uri": self.redirect_uri,
|
|
||||||
"state": state,
|
|
||||||
}
|
|
||||||
if scopes:
|
|
||||||
params["scope"] = " ".join(scopes)
|
|
||||||
# PKCE (S256) — included when the caller provides a code_challenge
|
|
||||||
if code_challenge:
|
|
||||||
params["code_challenge"] = code_challenge
|
|
||||||
params["code_challenge_method"] = "S256"
|
|
||||||
# MCP spec requires resource indicator (RFC 8707)
|
|
||||||
if self.resource_url:
|
|
||||||
params["resource"] = self.resource_url
|
|
||||||
|
|
||||||
return f"{self.authorize_url}?{urllib.parse.urlencode(params)}"
|
|
||||||
|
|
||||||
async def exchange_code_for_tokens(
|
|
||||||
self,
|
|
||||||
code: str,
|
|
||||||
scopes: list[str],
|
|
||||||
code_verifier: Optional[str],
|
|
||||||
) -> OAuth2Credentials:
|
|
||||||
data: dict[str, str] = {
|
|
||||||
"grant_type": "authorization_code",
|
|
||||||
"code": code,
|
|
||||||
"redirect_uri": self.redirect_uri,
|
|
||||||
"client_id": self.client_id,
|
|
||||||
}
|
|
||||||
if self.client_secret:
|
|
||||||
data["client_secret"] = self.client_secret
|
|
||||||
if code_verifier:
|
|
||||||
data["code_verifier"] = code_verifier
|
|
||||||
if self.resource_url:
|
|
||||||
data["resource"] = self.resource_url
|
|
||||||
|
|
||||||
response = await Requests(raise_for_status=True).post(
|
|
||||||
self.token_url,
|
|
||||||
data=data,
|
|
||||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
||||||
)
|
|
||||||
tokens = response.json()
|
|
||||||
|
|
||||||
if "error" in tokens:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Token exchange failed: {tokens.get('error_description', tokens['error'])}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if "access_token" not in tokens:
|
|
||||||
raise RuntimeError("OAuth token response missing 'access_token' field")
|
|
||||||
|
|
||||||
now = int(time.time())
|
|
||||||
expires_in = tokens.get("expires_in")
|
|
||||||
|
|
||||||
return OAuth2Credentials(
|
|
||||||
provider=self.PROVIDER_NAME,
|
|
||||||
title=None,
|
|
||||||
access_token=SecretStr(tokens["access_token"]),
|
|
||||||
refresh_token=(
|
|
||||||
SecretStr(tokens["refresh_token"])
|
|
||||||
if tokens.get("refresh_token")
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
access_token_expires_at=now + expires_in if expires_in else None,
|
|
||||||
refresh_token_expires_at=None,
|
|
||||||
scopes=scopes,
|
|
||||||
metadata={
|
|
||||||
"mcp_token_url": self.token_url,
|
|
||||||
"mcp_resource_url": self.resource_url,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _refresh_tokens(
|
|
||||||
self, credentials: OAuth2Credentials
|
|
||||||
) -> OAuth2Credentials:
|
|
||||||
if not credentials.refresh_token:
|
|
||||||
raise ValueError("No refresh token available for MCP OAuth credentials")
|
|
||||||
|
|
||||||
data: dict[str, str] = {
|
|
||||||
"grant_type": "refresh_token",
|
|
||||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
|
||||||
"client_id": self.client_id,
|
|
||||||
}
|
|
||||||
if self.client_secret:
|
|
||||||
data["client_secret"] = self.client_secret
|
|
||||||
if self.resource_url:
|
|
||||||
data["resource"] = self.resource_url
|
|
||||||
|
|
||||||
response = await Requests(raise_for_status=True).post(
|
|
||||||
self.token_url,
|
|
||||||
data=data,
|
|
||||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
||||||
)
|
|
||||||
tokens = response.json()
|
|
||||||
|
|
||||||
if "error" in tokens:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Token refresh failed: {tokens.get('error_description', tokens['error'])}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if "access_token" not in tokens:
|
|
||||||
raise RuntimeError("OAuth refresh response missing 'access_token' field")
|
|
||||||
|
|
||||||
now = int(time.time())
|
|
||||||
expires_in = tokens.get("expires_in")
|
|
||||||
|
|
||||||
return OAuth2Credentials(
|
|
||||||
id=credentials.id,
|
|
||||||
provider=self.PROVIDER_NAME,
|
|
||||||
title=credentials.title,
|
|
||||||
access_token=SecretStr(tokens["access_token"]),
|
|
||||||
refresh_token=(
|
|
||||||
SecretStr(tokens["refresh_token"])
|
|
||||||
if tokens.get("refresh_token")
|
|
||||||
else credentials.refresh_token
|
|
||||||
),
|
|
||||||
access_token_expires_at=now + expires_in if expires_in else None,
|
|
||||||
refresh_token_expires_at=credentials.refresh_token_expires_at,
|
|
||||||
scopes=credentials.scopes,
|
|
||||||
metadata=credentials.metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
|
||||||
if not self.revoke_url:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = {
|
|
||||||
"token": credentials.access_token.get_secret_value(),
|
|
||||||
"token_type_hint": "access_token",
|
|
||||||
"client_id": self.client_id,
|
|
||||||
}
|
|
||||||
await Requests().post(
|
|
||||||
self.revoke_url,
|
|
||||||
data=data,
|
|
||||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to revoke MCP OAuth tokens", exc_info=True)
|
|
||||||
return False
|
|
||||||
@@ -1,109 +0,0 @@
|
|||||||
"""
|
|
||||||
End-to-end tests against a real public MCP server.
|
|
||||||
|
|
||||||
These tests hit the OpenAI docs MCP server (https://developers.openai.com/mcp)
|
|
||||||
which is publicly accessible without authentication and returns SSE responses.
|
|
||||||
|
|
||||||
Mark: These are tagged with ``@pytest.mark.e2e`` so they can be run/skipped
|
|
||||||
independently of the rest of the test suite (they require network access).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.blocks.mcp.client import MCPClient
|
|
||||||
|
|
||||||
# Public MCP server that requires no authentication
|
|
||||||
OPENAI_DOCS_MCP_URL = "https://developers.openai.com/mcp"
|
|
||||||
|
|
||||||
# Skip all tests in this module unless RUN_E2E env var is set
|
|
||||||
pytestmark = pytest.mark.skipif(
|
|
||||||
not os.environ.get("RUN_E2E"), reason="set RUN_E2E=1 to run e2e tests"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestRealMCPServer:
|
|
||||||
"""Tests against the live OpenAI docs MCP server."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_initialize(self):
|
|
||||||
"""Verify we can complete the MCP handshake with a real server."""
|
|
||||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
|
||||||
result = await client.initialize()
|
|
||||||
|
|
||||||
assert result["protocolVersion"] == "2025-03-26"
|
|
||||||
assert "serverInfo" in result
|
|
||||||
assert result["serverInfo"]["name"] == "openai-docs-mcp"
|
|
||||||
assert "tools" in result.get("capabilities", {})
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_list_tools(self):
|
|
||||||
"""Verify we can discover tools from a real MCP server."""
|
|
||||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
|
||||||
await client.initialize()
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert len(tools) >= 3 # server has at least 5 tools as of writing
|
|
||||||
|
|
||||||
tool_names = {t.name for t in tools}
|
|
||||||
# These tools are documented and should be stable
|
|
||||||
assert "search_openai_docs" in tool_names
|
|
||||||
assert "list_openai_docs" in tool_names
|
|
||||||
assert "fetch_openai_doc" in tool_names
|
|
||||||
|
|
||||||
# Verify schema structure
|
|
||||||
search_tool = next(t for t in tools if t.name == "search_openai_docs")
|
|
||||||
assert "query" in search_tool.input_schema.get("properties", {})
|
|
||||||
assert "query" in search_tool.input_schema.get("required", [])
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_tool_list_api_endpoints(self):
|
|
||||||
"""Call the list_api_endpoints tool and verify we get real data."""
|
|
||||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool("list_api_endpoints", {})
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
assert len(result.content) >= 1
|
|
||||||
assert result.content[0]["type"] == "text"
|
|
||||||
|
|
||||||
data = json.loads(result.content[0]["text"])
|
|
||||||
assert "paths" in data or "urls" in data
|
|
||||||
# The OpenAI API should have many endpoints
|
|
||||||
total = data.get("total", len(data.get("paths", [])))
|
|
||||||
assert total > 50
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_tool_search(self):
|
|
||||||
"""Search for docs and verify we get results."""
|
|
||||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool(
|
|
||||||
"search_openai_docs", {"query": "chat completions", "limit": 3}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
assert len(result.content) >= 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_sse_response_handling(self):
|
|
||||||
"""Verify the client correctly handles SSE responses from a real server.
|
|
||||||
|
|
||||||
This is the key test — our local test server returns JSON,
|
|
||||||
but real MCP servers typically return SSE. This proves the
|
|
||||||
SSE parsing works end-to-end.
|
|
||||||
"""
|
|
||||||
client = MCPClient(OPENAI_DOCS_MCP_URL)
|
|
||||||
# initialize() internally calls _send_request which must parse SSE
|
|
||||||
result = await client.initialize()
|
|
||||||
|
|
||||||
# If we got here without error, SSE parsing works
|
|
||||||
assert isinstance(result, dict)
|
|
||||||
assert "protocolVersion" in result
|
|
||||||
|
|
||||||
# Also verify list_tools works (another SSE response)
|
|
||||||
tools = await client.list_tools()
|
|
||||||
assert len(tools) > 0
|
|
||||||
assert all(hasattr(t, "name") for t in tools)
|
|
||||||
@@ -1,389 +0,0 @@
|
|||||||
"""
|
|
||||||
Integration tests for MCP client and MCPToolBlock against a real HTTP server.
|
|
||||||
|
|
||||||
These tests spin up a local MCP test server and run the full client/block flow
|
|
||||||
against it — no mocking, real HTTP requests.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import threading
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from aiohttp import web
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.blocks.mcp.block import MCPToolBlock
|
|
||||||
from backend.blocks.mcp.client import MCPClient
|
|
||||||
from backend.blocks.mcp.test_server import create_test_mcp_app
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
MOCK_USER_ID = "test-user-integration"
|
|
||||||
|
|
||||||
|
|
||||||
class _MCPTestServer:
|
|
||||||
"""
|
|
||||||
Run an MCP test server in a background thread with its own event loop.
|
|
||||||
This avoids event loop conflicts with pytest-asyncio.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, auth_token: str | None = None):
|
|
||||||
self.auth_token = auth_token
|
|
||||||
self.url: str = ""
|
|
||||||
self._runner: web.AppRunner | None = None
|
|
||||||
self._loop: asyncio.AbstractEventLoop | None = None
|
|
||||||
self._thread: threading.Thread | None = None
|
|
||||||
self._started = threading.Event()
|
|
||||||
|
|
||||||
def _run(self):
|
|
||||||
self._loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(self._loop)
|
|
||||||
self._loop.run_until_complete(self._start())
|
|
||||||
self._started.set()
|
|
||||||
self._loop.run_forever()
|
|
||||||
|
|
||||||
async def _start(self):
|
|
||||||
app = create_test_mcp_app(auth_token=self.auth_token)
|
|
||||||
self._runner = web.AppRunner(app)
|
|
||||||
await self._runner.setup()
|
|
||||||
site = web.TCPSite(self._runner, "127.0.0.1", 0)
|
|
||||||
await site.start()
|
|
||||||
port = site._server.sockets[0].getsockname()[1] # type: ignore[union-attr]
|
|
||||||
self.url = f"http://127.0.0.1:{port}/mcp"
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
self._thread = threading.Thread(target=self._run, daemon=True)
|
|
||||||
self._thread.start()
|
|
||||||
if not self._started.wait(timeout=5):
|
|
||||||
raise RuntimeError("MCP test server failed to start within 5 seconds")
|
|
||||||
return self
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
if self._loop and self._runner:
|
|
||||||
asyncio.run_coroutine_threadsafe(self._runner.cleanup(), self._loop).result(
|
|
||||||
timeout=5
|
|
||||||
)
|
|
||||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
|
||||||
if self._thread:
|
|
||||||
self._thread.join(timeout=5)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def mcp_server():
|
|
||||||
"""Start a local MCP test server in a background thread."""
|
|
||||||
server = _MCPTestServer()
|
|
||||||
server.start()
|
|
||||||
yield server.url
|
|
||||||
server.stop()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def mcp_server_with_auth():
|
|
||||||
"""Start a local MCP test server with auth in a background thread."""
|
|
||||||
server = _MCPTestServer(auth_token="test-secret-token")
|
|
||||||
server.start()
|
|
||||||
yield server.url, "test-secret-token"
|
|
||||||
server.stop()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def _allow_localhost():
|
|
||||||
"""
|
|
||||||
Allow 127.0.0.1 through SSRF protection for integration tests.
|
|
||||||
|
|
||||||
The Requests class blocks private IPs by default. We patch the Requests
|
|
||||||
constructor to always include 127.0.0.1 as a trusted origin so the local
|
|
||||||
test server is reachable.
|
|
||||||
"""
|
|
||||||
from backend.util.request import Requests
|
|
||||||
|
|
||||||
original_init = Requests.__init__
|
|
||||||
|
|
||||||
def patched_init(self, *args, **kwargs):
|
|
||||||
trusted = list(kwargs.get("trusted_origins") or [])
|
|
||||||
trusted.append("http://127.0.0.1")
|
|
||||||
kwargs["trusted_origins"] = trusted
|
|
||||||
original_init(self, *args, **kwargs)
|
|
||||||
|
|
||||||
with patch.object(Requests, "__init__", patched_init):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
def _make_client(url: str, auth_token: str | None = None) -> MCPClient:
|
|
||||||
"""Create an MCPClient for integration tests."""
|
|
||||||
return MCPClient(url, auth_token=auth_token)
|
|
||||||
|
|
||||||
|
|
||||||
# ── MCPClient integration tests ──────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPClientIntegration:
|
|
||||||
"""Test MCPClient against a real local MCP server."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_initialize(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
result = await client.initialize()
|
|
||||||
|
|
||||||
assert result["protocolVersion"] == "2025-03-26"
|
|
||||||
assert result["serverInfo"]["name"] == "test-mcp-server"
|
|
||||||
assert "tools" in result["capabilities"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_list_tools(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert len(tools) == 3
|
|
||||||
|
|
||||||
tool_names = {t.name for t in tools}
|
|
||||||
assert tool_names == {"get_weather", "add_numbers", "echo"}
|
|
||||||
|
|
||||||
# Check get_weather schema
|
|
||||||
weather = next(t for t in tools if t.name == "get_weather")
|
|
||||||
assert weather.description == "Get current weather for a city"
|
|
||||||
assert "city" in weather.input_schema["properties"]
|
|
||||||
assert weather.input_schema["required"] == ["city"]
|
|
||||||
|
|
||||||
# Check add_numbers schema
|
|
||||||
add = next(t for t in tools if t.name == "add_numbers")
|
|
||||||
assert "a" in add.input_schema["properties"]
|
|
||||||
assert "b" in add.input_schema["properties"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_tool_get_weather(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool("get_weather", {"city": "London"})
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
assert len(result.content) == 1
|
|
||||||
assert result.content[0]["type"] == "text"
|
|
||||||
|
|
||||||
data = json.loads(result.content[0]["text"])
|
|
||||||
assert data["city"] == "London"
|
|
||||||
assert data["temperature"] == 22
|
|
||||||
assert data["condition"] == "sunny"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_tool_add_numbers(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool("add_numbers", {"a": 3, "b": 7})
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
data = json.loads(result.content[0]["text"])
|
|
||||||
assert data["result"] == 10
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_tool_echo(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool("echo", {"message": "Hello MCP!"})
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
assert result.content[0]["text"] == "Hello MCP!"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_unknown_tool(self, mcp_server):
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
result = await client.call_tool("nonexistent_tool", {})
|
|
||||||
|
|
||||||
assert result.is_error
|
|
||||||
assert "Unknown tool" in result.content[0]["text"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_auth_success(self, mcp_server_with_auth):
|
|
||||||
url, token = mcp_server_with_auth
|
|
||||||
client = _make_client(url, auth_token=token)
|
|
||||||
result = await client.initialize()
|
|
||||||
|
|
||||||
assert result["protocolVersion"] == "2025-03-26"
|
|
||||||
|
|
||||||
tools = await client.list_tools()
|
|
||||||
assert len(tools) == 3
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_auth_failure(self, mcp_server_with_auth):
|
|
||||||
url, _ = mcp_server_with_auth
|
|
||||||
client = _make_client(url, auth_token="wrong-token")
|
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
await client.initialize()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_auth_missing(self, mcp_server_with_auth):
|
|
||||||
url, _ = mcp_server_with_auth
|
|
||||||
client = _make_client(url)
|
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
await client.initialize()
|
|
||||||
|
|
||||||
|
|
||||||
# ── MCPToolBlock integration tests ───────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPToolBlockIntegration:
|
|
||||||
"""Test MCPToolBlock end-to-end against a real local MCP server."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_full_flow_get_weather(self, mcp_server):
|
|
||||||
"""Full flow: discover tools, select one, execute it."""
|
|
||||||
# Step 1: Discover tools (simulating what the frontend/API would do)
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
tools = await client.list_tools()
|
|
||||||
assert len(tools) == 3
|
|
||||||
|
|
||||||
# Step 2: User selects "get_weather" and we get its schema
|
|
||||||
weather_tool = next(t for t in tools if t.name == "get_weather")
|
|
||||||
|
|
||||||
# Step 3: Execute the block — no credentials (public server)
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=mcp_server,
|
|
||||||
selected_tool="get_weather",
|
|
||||||
tool_input_schema=weather_tool.input_schema,
|
|
||||||
tool_arguments={"city": "Paris"},
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
result = outputs[0][1]
|
|
||||||
assert result["city"] == "Paris"
|
|
||||||
assert result["temperature"] == 22
|
|
||||||
assert result["condition"] == "sunny"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_full_flow_add_numbers(self, mcp_server):
|
|
||||||
"""Full flow for add_numbers tool."""
|
|
||||||
client = _make_client(mcp_server)
|
|
||||||
await client.initialize()
|
|
||||||
tools = await client.list_tools()
|
|
||||||
add_tool = next(t for t in tools if t.name == "add_numbers")
|
|
||||||
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=mcp_server,
|
|
||||||
selected_tool="add_numbers",
|
|
||||||
tool_input_schema=add_tool.input_schema,
|
|
||||||
tool_arguments={"a": 42, "b": 58},
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
assert outputs[0][1]["result"] == 100
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_full_flow_echo_plain_text(self, mcp_server):
|
|
||||||
"""Verify plain text (non-JSON) responses work."""
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=mcp_server,
|
|
||||||
selected_tool="echo",
|
|
||||||
tool_input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"message": {"type": "string"}},
|
|
||||||
"required": ["message"],
|
|
||||||
},
|
|
||||||
tool_arguments={"message": "Hello from AutoGPT!"},
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
assert outputs[0][1] == "Hello from AutoGPT!"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_full_flow_unknown_tool_yields_error(self, mcp_server):
|
|
||||||
"""Calling an unknown tool should yield an error output."""
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=mcp_server,
|
|
||||||
selected_tool="nonexistent_tool",
|
|
||||||
tool_arguments={},
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "error"
|
|
||||||
assert "returned an error" in outputs[0][1]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_full_flow_with_auth(self, mcp_server_with_auth):
|
|
||||||
"""Full flow with authentication via credentials kwarg."""
|
|
||||||
url, token = mcp_server_with_auth
|
|
||||||
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=url,
|
|
||||||
selected_tool="echo",
|
|
||||||
tool_input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"message": {"type": "string"}},
|
|
||||||
"required": ["message"],
|
|
||||||
},
|
|
||||||
tool_arguments={"message": "Authenticated!"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Pass credentials via the standard kwarg (as the executor would)
|
|
||||||
test_creds = OAuth2Credentials(
|
|
||||||
id="test-cred",
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr(token),
|
|
||||||
refresh_token=SecretStr(""),
|
|
||||||
scopes=[],
|
|
||||||
title="Test MCP credential",
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(
|
|
||||||
input_data, user_id=MOCK_USER_ID, credentials=test_creds
|
|
||||||
):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
assert outputs[0][1] == "Authenticated!"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_no_credentials_runs_without_auth(self, mcp_server):
|
|
||||||
"""Block runs without auth when no credentials are provided."""
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url=mcp_server,
|
|
||||||
selected_tool="echo",
|
|
||||||
tool_input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"message": {"type": "string"}},
|
|
||||||
"required": ["message"],
|
|
||||||
},
|
|
||||||
tool_arguments={"message": "No auth needed"},
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(
|
|
||||||
input_data, user_id=MOCK_USER_ID, credentials=None
|
|
||||||
):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
assert outputs[0][1] == "No auth needed"
|
|
||||||
@@ -1,619 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for MCP client and MCPToolBlock.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.blocks.mcp.block import MCPToolBlock
|
|
||||||
from backend.blocks.mcp.client import MCPCallResult, MCPClient, MCPClientError
|
|
||||||
from backend.util.test import execute_block_test
|
|
||||||
|
|
||||||
# ── SSE parsing unit tests ───────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestSSEParsing:
|
|
||||||
"""Tests for SSE (text/event-stream) response parsing."""
|
|
||||||
|
|
||||||
def test_parse_sse_simple(self):
|
|
||||||
sse = (
|
|
||||||
"event: message\n"
|
|
||||||
'data: {"jsonrpc":"2.0","result":{"tools":[]},"id":1}\n'
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
body = MCPClient._parse_sse_response(sse)
|
|
||||||
assert body["result"] == {"tools": []}
|
|
||||||
assert body["id"] == 1
|
|
||||||
|
|
||||||
def test_parse_sse_with_notifications(self):
|
|
||||||
"""SSE streams can contain notifications (no id) before the response."""
|
|
||||||
sse = (
|
|
||||||
"event: message\n"
|
|
||||||
'data: {"jsonrpc":"2.0","method":"some/notification"}\n'
|
|
||||||
"\n"
|
|
||||||
"event: message\n"
|
|
||||||
'data: {"jsonrpc":"2.0","result":{"ok":true},"id":2}\n'
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
body = MCPClient._parse_sse_response(sse)
|
|
||||||
assert body["result"] == {"ok": True}
|
|
||||||
assert body["id"] == 2
|
|
||||||
|
|
||||||
def test_parse_sse_error_response(self):
|
|
||||||
sse = (
|
|
||||||
"event: message\n"
|
|
||||||
'data: {"jsonrpc":"2.0","error":{"code":-32600,"message":"Bad Request"},"id":1}\n'
|
|
||||||
)
|
|
||||||
body = MCPClient._parse_sse_response(sse)
|
|
||||||
assert "error" in body
|
|
||||||
assert body["error"]["code"] == -32600
|
|
||||||
|
|
||||||
def test_parse_sse_no_data_raises(self):
|
|
||||||
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
|
|
||||||
MCPClient._parse_sse_response("event: message\n\n")
|
|
||||||
|
|
||||||
def test_parse_sse_empty_raises(self):
|
|
||||||
with pytest.raises(MCPClientError, match="No JSON-RPC response found"):
|
|
||||||
MCPClient._parse_sse_response("")
|
|
||||||
|
|
||||||
def test_parse_sse_ignores_non_data_lines(self):
|
|
||||||
sse = (
|
|
||||||
": comment line\n"
|
|
||||||
"event: message\n"
|
|
||||||
"id: 123\n"
|
|
||||||
'data: {"jsonrpc":"2.0","result":"ok","id":1}\n'
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
body = MCPClient._parse_sse_response(sse)
|
|
||||||
assert body["result"] == "ok"
|
|
||||||
|
|
||||||
def test_parse_sse_uses_last_response(self):
|
|
||||||
"""If multiple responses exist, use the last one."""
|
|
||||||
sse = (
|
|
||||||
'data: {"jsonrpc":"2.0","result":"first","id":1}\n'
|
|
||||||
"\n"
|
|
||||||
'data: {"jsonrpc":"2.0","result":"second","id":2}\n'
|
|
||||||
"\n"
|
|
||||||
)
|
|
||||||
body = MCPClient._parse_sse_response(sse)
|
|
||||||
assert body["result"] == "second"
|
|
||||||
|
|
||||||
|
|
||||||
# ── MCPClient unit tests ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPClient:
|
|
||||||
"""Tests for the MCP HTTP client."""
|
|
||||||
|
|
||||||
def test_build_headers_without_auth(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
headers = client._build_headers()
|
|
||||||
assert "Authorization" not in headers
|
|
||||||
assert headers["Content-Type"] == "application/json"
|
|
||||||
|
|
||||||
def test_build_headers_with_auth(self):
|
|
||||||
client = MCPClient("https://mcp.example.com", auth_token="my-token")
|
|
||||||
headers = client._build_headers()
|
|
||||||
assert headers["Authorization"] == "Bearer my-token"
|
|
||||||
|
|
||||||
def test_build_jsonrpc_request(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
req = client._build_jsonrpc_request("tools/list")
|
|
||||||
assert req["jsonrpc"] == "2.0"
|
|
||||||
assert req["method"] == "tools/list"
|
|
||||||
assert "id" in req
|
|
||||||
assert "params" not in req
|
|
||||||
|
|
||||||
def test_build_jsonrpc_request_with_params(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
req = client._build_jsonrpc_request(
|
|
||||||
"tools/call", {"name": "test", "arguments": {"x": 1}}
|
|
||||||
)
|
|
||||||
assert req["params"] == {"name": "test", "arguments": {"x": 1}}
|
|
||||||
|
|
||||||
def test_request_id_increments(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
req1 = client._build_jsonrpc_request("tools/list")
|
|
||||||
req2 = client._build_jsonrpc_request("tools/list")
|
|
||||||
assert req2["id"] > req1["id"]
|
|
||||||
|
|
||||||
def test_server_url_trailing_slash_stripped(self):
|
|
||||||
client = MCPClient("https://mcp.example.com/mcp/")
|
|
||||||
assert client.server_url == "https://mcp.example.com/mcp"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_send_request_success(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
mock_response = AsyncMock()
|
|
||||||
mock_response.json.return_value = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"result": {"tools": []},
|
|
||||||
"id": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value={"tools": []}):
|
|
||||||
result = await client._send_request("tools/list")
|
|
||||||
assert result == {"tools": []}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_send_request_error(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
async def mock_send(*args, **kwargs):
|
|
||||||
raise MCPClientError("MCP server error [-32600]: Invalid Request")
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", side_effect=mock_send):
|
|
||||||
with pytest.raises(MCPClientError, match="Invalid Request"):
|
|
||||||
await client._send_request("tools/list")
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_list_tools(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
mock_result = {
|
|
||||||
"tools": [
|
|
||||||
{
|
|
||||||
"name": "get_weather",
|
|
||||||
"description": "Get current weather for a city",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"city": {"type": "string"}},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "search",
|
|
||||||
"description": "Search the web",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"query": {"type": "string"}},
|
|
||||||
"required": ["query"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value=mock_result):
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert len(tools) == 2
|
|
||||||
assert tools[0].name == "get_weather"
|
|
||||||
assert tools[0].description == "Get current weather for a city"
|
|
||||||
assert tools[0].input_schema["properties"]["city"]["type"] == "string"
|
|
||||||
assert tools[1].name == "search"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_list_tools_empty(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value={"tools": []}):
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert tools == []
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_list_tools_none_result(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value=None):
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert tools == []
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_tool_success(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
mock_result = {
|
|
||||||
"content": [
|
|
||||||
{"type": "text", "text": json.dumps({"temp": 20, "city": "London"})}
|
|
||||||
],
|
|
||||||
"isError": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value=mock_result):
|
|
||||||
result = await client.call_tool("get_weather", {"city": "London"})
|
|
||||||
|
|
||||||
assert not result.is_error
|
|
||||||
assert len(result.content) == 1
|
|
||||||
assert result.content[0]["type"] == "text"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_tool_error(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
mock_result = {
|
|
||||||
"content": [{"type": "text", "text": "City not found"}],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value=mock_result):
|
|
||||||
result = await client.call_tool("get_weather", {"city": "???"})
|
|
||||||
|
|
||||||
assert result.is_error
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_tool_none_result(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
with patch.object(client, "_send_request", return_value=None):
|
|
||||||
result = await client.call_tool("get_weather", {"city": "London"})
|
|
||||||
|
|
||||||
assert result.is_error
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_initialize(self):
|
|
||||||
client = MCPClient("https://mcp.example.com")
|
|
||||||
|
|
||||||
mock_result = {
|
|
||||||
"protocolVersion": "2025-03-26",
|
|
||||||
"capabilities": {"tools": {}},
|
|
||||||
"serverInfo": {"name": "test-server", "version": "1.0.0"},
|
|
||||||
}
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(client, "_send_request", return_value=mock_result) as mock_req,
|
|
||||||
patch.object(client, "_send_notification") as mock_notif,
|
|
||||||
):
|
|
||||||
result = await client.initialize()
|
|
||||||
|
|
||||||
mock_req.assert_called_once()
|
|
||||||
mock_notif.assert_called_once_with("notifications/initialized")
|
|
||||||
assert result["protocolVersion"] == "2025-03-26"
|
|
||||||
|
|
||||||
|
|
||||||
# ── MCPToolBlock unit tests ──────────────────────────────────────────
|
|
||||||
|
|
||||||
MOCK_USER_ID = "test-user-123"
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPToolBlock:
|
|
||||||
"""Tests for the MCPToolBlock."""
|
|
||||||
|
|
||||||
def test_block_instantiation(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
assert block.id == "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4"
|
|
||||||
assert block.name == "MCPToolBlock"
|
|
||||||
|
|
||||||
def test_input_schema_has_required_fields(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
schema = block.input_schema.jsonschema()
|
|
||||||
props = schema.get("properties", {})
|
|
||||||
assert "server_url" in props
|
|
||||||
assert "selected_tool" in props
|
|
||||||
assert "tool_arguments" in props
|
|
||||||
assert "credentials" in props
|
|
||||||
|
|
||||||
def test_output_schema(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
schema = block.output_schema.jsonschema()
|
|
||||||
props = schema.get("properties", {})
|
|
||||||
assert "result" in props
|
|
||||||
assert "error" in props
|
|
||||||
|
|
||||||
def test_get_input_schema_with_tool_schema(self):
|
|
||||||
tool_schema = {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"query": {"type": "string"}},
|
|
||||||
"required": ["query"],
|
|
||||||
}
|
|
||||||
data = {"tool_input_schema": tool_schema}
|
|
||||||
result = MCPToolBlock.Input.get_input_schema(data)
|
|
||||||
assert result == tool_schema
|
|
||||||
|
|
||||||
def test_get_input_schema_without_tool_schema(self):
|
|
||||||
result = MCPToolBlock.Input.get_input_schema({})
|
|
||||||
assert result == {}
|
|
||||||
|
|
||||||
def test_get_input_defaults(self):
|
|
||||||
data = {"tool_arguments": {"city": "London"}}
|
|
||||||
result = MCPToolBlock.Input.get_input_defaults(data)
|
|
||||||
assert result == {"city": "London"}
|
|
||||||
|
|
||||||
def test_get_missing_input(self):
|
|
||||||
data = {
|
|
||||||
"tool_input_schema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"city": {"type": "string"},
|
|
||||||
"units": {"type": "string"},
|
|
||||||
},
|
|
||||||
"required": ["city", "units"],
|
|
||||||
},
|
|
||||||
"tool_arguments": {"city": "London"},
|
|
||||||
}
|
|
||||||
missing = MCPToolBlock.Input.get_missing_input(data)
|
|
||||||
assert missing == {"units"}
|
|
||||||
|
|
||||||
def test_get_missing_input_all_present(self):
|
|
||||||
data = {
|
|
||||||
"tool_input_schema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"city": {"type": "string"}},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
"tool_arguments": {"city": "London"},
|
|
||||||
}
|
|
||||||
missing = MCPToolBlock.Input.get_missing_input(data)
|
|
||||||
assert missing == set()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_run_with_mock(self):
|
|
||||||
"""Test the block using the built-in test infrastructure."""
|
|
||||||
block = MCPToolBlock()
|
|
||||||
await execute_block_test(block)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_run_missing_server_url(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="",
|
|
||||||
selected_tool="test",
|
|
||||||
)
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
assert outputs == [("error", "MCP server URL is required")]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_run_missing_tool(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="https://mcp.example.com/mcp",
|
|
||||||
selected_tool="",
|
|
||||||
)
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
assert outputs == [
|
|
||||||
("error", "No tool selected. Please select a tool from the dropdown.")
|
|
||||||
]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_run_success(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="https://mcp.example.com/mcp",
|
|
||||||
selected_tool="get_weather",
|
|
||||||
tool_input_schema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {"city": {"type": "string"}},
|
|
||||||
},
|
|
||||||
tool_arguments={"city": "London"},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_call(*args, **kwargs):
|
|
||||||
return {"temp": 20, "city": "London"}
|
|
||||||
|
|
||||||
block._call_mcp_tool = mock_call # type: ignore
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert outputs[0][0] == "result"
|
|
||||||
assert outputs[0][1] == {"temp": 20, "city": "London"}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_run_mcp_error(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="https://mcp.example.com/mcp",
|
|
||||||
selected_tool="bad_tool",
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_call(*args, **kwargs):
|
|
||||||
raise MCPClientError("Tool not found")
|
|
||||||
|
|
||||||
block._call_mcp_tool = mock_call # type: ignore
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert outputs[0][0] == "error"
|
|
||||||
assert "Tool not found" in outputs[0][1]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_mcp_tool_parses_json_text(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
|
|
||||||
mock_result = MCPCallResult(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": '{"temp": 20}'},
|
|
||||||
],
|
|
||||||
is_error=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_init(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def mock_call(self, name, args):
|
|
||||||
return mock_result
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MCPClient, "initialize", mock_init),
|
|
||||||
patch.object(MCPClient, "call_tool", mock_call),
|
|
||||||
):
|
|
||||||
result = await block._call_mcp_tool(
|
|
||||||
"https://mcp.example.com", "test_tool", {}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == {"temp": 20}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_mcp_tool_plain_text(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
|
|
||||||
mock_result = MCPCallResult(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "Hello, world!"},
|
|
||||||
],
|
|
||||||
is_error=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_init(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def mock_call(self, name, args):
|
|
||||||
return mock_result
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MCPClient, "initialize", mock_init),
|
|
||||||
patch.object(MCPClient, "call_tool", mock_call),
|
|
||||||
):
|
|
||||||
result = await block._call_mcp_tool(
|
|
||||||
"https://mcp.example.com", "test_tool", {}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == "Hello, world!"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_mcp_tool_multiple_content(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
|
|
||||||
mock_result = MCPCallResult(
|
|
||||||
content=[
|
|
||||||
{"type": "text", "text": "Part 1"},
|
|
||||||
{"type": "text", "text": '{"part": 2}'},
|
|
||||||
],
|
|
||||||
is_error=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_init(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def mock_call(self, name, args):
|
|
||||||
return mock_result
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MCPClient, "initialize", mock_init),
|
|
||||||
patch.object(MCPClient, "call_tool", mock_call),
|
|
||||||
):
|
|
||||||
result = await block._call_mcp_tool(
|
|
||||||
"https://mcp.example.com", "test_tool", {}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == ["Part 1", {"part": 2}]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_mcp_tool_error_result(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
|
|
||||||
mock_result = MCPCallResult(
|
|
||||||
content=[{"type": "text", "text": "Something went wrong"}],
|
|
||||||
is_error=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_init(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def mock_call(self, name, args):
|
|
||||||
return mock_result
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MCPClient, "initialize", mock_init),
|
|
||||||
patch.object(MCPClient, "call_tool", mock_call),
|
|
||||||
):
|
|
||||||
with pytest.raises(MCPClientError, match="returned an error"):
|
|
||||||
await block._call_mcp_tool("https://mcp.example.com", "test_tool", {})
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_call_mcp_tool_image_content(self):
|
|
||||||
block = MCPToolBlock()
|
|
||||||
|
|
||||||
mock_result = MCPCallResult(
|
|
||||||
content=[
|
|
||||||
{
|
|
||||||
"type": "image",
|
|
||||||
"data": "base64data==",
|
|
||||||
"mimeType": "image/png",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
is_error=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_init(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def mock_call(self, name, args):
|
|
||||||
return mock_result
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(MCPClient, "initialize", mock_init),
|
|
||||||
patch.object(MCPClient, "call_tool", mock_call),
|
|
||||||
):
|
|
||||||
result = await block._call_mcp_tool(
|
|
||||||
"https://mcp.example.com", "test_tool", {}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == {
|
|
||||||
"type": "image",
|
|
||||||
"data": "base64data==",
|
|
||||||
"mimeType": "image/png",
|
|
||||||
}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_run_with_credentials(self):
|
|
||||||
"""Verify the block uses OAuth2Credentials and passes auth token."""
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="https://mcp.example.com/mcp",
|
|
||||||
selected_tool="test_tool",
|
|
||||||
)
|
|
||||||
|
|
||||||
captured_tokens: list[str | None] = []
|
|
||||||
|
|
||||||
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
|
||||||
captured_tokens.append(auth_token)
|
|
||||||
return "ok"
|
|
||||||
|
|
||||||
block._call_mcp_tool = mock_call # type: ignore
|
|
||||||
|
|
||||||
test_creds = OAuth2Credentials(
|
|
||||||
id="cred-123",
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("resolved-token"),
|
|
||||||
refresh_token=SecretStr(""),
|
|
||||||
scopes=[],
|
|
||||||
title="Test MCP credential",
|
|
||||||
)
|
|
||||||
|
|
||||||
async for _ in block.run(
|
|
||||||
input_data, user_id=MOCK_USER_ID, credentials=test_creds
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
assert captured_tokens == ["resolved-token"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_run_without_credentials(self):
|
|
||||||
"""Verify the block works without credentials (public server)."""
|
|
||||||
block = MCPToolBlock()
|
|
||||||
input_data = MCPToolBlock.Input(
|
|
||||||
server_url="https://mcp.example.com/mcp",
|
|
||||||
selected_tool="test_tool",
|
|
||||||
)
|
|
||||||
|
|
||||||
captured_tokens: list[str | None] = []
|
|
||||||
|
|
||||||
async def mock_call(server_url, tool_name, arguments, auth_token=None):
|
|
||||||
captured_tokens.append(auth_token)
|
|
||||||
return "ok"
|
|
||||||
|
|
||||||
block._call_mcp_tool = mock_call # type: ignore
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
async for name, data in block.run(input_data, user_id=MOCK_USER_ID):
|
|
||||||
outputs.append((name, data))
|
|
||||||
|
|
||||||
assert captured_tokens == [None]
|
|
||||||
assert outputs == [("result", "ok")]
|
|
||||||
@@ -1,242 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for MCP OAuth handler.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import SecretStr
|
|
||||||
|
|
||||||
from backend.blocks.mcp.client import MCPClient
|
|
||||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
|
||||||
from backend.data.model import OAuth2Credentials
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_response(json_data: dict, status: int = 200) -> MagicMock:
|
|
||||||
"""Create a mock Response with synchronous json() (matching Requests.Response)."""
|
|
||||||
resp = MagicMock()
|
|
||||||
resp.status = status
|
|
||||||
resp.ok = 200 <= status < 300
|
|
||||||
resp.json.return_value = json_data
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPOAuthHandler:
|
|
||||||
"""Tests for the MCPOAuthHandler."""
|
|
||||||
|
|
||||||
def _make_handler(self, **overrides) -> MCPOAuthHandler:
|
|
||||||
defaults = {
|
|
||||||
"client_id": "test-client-id",
|
|
||||||
"client_secret": "test-client-secret",
|
|
||||||
"redirect_uri": "https://app.example.com/callback",
|
|
||||||
"authorize_url": "https://auth.example.com/authorize",
|
|
||||||
"token_url": "https://auth.example.com/token",
|
|
||||||
}
|
|
||||||
defaults.update(overrides)
|
|
||||||
return MCPOAuthHandler(**defaults)
|
|
||||||
|
|
||||||
def test_get_login_url_basic(self):
|
|
||||||
handler = self._make_handler()
|
|
||||||
url = handler.get_login_url(
|
|
||||||
scopes=["read", "write"],
|
|
||||||
state="random-state-token",
|
|
||||||
code_challenge="S256-challenge-value",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "https://auth.example.com/authorize?" in url
|
|
||||||
assert "response_type=code" in url
|
|
||||||
assert "client_id=test-client-id" in url
|
|
||||||
assert "state=random-state-token" in url
|
|
||||||
assert "code_challenge=S256-challenge-value" in url
|
|
||||||
assert "code_challenge_method=S256" in url
|
|
||||||
assert "scope=read+write" in url
|
|
||||||
|
|
||||||
def test_get_login_url_with_resource(self):
|
|
||||||
handler = self._make_handler(resource_url="https://mcp.example.com/mcp")
|
|
||||||
url = handler.get_login_url(
|
|
||||||
scopes=[], state="state", code_challenge="challenge"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "resource=https" in url
|
|
||||||
|
|
||||||
def test_get_login_url_without_pkce(self):
|
|
||||||
handler = self._make_handler()
|
|
||||||
url = handler.get_login_url(scopes=["read"], state="state", code_challenge=None)
|
|
||||||
|
|
||||||
assert "code_challenge" not in url
|
|
||||||
assert "code_challenge_method" not in url
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_exchange_code_for_tokens(self):
|
|
||||||
handler = self._make_handler()
|
|
||||||
|
|
||||||
resp = _mock_response(
|
|
||||||
{
|
|
||||||
"access_token": "new-access-token",
|
|
||||||
"refresh_token": "new-refresh-token",
|
|
||||||
"expires_in": 3600,
|
|
||||||
"token_type": "Bearer",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.post = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
creds = await handler.exchange_code_for_tokens(
|
|
||||||
code="auth-code",
|
|
||||||
scopes=["read"],
|
|
||||||
code_verifier="pkce-verifier",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(creds, OAuth2Credentials)
|
|
||||||
assert creds.access_token.get_secret_value() == "new-access-token"
|
|
||||||
assert creds.refresh_token is not None
|
|
||||||
assert creds.refresh_token.get_secret_value() == "new-refresh-token"
|
|
||||||
assert creds.scopes == ["read"]
|
|
||||||
assert creds.access_token_expires_at is not None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_refresh_tokens(self):
|
|
||||||
handler = self._make_handler()
|
|
||||||
|
|
||||||
existing_creds = OAuth2Credentials(
|
|
||||||
id="existing-id",
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("old-token"),
|
|
||||||
refresh_token=SecretStr("old-refresh"),
|
|
||||||
scopes=["read"],
|
|
||||||
title="test",
|
|
||||||
)
|
|
||||||
|
|
||||||
resp = _mock_response(
|
|
||||||
{
|
|
||||||
"access_token": "refreshed-token",
|
|
||||||
"refresh_token": "new-refresh",
|
|
||||||
"expires_in": 3600,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.post = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
refreshed = await handler._refresh_tokens(existing_creds)
|
|
||||||
|
|
||||||
assert refreshed.id == "existing-id"
|
|
||||||
assert refreshed.access_token.get_secret_value() == "refreshed-token"
|
|
||||||
assert refreshed.refresh_token is not None
|
|
||||||
assert refreshed.refresh_token.get_secret_value() == "new-refresh"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_refresh_tokens_no_refresh_token(self):
|
|
||||||
handler = self._make_handler()
|
|
||||||
|
|
||||||
creds = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("token"),
|
|
||||||
scopes=["read"],
|
|
||||||
title="test",
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="No refresh token"):
|
|
||||||
await handler._refresh_tokens(creds)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_revoke_tokens_no_url(self):
|
|
||||||
handler = self._make_handler(revoke_url=None)
|
|
||||||
|
|
||||||
creds = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("token"),
|
|
||||||
scopes=[],
|
|
||||||
title="test",
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await handler.revoke_tokens(creds)
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_revoke_tokens_with_url(self):
|
|
||||||
handler = self._make_handler(revoke_url="https://auth.example.com/revoke")
|
|
||||||
|
|
||||||
creds = OAuth2Credentials(
|
|
||||||
provider="mcp",
|
|
||||||
access_token=SecretStr("token"),
|
|
||||||
scopes=[],
|
|
||||||
title="test",
|
|
||||||
)
|
|
||||||
|
|
||||||
resp = _mock_response({}, status=200)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.oauth.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.post = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
result = await handler.revoke_tokens(creds)
|
|
||||||
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPClientDiscovery:
|
|
||||||
"""Tests for MCPClient OAuth metadata discovery."""
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_auth_found(self):
|
|
||||||
client = MCPClient("https://mcp.example.com/mcp")
|
|
||||||
|
|
||||||
metadata = {
|
|
||||||
"authorization_servers": ["https://auth.example.com"],
|
|
||||||
"resource": "https://mcp.example.com/mcp",
|
|
||||||
}
|
|
||||||
|
|
||||||
resp = _mock_response(metadata, status=200)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.get = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
result = await client.discover_auth()
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
assert result["authorization_servers"] == ["https://auth.example.com"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_auth_not_found(self):
|
|
||||||
client = MCPClient("https://mcp.example.com/mcp")
|
|
||||||
|
|
||||||
resp = _mock_response({}, status=404)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.get = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
result = await client.discover_auth()
|
|
||||||
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(loop_scope="session")
|
|
||||||
async def test_discover_auth_server_metadata(self):
|
|
||||||
client = MCPClient("https://mcp.example.com/mcp")
|
|
||||||
|
|
||||||
server_metadata = {
|
|
||||||
"issuer": "https://auth.example.com",
|
|
||||||
"authorization_endpoint": "https://auth.example.com/authorize",
|
|
||||||
"token_endpoint": "https://auth.example.com/token",
|
|
||||||
"registration_endpoint": "https://auth.example.com/register",
|
|
||||||
"code_challenge_methods_supported": ["S256"],
|
|
||||||
}
|
|
||||||
|
|
||||||
resp = _mock_response(server_metadata, status=200)
|
|
||||||
|
|
||||||
with patch("backend.blocks.mcp.client.Requests") as MockRequests:
|
|
||||||
instance = MockRequests.return_value
|
|
||||||
instance.get = AsyncMock(return_value=resp)
|
|
||||||
|
|
||||||
result = await client.discover_auth_server_metadata(
|
|
||||||
"https://auth.example.com"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result is not None
|
|
||||||
assert result["authorization_endpoint"] == "https://auth.example.com/authorize"
|
|
||||||
assert result["token_endpoint"] == "https://auth.example.com/token"
|
|
||||||
@@ -1,162 +0,0 @@
|
|||||||
"""
|
|
||||||
Minimal MCP server for integration testing.
|
|
||||||
|
|
||||||
Implements the MCP Streamable HTTP transport (JSON-RPC 2.0 over HTTP POST)
|
|
||||||
with a few sample tools. Runs on localhost with a random available port.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from aiohttp import web
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Sample tools this test server exposes
|
|
||||||
TEST_TOOLS = [
|
|
||||||
{
|
|
||||||
"name": "get_weather",
|
|
||||||
"description": "Get current weather for a city",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"city": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "City name",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "add_numbers",
|
|
||||||
"description": "Add two numbers together",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"a": {"type": "number", "description": "First number"},
|
|
||||||
"b": {"type": "number", "description": "Second number"},
|
|
||||||
},
|
|
||||||
"required": ["a", "b"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "echo",
|
|
||||||
"description": "Echo back the input message",
|
|
||||||
"inputSchema": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"message": {"type": "string", "description": "Message to echo"},
|
|
||||||
},
|
|
||||||
"required": ["message"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_initialize(params: dict) -> dict:
|
|
||||||
return {
|
|
||||||
"protocolVersion": "2025-03-26",
|
|
||||||
"capabilities": {"tools": {"listChanged": False}},
|
|
||||||
"serverInfo": {"name": "test-mcp-server", "version": "1.0.0"},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_tools_list(params: dict) -> dict:
|
|
||||||
return {"tools": TEST_TOOLS}
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_tools_call(params: dict) -> dict:
|
|
||||||
tool_name = params.get("name", "")
|
|
||||||
arguments = params.get("arguments", {})
|
|
||||||
|
|
||||||
if tool_name == "get_weather":
|
|
||||||
city = arguments.get("city", "Unknown")
|
|
||||||
return {
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": json.dumps(
|
|
||||||
{"city": city, "temperature": 22, "condition": "sunny"}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
elif tool_name == "add_numbers":
|
|
||||||
a = arguments.get("a", 0)
|
|
||||||
b = arguments.get("b", 0)
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": json.dumps({"result": a + b})}],
|
|
||||||
}
|
|
||||||
|
|
||||||
elif tool_name == "echo":
|
|
||||||
message = arguments.get("message", "")
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": message}],
|
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"content": [{"type": "text", "text": f"Unknown tool: {tool_name}"}],
|
|
||||||
"isError": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
HANDLERS = {
|
|
||||||
"initialize": _handle_initialize,
|
|
||||||
"tools/list": _handle_tools_list,
|
|
||||||
"tools/call": _handle_tools_call,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_mcp_request(request: web.Request) -> web.Response:
|
|
||||||
"""Handle incoming MCP JSON-RPC 2.0 requests."""
|
|
||||||
# Check auth if configured
|
|
||||||
expected_token = request.app.get("auth_token")
|
|
||||||
if expected_token:
|
|
||||||
auth_header = request.headers.get("Authorization", "")
|
|
||||||
if auth_header != f"Bearer {expected_token}":
|
|
||||||
return web.json_response(
|
|
||||||
{
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"error": {"code": -32001, "message": "Unauthorized"},
|
|
||||||
"id": None,
|
|
||||||
},
|
|
||||||
status=401,
|
|
||||||
)
|
|
||||||
|
|
||||||
body = await request.json()
|
|
||||||
|
|
||||||
# Handle notifications (no id field) — just acknowledge
|
|
||||||
if "id" not in body:
|
|
||||||
return web.Response(status=202)
|
|
||||||
|
|
||||||
method = body.get("method", "")
|
|
||||||
params = body.get("params", {})
|
|
||||||
request_id = body.get("id")
|
|
||||||
|
|
||||||
handler = HANDLERS.get(method)
|
|
||||||
if not handler:
|
|
||||||
return web.json_response(
|
|
||||||
{
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"error": {
|
|
||||||
"code": -32601,
|
|
||||||
"message": f"Method not found: {method}",
|
|
||||||
},
|
|
||||||
"id": request_id,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
result = handler(params)
|
|
||||||
return web.json_response({"jsonrpc": "2.0", "result": result, "id": request_id})
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_mcp_app(auth_token: str | None = None) -> web.Application:
|
|
||||||
"""Create an aiohttp app that acts as an MCP server."""
|
|
||||||
app = web.Application()
|
|
||||||
app.router.add_post("/mcp", handle_mcp_request)
|
|
||||||
if auth_token:
|
|
||||||
app["auth_token"] = auth_token
|
|
||||||
return app
|
|
||||||
@@ -33,7 +33,6 @@ from backend.util import type as type_utils
|
|||||||
from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError
|
from backend.util.exceptions import GraphNotAccessibleError, GraphNotInLibraryError
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
from backend.util.models import Pagination
|
from backend.util.models import Pagination
|
||||||
from backend.util.request import parse_url
|
|
||||||
|
|
||||||
from .block import BlockInput
|
from .block import BlockInput
|
||||||
from .db import BaseDbModel
|
from .db import BaseDbModel
|
||||||
@@ -450,9 +449,6 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
continue
|
continue
|
||||||
if ProviderName.HTTP in field.provider:
|
if ProviderName.HTTP in field.provider:
|
||||||
continue
|
continue
|
||||||
# MCP credentials are intentionally split by server URL
|
|
||||||
if ProviderName.MCP in field.provider:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# If this happens, that means a block implementation probably needs
|
# If this happens, that means a block implementation probably needs
|
||||||
# to be updated.
|
# to be updated.
|
||||||
@@ -509,18 +505,6 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
"required": ["id", "provider", "type"],
|
"required": ["id", "provider", "type"],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add a descriptive display title when URL-based discriminator values
|
|
||||||
# are present (e.g. "mcp.sentry.dev" instead of just "Mcp")
|
|
||||||
if (
|
|
||||||
field_info.discriminator
|
|
||||||
and not field_info.discriminator_mapping
|
|
||||||
and field_info.discriminator_values
|
|
||||||
):
|
|
||||||
hostnames = sorted(
|
|
||||||
parse_url(str(v)).netloc for v in field_info.discriminator_values
|
|
||||||
)
|
|
||||||
field_schema["display_name"] = ", ".join(hostnames)
|
|
||||||
|
|
||||||
# Add other (optional) field info items
|
# Add other (optional) field info items
|
||||||
field_schema.update(
|
field_schema.update(
|
||||||
field_info.model_dump(
|
field_info.model_dump(
|
||||||
@@ -565,17 +549,8 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
|
|
||||||
for graph in [self] + self.sub_graphs:
|
for graph in [self] + self.sub_graphs:
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
# A node's credentials are optional if either:
|
# Track if this node requires credentials (credentials_optional=False means required)
|
||||||
# 1. The node metadata says so (credentials_optional=True), or
|
node_required_map[node.id] = not node.credentials_optional
|
||||||
# 2. All credential fields on the block have defaults (not required by schema)
|
|
||||||
block_required = node.block.input_schema.get_required_fields()
|
|
||||||
creds_required_by_schema = any(
|
|
||||||
fname in block_required
|
|
||||||
for fname in node.block.input_schema.get_credentials_fields()
|
|
||||||
)
|
|
||||||
node_required_map[node.id] = (
|
|
||||||
not node.credentials_optional and creds_required_by_schema
|
|
||||||
)
|
|
||||||
|
|
||||||
for (
|
for (
|
||||||
field_name,
|
field_name,
|
||||||
@@ -801,19 +776,6 @@ class GraphModel(Graph, GraphMeta):
|
|||||||
"'credentials' and `*_credentials` are reserved"
|
"'credentials' and `*_credentials` are reserved"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check custom block-level validation (e.g., MCP dynamic tool arguments).
|
|
||||||
# Blocks can override get_missing_input to report additional missing fields
|
|
||||||
# beyond the standard top-level required fields.
|
|
||||||
if for_run:
|
|
||||||
credential_fields = InputSchema.get_credentials_fields()
|
|
||||||
custom_missing = InputSchema.get_missing_input(node.input_default)
|
|
||||||
for field_name in custom_missing:
|
|
||||||
if (
|
|
||||||
field_name not in provided_inputs
|
|
||||||
and field_name not in credential_fields
|
|
||||||
):
|
|
||||||
node_errors[node.id][field_name] = "This field is required"
|
|
||||||
|
|
||||||
# Get input schema properties and check dependencies
|
# Get input schema properties and check dependencies
|
||||||
input_fields = InputSchema.model_fields
|
input_fields = InputSchema.model_fields
|
||||||
|
|
||||||
|
|||||||
@@ -462,120 +462,3 @@ def test_node_credentials_optional_with_other_metadata():
|
|||||||
assert node.credentials_optional is True
|
assert node.credentials_optional is True
|
||||||
assert node.metadata["position"] == {"x": 100, "y": 200}
|
assert node.metadata["position"] == {"x": 100, "y": 200}
|
||||||
assert node.metadata["customized_name"] == "My Custom Node"
|
assert node.metadata["customized_name"] == "My Custom Node"
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Tests for MCP Credential Deduplication
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
|
|
||||||
def test_mcp_credential_combine_different_servers():
|
|
||||||
"""Two MCP credential fields with different server URLs should produce
|
|
||||||
separate entries when combined (not merged into one)."""
|
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsType
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
|
|
||||||
|
|
||||||
field_sentry = CredentialsFieldInfo(
|
|
||||||
credentials_provider=frozenset([ProviderName.MCP]),
|
|
||||||
credentials_types=oauth2_types,
|
|
||||||
credentials_scopes=None,
|
|
||||||
discriminator="server_url",
|
|
||||||
discriminator_values={"https://mcp.sentry.dev/mcp"},
|
|
||||||
)
|
|
||||||
field_linear = CredentialsFieldInfo(
|
|
||||||
credentials_provider=frozenset([ProviderName.MCP]),
|
|
||||||
credentials_types=oauth2_types,
|
|
||||||
credentials_scopes=None,
|
|
||||||
discriminator="server_url",
|
|
||||||
discriminator_values={"https://mcp.linear.app/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
combined = CredentialsFieldInfo.combine(
|
|
||||||
(field_sentry, ("node-sentry", "credentials")),
|
|
||||||
(field_linear, ("node-linear", "credentials")),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should produce 2 separate credential entries
|
|
||||||
assert len(combined) == 2, (
|
|
||||||
f"Expected 2 credential entries for 2 MCP blocks with different servers, "
|
|
||||||
f"got {len(combined)}: {list(combined.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Each entry should contain the server hostname in its key
|
|
||||||
keys = list(combined.keys())
|
|
||||||
assert any(
|
|
||||||
"mcp.sentry.dev" in k for k in keys
|
|
||||||
), f"Expected 'mcp.sentry.dev' in one key, got {keys}"
|
|
||||||
assert any(
|
|
||||||
"mcp.linear.app" in k for k in keys
|
|
||||||
), f"Expected 'mcp.linear.app' in one key, got {keys}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_mcp_credential_combine_same_server():
|
|
||||||
"""Two MCP credential fields with the same server URL should be combined
|
|
||||||
into one credential entry."""
|
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsType
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
|
|
||||||
|
|
||||||
field_a = CredentialsFieldInfo(
|
|
||||||
credentials_provider=frozenset([ProviderName.MCP]),
|
|
||||||
credentials_types=oauth2_types,
|
|
||||||
credentials_scopes=None,
|
|
||||||
discriminator="server_url",
|
|
||||||
discriminator_values={"https://mcp.sentry.dev/mcp"},
|
|
||||||
)
|
|
||||||
field_b = CredentialsFieldInfo(
|
|
||||||
credentials_provider=frozenset([ProviderName.MCP]),
|
|
||||||
credentials_types=oauth2_types,
|
|
||||||
credentials_scopes=None,
|
|
||||||
discriminator="server_url",
|
|
||||||
discriminator_values={"https://mcp.sentry.dev/mcp"},
|
|
||||||
)
|
|
||||||
|
|
||||||
combined = CredentialsFieldInfo.combine(
|
|
||||||
(field_a, ("node-a", "credentials")),
|
|
||||||
(field_b, ("node-b", "credentials")),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should produce 1 credential entry (same server URL)
|
|
||||||
assert len(combined) == 1, (
|
|
||||||
f"Expected 1 credential entry for 2 MCP blocks with same server, "
|
|
||||||
f"got {len(combined)}: {list(combined.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_mcp_credential_combine_no_discriminator_values():
|
|
||||||
"""MCP credential fields without discriminator_values should be merged
|
|
||||||
into a single entry (backwards compat for blocks without server_url set)."""
|
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsType
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
oauth2_types: frozenset[CredentialsType] = frozenset(["oauth2"])
|
|
||||||
|
|
||||||
field_a = CredentialsFieldInfo(
|
|
||||||
credentials_provider=frozenset([ProviderName.MCP]),
|
|
||||||
credentials_types=oauth2_types,
|
|
||||||
credentials_scopes=None,
|
|
||||||
discriminator="server_url",
|
|
||||||
)
|
|
||||||
field_b = CredentialsFieldInfo(
|
|
||||||
credentials_provider=frozenset([ProviderName.MCP]),
|
|
||||||
credentials_types=oauth2_types,
|
|
||||||
credentials_scopes=None,
|
|
||||||
discriminator="server_url",
|
|
||||||
)
|
|
||||||
|
|
||||||
combined = CredentialsFieldInfo.combine(
|
|
||||||
(field_a, ("node-a", "credentials")),
|
|
||||||
(field_b, ("node-b", "credentials")),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should produce 1 entry (no URL differentiation)
|
|
||||||
assert len(combined) == 1, (
|
|
||||||
f"Expected 1 credential entry for MCP blocks without discriminator_values, "
|
|
||||||
f"got {len(combined)}: {list(combined.keys())}"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ from pydantic import (
|
|||||||
GetCoreSchemaHandler,
|
GetCoreSchemaHandler,
|
||||||
SecretStr,
|
SecretStr,
|
||||||
field_serializer,
|
field_serializer,
|
||||||
model_validator,
|
|
||||||
)
|
)
|
||||||
from pydantic_core import (
|
from pydantic_core import (
|
||||||
CoreSchema,
|
CoreSchema,
|
||||||
@@ -503,25 +502,6 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
provider: CP
|
provider: CP
|
||||||
type: CT
|
type: CT
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def _normalize_legacy_provider(cls, data: Any) -> Any:
|
|
||||||
"""Fix ``ProviderName.X`` format from Python 3.13 ``str(Enum)`` bug.
|
|
||||||
|
|
||||||
Python 3.13 changed ``str(StrEnum)`` to return ``"ClassName.MEMBER"``
|
|
||||||
instead of the plain value. Old stored credential references may have
|
|
||||||
``provider: "ProviderName.MCP"`` instead of ``"mcp"``.
|
|
||||||
"""
|
|
||||||
if isinstance(data, dict):
|
|
||||||
prov = data.get("provider", "")
|
|
||||||
if isinstance(prov, str) and prov.startswith("ProviderName."):
|
|
||||||
member = prov.removeprefix("ProviderName.")
|
|
||||||
try:
|
|
||||||
data = {**data, "provider": ProviderName[member].value}
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
return data
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def allowed_providers(cls) -> tuple[ProviderName, ...] | None:
|
def allowed_providers(cls) -> tuple[ProviderName, ...] | None:
|
||||||
return get_args(cls.model_fields["provider"].annotation)
|
return get_args(cls.model_fields["provider"].annotation)
|
||||||
@@ -626,18 +606,11 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
] = defaultdict(list)
|
] = defaultdict(list)
|
||||||
|
|
||||||
for field, key in fields:
|
for field, key in fields:
|
||||||
if (
|
if field.provider == frozenset([ProviderName.HTTP]):
|
||||||
field.discriminator
|
# HTTP host-scoped credentials can have different hosts that reqires different credential sets.
|
||||||
and not field.discriminator_mapping
|
# Group by host extracted from the URL
|
||||||
and field.discriminator_values
|
|
||||||
):
|
|
||||||
# URL-based discrimination (e.g. HTTP host-scoped, MCP server URL):
|
|
||||||
# Each unique host gets its own credential entry.
|
|
||||||
provider_prefix = next(iter(field.provider))
|
|
||||||
# Use .value for enum types to get the plain string (e.g. "mcp" not "ProviderName.MCP")
|
|
||||||
prefix_str = getattr(provider_prefix, "value", str(provider_prefix))
|
|
||||||
providers = frozenset(
|
providers = frozenset(
|
||||||
[cast(CP, prefix_str)]
|
[cast(CP, "http")]
|
||||||
+ [
|
+ [
|
||||||
cast(CP, parse_url(str(value)).netloc)
|
cast(CP, parse_url(str(value)).netloc)
|
||||||
for value in field.discriminator_values
|
for value in field.discriminator_values
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from backend.blocks import get_block
|
|||||||
from backend.blocks._base import BlockSchema
|
from backend.blocks._base import BlockSchema
|
||||||
from backend.blocks.agent import AgentExecutorBlock
|
from backend.blocks.agent import AgentExecutorBlock
|
||||||
from backend.blocks.io import AgentOutputBlock
|
from backend.blocks.io import AgentOutputBlock
|
||||||
from backend.blocks.mcp.block import MCPToolBlock
|
|
||||||
from backend.data import redis_client as redis
|
from backend.data import redis_client as redis
|
||||||
from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry
|
from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry
|
||||||
from backend.data.credit import UsageTransactionMetadata
|
from backend.data.credit import UsageTransactionMetadata
|
||||||
@@ -229,18 +228,6 @@ async def execute_node(
|
|||||||
_input_data.nodes_input_masks = nodes_input_masks
|
_input_data.nodes_input_masks = nodes_input_masks
|
||||||
_input_data.user_id = user_id
|
_input_data.user_id = user_id
|
||||||
input_data = _input_data.model_dump()
|
input_data = _input_data.model_dump()
|
||||||
elif isinstance(node_block, MCPToolBlock):
|
|
||||||
_mcp_data = MCPToolBlock.Input(**node.input_default)
|
|
||||||
# Dynamic tool fields are flattened to top-level by validate_exec
|
|
||||||
# (via get_input_defaults). Collect them back into tool_arguments.
|
|
||||||
tool_schema = _mcp_data.tool_input_schema
|
|
||||||
tool_props = set(tool_schema.get("properties", {}).keys())
|
|
||||||
merged_args = {**_mcp_data.tool_arguments}
|
|
||||||
for key in tool_props:
|
|
||||||
if key in input_data:
|
|
||||||
merged_args[key] = input_data[key]
|
|
||||||
_mcp_data.tool_arguments = merged_args
|
|
||||||
input_data = _mcp_data.model_dump()
|
|
||||||
data.inputs = input_data
|
data.inputs = input_data
|
||||||
|
|
||||||
# Execute the node
|
# Execute the node
|
||||||
@@ -277,34 +264,8 @@ async def execute_node(
|
|||||||
|
|
||||||
# Handle regular credentials fields
|
# Handle regular credentials fields
|
||||||
for field_name, input_type in input_model.get_credentials_fields().items():
|
for field_name, input_type in input_model.get_credentials_fields().items():
|
||||||
field_value = input_data.get(field_name)
|
credentials_meta = input_type(**input_data[field_name])
|
||||||
if not field_value or (
|
credentials, lock = await creds_manager.acquire(user_id, credentials_meta.id)
|
||||||
isinstance(field_value, dict) and not field_value.get("id")
|
|
||||||
):
|
|
||||||
# No credentials configured — nullify so JSON schema validation
|
|
||||||
# doesn't choke on the empty default `{}`.
|
|
||||||
input_data[field_name] = None
|
|
||||||
continue # Block runs without credentials
|
|
||||||
|
|
||||||
credentials_meta = input_type(**field_value)
|
|
||||||
# Write normalized values back so JSON schema validation also passes
|
|
||||||
# (model_validator may have fixed legacy formats like "ProviderName.MCP")
|
|
||||||
input_data[field_name] = credentials_meta.model_dump(mode="json")
|
|
||||||
try:
|
|
||||||
credentials, lock = await creds_manager.acquire(
|
|
||||||
user_id, credentials_meta.id
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
# Credential was deleted or doesn't exist.
|
|
||||||
# If the field has a default, run without credentials.
|
|
||||||
if input_model.model_fields[field_name].default is not None:
|
|
||||||
log_metadata.warning(
|
|
||||||
f"Credentials #{credentials_meta.id} not found, "
|
|
||||||
"running without (field has default)"
|
|
||||||
)
|
|
||||||
input_data[field_name] = None
|
|
||||||
continue
|
|
||||||
raise
|
|
||||||
creds_locks.append(lock)
|
creds_locks.append(lock)
|
||||||
extra_exec_kwargs[field_name] = credentials
|
extra_exec_kwargs[field_name] = credentials
|
||||||
|
|
||||||
|
|||||||
@@ -260,13 +260,7 @@ async def _validate_node_input_credentials(
|
|||||||
# Track if any credential field is missing for this node
|
# Track if any credential field is missing for this node
|
||||||
has_missing_credentials = False
|
has_missing_credentials = False
|
||||||
|
|
||||||
# A credential field is optional if the node metadata says so, or if
|
|
||||||
# the block schema declares a default for the field.
|
|
||||||
required_fields = block.input_schema.get_required_fields()
|
|
||||||
is_creds_optional = node.credentials_optional
|
|
||||||
|
|
||||||
for field_name, credentials_meta_type in credentials_fields.items():
|
for field_name, credentials_meta_type in credentials_fields.items():
|
||||||
field_is_optional = is_creds_optional or field_name not in required_fields
|
|
||||||
try:
|
try:
|
||||||
# Check nodes_input_masks first, then input_default
|
# Check nodes_input_masks first, then input_default
|
||||||
field_value = None
|
field_value = None
|
||||||
@@ -279,7 +273,7 @@ async def _validate_node_input_credentials(
|
|||||||
elif field_name in node.input_default:
|
elif field_name in node.input_default:
|
||||||
# For optional credentials, don't use input_default - treat as missing
|
# For optional credentials, don't use input_default - treat as missing
|
||||||
# This prevents stale credential IDs from failing validation
|
# This prevents stale credential IDs from failing validation
|
||||||
if field_is_optional:
|
if node.credentials_optional:
|
||||||
field_value = None
|
field_value = None
|
||||||
else:
|
else:
|
||||||
field_value = node.input_default[field_name]
|
field_value = node.input_default[field_name]
|
||||||
@@ -289,8 +283,8 @@ async def _validate_node_input_credentials(
|
|||||||
isinstance(field_value, dict) and not field_value.get("id")
|
isinstance(field_value, dict) and not field_value.get("id")
|
||||||
):
|
):
|
||||||
has_missing_credentials = True
|
has_missing_credentials = True
|
||||||
# If credential field is optional, skip instead of error
|
# If node has credentials_optional flag, mark for skipping instead of error
|
||||||
if field_is_optional:
|
if node.credentials_optional:
|
||||||
continue # Don't add error, will be marked for skip after loop
|
continue # Don't add error, will be marked for skip after loop
|
||||||
else:
|
else:
|
||||||
credential_errors[node.id][
|
credential_errors[node.id][
|
||||||
@@ -340,16 +334,16 @@ async def _validate_node_input_credentials(
|
|||||||
] = "Invalid credentials: type/provider mismatch"
|
] = "Invalid credentials: type/provider mismatch"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If node has optional credentials and any are missing, allow running without.
|
# If node has optional credentials and any are missing, mark for skipping
|
||||||
# The executor will pass credentials=None to the block's run().
|
# But only if there are no other errors for this node
|
||||||
if (
|
if (
|
||||||
has_missing_credentials
|
has_missing_credentials
|
||||||
and is_creds_optional
|
and node.credentials_optional
|
||||||
and node.id not in credential_errors
|
and node.id not in credential_errors
|
||||||
):
|
):
|
||||||
|
nodes_to_skip.add(node.id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Node #{node.id}: optional credentials not configured, "
|
f"Node #{node.id} will be skipped: optional credentials not configured"
|
||||||
"running without"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return credential_errors, nodes_to_skip
|
return credential_errors, nodes_to_skip
|
||||||
|
|||||||
@@ -495,7 +495,6 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip(
|
|||||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||||
"credentials": mock_credentials_field_type
|
"credentials": mock_credentials_field_type
|
||||||
}
|
}
|
||||||
mock_block.input_schema.get_required_fields.return_value = {"credentials"}
|
|
||||||
mock_node.block = mock_block
|
mock_node.block = mock_block
|
||||||
|
|
||||||
# Create mock graph
|
# Create mock graph
|
||||||
@@ -509,8 +508,8 @@ async def test_validate_node_input_credentials_returns_nodes_to_skip(
|
|||||||
nodes_input_masks=None,
|
nodes_input_masks=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Node should NOT be in nodes_to_skip (runs without credentials) and not in errors
|
# Node should be in nodes_to_skip, not in errors
|
||||||
assert mock_node.id not in nodes_to_skip
|
assert mock_node.id in nodes_to_skip
|
||||||
assert mock_node.id not in errors
|
assert mock_node.id not in errors
|
||||||
|
|
||||||
|
|
||||||
@@ -536,7 +535,6 @@ async def test_validate_node_input_credentials_required_missing_creds_error(
|
|||||||
mock_block.input_schema.get_credentials_fields.return_value = {
|
mock_block.input_schema.get_credentials_fields.return_value = {
|
||||||
"credentials": mock_credentials_field_type
|
"credentials": mock_credentials_field_type
|
||||||
}
|
}
|
||||||
mock_block.input_schema.get_required_fields.return_value = {"credentials"}
|
|
||||||
mock_node.block = mock_block
|
mock_node.block = mock_block
|
||||||
|
|
||||||
# Create mock graph
|
# Create mock graph
|
||||||
|
|||||||
@@ -22,27 +22,6 @@ from backend.util.settings import Settings
|
|||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
|
|
||||||
def provider_matches(stored: str, expected: str) -> bool:
|
|
||||||
"""Compare provider strings, handling Python 3.13 ``str(StrEnum)`` bug.
|
|
||||||
|
|
||||||
On Python 3.13, ``str(ProviderName.MCP)`` returns ``"ProviderName.MCP"``
|
|
||||||
instead of ``"mcp"``. OAuth states persisted with the buggy format need
|
|
||||||
to match when ``expected`` is the canonical value (e.g. ``"mcp"``).
|
|
||||||
"""
|
|
||||||
if stored == expected:
|
|
||||||
return True
|
|
||||||
if stored.startswith("ProviderName."):
|
|
||||||
member = stored.removeprefix("ProviderName.")
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
|
|
||||||
try:
|
|
||||||
return ProviderName[member].value == expected
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# This is an overrride since ollama doesn't actually require an API key, but the creddential system enforces one be attached
|
# This is an overrride since ollama doesn't actually require an API key, but the creddential system enforces one be attached
|
||||||
ollama_credentials = APIKeyCredentials(
|
ollama_credentials = APIKeyCredentials(
|
||||||
id="744fdc56-071a-4761-b5a5-0af0ce10a2b5",
|
id="744fdc56-071a-4761-b5a5-0af0ce10a2b5",
|
||||||
@@ -410,7 +389,7 @@ class IntegrationCredentialsStore:
|
|||||||
self, user_id: str, provider: str
|
self, user_id: str, provider: str
|
||||||
) -> list[Credentials]:
|
) -> list[Credentials]:
|
||||||
credentials = await self.get_all_creds(user_id)
|
credentials = await self.get_all_creds(user_id)
|
||||||
return [c for c in credentials if provider_matches(c.provider, provider)]
|
return [c for c in credentials if c.provider == provider]
|
||||||
|
|
||||||
async def get_authorized_providers(self, user_id: str) -> list[str]:
|
async def get_authorized_providers(self, user_id: str) -> list[str]:
|
||||||
credentials = await self.get_all_creds(user_id)
|
credentials = await self.get_all_creds(user_id)
|
||||||
@@ -506,6 +485,17 @@ class IntegrationCredentialsStore:
|
|||||||
async with self.edit_user_integrations(user_id) as user_integrations:
|
async with self.edit_user_integrations(user_id) as user_integrations:
|
||||||
user_integrations.oauth_states.append(state)
|
user_integrations.oauth_states.append(state)
|
||||||
|
|
||||||
|
async with await self.locked_user_integrations(user_id):
|
||||||
|
|
||||||
|
user_integrations = await self._get_user_integrations(user_id)
|
||||||
|
oauth_states = user_integrations.oauth_states
|
||||||
|
oauth_states.append(state)
|
||||||
|
user_integrations.oauth_states = oauth_states
|
||||||
|
|
||||||
|
await self.db_manager.update_user_integrations(
|
||||||
|
user_id=user_id, data=user_integrations
|
||||||
|
)
|
||||||
|
|
||||||
return token, code_challenge
|
return token, code_challenge
|
||||||
|
|
||||||
def _generate_code_challenge(self) -> tuple[str, str]:
|
def _generate_code_challenge(self) -> tuple[str, str]:
|
||||||
@@ -531,7 +521,7 @@ class IntegrationCredentialsStore:
|
|||||||
state
|
state
|
||||||
for state in oauth_states
|
for state in oauth_states
|
||||||
if secrets.compare_digest(state.token, token)
|
if secrets.compare_digest(state.token, token)
|
||||||
and provider_matches(state.provider, provider)
|
and state.provider == provider
|
||||||
and state.expires_at > now.timestamp()
|
and state.expires_at > now.timestamp()
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
|
|||||||
@@ -9,10 +9,7 @@ from redis.asyncio.lock import Lock as AsyncRedisLock
|
|||||||
|
|
||||||
from backend.data.model import Credentials, OAuth2Credentials
|
from backend.data.model import Credentials, OAuth2Credentials
|
||||||
from backend.data.redis_client import get_redis_async
|
from backend.data.redis_client import get_redis_async
|
||||||
from backend.integrations.credentials_store import (
|
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||||
IntegrationCredentialsStore,
|
|
||||||
provider_matches,
|
|
||||||
)
|
|
||||||
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
from backend.integrations.oauth import CREDENTIALS_BY_PROVIDER, HANDLERS_BY_NAME
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.exceptions import MissingConfigError
|
from backend.util.exceptions import MissingConfigError
|
||||||
@@ -140,9 +137,6 @@ class IntegrationCredentialsManager:
|
|||||||
self, user_id: str, credentials: OAuth2Credentials, lock: bool = True
|
self, user_id: str, credentials: OAuth2Credentials, lock: bool = True
|
||||||
) -> OAuth2Credentials:
|
) -> OAuth2Credentials:
|
||||||
async with self._locked(user_id, credentials.id, "refresh"):
|
async with self._locked(user_id, credentials.id, "refresh"):
|
||||||
if provider_matches(credentials.provider, ProviderName.MCP.value):
|
|
||||||
oauth_handler = create_mcp_oauth_handler(credentials)
|
|
||||||
else:
|
|
||||||
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
|
oauth_handler = await _get_provider_oauth_handler(credentials.provider)
|
||||||
if oauth_handler.needs_refresh(credentials):
|
if oauth_handler.needs_refresh(credentials):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -242,31 +236,3 @@ async def _get_provider_oauth_handler(provider_name_str: str) -> "BaseOAuthHandl
|
|||||||
client_secret=client_secret,
|
client_secret=client_secret,
|
||||||
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
|
redirect_uri=f"{frontend_base_url}/auth/integrations/oauth_callback",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_mcp_oauth_handler(
|
|
||||||
credentials: OAuth2Credentials,
|
|
||||||
) -> "BaseOAuthHandler":
|
|
||||||
"""Create an MCPOAuthHandler from credential metadata for token refresh.
|
|
||||||
|
|
||||||
MCP OAuth handlers have dynamic endpoints discovered per-server, so they
|
|
||||||
can't be registered as singletons in HANDLERS_BY_NAME. Instead, the handler
|
|
||||||
is reconstructed from metadata stored on the credential during initial auth.
|
|
||||||
"""
|
|
||||||
from backend.blocks.mcp.oauth import MCPOAuthHandler
|
|
||||||
|
|
||||||
meta = credentials.metadata or {}
|
|
||||||
token_url = meta.get("mcp_token_url", "")
|
|
||||||
if not token_url:
|
|
||||||
raise ValueError(
|
|
||||||
f"MCP credential {credentials.id} is missing 'mcp_token_url' metadata; "
|
|
||||||
"cannot refresh tokens"
|
|
||||||
)
|
|
||||||
return MCPOAuthHandler(
|
|
||||||
client_id=meta.get("mcp_client_id", ""),
|
|
||||||
client_secret=meta.get("mcp_client_secret", ""),
|
|
||||||
redirect_uri="", # Not needed for token refresh
|
|
||||||
authorize_url="", # Not needed for token refresh
|
|
||||||
token_url=token_url,
|
|
||||||
resource_url=meta.get("mcp_resource_url"),
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ class ProviderName(str, Enum):
|
|||||||
IDEOGRAM = "ideogram"
|
IDEOGRAM = "ideogram"
|
||||||
JINA = "jina"
|
JINA = "jina"
|
||||||
LLAMA_API = "llama_api"
|
LLAMA_API = "llama_api"
|
||||||
MCP = "mcp"
|
|
||||||
MEDIUM = "medium"
|
MEDIUM = "medium"
|
||||||
MEM0 = "mem0"
|
MEM0 = "mem0"
|
||||||
NOTION = "notion"
|
NOTION = "notion"
|
||||||
|
|||||||
@@ -51,21 +51,6 @@ async def _on_graph_activate(graph: "BaseGraph | GraphModel", user_id: str):
|
|||||||
if (
|
if (
|
||||||
creds_meta := new_node.input_default.get(creds_field_name)
|
creds_meta := new_node.input_default.get(creds_field_name)
|
||||||
) and not await get_credentials(creds_meta["id"]):
|
) and not await get_credentials(creds_meta["id"]):
|
||||||
# If the credential field is optional (has a default in the
|
|
||||||
# schema, or node metadata marks it optional), clear the stale
|
|
||||||
# reference instead of blocking the save.
|
|
||||||
creds_field_optional = (
|
|
||||||
new_node.credentials_optional
|
|
||||||
or creds_field_name not in block_input_schema.get_required_fields()
|
|
||||||
)
|
|
||||||
if creds_field_optional:
|
|
||||||
new_node.input_default[creds_field_name] = {}
|
|
||||||
logger.warning(
|
|
||||||
f"Node #{new_node.id}: cleared stale optional "
|
|
||||||
f"credentials #{creds_meta['id']} for "
|
|
||||||
f"'{creds_field_name}'"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Node #{new_node.id} input '{creds_field_name}' updated with "
|
f"Node #{new_node.id} input '{creds_field_name}' updated with "
|
||||||
f"non-existent credentials #{creds_meta['id']}"
|
f"non-existent credentials #{creds_meta['id']}"
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ class Flag(str, Enum):
|
|||||||
AGENT_ACTIVITY = "agent-activity"
|
AGENT_ACTIVITY = "agent-activity"
|
||||||
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
|
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
|
||||||
CHAT = "chat"
|
CHAT = "chat"
|
||||||
COPILOT_SDK = "copilot-sdk"
|
|
||||||
|
|
||||||
|
|
||||||
def is_configured() -> bool:
|
def is_configured() -> bool:
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ class HostResolver(abc.AbstractResolver):
|
|||||||
def __init__(self, ssl_hostname: str, ip_addresses: list[str]):
|
def __init__(self, ssl_hostname: str, ip_addresses: list[str]):
|
||||||
self.ssl_hostname = ssl_hostname
|
self.ssl_hostname = ssl_hostname
|
||||||
self.ip_addresses = ip_addresses
|
self.ip_addresses = ip_addresses
|
||||||
self._default = aiohttp.ThreadedResolver()
|
self._default = aiohttp.AsyncResolver()
|
||||||
|
|
||||||
async def resolve(self, host, port=0, family=socket.AF_INET):
|
async def resolve(self, host, port=0, family=socket.AF_INET):
|
||||||
if host == self.ssl_hostname:
|
if host == self.ssl_hostname:
|
||||||
@@ -467,7 +467,7 @@ class Requests:
|
|||||||
resolver = HostResolver(ssl_hostname=hostname, ip_addresses=ip_addresses)
|
resolver = HostResolver(ssl_hostname=hostname, ip_addresses=ip_addresses)
|
||||||
ssl_context = ssl.create_default_context()
|
ssl_context = ssl.create_default_context()
|
||||||
connector = aiohttp.TCPConnector(resolver=resolver, ssl=ssl_context)
|
connector = aiohttp.TCPConnector(resolver=resolver, ssl=ssl_context)
|
||||||
session_kwargs: dict = {}
|
session_kwargs = {}
|
||||||
if connector:
|
if connector:
|
||||||
session_kwargs["connector"] = connector
|
session_kwargs["connector"] = connector
|
||||||
|
|
||||||
|
|||||||
94
autogpt_platform/backend/poetry.lock
generated
94
autogpt_platform/backend/poetry.lock
generated
@@ -897,29 +897,6 @@ files = [
|
|||||||
{file = "charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a"},
|
{file = "charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "claude-agent-sdk"
|
|
||||||
version = "0.1.35"
|
|
||||||
description = "Python SDK for Claude Code"
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.10"
|
|
||||||
groups = ["main"]
|
|
||||||
files = [
|
|
||||||
{file = "claude_agent_sdk-0.1.35-py3-none-macosx_11_0_arm64.whl", hash = "sha256:df67f4deade77b16a9678b3a626c176498e40417f33b04beda9628287f375591"},
|
|
||||||
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:14963944f55ded7c8ed518feebfa5b4284aa6dd8d81aeff2e5b21a962ce65097"},
|
|
||||||
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:84344dcc535d179c1fc8a11c6f34c37c3b583447bdf09d869effb26514fd7a65"},
|
|
||||||
{file = "claude_agent_sdk-0.1.35-py3-none-win_amd64.whl", hash = "sha256:1b3d54b47448c93f6f372acd4d1757f047c3c1e8ef5804be7a1e3e53e2c79a5f"},
|
|
||||||
{file = "claude_agent_sdk-0.1.35.tar.gz", hash = "sha256:0f98e2b3c71ca85abfc042e7a35c648df88e87fda41c52e6779ef7b038dcbb52"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
anyio = ">=4.0.0"
|
|
||||||
mcp = ">=0.1.0"
|
|
||||||
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""}
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
dev = ["anyio[trio] (>=4.0.0)", "mypy (>=1.0.0)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.20.0)", "pytest-cov (>=4.0.0)", "ruff (>=0.1.0)"]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cleo"
|
name = "cleo"
|
||||||
version = "2.1.0"
|
version = "2.1.0"
|
||||||
@@ -2616,18 +2593,6 @@ http2 = ["h2 (>=3,<5)"]
|
|||||||
socks = ["socksio (==1.*)"]
|
socks = ["socksio (==1.*)"]
|
||||||
zstd = ["zstandard (>=0.18.0)"]
|
zstd = ["zstandard (>=0.18.0)"]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "httpx-sse"
|
|
||||||
version = "0.4.3"
|
|
||||||
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.9"
|
|
||||||
groups = ["main"]
|
|
||||||
files = [
|
|
||||||
{file = "httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc"},
|
|
||||||
{file = "httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "huggingface-hub"
|
name = "huggingface-hub"
|
||||||
version = "1.4.1"
|
version = "1.4.1"
|
||||||
@@ -3345,39 +3310,6 @@ files = [
|
|||||||
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
|
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "mcp"
|
|
||||||
version = "1.26.0"
|
|
||||||
description = "Model Context Protocol SDK"
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.10"
|
|
||||||
groups = ["main"]
|
|
||||||
files = [
|
|
||||||
{file = "mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca"},
|
|
||||||
{file = "mcp-1.26.0.tar.gz", hash = "sha256:db6e2ef491eecc1a0d93711a76f28dec2e05999f93afd48795da1c1137142c66"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
anyio = ">=4.5"
|
|
||||||
httpx = ">=0.27.1"
|
|
||||||
httpx-sse = ">=0.4"
|
|
||||||
jsonschema = ">=4.20.0"
|
|
||||||
pydantic = ">=2.11.0,<3.0.0"
|
|
||||||
pydantic-settings = ">=2.5.2"
|
|
||||||
pyjwt = {version = ">=2.10.1", extras = ["crypto"]}
|
|
||||||
python-multipart = ">=0.0.9"
|
|
||||||
pywin32 = {version = ">=310", markers = "sys_platform == \"win32\""}
|
|
||||||
sse-starlette = ">=1.6.1"
|
|
||||||
starlette = ">=0.27"
|
|
||||||
typing-extensions = ">=4.9.0"
|
|
||||||
typing-inspection = ">=0.4.1"
|
|
||||||
uvicorn = {version = ">=0.31.1", markers = "sys_platform != \"emscripten\""}
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
cli = ["python-dotenv (>=1.0.0)", "typer (>=0.16.0)"]
|
|
||||||
rich = ["rich (>=13.9.4)"]
|
|
||||||
ws = ["websockets (>=15.0.1)"]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mdurl"
|
name = "mdurl"
|
||||||
version = "0.1.2"
|
version = "0.1.2"
|
||||||
@@ -6062,7 +5994,7 @@ description = "Python for Window Extensions"
|
|||||||
optional = false
|
optional = false
|
||||||
python-versions = "*"
|
python-versions = "*"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
markers = "sys_platform == \"win32\" or platform_system == \"Windows\""
|
markers = "platform_system == \"Windows\""
|
||||||
files = [
|
files = [
|
||||||
{file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"},
|
{file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"},
|
||||||
{file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"},
|
{file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"},
|
||||||
@@ -7042,28 +6974,6 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
|
|||||||
pymysql = ["pymysql"]
|
pymysql = ["pymysql"]
|
||||||
sqlcipher = ["sqlcipher3_binary"]
|
sqlcipher = ["sqlcipher3_binary"]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "sse-starlette"
|
|
||||||
version = "3.2.0"
|
|
||||||
description = "SSE plugin for Starlette"
|
|
||||||
optional = false
|
|
||||||
python-versions = ">=3.9"
|
|
||||||
groups = ["main"]
|
|
||||||
files = [
|
|
||||||
{file = "sse_starlette-3.2.0-py3-none-any.whl", hash = "sha256:5876954bd51920fc2cd51baee47a080eb88a37b5b784e615abb0b283f801cdbf"},
|
|
||||||
{file = "sse_starlette-3.2.0.tar.gz", hash = "sha256:8127594edfb51abe44eac9c49e59b0b01f1039d0c7461c6fd91d4e03b70da422"},
|
|
||||||
]
|
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
anyio = ">=4.7.0"
|
|
||||||
starlette = ">=0.49.1"
|
|
||||||
|
|
||||||
[package.extras]
|
|
||||||
daphne = ["daphne (>=4.2.0)"]
|
|
||||||
examples = ["aiosqlite (>=0.21.0)", "fastapi (>=0.115.12)", "sqlalchemy[asyncio] (>=2.0.41)", "uvicorn (>=0.34.0)"]
|
|
||||||
granian = ["granian (>=2.3.1)"]
|
|
||||||
uvicorn = ["uvicorn (>=0.34.0)"]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "stagehand"
|
name = "stagehand"
|
||||||
version = "0.5.9"
|
version = "0.5.9"
|
||||||
@@ -8530,4 +8440,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = ">=3.10,<3.14"
|
python-versions = ">=3.10,<3.14"
|
||||||
content-hash = "55e095de555482f0fe47de7695f390fe93e7bcf739b31c391b2e5e3c3d938ae3"
|
content-hash = "fa9c5deadf593e815dd2190f58e22152373900603f5f244b9616cd721de84d2f"
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ anthropic = "^0.79.0"
|
|||||||
apscheduler = "^3.11.1"
|
apscheduler = "^3.11.1"
|
||||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||||
bleach = { extras = ["css"], version = "^6.2.0" }
|
bleach = { extras = ["css"], version = "^6.2.0" }
|
||||||
claude-agent-sdk = "^0.1.0"
|
|
||||||
click = "^8.2.0"
|
click = "^8.2.0"
|
||||||
cryptography = "^46.0"
|
cryptography = "^46.0"
|
||||||
discord-py = "^2.5.2"
|
discord-py = "^2.5.2"
|
||||||
|
|||||||
@@ -1,66 +0,0 @@
|
|||||||
from typing import cast
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from backend.blocks.jina._auth import (
|
|
||||||
TEST_CREDENTIALS,
|
|
||||||
TEST_CREDENTIALS_INPUT,
|
|
||||||
JinaCredentialsInput,
|
|
||||||
)
|
|
||||||
from backend.blocks.jina.search import ExtractWebsiteContentBlock
|
|
||||||
from backend.util.request import HTTPClientError
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_extract_website_content_returns_content(monkeypatch):
|
|
||||||
block = ExtractWebsiteContentBlock()
|
|
||||||
input_data = block.Input(
|
|
||||||
url="https://example.com",
|
|
||||||
credentials=cast(JinaCredentialsInput, TEST_CREDENTIALS_INPUT),
|
|
||||||
raw_content=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def fake_get_request(url, json=False, headers=None):
|
|
||||||
assert url == "https://example.com"
|
|
||||||
assert headers == {}
|
|
||||||
return "page content"
|
|
||||||
|
|
||||||
monkeypatch.setattr(block, "get_request", fake_get_request)
|
|
||||||
|
|
||||||
results = [
|
|
||||||
output
|
|
||||||
async for output in block.run(
|
|
||||||
input_data=input_data, credentials=TEST_CREDENTIALS
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
assert ("content", "page content") in results
|
|
||||||
assert all(key != "error" for key, _ in results)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_extract_website_content_handles_http_error(monkeypatch):
|
|
||||||
block = ExtractWebsiteContentBlock()
|
|
||||||
input_data = block.Input(
|
|
||||||
url="https://example.com",
|
|
||||||
credentials=cast(JinaCredentialsInput, TEST_CREDENTIALS_INPUT),
|
|
||||||
raw_content=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def fake_get_request(url, json=False, headers=None):
|
|
||||||
raise HTTPClientError("HTTP 400 Error: Bad Request", 400)
|
|
||||||
|
|
||||||
monkeypatch.setattr(block, "get_request", fake_get_request)
|
|
||||||
|
|
||||||
results = [
|
|
||||||
output
|
|
||||||
async for output in block.run(
|
|
||||||
input_data=input_data, credentials=TEST_CREDENTIALS
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
assert ("content", "page content") not in results
|
|
||||||
error_messages = [value for key, value in results if key == "error"]
|
|
||||||
assert error_messages
|
|
||||||
assert "Client error (400)" in error_messages[0]
|
|
||||||
assert "https://example.com" in error_messages[0]
|
|
||||||
@@ -1,133 +0,0 @@
|
|||||||
"""Tests for SDK security hooks — workspace paths, tool access, and deny messages.
|
|
||||||
|
|
||||||
These are pure unit tests with no external dependencies (no SDK, no DB, no server).
|
|
||||||
They validate that the security hooks correctly block unauthorized paths,
|
|
||||||
tool access, and dangerous input patterns.
|
|
||||||
|
|
||||||
Note: Bash command validation was removed — the SDK built-in Bash tool is not in
|
|
||||||
allowed_tools, and the bash_exec MCP tool has kernel-level network isolation
|
|
||||||
(unshare --net) making command-level parsing unnecessary.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from backend.api.features.chat.sdk.security_hooks import (
|
|
||||||
_validate_tool_access,
|
|
||||||
_validate_workspace_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
SDK_CWD = "/tmp/copilot-test-session"
|
|
||||||
|
|
||||||
|
|
||||||
def _is_denied(result: dict) -> bool:
|
|
||||||
hook = result.get("hookSpecificOutput", {})
|
|
||||||
return hook.get("permissionDecision") == "deny"
|
|
||||||
|
|
||||||
|
|
||||||
def _reason(result: dict) -> str:
|
|
||||||
return result.get("hookSpecificOutput", {}).get("permissionDecisionReason", "")
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Workspace path validation (Read, Write, Edit, etc.)
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
|
|
||||||
class TestWorkspacePathValidation:
|
|
||||||
def test_path_in_workspace(self):
|
|
||||||
result = _validate_workspace_path(
|
|
||||||
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD
|
|
||||||
)
|
|
||||||
assert not _is_denied(result)
|
|
||||||
|
|
||||||
def test_path_outside_workspace(self):
|
|
||||||
result = _validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
def test_tool_results_allowed(self):
|
|
||||||
result = _validate_workspace_path(
|
|
||||||
"Read",
|
|
||||||
{"file_path": "~/.claude/projects/abc/tool-results/out.txt"},
|
|
||||||
SDK_CWD,
|
|
||||||
)
|
|
||||||
assert not _is_denied(result)
|
|
||||||
|
|
||||||
def test_claude_settings_blocked(self):
|
|
||||||
result = _validate_workspace_path(
|
|
||||||
"Read", {"file_path": "~/.claude/settings.json"}, SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
def test_claude_projects_without_tool_results(self):
|
|
||||||
result = _validate_workspace_path(
|
|
||||||
"Read", {"file_path": "~/.claude/projects/abc/credentials.json"}, SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
def test_no_path_allowed(self):
|
|
||||||
"""Glob/Grep without path defaults to cwd — should be allowed."""
|
|
||||||
result = _validate_workspace_path("Grep", {"pattern": "foo"}, SDK_CWD)
|
|
||||||
assert not _is_denied(result)
|
|
||||||
|
|
||||||
def test_path_traversal_with_dotdot(self):
|
|
||||||
result = _validate_workspace_path(
|
|
||||||
"Read", {"file_path": f"{SDK_CWD}/../../../etc/passwd"}, SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Tool access validation
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
|
|
||||||
class TestToolAccessValidation:
|
|
||||||
def test_blocked_tools(self):
|
|
||||||
for tool in ("bash", "shell", "exec", "terminal", "command"):
|
|
||||||
result = _validate_tool_access(tool, {})
|
|
||||||
assert _is_denied(result), f"Tool '{tool}' should be blocked"
|
|
||||||
|
|
||||||
def test_bash_builtin_blocked(self):
|
|
||||||
"""SDK built-in Bash (capital) is blocked as defence-in-depth."""
|
|
||||||
result = _validate_tool_access("Bash", {"command": "echo hello"}, SDK_CWD)
|
|
||||||
assert _is_denied(result)
|
|
||||||
assert "Bash" in _reason(result)
|
|
||||||
|
|
||||||
def test_workspace_tools_delegate(self):
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD
|
|
||||||
)
|
|
||||||
assert not _is_denied(result)
|
|
||||||
|
|
||||||
def test_dangerous_pattern_blocked(self):
|
|
||||||
result = _validate_tool_access("SomeUnknownTool", {"data": "sudo rm -rf /"})
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
def test_safe_unknown_tool_allowed(self):
|
|
||||||
result = _validate_tool_access("SomeSafeTool", {"data": "hello world"})
|
|
||||||
assert not _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Deny message quality (ntindle feedback)
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
|
|
||||||
class TestDenyMessageClarity:
|
|
||||||
"""Deny messages must include [SECURITY] and 'cannot be bypassed'
|
|
||||||
so the model knows the restriction is enforced, not a suggestion."""
|
|
||||||
|
|
||||||
def test_blocked_tool_message(self):
|
|
||||||
reason = _reason(_validate_tool_access("bash", {}))
|
|
||||||
assert "[SECURITY]" in reason
|
|
||||||
assert "cannot be bypassed" in reason
|
|
||||||
|
|
||||||
def test_bash_builtin_blocked_message(self):
|
|
||||||
reason = _reason(_validate_tool_access("Bash", {"command": "echo hello"}))
|
|
||||||
assert "[SECURITY]" in reason
|
|
||||||
assert "cannot be bypassed" in reason
|
|
||||||
|
|
||||||
def test_workspace_path_message(self):
|
|
||||||
reason = _reason(
|
|
||||||
_validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD)
|
|
||||||
)
|
|
||||||
assert "[SECURITY]" in reason
|
|
||||||
assert "cannot be bypassed" in reason
|
|
||||||
@@ -1,255 +0,0 @@
|
|||||||
"""Unit tests for JSONL transcript management utilities."""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
|
|
||||||
from backend.api.features.chat.sdk.transcript import (
|
|
||||||
STRIPPABLE_TYPES,
|
|
||||||
read_transcript_file,
|
|
||||||
strip_progress_entries,
|
|
||||||
validate_transcript,
|
|
||||||
write_transcript_to_tempfile,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_jsonl(*entries: dict) -> str:
|
|
||||||
return "\n".join(json.dumps(e) for e in entries) + "\n"
|
|
||||||
|
|
||||||
|
|
||||||
# --- Fixtures ---
|
|
||||||
|
|
||||||
|
|
||||||
METADATA_LINE = {"type": "queue-operation", "subtype": "create"}
|
|
||||||
FILE_HISTORY = {"type": "file-history-snapshot", "files": []}
|
|
||||||
USER_MSG = {"type": "user", "uuid": "u1", "message": {"role": "user", "content": "hi"}}
|
|
||||||
ASST_MSG = {
|
|
||||||
"type": "assistant",
|
|
||||||
"uuid": "a1",
|
|
||||||
"parentUuid": "u1",
|
|
||||||
"message": {"role": "assistant", "content": "hello"},
|
|
||||||
}
|
|
||||||
PROGRESS_ENTRY = {
|
|
||||||
"type": "progress",
|
|
||||||
"uuid": "p1",
|
|
||||||
"parentUuid": "u1",
|
|
||||||
"data": {"type": "bash_progress", "stdout": "running..."},
|
|
||||||
}
|
|
||||||
|
|
||||||
VALID_TRANSCRIPT = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG, ASST_MSG)
|
|
||||||
|
|
||||||
|
|
||||||
# --- read_transcript_file ---
|
|
||||||
|
|
||||||
|
|
||||||
class TestReadTranscriptFile:
|
|
||||||
def test_returns_content_for_valid_file(self, tmp_path):
|
|
||||||
path = tmp_path / "session.jsonl"
|
|
||||||
path.write_text(VALID_TRANSCRIPT)
|
|
||||||
result = read_transcript_file(str(path))
|
|
||||||
assert result is not None
|
|
||||||
assert "user" in result
|
|
||||||
|
|
||||||
def test_returns_none_for_missing_file(self):
|
|
||||||
assert read_transcript_file("/nonexistent/path.jsonl") is None
|
|
||||||
|
|
||||||
def test_returns_none_for_empty_path(self):
|
|
||||||
assert read_transcript_file("") is None
|
|
||||||
|
|
||||||
def test_returns_none_for_empty_file(self, tmp_path):
|
|
||||||
path = tmp_path / "empty.jsonl"
|
|
||||||
path.write_text("")
|
|
||||||
assert read_transcript_file(str(path)) is None
|
|
||||||
|
|
||||||
def test_returns_none_for_metadata_only(self, tmp_path):
|
|
||||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY)
|
|
||||||
path = tmp_path / "meta.jsonl"
|
|
||||||
path.write_text(content)
|
|
||||||
assert read_transcript_file(str(path)) is None
|
|
||||||
|
|
||||||
def test_returns_none_for_invalid_json(self, tmp_path):
|
|
||||||
path = tmp_path / "bad.jsonl"
|
|
||||||
path.write_text("not json\n{}\n{}\n")
|
|
||||||
assert read_transcript_file(str(path)) is None
|
|
||||||
|
|
||||||
def test_no_size_limit(self, tmp_path):
|
|
||||||
"""Large files are accepted — bucket storage has no size limit."""
|
|
||||||
big_content = {"type": "user", "uuid": "u9", "data": "x" * 1_000_000}
|
|
||||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, big_content, ASST_MSG)
|
|
||||||
path = tmp_path / "big.jsonl"
|
|
||||||
path.write_text(content)
|
|
||||||
result = read_transcript_file(str(path))
|
|
||||||
assert result is not None
|
|
||||||
|
|
||||||
|
|
||||||
# --- write_transcript_to_tempfile ---
|
|
||||||
|
|
||||||
|
|
||||||
class TestWriteTranscriptToTempfile:
|
|
||||||
"""Tests use /tmp/copilot-* paths to satisfy the sandbox prefix check."""
|
|
||||||
|
|
||||||
def test_writes_file_and_returns_path(self):
|
|
||||||
cwd = "/tmp/copilot-test-write"
|
|
||||||
try:
|
|
||||||
result = write_transcript_to_tempfile(
|
|
||||||
VALID_TRANSCRIPT, "sess-1234-abcd", cwd
|
|
||||||
)
|
|
||||||
assert result is not None
|
|
||||||
assert os.path.isfile(result)
|
|
||||||
assert result.endswith(".jsonl")
|
|
||||||
with open(result) as f:
|
|
||||||
assert f.read() == VALID_TRANSCRIPT
|
|
||||||
finally:
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
shutil.rmtree(cwd, ignore_errors=True)
|
|
||||||
|
|
||||||
def test_creates_parent_directory(self):
|
|
||||||
cwd = "/tmp/copilot-test-mkdir"
|
|
||||||
try:
|
|
||||||
result = write_transcript_to_tempfile(VALID_TRANSCRIPT, "sess-1234", cwd)
|
|
||||||
assert result is not None
|
|
||||||
assert os.path.isdir(cwd)
|
|
||||||
finally:
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
shutil.rmtree(cwd, ignore_errors=True)
|
|
||||||
|
|
||||||
def test_uses_session_id_prefix(self):
|
|
||||||
cwd = "/tmp/copilot-test-prefix"
|
|
||||||
try:
|
|
||||||
result = write_transcript_to_tempfile(
|
|
||||||
VALID_TRANSCRIPT, "abcdef12-rest", cwd
|
|
||||||
)
|
|
||||||
assert result is not None
|
|
||||||
assert "abcdef12" in os.path.basename(result)
|
|
||||||
finally:
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
shutil.rmtree(cwd, ignore_errors=True)
|
|
||||||
|
|
||||||
def test_rejects_cwd_outside_sandbox(self, tmp_path):
|
|
||||||
cwd = str(tmp_path / "not-copilot")
|
|
||||||
result = write_transcript_to_tempfile(VALID_TRANSCRIPT, "sess-1234", cwd)
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
# --- validate_transcript ---
|
|
||||||
|
|
||||||
|
|
||||||
class TestValidateTranscript:
|
|
||||||
def test_valid_transcript(self):
|
|
||||||
assert validate_transcript(VALID_TRANSCRIPT) is True
|
|
||||||
|
|
||||||
def test_none_content(self):
|
|
||||||
assert validate_transcript(None) is False
|
|
||||||
|
|
||||||
def test_empty_content(self):
|
|
||||||
assert validate_transcript("") is False
|
|
||||||
|
|
||||||
def test_metadata_only(self):
|
|
||||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY)
|
|
||||||
assert validate_transcript(content) is False
|
|
||||||
|
|
||||||
def test_user_only_no_assistant(self):
|
|
||||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, USER_MSG)
|
|
||||||
assert validate_transcript(content) is False
|
|
||||||
|
|
||||||
def test_assistant_only_no_user(self):
|
|
||||||
content = _make_jsonl(METADATA_LINE, FILE_HISTORY, ASST_MSG)
|
|
||||||
assert validate_transcript(content) is False
|
|
||||||
|
|
||||||
def test_invalid_json_returns_false(self):
|
|
||||||
assert validate_transcript("not json\n{}\n{}\n") is False
|
|
||||||
|
|
||||||
|
|
||||||
# --- strip_progress_entries ---
|
|
||||||
|
|
||||||
|
|
||||||
class TestStripProgressEntries:
|
|
||||||
def test_strips_all_strippable_types(self):
|
|
||||||
"""All STRIPPABLE_TYPES are removed from the output."""
|
|
||||||
entries = [
|
|
||||||
USER_MSG,
|
|
||||||
{"type": "progress", "uuid": "p1", "parentUuid": "u1"},
|
|
||||||
{"type": "file-history-snapshot", "files": []},
|
|
||||||
{"type": "queue-operation", "subtype": "create"},
|
|
||||||
{"type": "summary", "text": "..."},
|
|
||||||
{"type": "pr-link", "url": "..."},
|
|
||||||
ASST_MSG,
|
|
||||||
]
|
|
||||||
result = strip_progress_entries(_make_jsonl(*entries))
|
|
||||||
result_types = {json.loads(line)["type"] for line in result.strip().split("\n")}
|
|
||||||
assert result_types == {"user", "assistant"}
|
|
||||||
for stype in STRIPPABLE_TYPES:
|
|
||||||
assert stype not in result_types
|
|
||||||
|
|
||||||
def test_reparents_children_of_stripped_entries(self):
|
|
||||||
"""An assistant message whose parent is a progress entry gets reparented."""
|
|
||||||
progress = {
|
|
||||||
"type": "progress",
|
|
||||||
"uuid": "p1",
|
|
||||||
"parentUuid": "u1",
|
|
||||||
"data": {"type": "bash_progress"},
|
|
||||||
}
|
|
||||||
asst = {
|
|
||||||
"type": "assistant",
|
|
||||||
"uuid": "a1",
|
|
||||||
"parentUuid": "p1", # Points to progress
|
|
||||||
"message": {"role": "assistant", "content": "done"},
|
|
||||||
}
|
|
||||||
content = _make_jsonl(USER_MSG, progress, asst)
|
|
||||||
result = strip_progress_entries(content)
|
|
||||||
lines = [json.loads(line) for line in result.strip().split("\n")]
|
|
||||||
|
|
||||||
asst_entry = next(e for e in lines if e["type"] == "assistant")
|
|
||||||
# Should be reparented to u1 (the user message)
|
|
||||||
assert asst_entry["parentUuid"] == "u1"
|
|
||||||
|
|
||||||
def test_reparents_through_chain(self):
|
|
||||||
"""Reparenting walks through multiple stripped entries."""
|
|
||||||
p1 = {"type": "progress", "uuid": "p1", "parentUuid": "u1"}
|
|
||||||
p2 = {"type": "progress", "uuid": "p2", "parentUuid": "p1"}
|
|
||||||
p3 = {"type": "progress", "uuid": "p3", "parentUuid": "p2"}
|
|
||||||
asst = {
|
|
||||||
"type": "assistant",
|
|
||||||
"uuid": "a1",
|
|
||||||
"parentUuid": "p3", # 3 levels deep
|
|
||||||
"message": {"role": "assistant", "content": "done"},
|
|
||||||
}
|
|
||||||
content = _make_jsonl(USER_MSG, p1, p2, p3, asst)
|
|
||||||
result = strip_progress_entries(content)
|
|
||||||
lines = [json.loads(line) for line in result.strip().split("\n")]
|
|
||||||
|
|
||||||
asst_entry = next(e for e in lines if e["type"] == "assistant")
|
|
||||||
assert asst_entry["parentUuid"] == "u1"
|
|
||||||
|
|
||||||
def test_preserves_non_strippable_entries(self):
|
|
||||||
"""User, assistant, and system entries are preserved."""
|
|
||||||
system = {"type": "system", "uuid": "s1", "message": "prompt"}
|
|
||||||
content = _make_jsonl(system, USER_MSG, ASST_MSG)
|
|
||||||
result = strip_progress_entries(content)
|
|
||||||
result_types = [json.loads(line)["type"] for line in result.strip().split("\n")]
|
|
||||||
assert result_types == ["system", "user", "assistant"]
|
|
||||||
|
|
||||||
def test_empty_input(self):
|
|
||||||
result = strip_progress_entries("")
|
|
||||||
# Should return just a newline (empty content stripped)
|
|
||||||
assert result.strip() == ""
|
|
||||||
|
|
||||||
def test_no_strippable_entries(self):
|
|
||||||
"""When there's nothing to strip, output matches input structure."""
|
|
||||||
content = _make_jsonl(USER_MSG, ASST_MSG)
|
|
||||||
result = strip_progress_entries(content)
|
|
||||||
result_lines = result.strip().split("\n")
|
|
||||||
assert len(result_lines) == 2
|
|
||||||
|
|
||||||
def test_handles_entries_without_uuid(self):
|
|
||||||
"""Entries without uuid field are handled gracefully."""
|
|
||||||
no_uuid = {"type": "queue-operation", "subtype": "create"}
|
|
||||||
content = _make_jsonl(no_uuid, USER_MSG, ASST_MSG)
|
|
||||||
result = strip_progress_entries(content)
|
|
||||||
result_types = [json.loads(line)["type"] for line in result.strip().split("\n")]
|
|
||||||
# queue-operation is strippable
|
|
||||||
assert "queue-operation" not in result_types
|
|
||||||
assert "user" in result_types
|
|
||||||
assert "assistant" in result_types
|
|
||||||
@@ -1,96 +0,0 @@
|
|||||||
import { NextResponse } from "next/server";
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Safely encode a value as JSON for embedding in a script tag.
|
|
||||||
* Escapes characters that could break out of the script context to prevent XSS.
|
|
||||||
*/
|
|
||||||
function safeJsonStringify(value: unknown): string {
|
|
||||||
return JSON.stringify(value)
|
|
||||||
.replace(/</g, "\\u003c")
|
|
||||||
.replace(/>/g, "\\u003e")
|
|
||||||
.replace(/&/g, "\\u0026");
|
|
||||||
}
|
|
||||||
|
|
||||||
// MCP-specific OAuth callback route.
|
|
||||||
//
|
|
||||||
// Unlike the generic oauth_callback which relies on window.opener.postMessage,
|
|
||||||
// this route uses BroadcastChannel as the PRIMARY communication method.
|
|
||||||
// This is critical because cross-origin OAuth flows (e.g. Sentry → localhost)
|
|
||||||
// often lose window.opener due to COOP (Cross-Origin-Opener-Policy) headers.
|
|
||||||
//
|
|
||||||
// BroadcastChannel works across all same-origin tabs/popups regardless of opener.
|
|
||||||
export async function GET(request: Request) {
|
|
||||||
const { searchParams } = new URL(request.url);
|
|
||||||
const code = searchParams.get("code");
|
|
||||||
const state = searchParams.get("state");
|
|
||||||
|
|
||||||
const success = Boolean(code && state);
|
|
||||||
const message = success
|
|
||||||
? { success: true, code, state }
|
|
||||||
: {
|
|
||||||
success: false,
|
|
||||||
message: `Missing parameters: ${searchParams.toString()}`,
|
|
||||||
};
|
|
||||||
|
|
||||||
return new NextResponse(
|
|
||||||
`<!DOCTYPE html>
|
|
||||||
<html>
|
|
||||||
<head><title>MCP Sign-in</title></head>
|
|
||||||
<body style="font-family: system-ui, -apple-system, sans-serif; display: flex; align-items: center; justify-content: center; min-height: 100vh; margin: 0; background: #f9fafb;">
|
|
||||||
<div style="text-align: center; max-width: 400px; padding: 2rem;">
|
|
||||||
<div id="spinner" style="margin: 0 auto 1rem; width: 32px; height: 32px; border: 3px solid #e5e7eb; border-top-color: #3b82f6; border-radius: 50%; animation: spin 0.8s linear infinite;"></div>
|
|
||||||
<p id="status" style="color: #374151; font-size: 16px;">Completing sign-in...</p>
|
|
||||||
</div>
|
|
||||||
<style>@keyframes spin { to { transform: rotate(360deg); } }</style>
|
|
||||||
<script>
|
|
||||||
(function() {
|
|
||||||
var msg = ${safeJsonStringify(message)};
|
|
||||||
var sent = false;
|
|
||||||
|
|
||||||
// Method 1: BroadcastChannel (reliable across tabs/popups, no opener needed)
|
|
||||||
try {
|
|
||||||
var bc = new BroadcastChannel("mcp_oauth");
|
|
||||||
bc.postMessage({ type: "mcp_oauth_result", success: msg.success, code: msg.code, state: msg.state, message: msg.message });
|
|
||||||
bc.close();
|
|
||||||
sent = true;
|
|
||||||
} catch(e) { /* BroadcastChannel not supported */ }
|
|
||||||
|
|
||||||
// Method 2: window.opener.postMessage (fallback for same-origin popups)
|
|
||||||
try {
|
|
||||||
if (window.opener && !window.opener.closed) {
|
|
||||||
window.opener.postMessage(
|
|
||||||
{ message_type: "mcp_oauth_result", success: msg.success, code: msg.code, state: msg.state, message: msg.message },
|
|
||||||
window.location.origin
|
|
||||||
);
|
|
||||||
sent = true;
|
|
||||||
}
|
|
||||||
} catch(e) { /* opener not available (COOP) */ }
|
|
||||||
|
|
||||||
// Method 3: localStorage (most reliable cross-tab fallback)
|
|
||||||
try {
|
|
||||||
localStorage.setItem("mcp_oauth_result", JSON.stringify(msg));
|
|
||||||
sent = true;
|
|
||||||
} catch(e) { /* localStorage not available */ }
|
|
||||||
|
|
||||||
var statusEl = document.getElementById("status");
|
|
||||||
var spinnerEl = document.getElementById("spinner");
|
|
||||||
spinnerEl.style.display = "none";
|
|
||||||
|
|
||||||
if (msg.success && sent) {
|
|
||||||
statusEl.textContent = "Sign-in complete! This window will close.";
|
|
||||||
statusEl.style.color = "#059669";
|
|
||||||
setTimeout(function() { window.close(); }, 1500);
|
|
||||||
} else if (msg.success) {
|
|
||||||
statusEl.textContent = "Sign-in successful! You can close this tab and return to the builder.";
|
|
||||||
statusEl.style.color = "#059669";
|
|
||||||
} else {
|
|
||||||
statusEl.textContent = "Sign-in failed: " + (msg.message || "Unknown error");
|
|
||||||
statusEl.style.color = "#dc2626";
|
|
||||||
}
|
|
||||||
})();
|
|
||||||
</script>
|
|
||||||
</body>
|
|
||||||
</html>`,
|
|
||||||
{ headers: { "Content-Type": "text/html" } },
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -4,7 +4,7 @@ import {
|
|||||||
} from "@/app/api/__generated__/endpoints/graphs/graphs";
|
} from "@/app/api/__generated__/endpoints/graphs/graphs";
|
||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
import { parseAsInteger, parseAsString, useQueryStates } from "nuqs";
|
import { parseAsInteger, parseAsString, useQueryStates } from "nuqs";
|
||||||
import { GraphExecutionMeta } from "@/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/use-agent-runs";
|
import { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
|
||||||
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
|
import { useGraphStore } from "@/app/(platform)/build/stores/graphStore";
|
||||||
import { useShallow } from "zustand/react/shallow";
|
import { useShallow } from "zustand/react/shallow";
|
||||||
import { useEffect, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
|
|||||||
@@ -47,10 +47,7 @@ export type CustomNode = XYNode<CustomNodeData, "custom">;
|
|||||||
|
|
||||||
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
||||||
({ data, id: nodeId, selected }) => {
|
({ data, id: nodeId, selected }) => {
|
||||||
const { inputSchema, outputSchema, isMCPWithTool } = useCustomNode({
|
const { inputSchema, outputSchema } = useCustomNode({ data, nodeId });
|
||||||
data,
|
|
||||||
nodeId,
|
|
||||||
});
|
|
||||||
|
|
||||||
const isAgent = data.uiType === BlockUIType.AGENT;
|
const isAgent = data.uiType === BlockUIType.AGENT;
|
||||||
|
|
||||||
@@ -101,7 +98,6 @@ export const CustomNode: React.FC<NodeProps<CustomNode>> = React.memo(
|
|||||||
jsonSchema={preprocessInputSchema(inputSchema)}
|
jsonSchema={preprocessInputSchema(inputSchema)}
|
||||||
nodeId={nodeId}
|
nodeId={nodeId}
|
||||||
uiType={data.uiType}
|
uiType={data.uiType}
|
||||||
isMCPWithTool={isMCPWithTool}
|
|
||||||
className={cn(
|
className={cn(
|
||||||
"bg-white px-4",
|
"bg-white px-4",
|
||||||
isWebhook && "pointer-events-none opacity-50",
|
isWebhook && "pointer-events-none opacity-50",
|
||||||
|
|||||||
@@ -20,8 +20,10 @@ type Props = {
|
|||||||
|
|
||||||
export const NodeHeader = ({ data, nodeId }: Props) => {
|
export const NodeHeader = ({ data, nodeId }: Props) => {
|
||||||
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
||||||
|
const title =
|
||||||
const title = (data.metadata?.customized_name as string) || data.title;
|
(data.metadata?.customized_name as string) ||
|
||||||
|
data.hardcodedValues?.agent_name ||
|
||||||
|
data.title;
|
||||||
|
|
||||||
const [isEditingTitle, setIsEditingTitle] = useState(false);
|
const [isEditingTitle, setIsEditingTitle] = useState(false);
|
||||||
const [editedTitle, setEditedTitle] = useState(title);
|
const [editedTitle, setEditedTitle] = useState(title);
|
||||||
|
|||||||
@@ -3,34 +3,6 @@ import { CustomNodeData } from "./CustomNode";
|
|||||||
import { BlockUIType } from "../../../types";
|
import { BlockUIType } from "../../../types";
|
||||||
import { useMemo } from "react";
|
import { useMemo } from "react";
|
||||||
import { mergeSchemaForResolution } from "./helpers";
|
import { mergeSchemaForResolution } from "./helpers";
|
||||||
/**
|
|
||||||
* Build a dynamic input schema for MCP blocks.
|
|
||||||
*
|
|
||||||
* When a tool has been selected (tool_input_schema is populated), the block
|
|
||||||
* renders the selected tool's input parameters *plus* the credentials field
|
|
||||||
* so users can select/change the OAuth credential used for execution.
|
|
||||||
*
|
|
||||||
* Static fields like server_url, selected_tool, available_tools, and
|
|
||||||
* tool_arguments are hidden because they're pre-configured from the dialog.
|
|
||||||
*/
|
|
||||||
function buildMCPInputSchema(
|
|
||||||
toolInputSchema: Record<string, any>,
|
|
||||||
blockInputSchema: Record<string, any>,
|
|
||||||
): Record<string, any> {
|
|
||||||
// Extract the credentials field from the block's original input schema
|
|
||||||
const credentialsSchema =
|
|
||||||
blockInputSchema?.properties?.credentials ?? undefined;
|
|
||||||
|
|
||||||
return {
|
|
||||||
type: "object",
|
|
||||||
properties: {
|
|
||||||
// Credentials field first so the dropdown appears at the top
|
|
||||||
...(credentialsSchema ? { credentials: credentialsSchema } : {}),
|
|
||||||
...(toolInputSchema.properties ?? {}),
|
|
||||||
},
|
|
||||||
required: [...(toolInputSchema.required ?? [])],
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
export const useCustomNode = ({
|
export const useCustomNode = ({
|
||||||
data,
|
data,
|
||||||
@@ -47,17 +19,9 @@ export const useCustomNode = ({
|
|||||||
);
|
);
|
||||||
|
|
||||||
const isAgent = data.uiType === BlockUIType.AGENT;
|
const isAgent = data.uiType === BlockUIType.AGENT;
|
||||||
const isMCPWithTool =
|
|
||||||
data.uiType === BlockUIType.MCP_TOOL &&
|
|
||||||
!!data.hardcodedValues?.tool_input_schema?.properties;
|
|
||||||
|
|
||||||
const currentInputSchema = isAgent
|
const currentInputSchema = isAgent
|
||||||
? (data.hardcodedValues.input_schema ?? {})
|
? (data.hardcodedValues.input_schema ?? {})
|
||||||
: isMCPWithTool
|
|
||||||
? buildMCPInputSchema(
|
|
||||||
data.hardcodedValues.tool_input_schema,
|
|
||||||
data.inputSchema,
|
|
||||||
)
|
|
||||||
: data.inputSchema;
|
: data.inputSchema;
|
||||||
const currentOutputSchema = isAgent
|
const currentOutputSchema = isAgent
|
||||||
? (data.hardcodedValues.output_schema ?? {})
|
? (data.hardcodedValues.output_schema ?? {})
|
||||||
@@ -90,6 +54,5 @@ export const useCustomNode = ({
|
|||||||
return {
|
return {
|
||||||
inputSchema,
|
inputSchema,
|
||||||
outputSchema,
|
outputSchema,
|
||||||
isMCPWithTool,
|
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -9,72 +9,39 @@ interface FormCreatorProps {
|
|||||||
jsonSchema: RJSFSchema;
|
jsonSchema: RJSFSchema;
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
uiType: BlockUIType;
|
uiType: BlockUIType;
|
||||||
/** When true the block is an MCP Tool with a selected tool. */
|
|
||||||
isMCPWithTool?: boolean;
|
|
||||||
showHandles?: boolean;
|
showHandles?: boolean;
|
||||||
className?: string;
|
className?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const FormCreator: React.FC<FormCreatorProps> = React.memo(
|
export const FormCreator: React.FC<FormCreatorProps> = React.memo(
|
||||||
({
|
({ jsonSchema, nodeId, uiType, showHandles = true, className }) => {
|
||||||
jsonSchema,
|
|
||||||
nodeId,
|
|
||||||
uiType,
|
|
||||||
isMCPWithTool = false,
|
|
||||||
showHandles = true,
|
|
||||||
className,
|
|
||||||
}) => {
|
|
||||||
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
||||||
|
|
||||||
const getHardCodedValues = useNodeStore(
|
const getHardCodedValues = useNodeStore(
|
||||||
(state) => state.getHardCodedValues,
|
(state) => state.getHardCodedValues,
|
||||||
);
|
);
|
||||||
|
|
||||||
const isAgent = uiType === BlockUIType.AGENT;
|
|
||||||
|
|
||||||
const handleChange = ({ formData }: any) => {
|
const handleChange = ({ formData }: any) => {
|
||||||
if ("credentials" in formData && !formData.credentials?.id) {
|
if ("credentials" in formData && !formData.credentials?.id) {
|
||||||
delete formData.credentials;
|
delete formData.credentials;
|
||||||
}
|
}
|
||||||
|
|
||||||
let updatedValues;
|
const updatedValues =
|
||||||
if (isAgent) {
|
uiType === BlockUIType.AGENT
|
||||||
updatedValues = {
|
? {
|
||||||
...getHardCodedValues(nodeId),
|
...getHardCodedValues(nodeId),
|
||||||
inputs: formData,
|
inputs: formData,
|
||||||
};
|
|
||||||
} else if (isMCPWithTool) {
|
|
||||||
// Separate credentials from tool arguments — credentials are stored
|
|
||||||
// at the top level of hardcodedValues, not inside tool_arguments.
|
|
||||||
const { credentials, ...toolArgs } = formData;
|
|
||||||
updatedValues = {
|
|
||||||
...getHardCodedValues(nodeId),
|
|
||||||
tool_arguments: toolArgs,
|
|
||||||
...(credentials?.id ? { credentials } : {}),
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
updatedValues = formData;
|
|
||||||
}
|
}
|
||||||
|
: formData;
|
||||||
|
|
||||||
updateNodeData(nodeId, { hardcodedValues: updatedValues });
|
updateNodeData(nodeId, { hardcodedValues: updatedValues });
|
||||||
};
|
};
|
||||||
|
|
||||||
const hardcodedValues = getHardCodedValues(nodeId);
|
const hardcodedValues = getHardCodedValues(nodeId);
|
||||||
|
const initialValues =
|
||||||
let initialValues;
|
uiType === BlockUIType.AGENT
|
||||||
if (isAgent) {
|
? (hardcodedValues.inputs ?? {})
|
||||||
initialValues = hardcodedValues.inputs ?? {};
|
: hardcodedValues;
|
||||||
} else if (isMCPWithTool) {
|
|
||||||
// Merge tool arguments with credentials for the form
|
|
||||||
initialValues = {
|
|
||||||
...(hardcodedValues.tool_arguments ?? {}),
|
|
||||||
...(hardcodedValues.credentials?.id
|
|
||||||
? { credentials: hardcodedValues.credentials }
|
|
||||||
: {}),
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
initialValues = hardcodedValues;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
|
|||||||
@@ -1,558 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import React, {
|
|
||||||
useState,
|
|
||||||
useCallback,
|
|
||||||
useRef,
|
|
||||||
useEffect,
|
|
||||||
useContext,
|
|
||||||
} from "react";
|
|
||||||
import {
|
|
||||||
Dialog,
|
|
||||||
DialogContent,
|
|
||||||
DialogDescription,
|
|
||||||
DialogFooter,
|
|
||||||
DialogHeader,
|
|
||||||
DialogTitle,
|
|
||||||
} from "@/components/__legacy__/ui/dialog";
|
|
||||||
import { Button } from "@/components/__legacy__/ui/button";
|
|
||||||
import { Input } from "@/components/__legacy__/ui/input";
|
|
||||||
import { Label } from "@/components/__legacy__/ui/label";
|
|
||||||
import { LoadingSpinner } from "@/components/__legacy__/ui/loading";
|
|
||||||
import { Badge } from "@/components/__legacy__/ui/badge";
|
|
||||||
import { ScrollArea } from "@/components/__legacy__/ui/scroll-area";
|
|
||||||
import type { CredentialsMetaInput } from "@/lib/autogpt-server-api";
|
|
||||||
import type { MCPToolResponse } from "@/app/api/__generated__/models/mCPToolResponse";
|
|
||||||
import {
|
|
||||||
postV2DiscoverAvailableToolsOnAnMcpServer,
|
|
||||||
postV2InitiateOauthLoginForAnMcpServer,
|
|
||||||
postV2ExchangeOauthCodeForMcpTokens,
|
|
||||||
} from "@/app/api/__generated__/endpoints/mcp/mcp";
|
|
||||||
import { CaretDown } from "@phosphor-icons/react";
|
|
||||||
import { openOAuthPopup } from "@/lib/oauth-popup";
|
|
||||||
import { CredentialsProvidersContext } from "@/providers/agent-credentials/credentials-provider";
|
|
||||||
|
|
||||||
export type MCPToolDialogResult = {
|
|
||||||
serverUrl: string;
|
|
||||||
serverName: string | null;
|
|
||||||
selectedTool: string;
|
|
||||||
toolInputSchema: Record<string, any>;
|
|
||||||
availableTools: Record<string, any>;
|
|
||||||
/** Credentials meta from OAuth flow, null for public servers. */
|
|
||||||
credentials: CredentialsMetaInput | null;
|
|
||||||
};
|
|
||||||
|
|
||||||
interface MCPToolDialogProps {
|
|
||||||
open: boolean;
|
|
||||||
onClose: () => void;
|
|
||||||
onConfirm: (result: MCPToolDialogResult) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
type DialogStep = "url" | "tool";
|
|
||||||
|
|
||||||
export function MCPToolDialog({
|
|
||||||
open,
|
|
||||||
onClose,
|
|
||||||
onConfirm,
|
|
||||||
}: MCPToolDialogProps) {
|
|
||||||
const allProviders = useContext(CredentialsProvidersContext);
|
|
||||||
|
|
||||||
const [step, setStep] = useState<DialogStep>("url");
|
|
||||||
const [serverUrl, setServerUrl] = useState("");
|
|
||||||
const [tools, setTools] = useState<MCPToolResponse[]>([]);
|
|
||||||
const [serverName, setServerName] = useState<string | null>(null);
|
|
||||||
const [loading, setLoading] = useState(false);
|
|
||||||
const [error, setError] = useState<string | null>(null);
|
|
||||||
const [authRequired, setAuthRequired] = useState(false);
|
|
||||||
const [oauthLoading, setOauthLoading] = useState(false);
|
|
||||||
const [showManualToken, setShowManualToken] = useState(false);
|
|
||||||
const [manualToken, setManualToken] = useState("");
|
|
||||||
const [selectedTool, setSelectedTool] = useState<MCPToolResponse | null>(
|
|
||||||
null,
|
|
||||||
);
|
|
||||||
const [credentials, setCredentials] = useState<CredentialsMetaInput | null>(
|
|
||||||
null,
|
|
||||||
);
|
|
||||||
|
|
||||||
const startOAuthRef = useRef(false);
|
|
||||||
const oauthAbortRef = useRef<((reason?: string) => void) | null>(null);
|
|
||||||
|
|
||||||
// Clean up on unmount
|
|
||||||
useEffect(() => {
|
|
||||||
return () => {
|
|
||||||
oauthAbortRef.current?.();
|
|
||||||
};
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const reset = useCallback(() => {
|
|
||||||
oauthAbortRef.current?.();
|
|
||||||
oauthAbortRef.current = null;
|
|
||||||
setStep("url");
|
|
||||||
setServerUrl("");
|
|
||||||
setManualToken("");
|
|
||||||
setTools([]);
|
|
||||||
setServerName(null);
|
|
||||||
setLoading(false);
|
|
||||||
setError(null);
|
|
||||||
setAuthRequired(false);
|
|
||||||
setOauthLoading(false);
|
|
||||||
setShowManualToken(false);
|
|
||||||
setSelectedTool(null);
|
|
||||||
setCredentials(null);
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const handleClose = useCallback(() => {
|
|
||||||
reset();
|
|
||||||
onClose();
|
|
||||||
}, [reset, onClose]);
|
|
||||||
|
|
||||||
const discoverTools = useCallback(async (url: string, authToken?: string) => {
|
|
||||||
setLoading(true);
|
|
||||||
setError(null);
|
|
||||||
try {
|
|
||||||
const response = await postV2DiscoverAvailableToolsOnAnMcpServer({
|
|
||||||
server_url: url,
|
|
||||||
auth_token: authToken || null,
|
|
||||||
});
|
|
||||||
if (response.status !== 200) throw response.data;
|
|
||||||
setTools(response.data.tools);
|
|
||||||
setServerName(response.data.server_name ?? null);
|
|
||||||
setAuthRequired(false);
|
|
||||||
setShowManualToken(false);
|
|
||||||
setStep("tool");
|
|
||||||
} catch (e: any) {
|
|
||||||
if (e?.status === 401 || e?.status === 403) {
|
|
||||||
setAuthRequired(true);
|
|
||||||
setError(null);
|
|
||||||
// Automatically start OAuth sign-in instead of requiring a second click
|
|
||||||
setLoading(false);
|
|
||||||
startOAuthRef.current = true;
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
const message =
|
|
||||||
e?.message || e?.detail || "Failed to connect to MCP server";
|
|
||||||
setError(
|
|
||||||
typeof message === "string" ? message : JSON.stringify(message),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
setLoading(false);
|
|
||||||
}
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const handleDiscoverTools = useCallback(() => {
|
|
||||||
if (!serverUrl.trim()) return;
|
|
||||||
discoverTools(serverUrl.trim(), manualToken.trim() || undefined);
|
|
||||||
}, [serverUrl, manualToken, discoverTools]);
|
|
||||||
|
|
||||||
const handleOAuthSignIn = useCallback(async () => {
|
|
||||||
if (!serverUrl.trim()) return;
|
|
||||||
setError(null);
|
|
||||||
|
|
||||||
// Abort any previous OAuth flow
|
|
||||||
oauthAbortRef.current?.();
|
|
||||||
|
|
||||||
setOauthLoading(true);
|
|
||||||
|
|
||||||
try {
|
|
||||||
const loginResponse = await postV2InitiateOauthLoginForAnMcpServer({
|
|
||||||
server_url: serverUrl.trim(),
|
|
||||||
});
|
|
||||||
if (loginResponse.status !== 200) throw loginResponse.data;
|
|
||||||
const { login_url, state_token } = loginResponse.data;
|
|
||||||
|
|
||||||
const { promise, cleanup } = openOAuthPopup(login_url, {
|
|
||||||
stateToken: state_token,
|
|
||||||
useCrossOriginListeners: true,
|
|
||||||
});
|
|
||||||
oauthAbortRef.current = cleanup.abort;
|
|
||||||
|
|
||||||
const result = await promise;
|
|
||||||
|
|
||||||
// Exchange code for tokens via the credentials provider (updates cache)
|
|
||||||
setLoading(true);
|
|
||||||
setOauthLoading(false);
|
|
||||||
|
|
||||||
const mcpProvider = allProviders?.["mcp"];
|
|
||||||
let callbackResult;
|
|
||||||
if (mcpProvider) {
|
|
||||||
callbackResult = await mcpProvider.mcpOAuthCallback(
|
|
||||||
result.code,
|
|
||||||
state_token,
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
const cbResponse = await postV2ExchangeOauthCodeForMcpTokens({
|
|
||||||
code: result.code,
|
|
||||||
state_token,
|
|
||||||
});
|
|
||||||
if (cbResponse.status !== 200) throw cbResponse.data;
|
|
||||||
callbackResult = cbResponse.data;
|
|
||||||
}
|
|
||||||
|
|
||||||
setCredentials({
|
|
||||||
id: callbackResult.id,
|
|
||||||
provider: callbackResult.provider,
|
|
||||||
type: callbackResult.type,
|
|
||||||
title: callbackResult.title,
|
|
||||||
});
|
|
||||||
setAuthRequired(false);
|
|
||||||
|
|
||||||
// Discover tools now that we're authenticated
|
|
||||||
const toolsResponse = await postV2DiscoverAvailableToolsOnAnMcpServer({
|
|
||||||
server_url: serverUrl.trim(),
|
|
||||||
});
|
|
||||||
if (toolsResponse.status !== 200) throw toolsResponse.data;
|
|
||||||
setTools(toolsResponse.data.tools);
|
|
||||||
setServerName(toolsResponse.data.server_name ?? null);
|
|
||||||
setStep("tool");
|
|
||||||
} catch (e: any) {
|
|
||||||
// If server doesn't support OAuth → show manual token entry
|
|
||||||
if (e?.status === 400) {
|
|
||||||
setShowManualToken(true);
|
|
||||||
setError(
|
|
||||||
"This server does not support OAuth sign-in. Please enter a token manually.",
|
|
||||||
);
|
|
||||||
} else if (e?.message === "OAuth flow timed out") {
|
|
||||||
setError("OAuth sign-in timed out. Please try again.");
|
|
||||||
} else {
|
|
||||||
const status = e?.status;
|
|
||||||
let message: string;
|
|
||||||
if (status === 401 || status === 403) {
|
|
||||||
message =
|
|
||||||
"Authentication succeeded but the server still rejected the request. " +
|
|
||||||
"The token audience may not match. Please try again.";
|
|
||||||
} else {
|
|
||||||
message = e?.message || e?.detail || "Failed to complete sign-in";
|
|
||||||
}
|
|
||||||
setError(
|
|
||||||
typeof message === "string" ? message : JSON.stringify(message),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
setOauthLoading(false);
|
|
||||||
setLoading(false);
|
|
||||||
oauthAbortRef.current = null;
|
|
||||||
}
|
|
||||||
}, [serverUrl, allProviders]);
|
|
||||||
|
|
||||||
// Auto-start OAuth sign-in when server returns 401/403
|
|
||||||
useEffect(() => {
|
|
||||||
if (authRequired && startOAuthRef.current) {
|
|
||||||
startOAuthRef.current = false;
|
|
||||||
handleOAuthSignIn();
|
|
||||||
}
|
|
||||||
}, [authRequired, handleOAuthSignIn]);
|
|
||||||
|
|
||||||
const handleConfirm = useCallback(() => {
|
|
||||||
if (!selectedTool) return;
|
|
||||||
|
|
||||||
const availableTools: Record<string, any> = {};
|
|
||||||
for (const t of tools) {
|
|
||||||
availableTools[t.name] = {
|
|
||||||
description: t.description,
|
|
||||||
input_schema: t.input_schema,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
onConfirm({
|
|
||||||
serverUrl: serverUrl.trim(),
|
|
||||||
serverName,
|
|
||||||
selectedTool: selectedTool.name,
|
|
||||||
toolInputSchema: selectedTool.input_schema,
|
|
||||||
availableTools,
|
|
||||||
credentials,
|
|
||||||
});
|
|
||||||
reset();
|
|
||||||
}, [
|
|
||||||
selectedTool,
|
|
||||||
tools,
|
|
||||||
serverUrl,
|
|
||||||
serverName,
|
|
||||||
credentials,
|
|
||||||
onConfirm,
|
|
||||||
reset,
|
|
||||||
]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Dialog open={open} onOpenChange={(isOpen) => !isOpen && handleClose()}>
|
|
||||||
<DialogContent className="max-w-lg">
|
|
||||||
<DialogHeader>
|
|
||||||
<DialogTitle>
|
|
||||||
{step === "url"
|
|
||||||
? "Connect to MCP Server"
|
|
||||||
: `Select a Tool${serverName ? ` — ${serverName}` : ""}`}
|
|
||||||
</DialogTitle>
|
|
||||||
<DialogDescription>
|
|
||||||
{step === "url"
|
|
||||||
? "Enter the URL of an MCP server to discover its available tools."
|
|
||||||
: `Found ${tools.length} tool${tools.length !== 1 ? "s" : ""}. Select one to add to your agent.`}
|
|
||||||
</DialogDescription>
|
|
||||||
</DialogHeader>
|
|
||||||
|
|
||||||
{step === "url" && (
|
|
||||||
<div className="flex flex-col gap-4 py-2">
|
|
||||||
<div className="flex flex-col gap-2">
|
|
||||||
<Label htmlFor="mcp-server-url">Server URL</Label>
|
|
||||||
<Input
|
|
||||||
id="mcp-server-url"
|
|
||||||
type="url"
|
|
||||||
placeholder="https://mcp.example.com/mcp"
|
|
||||||
value={serverUrl}
|
|
||||||
onChange={(e) => setServerUrl(e.target.value)}
|
|
||||||
onKeyDown={(e) => e.key === "Enter" && handleDiscoverTools()}
|
|
||||||
autoFocus
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Auth required: show manual token option */}
|
|
||||||
{authRequired && !showManualToken && (
|
|
||||||
<button
|
|
||||||
onClick={() => setShowManualToken(true)}
|
|
||||||
className="text-xs text-gray-500 underline hover:text-gray-700 dark:text-gray-400 dark:hover:text-gray-300"
|
|
||||||
>
|
|
||||||
or enter a token manually
|
|
||||||
</button>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{/* Manual token entry — only visible when expanded */}
|
|
||||||
{showManualToken && (
|
|
||||||
<div className="flex flex-col gap-2">
|
|
||||||
<Label htmlFor="mcp-auth-token" className="text-sm">
|
|
||||||
Bearer Token
|
|
||||||
</Label>
|
|
||||||
<Input
|
|
||||||
id="mcp-auth-token"
|
|
||||||
type="password"
|
|
||||||
placeholder="Paste your auth token here"
|
|
||||||
value={manualToken}
|
|
||||||
onChange={(e) => setManualToken(e.target.value)}
|
|
||||||
onKeyDown={(e) => e.key === "Enter" && handleDiscoverTools()}
|
|
||||||
autoFocus
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{error && <p className="text-sm text-red-500">{error}</p>}
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{step === "tool" && (
|
|
||||||
<ScrollArea className="max-h-[50vh] py-2">
|
|
||||||
<div className="flex flex-col gap-2 pr-3">
|
|
||||||
{tools.map((tool) => (
|
|
||||||
<MCPToolCard
|
|
||||||
key={tool.name}
|
|
||||||
tool={tool}
|
|
||||||
selected={selectedTool?.name === tool.name}
|
|
||||||
onSelect={() => setSelectedTool(tool)}
|
|
||||||
/>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
</ScrollArea>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<DialogFooter>
|
|
||||||
{step === "tool" && (
|
|
||||||
<Button
|
|
||||||
variant="outline"
|
|
||||||
onClick={() => {
|
|
||||||
setStep("url");
|
|
||||||
setSelectedTool(null);
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
Back
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
<Button variant="outline" onClick={handleClose}>
|
|
||||||
Cancel
|
|
||||||
</Button>
|
|
||||||
{step === "url" && (
|
|
||||||
<Button
|
|
||||||
onClick={
|
|
||||||
authRequired && !showManualToken
|
|
||||||
? handleOAuthSignIn
|
|
||||||
: handleDiscoverTools
|
|
||||||
}
|
|
||||||
disabled={!serverUrl.trim() || loading || oauthLoading}
|
|
||||||
>
|
|
||||||
{loading || oauthLoading ? (
|
|
||||||
<span className="flex items-center gap-2">
|
|
||||||
<LoadingSpinner className="size-4" />
|
|
||||||
{oauthLoading ? "Waiting for sign-in..." : "Connecting..."}
|
|
||||||
</span>
|
|
||||||
) : authRequired && !showManualToken ? (
|
|
||||||
"Sign in & Connect"
|
|
||||||
) : (
|
|
||||||
"Discover Tools"
|
|
||||||
)}
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
{step === "tool" && (
|
|
||||||
<Button onClick={handleConfirm} disabled={!selectedTool}>
|
|
||||||
Add Block
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
</DialogFooter>
|
|
||||||
</DialogContent>
|
|
||||||
</Dialog>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// --------------- Tool Card Component --------------- //
|
|
||||||
|
|
||||||
/** Truncate a description to a reasonable length for the collapsed view. */
|
|
||||||
function truncateDescription(text: string, maxLen = 120): string {
|
|
||||||
if (text.length <= maxLen) return text;
|
|
||||||
return text.slice(0, maxLen).trimEnd() + "…";
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Pretty-print a JSON Schema type for a parameter. */
|
|
||||||
function schemaTypeLabel(schema: Record<string, any>): string {
|
|
||||||
if (schema.type) return schema.type;
|
|
||||||
if (schema.anyOf)
|
|
||||||
return schema.anyOf.map((s: any) => s.type ?? "any").join(" | ");
|
|
||||||
if (schema.oneOf)
|
|
||||||
return schema.oneOf.map((s: any) => s.type ?? "any").join(" | ");
|
|
||||||
return "any";
|
|
||||||
}
|
|
||||||
|
|
||||||
function MCPToolCard({
|
|
||||||
tool,
|
|
||||||
selected,
|
|
||||||
onSelect,
|
|
||||||
}: {
|
|
||||||
tool: MCPToolResponse;
|
|
||||||
selected: boolean;
|
|
||||||
onSelect: () => void;
|
|
||||||
}) {
|
|
||||||
const [expanded, setExpanded] = useState(false);
|
|
||||||
const schema = tool.input_schema as Record<string, any>;
|
|
||||||
const properties = schema?.properties ?? {};
|
|
||||||
const required = new Set<string>(schema?.required ?? []);
|
|
||||||
const paramNames = Object.keys(properties);
|
|
||||||
|
|
||||||
// Strip XML-like tags from description for cleaner display.
|
|
||||||
// Loop to handle nested tags like <scr<script>ipt> (CodeQL fix).
|
|
||||||
let cleanDescription = tool.description ?? "";
|
|
||||||
let prev = "";
|
|
||||||
while (prev !== cleanDescription) {
|
|
||||||
prev = cleanDescription;
|
|
||||||
cleanDescription = cleanDescription.replace(/<[^>]*>/g, "");
|
|
||||||
}
|
|
||||||
cleanDescription = cleanDescription.trim();
|
|
||||||
|
|
||||||
return (
|
|
||||||
<button
|
|
||||||
onClick={onSelect}
|
|
||||||
className={`group flex flex-col rounded-lg border text-left transition-colors ${
|
|
||||||
selected
|
|
||||||
? "border-blue-500 bg-blue-50 dark:border-blue-400 dark:bg-blue-950"
|
|
||||||
: "border-gray-200 hover:border-gray-300 hover:bg-gray-50 dark:border-slate-700 dark:hover:border-slate-600 dark:hover:bg-slate-800"
|
|
||||||
}`}
|
|
||||||
>
|
|
||||||
{/* Header */}
|
|
||||||
<div className="flex items-center gap-2 px-3 pb-1 pt-3">
|
|
||||||
<span className="flex-1 text-sm font-semibold dark:text-white">
|
|
||||||
{tool.name}
|
|
||||||
</span>
|
|
||||||
{paramNames.length > 0 && (
|
|
||||||
<Badge variant="secondary" className="text-[10px]">
|
|
||||||
{paramNames.length} param{paramNames.length !== 1 ? "s" : ""}
|
|
||||||
</Badge>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Description (collapsed: truncated) */}
|
|
||||||
{cleanDescription && (
|
|
||||||
<p className="px-3 pb-1 text-xs leading-relaxed text-gray-500 dark:text-gray-400">
|
|
||||||
{expanded ? cleanDescription : truncateDescription(cleanDescription)}
|
|
||||||
</p>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{/* Parameter badges (collapsed view) */}
|
|
||||||
{!expanded && paramNames.length > 0 && (
|
|
||||||
<div className="flex flex-wrap gap-1 px-3 pb-2">
|
|
||||||
{paramNames.slice(0, 6).map((name) => (
|
|
||||||
<Badge
|
|
||||||
key={name}
|
|
||||||
variant="outline"
|
|
||||||
className="text-[10px] font-normal"
|
|
||||||
>
|
|
||||||
{name}
|
|
||||||
{required.has(name) && (
|
|
||||||
<span className="ml-0.5 text-red-400">*</span>
|
|
||||||
)}
|
|
||||||
</Badge>
|
|
||||||
))}
|
|
||||||
{paramNames.length > 6 && (
|
|
||||||
<Badge variant="outline" className="text-[10px] font-normal">
|
|
||||||
+{paramNames.length - 6} more
|
|
||||||
</Badge>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{/* Expanded: full parameter details */}
|
|
||||||
{expanded && paramNames.length > 0 && (
|
|
||||||
<div className="mx-3 mb-2 rounded border border-gray-100 bg-gray-50/50 dark:border-slate-700 dark:bg-slate-800/50">
|
|
||||||
<table className="w-full text-xs">
|
|
||||||
<thead>
|
|
||||||
<tr className="border-b border-gray-100 dark:border-slate-700">
|
|
||||||
<th className="px-2 py-1 text-left font-medium text-gray-500 dark:text-gray-400">
|
|
||||||
Parameter
|
|
||||||
</th>
|
|
||||||
<th className="px-2 py-1 text-left font-medium text-gray-500 dark:text-gray-400">
|
|
||||||
Type
|
|
||||||
</th>
|
|
||||||
<th className="px-2 py-1 text-left font-medium text-gray-500 dark:text-gray-400">
|
|
||||||
Description
|
|
||||||
</th>
|
|
||||||
</tr>
|
|
||||||
</thead>
|
|
||||||
<tbody>
|
|
||||||
{paramNames.map((name) => {
|
|
||||||
const prop = properties[name] ?? {};
|
|
||||||
return (
|
|
||||||
<tr
|
|
||||||
key={name}
|
|
||||||
className="border-b border-gray-50 last:border-0 dark:border-slate-700/50"
|
|
||||||
>
|
|
||||||
<td className="px-2 py-1 font-mono text-[11px] text-gray-700 dark:text-gray-300">
|
|
||||||
{name}
|
|
||||||
{required.has(name) && (
|
|
||||||
<span className="ml-0.5 text-red-400">*</span>
|
|
||||||
)}
|
|
||||||
</td>
|
|
||||||
<td className="px-2 py-1 text-gray-500 dark:text-gray-400">
|
|
||||||
{schemaTypeLabel(prop)}
|
|
||||||
</td>
|
|
||||||
<td className="max-w-[200px] truncate px-2 py-1 text-gray-500 dark:text-gray-400">
|
|
||||||
{prop.description ?? "—"}
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
);
|
|
||||||
})}
|
|
||||||
</tbody>
|
|
||||||
</table>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{/* Toggle details */}
|
|
||||||
{(paramNames.length > 0 || cleanDescription.length > 120) && (
|
|
||||||
<button
|
|
||||||
type="button"
|
|
||||||
onClick={(e) => {
|
|
||||||
e.stopPropagation();
|
|
||||||
setExpanded((prev) => !prev);
|
|
||||||
}}
|
|
||||||
className="flex w-full items-center justify-center gap-1 border-t border-gray-100 py-1.5 text-[10px] text-gray-400 hover:text-gray-600 dark:border-slate-700 dark:text-gray-500 dark:hover:text-gray-300"
|
|
||||||
>
|
|
||||||
{expanded ? "Hide details" : "Show details"}
|
|
||||||
<CaretDown
|
|
||||||
className={`h-3 w-3 transition-transform ${expanded ? "rotate-180" : ""}`}
|
|
||||||
/>
|
|
||||||
</button>
|
|
||||||
)}
|
|
||||||
</button>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import { Button } from "@/components/__legacy__/ui/button";
|
import { Button } from "@/components/__legacy__/ui/button";
|
||||||
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
import { Skeleton } from "@/components/__legacy__/ui/skeleton";
|
||||||
import { beautifyString, cn } from "@/lib/utils";
|
import { beautifyString, cn } from "@/lib/utils";
|
||||||
import React, { ButtonHTMLAttributes, useCallback, useState } from "react";
|
import React, { ButtonHTMLAttributes } from "react";
|
||||||
import { highlightText } from "./helpers";
|
import { highlightText } from "./helpers";
|
||||||
import { PlusIcon } from "@phosphor-icons/react";
|
import { PlusIcon } from "@phosphor-icons/react";
|
||||||
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
import { BlockInfo } from "@/app/api/__generated__/models/blockInfo";
|
||||||
@@ -9,12 +9,6 @@ import { useControlPanelStore } from "../../../stores/controlPanelStore";
|
|||||||
import { blockDragPreviewStyle } from "./style";
|
import { blockDragPreviewStyle } from "./style";
|
||||||
import { useReactFlow } from "@xyflow/react";
|
import { useReactFlow } from "@xyflow/react";
|
||||||
import { useNodeStore } from "../../../stores/nodeStore";
|
import { useNodeStore } from "../../../stores/nodeStore";
|
||||||
import { BlockUIType, SpecialBlockID } from "@/lib/autogpt-server-api";
|
|
||||||
import {
|
|
||||||
MCPToolDialog,
|
|
||||||
type MCPToolDialogResult,
|
|
||||||
} from "@/app/(platform)/build/components/MCPToolDialog";
|
|
||||||
|
|
||||||
interface Props extends ButtonHTMLAttributes<HTMLButtonElement> {
|
interface Props extends ButtonHTMLAttributes<HTMLButtonElement> {
|
||||||
title?: string;
|
title?: string;
|
||||||
description?: string;
|
description?: string;
|
||||||
@@ -39,13 +33,9 @@ export const Block: BlockComponent = ({
|
|||||||
);
|
);
|
||||||
const { setViewport } = useReactFlow();
|
const { setViewport } = useReactFlow();
|
||||||
const { addBlock } = useNodeStore();
|
const { addBlock } = useNodeStore();
|
||||||
const [mcpDialogOpen, setMcpDialogOpen] = useState(false);
|
|
||||||
|
|
||||||
const isMCPBlock = blockData.uiType === BlockUIType.MCP_TOOL;
|
const handleClick = () => {
|
||||||
|
const customNode = addBlock(blockData);
|
||||||
const addBlockAndCenter = useCallback(
|
|
||||||
(block: BlockInfo, hardcodedValues?: Record<string, any>) => {
|
|
||||||
const customNode = addBlock(block, hardcodedValues);
|
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
setViewport(
|
setViewport(
|
||||||
{
|
{
|
||||||
@@ -56,69 +46,9 @@ export const Block: BlockComponent = ({
|
|||||||
{ duration: 500 },
|
{ duration: 500 },
|
||||||
);
|
);
|
||||||
}, 50);
|
}, 50);
|
||||||
return customNode;
|
|
||||||
},
|
|
||||||
[addBlock, setViewport],
|
|
||||||
);
|
|
||||||
|
|
||||||
const updateNodeData = useNodeStore((state) => state.updateNodeData);
|
|
||||||
|
|
||||||
const handleMCPToolConfirm = useCallback(
|
|
||||||
(result: MCPToolDialogResult) => {
|
|
||||||
// Derive a display label: prefer server name, fall back to URL hostname.
|
|
||||||
let serverLabel = result.serverName;
|
|
||||||
if (!serverLabel) {
|
|
||||||
try {
|
|
||||||
serverLabel = new URL(result.serverUrl).hostname;
|
|
||||||
} catch {
|
|
||||||
serverLabel = "MCP";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const customNode = addBlockAndCenter(blockData, {
|
|
||||||
server_url: result.serverUrl,
|
|
||||||
server_name: serverLabel,
|
|
||||||
selected_tool: result.selectedTool,
|
|
||||||
tool_input_schema: result.toolInputSchema,
|
|
||||||
available_tools: result.availableTools,
|
|
||||||
credentials: result.credentials ?? undefined,
|
|
||||||
});
|
|
||||||
if (customNode) {
|
|
||||||
const title = result.selectedTool
|
|
||||||
? `${serverLabel}: ${beautifyString(result.selectedTool)}`
|
|
||||||
: undefined;
|
|
||||||
updateNodeData(customNode.id, {
|
|
||||||
metadata: {
|
|
||||||
...customNode.data.metadata,
|
|
||||||
credentials_optional: true,
|
|
||||||
...(title && { customized_name: title }),
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
setMcpDialogOpen(false);
|
|
||||||
},
|
|
||||||
[addBlockAndCenter, blockData, updateNodeData],
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleClick = () => {
|
|
||||||
if (isMCPBlock) {
|
|
||||||
setMcpDialogOpen(true);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const customNode = addBlockAndCenter(blockData);
|
|
||||||
// Set customized_name for agent blocks so the agent's name persists
|
|
||||||
if (customNode && blockData.id === SpecialBlockID.AGENT) {
|
|
||||||
updateNodeData(customNode.id, {
|
|
||||||
metadata: {
|
|
||||||
...customNode.data.metadata,
|
|
||||||
customized_name: blockData.name,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleDragStart = (e: React.DragEvent<HTMLButtonElement>) => {
|
const handleDragStart = (e: React.DragEvent<HTMLButtonElement>) => {
|
||||||
if (isMCPBlock) return;
|
|
||||||
e.dataTransfer.effectAllowed = "copy";
|
e.dataTransfer.effectAllowed = "copy";
|
||||||
e.dataTransfer.setData("application/reactflow", JSON.stringify(blockData));
|
e.dataTransfer.setData("application/reactflow", JSON.stringify(blockData));
|
||||||
|
|
||||||
@@ -141,14 +71,12 @@ export const Block: BlockComponent = ({
|
|||||||
: undefined;
|
: undefined;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
|
||||||
<Button
|
<Button
|
||||||
draggable={!isMCPBlock}
|
draggable={true}
|
||||||
data-id={blockDataId}
|
data-id={blockDataId}
|
||||||
className={cn(
|
className={cn(
|
||||||
"group flex h-16 w-full min-w-[7.5rem] items-center justify-start space-x-3 whitespace-normal rounded-[0.75rem] bg-zinc-50 px-[0.875rem] py-[0.625rem] text-start shadow-none",
|
"group flex h-16 w-full min-w-[7.5rem] items-center justify-start space-x-3 whitespace-normal rounded-[0.75rem] bg-zinc-50 px-[0.875rem] py-[0.625rem] text-start shadow-none",
|
||||||
"hover:cursor-default hover:bg-zinc-100 focus:ring-0 active:bg-zinc-100 active:ring-1 active:ring-zinc-300 disabled:cursor-not-allowed",
|
"hover:cursor-default hover:bg-zinc-100 focus:ring-0 active:bg-zinc-100 active:ring-1 active:ring-zinc-300 disabled:cursor-not-allowed",
|
||||||
isMCPBlock && "hover:cursor-pointer",
|
|
||||||
className,
|
className,
|
||||||
)}
|
)}
|
||||||
onDragStart={handleDragStart}
|
onDragStart={handleDragStart}
|
||||||
@@ -183,14 +111,6 @@ export const Block: BlockComponent = ({
|
|||||||
<PlusIcon className="h-5 w-5 text-zinc-50" />
|
<PlusIcon className="h-5 w-5 text-zinc-50" />
|
||||||
</div>
|
</div>
|
||||||
</Button>
|
</Button>
|
||||||
{isMCPBlock && (
|
|
||||||
<MCPToolDialog
|
|
||||||
open={mcpDialogOpen}
|
|
||||||
onClose={() => setMcpDialogOpen(false)}
|
|
||||||
onConfirm={handleMCPToolConfirm}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</>
|
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { useCallback } from "react";
|
import { useCallback } from "react";
|
||||||
|
|
||||||
import { AgentRunDraftView } from "@/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-draft-view";
|
import { AgentRunDraftView } from "@/app/(platform)/build/components/legacy-builder/agent-run-draft-view";
|
||||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||||
import type {
|
import type {
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import {
|
|||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { getGetV2ListMySubmissionsQueryKey } from "@/app/api/__generated__/endpoints/store/store";
|
import { getGetV2ListMySubmissionsQueryKey } from "@/app/api/__generated__/endpoints/store/store";
|
||||||
import { CronExpressionDialog } from "@/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/cron-scheduler-dialog";
|
import { CronExpressionDialog } from "@/components/contextual/CronScheduler/cron-scheduler-dialog";
|
||||||
import { humanizeCronExpression } from "@/lib/cron-expression-utils";
|
import { humanizeCronExpression } from "@/lib/cron-expression-utils";
|
||||||
import { CalendarClockIcon } from "lucide-react";
|
import { CalendarClockIcon } from "lucide-react";
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import {
|
|||||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||||
|
|
||||||
import { RunAgentInputs } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentInputs/RunAgentInputs";
|
import { RunAgentInputs } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentInputs/RunAgentInputs";
|
||||||
import { ScheduleTaskDialog } from "@/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/cron-scheduler-dialog";
|
import { ScheduleTaskDialog } from "@/components/contextual/CronScheduler/cron-scheduler-dialog";
|
||||||
import ActionButtonGroup from "@/components/__legacy__/action-button-group";
|
import ActionButtonGroup from "@/components/__legacy__/action-button-group";
|
||||||
import type { ButtonAction } from "@/components/__legacy__/types";
|
import type { ButtonAction } from "@/components/__legacy__/types";
|
||||||
import {
|
import {
|
||||||
@@ -53,7 +53,10 @@ import { ClockIcon, CopyIcon, InfoIcon } from "@phosphor-icons/react";
|
|||||||
import { CalendarClockIcon, Trash2Icon } from "lucide-react";
|
import { CalendarClockIcon, Trash2Icon } from "lucide-react";
|
||||||
|
|
||||||
import { analytics } from "@/services/analytics";
|
import { analytics } from "@/services/analytics";
|
||||||
import { AgentStatus, AgentStatusChip } from "./agent-status-chip";
|
import {
|
||||||
|
AgentStatus,
|
||||||
|
AgentStatusChip,
|
||||||
|
} from "@/app/(platform)/build/components/legacy-builder/agent-status-chip";
|
||||||
|
|
||||||
export function AgentRunDraftView({
|
export function AgentRunDraftView({
|
||||||
graph,
|
graph,
|
||||||
@@ -9,5 +9,4 @@ export enum BlockUIType {
|
|||||||
AGENT = "Agent",
|
AGENT = "Agent",
|
||||||
AI = "AI",
|
AI = "AI",
|
||||||
AYRSHARE = "Ayrshare",
|
AYRSHARE = "Ayrshare",
|
||||||
MCP_TOOL = "MCP Tool",
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ import { FindBlocksTool } from "../../tools/FindBlocks/FindBlocks";
|
|||||||
import { RunAgentTool } from "../../tools/RunAgent/RunAgent";
|
import { RunAgentTool } from "../../tools/RunAgent/RunAgent";
|
||||||
import { RunBlockTool } from "../../tools/RunBlock/RunBlock";
|
import { RunBlockTool } from "../../tools/RunBlock/RunBlock";
|
||||||
import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs";
|
import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs";
|
||||||
import { GenericTool } from "../../tools/GenericTool/GenericTool";
|
|
||||||
import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput";
|
import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput";
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -274,16 +273,6 @@ export const ChatMessagesContainer = ({
|
|||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
default:
|
default:
|
||||||
// Render a generic tool indicator for SDK built-in
|
|
||||||
// tools (Read, Glob, Grep, etc.) or any unrecognized tool
|
|
||||||
if (part.type.startsWith("tool-")) {
|
|
||||||
return (
|
|
||||||
<GenericTool
|
|
||||||
key={`${message.id}-${i}`}
|
|
||||||
part={part as ToolUIPart}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
})}
|
})}
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { ToolUIPart } from "ai";
|
|
||||||
import { GearIcon } from "@phosphor-icons/react";
|
|
||||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
part: ToolUIPart;
|
|
||||||
}
|
|
||||||
|
|
||||||
function extractToolName(part: ToolUIPart): string {
|
|
||||||
// ToolUIPart.type is "tool-{name}", extract the name portion.
|
|
||||||
return part.type.replace(/^tool-/, "");
|
|
||||||
}
|
|
||||||
|
|
||||||
function formatToolName(name: string): string {
|
|
||||||
// "search_docs" → "Search docs", "Read" → "Read"
|
|
||||||
return name.replace(/_/g, " ").replace(/^\w/, (c) => c.toUpperCase());
|
|
||||||
}
|
|
||||||
|
|
||||||
function getAnimationText(part: ToolUIPart): string {
|
|
||||||
const label = formatToolName(extractToolName(part));
|
|
||||||
|
|
||||||
switch (part.state) {
|
|
||||||
case "input-streaming":
|
|
||||||
case "input-available":
|
|
||||||
return `Running ${label}…`;
|
|
||||||
case "output-available":
|
|
||||||
return `${label} completed`;
|
|
||||||
case "output-error":
|
|
||||||
return `${label} failed`;
|
|
||||||
default:
|
|
||||||
return `Running ${label}…`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export function GenericTool({ part }: Props) {
|
|
||||||
const isStreaming =
|
|
||||||
part.state === "input-streaming" || part.state === "input-available";
|
|
||||||
const isError = part.state === "output-error";
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="py-2">
|
|
||||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
|
||||||
<GearIcon
|
|
||||||
size={14}
|
|
||||||
weight="regular"
|
|
||||||
className={
|
|
||||||
isError
|
|
||||||
? "text-red-500"
|
|
||||||
: isStreaming
|
|
||||||
? "animate-spin text-neutral-500"
|
|
||||||
: "text-neutral-400"
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
<MorphingTextAnimation
|
|
||||||
text={getAnimationText(part)}
|
|
||||||
className={isError ? "text-red-500" : undefined}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,631 +0,0 @@
|
|||||||
"use client";
|
|
||||||
import { useParams, useRouter } from "next/navigation";
|
|
||||||
import { useQueryState } from "nuqs";
|
|
||||||
import React, {
|
|
||||||
useCallback,
|
|
||||||
useEffect,
|
|
||||||
useMemo,
|
|
||||||
useRef,
|
|
||||||
useState,
|
|
||||||
} from "react";
|
|
||||||
|
|
||||||
import {
|
|
||||||
Graph,
|
|
||||||
GraphExecution,
|
|
||||||
GraphExecutionID,
|
|
||||||
GraphExecutionMeta,
|
|
||||||
GraphID,
|
|
||||||
LibraryAgent,
|
|
||||||
LibraryAgentID,
|
|
||||||
LibraryAgentPreset,
|
|
||||||
LibraryAgentPresetID,
|
|
||||||
Schedule,
|
|
||||||
ScheduleID,
|
|
||||||
} from "@/lib/autogpt-server-api";
|
|
||||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
|
||||||
import { exportAsJSONFile } from "@/lib/utils";
|
|
||||||
|
|
||||||
import DeleteConfirmDialog from "@/components/__legacy__/delete-confirm-dialog";
|
|
||||||
import type { ButtonAction } from "@/components/__legacy__/types";
|
|
||||||
import { Button } from "@/components/__legacy__/ui/button";
|
|
||||||
import {
|
|
||||||
Dialog,
|
|
||||||
DialogContent,
|
|
||||||
DialogDescription,
|
|
||||||
DialogFooter,
|
|
||||||
DialogHeader,
|
|
||||||
DialogTitle,
|
|
||||||
} from "@/components/__legacy__/ui/dialog";
|
|
||||||
import LoadingBox, { LoadingSpinner } from "@/components/__legacy__/ui/loading";
|
|
||||||
import {
|
|
||||||
useToast,
|
|
||||||
useToastOnFail,
|
|
||||||
} from "@/components/molecules/Toast/use-toast";
|
|
||||||
import { AgentRunDetailsView } from "./components/agent-run-details-view";
|
|
||||||
import { AgentRunDraftView } from "./components/agent-run-draft-view";
|
|
||||||
import { CreatePresetDialog } from "./components/create-preset-dialog";
|
|
||||||
import { useAgentRunsInfinite } from "./use-agent-runs";
|
|
||||||
import { AgentRunsSelectorList } from "./components/agent-runs-selector-list";
|
|
||||||
import { AgentScheduleDetailsView } from "./components/agent-schedule-details-view";
|
|
||||||
|
|
||||||
export function OldAgentLibraryView() {
|
|
||||||
const { id: agentID }: { id: LibraryAgentID } = useParams();
|
|
||||||
const [executionId, setExecutionId] = useQueryState("executionId");
|
|
||||||
const toastOnFail = useToastOnFail();
|
|
||||||
const { toast } = useToast();
|
|
||||||
const router = useRouter();
|
|
||||||
const api = useBackendAPI();
|
|
||||||
|
|
||||||
// ============================ STATE =============================
|
|
||||||
|
|
||||||
const [graph, setGraph] = useState<Graph | null>(null); // Graph version corresponding to LibraryAgent
|
|
||||||
const [agent, setAgent] = useState<LibraryAgent | null>(null);
|
|
||||||
const agentRunsQuery = useAgentRunsInfinite(graph?.id); // only runs once graph.id is known
|
|
||||||
const agentRuns = agentRunsQuery.agentRuns;
|
|
||||||
const [agentPresets, setAgentPresets] = useState<LibraryAgentPreset[]>([]);
|
|
||||||
const [schedules, setSchedules] = useState<Schedule[]>([]);
|
|
||||||
const [selectedView, selectView] = useState<
|
|
||||||
| { type: "run"; id?: GraphExecutionID }
|
|
||||||
| { type: "preset"; id: LibraryAgentPresetID }
|
|
||||||
| { type: "schedule"; id: ScheduleID }
|
|
||||||
>({ type: "run" });
|
|
||||||
const [selectedRun, setSelectedRun] = useState<
|
|
||||||
GraphExecution | GraphExecutionMeta | null
|
|
||||||
>(null);
|
|
||||||
const selectedSchedule =
|
|
||||||
selectedView.type == "schedule"
|
|
||||||
? schedules.find((s) => s.id == selectedView.id)
|
|
||||||
: null;
|
|
||||||
const [isFirstLoad, setIsFirstLoad] = useState<boolean>(true);
|
|
||||||
const [agentDeleteDialogOpen, setAgentDeleteDialogOpen] =
|
|
||||||
useState<boolean>(false);
|
|
||||||
const [confirmingDeleteAgentRun, setConfirmingDeleteAgentRun] =
|
|
||||||
useState<GraphExecutionMeta | null>(null);
|
|
||||||
const [confirmingDeleteAgentPreset, setConfirmingDeleteAgentPreset] =
|
|
||||||
useState<LibraryAgentPresetID | null>(null);
|
|
||||||
const [copyAgentDialogOpen, setCopyAgentDialogOpen] = useState(false);
|
|
||||||
const [creatingPresetFromExecutionID, setCreatingPresetFromExecutionID] =
|
|
||||||
useState<GraphExecutionID | null>(null);
|
|
||||||
|
|
||||||
// Set page title with agent name
|
|
||||||
useEffect(() => {
|
|
||||||
if (agent) {
|
|
||||||
document.title = `${agent.name} - Library - AutoGPT Platform`;
|
|
||||||
}
|
|
||||||
}, [agent]);
|
|
||||||
|
|
||||||
const openRunDraftView = useCallback(() => {
|
|
||||||
selectView({ type: "run" });
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const selectRun = useCallback((id: GraphExecutionID) => {
|
|
||||||
selectView({ type: "run", id });
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const selectPreset = useCallback((id: LibraryAgentPresetID) => {
|
|
||||||
selectView({ type: "preset", id });
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const selectSchedule = useCallback((id: ScheduleID) => {
|
|
||||||
selectView({ type: "schedule", id });
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const graphVersions = useRef<Record<number, Graph>>({});
|
|
||||||
const loadingGraphVersions = useRef<Record<number, Promise<Graph>>>({});
|
|
||||||
const getGraphVersion = useCallback(
|
|
||||||
async (graphID: GraphID, version: number) => {
|
|
||||||
if (version in graphVersions.current)
|
|
||||||
return graphVersions.current[version];
|
|
||||||
if (version in loadingGraphVersions.current)
|
|
||||||
return loadingGraphVersions.current[version];
|
|
||||||
|
|
||||||
const pendingGraph = api.getGraph(graphID, version).then((graph) => {
|
|
||||||
graphVersions.current[version] = graph;
|
|
||||||
return graph;
|
|
||||||
});
|
|
||||||
// Cache promise as well to avoid duplicate requests
|
|
||||||
loadingGraphVersions.current[version] = pendingGraph;
|
|
||||||
return pendingGraph;
|
|
||||||
},
|
|
||||||
[api, graphVersions, loadingGraphVersions],
|
|
||||||
);
|
|
||||||
|
|
||||||
const lastRefresh = useRef<number>(0);
|
|
||||||
const refreshPageData = useCallback(() => {
|
|
||||||
if (Date.now() - lastRefresh.current < 2e3) return; // 2 second debounce
|
|
||||||
lastRefresh.current = Date.now();
|
|
||||||
|
|
||||||
api.getLibraryAgent(agentID).then((agent) => {
|
|
||||||
setAgent(agent);
|
|
||||||
|
|
||||||
getGraphVersion(agent.graph_id, agent.graph_version).then(
|
|
||||||
(_graph) =>
|
|
||||||
(graph && graph.version == _graph.version) || setGraph(_graph),
|
|
||||||
);
|
|
||||||
Promise.all([
|
|
||||||
agentRunsQuery.refetchRuns(),
|
|
||||||
api.listLibraryAgentPresets({
|
|
||||||
graph_id: agent.graph_id,
|
|
||||||
page_size: 100,
|
|
||||||
}),
|
|
||||||
]).then(([runsQueryResult, presets]) => {
|
|
||||||
setAgentPresets(presets.presets);
|
|
||||||
|
|
||||||
const newestAgentRunsResponse = runsQueryResult.data?.pages[0];
|
|
||||||
if (!newestAgentRunsResponse || newestAgentRunsResponse.status != 200)
|
|
||||||
return;
|
|
||||||
const newestAgentRuns = newestAgentRunsResponse.data.executions;
|
|
||||||
// Preload the corresponding graph versions for the latest 10 runs
|
|
||||||
new Set(
|
|
||||||
newestAgentRuns.slice(0, 10).map((run) => run.graph_version),
|
|
||||||
).forEach((version) => getGraphVersion(agent.graph_id, version));
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}, [api, agentID, getGraphVersion, graph]);
|
|
||||||
|
|
||||||
// On first load: select the latest run
|
|
||||||
useEffect(() => {
|
|
||||||
// Only for first load or first execution
|
|
||||||
if (selectedView.id || !isFirstLoad) return;
|
|
||||||
if (agentRuns.length == 0 && agentPresets.length == 0) return;
|
|
||||||
|
|
||||||
setIsFirstLoad(false);
|
|
||||||
if (agentRuns.length > 0) {
|
|
||||||
// select latest run
|
|
||||||
const latestRun = agentRuns.reduce((latest, current) => {
|
|
||||||
if (!latest.started_at && !current.started_at) return latest;
|
|
||||||
if (!latest.started_at) return current;
|
|
||||||
if (!current.started_at) return latest;
|
|
||||||
return latest.started_at > current.started_at ? latest : current;
|
|
||||||
}, agentRuns[0]);
|
|
||||||
selectRun(latestRun.id as GraphExecutionID);
|
|
||||||
} else {
|
|
||||||
// select top preset
|
|
||||||
const latestPreset = agentPresets.toSorted(
|
|
||||||
(a, b) => b.updated_at.getTime() - a.updated_at.getTime(),
|
|
||||||
)[0];
|
|
||||||
selectPreset(latestPreset.id);
|
|
||||||
}
|
|
||||||
}, [
|
|
||||||
isFirstLoad,
|
|
||||||
selectedView.id,
|
|
||||||
agentRuns,
|
|
||||||
agentPresets,
|
|
||||||
selectRun,
|
|
||||||
selectPreset,
|
|
||||||
]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (executionId) {
|
|
||||||
selectRun(executionId as GraphExecutionID);
|
|
||||||
setExecutionId(null);
|
|
||||||
}
|
|
||||||
}, [executionId, selectRun, setExecutionId]);
|
|
||||||
|
|
||||||
// Initial load
|
|
||||||
useEffect(() => {
|
|
||||||
refreshPageData();
|
|
||||||
|
|
||||||
// Show a toast when the WebSocket connection disconnects
|
|
||||||
let connectionToast: ReturnType<typeof toast> | null = null;
|
|
||||||
const cancelDisconnectHandler = api.onWebSocketDisconnect(() => {
|
|
||||||
connectionToast ??= toast({
|
|
||||||
title: "Connection to server was lost",
|
|
||||||
variant: "destructive",
|
|
||||||
description: (
|
|
||||||
<div className="flex items-center">
|
|
||||||
Trying to reconnect...
|
|
||||||
<LoadingSpinner className="ml-1.5 size-3.5" />
|
|
||||||
</div>
|
|
||||||
),
|
|
||||||
duration: Infinity,
|
|
||||||
dismissable: true,
|
|
||||||
});
|
|
||||||
});
|
|
||||||
const cancelConnectHandler = api.onWebSocketConnect(() => {
|
|
||||||
if (connectionToast)
|
|
||||||
connectionToast.update({
|
|
||||||
id: connectionToast.id,
|
|
||||||
title: "✅ Connection re-established",
|
|
||||||
variant: "default",
|
|
||||||
description: (
|
|
||||||
<div className="flex items-center">
|
|
||||||
Refreshing data...
|
|
||||||
<LoadingSpinner className="ml-1.5 size-3.5" />
|
|
||||||
</div>
|
|
||||||
),
|
|
||||||
duration: 2000,
|
|
||||||
dismissable: true,
|
|
||||||
});
|
|
||||||
connectionToast = null;
|
|
||||||
});
|
|
||||||
return () => {
|
|
||||||
cancelDisconnectHandler();
|
|
||||||
cancelConnectHandler();
|
|
||||||
};
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
// Subscribe to WebSocket updates for agent runs
|
|
||||||
useEffect(() => {
|
|
||||||
if (!agent?.graph_id) return;
|
|
||||||
|
|
||||||
return api.onWebSocketConnect(() => {
|
|
||||||
refreshPageData(); // Sync up on (re)connect
|
|
||||||
|
|
||||||
// Subscribe to all executions for this agent
|
|
||||||
api.subscribeToGraphExecutions(agent.graph_id);
|
|
||||||
});
|
|
||||||
}, [api, agent?.graph_id, refreshPageData]);
|
|
||||||
|
|
||||||
// Handle execution updates
|
|
||||||
useEffect(() => {
|
|
||||||
const detachExecUpdateHandler = api.onWebSocketMessage(
|
|
||||||
"graph_execution_event",
|
|
||||||
(data) => {
|
|
||||||
if (data.graph_id != agent?.graph_id) return;
|
|
||||||
|
|
||||||
agentRunsQuery.upsertAgentRun(data);
|
|
||||||
if (data.id === selectedView.id) {
|
|
||||||
// Update currently viewed run
|
|
||||||
setSelectedRun(data);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
return () => {
|
|
||||||
detachExecUpdateHandler();
|
|
||||||
};
|
|
||||||
}, [api, agent?.graph_id, selectedView.id]);
|
|
||||||
|
|
||||||
// Pre-load selectedRun based on selectedView
|
|
||||||
useEffect(() => {
|
|
||||||
if (selectedView.type != "run" || !selectedView.id) return;
|
|
||||||
|
|
||||||
const newSelectedRun = agentRuns.find((run) => run.id == selectedView.id);
|
|
||||||
if (selectedView.id !== selectedRun?.id) {
|
|
||||||
// Pull partial data from "cache" while waiting for the rest to load
|
|
||||||
setSelectedRun((newSelectedRun as GraphExecutionMeta) ?? null);
|
|
||||||
}
|
|
||||||
}, [api, selectedView, agentRuns, selectedRun?.id]);
|
|
||||||
|
|
||||||
// Load selectedRun based on selectedView; refresh on agent refresh
|
|
||||||
useEffect(() => {
|
|
||||||
if (selectedView.type != "run" || !selectedView.id || !agent) return;
|
|
||||||
|
|
||||||
api
|
|
||||||
.getGraphExecutionInfo(agent.graph_id, selectedView.id)
|
|
||||||
.then(async (run) => {
|
|
||||||
// Ensure corresponding graph version is available before rendering I/O
|
|
||||||
await getGraphVersion(run.graph_id, run.graph_version);
|
|
||||||
setSelectedRun(run);
|
|
||||||
});
|
|
||||||
}, [api, selectedView, agent, getGraphVersion]);
|
|
||||||
|
|
||||||
const fetchSchedules = useCallback(async () => {
|
|
||||||
if (!agent) return;
|
|
||||||
|
|
||||||
setSchedules(await api.listGraphExecutionSchedules(agent.graph_id));
|
|
||||||
}, [api, agent?.graph_id]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
fetchSchedules();
|
|
||||||
}, [fetchSchedules]);
|
|
||||||
|
|
||||||
// =========================== ACTIONS ============================
|
|
||||||
|
|
||||||
const deleteRun = useCallback(
|
|
||||||
async (run: GraphExecutionMeta) => {
|
|
||||||
if (run.status == "RUNNING" || run.status == "QUEUED") {
|
|
||||||
await api.stopGraphExecution(run.graph_id, run.id);
|
|
||||||
}
|
|
||||||
await api.deleteGraphExecution(run.id);
|
|
||||||
|
|
||||||
setConfirmingDeleteAgentRun(null);
|
|
||||||
if (selectedView.type == "run" && selectedView.id == run.id) {
|
|
||||||
openRunDraftView();
|
|
||||||
}
|
|
||||||
agentRunsQuery.removeAgentRun(run.id);
|
|
||||||
},
|
|
||||||
[api, selectedView, openRunDraftView],
|
|
||||||
);
|
|
||||||
|
|
||||||
const deletePreset = useCallback(
|
|
||||||
async (presetID: LibraryAgentPresetID) => {
|
|
||||||
await api.deleteLibraryAgentPreset(presetID);
|
|
||||||
|
|
||||||
setConfirmingDeleteAgentPreset(null);
|
|
||||||
if (selectedView.type == "preset" && selectedView.id == presetID) {
|
|
||||||
openRunDraftView();
|
|
||||||
}
|
|
||||||
setAgentPresets((presets) => presets.filter((p) => p.id !== presetID));
|
|
||||||
},
|
|
||||||
[api, selectedView, openRunDraftView],
|
|
||||||
);
|
|
||||||
|
|
||||||
const deleteSchedule = useCallback(
|
|
||||||
async (scheduleID: ScheduleID) => {
|
|
||||||
const removedSchedule =
|
|
||||||
await api.deleteGraphExecutionSchedule(scheduleID);
|
|
||||||
|
|
||||||
setSchedules((schedules) => {
|
|
||||||
const newSchedules = schedules.filter(
|
|
||||||
(s) => s.id !== removedSchedule.id,
|
|
||||||
);
|
|
||||||
if (
|
|
||||||
selectedView.type == "schedule" &&
|
|
||||||
selectedView.id == removedSchedule.id
|
|
||||||
) {
|
|
||||||
if (newSchedules.length > 0) {
|
|
||||||
// Select next schedule if available
|
|
||||||
selectSchedule(newSchedules[0].id);
|
|
||||||
} else {
|
|
||||||
// Reset to draft view if current schedule was deleted
|
|
||||||
openRunDraftView();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return newSchedules;
|
|
||||||
});
|
|
||||||
openRunDraftView();
|
|
||||||
},
|
|
||||||
[schedules, api],
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleCreatePresetFromRun = useCallback(
|
|
||||||
async (name: string, description: string) => {
|
|
||||||
if (!creatingPresetFromExecutionID) return;
|
|
||||||
|
|
||||||
await api
|
|
||||||
.createLibraryAgentPreset({
|
|
||||||
name,
|
|
||||||
description,
|
|
||||||
graph_execution_id: creatingPresetFromExecutionID,
|
|
||||||
})
|
|
||||||
.then((preset) => {
|
|
||||||
setAgentPresets((prev) => [...prev, preset]);
|
|
||||||
selectPreset(preset.id);
|
|
||||||
setCreatingPresetFromExecutionID(null);
|
|
||||||
})
|
|
||||||
.catch(toastOnFail("create a preset"));
|
|
||||||
},
|
|
||||||
[api, creatingPresetFromExecutionID, selectPreset, toast],
|
|
||||||
);
|
|
||||||
|
|
||||||
const downloadGraph = useCallback(
|
|
||||||
async () =>
|
|
||||||
agent &&
|
|
||||||
// Export sanitized graph from backend
|
|
||||||
api
|
|
||||||
.getGraph(agent.graph_id, agent.graph_version, true)
|
|
||||||
.then((graph) =>
|
|
||||||
exportAsJSONFile(graph, `${graph.name}_v${graph.version}.json`),
|
|
||||||
),
|
|
||||||
[api, agent],
|
|
||||||
);
|
|
||||||
|
|
||||||
const copyAgent = useCallback(async () => {
|
|
||||||
setCopyAgentDialogOpen(false);
|
|
||||||
api
|
|
||||||
.forkLibraryAgent(agentID)
|
|
||||||
.then((newAgent) => {
|
|
||||||
router.push(`/library/agents/${newAgent.id}`);
|
|
||||||
})
|
|
||||||
.catch((error) => {
|
|
||||||
console.error("Error copying agent:", error);
|
|
||||||
toast({
|
|
||||||
title: "Error copying agent",
|
|
||||||
description: `An error occurred while copying the agent: ${error.message}`,
|
|
||||||
variant: "destructive",
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}, [agentID, api, router, toast]);
|
|
||||||
|
|
||||||
const agentActions: ButtonAction[] = useMemo(
|
|
||||||
() => [
|
|
||||||
{
|
|
||||||
label: "Customize agent",
|
|
||||||
href: `/build?flowID=${agent?.graph_id}&flowVersion=${agent?.graph_version}`,
|
|
||||||
disabled: !agent?.can_access_graph,
|
|
||||||
},
|
|
||||||
{ label: "Export agent to file", callback: downloadGraph },
|
|
||||||
...(!agent?.can_access_graph
|
|
||||||
? [
|
|
||||||
{
|
|
||||||
label: "Edit a copy",
|
|
||||||
callback: () => setCopyAgentDialogOpen(true),
|
|
||||||
},
|
|
||||||
]
|
|
||||||
: []),
|
|
||||||
{
|
|
||||||
label: "Delete agent",
|
|
||||||
callback: () => setAgentDeleteDialogOpen(true),
|
|
||||||
},
|
|
||||||
],
|
|
||||||
[agent, downloadGraph],
|
|
||||||
);
|
|
||||||
|
|
||||||
const runGraph =
|
|
||||||
graphVersions.current[selectedRun?.graph_version ?? 0] ?? graph;
|
|
||||||
|
|
||||||
const onCreateSchedule = useCallback(
|
|
||||||
(schedule: Schedule) => {
|
|
||||||
setSchedules((prev) => [...prev, schedule]);
|
|
||||||
selectSchedule(schedule.id);
|
|
||||||
},
|
|
||||||
[selectView],
|
|
||||||
);
|
|
||||||
|
|
||||||
const onCreatePreset = useCallback(
|
|
||||||
(preset: LibraryAgentPreset) => {
|
|
||||||
setAgentPresets((prev) => [...prev, preset]);
|
|
||||||
selectPreset(preset.id);
|
|
||||||
},
|
|
||||||
[selectPreset],
|
|
||||||
);
|
|
||||||
|
|
||||||
const onUpdatePreset = useCallback(
|
|
||||||
(updated: LibraryAgentPreset) => {
|
|
||||||
setAgentPresets((prev) =>
|
|
||||||
prev.map((p) => (p.id === updated.id ? updated : p)),
|
|
||||||
);
|
|
||||||
selectPreset(updated.id);
|
|
||||||
},
|
|
||||||
[selectPreset],
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!agent || !graph) {
|
|
||||||
return <LoadingBox className="h-[90vh]" />;
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="container justify-stretch p-0 pt-16 lg:flex">
|
|
||||||
{/* Sidebar w/ list of runs */}
|
|
||||||
{/* TODO: render this below header in sm and md layouts */}
|
|
||||||
<AgentRunsSelectorList
|
|
||||||
className="agpt-div w-full border-b pb-2 lg:w-auto lg:border-b-0 lg:border-r lg:pb-0"
|
|
||||||
agent={agent}
|
|
||||||
agentRunsQuery={agentRunsQuery}
|
|
||||||
agentPresets={agentPresets}
|
|
||||||
schedules={schedules}
|
|
||||||
selectedView={selectedView}
|
|
||||||
onSelectRun={selectRun}
|
|
||||||
onSelectPreset={selectPreset}
|
|
||||||
onSelectSchedule={selectSchedule}
|
|
||||||
onSelectDraftNewRun={openRunDraftView}
|
|
||||||
doDeleteRun={setConfirmingDeleteAgentRun}
|
|
||||||
doDeletePreset={setConfirmingDeleteAgentPreset}
|
|
||||||
doDeleteSchedule={deleteSchedule}
|
|
||||||
doCreatePresetFromRun={setCreatingPresetFromExecutionID}
|
|
||||||
/>
|
|
||||||
|
|
||||||
<div className="flex-1">
|
|
||||||
{/* Header */}
|
|
||||||
<div className="agpt-div w-full border-b">
|
|
||||||
<h1
|
|
||||||
data-testid="agent-title"
|
|
||||||
className="font-poppins text-3xl font-medium"
|
|
||||||
>
|
|
||||||
{
|
|
||||||
agent.name /* TODO: use dynamic/custom run title - https://github.com/Significant-Gravitas/AutoGPT/issues/9184 */
|
|
||||||
}
|
|
||||||
</h1>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Run / Schedule views */}
|
|
||||||
{(selectedView.type == "run" && selectedView.id ? (
|
|
||||||
selectedRun && runGraph ? (
|
|
||||||
<AgentRunDetailsView
|
|
||||||
agent={agent}
|
|
||||||
graph={runGraph}
|
|
||||||
run={selectedRun}
|
|
||||||
agentActions={agentActions}
|
|
||||||
onRun={selectRun}
|
|
||||||
doDeleteRun={() => setConfirmingDeleteAgentRun(selectedRun)}
|
|
||||||
doCreatePresetFromRun={() =>
|
|
||||||
setCreatingPresetFromExecutionID(selectedRun.id)
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
) : null
|
|
||||||
) : selectedView.type == "run" ? (
|
|
||||||
/* Draft new runs / Create new presets */
|
|
||||||
<AgentRunDraftView
|
|
||||||
graph={graph}
|
|
||||||
onRun={selectRun}
|
|
||||||
onCreateSchedule={onCreateSchedule}
|
|
||||||
onCreatePreset={onCreatePreset}
|
|
||||||
agentActions={agentActions}
|
|
||||||
recommendedScheduleCron={agent?.recommended_schedule_cron || null}
|
|
||||||
/>
|
|
||||||
) : selectedView.type == "preset" ? (
|
|
||||||
/* Edit & update presets */
|
|
||||||
<AgentRunDraftView
|
|
||||||
graph={graph}
|
|
||||||
agentPreset={
|
|
||||||
agentPresets.find((preset) => preset.id == selectedView.id)!
|
|
||||||
}
|
|
||||||
onRun={selectRun}
|
|
||||||
recommendedScheduleCron={agent?.recommended_schedule_cron || null}
|
|
||||||
onCreateSchedule={onCreateSchedule}
|
|
||||||
onUpdatePreset={onUpdatePreset}
|
|
||||||
doDeletePreset={setConfirmingDeleteAgentPreset}
|
|
||||||
agentActions={agentActions}
|
|
||||||
/>
|
|
||||||
) : selectedView.type == "schedule" ? (
|
|
||||||
selectedSchedule &&
|
|
||||||
graph && (
|
|
||||||
<AgentScheduleDetailsView
|
|
||||||
graph={graph}
|
|
||||||
schedule={selectedSchedule}
|
|
||||||
// agent={agent}
|
|
||||||
agentActions={agentActions}
|
|
||||||
onForcedRun={selectRun}
|
|
||||||
doDeleteSchedule={deleteSchedule}
|
|
||||||
/>
|
|
||||||
)
|
|
||||||
) : null) || <LoadingBox className="h-[70vh]" />}
|
|
||||||
|
|
||||||
<DeleteConfirmDialog
|
|
||||||
entityType="agent"
|
|
||||||
open={agentDeleteDialogOpen}
|
|
||||||
onOpenChange={setAgentDeleteDialogOpen}
|
|
||||||
onDoDelete={() =>
|
|
||||||
agent &&
|
|
||||||
api.deleteLibraryAgent(agent.id).then(() => router.push("/library"))
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
|
|
||||||
<DeleteConfirmDialog
|
|
||||||
entityType="agent run"
|
|
||||||
open={!!confirmingDeleteAgentRun}
|
|
||||||
onOpenChange={(open) => !open && setConfirmingDeleteAgentRun(null)}
|
|
||||||
onDoDelete={() =>
|
|
||||||
confirmingDeleteAgentRun && deleteRun(confirmingDeleteAgentRun)
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
<DeleteConfirmDialog
|
|
||||||
entityType={agent.has_external_trigger ? "trigger" : "agent preset"}
|
|
||||||
open={!!confirmingDeleteAgentPreset}
|
|
||||||
onOpenChange={(open) => !open && setConfirmingDeleteAgentPreset(null)}
|
|
||||||
onDoDelete={() =>
|
|
||||||
confirmingDeleteAgentPreset &&
|
|
||||||
deletePreset(confirmingDeleteAgentPreset)
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
{/* Copy agent confirmation dialog */}
|
|
||||||
<Dialog
|
|
||||||
onOpenChange={setCopyAgentDialogOpen}
|
|
||||||
open={copyAgentDialogOpen}
|
|
||||||
>
|
|
||||||
<DialogContent>
|
|
||||||
<DialogHeader>
|
|
||||||
<DialogTitle>You're making an editable copy</DialogTitle>
|
|
||||||
<DialogDescription className="pt-2">
|
|
||||||
The original Marketplace agent stays the same and cannot be
|
|
||||||
edited. We'll save a new version of this agent to your
|
|
||||||
Library. From there, you can customize it however you'd
|
|
||||||
like by clicking "Customize agent" — this will open
|
|
||||||
the builder where you can see and modify the inner workings.
|
|
||||||
</DialogDescription>
|
|
||||||
</DialogHeader>
|
|
||||||
<DialogFooter className="justify-end">
|
|
||||||
<Button
|
|
||||||
type="button"
|
|
||||||
variant="outline"
|
|
||||||
onClick={() => setCopyAgentDialogOpen(false)}
|
|
||||||
>
|
|
||||||
Cancel
|
|
||||||
</Button>
|
|
||||||
<Button type="button" onClick={copyAgent}>
|
|
||||||
Continue
|
|
||||||
</Button>
|
|
||||||
</DialogFooter>
|
|
||||||
</DialogContent>
|
|
||||||
</Dialog>
|
|
||||||
<CreatePresetDialog
|
|
||||||
open={!!creatingPresetFromExecutionID}
|
|
||||||
onOpenChange={() => setCreatingPresetFromExecutionID(null)}
|
|
||||||
onConfirm={handleCreatePresetFromRun}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,445 +0,0 @@
|
|||||||
"use client";
|
|
||||||
import { format, formatDistanceToNow, formatDistanceStrict } from "date-fns";
|
|
||||||
import React, { useCallback, useMemo, useEffect } from "react";
|
|
||||||
|
|
||||||
import {
|
|
||||||
Graph,
|
|
||||||
GraphExecution,
|
|
||||||
GraphExecutionID,
|
|
||||||
GraphExecutionMeta,
|
|
||||||
LibraryAgent,
|
|
||||||
} from "@/lib/autogpt-server-api";
|
|
||||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
|
||||||
|
|
||||||
import ActionButtonGroup from "@/components/__legacy__/action-button-group";
|
|
||||||
import type { ButtonAction } from "@/components/__legacy__/types";
|
|
||||||
import {
|
|
||||||
Card,
|
|
||||||
CardContent,
|
|
||||||
CardHeader,
|
|
||||||
CardTitle,
|
|
||||||
} from "@/components/__legacy__/ui/card";
|
|
||||||
import {
|
|
||||||
IconRefresh,
|
|
||||||
IconSquare,
|
|
||||||
IconCircleAlert,
|
|
||||||
} from "@/components/__legacy__/ui/icons";
|
|
||||||
import { Input } from "@/components/__legacy__/ui/input";
|
|
||||||
import LoadingBox from "@/components/__legacy__/ui/loading";
|
|
||||||
import {
|
|
||||||
Tooltip,
|
|
||||||
TooltipContent,
|
|
||||||
TooltipProvider,
|
|
||||||
TooltipTrigger,
|
|
||||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
|
||||||
import { useToastOnFail } from "@/components/molecules/Toast/use-toast";
|
|
||||||
|
|
||||||
import { AgentRunStatus, agentRunStatusMap } from "./agent-run-status-chip";
|
|
||||||
import useCredits from "@/hooks/useCredits";
|
|
||||||
import { AgentRunOutputView } from "./agent-run-output-view";
|
|
||||||
import { analytics } from "@/services/analytics";
|
|
||||||
import { PendingReviewsList } from "@/components/organisms/PendingReviewsList/PendingReviewsList";
|
|
||||||
import { usePendingReviewsForExecution } from "@/hooks/usePendingReviews";
|
|
||||||
|
|
||||||
export function AgentRunDetailsView({
|
|
||||||
agent,
|
|
||||||
graph,
|
|
||||||
run,
|
|
||||||
agentActions,
|
|
||||||
onRun,
|
|
||||||
doDeleteRun,
|
|
||||||
doCreatePresetFromRun,
|
|
||||||
}: {
|
|
||||||
agent: LibraryAgent;
|
|
||||||
graph: Graph;
|
|
||||||
run: GraphExecution | GraphExecutionMeta;
|
|
||||||
agentActions: ButtonAction[];
|
|
||||||
onRun: (runID: GraphExecutionID) => void;
|
|
||||||
doDeleteRun: () => void;
|
|
||||||
doCreatePresetFromRun: () => void;
|
|
||||||
}): React.ReactNode {
|
|
||||||
const api = useBackendAPI();
|
|
||||||
const { formatCredits } = useCredits();
|
|
||||||
|
|
||||||
const runStatus: AgentRunStatus = useMemo(
|
|
||||||
() => agentRunStatusMap[run.status],
|
|
||||||
[run],
|
|
||||||
);
|
|
||||||
|
|
||||||
const {
|
|
||||||
pendingReviews,
|
|
||||||
isLoading: reviewsLoading,
|
|
||||||
refetch: refetchReviews,
|
|
||||||
} = usePendingReviewsForExecution(run.id);
|
|
||||||
|
|
||||||
const toastOnFail = useToastOnFail();
|
|
||||||
|
|
||||||
// Refetch pending reviews when execution status changes to REVIEW
|
|
||||||
useEffect(() => {
|
|
||||||
if (runStatus === "review" && run.id) {
|
|
||||||
refetchReviews();
|
|
||||||
}
|
|
||||||
}, [runStatus, run.id, refetchReviews]);
|
|
||||||
|
|
||||||
const infoStats: { label: string; value: React.ReactNode }[] = useMemo(() => {
|
|
||||||
if (!run) return [];
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
label: "Status",
|
|
||||||
value: runStatus.charAt(0).toUpperCase() + runStatus.slice(1),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
label: "Started",
|
|
||||||
value: run.started_at
|
|
||||||
? `${formatDistanceToNow(run.started_at, { addSuffix: true })}, ${format(run.started_at, "HH:mm")}`
|
|
||||||
: "—",
|
|
||||||
},
|
|
||||||
...(run.stats
|
|
||||||
? [
|
|
||||||
{
|
|
||||||
label: "Duration",
|
|
||||||
value: formatDistanceStrict(0, run.stats.duration * 1000),
|
|
||||||
},
|
|
||||||
{ label: "Steps", value: run.stats.node_exec_count },
|
|
||||||
{ label: "Cost", value: formatCredits(run.stats.cost) },
|
|
||||||
]
|
|
||||||
: []),
|
|
||||||
];
|
|
||||||
}, [run, runStatus, formatCredits]);
|
|
||||||
|
|
||||||
const agentRunInputs:
|
|
||||||
| Record<
|
|
||||||
string,
|
|
||||||
{
|
|
||||||
title?: string;
|
|
||||||
/* type: BlockIOSubType; */
|
|
||||||
value: string | number | undefined;
|
|
||||||
}
|
|
||||||
>
|
|
||||||
| undefined = useMemo(() => {
|
|
||||||
if (!run.inputs) return undefined;
|
|
||||||
// TODO: show (link to) preset - https://github.com/Significant-Gravitas/AutoGPT/issues/9168
|
|
||||||
|
|
||||||
// Add type info from agent input schema
|
|
||||||
return Object.fromEntries(
|
|
||||||
Object.entries(run.inputs).map(([k, v]) => [
|
|
||||||
k,
|
|
||||||
{
|
|
||||||
title: graph.input_schema.properties[k]?.title,
|
|
||||||
// type: graph.input_schema.properties[k].type, // TODO: implement typed graph inputs
|
|
||||||
value: typeof v == "object" ? JSON.stringify(v, undefined, 2) : v,
|
|
||||||
},
|
|
||||||
]),
|
|
||||||
);
|
|
||||||
}, [graph, run]);
|
|
||||||
|
|
||||||
const runAgain = useCallback(() => {
|
|
||||||
if (
|
|
||||||
!run.inputs ||
|
|
||||||
!(graph.credentials_input_schema?.required ?? []).every(
|
|
||||||
(k) => k in (run.credential_inputs ?? {}),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return;
|
|
||||||
|
|
||||||
if (run.preset_id) {
|
|
||||||
return api
|
|
||||||
.executeLibraryAgentPreset(
|
|
||||||
run.preset_id,
|
|
||||||
run.inputs!,
|
|
||||||
run.credential_inputs!,
|
|
||||||
)
|
|
||||||
.then(({ id }) => {
|
|
||||||
analytics.sendDatafastEvent("run_agent", {
|
|
||||||
name: graph.name,
|
|
||||||
id: graph.id,
|
|
||||||
});
|
|
||||||
onRun(id);
|
|
||||||
})
|
|
||||||
.catch(toastOnFail("execute agent preset"));
|
|
||||||
}
|
|
||||||
|
|
||||||
return api
|
|
||||||
.executeGraph(
|
|
||||||
graph.id,
|
|
||||||
graph.version,
|
|
||||||
run.inputs!,
|
|
||||||
run.credential_inputs!,
|
|
||||||
"library",
|
|
||||||
)
|
|
||||||
.then(({ id }) => {
|
|
||||||
analytics.sendDatafastEvent("run_agent", {
|
|
||||||
name: graph.name,
|
|
||||||
id: graph.id,
|
|
||||||
});
|
|
||||||
onRun(id);
|
|
||||||
})
|
|
||||||
.catch(toastOnFail("execute agent"));
|
|
||||||
}, [api, graph, run, onRun, toastOnFail]);
|
|
||||||
|
|
||||||
const stopRun = useCallback(
|
|
||||||
() => api.stopGraphExecution(graph.id, run.id),
|
|
||||||
[api, graph.id, run.id],
|
|
||||||
);
|
|
||||||
|
|
||||||
const agentRunOutputs:
|
|
||||||
| Record<
|
|
||||||
string,
|
|
||||||
{
|
|
||||||
title?: string;
|
|
||||||
/* type: BlockIOSubType; */
|
|
||||||
values: Array<React.ReactNode>;
|
|
||||||
}
|
|
||||||
>
|
|
||||||
| null
|
|
||||||
| undefined = useMemo(() => {
|
|
||||||
if (!("outputs" in run)) return undefined;
|
|
||||||
if (!["running", "success", "failed", "stopped"].includes(runStatus))
|
|
||||||
return null;
|
|
||||||
|
|
||||||
// Add type info from agent input schema
|
|
||||||
return Object.fromEntries(
|
|
||||||
Object.entries(run.outputs).map(([k, vv]) => [
|
|
||||||
k,
|
|
||||||
{
|
|
||||||
title: graph.output_schema.properties[k].title,
|
|
||||||
/* type: agent.output_schema.properties[k].type */
|
|
||||||
values: vv.map((v) =>
|
|
||||||
typeof v == "object" ? JSON.stringify(v, undefined, 2) : v,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
]),
|
|
||||||
);
|
|
||||||
}, [graph, run, runStatus]);
|
|
||||||
|
|
||||||
const runActions: ButtonAction[] = useMemo(
|
|
||||||
() => [
|
|
||||||
...(["running", "queued"].includes(runStatus)
|
|
||||||
? ([
|
|
||||||
{
|
|
||||||
label: (
|
|
||||||
<>
|
|
||||||
<IconSquare className="mr-2 size-4" />
|
|
||||||
Stop run
|
|
||||||
</>
|
|
||||||
),
|
|
||||||
variant: "secondary",
|
|
||||||
callback: stopRun,
|
|
||||||
},
|
|
||||||
] satisfies ButtonAction[])
|
|
||||||
: []),
|
|
||||||
...(["success", "failed", "stopped"].includes(runStatus) &&
|
|
||||||
!graph.has_external_trigger &&
|
|
||||||
(graph.credentials_input_schema?.required ?? []).every(
|
|
||||||
(k) => k in (run.credential_inputs ?? {}),
|
|
||||||
)
|
|
||||||
? [
|
|
||||||
{
|
|
||||||
label: (
|
|
||||||
<>
|
|
||||||
<IconRefresh className="mr-2 size-4" />
|
|
||||||
Run again
|
|
||||||
</>
|
|
||||||
),
|
|
||||||
callback: runAgain,
|
|
||||||
dataTestId: "run-again-button",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
: []),
|
|
||||||
...(agent.can_access_graph
|
|
||||||
? [
|
|
||||||
{
|
|
||||||
label: "Open run in builder",
|
|
||||||
href: `/build?flowID=${run.graph_id}&flowVersion=${run.graph_version}&flowExecutionID=${run.id}`,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
: []),
|
|
||||||
{ label: "Create preset from run", callback: doCreatePresetFromRun },
|
|
||||||
{ label: "Delete run", variant: "secondary", callback: doDeleteRun },
|
|
||||||
],
|
|
||||||
[
|
|
||||||
runStatus,
|
|
||||||
runAgain,
|
|
||||||
stopRun,
|
|
||||||
doDeleteRun,
|
|
||||||
doCreatePresetFromRun,
|
|
||||||
graph.has_external_trigger,
|
|
||||||
graph.credentials_input_schema?.required,
|
|
||||||
agent.can_access_graph,
|
|
||||||
run.graph_id,
|
|
||||||
run.graph_version,
|
|
||||||
run.id,
|
|
||||||
],
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="agpt-div flex gap-6">
|
|
||||||
<div className="flex flex-1 flex-col gap-4">
|
|
||||||
<Card className="agpt-box">
|
|
||||||
<CardHeader>
|
|
||||||
<CardTitle className="font-poppins text-lg">Info</CardTitle>
|
|
||||||
</CardHeader>
|
|
||||||
|
|
||||||
<CardContent>
|
|
||||||
<div className="flex justify-stretch gap-4">
|
|
||||||
{infoStats.map(({ label, value }) => (
|
|
||||||
<div key={label} className="flex-1">
|
|
||||||
<p className="text-sm font-medium text-black">{label}</p>
|
|
||||||
<p className="text-sm text-neutral-600">{value}</p>
|
|
||||||
</div>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
{run.status === "FAILED" && (
|
|
||||||
<div className="mt-4 rounded-md border border-red-200 bg-red-50 p-3 dark:border-red-800 dark:bg-red-900/20">
|
|
||||||
<p className="text-sm text-red-800 dark:text-red-200">
|
|
||||||
<strong>Error:</strong>{" "}
|
|
||||||
{run.stats?.error ||
|
|
||||||
"The execution failed due to an internal error. You can re-run the agent to retry."}
|
|
||||||
</p>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
|
|
||||||
{/* Smart Agent Execution Summary */}
|
|
||||||
{run.stats?.activity_status && (
|
|
||||||
<Card className="agpt-box">
|
|
||||||
<CardHeader>
|
|
||||||
<CardTitle className="flex items-center gap-2 font-poppins text-lg">
|
|
||||||
Task Summary
|
|
||||||
<TooltipProvider>
|
|
||||||
<Tooltip>
|
|
||||||
<TooltipTrigger asChild>
|
|
||||||
<IconCircleAlert className="size-4 cursor-help text-neutral-500 hover:text-neutral-700" />
|
|
||||||
</TooltipTrigger>
|
|
||||||
<TooltipContent>
|
|
||||||
<p className="max-w-xs">
|
|
||||||
This AI-generated summary describes how the agent
|
|
||||||
handled your task. It’s an experimental feature and may
|
|
||||||
occasionally be inaccurate.
|
|
||||||
</p>
|
|
||||||
</TooltipContent>
|
|
||||||
</Tooltip>
|
|
||||||
</TooltipProvider>
|
|
||||||
</CardTitle>
|
|
||||||
</CardHeader>
|
|
||||||
<CardContent className="space-y-4">
|
|
||||||
<p className="text-sm leading-relaxed text-neutral-700">
|
|
||||||
{run.stats.activity_status}
|
|
||||||
</p>
|
|
||||||
|
|
||||||
{/* Correctness Score */}
|
|
||||||
{typeof run.stats.correctness_score === "number" && (
|
|
||||||
<div className="flex items-center gap-3 rounded-lg bg-neutral-50 p-3">
|
|
||||||
<div className="flex items-center gap-2">
|
|
||||||
<span className="text-sm font-medium text-neutral-600">
|
|
||||||
Success Estimate:
|
|
||||||
</span>
|
|
||||||
<div className="flex items-center gap-2">
|
|
||||||
<div className="relative h-2 w-16 overflow-hidden rounded-full bg-neutral-200">
|
|
||||||
<div
|
|
||||||
className={`h-full transition-all ${
|
|
||||||
run.stats.correctness_score >= 0.8
|
|
||||||
? "bg-green-500"
|
|
||||||
: run.stats.correctness_score >= 0.6
|
|
||||||
? "bg-yellow-500"
|
|
||||||
: run.stats.correctness_score >= 0.4
|
|
||||||
? "bg-orange-500"
|
|
||||||
: "bg-red-500"
|
|
||||||
}`}
|
|
||||||
style={{
|
|
||||||
width: `${Math.round(run.stats.correctness_score * 100)}%`,
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<span className="text-sm font-medium">
|
|
||||||
{Math.round(run.stats.correctness_score * 100)}%
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<TooltipProvider>
|
|
||||||
<Tooltip>
|
|
||||||
<TooltipTrigger asChild>
|
|
||||||
<IconCircleAlert className="size-4 cursor-help text-neutral-400 hover:text-neutral-600" />
|
|
||||||
</TooltipTrigger>
|
|
||||||
<TooltipContent>
|
|
||||||
<p className="max-w-xs">
|
|
||||||
AI-generated estimate of how well this execution
|
|
||||||
achieved its intended purpose. This score indicates
|
|
||||||
{run.stats.correctness_score >= 0.8
|
|
||||||
? " the agent was highly successful."
|
|
||||||
: run.stats.correctness_score >= 0.6
|
|
||||||
? " the agent was mostly successful with minor issues."
|
|
||||||
: run.stats.correctness_score >= 0.4
|
|
||||||
? " the agent was partially successful with some gaps."
|
|
||||||
: " the agent had limited success with significant issues."}
|
|
||||||
</p>
|
|
||||||
</TooltipContent>
|
|
||||||
</Tooltip>
|
|
||||||
</TooltipProvider>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{agentRunOutputs !== null && (
|
|
||||||
<AgentRunOutputView agentRunOutputs={agentRunOutputs} />
|
|
||||||
)}
|
|
||||||
|
|
||||||
{/* Pending Reviews Section */}
|
|
||||||
{runStatus === "review" && (
|
|
||||||
<Card className="agpt-box">
|
|
||||||
<CardHeader>
|
|
||||||
<CardTitle className="font-poppins text-lg">
|
|
||||||
Pending Reviews ({pendingReviews.length})
|
|
||||||
</CardTitle>
|
|
||||||
</CardHeader>
|
|
||||||
<CardContent>
|
|
||||||
{reviewsLoading ? (
|
|
||||||
<LoadingBox spinnerSize={12} className="h-24" />
|
|
||||||
) : pendingReviews.length > 0 ? (
|
|
||||||
<PendingReviewsList
|
|
||||||
reviews={pendingReviews}
|
|
||||||
onReviewComplete={refetchReviews}
|
|
||||||
emptyMessage="No pending reviews for this execution"
|
|
||||||
/>
|
|
||||||
) : (
|
|
||||||
<div className="py-4 text-neutral-600">
|
|
||||||
No pending reviews for this execution
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<Card className="agpt-box">
|
|
||||||
<CardHeader>
|
|
||||||
<CardTitle className="font-poppins text-lg">Input</CardTitle>
|
|
||||||
</CardHeader>
|
|
||||||
<CardContent className="flex flex-col gap-4">
|
|
||||||
{agentRunInputs !== undefined ? (
|
|
||||||
Object.entries(agentRunInputs).map(([key, { title, value }]) => (
|
|
||||||
<div key={key} className="flex flex-col gap-1.5">
|
|
||||||
<label className="text-sm font-medium">{title || key}</label>
|
|
||||||
<Input value={value} className="rounded-full" disabled />
|
|
||||||
</div>
|
|
||||||
))
|
|
||||||
) : (
|
|
||||||
<LoadingBox spinnerSize={12} className="h-24" />
|
|
||||||
)}
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Run / Agent Actions */}
|
|
||||||
<aside className="w-48 xl:w-56">
|
|
||||||
<div className="flex flex-col gap-8">
|
|
||||||
<ActionButtonGroup title="Run actions" actions={runActions} />
|
|
||||||
|
|
||||||
<ActionButtonGroup title="Agent actions" actions={agentActions} />
|
|
||||||
</div>
|
|
||||||
</aside>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,178 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
|
||||||
import React, { useMemo } from "react";
|
|
||||||
|
|
||||||
import {
|
|
||||||
Card,
|
|
||||||
CardContent,
|
|
||||||
CardHeader,
|
|
||||||
CardTitle,
|
|
||||||
} from "@/components/__legacy__/ui/card";
|
|
||||||
|
|
||||||
import LoadingBox from "@/components/__legacy__/ui/loading";
|
|
||||||
import type { OutputMetadata } from "../../../../../../../../components/contextual/OutputRenderers";
|
|
||||||
import {
|
|
||||||
globalRegistry,
|
|
||||||
OutputActions,
|
|
||||||
OutputItem,
|
|
||||||
} from "../../../../../../../../components/contextual/OutputRenderers";
|
|
||||||
|
|
||||||
export function AgentRunOutputView({
|
|
||||||
agentRunOutputs,
|
|
||||||
}: {
|
|
||||||
agentRunOutputs:
|
|
||||||
| Record<
|
|
||||||
string,
|
|
||||||
{
|
|
||||||
title?: string;
|
|
||||||
/* type: BlockIOSubType; */
|
|
||||||
values: Array<React.ReactNode>;
|
|
||||||
}
|
|
||||||
>
|
|
||||||
| undefined;
|
|
||||||
}) {
|
|
||||||
const enableEnhancedOutputHandling = useGetFlag(
|
|
||||||
Flag.ENABLE_ENHANCED_OUTPUT_HANDLING,
|
|
||||||
);
|
|
||||||
|
|
||||||
// Prepare items for the renderer system
|
|
||||||
const outputItems = useMemo(() => {
|
|
||||||
if (!agentRunOutputs) return [];
|
|
||||||
|
|
||||||
const items: Array<{
|
|
||||||
key: string;
|
|
||||||
label: string;
|
|
||||||
value: unknown;
|
|
||||||
metadata?: OutputMetadata;
|
|
||||||
renderer: any;
|
|
||||||
}> = [];
|
|
||||||
|
|
||||||
Object.entries(agentRunOutputs).forEach(([key, { title, values }]) => {
|
|
||||||
values.forEach((value, index) => {
|
|
||||||
// Enhanced metadata extraction
|
|
||||||
const metadata: OutputMetadata = {};
|
|
||||||
|
|
||||||
// Type guard to safely access properties
|
|
||||||
if (
|
|
||||||
typeof value === "object" &&
|
|
||||||
value !== null &&
|
|
||||||
!React.isValidElement(value)
|
|
||||||
) {
|
|
||||||
const objValue = value as any;
|
|
||||||
if (objValue.type) metadata.type = objValue.type;
|
|
||||||
if (objValue.mimeType) metadata.mimeType = objValue.mimeType;
|
|
||||||
if (objValue.filename) metadata.filename = objValue.filename;
|
|
||||||
}
|
|
||||||
|
|
||||||
const renderer = globalRegistry.getRenderer(value, metadata);
|
|
||||||
if (renderer) {
|
|
||||||
items.push({
|
|
||||||
key: `${key}-${index}`,
|
|
||||||
label: index === 0 ? title || key : "",
|
|
||||||
value,
|
|
||||||
metadata,
|
|
||||||
renderer,
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
const textRenderer = globalRegistry
|
|
||||||
.getAllRenderers()
|
|
||||||
.find((r) => r.name === "TextRenderer");
|
|
||||||
if (textRenderer) {
|
|
||||||
items.push({
|
|
||||||
key: `${key}-${index}`,
|
|
||||||
label: index === 0 ? title || key : "",
|
|
||||||
value: JSON.stringify(value, null, 2),
|
|
||||||
metadata,
|
|
||||||
renderer: textRenderer,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
return items;
|
|
||||||
}, [agentRunOutputs]);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<>
|
|
||||||
{enableEnhancedOutputHandling ? (
|
|
||||||
<Card className="agpt-box" style={{ maxWidth: "950px" }}>
|
|
||||||
<CardHeader>
|
|
||||||
<div className="flex items-center justify-between">
|
|
||||||
<CardTitle className="font-poppins text-lg">Output</CardTitle>
|
|
||||||
{outputItems.length > 0 && (
|
|
||||||
<OutputActions
|
|
||||||
items={outputItems.map((item) => ({
|
|
||||||
value: item.value,
|
|
||||||
metadata: item.metadata,
|
|
||||||
renderer: item.renderer,
|
|
||||||
}))}
|
|
||||||
/>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</CardHeader>
|
|
||||||
|
|
||||||
<CardContent
|
|
||||||
className="flex flex-col gap-4"
|
|
||||||
style={{ maxWidth: "660px" }}
|
|
||||||
>
|
|
||||||
{agentRunOutputs !== undefined ? (
|
|
||||||
outputItems.length > 0 ? (
|
|
||||||
outputItems.map((item) => (
|
|
||||||
<OutputItem
|
|
||||||
key={item.key}
|
|
||||||
value={item.value}
|
|
||||||
metadata={item.metadata}
|
|
||||||
renderer={item.renderer}
|
|
||||||
label={item.label}
|
|
||||||
/>
|
|
||||||
))
|
|
||||||
) : (
|
|
||||||
<p className="text-sm text-muted-foreground">
|
|
||||||
No outputs to display
|
|
||||||
</p>
|
|
||||||
)
|
|
||||||
) : (
|
|
||||||
<LoadingBox spinnerSize={12} className="h-24" />
|
|
||||||
)}
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
) : (
|
|
||||||
<Card className="agpt-box" style={{ maxWidth: "950px" }}>
|
|
||||||
<CardHeader>
|
|
||||||
<CardTitle className="font-poppins text-lg">Output</CardTitle>
|
|
||||||
</CardHeader>
|
|
||||||
|
|
||||||
<CardContent
|
|
||||||
className="flex flex-col gap-4"
|
|
||||||
style={{ maxWidth: "660px" }}
|
|
||||||
>
|
|
||||||
{agentRunOutputs !== undefined ? (
|
|
||||||
Object.entries(agentRunOutputs).map(
|
|
||||||
([key, { title, values }]) => (
|
|
||||||
<div key={key} className="flex flex-col gap-1.5">
|
|
||||||
<label className="text-sm font-medium">
|
|
||||||
{title || key}
|
|
||||||
</label>
|
|
||||||
{values.map((value, i) => (
|
|
||||||
<p
|
|
||||||
className="resize-none overflow-x-auto whitespace-pre-wrap break-words border-none text-sm text-neutral-700 disabled:cursor-not-allowed"
|
|
||||||
key={i}
|
|
||||||
>
|
|
||||||
{value}
|
|
||||||
</p>
|
|
||||||
))}
|
|
||||||
{/* TODO: pretty type-dependent rendering */}
|
|
||||||
</div>
|
|
||||||
),
|
|
||||||
)
|
|
||||||
) : (
|
|
||||||
<LoadingBox spinnerSize={12} className="h-24" />
|
|
||||||
)}
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
)}
|
|
||||||
</>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
import React from "react";
|
|
||||||
|
|
||||||
import { Badge } from "@/components/__legacy__/ui/badge";
|
|
||||||
|
|
||||||
import { GraphExecutionMeta } from "@/lib/autogpt-server-api/types";
|
|
||||||
|
|
||||||
export type AgentRunStatus =
|
|
||||||
| "success"
|
|
||||||
| "failed"
|
|
||||||
| "queued"
|
|
||||||
| "running"
|
|
||||||
| "stopped"
|
|
||||||
| "scheduled"
|
|
||||||
| "draft"
|
|
||||||
| "review";
|
|
||||||
|
|
||||||
export const agentRunStatusMap: Record<
|
|
||||||
GraphExecutionMeta["status"],
|
|
||||||
AgentRunStatus
|
|
||||||
> = {
|
|
||||||
INCOMPLETE: "draft",
|
|
||||||
COMPLETED: "success",
|
|
||||||
FAILED: "failed",
|
|
||||||
QUEUED: "queued",
|
|
||||||
RUNNING: "running",
|
|
||||||
TERMINATED: "stopped",
|
|
||||||
REVIEW: "review",
|
|
||||||
};
|
|
||||||
|
|
||||||
const statusData: Record<
|
|
||||||
AgentRunStatus,
|
|
||||||
{ label: string; variant: keyof typeof statusStyles }
|
|
||||||
> = {
|
|
||||||
success: { label: "Success", variant: "success" },
|
|
||||||
running: { label: "Running", variant: "info" },
|
|
||||||
failed: { label: "Failed", variant: "destructive" },
|
|
||||||
queued: { label: "Queued", variant: "warning" },
|
|
||||||
draft: { label: "Draft", variant: "secondary" },
|
|
||||||
stopped: { label: "Stopped", variant: "secondary" },
|
|
||||||
scheduled: { label: "Scheduled", variant: "secondary" },
|
|
||||||
review: { label: "In Review", variant: "warning" },
|
|
||||||
};
|
|
||||||
|
|
||||||
const statusStyles = {
|
|
||||||
success:
|
|
||||||
"bg-green-100 text-green-800 hover:bg-green-100 hover:text-green-800",
|
|
||||||
destructive: "bg-red-100 text-red-800 hover:bg-red-100 hover:text-red-800",
|
|
||||||
warning:
|
|
||||||
"bg-yellow-100 text-yellow-800 hover:bg-yellow-100 hover:text-yellow-800",
|
|
||||||
info: "bg-blue-100 text-blue-800 hover:bg-blue-100 hover:text-blue-800",
|
|
||||||
secondary:
|
|
||||||
"bg-slate-100 text-slate-800 hover:bg-slate-100 hover:text-slate-800",
|
|
||||||
};
|
|
||||||
|
|
||||||
export function AgentRunStatusChip({
|
|
||||||
status,
|
|
||||||
}: {
|
|
||||||
status: AgentRunStatus;
|
|
||||||
}): React.ReactElement {
|
|
||||||
return (
|
|
||||||
<Badge
|
|
||||||
variant="secondary"
|
|
||||||
className={`text-xs font-medium ${statusStyles[statusData[status]?.variant]} rounded-[45px] px-[9px] py-[3px]`}
|
|
||||||
>
|
|
||||||
{statusData[status]?.label}
|
|
||||||
</Badge>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,130 +0,0 @@
|
|||||||
import React from "react";
|
|
||||||
import { formatDistanceToNow, isPast } from "date-fns";
|
|
||||||
|
|
||||||
import { cn } from "@/lib/utils";
|
|
||||||
|
|
||||||
import { Link2Icon, Link2OffIcon, MoreVertical } from "lucide-react";
|
|
||||||
import { Card, CardContent } from "@/components/__legacy__/ui/card";
|
|
||||||
import { Button } from "@/components/__legacy__/ui/button";
|
|
||||||
import {
|
|
||||||
DropdownMenu,
|
|
||||||
DropdownMenuContent,
|
|
||||||
DropdownMenuItem,
|
|
||||||
DropdownMenuTrigger,
|
|
||||||
} from "@/components/__legacy__/ui/dropdown-menu";
|
|
||||||
|
|
||||||
import { AgentStatus, AgentStatusChip } from "./agent-status-chip";
|
|
||||||
import { AgentRunStatus, AgentRunStatusChip } from "./agent-run-status-chip";
|
|
||||||
import { PushPinSimpleIcon } from "@phosphor-icons/react";
|
|
||||||
|
|
||||||
export type AgentRunSummaryProps = (
|
|
||||||
| {
|
|
||||||
type: "run";
|
|
||||||
status: AgentRunStatus;
|
|
||||||
}
|
|
||||||
| {
|
|
||||||
type: "preset";
|
|
||||||
status?: undefined;
|
|
||||||
}
|
|
||||||
| {
|
|
||||||
type: "preset.triggered";
|
|
||||||
status: AgentStatus;
|
|
||||||
}
|
|
||||||
| {
|
|
||||||
type: "schedule";
|
|
||||||
status: "scheduled";
|
|
||||||
}
|
|
||||||
) & {
|
|
||||||
title: string;
|
|
||||||
timestamp?: number | Date;
|
|
||||||
selected?: boolean;
|
|
||||||
onClick?: () => void;
|
|
||||||
// onRename: () => void;
|
|
||||||
onDelete: () => void;
|
|
||||||
onPinAsPreset?: () => void;
|
|
||||||
className?: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export function AgentRunSummaryCard({
|
|
||||||
type,
|
|
||||||
status,
|
|
||||||
title,
|
|
||||||
timestamp,
|
|
||||||
selected = false,
|
|
||||||
onClick,
|
|
||||||
// onRename,
|
|
||||||
onDelete,
|
|
||||||
onPinAsPreset,
|
|
||||||
className,
|
|
||||||
}: AgentRunSummaryProps): React.ReactElement {
|
|
||||||
return (
|
|
||||||
<Card
|
|
||||||
className={cn(
|
|
||||||
"agpt-rounded-card cursor-pointer border-zinc-300",
|
|
||||||
selected ? "agpt-card-selected" : "",
|
|
||||||
className,
|
|
||||||
)}
|
|
||||||
onClick={onClick}
|
|
||||||
>
|
|
||||||
<CardContent className="relative p-2.5 lg:p-4">
|
|
||||||
{(type == "run" || type == "schedule") && (
|
|
||||||
<AgentRunStatusChip status={status} />
|
|
||||||
)}
|
|
||||||
{type == "preset" && (
|
|
||||||
<div className="flex items-center text-sm font-medium text-neutral-700">
|
|
||||||
<PushPinSimpleIcon className="mr-1 size-4 text-foreground" /> Preset
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
{type == "preset.triggered" && (
|
|
||||||
<div className="flex items-center justify-between">
|
|
||||||
<AgentStatusChip status={status} />
|
|
||||||
|
|
||||||
<div className="flex items-center text-sm font-medium text-neutral-700">
|
|
||||||
{status == "inactive" ? (
|
|
||||||
<Link2OffIcon className="mr-1 size-4 text-foreground" />
|
|
||||||
) : (
|
|
||||||
<Link2Icon className="mr-1 size-4 text-foreground" />
|
|
||||||
)}{" "}
|
|
||||||
Trigger
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
<div className="mt-5 flex items-center justify-between">
|
|
||||||
<h3 className="truncate pr-2 text-base font-medium text-neutral-900">
|
|
||||||
{title}
|
|
||||||
</h3>
|
|
||||||
|
|
||||||
<DropdownMenu>
|
|
||||||
<DropdownMenuTrigger asChild>
|
|
||||||
<Button variant="ghost" className="h-5 w-5 p-0">
|
|
||||||
<MoreVertical className="h-5 w-5" />
|
|
||||||
</Button>
|
|
||||||
</DropdownMenuTrigger>
|
|
||||||
<DropdownMenuContent>
|
|
||||||
{onPinAsPreset && (
|
|
||||||
<DropdownMenuItem onClick={onPinAsPreset}>
|
|
||||||
Pin as a preset
|
|
||||||
</DropdownMenuItem>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{/* <DropdownMenuItem onClick={onRename}>Rename</DropdownMenuItem> */}
|
|
||||||
|
|
||||||
<DropdownMenuItem onClick={onDelete}>Delete</DropdownMenuItem>
|
|
||||||
</DropdownMenuContent>
|
|
||||||
</DropdownMenu>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{timestamp && (
|
|
||||||
<p
|
|
||||||
className="mt-1 text-sm font-normal text-neutral-500"
|
|
||||||
title={new Date(timestamp).toString()}
|
|
||||||
>
|
|
||||||
{isPast(timestamp) ? "Ran" : "Runs in"}{" "}
|
|
||||||
{formatDistanceToNow(timestamp, { addSuffix: true })}
|
|
||||||
</p>
|
|
||||||
)}
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,237 +0,0 @@
|
|||||||
"use client";
|
|
||||||
import { Plus } from "lucide-react";
|
|
||||||
import React, { useEffect, useState } from "react";
|
|
||||||
|
|
||||||
import {
|
|
||||||
GraphExecutionID,
|
|
||||||
GraphExecutionMeta,
|
|
||||||
LibraryAgent,
|
|
||||||
LibraryAgentPreset,
|
|
||||||
LibraryAgentPresetID,
|
|
||||||
Schedule,
|
|
||||||
ScheduleID,
|
|
||||||
} from "@/lib/autogpt-server-api";
|
|
||||||
import { cn } from "@/lib/utils";
|
|
||||||
|
|
||||||
import { Badge } from "@/components/__legacy__/ui/badge";
|
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
|
||||||
import LoadingBox, { LoadingSpinner } from "@/components/__legacy__/ui/loading";
|
|
||||||
import { Separator } from "@/components/__legacy__/ui/separator";
|
|
||||||
import { ScrollArea } from "@/components/__legacy__/ui/scroll-area";
|
|
||||||
import { InfiniteScroll } from "@/components/contextual/InfiniteScroll/InfiniteScroll";
|
|
||||||
import { AgentRunsQuery } from "../use-agent-runs";
|
|
||||||
import { agentRunStatusMap } from "./agent-run-status-chip";
|
|
||||||
import { AgentRunSummaryCard } from "./agent-run-summary-card";
|
|
||||||
|
|
||||||
interface AgentRunsSelectorListProps {
|
|
||||||
agent: LibraryAgent;
|
|
||||||
agentRunsQuery: AgentRunsQuery;
|
|
||||||
agentPresets: LibraryAgentPreset[];
|
|
||||||
schedules: Schedule[];
|
|
||||||
selectedView: { type: "run" | "preset" | "schedule"; id?: string };
|
|
||||||
allowDraftNewRun?: boolean;
|
|
||||||
onSelectRun: (id: GraphExecutionID) => void;
|
|
||||||
onSelectPreset: (preset: LibraryAgentPresetID) => void;
|
|
||||||
onSelectSchedule: (id: ScheduleID) => void;
|
|
||||||
onSelectDraftNewRun: () => void;
|
|
||||||
doDeleteRun: (id: GraphExecutionMeta) => void;
|
|
||||||
doDeletePreset: (id: LibraryAgentPresetID) => void;
|
|
||||||
doDeleteSchedule: (id: ScheduleID) => void;
|
|
||||||
doCreatePresetFromRun?: (id: GraphExecutionID) => void;
|
|
||||||
className?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function AgentRunsSelectorList({
|
|
||||||
agent,
|
|
||||||
agentRunsQuery: {
|
|
||||||
agentRuns,
|
|
||||||
agentRunCount,
|
|
||||||
agentRunsLoading,
|
|
||||||
hasMoreRuns,
|
|
||||||
fetchMoreRuns,
|
|
||||||
isFetchingMoreRuns,
|
|
||||||
},
|
|
||||||
agentPresets,
|
|
||||||
schedules,
|
|
||||||
selectedView,
|
|
||||||
allowDraftNewRun = true,
|
|
||||||
onSelectRun,
|
|
||||||
onSelectPreset,
|
|
||||||
onSelectSchedule,
|
|
||||||
onSelectDraftNewRun,
|
|
||||||
doDeleteRun,
|
|
||||||
doDeletePreset,
|
|
||||||
doDeleteSchedule,
|
|
||||||
doCreatePresetFromRun,
|
|
||||||
className,
|
|
||||||
}: AgentRunsSelectorListProps): React.ReactElement {
|
|
||||||
const [activeListTab, setActiveListTab] = useState<"runs" | "scheduled">(
|
|
||||||
"runs",
|
|
||||||
);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (selectedView.type === "schedule") {
|
|
||||||
setActiveListTab("scheduled");
|
|
||||||
} else {
|
|
||||||
setActiveListTab("runs");
|
|
||||||
}
|
|
||||||
}, [selectedView]);
|
|
||||||
|
|
||||||
const listItemClasses = "h-28 w-72 lg:w-full lg:h-32";
|
|
||||||
|
|
||||||
return (
|
|
||||||
<aside className={cn("flex flex-col gap-4", className)}>
|
|
||||||
{allowDraftNewRun ? (
|
|
||||||
<Button
|
|
||||||
className={"mb-4 hidden lg:flex"}
|
|
||||||
onClick={onSelectDraftNewRun}
|
|
||||||
leftIcon={<Plus className="h-6 w-6" />}
|
|
||||||
>
|
|
||||||
New {agent.has_external_trigger ? "trigger" : "run"}
|
|
||||||
</Button>
|
|
||||||
) : null}
|
|
||||||
|
|
||||||
<div className="flex gap-2">
|
|
||||||
<Badge
|
|
||||||
variant={activeListTab === "runs" ? "secondary" : "outline"}
|
|
||||||
className="cursor-pointer gap-2 rounded-full text-base"
|
|
||||||
onClick={() => setActiveListTab("runs")}
|
|
||||||
>
|
|
||||||
<span>Runs</span>
|
|
||||||
<span className="text-neutral-600">
|
|
||||||
{agentRunCount ?? <LoadingSpinner className="size-4" />}
|
|
||||||
</span>
|
|
||||||
</Badge>
|
|
||||||
|
|
||||||
<Badge
|
|
||||||
variant={activeListTab === "scheduled" ? "secondary" : "outline"}
|
|
||||||
className="cursor-pointer gap-2 rounded-full text-base"
|
|
||||||
onClick={() => setActiveListTab("scheduled")}
|
|
||||||
>
|
|
||||||
<span>Scheduled</span>
|
|
||||||
<span className="text-neutral-600">{schedules.length}</span>
|
|
||||||
</Badge>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Runs / Schedules list */}
|
|
||||||
{agentRunsLoading && activeListTab === "runs" ? (
|
|
||||||
<LoadingBox className="h-28 w-full lg:h-[calc(100vh-300px)] lg:w-72 xl:w-80" />
|
|
||||||
) : (
|
|
||||||
<ScrollArea
|
|
||||||
className="w-full lg:h-[calc(100vh-300px)] lg:w-72 xl:w-80"
|
|
||||||
orientation={window.innerWidth >= 1024 ? "vertical" : "horizontal"}
|
|
||||||
>
|
|
||||||
<InfiniteScroll
|
|
||||||
direction={window.innerWidth >= 1024 ? "vertical" : "horizontal"}
|
|
||||||
hasNextPage={hasMoreRuns}
|
|
||||||
fetchNextPage={fetchMoreRuns}
|
|
||||||
isFetchingNextPage={isFetchingMoreRuns}
|
|
||||||
>
|
|
||||||
<div className="flex items-center gap-2 lg:flex-col">
|
|
||||||
{/* New Run button - only in small layouts */}
|
|
||||||
{allowDraftNewRun && (
|
|
||||||
<Button
|
|
||||||
size="large"
|
|
||||||
className={
|
|
||||||
"flex h-12 w-40 items-center gap-2 py-6 lg:hidden " +
|
|
||||||
(selectedView.type == "run" && !selectedView.id
|
|
||||||
? "agpt-card-selected text-accent"
|
|
||||||
: "")
|
|
||||||
}
|
|
||||||
onClick={onSelectDraftNewRun}
|
|
||||||
leftIcon={<Plus className="h-6 w-6" />}
|
|
||||||
>
|
|
||||||
New {agent.has_external_trigger ? "trigger" : "run"}
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{activeListTab === "runs" ? (
|
|
||||||
<>
|
|
||||||
{agentPresets
|
|
||||||
.filter((preset) => preset.webhook) // Triggers
|
|
||||||
.toSorted(
|
|
||||||
(a, b) => b.updated_at.getTime() - a.updated_at.getTime(),
|
|
||||||
)
|
|
||||||
.map((preset) => (
|
|
||||||
<AgentRunSummaryCard
|
|
||||||
className={cn(listItemClasses, "lg:h-auto")}
|
|
||||||
key={preset.id}
|
|
||||||
type="preset.triggered"
|
|
||||||
status={preset.is_active ? "active" : "inactive"}
|
|
||||||
title={preset.name}
|
|
||||||
// timestamp={preset.last_run_time} // TODO: implement this
|
|
||||||
selected={selectedView.id === preset.id}
|
|
||||||
onClick={() => onSelectPreset(preset.id)}
|
|
||||||
onDelete={() => doDeletePreset(preset.id)}
|
|
||||||
/>
|
|
||||||
))}
|
|
||||||
{agentPresets
|
|
||||||
.filter((preset) => !preset.webhook) // Presets
|
|
||||||
.toSorted(
|
|
||||||
(a, b) => b.updated_at.getTime() - a.updated_at.getTime(),
|
|
||||||
)
|
|
||||||
.map((preset) => (
|
|
||||||
<AgentRunSummaryCard
|
|
||||||
className={cn(listItemClasses, "lg:h-auto")}
|
|
||||||
key={preset.id}
|
|
||||||
type="preset"
|
|
||||||
title={preset.name}
|
|
||||||
// timestamp={preset.last_run_time} // TODO: implement this
|
|
||||||
selected={selectedView.id === preset.id}
|
|
||||||
onClick={() => onSelectPreset(preset.id)}
|
|
||||||
onDelete={() => doDeletePreset(preset.id)}
|
|
||||||
/>
|
|
||||||
))}
|
|
||||||
{agentPresets.length > 0 && <Separator className="my-1" />}
|
|
||||||
{agentRuns
|
|
||||||
.toSorted((a, b) => {
|
|
||||||
const aTime = a.started_at?.getTime() ?? 0;
|
|
||||||
const bTime = b.started_at?.getTime() ?? 0;
|
|
||||||
return bTime - aTime;
|
|
||||||
})
|
|
||||||
.map((run) => (
|
|
||||||
<AgentRunSummaryCard
|
|
||||||
className={listItemClasses}
|
|
||||||
key={run.id}
|
|
||||||
type="run"
|
|
||||||
status={agentRunStatusMap[run.status]}
|
|
||||||
title={
|
|
||||||
(run.preset_id
|
|
||||||
? agentPresets.find((p) => p.id == run.preset_id)
|
|
||||||
?.name
|
|
||||||
: null) ?? agent.name
|
|
||||||
}
|
|
||||||
timestamp={run.started_at ?? undefined}
|
|
||||||
selected={selectedView.id === run.id}
|
|
||||||
onClick={() => onSelectRun(run.id)}
|
|
||||||
onDelete={() => doDeleteRun(run as GraphExecutionMeta)}
|
|
||||||
onPinAsPreset={
|
|
||||||
doCreatePresetFromRun
|
|
||||||
? () => doCreatePresetFromRun(run.id)
|
|
||||||
: undefined
|
|
||||||
}
|
|
||||||
/>
|
|
||||||
))}
|
|
||||||
</>
|
|
||||||
) : (
|
|
||||||
schedules.map((schedule) => (
|
|
||||||
<AgentRunSummaryCard
|
|
||||||
className={listItemClasses}
|
|
||||||
key={schedule.id}
|
|
||||||
type="schedule"
|
|
||||||
status="scheduled" // TODO: implement active/inactive status for schedules
|
|
||||||
title={schedule.name}
|
|
||||||
timestamp={schedule.next_run_time}
|
|
||||||
selected={selectedView.id === schedule.id}
|
|
||||||
onClick={() => onSelectSchedule(schedule.id)}
|
|
||||||
onDelete={() => doDeleteSchedule(schedule.id)}
|
|
||||||
/>
|
|
||||||
))
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</InfiniteScroll>
|
|
||||||
</ScrollArea>
|
|
||||||
)}
|
|
||||||
</aside>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,180 +0,0 @@
|
|||||||
"use client";
|
|
||||||
import React, { useCallback, useMemo } from "react";
|
|
||||||
|
|
||||||
import {
|
|
||||||
Graph,
|
|
||||||
GraphExecutionID,
|
|
||||||
Schedule,
|
|
||||||
ScheduleID,
|
|
||||||
} from "@/lib/autogpt-server-api";
|
|
||||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
|
||||||
|
|
||||||
import ActionButtonGroup from "@/components/__legacy__/action-button-group";
|
|
||||||
import type { ButtonAction } from "@/components/__legacy__/types";
|
|
||||||
import {
|
|
||||||
Card,
|
|
||||||
CardContent,
|
|
||||||
CardHeader,
|
|
||||||
CardTitle,
|
|
||||||
} from "@/components/__legacy__/ui/card";
|
|
||||||
import { IconCross } from "@/components/__legacy__/ui/icons";
|
|
||||||
import { Input } from "@/components/__legacy__/ui/input";
|
|
||||||
import LoadingBox from "@/components/__legacy__/ui/loading";
|
|
||||||
import { useToastOnFail } from "@/components/molecules/Toast/use-toast";
|
|
||||||
import { humanizeCronExpression } from "@/lib/cron-expression-utils";
|
|
||||||
import { formatScheduleTime } from "@/lib/timezone-utils";
|
|
||||||
import { useUserTimezone } from "@/lib/hooks/useUserTimezone";
|
|
||||||
import { PlayIcon } from "lucide-react";
|
|
||||||
|
|
||||||
import { AgentRunStatus } from "./agent-run-status-chip";
|
|
||||||
|
|
||||||
export function AgentScheduleDetailsView({
|
|
||||||
graph,
|
|
||||||
schedule,
|
|
||||||
agentActions,
|
|
||||||
onForcedRun,
|
|
||||||
doDeleteSchedule,
|
|
||||||
}: {
|
|
||||||
graph: Graph;
|
|
||||||
schedule: Schedule;
|
|
||||||
agentActions: ButtonAction[];
|
|
||||||
onForcedRun: (runID: GraphExecutionID) => void;
|
|
||||||
doDeleteSchedule: (scheduleID: ScheduleID) => void;
|
|
||||||
}): React.ReactNode {
|
|
||||||
const api = useBackendAPI();
|
|
||||||
|
|
||||||
const selectedRunStatus: AgentRunStatus = "scheduled";
|
|
||||||
|
|
||||||
const toastOnFail = useToastOnFail();
|
|
||||||
|
|
||||||
// Get user's timezone for displaying schedule times
|
|
||||||
const userTimezone = useUserTimezone();
|
|
||||||
|
|
||||||
const infoStats: { label: string; value: React.ReactNode }[] = useMemo(() => {
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
label: "Status",
|
|
||||||
value:
|
|
||||||
selectedRunStatus.charAt(0).toUpperCase() +
|
|
||||||
selectedRunStatus.slice(1),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
label: "Schedule",
|
|
||||||
value: humanizeCronExpression(schedule.cron),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
label: "Next run",
|
|
||||||
value: formatScheduleTime(schedule.next_run_time, userTimezone),
|
|
||||||
},
|
|
||||||
];
|
|
||||||
}, [schedule, selectedRunStatus, userTimezone]);
|
|
||||||
|
|
||||||
const agentRunInputs: Record<
|
|
||||||
string,
|
|
||||||
{ title?: string; /* type: BlockIOSubType; */ value: any }
|
|
||||||
> = useMemo(() => {
|
|
||||||
// TODO: show (link to) preset - https://github.com/Significant-Gravitas/AutoGPT/issues/9168
|
|
||||||
|
|
||||||
// Add type info from agent input schema
|
|
||||||
return Object.fromEntries(
|
|
||||||
Object.entries(schedule.input_data).map(([k, v]) => [
|
|
||||||
k,
|
|
||||||
{
|
|
||||||
title: graph.input_schema.properties[k].title,
|
|
||||||
/* TODO: type: agent.input_schema.properties[k].type */
|
|
||||||
value: v,
|
|
||||||
},
|
|
||||||
]),
|
|
||||||
);
|
|
||||||
}, [graph, schedule]);
|
|
||||||
|
|
||||||
const runNow = useCallback(
|
|
||||||
() =>
|
|
||||||
api
|
|
||||||
.executeGraph(
|
|
||||||
graph.id,
|
|
||||||
graph.version,
|
|
||||||
schedule.input_data,
|
|
||||||
schedule.input_credentials,
|
|
||||||
"library",
|
|
||||||
)
|
|
||||||
.then((run) => onForcedRun(run.id))
|
|
||||||
.catch(toastOnFail("execute agent")),
|
|
||||||
[api, graph, schedule, onForcedRun, toastOnFail],
|
|
||||||
);
|
|
||||||
|
|
||||||
const runActions: ButtonAction[] = useMemo(
|
|
||||||
() => [
|
|
||||||
{
|
|
||||||
label: (
|
|
||||||
<>
|
|
||||||
<PlayIcon className="mr-2 size-4" />
|
|
||||||
Run now
|
|
||||||
</>
|
|
||||||
),
|
|
||||||
callback: runNow,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
label: (
|
|
||||||
<>
|
|
||||||
<IconCross className="mr-2 size-4 px-0.5" />
|
|
||||||
Delete schedule
|
|
||||||
</>
|
|
||||||
),
|
|
||||||
callback: () => doDeleteSchedule(schedule.id),
|
|
||||||
variant: "destructive",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
[runNow],
|
|
||||||
);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="agpt-div flex gap-6">
|
|
||||||
<div className="flex flex-1 flex-col gap-4">
|
|
||||||
<Card className="agpt-box">
|
|
||||||
<CardHeader>
|
|
||||||
<CardTitle className="font-poppins text-lg">Info</CardTitle>
|
|
||||||
</CardHeader>
|
|
||||||
|
|
||||||
<CardContent>
|
|
||||||
<div className="flex justify-stretch gap-4">
|
|
||||||
{infoStats.map(({ label, value }) => (
|
|
||||||
<div key={label} className="flex-1">
|
|
||||||
<p className="text-sm font-medium text-black">{label}</p>
|
|
||||||
<p className="text-sm text-neutral-600">{value}</p>
|
|
||||||
</div>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
|
|
||||||
<Card className="agpt-box">
|
|
||||||
<CardHeader>
|
|
||||||
<CardTitle className="font-poppins text-lg">Input</CardTitle>
|
|
||||||
</CardHeader>
|
|
||||||
<CardContent className="flex flex-col gap-4">
|
|
||||||
{agentRunInputs !== undefined ? (
|
|
||||||
Object.entries(agentRunInputs).map(([key, { title, value }]) => (
|
|
||||||
<div key={key} className="flex flex-col gap-1.5">
|
|
||||||
<label className="text-sm font-medium">{title || key}</label>
|
|
||||||
<Input value={value} className="rounded-full" disabled />
|
|
||||||
</div>
|
|
||||||
))
|
|
||||||
) : (
|
|
||||||
<LoadingBox spinnerSize={12} className="h-24" />
|
|
||||||
)}
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Run / Agent Actions */}
|
|
||||||
<aside className="w-48 xl:w-56">
|
|
||||||
<div className="flex flex-col gap-8">
|
|
||||||
<ActionButtonGroup title="Run actions" actions={runActions} />
|
|
||||||
|
|
||||||
<ActionButtonGroup title="Agent actions" actions={agentActions} />
|
|
||||||
</div>
|
|
||||||
</aside>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import React, { useState } from "react";
|
|
||||||
import { Button } from "@/components/__legacy__/ui/button";
|
|
||||||
import {
|
|
||||||
Dialog,
|
|
||||||
DialogContent,
|
|
||||||
DialogDescription,
|
|
||||||
DialogFooter,
|
|
||||||
DialogHeader,
|
|
||||||
DialogTitle,
|
|
||||||
} from "@/components/__legacy__/ui/dialog";
|
|
||||||
import { Input } from "@/components/__legacy__/ui/input";
|
|
||||||
import { Textarea } from "@/components/__legacy__/ui/textarea";
|
|
||||||
|
|
||||||
interface CreatePresetDialogProps {
|
|
||||||
open: boolean;
|
|
||||||
onOpenChange: (open: boolean) => void;
|
|
||||||
onConfirm: (name: string, description: string) => Promise<void> | void;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function CreatePresetDialog({
|
|
||||||
open,
|
|
||||||
onOpenChange,
|
|
||||||
onConfirm,
|
|
||||||
}: CreatePresetDialogProps) {
|
|
||||||
const [name, setName] = useState("");
|
|
||||||
const [description, setDescription] = useState("");
|
|
||||||
|
|
||||||
const handleSubmit = async () => {
|
|
||||||
if (name.trim()) {
|
|
||||||
await onConfirm(name.trim(), description.trim());
|
|
||||||
setName("");
|
|
||||||
setDescription("");
|
|
||||||
onOpenChange(false);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleCancel = () => {
|
|
||||||
setName("");
|
|
||||||
setDescription("");
|
|
||||||
onOpenChange(false);
|
|
||||||
};
|
|
||||||
|
|
||||||
const handleKeyDown = (e: React.KeyboardEvent) => {
|
|
||||||
if (e.key === "Enter" && (e.metaKey || e.ctrlKey)) {
|
|
||||||
e.preventDefault();
|
|
||||||
handleSubmit();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
|
||||||
<DialogContent className="sm:max-w-[425px]">
|
|
||||||
<DialogHeader>
|
|
||||||
<DialogTitle>Create Preset</DialogTitle>
|
|
||||||
<DialogDescription>
|
|
||||||
Give your preset a name and description to help identify it later.
|
|
||||||
</DialogDescription>
|
|
||||||
</DialogHeader>
|
|
||||||
<div className="grid gap-4 py-4">
|
|
||||||
<div className="grid gap-2">
|
|
||||||
<label htmlFor="preset-name" className="text-sm font-medium">
|
|
||||||
Name *
|
|
||||||
</label>
|
|
||||||
<Input
|
|
||||||
id="preset-name"
|
|
||||||
placeholder="Enter preset name"
|
|
||||||
value={name}
|
|
||||||
onChange={(e) => setName(e.target.value)}
|
|
||||||
onKeyDown={handleKeyDown}
|
|
||||||
autoFocus
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<div className="grid gap-2">
|
|
||||||
<label htmlFor="preset-description" className="text-sm font-medium">
|
|
||||||
Description
|
|
||||||
</label>
|
|
||||||
<Textarea
|
|
||||||
id="preset-description"
|
|
||||||
placeholder="Optional description"
|
|
||||||
value={description}
|
|
||||||
onChange={(e) => setDescription(e.target.value)}
|
|
||||||
onKeyDown={handleKeyDown}
|
|
||||||
rows={3}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<DialogFooter>
|
|
||||||
<Button variant="outline" onClick={handleCancel}>
|
|
||||||
Cancel
|
|
||||||
</Button>
|
|
||||||
<Button onClick={handleSubmit} disabled={!name.trim()}>
|
|
||||||
Create Preset
|
|
||||||
</Button>
|
|
||||||
</DialogFooter>
|
|
||||||
</DialogContent>
|
|
||||||
</Dialog>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -1,210 +0,0 @@
|
|||||||
import {
|
|
||||||
GraphExecutionMeta as LegacyGraphExecutionMeta,
|
|
||||||
GraphID,
|
|
||||||
GraphExecutionID,
|
|
||||||
} from "@/lib/autogpt-server-api";
|
|
||||||
import { getQueryClient } from "@/lib/react-query/queryClient";
|
|
||||||
import {
|
|
||||||
getPaginatedTotalCount,
|
|
||||||
getPaginationNextPageNumber,
|
|
||||||
unpaginate,
|
|
||||||
} from "@/app/api/helpers";
|
|
||||||
import {
|
|
||||||
getV1ListGraphExecutionsResponse,
|
|
||||||
getV1ListGraphExecutionsResponse200,
|
|
||||||
useGetV1ListGraphExecutionsInfinite,
|
|
||||||
} from "@/app/api/__generated__/endpoints/graphs/graphs";
|
|
||||||
import { GraphExecutionsPaginated } from "@/app/api/__generated__/models/graphExecutionsPaginated";
|
|
||||||
import { GraphExecutionMeta as RawGraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
|
|
||||||
|
|
||||||
export type GraphExecutionMeta = Omit<
|
|
||||||
RawGraphExecutionMeta,
|
|
||||||
"id" | "user_id" | "graph_id" | "preset_id" | "stats"
|
|
||||||
> &
|
|
||||||
Pick<
|
|
||||||
LegacyGraphExecutionMeta,
|
|
||||||
"id" | "user_id" | "graph_id" | "preset_id" | "stats"
|
|
||||||
>;
|
|
||||||
|
|
||||||
/** Hook to fetch runs for a specific graph, with support for infinite scroll.
|
|
||||||
*
|
|
||||||
* @param graphID - The ID of the graph to fetch agent runs for. This parameter is
|
|
||||||
* optional in the sense that the hook doesn't run unless it is passed.
|
|
||||||
* This way, it can be used in components where the graph ID is not
|
|
||||||
* immediately available.
|
|
||||||
*/
|
|
||||||
export const useAgentRunsInfinite = (graphID?: GraphID) => {
|
|
||||||
const queryClient = getQueryClient();
|
|
||||||
const {
|
|
||||||
data: queryResults,
|
|
||||||
refetch: refetchRuns,
|
|
||||||
isPending: agentRunsLoading,
|
|
||||||
isRefetching: agentRunsReloading,
|
|
||||||
hasNextPage: hasMoreRuns,
|
|
||||||
fetchNextPage: fetchMoreRuns,
|
|
||||||
isFetchingNextPage: isFetchingMoreRuns,
|
|
||||||
queryKey,
|
|
||||||
} = useGetV1ListGraphExecutionsInfinite(
|
|
||||||
graphID!,
|
|
||||||
{ page: 1, page_size: 20 },
|
|
||||||
{
|
|
||||||
query: {
|
|
||||||
getNextPageParam: getPaginationNextPageNumber,
|
|
||||||
|
|
||||||
// Prevent query from running if graphID is not available (yet)
|
|
||||||
...(!graphID
|
|
||||||
? {
|
|
||||||
enabled: false,
|
|
||||||
queryFn: () =>
|
|
||||||
// Fake empty response if graphID is not available (yet)
|
|
||||||
Promise.resolve({
|
|
||||||
status: 200,
|
|
||||||
data: {
|
|
||||||
executions: [],
|
|
||||||
pagination: {
|
|
||||||
current_page: 1,
|
|
||||||
page_size: 20,
|
|
||||||
total_items: 0,
|
|
||||||
total_pages: 0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
headers: new Headers(),
|
|
||||||
} satisfies getV1ListGraphExecutionsResponse),
|
|
||||||
}
|
|
||||||
: {}),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
queryClient,
|
|
||||||
);
|
|
||||||
|
|
||||||
const agentRuns = queryResults ? unpaginate(queryResults, "executions") : [];
|
|
||||||
const agentRunCount = getPaginatedTotalCount(queryResults);
|
|
||||||
|
|
||||||
const upsertAgentRun = (newAgentRun: GraphExecutionMeta) => {
|
|
||||||
queryClient.setQueryData(
|
|
||||||
queryKey,
|
|
||||||
(currentQueryData: typeof queryResults) => {
|
|
||||||
if (!currentQueryData?.pages || agentRunCount === undefined)
|
|
||||||
return currentQueryData;
|
|
||||||
|
|
||||||
const exists = currentQueryData.pages.some((page) => {
|
|
||||||
if (page.status !== 200) return false;
|
|
||||||
|
|
||||||
const response = page.data;
|
|
||||||
return response.executions.some((run) => run.id === newAgentRun.id);
|
|
||||||
});
|
|
||||||
if (exists) {
|
|
||||||
// If the run already exists, we update it
|
|
||||||
return {
|
|
||||||
...currentQueryData,
|
|
||||||
pages: currentQueryData.pages.map((page) => {
|
|
||||||
if (page.status !== 200) return page;
|
|
||||||
const response = page.data;
|
|
||||||
const executions = response.executions;
|
|
||||||
|
|
||||||
const index = executions.findIndex(
|
|
||||||
(run) => run.id === newAgentRun.id,
|
|
||||||
);
|
|
||||||
if (index === -1) return page;
|
|
||||||
|
|
||||||
const newExecutions = [...executions];
|
|
||||||
newExecutions[index] = newAgentRun;
|
|
||||||
|
|
||||||
return {
|
|
||||||
...page,
|
|
||||||
data: {
|
|
||||||
...response,
|
|
||||||
executions: newExecutions,
|
|
||||||
},
|
|
||||||
} satisfies getV1ListGraphExecutionsResponse;
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the run does not exist, we add it to the first page
|
|
||||||
const page = currentQueryData
|
|
||||||
.pages[0] as getV1ListGraphExecutionsResponse200 & {
|
|
||||||
headers: Headers;
|
|
||||||
};
|
|
||||||
const updatedExecutions = [newAgentRun, ...page.data.executions];
|
|
||||||
const updatedPage = {
|
|
||||||
...page,
|
|
||||||
data: {
|
|
||||||
...page.data,
|
|
||||||
executions: updatedExecutions,
|
|
||||||
},
|
|
||||||
} satisfies getV1ListGraphExecutionsResponse;
|
|
||||||
const updatedPages = [updatedPage, ...currentQueryData.pages.slice(1)];
|
|
||||||
return {
|
|
||||||
...currentQueryData,
|
|
||||||
pages: updatedPages.map(
|
|
||||||
// Increment the total runs count in the pagination info of all pages
|
|
||||||
(page) =>
|
|
||||||
page.status === 200
|
|
||||||
? {
|
|
||||||
...page,
|
|
||||||
data: {
|
|
||||||
...page.data,
|
|
||||||
pagination: {
|
|
||||||
...page.data.pagination,
|
|
||||||
total_items: agentRunCount + 1,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
: page,
|
|
||||||
),
|
|
||||||
};
|
|
||||||
},
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
const removeAgentRun = (runID: GraphExecutionID) => {
|
|
||||||
queryClient.setQueryData(
|
|
||||||
[queryKey, { page: 1, page_size: 20 }],
|
|
||||||
(currentQueryData: typeof queryResults) => {
|
|
||||||
if (!currentQueryData?.pages) return currentQueryData;
|
|
||||||
|
|
||||||
let found = false;
|
|
||||||
return {
|
|
||||||
...currentQueryData,
|
|
||||||
pages: currentQueryData.pages.map((page) => {
|
|
||||||
const response = page.data as GraphExecutionsPaginated;
|
|
||||||
const filteredExecutions = response.executions.filter(
|
|
||||||
(run) => run.id !== runID,
|
|
||||||
);
|
|
||||||
if (filteredExecutions.length < response.executions.length) {
|
|
||||||
found = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
...page,
|
|
||||||
data: {
|
|
||||||
...response,
|
|
||||||
executions: filteredExecutions,
|
|
||||||
pagination: {
|
|
||||||
...response.pagination,
|
|
||||||
total_items:
|
|
||||||
response.pagination.total_items - (found ? 1 : 0),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
},
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
return {
|
|
||||||
agentRuns: agentRuns as GraphExecutionMeta[],
|
|
||||||
refetchRuns,
|
|
||||||
agentRunCount,
|
|
||||||
agentRunsLoading: agentRunsLoading || agentRunsReloading,
|
|
||||||
hasMoreRuns,
|
|
||||||
fetchMoreRuns,
|
|
||||||
isFetchingMoreRuns,
|
|
||||||
upsertAgentRun,
|
|
||||||
removeAgentRun,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
export type AgentRunsQuery = ReturnType<typeof useAgentRunsInfinite>;
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { OldAgentLibraryView } from "../../agents/[id]/components/OldAgentLibraryView/OldAgentLibraryView";
|
|
||||||
|
|
||||||
export default function OldAgentLibraryPage() {
|
|
||||||
return <OldAgentLibraryView />;
|
|
||||||
}
|
|
||||||
@@ -4269,128 +4269,6 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"/api/mcp/discover-tools": {
|
|
||||||
"post": {
|
|
||||||
"tags": ["v2", "mcp", "mcp"],
|
|
||||||
"summary": "Discover available tools on an MCP server",
|
|
||||||
"description": "Connect to an MCP server and return its available tools.\n\nIf the user has a stored MCP credential for this server URL, it will be\nused automatically — no need to pass an explicit auth token.",
|
|
||||||
"operationId": "postV2Discover available tools on an mcp server",
|
|
||||||
"requestBody": {
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": { "$ref": "#/components/schemas/DiscoverToolsRequest" }
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": true
|
|
||||||
},
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/DiscoverToolsResponse"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"401": {
|
|
||||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
|
||||||
},
|
|
||||||
"422": {
|
|
||||||
"description": "Validation Error",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"security": [{ "HTTPBearerJWT": [] }]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/api/mcp/oauth/callback": {
|
|
||||||
"post": {
|
|
||||||
"tags": ["v2", "mcp", "mcp"],
|
|
||||||
"summary": "Exchange OAuth code for MCP tokens",
|
|
||||||
"description": "Exchange the authorization code for tokens and store the credential.\n\nThe frontend calls this after receiving the OAuth code from the popup.\nOn success, subsequent ``/discover-tools`` calls for the same server URL\nwill automatically use the stored credential.",
|
|
||||||
"operationId": "postV2Exchange oauth code for mcp tokens",
|
|
||||||
"requestBody": {
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/MCPOAuthCallbackRequest"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": true
|
|
||||||
},
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/CredentialsMetaResponse"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"401": {
|
|
||||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
|
||||||
},
|
|
||||||
"422": {
|
|
||||||
"description": "Validation Error",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"security": [{ "HTTPBearerJWT": [] }]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/api/mcp/oauth/login": {
|
|
||||||
"post": {
|
|
||||||
"tags": ["v2", "mcp", "mcp"],
|
|
||||||
"summary": "Initiate OAuth login for an MCP server",
|
|
||||||
"description": "Discover OAuth metadata from the MCP server and return a login URL.\n\n1. Discovers the protected-resource metadata (RFC 9728)\n2. Fetches the authorization server metadata (RFC 8414)\n3. Performs Dynamic Client Registration (RFC 7591) if available\n4. Returns the authorization URL for the frontend to open in a popup",
|
|
||||||
"operationId": "postV2Initiate oauth login for an mcp server",
|
|
||||||
"requestBody": {
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": { "$ref": "#/components/schemas/MCPOAuthLoginRequest" }
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": true
|
|
||||||
},
|
|
||||||
"responses": {
|
|
||||||
"200": {
|
|
||||||
"description": "Successful Response",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": {
|
|
||||||
"$ref": "#/components/schemas/MCPOAuthLoginResponse"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"401": {
|
|
||||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
|
||||||
},
|
|
||||||
"422": {
|
|
||||||
"description": "Validation Error",
|
|
||||||
"content": {
|
|
||||||
"application/json": {
|
|
||||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"security": [{ "HTTPBearerJWT": [] }]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"/api/oauth/app/{client_id}": {
|
"/api/oauth/app/{client_id}": {
|
||||||
"get": {
|
"get": {
|
||||||
"tags": ["oauth"],
|
"tags": ["oauth"],
|
||||||
@@ -7188,57 +7066,13 @@
|
|||||||
"properties": {
|
"properties": {
|
||||||
"id": { "type": "string", "title": "Id" },
|
"id": { "type": "string", "title": "Id" },
|
||||||
"name": { "type": "string", "title": "Name" },
|
"name": { "type": "string", "title": "Name" },
|
||||||
"description": { "type": "string", "title": "Description" },
|
"description": { "type": "string", "title": "Description" }
|
||||||
"categories": {
|
|
||||||
"items": { "type": "string" },
|
|
||||||
"type": "array",
|
|
||||||
"title": "Categories"
|
|
||||||
},
|
|
||||||
"input_schema": {
|
|
||||||
"additionalProperties": true,
|
|
||||||
"type": "object",
|
|
||||||
"title": "Input Schema",
|
|
||||||
"description": "Full JSON schema for block inputs"
|
|
||||||
},
|
|
||||||
"output_schema": {
|
|
||||||
"additionalProperties": true,
|
|
||||||
"type": "object",
|
|
||||||
"title": "Output Schema",
|
|
||||||
"description": "Full JSON schema for block outputs"
|
|
||||||
},
|
|
||||||
"required_inputs": {
|
|
||||||
"items": { "$ref": "#/components/schemas/BlockInputFieldInfo" },
|
|
||||||
"type": "array",
|
|
||||||
"title": "Required Inputs",
|
|
||||||
"description": "List of input fields for this block"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["id", "name", "description", "categories"],
|
"required": ["id", "name", "description"],
|
||||||
"title": "BlockInfoSummary",
|
"title": "BlockInfoSummary",
|
||||||
"description": "Summary of a block for search results."
|
"description": "Summary of a block for search results."
|
||||||
},
|
},
|
||||||
"BlockInputFieldInfo": {
|
|
||||||
"properties": {
|
|
||||||
"name": { "type": "string", "title": "Name" },
|
|
||||||
"type": { "type": "string", "title": "Type" },
|
|
||||||
"description": {
|
|
||||||
"type": "string",
|
|
||||||
"title": "Description",
|
|
||||||
"default": ""
|
|
||||||
},
|
|
||||||
"required": {
|
|
||||||
"type": "boolean",
|
|
||||||
"title": "Required",
|
|
||||||
"default": false
|
|
||||||
},
|
|
||||||
"default": { "anyOf": [{}, { "type": "null" }], "title": "Default" }
|
|
||||||
},
|
|
||||||
"type": "object",
|
|
||||||
"required": ["name", "type"],
|
|
||||||
"title": "BlockInputFieldInfo",
|
|
||||||
"description": "Information about a block input field."
|
|
||||||
},
|
|
||||||
"BlockListResponse": {
|
"BlockListResponse": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"type": {
|
"type": {
|
||||||
@@ -7256,12 +7090,7 @@
|
|||||||
"title": "Blocks"
|
"title": "Blocks"
|
||||||
},
|
},
|
||||||
"count": { "type": "integer", "title": "Count" },
|
"count": { "type": "integer", "title": "Count" },
|
||||||
"query": { "type": "string", "title": "Query" },
|
"query": { "type": "string", "title": "Query" }
|
||||||
"usage_hint": {
|
|
||||||
"type": "string",
|
|
||||||
"title": "Usage Hint",
|
|
||||||
"default": "To execute a block, call run_block with block_id set to the block's 'id' field and input_data containing the fields listed in required_inputs."
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["message", "blocks", "count", "query"],
|
"required": ["message", "blocks", "count", "query"],
|
||||||
@@ -7813,7 +7642,7 @@
|
|||||||
"host": {
|
"host": {
|
||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||||
"title": "Host",
|
"title": "Host",
|
||||||
"description": "Host pattern for host-scoped or MCP server URL for MCP credentials"
|
"description": "Host pattern for host-scoped credentials"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -7833,45 +7662,6 @@
|
|||||||
"required": ["version_counts"],
|
"required": ["version_counts"],
|
||||||
"title": "DeleteGraphResponse"
|
"title": "DeleteGraphResponse"
|
||||||
},
|
},
|
||||||
"DiscoverToolsRequest": {
|
|
||||||
"properties": {
|
|
||||||
"server_url": {
|
|
||||||
"type": "string",
|
|
||||||
"title": "Server Url",
|
|
||||||
"description": "URL of the MCP server"
|
|
||||||
},
|
|
||||||
"auth_token": {
|
|
||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
|
||||||
"title": "Auth Token",
|
|
||||||
"description": "Optional Bearer token for authenticated MCP servers"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"type": "object",
|
|
||||||
"required": ["server_url"],
|
|
||||||
"title": "DiscoverToolsRequest",
|
|
||||||
"description": "Request to discover tools on an MCP server."
|
|
||||||
},
|
|
||||||
"DiscoverToolsResponse": {
|
|
||||||
"properties": {
|
|
||||||
"tools": {
|
|
||||||
"items": { "$ref": "#/components/schemas/MCPToolResponse" },
|
|
||||||
"type": "array",
|
|
||||||
"title": "Tools"
|
|
||||||
},
|
|
||||||
"server_name": {
|
|
||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
|
||||||
"title": "Server Name"
|
|
||||||
},
|
|
||||||
"protocol_version": {
|
|
||||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
|
||||||
"title": "Protocol Version"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"type": "object",
|
|
||||||
"required": ["tools"],
|
|
||||||
"title": "DiscoverToolsResponse",
|
|
||||||
"description": "Response containing the list of tools available on an MCP server."
|
|
||||||
},
|
|
||||||
"DocPageResponse": {
|
"DocPageResponse": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"type": {
|
"type": {
|
||||||
@@ -9448,62 +9238,6 @@
|
|||||||
"required": ["login_url", "state_token"],
|
"required": ["login_url", "state_token"],
|
||||||
"title": "LoginResponse"
|
"title": "LoginResponse"
|
||||||
},
|
},
|
||||||
"MCPOAuthCallbackRequest": {
|
|
||||||
"properties": {
|
|
||||||
"code": {
|
|
||||||
"type": "string",
|
|
||||||
"title": "Code",
|
|
||||||
"description": "Authorization code from OAuth callback"
|
|
||||||
},
|
|
||||||
"state_token": {
|
|
||||||
"type": "string",
|
|
||||||
"title": "State Token",
|
|
||||||
"description": "State token for CSRF verification"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"type": "object",
|
|
||||||
"required": ["code", "state_token"],
|
|
||||||
"title": "MCPOAuthCallbackRequest",
|
|
||||||
"description": "Request to exchange an OAuth code for tokens."
|
|
||||||
},
|
|
||||||
"MCPOAuthLoginRequest": {
|
|
||||||
"properties": {
|
|
||||||
"server_url": {
|
|
||||||
"type": "string",
|
|
||||||
"title": "Server Url",
|
|
||||||
"description": "URL of the MCP server that requires OAuth"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"type": "object",
|
|
||||||
"required": ["server_url"],
|
|
||||||
"title": "MCPOAuthLoginRequest",
|
|
||||||
"description": "Request to start an OAuth flow for an MCP server."
|
|
||||||
},
|
|
||||||
"MCPOAuthLoginResponse": {
|
|
||||||
"properties": {
|
|
||||||
"login_url": { "type": "string", "title": "Login Url" },
|
|
||||||
"state_token": { "type": "string", "title": "State Token" }
|
|
||||||
},
|
|
||||||
"type": "object",
|
|
||||||
"required": ["login_url", "state_token"],
|
|
||||||
"title": "MCPOAuthLoginResponse",
|
|
||||||
"description": "Response with the OAuth login URL for the user to authenticate."
|
|
||||||
},
|
|
||||||
"MCPToolResponse": {
|
|
||||||
"properties": {
|
|
||||||
"name": { "type": "string", "title": "Name" },
|
|
||||||
"description": { "type": "string", "title": "Description" },
|
|
||||||
"input_schema": {
|
|
||||||
"additionalProperties": true,
|
|
||||||
"type": "object",
|
|
||||||
"title": "Input Schema"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"type": "object",
|
|
||||||
"required": ["name", "description", "input_schema"],
|
|
||||||
"title": "MCPToolResponse",
|
|
||||||
"description": "A single MCP tool returned by discovery."
|
|
||||||
},
|
|
||||||
"MarketplaceListing": {
|
"MarketplaceListing": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"id": { "type": "string", "title": "Id" },
|
"id": { "type": "string", "title": "Id" },
|
||||||
@@ -10762,9 +10496,6 @@
|
|||||||
"operation_pending",
|
"operation_pending",
|
||||||
"operation_in_progress",
|
"operation_in_progress",
|
||||||
"input_validation_error",
|
"input_validation_error",
|
||||||
"web_fetch",
|
|
||||||
"bash_exec",
|
|
||||||
"operation_status",
|
|
||||||
"feature_request_search",
|
"feature_request_search",
|
||||||
"feature_request_created"
|
"feature_request_created"
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -38,8 +38,13 @@ export function CredentialsGroupedView({
|
|||||||
const allProviders = useContext(CredentialsProvidersContext);
|
const allProviders = useContext(CredentialsProvidersContext);
|
||||||
|
|
||||||
const { userCredentialFields, systemCredentialFields } = useMemo(
|
const { userCredentialFields, systemCredentialFields } = useMemo(
|
||||||
() => splitCredentialFieldsBySystem(credentialFields, allProviders),
|
() =>
|
||||||
[credentialFields, allProviders],
|
splitCredentialFieldsBySystem(
|
||||||
|
credentialFields,
|
||||||
|
allProviders,
|
||||||
|
inputCredentials,
|
||||||
|
),
|
||||||
|
[credentialFields, allProviders, inputCredentials],
|
||||||
);
|
);
|
||||||
|
|
||||||
const hasSystemCredentials = systemCredentialFields.length > 0;
|
const hasSystemCredentials = systemCredentialFields.length > 0;
|
||||||
@@ -81,13 +86,11 @@ export function CredentialsGroupedView({
|
|||||||
const providerNames = schema.credentials_provider || [];
|
const providerNames = schema.credentials_provider || [];
|
||||||
const credentialTypes = schema.credentials_types || [];
|
const credentialTypes = schema.credentials_types || [];
|
||||||
const requiredScopes = schema.credentials_scopes;
|
const requiredScopes = schema.credentials_scopes;
|
||||||
const discriminatorValues = schema.discriminator_values;
|
|
||||||
const savedCredential = findSavedCredentialByProviderAndType(
|
const savedCredential = findSavedCredentialByProviderAndType(
|
||||||
providerNames,
|
providerNames,
|
||||||
credentialTypes,
|
credentialTypes,
|
||||||
requiredScopes,
|
requiredScopes,
|
||||||
allProviders,
|
allProviders,
|
||||||
discriminatorValues,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
if (savedCredential) {
|
if (savedCredential) {
|
||||||
|
|||||||
@@ -23,35 +23,10 @@ function hasRequiredScopes(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Check if a credential matches the discriminator values (e.g. MCP server URL). */
|
|
||||||
function matchesDiscriminatorValues(
|
|
||||||
credential: { host?: string | null; provider: string; type: string },
|
|
||||||
discriminatorValues?: string[],
|
|
||||||
) {
|
|
||||||
// MCP OAuth2 credentials must match by server URL
|
|
||||||
if (credential.type === "oauth2" && credential.provider === "mcp") {
|
|
||||||
if (!discriminatorValues || discriminatorValues.length === 0) return false;
|
|
||||||
return (
|
|
||||||
credential.host != null && discriminatorValues.includes(credential.host)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
// Host-scoped credentials match by host
|
|
||||||
if (credential.type === "host_scoped" && credential.host) {
|
|
||||||
if (!discriminatorValues || discriminatorValues.length === 0) return true;
|
|
||||||
return discriminatorValues.some((v) => {
|
|
||||||
try {
|
|
||||||
return new URL(v).hostname === credential.host;
|
|
||||||
} catch {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function splitCredentialFieldsBySystem(
|
export function splitCredentialFieldsBySystem(
|
||||||
credentialFields: CredentialField[],
|
credentialFields: CredentialField[],
|
||||||
allProviders: CredentialsProvidersContextType | null,
|
allProviders: CredentialsProvidersContextType | null,
|
||||||
|
inputCredentials?: Record<string, unknown>,
|
||||||
) {
|
) {
|
||||||
if (!allProviders || credentialFields.length === 0) {
|
if (!allProviders || credentialFields.length === 0) {
|
||||||
return {
|
return {
|
||||||
@@ -77,9 +52,17 @@ export function splitCredentialFieldsBySystem(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const sortByUnsetFirst = (a: CredentialField, b: CredentialField) => {
|
||||||
|
const aIsSet = Boolean(inputCredentials?.[a[0]]);
|
||||||
|
const bIsSet = Boolean(inputCredentials?.[b[0]]);
|
||||||
|
|
||||||
|
if (aIsSet === bIsSet) return 0;
|
||||||
|
return aIsSet ? 1 : -1;
|
||||||
|
};
|
||||||
|
|
||||||
return {
|
return {
|
||||||
userCredentialFields: userFields,
|
userCredentialFields: userFields.sort(sortByUnsetFirst),
|
||||||
systemCredentialFields: systemFields,
|
systemCredentialFields: systemFields.sort(sortByUnsetFirst),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -177,7 +160,6 @@ export function findSavedCredentialByProviderAndType(
|
|||||||
credentialTypes: string[],
|
credentialTypes: string[],
|
||||||
requiredScopes: string[] | undefined,
|
requiredScopes: string[] | undefined,
|
||||||
allProviders: CredentialsProvidersContextType | null,
|
allProviders: CredentialsProvidersContextType | null,
|
||||||
discriminatorValues?: string[],
|
|
||||||
): SavedCredential | undefined {
|
): SavedCredential | undefined {
|
||||||
for (const providerName of providerNames) {
|
for (const providerName of providerNames) {
|
||||||
const providerData = allProviders?.[providerName];
|
const providerData = allProviders?.[providerName];
|
||||||
@@ -194,14 +176,9 @@ export function findSavedCredentialByProviderAndType(
|
|||||||
credentialTypes.length === 0 ||
|
credentialTypes.length === 0 ||
|
||||||
credentialTypes.includes(credential.type);
|
credentialTypes.includes(credential.type);
|
||||||
const scopesMatch = hasRequiredScopes(credential, requiredScopes);
|
const scopesMatch = hasRequiredScopes(credential, requiredScopes);
|
||||||
const hostMatches = matchesDiscriminatorValues(
|
|
||||||
credential,
|
|
||||||
discriminatorValues,
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!typeMatches) continue;
|
if (!typeMatches) continue;
|
||||||
if (!scopesMatch) continue;
|
if (!scopesMatch) continue;
|
||||||
if (!hostMatches) continue;
|
|
||||||
|
|
||||||
matchingCredentials.push(credential as SavedCredential);
|
matchingCredentials.push(credential as SavedCredential);
|
||||||
}
|
}
|
||||||
@@ -213,14 +190,9 @@ export function findSavedCredentialByProviderAndType(
|
|||||||
credentialTypes.length === 0 ||
|
credentialTypes.length === 0 ||
|
||||||
credentialTypes.includes(credential.type);
|
credentialTypes.includes(credential.type);
|
||||||
const scopesMatch = hasRequiredScopes(credential, requiredScopes);
|
const scopesMatch = hasRequiredScopes(credential, requiredScopes);
|
||||||
const hostMatches = matchesDiscriminatorValues(
|
|
||||||
credential,
|
|
||||||
discriminatorValues,
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!typeMatches) continue;
|
if (!typeMatches) continue;
|
||||||
if (!scopesMatch) continue;
|
if (!scopesMatch) continue;
|
||||||
if (!hostMatches) continue;
|
|
||||||
|
|
||||||
matchingCredentials.push(credential as SavedCredential);
|
matchingCredentials.push(credential as SavedCredential);
|
||||||
}
|
}
|
||||||
@@ -242,7 +214,6 @@ export function findSavedUserCredentialByProviderAndType(
|
|||||||
credentialTypes: string[],
|
credentialTypes: string[],
|
||||||
requiredScopes: string[] | undefined,
|
requiredScopes: string[] | undefined,
|
||||||
allProviders: CredentialsProvidersContextType | null,
|
allProviders: CredentialsProvidersContextType | null,
|
||||||
discriminatorValues?: string[],
|
|
||||||
): SavedCredential | undefined {
|
): SavedCredential | undefined {
|
||||||
for (const providerName of providerNames) {
|
for (const providerName of providerNames) {
|
||||||
const providerData = allProviders?.[providerName];
|
const providerData = allProviders?.[providerName];
|
||||||
@@ -259,14 +230,9 @@ export function findSavedUserCredentialByProviderAndType(
|
|||||||
credentialTypes.length === 0 ||
|
credentialTypes.length === 0 ||
|
||||||
credentialTypes.includes(credential.type);
|
credentialTypes.includes(credential.type);
|
||||||
const scopesMatch = hasRequiredScopes(credential, requiredScopes);
|
const scopesMatch = hasRequiredScopes(credential, requiredScopes);
|
||||||
const hostMatches = matchesDiscriminatorValues(
|
|
||||||
credential,
|
|
||||||
discriminatorValues,
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!typeMatches) continue;
|
if (!typeMatches) continue;
|
||||||
if (!scopesMatch) continue;
|
if (!scopesMatch) continue;
|
||||||
if (!hostMatches) continue;
|
|
||||||
|
|
||||||
matchingCredentials.push(credential as SavedCredential);
|
matchingCredentials.push(credential as SavedCredential);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,14 +5,14 @@ import {
|
|||||||
BlockIOCredentialsSubSchema,
|
BlockIOCredentialsSubSchema,
|
||||||
CredentialsMetaInput,
|
CredentialsMetaInput,
|
||||||
} from "@/lib/autogpt-server-api/types";
|
} from "@/lib/autogpt-server-api/types";
|
||||||
import { postV2InitiateOauthLoginForAnMcpServer } from "@/app/api/__generated__/endpoints/mcp/mcp";
|
|
||||||
import { openOAuthPopup } from "@/lib/oauth-popup";
|
|
||||||
import { useQueryClient } from "@tanstack/react-query";
|
import { useQueryClient } from "@tanstack/react-query";
|
||||||
import { useEffect, useRef, useState } from "react";
|
import { useEffect, useRef, useState } from "react";
|
||||||
import {
|
import {
|
||||||
filterSystemCredentials,
|
filterSystemCredentials,
|
||||||
getActionButtonText,
|
getActionButtonText,
|
||||||
getSystemCredentials,
|
getSystemCredentials,
|
||||||
|
OAUTH_TIMEOUT_MS,
|
||||||
|
OAuthPopupResultMessage,
|
||||||
} from "./helpers";
|
} from "./helpers";
|
||||||
|
|
||||||
export type CredentialsInputState = ReturnType<typeof useCredentialsInput>;
|
export type CredentialsInputState = ReturnType<typeof useCredentialsInput>;
|
||||||
@@ -57,14 +57,6 @@ export function useCredentialsInput({
|
|||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
const credentials = useCredentials(schema, siblingInputs);
|
const credentials = useCredentials(schema, siblingInputs);
|
||||||
const hasAttemptedAutoSelect = useRef(false);
|
const hasAttemptedAutoSelect = useRef(false);
|
||||||
const oauthAbortRef = useRef<((reason?: string) => void) | null>(null);
|
|
||||||
|
|
||||||
// Clean up on unmount
|
|
||||||
useEffect(() => {
|
|
||||||
return () => {
|
|
||||||
oauthAbortRef.current?.();
|
|
||||||
};
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const deleteCredentialsMutation = useDeleteV1DeleteCredentials({
|
const deleteCredentialsMutation = useDeleteV1DeleteCredentials({
|
||||||
mutation: {
|
mutation: {
|
||||||
@@ -89,14 +81,11 @@ export function useCredentialsInput({
|
|||||||
}
|
}
|
||||||
}, [credentials, onLoaded]);
|
}, [credentials, onLoaded]);
|
||||||
|
|
||||||
// Unselect credential if not available in the loaded credential list.
|
// Unselect credential if not available
|
||||||
// Skip when no credentials have been loaded yet (empty list could mean
|
|
||||||
// the provider data hasn't finished loading, not that the credential is invalid).
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (readOnly) return;
|
if (readOnly) return;
|
||||||
if (!credentials || !("savedCredentials" in credentials)) return;
|
if (!credentials || !("savedCredentials" in credentials)) return;
|
||||||
const availableCreds = credentials.savedCredentials;
|
const availableCreds = credentials.savedCredentials;
|
||||||
if (availableCreds.length === 0) return;
|
|
||||||
if (
|
if (
|
||||||
selectedCredential &&
|
selectedCredential &&
|
||||||
!availableCreds.some((c) => c.id === selectedCredential.id)
|
!availableCreds.some((c) => c.id === selectedCredential.id)
|
||||||
@@ -121,9 +110,7 @@ export function useCredentialsInput({
|
|||||||
if (hasAttemptedAutoSelect.current) return;
|
if (hasAttemptedAutoSelect.current) return;
|
||||||
hasAttemptedAutoSelect.current = true;
|
hasAttemptedAutoSelect.current = true;
|
||||||
|
|
||||||
// Auto-select if exactly one credential matches.
|
if (isOptional) return;
|
||||||
// For optional fields with multiple options, let the user choose.
|
|
||||||
if (isOptional && savedCreds.length > 1) return;
|
|
||||||
|
|
||||||
const cred = savedCreds[0];
|
const cred = savedCreds[0];
|
||||||
onSelectCredential({
|
onSelectCredential({
|
||||||
@@ -161,9 +148,7 @@ export function useCredentialsInput({
|
|||||||
supportsHostScoped,
|
supportsHostScoped,
|
||||||
savedCredentials,
|
savedCredentials,
|
||||||
oAuthCallback,
|
oAuthCallback,
|
||||||
mcpOAuthCallback,
|
|
||||||
isSystemProvider,
|
isSystemProvider,
|
||||||
discriminatorValue,
|
|
||||||
} = credentials;
|
} = credentials;
|
||||||
|
|
||||||
// Split credentials into user and system
|
// Split credentials into user and system
|
||||||
@@ -172,66 +157,72 @@ export function useCredentialsInput({
|
|||||||
|
|
||||||
async function handleOAuthLogin() {
|
async function handleOAuthLogin() {
|
||||||
setOAuthError(null);
|
setOAuthError(null);
|
||||||
|
const { login_url, state_token } = await api.oAuthLogin(
|
||||||
// Abort any previous OAuth flow
|
|
||||||
oauthAbortRef.current?.();
|
|
||||||
|
|
||||||
// MCP uses dynamic OAuth discovery per server URL
|
|
||||||
const isMCP = provider === "mcp" && !!discriminatorValue;
|
|
||||||
|
|
||||||
try {
|
|
||||||
let login_url: string;
|
|
||||||
let state_token: string;
|
|
||||||
|
|
||||||
if (isMCP) {
|
|
||||||
const mcpLoginResponse = await postV2InitiateOauthLoginForAnMcpServer({
|
|
||||||
server_url: discriminatorValue!,
|
|
||||||
});
|
|
||||||
if (mcpLoginResponse.status !== 200) throw mcpLoginResponse.data;
|
|
||||||
({ login_url, state_token } = mcpLoginResponse.data);
|
|
||||||
} else {
|
|
||||||
({ login_url, state_token } = await api.oAuthLogin(
|
|
||||||
provider,
|
provider,
|
||||||
schema.credentials_scopes,
|
schema.credentials_scopes,
|
||||||
));
|
);
|
||||||
|
setOAuth2FlowInProgress(true);
|
||||||
|
const popup = window.open(login_url, "_blank", "popup=true");
|
||||||
|
|
||||||
|
if (!popup) {
|
||||||
|
throw new Error(
|
||||||
|
"Failed to open popup window. Please allow popups for this site.",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
setOAuth2FlowInProgress(true);
|
|
||||||
|
|
||||||
const { promise, cleanup } = openOAuthPopup(login_url, {
|
|
||||||
stateToken: state_token,
|
|
||||||
useCrossOriginListeners: isMCP,
|
|
||||||
// Standard OAuth uses "oauth_popup_result", MCP uses "mcp_oauth_result"
|
|
||||||
acceptMessageTypes: isMCP
|
|
||||||
? ["mcp_oauth_result"]
|
|
||||||
: ["oauth_popup_result"],
|
|
||||||
});
|
|
||||||
|
|
||||||
oauthAbortRef.current = cleanup.abort;
|
|
||||||
// Expose abort signal for the waiting modal's cancel button
|
|
||||||
const controller = new AbortController();
|
const controller = new AbortController();
|
||||||
cleanup.signal.addEventListener("abort", () =>
|
|
||||||
controller.abort("completed"),
|
|
||||||
);
|
|
||||||
setOAuthPopupController(controller);
|
setOAuthPopupController(controller);
|
||||||
|
controller.signal.onabort = () => {
|
||||||
|
console.debug("OAuth flow aborted");
|
||||||
|
setOAuth2FlowInProgress(false);
|
||||||
|
popup.close();
|
||||||
|
};
|
||||||
|
|
||||||
const result = await promise;
|
const handleMessage = async (e: MessageEvent<OAuthPopupResultMessage>) => {
|
||||||
|
console.debug("Message received:", e.data);
|
||||||
|
if (
|
||||||
|
typeof e.data != "object" ||
|
||||||
|
!("message_type" in e.data) ||
|
||||||
|
e.data.message_type !== "oauth_popup_result"
|
||||||
|
) {
|
||||||
|
console.debug("Ignoring irrelevant message");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Exchange code for tokens via the provider (updates credential cache)
|
if (!e.data.success) {
|
||||||
const credentialResult = isMCP
|
console.error("OAuth flow failed:", e.data.message);
|
||||||
? await mcpOAuthCallback(result.code, state_token)
|
setOAuthError(`OAuth flow failed: ${e.data.message}`);
|
||||||
: await oAuthCallback(result.code, result.state);
|
setOAuth2FlowInProgress(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// Check if the credential's scopes match the required scopes (skip for MCP)
|
if (e.data.state !== state_token) {
|
||||||
if (!isMCP) {
|
console.error("Invalid state token received");
|
||||||
|
setOAuthError("Invalid state token received");
|
||||||
|
setOAuth2FlowInProgress(false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
console.debug("Processing OAuth callback");
|
||||||
|
const credentials = await oAuthCallback(e.data.code, e.data.state);
|
||||||
|
console.debug("OAuth callback processed successfully");
|
||||||
|
|
||||||
|
// Check if the credential's scopes match the required scopes
|
||||||
const requiredScopes = schema.credentials_scopes;
|
const requiredScopes = schema.credentials_scopes;
|
||||||
if (requiredScopes && requiredScopes.length > 0) {
|
if (requiredScopes && requiredScopes.length > 0) {
|
||||||
const grantedScopes = new Set(credentialResult.scopes || []);
|
const grantedScopes = new Set(credentials.scopes || []);
|
||||||
const hasAllRequiredScopes = new Set(requiredScopes).isSubsetOf(
|
const hasAllRequiredScopes = new Set(requiredScopes).isSubsetOf(
|
||||||
grantedScopes,
|
grantedScopes,
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!hasAllRequiredScopes) {
|
if (!hasAllRequiredScopes) {
|
||||||
|
console.error(
|
||||||
|
`Newly created OAuth credential for ${providerName} has insufficient scopes. Required:`,
|
||||||
|
requiredScopes,
|
||||||
|
"Granted:",
|
||||||
|
credentials.scopes,
|
||||||
|
);
|
||||||
setOAuthError(
|
setOAuthError(
|
||||||
"Connection failed: the granted permissions don't match what's required. " +
|
"Connection failed: the granted permissions don't match what's required. " +
|
||||||
"Please contact the application administrator.",
|
"Please contact the application administrator.",
|
||||||
@@ -239,28 +230,38 @@ export function useCredentialsInput({
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
onSelectCredential({
|
onSelectCredential({
|
||||||
id: credentialResult.id,
|
id: credentials.id,
|
||||||
type: "oauth2",
|
type: "oauth2",
|
||||||
title: credentialResult.title,
|
title: credentials.title,
|
||||||
provider,
|
provider,
|
||||||
});
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
if (error instanceof Error && error.message === "OAuth flow timed out") {
|
console.error("Error in OAuth callback:", error);
|
||||||
setOAuthError("OAuth flow timed out");
|
|
||||||
} else {
|
|
||||||
setOAuthError(
|
setOAuthError(
|
||||||
`OAuth error: ${
|
`Error in OAuth callback: ${
|
||||||
error instanceof Error ? error.message : String(error)
|
error instanceof Error ? error.message : String(error)
|
||||||
}`,
|
}`,
|
||||||
);
|
);
|
||||||
}
|
|
||||||
} finally {
|
} finally {
|
||||||
|
console.debug("Finalizing OAuth flow");
|
||||||
setOAuth2FlowInProgress(false);
|
setOAuth2FlowInProgress(false);
|
||||||
oauthAbortRef.current = null;
|
controller.abort("success");
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
console.debug("Adding message event listener");
|
||||||
|
window.addEventListener("message", handleMessage, {
|
||||||
|
signal: controller.signal,
|
||||||
|
});
|
||||||
|
|
||||||
|
setTimeout(() => {
|
||||||
|
console.debug("OAuth flow timed out");
|
||||||
|
controller.abort("timeout");
|
||||||
|
setOAuth2FlowInProgress(false);
|
||||||
|
setOAuthError("OAuth flow timed out");
|
||||||
|
}, OAUTH_TIMEOUT_MS);
|
||||||
}
|
}
|
||||||
|
|
||||||
function handleActionButtonClick() {
|
function handleActionButtonClick() {
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import { useEffect, useState } from "react";
|
|||||||
import { Input } from "@/components/__legacy__/ui/input";
|
import { Input } from "@/components/__legacy__/ui/input";
|
||||||
import { Button } from "@/components/__legacy__/ui/button";
|
import { Button } from "@/components/__legacy__/ui/button";
|
||||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||||
import { CronScheduler } from "@/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/cron-scheduler";
|
import { CronScheduler } from "@/components/contextual/CronScheduler/cron-scheduler";
|
||||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||||
import { getTimezoneDisplayName } from "@/lib/timezone-utils";
|
import { getTimezoneDisplayName } from "@/lib/timezone-utils";
|
||||||
import { useUserTimezone } from "@/lib/hooks/useUserTimezone";
|
import { useUserTimezone } from "@/lib/hooks/useUserTimezone";
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import { CronExpressionDialog } from "@/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/cron-scheduler-dialog";
|
import { CronExpressionDialog } from "@/components/contextual/CronScheduler/cron-scheduler-dialog";
|
||||||
import { Form, FormField } from "@/components/__legacy__/ui/form";
|
import { Form, FormField } from "@/components/__legacy__/ui/form";
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
import { Input } from "@/components/atoms/Input/Input";
|
import { Input } from "@/components/atoms/Input/Input";
|
||||||
|
|||||||
@@ -100,11 +100,6 @@ export default function useCredentials(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter MCP OAuth2 credentials by server URL matching
|
|
||||||
if (c.type === "oauth2" && c.provider === "mcp") {
|
|
||||||
return discriminatorValue != null && c.host === discriminatorValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter by OAuth credentials that have sufficient scopes for this block
|
// Filter by OAuth credentials that have sufficient scopes for this block
|
||||||
if (c.type === "oauth2") {
|
if (c.type === "oauth2") {
|
||||||
const requiredScopes = credsInputSchema.credentials_scopes;
|
const requiredScopes = credsInputSchema.credentials_scopes;
|
||||||
|
|||||||
@@ -749,12 +749,10 @@ export enum BlockUIType {
|
|||||||
AGENT = "Agent",
|
AGENT = "Agent",
|
||||||
AI = "AI",
|
AI = "AI",
|
||||||
AYRSHARE = "Ayrshare",
|
AYRSHARE = "Ayrshare",
|
||||||
MCP_TOOL = "MCP Tool",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export enum SpecialBlockID {
|
export enum SpecialBlockID {
|
||||||
AGENT = "e189baac-8c20-45a1-94a7-55177ea42565",
|
AGENT = "e189baac-8c20-45a1-94a7-55177ea42565",
|
||||||
MCP_TOOL = "a0a4b1c2-d3e4-4f56-a7b8-c9d0e1f2a3b4",
|
|
||||||
SMART_DECISION = "3b191d9f-356f-482d-8238-ba04b6d18381",
|
SMART_DECISION = "3b191d9f-356f-482d-8238-ba04b6d18381",
|
||||||
OUTPUT = "363ae599-353e-4804-937e-b2ee3cef3da4",
|
OUTPUT = "363ae599-353e-4804-937e-b2ee3cef3da4",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,177 +0,0 @@
|
|||||||
/**
|
|
||||||
* Shared utility for OAuth popup flows with cross-origin support.
|
|
||||||
*
|
|
||||||
* Handles BroadcastChannel, postMessage, and localStorage polling
|
|
||||||
* to reliably receive OAuth callback results even when COOP headers
|
|
||||||
* sever the window.opener relationship.
|
|
||||||
*/
|
|
||||||
|
|
||||||
const DEFAULT_TIMEOUT_MS = 5 * 60 * 1000; // 5 minutes
|
|
||||||
|
|
||||||
export type OAuthPopupResult = {
|
|
||||||
code: string;
|
|
||||||
state: string;
|
|
||||||
};
|
|
||||||
|
|
||||||
export type OAuthPopupOptions = {
|
|
||||||
/** State token to validate against incoming messages */
|
|
||||||
stateToken: string;
|
|
||||||
/**
|
|
||||||
* Use BroadcastChannel + localStorage polling for cross-origin OAuth (MCP).
|
|
||||||
* Standard OAuth only uses postMessage via window.opener.
|
|
||||||
*/
|
|
||||||
useCrossOriginListeners?: boolean;
|
|
||||||
/** BroadcastChannel name (default: "mcp_oauth") */
|
|
||||||
broadcastChannelName?: string;
|
|
||||||
/** localStorage key for cross-origin fallback (default: "mcp_oauth_result") */
|
|
||||||
localStorageKey?: string;
|
|
||||||
/** Message types to accept (default: ["oauth_popup_result", "mcp_oauth_result"]) */
|
|
||||||
acceptMessageTypes?: string[];
|
|
||||||
/** Timeout in ms (default: 5 minutes) */
|
|
||||||
timeout?: number;
|
|
||||||
};
|
|
||||||
|
|
||||||
type Cleanup = {
|
|
||||||
/** Abort the OAuth flow and close the popup */
|
|
||||||
abort: (reason?: string) => void;
|
|
||||||
/** The AbortController signal */
|
|
||||||
signal: AbortSignal;
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Opens an OAuth popup and sets up listeners for the callback result.
|
|
||||||
*
|
|
||||||
* Opens a blank popup synchronously (to avoid popup blockers), then navigates
|
|
||||||
* it to the login URL. Returns a promise that resolves with the OAuth code/state.
|
|
||||||
*
|
|
||||||
* @param loginUrl - The OAuth authorization URL to navigate to
|
|
||||||
* @param options - Configuration for message handling
|
|
||||||
* @returns Object with `promise` (resolves with OAuth result) and `abort` (cancels flow)
|
|
||||||
*/
|
|
||||||
export function openOAuthPopup(
|
|
||||||
loginUrl: string,
|
|
||||||
options: OAuthPopupOptions,
|
|
||||||
): { promise: Promise<OAuthPopupResult>; cleanup: Cleanup } {
|
|
||||||
const {
|
|
||||||
stateToken,
|
|
||||||
useCrossOriginListeners = false,
|
|
||||||
broadcastChannelName = "mcp_oauth",
|
|
||||||
localStorageKey = "mcp_oauth_result",
|
|
||||||
acceptMessageTypes = ["oauth_popup_result", "mcp_oauth_result"],
|
|
||||||
timeout = DEFAULT_TIMEOUT_MS,
|
|
||||||
} = options;
|
|
||||||
|
|
||||||
const controller = new AbortController();
|
|
||||||
|
|
||||||
// Open popup synchronously (before any async work) to avoid browser popup blockers
|
|
||||||
const width = 500;
|
|
||||||
const height = 700;
|
|
||||||
const left = window.screenX + (window.outerWidth - width) / 2;
|
|
||||||
const top = window.screenY + (window.outerHeight - height) / 2;
|
|
||||||
const popup = window.open(
|
|
||||||
"about:blank",
|
|
||||||
"_blank",
|
|
||||||
`width=${width},height=${height},left=${left},top=${top},popup=true,scrollbars=yes`,
|
|
||||||
);
|
|
||||||
|
|
||||||
if (popup && !popup.closed) {
|
|
||||||
popup.location.href = loginUrl;
|
|
||||||
} else {
|
|
||||||
// Popup was blocked — open in new tab as fallback
|
|
||||||
window.open(loginUrl, "_blank");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close popup on abort
|
|
||||||
controller.signal.addEventListener("abort", () => {
|
|
||||||
if (popup && !popup.closed) popup.close();
|
|
||||||
});
|
|
||||||
|
|
||||||
// Clear any stale localStorage entry
|
|
||||||
if (useCrossOriginListeners) {
|
|
||||||
try {
|
|
||||||
localStorage.removeItem(localStorageKey);
|
|
||||||
} catch {}
|
|
||||||
}
|
|
||||||
|
|
||||||
const promise = new Promise<OAuthPopupResult>((resolve, reject) => {
|
|
||||||
let handled = false;
|
|
||||||
|
|
||||||
const handleResult = (data: any) => {
|
|
||||||
if (handled) return; // Prevent double-handling
|
|
||||||
|
|
||||||
// Validate message type
|
|
||||||
const messageType = data?.message_type ?? data?.type;
|
|
||||||
if (!messageType || !acceptMessageTypes.includes(messageType)) return;
|
|
||||||
|
|
||||||
// Validate state token
|
|
||||||
if (data.state !== stateToken) {
|
|
||||||
// State mismatch — this message is for a different listener. Ignore silently.
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
handled = true;
|
|
||||||
|
|
||||||
if (!data.success) {
|
|
||||||
reject(new Error(data.message || "OAuth authentication failed"));
|
|
||||||
} else {
|
|
||||||
resolve({ code: data.code, state: data.state });
|
|
||||||
}
|
|
||||||
|
|
||||||
controller.abort("completed");
|
|
||||||
};
|
|
||||||
|
|
||||||
// Listener: postMessage (works for same-origin popups)
|
|
||||||
window.addEventListener(
|
|
||||||
"message",
|
|
||||||
(event: MessageEvent) => {
|
|
||||||
if (typeof event.data === "object") {
|
|
||||||
handleResult(event.data);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{ signal: controller.signal },
|
|
||||||
);
|
|
||||||
|
|
||||||
// Cross-origin listeners for MCP OAuth
|
|
||||||
if (useCrossOriginListeners) {
|
|
||||||
// Listener: BroadcastChannel (works across tabs/popups without opener)
|
|
||||||
try {
|
|
||||||
const bc = new BroadcastChannel(broadcastChannelName);
|
|
||||||
bc.onmessage = (event) => handleResult(event.data);
|
|
||||||
controller.signal.addEventListener("abort", () => bc.close());
|
|
||||||
} catch {}
|
|
||||||
|
|
||||||
// Listener: localStorage polling (most reliable cross-tab fallback)
|
|
||||||
const pollInterval = setInterval(() => {
|
|
||||||
try {
|
|
||||||
const stored = localStorage.getItem(localStorageKey);
|
|
||||||
if (stored) {
|
|
||||||
const data = JSON.parse(stored);
|
|
||||||
localStorage.removeItem(localStorageKey);
|
|
||||||
handleResult(data);
|
|
||||||
}
|
|
||||||
} catch {}
|
|
||||||
}, 500);
|
|
||||||
controller.signal.addEventListener("abort", () =>
|
|
||||||
clearInterval(pollInterval),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Timeout
|
|
||||||
const timeoutId = setTimeout(() => {
|
|
||||||
if (!handled) {
|
|
||||||
handled = true;
|
|
||||||
reject(new Error("OAuth flow timed out"));
|
|
||||||
controller.abort("timeout");
|
|
||||||
}
|
|
||||||
}, timeout);
|
|
||||||
controller.signal.addEventListener("abort", () => clearTimeout(timeoutId));
|
|
||||||
});
|
|
||||||
|
|
||||||
return {
|
|
||||||
promise,
|
|
||||||
cleanup: {
|
|
||||||
abort: (reason?: string) => controller.abort(reason || "canceled"),
|
|
||||||
signal: controller.signal,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
@@ -18,6 +18,6 @@ export const config = {
|
|||||||
* Note: /auth/authorize and /auth/integrations/* ARE protected and need
|
* Note: /auth/authorize and /auth/integrations/* ARE protected and need
|
||||||
* middleware to run for authentication checks.
|
* middleware to run for authentication checks.
|
||||||
*/
|
*/
|
||||||
"/((?!_next/static|_next/image|favicon.ico|auth/callback|auth/integrations/mcp_callback|.*\\.(?:svg|png|jpg|jpeg|gif|webp)$).*)",
|
"/((?!_next/static|_next/image|favicon.ico|auth/callback|.*\\.(?:svg|png|jpg|jpeg|gif|webp)$).*)",
|
||||||
],
|
],
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import {
|
|||||||
HostScopedCredentials,
|
HostScopedCredentials,
|
||||||
UserPasswordCredentials,
|
UserPasswordCredentials,
|
||||||
} from "@/lib/autogpt-server-api";
|
} from "@/lib/autogpt-server-api";
|
||||||
import { postV2ExchangeOauthCodeForMcpTokens } from "@/app/api/__generated__/endpoints/mcp/mcp";
|
|
||||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||||
import { toDisplayName } from "@/providers/agent-credentials/helper";
|
import { toDisplayName } from "@/providers/agent-credentials/helper";
|
||||||
@@ -39,11 +38,6 @@ export type CredentialsProviderData = {
|
|||||||
code: string,
|
code: string,
|
||||||
state_token: string,
|
state_token: string,
|
||||||
) => Promise<CredentialsMetaResponse>;
|
) => Promise<CredentialsMetaResponse>;
|
||||||
/** MCP-specific OAuth callback that uses dynamic per-server OAuth discovery. */
|
|
||||||
mcpOAuthCallback: (
|
|
||||||
code: string,
|
|
||||||
state_token: string,
|
|
||||||
) => Promise<CredentialsMetaResponse>;
|
|
||||||
createAPIKeyCredentials: (
|
createAPIKeyCredentials: (
|
||||||
credentials: APIKeyCredentialsCreatable,
|
credentials: APIKeyCredentialsCreatable,
|
||||||
) => Promise<CredentialsMetaResponse>;
|
) => Promise<CredentialsMetaResponse>;
|
||||||
@@ -126,35 +120,6 @@ export default function CredentialsProvider({
|
|||||||
[api, addCredentials, onFailToast],
|
[api, addCredentials, onFailToast],
|
||||||
);
|
);
|
||||||
|
|
||||||
/** Exchanges an MCP OAuth code for tokens and adds the result to the internal credentials store. */
|
|
||||||
const mcpOAuthCallback = useCallback(
|
|
||||||
async (
|
|
||||||
code: string,
|
|
||||||
state_token: string,
|
|
||||||
): Promise<CredentialsMetaResponse> => {
|
|
||||||
try {
|
|
||||||
const response = await postV2ExchangeOauthCodeForMcpTokens({
|
|
||||||
code,
|
|
||||||
state_token,
|
|
||||||
});
|
|
||||||
if (response.status !== 200) throw response.data;
|
|
||||||
const credsMeta: CredentialsMetaResponse = {
|
|
||||||
...response.data,
|
|
||||||
title: response.data.title ?? undefined,
|
|
||||||
scopes: response.data.scopes ?? undefined,
|
|
||||||
username: response.data.username ?? undefined,
|
|
||||||
host: response.data.host ?? undefined,
|
|
||||||
};
|
|
||||||
addCredentials("mcp", credsMeta);
|
|
||||||
return credsMeta;
|
|
||||||
} catch (error) {
|
|
||||||
onFailToast("complete MCP OAuth authentication")(error);
|
|
||||||
throw error;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
[addCredentials, onFailToast],
|
|
||||||
);
|
|
||||||
|
|
||||||
/** Wraps `BackendAPI.createAPIKeyCredentials`, and adds the result to the internal credentials store. */
|
/** Wraps `BackendAPI.createAPIKeyCredentials`, and adds the result to the internal credentials store. */
|
||||||
const createAPIKeyCredentials = useCallback(
|
const createAPIKeyCredentials = useCallback(
|
||||||
async (
|
async (
|
||||||
@@ -293,7 +258,6 @@ export default function CredentialsProvider({
|
|||||||
isSystemProvider: systemProviders.has(provider),
|
isSystemProvider: systemProviders.has(provider),
|
||||||
oAuthCallback: (code: string, state_token: string) =>
|
oAuthCallback: (code: string, state_token: string) =>
|
||||||
oAuthCallback(provider, code, state_token),
|
oAuthCallback(provider, code, state_token),
|
||||||
mcpOAuthCallback,
|
|
||||||
createAPIKeyCredentials: (
|
createAPIKeyCredentials: (
|
||||||
credentials: APIKeyCredentialsCreatable,
|
credentials: APIKeyCredentialsCreatable,
|
||||||
) => createAPIKeyCredentials(provider, credentials),
|
) => createAPIKeyCredentials(provider, credentials),
|
||||||
@@ -322,7 +286,6 @@ export default function CredentialsProvider({
|
|||||||
createHostScopedCredentials,
|
createHostScopedCredentials,
|
||||||
deleteCredentials,
|
deleteCredentials,
|
||||||
oAuthCallback,
|
oAuthCallback,
|
||||||
mcpOAuthCallback,
|
|
||||||
onFailToast,
|
onFailToast,
|
||||||
]);
|
]);
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user