From 6172f7b1f5bddf1a9e8cebd5f92d01ccda22c7d5 Mon Sep 17 00:00:00 2001 From: Swifty Date: Fri, 6 Feb 2026 09:55:15 +0100 Subject: [PATCH] fix(backend): resolve chat session persistence race condition on sequence uniqueness Compute message sequence numbers inside the database transaction rather than trusting a pre-queried count, preventing UniqueViolationError when multiple pods insert messages concurrently. Adds retry logic at both the db batch-insert layer and the model upsert layer for defense in depth. --- .../backend/backend/api/features/chat/db.py | 133 +++++++++++++----- .../backend/api/features/chat/model.py | 32 +++++ .../backend/api/features/chat/model_test.py | 92 ++++++++++++ 3 files changed, 219 insertions(+), 38 deletions(-) diff --git a/autogpt_platform/backend/backend/api/features/chat/db.py b/autogpt_platform/backend/backend/api/features/chat/db.py index d34b4e5b07..a2cdcd050c 100644 --- a/autogpt_platform/backend/backend/api/features/chat/db.py +++ b/autogpt_platform/backend/backend/api/features/chat/db.py @@ -5,6 +5,7 @@ import logging from datetime import UTC, datetime from typing import Any, cast +from prisma.errors import UniqueViolationError from prisma.models import ChatMessage as PrismaChatMessage from prisma.models import ChatSession as PrismaChatSession from prisma.types import ( @@ -135,6 +136,28 @@ async def add_chat_message( return message +async def get_max_message_sequence(session_id: str, tx: Any = None) -> int: + """Get the highest sequence number for a session's messages. + + Args: + session_id: The chat session ID. + tx: Optional transaction client for running inside a transaction. + + Returns: + The max sequence number, or -1 if no messages exist + (so that max + 1 = 0 for the first message). + """ + client = PrismaChatMessage.prisma(tx) if tx else PrismaChatMessage.prisma() + results = await client.find_many( + where={"sessionId": session_id}, + order={"sequence": "desc"}, + take=1, + ) + if results: + return results[0].sequence + return -1 + + async def add_chat_messages_batch( session_id: str, messages: list[dict[str, Any]], @@ -143,54 +166,88 @@ async def add_chat_messages_batch( """Add multiple messages to a chat session in a batch. Uses a transaction for atomicity - if any message creation fails, - the entire batch is rolled back. + the entire batch is rolled back. Computes the actual start sequence + inside the transaction to prevent race conditions in multi-pod deployments. + + Retries once on UniqueViolationError (another pod may have inserted + messages with the same sequence numbers concurrently). """ if not messages: return [] - created_messages = [] + max_attempts = 2 + for attempt in range(1, max_attempts + 1): + try: + created_messages = [] - async with transaction() as tx: - for i, msg in enumerate(messages): - # Build input dict dynamically rather than using ChatMessageCreateInput - # directly because Prisma's TypedDict validation rejects optional fields - # set to None. We only include fields that have values, then cast. - data: dict[str, Any] = { - "Session": {"connect": {"id": session_id}}, - "role": msg["role"], - "sequence": start_sequence + i, - } + async with transaction() as tx: + # Compute authoritative start sequence inside the transaction + actual_max = await get_max_message_sequence(session_id, tx) + actual_start = actual_max + 1 - # Add optional string fields - if msg.get("content") is not None: - data["content"] = msg["content"] - if msg.get("name") is not None: - data["name"] = msg["name"] - if msg.get("tool_call_id") is not None: - data["toolCallId"] = msg["tool_call_id"] - if msg.get("refusal") is not None: - data["refusal"] = msg["refusal"] + if actual_start != start_sequence: + logger.warning( + f"Sequence adjustment for session {session_id}: " + f"caller provided start_sequence={start_sequence}, " + f"but DB max sequence is {actual_max} " + f"(using {actual_start} instead)" + ) - # Add optional JSON fields only when they have values - if msg.get("tool_calls") is not None: - data["toolCalls"] = SafeJson(msg["tool_calls"]) - if msg.get("function_call") is not None: - data["functionCall"] = SafeJson(msg["function_call"]) + for i, msg in enumerate(messages): + # Build input dict dynamically rather than using + # ChatMessageCreateInput directly because Prisma's TypedDict + # validation rejects optional fields set to None. + data: dict[str, Any] = { + "Session": {"connect": {"id": session_id}}, + "role": msg["role"], + "sequence": actual_start + i, + } - created = await PrismaChatMessage.prisma(tx).create( - data=cast(ChatMessageCreateInput, data) + # Add optional string fields + if msg.get("content") is not None: + data["content"] = msg["content"] + if msg.get("name") is not None: + data["name"] = msg["name"] + if msg.get("tool_call_id") is not None: + data["toolCallId"] = msg["tool_call_id"] + if msg.get("refusal") is not None: + data["refusal"] = msg["refusal"] + + # Add optional JSON fields only when they have values + if msg.get("tool_calls") is not None: + data["toolCalls"] = SafeJson(msg["tool_calls"]) + if msg.get("function_call") is not None: + data["functionCall"] = SafeJson(msg["function_call"]) + + created = await PrismaChatMessage.prisma(tx).create( + data=cast(ChatMessageCreateInput, data) + ) + created_messages.append(created) + + # Update session's updatedAt timestamp within the same transaction. + # Note: Token usage is updated separately via update_chat_session(). + await PrismaChatSession.prisma(tx).update( + where={"id": session_id}, + data={"updatedAt": datetime.now(UTC)}, + ) + + return created_messages + + except UniqueViolationError: + if attempt < max_attempts: + logger.warning( + f"UniqueViolationError on attempt {attempt} for session " + f"{session_id}, retrying with fresh sequence" + ) + continue + logger.error( + f"UniqueViolationError persisted after {max_attempts} attempts " + f"for session {session_id}" ) - created_messages.append(created) + raise - # Update session's updatedAt timestamp within the same transaction. - # Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated - # separately via update_chat_session() after streaming completes. - await PrismaChatSession.prisma(tx).update( - where={"id": session_id}, - data={"updatedAt": datetime.now(UTC)}, - ) - - return created_messages + # Unreachable, but satisfies type checker + return [] async def get_user_chat_sessions( diff --git a/autogpt_platform/backend/backend/api/features/chat/model.py b/autogpt_platform/backend/backend/api/features/chat/model.py index 7318ef88d7..39d616414c 100644 --- a/autogpt_platform/backend/backend/api/features/chat/model.py +++ b/autogpt_platform/backend/backend/api/features/chat/model.py @@ -19,6 +19,7 @@ from openai.types.chat.chat_completion_message_tool_call_param import ( ChatCompletionMessageToolCallParam, Function, ) +from prisma.errors import UniqueViolationError from prisma.models import ChatMessage as PrismaChatMessage from prisma.models import ChatSession as PrismaChatSession from pydantic import BaseModel @@ -468,6 +469,37 @@ async def upsert_chat_session( # Save to database (primary storage) try: await _save_session_to_db(session, existing_message_count) + except UniqueViolationError: + # Another pod likely saved the same messages concurrently. + # Re-query the message count and retry if unsaved messages remain. + logger.warning( + f"UniqueViolationError saving session {session.session_id}, " + f"re-querying message count for retry" + ) + try: + fresh_count = await chat_db.get_chat_session_message_count( + session.session_id + ) + if fresh_count < len(session.messages): + logger.info( + f"Retrying save for session {session.session_id}: " + f"fresh_count={fresh_count}, " + f"total_messages={len(session.messages)}" + ) + await _save_session_to_db(session, fresh_count) + else: + logger.info( + f"All messages already saved for session " + f"{session.session_id} by another process " + f"(db_count={fresh_count}, " + f"session_count={len(session.messages)})" + ) + except Exception as retry_err: + logger.error( + f"Retry also failed for session {session.session_id}: " + f"{retry_err}" + ) + db_error = retry_err except Exception as e: logger.error( f"Failed to save session {session.session_id} to database: {e}" diff --git a/autogpt_platform/backend/backend/api/features/chat/model_test.py b/autogpt_platform/backend/backend/api/features/chat/model_test.py index c230b00f9c..581ea32763 100644 --- a/autogpt_platform/backend/backend/api/features/chat/model_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/model_test.py @@ -1,4 +1,7 @@ +from unittest.mock import patch + import pytest +from prisma.errors import UniqueViolationError from .model import ( ChatMessage, @@ -117,3 +120,92 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id): loaded.tool_calls is not None ), f"Tool calls missing for {orig.role} message" assert len(orig.tool_calls) == len(loaded.tool_calls) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_upsert_handles_concurrent_saves(setup_test_user, test_user_id): + """Test that incremental saves work: save initial messages, add more, save again.""" + from backend.data.redis_client import get_redis_async + + # Create session with initial messages + s = ChatSession.new(user_id=test_user_id) + s.messages = [ + ChatMessage(content="First message", role="user"), + ChatMessage(content="First reply", role="assistant"), + ] + s = await upsert_chat_session(s) + + # Add more messages and save again (incremental) + s.messages.append(ChatMessage(content="Second message", role="user")) + s.messages.append(ChatMessage(content="Second reply", role="assistant")) + s = await upsert_chat_session(s) + + # Clear cache and verify all messages round-trip from DB + redis_key = f"chat:session:{s.session_id}" + async_redis = await get_redis_async() + await async_redis.delete(redis_key) + + loaded = await get_chat_session(session_id=s.session_id, user_id=s.user_id) + assert loaded is not None, "Session not found after incremental save" + assert len(loaded.messages) == 4, f"Expected 4 messages, got {len(loaded.messages)}" + + # Verify content of all messages + assert loaded.messages[0].content == "First message" + assert loaded.messages[1].content == "First reply" + assert loaded.messages[2].content == "Second message" + assert loaded.messages[3].content == "Second reply" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_upsert_retries_on_unique_violation(setup_test_user, test_user_id): + """Test that upsert_chat_session retries when UniqueViolationError is raised.""" + from . import db as chat_db + + # Create a session with initial messages + s = ChatSession.new(user_id=test_user_id) + s.messages = [ + ChatMessage(content="Hello", role="user"), + ] + s = await upsert_chat_session(s) + + # Add a new message + s.messages.append(ChatMessage(content="World", role="assistant")) + + # Mock add_chat_messages_batch to raise UniqueViolationError on first call, + # then succeed on second call (simulating another pod saving concurrently). + original_batch = chat_db.add_chat_messages_batch + call_count = 0 + + async def mock_batch(session_id, messages, start_sequence): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise UniqueViolationError( + { + "error": "Unique constraint failed on the fields: (sessionId, sequence)" + } + ) + return await original_batch(session_id, messages, start_sequence) + + with patch.object(chat_db, "add_chat_messages_batch", side_effect=mock_batch): + # Also mock get_chat_session_message_count to return 1 on retry + # (simulating that the first message was saved by "another pod") + original_count = chat_db.get_chat_session_message_count + + count_call = 0 + + async def mock_count(session_id): + nonlocal count_call + count_call += 1 + # First call is the initial count check in upsert_chat_session + # Second call is the retry after UniqueViolationError + return await original_count(session_id) + + with patch.object( + chat_db, "get_chat_session_message_count", side_effect=mock_count + ): + s = await upsert_chat_session(s) + + # Verify the session completed without error + assert s is not None + assert len(s.messages) == 2