refactor & improve retry for the reliability of RemoteRuntime & evaluation (#3846)

This commit is contained in:
Xingyao Wang
2024-09-13 06:37:07 -05:00
committed by GitHub
parent 7506b20087
commit 78c5f58adc
5 changed files with 185 additions and 80 deletions

View File

@@ -176,6 +176,7 @@ def initialize_runtime(
# inject the instance info
action = CmdRunAction(command='mkdir -p /swe_util/eval_data/instances')
action.timeout = 600
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
@@ -233,6 +234,7 @@ def initialize_runtime(
), f'Failed to source /swe_util/swe_entry.sh: {obs.content}'
action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
action.timeout = 600
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})

View File

@@ -5,6 +5,7 @@ import os
import pathlib
import subprocess
import time
import traceback
from concurrent.futures import ProcessPoolExecutor
from typing import Any, Awaitable, Callable
@@ -77,6 +78,12 @@ class EvalOutput(BaseModel):
return json.dumps(dumped_dict)
class EvalError(BaseModel):
instance_id: str
error: str
stacktrace: str
def codeact_user_response(
state: State,
encapsulate_solution: bool = False,
@@ -227,6 +234,20 @@ def prepare_dataset(
return pd.DataFrame(new_dataset)
def process_instance(
instance, metadata, use_multiprocessing, process_instance_func
) -> EvalOutput | EvalError:
try:
return process_instance_func(instance, metadata, use_multiprocessing)
except Exception as e:
logger.error(f'Error processing instance [{instance.instance_id}]: {e}')
return EvalError(
instance_id=instance.instance_id,
error=str(e),
stacktrace=traceback.format_exc(),
)
def run_evaluation(
dataset: pd.DataFrame,
metadata: EvalMetadata,
@@ -241,42 +262,65 @@ def run_evaluation(
f'Evaluation started with Agent {metadata.agent_class}:\n'
f'model {metadata.llm_config.model}, max iterations {metadata.max_iterations}.\n'
)
pbar = tqdm(total=len(dataset))
instance_queue = mp.Queue()
for _, instance in dataset.iterrows():
instance_queue.put(instance)
total_instances = instance_queue.qsize()
pbar = tqdm(total=total_instances, desc='Instances processed')
output_fp = open(output_file, 'a')
def update_progress(future):
pbar.update(1)
output: EvalOutput = future.result() if use_multiprocessing else future
pbar.set_description(f'Instance {output.instance_id}')
pbar.set_postfix_str(f'Test Result: {output.test_result}')
logger.info(
f'Finished evaluation for instance {output.instance_id}: {str(output.test_result)[:300]}...\n'
)
output_fp.write(json.dumps(output.model_dump()) + '\n')
output_fp.flush()
def update_progress(result: EvalOutput | EvalError, instance: pd.Series):
if isinstance(result, EvalOutput):
pbar.update(1)
pbar.set_description(f'Instance {result.instance_id}')
pbar.set_postfix_str(f'Test Result: {result.test_result}')
logger.info(
f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n'
)
output_fp.write(json.dumps(result.model_dump()) + '\n')
output_fp.flush()
else:
logger.error(
f'Retrying instance [{instance.instance_id}] due to error: {result.error}. Stacktrace:\n{result.stacktrace}'
+ '\n'
+ '-' * 10
+ '[You may ignore this error if it is a transient issue - the instance will be automatically retried.]'
+ '-' * 10
+ '\n'
)
instance_queue.put(instance)
pbar.total += 1
pbar.refresh()
try:
if use_multiprocessing:
with ProcessPoolExecutor(num_workers) as executor:
futures = []
for _, instance in dataset.iterrows():
future = executor.submit(
process_instance_func,
instance,
metadata,
bool(num_workers > 1),
)
future.add_done_callback(update_progress)
futures.append(future)
for future in futures:
future.result()
# Use plain for loop for single process for easier debugging
while not instance_queue.empty():
futures = []
for _ in range(min(num_workers, instance_queue.qsize())):
instance = instance_queue.get()
future = executor.submit(
process_instance,
instance,
metadata,
True,
process_instance_func,
)
future.add_done_callback(
lambda f, inst=instance: update_progress(f.result(), inst)
)
futures.append(future)
for future in futures:
future.result()
else:
assert num_workers == 1
for _, instance in dataset.iterrows():
output = process_instance_func(instance, metadata, False)
update_progress(output)
while not instance_queue.empty():
instance = instance_queue.get()
result = process_instance(
instance, metadata, False, process_instance_func
)
update_progress(result, instance)
except KeyboardInterrupt:
print('\nKeyboardInterrupt received. Cleaning up...\n')

View File

@@ -7,6 +7,7 @@ import requests
from openhands.core.logger import openhands_logger as logger
from openhands.runtime.builder import RuntimeBuilder
from openhands.runtime.utils.request import send_request
class RemoteRuntimeBuilder(RuntimeBuilder):
@@ -15,6 +16,8 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
def __init__(self, api_url: str, api_key: str):
self.api_url = api_url
self.api_key = api_key
self.session = requests.Session()
self.session.headers.update({'X-API-Key': self.api_key})
def build(self, path: str, tags: list[str]) -> str:
"""Builds a Docker image using the Runtime API's /build endpoint."""
@@ -38,8 +41,9 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
files.append(('tags', (None, tag)))
# Send the POST request to /build
headers = {'X-API-Key': self.api_key}
response = requests.post(f'{self.api_url}/build', files=files, headers=headers)
response = send_request(
self.session, 'POST', f'{self.api_url}/build', files=files
)
if response.status_code != 202:
logger.error(f'Build initiation failed: {response.text}')
@@ -57,10 +61,11 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
logger.error('Build timed out after 30 minutes')
raise RuntimeError('Build timed out after 30 minutes')
status_response = requests.get(
status_response = send_request(
self.session,
'GET',
f'{self.api_url}/build_status',
params={'build_id': build_id},
headers=headers,
)
if status_response.status_code != 200:
@@ -90,14 +95,14 @@ class RemoteRuntimeBuilder(RuntimeBuilder):
raise RuntimeError(error_message)
# Wait before polling again
time.sleep(5)
time.sleep(30)
def image_exists(self, image_name: str) -> bool:
"""Checks if an image exists in the remote registry using the /image_exists endpoint."""
params = {'image': image_name}
session = requests.Session()
session.headers.update({'X-API-Key': self.api_key})
response = session.get(f'{self.api_url}/image_exists', params=params)
response = send_request(
self.session, 'GET', f'{self.api_url}/image_exists', params=params
)
if response.status_code != 200:
logger.error(f'Failed to check image existence: {response.text}')

View File

@@ -1,13 +1,11 @@
import os
import ssl
import tempfile
import threading
import uuid
from typing import Any, Type
from zipfile import ZipFile
import requests
from requests.exceptions import HTTPError, RequestException, Timeout
from requests.exceptions import Timeout
from tenacity import (
retry,
retry_if_exception_type,
@@ -37,15 +35,13 @@ from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
from openhands.runtime.builder.remote import RemoteRuntimeBuilder
from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.runtime import Runtime
from openhands.runtime.utils.request import (
DEFAULT_RETRY_EXCEPTIONS,
is_404_error,
send_request,
)
from openhands.runtime.utils.runtime_build import build_runtime_image
DEFAULT_RETRY_EXCEPTIONS = [
ssl.SSLCertVerificationError,
RequestException,
HTTPError,
Timeout,
]
class RemoteRuntime(Runtime):
"""This runtime will connect to a remote od-runtime-client."""
@@ -99,7 +95,7 @@ class RemoteRuntime(Runtime):
self.container_image: str = self.config.sandbox.base_container_image
self.container_name = 'od-remote-runtime-' + self.instance_id
logger.debug(f'RemoteRuntime `{sid}` config:\n{self.config}')
response = self._send_request('GET', f'{self.api_url}/registry_prefix')
response = send_request(self.session, 'GET', f'{self.api_url}/registry_prefix')
response_json = response.json()
registry_prefix = response_json['registry_prefix']
os.environ['OD_RUNTIME_RUNTIME_IMAGE_REPO'] = (
@@ -122,7 +118,8 @@ class RemoteRuntime(Runtime):
)
# Use the /image_exists endpoint to check if the image exists
response = self._send_request(
response = send_request(
self.session,
'GET',
f'{self.api_url}/image_exists',
params={'image': self.container_image},
@@ -157,8 +154,8 @@ class RemoteRuntime(Runtime):
}
# Start the sandbox using the /start endpoint
response = self._send_request(
'POST', f'{self.api_url}/start', json=start_request
response = send_request(
self.session, 'POST', f'{self.api_url}/start', json=start_request
)
if response.status_code != 201:
raise RuntimeError(f'Failed to start sandbox: {response.text}')
@@ -184,29 +181,6 @@ class RemoteRuntime(Runtime):
self.runtime_url is not None
), 'Runtime URL is not set. This should never happen.'
def _send_request(
self,
method: str,
url: str,
retry_exceptions: list[Type[Exception]] | None = None,
**kwargs: Any,
) -> requests.Response:
if retry_exceptions is None:
retry_exceptions = DEFAULT_RETRY_EXCEPTIONS
@retry(
stop=stop_after_attempt(30),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_if_exception_type(tuple(retry_exceptions)),
reraise=True,
)
def _send_request_with_retry():
response = self.session.request(method, url, **kwargs)
response.raise_for_status()
return response
return _send_request_with_retry()
@retry(
stop=stop_after_attempt(10),
wait=wait_exponential(multiplier=1, min=4, max=60),
@@ -215,7 +189,15 @@ class RemoteRuntime(Runtime):
)
def _wait_until_alive(self):
logger.info('Waiting for sandbox to be alive...')
response = self._send_request('GET', f'{self.runtime_url}/alive')
response = send_request(
self.session,
'GET',
f'{self.runtime_url}/alive',
# Retry 404 errors for the /alive endpoint
# because the runtime might just be starting up
# and have not registered the endpoint yet
retry_fns=[is_404_error],
)
if response.status_code != 200:
msg = f'Runtime is not alive yet (id={self.runtime_id}). Status: {response.status_code}.'
logger.warning(msg)
@@ -228,8 +210,11 @@ class RemoteRuntime(Runtime):
def close(self):
if self.runtime_id:
try:
response = self._send_request(
'POST', f'{self.api_url}/stop', json={'runtime_id': self.runtime_id}
response = send_request(
self.session,
'POST',
f'{self.api_url}/stop',
json={'runtime_id': self.runtime_id},
)
if response.status_code != 200:
logger.error(f'Failed to stop sandbox: {response.text}')
@@ -262,7 +247,8 @@ class RemoteRuntime(Runtime):
logger.info('Executing action')
request_body = {'action': event_to_dict(action)}
logger.debug(f'Request body: {request_body}')
response = self._send_request(
response = send_request(
self.session,
'POST',
f'{self.runtime_url}/execute_action',
json=request_body,
@@ -270,6 +256,10 @@ class RemoteRuntime(Runtime):
retry_exceptions=list(
filter(lambda e: e != TimeoutError, DEFAULT_RETRY_EXCEPTIONS)
),
# Retry 404 errors for the /execute_action endpoint
# because the runtime might just be starting up
# and have not registered the endpoint yet
retry_fns=[is_404_error],
)
if response.status_code == 200:
output = response.json()
@@ -335,7 +325,8 @@ class RemoteRuntime(Runtime):
params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()}
response = self._send_request(
response = send_request(
self.session,
'POST',
f'{self.runtime_url}/upload_file',
files=upload_data,
@@ -368,7 +359,8 @@ class RemoteRuntime(Runtime):
if path is not None:
data['path'] = path
response = self._send_request(
response = send_request(
self.session,
'POST',
f'{self.runtime_url}/list_files',
json=data,

View File

@@ -0,0 +1,62 @@
from typing import Any, Callable, Type
import requests
from requests.exceptions import ConnectionError, Timeout
from tenacity import (
retry,
retry_if_exception,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
def is_server_error(exception):
return (
isinstance(exception, requests.HTTPError)
and exception.response.status_code >= 500
)
def is_404_error(exception):
return (
isinstance(exception, requests.HTTPError)
and exception.response.status_code == 404
)
DEFAULT_RETRY_EXCEPTIONS = [
ConnectionError,
Timeout,
]
def send_request(
session: requests.Session,
method: str,
url: str,
retry_exceptions: list[Type[Exception]] | None = None,
retry_fns: list[Callable[[Exception], bool]] | None = None,
n_attempts: int = 15,
**kwargs: Any,
) -> requests.Response:
exceptions_to_catch = retry_exceptions or DEFAULT_RETRY_EXCEPTIONS
retry_condition = retry_if_exception_type(
tuple(exceptions_to_catch)
) | retry_if_exception(is_server_error)
if retry_fns is not None:
for fn in retry_fns:
retry_condition |= retry_if_exception(fn)
@retry(
stop=stop_after_attempt(n_attempts),
wait=wait_exponential(multiplier=1, min=4, max=60),
retry=retry_condition,
reraise=True,
)
def _send_request_with_retry():
response = session.request(method, url, **kwargs)
response.raise_for_status()
return response
return _send_request_with_retry()