Merge commit from fork

* fix(backend): add HMAC signing to Redis cache to prevent pickle deserialization attacks

Add HMAC-SHA256 integrity verification to all values stored in the shared
Redis cache. This prevents cache poisoning attacks where an attacker with
Redis access injects malicious pickled payloads that execute arbitrary code
on deserialization.

Changes:
- Sign pickled values with HMAC-SHA256 before storing in Redis
- Verify HMAC signature before deserializing cached values
- Reject tampered or unsigned (legacy) cache entries gracefully
  (treated as cache misses, logged as warnings)
- Derive HMAC key from redis_password or unsubscribe_secret_key
- Add tests for HMAC round-trip, tamper detection, and legacy rejection

Fixes GHSA-rfg2-37xq-w4m9

* improve log message

---------

Co-authored-by: Reinier van der Leer <pwuts@agpt.co>
This commit is contained in:
Otto
2026-03-17 14:52:37 +00:00
committed by GitHub
parent ae660ea04f
commit a6124b06d5
2 changed files with 184 additions and 33 deletions

View File

@@ -10,6 +10,8 @@ Provides decorators for caching function results with support for:
"""
import asyncio
import hashlib
import hmac
import inspect
import logging
import pickle
@@ -32,6 +34,9 @@ T = TypeVar("T")
logger = logging.getLogger(__name__)
settings = Settings()
# Length of the HMAC-SHA256 signature prefix on cached values.
_HMAC_SIG_LEN = 32
# RECOMMENDED REDIS CONFIGURATION FOR PRODUCTION:
# Configure Redis with the following settings for optimal caching performance:
# maxmemory-policy allkeys-lru # Evict least recently used keys when memory limit reached
@@ -176,35 +181,45 @@ def cached(
"""
def decorator(target_func: Callable[P, R]) -> CachedFunction[P, R]:
func_name = target_func.__name__
cache_storage: dict[tuple, CachedValue] = {}
_event_loop_locks: dict[Any, asyncio.Lock] = {}
def _get_from_redis(redis_key: str) -> Any | None:
"""Get value from Redis, optionally refreshing TTL."""
"""Get value from Redis, optionally refreshing TTL.
Values are expected to carry an HMAC-SHA256 prefix for integrity
verification. Unsigned (legacy) or tampered entries are silently
discarded and treated as cache misses, so the caller recomputes and
re-stores them with a valid signature.
"""
try:
if refresh_ttl_on_get:
# Use GETEX to get value and refresh expiry atomically
cached_bytes = _get_redis().getex(redis_key, ex=ttl_seconds)
else:
cached_bytes = _get_redis().get(redis_key)
if cached_bytes and isinstance(cached_bytes, bytes):
return pickle.loads(cached_bytes)
payload = _verify_and_strip(cached_bytes)
if payload is None:
logger.warning(
"[SECURITY] Cache HMAC verification failed "
f"for {func_name}, discarding entry: "
"possible tampering or legacy unsigned value"
)
return None
return pickle.loads(payload)
except Exception as e:
logger.error(
f"Redis error during cache check for {target_func.__name__}: {e}"
)
logger.error(f"Redis error during cache check for {func_name}: {e}")
return None
def _set_to_redis(redis_key: str, value: Any) -> None:
"""Set value in Redis with TTL."""
"""Set HMAC-signed pickled value in Redis with TTL."""
try:
pickled_value = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
_get_redis().setex(redis_key, ttl_seconds, pickled_value)
pickled = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
_get_redis().setex(redis_key, ttl_seconds, _sign_payload(pickled))
except Exception as e:
logger.error(
f"Redis error storing cache for {target_func.__name__}: {e}"
)
logger.error(f"Redis error storing cache for {func_name}: {e}")
def _get_from_memory(key: tuple) -> Any | None:
"""Get value from in-memory cache, checking TTL."""
@@ -212,7 +227,7 @@ def cached(
cached_data = cache_storage[key]
if time.time() - cached_data.timestamp < ttl_seconds:
logger.debug(
f"Cache hit for {target_func.__name__} args: {key[0]} kwargs: {key[1]}"
f"Cache hit for {func_name} args: {key[0]} kwargs: {key[1]}"
)
return cached_data.result
return None
@@ -244,9 +259,7 @@ def cached(
@wraps(target_func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs):
key = _make_hashable_key(args, kwargs)
redis_key = (
_make_redis_key(key, target_func.__name__) if shared_cache else ""
)
redis_key = _make_redis_key(key, func_name) if shared_cache else ""
# Fast path: check cache without lock
if shared_cache:
@@ -271,7 +284,7 @@ def cached(
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {target_func.__name__}")
logger.debug(f"Cache miss for {func_name}")
result = await target_func(*args, **kwargs)
# Store result
@@ -291,9 +304,7 @@ def cached(
@wraps(target_func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs):
key = _make_hashable_key(args, kwargs)
redis_key = (
_make_redis_key(key, target_func.__name__) if shared_cache else ""
)
redis_key = _make_redis_key(key, func_name) if shared_cache else ""
# Fast path: check cache without lock
if shared_cache:
@@ -318,7 +329,7 @@ def cached(
return result
# Cache miss - execute function
logger.debug(f"Cache miss for {target_func.__name__}")
logger.debug(f"Cache miss for {func_name}")
result = target_func(*args, **kwargs)
# Store result
@@ -337,16 +348,10 @@ def cached(
if shared_cache:
if pattern:
# Clear entries matching pattern
keys = list(
_get_redis().scan_iter(
f"cache:{target_func.__name__}:{pattern}"
)
)
keys = list(_get_redis().scan_iter(f"cache:{func_name}:{pattern}"))
else:
# Clear all cache keys
keys = list(
_get_redis().scan_iter(f"cache:{target_func.__name__}:*")
)
keys = list(_get_redis().scan_iter(f"cache:{func_name}:*"))
if keys:
pipeline = _get_redis().pipeline()
@@ -364,9 +369,7 @@ def cached(
def cache_info() -> dict[str, int | None]:
if shared_cache:
cache_keys = list(
_get_redis().scan_iter(f"cache:{target_func.__name__}:*")
)
cache_keys = list(_get_redis().scan_iter(f"cache:{func_name}:*"))
return {
"size": len(cache_keys),
"maxsize": None, # Redis manages its own size
@@ -383,7 +386,7 @@ def cached(
"""Delete a specific cache entry. Returns True if entry existed."""
key = _make_hashable_key(args, kwargs)
if shared_cache:
redis_key = _make_redis_key(key, target_func.__name__)
redis_key = _make_redis_key(key, func_name)
deleted_count = cast(int, _get_redis().delete(redis_key))
return deleted_count > 0
else:
@@ -401,6 +404,52 @@ def cached(
return decorator
def _sign_payload(data: bytes) -> bytes:
"""Return *signature + data* (32-byte HMAC-SHA256 prefix)."""
sig = hmac.new(_get_hmac_key(), data, hashlib.sha256).digest()
return sig + data
def _verify_and_strip(blob: bytes) -> bytes | None:
"""Verify the HMAC prefix and return the payload, or `None`.
During deployment, the cache may still contain unsigned (legacy) entries.
These will fail verification and return `None`, causing callers to treat
them as cache misses. The value is then recomputed and stored with a valid
HMAC signature. This means the transition is fully automatic: no cache
flush is required, and all entries self-heal on next access within their
TTL window (max 1 hour for the longest-lived entries).
"""
if len(blob) <= _HMAC_SIG_LEN:
return None
sig, data = blob[:_HMAC_SIG_LEN], blob[_HMAC_SIG_LEN:]
expected = hmac.new(_get_hmac_key(), data, hashlib.sha256).digest()
if hmac.compare_digest(sig, expected):
return data
return None
@cache
def _get_hmac_key() -> bytes:
"""Derive a stable HMAC key for signing cached values in Redis.
Uses `encryption_key` — a backend-only secret that Redis never sees.
This ensures that even if an attacker compromises Redis, they cannot forge
valid HMAC signatures for cache entries.
Falls back to a hardcoded default with a loud warning so the decorator
never crashes in development/test environments without secrets configured.
"""
secret = settings.secrets.encryption_key
if not secret:
logger.warning(
"[SECURITY] No encryption_key configured: cache HMAC signing will use a "
"weak default key. Set ENCRYPTION_KEY for production deployments."
)
secret = "autogpt-cache-default-hmac-key"
return hashlib.sha256(secret.encode()).digest()
def thread_cached(func):
"""
Thread-local cache decorator for both sync and async functions.

View File

@@ -1121,3 +1121,105 @@ class TestSharedCache:
# Cleanup
shared_perf_function.cache_clear()
local_perf_function.cache_clear()
class TestCacheHMAC:
"""Tests for HMAC integrity verification on Redis-backed cache."""
def test_hmac_signed_roundtrip(self):
"""Values written to Redis can be read back via HMAC verification."""
call_count = 0
@cached(ttl_seconds=30, shared_cache=True)
def hmac_roundtrip_fn(x: int) -> dict:
nonlocal call_count
call_count += 1
return {"value": x, "nested": [1, 2, 3]}
hmac_roundtrip_fn.cache_clear()
result1 = hmac_roundtrip_fn(42)
assert result1 == {"value": 42, "nested": [1, 2, 3]}
assert call_count == 1
# Second call should hit cache (HMAC verification passes)
result2 = hmac_roundtrip_fn(42)
assert result2 == {"value": 42, "nested": [1, 2, 3]}
assert call_count == 1
hmac_roundtrip_fn.cache_clear()
def test_tampered_cache_entry_rejected(self):
"""A tampered Redis entry is rejected and treated as a cache miss."""
from backend.util.cache import _get_redis
call_count = 0
@cached(ttl_seconds=30, shared_cache=True)
def tamper_test_fn(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 2
tamper_test_fn.cache_clear()
# Populate the cache
result = tamper_test_fn(7)
assert result == 14
assert call_count == 1
# Find and tamper with the Redis key
redis = _get_redis()
keys = list(redis.scan_iter("cache:tamper_test_fn:*"))
assert len(keys) >= 1, "Expected at least one cache key"
for key in keys:
raw = redis.get(key)
assert raw is not None
# Flip a byte in the signature portion to simulate tampering
tampered = bytes([raw[0] ^ 0xFF]) + raw[1:]
redis.set(key, tampered)
# Next call should detect tampering and recompute
result2 = tamper_test_fn(7)
assert result2 == 14
assert call_count == 2 # Had to recompute
tamper_test_fn.cache_clear()
def test_unsigned_legacy_entry_rejected(self):
"""A raw pickled value (no HMAC prefix) is rejected as a cache miss."""
import pickle as _pickle
from backend.util.cache import _get_redis
call_count = 0
@cached(ttl_seconds=30, shared_cache=True)
def legacy_test_fn(x: int) -> int:
nonlocal call_count
call_count += 1
return x + 100
legacy_test_fn.cache_clear()
# Manually write an unsigned (legacy) pickled value directly to Redis
redis = _get_redis()
# We need to figure out the cache key format; populate first then overwrite
result = legacy_test_fn(5)
assert result == 105
assert call_count == 1
keys = list(redis.scan_iter("cache:legacy_test_fn:*"))
assert len(keys) >= 1
# Overwrite with raw unsigned pickle (simulating a legacy entry)
for key in keys:
redis.set(key, _pickle.dumps(999))
# Next call should reject the unsigned value and recompute
result2 = legacy_test_fn(5)
assert result2 == 105
assert call_count == 2
legacy_test_fn.cache_clear()