refactor(backend/copilot): expose public adapter flush + consolidate final-failure emit

CodeRabbit flagged that _flush_orphan_tool_uses_to_session (called from
_InterruptedAttempt.finalize) used state.adapter._flush_unresolved_tool_calls
with a # noqa: SLF001 suppressor. The private call mutates resolved_tool_calls
and flips has_unresolved_tool_calls to False, which caused the downstream
error-cleanup block at lines 4185-4200 to skip its own flush — UI spinners
on the client stayed open until page refresh because no cleanup events were
yielded after the early flush swallowed the unresolved state.

Changes:
- Rename _flush_unresolved_tool_calls → flush_unresolved_tool_calls (public)
  in response_adapter.py; update 3 internal call sites + 2 service.py sites.
  Drops the # noqa: SLF001 suppressor (no longer a private-access violation).
- _flush_orphan_tool_uses_to_session and _InterruptedAttempt.finalize now
  return the list[StreamBaseResponse] produced by the flush so the caller
  yields them to the client instead of re-flushing.
- Replace the three scattered post-loop error blocks (partial restore +
  redundant flush + stream_err yield + handled_error yield) with one
  consolidated block that: (a) calls _classify_final_failure → _FinalFailure,
  (b) yields finalize()'s events + _end_text_if_open, (c) yields one
  StreamError (unless handled_error.already_yielded=True). Fixes the
  double-flush skip-cleanup bug and eliminates duplicated error-text/code
  strings between history marker and SSE yield.
- _classify_final_failure now returns _FinalFailure(display_msg, code,
  retryable) instead of a (msg, retryable) tuple — single source of truth
  for in-history marker + SSE event so they can't drift.

Tests: +5 _classify_final_failure contract tests, +2 return-value assertions
on finalize/orphan-flush. All 1022 SDK tests pass (was 1012).
This commit is contained in:
majdyz
2026-04-25 09:51:44 +07:00
parent 2e7c5fbecd
commit 6576bf561e
3 changed files with 191 additions and 91 deletions

View File

@@ -22,6 +22,8 @@ from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.response_model import StreamToolOutputAvailable
from .service import (
_classify_final_failure,
_FinalFailure,
_flush_orphan_tool_uses_to_session,
_HandledErrorInfo,
_InterruptedAttempt,
@@ -49,7 +51,7 @@ def _adapter_with_unresolved(responses: list[StreamToolOutputAvailable]):
out.extend(responses)
adapter.has_unresolved_tool_calls = False
adapter._flush_unresolved_tool_calls.side_effect = _flush
adapter.flush_unresolved_tool_calls.side_effect = _flush
state = MagicMock()
state.adapter = adapter
return state
@@ -147,8 +149,8 @@ class TestInterruptedAttemptFinalize:
attempt = _InterruptedAttempt(
partial=[ChatMessage(role="assistant", content="x")]
)
attempt.finalize(None, state=None, display_msg="Boom", retryable=False)
# no raise = pass
events = attempt.finalize(None, state=None, display_msg="Boom", retryable=False)
assert events == []
def test_flushes_unresolved_tools_between_partial_and_marker(self):
session = _make_session([ChatMessage(role="user", content="hi")])
@@ -167,12 +169,20 @@ class TestInterruptedAttemptFinalize:
),
]
)
state = _adapter_with_unresolved([_tool_output("t1", "interrupted")])
attempt.finalize(session, state=state, display_msg="Boom", retryable=False)
flushed = [_tool_output("t1", "interrupted")]
state = _adapter_with_unresolved(flushed)
events = attempt.finalize(
session, state=state, display_msg="Boom", retryable=False
)
roles = [m.role for m in session.messages]
assert roles == ["user", "assistant", "tool", "assistant"]
assert session.messages[2].tool_call_id == "t1"
assert session.messages[2].content == "interrupted"
# The same events that were persisted to history are returned to the
# caller so the caller can yield them to the client — without this
# the frontend's spinner widgets stay open until refresh because the
# adapter's has_unresolved_tool_calls flag is already flipped to False.
assert events == flushed
def test_clear_drops_both_partial_and_handled_error(self):
attempt = _InterruptedAttempt(
@@ -189,27 +199,101 @@ class TestInterruptedAttemptFinalize:
class TestFlushOrphanToolUses:
def test_appends_synthetic_tool_results_for_unresolved(self):
session = _make_session()
state = _adapter_with_unresolved(
[_tool_output("t1", "r1"), _tool_output("t2", {"ok": False})]
)
_flush_orphan_tool_uses_to_session(session, state)
flushed = [_tool_output("t1", "r1"), _tool_output("t2", {"ok": False})]
state = _adapter_with_unresolved(flushed)
events = _flush_orphan_tool_uses_to_session(session, state)
assert [m.tool_call_id for m in session.messages] == ["t1", "t2"]
# Dict outputs are JSON-encoded so structure survives the str-only
# ChatMessage content field for the next-turn LLM read.
assert session.messages[1].content == '{"ok": false}'
assert events == flushed
def test_noop_when_state_is_none(self):
session = _make_session()
_flush_orphan_tool_uses_to_session(session, None)
events = _flush_orphan_tool_uses_to_session(session, None)
assert session.messages == []
assert events == []
def test_noop_when_no_unresolved(self):
adapter = MagicMock()
adapter.has_unresolved_tool_calls = False
state = MagicMock()
state.adapter = adapter
_flush_orphan_tool_uses_to_session(_make_session(), state)
adapter._flush_unresolved_tool_calls.assert_not_called()
events = _flush_orphan_tool_uses_to_session(_make_session(), state)
adapter.flush_unresolved_tool_calls.assert_not_called()
assert events == []
class TestClassifyFinalFailure:
"""Ensures the history marker (via finalize) and the SSE StreamError yield
share one source of truth for display message + stream code — any drift
would let the chat bubble and the SSE event show different copy for the
same failure."""
def test_handled_error_wins(self):
interrupted = _InterruptedAttempt(
handled_error=_HandledErrorInfo(
error_msg="circuit tripped",
code="circuit_breaker",
retryable=False,
already_yielded=True,
)
)
result = _classify_final_failure(
interrupted,
attempts_exhausted=False,
transient_exhausted=False,
stream_err=RuntimeError("ignored"),
)
assert result == _FinalFailure(
display_msg="circuit tripped",
code="circuit_breaker",
retryable=False,
)
def test_attempts_exhausted(self):
result = _classify_final_failure(
_InterruptedAttempt(),
attempts_exhausted=True,
transient_exhausted=False,
stream_err=RuntimeError("x"),
)
assert result is not None
assert result.code == "all_attempts_exhausted"
assert result.retryable is False
def test_transient_exhausted(self):
result = _classify_final_failure(
_InterruptedAttempt(),
attempts_exhausted=False,
transient_exhausted=True,
stream_err=RuntimeError("x"),
)
assert result is not None
assert result.code == "transient_api_error"
assert result.retryable is True
def test_stream_err_fallback(self):
result = _classify_final_failure(
_InterruptedAttempt(),
attempts_exhausted=False,
transient_exhausted=False,
stream_err=RuntimeError("some sdk error"),
)
assert result is not None
assert result.code == "sdk_stream_error"
assert result.retryable is False
def test_returns_none_when_no_failure_recorded(self):
assert (
_classify_final_failure(
_InterruptedAttempt(),
attempts_exhausted=False,
transient_exhausted=False,
stream_err=None,
)
is None
)
class TestRetryRollbackContract:

View File

@@ -166,7 +166,7 @@ class SDKResponseAdapter:
# are still executing concurrently and haven't finished yet.
is_tool_only = all(isinstance(b, ToolUseBlock) for b in sdk_message.content)
if not is_tool_only:
self._flush_unresolved_tool_calls(responses)
self.flush_unresolved_tool_calls(responses)
# After tool results, the SDK sends a new AssistantMessage for the
# next LLM turn. Open a new step if the previous one was closed.
@@ -375,7 +375,7 @@ class SDKResponseAdapter:
self.step_open = False
elif isinstance(sdk_message, ResultMessage):
self._flush_unresolved_tool_calls(responses)
self.flush_unresolved_tool_calls(responses)
# Thinking-only final turn guard: when the model's last LLM
# call after a tool result produced only a ``ThinkingBlock``
# (no ``TextBlock``, no ``ToolUseBlock``) the UI has nothing
@@ -703,7 +703,7 @@ class SDKResponseAdapter:
self._pending_thinking_delta = ""
self._pending_thinking_index = None
def _flush_unresolved_tool_calls(self, responses: list[StreamBaseResponse]) -> None:
def flush_unresolved_tool_calls(self, responses: list[StreamBaseResponse]) -> None:
"""Emit outputs for tool calls that didn't receive a UserMessage result.
SDK built-in tools (WebSearch, Read, etc.) may be executed by the CLI
@@ -711,6 +711,12 @@ class SDKResponseAdapter:
``ToolResultBlock`` content. The ``PostToolUse`` hook stashes their
output, which we pop and emit here before the next ``AssistantMessage``
starts.
Callers that need to both record synthetic tool_results in history AND
yield the same events to the client should call this exactly once and
share the resulting list — the method mutates ``resolved_tool_calls``,
so a second call returns nothing and ``has_unresolved_tool_calls``
flips to ``False`` after the first invocation.
"""
unresolved = [
(tid, info.get("name", "unknown"))

View File

@@ -615,39 +615,48 @@ class _InterruptedAttempt:
display_msg: str,
*,
retryable: bool,
) -> None:
) -> list[StreamBaseResponse]:
"""Re-attach partial + synthetic tool_result rows + error marker.
Called exactly once after the retry loop on final-failure exit.
Idempotent on empty state, so it's safe to call on paths where no
rollback happened.
Returns the ``StreamBaseResponse`` events produced by the safety
flush so the caller can yield them to the client (the flush mutates
adapter state, so a second flush elsewhere would return nothing and
stale UI elements like spinners would stay open).
"""
if session is None:
return
return []
if self.partial:
session.messages.extend(self.partial)
self.partial = []
_flush_orphan_tool_uses_to_session(session, state)
events = _flush_orphan_tool_uses_to_session(session, state)
_append_error_marker(session, display_msg, retryable=retryable)
return events
def _flush_orphan_tool_uses_to_session(
session: "ChatSession | None",
state: "_RetryState | None",
) -> None:
) -> list[StreamBaseResponse]:
"""Synthesize ``tool_result`` rows for ``tool_use`` blocks that never resolved.
Re-attached partial work may carry orphan ``tool_use`` blocks; without
matching ``tool_result`` rows the next turn's LLM call would error with
``tool_use_id without tool_result``. The adapter's safety-flush produces
interrupted-marker results that satisfy the API contract.
Returns the flushed events so callers can yield them to the client
alongside persisting the synthetic rows in session history.
"""
if session is None or state is None:
return
return []
if not state.adapter.has_unresolved_tool_calls:
return
return []
safety: list[StreamBaseResponse] = []
state.adapter._flush_unresolved_tool_calls(safety) # noqa: SLF001
state.adapter.flush_unresolved_tool_calls(safety)
for resp in safety:
if isinstance(resp, StreamToolOutputAvailable):
content = (
@@ -658,6 +667,20 @@ def _flush_orphan_tool_uses_to_session(
session.messages.append(
ChatMessage(role="tool", content=content, tool_call_id=resp.toolCallId)
)
return safety
@dataclass(frozen=True)
class _FinalFailure:
"""Display message + stream code + retryable flag for a final-failure exit.
Shared by the in-history error marker (via ``_InterruptedAttempt.finalize``)
and the client-facing ``StreamError`` SSE yield so the two stay in sync.
"""
display_msg: str
code: str
retryable: bool
def _classify_final_failure(
@@ -665,25 +688,41 @@ def _classify_final_failure(
attempts_exhausted: bool,
transient_exhausted: bool,
stream_err: BaseException | None,
) -> tuple[str | None, bool]:
"""Pick the display message + retryable flag for the post-loop restore.
) -> _FinalFailure | None:
"""Pick the display message, stream code, and retryable flag for the exit.
Mirrors the error-code selection used for the client-facing ``StreamError``
yield so the in-history marker and the SSE event stay consistent.
Returns ``None`` when no failure was recorded (success path) — the caller
should skip both the history marker and the SSE yield in that case.
"""
if interrupted.handled_error is not None:
return interrupted.handled_error.error_msg, interrupted.handled_error.retryable
return _FinalFailure(
display_msg=interrupted.handled_error.error_msg,
code=interrupted.handled_error.code,
retryable=interrupted.handled_error.retryable,
)
if attempts_exhausted:
return (
"Your conversation is too long. "
"Please start a new chat or clear some history."
), False
return _FinalFailure(
display_msg=(
"Your conversation is too long. "
"Please start a new chat or clear some history."
),
code="all_attempts_exhausted",
retryable=False,
)
if transient_exhausted:
return FRIENDLY_TRANSIENT_MSG, True
return _FinalFailure(
display_msg=FRIENDLY_TRANSIENT_MSG,
code="transient_api_error",
retryable=True,
)
if stream_err is not None:
safe_err = str(stream_err).replace("\n", " ").replace("\r", "")[:500]
return _friendly_error_text(safe_err), False
return None, False
return _FinalFailure(
display_msg=_friendly_error_text(safe_err),
code="sdk_stream_error",
retryable=False,
)
return None
def _setup_langfuse_otel() -> None:
@@ -2868,7 +2907,7 @@ async def _run_stream_attempt(
- len(state.adapter.resolved_tool_calls),
)
safety_responses: list[StreamBaseResponse] = []
state.adapter._flush_unresolved_tool_calls(safety_responses)
state.adapter.flush_unresolved_tool_calls(safety_responses)
for response in safety_responses:
if isinstance(
response,
@@ -4132,67 +4171,38 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
_MAX_STREAM_ATTEMPTS,
stream_err,
)
# Restore the rolled-back partial + add error marker exactly once per
# failure mode. See _InterruptedAttempt for the carry-state contract.
# Consolidated final-failure handling. _classify_final_failure picks
# the display message + stream code + retryable flag, finalize() adds
# the history marker and produces the safety-flush events that close
# stale UI widgets on the client, and the StreamError yield below
# surfaces the same message over SSE. The _HandledStreamError path
# sets ``already_yielded=True`` for non-transient errors (circuit
# breaker, idle timeout) whose inner handler already yielded — skip
# the re-yield in that case.
if ended_with_stream_error:
final_msg, final_retryable = _classify_final_failure(
failure = _classify_final_failure(
interrupted, attempts_exhausted, transient_exhausted, stream_err
)
if final_msg is not None:
interrupted.finalize(
session, state, final_msg, retryable=final_retryable
if failure is not None:
cleanup_events: list[StreamBaseResponse] = []
if state is not None:
state.adapter._end_text_if_open(cleanup_events)
cleanup_events.extend(
interrupted.finalize(
session,
state,
failure.display_msg,
retryable=failure.retryable,
)
)
if ended_with_stream_error and state is not None:
# Flush any unresolved tool calls so the frontend can close
# stale UI elements (e.g. spinners) that were started before
# the exception interrupted the stream.
error_flush: list[StreamBaseResponse] = []
state.adapter._end_text_if_open(error_flush)
if state.adapter.has_unresolved_tool_calls:
logger.warning(
"%s Flushing %d unresolved tool(s) after stream error",
log_prefix,
len(state.adapter.current_tool_calls)
- len(state.adapter.resolved_tool_calls),
for response in cleanup_events:
yield response
already_yielded = (
interrupted.handled_error is not None
and interrupted.handled_error.already_yielded
)
state.adapter._flush_unresolved_tool_calls(error_flush)
for response in error_flush:
yield response
if ended_with_stream_error and stream_err is not None:
# Use distinct error codes depending on how the loop ended:
# • "all_attempts_exhausted" — context compaction ran out of room
# • "transient_api_error" — 429/5xx/ECONNRESET retries exhausted
# • "sdk_stream_error" — non-context, non-transient fatal error
safe_err = str(stream_err).replace("\n", " ").replace("\r", "")[:500]
if attempts_exhausted:
error_text = (
"Your conversation is too long. "
"Please start a new chat or clear some history."
)
error_code = "all_attempts_exhausted"
elif transient_exhausted:
error_text = FRIENDLY_TRANSIENT_MSG
error_code = "transient_api_error"
else:
error_text = _friendly_error_text(safe_err)
error_code = "sdk_stream_error"
yield StreamError(errorText=error_text, code=error_code)
# _HandledStreamError exits the retry loop with stream_err unset, so
# the previous block doesn't fire. Emit the client-facing StreamError
# here when the inner handler chose not to (transient errors suppress
# the early flash so the client only sees the final error after all
# retries are exhausted).
if (
interrupted.handled_error is not None
and not interrupted.handled_error.already_yielded
):
yield StreamError(
errorText=interrupted.handled_error.error_msg,
code=interrupted.handled_error.code,
)
if not already_yielded:
yield StreamError(errorText=failure.display_msg, code=failure.code)
# Copy token usage from retry state to outer-scope accumulators
# so the finally block can persist them.