mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Merge commit from fork
* fix(backend): add HMAC signing to Redis cache to prevent pickle deserialization attacks Add HMAC-SHA256 integrity verification to all values stored in the shared Redis cache. This prevents cache poisoning attacks where an attacker with Redis access injects malicious pickled payloads that execute arbitrary code on deserialization. Changes: - Sign pickled values with HMAC-SHA256 before storing in Redis - Verify HMAC signature before deserializing cached values - Reject tampered or unsigned (legacy) cache entries gracefully (treated as cache misses, logged as warnings) - Derive HMAC key from redis_password or unsubscribe_secret_key - Add tests for HMAC round-trip, tamper detection, and legacy rejection Fixes GHSA-rfg2-37xq-w4m9 * improve log message --------- Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
This commit is contained in:
@@ -10,6 +10,8 @@ Provides decorators for caching function results with support for:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import inspect
|
||||
import logging
|
||||
import pickle
|
||||
@@ -32,6 +34,9 @@ T = TypeVar("T")
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
# Length of the HMAC-SHA256 signature prefix on cached values.
|
||||
_HMAC_SIG_LEN = 32
|
||||
|
||||
# RECOMMENDED REDIS CONFIGURATION FOR PRODUCTION:
|
||||
# Configure Redis with the following settings for optimal caching performance:
|
||||
# maxmemory-policy allkeys-lru # Evict least recently used keys when memory limit reached
|
||||
@@ -176,35 +181,45 @@ def cached(
|
||||
"""
|
||||
|
||||
def decorator(target_func: Callable[P, R]) -> CachedFunction[P, R]:
|
||||
func_name = target_func.__name__
|
||||
cache_storage: dict[tuple, CachedValue] = {}
|
||||
_event_loop_locks: dict[Any, asyncio.Lock] = {}
|
||||
|
||||
def _get_from_redis(redis_key: str) -> Any | None:
|
||||
"""Get value from Redis, optionally refreshing TTL."""
|
||||
"""Get value from Redis, optionally refreshing TTL.
|
||||
|
||||
Values are expected to carry an HMAC-SHA256 prefix for integrity
|
||||
verification. Unsigned (legacy) or tampered entries are silently
|
||||
discarded and treated as cache misses, so the caller recomputes and
|
||||
re-stores them with a valid signature.
|
||||
"""
|
||||
try:
|
||||
if refresh_ttl_on_get:
|
||||
# Use GETEX to get value and refresh expiry atomically
|
||||
cached_bytes = _get_redis().getex(redis_key, ex=ttl_seconds)
|
||||
else:
|
||||
cached_bytes = _get_redis().get(redis_key)
|
||||
|
||||
if cached_bytes and isinstance(cached_bytes, bytes):
|
||||
return pickle.loads(cached_bytes)
|
||||
payload = _verify_and_strip(cached_bytes)
|
||||
if payload is None:
|
||||
logger.warning(
|
||||
"[SECURITY] Cache HMAC verification failed "
|
||||
f"for {func_name}, discarding entry: "
|
||||
"possible tampering or legacy unsigned value"
|
||||
)
|
||||
return None
|
||||
return pickle.loads(payload)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Redis error during cache check for {target_func.__name__}: {e}"
|
||||
)
|
||||
logger.error(f"Redis error during cache check for {func_name}: {e}")
|
||||
return None
|
||||
|
||||
def _set_to_redis(redis_key: str, value: Any) -> None:
|
||||
"""Set value in Redis with TTL."""
|
||||
"""Set HMAC-signed pickled value in Redis with TTL."""
|
||||
try:
|
||||
pickled_value = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
_get_redis().setex(redis_key, ttl_seconds, pickled_value)
|
||||
pickled = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
_get_redis().setex(redis_key, ttl_seconds, _sign_payload(pickled))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Redis error storing cache for {target_func.__name__}: {e}"
|
||||
)
|
||||
logger.error(f"Redis error storing cache for {func_name}: {e}")
|
||||
|
||||
def _get_from_memory(key: tuple) -> Any | None:
|
||||
"""Get value from in-memory cache, checking TTL."""
|
||||
@@ -212,7 +227,7 @@ def cached(
|
||||
cached_data = cache_storage[key]
|
||||
if time.time() - cached_data.timestamp < ttl_seconds:
|
||||
logger.debug(
|
||||
f"Cache hit for {target_func.__name__} args: {key[0]} kwargs: {key[1]}"
|
||||
f"Cache hit for {func_name} args: {key[0]} kwargs: {key[1]}"
|
||||
)
|
||||
return cached_data.result
|
||||
return None
|
||||
@@ -244,9 +259,7 @@ def cached(
|
||||
@wraps(target_func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
redis_key = (
|
||||
_make_redis_key(key, target_func.__name__) if shared_cache else ""
|
||||
)
|
||||
redis_key = _make_redis_key(key, func_name) if shared_cache else ""
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if shared_cache:
|
||||
@@ -271,7 +284,7 @@ def cached(
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
logger.debug(f"Cache miss for {func_name}")
|
||||
result = await target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
@@ -291,9 +304,7 @@ def cached(
|
||||
@wraps(target_func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
redis_key = (
|
||||
_make_redis_key(key, target_func.__name__) if shared_cache else ""
|
||||
)
|
||||
redis_key = _make_redis_key(key, func_name) if shared_cache else ""
|
||||
|
||||
# Fast path: check cache without lock
|
||||
if shared_cache:
|
||||
@@ -318,7 +329,7 @@ def cached(
|
||||
return result
|
||||
|
||||
# Cache miss - execute function
|
||||
logger.debug(f"Cache miss for {target_func.__name__}")
|
||||
logger.debug(f"Cache miss for {func_name}")
|
||||
result = target_func(*args, **kwargs)
|
||||
|
||||
# Store result
|
||||
@@ -337,16 +348,10 @@ def cached(
|
||||
if shared_cache:
|
||||
if pattern:
|
||||
# Clear entries matching pattern
|
||||
keys = list(
|
||||
_get_redis().scan_iter(
|
||||
f"cache:{target_func.__name__}:{pattern}"
|
||||
)
|
||||
)
|
||||
keys = list(_get_redis().scan_iter(f"cache:{func_name}:{pattern}"))
|
||||
else:
|
||||
# Clear all cache keys
|
||||
keys = list(
|
||||
_get_redis().scan_iter(f"cache:{target_func.__name__}:*")
|
||||
)
|
||||
keys = list(_get_redis().scan_iter(f"cache:{func_name}:*"))
|
||||
|
||||
if keys:
|
||||
pipeline = _get_redis().pipeline()
|
||||
@@ -364,9 +369,7 @@ def cached(
|
||||
|
||||
def cache_info() -> dict[str, int | None]:
|
||||
if shared_cache:
|
||||
cache_keys = list(
|
||||
_get_redis().scan_iter(f"cache:{target_func.__name__}:*")
|
||||
)
|
||||
cache_keys = list(_get_redis().scan_iter(f"cache:{func_name}:*"))
|
||||
return {
|
||||
"size": len(cache_keys),
|
||||
"maxsize": None, # Redis manages its own size
|
||||
@@ -383,7 +386,7 @@ def cached(
|
||||
"""Delete a specific cache entry. Returns True if entry existed."""
|
||||
key = _make_hashable_key(args, kwargs)
|
||||
if shared_cache:
|
||||
redis_key = _make_redis_key(key, target_func.__name__)
|
||||
redis_key = _make_redis_key(key, func_name)
|
||||
deleted_count = cast(int, _get_redis().delete(redis_key))
|
||||
return deleted_count > 0
|
||||
else:
|
||||
@@ -401,6 +404,52 @@ def cached(
|
||||
return decorator
|
||||
|
||||
|
||||
def _sign_payload(data: bytes) -> bytes:
|
||||
"""Return *signature + data* (32-byte HMAC-SHA256 prefix)."""
|
||||
sig = hmac.new(_get_hmac_key(), data, hashlib.sha256).digest()
|
||||
return sig + data
|
||||
|
||||
|
||||
def _verify_and_strip(blob: bytes) -> bytes | None:
|
||||
"""Verify the HMAC prefix and return the payload, or `None`.
|
||||
|
||||
During deployment, the cache may still contain unsigned (legacy) entries.
|
||||
These will fail verification and return `None`, causing callers to treat
|
||||
them as cache misses. The value is then recomputed and stored with a valid
|
||||
HMAC signature. This means the transition is fully automatic: no cache
|
||||
flush is required, and all entries self-heal on next access within their
|
||||
TTL window (max 1 hour for the longest-lived entries).
|
||||
"""
|
||||
if len(blob) <= _HMAC_SIG_LEN:
|
||||
return None
|
||||
sig, data = blob[:_HMAC_SIG_LEN], blob[_HMAC_SIG_LEN:]
|
||||
expected = hmac.new(_get_hmac_key(), data, hashlib.sha256).digest()
|
||||
if hmac.compare_digest(sig, expected):
|
||||
return data
|
||||
return None
|
||||
|
||||
|
||||
@cache
|
||||
def _get_hmac_key() -> bytes:
|
||||
"""Derive a stable HMAC key for signing cached values in Redis.
|
||||
|
||||
Uses `encryption_key` — a backend-only secret that Redis never sees.
|
||||
This ensures that even if an attacker compromises Redis, they cannot forge
|
||||
valid HMAC signatures for cache entries.
|
||||
|
||||
Falls back to a hardcoded default with a loud warning so the decorator
|
||||
never crashes in development/test environments without secrets configured.
|
||||
"""
|
||||
secret = settings.secrets.encryption_key
|
||||
if not secret:
|
||||
logger.warning(
|
||||
"[SECURITY] No encryption_key configured: cache HMAC signing will use a "
|
||||
"weak default key. Set ENCRYPTION_KEY for production deployments."
|
||||
)
|
||||
secret = "autogpt-cache-default-hmac-key"
|
||||
return hashlib.sha256(secret.encode()).digest()
|
||||
|
||||
|
||||
def thread_cached(func):
|
||||
"""
|
||||
Thread-local cache decorator for both sync and async functions.
|
||||
|
||||
@@ -1121,3 +1121,105 @@ class TestSharedCache:
|
||||
# Cleanup
|
||||
shared_perf_function.cache_clear()
|
||||
local_perf_function.cache_clear()
|
||||
|
||||
|
||||
class TestCacheHMAC:
|
||||
"""Tests for HMAC integrity verification on Redis-backed cache."""
|
||||
|
||||
def test_hmac_signed_roundtrip(self):
|
||||
"""Values written to Redis can be read back via HMAC verification."""
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
def hmac_roundtrip_fn(x: int) -> dict:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return {"value": x, "nested": [1, 2, 3]}
|
||||
|
||||
hmac_roundtrip_fn.cache_clear()
|
||||
|
||||
result1 = hmac_roundtrip_fn(42)
|
||||
assert result1 == {"value": 42, "nested": [1, 2, 3]}
|
||||
assert call_count == 1
|
||||
|
||||
# Second call should hit cache (HMAC verification passes)
|
||||
result2 = hmac_roundtrip_fn(42)
|
||||
assert result2 == {"value": 42, "nested": [1, 2, 3]}
|
||||
assert call_count == 1
|
||||
|
||||
hmac_roundtrip_fn.cache_clear()
|
||||
|
||||
def test_tampered_cache_entry_rejected(self):
|
||||
"""A tampered Redis entry is rejected and treated as a cache miss."""
|
||||
from backend.util.cache import _get_redis
|
||||
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
def tamper_test_fn(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
|
||||
tamper_test_fn.cache_clear()
|
||||
|
||||
# Populate the cache
|
||||
result = tamper_test_fn(7)
|
||||
assert result == 14
|
||||
assert call_count == 1
|
||||
|
||||
# Find and tamper with the Redis key
|
||||
redis = _get_redis()
|
||||
keys = list(redis.scan_iter("cache:tamper_test_fn:*"))
|
||||
assert len(keys) >= 1, "Expected at least one cache key"
|
||||
|
||||
for key in keys:
|
||||
raw = redis.get(key)
|
||||
assert raw is not None
|
||||
# Flip a byte in the signature portion to simulate tampering
|
||||
tampered = bytes([raw[0] ^ 0xFF]) + raw[1:]
|
||||
redis.set(key, tampered)
|
||||
|
||||
# Next call should detect tampering and recompute
|
||||
result2 = tamper_test_fn(7)
|
||||
assert result2 == 14
|
||||
assert call_count == 2 # Had to recompute
|
||||
|
||||
tamper_test_fn.cache_clear()
|
||||
|
||||
def test_unsigned_legacy_entry_rejected(self):
|
||||
"""A raw pickled value (no HMAC prefix) is rejected as a cache miss."""
|
||||
import pickle as _pickle
|
||||
|
||||
from backend.util.cache import _get_redis
|
||||
|
||||
call_count = 0
|
||||
|
||||
@cached(ttl_seconds=30, shared_cache=True)
|
||||
def legacy_test_fn(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x + 100
|
||||
|
||||
legacy_test_fn.cache_clear()
|
||||
|
||||
# Manually write an unsigned (legacy) pickled value directly to Redis
|
||||
redis = _get_redis()
|
||||
# We need to figure out the cache key format; populate first then overwrite
|
||||
result = legacy_test_fn(5)
|
||||
assert result == 105
|
||||
assert call_count == 1
|
||||
|
||||
keys = list(redis.scan_iter("cache:legacy_test_fn:*"))
|
||||
assert len(keys) >= 1
|
||||
|
||||
# Overwrite with raw unsigned pickle (simulating a legacy entry)
|
||||
for key in keys:
|
||||
redis.set(key, _pickle.dumps(999))
|
||||
|
||||
# Next call should reject the unsigned value and recompute
|
||||
result2 = legacy_test_fn(5)
|
||||
assert result2 == 105
|
||||
assert call_count == 2
|
||||
|
||||
legacy_test_fn.cache_clear()
|
||||
|
||||
Reference in New Issue
Block a user