mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
refactor(backend/copilot): expose public adapter flush + consolidate final-failure emit
CodeRabbit flagged that _flush_orphan_tool_uses_to_session (called from _InterruptedAttempt.finalize) used state.adapter._flush_unresolved_tool_calls with a # noqa: SLF001 suppressor. The private call mutates resolved_tool_calls and flips has_unresolved_tool_calls to False, which caused the downstream error-cleanup block at lines 4185-4200 to skip its own flush — UI spinners on the client stayed open until page refresh because no cleanup events were yielded after the early flush swallowed the unresolved state. Changes: - Rename _flush_unresolved_tool_calls → flush_unresolved_tool_calls (public) in response_adapter.py; update 3 internal call sites + 2 service.py sites. Drops the # noqa: SLF001 suppressor (no longer a private-access violation). - _flush_orphan_tool_uses_to_session and _InterruptedAttempt.finalize now return the list[StreamBaseResponse] produced by the flush so the caller yields them to the client instead of re-flushing. - Replace the three scattered post-loop error blocks (partial restore + redundant flush + stream_err yield + handled_error yield) with one consolidated block that: (a) calls _classify_final_failure → _FinalFailure, (b) yields finalize()'s events + _end_text_if_open, (c) yields one StreamError (unless handled_error.already_yielded=True). Fixes the double-flush skip-cleanup bug and eliminates duplicated error-text/code strings between history marker and SSE yield. - _classify_final_failure now returns _FinalFailure(display_msg, code, retryable) instead of a (msg, retryable) tuple — single source of truth for in-history marker + SSE event so they can't drift. Tests: +5 _classify_final_failure contract tests, +2 return-value assertions on finalize/orphan-flush. All 1022 SDK tests pass (was 1012).
This commit is contained in:
@@ -22,6 +22,8 @@ from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
|
||||
from .service import (
|
||||
_classify_final_failure,
|
||||
_FinalFailure,
|
||||
_flush_orphan_tool_uses_to_session,
|
||||
_HandledErrorInfo,
|
||||
_InterruptedAttempt,
|
||||
@@ -49,7 +51,7 @@ def _adapter_with_unresolved(responses: list[StreamToolOutputAvailable]):
|
||||
out.extend(responses)
|
||||
adapter.has_unresolved_tool_calls = False
|
||||
|
||||
adapter._flush_unresolved_tool_calls.side_effect = _flush
|
||||
adapter.flush_unresolved_tool_calls.side_effect = _flush
|
||||
state = MagicMock()
|
||||
state.adapter = adapter
|
||||
return state
|
||||
@@ -147,8 +149,8 @@ class TestInterruptedAttemptFinalize:
|
||||
attempt = _InterruptedAttempt(
|
||||
partial=[ChatMessage(role="assistant", content="x")]
|
||||
)
|
||||
attempt.finalize(None, state=None, display_msg="Boom", retryable=False)
|
||||
# no raise = pass
|
||||
events = attempt.finalize(None, state=None, display_msg="Boom", retryable=False)
|
||||
assert events == []
|
||||
|
||||
def test_flushes_unresolved_tools_between_partial_and_marker(self):
|
||||
session = _make_session([ChatMessage(role="user", content="hi")])
|
||||
@@ -167,12 +169,20 @@ class TestInterruptedAttemptFinalize:
|
||||
),
|
||||
]
|
||||
)
|
||||
state = _adapter_with_unresolved([_tool_output("t1", "interrupted")])
|
||||
attempt.finalize(session, state=state, display_msg="Boom", retryable=False)
|
||||
flushed = [_tool_output("t1", "interrupted")]
|
||||
state = _adapter_with_unresolved(flushed)
|
||||
events = attempt.finalize(
|
||||
session, state=state, display_msg="Boom", retryable=False
|
||||
)
|
||||
roles = [m.role for m in session.messages]
|
||||
assert roles == ["user", "assistant", "tool", "assistant"]
|
||||
assert session.messages[2].tool_call_id == "t1"
|
||||
assert session.messages[2].content == "interrupted"
|
||||
# The same events that were persisted to history are returned to the
|
||||
# caller so the caller can yield them to the client — without this
|
||||
# the frontend's spinner widgets stay open until refresh because the
|
||||
# adapter's has_unresolved_tool_calls flag is already flipped to False.
|
||||
assert events == flushed
|
||||
|
||||
def test_clear_drops_both_partial_and_handled_error(self):
|
||||
attempt = _InterruptedAttempt(
|
||||
@@ -189,27 +199,101 @@ class TestInterruptedAttemptFinalize:
|
||||
class TestFlushOrphanToolUses:
|
||||
def test_appends_synthetic_tool_results_for_unresolved(self):
|
||||
session = _make_session()
|
||||
state = _adapter_with_unresolved(
|
||||
[_tool_output("t1", "r1"), _tool_output("t2", {"ok": False})]
|
||||
)
|
||||
_flush_orphan_tool_uses_to_session(session, state)
|
||||
flushed = [_tool_output("t1", "r1"), _tool_output("t2", {"ok": False})]
|
||||
state = _adapter_with_unresolved(flushed)
|
||||
events = _flush_orphan_tool_uses_to_session(session, state)
|
||||
assert [m.tool_call_id for m in session.messages] == ["t1", "t2"]
|
||||
# Dict outputs are JSON-encoded so structure survives the str-only
|
||||
# ChatMessage content field for the next-turn LLM read.
|
||||
assert session.messages[1].content == '{"ok": false}'
|
||||
assert events == flushed
|
||||
|
||||
def test_noop_when_state_is_none(self):
|
||||
session = _make_session()
|
||||
_flush_orphan_tool_uses_to_session(session, None)
|
||||
events = _flush_orphan_tool_uses_to_session(session, None)
|
||||
assert session.messages == []
|
||||
assert events == []
|
||||
|
||||
def test_noop_when_no_unresolved(self):
|
||||
adapter = MagicMock()
|
||||
adapter.has_unresolved_tool_calls = False
|
||||
state = MagicMock()
|
||||
state.adapter = adapter
|
||||
_flush_orphan_tool_uses_to_session(_make_session(), state)
|
||||
adapter._flush_unresolved_tool_calls.assert_not_called()
|
||||
events = _flush_orphan_tool_uses_to_session(_make_session(), state)
|
||||
adapter.flush_unresolved_tool_calls.assert_not_called()
|
||||
assert events == []
|
||||
|
||||
|
||||
class TestClassifyFinalFailure:
|
||||
"""Ensures the history marker (via finalize) and the SSE StreamError yield
|
||||
share one source of truth for display message + stream code — any drift
|
||||
would let the chat bubble and the SSE event show different copy for the
|
||||
same failure."""
|
||||
|
||||
def test_handled_error_wins(self):
|
||||
interrupted = _InterruptedAttempt(
|
||||
handled_error=_HandledErrorInfo(
|
||||
error_msg="circuit tripped",
|
||||
code="circuit_breaker",
|
||||
retryable=False,
|
||||
already_yielded=True,
|
||||
)
|
||||
)
|
||||
result = _classify_final_failure(
|
||||
interrupted,
|
||||
attempts_exhausted=False,
|
||||
transient_exhausted=False,
|
||||
stream_err=RuntimeError("ignored"),
|
||||
)
|
||||
assert result == _FinalFailure(
|
||||
display_msg="circuit tripped",
|
||||
code="circuit_breaker",
|
||||
retryable=False,
|
||||
)
|
||||
|
||||
def test_attempts_exhausted(self):
|
||||
result = _classify_final_failure(
|
||||
_InterruptedAttempt(),
|
||||
attempts_exhausted=True,
|
||||
transient_exhausted=False,
|
||||
stream_err=RuntimeError("x"),
|
||||
)
|
||||
assert result is not None
|
||||
assert result.code == "all_attempts_exhausted"
|
||||
assert result.retryable is False
|
||||
|
||||
def test_transient_exhausted(self):
|
||||
result = _classify_final_failure(
|
||||
_InterruptedAttempt(),
|
||||
attempts_exhausted=False,
|
||||
transient_exhausted=True,
|
||||
stream_err=RuntimeError("x"),
|
||||
)
|
||||
assert result is not None
|
||||
assert result.code == "transient_api_error"
|
||||
assert result.retryable is True
|
||||
|
||||
def test_stream_err_fallback(self):
|
||||
result = _classify_final_failure(
|
||||
_InterruptedAttempt(),
|
||||
attempts_exhausted=False,
|
||||
transient_exhausted=False,
|
||||
stream_err=RuntimeError("some sdk error"),
|
||||
)
|
||||
assert result is not None
|
||||
assert result.code == "sdk_stream_error"
|
||||
assert result.retryable is False
|
||||
|
||||
def test_returns_none_when_no_failure_recorded(self):
|
||||
assert (
|
||||
_classify_final_failure(
|
||||
_InterruptedAttempt(),
|
||||
attempts_exhausted=False,
|
||||
transient_exhausted=False,
|
||||
stream_err=None,
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
class TestRetryRollbackContract:
|
||||
|
||||
@@ -166,7 +166,7 @@ class SDKResponseAdapter:
|
||||
# are still executing concurrently and haven't finished yet.
|
||||
is_tool_only = all(isinstance(b, ToolUseBlock) for b in sdk_message.content)
|
||||
if not is_tool_only:
|
||||
self._flush_unresolved_tool_calls(responses)
|
||||
self.flush_unresolved_tool_calls(responses)
|
||||
|
||||
# After tool results, the SDK sends a new AssistantMessage for the
|
||||
# next LLM turn. Open a new step if the previous one was closed.
|
||||
@@ -375,7 +375,7 @@ class SDKResponseAdapter:
|
||||
self.step_open = False
|
||||
|
||||
elif isinstance(sdk_message, ResultMessage):
|
||||
self._flush_unresolved_tool_calls(responses)
|
||||
self.flush_unresolved_tool_calls(responses)
|
||||
# Thinking-only final turn guard: when the model's last LLM
|
||||
# call after a tool result produced only a ``ThinkingBlock``
|
||||
# (no ``TextBlock``, no ``ToolUseBlock``) the UI has nothing
|
||||
@@ -703,7 +703,7 @@ class SDKResponseAdapter:
|
||||
self._pending_thinking_delta = ""
|
||||
self._pending_thinking_index = None
|
||||
|
||||
def _flush_unresolved_tool_calls(self, responses: list[StreamBaseResponse]) -> None:
|
||||
def flush_unresolved_tool_calls(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""Emit outputs for tool calls that didn't receive a UserMessage result.
|
||||
|
||||
SDK built-in tools (WebSearch, Read, etc.) may be executed by the CLI
|
||||
@@ -711,6 +711,12 @@ class SDKResponseAdapter:
|
||||
``ToolResultBlock`` content. The ``PostToolUse`` hook stashes their
|
||||
output, which we pop and emit here before the next ``AssistantMessage``
|
||||
starts.
|
||||
|
||||
Callers that need to both record synthetic tool_results in history AND
|
||||
yield the same events to the client should call this exactly once and
|
||||
share the resulting list — the method mutates ``resolved_tool_calls``,
|
||||
so a second call returns nothing and ``has_unresolved_tool_calls``
|
||||
flips to ``False`` after the first invocation.
|
||||
"""
|
||||
unresolved = [
|
||||
(tid, info.get("name", "unknown"))
|
||||
|
||||
@@ -615,39 +615,48 @@ class _InterruptedAttempt:
|
||||
display_msg: str,
|
||||
*,
|
||||
retryable: bool,
|
||||
) -> None:
|
||||
) -> list[StreamBaseResponse]:
|
||||
"""Re-attach partial + synthetic tool_result rows + error marker.
|
||||
|
||||
Called exactly once after the retry loop on final-failure exit.
|
||||
Idempotent on empty state, so it's safe to call on paths where no
|
||||
rollback happened.
|
||||
|
||||
Returns the ``StreamBaseResponse`` events produced by the safety
|
||||
flush so the caller can yield them to the client (the flush mutates
|
||||
adapter state, so a second flush elsewhere would return nothing and
|
||||
stale UI elements like spinners would stay open).
|
||||
"""
|
||||
if session is None:
|
||||
return
|
||||
return []
|
||||
if self.partial:
|
||||
session.messages.extend(self.partial)
|
||||
self.partial = []
|
||||
_flush_orphan_tool_uses_to_session(session, state)
|
||||
events = _flush_orphan_tool_uses_to_session(session, state)
|
||||
_append_error_marker(session, display_msg, retryable=retryable)
|
||||
return events
|
||||
|
||||
|
||||
def _flush_orphan_tool_uses_to_session(
|
||||
session: "ChatSession | None",
|
||||
state: "_RetryState | None",
|
||||
) -> None:
|
||||
) -> list[StreamBaseResponse]:
|
||||
"""Synthesize ``tool_result`` rows for ``tool_use`` blocks that never resolved.
|
||||
|
||||
Re-attached partial work may carry orphan ``tool_use`` blocks; without
|
||||
matching ``tool_result`` rows the next turn's LLM call would error with
|
||||
``tool_use_id without tool_result``. The adapter's safety-flush produces
|
||||
interrupted-marker results that satisfy the API contract.
|
||||
|
||||
Returns the flushed events so callers can yield them to the client
|
||||
alongside persisting the synthetic rows in session history.
|
||||
"""
|
||||
if session is None or state is None:
|
||||
return
|
||||
return []
|
||||
if not state.adapter.has_unresolved_tool_calls:
|
||||
return
|
||||
return []
|
||||
safety: list[StreamBaseResponse] = []
|
||||
state.adapter._flush_unresolved_tool_calls(safety) # noqa: SLF001
|
||||
state.adapter.flush_unresolved_tool_calls(safety)
|
||||
for resp in safety:
|
||||
if isinstance(resp, StreamToolOutputAvailable):
|
||||
content = (
|
||||
@@ -658,6 +667,20 @@ def _flush_orphan_tool_uses_to_session(
|
||||
session.messages.append(
|
||||
ChatMessage(role="tool", content=content, tool_call_id=resp.toolCallId)
|
||||
)
|
||||
return safety
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _FinalFailure:
|
||||
"""Display message + stream code + retryable flag for a final-failure exit.
|
||||
|
||||
Shared by the in-history error marker (via ``_InterruptedAttempt.finalize``)
|
||||
and the client-facing ``StreamError`` SSE yield so the two stay in sync.
|
||||
"""
|
||||
|
||||
display_msg: str
|
||||
code: str
|
||||
retryable: bool
|
||||
|
||||
|
||||
def _classify_final_failure(
|
||||
@@ -665,25 +688,41 @@ def _classify_final_failure(
|
||||
attempts_exhausted: bool,
|
||||
transient_exhausted: bool,
|
||||
stream_err: BaseException | None,
|
||||
) -> tuple[str | None, bool]:
|
||||
"""Pick the display message + retryable flag for the post-loop restore.
|
||||
) -> _FinalFailure | None:
|
||||
"""Pick the display message, stream code, and retryable flag for the exit.
|
||||
|
||||
Mirrors the error-code selection used for the client-facing ``StreamError``
|
||||
yield so the in-history marker and the SSE event stay consistent.
|
||||
Returns ``None`` when no failure was recorded (success path) — the caller
|
||||
should skip both the history marker and the SSE yield in that case.
|
||||
"""
|
||||
if interrupted.handled_error is not None:
|
||||
return interrupted.handled_error.error_msg, interrupted.handled_error.retryable
|
||||
return _FinalFailure(
|
||||
display_msg=interrupted.handled_error.error_msg,
|
||||
code=interrupted.handled_error.code,
|
||||
retryable=interrupted.handled_error.retryable,
|
||||
)
|
||||
if attempts_exhausted:
|
||||
return (
|
||||
"Your conversation is too long. "
|
||||
"Please start a new chat or clear some history."
|
||||
), False
|
||||
return _FinalFailure(
|
||||
display_msg=(
|
||||
"Your conversation is too long. "
|
||||
"Please start a new chat or clear some history."
|
||||
),
|
||||
code="all_attempts_exhausted",
|
||||
retryable=False,
|
||||
)
|
||||
if transient_exhausted:
|
||||
return FRIENDLY_TRANSIENT_MSG, True
|
||||
return _FinalFailure(
|
||||
display_msg=FRIENDLY_TRANSIENT_MSG,
|
||||
code="transient_api_error",
|
||||
retryable=True,
|
||||
)
|
||||
if stream_err is not None:
|
||||
safe_err = str(stream_err).replace("\n", " ").replace("\r", "")[:500]
|
||||
return _friendly_error_text(safe_err), False
|
||||
return None, False
|
||||
return _FinalFailure(
|
||||
display_msg=_friendly_error_text(safe_err),
|
||||
code="sdk_stream_error",
|
||||
retryable=False,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _setup_langfuse_otel() -> None:
|
||||
@@ -2868,7 +2907,7 @@ async def _run_stream_attempt(
|
||||
- len(state.adapter.resolved_tool_calls),
|
||||
)
|
||||
safety_responses: list[StreamBaseResponse] = []
|
||||
state.adapter._flush_unresolved_tool_calls(safety_responses)
|
||||
state.adapter.flush_unresolved_tool_calls(safety_responses)
|
||||
for response in safety_responses:
|
||||
if isinstance(
|
||||
response,
|
||||
@@ -4132,67 +4171,38 @@ async def stream_chat_completion_sdk( # pyright: ignore[reportGeneralTypeIssues
|
||||
_MAX_STREAM_ATTEMPTS,
|
||||
stream_err,
|
||||
)
|
||||
# Restore the rolled-back partial + add error marker exactly once per
|
||||
# failure mode. See _InterruptedAttempt for the carry-state contract.
|
||||
# Consolidated final-failure handling. _classify_final_failure picks
|
||||
# the display message + stream code + retryable flag, finalize() adds
|
||||
# the history marker and produces the safety-flush events that close
|
||||
# stale UI widgets on the client, and the StreamError yield below
|
||||
# surfaces the same message over SSE. The _HandledStreamError path
|
||||
# sets ``already_yielded=True`` for non-transient errors (circuit
|
||||
# breaker, idle timeout) whose inner handler already yielded — skip
|
||||
# the re-yield in that case.
|
||||
if ended_with_stream_error:
|
||||
final_msg, final_retryable = _classify_final_failure(
|
||||
failure = _classify_final_failure(
|
||||
interrupted, attempts_exhausted, transient_exhausted, stream_err
|
||||
)
|
||||
if final_msg is not None:
|
||||
interrupted.finalize(
|
||||
session, state, final_msg, retryable=final_retryable
|
||||
if failure is not None:
|
||||
cleanup_events: list[StreamBaseResponse] = []
|
||||
if state is not None:
|
||||
state.adapter._end_text_if_open(cleanup_events)
|
||||
cleanup_events.extend(
|
||||
interrupted.finalize(
|
||||
session,
|
||||
state,
|
||||
failure.display_msg,
|
||||
retryable=failure.retryable,
|
||||
)
|
||||
)
|
||||
|
||||
if ended_with_stream_error and state is not None:
|
||||
# Flush any unresolved tool calls so the frontend can close
|
||||
# stale UI elements (e.g. spinners) that were started before
|
||||
# the exception interrupted the stream.
|
||||
error_flush: list[StreamBaseResponse] = []
|
||||
state.adapter._end_text_if_open(error_flush)
|
||||
if state.adapter.has_unresolved_tool_calls:
|
||||
logger.warning(
|
||||
"%s Flushing %d unresolved tool(s) after stream error",
|
||||
log_prefix,
|
||||
len(state.adapter.current_tool_calls)
|
||||
- len(state.adapter.resolved_tool_calls),
|
||||
for response in cleanup_events:
|
||||
yield response
|
||||
already_yielded = (
|
||||
interrupted.handled_error is not None
|
||||
and interrupted.handled_error.already_yielded
|
||||
)
|
||||
state.adapter._flush_unresolved_tool_calls(error_flush)
|
||||
for response in error_flush:
|
||||
yield response
|
||||
|
||||
if ended_with_stream_error and stream_err is not None:
|
||||
# Use distinct error codes depending on how the loop ended:
|
||||
# • "all_attempts_exhausted" — context compaction ran out of room
|
||||
# • "transient_api_error" — 429/5xx/ECONNRESET retries exhausted
|
||||
# • "sdk_stream_error" — non-context, non-transient fatal error
|
||||
safe_err = str(stream_err).replace("\n", " ").replace("\r", "")[:500]
|
||||
if attempts_exhausted:
|
||||
error_text = (
|
||||
"Your conversation is too long. "
|
||||
"Please start a new chat or clear some history."
|
||||
)
|
||||
error_code = "all_attempts_exhausted"
|
||||
elif transient_exhausted:
|
||||
error_text = FRIENDLY_TRANSIENT_MSG
|
||||
error_code = "transient_api_error"
|
||||
else:
|
||||
error_text = _friendly_error_text(safe_err)
|
||||
error_code = "sdk_stream_error"
|
||||
yield StreamError(errorText=error_text, code=error_code)
|
||||
|
||||
# _HandledStreamError exits the retry loop with stream_err unset, so
|
||||
# the previous block doesn't fire. Emit the client-facing StreamError
|
||||
# here when the inner handler chose not to (transient errors suppress
|
||||
# the early flash so the client only sees the final error after all
|
||||
# retries are exhausted).
|
||||
if (
|
||||
interrupted.handled_error is not None
|
||||
and not interrupted.handled_error.already_yielded
|
||||
):
|
||||
yield StreamError(
|
||||
errorText=interrupted.handled_error.error_msg,
|
||||
code=interrupted.handled_error.code,
|
||||
)
|
||||
if not already_yielded:
|
||||
yield StreamError(errorText=failure.display_msg, code=failure.code)
|
||||
|
||||
# Copy token usage from retry state to outer-scope accumulators
|
||||
# so the finally block can persist them.
|
||||
|
||||
Reference in New Issue
Block a user