mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-03-17 03:00:27 -04:00
Compare commits
12 Commits
fix/transc
...
pr-12177
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bac7b9efb9 | ||
|
|
6e1941d7ae | ||
|
|
129b992059 | ||
|
|
1b82a55eca | ||
|
|
9d4697e859 | ||
|
|
366547e448 | ||
|
|
af491b5511 | ||
|
|
6acefee6f3 | ||
|
|
eb4650fbb8 | ||
|
|
8bdf83128e | ||
|
|
a1d5b99226 | ||
|
|
0450ea5313 |
@@ -38,8 +38,10 @@ class ChatConfig(BaseSettings):
|
|||||||
|
|
||||||
# Long-running operation configuration
|
# Long-running operation configuration
|
||||||
long_running_operation_ttl: int = Field(
|
long_running_operation_ttl: int = Field(
|
||||||
default=600,
|
default=3600,
|
||||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
description="TTL in seconds for long-running operation deduplication lock "
|
||||||
|
"(1 hour, matches stream_ttl). Prevents duplicate operations if pod dies. "
|
||||||
|
"For longer operations, the stream_registry heartbeat keeps them alive.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stream registry configuration for SSE reconnection
|
# Stream registry configuration for SSE reconnection
|
||||||
|
|||||||
@@ -132,58 +132,97 @@ async def add_chat_messages_batch(
|
|||||||
session_id: str,
|
session_id: str,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
start_sequence: int,
|
start_sequence: int,
|
||||||
) -> list[ChatMessage]:
|
) -> tuple[list[ChatMessage], int]:
|
||||||
"""Add multiple messages to a chat session in a batch.
|
"""Add multiple messages to a chat session in a batch.
|
||||||
|
|
||||||
Uses a transaction for atomicity - if any message creation fails,
|
Uses collision detection with retry: tries to create messages starting
|
||||||
the entire batch is rolled back.
|
at start_sequence. If a unique constraint violation occurs (e.g., the
|
||||||
|
streaming loop and long-running callback race), queries MAX(sequence)
|
||||||
|
and retries with the correct next sequence number. This avoids
|
||||||
|
unnecessary upserts and DB queries in the common case (no collision).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (messages, final_message_count) where final_message_count
|
||||||
|
is the total number of messages in the session after insertion.
|
||||||
|
This allows callers to update their counters even when collision
|
||||||
|
detection adjusts start_sequence.
|
||||||
"""
|
"""
|
||||||
if not messages:
|
if not messages:
|
||||||
return []
|
# No messages to add - return current count
|
||||||
|
return [], start_sequence
|
||||||
|
|
||||||
created_messages = []
|
max_retries = 3
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
created_messages = []
|
||||||
|
async with db.transaction() as tx:
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||||
|
# directly because Prisma's TypedDict validation rejects optional fields
|
||||||
|
# set to None. We only include fields that have values, then cast.
|
||||||
|
data: dict[str, Any] = {
|
||||||
|
"Session": {"connect": {"id": session_id}},
|
||||||
|
"role": msg["role"],
|
||||||
|
"sequence": start_sequence + i,
|
||||||
|
}
|
||||||
|
|
||||||
async with db.transaction() as tx:
|
# Add optional string fields
|
||||||
for i, msg in enumerate(messages):
|
if msg.get("content") is not None:
|
||||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
data["content"] = msg["content"]
|
||||||
# directly because Prisma's TypedDict validation rejects optional fields
|
if msg.get("name") is not None:
|
||||||
# set to None. We only include fields that have values, then cast.
|
data["name"] = msg["name"]
|
||||||
data: dict[str, Any] = {
|
if msg.get("tool_call_id") is not None:
|
||||||
"Session": {"connect": {"id": session_id}},
|
data["toolCallId"] = msg["tool_call_id"]
|
||||||
"role": msg["role"],
|
if msg.get("refusal") is not None:
|
||||||
"sequence": start_sequence + i,
|
data["refusal"] = msg["refusal"]
|
||||||
}
|
|
||||||
|
|
||||||
# Add optional string fields
|
# Add optional JSON fields only when they have values
|
||||||
if msg.get("content") is not None:
|
if msg.get("tool_calls") is not None:
|
||||||
data["content"] = msg["content"]
|
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||||
if msg.get("name") is not None:
|
if msg.get("function_call") is not None:
|
||||||
data["name"] = msg["name"]
|
data["functionCall"] = SafeJson(msg["function_call"])
|
||||||
if msg.get("tool_call_id") is not None:
|
|
||||||
data["toolCallId"] = msg["tool_call_id"]
|
|
||||||
if msg.get("refusal") is not None:
|
|
||||||
data["refusal"] = msg["refusal"]
|
|
||||||
|
|
||||||
# Add optional JSON fields only when they have values
|
created = await PrismaChatMessage.prisma(tx).create(
|
||||||
if msg.get("tool_calls") is not None:
|
data=cast(ChatMessageCreateInput, data)
|
||||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
)
|
||||||
if msg.get("function_call") is not None:
|
created_messages.append(created)
|
||||||
data["functionCall"] = SafeJson(msg["function_call"])
|
|
||||||
|
|
||||||
created = await PrismaChatMessage.prisma(tx).create(
|
# Update session's updatedAt timestamp within the same transaction.
|
||||||
data=cast(ChatMessageCreateInput, data)
|
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
|
||||||
|
# separately via update_chat_session() after streaming completes.
|
||||||
|
await PrismaChatSession.prisma(tx).update(
|
||||||
|
where={"id": session_id},
|
||||||
|
data={"updatedAt": datetime.now(UTC)},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return messages and final message count (for shared counter sync)
|
||||||
|
final_count = start_sequence + len(messages)
|
||||||
|
return [ChatMessage.from_db(m) for m in created_messages], final_count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Check if it's a unique constraint violation
|
||||||
|
error_msg = str(e).lower()
|
||||||
|
is_unique_constraint = (
|
||||||
|
"unique constraint" in error_msg or "duplicate key" in error_msg
|
||||||
)
|
)
|
||||||
created_messages.append(created)
|
|
||||||
|
|
||||||
# Update session's updatedAt timestamp within the same transaction.
|
if is_unique_constraint and attempt < max_retries - 1:
|
||||||
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
|
# Collision detected - query MAX(sequence)+1 and retry with correct offset
|
||||||
# separately via update_chat_session() after streaming completes.
|
logger.info(
|
||||||
await PrismaChatSession.prisma(tx).update(
|
f"Collision detected for session {session_id} at sequence "
|
||||||
where={"id": session_id},
|
f"{start_sequence}, querying DB for latest sequence"
|
||||||
data={"updatedAt": datetime.now(UTC)},
|
)
|
||||||
)
|
start_sequence = await get_next_sequence(session_id)
|
||||||
|
logger.info(
|
||||||
|
f"Retrying batch insert with start_sequence={start_sequence}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Not a collision or max retries exceeded - propagate error
|
||||||
|
raise
|
||||||
|
|
||||||
return [ChatMessage.from_db(m) for m in created_messages]
|
# Should never reach here due to raise in exception handler
|
||||||
|
raise RuntimeError(f"Failed to insert messages after {max_retries} attempts")
|
||||||
|
|
||||||
|
|
||||||
async def get_user_chat_sessions(
|
async def get_user_chat_sessions(
|
||||||
@@ -237,10 +276,23 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def get_chat_session_message_count(session_id: str) -> int:
|
async def get_next_sequence(session_id: str) -> int:
|
||||||
"""Get the number of messages in a chat session."""
|
"""Get the next sequence number for a new message in this session.
|
||||||
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
|
||||||
return count
|
Uses MAX(sequence) + 1 for robustness. Returns 0 if no messages exist.
|
||||||
|
More robust than COUNT(*) because it's immune to deleted messages.
|
||||||
|
"""
|
||||||
|
result = await db.prisma.query_raw(
|
||||||
|
"""
|
||||||
|
SELECT COALESCE(MAX(sequence) + 1, 0) as next_seq
|
||||||
|
FROM "ChatMessage"
|
||||||
|
WHERE "sessionId" = $1
|
||||||
|
""",
|
||||||
|
session_id,
|
||||||
|
)
|
||||||
|
if not result or len(result) == 0:
|
||||||
|
return 0
|
||||||
|
return int(result[0]["next_seq"])
|
||||||
|
|
||||||
|
|
||||||
async def update_tool_message_content(
|
async def update_tool_message_content(
|
||||||
|
|||||||
@@ -436,7 +436,7 @@ async def upsert_chat_session(
|
|||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
*,
|
*,
|
||||||
existing_message_count: int | None = None,
|
existing_message_count: int | None = None,
|
||||||
) -> ChatSession:
|
) -> tuple[ChatSession, int]:
|
||||||
"""Update a chat session in both cache and database.
|
"""Update a chat session in both cache and database.
|
||||||
|
|
||||||
Uses session-level locking to prevent race conditions when concurrent
|
Uses session-level locking to prevent race conditions when concurrent
|
||||||
@@ -449,6 +449,10 @@ async def upsert_chat_session(
|
|||||||
accurately. Useful for incremental saves in a streaming loop
|
accurately. Useful for incremental saves in a streaming loop
|
||||||
where the caller already knows how many messages are persisted.
|
where the caller already knows how many messages are persisted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (session, final_message_count) where final_message_count is
|
||||||
|
the actual persisted message count after collision detection adjustments.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
DatabaseError: If the database write fails. The cache is still updated
|
DatabaseError: If the database write fails. The cache is still updated
|
||||||
as a best-effort optimization, but the error is propagated to ensure
|
as a best-effort optimization, but the error is propagated to ensure
|
||||||
@@ -461,15 +465,16 @@ async def upsert_chat_session(
|
|||||||
async with lock:
|
async with lock:
|
||||||
# Get existing message count from DB for incremental saves
|
# Get existing message count from DB for incremental saves
|
||||||
if existing_message_count is None:
|
if existing_message_count is None:
|
||||||
existing_message_count = await chat_db().get_chat_session_message_count(
|
existing_message_count = await chat_db().get_next_sequence(
|
||||||
session.session_id
|
session.session_id
|
||||||
)
|
)
|
||||||
|
|
||||||
db_error: Exception | None = None
|
db_error: Exception | None = None
|
||||||
|
final_count = existing_message_count
|
||||||
|
|
||||||
# Save to database (primary storage)
|
# Save to database (primary storage)
|
||||||
try:
|
try:
|
||||||
await _save_session_to_db(
|
final_count = await _save_session_to_db(
|
||||||
session,
|
session,
|
||||||
existing_message_count,
|
existing_message_count,
|
||||||
skip_existence_check=existing_message_count > 0,
|
skip_existence_check=existing_message_count > 0,
|
||||||
@@ -500,7 +505,7 @@ async def upsert_chat_session(
|
|||||||
f"Failed to persist chat session {session.session_id} to database"
|
f"Failed to persist chat session {session.session_id} to database"
|
||||||
) from db_error
|
) from db_error
|
||||||
|
|
||||||
return session
|
return session, final_count
|
||||||
|
|
||||||
|
|
||||||
async def _save_session_to_db(
|
async def _save_session_to_db(
|
||||||
@@ -508,13 +513,16 @@ async def _save_session_to_db(
|
|||||||
existing_message_count: int,
|
existing_message_count: int,
|
||||||
*,
|
*,
|
||||||
skip_existence_check: bool = False,
|
skip_existence_check: bool = False,
|
||||||
) -> None:
|
) -> int:
|
||||||
"""Save or update a chat session in the database.
|
"""Save or update a chat session in the database.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
skip_existence_check: When True, skip the ``get_chat_session`` query
|
skip_existence_check: When True, skip the ``get_chat_session`` query
|
||||||
and assume the session row already exists. Saves one DB round trip
|
and assume the session row already exists. Saves one DB round trip
|
||||||
for incremental saves during streaming.
|
for incremental saves during streaming.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final message count after save (accounting for collision detection).
|
||||||
"""
|
"""
|
||||||
db = chat_db()
|
db = chat_db()
|
||||||
|
|
||||||
@@ -546,6 +554,7 @@ async def _save_session_to_db(
|
|||||||
|
|
||||||
# Add new messages (only those after existing count)
|
# Add new messages (only those after existing count)
|
||||||
new_messages = session.messages[existing_message_count:]
|
new_messages = session.messages[existing_message_count:]
|
||||||
|
final_count = existing_message_count
|
||||||
if new_messages:
|
if new_messages:
|
||||||
messages_data = []
|
messages_data = []
|
||||||
for msg in new_messages:
|
for msg in new_messages:
|
||||||
@@ -565,12 +574,14 @@ async def _save_session_to_db(
|
|||||||
f"roles={[m['role'] for m in messages_data]}, "
|
f"roles={[m['role'] for m in messages_data]}, "
|
||||||
f"start_sequence={existing_message_count}"
|
f"start_sequence={existing_message_count}"
|
||||||
)
|
)
|
||||||
await db.add_chat_messages_batch(
|
_, final_count = await db.add_chat_messages_batch(
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
messages=messages_data,
|
messages=messages_data,
|
||||||
start_sequence=existing_message_count,
|
start_sequence=existing_message_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return final_count
|
||||||
|
|
||||||
|
|
||||||
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
||||||
"""Atomically append a message to a session and persist it.
|
"""Atomically append a message to a session and persist it.
|
||||||
@@ -587,9 +598,7 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
|
|||||||
raise ValueError(f"Session {session_id} not found")
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
|
||||||
session.messages.append(message)
|
session.messages.append(message)
|
||||||
existing_message_count = await chat_db().get_chat_session_message_count(
|
existing_message_count = await chat_db().get_next_sequence(session_id)
|
||||||
session_id
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await _save_session_to_db(session, existing_message_count)
|
await _save_session_to_db(session, existing_message_count)
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ async def test_chatsession_redis_storage(setup_test_user, test_user_id):
|
|||||||
s = ChatSession.new(user_id=test_user_id)
|
s = ChatSession.new(user_id=test_user_id)
|
||||||
s.messages = messages
|
s.messages = messages
|
||||||
|
|
||||||
s = await upsert_chat_session(s)
|
s, _ = await upsert_chat_session(s)
|
||||||
|
|
||||||
s2 = await get_chat_session(
|
s2 = await get_chat_session(
|
||||||
session_id=s.session_id,
|
session_id=s.session_id,
|
||||||
@@ -77,7 +77,7 @@ async def test_chatsession_redis_storage_user_id_mismatch(
|
|||||||
|
|
||||||
s = ChatSession.new(user_id=test_user_id)
|
s = ChatSession.new(user_id=test_user_id)
|
||||||
s.messages = messages
|
s.messages = messages
|
||||||
s = await upsert_chat_session(s)
|
s, _ = await upsert_chat_session(s)
|
||||||
|
|
||||||
s2 = await get_chat_session(s.session_id, "different_user_id")
|
s2 = await get_chat_session(s.session_id, "different_user_id")
|
||||||
|
|
||||||
@@ -94,7 +94,7 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
|
|||||||
s.messages = messages # Contains user, assistant, and tool messages
|
s.messages = messages # Contains user, assistant, and tool messages
|
||||||
assert s.session_id is not None, "Session id is not set"
|
assert s.session_id is not None, "Session id is not set"
|
||||||
# Upsert to save to both cache and DB
|
# Upsert to save to both cache and DB
|
||||||
s = await upsert_chat_session(s)
|
s, _ = await upsert_chat_session(s)
|
||||||
|
|
||||||
# Clear the Redis cache to force DB load
|
# Clear the Redis cache to force DB load
|
||||||
redis_key = f"chat:session:{s.session_id}"
|
redis_key = f"chat:session:{s.session_id}"
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from collections.abc import AsyncGenerator
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.data.redis_client import get_redis_async
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
from .. import stream_registry
|
from .. import stream_registry
|
||||||
@@ -132,8 +133,65 @@ is delivered to the user via a background stream.
|
|||||||
All tasks must run in the foreground.
|
All tasks must run in the foreground.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Session streaming lock configuration
|
||||||
|
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
|
||||||
|
STREAM_LOCK_TTL = 3600 # 1 hour - matches stream_ttl
|
||||||
|
|
||||||
def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
|
|
||||||
|
async def _acquire_stream_lock(session_id: str, stream_id: str) -> bool:
|
||||||
|
"""Acquire an exclusive lock for streaming to this session.
|
||||||
|
|
||||||
|
Prevents multiple concurrent streams to the same session which can cause:
|
||||||
|
- Message duplication
|
||||||
|
- Race conditions in message saves
|
||||||
|
- Confusing UX with multiple AI responses
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if lock was acquired, False if another stream is active.
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
lock_key = f"{STREAM_LOCK_PREFIX}{session_id}"
|
||||||
|
# SET NX EX - atomic "set if not exists" with expiry
|
||||||
|
result = await redis.set(lock_key, stream_id, ex=STREAM_LOCK_TTL, nx=True)
|
||||||
|
return result is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def _release_stream_lock(session_id: str, stream_id: str) -> None:
|
||||||
|
"""Release the stream lock if we still own it.
|
||||||
|
|
||||||
|
Only releases the lock if the stored stream_id matches ours (prevents
|
||||||
|
releasing another stream's lock if we somehow timed out).
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
lock_key = f"{STREAM_LOCK_PREFIX}{session_id}"
|
||||||
|
|
||||||
|
# Lua script for atomic compare-and-delete (only delete if value matches)
|
||||||
|
script = """
|
||||||
|
if redis.call("GET", KEYS[1]) == ARGV[1] then
|
||||||
|
return redis.call("DEL", KEYS[1])
|
||||||
|
else
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
"""
|
||||||
|
await redis.eval(script, 1, lock_key, stream_id) # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
async def check_active_stream(session_id: str) -> str | None:
|
||||||
|
"""Check if a stream is currently active for this session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The active stream_id if one exists, None otherwise.
|
||||||
|
"""
|
||||||
|
redis = await get_redis_async()
|
||||||
|
lock_key = f"{STREAM_LOCK_PREFIX}{session_id}"
|
||||||
|
active_stream = await redis.get(lock_key)
|
||||||
|
return active_stream.decode() if isinstance(active_stream, bytes) else active_stream
|
||||||
|
|
||||||
|
|
||||||
|
def _build_long_running_callback(
|
||||||
|
user_id: str | None,
|
||||||
|
saved_msg_count_ref: list[int] | None = None,
|
||||||
|
) -> LongRunningCallback:
|
||||||
"""Build a callback that delegates long-running tools to the non-SDK infrastructure.
|
"""Build a callback that delegates long-running tools to the non-SDK infrastructure.
|
||||||
|
|
||||||
Long-running tools (create_agent, edit_agent, etc.) are delegated to the
|
Long-running tools (create_agent, edit_agent, etc.) are delegated to the
|
||||||
@@ -142,6 +200,12 @@ def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
|
|||||||
page refreshes / pod restarts, and the frontend shows the proper loading
|
page refreshes / pod restarts, and the frontend shows the proper loading
|
||||||
widget with progress updates.
|
widget with progress updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID for the session
|
||||||
|
saved_msg_count_ref: Mutable reference [count] shared with streaming loop
|
||||||
|
for coordinating message saves. When provided, the callback will update
|
||||||
|
it after appending messages to prevent counter drift.
|
||||||
|
|
||||||
The returned callback matches the ``LongRunningCallback`` signature:
|
The returned callback matches the ``LongRunningCallback`` signature:
|
||||||
``(tool_name, args, session) -> MCP response dict``.
|
``(tool_name, args, session) -> MCP response dict``.
|
||||||
"""
|
"""
|
||||||
@@ -207,7 +271,11 @@ def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
|
|||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
)
|
)
|
||||||
session.messages.append(pending_message)
|
session.messages.append(pending_message)
|
||||||
await upsert_chat_session(session)
|
# Collision detection happens in add_chat_messages_batch (db.py)
|
||||||
|
_, final_count = await upsert_chat_session(session)
|
||||||
|
# Update shared counter so streaming loop stays in sync
|
||||||
|
if saved_msg_count_ref is not None:
|
||||||
|
saved_msg_count_ref[0] = final_count
|
||||||
|
|
||||||
# --- Spawn background task (reuses non-SDK infrastructure) ---
|
# --- Spawn background task (reuses non-SDK infrastructure) ---
|
||||||
bg_task = asyncio.create_task(
|
bg_task = asyncio.create_task(
|
||||||
@@ -542,7 +610,7 @@ async def stream_chat_completion_sdk(
|
|||||||
user_id=user_id, session_id=session_id, message_length=len(message)
|
user_id=user_id, session_id=session_id, message_length=len(message)
|
||||||
)
|
)
|
||||||
|
|
||||||
session = await upsert_chat_session(session)
|
session, _ = await upsert_chat_session(session)
|
||||||
|
|
||||||
# Generate title for new sessions (first user message)
|
# Generate title for new sessions (first user message)
|
||||||
if is_user_message and not session.title:
|
if is_user_message and not session.title:
|
||||||
@@ -564,6 +632,23 @@ async def stream_chat_completion_sdk(
|
|||||||
system_prompt += _SDK_TOOL_SUPPLEMENT
|
system_prompt += _SDK_TOOL_SUPPLEMENT
|
||||||
message_id = str(uuid.uuid4())
|
message_id = str(uuid.uuid4())
|
||||||
task_id = str(uuid.uuid4())
|
task_id = str(uuid.uuid4())
|
||||||
|
stream_id = task_id # Use task_id as unique stream identifier
|
||||||
|
|
||||||
|
# Acquire stream lock to prevent concurrent streams to the same session
|
||||||
|
lock_acquired = await _acquire_stream_lock(session_id, stream_id)
|
||||||
|
if not lock_acquired:
|
||||||
|
# Another stream is active - check if it's still alive
|
||||||
|
active_stream = await check_active_stream(session_id)
|
||||||
|
logger.warning(
|
||||||
|
f"[SDK] Session {session_id} already has an active stream: {active_stream}"
|
||||||
|
)
|
||||||
|
yield StreamError(
|
||||||
|
errorText="Another stream is already active for this session. "
|
||||||
|
"Please wait for it to complete or refresh the page.",
|
||||||
|
code="stream_already_active",
|
||||||
|
)
|
||||||
|
yield StreamFinish()
|
||||||
|
return
|
||||||
|
|
||||||
yield StreamStart(messageId=message_id, taskId=task_id)
|
yield StreamStart(messageId=message_id, taskId=task_id)
|
||||||
|
|
||||||
@@ -581,10 +666,16 @@ async def stream_chat_completion_sdk(
|
|||||||
sdk_cwd = _make_sdk_cwd(session_id)
|
sdk_cwd = _make_sdk_cwd(session_id)
|
||||||
os.makedirs(sdk_cwd, exist_ok=True)
|
os.makedirs(sdk_cwd, exist_ok=True)
|
||||||
|
|
||||||
|
# Initialize saved message counter as mutable list so long-running
|
||||||
|
# callback and streaming loop can coordinate
|
||||||
|
saved_msg_count_ref: list[int] = [len(session.messages)]
|
||||||
|
|
||||||
set_execution_context(
|
set_execution_context(
|
||||||
user_id,
|
user_id,
|
||||||
session,
|
session,
|
||||||
long_running_callback=_build_long_running_callback(user_id),
|
long_running_callback=_build_long_running_callback(
|
||||||
|
user_id, saved_msg_count_ref
|
||||||
|
),
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
||||||
@@ -715,9 +806,8 @@ async def stream_chat_completion_sdk(
|
|||||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||||
has_appended_assistant = False
|
has_appended_assistant = False
|
||||||
has_tool_results = False
|
has_tool_results = False
|
||||||
# Track persisted message count to skip DB count queries
|
# Track persisted message count. Uses shared ref so long-running
|
||||||
# on incremental saves. Initial save happened at line 545.
|
# callback can update it for coordination
|
||||||
saved_msg_count = len(session.messages)
|
|
||||||
|
|
||||||
# Use an explicit async iterator with non-cancelling heartbeats.
|
# Use an explicit async iterator with non-cancelling heartbeats.
|
||||||
# CRITICAL: we must NOT cancel __anext__() mid-flight — doing so
|
# CRITICAL: we must NOT cancel __anext__() mid-flight — doing so
|
||||||
@@ -893,13 +983,12 @@ async def stream_chat_completion_sdk(
|
|||||||
has_appended_assistant = True
|
has_appended_assistant = True
|
||||||
# Save before tool execution starts so the
|
# Save before tool execution starts so the
|
||||||
# pending tool call is visible on refresh /
|
# pending tool call is visible on refresh /
|
||||||
# other devices.
|
# other devices. Collision detection happens
|
||||||
|
# in add_chat_messages_batch (db.py).
|
||||||
try:
|
try:
|
||||||
await upsert_chat_session(
|
_, final_count = await upsert_chat_session(session)
|
||||||
session,
|
# Update shared ref so callback stays in sync
|
||||||
existing_message_count=saved_msg_count,
|
saved_msg_count_ref[0] = final_count
|
||||||
)
|
|
||||||
saved_msg_count = len(session.messages)
|
|
||||||
except Exception as save_err:
|
except Exception as save_err:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"[SDK] [%s] Incremental save " "failed: %s",
|
"[SDK] [%s] Incremental save " "failed: %s",
|
||||||
@@ -922,12 +1011,11 @@ async def stream_chat_completion_sdk(
|
|||||||
has_tool_results = True
|
has_tool_results = True
|
||||||
# Save after tool completes so the result is
|
# Save after tool completes so the result is
|
||||||
# visible on refresh / other devices.
|
# visible on refresh / other devices.
|
||||||
|
# Collision detection happens in add_chat_messages_batch (db.py).
|
||||||
try:
|
try:
|
||||||
await upsert_chat_session(
|
_, final_count = await upsert_chat_session(session)
|
||||||
session,
|
# Update shared ref so callback stays in sync
|
||||||
existing_message_count=saved_msg_count,
|
saved_msg_count_ref[0] = final_count
|
||||||
)
|
|
||||||
saved_msg_count = len(session.messages)
|
|
||||||
except Exception as save_err:
|
except Exception as save_err:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"[SDK] [%s] Incremental save " "failed: %s",
|
"[SDK] [%s] Incremental save " "failed: %s",
|
||||||
@@ -1059,11 +1147,12 @@ async def stream_chat_completion_sdk(
|
|||||||
"to use the OpenAI-compatible fallback."
|
"to use the OpenAI-compatible fallback."
|
||||||
)
|
)
|
||||||
|
|
||||||
await asyncio.shield(upsert_chat_session(session))
|
_, final_count = await asyncio.shield(upsert_chat_session(session))
|
||||||
logger.info(
|
logger.info(
|
||||||
"[SDK] [%s] Session saved with %d messages",
|
"[SDK] [%s] Session saved with %d messages (DB count: %d)",
|
||||||
session_id[:12],
|
session_id[:12],
|
||||||
len(session.messages),
|
len(session.messages),
|
||||||
|
final_count,
|
||||||
)
|
)
|
||||||
if not stream_completed:
|
if not stream_completed:
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
@@ -1121,6 +1210,9 @@ async def stream_chat_completion_sdk(
|
|||||||
if sdk_cwd:
|
if sdk_cwd:
|
||||||
_cleanup_sdk_tool_results(sdk_cwd)
|
_cleanup_sdk_tool_results(sdk_cwd)
|
||||||
|
|
||||||
|
# Release stream lock to allow new streams for this session
|
||||||
|
await _release_stream_lock(session_id, stream_id)
|
||||||
|
|
||||||
|
|
||||||
async def _try_upload_transcript(
|
async def _try_upload_transcript(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
|||||||
@@ -352,7 +352,8 @@ async def assign_user_to_session(
|
|||||||
if not session:
|
if not session:
|
||||||
raise NotFoundError(f"Session {session_id} not found")
|
raise NotFoundError(f"Session {session_id} not found")
|
||||||
session.user_id = user_id
|
session.user_id = user_id
|
||||||
return await upsert_chat_session(session)
|
session, _ = await upsert_chat_session(session)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
async def stream_chat_completion(
|
async def stream_chat_completion(
|
||||||
@@ -463,7 +464,7 @@ async def stream_chat_completion(
|
|||||||
)
|
)
|
||||||
|
|
||||||
upsert_start = time.monotonic()
|
upsert_start = time.monotonic()
|
||||||
session = await upsert_chat_session(session)
|
session, _ = await upsert_chat_session(session)
|
||||||
upsert_time = (time.monotonic() - upsert_start) * 1000
|
upsert_time = (time.monotonic() - upsert_start) * 1000
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
|
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
|
||||||
@@ -689,7 +690,7 @@ async def stream_chat_completion(
|
|||||||
f"tool_responses={len(tool_response_messages)}"
|
f"tool_responses={len(tool_response_messages)}"
|
||||||
)
|
)
|
||||||
if messages_to_save_early or has_appended_streaming_message:
|
if messages_to_save_early or has_appended_streaming_message:
|
||||||
await upsert_chat_session(session)
|
_ = await upsert_chat_session(session)
|
||||||
has_saved_assistant_message = True
|
has_saved_assistant_message = True
|
||||||
|
|
||||||
has_yielded_end = True
|
has_yielded_end = True
|
||||||
@@ -728,7 +729,7 @@ async def stream_chat_completion(
|
|||||||
if tool_response_messages:
|
if tool_response_messages:
|
||||||
session.messages.extend(tool_response_messages)
|
session.messages.extend(tool_response_messages)
|
||||||
try:
|
try:
|
||||||
await upsert_chat_session(session)
|
_ = await upsert_chat_session(session)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Failed to save interrupted session {session.session_id}: {e}"
|
f"Failed to save interrupted session {session.session_id}: {e}"
|
||||||
@@ -769,7 +770,7 @@ async def stream_chat_completion(
|
|||||||
if messages_to_save:
|
if messages_to_save:
|
||||||
session.messages.extend(messages_to_save)
|
session.messages.extend(messages_to_save)
|
||||||
if messages_to_save or has_appended_streaming_message:
|
if messages_to_save or has_appended_streaming_message:
|
||||||
await upsert_chat_session(session)
|
_ = await upsert_chat_session(session)
|
||||||
|
|
||||||
if not has_yielded_error:
|
if not has_yielded_error:
|
||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
@@ -853,7 +854,7 @@ async def stream_chat_completion(
|
|||||||
not has_long_running_tool_call
|
not has_long_running_tool_call
|
||||||
and (messages_to_save or has_appended_streaming_message)
|
and (messages_to_save or has_appended_streaming_message)
|
||||||
):
|
):
|
||||||
await upsert_chat_session(session)
|
_ = await upsert_chat_session(session)
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Assistant message already saved when StreamFinish was received, "
|
"Assistant message already saved when StreamFinish was received, "
|
||||||
@@ -1525,7 +1526,7 @@ async def _yield_tool_call(
|
|||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
)
|
)
|
||||||
session.messages.append(pending_message)
|
session.messages.append(pending_message)
|
||||||
await upsert_chat_session(session)
|
_ = await upsert_chat_session(session)
|
||||||
|
|
||||||
await _with_optional_lock(session_lock, _save_pending)
|
await _with_optional_lock(session_lock, _save_pending)
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -2019,7 +2020,7 @@ async def _generate_llm_continuation(
|
|||||||
fresh_session.messages.append(assistant_message)
|
fresh_session.messages.append(assistant_message)
|
||||||
|
|
||||||
# Save to database (not cache) to persist the response
|
# Save to database (not cache) to persist the response
|
||||||
await upsert_chat_session(fresh_session)
|
_ = await upsert_chat_session(fresh_session)
|
||||||
|
|
||||||
# Invalidate cache so next poll/refresh gets fresh data
|
# Invalidate cache so next poll/refresh gets fresh data
|
||||||
await invalidate_session_cache(session_id)
|
await invalidate_session_cache(session_id)
|
||||||
@@ -2225,7 +2226,7 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
fresh_session.messages.append(assistant_message)
|
fresh_session.messages.append(assistant_message)
|
||||||
|
|
||||||
# Save to database (not cache) to persist the response
|
# Save to database (not cache) to persist the response
|
||||||
await upsert_chat_session(fresh_session)
|
_ = await upsert_chat_session(fresh_session)
|
||||||
|
|
||||||
# Invalidate cache so next poll/refresh gets fresh data
|
# Invalidate cache so next poll/refresh gets fresh data
|
||||||
await invalidate_session_cache(session_id)
|
await invalidate_session_cache(session_id)
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
|
|||||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||||
|
|
||||||
session = await create_chat_session(test_user_id)
|
session = await create_chat_session(test_user_id)
|
||||||
session = await upsert_chat_session(session)
|
session, _ = await upsert_chat_session(session)
|
||||||
|
|
||||||
has_errors = False
|
has_errors = False
|
||||||
has_ended = False
|
has_ended = False
|
||||||
@@ -104,7 +104,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
|||||||
return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test")
|
return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test")
|
||||||
|
|
||||||
session = await create_chat_session(test_user_id)
|
session = await create_chat_session(test_user_id)
|
||||||
session = await upsert_chat_session(session)
|
session, _ = await upsert_chat_session(session)
|
||||||
|
|
||||||
# --- Turn 1: send a message with a unique keyword ---
|
# --- Turn 1: send a message with a unique keyword ---
|
||||||
keyword = "ZEPHYR42"
|
keyword = "ZEPHYR42"
|
||||||
|
|||||||
@@ -303,7 +303,7 @@ class DatabaseManager(AppService):
|
|||||||
get_user_chat_sessions = _(chat_db.get_user_chat_sessions)
|
get_user_chat_sessions = _(chat_db.get_user_chat_sessions)
|
||||||
get_user_session_count = _(chat_db.get_user_session_count)
|
get_user_session_count = _(chat_db.get_user_session_count)
|
||||||
delete_chat_session = _(chat_db.delete_chat_session)
|
delete_chat_session = _(chat_db.delete_chat_session)
|
||||||
get_chat_session_message_count = _(chat_db.get_chat_session_message_count)
|
get_next_sequence = _(chat_db.get_next_sequence)
|
||||||
update_tool_message_content = _(chat_db.update_tool_message_content)
|
update_tool_message_content = _(chat_db.update_tool_message_content)
|
||||||
|
|
||||||
|
|
||||||
@@ -473,5 +473,5 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
|||||||
get_user_chat_sessions = d.get_user_chat_sessions
|
get_user_chat_sessions = d.get_user_chat_sessions
|
||||||
get_user_session_count = d.get_user_session_count
|
get_user_session_count = d.get_user_session_count
|
||||||
delete_chat_session = d.delete_chat_session
|
delete_chat_session = d.delete_chat_session
|
||||||
get_chat_session_message_count = d.get_chat_session_message_count
|
get_next_sequence = d.get_next_sequence
|
||||||
update_tool_message_content = d.update_tool_message_content
|
update_tool_message_content = d.update_tool_message_content
|
||||||
|
|||||||
Reference in New Issue
Block a user