feat(backend/chat): Add StreamStartStep/StreamFinishStep to SDK adapter

The non-SDK path emits step boundaries (StartStep/FinishStep) around
each LLM turn and tool cycle. The SDK adapter was missing these,
causing the frontend to lack visual step framing for tool calls.

Now the SDK adapter emits:
- StreamStartStep after init and before each new LLM turn
- StreamFinishStep after tool results and before final finish
This commit is contained in:
Zamil Majdy
2026-02-11 20:18:27 +04:00
parent 82c483d6c8
commit b14b3803ad
2 changed files with 84 additions and 40 deletions

View File

@@ -24,7 +24,9 @@ from backend.api.features.chat.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
@@ -50,6 +52,7 @@ class SDKResponseAdapter:
self.has_ended_text = False
self.current_tool_calls: dict[str, dict[str, str]] = {}
self.task_id: str | None = None
self.step_open = False
def set_task_id(self, task_id: str) -> None:
"""Set the task ID for reconnection support."""
@@ -64,8 +67,17 @@ class SDKResponseAdapter:
responses.append(
StreamStart(messageId=self.message_id, taskId=self.task_id)
)
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
responses.append(StreamStartStep())
self.step_open = True
elif isinstance(sdk_message, AssistantMessage):
# After tool results, the SDK sends a new AssistantMessage for the
# next LLM turn. Open a new step if the previous one was closed.
if not self.step_open:
responses.append(StreamStartStep())
self.step_open = True
for block in sdk_message.content:
if isinstance(block, TextBlock):
if block.text:
@@ -90,7 +102,7 @@ class SDKResponseAdapter:
self.current_tool_calls[block.id] = {"name": block.name}
elif isinstance(sdk_message, UserMessage):
# UserMessage carries tool results back from tool execution
# UserMessage carries tool results back from tool execution.
content = sdk_message.content
blocks = content if isinstance(content, list) else []
for block in blocks:
@@ -107,11 +119,21 @@ class SDKResponseAdapter:
)
)
elif isinstance(sdk_message, ResultMessage):
if sdk_message.subtype == "success":
self._end_text_if_open(responses)
responses.append(StreamFinish())
# Close the current step after tool results — the next
# AssistantMessage will open a new step for the continuation.
if self.step_open:
responses.append(StreamFinishStep())
self.step_open = False
elif isinstance(sdk_message, ResultMessage):
self._end_text_if_open(responses)
# Close the step before finishing.
if self.step_open:
responses.append(StreamFinishStep())
self.step_open = False
if sdk_message.subtype == "success":
responses.append(StreamFinish())
elif sdk_message.subtype in ("error", "error_during_execution"):
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
responses.append(

View File

@@ -14,7 +14,9 @@ from backend.api.features.chat.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
@@ -35,13 +37,14 @@ def _adapter() -> SDKResponseAdapter:
# -- SystemMessage -----------------------------------------------------------
def test_system_init_emits_start():
def test_system_init_emits_start_and_step():
adapter = _adapter()
results = adapter.convert_message(SystemMessage(subtype="init", data={}))
assert len(results) == 1
assert len(results) == 2
assert isinstance(results[0], StreamStart)
assert results[0].messageId == "msg-1"
assert results[0].taskId == "task-1"
assert isinstance(results[1], StreamStartStep)
def test_system_non_init_emits_nothing():
@@ -53,21 +56,24 @@ def test_system_non_init_emits_nothing():
# -- AssistantMessage with TextBlock -----------------------------------------
def test_text_block_emits_start_and_delta():
def test_text_block_emits_step_start_and_delta():
adapter = _adapter()
msg = AssistantMessage(content=[TextBlock(text="hello")], model="test")
results = adapter.convert_message(msg)
assert len(results) == 2
assert isinstance(results[0], StreamTextStart)
assert isinstance(results[1], StreamTextDelta)
assert results[1].delta == "hello"
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamTextStart)
assert isinstance(results[2], StreamTextDelta)
assert results[2].delta == "hello"
def test_empty_text_block_is_skipped():
def test_empty_text_block_emits_only_step():
adapter = _adapter()
msg = AssistantMessage(content=[TextBlock(text="")], model="test")
results = adapter.convert_message(msg)
assert results == []
# Empty text skipped, but step still opens
assert len(results) == 1
assert isinstance(results[0], StreamStartStep)
def test_multiple_text_deltas_reuse_block_id():
@@ -76,12 +82,13 @@ def test_multiple_text_deltas_reuse_block_id():
msg2 = AssistantMessage(content=[TextBlock(text="b")], model="test")
r1 = adapter.convert_message(msg1)
r2 = adapter.convert_message(msg2)
# First gets start+delta, second only delta (block already started)
assert len(r1) == 2
# First gets step+start+delta, second only delta (block & step already started)
assert len(r1) == 3
assert isinstance(r1[0], StreamStartStep)
assert isinstance(r1[1], StreamTextStart)
assert len(r2) == 1
assert isinstance(r2[0], StreamTextDelta)
assert isinstance(r1[0], StreamTextStart)
assert r1[0].id == r2[0].id # same block ID
assert r1[1].id == r2[0].id # same block ID
# -- AssistantMessage with ToolUseBlock --------------------------------------
@@ -94,12 +101,13 @@ def test_tool_use_emits_input_start_and_available():
model="test",
)
results = adapter.convert_message(msg)
assert len(results) == 2
assert isinstance(results[0], StreamToolInputStart)
assert results[0].toolCallId == "tool-1"
assert results[0].toolName == "find_agent"
assert isinstance(results[1], StreamToolInputAvailable)
assert results[1].input == {"q": "x"}
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamToolInputStart)
assert results[1].toolCallId == "tool-1"
assert results[1].toolName == "find_agent"
assert isinstance(results[2], StreamToolInputAvailable)
assert results[2].input == {"q": "x"}
def test_text_then_tool_ends_text_block():
@@ -108,9 +116,9 @@ def test_text_then_tool_ends_text_block():
tool_msg = AssistantMessage(
content=[ToolUseBlock(id="t1", name="tool", input={})], model="test"
)
adapter.convert_message(text_msg)
adapter.convert_message(text_msg) # opens step + text
results = adapter.convert_message(tool_msg)
# Should have: TextEnd, ToolInputStart, ToolInputAvailable
# Step already open, so: TextEnd, ToolInputStart, ToolInputAvailable
assert len(results) == 3
assert isinstance(results[0], StreamTextEnd)
assert isinstance(results[1], StreamToolInputStart)
@@ -119,9 +127,9 @@ def test_text_then_tool_ends_text_block():
# -- UserMessage with ToolResultBlock ----------------------------------------
def test_tool_result_emits_output():
def test_tool_result_emits_output_and_finish_step():
adapter = _adapter()
# First register the tool call
# First register the tool call (opens step)
tool_msg = AssistantMessage(
content=[ToolUseBlock(id="t1", name="find_agent", input={})], model="test"
)
@@ -132,12 +140,13 @@ def test_tool_result_emits_output():
content=[ToolResultBlock(tool_use_id="t1", content="found 3 agents")]
)
results = adapter.convert_message(result_msg)
assert len(results) == 1
assert len(results) == 2
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].toolCallId == "t1"
assert results[0].toolName == "find_agent"
assert results[0].output == "found 3 agents"
assert results[0].success is True
assert isinstance(results[1], StreamFinishStep)
def test_tool_result_error():
@@ -153,6 +162,7 @@ def test_tool_result_error():
results = adapter.convert_message(result_msg)
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].success is False
assert isinstance(results[1], StreamFinishStep)
def test_tool_result_list_content():
@@ -176,6 +186,7 @@ def test_tool_result_list_content():
results = adapter.convert_message(result_msg)
assert isinstance(results[0], StreamToolOutputAvailable)
assert results[0].output == "line1line2"
assert isinstance(results[1], StreamFinishStep)
def test_string_user_message_ignored():
@@ -188,9 +199,9 @@ def test_string_user_message_ignored():
# -- ResultMessage -----------------------------------------------------------
def test_result_success_emits_finish():
def test_result_success_emits_finish_step_and_finish():
adapter = _adapter()
# Start some text first
# Start some text first (opens step)
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="done")], model="test")
)
@@ -203,10 +214,11 @@ def test_result_success_emits_finish():
session_id="s1",
)
results = adapter.convert_message(msg)
# TextEnd + StreamFinish
assert len(results) == 2
# TextEnd + FinishStep + StreamFinish
assert len(results) == 3
assert isinstance(results[0], StreamTextEnd)
assert isinstance(results[1], StreamFinish)
assert isinstance(results[1], StreamFinishStep)
assert isinstance(results[2], StreamFinish)
def test_result_error_emits_error_and_finish():
@@ -221,6 +233,7 @@ def test_result_error_emits_error_and_finish():
result="API rate limited",
)
results = adapter.convert_message(msg)
# No step was open, so no FinishStep — just Error + Finish
assert len(results) == 2
assert isinstance(results[0], StreamError)
assert "API rate limited" in results[0].errorText
@@ -232,7 +245,7 @@ def test_result_error_emits_error_and_finish():
def test_text_after_tool_gets_new_block_id():
adapter = _adapter()
# Text -> Tool -> Text should get a new text block ID
# Text -> Tool -> ToolResult -> Text should get a new text block ID and step
adapter.convert_message(
AssistantMessage(content=[TextBlock(text="before")], model="test")
)
@@ -241,14 +254,19 @@ def test_text_after_tool_gets_new_block_id():
content=[ToolUseBlock(id="t1", name="tool", input={})], model="test"
)
)
# Send tool result (closes step)
adapter.convert_message(
UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="ok")])
)
results = adapter.convert_message(
AssistantMessage(content=[TextBlock(text="after")], model="test")
)
# Should get StreamTextStart (new block) + StreamTextDelta
assert len(results) == 2
assert isinstance(results[0], StreamTextStart)
assert isinstance(results[1], StreamTextDelta)
assert results[1].delta == "after"
# Should get StreamStartStep (new step) + StreamTextStart (new block) + StreamTextDelta
assert len(results) == 3
assert isinstance(results[0], StreamStartStep)
assert isinstance(results[1], StreamTextStart)
assert isinstance(results[2], StreamTextDelta)
assert results[2].delta == "after"
# -- Full conversation flow --------------------------------------------------
@@ -311,14 +329,18 @@ def test_full_conversation_flow():
types = [type(r).__name__ for r in all_responses]
assert types == [
"StreamStart",
"StreamStartStep", # step 1: text + tool call
"StreamTextStart",
"StreamTextDelta", # "Let me search"
"StreamTextEnd", # closed before tool
"StreamToolInputStart",
"StreamToolInputAvailable",
"StreamToolOutputAvailable", # tool result
"StreamFinishStep", # step 1 closed after tool result
"StreamStartStep", # step 2: continuation text
"StreamTextStart", # new block after tool
"StreamTextDelta", # "I found 2"
"StreamTextEnd", # closed by result
"StreamFinishStep", # step 2 closed
"StreamFinish",
]