refactor(copilot): replace upsert with collision detection for concurrent message saves

- Use create() with MAX(sequence) retry instead of upsert()
- Query DB only on collision (not every save) for better performance
- Remove Layer 2 DB queries from incremental saves in streaming loop
- Add get_max_sequence() helper using raw SQL for robustness
- Collision detection retries up to 3 times on unique constraint errors

This approach:
- Optimizes common case (no collision) - no extra DB queries
- Handles concurrent writes via automatic retry with correct sequence
- Uses MAX(sequence) instead of COUNT for more robust offset calculation
This commit is contained in:
Zamil Majdy
2026-02-20 18:09:34 +07:00
parent 6acefee6f3
commit af491b5511
2 changed files with 100 additions and 100 deletions

View File

@@ -9,7 +9,6 @@ from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from prisma.types import (
ChatMessageCreateInput,
ChatMessageUpdateInput,
ChatSessionCreateInput,
ChatSessionUpdateInput,
ChatSessionWhereInput,
@@ -136,86 +135,86 @@ async def add_chat_messages_batch(
) -> list[ChatMessage]:
"""Add multiple messages to a chat session in a batch.
Uses a transaction for atomicity. Each message is upserted by the
(sessionId, sequence) composite key so that concurrent writers
(e.g., the streaming loop racing with a long-running tool callback)
can overlap without triggering a unique-constraint violation;
the later write wins and overwrites the existing row. Non-duplicate
failures still roll back the entire batch.
Uses collision detection with retry: tries to create messages starting
at start_sequence. If a unique constraint violation occurs (e.g., the
streaming loop and long-running callback race), queries MAX(sequence)
and retries with the correct next sequence number. This avoids
unnecessary upserts and DB queries in the common case (no collision).
"""
if not messages:
return []
created_messages = []
async with db.transaction() as tx:
for i, msg in enumerate(messages):
# Build input dict dynamically rather than using ChatMessageCreateInput
# directly because Prisma's TypedDict validation rejects optional fields
# set to None. We only include fields that have values, then cast.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": msg["role"],
"sequence": start_sequence + i,
}
# Add optional string fields
if msg.get("content") is not None:
data["content"] = msg["content"]
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = msg["refusal"]
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
# Use upsert to handle concurrent writers (e.g. incremental
# streaming saves racing with long-running tool callbacks) that
# may produce duplicate (sessionId, sequence) pairs.
# Explicitly construct update_data (exclude Session relation and sequence)
update_data: dict[str, Any] = {"role": msg["role"]}
if msg.get("content") is not None:
update_data["content"] = msg["content"]
if msg.get("name") is not None:
update_data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
update_data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
update_data["refusal"] = msg["refusal"]
if msg.get("tool_calls") is not None:
update_data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
update_data["functionCall"] = SafeJson(msg["function_call"])
created = await PrismaChatMessage.prisma(tx).upsert(
where={
"sessionId_sequence": {
"sessionId": session_id,
max_retries = 3
for attempt in range(max_retries):
try:
created_messages = []
async with db.transaction() as tx:
for i, msg in enumerate(messages):
# Build input dict dynamically rather than using ChatMessageCreateInput
# directly because Prisma's TypedDict validation rejects optional fields
# set to None. We only include fields that have values, then cast.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": msg["role"],
"sequence": start_sequence + i,
}
},
data={
"create": cast(ChatMessageCreateInput, data),
"update": cast(ChatMessageUpdateInput, update_data),
},
# Add optional string fields
if msg.get("content") is not None:
data["content"] = msg["content"]
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = msg["refusal"]
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
created = await PrismaChatMessage.prisma(tx).create(
data=cast(ChatMessageCreateInput, data)
)
created_messages.append(created)
# Update session's updatedAt timestamp within the same transaction.
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
# separately via update_chat_session() after streaming completes.
await PrismaChatSession.prisma(tx).update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
)
return [ChatMessage.from_db(m) for m in created_messages]
except Exception as e:
# Check if it's a unique constraint violation
error_msg = str(e).lower()
is_unique_constraint = (
"unique constraint" in error_msg or "duplicate key" in error_msg
)
created_messages.append(created)
# Update session's updatedAt timestamp within the same transaction.
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
# separately via update_chat_session() after streaming completes.
await PrismaChatSession.prisma(tx).update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
)
if is_unique_constraint and attempt < max_retries - 1:
# Collision detected - query MAX(sequence) and retry with correct offset
logger.info(
f"Collision detected for session {session_id} at sequence "
f"{start_sequence}, querying DB for latest sequence"
)
max_seq = await get_max_sequence(session_id)
start_sequence = max_seq + 1
logger.info(
f"Retrying batch insert with start_sequence={start_sequence}"
)
continue
else:
# Not a collision or max retries exceeded - propagate error
raise
return [ChatMessage.from_db(m) for m in created_messages]
# Should never reach here due to raise in exception handler
raise RuntimeError(f"Failed to insert messages after {max_retries} attempts")
async def get_user_chat_sessions(
@@ -275,6 +274,25 @@ async def get_chat_session_message_count(session_id: str) -> int:
return count
async def get_max_sequence(session_id: str) -> int:
"""Get the maximum sequence number for a session.
Returns the highest sequence number, or -1 if no messages exist.
This is used for collision detection when concurrent writers race.
"""
result = await db.prisma.query_raw(
"""
SELECT COALESCE(MAX(sequence), -1) as max_seq
FROM "ChatMessage"
WHERE "sessionId" = $1
""",
session_id,
)
if not result or len(result) == 0:
return -1
return int(result[0]["max_seq"])
async def update_tool_message_content(
session_id: str,
tool_call_id: str,

View File

@@ -11,7 +11,6 @@ from typing import Any
from backend.util.exceptions import NotFoundError
from .. import db as chat_db
from .. import stream_registry
from ..config import ChatConfig
from ..model import (
@@ -217,10 +216,9 @@ def _build_long_running_callback(
tool_call_id=tool_call_id,
)
session.messages.append(pending_message)
# Layer 2: Query DB for latest count before save (defense against stale counter)
db_count = await chat_db.get_chat_session_message_count(session_id)
await upsert_chat_session(session, existing_message_count=db_count)
# Layer 3: Update shared counter so streaming loop stays in sync
# Collision detection happens in add_chat_messages_batch (db.py)
await upsert_chat_session(session)
# Update shared counter so streaming loop stays in sync
if saved_msg_count_ref is not None:
saved_msg_count_ref[0] = len(session.messages)
@@ -913,19 +911,11 @@ async def stream_chat_completion_sdk(
has_appended_assistant = True
# Save before tool execution starts so the
# pending tool call is visible on refresh /
# other devices.
# other devices. Collision detection happens
# in add_chat_messages_batch (db.py).
try:
# Layer 2: Query DB for latest count (defense against stale counter)
db_count = (
await chat_db.get_chat_session_message_count(
session_id
)
)
await upsert_chat_session(
session,
existing_message_count=db_count,
)
# Layer 3: Update shared ref so callback stays in sync
await upsert_chat_session(session)
# Update shared ref so callback stays in sync
saved_msg_count_ref[0] = len(session.messages)
except Exception as save_err:
logger.warning(
@@ -949,18 +939,10 @@ async def stream_chat_completion_sdk(
has_tool_results = True
# Save after tool completes so the result is
# visible on refresh / other devices.
# Collision detection happens in add_chat_messages_batch (db.py).
try:
# Layer 2: Query DB for latest count (defense against stale counter)
db_count = (
await chat_db.get_chat_session_message_count(
session_id
)
)
await upsert_chat_session(
session,
existing_message_count=db_count,
)
# Layer 3: Update shared ref so callback stays in sync
await upsert_chat_session(session)
# Update shared ref so callback stays in sync
saved_msg_count_ref[0] = len(session.messages)
except Exception as save_err:
logger.warning(