added redis caching

This commit is contained in:
Swifty
2025-09-25 16:45:40 +02:00
parent 1afebcf96b
commit ae6ad35bf6
7 changed files with 628 additions and 79 deletions

View File

@@ -1,6 +1,8 @@
import asyncio
import inspect
import logging
import os
import pickle
import threading
import time
from functools import wraps
@@ -20,6 +22,92 @@ R_co = TypeVar("R_co", covariant=True)
logger = logging.getLogger(__name__)
# Redis client providers (can be set externally)
_redis_client_provider = None
_async_redis_client_provider = None
def set_redis_client_provider(sync_provider=None, async_provider=None):
"""
Set external Redis client providers.
This allows the backend to inject its Redis clients into the cache system.
Args:
sync_provider: A callable that returns a sync Redis client
async_provider: A callable that returns an async Redis client
"""
global _redis_client_provider, _async_redis_client_provider
if sync_provider:
_redis_client_provider = sync_provider
if async_provider:
_async_redis_client_provider = async_provider
def _get_redis_client():
"""Get Redis client from provider or create a default one."""
if _redis_client_provider:
try:
client = _redis_client_provider()
if client:
return client
except Exception as e:
logger.warning(f"Failed to get Redis client from provider: {e}")
# Fallback to creating our own client
try:
from redis import Redis
client = Redis(
host=os.getenv("REDIS_HOST", "localhost"),
port=int(os.getenv("REDIS_PORT", "6379")),
password=os.getenv("REDIS_PASSWORD", None),
decode_responses=False, # We'll use pickle for serialization
)
client.ping()
return client
except Exception as e:
logger.warning(f"Failed to connect to Redis: {e}")
return None
async def _get_async_redis_client():
"""Get async Redis client from provider or create a default one."""
if _async_redis_client_provider:
try:
# Provider is an async function, we need to await it
if inspect.iscoroutinefunction(_async_redis_client_provider):
client = await _async_redis_client_provider()
else:
client = _async_redis_client_provider()
if client:
return client
except Exception as e:
logger.warning(f"Failed to get async Redis client from provider: {e}")
# Fallback to creating our own client
try:
from redis.asyncio import Redis as AsyncRedis
client = AsyncRedis(
host=os.getenv("REDIS_HOST", "localhost"),
port=int(os.getenv("REDIS_PORT", "6379")),
password=os.getenv("REDIS_PASSWORD", None),
decode_responses=False, # We'll use pickle for serialization
)
return client
except Exception as e:
logger.warning(f"Failed to create async Redis client: {e}")
return None
def _make_redis_key(func_name: str, key: tuple) -> str:
"""Create a Redis key from function name and cache key."""
# Convert the key to a string representation
key_str = str(key)
# Add a prefix to avoid collisions with other Redis usage
return f"cache:{func_name}:{key_str}"
def _make_hashable_key(
args: tuple[Any, ...], kwargs: dict[str, Any]
@@ -85,6 +173,7 @@ def cached(
*,
maxsize: int = 128,
ttl_seconds: int | None = None,
shared_cache: bool = False,
) -> Callable[[Callable], CachedFunction]:
"""
Thundering herd safe cache decorator for both sync and async functions.
@@ -92,10 +181,13 @@ def cached(
Uses double-checked locking to prevent multiple threads/coroutines from
executing the expensive operation simultaneously during cache misses.
When shared_cache=True, uses Redis for distributed caching across instances.
Args:
func: The function to cache (when used without parentheses)
maxsize: Maximum number of cached entries
maxsize: Maximum number of cached entries (ignored when using Redis)
ttl_seconds: Time to live in seconds. If None, entries never expire
shared_cache: If True, use Redis for distributed caching
Returns:
Decorated function or decorator
@@ -127,50 +219,100 @@ def cached(
key = _make_hashable_key(args, kwargs)
current_time = time.time()
# Fast path: check cache without lock
if key in cache_storage:
if ttl_seconds is None:
logger.debug(f"Cache hit for {target_func.__name__}")
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.debug(f"Cache hit for {target_func.__name__}")
# Try Redis first if shared_cache is enabled
if shared_cache:
redis_client = await _get_async_redis_client()
if redis_client:
redis_key = _make_redis_key(target_func.__name__, key)
try:
# Check Redis cache
await redis_client.ping() # Ensure connection is alive
cached_value = await redis_client.get(redis_key)
if cached_value:
result = pickle.loads(cast(bytes, cached_value))
logger.info(
f"Redis cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
)
return result
# Slow path: acquire lock for cache miss/expiry
async with cache_lock:
# Double-check: another coroutine might have populated cache
except Exception as e:
logger.warning(f"Redis cache read failed: {e}")
# Fall through to execute function
else:
# Fast path: check local cache without lock
if key in cache_storage:
if ttl_seconds is None:
logger.info(
f"Cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
)
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.info(
f"Cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
)
return result
# Slow path: acquire lock for cache miss/expiry
async with cache_lock:
# Double-check: another coroutine might have populated cache
if shared_cache:
redis_client = await _get_async_redis_client()
if redis_client:
redis_key = _make_redis_key(target_func.__name__, key)
try:
cached_value = await redis_client.get(redis_key)
if cached_value:
return pickle.loads(cast(bytes, cached_value))
except Exception as e:
logger.warning(f"Redis cache read failed in lock: {e}")
else:
if key in cache_storage:
if ttl_seconds is None:
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {target_func.__name__}")
logger.info(f"Cache miss for {target_func.__name__}")
result = await target_func(*args, **kwargs)
# Store result
if ttl_seconds is None:
cache_storage[key] = result
if shared_cache:
redis_client = await _get_async_redis_client()
if redis_client:
redis_key = _make_redis_key(target_func.__name__, key)
try:
serialized = pickle.dumps(result)
await redis_client.set(
redis_key,
serialized,
ex=ttl_seconds if ttl_seconds else None,
)
except Exception as e:
logger.warning(f"Redis cache write failed: {e}")
else:
cache_storage[key] = (result, current_time)
if ttl_seconds is None:
cache_storage[key] = result
else:
cache_storage[key] = (result, current_time)
# Cleanup if needed
if len(cache_storage) > maxsize:
cutoff = maxsize // 2
oldest_keys = (
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
)
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
# Cleanup if needed
if len(cache_storage) > maxsize:
cutoff = maxsize // 2
oldest_keys = (
list(cache_storage.keys())[:-cutoff]
if cutoff > 0
else []
)
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
return result
@@ -185,50 +327,99 @@ def cached(
key = _make_hashable_key(args, kwargs)
current_time = time.time()
# Fast path: check cache without lock
if key in cache_storage:
if ttl_seconds is None:
logger.debug(f"Cache hit for {target_func.__name__}")
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.debug(f"Cache hit for {target_func.__name__}")
# Try Redis first if shared_cache is enabled
if shared_cache:
redis_client = _get_redis_client()
if redis_client:
redis_key = _make_redis_key(target_func.__name__, key)
try:
# Check Redis cache
cached_value = redis_client.get(redis_key)
if cached_value:
result = pickle.loads(cast(bytes, cached_value))
logger.info(
f"Redis cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
)
return result
# Slow path: acquire lock for cache miss/expiry
with cache_lock:
# Double-check: another thread might have populated cache
except Exception as e:
logger.warning(f"Redis cache read failed: {e}")
# Fall through to execute function
else:
# Fast path: check local cache without lock
if key in cache_storage:
if ttl_seconds is None:
logger.info(
f"Cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
)
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.info(
f"Cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
)
return result
# Slow path: acquire lock for cache miss/expiry
with cache_lock:
# Double-check: another thread might have populated cache
if shared_cache:
redis_client = _get_redis_client()
if redis_client:
redis_key = _make_redis_key(target_func.__name__, key)
try:
cached_value = redis_client.get(redis_key)
if cached_value:
return pickle.loads(cast(bytes, cached_value))
except Exception as e:
logger.warning(f"Redis cache read failed in lock: {e}")
else:
if key in cache_storage:
if ttl_seconds is None:
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {target_func.__name__}")
logger.info(f"Cache miss for {target_func.__name__}")
result = target_func(*args, **kwargs)
# Store result
if ttl_seconds is None:
cache_storage[key] = result
if shared_cache:
redis_client = _get_redis_client()
if redis_client:
redis_key = _make_redis_key(target_func.__name__, key)
try:
serialized = pickle.dumps(result)
redis_client.set(
redis_key,
serialized,
ex=ttl_seconds if ttl_seconds else None,
)
except Exception as e:
logger.warning(f"Redis cache write failed: {e}")
else:
cache_storage[key] = (result, current_time)
if ttl_seconds is None:
cache_storage[key] = result
else:
cache_storage[key] = (result, current_time)
# Cleanup if needed
if len(cache_storage) > maxsize:
cutoff = maxsize // 2
oldest_keys = (
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
)
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
# Cleanup if needed
if len(cache_storage) > maxsize:
cutoff = maxsize // 2
oldest_keys = (
list(cache_storage.keys())[:-cutoff]
if cutoff > 0
else []
)
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
return result
@@ -236,22 +427,105 @@ def cached(
# Add cache management methods
def cache_clear() -> None:
cache_storage.clear()
"""Clear all cached entries."""
if shared_cache:
redis_client = (
_get_redis_client()
if not inspect.iscoroutinefunction(target_func)
else None
)
if redis_client:
try:
# Clear all cache entries for this function
pattern = f"cache:{target_func.__name__}:*"
for key in redis_client.scan_iter(match=pattern):
redis_client.delete(key)
except Exception as e:
logger.warning(f"Redis cache clear failed: {e}")
else:
cache_storage.clear()
def cache_info() -> dict[str, int | None]:
def cache_info() -> dict[str, Any]:
"""Get cache statistics."""
if shared_cache:
redis_client = (
_get_redis_client()
if not inspect.iscoroutinefunction(target_func)
else None
)
if redis_client:
try:
pattern = f"cache:{target_func.__name__}:*"
count = sum(1 for _ in redis_client.scan_iter(match=pattern))
return {
"size": count,
"maxsize": None, # Not applicable for Redis
"ttl_seconds": ttl_seconds,
"shared_cache": True,
}
except Exception as e:
logger.warning(f"Redis cache info failed: {e}")
return {
"size": 0,
"maxsize": None,
"ttl_seconds": ttl_seconds,
"shared_cache": True,
"error": str(e),
}
return {
"size": len(cache_storage),
"maxsize": maxsize,
"ttl_seconds": ttl_seconds,
"shared_cache": False,
}
def cache_delete(*args, **kwargs) -> bool:
"""Delete a specific cache entry. Returns True if entry existed."""
key = _make_hashable_key(args, kwargs)
if key in cache_storage:
del cache_storage[key]
return True
return False
# Create appropriate cache_delete based on whether function is async
if inspect.iscoroutinefunction(target_func):
async def async_cache_delete(*args, **kwargs) -> bool:
"""Delete a specific cache entry. Returns True if entry existed."""
key = _make_hashable_key(args, kwargs)
if shared_cache:
redis_client = await _get_async_redis_client()
if redis_client:
redis_key = _make_redis_key(target_func.__name__, key)
try:
result = await redis_client.delete(redis_key)
return cast(int, result) > 0
except Exception as e:
logger.warning(f"Redis cache delete failed: {e}")
return False
else:
if key in cache_storage:
del cache_storage[key]
return True
return False
cache_delete = async_cache_delete
else:
def sync_cache_delete(*args, **kwargs) -> bool:
"""Delete a specific cache entry. Returns True if entry existed."""
key = _make_hashable_key(args, kwargs)
if shared_cache:
redis_client = _get_redis_client()
if redis_client:
redis_key = _make_redis_key(target_func.__name__, key)
try:
result = redis_client.delete(redis_key)
return cast(int, result) > 0
except Exception as e:
logger.warning(f"Redis cache delete failed: {e}")
return False
else:
if key in cache_storage:
del cache_storage[key]
return True
return False
cache_delete = sync_cache_delete
setattr(wrapper, "cache_clear", cache_clear)
setattr(wrapper, "cache_info", cache_info)

View File

@@ -0,0 +1,172 @@
"""
Test Redis cache functionality.
"""
import asyncio
import time
from unittest.mock import patch
import pytest
from autogpt_libs.utils.cache import cached
# Test with Redis cache enabled
@pytest.mark.asyncio
async def test_redis_cache_async():
"""Test async function with Redis cache."""
call_count = 0
@cached(ttl_seconds=60, shared_cache=True)
async def expensive_async_operation(x: int, y: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.1)
return x + y
# First call should execute function
result1 = await expensive_async_operation(5, 3)
assert result1 == 8
assert call_count == 1
# Second call with same args should use cache
result2 = await expensive_async_operation(5, 3)
assert result2 == 8
assert call_count == 1 # Should not increment
# Different args should execute function again
result3 = await expensive_async_operation(10, 5)
assert result3 == 15
assert call_count == 2
# Test cache_delete
deleted = expensive_async_operation.cache_delete(5, 3)
assert isinstance(deleted, bool) # Depends on Redis availability
# Test cache_clear
expensive_async_operation.cache_clear()
# Test cache_info
info = expensive_async_operation.cache_info()
assert "ttl_seconds" in info
assert info["ttl_seconds"] == 60
assert "shared_cache" in info
def test_redis_cache_sync():
"""Test sync function with Redis cache."""
call_count = 0
@cached(ttl_seconds=60, shared_cache=True)
def expensive_sync_operation(x: int, y: int) -> int:
nonlocal call_count
call_count += 1
time.sleep(0.1)
return x * y
# First call should execute function
result1 = expensive_sync_operation(5, 3)
assert result1 == 15
assert call_count == 1
# Second call with same args should use cache
result2 = expensive_sync_operation(5, 3)
assert result2 == 15
assert call_count == 1 # Should not increment
# Different args should execute function again
result3 = expensive_sync_operation(10, 5)
assert result3 == 50
assert call_count == 2
# Test cache management functions
expensive_sync_operation.cache_clear()
info = expensive_sync_operation.cache_info()
assert info["shared_cache"] is True
def test_fallback_to_local_cache_when_redis_unavailable():
"""Test that cache falls back to local when Redis is unavailable."""
with patch("autogpt_libs.utils.cache._get_redis_client", return_value=None):
call_count = 0
@cached(ttl_seconds=60, shared_cache=True)
def operation(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 2
# Should still work with local cache
result1 = operation(5)
assert result1 == 10
assert call_count == 1
result2 = operation(5)
assert result2 == 10
assert call_count == 1 # Cached locally
@pytest.mark.asyncio
async def test_redis_cache_with_complex_types():
"""Test Redis cache with complex data types."""
@cached(ttl_seconds=30, shared_cache=True)
async def get_complex_data(user_id: str, filters: dict) -> dict:
return {
"user_id": user_id,
"filters": filters,
"data": [1, 2, 3, 4, 5],
"nested": {"key1": "value1", "key2": ["a", "b", "c"]},
}
result1 = await get_complex_data("user123", {"status": "active", "limit": 10})
result2 = await get_complex_data("user123", {"status": "active", "limit": 10})
assert result1 == result2
assert result1["user_id"] == "user123"
assert result1["filters"]["status"] == "active"
def test_local_cache_without_shared():
"""Test that shared_cache=False uses local cache only."""
call_count = 0
@cached(maxsize=10, ttl_seconds=30, shared_cache=False)
def local_only_operation(x: int) -> int:
nonlocal call_count
call_count += 1
return x**2
result1 = local_only_operation(4)
assert result1 == 16
assert call_count == 1
result2 = local_only_operation(4)
assert result2 == 16
assert call_count == 1
# Check cache info shows it's not shared
info = local_only_operation.cache_info()
assert info["shared_cache"] is False
assert info["maxsize"] == 10
if __name__ == "__main__":
# Run basic tests
print("Testing sync Redis cache...")
test_redis_cache_sync()
print("✓ Sync Redis cache test passed")
print("Testing local cache...")
test_local_cache_without_shared()
print("✓ Local cache test passed")
print("Testing fallback...")
test_fallback_to_local_cache_when_redis_unavailable()
print("✓ Fallback test passed")
print("\nAll tests passed!")

View File

@@ -70,6 +70,11 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.db.connect()
# Initialize Redis clients for cache system
from backend.util.cache_redis_init import initialize_cache_redis
initialize_cache_redis()
# Ensure SDK auto-registration is patched before initializing blocks
from backend.sdk.registry import AutoRegistry

View File

@@ -45,7 +45,8 @@ def get_cached_blocks() -> Sequence[dict]:
# Cache user's graphs list for 15 minutes
@cached(maxsize=1000, ttl_seconds=900)
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=1000, ttl_seconds=900, shared_cache=True)
async def get_cached_graphs(
user_id: str,
page: int,
@@ -60,7 +61,8 @@ async def get_cached_graphs(
# Cache individual graph details for 30 minutes
@cached(maxsize=500, ttl_seconds=1800)
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
async def get_cached_graph(
graph_id: str,
version: int | None,
@@ -76,7 +78,8 @@ async def get_cached_graph(
# Cache graph versions for 30 minutes
@cached(maxsize=500, ttl_seconds=1800)
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
async def get_cached_graph_all_versions(
graph_id: str,
user_id: str,
@@ -92,7 +95,8 @@ async def get_cached_graph_all_versions(
# Cache graph executions for 10 seconds.
@cached(maxsize=1000, ttl_seconds=10)
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=1000, ttl_seconds=10, shared_cache=True)
async def get_cached_graph_executions(
graph_id: str,
user_id: str,
@@ -109,6 +113,7 @@ async def get_cached_graph_executions(
# Cache all user executions for 10 seconds.
# No shared_cache - cache_delete not used for this function
@cached(maxsize=500, ttl_seconds=10)
async def get_cached_graphs_executions(
user_id: str,
@@ -124,6 +129,7 @@ async def get_cached_graphs_executions(
# Cache individual execution details for 10 seconds.
# No shared_cache - cache_delete not used for this function
@cached(maxsize=1000, ttl_seconds=10)
async def get_cached_graph_execution(
graph_exec_id: str,
@@ -141,7 +147,8 @@ async def get_cached_graph_execution(
# Cache user timezone for 1 hour
@cached(maxsize=1000, ttl_seconds=3600)
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=1000, ttl_seconds=3600, shared_cache=True)
async def get_cached_user_timezone(user_id: str):
"""Cached helper to get user timezone."""
user = await user_db.get_user_by_id(user_id)
@@ -149,7 +156,8 @@ async def get_cached_user_timezone(user_id: str):
# Cache user preferences for 30 minutes
@cached(maxsize=1000, ttl_seconds=1800)
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
async def get_cached_user_preferences(user_id: str):
"""Cached helper to get user notification preferences."""
return await user_db.get_user_notification_preference(user_id)

View File

@@ -13,7 +13,8 @@ import backend.server.v2.library.db
# Cache library agents list for 10 minutes
@cached(maxsize=1000, ttl_seconds=600)
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=1000, ttl_seconds=600, shared_cache=True)
async def get_cached_library_agents(
user_id: str,
page: int = 1,
@@ -28,7 +29,8 @@ async def get_cached_library_agents(
# Cache user's favorite agents for 5 minutes - favorites change more frequently
@cached(maxsize=500, ttl_seconds=300)
# No shared_cache - cache_delete not used for this function (based on search results)
@cached(maxsize=500, ttl_seconds=300, shared_cache=False)
async def get_cached_library_agent_favorites(
user_id: str,
page: int = 1,
@@ -43,7 +45,8 @@ async def get_cached_library_agent_favorites(
# Cache individual library agent details for 30 minutes
@cached(maxsize=1000, ttl_seconds=1800)
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
async def get_cached_library_agent(
library_agent_id: str,
user_id: str,
@@ -56,7 +59,7 @@ async def get_cached_library_agent(
# Cache library agent by graph ID for 30 minutes
@cached(maxsize=1000, ttl_seconds=1800)
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
async def get_cached_library_agent_by_graph_id(
graph_id: str,
user_id: str,
@@ -69,7 +72,7 @@ async def get_cached_library_agent_by_graph_id(
# Cache library agent by store version ID for 1 hour - marketplace agents are more stable
@cached(maxsize=500, ttl_seconds=3600)
@cached(maxsize=500, ttl_seconds=3600, shared_cache=True)
async def get_cached_library_agent_by_store_version(
store_listing_version_id: str,
user_id: str,
@@ -85,7 +88,8 @@ async def get_cached_library_agent_by_store_version(
# Cache library presets list for 30 minutes
@cached(maxsize=500, ttl_seconds=1800)
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
async def get_cached_library_presets(
user_id: str,
page: int = 1,
@@ -100,7 +104,8 @@ async def get_cached_library_presets(
# Cache individual preset details for 30 minutes
@cached(maxsize=1000, ttl_seconds=1800)
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
async def get_cached_library_preset(
preset_id: str,
user_id: str,

View File

@@ -11,7 +11,8 @@ import backend.server.v2.store.db
# Cache user profiles for 1 hour per user
@cached(maxsize=1000, ttl_seconds=3600)
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=1000, ttl_seconds=3600, shared_cache=True)
async def _get_cached_user_profile(user_id: str):
"""Cached helper to get user profile."""
return await backend.server.v2.store.db.get_user_profile(user_id)
@@ -19,6 +20,7 @@ async def _get_cached_user_profile(user_id: str):
# Cache store agents list for 15 minutes
# Different cache entries for different query combinations
# No shared_cache - cache_delete not used for this function
@cached(maxsize=5000, ttl_seconds=900)
async def _get_cached_store_agents(
featured: bool,
@@ -42,6 +44,7 @@ async def _get_cached_store_agents(
# Cache individual agent details for 15 minutes
# No shared_cache - cache_delete not used for this function
@cached(maxsize=200, ttl_seconds=900)
async def _get_cached_agent_details(username: str, agent_name: str):
"""Cached helper to get agent details."""
@@ -51,6 +54,7 @@ async def _get_cached_agent_details(username: str, agent_name: str):
# Cache agent graphs for 1 hour
# No shared_cache - cache_delete not used for this function
@cached(maxsize=200, ttl_seconds=3600)
async def _get_cached_agent_graph(store_listing_version_id: str):
"""Cached helper to get agent graph."""
@@ -60,6 +64,7 @@ async def _get_cached_agent_graph(store_listing_version_id: str):
# Cache agent by version for 1 hour
# No shared_cache - cache_delete not used for this function
@cached(maxsize=200, ttl_seconds=3600)
async def _get_cached_store_agent_by_version(store_listing_version_id: str):
"""Cached helper to get store agent by version ID."""
@@ -69,6 +74,7 @@ async def _get_cached_store_agent_by_version(store_listing_version_id: str):
# Cache creators list for 1 hour
# No shared_cache - cache_delete not used for this function
@cached(maxsize=200, ttl_seconds=3600)
async def _get_cached_store_creators(
featured: bool,
@@ -88,6 +94,7 @@ async def _get_cached_store_creators(
# Cache individual creator details for 1 hour
# No shared_cache - cache_delete not used for this function
@cached(maxsize=100, ttl_seconds=3600)
async def _get_cached_creator_details(username: str):
"""Cached helper to get creator details."""
@@ -97,6 +104,7 @@ async def _get_cached_creator_details(username: str):
# Cache user's own agents for 5 mins (shorter TTL as this changes more frequently)
# No shared_cache - cache_delete not used for this function
@cached(maxsize=500, ttl_seconds=300)
async def _get_cached_my_agents(user_id: str, page: int, page_size: int):
"""Cached helper to get user's agents."""
@@ -106,7 +114,8 @@ async def _get_cached_my_agents(user_id: str, page: int, page_size: int):
# Cache user's submissions for 1 hour (shorter TTL as this changes frequently)
@cached(maxsize=500, ttl_seconds=3600)
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=500, ttl_seconds=3600, shared_cache=True)
async def _get_cached_submissions(user_id: str, page: int, page_size: int):
"""Cached helper to get user's submissions."""
return await backend.server.v2.store.db.get_store_submissions(

View File

@@ -0,0 +1,76 @@
"""
Initialize Redis clients for the cache system.
This module bridges the gap between backend's Redis client and autogpt_libs cache system.
"""
import logging
from autogpt_libs.utils.cache import set_redis_client_provider
logger = logging.getLogger(__name__)
def initialize_cache_redis():
"""
Initialize the cache system with backend's Redis clients.
This function sets up the cache system to use the backend's
existing Redis connection instead of creating its own.
"""
try:
from backend.data.redis_client import HOST, PASSWORD, PORT
# Create provider functions that create new binary-mode clients
def get_sync_redis_for_cache():
"""Get sync Redis client configured for cache (binary mode)."""
try:
from redis import Redis
client = Redis(
host=HOST,
port=PORT,
password=PASSWORD,
decode_responses=False, # Binary mode for pickle
)
# Test the connection
client.ping()
return client
except Exception as e:
logger.warning(f"Failed to get Redis client for cache: {e}")
return None
async def get_async_redis_for_cache():
"""Get async Redis client configured for cache (binary mode)."""
try:
from redis.asyncio import Redis as AsyncRedis
client = AsyncRedis(
host=HOST,
port=PORT,
password=PASSWORD,
decode_responses=False, # Binary mode for pickle
)
# Test the connection
await client.ping()
return client
except Exception as e:
logger.warning(f"Failed to get async Redis client for cache: {e}")
return None
# Set the providers in the cache system
set_redis_client_provider(
sync_provider=get_sync_redis_for_cache,
async_provider=get_async_redis_for_cache,
)
logger.info("Cache system initialized with backend Redis clients")
except ImportError as e:
logger.warning(f"Could not import Redis clients, cache will use fallback: {e}")
except Exception as e:
logger.error(f"Failed to initialize cache Redis: {e}")
# Auto-initialize when module is imported
initialize_cache_redis()