diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py index 57126c5563..60a6a3cc1c 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py @@ -8,7 +8,17 @@ the frontend expects. import json import logging import uuid -from typing import Any, AsyncGenerator + +from claude_agent_sdk import ( + AssistantMessage, + Message, + ResultMessage, + SystemMessage, + TextBlock, + ToolResultBlock, + ToolUseBlock, + UserMessage, +) from backend.api.features.chat.response_model import ( StreamBaseResponse, @@ -36,234 +46,106 @@ class SDKResponseAdapter: """ def __init__(self, message_id: str | None = None): - """Initialize the adapter. - - Args: - message_id: Optional message ID. If not provided, one will be generated. - """ self.message_id = message_id or str(uuid.uuid4()) self.text_block_id = str(uuid.uuid4()) self.has_started_text = False self.has_ended_text = False - self.current_tool_calls: dict[str, dict[str, Any]] = {} + self.current_tool_calls: dict[str, dict[str, str]] = {} self.task_id: str | None = None def set_task_id(self, task_id: str) -> None: """Set the task ID for reconnection support.""" self.task_id = task_id - def convert_message(self, sdk_message: Any) -> list[StreamBaseResponse]: - """Convert a single SDK message to Vercel AI SDK format. - - Args: - sdk_message: A message from the Claude Agent SDK. - - Returns: - List of StreamBaseResponse objects (may be empty or multiple). - """ + def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]: + """Convert a single SDK message to Vercel AI SDK format.""" responses: list[StreamBaseResponse] = [] - # Handle different SDK message types - use class name since SDK uses dataclasses - class_name = type(sdk_message).__name__ - msg_subtype = getattr(sdk_message, "subtype", None) - - if class_name == "SystemMessage": - if msg_subtype == "init": - # Session initialization - emit start + if isinstance(sdk_message, SystemMessage): + if sdk_message.subtype == "init": responses.append( - StreamStart( - messageId=self.message_id, - taskId=self.task_id, - ) + StreamStart(messageId=self.message_id, taskId=self.task_id) ) - elif class_name == "AssistantMessage": - # Assistant message with content blocks - content = getattr(sdk_message, "content", []) - for block in content: - # Check block type by class name (SDK uses dataclasses) or dict type - block_class = type(block).__name__ - block_type = block.get("type") if isinstance(block, dict) else None - - if block_class == "TextBlock" or block_type == "text": - # Text content - text = getattr(block, "text", None) or ( - block.get("text") if isinstance(block, dict) else "" - ) - - if text: - # Start text block if needed (or restart after tool calls) - if not self.has_started_text or self.has_ended_text: - # Generate new text block ID for text after tools - if self.has_ended_text: - self.text_block_id = str(uuid.uuid4()) - self.has_ended_text = False - responses.append(StreamTextStart(id=self.text_block_id)) - self.has_started_text = True - - # Emit text delta + elif isinstance(sdk_message, AssistantMessage): + for block in sdk_message.content: + if isinstance(block, TextBlock): + if block.text: + self._ensure_text_started(responses) responses.append( - StreamTextDelta( - id=self.text_block_id, - delta=text, - ) + StreamTextDelta(id=self.text_block_id, delta=block.text) ) - elif block_class == "ToolUseBlock" or block_type == "tool_use": - # Tool call - tool_id_raw = getattr(block, "id", None) or ( - block.get("id") if isinstance(block, dict) else None - ) - tool_id: str = ( - str(tool_id_raw) if tool_id_raw else str(uuid.uuid4()) - ) + elif isinstance(block, ToolUseBlock): + self._end_text_if_open(responses) - tool_name_raw = getattr(block, "name", None) or ( - block.get("name") if isinstance(block, dict) else None - ) - tool_name: str = str(tool_name_raw) if tool_name_raw else "unknown" - - tool_input = getattr(block, "input", None) or ( - block.get("input") if isinstance(block, dict) else {} - ) - - # End text block if we were streaming text - if self.has_started_text and not self.has_ended_text: - responses.append(StreamTextEnd(id=self.text_block_id)) - self.has_ended_text = True - - # Emit tool input start responses.append( - StreamToolInputStart( - toolCallId=tool_id, - toolName=tool_name, - ) + StreamToolInputStart(toolCallId=block.id, toolName=block.name) ) - - # Emit tool input available with full input responses.append( StreamToolInputAvailable( - toolCallId=tool_id, + toolCallId=block.id, + toolName=block.name, + input=block.input, + ) + ) + self.current_tool_calls[block.id] = {"name": block.name} + + elif isinstance(sdk_message, UserMessage): + # UserMessage carries tool results back from tool execution + content = sdk_message.content + blocks = content if isinstance(content, list) else [] + for block in blocks: + if isinstance(block, ToolResultBlock) and block.tool_use_id: + tool_info = self.current_tool_calls.get(block.tool_use_id, {}) + tool_name = tool_info.get("name", "unknown") + output = _extract_tool_output(block.content) + responses.append( + StreamToolOutputAvailable( + toolCallId=block.tool_use_id, toolName=tool_name, - input=tool_input if isinstance(tool_input, dict) else {}, + output=output, + success=not (block.is_error or False), ) ) - # Track the tool call - self.current_tool_calls[tool_id] = { - "name": tool_name, - "input": tool_input, - } - - elif class_name in ("ToolResultMessage", "UserMessage"): - # Tool result - check for tool_result content - content = getattr(sdk_message, "content", []) - - for block in content: - block_class = type(block).__name__ - block_type = block.get("type") if isinstance(block, dict) else None - - if block_class == "ToolResultBlock" or block_type == "tool_result": - tool_use_id = getattr(block, "tool_use_id", None) or ( - block.get("tool_use_id") if isinstance(block, dict) else None - ) - result_content = getattr(block, "content", None) or ( - block.get("content") if isinstance(block, dict) else "" - ) - is_error = getattr(block, "is_error", False) or ( - block.get("is_error", False) - if isinstance(block, dict) - else False - ) - - if tool_use_id: - tool_info = self.current_tool_calls.get(tool_use_id, {}) - tool_name = tool_info.get("name", "unknown") - - # Format the output - if isinstance(result_content, list): - # Extract text from content blocks - output_text = "" - for item in result_content: - if ( - isinstance(item, dict) - and item.get("type") == "text" - ): - output_text += item.get("text", "") - elif hasattr(item, "text"): - output_text += getattr(item, "text", "") - if output_text: - output = output_text - else: - try: - output = json.dumps(result_content) - except (TypeError, ValueError): - output = str(result_content) - elif isinstance(result_content, str): - output = result_content - else: - try: - output = json.dumps(result_content) - except (TypeError, ValueError): - output = str(result_content) - - responses.append( - StreamToolOutputAvailable( - toolCallId=tool_use_id, - toolName=tool_name, - output=output, - success=not is_error, - ) - ) - - elif class_name == "ResultMessage": - # Final result - if msg_subtype == "success": - # End text block if still open - if self.has_started_text and not self.has_ended_text: - responses.append(StreamTextEnd(id=self.text_block_id)) - self.has_ended_text = True - - # Emit finish + elif isinstance(sdk_message, ResultMessage): + if sdk_message.subtype == "success": + self._end_text_if_open(responses) responses.append(StreamFinish()) - elif msg_subtype in ("error", "error_during_execution"): - error_msg = getattr(sdk_message, "error", "Unknown error") + elif sdk_message.subtype in ("error", "error_during_execution"): + error_msg = getattr(sdk_message, "result", None) or "Unknown error" responses.append( - StreamError( - errorText=str(error_msg), - code="sdk_error", - ) + StreamError(errorText=str(error_msg), code="sdk_error") ) responses.append(StreamFinish()) - elif class_name == "ErrorMessage": - # Error message - error_msg = getattr(sdk_message, "message", None) or getattr( - sdk_message, "error", "Unknown error" - ) - responses.append( - StreamError( - errorText=str(error_msg), - code="sdk_error", - ) - ) - responses.append(StreamFinish()) - else: - logger.debug(f"Unhandled SDK message type: {class_name}") + logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}") return responses + def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None: + """Start (or restart) a text block if needed.""" + if not self.has_started_text or self.has_ended_text: + if self.has_ended_text: + self.text_block_id = str(uuid.uuid4()) + self.has_ended_text = False + responses.append(StreamTextStart(id=self.text_block_id)) + self.has_started_text = True + + def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None: + """End the current text block if one is open.""" + if self.has_started_text and not self.has_ended_text: + responses.append(StreamTextEnd(id=self.text_block_id)) + self.has_ended_text = True + def create_heartbeat(self, tool_call_id: str | None = None) -> StreamHeartbeat: """Create a heartbeat response.""" return StreamHeartbeat(toolCallId=tool_call_id) - def create_usage( - self, - prompt_tokens: int, - completion_tokens: int, - ) -> StreamUsage: + def create_usage(self, prompt_tokens: int, completion_tokens: int) -> StreamUsage: """Create a usage statistics response.""" return StreamUsage( promptTokens=prompt_tokens, @@ -272,49 +154,21 @@ class SDKResponseAdapter: ) -async def adapt_sdk_stream( - sdk_stream: AsyncGenerator[Any, None], - message_id: str | None = None, - task_id: str | None = None, -) -> AsyncGenerator[StreamBaseResponse, None]: - """Adapt a Claude Agent SDK stream to Vercel AI SDK format. - - Args: - sdk_stream: The async generator from the Claude Agent SDK. - message_id: Optional message ID for the response. - task_id: Optional task ID for reconnection support. - - Yields: - StreamBaseResponse objects in Vercel AI SDK format. - """ - adapter = SDKResponseAdapter(message_id=message_id) - if task_id: - adapter.set_task_id(task_id) - - # Emit start immediately - yield StreamStart(messageId=adapter.message_id, taskId=task_id) - - finished = False +def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str: + """Extract a string output from a ToolResultBlock's content field.""" + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [item.get("text", "") for item in content if item.get("type") == "text"] + if parts: + return "".join(parts) + try: + return json.dumps(content) + except (TypeError, ValueError): + return str(content) + if content is None: + return "" try: - async for sdk_message in sdk_stream: - responses = adapter.convert_message(sdk_message) - for response in responses: - # Skip duplicate start messages - if isinstance(response, StreamStart): - continue - if isinstance(response, StreamFinish): - finished = True - yield response - - except Exception as e: - logger.error(f"Error in SDK stream: {e}", exc_info=True) - yield StreamError( - errorText="An error occurred. Please try again.", - code="stream_error", - ) - yield StreamFinish() - return - - # Ensure terminal StreamFinish if SDK stream ended without one - if not finished: - yield StreamFinish() + return json.dumps(content) + except (TypeError, ValueError): + return str(content) diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter_test.py b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter_test.py new file mode 100644 index 0000000000..098836acad --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter_test.py @@ -0,0 +1,324 @@ +"""Unit tests for the SDK response adapter.""" + +from claude_agent_sdk import ( + AssistantMessage, + ResultMessage, + SystemMessage, + TextBlock, + ToolResultBlock, + ToolUseBlock, + UserMessage, +) + +from backend.api.features.chat.response_model import ( + StreamBaseResponse, + StreamError, + StreamFinish, + StreamStart, + StreamTextDelta, + StreamTextEnd, + StreamTextStart, + StreamToolInputAvailable, + StreamToolInputStart, + StreamToolOutputAvailable, +) + +from .response_adapter import SDKResponseAdapter + + +def _adapter() -> SDKResponseAdapter: + a = SDKResponseAdapter(message_id="msg-1") + a.set_task_id("task-1") + return a + + +# -- SystemMessage ----------------------------------------------------------- + + +def test_system_init_emits_start(): + adapter = _adapter() + results = adapter.convert_message(SystemMessage(subtype="init", data={})) + assert len(results) == 1 + assert isinstance(results[0], StreamStart) + assert results[0].messageId == "msg-1" + assert results[0].taskId == "task-1" + + +def test_system_non_init_emits_nothing(): + adapter = _adapter() + results = adapter.convert_message(SystemMessage(subtype="other", data={})) + assert results == [] + + +# -- AssistantMessage with TextBlock ----------------------------------------- + + +def test_text_block_emits_start_and_delta(): + adapter = _adapter() + msg = AssistantMessage(content=[TextBlock(text="hello")], model="test") + results = adapter.convert_message(msg) + assert len(results) == 2 + assert isinstance(results[0], StreamTextStart) + assert isinstance(results[1], StreamTextDelta) + assert results[1].delta == "hello" + + +def test_empty_text_block_is_skipped(): + adapter = _adapter() + msg = AssistantMessage(content=[TextBlock(text="")], model="test") + results = adapter.convert_message(msg) + assert results == [] + + +def test_multiple_text_deltas_reuse_block_id(): + adapter = _adapter() + msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test") + msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test") + r1 = adapter.convert_message(msg1) + r2 = adapter.convert_message(msg2) + # First gets start+delta, second only delta (block already started) + assert len(r1) == 2 + assert len(r2) == 1 + assert isinstance(r2[0], StreamTextDelta) + assert isinstance(r1[0], StreamTextStart) + assert r1[0].id == r2[0].id # same block ID + + +# -- AssistantMessage with ToolUseBlock -------------------------------------- + + +def test_tool_use_emits_input_start_and_available(): + adapter = _adapter() + msg = AssistantMessage( + content=[ToolUseBlock(id="tool-1", name="find_agent", input={"q": "x"})], + model="test", + ) + results = adapter.convert_message(msg) + assert len(results) == 2 + assert isinstance(results[0], StreamToolInputStart) + assert results[0].toolCallId == "tool-1" + assert results[0].toolName == "find_agent" + assert isinstance(results[1], StreamToolInputAvailable) + assert results[1].input == {"q": "x"} + + +def test_text_then_tool_ends_text_block(): + adapter = _adapter() + text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test") + tool_msg = AssistantMessage( + content=[ToolUseBlock(id="t1", name="tool", input={})], model="test" + ) + adapter.convert_message(text_msg) + results = adapter.convert_message(tool_msg) + # Should have: TextEnd, ToolInputStart, ToolInputAvailable + assert len(results) == 3 + assert isinstance(results[0], StreamTextEnd) + assert isinstance(results[1], StreamToolInputStart) + + +# -- UserMessage with ToolResultBlock ---------------------------------------- + + +def test_tool_result_emits_output(): + adapter = _adapter() + # First register the tool call + tool_msg = AssistantMessage( + content=[ToolUseBlock(id="t1", name="find_agent", input={})], model="test" + ) + adapter.convert_message(tool_msg) + + # Now send tool result + result_msg = UserMessage( + content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")] + ) + results = adapter.convert_message(result_msg) + assert len(results) == 1 + assert isinstance(results[0], StreamToolOutputAvailable) + assert results[0].toolCallId == "t1" + assert results[0].toolName == "find_agent" + assert results[0].output == "found 3 agents" + assert results[0].success is True + + +def test_tool_result_error(): + adapter = _adapter() + adapter.convert_message( + AssistantMessage( + content=[ToolUseBlock(id="t1", name="run_agent", input={})], model="test" + ) + ) + result_msg = UserMessage( + content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)] + ) + results = adapter.convert_message(result_msg) + assert isinstance(results[0], StreamToolOutputAvailable) + assert results[0].success is False + + +def test_tool_result_list_content(): + adapter = _adapter() + adapter.convert_message( + AssistantMessage( + content=[ToolUseBlock(id="t1", name="tool", input={})], model="test" + ) + ) + result_msg = UserMessage( + content=[ + ToolResultBlock( + tool_use_id="t1", + content=[ + {"type": "text", "text": "line1"}, + {"type": "text", "text": "line2"}, + ], + ) + ] + ) + results = adapter.convert_message(result_msg) + assert isinstance(results[0], StreamToolOutputAvailable) + assert results[0].output == "line1line2" + + +def test_string_user_message_ignored(): + """A plain string UserMessage (not tool results) produces no output.""" + adapter = _adapter() + results = adapter.convert_message(UserMessage(content="hello")) + assert results == [] + + +# -- ResultMessage ----------------------------------------------------------- + + +def test_result_success_emits_finish(): + adapter = _adapter() + # Start some text first + adapter.convert_message( + AssistantMessage(content=[TextBlock(text="done")], model="test") + ) + msg = ResultMessage( + subtype="success", + duration_ms=100, + duration_api_ms=50, + is_error=False, + num_turns=1, + session_id="s1", + ) + results = adapter.convert_message(msg) + # TextEnd + StreamFinish + assert len(results) == 2 + assert isinstance(results[0], StreamTextEnd) + assert isinstance(results[1], StreamFinish) + + +def test_result_error_emits_error_and_finish(): + adapter = _adapter() + msg = ResultMessage( + subtype="error", + duration_ms=100, + duration_api_ms=50, + is_error=True, + num_turns=0, + session_id="s1", + result="API rate limited", + ) + results = adapter.convert_message(msg) + assert len(results) == 2 + assert isinstance(results[0], StreamError) + assert "API rate limited" in results[0].errorText + assert isinstance(results[1], StreamFinish) + + +# -- Text after tools (new block ID) ---------------------------------------- + + +def test_text_after_tool_gets_new_block_id(): + adapter = _adapter() + # Text -> Tool -> Text should get a new text block ID + adapter.convert_message( + AssistantMessage(content=[TextBlock(text="before")], model="test") + ) + adapter.convert_message( + AssistantMessage( + content=[ToolUseBlock(id="t1", name="tool", input={})], model="test" + ) + ) + results = adapter.convert_message( + AssistantMessage(content=[TextBlock(text="after")], model="test") + ) + # Should get StreamTextStart (new block) + StreamTextDelta + assert len(results) == 2 + assert isinstance(results[0], StreamTextStart) + assert isinstance(results[1], StreamTextDelta) + assert results[1].delta == "after" + + +# -- Full conversation flow -------------------------------------------------- + + +def test_full_conversation_flow(): + """Simulate a complete conversation: init -> text -> tool -> result -> text -> finish.""" + adapter = _adapter() + all_responses: list[StreamBaseResponse] = [] + + # 1. Init + all_responses.extend( + adapter.convert_message(SystemMessage(subtype="init", data={})) + ) + # 2. Assistant text + all_responses.extend( + adapter.convert_message( + AssistantMessage(content=[TextBlock(text="Let me search")], model="test") + ) + ) + # 3. Tool use + all_responses.extend( + adapter.convert_message( + AssistantMessage( + content=[ + ToolUseBlock(id="t1", name="find_agent", input={"query": "email"}) + ], + model="test", + ) + ) + ) + # 4. Tool result + all_responses.extend( + adapter.convert_message( + UserMessage( + content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")] + ) + ) + ) + # 5. More text + all_responses.extend( + adapter.convert_message( + AssistantMessage(content=[TextBlock(text="I found 2")], model="test") + ) + ) + # 6. Result + all_responses.extend( + adapter.convert_message( + ResultMessage( + subtype="success", + duration_ms=500, + duration_api_ms=400, + is_error=False, + num_turns=2, + session_id="s1", + ) + ) + ) + + types = [type(r).__name__ for r in all_responses] + assert types == [ + "StreamStart", + "StreamTextStart", + "StreamTextDelta", # "Let me search" + "StreamTextEnd", # closed before tool + "StreamToolInputStart", + "StreamToolInputAvailable", + "StreamToolOutputAvailable", # tool result + "StreamTextStart", # new block after tool + "StreamTextDelta", # "I found 2" + "StreamTextEnd", # closed by result + "StreamFinish", + ] diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/service.py b/autogpt_platform/backend/backend/api/features/chat/sdk/service.py index 5d3c6c494f..0702b409b2 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/sdk/service.py @@ -6,7 +6,7 @@ import json import logging import os import uuid -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, AsyncIterator from typing import Any import openai @@ -63,6 +63,58 @@ def _cleanup_sdk_tool_results() -> None: pass +def _build_conversation_messages( + session: ChatSession, +) -> AsyncIterator[dict[str, Any]]: + """Build an async iterator of SDK-compatible message dicts from session history. + + Yields structured user/assistant turns that the SDK writes directly to the + CLI's stdin. This gives the model native conversation context (enabling + turn-level compaction for long conversations) without any file I/O. + + Only prior messages are yielded; the current (last) user message is + appended at the end so the SDK processes it as the new query. + """ + + async def _iter() -> AsyncIterator[dict[str, Any]]: + # Yield all messages except the last (current user message) + for msg in session.messages[:-1]: + if msg.role == "user": + yield { + "type": "user", + "message": { + "role": "user", + "content": msg.content or "", + }, + "session_id": session.session_id, + } + elif msg.role == "assistant" and msg.content: + yield { + "type": "assistant", + "message": { + "role": "assistant", + "content": [{"type": "text", "text": msg.content}], + }, + "session_id": session.session_id, + } + # Skip tool messages — the assistant's text already captures the + # key information and tool IDs won't match across sessions. + + # Yield the current user message last + current = session.messages[-1] if session.messages else None + if current and current.role == "user": + yield { + "type": "user", + "message": { + "role": "user", + "content": current.content or "", + }, + "session_id": session.session_id, + } + + return _iter() + + DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations. Here is everything you know about the current user from previous interactions: @@ -154,42 +206,6 @@ async def _build_system_prompt( return DEFAULT_SYSTEM_PROMPT.replace("{users_information}", context), understanding -def _format_conversation_history(session: ChatSession) -> str: - """Format conversation history as a prompt context. - - Passes full history to the SDK — the SDK handles context compaction - automatically when the context window approaches its limit. - """ - if not session.messages: - return "" - - # Get all messages except the last user message (which will be the prompt) - messages = session.messages[:-1] if session.messages else [] - if not messages: - return "" - - history_parts = [""] - - for msg in messages: - if msg.role == "user": - history_parts.append(f"User: {msg.content or ''}") - elif msg.role == "assistant": - # Only include text content, skip tool call metadata - # (tool calls are noise for history context) - if msg.content: - history_parts.append(f"Assistant: {msg.content}") - # Skip tool result messages — they're not useful for conversation context - - history_parts.append("") - history_parts.append("") - history_parts.append( - "Continue this conversation. Respond to the user's latest message:" - ) - history_parts.append("") - - return "\n".join(history_parts) - - async def _generate_session_title( message: str, user_id: str | None = None, @@ -310,127 +326,130 @@ async def stream_chat_completion_sdk( mcp_servers={"copilot": mcp_server}, # type: ignore[arg-type] allowed_tools=COPILOT_TOOL_NAMES, hooks=create_security_hooks(user_id), # type: ignore[arg-type] - continue_conversation=True, # Enable conversation continuation + continue_conversation=True, ) adapter = SDKResponseAdapter(message_id=message_id) adapter.set_task_id(task_id) - async with ClaudeSDKClient(options=options) as client: - # Build prompt with conversation history for context - # The SDK doesn't support replaying full conversation history, - # so we include it as context in the prompt - current_message = message or "" - if not current_message and session.messages: - last_user = [m for m in session.messages if m.role == "user"] - if last_user: - current_message = last_user[-1].content or "" + try: + async with ClaudeSDKClient(options=options) as client: + # Determine the current user message + current_message = message or "" + if not current_message and session.messages: + last_user = [m for m in session.messages if m.role == "user"] + if last_user: + current_message = last_user[-1].content or "" - # Include conversation history if there are prior messages - if len(session.messages) > 1: - history_context = _format_conversation_history(session) - prompt = f"{history_context}{current_message}" - else: - prompt = current_message + # Guard against empty messages + if not current_message.strip(): + yield StreamError( + errorText="Message cannot be empty.", + code="empty_prompt", + ) + yield StreamFinish() + return - logger.info( - f"[SDK] Prompt built: {len(prompt)} chars, " - f"{len(session.messages)} messages in session" - ) + # For multi-turn conversations, pass structured history + # as an AsyncIterable so the CLI sees native turns and + # can do turn-level compaction. For first messages, just + # send the string directly. + if len(session.messages) > 1: + history_iter = _build_conversation_messages(session) + await client.query(history_iter, session_id=session_id) + logger.info( + f"[SDK] Structured history: " + f"{len(session.messages) - 1} prior messages" + ) + else: + await client.query(current_message, session_id=session_id) + logger.info("[SDK] New conversation") - # Guard against empty prompts - if not prompt.strip(): - yield StreamError( - errorText="Message cannot be empty.", - code="empty_prompt", - ) - yield StreamFinish() - return + # Track assistant response to save to session + # We may need multiple assistant messages if text comes after tool results + assistant_response = ChatMessage(role="assistant", content="") + accumulated_tool_calls: list[dict[str, Any]] = [] + has_appended_assistant = False + has_tool_results = False # Track if we've received tool results - await client.query(prompt, session_id=session_id) + # Receive messages from the SDK + async for sdk_msg in client.receive_messages(): + for response in adapter.convert_message(sdk_msg): + if isinstance(response, StreamStart): + continue + yield response - # Track assistant response to save to session - # We may need multiple assistant messages if text comes after tool results - assistant_response = ChatMessage(role="assistant", content="") - accumulated_tool_calls: list[dict[str, Any]] = [] - has_appended_assistant = False - has_tool_results = False # Track if we've received tool results + # Accumulate text deltas into assistant response + if isinstance(response, StreamTextDelta): + delta = response.delta or "" + # After tool results, create new assistant message for post-tool text + if has_tool_results and has_appended_assistant: + assistant_response = ChatMessage( + role="assistant", content=delta + ) + accumulated_tool_calls = [] # Reset for new message + session.messages.append(assistant_response) + has_tool_results = False + else: + assistant_response.content = ( + assistant_response.content or "" + ) + delta + if not has_appended_assistant: + session.messages.append(assistant_response) + has_appended_assistant = True - # Receive messages from the SDK - async for sdk_msg in client.receive_messages(): - for response in adapter.convert_message(sdk_msg): - if isinstance(response, StreamStart): - continue - yield response - - # Accumulate text deltas into assistant response - if isinstance(response, StreamTextDelta): - delta = response.delta or "" - # After tool results, create new assistant message for post-tool text - if has_tool_results and has_appended_assistant: - assistant_response = ChatMessage( - role="assistant", content=delta + # Track tool calls on the assistant message + elif isinstance(response, StreamToolInputAvailable): + accumulated_tool_calls.append( + { + "id": response.toolCallId, + "type": "function", + "function": { + "name": response.toolName, + "arguments": json.dumps( + response.input or {} + ), + }, + } ) - accumulated_tool_calls = [] # Reset for new message - session.messages.append(assistant_response) - has_tool_results = False - else: - assistant_response.content = ( - assistant_response.content or "" - ) + delta + # Update assistant message with tool calls + assistant_response.tool_calls = accumulated_tool_calls + # Append assistant message if not already (tool-only response) if not has_appended_assistant: session.messages.append(assistant_response) has_appended_assistant = True - # Track tool calls on the assistant message - elif isinstance(response, StreamToolInputAvailable): - accumulated_tool_calls.append( - { - "id": response.toolCallId, - "type": "function", - "function": { - "name": response.toolName, - "arguments": json.dumps(response.input or {}), - }, - } - ) - # Update assistant message with tool calls - assistant_response.tool_calls = accumulated_tool_calls - # Append assistant message if not already (tool-only response) - if not has_appended_assistant: - session.messages.append(assistant_response) - has_appended_assistant = True - - elif isinstance(response, StreamToolOutputAvailable): - session.messages.append( - ChatMessage( - role="tool", - content=( - response.output - if isinstance(response.output, str) - else str(response.output) - ), - tool_call_id=response.toolCallId, + elif isinstance(response, StreamToolOutputAvailable): + session.messages.append( + ChatMessage( + role="tool", + content=( + response.output + if isinstance(response.output, str) + else str(response.output) + ), + tool_call_id=response.toolCallId, + ) ) - ) - has_tool_results = True + has_tool_results = True - elif isinstance(response, StreamFinish): - stream_completed = True + elif isinstance(response, StreamFinish): + stream_completed = True - # Break out of the message loop if we received finish signal - if stream_completed: - break + # Break out of the message loop if we received finish signal + if stream_completed: + break - # Ensure assistant response is saved even if no text deltas - # (e.g., only tool calls were made) - if ( - assistant_response.content or assistant_response.tool_calls - ) and not has_appended_assistant: - session.messages.append(assistant_response) + # Ensure assistant response is saved even if no text deltas + # (e.g., only tool calls were made) + if ( + assistant_response.content or assistant_response.tool_calls + ) and not has_appended_assistant: + session.messages.append(assistant_response) - # Clean up SDK tool-result files to avoid accumulation - _cleanup_sdk_tool_results() + finally: + # Always clean up SDK tool-result files, even on error + _cleanup_sdk_tool_results() except ImportError: logger.warning(