mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Requested by @majdyz Concurrent writers (incremental streaming saves from PR #12173 and long-running tool callbacks) can race to persist messages with the same `(sessionId, sequence)` pair, causing unique constraint violations on `ChatMessage`. **Root cause:** The streaming loop tracks `saved_msg_count` in-memory, but the long-running tool callback (`_build_long_running_callback`) also appends messages and calls `upsert_chat_session` independently — without coordinating sequence numbers. When the streaming loop does its next incremental save with the stale `saved_msg_count`, it tries to insert at a sequence that already exists. **Fix:** Multi-layered defense-in-depth approach: 1. **Collision detection with retry** (db.py): `add_chat_messages_batch` uses `create_many()` in a transaction. On `UniqueViolationError`, queries `MAX(sequence)+1` from DB and retries with the correct offset (max 5 attempts). 2. **Robust sequence tracking** (db.py): `get_next_sequence()` uses indexed `find_first` with `order={"sequence": "desc"}` for O(1) MAX lookup, immune to deleted messages. 3. **Session-based counter** (model.py): Added `saved_message_count` field to `ChatSession`. `upsert_chat_session` returns the session with updated count, eliminating tuple returns throughout the codebase. 4. **MessageCounter dataclass** (sdk/service.py): Replaced list[int] mutable reference pattern with a clean `MessageCounter` dataclass for shared state between streaming loop and long-running callbacks. 5. **Session locking** (sdk/service.py): Prevent concurrent streams on the same session using Redis `SET NX EX` distributed locks with TTL refresh on heartbeats (config.stream_ttl = 3600s). 6. **Atomic operations** (db.py): Single timestamp for all messages and session update in batch operations for consistency. Parallel queries with `asyncio.gather` for lower latency. 7. **Config-based TTL** (sdk/service.py, config.py): Consolidated all TTL constants to use `config.stream_ttl` (3600s) with lock refresh on heartbeats. ### Key implementation details - **create_many**: Uses `sessionId` directly (not nested `Session.connect`) as `create_many` doesn't support nested creates - **Type narrowing**: Added explicit `assert session is not None` statements for pyright type checking in async contexts - **Parallel operations**: Use `asyncio.gather` for independent DB operations (create_many + session update) - **Single timestamp**: All messages in a batch share the same `createdAt` timestamp for atomicity ### Changes - `backend/copilot/db.py`: Collision detection with `create_many` + retry, indexed sequence lookup, single timestamp, parallel queries - `backend/copilot/model.py`: Added `saved_message_count` field, simplified return types - `backend/copilot/sdk/service.py`: MessageCounter dataclass, session locking with refresh, config-based TTL, type narrowing - `backend/copilot/service.py`: Updated all callers to handle new return types - `backend/copilot/config.py`: Increased long_running_operation_ttl to 3600s with clarified docstring - `backend/copilot/*_test.py`: Tests updated for new signatures --------- Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
252 lines
8.5 KiB
Python
252 lines
8.5 KiB
Python
"""Redis-based distributed locking for cluster coordination."""
|
|
|
|
import asyncio
|
|
import logging
|
|
import threading
|
|
import time
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from redis import Redis
|
|
from redis.asyncio import Redis as AsyncRedis
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ClusterLock:
|
|
"""Simple Redis-based distributed lock for preventing duplicate execution."""
|
|
|
|
def __init__(self, redis: "Redis", key: str, owner_id: str, timeout: int = 300):
|
|
self.redis = redis
|
|
self.key = key
|
|
self.owner_id = owner_id
|
|
self.timeout = timeout
|
|
self._last_refresh = 0.0
|
|
self._refresh_lock = threading.Lock()
|
|
|
|
def try_acquire(self) -> str | None:
|
|
"""Try to acquire the lock.
|
|
|
|
Returns:
|
|
- owner_id (self.owner_id) if successfully acquired
|
|
- different owner_id if someone else holds the lock
|
|
- None if Redis is unavailable or other error
|
|
"""
|
|
try:
|
|
success = self.redis.set(self.key, self.owner_id, nx=True, ex=self.timeout)
|
|
if success:
|
|
with self._refresh_lock:
|
|
self._last_refresh = time.time()
|
|
return self.owner_id # Successfully acquired
|
|
|
|
# Failed to acquire, get current owner
|
|
current_value = self.redis.get(self.key)
|
|
if current_value:
|
|
current_owner = (
|
|
current_value.decode("utf-8")
|
|
if isinstance(current_value, bytes)
|
|
else str(current_value)
|
|
)
|
|
return current_owner
|
|
|
|
# Key doesn't exist but we failed to set it - race condition or Redis issue
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"ClusterLock.try_acquire failed for key {self.key}: {e}")
|
|
return None
|
|
|
|
def refresh(self) -> bool:
|
|
"""Refresh lock TTL if we still own it.
|
|
|
|
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
|
|
During rate limiting, still verifies lock existence but skips TTL extension.
|
|
Setting _last_refresh to 0 bypasses rate limiting for testing.
|
|
|
|
Thread-safe: uses _refresh_lock to protect _last_refresh access.
|
|
"""
|
|
# Calculate refresh interval: max(timeout // 10, 1)
|
|
refresh_interval = max(self.timeout // 10, 1)
|
|
current_time = time.time()
|
|
|
|
# Check if we're within the rate limit period (thread-safe read)
|
|
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
|
|
with self._refresh_lock:
|
|
last_refresh = self._last_refresh
|
|
is_rate_limited = (
|
|
last_refresh > 0 and (current_time - last_refresh) < refresh_interval
|
|
)
|
|
|
|
try:
|
|
# Always verify lock existence, even during rate limiting
|
|
current_value = self.redis.get(self.key)
|
|
if not current_value:
|
|
with self._refresh_lock:
|
|
self._last_refresh = 0
|
|
return False
|
|
|
|
stored_owner = (
|
|
current_value.decode("utf-8")
|
|
if isinstance(current_value, bytes)
|
|
else str(current_value)
|
|
)
|
|
if stored_owner != self.owner_id:
|
|
with self._refresh_lock:
|
|
self._last_refresh = 0
|
|
return False
|
|
|
|
# If rate limited, return True but don't update TTL or timestamp
|
|
if is_rate_limited:
|
|
return True
|
|
|
|
# Perform actual refresh
|
|
if self.redis.expire(self.key, self.timeout):
|
|
with self._refresh_lock:
|
|
self._last_refresh = current_time
|
|
return True
|
|
|
|
with self._refresh_lock:
|
|
self._last_refresh = 0
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"ClusterLock.refresh failed for key {self.key}: {e}")
|
|
with self._refresh_lock:
|
|
self._last_refresh = 0
|
|
return False
|
|
|
|
def release(self):
|
|
"""Release the lock."""
|
|
with self._refresh_lock:
|
|
if self._last_refresh == 0:
|
|
return
|
|
|
|
try:
|
|
self.redis.delete(self.key)
|
|
except Exception:
|
|
pass
|
|
|
|
with self._refresh_lock:
|
|
self._last_refresh = 0.0
|
|
|
|
|
|
class AsyncClusterLock:
|
|
"""Async Redis-based distributed lock for preventing duplicate execution."""
|
|
|
|
def __init__(
|
|
self, redis: "AsyncRedis", key: str, owner_id: str, timeout: int = 300
|
|
):
|
|
self.redis = redis
|
|
self.key = key
|
|
self.owner_id = owner_id
|
|
self.timeout = timeout
|
|
self._last_refresh = 0.0
|
|
self._refresh_lock = asyncio.Lock()
|
|
|
|
async def try_acquire(self) -> str | None:
|
|
"""Try to acquire the lock.
|
|
|
|
Returns:
|
|
- owner_id (self.owner_id) if successfully acquired
|
|
- different owner_id if someone else holds the lock
|
|
- None if Redis is unavailable or other error
|
|
"""
|
|
try:
|
|
success = await self.redis.set(
|
|
self.key, self.owner_id, nx=True, ex=self.timeout
|
|
)
|
|
if success:
|
|
async with self._refresh_lock:
|
|
self._last_refresh = time.time()
|
|
return self.owner_id # Successfully acquired
|
|
|
|
# Failed to acquire, get current owner
|
|
current_value = await self.redis.get(self.key)
|
|
if current_value:
|
|
current_owner = (
|
|
current_value.decode("utf-8")
|
|
if isinstance(current_value, bytes)
|
|
else str(current_value)
|
|
)
|
|
return current_owner
|
|
|
|
# Key doesn't exist but we failed to set it - race condition or Redis issue
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"AsyncClusterLock.try_acquire failed for key {self.key}: {e}")
|
|
return None
|
|
|
|
async def refresh(self) -> bool:
|
|
"""Refresh lock TTL if we still own it.
|
|
|
|
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
|
|
During rate limiting, still verifies lock existence but skips TTL extension.
|
|
Setting _last_refresh to 0 bypasses rate limiting for testing.
|
|
|
|
Async-safe: uses asyncio.Lock to protect _last_refresh access.
|
|
"""
|
|
# Calculate refresh interval: max(timeout // 10, 1)
|
|
refresh_interval = max(self.timeout // 10, 1)
|
|
current_time = time.time()
|
|
|
|
# Check if we're within the rate limit period (async-safe read)
|
|
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
|
|
async with self._refresh_lock:
|
|
last_refresh = self._last_refresh
|
|
is_rate_limited = (
|
|
last_refresh > 0 and (current_time - last_refresh) < refresh_interval
|
|
)
|
|
|
|
try:
|
|
# Always verify lock existence, even during rate limiting
|
|
current_value = await self.redis.get(self.key)
|
|
if not current_value:
|
|
async with self._refresh_lock:
|
|
self._last_refresh = 0
|
|
return False
|
|
|
|
stored_owner = (
|
|
current_value.decode("utf-8")
|
|
if isinstance(current_value, bytes)
|
|
else str(current_value)
|
|
)
|
|
if stored_owner != self.owner_id:
|
|
async with self._refresh_lock:
|
|
self._last_refresh = 0
|
|
return False
|
|
|
|
# If rate limited, return True but don't update TTL or timestamp
|
|
if is_rate_limited:
|
|
return True
|
|
|
|
# Perform actual refresh
|
|
if await self.redis.expire(self.key, self.timeout):
|
|
async with self._refresh_lock:
|
|
self._last_refresh = current_time
|
|
return True
|
|
|
|
async with self._refresh_lock:
|
|
self._last_refresh = 0
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"AsyncClusterLock.refresh failed for key {self.key}: {e}")
|
|
async with self._refresh_lock:
|
|
self._last_refresh = 0
|
|
return False
|
|
|
|
async def release(self):
|
|
"""Release the lock."""
|
|
async with self._refresh_lock:
|
|
if self._last_refresh == 0:
|
|
return
|
|
|
|
try:
|
|
await self.redis.delete(self.key)
|
|
except Exception:
|
|
pass
|
|
|
|
async with self._refresh_lock:
|
|
self._last_refresh = 0.0
|