Compare commits

...

1 Commits

Author SHA1 Message Date
Reinier van der Leer
f19148777f fix(backend/chat): Use distributed locks for chat session mutations 2026-01-26 19:14:07 +01:00

View File

@@ -1,9 +1,8 @@
import asyncio
import logging
import uuid
from contextlib import asynccontextmanager
from datetime import UTC, datetime
from typing import Any
from weakref import WeakValueDictionary
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
@@ -52,28 +51,36 @@ def _get_session_cache_key(session_id: str) -> str:
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
# Session-level locks to prevent race conditions during concurrent upserts.
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
_session_locks_mutex = asyncio.Lock()
CHAT_SESSION_LOCK_PREFIX = "chat:session_lock:"
CHAT_SESSION_LOCK_TIMEOUT = 60 # seconds
async def _get_session_lock(session_id: str) -> asyncio.Lock:
"""Get or create a lock for a specific session to prevent concurrent upserts.
@asynccontextmanager
async def _session_lock(session_id: str):
"""Distributed lock for a chat session using Redis.
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
when no coroutine holds a reference to them, preventing memory leaks from
unbounded growth of session locks.
Provides system-wide locking across horizontally scaled backend instances
to prevent race conditions during concurrent session mutations.
"""
async with _session_locks_mutex:
lock = _session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
_session_locks[session_id] = lock
return lock
async_redis = await get_redis_async()
lock_key = _get_session_lock_key(session_id)
lock = async_redis.lock(lock_key, timeout=CHAT_SESSION_LOCK_TIMEOUT)
try:
await lock.acquire()
yield
finally:
if await lock.locked() and await lock.owned():
try:
await lock.release()
except Exception as e:
logger.warning(
f"Failed to release lock for chat session #{session_id}: {e}"
)
def _get_session_lock_key(session_id: str) -> str:
"""Get the Redis lock key for a chat session."""
return f"{CHAT_SESSION_LOCK_PREFIX}{session_id}"
class ChatMessage(BaseModel):
@@ -439,10 +446,8 @@ async def upsert_chat_session(
callers are aware of the persistence failure.
RedisError: If the cache write fails (after successful DB write).
"""
# Acquire session-specific lock to prevent concurrent upserts
lock = await _get_session_lock(session.session_id)
async with lock:
# Acquire distributed session-specific lock to prevent concurrent upserts
async with _session_lock(session.session_id):
# Get existing message count from DB for incremental saves
existing_message_count = await chat_db.get_chat_session_message_count(
session.session_id
@@ -553,7 +558,7 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
if not deleted:
return False
# Only invalidate cache and clean up lock after DB confirms deletion
# Invalidate cache after DB confirms deletion
try:
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
@@ -561,10 +566,6 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
except Exception as e:
logger.warning(f"Failed to delete session {session_id} from cache: {e}")
# Clean up session lock (belt-and-suspenders with WeakValueDictionary)
async with _session_locks_mutex:
_session_locks.pop(session_id, None)
return True