From 5efb80d47b476892bd22fcd3d9aacc3cfebcf38f Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Fri, 6 Feb 2026 13:25:10 +0400 Subject: [PATCH] fix(backend/chat): Address PR review comments for Claude SDK integration - Add StreamFinish after ErrorMessage in response adapter - Fix str.replace to removeprefix in security hooks - Apply max_context_messages limit as safety guard in history formatting - Add empty prompt guard before sending to SDK - Sanitize error messages to avoid exposing internal details - Fix fire-and-forget asyncio.create_task by storing task reference - Fix tool_calls population on assistant messages - Rewrite Anthropic fallback to persist messages and merge consecutive roles - Only use ANTHROPIC_API_KEY for fallback (not OpenRouter keys) - Fix IndexError when tool result content list is empty --- .../features/chat/sdk/anthropic_fallback.py | 157 +++++++++++++++--- .../api/features/chat/sdk/response_adapter.py | 1 + .../api/features/chat/sdk/security_hooks.py | 4 +- .../backend/api/features/chat/sdk/service.py | 60 ++++++- 4 files changed, 189 insertions(+), 33 deletions(-) diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/anthropic_fallback.py b/autogpt_platform/backend/backend/api/features/chat/sdk/anthropic_fallback.py index a9977f12f4..af5c78c21a 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/anthropic_fallback.py +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/anthropic_fallback.py @@ -11,8 +11,7 @@ import uuid from collections.abc import AsyncGenerator from typing import Any, cast -from ..config import ChatConfig -from ..model import ChatSession +from ..model import ChatMessage, ChatSession from ..response_model import ( StreamBaseResponse, StreamError, @@ -28,7 +27,6 @@ from ..response_model import ( from .tool_adapter import get_tool_definitions, get_tool_handlers logger = logging.getLogger(__name__) -config = ChatConfig() async def stream_with_anthropic( @@ -36,13 +34,19 @@ async def stream_with_anthropic( system_prompt: str, text_block_id: str, ) -> AsyncGenerator[StreamBaseResponse, None]: - """Stream using Anthropic SDK directly with tool calling support.""" + """Stream using Anthropic SDK directly with tool calling support. + + This function accumulates messages into the session for persistence. + The caller should NOT yield an additional StreamFinish - this function handles it. + """ import anthropic - api_key = os.getenv("ANTHROPIC_API_KEY") or config.api_key + # Only use ANTHROPIC_API_KEY - don't fall back to OpenRouter keys + api_key = os.getenv("ANTHROPIC_API_KEY") if not api_key: yield StreamError( - errorText="ANTHROPIC_API_KEY not configured", code="config_error" + errorText="ANTHROPIC_API_KEY not configured for fallback", + code="config_error", ) yield StreamFinish() return @@ -69,6 +73,8 @@ async def stream_with_anthropic( has_started_text = False max_iterations = 10 + accumulated_text = "" + accumulated_tool_calls: list[dict[str, Any]] = [] for _ in range(max_iterations): try: @@ -94,6 +100,7 @@ async def stream_with_anthropic( elif event.type == "content_block_delta": delta = event.delta if hasattr(delta, "type") and delta.type == "text_delta": + accumulated_text += delta.text yield StreamTextDelta(id=text_block_id, delta=delta.text) final_message = await stream.get_final_message() @@ -122,6 +129,22 @@ async def stream_with_anthropic( } ) + # Track tool call for session persistence + accumulated_tool_calls.append( + { + "id": block.id, + "type": "function", + "function": { + "name": block.name, + "arguments": json.dumps( + block.input + if isinstance(block.input, dict) + else {} + ), + }, + } + ) + yield StreamToolInputAvailable( toolCallId=block.id, toolName=block.name, @@ -141,6 +164,15 @@ async def stream_with_anthropic( success=not is_error, ) + # Save tool result to session + session.messages.append( + ChatMessage( + role="tool", + content=output, + tool_call_id=block.id, + ) + ) + tool_results.append( { "type": "tool_result", @@ -150,6 +182,22 @@ async def stream_with_anthropic( } ) + # Save assistant message with tool calls to session + session.messages.append( + ChatMessage( + role="assistant", + content=accumulated_text or None, + tool_calls=( + accumulated_tool_calls + if accumulated_tool_calls + else None + ), + ) + ) + # Reset for next iteration + accumulated_text = "" + accumulated_tool_calls = [] + anthropic_messages.append( {"role": "assistant", "content": assistant_content} ) @@ -160,6 +208,12 @@ async def stream_with_anthropic( if has_started_text: yield StreamTextEnd(id=text_block_id) + # Save final assistant response to session + if accumulated_text: + session.messages.append( + ChatMessage(role="assistant", content=accumulated_text) + ) + yield StreamUsage( promptTokens=final_message.usage.input_tokens, completionTokens=final_message.usage.output_tokens, @@ -171,7 +225,10 @@ async def stream_with_anthropic( except Exception as e: logger.error(f"[Anthropic Fallback] Error: {e}", exc_info=True) - yield StreamError(errorText=f"Error: {str(e)}", code="anthropic_error") + yield StreamError( + errorText="An error occurred. Please try again.", + code="anthropic_error", + ) yield StreamFinish() return @@ -180,11 +237,15 @@ async def stream_with_anthropic( def _convert_session_to_anthropic(session: ChatSession) -> list[dict[str, Any]]: - """Convert session messages to Anthropic format.""" - messages = [] + """Convert session messages to Anthropic format. + + Handles merging consecutive same-role messages (Anthropic requires alternating roles). + """ + messages: list[dict[str, Any]] = [] + for msg in session.messages: if msg.role == "user": - messages.append({"role": "user", "content": msg.content or ""}) + new_msg = {"role": "user", "content": msg.content or ""} elif msg.role == "assistant": content: list[dict[str, Any]] = [] if msg.content: @@ -207,21 +268,61 @@ def _convert_session_to_anthropic(session: ChatSession) -> list[dict[str, Any]]: } ) if content: - messages.append({"role": "assistant", "content": content}) + new_msg = {"role": "assistant", "content": content} + else: + continue # Skip empty assistant messages elif msg.role == "tool": - messages.append( - { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": msg.tool_call_id or "", - "content": msg.content or "", - } - ], - } - ) - return messages + new_msg = { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": msg.tool_call_id or "", + "content": msg.content or "", + } + ], + } + else: + continue + + messages.append(new_msg) + + # Merge consecutive same-role messages (Anthropic requires alternating roles) + return _merge_consecutive_roles(messages) + + +def _merge_consecutive_roles(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Merge consecutive messages with the same role. + + Anthropic API requires alternating user/assistant roles. + """ + if not messages: + return [] + + merged: list[dict[str, Any]] = [] + for msg in messages: + if merged and merged[-1]["role"] == msg["role"]: + # Merge with previous message + prev_content = merged[-1]["content"] + new_content = msg["content"] + + # Normalize both to list-of-blocks form + if isinstance(prev_content, str): + prev_content = [{"type": "text", "text": prev_content}] + if isinstance(new_content, str): + new_content = [{"type": "text", "text": new_content}] + + # Ensure both are lists + if not isinstance(prev_content, list): + prev_content = [prev_content] + if not isinstance(new_content, list): + new_content = [new_content] + + merged[-1]["content"] = prev_content + new_content + else: + merged.append(msg) + + return merged async def _execute_tool( @@ -234,7 +335,13 @@ async def _execute_tool( try: result = await handler(tool_input) - output = result.get("content", [{}])[0].get("text", "") + # Safely extract output - handle empty or missing content + content = result.get("content") or [] + if content and isinstance(content, list) and len(content) > 0: + first_item = content[0] + output = first_item.get("text", "") if isinstance(first_item, dict) else "" + else: + output = "" is_error = result.get("isError", False) return output, is_error except Exception as e: diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py index 9396aa4f90..5c740d6090 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py @@ -239,6 +239,7 @@ class SDKResponseAdapter: code="sdk_error", ) ) + responses.append(StreamFinish()) return responses diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py b/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py index bf5a909e09..8eac335867 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py @@ -237,9 +237,7 @@ def create_strict_security_hooks( tool_input = cast(dict[str, Any], input_data.get("tool_input", {})) # Remove MCP prefix if present - clean_name = tool_name - if tool_name.startswith("mcp__copilot__"): - clean_name = tool_name.replace("mcp__copilot__", "") + clean_name = tool_name.removeprefix("mcp__copilot__") if clean_name not in allowed_set: logger.warning(f"Blocked non-whitelisted tool: {tool_name}") diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/service.py b/autogpt_platform/backend/backend/api/features/chat/sdk/service.py index cbc15f1d6c..a723589e9f 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/service.py @@ -1,6 +1,7 @@ """Claude Agent SDK service layer for CoPilot chat completions.""" import asyncio +import json import logging import uuid from collections.abc import AsyncGenerator @@ -28,6 +29,7 @@ from ..response_model import ( StreamFinish, StreamStart, StreamTextDelta, + StreamToolInputAvailable, StreamToolOutputAvailable, ) from ..tracking import track_user_message @@ -43,6 +45,9 @@ from .tool_adapter import ( logger = logging.getLogger(__name__) config = ChatConfig() +# Set to hold background tasks to prevent garbage collection +_background_tasks: set[asyncio.Task[Any]] = set() + DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations. Here is everything you know about the current user from previous interactions: @@ -137,8 +142,8 @@ async def _build_system_prompt( def _format_conversation_history(session: ChatSession) -> str: """Format conversation history as a prompt context. - The SDK handles context compaction automatically, so we pass full history - without manual truncation. The SDK will intelligently summarize if needed. + The SDK handles context compaction automatically, but we apply + max_context_messages as a safety guard to limit initial prompt size. """ if not session.messages: return "" @@ -148,6 +153,12 @@ def _format_conversation_history(session: ChatSession) -> str: if not messages: return "" + # Apply max_context_messages limit as a safety guard + # (SDK handles compaction, but this prevents excessively large initial prompts) + max_messages = config.max_context_messages + if len(messages) > max_messages: + messages = messages[-max_messages:] + history_parts = [""] for msg in messages: @@ -261,9 +272,12 @@ async def stream_chat_completion_sdk( if len(user_messages) == 1: first_message = user_messages[0].content or message or "" if first_message: - asyncio.create_task( + task = asyncio.create_task( _update_title_async(session_id, first_message, user_id) ) + # Store reference to prevent garbage collection + _background_tasks.add(task) + task.add_done_callback(_background_tasks.discard) # Check if there's conversation history (more than just the current message) has_history = len(session.messages) > 1 @@ -316,11 +330,21 @@ async def stream_chat_completion_sdk( else: prompt = current_message + # Guard against empty prompts + if not prompt.strip(): + yield StreamError( + errorText="Message cannot be empty.", + code="empty_prompt", + ) + yield StreamFinish() + return + await client.query(prompt, session_id=session_id) # Track assistant response to save to session # We may need multiple assistant messages if text comes after tool results assistant_response = ChatMessage(role="assistant", content="") + accumulated_tool_calls: list[dict[str, Any]] = [] has_appended_assistant = False has_tool_results = False # Track if we've received tool results @@ -340,6 +364,7 @@ async def stream_chat_completion_sdk( assistant_response = ChatMessage( role="assistant", content=delta ) + accumulated_tool_calls = [] # Reset for new message session.messages.append(assistant_response) has_tool_results = False else: @@ -350,6 +375,25 @@ async def stream_chat_completion_sdk( session.messages.append(assistant_response) has_appended_assistant = True + # Track tool calls on the assistant message + elif isinstance(response, StreamToolInputAvailable): + accumulated_tool_calls.append( + { + "id": response.toolCallId, + "type": "function", + "function": { + "name": response.toolName, + "arguments": json.dumps(response.input or {}), + }, + } + ) + # Update assistant message with tool calls + assistant_response.tool_calls = accumulated_tool_calls + # Append assistant message if not already (tool-only response) + if not has_appended_assistant: + session.messages.append(assistant_response) + has_appended_assistant = True + elif isinstance(response, StreamToolOutputAvailable): session.messages.append( ChatMessage( @@ -373,7 +417,9 @@ async def stream_chat_completion_sdk( # Ensure assistant response is saved even if no text deltas # (e.g., only tool calls were made) - if assistant_response.content and not has_appended_assistant: + if ( + assistant_response.content or assistant_response.tool_calls + ) and not has_appended_assistant: session.messages.append(assistant_response) except ImportError: @@ -402,7 +448,11 @@ async def stream_chat_completion_sdk( await upsert_chat_session(session) except Exception as save_err: logger.error(f"[SDK] Failed to save session on error: {save_err}") - yield StreamError(errorText=f"An error occurred: {str(e)}", code="sdk_error") + # Sanitize error message to avoid exposing internal details + yield StreamError( + errorText="An error occurred. Please try again.", + code="sdk_error", + ) yield StreamFinish()