refactor(backend/copilot): consolidate SDK partial-preserve state into _InterruptedAttempt

Previous revisions carried the failed-attempt state across three separate
function-scope variables (last_attempt_partial, handled_error_info) + four
module-level helpers (_rollback_attempt_capturing_partial,
_restore_partial_with_error_marker, _flush_orphan_tool_uses_to_session,
_append_error_marker). The retry loop mutated all three and the post-loop
block reassembled the pieces by hand. Scattered and hard to follow.

Collapse to one dataclass with capture / clear / finalize + one
_classify_final_failure helper that picks the display message based on
which failure flag the retry loop set (attempts_exhausted,
transient_exhausted, stream_err, handled_error). Call sites:

  - success break:          interrupted.clear()
  - _HandledStreamError:    interrupted.capture(...); interrupted.handled_error = ...
  - Exception:              interrupted.capture(...)
  - post-loop:              final_msg, retryable = _classify_final_failure(interrupted, ...); interrupted.finalize(...)
  - outer except:           interrupted.finalize(...)

Behaviour is unchanged — same restore semantics, same StreamError
sequencing, same transcript-upload skip, same orphan tool_use flush, same
stale-marker stripping from b1172e203 / 5406fe9b5. The retry-scenarios
suite (48 integration tests) plus the rewritten interrupted_partial_test
(14 unit tests) both pass; the full SDK test suite (1012 tests) is green.
This commit is contained in:
majdyz
2026-04-25 09:40:30 +07:00
parent 5406fe9b5b
commit 2e7c5fbecd
2 changed files with 284 additions and 336 deletions

View File

@@ -1,14 +1,13 @@
"""Tests for partial-work preservation when an SDK turn is interrupted.
Covers the regression path SECRT-2275 surfaced: when the SDK retry loop
rolls back ``session.messages`` for a failed attempt (correct behavior so a
successful retry doesn't duplicate content) it MUST re-attach the rolled-back
work on final-failure exit. Otherwise the user's UI streamed tokens live but
Covers the regression SECRT-2275 surfaced: when the SDK retry loop rolls
back ``session.messages`` for a failed attempt (correct so a successful
retry doesn't duplicate content) it MUST re-attach the rolled-back work on
final-failure exit. Without that, the user's UI streamed tokens live then
a refresh shows an empty turn — described by users as "the turn is gone".
Tests target the helper functions directly (unit) plus the rollback-then-
restore contract (state-driven). Full end-to-end coverage of the retry loop
lives in retry_scenarios_test.py.
Tests target the ``_InterruptedAttempt`` dataclass + the orphan-tool flush
directly. Full retry-loop coverage lives in ``retry_scenarios_test.py``.
"""
from __future__ import annotations
@@ -24,8 +23,8 @@ from backend.copilot.response_model import StreamToolOutputAvailable
from .service import (
_flush_orphan_tool_uses_to_session,
_restore_partial_with_error_marker,
_rollback_attempt_capturing_partial,
_HandledErrorInfo,
_InterruptedAttempt,
)
@@ -35,21 +34,19 @@ def _make_session(messages: list[ChatMessage] | None = None) -> ChatSession:
return session
def _make_tool_output(tool_call_id: str, output) -> StreamToolOutputAvailable:
def _tool_output(tool_call_id: str, output) -> StreamToolOutputAvailable:
return StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName="t",
output=output,
toolCallId=tool_call_id, toolName="t", output=output
)
def _adapter_with_unresolved(unresolved_responses: list[StreamToolOutputAvailable]):
"""Build a stub _RetryState whose adapter flushes the given responses."""
def _adapter_with_unresolved(responses: list[StreamToolOutputAvailable]):
"""Stub _RetryState whose adapter flushes the given responses."""
adapter = MagicMock()
adapter.has_unresolved_tool_calls = bool(unresolved_responses)
adapter.has_unresolved_tool_calls = bool(responses)
def _flush(responses: list) -> None:
responses.extend(unresolved_responses)
def _flush(out: list) -> None:
out.extend(responses)
adapter.has_unresolved_tool_calls = False
adapter._flush_unresolved_tool_calls.side_effect = _flush
@@ -58,145 +55,29 @@ def _adapter_with_unresolved(unresolved_responses: list[StreamToolOutputAvailabl
return state
class TestRestorePartialWithErrorMarker:
def test_appends_partial_then_marker_when_partial_present(self):
session = _make_session([ChatMessage(role="user", content="hi")])
partial = [
ChatMessage(role="assistant", content="I was working on "),
ChatMessage(role="tool", content="result-1", tool_call_id="t1"),
]
_restore_partial_with_error_marker(
session,
state=None,
partial=partial,
display_msg="Boom",
retryable=False,
)
# Pre-existing user msg + 2 partial msgs + error marker
assert len(session.messages) == 4
assert session.messages[1].content == "I was working on "
assert session.messages[2].role == "tool"
assert session.messages[3].content.startswith(COPILOT_ERROR_PREFIX)
# Partial list is consumed (cleared) so a stray follow-up call won't
# double-attach the same content.
assert partial == []
def test_only_marker_when_partial_empty(self):
session = _make_session([ChatMessage(role="user", content="hi")])
_restore_partial_with_error_marker(
session,
state=None,
partial=[],
display_msg="Boom",
retryable=True,
)
assert len(session.messages) == 2
assert session.messages[-1].content.startswith(COPILOT_RETRYABLE_ERROR_PREFIX)
def test_noop_when_session_is_none(self):
# Signature accepts None — must not raise.
_restore_partial_with_error_marker(
None,
state=None,
partial=[ChatMessage(role="assistant", content="x")],
display_msg="Boom",
retryable=False,
)
def test_flushes_unresolved_tools_between_partial_and_marker(self):
session = _make_session([ChatMessage(role="user", content="hi")])
partial = [
ChatMessage(
role="assistant",
content="calling tool",
tool_calls=[
{
"id": "t1",
"type": "function",
"function": {"name": "lookup", "arguments": "{}"},
}
],
),
]
state = _adapter_with_unresolved([_make_tool_output("t1", "interrupted")])
_restore_partial_with_error_marker(
session,
state=state,
partial=partial,
display_msg="Boom",
retryable=False,
)
roles = [m.role for m in session.messages]
# user, assistant(partial), tool(synthetic), assistant(error marker)
assert roles == ["user", "assistant", "tool", "assistant"]
synthetic_tool = session.messages[2]
assert synthetic_tool.tool_call_id == "t1"
assert synthetic_tool.content == "interrupted"
def _builder_stub() -> MagicMock:
builder = MagicMock()
builder.restore = MagicMock()
return builder
class TestFlushOrphanToolUses:
def test_appends_synthetic_tool_results_for_unresolved(self):
session = _make_session()
state = _adapter_with_unresolved(
[
_make_tool_output("t1", "r1"),
_make_tool_output("t2", {"ok": False}),
]
)
_flush_orphan_tool_uses_to_session(session, state)
assert [m.tool_call_id for m in session.messages] == ["t1", "t2"]
# Dict outputs are JSON-encoded so they survive the str-only ChatMessage
# content field without losing structure for the next-turn LLM read.
assert session.messages[1].content == '{"ok": false}'
def test_noop_when_state_is_none(self):
session = _make_session()
_flush_orphan_tool_uses_to_session(session, None)
assert session.messages == []
def test_noop_when_no_unresolved(self):
session = _make_session()
adapter = MagicMock()
adapter.has_unresolved_tool_calls = False
state = MagicMock()
state.adapter = adapter
_flush_orphan_tool_uses_to_session(session, state)
adapter._flush_unresolved_tool_calls.assert_not_called()
class TestRollbackCapturingPartial:
"""Direct tests of `_rollback_attempt_capturing_partial`.
The retry loop relies on this helper not to leak error markers that
`_run_stream_attempt` already appended to `session.messages` — otherwise
the post-loop restore replays a stale marker before adding its own,
leaving duplicate error bubbles.
"""
def _builder_with_snap(self):
builder = MagicMock()
builder.restore = MagicMock()
return builder
def test_returns_partial_when_no_marker_present(self):
class TestInterruptedAttemptCapture:
def test_keeps_partial_when_no_marker_present(self):
session = _make_session(
[
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="part-1"),
]
)
builder = self._builder_with_snap()
captured = _rollback_attempt_capturing_partial(
session, builder, transcript_snap=object(), pre_attempt_msg_count=1
)
assert [m.content for m in captured] == ["part-1"]
assert session.messages == [ChatMessage(role="user", content="hi")]
attempt = _InterruptedAttempt()
attempt.capture(session, _builder_stub(), object(), pre_attempt_msg_count=1)
assert [m.content for m in attempt.partial] == ["part-1"]
assert [m.content for m in session.messages] == ["hi"]
def test_strips_trailing_error_marker(self):
# _run_stream_attempt appended a marker via _append_error_marker
# (e.g. idle timeout, circuit breaker) before raising
# _HandledStreamError. The rollback must NOT carry it forward, or
# the post-loop restore will replay the stale marker + add its own.
# _run_stream_attempt may append a marker (idle timeout, circuit
# breaker) before raising _HandledStreamError. Carrying it forward
# would let finalize() replay it and then add its own.
marker = (
f"{COPILOT_RETRYABLE_ERROR_PREFIX} The session has been idle "
"for too long. Please try again."
@@ -208,66 +89,136 @@ class TestRollbackCapturingPartial:
ChatMessage(role="assistant", content=marker),
]
)
captured = _rollback_attempt_capturing_partial(
session,
self._builder_with_snap(),
transcript_snap=object(),
pre_attempt_msg_count=1,
)
assert [m.content for m in captured] == ["part-1"]
attempt = _InterruptedAttempt()
attempt.capture(session, _builder_stub(), object(), pre_attempt_msg_count=1)
assert [m.content for m in attempt.partial] == ["part-1"]
def test_strips_consecutive_error_markers(self):
# Defensive: if more than one marker landed back-to-back (legacy
# path or future regression), strip them all.
session = _make_session(
[
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="part-1"),
ChatMessage(role="assistant", content=f"{COPILOT_ERROR_PREFIX} a"),
ChatMessage(
role="assistant",
content=f"{COPILOT_RETRYABLE_ERROR_PREFIX} b",
role="assistant", content=f"{COPILOT_RETRYABLE_ERROR_PREFIX} b"
),
]
)
captured = _rollback_attempt_capturing_partial(
session,
self._builder_with_snap(),
transcript_snap=object(),
pre_attempt_msg_count=1,
)
assert [m.content for m in captured] == ["part-1"]
attempt = _InterruptedAttempt()
attempt.capture(session, _builder_stub(), object(), pre_attempt_msg_count=1)
assert [m.content for m in attempt.partial] == ["part-1"]
def test_does_not_strip_non_marker_assistant(self):
# Regular assistant text starting with similar-but-not-prefix
# content must be preserved — only the canonical error markers
# should be filtered.
def test_preserves_non_marker_assistant(self):
session = _make_session(
[
ChatMessage(role="user", content="hi"),
ChatMessage(role="assistant", content="Important note"),
]
)
captured = _rollback_attempt_capturing_partial(
session,
self._builder_with_snap(),
transcript_snap=object(),
pre_attempt_msg_count=1,
attempt = _InterruptedAttempt()
attempt.capture(session, _builder_stub(), object(), pre_attempt_msg_count=1)
assert [m.content for m in attempt.partial] == ["Important note"]
class TestInterruptedAttemptFinalize:
def test_appends_partial_then_marker(self):
session = _make_session([ChatMessage(role="user", content="hi")])
attempt = _InterruptedAttempt(
partial=[
ChatMessage(role="assistant", content="working"),
ChatMessage(role="tool", content="result", tool_call_id="t1"),
]
)
assert [m.content for m in captured] == ["Important note"]
attempt.finalize(session, state=None, display_msg="Boom", retryable=False)
roles = [m.role for m in session.messages]
assert roles == ["user", "assistant", "tool", "assistant"]
assert session.messages[-1].content.startswith(COPILOT_ERROR_PREFIX)
# partial consumed so a follow-up finalize() is a no-op for partial.
assert attempt.partial == []
def test_only_marker_when_partial_empty(self):
session = _make_session([ChatMessage(role="user", content="hi")])
attempt = _InterruptedAttempt()
attempt.finalize(session, state=None, display_msg="Boom", retryable=True)
assert len(session.messages) == 2
assert session.messages[-1].content.startswith(COPILOT_RETRYABLE_ERROR_PREFIX)
def test_noop_when_session_is_none(self):
attempt = _InterruptedAttempt(
partial=[ChatMessage(role="assistant", content="x")]
)
attempt.finalize(None, state=None, display_msg="Boom", retryable=False)
# no raise = pass
def test_flushes_unresolved_tools_between_partial_and_marker(self):
session = _make_session([ChatMessage(role="user", content="hi")])
attempt = _InterruptedAttempt(
partial=[
ChatMessage(
role="assistant",
content="calling",
tool_calls=[
{
"id": "t1",
"type": "function",
"function": {"name": "lookup", "arguments": "{}"},
}
],
),
]
)
state = _adapter_with_unresolved([_tool_output("t1", "interrupted")])
attempt.finalize(session, state=state, display_msg="Boom", retryable=False)
roles = [m.role for m in session.messages]
assert roles == ["user", "assistant", "tool", "assistant"]
assert session.messages[2].tool_call_id == "t1"
assert session.messages[2].content == "interrupted"
def test_clear_drops_both_partial_and_handled_error(self):
attempt = _InterruptedAttempt(
partial=[ChatMessage(role="assistant", content="x")],
handled_error=_HandledErrorInfo(
error_msg="m", code="c", retryable=True, already_yielded=False
),
)
attempt.clear()
assert attempt.partial == []
assert attempt.handled_error is None
class TestFlushOrphanToolUses:
def test_appends_synthetic_tool_results_for_unresolved(self):
session = _make_session()
state = _adapter_with_unresolved(
[_tool_output("t1", "r1"), _tool_output("t2", {"ok": False})]
)
_flush_orphan_tool_uses_to_session(session, state)
assert [m.tool_call_id for m in session.messages] == ["t1", "t2"]
# Dict outputs are JSON-encoded so structure survives the str-only
# ChatMessage content field for the next-turn LLM read.
assert session.messages[1].content == '{"ok": false}'
def test_noop_when_state_is_none(self):
session = _make_session()
_flush_orphan_tool_uses_to_session(session, None)
assert session.messages == []
def test_noop_when_no_unresolved(self):
adapter = MagicMock()
adapter.has_unresolved_tool_calls = False
state = MagicMock()
state.adapter = adapter
_flush_orphan_tool_uses_to_session(_make_session(), state)
adapter._flush_unresolved_tool_calls.assert_not_called()
class TestRetryRollbackContract:
"""Property-style: a rolled-back attempt must be recoverable on final exit.
"""End-to-end contract: capture on a rolled-back attempt + finalize yields
the exact content the user saw streaming live, plus the error marker."""
Simulates the retry loop's rollback by mirroring the exact slicing and
captured-list shape used in stream_chat_completion_sdk so that any drift
in that contract is caught here without needing the full SDK fixture.
"""
def test_capture_slice_matches_rollback(self):
def test_capture_then_finalize_matches_streamed_sequence(self):
session = _make_session([ChatMessage(role="user", content="hi")])
pre_attempt_msg_count = len(session.messages)
pre = len(session.messages)
# Simulate incremental SDK appends during the attempt.
session.messages.extend(
[
@@ -275,18 +226,11 @@ class TestRetryRollbackContract:
ChatMessage(role="assistant", content="part-2"),
]
)
captured = list(session.messages[pre_attempt_msg_count:])
session.messages = session.messages[:pre_attempt_msg_count]
# Final-failure restore.
_restore_partial_with_error_marker(
session,
state=None,
partial=captured,
display_msg="Boom",
retryable=False,
)
contents = [m.content for m in session.messages]
assert contents == [
attempt = _InterruptedAttempt()
attempt.capture(session, _builder_stub(), object(), pre)
# Final-failure path — no retry, no success clear().
attempt.finalize(session, state=None, display_msg="Boom", retryable=False)
assert [m.content for m in session.messages] == [
"hi",
"part-1",
"part-2",

View File

@@ -541,14 +541,7 @@ def _append_error_marker(
*,
retryable: bool = False,
) -> None:
"""Append a copilot error marker to *session* so it persists across refresh.
Args:
session: The chat session to append to (no-op if `None`).
display_msg: User-visible error text.
retryable: If `True`, use the retryable prefix so the frontend
shows a "Try Again" button.
"""
"""Append a copilot error marker to *session* so it persists across refresh."""
if session is None:
return
prefix = COPILOT_RETRYABLE_ERROR_PREFIX if retryable else COPILOT_ERROR_PREFIX
@@ -557,17 +550,97 @@ def _append_error_marker(
)
def _is_error_marker(msg: ChatMessage) -> bool:
"""True if *msg* is an error marker emitted by ``_append_error_marker``."""
if msg.role != "assistant" or not msg.content:
return False
return msg.content.startswith(COPILOT_ERROR_PREFIX) or msg.content.startswith(
COPILOT_RETRYABLE_ERROR_PREFIX
)
@dataclass
class _InterruptedAttempt:
"""Captured state of a failed SDK attempt, carried across the retry loop.
The SDK always rolls back ``session.messages`` before deciding whether
to retry (so attempt #2 starts clean). That rollback would otherwise
discard everything the assistant produced — the user sees tokens stream
live, then a refresh shows nothing. This dataclass holds the rolled-back
messages plus the ``_HandledStreamError`` info needed to emit a final
``StreamError`` once the loop decides not to retry.
The retry loop calls ``capture()`` on every failed attempt, ``clear()``
on a successful retry (so prior rolled-back content is not replayed),
and ``finalize()`` exactly once after the loop on final failure.
"""
partial: list[ChatMessage] = dataclass_field(default_factory=list)
# Populated by the ``except _HandledStreamError`` branch so the post-loop
# block can restore the partial and (when the inner handler didn't) emit
# the client-facing StreamError. Transient errors deliberately suppress
# the early StreamError flash and rely on this post-loop emit.
handled_error: "_HandledErrorInfo | None" = None
def capture(
self,
session: ChatSession,
transcript_builder: "TranscriptBuilder",
transcript_snap: object,
pre_attempt_msg_count: int,
) -> None:
"""Roll back ``session.messages`` + transcript, keeping the partial.
Trailing error markers appended inside ``_run_stream_attempt`` (idle
timeout, circuit breaker) are stripped: re-attaching them would make
the post-loop restore replay a stale marker before adding its own,
leaving duplicate error bubbles.
"""
tail = list(session.messages[pre_attempt_msg_count:])
while tail and _is_error_marker(tail[-1]):
tail.pop()
self.partial = tail
session.messages = session.messages[:pre_attempt_msg_count]
transcript_builder.restore(transcript_snap) # type: ignore[arg-type]
def clear(self) -> None:
"""Drop captured state — used on successful retry."""
self.partial = []
self.handled_error = None
def finalize(
self,
session: ChatSession | None,
state: "_RetryState | None",
display_msg: str,
*,
retryable: bool,
) -> None:
"""Re-attach partial + synthetic tool_result rows + error marker.
Called exactly once after the retry loop on final-failure exit.
Idempotent on empty state, so it's safe to call on paths where no
rollback happened.
"""
if session is None:
return
if self.partial:
session.messages.extend(self.partial)
self.partial = []
_flush_orphan_tool_uses_to_session(session, state)
_append_error_marker(session, display_msg, retryable=retryable)
def _flush_orphan_tool_uses_to_session(
session: "ChatSession | None",
state: "_RetryState | None",
) -> None:
"""Synthesize tool_result rows for tool_use blocks that never resolved.
"""Synthesize ``tool_result`` rows for ``tool_use`` blocks that never resolved.
Without this, partial assistant work re-attached after a final failure
would carry orphan tool_use blocks. The next turn's LLM call would error
with ``tool_use_id without tool_result`` — and any baseline replay would
surface the same broken history. Flushing produces interrupted-marker
tool_results that satisfy the API contract.
Re-attached partial work may carry orphan ``tool_use`` blocks; without
matching ``tool_result`` rows the next turn's LLM call would error with
``tool_use_id without tool_result``. The adapter's safety-flush produces
interrupted-marker results that satisfy the API contract.
"""
if session is None or state is None:
return
@@ -587,65 +660,30 @@ def _flush_orphan_tool_uses_to_session(
)
def _restore_partial_with_error_marker(
session: "ChatSession | None",
state: "_RetryState | None",
partial: list[ChatMessage],
display_msg: str,
*,
retryable: bool,
) -> None:
"""Re-attach a rolled-back attempt's partial work, then add the error marker.
def _classify_final_failure(
interrupted: _InterruptedAttempt,
attempts_exhausted: bool,
transient_exhausted: bool,
stream_err: BaseException | None,
) -> tuple[str | None, bool]:
"""Pick the display message + retryable flag for the post-loop restore.
Called when retries are exhausted or no retry is attempted. Without this,
the SDK retry loop's pre-decision rollback would discard everything the
assistant produced in the failed attempt (text, tool calls, reasoning),
leaving the user with a chat that looks like nothing happened — even
though the events were already streamed live to their UI.
Mirrors the error-code selection used for the client-facing ``StreamError``
yield so the in-history marker and the SSE event stay consistent.
"""
if session is None:
return
if partial:
session.messages.extend(partial)
partial.clear()
_flush_orphan_tool_uses_to_session(session, state)
_append_error_marker(session, display_msg, retryable=retryable)
def _rollback_attempt_capturing_partial(
session: "ChatSession",
transcript_builder: "TranscriptBuilder",
transcript_snap: object,
pre_attempt_msg_count: int,
) -> list[ChatMessage]:
"""Roll back an attempt's session.messages + transcript, capturing the partial.
The returned list holds the assistant work that was incrementally appended
during the failed attempt. The caller passes it to
``_restore_partial_with_error_marker`` on final-failure exit and discards
it on a successful retry.
Trailing error markers appended inside ``_run_stream_attempt`` (idle
timeout, circuit breaker) are stripped: re-attaching them would let the
post-loop restore replay a stale marker before adding its own, leaving
duplicate error bubbles and pushing any synthetic ``tool_result`` after
an assistant(error) turn that has no matching ``tool_use``.
"""
captured = list(session.messages[pre_attempt_msg_count:])
while captured and _is_error_marker(captured[-1]):
captured.pop()
session.messages = session.messages[:pre_attempt_msg_count]
transcript_builder.restore(transcript_snap) # type: ignore[arg-type]
return captured
def _is_error_marker(msg: ChatMessage) -> bool:
"""True if *msg* is an error marker emitted by ``_append_error_marker``."""
if msg.role != "assistant" or not msg.content:
return False
return msg.content.startswith(COPILOT_ERROR_PREFIX) or msg.content.startswith(
COPILOT_RETRYABLE_ERROR_PREFIX
)
if interrupted.handled_error is not None:
return interrupted.handled_error.error_msg, interrupted.handled_error.retryable
if attempts_exhausted:
return (
"Your conversation is too long. "
"Please start a new chat or clear some history."
), False
if transient_exhausted:
return FRIENDLY_TRANSIENT_MSG, True
if stream_err is not None:
safe_err = str(stream_err).replace("\n", " ").replace("\r", "")[:500]
return _friendly_error_text(safe_err), False
return None, False
def _setup_langfuse_otel() -> None:
@@ -3112,12 +3150,13 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
request_arrival_at: float = 0.0,
**_kwargs: Any,
) -> AsyncGenerator[StreamBaseResponse, None]:
# Pyright's complexity heuristic bails on this function (~1500 LoC, retry
# Pyright's complexity heuristic bails on this ~1500 LoC function (retry
# loop with context-overflow fallback + transient backoff + partial-work
# preservation). Splitting further would hurt readability — the branches
# share state (session, adapter, transcript builder, token accumulators)
# that's hard to pass cleanly through helpers. Suppress the bailout; real
# type errors elsewhere in the file remain surfaced.
# preservation). Splitting the retry loop further hurts readability —
# branches share mutable state (session, adapter, transcript builder,
# usage accumulators) that doesn't pass cleanly through helpers. The
# suppression only silences the complexity bailout; real type errors in
# the function body still surface.
"""Stream chat completion using Claude Agent SDK.
Args:
@@ -3281,14 +3320,11 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
turn_cost_usd: float | None = None
graphiti_enabled = False
pre_attempt_msg_count = 0
# Holds messages that the retry loop rolls back from session.messages on a
# failed attempt. On final-failure exit we re-attach these so the user sees
# what the assistant produced before the error rather than an empty chat.
last_attempt_partial: list[ChatMessage] = []
# Populated by the _HandledStreamError branch. Consumed once after the
# retry loop to re-attach the partial and, if the inner handler hadn't
# already emitted one, yield a single StreamError to the client.
handled_error_info: _HandledErrorInfo | None = None
# State of the latest failed attempt: rolled-back messages + any
# _HandledStreamError info to emit on final-failure exit. The retry loop
# mutates this via capture()/clear(); the post-loop block calls
# finalize() once.
interrupted = _InterruptedAttempt()
# Defaults ensure the finally block can always reference these safely even when
# an early return (e.g. sdk_cwd error) skips their normal assignment below.
sdk_model: str | None = None
@@ -3949,10 +3985,9 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
"using fallback model for this request"
)
yield event
# Drop any stale partial captured from prior failed attempts
# so the outer cleanup paths don't re-attach pre-retry content
# the successful attempt already replaced.
last_attempt_partial.clear()
# Discard any state captured from prior failed attempts so
# outer cleanup paths don't replay pre-retry content.
interrupted.clear()
break # Stream completed — exit retry loop
except asyncio.CancelledError:
logger.warning(
@@ -3968,7 +4003,7 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
# session messages and set the error flag — do NOT set
# stream_err so the post-loop code won't emit a
# duplicate StreamError.
last_attempt_partial = _rollback_attempt_capturing_partial(
interrupted.capture(
session,
state.transcript_builder,
transcript_snap,
@@ -4007,7 +4042,7 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
# attempt that no longer match session.messages. Skip upload
# so a future --resume doesn't replay rolled-back content.
skip_transcript_upload = True
handled_error_info = _HandledErrorInfo(
interrupted.handled_error = _HandledErrorInfo(
error_msg=exc.error_msg or FRIENDLY_TRANSIENT_MSG,
code=exc.code or "transient_api_error",
retryable=exc.retryable,
@@ -4031,7 +4066,7 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
stream_err,
exc_info=True,
)
last_attempt_partial = _rollback_attempt_capturing_partial(
interrupted.capture(
session,
state.transcript_builder,
transcript_snap,
@@ -4097,38 +4132,15 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
_MAX_STREAM_ATTEMPTS,
stream_err,
)
# Restore the rolled-back partial work + add error marker exactly once
# per failure mode. Earlier revisions did this inline in each retry-loop
# branch; consolidating here keeps the retry loop itself simple enough
# for pyright to type-check.
# Restore the rolled-back partial + add error marker exactly once per
# failure mode. See _InterruptedAttempt for the carry-state contract.
if ended_with_stream_error:
if handled_error_info is not None:
final_msg = handled_error_info.error_msg
final_retryable = handled_error_info.retryable
elif attempts_exhausted:
final_msg = (
"Your conversation is too long. "
"Please start a new chat or clear some history."
)
final_retryable = False
elif transient_exhausted:
final_msg = FRIENDLY_TRANSIENT_MSG
final_retryable = True
elif stream_err is not None:
final_msg = _friendly_error_text(
str(stream_err).replace("\n", " ").replace("\r", "")[:500]
)
final_retryable = False
else:
final_msg = None
final_retryable = False
final_msg, final_retryable = _classify_final_failure(
interrupted, attempts_exhausted, transient_exhausted, stream_err
)
if final_msg is not None:
_restore_partial_with_error_marker(
session,
state,
last_attempt_partial,
final_msg,
retryable=final_retryable,
interrupted.finalize(
session, state, final_msg, retryable=final_retryable
)
if ended_with_stream_error and state is not None:
@@ -4173,10 +4185,13 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
# here when the inner handler chose not to (transient errors suppress
# the early flash so the client only sees the final error after all
# retries are exhausted).
if handled_error_info is not None and not handled_error_info.already_yielded:
if (
interrupted.handled_error is not None
and not interrupted.handled_error.already_yielded
):
yield StreamError(
errorText=handled_error_info.error_msg,
code=handled_error_info.code,
errorText=interrupted.handled_error.error_msg,
code=interrupted.handled_error.code,
)
# Copy token usage from retry state to outer-scope accumulators
@@ -4259,24 +4274,13 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
else:
display_msg, code = error_msg, "sdk_error"
# Append error marker to session (non-invasive text parsing approach).
# The finally block will persist the session with this error marker.
# Skip if a marker was already appended inside the stream loop
# (ended_with_stream_error) to avoid duplicate stale markers.
# Append error marker + restore any rolled-back partial when the retry
# loop didn't already finalize. ``interrupted`` is empty on success and
# on paths where the retry loop's own post-loop finalize() already ran,
# so this is a no-op for those and only kicks in for unhandled errors
# that bypass the retry-loop handlers entirely.
if not ended_with_stream_error:
# Restore any rolled-back partial work the retry loop captured —
# last_attempt_partial is empty on success / final-failure-handled
# paths (cleared on success break, consumed inside the retry loop
# by _restore_partial_with_error_marker), so this is a no-op for
# those cases and only kicks in when an unhandled exception bypassed
# the retry loop's own restoration.
_restore_partial_with_error_marker(
session,
state,
last_attempt_partial,
display_msg,
retryable=is_transient,
)
interrupted.finalize(session, state, display_msg, retryable=is_transient)
logger.debug(
"%s Appended error marker, will be persisted in finally",
log_prefix,