mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
fix: eliminate N+1 performance bug in RemoteSandboxService with batch endpoint (#12105)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -477,7 +477,11 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
|||||||
if sandbox.status in (None, SandboxStatus.ERROR):
|
if sandbox.status in (None, SandboxStatus.ERROR):
|
||||||
raise SandboxError(f'Sandbox status: {sandbox.status}')
|
raise SandboxError(f'Sandbox status: {sandbox.status}')
|
||||||
if sandbox.status == SandboxStatus.RUNNING:
|
if sandbox.status == SandboxStatus.RUNNING:
|
||||||
return
|
# There are still bugs in the remote runtime - they report running while still just
|
||||||
|
# starting resulting in a race condition. Manually check that it is actually
|
||||||
|
# running.
|
||||||
|
if await self._check_agent_server_alive(sandbox):
|
||||||
|
return
|
||||||
if sandbox.status != SandboxStatus.STARTING:
|
if sandbox.status != SandboxStatus.STARTING:
|
||||||
raise SandboxError(f'Sandbox not startable: {sandbox.id}')
|
raise SandboxError(f'Sandbox not startable: {sandbox.id}')
|
||||||
|
|
||||||
@@ -490,9 +494,19 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
|||||||
if sandbox.status not in (SandboxStatus.STARTING, SandboxStatus.RUNNING):
|
if sandbox.status not in (SandboxStatus.STARTING, SandboxStatus.RUNNING):
|
||||||
raise SandboxError(f'Sandbox not startable: {sandbox.id}')
|
raise SandboxError(f'Sandbox not startable: {sandbox.id}')
|
||||||
if sandbox_info.status == SandboxStatus.RUNNING:
|
if sandbox_info.status == SandboxStatus.RUNNING:
|
||||||
return
|
# There are still bugs in the remote runtime - they report running while still just
|
||||||
|
# starting resulting in a race condition. Manually check that it is actually
|
||||||
|
# running.
|
||||||
|
if await self._check_agent_server_alive(sandbox_info):
|
||||||
|
return
|
||||||
raise SandboxError(f'Sandbox failed to start: {sandbox.id}')
|
raise SandboxError(f'Sandbox failed to start: {sandbox.id}')
|
||||||
|
|
||||||
|
async def _check_agent_server_alive(self, sandbox_info: SandboxInfo) -> bool:
|
||||||
|
agent_server_url = self._get_agent_server_url(sandbox_info)
|
||||||
|
url = f'{agent_server_url.rstrip("/")}/alive'
|
||||||
|
response = await self.httpx_client.get(url)
|
||||||
|
return response.is_success
|
||||||
|
|
||||||
def _get_agent_server_url(self, sandbox: SandboxInfo) -> str:
|
def _get_agent_server_url(self, sandbox: SandboxInfo) -> str:
|
||||||
"""Get agent server url for running sandbox."""
|
"""Get agent server url for running sandbox."""
|
||||||
exposed_urls = sandbox.exposed_urls
|
exposed_urls = sandbox.exposed_urls
|
||||||
|
|||||||
@@ -122,18 +122,9 @@ class RemoteSandboxService(SandboxService):
|
|||||||
_logger.error(f'HTTP error for URL {url}: {e}')
|
_logger.error(f'HTTP error for URL {url}: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _to_sandbox_info(
|
def _to_sandbox_info(
|
||||||
self, stored: StoredRemoteSandbox, runtime: dict[str, Any] | None = None
|
self, stored: StoredRemoteSandbox, runtime: dict[str, Any] | None = None
|
||||||
) -> SandboxInfo:
|
):
|
||||||
# If we did not get passsed runtime data, load some
|
|
||||||
if runtime is None:
|
|
||||||
try:
|
|
||||||
runtime = await self._get_runtime(stored.id)
|
|
||||||
except Exception:
|
|
||||||
_logger.exception(
|
|
||||||
f'Error getting runtime: {stored.id}', stack_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
status = self._get_sandbox_status_from_runtime(runtime)
|
status = self._get_sandbox_status_from_runtime(runtime)
|
||||||
|
|
||||||
# Get session_api_key and exposed urls
|
# Get session_api_key and exposed urls
|
||||||
@@ -233,6 +224,41 @@ class RemoteSandboxService(SandboxService):
|
|||||||
runtime_data = response.json()
|
runtime_data = response.json()
|
||||||
return runtime_data
|
return runtime_data
|
||||||
|
|
||||||
|
async def _get_runtimes_batch(
|
||||||
|
self, sandbox_ids: list[str]
|
||||||
|
) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Get multiple runtimes in a single batch request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sandbox_ids: List of sandbox IDs to fetch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping sandbox_id to runtime data
|
||||||
|
"""
|
||||||
|
if not sandbox_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Build query parameters for the batch endpoint
|
||||||
|
params = [('ids', sandbox_id) for sandbox_id in sandbox_ids]
|
||||||
|
|
||||||
|
response = await self._send_runtime_api_request(
|
||||||
|
'GET',
|
||||||
|
'/sessions/batch',
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
batch_data = response.json()
|
||||||
|
|
||||||
|
# The batch endpoint should return a list of runtimes
|
||||||
|
# Convert to a dictionary keyed by session_id for easy lookup
|
||||||
|
runtimes_by_id = {}
|
||||||
|
if batch_data and 'runtimes' in batch_data:
|
||||||
|
for runtime in batch_data['runtimes']:
|
||||||
|
if 'session_id' in runtime:
|
||||||
|
runtimes_by_id[runtime['session_id']] = runtime
|
||||||
|
|
||||||
|
return runtimes_by_id
|
||||||
|
|
||||||
async def _init_environment(
|
async def _init_environment(
|
||||||
self, sandbox_spec: SandboxSpecInfo, sandbox_id: str
|
self, sandbox_spec: SandboxSpecInfo, sandbox_id: str
|
||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
@@ -283,13 +309,15 @@ class RemoteSandboxService(SandboxService):
|
|||||||
if has_more:
|
if has_more:
|
||||||
next_page_id = str(offset + limit)
|
next_page_id = str(offset + limit)
|
||||||
|
|
||||||
# Convert stored callbacks to domain models
|
# Batch fetch runtime data for all sandboxes
|
||||||
items = await asyncio.gather(
|
sandbox_ids = [stored_sandbox.id for stored_sandbox in stored_sandboxes]
|
||||||
*[
|
runtimes_by_id = await self._get_runtimes_batch(sandbox_ids)
|
||||||
self._to_sandbox_info(stored_sandbox)
|
|
||||||
for stored_sandbox in stored_sandboxes
|
# Convert stored sandboxes to domain models with runtime data
|
||||||
]
|
items = [
|
||||||
)
|
self._to_sandbox_info(stored_sandbox, runtimes_by_id.get(stored_sandbox.id))
|
||||||
|
for stored_sandbox in stored_sandboxes
|
||||||
|
]
|
||||||
|
|
||||||
return SandboxPage(items=items, next_page_id=next_page_id)
|
return SandboxPage(items=items, next_page_id=next_page_id)
|
||||||
|
|
||||||
@@ -298,7 +326,16 @@ class RemoteSandboxService(SandboxService):
|
|||||||
stored_sandbox = await self._get_stored_sandbox(sandbox_id)
|
stored_sandbox = await self._get_stored_sandbox(sandbox_id)
|
||||||
if stored_sandbox is None:
|
if stored_sandbox is None:
|
||||||
return None
|
return None
|
||||||
return await self._to_sandbox_info(stored_sandbox)
|
|
||||||
|
runtime = None
|
||||||
|
try:
|
||||||
|
runtime = await self._get_runtime(stored_sandbox.id)
|
||||||
|
except Exception:
|
||||||
|
_logger.exception(
|
||||||
|
f'Error getting runtime: {stored_sandbox.id}', stack_info=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._to_sandbox_info(stored_sandbox, runtime)
|
||||||
|
|
||||||
async def get_sandbox_by_session_api_key(
|
async def get_sandbox_by_session_api_key(
|
||||||
self, session_api_key: str
|
self, session_api_key: str
|
||||||
@@ -323,7 +360,7 @@ class RemoteSandboxService(SandboxService):
|
|||||||
sandbox = result.first()
|
sandbox = result.first()
|
||||||
if sandbox is None:
|
if sandbox is None:
|
||||||
raise ValueError('sandbox_not_found')
|
raise ValueError('sandbox_not_found')
|
||||||
return await self._to_sandbox_info(sandbox, runtime)
|
return self._to_sandbox_info(sandbox, runtime)
|
||||||
except Exception:
|
except Exception:
|
||||||
_logger.exception(
|
_logger.exception(
|
||||||
'Error getting sandbox from session_api_key', stack_info=True
|
'Error getting sandbox from session_api_key', stack_info=True
|
||||||
@@ -339,7 +376,7 @@ class RemoteSandboxService(SandboxService):
|
|||||||
try:
|
try:
|
||||||
runtime = await self._get_runtime(stored_sandbox.id)
|
runtime = await self._get_runtime(stored_sandbox.id)
|
||||||
if runtime and runtime.get('session_api_key') == session_api_key:
|
if runtime and runtime.get('session_api_key') == session_api_key:
|
||||||
return await self._to_sandbox_info(stored_sandbox, runtime)
|
return self._to_sandbox_info(stored_sandbox, runtime)
|
||||||
except Exception:
|
except Exception:
|
||||||
# Continue checking other sandboxes if one fails
|
# Continue checking other sandboxes if one fails
|
||||||
continue
|
continue
|
||||||
@@ -412,7 +449,7 @@ class RemoteSandboxService(SandboxService):
|
|||||||
# Hack - result doesn't contain this
|
# Hack - result doesn't contain this
|
||||||
runtime_data['pod_status'] = 'pending'
|
runtime_data['pod_status'] = 'pending'
|
||||||
|
|
||||||
return await self._to_sandbox_info(stored_sandbox, runtime_data)
|
return self._to_sandbox_info(stored_sandbox, runtime_data)
|
||||||
|
|
||||||
except httpx.HTTPError as e:
|
except httpx.HTTPError as e:
|
||||||
_logger.error(f'Failed to start sandbox: {e}')
|
_logger.error(f'Failed to start sandbox: {e}')
|
||||||
@@ -480,6 +517,55 @@ class RemoteSandboxService(SandboxService):
|
|||||||
_logger.error(f'Error deleting sandbox {sandbox_id}: {e}')
|
_logger.error(f'Error deleting sandbox {sandbox_id}: {e}')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def pause_old_sandboxes(self, max_num_sandboxes: int) -> list[str]:
|
||||||
|
"""Pause the oldest sandboxes if there are more than max_num_sandboxes running.
|
||||||
|
In a multi user environment, this will pause sandboxes only for the current user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_num_sandboxes: Maximum number of sandboxes to keep running
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of sandbox IDs that were paused
|
||||||
|
"""
|
||||||
|
if max_num_sandboxes <= 0:
|
||||||
|
raise ValueError('max_num_sandboxes must be greater than 0')
|
||||||
|
|
||||||
|
response = await self._send_runtime_api_request(
|
||||||
|
'GET',
|
||||||
|
'/list',
|
||||||
|
)
|
||||||
|
content = response.json()
|
||||||
|
running_session_ids = [
|
||||||
|
runtime.get('session_id') for runtime in content['runtimes']
|
||||||
|
]
|
||||||
|
|
||||||
|
query = await self._secure_select()
|
||||||
|
query = query.filter(StoredRemoteSandbox.id.in_(running_session_ids)).order_by(
|
||||||
|
StoredRemoteSandbox.created_at.desc()
|
||||||
|
)
|
||||||
|
running_sandboxes = list(await self.db_session.execute(query))
|
||||||
|
|
||||||
|
# If we're within the limit, no cleanup needed
|
||||||
|
if len(running_sandboxes) <= max_num_sandboxes:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Determine how many to pause
|
||||||
|
num_to_pause = len(running_sandboxes) - max_num_sandboxes
|
||||||
|
sandboxes_to_pause = running_sandboxes[:num_to_pause]
|
||||||
|
|
||||||
|
# Stop the oldest sandboxes
|
||||||
|
paused_sandbox_ids = []
|
||||||
|
for sandbox in sandboxes_to_pause:
|
||||||
|
try:
|
||||||
|
success = await self.pause_sandbox(sandbox.id)
|
||||||
|
if success:
|
||||||
|
paused_sandbox_ids.append(sandbox.id)
|
||||||
|
except Exception:
|
||||||
|
# Continue trying to pause other sandboxes even if one fails
|
||||||
|
pass
|
||||||
|
|
||||||
|
return paused_sandbox_ids
|
||||||
|
|
||||||
|
|
||||||
def _build_service_url(url: str, service_name: str):
|
def _build_service_url(url: str, service_name: str):
|
||||||
scheme, host_and_path = url.split('://')
|
scheme, host_and_path = url.split('://')
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ class SandboxService(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
async def pause_old_sandboxes(self, max_num_sandboxes: int) -> list[str]:
|
async def pause_old_sandboxes(self, max_num_sandboxes: int) -> list[str]:
|
||||||
"""Stop the oldest sandboxes if there are more than max_num_sandboxes running.
|
"""Pause the oldest sandboxes if there are more than max_num_sandboxes running.
|
||||||
In a multi user environment, this will pause sandboxes only for the current user.
|
In a multi user environment, this will pause sandboxes only for the current user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -331,7 +331,7 @@ class TestSandboxInfoConversion:
|
|||||||
runtime_data = create_runtime_data(status='running', pod_status='ready')
|
runtime_data = create_runtime_data(status='running', pod_status='ready')
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
sandbox_info = await remote_sandbox_service._to_sandbox_info(
|
sandbox_info = remote_sandbox_service._to_sandbox_info(
|
||||||
stored_sandbox, runtime_data
|
stored_sandbox, runtime_data
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -358,7 +358,7 @@ class TestSandboxInfoConversion:
|
|||||||
runtime_data = create_runtime_data(status='running', pod_status='pending')
|
runtime_data = create_runtime_data(status='running', pod_status='pending')
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
sandbox_info = await remote_sandbox_service._to_sandbox_info(
|
sandbox_info = remote_sandbox_service._to_sandbox_info(
|
||||||
stored_sandbox, runtime_data
|
stored_sandbox, runtime_data
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -367,23 +367,6 @@ class TestSandboxInfoConversion:
|
|||||||
assert sandbox_info.session_api_key == 'test-session-key'
|
assert sandbox_info.session_api_key == 'test-session-key'
|
||||||
assert sandbox_info.exposed_urls is None
|
assert sandbox_info.exposed_urls is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_to_sandbox_info_without_runtime(self, remote_sandbox_service):
|
|
||||||
"""Test conversion to SandboxInfo without runtime data."""
|
|
||||||
# Setup
|
|
||||||
stored_sandbox = create_stored_sandbox()
|
|
||||||
remote_sandbox_service._get_runtime = AsyncMock(
|
|
||||||
side_effect=Exception('Runtime not found')
|
|
||||||
)
|
|
||||||
|
|
||||||
# Execute
|
|
||||||
sandbox_info = await remote_sandbox_service._to_sandbox_info(stored_sandbox)
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
assert sandbox_info.status == SandboxStatus.MISSING
|
|
||||||
assert sandbox_info.session_api_key is None
|
|
||||||
assert sandbox_info.exposed_urls is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_to_sandbox_info_loads_runtime_when_none_provided(
|
async def test_to_sandbox_info_loads_runtime_when_none_provided(
|
||||||
self, remote_sandbox_service
|
self, remote_sandbox_service
|
||||||
@@ -391,15 +374,12 @@ class TestSandboxInfoConversion:
|
|||||||
"""Test that runtime data is loaded when not provided."""
|
"""Test that runtime data is loaded when not provided."""
|
||||||
# Setup
|
# Setup
|
||||||
stored_sandbox = create_stored_sandbox()
|
stored_sandbox = create_stored_sandbox()
|
||||||
runtime_data = create_runtime_data()
|
|
||||||
remote_sandbox_service._get_runtime = AsyncMock(return_value=runtime_data)
|
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
sandbox_info = await remote_sandbox_service._to_sandbox_info(stored_sandbox)
|
sandbox_info = remote_sandbox_service._to_sandbox_info(stored_sandbox, None)
|
||||||
|
|
||||||
# Verify
|
# Verify
|
||||||
remote_sandbox_service._get_runtime.assert_called_once_with('test-sandbox-123')
|
assert sandbox_info.status == SandboxStatus.MISSING
|
||||||
assert sandbox_info.status == SandboxStatus.RUNNING
|
|
||||||
|
|
||||||
|
|
||||||
class TestSandboxLifecycle:
|
class TestSandboxLifecycle:
|
||||||
@@ -677,15 +657,18 @@ class TestSandboxSearch:
|
|||||||
mock_result = MagicMock()
|
mock_result = MagicMock()
|
||||||
mock_result.scalars.return_value = mock_scalars
|
mock_result.scalars.return_value = mock_scalars
|
||||||
remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result)
|
remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result)
|
||||||
remote_sandbox_service._to_sandbox_info = AsyncMock(
|
|
||||||
side_effect=lambda stored: SandboxInfo(
|
# Mock the batch endpoint response
|
||||||
id=stored.id,
|
mock_batch_response = MagicMock()
|
||||||
created_by_user_id=stored.created_by_user_id,
|
mock_batch_response.raise_for_status.return_value = None
|
||||||
sandbox_spec_id=stored.sandbox_spec_id,
|
mock_batch_response.json.return_value = {
|
||||||
status=SandboxStatus.RUNNING,
|
'runtimes': [
|
||||||
session_api_key='test-key',
|
create_runtime_data('sb1'),
|
||||||
created_at=stored.created_at,
|
create_runtime_data('sb2'),
|
||||||
)
|
]
|
||||||
|
}
|
||||||
|
remote_sandbox_service.httpx_client.request = AsyncMock(
|
||||||
|
return_value=mock_batch_response
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
@@ -697,6 +680,14 @@ class TestSandboxSearch:
|
|||||||
assert result.items[0].id == 'sb1'
|
assert result.items[0].id == 'sb1'
|
||||||
assert result.items[1].id == 'sb2'
|
assert result.items[1].id == 'sb2'
|
||||||
|
|
||||||
|
# Verify that the batch endpoint was called
|
||||||
|
remote_sandbox_service.httpx_client.request.assert_called_once_with(
|
||||||
|
'GET',
|
||||||
|
'https://api.example.com/sessions/batch',
|
||||||
|
headers={'X-API-Key': 'test-api-key'},
|
||||||
|
params=[('ids', 'sb1'), ('ids', 'sb2')],
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_sandboxes_with_pagination(self, remote_sandbox_service):
|
async def test_search_sandboxes_with_pagination(self, remote_sandbox_service):
|
||||||
"""Test sandbox search with pagination."""
|
"""Test sandbox search with pagination."""
|
||||||
@@ -710,15 +701,15 @@ class TestSandboxSearch:
|
|||||||
mock_result = MagicMock()
|
mock_result = MagicMock()
|
||||||
mock_result.scalars.return_value = mock_scalars
|
mock_result.scalars.return_value = mock_scalars
|
||||||
remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result)
|
remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result)
|
||||||
remote_sandbox_service._to_sandbox_info = AsyncMock(
|
|
||||||
side_effect=lambda stored: SandboxInfo(
|
# Mock the batch endpoint response
|
||||||
id=stored.id,
|
mock_batch_response = MagicMock()
|
||||||
created_by_user_id=stored.created_by_user_id,
|
mock_batch_response.raise_for_status.return_value = None
|
||||||
sandbox_spec_id=stored.sandbox_spec_id,
|
mock_batch_response.json.return_value = {
|
||||||
status=SandboxStatus.RUNNING,
|
'runtimes': [create_runtime_data(f'sb{i}') for i in range(6)]
|
||||||
session_api_key='test-key',
|
}
|
||||||
created_at=stored.created_at,
|
remote_sandbox_service.httpx_client.request = AsyncMock(
|
||||||
)
|
return_value=mock_batch_response
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
@@ -739,15 +730,15 @@ class TestSandboxSearch:
|
|||||||
mock_result = MagicMock()
|
mock_result = MagicMock()
|
||||||
mock_result.scalars.return_value = mock_scalars
|
mock_result.scalars.return_value = mock_scalars
|
||||||
remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result)
|
remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result)
|
||||||
remote_sandbox_service._to_sandbox_info = AsyncMock(
|
|
||||||
side_effect=lambda stored: SandboxInfo(
|
# Mock the batch endpoint response
|
||||||
id=stored.id,
|
mock_batch_response = MagicMock()
|
||||||
created_by_user_id=stored.created_by_user_id,
|
mock_batch_response.raise_for_status.return_value = None
|
||||||
sandbox_spec_id=stored.sandbox_spec_id,
|
mock_batch_response.json.return_value = {
|
||||||
status=SandboxStatus.RUNNING,
|
'runtimes': [create_runtime_data('sb1')]
|
||||||
session_api_key='test-key',
|
}
|
||||||
created_at=stored.created_at,
|
remote_sandbox_service.httpx_client.request = AsyncMock(
|
||||||
)
|
return_value=mock_batch_response
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
@@ -757,6 +748,80 @@ class TestSandboxSearch:
|
|||||||
# Note: We can't easily verify the exact SQL query, but we can verify the method was called
|
# Note: We can't easily verify the exact SQL query, but we can verify the method was called
|
||||||
remote_sandbox_service.db_session.execute.assert_called_once()
|
remote_sandbox_service.db_session.execute.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_runtimes_batch_success(self, remote_sandbox_service):
|
||||||
|
"""Test successful batch runtime retrieval."""
|
||||||
|
# Setup
|
||||||
|
sandbox_ids = ['sb1', 'sb2', 'sb3']
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
'runtimes': [
|
||||||
|
create_runtime_data('sb1'),
|
||||||
|
create_runtime_data('sb2'),
|
||||||
|
create_runtime_data('sb3'),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
remote_sandbox_service.httpx_client.request = AsyncMock(
|
||||||
|
return_value=mock_response
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = await remote_sandbox_service._get_runtimes_batch(sandbox_ids)
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
assert len(result) == 3
|
||||||
|
assert 'sb1' in result
|
||||||
|
assert 'sb2' in result
|
||||||
|
assert 'sb3' in result
|
||||||
|
assert result['sb1']['session_id'] == 'sb1'
|
||||||
|
|
||||||
|
# Verify the correct API call was made
|
||||||
|
remote_sandbox_service.httpx_client.request.assert_called_once_with(
|
||||||
|
'GET',
|
||||||
|
'https://api.example.com/sessions/batch',
|
||||||
|
headers={'X-API-Key': 'test-api-key'},
|
||||||
|
params=[('ids', 'sb1'), ('ids', 'sb2'), ('ids', 'sb3')],
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_runtimes_batch_empty_list(self, remote_sandbox_service):
|
||||||
|
"""Test batch runtime retrieval with empty sandbox list."""
|
||||||
|
# Execute
|
||||||
|
result = await remote_sandbox_service._get_runtimes_batch([])
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
assert result == {}
|
||||||
|
# Verify no API call was made
|
||||||
|
remote_sandbox_service.httpx_client.request.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_runtimes_batch_partial_results(self, remote_sandbox_service):
|
||||||
|
"""Test batch runtime retrieval with partial results (some sandboxes not found)."""
|
||||||
|
# Setup
|
||||||
|
sandbox_ids = ['sb1', 'sb2', 'sb3']
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.raise_for_status.return_value = None
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
'runtimes': [
|
||||||
|
create_runtime_data('sb1'),
|
||||||
|
create_runtime_data('sb3'),
|
||||||
|
# sb2 is missing from the response
|
||||||
|
]
|
||||||
|
}
|
||||||
|
remote_sandbox_service.httpx_client.request = AsyncMock(
|
||||||
|
return_value=mock_response
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
result = await remote_sandbox_service._get_runtimes_batch(sandbox_ids)
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
assert len(result) == 2
|
||||||
|
assert 'sb1' in result
|
||||||
|
assert 'sb2' not in result # Missing from response
|
||||||
|
assert 'sb3' in result
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_sandbox_exists(self, remote_sandbox_service):
|
async def test_get_sandbox_exists(self, remote_sandbox_service):
|
||||||
"""Test getting an existing sandbox."""
|
"""Test getting an existing sandbox."""
|
||||||
@@ -765,7 +830,7 @@ class TestSandboxSearch:
|
|||||||
remote_sandbox_service._get_stored_sandbox = AsyncMock(
|
remote_sandbox_service._get_stored_sandbox = AsyncMock(
|
||||||
return_value=stored_sandbox
|
return_value=stored_sandbox
|
||||||
)
|
)
|
||||||
remote_sandbox_service._to_sandbox_info = AsyncMock(
|
remote_sandbox_service._to_sandbox_info = MagicMock(
|
||||||
return_value=SandboxInfo(
|
return_value=SandboxInfo(
|
||||||
id='test-sandbox-123',
|
id='test-sandbox-123',
|
||||||
created_by_user_id='test-user-123',
|
created_by_user_id='test-user-123',
|
||||||
|
|||||||
Reference in New Issue
Block a user