mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
refactor(copilot): replace upsert with collision detection for concurrent message saves
- Use create() with MAX(sequence) retry instead of upsert() - Query DB only on collision (not every save) for better performance - Remove Layer 2 DB queries from incremental saves in streaming loop - Add get_max_sequence() helper using raw SQL for robustness - Collision detection retries up to 3 times on unique constraint errors This approach: - Optimizes common case (no collision) - no extra DB queries - Handles concurrent writes via automatic retry with correct sequence - Uses MAX(sequence) instead of COUNT for more robust offset calculation
This commit is contained in:
@@ -9,7 +9,6 @@ from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.types import (
|
||||
ChatMessageCreateInput,
|
||||
ChatMessageUpdateInput,
|
||||
ChatSessionCreateInput,
|
||||
ChatSessionUpdateInput,
|
||||
ChatSessionWhereInput,
|
||||
@@ -136,86 +135,86 @@ async def add_chat_messages_batch(
|
||||
) -> list[ChatMessage]:
|
||||
"""Add multiple messages to a chat session in a batch.
|
||||
|
||||
Uses a transaction for atomicity. Each message is upserted by the
|
||||
(sessionId, sequence) composite key so that concurrent writers
|
||||
(e.g., the streaming loop racing with a long-running tool callback)
|
||||
can overlap without triggering a unique-constraint violation;
|
||||
the later write wins and overwrites the existing row. Non-duplicate
|
||||
failures still roll back the entire batch.
|
||||
Uses collision detection with retry: tries to create messages starting
|
||||
at start_sequence. If a unique constraint violation occurs (e.g., the
|
||||
streaming loop and long-running callback race), queries MAX(sequence)
|
||||
and retries with the correct next sequence number. This avoids
|
||||
unnecessary upserts and DB queries in the common case (no collision).
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
created_messages = []
|
||||
|
||||
async with db.transaction() as tx:
|
||||
for i, msg in enumerate(messages):
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||
# directly because Prisma's TypedDict validation rejects optional fields
|
||||
# set to None. We only include fields that have values, then cast.
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
# Use upsert to handle concurrent writers (e.g. incremental
|
||||
# streaming saves racing with long-running tool callbacks) that
|
||||
# may produce duplicate (sessionId, sequence) pairs.
|
||||
# Explicitly construct update_data (exclude Session relation and sequence)
|
||||
update_data: dict[str, Any] = {"role": msg["role"]}
|
||||
if msg.get("content") is not None:
|
||||
update_data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
update_data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
update_data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
update_data["refusal"] = msg["refusal"]
|
||||
if msg.get("tool_calls") is not None:
|
||||
update_data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
update_data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
created = await PrismaChatMessage.prisma(tx).upsert(
|
||||
where={
|
||||
"sessionId_sequence": {
|
||||
"sessionId": session_id,
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
created_messages = []
|
||||
async with db.transaction() as tx:
|
||||
for i, msg in enumerate(messages):
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||
# directly because Prisma's TypedDict validation rejects optional fields
|
||||
# set to None. We only include fields that have values, then cast.
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
},
|
||||
data={
|
||||
"create": cast(ChatMessageCreateInput, data),
|
||||
"update": cast(ChatMessageUpdateInput, update_data),
|
||||
},
|
||||
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
created = await PrismaChatMessage.prisma(tx).create(
|
||||
data=cast(ChatMessageCreateInput, data)
|
||||
)
|
||||
created_messages.append(created)
|
||||
|
||||
# Update session's updatedAt timestamp within the same transaction.
|
||||
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
|
||||
# separately via update_chat_session() after streaming completes.
|
||||
await PrismaChatSession.prisma(tx).update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
|
||||
return [ChatMessage.from_db(m) for m in created_messages]
|
||||
|
||||
except Exception as e:
|
||||
# Check if it's a unique constraint violation
|
||||
error_msg = str(e).lower()
|
||||
is_unique_constraint = (
|
||||
"unique constraint" in error_msg or "duplicate key" in error_msg
|
||||
)
|
||||
created_messages.append(created)
|
||||
|
||||
# Update session's updatedAt timestamp within the same transaction.
|
||||
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
|
||||
# separately via update_chat_session() after streaming completes.
|
||||
await PrismaChatSession.prisma(tx).update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
if is_unique_constraint and attempt < max_retries - 1:
|
||||
# Collision detected - query MAX(sequence) and retry with correct offset
|
||||
logger.info(
|
||||
f"Collision detected for session {session_id} at sequence "
|
||||
f"{start_sequence}, querying DB for latest sequence"
|
||||
)
|
||||
max_seq = await get_max_sequence(session_id)
|
||||
start_sequence = max_seq + 1
|
||||
logger.info(
|
||||
f"Retrying batch insert with start_sequence={start_sequence}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# Not a collision or max retries exceeded - propagate error
|
||||
raise
|
||||
|
||||
return [ChatMessage.from_db(m) for m in created_messages]
|
||||
# Should never reach here due to raise in exception handler
|
||||
raise RuntimeError(f"Failed to insert messages after {max_retries} attempts")
|
||||
|
||||
|
||||
async def get_user_chat_sessions(
|
||||
@@ -275,6 +274,25 @@ async def get_chat_session_message_count(session_id: str) -> int:
|
||||
return count
|
||||
|
||||
|
||||
async def get_max_sequence(session_id: str) -> int:
|
||||
"""Get the maximum sequence number for a session.
|
||||
|
||||
Returns the highest sequence number, or -1 if no messages exist.
|
||||
This is used for collision detection when concurrent writers race.
|
||||
"""
|
||||
result = await db.prisma.query_raw(
|
||||
"""
|
||||
SELECT COALESCE(MAX(sequence), -1) as max_seq
|
||||
FROM "ChatMessage"
|
||||
WHERE "sessionId" = $1
|
||||
""",
|
||||
session_id,
|
||||
)
|
||||
if not result or len(result) == 0:
|
||||
return -1
|
||||
return int(result[0]["max_seq"])
|
||||
|
||||
|
||||
async def update_tool_message_content(
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
|
||||
@@ -11,7 +11,6 @@ from typing import Any
|
||||
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .. import db as chat_db
|
||||
from .. import stream_registry
|
||||
from ..config import ChatConfig
|
||||
from ..model import (
|
||||
@@ -217,10 +216,9 @@ def _build_long_running_callback(
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
session.messages.append(pending_message)
|
||||
# Layer 2: Query DB for latest count before save (defense against stale counter)
|
||||
db_count = await chat_db.get_chat_session_message_count(session_id)
|
||||
await upsert_chat_session(session, existing_message_count=db_count)
|
||||
# Layer 3: Update shared counter so streaming loop stays in sync
|
||||
# Collision detection happens in add_chat_messages_batch (db.py)
|
||||
await upsert_chat_session(session)
|
||||
# Update shared counter so streaming loop stays in sync
|
||||
if saved_msg_count_ref is not None:
|
||||
saved_msg_count_ref[0] = len(session.messages)
|
||||
|
||||
@@ -913,19 +911,11 @@ async def stream_chat_completion_sdk(
|
||||
has_appended_assistant = True
|
||||
# Save before tool execution starts so the
|
||||
# pending tool call is visible on refresh /
|
||||
# other devices.
|
||||
# other devices. Collision detection happens
|
||||
# in add_chat_messages_batch (db.py).
|
||||
try:
|
||||
# Layer 2: Query DB for latest count (defense against stale counter)
|
||||
db_count = (
|
||||
await chat_db.get_chat_session_message_count(
|
||||
session_id
|
||||
)
|
||||
)
|
||||
await upsert_chat_session(
|
||||
session,
|
||||
existing_message_count=db_count,
|
||||
)
|
||||
# Layer 3: Update shared ref so callback stays in sync
|
||||
await upsert_chat_session(session)
|
||||
# Update shared ref so callback stays in sync
|
||||
saved_msg_count_ref[0] = len(session.messages)
|
||||
except Exception as save_err:
|
||||
logger.warning(
|
||||
@@ -949,18 +939,10 @@ async def stream_chat_completion_sdk(
|
||||
has_tool_results = True
|
||||
# Save after tool completes so the result is
|
||||
# visible on refresh / other devices.
|
||||
# Collision detection happens in add_chat_messages_batch (db.py).
|
||||
try:
|
||||
# Layer 2: Query DB for latest count (defense against stale counter)
|
||||
db_count = (
|
||||
await chat_db.get_chat_session_message_count(
|
||||
session_id
|
||||
)
|
||||
)
|
||||
await upsert_chat_session(
|
||||
session,
|
||||
existing_message_count=db_count,
|
||||
)
|
||||
# Layer 3: Update shared ref so callback stays in sync
|
||||
await upsert_chat_session(session)
|
||||
# Update shared ref so callback stays in sync
|
||||
saved_msg_count_ref[0] = len(session.messages)
|
||||
except Exception as save_err:
|
||||
logger.warning(
|
||||
|
||||
Reference in New Issue
Block a user