diff --git a/.gitignore b/.gitignore index 97d6b18a76..53df57dc70 100644 --- a/.gitignore +++ b/.gitignore @@ -195,3 +195,4 @@ test.db # Implementation plans (generated by AI agents) plans/ .claude/worktrees/ +test-results/ diff --git a/autogpt_platform/backend/backend/copilot/db.py b/autogpt_platform/backend/backend/copilot/db.py index b85e08606c..f646c7d473 100644 --- a/autogpt_platform/backend/backend/copilot/db.py +++ b/autogpt_platform/backend/backend/copilot/db.py @@ -548,6 +548,46 @@ async def update_message_content_by_sequence( return False +async def update_message_tool_calls( + session_id: str, + sequence: int, + tool_calls: list[dict], +) -> bool: + """Patch the toolCalls column of an already-saved assistant message. + + Called when StreamToolInputAvailable arrives after an intermediate flush + saved the assistant message with tool_calls=None. The DB save is + append-only (uses get_next_sequence), so the already-persisted row must + be updated in-place to reflect the tool_calls that arrived later. + + Args: + session_id: The chat session ID. + sequence: The 0-based sequence number of the assistant message to patch. + tool_calls: The full list of tool call dicts to set on the row. + + Returns: + True if the row was found and updated, False otherwise. + """ + try: + result = await PrismaChatMessage.prisma().update_many( + where={"sessionId": session_id, "sequence": sequence}, + data={"toolCalls": SafeJson(tool_calls)}, + ) + if result == 0: + logger.warning( + f"update_message_tool_calls: no row found for session {session_id}, " + f"sequence {sequence}" + ) + return False + return True + except Exception as e: + logger.error( + f"update_message_tool_calls failed for session {session_id}, " + f"sequence {sequence}: {e}" + ) + return False + + async def set_turn_duration(session_id: str, duration_ms: int) -> None: """Set durationMs on the last assistant message in a session. diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index 9cef40ba7a..5603fdf5ad 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -44,6 +44,7 @@ from backend.util.exceptions import NotFoundError from backend.util.settings import Settings from ..config import ChatConfig, CopilotLlmModel, CopilotMode +from ..db import update_message_tool_calls from ..constants import ( COPILOT_ERROR_PREFIX, COPILOT_RETRYABLE_ERROR_PREFIX, @@ -2262,6 +2263,30 @@ async def _run_stream_attempt( if dispatched is not None: yield dispatched + # If tool calls arrived this batch AND the assistant message was + # already flushed to DB (sequence is set), patch the existing row + # so tool_calls are not lost. The append-only save (start_sequence) + # in _save_session_to_db never re-saves already-persisted rows, so + # without this patch the assistant row keeps tool_calls=null. + if acc.assistant_response.sequence is not None and any( + isinstance(r, StreamToolInputAvailable) for r in adapter_responses + ): + try: + await asyncio.shield( + update_message_tool_calls( + ctx.session.session_id, + acc.assistant_response.sequence, + acc.accumulated_tool_calls, + ) + ) + except Exception as patch_err: + logger.warning( + "%s tool_calls DB patch failed (sequence=%d): %s", + ctx.log_prefix, + acc.assistant_response.sequence, + patch_err, + ) + # Append assistant entry AFTER convert_message so that # any stashed tool results from the previous turn are # recorded first, preserving the required API order: diff --git a/autogpt_platform/backend/backend/copilot/sdk/session_persistence_test.py b/autogpt_platform/backend/backend/copilot/sdk/session_persistence_test.py index ea7b128927..8d2ec5bc24 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/session_persistence_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/session_persistence_test.py @@ -20,7 +20,11 @@ from datetime import datetime, timezone from unittest.mock import MagicMock from backend.copilot.model import ChatMessage, ChatSession -from backend.copilot.response_model import StreamStartStep, StreamTextDelta +from backend.copilot.response_model import ( + StreamStartStep, + StreamTextDelta, + StreamToolInputAvailable, +) from backend.copilot.sdk.service import _dispatch_response, _StreamAccumulator _NOW = datetime(2024, 1, 1, tzinfo=timezone.utc) @@ -215,3 +219,100 @@ class TestPreCreateAssistantMessage: _simulate_pre_create(acc, ctx) assert len(ctx.session.messages) == 0 + + +class TestToolCallsLostAfterIntermediateFlush: + """Regression tests for the bug where tool_calls are lost when an + intermediate flush saves the assistant message before StreamToolInputAvailable + arrives. + + Sequence that triggers the bug: + 1. StreamTextDelta → assistant message appended with tool_calls=None + 2. Intermediate flush fires (time/count threshold) → DB row written with tool_calls=null + and acc.assistant_response.sequence is set (back-filled) + 3. StreamToolInputAvailable → acc.assistant_response.tool_calls mutated in-memory + 4. Final save: append-only — assistant row already in DB, tool_calls never updated + + Fix: when StreamToolInputAvailable arrives and acc.assistant_response.sequence + is not None, issue a DB UPDATE to patch toolCalls on the existing row. + """ + + def test_text_delta_then_tool_input_sets_tool_calls_on_message(self) -> None: + """After text arrives then tool input arrives, acc.assistant_response.tool_calls + should be populated regardless of flush state.""" + session = _make_session() + ctx = _make_ctx(session) + state = _make_state() + acc = _StreamAccumulator( + assistant_response=ChatMessage(role="assistant", content=""), + accumulated_tool_calls=[], + ) + + # Step 1: text delta arrives, message appended + _dispatch_response( + StreamTextDelta(id="t1", delta="Let me run that for you."), + acc, + ctx, + state, + False, + "[test]", + ) + assert acc.has_appended_assistant + assert session.messages[-1].tool_calls is None + + # Step 2: simulate intermediate flush back-filling the sequence + acc.assistant_response.sequence = 1 # back-filled by _save_session_to_db + + # Step 3: tool input arrives + _dispatch_response( + StreamToolInputAvailable( + toolCallId="call_abc", + toolName="bash_exec", + input={"command": "ls"}, + ), + acc, + ctx, + state, + False, + "[test]", + ) + + # tool_calls should be set in memory + assert acc.assistant_response.tool_calls is not None + assert len(acc.assistant_response.tool_calls) == 1 + assert acc.assistant_response.tool_calls[0]["id"] == "call_abc" + + def test_sequence_set_when_flush_occurred_before_tool_input(self) -> None: + """When sequence is back-filled (flush happened) before tool calls arrive, + it is detectable so the caller can issue a DB patch.""" + acc = _StreamAccumulator( + assistant_response=ChatMessage(role="assistant", content="hello"), + accumulated_tool_calls=[], + has_appended_assistant=True, + ) + # Simulate flush back-fill + acc.assistant_response.sequence = 3 + + ctx = _make_ctx() + state = _make_state() + + _dispatch_response( + StreamToolInputAvailable( + toolCallId="call_xyz", + toolName="run_block", + input={}, + ), + acc, + ctx, + state, + False, + "[test]", + ) + + # Caller should detect this condition and issue a DB patch + needs_db_patch = acc.assistant_response.sequence is not None and bool( + acc.accumulated_tool_calls + ) + assert ( + needs_db_patch + ), "Expected needs_db_patch=True when flush happened before tool calls arrived"