mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(copilot): implement session locking to prevent concurrent streams
- Add stream_id (using task_id) to uniquely identify each stream - Acquire exclusive lock (Redis SET NX EX) when starting a stream - Release lock in finally block using Lua script (atomic compare-and-delete) - Return error if another stream is already active for the session - Lock TTL is 1 hour (matches stream_ttl) with automatic cleanup This prevents: - Message duplication from concurrent streams - Race conditions in message saves - Confusing UX with multiple AI responses - Frontend reconnecting while existing stream is active - Multiple browser tabs streaming to same session
This commit is contained in:
@@ -9,6 +9,7 @@ from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .. import stream_registry
|
||||
@@ -132,6 +133,60 @@ is delivered to the user via a background stream.
|
||||
All tasks must run in the foreground.
|
||||
"""
|
||||
|
||||
# Session streaming lock configuration
|
||||
STREAM_LOCK_PREFIX = "copilot:stream:lock:"
|
||||
STREAM_LOCK_TTL = 3600 # 1 hour - matches stream_ttl
|
||||
|
||||
|
||||
async def _acquire_stream_lock(session_id: str, stream_id: str) -> bool:
|
||||
"""Acquire an exclusive lock for streaming to this session.
|
||||
|
||||
Prevents multiple concurrent streams to the same session which can cause:
|
||||
- Message duplication
|
||||
- Race conditions in message saves
|
||||
- Confusing UX with multiple AI responses
|
||||
|
||||
Returns:
|
||||
True if lock was acquired, False if another stream is active.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
lock_key = f"{STREAM_LOCK_PREFIX}{session_id}"
|
||||
# SET NX EX - atomic "set if not exists" with expiry
|
||||
result = await redis.set(lock_key, stream_id, ex=STREAM_LOCK_TTL, nx=True)
|
||||
return result is not None
|
||||
|
||||
|
||||
async def _release_stream_lock(session_id: str, stream_id: str) -> None:
|
||||
"""Release the stream lock if we still own it.
|
||||
|
||||
Only releases the lock if the stored stream_id matches ours (prevents
|
||||
releasing another stream's lock if we somehow timed out).
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
lock_key = f"{STREAM_LOCK_PREFIX}{session_id}"
|
||||
|
||||
# Lua script for atomic compare-and-delete (only delete if value matches)
|
||||
script = """
|
||||
if redis.call("GET", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("DEL", KEYS[1])
|
||||
else
|
||||
return 0
|
||||
end
|
||||
"""
|
||||
await redis.eval(script, 1, lock_key, stream_id) # type: ignore[misc]
|
||||
|
||||
|
||||
async def check_active_stream(session_id: str) -> str | None:
|
||||
"""Check if a stream is currently active for this session.
|
||||
|
||||
Returns:
|
||||
The active stream_id if one exists, None otherwise.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
lock_key = f"{STREAM_LOCK_PREFIX}{session_id}"
|
||||
active_stream = await redis.get(lock_key)
|
||||
return active_stream.decode() if isinstance(active_stream, bytes) else active_stream
|
||||
|
||||
|
||||
def _build_long_running_callback(
|
||||
user_id: str | None,
|
||||
@@ -577,6 +632,23 @@ async def stream_chat_completion_sdk(
|
||||
system_prompt += _SDK_TOOL_SUPPLEMENT
|
||||
message_id = str(uuid.uuid4())
|
||||
task_id = str(uuid.uuid4())
|
||||
stream_id = task_id # Use task_id as unique stream identifier
|
||||
|
||||
# Acquire stream lock to prevent concurrent streams to the same session
|
||||
lock_acquired = await _acquire_stream_lock(session_id, stream_id)
|
||||
if not lock_acquired:
|
||||
# Another stream is active - check if it's still alive
|
||||
active_stream = await check_active_stream(session_id)
|
||||
logger.warning(
|
||||
f"[SDK] Session {session_id} already has an active stream: {active_stream}"
|
||||
)
|
||||
yield StreamError(
|
||||
errorText="Another stream is already active for this session. "
|
||||
"Please wait for it to complete or refresh the page.",
|
||||
code="stream_already_active",
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
yield StreamStart(messageId=message_id, taskId=task_id)
|
||||
|
||||
@@ -1137,6 +1209,9 @@ async def stream_chat_completion_sdk(
|
||||
if sdk_cwd:
|
||||
_cleanup_sdk_tool_results(sdk_cwd)
|
||||
|
||||
# Release stream lock to allow new streams for this session
|
||||
await _release_stream_lock(session_id, stream_id)
|
||||
|
||||
|
||||
async def _try_upload_transcript(
|
||||
user_id: str,
|
||||
|
||||
Reference in New Issue
Block a user