diff --git a/autogpt_platform/backend/backend/copilot/sdk/interrupted_partial_test.py b/autogpt_platform/backend/backend/copilot/sdk/interrupted_partial_test.py new file mode 100644 index 0000000000..19d9f67a3b --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/sdk/interrupted_partial_test.py @@ -0,0 +1,200 @@ +"""Tests for partial-work preservation when an SDK turn is interrupted. + +Covers the regression path SECRT-2275 surfaced: when the SDK retry loop +rolls back ``session.messages`` for a failed attempt (correct behavior so a +successful retry doesn't duplicate content) it MUST re-attach the rolled-back +work on final-failure exit. Otherwise the user's UI streamed tokens live but +a refresh shows an empty turn — described by users as "the turn is gone". + +Tests target the helper functions directly (unit) plus the rollback-then- +restore contract (state-driven). Full end-to-end coverage of the retry loop +lives in retry_scenarios_test.py. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +from backend.copilot.constants import ( + COPILOT_ERROR_PREFIX, + COPILOT_RETRYABLE_ERROR_PREFIX, +) +from backend.copilot.model import ChatMessage, ChatSession +from backend.copilot.response_model import StreamToolOutputAvailable + +from .service import ( + _flush_orphan_tool_uses_to_session, + _restore_partial_with_error_marker, +) + + +def _make_session(messages: list[ChatMessage] | None = None) -> ChatSession: + session = ChatSession.new(user_id="user-1", dry_run=False) + session.messages = list(messages or []) + return session + + +def _make_tool_output(tool_call_id: str, output) -> StreamToolOutputAvailable: + return StreamToolOutputAvailable( + toolCallId=tool_call_id, + toolName="t", + output=output, + ) + + +def _adapter_with_unresolved(unresolved_responses: list[StreamToolOutputAvailable]): + """Build a stub _RetryState whose adapter flushes the given responses.""" + adapter = MagicMock() + adapter.has_unresolved_tool_calls = bool(unresolved_responses) + + def _flush(responses: list) -> None: + responses.extend(unresolved_responses) + adapter.has_unresolved_tool_calls = False + + adapter._flush_unresolved_tool_calls.side_effect = _flush + state = MagicMock() + state.adapter = adapter + return state + + +class TestRestorePartialWithErrorMarker: + def test_appends_partial_then_marker_when_partial_present(self): + session = _make_session([ChatMessage(role="user", content="hi")]) + partial = [ + ChatMessage(role="assistant", content="I was working on "), + ChatMessage(role="tool", content="result-1", tool_call_id="t1"), + ] + _restore_partial_with_error_marker( + session, + state=None, + partial=partial, + display_msg="Boom", + retryable=False, + ) + # Pre-existing user msg + 2 partial msgs + error marker + assert len(session.messages) == 4 + assert session.messages[1].content == "I was working on " + assert session.messages[2].role == "tool" + assert session.messages[3].content.startswith(COPILOT_ERROR_PREFIX) + # Partial list is consumed (cleared) so a stray follow-up call won't + # double-attach the same content. + assert partial == [] + + def test_only_marker_when_partial_empty(self): + session = _make_session([ChatMessage(role="user", content="hi")]) + _restore_partial_with_error_marker( + session, + state=None, + partial=[], + display_msg="Boom", + retryable=True, + ) + assert len(session.messages) == 2 + assert session.messages[-1].content.startswith(COPILOT_RETRYABLE_ERROR_PREFIX) + + def test_noop_when_session_is_none(self): + # Signature accepts None — must not raise. + _restore_partial_with_error_marker( + None, + state=None, + partial=[ChatMessage(role="assistant", content="x")], + display_msg="Boom", + retryable=False, + ) + + def test_flushes_unresolved_tools_between_partial_and_marker(self): + session = _make_session([ChatMessage(role="user", content="hi")]) + partial = [ + ChatMessage( + role="assistant", + content="calling tool", + tool_calls=[ + { + "id": "t1", + "type": "function", + "function": {"name": "lookup", "arguments": "{}"}, + } + ], + ), + ] + state = _adapter_with_unresolved([_make_tool_output("t1", "interrupted")]) + _restore_partial_with_error_marker( + session, + state=state, + partial=partial, + display_msg="Boom", + retryable=False, + ) + roles = [m.role for m in session.messages] + # user, assistant(partial), tool(synthetic), assistant(error marker) + assert roles == ["user", "assistant", "tool", "assistant"] + synthetic_tool = session.messages[2] + assert synthetic_tool.tool_call_id == "t1" + assert synthetic_tool.content == "interrupted" + + +class TestFlushOrphanToolUses: + def test_appends_synthetic_tool_results_for_unresolved(self): + session = _make_session() + state = _adapter_with_unresolved( + [ + _make_tool_output("t1", "r1"), + _make_tool_output("t2", {"ok": False}), + ] + ) + _flush_orphan_tool_uses_to_session(session, state) + assert [m.tool_call_id for m in session.messages] == ["t1", "t2"] + # Dict outputs are JSON-encoded so they survive the str-only ChatMessage + # content field without losing structure for the next-turn LLM read. + assert session.messages[1].content == '{"ok": false}' + + def test_noop_when_state_is_none(self): + session = _make_session() + _flush_orphan_tool_uses_to_session(session, None) + assert session.messages == [] + + def test_noop_when_no_unresolved(self): + session = _make_session() + adapter = MagicMock() + adapter.has_unresolved_tool_calls = False + state = MagicMock() + state.adapter = adapter + _flush_orphan_tool_uses_to_session(session, state) + adapter._flush_unresolved_tool_calls.assert_not_called() + + +class TestRetryRollbackContract: + """Property-style: a rolled-back attempt must be recoverable on final exit. + + Simulates the retry loop's rollback by mirroring the exact slicing and + captured-list shape used in stream_chat_completion_sdk so that any drift + in that contract is caught here without needing the full SDK fixture. + """ + + def test_capture_slice_matches_rollback(self): + session = _make_session([ChatMessage(role="user", content="hi")]) + pre_attempt_msg_count = len(session.messages) + # Simulate incremental SDK appends during the attempt. + session.messages.extend( + [ + ChatMessage(role="assistant", content="part-1"), + ChatMessage(role="assistant", content="part-2"), + ] + ) + captured = list(session.messages[pre_attempt_msg_count:]) + session.messages = session.messages[:pre_attempt_msg_count] + # Final-failure restore. + _restore_partial_with_error_marker( + session, + state=None, + partial=captured, + display_msg="Boom", + retryable=False, + ) + contents = [m.content for m in session.messages] + assert contents == [ + "hi", + "part-1", + "part-2", + f"{COPILOT_ERROR_PREFIX} Boom", + ] diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index f57e6ea791..69dc4b0382 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -557,6 +557,80 @@ def _append_error_marker( ) +def _flush_orphan_tool_uses_to_session( + session: "ChatSession | None", + state: "_RetryState | None", +) -> None: + """Synthesize tool_result rows for tool_use blocks that never resolved. + + Without this, partial assistant work re-attached after a final failure + would carry orphan tool_use blocks. The next turn's LLM call would error + with ``tool_use_id without tool_result`` — and any baseline replay would + surface the same broken history. Flushing produces interrupted-marker + tool_results that satisfy the API contract. + """ + if session is None or state is None: + return + if not state.adapter.has_unresolved_tool_calls: + return + safety: list[StreamBaseResponse] = [] + state.adapter._flush_unresolved_tool_calls(safety) # noqa: SLF001 + for resp in safety: + if isinstance(resp, StreamToolOutputAvailable): + content = ( + resp.output + if isinstance(resp.output, str) + else json.dumps(resp.output, ensure_ascii=False) + ) + session.messages.append( + ChatMessage(role="tool", content=content, tool_call_id=resp.toolCallId) + ) + + +def _restore_partial_with_error_marker( + session: "ChatSession | None", + state: "_RetryState | None", + partial: list[ChatMessage], + display_msg: str, + *, + retryable: bool, +) -> None: + """Re-attach a rolled-back attempt's partial work, then add the error marker. + + Called when retries are exhausted or no retry is attempted. Without this, + the SDK retry loop's pre-decision rollback would discard everything the + assistant produced in the failed attempt (text, tool calls, reasoning), + leaving the user with a chat that looks like nothing happened — even + though the events were already streamed live to their UI. + """ + if session is None: + return + if partial: + session.messages.extend(partial) + partial.clear() + _flush_orphan_tool_uses_to_session(session, state) + _append_error_marker(session, display_msg, retryable=retryable) + + +def _rollback_attempt_capturing_partial( + session: "ChatSession", + transcript_builder: "TranscriptBuilder", + transcript_snap: object, + pre_attempt_msg_count: int, +) -> list[ChatMessage]: + """Roll back an attempt's session.messages + transcript, capturing the partial. + + The returned list holds the assistant work that was incrementally appended + during the failed attempt. The caller passes it to + ``_restore_partial_with_error_marker`` on final-failure exit and discards + it on a successful retry. + """ + captured = list(session.messages[pre_attempt_msg_count:]) + session.messages = session.messages[:pre_attempt_msg_count] + transcript_builder.restore(transcript_snap) # type: ignore[arg-type] + return captured + + def _setup_langfuse_otel() -> None: """Configure OTEL tracing for the Claude Agent SDK → Langfuse. @@ -3169,6 +3243,10 @@ async def stream_chat_completion_sdk( turn_cost_usd: float | None = None graphiti_enabled = False pre_attempt_msg_count = 0 + # Holds messages that the retry loop rolls back from session.messages on a + # failed attempt. On final-failure exit we re-attach these so the user sees + # what the assistant produced before the error rather than an empty chat. + last_attempt_partial: list[ChatMessage] = [] # Defaults ensure the finally block can always reference these safely even when # an early return (e.g. sdk_cwd error) skips their normal assignment below. sdk_model: str | None = None @@ -3829,6 +3907,10 @@ async def stream_chat_completion_sdk( "using fallback model for this request" ) yield event + # Drop any stale partial captured from prior failed attempts + # so the outer cleanup paths don't re-attach pre-retry content + # the successful attempt already replaced. + last_attempt_partial.clear() break # Stream completed — exit retry loop except asyncio.CancelledError: logger.warning( @@ -3844,8 +3926,12 @@ async def stream_chat_completion_sdk( # session messages and set the error flag — do NOT set # stream_err so the post-loop code won't emit a # duplicate StreamError. - session.messages = session.messages[:pre_attempt_msg_count] - state.transcript_builder.restore(transcript_snap) + last_attempt_partial = _rollback_attempt_capturing_partial( + session, + state.transcript_builder, + transcript_snap, + pre_attempt_msg_count, + ) # Check if this is a transient error we can retry with backoff. # exc.code is the only reliable signal — str(exc) is always the # static "Stream error handled — StreamError already yielded" message. @@ -3879,12 +3965,13 @@ async def stream_chat_completion_sdk( # attempt that no longer match session.messages. Skip upload # so a future --resume doesn't replay rolled-back content. skip_transcript_upload = True - # Re-append the error marker so it survives the rollback - # and is persisted by the finally block (see #2947655365). - # Use the specific error message from the attempt (e.g. - # circuit breaker msg) rather than always the generic one. - _append_error_marker( + # Re-attach the rolled-back partial work + add error marker. + # Without partial restoration the user's UI streamed tokens + # live but a refresh shows nothing happened. + _restore_partial_with_error_marker( session, + state, + last_attempt_partial, exc.error_msg or FRIENDLY_TRANSIENT_MSG, retryable=True, ) @@ -3917,8 +4004,12 @@ async def stream_chat_completion_sdk( stream_err, exc_info=True, ) - session.messages = session.messages[:pre_attempt_msg_count] - state.transcript_builder.restore(transcript_snap) + last_attempt_partial = _rollback_attempt_capturing_partial( + session, + state.transcript_builder, + transcript_snap, + pre_attempt_msg_count, + ) if events_yielded > 0: # Events were already sent to the frontend and cannot be # unsent. Retrying would produce duplicate/inconsistent @@ -3929,6 +4020,19 @@ async def stream_chat_completion_sdk( events_yielded, ) skip_transcript_upload = True + # Restore the streamed partial + add error marker — without + # this the frontend would briefly show streamed text live + # then a refresh would show an empty turn. + safe_err = ( + str(stream_err).replace("\n", " ").replace("\r", "")[:500] + ) + _restore_partial_with_error_marker( + session, + state, + last_attempt_partial, + _friendly_error_text(safe_err), + retryable=False, + ) ended_with_stream_error = True break # Transient API errors (ECONNRESET, 429, 5xx) — retry @@ -3956,8 +4060,12 @@ async def stream_chat_completion_sdk( # at line ~2310. transient_exhausted = True skip_transcript_upload = True - _append_error_marker( - session, FRIENDLY_TRANSIENT_MSG, retryable=True + _restore_partial_with_error_marker( + session, + state, + last_attempt_partial, + FRIENDLY_TRANSIENT_MSG, + retryable=True, ) ended_with_stream_error = True break @@ -3966,6 +4074,16 @@ async def stream_chat_completion_sdk( # Non-context, non-transient errors (auth, fatal) # should not trigger compaction — surface immediately. skip_transcript_upload = True + safe_err = ( + str(stream_err).replace("\n", " ").replace("\r", "")[:500] + ) + _restore_partial_with_error_marker( + session, + state, + last_attempt_partial, + _friendly_error_text(safe_err), + retryable=False, + ) ended_with_stream_error = True break attempt += 1 # advance to next context-level attempt @@ -3982,6 +4100,17 @@ async def stream_chat_completion_sdk( _MAX_STREAM_ATTEMPTS, stream_err, ) + # Restore the last attempt's partial work (rolled back by the + # exhausted context-level retry) so the user sees what was + # produced before the conversation hit the context ceiling. + _restore_partial_with_error_marker( + session, + state, + last_attempt_partial, + "Your conversation is too long. " + "Please start a new chat or clear some history.", + retryable=False, + ) if ended_with_stream_error and state is not None: # Flush any unresolved tool calls so the frontend can close @@ -4105,7 +4234,19 @@ async def stream_chat_completion_sdk( # Skip if a marker was already appended inside the stream loop # (ended_with_stream_error) to avoid duplicate stale markers. if not ended_with_stream_error: - _append_error_marker(session, display_msg, retryable=is_transient) + # Restore any rolled-back partial work the retry loop captured — + # last_attempt_partial is empty on success / final-failure-handled + # paths (cleared on success break, consumed inside the retry loop + # by _restore_partial_with_error_marker), so this is a no-op for + # those cases and only kicks in when an unhandled exception bypassed + # the retry loop's own restoration. + _restore_partial_with_error_marker( + session, + state, + last_attempt_partial, + display_msg, + retryable=is_transient, + ) logger.debug( "%s Appended error marker, will be persisted in finally", log_prefix,