fix(backend/chat): Address PR review comments for Claude SDK integration

- Add StreamFinish after ErrorMessage in response adapter
- Fix str.replace to removeprefix in security hooks
- Apply max_context_messages limit as safety guard in history formatting
- Add empty prompt guard before sending to SDK
- Sanitize error messages to avoid exposing internal details
- Fix fire-and-forget asyncio.create_task by storing task reference
- Fix tool_calls population on assistant messages
- Rewrite Anthropic fallback to persist messages and merge consecutive roles
- Only use ANTHROPIC_API_KEY for fallback (not OpenRouter keys)
- Fix IndexError when tool result content list is empty
This commit is contained in:
Zamil Majdy
2026-02-06 13:25:10 +04:00
parent b49d8e2cba
commit 5efb80d47b
4 changed files with 189 additions and 33 deletions

View File

@@ -11,8 +11,7 @@ import uuid
from collections.abc import AsyncGenerator
from typing import Any, cast
from ..config import ChatConfig
from ..model import ChatSession
from ..model import ChatMessage, ChatSession
from ..response_model import (
StreamBaseResponse,
StreamError,
@@ -28,7 +27,6 @@ from ..response_model import (
from .tool_adapter import get_tool_definitions, get_tool_handlers
logger = logging.getLogger(__name__)
config = ChatConfig()
async def stream_with_anthropic(
@@ -36,13 +34,19 @@ async def stream_with_anthropic(
system_prompt: str,
text_block_id: str,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream using Anthropic SDK directly with tool calling support."""
"""Stream using Anthropic SDK directly with tool calling support.
This function accumulates messages into the session for persistence.
The caller should NOT yield an additional StreamFinish - this function handles it.
"""
import anthropic
api_key = os.getenv("ANTHROPIC_API_KEY") or config.api_key
# Only use ANTHROPIC_API_KEY - don't fall back to OpenRouter keys
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
yield StreamError(
errorText="ANTHROPIC_API_KEY not configured", code="config_error"
errorText="ANTHROPIC_API_KEY not configured for fallback",
code="config_error",
)
yield StreamFinish()
return
@@ -69,6 +73,8 @@ async def stream_with_anthropic(
has_started_text = False
max_iterations = 10
accumulated_text = ""
accumulated_tool_calls: list[dict[str, Any]] = []
for _ in range(max_iterations):
try:
@@ -94,6 +100,7 @@ async def stream_with_anthropic(
elif event.type == "content_block_delta":
delta = event.delta
if hasattr(delta, "type") and delta.type == "text_delta":
accumulated_text += delta.text
yield StreamTextDelta(id=text_block_id, delta=delta.text)
final_message = await stream.get_final_message()
@@ -122,6 +129,22 @@ async def stream_with_anthropic(
}
)
# Track tool call for session persistence
accumulated_tool_calls.append(
{
"id": block.id,
"type": "function",
"function": {
"name": block.name,
"arguments": json.dumps(
block.input
if isinstance(block.input, dict)
else {}
),
},
}
)
yield StreamToolInputAvailable(
toolCallId=block.id,
toolName=block.name,
@@ -141,6 +164,15 @@ async def stream_with_anthropic(
success=not is_error,
)
# Save tool result to session
session.messages.append(
ChatMessage(
role="tool",
content=output,
tool_call_id=block.id,
)
)
tool_results.append(
{
"type": "tool_result",
@@ -150,6 +182,22 @@ async def stream_with_anthropic(
}
)
# Save assistant message with tool calls to session
session.messages.append(
ChatMessage(
role="assistant",
content=accumulated_text or None,
tool_calls=(
accumulated_tool_calls
if accumulated_tool_calls
else None
),
)
)
# Reset for next iteration
accumulated_text = ""
accumulated_tool_calls = []
anthropic_messages.append(
{"role": "assistant", "content": assistant_content}
)
@@ -160,6 +208,12 @@ async def stream_with_anthropic(
if has_started_text:
yield StreamTextEnd(id=text_block_id)
# Save final assistant response to session
if accumulated_text:
session.messages.append(
ChatMessage(role="assistant", content=accumulated_text)
)
yield StreamUsage(
promptTokens=final_message.usage.input_tokens,
completionTokens=final_message.usage.output_tokens,
@@ -171,7 +225,10 @@ async def stream_with_anthropic(
except Exception as e:
logger.error(f"[Anthropic Fallback] Error: {e}", exc_info=True)
yield StreamError(errorText=f"Error: {str(e)}", code="anthropic_error")
yield StreamError(
errorText="An error occurred. Please try again.",
code="anthropic_error",
)
yield StreamFinish()
return
@@ -180,11 +237,15 @@ async def stream_with_anthropic(
def _convert_session_to_anthropic(session: ChatSession) -> list[dict[str, Any]]:
"""Convert session messages to Anthropic format."""
messages = []
"""Convert session messages to Anthropic format.
Handles merging consecutive same-role messages (Anthropic requires alternating roles).
"""
messages: list[dict[str, Any]] = []
for msg in session.messages:
if msg.role == "user":
messages.append({"role": "user", "content": msg.content or ""})
new_msg = {"role": "user", "content": msg.content or ""}
elif msg.role == "assistant":
content: list[dict[str, Any]] = []
if msg.content:
@@ -207,21 +268,61 @@ def _convert_session_to_anthropic(session: ChatSession) -> list[dict[str, Any]]:
}
)
if content:
messages.append({"role": "assistant", "content": content})
new_msg = {"role": "assistant", "content": content}
else:
continue # Skip empty assistant messages
elif msg.role == "tool":
messages.append(
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": msg.tool_call_id or "",
"content": msg.content or "",
}
],
}
)
return messages
new_msg = {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": msg.tool_call_id or "",
"content": msg.content or "",
}
],
}
else:
continue
messages.append(new_msg)
# Merge consecutive same-role messages (Anthropic requires alternating roles)
return _merge_consecutive_roles(messages)
def _merge_consecutive_roles(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Merge consecutive messages with the same role.
Anthropic API requires alternating user/assistant roles.
"""
if not messages:
return []
merged: list[dict[str, Any]] = []
for msg in messages:
if merged and merged[-1]["role"] == msg["role"]:
# Merge with previous message
prev_content = merged[-1]["content"]
new_content = msg["content"]
# Normalize both to list-of-blocks form
if isinstance(prev_content, str):
prev_content = [{"type": "text", "text": prev_content}]
if isinstance(new_content, str):
new_content = [{"type": "text", "text": new_content}]
# Ensure both are lists
if not isinstance(prev_content, list):
prev_content = [prev_content]
if not isinstance(new_content, list):
new_content = [new_content]
merged[-1]["content"] = prev_content + new_content
else:
merged.append(msg)
return merged
async def _execute_tool(
@@ -234,7 +335,13 @@ async def _execute_tool(
try:
result = await handler(tool_input)
output = result.get("content", [{}])[0].get("text", "")
# Safely extract output - handle empty or missing content
content = result.get("content") or []
if content and isinstance(content, list) and len(content) > 0:
first_item = content[0]
output = first_item.get("text", "") if isinstance(first_item, dict) else ""
else:
output = ""
is_error = result.get("isError", False)
return output, is_error
except Exception as e:

View File

@@ -239,6 +239,7 @@ class SDKResponseAdapter:
code="sdk_error",
)
)
responses.append(StreamFinish())
return responses

View File

@@ -237,9 +237,7 @@ def create_strict_security_hooks(
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
# Remove MCP prefix if present
clean_name = tool_name
if tool_name.startswith("mcp__copilot__"):
clean_name = tool_name.replace("mcp__copilot__", "")
clean_name = tool_name.removeprefix("mcp__copilot__")
if clean_name not in allowed_set:
logger.warning(f"Blocked non-whitelisted tool: {tool_name}")

View File

@@ -1,6 +1,7 @@
"""Claude Agent SDK service layer for CoPilot chat completions."""
import asyncio
import json
import logging
import uuid
from collections.abc import AsyncGenerator
@@ -28,6 +29,7 @@ from ..response_model import (
StreamFinish,
StreamStart,
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from ..tracking import track_user_message
@@ -43,6 +45,9 @@ from .tool_adapter import (
logger = logging.getLogger(__name__)
config = ChatConfig()
# Set to hold background tasks to prevent garbage collection
_background_tasks: set[asyncio.Task[Any]] = set()
DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations.
Here is everything you know about the current user from previous interactions:
@@ -137,8 +142,8 @@ async def _build_system_prompt(
def _format_conversation_history(session: ChatSession) -> str:
"""Format conversation history as a prompt context.
The SDK handles context compaction automatically, so we pass full history
without manual truncation. The SDK will intelligently summarize if needed.
The SDK handles context compaction automatically, but we apply
max_context_messages as a safety guard to limit initial prompt size.
"""
if not session.messages:
return ""
@@ -148,6 +153,12 @@ def _format_conversation_history(session: ChatSession) -> str:
if not messages:
return ""
# Apply max_context_messages limit as a safety guard
# (SDK handles compaction, but this prevents excessively large initial prompts)
max_messages = config.max_context_messages
if len(messages) > max_messages:
messages = messages[-max_messages:]
history_parts = ["<conversation_history>"]
for msg in messages:
@@ -261,9 +272,12 @@ async def stream_chat_completion_sdk(
if len(user_messages) == 1:
first_message = user_messages[0].content or message or ""
if first_message:
asyncio.create_task(
task = asyncio.create_task(
_update_title_async(session_id, first_message, user_id)
)
# Store reference to prevent garbage collection
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
# Check if there's conversation history (more than just the current message)
has_history = len(session.messages) > 1
@@ -316,11 +330,21 @@ async def stream_chat_completion_sdk(
else:
prompt = current_message
# Guard against empty prompts
if not prompt.strip():
yield StreamError(
errorText="Message cannot be empty.",
code="empty_prompt",
)
yield StreamFinish()
return
await client.query(prompt, session_id=session_id)
# Track assistant response to save to session
# We may need multiple assistant messages if text comes after tool results
assistant_response = ChatMessage(role="assistant", content="")
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False # Track if we've received tool results
@@ -340,6 +364,7 @@ async def stream_chat_completion_sdk(
assistant_response = ChatMessage(
role="assistant", content=delta
)
accumulated_tool_calls = [] # Reset for new message
session.messages.append(assistant_response)
has_tool_results = False
else:
@@ -350,6 +375,25 @@ async def stream_chat_completion_sdk(
session.messages.append(assistant_response)
has_appended_assistant = True
# Track tool calls on the assistant message
elif isinstance(response, StreamToolInputAvailable):
accumulated_tool_calls.append(
{
"id": response.toolCallId,
"type": "function",
"function": {
"name": response.toolName,
"arguments": json.dumps(response.input or {}),
},
}
)
# Update assistant message with tool calls
assistant_response.tool_calls = accumulated_tool_calls
# Append assistant message if not already (tool-only response)
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolOutputAvailable):
session.messages.append(
ChatMessage(
@@ -373,7 +417,9 @@ async def stream_chat_completion_sdk(
# Ensure assistant response is saved even if no text deltas
# (e.g., only tool calls were made)
if assistant_response.content and not has_appended_assistant:
if (
assistant_response.content or assistant_response.tool_calls
) and not has_appended_assistant:
session.messages.append(assistant_response)
except ImportError:
@@ -402,7 +448,11 @@ async def stream_chat_completion_sdk(
await upsert_chat_session(session)
except Exception as save_err:
logger.error(f"[SDK] Failed to save session on error: {save_err}")
yield StreamError(errorText=f"An error occurred: {str(e)}", code="sdk_error")
# Sanitize error message to avoid exposing internal details
yield StreamError(
errorText="An error occurred. Please try again.",
code="sdk_error",
)
yield StreamFinish()