feat(backend): Introduce http client refresh on repeated error (#10481)

HTTP requests can fail when the DNS is messed up. Sometimes this kind of
issue requires a client reset.

### Changes 🏗️

Introduce HTTP client refresh on repeated error

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Manual run, added tests
This commit is contained in:
Zamil Majdy
2025-07-30 09:21:28 +08:00
committed by GitHub
parent 83f96b75c7
commit b9c7642cfc
4 changed files with 651 additions and 18 deletions

View File

@@ -0,0 +1,325 @@
"""Tests for the @thread_cached decorator.
This module tests the thread-local caching functionality including:
- Basic caching for sync and async functions
- Thread isolation (each thread has its own cache)
- Cache clearing functionality
- Exception handling (exceptions are not cached)
- Argument handling (positional vs keyword arguments)
"""
import asyncio
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock
import pytest
from autogpt_libs.utils.cache import clear_thread_cache, thread_cached
class TestThreadCached:
def test_sync_function_caching(self):
call_count = 0
@thread_cached
def expensive_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
return x + y
assert expensive_function(1, 2) == 3
assert call_count == 1
assert expensive_function(1, 2) == 3
assert call_count == 1
assert expensive_function(1, y=2) == 3
assert call_count == 2
assert expensive_function(2, 3) == 5
assert call_count == 3
assert expensive_function(1) == 1
assert call_count == 4
@pytest.mark.asyncio
async def test_async_function_caching(self):
call_count = 0
@thread_cached
async def expensive_async_function(x: int, y: int = 0) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return x + y
assert await expensive_async_function(1, 2) == 3
assert call_count == 1
assert await expensive_async_function(1, 2) == 3
assert call_count == 1
assert await expensive_async_function(1, y=2) == 3
assert call_count == 2
assert await expensive_async_function(2, 3) == 5
assert call_count == 3
def test_thread_isolation(self):
call_count = 0
results = {}
@thread_cached
def thread_specific_function(x: int) -> str:
nonlocal call_count
call_count += 1
return f"{threading.current_thread().name}-{x}"
def worker(thread_id: int):
result1 = thread_specific_function(1)
result2 = thread_specific_function(1)
result3 = thread_specific_function(2)
results[thread_id] = (result1, result2, result3)
with ThreadPoolExecutor(max_workers=3) as executor:
futures = [executor.submit(worker, i) for i in range(3)]
for future in futures:
future.result()
assert call_count >= 2
for thread_id, (r1, r2, r3) in results.items():
assert r1 == r2
assert r1 != r3
@pytest.mark.asyncio
async def test_async_thread_isolation(self):
call_count = 0
results = {}
@thread_cached
async def async_thread_specific_function(x: int) -> str:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return f"{threading.current_thread().name}-{x}"
async def async_worker(worker_id: int):
result1 = await async_thread_specific_function(1)
result2 = await async_thread_specific_function(1)
result3 = await async_thread_specific_function(2)
results[worker_id] = (result1, result2, result3)
tasks = [async_worker(i) for i in range(3)]
await asyncio.gather(*tasks)
for worker_id, (r1, r2, r3) in results.items():
assert r1 == r2
assert r1 != r3
def test_clear_cache_sync(self):
call_count = 0
@thread_cached
def clearable_function(x: int) -> int:
nonlocal call_count
call_count += 1
return x * 2
assert clearable_function(5) == 10
assert call_count == 1
assert clearable_function(5) == 10
assert call_count == 1
clear_thread_cache(clearable_function)
assert clearable_function(5) == 10
assert call_count == 2
@pytest.mark.asyncio
async def test_clear_cache_async(self):
call_count = 0
@thread_cached
async def clearable_async_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
return x * 2
assert await clearable_async_function(5) == 10
assert call_count == 1
assert await clearable_async_function(5) == 10
assert call_count == 1
clear_thread_cache(clearable_async_function)
assert await clearable_async_function(5) == 10
assert call_count == 2
def test_simple_arguments(self):
call_count = 0
@thread_cached
def simple_function(a: str, b: int, c: str = "default") -> str:
nonlocal call_count
call_count += 1
return f"{a}-{b}-{c}"
# First call with all positional args
result1 = simple_function("test", 42, "custom")
assert call_count == 1
# Same args, all positional - should hit cache
result2 = simple_function("test", 42, "custom")
assert call_count == 1
assert result1 == result2
# Same values but last arg as keyword - creates different cache key
result3 = simple_function("test", 42, c="custom")
assert call_count == 2
assert result1 == result3 # Same result, different cache entry
# Different value - new cache entry
result4 = simple_function("test", 43, "custom")
assert call_count == 3
assert result1 != result4
def test_positional_vs_keyword_args(self):
"""Test that positional and keyword arguments create different cache entries."""
call_count = 0
@thread_cached
def func(a: int, b: int = 10) -> str:
nonlocal call_count
call_count += 1
return f"result-{a}-{b}"
# All positional
result1 = func(1, 2)
assert call_count == 1
assert result1 == "result-1-2"
# Same values, but second arg as keyword
result2 = func(1, b=2)
assert call_count == 2 # Different cache key!
assert result2 == "result-1-2" # Same result
# Verify both are cached separately
func(1, 2) # Uses first cache entry
assert call_count == 2
func(1, b=2) # Uses second cache entry
assert call_count == 2
def test_exception_handling(self):
call_count = 0
@thread_cached
def failing_function(x: int) -> int:
nonlocal call_count
call_count += 1
if x < 0:
raise ValueError("Negative value")
return x * 2
assert failing_function(5) == 10
assert call_count == 1
with pytest.raises(ValueError):
failing_function(-1)
assert call_count == 2
with pytest.raises(ValueError):
failing_function(-1)
assert call_count == 3
assert failing_function(5) == 10
assert call_count == 3
@pytest.mark.asyncio
async def test_async_exception_handling(self):
call_count = 0
@thread_cached
async def async_failing_function(x: int) -> int:
nonlocal call_count
call_count += 1
await asyncio.sleep(0.01)
if x < 0:
raise ValueError("Negative value")
return x * 2
assert await async_failing_function(5) == 10
assert call_count == 1
with pytest.raises(ValueError):
await async_failing_function(-1)
assert call_count == 2
with pytest.raises(ValueError):
await async_failing_function(-1)
assert call_count == 3
def test_sync_caching_performance(self):
@thread_cached
def slow_function(x: int) -> int:
print(f"slow_function called with x={x}")
time.sleep(0.1)
return x * 2
start = time.time()
result1 = slow_function(5)
first_call_time = time.time() - start
print(f"First call took {first_call_time:.4f} seconds")
start = time.time()
result2 = slow_function(5)
second_call_time = time.time() - start
print(f"Second call took {second_call_time:.4f} seconds")
assert result1 == result2 == 10
assert first_call_time > 0.09
assert second_call_time < 0.01
@pytest.mark.asyncio
async def test_async_caching_performance(self):
@thread_cached
async def slow_async_function(x: int) -> int:
print(f"slow_async_function called with x={x}")
await asyncio.sleep(0.1)
return x * 2
start = time.time()
result1 = await slow_async_function(5)
first_call_time = time.time() - start
print(f"First async call took {first_call_time:.4f} seconds")
start = time.time()
result2 = await slow_async_function(5)
second_call_time = time.time() - start
print(f"Second async call took {second_call_time:.4f} seconds")
assert result1 == result2 == 10
assert first_call_time > 0.09
assert second_call_time < 0.01
def test_with_mock_objects(self):
mock = Mock(return_value=42)
@thread_cached
def function_using_mock(x: int) -> int:
return mock(x)
assert function_using_mock(1) == 42
assert mock.call_count == 1
assert function_using_mock(1) == 42
assert mock.call_count == 1
assert function_using_mock(2) == 42
assert mock.call_count == 2

View File

@@ -67,9 +67,10 @@ class LateExecutionMonitor:
num_running = len(running_late_executions)
num_users = len(set([r.user_id for r in all_late_executions]))
# Truncate to max 100 entries
truncated_executions = all_late_executions[:100]
was_truncated = num_total_late > 100
# Truncate to max entries
tuncate_size = 5
truncated_executions = all_late_executions[:tuncate_size]
was_truncated = num_total_late > tuncate_size
late_execution_details = [
f"* `Execution ID: {exec.id}, Graph ID: {exec.graph_id}v{exec.graph_version}, User ID: {exec.user_id}, Status: {exec.status}, Created At: {exec.started_at.isoformat()}`"
@@ -85,7 +86,7 @@ class LateExecutionMonitor:
if was_truncated:
message_parts.append(
f"\nShowing first 100 of {num_total_late} late executions:"
f"\nShowing first {tuncate_size} of {num_total_late} late executions:"
)
else:
message_parts.append("\nDetails:")

View File

@@ -312,26 +312,72 @@ def get_service_client(
host = service_type.get_host()
port = service_type.get_port()
self.base_url = f"http://{host}:{port}".rstrip("/")
self._connection_failure_count = 0
self._last_client_reset = 0
@cached_property
def sync_client(self) -> httpx.Client:
def _create_sync_client(self) -> httpx.Client:
return httpx.Client(
base_url=self.base_url,
timeout=call_timeout,
limits=httpx.Limits(
max_keepalive_connections=200, # 10x default for async concurrent calls
max_connections=500, # High limit for burst handling
keepalive_expiry=30.0, # Keep connections alive longer
),
)
@cached_property
def async_client(self) -> httpx.AsyncClient:
def _create_async_client(self) -> httpx.AsyncClient:
return httpx.AsyncClient(
base_url=self.base_url,
timeout=call_timeout,
limits=httpx.Limits(
max_keepalive_connections=200, # 10x default for async concurrent calls
max_connections=500, # High limit for burst handling
keepalive_expiry=30.0, # Keep connections alive longer
),
)
@cached_property
def sync_client(self) -> httpx.Client:
return self._create_sync_client()
@cached_property
def async_client(self) -> httpx.AsyncClient:
return self._create_async_client()
def _handle_connection_error(self, error: Exception) -> None:
"""Handle connection errors and implement self-healing"""
self._connection_failure_count += 1
current_time = time.time()
# If we've had 3+ failures, and it's been more than 30 seconds since last reset
if (
self._connection_failure_count >= 3
and current_time - self._last_client_reset > 30
):
logger.warning(
f"Connection failures detected ({self._connection_failure_count}), recreating HTTP clients"
)
# Clear cached clients to force recreation on next access
# Only recreate when there's actually a problem
if hasattr(self, "sync_client"):
delattr(self, "sync_client")
if hasattr(self, "async_client"):
delattr(self, "async_client")
# Reset counters
self._connection_failure_count = 0
self._last_client_reset = current_time
def _handle_call_method_response(
self, *, response: httpx.Response, method_name: str
) -> Any:
try:
response.raise_for_status()
# Reset failure count on successful response
self._connection_failure_count = 0
return response.json()
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error in {method_name}: {e.response.text}")
@@ -343,19 +389,27 @@ def get_service_client(
@_maybe_retry
def _call_method_sync(self, method_name: str, **kwargs: Any) -> Any:
return self._handle_call_method_response(
method_name=method_name,
response=self.sync_client.post(method_name, json=to_dict(kwargs)),
)
try:
return self._handle_call_method_response(
method_name=method_name,
response=self.sync_client.post(method_name, json=to_dict(kwargs)),
)
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
self._handle_connection_error(e)
raise
@_maybe_retry
async def _call_method_async(self, method_name: str, **kwargs: Any) -> Any:
return self._handle_call_method_response(
method_name=method_name,
response=await self.async_client.post(
method_name, json=to_dict(kwargs)
),
)
try:
return self._handle_call_method_response(
method_name=method_name,
response=await self.async_client.post(
method_name, json=to_dict(kwargs)
),
)
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
self._handle_connection_error(e)
raise
async def aclose(self) -> None:
self.sync_client.close()

View File

@@ -1,3 +1,8 @@
import time
from functools import cached_property
from unittest.mock import Mock
import httpx
import pytest
from backend.util.service import (
@@ -60,3 +65,251 @@ async def test_service_creation(server):
assert client.fun_with_async(5, 3) == 8
assert await client.add_async(5, 3) == 8
assert await client.subtract_async(10, 4) == 6
class TestDynamicClientConnectionHealing:
"""Test the DynamicClient connection healing logic"""
def setup_method(self):
"""Setup for each test method"""
# Create a mock service client type
self.mock_service_type = Mock()
self.mock_service_type.get_host.return_value = "localhost"
self.mock_service_type.get_port.return_value = 8000
self.mock_service_client_type = Mock()
self.mock_service_client_type.get_service_type.return_value = (
self.mock_service_type
)
# Create our test client with the real DynamicClient logic
self.client = self._create_test_client()
def _create_test_client(self):
"""Create a test client that mimics the real DynamicClient"""
class TestClient:
def __init__(self, service_client_type):
service_type = service_client_type.get_service_type()
host = service_type.get_host()
port = service_type.get_port()
self.base_url = f"http://{host}:{port}".rstrip("/")
self._connection_failure_count = 0
self._last_client_reset = 0
def _create_sync_client(self) -> httpx.Client:
return Mock(spec=httpx.Client)
def _create_async_client(self) -> httpx.AsyncClient:
return Mock(spec=httpx.AsyncClient)
@cached_property
def sync_client(self) -> httpx.Client:
return self._create_sync_client()
@cached_property
def async_client(self) -> httpx.AsyncClient:
return self._create_async_client()
def _handle_connection_error(self, error: Exception) -> None:
"""Handle connection errors and implement self-healing"""
self._connection_failure_count += 1
current_time = time.time()
# If we've had 3+ failures and it's been more than 30 seconds since last reset
if (
self._connection_failure_count >= 3
and current_time - self._last_client_reset > 30
):
# Clear cached clients to force recreation on next access
if hasattr(self, "sync_client"):
delattr(self, "sync_client")
if hasattr(self, "async_client"):
delattr(self, "async_client")
# Reset counters
self._connection_failure_count = 0
self._last_client_reset = current_time
return TestClient(self.mock_service_client_type)
def test_client_caching(self):
"""Test that clients are cached via @cached_property"""
# Get clients multiple times
sync1 = self.client.sync_client
sync2 = self.client.sync_client
async1 = self.client.async_client
async2 = self.client.async_client
# Should return same instances (cached)
assert sync1 is sync2, "Sync clients should be cached"
assert async1 is async2, "Async clients should be cached"
def test_connection_error_counting(self):
"""Test that connection errors are counted correctly"""
initial_count = self.client._connection_failure_count
# Simulate connection errors
self.client._handle_connection_error(Exception("Connection failed"))
assert self.client._connection_failure_count == initial_count + 1
self.client._handle_connection_error(Exception("Connection failed"))
assert self.client._connection_failure_count == initial_count + 2
def test_no_reset_before_threshold(self):
"""Test that clients are NOT reset before reaching failure threshold"""
# Get initial clients
sync_before = self.client.sync_client
async_before = self.client.async_client
# Simulate 2 failures (below threshold of 3)
self.client._handle_connection_error(Exception("Connection failed"))
self.client._handle_connection_error(Exception("Connection failed"))
# Clients should still be the same (no reset)
sync_after = self.client.sync_client
async_after = self.client.async_client
assert (
sync_before is sync_after
), "Sync client should not be reset before threshold"
assert (
async_before is async_after
), "Async client should not be reset before threshold"
assert self.client._connection_failure_count == 2
def test_no_reset_within_time_window(self):
"""Test that clients are NOT reset if within the 30-second window"""
# Get initial clients
sync_before = self.client.sync_client
async_before = self.client.async_client
# Set last reset to recent time (within 30 seconds)
self.client._last_client_reset = time.time() - 10 # 10 seconds ago
# Simulate 3+ failures
for _ in range(3):
self.client._handle_connection_error(Exception("Connection failed"))
# Clients should still be the same (no reset due to time window)
sync_after = self.client.sync_client
async_after = self.client.async_client
assert (
sync_before is sync_after
), "Sync client should not be reset within time window"
assert (
async_before is async_after
), "Async client should not be reset within time window"
assert self.client._connection_failure_count == 3
def test_reset_after_threshold_and_time(self):
"""Test that clients ARE reset after threshold failures and time window"""
# Get initial clients
sync_before = self.client.sync_client
async_before = self.client.async_client
# Set last reset to old time (beyond 30 seconds)
self.client._last_client_reset = time.time() - 60 # 60 seconds ago
# Simulate 3+ failures to trigger reset
for _ in range(3):
self.client._handle_connection_error(Exception("Connection failed"))
# Clients should be different (reset occurred)
sync_after = self.client.sync_client
async_after = self.client.async_client
assert (
sync_before is not sync_after
), "Sync client should be reset after threshold"
assert (
async_before is not async_after
), "Async client should be reset after threshold"
assert (
self.client._connection_failure_count == 0
), "Failure count should be reset"
def test_reset_counters_after_healing(self):
"""Test that counters are properly reset after healing"""
# Set up for reset
self.client._last_client_reset = time.time() - 60
self.client._connection_failure_count = 5
# Trigger reset
self.client._handle_connection_error(Exception("Connection failed"))
# Check counters are reset
assert self.client._connection_failure_count == 0
assert self.client._last_client_reset > time.time() - 5 # Recently reset
class TestConnectionHealingIntegration:
"""Integration tests for the complete connection healing workflow"""
def test_failure_count_reset_on_success(self):
"""Test that failure count would be reset on successful requests"""
# This simulates what happens in _handle_call_method_response
class ClientWithSuccessHandling:
def __init__(self):
self._connection_failure_count = 5
def _handle_successful_response(self):
# This is what happens in the real _handle_call_method_response
self._connection_failure_count = 0
client = ClientWithSuccessHandling()
client._handle_successful_response()
assert client._connection_failure_count == 0
def test_thirty_second_window_timing(self):
"""Test that the 30-second window works as expected"""
current_time = time.time()
# Test cases for the timing logic
test_cases = [
(current_time - 10, False), # 10 seconds ago - should NOT reset
(current_time - 29, False), # 29 seconds ago - should NOT reset
(current_time - 31, True), # 31 seconds ago - should reset
(current_time - 60, True), # 60 seconds ago - should reset
]
for last_reset_time, should_reset in test_cases:
failure_count = 3 # At threshold
time_condition = current_time - last_reset_time > 30
should_trigger_reset = failure_count >= 3 and time_condition
assert (
should_trigger_reset == should_reset
), f"Time window logic failed for {current_time - last_reset_time} seconds ago"
def test_cached_property_behavior():
"""Test that @cached_property works as expected for our use case"""
creation_count = 0
class TestCachedProperty:
@cached_property
def expensive_resource(self):
nonlocal creation_count
creation_count += 1
return f"resource-{creation_count}"
obj = TestCachedProperty()
# First access should create
resource1 = obj.expensive_resource
assert creation_count == 1
# Second access should return cached
resource2 = obj.expensive_resource
assert creation_count == 1 # No additional creation
assert resource1 is resource2
# Deleting the cached property should allow recreation
delattr(obj, "expensive_resource")
resource3 = obj.expensive_resource
assert creation_count == 2 # New creation
assert resource1 != resource3