mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-10 07:18:10 -05:00
eval: improve swebench infer error handling and retry (#4205)
This commit is contained in:
@@ -13,6 +13,7 @@ from evaluation.swe_bench.prompt import CODEACT_SWE_PROMPT
|
||||
from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
assert_and_raise,
|
||||
codeact_user_response,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
@@ -163,14 +164,16 @@ def initialize_runtime(
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0, f'Failed to export SWE_INSTANCE_ID: {obs.content}'
|
||||
)
|
||||
|
||||
action = CmdRunAction(command="""export USER=$(whoami); echo USER=${USER} """)
|
||||
action.timeout = 600
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to export USER: {obs.content}')
|
||||
|
||||
if USE_INSTANCE_IMAGE:
|
||||
# inject the init script
|
||||
@@ -182,9 +185,10 @@ def initialize_runtime(
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert (
|
||||
obs.exit_code == 0
|
||||
), f'Failed to create /swe_util/eval_data/instances: {obs.content}'
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to create /swe_util/eval_data/instances: {obs.content}',
|
||||
)
|
||||
|
||||
swe_instance_json_name = 'swe-bench-instance.json'
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
@@ -210,44 +214,53 @@ def initialize_runtime(
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to cat ~/.bashrc: {obs.content}')
|
||||
|
||||
action = CmdRunAction(command='source ~/.bashrc')
|
||||
action.timeout = 600
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0, f'Failed to source ~/.bashrc: {obs.content}'
|
||||
)
|
||||
|
||||
action = CmdRunAction(command='source /swe_util/instance_swe_entry.sh')
|
||||
action.timeout = 3600
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to source /swe_util/instance_swe_entry.sh: {obs.content}',
|
||||
)
|
||||
else:
|
||||
action = CmdRunAction(command='source /swe_util/swe_entry.sh')
|
||||
action.timeout = 1800
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert (
|
||||
obs.exit_code == 0
|
||||
), f'Failed to source /swe_util/swe_entry.sh: {obs.content}'
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to source /swe_util/swe_entry.sh: {obs.content}',
|
||||
)
|
||||
|
||||
action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
|
||||
action.timeout = 600
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to cd to /workspace/{workspace_dir_name}: {obs.content}',
|
||||
)
|
||||
|
||||
action = CmdRunAction(command='git reset --hard')
|
||||
action.timeout = 600
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to git reset --hard: {obs.content}')
|
||||
|
||||
action = CmdRunAction(
|
||||
command='for remote_name in $(git remote); do git remote remove "${remote_name}"; done'
|
||||
@@ -256,7 +269,7 @@ def initialize_runtime(
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to remove git remotes: {obs.content}')
|
||||
|
||||
logger.info('-' * 30)
|
||||
logger.info('END Runtime Initialization Fn')
|
||||
@@ -284,21 +297,27 @@ def complete_runtime(
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to cd to /workspace/{workspace_dir_name}: {obs.content}',
|
||||
)
|
||||
|
||||
action = CmdRunAction(command='git config --global core.pager ""')
|
||||
action.timeout = 600
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to git config --global core.pager "": {obs.content}',
|
||||
)
|
||||
|
||||
action = CmdRunAction(command='git add -A')
|
||||
action.timeout = 600
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to git add -A: {obs.content}')
|
||||
|
||||
n_retries = 0
|
||||
git_patch = None
|
||||
@@ -323,7 +342,7 @@ def complete_runtime(
|
||||
logger.error(f'Error occurred: {obs.content}. Retrying...')
|
||||
sleep_if_should_continue(10)
|
||||
else:
|
||||
raise ValueError(f'Unexpected observation type: {type(obs)}')
|
||||
assert_and_raise(False, f'Unexpected observation type: {type(obs)}')
|
||||
|
||||
logger.info('-' * 30)
|
||||
logger.info('END Runtime Completion Fn')
|
||||
@@ -346,31 +365,32 @@ def process_instance(
|
||||
logger.info(f'Starting evaluation for instance {instance.instance_id}.')
|
||||
|
||||
runtime = create_runtime(config, sid=instance.instance_id)
|
||||
initialize_runtime(runtime, instance)
|
||||
try:
|
||||
initialize_runtime(runtime, instance)
|
||||
|
||||
instruction = get_instruction(instance, metadata)
|
||||
instruction = get_instruction(instance, metadata)
|
||||
|
||||
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
||||
metadata.agent_class
|
||||
],
|
||||
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
task_str=instruction,
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
||||
metadata.agent_class
|
||||
],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# ======= THIS IS SWE-Bench specific =======
|
||||
# Get git patch
|
||||
return_val = complete_runtime(runtime, instance)
|
||||
git_patch = return_val['git_patch']
|
||||
logger.info(
|
||||
f'Got git diff for instance {instance.instance_id}:\n--------\n{git_patch}\n--------'
|
||||
)
|
||||
|
||||
runtime.close()
|
||||
# ======= THIS IS SWE-Bench specific =======
|
||||
# Get git patch
|
||||
return_val = complete_runtime(runtime, instance)
|
||||
git_patch = return_val['git_patch']
|
||||
logger.info(
|
||||
f'Got git diff for instance {instance.instance_id}:\n--------\n{git_patch}\n--------'
|
||||
)
|
||||
finally:
|
||||
runtime.close()
|
||||
# ==========================================
|
||||
|
||||
# ======= Attempt to evaluate the agent's edits =======
|
||||
|
||||
@@ -86,6 +86,10 @@ class EvalOutput(BaseModel):
|
||||
return json.dumps(dumped_dict)
|
||||
|
||||
|
||||
class EvalException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def codeact_user_response(
|
||||
state: State,
|
||||
encapsulate_solution: bool = False,
|
||||
@@ -252,6 +256,15 @@ def update_progress(
|
||||
output_fp.flush()
|
||||
|
||||
|
||||
def assert_and_raise(condition: bool, msg: str):
|
||||
"""Raise an EvalException if the condition is not met.
|
||||
|
||||
This will be used in conjunction with _process_instance_wrapper to handle retries. An EvalException should trigger a retry.
|
||||
"""
|
||||
if not condition:
|
||||
raise EvalException(msg)
|
||||
|
||||
|
||||
def _process_instance_wrapper(
|
||||
process_instance_func: Callable[[pd.Series, EvalMetadata, bool], EvalOutput],
|
||||
instance: pd.Series,
|
||||
|
||||
Reference in New Issue
Block a user