Compare commits

...

6 Commits

Author SHA1 Message Date
Zamil Majdy
5899897d8f refactor: simplify LaunchDarkly context and remove redundant code 2025-08-13 08:29:14 +00:00
Nicholas Tindle
0389c865aa feat(backend): in progress laucnhdarkly role lookup 2025-08-11 21:01:52 -05:00
Zamil Majdy
b88576d313 refactor(feature-flag): eliminate code duplication in role handling
Extract role and segment logic into reusable _add_role_to_attributes helper function.
This removes the duplicate code between use_user_id_only and full context paths.

- Add _add_role_to_attributes() helper function
- Remove duplicated role/segment logic in is_feature_enabled()
- Cleaner, more maintainable code with single source of truth

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-11 12:29:10 +07:00
Zamil Majdy
14634a6ce9 feat(platform): implement application-layer User model with proper type safety
## Core Changes

### Application Layer Separation
- Create application-layer User model with snake_case convention (created_at, email_verified, etc.)
- Add validation to prevent Prisma objects crossing service boundaries
- Replace all hasattr/getattr defensive coding with proper typing

### HTTP Client Improvements
- Prevent retry of HTTP 4xx client errors (404, 403, 401) which are permanent failures
- Add HTTPClientError and HTTPServerError exception categorization
- Comprehensive test coverage for retry behavior

### LaunchDarkly Integration Fixes
- Fix serialization issues by using proper snake_case application models
- Update feature flag client to use typed User model instead of Any
- Clean JSON parsing with proper imports (JSONDecodeError, json_loads)

### Type Safety Improvements
- Replace Any type annotations with proper PrismaUser typing
- Use AutoTopUpConfig class directly instead of generic dict
- Remove defensive hasattr() calls with direct attribute access
- Achieve zero format errors across entire codebase

## Files Modified
- backend/data/model.py: New application User model with from_db() converter
- backend/util/service.py: HTTP retry logic + Prisma validation
- backend/data/user.py: Updated to return application models
- autogpt_libs/feature_flag/client.py: Type-safe LaunchDarkly integration
- Multiple test files: Migrated to use application User model

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-11 12:26:30 +07:00
Zamil Majdy
10a402a766 fix(service): prevent retry of HTTP 4xx client errors
## Problem
The service client retry logic was incorrectly retrying HTTP 4xx client errors
(404, 403, 401, etc.), which should never be retried since they represent
permanent client-side issues that won't be resolved by retrying.

## Solution
- **Added HTTPClientError and HTTPServerError exceptions** to categorize HTTP errors
- **Modified retry exclusions** to include HTTPClientError in the exclude_exceptions list
- **Enhanced error handling** to wrap 4xx errors in HTTPClientError (non-retryable) and 5xx errors in HTTPServerError (retryable)
- **Preserved mapped exceptions** - When the server returns a properly formatted RemoteCallError with a mapped exception type (ValueError, etc.), that original exception is re-raised regardless of HTTP status

## Changes Made
- **New Exception Classes**: Added HTTPClientError and HTTPServerError with status code tracking
- **Improved Error Categorization**: HTTP errors are now properly categorized by status code:
  - 4xx → HTTPClientError (excluded from retries)
  - 5xx → HTTPServerError (can be retried)
  - Mapped exceptions (ValueError, etc.) → Original exception type preserved
- **Clean Logic Flow**: Simplified exception handling logic to check for mapped exceptions first, then fall back to HTTP status categorization
- **Comprehensive Tests**: Added TestHTTPErrorRetryBehavior with coverage for various status codes and exception mapping

## Benefits
- **No more wasted retry attempts** on permanent client errors (404, 403, etc.)
- **Faster error handling** for client errors since they fail immediately
- **Preserved compatibility** with existing exception mapping system
- **Better resource utilization** by avoiding unnecessary retry delays
- **Cleaner logs** with fewer spurious retry warnings

## Files Modified
- `backend/util/service.py`: Core retry logic and error handling improvements
- `backend/util/service_test.py`: Comprehensive test coverage for HTTP error retry behavior

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-11 11:54:25 +07:00
Zamil Majdy
004011726d feat(feature-flag): add LaunchDarkly segment support with 24h caching
## Summary
- Enhanced LaunchDarkly feature flag system to support context targeting with user segments
- Added 24-hour TTL caching for user context data to improve performance
- Implemented comprehensive user segmentation logic including role-based, domain-based, and account age segments

## Key Changes Made

### LaunchDarkly Client Enhancement
- **Segment Support**: Added full user context with segments for LaunchDarkly's contextTargets and segmentMatch rules
- **24-Hour Caching**: Implemented TTL cache with 86400 seconds (24 hours) to reduce database calls
- **Clean Architecture**: Uses existing `get_user_by_id` function following credentials_store.py pattern
- **Async-Only**: Removed sync support, made all functions async for consistency

### Segmentation Logic
- **Role-based**: user, admin, system segments
- **Domain-based**: employee segment for agpt.co emails, external for others
- **Account age**: new_user (<7 days), recent_user (7-30 days), established_user (>30 days)
- **Custom metadata**: Supports additional segments from user metadata

### Files Modified
- `autogpt_libs/feature_flag/client.py`: Main implementation with segment support and caching
- `autogpt_libs/utils/cache.py`: Added async TTL cache decorator with Protocol typing
- `autogpt_libs/utils/cache_test.py`: Comprehensive tests for TTL and process-level caching
- `backend/executor/database.py`: Exposed `get_user_by_id` for RPC support

## Technical Details
- Uses `@async_ttl_cache(maxsize=1000, ttl_seconds=86400)` for efficient caching
- Fallback to simple user_id context if database unavailable
- Proper exception handling without dangerous swallowing
- Strong typing with Protocol for cache functions
- Backwards compatible with `use_user_id_only` parameter

## Test Plan
- [x] All cache tests pass (TTL and process-level)
- [x] Feature flag evaluation works with segments
- [x] Database RPC calls work correctly
- [x] 24-hour cache reduces database load
- [x] Fallback to simple context when database unavailable

## Impact
- **Better targeting**: LaunchDarkly rules with segments now work correctly
- **Performance**: 24-hour cache dramatically reduces database calls
- **Maintainability**: Clean separation using existing database functions
- **Reliability**: Proper error handling and fallbacks

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-08-11 11:27:38 +07:00
15 changed files with 1088 additions and 109 deletions

View File

@@ -1,10 +1,14 @@
import asyncio
import contextlib
import logging
from functools import wraps
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, cast
from json import JSONDecodeError
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, TypeVar
if TYPE_CHECKING:
from backend.data.model import User
import ldclient
from backend.util.json import loads as json_loads
from fastapi import HTTPException
from ldclient import Context, LDClient
from ldclient.config import Config
@@ -56,32 +60,202 @@ def shutdown_launchdarkly() -> None:
def create_context(
user_id: str, additional_attributes: Optional[Dict[str, Any]] = None
user_id: str, additional_attributes: Optional[dict[str, Any]] = None
) -> Context:
"""Create LaunchDarkly context with optional additional attributes."""
builder = Context.builder(str(user_id)).kind("user")
# Use the key from attributes if provided, otherwise use user_id
context_key = user_id
if additional_attributes and "key" in additional_attributes:
context_key = additional_attributes["key"]
builder = Context.builder(str(context_key)).kind("user")
if additional_attributes:
for key, value in additional_attributes.items():
builder.set(key, value)
# Skip kind and key as they're already set
if key in ["kind", "key"]:
continue
elif key == "custom" and isinstance(value, dict):
# Handle custom attributes object - these go as individual attributes
for custom_key, custom_value in value.items():
builder.set(custom_key, custom_value)
else:
builder.set(key, value)
return builder.build()
def is_feature_enabled(flag_key: str, user_id: str, default: bool = False) -> bool:
async def _fetch_user_context_data(user_id: str) -> dict[str, Any]:
"""
Simple helper to check if a feature flag is enabled for a user.
Fetch user data and build LaunchDarkly context.
Args:
user_id: The user ID to fetch data for
Returns:
Dictionary with user context data including role
"""
# Use the unified database access approach
from backend.util.clients import get_database_manager_async_client
db_client = get_database_manager_async_client()
user = await db_client.get_user_by_id(user_id)
# Build LaunchDarkly context from user data
return _build_launchdarkly_context(user)
def _build_launchdarkly_context(user: "User") -> dict[str, Any]:
"""
Build LaunchDarkly context data matching frontend format.
Returns a context like:
{
"kind": "user",
"key": "user-id",
"email": "user@example.com", # Optional
"anonymous": false,
"custom": {
"role": "admin" # Optional
}
}
Args:
user: User object from database
Returns:
Dictionary with user context data
"""
from autogpt_libs.auth.models import DEFAULT_USER_ID
# Build basic context - always include kind, key, and anonymous
context_data: dict[str, Any] = {
"kind": "user",
"key": user.id,
"anonymous": False,
}
# Add email if present
if user.email:
context_data["email"] = user.email
# Initialize custom attributes
custom: dict[str, Any] = {}
# Determine user role from metadata
role = None
# Check if user is default/system user
if user.id == DEFAULT_USER_ID:
role = "admin" # Default user has admin privileges when auth is disabled
elif user.metadata:
# Check for role in metadata
try:
# Handle both string (direct DB) and dict (RPC) formats
if isinstance(user.metadata, str):
metadata = json_loads(user.metadata)
elif isinstance(user.metadata, dict):
metadata = user.metadata
else:
metadata = {}
# Extract role from metadata if present
if metadata.get("role"):
role = metadata["role"]
except (JSONDecodeError, TypeError) as e:
logger.debug(f"Failed to parse user metadata for context: {e}")
# Add role to custom attributes if present
if role:
custom["role"] = role
# Only add custom object if it has content
if custom:
context_data["custom"] = custom
return context_data
async def is_feature_enabled(
flag_key: str,
user_id: str,
default: bool = False,
use_user_id_only: bool = False,
additional_attributes: Optional[dict[str, Any]] = None,
user_role: Optional[str] = None,
) -> bool:
"""
Check if a feature flag is enabled for a user with full LaunchDarkly context support.
Args:
flag_key: The LaunchDarkly feature flag key
user_id: The user ID to evaluate the flag for
default: Default value if LaunchDarkly is unavailable or flag evaluation fails
use_user_id_only: If True, only use user_id without fetching database context
additional_attributes: Additional attributes to include in the context
user_role: Optional user role (e.g., "admin", "user") to add to segments
Returns:
True if feature is enabled, False otherwise
"""
try:
client = get_client()
context = create_context(str(user_id))
return client.variation(flag_key, context, default)
if use_user_id_only:
# Simple context with just user ID (for backward compatibility)
attrs = additional_attributes or {}
if user_role:
# Add role to custom attributes for consistency
if "custom" not in attrs:
attrs["custom"] = {}
if isinstance(attrs["custom"], dict):
attrs["custom"]["role"] = user_role
context = create_context(str(user_id), attrs)
else:
# Full context with user segments and metadata from database
try:
user_data = await _fetch_user_context_data(user_id)
except ImportError as e:
# Database modules not available - fallback to simple context
logger.debug(f"Database modules not available: {e}")
user_data = {}
except Exception as e:
# Database error - log and fallback to simple context
logger.warning(f"Failed to fetch user context for {user_id}: {e}")
user_data = {}
# Merge additional attributes and role
attrs = additional_attributes or {}
# If user_role is provided, add it to custom attributes
if user_role:
if "custom" not in user_data:
user_data["custom"] = {}
user_data["custom"]["role"] = user_role
# Merge additional attributes with user data
# Handle custom attributes specially
if "custom" in attrs and isinstance(attrs["custom"], dict):
if "custom" not in user_data:
user_data["custom"] = {}
user_data["custom"].update(attrs["custom"])
# Remove custom from attrs to avoid duplication
attrs = {k: v for k, v in attrs.items() if k != "custom"}
# Merge remaining attributes
final_attrs = {**user_data, **attrs}
context = create_context(str(user_id), final_attrs)
# Evaluate the flag
result = client.variation(flag_key, context, default)
logger.debug(
f"Feature flag {flag_key} for user {user_id}: {result} "
f"(use_user_id_only: {use_user_id_only})"
)
return result
except Exception as e:
logger.debug(
@@ -93,16 +267,19 @@ def is_feature_enabled(flag_key: str, user_id: str, default: bool = False) -> bo
def feature_flag(
flag_key: str,
default: bool = False,
) -> Callable[
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
]:
) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]:
"""
Decorator for feature flag protected endpoints.
Decorator for async feature flag protected endpoints.
Args:
flag_key: The LaunchDarkly feature flag key
default: Default value if flag evaluation fails
Returns:
Decorator that only works with async functions
"""
def decorator(
func: Callable[P, Union[T, Awaitable[T]]],
) -> Callable[P, Union[T, Awaitable[T]]]:
def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
@wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
try:
@@ -116,73 +293,24 @@ def feature_flag(
)
is_enabled = default
else:
context = create_context(str(user_id))
is_enabled = get_client().variation(flag_key, context, default)
if not is_enabled:
raise HTTPException(status_code=404, detail="Feature not available")
result = func(*args, **kwargs)
if asyncio.iscoroutine(result):
return await result
return cast(T, result)
except Exception as e:
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
raise
@wraps(func)
def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
try:
user_id = kwargs.get("user_id")
if not user_id:
raise ValueError("user_id is required")
if not get_client().is_initialized():
logger.warning(
f"LaunchDarkly not initialized, using default={default}"
# Use the unified function with full context support
is_enabled = await is_feature_enabled(
flag_key, str(user_id), default, use_user_id_only=False
)
is_enabled = default
else:
context = create_context(str(user_id))
is_enabled = get_client().variation(flag_key, context, default)
if not is_enabled:
raise HTTPException(status_code=404, detail="Feature not available")
return cast(T, func(*args, **kwargs))
return await func(*args, **kwargs)
except Exception as e:
logger.error(f"Error evaluating feature flag {flag_key}: {e}")
raise
return cast(
Callable[P, Union[T, Awaitable[T]]],
async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper,
)
return async_wrapper
return decorator
def percentage_rollout(
flag_key: str,
default: bool = False,
) -> Callable[
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
]:
"""Decorator for percentage-based rollouts."""
return feature_flag(flag_key, default)
def beta_feature(
flag_key: Optional[str] = None,
unauthorized_response: Any = {"message": "Not available in beta"},
) -> Callable[
[Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]]
]:
"""Decorator for beta features."""
actual_key = f"beta-{flag_key}" if flag_key else "beta"
return feature_flag(actual_key, False)
@contextlib.contextmanager
def mock_flag_variation(flag_key: str, return_value: Any):
"""Context manager for testing feature flags."""

View File

@@ -1,10 +1,26 @@
import inspect
import logging
import threading
from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload
import time
from functools import wraps
from typing import (
Any,
Awaitable,
Callable,
ParamSpec,
Protocol,
Tuple,
TypeVar,
cast,
overload,
runtime_checkable,
)
P = ParamSpec("P")
R = TypeVar("R")
logger = logging.getLogger(__name__)
@overload
def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ...
@@ -57,3 +73,153 @@ def thread_cached(
def clear_thread_cache(func: Callable) -> None:
if clear := getattr(func, "clear_cache", None):
clear()
@runtime_checkable
class AsyncCachedFunction(Protocol):
"""Protocol for async functions with cache management methods."""
def cache_clear(self) -> None:
"""Clear all cached entries."""
return None
def cache_info(self) -> dict[str, Any]:
"""Get cache statistics."""
return {}
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Call the cached function."""
return None
def async_ttl_cache(
maxsize: int = 128, ttl_seconds: int | None = None
) -> Callable[[Callable[..., Awaitable[Any]]], AsyncCachedFunction]:
"""
TTL (Time To Live) cache decorator for async functions.
Similar to functools.lru_cache but works with async functions and includes optional TTL.
Args:
maxsize: Maximum number of cached entries
ttl_seconds: Time to live in seconds. If None, entries never expire (like lru_cache)
Returns:
Decorator function
Example:
# With TTL
@async_ttl_cache(maxsize=1000, ttl_seconds=300)
async def api_call(param: str) -> dict:
return {"result": param}
# Without TTL (permanent cache like lru_cache)
@async_ttl_cache(maxsize=1000)
async def expensive_computation(param: str) -> dict:
return {"result": param}
"""
def decorator(
async_func: Callable[..., Awaitable[Any]],
) -> AsyncCachedFunction:
# Cache storage - use union type to handle both cases
cache_storage: dict[Any, Any | Tuple[Any, float]] = {}
@wraps(async_func)
async def wrapper(*args, **kwargs):
# Create cache key from arguments
key = (args, tuple(sorted(kwargs.items())))
current_time = time.time()
# Check if we have a valid cached entry
if key in cache_storage:
if ttl_seconds is None:
# No TTL - return cached result directly
logger.debug(
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
)
return cache_storage[key]
else:
# With TTL - check expiration
cached_data = cache_storage[key]
if isinstance(cached_data, tuple):
result, timestamp = cached_data
if current_time - timestamp < ttl_seconds:
logger.debug(
f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}"
)
return result
else:
# Expired entry
del cache_storage[key]
logger.debug(
f"Cache entry expired for {async_func.__name__}"
)
# Cache miss or expired - fetch fresh data
logger.debug(
f"Cache miss for {async_func.__name__} with key: {str(key)[:50]}"
)
result = await async_func(*args, **kwargs)
# Store in cache
if ttl_seconds is None:
cache_storage[key] = result
else:
cache_storage[key] = (result, current_time)
# Simple cleanup when cache gets too large
if len(cache_storage) > maxsize:
# Remove oldest entries (simple FIFO cleanup)
cutoff = maxsize // 2
oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else []
for old_key in oldest_keys:
cache_storage.pop(old_key, None)
logger.debug(
f"Cache cleanup: removed {len(oldest_keys)} entries for {async_func.__name__}"
)
return result
# Add cache management methods (similar to functools.lru_cache)
def cache_clear() -> None:
cache_storage.clear()
def cache_info() -> dict[str, Any]:
return {
"size": len(cache_storage),
"maxsize": maxsize,
"ttl_seconds": ttl_seconds,
}
# Attach methods to wrapper
setattr(wrapper, "cache_clear", cache_clear)
setattr(wrapper, "cache_info", cache_info)
return cast(AsyncCachedFunction, wrapper)
return decorator
def async_cache(
maxsize: int = 128,
) -> Callable[[Callable[..., Awaitable[Any]]], AsyncCachedFunction]:
"""
Process-level cache decorator for async functions (no TTL).
Similar to functools.lru_cache but works with async functions.
This is a convenience wrapper around async_ttl_cache with ttl_seconds=None.
Args:
maxsize: Maximum number of cached entries
Returns:
Decorator function
Example:
@async_cache(maxsize=1000)
async def expensive_computation(param: str) -> dict:
# Expensive computation here
return {"result": param}
"""
return async_ttl_cache(maxsize=maxsize, ttl_seconds=None)

View File

@@ -16,7 +16,12 @@ from unittest.mock import Mock
import pytest
from autogpt_libs.utils.cache import clear_thread_cache, thread_cached
from autogpt_libs.utils.cache import (
async_cache,
async_ttl_cache,
clear_thread_cache,
thread_cached,
)
class TestThreadCached:
@@ -323,3 +328,378 @@ class TestThreadCached:
assert function_using_mock(2) == 42
assert mock.call_count == 2
class TestAsyncTTLCache:
"""Tests for the @async_ttl_cache decorator."""
@pytest.mark.asyncio
async def test_basic_caching(self):
"""Test basic caching functionality."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def cached_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01) # Simulate async work
return x + y
# First call
result1 = await cached_function(1, 2)
assert result1 == 3
assert call_count == 1
# Second call with same args - should use cache
result2 = await cached_function(1, 2)
assert result2 == 3
assert call_count == 1 # No additional call
# Different args - should call function again
result3 = await cached_function(2, 3)
assert result3 == 5
assert call_count == 2
@pytest.mark.asyncio
async def test_ttl_expiration(self):
"""Test that cache entries expire after TTL."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
async def short_lived_cache(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 2
# First call
result1 = await short_lived_cache(5)
assert result1 == 10
assert call_count == 1
# Second call immediately - should use cache
result2 = await short_lived_cache(5)
assert result2 == 10
assert call_count == 1
# Wait for TTL to expire
await asyncio.sleep(1.1)
# Third call after expiration - should call function again
result3 = await short_lived_cache(5)
assert result3 == 10
assert call_count == 2
@pytest.mark.asyncio
async def test_cache_info(self):
"""Test cache info functionality."""
@async_ttl_cache(maxsize=5, ttl_seconds=300)
async def info_test_function(x: int) -> int:
return x * 3
# Check initial cache info
info = info_test_function.cache_info()
assert info["size"] == 0
assert info["maxsize"] == 5
assert info["ttl_seconds"] == 300
# Add an entry
await info_test_function(1)
info = info_test_function.cache_info()
assert info["size"] == 1
@pytest.mark.asyncio
async def test_cache_clear(self):
"""Test cache clearing functionality."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def clearable_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 4
# First call
result1 = await clearable_function(2)
assert result1 == 8
assert call_count == 1
# Second call - should use cache
result2 = await clearable_function(2)
assert result2 == 8
assert call_count == 1
# Clear cache
clearable_function.cache_clear()
# Third call after clear - should call function again
result3 = await clearable_function(2)
assert result3 == 8
assert call_count == 2
@pytest.mark.asyncio
async def test_maxsize_cleanup(self):
"""Test that cache cleans up when maxsize is exceeded."""
call_count = 0
@async_ttl_cache(maxsize=3, ttl_seconds=60)
async def size_limited_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x**2
# Fill cache to maxsize
await size_limited_function(1) # call_count: 1
await size_limited_function(2) # call_count: 2
await size_limited_function(3) # call_count: 3
info = size_limited_function.cache_info()
assert info["size"] == 3
# Add one more entry - should trigger cleanup
await size_limited_function(4) # call_count: 4
# Cache size should be reduced (cleanup removes oldest entries)
info = size_limited_function.cache_info()
assert info["size"] <= 3 # Should be cleaned up
@pytest.mark.asyncio
async def test_argument_variations(self):
"""Test caching with different argument patterns."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def arg_test_function(a: int, b: str = "default", *, c: int = 100) -> str:
nonlocal call_count
call_count += 1
return f"{a}-{b}-{c}"
# Different ways to call with same logical arguments
result1 = await arg_test_function(1, "test", c=200)
assert call_count == 1
# Same arguments, same order - should use cache
result2 = await arg_test_function(1, "test", c=200)
assert call_count == 1
assert result1 == result2
# Different arguments - should call function
result3 = await arg_test_function(2, "test", c=200)
assert call_count == 2
assert result1 != result3
@pytest.mark.asyncio
async def test_exception_handling(self):
"""Test that exceptions are not cached."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def exception_function(x: int) -> int:
nonlocal call_count
call_count += 1
if x < 0:
raise ValueError("Negative value not allowed")
return x * 2
# Successful call - should be cached
result1 = await exception_function(5)
assert result1 == 10
assert call_count == 1
# Same successful call - should use cache
result2 = await exception_function(5)
assert result2 == 10
assert call_count == 1
# Exception call - should not be cached
with pytest.raises(ValueError):
await exception_function(-1)
assert call_count == 2
# Same exception call - should call again (not cached)
with pytest.raises(ValueError):
await exception_function(-1)
assert call_count == 3
@pytest.mark.asyncio
async def test_concurrent_calls(self):
"""Test caching behavior with concurrent calls."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=60)
async def concurrent_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.05) # Simulate work
return x * x
# Launch concurrent calls with same arguments
tasks = [concurrent_function(3) for _ in range(5)]
results = await asyncio.gather(*tasks)
# All results should be the same
assert all(result == 9 for result in results)
# Note: Due to race conditions, call_count might be up to 5 for concurrent calls
# This tests that the cache doesn't break under concurrent access
assert 1 <= call_count <= 5
class TestAsyncCache:
"""Tests for the @async_cache decorator (no TTL)."""
@pytest.mark.asyncio
async def test_basic_caching_no_ttl(self):
"""Test basic caching functionality without TTL."""
call_count = 0
@async_cache(maxsize=10)
async def cached_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01) # Simulate async work
return x + y
# First call
result1 = await cached_function(1, 2)
assert result1 == 3
assert call_count == 1
# Second call with same args - should use cache
result2 = await cached_function(1, 2)
assert result2 == 3
assert call_count == 1 # No additional call
# Third call after some time - should still use cache (no TTL)
await asyncio.sleep(0.05)
result3 = await cached_function(1, 2)
assert result3 == 3
assert call_count == 1 # Still no additional call
# Different args - should call function again
result4 = await cached_function(2, 3)
assert result4 == 5
assert call_count == 2
@pytest.mark.asyncio
async def test_no_ttl_vs_ttl_behavior(self):
"""Test the difference between TTL and no-TTL caching."""
ttl_call_count = 0
no_ttl_call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL
async def ttl_function(x: int) -> int:
nonlocal ttl_call_count
ttl_call_count += 1
return x * 2
@async_cache(maxsize=10) # No TTL
async def no_ttl_function(x: int) -> int:
nonlocal no_ttl_call_count
no_ttl_call_count += 1
return x * 2
# First calls
await ttl_function(5)
await no_ttl_function(5)
assert ttl_call_count == 1
assert no_ttl_call_count == 1
# Wait for TTL to expire
await asyncio.sleep(1.1)
# Second calls after TTL expiry
await ttl_function(5) # Should call function again (TTL expired)
await no_ttl_function(5) # Should use cache (no TTL)
assert ttl_call_count == 2 # TTL function called again
assert no_ttl_call_count == 1 # No-TTL function still cached
@pytest.mark.asyncio
async def test_async_cache_info(self):
"""Test cache info for no-TTL cache."""
@async_cache(maxsize=5)
async def info_test_function(x: int) -> int:
return x * 3
# Check initial cache info
info = info_test_function.cache_info()
assert info["size"] == 0
assert info["maxsize"] == 5
assert info["ttl_seconds"] is None # No TTL
# Add an entry
await info_test_function(1)
info = info_test_function.cache_info()
assert info["size"] == 1
class TestTTLOptional:
"""Tests for optional TTL functionality."""
@pytest.mark.asyncio
async def test_ttl_none_behavior(self):
"""Test that ttl_seconds=None works like no TTL."""
call_count = 0
@async_ttl_cache(maxsize=10, ttl_seconds=None)
async def no_ttl_via_none(x: int) -> int:
nonlocal call_count
call_count += 1
return x**2
# First call
result1 = await no_ttl_via_none(3)
assert result1 == 9
assert call_count == 1
# Wait (would expire if there was TTL)
await asyncio.sleep(0.1)
# Second call - should still use cache
result2 = await no_ttl_via_none(3)
assert result2 == 9
assert call_count == 1 # No additional call
# Check cache info
info = no_ttl_via_none.cache_info()
assert info["ttl_seconds"] is None
@pytest.mark.asyncio
async def test_cache_options_comparison(self):
"""Test different cache options work as expected."""
ttl_calls = 0
no_ttl_calls = 0
@async_ttl_cache(maxsize=10, ttl_seconds=1) # With TTL
async def ttl_function(x: int) -> int:
nonlocal ttl_calls
ttl_calls += 1
return x * 10
@async_cache(maxsize=10) # Process-level cache (no TTL)
async def process_function(x: int) -> int:
nonlocal no_ttl_calls
no_ttl_calls += 1
return x * 10
# Both should cache initially
await ttl_function(3)
await process_function(3)
assert ttl_calls == 1
assert no_ttl_calls == 1
# Immediate second calls - both should use cache
await ttl_function(3)
await process_function(3)
assert ttl_calls == 1
assert no_ttl_calls == 1
# Wait for TTL to expire
await asyncio.sleep(1.1)
# After TTL expiry
await ttl_function(3) # Should call function again
await process_function(3) # Should still use cache
assert ttl_calls == 2 # TTL cache expired, called again
assert no_ttl_calls == 1 # Process cache never expires

View File

@@ -1,9 +1,8 @@
import logging
import pytest
from prisma.models import User
from backend.data.model import ProviderName
from backend.data.model import ProviderName, User
from backend.server.model import CreateGraph
from backend.server.rest_api import AgentServer
from backend.usecases.sample import create_test_graph, create_test_user

View File

@@ -998,8 +998,8 @@ def get_block_costs() -> dict[str, list[BlockCost]]:
async def get_stripe_customer_id(user_id: str) -> str:
user = await get_user_by_id(user_id)
if user.stripeCustomerId:
return user.stripeCustomerId
if user.stripe_customer_id:
return user.stripe_customer_id
customer = stripe.Customer.create(
name=user.name or "",
@@ -1022,10 +1022,10 @@ async def set_auto_top_up(user_id: str, config: AutoTopUpConfig):
async def get_auto_top_up(user_id: str) -> AutoTopUpConfig:
user = await get_user_by_id(user_id)
if not user.topUpConfig:
if not user.top_up_config:
return AutoTopUpConfig(threshold=0, amount=0)
return AutoTopUpConfig.model_validate(user.topUpConfig)
return AutoTopUpConfig.model_validate(user.top_up_config)
async def admin_get_user_history(

View File

@@ -5,6 +5,7 @@ import enum
import logging
from collections import defaultdict
from datetime import datetime, timezone
from json import JSONDecodeError
from typing import (
TYPE_CHECKING,
Annotated,
@@ -40,12 +41,120 @@ from pydantic_core import (
from typing_extensions import TypedDict
from backend.integrations.providers import ProviderName
from backend.util.json import loads as json_loads
from backend.util.settings import Secrets
# Type alias for any provider name (including custom ones)
AnyProviderName = str # Will be validated as ProviderName at runtime
class User(BaseModel):
"""Application-layer User model with snake_case convention."""
model_config = ConfigDict(
extra="forbid",
str_strip_whitespace=True,
)
id: str = Field(..., description="User ID")
email: str = Field(..., description="User email address")
email_verified: bool = Field(default=True, description="Whether email is verified")
name: Optional[str] = Field(None, description="User display name")
created_at: datetime = Field(..., description="When user was created")
updated_at: datetime = Field(..., description="When user was last updated")
metadata: dict[str, Any] = Field(
default_factory=dict, description="User metadata as dict"
)
integrations: str = Field(default="", description="Encrypted integrations data")
stripe_customer_id: Optional[str] = Field(None, description="Stripe customer ID")
top_up_config: Optional["AutoTopUpConfig"] = Field(
None, description="Top up configuration"
)
# Notification preferences
max_emails_per_day: int = Field(default=3, description="Maximum emails per day")
notify_on_agent_run: bool = Field(default=True, description="Notify on agent run")
notify_on_zero_balance: bool = Field(
default=True, description="Notify on zero balance"
)
notify_on_low_balance: bool = Field(
default=True, description="Notify on low balance"
)
notify_on_block_execution_failed: bool = Field(
default=True, description="Notify on block execution failure"
)
notify_on_continuous_agent_error: bool = Field(
default=True, description="Notify on continuous agent error"
)
notify_on_daily_summary: bool = Field(
default=True, description="Notify on daily summary"
)
notify_on_weekly_summary: bool = Field(
default=True, description="Notify on weekly summary"
)
notify_on_monthly_summary: bool = Field(
default=True, description="Notify on monthly summary"
)
@classmethod
def from_db(cls, prisma_user: "PrismaUser") -> "User":
"""Convert a database User object to application User model."""
# Handle metadata field - convert from JSON string or dict to dict
metadata = {}
if prisma_user.metadata:
if isinstance(prisma_user.metadata, str):
try:
metadata = json_loads(prisma_user.metadata)
except (JSONDecodeError, TypeError):
metadata = {}
elif isinstance(prisma_user.metadata, dict):
metadata = prisma_user.metadata
# Handle topUpConfig field
top_up_config = None
if prisma_user.topUpConfig:
if isinstance(prisma_user.topUpConfig, str):
try:
config_dict = json_loads(prisma_user.topUpConfig)
top_up_config = AutoTopUpConfig.model_validate(config_dict)
except (JSONDecodeError, TypeError, ValueError):
top_up_config = None
elif isinstance(prisma_user.topUpConfig, dict):
try:
top_up_config = AutoTopUpConfig.model_validate(
prisma_user.topUpConfig
)
except ValueError:
top_up_config = None
return cls(
id=prisma_user.id,
email=prisma_user.email,
email_verified=prisma_user.emailVerified or True,
name=prisma_user.name,
created_at=prisma_user.createdAt,
updated_at=prisma_user.updatedAt,
metadata=metadata,
integrations=prisma_user.integrations or "",
stripe_customer_id=prisma_user.stripeCustomerId,
top_up_config=top_up_config,
max_emails_per_day=prisma_user.maxEmailsPerDay or 3,
notify_on_agent_run=prisma_user.notifyOnAgentRun or True,
notify_on_zero_balance=prisma_user.notifyOnZeroBalance or True,
notify_on_low_balance=prisma_user.notifyOnLowBalance or True,
notify_on_block_execution_failed=prisma_user.notifyOnBlockExecutionFailed
or True,
notify_on_continuous_agent_error=prisma_user.notifyOnContinuousAgentError
or True,
notify_on_daily_summary=prisma_user.notifyOnDailySummary or True,
notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True,
notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True,
)
if TYPE_CHECKING:
from prisma.models import User as PrismaUser
from backend.data.block import BlockSchema
T = TypeVar("T")

View File

@@ -9,11 +9,11 @@ from urllib.parse import quote_plus
from autogpt_libs.auth.models import DEFAULT_USER_ID
from fastapi import HTTPException
from prisma.enums import NotificationType
from prisma.models import User
from prisma.models import User as PrismaUser
from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput
from backend.data.db import prisma
from backend.data.model import UserIntegrations, UserMetadata
from backend.data.model import User, UserIntegrations, UserMetadata
from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO
from backend.server.v2.store.exceptions import DatabaseError
from backend.util.encryption import JSONCryptor
@@ -44,7 +44,7 @@ async def get_or_create_user(user_data: dict) -> User:
)
)
return User.model_validate(user)
return User.from_db(user)
except Exception as e:
raise DatabaseError(f"Failed to get or create user {user_data}: {e}") from e
@@ -53,7 +53,7 @@ async def get_user_by_id(user_id: str) -> User:
user = await prisma.user.find_unique(where={"id": user_id})
if not user:
raise ValueError(f"User not found with ID: {user_id}")
return User.model_validate(user)
return User.from_db(user)
async def get_user_email_by_id(user_id: str) -> Optional[str]:
@@ -67,7 +67,7 @@ async def get_user_email_by_id(user_id: str) -> Optional[str]:
async def get_user_by_email(email: str) -> Optional[User]:
try:
user = await prisma.user.find_unique(where={"email": email})
return User.model_validate(user) if user else None
return User.from_db(user) if user else None
except Exception as e:
raise DatabaseError(f"Failed to get user by email {email}: {e}") from e
@@ -91,11 +91,11 @@ async def create_default_user() -> Optional[User]:
name="Default User",
)
)
return User.model_validate(user)
return User.from_db(user)
async def get_user_integrations(user_id: str) -> UserIntegrations:
user = await User.prisma().find_unique_or_raise(
user = await PrismaUser.prisma().find_unique_or_raise(
where={"id": user_id},
)
@@ -110,7 +110,7 @@ async def get_user_integrations(user_id: str) -> UserIntegrations:
async def update_user_integrations(user_id: str, data: UserIntegrations):
encrypted_data = JSONCryptor().encrypt(data.model_dump(exclude_none=True))
await User.prisma().update(
await PrismaUser.prisma().update(
where={"id": user_id},
data={"integrations": encrypted_data},
)
@@ -118,7 +118,7 @@ async def update_user_integrations(user_id: str, data: UserIntegrations):
async def migrate_and_encrypt_user_integrations():
"""Migrate integration credentials and OAuth states from metadata to integrations column."""
users = await User.prisma().find_many(
users = await PrismaUser.prisma().find_many(
where={
"metadata": cast(
JsonFilter,
@@ -154,7 +154,7 @@ async def migrate_and_encrypt_user_integrations():
raw_metadata.pop("integration_oauth_states", None)
# Update metadata without integration data
await User.prisma().update(
await PrismaUser.prisma().update(
where={"id": user.id},
data={"metadata": SafeJson(raw_metadata)},
)
@@ -162,7 +162,7 @@ async def migrate_and_encrypt_user_integrations():
async def get_active_user_ids_in_timerange(start_time: str, end_time: str) -> list[str]:
try:
users = await User.prisma().find_many(
users = await PrismaUser.prisma().find_many(
where={
"AgentGraphExecutions": {
"some": {
@@ -192,7 +192,7 @@ async def get_active_users_ids() -> list[str]:
async def get_user_notification_preference(user_id: str) -> NotificationPreference:
try:
user = await User.prisma().find_unique_or_raise(
user = await PrismaUser.prisma().find_unique_or_raise(
where={"id": user_id},
)
@@ -269,7 +269,7 @@ async def update_user_notification_preference(
if data.daily_limit:
update_data["maxEmailsPerDay"] = data.daily_limit
user = await User.prisma().update(
user = await PrismaUser.prisma().update(
where={"id": user_id},
data=update_data,
)
@@ -307,7 +307,7 @@ async def update_user_notification_preference(
async def set_user_email_verification(user_id: str, verified: bool) -> None:
"""Set the email verification status for a user."""
try:
await User.prisma().update(
await PrismaUser.prisma().update(
where={"id": user_id},
data={"emailVerified": verified},
)
@@ -320,7 +320,7 @@ async def set_user_email_verification(user_id: str, verified: bool) -> None:
async def get_user_email_verification(user_id: str) -> bool:
"""Get the email verification status for a user."""
try:
user = await User.prisma().find_unique_or_raise(
user = await PrismaUser.prisma().find_unique_or_raise(
where={"id": user_id},
)
return user.emailVerified

View File

@@ -102,8 +102,10 @@ async def generate_activity_status_for_execution(
Returns:
AI-generated activity status string, or None if feature is disabled
"""
# Check LaunchDarkly feature flag for AI activity status generation
if not is_feature_enabled(AI_ACTIVITY_STATUS_FLAG_KEY, user_id, default=False):
# Check LaunchDarkly feature flag for AI activity status generation with full context support
if not await is_feature_enabled(
AI_ACTIVITY_STATUS_FLAG_KEY, user_id, default=False
):
logger.debug("AI activity status generation is disabled via LaunchDarkly")
return None

View File

@@ -35,6 +35,7 @@ from backend.data.notifications import (
)
from backend.data.user import (
get_active_user_ids_in_timerange,
get_user_by_id,
get_user_email_by_id,
get_user_email_verification,
get_user_integrations,
@@ -129,6 +130,7 @@ class DatabaseManager(AppService):
# User Comms - async
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
get_user_by_id = _(get_user_by_id)
get_user_email_by_id = _(get_user_email_by_id)
get_user_email_verification = _(get_user_email_verification)
get_user_notification_preference = _(get_user_notification_preference)
@@ -201,6 +203,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
# User Comms
get_active_user_ids_in_timerange = d.get_active_user_ids_in_timerange
get_user_by_id = d.get_user_by_id
get_user_email_by_id = d.get_user_email_by_id
get_user_email_verification = d.get_user_email_verification
get_user_notification_preference = d.get_user_notification_preference

View File

@@ -3,7 +3,6 @@ import logging
import autogpt_libs.auth.models
import fastapi.responses
import pytest
from prisma.models import User
import backend.server.v2.library.model
import backend.server.v2.store.model
@@ -12,6 +11,7 @@ from backend.blocks.data_manipulation import FindInDictionaryBlock
from backend.blocks.io import AgentInputBlock
from backend.blocks.maths import CalculatorBlock, Operation
from backend.data import execution, graph
from backend.data.model import User
from backend.server.model import CreateGraph
from backend.server.rest_api import AgentServer
from backend.usecases.sample import create_test_graph, create_test_user

View File

@@ -1,13 +1,12 @@
from pathlib import Path
from prisma.models import User
from backend.blocks.basic import StoreValueBlock
from backend.blocks.block import BlockInstallationBlock
from backend.blocks.http import SendWebRequestBlock
from backend.blocks.llm import AITextGeneratorBlock
from backend.blocks.text import ExtractTextInformationBlock, FillTextTemplateBlock
from backend.data.graph import Graph, Link, Node, create_graph
from backend.data.model import User
from backend.data.user import get_or_create_user
from backend.util.test import SpinTestServer, wait_execution

View File

@@ -1,9 +1,8 @@
from prisma.models import User
from backend.blocks.llm import AIStructuredResponseGeneratorBlock
from backend.blocks.reddit import GetRedditPostsBlock, PostRedditCommentBlock
from backend.blocks.text import FillTextTemplateBlock, MatchTextPatternBlock
from backend.data.graph import Graph, Link, Node, create_graph
from backend.data.model import User
from backend.data.user import get_or_create_user
from backend.util.test import SpinTestServer, wait_execution

View File

@@ -1,10 +1,9 @@
from prisma.models import User
from backend.blocks.basic import StoreValueBlock
from backend.blocks.io import AgentInputBlock
from backend.blocks.text import FillTextTemplateBlock
from backend.data import graph
from backend.data.graph import create_graph
from backend.data.model import User
from backend.data.user import get_or_create_user
from backend.util.test import SpinTestServer, wait_execution

View File

@@ -45,6 +45,34 @@ api_comm_retry = config.pyro_client_comm_retry
api_comm_timeout = config.pyro_client_comm_timeout
api_call_timeout = config.rpc_client_call_timeout
def _validate_no_prisma_objects(obj: Any, path: str = "result") -> None:
"""
Recursively validate that no Prisma objects are being returned from service methods.
This enforces proper separation of layers - only application models should cross service boundaries.
"""
if obj is None:
return
# Check if it's a Prisma model object
if hasattr(obj, "__class__") and hasattr(obj.__class__, "__module__"):
module_name = obj.__class__.__module__
if module_name and "prisma.models" in module_name:
raise ValueError(
f"Prisma object {obj.__class__.__name__} found in {path}. "
"Service methods must return application models, not Prisma objects. "
f"Use {obj.__class__.__name__}.from_db() to convert to application model."
)
# Recursively check collections
if isinstance(obj, (list, tuple)):
for i, item in enumerate(obj):
_validate_no_prisma_objects(item, f"{path}[{i}]")
elif isinstance(obj, dict):
for key, value in obj.items():
_validate_no_prisma_objects(value, f"{path}['{key}']")
P = ParamSpec("P")
R = TypeVar("R")
EXPOSED_FLAG = "__exposed__"
@@ -111,6 +139,22 @@ class UnhealthyServiceError(ValueError):
return self.message
class HTTPClientError(Exception):
"""Exception for HTTP client errors (4xx status codes) that should not be retried."""
def __init__(self, status_code: int, message: str):
self.status_code = status_code
super().__init__(f"HTTP {status_code}: {message}")
class HTTPServerError(Exception):
"""Exception for HTTP server errors (5xx status codes) that can be retried."""
def __init__(self, status_code: int, message: str):
self.status_code = status_code
super().__init__(f"HTTP {status_code}: {message}")
EXCEPTION_MAPPING = {
e.__name__: e
for e in [
@@ -119,6 +163,8 @@ EXCEPTION_MAPPING = {
TimeoutError,
ConnectionError,
UnhealthyServiceError,
HTTPClientError,
HTTPServerError,
*[
ErrorType
for _, ErrorType in inspect.getmembers(exceptions)
@@ -191,17 +237,21 @@ class AppService(BaseAppService, ABC):
if asyncio.iscoroutinefunction(f):
async def async_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
return await f(
result = await f(
**{name: getattr(body, name) for name in type(body).model_fields}
)
_validate_no_prisma_objects(result, f"{func.__name__} result")
return result
return async_endpoint
else:
def sync_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable
return f(
result = f(
**{name: getattr(body, name) for name in type(body).model_fields}
)
_validate_no_prisma_objects(result, f"{func.__name__} result")
return result
return sync_endpoint
@@ -313,6 +363,7 @@ def get_service_client(
AttributeError, # Missing attributes
asyncio.CancelledError, # Task was cancelled
concurrent.futures.CancelledError, # Future was cancelled
HTTPClientError, # HTTP 4xx client errors - don't retry
),
)(fn)
@@ -390,11 +441,31 @@ def get_service_client(
self._connection_failure_count = 0
return response.json()
except httpx.HTTPStatusError as e:
error = RemoteCallError.model_validate(e.response.json())
# DEBUG HELP: if you made a custom exception, make sure you override self.args to be how to make your exception
raise EXCEPTION_MAPPING.get(error.type, Exception)(
*(error.args or [str(e)])
)
status_code = e.response.status_code
# Try to parse the error response as RemoteCallError for mapped exceptions
error_response = None
try:
error_response = RemoteCallError.model_validate(e.response.json())
except Exception:
pass
# If we successfully parsed a mapped exception type, re-raise it
if error_response and error_response.type in EXCEPTION_MAPPING:
exception_class = EXCEPTION_MAPPING[error_response.type]
args = error_response.args or [str(e)]
raise exception_class(*args)
# Otherwise categorize by HTTP status code
if 400 <= status_code < 500:
# Client errors (4xx) - wrap to prevent retries
raise HTTPClientError(status_code, str(e))
elif 500 <= status_code < 600:
# Server errors (5xx) - wrap but allow retries
raise HTTPServerError(status_code, str(e))
else:
# Other status codes (1xx, 2xx, 3xx) - re-raise original error
raise e
@_maybe_retry
def _call_method_sync(self, method_name: str, **kwargs: Any) -> Any:

View File

@@ -8,6 +8,8 @@ import pytest
from backend.util.service import (
AppService,
AppServiceClient,
HTTPClientError,
HTTPServerError,
endpoint_to_async,
expose,
get_service_client,
@@ -366,3 +368,125 @@ def test_service_no_retry_when_disabled(server):
# This should fail immediately without retry
with pytest.raises(RuntimeError, match="Intended error for testing"):
client.always_failing_add(5, 3)
class TestHTTPErrorRetryBehavior:
"""Test that HTTP client errors (4xx) are not retried but server errors (5xx) can be."""
# Note: These tests access private methods for testing internal behavior
# Type ignore comments are used to suppress warnings about accessing private methods
def test_http_client_error_not_retried(self):
"""Test that 4xx errors are wrapped as HTTPClientError and not retried."""
# Create a mock response with 404 status
mock_response = Mock()
mock_response.status_code = 404
mock_response.json.return_value = {"message": "Not found"}
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"404 Not Found", request=Mock(), response=mock_response
)
# Create client
client = get_service_client(ServiceTestClient)
dynamic_client = client
# Test the _handle_call_method_response directly
with pytest.raises(HTTPClientError) as exc_info:
dynamic_client._handle_call_method_response( # type: ignore[attr-defined]
response=mock_response, method_name="test_method"
)
assert exc_info.value.status_code == 404
assert "404" in str(exc_info.value)
def test_http_server_error_can_be_retried(self):
"""Test that 5xx errors are wrapped as HTTPServerError and can be retried."""
# Create a mock response with 500 status
mock_response = Mock()
mock_response.status_code = 500
mock_response.json.return_value = {"message": "Internal server error"}
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"500 Internal Server Error", request=Mock(), response=mock_response
)
# Create client
client = get_service_client(ServiceTestClient)
dynamic_client = client
# Test the _handle_call_method_response directly
with pytest.raises(HTTPServerError) as exc_info:
dynamic_client._handle_call_method_response( # type: ignore[attr-defined]
response=mock_response, method_name="test_method"
)
assert exc_info.value.status_code == 500
assert "500" in str(exc_info.value)
def test_mapped_exception_preserves_original_type(self):
"""Test that mapped exceptions preserve their original type regardless of HTTP status."""
# Create a mock response with ValueError in the remote call error
mock_response = Mock()
mock_response.status_code = 400
mock_response.json.return_value = {
"type": "ValueError",
"args": ["Invalid parameter value"],
}
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"400 Bad Request", request=Mock(), response=mock_response
)
# Create client
client = get_service_client(ServiceTestClient)
dynamic_client = client
# Test the _handle_call_method_response directly
with pytest.raises(ValueError) as exc_info:
dynamic_client._handle_call_method_response( # type: ignore[attr-defined]
response=mock_response, method_name="test_method"
)
assert "Invalid parameter value" in str(exc_info.value)
def test_client_error_status_codes_coverage(self):
"""Test that various 4xx status codes are all wrapped as HTTPClientError."""
client_error_codes = [400, 401, 403, 404, 405, 409, 422, 429]
for status_code in client_error_codes:
mock_response = Mock()
mock_response.status_code = status_code
mock_response.json.return_value = {"message": f"Error {status_code}"}
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
f"{status_code} Error", request=Mock(), response=mock_response
)
client = get_service_client(ServiceTestClient)
dynamic_client = client
with pytest.raises(HTTPClientError) as exc_info:
dynamic_client._handle_call_method_response( # type: ignore
response=mock_response, method_name="test_method"
)
assert exc_info.value.status_code == status_code
def test_server_error_status_codes_coverage(self):
"""Test that various 5xx status codes are all wrapped as HTTPServerError."""
server_error_codes = [500, 501, 502, 503, 504, 505]
for status_code in server_error_codes:
mock_response = Mock()
mock_response.status_code = status_code
mock_response.json.return_value = {"message": f"Error {status_code}"}
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
f"{status_code} Error", request=Mock(), response=mock_response
)
client = get_service_client(ServiceTestClient)
dynamic_client = client
with pytest.raises(HTTPServerError) as exc_info:
dynamic_client._handle_call_method_response( # type: ignore
response=mock_response, method_name="test_method"
)
assert exc_info.value.status_code == status_code