Files
AutoGPT/autogpt_platform/backend/backend/executor/cluster_lock.py
Otto e40c8c70ce fix(copilot): collision detection, session locking, and sync for concurrent message saves (#12177)
Requested by @majdyz

Concurrent writers (incremental streaming saves from PR #12173 and
long-running tool callbacks) can race to persist messages with the same
`(sessionId, sequence)` pair, causing unique constraint violations on
`ChatMessage`.

**Root cause:** The streaming loop tracks `saved_msg_count` in-memory,
but the long-running tool callback (`_build_long_running_callback`) also
appends messages and calls `upsert_chat_session` independently — without
coordinating sequence numbers. When the streaming loop does its next
incremental save with the stale `saved_msg_count`, it tries to insert at
a sequence that already exists.

**Fix:** Multi-layered defense-in-depth approach:

1. **Collision detection with retry** (db.py): `add_chat_messages_batch`
uses `create_many()` in a transaction. On `UniqueViolationError`,
queries `MAX(sequence)+1` from DB and retries with the correct offset
(max 5 attempts).

2. **Robust sequence tracking** (db.py): `get_next_sequence()` uses
indexed `find_first` with `order={"sequence": "desc"}` for O(1) MAX
lookup, immune to deleted messages.

3. **Session-based counter** (model.py): Added `saved_message_count`
field to `ChatSession`. `upsert_chat_session` returns the session with
updated count, eliminating tuple returns throughout the codebase.

4. **MessageCounter dataclass** (sdk/service.py): Replaced list[int]
mutable reference pattern with a clean `MessageCounter` dataclass for
shared state between streaming loop and long-running callbacks.

5. **Session locking** (sdk/service.py): Prevent concurrent streams on
the same session using Redis `SET NX EX` distributed locks with TTL
refresh on heartbeats (config.stream_ttl = 3600s).

6. **Atomic operations** (db.py): Single timestamp for all messages and
session update in batch operations for consistency. Parallel queries
with `asyncio.gather` for lower latency.

7. **Config-based TTL** (sdk/service.py, config.py): Consolidated all
TTL constants to use `config.stream_ttl` (3600s) with lock refresh on
heartbeats.

### Key implementation details

- **create_many**: Uses `sessionId` directly (not nested
`Session.connect`) as `create_many` doesn't support nested creates
- **Type narrowing**: Added explicit `assert session is not None`
statements for pyright type checking in async contexts
- **Parallel operations**: Use `asyncio.gather` for independent DB
operations (create_many + session update)
- **Single timestamp**: All messages in a batch share the same
`createdAt` timestamp for atomicity

### Changes
- `backend/copilot/db.py`: Collision detection with `create_many` +
retry, indexed sequence lookup, single timestamp, parallel queries
- `backend/copilot/model.py`: Added `saved_message_count` field,
simplified return types
- `backend/copilot/sdk/service.py`: MessageCounter dataclass, session
locking with refresh, config-based TTL, type narrowing
- `backend/copilot/service.py`: Updated all callers to handle new return
types
- `backend/copilot/config.py`: Increased long_running_operation_ttl to
3600s with clarified docstring
- `backend/copilot/*_test.py`: Tests updated for new signatures

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-02-20 15:05:03 +00:00

252 lines
8.5 KiB
Python

"""Redis-based distributed locking for cluster coordination."""
import asyncio
import logging
import threading
import time
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
logger = logging.getLogger(__name__)
class ClusterLock:
"""Simple Redis-based distributed lock for preventing duplicate execution."""
def __init__(self, redis: "Redis", key: str, owner_id: str, timeout: int = 300):
self.redis = redis
self.key = key
self.owner_id = owner_id
self.timeout = timeout
self._last_refresh = 0.0
self._refresh_lock = threading.Lock()
def try_acquire(self) -> str | None:
"""Try to acquire the lock.
Returns:
- owner_id (self.owner_id) if successfully acquired
- different owner_id if someone else holds the lock
- None if Redis is unavailable or other error
"""
try:
success = self.redis.set(self.key, self.owner_id, nx=True, ex=self.timeout)
if success:
with self._refresh_lock:
self._last_refresh = time.time()
return self.owner_id # Successfully acquired
# Failed to acquire, get current owner
current_value = self.redis.get(self.key)
if current_value:
current_owner = (
current_value.decode("utf-8")
if isinstance(current_value, bytes)
else str(current_value)
)
return current_owner
# Key doesn't exist but we failed to set it - race condition or Redis issue
return None
except Exception as e:
logger.error(f"ClusterLock.try_acquire failed for key {self.key}: {e}")
return None
def refresh(self) -> bool:
"""Refresh lock TTL if we still own it.
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
During rate limiting, still verifies lock existence but skips TTL extension.
Setting _last_refresh to 0 bypasses rate limiting for testing.
Thread-safe: uses _refresh_lock to protect _last_refresh access.
"""
# Calculate refresh interval: max(timeout // 10, 1)
refresh_interval = max(self.timeout // 10, 1)
current_time = time.time()
# Check if we're within the rate limit period (thread-safe read)
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
with self._refresh_lock:
last_refresh = self._last_refresh
is_rate_limited = (
last_refresh > 0 and (current_time - last_refresh) < refresh_interval
)
try:
# Always verify lock existence, even during rate limiting
current_value = self.redis.get(self.key)
if not current_value:
with self._refresh_lock:
self._last_refresh = 0
return False
stored_owner = (
current_value.decode("utf-8")
if isinstance(current_value, bytes)
else str(current_value)
)
if stored_owner != self.owner_id:
with self._refresh_lock:
self._last_refresh = 0
return False
# If rate limited, return True but don't update TTL or timestamp
if is_rate_limited:
return True
# Perform actual refresh
if self.redis.expire(self.key, self.timeout):
with self._refresh_lock:
self._last_refresh = current_time
return True
with self._refresh_lock:
self._last_refresh = 0
return False
except Exception as e:
logger.error(f"ClusterLock.refresh failed for key {self.key}: {e}")
with self._refresh_lock:
self._last_refresh = 0
return False
def release(self):
"""Release the lock."""
with self._refresh_lock:
if self._last_refresh == 0:
return
try:
self.redis.delete(self.key)
except Exception:
pass
with self._refresh_lock:
self._last_refresh = 0.0
class AsyncClusterLock:
"""Async Redis-based distributed lock for preventing duplicate execution."""
def __init__(
self, redis: "AsyncRedis", key: str, owner_id: str, timeout: int = 300
):
self.redis = redis
self.key = key
self.owner_id = owner_id
self.timeout = timeout
self._last_refresh = 0.0
self._refresh_lock = asyncio.Lock()
async def try_acquire(self) -> str | None:
"""Try to acquire the lock.
Returns:
- owner_id (self.owner_id) if successfully acquired
- different owner_id if someone else holds the lock
- None if Redis is unavailable or other error
"""
try:
success = await self.redis.set(
self.key, self.owner_id, nx=True, ex=self.timeout
)
if success:
async with self._refresh_lock:
self._last_refresh = time.time()
return self.owner_id # Successfully acquired
# Failed to acquire, get current owner
current_value = await self.redis.get(self.key)
if current_value:
current_owner = (
current_value.decode("utf-8")
if isinstance(current_value, bytes)
else str(current_value)
)
return current_owner
# Key doesn't exist but we failed to set it - race condition or Redis issue
return None
except Exception as e:
logger.error(f"AsyncClusterLock.try_acquire failed for key {self.key}: {e}")
return None
async def refresh(self) -> bool:
"""Refresh lock TTL if we still own it.
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
During rate limiting, still verifies lock existence but skips TTL extension.
Setting _last_refresh to 0 bypasses rate limiting for testing.
Async-safe: uses asyncio.Lock to protect _last_refresh access.
"""
# Calculate refresh interval: max(timeout // 10, 1)
refresh_interval = max(self.timeout // 10, 1)
current_time = time.time()
# Check if we're within the rate limit period (async-safe read)
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
async with self._refresh_lock:
last_refresh = self._last_refresh
is_rate_limited = (
last_refresh > 0 and (current_time - last_refresh) < refresh_interval
)
try:
# Always verify lock existence, even during rate limiting
current_value = await self.redis.get(self.key)
if not current_value:
async with self._refresh_lock:
self._last_refresh = 0
return False
stored_owner = (
current_value.decode("utf-8")
if isinstance(current_value, bytes)
else str(current_value)
)
if stored_owner != self.owner_id:
async with self._refresh_lock:
self._last_refresh = 0
return False
# If rate limited, return True but don't update TTL or timestamp
if is_rate_limited:
return True
# Perform actual refresh
if await self.redis.expire(self.key, self.timeout):
async with self._refresh_lock:
self._last_refresh = current_time
return True
async with self._refresh_lock:
self._last_refresh = 0
return False
except Exception as e:
logger.error(f"AsyncClusterLock.refresh failed for key {self.key}: {e}")
async with self._refresh_lock:
self._last_refresh = 0
return False
async def release(self):
"""Release the lock."""
async with self._refresh_lock:
if self._last_refresh == 0:
return
try:
await self.redis.delete(self.key)
except Exception:
pass
async with self._refresh_lock:
self._last_refresh = 0.0