mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(copilot): update shared counter after collision detection
When collision detection in add_chat_messages_batch retries with a higher sequence number, the actual persisted message count may differ from len(session.messages). This commit ensures the shared counter (saved_msg_count_ref) used by the streaming loop and long-running callback stays synchronized with the actual DB state. Changes: - Modified add_chat_messages_batch to return tuple[list[ChatMessage], int] where the int is the final message count after collision resolution - Updated _save_session_to_db and upsert_chat_session to propagate the final count up the call chain - Updated all callers in sdk/service.py to use the returned count instead of len(session.messages) when updating saved_msg_count_ref - Updated all other callers in service.py and tests to handle tuple return
This commit is contained in:
@@ -132,7 +132,7 @@ async def add_chat_messages_batch(
|
||||
session_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
start_sequence: int,
|
||||
) -> list[ChatMessage]:
|
||||
) -> tuple[list[ChatMessage], int]:
|
||||
"""Add multiple messages to a chat session in a batch.
|
||||
|
||||
Uses collision detection with retry: tries to create messages starting
|
||||
@@ -140,9 +140,16 @@ async def add_chat_messages_batch(
|
||||
streaming loop and long-running callback race), queries MAX(sequence)
|
||||
and retries with the correct next sequence number. This avoids
|
||||
unnecessary upserts and DB queries in the common case (no collision).
|
||||
|
||||
Returns:
|
||||
Tuple of (messages, final_message_count) where final_message_count
|
||||
is the total number of messages in the session after insertion.
|
||||
This allows callers to update their counters even when collision
|
||||
detection adjusts start_sequence.
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
# No messages to add - return current count
|
||||
return [], start_sequence
|
||||
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
@@ -188,7 +195,9 @@ async def add_chat_messages_batch(
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
|
||||
return [ChatMessage.from_db(m) for m in created_messages]
|
||||
# Return messages and final message count (for shared counter sync)
|
||||
final_count = start_sequence + len(messages)
|
||||
return [ChatMessage.from_db(m) for m in created_messages], final_count
|
||||
|
||||
except Exception as e:
|
||||
# Check if it's a unique constraint violation
|
||||
|
||||
@@ -436,7 +436,7 @@ async def upsert_chat_session(
|
||||
session: ChatSession,
|
||||
*,
|
||||
existing_message_count: int | None = None,
|
||||
) -> ChatSession:
|
||||
) -> tuple[ChatSession, int]:
|
||||
"""Update a chat session in both cache and database.
|
||||
|
||||
Uses session-level locking to prevent race conditions when concurrent
|
||||
@@ -449,6 +449,10 @@ async def upsert_chat_session(
|
||||
accurately. Useful for incremental saves in a streaming loop
|
||||
where the caller already knows how many messages are persisted.
|
||||
|
||||
Returns:
|
||||
Tuple of (session, final_message_count) where final_message_count is
|
||||
the actual persisted message count after collision detection adjustments.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If the database write fails. The cache is still updated
|
||||
as a best-effort optimization, but the error is propagated to ensure
|
||||
@@ -466,10 +470,11 @@ async def upsert_chat_session(
|
||||
)
|
||||
|
||||
db_error: Exception | None = None
|
||||
final_count = existing_message_count
|
||||
|
||||
# Save to database (primary storage)
|
||||
try:
|
||||
await _save_session_to_db(
|
||||
final_count = await _save_session_to_db(
|
||||
session,
|
||||
existing_message_count,
|
||||
skip_existence_check=existing_message_count > 0,
|
||||
@@ -500,7 +505,7 @@ async def upsert_chat_session(
|
||||
f"Failed to persist chat session {session.session_id} to database"
|
||||
) from db_error
|
||||
|
||||
return session
|
||||
return session, final_count
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
@@ -508,13 +513,16 @@ async def _save_session_to_db(
|
||||
existing_message_count: int,
|
||||
*,
|
||||
skip_existence_check: bool = False,
|
||||
) -> None:
|
||||
) -> int:
|
||||
"""Save or update a chat session in the database.
|
||||
|
||||
Args:
|
||||
skip_existence_check: When True, skip the ``get_chat_session`` query
|
||||
and assume the session row already exists. Saves one DB round trip
|
||||
for incremental saves during streaming.
|
||||
|
||||
Returns:
|
||||
Final message count after save (accounting for collision detection).
|
||||
"""
|
||||
db = chat_db()
|
||||
|
||||
@@ -546,6 +554,7 @@ async def _save_session_to_db(
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
final_count = existing_message_count
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
@@ -565,12 +574,14 @@ async def _save_session_to_db(
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
await db.add_chat_messages_batch(
|
||||
_, final_count = await db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
|
||||
return final_count
|
||||
|
||||
|
||||
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
||||
"""Atomically append a message to a session and persist it.
|
||||
|
||||
@@ -60,7 +60,7 @@ async def test_chatsession_redis_storage(setup_test_user, test_user_id):
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages
|
||||
|
||||
s = await upsert_chat_session(s)
|
||||
s, _ = await upsert_chat_session(s)
|
||||
|
||||
s2 = await get_chat_session(
|
||||
session_id=s.session_id,
|
||||
@@ -77,7 +77,7 @@ async def test_chatsession_redis_storage_user_id_mismatch(
|
||||
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages
|
||||
s = await upsert_chat_session(s)
|
||||
s, _ = await upsert_chat_session(s)
|
||||
|
||||
s2 = await get_chat_session(s.session_id, "different_user_id")
|
||||
|
||||
@@ -94,7 +94,7 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
|
||||
s.messages = messages # Contains user, assistant, and tool messages
|
||||
assert s.session_id is not None, "Session id is not set"
|
||||
# Upsert to save to both cache and DB
|
||||
s = await upsert_chat_session(s)
|
||||
s, _ = await upsert_chat_session(s)
|
||||
|
||||
# Clear the Redis cache to force DB load
|
||||
redis_key = f"chat:session:{s.session_id}"
|
||||
|
||||
@@ -272,10 +272,10 @@ def _build_long_running_callback(
|
||||
)
|
||||
session.messages.append(pending_message)
|
||||
# Collision detection happens in add_chat_messages_batch (db.py)
|
||||
await upsert_chat_session(session)
|
||||
_, final_count = await upsert_chat_session(session)
|
||||
# Update shared counter so streaming loop stays in sync
|
||||
if saved_msg_count_ref is not None:
|
||||
saved_msg_count_ref[0] = len(session.messages)
|
||||
saved_msg_count_ref[0] = final_count
|
||||
|
||||
# --- Spawn background task (reuses non-SDK infrastructure) ---
|
||||
bg_task = asyncio.create_task(
|
||||
@@ -610,7 +610,7 @@ async def stream_chat_completion_sdk(
|
||||
user_id=user_id, session_id=session_id, message_length=len(message)
|
||||
)
|
||||
|
||||
session = await upsert_chat_session(session)
|
||||
session, _ = await upsert_chat_session(session)
|
||||
|
||||
# Generate title for new sessions (first user message)
|
||||
if is_user_message and not session.title:
|
||||
@@ -986,9 +986,9 @@ async def stream_chat_completion_sdk(
|
||||
# other devices. Collision detection happens
|
||||
# in add_chat_messages_batch (db.py).
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
_, final_count = await upsert_chat_session(session)
|
||||
# Update shared ref so callback stays in sync
|
||||
saved_msg_count_ref[0] = len(session.messages)
|
||||
saved_msg_count_ref[0] = final_count
|
||||
except Exception as save_err:
|
||||
logger.warning(
|
||||
"[SDK] [%s] Incremental save " "failed: %s",
|
||||
@@ -1013,9 +1013,9 @@ async def stream_chat_completion_sdk(
|
||||
# visible on refresh / other devices.
|
||||
# Collision detection happens in add_chat_messages_batch (db.py).
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
_, final_count = await upsert_chat_session(session)
|
||||
# Update shared ref so callback stays in sync
|
||||
saved_msg_count_ref[0] = len(session.messages)
|
||||
saved_msg_count_ref[0] = final_count
|
||||
except Exception as save_err:
|
||||
logger.warning(
|
||||
"[SDK] [%s] Incremental save " "failed: %s",
|
||||
@@ -1147,11 +1147,12 @@ async def stream_chat_completion_sdk(
|
||||
"to use the OpenAI-compatible fallback."
|
||||
)
|
||||
|
||||
await asyncio.shield(upsert_chat_session(session))
|
||||
_, final_count = await asyncio.shield(upsert_chat_session(session))
|
||||
logger.info(
|
||||
"[SDK] [%s] Session saved with %d messages",
|
||||
"[SDK] [%s] Session saved with %d messages (DB count: %d)",
|
||||
session_id[:12],
|
||||
len(session.messages),
|
||||
final_count,
|
||||
)
|
||||
if not stream_completed:
|
||||
yield StreamFinish()
|
||||
|
||||
@@ -352,7 +352,8 @@ async def assign_user_to_session(
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
session.user_id = user_id
|
||||
return await upsert_chat_session(session)
|
||||
session, _ = await upsert_chat_session(session)
|
||||
return session
|
||||
|
||||
|
||||
async def stream_chat_completion(
|
||||
@@ -463,7 +464,7 @@ async def stream_chat_completion(
|
||||
)
|
||||
|
||||
upsert_start = time.monotonic()
|
||||
session = await upsert_chat_session(session)
|
||||
session, _ = await upsert_chat_session(session)
|
||||
upsert_time = (time.monotonic() - upsert_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
|
||||
@@ -689,7 +690,7 @@ async def stream_chat_completion(
|
||||
f"tool_responses={len(tool_response_messages)}"
|
||||
)
|
||||
if messages_to_save_early or has_appended_streaming_message:
|
||||
await upsert_chat_session(session)
|
||||
_ = await upsert_chat_session(session)
|
||||
has_saved_assistant_message = True
|
||||
|
||||
has_yielded_end = True
|
||||
@@ -728,7 +729,7 @@ async def stream_chat_completion(
|
||||
if tool_response_messages:
|
||||
session.messages.extend(tool_response_messages)
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
_ = await upsert_chat_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to save interrupted session {session.session_id}: {e}"
|
||||
@@ -769,7 +770,7 @@ async def stream_chat_completion(
|
||||
if messages_to_save:
|
||||
session.messages.extend(messages_to_save)
|
||||
if messages_to_save or has_appended_streaming_message:
|
||||
await upsert_chat_session(session)
|
||||
_ = await upsert_chat_session(session)
|
||||
|
||||
if not has_yielded_error:
|
||||
error_message = str(e)
|
||||
@@ -853,7 +854,7 @@ async def stream_chat_completion(
|
||||
not has_long_running_tool_call
|
||||
and (messages_to_save or has_appended_streaming_message)
|
||||
):
|
||||
await upsert_chat_session(session)
|
||||
_ = await upsert_chat_session(session)
|
||||
else:
|
||||
logger.info(
|
||||
"Assistant message already saved when StreamFinish was received, "
|
||||
@@ -1525,7 +1526,7 @@ async def _yield_tool_call(
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
session.messages.append(pending_message)
|
||||
await upsert_chat_session(session)
|
||||
_ = await upsert_chat_session(session)
|
||||
|
||||
await _with_optional_lock(session_lock, _save_pending)
|
||||
logger.info(
|
||||
@@ -2019,7 +2020,7 @@ async def _generate_llm_continuation(
|
||||
fresh_session.messages.append(assistant_message)
|
||||
|
||||
# Save to database (not cache) to persist the response
|
||||
await upsert_chat_session(fresh_session)
|
||||
_ = await upsert_chat_session(fresh_session)
|
||||
|
||||
# Invalidate cache so next poll/refresh gets fresh data
|
||||
await invalidate_session_cache(session_id)
|
||||
@@ -2225,7 +2226,7 @@ async def _generate_llm_continuation_with_streaming(
|
||||
fresh_session.messages.append(assistant_message)
|
||||
|
||||
# Save to database (not cache) to persist the response
|
||||
await upsert_chat_session(fresh_session)
|
||||
_ = await upsert_chat_session(fresh_session)
|
||||
|
||||
# Invalidate cache so next poll/refresh gets fresh data
|
||||
await invalidate_session_cache(session_id)
|
||||
|
||||
@@ -58,7 +58,7 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await upsert_chat_session(session)
|
||||
session, _ = await upsert_chat_session(session)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
@@ -104,7 +104,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||
return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
session = await upsert_chat_session(session)
|
||||
session, _ = await upsert_chat_session(session)
|
||||
|
||||
# --- Turn 1: send a message with a unique keyword ---
|
||||
keyword = "ZEPHYR42"
|
||||
|
||||
Reference in New Issue
Block a user