fix(copilot): flush unresolved tool calls on stream error to prevent stale frontend UI

When a mid-stream exception interrupts _run_stream_attempt, the post-stream
cleanup (tool call flushing, text-end closing) was bypassed. This left the
frontend with stale UI elements (e.g. active spinners) for tools that started
but never received completion events.

Add error-path cleanup after the retry loop: call _end_text_if_open and
_flush_unresolved_tool_calls on the adapter before yielding the StreamError,
so the frontend receives proper closure events for any in-flight tool calls.
This commit is contained in:
Zamil Majdy
2026-03-14 22:50:31 +07:00
parent fc844fde1f
commit 045096d863
5 changed files with 207 additions and 74 deletions

View File

@@ -115,7 +115,7 @@ class ChatConfig(BaseSettings):
description="E2B sandbox template to use for copilot sessions.",
)
e2b_sandbox_timeout: int = Field(
default=900, # 15 min safety net — raised from 5 min to accommodate compaction retries
default=600, # 10 min safety net — allows headroom for compaction retries
description="E2B sandbox running-time timeout (seconds). "
"E2B timeout is wall-clock (not idle). Explicit per-turn pause is the primary "
"mechanism; this is the safety net.",

View File

@@ -10,6 +10,7 @@ import pytest
from backend.util import json
from .conftest import build_test_transcript as _build_transcript
from .service import _is_prompt_too_long
from .transcript import (
_flatten_assistant_content,
_flatten_tool_result_content,
@@ -55,7 +56,7 @@ class TestFlattenAssistantContent:
]
result = _flatten_assistant_content(blocks)
assert "See this image:" in result
assert "[image]" in result
assert "[__image__]" in result
def test_empty(self):
assert _flatten_assistant_content([]) == ""
@@ -116,7 +117,7 @@ class TestFlattenToolResultContent:
def test_unknown_block_type_preserved_as_placeholder(self):
blocks = [{"type": "image", "source": {"type": "base64", "data": "..."}}]
result = _flatten_tool_result_content(blocks)
assert "[image]" in result
assert "[__image__]" in result
# ---------------------------------------------------------------------------
@@ -374,7 +375,7 @@ class TestCompactTranscript:
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
)(),
):
result = await compact_transcript(transcript)
result = await compact_transcript(transcript, model="test-model")
assert result is None
@pytest.mark.asyncio
@@ -413,7 +414,7 @@ class TestCompactTranscript:
return_value=mock_result,
),
):
result = await compact_transcript(transcript)
result = await compact_transcript(transcript, model="test-model")
assert result is None
@pytest.mark.asyncio
@@ -456,7 +457,7 @@ class TestCompactTranscript:
return_value=mock_result,
),
):
result = await compact_transcript(transcript)
result = await compact_transcript(transcript, model="test-model")
assert result is not None
assert validate_transcript(result)
msgs = _transcript_to_messages(result)
@@ -485,5 +486,67 @@ class TestCompactTranscript:
side_effect=RuntimeError("LLM unavailable"),
),
):
result = await compact_transcript(transcript)
result = await compact_transcript(transcript, model="test-model")
assert result is None
# ---------------------------------------------------------------------------
# _is_prompt_too_long
# ---------------------------------------------------------------------------
class TestIsPromptTooLong:
"""Unit tests for _is_prompt_too_long pattern matching."""
def test_prompt_is_too_long(self):
err = RuntimeError("prompt is too long for model context")
assert _is_prompt_too_long(err) is True
def test_request_too_large(self):
err = Exception("request too large: 250000 tokens")
assert _is_prompt_too_long(err) is True
def test_maximum_context_length(self):
err = ValueError("maximum context length exceeded")
assert _is_prompt_too_long(err) is True
def test_context_length_exceeded(self):
err = Exception("context_length_exceeded")
assert _is_prompt_too_long(err) is True
def test_input_tokens_exceed(self):
err = Exception("input tokens exceed the max_tokens limit")
assert _is_prompt_too_long(err) is True
def test_input_is_too_long(self):
err = Exception("input is too long for the model")
assert _is_prompt_too_long(err) is True
def test_content_length_exceeds(self):
err = Exception("content length exceeds maximum")
assert _is_prompt_too_long(err) is True
def test_unrelated_error_returns_false(self):
err = RuntimeError("network timeout")
assert _is_prompt_too_long(err) is False
def test_auth_error_returns_false(self):
err = Exception("authentication failed: invalid API key")
assert _is_prompt_too_long(err) is False
def test_chained_exception_detected(self):
"""Prompt-too-long error wrapped in another exception is detected."""
inner = RuntimeError("prompt is too long")
outer = Exception("SDK error")
outer.__cause__ = inner
assert _is_prompt_too_long(outer) is True
def test_case_insensitive(self):
err = Exception("PROMPT IS TOO LONG")
assert _is_prompt_too_long(err) is True
def test_old_max_tokens_exceeded_not_matched(self):
"""The old broad 'max_tokens_exceeded' pattern was removed.
Only 'input tokens exceed' should match now."""
err = Exception("max_tokens_exceeded")
assert _is_prompt_too_long(err) is False

View File

@@ -118,7 +118,7 @@ class TestScenarioCompactAndRetry:
return_value=mock_result,
),
):
result = await compact_transcript(original)
result = await compact_transcript(original, model="test-model")
assert result is not None
assert result != original # Must be different
@@ -172,7 +172,7 @@ class TestScenarioCompactFailsFallback:
side_effect=RuntimeError("LLM unavailable"),
),
):
result = await compact_transcript(transcript)
result = await compact_transcript(transcript, model="test-model")
assert result is None
def test_fresh_builder_after_transcript_drop(self):
@@ -263,7 +263,7 @@ class TestScenarioDoubleFailDBFallback:
return_value=mock_result,
),
):
result = await compact_transcript(transcript)
result = await compact_transcript(transcript, model="test-model")
# Compaction succeeded — caller would use this for attempt 2
assert result is not None
@@ -331,7 +331,7 @@ class TestScenarioCompactionIdentical:
return_value=mock_result,
),
):
result = await compact_transcript(transcript)
result = await compact_transcript(transcript, model="test-model")
# Returns None — signals caller to fall through to DB fallback
assert result is None
@@ -587,7 +587,7 @@ class TestRetryEdgeCases:
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
)(),
):
result = await compact_transcript(transcript)
result = await compact_transcript(transcript, model="test-model")
assert result is None
@pytest.mark.asyncio
@@ -618,7 +618,7 @@ class TestRetryEdgeCases:
return_value=mock_result,
),
):
result = await compact_transcript(transcript)
result = await compact_transcript(transcript, model="test-model")
assert result is not None
assert result != transcript
@@ -726,7 +726,7 @@ class TestRetryStateReset:
side_effect=RuntimeError("boom"),
),
):
compacted = await compact_transcript(transcript)
compacted = await compact_transcript(transcript, model="test-model")
# compact_transcript returns None on failure
assert compacted is None

View File

@@ -12,6 +12,7 @@ import subprocess
import sys
import uuid
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any, NamedTuple, cast
import openai
@@ -102,7 +103,7 @@ _PROMPT_TOO_LONG_PATTERNS: tuple[str, ...] = (
"request too large",
"maximum context length",
"context_length_exceeded",
"max_tokens_exceeded",
"input tokens exceed",
"input is too long",
"content length exceeds",
)
@@ -133,6 +134,24 @@ class ReducedContext(NamedTuple):
tried_compaction: bool
@dataclass
class _RetryState:
"""Mutable state passed to ``_run_stream_attempt`` instead of closures.
Holds values that the retry loop modifies between attempts so the inner
generator does not rely on reassigning closure variables.
"""
options: ClaudeAgentOptions
query_message: str
was_compacted: bool
use_resume: bool
resume_file: str | None
transcript_msg_count: int
adapter: SDKResponseAdapter
transcript_builder: TranscriptBuilder
async def _reduce_context(
transcript_content: str,
tried_compaction: bool,
@@ -871,6 +890,7 @@ async def stream_chat_completion_sdk(
_otel_ctx: Any = None
transcript_caused_error = False
transcript_content: str = ""
state: _RetryState | None = None
# Make sure there is no more code between the lock acquisition and try-block.
try:
@@ -1091,7 +1111,9 @@ async def stream_chat_completion_sdk(
_tried_compaction = False
async def _run_stream_attempt() -> AsyncGenerator[StreamBaseResponse, None]:
async def _run_stream_attempt(
state: _RetryState,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Run one SDK streaming attempt.
Opens a ``ClaudeSDKClient``, sends the query, iterates SDK
@@ -1102,12 +1124,11 @@ async def stream_chat_completion_sdk(
Yields stream events. On stream error the exception propagates
to the caller so the retry loop can rollback and retry.
Outer-scope variable contract (closure):
Reassigned between retries by the retry loop:
``options``, ``query_message``, ``was_compacted``,
``transcript_builder``, ``adapter``, ``use_resume``,
``resume_file``, ``transcript_msg_count``
Read-only (unchanged across retries):
Args:
state: Mutable retry state — holds values that the retry loop
modifies between attempts (options, query, adapter, etc.).
Read-only outer-scope variables (unchanged across retries):
``session``, ``session_id``, ``sdk_cwd``, ``log_prefix``,
``compaction``, ``attachments``, ``current_message``,
``file_ids``, ``lock``, ``message_id``, ``e2b_sandbox``
@@ -1118,27 +1139,27 @@ async def stream_chat_completion_sdk(
has_tool_results = False
stream_completed = False
async with ClaudeSDKClient(options=options) as client:
async with ClaudeSDKClient(options=state.options) as client:
logger.info(
"%s Sending query — resume=%s, total_msgs=%d, "
"query_len=%d, attached_files=%d, image_blocks=%d",
log_prefix,
use_resume,
state.use_resume,
len(session.messages),
len(query_message),
len(state.query_message),
len(file_ids) if file_ids else 0,
len(attachments.image_blocks),
)
compaction.reset_for_query()
if was_compacted:
if state.was_compacted:
for ev in compaction.emit_pre_query(session):
yield ev
if attachments.image_blocks:
content_blocks: list[dict[str, Any]] = [
*attachments.image_blocks,
{"type": "text", "text": query_message},
{"type": "text", "text": state.query_message},
]
user_msg = {
"type": "user",
@@ -1150,15 +1171,15 @@ async def stream_chat_completion_sdk(
await client._transport.write( # noqa: SLF001
json.dumps(user_msg) + "\n"
)
transcript_builder.append_user(
state.transcript_builder.append_user(
content=[
*attachments.image_blocks,
{"type": "text", "text": current_message},
]
)
else:
await client.query(query_message, session_id=session_id)
transcript_builder.append_user(content=current_message)
await client.query(state.query_message, session_id=session_id)
state.transcript_builder.append_user(content=current_message)
async for sdk_msg in _iter_sdk_messages(client):
# Heartbeat sentinel — refresh lock and keep SSE alive
@@ -1174,10 +1195,10 @@ async def stream_chat_completion_sdk(
log_prefix,
type(sdk_msg).__name__,
getattr(sdk_msg, "subtype", ""),
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
len(adapter.current_tool_calls),
len(adapter.resolved_tool_calls),
len(state.adapter.current_tool_calls)
- len(state.adapter.resolved_tool_calls),
len(state.adapter.current_tool_calls),
len(state.adapter.resolved_tool_calls),
)
# Log AssistantMessage API errors (e.g. invalid_request)
@@ -1209,7 +1230,7 @@ async def stream_chat_completion_sdk(
sdk_msg, AssistantMessage
) and all(isinstance(b, ToolUseBlock) for b in sdk_msg.content)
if (
adapter.has_unresolved_tool_calls
state.adapter.has_unresolved_tool_calls
and isinstance(sdk_msg, (AssistantMessage, ResultMessage))
and not is_parallel_continuation
):
@@ -1220,8 +1241,8 @@ async def stream_chat_completion_sdk(
"%s Timed out waiting for PostToolUse "
"hook stash (%d unresolved tool calls)",
log_prefix,
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
len(state.adapter.current_tool_calls)
- len(state.adapter.resolved_tool_calls),
)
# Log ResultMessage details for debugging
@@ -1231,10 +1252,10 @@ async def stream_chat_completion_sdk(
"(unresolved=%d, current=%d, resolved=%d)",
log_prefix,
sdk_msg.subtype,
len(adapter.current_tool_calls)
- len(adapter.resolved_tool_calls),
len(adapter.current_tool_calls),
len(adapter.resolved_tool_calls),
len(state.adapter.current_tool_calls)
- len(state.adapter.resolved_tool_calls),
len(state.adapter.current_tool_calls),
len(state.adapter.resolved_tool_calls),
)
if sdk_msg.subtype in (
"error",
@@ -1258,13 +1279,13 @@ async def stream_chat_completion_sdk(
compact_result.transcript_path,
)
if compacted is not None:
transcript_builder.replace_entries(
state.transcript_builder.replace_entries(
compacted, log_prefix=log_prefix
)
entries_replaced = True
# --- Dispatch adapter responses ---
for response in adapter.convert_message(sdk_msg):
for response in state.adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
continue
@@ -1346,7 +1367,7 @@ async def stream_chat_completion_sdk(
)
)
if not entries_replaced:
transcript_builder.append_tool_result(
state.transcript_builder.append_tool_result(
tool_use_id=response.toolCallId,
content=content,
)
@@ -1362,7 +1383,7 @@ async def stream_chat_completion_sdk(
# Skip if replace_entries just ran — the CLI session
# file already contains this message.
if isinstance(sdk_msg, AssistantMessage) and not entries_replaced:
transcript_builder.append_assistant(
state.transcript_builder.append_assistant(
content_blocks=_format_sdk_content_blocks(sdk_msg.content),
model=sdk_msg.model,
)
@@ -1371,14 +1392,15 @@ async def stream_chat_completion_sdk(
break
# --- Post-stream processing (only on success) ---
if adapter.has_unresolved_tool_calls:
if state.adapter.has_unresolved_tool_calls:
logger.warning(
"%s %d unresolved tool(s) after stream — flushing",
log_prefix,
len(adapter.current_tool_calls) - len(adapter.resolved_tool_calls),
len(state.adapter.current_tool_calls)
- len(state.adapter.resolved_tool_calls),
)
safety_responses: list[StreamBaseResponse] = []
adapter._flush_unresolved_tool_calls(safety_responses)
state.adapter._flush_unresolved_tool_calls(safety_responses)
for response in safety_responses:
if isinstance(
response,
@@ -1391,7 +1413,7 @@ async def stream_chat_completion_sdk(
getattr(response, "toolName", "N/A"),
)
if isinstance(response, StreamToolOutputAvailable):
transcript_builder.append_tool_result(
state.transcript_builder.append_tool_result(
tool_use_id=response.toolCallId,
content=(
response.output
@@ -1407,7 +1429,7 @@ async def stream_chat_completion_sdk(
log_prefix,
)
closing_responses: list[StreamBaseResponse] = []
adapter._end_text_if_open(closing_responses)
state.adapter._end_text_if_open(closing_responses)
for r in closing_responses:
yield r
session.messages.append(
@@ -1429,6 +1451,17 @@ async def stream_chat_completion_sdk(
attempts_exhausted = False
stream_err: Exception | None = None
state = _RetryState(
options=options,
query_message=query_message,
was_compacted=was_compacted,
use_resume=use_resume,
resume_file=resume_file,
transcript_msg_count=transcript_msg_count,
adapter=adapter,
transcript_builder=transcript_builder,
)
for _attempt in range(_MAX_STREAM_ATTEMPTS):
if _attempt > 0:
logger.info(
@@ -1445,11 +1478,11 @@ async def stream_chat_completion_sdk(
sdk_cwd,
log_prefix,
)
transcript_builder = ctx.builder
use_resume = ctx.use_resume
resume_file = ctx.resume_file
state.transcript_builder = ctx.builder
state.use_resume = ctx.use_resume
state.resume_file = ctx.resume_file
_tried_compaction = ctx.tried_compaction
transcript_msg_count = 0
state.transcript_msg_count = 0
if ctx.transcript_lost:
transcript_caused_error = True
@@ -1459,24 +1492,26 @@ async def stream_chat_completion_sdk(
sdk_options_kwargs_retry["resume"] = ctx.resume_file
elif "resume" in sdk_options_kwargs_retry:
del sdk_options_kwargs_retry["resume"]
options = ClaudeAgentOptions(**sdk_options_kwargs_retry) # type: ignore[arg-type] # dynamic kwargs
query_message, was_compacted = await _build_query_message(
state.options = ClaudeAgentOptions(**sdk_options_kwargs_retry) # type: ignore[arg-type] # dynamic kwargs
state.query_message, state.was_compacted = await _build_query_message(
current_message,
session,
use_resume,
transcript_msg_count,
state.use_resume,
state.transcript_msg_count,
session_id,
)
if attachments.hint:
query_message = f"{query_message}\n\n{attachments.hint}"
adapter = SDKResponseAdapter(
state.query_message = f"{state.query_message}\n\n{attachments.hint}"
state.adapter = SDKResponseAdapter(
message_id=message_id, session_id=session_id
)
_pre_attempt_msg_count = len(session.messages)
_events_yielded = 0
try:
async for event in _run_stream_attempt():
async for event in _run_stream_attempt(state):
_events_yielded += 1
yield event
break # Stream completed — exit retry loop
except asyncio.CancelledError:
@@ -1491,15 +1526,28 @@ async def stream_chat_completion_sdk(
stream_err = e
is_context_error = _is_prompt_too_long(e)
logger.warning(
"%s Stream error (attempt %d/%d, context_error=%s): %s",
"%s Stream error (attempt %d/%d, context_error=%s, "
"events_yielded=%d): %s",
log_prefix,
_attempt + 1,
_MAX_STREAM_ATTEMPTS,
is_context_error,
_events_yielded,
stream_err,
exc_info=True,
)
session.messages = session.messages[:_pre_attempt_msg_count]
if _events_yielded > 0:
# Events were already sent to the frontend and cannot be
# unsent. Retrying would produce duplicate/inconsistent
# output, so treat this as a final error.
logger.warning(
"%s Not retrying — %d events already yielded",
log_prefix,
_events_yielded,
)
ended_with_stream_error = True
break
if not is_context_error:
# Non-context errors (network, auth, rate-limit) should
# not trigger compaction — surface the error immediately.
@@ -1519,12 +1567,30 @@ async def stream_chat_completion_sdk(
stream_err,
)
if ended_with_stream_error and state is not None:
# Flush any unresolved tool calls so the frontend can close
# stale UI elements (e.g. spinners) that were started before
# the exception interrupted the stream.
error_flush: list[StreamBaseResponse] = []
state.adapter._end_text_if_open(error_flush)
if state.adapter.has_unresolved_tool_calls:
logger.warning(
"%s Flushing %d unresolved tool(s) after stream error",
log_prefix,
len(state.adapter.current_tool_calls)
- len(state.adapter.resolved_tool_calls),
)
state.adapter._flush_unresolved_tool_calls(error_flush)
for response in error_flush:
yield response
if ended_with_stream_error and stream_err is not None:
# Use distinct error codes: "all_attempts_exhausted" when all
# retries were consumed vs "sdk_stream_error" for non-context
# errors that broke the loop immediately (network, auth, etc.).
safe_err = str(stream_err).replace("\n", " ").replace("\r", "")[:500]
yield StreamError(
errorText=f"SDK stream error: {stream_err}",
errorText=f"SDK stream error: {safe_err}",
code=(
"all_attempts_exhausted"
if attempts_exhausted
@@ -1633,10 +1699,15 @@ async def stream_chat_completion_sdk(
"prompt-too-long error",
log_prefix,
)
elif config.claude_agent_use_resume and user_id and session is not None:
elif (
config.claude_agent_use_resume
and user_id
and session is not None
and state is not None
):
try:
transcript_upload_content = transcript_builder.to_jsonl()
entry_count = transcript_builder.entry_count
transcript_upload_content = state.transcript_builder.to_jsonl()
entry_count = state.transcript_builder.entry_count
if not transcript_upload_content:
logger.warning(

View File

@@ -22,7 +22,6 @@ from dataclasses import dataclass
from pathlib import Path
from uuid import uuid4
from backend.copilot.config import ChatConfig
from backend.util import json
from backend.util.clients import get_openai_client
from backend.util.prompt import CompressResult, compress_context
@@ -611,8 +610,9 @@ def _flatten_assistant_content(blocks: list) -> str:
elif btype == "tool_use":
parts.append(f"[tool_use: {block.get('name', '?')}]")
else:
# Preserve non-text blocks (e.g. image) as placeholders
parts.append(f"[{btype}]")
# Preserve non-text blocks (e.g. image) as placeholders.
# Use __prefix__ to distinguish from literal user text.
parts.append(f"[__{btype}__]")
elif isinstance(block, str):
parts.append(block)
return "\n".join(parts) if parts else ""
@@ -646,9 +646,10 @@ def _flatten_tool_result_content(blocks: list) -> str:
elif isinstance(block, dict) and block.get("type") == "text":
str_parts.append(str(block.get("text", "")))
elif isinstance(block, dict):
# Preserve non-text/non-tool_result blocks (e.g. image) as placeholders
# Preserve non-text/non-tool_result blocks (e.g. image) as placeholders.
# Use __prefix__ to distinguish from literal user text.
btype = block.get("type", "unknown")
str_parts.append(f"[{btype}]")
str_parts.append(f"[__{btype}__]")
elif isinstance(block, str):
str_parts.append(block)
return "\n".join(str_parts) if str_parts else ""
@@ -759,7 +760,7 @@ async def _run_compression(
async def compact_transcript(
content: str,
*,
model: str = "",
model: str,
log_prefix: str = "[Transcript]",
) -> str | None:
"""Compact an oversized JSONL transcript using LLM summarization.
@@ -782,8 +783,6 @@ async def compact_transcript(
Returns the compacted JSONL string, or ``None`` on failure.
"""
if not model:
model = ChatConfig().model
messages = _transcript_to_messages(content)
if len(messages) < 2:
logger.warning("%s Too few messages to compact (%d)", log_prefix, len(messages))