mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
fix(backend/copilot): patch toolCalls DB row when flush races ahead of StreamToolInputAvailable
The intermediate flush introduced in #12604 is append-only: rows already in the DB are never re-saved. When a flush fires between StreamTextDelta and StreamToolInputAvailable, the assistant row is written with toolCalls=null and the tool calls that arrive later are only applied in-memory — the DB row is never updated. The frontend silently drops tool calls when the column is null, making them invisible in the UI. Fix: after dispatching each adapter_responses batch, if StreamToolInputAvailable was present AND acc.assistant_response.sequence is already set (flush happened), issue a targeted UPDATE via update_message_tool_calls() to patch toolCalls on the existing row. asyncio.shield() keeps the patch from being cancelled on GeneratorExit. Adds regression tests in session_persistence_test.py covering the text→flush→ tool-input sequence.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -195,3 +195,4 @@ test.db
|
||||
# Implementation plans (generated by AI agents)
|
||||
plans/
|
||||
.claude/worktrees/
|
||||
test-results/
|
||||
|
||||
@@ -548,6 +548,46 @@ async def update_message_content_by_sequence(
|
||||
return False
|
||||
|
||||
|
||||
async def update_message_tool_calls(
|
||||
session_id: str,
|
||||
sequence: int,
|
||||
tool_calls: list[dict],
|
||||
) -> bool:
|
||||
"""Patch the toolCalls column of an already-saved assistant message.
|
||||
|
||||
Called when StreamToolInputAvailable arrives after an intermediate flush
|
||||
saved the assistant message with tool_calls=None. The DB save is
|
||||
append-only (uses get_next_sequence), so the already-persisted row must
|
||||
be updated in-place to reflect the tool_calls that arrived later.
|
||||
|
||||
Args:
|
||||
session_id: The chat session ID.
|
||||
sequence: The 0-based sequence number of the assistant message to patch.
|
||||
tool_calls: The full list of tool call dicts to set on the row.
|
||||
|
||||
Returns:
|
||||
True if the row was found and updated, False otherwise.
|
||||
"""
|
||||
try:
|
||||
result = await PrismaChatMessage.prisma().update_many(
|
||||
where={"sessionId": session_id, "sequence": sequence},
|
||||
data={"toolCalls": SafeJson(tool_calls)},
|
||||
)
|
||||
if result == 0:
|
||||
logger.warning(
|
||||
f"update_message_tool_calls: no row found for session {session_id}, "
|
||||
f"sequence {sequence}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"update_message_tool_calls failed for session {session_id}, "
|
||||
f"sequence {sequence}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def set_turn_duration(session_id: str, duration_ms: int) -> None:
|
||||
"""Set durationMs on the last assistant message in a session.
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ from backend.util.exceptions import NotFoundError
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from ..config import ChatConfig, CopilotLlmModel, CopilotMode
|
||||
from ..db import update_message_tool_calls
|
||||
from ..constants import (
|
||||
COPILOT_ERROR_PREFIX,
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX,
|
||||
@@ -2262,6 +2263,30 @@ async def _run_stream_attempt(
|
||||
if dispatched is not None:
|
||||
yield dispatched
|
||||
|
||||
# If tool calls arrived this batch AND the assistant message was
|
||||
# already flushed to DB (sequence is set), patch the existing row
|
||||
# so tool_calls are not lost. The append-only save (start_sequence)
|
||||
# in _save_session_to_db never re-saves already-persisted rows, so
|
||||
# without this patch the assistant row keeps tool_calls=null.
|
||||
if acc.assistant_response.sequence is not None and any(
|
||||
isinstance(r, StreamToolInputAvailable) for r in adapter_responses
|
||||
):
|
||||
try:
|
||||
await asyncio.shield(
|
||||
update_message_tool_calls(
|
||||
ctx.session.session_id,
|
||||
acc.assistant_response.sequence,
|
||||
acc.accumulated_tool_calls,
|
||||
)
|
||||
)
|
||||
except Exception as patch_err:
|
||||
logger.warning(
|
||||
"%s tool_calls DB patch failed (sequence=%d): %s",
|
||||
ctx.log_prefix,
|
||||
acc.assistant_response.sequence,
|
||||
patch_err,
|
||||
)
|
||||
|
||||
# Append assistant entry AFTER convert_message so that
|
||||
# any stashed tool results from the previous turn are
|
||||
# recorded first, preserving the required API order:
|
||||
|
||||
@@ -20,7 +20,11 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.response_model import StreamStartStep, StreamTextDelta
|
||||
from backend.copilot.response_model import (
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
)
|
||||
from backend.copilot.sdk.service import _dispatch_response, _StreamAccumulator
|
||||
|
||||
_NOW = datetime(2024, 1, 1, tzinfo=timezone.utc)
|
||||
@@ -215,3 +219,100 @@ class TestPreCreateAssistantMessage:
|
||||
_simulate_pre_create(acc, ctx)
|
||||
|
||||
assert len(ctx.session.messages) == 0
|
||||
|
||||
|
||||
class TestToolCallsLostAfterIntermediateFlush:
|
||||
"""Regression tests for the bug where tool_calls are lost when an
|
||||
intermediate flush saves the assistant message before StreamToolInputAvailable
|
||||
arrives.
|
||||
|
||||
Sequence that triggers the bug:
|
||||
1. StreamTextDelta → assistant message appended with tool_calls=None
|
||||
2. Intermediate flush fires (time/count threshold) → DB row written with tool_calls=null
|
||||
and acc.assistant_response.sequence is set (back-filled)
|
||||
3. StreamToolInputAvailable → acc.assistant_response.tool_calls mutated in-memory
|
||||
4. Final save: append-only — assistant row already in DB, tool_calls never updated
|
||||
|
||||
Fix: when StreamToolInputAvailable arrives and acc.assistant_response.sequence
|
||||
is not None, issue a DB UPDATE to patch toolCalls on the existing row.
|
||||
"""
|
||||
|
||||
def test_text_delta_then_tool_input_sets_tool_calls_on_message(self) -> None:
|
||||
"""After text arrives then tool input arrives, acc.assistant_response.tool_calls
|
||||
should be populated regardless of flush state."""
|
||||
session = _make_session()
|
||||
ctx = _make_ctx(session)
|
||||
state = _make_state()
|
||||
acc = _StreamAccumulator(
|
||||
assistant_response=ChatMessage(role="assistant", content=""),
|
||||
accumulated_tool_calls=[],
|
||||
)
|
||||
|
||||
# Step 1: text delta arrives, message appended
|
||||
_dispatch_response(
|
||||
StreamTextDelta(id="t1", delta="Let me run that for you."),
|
||||
acc,
|
||||
ctx,
|
||||
state,
|
||||
False,
|
||||
"[test]",
|
||||
)
|
||||
assert acc.has_appended_assistant
|
||||
assert session.messages[-1].tool_calls is None
|
||||
|
||||
# Step 2: simulate intermediate flush back-filling the sequence
|
||||
acc.assistant_response.sequence = 1 # back-filled by _save_session_to_db
|
||||
|
||||
# Step 3: tool input arrives
|
||||
_dispatch_response(
|
||||
StreamToolInputAvailable(
|
||||
toolCallId="call_abc",
|
||||
toolName="bash_exec",
|
||||
input={"command": "ls"},
|
||||
),
|
||||
acc,
|
||||
ctx,
|
||||
state,
|
||||
False,
|
||||
"[test]",
|
||||
)
|
||||
|
||||
# tool_calls should be set in memory
|
||||
assert acc.assistant_response.tool_calls is not None
|
||||
assert len(acc.assistant_response.tool_calls) == 1
|
||||
assert acc.assistant_response.tool_calls[0]["id"] == "call_abc"
|
||||
|
||||
def test_sequence_set_when_flush_occurred_before_tool_input(self) -> None:
|
||||
"""When sequence is back-filled (flush happened) before tool calls arrive,
|
||||
it is detectable so the caller can issue a DB patch."""
|
||||
acc = _StreamAccumulator(
|
||||
assistant_response=ChatMessage(role="assistant", content="hello"),
|
||||
accumulated_tool_calls=[],
|
||||
has_appended_assistant=True,
|
||||
)
|
||||
# Simulate flush back-fill
|
||||
acc.assistant_response.sequence = 3
|
||||
|
||||
ctx = _make_ctx()
|
||||
state = _make_state()
|
||||
|
||||
_dispatch_response(
|
||||
StreamToolInputAvailable(
|
||||
toolCallId="call_xyz",
|
||||
toolName="run_block",
|
||||
input={},
|
||||
),
|
||||
acc,
|
||||
ctx,
|
||||
state,
|
||||
False,
|
||||
"[test]",
|
||||
)
|
||||
|
||||
# Caller should detect this condition and issue a DB patch
|
||||
needs_db_patch = acc.assistant_response.sequence is not None and bool(
|
||||
acc.accumulated_tool_calls
|
||||
)
|
||||
assert (
|
||||
needs_db_patch
|
||||
), "Expected needs_db_patch=True when flush happened before tool calls arrived"
|
||||
|
||||
Reference in New Issue
Block a user