mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-11 16:18:07 -05:00
Compare commits
6 Commits
dev
...
feat/launc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5899897d8f | ||
|
|
0389c865aa | ||
|
|
b88576d313 | ||
|
|
14634a6ce9 | ||
|
|
10a402a766 | ||
|
|
004011726d |
@@ -1,10 +1,14 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, cast
|
||||
from json import JSONDecodeError
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.model import User
|
||||
|
||||
import ldclient
|
||||
from backend.util.json import loads as json_loads
|
||||
from fastapi import HTTPException
|
||||
from ldclient import Context, LDClient
|
||||
from ldclient.config import Config
|
||||
@@ -56,32 +60,202 @@ def shutdown_launchdarkly() -> None:
|
||||
|
||||
|
||||
def create_context(
|
||||
user_id: str, additional_attributes: Optional[Dict[str, Any]] = None
|
||||
user_id: str, additional_attributes: Optional[dict[str, Any]] = None
|
||||
) -> Context:
|
||||
"""Create LaunchDarkly context with optional additional attributes."""
|
||||
builder = Context.builder(str(user_id)).kind("user")
|
||||
# Use the key from attributes if provided, otherwise use user_id
|
||||
context_key = user_id
|
||||
if additional_attributes and "key" in additional_attributes:
|
||||
context_key = additional_attributes["key"]
|
||||
|
||||
builder = Context.builder(str(context_key)).kind("user")
|
||||
|
||||
if additional_attributes:
|
||||
for key, value in additional_attributes.items():
|
||||
builder.set(key, value)
|
||||
# Skip kind and key as they're already set
|
||||
if key in ["kind", "key"]:
|
||||
continue
|
||||
elif key == "custom" and isinstance(value, dict):
|
||||
# Handle custom attributes object - these go as individual attributes
|
||||
for custom_key, custom_value in value.items():
|
||||
builder.set(custom_key, custom_value)
|
||||
else:
|
||||
builder.set(key, value)
|
||||
return builder.build()
|
||||
|
||||
|
||||
def is_feature_enabled(flag_key: str, user_id: str, default: bool = False) -> bool:
|
||||
async def _fetch_user_context_data(user_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Simple helper to check if a feature flag is enabled for a user.
|
||||
Fetch user data and build LaunchDarkly context.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to fetch data for
|
||||
|
||||
Returns:
|
||||
Dictionary with user context data including role
|
||||
"""
|
||||
# Use the unified database access approach
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
db_client = get_database_manager_async_client()
|
||||
user = await db_client.get_user_by_id(user_id)
|
||||
|
||||
# Build LaunchDarkly context from user data
|
||||
return _build_launchdarkly_context(user)
|
||||
|
||||
|
||||
def _build_launchdarkly_context(user: "User") -> dict[str, Any]:
|
||||
"""
|
||||
Build LaunchDarkly context data matching frontend format.
|
||||
|
||||
Returns a context like:
|
||||
{
|
||||
"kind": "user",
|
||||
"key": "user-id",
|
||||
"email": "user@example.com", # Optional
|
||||
"anonymous": false,
|
||||
"custom": {
|
||||
"role": "admin" # Optional
|
||||
}
|
||||
}
|
||||
|
||||
Args:
|
||||
user: User object from database
|
||||
|
||||
Returns:
|
||||
Dictionary with user context data
|
||||
"""
|
||||
from autogpt_libs.auth.models import DEFAULT_USER_ID
|
||||
|
||||
# Build basic context - always include kind, key, and anonymous
|
||||
context_data: dict[str, Any] = {
|
||||
"kind": "user",
|
||||
"key": user.id,
|
||||
"anonymous": False,
|
||||
}
|
||||
|
||||
# Add email if present
|
||||
if user.email:
|
||||
context_data["email"] = user.email
|
||||
|
||||
# Initialize custom attributes
|
||||
custom: dict[str, Any] = {}
|
||||
|
||||
# Determine user role from metadata
|
||||
role = None
|
||||
|
||||
# Check if user is default/system user
|
||||
if user.id == DEFAULT_USER_ID:
|
||||
role = "admin" # Default user has admin privileges when auth is disabled
|
||||
elif user.metadata:
|
||||
# Check for role in metadata
|
||||
try:
|
||||
# Handle both string (direct DB) and dict (RPC) formats
|
||||
if isinstance(user.metadata, str):
|
||||
metadata = json_loads(user.metadata)
|
||||
elif isinstance(user.metadata, dict):
|
||||
metadata = user.metadata
|
||||
else:
|
||||
metadata = {}
|
||||
|
||||
# Extract role from metadata if present
|
||||
if metadata.get("role"):
|
||||
role = metadata["role"]
|
||||
|
||||
except (JSONDecodeError, TypeError) as e:
|
||||
logger.debug(f"Failed to parse user metadata for context: {e}")
|
||||
|
||||
# Add role to custom attributes if present
|
||||
if role:
|
||||
custom["role"] = role
|
||||
|
||||
# Only add custom object if it has content
|
||||
if custom:
|
||||
context_data["custom"] = custom
|
||||
|
||||
return context_data
|
||||
|
||||
|
||||
async def is_feature_enabled(
|
||||
flag_key: str,
|
||||
user_id: str,
|
||||
default: bool = False,
|
||||
use_user_id_only: bool = False,
|
||||
additional_attributes: Optional[dict[str, Any]] = None,
|
||||
user_role: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a feature flag is enabled for a user with full LaunchDarkly context support.
|
||||
|
||||
Args:
|
||||
flag_key: The LaunchDarkly feature flag key
|
||||
user_id: The user ID to evaluate the flag for
|
||||
default: Default value if LaunchDarkly is unavailable or flag evaluation fails
|
||||
use_user_id_only: If True, only use user_id without fetching database context
|
||||
additional_attributes: Additional attributes to include in the context
|
||||
user_role: Optional user role (e.g., "admin", "user") to add to segments
|
||||
|
||||
Returns:
|
||||
True if feature is enabled, False otherwise
|
||||
"""
|
||||
try:
|
||||
client = get_client()
|
||||
context = create_context(str(user_id))
|
||||
return client.variation(flag_key, context, default)
|
||||
|
||||
if use_user_id_only:
|
||||
# Simple context with just user ID (for backward compatibility)
|
||||
attrs = additional_attributes or {}
|
||||
if user_role:
|
||||
# Add role to custom attributes for consistency
|
||||
if "custom" not in attrs:
|
||||
attrs["custom"] = {}
|
||||
if isinstance(attrs["custom"], dict):
|
||||
attrs["custom"]["role"] = user_role
|
||||
context = create_context(str(user_id), attrs)
|
||||
else:
|
||||
# Full context with user segments and metadata from database
|
||||
try:
|
||||
user_data = await _fetch_user_context_data(user_id)
|
||||
except ImportError as e:
|
||||
# Database modules not available - fallback to simple context
|
||||
logger.debug(f"Database modules not available: {e}")
|
||||
user_data = {}
|
||||
except Exception as e:
|
||||
# Database error - log and fallback to simple context
|
||||
logger.warning(f"Failed to fetch user context for {user_id}: {e}")
|
||||
user_data = {}
|
||||
|
||||
# Merge additional attributes and role
|
||||
attrs = additional_attributes or {}
|
||||
|
||||
# If user_role is provided, add it to custom attributes
|
||||
if user_role:
|
||||
if "custom" not in user_data:
|
||||
user_data["custom"] = {}
|
||||
user_data["custom"]["role"] = user_role
|
||||
|
||||
# Merge additional attributes with user data
|
||||
# Handle custom attributes specially
|
||||
if "custom" in attrs and isinstance(attrs["custom"], dict):
|
||||
if "custom" not in user_data:
|
||||
user_data["custom"] = {}
|
||||
user_data["custom"].update(attrs["custom"])
|
||||
# Remove custom from attrs to avoid duplication
|
||||
attrs = {k: v for k, v in attrs.items() if k != "custom"}
|
||||
|
||||
# Merge remaining attributes
|
||||
final_attrs = {**user_data, **attrs}
|
||||
|
||||
context = create_context(str(user_id), final_attrs)
|
||||
|
||||
# Evaluate the flag
|
||||
result = client.variation(flag_key, context, default)
|
||||
|
||||
logger.debug(
|
||||
f"Feature flag {flag_key} for user {user_id}: {result} "
|
||||
f"(use_user_id_only: {use_user_id_only})"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
@@ -93,16 +267,19 @@ def is_feature_enabled(flag_key: str, user_id: str, default: bool = False) -> bo
|
||||
def feature_flag(
|
||||
flag_key: str,
|
||||
default: bool = False,
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]:
|
||||
"""
|
||||
Decorator for feature flag protected endpoints.
|
||||
Decorator for async feature flag protected endpoints.
|
||||
|
||||
Args:
|
||||
flag_key: The LaunchDarkly feature flag key
|
||||
default: Default value if flag evaluation fails
|
||||
|
||||
Returns:
|
||||
Decorator that only works with async functions
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
func: Callable[P, Union[T, Awaitable[T]]],
|
||||
) -> Callable[P, Union[T, Awaitable[T]]]:
|
||||
def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
try:
|
||||
@@ -116,73 +293,24 @@ def feature_flag(
|
||||
)
|
||||
is_enabled = default
|
||||
else:
|
||||
context = create_context(str(user_id))
|
||||
is_enabled = get_client().variation(flag_key, context, default)
|
||||
|
||||
if not is_enabled:
|
||||
raise HTTPException(status_code=404, detail="Feature not available")
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
if asyncio.iscoroutine(result):
|
||||
return await result
|
||||
return cast(T, result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
|
||||
raise
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||
try:
|
||||
user_id = kwargs.get("user_id")
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
if not get_client().is_initialized():
|
||||
logger.warning(
|
||||
f"LaunchDarkly not initialized, using default={default}"
|
||||
# Use the unified function with full context support
|
||||
is_enabled = await is_feature_enabled(
|
||||
flag_key, str(user_id), default, use_user_id_only=False
|
||||
)
|
||||
is_enabled = default
|
||||
else:
|
||||
context = create_context(str(user_id))
|
||||
is_enabled = get_client().variation(flag_key, context, default)
|
||||
|
||||
if not is_enabled:
|
||||
raise HTTPException(status_code=404, detail="Feature not available")
|
||||
|
||||
return cast(T, func(*args, **kwargs))
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
|
||||
raise
|
||||
|
||||
return cast(
|
||||
Callable[P, Union[T, Awaitable[T]]],
|
||||
async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper,
|
||||
)
|
||||
return async_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def percentage_rollout(
|
||||
flag_key: str,
|
||||
default: bool = False,
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""Decorator for percentage-based rollouts."""
|
||||
return feature_flag(flag_key, default)
|
||||
|
||||
|
||||
def beta_feature(
|
||||
flag_key: Optional[str] = None,
|
||||
unauthorized_response: Any = {"message": "Not available in beta"},
|
||||
) -> Callable[
|
||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
||||
]:
|
||||
"""Decorator for beta features."""
|
||||
actual_key = f"beta-{flag_key}" if flag_key else "beta"
|
||||
return feature_flag(actual_key, False)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_flag_variation(flag_key: str, return_value: Any):
|
||||
"""Context manager for testing feature flags."""
|
||||
|
||||
@@ -1,10 +1,26 @@
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@overload
|
||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
|
||||
@@ -57,3 +73,153 @@ def thread_cached(
|
||||
def clear_thread_cache(func: Callable) -> None:
|
||||
if clear := getattr(func, "clear_cache", None):
|
||||
clear()
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class AsyncCachedFunction(Protocol):
|
||||
"""Protocol for async functions with cache management methods."""
|
||||
|
||||
def cache_clear(self) -> None:
|
||||
"""Clear all cached entries."""
|
||||
return None
|
||||
|
||||
def cache_info(self) -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
return {}
|
||||
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Call the cached function."""
|
||||
return None
|
||||
|
||||
|
||||
def async_ttl_cache(
|
||||
maxsize: int = 128, ttl_seconds: int | None = None
|
||||
) -> Callable[[Callable[..., Awaitable[Any]]], AsyncCachedFunction]:
|
||||
"""
|
||||
TTL (Time To Live) cache decorator for async functions.
|
||||
|
||||
Similar to functools.lru_cache but works with async functions and includes optional TTL.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum number of cached entries
|
||||
ttl_seconds: Time to live in seconds. If None, entries never expire (like lru_cache)
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
|
||||
Example:
|
||||
# With TTL
|
||||
@async_ttl_cache(maxsize=1000, ttl_seconds=300)
|
||||
async def api_call(param: str) -> dict:
|
||||
return {"result": param}
|
||||
|
||||
# Without TTL (permanent cache like lru_cache)
|
||||
@async_ttl_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
return {"result": param}
|
||||
"""
|
||||
|
||||
def decorator(
|
||||
async_func: Callable[..., Awaitable[Any]],
|
||||
) -> AsyncCachedFunction:
|
||||
# Cache storage - use union type to handle both cases
|
||||
cache_storage: dict[Any, Any | Tuple[Any, float]] = {}
|
||||
|
||||
@wraps(async_func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Create cache key from arguments
|
||||
key = (args, tuple(sorted(kwargs.items())))
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we have a valid cached entry
|
||||
if key in cache_storage:
|
||||
if ttl_seconds is None:
|
||||
# No TTL - return cached result directly
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return cache_storage[key]
|
||||
else:
|
||||
# With TTL - check expiration
|
||||
cached_data = cache_storage[key]
|
||||
if isinstance(cached_data, tuple):
|
||||
result, timestamp = cached_data
|
||||
if current_time - timestamp < ttl_seconds:
|
||||
logger.debug(
|
||||
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
return result
|
||||
else:
|
||||
# Expired entry
|
||||
del cache_storage[key]
|
||||
logger.debug(
|
||||
f"Cache entry expired for {async_func.__name__}"
|
||||
)
|
||||
|
||||
# Cache miss or expired - fetch fresh data
|
||||
logger.debug(
|
||||
f"Cache miss for {async_func.__name__} with key: {str(key)[:50]}"
|
||||
)
|
||||
result = await async_func(*args, **kwargs)
|
||||
|
||||
# Store in cache
|
||||
if ttl_seconds is None:
|
||||
cache_storage[key] = result
|
||||
else:
|
||||
cache_storage[key] = (result, current_time)
|
||||
|
||||
# Simple cleanup when cache gets too large
|
||||
if len(cache_storage) > maxsize:
|
||||
# Remove oldest entries (simple FIFO cleanup)
|
||||
cutoff = maxsize // 2
|
||||
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||
for old_key in oldest_keys:
|
||||
cache_storage.pop(old_key, None)
|
||||
logger.debug(
|
||||
f"Cache cleanup: removed {len(oldest_keys)} entries for {async_func.__name__}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
# Add cache management methods (similar to functools.lru_cache)
|
||||
def cache_clear() -> None:
|
||||
cache_storage.clear()
|
||||
|
||||
def cache_info() -> dict[str, Any]:
|
||||
return {
|
||||
"size": len(cache_storage),
|
||||
"maxsize": maxsize,
|
||||
"ttl_seconds": ttl_seconds,
|
||||
}
|
||||
|
||||
# Attach methods to wrapper
|
||||
setattr(wrapper, "cache_clear", cache_clear)
|
||||
setattr(wrapper, "cache_info", cache_info)
|
||||
|
||||
return cast(AsyncCachedFunction, wrapper)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def async_cache(
|
||||
maxsize: int = 128,
|
||||
) -> Callable[[Callable[..., Awaitable[Any]]], AsyncCachedFunction]:
|
||||
"""
|
||||
Process-level cache decorator for async functions (no TTL).
|
||||
|
||||
Similar to functools.lru_cache but works with async functions.
|
||||
This is a convenience wrapper around async_ttl_cache with ttl_seconds=None.
|
||||
|
||||
Args:
|
||||
maxsize: Maximum number of cached entries
|
||||
|
||||
Returns:
|
||||
Decorator function
|
||||
|
||||
Example:
|
||||
@async_cache(maxsize=1000)
|
||||
async def expensive_computation(param: str) -> dict:
|
||||
# Expensive computation here
|
||||
return {"result": param}
|
||||
"""
|
||||
return async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||
|
||||
@@ -16,7 +16,12 @@ from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt_libs.utils.cache import clear_thread_cache, thread_cached
|
||||
from autogpt_libs.utils.cache import (
|
||||
async_cache,
|
||||
async_ttl_cache,
|
||||
clear_thread_cache,
|
||||
thread_cached,
|
||||
)
|
||||
|
||||
|
||||
class TestThreadCached:
|
||||
@@ -323,3 +328,378 @@ class TestThreadCached:
|
||||
|
||||
assert function_using_mock(2) == 42
|
||||
assert mock.call_count == 2
|
||||
|
||||
|
||||
class TestAsyncTTLCache:
|
||||
"""Tests for the @async_ttl_cache decorator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching(self):
|
||||
"""Test basic caching functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result3 = await cached_function(2, 3)
|
||||
assert result3 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_expiration(self):
|
||||
"""Test that cache entries expire after TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def short_lived_cache(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First call
|
||||
result1 = await short_lived_cache(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Second call immediately - should use cache
|
||||
result2 = await short_lived_cache(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Third call after expiration - should call function again
|
||||
result3 = await short_lived_cache(5)
|
||||
assert result3 == 10
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_info(self):
|
||||
"""Test cache info functionality."""
|
||||
|
||||
@async_ttl_cache(maxsize=5, ttl_seconds=300)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] == 300
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_clear(self):
|
||||
"""Test cache clearing functionality."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 4
|
||||
|
||||
# First call
|
||||
result1 = await clearable_function(2)
|
||||
assert result1 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Second call - should use cache
|
||||
result2 = await clearable_function(2)
|
||||
assert result2 == 8
|
||||
assert call_count == 1
|
||||
|
||||
# Clear cache
|
||||
clearable_function.cache_clear()
|
||||
|
||||
# Third call after clear - should call function again
|
||||
result3 = await clearable_function(2)
|
||||
assert result3 == 8
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maxsize_cleanup(self):
|
||||
"""Test that cache cleans up when maxsize is exceeded."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=3, ttl_seconds=60)
|
||||
async def size_limited_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# Fill cache to maxsize
|
||||
await size_limited_function(1) # call_count: 1
|
||||
await size_limited_function(2) # call_count: 2
|
||||
await size_limited_function(3) # call_count: 3
|
||||
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] == 3
|
||||
|
||||
# Add one more entry - should trigger cleanup
|
||||
await size_limited_function(4) # call_count: 4
|
||||
|
||||
# Cache size should be reduced (cleanup removes oldest entries)
|
||||
info = size_limited_function.cache_info()
|
||||
assert info["size"] <= 3 # Should be cleaned up
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_argument_variations(self):
|
||||
"""Test caching with different argument patterns."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def arg_test_function(a: int, b: str = "default", *, c: int = 100) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{a}-{b}-{c}"
|
||||
|
||||
# Different ways to call with same logical arguments
|
||||
result1 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
|
||||
# Same arguments, same order - should use cache
|
||||
result2 = await arg_test_function(1, "test", c=200)
|
||||
assert call_count == 1
|
||||
assert result1 == result2
|
||||
|
||||
# Different arguments - should call function
|
||||
result3 = await arg_test_function(2, "test", c=200)
|
||||
assert call_count == 2
|
||||
assert result1 != result3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_handling(self):
|
||||
"""Test that exceptions are not cached."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def exception_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if x < 0:
|
||||
raise ValueError("Negative value not allowed")
|
||||
return x * 2
|
||||
|
||||
# Successful call - should be cached
|
||||
result1 = await exception_function(5)
|
||||
assert result1 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Same successful call - should use cache
|
||||
result2 = await exception_function(5)
|
||||
assert result2 == 10
|
||||
assert call_count == 1
|
||||
|
||||
# Exception call - should not be cached
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
# Same exception call - should call again (not cached)
|
||||
with pytest.raises(ValueError):
|
||||
await exception_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_calls(self):
|
||||
"""Test caching behavior with concurrent calls."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||
async def concurrent_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.05) # Simulate work
|
||||
return x * x
|
||||
|
||||
# Launch concurrent calls with same arguments
|
||||
tasks = [concurrent_function(3) for _ in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All results should be the same
|
||||
assert all(result == 9 for result in results)
|
||||
|
||||
# Note: Due to race conditions, call_count might be up to 5 for concurrent calls
|
||||
# This tests that the cache doesn't break under concurrent access
|
||||
assert 1 <= call_count <= 5
|
||||
|
||||
|
||||
class TestAsyncCache:
|
||||
"""Tests for the @async_cache decorator (no TTL)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_caching_no_ttl(self):
|
||||
"""Test basic caching functionality without TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_cache(maxsize=10)
|
||||
async def cached_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01) # Simulate async work
|
||||
return x + y
|
||||
|
||||
# First call
|
||||
result1 = await cached_function(1, 2)
|
||||
assert result1 == 3
|
||||
assert call_count == 1
|
||||
|
||||
# Second call with same args - should use cache
|
||||
result2 = await cached_function(1, 2)
|
||||
assert result2 == 3
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Third call after some time - should still use cache (no TTL)
|
||||
await asyncio.sleep(0.05)
|
||||
result3 = await cached_function(1, 2)
|
||||
assert result3 == 3
|
||||
assert call_count == 1 # Still no additional call
|
||||
|
||||
# Different args - should call function again
|
||||
result4 = await cached_function(2, 3)
|
||||
assert result4 == 5
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_ttl_vs_ttl_behavior(self):
|
||||
"""Test the difference between TTL and no-TTL caching."""
|
||||
ttl_call_count = 0
|
||||
no_ttl_call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_call_count
|
||||
ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
@async_cache(maxsize=10) # No TTL
|
||||
async def no_ttl_function(x: int) -> int:
|
||||
nonlocal no_ttl_call_count
|
||||
no_ttl_call_count += 1
|
||||
return x * 2
|
||||
|
||||
# First calls
|
||||
await ttl_function(5)
|
||||
await no_ttl_function(5)
|
||||
assert ttl_call_count == 1
|
||||
assert no_ttl_call_count == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# Second calls after TTL expiry
|
||||
await ttl_function(5) # Should call function again (TTL expired)
|
||||
await no_ttl_function(5) # Should use cache (no TTL)
|
||||
assert ttl_call_count == 2 # TTL function called again
|
||||
assert no_ttl_call_count == 1 # No-TTL function still cached
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_cache_info(self):
|
||||
"""Test cache info for no-TTL cache."""
|
||||
|
||||
@async_cache(maxsize=5)
|
||||
async def info_test_function(x: int) -> int:
|
||||
return x * 3
|
||||
|
||||
# Check initial cache info
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 0
|
||||
assert info["maxsize"] == 5
|
||||
assert info["ttl_seconds"] is None # No TTL
|
||||
|
||||
# Add an entry
|
||||
await info_test_function(1)
|
||||
info = info_test_function.cache_info()
|
||||
assert info["size"] == 1
|
||||
|
||||
|
||||
class TestTTLOptional:
|
||||
"""Tests for optional TTL functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ttl_none_behavior(self):
|
||||
"""Test that ttl_seconds=None works like no TTL."""
|
||||
call_count = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=None)
|
||||
async def no_ttl_via_none(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x**2
|
||||
|
||||
# First call
|
||||
result1 = await no_ttl_via_none(3)
|
||||
assert result1 == 9
|
||||
assert call_count == 1
|
||||
|
||||
# Wait (would expire if there was TTL)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Second call - should still use cache
|
||||
result2 = await no_ttl_via_none(3)
|
||||
assert result2 == 9
|
||||
assert call_count == 1 # No additional call
|
||||
|
||||
# Check cache info
|
||||
info = no_ttl_via_none.cache_info()
|
||||
assert info["ttl_seconds"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_options_comparison(self):
|
||||
"""Test different cache options work as expected."""
|
||||
ttl_calls = 0
|
||||
no_ttl_calls = 0
|
||||
|
||||
@async_ttl_cache(maxsize=10, ttl_seconds=1) # With TTL
|
||||
async def ttl_function(x: int) -> int:
|
||||
nonlocal ttl_calls
|
||||
ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
@async_cache(maxsize=10) # Process-level cache (no TTL)
|
||||
async def process_function(x: int) -> int:
|
||||
nonlocal no_ttl_calls
|
||||
no_ttl_calls += 1
|
||||
return x * 10
|
||||
|
||||
# Both should cache initially
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Immediate second calls - both should use cache
|
||||
await ttl_function(3)
|
||||
await process_function(3)
|
||||
assert ttl_calls == 1
|
||||
assert no_ttl_calls == 1
|
||||
|
||||
# Wait for TTL to expire
|
||||
await asyncio.sleep(1.1)
|
||||
|
||||
# After TTL expiry
|
||||
await ttl_function(3) # Should call function again
|
||||
await process_function(3) # Should still use cache
|
||||
assert ttl_calls == 2 # TTL cache expired, called again
|
||||
assert no_ttl_calls == 1 # Process cache never expires
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from prisma.models import User
|
||||
|
||||
from backend.data.model import ProviderName
|
||||
from backend.data.model import ProviderName, User
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
|
||||
@@ -998,8 +998,8 @@ def get_block_costs() -> dict[str, list[BlockCost]]:
|
||||
async def get_stripe_customer_id(user_id: str) -> str:
|
||||
user = await get_user_by_id(user_id)
|
||||
|
||||
if user.stripeCustomerId:
|
||||
return user.stripeCustomerId
|
||||
if user.stripe_customer_id:
|
||||
return user.stripe_customer_id
|
||||
|
||||
customer = stripe.Customer.create(
|
||||
name=user.name or "",
|
||||
@@ -1022,10 +1022,10 @@ async def set_auto_top_up(user_id: str, config: AutoTopUpConfig):
|
||||
async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
|
||||
user = await get_user_by_id(user_id)
|
||||
|
||||
if not user.topUpConfig:
|
||||
if not user.top_up_config:
|
||||
return AutoTopUpConfig(threshold=0, amount=0)
|
||||
|
||||
return AutoTopUpConfig.model_validate(user.topUpConfig)
|
||||
return AutoTopUpConfig.model_validate(user.top_up_config)
|
||||
|
||||
|
||||
async def admin_get_user_history(
|
||||
|
||||
@@ -5,6 +5,7 @@ import enum
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from json import JSONDecodeError
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
@@ -40,12 +41,120 @@ from pydantic_core import (
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.json import loads as json_loads
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
# Type alias for any provider name (including custom ones)
|
||||
AnyProviderName = str # Will be validated as ProviderName at runtime
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""Application-layer User model with snake_case convention."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
str_strip_whitespace=True,
|
||||
)
|
||||
|
||||
id: str = Field(..., description="User ID")
|
||||
email: str = Field(..., description="User email address")
|
||||
email_verified: bool = Field(default=True, description="Whether email is verified")
|
||||
name: Optional[str] = Field(None, description="User display name")
|
||||
created_at: datetime = Field(..., description="When user was created")
|
||||
updated_at: datetime = Field(..., description="When user was last updated")
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict, description="User metadata as dict"
|
||||
)
|
||||
integrations: str = Field(default="", description="Encrypted integrations data")
|
||||
stripe_customer_id: Optional[str] = Field(None, description="Stripe customer ID")
|
||||
top_up_config: Optional["AutoTopUpConfig"] = Field(
|
||||
None, description="Top up configuration"
|
||||
)
|
||||
|
||||
# Notification preferences
|
||||
max_emails_per_day: int = Field(default=3, description="Maximum emails per day")
|
||||
notify_on_agent_run: bool = Field(default=True, description="Notify on agent run")
|
||||
notify_on_zero_balance: bool = Field(
|
||||
default=True, description="Notify on zero balance"
|
||||
)
|
||||
notify_on_low_balance: bool = Field(
|
||||
default=True, description="Notify on low balance"
|
||||
)
|
||||
notify_on_block_execution_failed: bool = Field(
|
||||
default=True, description="Notify on block execution failure"
|
||||
)
|
||||
notify_on_continuous_agent_error: bool = Field(
|
||||
default=True, description="Notify on continuous agent error"
|
||||
)
|
||||
notify_on_daily_summary: bool = Field(
|
||||
default=True, description="Notify on daily summary"
|
||||
)
|
||||
notify_on_weekly_summary: bool = Field(
|
||||
default=True, description="Notify on weekly summary"
|
||||
)
|
||||
notify_on_monthly_summary: bool = Field(
|
||||
default=True, description="Notify on monthly summary"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_user: "PrismaUser") -> "User":
|
||||
"""Convert a database User object to application User model."""
|
||||
# Handle metadata field - convert from JSON string or dict to dict
|
||||
metadata = {}
|
||||
if prisma_user.metadata:
|
||||
if isinstance(prisma_user.metadata, str):
|
||||
try:
|
||||
metadata = json_loads(prisma_user.metadata)
|
||||
except (JSONDecodeError, TypeError):
|
||||
metadata = {}
|
||||
elif isinstance(prisma_user.metadata, dict):
|
||||
metadata = prisma_user.metadata
|
||||
|
||||
# Handle topUpConfig field
|
||||
top_up_config = None
|
||||
if prisma_user.topUpConfig:
|
||||
if isinstance(prisma_user.topUpConfig, str):
|
||||
try:
|
||||
config_dict = json_loads(prisma_user.topUpConfig)
|
||||
top_up_config = AutoTopUpConfig.model_validate(config_dict)
|
||||
except (JSONDecodeError, TypeError, ValueError):
|
||||
top_up_config = None
|
||||
elif isinstance(prisma_user.topUpConfig, dict):
|
||||
try:
|
||||
top_up_config = AutoTopUpConfig.model_validate(
|
||||
prisma_user.topUpConfig
|
||||
)
|
||||
except ValueError:
|
||||
top_up_config = None
|
||||
|
||||
return cls(
|
||||
id=prisma_user.id,
|
||||
email=prisma_user.email,
|
||||
email_verified=prisma_user.emailVerified or True,
|
||||
name=prisma_user.name,
|
||||
created_at=prisma_user.createdAt,
|
||||
updated_at=prisma_user.updatedAt,
|
||||
metadata=metadata,
|
||||
integrations=prisma_user.integrations or "",
|
||||
stripe_customer_id=prisma_user.stripeCustomerId,
|
||||
top_up_config=top_up_config,
|
||||
max_emails_per_day=prisma_user.maxEmailsPerDay or 3,
|
||||
notify_on_agent_run=prisma_user.notifyOnAgentRun or True,
|
||||
notify_on_zero_balance=prisma_user.notifyOnZeroBalance or True,
|
||||
notify_on_low_balance=prisma_user.notifyOnLowBalance or True,
|
||||
notify_on_block_execution_failed=prisma_user.notifyOnBlockExecutionFailed
|
||||
or True,
|
||||
notify_on_continuous_agent_error=prisma_user.notifyOnContinuousAgentError
|
||||
or True,
|
||||
notify_on_daily_summary=prisma_user.notifyOnDailySummary or True,
|
||||
notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True,
|
||||
notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from prisma.models import User as PrismaUser
|
||||
|
||||
from backend.data.block import BlockSchema
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -9,11 +9,11 @@ from urllib.parse import quote_plus
|
||||
from autogpt_libs.auth.models import DEFAULT_USER_ID
|
||||
from fastapi import HTTPException
|
||||
from prisma.enums import NotificationType
|
||||
from prisma.models import User
|
||||
from prisma.models import User as PrismaUser
|
||||
from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
|
||||
|
||||
from backend.data.db import prisma
|
||||
from backend.data.model import UserIntegrations, UserMetadata
|
||||
from backend.data.model import User, UserIntegrations, UserMetadata
|
||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util.encryption import JSONCryptor
|
||||
@@ -44,7 +44,7 @@ async def get_or_create_user(user_data: dict) -> User:
|
||||
)
|
||||
)
|
||||
|
||||
return User.model_validate(user)
|
||||
return User.from_db(user)
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to get or create user {user_data}: {e}") from e
|
||||
|
||||
@@ -53,7 +53,7 @@ async def get_user_by_id(user_id: str) -> User:
|
||||
user = await prisma.user.find_unique(where={"id": user_id})
|
||||
if not user:
|
||||
raise ValueError(f"User not found with ID: {user_id}")
|
||||
return User.model_validate(user)
|
||||
return User.from_db(user)
|
||||
|
||||
|
||||
async def get_user_email_by_id(user_id: str) -> Optional[str]:
|
||||
@@ -67,7 +67,7 @@ async def get_user_email_by_id(user_id: str) -> Optional[str]:
|
||||
async def get_user_by_email(email: str) -> Optional[User]:
|
||||
try:
|
||||
user = await prisma.user.find_unique(where={"email": email})
|
||||
return User.model_validate(user) if user else None
|
||||
return User.from_db(user) if user else None
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to get user by email {email}: {e}") from e
|
||||
|
||||
@@ -91,11 +91,11 @@ async def create_default_user() -> Optional[User]:
|
||||
name="Default User",
|
||||
)
|
||||
)
|
||||
return User.model_validate(user)
|
||||
return User.from_db(user)
|
||||
|
||||
|
||||
async def get_user_integrations(user_id: str) -> UserIntegrations:
|
||||
user = await User.prisma().find_unique_or_raise(
|
||||
user = await PrismaUser.prisma().find_unique_or_raise(
|
||||
where={"id": user_id},
|
||||
)
|
||||
|
||||
@@ -110,7 +110,7 @@ async def get_user_integrations(user_id: str) -> UserIntegrations:
|
||||
|
||||
async def update_user_integrations(user_id: str, data: UserIntegrations):
|
||||
encrypted_data = JSONCryptor().encrypt(data.model_dump(exclude_none=True))
|
||||
await User.prisma().update(
|
||||
await PrismaUser.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={"integrations": encrypted_data},
|
||||
)
|
||||
@@ -118,7 +118,7 @@ async def update_user_integrations(user_id: str, data: UserIntegrations):
|
||||
|
||||
async def migrate_and_encrypt_user_integrations():
|
||||
"""Migrate integration credentials and OAuth states from metadata to integrations column."""
|
||||
users = await User.prisma().find_many(
|
||||
users = await PrismaUser.prisma().find_many(
|
||||
where={
|
||||
"metadata": cast(
|
||||
JsonFilter,
|
||||
@@ -154,7 +154,7 @@ async def migrate_and_encrypt_user_integrations():
|
||||
raw_metadata.pop("integration_oauth_states", None)
|
||||
|
||||
# Update metadata without integration data
|
||||
await User.prisma().update(
|
||||
await PrismaUser.prisma().update(
|
||||
where={"id": user.id},
|
||||
data={"metadata": SafeJson(raw_metadata)},
|
||||
)
|
||||
@@ -162,7 +162,7 @@ async def migrate_and_encrypt_user_integrations():
|
||||
|
||||
async def get_active_user_ids_in_timerange(start_time: str, end_time: str) -> list[str]:
|
||||
try:
|
||||
users = await User.prisma().find_many(
|
||||
users = await PrismaUser.prisma().find_many(
|
||||
where={
|
||||
"AgentGraphExecutions": {
|
||||
"some": {
|
||||
@@ -192,7 +192,7 @@ async def get_active_users_ids() -> list[str]:
|
||||
|
||||
async def get_user_notification_preference(user_id: str) -> NotificationPreference:
|
||||
try:
|
||||
user = await User.prisma().find_unique_or_raise(
|
||||
user = await PrismaUser.prisma().find_unique_or_raise(
|
||||
where={"id": user_id},
|
||||
)
|
||||
|
||||
@@ -269,7 +269,7 @@ async def update_user_notification_preference(
|
||||
if data.daily_limit:
|
||||
update_data["maxEmailsPerDay"] = data.daily_limit
|
||||
|
||||
user = await User.prisma().update(
|
||||
user = await PrismaUser.prisma().update(
|
||||
where={"id": user_id},
|
||||
data=update_data,
|
||||
)
|
||||
@@ -307,7 +307,7 @@ async def update_user_notification_preference(
|
||||
async def set_user_email_verification(user_id: str, verified: bool) -> None:
|
||||
"""Set the email verification status for a user."""
|
||||
try:
|
||||
await User.prisma().update(
|
||||
await PrismaUser.prisma().update(
|
||||
where={"id": user_id},
|
||||
data={"emailVerified": verified},
|
||||
)
|
||||
@@ -320,7 +320,7 @@ async def set_user_email_verification(user_id: str, verified: bool) -> None:
|
||||
async def get_user_email_verification(user_id: str) -> bool:
|
||||
"""Get the email verification status for a user."""
|
||||
try:
|
||||
user = await User.prisma().find_unique_or_raise(
|
||||
user = await PrismaUser.prisma().find_unique_or_raise(
|
||||
where={"id": user_id},
|
||||
)
|
||||
return user.emailVerified
|
||||
|
||||
@@ -102,8 +102,10 @@ async def generate_activity_status_for_execution(
|
||||
Returns:
|
||||
AI-generated activity status string, or None if feature is disabled
|
||||
"""
|
||||
# Check LaunchDarkly feature flag for AI activity status generation
|
||||
if not is_feature_enabled(AI_ACTIVITY_STATUS_FLAG_KEY, user_id, default=False):
|
||||
# Check LaunchDarkly feature flag for AI activity status generation with full context support
|
||||
if not await is_feature_enabled(
|
||||
AI_ACTIVITY_STATUS_FLAG_KEY, user_id, default=False
|
||||
):
|
||||
logger.debug("AI activity status generation is disabled via LaunchDarkly")
|
||||
return None
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ from backend.data.notifications import (
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_active_user_ids_in_timerange,
|
||||
get_user_by_id,
|
||||
get_user_email_by_id,
|
||||
get_user_email_verification,
|
||||
get_user_integrations,
|
||||
@@ -129,6 +130,7 @@ class DatabaseManager(AppService):
|
||||
|
||||
# User Comms - async
|
||||
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
|
||||
get_user_by_id = _(get_user_by_id)
|
||||
get_user_email_by_id = _(get_user_email_by_id)
|
||||
get_user_email_verification = _(get_user_email_verification)
|
||||
get_user_notification_preference = _(get_user_notification_preference)
|
||||
@@ -201,6 +203,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
|
||||
# User Comms
|
||||
get_active_user_ids_in_timerange = d.get_active_user_ids_in_timerange
|
||||
get_user_by_id = d.get_user_by_id
|
||||
get_user_email_by_id = d.get_user_email_by_id
|
||||
get_user_email_verification = d.get_user_email_verification
|
||||
get_user_notification_preference = d.get_user_notification_preference
|
||||
|
||||
@@ -3,7 +3,6 @@ import logging
|
||||
import autogpt_libs.auth.models
|
||||
import fastapi.responses
|
||||
import pytest
|
||||
from prisma.models import User
|
||||
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.store.model
|
||||
@@ -12,6 +11,7 @@ from backend.blocks.data_manipulation import FindInDictionaryBlock
|
||||
from backend.blocks.io import AgentInputBlock
|
||||
from backend.blocks.maths import CalculatorBlock, Operation
|
||||
from backend.data import execution, graph
|
||||
from backend.data.model import User
|
||||
from backend.server.model import CreateGraph
|
||||
from backend.server.rest_api import AgentServer
|
||||
from backend.usecases.sample import create_test_graph, create_test_user
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
from pathlib import Path
|
||||
|
||||
from prisma.models import User
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.block import BlockInstallationBlock
|
||||
from backend.blocks.http import SendWebRequestBlock
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.blocks.text import ExtractTextInformationBlock, FillTextTemplateBlock
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.model import User
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from prisma.models import User
|
||||
|
||||
from backend.blocks.llm import AIStructuredResponseGeneratorBlock
|
||||
from backend.blocks.reddit import GetRedditPostsBlock, PostRedditCommentBlock
|
||||
from backend.blocks.text import FillTextTemplateBlock, MatchTextPatternBlock
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.model import User
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from prisma.models import User
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.io import AgentInputBlock
|
||||
from backend.blocks.text import FillTextTemplateBlock
|
||||
from backend.data import graph
|
||||
from backend.data.graph import create_graph
|
||||
from backend.data.model import User
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.util.test import SpinTestServer, wait_execution
|
||||
|
||||
|
||||
@@ -45,6 +45,34 @@ api_comm_retry = config.pyro_client_comm_retry
|
||||
api_comm_timeout = config.pyro_client_comm_timeout
|
||||
api_call_timeout = config.rpc_client_call_timeout
|
||||
|
||||
|
||||
def _validate_no_prisma_objects(obj: Any, path: str = "result") -> None:
|
||||
"""
|
||||
Recursively validate that no Prisma objects are being returned from service methods.
|
||||
This enforces proper separation of layers - only application models should cross service boundaries.
|
||||
"""
|
||||
if obj is None:
|
||||
return
|
||||
|
||||
# Check if it's a Prisma model object
|
||||
if hasattr(obj, "__class__") and hasattr(obj.__class__, "__module__"):
|
||||
module_name = obj.__class__.__module__
|
||||
if module_name and "prisma.models" in module_name:
|
||||
raise ValueError(
|
||||
f"Prisma object {obj.__class__.__name__} found in {path}. "
|
||||
"Service methods must return application models, not Prisma objects. "
|
||||
f"Use {obj.__class__.__name__}.from_db() to convert to application model."
|
||||
)
|
||||
|
||||
# Recursively check collections
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for i, item in enumerate(obj):
|
||||
_validate_no_prisma_objects(item, f"{path}[{i}]")
|
||||
elif isinstance(obj, dict):
|
||||
for key, value in obj.items():
|
||||
_validate_no_prisma_objects(value, f"{path}['{key}']")
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
EXPOSED_FLAG = "__exposed__"
|
||||
@@ -111,6 +139,22 @@ class UnhealthyServiceError(ValueError):
|
||||
return self.message
|
||||
|
||||
|
||||
class HTTPClientError(Exception):
|
||||
"""Exception for HTTP client errors (4xx status codes) that should not be retried."""
|
||||
|
||||
def __init__(self, status_code: int, message: str):
|
||||
self.status_code = status_code
|
||||
super().__init__(f"HTTP {status_code}: {message}")
|
||||
|
||||
|
||||
class HTTPServerError(Exception):
|
||||
"""Exception for HTTP server errors (5xx status codes) that can be retried."""
|
||||
|
||||
def __init__(self, status_code: int, message: str):
|
||||
self.status_code = status_code
|
||||
super().__init__(f"HTTP {status_code}: {message}")
|
||||
|
||||
|
||||
EXCEPTION_MAPPING = {
|
||||
e.__name__: e
|
||||
for e in [
|
||||
@@ -119,6 +163,8 @@ EXCEPTION_MAPPING = {
|
||||
TimeoutError,
|
||||
ConnectionError,
|
||||
UnhealthyServiceError,
|
||||
HTTPClientError,
|
||||
HTTPServerError,
|
||||
*[
|
||||
ErrorType
|
||||
for _, ErrorType in inspect.getmembers(exceptions)
|
||||
@@ -191,17 +237,21 @@ class AppService(BaseAppService, ABC):
|
||||
if asyncio.iscoroutinefunction(f):
|
||||
|
||||
async def async_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
|
||||
return await f(
|
||||
result = await f(
|
||||
**{name: getattr(body, name) for name in type(body).model_fields}
|
||||
)
|
||||
_validate_no_prisma_objects(result, f"{func.__name__} result")
|
||||
return result
|
||||
|
||||
return async_endpoint
|
||||
else:
|
||||
|
||||
def sync_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
|
||||
return f(
|
||||
result = f(
|
||||
**{name: getattr(body, name) for name in type(body).model_fields}
|
||||
)
|
||||
_validate_no_prisma_objects(result, f"{func.__name__} result")
|
||||
return result
|
||||
|
||||
return sync_endpoint
|
||||
|
||||
@@ -313,6 +363,7 @@ def get_service_client(
|
||||
AttributeError, # Missing attributes
|
||||
asyncio.CancelledError, # Task was cancelled
|
||||
concurrent.futures.CancelledError, # Future was cancelled
|
||||
HTTPClientError, # HTTP 4xx client errors - don't retry
|
||||
),
|
||||
)(fn)
|
||||
|
||||
@@ -390,11 +441,31 @@ def get_service_client(
|
||||
self._connection_failure_count = 0
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
error = RemoteCallError.model_validate(e.response.json())
|
||||
# DEBUG HELP: if you made a custom exception, make sure you override self.args to be how to make your exception
|
||||
raise EXCEPTION_MAPPING.get(error.type, Exception)(
|
||||
*(error.args or [str(e)])
|
||||
)
|
||||
status_code = e.response.status_code
|
||||
|
||||
# Try to parse the error response as RemoteCallError for mapped exceptions
|
||||
error_response = None
|
||||
try:
|
||||
error_response = RemoteCallError.model_validate(e.response.json())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If we successfully parsed a mapped exception type, re-raise it
|
||||
if error_response and error_response.type in EXCEPTION_MAPPING:
|
||||
exception_class = EXCEPTION_MAPPING[error_response.type]
|
||||
args = error_response.args or [str(e)]
|
||||
raise exception_class(*args)
|
||||
|
||||
# Otherwise categorize by HTTP status code
|
||||
if 400 <= status_code < 500:
|
||||
# Client errors (4xx) - wrap to prevent retries
|
||||
raise HTTPClientError(status_code, str(e))
|
||||
elif 500 <= status_code < 600:
|
||||
# Server errors (5xx) - wrap but allow retries
|
||||
raise HTTPServerError(status_code, str(e))
|
||||
else:
|
||||
# Other status codes (1xx, 2xx, 3xx) - re-raise original error
|
||||
raise e
|
||||
|
||||
@_maybe_retry
|
||||
def _call_method_sync(self, method_name: str, **kwargs: Any) -> Any:
|
||||
|
||||
@@ -8,6 +8,8 @@ import pytest
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
HTTPClientError,
|
||||
HTTPServerError,
|
||||
endpoint_to_async,
|
||||
expose,
|
||||
get_service_client,
|
||||
@@ -366,3 +368,125 @@ def test_service_no_retry_when_disabled(server):
|
||||
# This should fail immediately without retry
|
||||
with pytest.raises(RuntimeError, match="Intended error for testing"):
|
||||
client.always_failing_add(5, 3)
|
||||
|
||||
|
||||
class TestHTTPErrorRetryBehavior:
|
||||
"""Test that HTTP client errors (4xx) are not retried but server errors (5xx) can be."""
|
||||
|
||||
# Note: These tests access private methods for testing internal behavior
|
||||
# Type ignore comments are used to suppress warnings about accessing private methods
|
||||
|
||||
def test_http_client_error_not_retried(self):
|
||||
"""Test that 4xx errors are wrapped as HTTPClientError and not retried."""
|
||||
# Create a mock response with 404 status
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 404
|
||||
mock_response.json.return_value = {"message": "Not found"}
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"404 Not Found", request=Mock(), response=mock_response
|
||||
)
|
||||
|
||||
# Create client
|
||||
client = get_service_client(ServiceTestClient)
|
||||
dynamic_client = client
|
||||
|
||||
# Test the _handle_call_method_response directly
|
||||
with pytest.raises(HTTPClientError) as exc_info:
|
||||
dynamic_client._handle_call_method_response( # type: ignore[attr-defined]
|
||||
response=mock_response, method_name="test_method"
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
assert "404" in str(exc_info.value)
|
||||
|
||||
def test_http_server_error_can_be_retried(self):
|
||||
"""Test that 5xx errors are wrapped as HTTPServerError and can be retried."""
|
||||
# Create a mock response with 500 status
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.json.return_value = {"message": "Internal server error"}
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"500 Internal Server Error", request=Mock(), response=mock_response
|
||||
)
|
||||
|
||||
# Create client
|
||||
client = get_service_client(ServiceTestClient)
|
||||
dynamic_client = client
|
||||
|
||||
# Test the _handle_call_method_response directly
|
||||
with pytest.raises(HTTPServerError) as exc_info:
|
||||
dynamic_client._handle_call_method_response( # type: ignore[attr-defined]
|
||||
response=mock_response, method_name="test_method"
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert "500" in str(exc_info.value)
|
||||
|
||||
def test_mapped_exception_preserves_original_type(self):
|
||||
"""Test that mapped exceptions preserve their original type regardless of HTTP status."""
|
||||
# Create a mock response with ValueError in the remote call error
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 400
|
||||
mock_response.json.return_value = {
|
||||
"type": "ValueError",
|
||||
"args": ["Invalid parameter value"],
|
||||
}
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"400 Bad Request", request=Mock(), response=mock_response
|
||||
)
|
||||
|
||||
# Create client
|
||||
client = get_service_client(ServiceTestClient)
|
||||
dynamic_client = client
|
||||
|
||||
# Test the _handle_call_method_response directly
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
dynamic_client._handle_call_method_response( # type: ignore[attr-defined]
|
||||
response=mock_response, method_name="test_method"
|
||||
)
|
||||
|
||||
assert "Invalid parameter value" in str(exc_info.value)
|
||||
|
||||
def test_client_error_status_codes_coverage(self):
|
||||
"""Test that various 4xx status codes are all wrapped as HTTPClientError."""
|
||||
client_error_codes = [400, 401, 403, 404, 405, 409, 422, 429]
|
||||
|
||||
for status_code in client_error_codes:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.json.return_value = {"message": f"Error {status_code}"}
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
f"{status_code} Error", request=Mock(), response=mock_response
|
||||
)
|
||||
|
||||
client = get_service_client(ServiceTestClient)
|
||||
dynamic_client = client
|
||||
|
||||
with pytest.raises(HTTPClientError) as exc_info:
|
||||
dynamic_client._handle_call_method_response( # type: ignore
|
||||
response=mock_response, method_name="test_method"
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status_code
|
||||
|
||||
def test_server_error_status_codes_coverage(self):
|
||||
"""Test that various 5xx status codes are all wrapped as HTTPServerError."""
|
||||
server_error_codes = [500, 501, 502, 503, 504, 505]
|
||||
|
||||
for status_code in server_error_codes:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = status_code
|
||||
mock_response.json.return_value = {"message": f"Error {status_code}"}
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
f"{status_code} Error", request=Mock(), response=mock_response
|
||||
)
|
||||
|
||||
client = get_service_client(ServiceTestClient)
|
||||
dynamic_client = client
|
||||
|
||||
with pytest.raises(HTTPServerError) as exc_info:
|
||||
dynamic_client._handle_call_method_response( # type: ignore
|
||||
response=mock_response, method_name="test_method"
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status_code
|
||||
|
||||
Reference in New Issue
Block a user