refactor(backend/chat): Use proper SDK types and in-memory conversation history

Replace duck typing (class name checks, getattr) with isinstance() using
SDK-exported dataclasses. Replace file-based --resume with AsyncIterable
message injection for conversation history, eliminating disk I/O. Add 15
unit tests for the response adapter.
This commit is contained in:
Zamil Majdy
2026-02-10 18:17:00 +04:00
parent 0f2d1a6553
commit a31fc8b162
3 changed files with 567 additions and 370 deletions

View File

@@ -8,7 +8,17 @@ the frontend expects.
import json
import logging
import uuid
from typing import Any, AsyncGenerator
from claude_agent_sdk import (
AssistantMessage,
Message,
ResultMessage,
SystemMessage,
TextBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
from backend.api.features.chat.response_model import (
StreamBaseResponse,
@@ -36,234 +46,106 @@ class SDKResponseAdapter:
"""
def __init__(self, message_id: str | None = None):
"""Initialize the adapter.
Args:
message_id: Optional message ID. If not provided, one will be generated.
"""
self.message_id = message_id or str(uuid.uuid4())
self.text_block_id = str(uuid.uuid4())
self.has_started_text = False
self.has_ended_text = False
self.current_tool_calls: dict[str, dict[str, Any]] = {}
self.current_tool_calls: dict[str, dict[str, str]] = {}
self.task_id: str | None = None
def set_task_id(self, task_id: str) -> None:
"""Set the task ID for reconnection support."""
self.task_id = task_id
def convert_message(self, sdk_message: Any) -> list[StreamBaseResponse]:
"""Convert a single SDK message to Vercel AI SDK format.
Args:
sdk_message: A message from the Claude Agent SDK.
Returns:
List of StreamBaseResponse objects (may be empty or multiple).
"""
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
"""Convert a single SDK message to Vercel AI SDK format."""
responses: list[StreamBaseResponse] = []
# Handle different SDK message types - use class name since SDK uses dataclasses
class_name = type(sdk_message).__name__
msg_subtype = getattr(sdk_message, "subtype", None)
if class_name == "SystemMessage":
if msg_subtype == "init":
# Session initialization - emit start
if isinstance(sdk_message, SystemMessage):
if sdk_message.subtype == "init":
responses.append(
StreamStart(
messageId=self.message_id,
taskId=self.task_id,
)
StreamStart(messageId=self.message_id, taskId=self.task_id)
)
elif class_name == "AssistantMessage":
# Assistant message with content blocks
content = getattr(sdk_message, "content", [])
for block in content:
# Check block type by class name (SDK uses dataclasses) or dict type
block_class = type(block).__name__
block_type = block.get("type") if isinstance(block, dict) else None
if block_class == "TextBlock" or block_type == "text":
# Text content
text = getattr(block, "text", None) or (
block.get("text") if isinstance(block, dict) else ""
)
if text:
# Start text block if needed (or restart after tool calls)
if not self.has_started_text or self.has_ended_text:
# Generate new text block ID for text after tools
if self.has_ended_text:
self.text_block_id = str(uuid.uuid4())
self.has_ended_text = False
responses.append(StreamTextStart(id=self.text_block_id))
self.has_started_text = True
# Emit text delta
elif isinstance(sdk_message, AssistantMessage):
for block in sdk_message.content:
if isinstance(block, TextBlock):
if block.text:
self._ensure_text_started(responses)
responses.append(
StreamTextDelta(
id=self.text_block_id,
delta=text,
)
StreamTextDelta(id=self.text_block_id, delta=block.text)
)
elif block_class == "ToolUseBlock" or block_type == "tool_use":
# Tool call
tool_id_raw = getattr(block, "id", None) or (
block.get("id") if isinstance(block, dict) else None
)
tool_id: str = (
str(tool_id_raw) if tool_id_raw else str(uuid.uuid4())
)
elif isinstance(block, ToolUseBlock):
self._end_text_if_open(responses)
tool_name_raw = getattr(block, "name", None) or (
block.get("name") if isinstance(block, dict) else None
)
tool_name: str = str(tool_name_raw) if tool_name_raw else "unknown"
tool_input = getattr(block, "input", None) or (
block.get("input") if isinstance(block, dict) else {}
)
# End text block if we were streaming text
if self.has_started_text and not self.has_ended_text:
responses.append(StreamTextEnd(id=self.text_block_id))
self.has_ended_text = True
# Emit tool input start
responses.append(
StreamToolInputStart(
toolCallId=tool_id,
toolName=tool_name,
)
StreamToolInputStart(toolCallId=block.id, toolName=block.name)
)
# Emit tool input available with full input
responses.append(
StreamToolInputAvailable(
toolCallId=tool_id,
toolCallId=block.id,
toolName=block.name,
input=block.input,
)
)
self.current_tool_calls[block.id] = {"name": block.name}
elif isinstance(sdk_message, UserMessage):
# UserMessage carries tool results back from tool execution
content = sdk_message.content
blocks = content if isinstance(content, list) else []
for block in blocks:
if isinstance(block, ToolResultBlock) and block.tool_use_id:
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
tool_name = tool_info.get("name", "unknown")
output = _extract_tool_output(block.content)
responses.append(
StreamToolOutputAvailable(
toolCallId=block.tool_use_id,
toolName=tool_name,
input=tool_input if isinstance(tool_input, dict) else {},
output=output,
success=not (block.is_error or False),
)
)
# Track the tool call
self.current_tool_calls[tool_id] = {
"name": tool_name,
"input": tool_input,
}
elif class_name in ("ToolResultMessage", "UserMessage"):
# Tool result - check for tool_result content
content = getattr(sdk_message, "content", [])
for block in content:
block_class = type(block).__name__
block_type = block.get("type") if isinstance(block, dict) else None
if block_class == "ToolResultBlock" or block_type == "tool_result":
tool_use_id = getattr(block, "tool_use_id", None) or (
block.get("tool_use_id") if isinstance(block, dict) else None
)
result_content = getattr(block, "content", None) or (
block.get("content") if isinstance(block, dict) else ""
)
is_error = getattr(block, "is_error", False) or (
block.get("is_error", False)
if isinstance(block, dict)
else False
)
if tool_use_id:
tool_info = self.current_tool_calls.get(tool_use_id, {})
tool_name = tool_info.get("name", "unknown")
# Format the output
if isinstance(result_content, list):
# Extract text from content blocks
output_text = ""
for item in result_content:
if (
isinstance(item, dict)
and item.get("type") == "text"
):
output_text += item.get("text", "")
elif hasattr(item, "text"):
output_text += getattr(item, "text", "")
if output_text:
output = output_text
else:
try:
output = json.dumps(result_content)
except (TypeError, ValueError):
output = str(result_content)
elif isinstance(result_content, str):
output = result_content
else:
try:
output = json.dumps(result_content)
except (TypeError, ValueError):
output = str(result_content)
responses.append(
StreamToolOutputAvailable(
toolCallId=tool_use_id,
toolName=tool_name,
output=output,
success=not is_error,
)
)
elif class_name == "ResultMessage":
# Final result
if msg_subtype == "success":
# End text block if still open
if self.has_started_text and not self.has_ended_text:
responses.append(StreamTextEnd(id=self.text_block_id))
self.has_ended_text = True
# Emit finish
elif isinstance(sdk_message, ResultMessage):
if sdk_message.subtype == "success":
self._end_text_if_open(responses)
responses.append(StreamFinish())
elif msg_subtype in ("error", "error_during_execution"):
error_msg = getattr(sdk_message, "error", "Unknown error")
elif sdk_message.subtype in ("error", "error_during_execution"):
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
responses.append(
StreamError(
errorText=str(error_msg),
code="sdk_error",
)
StreamError(errorText=str(error_msg), code="sdk_error")
)
responses.append(StreamFinish())
elif class_name == "ErrorMessage":
# Error message
error_msg = getattr(sdk_message, "message", None) or getattr(
sdk_message, "error", "Unknown error"
)
responses.append(
StreamError(
errorText=str(error_msg),
code="sdk_error",
)
)
responses.append(StreamFinish())
else:
logger.debug(f"Unhandled SDK message type: {class_name}")
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
return responses
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
"""Start (or restart) a text block if needed."""
if not self.has_started_text or self.has_ended_text:
if self.has_ended_text:
self.text_block_id = str(uuid.uuid4())
self.has_ended_text = False
responses.append(StreamTextStart(id=self.text_block_id))
self.has_started_text = True
def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None:
"""End the current text block if one is open."""
if self.has_started_text and not self.has_ended_text:
responses.append(StreamTextEnd(id=self.text_block_id))
self.has_ended_text = True
def create_heartbeat(self, tool_call_id: str | None = None) -> StreamHeartbeat:
"""Create a heartbeat response."""
return StreamHeartbeat(toolCallId=tool_call_id)
def create_usage(
self,
prompt_tokens: int,
completion_tokens: int,
) -> StreamUsage:
def create_usage(self, prompt_tokens: int, completion_tokens: int) -> StreamUsage:
"""Create a usage statistics response."""
return StreamUsage(
promptTokens=prompt_tokens,
@@ -272,49 +154,21 @@ class SDKResponseAdapter:
)
async def adapt_sdk_stream(
sdk_stream: AsyncGenerator[Any, None],
message_id: str | None = None,
task_id: str | None = None,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Adapt a Claude Agent SDK stream to Vercel AI SDK format.
Args:
sdk_stream: The async generator from the Claude Agent SDK.
message_id: Optional message ID for the response.
task_id: Optional task ID for reconnection support.
Yields:
StreamBaseResponse objects in Vercel AI SDK format.
"""
adapter = SDKResponseAdapter(message_id=message_id)
if task_id:
adapter.set_task_id(task_id)
# Emit start immediately
yield StreamStart(messageId=adapter.message_id, taskId=task_id)
finished = False
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
"""Extract a string output from a ToolResultBlock's content field."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
if parts:
return "".join(parts)
try:
return json.dumps(content)
except (TypeError, ValueError):
return str(content)
if content is None:
return ""
try:
async for sdk_message in sdk_stream:
responses = adapter.convert_message(sdk_message)
for response in responses:
# Skip duplicate start messages
if isinstance(response, StreamStart):
continue
if isinstance(response, StreamFinish):
finished = True
yield response
except Exception as e:
logger.error(f"Error in SDK stream: {e}", exc_info=True)
yield StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
)
yield StreamFinish()
return
# Ensure terminal StreamFinish if SDK stream ended without one
if not finished:
yield StreamFinish()
return json.dumps(content)
except (TypeError, ValueError):
return str(content)

View File

@@ -0,0 +1,324 @@
"""Unit tests for the SDK response adapter."""
from claude_agent_sdk import (
AssistantMessage,
ResultMessage,
SystemMessage,
TextBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
from backend.api.features.chat.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamStart,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from .response_adapter import SDKResponseAdapter
def _adapter() -> SDKResponseAdapter:
a = SDKResponseAdapter(message_id="msg-1")
a.set_task_id("task-1")
return a
# -- SystemMessage -----------------------------------------------------------
def test_system_init_emits_start():
adapter = _adapter()
results = adapter.convert_message(SystemMessage(subtype="init", data={}))
assert len(results) == 1
assert isinstance(results[0], StreamStart)
assert results[0].messageId == "msg-1"
assert results[0].taskId == "task-1"
def test_system_non_init_emits_nothing():
adapter = _adapter()
results = adapter.convert_message(SystemMessage(subtype="other", data={}))
assert results == []
# -- AssistantMessage with TextBlock -----------------------------------------
def test_text_block_emits_start_and_delta():
adapter = _adapter()
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
results = adapter.convert_message(msg)
assert len(results) == 2
assert isinstance(results[0], StreamTextStart)
assert isinstance(results[1], StreamTextDelta)
assert results[1].delta == "hello"
def test_empty_text_block_is_skipped():
adapter = _adapter()
msg = AssistantMessage(content=[TextBlock(text="")], model="test")
results = adapter.convert_message(msg)
assert results == []
def test_multiple_text_deltas_reuse_block_id():
adapter = _adapter()
msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test")
msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test")
r1 = adapter.convert_message(msg1)
r2 = adapter.convert_message(msg2)
# First gets start+delta, second only delta (block already started)
assert len(r1) == 2
assert len(r2) == 1
assert isinstance(r2[0], StreamTextDelta)
assert isinstance(r1[0], StreamTextStart)
assert r1[0].id == r2[0].id # same block ID
# -- AssistantMessage with ToolUseBlock --------------------------------------
def test_tool_use_emits_input_start_and_available():
adapter = _adapter()
msg = AssistantMessage(
content=[ToolUseBlock(id="tool-1", name="find_agent", input={"q": "x"})],
model="test",
)
results = adapter.convert_message(msg)
assert len(results) == 2
assert isinstance(results[0], StreamToolInputStart)
assert results[0].toolCallId == "tool-1"
assert results[0].toolName == "find_agent"
assert isinstance(results[1], StreamToolInputAvailable)
assert results[1].input == {"q": "x"}
def test_text_then_tool_ends_text_block():
adapter = _adapter()
text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
tool_msg = AssistantMessage(
content=[ToolUseBlock(id="t1", name="tool", input={})], model="test"
)
adapter.convert_message(text_msg)
results = adapter.convert_message(tool_msg)
# Should have: TextEnd, ToolInputStart, ToolInputAvailable
assert len(results) == 3
assert isinstance(results[0], StreamTextEnd)
assert isinstance(results[1], StreamToolInputStart)
# -- UserMessage with ToolResultBlock ----------------------------------------
def test_tool_result_emits_output():
adapter = _adapter()
# First register the tool call
tool_msg = AssistantMessage(
content=[ToolUseBlock(id="t1", name="find_agent", input={})], model="test"
)
adapter.convert_message(tool_msg)
# Now send tool result
result_msg = UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")]
)
results = adapter.convert_message(result_msg)
assert len(results) == 1
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].toolCallId == "t1"
assert results[0].toolName == "find_agent"
assert results[0].output == "found 3 agents"
assert results[0].success is True
def test_tool_result_error():
adapter = _adapter()
adapter.convert_message(
AssistantMessage(
content=[ToolUseBlock(id="t1", name="run_agent", input={})], model="test"
)
)
result_msg = UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)]
)
results = adapter.convert_message(result_msg)
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].success is False
def test_tool_result_list_content():
adapter = _adapter()
adapter.convert_message(
AssistantMessage(
content=[ToolUseBlock(id="t1", name="tool", input={})], model="test"
)
)
result_msg = UserMessage(
content=[
ToolResultBlock(
tool_use_id="t1",
content=[
{"type": "text", "text": "line1"},
{"type": "text", "text": "line2"},
],
)
]
)
results = adapter.convert_message(result_msg)
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].output == "line1line2"
def test_string_user_message_ignored():
"""A plain string UserMessage (not tool results) produces no output."""
adapter = _adapter()
results = adapter.convert_message(UserMessage(content="hello"))
assert results == []
# -- ResultMessage -----------------------------------------------------------
def test_result_success_emits_finish():
adapter = _adapter()
# Start some text first
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="done")], model="test")
)
msg = ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="s1",
)
results = adapter.convert_message(msg)
# TextEnd + StreamFinish
assert len(results) == 2
assert isinstance(results[0], StreamTextEnd)
assert isinstance(results[1], StreamFinish)
def test_result_error_emits_error_and_finish():
adapter = _adapter()
msg = ResultMessage(
subtype="error",
duration_ms=100,
duration_api_ms=50,
is_error=True,
num_turns=0,
session_id="s1",
result="API rate limited",
)
results = adapter.convert_message(msg)
assert len(results) == 2
assert isinstance(results[0], StreamError)
assert "API rate limited" in results[0].errorText
assert isinstance(results[1], StreamFinish)
# -- Text after tools (new block ID) ----------------------------------------
def test_text_after_tool_gets_new_block_id():
adapter = _adapter()
# Text -> Tool -> Text should get a new text block ID
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="before")], model="test")
)
adapter.convert_message(
AssistantMessage(
content=[ToolUseBlock(id="t1", name="tool", input={})], model="test"
)
)
results = adapter.convert_message(
AssistantMessage(content=[TextBlock(text="after")], model="test")
)
# Should get StreamTextStart (new block) + StreamTextDelta
assert len(results) == 2
assert isinstance(results[0], StreamTextStart)
assert isinstance(results[1], StreamTextDelta)
assert results[1].delta == "after"
# -- Full conversation flow --------------------------------------------------
def test_full_conversation_flow():
"""Simulate a complete conversation: init -> text -> tool -> result -> text -> finish."""
adapter = _adapter()
all_responses: list[StreamBaseResponse] = []
# 1. Init
all_responses.extend(
adapter.convert_message(SystemMessage(subtype="init", data={}))
)
# 2. Assistant text
all_responses.extend(
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="Let me search")], model="test")
)
)
# 3. Tool use
all_responses.extend(
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(id="t1", name="find_agent", input={"query": "email"})
],
model="test",
)
)
)
# 4. Tool result
all_responses.extend(
adapter.convert_message(
UserMessage(
content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")]
)
)
)
# 5. More text
all_responses.extend(
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="I found 2")], model="test")
)
)
# 6. Result
all_responses.extend(
adapter.convert_message(
ResultMessage(
subtype="success",
duration_ms=500,
duration_api_ms=400,
is_error=False,
num_turns=2,
session_id="s1",
)
)
)
types = [type(r).__name__ for r in all_responses]
assert types == [
"StreamStart",
"StreamTextStart",
"StreamTextDelta", # "Let me search"
"StreamTextEnd", # closed before tool
"StreamToolInputStart",
"StreamToolInputAvailable",
"StreamToolOutputAvailable", # tool result
"StreamTextStart", # new block after tool
"StreamTextDelta", # "I found 2"
"StreamTextEnd", # closed by result
"StreamFinish",
]

View File

@@ -6,7 +6,7 @@ import json
import logging
import os
import uuid
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
import openai
@@ -63,6 +63,58 @@ def _cleanup_sdk_tool_results() -> None:
pass
def _build_conversation_messages(
session: ChatSession,
) -> AsyncIterator[dict[str, Any]]:
"""Build an async iterator of SDK-compatible message dicts from session history.
Yields structured user/assistant turns that the SDK writes directly to the
CLI's stdin. This gives the model native conversation context (enabling
turn-level compaction for long conversations) without any file I/O.
Only prior messages are yielded; the current (last) user message is
appended at the end so the SDK processes it as the new query.
"""
async def _iter() -> AsyncIterator[dict[str, Any]]:
# Yield all messages except the last (current user message)
for msg in session.messages[:-1]:
if msg.role == "user":
yield {
"type": "user",
"message": {
"role": "user",
"content": msg.content or "",
},
"session_id": session.session_id,
}
elif msg.role == "assistant" and msg.content:
yield {
"type": "assistant",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": msg.content}],
},
"session_id": session.session_id,
}
# Skip tool messages — the assistant's text already captures the
# key information and tool IDs won't match across sessions.
# Yield the current user message last
current = session.messages[-1] if session.messages else None
if current and current.role == "user":
yield {
"type": "user",
"message": {
"role": "user",
"content": current.content or "",
},
"session_id": session.session_id,
}
return _iter()
DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations.
Here is everything you know about the current user from previous interactions:
@@ -154,42 +206,6 @@ async def _build_system_prompt(
return DEFAULT_SYSTEM_PROMPT.replace("{users_information}", context), understanding
def _format_conversation_history(session: ChatSession) -> str:
"""Format conversation history as a prompt context.
Passes full history to the SDK — the SDK handles context compaction
automatically when the context window approaches its limit.
"""
if not session.messages:
return ""
# Get all messages except the last user message (which will be the prompt)
messages = session.messages[:-1] if session.messages else []
if not messages:
return ""
history_parts = ["<conversation_history>"]
for msg in messages:
if msg.role == "user":
history_parts.append(f"User: {msg.content or ''}")
elif msg.role == "assistant":
# Only include text content, skip tool call metadata
# (tool calls are noise for history context)
if msg.content:
history_parts.append(f"Assistant: {msg.content}")
# Skip tool result messages — they're not useful for conversation context
history_parts.append("</conversation_history>")
history_parts.append("")
history_parts.append(
"Continue this conversation. Respond to the user's latest message:"
)
history_parts.append("")
return "\n".join(history_parts)
async def _generate_session_title(
message: str,
user_id: str | None = None,
@@ -310,127 +326,130 @@ async def stream_chat_completion_sdk(
mcp_servers={"copilot": mcp_server}, # type: ignore[arg-type]
allowed_tools=COPILOT_TOOL_NAMES,
hooks=create_security_hooks(user_id), # type: ignore[arg-type]
continue_conversation=True, # Enable conversation continuation
continue_conversation=True,
)
adapter = SDKResponseAdapter(message_id=message_id)
adapter.set_task_id(task_id)
async with ClaudeSDKClient(options=options) as client:
# Build prompt with conversation history for context
# The SDK doesn't support replaying full conversation history,
# so we include it as context in the prompt
current_message = message or ""
if not current_message and session.messages:
last_user = [m for m in session.messages if m.role == "user"]
if last_user:
current_message = last_user[-1].content or ""
try:
async with ClaudeSDKClient(options=options) as client:
# Determine the current user message
current_message = message or ""
if not current_message and session.messages:
last_user = [m for m in session.messages if m.role == "user"]
if last_user:
current_message = last_user[-1].content or ""
# Include conversation history if there are prior messages
if len(session.messages) > 1:
history_context = _format_conversation_history(session)
prompt = f"{history_context}{current_message}"
else:
prompt = current_message
# Guard against empty messages
if not current_message.strip():
yield StreamError(
errorText="Message cannot be empty.",
code="empty_prompt",
)
yield StreamFinish()
return
logger.info(
f"[SDK] Prompt built: {len(prompt)} chars, "
f"{len(session.messages)} messages in session"
)
# For multi-turn conversations, pass structured history
# as an AsyncIterable so the CLI sees native turns and
# can do turn-level compaction. For first messages, just
# send the string directly.
if len(session.messages) > 1:
history_iter = _build_conversation_messages(session)
await client.query(history_iter, session_id=session_id)
logger.info(
f"[SDK] Structured history: "
f"{len(session.messages) - 1} prior messages"
)
else:
await client.query(current_message, session_id=session_id)
logger.info("[SDK] New conversation")
# Guard against empty prompts
if not prompt.strip():
yield StreamError(
errorText="Message cannot be empty.",
code="empty_prompt",
)
yield StreamFinish()
return
# Track assistant response to save to session
# We may need multiple assistant messages if text comes after tool results
assistant_response = ChatMessage(role="assistant", content="")
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False # Track if we've received tool results
await client.query(prompt, session_id=session_id)
# Receive messages from the SDK
async for sdk_msg in client.receive_messages():
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
continue
yield response
# Track assistant response to save to session
# We may need multiple assistant messages if text comes after tool results
assistant_response = ChatMessage(role="assistant", content="")
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False # Track if we've received tool results
# Accumulate text deltas into assistant response
if isinstance(response, StreamTextDelta):
delta = response.delta or ""
# After tool results, create new assistant message for post-tool text
if has_tool_results and has_appended_assistant:
assistant_response = ChatMessage(
role="assistant", content=delta
)
accumulated_tool_calls = [] # Reset for new message
session.messages.append(assistant_response)
has_tool_results = False
else:
assistant_response.content = (
assistant_response.content or ""
) + delta
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
# Receive messages from the SDK
async for sdk_msg in client.receive_messages():
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
continue
yield response
# Accumulate text deltas into assistant response
if isinstance(response, StreamTextDelta):
delta = response.delta or ""
# After tool results, create new assistant message for post-tool text
if has_tool_results and has_appended_assistant:
assistant_response = ChatMessage(
role="assistant", content=delta
# Track tool calls on the assistant message
elif isinstance(response, StreamToolInputAvailable):
accumulated_tool_calls.append(
{
"id": response.toolCallId,
"type": "function",
"function": {
"name": response.toolName,
"arguments": json.dumps(
response.input or {}
),
},
}
)
accumulated_tool_calls = [] # Reset for new message
session.messages.append(assistant_response)
has_tool_results = False
else:
assistant_response.content = (
assistant_response.content or ""
) + delta
# Update assistant message with tool calls
assistant_response.tool_calls = accumulated_tool_calls
# Append assistant message if not already (tool-only response)
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
# Track tool calls on the assistant message
elif isinstance(response, StreamToolInputAvailable):
accumulated_tool_calls.append(
{
"id": response.toolCallId,
"type": "function",
"function": {
"name": response.toolName,
"arguments": json.dumps(response.input or {}),
},
}
)
# Update assistant message with tool calls
assistant_response.tool_calls = accumulated_tool_calls
# Append assistant message if not already (tool-only response)
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolOutputAvailable):
session.messages.append(
ChatMessage(
role="tool",
content=(
response.output
if isinstance(response.output, str)
else str(response.output)
),
tool_call_id=response.toolCallId,
elif isinstance(response, StreamToolOutputAvailable):
session.messages.append(
ChatMessage(
role="tool",
content=(
response.output
if isinstance(response.output, str)
else str(response.output)
),
tool_call_id=response.toolCallId,
)
)
)
has_tool_results = True
has_tool_results = True
elif isinstance(response, StreamFinish):
stream_completed = True
elif isinstance(response, StreamFinish):
stream_completed = True
# Break out of the message loop if we received finish signal
if stream_completed:
break
# Break out of the message loop if we received finish signal
if stream_completed:
break
# Ensure assistant response is saved even if no text deltas
# (e.g., only tool calls were made)
if (
assistant_response.content or assistant_response.tool_calls
) and not has_appended_assistant:
session.messages.append(assistant_response)
# Ensure assistant response is saved even if no text deltas
# (e.g., only tool calls were made)
if (
assistant_response.content or assistant_response.tool_calls
) and not has_appended_assistant:
session.messages.append(assistant_response)
# Clean up SDK tool-result files to avoid accumulation
_cleanup_sdk_tool_results()
finally:
# Always clean up SDK tool-result files, even on error
_cleanup_sdk_tool_results()
except ImportError:
logger.warning(