Compare commits

...

12 Commits

Author SHA1 Message Date
Swifty
ae6ad35bf6 added redis caching 2025-09-25 16:45:40 +02:00
Swifty
1afebcf96b fix test 2025-09-25 15:34:47 +02:00
Swifty
bc5eb8a8a5 updated caching invalidation rules 2025-09-25 15:26:51 +02:00
Swifty
872ef5fdfb fmt 2025-09-25 13:31:18 +02:00
Swifty
12382e7990 fix duplicated caching code 2025-09-25 13:31:13 +02:00
Swifty
d68a3a1b53 fixed caching when saving an agent 2025-09-25 12:57:52 +02:00
Swifty
863e213af3 fixed caching when adding agent from library 2025-09-25 12:57:40 +02:00
Swifty
c61af53a74 Merge remote-tracking branch 'origin/dev' into swiftyos/caching-pt2 2025-09-25 12:04:07 +02:00
Swifty
eb94503de8 fix(backend): prevent caching of None graph results to handle dynamic permissions
When a graph is not found/accessible, we now clear the cache entry rather than
caching the None result. This prevents issues with store listing permissions
where a graph becomes accessible after approval but the cache still returns
the old 'not found' result.
2025-09-24 18:08:13 +02:00
Swifty
ee4feff8c2 fix(backend): include subgraphs in cached graph retrieval
The cached graph function was missing include_subgraphs=True parameter which
is needed to construct full credentials input schema. This was causing
test_access_store_listing_graph to fail.
2025-09-24 16:51:20 +02:00
Swifty
9147c2d6c8 fix(backend): update library favorites test to mock cache function instead of db
The test was failing because routes now use cached functions. Updated the mock
to patch the cache function which is what the route actually calls.
2025-09-24 16:38:07 +02:00
Swifty
a3af430c69 feat(backend): implement comprehensive caching layer for all GET endpoints (Part 2)
- Created separate cache.py modules for better code organization
  - backend/server/routers/cache.py for V1 API endpoints
  - backend/server/v2/library/cache.py for library endpoints
  - backend/server/v2/store/cache.py (refactored from routes)

- Added caching to all major GET endpoints:
  - Graphs list/details with 15-30 min TTL
  - Graph executions with 5 min TTL
  - User preferences/timezone with 30-60 min TTL
  - Library agents/favorites/presets with 10-30 min TTL
  - Store listings/profiles with 5-60 min TTL

- Implemented intelligent cache invalidation:
  - Clears relevant caches on CREATE/UPDATE/DELETE operations
  - Uses positional arguments for cache_delete to match function calls
  - Selective caching only for default queries (bypasses cache for filtered/searched results)

- Added comprehensive test coverage:
  - 20 cache-specific tests all passing
  - Validates cache hit/miss behavior
  - Verifies invalidation on mutations

- Performance improvements:
  - Reduces database load for frequently accessed data
  - Built-in thundering herd protection via @cached decorator
  - Configurable TTLs based on data volatility
2025-09-24 16:20:19 +02:00
15 changed files with 1817 additions and 245 deletions

View File

@@ -1,6 +1,8 @@
import asyncio
import inspect
import logging
import os
import pickle
import threading
import time
from functools import wraps
@@ -20,6 +22,92 @@ R_co = TypeVar("R_co", covariant=True)
logger = logging.getLogger(__name__)
# Redis client providers (can be set externally)
_redis_client_provider = None
_async_redis_client_provider = None
def set_redis_client_provider(sync_provider=None, async_provider=None):
"""
Set external Redis client providers.
This allows the backend to inject its Redis clients into the cache system.
Args:
sync_provider: A callable that returns a sync Redis client
async_provider: A callable that returns an async Redis client
"""
global _redis_client_provider, _async_redis_client_provider
if sync_provider:
_redis_client_provider = sync_provider
if async_provider:
_async_redis_client_provider = async_provider
def _get_redis_client():
"""Get Redis client from provider or create a default one."""
if _redis_client_provider:
try:
client = _redis_client_provider()
if client:
return client
except Exception as e:
logger.warning(f"Failed to get Redis client from provider: {e}")
# Fallback to creating our own client
try:
from redis import Redis
client = Redis(
host=os.getenv("REDIS_HOST", "localhost"),
port=int(os.getenv("REDIS_PORT", "6379")),
password=os.getenv("REDIS_PASSWORD", None),
decode_responses=False, # We'll use pickle for serialization
)
client.ping()
return client
except Exception as e:
logger.warning(f"Failed to connect to Redis: {e}")
return None
async def _get_async_redis_client():
"""Get async Redis client from provider or create a default one."""
if _async_redis_client_provider:
try:
# Provider is an async function, we need to await it
if inspect.iscoroutinefunction(_async_redis_client_provider):
client = await _async_redis_client_provider()
else:
client = _async_redis_client_provider()
if client:
return client
except Exception as e:
logger.warning(f"Failed to get async Redis client from provider: {e}")
# Fallback to creating our own client
try:
from redis.asyncio import Redis as AsyncRedis
client = AsyncRedis(
host=os.getenv("REDIS_HOST", "localhost"),
port=int(os.getenv("REDIS_PORT", "6379")),
password=os.getenv("REDIS_PASSWORD", None),
decode_responses=False, # We'll use pickle for serialization
)
return client
except Exception as e:
logger.warning(f"Failed to create async Redis client: {e}")
return None
def _make_redis_key(func_name: str, key: tuple) -> str:
"""Create a Redis key from function name and cache key."""
# Convert the key to a string representation
key_str = str(key)
# Add a prefix to avoid collisions with other Redis usage
return f"cache:{func_name}:{key_str}"
def _make_hashable_key(
args: tuple[Any, ...], kwargs: dict[str, Any]
@@ -85,6 +173,7 @@ def cached(
*,
maxsize: int = 128,
ttl_seconds: int | None = None,
shared_cache: bool = False,
) -> Callable[[Callable], CachedFunction]:
"""
Thundering herd safe cache decorator for both sync and async functions.
@@ -92,10 +181,13 @@ def cached(
Uses double-checked locking to prevent multiple threads/coroutines from
executing the expensive operation simultaneously during cache misses.
When shared_cache=True, uses Redis for distributed caching across instances.
Args:
func: The function to cache (when used without parentheses)
maxsize: Maximum number of cached entries
maxsize: Maximum number of cached entries (ignored when using Redis)
ttl_seconds: Time to live in seconds. If None, entries never expire
shared_cache: If True, use Redis for distributed caching
Returns:
Decorated function or decorator
@@ -127,50 +219,100 @@ def cached(
key = _make_hashable_key(args, kwargs)
current_time = time.time()
# Fast path: check cache without lock
if key in cache_storage:
if ttl_seconds is None:
logger.debug(f"Cache hit for {target_func.__name__}")
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.debug(f"Cache hit for {target_func.__name__}")
# Try Redis first if shared_cache is enabled
if shared_cache:
redis_client = await _get_async_redis_client()
if redis_client:
redis_key = _make_redis_key(target_func.__name__, key)
try:
# Check Redis cache
await redis_client.ping() # Ensure connection is alive
cached_value = await redis_client.get(redis_key)
if cached_value:
result = pickle.loads(cast(bytes, cached_value))
logger.info(
f"Redis cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
)
return result
# Slow path: acquire lock for cache miss/expiry
async with cache_lock:
# Double-check: another coroutine might have populated cache
except Exception as e:
logger.warning(f"Redis cache read failed: {e}")
# Fall through to execute function
else:
# Fast path: check local cache without lock
if key in cache_storage:
if ttl_seconds is None:
logger.info(
f"Cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
)
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.info(
f"Cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
)
return result
# Slow path: acquire lock for cache miss/expiry
async with cache_lock:
# Double-check: another coroutine might have populated cache
if shared_cache:
redis_client = await _get_async_redis_client()
if redis_client:
redis_key = _make_redis_key(target_func.__name__, key)
try:
cached_value = await redis_client.get(redis_key)
if cached_value:
return pickle.loads(cast(bytes, cached_value))
except Exception as e:
logger.warning(f"Redis cache read failed in lock: {e}")
else:
if key in cache_storage:
if ttl_seconds is None:
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {target_func.__name__}")
logger.info(f"Cache miss for {target_func.__name__}")
result = await target_func(*args, **kwargs)
# Store result
if ttl_seconds is None:
cache_storage[key] = result
if shared_cache:
redis_client = await _get_async_redis_client()
if redis_client:
redis_key = _make_redis_key(target_func.__name__, key)
try:
serialized = pickle.dumps(result)
await redis_client.set(
redis_key,
serialized,
ex=ttl_seconds if ttl_seconds else None,
)
except Exception as e:
logger.warning(f"Redis cache write failed: {e}")
else:
cache_storage[key] = (result, current_time)
if ttl_seconds is None:
cache_storage[key] = result
else:
cache_storage[key] = (result, current_time)
# Cleanup if needed
if len(cache_storage) > maxsize:
cutoff = maxsize // 2
oldest_keys = (
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
)
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
# Cleanup if needed
if len(cache_storage) > maxsize:
cutoff = maxsize // 2
oldest_keys = (
list(cache_storage.keys())[:-cutoff]
if cutoff > 0
else []
)
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
return result
@@ -185,50 +327,99 @@ def cached(
key = _make_hashable_key(args, kwargs)
current_time = time.time()
# Fast path: check cache without lock
if key in cache_storage:
if ttl_seconds is None:
logger.debug(f"Cache hit for {target_func.__name__}")
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.debug(f"Cache hit for {target_func.__name__}")
# Try Redis first if shared_cache is enabled
if shared_cache:
redis_client = _get_redis_client()
if redis_client:
redis_key = _make_redis_key(target_func.__name__, key)
try:
# Check Redis cache
cached_value = redis_client.get(redis_key)
if cached_value:
result = pickle.loads(cast(bytes, cached_value))
logger.info(
f"Redis cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
)
return result
# Slow path: acquire lock for cache miss/expiry
with cache_lock:
# Double-check: another thread might have populated cache
except Exception as e:
logger.warning(f"Redis cache read failed: {e}")
# Fall through to execute function
else:
# Fast path: check local cache without lock
if key in cache_storage:
if ttl_seconds is None:
logger.info(
f"Cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
)
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.info(
f"Cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
)
return result
# Slow path: acquire lock for cache miss/expiry
with cache_lock:
# Double-check: another thread might have populated cache
if shared_cache:
redis_client = _get_redis_client()
if redis_client:
redis_key = _make_redis_key(target_func.__name__, key)
try:
cached_value = redis_client.get(redis_key)
if cached_value:
return pickle.loads(cast(bytes, cached_value))
except Exception as e:
logger.warning(f"Redis cache read failed in lock: {e}")
else:
if key in cache_storage:
if ttl_seconds is None:
return cache_storage[key]
else:
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {target_func.__name__}")
logger.info(f"Cache miss for {target_func.__name__}")
result = target_func(*args, **kwargs)
# Store result
if ttl_seconds is None:
cache_storage[key] = result
if shared_cache:
redis_client = _get_redis_client()
if redis_client:
redis_key = _make_redis_key(target_func.__name__, key)
try:
serialized = pickle.dumps(result)
redis_client.set(
redis_key,
serialized,
ex=ttl_seconds if ttl_seconds else None,
)
except Exception as e:
logger.warning(f"Redis cache write failed: {e}")
else:
cache_storage[key] = (result, current_time)
if ttl_seconds is None:
cache_storage[key] = result
else:
cache_storage[key] = (result, current_time)
# Cleanup if needed
if len(cache_storage) > maxsize:
cutoff = maxsize // 2
oldest_keys = (
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
)
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
# Cleanup if needed
if len(cache_storage) > maxsize:
cutoff = maxsize // 2
oldest_keys = (
list(cache_storage.keys())[:-cutoff]
if cutoff > 0
else []
)
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
return result
@@ -236,22 +427,105 @@ def cached(
# Add cache management methods
def cache_clear() -> None:
cache_storage.clear()
"""Clear all cached entries."""
if shared_cache:
redis_client = (
_get_redis_client()
if not inspect.iscoroutinefunction(target_func)
else None
)
if redis_client:
try:
# Clear all cache entries for this function
pattern = f"cache:{target_func.__name__}:*"
for key in redis_client.scan_iter(match=pattern):
redis_client.delete(key)
except Exception as e:
logger.warning(f"Redis cache clear failed: {e}")
else:
cache_storage.clear()
def cache_info() -> dict[str, int | None]:
def cache_info() -> dict[str, Any]:
"""Get cache statistics."""
if shared_cache:
redis_client = (
_get_redis_client()
if not inspect.iscoroutinefunction(target_func)
else None
)
if redis_client:
try:
pattern = f"cache:{target_func.__name__}:*"
count = sum(1 for _ in redis_client.scan_iter(match=pattern))
return {
"size": count,
"maxsize": None, # Not applicable for Redis
"ttl_seconds": ttl_seconds,
"shared_cache": True,
}
except Exception as e:
logger.warning(f"Redis cache info failed: {e}")
return {
"size": 0,
"maxsize": None,
"ttl_seconds": ttl_seconds,
"shared_cache": True,
"error": str(e),
}
return {
"size": len(cache_storage),
"maxsize": maxsize,
"ttl_seconds": ttl_seconds,
"shared_cache": False,
}
def cache_delete(*args, **kwargs) -> bool:
"""Delete a specific cache entry. Returns True if entry existed."""
key = _make_hashable_key(args, kwargs)
if key in cache_storage:
del cache_storage[key]
return True
return False
# Create appropriate cache_delete based on whether function is async
if inspect.iscoroutinefunction(target_func):
async def async_cache_delete(*args, **kwargs) -> bool:
"""Delete a specific cache entry. Returns True if entry existed."""
key = _make_hashable_key(args, kwargs)
if shared_cache:
redis_client = await _get_async_redis_client()
if redis_client:
redis_key = _make_redis_key(target_func.__name__, key)
try:
result = await redis_client.delete(redis_key)
return cast(int, result) > 0
except Exception as e:
logger.warning(f"Redis cache delete failed: {e}")
return False
else:
if key in cache_storage:
del cache_storage[key]
return True
return False
cache_delete = async_cache_delete
else:
def sync_cache_delete(*args, **kwargs) -> bool:
"""Delete a specific cache entry. Returns True if entry existed."""
key = _make_hashable_key(args, kwargs)
if shared_cache:
redis_client = _get_redis_client()
if redis_client:
redis_key = _make_redis_key(target_func.__name__, key)
try:
result = redis_client.delete(redis_key)
return cast(int, result) > 0
except Exception as e:
logger.warning(f"Redis cache delete failed: {e}")
return False
else:
if key in cache_storage:
del cache_storage[key]
return True
return False
cache_delete = sync_cache_delete
setattr(wrapper, "cache_clear", cache_clear)
setattr(wrapper, "cache_info", cache_info)

View File

@@ -0,0 +1,172 @@
"""
Test Redis cache functionality.
"""
import asyncio
import time
from unittest.mock import patch
import pytest
from autogpt_libs.utils.cache import cached
# Test with Redis cache enabled
@pytest.mark.asyncio
async def test_redis_cache_async():
"""Test async function with Redis cache."""
call_count = 0
@cached(ttl_seconds=60, shared_cache=True)
async def expensive_async_operation(x: int, y: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.1)
return x + y
# First call should execute function
result1 = await expensive_async_operation(5, 3)
assert result1 == 8
assert call_count == 1
# Second call with same args should use cache
result2 = await expensive_async_operation(5, 3)
assert result2 == 8
assert call_count == 1 # Should not increment
# Different args should execute function again
result3 = await expensive_async_operation(10, 5)
assert result3 == 15
assert call_count == 2
# Test cache_delete
deleted = expensive_async_operation.cache_delete(5, 3)
assert isinstance(deleted, bool) # Depends on Redis availability
# Test cache_clear
expensive_async_operation.cache_clear()
# Test cache_info
info = expensive_async_operation.cache_info()
assert "ttl_seconds" in info
assert info["ttl_seconds"] == 60
assert "shared_cache" in info
def test_redis_cache_sync():
"""Test sync function with Redis cache."""
call_count = 0
@cached(ttl_seconds=60, shared_cache=True)
def expensive_sync_operation(x: int, y: int) -> int:
nonlocal call_count
call_count += 1
time.sleep(0.1)
return x * y
# First call should execute function
result1 = expensive_sync_operation(5, 3)
assert result1 == 15
assert call_count == 1
# Second call with same args should use cache
result2 = expensive_sync_operation(5, 3)
assert result2 == 15
assert call_count == 1 # Should not increment
# Different args should execute function again
result3 = expensive_sync_operation(10, 5)
assert result3 == 50
assert call_count == 2
# Test cache management functions
expensive_sync_operation.cache_clear()
info = expensive_sync_operation.cache_info()
assert info["shared_cache"] is True
def test_fallback_to_local_cache_when_redis_unavailable():
"""Test that cache falls back to local when Redis is unavailable."""
with patch("autogpt_libs.utils.cache._get_redis_client", return_value=None):
call_count = 0
@cached(ttl_seconds=60, shared_cache=True)
def operation(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 2
# Should still work with local cache
result1 = operation(5)
assert result1 == 10
assert call_count == 1
result2 = operation(5)
assert result2 == 10
assert call_count == 1 # Cached locally
@pytest.mark.asyncio
async def test_redis_cache_with_complex_types():
"""Test Redis cache with complex data types."""
@cached(ttl_seconds=30, shared_cache=True)
async def get_complex_data(user_id: str, filters: dict) -> dict:
return {
"user_id": user_id,
"filters": filters,
"data": [1, 2, 3, 4, 5],
"nested": {"key1": "value1", "key2": ["a", "b", "c"]},
}
result1 = await get_complex_data("user123", {"status": "active", "limit": 10})
result2 = await get_complex_data("user123", {"status": "active", "limit": 10})
assert result1 == result2
assert result1["user_id"] == "user123"
assert result1["filters"]["status"] == "active"
def test_local_cache_without_shared():
"""Test that shared_cache=False uses local cache only."""
call_count = 0
@cached(maxsize=10, ttl_seconds=30, shared_cache=False)
def local_only_operation(x: int) -> int:
nonlocal call_count
call_count += 1
return x**2
result1 = local_only_operation(4)
assert result1 == 16
assert call_count == 1
result2 = local_only_operation(4)
assert result2 == 16
assert call_count == 1
# Check cache info shows it's not shared
info = local_only_operation.cache_info()
assert info["shared_cache"] is False
assert info["maxsize"] == 10
if __name__ == "__main__":
# Run basic tests
print("Testing sync Redis cache...")
test_redis_cache_sync()
print("✓ Sync Redis cache test passed")
print("Testing local cache...")
test_local_cache_without_shared()
print("✓ Local cache test passed")
print("Testing fallback...")
test_fallback_to_local_cache_when_redis_unavailable()
print("✓ Fallback test passed")
print("\nAll tests passed!")

View File

@@ -70,6 +70,11 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.db.connect()
# Initialize Redis clients for cache system
from backend.util.cache_redis_init import initialize_cache_redis
initialize_cache_redis()
# Ensure SDK auto-registration is patched before initializing blocks
from backend.sdk.registry import AutoRegistry

View File

@@ -0,0 +1,163 @@
"""
Cache functions for main V1 API endpoints.
This module contains all caching decorators and helpers for the V1 API,
separated from the main routes for better organization and maintainability.
"""
from typing import Sequence
from autogpt_libs.utils.cache import cached
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data import user as user_db
from backend.data.block import get_blocks
# ===== Block Caches =====
# Cache block definitions with costs - they rarely change
@cached(maxsize=1, ttl_seconds=3600)
def get_cached_blocks() -> Sequence[dict]:
"""
Get cached blocks with thundering herd protection.
Uses cached decorator to prevent multiple concurrent requests
from all executing the expensive block loading operation.
"""
from backend.data.credit import get_block_cost
block_classes = get_blocks()
result = []
for block_class in block_classes.values():
block_instance = block_class()
if not block_instance.disabled:
# Get costs for this specific block class without creating another instance
costs = get_block_cost(block_instance)
result.append({**block_instance.to_dict(), "costs": costs})
return result
# ===== Graph Caches =====
# Cache user's graphs list for 15 minutes
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=1000, ttl_seconds=900, shared_cache=True)
async def get_cached_graphs(
user_id: str,
page: int,
page_size: int,
):
"""Cached helper to get user's graphs."""
return await graph_db.list_graphs_paginated(
user_id=user_id,
page=page,
page_size=page_size,
)
# Cache individual graph details for 30 minutes
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
async def get_cached_graph(
graph_id: str,
version: int | None,
user_id: str,
):
"""Cached helper to get graph details."""
return await graph_db.get_graph(
graph_id=graph_id,
version=version,
user_id=user_id,
include_subgraphs=True, # needed to construct full credentials input schema
)
# Cache graph versions for 30 minutes
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
async def get_cached_graph_all_versions(
graph_id: str,
user_id: str,
) -> Sequence[graph_db.GraphModel]:
"""Cached helper to get all versions of a graph."""
return await graph_db.get_graph_all_versions(
graph_id=graph_id,
user_id=user_id,
)
# ===== Execution Caches =====
# Cache graph executions for 10 seconds.
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=1000, ttl_seconds=10, shared_cache=True)
async def get_cached_graph_executions(
graph_id: str,
user_id: str,
page: int,
page_size: int,
):
"""Cached helper to get graph executions."""
return await execution_db.get_graph_executions_paginated(
graph_id=graph_id,
user_id=user_id,
page=page,
page_size=page_size,
)
# Cache all user executions for 10 seconds.
# No shared_cache - cache_delete not used for this function
@cached(maxsize=500, ttl_seconds=10)
async def get_cached_graphs_executions(
user_id: str,
page: int,
page_size: int,
):
"""Cached helper to get all user's graph executions."""
return await execution_db.get_graph_executions_paginated(
user_id=user_id,
page=page,
page_size=page_size,
)
# Cache individual execution details for 10 seconds.
# No shared_cache - cache_delete not used for this function
@cached(maxsize=1000, ttl_seconds=10)
async def get_cached_graph_execution(
graph_exec_id: str,
user_id: str,
):
"""Cached helper to get graph execution details."""
return await execution_db.get_graph_execution(
user_id=user_id,
execution_id=graph_exec_id,
include_node_executions=False,
)
# ===== User Preference Caches =====
# Cache user timezone for 1 hour
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=1000, ttl_seconds=3600, shared_cache=True)
async def get_cached_user_timezone(user_id: str):
"""Cached helper to get user timezone."""
user = await user_db.get_user_by_id(user_id)
return {"timezone": user.timezone if user else "UTC"}
# Cache user preferences for 30 minutes
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
async def get_cached_user_preferences(user_id: str):
"""Cached helper to get user notification preferences."""
return await user_db.get_user_notification_preference(user_id)

View File

@@ -0,0 +1,346 @@
"""
Tests for cache invalidation in V1 API routes.
This module tests that caches are properly invalidated when data is modified
through POST, PUT, PATCH, and DELETE operations.
"""
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import backend.server.routers.cache as cache
from backend.data import graph as graph_db
@pytest.fixture
def mock_user_id():
"""Generate a mock user ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def mock_graph_id():
"""Generate a mock graph ID for testing."""
return str(uuid.uuid4())
class TestGraphCacheInvalidation:
"""Test cache invalidation for graph operations."""
@pytest.mark.asyncio
async def test_create_graph_clears_list_cache(self, mock_user_id):
"""Test that creating a graph clears the graphs list cache."""
# Setup
cache.get_cached_graphs.cache_clear()
# Pre-populate cache
with patch.object(
graph_db, "list_graphs_paginated", new_callable=AsyncMock
) as mock_list:
mock_list.return_value = MagicMock(graphs=[])
# First call should hit the database
await cache.get_cached_graphs(mock_user_id, 1, 250)
assert mock_list.call_count == 1
# Second call should use cache
await cache.get_cached_graphs(mock_user_id, 1, 250)
assert mock_list.call_count == 1 # Still 1, used cache
# Simulate cache invalidation (what happens in create_new_graph)
cache.get_cached_graphs.cache_delete(mock_user_id, 1, 250)
# Next call should hit database again
await cache.get_cached_graphs(mock_user_id, 1, 250)
assert mock_list.call_count == 2 # Incremented, cache was cleared
@pytest.mark.asyncio
async def test_delete_graph_clears_multiple_caches(
self, mock_user_id, mock_graph_id
):
"""Test that deleting a graph clears all related caches."""
# Clear all caches first
cache.get_cached_graphs.cache_clear()
cache.get_cached_graph.cache_clear()
cache.get_cached_graph_all_versions.cache_clear()
cache.get_cached_graph_executions.cache_clear()
# Setup mocks
with patch.object(
graph_db, "list_graphs_paginated", new_callable=AsyncMock
) as mock_list, patch.object(
graph_db, "get_graph", new_callable=AsyncMock
) as mock_get, patch.object(
graph_db, "get_graph_all_versions", new_callable=AsyncMock
) as mock_versions:
mock_list.return_value = MagicMock(graphs=[])
mock_get.return_value = MagicMock(id=mock_graph_id)
mock_versions.return_value = []
# Pre-populate all caches (use consistent argument style)
await cache.get_cached_graphs(mock_user_id, 1, 250)
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
initial_calls = {
"list": mock_list.call_count,
"get": mock_get.call_count,
"versions": mock_versions.call_count,
}
# Use cached values (no additional DB calls)
await cache.get_cached_graphs(mock_user_id, 1, 250)
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
# Verify cache was used
assert mock_list.call_count == initial_calls["list"]
assert mock_get.call_count == initial_calls["get"]
assert mock_versions.call_count == initial_calls["versions"]
# Simulate delete_graph cache invalidation
# Use positional arguments for cache_delete to match how we called the functions
result1 = cache.get_cached_graphs.cache_delete(mock_user_id, 1, 250)
result2 = cache.get_cached_graph.cache_delete(
mock_graph_id, None, mock_user_id
)
result3 = cache.get_cached_graph_all_versions.cache_delete(
mock_graph_id, mock_user_id
)
# Verify that the cache entries were actually deleted
assert result1, "Failed to delete graphs cache entry"
assert result2, "Failed to delete graph cache entry"
assert result3, "Failed to delete graph versions cache entry"
# Next calls should hit database
await cache.get_cached_graphs(mock_user_id, 1, 250)
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
# Verify database was called again
assert mock_list.call_count == initial_calls["list"] + 1
assert mock_get.call_count == initial_calls["get"] + 1
assert mock_versions.call_count == initial_calls["versions"] + 1
@pytest.mark.asyncio
async def test_update_graph_clears_caches(self, mock_user_id, mock_graph_id):
"""Test that updating a graph clears the appropriate caches."""
# Clear caches
cache.get_cached_graph.cache_clear()
cache.get_cached_graph_all_versions.cache_clear()
cache.get_cached_graphs.cache_clear()
with patch.object(
graph_db, "get_graph", new_callable=AsyncMock
) as mock_get, patch.object(
graph_db, "get_graph_all_versions", new_callable=AsyncMock
) as mock_versions, patch.object(
graph_db, "list_graphs_paginated", new_callable=AsyncMock
) as mock_list:
mock_get.return_value = MagicMock(id=mock_graph_id, version=1)
mock_versions.return_value = [MagicMock(version=1)]
mock_list.return_value = MagicMock(graphs=[])
# Populate caches
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
await cache.get_cached_graphs(mock_user_id, 1, 250)
initial_calls = {
"get": mock_get.call_count,
"versions": mock_versions.call_count,
"list": mock_list.call_count,
}
# Verify cache is being used
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
await cache.get_cached_graphs(mock_user_id, 1, 250)
assert mock_get.call_count == initial_calls["get"]
assert mock_versions.call_count == initial_calls["versions"]
assert mock_list.call_count == initial_calls["list"]
# Simulate update_graph cache invalidation
cache.get_cached_graph.cache_delete(mock_graph_id, None, mock_user_id)
cache.get_cached_graph_all_versions.cache_delete(
mock_graph_id, mock_user_id
)
cache.get_cached_graphs.cache_delete(mock_user_id, 1, 250)
# Next calls should hit database
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
await cache.get_cached_graphs(mock_user_id, 1, 250)
assert mock_get.call_count == initial_calls["get"] + 1
assert mock_versions.call_count == initial_calls["versions"] + 1
assert mock_list.call_count == initial_calls["list"] + 1
class TestUserPreferencesCacheInvalidation:
"""Test cache invalidation for user preferences operations."""
@pytest.mark.asyncio
async def test_update_preferences_clears_cache(self, mock_user_id):
"""Test that updating preferences clears the preferences cache."""
# Clear cache
cache.get_cached_user_preferences.cache_clear()
with patch.object(
cache.user_db, "get_user_notification_preference", new_callable=AsyncMock
) as mock_get_prefs:
mock_prefs = MagicMock(email_notifications=True, push_notifications=False)
mock_get_prefs.return_value = mock_prefs
# First call hits database
result1 = await cache.get_cached_user_preferences(mock_user_id)
assert mock_get_prefs.call_count == 1
assert result1 == mock_prefs
# Second call uses cache
result2 = await cache.get_cached_user_preferences(mock_user_id)
assert mock_get_prefs.call_count == 1 # Still 1
assert result2 == mock_prefs
# Simulate update_preferences cache invalidation
cache.get_cached_user_preferences.cache_delete(mock_user_id)
# Change the mock return value to simulate updated preferences
mock_prefs_updated = MagicMock(
email_notifications=False, push_notifications=True
)
mock_get_prefs.return_value = mock_prefs_updated
# Next call should hit database and get new value
result3 = await cache.get_cached_user_preferences(mock_user_id)
assert mock_get_prefs.call_count == 2
assert result3 == mock_prefs_updated
@pytest.mark.asyncio
async def test_timezone_cache_operations(self, mock_user_id):
"""Test timezone cache and its operations."""
# Clear cache
cache.get_cached_user_timezone.cache_clear()
with patch.object(
cache.user_db, "get_user_by_id", new_callable=AsyncMock
) as mock_get_user:
mock_user = MagicMock(timezone="America/New_York")
mock_get_user.return_value = mock_user
# First call hits database
result1 = await cache.get_cached_user_timezone(mock_user_id)
assert mock_get_user.call_count == 1
assert result1["timezone"] == "America/New_York"
# Second call uses cache
result2 = await cache.get_cached_user_timezone(mock_user_id)
assert mock_get_user.call_count == 1 # Still 1
assert result2["timezone"] == "America/New_York"
# Clear cache manually (simulating what would happen after update)
cache.get_cached_user_timezone.cache_delete(mock_user_id)
# Change timezone
mock_user.timezone = "Europe/London"
# Next call should hit database
result3 = await cache.get_cached_user_timezone(mock_user_id)
assert mock_get_user.call_count == 2
assert result3["timezone"] == "Europe/London"
class TestExecutionCacheInvalidation:
"""Test cache invalidation for execution operations."""
@pytest.mark.asyncio
async def test_execution_cache_cleared_on_graph_delete(
self, mock_user_id, mock_graph_id
):
"""Test that execution caches are cleared when a graph is deleted."""
# Clear cache
cache.get_cached_graph_executions.cache_clear()
with patch.object(
cache.execution_db, "get_graph_executions_paginated", new_callable=AsyncMock
) as mock_exec:
mock_exec.return_value = MagicMock(executions=[])
# Populate cache for multiple pages
for page in range(1, 4):
await cache.get_cached_graph_executions(
mock_graph_id, mock_user_id, page, 25
)
initial_calls = mock_exec.call_count
# Verify cache is used
for page in range(1, 4):
await cache.get_cached_graph_executions(
mock_graph_id, mock_user_id, page, 25
)
assert mock_exec.call_count == initial_calls # No new calls
# Simulate graph deletion clearing execution caches
for page in range(1, 10): # Clear more pages as done in delete_graph
cache.get_cached_graph_executions.cache_delete(
mock_graph_id, mock_user_id, page, 25
)
# Next calls should hit database
for page in range(1, 4):
await cache.get_cached_graph_executions(
mock_graph_id, mock_user_id, page, 25
)
assert mock_exec.call_count == initial_calls + 3 # 3 new calls
class TestCacheInfo:
"""Test cache information and metrics."""
def test_cache_info_returns_correct_metrics(self):
"""Test that cache_info returns correct metrics."""
# Clear all caches
cache.get_cached_graphs.cache_clear()
cache.get_cached_graph.cache_clear()
# Get initial info
info_graphs = cache.get_cached_graphs.cache_info()
info_graph = cache.get_cached_graph.cache_info()
assert info_graphs["size"] == 0
assert info_graph["size"] == 0
# Note: We can't directly test cache population without real async context,
# but we can verify the cache_info structure
assert "size" in info_graphs
assert "maxsize" in info_graphs
assert "ttl_seconds" in info_graphs
def test_cache_clear_removes_all_entries(self):
"""Test that cache_clear removes all entries."""
# This test verifies the cache_clear method exists and can be called
cache.get_cached_graphs.cache_clear()
cache.get_cached_graph.cache_clear()
cache.get_cached_graph_all_versions.cache_clear()
cache.get_cached_graph_executions.cache_clear()
cache.get_cached_graphs_executions.cache_clear()
cache.get_cached_user_preferences.cache_clear()
cache.get_cached_user_timezone.cache_clear()
# After clear, all caches should be empty
assert cache.get_cached_graphs.cache_info()["size"] == 0
assert cache.get_cached_graph.cache_info()["size"] == 0
assert cache.get_cached_graph_all_versions.cache_info()["size"] == 0
assert cache.get_cached_graph_executions.cache_info()["size"] == 0
assert cache.get_cached_graphs_executions.cache_info()["size"] == 0
assert cache.get_cached_user_preferences.cache_info()["size"] == 0
assert cache.get_cached_user_timezone.cache_info()["size"] == 0

View File

@@ -11,7 +11,6 @@ import pydantic
import stripe
from autogpt_libs.auth import get_user_id, requires_user
from autogpt_libs.auth.jwt_utils import get_jwt_payload
from autogpt_libs.utils.cache import cached
from fastapi import (
APIRouter,
Body,
@@ -29,11 +28,12 @@ from typing_extensions import Optional, TypedDict
import backend.server.integrations.router
import backend.server.routers.analytics
import backend.server.routers.cache as cache
import backend.server.v2.library.db as library_db
from backend.data import api_key as api_key_db
from backend.data import execution as execution_db
from backend.data import graph as graph_db
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
from backend.data.block import BlockInput, CompletedBlockOutput, get_block
from backend.data.credit import (
AutoTopUpConfig,
RefundRequest,
@@ -55,7 +55,6 @@ from backend.data.onboarding import (
from backend.data.user import (
get_or_create_user,
get_user_by_id,
get_user_notification_preference,
update_user_email,
update_user_notification_preference,
update_user_timezone,
@@ -82,6 +81,7 @@ from backend.server.model import (
UpdateTimezoneRequest,
UploadFileResponse,
)
from backend.server.v2.library import cache as library_cache
from backend.util.clients import get_scheduler_client
from backend.util.cloud_storage import get_cloud_storage_handler
from backend.util.exceptions import GraphValidationError, NotFoundError
@@ -165,7 +165,9 @@ async def get_user_timezone_route(
) -> TimezoneResponse:
"""Get user timezone setting."""
user = await get_or_create_user(user_data)
return TimezoneResponse(timezone=user.timezone)
# Use cached timezone for subsequent calls
result = await cache.get_cached_user_timezone(user.id)
return TimezoneResponse(timezone=result["timezone"])
@v1_router.post(
@@ -179,6 +181,7 @@ async def update_user_timezone_route(
) -> TimezoneResponse:
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
user = await update_user_timezone(user_id, str(request.timezone))
cache.get_cached_user_timezone.cache_delete(user_id)
return TimezoneResponse(timezone=user.timezone)
@@ -191,7 +194,7 @@ async def update_user_timezone_route(
async def get_preferences(
user_id: Annotated[str, Security(get_user_id)],
) -> NotificationPreference:
preferences = await get_user_notification_preference(user_id)
preferences = await cache.get_cached_user_preferences(user_id)
return preferences
@@ -206,6 +209,10 @@ async def update_preferences(
preferences: NotificationPreferenceDTO = Body(...),
) -> NotificationPreference:
output = await update_user_notification_preference(user_id, preferences)
# Clear preferences cache after update
cache.get_cached_user_preferences.cache_delete(user_id)
return output
@@ -263,29 +270,6 @@ async def is_onboarding_enabled():
########################################################
@cached()
def _get_cached_blocks() -> Sequence[dict[Any, Any]]:
"""
Get cached blocks with thundering herd protection.
Uses sync_cache decorator to prevent multiple concurrent requests
from all executing the expensive block loading operation.
"""
from backend.data.credit import get_block_cost
block_classes = get_blocks()
result = []
for block_class in block_classes.values():
block_instance = block_class()
if not block_instance.disabled:
# Get costs for this specific block class without creating another instance
costs = get_block_cost(block_instance)
result.append({**block_instance.to_dict(), "costs": costs})
return result
@v1_router.get(
path="/blocks",
summary="List available blocks",
@@ -293,7 +277,7 @@ def _get_cached_blocks() -> Sequence[dict[Any, Any]]:
dependencies=[Security(requires_user)],
)
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
return _get_cached_blocks()
return cache.get_cached_blocks()
@v1_router.post(
@@ -633,11 +617,10 @@ class DeleteGraphResponse(TypedDict):
async def list_graphs(
user_id: Annotated[str, Security(get_user_id)],
) -> Sequence[graph_db.GraphMeta]:
paginated_result = await graph_db.list_graphs_paginated(
paginated_result = await cache.get_cached_graphs(
user_id=user_id,
page=1,
page_size=250,
filter_by="active",
)
return paginated_result.graphs
@@ -660,13 +643,26 @@ async def get_graph(
version: int | None = None,
for_export: bool = False,
) -> graph_db.GraphModel:
graph = await graph_db.get_graph(
graph_id,
version,
user_id=user_id,
for_export=for_export,
include_subgraphs=True, # needed to construct full credentials input schema
)
# Use cache for non-export requests
if not for_export:
graph = await cache.get_cached_graph(
graph_id=graph_id,
version=version,
user_id=user_id,
)
# If graph not found, clear cache entry as permissions may have changed
if not graph:
cache.get_cached_graph.cache_delete(
graph_id=graph_id, version=version, user_id=user_id
)
else:
graph = await graph_db.get_graph(
graph_id,
version,
user_id=user_id,
for_export=for_export,
include_subgraphs=True, # needed to construct full credentials input schema
)
if not graph:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return graph
@@ -681,7 +677,7 @@ async def get_graph(
async def get_graph_all_versions(
graph_id: str, user_id: Annotated[str, Security(get_user_id)]
) -> Sequence[graph_db.GraphModel]:
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
graphs = await cache.get_cached_graph_all_versions(graph_id, user_id=user_id)
if not graphs:
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
return graphs
@@ -705,6 +701,14 @@ async def create_new_graph(
# as the graph already valid and no sub-graphs are returned back.
await graph_db.create_graph(graph, user_id=user_id)
await library_db.create_library_agent(graph, user_id=user_id)
# Clear graphs list cache after creating new graph
cache.get_cached_graphs.cache_delete(user_id=user_id, page=1, page_size=250)
for page in range(1, 20):
library_cache.get_cached_library_agents.cache_delete(
user_id=user_id, page=page, page_size=8
)
return await on_graph_activate(graph, user_id=user_id)
@@ -720,7 +724,18 @@ async def delete_graph(
if active_version := await graph_db.get_graph(graph_id, user_id=user_id):
await on_graph_deactivate(active_version, user_id=user_id)
return {"version_counts": await graph_db.delete_graph(graph_id, user_id=user_id)}
result = DeleteGraphResponse(
version_counts=await graph_db.delete_graph(graph_id, user_id=user_id)
)
# Clear caches after deleting graph
cache.get_cached_graphs.cache_delete(user_id=user_id, page=1, page_size=250)
cache.get_cached_graph.cache_delete(
graph_id=graph_id, version=None, user_id=user_id
)
cache.get_cached_graph_all_versions.cache_delete(graph_id, user_id=user_id)
return result
@v1_router.put(
@@ -776,6 +791,14 @@ async def update_graph(
include_subgraphs=True,
)
assert new_graph_version_with_subgraphs # make type checker happy
# Clear caches after updating graph
cache.get_cached_graph.cache_delete(
graph_id=graph_id, version=None, user_id=user_id
)
cache.get_cached_graph_all_versions.cache_delete(graph_id, user_id=user_id)
cache.get_cached_graphs.cache_delete(user_id=user_id, page=1, page_size=250)
return new_graph_version_with_subgraphs
@@ -853,6 +876,12 @@ async def execute_graph(
# Record successful graph execution
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
record_graph_operation(operation="execute", status="success")
for page in range(1, 10):
cache.get_cached_graph_executions.cache_delete(
graph_id=graph_id, user_id=user_id, page=page, page_size=20
)
return result
except GraphValidationError as e:
# Record failed graph execution
@@ -928,7 +957,7 @@ async def _stop_graph_run(
async def list_graphs_executions(
user_id: Annotated[str, Security(get_user_id)],
) -> list[execution_db.GraphExecutionMeta]:
paginated_result = await execution_db.get_graph_executions_paginated(
paginated_result = await cache.get_cached_graphs_executions(
user_id=user_id,
page=1,
page_size=250,
@@ -950,7 +979,7 @@ async def list_graph_executions(
25, ge=1, le=100, description="Number of executions per page"
),
) -> execution_db.GraphExecutionsPaginated:
return await execution_db.get_graph_executions_paginated(
return await cache.get_cached_graph_executions(
graph_id=graph_id,
user_id=user_id,
page=page,

View File

@@ -102,13 +102,13 @@ def test_get_graph_blocks(
mock_block.id = "test-block"
mock_block.disabled = False
# Mock get_blocks
# Mock get_blocks in cache module where it's actually used
mocker.patch(
"backend.server.routers.v1.get_blocks",
"backend.server.routers.cache.get_blocks",
return_value={"test-block": lambda: mock_block},
)
# Mock block costs
# Mock block costs where it's imported inside the function
mocker.patch(
"backend.data.credit.get_block_cost",
return_value=[{"cost": 10, "type": "credit"}],

View File

@@ -0,0 +1,117 @@
"""
Cache functions for Library API endpoints.
This module contains all caching decorators and helpers for the Library API,
separated from the main routes for better organization and maintainability.
"""
from autogpt_libs.utils.cache import cached
import backend.server.v2.library.db
# ===== Library Agent Caches =====
# Cache library agents list for 10 minutes
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=1000, ttl_seconds=600, shared_cache=True)
async def get_cached_library_agents(
user_id: str,
page: int = 1,
page_size: int = 20,
):
"""Cached helper to get library agents list."""
return await backend.server.v2.library.db.list_library_agents(
user_id=user_id,
page=page,
page_size=page_size,
)
# Cache user's favorite agents for 5 minutes - favorites change more frequently
# No shared_cache - cache_delete not used for this function (based on search results)
@cached(maxsize=500, ttl_seconds=300, shared_cache=False)
async def get_cached_library_agent_favorites(
user_id: str,
page: int = 1,
page_size: int = 20,
):
"""Cached helper to get user's favorite library agents."""
return await backend.server.v2.library.db.list_favorite_library_agents(
user_id=user_id,
page=page,
page_size=page_size,
)
# Cache individual library agent details for 30 minutes
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
async def get_cached_library_agent(
library_agent_id: str,
user_id: str,
):
"""Cached helper to get library agent details."""
return await backend.server.v2.library.db.get_library_agent(
id=library_agent_id,
user_id=user_id,
)
# Cache library agent by graph ID for 30 minutes
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
async def get_cached_library_agent_by_graph_id(
graph_id: str,
user_id: str,
):
"""Cached helper to get library agent by graph ID."""
return await backend.server.v2.library.db.get_library_agent_by_graph_id(
graph_id=graph_id,
user_id=user_id,
)
# Cache library agent by store version ID for 1 hour - marketplace agents are more stable
@cached(maxsize=500, ttl_seconds=3600, shared_cache=True)
async def get_cached_library_agent_by_store_version(
store_listing_version_id: str,
user_id: str,
):
"""Cached helper to get library agent by store version ID."""
return await backend.server.v2.library.db.get_library_agent_by_store_version_id(
store_listing_version_id=store_listing_version_id,
user_id=user_id,
)
# ===== Library Preset Caches =====
# Cache library presets list for 30 minutes
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
async def get_cached_library_presets(
user_id: str,
page: int = 1,
page_size: int = 20,
):
"""Cached helper to get library presets list."""
return await backend.server.v2.library.db.list_presets(
user_id=user_id,
page=page,
page_size=page_size,
)
# Cache individual preset details for 30 minutes
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
async def get_cached_library_preset(
preset_id: str,
user_id: str,
):
"""Cached helper to get library preset details."""
return await backend.server.v2.library.db.get_preset(
preset_id=preset_id,
user_id=user_id,
)

View File

@@ -0,0 +1,272 @@
"""
Tests for cache invalidation in Library API routes.
This module tests that library caches are properly invalidated when data is modified.
"""
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import backend.server.v2.library.cache as library_cache
import backend.server.v2.library.db as library_db
@pytest.fixture
def mock_user_id():
"""Generate a mock user ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def mock_library_agent_id():
"""Generate a mock library agent ID for testing."""
return str(uuid.uuid4())
class TestLibraryAgentCacheInvalidation:
"""Test cache invalidation for library agent operations."""
@pytest.mark.asyncio
async def test_add_agent_clears_list_cache(self, mock_user_id):
"""Test that adding an agent clears the library agents list cache."""
# Clear cache
library_cache.get_cached_library_agents.cache_clear()
with patch.object(
library_db, "list_library_agents", new_callable=AsyncMock
) as mock_list:
mock_response = MagicMock(agents=[], total_count=0, page=1, page_size=20)
mock_list.return_value = mock_response
# First call hits database
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
assert mock_list.call_count == 1
# Second call uses cache
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
assert mock_list.call_count == 1 # Still 1, cache used
# Simulate adding an agent (cache invalidation)
for page in range(1, 5):
library_cache.get_cached_library_agents.cache_delete(
mock_user_id, page, 15
)
library_cache.get_cached_library_agents.cache_delete(
mock_user_id, page, 20
)
# Next call should hit database
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
assert mock_list.call_count == 2
@pytest.mark.asyncio
async def test_delete_agent_clears_multiple_caches(
self, mock_user_id, mock_library_agent_id
):
"""Test that deleting an agent clears both specific and list caches."""
# Clear caches
library_cache.get_cached_library_agent.cache_clear()
library_cache.get_cached_library_agents.cache_clear()
with patch.object(
library_db, "get_library_agent", new_callable=AsyncMock
) as mock_get, patch.object(
library_db, "list_library_agents", new_callable=AsyncMock
) as mock_list:
mock_agent = MagicMock(id=mock_library_agent_id, name="Test Agent")
mock_get.return_value = mock_agent
mock_list.return_value = MagicMock(agents=[mock_agent])
# Populate caches
await library_cache.get_cached_library_agent(
mock_library_agent_id, mock_user_id
)
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
initial_calls = {
"get": mock_get.call_count,
"list": mock_list.call_count,
}
# Verify cache is used
await library_cache.get_cached_library_agent(
mock_library_agent_id, mock_user_id
)
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
assert mock_get.call_count == initial_calls["get"]
assert mock_list.call_count == initial_calls["list"]
# Simulate delete_library_agent cache invalidation
library_cache.get_cached_library_agent.cache_delete(
mock_library_agent_id, mock_user_id
)
for page in range(1, 5):
library_cache.get_cached_library_agents.cache_delete(
mock_user_id, page, 15
)
library_cache.get_cached_library_agents.cache_delete(
mock_user_id, page, 20
)
# Next calls should hit database
await library_cache.get_cached_library_agent(
mock_library_agent_id, mock_user_id
)
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
assert mock_get.call_count == initial_calls["get"] + 1
assert mock_list.call_count == initial_calls["list"] + 1
@pytest.mark.asyncio
async def test_favorites_cache_operations(self, mock_user_id):
"""Test that favorites cache works independently."""
# Clear cache
library_cache.get_cached_library_agent_favorites.cache_clear()
with patch.object(
library_db, "list_favorite_library_agents", new_callable=AsyncMock
) as mock_favs:
mock_response = MagicMock(agents=[], total_count=0, page=1, page_size=20)
mock_favs.return_value = mock_response
# First call hits database
await library_cache.get_cached_library_agent_favorites(mock_user_id, 1, 20)
assert mock_favs.call_count == 1
# Second call uses cache
await library_cache.get_cached_library_agent_favorites(mock_user_id, 1, 20)
assert mock_favs.call_count == 1 # Cache used
# Clear cache
library_cache.get_cached_library_agent_favorites.cache_delete(
mock_user_id, 1, 20
)
# Next call hits database
await library_cache.get_cached_library_agent_favorites(mock_user_id, 1, 20)
assert mock_favs.call_count == 2
class TestLibraryPresetCacheInvalidation:
"""Test cache invalidation for library preset operations."""
@pytest.mark.asyncio
async def test_preset_cache_operations(self, mock_user_id):
"""Test preset cache and invalidation."""
# Clear cache
library_cache.get_cached_library_presets.cache_clear()
library_cache.get_cached_library_preset.cache_clear()
preset_id = str(uuid.uuid4())
with patch.object(
library_db, "list_presets", new_callable=AsyncMock
) as mock_list, patch.object(
library_db, "get_preset", new_callable=AsyncMock
) as mock_get:
mock_preset = MagicMock(id=preset_id, name="Test Preset")
mock_list.return_value = MagicMock(presets=[mock_preset])
mock_get.return_value = mock_preset
# Populate caches
await library_cache.get_cached_library_presets(mock_user_id, 1, 20)
await library_cache.get_cached_library_preset(preset_id, mock_user_id)
initial_calls = {
"list": mock_list.call_count,
"get": mock_get.call_count,
}
# Verify cache is used
await library_cache.get_cached_library_presets(mock_user_id, 1, 20)
await library_cache.get_cached_library_preset(preset_id, mock_user_id)
assert mock_list.call_count == initial_calls["list"]
assert mock_get.call_count == initial_calls["get"]
# Clear specific preset cache
library_cache.get_cached_library_preset.cache_delete(
preset_id, mock_user_id
)
# Clear list cache
library_cache.get_cached_library_presets.cache_delete(mock_user_id, 1, 20)
# Next calls should hit database
await library_cache.get_cached_library_presets(mock_user_id, 1, 20)
await library_cache.get_cached_library_preset(preset_id, mock_user_id)
assert mock_list.call_count == initial_calls["list"] + 1
assert mock_get.call_count == initial_calls["get"] + 1
class TestLibraryCacheMetrics:
"""Test library cache metrics and management."""
def test_cache_info_structure(self):
"""Test that cache_info returns expected structure."""
info = library_cache.get_cached_library_agents.cache_info()
assert "size" in info
assert "maxsize" in info
assert "ttl_seconds" in info
assert info["maxsize"] == 1000 # As defined in cache.py
assert info["ttl_seconds"] == 600 # 10 minutes
def test_all_library_caches_can_be_cleared(self):
"""Test that all library caches can be cleared."""
# Clear all library caches
library_cache.get_cached_library_agents.cache_clear()
library_cache.get_cached_library_agent_favorites.cache_clear()
library_cache.get_cached_library_agent.cache_clear()
library_cache.get_cached_library_agent_by_graph_id.cache_clear()
library_cache.get_cached_library_agent_by_store_version.cache_clear()
library_cache.get_cached_library_presets.cache_clear()
library_cache.get_cached_library_preset.cache_clear()
# Verify all are empty
assert library_cache.get_cached_library_agents.cache_info()["size"] == 0
assert (
library_cache.get_cached_library_agent_favorites.cache_info()["size"] == 0
)
assert library_cache.get_cached_library_agent.cache_info()["size"] == 0
assert (
library_cache.get_cached_library_agent_by_graph_id.cache_info()["size"] == 0
)
assert (
library_cache.get_cached_library_agent_by_store_version.cache_info()["size"]
== 0
)
assert library_cache.get_cached_library_presets.cache_info()["size"] == 0
assert library_cache.get_cached_library_preset.cache_info()["size"] == 0
def test_cache_ttl_values(self):
"""Test that cache TTL values are set correctly."""
# Library agents - 10 minutes
assert (
library_cache.get_cached_library_agents.cache_info()["ttl_seconds"] == 600
)
# Favorites - 5 minutes (more dynamic)
assert (
library_cache.get_cached_library_agent_favorites.cache_info()["ttl_seconds"]
== 300
)
# Individual agent - 30 minutes
assert (
library_cache.get_cached_library_agent.cache_info()["ttl_seconds"] == 1800
)
# Presets - 30 minutes
assert (
library_cache.get_cached_library_presets.cache_info()["ttl_seconds"] == 1800
)
assert (
library_cache.get_cached_library_preset.cache_info()["ttl_seconds"] == 1800
)

View File

@@ -5,6 +5,7 @@ import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, Body, HTTPException, Query, Security, status
from fastapi.responses import Response
import backend.server.v2.library.cache as library_cache
import backend.server.v2.library.db as library_db
import backend.server.v2.library.model as library_model
import backend.server.v2.store.exceptions as store_exceptions
@@ -64,13 +65,22 @@ async def list_library_agents(
HTTPException: If a server/database error occurs.
"""
try:
return await library_db.list_library_agents(
user_id=user_id,
search_term=search_term,
sort_by=sort_by,
page=page,
page_size=page_size,
)
# Use cache for default queries (no search term, default sort)
if search_term is None and sort_by == library_model.LibraryAgentSort.UPDATED_AT:
return await library_cache.get_cached_library_agents(
user_id=user_id,
page=page,
page_size=page_size,
)
else:
# Direct DB query for searches and custom sorts
return await library_db.list_library_agents(
user_id=user_id,
search_term=search_term,
sort_by=sort_by,
page=page,
page_size=page_size,
)
except Exception as e:
logger.error(f"Could not list library agents for user #{user_id}: {e}")
raise HTTPException(
@@ -114,7 +124,7 @@ async def list_favorite_library_agents(
HTTPException: If a server/database error occurs.
"""
try:
return await library_db.list_favorite_library_agents(
return await library_cache.get_cached_library_agent_favorites(
user_id=user_id,
page=page,
page_size=page_size,
@@ -132,7 +142,9 @@ async def get_library_agent(
library_agent_id: str,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryAgent:
return await library_db.get_library_agent(id=library_agent_id, user_id=user_id)
return await library_cache.get_cached_library_agent(
library_agent_id=library_agent_id, user_id=user_id
)
@router.get("/by-graph/{graph_id}")
@@ -210,11 +222,19 @@ async def add_marketplace_agent_to_library(
HTTPException(500): If a server/database error occurs.
"""
try:
return await library_db.add_store_agent_to_library(
result = await library_db.add_store_agent_to_library(
store_listing_version_id=store_listing_version_id,
user_id=user_id,
)
# Clear library caches after adding new agent
for page in range(1, 20):
library_cache.get_cached_library_agents.cache_delete(
user_id=user_id, page=page, page_size=8
)
return result
except store_exceptions.AgentNotFoundError as e:
logger.warning(
f"Could not find store listing version {store_listing_version_id} "
@@ -320,6 +340,16 @@ async def delete_library_agent(
await library_db.delete_library_agent(
library_agent_id=library_agent_id, user_id=user_id
)
# Clear caches after deleting agent
library_cache.get_cached_library_agent.cache_delete(
library_agent_id=library_agent_id, user_id=user_id
)
for page in range(1, 20):
library_cache.get_cached_library_agents.cache_delete(
user_id=user_id, page=page, page_size=8
)
return Response(status_code=status.HTTP_204_NO_CONTENT)
except NotFoundError as e:
raise HTTPException(

View File

@@ -4,6 +4,8 @@ from typing import Any, Optional
import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, Body, HTTPException, Query, Security, status
import backend.server.routers.cache as cache
import backend.server.v2.library.cache as library_cache
import backend.server.v2.library.db as db
import backend.server.v2.library.model as models
from backend.data.execution import GraphExecutionMeta
@@ -51,12 +53,21 @@ async def list_presets(
models.LibraryAgentPresetResponse: A response containing the list of presets.
"""
try:
return await db.list_presets(
user_id=user_id,
graph_id=graph_id,
page=page,
page_size=page_size,
)
# Use cache only for default queries (no filter)
if graph_id is None:
return await library_cache.get_cached_library_presets(
user_id=user_id,
page=page,
page_size=page_size,
)
else:
# Direct DB query for filtered requests
return await db.list_presets(
user_id=user_id,
graph_id=graph_id,
page=page,
page_size=page_size,
)
except Exception as e:
logger.exception("Failed to list presets for user %s: %s", user_id, e)
raise HTTPException(
@@ -87,7 +98,7 @@ async def get_preset(
HTTPException: If the preset is not found or an error occurs.
"""
try:
preset = await db.get_preset(user_id, preset_id)
preset = await library_cache.get_cached_library_preset(preset_id, user_id)
except Exception as e:
logger.exception(
"Error retrieving preset %s for user %s: %s", preset_id, user_id, e
@@ -131,9 +142,20 @@ async def create_preset(
"""
try:
if isinstance(preset, models.LibraryAgentPresetCreatable):
return await db.create_preset(user_id, preset)
result = await db.create_preset(user_id, preset)
else:
return await db.create_preset_from_graph_execution(user_id, preset)
result = await db.create_preset_from_graph_execution(user_id, preset)
# Clear presets list cache after creating new preset
for page in range(1, 5):
library_cache.get_cached_library_presets.cache_delete(
user_id=user_id, page=page, page_size=10
)
library_cache.get_cached_library_presets.cache_delete(
user_id=user_id, page=page, page_size=20
)
return result
except NotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except Exception as e:
@@ -200,6 +222,16 @@ async def setup_trigger(
is_active=True,
),
)
# Clear presets list cache after creating new preset
for page in range(1, 5):
library_cache.get_cached_library_presets.cache_delete(
user_id=user_id, page=page, page_size=10
)
library_cache.get_cached_library_presets.cache_delete(
user_id=user_id, page=page, page_size=20
)
return new_preset
@@ -278,6 +310,18 @@ async def update_preset(
description=preset.description,
is_active=preset.is_active,
)
# Clear caches after updating preset
library_cache.get_cached_library_preset.cache_delete(
preset_id=preset_id, user_id=user_id
)
for page in range(1, 5):
library_cache.get_cached_library_presets.cache_delete(
user_id=user_id, page=page, page_size=10
)
library_cache.get_cached_library_presets.cache_delete(
user_id=user_id, page=page, page_size=20
)
except Exception as e:
logger.exception("Preset update failed for user %s: %s", user_id, e)
raise HTTPException(
@@ -351,6 +395,18 @@ async def delete_preset(
try:
await db.delete_preset(user_id, preset_id)
# Clear caches after deleting preset
library_cache.get_cached_library_preset.cache_delete(
preset_id=preset_id, user_id=user_id
)
for page in range(1, 5):
library_cache.get_cached_library_presets.cache_delete(
user_id=user_id, page=page, page_size=10
)
library_cache.get_cached_library_presets.cache_delete(
user_id=user_id, page=page, page_size=20
)
except Exception as e:
logger.exception(
"Error deleting preset %s for user %s: %s", preset_id, user_id, e
@@ -401,6 +457,14 @@ async def execute_preset(
merged_node_input = preset.inputs | inputs
merged_credential_inputs = preset.credentials | credential_inputs
for page in range(1, 10):
cache.get_cached_graph_executions.cache_delete(
graph_id=preset.graph_id, user_id=user_id, page=page, page_size=20
)
cache.get_cached_graph_executions.cache_delete(
user_id=user_id, page=page, page_size=20
)
return await add_graph_execution(
user_id=user_id,
graph_id=preset.graph_id,

View File

@@ -179,14 +179,15 @@ async def test_get_favorite_library_agents_success(
def test_get_favorite_library_agents_error(
mocker: pytest_mock.MockFixture, test_user_id: str
):
mock_db_call = mocker.patch(
"backend.server.v2.library.db.list_favorite_library_agents"
# Mock the cache function instead of the DB directly since routes now use cache
mock_cache_call = mocker.patch(
"backend.server.v2.library.routes.agents.library_cache.get_cached_library_agent_favorites"
)
mock_db_call.side_effect = Exception("Test error")
mock_cache_call.side_effect = Exception("Test error")
response = client.get("/agents/favorites")
assert response.status_code == 500
mock_db_call.assert_called_once_with(
mock_cache_call.assert_called_once_with(
user_id=test_user_id,
page=1,
page_size=15,

View File

@@ -0,0 +1,125 @@
"""
Cache functions for Store API endpoints.
This module contains all caching decorators and helpers for the Store API,
separated from the main routes for better organization and maintainability.
"""
from autogpt_libs.utils.cache import cached
import backend.server.v2.store.db
# Cache user profiles for 1 hour per user
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=1000, ttl_seconds=3600, shared_cache=True)
async def _get_cached_user_profile(user_id: str):
"""Cached helper to get user profile."""
return await backend.server.v2.store.db.get_user_profile(user_id)
# Cache store agents list for 15 minutes
# Different cache entries for different query combinations
# No shared_cache - cache_delete not used for this function
@cached(maxsize=5000, ttl_seconds=900)
async def _get_cached_store_agents(
featured: bool,
creator: str | None,
sorted_by: str | None,
search_query: str | None,
category: str | None,
page: int,
page_size: int,
):
"""Cached helper to get store agents."""
return await backend.server.v2.store.db.get_store_agents(
featured=featured,
creators=[creator] if creator else None,
sorted_by=sorted_by,
search_query=search_query,
category=category,
page=page,
page_size=page_size,
)
# Cache individual agent details for 15 minutes
# No shared_cache - cache_delete not used for this function
@cached(maxsize=200, ttl_seconds=900)
async def _get_cached_agent_details(username: str, agent_name: str):
"""Cached helper to get agent details."""
return await backend.server.v2.store.db.get_store_agent_details(
username=username, agent_name=agent_name
)
# Cache agent graphs for 1 hour
# No shared_cache - cache_delete not used for this function
@cached(maxsize=200, ttl_seconds=3600)
async def _get_cached_agent_graph(store_listing_version_id: str):
"""Cached helper to get agent graph."""
return await backend.server.v2.store.db.get_available_graph(
store_listing_version_id
)
# Cache agent by version for 1 hour
# No shared_cache - cache_delete not used for this function
@cached(maxsize=200, ttl_seconds=3600)
async def _get_cached_store_agent_by_version(store_listing_version_id: str):
"""Cached helper to get store agent by version ID."""
return await backend.server.v2.store.db.get_store_agent_by_version_id(
store_listing_version_id
)
# Cache creators list for 1 hour
# No shared_cache - cache_delete not used for this function
@cached(maxsize=200, ttl_seconds=3600)
async def _get_cached_store_creators(
featured: bool,
search_query: str | None,
sorted_by: str | None,
page: int,
page_size: int,
):
"""Cached helper to get store creators."""
return await backend.server.v2.store.db.get_store_creators(
featured=featured,
search_query=search_query,
sorted_by=sorted_by,
page=page,
page_size=page_size,
)
# Cache individual creator details for 1 hour
# No shared_cache - cache_delete not used for this function
@cached(maxsize=100, ttl_seconds=3600)
async def _get_cached_creator_details(username: str):
"""Cached helper to get creator details."""
return await backend.server.v2.store.db.get_store_creator_details(
username=username.lower()
)
# Cache user's own agents for 5 mins (shorter TTL as this changes more frequently)
# No shared_cache - cache_delete not used for this function
@cached(maxsize=500, ttl_seconds=300)
async def _get_cached_my_agents(user_id: str, page: int, page_size: int):
"""Cached helper to get user's agents."""
return await backend.server.v2.store.db.get_my_agents(
user_id, page=page, page_size=page_size
)
# Cache user's submissions for 1 hour (shorter TTL as this changes frequently)
# Uses shared_cache since cache_delete is called on this function
@cached(maxsize=500, ttl_seconds=3600, shared_cache=True)
async def _get_cached_submissions(user_id: str, page: int, page_size: int):
"""Cached helper to get user's submissions."""
return await backend.server.v2.store.db.get_store_submissions(
user_id=user_id,
page=page,
page_size=page_size,
)

View File

@@ -6,7 +6,6 @@ import urllib.parse
import autogpt_libs.auth
import fastapi
import fastapi.responses
from autogpt_libs.utils.cache import cached
import backend.data.graph
import backend.server.v2.store.db
@@ -15,123 +14,22 @@ import backend.server.v2.store.image_gen
import backend.server.v2.store.media
import backend.server.v2.store.model
import backend.util.json
from backend.server.v2.store.cache import (
_get_cached_agent_details,
_get_cached_agent_graph,
_get_cached_creator_details,
_get_cached_my_agents,
_get_cached_store_agent_by_version,
_get_cached_store_agents,
_get_cached_store_creators,
_get_cached_submissions,
_get_cached_user_profile,
)
logger = logging.getLogger(__name__)
router = fastapi.APIRouter()
##############################################
############### Caches #######################
##############################################
# Cache user profiles for 1 hour per user
@cached(maxsize=1000, ttl_seconds=3600)
async def _get_cached_user_profile(user_id: str):
"""Cached helper to get user profile."""
return await backend.server.v2.store.db.get_user_profile(user_id)
# Cache store agents list for 15 minutes
# Different cache entries for different query combinations
@cached(maxsize=5000, ttl_seconds=900)
async def _get_cached_store_agents(
featured: bool,
creator: str | None,
sorted_by: str | None,
search_query: str | None,
category: str | None,
page: int,
page_size: int,
):
"""Cached helper to get store agents."""
return await backend.server.v2.store.db.get_store_agents(
featured=featured,
creators=[creator] if creator else None,
sorted_by=sorted_by,
search_query=search_query,
category=category,
page=page,
page_size=page_size,
)
# Cache individual agent details for 15 minutes
@cached(maxsize=200, ttl_seconds=900)
async def _get_cached_agent_details(username: str, agent_name: str):
"""Cached helper to get agent details."""
return await backend.server.v2.store.db.get_store_agent_details(
username=username, agent_name=agent_name
)
# Cache agent graphs for 1 hour
@cached(maxsize=200, ttl_seconds=3600)
async def _get_cached_agent_graph(store_listing_version_id: str):
"""Cached helper to get agent graph."""
return await backend.server.v2.store.db.get_available_graph(
store_listing_version_id
)
# Cache agent by version for 1 hour
@cached(maxsize=200, ttl_seconds=3600)
async def _get_cached_store_agent_by_version(store_listing_version_id: str):
"""Cached helper to get store agent by version ID."""
return await backend.server.v2.store.db.get_store_agent_by_version_id(
store_listing_version_id
)
# Cache creators list for 1 hour
@cached(maxsize=200, ttl_seconds=3600)
async def _get_cached_store_creators(
featured: bool,
search_query: str | None,
sorted_by: str | None,
page: int,
page_size: int,
):
"""Cached helper to get store creators."""
return await backend.server.v2.store.db.get_store_creators(
featured=featured,
search_query=search_query,
sorted_by=sorted_by,
page=page,
page_size=page_size,
)
# Cache individual creator details for 1 hour
@cached(maxsize=100, ttl_seconds=3600)
async def _get_cached_creator_details(username: str):
"""Cached helper to get creator details."""
return await backend.server.v2.store.db.get_store_creator_details(
username=username.lower()
)
# Cache user's own agents for 5 mins (shorter TTL as this changes more frequently)
@cached(maxsize=500, ttl_seconds=300)
async def _get_cached_my_agents(user_id: str, page: int, page_size: int):
"""Cached helper to get user's agents."""
return await backend.server.v2.store.db.get_my_agents(
user_id, page=page, page_size=page_size
)
# Cache user's submissions for 1 hour (shorter TTL as this changes frequently)
@cached(maxsize=500, ttl_seconds=3600)
async def _get_cached_submissions(user_id: str, page: int, page_size: int):
"""Cached helper to get user's submissions."""
return await backend.server.v2.store.db.get_store_submissions(
user_id=user_id,
page=page,
page_size=page_size,
)
##############################################
############### Profile Endpoints ############
##############################################

View File

@@ -0,0 +1,76 @@
"""
Initialize Redis clients for the cache system.
This module bridges the gap between backend's Redis client and autogpt_libs cache system.
"""
import logging
from autogpt_libs.utils.cache import set_redis_client_provider
logger = logging.getLogger(__name__)
def initialize_cache_redis():
"""
Initialize the cache system with backend's Redis clients.
This function sets up the cache system to use the backend's
existing Redis connection instead of creating its own.
"""
try:
from backend.data.redis_client import HOST, PASSWORD, PORT
# Create provider functions that create new binary-mode clients
def get_sync_redis_for_cache():
"""Get sync Redis client configured for cache (binary mode)."""
try:
from redis import Redis
client = Redis(
host=HOST,
port=PORT,
password=PASSWORD,
decode_responses=False, # Binary mode for pickle
)
# Test the connection
client.ping()
return client
except Exception as e:
logger.warning(f"Failed to get Redis client for cache: {e}")
return None
async def get_async_redis_for_cache():
"""Get async Redis client configured for cache (binary mode)."""
try:
from redis.asyncio import Redis as AsyncRedis
client = AsyncRedis(
host=HOST,
port=PORT,
password=PASSWORD,
decode_responses=False, # Binary mode for pickle
)
# Test the connection
await client.ping()
return client
except Exception as e:
logger.warning(f"Failed to get async Redis client for cache: {e}")
return None
# Set the providers in the cache system
set_redis_client_provider(
sync_provider=get_sync_redis_for_cache,
async_provider=get_async_redis_for_cache,
)
logger.info("Cache system initialized with backend Redis clients")
except ImportError as e:
logger.warning(f"Could not import Redis clients, cache will use fallback: {e}")
except Exception as e:
logger.error(f"Failed to initialize cache Redis: {e}")
# Auto-initialize when module is imported
initialize_cache_redis()