mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-10 07:18:10 -05:00
V1 CORS Fix (#11586)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -47,6 +47,7 @@ from openhands.app_server.utils.sql_utils import Base, UtcDateTime
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
WEBHOOK_CALLBACK_VARIABLE = 'OH_WEBHOOKS_0_BASE_URL'
|
||||
ALLOW_CORS_ORIGINS_VARIABLE = 'OH_ALLOW_CORS_ORIGINS_0'
|
||||
polling_task: asyncio.Task | None = None
|
||||
POD_STATUS_MAPPING = {
|
||||
'ready': SandboxStatus.RUNNING,
|
||||
@@ -128,22 +129,10 @@ class RemoteSandboxService(SandboxService):
|
||||
f'Error getting runtime: {stored.id}', stack_info=True
|
||||
)
|
||||
|
||||
status = self._get_sandbox_status_from_runtime(runtime)
|
||||
|
||||
# Get session_api_key and exposed urls
|
||||
if runtime:
|
||||
# Translate status
|
||||
status = None
|
||||
pod_status = runtime['pod_status'].lower()
|
||||
if pod_status:
|
||||
status = POD_STATUS_MAPPING.get(pod_status, None)
|
||||
|
||||
# If we failed to get the status from the pod status, fall back to status
|
||||
if status is None:
|
||||
runtime_status = runtime.get('status')
|
||||
if runtime_status:
|
||||
status = STATUS_MAPPING.get(runtime_status.lower(), None)
|
||||
|
||||
if status is None:
|
||||
status = SandboxStatus.MISSING
|
||||
|
||||
session_api_key = runtime['session_api_key']
|
||||
if status == SandboxStatus.RUNNING:
|
||||
exposed_urls = []
|
||||
@@ -165,7 +154,6 @@ class RemoteSandboxService(SandboxService):
|
||||
exposed_urls = None
|
||||
else:
|
||||
session_api_key = None
|
||||
status = SandboxStatus.MISSING
|
||||
exposed_urls = None
|
||||
|
||||
sandbox_spec_id = stored.sandbox_spec_id
|
||||
@@ -179,6 +167,32 @@ class RemoteSandboxService(SandboxService):
|
||||
created_at=stored.created_at,
|
||||
)
|
||||
|
||||
def _get_sandbox_status_from_runtime(
|
||||
self, runtime: dict[str, Any] | None
|
||||
) -> SandboxStatus:
|
||||
"""Derive a SandboxStatus from the runtime info. The legacy logic for getting
|
||||
the status of a runtime is inconsistent. It is divided between a "status" which
|
||||
cannot be trusted (It sometimes returns "running" for cases when the pod is
|
||||
still starting) and a "pod_status" which is not returned for list
|
||||
operations."""
|
||||
if not runtime:
|
||||
return SandboxStatus.MISSING
|
||||
|
||||
status = None
|
||||
pod_status = runtime['pod_status'].lower()
|
||||
if pod_status:
|
||||
status = POD_STATUS_MAPPING.get(pod_status, None)
|
||||
|
||||
# If we failed to get the status from the pod status, fall back to status
|
||||
if status is None:
|
||||
runtime_status = runtime.get('status')
|
||||
if runtime_status:
|
||||
status = STATUS_MAPPING.get(runtime_status.lower(), None)
|
||||
|
||||
if status is None:
|
||||
return SandboxStatus.MISSING
|
||||
return status
|
||||
|
||||
async def _secure_select(self):
|
||||
query = select(StoredRemoteSandbox)
|
||||
user_id = await self.user_context.get_user_id()
|
||||
@@ -213,6 +227,9 @@ class RemoteSandboxService(SandboxService):
|
||||
environment[WEBHOOK_CALLBACK_VARIABLE] = (
|
||||
f'{self.web_url}/api/v1/webhooks/{sandbox_id}'
|
||||
)
|
||||
# We specify CORS settings only if there is a public facing url - otherwise
|
||||
# we are probably in local development and the only url in use is localhost
|
||||
environment[ALLOW_CORS_ORIGINS_VARIABLE] = self.web_url
|
||||
|
||||
return environment
|
||||
|
||||
@@ -614,6 +631,7 @@ class RemoteSandboxServiceInjector(SandboxServiceInjector):
|
||||
)
|
||||
|
||||
# If no public facing web url is defined, poll for changes as callbacks will be unavailable.
|
||||
# This is primarily used for local development rather than production
|
||||
config = get_global_config()
|
||||
web_url = config.web_url
|
||||
if web_url is None:
|
||||
|
||||
@@ -66,6 +66,7 @@ class SandboxService(ABC):
|
||||
|
||||
async def pause_old_sandboxes(self, max_num_sandboxes: int) -> list[str]:
|
||||
"""Stop the oldest sandboxes if there are more than max_num_sandboxes running.
|
||||
In a multi user environment, this will pause sandboxes only for the current user.
|
||||
|
||||
Args:
|
||||
max_num_sandboxes: Maximum number of sandboxes to keep running
|
||||
|
||||
947
tests/unit/app_server/test_remote_sandbox_service.py
Normal file
947
tests/unit/app_server/test_remote_sandbox_service.py
Normal file
@@ -0,0 +1,947 @@
|
||||
"""Tests for RemoteSandboxService.
|
||||
|
||||
This module tests the RemoteSandboxService implementation, focusing on:
|
||||
- Remote runtime API communication and error handling
|
||||
- Sandbox lifecycle management (start, pause, resume, delete)
|
||||
- Status mapping from remote runtime to internal sandbox statuses
|
||||
- Environment variable injection for CORS and webhooks
|
||||
- Data transformation from remote runtime to SandboxInfo objects
|
||||
- User-scoped sandbox operations and security
|
||||
- Pagination and search functionality
|
||||
- Error handling for HTTP failures and edge cases
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from openhands.app_server.errors import SandboxError
|
||||
from openhands.app_server.sandbox.remote_sandbox_service import (
|
||||
ALLOW_CORS_ORIGINS_VARIABLE,
|
||||
POD_STATUS_MAPPING,
|
||||
STATUS_MAPPING,
|
||||
WEBHOOK_CALLBACK_VARIABLE,
|
||||
RemoteSandboxService,
|
||||
StoredRemoteSandbox,
|
||||
)
|
||||
from openhands.app_server.sandbox.sandbox_models import (
|
||||
AGENT_SERVER,
|
||||
VSCODE,
|
||||
WORKER_1,
|
||||
WORKER_2,
|
||||
SandboxInfo,
|
||||
SandboxStatus,
|
||||
)
|
||||
from openhands.app_server.sandbox.sandbox_spec_models import SandboxSpecInfo
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_sandbox_spec_service():
|
||||
"""Mock SandboxSpecService for testing."""
|
||||
mock_service = AsyncMock()
|
||||
mock_spec = SandboxSpecInfo(
|
||||
id='test-image:latest',
|
||||
command=['/usr/local/bin/openhands-agent-server', '--port', '60000'],
|
||||
initial_env={'TEST_VAR': 'test_value'},
|
||||
working_dir='/workspace/project',
|
||||
)
|
||||
mock_service.get_default_sandbox_spec.return_value = mock_spec
|
||||
mock_service.get_sandbox_spec.return_value = mock_spec
|
||||
return mock_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_context():
|
||||
"""Mock UserContext for testing."""
|
||||
mock_context = AsyncMock(spec=UserContext)
|
||||
mock_context.get_user_id.return_value = 'test-user-123'
|
||||
return mock_context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_httpx_client():
|
||||
"""Mock httpx.AsyncClient for testing."""
|
||||
return AsyncMock(spec=httpx.AsyncClient)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session for testing."""
|
||||
return AsyncMock(spec=AsyncSession)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def remote_sandbox_service(
|
||||
mock_sandbox_spec_service, mock_user_context, mock_httpx_client, mock_db_session
|
||||
):
|
||||
"""Create RemoteSandboxService instance with mocked dependencies."""
|
||||
return RemoteSandboxService(
|
||||
sandbox_spec_service=mock_sandbox_spec_service,
|
||||
api_url='https://api.example.com',
|
||||
api_key='test-api-key',
|
||||
web_url='https://web.example.com',
|
||||
resource_factor=1,
|
||||
runtime_class='gvisor',
|
||||
start_sandbox_timeout=120,
|
||||
max_num_sandboxes=10,
|
||||
user_context=mock_user_context,
|
||||
httpx_client=mock_httpx_client,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
|
||||
def create_runtime_data(
|
||||
session_id: str = 'test-sandbox-123',
|
||||
status: str = 'running',
|
||||
pod_status: str = 'ready',
|
||||
url: str = 'https://sandbox.example.com',
|
||||
session_api_key: str = 'test-session-key',
|
||||
runtime_id: str = 'runtime-456',
|
||||
) -> dict[str, Any]:
|
||||
"""Helper function to create runtime data for testing."""
|
||||
return {
|
||||
'session_id': session_id,
|
||||
'status': status,
|
||||
'pod_status': pod_status,
|
||||
'url': url,
|
||||
'session_api_key': session_api_key,
|
||||
'runtime_id': runtime_id,
|
||||
}
|
||||
|
||||
|
||||
def create_stored_sandbox(
|
||||
sandbox_id: str = 'test-sandbox-123',
|
||||
user_id: str = 'test-user-123',
|
||||
spec_id: str = 'test-image:latest',
|
||||
created_at: datetime | None = None,
|
||||
) -> StoredRemoteSandbox:
|
||||
"""Helper function to create StoredRemoteSandbox for testing."""
|
||||
if created_at is None:
|
||||
created_at = datetime.now(timezone.utc)
|
||||
|
||||
return StoredRemoteSandbox(
|
||||
id=sandbox_id,
|
||||
created_by_user_id=user_id,
|
||||
sandbox_spec_id=spec_id,
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
|
||||
class TestRemoteSandboxService:
|
||||
"""Test cases for RemoteSandboxService core functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_runtime_api_request_success(self, remote_sandbox_service):
|
||||
"""Test successful API request to remote runtime."""
|
||||
# Setup
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {'result': 'success'}
|
||||
remote_sandbox_service.httpx_client.request.return_value = mock_response
|
||||
|
||||
# Execute
|
||||
response = await remote_sandbox_service._send_runtime_api_request(
|
||||
'GET', '/test-endpoint', json={'test': 'data'}
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert response == mock_response
|
||||
remote_sandbox_service.httpx_client.request.assert_called_once_with(
|
||||
'GET',
|
||||
'https://api.example.com/test-endpoint',
|
||||
headers={'X-API-Key': 'test-api-key'},
|
||||
json={'test': 'data'},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_runtime_api_request_timeout(self, remote_sandbox_service):
|
||||
"""Test API request timeout handling."""
|
||||
# Setup
|
||||
remote_sandbox_service.httpx_client.request.side_effect = (
|
||||
httpx.TimeoutException('Request timeout')
|
||||
)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(httpx.TimeoutException):
|
||||
await remote_sandbox_service._send_runtime_api_request('GET', '/test')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_runtime_api_request_http_error(self, remote_sandbox_service):
|
||||
"""Test API request HTTP error handling."""
|
||||
# Setup
|
||||
remote_sandbox_service.httpx_client.request.side_effect = httpx.HTTPError(
|
||||
'HTTP error'
|
||||
)
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(httpx.HTTPError):
|
||||
await remote_sandbox_service._send_runtime_api_request('GET', '/test')
|
||||
|
||||
|
||||
class TestStatusMapping:
|
||||
"""Test cases for status mapping functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sandbox_status_from_runtime_with_pod_status(
|
||||
self, remote_sandbox_service
|
||||
):
|
||||
"""Test status mapping using pod_status."""
|
||||
runtime_data = create_runtime_data(pod_status='ready')
|
||||
|
||||
status = remote_sandbox_service._get_sandbox_status_from_runtime(runtime_data)
|
||||
|
||||
assert status == SandboxStatus.RUNNING
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sandbox_status_from_runtime_fallback_to_status(
|
||||
self, remote_sandbox_service
|
||||
):
|
||||
"""Test status mapping fallback to status field."""
|
||||
runtime_data = create_runtime_data(
|
||||
pod_status='unknown_pod_status', status='running'
|
||||
)
|
||||
|
||||
status = remote_sandbox_service._get_sandbox_status_from_runtime(runtime_data)
|
||||
|
||||
assert status == SandboxStatus.RUNNING
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sandbox_status_from_runtime_no_runtime(
|
||||
self, remote_sandbox_service
|
||||
):
|
||||
"""Test status mapping with no runtime data."""
|
||||
status = remote_sandbox_service._get_sandbox_status_from_runtime(None)
|
||||
|
||||
assert status == SandboxStatus.MISSING
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sandbox_status_from_runtime_unknown_status(
|
||||
self, remote_sandbox_service
|
||||
):
|
||||
"""Test status mapping with unknown status values."""
|
||||
runtime_data = create_runtime_data(
|
||||
pod_status='unknown_pod', status='unknown_status'
|
||||
)
|
||||
|
||||
status = remote_sandbox_service._get_sandbox_status_from_runtime(runtime_data)
|
||||
|
||||
assert status == SandboxStatus.MISSING
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pod_status_mapping_coverage(self, remote_sandbox_service):
|
||||
"""Test all pod status mappings are handled correctly."""
|
||||
test_cases = [
|
||||
('ready', SandboxStatus.RUNNING),
|
||||
('pending', SandboxStatus.STARTING),
|
||||
('running', SandboxStatus.STARTING),
|
||||
('failed', SandboxStatus.ERROR),
|
||||
('unknown', SandboxStatus.ERROR),
|
||||
('crashloopbackoff', SandboxStatus.ERROR),
|
||||
]
|
||||
|
||||
for pod_status, expected_status in test_cases:
|
||||
runtime_data = create_runtime_data(pod_status=pod_status)
|
||||
status = remote_sandbox_service._get_sandbox_status_from_runtime(
|
||||
runtime_data
|
||||
)
|
||||
assert status == expected_status, f'Failed for pod_status: {pod_status}'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_mapping_coverage(self, remote_sandbox_service):
|
||||
"""Test all status mappings are handled correctly."""
|
||||
test_cases = [
|
||||
('running', SandboxStatus.RUNNING),
|
||||
('paused', SandboxStatus.PAUSED),
|
||||
('stopped', SandboxStatus.MISSING),
|
||||
('starting', SandboxStatus.STARTING),
|
||||
('error', SandboxStatus.ERROR),
|
||||
]
|
||||
|
||||
for status, expected_status in test_cases:
|
||||
# Use empty pod_status to force fallback to status field
|
||||
runtime_data = create_runtime_data(pod_status='', status=status)
|
||||
result = remote_sandbox_service._get_sandbox_status_from_runtime(
|
||||
runtime_data
|
||||
)
|
||||
assert result == expected_status, f'Failed for status: {status}'
|
||||
|
||||
|
||||
class TestEnvironmentInitialization:
|
||||
"""Test cases for environment variable initialization."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_environment_with_web_url(self, remote_sandbox_service):
|
||||
"""Test environment initialization with web_url set."""
|
||||
# Setup
|
||||
sandbox_spec = SandboxSpecInfo(
|
||||
id='test-image',
|
||||
command=['test'],
|
||||
initial_env={'EXISTING_VAR': 'existing_value'},
|
||||
working_dir='/workspace',
|
||||
)
|
||||
sandbox_id = 'test-sandbox-123'
|
||||
|
||||
# Execute
|
||||
environment = await remote_sandbox_service._init_environment(
|
||||
sandbox_spec, sandbox_id
|
||||
)
|
||||
|
||||
# Verify
|
||||
expected_webhook_url = (
|
||||
'https://web.example.com/api/v1/webhooks/test-sandbox-123'
|
||||
)
|
||||
assert environment['EXISTING_VAR'] == 'existing_value'
|
||||
assert environment[WEBHOOK_CALLBACK_VARIABLE] == expected_webhook_url
|
||||
assert environment[ALLOW_CORS_ORIGINS_VARIABLE] == 'https://web.example.com'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_environment_without_web_url(self, remote_sandbox_service):
|
||||
"""Test environment initialization without web_url."""
|
||||
# Setup
|
||||
remote_sandbox_service.web_url = None
|
||||
sandbox_spec = SandboxSpecInfo(
|
||||
id='test-image',
|
||||
command=['test'],
|
||||
initial_env={'EXISTING_VAR': 'existing_value'},
|
||||
working_dir='/workspace',
|
||||
)
|
||||
sandbox_id = 'test-sandbox-123'
|
||||
|
||||
# Execute
|
||||
environment = await remote_sandbox_service._init_environment(
|
||||
sandbox_spec, sandbox_id
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert environment['EXISTING_VAR'] == 'existing_value'
|
||||
assert WEBHOOK_CALLBACK_VARIABLE not in environment
|
||||
assert ALLOW_CORS_ORIGINS_VARIABLE not in environment
|
||||
|
||||
|
||||
class TestSandboxInfoConversion:
|
||||
"""Test cases for converting stored sandbox and runtime data to SandboxInfo."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_sandbox_info_with_running_runtime(self, remote_sandbox_service):
|
||||
"""Test conversion to SandboxInfo with running runtime."""
|
||||
# Setup
|
||||
stored_sandbox = create_stored_sandbox()
|
||||
runtime_data = create_runtime_data(status='running', pod_status='ready')
|
||||
|
||||
# Execute
|
||||
sandbox_info = await remote_sandbox_service._to_sandbox_info(
|
||||
stored_sandbox, runtime_data
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert sandbox_info.id == 'test-sandbox-123'
|
||||
assert sandbox_info.created_by_user_id == 'test-user-123'
|
||||
assert sandbox_info.sandbox_spec_id == 'test-image:latest'
|
||||
assert sandbox_info.status == SandboxStatus.RUNNING
|
||||
assert sandbox_info.session_api_key == 'test-session-key'
|
||||
assert len(sandbox_info.exposed_urls) == 4
|
||||
|
||||
# Check exposed URLs
|
||||
url_names = [url.name for url in sandbox_info.exposed_urls]
|
||||
assert AGENT_SERVER in url_names
|
||||
assert VSCODE in url_names
|
||||
assert WORKER_1 in url_names
|
||||
assert WORKER_2 in url_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_sandbox_info_with_starting_runtime(self, remote_sandbox_service):
|
||||
"""Test conversion to SandboxInfo with starting runtime."""
|
||||
# Setup
|
||||
stored_sandbox = create_stored_sandbox()
|
||||
runtime_data = create_runtime_data(status='running', pod_status='pending')
|
||||
|
||||
# Execute
|
||||
sandbox_info = await remote_sandbox_service._to_sandbox_info(
|
||||
stored_sandbox, runtime_data
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert sandbox_info.status == SandboxStatus.STARTING
|
||||
assert sandbox_info.session_api_key == 'test-session-key'
|
||||
assert sandbox_info.exposed_urls is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_sandbox_info_without_runtime(self, remote_sandbox_service):
|
||||
"""Test conversion to SandboxInfo without runtime data."""
|
||||
# Setup
|
||||
stored_sandbox = create_stored_sandbox()
|
||||
remote_sandbox_service._get_runtime = AsyncMock(
|
||||
side_effect=Exception('Runtime not found')
|
||||
)
|
||||
|
||||
# Execute
|
||||
sandbox_info = await remote_sandbox_service._to_sandbox_info(stored_sandbox)
|
||||
|
||||
# Verify
|
||||
assert sandbox_info.status == SandboxStatus.MISSING
|
||||
assert sandbox_info.session_api_key is None
|
||||
assert sandbox_info.exposed_urls is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_sandbox_info_loads_runtime_when_none_provided(
|
||||
self, remote_sandbox_service
|
||||
):
|
||||
"""Test that runtime data is loaded when not provided."""
|
||||
# Setup
|
||||
stored_sandbox = create_stored_sandbox()
|
||||
runtime_data = create_runtime_data()
|
||||
remote_sandbox_service._get_runtime = AsyncMock(return_value=runtime_data)
|
||||
|
||||
# Execute
|
||||
sandbox_info = await remote_sandbox_service._to_sandbox_info(stored_sandbox)
|
||||
|
||||
# Verify
|
||||
remote_sandbox_service._get_runtime.assert_called_once_with('test-sandbox-123')
|
||||
assert sandbox_info.status == SandboxStatus.RUNNING
|
||||
|
||||
|
||||
class TestSandboxLifecycle:
|
||||
"""Test cases for sandbox lifecycle operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_sandbox_success(
|
||||
self, remote_sandbox_service, mock_sandbox_spec_service
|
||||
):
|
||||
"""Test successful sandbox start."""
|
||||
# Setup
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = create_runtime_data()
|
||||
remote_sandbox_service.httpx_client.request.return_value = mock_response
|
||||
remote_sandbox_service.pause_old_sandboxes = AsyncMock(return_value=[])
|
||||
|
||||
# Mock database operations
|
||||
remote_sandbox_service.db_session.add = MagicMock()
|
||||
remote_sandbox_service.db_session.commit = AsyncMock()
|
||||
|
||||
# Execute
|
||||
with patch('base62.encodebytes', return_value='test-sandbox-123'):
|
||||
sandbox_info = await remote_sandbox_service.start_sandbox()
|
||||
|
||||
# Verify
|
||||
assert sandbox_info.id == 'test-sandbox-123'
|
||||
assert (
|
||||
sandbox_info.status == SandboxStatus.STARTING
|
||||
) # pod_status is 'pending' by default
|
||||
remote_sandbox_service.pause_old_sandboxes.assert_called_once_with(
|
||||
9
|
||||
) # max_num_sandboxes - 1
|
||||
remote_sandbox_service.db_session.add.assert_called_once()
|
||||
remote_sandbox_service.db_session.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_sandbox_with_specific_spec(
|
||||
self, remote_sandbox_service, mock_sandbox_spec_service
|
||||
):
|
||||
"""Test starting sandbox with specific sandbox spec."""
|
||||
# Setup
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = create_runtime_data()
|
||||
remote_sandbox_service.httpx_client.request.return_value = mock_response
|
||||
remote_sandbox_service.pause_old_sandboxes = AsyncMock(return_value=[])
|
||||
remote_sandbox_service.db_session.add = MagicMock()
|
||||
remote_sandbox_service.db_session.commit = AsyncMock()
|
||||
|
||||
# Execute
|
||||
with patch('base62.encodebytes', return_value='test-sandbox-123'):
|
||||
await remote_sandbox_service.start_sandbox('custom-spec-id')
|
||||
|
||||
# Verify
|
||||
mock_sandbox_spec_service.get_sandbox_spec.assert_called_once_with(
|
||||
'custom-spec-id'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_sandbox_spec_not_found(
|
||||
self, remote_sandbox_service, mock_sandbox_spec_service
|
||||
):
|
||||
"""Test starting sandbox with non-existent spec."""
|
||||
# Setup
|
||||
mock_sandbox_spec_service.get_sandbox_spec.return_value = None
|
||||
remote_sandbox_service.pause_old_sandboxes = AsyncMock(return_value=[])
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match='Sandbox Spec not found'):
|
||||
await remote_sandbox_service.start_sandbox('non-existent-spec')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_sandbox_http_error(self, remote_sandbox_service):
|
||||
"""Test sandbox start with HTTP error."""
|
||||
# Setup
|
||||
remote_sandbox_service.httpx_client.request.side_effect = httpx.HTTPError(
|
||||
'API Error'
|
||||
)
|
||||
remote_sandbox_service.pause_old_sandboxes = AsyncMock(return_value=[])
|
||||
remote_sandbox_service.db_session.add = MagicMock()
|
||||
remote_sandbox_service.db_session.commit = AsyncMock()
|
||||
|
||||
# Execute & Verify
|
||||
with patch('base62.encodebytes', return_value='test-sandbox-123'):
|
||||
with pytest.raises(SandboxError, match='Failed to start sandbox'):
|
||||
await remote_sandbox_service.start_sandbox()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_sandbox_with_sysbox_runtime(self, remote_sandbox_service):
|
||||
"""Test sandbox start with sysbox runtime class."""
|
||||
# Setup
|
||||
remote_sandbox_service.runtime_class = 'sysbox'
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = create_runtime_data()
|
||||
remote_sandbox_service.httpx_client.request.return_value = mock_response
|
||||
remote_sandbox_service.pause_old_sandboxes = AsyncMock(return_value=[])
|
||||
remote_sandbox_service.db_session.add = MagicMock()
|
||||
remote_sandbox_service.db_session.commit = AsyncMock()
|
||||
|
||||
# Execute
|
||||
with patch('base62.encodebytes', return_value='test-sandbox-123'):
|
||||
await remote_sandbox_service.start_sandbox()
|
||||
|
||||
# Verify runtime_class is included in request
|
||||
call_args = remote_sandbox_service.httpx_client.request.call_args
|
||||
request_data = call_args[1]['json']
|
||||
assert request_data['runtime_class'] == 'sysbox-runc'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_sandbox_success(self, remote_sandbox_service):
|
||||
"""Test successful sandbox resume."""
|
||||
# Setup
|
||||
stored_sandbox = create_stored_sandbox()
|
||||
runtime_data = create_runtime_data()
|
||||
|
||||
remote_sandbox_service._get_stored_sandbox = AsyncMock(
|
||||
return_value=stored_sandbox
|
||||
)
|
||||
remote_sandbox_service._get_runtime = AsyncMock(return_value=runtime_data)
|
||||
remote_sandbox_service.pause_old_sandboxes = AsyncMock(return_value=[])
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
remote_sandbox_service.httpx_client.request.return_value = mock_response
|
||||
|
||||
# Execute
|
||||
result = await remote_sandbox_service.resume_sandbox('test-sandbox-123')
|
||||
|
||||
# Verify
|
||||
assert result is True
|
||||
remote_sandbox_service.pause_old_sandboxes.assert_called_once_with(9)
|
||||
remote_sandbox_service.httpx_client.request.assert_called_once_with(
|
||||
'POST',
|
||||
'https://api.example.com/resume',
|
||||
headers={'X-API-Key': 'test-api-key'},
|
||||
json={'runtime_id': 'runtime-456'},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_sandbox_not_found(self, remote_sandbox_service):
|
||||
"""Test resuming non-existent sandbox."""
|
||||
# Setup
|
||||
remote_sandbox_service._get_stored_sandbox = AsyncMock(return_value=None)
|
||||
remote_sandbox_service.pause_old_sandboxes = AsyncMock(return_value=[])
|
||||
|
||||
# Execute
|
||||
result = await remote_sandbox_service.resume_sandbox('non-existent')
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_sandbox_runtime_not_found(self, remote_sandbox_service):
|
||||
"""Test resuming sandbox when runtime returns 404."""
|
||||
# Setup
|
||||
stored_sandbox = create_stored_sandbox()
|
||||
runtime_data = create_runtime_data()
|
||||
|
||||
remote_sandbox_service._get_stored_sandbox = AsyncMock(
|
||||
return_value=stored_sandbox
|
||||
)
|
||||
remote_sandbox_service._get_runtime = AsyncMock(return_value=runtime_data)
|
||||
remote_sandbox_service.pause_old_sandboxes = AsyncMock(return_value=[])
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 404
|
||||
remote_sandbox_service.httpx_client.request.return_value = mock_response
|
||||
|
||||
# Execute
|
||||
result = await remote_sandbox_service.resume_sandbox('test-sandbox-123')
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_sandbox_success(self, remote_sandbox_service):
|
||||
"""Test successful sandbox pause."""
|
||||
# Setup
|
||||
stored_sandbox = create_stored_sandbox()
|
||||
runtime_data = create_runtime_data()
|
||||
|
||||
remote_sandbox_service._get_stored_sandbox = AsyncMock(
|
||||
return_value=stored_sandbox
|
||||
)
|
||||
remote_sandbox_service._get_runtime = AsyncMock(return_value=runtime_data)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
remote_sandbox_service.httpx_client.request.return_value = mock_response
|
||||
|
||||
# Execute
|
||||
result = await remote_sandbox_service.pause_sandbox('test-sandbox-123')
|
||||
|
||||
# Verify
|
||||
assert result is True
|
||||
remote_sandbox_service.httpx_client.request.assert_called_once_with(
|
||||
'POST',
|
||||
'https://api.example.com/pause',
|
||||
headers={'X-API-Key': 'test-api-key'},
|
||||
json={'runtime_id': 'runtime-456'},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_sandbox_success(self, remote_sandbox_service):
|
||||
"""Test successful sandbox deletion."""
|
||||
# Setup
|
||||
stored_sandbox = create_stored_sandbox()
|
||||
runtime_data = create_runtime_data()
|
||||
|
||||
remote_sandbox_service._get_stored_sandbox = AsyncMock(
|
||||
return_value=stored_sandbox
|
||||
)
|
||||
remote_sandbox_service._get_runtime = AsyncMock(return_value=runtime_data)
|
||||
remote_sandbox_service.db_session.delete = AsyncMock()
|
||||
remote_sandbox_service.db_session.commit = AsyncMock()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
remote_sandbox_service.httpx_client.request.return_value = mock_response
|
||||
|
||||
# Execute
|
||||
result = await remote_sandbox_service.delete_sandbox('test-sandbox-123')
|
||||
|
||||
# Verify
|
||||
assert result is True
|
||||
remote_sandbox_service.db_session.delete.assert_called_once_with(stored_sandbox)
|
||||
remote_sandbox_service.db_session.commit.assert_called_once()
|
||||
remote_sandbox_service.httpx_client.request.assert_called_once_with(
|
||||
'POST',
|
||||
'https://api.example.com/stop',
|
||||
headers={'X-API-Key': 'test-api-key'},
|
||||
json={'runtime_id': 'runtime-456'},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_sandbox_runtime_not_found_ignored(
|
||||
self, remote_sandbox_service
|
||||
):
|
||||
"""Test sandbox deletion when runtime returns 404 (should be ignored)."""
|
||||
# Setup
|
||||
stored_sandbox = create_stored_sandbox()
|
||||
runtime_data = create_runtime_data()
|
||||
|
||||
remote_sandbox_service._get_stored_sandbox = AsyncMock(
|
||||
return_value=stored_sandbox
|
||||
)
|
||||
remote_sandbox_service._get_runtime = AsyncMock(return_value=runtime_data)
|
||||
remote_sandbox_service.db_session.delete = AsyncMock()
|
||||
remote_sandbox_service.db_session.commit = AsyncMock()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 404
|
||||
remote_sandbox_service.httpx_client.request.return_value = mock_response
|
||||
|
||||
# Execute
|
||||
result = await remote_sandbox_service.delete_sandbox('test-sandbox-123')
|
||||
|
||||
# Verify
|
||||
assert result is True # 404 should be ignored for delete operations
|
||||
|
||||
|
||||
class TestSandboxSearch:
|
||||
"""Test cases for sandbox search and retrieval."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_sandboxes_basic(self, remote_sandbox_service):
|
||||
"""Test basic sandbox search functionality."""
|
||||
# Setup
|
||||
stored_sandboxes = [
|
||||
create_stored_sandbox('sb1'),
|
||||
create_stored_sandbox('sb2'),
|
||||
]
|
||||
|
||||
mock_scalars = MagicMock()
|
||||
mock_scalars.all.return_value = stored_sandboxes
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value = mock_scalars
|
||||
remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result)
|
||||
remote_sandbox_service._to_sandbox_info = AsyncMock(
|
||||
side_effect=lambda stored: SandboxInfo(
|
||||
id=stored.id,
|
||||
created_by_user_id=stored.created_by_user_id,
|
||||
sandbox_spec_id=stored.sandbox_spec_id,
|
||||
status=SandboxStatus.RUNNING,
|
||||
session_api_key='test-key',
|
||||
created_at=stored.created_at,
|
||||
)
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = await remote_sandbox_service.search_sandboxes()
|
||||
|
||||
# Verify
|
||||
assert len(result.items) == 2
|
||||
assert result.next_page_id is None
|
||||
assert result.items[0].id == 'sb1'
|
||||
assert result.items[1].id == 'sb2'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_sandboxes_with_pagination(self, remote_sandbox_service):
|
||||
"""Test sandbox search with pagination."""
|
||||
# Setup - return limit + 1 items to trigger pagination
|
||||
stored_sandboxes = [
|
||||
create_stored_sandbox(f'sb{i}') for i in range(6)
|
||||
] # limit=5, so 6 items
|
||||
|
||||
mock_scalars = MagicMock()
|
||||
mock_scalars.all.return_value = stored_sandboxes
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value = mock_scalars
|
||||
remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result)
|
||||
remote_sandbox_service._to_sandbox_info = AsyncMock(
|
||||
side_effect=lambda stored: SandboxInfo(
|
||||
id=stored.id,
|
||||
created_by_user_id=stored.created_by_user_id,
|
||||
sandbox_spec_id=stored.sandbox_spec_id,
|
||||
status=SandboxStatus.RUNNING,
|
||||
session_api_key='test-key',
|
||||
created_at=stored.created_at,
|
||||
)
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = await remote_sandbox_service.search_sandboxes(limit=5)
|
||||
|
||||
# Verify
|
||||
assert len(result.items) == 5 # Should be limited to 5
|
||||
assert result.next_page_id == '5' # Next page offset
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_sandboxes_with_page_id(self, remote_sandbox_service):
|
||||
"""Test sandbox search with page_id offset."""
|
||||
# Setup
|
||||
stored_sandboxes = [create_stored_sandbox('sb1')]
|
||||
|
||||
mock_scalars = MagicMock()
|
||||
mock_scalars.all.return_value = stored_sandboxes
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value = mock_scalars
|
||||
remote_sandbox_service.db_session.execute = AsyncMock(return_value=mock_result)
|
||||
remote_sandbox_service._to_sandbox_info = AsyncMock(
|
||||
side_effect=lambda stored: SandboxInfo(
|
||||
id=stored.id,
|
||||
created_by_user_id=stored.created_by_user_id,
|
||||
sandbox_spec_id=stored.sandbox_spec_id,
|
||||
status=SandboxStatus.RUNNING,
|
||||
session_api_key='test-key',
|
||||
created_at=stored.created_at,
|
||||
)
|
||||
)
|
||||
|
||||
# Execute
|
||||
await remote_sandbox_service.search_sandboxes(page_id='10', limit=5)
|
||||
|
||||
# Verify that offset was applied to the query
|
||||
# Note: We can't easily verify the exact SQL query, but we can verify the method was called
|
||||
remote_sandbox_service.db_session.execute.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sandbox_exists(self, remote_sandbox_service):
|
||||
"""Test getting an existing sandbox."""
|
||||
# Setup
|
||||
stored_sandbox = create_stored_sandbox()
|
||||
remote_sandbox_service._get_stored_sandbox = AsyncMock(
|
||||
return_value=stored_sandbox
|
||||
)
|
||||
remote_sandbox_service._to_sandbox_info = AsyncMock(
|
||||
return_value=SandboxInfo(
|
||||
id='test-sandbox-123',
|
||||
created_by_user_id='test-user-123',
|
||||
sandbox_spec_id='test-image:latest',
|
||||
status=SandboxStatus.RUNNING,
|
||||
session_api_key='test-key',
|
||||
created_at=stored_sandbox.created_at,
|
||||
)
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = await remote_sandbox_service.get_sandbox('test-sandbox-123')
|
||||
|
||||
# Verify
|
||||
assert result is not None
|
||||
assert result.id == 'test-sandbox-123'
|
||||
remote_sandbox_service._get_stored_sandbox.assert_called_once_with(
|
||||
'test-sandbox-123'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sandbox_not_exists(self, remote_sandbox_service):
|
||||
"""Test getting a non-existent sandbox."""
|
||||
# Setup
|
||||
remote_sandbox_service._get_stored_sandbox = AsyncMock(return_value=None)
|
||||
|
||||
# Execute
|
||||
result = await remote_sandbox_service.get_sandbox('non-existent')
|
||||
|
||||
# Verify
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestUserSecurity:
|
||||
"""Test cases for user-scoped operations and security."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_secure_select_with_user_id(self, remote_sandbox_service):
|
||||
"""Test that _secure_select filters by user ID."""
|
||||
# Setup
|
||||
remote_sandbox_service.user_context.get_user_id.return_value = 'test-user-123'
|
||||
|
||||
# Execute
|
||||
await remote_sandbox_service._secure_select()
|
||||
|
||||
# Verify
|
||||
# Note: We can't easily test the exact SQL query structure, but we can verify
|
||||
# that get_user_id was called, which means user filtering should be applied
|
||||
remote_sandbox_service.user_context.get_user_id.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_secure_select_without_user_id(self, remote_sandbox_service):
|
||||
"""Test that _secure_select works when user ID is None."""
|
||||
# Setup
|
||||
remote_sandbox_service.user_context.get_user_id.return_value = None
|
||||
|
||||
# Execute
|
||||
await remote_sandbox_service._secure_select()
|
||||
|
||||
# Verify
|
||||
remote_sandbox_service.user_context.get_user_id.assert_called_once()
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test cases for error handling scenarios."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_sandbox_http_error(self, remote_sandbox_service):
|
||||
"""Test resume sandbox with HTTP error."""
|
||||
# Setup
|
||||
stored_sandbox = create_stored_sandbox()
|
||||
runtime_data = create_runtime_data()
|
||||
|
||||
remote_sandbox_service._get_stored_sandbox = AsyncMock(
|
||||
return_value=stored_sandbox
|
||||
)
|
||||
remote_sandbox_service._get_runtime = AsyncMock(return_value=runtime_data)
|
||||
remote_sandbox_service.pause_old_sandboxes = AsyncMock(return_value=[])
|
||||
remote_sandbox_service.httpx_client.request.side_effect = httpx.HTTPError(
|
||||
'API Error'
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = await remote_sandbox_service.resume_sandbox('test-sandbox-123')
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_sandbox_http_error(self, remote_sandbox_service):
|
||||
"""Test pause sandbox with HTTP error."""
|
||||
# Setup
|
||||
stored_sandbox = create_stored_sandbox()
|
||||
runtime_data = create_runtime_data()
|
||||
|
||||
remote_sandbox_service._get_stored_sandbox = AsyncMock(
|
||||
return_value=stored_sandbox
|
||||
)
|
||||
remote_sandbox_service._get_runtime = AsyncMock(return_value=runtime_data)
|
||||
remote_sandbox_service.httpx_client.request.side_effect = httpx.HTTPError(
|
||||
'API Error'
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = await remote_sandbox_service.pause_sandbox('test-sandbox-123')
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_sandbox_http_error(self, remote_sandbox_service):
|
||||
"""Test delete sandbox with HTTP error."""
|
||||
# Setup
|
||||
stored_sandbox = create_stored_sandbox()
|
||||
runtime_data = create_runtime_data()
|
||||
|
||||
remote_sandbox_service._get_stored_sandbox = AsyncMock(
|
||||
return_value=stored_sandbox
|
||||
)
|
||||
remote_sandbox_service._get_runtime = AsyncMock(return_value=runtime_data)
|
||||
remote_sandbox_service.db_session.delete = AsyncMock()
|
||||
remote_sandbox_service.db_session.commit = AsyncMock()
|
||||
remote_sandbox_service.httpx_client.request.side_effect = httpx.HTTPError(
|
||||
'API Error'
|
||||
)
|
||||
|
||||
# Execute
|
||||
result = await remote_sandbox_service.delete_sandbox('test-sandbox-123')
|
||||
|
||||
# Verify
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestUtilityFunctions:
|
||||
"""Test cases for utility functions."""
|
||||
|
||||
def test_build_service_url(self):
|
||||
"""Test _build_service_url function."""
|
||||
from openhands.app_server.sandbox.remote_sandbox_service import (
|
||||
_build_service_url,
|
||||
)
|
||||
|
||||
# Test HTTPS URL
|
||||
result = _build_service_url('https://sandbox.example.com/path', 'vscode')
|
||||
assert result == 'https://vscode-sandbox.example.com/path'
|
||||
|
||||
# Test HTTP URL
|
||||
result = _build_service_url('http://localhost:8000', 'work-1')
|
||||
assert result == 'http://work-1-localhost:8000'
|
||||
|
||||
|
||||
class TestConstants:
|
||||
"""Test cases for constants and mappings."""
|
||||
|
||||
def test_pod_status_mapping_completeness(self):
|
||||
"""Test that POD_STATUS_MAPPING covers expected statuses."""
|
||||
expected_statuses = [
|
||||
'ready',
|
||||
'pending',
|
||||
'running',
|
||||
'failed',
|
||||
'unknown',
|
||||
'crashloopbackoff',
|
||||
]
|
||||
for status in expected_statuses:
|
||||
assert status in POD_STATUS_MAPPING, f'Missing pod status: {status}'
|
||||
|
||||
def test_status_mapping_completeness(self):
|
||||
"""Test that STATUS_MAPPING covers expected statuses."""
|
||||
expected_statuses = ['running', 'paused', 'stopped', 'starting', 'error']
|
||||
for status in expected_statuses:
|
||||
assert status in STATUS_MAPPING, f'Missing status: {status}'
|
||||
|
||||
def test_environment_variable_constants(self):
|
||||
"""Test that environment variable constants are defined."""
|
||||
assert WEBHOOK_CALLBACK_VARIABLE == 'OH_WEBHOOKS_0_BASE_URL'
|
||||
assert ALLOW_CORS_ORIGINS_VARIABLE == 'OH_ALLOW_CORS_ORIGINS_0'
|
||||
Reference in New Issue
Block a user