mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
refactor(backend/copilot): consolidate SDK partial-preserve state into _InterruptedAttempt
Previous revisions carried the failed-attempt state across three separate function-scope variables (last_attempt_partial, handled_error_info) + four module-level helpers (_rollback_attempt_capturing_partial, _restore_partial_with_error_marker, _flush_orphan_tool_uses_to_session, _append_error_marker). The retry loop mutated all three and the post-loop block reassembled the pieces by hand. Scattered and hard to follow. Collapse to one dataclass with capture / clear / finalize + one _classify_final_failure helper that picks the display message based on which failure flag the retry loop set (attempts_exhausted, transient_exhausted, stream_err, handled_error). Call sites: - success break: interrupted.clear() - _HandledStreamError: interrupted.capture(...); interrupted.handled_error = ... - Exception: interrupted.capture(...) - post-loop: final_msg, retryable = _classify_final_failure(interrupted, ...); interrupted.finalize(...) - outer except: interrupted.finalize(...) Behaviour is unchanged — same restore semantics, same StreamError sequencing, same transcript-upload skip, same orphan tool_use flush, same stale-marker stripping fromb1172e203/5406fe9b5. The retry-scenarios suite (48 integration tests) plus the rewritten interrupted_partial_test (14 unit tests) both pass; the full SDK test suite (1012 tests) is green.
This commit is contained in:
@@ -1,14 +1,13 @@
|
||||
"""Tests for partial-work preservation when an SDK turn is interrupted.
|
||||
|
||||
Covers the regression path SECRT-2275 surfaced: when the SDK retry loop
|
||||
rolls back ``session.messages`` for a failed attempt (correct behavior so a
|
||||
successful retry doesn't duplicate content) it MUST re-attach the rolled-back
|
||||
work on final-failure exit. Otherwise the user's UI streamed tokens live but
|
||||
Covers the regression SECRT-2275 surfaced: when the SDK retry loop rolls
|
||||
back ``session.messages`` for a failed attempt (correct so a successful
|
||||
retry doesn't duplicate content) it MUST re-attach the rolled-back work on
|
||||
final-failure exit. Without that, the user's UI streamed tokens live then
|
||||
a refresh shows an empty turn — described by users as "the turn is gone".
|
||||
|
||||
Tests target the helper functions directly (unit) plus the rollback-then-
|
||||
restore contract (state-driven). Full end-to-end coverage of the retry loop
|
||||
lives in retry_scenarios_test.py.
|
||||
Tests target the ``_InterruptedAttempt`` dataclass + the orphan-tool flush
|
||||
directly. Full retry-loop coverage lives in ``retry_scenarios_test.py``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -24,8 +23,8 @@ from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
|
||||
from .service import (
|
||||
_flush_orphan_tool_uses_to_session,
|
||||
_restore_partial_with_error_marker,
|
||||
_rollback_attempt_capturing_partial,
|
||||
_HandledErrorInfo,
|
||||
_InterruptedAttempt,
|
||||
)
|
||||
|
||||
|
||||
@@ -35,21 +34,19 @@ def _make_session(messages: list[ChatMessage] | None = None) -> ChatSession:
|
||||
return session
|
||||
|
||||
|
||||
def _make_tool_output(tool_call_id: str, output) -> StreamToolOutputAvailable:
|
||||
def _tool_output(tool_call_id: str, output) -> StreamToolOutputAvailable:
|
||||
return StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName="t",
|
||||
output=output,
|
||||
toolCallId=tool_call_id, toolName="t", output=output
|
||||
)
|
||||
|
||||
|
||||
def _adapter_with_unresolved(unresolved_responses: list[StreamToolOutputAvailable]):
|
||||
"""Build a stub _RetryState whose adapter flushes the given responses."""
|
||||
def _adapter_with_unresolved(responses: list[StreamToolOutputAvailable]):
|
||||
"""Stub _RetryState whose adapter flushes the given responses."""
|
||||
adapter = MagicMock()
|
||||
adapter.has_unresolved_tool_calls = bool(unresolved_responses)
|
||||
adapter.has_unresolved_tool_calls = bool(responses)
|
||||
|
||||
def _flush(responses: list) -> None:
|
||||
responses.extend(unresolved_responses)
|
||||
def _flush(out: list) -> None:
|
||||
out.extend(responses)
|
||||
adapter.has_unresolved_tool_calls = False
|
||||
|
||||
adapter._flush_unresolved_tool_calls.side_effect = _flush
|
||||
@@ -58,145 +55,29 @@ def _adapter_with_unresolved(unresolved_responses: list[StreamToolOutputAvailabl
|
||||
return state
|
||||
|
||||
|
||||
class TestRestorePartialWithErrorMarker:
|
||||
def test_appends_partial_then_marker_when_partial_present(self):
|
||||
session = _make_session([ChatMessage(role="user", content="hi")])
|
||||
partial = [
|
||||
ChatMessage(role="assistant", content="I was working on "),
|
||||
ChatMessage(role="tool", content="result-1", tool_call_id="t1"),
|
||||
]
|
||||
_restore_partial_with_error_marker(
|
||||
session,
|
||||
state=None,
|
||||
partial=partial,
|
||||
display_msg="Boom",
|
||||
retryable=False,
|
||||
)
|
||||
# Pre-existing user msg + 2 partial msgs + error marker
|
||||
assert len(session.messages) == 4
|
||||
assert session.messages[1].content == "I was working on "
|
||||
assert session.messages[2].role == "tool"
|
||||
assert session.messages[3].content.startswith(COPILOT_ERROR_PREFIX)
|
||||
# Partial list is consumed (cleared) so a stray follow-up call won't
|
||||
# double-attach the same content.
|
||||
assert partial == []
|
||||
|
||||
def test_only_marker_when_partial_empty(self):
|
||||
session = _make_session([ChatMessage(role="user", content="hi")])
|
||||
_restore_partial_with_error_marker(
|
||||
session,
|
||||
state=None,
|
||||
partial=[],
|
||||
display_msg="Boom",
|
||||
retryable=True,
|
||||
)
|
||||
assert len(session.messages) == 2
|
||||
assert session.messages[-1].content.startswith(COPILOT_RETRYABLE_ERROR_PREFIX)
|
||||
|
||||
def test_noop_when_session_is_none(self):
|
||||
# Signature accepts None — must not raise.
|
||||
_restore_partial_with_error_marker(
|
||||
None,
|
||||
state=None,
|
||||
partial=[ChatMessage(role="assistant", content="x")],
|
||||
display_msg="Boom",
|
||||
retryable=False,
|
||||
)
|
||||
|
||||
def test_flushes_unresolved_tools_between_partial_and_marker(self):
|
||||
session = _make_session([ChatMessage(role="user", content="hi")])
|
||||
partial = [
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="calling tool",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "t1",
|
||||
"type": "function",
|
||||
"function": {"name": "lookup", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
state = _adapter_with_unresolved([_make_tool_output("t1", "interrupted")])
|
||||
_restore_partial_with_error_marker(
|
||||
session,
|
||||
state=state,
|
||||
partial=partial,
|
||||
display_msg="Boom",
|
||||
retryable=False,
|
||||
)
|
||||
roles = [m.role for m in session.messages]
|
||||
# user, assistant(partial), tool(synthetic), assistant(error marker)
|
||||
assert roles == ["user", "assistant", "tool", "assistant"]
|
||||
synthetic_tool = session.messages[2]
|
||||
assert synthetic_tool.tool_call_id == "t1"
|
||||
assert synthetic_tool.content == "interrupted"
|
||||
def _builder_stub() -> MagicMock:
|
||||
builder = MagicMock()
|
||||
builder.restore = MagicMock()
|
||||
return builder
|
||||
|
||||
|
||||
class TestFlushOrphanToolUses:
|
||||
def test_appends_synthetic_tool_results_for_unresolved(self):
|
||||
session = _make_session()
|
||||
state = _adapter_with_unresolved(
|
||||
[
|
||||
_make_tool_output("t1", "r1"),
|
||||
_make_tool_output("t2", {"ok": False}),
|
||||
]
|
||||
)
|
||||
_flush_orphan_tool_uses_to_session(session, state)
|
||||
assert [m.tool_call_id for m in session.messages] == ["t1", "t2"]
|
||||
# Dict outputs are JSON-encoded so they survive the str-only ChatMessage
|
||||
# content field without losing structure for the next-turn LLM read.
|
||||
assert session.messages[1].content == '{"ok": false}'
|
||||
|
||||
def test_noop_when_state_is_none(self):
|
||||
session = _make_session()
|
||||
_flush_orphan_tool_uses_to_session(session, None)
|
||||
assert session.messages == []
|
||||
|
||||
def test_noop_when_no_unresolved(self):
|
||||
session = _make_session()
|
||||
adapter = MagicMock()
|
||||
adapter.has_unresolved_tool_calls = False
|
||||
state = MagicMock()
|
||||
state.adapter = adapter
|
||||
_flush_orphan_tool_uses_to_session(session, state)
|
||||
adapter._flush_unresolved_tool_calls.assert_not_called()
|
||||
|
||||
|
||||
class TestRollbackCapturingPartial:
|
||||
"""Direct tests of `_rollback_attempt_capturing_partial`.
|
||||
|
||||
The retry loop relies on this helper not to leak error markers that
|
||||
`_run_stream_attempt` already appended to `session.messages` — otherwise
|
||||
the post-loop restore replays a stale marker before adding its own,
|
||||
leaving duplicate error bubbles.
|
||||
"""
|
||||
|
||||
def _builder_with_snap(self):
|
||||
builder = MagicMock()
|
||||
builder.restore = MagicMock()
|
||||
return builder
|
||||
|
||||
def test_returns_partial_when_no_marker_present(self):
|
||||
class TestInterruptedAttemptCapture:
|
||||
def test_keeps_partial_when_no_marker_present(self):
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="hi"),
|
||||
ChatMessage(role="assistant", content="part-1"),
|
||||
]
|
||||
)
|
||||
builder = self._builder_with_snap()
|
||||
captured = _rollback_attempt_capturing_partial(
|
||||
session, builder, transcript_snap=object(), pre_attempt_msg_count=1
|
||||
)
|
||||
assert [m.content for m in captured] == ["part-1"]
|
||||
assert session.messages == [ChatMessage(role="user", content="hi")]
|
||||
attempt = _InterruptedAttempt()
|
||||
attempt.capture(session, _builder_stub(), object(), pre_attempt_msg_count=1)
|
||||
assert [m.content for m in attempt.partial] == ["part-1"]
|
||||
assert [m.content for m in session.messages] == ["hi"]
|
||||
|
||||
def test_strips_trailing_error_marker(self):
|
||||
# _run_stream_attempt appended a marker via _append_error_marker
|
||||
# (e.g. idle timeout, circuit breaker) before raising
|
||||
# _HandledStreamError. The rollback must NOT carry it forward, or
|
||||
# the post-loop restore will replay the stale marker + add its own.
|
||||
# _run_stream_attempt may append a marker (idle timeout, circuit
|
||||
# breaker) before raising _HandledStreamError. Carrying it forward
|
||||
# would let finalize() replay it and then add its own.
|
||||
marker = (
|
||||
f"{COPILOT_RETRYABLE_ERROR_PREFIX} The session has been idle "
|
||||
"for too long. Please try again."
|
||||
@@ -208,66 +89,136 @@ class TestRollbackCapturingPartial:
|
||||
ChatMessage(role="assistant", content=marker),
|
||||
]
|
||||
)
|
||||
captured = _rollback_attempt_capturing_partial(
|
||||
session,
|
||||
self._builder_with_snap(),
|
||||
transcript_snap=object(),
|
||||
pre_attempt_msg_count=1,
|
||||
)
|
||||
assert [m.content for m in captured] == ["part-1"]
|
||||
attempt = _InterruptedAttempt()
|
||||
attempt.capture(session, _builder_stub(), object(), pre_attempt_msg_count=1)
|
||||
assert [m.content for m in attempt.partial] == ["part-1"]
|
||||
|
||||
def test_strips_consecutive_error_markers(self):
|
||||
# Defensive: if more than one marker landed back-to-back (legacy
|
||||
# path or future regression), strip them all.
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="hi"),
|
||||
ChatMessage(role="assistant", content="part-1"),
|
||||
ChatMessage(role="assistant", content=f"{COPILOT_ERROR_PREFIX} a"),
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=f"{COPILOT_RETRYABLE_ERROR_PREFIX} b",
|
||||
role="assistant", content=f"{COPILOT_RETRYABLE_ERROR_PREFIX} b"
|
||||
),
|
||||
]
|
||||
)
|
||||
captured = _rollback_attempt_capturing_partial(
|
||||
session,
|
||||
self._builder_with_snap(),
|
||||
transcript_snap=object(),
|
||||
pre_attempt_msg_count=1,
|
||||
)
|
||||
assert [m.content for m in captured] == ["part-1"]
|
||||
attempt = _InterruptedAttempt()
|
||||
attempt.capture(session, _builder_stub(), object(), pre_attempt_msg_count=1)
|
||||
assert [m.content for m in attempt.partial] == ["part-1"]
|
||||
|
||||
def test_does_not_strip_non_marker_assistant(self):
|
||||
# Regular assistant text starting with similar-but-not-prefix
|
||||
# content must be preserved — only the canonical error markers
|
||||
# should be filtered.
|
||||
def test_preserves_non_marker_assistant(self):
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="hi"),
|
||||
ChatMessage(role="assistant", content="Important note"),
|
||||
]
|
||||
)
|
||||
captured = _rollback_attempt_capturing_partial(
|
||||
session,
|
||||
self._builder_with_snap(),
|
||||
transcript_snap=object(),
|
||||
pre_attempt_msg_count=1,
|
||||
attempt = _InterruptedAttempt()
|
||||
attempt.capture(session, _builder_stub(), object(), pre_attempt_msg_count=1)
|
||||
assert [m.content for m in attempt.partial] == ["Important note"]
|
||||
|
||||
|
||||
class TestInterruptedAttemptFinalize:
|
||||
def test_appends_partial_then_marker(self):
|
||||
session = _make_session([ChatMessage(role="user", content="hi")])
|
||||
attempt = _InterruptedAttempt(
|
||||
partial=[
|
||||
ChatMessage(role="assistant", content="working"),
|
||||
ChatMessage(role="tool", content="result", tool_call_id="t1"),
|
||||
]
|
||||
)
|
||||
assert [m.content for m in captured] == ["Important note"]
|
||||
attempt.finalize(session, state=None, display_msg="Boom", retryable=False)
|
||||
roles = [m.role for m in session.messages]
|
||||
assert roles == ["user", "assistant", "tool", "assistant"]
|
||||
assert session.messages[-1].content.startswith(COPILOT_ERROR_PREFIX)
|
||||
# partial consumed so a follow-up finalize() is a no-op for partial.
|
||||
assert attempt.partial == []
|
||||
|
||||
def test_only_marker_when_partial_empty(self):
|
||||
session = _make_session([ChatMessage(role="user", content="hi")])
|
||||
attempt = _InterruptedAttempt()
|
||||
attempt.finalize(session, state=None, display_msg="Boom", retryable=True)
|
||||
assert len(session.messages) == 2
|
||||
assert session.messages[-1].content.startswith(COPILOT_RETRYABLE_ERROR_PREFIX)
|
||||
|
||||
def test_noop_when_session_is_none(self):
|
||||
attempt = _InterruptedAttempt(
|
||||
partial=[ChatMessage(role="assistant", content="x")]
|
||||
)
|
||||
attempt.finalize(None, state=None, display_msg="Boom", retryable=False)
|
||||
# no raise = pass
|
||||
|
||||
def test_flushes_unresolved_tools_between_partial_and_marker(self):
|
||||
session = _make_session([ChatMessage(role="user", content="hi")])
|
||||
attempt = _InterruptedAttempt(
|
||||
partial=[
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="calling",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "t1",
|
||||
"type": "function",
|
||||
"function": {"name": "lookup", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
state = _adapter_with_unresolved([_tool_output("t1", "interrupted")])
|
||||
attempt.finalize(session, state=state, display_msg="Boom", retryable=False)
|
||||
roles = [m.role for m in session.messages]
|
||||
assert roles == ["user", "assistant", "tool", "assistant"]
|
||||
assert session.messages[2].tool_call_id == "t1"
|
||||
assert session.messages[2].content == "interrupted"
|
||||
|
||||
def test_clear_drops_both_partial_and_handled_error(self):
|
||||
attempt = _InterruptedAttempt(
|
||||
partial=[ChatMessage(role="assistant", content="x")],
|
||||
handled_error=_HandledErrorInfo(
|
||||
error_msg="m", code="c", retryable=True, already_yielded=False
|
||||
),
|
||||
)
|
||||
attempt.clear()
|
||||
assert attempt.partial == []
|
||||
assert attempt.handled_error is None
|
||||
|
||||
|
||||
class TestFlushOrphanToolUses:
|
||||
def test_appends_synthetic_tool_results_for_unresolved(self):
|
||||
session = _make_session()
|
||||
state = _adapter_with_unresolved(
|
||||
[_tool_output("t1", "r1"), _tool_output("t2", {"ok": False})]
|
||||
)
|
||||
_flush_orphan_tool_uses_to_session(session, state)
|
||||
assert [m.tool_call_id for m in session.messages] == ["t1", "t2"]
|
||||
# Dict outputs are JSON-encoded so structure survives the str-only
|
||||
# ChatMessage content field for the next-turn LLM read.
|
||||
assert session.messages[1].content == '{"ok": false}'
|
||||
|
||||
def test_noop_when_state_is_none(self):
|
||||
session = _make_session()
|
||||
_flush_orphan_tool_uses_to_session(session, None)
|
||||
assert session.messages == []
|
||||
|
||||
def test_noop_when_no_unresolved(self):
|
||||
adapter = MagicMock()
|
||||
adapter.has_unresolved_tool_calls = False
|
||||
state = MagicMock()
|
||||
state.adapter = adapter
|
||||
_flush_orphan_tool_uses_to_session(_make_session(), state)
|
||||
adapter._flush_unresolved_tool_calls.assert_not_called()
|
||||
|
||||
|
||||
class TestRetryRollbackContract:
|
||||
"""Property-style: a rolled-back attempt must be recoverable on final exit.
|
||||
"""End-to-end contract: capture on a rolled-back attempt + finalize yields
|
||||
the exact content the user saw streaming live, plus the error marker."""
|
||||
|
||||
Simulates the retry loop's rollback by mirroring the exact slicing and
|
||||
captured-list shape used in stream_chat_completion_sdk so that any drift
|
||||
in that contract is caught here without needing the full SDK fixture.
|
||||
"""
|
||||
|
||||
def test_capture_slice_matches_rollback(self):
|
||||
def test_capture_then_finalize_matches_streamed_sequence(self):
|
||||
session = _make_session([ChatMessage(role="user", content="hi")])
|
||||
pre_attempt_msg_count = len(session.messages)
|
||||
pre = len(session.messages)
|
||||
# Simulate incremental SDK appends during the attempt.
|
||||
session.messages.extend(
|
||||
[
|
||||
@@ -275,18 +226,11 @@ class TestRetryRollbackContract:
|
||||
ChatMessage(role="assistant", content="part-2"),
|
||||
]
|
||||
)
|
||||
captured = list(session.messages[pre_attempt_msg_count:])
|
||||
session.messages = session.messages[:pre_attempt_msg_count]
|
||||
# Final-failure restore.
|
||||
_restore_partial_with_error_marker(
|
||||
session,
|
||||
state=None,
|
||||
partial=captured,
|
||||
display_msg="Boom",
|
||||
retryable=False,
|
||||
)
|
||||
contents = [m.content for m in session.messages]
|
||||
assert contents == [
|
||||
attempt = _InterruptedAttempt()
|
||||
attempt.capture(session, _builder_stub(), object(), pre)
|
||||
# Final-failure path — no retry, no success clear().
|
||||
attempt.finalize(session, state=None, display_msg="Boom", retryable=False)
|
||||
assert [m.content for m in session.messages] == [
|
||||
"hi",
|
||||
"part-1",
|
||||
"part-2",
|
||||
|
||||
@@ -541,14 +541,7 @@ def _append_error_marker(
|
||||
*,
|
||||
retryable: bool = False,
|
||||
) -> None:
|
||||
"""Append a copilot error marker to *session* so it persists across refresh.
|
||||
|
||||
Args:
|
||||
session: The chat session to append to (no-op if `None`).
|
||||
display_msg: User-visible error text.
|
||||
retryable: If `True`, use the retryable prefix so the frontend
|
||||
shows a "Try Again" button.
|
||||
"""
|
||||
"""Append a copilot error marker to *session* so it persists across refresh."""
|
||||
if session is None:
|
||||
return
|
||||
prefix = COPILOT_RETRYABLE_ERROR_PREFIX if retryable else COPILOT_ERROR_PREFIX
|
||||
@@ -557,17 +550,97 @@ def _append_error_marker(
|
||||
)
|
||||
|
||||
|
||||
def _is_error_marker(msg: ChatMessage) -> bool:
|
||||
"""True if *msg* is an error marker emitted by ``_append_error_marker``."""
|
||||
if msg.role != "assistant" or not msg.content:
|
||||
return False
|
||||
return msg.content.startswith(COPILOT_ERROR_PREFIX) or msg.content.startswith(
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _InterruptedAttempt:
|
||||
"""Captured state of a failed SDK attempt, carried across the retry loop.
|
||||
|
||||
The SDK always rolls back ``session.messages`` before deciding whether
|
||||
to retry (so attempt #2 starts clean). That rollback would otherwise
|
||||
discard everything the assistant produced — the user sees tokens stream
|
||||
live, then a refresh shows nothing. This dataclass holds the rolled-back
|
||||
messages plus the ``_HandledStreamError`` info needed to emit a final
|
||||
``StreamError`` once the loop decides not to retry.
|
||||
|
||||
The retry loop calls ``capture()`` on every failed attempt, ``clear()``
|
||||
on a successful retry (so prior rolled-back content is not replayed),
|
||||
and ``finalize()`` exactly once after the loop on final failure.
|
||||
"""
|
||||
|
||||
partial: list[ChatMessage] = dataclass_field(default_factory=list)
|
||||
# Populated by the ``except _HandledStreamError`` branch so the post-loop
|
||||
# block can restore the partial and (when the inner handler didn't) emit
|
||||
# the client-facing StreamError. Transient errors deliberately suppress
|
||||
# the early StreamError flash and rely on this post-loop emit.
|
||||
handled_error: "_HandledErrorInfo | None" = None
|
||||
|
||||
def capture(
|
||||
self,
|
||||
session: ChatSession,
|
||||
transcript_builder: "TranscriptBuilder",
|
||||
transcript_snap: object,
|
||||
pre_attempt_msg_count: int,
|
||||
) -> None:
|
||||
"""Roll back ``session.messages`` + transcript, keeping the partial.
|
||||
|
||||
Trailing error markers appended inside ``_run_stream_attempt`` (idle
|
||||
timeout, circuit breaker) are stripped: re-attaching them would make
|
||||
the post-loop restore replay a stale marker before adding its own,
|
||||
leaving duplicate error bubbles.
|
||||
"""
|
||||
tail = list(session.messages[pre_attempt_msg_count:])
|
||||
while tail and _is_error_marker(tail[-1]):
|
||||
tail.pop()
|
||||
self.partial = tail
|
||||
session.messages = session.messages[:pre_attempt_msg_count]
|
||||
transcript_builder.restore(transcript_snap) # type: ignore[arg-type]
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Drop captured state — used on successful retry."""
|
||||
self.partial = []
|
||||
self.handled_error = None
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
session: ChatSession | None,
|
||||
state: "_RetryState | None",
|
||||
display_msg: str,
|
||||
*,
|
||||
retryable: bool,
|
||||
) -> None:
|
||||
"""Re-attach partial + synthetic tool_result rows + error marker.
|
||||
|
||||
Called exactly once after the retry loop on final-failure exit.
|
||||
Idempotent on empty state, so it's safe to call on paths where no
|
||||
rollback happened.
|
||||
"""
|
||||
if session is None:
|
||||
return
|
||||
if self.partial:
|
||||
session.messages.extend(self.partial)
|
||||
self.partial = []
|
||||
_flush_orphan_tool_uses_to_session(session, state)
|
||||
_append_error_marker(session, display_msg, retryable=retryable)
|
||||
|
||||
|
||||
def _flush_orphan_tool_uses_to_session(
|
||||
session: "ChatSession | None",
|
||||
state: "_RetryState | None",
|
||||
) -> None:
|
||||
"""Synthesize tool_result rows for tool_use blocks that never resolved.
|
||||
"""Synthesize ``tool_result`` rows for ``tool_use`` blocks that never resolved.
|
||||
|
||||
Without this, partial assistant work re-attached after a final failure
|
||||
would carry orphan tool_use blocks. The next turn's LLM call would error
|
||||
with ``tool_use_id without tool_result`` — and any baseline replay would
|
||||
surface the same broken history. Flushing produces interrupted-marker
|
||||
tool_results that satisfy the API contract.
|
||||
Re-attached partial work may carry orphan ``tool_use`` blocks; without
|
||||
matching ``tool_result`` rows the next turn's LLM call would error with
|
||||
``tool_use_id without tool_result``. The adapter's safety-flush produces
|
||||
interrupted-marker results that satisfy the API contract.
|
||||
"""
|
||||
if session is None or state is None:
|
||||
return
|
||||
@@ -587,65 +660,30 @@ def _flush_orphan_tool_uses_to_session(
|
||||
)
|
||||
|
||||
|
||||
def _restore_partial_with_error_marker(
|
||||
session: "ChatSession | None",
|
||||
state: "_RetryState | None",
|
||||
partial: list[ChatMessage],
|
||||
display_msg: str,
|
||||
*,
|
||||
retryable: bool,
|
||||
) -> None:
|
||||
"""Re-attach a rolled-back attempt's partial work, then add the error marker.
|
||||
def _classify_final_failure(
|
||||
interrupted: _InterruptedAttempt,
|
||||
attempts_exhausted: bool,
|
||||
transient_exhausted: bool,
|
||||
stream_err: BaseException | None,
|
||||
) -> tuple[str | None, bool]:
|
||||
"""Pick the display message + retryable flag for the post-loop restore.
|
||||
|
||||
Called when retries are exhausted or no retry is attempted. Without this,
|
||||
the SDK retry loop's pre-decision rollback would discard everything the
|
||||
assistant produced in the failed attempt (text, tool calls, reasoning),
|
||||
leaving the user with a chat that looks like nothing happened — even
|
||||
though the events were already streamed live to their UI.
|
||||
Mirrors the error-code selection used for the client-facing ``StreamError``
|
||||
yield so the in-history marker and the SSE event stay consistent.
|
||||
"""
|
||||
if session is None:
|
||||
return
|
||||
if partial:
|
||||
session.messages.extend(partial)
|
||||
partial.clear()
|
||||
_flush_orphan_tool_uses_to_session(session, state)
|
||||
_append_error_marker(session, display_msg, retryable=retryable)
|
||||
|
||||
|
||||
def _rollback_attempt_capturing_partial(
|
||||
session: "ChatSession",
|
||||
transcript_builder: "TranscriptBuilder",
|
||||
transcript_snap: object,
|
||||
pre_attempt_msg_count: int,
|
||||
) -> list[ChatMessage]:
|
||||
"""Roll back an attempt's session.messages + transcript, capturing the partial.
|
||||
|
||||
The returned list holds the assistant work that was incrementally appended
|
||||
during the failed attempt. The caller passes it to
|
||||
``_restore_partial_with_error_marker`` on final-failure exit and discards
|
||||
it on a successful retry.
|
||||
|
||||
Trailing error markers appended inside ``_run_stream_attempt`` (idle
|
||||
timeout, circuit breaker) are stripped: re-attaching them would let the
|
||||
post-loop restore replay a stale marker before adding its own, leaving
|
||||
duplicate error bubbles and pushing any synthetic ``tool_result`` after
|
||||
an assistant(error) turn that has no matching ``tool_use``.
|
||||
"""
|
||||
captured = list(session.messages[pre_attempt_msg_count:])
|
||||
while captured and _is_error_marker(captured[-1]):
|
||||
captured.pop()
|
||||
session.messages = session.messages[:pre_attempt_msg_count]
|
||||
transcript_builder.restore(transcript_snap) # type: ignore[arg-type]
|
||||
return captured
|
||||
|
||||
|
||||
def _is_error_marker(msg: ChatMessage) -> bool:
|
||||
"""True if *msg* is an error marker emitted by ``_append_error_marker``."""
|
||||
if msg.role != "assistant" or not msg.content:
|
||||
return False
|
||||
return msg.content.startswith(COPILOT_ERROR_PREFIX) or msg.content.startswith(
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX
|
||||
)
|
||||
if interrupted.handled_error is not None:
|
||||
return interrupted.handled_error.error_msg, interrupted.handled_error.retryable
|
||||
if attempts_exhausted:
|
||||
return (
|
||||
"Your conversation is too long. "
|
||||
"Please start a new chat or clear some history."
|
||||
), False
|
||||
if transient_exhausted:
|
||||
return FRIENDLY_TRANSIENT_MSG, True
|
||||
if stream_err is not None:
|
||||
safe_err = str(stream_err).replace("\n", " ").replace("\r", "")[:500]
|
||||
return _friendly_error_text(safe_err), False
|
||||
return None, False
|
||||
|
||||
|
||||
def _setup_langfuse_otel() -> None:
|
||||
@@ -3112,12 +3150,13 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
|
||||
request_arrival_at: float = 0.0,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
# Pyright's complexity heuristic bails on this function (~1500 LoC, retry
|
||||
# Pyright's complexity heuristic bails on this ~1500 LoC function (retry
|
||||
# loop with context-overflow fallback + transient backoff + partial-work
|
||||
# preservation). Splitting further would hurt readability — the branches
|
||||
# share state (session, adapter, transcript builder, token accumulators)
|
||||
# that's hard to pass cleanly through helpers. Suppress the bailout; real
|
||||
# type errors elsewhere in the file remain surfaced.
|
||||
# preservation). Splitting the retry loop further hurts readability —
|
||||
# branches share mutable state (session, adapter, transcript builder,
|
||||
# usage accumulators) that doesn't pass cleanly through helpers. The
|
||||
# suppression only silences the complexity bailout; real type errors in
|
||||
# the function body still surface.
|
||||
"""Stream chat completion using Claude Agent SDK.
|
||||
|
||||
Args:
|
||||
@@ -3281,14 +3320,11 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
|
||||
turn_cost_usd: float | None = None
|
||||
graphiti_enabled = False
|
||||
pre_attempt_msg_count = 0
|
||||
# Holds messages that the retry loop rolls back from session.messages on a
|
||||
# failed attempt. On final-failure exit we re-attach these so the user sees
|
||||
# what the assistant produced before the error rather than an empty chat.
|
||||
last_attempt_partial: list[ChatMessage] = []
|
||||
# Populated by the _HandledStreamError branch. Consumed once after the
|
||||
# retry loop to re-attach the partial and, if the inner handler hadn't
|
||||
# already emitted one, yield a single StreamError to the client.
|
||||
handled_error_info: _HandledErrorInfo | None = None
|
||||
# State of the latest failed attempt: rolled-back messages + any
|
||||
# _HandledStreamError info to emit on final-failure exit. The retry loop
|
||||
# mutates this via capture()/clear(); the post-loop block calls
|
||||
# finalize() once.
|
||||
interrupted = _InterruptedAttempt()
|
||||
# Defaults ensure the finally block can always reference these safely even when
|
||||
# an early return (e.g. sdk_cwd error) skips their normal assignment below.
|
||||
sdk_model: str | None = None
|
||||
@@ -3949,10 +3985,9 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
|
||||
"using fallback model for this request"
|
||||
)
|
||||
yield event
|
||||
# Drop any stale partial captured from prior failed attempts
|
||||
# so the outer cleanup paths don't re-attach pre-retry content
|
||||
# the successful attempt already replaced.
|
||||
last_attempt_partial.clear()
|
||||
# Discard any state captured from prior failed attempts so
|
||||
# outer cleanup paths don't replay pre-retry content.
|
||||
interrupted.clear()
|
||||
break # Stream completed — exit retry loop
|
||||
except asyncio.CancelledError:
|
||||
logger.warning(
|
||||
@@ -3968,7 +4003,7 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
|
||||
# session messages and set the error flag — do NOT set
|
||||
# stream_err so the post-loop code won't emit a
|
||||
# duplicate StreamError.
|
||||
last_attempt_partial = _rollback_attempt_capturing_partial(
|
||||
interrupted.capture(
|
||||
session,
|
||||
state.transcript_builder,
|
||||
transcript_snap,
|
||||
@@ -4007,7 +4042,7 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
|
||||
# attempt that no longer match session.messages. Skip upload
|
||||
# so a future --resume doesn't replay rolled-back content.
|
||||
skip_transcript_upload = True
|
||||
handled_error_info = _HandledErrorInfo(
|
||||
interrupted.handled_error = _HandledErrorInfo(
|
||||
error_msg=exc.error_msg or FRIENDLY_TRANSIENT_MSG,
|
||||
code=exc.code or "transient_api_error",
|
||||
retryable=exc.retryable,
|
||||
@@ -4031,7 +4066,7 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
|
||||
stream_err,
|
||||
exc_info=True,
|
||||
)
|
||||
last_attempt_partial = _rollback_attempt_capturing_partial(
|
||||
interrupted.capture(
|
||||
session,
|
||||
state.transcript_builder,
|
||||
transcript_snap,
|
||||
@@ -4097,38 +4132,15 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
|
||||
_MAX_STREAM_ATTEMPTS,
|
||||
stream_err,
|
||||
)
|
||||
# Restore the rolled-back partial work + add error marker exactly once
|
||||
# per failure mode. Earlier revisions did this inline in each retry-loop
|
||||
# branch; consolidating here keeps the retry loop itself simple enough
|
||||
# for pyright to type-check.
|
||||
# Restore the rolled-back partial + add error marker exactly once per
|
||||
# failure mode. See _InterruptedAttempt for the carry-state contract.
|
||||
if ended_with_stream_error:
|
||||
if handled_error_info is not None:
|
||||
final_msg = handled_error_info.error_msg
|
||||
final_retryable = handled_error_info.retryable
|
||||
elif attempts_exhausted:
|
||||
final_msg = (
|
||||
"Your conversation is too long. "
|
||||
"Please start a new chat or clear some history."
|
||||
)
|
||||
final_retryable = False
|
||||
elif transient_exhausted:
|
||||
final_msg = FRIENDLY_TRANSIENT_MSG
|
||||
final_retryable = True
|
||||
elif stream_err is not None:
|
||||
final_msg = _friendly_error_text(
|
||||
str(stream_err).replace("\n", " ").replace("\r", "")[:500]
|
||||
)
|
||||
final_retryable = False
|
||||
else:
|
||||
final_msg = None
|
||||
final_retryable = False
|
||||
final_msg, final_retryable = _classify_final_failure(
|
||||
interrupted, attempts_exhausted, transient_exhausted, stream_err
|
||||
)
|
||||
if final_msg is not None:
|
||||
_restore_partial_with_error_marker(
|
||||
session,
|
||||
state,
|
||||
last_attempt_partial,
|
||||
final_msg,
|
||||
retryable=final_retryable,
|
||||
interrupted.finalize(
|
||||
session, state, final_msg, retryable=final_retryable
|
||||
)
|
||||
|
||||
if ended_with_stream_error and state is not None:
|
||||
@@ -4173,10 +4185,13 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
|
||||
# here when the inner handler chose not to (transient errors suppress
|
||||
# the early flash so the client only sees the final error after all
|
||||
# retries are exhausted).
|
||||
if handled_error_info is not None and not handled_error_info.already_yielded:
|
||||
if (
|
||||
interrupted.handled_error is not None
|
||||
and not interrupted.handled_error.already_yielded
|
||||
):
|
||||
yield StreamError(
|
||||
errorText=handled_error_info.error_msg,
|
||||
code=handled_error_info.code,
|
||||
errorText=interrupted.handled_error.error_msg,
|
||||
code=interrupted.handled_error.code,
|
||||
)
|
||||
|
||||
# Copy token usage from retry state to outer-scope accumulators
|
||||
@@ -4259,24 +4274,13 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
|
||||
else:
|
||||
display_msg, code = error_msg, "sdk_error"
|
||||
|
||||
# Append error marker to session (non-invasive text parsing approach).
|
||||
# The finally block will persist the session with this error marker.
|
||||
# Skip if a marker was already appended inside the stream loop
|
||||
# (ended_with_stream_error) to avoid duplicate stale markers.
|
||||
# Append error marker + restore any rolled-back partial when the retry
|
||||
# loop didn't already finalize. ``interrupted`` is empty on success and
|
||||
# on paths where the retry loop's own post-loop finalize() already ran,
|
||||
# so this is a no-op for those and only kicks in for unhandled errors
|
||||
# that bypass the retry-loop handlers entirely.
|
||||
if not ended_with_stream_error:
|
||||
# Restore any rolled-back partial work the retry loop captured —
|
||||
# last_attempt_partial is empty on success / final-failure-handled
|
||||
# paths (cleared on success break, consumed inside the retry loop
|
||||
# by _restore_partial_with_error_marker), so this is a no-op for
|
||||
# those cases and only kicks in when an unhandled exception bypassed
|
||||
# the retry loop's own restoration.
|
||||
_restore_partial_with_error_marker(
|
||||
session,
|
||||
state,
|
||||
last_attempt_partial,
|
||||
display_msg,
|
||||
retryable=is_transient,
|
||||
)
|
||||
interrupted.finalize(session, state, display_msg, retryable=is_transient)
|
||||
logger.debug(
|
||||
"%s Appended error marker, will be persisted in finally",
|
||||
log_prefix,
|
||||
|
||||
Reference in New Issue
Block a user