From 45fb4fb9bc70c1cc904f674420950be286a70db0 Mon Sep 17 00:00:00 2001 From: Robert Brennan Date: Wed, 9 Oct 2024 12:37:52 -0400 Subject: [PATCH] allow reconnecting to a runtime (#4223) --- evaluation/swe_bench/run_infer.py | 1 + frontend/src/components/AgentStatusBar.tsx | 12 +- openhands/core/config/sandbox_config.py | 1 + openhands/runtime/remote/runtime.py | 213 ++++++++++++--------- openhands/server/session/agent_session.py | 2 +- 5 files changed, 137 insertions(+), 92 deletions(-) diff --git a/evaluation/swe_bench/run_infer.py b/evaluation/swe_bench/run_infer.py index 4d65500cb4..835f80ef55 100644 --- a/evaluation/swe_bench/run_infer.py +++ b/evaluation/swe_bench/run_infer.py @@ -133,6 +133,7 @@ def get_config( timeout=300, api_key=os.environ.get('ALLHANDS_API_KEY', None), remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'), + keep_remote_runtime_alive=False, ), # do not mount workspace workspace_base=None, diff --git a/frontend/src/components/AgentStatusBar.tsx b/frontend/src/components/AgentStatusBar.tsx index 516b87dca1..c337a838f3 100644 --- a/frontend/src/components/AgentStatusBar.tsx +++ b/frontend/src/components/AgentStatusBar.tsx @@ -94,12 +94,14 @@ function AgentStatusBar() { const [statusMessage, setStatusMessage] = React.useState(""); React.useEffect(() => { - const trimmedCustomMessage = curStatusMessage.status.trim(); - if (trimmedCustomMessage) { - setStatusMessage(t(trimmedCustomMessage)); - } else { - setStatusMessage(AgentStatusMap[curAgentState].message); + if (curAgentState === AgentState.LOADING) { + const trimmedCustomMessage = curStatusMessage.status.trim(); + if (trimmedCustomMessage) { + setStatusMessage(t(trimmedCustomMessage)); + return; + } } + setStatusMessage(AgentStatusMap[curAgentState].message); }, [curAgentState, curStatusMessage.status]); return ( diff --git a/openhands/core/config/sandbox_config.py b/openhands/core/config/sandbox_config.py index 774ff078c9..43f5773052 100644 --- a/openhands/core/config/sandbox_config.py +++ b/openhands/core/config/sandbox_config.py @@ -34,6 +34,7 @@ class SandboxConfig: remote_runtime_api_url: str = 'http://localhost:8000' local_runtime_url: str = 'http://localhost' + keep_remote_runtime_alive: bool = True api_key: str | None = None base_container_image: str = 'nikolaik/python-nodejs:python3.12-nodejs22' # default to nikolaik/python-nodejs:python3.12-nodejs22 for eventstream runtime runtime_container_image: str | None = None diff --git a/openhands/runtime/remote/runtime.py b/openhands/runtime/remote/runtime.py index 072968e41c..5e2c35af8d 100644 --- a/openhands/runtime/remote/runtime.py +++ b/openhands/runtime/remote/runtime.py @@ -1,7 +1,6 @@ import os import tempfile import threading -import uuid from typing import Callable, Optional from zipfile import ZipFile @@ -11,7 +10,7 @@ from tenacity import ( retry, retry_if_exception_type, stop_after_attempt, - wait_exponential, + wait_fixed, ) from openhands.core.config import AppConfig @@ -60,14 +59,13 @@ class RemoteRuntime(Runtime): status_message_callback: Optional[Callable] = None, ): self.config = config + self.status_message_callback = status_message_callback if self.config.sandbox.api_key is None: raise ValueError( 'API key is required to use the remote runtime. ' 'Please set the API key in the config (config.toml) or as an environment variable (SANDBOX_API_KEY).' ) - self.status_message_callback = status_message_callback - self.send_status_message('STATUS$STARTING_RUNTIME') self.session = requests.Session() self.session.headers.update({'X-API-Key': self.config.sandbox.api_key}) self.action_semaphore = threading.Semaphore(1) @@ -83,61 +81,116 @@ class RemoteRuntime(Runtime): self.runtime_id: str | None = None self.runtime_url: str | None = None - self.instance_id = ( - sid + str(uuid.uuid4()) if sid is not None else str(uuid.uuid4()) + self.instance_id = sid + + self._start_or_attach_to_runtime(plugins) + + # Initialize the eventstream and env vars + super().__init__( + config, event_stream, sid, plugins, env_vars, status_message_callback ) - self.container_name = 'oh-remote-runtime-' + self.instance_id - if self.config.sandbox.runtime_container_image is not None: - logger.info( - f'Running remote runtime with image: {self.config.sandbox.runtime_container_image}' - ) - self.container_image = self.config.sandbox.runtime_container_image + self._wait_until_alive() + self.setup_initial_env() + + def _start_or_attach_to_runtime(self, plugins: list[PluginRequirement] | None): + existing_runtime = self._check_existing_runtime() + if existing_runtime: + logger.info(f'Using existing runtime with ID: {self.runtime_id}') else: - logger.info( - f'Building remote runtime with base image: {self.config.sandbox.base_container_image}' - ) - logger.debug(f'RemoteRuntime `{sid}` config:\n{self.config}') - response = send_request_with_retry( - self.session, - 'GET', - f'{self.config.sandbox.remote_runtime_api_url}/registry_prefix', - timeout=30, - ) - response_json = response.json() - registry_prefix = response_json['registry_prefix'] - os.environ['OH_RUNTIME_RUNTIME_IMAGE_REPO'] = ( - registry_prefix.rstrip('/') + '/runtime' - ) - logger.info( - f'Runtime image repo: {os.environ["OH_RUNTIME_RUNTIME_IMAGE_REPO"]}' - ) - - if self.config.sandbox.runtime_extra_deps: - logger.info( - f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}' - ) - - # Build the container image self.send_status_message('STATUS$STARTING_CONTAINER') - self.container_image = build_runtime_image( - self.config.sandbox.base_container_image, - self.runtime_builder, - extra_deps=self.config.sandbox.runtime_extra_deps, - force_rebuild=self.config.sandbox.force_rebuild_runtime, - ) + if self.config.sandbox.runtime_container_image is None: + logger.info( + f'Building remote runtime with base image: {self.config.sandbox.base_container_image}' + ) + self._build_runtime() + else: + logger.info( + f'Running remote runtime with image: {self.config.sandbox.runtime_container_image}' + ) + self._start_runtime(plugins) + assert ( + self.runtime_id is not None + ), 'Runtime ID is not set. This should never happen.' + assert ( + self.runtime_url is not None + ), 'Runtime URL is not set. This should never happen.' + self.send_status_message('STATUS$WAITING_FOR_CLIENT') + self._wait_until_alive() + def _check_existing_runtime(self) -> bool: + try: response = send_request_with_retry( self.session, 'GET', - f'{self.config.sandbox.remote_runtime_api_url}/image_exists', - params={'image': self.container_image}, - timeout=30, + f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.instance_id}', + timeout=5, ) - if response.status_code != 200 or not response.json()['exists']: - raise RuntimeError( - f'Container image {self.container_image} does not exist' - ) + except Exception as e: + logger.error(f'Error while looking for remote runtime: {e}') + return False + if response.status_code == 200: + data = response.json() + status = data.get('status') + if status == 'running': + self._parse_runtime_response(response) + return True + elif status == 'stopped': + logger.info('Found existing remote runtime, but it is stopped') + return False + elif status == 'paused': + logger.info('Found existing remote runtime, but it is paused') + self._parse_runtime_response(response) + self._resume_runtime() + return True + else: + logger.error(f'Invalid response from runtime API: {data}') + return False + else: + logger.info('Could not find existing remote runtime') + return False + + def _build_runtime(self): + logger.debug(f'RemoteRuntime `{self.instance_id}` config:\n{self.config}') + response = send_request_with_retry( + self.session, + 'GET', + f'{self.config.sandbox.remote_runtime_api_url}/registry_prefix', + timeout=30, + ) + response_json = response.json() + registry_prefix = response_json['registry_prefix'] + os.environ['OH_RUNTIME_RUNTIME_IMAGE_REPO'] = ( + registry_prefix.rstrip('/') + '/runtime' + ) + logger.info( + f'Runtime image repo: {os.environ["OH_RUNTIME_RUNTIME_IMAGE_REPO"]}' + ) + + if self.config.sandbox.runtime_extra_deps: + logger.info( + f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}' + ) + + # Build the container image + self.container_image = build_runtime_image( + self.config.sandbox.base_container_image, + self.runtime_builder, + extra_deps=self.config.sandbox.runtime_extra_deps, + force_rebuild=self.config.sandbox.force_rebuild_runtime, + ) + + response = send_request_with_retry( + self.session, + 'GET', + f'{self.config.sandbox.remote_runtime_api_url}/image_exists', + params={'image': self.container_image}, + timeout=30, + ) + if response.status_code != 200 or not response.json()['exists']: + raise RuntimeError(f'Container image {self.container_image} does not exist') + + def _start_runtime(self, plugins: list[PluginRequirement] | None): # Prepare the request body for the /start endpoint plugin_arg = '' if plugins is not None and len(plugins) > 0: @@ -160,11 +213,10 @@ class RemoteRuntime(Runtime): f'{browsergym_arg}' ), 'working_dir': '/openhands/code/', - 'name': self.container_name, 'environment': {'DEBUG': 'true'} if self.config.debug else {}, + 'runtime_id': self.instance_id, } - self.send_status_message('STATUS$WAITING_FOR_CLIENT') # Start the sandbox using the /start endpoint response = send_request_with_retry( self.session, @@ -175,45 +227,35 @@ class RemoteRuntime(Runtime): ) if response.status_code != 201: raise RuntimeError(f'Failed to start sandbox: {response.text}') - start_response = response.json() - self.runtime_id = start_response['runtime_id'] - self.runtime_url = start_response['url'] - + self._parse_runtime_response(response) logger.info( f'Sandbox started. Runtime ID: {self.runtime_id}, URL: {self.runtime_url}' ) + def _resume_runtime(self): + response = send_request_with_retry( + self.session, + 'POST', + f'{self.config.sandbox.remote_runtime_api_url}/resume', + json={'runtime_id': self.runtime_id}, + timeout=30, + ) + if response.status_code != 200: + raise RuntimeError(f'Failed to resume sandbox: {response.text}') + logger.info(f'Sandbox resumed. Runtime ID: {self.runtime_id}') + + def _parse_runtime_response(self, response: requests.Response): + start_response = response.json() + self.runtime_id = start_response['runtime_id'] + self.runtime_url = start_response['url'] if 'session_api_key' in start_response: self.session.headers.update( {'X-Session-API-Key': start_response['session_api_key']} ) - # Initialize the eventstream and env vars - super().__init__( - config, event_stream, sid, plugins, env_vars, status_message_callback - ) - - logger.info( - f'Runtime initialized with plugins: {[plugin.name for plugin in self.plugins]}' - ) - logger.info(f'Runtime initialized with env vars: {env_vars}') - assert ( - self.runtime_id is not None - ), 'Runtime ID is not set. This should never happen.' - assert ( - self.runtime_url is not None - ), 'Runtime URL is not set. This should never happen.' - - self._wait_until_alive() - - self.send_status_message(' ') - - self._wait_until_alive() - self.setup_initial_env() - @retry( - stop=stop_after_attempt(10) | stop_if_should_exit(), - wait=wait_exponential(multiplier=1, min=4, max=60), + stop=stop_after_attempt(60) | stop_if_should_exit(), + wait=wait_fixed(2), retry=retry_if_exception_type(RuntimeError), reraise=True, ) @@ -236,6 +278,9 @@ class RemoteRuntime(Runtime): raise RuntimeError(msg) def close(self, timeout: int = 10): + if self.config.sandbox.keep_remote_runtime_alive: + self.session.close() + return if self.runtime_id: try: response = send_request_with_retry( @@ -268,8 +313,6 @@ class RemoteRuntime(Runtime): f'Action {action_type} is not supported in the current runtime.' ) - self._wait_until_alive() - assert action.timeout is not None try: @@ -331,7 +374,6 @@ class RemoteRuntime(Runtime): if not os.path.exists(host_src): raise FileNotFoundError(f'Source file {host_src} does not exist') - self._wait_until_alive() try: if recursive: with tempfile.NamedTemporaryFile( @@ -383,7 +425,6 @@ class RemoteRuntime(Runtime): logger.info(f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}') def list_files(self, path: str | None = None) -> list[str]: - self._wait_until_alive() try: data = {} if path is not None: @@ -397,7 +438,7 @@ class RemoteRuntime(Runtime): retry_exceptions=list( filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS) ), - timeout=30, # The runtime sbould already be running here + timeout=30, ) if response.status_code == 200: response_json = response.json() diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index 6eb2faa854..f172021a37 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -180,7 +180,7 @@ class AgentSession: status_message_callback=status_message_callback, ) except Exception as e: - logger.error(f'Runtime initialization failed: {e}') + logger.error(f'Runtime initialization failed: {e}', exc_info=True) raise if self.runtime is not None: