mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
refactor & improve retry for the reliability of RemoteRuntime & evaluation (#3846)
This commit is contained in:
@@ -176,6 +176,7 @@ def initialize_runtime(
|
||||
|
||||
# inject the instance info
|
||||
action = CmdRunAction(command='mkdir -p /swe_util/eval_data/instances')
|
||||
action.timeout = 600
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
@@ -233,6 +234,7 @@ def initialize_runtime(
|
||||
), f'Failed to source /swe_util/swe_entry.sh: {obs.content}'
|
||||
|
||||
action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
|
||||
action.timeout = 600
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
@@ -5,6 +5,7 @@ import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
@@ -77,6 +78,12 @@ class EvalOutput(BaseModel):
|
||||
return json.dumps(dumped_dict)
|
||||
|
||||
|
||||
class EvalError(BaseModel):
|
||||
instance_id: str
|
||||
error: str
|
||||
stacktrace: str
|
||||
|
||||
|
||||
def codeact_user_response(
|
||||
state: State,
|
||||
encapsulate_solution: bool = False,
|
||||
@@ -227,6 +234,20 @@ def prepare_dataset(
|
||||
return pd.DataFrame(new_dataset)
|
||||
|
||||
|
||||
def process_instance(
|
||||
instance, metadata, use_multiprocessing, process_instance_func
|
||||
) -> EvalOutput | EvalError:
|
||||
try:
|
||||
return process_instance_func(instance, metadata, use_multiprocessing)
|
||||
except Exception as e:
|
||||
logger.error(f'Error processing instance [{instance.instance_id}]: {e}')
|
||||
return EvalError(
|
||||
instance_id=instance.instance_id,
|
||||
error=str(e),
|
||||
stacktrace=traceback.format_exc(),
|
||||
)
|
||||
|
||||
|
||||
def run_evaluation(
|
||||
dataset: pd.DataFrame,
|
||||
metadata: EvalMetadata,
|
||||
@@ -241,42 +262,65 @@ def run_evaluation(
|
||||
f'Evaluation started with Agent {metadata.agent_class}:\n'
|
||||
f'model {metadata.llm_config.model}, max iterations {metadata.max_iterations}.\n'
|
||||
)
|
||||
pbar = tqdm(total=len(dataset))
|
||||
|
||||
instance_queue = mp.Queue()
|
||||
for _, instance in dataset.iterrows():
|
||||
instance_queue.put(instance)
|
||||
|
||||
total_instances = instance_queue.qsize()
|
||||
pbar = tqdm(total=total_instances, desc='Instances processed')
|
||||
output_fp = open(output_file, 'a')
|
||||
|
||||
def update_progress(future):
|
||||
pbar.update(1)
|
||||
output: EvalOutput = future.result() if use_multiprocessing else future
|
||||
|
||||
pbar.set_description(f'Instance {output.instance_id}')
|
||||
pbar.set_postfix_str(f'Test Result: {output.test_result}')
|
||||
logger.info(
|
||||
f'Finished evaluation for instance {output.instance_id}: {str(output.test_result)[:300]}...\n'
|
||||
)
|
||||
output_fp.write(json.dumps(output.model_dump()) + '\n')
|
||||
output_fp.flush()
|
||||
def update_progress(result: EvalOutput | EvalError, instance: pd.Series):
|
||||
if isinstance(result, EvalOutput):
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'Instance {result.instance_id}')
|
||||
pbar.set_postfix_str(f'Test Result: {result.test_result}')
|
||||
logger.info(
|
||||
f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
|
||||
)
|
||||
output_fp.write(json.dumps(result.model_dump()) + '\n')
|
||||
output_fp.flush()
|
||||
else:
|
||||
logger.error(
|
||||
f'Retrying instance [{instance.instance_id}] due to error: {result.error}. Stacktrace:\n{result.stacktrace}'
|
||||
+ '\n'
|
||||
+ '-' * 10
|
||||
+ '[You may ignore this error if it is a transient issue - the instance will be automatically retried.]'
|
||||
+ '-' * 10
|
||||
+ '\n'
|
||||
)
|
||||
instance_queue.put(instance)
|
||||
pbar.total += 1
|
||||
pbar.refresh()
|
||||
|
||||
try:
|
||||
if use_multiprocessing:
|
||||
with ProcessPoolExecutor(num_workers) as executor:
|
||||
futures = []
|
||||
for _, instance in dataset.iterrows():
|
||||
future = executor.submit(
|
||||
process_instance_func,
|
||||
instance,
|
||||
metadata,
|
||||
bool(num_workers > 1),
|
||||
)
|
||||
future.add_done_callback(update_progress)
|
||||
futures.append(future)
|
||||
for future in futures:
|
||||
future.result()
|
||||
# Use plain for loop for single process for easier debugging
|
||||
while not instance_queue.empty():
|
||||
futures = []
|
||||
for _ in range(min(num_workers, instance_queue.qsize())):
|
||||
instance = instance_queue.get()
|
||||
future = executor.submit(
|
||||
process_instance,
|
||||
instance,
|
||||
metadata,
|
||||
True,
|
||||
process_instance_func,
|
||||
)
|
||||
future.add_done_callback(
|
||||
lambda f, inst=instance: update_progress(f.result(), inst)
|
||||
)
|
||||
futures.append(future)
|
||||
for future in futures:
|
||||
future.result()
|
||||
else:
|
||||
assert num_workers == 1
|
||||
for _, instance in dataset.iterrows():
|
||||
output = process_instance_func(instance, metadata, False)
|
||||
update_progress(output)
|
||||
while not instance_queue.empty():
|
||||
instance = instance_queue.get()
|
||||
result = process_instance(
|
||||
instance, metadata, False, process_instance_func
|
||||
)
|
||||
update_progress(result, instance)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print('\nKeyboardInterrupt received. Cleaning up...\n')
|
||||
|
||||
@@ -7,6 +7,7 @@ import requests
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.runtime.builder import RuntimeBuilder
|
||||
from openhands.runtime.utils.request import send_request
|
||||
|
||||
|
||||
class RemoteRuntimeBuilder(RuntimeBuilder):
|
||||
@@ -15,6 +16,8 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
|
||||
def __init__(self, api_url: str, api_key: str):
|
||||
self.api_url = api_url
|
||||
self.api_key = api_key
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({'X-API-Key': self.api_key})
|
||||
|
||||
def build(self, path: str, tags: list[str]) -> str:
|
||||
"""Builds a Docker image using the Runtime API's /build endpoint."""
|
||||
@@ -38,8 +41,9 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
|
||||
files.append(('tags', (None, tag)))
|
||||
|
||||
# Send the POST request to /build
|
||||
headers = {'X-API-Key': self.api_key}
|
||||
response = requests.post(f'{self.api_url}/build', files=files, headers=headers)
|
||||
response = send_request(
|
||||
self.session, 'POST', f'{self.api_url}/build', files=files
|
||||
)
|
||||
|
||||
if response.status_code != 202:
|
||||
logger.error(f'Build initiation failed: {response.text}')
|
||||
@@ -57,10 +61,11 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
|
||||
logger.error('Build timed out after 30 minutes')
|
||||
raise RuntimeError('Build timed out after 30 minutes')
|
||||
|
||||
status_response = requests.get(
|
||||
status_response = send_request(
|
||||
self.session,
|
||||
'GET',
|
||||
f'{self.api_url}/build_status',
|
||||
params={'build_id': build_id},
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if status_response.status_code != 200:
|
||||
@@ -90,14 +95,14 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
# Wait before polling again
|
||||
time.sleep(5)
|
||||
time.sleep(30)
|
||||
|
||||
def image_exists(self, image_name: str) -> bool:
|
||||
"""Checks if an image exists in the remote registry using the /image_exists endpoint."""
|
||||
params = {'image': image_name}
|
||||
session = requests.Session()
|
||||
session.headers.update({'X-API-Key': self.api_key})
|
||||
response = session.get(f'{self.api_url}/image_exists', params=params)
|
||||
response = send_request(
|
||||
self.session, 'GET', f'{self.api_url}/image_exists', params=params
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f'Failed to check image existence: {response.text}')
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import os
|
||||
import ssl
|
||||
import tempfile
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Any, Type
|
||||
from zipfile import ZipFile
|
||||
|
||||
import requests
|
||||
from requests.exceptions import HTTPError, RequestException, Timeout
|
||||
from requests.exceptions import Timeout
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
@@ -37,15 +35,13 @@ from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
|
||||
from openhands.runtime.builder.remote import RemoteRuntimeBuilder
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
from openhands.runtime.runtime import Runtime
|
||||
from openhands.runtime.utils.request import (
|
||||
DEFAULT_RETRY_EXCEPTIONS,
|
||||
is_404_error,
|
||||
send_request,
|
||||
)
|
||||
from openhands.runtime.utils.runtime_build import build_runtime_image
|
||||
|
||||
DEFAULT_RETRY_EXCEPTIONS = [
|
||||
ssl.SSLCertVerificationError,
|
||||
RequestException,
|
||||
HTTPError,
|
||||
Timeout,
|
||||
]
|
||||
|
||||
|
||||
class RemoteRuntime(Runtime):
|
||||
"""This runtime will connect to a remote od-runtime-client."""
|
||||
@@ -99,7 +95,7 @@ class RemoteRuntime(Runtime):
|
||||
self.container_image: str = self.config.sandbox.base_container_image
|
||||
self.container_name = 'od-remote-runtime-' + self.instance_id
|
||||
logger.debug(f'RemoteRuntime `{sid}` config:\n{self.config}')
|
||||
response = self._send_request('GET', f'{self.api_url}/registry_prefix')
|
||||
response = send_request(self.session, 'GET', f'{self.api_url}/registry_prefix')
|
||||
response_json = response.json()
|
||||
registry_prefix = response_json['registry_prefix']
|
||||
os.environ['OD_RUNTIME_RUNTIME_IMAGE_REPO'] = (
|
||||
@@ -122,7 +118,8 @@ class RemoteRuntime(Runtime):
|
||||
)
|
||||
|
||||
# Use the /image_exists endpoint to check if the image exists
|
||||
response = self._send_request(
|
||||
response = send_request(
|
||||
self.session,
|
||||
'GET',
|
||||
f'{self.api_url}/image_exists',
|
||||
params={'image': self.container_image},
|
||||
@@ -157,8 +154,8 @@ class RemoteRuntime(Runtime):
|
||||
}
|
||||
|
||||
# Start the sandbox using the /start endpoint
|
||||
response = self._send_request(
|
||||
'POST', f'{self.api_url}/start', json=start_request
|
||||
response = send_request(
|
||||
self.session, 'POST', f'{self.api_url}/start', json=start_request
|
||||
)
|
||||
if response.status_code != 201:
|
||||
raise RuntimeError(f'Failed to start sandbox: {response.text}')
|
||||
@@ -184,29 +181,6 @@ class RemoteRuntime(Runtime):
|
||||
self.runtime_url is not None
|
||||
), 'Runtime URL is not set. This should never happen.'
|
||||
|
||||
def _send_request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
retry_exceptions: list[Type[Exception]] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> requests.Response:
|
||||
if retry_exceptions is None:
|
||||
retry_exceptions = DEFAULT_RETRY_EXCEPTIONS
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(30),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_if_exception_type(tuple(retry_exceptions)),
|
||||
reraise=True,
|
||||
)
|
||||
def _send_request_with_retry():
|
||||
response = self.session.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
return _send_request_with_retry()
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(10),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
@@ -215,7 +189,15 @@ class RemoteRuntime(Runtime):
|
||||
)
|
||||
def _wait_until_alive(self):
|
||||
logger.info('Waiting for sandbox to be alive...')
|
||||
response = self._send_request('GET', f'{self.runtime_url}/alive')
|
||||
response = send_request(
|
||||
self.session,
|
||||
'GET',
|
||||
f'{self.runtime_url}/alive',
|
||||
# Retry 404 errors for the /alive endpoint
|
||||
# because the runtime might just be starting up
|
||||
# and have not registered the endpoint yet
|
||||
retry_fns=[is_404_error],
|
||||
)
|
||||
if response.status_code != 200:
|
||||
msg = f'Runtime is not alive yet (id={self.runtime_id}). Status: {response.status_code}.'
|
||||
logger.warning(msg)
|
||||
@@ -228,8 +210,11 @@ class RemoteRuntime(Runtime):
|
||||
def close(self):
|
||||
if self.runtime_id:
|
||||
try:
|
||||
response = self._send_request(
|
||||
'POST', f'{self.api_url}/stop', json={'runtime_id': self.runtime_id}
|
||||
response = send_request(
|
||||
self.session,
|
||||
'POST',
|
||||
f'{self.api_url}/stop',
|
||||
json={'runtime_id': self.runtime_id},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.error(f'Failed to stop sandbox: {response.text}')
|
||||
@@ -262,7 +247,8 @@ class RemoteRuntime(Runtime):
|
||||
logger.info('Executing action')
|
||||
request_body = {'action': event_to_dict(action)}
|
||||
logger.debug(f'Request body: {request_body}')
|
||||
response = self._send_request(
|
||||
response = send_request(
|
||||
self.session,
|
||||
'POST',
|
||||
f'{self.runtime_url}/execute_action',
|
||||
json=request_body,
|
||||
@@ -270,6 +256,10 @@ class RemoteRuntime(Runtime):
|
||||
retry_exceptions=list(
|
||||
filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
|
||||
),
|
||||
# Retry 404 errors for the /execute_action endpoint
|
||||
# because the runtime might just be starting up
|
||||
# and have not registered the endpoint yet
|
||||
retry_fns=[is_404_error],
|
||||
)
|
||||
if response.status_code == 200:
|
||||
output = response.json()
|
||||
@@ -335,7 +325,8 @@ class RemoteRuntime(Runtime):
|
||||
|
||||
params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
|
||||
|
||||
response = self._send_request(
|
||||
response = send_request(
|
||||
self.session,
|
||||
'POST',
|
||||
f'{self.runtime_url}/upload_file',
|
||||
files=upload_data,
|
||||
@@ -368,7 +359,8 @@ class RemoteRuntime(Runtime):
|
||||
if path is not None:
|
||||
data['path'] = path
|
||||
|
||||
response = self._send_request(
|
||||
response = send_request(
|
||||
self.session,
|
||||
'POST',
|
||||
f'{self.runtime_url}/list_files',
|
||||
json=data,
|
||||
|
||||
62
openhands/runtime/utils/request.py
Normal file
62
openhands/runtime/utils/request.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import Any, Callable, Type
|
||||
|
||||
import requests
|
||||
from requests.exceptions import ConnectionError, Timeout
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
|
||||
def is_server_error(exception):
|
||||
return (
|
||||
isinstance(exception, requests.HTTPError)
|
||||
and exception.response.status_code >= 500
|
||||
)
|
||||
|
||||
|
||||
def is_404_error(exception):
|
||||
return (
|
||||
isinstance(exception, requests.HTTPError)
|
||||
and exception.response.status_code == 404
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_RETRY_EXCEPTIONS = [
|
||||
ConnectionError,
|
||||
Timeout,
|
||||
]
|
||||
|
||||
|
||||
def send_request(
|
||||
session: requests.Session,
|
||||
method: str,
|
||||
url: str,
|
||||
retry_exceptions: list[Type[Exception]] | None = None,
|
||||
retry_fns: list[Callable[[Exception], bool]] | None = None,
|
||||
n_attempts: int = 15,
|
||||
**kwargs: Any,
|
||||
) -> requests.Response:
|
||||
exceptions_to_catch = retry_exceptions or DEFAULT_RETRY_EXCEPTIONS
|
||||
retry_condition = retry_if_exception_type(
|
||||
tuple(exceptions_to_catch)
|
||||
) | retry_if_exception(is_server_error)
|
||||
if retry_fns is not None:
|
||||
for fn in retry_fns:
|
||||
retry_condition |= retry_if_exception(fn)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(n_attempts),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=60),
|
||||
retry=retry_condition,
|
||||
reraise=True,
|
||||
)
|
||||
def _send_request_with_retry():
|
||||
response = session.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
return _send_request_with_retry()
|
||||
Reference in New Issue
Block a user