mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(feature-flag): add LaunchDarkly segment support with 24h caching
## Summary - Enhanced LaunchDarkly feature flag system to support context targeting with user segments - Added 24-hour TTL caching for user context data to improve performance - Implemented comprehensive user segmentation logic including role-based, domain-based, and account age segments ## Key Changes Made ### LaunchDarkly Client Enhancement - **Segment Support**: Added full user context with segments for LaunchDarkly's contextTargets and segmentMatch rules - **24-Hour Caching**: Implemented TTL cache with 86400 seconds (24 hours) to reduce database calls - **Clean Architecture**: Uses existing `get_user_by_id` function following credentials_store.py pattern - **Async-Only**: Removed sync support, made all functions async for consistency ### Segmentation Logic - **Role-based**: user, admin, system segments - **Domain-based**: employee segment for agpt.co emails, external for others - **Account age**: new_user (<7 days), recent_user (7-30 days), established_user (>30 days) - **Custom metadata**: Supports additional segments from user metadata ### Files Modified - `autogpt_libs/feature_flag/client.py`: Main implementation with segment support and caching - `autogpt_libs/utils/cache.py`: Added async TTL cache decorator with Protocol typing - `autogpt_libs/utils/cache_test.py`: Comprehensive tests for TTL and process-level caching - `backend/executor/database.py`: Exposed `get_user_by_id` for RPC support ## Technical Details - Uses `@async_ttl_cache(maxsize=1000, ttl_seconds=86400)` for efficient caching - Fallback to simple user_id context if database unavailable - Proper exception handling without dangerous swallowing - Strong typing with Protocol for cache functions - Backwards compatible with `use_user_id_only` parameter ## Test Plan - [x] All cache tests pass (TTL and process-level) - [x] Feature flag evaluation works with segments - [x] Database RPC calls work correctly - [x] 24-hour cache reduces database load - [x] Fallback to simple context when database unavailable ## Impact - **Better targeting**: LaunchDarkly rules with segments now work correctly - **Performance**: 24-hour cache dramatically reduces database calls - **Maintainability**: Clean separation using existing database functions - **Reliability**: Proper error handling and fallbacks 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, cast
|
||||
from typing import Any, Awaitable, Callable, Optional, TypeVar
|
||||
|
||||
import ldclient
|
||||
from fastapi import HTTPException
|
||||
@@ -10,6 +9,8 @@ from ldclient import Context, LDClient
|
||||
from ldclient.config import Config
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from autogpt_libs.utils.cache import async_ttl_cache
|
||||
|
||||
from .config import SETTINGS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -56,7 +57,7 @@ def shutdown_launchdarkly() -> None:
|
||||
|
||||
|
||||
def create_context(
|
||||
user_id: str, additional_attributes: Optional[Dict[str, Any]] = None
|
||||
user_id: str, additional_attributes: Optional[dict[str, Any]] = None
|
||||
) -> Context:
|
||||
"""Create LaunchDarkly context with optional additional attributes."""
|
||||
builder = Context.builder(str(user_id)).kind("user")
|
||||
@@ -66,22 +67,215 @@ def create_context(
|
||||
return builder.build()
|
||||
|
||||
|
||||
def is_feature_enabled(flag_key: str, user_id: str, default: bool = False) -> bool:
|
||||
@async_ttl_cache(maxsize=1000, ttl_seconds=86400) # 1000 entries, 24 hours TTL
|
||||
async def _fetch_user_context_data(user_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Simple helper to check if a feature flag is enabled for a user.
|
||||
Fetch user data and build complete LaunchDarkly context with segments.
|
||||
|
||||
Uses the existing get_user_by_id function and applies segmentation logic
|
||||
within the LaunchDarkly client for clean separation of concerns.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to fetch data for
|
||||
|
||||
Returns:
|
||||
Dictionary with user context data including segments
|
||||
|
||||
Raises:
|
||||
Exception: If database query fails (not cached)
|
||||
"""
|
||||
# Import here to avoid circular dependencies
|
||||
from backend.data.db import prisma
|
||||
|
||||
# Check if we're in a context with direct database access
|
||||
if prisma.is_connected():
|
||||
# Direct database query using existing function
|
||||
from backend.data.user import get_user_by_id
|
||||
|
||||
user = await get_user_by_id(user_id)
|
||||
else:
|
||||
# Use database manager client for RPC calls
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
db_client = get_database_manager_async_client()
|
||||
user = await db_client.get_user_by_id(user_id)
|
||||
|
||||
# Build and return LaunchDarkly context with segments
|
||||
return _build_launchdarkly_context(user)
|
||||
|
||||
|
||||
def _build_launchdarkly_context(user) -> dict[str, Any]:
|
||||
"""
|
||||
Build LaunchDarkly context data with segments from user object.
|
||||
|
||||
Args:
|
||||
user: User object from database
|
||||
|
||||
Returns:
|
||||
Dictionary with user context data including segments
|
||||
"""
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from autogpt_libs.auth.models import DEFAULT_USER_ID
|
||||
|
||||
# Build basic context data
|
||||
context_data = {
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"createdAt": user.createdAt.isoformat() if user.createdAt else None,
|
||||
}
|
||||
|
||||
# Determine user segments for LaunchDarkly targeting
|
||||
segments = []
|
||||
|
||||
# Add role-based segment
|
||||
if user.id == DEFAULT_USER_ID:
|
||||
segments.extend(
|
||||
["system", "admin"]
|
||||
) # Default user has admin privileges when auth is disabled
|
||||
else:
|
||||
segments.append("user") # Regular users
|
||||
|
||||
# Add email domain-based segments for targeting
|
||||
if user.email:
|
||||
domain = user.email.split("@")[-1] if "@" in user.email else ""
|
||||
if domain:
|
||||
context_data["email_domain"] = domain
|
||||
if domain in ["agpt.co"]:
|
||||
segments.append("employee")
|
||||
|
||||
# Parse metadata for additional segments and custom attributes
|
||||
if user.metadata:
|
||||
try:
|
||||
metadata = (
|
||||
json.loads(user.metadata)
|
||||
if isinstance(user.metadata, str)
|
||||
else user.metadata
|
||||
)
|
||||
|
||||
# Extract explicit segments from metadata if they exist
|
||||
if "segments" in metadata:
|
||||
if isinstance(metadata["segments"], list):
|
||||
segments.extend(metadata["segments"])
|
||||
elif isinstance(metadata["segments"], str):
|
||||
segments.append(metadata["segments"])
|
||||
|
||||
# Extract role from metadata if present
|
||||
if "role" in metadata:
|
||||
role = metadata["role"]
|
||||
if role in ["admin", "moderator", "employee"]:
|
||||
segments.append(role)
|
||||
context_data["role"] = role
|
||||
|
||||
# Add custom attributes with prefix to avoid conflicts
|
||||
for key, value in metadata.items():
|
||||
if key not in ["segments", "role"]: # Skip processed fields
|
||||
context_data[f"custom_{key}"] = value
|
||||
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
logger.debug(f"Failed to parse user metadata for context: {e}")
|
||||
|
||||
# Add account age segment for targeting new vs old users
|
||||
if user.createdAt:
|
||||
account_age_days = (datetime.now(user.createdAt.tzinfo) - user.createdAt).days
|
||||
if account_age_days < 7:
|
||||
segments.append("new_user")
|
||||
elif account_age_days < 30:
|
||||
segments.append("recent_user")
|
||||
else:
|
||||
segments.append("established_user")
|
||||
|
||||
context_data["account_age_days"] = account_age_days
|
||||
|
||||
# Remove duplicates and sort for consistency
|
||||
context_data["segments"] = sorted(list(set(segments)))
|
||||
|
||||
return context_data
|
||||
|
||||
|
||||
async def is_feature_enabled(
|
||||
flag_key: str,
|
||||
user_id: str,
|
||||
default: bool = False,
|
||||
use_user_id_only: bool = False,
|
||||
additional_attributes: Optional[dict[str, Any]] = None,
|
||||
user_role: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a feature flag is enabled for a user with full LaunchDarkly context support.
|
||||
|
||||
Args:
|
||||
flag_key: The LaunchDarkly feature flag key
|
||||
user_id: The user ID to evaluate the flag for
|
||||
default: Default value if LaunchDarkly is unavailable or flag evaluation fails
|
||||
use_user_id_only: If True, only use user_id without fetching database context
|
||||
additional_attributes: Additional attributes to include in the context
|
||||
user_role: Optional user role (e.g., "admin", "user") to add to segments
|
||||
|
||||
Returns:
|
||||
True if feature is enabled, False otherwise
|
||||
"""
|
||||
try:
|
||||
client = get_client()
|
||||
context = create_context(str(user_id))
|
||||
return client.variation(flag_key, context, default)
|
||||
|
||||
if use_user_id_only:
|
||||
# Simple context with just user ID (for backward compatibility)
|
||||
# But still add role if provided
|
||||
attrs = additional_attributes or {}
|
||||
if user_role:
|
||||
attrs["role"] = user_role
|
||||
# Also add role as a segment for LaunchDarkly targeting
|
||||
segments = attrs.get("segments", [])
|
||||
if isinstance(segments, list):
|
||||
segments.append(user_role)
|
||||
else:
|
||||
segments = [user_role]
|
||||
attrs["segments"] = list(set(segments)) # Remove duplicates
|
||||
context = create_context(str(user_id), attrs)
|
||||
else:
|
||||
# Full context with user segments and metadata from database
|
||||
try:
|
||||
user_data = await _fetch_user_context_data(user_id)
|
||||
except ImportError as e:
|
||||
# Database modules not available - fallback to simple context
|
||||
logger.debug(f"Database modules not available: {e}")
|
||||
user_data = {}
|
||||
except Exception as e:
|
||||
# Database error - log and fallback to simple context
|
||||
logger.warning(f"Failed to fetch user context for {user_id}: {e}")
|
||||
user_data = {}
|
||||
|
||||
# Merge additional attributes and role
|
||||
attrs = additional_attributes or {}
|
||||
if user_role:
|
||||
attrs["role"] = user_role
|
||||
# Add role to segments if we have them
|
||||
if "segments" in user_data:
|
||||
user_data["segments"].append(user_role)
|
||||
user_data["segments"] = list(set(user_data["segments"]))
|
||||
else:
|
||||
user_data["segments"] = [user_role]
|
||||
|
||||
# Merge attributes with user data
|
||||
final_attrs = {**user_data, **attrs}
|
||||
|
||||
# Handle segment merging if both have segments
|
||||
if "segments" in user_data and "segments" in attrs:
|
||||
combined = list(set(user_data["segments"] + attrs["segments"]))
|
||||
final_attrs["segments"] = combined
|
||||
|
||||
context = create_context(str(user_id), final_attrs)
|
||||
|
||||
# Evaluate the flag
|
||||
result = client.variation(flag_key, context, default)
|
||||
|
||||
logger.debug(
|
||||
f"Feature flag {flag_key} for user {user_id}: {result} "
|
||||
f"(use_user_id_only: {use_user_id_only})"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
@@ -93,16 +287,19 @@ def is_feature_enabled(flag_key: str, user_id: str, default: bool = False) -> bo
|
||||
def feature_flag(
|
||||
flag_key: str,
|
||||
default: bool = False,
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]:
|
||||
"""
|
||||
Decorator for feature flag protected endpoints.
|
||||
Decorator for async feature flag protected endpoints.
|
||||
|
||||
Args:
|
||||
flag_key: The LaunchDarkly feature flag key
|
||||
default: Default value if flag evaluation fails
|
||||
|
||||
Returns:
|
||||
Decorator that only works with async functions
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[P, Union[T, Awaitable[T]]],
|
||||
) -> Callable[P, Union[T, Awaitable[T]]]:
|
||||
def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
try:
|
||||
@@ -116,73 +313,24 @@ def feature_flag(
|
||||
)
|
||||
is_enabled = default
|
||||
else:
|
||||
context = create_context(str(user_id))
|
||||
is_enabled = get_client().variation(flag_key, context, default)
|
||||
|
||||
if not is_enabled:
|
||||
raise HTTPException(status_code=404, detail="Feature not available")
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
if asyncio.iscoroutine(result):
|
||||
return await result
|
||||
return cast(T, result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
|
||||
raise
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
try:
|
||||
user_id = kwargs.get("user_id")
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
if not get_client().is_initialized():
|
||||
logger.warning(
|
||||
f"LaunchDarkly not initialized, using default={default}"
|
||||
# Use the unified function with full context support
|
||||
is_enabled = await is_feature_enabled(
|
||||
flag_key, str(user_id), default, use_user_id_only=False
|
||||
)
|
||||
is_enabled = default
|
||||
else:
|
||||
context = create_context(str(user_id))
|
||||
is_enabled = get_client().variation(flag_key, context, default)
|
||||
|
||||
if not is_enabled:
|
||||
raise HTTPException(status_code=404, detail="Feature not available")
|
||||
|
||||
return cast(T, func(*args, **kwargs))
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
|
||||
raise
|
||||
|
||||
return cast(
|
||||
Callable[P, Union[T, Awaitable[T]]],
|
||||
async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper,
|
||||
)
|
||||
return async_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def percentage_rollout(
|
||||
flag_key: str,
|
||||
default: bool = False,
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""Decorator for percentage-based rollouts."""
|
||||
return feature_flag(flag_key, default)
|
||||
|
||||
|
||||
def beta_feature(
|
||||
flag_key: Optional[str] = None,
|
||||
unauthorized_response: Any = {"message": "Not available in beta"},
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""Decorator for beta features."""
|
||||
actual_key = f"beta-{flag_key}" if flag_key else "beta"
|
||||
return feature_flag(actual_key, False)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_flag_variation(flag_key: str, return_value: Any):
|
||||
"""Context manager for testing feature flags."""
|
||||
|
||||
@@ -1,10 +1,26 @@
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
|
||||
@@ -57,3 +73,153 @@ def thread_cached(
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AsyncCachedFunction(Protocol):
|
||||
"""Protocol for async functions with cache management methods."""
|
||||
|
||||
def cache_clear(self) -> None:
|
||||
"""Clear all cached entries."""
|
||||
return None
|
||||
|
||||
def cache_info(self) -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
return {}
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Call the cached function."""
|
||||
return None
|
||||
|
||||
|
||||
def async_ttl_cache(
|
||||
maxsize: int = 128, ttl_seconds: int | None = None
|
||||
) -> Callable[[Callable[..., Awaitable[Any]]], AsyncCachedFunction]:
|
||||
"""
|
||||
TTL (Time To Live) cache decorator for async functions.
|
||||
|
||||
Similar to functools.lru_cache but works with async functions and includes optional TTL.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum number of cached entries
|
||||
ttl_seconds: Time to live in seconds. If None, entries never expire (like lru_cache)
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
|
||||
Example:
|
||||
# With TTL
|
||||
@async_ttl_cache(maxsize=1000, ttl_seconds=300)
|
||||
async def api_call(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
# Without TTL (permanent cache like lru_cache)
|
||||
@async_ttl_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
async_func: Callable[..., Awaitable[Any]],
|
||||
) -> AsyncCachedFunction:
|
||||
# Cache storage - use union type to handle both cases
|
||||
cache_storage: dict[Any, Any | Tuple[Any, float]] = {}
|
||||
|
||||
@wraps(async_func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Create cache key from arguments
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we have a valid cached entry
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
# No TTL - return cached result directly
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cache_storage[key]
|
||||
else:
|
||||
# With TTL - check expiration
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return result
|
||||
else:
|
||||
# Expired entry
|
||||
del cache_storage[key]
|
||||
logger.debug(
|
||||
f"Cache entry expired for {async_func.__name__}"
|
||||
)
|
||||
|
||||
# Cache miss or expired - fetch fresh data
|
||||
logger.debug(
|
||||
f"Cache miss for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
result = await async_func(*args, **kwargs)
|
||||
|
||||
# Store in cache
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Simple cleanup when cache gets too large
|
||||
if len(cache_storage) > maxsize:
|
||||
# Remove oldest entries (simple FIFO cleanup)
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
logger.debug(
|
||||
f"Cache cleanup: removed {len(oldest_keys)} entries for {async_func.__name__}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
# Add cache management methods (similar to functools.lru_cache)
|
||||
def cache_clear() -> None:
|
||||
cache_storage.clear()
|
||||
|
||||
def cache_info() -> dict[str, Any]:
|
||||
return {
|
||||
"size": len(cache_storage),
|
||||
"maxsize": maxsize,
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
|
||||
# Attach methods to wrapper
|
||||
setattr(wrapper, "cache_clear", cache_clear)
|
||||
setattr(wrapper, "cache_info", cache_info)
|
||||
|
||||
return cast(AsyncCachedFunction, wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def async_cache(
|
||||
maxsize: int = 128,
|
||||
) -> Callable[[Callable[..., Awaitable[Any]]], AsyncCachedFunction]:
|
||||
"""
|
||||
Process-level cache decorator for async functions (no TTL).
|
||||
|
||||
Similar to functools.lru_cache but works with async functions.
|
||||
This is a convenience wrapper around async_ttl_cache with ttl_seconds=None.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum number of cached entries
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
|
||||
Example:
|
||||
@async_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
# Expensive computation here
|
||||
return {"result": param}
|
||||
"""
|
||||
return async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
|
||||
@@ -16,7 +16,12 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt_libs.utils.cache import clear_thread_cache, thread_cached
|
||||
from autogpt_libs.utils.cache import (
|
||||
async_cache,
|
||||
async_ttl_cache,
|
||||
clear_thread_cache,
|
||||
thread_cached,
|
||||
)
|
||||
|
||||
|
||||
class TestThreadCached:
|
||||
@@ -323,3 +328,378 @@ class TestThreadCached:
|
||||
|
||||
assert function_using_mock(2) == 42
|
||||
assert mock.call_count == 2
|
||||
|
||||
|
||||
class TestAsyncTTLCache:
|
||||
"""Tests for the @async_ttl_cache decorator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching(self):
|
||||
"""Test basic caching functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result3 = await cached_function(2, 3)
|
||||
assert result3 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_expiration(self):
|
||||
"""Test that cache entries expire after TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def short_lived_cache(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First call
|
||||
result1 = await short_lived_cache(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Second call immediately - should use cache
|
||||
result2 = await short_lived_cache(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Third call after expiration - should call function again
|
||||
result3 = await short_lived_cache(5)
|
||||
assert result3 == 10
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_info(self):
|
||||
"""Test cache info functionality."""
|
||||
|
||||
@async_ttl_cache(maxsize=5, ttl_seconds=300)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] == 300
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_clear(self):
|
||||
"""Test cache clearing functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 4
|
||||
|
||||
# First call
|
||||
result1 = await clearable_function(2)
|
||||
assert result1 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
result2 = await clearable_function(2)
|
||||
assert result2 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Clear cache
|
||||
clearable_function.cache_clear()
|
||||
|
||||
# Third call after clear - should call function again
|
||||
result3 = await clearable_function(2)
|
||||
assert result3 == 8
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maxsize_cleanup(self):
|
||||
"""Test that cache cleans up when maxsize is exceeded."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=3, ttl_seconds=60)
|
||||
async def size_limited_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# Fill cache to maxsize
|
||||
await size_limited_function(1) # call_count: 1
|
||||
await size_limited_function(2) # call_count: 2
|
||||
await size_limited_function(3) # call_count: 3
|
||||
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] == 3
|
||||
|
||||
# Add one more entry - should trigger cleanup
|
||||
await size_limited_function(4) # call_count: 4
|
||||
|
||||
# Cache size should be reduced (cleanup removes oldest entries)
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] <= 3 # Should be cleaned up
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_argument_variations(self):
|
||||
"""Test caching with different argument patterns."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def arg_test_function(a: int, b: str = "default", *, c: int = 100) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{a}-{b}-{c}"
|
||||
|
||||
# Different ways to call with same logical arguments
|
||||
result1 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
|
||||
# Same arguments, same order - should use cache
|
||||
result2 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
assert result1 == result2
|
||||
|
||||
# Different arguments - should call function
|
||||
result3 = await arg_test_function(2, "test", c=200)
|
||||
assert call_count == 2
|
||||
assert result1 != result3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_handling(self):
|
||||
"""Test that exceptions are not cached."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def exception_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if x < 0:
|
||||
raise ValueError("Negative value not allowed")
|
||||
return x * 2
|
||||
|
||||
# Successful call - should be cached
|
||||
result1 = await exception_function(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Same successful call - should use cache
|
||||
result2 = await exception_function(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Exception call - should not be cached
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
# Same exception call - should call again (not cached)
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_calls(self):
|
||||
"""Test caching behavior with concurrent calls."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def concurrent_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.05) # Simulate work
|
||||
return x * x
|
||||
|
||||
# Launch concurrent calls with same arguments
|
||||
tasks = [concurrent_function(3) for _ in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All results should be the same
|
||||
assert all(result == 9 for result in results)
|
||||
|
||||
# Note: Due to race conditions, call_count might be up to 5 for concurrent calls
|
||||
# This tests that the cache doesn't break under concurrent access
|
||||
assert 1 <= call_count <= 5
|
||||
|
||||
|
||||
class TestAsyncCache:
|
||||
"""Tests for the @async_cache decorator (no TTL)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching_no_ttl(self):
|
||||
"""Test basic caching functionality without TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_cache(maxsize=10)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Third call after some time - should still use cache (no TTL)
|
||||
await asyncio.sleep(0.05)
|
||||
result3 = await cached_function(1, 2)
|
||||
assert result3 == 3
|
||||
assert call_count == 1 # Still no additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result4 = await cached_function(2, 3)
|
||||
assert result4 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_ttl_vs_ttl_behavior(self):
|
||||
"""Test the difference between TTL and no-TTL caching."""
|
||||
ttl_call_count = 0
|
||||
no_ttl_call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_call_count
|
||||
ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
@async_cache(maxsize=10) # No TTL
|
||||
async def no_ttl_function(x: int) -> int:
|
||||
nonlocal no_ttl_call_count
|
||||
no_ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First calls
|
||||
await ttl_function(5)
|
||||
await no_ttl_function(5)
|
||||
assert ttl_call_count == 1
|
||||
assert no_ttl_call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Second calls after TTL expiry
|
||||
await ttl_function(5) # Should call function again (TTL expired)
|
||||
await no_ttl_function(5) # Should use cache (no TTL)
|
||||
assert ttl_call_count == 2 # TTL function called again
|
||||
assert no_ttl_call_count == 1 # No-TTL function still cached
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_cache_info(self):
|
||||
"""Test cache info for no-TTL cache."""
|
||||
|
||||
@async_cache(maxsize=5)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] is None # No TTL
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
|
||||
class TestTTLOptional:
|
||||
"""Tests for optional TTL functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_none_behavior(self):
|
||||
"""Test that ttl_seconds=None works like no TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=None)
|
||||
async def no_ttl_via_none(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# First call
|
||||
result1 = await no_ttl_via_none(3)
|
||||
assert result1 == 9
|
||||
assert call_count == 1
|
||||
|
||||
# Wait (would expire if there was TTL)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Second call - should still use cache
|
||||
result2 = await no_ttl_via_none(3)
|
||||
assert result2 == 9
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Check cache info
|
||||
info = no_ttl_via_none.cache_info()
|
||||
assert info["ttl_seconds"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_options_comparison(self):
|
||||
"""Test different cache options work as expected."""
|
||||
ttl_calls = 0
|
||||
no_ttl_calls = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # With TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_calls
|
||||
ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
@async_cache(maxsize=10) # Process-level cache (no TTL)
|
||||
async def process_function(x: int) -> int:
|
||||
nonlocal no_ttl_calls
|
||||
no_ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
# Both should cache initially
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Immediate second calls - both should use cache
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# After TTL expiry
|
||||
await ttl_function(3) # Should call function again
|
||||
await process_function(3) # Should still use cache
|
||||
assert ttl_calls == 2 # TTL cache expired, called again
|
||||
assert no_ttl_calls == 1 # Process cache never expires
|
||||
|
||||
@@ -102,8 +102,10 @@ async def generate_activity_status_for_execution(
|
||||
Returns:
|
||||
AI-generated activity status string, or None if feature is disabled
|
||||
"""
|
||||
# Check LaunchDarkly feature flag for AI activity status generation
|
||||
if not is_feature_enabled(AI_ACTIVITY_STATUS_FLAG_KEY, user_id, default=False):
|
||||
# Check LaunchDarkly feature flag for AI activity status generation with full context support
|
||||
if not await is_feature_enabled(
|
||||
AI_ACTIVITY_STATUS_FLAG_KEY, user_id, default=False
|
||||
):
|
||||
logger.debug("AI activity status generation is disabled via LaunchDarkly")
|
||||
return None
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ from backend.data.notifications import (
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_active_user_ids_in_timerange,
|
||||
get_user_by_id,
|
||||
get_user_email_by_id,
|
||||
get_user_email_verification,
|
||||
get_user_integrations,
|
||||
@@ -129,6 +130,7 @@ class DatabaseManager(AppService):
|
||||
|
||||
# User Comms - async
|
||||
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
|
||||
get_user_by_id = _(get_user_by_id)
|
||||
get_user_email_by_id = _(get_user_email_by_id)
|
||||
get_user_email_verification = _(get_user_email_verification)
|
||||
get_user_notification_preference = _(get_user_notification_preference)
|
||||
@@ -201,6 +203,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
|
||||
# User Comms
|
||||
get_active_user_ids_in_timerange = d.get_active_user_ids_in_timerange
|
||||
get_user_by_id = d.get_user_by_id
|
||||
get_user_email_by_id = d.get_user_email_by_id
|
||||
get_user_email_verification = d.get_user_email_verification
|
||||
get_user_notification_preference = d.get_user_notification_preference
|
||||
|
||||
Reference in New Issue
Block a user