mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-10 14:55:16 -05:00
refactor(backend/chat): Use proper SDK types and in-memory conversation history
Replace duck typing (class name checks, getattr) with isinstance() using SDK-exported dataclasses. Replace file-based --resume with AsyncIterable message injection for conversation history, eliminating disk I/O. Add 15 unit tests for the response adapter.
This commit is contained in:
@@ -8,7 +8,17 @@ the frontend expects.
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
Message,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from backend.api.features.chat.response_model import (
|
||||
StreamBaseResponse,
|
||||
@@ -36,234 +46,106 @@ class SDKResponseAdapter:
|
||||
"""
|
||||
|
||||
def __init__(self, message_id: str | None = None):
|
||||
"""Initialize the adapter.
|
||||
|
||||
Args:
|
||||
message_id: Optional message ID. If not provided, one will be generated.
|
||||
"""
|
||||
self.message_id = message_id or str(uuid.uuid4())
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
self.has_started_text = False
|
||||
self.has_ended_text = False
|
||||
self.current_tool_calls: dict[str, dict[str, Any]] = {}
|
||||
self.current_tool_calls: dict[str, dict[str, str]] = {}
|
||||
self.task_id: str | None = None
|
||||
|
||||
def set_task_id(self, task_id: str) -> None:
|
||||
"""Set the task ID for reconnection support."""
|
||||
self.task_id = task_id
|
||||
|
||||
def convert_message(self, sdk_message: Any) -> list[StreamBaseResponse]:
|
||||
"""Convert a single SDK message to Vercel AI SDK format.
|
||||
|
||||
Args:
|
||||
sdk_message: A message from the Claude Agent SDK.
|
||||
|
||||
Returns:
|
||||
List of StreamBaseResponse objects (may be empty or multiple).
|
||||
"""
|
||||
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
|
||||
"""Convert a single SDK message to Vercel AI SDK format."""
|
||||
responses: list[StreamBaseResponse] = []
|
||||
|
||||
# Handle different SDK message types - use class name since SDK uses dataclasses
|
||||
class_name = type(sdk_message).__name__
|
||||
msg_subtype = getattr(sdk_message, "subtype", None)
|
||||
|
||||
if class_name == "SystemMessage":
|
||||
if msg_subtype == "init":
|
||||
# Session initialization - emit start
|
||||
if isinstance(sdk_message, SystemMessage):
|
||||
if sdk_message.subtype == "init":
|
||||
responses.append(
|
||||
StreamStart(
|
||||
messageId=self.message_id,
|
||||
taskId=self.task_id,
|
||||
)
|
||||
StreamStart(messageId=self.message_id, taskId=self.task_id)
|
||||
)
|
||||
|
||||
elif class_name == "AssistantMessage":
|
||||
# Assistant message with content blocks
|
||||
content = getattr(sdk_message, "content", [])
|
||||
for block in content:
|
||||
# Check block type by class name (SDK uses dataclasses) or dict type
|
||||
block_class = type(block).__name__
|
||||
block_type = block.get("type") if isinstance(block, dict) else None
|
||||
|
||||
if block_class == "TextBlock" or block_type == "text":
|
||||
# Text content
|
||||
text = getattr(block, "text", None) or (
|
||||
block.get("text") if isinstance(block, dict) else ""
|
||||
)
|
||||
|
||||
if text:
|
||||
# Start text block if needed (or restart after tool calls)
|
||||
if not self.has_started_text or self.has_ended_text:
|
||||
# Generate new text block ID for text after tools
|
||||
if self.has_ended_text:
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
self.has_ended_text = False
|
||||
responses.append(StreamTextStart(id=self.text_block_id))
|
||||
self.has_started_text = True
|
||||
|
||||
# Emit text delta
|
||||
elif isinstance(sdk_message, AssistantMessage):
|
||||
for block in sdk_message.content:
|
||||
if isinstance(block, TextBlock):
|
||||
if block.text:
|
||||
self._ensure_text_started(responses)
|
||||
responses.append(
|
||||
StreamTextDelta(
|
||||
id=self.text_block_id,
|
||||
delta=text,
|
||||
)
|
||||
StreamTextDelta(id=self.text_block_id, delta=block.text)
|
||||
)
|
||||
|
||||
elif block_class == "ToolUseBlock" or block_type == "tool_use":
|
||||
# Tool call
|
||||
tool_id_raw = getattr(block, "id", None) or (
|
||||
block.get("id") if isinstance(block, dict) else None
|
||||
)
|
||||
tool_id: str = (
|
||||
str(tool_id_raw) if tool_id_raw else str(uuid.uuid4())
|
||||
)
|
||||
elif isinstance(block, ToolUseBlock):
|
||||
self._end_text_if_open(responses)
|
||||
|
||||
tool_name_raw = getattr(block, "name", None) or (
|
||||
block.get("name") if isinstance(block, dict) else None
|
||||
)
|
||||
tool_name: str = str(tool_name_raw) if tool_name_raw else "unknown"
|
||||
|
||||
tool_input = getattr(block, "input", None) or (
|
||||
block.get("input") if isinstance(block, dict) else {}
|
||||
)
|
||||
|
||||
# End text block if we were streaming text
|
||||
if self.has_started_text and not self.has_ended_text:
|
||||
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||
self.has_ended_text = True
|
||||
|
||||
# Emit tool input start
|
||||
responses.append(
|
||||
StreamToolInputStart(
|
||||
toolCallId=tool_id,
|
||||
toolName=tool_name,
|
||||
)
|
||||
StreamToolInputStart(toolCallId=block.id, toolName=block.name)
|
||||
)
|
||||
|
||||
# Emit tool input available with full input
|
||||
responses.append(
|
||||
StreamToolInputAvailable(
|
||||
toolCallId=tool_id,
|
||||
toolCallId=block.id,
|
||||
toolName=block.name,
|
||||
input=block.input,
|
||||
)
|
||||
)
|
||||
self.current_tool_calls[block.id] = {"name": block.name}
|
||||
|
||||
elif isinstance(sdk_message, UserMessage):
|
||||
# UserMessage carries tool results back from tool execution
|
||||
content = sdk_message.content
|
||||
blocks = content if isinstance(content, list) else []
|
||||
for block in blocks:
|
||||
if isinstance(block, ToolResultBlock) and block.tool_use_id:
|
||||
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
output = _extract_tool_output(block.content)
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=block.tool_use_id,
|
||||
toolName=tool_name,
|
||||
input=tool_input if isinstance(tool_input, dict) else {},
|
||||
output=output,
|
||||
success=not (block.is_error or False),
|
||||
)
|
||||
)
|
||||
|
||||
# Track the tool call
|
||||
self.current_tool_calls[tool_id] = {
|
||||
"name": tool_name,
|
||||
"input": tool_input,
|
||||
}
|
||||
|
||||
elif class_name in ("ToolResultMessage", "UserMessage"):
|
||||
# Tool result - check for tool_result content
|
||||
content = getattr(sdk_message, "content", [])
|
||||
|
||||
for block in content:
|
||||
block_class = type(block).__name__
|
||||
block_type = block.get("type") if isinstance(block, dict) else None
|
||||
|
||||
if block_class == "ToolResultBlock" or block_type == "tool_result":
|
||||
tool_use_id = getattr(block, "tool_use_id", None) or (
|
||||
block.get("tool_use_id") if isinstance(block, dict) else None
|
||||
)
|
||||
result_content = getattr(block, "content", None) or (
|
||||
block.get("content") if isinstance(block, dict) else ""
|
||||
)
|
||||
is_error = getattr(block, "is_error", False) or (
|
||||
block.get("is_error", False)
|
||||
if isinstance(block, dict)
|
||||
else False
|
||||
)
|
||||
|
||||
if tool_use_id:
|
||||
tool_info = self.current_tool_calls.get(tool_use_id, {})
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
|
||||
# Format the output
|
||||
if isinstance(result_content, list):
|
||||
# Extract text from content blocks
|
||||
output_text = ""
|
||||
for item in result_content:
|
||||
if (
|
||||
isinstance(item, dict)
|
||||
and item.get("type") == "text"
|
||||
):
|
||||
output_text += item.get("text", "")
|
||||
elif hasattr(item, "text"):
|
||||
output_text += getattr(item, "text", "")
|
||||
if output_text:
|
||||
output = output_text
|
||||
else:
|
||||
try:
|
||||
output = json.dumps(result_content)
|
||||
except (TypeError, ValueError):
|
||||
output = str(result_content)
|
||||
elif isinstance(result_content, str):
|
||||
output = result_content
|
||||
else:
|
||||
try:
|
||||
output = json.dumps(result_content)
|
||||
except (TypeError, ValueError):
|
||||
output = str(result_content)
|
||||
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_use_id,
|
||||
toolName=tool_name,
|
||||
output=output,
|
||||
success=not is_error,
|
||||
)
|
||||
)
|
||||
|
||||
elif class_name == "ResultMessage":
|
||||
# Final result
|
||||
if msg_subtype == "success":
|
||||
# End text block if still open
|
||||
if self.has_started_text and not self.has_ended_text:
|
||||
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||
self.has_ended_text = True
|
||||
|
||||
# Emit finish
|
||||
elif isinstance(sdk_message, ResultMessage):
|
||||
if sdk_message.subtype == "success":
|
||||
self._end_text_if_open(responses)
|
||||
responses.append(StreamFinish())
|
||||
|
||||
elif msg_subtype in ("error", "error_during_execution"):
|
||||
error_msg = getattr(sdk_message, "error", "Unknown error")
|
||||
elif sdk_message.subtype in ("error", "error_during_execution"):
|
||||
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
|
||||
responses.append(
|
||||
StreamError(
|
||||
errorText=str(error_msg),
|
||||
code="sdk_error",
|
||||
)
|
||||
StreamError(errorText=str(error_msg), code="sdk_error")
|
||||
)
|
||||
responses.append(StreamFinish())
|
||||
|
||||
elif class_name == "ErrorMessage":
|
||||
# Error message
|
||||
error_msg = getattr(sdk_message, "message", None) or getattr(
|
||||
sdk_message, "error", "Unknown error"
|
||||
)
|
||||
responses.append(
|
||||
StreamError(
|
||||
errorText=str(error_msg),
|
||||
code="sdk_error",
|
||||
)
|
||||
)
|
||||
responses.append(StreamFinish())
|
||||
|
||||
else:
|
||||
logger.debug(f"Unhandled SDK message type: {class_name}")
|
||||
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
|
||||
|
||||
return responses
|
||||
|
||||
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""Start (or restart) a text block if needed."""
|
||||
if not self.has_started_text or self.has_ended_text:
|
||||
if self.has_ended_text:
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
self.has_ended_text = False
|
||||
responses.append(StreamTextStart(id=self.text_block_id))
|
||||
self.has_started_text = True
|
||||
|
||||
def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""End the current text block if one is open."""
|
||||
if self.has_started_text and not self.has_ended_text:
|
||||
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||
self.has_ended_text = True
|
||||
|
||||
def create_heartbeat(self, tool_call_id: str | None = None) -> StreamHeartbeat:
|
||||
"""Create a heartbeat response."""
|
||||
return StreamHeartbeat(toolCallId=tool_call_id)
|
||||
|
||||
def create_usage(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
) -> StreamUsage:
|
||||
def create_usage(self, prompt_tokens: int, completion_tokens: int) -> StreamUsage:
|
||||
"""Create a usage statistics response."""
|
||||
return StreamUsage(
|
||||
promptTokens=prompt_tokens,
|
||||
@@ -272,49 +154,21 @@ class SDKResponseAdapter:
|
||||
)
|
||||
|
||||
|
||||
async def adapt_sdk_stream(
|
||||
sdk_stream: AsyncGenerator[Any, None],
|
||||
message_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Adapt a Claude Agent SDK stream to Vercel AI SDK format.
|
||||
|
||||
Args:
|
||||
sdk_stream: The async generator from the Claude Agent SDK.
|
||||
message_id: Optional message ID for the response.
|
||||
task_id: Optional task ID for reconnection support.
|
||||
|
||||
Yields:
|
||||
StreamBaseResponse objects in Vercel AI SDK format.
|
||||
"""
|
||||
adapter = SDKResponseAdapter(message_id=message_id)
|
||||
if task_id:
|
||||
adapter.set_task_id(task_id)
|
||||
|
||||
# Emit start immediately
|
||||
yield StreamStart(messageId=adapter.message_id, taskId=task_id)
|
||||
|
||||
finished = False
|
||||
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
||||
"""Extract a string output from a ToolResultBlock's content field."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
|
||||
if parts:
|
||||
return "".join(parts)
|
||||
try:
|
||||
return json.dumps(content)
|
||||
except (TypeError, ValueError):
|
||||
return str(content)
|
||||
if content is None:
|
||||
return ""
|
||||
try:
|
||||
async for sdk_message in sdk_stream:
|
||||
responses = adapter.convert_message(sdk_message)
|
||||
for response in responses:
|
||||
# Skip duplicate start messages
|
||||
if isinstance(response, StreamStart):
|
||||
continue
|
||||
if isinstance(response, StreamFinish):
|
||||
finished = True
|
||||
yield response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in SDK stream: {e}", exc_info=True)
|
||||
yield StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="stream_error",
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
# Ensure terminal StreamFinish if SDK stream ended without one
|
||||
if not finished:
|
||||
yield StreamFinish()
|
||||
return json.dumps(content)
|
||||
except (TypeError, ValueError):
|
||||
return str(content)
|
||||
|
||||
@@ -0,0 +1,324 @@
|
||||
"""Unit tests for the SDK response adapter."""
|
||||
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from backend.api.features.chat.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
|
||||
|
||||
def _adapter() -> SDKResponseAdapter:
|
||||
a = SDKResponseAdapter(message_id="msg-1")
|
||||
a.set_task_id("task-1")
|
||||
return a
|
||||
|
||||
|
||||
# -- SystemMessage -----------------------------------------------------------
|
||||
|
||||
|
||||
def test_system_init_emits_start():
|
||||
adapter = _adapter()
|
||||
results = adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], StreamStart)
|
||||
assert results[0].messageId == "msg-1"
|
||||
assert results[0].taskId == "task-1"
|
||||
|
||||
|
||||
def test_system_non_init_emits_nothing():
|
||||
adapter = _adapter()
|
||||
results = adapter.convert_message(SystemMessage(subtype="other", data={}))
|
||||
assert results == []
|
||||
|
||||
|
||||
# -- AssistantMessage with TextBlock -----------------------------------------
|
||||
|
||||
|
||||
def test_text_block_emits_start_and_delta():
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
|
||||
results = adapter.convert_message(msg)
|
||||
assert len(results) == 2
|
||||
assert isinstance(results[0], StreamTextStart)
|
||||
assert isinstance(results[1], StreamTextDelta)
|
||||
assert results[1].delta == "hello"
|
||||
|
||||
|
||||
def test_empty_text_block_is_skipped():
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(content=[TextBlock(text="")], model="test")
|
||||
results = adapter.convert_message(msg)
|
||||
assert results == []
|
||||
|
||||
|
||||
def test_multiple_text_deltas_reuse_block_id():
|
||||
adapter = _adapter()
|
||||
msg1 = AssistantMessage(content=[TextBlock(text="a")], model="test")
|
||||
msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test")
|
||||
r1 = adapter.convert_message(msg1)
|
||||
r2 = adapter.convert_message(msg2)
|
||||
# First gets start+delta, second only delta (block already started)
|
||||
assert len(r1) == 2
|
||||
assert len(r2) == 1
|
||||
assert isinstance(r2[0], StreamTextDelta)
|
||||
assert isinstance(r1[0], StreamTextStart)
|
||||
assert r1[0].id == r2[0].id # same block ID
|
||||
|
||||
|
||||
# -- AssistantMessage with ToolUseBlock --------------------------------------
|
||||
|
||||
|
||||
def test_tool_use_emits_input_start_and_available():
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(
|
||||
content=[ToolUseBlock(id="tool-1", name="find_agent", input={"q": "x"})],
|
||||
model="test",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
assert len(results) == 2
|
||||
assert isinstance(results[0], StreamToolInputStart)
|
||||
assert results[0].toolCallId == "tool-1"
|
||||
assert results[0].toolName == "find_agent"
|
||||
assert isinstance(results[1], StreamToolInputAvailable)
|
||||
assert results[1].input == {"q": "x"}
|
||||
|
||||
|
||||
def test_text_then_tool_ends_text_block():
|
||||
adapter = _adapter()
|
||||
text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
|
||||
tool_msg = AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name="tool", input={})], model="test"
|
||||
)
|
||||
adapter.convert_message(text_msg)
|
||||
results = adapter.convert_message(tool_msg)
|
||||
# Should have: TextEnd, ToolInputStart, ToolInputAvailable
|
||||
assert len(results) == 3
|
||||
assert isinstance(results[0], StreamTextEnd)
|
||||
assert isinstance(results[1], StreamToolInputStart)
|
||||
|
||||
|
||||
# -- UserMessage with ToolResultBlock ----------------------------------------
|
||||
|
||||
|
||||
def test_tool_result_emits_output():
|
||||
adapter = _adapter()
|
||||
# First register the tool call
|
||||
tool_msg = AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name="find_agent", input={})], model="test"
|
||||
)
|
||||
adapter.convert_message(tool_msg)
|
||||
|
||||
# Now send tool result
|
||||
result_msg = UserMessage(
|
||||
content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")]
|
||||
)
|
||||
results = adapter.convert_message(result_msg)
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
||||
assert results[0].toolCallId == "t1"
|
||||
assert results[0].toolName == "find_agent"
|
||||
assert results[0].output == "found 3 agents"
|
||||
assert results[0].success is True
|
||||
|
||||
|
||||
def test_tool_result_error():
|
||||
adapter = _adapter()
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name="run_agent", input={})], model="test"
|
||||
)
|
||||
)
|
||||
result_msg = UserMessage(
|
||||
content=[ToolResultBlock(tool_use_id="t1", content="timeout", is_error=True)]
|
||||
)
|
||||
results = adapter.convert_message(result_msg)
|
||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
||||
assert results[0].success is False
|
||||
|
||||
|
||||
def test_tool_result_list_content():
|
||||
adapter = _adapter()
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name="tool", input={})], model="test"
|
||||
)
|
||||
)
|
||||
result_msg = UserMessage(
|
||||
content=[
|
||||
ToolResultBlock(
|
||||
tool_use_id="t1",
|
||||
content=[
|
||||
{"type": "text", "text": "line1"},
|
||||
{"type": "text", "text": "line2"},
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
results = adapter.convert_message(result_msg)
|
||||
assert isinstance(results[0], StreamToolOutputAvailable)
|
||||
assert results[0].output == "line1line2"
|
||||
|
||||
|
||||
def test_string_user_message_ignored():
|
||||
"""A plain string UserMessage (not tool results) produces no output."""
|
||||
adapter = _adapter()
|
||||
results = adapter.convert_message(UserMessage(content="hello"))
|
||||
assert results == []
|
||||
|
||||
|
||||
# -- ResultMessage -----------------------------------------------------------
|
||||
|
||||
|
||||
def test_result_success_emits_finish():
|
||||
adapter = _adapter()
|
||||
# Start some text first
|
||||
adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="done")], model="test")
|
||||
)
|
||||
msg = ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
# TextEnd + StreamFinish
|
||||
assert len(results) == 2
|
||||
assert isinstance(results[0], StreamTextEnd)
|
||||
assert isinstance(results[1], StreamFinish)
|
||||
|
||||
|
||||
def test_result_error_emits_error_and_finish():
|
||||
adapter = _adapter()
|
||||
msg = ResultMessage(
|
||||
subtype="error",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=True,
|
||||
num_turns=0,
|
||||
session_id="s1",
|
||||
result="API rate limited",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
assert len(results) == 2
|
||||
assert isinstance(results[0], StreamError)
|
||||
assert "API rate limited" in results[0].errorText
|
||||
assert isinstance(results[1], StreamFinish)
|
||||
|
||||
|
||||
# -- Text after tools (new block ID) ----------------------------------------
|
||||
|
||||
|
||||
def test_text_after_tool_gets_new_block_id():
|
||||
adapter = _adapter()
|
||||
# Text -> Tool -> Text should get a new text block ID
|
||||
adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="before")], model="test")
|
||||
)
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name="tool", input={})], model="test"
|
||||
)
|
||||
)
|
||||
results = adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="after")], model="test")
|
||||
)
|
||||
# Should get StreamTextStart (new block) + StreamTextDelta
|
||||
assert len(results) == 2
|
||||
assert isinstance(results[0], StreamTextStart)
|
||||
assert isinstance(results[1], StreamTextDelta)
|
||||
assert results[1].delta == "after"
|
||||
|
||||
|
||||
# -- Full conversation flow --------------------------------------------------
|
||||
|
||||
|
||||
def test_full_conversation_flow():
|
||||
"""Simulate a complete conversation: init -> text -> tool -> result -> text -> finish."""
|
||||
adapter = _adapter()
|
||||
all_responses: list[StreamBaseResponse] = []
|
||||
|
||||
# 1. Init
|
||||
all_responses.extend(
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
)
|
||||
# 2. Assistant text
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="Let me search")], model="test")
|
||||
)
|
||||
)
|
||||
# 3. Tool use
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(id="t1", name="find_agent", input={"query": "email"})
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
)
|
||||
# 4. Tool result
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
UserMessage(
|
||||
content=[ToolResultBlock(tool_use_id="t1", content="Found 2 agents")]
|
||||
)
|
||||
)
|
||||
)
|
||||
# 5. More text
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="I found 2")], model="test")
|
||||
)
|
||||
)
|
||||
# 6. Result
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=500,
|
||||
duration_api_ms=400,
|
||||
is_error=False,
|
||||
num_turns=2,
|
||||
session_id="s1",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
types = [type(r).__name__ for r in all_responses]
|
||||
assert types == [
|
||||
"StreamStart",
|
||||
"StreamTextStart",
|
||||
"StreamTextDelta", # "Let me search"
|
||||
"StreamTextEnd", # closed before tool
|
||||
"StreamToolInputStart",
|
||||
"StreamToolInputAvailable",
|
||||
"StreamToolOutputAvailable", # tool result
|
||||
"StreamTextStart", # new block after tool
|
||||
"StreamTextDelta", # "I found 2"
|
||||
"StreamTextEnd", # closed by result
|
||||
"StreamFinish",
|
||||
]
|
||||
@@ -6,7 +6,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
@@ -63,6 +63,58 @@ def _cleanup_sdk_tool_results() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _build_conversation_messages(
|
||||
session: ChatSession,
|
||||
) -> AsyncIterator[dict[str, Any]]:
|
||||
"""Build an async iterator of SDK-compatible message dicts from session history.
|
||||
|
||||
Yields structured user/assistant turns that the SDK writes directly to the
|
||||
CLI's stdin. This gives the model native conversation context (enabling
|
||||
turn-level compaction for long conversations) without any file I/O.
|
||||
|
||||
Only prior messages are yielded; the current (last) user message is
|
||||
appended at the end so the SDK processes it as the new query.
|
||||
"""
|
||||
|
||||
async def _iter() -> AsyncIterator[dict[str, Any]]:
|
||||
# Yield all messages except the last (current user message)
|
||||
for msg in session.messages[:-1]:
|
||||
if msg.role == "user":
|
||||
yield {
|
||||
"type": "user",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": msg.content or "",
|
||||
},
|
||||
"session_id": session.session_id,
|
||||
}
|
||||
elif msg.role == "assistant" and msg.content:
|
||||
yield {
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": msg.content}],
|
||||
},
|
||||
"session_id": session.session_id,
|
||||
}
|
||||
# Skip tool messages — the assistant's text already captures the
|
||||
# key information and tool IDs won't match across sessions.
|
||||
|
||||
# Yield the current user message last
|
||||
current = session.messages[-1] if session.messages else None
|
||||
if current and current.role == "user":
|
||||
yield {
|
||||
"type": "user",
|
||||
"message": {
|
||||
"role": "user",
|
||||
"content": current.content or "",
|
||||
},
|
||||
"session_id": session.session_id,
|
||||
}
|
||||
|
||||
return _iter()
|
||||
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = """You are **Otto**, an AI Co-Pilot for AutoGPT and a Forward-Deployed Automation Engineer serving small business owners. Your mission is to help users automate business tasks with AI by delivering tangible value through working automations—not through documentation or lengthy explanations.
|
||||
|
||||
Here is everything you know about the current user from previous interactions:
|
||||
@@ -154,42 +206,6 @@ async def _build_system_prompt(
|
||||
return DEFAULT_SYSTEM_PROMPT.replace("{users_information}", context), understanding
|
||||
|
||||
|
||||
def _format_conversation_history(session: ChatSession) -> str:
|
||||
"""Format conversation history as a prompt context.
|
||||
|
||||
Passes full history to the SDK — the SDK handles context compaction
|
||||
automatically when the context window approaches its limit.
|
||||
"""
|
||||
if not session.messages:
|
||||
return ""
|
||||
|
||||
# Get all messages except the last user message (which will be the prompt)
|
||||
messages = session.messages[:-1] if session.messages else []
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
history_parts = ["<conversation_history>"]
|
||||
|
||||
for msg in messages:
|
||||
if msg.role == "user":
|
||||
history_parts.append(f"User: {msg.content or ''}")
|
||||
elif msg.role == "assistant":
|
||||
# Only include text content, skip tool call metadata
|
||||
# (tool calls are noise for history context)
|
||||
if msg.content:
|
||||
history_parts.append(f"Assistant: {msg.content}")
|
||||
# Skip tool result messages — they're not useful for conversation context
|
||||
|
||||
history_parts.append("</conversation_history>")
|
||||
history_parts.append("")
|
||||
history_parts.append(
|
||||
"Continue this conversation. Respond to the user's latest message:"
|
||||
)
|
||||
history_parts.append("")
|
||||
|
||||
return "\n".join(history_parts)
|
||||
|
||||
|
||||
async def _generate_session_title(
|
||||
message: str,
|
||||
user_id: str | None = None,
|
||||
@@ -310,127 +326,130 @@ async def stream_chat_completion_sdk(
|
||||
mcp_servers={"copilot": mcp_server}, # type: ignore[arg-type]
|
||||
allowed_tools=COPILOT_TOOL_NAMES,
|
||||
hooks=create_security_hooks(user_id), # type: ignore[arg-type]
|
||||
continue_conversation=True, # Enable conversation continuation
|
||||
continue_conversation=True,
|
||||
)
|
||||
|
||||
adapter = SDKResponseAdapter(message_id=message_id)
|
||||
adapter.set_task_id(task_id)
|
||||
|
||||
async with ClaudeSDKClient(options=options) as client:
|
||||
# Build prompt with conversation history for context
|
||||
# The SDK doesn't support replaying full conversation history,
|
||||
# so we include it as context in the prompt
|
||||
current_message = message or ""
|
||||
if not current_message and session.messages:
|
||||
last_user = [m for m in session.messages if m.role == "user"]
|
||||
if last_user:
|
||||
current_message = last_user[-1].content or ""
|
||||
try:
|
||||
async with ClaudeSDKClient(options=options) as client:
|
||||
# Determine the current user message
|
||||
current_message = message or ""
|
||||
if not current_message and session.messages:
|
||||
last_user = [m for m in session.messages if m.role == "user"]
|
||||
if last_user:
|
||||
current_message = last_user[-1].content or ""
|
||||
|
||||
# Include conversation history if there are prior messages
|
||||
if len(session.messages) > 1:
|
||||
history_context = _format_conversation_history(session)
|
||||
prompt = f"{history_context}{current_message}"
|
||||
else:
|
||||
prompt = current_message
|
||||
# Guard against empty messages
|
||||
if not current_message.strip():
|
||||
yield StreamError(
|
||||
errorText="Message cannot be empty.",
|
||||
code="empty_prompt",
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[SDK] Prompt built: {len(prompt)} chars, "
|
||||
f"{len(session.messages)} messages in session"
|
||||
)
|
||||
# For multi-turn conversations, pass structured history
|
||||
# as an AsyncIterable so the CLI sees native turns and
|
||||
# can do turn-level compaction. For first messages, just
|
||||
# send the string directly.
|
||||
if len(session.messages) > 1:
|
||||
history_iter = _build_conversation_messages(session)
|
||||
await client.query(history_iter, session_id=session_id)
|
||||
logger.info(
|
||||
f"[SDK] Structured history: "
|
||||
f"{len(session.messages) - 1} prior messages"
|
||||
)
|
||||
else:
|
||||
await client.query(current_message, session_id=session_id)
|
||||
logger.info("[SDK] New conversation")
|
||||
|
||||
# Guard against empty prompts
|
||||
if not prompt.strip():
|
||||
yield StreamError(
|
||||
errorText="Message cannot be empty.",
|
||||
code="empty_prompt",
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
# Track assistant response to save to session
|
||||
# We may need multiple assistant messages if text comes after tool results
|
||||
assistant_response = ChatMessage(role="assistant", content="")
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
has_appended_assistant = False
|
||||
has_tool_results = False # Track if we've received tool results
|
||||
|
||||
await client.query(prompt, session_id=session_id)
|
||||
# Receive messages from the SDK
|
||||
async for sdk_msg in client.receive_messages():
|
||||
for response in adapter.convert_message(sdk_msg):
|
||||
if isinstance(response, StreamStart):
|
||||
continue
|
||||
yield response
|
||||
|
||||
# Track assistant response to save to session
|
||||
# We may need multiple assistant messages if text comes after tool results
|
||||
assistant_response = ChatMessage(role="assistant", content="")
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
has_appended_assistant = False
|
||||
has_tool_results = False # Track if we've received tool results
|
||||
# Accumulate text deltas into assistant response
|
||||
if isinstance(response, StreamTextDelta):
|
||||
delta = response.delta or ""
|
||||
# After tool results, create new assistant message for post-tool text
|
||||
if has_tool_results and has_appended_assistant:
|
||||
assistant_response = ChatMessage(
|
||||
role="assistant", content=delta
|
||||
)
|
||||
accumulated_tool_calls = [] # Reset for new message
|
||||
session.messages.append(assistant_response)
|
||||
has_tool_results = False
|
||||
else:
|
||||
assistant_response.content = (
|
||||
assistant_response.content or ""
|
||||
) + delta
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
|
||||
# Receive messages from the SDK
|
||||
async for sdk_msg in client.receive_messages():
|
||||
for response in adapter.convert_message(sdk_msg):
|
||||
if isinstance(response, StreamStart):
|
||||
continue
|
||||
yield response
|
||||
|
||||
# Accumulate text deltas into assistant response
|
||||
if isinstance(response, StreamTextDelta):
|
||||
delta = response.delta or ""
|
||||
# After tool results, create new assistant message for post-tool text
|
||||
if has_tool_results and has_appended_assistant:
|
||||
assistant_response = ChatMessage(
|
||||
role="assistant", content=delta
|
||||
# Track tool calls on the assistant message
|
||||
elif isinstance(response, StreamToolInputAvailable):
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": response.toolCallId,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": response.toolName,
|
||||
"arguments": json.dumps(
|
||||
response.input or {}
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
accumulated_tool_calls = [] # Reset for new message
|
||||
session.messages.append(assistant_response)
|
||||
has_tool_results = False
|
||||
else:
|
||||
assistant_response.content = (
|
||||
assistant_response.content or ""
|
||||
) + delta
|
||||
# Update assistant message with tool calls
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
# Append assistant message if not already (tool-only response)
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
|
||||
# Track tool calls on the assistant message
|
||||
elif isinstance(response, StreamToolInputAvailable):
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": response.toolCallId,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": response.toolName,
|
||||
"arguments": json.dumps(response.input or {}),
|
||||
},
|
||||
}
|
||||
)
|
||||
# Update assistant message with tool calls
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
# Append assistant message if not already (tool-only response)
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
|
||||
elif isinstance(response, StreamToolOutputAvailable):
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=(
|
||||
response.output
|
||||
if isinstance(response.output, str)
|
||||
else str(response.output)
|
||||
),
|
||||
tool_call_id=response.toolCallId,
|
||||
elif isinstance(response, StreamToolOutputAvailable):
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=(
|
||||
response.output
|
||||
if isinstance(response.output, str)
|
||||
else str(response.output)
|
||||
),
|
||||
tool_call_id=response.toolCallId,
|
||||
)
|
||||
)
|
||||
)
|
||||
has_tool_results = True
|
||||
has_tool_results = True
|
||||
|
||||
elif isinstance(response, StreamFinish):
|
||||
stream_completed = True
|
||||
elif isinstance(response, StreamFinish):
|
||||
stream_completed = True
|
||||
|
||||
# Break out of the message loop if we received finish signal
|
||||
if stream_completed:
|
||||
break
|
||||
# Break out of the message loop if we received finish signal
|
||||
if stream_completed:
|
||||
break
|
||||
|
||||
# Ensure assistant response is saved even if no text deltas
|
||||
# (e.g., only tool calls were made)
|
||||
if (
|
||||
assistant_response.content or assistant_response.tool_calls
|
||||
) and not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
# Ensure assistant response is saved even if no text deltas
|
||||
# (e.g., only tool calls were made)
|
||||
if (
|
||||
assistant_response.content or assistant_response.tool_calls
|
||||
) and not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
|
||||
# Clean up SDK tool-result files to avoid accumulation
|
||||
_cleanup_sdk_tool_results()
|
||||
finally:
|
||||
# Always clean up SDK tool-result files, even on error
|
||||
_cleanup_sdk_tool_results()
|
||||
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
|
||||
Reference in New Issue
Block a user