Iterative evaluation with rule-based critic (#7293)

This commit is contained in:
Xingyao Wang
2025-03-17 14:37:35 -04:00
committed by GitHub
parent a4b836b5f9
commit 9b9e728cf6
6 changed files with 205 additions and 20 deletions

View File

@@ -18,6 +18,20 @@ Please follow instruction [here](../../README.md#setup) to setup your local deve
## Run Inference (Rollout) on SWE-Bench Instances: Generate Patch from Problem Statement ## Run Inference (Rollout) on SWE-Bench Instances: Generate Patch from Problem Statement
> [!NOTE]
> **Iterative Evaluation Protocol**
>
> We have an iterative approach for more stable and reproducible results:
> - For each instance, we attempt to generate a solution up to 3 times
> - Each attempt continues until either:
> 1. The agent successfully produces a patch with `AgentFinishAction`, or
> 2. The attempt reaches the maximum iteration limit
> - If an attempt fails, we retry with a fresh attempt (up to the 3-attempt maximum)
> - If your LLM config has temperature=0, we will automatically use temperature=0.1 for the 2nd and 3rd attempts
>
> To enable this iterative protocol, set `export ITERATIVE_EVAL_MODE=true`
### Running Locally with Docker ### Running Locally with Docker
Make sure your Docker daemon is running, and you have ample disk space (at least 200-500GB, depends on the SWE-Bench set you are running on) for the instance-level docker image. Make sure your Docker daemon is running, and you have ample disk space (at least 200-500GB, depends on the SWE-Bench set you are running on) for the instance-level docker image.
@@ -45,7 +59,7 @@ to `CodeActAgent`.
default, the script evaluates the entire SWE-bench_Lite test set (300 issues). Note: default, the script evaluates the entire SWE-bench_Lite test set (300 issues). Note:
in order to use `eval_limit`, you must also set `agent`. in order to use `eval_limit`, you must also set `agent`.
- `max_iter`, e.g. `20`, is the maximum number of iterations for the agent to run. By - `max_iter`, e.g. `20`, is the maximum number of iterations for the agent to run. By
default, it is set to 30. default, it is set to 60.
- `num_workers`, e.g. `3`, is the number of parallel workers to run the evaluation. By - `num_workers`, e.g. `3`, is the number of parallel workers to run the evaluation. By
default, it is set to 1. default, it is set to 1.
- `dataset`, a huggingface dataset name. e.g. `princeton-nlp/SWE-bench`, `princeton-nlp/SWE-bench_Lite`, or `princeton-nlp/SWE-bench_Verified`, specifies which dataset to evaluate on. - `dataset`, a huggingface dataset name. e.g. `princeton-nlp/SWE-bench`, `princeton-nlp/SWE-bench_Lite`, or `princeton-nlp/SWE-bench_Verified`, specifies which dataset to evaluate on.

View File

@@ -37,9 +37,10 @@ from openhands.core.config import (
) )
from openhands.core.logger import openhands_logger as logger from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller from openhands.core.main import create_runtime, run_controller
from openhands.critic import AgentFinishedCritic
from openhands.events.action import CmdRunAction, MessageAction from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation, ErrorObservation from openhands.events.observation import CmdOutputObservation, ErrorObservation
from openhands.events.serialization.event import event_to_dict from openhands.events.serialization.event import event_from_dict, event_to_dict
from openhands.runtime.base import Runtime from openhands.runtime.base import Runtime
from openhands.utils.async_utils import call_async_from_sync from openhands.utils.async_utils import call_async_from_sync
from openhands.utils.shutdown_listener import sleep_if_should_continue from openhands.utils.shutdown_listener import sleep_if_should_continue
@@ -122,7 +123,9 @@ You SHOULD NEVER attempt to browse the web.
# TODO: migrate all swe-bench docker to ghcr.io/openhands # TODO: migrate all swe-bench docker to ghcr.io/openhands
DEFAULT_DOCKER_IMAGE_PREFIX = os.environ.get('EVAL_DOCKER_IMAGE_PREFIX', 'docker.io/xingyaoww/') DEFAULT_DOCKER_IMAGE_PREFIX = os.environ.get(
'EVAL_DOCKER_IMAGE_PREFIX', 'docker.io/xingyaoww/'
)
logger.info(f'Default docker image prefix: {DEFAULT_DOCKER_IMAGE_PREFIX}') logger.info(f'Default docker image prefix: {DEFAULT_DOCKER_IMAGE_PREFIX}')
@@ -637,20 +640,132 @@ if __name__ == '__main__':
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl') output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
print(f'### OUTPUT FILE: {output_file} ###') print(f'### OUTPUT FILE: {output_file} ###')
instances = prepare_dataset(swe_bench_tests, output_file, args.eval_n_limit)
if len(instances) > 0 and not isinstance( # Run evaluation in iterative mode:
instances['PASS_TO_PASS'][instances['PASS_TO_PASS'].index[0]], str # If a rollout fails to output AgentFinishAction, we will try again until it succeeds OR total 3 attempts have been made.
): ITERATIVE_EVAL_MODE = (
for col in ['PASS_TO_PASS', 'FAIL_TO_PASS']: os.environ.get('ITERATIVE_EVAL_MODE', 'false').lower() == 'true'
instances[col] = instances[col].apply(lambda x: str(x))
run_evaluation(
instances,
metadata,
output_file,
args.eval_num_workers,
process_instance,
timeout_seconds=8 * 60 * 60, # 8 hour PER instance should be more than enough
max_retries=5,
) )
ITERATIVE_EVAL_MODE_MAX_ATTEMPTS = int(
os.environ.get('ITERATIVE_EVAL_MODE_MAX_ATTEMPTS', '3')
)
if not ITERATIVE_EVAL_MODE:
# load the dataset
instances = prepare_dataset(swe_bench_tests, output_file, args.eval_n_limit)
if len(instances) > 0 and not isinstance(
instances['PASS_TO_PASS'][instances['PASS_TO_PASS'].index[0]], str
):
for col in ['PASS_TO_PASS', 'FAIL_TO_PASS']:
instances[col] = instances[col].apply(lambda x: str(x))
run_evaluation(
instances,
metadata,
output_file,
args.eval_num_workers,
process_instance,
timeout_seconds=8
* 60
* 60, # 8 hour PER instance should be more than enough
max_retries=5,
)
else:
critic = AgentFinishedCritic()
def get_cur_output_file_path(attempt: int) -> str:
return (
f'{output_file.removesuffix(".jsonl")}.critic_attempt_{attempt}.jsonl'
)
eval_ids = None
for attempt in range(1, ITERATIVE_EVAL_MODE_MAX_ATTEMPTS + 1):
cur_output_file = get_cur_output_file_path(attempt)
logger.info(
f'Running evaluation with critic {critic.__class__.__name__} for attempt {attempt} of {ITERATIVE_EVAL_MODE_MAX_ATTEMPTS}.'
)
# For deterministic eval, we set temperature to 0.1 for (>1) attempt
# so hopefully we get slightly different results
if attempt > 1 and metadata.llm_config.temperature == 0:
logger.info(
f'Detected temperature is 0 for (>1) attempt {attempt}. Setting temperature to 0.1...'
)
metadata.llm_config.temperature = 0.1
# Load instances - at first attempt, we evaluate all instances
# On subsequent attempts, we only evaluate the instances that failed the previous attempt determined by critic
instances = prepare_dataset(
swe_bench_tests, cur_output_file, args.eval_n_limit, eval_ids=eval_ids
)
if len(instances) > 0 and not isinstance(
instances['PASS_TO_PASS'][instances['PASS_TO_PASS'].index[0]], str
):
for col in ['PASS_TO_PASS', 'FAIL_TO_PASS']:
instances[col] = instances[col].apply(lambda x: str(x))
# Run evaluation - but save them to cur_output_file
logger.info(
f'Evaluating {len(instances)} instances for attempt {attempt}...'
)
run_evaluation(
instances,
metadata,
cur_output_file,
args.eval_num_workers,
process_instance,
timeout_seconds=8
* 60
* 60, # 8 hour PER instance should be more than enough
max_retries=5,
)
# When eval is done, we update eval_ids to the instances that failed the current attempt
instances_failed = []
logger.info(
f'Use critic {critic.__class__.__name__} to check {len(instances)} instances for attempt {attempt}...'
)
with open(cur_output_file, 'r') as f:
for line in f:
instance = json.loads(line)
history = [event_from_dict(event) for event in instance['history']]
critic_result = critic.evaluate(history)
if not critic_result.success:
instances_failed.append(instance['instance_id'])
logger.info(
f'{len(instances_failed)} instances failed the current attempt {attempt}: {instances_failed}'
)
eval_ids = instances_failed
# If no instances failed, we break
if len(instances_failed) == 0:
break
# Then we should aggregate the results from all attempts into the original output file
# and remove the intermediate files
logger.info(
'Aggregating results from all attempts into the original output file...'
)
fout = open(output_file, 'w')
added_instance_ids = set()
for attempt in reversed(range(1, ITERATIVE_EVAL_MODE_MAX_ATTEMPTS + 1)):
cur_output_file = get_cur_output_file_path(attempt)
if not os.path.exists(cur_output_file):
logger.warning(
f'Intermediate output file {cur_output_file} does not exist. Skipping...'
)
continue
with open(cur_output_file, 'r') as f:
for line in f:
instance = json.loads(line)
if instance['instance_id'] not in added_instance_ids:
fout.write(line)
added_instance_ids.add(instance['instance_id'])
logger.info(
f'Aggregated instances from {cur_output_file}. Total instances added so far: {len(added_instance_ids)}'
)
fout.close()
logger.info(
f'Done! Total {len(added_instance_ids)} instances added to {output_file}'
)

View File

@@ -25,8 +25,8 @@ if [ -z "$AGENT" ]; then
fi fi
if [ -z "$MAX_ITER" ]; then if [ -z "$MAX_ITER" ]; then
echo "MAX_ITER not specified, use default 100" echo "MAX_ITER not specified, use default 60"
MAX_ITER=100 MAX_ITER=60
fi fi
if [ -z "$RUN_WITH_BROWSING" ]; then if [ -z "$RUN_WITH_BROWSING" ]; then

View File

@@ -0,0 +1,4 @@
from .base import BaseCritic, CriticResult
from .finish_critic import AgentFinishedCritic
__all__ = ['CriticResult', 'BaseCritic', 'AgentFinishedCritic']

31
openhands/critic/base.py Normal file
View File

@@ -0,0 +1,31 @@
import abc
from pydantic import BaseModel
from openhands.events import Event
class CriticResult(BaseModel):
"""
A critic result is a score and a message.
"""
score: float
message: str
@property
def success(self) -> bool:
"""
Whether the agent is successful.
"""
return self.score >= 0.5
class BaseCritic(abc.ABC):
"""
A critic is a function that takes in a list of events and returns a score about the quality of those events.
"""
@abc.abstractmethod
def evaluate(self, events: list[Event]) -> CriticResult:
pass

View File

@@ -0,0 +1,21 @@
from openhands.critic.base import BaseCritic, CriticResult
from openhands.events import Event
from openhands.events.action import Action, AgentFinishAction
class AgentFinishedCritic(BaseCritic):
"""This is a simple rule-based critic that checks if the last event is an AgentFinishAction.
If not, it will return a score of 0 and a message indicating that the agent did not finish.
"""
def __init__(self):
pass
def evaluate(self, events: list[Event]) -> CriticResult:
last_action = next((h for h in reversed(events) if isinstance(h, Action)), None)
if isinstance(last_action, AgentFinishAction):
return CriticResult(score=1, message='Agent finished.')
else:
return CriticResult(score=0, message='Agent did not finish.')