Compare commits

...

12 Commits

Author SHA1 Message Date
Zamil Majdy
bac7b9efb9 fix(copilot): update shared counter after collision detection
When collision detection in add_chat_messages_batch retries with a higher
sequence number, the actual persisted message count may differ from
len(session.messages). This commit ensures the shared counter
(saved_msg_count_ref) used by the streaming loop and long-running callback
stays synchronized with the actual DB state.

Changes:
- Modified add_chat_messages_batch to return tuple[list[ChatMessage], int]
  where the int is the final message count after collision resolution
- Updated _save_session_to_db and upsert_chat_session to propagate the
  final count up the call chain
- Updated all callers in sdk/service.py to use the returned count instead
  of len(session.messages) when updating saved_msg_count_ref
- Updated all other callers in service.py and tests to handle tuple return
2026-02-20 18:58:02 +07:00
Zamil Majdy
6e1941d7ae feat(copilot): implement session locking to prevent concurrent streams
- Add stream_id (using task_id) to uniquely identify each stream
- Acquire exclusive lock (Redis SET NX EX) when starting a stream
- Release lock in finally block using Lua script (atomic compare-and-delete)
- Return error if another stream is already active for the session
- Lock TTL is 1 hour (matches stream_ttl) with automatic cleanup

This prevents:
- Message duplication from concurrent streams
- Race conditions in message saves
- Confusing UX with multiple AI responses
- Frontend reconnecting while existing stream is active
- Multiple browser tabs streaming to same session
2026-02-20 18:28:35 +07:00
Zamil Majdy
129b992059 feat(copilot): increase long-running operation TTL to 1 hour
- Increase long_running_operation_ttl from 600s (10min) to 3600s (1hour)
- Match stream_ttl duration for consistency
- Add clarifying description about deduplication lock purpose

Some operations (like complex agent runs) can take longer than 10 minutes.
The stream_registry heartbeat (publish_chunk) already keeps operations alive,
so this TTL is just a safety net for deduplication.
2026-02-20 18:22:26 +07:00
Zamil Majdy
1b82a55eca chore: remove obsolete plan file
Plan was completed and changes are now in the PR. No need to keep the plan file.
2026-02-20 18:21:00 +07:00
Zamil Majdy
9d4697e859 refactor(copilot): replace COUNT with MAX for sequence tracking
- Rename get_max_sequence() to get_next_sequence() returning MAX+1
- Replace all get_chat_session_message_count() calls with get_next_sequence()
- Remove old get_chat_session_message_count() function
- Update db_manager.py to export get_next_sequence

Using MAX(sequence)+1 is more robust than COUNT(*) because:
- Immune to deleted messages
- Handles gaps in sequence numbers correctly
- Simpler collision detection logic
2026-02-20 18:20:29 +07:00
Zamil Majdy
366547e448 refactor(copilot): remove confusing 'Layer' comments from code
- Remove '(Layer 3: defense-in-depth)' annotations
- Replace with clearer explanations of what the code does
- Makes the code easier to understand without implementation history
2026-02-20 18:18:25 +07:00
Zamil Majdy
af491b5511 refactor(copilot): replace upsert with collision detection for concurrent message saves
- Use create() with MAX(sequence) retry instead of upsert()
- Query DB only on collision (not every save) for better performance
- Remove Layer 2 DB queries from incremental saves in streaming loop
- Add get_max_sequence() helper using raw SQL for robustness
- Collision detection retries up to 3 times on unique constraint errors

This approach:
- Optimizes common case (no collision) - no extra DB queries
- Handles concurrent writes via automatic retry with correct sequence
- Uses MAX(sequence) instead of COUNT for more robust offset calculation
2026-02-20 18:09:34 +07:00
Zamil Majdy
6acefee6f3 fix(copilot): defense-in-depth for concurrent message saves (all 3 layers)
Implements three complementary layers to prevent unique constraint violations
on (sessionId, sequence) caused by concurrent writers during SDK streaming:

**Layer 1: Upsert (already in PR)**
- add_chat_messages_batch uses upsert() instead of create()
- Explicitly constructs update_data excluding Session and sequence
- Final safety net: duplicate sequences update instead of crash

**Layer 2: Query DB Before Each Save (NEW)**
- Query get_chat_session_message_count() before each save
- DB is source of truth, prevents using stale in-memory counter
- Applied to: long-running callback + 2 incremental saves
- Trade-off: Extra COUNT query (~1-2ms), but prevents race

**Layer 3: Shared Counter (NEW)**
- saved_msg_count_ref as mutable list[int] shared between:
  - Streaming loop (incremental saves)
  - Long-running callback (_build_long_running_callback)
- Both writers update it after successful save
- Keeps in-memory tracking accurate for performance

**Why all three:**
- Layer 2 alone: adds DB queries (performance cost)
- Layer 3 alone: doesn't handle external writers
- Layer 1 alone: may silently overwrite data
- Together: correctness + performance + safety net

Files:
- backend/copilot/db.py - Layer 1 (upsert with explicit update_data)
- backend/copilot/sdk/service.py - Layers 2 & 3

Fixes race where long-running tools (create_agent, edit_agent) would
append messages behind streaming loop's back, causing stale counter.

Addresses PR review comments and Discord analysis.
2026-02-20 18:02:00 +07:00
Zamil Majdy
eb4650fbb8 fix(copilot): explicitly construct update_data for better type safety
Instead of filtering from data dict, explicitly build update_data with
only the fields that should be updated. This is safer and makes it
obvious what fields are being updated in the upsert operation.

Addresses PR review comment about exhaustive field construction.
2026-02-20 17:53:52 +07:00
Zamil Majdy
8bdf83128e fix(copilot): address CodeRabbit review - add type safety and exclude sequence from update
- Add ChatMessageUpdateInput import for type-safe update payload
- Exclude both 'Session' and 'sequence' from update_data (sequence is part of composite key)
- Cast update_data to ChatMessageUpdateInput for type checking
- Update docstring to document upsert semantics and idempotency
2026-02-20 17:49:11 +07:00
Zamil Majdy
a1d5b99226 Merge branch 'dev' into otto/fix-chat-messages-batch-upsert 2026-02-20 17:48:02 +07:00
Otto
0450ea5313 fix(copilot): use upsert in add_chat_messages_batch to handle duplicate sequences
Concurrent writers (incremental streaming saves and long-running tool
callbacks) can race to persist messages with the same (sessionId, sequence)
pair, causing unique constraint violations.

Replace prisma create() with upsert() so duplicate sequences update the
existing row instead of failing. This is safe because later writes always
contain the most complete data (e.g. accumulated assistant text).
2026-02-20 09:59:56 +00:00
8 changed files with 247 additions and 91 deletions

View File

@@ -38,8 +38,10 @@ class ChatConfig(BaseSettings):
# Long-running operation configuration
long_running_operation_ttl: int = Field(
default=600,
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
default=3600,
description="TTL in seconds for long-running operation deduplication lock "
"(1 hour, matches stream_ttl). Prevents duplicate operations if pod dies. "
"For longer operations, the stream_registry heartbeat keeps them alive.",
)
# Stream registry configuration for SSE reconnection

View File

@@ -132,58 +132,97 @@ async def add_chat_messages_batch(
session_id: str,
messages: list[dict[str, Any]],
start_sequence: int,
) -> list[ChatMessage]:
) -> tuple[list[ChatMessage], int]:
"""Add multiple messages to a chat session in a batch.
Uses a transaction for atomicity - if any message creation fails,
the entire batch is rolled back.
Uses collision detection with retry: tries to create messages starting
at start_sequence. If a unique constraint violation occurs (e.g., the
streaming loop and long-running callback race), queries MAX(sequence)
and retries with the correct next sequence number. This avoids
unnecessary upserts and DB queries in the common case (no collision).
Returns:
Tuple of (messages, final_message_count) where final_message_count
is the total number of messages in the session after insertion.
This allows callers to update their counters even when collision
detection adjusts start_sequence.
"""
if not messages:
return []
# No messages to add - return current count
return [], start_sequence
created_messages = []
max_retries = 3
for attempt in range(max_retries):
try:
created_messages = []
async with db.transaction() as tx:
for i, msg in enumerate(messages):
# Build input dict dynamically rather than using ChatMessageCreateInput
# directly because Prisma's TypedDict validation rejects optional fields
# set to None. We only include fields that have values, then cast.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": msg["role"],
"sequence": start_sequence + i,
}
async with db.transaction() as tx:
for i, msg in enumerate(messages):
# Build input dict dynamically rather than using ChatMessageCreateInput
# directly because Prisma's TypedDict validation rejects optional fields
# set to None. We only include fields that have values, then cast.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": msg["role"],
"sequence": start_sequence + i,
}
# Add optional string fields
if msg.get("content") is not None:
data["content"] = msg["content"]
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = msg["refusal"]
# Add optional string fields
if msg.get("content") is not None:
data["content"] = msg["content"]
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = msg["refusal"]
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
created = await PrismaChatMessage.prisma(tx).create(
data=cast(ChatMessageCreateInput, data)
)
created_messages.append(created)
created = await PrismaChatMessage.prisma(tx).create(
data=cast(ChatMessageCreateInput, data)
# Update session's updatedAt timestamp within the same transaction.
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
# separately via update_chat_session() after streaming completes.
await PrismaChatSession.prisma(tx).update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
)
# Return messages and final message count (for shared counter sync)
final_count = start_sequence + len(messages)
return [ChatMessage.from_db(m) for m in created_messages], final_count
except Exception as e:
# Check if it's a unique constraint violation
error_msg = str(e).lower()
is_unique_constraint = (
"unique constraint" in error_msg or "duplicate key" in error_msg
)
created_messages.append(created)
# Update session's updatedAt timestamp within the same transaction.
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
# separately via update_chat_session() after streaming completes.
await PrismaChatSession.prisma(tx).update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
)
if is_unique_constraint and attempt < max_retries - 1:
# Collision detected - query MAX(sequence)+1 and retry with correct offset
logger.info(
f"Collision detected for session {session_id} at sequence "
f"{start_sequence}, querying DB for latest sequence"
)
start_sequence = await get_next_sequence(session_id)
logger.info(
f"Retrying batch insert with start_sequence={start_sequence}"
)
continue
else:
# Not a collision or max retries exceeded - propagate error
raise
return [ChatMessage.from_db(m) for m in created_messages]
# Should never reach here due to raise in exception handler
raise RuntimeError(f"Failed to insert messages after {max_retries} attempts")
async def get_user_chat_sessions(
@@ -237,10 +276,23 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
return False
async def get_chat_session_message_count(session_id: str) -> int:
"""Get the number of messages in a chat session."""
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
return count
async def get_next_sequence(session_id: str) -> int:
"""Get the next sequence number for a new message in this session.
Uses MAX(sequence) + 1 for robustness. Returns 0 if no messages exist.
More robust than COUNT(*) because it's immune to deleted messages.
"""
result = await db.prisma.query_raw(
"""
SELECT COALESCE(MAX(sequence) + 1, 0) as next_seq
FROM "ChatMessage"
WHERE "sessionId" = $1
""",
session_id,
)
if not result or len(result) == 0:
return 0
return int(result[0]["next_seq"])
async def update_tool_message_content(

View File

@@ -436,7 +436,7 @@ async def upsert_chat_session(
session: ChatSession,
*,
existing_message_count: int | None = None,
) -> ChatSession:
) -> tuple[ChatSession, int]:
"""Update a chat session in both cache and database.
Uses session-level locking to prevent race conditions when concurrent
@@ -449,6 +449,10 @@ async def upsert_chat_session(
accurately. Useful for incremental saves in a streaming loop
where the caller already knows how many messages are persisted.
Returns:
Tuple of (session, final_message_count) where final_message_count is
the actual persisted message count after collision detection adjustments.
Raises:
DatabaseError: If the database write fails. The cache is still updated
as a best-effort optimization, but the error is propagated to ensure
@@ -461,15 +465,16 @@ async def upsert_chat_session(
async with lock:
# Get existing message count from DB for incremental saves
if existing_message_count is None:
existing_message_count = await chat_db().get_chat_session_message_count(
existing_message_count = await chat_db().get_next_sequence(
session.session_id
)
db_error: Exception | None = None
final_count = existing_message_count
# Save to database (primary storage)
try:
await _save_session_to_db(
final_count = await _save_session_to_db(
session,
existing_message_count,
skip_existence_check=existing_message_count > 0,
@@ -500,7 +505,7 @@ async def upsert_chat_session(
f"Failed to persist chat session {session.session_id} to database"
) from db_error
return session
return session, final_count
async def _save_session_to_db(
@@ -508,13 +513,16 @@ async def _save_session_to_db(
existing_message_count: int,
*,
skip_existence_check: bool = False,
) -> None:
) -> int:
"""Save or update a chat session in the database.
Args:
skip_existence_check: When True, skip the ``get_chat_session`` query
and assume the session row already exists. Saves one DB round trip
for incremental saves during streaming.
Returns:
Final message count after save (accounting for collision detection).
"""
db = chat_db()
@@ -546,6 +554,7 @@ async def _save_session_to_db(
# Add new messages (only those after existing count)
new_messages = session.messages[existing_message_count:]
final_count = existing_message_count
if new_messages:
messages_data = []
for msg in new_messages:
@@ -565,12 +574,14 @@ async def _save_session_to_db(
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
)
await db.add_chat_messages_batch(
_, final_count = await db.add_chat_messages_batch(
session_id=session.session_id,
messages=messages_data,
start_sequence=existing_message_count,
)
return final_count
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
"""Atomically append a message to a session and persist it.
@@ -587,9 +598,7 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
raise ValueError(f"Session {session_id} not found")
session.messages.append(message)
existing_message_count = await chat_db().get_chat_session_message_count(
session_id
)
existing_message_count = await chat_db().get_next_sequence(session_id)
try:
await _save_session_to_db(session, existing_message_count)

View File

@@ -60,7 +60,7 @@ async def test_chatsession_redis_storage(setup_test_user, test_user_id):
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s = await upsert_chat_session(s)
s, _ = await upsert_chat_session(s)
s2 = await get_chat_session(
session_id=s.session_id,
@@ -77,7 +77,7 @@ async def test_chatsession_redis_storage_user_id_mismatch(
s = ChatSession.new(user_id=test_user_id)
s.messages = messages
s = await upsert_chat_session(s)
s, _ = await upsert_chat_session(s)
s2 = await get_chat_session(s.session_id, "different_user_id")
@@ -94,7 +94,7 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
s.messages = messages # Contains user, assistant, and tool messages
assert s.session_id is not None, "Session id is not set"
# Upsert to save to both cache and DB
s = await upsert_chat_session(s)
s, _ = await upsert_chat_session(s)
# Clear the Redis cache to force DB load
redis_key = f"chat:session:{s.session_id}"

View File

@@ -9,6 +9,7 @@ from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any
from backend.data.redis_client import get_redis_async
from backend.util.exceptions import NotFoundError
from .. import stream_registry
@@ -132,8 +133,65 @@ is delivered to the user via a background stream.
All tasks must run in the foreground.
"""
# Session streaming lock configuration
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
STREAM_LOCK_TTL = 3600 # 1 hour - matches stream_ttl
def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
async def _acquire_stream_lock(session_id: str, stream_id: str) -> bool:
"""Acquire an exclusive lock for streaming to this session.
Prevents multiple concurrent streams to the same session which can cause:
- Message duplication
- Race conditions in message saves
- Confusing UX with multiple AI responses
Returns:
True if lock was acquired, False if another stream is active.
"""
redis = await get_redis_async()
lock_key = f"{STREAM_LOCK_PREFIX}{session_id}"
# SET NX EX - atomic "set if not exists" with expiry
result = await redis.set(lock_key, stream_id, ex=STREAM_LOCK_TTL, nx=True)
return result is not None
async def _release_stream_lock(session_id: str, stream_id: str) -> None:
"""Release the stream lock if we still own it.
Only releases the lock if the stored stream_id matches ours (prevents
releasing another stream's lock if we somehow timed out).
"""
redis = await get_redis_async()
lock_key = f"{STREAM_LOCK_PREFIX}{session_id}"
# Lua script for atomic compare-and-delete (only delete if value matches)
script = """
if redis.call("GET", KEYS[1]) == ARGV[1] then
return redis.call("DEL", KEYS[1])
else
return 0
end
"""
await redis.eval(script, 1, lock_key, stream_id) # type: ignore[misc]
async def check_active_stream(session_id: str) -> str | None:
"""Check if a stream is currently active for this session.
Returns:
The active stream_id if one exists, None otherwise.
"""
redis = await get_redis_async()
lock_key = f"{STREAM_LOCK_PREFIX}{session_id}"
active_stream = await redis.get(lock_key)
return active_stream.decode() if isinstance(active_stream, bytes) else active_stream
def _build_long_running_callback(
user_id: str | None,
saved_msg_count_ref: list[int] | None = None,
) -> LongRunningCallback:
"""Build a callback that delegates long-running tools to the non-SDK infrastructure.
Long-running tools (create_agent, edit_agent, etc.) are delegated to the
@@ -142,6 +200,12 @@ def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
page refreshes / pod restarts, and the frontend shows the proper loading
widget with progress updates.
Args:
user_id: User ID for the session
saved_msg_count_ref: Mutable reference [count] shared with streaming loop
for coordinating message saves. When provided, the callback will update
it after appending messages to prevent counter drift.
The returned callback matches the ``LongRunningCallback`` signature:
``(tool_name, args, session) -> MCP response dict``.
"""
@@ -207,7 +271,11 @@ def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
tool_call_id=tool_call_id,
)
session.messages.append(pending_message)
await upsert_chat_session(session)
# Collision detection happens in add_chat_messages_batch (db.py)
_, final_count = await upsert_chat_session(session)
# Update shared counter so streaming loop stays in sync
if saved_msg_count_ref is not None:
saved_msg_count_ref[0] = final_count
# --- Spawn background task (reuses non-SDK infrastructure) ---
bg_task = asyncio.create_task(
@@ -542,7 +610,7 @@ async def stream_chat_completion_sdk(
user_id=user_id, session_id=session_id, message_length=len(message)
)
session = await upsert_chat_session(session)
session, _ = await upsert_chat_session(session)
# Generate title for new sessions (first user message)
if is_user_message and not session.title:
@@ -564,6 +632,23 @@ async def stream_chat_completion_sdk(
system_prompt += _SDK_TOOL_SUPPLEMENT
message_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
stream_id = task_id # Use task_id as unique stream identifier
# Acquire stream lock to prevent concurrent streams to the same session
lock_acquired = await _acquire_stream_lock(session_id, stream_id)
if not lock_acquired:
# Another stream is active - check if it's still alive
active_stream = await check_active_stream(session_id)
logger.warning(
f"[SDK] Session {session_id} already has an active stream: {active_stream}"
)
yield StreamError(
errorText="Another stream is already active for this session. "
"Please wait for it to complete or refresh the page.",
code="stream_already_active",
)
yield StreamFinish()
return
yield StreamStart(messageId=message_id, taskId=task_id)
@@ -581,10 +666,16 @@ async def stream_chat_completion_sdk(
sdk_cwd = _make_sdk_cwd(session_id)
os.makedirs(sdk_cwd, exist_ok=True)
# Initialize saved message counter as mutable list so long-running
# callback and streaming loop can coordinate
saved_msg_count_ref: list[int] = [len(session.messages)]
set_execution_context(
user_id,
session,
long_running_callback=_build_long_running_callback(user_id),
long_running_callback=_build_long_running_callback(
user_id, saved_msg_count_ref
),
)
try:
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
@@ -715,9 +806,8 @@ async def stream_chat_completion_sdk(
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False
# Track persisted message count to skip DB count queries
# on incremental saves. Initial save happened at line 545.
saved_msg_count = len(session.messages)
# Track persisted message count. Uses shared ref so long-running
# callback can update it for coordination
# Use an explicit async iterator with non-cancelling heartbeats.
# CRITICAL: we must NOT cancel __anext__() mid-flight — doing so
@@ -893,13 +983,12 @@ async def stream_chat_completion_sdk(
has_appended_assistant = True
# Save before tool execution starts so the
# pending tool call is visible on refresh /
# other devices.
# other devices. Collision detection happens
# in add_chat_messages_batch (db.py).
try:
await upsert_chat_session(
session,
existing_message_count=saved_msg_count,
)
saved_msg_count = len(session.messages)
_, final_count = await upsert_chat_session(session)
# Update shared ref so callback stays in sync
saved_msg_count_ref[0] = final_count
except Exception as save_err:
logger.warning(
"[SDK] [%s] Incremental save " "failed: %s",
@@ -922,12 +1011,11 @@ async def stream_chat_completion_sdk(
has_tool_results = True
# Save after tool completes so the result is
# visible on refresh / other devices.
# Collision detection happens in add_chat_messages_batch (db.py).
try:
await upsert_chat_session(
session,
existing_message_count=saved_msg_count,
)
saved_msg_count = len(session.messages)
_, final_count = await upsert_chat_session(session)
# Update shared ref so callback stays in sync
saved_msg_count_ref[0] = final_count
except Exception as save_err:
logger.warning(
"[SDK] [%s] Incremental save " "failed: %s",
@@ -1059,11 +1147,12 @@ async def stream_chat_completion_sdk(
"to use the OpenAI-compatible fallback."
)
await asyncio.shield(upsert_chat_session(session))
_, final_count = await asyncio.shield(upsert_chat_session(session))
logger.info(
"[SDK] [%s] Session saved with %d messages",
"[SDK] [%s] Session saved with %d messages (DB count: %d)",
session_id[:12],
len(session.messages),
final_count,
)
if not stream_completed:
yield StreamFinish()
@@ -1121,6 +1210,9 @@ async def stream_chat_completion_sdk(
if sdk_cwd:
_cleanup_sdk_tool_results(sdk_cwd)
# Release stream lock to allow new streams for this session
await _release_stream_lock(session_id, stream_id)
async def _try_upload_transcript(
user_id: str,

View File

@@ -352,7 +352,8 @@ async def assign_user_to_session(
if not session:
raise NotFoundError(f"Session {session_id} not found")
session.user_id = user_id
return await upsert_chat_session(session)
session, _ = await upsert_chat_session(session)
return session
async def stream_chat_completion(
@@ -463,7 +464,7 @@ async def stream_chat_completion(
)
upsert_start = time.monotonic()
session = await upsert_chat_session(session)
session, _ = await upsert_chat_session(session)
upsert_time = (time.monotonic() - upsert_start) * 1000
logger.info(
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
@@ -689,7 +690,7 @@ async def stream_chat_completion(
f"tool_responses={len(tool_response_messages)}"
)
if messages_to_save_early or has_appended_streaming_message:
await upsert_chat_session(session)
_ = await upsert_chat_session(session)
has_saved_assistant_message = True
has_yielded_end = True
@@ -728,7 +729,7 @@ async def stream_chat_completion(
if tool_response_messages:
session.messages.extend(tool_response_messages)
try:
await upsert_chat_session(session)
_ = await upsert_chat_session(session)
except Exception as e:
logger.warning(
f"Failed to save interrupted session {session.session_id}: {e}"
@@ -769,7 +770,7 @@ async def stream_chat_completion(
if messages_to_save:
session.messages.extend(messages_to_save)
if messages_to_save or has_appended_streaming_message:
await upsert_chat_session(session)
_ = await upsert_chat_session(session)
if not has_yielded_error:
error_message = str(e)
@@ -853,7 +854,7 @@ async def stream_chat_completion(
not has_long_running_tool_call
and (messages_to_save or has_appended_streaming_message)
):
await upsert_chat_session(session)
_ = await upsert_chat_session(session)
else:
logger.info(
"Assistant message already saved when StreamFinish was received, "
@@ -1525,7 +1526,7 @@ async def _yield_tool_call(
tool_call_id=tool_call_id,
)
session.messages.append(pending_message)
await upsert_chat_session(session)
_ = await upsert_chat_session(session)
await _with_optional_lock(session_lock, _save_pending)
logger.info(
@@ -2019,7 +2020,7 @@ async def _generate_llm_continuation(
fresh_session.messages.append(assistant_message)
# Save to database (not cache) to persist the response
await upsert_chat_session(fresh_session)
_ = await upsert_chat_session(fresh_session)
# Invalidate cache so next poll/refresh gets fresh data
await invalidate_session_cache(session_id)
@@ -2225,7 +2226,7 @@ async def _generate_llm_continuation_with_streaming(
fresh_session.messages.append(assistant_message)
# Save to database (not cache) to persist the response
await upsert_chat_session(fresh_session)
_ = await upsert_chat_session(fresh_session)
# Invalidate cache so next poll/refresh gets fresh data
await invalidate_session_cache(session_id)

View File

@@ -58,7 +58,7 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
session, _ = await upsert_chat_session(session)
has_errors = False
has_ended = False
@@ -104,7 +104,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test")
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
session, _ = await upsert_chat_session(session)
# --- Turn 1: send a message with a unique keyword ---
keyword = "ZEPHYR42"

View File

@@ -303,7 +303,7 @@ class DatabaseManager(AppService):
get_user_chat_sessions = _(chat_db.get_user_chat_sessions)
get_user_session_count = _(chat_db.get_user_session_count)
delete_chat_session = _(chat_db.delete_chat_session)
get_chat_session_message_count = _(chat_db.get_chat_session_message_count)
get_next_sequence = _(chat_db.get_next_sequence)
update_tool_message_content = _(chat_db.update_tool_message_content)
@@ -473,5 +473,5 @@ class DatabaseManagerAsyncClient(AppServiceClient):
get_user_chat_sessions = d.get_user_chat_sessions
get_user_session_count = d.get_user_session_count
delete_chat_session = d.delete_chat_session
get_chat_session_message_count = d.get_chat_session_message_count
get_next_sequence = d.get_next_sequence
update_tool_message_content = d.update_tool_message_content