mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-06 04:45:10 -05:00
fix(backend): resolve chat session persistence race condition on sequence uniqueness
Compute message sequence numbers inside the database transaction rather than trusting a pre-queried count, preventing UniqueViolationError when multiple pods insert messages concurrently. Adds retry logic at both the db batch-insert layer and the model upsert layer for defense in depth.
This commit is contained in:
@@ -5,6 +5,7 @@ import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, cast
|
||||
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.types import (
|
||||
@@ -135,6 +136,28 @@ async def add_chat_message(
|
||||
return message
|
||||
|
||||
|
||||
async def get_max_message_sequence(session_id: str, tx: Any = None) -> int:
|
||||
"""Get the highest sequence number for a session's messages.
|
||||
|
||||
Args:
|
||||
session_id: The chat session ID.
|
||||
tx: Optional transaction client for running inside a transaction.
|
||||
|
||||
Returns:
|
||||
The max sequence number, or -1 if no messages exist
|
||||
(so that max + 1 = 0 for the first message).
|
||||
"""
|
||||
client = PrismaChatMessage.prisma(tx) if tx else PrismaChatMessage.prisma()
|
||||
results = await client.find_many(
|
||||
where={"sessionId": session_id},
|
||||
order={"sequence": "desc"},
|
||||
take=1,
|
||||
)
|
||||
if results:
|
||||
return results[0].sequence
|
||||
return -1
|
||||
|
||||
|
||||
async def add_chat_messages_batch(
|
||||
session_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
@@ -143,54 +166,88 @@ async def add_chat_messages_batch(
|
||||
"""Add multiple messages to a chat session in a batch.
|
||||
|
||||
Uses a transaction for atomicity - if any message creation fails,
|
||||
the entire batch is rolled back.
|
||||
the entire batch is rolled back. Computes the actual start sequence
|
||||
inside the transaction to prevent race conditions in multi-pod deployments.
|
||||
|
||||
Retries once on UniqueViolationError (another pod may have inserted
|
||||
messages with the same sequence numbers concurrently).
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
created_messages = []
|
||||
max_attempts = 2
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
created_messages = []
|
||||
|
||||
async with transaction() as tx:
|
||||
for i, msg in enumerate(messages):
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||
# directly because Prisma's TypedDict validation rejects optional fields
|
||||
# set to None. We only include fields that have values, then cast.
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
async with transaction() as tx:
|
||||
# Compute authoritative start sequence inside the transaction
|
||||
actual_max = await get_max_message_sequence(session_id, tx)
|
||||
actual_start = actual_max + 1
|
||||
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
if actual_start != start_sequence:
|
||||
logger.warning(
|
||||
f"Sequence adjustment for session {session_id}: "
|
||||
f"caller provided start_sequence={start_sequence}, "
|
||||
f"but DB max sequence is {actual_max} "
|
||||
f"(using {actual_start} instead)"
|
||||
)
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
for i, msg in enumerate(messages):
|
||||
# Build input dict dynamically rather than using
|
||||
# ChatMessageCreateInput directly because Prisma's TypedDict
|
||||
# validation rejects optional fields set to None.
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": actual_start + i,
|
||||
}
|
||||
|
||||
created = await PrismaChatMessage.prisma(tx).create(
|
||||
data=cast(ChatMessageCreateInput, data)
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
created = await PrismaChatMessage.prisma(tx).create(
|
||||
data=cast(ChatMessageCreateInput, data)
|
||||
)
|
||||
created_messages.append(created)
|
||||
|
||||
# Update session's updatedAt timestamp within the same transaction.
|
||||
# Note: Token usage is updated separately via update_chat_session().
|
||||
await PrismaChatSession.prisma(tx).update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
|
||||
return created_messages
|
||||
|
||||
except UniqueViolationError:
|
||||
if attempt < max_attempts:
|
||||
logger.warning(
|
||||
f"UniqueViolationError on attempt {attempt} for session "
|
||||
f"{session_id}, retrying with fresh sequence"
|
||||
)
|
||||
continue
|
||||
logger.error(
|
||||
f"UniqueViolationError persisted after {max_attempts} attempts "
|
||||
f"for session {session_id}"
|
||||
)
|
||||
created_messages.append(created)
|
||||
raise
|
||||
|
||||
# Update session's updatedAt timestamp within the same transaction.
|
||||
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
|
||||
# separately via update_chat_session() after streaming completes.
|
||||
await PrismaChatSession.prisma(tx).update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
|
||||
return created_messages
|
||||
# Unreachable, but satisfies type checker
|
||||
return []
|
||||
|
||||
|
||||
async def get_user_chat_sessions(
|
||||
|
||||
@@ -19,6 +19,7 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
|
||||
ChatCompletionMessageToolCallParam,
|
||||
Function,
|
||||
)
|
||||
from prisma.errors import UniqueViolationError
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from pydantic import BaseModel
|
||||
@@ -468,6 +469,37 @@ async def upsert_chat_session(
|
||||
# Save to database (primary storage)
|
||||
try:
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
except UniqueViolationError:
|
||||
# Another pod likely saved the same messages concurrently.
|
||||
# Re-query the message count and retry if unsaved messages remain.
|
||||
logger.warning(
|
||||
f"UniqueViolationError saving session {session.session_id}, "
|
||||
f"re-querying message count for retry"
|
||||
)
|
||||
try:
|
||||
fresh_count = await chat_db.get_chat_session_message_count(
|
||||
session.session_id
|
||||
)
|
||||
if fresh_count < len(session.messages):
|
||||
logger.info(
|
||||
f"Retrying save for session {session.session_id}: "
|
||||
f"fresh_count={fresh_count}, "
|
||||
f"total_messages={len(session.messages)}"
|
||||
)
|
||||
await _save_session_to_db(session, fresh_count)
|
||||
else:
|
||||
logger.info(
|
||||
f"All messages already saved for session "
|
||||
f"{session.session_id} by another process "
|
||||
f"(db_count={fresh_count}, "
|
||||
f"session_count={len(session.messages)})"
|
||||
)
|
||||
except Exception as retry_err:
|
||||
logger.error(
|
||||
f"Retry also failed for session {session.session_id}: "
|
||||
f"{retry_err}"
|
||||
)
|
||||
db_error = retry_err
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save session {session.session_id} to database: {e}"
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from prisma.errors import UniqueViolationError
|
||||
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
@@ -117,3 +120,92 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
|
||||
loaded.tool_calls is not None
|
||||
), f"Tool calls missing for {orig.role} message"
|
||||
assert len(orig.tool_calls) == len(loaded.tool_calls)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_upsert_handles_concurrent_saves(setup_test_user, test_user_id):
|
||||
"""Test that incremental saves work: save initial messages, add more, save again."""
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
# Create session with initial messages
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = [
|
||||
ChatMessage(content="First message", role="user"),
|
||||
ChatMessage(content="First reply", role="assistant"),
|
||||
]
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
# Add more messages and save again (incremental)
|
||||
s.messages.append(ChatMessage(content="Second message", role="user"))
|
||||
s.messages.append(ChatMessage(content="Second reply", role="assistant"))
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
# Clear cache and verify all messages round-trip from DB
|
||||
redis_key = f"chat:session:{s.session_id}"
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.delete(redis_key)
|
||||
|
||||
loaded = await get_chat_session(session_id=s.session_id, user_id=s.user_id)
|
||||
assert loaded is not None, "Session not found after incremental save"
|
||||
assert len(loaded.messages) == 4, f"Expected 4 messages, got {len(loaded.messages)}"
|
||||
|
||||
# Verify content of all messages
|
||||
assert loaded.messages[0].content == "First message"
|
||||
assert loaded.messages[1].content == "First reply"
|
||||
assert loaded.messages[2].content == "Second message"
|
||||
assert loaded.messages[3].content == "Second reply"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_upsert_retries_on_unique_violation(setup_test_user, test_user_id):
|
||||
"""Test that upsert_chat_session retries when UniqueViolationError is raised."""
|
||||
from . import db as chat_db
|
||||
|
||||
# Create a session with initial messages
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = [
|
||||
ChatMessage(content="Hello", role="user"),
|
||||
]
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
# Add a new message
|
||||
s.messages.append(ChatMessage(content="World", role="assistant"))
|
||||
|
||||
# Mock add_chat_messages_batch to raise UniqueViolationError on first call,
|
||||
# then succeed on second call (simulating another pod saving concurrently).
|
||||
original_batch = chat_db.add_chat_messages_batch
|
||||
call_count = 0
|
||||
|
||||
async def mock_batch(session_id, messages, start_sequence):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise UniqueViolationError(
|
||||
{
|
||||
"error": "Unique constraint failed on the fields: (sessionId, sequence)"
|
||||
}
|
||||
)
|
||||
return await original_batch(session_id, messages, start_sequence)
|
||||
|
||||
with patch.object(chat_db, "add_chat_messages_batch", side_effect=mock_batch):
|
||||
# Also mock get_chat_session_message_count to return 1 on retry
|
||||
# (simulating that the first message was saved by "another pod")
|
||||
original_count = chat_db.get_chat_session_message_count
|
||||
|
||||
count_call = 0
|
||||
|
||||
async def mock_count(session_id):
|
||||
nonlocal count_call
|
||||
count_call += 1
|
||||
# First call is the initial count check in upsert_chat_session
|
||||
# Second call is the retry after UniqueViolationError
|
||||
return await original_count(session_id)
|
||||
|
||||
with patch.object(
|
||||
chat_db, "get_chat_session_message_count", side_effect=mock_count
|
||||
):
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
# Verify the session completed without error
|
||||
assert s is not None
|
||||
assert len(s.messages) == 2
|
||||
|
||||
Reference in New Issue
Block a user