mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
added redis caching
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
from functools import wraps
|
||||
@@ -20,6 +22,92 @@ R_co = TypeVar("R_co", covariant=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis client providers (can be set externally)
|
||||
_redis_client_provider = None
|
||||
_async_redis_client_provider = None
|
||||
|
||||
|
||||
def set_redis_client_provider(sync_provider=None, async_provider=None):
|
||||
"""
|
||||
Set external Redis client providers.
|
||||
|
||||
This allows the backend to inject its Redis clients into the cache system.
|
||||
|
||||
Args:
|
||||
sync_provider: A callable that returns a sync Redis client
|
||||
async_provider: A callable that returns an async Redis client
|
||||
"""
|
||||
global _redis_client_provider, _async_redis_client_provider
|
||||
if sync_provider:
|
||||
_redis_client_provider = sync_provider
|
||||
if async_provider:
|
||||
_async_redis_client_provider = async_provider
|
||||
|
||||
|
||||
def _get_redis_client():
|
||||
"""Get Redis client from provider or create a default one."""
|
||||
if _redis_client_provider:
|
||||
try:
|
||||
client = _redis_client_provider()
|
||||
if client:
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get Redis client from provider: {e}")
|
||||
|
||||
# Fallback to creating our own client
|
||||
try:
|
||||
from redis import Redis
|
||||
|
||||
client = Redis(
|
||||
host=os.getenv("REDIS_HOST", "localhost"),
|
||||
port=int(os.getenv("REDIS_PORT", "6379")),
|
||||
password=os.getenv("REDIS_PASSWORD", None),
|
||||
decode_responses=False, # We'll use pickle for serialization
|
||||
)
|
||||
client.ping()
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to connect to Redis: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _get_async_redis_client():
|
||||
"""Get async Redis client from provider or create a default one."""
|
||||
if _async_redis_client_provider:
|
||||
try:
|
||||
# Provider is an async function, we need to await it
|
||||
if inspect.iscoroutinefunction(_async_redis_client_provider):
|
||||
client = await _async_redis_client_provider()
|
||||
else:
|
||||
client = _async_redis_client_provider()
|
||||
if client:
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get async Redis client from provider: {e}")
|
||||
|
||||
# Fallback to creating our own client
|
||||
try:
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
client = AsyncRedis(
|
||||
host=os.getenv("REDIS_HOST", "localhost"),
|
||||
port=int(os.getenv("REDIS_PORT", "6379")),
|
||||
password=os.getenv("REDIS_PASSWORD", None),
|
||||
decode_responses=False, # We'll use pickle for serialization
|
||||
)
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create async Redis client: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _make_redis_key(func_name: str, key: tuple) -> str:
|
||||
"""Create a Redis key from function name and cache key."""
|
||||
# Convert the key to a string representation
|
||||
key_str = str(key)
|
||||
# Add a prefix to avoid collisions with other Redis usage
|
||||
return f"cache:{func_name}:{key_str}"
|
||||
|
||||
|
||||
def _make_hashable_key(
|
||||
args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
@@ -85,6 +173,7 @@ def cached(
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
ttl_seconds: int | None = None,
|
||||
shared_cache: bool = False,
|
||||
) -> Callable[[Callable], CachedFunction]:
|
||||
"""
|
||||
Thundering herd safe cache decorator for both sync and async functions.
|
||||
@@ -92,10 +181,13 @@ def cached(
|
||||
Uses double-checked locking to prevent multiple threads/coroutines from
|
||||
executing the expensive operation simultaneously during cache misses.
|
||||
|
||||
When shared_cache=True, uses Redis for distributed caching across instances.
|
||||
|
||||
Args:
|
||||
func: The function to cache (when used without parentheses)
|
||||
maxsize: Maximum number of cached entries
|
||||
maxsize: Maximum number of cached entries (ignored when using Redis)
|
||||
ttl_seconds: Time to live in seconds. If None, entries never expire
|
||||
shared_cache: If True, use Redis for distributed caching
|
||||
|
||||
Returns:
|
||||
Decorated function or decorator
|
||||
@@ -127,50 +219,100 @@ def cached(
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
current_time = time.time()
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
# Try Redis first if shared_cache is enabled
|
||||
if shared_cache:
|
||||
redis_client = await _get_async_redis_client()
|
||||
if redis_client:
|
||||
redis_key = _make_redis_key(target_func.__name__, key)
|
||||
try:
|
||||
# Check Redis cache
|
||||
await redis_client.ping() # Ensure connection is alive
|
||||
cached_value = await redis_client.get(redis_key)
|
||||
if cached_value:
|
||||
result = pickle.loads(cast(bytes, cached_value))
|
||||
logger.info(
|
||||
f"Redis cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
|
||||
)
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
async with cache_lock:
|
||||
# Double-check: another coroutine might have populated cache
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache read failed: {e}")
|
||||
# Fall through to execute function
|
||||
else:
|
||||
# Fast path: check local cache without lock
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
logger.info(
|
||||
f"Cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
|
||||
)
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.info(
|
||||
f"Cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
|
||||
)
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
async with cache_lock:
|
||||
# Double-check: another coroutine might have populated cache
|
||||
if shared_cache:
|
||||
redis_client = await _get_async_redis_client()
|
||||
if redis_client:
|
||||
redis_key = _make_redis_key(target_func.__name__, key)
|
||||
try:
|
||||
cached_value = await redis_client.get(redis_key)
|
||||
if cached_value:
|
||||
return pickle.loads(cast(bytes, cached_value))
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache read failed in lock: {e}")
|
||||
else:
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
logger.info(f"Cache miss for {target_func.__name__}")
|
||||
result = await target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
if shared_cache:
|
||||
redis_client = await _get_async_redis_client()
|
||||
if redis_client:
|
||||
redis_key = _make_redis_key(target_func.__name__, key)
|
||||
try:
|
||||
serialized = pickle.dumps(result)
|
||||
await redis_client.set(
|
||||
redis_key,
|
||||
serialized,
|
||||
ex=ttl_seconds if ttl_seconds else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache write failed: {e}")
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = (
|
||||
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
)
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = (
|
||||
list(cache_storage.keys())[:-cutoff]
|
||||
if cutoff > 0
|
||||
else []
|
||||
)
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
|
||||
return result
|
||||
|
||||
@@ -185,50 +327,99 @@ def cached(
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
current_time = time.time()
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
# Try Redis first if shared_cache is enabled
|
||||
if shared_cache:
|
||||
redis_client = _get_redis_client()
|
||||
if redis_client:
|
||||
redis_key = _make_redis_key(target_func.__name__, key)
|
||||
try:
|
||||
# Check Redis cache
|
||||
cached_value = redis_client.get(redis_key)
|
||||
if cached_value:
|
||||
result = pickle.loads(cast(bytes, cached_value))
|
||||
logger.info(
|
||||
f"Redis cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
|
||||
)
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
with cache_lock:
|
||||
# Double-check: another thread might have populated cache
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache read failed: {e}")
|
||||
# Fall through to execute function
|
||||
else:
|
||||
# Fast path: check local cache without lock
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
logger.info(
|
||||
f"Cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
|
||||
)
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.info(
|
||||
f"Cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
|
||||
)
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
with cache_lock:
|
||||
# Double-check: another thread might have populated cache
|
||||
if shared_cache:
|
||||
redis_client = _get_redis_client()
|
||||
if redis_client:
|
||||
redis_key = _make_redis_key(target_func.__name__, key)
|
||||
try:
|
||||
cached_value = redis_client.get(redis_key)
|
||||
if cached_value:
|
||||
return pickle.loads(cast(bytes, cached_value))
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache read failed in lock: {e}")
|
||||
else:
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
logger.info(f"Cache miss for {target_func.__name__}")
|
||||
result = target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
if shared_cache:
|
||||
redis_client = _get_redis_client()
|
||||
if redis_client:
|
||||
redis_key = _make_redis_key(target_func.__name__, key)
|
||||
try:
|
||||
serialized = pickle.dumps(result)
|
||||
redis_client.set(
|
||||
redis_key,
|
||||
serialized,
|
||||
ex=ttl_seconds if ttl_seconds else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache write failed: {e}")
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = (
|
||||
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
)
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = (
|
||||
list(cache_storage.keys())[:-cutoff]
|
||||
if cutoff > 0
|
||||
else []
|
||||
)
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
|
||||
return result
|
||||
|
||||
@@ -236,22 +427,105 @@ def cached(
|
||||
|
||||
# Add cache management methods
|
||||
def cache_clear() -> None:
|
||||
cache_storage.clear()
|
||||
"""Clear all cached entries."""
|
||||
if shared_cache:
|
||||
redis_client = (
|
||||
_get_redis_client()
|
||||
if not inspect.iscoroutinefunction(target_func)
|
||||
else None
|
||||
)
|
||||
if redis_client:
|
||||
try:
|
||||
# Clear all cache entries for this function
|
||||
pattern = f"cache:{target_func.__name__}:*"
|
||||
for key in redis_client.scan_iter(match=pattern):
|
||||
redis_client.delete(key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache clear failed: {e}")
|
||||
else:
|
||||
cache_storage.clear()
|
||||
|
||||
def cache_info() -> dict[str, int | None]:
|
||||
def cache_info() -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
if shared_cache:
|
||||
redis_client = (
|
||||
_get_redis_client()
|
||||
if not inspect.iscoroutinefunction(target_func)
|
||||
else None
|
||||
)
|
||||
if redis_client:
|
||||
try:
|
||||
pattern = f"cache:{target_func.__name__}:*"
|
||||
count = sum(1 for _ in redis_client.scan_iter(match=pattern))
|
||||
return {
|
||||
"size": count,
|
||||
"maxsize": None, # Not applicable for Redis
|
||||
"ttl_seconds": ttl_seconds,
|
||||
"shared_cache": True,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache info failed: {e}")
|
||||
return {
|
||||
"size": 0,
|
||||
"maxsize": None,
|
||||
"ttl_seconds": ttl_seconds,
|
||||
"shared_cache": True,
|
||||
"error": str(e),
|
||||
}
|
||||
return {
|
||||
"size": len(cache_storage),
|
||||
"maxsize": maxsize,
|
||||
"ttl_seconds": ttl_seconds,
|
||||
"shared_cache": False,
|
||||
}
|
||||
|
||||
def cache_delete(*args, **kwargs) -> bool:
|
||||
"""Delete a specific cache entry. Returns True if entry existed."""
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key in cache_storage:
|
||||
del cache_storage[key]
|
||||
return True
|
||||
return False
|
||||
# Create appropriate cache_delete based on whether function is async
|
||||
if inspect.iscoroutinefunction(target_func):
|
||||
|
||||
async def async_cache_delete(*args, **kwargs) -> bool:
|
||||
"""Delete a specific cache entry. Returns True if entry existed."""
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
|
||||
if shared_cache:
|
||||
redis_client = await _get_async_redis_client()
|
||||
if redis_client:
|
||||
redis_key = _make_redis_key(target_func.__name__, key)
|
||||
try:
|
||||
result = await redis_client.delete(redis_key)
|
||||
return cast(int, result) > 0
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache delete failed: {e}")
|
||||
return False
|
||||
else:
|
||||
if key in cache_storage:
|
||||
del cache_storage[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
cache_delete = async_cache_delete
|
||||
else:
|
||||
|
||||
def sync_cache_delete(*args, **kwargs) -> bool:
|
||||
"""Delete a specific cache entry. Returns True if entry existed."""
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
|
||||
if shared_cache:
|
||||
redis_client = _get_redis_client()
|
||||
if redis_client:
|
||||
redis_key = _make_redis_key(target_func.__name__, key)
|
||||
try:
|
||||
result = redis_client.delete(redis_key)
|
||||
return cast(int, result) > 0
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache delete failed: {e}")
|
||||
return False
|
||||
else:
|
||||
if key in cache_storage:
|
||||
del cache_storage[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
cache_delete = sync_cache_delete
|
||||
|
||||
setattr(wrapper, "cache_clear", cache_clear)
|
||||
setattr(wrapper, "cache_info", cache_info)
|
||||
|
||||
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Test Redis cache functionality.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
|
||||
# Test with Redis cache enabled
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_cache_async():
|
||||
"""Test async function with Redis cache."""
|
||||
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=60, shared_cache=True)
|
||||
async def expensive_async_operation(x: int, y: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.1)
|
||||
return x + y
|
||||
|
||||
# First call should execute function
|
||||
result1 = await expensive_async_operation(5, 3)
|
||||
assert result1 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args should use cache
|
||||
result2 = await expensive_async_operation(5, 3)
|
||||
assert result2 == 8
|
||||
assert call_count == 1 # Should not increment
|
||||
|
||||
# Different args should execute function again
|
||||
result3 = await expensive_async_operation(10, 5)
|
||||
assert result3 == 15
|
||||
assert call_count == 2
|
||||
|
||||
# Test cache_delete
|
||||
deleted = expensive_async_operation.cache_delete(5, 3)
|
||||
assert isinstance(deleted, bool) # Depends on Redis availability
|
||||
|
||||
# Test cache_clear
|
||||
expensive_async_operation.cache_clear()
|
||||
|
||||
# Test cache_info
|
||||
info = expensive_async_operation.cache_info()
|
||||
assert "ttl_seconds" in info
|
||||
assert info["ttl_seconds"] == 60
|
||||
assert "shared_cache" in info
|
||||
|
||||
|
||||
def test_redis_cache_sync():
|
||||
"""Test sync function with Redis cache."""
|
||||
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=60, shared_cache=True)
|
||||
def expensive_sync_operation(x: int, y: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
time.sleep(0.1)
|
||||
return x * y
|
||||
|
||||
# First call should execute function
|
||||
result1 = expensive_sync_operation(5, 3)
|
||||
assert result1 == 15
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args should use cache
|
||||
result2 = expensive_sync_operation(5, 3)
|
||||
assert result2 == 15
|
||||
assert call_count == 1 # Should not increment
|
||||
|
||||
# Different args should execute function again
|
||||
result3 = expensive_sync_operation(10, 5)
|
||||
assert result3 == 50
|
||||
assert call_count == 2
|
||||
|
||||
# Test cache management functions
|
||||
expensive_sync_operation.cache_clear()
|
||||
info = expensive_sync_operation.cache_info()
|
||||
assert info["shared_cache"] is True
|
||||
|
||||
|
||||
def test_fallback_to_local_cache_when_redis_unavailable():
|
||||
"""Test that cache falls back to local when Redis is unavailable."""
|
||||
|
||||
with patch("autogpt_libs.utils.cache._get_redis_client", return_value=None):
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=60, shared_cache=True)
|
||||
def operation(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
|
||||
# Should still work with local cache
|
||||
result1 = operation(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
result2 = operation(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1 # Cached locally
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_cache_with_complex_types():
|
||||
"""Test Redis cache with complex data types."""
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
async def get_complex_data(user_id: str, filters: dict) -> dict:
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"filters": filters,
|
||||
"data": [1, 2, 3, 4, 5],
|
||||
"nested": {"key1": "value1", "key2": ["a", "b", "c"]},
|
||||
}
|
||||
|
||||
result1 = await get_complex_data("user123", {"status": "active", "limit": 10})
|
||||
result2 = await get_complex_data("user123", {"status": "active", "limit": 10})
|
||||
|
||||
assert result1 == result2
|
||||
assert result1["user_id"] == "user123"
|
||||
assert result1["filters"]["status"] == "active"
|
||||
|
||||
|
||||
def test_local_cache_without_shared():
|
||||
"""Test that shared_cache=False uses local cache only."""
|
||||
|
||||
call_count = 0
|
||||
|
||||
@cached(maxsize=10, ttl_seconds=30, shared_cache=False)
|
||||
def local_only_operation(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
result1 = local_only_operation(4)
|
||||
assert result1 == 16
|
||||
assert call_count == 1
|
||||
|
||||
result2 = local_only_operation(4)
|
||||
assert result2 == 16
|
||||
assert call_count == 1
|
||||
|
||||
# Check cache info shows it's not shared
|
||||
info = local_only_operation.cache_info()
|
||||
assert info["shared_cache"] is False
|
||||
assert info["maxsize"] == 10
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run basic tests
|
||||
print("Testing sync Redis cache...")
|
||||
test_redis_cache_sync()
|
||||
print("✓ Sync Redis cache test passed")
|
||||
|
||||
print("Testing local cache...")
|
||||
test_local_cache_without_shared()
|
||||
print("✓ Local cache test passed")
|
||||
|
||||
print("Testing fallback...")
|
||||
test_fallback_to_local_cache_when_redis_unavailable()
|
||||
print("✓ Fallback test passed")
|
||||
|
||||
print("\nAll tests passed!")
|
||||
@@ -70,6 +70,11 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
await backend.data.db.connect()
|
||||
|
||||
# Initialize Redis clients for cache system
|
||||
from backend.util.cache_redis_init import initialize_cache_redis
|
||||
|
||||
initialize_cache_redis()
|
||||
|
||||
# Ensure SDK auto-registration is patched before initializing blocks
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
|
||||
@@ -45,7 +45,8 @@ def get_cached_blocks() -> Sequence[dict]:
|
||||
|
||||
|
||||
# Cache user's graphs list for 15 minutes
|
||||
@cached(maxsize=1000, ttl_seconds=900)
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=1000, ttl_seconds=900, shared_cache=True)
|
||||
async def get_cached_graphs(
|
||||
user_id: str,
|
||||
page: int,
|
||||
@@ -60,7 +61,8 @@ async def get_cached_graphs(
|
||||
|
||||
|
||||
# Cache individual graph details for 30 minutes
|
||||
@cached(maxsize=500, ttl_seconds=1800)
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_graph(
|
||||
graph_id: str,
|
||||
version: int | None,
|
||||
@@ -76,7 +78,8 @@ async def get_cached_graph(
|
||||
|
||||
|
||||
# Cache graph versions for 30 minutes
|
||||
@cached(maxsize=500, ttl_seconds=1800)
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_graph_all_versions(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
@@ -92,7 +95,8 @@ async def get_cached_graph_all_versions(
|
||||
|
||||
|
||||
# Cache graph executions for 10 seconds.
|
||||
@cached(maxsize=1000, ttl_seconds=10)
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=1000, ttl_seconds=10, shared_cache=True)
|
||||
async def get_cached_graph_executions(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
@@ -109,6 +113,7 @@ async def get_cached_graph_executions(
|
||||
|
||||
|
||||
# Cache all user executions for 10 seconds.
|
||||
# No shared_cache - cache_delete not used for this function
|
||||
@cached(maxsize=500, ttl_seconds=10)
|
||||
async def get_cached_graphs_executions(
|
||||
user_id: str,
|
||||
@@ -124,6 +129,7 @@ async def get_cached_graphs_executions(
|
||||
|
||||
|
||||
# Cache individual execution details for 10 seconds.
|
||||
# No shared_cache - cache_delete not used for this function
|
||||
@cached(maxsize=1000, ttl_seconds=10)
|
||||
async def get_cached_graph_execution(
|
||||
graph_exec_id: str,
|
||||
@@ -141,7 +147,8 @@ async def get_cached_graph_execution(
|
||||
|
||||
|
||||
# Cache user timezone for 1 hour
|
||||
@cached(maxsize=1000, ttl_seconds=3600)
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=1000, ttl_seconds=3600, shared_cache=True)
|
||||
async def get_cached_user_timezone(user_id: str):
|
||||
"""Cached helper to get user timezone."""
|
||||
user = await user_db.get_user_by_id(user_id)
|
||||
@@ -149,7 +156,8 @@ async def get_cached_user_timezone(user_id: str):
|
||||
|
||||
|
||||
# Cache user preferences for 30 minutes
|
||||
@cached(maxsize=1000, ttl_seconds=1800)
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_user_preferences(user_id: str):
|
||||
"""Cached helper to get user notification preferences."""
|
||||
return await user_db.get_user_notification_preference(user_id)
|
||||
|
||||
@@ -13,7 +13,8 @@ import backend.server.v2.library.db
|
||||
|
||||
|
||||
# Cache library agents list for 10 minutes
|
||||
@cached(maxsize=1000, ttl_seconds=600)
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=1000, ttl_seconds=600, shared_cache=True)
|
||||
async def get_cached_library_agents(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
@@ -28,7 +29,8 @@ async def get_cached_library_agents(
|
||||
|
||||
|
||||
# Cache user's favorite agents for 5 minutes - favorites change more frequently
|
||||
@cached(maxsize=500, ttl_seconds=300)
|
||||
# No shared_cache - cache_delete not used for this function (based on search results)
|
||||
@cached(maxsize=500, ttl_seconds=300, shared_cache=False)
|
||||
async def get_cached_library_agent_favorites(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
@@ -43,7 +45,8 @@ async def get_cached_library_agent_favorites(
|
||||
|
||||
|
||||
# Cache individual library agent details for 30 minutes
|
||||
@cached(maxsize=1000, ttl_seconds=1800)
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str,
|
||||
@@ -56,7 +59,7 @@ async def get_cached_library_agent(
|
||||
|
||||
|
||||
# Cache library agent by graph ID for 30 minutes
|
||||
@cached(maxsize=1000, ttl_seconds=1800)
|
||||
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_library_agent_by_graph_id(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
@@ -69,7 +72,7 @@ async def get_cached_library_agent_by_graph_id(
|
||||
|
||||
|
||||
# Cache library agent by store version ID for 1 hour - marketplace agents are more stable
|
||||
@cached(maxsize=500, ttl_seconds=3600)
|
||||
@cached(maxsize=500, ttl_seconds=3600, shared_cache=True)
|
||||
async def get_cached_library_agent_by_store_version(
|
||||
store_listing_version_id: str,
|
||||
user_id: str,
|
||||
@@ -85,7 +88,8 @@ async def get_cached_library_agent_by_store_version(
|
||||
|
||||
|
||||
# Cache library presets list for 30 minutes
|
||||
@cached(maxsize=500, ttl_seconds=1800)
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_library_presets(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
@@ -100,7 +104,8 @@ async def get_cached_library_presets(
|
||||
|
||||
|
||||
# Cache individual preset details for 30 minutes
|
||||
@cached(maxsize=1000, ttl_seconds=1800)
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_library_preset(
|
||||
preset_id: str,
|
||||
user_id: str,
|
||||
|
||||
@@ -11,7 +11,8 @@ import backend.server.v2.store.db
|
||||
|
||||
|
||||
# Cache user profiles for 1 hour per user
|
||||
@cached(maxsize=1000, ttl_seconds=3600)
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=1000, ttl_seconds=3600, shared_cache=True)
|
||||
async def _get_cached_user_profile(user_id: str):
|
||||
"""Cached helper to get user profile."""
|
||||
return await backend.server.v2.store.db.get_user_profile(user_id)
|
||||
@@ -19,6 +20,7 @@ async def _get_cached_user_profile(user_id: str):
|
||||
|
||||
# Cache store agents list for 15 minutes
|
||||
# Different cache entries for different query combinations
|
||||
# No shared_cache - cache_delete not used for this function
|
||||
@cached(maxsize=5000, ttl_seconds=900)
|
||||
async def _get_cached_store_agents(
|
||||
featured: bool,
|
||||
@@ -42,6 +44,7 @@ async def _get_cached_store_agents(
|
||||
|
||||
|
||||
# Cache individual agent details for 15 minutes
|
||||
# No shared_cache - cache_delete not used for this function
|
||||
@cached(maxsize=200, ttl_seconds=900)
|
||||
async def _get_cached_agent_details(username: str, agent_name: str):
|
||||
"""Cached helper to get agent details."""
|
||||
@@ -51,6 +54,7 @@ async def _get_cached_agent_details(username: str, agent_name: str):
|
||||
|
||||
|
||||
# Cache agent graphs for 1 hour
|
||||
# No shared_cache - cache_delete not used for this function
|
||||
@cached(maxsize=200, ttl_seconds=3600)
|
||||
async def _get_cached_agent_graph(store_listing_version_id: str):
|
||||
"""Cached helper to get agent graph."""
|
||||
@@ -60,6 +64,7 @@ async def _get_cached_agent_graph(store_listing_version_id: str):
|
||||
|
||||
|
||||
# Cache agent by version for 1 hour
|
||||
# No shared_cache - cache_delete not used for this function
|
||||
@cached(maxsize=200, ttl_seconds=3600)
|
||||
async def _get_cached_store_agent_by_version(store_listing_version_id: str):
|
||||
"""Cached helper to get store agent by version ID."""
|
||||
@@ -69,6 +74,7 @@ async def _get_cached_store_agent_by_version(store_listing_version_id: str):
|
||||
|
||||
|
||||
# Cache creators list for 1 hour
|
||||
# No shared_cache - cache_delete not used for this function
|
||||
@cached(maxsize=200, ttl_seconds=3600)
|
||||
async def _get_cached_store_creators(
|
||||
featured: bool,
|
||||
@@ -88,6 +94,7 @@ async def _get_cached_store_creators(
|
||||
|
||||
|
||||
# Cache individual creator details for 1 hour
|
||||
# No shared_cache - cache_delete not used for this function
|
||||
@cached(maxsize=100, ttl_seconds=3600)
|
||||
async def _get_cached_creator_details(username: str):
|
||||
"""Cached helper to get creator details."""
|
||||
@@ -97,6 +104,7 @@ async def _get_cached_creator_details(username: str):
|
||||
|
||||
|
||||
# Cache user's own agents for 5 mins (shorter TTL as this changes more frequently)
|
||||
# No shared_cache - cache_delete not used for this function
|
||||
@cached(maxsize=500, ttl_seconds=300)
|
||||
async def _get_cached_my_agents(user_id: str, page: int, page_size: int):
|
||||
"""Cached helper to get user's agents."""
|
||||
@@ -106,7 +114,8 @@ async def _get_cached_my_agents(user_id: str, page: int, page_size: int):
|
||||
|
||||
|
||||
# Cache user's submissions for 1 hour (shorter TTL as this changes frequently)
|
||||
@cached(maxsize=500, ttl_seconds=3600)
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=500, ttl_seconds=3600, shared_cache=True)
|
||||
async def _get_cached_submissions(user_id: str, page: int, page_size: int):
|
||||
"""Cached helper to get user's submissions."""
|
||||
return await backend.server.v2.store.db.get_store_submissions(
|
||||
|
||||
76
autogpt_platform/backend/backend/util/cache_redis_init.py
Normal file
76
autogpt_platform/backend/backend/util/cache_redis_init.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
Initialize Redis clients for the cache system.
|
||||
|
||||
This module bridges the gap between backend's Redis client and autogpt_libs cache system.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from autogpt_libs.utils.cache import set_redis_client_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_cache_redis():
|
||||
"""
|
||||
Initialize the cache system with backend's Redis clients.
|
||||
|
||||
This function sets up the cache system to use the backend's
|
||||
existing Redis connection instead of creating its own.
|
||||
"""
|
||||
try:
|
||||
from backend.data.redis_client import HOST, PASSWORD, PORT
|
||||
|
||||
# Create provider functions that create new binary-mode clients
|
||||
def get_sync_redis_for_cache():
|
||||
"""Get sync Redis client configured for cache (binary mode)."""
|
||||
try:
|
||||
from redis import Redis
|
||||
|
||||
client = Redis(
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
password=PASSWORD,
|
||||
decode_responses=False, # Binary mode for pickle
|
||||
)
|
||||
# Test the connection
|
||||
client.ping()
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get Redis client for cache: {e}")
|
||||
return None
|
||||
|
||||
async def get_async_redis_for_cache():
|
||||
"""Get async Redis client configured for cache (binary mode)."""
|
||||
try:
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
client = AsyncRedis(
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
password=PASSWORD,
|
||||
decode_responses=False, # Binary mode for pickle
|
||||
)
|
||||
# Test the connection
|
||||
await client.ping()
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get async Redis client for cache: {e}")
|
||||
return None
|
||||
|
||||
# Set the providers in the cache system
|
||||
set_redis_client_provider(
|
||||
sync_provider=get_sync_redis_for_cache,
|
||||
async_provider=get_async_redis_for_cache,
|
||||
)
|
||||
|
||||
logger.info("Cache system initialized with backend Redis clients")
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import Redis clients, cache will use fallback: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize cache Redis: {e}")
|
||||
|
||||
|
||||
# Auto-initialize when module is imported
|
||||
initialize_cache_redis()
|
||||
Reference in New Issue
Block a user