diff --git a/autogpt_platform/autogpt_libs/autogpt_libs/utils/cache_test.py b/autogpt_platform/autogpt_libs/autogpt_libs/utils/cache_test.py new file mode 100644 index 0000000000..91f9b0b824 --- /dev/null +++ b/autogpt_platform/autogpt_libs/autogpt_libs/utils/cache_test.py @@ -0,0 +1,325 @@ +"""Tests for the @thread_cached decorator. + +This module tests the thread-local caching functionality including: +- Basic caching for sync and async functions +- Thread isolation (each thread has its own cache) +- Cache clearing functionality +- Exception handling (exceptions are not cached) +- Argument handling (positional vs keyword arguments) +""" + +import asyncio +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import Mock + +import pytest + +from autogpt_libs.utils.cache import clear_thread_cache, thread_cached + + +class TestThreadCached: + def test_sync_function_caching(self): + call_count = 0 + + @thread_cached + def expensive_function(x: int, y: int = 0) -> int: + nonlocal call_count + call_count += 1 + return x + y + + assert expensive_function(1, 2) == 3 + assert call_count == 1 + + assert expensive_function(1, 2) == 3 + assert call_count == 1 + + assert expensive_function(1, y=2) == 3 + assert call_count == 2 + + assert expensive_function(2, 3) == 5 + assert call_count == 3 + + assert expensive_function(1) == 1 + assert call_count == 4 + + @pytest.mark.asyncio + async def test_async_function_caching(self): + call_count = 0 + + @thread_cached + async def expensive_async_function(x: int, y: int = 0) -> int: + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.01) + return x + y + + assert await expensive_async_function(1, 2) == 3 + assert call_count == 1 + + assert await expensive_async_function(1, 2) == 3 + assert call_count == 1 + + assert await expensive_async_function(1, y=2) == 3 + assert call_count == 2 + + assert await expensive_async_function(2, 3) == 5 + assert call_count == 3 + + def test_thread_isolation(self): + call_count = 0 + results = {} + + @thread_cached + def thread_specific_function(x: int) -> str: + nonlocal call_count + call_count += 1 + return f"{threading.current_thread().name}-{x}" + + def worker(thread_id: int): + result1 = thread_specific_function(1) + result2 = thread_specific_function(1) + result3 = thread_specific_function(2) + results[thread_id] = (result1, result2, result3) + + with ThreadPoolExecutor(max_workers=3) as executor: + futures = [executor.submit(worker, i) for i in range(3)] + for future in futures: + future.result() + + assert call_count >= 2 + + for thread_id, (r1, r2, r3) in results.items(): + assert r1 == r2 + assert r1 != r3 + + @pytest.mark.asyncio + async def test_async_thread_isolation(self): + call_count = 0 + results = {} + + @thread_cached + async def async_thread_specific_function(x: int) -> str: + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.01) + return f"{threading.current_thread().name}-{x}" + + async def async_worker(worker_id: int): + result1 = await async_thread_specific_function(1) + result2 = await async_thread_specific_function(1) + result3 = await async_thread_specific_function(2) + results[worker_id] = (result1, result2, result3) + + tasks = [async_worker(i) for i in range(3)] + await asyncio.gather(*tasks) + + for worker_id, (r1, r2, r3) in results.items(): + assert r1 == r2 + assert r1 != r3 + + def test_clear_cache_sync(self): + call_count = 0 + + @thread_cached + def clearable_function(x: int) -> int: + nonlocal call_count + call_count += 1 + return x * 2 + + assert clearable_function(5) == 10 + assert call_count == 1 + + assert clearable_function(5) == 10 + assert call_count == 1 + + clear_thread_cache(clearable_function) + + assert clearable_function(5) == 10 + assert call_count == 2 + + @pytest.mark.asyncio + async def test_clear_cache_async(self): + call_count = 0 + + @thread_cached + async def clearable_async_function(x: int) -> int: + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.01) + return x * 2 + + assert await clearable_async_function(5) == 10 + assert call_count == 1 + + assert await clearable_async_function(5) == 10 + assert call_count == 1 + + clear_thread_cache(clearable_async_function) + + assert await clearable_async_function(5) == 10 + assert call_count == 2 + + def test_simple_arguments(self): + call_count = 0 + + @thread_cached + def simple_function(a: str, b: int, c: str = "default") -> str: + nonlocal call_count + call_count += 1 + return f"{a}-{b}-{c}" + + # First call with all positional args + result1 = simple_function("test", 42, "custom") + assert call_count == 1 + + # Same args, all positional - should hit cache + result2 = simple_function("test", 42, "custom") + assert call_count == 1 + assert result1 == result2 + + # Same values but last arg as keyword - creates different cache key + result3 = simple_function("test", 42, c="custom") + assert call_count == 2 + assert result1 == result3 # Same result, different cache entry + + # Different value - new cache entry + result4 = simple_function("test", 43, "custom") + assert call_count == 3 + assert result1 != result4 + + def test_positional_vs_keyword_args(self): + """Test that positional and keyword arguments create different cache entries.""" + call_count = 0 + + @thread_cached + def func(a: int, b: int = 10) -> str: + nonlocal call_count + call_count += 1 + return f"result-{a}-{b}" + + # All positional + result1 = func(1, 2) + assert call_count == 1 + assert result1 == "result-1-2" + + # Same values, but second arg as keyword + result2 = func(1, b=2) + assert call_count == 2 # Different cache key! + assert result2 == "result-1-2" # Same result + + # Verify both are cached separately + func(1, 2) # Uses first cache entry + assert call_count == 2 + + func(1, b=2) # Uses second cache entry + assert call_count == 2 + + def test_exception_handling(self): + call_count = 0 + + @thread_cached + def failing_function(x: int) -> int: + nonlocal call_count + call_count += 1 + if x < 0: + raise ValueError("Negative value") + return x * 2 + + assert failing_function(5) == 10 + assert call_count == 1 + + with pytest.raises(ValueError): + failing_function(-1) + assert call_count == 2 + + with pytest.raises(ValueError): + failing_function(-1) + assert call_count == 3 + + assert failing_function(5) == 10 + assert call_count == 3 + + @pytest.mark.asyncio + async def test_async_exception_handling(self): + call_count = 0 + + @thread_cached + async def async_failing_function(x: int) -> int: + nonlocal call_count + call_count += 1 + await asyncio.sleep(0.01) + if x < 0: + raise ValueError("Negative value") + return x * 2 + + assert await async_failing_function(5) == 10 + assert call_count == 1 + + with pytest.raises(ValueError): + await async_failing_function(-1) + assert call_count == 2 + + with pytest.raises(ValueError): + await async_failing_function(-1) + assert call_count == 3 + + def test_sync_caching_performance(self): + @thread_cached + def slow_function(x: int) -> int: + print(f"slow_function called with x={x}") + time.sleep(0.1) + return x * 2 + + start = time.time() + result1 = slow_function(5) + first_call_time = time.time() - start + print(f"First call took {first_call_time:.4f} seconds") + + start = time.time() + result2 = slow_function(5) + second_call_time = time.time() - start + print(f"Second call took {second_call_time:.4f} seconds") + + assert result1 == result2 == 10 + assert first_call_time > 0.09 + assert second_call_time < 0.01 + + @pytest.mark.asyncio + async def test_async_caching_performance(self): + @thread_cached + async def slow_async_function(x: int) -> int: + print(f"slow_async_function called with x={x}") + await asyncio.sleep(0.1) + return x * 2 + + start = time.time() + result1 = await slow_async_function(5) + first_call_time = time.time() - start + print(f"First async call took {first_call_time:.4f} seconds") + + start = time.time() + result2 = await slow_async_function(5) + second_call_time = time.time() - start + print(f"Second async call took {second_call_time:.4f} seconds") + + assert result1 == result2 == 10 + assert first_call_time > 0.09 + assert second_call_time < 0.01 + + def test_with_mock_objects(self): + mock = Mock(return_value=42) + + @thread_cached + def function_using_mock(x: int) -> int: + return mock(x) + + assert function_using_mock(1) == 42 + assert mock.call_count == 1 + + assert function_using_mock(1) == 42 + assert mock.call_count == 1 + + assert function_using_mock(2) == 42 + assert mock.call_count == 2 diff --git a/autogpt_platform/backend/backend/monitoring/late_execution_monitor.py b/autogpt_platform/backend/backend/monitoring/late_execution_monitor.py index fe0ccdae91..7cace69908 100644 --- a/autogpt_platform/backend/backend/monitoring/late_execution_monitor.py +++ b/autogpt_platform/backend/backend/monitoring/late_execution_monitor.py @@ -67,9 +67,10 @@ class LateExecutionMonitor: num_running = len(running_late_executions) num_users = len(set([r.user_id for r in all_late_executions])) - # Truncate to max 100 entries - truncated_executions = all_late_executions[:100] - was_truncated = num_total_late > 100 + # Truncate to max entries + tuncate_size = 5 + truncated_executions = all_late_executions[:tuncate_size] + was_truncated = num_total_late > tuncate_size late_execution_details = [ f"* `Execution ID: {exec.id}, Graph ID: {exec.graph_id}v{exec.graph_version}, User ID: {exec.user_id}, Status: {exec.status}, Created At: {exec.started_at.isoformat()}`" @@ -85,7 +86,7 @@ class LateExecutionMonitor: if was_truncated: message_parts.append( - f"\nShowing first 100 of {num_total_late} late executions:" + f"\nShowing first {tuncate_size} of {num_total_late} late executions:" ) else: message_parts.append("\nDetails:") diff --git a/autogpt_platform/backend/backend/util/service.py b/autogpt_platform/backend/backend/util/service.py index 3be0577550..8573f03bcb 100644 --- a/autogpt_platform/backend/backend/util/service.py +++ b/autogpt_platform/backend/backend/util/service.py @@ -312,26 +312,72 @@ def get_service_client( host = service_type.get_host() port = service_type.get_port() self.base_url = f"http://{host}:{port}".rstrip("/") + self._connection_failure_count = 0 + self._last_client_reset = 0 - @cached_property - def sync_client(self) -> httpx.Client: + def _create_sync_client(self) -> httpx.Client: return httpx.Client( base_url=self.base_url, timeout=call_timeout, + limits=httpx.Limits( + max_keepalive_connections=200, # 10x default for async concurrent calls + max_connections=500, # High limit for burst handling + keepalive_expiry=30.0, # Keep connections alive longer + ), ) - @cached_property - def async_client(self) -> httpx.AsyncClient: + def _create_async_client(self) -> httpx.AsyncClient: return httpx.AsyncClient( base_url=self.base_url, timeout=call_timeout, + limits=httpx.Limits( + max_keepalive_connections=200, # 10x default for async concurrent calls + max_connections=500, # High limit for burst handling + keepalive_expiry=30.0, # Keep connections alive longer + ), ) + @cached_property + def sync_client(self) -> httpx.Client: + return self._create_sync_client() + + @cached_property + def async_client(self) -> httpx.AsyncClient: + return self._create_async_client() + + def _handle_connection_error(self, error: Exception) -> None: + """Handle connection errors and implement self-healing""" + self._connection_failure_count += 1 + current_time = time.time() + + # If we've had 3+ failures, and it's been more than 30 seconds since last reset + if ( + self._connection_failure_count >= 3 + and current_time - self._last_client_reset > 30 + ): + + logger.warning( + f"Connection failures detected ({self._connection_failure_count}), recreating HTTP clients" + ) + + # Clear cached clients to force recreation on next access + # Only recreate when there's actually a problem + if hasattr(self, "sync_client"): + delattr(self, "sync_client") + if hasattr(self, "async_client"): + delattr(self, "async_client") + + # Reset counters + self._connection_failure_count = 0 + self._last_client_reset = current_time + def _handle_call_method_response( self, *, response: httpx.Response, method_name: str ) -> Any: try: response.raise_for_status() + # Reset failure count on successful response + self._connection_failure_count = 0 return response.json() except httpx.HTTPStatusError as e: logger.error(f"HTTP error in {method_name}: {e.response.text}") @@ -343,19 +389,27 @@ def get_service_client( @_maybe_retry def _call_method_sync(self, method_name: str, **kwargs: Any) -> Any: - return self._handle_call_method_response( - method_name=method_name, - response=self.sync_client.post(method_name, json=to_dict(kwargs)), - ) + try: + return self._handle_call_method_response( + method_name=method_name, + response=self.sync_client.post(method_name, json=to_dict(kwargs)), + ) + except (httpx.ConnectError, httpx.ConnectTimeout) as e: + self._handle_connection_error(e) + raise @_maybe_retry async def _call_method_async(self, method_name: str, **kwargs: Any) -> Any: - return self._handle_call_method_response( - method_name=method_name, - response=await self.async_client.post( - method_name, json=to_dict(kwargs) - ), - ) + try: + return self._handle_call_method_response( + method_name=method_name, + response=await self.async_client.post( + method_name, json=to_dict(kwargs) + ), + ) + except (httpx.ConnectError, httpx.ConnectTimeout) as e: + self._handle_connection_error(e) + raise async def aclose(self) -> None: self.sync_client.close() diff --git a/autogpt_platform/backend/backend/util/service_test.py b/autogpt_platform/backend/backend/util/service_test.py index 3d8eb29419..8b9e295bb9 100644 --- a/autogpt_platform/backend/backend/util/service_test.py +++ b/autogpt_platform/backend/backend/util/service_test.py @@ -1,3 +1,8 @@ +import time +from functools import cached_property +from unittest.mock import Mock + +import httpx import pytest from backend.util.service import ( @@ -60,3 +65,251 @@ async def test_service_creation(server): assert client.fun_with_async(5, 3) == 8 assert await client.add_async(5, 3) == 8 assert await client.subtract_async(10, 4) == 6 + + +class TestDynamicClientConnectionHealing: + """Test the DynamicClient connection healing logic""" + + def setup_method(self): + """Setup for each test method""" + # Create a mock service client type + self.mock_service_type = Mock() + self.mock_service_type.get_host.return_value = "localhost" + self.mock_service_type.get_port.return_value = 8000 + + self.mock_service_client_type = Mock() + self.mock_service_client_type.get_service_type.return_value = ( + self.mock_service_type + ) + + # Create our test client with the real DynamicClient logic + self.client = self._create_test_client() + + def _create_test_client(self): + """Create a test client that mimics the real DynamicClient""" + + class TestClient: + def __init__(self, service_client_type): + service_type = service_client_type.get_service_type() + host = service_type.get_host() + port = service_type.get_port() + self.base_url = f"http://{host}:{port}".rstrip("/") + self._connection_failure_count = 0 + self._last_client_reset = 0 + + def _create_sync_client(self) -> httpx.Client: + return Mock(spec=httpx.Client) + + def _create_async_client(self) -> httpx.AsyncClient: + return Mock(spec=httpx.AsyncClient) + + @cached_property + def sync_client(self) -> httpx.Client: + return self._create_sync_client() + + @cached_property + def async_client(self) -> httpx.AsyncClient: + return self._create_async_client() + + def _handle_connection_error(self, error: Exception) -> None: + """Handle connection errors and implement self-healing""" + self._connection_failure_count += 1 + current_time = time.time() + + # If we've had 3+ failures and it's been more than 30 seconds since last reset + if ( + self._connection_failure_count >= 3 + and current_time - self._last_client_reset > 30 + ): + + # Clear cached clients to force recreation on next access + if hasattr(self, "sync_client"): + delattr(self, "sync_client") + if hasattr(self, "async_client"): + delattr(self, "async_client") + + # Reset counters + self._connection_failure_count = 0 + self._last_client_reset = current_time + + return TestClient(self.mock_service_client_type) + + def test_client_caching(self): + """Test that clients are cached via @cached_property""" + # Get clients multiple times + sync1 = self.client.sync_client + sync2 = self.client.sync_client + async1 = self.client.async_client + async2 = self.client.async_client + + # Should return same instances (cached) + assert sync1 is sync2, "Sync clients should be cached" + assert async1 is async2, "Async clients should be cached" + + def test_connection_error_counting(self): + """Test that connection errors are counted correctly""" + initial_count = self.client._connection_failure_count + + # Simulate connection errors + self.client._handle_connection_error(Exception("Connection failed")) + assert self.client._connection_failure_count == initial_count + 1 + + self.client._handle_connection_error(Exception("Connection failed")) + assert self.client._connection_failure_count == initial_count + 2 + + def test_no_reset_before_threshold(self): + """Test that clients are NOT reset before reaching failure threshold""" + # Get initial clients + sync_before = self.client.sync_client + async_before = self.client.async_client + + # Simulate 2 failures (below threshold of 3) + self.client._handle_connection_error(Exception("Connection failed")) + self.client._handle_connection_error(Exception("Connection failed")) + + # Clients should still be the same (no reset) + sync_after = self.client.sync_client + async_after = self.client.async_client + + assert ( + sync_before is sync_after + ), "Sync client should not be reset before threshold" + assert ( + async_before is async_after + ), "Async client should not be reset before threshold" + assert self.client._connection_failure_count == 2 + + def test_no_reset_within_time_window(self): + """Test that clients are NOT reset if within the 30-second window""" + # Get initial clients + sync_before = self.client.sync_client + async_before = self.client.async_client + + # Set last reset to recent time (within 30 seconds) + self.client._last_client_reset = time.time() - 10 # 10 seconds ago + + # Simulate 3+ failures + for _ in range(3): + self.client._handle_connection_error(Exception("Connection failed")) + + # Clients should still be the same (no reset due to time window) + sync_after = self.client.sync_client + async_after = self.client.async_client + + assert ( + sync_before is sync_after + ), "Sync client should not be reset within time window" + assert ( + async_before is async_after + ), "Async client should not be reset within time window" + assert self.client._connection_failure_count == 3 + + def test_reset_after_threshold_and_time(self): + """Test that clients ARE reset after threshold failures and time window""" + # Get initial clients + sync_before = self.client.sync_client + async_before = self.client.async_client + + # Set last reset to old time (beyond 30 seconds) + self.client._last_client_reset = time.time() - 60 # 60 seconds ago + + # Simulate 3+ failures to trigger reset + for _ in range(3): + self.client._handle_connection_error(Exception("Connection failed")) + + # Clients should be different (reset occurred) + sync_after = self.client.sync_client + async_after = self.client.async_client + + assert ( + sync_before is not sync_after + ), "Sync client should be reset after threshold" + assert ( + async_before is not async_after + ), "Async client should be reset after threshold" + assert ( + self.client._connection_failure_count == 0 + ), "Failure count should be reset" + + def test_reset_counters_after_healing(self): + """Test that counters are properly reset after healing""" + # Set up for reset + self.client._last_client_reset = time.time() - 60 + self.client._connection_failure_count = 5 + + # Trigger reset + self.client._handle_connection_error(Exception("Connection failed")) + + # Check counters are reset + assert self.client._connection_failure_count == 0 + assert self.client._last_client_reset > time.time() - 5 # Recently reset + + +class TestConnectionHealingIntegration: + """Integration tests for the complete connection healing workflow""" + + def test_failure_count_reset_on_success(self): + """Test that failure count would be reset on successful requests""" + + # This simulates what happens in _handle_call_method_response + class ClientWithSuccessHandling: + def __init__(self): + self._connection_failure_count = 5 + + def _handle_successful_response(self): + # This is what happens in the real _handle_call_method_response + self._connection_failure_count = 0 + + client = ClientWithSuccessHandling() + client._handle_successful_response() + assert client._connection_failure_count == 0 + + def test_thirty_second_window_timing(self): + """Test that the 30-second window works as expected""" + current_time = time.time() + + # Test cases for the timing logic + test_cases = [ + (current_time - 10, False), # 10 seconds ago - should NOT reset + (current_time - 29, False), # 29 seconds ago - should NOT reset + (current_time - 31, True), # 31 seconds ago - should reset + (current_time - 60, True), # 60 seconds ago - should reset + ] + + for last_reset_time, should_reset in test_cases: + failure_count = 3 # At threshold + time_condition = current_time - last_reset_time > 30 + should_trigger_reset = failure_count >= 3 and time_condition + + assert ( + should_trigger_reset == should_reset + ), f"Time window logic failed for {current_time - last_reset_time} seconds ago" + + +def test_cached_property_behavior(): + """Test that @cached_property works as expected for our use case""" + creation_count = 0 + + class TestCachedProperty: + @cached_property + def expensive_resource(self): + nonlocal creation_count + creation_count += 1 + return f"resource-{creation_count}" + + obj = TestCachedProperty() + + # First access should create + resource1 = obj.expensive_resource + assert creation_count == 1 + + # Second access should return cached + resource2 = obj.expensive_resource + assert creation_count == 1 # No additional creation + assert resource1 is resource2 + + # Deleting the cached property should allow recreation + delattr(obj, "expensive_resource") + resource3 = obj.expensive_resource + assert creation_count == 2 # New creation + assert resource1 != resource3