diff --git a/autogpt_platform/autogpt_libs/autogpt_libs/feature_flag/client.py b/autogpt_platform/autogpt_libs/autogpt_libs/feature_flag/client.py index 1bd7584942..188ee4198a 100644 --- a/autogpt_platform/autogpt_libs/autogpt_libs/feature_flag/client.py +++ b/autogpt_platform/autogpt_libs/autogpt_libs/feature_flag/client.py @@ -1,8 +1,7 @@ -import asyncio import contextlib import logging from functools import wraps -from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, cast +from typing import Any, Awaitable, Callable, TypeVar import ldclient from fastapi import HTTPException @@ -10,6 +9,8 @@ from ldclient import Context, LDClient from ldclient.config import Config from typing_extensions import ParamSpec +from autogpt_libs.utils.cache import async_ttl_cache + from .config import SETTINGS logger = logging.getLogger(__name__) @@ -55,20 +56,46 @@ def shutdown_launchdarkly() -> None: logger.info("LaunchDarkly client closed successfully") -def create_context( - user_id: str, additional_attributes: Optional[Dict[str, Any]] = None -) -> Context: - """Create LaunchDarkly context with optional additional attributes.""" - builder = Context.builder(str(user_id)).kind("user") - if additional_attributes: - for key, value in additional_attributes.items(): - builder.set(key, value) +@async_ttl_cache(maxsize=1000, ttl_seconds=86400) # 1000 entries, 24 hours TTL +async def _fetch_user_context_data(user_id: str) -> Context: + """ + Fetch user context for LaunchDarkly from Supabase. + + Args: + user_id: The user ID to fetch data for + + Returns: + LaunchDarkly Context object + """ + builder = Context.builder(user_id).kind("user").anonymous(True) + + try: + from backend.util.clients import get_supabase + + # If we have user data, update context + response = get_supabase().auth.admin.get_user_by_id(user_id) + if response and response.user: + user = response.user + builder.anonymous(False) + if user.role: + builder.set("role", user.role) + if user.email: + builder.set("email", user.email) + builder.set("email_domain", user.email.split("@")[-1]) + + except Exception as e: + logger.warning(f"Failed to fetch user context for {user_id}: {e}") + return builder.build() -def is_feature_enabled(flag_key: str, user_id: str, default: bool = False) -> bool: +async def is_feature_enabled( + flag_key: str, + user_id: str, + default: bool = False, +) -> bool: """ - Simple helper to check if a feature flag is enabled for a user. + Check if a feature flag is enabled for a user. Args: flag_key: The LaunchDarkly feature flag key @@ -80,11 +107,18 @@ def is_feature_enabled(flag_key: str, user_id: str, default: bool = False) -> bo """ try: client = get_client() - context = create_context(str(user_id)) - return client.variation(flag_key, context, default) + + # Get user context from Supabase + context = await _fetch_user_context_data(user_id) + + # Evaluate flag + result = client.variation(flag_key, context, default) + + logger.debug(f"Feature flag {flag_key} for user {user_id}: {result}") + return result except Exception as e: - logger.debug( + logger.warning( f"LaunchDarkly flag evaluation failed for {flag_key}: {e}, using default={default}" ) return default @@ -93,16 +127,19 @@ def is_feature_enabled(flag_key: str, user_id: str, default: bool = False) -> bo def feature_flag( flag_key: str, default: bool = False, -) -> Callable[ - [Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]] -]: +) -> Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]: """ - Decorator for feature flag protected endpoints. + Decorator for async feature flag protected endpoints. + + Args: + flag_key: The LaunchDarkly feature flag key + default: Default value if flag evaluation fails + + Returns: + Decorator that only works with async functions """ - def decorator( - func: Callable[P, Union[T, Awaitable[T]]], - ) -> Callable[P, Union[T, Awaitable[T]]]: + def decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]: @wraps(func) async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: try: @@ -116,73 +153,24 @@ def feature_flag( ) is_enabled = default else: - context = create_context(str(user_id)) - is_enabled = get_client().variation(flag_key, context, default) - - if not is_enabled: - raise HTTPException(status_code=404, detail="Feature not available") - - result = func(*args, **kwargs) - if asyncio.iscoroutine(result): - return await result - return cast(T, result) - except Exception as e: - logger.error(f"Error evaluating feature flag {flag_key}: {e}") - raise - - @wraps(func) - def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - try: - user_id = kwargs.get("user_id") - if not user_id: - raise ValueError("user_id is required") - - if not get_client().is_initialized(): - logger.warning( - f"LaunchDarkly not initialized, using default={default}" + # Use the simplified function + is_enabled = await is_feature_enabled( + flag_key, str(user_id), default ) - is_enabled = default - else: - context = create_context(str(user_id)) - is_enabled = get_client().variation(flag_key, context, default) if not is_enabled: raise HTTPException(status_code=404, detail="Feature not available") - return cast(T, func(*args, **kwargs)) + return await func(*args, **kwargs) except Exception as e: logger.error(f"Error evaluating feature flag {flag_key}: {e}") raise - return cast( - Callable[P, Union[T, Awaitable[T]]], - async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper, - ) + return async_wrapper return decorator -def percentage_rollout( - flag_key: str, - default: bool = False, -) -> Callable[ - [Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]] -]: - """Decorator for percentage-based rollouts.""" - return feature_flag(flag_key, default) - - -def beta_feature( - flag_key: Optional[str] = None, - unauthorized_response: Any = {"message": "Not available in beta"}, -) -> Callable[ - [Callable[P, Union[T, Awaitable[T]]]], Callable[P, Union[T, Awaitable[T]]] -]: - """Decorator for beta features.""" - actual_key = f"beta-{flag_key}" if flag_key else "beta" - return feature_flag(actual_key, False) - - @contextlib.contextmanager def mock_flag_variation(flag_key: str, return_value: Any): """Context manager for testing feature flags.""" diff --git a/autogpt_platform/autogpt_libs/autogpt_libs/utils/cache.py b/autogpt_platform/autogpt_libs/autogpt_libs/utils/cache.py index 69858570ab..23328e46a3 100644 --- a/autogpt_platform/autogpt_libs/autogpt_libs/utils/cache.py +++ b/autogpt_platform/autogpt_libs/autogpt_libs/utils/cache.py @@ -1,17 +1,34 @@ import inspect +import logging import threading -from typing import Awaitable, Callable, ParamSpec, TypeVar, cast, overload +import time +from functools import wraps +from typing import ( + Awaitable, + Callable, + ParamSpec, + Protocol, + Tuple, + TypeVar, + cast, + overload, + runtime_checkable, +) P = ParamSpec("P") R = TypeVar("R") - -@overload -def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: ... +logger = logging.getLogger(__name__) @overload -def thread_cached(func: Callable[P, R]) -> Callable[P, R]: ... +def thread_cached(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: + pass + + +@overload +def thread_cached(func: Callable[P, R]) -> Callable[P, R]: + pass def thread_cached( @@ -57,3 +74,193 @@ def thread_cached( def clear_thread_cache(func: Callable) -> None: if clear := getattr(func, "clear_cache", None): clear() + + +FuncT = TypeVar("FuncT") + + +R_co = TypeVar("R_co", covariant=True) + + +@runtime_checkable +class AsyncCachedFunction(Protocol[P, R_co]): + """Protocol for async functions with cache management methods.""" + + def cache_clear(self) -> None: + """Clear all cached entries.""" + return None + + def cache_info(self) -> dict[str, int | None]: + """Get cache statistics.""" + return {} + + async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R_co: + """Call the cached function.""" + return None # type: ignore + + +def async_ttl_cache( + maxsize: int = 128, ttl_seconds: int | None = None +) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]: + """ + TTL (Time To Live) cache decorator for async functions. + + Similar to functools.lru_cache but works with async functions and includes optional TTL. + + Args: + maxsize: Maximum number of cached entries + ttl_seconds: Time to live in seconds. If None, entries never expire (like lru_cache) + + Returns: + Decorator function + + Example: + # With TTL + @async_ttl_cache(maxsize=1000, ttl_seconds=300) + async def api_call(param: str) -> dict: + return {"result": param} + + # Without TTL (permanent cache like lru_cache) + @async_ttl_cache(maxsize=1000) + async def expensive_computation(param: str) -> dict: + return {"result": param} + """ + + def decorator( + async_func: Callable[P, Awaitable[R]], + ) -> AsyncCachedFunction[P, R]: + # Cache storage - use union type to handle both cases + cache_storage: dict[tuple, R | Tuple[R, float]] = {} + + @wraps(async_func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + # Create cache key from arguments + key = (args, tuple(sorted(kwargs.items()))) + current_time = time.time() + + # Check if we have a valid cached entry + if key in cache_storage: + if ttl_seconds is None: + # No TTL - return cached result directly + logger.debug( + f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}" + ) + return cast(R, cache_storage[key]) + else: + # With TTL - check expiration + cached_data = cache_storage[key] + if isinstance(cached_data, tuple): + result, timestamp = cached_data + if current_time - timestamp < ttl_seconds: + logger.debug( + f"Cache hit for {async_func.__name__} with key: {str(key)[:50]}" + ) + return cast(R, result) + else: + # Expired entry + del cache_storage[key] + logger.debug( + f"Cache entry expired for {async_func.__name__}" + ) + + # Cache miss or expired - fetch fresh data + logger.debug( + f"Cache miss for {async_func.__name__} with key: {str(key)[:50]}" + ) + result = await async_func(*args, **kwargs) + + # Store in cache + if ttl_seconds is None: + cache_storage[key] = result + else: + cache_storage[key] = (result, current_time) + + # Simple cleanup when cache gets too large + if len(cache_storage) > maxsize: + # Remove oldest entries (simple FIFO cleanup) + cutoff = maxsize // 2 + oldest_keys = list(cache_storage.keys())[:-cutoff] if cutoff > 0 else [] + for old_key in oldest_keys: + cache_storage.pop(old_key, None) + logger.debug( + f"Cache cleanup: removed {len(oldest_keys)} entries for {async_func.__name__}" + ) + + return result + + # Add cache management methods (similar to functools.lru_cache) + def cache_clear() -> None: + cache_storage.clear() + + def cache_info() -> dict[str, int | None]: + return { + "size": len(cache_storage), + "maxsize": maxsize, + "ttl_seconds": ttl_seconds, + } + + # Attach methods to wrapper + setattr(wrapper, "cache_clear", cache_clear) + setattr(wrapper, "cache_info", cache_info) + + return cast(AsyncCachedFunction[P, R], wrapper) + + return decorator + + +@overload +def async_cache( + func: Callable[P, Awaitable[R]], +) -> AsyncCachedFunction[P, R]: + pass + + +@overload +def async_cache( + func: None = None, + *, + maxsize: int = 128, +) -> Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]]: + pass + + +def async_cache( + func: Callable[P, Awaitable[R]] | None = None, + *, + maxsize: int = 128, +) -> ( + AsyncCachedFunction[P, R] + | Callable[[Callable[P, Awaitable[R]]], AsyncCachedFunction[P, R]] +): + """ + Process-level cache decorator for async functions (no TTL). + + Similar to functools.lru_cache but works with async functions. + This is a convenience wrapper around async_ttl_cache with ttl_seconds=None. + + Args: + func: The async function to cache (when used without parentheses) + maxsize: Maximum number of cached entries + + Returns: + Decorated function or decorator + + Example: + # Without parentheses (uses default maxsize=128) + @async_cache + async def get_data(param: str) -> dict: + return {"result": param} + + # With parentheses and custom maxsize + @async_cache(maxsize=1000) + async def expensive_computation(param: str) -> dict: + # Expensive computation here + return {"result": param} + """ + if func is None: + # Called with parentheses @async_cache() or @async_cache(maxsize=...) + return async_ttl_cache(maxsize=maxsize, ttl_seconds=None) + else: + # Called without parentheses @async_cache + decorator = async_ttl_cache(maxsize=maxsize, ttl_seconds=None) + return decorator(func) diff --git a/autogpt_platform/autogpt_libs/autogpt_libs/utils/cache_test.py b/autogpt_platform/autogpt_libs/autogpt_libs/utils/cache_test.py index 91f9b0b824..e6ca3ecdfd 100644 --- a/autogpt_platform/autogpt_libs/autogpt_libs/utils/cache_test.py +++ b/autogpt_platform/autogpt_libs/autogpt_libs/utils/cache_test.py @@ -16,7 +16,12 @@ from unittest.mock import Mock import pytest -from autogpt_libs.utils.cache import clear_thread_cache, thread_cached +from autogpt_libs.utils.cache import ( + async_cache, + async_ttl_cache, + clear_thread_cache, + thread_cached, +) class TestThreadCached: @@ -323,3 +328,378 @@ class TestThreadCached: assert function_using_mock(2) == 42 assert mock.call_count == 2 + + +class TestAsyncTTLCache: + """Tests for the @async_ttl_cache decorator.""" + + @pytest.mark.asyncio + async def test_basic_caching(self): + """Test basic caching functionality.""" + call_count = 0 + + @async_ttl_cache(maxsize=10, ttl_seconds=60) + async def cached_function(x: int, y: int = 0) -> int: + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.01) # Simulate async work + return x + y + + # First call + result1 = await cached_function(1, 2) + assert result1 == 3 + assert call_count == 1 + + # Second call with same args - should use cache + result2 = await cached_function(1, 2) + assert result2 == 3 + assert call_count == 1 # No additional call + + # Different args - should call function again + result3 = await cached_function(2, 3) + assert result3 == 5 + assert call_count == 2 + + @pytest.mark.asyncio + async def test_ttl_expiration(self): + """Test that cache entries expire after TTL.""" + call_count = 0 + + @async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL + async def short_lived_cache(x: int) -> int: + nonlocal call_count + call_count += 1 + return x * 2 + + # First call + result1 = await short_lived_cache(5) + assert result1 == 10 + assert call_count == 1 + + # Second call immediately - should use cache + result2 = await short_lived_cache(5) + assert result2 == 10 + assert call_count == 1 + + # Wait for TTL to expire + await asyncio.sleep(1.1) + + # Third call after expiration - should call function again + result3 = await short_lived_cache(5) + assert result3 == 10 + assert call_count == 2 + + @pytest.mark.asyncio + async def test_cache_info(self): + """Test cache info functionality.""" + + @async_ttl_cache(maxsize=5, ttl_seconds=300) + async def info_test_function(x: int) -> int: + return x * 3 + + # Check initial cache info + info = info_test_function.cache_info() + assert info["size"] == 0 + assert info["maxsize"] == 5 + assert info["ttl_seconds"] == 300 + + # Add an entry + await info_test_function(1) + info = info_test_function.cache_info() + assert info["size"] == 1 + + @pytest.mark.asyncio + async def test_cache_clear(self): + """Test cache clearing functionality.""" + call_count = 0 + + @async_ttl_cache(maxsize=10, ttl_seconds=60) + async def clearable_function(x: int) -> int: + nonlocal call_count + call_count += 1 + return x * 4 + + # First call + result1 = await clearable_function(2) + assert result1 == 8 + assert call_count == 1 + + # Second call - should use cache + result2 = await clearable_function(2) + assert result2 == 8 + assert call_count == 1 + + # Clear cache + clearable_function.cache_clear() + + # Third call after clear - should call function again + result3 = await clearable_function(2) + assert result3 == 8 + assert call_count == 2 + + @pytest.mark.asyncio + async def test_maxsize_cleanup(self): + """Test that cache cleans up when maxsize is exceeded.""" + call_count = 0 + + @async_ttl_cache(maxsize=3, ttl_seconds=60) + async def size_limited_function(x: int) -> int: + nonlocal call_count + call_count += 1 + return x**2 + + # Fill cache to maxsize + await size_limited_function(1) # call_count: 1 + await size_limited_function(2) # call_count: 2 + await size_limited_function(3) # call_count: 3 + + info = size_limited_function.cache_info() + assert info["size"] == 3 + + # Add one more entry - should trigger cleanup + await size_limited_function(4) # call_count: 4 + + # Cache size should be reduced (cleanup removes oldest entries) + info = size_limited_function.cache_info() + assert info["size"] is not None and info["size"] <= 3 # Should be cleaned up + + @pytest.mark.asyncio + async def test_argument_variations(self): + """Test caching with different argument patterns.""" + call_count = 0 + + @async_ttl_cache(maxsize=10, ttl_seconds=60) + async def arg_test_function(a: int, b: str = "default", *, c: int = 100) -> str: + nonlocal call_count + call_count += 1 + return f"{a}-{b}-{c}" + + # Different ways to call with same logical arguments + result1 = await arg_test_function(1, "test", c=200) + assert call_count == 1 + + # Same arguments, same order - should use cache + result2 = await arg_test_function(1, "test", c=200) + assert call_count == 1 + assert result1 == result2 + + # Different arguments - should call function + result3 = await arg_test_function(2, "test", c=200) + assert call_count == 2 + assert result1 != result3 + + @pytest.mark.asyncio + async def test_exception_handling(self): + """Test that exceptions are not cached.""" + call_count = 0 + + @async_ttl_cache(maxsize=10, ttl_seconds=60) + async def exception_function(x: int) -> int: + nonlocal call_count + call_count += 1 + if x < 0: + raise ValueError("Negative value not allowed") + return x * 2 + + # Successful call - should be cached + result1 = await exception_function(5) + assert result1 == 10 + assert call_count == 1 + + # Same successful call - should use cache + result2 = await exception_function(5) + assert result2 == 10 + assert call_count == 1 + + # Exception call - should not be cached + with pytest.raises(ValueError): + await exception_function(-1) + assert call_count == 2 + + # Same exception call - should call again (not cached) + with pytest.raises(ValueError): + await exception_function(-1) + assert call_count == 3 + + @pytest.mark.asyncio + async def test_concurrent_calls(self): + """Test caching behavior with concurrent calls.""" + call_count = 0 + + @async_ttl_cache(maxsize=10, ttl_seconds=60) + async def concurrent_function(x: int) -> int: + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.05) # Simulate work + return x * x + + # Launch concurrent calls with same arguments + tasks = [concurrent_function(3) for _ in range(5)] + results = await asyncio.gather(*tasks) + + # All results should be the same + assert all(result == 9 for result in results) + + # Note: Due to race conditions, call_count might be up to 5 for concurrent calls + # This tests that the cache doesn't break under concurrent access + assert 1 <= call_count <= 5 + + +class TestAsyncCache: + """Tests for the @async_cache decorator (no TTL).""" + + @pytest.mark.asyncio + async def test_basic_caching_no_ttl(self): + """Test basic caching functionality without TTL.""" + call_count = 0 + + @async_cache(maxsize=10) + async def cached_function(x: int, y: int = 0) -> int: + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.01) # Simulate async work + return x + y + + # First call + result1 = await cached_function(1, 2) + assert result1 == 3 + assert call_count == 1 + + # Second call with same args - should use cache + result2 = await cached_function(1, 2) + assert result2 == 3 + assert call_count == 1 # No additional call + + # Third call after some time - should still use cache (no TTL) + await asyncio.sleep(0.05) + result3 = await cached_function(1, 2) + assert result3 == 3 + assert call_count == 1 # Still no additional call + + # Different args - should call function again + result4 = await cached_function(2, 3) + assert result4 == 5 + assert call_count == 2 + + @pytest.mark.asyncio + async def test_no_ttl_vs_ttl_behavior(self): + """Test the difference between TTL and no-TTL caching.""" + ttl_call_count = 0 + no_ttl_call_count = 0 + + @async_ttl_cache(maxsize=10, ttl_seconds=1) # Short TTL + async def ttl_function(x: int) -> int: + nonlocal ttl_call_count + ttl_call_count += 1 + return x * 2 + + @async_cache(maxsize=10) # No TTL + async def no_ttl_function(x: int) -> int: + nonlocal no_ttl_call_count + no_ttl_call_count += 1 + return x * 2 + + # First calls + await ttl_function(5) + await no_ttl_function(5) + assert ttl_call_count == 1 + assert no_ttl_call_count == 1 + + # Wait for TTL to expire + await asyncio.sleep(1.1) + + # Second calls after TTL expiry + await ttl_function(5) # Should call function again (TTL expired) + await no_ttl_function(5) # Should use cache (no TTL) + assert ttl_call_count == 2 # TTL function called again + assert no_ttl_call_count == 1 # No-TTL function still cached + + @pytest.mark.asyncio + async def test_async_cache_info(self): + """Test cache info for no-TTL cache.""" + + @async_cache(maxsize=5) + async def info_test_function(x: int) -> int: + return x * 3 + + # Check initial cache info + info = info_test_function.cache_info() + assert info["size"] == 0 + assert info["maxsize"] == 5 + assert info["ttl_seconds"] is None # No TTL + + # Add an entry + await info_test_function(1) + info = info_test_function.cache_info() + assert info["size"] == 1 + + +class TestTTLOptional: + """Tests for optional TTL functionality.""" + + @pytest.mark.asyncio + async def test_ttl_none_behavior(self): + """Test that ttl_seconds=None works like no TTL.""" + call_count = 0 + + @async_ttl_cache(maxsize=10, ttl_seconds=None) + async def no_ttl_via_none(x: int) -> int: + nonlocal call_count + call_count += 1 + return x**2 + + # First call + result1 = await no_ttl_via_none(3) + assert result1 == 9 + assert call_count == 1 + + # Wait (would expire if there was TTL) + await asyncio.sleep(0.1) + + # Second call - should still use cache + result2 = await no_ttl_via_none(3) + assert result2 == 9 + assert call_count == 1 # No additional call + + # Check cache info + info = no_ttl_via_none.cache_info() + assert info["ttl_seconds"] is None + + @pytest.mark.asyncio + async def test_cache_options_comparison(self): + """Test different cache options work as expected.""" + ttl_calls = 0 + no_ttl_calls = 0 + + @async_ttl_cache(maxsize=10, ttl_seconds=1) # With TTL + async def ttl_function(x: int) -> int: + nonlocal ttl_calls + ttl_calls += 1 + return x * 10 + + @async_cache(maxsize=10) # Process-level cache (no TTL) + async def process_function(x: int) -> int: + nonlocal no_ttl_calls + no_ttl_calls += 1 + return x * 10 + + # Both should cache initially + await ttl_function(3) + await process_function(3) + assert ttl_calls == 1 + assert no_ttl_calls == 1 + + # Immediate second calls - both should use cache + await ttl_function(3) + await process_function(3) + assert ttl_calls == 1 + assert no_ttl_calls == 1 + + # Wait for TTL to expire + await asyncio.sleep(1.1) + + # After TTL expiry + await ttl_function(3) # Should call function again + await process_function(3) # Should still use cache + assert ttl_calls == 2 # TTL cache expired, called again + assert no_ttl_calls == 1 # Process cache never expires diff --git a/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker.py b/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker.py index 5c96309592..c09eac09e1 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker.py +++ b/autogpt_platform/backend/backend/blocks/test/test_smart_decision_maker.py @@ -1,9 +1,8 @@ import logging import pytest -from prisma.models import User -from backend.data.model import ProviderName +from backend.data.model import ProviderName, User from backend.server.model import CreateGraph from backend.server.rest_api import AgentServer from backend.usecases.sample import create_test_graph, create_test_user diff --git a/autogpt_platform/backend/backend/data/credit.py b/autogpt_platform/backend/backend/data/credit.py index af0d25e8bb..5bddbd42ad 100644 --- a/autogpt_platform/backend/backend/data/credit.py +++ b/autogpt_platform/backend/backend/data/credit.py @@ -998,8 +998,8 @@ def get_block_costs() -> dict[str, list[BlockCost]]: async def get_stripe_customer_id(user_id: str) -> str: user = await get_user_by_id(user_id) - if user.stripeCustomerId: - return user.stripeCustomerId + if user.stripe_customer_id: + return user.stripe_customer_id customer = stripe.Customer.create( name=user.name or "", @@ -1022,10 +1022,10 @@ async def set_auto_top_up(user_id: str, config: AutoTopUpConfig): async def get_auto_top_up(user_id: str) -> AutoTopUpConfig: user = await get_user_by_id(user_id) - if not user.topUpConfig: + if not user.top_up_config: return AutoTopUpConfig(threshold=0, amount=0) - return AutoTopUpConfig.model_validate(user.topUpConfig) + return AutoTopUpConfig.model_validate(user.top_up_config) async def admin_get_user_history( diff --git a/autogpt_platform/backend/backend/data/model.py b/autogpt_platform/backend/backend/data/model.py index 66854dc73e..fa9e55ce0b 100644 --- a/autogpt_platform/backend/backend/data/model.py +++ b/autogpt_platform/backend/backend/data/model.py @@ -5,6 +5,7 @@ import enum import logging from collections import defaultdict from datetime import datetime, timezone +from json import JSONDecodeError from typing import ( TYPE_CHECKING, Annotated, @@ -40,12 +41,120 @@ from pydantic_core import ( from typing_extensions import TypedDict from backend.integrations.providers import ProviderName +from backend.util.json import loads as json_loads from backend.util.settings import Secrets # Type alias for any provider name (including custom ones) AnyProviderName = str # Will be validated as ProviderName at runtime + +class User(BaseModel): + """Application-layer User model with snake_case convention.""" + + model_config = ConfigDict( + extra="forbid", + str_strip_whitespace=True, + ) + + id: str = Field(..., description="User ID") + email: str = Field(..., description="User email address") + email_verified: bool = Field(default=True, description="Whether email is verified") + name: Optional[str] = Field(None, description="User display name") + created_at: datetime = Field(..., description="When user was created") + updated_at: datetime = Field(..., description="When user was last updated") + metadata: dict[str, Any] = Field( + default_factory=dict, description="User metadata as dict" + ) + integrations: str = Field(default="", description="Encrypted integrations data") + stripe_customer_id: Optional[str] = Field(None, description="Stripe customer ID") + top_up_config: Optional["AutoTopUpConfig"] = Field( + None, description="Top up configuration" + ) + + # Notification preferences + max_emails_per_day: int = Field(default=3, description="Maximum emails per day") + notify_on_agent_run: bool = Field(default=True, description="Notify on agent run") + notify_on_zero_balance: bool = Field( + default=True, description="Notify on zero balance" + ) + notify_on_low_balance: bool = Field( + default=True, description="Notify on low balance" + ) + notify_on_block_execution_failed: bool = Field( + default=True, description="Notify on block execution failure" + ) + notify_on_continuous_agent_error: bool = Field( + default=True, description="Notify on continuous agent error" + ) + notify_on_daily_summary: bool = Field( + default=True, description="Notify on daily summary" + ) + notify_on_weekly_summary: bool = Field( + default=True, description="Notify on weekly summary" + ) + notify_on_monthly_summary: bool = Field( + default=True, description="Notify on monthly summary" + ) + + @classmethod + def from_db(cls, prisma_user: "PrismaUser") -> "User": + """Convert a database User object to application User model.""" + # Handle metadata field - convert from JSON string or dict to dict + metadata = {} + if prisma_user.metadata: + if isinstance(prisma_user.metadata, str): + try: + metadata = json_loads(prisma_user.metadata) + except (JSONDecodeError, TypeError): + metadata = {} + elif isinstance(prisma_user.metadata, dict): + metadata = prisma_user.metadata + + # Handle topUpConfig field + top_up_config = None + if prisma_user.topUpConfig: + if isinstance(prisma_user.topUpConfig, str): + try: + config_dict = json_loads(prisma_user.topUpConfig) + top_up_config = AutoTopUpConfig.model_validate(config_dict) + except (JSONDecodeError, TypeError, ValueError): + top_up_config = None + elif isinstance(prisma_user.topUpConfig, dict): + try: + top_up_config = AutoTopUpConfig.model_validate( + prisma_user.topUpConfig + ) + except ValueError: + top_up_config = None + + return cls( + id=prisma_user.id, + email=prisma_user.email, + email_verified=prisma_user.emailVerified or True, + name=prisma_user.name, + created_at=prisma_user.createdAt, + updated_at=prisma_user.updatedAt, + metadata=metadata, + integrations=prisma_user.integrations or "", + stripe_customer_id=prisma_user.stripeCustomerId, + top_up_config=top_up_config, + max_emails_per_day=prisma_user.maxEmailsPerDay or 3, + notify_on_agent_run=prisma_user.notifyOnAgentRun or True, + notify_on_zero_balance=prisma_user.notifyOnZeroBalance or True, + notify_on_low_balance=prisma_user.notifyOnLowBalance or True, + notify_on_block_execution_failed=prisma_user.notifyOnBlockExecutionFailed + or True, + notify_on_continuous_agent_error=prisma_user.notifyOnContinuousAgentError + or True, + notify_on_daily_summary=prisma_user.notifyOnDailySummary or True, + notify_on_weekly_summary=prisma_user.notifyOnWeeklySummary or True, + notify_on_monthly_summary=prisma_user.notifyOnMonthlySummary or True, + ) + + if TYPE_CHECKING: + from prisma.models import User as PrismaUser + from backend.data.block import BlockSchema T = TypeVar("T") diff --git a/autogpt_platform/backend/backend/data/user.py b/autogpt_platform/backend/backend/data/user.py index 8b57a1d6bd..5b37d54152 100644 --- a/autogpt_platform/backend/backend/data/user.py +++ b/autogpt_platform/backend/backend/data/user.py @@ -9,11 +9,11 @@ from urllib.parse import quote_plus from autogpt_libs.auth.models import DEFAULT_USER_ID from fastapi import HTTPException from prisma.enums import NotificationType -from prisma.models import User +from prisma.models import User as PrismaUser from prisma.types import JsonFilter, UserCreateInput, UserUpdateInput from backend.data.db import prisma -from backend.data.model import UserIntegrations, UserMetadata +from backend.data.model import User, UserIntegrations, UserMetadata from backend.data.notifications import NotificationPreference, NotificationPreferenceDTO from backend.server.v2.store.exceptions import DatabaseError from backend.util.encryption import JSONCryptor @@ -44,7 +44,7 @@ async def get_or_create_user(user_data: dict) -> User: ) ) - return User.model_validate(user) + return User.from_db(user) except Exception as e: raise DatabaseError(f"Failed to get or create user {user_data}: {e}") from e @@ -53,7 +53,7 @@ async def get_user_by_id(user_id: str) -> User: user = await prisma.user.find_unique(where={"id": user_id}) if not user: raise ValueError(f"User not found with ID: {user_id}") - return User.model_validate(user) + return User.from_db(user) async def get_user_email_by_id(user_id: str) -> Optional[str]: @@ -67,7 +67,7 @@ async def get_user_email_by_id(user_id: str) -> Optional[str]: async def get_user_by_email(email: str) -> Optional[User]: try: user = await prisma.user.find_unique(where={"email": email}) - return User.model_validate(user) if user else None + return User.from_db(user) if user else None except Exception as e: raise DatabaseError(f"Failed to get user by email {email}: {e}") from e @@ -91,11 +91,11 @@ async def create_default_user() -> Optional[User]: name="Default User", ) ) - return User.model_validate(user) + return User.from_db(user) async def get_user_integrations(user_id: str) -> UserIntegrations: - user = await User.prisma().find_unique_or_raise( + user = await PrismaUser.prisma().find_unique_or_raise( where={"id": user_id}, ) @@ -110,7 +110,7 @@ async def get_user_integrations(user_id: str) -> UserIntegrations: async def update_user_integrations(user_id: str, data: UserIntegrations): encrypted_data = JSONCryptor().encrypt(data.model_dump(exclude_none=True)) - await User.prisma().update( + await PrismaUser.prisma().update( where={"id": user_id}, data={"integrations": encrypted_data}, ) @@ -118,7 +118,7 @@ async def update_user_integrations(user_id: str, data: UserIntegrations): async def migrate_and_encrypt_user_integrations(): """Migrate integration credentials and OAuth states from metadata to integrations column.""" - users = await User.prisma().find_many( + users = await PrismaUser.prisma().find_many( where={ "metadata": cast( JsonFilter, @@ -154,7 +154,7 @@ async def migrate_and_encrypt_user_integrations(): raw_metadata.pop("integration_oauth_states", None) # Update metadata without integration data - await User.prisma().update( + await PrismaUser.prisma().update( where={"id": user.id}, data={"metadata": SafeJson(raw_metadata)}, ) @@ -162,7 +162,7 @@ async def migrate_and_encrypt_user_integrations(): async def get_active_user_ids_in_timerange(start_time: str, end_time: str) -> list[str]: try: - users = await User.prisma().find_many( + users = await PrismaUser.prisma().find_many( where={ "AgentGraphExecutions": { "some": { @@ -192,7 +192,7 @@ async def get_active_users_ids() -> list[str]: async def get_user_notification_preference(user_id: str) -> NotificationPreference: try: - user = await User.prisma().find_unique_or_raise( + user = await PrismaUser.prisma().find_unique_or_raise( where={"id": user_id}, ) @@ -269,7 +269,7 @@ async def update_user_notification_preference( if data.daily_limit: update_data["maxEmailsPerDay"] = data.daily_limit - user = await User.prisma().update( + user = await PrismaUser.prisma().update( where={"id": user_id}, data=update_data, ) @@ -307,7 +307,7 @@ async def update_user_notification_preference( async def set_user_email_verification(user_id: str, verified: bool) -> None: """Set the email verification status for a user.""" try: - await User.prisma().update( + await PrismaUser.prisma().update( where={"id": user_id}, data={"emailVerified": verified}, ) @@ -320,7 +320,7 @@ async def set_user_email_verification(user_id: str, verified: bool) -> None: async def get_user_email_verification(user_id: str) -> bool: """Get the email verification status for a user.""" try: - user = await User.prisma().find_unique_or_raise( + user = await PrismaUser.prisma().find_unique_or_raise( where={"id": user_id}, ) return user.emailVerified diff --git a/autogpt_platform/backend/backend/executor/activity_status_generator.py b/autogpt_platform/backend/backend/executor/activity_status_generator.py index 3ad13d29d3..eea0faa449 100644 --- a/autogpt_platform/backend/backend/executor/activity_status_generator.py +++ b/autogpt_platform/backend/backend/executor/activity_status_generator.py @@ -102,8 +102,10 @@ async def generate_activity_status_for_execution( Returns: AI-generated activity status string, or None if feature is disabled """ - # Check LaunchDarkly feature flag for AI activity status generation - if not is_feature_enabled(AI_ACTIVITY_STATUS_FLAG_KEY, user_id, default=False): + # Check LaunchDarkly feature flag for AI activity status generation with full context support + if not await is_feature_enabled( + AI_ACTIVITY_STATUS_FLAG_KEY, user_id, default=False + ): logger.debug("AI activity status generation is disabled via LaunchDarkly") return None diff --git a/autogpt_platform/backend/backend/executor/manager_test.py b/autogpt_platform/backend/backend/executor/manager_test.py index 03d3e89011..c565eedfbf 100644 --- a/autogpt_platform/backend/backend/executor/manager_test.py +++ b/autogpt_platform/backend/backend/executor/manager_test.py @@ -3,7 +3,6 @@ import logging import autogpt_libs.auth.models import fastapi.responses import pytest -from prisma.models import User import backend.server.v2.library.model import backend.server.v2.store.model @@ -12,6 +11,7 @@ from backend.blocks.data_manipulation import FindInDictionaryBlock from backend.blocks.io import AgentInputBlock from backend.blocks.maths import CalculatorBlock, Operation from backend.data import execution, graph +from backend.data.model import User from backend.server.model import CreateGraph from backend.server.rest_api import AgentServer from backend.usecases.sample import create_test_graph, create_test_user diff --git a/autogpt_platform/backend/backend/server/integrations/utils.py b/autogpt_platform/backend/backend/server/integrations/utils.py deleted file mode 100644 index 0fa1052e5b..0000000000 --- a/autogpt_platform/backend/backend/server/integrations/utils.py +++ /dev/null @@ -1,11 +0,0 @@ -from supabase import Client, create_client - -from backend.util.settings import Settings - -settings = Settings() - - -def get_supabase() -> Client: - return create_client( - settings.secrets.supabase_url, settings.secrets.supabase_service_role_key - ) diff --git a/autogpt_platform/backend/backend/usecases/block_autogen.py b/autogpt_platform/backend/backend/usecases/block_autogen.py index 2cae49bf1a..7ccd766c5e 100644 --- a/autogpt_platform/backend/backend/usecases/block_autogen.py +++ b/autogpt_platform/backend/backend/usecases/block_autogen.py @@ -1,13 +1,12 @@ from pathlib import Path -from prisma.models import User - from backend.blocks.basic import StoreValueBlock from backend.blocks.block import BlockInstallationBlock from backend.blocks.http import SendWebRequestBlock from backend.blocks.llm import AITextGeneratorBlock from backend.blocks.text import ExtractTextInformationBlock, FillTextTemplateBlock from backend.data.graph import Graph, Link, Node, create_graph +from backend.data.model import User from backend.data.user import get_or_create_user from backend.util.test import SpinTestServer, wait_execution diff --git a/autogpt_platform/backend/backend/usecases/reddit_marketing.py b/autogpt_platform/backend/backend/usecases/reddit_marketing.py index ce702ef590..1ec381d618 100644 --- a/autogpt_platform/backend/backend/usecases/reddit_marketing.py +++ b/autogpt_platform/backend/backend/usecases/reddit_marketing.py @@ -1,9 +1,8 @@ -from prisma.models import User - from backend.blocks.llm import AIStructuredResponseGeneratorBlock from backend.blocks.reddit import GetRedditPostsBlock, PostRedditCommentBlock from backend.blocks.text import FillTextTemplateBlock, MatchTextPatternBlock from backend.data.graph import Graph, Link, Node, create_graph +from backend.data.model import User from backend.data.user import get_or_create_user from backend.util.test import SpinTestServer, wait_execution diff --git a/autogpt_platform/backend/backend/usecases/sample.py b/autogpt_platform/backend/backend/usecases/sample.py index 61b5424e7e..234a39c976 100644 --- a/autogpt_platform/backend/backend/usecases/sample.py +++ b/autogpt_platform/backend/backend/usecases/sample.py @@ -1,10 +1,9 @@ -from prisma.models import User - from backend.blocks.basic import StoreValueBlock from backend.blocks.io import AgentInputBlock from backend.blocks.text import FillTextTemplateBlock from backend.data import graph from backend.data.graph import create_graph +from backend.data.model import User from backend.data.user import get_or_create_user from backend.util.test import SpinTestServer, wait_execution diff --git a/autogpt_platform/backend/backend/util/clients.py b/autogpt_platform/backend/backend/util/clients.py index b30615c283..cdc66a807d 100644 --- a/autogpt_platform/backend/backend/util/clients.py +++ b/autogpt_platform/backend/backend/util/clients.py @@ -2,11 +2,18 @@ Centralized service client helpers with thread caching. """ +from functools import cache from typing import TYPE_CHECKING -from autogpt_libs.utils.cache import thread_cached +from autogpt_libs.utils.cache import async_cache, thread_cached + +from backend.util.settings import Settings + +settings = Settings() if TYPE_CHECKING: + from supabase import AClient, Client + from backend.data.execution import ( AsyncRedisExecutionEventBus, RedisExecutionEventBus, @@ -109,6 +116,29 @@ def get_integration_credentials_store() -> "IntegrationCredentialsStore": return IntegrationCredentialsStore() +# ============ Supabase Clients ============ # + + +@cache +def get_supabase() -> "Client": + """Get a process-cached synchronous Supabase client instance.""" + from supabase import create_client + + return create_client( + settings.secrets.supabase_url, settings.secrets.supabase_service_role_key + ) + + +@async_cache +async def get_async_supabase() -> "AClient": + """Get a process-cached asynchronous Supabase client instance.""" + from supabase import create_async_client + + return await create_async_client( + settings.secrets.supabase_url, settings.secrets.supabase_service_role_key + ) + + # ============ Notification Queue Helpers ============ # diff --git a/autogpt_platform/backend/backend/util/service.py b/autogpt_platform/backend/backend/util/service.py index 0346d35944..028f0e5fea 100644 --- a/autogpt_platform/backend/backend/util/service.py +++ b/autogpt_platform/backend/backend/util/service.py @@ -45,6 +45,34 @@ api_comm_retry = config.pyro_client_comm_retry api_comm_timeout = config.pyro_client_comm_timeout api_call_timeout = config.rpc_client_call_timeout + +def _validate_no_prisma_objects(obj: Any, path: str = "result") -> None: + """ + Recursively validate that no Prisma objects are being returned from service methods. + This enforces proper separation of layers - only application models should cross service boundaries. + """ + if obj is None: + return + + # Check if it's a Prisma model object + if hasattr(obj, "__class__") and hasattr(obj.__class__, "__module__"): + module_name = obj.__class__.__module__ + if module_name and "prisma.models" in module_name: + raise ValueError( + f"Prisma object {obj.__class__.__name__} found in {path}. " + "Service methods must return application models, not Prisma objects. " + f"Use {obj.__class__.__name__}.from_db() to convert to application model." + ) + + # Recursively check collections + if isinstance(obj, (list, tuple)): + for i, item in enumerate(obj): + _validate_no_prisma_objects(item, f"{path}[{i}]") + elif isinstance(obj, dict): + for key, value in obj.items(): + _validate_no_prisma_objects(value, f"{path}['{key}']") + + P = ParamSpec("P") R = TypeVar("R") EXPOSED_FLAG = "__exposed__" @@ -111,6 +139,22 @@ class UnhealthyServiceError(ValueError): return self.message +class HTTPClientError(Exception): + """Exception for HTTP client errors (4xx status codes) that should not be retried.""" + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + super().__init__(f"HTTP {status_code}: {message}") + + +class HTTPServerError(Exception): + """Exception for HTTP server errors (5xx status codes) that can be retried.""" + + def __init__(self, status_code: int, message: str): + self.status_code = status_code + super().__init__(f"HTTP {status_code}: {message}") + + EXCEPTION_MAPPING = { e.__name__: e for e in [ @@ -119,6 +163,8 @@ EXCEPTION_MAPPING = { TimeoutError, ConnectionError, UnhealthyServiceError, + HTTPClientError, + HTTPServerError, *[ ErrorType for _, ErrorType in inspect.getmembers(exceptions) @@ -191,17 +237,21 @@ class AppService(BaseAppService, ABC): if asyncio.iscoroutinefunction(f): async def async_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable - return await f( + result = await f( **{name: getattr(body, name) for name in type(body).model_fields} ) + _validate_no_prisma_objects(result, f"{func.__name__} result") + return result return async_endpoint else: def sync_endpoint(body: RequestBodyModel): # type: ignore #RequestBodyModel being variable - return f( + result = f( **{name: getattr(body, name) for name in type(body).model_fields} ) + _validate_no_prisma_objects(result, f"{func.__name__} result") + return result return sync_endpoint @@ -313,6 +363,7 @@ def get_service_client( AttributeError, # Missing attributes asyncio.CancelledError, # Task was cancelled concurrent.futures.CancelledError, # Future was cancelled + HTTPClientError, # HTTP 4xx client errors - don't retry ), )(fn) @@ -390,11 +441,31 @@ def get_service_client( self._connection_failure_count = 0 return response.json() except httpx.HTTPStatusError as e: - error = RemoteCallError.model_validate(e.response.json()) - # DEBUG HELP: if you made a custom exception, make sure you override self.args to be how to make your exception - raise EXCEPTION_MAPPING.get(error.type, Exception)( - *(error.args or [str(e)]) - ) + status_code = e.response.status_code + + # Try to parse the error response as RemoteCallError for mapped exceptions + error_response = None + try: + error_response = RemoteCallError.model_validate(e.response.json()) + except Exception: + pass + + # If we successfully parsed a mapped exception type, re-raise it + if error_response and error_response.type in EXCEPTION_MAPPING: + exception_class = EXCEPTION_MAPPING[error_response.type] + args = error_response.args or [str(e)] + raise exception_class(*args) + + # Otherwise categorize by HTTP status code + if 400 <= status_code < 500: + # Client errors (4xx) - wrap to prevent retries + raise HTTPClientError(status_code, str(e)) + elif 500 <= status_code < 600: + # Server errors (5xx) - wrap but allow retries + raise HTTPServerError(status_code, str(e)) + else: + # Other status codes (1xx, 2xx, 3xx) - re-raise original error + raise e @_maybe_retry def _call_method_sync(self, method_name: str, **kwargs: Any) -> Any: diff --git a/autogpt_platform/backend/backend/util/service_test.py b/autogpt_platform/backend/backend/util/service_test.py index 3a594fed25..1683c64220 100644 --- a/autogpt_platform/backend/backend/util/service_test.py +++ b/autogpt_platform/backend/backend/util/service_test.py @@ -8,6 +8,8 @@ import pytest from backend.util.service import ( AppService, AppServiceClient, + HTTPClientError, + HTTPServerError, endpoint_to_async, expose, get_service_client, @@ -366,3 +368,125 @@ def test_service_no_retry_when_disabled(server): # This should fail immediately without retry with pytest.raises(RuntimeError, match="Intended error for testing"): client.always_failing_add(5, 3) + + +class TestHTTPErrorRetryBehavior: + """Test that HTTP client errors (4xx) are not retried but server errors (5xx) can be.""" + + # Note: These tests access private methods for testing internal behavior + # Type ignore comments are used to suppress warnings about accessing private methods + + def test_http_client_error_not_retried(self): + """Test that 4xx errors are wrapped as HTTPClientError and not retried.""" + # Create a mock response with 404 status + mock_response = Mock() + mock_response.status_code = 404 + mock_response.json.return_value = {"message": "Not found"} + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "404 Not Found", request=Mock(), response=mock_response + ) + + # Create client + client = get_service_client(ServiceTestClient) + dynamic_client = client + + # Test the _handle_call_method_response directly + with pytest.raises(HTTPClientError) as exc_info: + dynamic_client._handle_call_method_response( # type: ignore[attr-defined] + response=mock_response, method_name="test_method" + ) + + assert exc_info.value.status_code == 404 + assert "404" in str(exc_info.value) + + def test_http_server_error_can_be_retried(self): + """Test that 5xx errors are wrapped as HTTPServerError and can be retried.""" + # Create a mock response with 500 status + mock_response = Mock() + mock_response.status_code = 500 + mock_response.json.return_value = {"message": "Internal server error"} + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "500 Internal Server Error", request=Mock(), response=mock_response + ) + + # Create client + client = get_service_client(ServiceTestClient) + dynamic_client = client + + # Test the _handle_call_method_response directly + with pytest.raises(HTTPServerError) as exc_info: + dynamic_client._handle_call_method_response( # type: ignore[attr-defined] + response=mock_response, method_name="test_method" + ) + + assert exc_info.value.status_code == 500 + assert "500" in str(exc_info.value) + + def test_mapped_exception_preserves_original_type(self): + """Test that mapped exceptions preserve their original type regardless of HTTP status.""" + # Create a mock response with ValueError in the remote call error + mock_response = Mock() + mock_response.status_code = 400 + mock_response.json.return_value = { + "type": "ValueError", + "args": ["Invalid parameter value"], + } + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "400 Bad Request", request=Mock(), response=mock_response + ) + + # Create client + client = get_service_client(ServiceTestClient) + dynamic_client = client + + # Test the _handle_call_method_response directly + with pytest.raises(ValueError) as exc_info: + dynamic_client._handle_call_method_response( # type: ignore[attr-defined] + response=mock_response, method_name="test_method" + ) + + assert "Invalid parameter value" in str(exc_info.value) + + def test_client_error_status_codes_coverage(self): + """Test that various 4xx status codes are all wrapped as HTTPClientError.""" + client_error_codes = [400, 401, 403, 404, 405, 409, 422, 429] + + for status_code in client_error_codes: + mock_response = Mock() + mock_response.status_code = status_code + mock_response.json.return_value = {"message": f"Error {status_code}"} + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + f"{status_code} Error", request=Mock(), response=mock_response + ) + + client = get_service_client(ServiceTestClient) + dynamic_client = client + + with pytest.raises(HTTPClientError) as exc_info: + dynamic_client._handle_call_method_response( # type: ignore + response=mock_response, method_name="test_method" + ) + + assert exc_info.value.status_code == status_code + + def test_server_error_status_codes_coverage(self): + """Test that various 5xx status codes are all wrapped as HTTPServerError.""" + server_error_codes = [500, 501, 502, 503, 504, 505] + + for status_code in server_error_codes: + mock_response = Mock() + mock_response.status_code = status_code + mock_response.json.return_value = {"message": f"Error {status_code}"} + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + f"{status_code} Error", request=Mock(), response=mock_response + ) + + client = get_service_client(ServiceTestClient) + dynamic_client = client + + with pytest.raises(HTTPServerError) as exc_info: + dynamic_client._handle_call_method_response( # type: ignore + response=mock_response, method_name="test_method" + ) + + assert exc_info.value.status_code == status_code diff --git a/autogpt_platform/backend/test/e2e_test_data.py b/autogpt_platform/backend/test/e2e_test_data.py index 8b90e89de7..19fba43a2d 100644 --- a/autogpt_platform/backend/test/e2e_test_data.py +++ b/autogpt_platform/backend/test/e2e_test_data.py @@ -30,10 +30,10 @@ from backend.data.graph import Graph, Link, Node, create_graph # Import API functions from the backend from backend.data.user import get_or_create_user -from backend.server.integrations.utils import get_supabase from backend.server.v2.library.db import create_library_agent, create_preset from backend.server.v2.library.model import LibraryAgentPresetCreatable from backend.server.v2.store.db import create_store_submission, review_store_submission +from backend.util.clients import get_supabase faker = Faker()