Compare commits

...

2 Commits

Author SHA1 Message Date
Swifty
45a1c522a8 Merge branch 'dev' into swiftyos/secrt-1905-bug-chat-session-persistence-race-condition-unique 2026-02-06 10:19:29 +01:00
Swifty
6172f7b1f5 fix(backend): resolve chat session persistence race condition on sequence uniqueness
Compute message sequence numbers inside the database transaction rather
than trusting a pre-queried count, preventing UniqueViolationError when
multiple pods insert messages concurrently. Adds retry logic at both the
db batch-insert layer and the model upsert layer for defense in depth.
2026-02-06 09:55:15 +01:00
3 changed files with 219 additions and 38 deletions

View File

@@ -5,6 +5,7 @@ import logging
from datetime import UTC, datetime
from typing import Any, cast
from prisma.errors import UniqueViolationError
from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from prisma.types import (
@@ -135,6 +136,28 @@ async def add_chat_message(
return message
async def get_max_message_sequence(session_id: str, tx: Any = None) -> int:
"""Get the highest sequence number for a session's messages.
Args:
session_id: The chat session ID.
tx: Optional transaction client for running inside a transaction.
Returns:
The max sequence number, or -1 if no messages exist
(so that max + 1 = 0 for the first message).
"""
client = PrismaChatMessage.prisma(tx) if tx else PrismaChatMessage.prisma()
results = await client.find_many(
where={"sessionId": session_id},
order={"sequence": "desc"},
take=1,
)
if results:
return results[0].sequence
return -1
async def add_chat_messages_batch(
session_id: str,
messages: list[dict[str, Any]],
@@ -143,54 +166,88 @@ async def add_chat_messages_batch(
"""Add multiple messages to a chat session in a batch.
Uses a transaction for atomicity - if any message creation fails,
the entire batch is rolled back.
the entire batch is rolled back. Computes the actual start sequence
inside the transaction to prevent race conditions in multi-pod deployments.
Retries once on UniqueViolationError (another pod may have inserted
messages with the same sequence numbers concurrently).
"""
if not messages:
return []
created_messages = []
max_attempts = 2
for attempt in range(1, max_attempts + 1):
try:
created_messages = []
async with transaction() as tx:
for i, msg in enumerate(messages):
# Build input dict dynamically rather than using ChatMessageCreateInput
# directly because Prisma's TypedDict validation rejects optional fields
# set to None. We only include fields that have values, then cast.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": msg["role"],
"sequence": start_sequence + i,
}
async with transaction() as tx:
# Compute authoritative start sequence inside the transaction
actual_max = await get_max_message_sequence(session_id, tx)
actual_start = actual_max + 1
# Add optional string fields
if msg.get("content") is not None:
data["content"] = msg["content"]
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = msg["refusal"]
if actual_start != start_sequence:
logger.warning(
f"Sequence adjustment for session {session_id}: "
f"caller provided start_sequence={start_sequence}, "
f"but DB max sequence is {actual_max} "
f"(using {actual_start} instead)"
)
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
for i, msg in enumerate(messages):
# Build input dict dynamically rather than using
# ChatMessageCreateInput directly because Prisma's TypedDict
# validation rejects optional fields set to None.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": msg["role"],
"sequence": actual_start + i,
}
created = await PrismaChatMessage.prisma(tx).create(
data=cast(ChatMessageCreateInput, data)
# Add optional string fields
if msg.get("content") is not None:
data["content"] = msg["content"]
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = msg["refusal"]
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
created = await PrismaChatMessage.prisma(tx).create(
data=cast(ChatMessageCreateInput, data)
)
created_messages.append(created)
# Update session's updatedAt timestamp within the same transaction.
# Note: Token usage is updated separately via update_chat_session().
await PrismaChatSession.prisma(tx).update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
)
return created_messages
except UniqueViolationError:
if attempt < max_attempts:
logger.warning(
f"UniqueViolationError on attempt {attempt} for session "
f"{session_id}, retrying with fresh sequence"
)
continue
logger.error(
f"UniqueViolationError persisted after {max_attempts} attempts "
f"for session {session_id}"
)
created_messages.append(created)
raise
# Update session's updatedAt timestamp within the same transaction.
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
# separately via update_chat_session() after streaming completes.
await PrismaChatSession.prisma(tx).update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
)
return created_messages
# Unreachable, but satisfies type checker
return []
async def get_user_chat_sessions(

View File

@@ -19,6 +19,7 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
ChatCompletionMessageToolCallParam,
Function,
)
from prisma.errors import UniqueViolationError
from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from pydantic import BaseModel
@@ -468,6 +469,37 @@ async def upsert_chat_session(
# Save to database (primary storage)
try:
await _save_session_to_db(session, existing_message_count)
except UniqueViolationError:
# Another pod likely saved the same messages concurrently.
# Re-query the message count and retry if unsaved messages remain.
logger.warning(
f"UniqueViolationError saving session {session.session_id}, "
f"re-querying message count for retry"
)
try:
fresh_count = await chat_db.get_chat_session_message_count(
session.session_id
)
if fresh_count < len(session.messages):
logger.info(
f"Retrying save for session {session.session_id}: "
f"fresh_count={fresh_count}, "
f"total_messages={len(session.messages)}"
)
await _save_session_to_db(session, fresh_count)
else:
logger.info(
f"All messages already saved for session "
f"{session.session_id} by another process "
f"(db_count={fresh_count}, "
f"session_count={len(session.messages)})"
)
except Exception as retry_err:
logger.error(
f"Retry also failed for session {session.session_id}: "
f"{retry_err}"
)
db_error = retry_err
except Exception as e:
logger.error(
f"Failed to save session {session.session_id} to database: {e}"

View File

@@ -1,4 +1,7 @@
from unittest.mock import patch
import pytest
from prisma.errors import UniqueViolationError
from .model import (
ChatMessage,
@@ -117,3 +120,92 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
loaded.tool_calls is not None
), f"Tool calls missing for {orig.role} message"
assert len(orig.tool_calls) == len(loaded.tool_calls)
@pytest.mark.asyncio(loop_scope="session")
async def test_upsert_handles_concurrent_saves(setup_test_user, test_user_id):
"""Test that incremental saves work: save initial messages, add more, save again."""
from backend.data.redis_client import get_redis_async
# Create session with initial messages
s = ChatSession.new(user_id=test_user_id)
s.messages = [
ChatMessage(content="First message", role="user"),
ChatMessage(content="First reply", role="assistant"),
]
s = await upsert_chat_session(s)
# Add more messages and save again (incremental)
s.messages.append(ChatMessage(content="Second message", role="user"))
s.messages.append(ChatMessage(content="Second reply", role="assistant"))
s = await upsert_chat_session(s)
# Clear cache and verify all messages round-trip from DB
redis_key = f"chat:session:{s.session_id}"
async_redis = await get_redis_async()
await async_redis.delete(redis_key)
loaded = await get_chat_session(session_id=s.session_id, user_id=s.user_id)
assert loaded is not None, "Session not found after incremental save"
assert len(loaded.messages) == 4, f"Expected 4 messages, got {len(loaded.messages)}"
# Verify content of all messages
assert loaded.messages[0].content == "First message"
assert loaded.messages[1].content == "First reply"
assert loaded.messages[2].content == "Second message"
assert loaded.messages[3].content == "Second reply"
@pytest.mark.asyncio(loop_scope="session")
async def test_upsert_retries_on_unique_violation(setup_test_user, test_user_id):
"""Test that upsert_chat_session retries when UniqueViolationError is raised."""
from . import db as chat_db
# Create a session with initial messages
s = ChatSession.new(user_id=test_user_id)
s.messages = [
ChatMessage(content="Hello", role="user"),
]
s = await upsert_chat_session(s)
# Add a new message
s.messages.append(ChatMessage(content="World", role="assistant"))
# Mock add_chat_messages_batch to raise UniqueViolationError on first call,
# then succeed on second call (simulating another pod saving concurrently).
original_batch = chat_db.add_chat_messages_batch
call_count = 0
async def mock_batch(session_id, messages, start_sequence):
nonlocal call_count
call_count += 1
if call_count == 1:
raise UniqueViolationError(
{
"error": "Unique constraint failed on the fields: (sessionId, sequence)"
}
)
return await original_batch(session_id, messages, start_sequence)
with patch.object(chat_db, "add_chat_messages_batch", side_effect=mock_batch):
# Also mock get_chat_session_message_count to return 1 on retry
# (simulating that the first message was saved by "another pod")
original_count = chat_db.get_chat_session_message_count
count_call = 0
async def mock_count(session_id):
nonlocal count_call
count_call += 1
# First call is the initial count check in upsert_chat_session
# Second call is the retry after UniqueViolationError
return await original_count(session_id)
with patch.object(
chat_db, "get_chat_session_message_count", side_effect=mock_count
):
s = await upsert_chat_session(s)
# Verify the session completed without error
assert s is not None
assert len(s.messages) == 2