fix(copilot): don't flush parallel tool calls prematurely

The SDK sends parallel tool calls as separate AssistantMessages each
containing only ToolUseBlocks.  The flush logic treated each new
AssistantMessage as a new turn and prematurely emitted empty output for
prior tools, causing spinners to disappear and the stream to appear
stuck.

Skip flush and wait_for_stash when the incoming AssistantMessage is a
parallel continuation (contains only ToolUseBlocks).  Also prevent
duplicate StreamToolOutputAvailable for already-resolved tool calls.
This commit is contained in:
Zamil Majdy
2026-02-20 11:43:44 +07:00
parent 3a38b5e9bd
commit a408b45542
3 changed files with 132 additions and 5 deletions

View File

@@ -83,7 +83,12 @@ class SDKResponseAdapter:
elif isinstance(sdk_message, AssistantMessage):
# Flush any SDK built-in tool calls that didn't get a UserMessage
# result (e.g. WebSearch, Read handled internally by the CLI).
self._flush_unresolved_tool_calls(responses)
# BUT skip flush when this AssistantMessage is a parallel tool
# continuation (contains only ToolUseBlocks) — the prior tools
# are still executing concurrently and haven't finished yet.
is_tool_only = all(isinstance(b, ToolUseBlock) for b in sdk_message.content)
if not is_tool_only:
self._flush_unresolved_tool_calls(responses)
# After tool results, the SDK sends a new AssistantMessage for the
# next LLM turn. Open a new step if the previous one was closed.
@@ -126,6 +131,11 @@ class SDKResponseAdapter:
for block in blocks:
if isinstance(block, ToolResultBlock) and block.tool_use_id:
# Skip if already resolved (e.g. by flush) — the real
# result supersedes the empty flush, but re-emitting
# would confuse the frontend's state machine.
if block.tool_use_id in self.resolved_tool_calls:
continue
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
tool_name = tool_info.get("name", "unknown")
@@ -150,7 +160,11 @@ class SDKResponseAdapter:
# Handle SDK built-in tool results carried via parent_tool_use_id
# instead of (or in addition to) ToolResultBlock content.
parent_id = sdk_message.parent_tool_use_id
if parent_id and parent_id not in resolved_in_blocks:
if (
parent_id
and parent_id not in resolved_in_blocks
and parent_id not in self.resolved_tool_calls
):
tool_info = self.current_tool_calls.get(parent_id, {})
tool_name = tool_info.get("name", "unknown")

View File

@@ -580,3 +580,101 @@ async def test_wait_for_stash_already_stashed():
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_stash_event.set(None)
# -- Parallel tool call tests --
def test_parallel_tool_calls_not_flushed_prematurely():
"""Parallel tool calls should NOT be flushed when the next AssistantMessage
only contains ToolUseBlocks (parallel continuation)."""
adapter = SDKResponseAdapter()
# Init
adapter.convert_message(SystemMessage(subtype="init", data={}))
# First AssistantMessage: tool call #1
msg1 = AssistantMessage(
content=[ToolUseBlock(id="t1", name="WebSearch", input={"q": "foo"})],
model="test",
)
r1 = adapter.convert_message(msg1)
assert any(isinstance(r, StreamToolInputAvailable) for r in r1)
assert adapter.has_unresolved_tool_calls
# Second AssistantMessage: tool call #2 (parallel continuation)
msg2 = AssistantMessage(
content=[ToolUseBlock(id="t2", name="WebSearch", input={"q": "bar"})],
model="test",
)
r2 = adapter.convert_message(msg2)
# No flush should have happened — t1 should NOT have StreamToolOutputAvailable
output_events = [r for r in r2 if isinstance(r, StreamToolOutputAvailable)]
assert len(output_events) == 0, (
f"Tool-only AssistantMessage should not flush prior tools, "
f"but got {len(output_events)} output events"
)
# Both t1 and t2 should still be unresolved
assert "t1" not in adapter.resolved_tool_calls
assert "t2" not in adapter.resolved_tool_calls
def test_text_assistant_message_flushes_prior_tools():
"""An AssistantMessage with text (new turn) should flush unresolved tools."""
adapter = SDKResponseAdapter()
# Init
adapter.convert_message(SystemMessage(subtype="init", data={}))
# Tool call
msg1 = AssistantMessage(
content=[ToolUseBlock(id="t1", name="WebSearch", input={"q": "foo"})],
model="test",
)
adapter.convert_message(msg1)
assert adapter.has_unresolved_tool_calls
# Text AssistantMessage (new turn after tools completed)
msg2 = AssistantMessage(
content=[TextBlock(text="Here are the results")],
model="test",
)
r2 = adapter.convert_message(msg2)
# Flush SHOULD have happened — t1 gets empty output
output_events = [r for r in r2 if isinstance(r, StreamToolOutputAvailable)]
assert len(output_events) == 1
assert output_events[0].toolCallId == "t1"
assert "t1" in adapter.resolved_tool_calls
def test_already_resolved_tool_skipped_in_user_message():
"""A tool result in UserMessage should be skipped if already resolved by flush."""
adapter = SDKResponseAdapter()
adapter.convert_message(SystemMessage(subtype="init", data={}))
# Tool call + flush via text message
adapter.convert_message(
AssistantMessage(
content=[ToolUseBlock(id="t1", name="WebSearch", input={})],
model="test",
)
)
adapter.convert_message(
AssistantMessage(
content=[TextBlock(text="Done")],
model="test",
)
)
assert "t1" in adapter.resolved_tool_calls
# Now UserMessage arrives with the real result — should be skipped
user_msg = UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="real")])
r = adapter.convert_message(user_msg)
output_events = [r_ for r_ in r if isinstance(r_, StreamToolOutputAvailable)]
assert (
len(output_events) == 0
), "Already-resolved tool should not emit duplicate output"

View File

@@ -745,10 +745,25 @@ async def stream_chat_completion_sdk(
# awaits an asyncio.Event signaled by stash_pending_tool_output(),
# completing as soon as the hook finishes (typically <1ms).
# The sleep(0) after lets any remaining concurrent hooks complete.
from claude_agent_sdk import AssistantMessage, ResultMessage
#
# Skip for parallel tool continuations: when the SDK sends
# parallel tool calls as separate AssistantMessages (each
# containing only ToolUseBlocks), we must NOT wait/flush
# — the prior tools are still executing concurrently.
from claude_agent_sdk import (
AssistantMessage,
ResultMessage,
ToolUseBlock,
)
if adapter.has_unresolved_tool_calls and isinstance(
sdk_msg, (AssistantMessage, ResultMessage)
is_parallel_continuation = isinstance(
sdk_msg, AssistantMessage
) and all(isinstance(b, ToolUseBlock) for b in sdk_msg.content)
if (
adapter.has_unresolved_tool_calls
and isinstance(sdk_msg, (AssistantMessage, ResultMessage))
and not is_parallel_continuation
):
if await wait_for_stash(timeout=0.5):
await asyncio.sleep(0)