diff --git a/autogpt_platform/backend/backend/copilot/sdk/interrupted_partial_test.py b/autogpt_platform/backend/backend/copilot/sdk/interrupted_partial_test.py index 85b8c2c373..a3dbd964f6 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/interrupted_partial_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/interrupted_partial_test.py @@ -1,14 +1,13 @@ """Tests for partial-work preservation when an SDK turn is interrupted. -Covers the regression path SECRT-2275 surfaced: when the SDK retry loop -rolls back ``session.messages`` for a failed attempt (correct behavior so a -successful retry doesn't duplicate content) it MUST re-attach the rolled-back -work on final-failure exit. Otherwise the user's UI streamed tokens live but +Covers the regression SECRT-2275 surfaced: when the SDK retry loop rolls +back ``session.messages`` for a failed attempt (correct so a successful +retry doesn't duplicate content) it MUST re-attach the rolled-back work on +final-failure exit. Without that, the user's UI streamed tokens live then a refresh shows an empty turn — described by users as "the turn is gone". -Tests target the helper functions directly (unit) plus the rollback-then- -restore contract (state-driven). Full end-to-end coverage of the retry loop -lives in retry_scenarios_test.py. +Tests target the ``_InterruptedAttempt`` dataclass + the orphan-tool flush +directly. Full retry-loop coverage lives in ``retry_scenarios_test.py``. """ from __future__ import annotations @@ -24,8 +23,8 @@ from backend.copilot.response_model import StreamToolOutputAvailable from .service import ( _flush_orphan_tool_uses_to_session, - _restore_partial_with_error_marker, - _rollback_attempt_capturing_partial, + _HandledErrorInfo, + _InterruptedAttempt, ) @@ -35,21 +34,19 @@ def _make_session(messages: list[ChatMessage] | None = None) -> ChatSession: return session -def _make_tool_output(tool_call_id: str, output) -> StreamToolOutputAvailable: +def _tool_output(tool_call_id: str, output) -> StreamToolOutputAvailable: return StreamToolOutputAvailable( - toolCallId=tool_call_id, - toolName="t", - output=output, + toolCallId=tool_call_id, toolName="t", output=output ) -def _adapter_with_unresolved(unresolved_responses: list[StreamToolOutputAvailable]): - """Build a stub _RetryState whose adapter flushes the given responses.""" +def _adapter_with_unresolved(responses: list[StreamToolOutputAvailable]): + """Stub _RetryState whose adapter flushes the given responses.""" adapter = MagicMock() - adapter.has_unresolved_tool_calls = bool(unresolved_responses) + adapter.has_unresolved_tool_calls = bool(responses) - def _flush(responses: list) -> None: - responses.extend(unresolved_responses) + def _flush(out: list) -> None: + out.extend(responses) adapter.has_unresolved_tool_calls = False adapter._flush_unresolved_tool_calls.side_effect = _flush @@ -58,145 +55,29 @@ def _adapter_with_unresolved(unresolved_responses: list[StreamToolOutputAvailabl return state -class TestRestorePartialWithErrorMarker: - def test_appends_partial_then_marker_when_partial_present(self): - session = _make_session([ChatMessage(role="user", content="hi")]) - partial = [ - ChatMessage(role="assistant", content="I was working on "), - ChatMessage(role="tool", content="result-1", tool_call_id="t1"), - ] - _restore_partial_with_error_marker( - session, - state=None, - partial=partial, - display_msg="Boom", - retryable=False, - ) - # Pre-existing user msg + 2 partial msgs + error marker - assert len(session.messages) == 4 - assert session.messages[1].content == "I was working on " - assert session.messages[2].role == "tool" - assert session.messages[3].content.startswith(COPILOT_ERROR_PREFIX) - # Partial list is consumed (cleared) so a stray follow-up call won't - # double-attach the same content. - assert partial == [] - - def test_only_marker_when_partial_empty(self): - session = _make_session([ChatMessage(role="user", content="hi")]) - _restore_partial_with_error_marker( - session, - state=None, - partial=[], - display_msg="Boom", - retryable=True, - ) - assert len(session.messages) == 2 - assert session.messages[-1].content.startswith(COPILOT_RETRYABLE_ERROR_PREFIX) - - def test_noop_when_session_is_none(self): - # Signature accepts None — must not raise. - _restore_partial_with_error_marker( - None, - state=None, - partial=[ChatMessage(role="assistant", content="x")], - display_msg="Boom", - retryable=False, - ) - - def test_flushes_unresolved_tools_between_partial_and_marker(self): - session = _make_session([ChatMessage(role="user", content="hi")]) - partial = [ - ChatMessage( - role="assistant", - content="calling tool", - tool_calls=[ - { - "id": "t1", - "type": "function", - "function": {"name": "lookup", "arguments": "{}"}, - } - ], - ), - ] - state = _adapter_with_unresolved([_make_tool_output("t1", "interrupted")]) - _restore_partial_with_error_marker( - session, - state=state, - partial=partial, - display_msg="Boom", - retryable=False, - ) - roles = [m.role for m in session.messages] - # user, assistant(partial), tool(synthetic), assistant(error marker) - assert roles == ["user", "assistant", "tool", "assistant"] - synthetic_tool = session.messages[2] - assert synthetic_tool.tool_call_id == "t1" - assert synthetic_tool.content == "interrupted" +def _builder_stub() -> MagicMock: + builder = MagicMock() + builder.restore = MagicMock() + return builder -class TestFlushOrphanToolUses: - def test_appends_synthetic_tool_results_for_unresolved(self): - session = _make_session() - state = _adapter_with_unresolved( - [ - _make_tool_output("t1", "r1"), - _make_tool_output("t2", {"ok": False}), - ] - ) - _flush_orphan_tool_uses_to_session(session, state) - assert [m.tool_call_id for m in session.messages] == ["t1", "t2"] - # Dict outputs are JSON-encoded so they survive the str-only ChatMessage - # content field without losing structure for the next-turn LLM read. - assert session.messages[1].content == '{"ok": false}' - - def test_noop_when_state_is_none(self): - session = _make_session() - _flush_orphan_tool_uses_to_session(session, None) - assert session.messages == [] - - def test_noop_when_no_unresolved(self): - session = _make_session() - adapter = MagicMock() - adapter.has_unresolved_tool_calls = False - state = MagicMock() - state.adapter = adapter - _flush_orphan_tool_uses_to_session(session, state) - adapter._flush_unresolved_tool_calls.assert_not_called() - - -class TestRollbackCapturingPartial: - """Direct tests of `_rollback_attempt_capturing_partial`. - - The retry loop relies on this helper not to leak error markers that - `_run_stream_attempt` already appended to `session.messages` — otherwise - the post-loop restore replays a stale marker before adding its own, - leaving duplicate error bubbles. - """ - - def _builder_with_snap(self): - builder = MagicMock() - builder.restore = MagicMock() - return builder - - def test_returns_partial_when_no_marker_present(self): +class TestInterruptedAttemptCapture: + def test_keeps_partial_when_no_marker_present(self): session = _make_session( [ ChatMessage(role="user", content="hi"), ChatMessage(role="assistant", content="part-1"), ] ) - builder = self._builder_with_snap() - captured = _rollback_attempt_capturing_partial( - session, builder, transcript_snap=object(), pre_attempt_msg_count=1 - ) - assert [m.content for m in captured] == ["part-1"] - assert session.messages == [ChatMessage(role="user", content="hi")] + attempt = _InterruptedAttempt() + attempt.capture(session, _builder_stub(), object(), pre_attempt_msg_count=1) + assert [m.content for m in attempt.partial] == ["part-1"] + assert [m.content for m in session.messages] == ["hi"] def test_strips_trailing_error_marker(self): - # _run_stream_attempt appended a marker via _append_error_marker - # (e.g. idle timeout, circuit breaker) before raising - # _HandledStreamError. The rollback must NOT carry it forward, or - # the post-loop restore will replay the stale marker + add its own. + # _run_stream_attempt may append a marker (idle timeout, circuit + # breaker) before raising _HandledStreamError. Carrying it forward + # would let finalize() replay it and then add its own. marker = ( f"{COPILOT_RETRYABLE_ERROR_PREFIX} The session has been idle " "for too long. Please try again." @@ -208,66 +89,136 @@ class TestRollbackCapturingPartial: ChatMessage(role="assistant", content=marker), ] ) - captured = _rollback_attempt_capturing_partial( - session, - self._builder_with_snap(), - transcript_snap=object(), - pre_attempt_msg_count=1, - ) - assert [m.content for m in captured] == ["part-1"] + attempt = _InterruptedAttempt() + attempt.capture(session, _builder_stub(), object(), pre_attempt_msg_count=1) + assert [m.content for m in attempt.partial] == ["part-1"] def test_strips_consecutive_error_markers(self): - # Defensive: if more than one marker landed back-to-back (legacy - # path or future regression), strip them all. session = _make_session( [ ChatMessage(role="user", content="hi"), ChatMessage(role="assistant", content="part-1"), ChatMessage(role="assistant", content=f"{COPILOT_ERROR_PREFIX} a"), ChatMessage( - role="assistant", - content=f"{COPILOT_RETRYABLE_ERROR_PREFIX} b", + role="assistant", content=f"{COPILOT_RETRYABLE_ERROR_PREFIX} b" ), ] ) - captured = _rollback_attempt_capturing_partial( - session, - self._builder_with_snap(), - transcript_snap=object(), - pre_attempt_msg_count=1, - ) - assert [m.content for m in captured] == ["part-1"] + attempt = _InterruptedAttempt() + attempt.capture(session, _builder_stub(), object(), pre_attempt_msg_count=1) + assert [m.content for m in attempt.partial] == ["part-1"] - def test_does_not_strip_non_marker_assistant(self): - # Regular assistant text starting with similar-but-not-prefix - # content must be preserved — only the canonical error markers - # should be filtered. + def test_preserves_non_marker_assistant(self): session = _make_session( [ ChatMessage(role="user", content="hi"), ChatMessage(role="assistant", content="Important note"), ] ) - captured = _rollback_attempt_capturing_partial( - session, - self._builder_with_snap(), - transcript_snap=object(), - pre_attempt_msg_count=1, + attempt = _InterruptedAttempt() + attempt.capture(session, _builder_stub(), object(), pre_attempt_msg_count=1) + assert [m.content for m in attempt.partial] == ["Important note"] + + +class TestInterruptedAttemptFinalize: + def test_appends_partial_then_marker(self): + session = _make_session([ChatMessage(role="user", content="hi")]) + attempt = _InterruptedAttempt( + partial=[ + ChatMessage(role="assistant", content="working"), + ChatMessage(role="tool", content="result", tool_call_id="t1"), + ] ) - assert [m.content for m in captured] == ["Important note"] + attempt.finalize(session, state=None, display_msg="Boom", retryable=False) + roles = [m.role for m in session.messages] + assert roles == ["user", "assistant", "tool", "assistant"] + assert session.messages[-1].content.startswith(COPILOT_ERROR_PREFIX) + # partial consumed so a follow-up finalize() is a no-op for partial. + assert attempt.partial == [] + + def test_only_marker_when_partial_empty(self): + session = _make_session([ChatMessage(role="user", content="hi")]) + attempt = _InterruptedAttempt() + attempt.finalize(session, state=None, display_msg="Boom", retryable=True) + assert len(session.messages) == 2 + assert session.messages[-1].content.startswith(COPILOT_RETRYABLE_ERROR_PREFIX) + + def test_noop_when_session_is_none(self): + attempt = _InterruptedAttempt( + partial=[ChatMessage(role="assistant", content="x")] + ) + attempt.finalize(None, state=None, display_msg="Boom", retryable=False) + # no raise = pass + + def test_flushes_unresolved_tools_between_partial_and_marker(self): + session = _make_session([ChatMessage(role="user", content="hi")]) + attempt = _InterruptedAttempt( + partial=[ + ChatMessage( + role="assistant", + content="calling", + tool_calls=[ + { + "id": "t1", + "type": "function", + "function": {"name": "lookup", "arguments": "{}"}, + } + ], + ), + ] + ) + state = _adapter_with_unresolved([_tool_output("t1", "interrupted")]) + attempt.finalize(session, state=state, display_msg="Boom", retryable=False) + roles = [m.role for m in session.messages] + assert roles == ["user", "assistant", "tool", "assistant"] + assert session.messages[2].tool_call_id == "t1" + assert session.messages[2].content == "interrupted" + + def test_clear_drops_both_partial_and_handled_error(self): + attempt = _InterruptedAttempt( + partial=[ChatMessage(role="assistant", content="x")], + handled_error=_HandledErrorInfo( + error_msg="m", code="c", retryable=True, already_yielded=False + ), + ) + attempt.clear() + assert attempt.partial == [] + assert attempt.handled_error is None + + +class TestFlushOrphanToolUses: + def test_appends_synthetic_tool_results_for_unresolved(self): + session = _make_session() + state = _adapter_with_unresolved( + [_tool_output("t1", "r1"), _tool_output("t2", {"ok": False})] + ) + _flush_orphan_tool_uses_to_session(session, state) + assert [m.tool_call_id for m in session.messages] == ["t1", "t2"] + # Dict outputs are JSON-encoded so structure survives the str-only + # ChatMessage content field for the next-turn LLM read. + assert session.messages[1].content == '{"ok": false}' + + def test_noop_when_state_is_none(self): + session = _make_session() + _flush_orphan_tool_uses_to_session(session, None) + assert session.messages == [] + + def test_noop_when_no_unresolved(self): + adapter = MagicMock() + adapter.has_unresolved_tool_calls = False + state = MagicMock() + state.adapter = adapter + _flush_orphan_tool_uses_to_session(_make_session(), state) + adapter._flush_unresolved_tool_calls.assert_not_called() class TestRetryRollbackContract: - """Property-style: a rolled-back attempt must be recoverable on final exit. + """End-to-end contract: capture on a rolled-back attempt + finalize yields + the exact content the user saw streaming live, plus the error marker.""" - Simulates the retry loop's rollback by mirroring the exact slicing and - captured-list shape used in stream_chat_completion_sdk so that any drift - in that contract is caught here without needing the full SDK fixture. - """ - - def test_capture_slice_matches_rollback(self): + def test_capture_then_finalize_matches_streamed_sequence(self): session = _make_session([ChatMessage(role="user", content="hi")]) - pre_attempt_msg_count = len(session.messages) + pre = len(session.messages) # Simulate incremental SDK appends during the attempt. session.messages.extend( [ @@ -275,18 +226,11 @@ class TestRetryRollbackContract: ChatMessage(role="assistant", content="part-2"), ] ) - captured = list(session.messages[pre_attempt_msg_count:]) - session.messages = session.messages[:pre_attempt_msg_count] - # Final-failure restore. - _restore_partial_with_error_marker( - session, - state=None, - partial=captured, - display_msg="Boom", - retryable=False, - ) - contents = [m.content for m in session.messages] - assert contents == [ + attempt = _InterruptedAttempt() + attempt.capture(session, _builder_stub(), object(), pre) + # Final-failure path — no retry, no success clear(). + attempt.finalize(session, state=None, display_msg="Boom", retryable=False) + assert [m.content for m in session.messages] == [ "hi", "part-1", "part-2", diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index 613433993a..7704ece99e 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -541,14 +541,7 @@ def _append_error_marker( *, retryable: bool = False, ) -> None: - """Append a copilot error marker to *session* so it persists across refresh. - - Args: - session: The chat session to append to (no-op if `None`). - display_msg: User-visible error text. - retryable: If `True`, use the retryable prefix so the frontend - shows a "Try Again" button. - """ + """Append a copilot error marker to *session* so it persists across refresh.""" if session is None: return prefix = COPILOT_RETRYABLE_ERROR_PREFIX if retryable else COPILOT_ERROR_PREFIX @@ -557,17 +550,97 @@ def _append_error_marker( ) +def _is_error_marker(msg: ChatMessage) -> bool: + """True if *msg* is an error marker emitted by ``_append_error_marker``.""" + if msg.role != "assistant" or not msg.content: + return False + return msg.content.startswith(COPILOT_ERROR_PREFIX) or msg.content.startswith( + COPILOT_RETRYABLE_ERROR_PREFIX + ) + + +@dataclass +class _InterruptedAttempt: + """Captured state of a failed SDK attempt, carried across the retry loop. + + The SDK always rolls back ``session.messages`` before deciding whether + to retry (so attempt #2 starts clean). That rollback would otherwise + discard everything the assistant produced — the user sees tokens stream + live, then a refresh shows nothing. This dataclass holds the rolled-back + messages plus the ``_HandledStreamError`` info needed to emit a final + ``StreamError`` once the loop decides not to retry. + + The retry loop calls ``capture()`` on every failed attempt, ``clear()`` + on a successful retry (so prior rolled-back content is not replayed), + and ``finalize()`` exactly once after the loop on final failure. + """ + + partial: list[ChatMessage] = dataclass_field(default_factory=list) + # Populated by the ``except _HandledStreamError`` branch so the post-loop + # block can restore the partial and (when the inner handler didn't) emit + # the client-facing StreamError. Transient errors deliberately suppress + # the early StreamError flash and rely on this post-loop emit. + handled_error: "_HandledErrorInfo | None" = None + + def capture( + self, + session: ChatSession, + transcript_builder: "TranscriptBuilder", + transcript_snap: object, + pre_attempt_msg_count: int, + ) -> None: + """Roll back ``session.messages`` + transcript, keeping the partial. + + Trailing error markers appended inside ``_run_stream_attempt`` (idle + timeout, circuit breaker) are stripped: re-attaching them would make + the post-loop restore replay a stale marker before adding its own, + leaving duplicate error bubbles. + """ + tail = list(session.messages[pre_attempt_msg_count:]) + while tail and _is_error_marker(tail[-1]): + tail.pop() + self.partial = tail + session.messages = session.messages[:pre_attempt_msg_count] + transcript_builder.restore(transcript_snap) # type: ignore[arg-type] + + def clear(self) -> None: + """Drop captured state — used on successful retry.""" + self.partial = [] + self.handled_error = None + + def finalize( + self, + session: ChatSession | None, + state: "_RetryState | None", + display_msg: str, + *, + retryable: bool, + ) -> None: + """Re-attach partial + synthetic tool_result rows + error marker. + + Called exactly once after the retry loop on final-failure exit. + Idempotent on empty state, so it's safe to call on paths where no + rollback happened. + """ + if session is None: + return + if self.partial: + session.messages.extend(self.partial) + self.partial = [] + _flush_orphan_tool_uses_to_session(session, state) + _append_error_marker(session, display_msg, retryable=retryable) + + def _flush_orphan_tool_uses_to_session( session: "ChatSession | None", state: "_RetryState | None", ) -> None: - """Synthesize tool_result rows for tool_use blocks that never resolved. + """Synthesize ``tool_result`` rows for ``tool_use`` blocks that never resolved. - Without this, partial assistant work re-attached after a final failure - would carry orphan tool_use blocks. The next turn's LLM call would error - with ``tool_use_id without tool_result`` — and any baseline replay would - surface the same broken history. Flushing produces interrupted-marker - tool_results that satisfy the API contract. + Re-attached partial work may carry orphan ``tool_use`` blocks; without + matching ``tool_result`` rows the next turn's LLM call would error with + ``tool_use_id without tool_result``. The adapter's safety-flush produces + interrupted-marker results that satisfy the API contract. """ if session is None or state is None: return @@ -587,65 +660,30 @@ def _flush_orphan_tool_uses_to_session( ) -def _restore_partial_with_error_marker( - session: "ChatSession | None", - state: "_RetryState | None", - partial: list[ChatMessage], - display_msg: str, - *, - retryable: bool, -) -> None: - """Re-attach a rolled-back attempt's partial work, then add the error marker. +def _classify_final_failure( + interrupted: _InterruptedAttempt, + attempts_exhausted: bool, + transient_exhausted: bool, + stream_err: BaseException | None, +) -> tuple[str | None, bool]: + """Pick the display message + retryable flag for the post-loop restore. - Called when retries are exhausted or no retry is attempted. Without this, - the SDK retry loop's pre-decision rollback would discard everything the - assistant produced in the failed attempt (text, tool calls, reasoning), - leaving the user with a chat that looks like nothing happened — even - though the events were already streamed live to their UI. + Mirrors the error-code selection used for the client-facing ``StreamError`` + yield so the in-history marker and the SSE event stay consistent. """ - if session is None: - return - if partial: - session.messages.extend(partial) - partial.clear() - _flush_orphan_tool_uses_to_session(session, state) - _append_error_marker(session, display_msg, retryable=retryable) - - -def _rollback_attempt_capturing_partial( - session: "ChatSession", - transcript_builder: "TranscriptBuilder", - transcript_snap: object, - pre_attempt_msg_count: int, -) -> list[ChatMessage]: - """Roll back an attempt's session.messages + transcript, capturing the partial. - - The returned list holds the assistant work that was incrementally appended - during the failed attempt. The caller passes it to - ``_restore_partial_with_error_marker`` on final-failure exit and discards - it on a successful retry. - - Trailing error markers appended inside ``_run_stream_attempt`` (idle - timeout, circuit breaker) are stripped: re-attaching them would let the - post-loop restore replay a stale marker before adding its own, leaving - duplicate error bubbles and pushing any synthetic ``tool_result`` after - an assistant(error) turn that has no matching ``tool_use``. - """ - captured = list(session.messages[pre_attempt_msg_count:]) - while captured and _is_error_marker(captured[-1]): - captured.pop() - session.messages = session.messages[:pre_attempt_msg_count] - transcript_builder.restore(transcript_snap) # type: ignore[arg-type] - return captured - - -def _is_error_marker(msg: ChatMessage) -> bool: - """True if *msg* is an error marker emitted by ``_append_error_marker``.""" - if msg.role != "assistant" or not msg.content: - return False - return msg.content.startswith(COPILOT_ERROR_PREFIX) or msg.content.startswith( - COPILOT_RETRYABLE_ERROR_PREFIX - ) + if interrupted.handled_error is not None: + return interrupted.handled_error.error_msg, interrupted.handled_error.retryable + if attempts_exhausted: + return ( + "Your conversation is too long. " + "Please start a new chat or clear some history." + ), False + if transient_exhausted: + return FRIENDLY_TRANSIENT_MSG, True + if stream_err is not None: + safe_err = str(stream_err).replace("\n", " ").replace("\r", "")[:500] + return _friendly_error_text(safe_err), False + return None, False def _setup_langfuse_otel() -> None: @@ -3112,12 +3150,13 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues request_arrival_at: float = 0.0, **_kwargs: Any, ) -> AsyncGenerator[StreamBaseResponse, None]: - # Pyright's complexity heuristic bails on this function (~1500 LoC, retry + # Pyright's complexity heuristic bails on this ~1500 LoC function (retry # loop with context-overflow fallback + transient backoff + partial-work - # preservation). Splitting further would hurt readability — the branches - # share state (session, adapter, transcript builder, token accumulators) - # that's hard to pass cleanly through helpers. Suppress the bailout; real - # type errors elsewhere in the file remain surfaced. + # preservation). Splitting the retry loop further hurts readability — + # branches share mutable state (session, adapter, transcript builder, + # usage accumulators) that doesn't pass cleanly through helpers. The + # suppression only silences the complexity bailout; real type errors in + # the function body still surface. """Stream chat completion using Claude Agent SDK. Args: @@ -3281,14 +3320,11 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues turn_cost_usd: float | None = None graphiti_enabled = False pre_attempt_msg_count = 0 - # Holds messages that the retry loop rolls back from session.messages on a - # failed attempt. On final-failure exit we re-attach these so the user sees - # what the assistant produced before the error rather than an empty chat. - last_attempt_partial: list[ChatMessage] = [] - # Populated by the _HandledStreamError branch. Consumed once after the - # retry loop to re-attach the partial and, if the inner handler hadn't - # already emitted one, yield a single StreamError to the client. - handled_error_info: _HandledErrorInfo | None = None + # State of the latest failed attempt: rolled-back messages + any + # _HandledStreamError info to emit on final-failure exit. The retry loop + # mutates this via capture()/clear(); the post-loop block calls + # finalize() once. + interrupted = _InterruptedAttempt() # Defaults ensure the finally block can always reference these safely even when # an early return (e.g. sdk_cwd error) skips their normal assignment below. sdk_model: str | None = None @@ -3949,10 +3985,9 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues "using fallback model for this request" ) yield event - # Drop any stale partial captured from prior failed attempts - # so the outer cleanup paths don't re-attach pre-retry content - # the successful attempt already replaced. - last_attempt_partial.clear() + # Discard any state captured from prior failed attempts so + # outer cleanup paths don't replay pre-retry content. + interrupted.clear() break # Stream completed — exit retry loop except asyncio.CancelledError: logger.warning( @@ -3968,7 +4003,7 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues # session messages and set the error flag — do NOT set # stream_err so the post-loop code won't emit a # duplicate StreamError. - last_attempt_partial = _rollback_attempt_capturing_partial( + interrupted.capture( session, state.transcript_builder, transcript_snap, @@ -4007,7 +4042,7 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues # attempt that no longer match session.messages. Skip upload # so a future --resume doesn't replay rolled-back content. skip_transcript_upload = True - handled_error_info = _HandledErrorInfo( + interrupted.handled_error = _HandledErrorInfo( error_msg=exc.error_msg or FRIENDLY_TRANSIENT_MSG, code=exc.code or "transient_api_error", retryable=exc.retryable, @@ -4031,7 +4066,7 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues stream_err, exc_info=True, ) - last_attempt_partial = _rollback_attempt_capturing_partial( + interrupted.capture( session, state.transcript_builder, transcript_snap, @@ -4097,38 +4132,15 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues _MAX_STREAM_ATTEMPTS, stream_err, ) - # Restore the rolled-back partial work + add error marker exactly once - # per failure mode. Earlier revisions did this inline in each retry-loop - # branch; consolidating here keeps the retry loop itself simple enough - # for pyright to type-check. + # Restore the rolled-back partial + add error marker exactly once per + # failure mode. See _InterruptedAttempt for the carry-state contract. if ended_with_stream_error: - if handled_error_info is not None: - final_msg = handled_error_info.error_msg - final_retryable = handled_error_info.retryable - elif attempts_exhausted: - final_msg = ( - "Your conversation is too long. " - "Please start a new chat or clear some history." - ) - final_retryable = False - elif transient_exhausted: - final_msg = FRIENDLY_TRANSIENT_MSG - final_retryable = True - elif stream_err is not None: - final_msg = _friendly_error_text( - str(stream_err).replace("\n", " ").replace("\r", "")[:500] - ) - final_retryable = False - else: - final_msg = None - final_retryable = False + final_msg, final_retryable = _classify_final_failure( + interrupted, attempts_exhausted, transient_exhausted, stream_err + ) if final_msg is not None: - _restore_partial_with_error_marker( - session, - state, - last_attempt_partial, - final_msg, - retryable=final_retryable, + interrupted.finalize( + session, state, final_msg, retryable=final_retryable ) if ended_with_stream_error and state is not None: @@ -4173,10 +4185,13 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues # here when the inner handler chose not to (transient errors suppress # the early flash so the client only sees the final error after all # retries are exhausted). - if handled_error_info is not None and not handled_error_info.already_yielded: + if ( + interrupted.handled_error is not None + and not interrupted.handled_error.already_yielded + ): yield StreamError( - errorText=handled_error_info.error_msg, - code=handled_error_info.code, + errorText=interrupted.handled_error.error_msg, + code=interrupted.handled_error.code, ) # Copy token usage from retry state to outer-scope accumulators @@ -4259,24 +4274,13 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues else: display_msg, code = error_msg, "sdk_error" - # Append error marker to session (non-invasive text parsing approach). - # The finally block will persist the session with this error marker. - # Skip if a marker was already appended inside the stream loop - # (ended_with_stream_error) to avoid duplicate stale markers. + # Append error marker + restore any rolled-back partial when the retry + # loop didn't already finalize. ``interrupted`` is empty on success and + # on paths where the retry loop's own post-loop finalize() already ran, + # so this is a no-op for those and only kicks in for unhandled errors + # that bypass the retry-loop handlers entirely. if not ended_with_stream_error: - # Restore any rolled-back partial work the retry loop captured — - # last_attempt_partial is empty on success / final-failure-handled - # paths (cleared on success break, consumed inside the retry loop - # by _restore_partial_with_error_marker), so this is a no-op for - # those cases and only kicks in when an unhandled exception bypassed - # the retry loop's own restoration. - _restore_partial_with_error_marker( - session, - state, - last_attempt_partial, - display_msg, - retryable=is_transient, - ) + interrupted.finalize(session, state, display_msg, retryable=is_transient) logger.debug( "%s Appended error marker, will be persisted in finally", log_prefix,