diff --git a/autogpt_platform/backend/backend/api/features/chat/service.py b/autogpt_platform/backend/backend/api/features/chat/service.py index 6d99aab583..df5ff7a60a 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/service.py @@ -800,6 +800,114 @@ async def _summarize_messages( return summary or "No summary available." +def _ensure_tool_pairs_intact( + recent_messages: list[dict], + all_messages: list[dict], + start_index: int, +) -> list[dict]: + """ + Ensure tool_call/tool_response pairs stay together after slicing. + + When slicing messages for context compaction, a naive slice can separate + an assistant message containing tool_calls from its corresponding tool + response messages. This causes API validation errors (e.g., Anthropic's + "unexpected tool_use_id found in tool_result blocks"). + + This function checks for orphan tool responses in the slice and extends + backwards to include their corresponding assistant messages. + + Args: + recent_messages: The sliced messages to validate + all_messages: The complete message list (for looking up missing assistants) + start_index: The index in all_messages where recent_messages begins + + Returns: + A potentially extended list of messages with tool pairs intact + """ + if not recent_messages: + return recent_messages + + # Collect all tool_call_ids from assistant messages in the slice + available_tool_call_ids: set[str] = set() + for msg in recent_messages: + if msg.get("role") == "assistant" and msg.get("tool_calls"): + for tc in msg["tool_calls"]: + tc_id = tc.get("id") + if tc_id: + available_tool_call_ids.add(tc_id) + + # Find orphan tool responses (tool messages whose tool_call_id is missing) + orphan_tool_call_ids: set[str] = set() + for msg in recent_messages: + if msg.get("role") == "tool": + tc_id = msg.get("tool_call_id") + if tc_id and tc_id not in available_tool_call_ids: + orphan_tool_call_ids.add(tc_id) + + if not orphan_tool_call_ids: + # No orphans, slice is valid + return recent_messages + + # Find the assistant messages that contain the orphan tool_call_ids + # Search backwards from start_index in all_messages + messages_to_prepend: list[dict] = [] + for i in range(start_index - 1, -1, -1): + msg = all_messages[i] + if msg.get("role") == "assistant" and msg.get("tool_calls"): + msg_tool_ids = {tc.get("id") for tc in msg["tool_calls"] if tc.get("id")} + if msg_tool_ids & orphan_tool_call_ids: + # This assistant message has tool_calls we need + # Also collect its contiguous tool responses that follow it + assistant_and_responses: list[dict] = [msg] + + # Scan forward from this assistant to collect tool responses + for j in range(i + 1, start_index): + following_msg = all_messages[j] + if following_msg.get("role") == "tool": + tool_id = following_msg.get("tool_call_id") + if tool_id and tool_id in msg_tool_ids: + assistant_and_responses.append(following_msg) + else: + # Stop at first non-tool message + break + + # Prepend the assistant and its tool responses (maintain order) + messages_to_prepend = assistant_and_responses + messages_to_prepend + # Mark these as found + orphan_tool_call_ids -= msg_tool_ids + # Also add this assistant's tool_call_ids to available set + available_tool_call_ids |= msg_tool_ids + + if not orphan_tool_call_ids: + # Found all missing assistants + break + + if orphan_tool_call_ids: + # Some tool_call_ids couldn't be resolved - remove those tool responses + # This shouldn't happen in normal operation but handles edge cases + logger.warning( + f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. " + "Removing orphan tool responses." + ) + recent_messages = [ + msg + for msg in recent_messages + if not ( + msg.get("role") == "tool" + and msg.get("tool_call_id") in orphan_tool_call_ids + ) + ] + + if messages_to_prepend: + logger.info( + f"Extended recent messages by {len(messages_to_prepend)} to preserve " + f"tool_call/tool_response pairs" + ) + return messages_to_prepend + recent_messages + + return recent_messages + + async def _stream_chat_chunks( session: ChatSession, tools: list[ChatCompletionToolParam], @@ -891,7 +999,15 @@ async def _stream_chat_chunks( # Always attempt mitigation when over limit, even with few messages if messages: # Split messages based on whether system prompt exists - recent_messages = messages[-KEEP_RECENT:] + # Calculate start index for the slice + slice_start = max(0, len(messages_dict) - KEEP_RECENT) + recent_messages = messages_dict[-KEEP_RECENT:] + + # Ensure tool_call/tool_response pairs stay together + # This prevents API errors from orphan tool responses + recent_messages = _ensure_tool_pairs_intact( + recent_messages, messages_dict, slice_start + ) if has_system_prompt: # Keep system prompt separate, summarize everything between system and recent @@ -978,6 +1094,13 @@ async def _stream_chat_chunks( if len(recent_messages) >= keep_count else recent_messages ) + # Ensure tool pairs stay intact in the reduced slice + reduced_slice_start = max( + 0, len(recent_messages) - keep_count + ) + reduced_recent = _ensure_tool_pairs_intact( + reduced_recent, recent_messages, reduced_slice_start + ) if has_system_prompt: messages = [ system_msg, @@ -1036,7 +1159,10 @@ async def _stream_chat_chunks( # Create a base list excluding system prompt to avoid duplication # This is the pool of messages we'll slice from in the loop - base_msgs = messages[1:] if has_system_prompt else messages + # Use messages_dict for type consistency with _ensure_tool_pairs_intact + base_msgs = ( + messages_dict[1:] if has_system_prompt else messages_dict + ) # Try progressively smaller keep counts new_token_count = token_count # Initialize with current count @@ -1059,6 +1185,12 @@ async def _stream_chat_chunks( # Slice from base_msgs to get recent messages (without system prompt) recent_messages = base_msgs[-keep_count:] + # Ensure tool pairs stay intact in the reduced slice + reduced_slice_start = max(0, len(base_msgs) - keep_count) + recent_messages = _ensure_tool_pairs_intact( + recent_messages, base_msgs, reduced_slice_start + ) + if has_system_prompt: messages = [system_msg] + recent_messages else: