mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-21 04:57:58 -05:00
Compare commits
6 Commits
testing-cl
...
feat/launc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5899897d8f | ||
|
|
0389c865aa | ||
|
|
b88576d313 | ||
|
|
14634a6ce9 | ||
|
|
10a402a766 | ||
|
|
004011726d |
@@ -1,10 +1,14 @@
|
|||||||
import asyncio
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, cast
|
from json import JSONDecodeError
|
||||||
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, TypeVar
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.data.model import User
|
||||||
|
|
||||||
import ldclient
|
import ldclient
|
||||||
|
from backend.util.json import loads as json_loads
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from ldclient import Context, LDClient
|
from ldclient import Context, LDClient
|
||||||
from ldclient.config import Config
|
from ldclient.config import Config
|
||||||
@@ -56,32 +60,202 @@ def shutdown_launchdarkly() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def create_context(
|
def create_context(
|
||||||
user_id: str, additional_attributes: Optional[Dict[str, Any]] = None
|
user_id: str, additional_attributes: Optional[dict[str, Any]] = None
|
||||||
) -> Context:
|
) -> Context:
|
||||||
"""Create LaunchDarkly context with optional additional attributes."""
|
"""Create LaunchDarkly context with optional additional attributes."""
|
||||||
builder = Context.builder(str(user_id)).kind("user")
|
# Use the key from attributes if provided, otherwise use user_id
|
||||||
|
context_key = user_id
|
||||||
|
if additional_attributes and "key" in additional_attributes:
|
||||||
|
context_key = additional_attributes["key"]
|
||||||
|
|
||||||
|
builder = Context.builder(str(context_key)).kind("user")
|
||||||
|
|
||||||
if additional_attributes:
|
if additional_attributes:
|
||||||
for key, value in additional_attributes.items():
|
for key, value in additional_attributes.items():
|
||||||
builder.set(key, value)
|
# Skip kind and key as they're already set
|
||||||
|
if key in ["kind", "key"]:
|
||||||
|
continue
|
||||||
|
elif key == "custom" and isinstance(value, dict):
|
||||||
|
# Handle custom attributes object - these go as individual attributes
|
||||||
|
for custom_key, custom_value in value.items():
|
||||||
|
builder.set(custom_key, custom_value)
|
||||||
|
else:
|
||||||
|
builder.set(key, value)
|
||||||
return builder.build()
|
return builder.build()
|
||||||
|
|
||||||
|
|
||||||
def is_feature_enabled(flag_key: str, user_id: str, default: bool = False) -> bool:
|
async def _fetch_user_context_data(user_id: str) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Simple helper to check if a feature flag is enabled for a user.
|
Fetch user data and build LaunchDarkly context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID to fetch data for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with user context data including role
|
||||||
|
"""
|
||||||
|
# Use the unified database access approach
|
||||||
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
|
||||||
|
db_client = get_database_manager_async_client()
|
||||||
|
user = await db_client.get_user_by_id(user_id)
|
||||||
|
|
||||||
|
# Build LaunchDarkly context from user data
|
||||||
|
return _build_launchdarkly_context(user)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_launchdarkly_context(user: "User") -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Build LaunchDarkly context data matching frontend format.
|
||||||
|
|
||||||
|
Returns a context like:
|
||||||
|
{
|
||||||
|
"kind": "user",
|
||||||
|
"key": "user-id",
|
||||||
|
"email": "user@example.com", # Optional
|
||||||
|
"anonymous": false,
|
||||||
|
"custom": {
|
||||||
|
"role": "admin" # Optional
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: User object from database
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with user context data
|
||||||
|
"""
|
||||||
|
from autogpt_libs.auth.models import DEFAULT_USER_ID
|
||||||
|
|
||||||
|
# Build basic context - always include kind, key, and anonymous
|
||||||
|
context_data: dict[str, Any] = {
|
||||||
|
"kind": "user",
|
||||||
|
"key": user.id,
|
||||||
|
"anonymous": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add email if present
|
||||||
|
if user.email:
|
||||||
|
context_data["email"] = user.email
|
||||||
|
|
||||||
|
# Initialize custom attributes
|
||||||
|
custom: dict[str, Any] = {}
|
||||||
|
|
||||||
|
# Determine user role from metadata
|
||||||
|
role = None
|
||||||
|
|
||||||
|
# Check if user is default/system user
|
||||||
|
if user.id == DEFAULT_USER_ID:
|
||||||
|
role = "admin" # Default user has admin privileges when auth is disabled
|
||||||
|
elif user.metadata:
|
||||||
|
# Check for role in metadata
|
||||||
|
try:
|
||||||
|
# Handle both string (direct DB) and dict (RPC) formats
|
||||||
|
if isinstance(user.metadata, str):
|
||||||
|
metadata = json_loads(user.metadata)
|
||||||
|
elif isinstance(user.metadata, dict):
|
||||||
|
metadata = user.metadata
|
||||||
|
else:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
# Extract role from metadata if present
|
||||||
|
if metadata.get("role"):
|
||||||
|
role = metadata["role"]
|
||||||
|
|
||||||
|
except (JSONDecodeError, TypeError) as e:
|
||||||
|
logger.debug(f"Failed to parse user metadata for context: {e}")
|
||||||
|
|
||||||
|
# Add role to custom attributes if present
|
||||||
|
if role:
|
||||||
|
custom["role"] = role
|
||||||
|
|
||||||
|
# Only add custom object if it has content
|
||||||
|
if custom:
|
||||||
|
context_data["custom"] = custom
|
||||||
|
|
||||||
|
return context_data
|
||||||
|
|
||||||
|
|
||||||
|
async def is_feature_enabled(
|
||||||
|
flag_key: str,
|
||||||
|
user_id: str,
|
||||||
|
default: bool = False,
|
||||||
|
use_user_id_only: bool = False,
|
||||||
|
additional_attributes: Optional[dict[str, Any]] = None,
|
||||||
|
user_role: Optional[str] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a feature flag is enabled for a user with full LaunchDarkly context support.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
flag_key: The LaunchDarkly feature flag key
|
flag_key: The LaunchDarkly feature flag key
|
||||||
user_id: The user ID to evaluate the flag for
|
user_id: The user ID to evaluate the flag for
|
||||||
default: Default value if LaunchDarkly is unavailable or flag evaluation fails
|
default: Default value if LaunchDarkly is unavailable or flag evaluation fails
|
||||||
|
use_user_id_only: If True, only use user_id without fetching database context
|
||||||
|
additional_attributes: Additional attributes to include in the context
|
||||||
|
user_role: Optional user role (e.g., "admin", "user") to add to segments
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if feature is enabled, False otherwise
|
True if feature is enabled, False otherwise
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
client = get_client()
|
client = get_client()
|
||||||
context = create_context(str(user_id))
|
|
||||||
return client.variation(flag_key, context, default)
|
if use_user_id_only:
|
||||||
|
# Simple context with just user ID (for backward compatibility)
|
||||||
|
attrs = additional_attributes or {}
|
||||||
|
if user_role:
|
||||||
|
# Add role to custom attributes for consistency
|
||||||
|
if "custom" not in attrs:
|
||||||
|
attrs["custom"] = {}
|
||||||
|
if isinstance(attrs["custom"], dict):
|
||||||
|
attrs["custom"]["role"] = user_role
|
||||||
|
context = create_context(str(user_id), attrs)
|
||||||
|
else:
|
||||||
|
# Full context with user segments and metadata from database
|
||||||
|
try:
|
||||||
|
user_data = await _fetch_user_context_data(user_id)
|
||||||
|
except ImportError as e:
|
||||||
|
# Database modules not available - fallback to simple context
|
||||||
|
logger.debug(f"Database modules not available: {e}")
|
||||||
|
user_data = {}
|
||||||
|
except Exception as e:
|
||||||
|
# Database error - log and fallback to simple context
|
||||||
|
logger.warning(f"Failed to fetch user context for {user_id}: {e}")
|
||||||
|
user_data = {}
|
||||||
|
|
||||||
|
# Merge additional attributes and role
|
||||||
|
attrs = additional_attributes or {}
|
||||||
|
|
||||||
|
# If user_role is provided, add it to custom attributes
|
||||||
|
if user_role:
|
||||||
|
if "custom" not in user_data:
|
||||||
|
user_data["custom"] = {}
|
||||||
|
user_data["custom"]["role"] = user_role
|
||||||
|
|
||||||
|
# Merge additional attributes with user data
|
||||||
|
# Handle custom attributes specially
|
||||||
|
if "custom" in attrs and isinstance(attrs["custom"], dict):
|
||||||
|
if "custom" not in user_data:
|
||||||
|
user_data["custom"] = {}
|
||||||
|
user_data["custom"].update(attrs["custom"])
|
||||||
|
# Remove custom from attrs to avoid duplication
|
||||||
|
attrs = {k: v for k, v in attrs.items() if k != "custom"}
|
||||||
|
|
||||||
|
# Merge remaining attributes
|
||||||
|
final_attrs = {**user_data, **attrs}
|
||||||
|
|
||||||
|
context = create_context(str(user_id), final_attrs)
|
||||||
|
|
||||||
|
# Evaluate the flag
|
||||||
|
result = client.variation(flag_key, context, default)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Feature flag {flag_key} for user {user_id}: {result} "
|
||||||
|
f"(use_user_id_only: {use_user_id_only})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -93,16 +267,19 @@ def is_feature_enabled(flag_key: str, user_id: str, default: bool = False) -> bo
|
|||||||
def feature_flag(
|
def feature_flag(
|
||||||
flag_key: str,
|
flag_key: str,
|
||||||
default: bool = False,
|
default: bool = False,
|
||||||
) -> Callable[
|
) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]:
|
||||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
|
||||||
]:
|
|
||||||
"""
|
"""
|
||||||
Decorator for feature flag protected endpoints.
|
Decorator for async feature flag protected endpoints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
flag_key: The LaunchDarkly feature flag key
|
||||||
|
default: Default value if flag evaluation fails
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decorator that only works with async functions
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(
|
def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
|
||||||
func: Callable[P, Union[T, Awaitable[T]]],
|
|
||||||
) -> Callable[P, Union[T, Awaitable[T]]]:
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
||||||
try:
|
try:
|
||||||
@@ -116,73 +293,24 @@ def feature_flag(
|
|||||||
)
|
)
|
||||||
is_enabled = default
|
is_enabled = default
|
||||||
else:
|
else:
|
||||||
context = create_context(str(user_id))
|
# Use the unified function with full context support
|
||||||
is_enabled = get_client().variation(flag_key, context, default)
|
is_enabled = await is_feature_enabled(
|
||||||
|
flag_key, str(user_id), default, use_user_id_only=False
|
||||||
if not is_enabled:
|
|
||||||
raise HTTPException(status_code=404, detail="Feature not available")
|
|
||||||
|
|
||||||
result = func(*args, **kwargs)
|
|
||||||
if asyncio.iscoroutine(result):
|
|
||||||
return await result
|
|
||||||
return cast(T, result)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
@wraps(func)
|
|
||||||
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
||||||
try:
|
|
||||||
user_id = kwargs.get("user_id")
|
|
||||||
if not user_id:
|
|
||||||
raise ValueError("user_id is required")
|
|
||||||
|
|
||||||
if not get_client().is_initialized():
|
|
||||||
logger.warning(
|
|
||||||
f"LaunchDarkly not initialized, using default={default}"
|
|
||||||
)
|
)
|
||||||
is_enabled = default
|
|
||||||
else:
|
|
||||||
context = create_context(str(user_id))
|
|
||||||
is_enabled = get_client().variation(flag_key, context, default)
|
|
||||||
|
|
||||||
if not is_enabled:
|
if not is_enabled:
|
||||||
raise HTTPException(status_code=404, detail="Feature not available")
|
raise HTTPException(status_code=404, detail="Feature not available")
|
||||||
|
|
||||||
return cast(T, func(*args, **kwargs))
|
return await func(*args, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
|
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return cast(
|
return async_wrapper
|
||||||
Callable[P, Union[T, Awaitable[T]]],
|
|
||||||
async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper,
|
|
||||||
)
|
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def percentage_rollout(
|
|
||||||
flag_key: str,
|
|
||||||
default: bool = False,
|
|
||||||
) -> Callable[
|
|
||||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
|
||||||
]:
|
|
||||||
"""Decorator for percentage-based rollouts."""
|
|
||||||
return feature_flag(flag_key, default)
|
|
||||||
|
|
||||||
|
|
||||||
def beta_feature(
|
|
||||||
flag_key: Optional[str] = None,
|
|
||||||
unauthorized_response: Any = {"message": "Not available in beta"},
|
|
||||||
) -> Callable[
|
|
||||||
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
|
|
||||||
]:
|
|
||||||
"""Decorator for beta features."""
|
|
||||||
actual_key = f"beta-{flag_key}" if flag_key else "beta"
|
|
||||||
return feature_flag(actual_key, False)
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def mock_flag_variation(flag_key: str, return_value: Any):
|
def mock_flag_variation(flag_key: str, return_value: Any):
|
||||||
"""Context manager for testing feature flags."""
|
"""Context manager for testing feature flags."""
|
||||||
|
|||||||
@@ -1,10 +1,26 @@
|
|||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
|
import time
|
||||||
|
from functools import wraps
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Awaitable,
|
||||||
|
Callable,
|
||||||
|
ParamSpec,
|
||||||
|
Protocol,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
cast,
|
||||||
|
overload,
|
||||||
|
runtime_checkable,
|
||||||
|
)
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
|
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
|
||||||
@@ -57,3 +73,153 @@ def thread_cached(
|
|||||||
def clear_thread_cache(func: Callable) -> None:
|
def clear_thread_cache(func: Callable) -> None:
|
||||||
if clear := getattr(func, "clear_cache", None):
|
if clear := getattr(func, "clear_cache", None):
|
||||||
clear()
|
clear()
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class AsyncCachedFunction(Protocol):
|
||||||
|
"""Protocol for async functions with cache management methods."""
|
||||||
|
|
||||||
|
def cache_clear(self) -> None:
|
||||||
|
"""Clear all cached entries."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def cache_info(self) -> dict[str, Any]:
|
||||||
|
"""Get cache statistics."""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
"""Call the cached function."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def async_ttl_cache(
|
||||||
|
maxsize: int = 128, ttl_seconds: int | None = None
|
||||||
|
) -> Callable[[Callable[..., Awaitable[Any]]], AsyncCachedFunction]:
|
||||||
|
"""
|
||||||
|
TTL (Time To Live) cache decorator for async functions.
|
||||||
|
|
||||||
|
Similar to functools.lru_cache but works with async functions and includes optional TTL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
maxsize: Maximum number of cached entries
|
||||||
|
ttl_seconds: Time to live in seconds. If None, entries never expire (like lru_cache)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decorator function
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# With TTL
|
||||||
|
@async_ttl_cache(maxsize=1000, ttl_seconds=300)
|
||||||
|
async def api_call(param: str) -> dict:
|
||||||
|
return {"result": param}
|
||||||
|
|
||||||
|
# Without TTL (permanent cache like lru_cache)
|
||||||
|
@async_ttl_cache(maxsize=1000)
|
||||||
|
async def expensive_computation(param: str) -> dict:
|
||||||
|
return {"result": param}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(
|
||||||
|
async_func: Callable[..., Awaitable[Any]],
|
||||||
|
) -> AsyncCachedFunction:
|
||||||
|
# Cache storage - use union type to handle both cases
|
||||||
|
cache_storage: dict[Any, Any | Tuple[Any, float]] = {}
|
||||||
|
|
||||||
|
@wraps(async_func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
# Create cache key from arguments
|
||||||
|
key = (args, tuple(sorted(kwargs.items())))
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# Check if we have a valid cached entry
|
||||||
|
if key in cache_storage:
|
||||||
|
if ttl_seconds is None:
|
||||||
|
# No TTL - return cached result directly
|
||||||
|
logger.debug(
|
||||||
|
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||||
|
)
|
||||||
|
return cache_storage[key]
|
||||||
|
else:
|
||||||
|
# With TTL - check expiration
|
||||||
|
cached_data = cache_storage[key]
|
||||||
|
if isinstance(cached_data, tuple):
|
||||||
|
result, timestamp = cached_data
|
||||||
|
if current_time - timestamp < ttl_seconds:
|
||||||
|
logger.debug(
|
||||||
|
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
# Expired entry
|
||||||
|
del cache_storage[key]
|
||||||
|
logger.debug(
|
||||||
|
f"Cache entry expired for {async_func.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache miss or expired - fetch fresh data
|
||||||
|
logger.debug(
|
||||||
|
f"Cache miss for {async_func.__name__} with key: {str(key)[:50]}"
|
||||||
|
)
|
||||||
|
result = await async_func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Store in cache
|
||||||
|
if ttl_seconds is None:
|
||||||
|
cache_storage[key] = result
|
||||||
|
else:
|
||||||
|
cache_storage[key] = (result, current_time)
|
||||||
|
|
||||||
|
# Simple cleanup when cache gets too large
|
||||||
|
if len(cache_storage) > maxsize:
|
||||||
|
# Remove oldest entries (simple FIFO cleanup)
|
||||||
|
cutoff = maxsize // 2
|
||||||
|
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
|
||||||
|
for old_key in oldest_keys:
|
||||||
|
cache_storage.pop(old_key, None)
|
||||||
|
logger.debug(
|
||||||
|
f"Cache cleanup: removed {len(oldest_keys)} entries for {async_func.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Add cache management methods (similar to functools.lru_cache)
|
||||||
|
def cache_clear() -> None:
|
||||||
|
cache_storage.clear()
|
||||||
|
|
||||||
|
def cache_info() -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"size": len(cache_storage),
|
||||||
|
"maxsize": maxsize,
|
||||||
|
"ttl_seconds": ttl_seconds,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Attach methods to wrapper
|
||||||
|
setattr(wrapper, "cache_clear", cache_clear)
|
||||||
|
setattr(wrapper, "cache_info", cache_info)
|
||||||
|
|
||||||
|
return cast(AsyncCachedFunction, wrapper)
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def async_cache(
|
||||||
|
maxsize: int = 128,
|
||||||
|
) -> Callable[[Callable[..., Awaitable[Any]]], AsyncCachedFunction]:
|
||||||
|
"""
|
||||||
|
Process-level cache decorator for async functions (no TTL).
|
||||||
|
|
||||||
|
Similar to functools.lru_cache but works with async functions.
|
||||||
|
This is a convenience wrapper around async_ttl_cache with ttl_seconds=None.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
maxsize: Maximum number of cached entries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decorator function
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@async_cache(maxsize=1000)
|
||||||
|
async def expensive_computation(param: str) -> dict:
|
||||||
|
# Expensive computation here
|
||||||
|
return {"result": param}
|
||||||
|
"""
|
||||||
|
return async_ttl_cache(maxsize=maxsize, ttl_seconds=None)
|
||||||
|
|||||||
@@ -16,7 +16,12 @@ from unittest.mock import Mock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from autogpt_libs.utils.cache import clear_thread_cache, thread_cached
|
from autogpt_libs.utils.cache import (
|
||||||
|
async_cache,
|
||||||
|
async_ttl_cache,
|
||||||
|
clear_thread_cache,
|
||||||
|
thread_cached,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestThreadCached:
|
class TestThreadCached:
|
||||||
@@ -323,3 +328,378 @@ class TestThreadCached:
|
|||||||
|
|
||||||
assert function_using_mock(2) == 42
|
assert function_using_mock(2) == 42
|
||||||
assert mock.call_count == 2
|
assert mock.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncTTLCache:
|
||||||
|
"""Tests for the @async_ttl_cache decorator."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_basic_caching(self):
|
||||||
|
"""Test basic caching functionality."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||||
|
async def cached_function(x: int, y: int = 0) -> int:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
await asyncio.sleep(0.01) # Simulate async work
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
# First call
|
||||||
|
result1 = await cached_function(1, 2)
|
||||||
|
assert result1 == 3
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
# Second call with same args - should use cache
|
||||||
|
result2 = await cached_function(1, 2)
|
||||||
|
assert result2 == 3
|
||||||
|
assert call_count == 1 # No additional call
|
||||||
|
|
||||||
|
# Different args - should call function again
|
||||||
|
result3 = await cached_function(2, 3)
|
||||||
|
assert result3 == 5
|
||||||
|
assert call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ttl_expiration(self):
|
||||||
|
"""Test that cache entries expire after TTL."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||||
|
async def short_lived_cache(x: int) -> int:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
# First call
|
||||||
|
result1 = await short_lived_cache(5)
|
||||||
|
assert result1 == 10
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
# Second call immediately - should use cache
|
||||||
|
result2 = await short_lived_cache(5)
|
||||||
|
assert result2 == 10
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
# Wait for TTL to expire
|
||||||
|
await asyncio.sleep(1.1)
|
||||||
|
|
||||||
|
# Third call after expiration - should call function again
|
||||||
|
result3 = await short_lived_cache(5)
|
||||||
|
assert result3 == 10
|
||||||
|
assert call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_info(self):
|
||||||
|
"""Test cache info functionality."""
|
||||||
|
|
||||||
|
@async_ttl_cache(maxsize=5, ttl_seconds=300)
|
||||||
|
async def info_test_function(x: int) -> int:
|
||||||
|
return x * 3
|
||||||
|
|
||||||
|
# Check initial cache info
|
||||||
|
info = info_test_function.cache_info()
|
||||||
|
assert info["size"] == 0
|
||||||
|
assert info["maxsize"] == 5
|
||||||
|
assert info["ttl_seconds"] == 300
|
||||||
|
|
||||||
|
# Add an entry
|
||||||
|
await info_test_function(1)
|
||||||
|
info = info_test_function.cache_info()
|
||||||
|
assert info["size"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_clear(self):
|
||||||
|
"""Test cache clearing functionality."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||||
|
async def clearable_function(x: int) -> int:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return x * 4
|
||||||
|
|
||||||
|
# First call
|
||||||
|
result1 = await clearable_function(2)
|
||||||
|
assert result1 == 8
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
# Second call - should use cache
|
||||||
|
result2 = await clearable_function(2)
|
||||||
|
assert result2 == 8
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
# Clear cache
|
||||||
|
clearable_function.cache_clear()
|
||||||
|
|
||||||
|
# Third call after clear - should call function again
|
||||||
|
result3 = await clearable_function(2)
|
||||||
|
assert result3 == 8
|
||||||
|
assert call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_maxsize_cleanup(self):
|
||||||
|
"""Test that cache cleans up when maxsize is exceeded."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
@async_ttl_cache(maxsize=3, ttl_seconds=60)
|
||||||
|
async def size_limited_function(x: int) -> int:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return x**2
|
||||||
|
|
||||||
|
# Fill cache to maxsize
|
||||||
|
await size_limited_function(1) # call_count: 1
|
||||||
|
await size_limited_function(2) # call_count: 2
|
||||||
|
await size_limited_function(3) # call_count: 3
|
||||||
|
|
||||||
|
info = size_limited_function.cache_info()
|
||||||
|
assert info["size"] == 3
|
||||||
|
|
||||||
|
# Add one more entry - should trigger cleanup
|
||||||
|
await size_limited_function(4) # call_count: 4
|
||||||
|
|
||||||
|
# Cache size should be reduced (cleanup removes oldest entries)
|
||||||
|
info = size_limited_function.cache_info()
|
||||||
|
assert info["size"] <= 3 # Should be cleaned up
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_argument_variations(self):
|
||||||
|
"""Test caching with different argument patterns."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||||
|
async def arg_test_function(a: int, b: str = "default", *, c: int = 100) -> str:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return f"{a}-{b}-{c}"
|
||||||
|
|
||||||
|
# Different ways to call with same logical arguments
|
||||||
|
result1 = await arg_test_function(1, "test", c=200)
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
# Same arguments, same order - should use cache
|
||||||
|
result2 = await arg_test_function(1, "test", c=200)
|
||||||
|
assert call_count == 1
|
||||||
|
assert result1 == result2
|
||||||
|
|
||||||
|
# Different arguments - should call function
|
||||||
|
result3 = await arg_test_function(2, "test", c=200)
|
||||||
|
assert call_count == 2
|
||||||
|
assert result1 != result3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exception_handling(self):
|
||||||
|
"""Test that exceptions are not cached."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||||
|
async def exception_function(x: int) -> int:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if x < 0:
|
||||||
|
raise ValueError("Negative value not allowed")
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
# Successful call - should be cached
|
||||||
|
result1 = await exception_function(5)
|
||||||
|
assert result1 == 10
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
# Same successful call - should use cache
|
||||||
|
result2 = await exception_function(5)
|
||||||
|
assert result2 == 10
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
# Exception call - should not be cached
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await exception_function(-1)
|
||||||
|
assert call_count == 2
|
||||||
|
|
||||||
|
# Same exception call - should call again (not cached)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await exception_function(-1)
|
||||||
|
assert call_count == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_calls(self):
|
||||||
|
"""Test caching behavior with concurrent calls."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
@async_ttl_cache(maxsize=10, ttl_seconds=60)
|
||||||
|
async def concurrent_function(x: int) -> int:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
await asyncio.sleep(0.05) # Simulate work
|
||||||
|
return x * x
|
||||||
|
|
||||||
|
# Launch concurrent calls with same arguments
|
||||||
|
tasks = [concurrent_function(3) for _ in range(5)]
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# All results should be the same
|
||||||
|
assert all(result == 9 for result in results)
|
||||||
|
|
||||||
|
# Note: Due to race conditions, call_count might be up to 5 for concurrent calls
|
||||||
|
# This tests that the cache doesn't break under concurrent access
|
||||||
|
assert 1 <= call_count <= 5
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncCache:
|
||||||
|
"""Tests for the @async_cache decorator (no TTL)."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_basic_caching_no_ttl(self):
|
||||||
|
"""Test basic caching functionality without TTL."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
@async_cache(maxsize=10)
|
||||||
|
async def cached_function(x: int, y: int = 0) -> int:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
await asyncio.sleep(0.01) # Simulate async work
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
# First call
|
||||||
|
result1 = await cached_function(1, 2)
|
||||||
|
assert result1 == 3
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
# Second call with same args - should use cache
|
||||||
|
result2 = await cached_function(1, 2)
|
||||||
|
assert result2 == 3
|
||||||
|
assert call_count == 1 # No additional call
|
||||||
|
|
||||||
|
# Third call after some time - should still use cache (no TTL)
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
result3 = await cached_function(1, 2)
|
||||||
|
assert result3 == 3
|
||||||
|
assert call_count == 1 # Still no additional call
|
||||||
|
|
||||||
|
# Different args - should call function again
|
||||||
|
result4 = await cached_function(2, 3)
|
||||||
|
assert result4 == 5
|
||||||
|
assert call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_ttl_vs_ttl_behavior(self):
|
||||||
|
"""Test the difference between TTL and no-TTL caching."""
|
||||||
|
ttl_call_count = 0
|
||||||
|
no_ttl_call_count = 0
|
||||||
|
|
||||||
|
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
|
||||||
|
async def ttl_function(x: int) -> int:
|
||||||
|
nonlocal ttl_call_count
|
||||||
|
ttl_call_count += 1
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
@async_cache(maxsize=10) # No TTL
|
||||||
|
async def no_ttl_function(x: int) -> int:
|
||||||
|
nonlocal no_ttl_call_count
|
||||||
|
no_ttl_call_count += 1
|
||||||
|
return x * 2
|
||||||
|
|
||||||
|
# First calls
|
||||||
|
await ttl_function(5)
|
||||||
|
await no_ttl_function(5)
|
||||||
|
assert ttl_call_count == 1
|
||||||
|
assert no_ttl_call_count == 1
|
||||||
|
|
||||||
|
# Wait for TTL to expire
|
||||||
|
await asyncio.sleep(1.1)
|
||||||
|
|
||||||
|
# Second calls after TTL expiry
|
||||||
|
await ttl_function(5) # Should call function again (TTL expired)
|
||||||
|
await no_ttl_function(5) # Should use cache (no TTL)
|
||||||
|
assert ttl_call_count == 2 # TTL function called again
|
||||||
|
assert no_ttl_call_count == 1 # No-TTL function still cached
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_cache_info(self):
|
||||||
|
"""Test cache info for no-TTL cache."""
|
||||||
|
|
||||||
|
@async_cache(maxsize=5)
|
||||||
|
async def info_test_function(x: int) -> int:
|
||||||
|
return x * 3
|
||||||
|
|
||||||
|
# Check initial cache info
|
||||||
|
info = info_test_function.cache_info()
|
||||||
|
assert info["size"] == 0
|
||||||
|
assert info["maxsize"] == 5
|
||||||
|
assert info["ttl_seconds"] is None # No TTL
|
||||||
|
|
||||||
|
# Add an entry
|
||||||
|
await info_test_function(1)
|
||||||
|
info = info_test_function.cache_info()
|
||||||
|
assert info["size"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestTTLOptional:
|
||||||
|
"""Tests for optional TTL functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ttl_none_behavior(self):
|
||||||
|
"""Test that ttl_seconds=None works like no TTL."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
@async_ttl_cache(maxsize=10, ttl_seconds=None)
|
||||||
|
async def no_ttl_via_none(x: int) -> int:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return x**2
|
||||||
|
|
||||||
|
# First call
|
||||||
|
result1 = await no_ttl_via_none(3)
|
||||||
|
assert result1 == 9
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
# Wait (would expire if there was TTL)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# Second call - should still use cache
|
||||||
|
result2 = await no_ttl_via_none(3)
|
||||||
|
assert result2 == 9
|
||||||
|
assert call_count == 1 # No additional call
|
||||||
|
|
||||||
|
# Check cache info
|
||||||
|
info = no_ttl_via_none.cache_info()
|
||||||
|
assert info["ttl_seconds"] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_options_comparison(self):
|
||||||
|
"""Test different cache options work as expected."""
|
||||||
|
ttl_calls = 0
|
||||||
|
no_ttl_calls = 0
|
||||||
|
|
||||||
|
@async_ttl_cache(maxsize=10, ttl_seconds=1) # With TTL
|
||||||
|
async def ttl_function(x: int) -> int:
|
||||||
|
nonlocal ttl_calls
|
||||||
|
ttl_calls += 1
|
||||||
|
return x * 10
|
||||||
|
|
||||||
|
@async_cache(maxsize=10) # Process-level cache (no TTL)
|
||||||
|
async def process_function(x: int) -> int:
|
||||||
|
nonlocal no_ttl_calls
|
||||||
|
no_ttl_calls += 1
|
||||||
|
return x * 10
|
||||||
|
|
||||||
|
# Both should cache initially
|
||||||
|
await ttl_function(3)
|
||||||
|
await process_function(3)
|
||||||
|
assert ttl_calls == 1
|
||||||
|
assert no_ttl_calls == 1
|
||||||
|
|
||||||
|
# Immediate second calls - both should use cache
|
||||||
|
await ttl_function(3)
|
||||||
|
await process_function(3)
|
||||||
|
assert ttl_calls == 1
|
||||||
|
assert no_ttl_calls == 1
|
||||||
|
|
||||||
|
# Wait for TTL to expire
|
||||||
|
await asyncio.sleep(1.1)
|
||||||
|
|
||||||
|
# After TTL expiry
|
||||||
|
await ttl_function(3) # Should call function again
|
||||||
|
await process_function(3) # Should still use cache
|
||||||
|
assert ttl_calls == 2 # TTL cache expired, called again
|
||||||
|
assert no_ttl_calls == 1 # Process cache never expires
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from prisma.models import User
|
|
||||||
|
|
||||||
from backend.data.model import ProviderName
|
from backend.data.model import ProviderName, User
|
||||||
from backend.server.model import CreateGraph
|
from backend.server.model import CreateGraph
|
||||||
from backend.server.rest_api import AgentServer
|
from backend.server.rest_api import AgentServer
|
||||||
from backend.usecases.sample import create_test_graph, create_test_user
|
from backend.usecases.sample import create_test_graph, create_test_user
|
||||||
|
|||||||
@@ -998,8 +998,8 @@ def get_block_costs() -> dict[str, list[BlockCost]]:
|
|||||||
async def get_stripe_customer_id(user_id: str) -> str:
|
async def get_stripe_customer_id(user_id: str) -> str:
|
||||||
user = await get_user_by_id(user_id)
|
user = await get_user_by_id(user_id)
|
||||||
|
|
||||||
if user.stripeCustomerId:
|
if user.stripe_customer_id:
|
||||||
return user.stripeCustomerId
|
return user.stripe_customer_id
|
||||||
|
|
||||||
customer = stripe.Customer.create(
|
customer = stripe.Customer.create(
|
||||||
name=user.name or "",
|
name=user.name or "",
|
||||||
@@ -1022,10 +1022,10 @@ async def set_auto_top_up(user_id: str, config: AutoTopUpConfig):
|
|||||||
async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
|
async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
|
||||||
user = await get_user_by_id(user_id)
|
user = await get_user_by_id(user_id)
|
||||||
|
|
||||||
if not user.topUpConfig:
|
if not user.top_up_config:
|
||||||
return AutoTopUpConfig(threshold=0, amount=0)
|
return AutoTopUpConfig(threshold=0, amount=0)
|
||||||
|
|
||||||
return AutoTopUpConfig.model_validate(user.topUpConfig)
|
return AutoTopUpConfig.model_validate(user.top_up_config)
|
||||||
|
|
||||||
|
|
||||||
async def admin_get_user_history(
|
async def admin_get_user_history(
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import enum
|
|||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from json import JSONDecodeError
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Annotated,
|
Annotated,
|
||||||
@@ -40,12 +41,120 @@ from pydantic_core import (
|
|||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
|
from backend.util.json import loads as json_loads
|
||||||
from backend.util.settings import Secrets
|
from backend.util.settings import Secrets
|
||||||
|
|
||||||
# Type alias for any provider name (including custom ones)
|
# Type alias for any provider name (including custom ones)
|
||||||
AnyProviderName = str # Will be validated as ProviderName at runtime
|
AnyProviderName = str # Will be validated as ProviderName at runtime
|
||||||
|
|
||||||
|
|
||||||
|
class User(BaseModel):
|
||||||
|
"""Application-layer User model with snake_case convention."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
extra="forbid",
|
||||||
|
str_strip_whitespace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
id: str = Field(..., description="User ID")
|
||||||
|
email: str = Field(..., description="User email address")
|
||||||
|
email_verified: bool = Field(default=True, description="Whether email is verified")
|
||||||
|
name: Optional[str] = Field(None, description="User display name")
|
||||||
|
created_at: datetime = Field(..., description="When user was created")
|
||||||
|
updated_at: datetime = Field(..., description="When user was last updated")
|
||||||
|
metadata: dict[str, Any] = Field(
|
||||||
|
default_factory=dict, description="User metadata as dict"
|
||||||
|
)
|
||||||
|
integrations: str = Field(default="", description="Encrypted integrations data")
|
||||||
|
stripe_customer_id: Optional[str] = Field(None, description="Stripe customer ID")
|
||||||
|
top_up_config: Optional["AutoTopUpConfig"] = Field(
|
||||||
|
None, description="Top up configuration"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Notification preferences
|
||||||
|
max_emails_per_day: int = Field(default=3, description="Maximum emails per day")
|
||||||
|
notify_on_agent_run: bool = Field(default=True, description="Notify on agent run")
|
||||||
|
notify_on_zero_balance: bool = Field(
|
||||||
|
default=True, description="Notify on zero balance"
|
||||||
|
)
|
||||||
|
notify_on_low_balance: bool = Field(
|
||||||
|
default=True, description="Notify on low balance"
|
||||||
|
)
|
||||||
|
notify_on_block_execution_failed: bool = Field(
|
||||||
|
default=True, description="Notify on block execution failure"
|
||||||
|
)
|
||||||
|
notify_on_continuous_agent_error: bool = Field(
|
||||||
|
default=True, description="Notify on continuous agent error"
|
||||||
|
)
|
||||||
|
notify_on_daily_summary: bool = Field(
|
||||||
|
default=True, description="Notify on daily summary"
|
||||||
|
)
|
||||||
|
notify_on_weekly_summary: bool = Field(
|
||||||
|
default=True, description="Notify on weekly summary"
|
||||||
|
)
|
||||||
|
notify_on_monthly_summary: bool = Field(
|
||||||
|
default=True, description="Notify on monthly summary"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_db(cls, prisma_user: "PrismaUser") -> "User":
|
||||||
|
"""Convert a database User object to application User model."""
|
||||||
|
# Handle metadata field - convert from JSON string or dict to dict
|
||||||
|
metadata = {}
|
||||||
|
if prisma_user.metadata:
|
||||||
|
if isinstance(prisma_user.metadata, str):
|
||||||
|
try:
|
||||||
|
metadata = json_loads(prisma_user.metadata)
|
||||||
|
except (JSONDecodeError, TypeError):
|
||||||
|
metadata = {}
|
||||||
|
elif isinstance(prisma_user.metadata, dict):
|
||||||
|
metadata = prisma_user.metadata
|
||||||
|
|
||||||
|
# Handle topUpConfig field
|
||||||
|
top_up_config = None
|
||||||
|
if prisma_user.topUpConfig:
|
||||||
|
if isinstance(prisma_user.topUpConfig, str):
|
||||||
|
try:
|
||||||
|
config_dict = json_loads(prisma_user.topUpConfig)
|
||||||
|
top_up_config = AutoTopUpConfig.model_validate(config_dict)
|
||||||
|
except (JSONDecodeError, TypeError, ValueError):
|
||||||
|
top_up_config = None
|
||||||
|
elif isinstance(prisma_user.topUpConfig, dict):
|
||||||
|
try:
|
||||||
|
top_up_config = AutoTopUpConfig.model_validate(
|
||||||
|
prisma_user.topUpConfig
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
top_up_config = None
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
id=prisma_user.id,
|
||||||
|
email=prisma_user.email,
|
||||||
|
email_verified=prisma_user.emailVerified or True,
|
||||||
|
name=prisma_user.name,
|
||||||
|
created_at=prisma_user.createdAt,
|
||||||
|
updated_at=prisma_user.updatedAt,
|
||||||
|
metadata=metadata,
|
||||||
|
integrations=prisma_user.integrations or "",
|
||||||
|
stripe_customer_id=prisma_user.stripeCustomerId,
|
||||||
|
top_up_config=top_up_config,
|
||||||
|
max_emails_per_day=prisma_user.maxEmailsPerDay or 3,
|
||||||
|
notify_on_agent_run=prisma_user.notifyOnAgentRun or True,
|
||||||
|
notify_on_zero_balance=prisma_user.notifyOnZeroBalance or True,
|
||||||
|
notify_on_low_balance=prisma_user.notifyOnLowBalance or True,
|
||||||
|
notify_on_block_execution_failed=prisma_user.notifyOnBlockExecutionFailed
|
||||||
|
or True,
|
||||||
|
notify_on_continuous_agent_error=prisma_user.notifyOnContinuousAgentError
|
||||||
|
or True,
|
||||||
|
notify_on_daily_summary=prisma_user.notifyOnDailySummary or True,
|
||||||
|
notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True,
|
||||||
|
notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from prisma.models import User as PrismaUser
|
||||||
|
|
||||||
from backend.data.block import BlockSchema
|
from backend.data.block import BlockSchema
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|||||||
@@ -9,11 +9,11 @@ from urllib.parse import quote_plus
|
|||||||
from autogpt_libs.auth.models import DEFAULT_USER_ID
|
from autogpt_libs.auth.models import DEFAULT_USER_ID
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from prisma.enums import NotificationType
|
from prisma.enums import NotificationType
|
||||||
from prisma.models import User
|
from prisma.models import User as PrismaUser
|
||||||
from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
|
from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
|
||||||
|
|
||||||
from backend.data.db import prisma
|
from backend.data.db import prisma
|
||||||
from backend.data.model import UserIntegrations, UserMetadata
|
from backend.data.model import User, UserIntegrations, UserMetadata
|
||||||
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
|
||||||
from backend.server.v2.store.exceptions import DatabaseError
|
from backend.server.v2.store.exceptions import DatabaseError
|
||||||
from backend.util.encryption import JSONCryptor
|
from backend.util.encryption import JSONCryptor
|
||||||
@@ -44,7 +44,7 @@ async def get_or_create_user(user_data: dict) -> User:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return User.model_validate(user)
|
return User.from_db(user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise DatabaseError(f"Failed to get or create user {user_data}: {e}") from e
|
raise DatabaseError(f"Failed to get or create user {user_data}: {e}") from e
|
||||||
|
|
||||||
@@ -53,7 +53,7 @@ async def get_user_by_id(user_id: str) -> User:
|
|||||||
user = await prisma.user.find_unique(where={"id": user_id})
|
user = await prisma.user.find_unique(where={"id": user_id})
|
||||||
if not user:
|
if not user:
|
||||||
raise ValueError(f"User not found with ID: {user_id}")
|
raise ValueError(f"User not found with ID: {user_id}")
|
||||||
return User.model_validate(user)
|
return User.from_db(user)
|
||||||
|
|
||||||
|
|
||||||
async def get_user_email_by_id(user_id: str) -> Optional[str]:
|
async def get_user_email_by_id(user_id: str) -> Optional[str]:
|
||||||
@@ -67,7 +67,7 @@ async def get_user_email_by_id(user_id: str) -> Optional[str]:
|
|||||||
async def get_user_by_email(email: str) -> Optional[User]:
|
async def get_user_by_email(email: str) -> Optional[User]:
|
||||||
try:
|
try:
|
||||||
user = await prisma.user.find_unique(where={"email": email})
|
user = await prisma.user.find_unique(where={"email": email})
|
||||||
return User.model_validate(user) if user else None
|
return User.from_db(user) if user else None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise DatabaseError(f"Failed to get user by email {email}: {e}") from e
|
raise DatabaseError(f"Failed to get user by email {email}: {e}") from e
|
||||||
|
|
||||||
@@ -91,11 +91,11 @@ async def create_default_user() -> Optional[User]:
|
|||||||
name="Default User",
|
name="Default User",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return User.model_validate(user)
|
return User.from_db(user)
|
||||||
|
|
||||||
|
|
||||||
async def get_user_integrations(user_id: str) -> UserIntegrations:
|
async def get_user_integrations(user_id: str) -> UserIntegrations:
|
||||||
user = await User.prisma().find_unique_or_raise(
|
user = await PrismaUser.prisma().find_unique_or_raise(
|
||||||
where={"id": user_id},
|
where={"id": user_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -110,7 +110,7 @@ async def get_user_integrations(user_id: str) -> UserIntegrations:
|
|||||||
|
|
||||||
async def update_user_integrations(user_id: str, data: UserIntegrations):
|
async def update_user_integrations(user_id: str, data: UserIntegrations):
|
||||||
encrypted_data = JSONCryptor().encrypt(data.model_dump(exclude_none=True))
|
encrypted_data = JSONCryptor().encrypt(data.model_dump(exclude_none=True))
|
||||||
await User.prisma().update(
|
await PrismaUser.prisma().update(
|
||||||
where={"id": user_id},
|
where={"id": user_id},
|
||||||
data={"integrations": encrypted_data},
|
data={"integrations": encrypted_data},
|
||||||
)
|
)
|
||||||
@@ -118,7 +118,7 @@ async def update_user_integrations(user_id: str, data: UserIntegrations):
|
|||||||
|
|
||||||
async def migrate_and_encrypt_user_integrations():
|
async def migrate_and_encrypt_user_integrations():
|
||||||
"""Migrate integration credentials and OAuth states from metadata to integrations column."""
|
"""Migrate integration credentials and OAuth states from metadata to integrations column."""
|
||||||
users = await User.prisma().find_many(
|
users = await PrismaUser.prisma().find_many(
|
||||||
where={
|
where={
|
||||||
"metadata": cast(
|
"metadata": cast(
|
||||||
JsonFilter,
|
JsonFilter,
|
||||||
@@ -154,7 +154,7 @@ async def migrate_and_encrypt_user_integrations():
|
|||||||
raw_metadata.pop("integration_oauth_states", None)
|
raw_metadata.pop("integration_oauth_states", None)
|
||||||
|
|
||||||
# Update metadata without integration data
|
# Update metadata without integration data
|
||||||
await User.prisma().update(
|
await PrismaUser.prisma().update(
|
||||||
where={"id": user.id},
|
where={"id": user.id},
|
||||||
data={"metadata": SafeJson(raw_metadata)},
|
data={"metadata": SafeJson(raw_metadata)},
|
||||||
)
|
)
|
||||||
@@ -162,7 +162,7 @@ async def migrate_and_encrypt_user_integrations():
|
|||||||
|
|
||||||
async def get_active_user_ids_in_timerange(start_time: str, end_time: str) -> list[str]:
|
async def get_active_user_ids_in_timerange(start_time: str, end_time: str) -> list[str]:
|
||||||
try:
|
try:
|
||||||
users = await User.prisma().find_many(
|
users = await PrismaUser.prisma().find_many(
|
||||||
where={
|
where={
|
||||||
"AgentGraphExecutions": {
|
"AgentGraphExecutions": {
|
||||||
"some": {
|
"some": {
|
||||||
@@ -192,7 +192,7 @@ async def get_active_users_ids() -> list[str]:
|
|||||||
|
|
||||||
async def get_user_notification_preference(user_id: str) -> NotificationPreference:
|
async def get_user_notification_preference(user_id: str) -> NotificationPreference:
|
||||||
try:
|
try:
|
||||||
user = await User.prisma().find_unique_or_raise(
|
user = await PrismaUser.prisma().find_unique_or_raise(
|
||||||
where={"id": user_id},
|
where={"id": user_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -269,7 +269,7 @@ async def update_user_notification_preference(
|
|||||||
if data.daily_limit:
|
if data.daily_limit:
|
||||||
update_data["maxEmailsPerDay"] = data.daily_limit
|
update_data["maxEmailsPerDay"] = data.daily_limit
|
||||||
|
|
||||||
user = await User.prisma().update(
|
user = await PrismaUser.prisma().update(
|
||||||
where={"id": user_id},
|
where={"id": user_id},
|
||||||
data=update_data,
|
data=update_data,
|
||||||
)
|
)
|
||||||
@@ -307,7 +307,7 @@ async def update_user_notification_preference(
|
|||||||
async def set_user_email_verification(user_id: str, verified: bool) -> None:
|
async def set_user_email_verification(user_id: str, verified: bool) -> None:
|
||||||
"""Set the email verification status for a user."""
|
"""Set the email verification status for a user."""
|
||||||
try:
|
try:
|
||||||
await User.prisma().update(
|
await PrismaUser.prisma().update(
|
||||||
where={"id": user_id},
|
where={"id": user_id},
|
||||||
data={"emailVerified": verified},
|
data={"emailVerified": verified},
|
||||||
)
|
)
|
||||||
@@ -320,7 +320,7 @@ async def set_user_email_verification(user_id: str, verified: bool) -> None:
|
|||||||
async def get_user_email_verification(user_id: str) -> bool:
|
async def get_user_email_verification(user_id: str) -> bool:
|
||||||
"""Get the email verification status for a user."""
|
"""Get the email verification status for a user."""
|
||||||
try:
|
try:
|
||||||
user = await User.prisma().find_unique_or_raise(
|
user = await PrismaUser.prisma().find_unique_or_raise(
|
||||||
where={"id": user_id},
|
where={"id": user_id},
|
||||||
)
|
)
|
||||||
return user.emailVerified
|
return user.emailVerified
|
||||||
|
|||||||
@@ -102,8 +102,10 @@ async def generate_activity_status_for_execution(
|
|||||||
Returns:
|
Returns:
|
||||||
AI-generated activity status string, or None if feature is disabled
|
AI-generated activity status string, or None if feature is disabled
|
||||||
"""
|
"""
|
||||||
# Check LaunchDarkly feature flag for AI activity status generation
|
# Check LaunchDarkly feature flag for AI activity status generation with full context support
|
||||||
if not is_feature_enabled(AI_ACTIVITY_STATUS_FLAG_KEY, user_id, default=False):
|
if not await is_feature_enabled(
|
||||||
|
AI_ACTIVITY_STATUS_FLAG_KEY, user_id, default=False
|
||||||
|
):
|
||||||
logger.debug("AI activity status generation is disabled via LaunchDarkly")
|
logger.debug("AI activity status generation is disabled via LaunchDarkly")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from backend.data.notifications import (
|
|||||||
)
|
)
|
||||||
from backend.data.user import (
|
from backend.data.user import (
|
||||||
get_active_user_ids_in_timerange,
|
get_active_user_ids_in_timerange,
|
||||||
|
get_user_by_id,
|
||||||
get_user_email_by_id,
|
get_user_email_by_id,
|
||||||
get_user_email_verification,
|
get_user_email_verification,
|
||||||
get_user_integrations,
|
get_user_integrations,
|
||||||
@@ -129,6 +130,7 @@ class DatabaseManager(AppService):
|
|||||||
|
|
||||||
# User Comms - async
|
# User Comms - async
|
||||||
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
|
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
|
||||||
|
get_user_by_id = _(get_user_by_id)
|
||||||
get_user_email_by_id = _(get_user_email_by_id)
|
get_user_email_by_id = _(get_user_email_by_id)
|
||||||
get_user_email_verification = _(get_user_email_verification)
|
get_user_email_verification = _(get_user_email_verification)
|
||||||
get_user_notification_preference = _(get_user_notification_preference)
|
get_user_notification_preference = _(get_user_notification_preference)
|
||||||
@@ -201,6 +203,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
|||||||
|
|
||||||
# User Comms
|
# User Comms
|
||||||
get_active_user_ids_in_timerange = d.get_active_user_ids_in_timerange
|
get_active_user_ids_in_timerange = d.get_active_user_ids_in_timerange
|
||||||
|
get_user_by_id = d.get_user_by_id
|
||||||
get_user_email_by_id = d.get_user_email_by_id
|
get_user_email_by_id = d.get_user_email_by_id
|
||||||
get_user_email_verification = d.get_user_email_verification
|
get_user_email_verification = d.get_user_email_verification
|
||||||
get_user_notification_preference = d.get_user_notification_preference
|
get_user_notification_preference = d.get_user_notification_preference
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import logging
|
|||||||
import autogpt_libs.auth.models
|
import autogpt_libs.auth.models
|
||||||
import fastapi.responses
|
import fastapi.responses
|
||||||
import pytest
|
import pytest
|
||||||
from prisma.models import User
|
|
||||||
|
|
||||||
import backend.server.v2.library.model
|
import backend.server.v2.library.model
|
||||||
import backend.server.v2.store.model
|
import backend.server.v2.store.model
|
||||||
@@ -12,6 +11,7 @@ from backend.blocks.data_manipulation import FindInDictionaryBlock
|
|||||||
from backend.blocks.io import AgentInputBlock
|
from backend.blocks.io import AgentInputBlock
|
||||||
from backend.blocks.maths import CalculatorBlock, Operation
|
from backend.blocks.maths import CalculatorBlock, Operation
|
||||||
from backend.data import execution, graph
|
from backend.data import execution, graph
|
||||||
|
from backend.data.model import User
|
||||||
from backend.server.model import CreateGraph
|
from backend.server.model import CreateGraph
|
||||||
from backend.server.rest_api import AgentServer
|
from backend.server.rest_api import AgentServer
|
||||||
from backend.usecases.sample import create_test_graph, create_test_user
|
from backend.usecases.sample import create_test_graph, create_test_user
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from prisma.models import User
|
|
||||||
|
|
||||||
from backend.blocks.basic import StoreValueBlock
|
from backend.blocks.basic import StoreValueBlock
|
||||||
from backend.blocks.block import BlockInstallationBlock
|
from backend.blocks.block import BlockInstallationBlock
|
||||||
from backend.blocks.http import SendWebRequestBlock
|
from backend.blocks.http import SendWebRequestBlock
|
||||||
from backend.blocks.llm import AITextGeneratorBlock
|
from backend.blocks.llm import AITextGeneratorBlock
|
||||||
from backend.blocks.text import ExtractTextInformationBlock, FillTextTemplateBlock
|
from backend.blocks.text import ExtractTextInformationBlock, FillTextTemplateBlock
|
||||||
from backend.data.graph import Graph, Link, Node, create_graph
|
from backend.data.graph import Graph, Link, Node, create_graph
|
||||||
|
from backend.data.model import User
|
||||||
from backend.data.user import get_or_create_user
|
from backend.data.user import get_or_create_user
|
||||||
from backend.util.test import SpinTestServer, wait_execution
|
from backend.util.test import SpinTestServer, wait_execution
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
from prisma.models import User
|
|
||||||
|
|
||||||
from backend.blocks.llm import AIStructuredResponseGeneratorBlock
|
from backend.blocks.llm import AIStructuredResponseGeneratorBlock
|
||||||
from backend.blocks.reddit import GetRedditPostsBlock, PostRedditCommentBlock
|
from backend.blocks.reddit import GetRedditPostsBlock, PostRedditCommentBlock
|
||||||
from backend.blocks.text import FillTextTemplateBlock, MatchTextPatternBlock
|
from backend.blocks.text import FillTextTemplateBlock, MatchTextPatternBlock
|
||||||
from backend.data.graph import Graph, Link, Node, create_graph
|
from backend.data.graph import Graph, Link, Node, create_graph
|
||||||
|
from backend.data.model import User
|
||||||
from backend.data.user import get_or_create_user
|
from backend.data.user import get_or_create_user
|
||||||
from backend.util.test import SpinTestServer, wait_execution
|
from backend.util.test import SpinTestServer, wait_execution
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
from prisma.models import User
|
|
||||||
|
|
||||||
from backend.blocks.basic import StoreValueBlock
|
from backend.blocks.basic import StoreValueBlock
|
||||||
from backend.blocks.io import AgentInputBlock
|
from backend.blocks.io import AgentInputBlock
|
||||||
from backend.blocks.text import FillTextTemplateBlock
|
from backend.blocks.text import FillTextTemplateBlock
|
||||||
from backend.data import graph
|
from backend.data import graph
|
||||||
from backend.data.graph import create_graph
|
from backend.data.graph import create_graph
|
||||||
|
from backend.data.model import User
|
||||||
from backend.data.user import get_or_create_user
|
from backend.data.user import get_or_create_user
|
||||||
from backend.util.test import SpinTestServer, wait_execution
|
from backend.util.test import SpinTestServer, wait_execution
|
||||||
|
|
||||||
|
|||||||
@@ -45,6 +45,34 @@ api_comm_retry = config.pyro_client_comm_retry
|
|||||||
api_comm_timeout = config.pyro_client_comm_timeout
|
api_comm_timeout = config.pyro_client_comm_timeout
|
||||||
api_call_timeout = config.rpc_client_call_timeout
|
api_call_timeout = config.rpc_client_call_timeout
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_no_prisma_objects(obj: Any, path: str = "result") -> None:
|
||||||
|
"""
|
||||||
|
Recursively validate that no Prisma objects are being returned from service methods.
|
||||||
|
This enforces proper separation of layers - only application models should cross service boundaries.
|
||||||
|
"""
|
||||||
|
if obj is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if it's a Prisma model object
|
||||||
|
if hasattr(obj, "__class__") and hasattr(obj.__class__, "__module__"):
|
||||||
|
module_name = obj.__class__.__module__
|
||||||
|
if module_name and "prisma.models" in module_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"Prisma object {obj.__class__.__name__} found in {path}. "
|
||||||
|
"Service methods must return application models, not Prisma objects. "
|
||||||
|
f"Use {obj.__class__.__name__}.from_db() to convert to application model."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recursively check collections
|
||||||
|
if isinstance(obj, (list, tuple)):
|
||||||
|
for i, item in enumerate(obj):
|
||||||
|
_validate_no_prisma_objects(item, f"{path}[{i}]")
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
for key, value in obj.items():
|
||||||
|
_validate_no_prisma_objects(value, f"{path}['{key}']")
|
||||||
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
EXPOSED_FLAG = "__exposed__"
|
EXPOSED_FLAG = "__exposed__"
|
||||||
@@ -111,6 +139,22 @@ class UnhealthyServiceError(ValueError):
|
|||||||
return self.message
|
return self.message
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPClientError(Exception):
|
||||||
|
"""Exception for HTTP client errors (4xx status codes) that should not be retried."""
|
||||||
|
|
||||||
|
def __init__(self, status_code: int, message: str):
|
||||||
|
self.status_code = status_code
|
||||||
|
super().__init__(f"HTTP {status_code}: {message}")
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPServerError(Exception):
|
||||||
|
"""Exception for HTTP server errors (5xx status codes) that can be retried."""
|
||||||
|
|
||||||
|
def __init__(self, status_code: int, message: str):
|
||||||
|
self.status_code = status_code
|
||||||
|
super().__init__(f"HTTP {status_code}: {message}")
|
||||||
|
|
||||||
|
|
||||||
EXCEPTION_MAPPING = {
|
EXCEPTION_MAPPING = {
|
||||||
e.__name__: e
|
e.__name__: e
|
||||||
for e in [
|
for e in [
|
||||||
@@ -119,6 +163,8 @@ EXCEPTION_MAPPING = {
|
|||||||
TimeoutError,
|
TimeoutError,
|
||||||
ConnectionError,
|
ConnectionError,
|
||||||
UnhealthyServiceError,
|
UnhealthyServiceError,
|
||||||
|
HTTPClientError,
|
||||||
|
HTTPServerError,
|
||||||
*[
|
*[
|
||||||
ErrorType
|
ErrorType
|
||||||
for _, ErrorType in inspect.getmembers(exceptions)
|
for _, ErrorType in inspect.getmembers(exceptions)
|
||||||
@@ -191,17 +237,21 @@ class AppService(BaseAppService, ABC):
|
|||||||
if asyncio.iscoroutinefunction(f):
|
if asyncio.iscoroutinefunction(f):
|
||||||
|
|
||||||
async def async_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
|
async def async_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
|
||||||
return await f(
|
result = await f(
|
||||||
**{name: getattr(body, name) for name in type(body).model_fields}
|
**{name: getattr(body, name) for name in type(body).model_fields}
|
||||||
)
|
)
|
||||||
|
_validate_no_prisma_objects(result, f"{func.__name__} result")
|
||||||
|
return result
|
||||||
|
|
||||||
return async_endpoint
|
return async_endpoint
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def sync_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
|
def sync_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
|
||||||
return f(
|
result = f(
|
||||||
**{name: getattr(body, name) for name in type(body).model_fields}
|
**{name: getattr(body, name) for name in type(body).model_fields}
|
||||||
)
|
)
|
||||||
|
_validate_no_prisma_objects(result, f"{func.__name__} result")
|
||||||
|
return result
|
||||||
|
|
||||||
return sync_endpoint
|
return sync_endpoint
|
||||||
|
|
||||||
@@ -313,6 +363,7 @@ def get_service_client(
|
|||||||
AttributeError, # Missing attributes
|
AttributeError, # Missing attributes
|
||||||
asyncio.CancelledError, # Task was cancelled
|
asyncio.CancelledError, # Task was cancelled
|
||||||
concurrent.futures.CancelledError, # Future was cancelled
|
concurrent.futures.CancelledError, # Future was cancelled
|
||||||
|
HTTPClientError, # HTTP 4xx client errors - don't retry
|
||||||
),
|
),
|
||||||
)(fn)
|
)(fn)
|
||||||
|
|
||||||
@@ -390,11 +441,31 @@ def get_service_client(
|
|||||||
self._connection_failure_count = 0
|
self._connection_failure_count = 0
|
||||||
return response.json()
|
return response.json()
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
error = RemoteCallError.model_validate(e.response.json())
|
status_code = e.response.status_code
|
||||||
# DEBUG HELP: if you made a custom exception, make sure you override self.args to be how to make your exception
|
|
||||||
raise EXCEPTION_MAPPING.get(error.type, Exception)(
|
# Try to parse the error response as RemoteCallError for mapped exceptions
|
||||||
*(error.args or [str(e)])
|
error_response = None
|
||||||
)
|
try:
|
||||||
|
error_response = RemoteCallError.model_validate(e.response.json())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# If we successfully parsed a mapped exception type, re-raise it
|
||||||
|
if error_response and error_response.type in EXCEPTION_MAPPING:
|
||||||
|
exception_class = EXCEPTION_MAPPING[error_response.type]
|
||||||
|
args = error_response.args or [str(e)]
|
||||||
|
raise exception_class(*args)
|
||||||
|
|
||||||
|
# Otherwise categorize by HTTP status code
|
||||||
|
if 400 <= status_code < 500:
|
||||||
|
# Client errors (4xx) - wrap to prevent retries
|
||||||
|
raise HTTPClientError(status_code, str(e))
|
||||||
|
elif 500 <= status_code < 600:
|
||||||
|
# Server errors (5xx) - wrap but allow retries
|
||||||
|
raise HTTPServerError(status_code, str(e))
|
||||||
|
else:
|
||||||
|
# Other status codes (1xx, 2xx, 3xx) - re-raise original error
|
||||||
|
raise e
|
||||||
|
|
||||||
@_maybe_retry
|
@_maybe_retry
|
||||||
def _call_method_sync(self, method_name: str, **kwargs: Any) -> Any:
|
def _call_method_sync(self, method_name: str, **kwargs: Any) -> Any:
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import pytest
|
|||||||
from backend.util.service import (
|
from backend.util.service import (
|
||||||
AppService,
|
AppService,
|
||||||
AppServiceClient,
|
AppServiceClient,
|
||||||
|
HTTPClientError,
|
||||||
|
HTTPServerError,
|
||||||
endpoint_to_async,
|
endpoint_to_async,
|
||||||
expose,
|
expose,
|
||||||
get_service_client,
|
get_service_client,
|
||||||
@@ -366,3 +368,125 @@ def test_service_no_retry_when_disabled(server):
|
|||||||
# This should fail immediately without retry
|
# This should fail immediately without retry
|
||||||
with pytest.raises(RuntimeError, match="Intended error for testing"):
|
with pytest.raises(RuntimeError, match="Intended error for testing"):
|
||||||
client.always_failing_add(5, 3)
|
client.always_failing_add(5, 3)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHTTPErrorRetryBehavior:
|
||||||
|
"""Test that HTTP client errors (4xx) are not retried but server errors (5xx) can be."""
|
||||||
|
|
||||||
|
# Note: These tests access private methods for testing internal behavior
|
||||||
|
# Type ignore comments are used to suppress warnings about accessing private methods
|
||||||
|
|
||||||
|
def test_http_client_error_not_retried(self):
|
||||||
|
"""Test that 4xx errors are wrapped as HTTPClientError and not retried."""
|
||||||
|
# Create a mock response with 404 status
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status_code = 404
|
||||||
|
mock_response.json.return_value = {"message": "Not found"}
|
||||||
|
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||||
|
"404 Not Found", request=Mock(), response=mock_response
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create client
|
||||||
|
client = get_service_client(ServiceTestClient)
|
||||||
|
dynamic_client = client
|
||||||
|
|
||||||
|
# Test the _handle_call_method_response directly
|
||||||
|
with pytest.raises(HTTPClientError) as exc_info:
|
||||||
|
dynamic_client._handle_call_method_response( # type: ignore[attr-defined]
|
||||||
|
response=mock_response, method_name="test_method"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 404
|
||||||
|
assert "404" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_http_server_error_can_be_retried(self):
|
||||||
|
"""Test that 5xx errors are wrapped as HTTPServerError and can be retried."""
|
||||||
|
# Create a mock response with 500 status
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status_code = 500
|
||||||
|
mock_response.json.return_value = {"message": "Internal server error"}
|
||||||
|
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||||
|
"500 Internal Server Error", request=Mock(), response=mock_response
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create client
|
||||||
|
client = get_service_client(ServiceTestClient)
|
||||||
|
dynamic_client = client
|
||||||
|
|
||||||
|
# Test the _handle_call_method_response directly
|
||||||
|
with pytest.raises(HTTPServerError) as exc_info:
|
||||||
|
dynamic_client._handle_call_method_response( # type: ignore[attr-defined]
|
||||||
|
response=mock_response, method_name="test_method"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 500
|
||||||
|
assert "500" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_mapped_exception_preserves_original_type(self):
|
||||||
|
"""Test that mapped exceptions preserve their original type regardless of HTTP status."""
|
||||||
|
# Create a mock response with ValueError in the remote call error
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status_code = 400
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"type": "ValueError",
|
||||||
|
"args": ["Invalid parameter value"],
|
||||||
|
}
|
||||||
|
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||||
|
"400 Bad Request", request=Mock(), response=mock_response
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create client
|
||||||
|
client = get_service_client(ServiceTestClient)
|
||||||
|
dynamic_client = client
|
||||||
|
|
||||||
|
# Test the _handle_call_method_response directly
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
dynamic_client._handle_call_method_response( # type: ignore[attr-defined]
|
||||||
|
response=mock_response, method_name="test_method"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Invalid parameter value" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_client_error_status_codes_coverage(self):
|
||||||
|
"""Test that various 4xx status codes are all wrapped as HTTPClientError."""
|
||||||
|
client_error_codes = [400, 401, 403, 404, 405, 409, 422, 429]
|
||||||
|
|
||||||
|
for status_code in client_error_codes:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status_code = status_code
|
||||||
|
mock_response.json.return_value = {"message": f"Error {status_code}"}
|
||||||
|
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||||
|
f"{status_code} Error", request=Mock(), response=mock_response
|
||||||
|
)
|
||||||
|
|
||||||
|
client = get_service_client(ServiceTestClient)
|
||||||
|
dynamic_client = client
|
||||||
|
|
||||||
|
with pytest.raises(HTTPClientError) as exc_info:
|
||||||
|
dynamic_client._handle_call_method_response( # type: ignore
|
||||||
|
response=mock_response, method_name="test_method"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == status_code
|
||||||
|
|
||||||
|
def test_server_error_status_codes_coverage(self):
|
||||||
|
"""Test that various 5xx status codes are all wrapped as HTTPServerError."""
|
||||||
|
server_error_codes = [500, 501, 502, 503, 504, 505]
|
||||||
|
|
||||||
|
for status_code in server_error_codes:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.status_code = status_code
|
||||||
|
mock_response.json.return_value = {"message": f"Error {status_code}"}
|
||||||
|
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||||
|
f"{status_code} Error", request=Mock(), response=mock_response
|
||||||
|
)
|
||||||
|
|
||||||
|
client = get_service_client(ServiceTestClient)
|
||||||
|
dynamic_client = client
|
||||||
|
|
||||||
|
with pytest.raises(HTTPServerError) as exc_info:
|
||||||
|
dynamic_client._handle_call_method_response( # type: ignore
|
||||||
|
response=mock_response, method_name="test_method"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == status_code
|
||||||
|
|||||||
Reference in New Issue
Block a user