mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(backend): Introduce http client refresh on repeated error (#10481)
HTTP requests can fail when the DNS is messed up. Sometimes this kind of issue requires a client reset. ### Changes 🏗️ Introduce HTTP client refresh on repeated error ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: <!-- Put your test plan here: --> - [x] Manual run, added tests
This commit is contained in:
325
autogpt_platform/autogpt_libs/autogpt_libs/utils/cache_test.py
Normal file
325
autogpt_platform/autogpt_libs/autogpt_libs/utils/cache_test.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""Tests for the @thread_cached decorator.
|
||||
|
||||
This module tests the thread-local caching functionality including:
|
||||
- Basic caching for sync and async functions
|
||||
- Thread isolation (each thread has its own cache)
|
||||
- Cache clearing functionality
|
||||
- Exception handling (exceptions are not cached)
|
||||
- Argument handling (positional vs keyword arguments)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from autogpt_libs.utils.cache import clear_thread_cache, thread_cached
|
||||
|
||||
|
||||
class TestThreadCached:
|
||||
def test_sync_function_caching(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def expensive_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x + y
|
||||
|
||||
assert expensive_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert expensive_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert expensive_function(1, y=2) == 3
|
||||
assert call_count == 2
|
||||
|
||||
assert expensive_function(2, 3) == 5
|
||||
assert call_count == 3
|
||||
|
||||
assert expensive_function(1) == 1
|
||||
assert call_count == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function_caching(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
async def expensive_async_function(x: int, y: int = 0) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x + y
|
||||
|
||||
assert await expensive_async_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert await expensive_async_function(1, 2) == 3
|
||||
assert call_count == 1
|
||||
|
||||
assert await expensive_async_function(1, y=2) == 3
|
||||
assert call_count == 2
|
||||
|
||||
assert await expensive_async_function(2, 3) == 5
|
||||
assert call_count == 3
|
||||
|
||||
def test_thread_isolation(self):
|
||||
call_count = 0
|
||||
results = {}
|
||||
|
||||
@thread_cached
|
||||
def thread_specific_function(x: int) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{threading.current_thread().name}-{x}"
|
||||
|
||||
def worker(thread_id: int):
|
||||
result1 = thread_specific_function(1)
|
||||
result2 = thread_specific_function(1)
|
||||
result3 = thread_specific_function(2)
|
||||
results[thread_id] = (result1, result2, result3)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||
futures = [executor.submit(worker, i) for i in range(3)]
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
assert call_count >= 2
|
||||
|
||||
for thread_id, (r1, r2, r3) in results.items():
|
||||
assert r1 == r2
|
||||
assert r1 != r3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_thread_isolation(self):
|
||||
call_count = 0
|
||||
results = {}
|
||||
|
||||
@thread_cached
|
||||
async def async_thread_specific_function(x: int) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return f"{threading.current_thread().name}-{x}"
|
||||
|
||||
async def async_worker(worker_id: int):
|
||||
result1 = await async_thread_specific_function(1)
|
||||
result2 = await async_thread_specific_function(1)
|
||||
result3 = await async_thread_specific_function(2)
|
||||
results[worker_id] = (result1, result2, result3)
|
||||
|
||||
tasks = [async_worker(i) for i in range(3)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
for worker_id, (r1, r2, r3) in results.items():
|
||||
assert r1 == r2
|
||||
assert r1 != r3
|
||||
|
||||
def test_clear_cache_sync(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def clearable_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return x * 2
|
||||
|
||||
assert clearable_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
assert clearable_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
clear_thread_cache(clearable_function)
|
||||
|
||||
assert clearable_function(5) == 10
|
||||
assert call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_cache_async(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
async def clearable_async_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
return x * 2
|
||||
|
||||
assert await clearable_async_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
assert await clearable_async_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
clear_thread_cache(clearable_async_function)
|
||||
|
||||
assert await clearable_async_function(5) == 10
|
||||
assert call_count == 2
|
||||
|
||||
def test_simple_arguments(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def simple_function(a: str, b: int, c: str = "default") -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"{a}-{b}-{c}"
|
||||
|
||||
# First call with all positional args
|
||||
result1 = simple_function("test", 42, "custom")
|
||||
assert call_count == 1
|
||||
|
||||
# Same args, all positional - should hit cache
|
||||
result2 = simple_function("test", 42, "custom")
|
||||
assert call_count == 1
|
||||
assert result1 == result2
|
||||
|
||||
# Same values but last arg as keyword - creates different cache key
|
||||
result3 = simple_function("test", 42, c="custom")
|
||||
assert call_count == 2
|
||||
assert result1 == result3 # Same result, different cache entry
|
||||
|
||||
# Different value - new cache entry
|
||||
result4 = simple_function("test", 43, "custom")
|
||||
assert call_count == 3
|
||||
assert result1 != result4
|
||||
|
||||
def test_positional_vs_keyword_args(self):
|
||||
"""Test that positional and keyword arguments create different cache entries."""
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def func(a: int, b: int = 10) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"result-{a}-{b}"
|
||||
|
||||
# All positional
|
||||
result1 = func(1, 2)
|
||||
assert call_count == 1
|
||||
assert result1 == "result-1-2"
|
||||
|
||||
# Same values, but second arg as keyword
|
||||
result2 = func(1, b=2)
|
||||
assert call_count == 2 # Different cache key!
|
||||
assert result2 == "result-1-2" # Same result
|
||||
|
||||
# Verify both are cached separately
|
||||
func(1, 2) # Uses first cache entry
|
||||
assert call_count == 2
|
||||
|
||||
func(1, b=2) # Uses second cache entry
|
||||
assert call_count == 2
|
||||
|
||||
def test_exception_handling(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
def failing_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if x < 0:
|
||||
raise ValueError("Negative value")
|
||||
return x * 2
|
||||
|
||||
assert failing_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
failing_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
failing_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
assert failing_function(5) == 10
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_exception_handling(self):
|
||||
call_count = 0
|
||||
|
||||
@thread_cached
|
||||
async def async_failing_function(x: int) -> int:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
await asyncio.sleep(0.01)
|
||||
if x < 0:
|
||||
raise ValueError("Negative value")
|
||||
return x * 2
|
||||
|
||||
assert await async_failing_function(5) == 10
|
||||
assert call_count == 1
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await async_failing_function(-1)
|
||||
assert call_count == 2
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await async_failing_function(-1)
|
||||
assert call_count == 3
|
||||
|
||||
def test_sync_caching_performance(self):
|
||||
@thread_cached
|
||||
def slow_function(x: int) -> int:
|
||||
print(f"slow_function called with x={x}")
|
||||
time.sleep(0.1)
|
||||
return x * 2
|
||||
|
||||
start = time.time()
|
||||
result1 = slow_function(5)
|
||||
first_call_time = time.time() - start
|
||||
print(f"First call took {first_call_time:.4f} seconds")
|
||||
|
||||
start = time.time()
|
||||
result2 = slow_function(5)
|
||||
second_call_time = time.time() - start
|
||||
print(f"Second call took {second_call_time:.4f} seconds")
|
||||
|
||||
assert result1 == result2 == 10
|
||||
assert first_call_time > 0.09
|
||||
assert second_call_time < 0.01
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_caching_performance(self):
|
||||
@thread_cached
|
||||
async def slow_async_function(x: int) -> int:
|
||||
print(f"slow_async_function called with x={x}")
|
||||
await asyncio.sleep(0.1)
|
||||
return x * 2
|
||||
|
||||
start = time.time()
|
||||
result1 = await slow_async_function(5)
|
||||
first_call_time = time.time() - start
|
||||
print(f"First async call took {first_call_time:.4f} seconds")
|
||||
|
||||
start = time.time()
|
||||
result2 = await slow_async_function(5)
|
||||
second_call_time = time.time() - start
|
||||
print(f"Second async call took {second_call_time:.4f} seconds")
|
||||
|
||||
assert result1 == result2 == 10
|
||||
assert first_call_time > 0.09
|
||||
assert second_call_time < 0.01
|
||||
|
||||
def test_with_mock_objects(self):
|
||||
mock = Mock(return_value=42)
|
||||
|
||||
@thread_cached
|
||||
def function_using_mock(x: int) -> int:
|
||||
return mock(x)
|
||||
|
||||
assert function_using_mock(1) == 42
|
||||
assert mock.call_count == 1
|
||||
|
||||
assert function_using_mock(1) == 42
|
||||
assert mock.call_count == 1
|
||||
|
||||
assert function_using_mock(2) == 42
|
||||
assert mock.call_count == 2
|
||||
@@ -67,9 +67,10 @@ class LateExecutionMonitor:
|
||||
num_running = len(running_late_executions)
|
||||
num_users = len(set([r.user_id for r in all_late_executions]))
|
||||
|
||||
# Truncate to max 100 entries
|
||||
truncated_executions = all_late_executions[:100]
|
||||
was_truncated = num_total_late > 100
|
||||
# Truncate to max entries
|
||||
tuncate_size = 5
|
||||
truncated_executions = all_late_executions[:tuncate_size]
|
||||
was_truncated = num_total_late > tuncate_size
|
||||
|
||||
late_execution_details = [
|
||||
f"* `Execution ID: {exec.id}, Graph ID: {exec.graph_id}v{exec.graph_version}, User ID: {exec.user_id}, Status: {exec.status}, Created At: {exec.started_at.isoformat()}`"
|
||||
@@ -85,7 +86,7 @@ class LateExecutionMonitor:
|
||||
|
||||
if was_truncated:
|
||||
message_parts.append(
|
||||
f"\nShowing first 100 of {num_total_late} late executions:"
|
||||
f"\nShowing first {tuncate_size} of {num_total_late} late executions:"
|
||||
)
|
||||
else:
|
||||
message_parts.append("\nDetails:")
|
||||
|
||||
@@ -312,26 +312,72 @@ def get_service_client(
|
||||
host = service_type.get_host()
|
||||
port = service_type.get_port()
|
||||
self.base_url = f"http://{host}:{port}".rstrip("/")
|
||||
self._connection_failure_count = 0
|
||||
self._last_client_reset = 0
|
||||
|
||||
@cached_property
|
||||
def sync_client(self) -> httpx.Client:
|
||||
def _create_sync_client(self) -> httpx.Client:
|
||||
return httpx.Client(
|
||||
base_url=self.base_url,
|
||||
timeout=call_timeout,
|
||||
limits=httpx.Limits(
|
||||
max_keepalive_connections=200, # 10x default for async concurrent calls
|
||||
max_connections=500, # High limit for burst handling
|
||||
keepalive_expiry=30.0, # Keep connections alive longer
|
||||
),
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def async_client(self) -> httpx.AsyncClient:
|
||||
def _create_async_client(self) -> httpx.AsyncClient:
|
||||
return httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
timeout=call_timeout,
|
||||
limits=httpx.Limits(
|
||||
max_keepalive_connections=200, # 10x default for async concurrent calls
|
||||
max_connections=500, # High limit for burst handling
|
||||
keepalive_expiry=30.0, # Keep connections alive longer
|
||||
),
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def sync_client(self) -> httpx.Client:
|
||||
return self._create_sync_client()
|
||||
|
||||
@cached_property
|
||||
def async_client(self) -> httpx.AsyncClient:
|
||||
return self._create_async_client()
|
||||
|
||||
def _handle_connection_error(self, error: Exception) -> None:
|
||||
"""Handle connection errors and implement self-healing"""
|
||||
self._connection_failure_count += 1
|
||||
current_time = time.time()
|
||||
|
||||
# If we've had 3+ failures, and it's been more than 30 seconds since last reset
|
||||
if (
|
||||
self._connection_failure_count >= 3
|
||||
and current_time - self._last_client_reset > 30
|
||||
):
|
||||
|
||||
logger.warning(
|
||||
f"Connection failures detected ({self._connection_failure_count}), recreating HTTP clients"
|
||||
)
|
||||
|
||||
# Clear cached clients to force recreation on next access
|
||||
# Only recreate when there's actually a problem
|
||||
if hasattr(self, "sync_client"):
|
||||
delattr(self, "sync_client")
|
||||
if hasattr(self, "async_client"):
|
||||
delattr(self, "async_client")
|
||||
|
||||
# Reset counters
|
||||
self._connection_failure_count = 0
|
||||
self._last_client_reset = current_time
|
||||
|
||||
def _handle_call_method_response(
|
||||
self, *, response: httpx.Response, method_name: str
|
||||
) -> Any:
|
||||
try:
|
||||
response.raise_for_status()
|
||||
# Reset failure count on successful response
|
||||
self._connection_failure_count = 0
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error in {method_name}: {e.response.text}")
|
||||
@@ -343,19 +389,27 @@ def get_service_client(
|
||||
|
||||
@_maybe_retry
|
||||
def _call_method_sync(self, method_name: str, **kwargs: Any) -> Any:
|
||||
return self._handle_call_method_response(
|
||||
method_name=method_name,
|
||||
response=self.sync_client.post(method_name, json=to_dict(kwargs)),
|
||||
)
|
||||
try:
|
||||
return self._handle_call_method_response(
|
||||
method_name=method_name,
|
||||
response=self.sync_client.post(method_name, json=to_dict(kwargs)),
|
||||
)
|
||||
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
|
||||
self._handle_connection_error(e)
|
||||
raise
|
||||
|
||||
@_maybe_retry
|
||||
async def _call_method_async(self, method_name: str, **kwargs: Any) -> Any:
|
||||
return self._handle_call_method_response(
|
||||
method_name=method_name,
|
||||
response=await self.async_client.post(
|
||||
method_name, json=to_dict(kwargs)
|
||||
),
|
||||
)
|
||||
try:
|
||||
return self._handle_call_method_response(
|
||||
method_name=method_name,
|
||||
response=await self.async_client.post(
|
||||
method_name, json=to_dict(kwargs)
|
||||
),
|
||||
)
|
||||
except (httpx.ConnectError, httpx.ConnectTimeout) as e:
|
||||
self._handle_connection_error(e)
|
||||
raise
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.sync_client.close()
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
import time
|
||||
from functools import cached_property
|
||||
from unittest.mock import Mock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from backend.util.service import (
|
||||
@@ -60,3 +65,251 @@ async def test_service_creation(server):
|
||||
assert client.fun_with_async(5, 3) == 8
|
||||
assert await client.add_async(5, 3) == 8
|
||||
assert await client.subtract_async(10, 4) == 6
|
||||
|
||||
|
||||
class TestDynamicClientConnectionHealing:
|
||||
"""Test the DynamicClient connection healing logic"""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup for each test method"""
|
||||
# Create a mock service client type
|
||||
self.mock_service_type = Mock()
|
||||
self.mock_service_type.get_host.return_value = "localhost"
|
||||
self.mock_service_type.get_port.return_value = 8000
|
||||
|
||||
self.mock_service_client_type = Mock()
|
||||
self.mock_service_client_type.get_service_type.return_value = (
|
||||
self.mock_service_type
|
||||
)
|
||||
|
||||
# Create our test client with the real DynamicClient logic
|
||||
self.client = self._create_test_client()
|
||||
|
||||
def _create_test_client(self):
|
||||
"""Create a test client that mimics the real DynamicClient"""
|
||||
|
||||
class TestClient:
|
||||
def __init__(self, service_client_type):
|
||||
service_type = service_client_type.get_service_type()
|
||||
host = service_type.get_host()
|
||||
port = service_type.get_port()
|
||||
self.base_url = f"http://{host}:{port}".rstrip("/")
|
||||
self._connection_failure_count = 0
|
||||
self._last_client_reset = 0
|
||||
|
||||
def _create_sync_client(self) -> httpx.Client:
|
||||
return Mock(spec=httpx.Client)
|
||||
|
||||
def _create_async_client(self) -> httpx.AsyncClient:
|
||||
return Mock(spec=httpx.AsyncClient)
|
||||
|
||||
@cached_property
|
||||
def sync_client(self) -> httpx.Client:
|
||||
return self._create_sync_client()
|
||||
|
||||
@cached_property
|
||||
def async_client(self) -> httpx.AsyncClient:
|
||||
return self._create_async_client()
|
||||
|
||||
def _handle_connection_error(self, error: Exception) -> None:
|
||||
"""Handle connection errors and implement self-healing"""
|
||||
self._connection_failure_count += 1
|
||||
current_time = time.time()
|
||||
|
||||
# If we've had 3+ failures and it's been more than 30 seconds since last reset
|
||||
if (
|
||||
self._connection_failure_count >= 3
|
||||
and current_time - self._last_client_reset > 30
|
||||
):
|
||||
|
||||
# Clear cached clients to force recreation on next access
|
||||
if hasattr(self, "sync_client"):
|
||||
delattr(self, "sync_client")
|
||||
if hasattr(self, "async_client"):
|
||||
delattr(self, "async_client")
|
||||
|
||||
# Reset counters
|
||||
self._connection_failure_count = 0
|
||||
self._last_client_reset = current_time
|
||||
|
||||
return TestClient(self.mock_service_client_type)
|
||||
|
||||
def test_client_caching(self):
|
||||
"""Test that clients are cached via @cached_property"""
|
||||
# Get clients multiple times
|
||||
sync1 = self.client.sync_client
|
||||
sync2 = self.client.sync_client
|
||||
async1 = self.client.async_client
|
||||
async2 = self.client.async_client
|
||||
|
||||
# Should return same instances (cached)
|
||||
assert sync1 is sync2, "Sync clients should be cached"
|
||||
assert async1 is async2, "Async clients should be cached"
|
||||
|
||||
def test_connection_error_counting(self):
|
||||
"""Test that connection errors are counted correctly"""
|
||||
initial_count = self.client._connection_failure_count
|
||||
|
||||
# Simulate connection errors
|
||||
self.client._handle_connection_error(Exception("Connection failed"))
|
||||
assert self.client._connection_failure_count == initial_count + 1
|
||||
|
||||
self.client._handle_connection_error(Exception("Connection failed"))
|
||||
assert self.client._connection_failure_count == initial_count + 2
|
||||
|
||||
def test_no_reset_before_threshold(self):
|
||||
"""Test that clients are NOT reset before reaching failure threshold"""
|
||||
# Get initial clients
|
||||
sync_before = self.client.sync_client
|
||||
async_before = self.client.async_client
|
||||
|
||||
# Simulate 2 failures (below threshold of 3)
|
||||
self.client._handle_connection_error(Exception("Connection failed"))
|
||||
self.client._handle_connection_error(Exception("Connection failed"))
|
||||
|
||||
# Clients should still be the same (no reset)
|
||||
sync_after = self.client.sync_client
|
||||
async_after = self.client.async_client
|
||||
|
||||
assert (
|
||||
sync_before is sync_after
|
||||
), "Sync client should not be reset before threshold"
|
||||
assert (
|
||||
async_before is async_after
|
||||
), "Async client should not be reset before threshold"
|
||||
assert self.client._connection_failure_count == 2
|
||||
|
||||
def test_no_reset_within_time_window(self):
|
||||
"""Test that clients are NOT reset if within the 30-second window"""
|
||||
# Get initial clients
|
||||
sync_before = self.client.sync_client
|
||||
async_before = self.client.async_client
|
||||
|
||||
# Set last reset to recent time (within 30 seconds)
|
||||
self.client._last_client_reset = time.time() - 10 # 10 seconds ago
|
||||
|
||||
# Simulate 3+ failures
|
||||
for _ in range(3):
|
||||
self.client._handle_connection_error(Exception("Connection failed"))
|
||||
|
||||
# Clients should still be the same (no reset due to time window)
|
||||
sync_after = self.client.sync_client
|
||||
async_after = self.client.async_client
|
||||
|
||||
assert (
|
||||
sync_before is sync_after
|
||||
), "Sync client should not be reset within time window"
|
||||
assert (
|
||||
async_before is async_after
|
||||
), "Async client should not be reset within time window"
|
||||
assert self.client._connection_failure_count == 3
|
||||
|
||||
def test_reset_after_threshold_and_time(self):
|
||||
"""Test that clients ARE reset after threshold failures and time window"""
|
||||
# Get initial clients
|
||||
sync_before = self.client.sync_client
|
||||
async_before = self.client.async_client
|
||||
|
||||
# Set last reset to old time (beyond 30 seconds)
|
||||
self.client._last_client_reset = time.time() - 60 # 60 seconds ago
|
||||
|
||||
# Simulate 3+ failures to trigger reset
|
||||
for _ in range(3):
|
||||
self.client._handle_connection_error(Exception("Connection failed"))
|
||||
|
||||
# Clients should be different (reset occurred)
|
||||
sync_after = self.client.sync_client
|
||||
async_after = self.client.async_client
|
||||
|
||||
assert (
|
||||
sync_before is not sync_after
|
||||
), "Sync client should be reset after threshold"
|
||||
assert (
|
||||
async_before is not async_after
|
||||
), "Async client should be reset after threshold"
|
||||
assert (
|
||||
self.client._connection_failure_count == 0
|
||||
), "Failure count should be reset"
|
||||
|
||||
def test_reset_counters_after_healing(self):
|
||||
"""Test that counters are properly reset after healing"""
|
||||
# Set up for reset
|
||||
self.client._last_client_reset = time.time() - 60
|
||||
self.client._connection_failure_count = 5
|
||||
|
||||
# Trigger reset
|
||||
self.client._handle_connection_error(Exception("Connection failed"))
|
||||
|
||||
# Check counters are reset
|
||||
assert self.client._connection_failure_count == 0
|
||||
assert self.client._last_client_reset > time.time() - 5 # Recently reset
|
||||
|
||||
|
||||
class TestConnectionHealingIntegration:
|
||||
"""Integration tests for the complete connection healing workflow"""
|
||||
|
||||
def test_failure_count_reset_on_success(self):
|
||||
"""Test that failure count would be reset on successful requests"""
|
||||
|
||||
# This simulates what happens in _handle_call_method_response
|
||||
class ClientWithSuccessHandling:
|
||||
def __init__(self):
|
||||
self._connection_failure_count = 5
|
||||
|
||||
def _handle_successful_response(self):
|
||||
# This is what happens in the real _handle_call_method_response
|
||||
self._connection_failure_count = 0
|
||||
|
||||
client = ClientWithSuccessHandling()
|
||||
client._handle_successful_response()
|
||||
assert client._connection_failure_count == 0
|
||||
|
||||
def test_thirty_second_window_timing(self):
|
||||
"""Test that the 30-second window works as expected"""
|
||||
current_time = time.time()
|
||||
|
||||
# Test cases for the timing logic
|
||||
test_cases = [
|
||||
(current_time - 10, False), # 10 seconds ago - should NOT reset
|
||||
(current_time - 29, False), # 29 seconds ago - should NOT reset
|
||||
(current_time - 31, True), # 31 seconds ago - should reset
|
||||
(current_time - 60, True), # 60 seconds ago - should reset
|
||||
]
|
||||
|
||||
for last_reset_time, should_reset in test_cases:
|
||||
failure_count = 3 # At threshold
|
||||
time_condition = current_time - last_reset_time > 30
|
||||
should_trigger_reset = failure_count >= 3 and time_condition
|
||||
|
||||
assert (
|
||||
should_trigger_reset == should_reset
|
||||
), f"Time window logic failed for {current_time - last_reset_time} seconds ago"
|
||||
|
||||
|
||||
def test_cached_property_behavior():
|
||||
"""Test that @cached_property works as expected for our use case"""
|
||||
creation_count = 0
|
||||
|
||||
class TestCachedProperty:
|
||||
@cached_property
|
||||
def expensive_resource(self):
|
||||
nonlocal creation_count
|
||||
creation_count += 1
|
||||
return f"resource-{creation_count}"
|
||||
|
||||
obj = TestCachedProperty()
|
||||
|
||||
# First access should create
|
||||
resource1 = obj.expensive_resource
|
||||
assert creation_count == 1
|
||||
|
||||
# Second access should return cached
|
||||
resource2 = obj.expensive_resource
|
||||
assert creation_count == 1 # No additional creation
|
||||
assert resource1 is resource2
|
||||
|
||||
# Deleting the cached property should allow recreation
|
||||
delattr(obj, "expensive_resource")
|
||||
resource3 = obj.expensive_resource
|
||||
assert creation_count == 2 # New creation
|
||||
assert resource1 != resource3
|
||||
|
||||
Reference in New Issue
Block a user