added shared caching

This commit is contained in:
Swifty
2025-10-01 16:17:26 +02:00
parent e13861ad33
commit b3fe2b84ce
8 changed files with 831 additions and 86 deletions

View File

@@ -1,6 +1,8 @@
import asyncio
import inspect
import logging
import os
import pickle
import threading
import time
from functools import wraps
@@ -14,6 +16,11 @@ from typing import (
runtime_checkable,
)
from dotenv import load_dotenv
from redis import ConnectionPool, Redis
from autogpt_libs.utils.retry import conn_retry
P = ParamSpec("P")
R = TypeVar("R")
R_co = TypeVar("R_co", covariant=True)
@@ -21,6 +28,40 @@ R_co = TypeVar("R_co", covariant=True)
logger = logging.getLogger(__name__)
load_dotenv()
HOST = os.getenv("REDIS_HOST", "localhost")
PORT = int(os.getenv("REDIS_PORT", "6379"))
PASSWORD = os.getenv("REDIS_PASSWORD", None)
logger = logging.getLogger(__name__)
@conn_retry("Redis", "Acquiring connection")
def connect() -> Redis:
# Configure connection pool for optimal performance
pool = ConnectionPool(
host=HOST,
port=PORT,
password=PASSWORD,
decode_responses=False, # Store binary data for pickle
max_connections=50, # Allow up to 50 concurrent connections
socket_keepalive=True, # Keep connections alive
socket_connect_timeout=5,
retry_on_timeout=True,
)
c = Redis(connection_pool=pool)
c.ping()
return c
redis_client = connect()
def _make_redis_key(key: tuple[Any, ...]) -> str:
"""Convert a hashable key tuple to a Redis key string."""
return f"cache:{hash(key)}"
def _make_hashable_key(
args: tuple[Any, ...], kwargs: dict[str, Any]
) -> tuple[Any, ...]:
@@ -85,6 +126,7 @@ def cached(
*,
maxsize: int = 128,
ttl_seconds: int | None = None,
shared_cache: bool = False,
) -> Callable[[Callable], CachedFunction]:
"""
Thundering herd safe cache decorator for both sync and async functions.
@@ -115,7 +157,6 @@ def cached(
"""
def decorator(target_func):
# Cache storage and per-event-loop locks
cache_storage = {}
_event_loop_locks = {} # Maps event loop to its asyncio.Lock
@@ -138,48 +179,98 @@ def cached(
key = _make_hashable_key(args, kwargs)
current_time = time.time()
# Fast path: check cache without lock
if key in cache_storage:
if ttl_seconds is None:
logger.debug(
f"Cache hit for {target_func.__name__} args: {args} kwargs: {kwargs}"
)
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.debug(
f"Cache hit for {target_func.__name__} args: {args} kwargs: {kwargs}"
)
return result
# Compute redis_key once if using shared cache
redis_key = _make_redis_key(key)
# Slow path: acquire lock for cache miss/expiry
async with _get_cache_lock():
# Double-check: another coroutine might have populated cache
# Fast path: check cache without lock
if shared_cache:
try:
# Use GET directly instead of EXISTS + GET (saves 1 round trip)
cached_bytes = redis_client.get(redis_key)
if cached_bytes is not None and isinstance(cached_bytes, bytes):
logger.debug(
f"Cache hit for {target_func.__name__} args: {args} kwargs: {kwargs}"
)
return pickle.loads(cached_bytes)
except Exception as e:
logger.error(
f"Redis error during cache check for {target_func.__name__}: {e}"
)
# Fall through to execute function
return await target_func(*args, **kwargs)
else:
if key in cache_storage:
if ttl_seconds is None:
logger.debug(
f"Cache hit for {target_func.__name__} args: {args} kwargs: {kwargs}"
)
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.debug(
f"Cache hit for {target_func.__name__} args: {args} kwargs: {kwargs}"
)
return result
# Slow path: acquire lock for cache miss/expiry
async with _get_cache_lock():
# Double-check: another coroutine might have populated cache
if shared_cache:
try:
# Use GET directly (saves 1 round trip)
cached_bytes = redis_client.get(redis_key)
if cached_bytes is not None and isinstance(
cached_bytes, bytes
):
return pickle.loads(cached_bytes)
except Exception as e:
logger.error(
f"Redis error during double-check for {target_func.__name__}: {e}"
)
# Continue to execute function
else:
if key in cache_storage:
if ttl_seconds is None:
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {target_func.__name__}")
result = await target_func(*args, **kwargs)
# Store result
if ttl_seconds is None:
cache_storage[key] = result
if shared_cache:
try:
pickled_result = pickle.dumps(
result, protocol=pickle.HIGHEST_PROTOCOL
)
if ttl_seconds is None:
redis_client.set(redis_key, pickled_result)
else:
redis_client.setex(
redis_key, ttl_seconds, pickled_result
)
except Exception as e:
logger.error(
f"Redis error storing cache for {target_func.__name__}: {e}"
)
# Continue without caching
else:
cache_storage[key] = (result, current_time)
if ttl_seconds is None:
cache_storage[key] = result
else:
cache_storage[key] = (result, current_time)
# Cleanup if needed
if len(cache_storage) > maxsize:
# Cleanup if needed (only for local cache)
if not shared_cache and len(cache_storage) > maxsize:
cutoff = maxsize // 2
oldest_keys = (
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
@@ -200,48 +291,98 @@ def cached(
key = _make_hashable_key(args, kwargs)
current_time = time.time()
# Fast path: check cache without lock
if key in cache_storage:
if ttl_seconds is None:
logger.debug(
f"Cache hit for {target_func.__name__} args: {args} kwargs: {kwargs}"
)
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.debug(
f"Cache hit for {target_func.__name__} args: {args} kwargs: {kwargs}"
)
return result
# Compute redis_key once if using shared cache
redis_key = _make_redis_key(key)
# Slow path: acquire lock for cache miss/expiry
with cache_lock:
# Double-check: another thread might have populated cache
# Fast path: check cache without lock
if shared_cache:
try:
# Use GET directly instead of EXISTS + GET (saves 1 round trip)
cached_bytes = redis_client.get(redis_key)
if cached_bytes is not None and isinstance(cached_bytes, bytes):
logger.debug(
f"Cache hit for {target_func.__name__} args: {args} kwargs: {kwargs}"
)
return pickle.loads(cached_bytes)
except Exception as e:
logger.error(
f"Redis error during cache check for {target_func.__name__}: {e}"
)
# Fall through to execute function
return target_func(*args, **kwargs)
else:
if key in cache_storage:
if ttl_seconds is None:
logger.debug(
f"Cache hit for {target_func.__name__} args: {args} kwargs: {kwargs}"
)
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.debug(
f"Cache hit for {target_func.__name__} args: {args} kwargs: {kwargs}"
)
return result
# Slow path: acquire lock for cache miss/expiry
with cache_lock:
# Double-check: another thread might have populated cache
if shared_cache:
try:
# Use GET directly (saves 1 round trip)
cached_bytes = redis_client.get(redis_key)
if cached_bytes is not None and isinstance(
cached_bytes, bytes
):
return pickle.loads(cached_bytes)
except Exception as e:
logger.error(
f"Redis error during double-check for {target_func.__name__}: {e}"
)
# Continue to execute function
else:
if key in cache_storage:
if ttl_seconds is None:
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {target_func.__name__}")
result = target_func(*args, **kwargs)
# Store result
if ttl_seconds is None:
cache_storage[key] = result
if shared_cache:
try:
pickled_result = pickle.dumps(
result, protocol=pickle.HIGHEST_PROTOCOL
)
if ttl_seconds is None:
redis_client.set(redis_key, pickled_result)
else:
redis_client.setex(
redis_key, ttl_seconds, pickled_result
)
except Exception as e:
logger.error(
f"Redis error storing cache for {target_func.__name__}: {e}"
)
# Continue without caching
else:
cache_storage[key] = (result, current_time)
if ttl_seconds is None:
cache_storage[key] = result
else:
cache_storage[key] = (result, current_time)
# Cleanup if needed
if len(cache_storage) > maxsize:
# Cleanup if needed (only for local cache)
if not shared_cache and len(cache_storage) > maxsize:
cutoff = maxsize // 2
oldest_keys = (
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
@@ -255,22 +396,47 @@ def cached(
# Add cache management methods
def cache_clear() -> None:
cache_storage.clear()
if shared_cache:
# Clear only cache keys (prefixed with "cache:") using pipeline for efficiency
keys = list(redis_client.scan_iter("cache:*", count=100))
if keys:
pipeline = redis_client.pipeline()
for key in keys:
pipeline.delete(key)
pipeline.execute()
else:
cache_storage.clear()
def cache_info() -> dict[str, int | None]:
return {
"size": len(cache_storage),
"maxsize": maxsize,
"ttl_seconds": ttl_seconds,
}
if shared_cache:
# Count only cache keys
cache_keys = list(redis_client.scan_iter("cache:*"))
return {
"size": len(cache_keys),
"maxsize": maxsize,
"ttl_seconds": ttl_seconds,
}
else:
return {
"size": len(cache_storage),
"maxsize": maxsize,
"ttl_seconds": ttl_seconds,
}
def cache_delete(*args, **kwargs) -> bool:
"""Delete a specific cache entry. Returns True if entry existed."""
key = _make_hashable_key(args, kwargs)
if key in cache_storage:
del cache_storage[key]
return True
return False
if shared_cache:
redis_key = _make_redis_key(key)
if redis_client.exists(redis_key):
redis_client.delete(redis_key)
return True
return False
else:
if key in cache_storage:
del cache_storage[key]
return True
return False
setattr(wrapper, "cache_clear", cache_clear)
setattr(wrapper, "cache_info", cache_info)

View File

@@ -12,7 +12,7 @@ import asyncio
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock
from unittest.mock import Mock, patch
import pytest
@@ -674,3 +674,324 @@ class TestCache:
# Try to delete non-existent entry
was_deleted = async_deletable_function.cache_delete(99)
assert was_deleted is False
class TestSharedCache:
"""Tests for shared_cache functionality using Redis."""
@pytest.fixture(autouse=True)
def setup_redis_mock(self):
"""Mock Redis client for testing."""
with patch("autogpt_libs.utils.cache.redis_client") as mock_redis:
# Configure mock to behave like Redis
self.mock_redis = mock_redis
self.redis_storage = {}
def mock_get(key):
return self.redis_storage.get(key)
def mock_set(key, value):
self.redis_storage[key] = value
return True
def mock_setex(key, ttl, value):
self.redis_storage[key] = value
return True
def mock_exists(key):
return 1 if key in self.redis_storage else 0
def mock_delete(key):
if key in self.redis_storage:
del self.redis_storage[key]
return 1
return 0
def mock_scan_iter(pattern, count=None):
# Pattern is a string like "cache:*", keys in storage are strings
prefix = pattern.rstrip("*")
return [
k
for k in self.redis_storage.keys()
if isinstance(k, str) and k.startswith(prefix)
]
def mock_pipeline():
pipe = Mock()
deleted_keys = []
def pipe_delete(key):
deleted_keys.append(key)
return pipe
def pipe_execute():
# Actually delete the keys when pipeline executes
for key in deleted_keys:
self.redis_storage.pop(key, None)
deleted_keys.clear()
return []
pipe.delete = Mock(side_effect=pipe_delete)
pipe.execute = Mock(side_effect=pipe_execute)
return pipe
mock_redis.get = Mock(side_effect=mock_get)
mock_redis.set = Mock(side_effect=mock_set)
mock_redis.setex = Mock(side_effect=mock_setex)
mock_redis.exists = Mock(side_effect=mock_exists)
mock_redis.delete = Mock(side_effect=mock_delete)
mock_redis.scan_iter = Mock(side_effect=mock_scan_iter)
mock_redis.pipeline = Mock(side_effect=mock_pipeline)
yield mock_redis
# Cleanup
self.redis_storage.clear()
def test_sync_shared_cache_basic(self):
"""Test basic shared cache functionality with sync function."""
call_count = 0
@cached(shared_cache=True)
def shared_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 10
# First call - should miss cache
result1 = shared_function(5)
assert result1 == 50
assert call_count == 1
assert self.mock_redis.get.called
assert self.mock_redis.set.called
# Second call - should hit cache
result2 = shared_function(5)
assert result2 == 50
assert call_count == 1 # Function not called again
@pytest.mark.asyncio
async def test_async_shared_cache_basic(self):
"""Test basic shared cache functionality with async function."""
call_count = 0
@cached(shared_cache=True)
async def async_shared_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return x * 20
# First call - should miss cache
result1 = await async_shared_function(3)
assert result1 == 60
assert call_count == 1
assert self.mock_redis.get.called
assert self.mock_redis.set.called
# Second call - should hit cache
result2 = await async_shared_function(3)
assert result2 == 60
assert call_count == 1 # Function not called again
def test_sync_shared_cache_with_ttl(self):
"""Test shared cache with TTL using sync function."""
call_count = 0
@cached(shared_cache=True, ttl_seconds=60)
def shared_ttl_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 30
# First call
result1 = shared_ttl_function(2)
assert result1 == 60
assert call_count == 1
assert self.mock_redis.setex.called
# Second call - should use cache
result2 = shared_ttl_function(2)
assert result2 == 60
assert call_count == 1
@pytest.mark.asyncio
async def test_async_shared_cache_with_ttl(self):
"""Test shared cache with TTL using async function."""
call_count = 0
@cached(shared_cache=True, ttl_seconds=120)
async def async_shared_ttl_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return x * 40
# First call
result1 = await async_shared_ttl_function(4)
assert result1 == 160
assert call_count == 1
assert self.mock_redis.setex.called
# Second call - should use cache
result2 = await async_shared_ttl_function(4)
assert result2 == 160
assert call_count == 1
def test_shared_cache_clear(self):
"""Test clearing shared cache."""
call_count = 0
@cached(shared_cache=True)
def clearable_shared_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 50
# First call
result1 = clearable_shared_function(1)
assert result1 == 50
assert call_count == 1
# Second call - should use cache
result2 = clearable_shared_function(1)
assert result2 == 50
assert call_count == 1
# Clear cache
clearable_shared_function.cache_clear()
assert self.mock_redis.pipeline.called
# Third call - should execute function again
result3 = clearable_shared_function(1)
assert result3 == 50
assert call_count == 2
def test_shared_cache_delete(self):
"""Test deleting specific shared cache entry."""
call_count = 0
@cached(shared_cache=True)
def deletable_shared_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 60
# First call for x=1
result1 = deletable_shared_function(1)
assert result1 == 60
assert call_count == 1
# First call for x=2
result2 = deletable_shared_function(2)
assert result2 == 120
assert call_count == 2
# Delete entry for x=1
was_deleted = deletable_shared_function.cache_delete(1)
assert was_deleted is True
# Call with x=1 should execute function again
result3 = deletable_shared_function(1)
assert result3 == 60
assert call_count == 3
# Call with x=2 should still use cache
result4 = deletable_shared_function(2)
assert result4 == 120
assert call_count == 3
def test_shared_cache_error_handling(self):
"""Test that Redis errors are handled gracefully."""
call_count = 0
@cached(shared_cache=True)
def error_prone_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 70
# Simulate Redis error
self.mock_redis.get.side_effect = Exception("Redis connection error")
# Function should still work
result = error_prone_function(1)
assert result == 70
assert call_count == 1
@pytest.mark.asyncio
async def test_async_shared_cache_error_handling(self):
"""Test that Redis errors are handled gracefully in async functions."""
call_count = 0
@cached(shared_cache=True)
async def async_error_prone_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return x * 80
# Simulate Redis error
self.mock_redis.get.side_effect = Exception("Redis connection error")
# Function should still work
result = await async_error_prone_function(1)
assert result == 80
assert call_count == 1
def test_shared_cache_with_complex_types(self):
"""Test shared cache with complex return types (lists, dicts)."""
call_count = 0
@cached(shared_cache=True)
def complex_return_function(x: int) -> dict:
nonlocal call_count
call_count += 1
return {"value": x, "squared": x * x, "list": [1, 2, 3]}
# First call
result1 = complex_return_function(5)
assert result1 == {"value": 5, "squared": 25, "list": [1, 2, 3]}
assert call_count == 1
# Second call - should use cache
result2 = complex_return_function(5)
assert result2 == {"value": 5, "squared": 25, "list": [1, 2, 3]}
assert call_count == 1
@pytest.mark.asyncio
async def test_async_thundering_herd_shared_cache(self):
"""Test thundering herd protection with shared cache."""
call_count = 0
@cached(shared_cache=True)
async def slow_shared_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.1)
return x * x
# Launch concurrent coroutines
tasks = [slow_shared_function(9) for _ in range(5)]
results = await asyncio.gather(*tasks)
# All results should be the same
assert all(result == 81 for result in results)
# Only one coroutine should have executed the function
assert call_count == 1
def test_shared_cache_info(self):
"""Test cache_info with shared cache."""
@cached(shared_cache=True, maxsize=100, ttl_seconds=300)
def info_function(x: int) -> int:
return x * 90
# Call the function to populate cache
info_function(1)
# Get cache info
info = info_function.cache_info()
assert "size" in info
assert info["maxsize"] == 100
assert info["ttl_seconds"] == 300

View File

@@ -0,0 +1,241 @@
# Duplicate of backend/backend/util/retry.py
import asyncio
import logging
import os
import threading
import time
from functools import wraps
from uuid import uuid4
from tenacity import (
retry,
retry_if_not_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)
logger = logging.getLogger(__name__)
# Alert threshold for excessive retries
EXCESSIVE_RETRY_THRESHOLD = 50
def _send_critical_retry_alert(
func_name: str, attempt_number: int, exception: Exception, context: str = ""
):
"""Send alert when a function is approaching the retry failure threshold."""
try:
# Import here to avoid circular imports
from backend.util.clients import get_notification_manager_client
notification_client = get_notification_manager_client()
prefix = f"{context}: " if context else ""
alert_msg = (
f"🚨 CRITICAL: Operation Approaching Failure Threshold: {prefix}'{func_name}'\n\n"
f"Current attempt: {attempt_number}/{EXCESSIVE_RETRY_THRESHOLD}\n"
f"Error: {type(exception).__name__}: {exception}\n\n"
f"This operation is about to fail permanently. Investigate immediately."
)
notification_client.discord_system_alert(alert_msg)
logger.critical(
f"CRITICAL ALERT SENT: Operation {func_name} at attempt {attempt_number}"
)
except Exception as alert_error:
logger.error(f"Failed to send critical retry alert: {alert_error}")
# Don't let alerting failures break the main flow
def _create_retry_callback(context: str = ""):
"""Create a retry callback with optional context."""
def callback(retry_state):
attempt_number = retry_state.attempt_number
exception = retry_state.outcome.exception()
func_name = getattr(retry_state.fn, "__name__", "unknown")
prefix = f"{context}: " if context else ""
if retry_state.outcome.failed and retry_state.next_action is None:
# Final failure - just log the error (alert was already sent at excessive threshold)
logger.error(
f"{prefix}Giving up after {attempt_number} attempts for '{func_name}': "
f"{type(exception).__name__}: {exception}"
)
else:
# Retry attempt - send critical alert only once at threshold
if attempt_number == EXCESSIVE_RETRY_THRESHOLD:
_send_critical_retry_alert(
func_name, attempt_number, exception, context
)
else:
logger.warning(
f"{prefix}Retry attempt {attempt_number} for '{func_name}': "
f"{type(exception).__name__}: {exception}"
)
return callback
def create_retry_decorator(
max_attempts: int = 5,
exclude_exceptions: tuple[type[BaseException], ...] = (),
max_wait: float = 30.0,
context: str = "",
reraise: bool = True,
):
"""
Create a preconfigured retry decorator with sensible defaults.
Uses exponential backoff with jitter by default.
Args:
max_attempts: Maximum number of attempts (default: 5)
exclude_exceptions: Tuple of exception types to not retry on
max_wait: Maximum wait time in seconds (default: 30)
context: Optional context string for log messages
reraise: Whether to reraise the final exception (default: True)
Returns:
Configured retry decorator
"""
if exclude_exceptions:
return retry(
stop=stop_after_attempt(max_attempts),
wait=wait_exponential_jitter(max=max_wait),
before_sleep=_create_retry_callback(context),
reraise=reraise,
retry=retry_if_not_exception_type(exclude_exceptions),
)
else:
return retry(
stop=stop_after_attempt(max_attempts),
wait=wait_exponential_jitter(max=max_wait),
before_sleep=_create_retry_callback(context),
reraise=reraise,
)
def _log_prefix(resource_name: str, conn_id: str):
"""
Returns a prefix string for logging purposes.
This needs to be called on the fly to get the current process ID & service name,
not the parent process ID & service name.
"""
return f"[PID-{os.getpid()}|THREAD-{threading.get_native_id()}|{resource_name}-{conn_id}]"
def conn_retry(
resource_name: str,
action_name: str,
max_retry: int = 5,
max_wait: float = 30,
):
conn_id = str(uuid4())
def on_retry(retry_state):
prefix = _log_prefix(resource_name, conn_id)
exception = retry_state.outcome.exception()
if retry_state.outcome.failed and retry_state.next_action is None:
logger.error(f"{prefix} {action_name} failed after retries: {exception}")
else:
logger.warning(
f"{prefix} {action_name} failed: {exception}. Retrying now..."
)
def decorator(func):
is_coroutine = asyncio.iscoroutinefunction(func)
# Use static retry configuration
retry_decorator = retry(
stop=stop_after_attempt(max_retry + 1), # +1 for the initial attempt
wait=wait_exponential_jitter(max=max_wait),
before_sleep=on_retry,
reraise=True,
)
wrapped_func = retry_decorator(func)
@wraps(func)
def sync_wrapper(*args, **kwargs):
prefix = _log_prefix(resource_name, conn_id)
logger.info(f"{prefix} {action_name} started...")
try:
result = wrapped_func(*args, **kwargs)
logger.info(f"{prefix} {action_name} completed successfully.")
return result
except Exception as e:
logger.error(f"{prefix} {action_name} failed after retries: {e}")
raise
@wraps(func)
async def async_wrapper(*args, **kwargs):
prefix = _log_prefix(resource_name, conn_id)
logger.info(f"{prefix} {action_name} started...")
try:
result = await wrapped_func(*args, **kwargs)
logger.info(f"{prefix} {action_name} completed successfully.")
return result
except Exception as e:
logger.error(f"{prefix} {action_name} failed after retries: {e}")
raise
return async_wrapper if is_coroutine else sync_wrapper
return decorator
# Preconfigured retry decorator for general functions
func_retry = create_retry_decorator(max_attempts=5)
def continuous_retry(*, retry_delay: float = 1.0):
def decorator(func):
is_coroutine = asyncio.iscoroutinefunction(func)
@wraps(func)
def sync_wrapper(*args, **kwargs):
counter = 0
while True:
try:
return func(*args, **kwargs)
except Exception as exc:
counter += 1
if counter % 10 == 0:
log = logger.exception
else:
log = logger.warning
log(
"%s failed for the %s times, error: [%s] — retrying in %.2fs",
func.__name__,
counter,
str(exc) or type(exc).__name__,
retry_delay,
)
time.sleep(retry_delay)
@wraps(func)
async def async_wrapper(*args, **kwargs):
while True:
counter = 0
try:
return await func(*args, **kwargs)
except Exception as exc:
counter += 1
if counter % 10 == 0:
log = logger.exception
else:
log = logger.warning
log(
"%s failed for the %s times, error: [%s] — retrying in %.2fs",
func.__name__,
counter,
str(exc) or type(exc).__name__,
retry_delay,
)
await asyncio.sleep(retry_delay)
return async_wrapper if is_coroutine else sync_wrapper
return decorator

View File

@@ -1719,6 +1719,22 @@ files = [
httpx = {version = ">=0.26,<0.29", extras = ["http2"]}
strenum = ">=0.4.15,<0.5.0"
[[package]]
name = "tenacity"
version = "9.1.2"
description = "Retry code until it succeeds"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138"},
{file = "tenacity-9.1.2.tar.gz", hash = "sha256:1169d376c297e7de388d18b4481760d478b0e99a777cad3a9c86e556f4b697cb"},
]
[package.extras]
doc = ["reno", "sphinx"]
test = ["pytest", "tornado (>=4.5)", "typeguard"]
[[package]]
name = "tomli"
version = "2.2.1"
@@ -1929,4 +1945,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<4.0"
content-hash = "0c40b63c3c921846cf05ccfb4e685d4959854b29c2c302245f9832e20aac6954"
content-hash = "5ec9e6cd2ef7524a356586354755215699e7b37b9bbdfbabc9c73b43085915f4"

View File

@@ -19,6 +19,7 @@ pydantic-settings = "^2.10.1"
pyjwt = { version = "^2.10.1", extras = ["crypto"] }
redis = "^6.2.0"
supabase = "^2.16.0"
tenacity = "^9.1.2"
uvicorn = "^0.35.0"
[tool.poetry.group.dev.dependencies]

View File

@@ -18,7 +18,7 @@ from backend.data.block import get_blocks
# Cache block definitions with costs - they rarely change
@cached(maxsize=1, ttl_seconds=3600)
@cached(maxsize=1, ttl_seconds=3600, shared_cache=True)
def get_cached_blocks() -> Sequence[dict]:
"""
Get cached blocks with thundering herd protection.
@@ -45,7 +45,7 @@ def get_cached_blocks() -> Sequence[dict]:
# Cache user's graphs list for 15 minutes
@cached(maxsize=1000, ttl_seconds=900)
@cached(maxsize=1000, ttl_seconds=900, shared_cache=True)
async def get_cached_graphs(
user_id: str,
page: int,
@@ -60,7 +60,7 @@ async def get_cached_graphs(
# Cache individual graph details for 30 minutes
@cached(maxsize=500, ttl_seconds=1800)
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
async def get_cached_graph(
graph_id: str,
version: int | None,
@@ -76,7 +76,7 @@ async def get_cached_graph(
# Cache graph versions for 30 minutes
@cached(maxsize=500, ttl_seconds=1800)
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
async def get_cached_graph_all_versions(
graph_id: str,
user_id: str,
@@ -92,7 +92,7 @@ async def get_cached_graph_all_versions(
# Cache graph executions for 10 seconds.
@cached(maxsize=1000, ttl_seconds=10)
@cached(maxsize=1000, ttl_seconds=10, shared_cache=True)
async def get_cached_graph_executions(
graph_id: str,
user_id: str,
@@ -109,7 +109,7 @@ async def get_cached_graph_executions(
# Cache all user executions for 10 seconds.
@cached(maxsize=500, ttl_seconds=10)
@cached(maxsize=500, ttl_seconds=10, shared_cache=True)
async def get_cached_graphs_executions(
user_id: str,
page: int,
@@ -124,7 +124,7 @@ async def get_cached_graphs_executions(
# Cache individual execution details for 10 seconds.
@cached(maxsize=1000, ttl_seconds=10)
@cached(maxsize=1000, ttl_seconds=10, shared_cache=True)
async def get_cached_graph_execution(
graph_exec_id: str,
user_id: str,
@@ -141,7 +141,7 @@ async def get_cached_graph_execution(
# Cache user timezone for 1 hour
@cached(maxsize=1000, ttl_seconds=3600)
@cached(maxsize=1000, ttl_seconds=3600, shared_cache=True)
async def get_cached_user_timezone(user_id: str):
"""Cached helper to get user timezone."""
user = await user_db.get_user_by_id(user_id)
@@ -149,7 +149,7 @@ async def get_cached_user_timezone(user_id: str):
# Cache user preferences for 30 minutes
@cached(maxsize=1000, ttl_seconds=1800)
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
async def get_cached_user_preferences(user_id: str):
"""Cached helper to get user notification preferences."""
return await user_db.get_user_notification_preference(user_id)

View File

@@ -13,7 +13,7 @@ import backend.server.v2.library.db
# Cache library agents list for 10 minutes
@cached(maxsize=1000, ttl_seconds=600)
@cached(maxsize=1000, ttl_seconds=600, shared_cache=True)
async def get_cached_library_agents(
user_id: str,
page: int = 1,
@@ -28,7 +28,7 @@ async def get_cached_library_agents(
# Cache user's favorite agents for 5 minutes - favorites change more frequently
@cached(maxsize=500, ttl_seconds=300)
@cached(maxsize=500, ttl_seconds=300, shared_cache=True)
async def get_cached_library_agent_favorites(
user_id: str,
page: int = 1,
@@ -43,7 +43,7 @@ async def get_cached_library_agent_favorites(
# Cache individual library agent details for 30 minutes
@cached(maxsize=1000, ttl_seconds=1800)
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
async def get_cached_library_agent(
library_agent_id: str,
user_id: str,
@@ -56,7 +56,7 @@ async def get_cached_library_agent(
# Cache library agent by graph ID for 30 minutes
@cached(maxsize=1000, ttl_seconds=1800)
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
async def get_cached_library_agent_by_graph_id(
graph_id: str,
user_id: str,
@@ -69,7 +69,7 @@ async def get_cached_library_agent_by_graph_id(
# Cache library agent by store version ID for 1 hour - marketplace agents are more stable
@cached(maxsize=500, ttl_seconds=3600)
@cached(maxsize=500, ttl_seconds=3600, shared_cache=True)
async def get_cached_library_agent_by_store_version(
store_listing_version_id: str,
user_id: str,
@@ -85,7 +85,7 @@ async def get_cached_library_agent_by_store_version(
# Cache library presets list for 30 minutes
@cached(maxsize=500, ttl_seconds=1800)
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
async def get_cached_library_presets(
user_id: str,
page: int = 1,
@@ -100,7 +100,7 @@ async def get_cached_library_presets(
# Cache individual preset details for 30 minutes
@cached(maxsize=1000, ttl_seconds=1800)
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
async def get_cached_library_preset(
preset_id: str,
user_id: str,

View File

@@ -11,7 +11,7 @@ import backend.server.v2.store.db
# Cache user profiles for 1 hour per user
@cached(maxsize=1000, ttl_seconds=3600)
@cached(maxsize=1000, ttl_seconds=3600, shared_cache=True)
async def _get_cached_user_profile(user_id: str):
"""Cached helper to get user profile."""
return await backend.server.v2.store.db.get_user_profile(user_id)
@@ -19,7 +19,7 @@ async def _get_cached_user_profile(user_id: str):
# Cache store agents list for 15 minutes
# Different cache entries for different query combinations
@cached(maxsize=5000, ttl_seconds=900)
@cached(maxsize=5000, ttl_seconds=900, shared_cache=True)
async def _get_cached_store_agents(
featured: bool,
creator: str | None,
@@ -42,7 +42,7 @@ async def _get_cached_store_agents(
# Cache individual agent details for 15 minutes
@cached(maxsize=200, ttl_seconds=900)
@cached(maxsize=200, ttl_seconds=900, shared_cache=True)
async def _get_cached_agent_details(username: str, agent_name: str):
"""Cached helper to get agent details."""
return await backend.server.v2.store.db.get_store_agent_details(
@@ -51,7 +51,7 @@ async def _get_cached_agent_details(username: str, agent_name: str):
# Cache agent graphs for 1 hour
@cached(maxsize=200, ttl_seconds=3600)
@cached(maxsize=200, ttl_seconds=3600, shared_cache=True)
async def _get_cached_agent_graph(store_listing_version_id: str):
"""Cached helper to get agent graph."""
return await backend.server.v2.store.db.get_available_graph(
@@ -60,7 +60,7 @@ async def _get_cached_agent_graph(store_listing_version_id: str):
# Cache agent by version for 1 hour
@cached(maxsize=200, ttl_seconds=3600)
@cached(maxsize=200, ttl_seconds=3600, shared_cache=True)
async def _get_cached_store_agent_by_version(store_listing_version_id: str):
"""Cached helper to get store agent by version ID."""
return await backend.server.v2.store.db.get_store_agent_by_version_id(
@@ -69,7 +69,7 @@ async def _get_cached_store_agent_by_version(store_listing_version_id: str):
# Cache creators list for 1 hour
@cached(maxsize=200, ttl_seconds=3600)
@cached(maxsize=200, ttl_seconds=3600, shared_cache=True)
async def _get_cached_store_creators(
featured: bool,
search_query: str | None,
@@ -88,7 +88,7 @@ async def _get_cached_store_creators(
# Cache individual creator details for 1 hour
@cached(maxsize=100, ttl_seconds=3600)
@cached(maxsize=100, ttl_seconds=3600, shared_cache=True)
async def _get_cached_creator_details(username: str):
"""Cached helper to get creator details."""
return await backend.server.v2.store.db.get_store_creator_details(
@@ -97,7 +97,7 @@ async def _get_cached_creator_details(username: str):
# Cache user's own agents for 5 mins (shorter TTL as this changes more frequently)
@cached(maxsize=500, ttl_seconds=300)
@cached(maxsize=500, ttl_seconds=300, shared_cache=True)
async def _get_cached_my_agents(user_id: str, page: int, page_size: int):
"""Cached helper to get user's agents."""
return await backend.server.v2.store.db.get_my_agents(
@@ -106,7 +106,7 @@ async def _get_cached_my_agents(user_id: str, page: int, page_size: int):
# Cache user's submissions for 1 hour (shorter TTL as this changes frequently)
@cached(maxsize=500, ttl_seconds=3600)
@cached(maxsize=500, ttl_seconds=3600, shared_cache=True)
async def _get_cached_submissions(user_id: str, page: int, page_size: int):
"""Cached helper to get user's submissions."""
return await backend.server.v2.store.db.get_store_submissions(