mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Refactor of error handling (#4575)
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com> Co-authored-by: Xingyao Wang <xingyao@all-hands.dev> Co-authored-by: Xingyao Wang <xingyao6@illinois.edu>
This commit is contained in:
@@ -27,14 +27,14 @@ class E2BRuntime(Runtime):
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
sandbox: E2BSandbox | None = None,
|
||||
status_message_callback: Optional[Callable] = None,
|
||||
status_callback: Optional[Callable] = None,
|
||||
):
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
sid,
|
||||
plugins,
|
||||
status_message_callback=status_message_callback,
|
||||
status_callback=status_callback,
|
||||
)
|
||||
if sandbox is None:
|
||||
self.sandbox = E2BSandbox()
|
||||
|
||||
@@ -25,7 +25,7 @@ from openhands.events.action import (
|
||||
)
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.observation import (
|
||||
FatalErrorObservation,
|
||||
ErrorObservation,
|
||||
NullObservation,
|
||||
Observation,
|
||||
UserRejectObservation,
|
||||
@@ -36,8 +36,9 @@ from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.builder import DockerRuntimeBuilder
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
from openhands.runtime.utils import find_available_tcp_port
|
||||
from openhands.runtime.utils.request import send_request_with_retry
|
||||
from openhands.runtime.utils.request import send_request
|
||||
from openhands.runtime.utils.runtime_build import build_runtime_image
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
|
||||
|
||||
@@ -123,7 +124,7 @@ class EventStreamRuntime(Runtime):
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
status_message_callback: Callable | None = None,
|
||||
status_callback: Callable | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -132,7 +133,7 @@ class EventStreamRuntime(Runtime):
|
||||
sid,
|
||||
plugins,
|
||||
env_vars,
|
||||
status_message_callback,
|
||||
status_callback,
|
||||
attach_to_existing,
|
||||
)
|
||||
|
||||
@@ -143,7 +144,7 @@ class EventStreamRuntime(Runtime):
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
status_message_callback: Callable | None = None,
|
||||
status_callback: Callable | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
):
|
||||
self.config = config
|
||||
@@ -151,7 +152,7 @@ class EventStreamRuntime(Runtime):
|
||||
self._container_port = 30001 # initial dummy value
|
||||
self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
|
||||
self.session = requests.Session()
|
||||
self.status_message_callback = status_message_callback
|
||||
self.status_callback = status_callback
|
||||
|
||||
self.docker_client: docker.DockerClient = self._init_docker_client()
|
||||
self.base_container_image = self.config.sandbox.base_container_image
|
||||
@@ -181,7 +182,7 @@ class EventStreamRuntime(Runtime):
|
||||
sid,
|
||||
plugins,
|
||||
env_vars,
|
||||
status_message_callback,
|
||||
status_callback,
|
||||
attach_to_existing,
|
||||
)
|
||||
|
||||
@@ -205,21 +206,21 @@ class EventStreamRuntime(Runtime):
|
||||
self.log(
|
||||
'info', f'Starting runtime with image: {self.runtime_container_image}'
|
||||
)
|
||||
self._init_container()
|
||||
await call_sync_from_async(self._init_container)
|
||||
self.log('info', f'Container started: {self.container_name}')
|
||||
|
||||
else:
|
||||
self._attach_to_container()
|
||||
await call_sync_from_async(self._attach_to_container)
|
||||
|
||||
if not self.attach_to_existing:
|
||||
self.log('info', f'Waiting for client to become ready at {self.api_url}...')
|
||||
self.send_status_message('STATUS$WAITING_FOR_CLIENT')
|
||||
self._wait_until_alive()
|
||||
await call_sync_from_async(self._wait_until_alive)
|
||||
if not self.attach_to_existing:
|
||||
self.log('info', 'Runtime is ready.')
|
||||
|
||||
if not self.attach_to_existing:
|
||||
self.setup_initial_env()
|
||||
await call_sync_from_async(self.setup_initial_env)
|
||||
|
||||
self.log(
|
||||
'debug',
|
||||
@@ -238,82 +239,74 @@ class EventStreamRuntime(Runtime):
|
||||
)
|
||||
raise ex
|
||||
|
||||
@tenacity.retry(
|
||||
stop=tenacity.stop_after_attempt(5) | stop_if_should_exit(),
|
||||
wait=tenacity.wait_fixed(5),
|
||||
)
|
||||
def _init_container(self):
|
||||
try:
|
||||
self.log('debug', 'Preparing to start container...')
|
||||
self.send_status_message('STATUS$PREPARING_CONTAINER')
|
||||
plugin_arg = ''
|
||||
if self.plugins is not None and len(self.plugins) > 0:
|
||||
plugin_arg = (
|
||||
f'--plugins {" ".join([plugin.name for plugin in self.plugins])} '
|
||||
)
|
||||
|
||||
self._host_port = self._find_available_port()
|
||||
self._container_port = (
|
||||
self._host_port
|
||||
) # in future this might differ from host port
|
||||
self.api_url = (
|
||||
f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
|
||||
self.log('debug', 'Preparing to start container...')
|
||||
self.send_status_message('STATUS$PREPARING_CONTAINER')
|
||||
plugin_arg = ''
|
||||
if self.plugins is not None and len(self.plugins) > 0:
|
||||
plugin_arg = (
|
||||
f'--plugins {" ".join([plugin.name for plugin in self.plugins])} '
|
||||
)
|
||||
|
||||
use_host_network = self.config.sandbox.use_host_network
|
||||
network_mode: str | None = 'host' if use_host_network else None
|
||||
port_mapping: dict[str, list[dict[str, str]]] | None = (
|
||||
None
|
||||
if use_host_network
|
||||
else {
|
||||
f'{self._container_port}/tcp': [{'HostPort': str(self._host_port)}]
|
||||
}
|
||||
)
|
||||
self._host_port = self._find_available_port()
|
||||
self._container_port = (
|
||||
self._host_port
|
||||
) # in future this might differ from host port
|
||||
self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
|
||||
|
||||
if use_host_network:
|
||||
self.log(
|
||||
'warn',
|
||||
'Using host network mode. If you are using MacOS, please make sure you have the latest version of Docker Desktop and enabled host network feature: https://docs.docker.com/network/drivers/host/#docker-desktop',
|
||||
)
|
||||
use_host_network = self.config.sandbox.use_host_network
|
||||
network_mode: str | None = 'host' if use_host_network else None
|
||||
port_mapping: dict[str, list[dict[str, str]]] | None = (
|
||||
None
|
||||
if use_host_network
|
||||
else {f'{self._container_port}/tcp': [{'HostPort': str(self._host_port)}]}
|
||||
)
|
||||
|
||||
# Combine environment variables
|
||||
environment = {
|
||||
'port': str(self._container_port),
|
||||
'PYTHONUNBUFFERED': 1,
|
||||
}
|
||||
if self.config.debug or DEBUG:
|
||||
environment['DEBUG'] = 'true'
|
||||
|
||||
self.log('debug', f'Workspace Base: {self.config.workspace_base}')
|
||||
if (
|
||||
self.config.workspace_mount_path is not None
|
||||
and self.config.workspace_mount_path_in_sandbox is not None
|
||||
):
|
||||
# e.g. result would be: {"/home/user/openhands/workspace": {'bind': "/workspace", 'mode': 'rw'}}
|
||||
volumes = {
|
||||
self.config.workspace_mount_path: {
|
||||
'bind': self.config.workspace_mount_path_in_sandbox,
|
||||
'mode': 'rw',
|
||||
}
|
||||
}
|
||||
logger.debug(f'Mount dir: {self.config.workspace_mount_path}')
|
||||
else:
|
||||
logger.debug(
|
||||
'Mount dir is not set, will not mount the workspace directory to the container'
|
||||
)
|
||||
volumes = None
|
||||
if use_host_network:
|
||||
self.log(
|
||||
'debug',
|
||||
f'Sandbox workspace: {self.config.workspace_mount_path_in_sandbox}',
|
||||
'warn',
|
||||
'Using host network mode. If you are using MacOS, please make sure you have the latest version of Docker Desktop and enabled host network feature: https://docs.docker.com/network/drivers/host/#docker-desktop',
|
||||
)
|
||||
|
||||
if self.config.sandbox.browsergym_eval_env is not None:
|
||||
browsergym_arg = (
|
||||
f'--browsergym-eval-env {self.config.sandbox.browsergym_eval_env}'
|
||||
)
|
||||
else:
|
||||
browsergym_arg = ''
|
||||
# Combine environment variables
|
||||
environment = {
|
||||
'port': str(self._container_port),
|
||||
'PYTHONUNBUFFERED': 1,
|
||||
}
|
||||
if self.config.debug or DEBUG:
|
||||
environment['DEBUG'] = 'true'
|
||||
|
||||
self.log('debug', f'Workspace Base: {self.config.workspace_base}')
|
||||
if (
|
||||
self.config.workspace_mount_path is not None
|
||||
and self.config.workspace_mount_path_in_sandbox is not None
|
||||
):
|
||||
# e.g. result would be: {"/home/user/openhands/workspace": {'bind': "/workspace", 'mode': 'rw'}}
|
||||
volumes = {
|
||||
self.config.workspace_mount_path: {
|
||||
'bind': self.config.workspace_mount_path_in_sandbox,
|
||||
'mode': 'rw',
|
||||
}
|
||||
}
|
||||
logger.debug(f'Mount dir: {self.config.workspace_mount_path}')
|
||||
else:
|
||||
logger.debug(
|
||||
'Mount dir is not set, will not mount the workspace directory to the container'
|
||||
)
|
||||
volumes = None
|
||||
self.log(
|
||||
'debug',
|
||||
f'Sandbox workspace: {self.config.workspace_mount_path_in_sandbox}',
|
||||
)
|
||||
|
||||
if self.config.sandbox.browsergym_eval_env is not None:
|
||||
browsergym_arg = (
|
||||
f'--browsergym-eval-env {self.config.sandbox.browsergym_eval_env}'
|
||||
)
|
||||
else:
|
||||
browsergym_arg = ''
|
||||
|
||||
try:
|
||||
self.container = self.docker_client.containers.run(
|
||||
self.runtime_container_image,
|
||||
command=(
|
||||
@@ -337,6 +330,21 @@ class EventStreamRuntime(Runtime):
|
||||
self.log_buffer = LogBuffer(self.container, self.log)
|
||||
self.log('debug', f'Container started. Server url: {self.api_url}')
|
||||
self.send_status_message('STATUS$CONTAINER_STARTED')
|
||||
except docker.errors.APIError as e:
|
||||
# check 409 error
|
||||
if '409' in str(e):
|
||||
self.log(
|
||||
'warning',
|
||||
f'Container {self.container_name} already exists. Removing...',
|
||||
)
|
||||
self._close_containers(rm_all_containers=True)
|
||||
return self._init_container()
|
||||
|
||||
else:
|
||||
self.log(
|
||||
'error',
|
||||
f'Error: Instance {self.container_name} FAILED to start container!\n',
|
||||
)
|
||||
except Exception as e:
|
||||
self.log(
|
||||
'error',
|
||||
@@ -384,27 +392,20 @@ class EventStreamRuntime(Runtime):
|
||||
|
||||
@tenacity.retry(
|
||||
stop=tenacity.stop_after_delay(120) | stop_if_should_exit(),
|
||||
wait=tenacity.wait_exponential(multiplier=2, min=1, max=20),
|
||||
reraise=(ConnectionRefusedError,),
|
||||
wait=tenacity.wait_fixed(2),
|
||||
)
|
||||
def _wait_until_alive(self):
|
||||
self._refresh_logs()
|
||||
if not self.log_buffer:
|
||||
raise RuntimeError('Runtime client is not ready.')
|
||||
|
||||
response = send_request_with_retry(
|
||||
send_request(
|
||||
self.session,
|
||||
'GET',
|
||||
f'{self.api_url}/alive',
|
||||
retry_exceptions=[ConnectionRefusedError],
|
||||
timeout=300, # 5 minutes gives the container time to be alive 🧟♂️
|
||||
timeout=5,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return
|
||||
else:
|
||||
msg = f'Action execution API is not alive. Response: {response}'
|
||||
self.log('error', msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def close(self, rm_all_containers: bool = True):
|
||||
"""Closes the EventStreamRuntime and associated objects
|
||||
@@ -421,7 +422,9 @@ class EventStreamRuntime(Runtime):
|
||||
|
||||
if self.attach_to_existing:
|
||||
return
|
||||
self._close_containers(rm_all_containers)
|
||||
|
||||
def _close_containers(self, rm_all_containers: bool = True):
|
||||
try:
|
||||
containers = self.docker_client.containers.list(all=True)
|
||||
for container in containers:
|
||||
@@ -466,10 +469,11 @@ class EventStreamRuntime(Runtime):
|
||||
return NullObservation('')
|
||||
action_type = action.action # type: ignore[attr-defined]
|
||||
if action_type not in ACTION_TYPE_TO_CLASS:
|
||||
return FatalErrorObservation(f'Action {action_type} does not exist.')
|
||||
raise ValueError(f'Action {action_type} does not exist.')
|
||||
if not hasattr(self, action_type):
|
||||
return FatalErrorObservation(
|
||||
f'Action {action_type} is not supported in the current runtime.'
|
||||
return ErrorObservation(
|
||||
f'Action {action_type} is not supported in the current runtime.',
|
||||
error_id='AGENT_ERROR$BAD_ACTION',
|
||||
)
|
||||
if (
|
||||
getattr(action, 'confirmation_state', None)
|
||||
@@ -484,33 +488,21 @@ class EventStreamRuntime(Runtime):
|
||||
assert action.timeout is not None
|
||||
|
||||
try:
|
||||
response = send_request_with_retry(
|
||||
response = send_request(
|
||||
self.session,
|
||||
'POST',
|
||||
f'{self.api_url}/execute_action',
|
||||
json={'action': event_to_dict(action)},
|
||||
timeout=action.timeout,
|
||||
# wait a few more seconds to get the timeout error from client side
|
||||
timeout=action.timeout + 5,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
output = response.json()
|
||||
obs = observation_from_dict(output)
|
||||
obs._cause = action.id # type: ignore[attr-defined]
|
||||
else:
|
||||
self.log('debug', f'action: {action}')
|
||||
self.log('debug', f'response: {response}')
|
||||
error_message = response.text
|
||||
self.log('error', f'Error from server: {error_message}')
|
||||
obs = FatalErrorObservation(
|
||||
f'Action execution failed: {error_message}'
|
||||
)
|
||||
output = response.json()
|
||||
obs = observation_from_dict(output)
|
||||
obs._cause = action.id # type: ignore[attr-defined]
|
||||
except requests.Timeout:
|
||||
self.log('error', 'No response received within the timeout period.')
|
||||
obs = FatalErrorObservation(
|
||||
f'Action execution timed out after {action.timeout} seconds.'
|
||||
raise RuntimeError(
|
||||
f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s'
|
||||
)
|
||||
except Exception as e:
|
||||
self.log('error', f'Error during action execution: {e}')
|
||||
obs = FatalErrorObservation(f'Action execution failed: {str(e)}')
|
||||
self._refresh_logs()
|
||||
return obs
|
||||
|
||||
@@ -567,7 +559,7 @@ class EventStreamRuntime(Runtime):
|
||||
|
||||
params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
|
||||
|
||||
response = send_request_with_retry(
|
||||
send_request(
|
||||
self.session,
|
||||
'POST',
|
||||
f'{self.api_url}/upload_file',
|
||||
@@ -575,11 +567,6 @@ class EventStreamRuntime(Runtime):
|
||||
params=params,
|
||||
timeout=300,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return
|
||||
else:
|
||||
error_message = response.text
|
||||
raise Exception(f'Copy operation failed: {error_message}')
|
||||
|
||||
except requests.Timeout:
|
||||
raise TimeoutError('Copy operation timed out')
|
||||
@@ -604,31 +591,25 @@ class EventStreamRuntime(Runtime):
|
||||
if path is not None:
|
||||
data['path'] = path
|
||||
|
||||
response = send_request_with_retry(
|
||||
response = send_request(
|
||||
self.session,
|
||||
'POST',
|
||||
f'{self.api_url}/list_files',
|
||||
json=data,
|
||||
timeout=30, # 30 seconds because the container should already be alive
|
||||
timeout=10,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
response_json = response.json()
|
||||
assert isinstance(response_json, list)
|
||||
return response_json
|
||||
else:
|
||||
error_message = response.text
|
||||
raise Exception(f'List files operation failed: {error_message}')
|
||||
response_json = response.json()
|
||||
assert isinstance(response_json, list)
|
||||
return response_json
|
||||
except requests.Timeout:
|
||||
raise TimeoutError('List files operation timed out')
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'List files operation failed: {str(e)}')
|
||||
|
||||
def copy_from(self, path: str) -> bytes:
|
||||
"""Zip all files in the sandbox and return as a stream of bytes."""
|
||||
self._refresh_logs()
|
||||
try:
|
||||
params = {'path': path}
|
||||
response = send_request_with_retry(
|
||||
response = send_request(
|
||||
self.session,
|
||||
'GET',
|
||||
f'{self.api_url}/download_files',
|
||||
@@ -636,16 +617,10 @@ class EventStreamRuntime(Runtime):
|
||||
stream=True,
|
||||
timeout=30,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
data = response.content
|
||||
return data
|
||||
else:
|
||||
error_message = response.text
|
||||
raise Exception(f'Copy operation failed: {error_message}')
|
||||
data = response.content
|
||||
return data
|
||||
except requests.Timeout:
|
||||
raise TimeoutError('Copy operation timed out')
|
||||
except Exception as e:
|
||||
raise RuntimeError(f'Copy operation failed: {str(e)}')
|
||||
|
||||
def _is_port_in_use_docker(self, port):
|
||||
containers = self.docker_client.containers.list()
|
||||
@@ -663,8 +638,3 @@ class EventStreamRuntime(Runtime):
|
||||
return port
|
||||
# If no port is found after max_attempts, return the last tried port
|
||||
return port
|
||||
|
||||
def send_status_message(self, message: str):
|
||||
"""Sends a status message if the callback function was provided."""
|
||||
if self.status_message_callback:
|
||||
self.status_message_callback(message)
|
||||
|
||||
@@ -75,7 +75,7 @@ class ModalRuntime(EventStreamRuntime):
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
status_message_callback: Callable | None = None,
|
||||
status_callback: Callable | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
):
|
||||
assert config.modal_api_token_id, 'Modal API token id is required'
|
||||
@@ -102,7 +102,7 @@ class ModalRuntime(EventStreamRuntime):
|
||||
self.container_port = 3000
|
||||
|
||||
self.session = requests.Session()
|
||||
self.status_message_callback = status_message_callback
|
||||
self.status_callback = status_callback
|
||||
self.base_container_image_id = self.config.sandbox.base_container_image
|
||||
self.runtime_container_image_id = self.config.sandbox.runtime_container_image
|
||||
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
|
||||
@@ -122,7 +122,7 @@ class ModalRuntime(EventStreamRuntime):
|
||||
sid,
|
||||
plugins,
|
||||
env_vars,
|
||||
status_message_callback,
|
||||
status_callback,
|
||||
attach_to_existing,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from zipfile import ZipFile
|
||||
|
||||
import requests
|
||||
from requests.exceptions import Timeout
|
||||
import tenacity
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.events import EventStream
|
||||
@@ -21,22 +20,26 @@ from openhands.events.action import (
|
||||
)
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.observation import (
|
||||
FatalErrorObservation,
|
||||
ErrorObservation,
|
||||
NullObservation,
|
||||
Observation,
|
||||
)
|
||||
from openhands.events.serialization import event_to_dict, observation_from_dict
|
||||
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.base import (
|
||||
Runtime,
|
||||
RuntimeDisconnectedError,
|
||||
RuntimeNotReadyError,
|
||||
)
|
||||
from openhands.runtime.builder.remote import RemoteRuntimeBuilder
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
from openhands.runtime.utils.command import get_remote_startup_command
|
||||
from openhands.runtime.utils.request import (
|
||||
is_404_error,
|
||||
is_503_error,
|
||||
send_request_with_retry,
|
||||
send_request,
|
||||
)
|
||||
from openhands.runtime.utils.runtime_build import build_runtime_image
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
|
||||
|
||||
class RemoteRuntime(Runtime):
|
||||
@@ -51,31 +54,32 @@ class RemoteRuntime(Runtime):
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
status_message_callback: Optional[Callable] = None,
|
||||
status_callback: Optional[Callable] = None,
|
||||
attach_to_existing: bool = False,
|
||||
):
|
||||
# We need to set session and action_semaphore before the __init__ below, or we get odd errors
|
||||
self.session = requests.Session()
|
||||
self.action_semaphore = threading.Semaphore(1)
|
||||
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
sid,
|
||||
plugins,
|
||||
env_vars,
|
||||
status_message_callback,
|
||||
status_callback,
|
||||
attach_to_existing,
|
||||
)
|
||||
|
||||
if self.config.sandbox.api_key is None:
|
||||
raise ValueError(
|
||||
'API key is required to use the remote runtime. '
|
||||
'Please set the API key in the config (config.toml) or as an environment variable (SANDBOX_API_KEY).'
|
||||
)
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({'X-API-Key': self.config.sandbox.api_key})
|
||||
self.action_semaphore = threading.Semaphore(1)
|
||||
|
||||
if self.config.workspace_base is not None:
|
||||
self.log(
|
||||
'warning',
|
||||
'debug',
|
||||
'Setting workspace_base is not supported in the remote runtime.',
|
||||
)
|
||||
|
||||
@@ -86,9 +90,13 @@ class RemoteRuntime(Runtime):
|
||||
self.runtime_url: str | None = None
|
||||
|
||||
async def connect(self):
|
||||
self._start_or_attach_to_runtime()
|
||||
self._wait_until_alive()
|
||||
self.setup_initial_env()
|
||||
await call_sync_from_async(self._start_or_attach_to_runtime)
|
||||
try:
|
||||
await call_sync_from_async(self._wait_until_alive)
|
||||
except RuntimeNotReadyError:
|
||||
self.log('error', 'Runtime failed to start, timed out before ready')
|
||||
raise
|
||||
await call_sync_from_async(self.setup_initial_env)
|
||||
|
||||
def _start_or_attach_to_runtime(self):
|
||||
existing_runtime = self._check_existing_runtime()
|
||||
@@ -127,44 +135,40 @@ class RemoteRuntime(Runtime):
|
||||
|
||||
def _check_existing_runtime(self) -> bool:
|
||||
try:
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
response = self._send_request(
|
||||
'GET',
|
||||
f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.sid}',
|
||||
timeout=5,
|
||||
)
|
||||
except Exception as e:
|
||||
except requests.HTTPError as e:
|
||||
if e.response.status_code == 404:
|
||||
return False
|
||||
self.log('debug', f'Error while looking for remote runtime: {e}')
|
||||
return False
|
||||
raise
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
status = data.get('status')
|
||||
if status == 'running':
|
||||
self._parse_runtime_response(response)
|
||||
return True
|
||||
elif status == 'stopped':
|
||||
self.log('debug', 'Found existing remote runtime, but it is stopped')
|
||||
return False
|
||||
elif status == 'paused':
|
||||
self.log('debug', 'Found existing remote runtime, but it is paused')
|
||||
self._parse_runtime_response(response)
|
||||
self._resume_runtime()
|
||||
return True
|
||||
else:
|
||||
self.log('error', f'Invalid response from runtime API: {data}')
|
||||
return False
|
||||
data = response.json()
|
||||
status = data.get('status')
|
||||
if status == 'running':
|
||||
self._parse_runtime_response(response)
|
||||
return True
|
||||
elif status == 'stopped':
|
||||
self.log('debug', 'Found existing remote runtime, but it is stopped')
|
||||
return False
|
||||
elif status == 'paused':
|
||||
self.log('debug', 'Found existing remote runtime, but it is paused')
|
||||
self._parse_runtime_response(response)
|
||||
self._resume_runtime()
|
||||
return True
|
||||
else:
|
||||
self.log('debug', 'Could not find existing remote runtime')
|
||||
self.log('error', f'Invalid response from runtime API: {data}')
|
||||
return False
|
||||
|
||||
def _build_runtime(self):
|
||||
self.log('debug', f'Building RemoteRuntime config:\n{self.config}')
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
response = self._send_request(
|
||||
'GET',
|
||||
f'{self.config.sandbox.remote_runtime_api_url}/registry_prefix',
|
||||
timeout=30,
|
||||
timeout=10,
|
||||
)
|
||||
response_json = response.json()
|
||||
registry_prefix = response_json['registry_prefix']
|
||||
@@ -191,14 +195,13 @@ class RemoteRuntime(Runtime):
|
||||
force_rebuild=self.config.sandbox.force_rebuild_runtime,
|
||||
)
|
||||
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
response = self._send_request(
|
||||
'GET',
|
||||
f'{self.config.sandbox.remote_runtime_api_url}/image_exists',
|
||||
params={'image': self.container_image},
|
||||
timeout=30,
|
||||
timeout=10,
|
||||
)
|
||||
if response.status_code != 200 or not response.json()['exists']:
|
||||
if not response.json()['exists']:
|
||||
raise RuntimeError(f'Container image {self.container_image} does not exist')
|
||||
|
||||
def _start_runtime(self):
|
||||
@@ -228,17 +231,11 @@ class RemoteRuntime(Runtime):
|
||||
}
|
||||
|
||||
# Start the sandbox using the /start endpoint
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
response = self._send_request(
|
||||
'POST',
|
||||
f'{self.config.sandbox.remote_runtime_api_url}/start',
|
||||
json=start_request,
|
||||
timeout=300,
|
||||
)
|
||||
if response.status_code != 201:
|
||||
raise RuntimeError(
|
||||
f'[Runtime (ID={self.runtime_id})] Failed to start runtime: {response.text}'
|
||||
)
|
||||
self._parse_runtime_response(response)
|
||||
self.log(
|
||||
'debug',
|
||||
@@ -246,17 +243,12 @@ class RemoteRuntime(Runtime):
|
||||
)
|
||||
|
||||
def _resume_runtime(self):
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
self._send_request(
|
||||
'POST',
|
||||
f'{self.config.sandbox.remote_runtime_api_url}/resume',
|
||||
json={'runtime_id': self.runtime_id},
|
||||
timeout=30,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f'[Runtime (ID={self.runtime_id})] Failed to resume runtime: {response.text}'
|
||||
)
|
||||
self.log('debug', 'Runtime resumed.')
|
||||
|
||||
def _parse_runtime_response(self, response: requests.Response):
|
||||
@@ -268,72 +260,57 @@ class RemoteRuntime(Runtime):
|
||||
{'X-Session-API-Key': start_response['session_api_key']}
|
||||
)
|
||||
|
||||
@tenacity.retry(
|
||||
stop=tenacity.stop_after_delay(180) | stop_if_should_exit(),
|
||||
reraise=True,
|
||||
retry=tenacity.retry_if_exception_type(RuntimeNotReadyError),
|
||||
wait=tenacity.wait_fixed(2),
|
||||
)
|
||||
def _wait_until_alive(self):
|
||||
self.log('debug', f'Waiting for runtime to be alive at url: {self.runtime_url}')
|
||||
# send GET request to /runtime/<id>
|
||||
pod_running = False
|
||||
max_not_found_count = 12 # 2 minutes
|
||||
not_found_count = 0
|
||||
while not pod_running:
|
||||
runtime_info_response = send_request_with_retry(
|
||||
self.session,
|
||||
'GET',
|
||||
f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.runtime_id}',
|
||||
timeout=5,
|
||||
)
|
||||
if runtime_info_response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f'Failed to get runtime status: {runtime_info_response.status_code}. Response: {runtime_info_response.text}'
|
||||
)
|
||||
runtime_data = runtime_info_response.json()
|
||||
assert runtime_data['runtime_id'] == self.runtime_id
|
||||
pod_status = runtime_data['pod_status']
|
||||
self.log(
|
||||
'debug',
|
||||
f'Waiting for runtime pod to be active. Current status: {pod_status}',
|
||||
)
|
||||
if pod_status == 'Ready':
|
||||
pod_running = True
|
||||
break
|
||||
elif pod_status == 'Not Found' and not_found_count < max_not_found_count:
|
||||
not_found_count += 1
|
||||
self.log(
|
||||
'debug',
|
||||
f'Runtime pod not found. Count: {not_found_count} / {max_not_found_count}',
|
||||
)
|
||||
elif pod_status in ('Failed', 'Unknown', 'Not Found'):
|
||||
# clean up the runtime
|
||||
self.close()
|
||||
raise RuntimeError(
|
||||
f'Runtime (ID={self.runtime_id}) failed to start. Current status: {pod_status}'
|
||||
)
|
||||
# Pending otherwise - add proper sleep
|
||||
time.sleep(10)
|
||||
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
runtime_info_response = self._send_request(
|
||||
'GET',
|
||||
f'{self.runtime_url}/alive',
|
||||
# Retry 404 & 503 errors for the /alive endpoint
|
||||
# because the runtime might just be starting up
|
||||
# and have not registered the endpoint yet
|
||||
retry_fns=[is_404_error, is_503_error],
|
||||
# leave enough time for the runtime to start up
|
||||
timeout=600,
|
||||
f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.runtime_id}',
|
||||
)
|
||||
if response.status_code != 200:
|
||||
msg = f'Runtime (ID={self.runtime_id}) is not alive yet. Status: {response.status_code}.'
|
||||
self.log('warning', msg)
|
||||
raise RuntimeError(msg)
|
||||
runtime_data = runtime_info_response.json()
|
||||
assert 'runtime_id' in runtime_data
|
||||
assert runtime_data['runtime_id'] == self.runtime_id
|
||||
assert 'pod_status' in runtime_data
|
||||
pod_status = runtime_data['pod_status']
|
||||
if pod_status == 'Ready':
|
||||
try:
|
||||
self._send_request(
|
||||
'GET',
|
||||
f'{self.runtime_url}/alive',
|
||||
) # will raise exception if we don't get 200 back.
|
||||
except requests.HTTPError as e:
|
||||
self.log(
|
||||
'warning', f"Runtime /alive failed, but pod says it's ready: {e}"
|
||||
)
|
||||
raise RuntimeNotReadyError(
|
||||
f'Runtime /alive failed to respond with 200: {e}'
|
||||
)
|
||||
return
|
||||
if pod_status in ('Failed', 'Unknown', 'Not Found'):
|
||||
# clean up the runtime
|
||||
self.close()
|
||||
raise RuntimeError(
|
||||
f'Runtime (ID={self.runtime_id}) failed to start. Current status: {pod_status}'
|
||||
)
|
||||
|
||||
self.log(
|
||||
'debug',
|
||||
f'Waiting for runtime pod to be active. Current status: {pod_status}',
|
||||
)
|
||||
raise RuntimeNotReadyError()
|
||||
|
||||
def close(self, timeout: int = 10):
|
||||
if self.config.sandbox.keep_remote_runtime_alive or self.attach_to_existing:
|
||||
self.session.close()
|
||||
return
|
||||
if self.runtime_id:
|
||||
if self.runtime_id and self.session:
|
||||
try:
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
response = self._send_request(
|
||||
'POST',
|
||||
f'{self.config.sandbox.remote_runtime_api_url}/stop',
|
||||
json={'runtime_id': self.runtime_id},
|
||||
@@ -361,12 +338,11 @@ class RemoteRuntime(Runtime):
|
||||
return NullObservation('')
|
||||
action_type = action.action # type: ignore[attr-defined]
|
||||
if action_type not in ACTION_TYPE_TO_CLASS:
|
||||
return FatalErrorObservation(
|
||||
f'[Runtime (ID={self.runtime_id})] Action {action_type} does not exist.'
|
||||
)
|
||||
raise ValueError(f'Action {action_type} does not exist.')
|
||||
if not hasattr(self, action_type):
|
||||
return FatalErrorObservation(
|
||||
f'[Runtime (ID={self.runtime_id})] Action {action_type} is not supported in the current runtime.'
|
||||
return ErrorObservation(
|
||||
f'[Runtime (ID={self.runtime_id})] Action {action_type} is not supported in the current runtime.',
|
||||
error_id='AGENT_ERROR$BAD_ACTION',
|
||||
)
|
||||
|
||||
assert action.timeout is not None
|
||||
@@ -374,36 +350,37 @@ class RemoteRuntime(Runtime):
|
||||
try:
|
||||
request_body = {'action': event_to_dict(action)}
|
||||
self.log('debug', f'Request body: {request_body}')
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
response = self._send_request(
|
||||
'POST',
|
||||
f'{self.runtime_url}/execute_action',
|
||||
json=request_body,
|
||||
timeout=action.timeout,
|
||||
# wait a few more seconds to get the timeout error from client side
|
||||
timeout=action.timeout + 5,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
output = response.json()
|
||||
obs = observation_from_dict(output)
|
||||
obs._cause = action.id # type: ignore[attr-defined]
|
||||
return obs
|
||||
else:
|
||||
error_message = response.text
|
||||
self.log('error', f'Error from server: {error_message}')
|
||||
obs = FatalErrorObservation(
|
||||
f'Action execution failed: {error_message}'
|
||||
)
|
||||
except Timeout:
|
||||
self.log('error', 'No response received within the timeout period.')
|
||||
obs = FatalErrorObservation(
|
||||
f'[Runtime (ID={self.runtime_id})] Action execution timed out'
|
||||
)
|
||||
except Exception as e:
|
||||
self.log('error', f'Error during action execution: {e}')
|
||||
obs = FatalErrorObservation(
|
||||
f'[Runtime (ID={self.runtime_id})] Action execution failed: {str(e)}'
|
||||
output = response.json()
|
||||
obs = observation_from_dict(output)
|
||||
obs._cause = action.id # type: ignore[attr-defined]
|
||||
except requests.Timeout:
|
||||
raise RuntimeError(
|
||||
f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s'
|
||||
)
|
||||
return obs
|
||||
|
||||
def _send_request(self, method, url, **kwargs):
|
||||
is_runtime_request = self.runtime_url and self.runtime_url in url
|
||||
try:
|
||||
return send_request(self.session, method, url, **kwargs)
|
||||
except requests.Timeout:
|
||||
self.log('error', 'No response received within the timeout period.')
|
||||
raise
|
||||
except requests.HTTPError as e:
|
||||
if is_runtime_request and e.response.status_code == 404:
|
||||
raise RuntimeDisconnectedError(
|
||||
f'404 error while connecting to {self.runtime_url}'
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
def run(self, action: CmdRunAction) -> Observation:
|
||||
return self.run_action(action)
|
||||
|
||||
@@ -450,32 +427,16 @@ class RemoteRuntime(Runtime):
|
||||
|
||||
params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
|
||||
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
response = self._send_request(
|
||||
'POST',
|
||||
f'{self.runtime_url}/upload_file',
|
||||
files=upload_data,
|
||||
params=params,
|
||||
timeout=300,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
self.log(
|
||||
'debug',
|
||||
f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}',
|
||||
)
|
||||
return
|
||||
else:
|
||||
error_message = response.text
|
||||
raise Exception(
|
||||
f'[Runtime (ID={self.runtime_id})] Copy operation failed: {error_message}'
|
||||
)
|
||||
except TimeoutError:
|
||||
raise TimeoutError(
|
||||
f'[Runtime (ID={self.runtime_id})] Copy operation timed out'
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f'[Runtime (ID={self.runtime_id})] Copy operation failed: {str(e)}'
|
||||
self.log(
|
||||
'debug',
|
||||
f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}',
|
||||
)
|
||||
finally:
|
||||
if recursive:
|
||||
@@ -485,64 +446,27 @@ class RemoteRuntime(Runtime):
|
||||
)
|
||||
|
||||
def list_files(self, path: str | None = None) -> list[str]:
|
||||
try:
|
||||
data = {}
|
||||
if path is not None:
|
||||
data['path'] = path
|
||||
data = {}
|
||||
if path is not None:
|
||||
data['path'] = path
|
||||
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
'POST',
|
||||
f'{self.runtime_url}/list_files',
|
||||
json=data,
|
||||
timeout=30,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
response_json = response.json()
|
||||
assert isinstance(response_json, list)
|
||||
return response_json
|
||||
else:
|
||||
error_message = response.text
|
||||
raise Exception(
|
||||
f'[Runtime (ID={self.runtime_id})] List files operation failed: {error_message}'
|
||||
)
|
||||
except TimeoutError:
|
||||
raise TimeoutError(
|
||||
f'[Runtime (ID={self.runtime_id})] List files operation timed out'
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f'[Runtime (ID={self.runtime_id})] List files operation failed: {str(e)}'
|
||||
)
|
||||
response = self._send_request(
|
||||
'POST',
|
||||
f'{self.runtime_url}/list_files',
|
||||
json=data,
|
||||
timeout=30,
|
||||
)
|
||||
response_json = response.json()
|
||||
assert isinstance(response_json, list)
|
||||
return response_json
|
||||
|
||||
def copy_from(self, path: str) -> bytes:
|
||||
"""Zip all files in the sandbox and return as a stream of bytes."""
|
||||
try:
|
||||
params = {'path': path}
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
'GET',
|
||||
f'{self.runtime_url}/download_files',
|
||||
params=params,
|
||||
timeout=30,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return response.content
|
||||
else:
|
||||
error_message = response.text
|
||||
raise Exception(
|
||||
f'[Runtime (ID={self.runtime_id})] Copy operation failed: {error_message}'
|
||||
)
|
||||
except requests.Timeout:
|
||||
raise TimeoutError(
|
||||
f'[Runtime (ID={self.runtime_id})] Copy operation timed out'
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f'[Runtime (ID={self.runtime_id})] Copy operation failed: {str(e)}'
|
||||
)
|
||||
|
||||
def send_status_message(self, message: str):
|
||||
"""Sends a status message if the callback function was provided."""
|
||||
if self.status_message_callback:
|
||||
self.status_message_callback(message)
|
||||
params = {'path': path}
|
||||
response = self._send_request(
|
||||
'GET',
|
||||
f'{self.runtime_url}/download_files',
|
||||
params=params,
|
||||
timeout=30,
|
||||
)
|
||||
return response.content
|
||||
|
||||
Reference in New Issue
Block a user