mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 16:48:06 -05:00
Compare commits
12 Commits
dev
...
swiftyos/r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae6ad35bf6 | ||
|
|
1afebcf96b | ||
|
|
bc5eb8a8a5 | ||
|
|
872ef5fdfb | ||
|
|
12382e7990 | ||
|
|
d68a3a1b53 | ||
|
|
863e213af3 | ||
|
|
c61af53a74 | ||
|
|
eb94503de8 | ||
|
|
ee4feff8c2 | ||
|
|
9147c2d6c8 | ||
|
|
a3af430c69 |
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
from functools import wraps
|
||||
@@ -20,6 +22,92 @@ R_co = TypeVar("R_co", covariant=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis client providers (can be set externally)
|
||||
_redis_client_provider = None
|
||||
_async_redis_client_provider = None
|
||||
|
||||
|
||||
def set_redis_client_provider(sync_provider=None, async_provider=None):
|
||||
"""
|
||||
Set external Redis client providers.
|
||||
|
||||
This allows the backend to inject its Redis clients into the cache system.
|
||||
|
||||
Args:
|
||||
sync_provider: A callable that returns a sync Redis client
|
||||
async_provider: A callable that returns an async Redis client
|
||||
"""
|
||||
global _redis_client_provider, _async_redis_client_provider
|
||||
if sync_provider:
|
||||
_redis_client_provider = sync_provider
|
||||
if async_provider:
|
||||
_async_redis_client_provider = async_provider
|
||||
|
||||
|
||||
def _get_redis_client():
|
||||
"""Get Redis client from provider or create a default one."""
|
||||
if _redis_client_provider:
|
||||
try:
|
||||
client = _redis_client_provider()
|
||||
if client:
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get Redis client from provider: {e}")
|
||||
|
||||
# Fallback to creating our own client
|
||||
try:
|
||||
from redis import Redis
|
||||
|
||||
client = Redis(
|
||||
host=os.getenv("REDIS_HOST", "localhost"),
|
||||
port=int(os.getenv("REDIS_PORT", "6379")),
|
||||
password=os.getenv("REDIS_PASSWORD", None),
|
||||
decode_responses=False, # We'll use pickle for serialization
|
||||
)
|
||||
client.ping()
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to connect to Redis: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _get_async_redis_client():
|
||||
"""Get async Redis client from provider or create a default one."""
|
||||
if _async_redis_client_provider:
|
||||
try:
|
||||
# Provider is an async function, we need to await it
|
||||
if inspect.iscoroutinefunction(_async_redis_client_provider):
|
||||
client = await _async_redis_client_provider()
|
||||
else:
|
||||
client = _async_redis_client_provider()
|
||||
if client:
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get async Redis client from provider: {e}")
|
||||
|
||||
# Fallback to creating our own client
|
||||
try:
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
client = AsyncRedis(
|
||||
host=os.getenv("REDIS_HOST", "localhost"),
|
||||
port=int(os.getenv("REDIS_PORT", "6379")),
|
||||
password=os.getenv("REDIS_PASSWORD", None),
|
||||
decode_responses=False, # We'll use pickle for serialization
|
||||
)
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create async Redis client: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _make_redis_key(func_name: str, key: tuple) -> str:
|
||||
"""Create a Redis key from function name and cache key."""
|
||||
# Convert the key to a string representation
|
||||
key_str = str(key)
|
||||
# Add a prefix to avoid collisions with other Redis usage
|
||||
return f"cache:{func_name}:{key_str}"
|
||||
|
||||
|
||||
def _make_hashable_key(
|
||||
args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
@@ -85,6 +173,7 @@ def cached(
|
||||
*,
|
||||
maxsize: int = 128,
|
||||
ttl_seconds: int | None = None,
|
||||
shared_cache: bool = False,
|
||||
) -> Callable[[Callable], CachedFunction]:
|
||||
"""
|
||||
Thundering herd safe cache decorator for both sync and async functions.
|
||||
@@ -92,10 +181,13 @@ def cached(
|
||||
Uses double-checked locking to prevent multiple threads/coroutines from
|
||||
executing the expensive operation simultaneously during cache misses.
|
||||
|
||||
When shared_cache=True, uses Redis for distributed caching across instances.
|
||||
|
||||
Args:
|
||||
func: The function to cache (when used without parentheses)
|
||||
maxsize: Maximum number of cached entries
|
||||
maxsize: Maximum number of cached entries (ignored when using Redis)
|
||||
ttl_seconds: Time to live in seconds. If None, entries never expire
|
||||
shared_cache: If True, use Redis for distributed caching
|
||||
|
||||
Returns:
|
||||
Decorated function or decorator
|
||||
@@ -127,50 +219,100 @@ def cached(
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
current_time = time.time()
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
# Try Redis first if shared_cache is enabled
|
||||
if shared_cache:
|
||||
redis_client = await _get_async_redis_client()
|
||||
if redis_client:
|
||||
redis_key = _make_redis_key(target_func.__name__, key)
|
||||
try:
|
||||
# Check Redis cache
|
||||
await redis_client.ping() # Ensure connection is alive
|
||||
cached_value = await redis_client.get(redis_key)
|
||||
if cached_value:
|
||||
result = pickle.loads(cast(bytes, cached_value))
|
||||
logger.info(
|
||||
f"Redis cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
|
||||
)
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
async with cache_lock:
|
||||
# Double-check: another coroutine might have populated cache
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache read failed: {e}")
|
||||
# Fall through to execute function
|
||||
else:
|
||||
# Fast path: check local cache without lock
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
logger.info(
|
||||
f"Cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
|
||||
)
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.info(
|
||||
f"Cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
|
||||
)
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
async with cache_lock:
|
||||
# Double-check: another coroutine might have populated cache
|
||||
if shared_cache:
|
||||
redis_client = await _get_async_redis_client()
|
||||
if redis_client:
|
||||
redis_key = _make_redis_key(target_func.__name__, key)
|
||||
try:
|
||||
cached_value = await redis_client.get(redis_key)
|
||||
if cached_value:
|
||||
return pickle.loads(cast(bytes, cached_value))
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache read failed in lock: {e}")
|
||||
else:
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
logger.info(f"Cache miss for {target_func.__name__}")
|
||||
result = await target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
if shared_cache:
|
||||
redis_client = await _get_async_redis_client()
|
||||
if redis_client:
|
||||
redis_key = _make_redis_key(target_func.__name__, key)
|
||||
try:
|
||||
serialized = pickle.dumps(result)
|
||||
await redis_client.set(
|
||||
redis_key,
|
||||
serialized,
|
||||
ex=ttl_seconds if ttl_seconds else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache write failed: {e}")
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = (
|
||||
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
)
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = (
|
||||
list(cache_storage.keys())[:-cutoff]
|
||||
if cutoff > 0
|
||||
else []
|
||||
)
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
|
||||
return result
|
||||
|
||||
@@ -185,50 +327,99 @@ def cached(
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
current_time = time.time()
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(f"Cache hit for {target_func.__name__}")
|
||||
# Try Redis first if shared_cache is enabled
|
||||
if shared_cache:
|
||||
redis_client = _get_redis_client()
|
||||
if redis_client:
|
||||
redis_key = _make_redis_key(target_func.__name__, key)
|
||||
try:
|
||||
# Check Redis cache
|
||||
cached_value = redis_client.get(redis_key)
|
||||
if cached_value:
|
||||
result = pickle.loads(cast(bytes, cached_value))
|
||||
logger.info(
|
||||
f"Redis cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
|
||||
)
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
with cache_lock:
|
||||
# Double-check: another thread might have populated cache
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache read failed: {e}")
|
||||
# Fall through to execute function
|
||||
else:
|
||||
# Fast path: check local cache without lock
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
logger.info(
|
||||
f"Cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
|
||||
)
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.info(
|
||||
f"Cache hit for {target_func.__name__}, args: {args}, kwargs: {kwargs}"
|
||||
)
|
||||
return result
|
||||
|
||||
# Slow path: acquire lock for cache miss/expiry
|
||||
with cache_lock:
|
||||
# Double-check: another thread might have populated cache
|
||||
if shared_cache:
|
||||
redis_client = _get_redis_client()
|
||||
if redis_client:
|
||||
redis_key = _make_redis_key(target_func.__name__, key)
|
||||
try:
|
||||
cached_value = redis_client.get(redis_key)
|
||||
if cached_value:
|
||||
return pickle.loads(cast(bytes, cached_value))
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache read failed in lock: {e}")
|
||||
else:
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
return cache_storage[key]
|
||||
else:
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
logger.info(f"Cache miss for {target_func.__name__}")
|
||||
result = target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
if shared_cache:
|
||||
redis_client = _get_redis_client()
|
||||
if redis_client:
|
||||
redis_key = _make_redis_key(target_func.__name__, key)
|
||||
try:
|
||||
serialized = pickle.dumps(result)
|
||||
redis_client.set(
|
||||
redis_key,
|
||||
serialized,
|
||||
ex=ttl_seconds if ttl_seconds else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache write failed: {e}")
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = (
|
||||
list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
)
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
# Cleanup if needed
|
||||
if len(cache_storage) > maxsize:
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = (
|
||||
list(cache_storage.keys())[:-cutoff]
|
||||
if cutoff > 0
|
||||
else []
|
||||
)
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
|
||||
return result
|
||||
|
||||
@@ -236,22 +427,105 @@ def cached(
|
||||
|
||||
# Add cache management methods
|
||||
def cache_clear() -> None:
|
||||
cache_storage.clear()
|
||||
"""Clear all cached entries."""
|
||||
if shared_cache:
|
||||
redis_client = (
|
||||
_get_redis_client()
|
||||
if not inspect.iscoroutinefunction(target_func)
|
||||
else None
|
||||
)
|
||||
if redis_client:
|
||||
try:
|
||||
# Clear all cache entries for this function
|
||||
pattern = f"cache:{target_func.__name__}:*"
|
||||
for key in redis_client.scan_iter(match=pattern):
|
||||
redis_client.delete(key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache clear failed: {e}")
|
||||
else:
|
||||
cache_storage.clear()
|
||||
|
||||
def cache_info() -> dict[str, int | None]:
|
||||
def cache_info() -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
if shared_cache:
|
||||
redis_client = (
|
||||
_get_redis_client()
|
||||
if not inspect.iscoroutinefunction(target_func)
|
||||
else None
|
||||
)
|
||||
if redis_client:
|
||||
try:
|
||||
pattern = f"cache:{target_func.__name__}:*"
|
||||
count = sum(1 for _ in redis_client.scan_iter(match=pattern))
|
||||
return {
|
||||
"size": count,
|
||||
"maxsize": None, # Not applicable for Redis
|
||||
"ttl_seconds": ttl_seconds,
|
||||
"shared_cache": True,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache info failed: {e}")
|
||||
return {
|
||||
"size": 0,
|
||||
"maxsize": None,
|
||||
"ttl_seconds": ttl_seconds,
|
||||
"shared_cache": True,
|
||||
"error": str(e),
|
||||
}
|
||||
return {
|
||||
"size": len(cache_storage),
|
||||
"maxsize": maxsize,
|
||||
"ttl_seconds": ttl_seconds,
|
||||
"shared_cache": False,
|
||||
}
|
||||
|
||||
def cache_delete(*args, **kwargs) -> bool:
|
||||
"""Delete a specific cache entry. Returns True if entry existed."""
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if key in cache_storage:
|
||||
del cache_storage[key]
|
||||
return True
|
||||
return False
|
||||
# Create appropriate cache_delete based on whether function is async
|
||||
if inspect.iscoroutinefunction(target_func):
|
||||
|
||||
async def async_cache_delete(*args, **kwargs) -> bool:
|
||||
"""Delete a specific cache entry. Returns True if entry existed."""
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
|
||||
if shared_cache:
|
||||
redis_client = await _get_async_redis_client()
|
||||
if redis_client:
|
||||
redis_key = _make_redis_key(target_func.__name__, key)
|
||||
try:
|
||||
result = await redis_client.delete(redis_key)
|
||||
return cast(int, result) > 0
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache delete failed: {e}")
|
||||
return False
|
||||
else:
|
||||
if key in cache_storage:
|
||||
del cache_storage[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
cache_delete = async_cache_delete
|
||||
else:
|
||||
|
||||
def sync_cache_delete(*args, **kwargs) -> bool:
|
||||
"""Delete a specific cache entry. Returns True if entry existed."""
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
|
||||
if shared_cache:
|
||||
redis_client = _get_redis_client()
|
||||
if redis_client:
|
||||
redis_key = _make_redis_key(target_func.__name__, key)
|
||||
try:
|
||||
result = redis_client.delete(redis_key)
|
||||
return cast(int, result) > 0
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis cache delete failed: {e}")
|
||||
return False
|
||||
else:
|
||||
if key in cache_storage:
|
||||
del cache_storage[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
cache_delete = sync_cache_delete
|
||||
|
||||
setattr(wrapper, "cache_clear", cache_clear)
|
||||
setattr(wrapper, "cache_info", cache_info)
|
||||
|
||||
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Test Redis cache functionality.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
|
||||
# Test with Redis cache enabled
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_cache_async():
|
||||
"""Test async function with Redis cache."""
|
||||
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=60, shared_cache=True)
|
||||
async def expensive_async_operation(x: int, y: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.1)
|
||||
return x + y
|
||||
|
||||
# First call should execute function
|
||||
result1 = await expensive_async_operation(5, 3)
|
||||
assert result1 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args should use cache
|
||||
result2 = await expensive_async_operation(5, 3)
|
||||
assert result2 == 8
|
||||
assert call_count == 1 # Should not increment
|
||||
|
||||
# Different args should execute function again
|
||||
result3 = await expensive_async_operation(10, 5)
|
||||
assert result3 == 15
|
||||
assert call_count == 2
|
||||
|
||||
# Test cache_delete
|
||||
deleted = expensive_async_operation.cache_delete(5, 3)
|
||||
assert isinstance(deleted, bool) # Depends on Redis availability
|
||||
|
||||
# Test cache_clear
|
||||
expensive_async_operation.cache_clear()
|
||||
|
||||
# Test cache_info
|
||||
info = expensive_async_operation.cache_info()
|
||||
assert "ttl_seconds" in info
|
||||
assert info["ttl_seconds"] == 60
|
||||
assert "shared_cache" in info
|
||||
|
||||
|
||||
def test_redis_cache_sync():
|
||||
"""Test sync function with Redis cache."""
|
||||
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=60, shared_cache=True)
|
||||
def expensive_sync_operation(x: int, y: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
time.sleep(0.1)
|
||||
return x * y
|
||||
|
||||
# First call should execute function
|
||||
result1 = expensive_sync_operation(5, 3)
|
||||
assert result1 == 15
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args should use cache
|
||||
result2 = expensive_sync_operation(5, 3)
|
||||
assert result2 == 15
|
||||
assert call_count == 1 # Should not increment
|
||||
|
||||
# Different args should execute function again
|
||||
result3 = expensive_sync_operation(10, 5)
|
||||
assert result3 == 50
|
||||
assert call_count == 2
|
||||
|
||||
# Test cache management functions
|
||||
expensive_sync_operation.cache_clear()
|
||||
info = expensive_sync_operation.cache_info()
|
||||
assert info["shared_cache"] is True
|
||||
|
||||
|
||||
def test_fallback_to_local_cache_when_redis_unavailable():
|
||||
"""Test that cache falls back to local when Redis is unavailable."""
|
||||
|
||||
with patch("autogpt_libs.utils.cache._get_redis_client", return_value=None):
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=60, shared_cache=True)
|
||||
def operation(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
|
||||
# Should still work with local cache
|
||||
result1 = operation(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
result2 = operation(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1 # Cached locally
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_cache_with_complex_types():
|
||||
"""Test Redis cache with complex data types."""
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
async def get_complex_data(user_id: str, filters: dict) -> dict:
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"filters": filters,
|
||||
"data": [1, 2, 3, 4, 5],
|
||||
"nested": {"key1": "value1", "key2": ["a", "b", "c"]},
|
||||
}
|
||||
|
||||
result1 = await get_complex_data("user123", {"status": "active", "limit": 10})
|
||||
result2 = await get_complex_data("user123", {"status": "active", "limit": 10})
|
||||
|
||||
assert result1 == result2
|
||||
assert result1["user_id"] == "user123"
|
||||
assert result1["filters"]["status"] == "active"
|
||||
|
||||
|
||||
def test_local_cache_without_shared():
|
||||
"""Test that shared_cache=False uses local cache only."""
|
||||
|
||||
call_count = 0
|
||||
|
||||
@cached(maxsize=10, ttl_seconds=30, shared_cache=False)
|
||||
def local_only_operation(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
result1 = local_only_operation(4)
|
||||
assert result1 == 16
|
||||
assert call_count == 1
|
||||
|
||||
result2 = local_only_operation(4)
|
||||
assert result2 == 16
|
||||
assert call_count == 1
|
||||
|
||||
# Check cache info shows it's not shared
|
||||
info = local_only_operation.cache_info()
|
||||
assert info["shared_cache"] is False
|
||||
assert info["maxsize"] == 10
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run basic tests
|
||||
print("Testing sync Redis cache...")
|
||||
test_redis_cache_sync()
|
||||
print("✓ Sync Redis cache test passed")
|
||||
|
||||
print("Testing local cache...")
|
||||
test_local_cache_without_shared()
|
||||
print("✓ Local cache test passed")
|
||||
|
||||
print("Testing fallback...")
|
||||
test_fallback_to_local_cache_when_redis_unavailable()
|
||||
print("✓ Fallback test passed")
|
||||
|
||||
print("\nAll tests passed!")
|
||||
@@ -70,6 +70,11 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
await backend.data.db.connect()
|
||||
|
||||
# Initialize Redis clients for cache system
|
||||
from backend.util.cache_redis_init import initialize_cache_redis
|
||||
|
||||
initialize_cache_redis()
|
||||
|
||||
# Ensure SDK auto-registration is patched before initializing blocks
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
|
||||
|
||||
163
autogpt_platform/backend/backend/server/routers/cache.py
Normal file
163
autogpt_platform/backend/backend/server/routers/cache.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Cache functions for main V1 API endpoints.
|
||||
|
||||
This module contains all caching decorators and helpers for the V1 API,
|
||||
separated from the main routes for better organization and maintainability.
|
||||
"""
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data import user as user_db
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
# ===== Block Caches =====
|
||||
|
||||
|
||||
# Cache block definitions with costs - they rarely change
|
||||
@cached(maxsize=1, ttl_seconds=3600)
|
||||
def get_cached_blocks() -> Sequence[dict]:
|
||||
"""
|
||||
Get cached blocks with thundering herd protection.
|
||||
|
||||
Uses cached decorator to prevent multiple concurrent requests
|
||||
from all executing the expensive block loading operation.
|
||||
"""
|
||||
from backend.data.credit import get_block_cost
|
||||
|
||||
block_classes = get_blocks()
|
||||
result = []
|
||||
|
||||
for block_class in block_classes.values():
|
||||
block_instance = block_class()
|
||||
if not block_instance.disabled:
|
||||
# Get costs for this specific block class without creating another instance
|
||||
costs = get_block_cost(block_instance)
|
||||
result.append({**block_instance.to_dict(), "costs": costs})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ===== Graph Caches =====
|
||||
|
||||
|
||||
# Cache user's graphs list for 15 minutes
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=1000, ttl_seconds=900, shared_cache=True)
|
||||
async def get_cached_graphs(
|
||||
user_id: str,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get user's graphs."""
|
||||
return await graph_db.list_graphs_paginated(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual graph details for 30 minutes
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_graph(
|
||||
graph_id: str,
|
||||
version: int | None,
|
||||
user_id: str,
|
||||
):
|
||||
"""Cached helper to get graph details."""
|
||||
return await graph_db.get_graph(
|
||||
graph_id=graph_id,
|
||||
version=version,
|
||||
user_id=user_id,
|
||||
include_subgraphs=True, # needed to construct full credentials input schema
|
||||
)
|
||||
|
||||
|
||||
# Cache graph versions for 30 minutes
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_graph_all_versions(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
"""Cached helper to get all versions of a graph."""
|
||||
return await graph_db.get_graph_all_versions(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
# ===== Execution Caches =====
|
||||
|
||||
|
||||
# Cache graph executions for 10 seconds.
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=1000, ttl_seconds=10, shared_cache=True)
|
||||
async def get_cached_graph_executions(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get graph executions."""
|
||||
return await execution_db.get_graph_executions_paginated(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache all user executions for 10 seconds.
|
||||
# No shared_cache - cache_delete not used for this function
|
||||
@cached(maxsize=500, ttl_seconds=10)
|
||||
async def get_cached_graphs_executions(
|
||||
user_id: str,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get all user's graph executions."""
|
||||
return await execution_db.get_graph_executions_paginated(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual execution details for 10 seconds.
|
||||
# No shared_cache - cache_delete not used for this function
|
||||
@cached(maxsize=1000, ttl_seconds=10)
|
||||
async def get_cached_graph_execution(
|
||||
graph_exec_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Cached helper to get graph execution details."""
|
||||
return await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=graph_exec_id,
|
||||
include_node_executions=False,
|
||||
)
|
||||
|
||||
|
||||
# ===== User Preference Caches =====
|
||||
|
||||
|
||||
# Cache user timezone for 1 hour
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=1000, ttl_seconds=3600, shared_cache=True)
|
||||
async def get_cached_user_timezone(user_id: str):
|
||||
"""Cached helper to get user timezone."""
|
||||
user = await user_db.get_user_by_id(user_id)
|
||||
return {"timezone": user.timezone if user else "UTC"}
|
||||
|
||||
|
||||
# Cache user preferences for 30 minutes
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_user_preferences(user_id: str):
|
||||
"""Cached helper to get user notification preferences."""
|
||||
return await user_db.get_user_notification_preference(user_id)
|
||||
346
autogpt_platform/backend/backend/server/routers/cache_test.py
Normal file
346
autogpt_platform/backend/backend/server/routers/cache_test.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
Tests for cache invalidation in V1 API routes.
|
||||
|
||||
This module tests that caches are properly invalidated when data is modified
|
||||
through POST, PUT, PATCH, and DELETE operations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import backend.server.routers.cache as cache
|
||||
from backend.data import graph as graph_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_id():
|
||||
"""Generate a mock user ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_id():
|
||||
"""Generate a mock graph ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class TestGraphCacheInvalidation:
|
||||
"""Test cache invalidation for graph operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_graph_clears_list_cache(self, mock_user_id):
|
||||
"""Test that creating a graph clears the graphs list cache."""
|
||||
# Setup
|
||||
cache.get_cached_graphs.cache_clear()
|
||||
|
||||
# Pre-populate cache
|
||||
with patch.object(
|
||||
graph_db, "list_graphs_paginated", new_callable=AsyncMock
|
||||
) as mock_list:
|
||||
mock_list.return_value = MagicMock(graphs=[])
|
||||
|
||||
# First call should hit the database
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
assert mock_list.call_count == 1
|
||||
|
||||
# Second call should use cache
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
assert mock_list.call_count == 1 # Still 1, used cache
|
||||
|
||||
# Simulate cache invalidation (what happens in create_new_graph)
|
||||
cache.get_cached_graphs.cache_delete(mock_user_id, 1, 250)
|
||||
|
||||
# Next call should hit database again
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
assert mock_list.call_count == 2 # Incremented, cache was cleared
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_graph_clears_multiple_caches(
|
||||
self, mock_user_id, mock_graph_id
|
||||
):
|
||||
"""Test that deleting a graph clears all related caches."""
|
||||
# Clear all caches first
|
||||
cache.get_cached_graphs.cache_clear()
|
||||
cache.get_cached_graph.cache_clear()
|
||||
cache.get_cached_graph_all_versions.cache_clear()
|
||||
cache.get_cached_graph_executions.cache_clear()
|
||||
|
||||
# Setup mocks
|
||||
with patch.object(
|
||||
graph_db, "list_graphs_paginated", new_callable=AsyncMock
|
||||
) as mock_list, patch.object(
|
||||
graph_db, "get_graph", new_callable=AsyncMock
|
||||
) as mock_get, patch.object(
|
||||
graph_db, "get_graph_all_versions", new_callable=AsyncMock
|
||||
) as mock_versions:
|
||||
|
||||
mock_list.return_value = MagicMock(graphs=[])
|
||||
mock_get.return_value = MagicMock(id=mock_graph_id)
|
||||
mock_versions.return_value = []
|
||||
|
||||
# Pre-populate all caches (use consistent argument style)
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
|
||||
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
|
||||
|
||||
initial_calls = {
|
||||
"list": mock_list.call_count,
|
||||
"get": mock_get.call_count,
|
||||
"versions": mock_versions.call_count,
|
||||
}
|
||||
|
||||
# Use cached values (no additional DB calls)
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
|
||||
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
|
||||
|
||||
# Verify cache was used
|
||||
assert mock_list.call_count == initial_calls["list"]
|
||||
assert mock_get.call_count == initial_calls["get"]
|
||||
assert mock_versions.call_count == initial_calls["versions"]
|
||||
|
||||
# Simulate delete_graph cache invalidation
|
||||
# Use positional arguments for cache_delete to match how we called the functions
|
||||
result1 = cache.get_cached_graphs.cache_delete(mock_user_id, 1, 250)
|
||||
result2 = cache.get_cached_graph.cache_delete(
|
||||
mock_graph_id, None, mock_user_id
|
||||
)
|
||||
result3 = cache.get_cached_graph_all_versions.cache_delete(
|
||||
mock_graph_id, mock_user_id
|
||||
)
|
||||
|
||||
# Verify that the cache entries were actually deleted
|
||||
assert result1, "Failed to delete graphs cache entry"
|
||||
assert result2, "Failed to delete graph cache entry"
|
||||
assert result3, "Failed to delete graph versions cache entry"
|
||||
|
||||
# Next calls should hit database
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
|
||||
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
|
||||
|
||||
# Verify database was called again
|
||||
assert mock_list.call_count == initial_calls["list"] + 1
|
||||
assert mock_get.call_count == initial_calls["get"] + 1
|
||||
assert mock_versions.call_count == initial_calls["versions"] + 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_graph_clears_caches(self, mock_user_id, mock_graph_id):
|
||||
"""Test that updating a graph clears the appropriate caches."""
|
||||
# Clear caches
|
||||
cache.get_cached_graph.cache_clear()
|
||||
cache.get_cached_graph_all_versions.cache_clear()
|
||||
cache.get_cached_graphs.cache_clear()
|
||||
|
||||
with patch.object(
|
||||
graph_db, "get_graph", new_callable=AsyncMock
|
||||
) as mock_get, patch.object(
|
||||
graph_db, "get_graph_all_versions", new_callable=AsyncMock
|
||||
) as mock_versions, patch.object(
|
||||
graph_db, "list_graphs_paginated", new_callable=AsyncMock
|
||||
) as mock_list:
|
||||
|
||||
mock_get.return_value = MagicMock(id=mock_graph_id, version=1)
|
||||
mock_versions.return_value = [MagicMock(version=1)]
|
||||
mock_list.return_value = MagicMock(graphs=[])
|
||||
|
||||
# Populate caches
|
||||
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
|
||||
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
|
||||
initial_calls = {
|
||||
"get": mock_get.call_count,
|
||||
"versions": mock_versions.call_count,
|
||||
"list": mock_list.call_count,
|
||||
}
|
||||
|
||||
# Verify cache is being used
|
||||
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
|
||||
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
|
||||
assert mock_get.call_count == initial_calls["get"]
|
||||
assert mock_versions.call_count == initial_calls["versions"]
|
||||
assert mock_list.call_count == initial_calls["list"]
|
||||
|
||||
# Simulate update_graph cache invalidation
|
||||
cache.get_cached_graph.cache_delete(mock_graph_id, None, mock_user_id)
|
||||
cache.get_cached_graph_all_versions.cache_delete(
|
||||
mock_graph_id, mock_user_id
|
||||
)
|
||||
cache.get_cached_graphs.cache_delete(mock_user_id, 1, 250)
|
||||
|
||||
# Next calls should hit database
|
||||
await cache.get_cached_graph(mock_graph_id, None, mock_user_id)
|
||||
await cache.get_cached_graph_all_versions(mock_graph_id, mock_user_id)
|
||||
await cache.get_cached_graphs(mock_user_id, 1, 250)
|
||||
|
||||
assert mock_get.call_count == initial_calls["get"] + 1
|
||||
assert mock_versions.call_count == initial_calls["versions"] + 1
|
||||
assert mock_list.call_count == initial_calls["list"] + 1
|
||||
|
||||
|
||||
class TestUserPreferencesCacheInvalidation:
|
||||
"""Test cache invalidation for user preferences operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_preferences_clears_cache(self, mock_user_id):
|
||||
"""Test that updating preferences clears the preferences cache."""
|
||||
# Clear cache
|
||||
cache.get_cached_user_preferences.cache_clear()
|
||||
|
||||
with patch.object(
|
||||
cache.user_db, "get_user_notification_preference", new_callable=AsyncMock
|
||||
) as mock_get_prefs:
|
||||
mock_prefs = MagicMock(email_notifications=True, push_notifications=False)
|
||||
mock_get_prefs.return_value = mock_prefs
|
||||
|
||||
# First call hits database
|
||||
result1 = await cache.get_cached_user_preferences(mock_user_id)
|
||||
assert mock_get_prefs.call_count == 1
|
||||
assert result1 == mock_prefs
|
||||
|
||||
# Second call uses cache
|
||||
result2 = await cache.get_cached_user_preferences(mock_user_id)
|
||||
assert mock_get_prefs.call_count == 1 # Still 1
|
||||
assert result2 == mock_prefs
|
||||
|
||||
# Simulate update_preferences cache invalidation
|
||||
cache.get_cached_user_preferences.cache_delete(mock_user_id)
|
||||
|
||||
# Change the mock return value to simulate updated preferences
|
||||
mock_prefs_updated = MagicMock(
|
||||
email_notifications=False, push_notifications=True
|
||||
)
|
||||
mock_get_prefs.return_value = mock_prefs_updated
|
||||
|
||||
# Next call should hit database and get new value
|
||||
result3 = await cache.get_cached_user_preferences(mock_user_id)
|
||||
assert mock_get_prefs.call_count == 2
|
||||
assert result3 == mock_prefs_updated
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timezone_cache_operations(self, mock_user_id):
|
||||
"""Test timezone cache and its operations."""
|
||||
# Clear cache
|
||||
cache.get_cached_user_timezone.cache_clear()
|
||||
|
||||
with patch.object(
|
||||
cache.user_db, "get_user_by_id", new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
mock_user = MagicMock(timezone="America/New_York")
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# First call hits database
|
||||
result1 = await cache.get_cached_user_timezone(mock_user_id)
|
||||
assert mock_get_user.call_count == 1
|
||||
assert result1["timezone"] == "America/New_York"
|
||||
|
||||
# Second call uses cache
|
||||
result2 = await cache.get_cached_user_timezone(mock_user_id)
|
||||
assert mock_get_user.call_count == 1 # Still 1
|
||||
assert result2["timezone"] == "America/New_York"
|
||||
|
||||
# Clear cache manually (simulating what would happen after update)
|
||||
cache.get_cached_user_timezone.cache_delete(mock_user_id)
|
||||
|
||||
# Change timezone
|
||||
mock_user.timezone = "Europe/London"
|
||||
|
||||
# Next call should hit database
|
||||
result3 = await cache.get_cached_user_timezone(mock_user_id)
|
||||
assert mock_get_user.call_count == 2
|
||||
assert result3["timezone"] == "Europe/London"
|
||||
|
||||
|
||||
class TestExecutionCacheInvalidation:
|
||||
"""Test cache invalidation for execution operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execution_cache_cleared_on_graph_delete(
|
||||
self, mock_user_id, mock_graph_id
|
||||
):
|
||||
"""Test that execution caches are cleared when a graph is deleted."""
|
||||
# Clear cache
|
||||
cache.get_cached_graph_executions.cache_clear()
|
||||
|
||||
with patch.object(
|
||||
cache.execution_db, "get_graph_executions_paginated", new_callable=AsyncMock
|
||||
) as mock_exec:
|
||||
mock_exec.return_value = MagicMock(executions=[])
|
||||
|
||||
# Populate cache for multiple pages
|
||||
for page in range(1, 4):
|
||||
await cache.get_cached_graph_executions(
|
||||
mock_graph_id, mock_user_id, page, 25
|
||||
)
|
||||
|
||||
initial_calls = mock_exec.call_count
|
||||
|
||||
# Verify cache is used
|
||||
for page in range(1, 4):
|
||||
await cache.get_cached_graph_executions(
|
||||
mock_graph_id, mock_user_id, page, 25
|
||||
)
|
||||
|
||||
assert mock_exec.call_count == initial_calls # No new calls
|
||||
|
||||
# Simulate graph deletion clearing execution caches
|
||||
for page in range(1, 10): # Clear more pages as done in delete_graph
|
||||
cache.get_cached_graph_executions.cache_delete(
|
||||
mock_graph_id, mock_user_id, page, 25
|
||||
)
|
||||
|
||||
# Next calls should hit database
|
||||
for page in range(1, 4):
|
||||
await cache.get_cached_graph_executions(
|
||||
mock_graph_id, mock_user_id, page, 25
|
||||
)
|
||||
|
||||
assert mock_exec.call_count == initial_calls + 3 # 3 new calls
|
||||
|
||||
|
||||
class TestCacheInfo:
|
||||
"""Test cache information and metrics."""
|
||||
|
||||
def test_cache_info_returns_correct_metrics(self):
|
||||
"""Test that cache_info returns correct metrics."""
|
||||
# Clear all caches
|
||||
cache.get_cached_graphs.cache_clear()
|
||||
cache.get_cached_graph.cache_clear()
|
||||
|
||||
# Get initial info
|
||||
info_graphs = cache.get_cached_graphs.cache_info()
|
||||
info_graph = cache.get_cached_graph.cache_info()
|
||||
|
||||
assert info_graphs["size"] == 0
|
||||
assert info_graph["size"] == 0
|
||||
|
||||
# Note: We can't directly test cache population without real async context,
|
||||
# but we can verify the cache_info structure
|
||||
assert "size" in info_graphs
|
||||
assert "maxsize" in info_graphs
|
||||
assert "ttl_seconds" in info_graphs
|
||||
|
||||
def test_cache_clear_removes_all_entries(self):
|
||||
"""Test that cache_clear removes all entries."""
|
||||
# This test verifies the cache_clear method exists and can be called
|
||||
cache.get_cached_graphs.cache_clear()
|
||||
cache.get_cached_graph.cache_clear()
|
||||
cache.get_cached_graph_all_versions.cache_clear()
|
||||
cache.get_cached_graph_executions.cache_clear()
|
||||
cache.get_cached_graphs_executions.cache_clear()
|
||||
cache.get_cached_user_preferences.cache_clear()
|
||||
cache.get_cached_user_timezone.cache_clear()
|
||||
|
||||
# After clear, all caches should be empty
|
||||
assert cache.get_cached_graphs.cache_info()["size"] == 0
|
||||
assert cache.get_cached_graph.cache_info()["size"] == 0
|
||||
assert cache.get_cached_graph_all_versions.cache_info()["size"] == 0
|
||||
assert cache.get_cached_graph_executions.cache_info()["size"] == 0
|
||||
assert cache.get_cached_graphs_executions.cache_info()["size"] == 0
|
||||
assert cache.get_cached_user_preferences.cache_info()["size"] == 0
|
||||
assert cache.get_cached_user_timezone.cache_info()["size"] == 0
|
||||
@@ -11,7 +11,6 @@ import pydantic
|
||||
import stripe
|
||||
from autogpt_libs.auth import get_user_id, requires_user
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from autogpt_libs.utils.cache import cached
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Body,
|
||||
@@ -29,11 +28,12 @@ from typing_extensions import Optional, TypedDict
|
||||
|
||||
import backend.server.integrations.router
|
||||
import backend.server.routers.analytics
|
||||
import backend.server.routers.cache as cache
|
||||
import backend.server.v2.library.db as library_db
|
||||
from backend.data import api_key as api_key_db
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput, get_block, get_blocks
|
||||
from backend.data.block import BlockInput, CompletedBlockOutput, get_block
|
||||
from backend.data.credit import (
|
||||
AutoTopUpConfig,
|
||||
RefundRequest,
|
||||
@@ -55,7 +55,6 @@ from backend.data.onboarding import (
|
||||
from backend.data.user import (
|
||||
get_or_create_user,
|
||||
get_user_by_id,
|
||||
get_user_notification_preference,
|
||||
update_user_email,
|
||||
update_user_notification_preference,
|
||||
update_user_timezone,
|
||||
@@ -82,6 +81,7 @@ from backend.server.model import (
|
||||
UpdateTimezoneRequest,
|
||||
UploadFileResponse,
|
||||
)
|
||||
from backend.server.v2.library import cache as library_cache
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.cloud_storage import get_cloud_storage_handler
|
||||
from backend.util.exceptions import GraphValidationError, NotFoundError
|
||||
@@ -165,7 +165,9 @@ async def get_user_timezone_route(
|
||||
) -> TimezoneResponse:
|
||||
"""Get user timezone setting."""
|
||||
user = await get_or_create_user(user_data)
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
# Use cached timezone for subsequent calls
|
||||
result = await cache.get_cached_user_timezone(user.id)
|
||||
return TimezoneResponse(timezone=result["timezone"])
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -179,6 +181,7 @@ async def update_user_timezone_route(
|
||||
) -> TimezoneResponse:
|
||||
"""Update user timezone. The timezone should be a valid IANA timezone identifier."""
|
||||
user = await update_user_timezone(user_id, str(request.timezone))
|
||||
cache.get_cached_user_timezone.cache_delete(user_id)
|
||||
return TimezoneResponse(timezone=user.timezone)
|
||||
|
||||
|
||||
@@ -191,7 +194,7 @@ async def update_user_timezone_route(
|
||||
async def get_preferences(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> NotificationPreference:
|
||||
preferences = await get_user_notification_preference(user_id)
|
||||
preferences = await cache.get_cached_user_preferences(user_id)
|
||||
return preferences
|
||||
|
||||
|
||||
@@ -206,6 +209,10 @@ async def update_preferences(
|
||||
preferences: NotificationPreferenceDTO = Body(...),
|
||||
) -> NotificationPreference:
|
||||
output = await update_user_notification_preference(user_id, preferences)
|
||||
|
||||
# Clear preferences cache after update
|
||||
cache.get_cached_user_preferences.cache_delete(user_id)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -263,29 +270,6 @@ async def is_onboarding_enabled():
|
||||
########################################################
|
||||
|
||||
|
||||
@cached()
|
||||
def _get_cached_blocks() -> Sequence[dict[Any, Any]]:
|
||||
"""
|
||||
Get cached blocks with thundering herd protection.
|
||||
|
||||
Uses sync_cache decorator to prevent multiple concurrent requests
|
||||
from all executing the expensive block loading operation.
|
||||
"""
|
||||
from backend.data.credit import get_block_cost
|
||||
|
||||
block_classes = get_blocks()
|
||||
result = []
|
||||
|
||||
for block_class in block_classes.values():
|
||||
block_instance = block_class()
|
||||
if not block_instance.disabled:
|
||||
# Get costs for this specific block class without creating another instance
|
||||
costs = get_block_cost(block_instance)
|
||||
result.append({**block_instance.to_dict(), "costs": costs})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@v1_router.get(
|
||||
path="/blocks",
|
||||
summary="List available blocks",
|
||||
@@ -293,7 +277,7 @@ def _get_cached_blocks() -> Sequence[dict[Any, Any]]:
|
||||
dependencies=[Security(requires_user)],
|
||||
)
|
||||
async def get_graph_blocks() -> Sequence[dict[Any, Any]]:
|
||||
return _get_cached_blocks()
|
||||
return cache.get_cached_blocks()
|
||||
|
||||
|
||||
@v1_router.post(
|
||||
@@ -633,11 +617,10 @@ class DeleteGraphResponse(TypedDict):
|
||||
async def list_graphs(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> Sequence[graph_db.GraphMeta]:
|
||||
paginated_result = await graph_db.list_graphs_paginated(
|
||||
paginated_result = await cache.get_cached_graphs(
|
||||
user_id=user_id,
|
||||
page=1,
|
||||
page_size=250,
|
||||
filter_by="active",
|
||||
)
|
||||
return paginated_result.graphs
|
||||
|
||||
@@ -660,13 +643,26 @@ async def get_graph(
|
||||
version: int | None = None,
|
||||
for_export: bool = False,
|
||||
) -> graph_db.GraphModel:
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id,
|
||||
version,
|
||||
user_id=user_id,
|
||||
for_export=for_export,
|
||||
include_subgraphs=True, # needed to construct full credentials input schema
|
||||
)
|
||||
# Use cache for non-export requests
|
||||
if not for_export:
|
||||
graph = await cache.get_cached_graph(
|
||||
graph_id=graph_id,
|
||||
version=version,
|
||||
user_id=user_id,
|
||||
)
|
||||
# If graph not found, clear cache entry as permissions may have changed
|
||||
if not graph:
|
||||
cache.get_cached_graph.cache_delete(
|
||||
graph_id=graph_id, version=version, user_id=user_id
|
||||
)
|
||||
else:
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id,
|
||||
version,
|
||||
user_id=user_id,
|
||||
for_export=for_export,
|
||||
include_subgraphs=True, # needed to construct full credentials input schema
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
return graph
|
||||
@@ -681,7 +677,7 @@ async def get_graph(
|
||||
async def get_graph_all_versions(
|
||||
graph_id: str, user_id: Annotated[str, Security(get_user_id)]
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||
graphs = await cache.get_cached_graph_all_versions(graph_id, user_id=user_id)
|
||||
if not graphs:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
return graphs
|
||||
@@ -705,6 +701,14 @@ async def create_new_graph(
|
||||
# as the graph already valid and no sub-graphs are returned back.
|
||||
await graph_db.create_graph(graph, user_id=user_id)
|
||||
await library_db.create_library_agent(graph, user_id=user_id)
|
||||
|
||||
# Clear graphs list cache after creating new graph
|
||||
cache.get_cached_graphs.cache_delete(user_id=user_id, page=1, page_size=250)
|
||||
for page in range(1, 20):
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
user_id=user_id, page=page, page_size=8
|
||||
)
|
||||
|
||||
return await on_graph_activate(graph, user_id=user_id)
|
||||
|
||||
|
||||
@@ -720,7 +724,18 @@ async def delete_graph(
|
||||
if active_version := await graph_db.get_graph(graph_id, user_id=user_id):
|
||||
await on_graph_deactivate(active_version, user_id=user_id)
|
||||
|
||||
return {"version_counts": await graph_db.delete_graph(graph_id, user_id=user_id)}
|
||||
result = DeleteGraphResponse(
|
||||
version_counts=await graph_db.delete_graph(graph_id, user_id=user_id)
|
||||
)
|
||||
|
||||
# Clear caches after deleting graph
|
||||
cache.get_cached_graphs.cache_delete(user_id=user_id, page=1, page_size=250)
|
||||
cache.get_cached_graph.cache_delete(
|
||||
graph_id=graph_id, version=None, user_id=user_id
|
||||
)
|
||||
cache.get_cached_graph_all_versions.cache_delete(graph_id, user_id=user_id)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@v1_router.put(
|
||||
@@ -776,6 +791,14 @@ async def update_graph(
|
||||
include_subgraphs=True,
|
||||
)
|
||||
assert new_graph_version_with_subgraphs # make type checker happy
|
||||
|
||||
# Clear caches after updating graph
|
||||
cache.get_cached_graph.cache_delete(
|
||||
graph_id=graph_id, version=None, user_id=user_id
|
||||
)
|
||||
cache.get_cached_graph_all_versions.cache_delete(graph_id, user_id=user_id)
|
||||
cache.get_cached_graphs.cache_delete(user_id=user_id, page=1, page_size=250)
|
||||
|
||||
return new_graph_version_with_subgraphs
|
||||
|
||||
|
||||
@@ -853,6 +876,12 @@ async def execute_graph(
|
||||
# Record successful graph execution
|
||||
record_graph_execution(graph_id=graph_id, status="success", user_id=user_id)
|
||||
record_graph_operation(operation="execute", status="success")
|
||||
|
||||
for page in range(1, 10):
|
||||
cache.get_cached_graph_executions.cache_delete(
|
||||
graph_id=graph_id, user_id=user_id, page=page, page_size=20
|
||||
)
|
||||
|
||||
return result
|
||||
except GraphValidationError as e:
|
||||
# Record failed graph execution
|
||||
@@ -928,7 +957,7 @@ async def _stop_graph_run(
|
||||
async def list_graphs_executions(
|
||||
user_id: Annotated[str, Security(get_user_id)],
|
||||
) -> list[execution_db.GraphExecutionMeta]:
|
||||
paginated_result = await execution_db.get_graph_executions_paginated(
|
||||
paginated_result = await cache.get_cached_graphs_executions(
|
||||
user_id=user_id,
|
||||
page=1,
|
||||
page_size=250,
|
||||
@@ -950,7 +979,7 @@ async def list_graph_executions(
|
||||
25, ge=1, le=100, description="Number of executions per page"
|
||||
),
|
||||
) -> execution_db.GraphExecutionsPaginated:
|
||||
return await execution_db.get_graph_executions_paginated(
|
||||
return await cache.get_cached_graph_executions(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
|
||||
@@ -102,13 +102,13 @@ def test_get_graph_blocks(
|
||||
mock_block.id = "test-block"
|
||||
mock_block.disabled = False
|
||||
|
||||
# Mock get_blocks
|
||||
# Mock get_blocks in cache module where it's actually used
|
||||
mocker.patch(
|
||||
"backend.server.routers.v1.get_blocks",
|
||||
"backend.server.routers.cache.get_blocks",
|
||||
return_value={"test-block": lambda: mock_block},
|
||||
)
|
||||
|
||||
# Mock block costs
|
||||
# Mock block costs where it's imported inside the function
|
||||
mocker.patch(
|
||||
"backend.data.credit.get_block_cost",
|
||||
return_value=[{"cost": 10, "type": "credit"}],
|
||||
|
||||
117
autogpt_platform/backend/backend/server/v2/library/cache.py
Normal file
117
autogpt_platform/backend/backend/server/v2/library/cache.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Cache functions for Library API endpoints.
|
||||
|
||||
This module contains all caching decorators and helpers for the Library API,
|
||||
separated from the main routes for better organization and maintainability.
|
||||
"""
|
||||
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
import backend.server.v2.library.db
|
||||
|
||||
# ===== Library Agent Caches =====
|
||||
|
||||
|
||||
# Cache library agents list for 10 minutes
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=1000, ttl_seconds=600, shared_cache=True)
|
||||
async def get_cached_library_agents(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""Cached helper to get library agents list."""
|
||||
return await backend.server.v2.library.db.list_library_agents(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache user's favorite agents for 5 minutes - favorites change more frequently
|
||||
# No shared_cache - cache_delete not used for this function (based on search results)
|
||||
@cached(maxsize=500, ttl_seconds=300, shared_cache=False)
|
||||
async def get_cached_library_agent_favorites(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""Cached helper to get user's favorite library agents."""
|
||||
return await backend.server.v2.library.db.list_favorite_library_agents(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual library agent details for 30 minutes
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Cached helper to get library agent details."""
|
||||
return await backend.server.v2.library.db.get_library_agent(
|
||||
id=library_agent_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
# Cache library agent by graph ID for 30 minutes
|
||||
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_library_agent_by_graph_id(
|
||||
graph_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Cached helper to get library agent by graph ID."""
|
||||
return await backend.server.v2.library.db.get_library_agent_by_graph_id(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
# Cache library agent by store version ID for 1 hour - marketplace agents are more stable
|
||||
@cached(maxsize=500, ttl_seconds=3600, shared_cache=True)
|
||||
async def get_cached_library_agent_by_store_version(
|
||||
store_listing_version_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Cached helper to get library agent by store version ID."""
|
||||
return await backend.server.v2.library.db.get_library_agent_by_store_version_id(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
# ===== Library Preset Caches =====
|
||||
|
||||
|
||||
# Cache library presets list for 30 minutes
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=500, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_library_presets(
|
||||
user_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
"""Cached helper to get library presets list."""
|
||||
return await backend.server.v2.library.db.list_presets(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual preset details for 30 minutes
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=1000, ttl_seconds=1800, shared_cache=True)
|
||||
async def get_cached_library_preset(
|
||||
preset_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Cached helper to get library preset details."""
|
||||
return await backend.server.v2.library.db.get_preset(
|
||||
preset_id=preset_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
272
autogpt_platform/backend/backend/server/v2/library/cache_test.py
Normal file
272
autogpt_platform/backend/backend/server/v2/library/cache_test.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
Tests for cache invalidation in Library API routes.
|
||||
|
||||
This module tests that library caches are properly invalidated when data is modified.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import backend.server.v2.library.cache as library_cache
|
||||
import backend.server.v2.library.db as library_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_id():
|
||||
"""Generate a mock user ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_library_agent_id():
|
||||
"""Generate a mock library agent ID for testing."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class TestLibraryAgentCacheInvalidation:
|
||||
"""Test cache invalidation for library agent operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_agent_clears_list_cache(self, mock_user_id):
|
||||
"""Test that adding an agent clears the library agents list cache."""
|
||||
# Clear cache
|
||||
library_cache.get_cached_library_agents.cache_clear()
|
||||
|
||||
with patch.object(
|
||||
library_db, "list_library_agents", new_callable=AsyncMock
|
||||
) as mock_list:
|
||||
mock_response = MagicMock(agents=[], total_count=0, page=1, page_size=20)
|
||||
mock_list.return_value = mock_response
|
||||
|
||||
# First call hits database
|
||||
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
|
||||
assert mock_list.call_count == 1
|
||||
|
||||
# Second call uses cache
|
||||
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
|
||||
assert mock_list.call_count == 1 # Still 1, cache used
|
||||
|
||||
# Simulate adding an agent (cache invalidation)
|
||||
for page in range(1, 5):
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
mock_user_id, page, 15
|
||||
)
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
mock_user_id, page, 20
|
||||
)
|
||||
|
||||
# Next call should hit database
|
||||
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
|
||||
assert mock_list.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_agent_clears_multiple_caches(
|
||||
self, mock_user_id, mock_library_agent_id
|
||||
):
|
||||
"""Test that deleting an agent clears both specific and list caches."""
|
||||
# Clear caches
|
||||
library_cache.get_cached_library_agent.cache_clear()
|
||||
library_cache.get_cached_library_agents.cache_clear()
|
||||
|
||||
with patch.object(
|
||||
library_db, "get_library_agent", new_callable=AsyncMock
|
||||
) as mock_get, patch.object(
|
||||
library_db, "list_library_agents", new_callable=AsyncMock
|
||||
) as mock_list:
|
||||
|
||||
mock_agent = MagicMock(id=mock_library_agent_id, name="Test Agent")
|
||||
mock_get.return_value = mock_agent
|
||||
mock_list.return_value = MagicMock(agents=[mock_agent])
|
||||
|
||||
# Populate caches
|
||||
await library_cache.get_cached_library_agent(
|
||||
mock_library_agent_id, mock_user_id
|
||||
)
|
||||
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
|
||||
|
||||
initial_calls = {
|
||||
"get": mock_get.call_count,
|
||||
"list": mock_list.call_count,
|
||||
}
|
||||
|
||||
# Verify cache is used
|
||||
await library_cache.get_cached_library_agent(
|
||||
mock_library_agent_id, mock_user_id
|
||||
)
|
||||
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
|
||||
|
||||
assert mock_get.call_count == initial_calls["get"]
|
||||
assert mock_list.call_count == initial_calls["list"]
|
||||
|
||||
# Simulate delete_library_agent cache invalidation
|
||||
library_cache.get_cached_library_agent.cache_delete(
|
||||
mock_library_agent_id, mock_user_id
|
||||
)
|
||||
for page in range(1, 5):
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
mock_user_id, page, 15
|
||||
)
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
mock_user_id, page, 20
|
||||
)
|
||||
|
||||
# Next calls should hit database
|
||||
await library_cache.get_cached_library_agent(
|
||||
mock_library_agent_id, mock_user_id
|
||||
)
|
||||
await library_cache.get_cached_library_agents(mock_user_id, 1, 20)
|
||||
|
||||
assert mock_get.call_count == initial_calls["get"] + 1
|
||||
assert mock_list.call_count == initial_calls["list"] + 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_favorites_cache_operations(self, mock_user_id):
|
||||
"""Test that favorites cache works independently."""
|
||||
# Clear cache
|
||||
library_cache.get_cached_library_agent_favorites.cache_clear()
|
||||
|
||||
with patch.object(
|
||||
library_db, "list_favorite_library_agents", new_callable=AsyncMock
|
||||
) as mock_favs:
|
||||
mock_response = MagicMock(agents=[], total_count=0, page=1, page_size=20)
|
||||
mock_favs.return_value = mock_response
|
||||
|
||||
# First call hits database
|
||||
await library_cache.get_cached_library_agent_favorites(mock_user_id, 1, 20)
|
||||
assert mock_favs.call_count == 1
|
||||
|
||||
# Second call uses cache
|
||||
await library_cache.get_cached_library_agent_favorites(mock_user_id, 1, 20)
|
||||
assert mock_favs.call_count == 1 # Cache used
|
||||
|
||||
# Clear cache
|
||||
library_cache.get_cached_library_agent_favorites.cache_delete(
|
||||
mock_user_id, 1, 20
|
||||
)
|
||||
|
||||
# Next call hits database
|
||||
await library_cache.get_cached_library_agent_favorites(mock_user_id, 1, 20)
|
||||
assert mock_favs.call_count == 2
|
||||
|
||||
|
||||
class TestLibraryPresetCacheInvalidation:
|
||||
"""Test cache invalidation for library preset operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preset_cache_operations(self, mock_user_id):
|
||||
"""Test preset cache and invalidation."""
|
||||
# Clear cache
|
||||
library_cache.get_cached_library_presets.cache_clear()
|
||||
library_cache.get_cached_library_preset.cache_clear()
|
||||
|
||||
preset_id = str(uuid.uuid4())
|
||||
|
||||
with patch.object(
|
||||
library_db, "list_presets", new_callable=AsyncMock
|
||||
) as mock_list, patch.object(
|
||||
library_db, "get_preset", new_callable=AsyncMock
|
||||
) as mock_get:
|
||||
|
||||
mock_preset = MagicMock(id=preset_id, name="Test Preset")
|
||||
mock_list.return_value = MagicMock(presets=[mock_preset])
|
||||
mock_get.return_value = mock_preset
|
||||
|
||||
# Populate caches
|
||||
await library_cache.get_cached_library_presets(mock_user_id, 1, 20)
|
||||
await library_cache.get_cached_library_preset(preset_id, mock_user_id)
|
||||
|
||||
initial_calls = {
|
||||
"list": mock_list.call_count,
|
||||
"get": mock_get.call_count,
|
||||
}
|
||||
|
||||
# Verify cache is used
|
||||
await library_cache.get_cached_library_presets(mock_user_id, 1, 20)
|
||||
await library_cache.get_cached_library_preset(preset_id, mock_user_id)
|
||||
|
||||
assert mock_list.call_count == initial_calls["list"]
|
||||
assert mock_get.call_count == initial_calls["get"]
|
||||
|
||||
# Clear specific preset cache
|
||||
library_cache.get_cached_library_preset.cache_delete(
|
||||
preset_id, mock_user_id
|
||||
)
|
||||
|
||||
# Clear list cache
|
||||
library_cache.get_cached_library_presets.cache_delete(mock_user_id, 1, 20)
|
||||
|
||||
# Next calls should hit database
|
||||
await library_cache.get_cached_library_presets(mock_user_id, 1, 20)
|
||||
await library_cache.get_cached_library_preset(preset_id, mock_user_id)
|
||||
|
||||
assert mock_list.call_count == initial_calls["list"] + 1
|
||||
assert mock_get.call_count == initial_calls["get"] + 1
|
||||
|
||||
|
||||
class TestLibraryCacheMetrics:
|
||||
"""Test library cache metrics and management."""
|
||||
|
||||
def test_cache_info_structure(self):
|
||||
"""Test that cache_info returns expected structure."""
|
||||
info = library_cache.get_cached_library_agents.cache_info()
|
||||
|
||||
assert "size" in info
|
||||
assert "maxsize" in info
|
||||
assert "ttl_seconds" in info
|
||||
assert info["maxsize"] == 1000 # As defined in cache.py
|
||||
assert info["ttl_seconds"] == 600 # 10 minutes
|
||||
|
||||
def test_all_library_caches_can_be_cleared(self):
|
||||
"""Test that all library caches can be cleared."""
|
||||
# Clear all library caches
|
||||
library_cache.get_cached_library_agents.cache_clear()
|
||||
library_cache.get_cached_library_agent_favorites.cache_clear()
|
||||
library_cache.get_cached_library_agent.cache_clear()
|
||||
library_cache.get_cached_library_agent_by_graph_id.cache_clear()
|
||||
library_cache.get_cached_library_agent_by_store_version.cache_clear()
|
||||
library_cache.get_cached_library_presets.cache_clear()
|
||||
library_cache.get_cached_library_preset.cache_clear()
|
||||
|
||||
# Verify all are empty
|
||||
assert library_cache.get_cached_library_agents.cache_info()["size"] == 0
|
||||
assert (
|
||||
library_cache.get_cached_library_agent_favorites.cache_info()["size"] == 0
|
||||
)
|
||||
assert library_cache.get_cached_library_agent.cache_info()["size"] == 0
|
||||
assert (
|
||||
library_cache.get_cached_library_agent_by_graph_id.cache_info()["size"] == 0
|
||||
)
|
||||
assert (
|
||||
library_cache.get_cached_library_agent_by_store_version.cache_info()["size"]
|
||||
== 0
|
||||
)
|
||||
assert library_cache.get_cached_library_presets.cache_info()["size"] == 0
|
||||
assert library_cache.get_cached_library_preset.cache_info()["size"] == 0
|
||||
|
||||
def test_cache_ttl_values(self):
|
||||
"""Test that cache TTL values are set correctly."""
|
||||
# Library agents - 10 minutes
|
||||
assert (
|
||||
library_cache.get_cached_library_agents.cache_info()["ttl_seconds"] == 600
|
||||
)
|
||||
|
||||
# Favorites - 5 minutes (more dynamic)
|
||||
assert (
|
||||
library_cache.get_cached_library_agent_favorites.cache_info()["ttl_seconds"]
|
||||
== 300
|
||||
)
|
||||
|
||||
# Individual agent - 30 minutes
|
||||
assert (
|
||||
library_cache.get_cached_library_agent.cache_info()["ttl_seconds"] == 1800
|
||||
)
|
||||
|
||||
# Presets - 30 minutes
|
||||
assert (
|
||||
library_cache.get_cached_library_presets.cache_info()["ttl_seconds"] == 1800
|
||||
)
|
||||
assert (
|
||||
library_cache.get_cached_library_preset.cache_info()["ttl_seconds"] == 1800
|
||||
)
|
||||
@@ -5,6 +5,7 @@ import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
||||
from fastapi.responses import Response
|
||||
|
||||
import backend.server.v2.library.cache as library_cache
|
||||
import backend.server.v2.library.db as library_db
|
||||
import backend.server.v2.library.model as library_model
|
||||
import backend.server.v2.store.exceptions as store_exceptions
|
||||
@@ -64,13 +65,22 @@ async def list_library_agents(
|
||||
HTTPException: If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
return await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_term,
|
||||
sort_by=sort_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
# Use cache for default queries (no search term, default sort)
|
||||
if search_term is None and sort_by == library_model.LibraryAgentSort.UPDATED_AT:
|
||||
return await library_cache.get_cached_library_agents(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
else:
|
||||
# Direct DB query for searches and custom sorts
|
||||
return await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_term,
|
||||
sort_by=sort_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not list library agents for user #{user_id}: {e}")
|
||||
raise HTTPException(
|
||||
@@ -114,7 +124,7 @@ async def list_favorite_library_agents(
|
||||
HTTPException: If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
return await library_db.list_favorite_library_agents(
|
||||
return await library_cache.get_cached_library_agent_favorites(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
@@ -132,7 +142,9 @@ async def get_library_agent(
|
||||
library_agent_id: str,
|
||||
user_id: str = Security(autogpt_auth_lib.get_user_id),
|
||||
) -> library_model.LibraryAgent:
|
||||
return await library_db.get_library_agent(id=library_agent_id, user_id=user_id)
|
||||
return await library_cache.get_cached_library_agent(
|
||||
library_agent_id=library_agent_id, user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
@router.get("/by-graph/{graph_id}")
|
||||
@@ -210,11 +222,19 @@ async def add_marketplace_agent_to_library(
|
||||
HTTPException(500): If a server/database error occurs.
|
||||
"""
|
||||
try:
|
||||
return await library_db.add_store_agent_to_library(
|
||||
result = await library_db.add_store_agent_to_library(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Clear library caches after adding new agent
|
||||
for page in range(1, 20):
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
user_id=user_id, page=page, page_size=8
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except store_exceptions.AgentNotFoundError as e:
|
||||
logger.warning(
|
||||
f"Could not find store listing version {store_listing_version_id} "
|
||||
@@ -320,6 +340,16 @@ async def delete_library_agent(
|
||||
await library_db.delete_library_agent(
|
||||
library_agent_id=library_agent_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Clear caches after deleting agent
|
||||
library_cache.get_cached_library_agent.cache_delete(
|
||||
library_agent_id=library_agent_id, user_id=user_id
|
||||
)
|
||||
for page in range(1, 20):
|
||||
library_cache.get_cached_library_agents.cache_delete(
|
||||
user_id=user_id, page=page, page_size=8
|
||||
)
|
||||
|
||||
return Response(status_code=status.HTTP_204_NO_CONTENT)
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -4,6 +4,8 @@ from typing import Any, Optional
|
||||
import autogpt_libs.auth as autogpt_auth_lib
|
||||
from fastapi import APIRouter, Body, HTTPException, Query, Security, status
|
||||
|
||||
import backend.server.routers.cache as cache
|
||||
import backend.server.v2.library.cache as library_cache
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.library.model as models
|
||||
from backend.data.execution import GraphExecutionMeta
|
||||
@@ -51,12 +53,21 @@ async def list_presets(
|
||||
models.LibraryAgentPresetResponse: A response containing the list of presets.
|
||||
"""
|
||||
try:
|
||||
return await db.list_presets(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
# Use cache only for default queries (no filter)
|
||||
if graph_id is None:
|
||||
return await library_cache.get_cached_library_presets(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
else:
|
||||
# Direct DB query for filtered requests
|
||||
return await db.list_presets(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to list presets for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
@@ -87,7 +98,7 @@ async def get_preset(
|
||||
HTTPException: If the preset is not found or an error occurs.
|
||||
"""
|
||||
try:
|
||||
preset = await db.get_preset(user_id, preset_id)
|
||||
preset = await library_cache.get_cached_library_preset(preset_id, user_id)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Error retrieving preset %s for user %s: %s", preset_id, user_id, e
|
||||
@@ -131,9 +142,20 @@ async def create_preset(
|
||||
"""
|
||||
try:
|
||||
if isinstance(preset, models.LibraryAgentPresetCreatable):
|
||||
return await db.create_preset(user_id, preset)
|
||||
result = await db.create_preset(user_id, preset)
|
||||
else:
|
||||
return await db.create_preset_from_graph_execution(user_id, preset)
|
||||
result = await db.create_preset_from_graph_execution(user_id, preset)
|
||||
|
||||
# Clear presets list cache after creating new preset
|
||||
for page in range(1, 5):
|
||||
library_cache.get_cached_library_presets.cache_delete(
|
||||
user_id=user_id, page=page, page_size=10
|
||||
)
|
||||
library_cache.get_cached_library_presets.cache_delete(
|
||||
user_id=user_id, page=page, page_size=20
|
||||
)
|
||||
|
||||
return result
|
||||
except NotFoundError as e:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||||
except Exception as e:
|
||||
@@ -200,6 +222,16 @@ async def setup_trigger(
|
||||
is_active=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Clear presets list cache after creating new preset
|
||||
for page in range(1, 5):
|
||||
library_cache.get_cached_library_presets.cache_delete(
|
||||
user_id=user_id, page=page, page_size=10
|
||||
)
|
||||
library_cache.get_cached_library_presets.cache_delete(
|
||||
user_id=user_id, page=page, page_size=20
|
||||
)
|
||||
|
||||
return new_preset
|
||||
|
||||
|
||||
@@ -278,6 +310,18 @@ async def update_preset(
|
||||
description=preset.description,
|
||||
is_active=preset.is_active,
|
||||
)
|
||||
|
||||
# Clear caches after updating preset
|
||||
library_cache.get_cached_library_preset.cache_delete(
|
||||
preset_id=preset_id, user_id=user_id
|
||||
)
|
||||
for page in range(1, 5):
|
||||
library_cache.get_cached_library_presets.cache_delete(
|
||||
user_id=user_id, page=page, page_size=10
|
||||
)
|
||||
library_cache.get_cached_library_presets.cache_delete(
|
||||
user_id=user_id, page=page, page_size=20
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Preset update failed for user %s: %s", user_id, e)
|
||||
raise HTTPException(
|
||||
@@ -351,6 +395,18 @@ async def delete_preset(
|
||||
|
||||
try:
|
||||
await db.delete_preset(user_id, preset_id)
|
||||
|
||||
# Clear caches after deleting preset
|
||||
library_cache.get_cached_library_preset.cache_delete(
|
||||
preset_id=preset_id, user_id=user_id
|
||||
)
|
||||
for page in range(1, 5):
|
||||
library_cache.get_cached_library_presets.cache_delete(
|
||||
user_id=user_id, page=page, page_size=10
|
||||
)
|
||||
library_cache.get_cached_library_presets.cache_delete(
|
||||
user_id=user_id, page=page, page_size=20
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Error deleting preset %s for user %s: %s", preset_id, user_id, e
|
||||
@@ -401,6 +457,14 @@ async def execute_preset(
|
||||
merged_node_input = preset.inputs | inputs
|
||||
merged_credential_inputs = preset.credentials | credential_inputs
|
||||
|
||||
for page in range(1, 10):
|
||||
cache.get_cached_graph_executions.cache_delete(
|
||||
graph_id=preset.graph_id, user_id=user_id, page=page, page_size=20
|
||||
)
|
||||
cache.get_cached_graph_executions.cache_delete(
|
||||
user_id=user_id, page=page, page_size=20
|
||||
)
|
||||
|
||||
return await add_graph_execution(
|
||||
user_id=user_id,
|
||||
graph_id=preset.graph_id,
|
||||
|
||||
@@ -179,14 +179,15 @@ async def test_get_favorite_library_agents_success(
|
||||
def test_get_favorite_library_agents_error(
|
||||
mocker: pytest_mock.MockFixture, test_user_id: str
|
||||
):
|
||||
mock_db_call = mocker.patch(
|
||||
"backend.server.v2.library.db.list_favorite_library_agents"
|
||||
# Mock the cache function instead of the DB directly since routes now use cache
|
||||
mock_cache_call = mocker.patch(
|
||||
"backend.server.v2.library.routes.agents.library_cache.get_cached_library_agent_favorites"
|
||||
)
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
mock_cache_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.get("/agents/favorites")
|
||||
assert response.status_code == 500
|
||||
mock_db_call.assert_called_once_with(
|
||||
mock_cache_call.assert_called_once_with(
|
||||
user_id=test_user_id,
|
||||
page=1,
|
||||
page_size=15,
|
||||
|
||||
125
autogpt_platform/backend/backend/server/v2/store/cache.py
Normal file
125
autogpt_platform/backend/backend/server/v2/store/cache.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Cache functions for Store API endpoints.
|
||||
|
||||
This module contains all caching decorators and helpers for the Store API,
|
||||
separated from the main routes for better organization and maintainability.
|
||||
"""
|
||||
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
import backend.server.v2.store.db
|
||||
|
||||
|
||||
# Cache user profiles for 1 hour per user
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=1000, ttl_seconds=3600, shared_cache=True)
|
||||
async def _get_cached_user_profile(user_id: str):
|
||||
"""Cached helper to get user profile."""
|
||||
return await backend.server.v2.store.db.get_user_profile(user_id)
|
||||
|
||||
|
||||
# Cache store agents list for 15 minutes
|
||||
# Different cache entries for different query combinations
|
||||
# No shared_cache - cache_delete not used for this function
|
||||
@cached(maxsize=5000, ttl_seconds=900)
|
||||
async def _get_cached_store_agents(
|
||||
featured: bool,
|
||||
creator: str | None,
|
||||
sorted_by: str | None,
|
||||
search_query: str | None,
|
||||
category: str | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get store agents."""
|
||||
return await backend.server.v2.store.db.get_store_agents(
|
||||
featured=featured,
|
||||
creators=[creator] if creator else None,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual agent details for 15 minutes
|
||||
# No shared_cache - cache_delete not used for this function
|
||||
@cached(maxsize=200, ttl_seconds=900)
|
||||
async def _get_cached_agent_details(username: str, agent_name: str):
|
||||
"""Cached helper to get agent details."""
|
||||
return await backend.server.v2.store.db.get_store_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
)
|
||||
|
||||
|
||||
# Cache agent graphs for 1 hour
|
||||
# No shared_cache - cache_delete not used for this function
|
||||
@cached(maxsize=200, ttl_seconds=3600)
|
||||
async def _get_cached_agent_graph(store_listing_version_id: str):
|
||||
"""Cached helper to get agent graph."""
|
||||
return await backend.server.v2.store.db.get_available_graph(
|
||||
store_listing_version_id
|
||||
)
|
||||
|
||||
|
||||
# Cache agent by version for 1 hour
|
||||
# No shared_cache - cache_delete not used for this function
|
||||
@cached(maxsize=200, ttl_seconds=3600)
|
||||
async def _get_cached_store_agent_by_version(store_listing_version_id: str):
|
||||
"""Cached helper to get store agent by version ID."""
|
||||
return await backend.server.v2.store.db.get_store_agent_by_version_id(
|
||||
store_listing_version_id
|
||||
)
|
||||
|
||||
|
||||
# Cache creators list for 1 hour
|
||||
# No shared_cache - cache_delete not used for this function
|
||||
@cached(maxsize=200, ttl_seconds=3600)
|
||||
async def _get_cached_store_creators(
|
||||
featured: bool,
|
||||
search_query: str | None,
|
||||
sorted_by: str | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get store creators."""
|
||||
return await backend.server.v2.store.db.get_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
sorted_by=sorted_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual creator details for 1 hour
|
||||
# No shared_cache - cache_delete not used for this function
|
||||
@cached(maxsize=100, ttl_seconds=3600)
|
||||
async def _get_cached_creator_details(username: str):
|
||||
"""Cached helper to get creator details."""
|
||||
return await backend.server.v2.store.db.get_store_creator_details(
|
||||
username=username.lower()
|
||||
)
|
||||
|
||||
|
||||
# Cache user's own agents for 5 mins (shorter TTL as this changes more frequently)
|
||||
# No shared_cache - cache_delete not used for this function
|
||||
@cached(maxsize=500, ttl_seconds=300)
|
||||
async def _get_cached_my_agents(user_id: str, page: int, page_size: int):
|
||||
"""Cached helper to get user's agents."""
|
||||
return await backend.server.v2.store.db.get_my_agents(
|
||||
user_id, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
|
||||
# Cache user's submissions for 1 hour (shorter TTL as this changes frequently)
|
||||
# Uses shared_cache since cache_delete is called on this function
|
||||
@cached(maxsize=500, ttl_seconds=3600, shared_cache=True)
|
||||
async def _get_cached_submissions(user_id: str, page: int, page_size: int):
|
||||
"""Cached helper to get user's submissions."""
|
||||
return await backend.server.v2.store.db.get_store_submissions(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
@@ -6,7 +6,6 @@ import urllib.parse
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
from autogpt_libs.utils.cache import cached
|
||||
|
||||
import backend.data.graph
|
||||
import backend.server.v2.store.db
|
||||
@@ -15,123 +14,22 @@ import backend.server.v2.store.image_gen
|
||||
import backend.server.v2.store.media
|
||||
import backend.server.v2.store.model
|
||||
import backend.util.json
|
||||
from backend.server.v2.store.cache import (
|
||||
_get_cached_agent_details,
|
||||
_get_cached_agent_graph,
|
||||
_get_cached_creator_details,
|
||||
_get_cached_my_agents,
|
||||
_get_cached_store_agent_by_version,
|
||||
_get_cached_store_agents,
|
||||
_get_cached_store_creators,
|
||||
_get_cached_submissions,
|
||||
_get_cached_user_profile,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
|
||||
|
||||
##############################################
|
||||
############### Caches #######################
|
||||
##############################################
|
||||
|
||||
|
||||
# Cache user profiles for 1 hour per user
|
||||
@cached(maxsize=1000, ttl_seconds=3600)
|
||||
async def _get_cached_user_profile(user_id: str):
|
||||
"""Cached helper to get user profile."""
|
||||
return await backend.server.v2.store.db.get_user_profile(user_id)
|
||||
|
||||
|
||||
# Cache store agents list for 15 minutes
|
||||
# Different cache entries for different query combinations
|
||||
@cached(maxsize=5000, ttl_seconds=900)
|
||||
async def _get_cached_store_agents(
|
||||
featured: bool,
|
||||
creator: str | None,
|
||||
sorted_by: str | None,
|
||||
search_query: str | None,
|
||||
category: str | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get store agents."""
|
||||
return await backend.server.v2.store.db.get_store_agents(
|
||||
featured=featured,
|
||||
creators=[creator] if creator else None,
|
||||
sorted_by=sorted_by,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual agent details for 15 minutes
|
||||
@cached(maxsize=200, ttl_seconds=900)
|
||||
async def _get_cached_agent_details(username: str, agent_name: str):
|
||||
"""Cached helper to get agent details."""
|
||||
return await backend.server.v2.store.db.get_store_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
)
|
||||
|
||||
|
||||
# Cache agent graphs for 1 hour
|
||||
@cached(maxsize=200, ttl_seconds=3600)
|
||||
async def _get_cached_agent_graph(store_listing_version_id: str):
|
||||
"""Cached helper to get agent graph."""
|
||||
return await backend.server.v2.store.db.get_available_graph(
|
||||
store_listing_version_id
|
||||
)
|
||||
|
||||
|
||||
# Cache agent by version for 1 hour
|
||||
@cached(maxsize=200, ttl_seconds=3600)
|
||||
async def _get_cached_store_agent_by_version(store_listing_version_id: str):
|
||||
"""Cached helper to get store agent by version ID."""
|
||||
return await backend.server.v2.store.db.get_store_agent_by_version_id(
|
||||
store_listing_version_id
|
||||
)
|
||||
|
||||
|
||||
# Cache creators list for 1 hour
|
||||
@cached(maxsize=200, ttl_seconds=3600)
|
||||
async def _get_cached_store_creators(
|
||||
featured: bool,
|
||||
search_query: str | None,
|
||||
sorted_by: str | None,
|
||||
page: int,
|
||||
page_size: int,
|
||||
):
|
||||
"""Cached helper to get store creators."""
|
||||
return await backend.server.v2.store.db.get_store_creators(
|
||||
featured=featured,
|
||||
search_query=search_query,
|
||||
sorted_by=sorted_by,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
# Cache individual creator details for 1 hour
|
||||
@cached(maxsize=100, ttl_seconds=3600)
|
||||
async def _get_cached_creator_details(username: str):
|
||||
"""Cached helper to get creator details."""
|
||||
return await backend.server.v2.store.db.get_store_creator_details(
|
||||
username=username.lower()
|
||||
)
|
||||
|
||||
|
||||
# Cache user's own agents for 5 mins (shorter TTL as this changes more frequently)
|
||||
@cached(maxsize=500, ttl_seconds=300)
|
||||
async def _get_cached_my_agents(user_id: str, page: int, page_size: int):
|
||||
"""Cached helper to get user's agents."""
|
||||
return await backend.server.v2.store.db.get_my_agents(
|
||||
user_id, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
|
||||
# Cache user's submissions for 1 hour (shorter TTL as this changes frequently)
|
||||
@cached(maxsize=500, ttl_seconds=3600)
|
||||
async def _get_cached_submissions(user_id: str, page: int, page_size: int):
|
||||
"""Cached helper to get user's submissions."""
|
||||
return await backend.server.v2.store.db.get_store_submissions(
|
||||
user_id=user_id,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
##############################################
|
||||
############### Profile Endpoints ############
|
||||
##############################################
|
||||
|
||||
76
autogpt_platform/backend/backend/util/cache_redis_init.py
Normal file
76
autogpt_platform/backend/backend/util/cache_redis_init.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
Initialize Redis clients for the cache system.
|
||||
|
||||
This module bridges the gap between backend's Redis client and autogpt_libs cache system.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from autogpt_libs.utils.cache import set_redis_client_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_cache_redis():
|
||||
"""
|
||||
Initialize the cache system with backend's Redis clients.
|
||||
|
||||
This function sets up the cache system to use the backend's
|
||||
existing Redis connection instead of creating its own.
|
||||
"""
|
||||
try:
|
||||
from backend.data.redis_client import HOST, PASSWORD, PORT
|
||||
|
||||
# Create provider functions that create new binary-mode clients
|
||||
def get_sync_redis_for_cache():
|
||||
"""Get sync Redis client configured for cache (binary mode)."""
|
||||
try:
|
||||
from redis import Redis
|
||||
|
||||
client = Redis(
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
password=PASSWORD,
|
||||
decode_responses=False, # Binary mode for pickle
|
||||
)
|
||||
# Test the connection
|
||||
client.ping()
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get Redis client for cache: {e}")
|
||||
return None
|
||||
|
||||
async def get_async_redis_for_cache():
|
||||
"""Get async Redis client configured for cache (binary mode)."""
|
||||
try:
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
client = AsyncRedis(
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
password=PASSWORD,
|
||||
decode_responses=False, # Binary mode for pickle
|
||||
)
|
||||
# Test the connection
|
||||
await client.ping()
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get async Redis client for cache: {e}")
|
||||
return None
|
||||
|
||||
# Set the providers in the cache system
|
||||
set_redis_client_provider(
|
||||
sync_provider=get_sync_redis_for_cache,
|
||||
async_provider=get_async_redis_for_cache,
|
||||
)
|
||||
|
||||
logger.info("Cache system initialized with backend Redis clients")
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import Redis clients, cache will use fallback: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize cache Redis: {e}")
|
||||
|
||||
|
||||
# Auto-initialize when module is imported
|
||||
initialize_cache_redis()
|
||||
Reference in New Issue
Block a user