fix(backend/copilot): preserve interrupted SDK partial work on final-failure exit

SECRT-2275 — when an SDK turn was interrupted (transient API errors with
exhausted retries, mid-stream LLM exceptions, or context-overflow with all
attempts exhausted) the retry loop's pre-decision rollback discarded the
assistant's partial work (text + tool calls + reasoning) that had been
incrementally appended to session.messages during the failed attempt.

Users described it as "the turn is gone": their UI streamed tokens live, then
a refresh showed an empty turn and the next message would prompt the model
to "continue" with no context, so it picked an unrelated old task.

Fix: capture the rolled-back partial in the retry-loop exception handlers and
re-attach it via a single helper on every final-failure branch (including
the events_yielded > 0 path that previously skipped the error marker entirely
and the non-context-non-transient + attempts-exhausted paths). Synthesize
"interrupted" tool_result rows for any orphan tool_use so the next turn's
LLM context stays API-valid. Successful retry breaks clear the captured
partial so attempt #1's rolled-back content doesn't leak into a successful
attempt #2's history.

Baseline path already preserves partial via its existing finally block; only
SDK was affected.
This commit is contained in:
majdyz
2026-04-25 08:16:37 +07:00
parent 2deac2073e
commit 7b66f255f5
2 changed files with 353 additions and 12 deletions

View File

@@ -0,0 +1,200 @@
"""Tests for partial-work preservation when an SDK turn is interrupted.
Covers the regression path SECRT-2275 surfaced: when the SDK retry loop
rolls back ``session.messages`` for a failed attempt (correct behavior so a
successful retry doesn't duplicate content) it MUST re-attach the rolled-back
work on final-failure exit. Otherwise the user's UI streamed tokens live but
a refresh shows an empty turn — described by users as "the turn is gone".
Tests target the helper functions directly (unit) plus the rollback-then-
restore contract (state-driven). Full end-to-end coverage of the retry loop
lives in retry_scenarios_test.py.
"""
from __future__ import annotations
from unittest.mock import MagicMock
from backend.copilot.constants import (
COPILOT_ERROR_PREFIX,
COPILOT_RETRYABLE_ERROR_PREFIX,
)
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.response_model import StreamToolOutputAvailable
from .service import (
_flush_orphan_tool_uses_to_session,
_restore_partial_with_error_marker,
)
def _make_session(messages: list[ChatMessage] | None = None) -> ChatSession:
session = ChatSession.new(user_id="user-1", dry_run=False)
session.messages = list(messages or [])
return session
def _make_tool_output(tool_call_id: str, output) -> StreamToolOutputAvailable:
return StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName="t",
output=output,
)
def _adapter_with_unresolved(unresolved_responses: list[StreamToolOutputAvailable]):
"""Build a stub _RetryState whose adapter flushes the given responses."""
adapter = MagicMock()
adapter.has_unresolved_tool_calls = bool(unresolved_responses)
def _flush(responses: list) -> None:
responses.extend(unresolved_responses)
adapter.has_unresolved_tool_calls = False
adapter._flush_unresolved_tool_calls.side_effect = _flush
state = MagicMock()
state.adapter = adapter
return state
class TestRestorePartialWithErrorMarker:
def test_appends_partial_then_marker_when_partial_present(self):
session = _make_session([ChatMessage(role="user", content="hi")])
partial = [
ChatMessage(role="assistant", content="I was working on "),
ChatMessage(role="tool", content="result-1", tool_call_id="t1"),
]
_restore_partial_with_error_marker(
session,
state=None,
partial=partial,
display_msg="Boom",
retryable=False,
)
# Pre-existing user msg + 2 partial msgs + error marker
assert len(session.messages) == 4
assert session.messages[1].content == "I was working on "
assert session.messages[2].role == "tool"
assert session.messages[3].content.startswith(COPILOT_ERROR_PREFIX)
# Partial list is consumed (cleared) so a stray follow-up call won't
# double-attach the same content.
assert partial == []
def test_only_marker_when_partial_empty(self):
session = _make_session([ChatMessage(role="user", content="hi")])
_restore_partial_with_error_marker(
session,
state=None,
partial=[],
display_msg="Boom",
retryable=True,
)
assert len(session.messages) == 2
assert session.messages[-1].content.startswith(COPILOT_RETRYABLE_ERROR_PREFIX)
def test_noop_when_session_is_none(self):
# Signature accepts None — must not raise.
_restore_partial_with_error_marker(
None,
state=None,
partial=[ChatMessage(role="assistant", content="x")],
display_msg="Boom",
retryable=False,
)
def test_flushes_unresolved_tools_between_partial_and_marker(self):
session = _make_session([ChatMessage(role="user", content="hi")])
partial = [
ChatMessage(
role="assistant",
content="calling tool",
tool_calls=[
{
"id": "t1",
"type": "function",
"function": {"name": "lookup", "arguments": "{}"},
}
],
),
]
state = _adapter_with_unresolved([_make_tool_output("t1", "interrupted")])
_restore_partial_with_error_marker(
session,
state=state,
partial=partial,
display_msg="Boom",
retryable=False,
)
roles = [m.role for m in session.messages]
# user, assistant(partial), tool(synthetic), assistant(error marker)
assert roles == ["user", "assistant", "tool", "assistant"]
synthetic_tool = session.messages[2]
assert synthetic_tool.tool_call_id == "t1"
assert synthetic_tool.content == "interrupted"
class TestFlushOrphanToolUses:
def test_appends_synthetic_tool_results_for_unresolved(self):
session = _make_session()
state = _adapter_with_unresolved(
[
_make_tool_output("t1", "r1"),
_make_tool_output("t2", {"ok": False}),
]
)
_flush_orphan_tool_uses_to_session(session, state)
assert [m.tool_call_id for m in session.messages] == ["t1", "t2"]
# Dict outputs are JSON-encoded so they survive the str-only ChatMessage
# content field without losing structure for the next-turn LLM read.
assert session.messages[1].content == '{"ok": false}'
def test_noop_when_state_is_none(self):
session = _make_session()
_flush_orphan_tool_uses_to_session(session, None)
assert session.messages == []
def test_noop_when_no_unresolved(self):
session = _make_session()
adapter = MagicMock()
adapter.has_unresolved_tool_calls = False
state = MagicMock()
state.adapter = adapter
_flush_orphan_tool_uses_to_session(session, state)
adapter._flush_unresolved_tool_calls.assert_not_called()
class TestRetryRollbackContract:
"""Property-style: a rolled-back attempt must be recoverable on final exit.
Simulates the retry loop's rollback by mirroring the exact slicing and
captured-list shape used in stream_chat_completion_sdk so that any drift
in that contract is caught here without needing the full SDK fixture.
"""
def test_capture_slice_matches_rollback(self):
session = _make_session([ChatMessage(role="user", content="hi")])
pre_attempt_msg_count = len(session.messages)
# Simulate incremental SDK appends during the attempt.
session.messages.extend(
[
ChatMessage(role="assistant", content="part-1"),
ChatMessage(role="assistant", content="part-2"),
]
)
captured = list(session.messages[pre_attempt_msg_count:])
session.messages = session.messages[:pre_attempt_msg_count]
# Final-failure restore.
_restore_partial_with_error_marker(
session,
state=None,
partial=captured,
display_msg="Boom",
retryable=False,
)
contents = [m.content for m in session.messages]
assert contents == [
"hi",
"part-1",
"part-2",
f"{COPILOT_ERROR_PREFIX} Boom",
]

View File

@@ -557,6 +557,80 @@ def _append_error_marker(
)
def _flush_orphan_tool_uses_to_session(
session: "ChatSession | None",
state: "_RetryState | None",
) -> None:
"""Synthesize tool_result rows for tool_use blocks that never resolved.
Without this, partial assistant work re-attached after a final failure
would carry orphan tool_use blocks. The next turn's LLM call would error
with ``tool_use_id without tool_result`` — and any baseline replay would
surface the same broken history. Flushing produces interrupted-marker
tool_results that satisfy the API contract.
"""
if session is None or state is None:
return
if not state.adapter.has_unresolved_tool_calls:
return
safety: list[StreamBaseResponse] = []
state.adapter._flush_unresolved_tool_calls(safety) # noqa: SLF001
for resp in safety:
if isinstance(resp, StreamToolOutputAvailable):
content = (
resp.output
if isinstance(resp.output, str)
else json.dumps(resp.output, ensure_ascii=False)
)
session.messages.append(
ChatMessage(role="tool", content=content, tool_call_id=resp.toolCallId)
)
def _restore_partial_with_error_marker(
session: "ChatSession | None",
state: "_RetryState | None",
partial: list[ChatMessage],
display_msg: str,
*,
retryable: bool,
) -> None:
"""Re-attach a rolled-back attempt's partial work, then add the error marker.
Called when retries are exhausted or no retry is attempted. Without this,
the SDK retry loop's pre-decision rollback would discard everything the
assistant produced in the failed attempt (text, tool calls, reasoning),
leaving the user with a chat that looks like nothing happened — even
though the events were already streamed live to their UI.
"""
if session is None:
return
if partial:
session.messages.extend(partial)
partial.clear()
_flush_orphan_tool_uses_to_session(session, state)
_append_error_marker(session, display_msg, retryable=retryable)
def _rollback_attempt_capturing_partial(
session: "ChatSession",
transcript_builder: "TranscriptBuilder",
transcript_snap: object,
pre_attempt_msg_count: int,
) -> list[ChatMessage]:
"""Roll back an attempt's session.messages + transcript, capturing the partial.
The returned list holds the assistant work that was incrementally appended
during the failed attempt. The caller passes it to
``_restore_partial_with_error_marker`` on final-failure exit and discards
it on a successful retry.
"""
captured = list(session.messages[pre_attempt_msg_count:])
session.messages = session.messages[:pre_attempt_msg_count]
transcript_builder.restore(transcript_snap) # type: ignore[arg-type]
return captured
def _setup_langfuse_otel() -> None:
"""Configure OTEL tracing for the Claude Agent SDK → Langfuse.
@@ -3169,6 +3243,10 @@ async def stream_chat_completion_sdk(
turn_cost_usd: float | None = None
graphiti_enabled = False
pre_attempt_msg_count = 0
# Holds messages that the retry loop rolls back from session.messages on a
# failed attempt. On final-failure exit we re-attach these so the user sees
# what the assistant produced before the error rather than an empty chat.
last_attempt_partial: list[ChatMessage] = []
# Defaults ensure the finally block can always reference these safely even when
# an early return (e.g. sdk_cwd error) skips their normal assignment below.
sdk_model: str | None = None
@@ -3829,6 +3907,10 @@ async def stream_chat_completion_sdk(
"using fallback model for this request"
)
yield event
# Drop any stale partial captured from prior failed attempts
# so the outer cleanup paths don't re-attach pre-retry content
# the successful attempt already replaced.
last_attempt_partial.clear()
break # Stream completed — exit retry loop
except asyncio.CancelledError:
logger.warning(
@@ -3844,8 +3926,12 @@ async def stream_chat_completion_sdk(
# session messages and set the error flag — do NOT set
# stream_err so the post-loop code won't emit a
# duplicate StreamError.
session.messages = session.messages[:pre_attempt_msg_count]
state.transcript_builder.restore(transcript_snap)
last_attempt_partial = _rollback_attempt_capturing_partial(
session,
state.transcript_builder,
transcript_snap,
pre_attempt_msg_count,
)
# Check if this is a transient error we can retry with backoff.
# exc.code is the only reliable signal — str(exc) is always the
# static "Stream error handled — StreamError already yielded" message.
@@ -3879,12 +3965,13 @@ async def stream_chat_completion_sdk(
# attempt that no longer match session.messages. Skip upload
# so a future --resume doesn't replay rolled-back content.
skip_transcript_upload = True
# Re-append the error marker so it survives the rollback
# and is persisted by the finally block (see #2947655365).
# Use the specific error message from the attempt (e.g.
# circuit breaker msg) rather than always the generic one.
_append_error_marker(
# Re-attach the rolled-back partial work + add error marker.
# Without partial restoration the user's UI streamed tokens
# live but a refresh shows nothing happened.
_restore_partial_with_error_marker(
session,
state,
last_attempt_partial,
exc.error_msg or FRIENDLY_TRANSIENT_MSG,
retryable=True,
)
@@ -3917,8 +4004,12 @@ async def stream_chat_completion_sdk(
stream_err,
exc_info=True,
)
session.messages = session.messages[:pre_attempt_msg_count]
state.transcript_builder.restore(transcript_snap)
last_attempt_partial = _rollback_attempt_capturing_partial(
session,
state.transcript_builder,
transcript_snap,
pre_attempt_msg_count,
)
if events_yielded > 0:
# Events were already sent to the frontend and cannot be
# unsent. Retrying would produce duplicate/inconsistent
@@ -3929,6 +4020,19 @@ async def stream_chat_completion_sdk(
events_yielded,
)
skip_transcript_upload = True
# Restore the streamed partial + add error marker — without
# this the frontend would briefly show streamed text live
# then a refresh would show an empty turn.
safe_err = (
str(stream_err).replace("\n", " ").replace("\r", "")[:500]
)
_restore_partial_with_error_marker(
session,
state,
last_attempt_partial,
_friendly_error_text(safe_err),
retryable=False,
)
ended_with_stream_error = True
break
# Transient API errors (ECONNRESET, 429, 5xx) — retry
@@ -3956,8 +4060,12 @@ async def stream_chat_completion_sdk(
# at line ~2310.
transient_exhausted = True
skip_transcript_upload = True
_append_error_marker(
session, FRIENDLY_TRANSIENT_MSG, retryable=True
_restore_partial_with_error_marker(
session,
state,
last_attempt_partial,
FRIENDLY_TRANSIENT_MSG,
retryable=True,
)
ended_with_stream_error = True
break
@@ -3966,6 +4074,16 @@ async def stream_chat_completion_sdk(
# Non-context, non-transient errors (auth, fatal)
# should not trigger compaction — surface immediately.
skip_transcript_upload = True
safe_err = (
str(stream_err).replace("\n", " ").replace("\r", "")[:500]
)
_restore_partial_with_error_marker(
session,
state,
last_attempt_partial,
_friendly_error_text(safe_err),
retryable=False,
)
ended_with_stream_error = True
break
attempt += 1 # advance to next context-level attempt
@@ -3982,6 +4100,17 @@ async def stream_chat_completion_sdk(
_MAX_STREAM_ATTEMPTS,
stream_err,
)
# Restore the last attempt's partial work (rolled back by the
# exhausted context-level retry) so the user sees what was
# produced before the conversation hit the context ceiling.
_restore_partial_with_error_marker(
session,
state,
last_attempt_partial,
"Your conversation is too long. "
"Please start a new chat or clear some history.",
retryable=False,
)
if ended_with_stream_error and state is not None:
# Flush any unresolved tool calls so the frontend can close
@@ -4105,7 +4234,19 @@ async def stream_chat_completion_sdk(
# Skip if a marker was already appended inside the stream loop
# (ended_with_stream_error) to avoid duplicate stale markers.
if not ended_with_stream_error:
_append_error_marker(session, display_msg, retryable=is_transient)
# Restore any rolled-back partial work the retry loop captured —
# last_attempt_partial is empty on success / final-failure-handled
# paths (cleared on success break, consumed inside the retry loop
# by _restore_partial_with_error_marker), so this is a no-op for
# those cases and only kicks in when an unhandled exception bypassed
# the retry loop's own restoration.
_restore_partial_with_error_marker(
session,
state,
last_attempt_partial,
display_msg,
retryable=is_transient,
)
logger.debug(
"%s Appended error marker, will be persisted in finally",
log_prefix,