diff --git a/autogpt_platform/backend/backend/api/features/chat/config.py b/autogpt_platform/backend/backend/api/features/chat/config.py index 2e8dbf5413..3901dbd04b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/config.py +++ b/autogpt_platform/backend/backend/api/features/chat/config.py @@ -93,6 +93,12 @@ class ChatConfig(BaseSettings): description="Name of the prompt in Langfuse to fetch", ) + # Claude Agent SDK Configuration + use_claude_agent_sdk: bool = Field( + default=True, + description="Use Claude Agent SDK for chat completions", + ) + @field_validator("api_key", mode="before") @classmethod def get_api_key(cls, v): @@ -132,6 +138,17 @@ class ChatConfig(BaseSettings): v = os.getenv("CHAT_INTERNAL_API_KEY") return v + @field_validator("use_claude_agent_sdk", mode="before") + @classmethod + def get_use_claude_agent_sdk(cls, v): + """Get use_claude_agent_sdk from environment if not provided.""" + # Check environment variable - default to True if not set + env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower() + if env_val: + return env_val in ("true", "1", "yes", "on") + # Default to True (SDK enabled by default) + return True if v is None else v + # Prompt paths for different contexts PROMPT_PATHS: dict[str, str] = { "default": "prompts/chat_system.md", diff --git a/autogpt_platform/backend/backend/api/features/chat/model.py b/autogpt_platform/backend/backend/api/features/chat/model.py index 7318ef88d7..d54dc35519 100644 --- a/autogpt_platform/backend/backend/api/features/chat/model.py +++ b/autogpt_platform/backend/backend/api/features/chat/model.py @@ -273,9 +273,8 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None: try: session = ChatSession.model_validate_json(raw_session) logger.info( - f"Loading session {session_id} from cache: " - f"message_count={len(session.messages)}, " - f"roles={[m.role for m in session.messages]}" + f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, " + f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles ) return session except Exception as e: @@ -317,11 +316,9 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None: return None messages = prisma_session.Messages - logger.info( - f"Loading session {session_id} from DB: " - f"has_messages={messages is not None}, " - f"message_count={len(messages) if messages else 0}, " - f"roles={[m.role for m in messages] if messages else []}" + logger.debug( + f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, " + f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles ) return ChatSession.from_db(prisma_session, messages) @@ -372,10 +369,9 @@ async def _save_session_to_db( "function_call": msg.function_call, } ) - logger.info( - f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: " - f"roles={[m['role'] for m in messages_data]}, " - f"start_sequence={existing_message_count}" + logger.debug( + f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, " + f"roles={[m['role'] for m in messages_data]}" ) await chat_db.add_chat_messages_batch( session_id=session.session_id, @@ -415,7 +411,7 @@ async def get_chat_session( logger.warning(f"Unexpected cache error for session {session_id}: {e}") # Fall back to database - logger.info(f"Session {session_id} not in cache, checking database") + logger.debug(f"Session {session_id} not in cache, checking database") session = await _get_session_from_db(session_id) if session is None: @@ -432,7 +428,6 @@ async def get_chat_session( # Cache the session from DB try: await _cache_session(session) - logger.info(f"Cached session {session_id} from database") except Exception as e: logger.warning(f"Failed to cache session {session_id}: {e}") @@ -603,13 +598,19 @@ async def update_session_title(session_id: str, title: str) -> bool: logger.warning(f"Session {session_id} not found for title update") return False - # Invalidate cache so next fetch gets updated title + # Update title in cache if it exists (instead of invalidating). + # This prevents race conditions where cache invalidation causes + # the frontend to see stale DB data while streaming is still in progress. try: - redis_key = _get_session_cache_key(session_id) - async_redis = await get_redis_async() - await async_redis.delete(redis_key) + cached = await _get_session_from_cache(session_id) + if cached: + cached.title = title + await _cache_session(cached) except Exception as e: - logger.warning(f"Failed to invalidate cache for session {session_id}: {e}") + # Not critical - title will be correct on next full cache refresh + logger.warning( + f"Failed to update title in cache for session {session_id}: {e}" + ) return True except Exception as e: diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index 3e731d86ac..640dbdb9cf 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -1,5 +1,6 @@ """Chat API routes for chat session management and streaming via SSE.""" +import asyncio import logging import uuid as uuid_module from collections.abc import AsyncGenerator @@ -16,8 +17,17 @@ from . import service as chat_service from . import stream_registry from .completion_handler import process_operation_failure, process_operation_success from .config import ChatConfig -from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions +from .model import ( + ChatMessage, + ChatSession, + create_chat_session, + get_chat_session, + get_user_sessions, + upsert_chat_session, +) from .response_model import StreamFinish, StreamHeartbeat, StreamStart +from .sdk import service as sdk_service +from .tracking import track_user_message config = ChatConfig() @@ -209,6 +219,10 @@ async def get_session( active_task, last_message_id = await stream_registry.get_active_task_for_session( session_id, user_id ) + logger.info( + f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, " + f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}" + ) if active_task: # Filter out the in-progress assistant message from the session response. # The client will receive the complete assistant response through the SSE @@ -265,10 +279,30 @@ async def stream_chat_post( containing the task_id for reconnection. """ - import asyncio - session = await _validate_and_get_session(session_id, user_id) + # Add user message to session BEFORE creating task to avoid race condition + # where GET_SESSION sees the task as "running" but the message isn't saved yet + if request.message: + session.messages.append( + ChatMessage( + role="user" if request.is_user_message else "assistant", + content=request.message, + ) + ) + if request.is_user_message: + track_user_message( + user_id=user_id, + session_id=session_id, + message_length=len(request.message), + ) + logger.info( + f"[STREAM] Saving user message to session {session_id}, " + f"msg_count={len(session.messages)}" + ) + session = await upsert_chat_session(session) + logger.info(f"[STREAM] User message saved for session {session_id}") + # Create a task in the stream registry for reconnection support task_id = str(uuid_module.uuid4()) operation_id = str(uuid_module.uuid4()) @@ -283,24 +317,38 @@ async def stream_chat_post( # Background task that runs the AI generation independently of SSE connection async def run_ai_generation(): + chunk_count = 0 try: # Emit a start event with task_id for reconnection start_chunk = StreamStart(messageId=task_id, taskId=task_id) await stream_registry.publish_chunk(task_id, start_chunk) - async for chunk in chat_service.stream_chat_completion( + # Choose service based on configuration + use_sdk = config.use_claude_agent_sdk + stream_fn = ( + sdk_service.stream_chat_completion_sdk + if use_sdk + else chat_service.stream_chat_completion + ) + # Pass message=None since we already added it to the session above + async for chunk in stream_fn( session_id, - request.message, + None, # Message already in session is_user_message=request.is_user_message, user_id=user_id, - session=session, # Pass pre-fetched session to avoid double-fetch + session=session, # Pass session with message already added context=request.context, ): + chunk_count += 1 # Write to Redis (subscribers will receive via XREAD) await stream_registry.publish_chunk(task_id, chunk) - # Mark task as completed - await stream_registry.mark_task_completed(task_id, "completed") + logger.info( + f"[BG_TASK] AI generation completed for session {session_id}: {chunk_count} chunks, marking task {task_id} as completed" + ) + # Mark task as completed (also publishes StreamFinish) + completed = await stream_registry.mark_task_completed(task_id, "completed") + logger.info(f"[BG_TASK] mark_task_completed returned: {completed}") except Exception as e: logger.error( f"Error in background AI generation for session {session_id}: {e}" @@ -315,7 +363,7 @@ async def stream_chat_post( async def event_generator() -> AsyncGenerator[str, None]: subscriber_queue = None try: - # Subscribe to the task stream (this replays existing messages + live updates) + # Subscribe to the task stream (replays + live updates) subscriber_queue = await stream_registry.subscribe_to_task( task_id=task_id, user_id=user_id, @@ -323,6 +371,7 @@ async def stream_chat_post( ) if subscriber_queue is None: + logger.warning(f"Failed to subscribe to task {task_id}") yield StreamFinish().to_sse() yield "data: [DONE]\n\n" return @@ -341,11 +390,11 @@ async def stream_chat_post( yield StreamHeartbeat().to_sse() except GeneratorExit: - pass # Client disconnected - background task continues + pass # Client disconnected - normal behavior except Exception as e: logger.error(f"Error in SSE stream for task {task_id}: {e}") finally: - # Unsubscribe when client disconnects or stream ends to prevent resource leak + # Unsubscribe when client disconnects or stream ends if subscriber_queue is not None: try: await stream_registry.unsubscribe_from_task( @@ -400,35 +449,21 @@ async def stream_chat_get( session = await _validate_and_get_session(session_id, user_id) async def event_generator() -> AsyncGenerator[str, None]: - chunk_count = 0 - first_chunk_type: str | None = None - async for chunk in chat_service.stream_chat_completion( + # Choose service based on configuration + use_sdk = config.use_claude_agent_sdk + stream_fn = ( + sdk_service.stream_chat_completion_sdk + if use_sdk + else chat_service.stream_chat_completion + ) + async for chunk in stream_fn( session_id, message, is_user_message=is_user_message, user_id=user_id, session=session, # Pass pre-fetched session to avoid double-fetch ): - if chunk_count < 3: - logger.info( - "Chat stream chunk", - extra={ - "session_id": session_id, - "chunk_type": str(chunk.type), - }, - ) - if not first_chunk_type: - first_chunk_type = str(chunk.type) - chunk_count += 1 yield chunk.to_sse() - logger.info( - "Chat stream completed", - extra={ - "session_id": session_id, - "chunk_count": chunk_count, - "first_chunk_type": first_chunk_type, - }, - ) # AI SDK protocol termination yield "data: [DONE]\n\n" @@ -550,8 +585,6 @@ async def stream_task( ) async def event_generator() -> AsyncGenerator[str, None]: - import asyncio - heartbeat_interval = 15.0 # Send heartbeat every 15 seconds try: while True: diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/__init__.py b/autogpt_platform/backend/backend/api/features/chat/sdk/__init__.py new file mode 100644 index 0000000000..7d9d6371e9 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/__init__.py @@ -0,0 +1,14 @@ +"""Claude Agent SDK integration for CoPilot. + +This module provides the integration layer between the Claude Agent SDK +and the existing CoPilot tool system, enabling drop-in replacement of +the current LLM orchestration with the battle-tested Claude Agent SDK. +""" + +from .service import stream_chat_completion_sdk +from .tool_adapter import create_copilot_mcp_server + +__all__ = [ + "stream_chat_completion_sdk", + "create_copilot_mcp_server", +] diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/anthropic_fallback.py b/autogpt_platform/backend/backend/api/features/chat/sdk/anthropic_fallback.py new file mode 100644 index 0000000000..a9977f12f4 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/anthropic_fallback.py @@ -0,0 +1,241 @@ +"""Anthropic SDK fallback implementation. + +This module provides the fallback streaming implementation using the Anthropic SDK +directly when the Claude Agent SDK is not available. +""" + +import json +import logging +import os +import uuid +from collections.abc import AsyncGenerator +from typing import Any, cast + +from ..config import ChatConfig +from ..model import ChatSession +from ..response_model import ( + StreamBaseResponse, + StreamError, + StreamFinish, + StreamTextDelta, + StreamTextEnd, + StreamTextStart, + StreamToolInputAvailable, + StreamToolInputStart, + StreamToolOutputAvailable, + StreamUsage, +) +from .tool_adapter import get_tool_definitions, get_tool_handlers + +logger = logging.getLogger(__name__) +config = ChatConfig() + + +async def stream_with_anthropic( + session: ChatSession, + system_prompt: str, + text_block_id: str, +) -> AsyncGenerator[StreamBaseResponse, None]: + """Stream using Anthropic SDK directly with tool calling support.""" + import anthropic + + api_key = os.getenv("ANTHROPIC_API_KEY") or config.api_key + if not api_key: + yield StreamError( + errorText="ANTHROPIC_API_KEY not configured", code="config_error" + ) + yield StreamFinish() + return + + client = anthropic.AsyncAnthropic(api_key=api_key) + tool_definitions = get_tool_definitions() + tool_handlers = get_tool_handlers() + + anthropic_tools = [ + { + "name": t["name"], + "description": t["description"], + "input_schema": t["inputSchema"], + } + for t in tool_definitions + ] + + anthropic_messages = _convert_session_to_anthropic(session) + + if not anthropic_messages or anthropic_messages[-1]["role"] != "user": + anthropic_messages.append( + {"role": "user", "content": "Continue with the task."} + ) + + has_started_text = False + max_iterations = 10 + + for _ in range(max_iterations): + try: + async with client.messages.stream( + model="claude-sonnet-4-20250514", + max_tokens=4096, + system=system_prompt, + messages=cast(Any, anthropic_messages), + tools=cast(Any, anthropic_tools) if anthropic_tools else [], + ) as stream: + async for event in stream: + if event.type == "content_block_start": + block = event.content_block + if hasattr(block, "type"): + if block.type == "text" and not has_started_text: + yield StreamTextStart(id=text_block_id) + has_started_text = True + elif block.type == "tool_use": + yield StreamToolInputStart( + toolCallId=block.id, toolName=block.name + ) + + elif event.type == "content_block_delta": + delta = event.delta + if hasattr(delta, "type") and delta.type == "text_delta": + yield StreamTextDelta(id=text_block_id, delta=delta.text) + + final_message = await stream.get_final_message() + + if final_message.stop_reason == "tool_use": + if has_started_text: + yield StreamTextEnd(id=text_block_id) + has_started_text = False + text_block_id = str(uuid.uuid4()) + + tool_results = [] + assistant_content: list[dict[str, Any]] = [] + + for block in final_message.content: + if block.type == "text": + assistant_content.append( + {"type": "text", "text": block.text} + ) + elif block.type == "tool_use": + assistant_content.append( + { + "type": "tool_use", + "id": block.id, + "name": block.name, + "input": block.input, + } + ) + + yield StreamToolInputAvailable( + toolCallId=block.id, + toolName=block.name, + input=( + block.input if isinstance(block.input, dict) else {} + ), + ) + + output, is_error = await _execute_tool( + block.name, block.input, tool_handlers + ) + + yield StreamToolOutputAvailable( + toolCallId=block.id, + toolName=block.name, + output=output, + success=not is_error, + ) + + tool_results.append( + { + "type": "tool_result", + "tool_use_id": block.id, + "content": output, + "is_error": is_error, + } + ) + + anthropic_messages.append( + {"role": "assistant", "content": assistant_content} + ) + anthropic_messages.append({"role": "user", "content": tool_results}) + continue + + else: + if has_started_text: + yield StreamTextEnd(id=text_block_id) + + yield StreamUsage( + promptTokens=final_message.usage.input_tokens, + completionTokens=final_message.usage.output_tokens, + totalTokens=final_message.usage.input_tokens + + final_message.usage.output_tokens, + ) + yield StreamFinish() + return + + except Exception as e: + logger.error(f"[Anthropic Fallback] Error: {e}", exc_info=True) + yield StreamError(errorText=f"Error: {str(e)}", code="anthropic_error") + yield StreamFinish() + return + + yield StreamError(errorText="Max tool iterations reached", code="max_iterations") + yield StreamFinish() + + +def _convert_session_to_anthropic(session: ChatSession) -> list[dict[str, Any]]: + """Convert session messages to Anthropic format.""" + messages = [] + for msg in session.messages: + if msg.role == "user": + messages.append({"role": "user", "content": msg.content or ""}) + elif msg.role == "assistant": + content: list[dict[str, Any]] = [] + if msg.content: + content.append({"type": "text", "text": msg.content}) + if msg.tool_calls: + for tc in msg.tool_calls: + func = tc.get("function", {}) + args = func.get("arguments", {}) + if isinstance(args, str): + try: + args = json.loads(args) + except json.JSONDecodeError: + args = {} + content.append( + { + "type": "tool_use", + "id": tc.get("id", str(uuid.uuid4())), + "name": func.get("name", ""), + "input": args, + } + ) + if content: + messages.append({"role": "assistant", "content": content}) + elif msg.role == "tool": + messages.append( + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": msg.tool_call_id or "", + "content": msg.content or "", + } + ], + } + ) + return messages + + +async def _execute_tool( + tool_name: str, tool_input: Any, handlers: dict[str, Any] +) -> tuple[str, bool]: + """Execute a tool and return (output, is_error).""" + handler = handlers.get(tool_name) + if not handler: + return f"Unknown tool: {tool_name}", True + + try: + result = await handler(tool_input) + output = result.get("content", [{}])[0].get("text", "") + is_error = result.get("isError", False) + return output, is_error + except Exception as e: + return f"Error: {str(e)}", True diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py new file mode 100644 index 0000000000..9396aa4f90 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py @@ -0,0 +1,299 @@ +"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format. + +This module provides the adapter layer that converts streaming messages from +the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that +the frontend expects. +""" + +import json +import logging +import uuid +from typing import Any, AsyncGenerator + +from backend.api.features.chat.response_model import ( + StreamBaseResponse, + StreamError, + StreamFinish, + StreamHeartbeat, + StreamStart, + StreamTextDelta, + StreamTextEnd, + StreamTextStart, + StreamToolInputAvailable, + StreamToolInputStart, + StreamToolOutputAvailable, + StreamUsage, +) + +logger = logging.getLogger(__name__) + + +class SDKResponseAdapter: + """Adapter for converting Claude Agent SDK messages to Vercel AI SDK format. + + This class maintains state during a streaming session to properly track + text blocks, tool calls, and message lifecycle. + """ + + def __init__(self, message_id: str | None = None): + """Initialize the adapter. + + Args: + message_id: Optional message ID. If not provided, one will be generated. + """ + self.message_id = message_id or str(uuid.uuid4()) + self.text_block_id = str(uuid.uuid4()) + self.has_started_text = False + self.has_ended_text = False + self.current_tool_calls: dict[str, dict[str, Any]] = {} + self.task_id: str | None = None + + def set_task_id(self, task_id: str) -> None: + """Set the task ID for reconnection support.""" + self.task_id = task_id + + def convert_message(self, sdk_message: Any) -> list[StreamBaseResponse]: + """Convert a single SDK message to Vercel AI SDK format. + + Args: + sdk_message: A message from the Claude Agent SDK. + + Returns: + List of StreamBaseResponse objects (may be empty or multiple). + """ + responses: list[StreamBaseResponse] = [] + + # Handle different SDK message types - use class name since SDK uses dataclasses + class_name = type(sdk_message).__name__ + msg_subtype = getattr(sdk_message, "subtype", None) + + if class_name == "SystemMessage": + if msg_subtype == "init": + # Session initialization - emit start + responses.append( + StreamStart( + messageId=self.message_id, + taskId=self.task_id, + ) + ) + + elif class_name == "AssistantMessage": + # Assistant message with content blocks + content = getattr(sdk_message, "content", []) + for block in content: + # Check block type by class name (SDK uses dataclasses) or dict type + block_class = type(block).__name__ + block_type = block.get("type") if isinstance(block, dict) else None + + if block_class == "TextBlock" or block_type == "text": + # Text content + text = getattr(block, "text", None) or ( + block.get("text") if isinstance(block, dict) else "" + ) + + if text: + # Start text block if needed (or restart after tool calls) + if not self.has_started_text or self.has_ended_text: + # Generate new text block ID for text after tools + if self.has_ended_text: + self.text_block_id = str(uuid.uuid4()) + self.has_ended_text = False + responses.append(StreamTextStart(id=self.text_block_id)) + self.has_started_text = True + + # Emit text delta + responses.append( + StreamTextDelta( + id=self.text_block_id, + delta=text, + ) + ) + + elif block_class == "ToolUseBlock" or block_type == "tool_use": + # Tool call + tool_id_raw = getattr(block, "id", None) or ( + block.get("id") if isinstance(block, dict) else None + ) + tool_id: str = ( + str(tool_id_raw) if tool_id_raw else str(uuid.uuid4()) + ) + + tool_name_raw = getattr(block, "name", None) or ( + block.get("name") if isinstance(block, dict) else None + ) + tool_name: str = str(tool_name_raw) if tool_name_raw else "unknown" + + tool_input = getattr(block, "input", None) or ( + block.get("input") if isinstance(block, dict) else {} + ) + + # End text block if we were streaming text + if self.has_started_text and not self.has_ended_text: + responses.append(StreamTextEnd(id=self.text_block_id)) + self.has_ended_text = True + + # Emit tool input start + responses.append( + StreamToolInputStart( + toolCallId=tool_id, + toolName=tool_name, + ) + ) + + # Emit tool input available with full input + responses.append( + StreamToolInputAvailable( + toolCallId=tool_id, + toolName=tool_name, + input=tool_input if isinstance(tool_input, dict) else {}, + ) + ) + + # Track the tool call + self.current_tool_calls[tool_id] = { + "name": tool_name, + "input": tool_input, + } + + elif class_name in ("ToolResultMessage", "UserMessage"): + # Tool result - check for tool_result content + content = getattr(sdk_message, "content", []) + + for block in content: + block_class = type(block).__name__ + block_type = block.get("type") if isinstance(block, dict) else None + + if block_class == "ToolResultBlock" or block_type == "tool_result": + tool_use_id = getattr(block, "tool_use_id", None) or ( + block.get("tool_use_id") if isinstance(block, dict) else None + ) + result_content = getattr(block, "content", None) or ( + block.get("content") if isinstance(block, dict) else "" + ) + is_error = getattr(block, "is_error", False) or ( + block.get("is_error", False) + if isinstance(block, dict) + else False + ) + + if tool_use_id: + tool_info = self.current_tool_calls.get(tool_use_id, {}) + tool_name = tool_info.get("name", "unknown") + + # Format the output + if isinstance(result_content, list): + # Extract text from content blocks + output_text = "" + for item in result_content: + if ( + isinstance(item, dict) + and item.get("type") == "text" + ): + output_text += item.get("text", "") + elif hasattr(item, "text"): + output_text += getattr(item, "text", "") + output = output_text + elif isinstance(result_content, str): + output = result_content + else: + output = json.dumps(result_content) + + responses.append( + StreamToolOutputAvailable( + toolCallId=tool_use_id, + toolName=tool_name, + output=output, + success=not is_error, + ) + ) + + elif class_name == "ResultMessage": + # Final result + if msg_subtype == "success": + # End text block if still open + if self.has_started_text and not self.has_ended_text: + responses.append(StreamTextEnd(id=self.text_block_id)) + self.has_ended_text = True + + # Emit finish + responses.append(StreamFinish()) + + elif msg_subtype in ("error", "error_during_execution"): + error_msg = getattr(sdk_message, "error", "Unknown error") + responses.append( + StreamError( + errorText=str(error_msg), + code="sdk_error", + ) + ) + responses.append(StreamFinish()) + + elif class_name == "ErrorMessage": + # Error message + error_msg = getattr(sdk_message, "message", None) or getattr( + sdk_message, "error", "Unknown error" + ) + responses.append( + StreamError( + errorText=str(error_msg), + code="sdk_error", + ) + ) + + return responses + + def create_heartbeat(self, tool_call_id: str | None = None) -> StreamHeartbeat: + """Create a heartbeat response.""" + return StreamHeartbeat(toolCallId=tool_call_id) + + def create_usage( + self, + prompt_tokens: int, + completion_tokens: int, + ) -> StreamUsage: + """Create a usage statistics response.""" + return StreamUsage( + promptTokens=prompt_tokens, + completionTokens=completion_tokens, + totalTokens=prompt_tokens + completion_tokens, + ) + + +async def adapt_sdk_stream( + sdk_stream: AsyncGenerator[Any, None], + message_id: str | None = None, + task_id: str | None = None, +) -> AsyncGenerator[StreamBaseResponse, None]: + """Adapt a Claude Agent SDK stream to Vercel AI SDK format. + + Args: + sdk_stream: The async generator from the Claude Agent SDK. + message_id: Optional message ID for the response. + task_id: Optional task ID for reconnection support. + + Yields: + StreamBaseResponse objects in Vercel AI SDK format. + """ + adapter = SDKResponseAdapter(message_id=message_id) + if task_id: + adapter.set_task_id(task_id) + + # Emit start immediately + yield StreamStart(messageId=adapter.message_id, taskId=task_id) + + try: + async for sdk_message in sdk_stream: + responses = adapter.convert_message(sdk_message) + for response in responses: + # Skip duplicate start messages + if isinstance(response, StreamStart): + continue + yield response + + except Exception as e: + logger.error(f"Error in SDK stream: {e}", exc_info=True) + yield StreamError( + errorText=f"Stream error: {str(e)}", + code="stream_error", + ) + yield StreamFinish() diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py b/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py new file mode 100644 index 0000000000..c07d3db534 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py @@ -0,0 +1,237 @@ +"""Security hooks for Claude Agent SDK integration. + +This module provides security hooks that validate tool calls before execution, +ensuring multi-user isolation and preventing unauthorized operations. +""" + +import logging +import re +from typing import Any, cast + +logger = logging.getLogger(__name__) + +# Tools that are blocked entirely (CLI/system access) +BLOCKED_TOOLS = { + "Bash", + "bash", + "shell", + "exec", + "terminal", + "command", + "Read", # Block raw file read - use workspace tools instead + "Write", # Block raw file write - use workspace tools instead + "Edit", # Block raw file edit - use workspace tools instead + "Glob", # Block raw file glob - use workspace tools instead + "Grep", # Block raw file grep - use workspace tools instead +} + +# Dangerous patterns in tool inputs +DANGEROUS_PATTERNS = [ + r"sudo", + r"rm\s+-rf", + r"dd\s+if=", + r"/etc/passwd", + r"/etc/shadow", + r"chmod\s+777", + r"curl\s+.*\|.*sh", + r"wget\s+.*\|.*sh", + r"eval\s*\(", + r"exec\s*\(", + r"__import__", + r"os\.system", + r"subprocess", +] + + +def _validate_tool_access(tool_name: str, tool_input: dict[str, Any]) -> dict[str, Any]: + """Validate that a tool call is allowed. + + Returns: + Empty dict to allow, or dict with hookSpecificOutput to deny + """ + # Block forbidden tools + if tool_name in BLOCKED_TOOLS: + logger.warning(f"Blocked tool access attempt: {tool_name}") + return { + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "permissionDecision": "deny", + "permissionDecisionReason": ( + f"Tool '{tool_name}' is not available. " + "Use the CoPilot-specific tools instead." + ), + } + } + + # Check for dangerous patterns in tool input + input_str = str(tool_input) + + for pattern in DANGEROUS_PATTERNS: + if re.search(pattern, input_str, re.IGNORECASE): + logger.warning( + f"Blocked dangerous pattern in tool input: {pattern} in {tool_name}" + ) + return { + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "permissionDecision": "deny", + "permissionDecisionReason": "Input contains blocked pattern", + } + } + + return {} + + +def _validate_user_isolation( + tool_name: str, tool_input: dict[str, Any], user_id: str | None +) -> dict[str, Any]: + """Validate that tool calls respect user isolation.""" + # For workspace file tools, ensure path doesn't escape + if "workspace" in tool_name.lower(): + path = tool_input.get("path", "") or tool_input.get("file_path", "") + if path: + # Check for path traversal + if ".." in path or path.startswith("/"): + logger.warning( + f"Blocked path traversal attempt: {path} by user {user_id}" + ) + return { + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "permissionDecision": "deny", + "permissionDecisionReason": "Path traversal not allowed", + } + } + + return {} + + +def create_security_hooks(user_id: str | None) -> dict[str, Any]: + """Create the security hooks configuration for Claude Agent SDK. + + Args: + user_id: Current user ID for isolation validation + + Returns: + Hooks configuration dict for ClaudeAgentOptions + """ + try: + from claude_agent_sdk import HookMatcher + from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput + + async def pre_tool_use_hook( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Combined pre-tool-use validation hook.""" + _ = context # unused but required by signature + # Extract tool info from the typed input + tool_name = cast(str, input_data.get("tool_name", "")) + tool_input = cast(dict[str, Any], input_data.get("tool_input", {})) + + # Validate basic tool access + result = _validate_tool_access(tool_name, tool_input) + if result: + return cast(SyncHookJSONOutput, result) + + # Validate user isolation + result = _validate_user_isolation(tool_name, tool_input, user_id) + if result: + return cast(SyncHookJSONOutput, result) + + # Log the usage + logger.debug( + f"[SDK Audit] Tool call: tool={tool_name}, " + f"user={user_id}, tool_use_id={tool_use_id}" + ) + + return cast(SyncHookJSONOutput, {}) + + return { + "PreToolUse": [ + HookMatcher( + matcher="*", + hooks=[pre_tool_use_hook], + ), + ], + } + except ImportError: + # Fallback for when SDK isn't available - return empty hooks + return {} + + +def create_strict_security_hooks( + user_id: str | None, + allowed_tools: list[str] | None = None, +) -> dict[str, Any]: + """Create strict security hooks that only allow specific tools. + + Args: + user_id: Current user ID + allowed_tools: List of allowed tool names (defaults to CoPilot tools) + + Returns: + Hooks configuration dict + """ + try: + from claude_agent_sdk import HookMatcher + from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput + + from .tool_adapter import RAW_TOOL_NAMES + + tools_list = allowed_tools if allowed_tools is not None else RAW_TOOL_NAMES + allowed_set = set(tools_list) + + async def strict_pre_tool_use( + input_data: HookInput, + tool_use_id: str | None, + context: HookContext, + ) -> SyncHookJSONOutput: + """Strict validation that only allows whitelisted tools.""" + _ = context # unused but required by signature + tool_name = cast(str, input_data.get("tool_name", "")) + tool_input = cast(dict[str, Any], input_data.get("tool_input", {})) + + # Remove MCP prefix if present + clean_name = tool_name + if tool_name.startswith("mcp__copilot__"): + clean_name = tool_name.replace("mcp__copilot__", "") + + if clean_name not in allowed_set: + logger.warning(f"Blocked non-whitelisted tool: {tool_name}") + return cast( + SyncHookJSONOutput, + { + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "permissionDecision": "deny", + "permissionDecisionReason": ( + f"Tool '{tool_name}' is not in the allowed list" + ), + } + }, + ) + + # Run standard validations + result = _validate_tool_access(tool_name, tool_input) + if result: + return cast(SyncHookJSONOutput, result) + + result = _validate_user_isolation(tool_name, tool_input, user_id) + if result: + return cast(SyncHookJSONOutput, result) + + logger.debug( + f"[SDK Audit] Tool call: tool={tool_name}, " + f"user={user_id}, tool_use_id={tool_use_id}" + ) + return cast(SyncHookJSONOutput, {}) + + return { + "PreToolUse": [ + HookMatcher(matcher="*", hooks=[strict_pre_tool_use]), + ], + } + except ImportError: + return {} diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/service.py b/autogpt_platform/backend/backend/api/features/chat/sdk/service.py new file mode 100644 index 0000000000..ce6c6e3bd4 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/service.py @@ -0,0 +1,428 @@ +"""Claude Agent SDK service layer for CoPilot chat completions.""" + +import asyncio +import logging +import uuid +from collections.abc import AsyncGenerator +from typing import Any + +import openai + +from backend.data.understanding import ( + format_understanding_for_prompt, + get_business_understanding, +) +from backend.util.exceptions import NotFoundError + +from ..config import ChatConfig +from ..model import ( + ChatMessage, + ChatSession, + get_chat_session, + update_session_title, + upsert_chat_session, +) +from ..response_model import ( + StreamBaseResponse, + StreamError, + StreamFinish, + StreamStart, + StreamTextDelta, + StreamToolOutputAvailable, +) +from ..tracking import track_user_message +from .anthropic_fallback import stream_with_anthropic +from .response_adapter import SDKResponseAdapter +from .security_hooks import create_security_hooks +from .tool_adapter import ( + COPILOT_TOOL_NAMES, + create_copilot_mcp_server, + set_execution_context, +) + +logger = logging.getLogger(__name__) +config = ChatConfig() + +DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations. + +Here is everything you know about the current user from previous interactions: + + +{users_information} + + +## YOUR CORE MANDATE + +You are action-oriented. Your success is measured by: +- **Value Delivery**: Does the user think "wow, that was amazing" or "what was the point"? +- **Demonstrable Proof**: Show working automations, not descriptions of what's possible +- **Time Saved**: Focus on tangible efficiency gains +- **Quality Output**: Deliver results that meet or exceed expectations + +## YOUR WORKFLOW + +Adapt flexibly to the conversation context. Not every interaction requires all stages: + +1. **Explore & Understand**: Learn about the user's business, tasks, and goals. Use `add_understanding` to capture important context that will improve future conversations. + +2. **Assess Automation Potential**: Help the user understand whether and how AI can automate their task. + +3. **Prepare for AI**: Provide brief, actionable guidance on prerequisites (data, access, etc.). + +4. **Discover or Create Agents**: + - **Always check the user's library first** with `find_library_agent` (these may be customized to their needs) + - Search the marketplace with `find_agent` for pre-built automations + - Find reusable components with `find_block` + - Create custom solutions with `create_agent` if nothing suitable exists + - Modify existing library agents with `edit_agent` + +5. **Execute**: Run automations immediately, schedule them, or set up webhooks using `run_agent`. Test specific components with `run_block`. + +6. **Show Results**: Display outputs using `agent_output`. + +## BEHAVIORAL GUIDELINES + +**Be Concise:** +- Target 2-5 short lines maximum +- Make every word count—no repetition or filler +- Use lightweight structure for scannability (bullets, numbered lists, short prompts) +- Avoid jargon (blocks, slugs, cron) unless the user asks + +**Be Proactive:** +- Suggest next steps before being asked +- Anticipate needs based on conversation context and user information +- Look for opportunities to expand scope when relevant +- Reveal capabilities through action, not explanation + +**Use Tools Effectively:** +- Select the right tool for each task +- **Always check `find_library_agent` before searching the marketplace** +- Use `add_understanding` to capture valuable business context +- When tool calls fail, try alternative approaches + +## CRITICAL REMINDER + +You are NOT a chatbot. You are NOT documentation. You are a partner who helps busy business owners get value quickly by showing proof through working automations. Bias toward action over explanation.""" + + +async def _build_system_prompt( + user_id: str | None, has_conversation_history: bool = False +) -> tuple[str, Any]: + """Build the system prompt with user's business understanding context. + + Args: + user_id: The user ID to fetch understanding for. + has_conversation_history: Whether there's existing conversation history. + If True, we don't tell the model to greet/introduce (since they're + already in a conversation). + """ + understanding = None + if user_id: + try: + understanding = await get_business_understanding(user_id) + except Exception as e: + logger.warning(f"Failed to fetch business understanding: {e}") + + if understanding: + context = format_understanding_for_prompt(understanding) + elif has_conversation_history: + # Don't tell model to greet if there's conversation history + context = "No prior understanding saved yet. Continue the existing conversation naturally." + else: + context = "This is the first time you are meeting the user. Greet them and introduce them to the platform" + + return DEFAULT_SYSTEM_PROMPT.format(users_information=context), understanding + + +def _format_conversation_history(session: ChatSession) -> str: + """Format conversation history as a prompt context. + + The Claude Agent SDK doesn't support replaying full conversation history, + so we include it as context in the prompt. + """ + if not session.messages: + return "" + + # Get all messages except the last user message (which will be the prompt) + messages = session.messages[:-1] if session.messages else [] + if not messages: + return "" + + history_parts = [] + history_parts.append("") + + for msg in messages: + if msg.role == "user": + history_parts.append(f"User: {msg.content or ''}") + elif msg.role == "assistant": + content = msg.content or "" + # Truncate long assistant responses + if len(content) > 500: + content = content[:500] + "..." + history_parts.append(f"Assistant: {content}") + # Include tool calls summary if any + if msg.tool_calls: + for tc in msg.tool_calls: + func = tc.get("function", {}) + tool_name = func.get("name", "unknown") + history_parts.append(f" [Called tool: {tool_name}]") + elif msg.role == "tool": + # Summarize tool results + result = msg.content or "" + if len(result) > 200: + result = result[:200] + "..." + history_parts.append(f" [Tool result: {result}]") + + history_parts.append("") + history_parts.append("") + history_parts.append( + "Continue this conversation. Respond to the user's latest message:" + ) + history_parts.append("") + + return "\n".join(history_parts) + + +async def _generate_session_title( + message: str, + user_id: str | None = None, + session_id: str | None = None, +) -> str | None: + """Generate a concise title for a chat session.""" + from backend.util.settings import Settings + + settings = Settings() + try: + # Build extra_body for OpenRouter tracing + extra_body: dict[str, Any] = { + "posthogProperties": {"environment": settings.config.app_env.value}, + } + if user_id: + extra_body["user"] = user_id[:128] + extra_body["posthogDistinctId"] = user_id + if session_id: + extra_body["session_id"] = session_id[:128] + + client = openai.AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) + response = await client.chat.completions.create( + model=config.title_model, + messages=[ + { + "role": "system", + "content": "Generate a very short title (3-6 words) for a chat conversation based on the user's first message. Return ONLY the title, no quotes or punctuation.", + }, + {"role": "user", "content": message[:500]}, + ], + max_tokens=20, + extra_body=extra_body, + ) + title = response.choices[0].message.content + if title: + title = title.strip().strip("\"'") + return title[:47] + "..." if len(title) > 50 else title + return None + except Exception as e: + logger.warning(f"Failed to generate session title: {e}") + return None + + +async def stream_chat_completion_sdk( + session_id: str, + message: str | None = None, + tool_call_response: str | None = None, # noqa: ARG001 + is_user_message: bool = True, + user_id: str | None = None, + retry_count: int = 0, # noqa: ARG001 + session: ChatSession | None = None, + context: dict[str, str] | None = None, # noqa: ARG001 +) -> AsyncGenerator[StreamBaseResponse, None]: + """Stream chat completion using Claude Agent SDK. + + Drop-in replacement for stream_chat_completion with improved reliability. + """ + + if session is None: + session = await get_chat_session(session_id, user_id) + + if not session: + raise NotFoundError( + f"Session {session_id} not found. Please create a new session first." + ) + + if message: + session.messages.append( + ChatMessage( + role="user" if is_user_message else "assistant", content=message + ) + ) + if is_user_message: + track_user_message( + user_id=user_id, session_id=session_id, message_length=len(message) + ) + + session = await upsert_chat_session(session) + + # Generate title for new sessions (first user message) + if is_user_message and not session.title: + user_messages = [m for m in session.messages if m.role == "user"] + if len(user_messages) == 1: + first_message = user_messages[0].content or message or "" + if first_message: + asyncio.create_task( + _update_title_async(session_id, first_message, user_id) + ) + + # Check if there's conversation history (more than just the current message) + has_history = len(session.messages) > 1 + system_prompt, _ = await _build_system_prompt( + user_id, has_conversation_history=has_history + ) + set_execution_context(user_id, session, None) + + message_id = str(uuid.uuid4()) + text_block_id = str(uuid.uuid4()) + task_id = str(uuid.uuid4()) + + yield StreamStart(messageId=message_id, taskId=task_id) + + # Track whether the stream completed normally via ResultMessage + stream_completed = False + + try: + try: + from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient + + # Create MCP server with CoPilot tools + mcp_server = create_copilot_mcp_server() + + options = ClaudeAgentOptions( + system_prompt=system_prompt, + mcp_servers={"copilot": mcp_server}, # type: ignore[arg-type] + allowed_tools=COPILOT_TOOL_NAMES, + hooks=create_security_hooks(user_id), # type: ignore[arg-type] + continue_conversation=True, # Enable conversation continuation + ) + + adapter = SDKResponseAdapter(message_id=message_id) + adapter.set_task_id(task_id) + + async with ClaudeSDKClient(options=options) as client: + # Build prompt with conversation history for context + # The SDK doesn't support replaying full conversation history, + # so we include it as context in the prompt + current_message = message or "" + if not current_message and session.messages: + last_user = [m for m in session.messages if m.role == "user"] + if last_user: + current_message = last_user[-1].content or "" + + # Include conversation history if there are prior messages + if len(session.messages) > 1: + history_context = _format_conversation_history(session) + prompt = f"{history_context}{current_message}" + else: + prompt = current_message + + await client.query(prompt, session_id=session_id) + + # Track assistant response to save to session + # We may need multiple assistant messages if text comes after tool results + assistant_response = ChatMessage(role="assistant", content="") + has_appended_assistant = False + has_tool_results = False # Track if we've received tool results + + # Receive messages from the SDK + async for sdk_msg in client.receive_messages(): + + for response in adapter.convert_message(sdk_msg): + if isinstance(response, StreamStart): + continue + yield response + + # Accumulate text deltas into assistant response + if isinstance(response, StreamTextDelta): + delta = response.delta or "" + # After tool results, create new assistant message for post-tool text + if has_tool_results and has_appended_assistant: + assistant_response = ChatMessage( + role="assistant", content=delta + ) + session.messages.append(assistant_response) + has_tool_results = False + else: + assistant_response.content = ( + assistant_response.content or "" + ) + delta + if not has_appended_assistant: + session.messages.append(assistant_response) + has_appended_assistant = True + + elif isinstance(response, StreamToolOutputAvailable): + session.messages.append( + ChatMessage( + role="tool", + content=( + response.output + if isinstance(response.output, str) + else str(response.output) + ), + tool_call_id=response.toolCallId, + ) + ) + has_tool_results = True + + elif isinstance(response, StreamFinish): + stream_completed = True + + # Break out of the message loop if we received finish signal + if stream_completed: + break + + # Ensure assistant response is saved even if no text deltas + # (e.g., only tool calls were made) + if assistant_response.content and not has_appended_assistant: + session.messages.append(assistant_response) + + except ImportError: + logger.warning( + "[SDK] claude-agent-sdk not available, using Anthropic fallback" + ) + async for response in stream_with_anthropic( + session, system_prompt, text_block_id + ): + yield response + + # Save the session with accumulated messages + await upsert_chat_session(session) + logger.debug( + f"[SDK] Session {session_id} saved with {len(session.messages)} messages" + ) + # Always yield StreamFinish to signal completion to the caller + # The adapter yields StreamFinish for the SSE stream, but we need to + # yield it here so the background task in routes.py knows to call mark_task_completed + yield StreamFinish() + + except Exception as e: + logger.error(f"[SDK] Error: {e}", exc_info=True) + # Save session even on error to preserve any partial response + try: + await upsert_chat_session(session) + except Exception as save_err: + logger.error(f"[SDK] Failed to save session on error: {save_err}") + yield StreamError(errorText=f"An error occurred: {str(e)}", code="sdk_error") + yield StreamFinish() + + +async def _update_title_async( + session_id: str, message: str, user_id: str | None = None +) -> None: + """Background task to update session title.""" + try: + title = await _generate_session_title( + message, user_id=user_id, session_id=session_id + ) + if title: + await update_session_title(session_id, title) + logger.debug(f"[SDK] Generated title for {session_id}: {title}") + except Exception as e: + logger.warning(f"[SDK] Failed to update session title: {e}") diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py b/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py new file mode 100644 index 0000000000..39d9e27561 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py @@ -0,0 +1,213 @@ +"""Tool adapter for wrapping existing CoPilot tools as Claude Agent SDK MCP tools. + +This module provides the adapter layer that converts existing BaseTool implementations +into in-process MCP tools that can be used with the Claude Agent SDK. +""" + +import json +import logging +from contextvars import ContextVar +from typing import Any + +from backend.api.features.chat.model import ChatSession +from backend.api.features.chat.tools import TOOL_REGISTRY +from backend.api.features.chat.tools.base import BaseTool + +logger = logging.getLogger(__name__) + +# Context variables to pass user/session info to tool execution +_current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default=None) +_current_session: ContextVar[ChatSession | None] = ContextVar( + "current_session", default=None +) +_current_tool_call_id: ContextVar[str | None] = ContextVar( + "current_tool_call_id", default=None +) + + +def set_execution_context( + user_id: str | None, + session: ChatSession, + tool_call_id: str | None = None, +) -> None: + """Set the execution context for tool calls. + + This must be called before streaming begins to ensure tools have access + to user_id and session information. + """ + _current_user_id.set(user_id) + _current_session.set(session) + _current_tool_call_id.set(tool_call_id) + + +def get_execution_context() -> tuple[str | None, ChatSession | None, str | None]: + """Get the current execution context.""" + return ( + _current_user_id.get(), + _current_session.get(), + _current_tool_call_id.get(), + ) + + +def create_tool_handler(base_tool: BaseTool): + """Create an async handler function for a BaseTool. + + This wraps the existing BaseTool._execute method to be compatible + with the Claude Agent SDK MCP tool format. + """ + + async def tool_handler(args: dict[str, Any]) -> dict[str, Any]: + """Execute the wrapped tool and return MCP-formatted response.""" + user_id, session, tool_call_id = get_execution_context() + + if session is None: + return { + "content": [ + { + "type": "text", + "text": json.dumps( + { + "error": "No session context available", + "type": "error", + } + ), + } + ], + "isError": True, + } + + try: + # Call the existing tool's execute method + result = await base_tool.execute( + user_id=user_id, + session=session, + tool_call_id=tool_call_id or "sdk-call", + **args, + ) + + # The result is a StreamToolOutputAvailable, extract the output + return { + "content": [ + { + "type": "text", + "text": ( + result.output + if isinstance(result.output, str) + else json.dumps(result.output) + ), + } + ], + "isError": not result.success, + } + + except Exception as e: + logger.error(f"Error executing tool {base_tool.name}: {e}", exc_info=True) + return { + "content": [ + { + "type": "text", + "text": json.dumps( + { + "error": str(e), + "type": "error", + "message": f"Failed to execute {base_tool.name}", + } + ), + } + ], + "isError": True, + } + + return tool_handler + + +def get_tool_definitions() -> list[dict[str, Any]]: + """Get all tool definitions in MCP format. + + Returns a list of tool definitions that can be used with + create_sdk_mcp_server or as raw tool definitions. + """ + tool_definitions = [] + + for tool_name, base_tool in TOOL_REGISTRY.items(): + tool_def = { + "name": tool_name, + "description": base_tool.description, + "inputSchema": { + "type": "object", + "properties": base_tool.parameters.get("properties", {}), + "required": base_tool.parameters.get("required", []), + }, + } + tool_definitions.append(tool_def) + + return tool_definitions + + +def get_tool_handlers() -> dict[str, Any]: + """Get all tool handlers mapped by name. + + Returns a dictionary mapping tool names to their handler functions. + """ + handlers = {} + + for tool_name, base_tool in TOOL_REGISTRY.items(): + handlers[tool_name] = create_tool_handler(base_tool) + + return handlers + + +# Create the MCP server configuration +def create_copilot_mcp_server(): + """Create an in-process MCP server configuration for CoPilot tools. + + This can be passed to ClaudeAgentOptions.mcp_servers. + + Note: The actual SDK MCP server creation depends on the claude-agent-sdk + package being available. This function returns the configuration that + can be used with the SDK. + """ + try: + from claude_agent_sdk import create_sdk_mcp_server, tool + + # Create decorated tool functions + sdk_tools = [] + + for tool_name, base_tool in TOOL_REGISTRY.items(): + # Get the handler + handler = create_tool_handler(base_tool) + + # Create the decorated tool + # The @tool decorator expects (name, description, schema) + decorated = tool( + tool_name, + base_tool.description, + base_tool.parameters.get("properties", {}), + )(handler) + + sdk_tools.append(decorated) + + # Create the MCP server + server = create_sdk_mcp_server( + name="copilot", + version="1.0.0", + tools=sdk_tools, + ) + + return server + + except ImportError: + logger.warning( + "claude-agent-sdk not available, returning tool definitions only" + ) + return { + "tools": get_tool_definitions(), + "handlers": get_tool_handlers(), + } + + +# List of tool names for allowed_tools configuration +COPILOT_TOOL_NAMES = [f"mcp__copilot__{name}" for name in TOOL_REGISTRY.keys()] + +# Also export the raw tool names for flexibility +RAW_TOOL_NAMES = list(TOOL_REGISTRY.keys()) diff --git a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py index 88a5023e2b..35b7681482 100644 --- a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py +++ b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py @@ -555,6 +555,31 @@ async def get_active_task_for_session( if task_user_id and user_id != task_user_id: continue + # Skip stale tasks (running for more than 5 minutes is suspicious) + created_at_str = meta.get("created_at", "") + if created_at_str: + try: + created_at = datetime.fromisoformat(created_at_str) + age_seconds = ( + datetime.now(timezone.utc) - created_at + ).total_seconds() + if ( + age_seconds > 60 + ): # 1 minute - tasks orphaned by server restart + logger.warning( + f"[TASK_LOOKUP] Skipping stale task {task_id[:8]}... " + f"(age={age_seconds:.0f}s)" + ) + # Mark stale task as failed to clean it up + await mark_task_completed(task_id, "failed") + continue + except (ValueError, TypeError): + pass # If we can't parse the date, continue with the task + + logger.info( + f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..." + ) + # Get the last message ID from Redis Stream stream_key = _get_task_stream_key(task_id) last_id = "0-0" diff --git a/autogpt_platform/backend/poetry.lock b/autogpt_platform/backend/poetry.lock index 91ac358ade..72b3fafa09 100644 --- a/autogpt_platform/backend/poetry.lock +++ b/autogpt_platform/backend/poetry.lock @@ -825,6 +825,29 @@ files = [ {file = "charset_normalizer-3.4.2.tar.gz", hash = "sha256:5baececa9ecba31eff645232d59845c07aa030f0c81ee70184a90d35099a0e63"}, ] +[[package]] +name = "claude-agent-sdk" +version = "0.1.29" +description = "Python SDK for Claude Code" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "claude_agent_sdk-0.1.29-py3-none-macosx_11_0_arm64.whl", hash = "sha256:811de31c92bd90250ebbfd79758c538766c672abde244ae0f7dec2d02ed5a1f7"}, + {file = "claude_agent_sdk-0.1.29-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:6279360d251ce8b8e9d922b03e3492c88736648e7f5e7c9f301fde0eef37928f"}, + {file = "claude_agent_sdk-0.1.29-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:4d1f01fe5f7252126f35808e2887a40125b784ac0dbf73b9509a4065a4766149"}, + {file = "claude_agent_sdk-0.1.29-py3-none-win_amd64.whl", hash = "sha256:67fb58a72f0dd54d079c538078130cc8c888bc60652d3d396768ffaee6716467"}, + {file = "claude_agent_sdk-0.1.29.tar.gz", hash = "sha256:ece32436a81fc015ca325d4121edeb5627ae9af15b5079f7b42d5eda9dcdb7a3"}, +] + +[package.dependencies] +anyio = ">=4.0.0" +mcp = ">=0.1.0" +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +dev = ["anyio[trio] (>=4.0.0)", "mypy (>=1.0.0)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.20.0)", "pytest-cov (>=4.0.0)", "ruff (>=0.1.0)"] + [[package]] name = "cleo" version = "2.1.0" @@ -2320,6 +2343,18 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "httpx-sse" +version = "0.4.3" +description = "Consume Server-Sent Event (SSE) messages with HTTPX." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc"}, + {file = "httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d"}, +] + [[package]] name = "huggingface-hub" version = "0.34.4" @@ -2981,6 +3016,39 @@ files = [ {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, ] +[[package]] +name = "mcp" +version = "1.26.0" +description = "Model Context Protocol SDK" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "mcp-1.26.0-py3-none-any.whl", hash = "sha256:904a21c33c25aa98ddbeb47273033c435e595bbacfdb177f4bd87f6dceebe1ca"}, + {file = "mcp-1.26.0.tar.gz", hash = "sha256:db6e2ef491eecc1a0d93711a76f28dec2e05999f93afd48795da1c1137142c66"}, +] + +[package.dependencies] +anyio = ">=4.5" +httpx = ">=0.27.1" +httpx-sse = ">=0.4" +jsonschema = ">=4.20.0" +pydantic = ">=2.11.0,<3.0.0" +pydantic-settings = ">=2.5.2" +pyjwt = {version = ">=2.10.1", extras = ["crypto"]} +python-multipart = ">=0.0.9" +pywin32 = {version = ">=310", markers = "sys_platform == \"win32\""} +sse-starlette = ">=1.6.1" +starlette = ">=0.27" +typing-extensions = ">=4.9.0" +typing-inspection = ">=0.4.1" +uvicorn = {version = ">=0.31.1", markers = "sys_platform != \"emscripten\""} + +[package.extras] +cli = ["python-dotenv (>=1.0.0)", "typer (>=0.16.0)"] +rich = ["rich (>=13.9.4)"] +ws = ["websockets (>=15.0.1)"] + [[package]] name = "mdurl" version = "0.1.2" @@ -4605,7 +4673,6 @@ files = [ {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567"}, - {file = "psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864"}, @@ -5210,7 +5277,7 @@ description = "Python for Window Extensions" optional = false python-versions = "*" groups = ["main"] -markers = "platform_system == \"Windows\"" +markers = "sys_platform == \"win32\" or platform_system == \"Windows\"" files = [ {file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"}, {file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"}, @@ -6195,6 +6262,27 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] sqlcipher = ["sqlcipher3_binary"] +[[package]] +name = "sse-starlette" +version = "3.0.3" +description = "SSE plugin for Starlette" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "sse_starlette-3.0.3-py3-none-any.whl", hash = "sha256:af5bf5a6f3933df1d9c7f8539633dc8444ca6a97ab2e2a7cd3b6e431ac03a431"}, + {file = "sse_starlette-3.0.3.tar.gz", hash = "sha256:88cfb08747e16200ea990c8ca876b03910a23b547ab3bd764c0d8eb81019b971"}, +] + +[package.dependencies] +anyio = ">=4.7.0" + +[package.extras] +daphne = ["daphne (>=4.2.0)"] +examples = ["aiosqlite (>=0.21.0)", "fastapi (>=0.115.12)", "sqlalchemy[asyncio] (>=2.0.41)", "starlette (>=0.49.1)", "uvicorn (>=0.34.0)"] +granian = ["granian (>=2.3.1)"] +uvicorn = ["uvicorn (>=0.34.0)"] + [[package]] name = "stagehand" version = "0.5.1" @@ -7512,4 +7600,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "ee5742dc1a9df50dfc06d4b26a1682cbb2b25cab6b79ce5625ec272f93e4f4bf" +content-hash = "84170f26db93731b93b7646fb29ec6b64b4312337641b65cb36b21dbe3f14d8c" diff --git a/autogpt_platform/backend/pyproject.toml b/autogpt_platform/backend/pyproject.toml index fe263e47c0..0f25a23525 100644 --- a/autogpt_platform/backend/pyproject.toml +++ b/autogpt_platform/backend/pyproject.toml @@ -13,6 +13,7 @@ aio-pika = "^9.5.5" aiohttp = "^3.10.0" aiodns = "^3.5.0" anthropic = "^0.59.0" +claude-agent-sdk = "^0.1.0" apscheduler = "^3.11.1" autogpt-libs = { path = "../autogpt_libs", develop = true } bleach = { extras = ["css"], version = "^6.2.0" }