mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
Iterative evaluation with rule-based critic (#7293)
This commit is contained in:
@@ -18,6 +18,20 @@ Please follow instruction [here](../../README.md#setup) to setup your local deve
|
|||||||
|
|
||||||
## Run Inference (Rollout) on SWE-Bench Instances: Generate Patch from Problem Statement
|
## Run Inference (Rollout) on SWE-Bench Instances: Generate Patch from Problem Statement
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> **Iterative Evaluation Protocol**
|
||||||
|
>
|
||||||
|
> We have an iterative approach for more stable and reproducible results:
|
||||||
|
> - For each instance, we attempt to generate a solution up to 3 times
|
||||||
|
> - Each attempt continues until either:
|
||||||
|
> 1. The agent successfully produces a patch with `AgentFinishAction`, or
|
||||||
|
> 2. The attempt reaches the maximum iteration limit
|
||||||
|
> - If an attempt fails, we retry with a fresh attempt (up to the 3-attempt maximum)
|
||||||
|
> - If your LLM config has temperature=0, we will automatically use temperature=0.1 for the 2nd and 3rd attempts
|
||||||
|
>
|
||||||
|
> To enable this iterative protocol, set `export ITERATIVE_EVAL_MODE=true`
|
||||||
|
|
||||||
|
|
||||||
### Running Locally with Docker
|
### Running Locally with Docker
|
||||||
|
|
||||||
Make sure your Docker daemon is running, and you have ample disk space (at least 200-500GB, depends on the SWE-Bench set you are running on) for the instance-level docker image.
|
Make sure your Docker daemon is running, and you have ample disk space (at least 200-500GB, depends on the SWE-Bench set you are running on) for the instance-level docker image.
|
||||||
@@ -45,7 +59,7 @@ to `CodeActAgent`.
|
|||||||
default, the script evaluates the entire SWE-bench_Lite test set (300 issues). Note:
|
default, the script evaluates the entire SWE-bench_Lite test set (300 issues). Note:
|
||||||
in order to use `eval_limit`, you must also set `agent`.
|
in order to use `eval_limit`, you must also set `agent`.
|
||||||
- `max_iter`, e.g. `20`, is the maximum number of iterations for the agent to run. By
|
- `max_iter`, e.g. `20`, is the maximum number of iterations for the agent to run. By
|
||||||
default, it is set to 30.
|
default, it is set to 60.
|
||||||
- `num_workers`, e.g. `3`, is the number of parallel workers to run the evaluation. By
|
- `num_workers`, e.g. `3`, is the number of parallel workers to run the evaluation. By
|
||||||
default, it is set to 1.
|
default, it is set to 1.
|
||||||
- `dataset`, a huggingface dataset name. e.g. `princeton-nlp/SWE-bench`, `princeton-nlp/SWE-bench_Lite`, or `princeton-nlp/SWE-bench_Verified`, specifies which dataset to evaluate on.
|
- `dataset`, a huggingface dataset name. e.g. `princeton-nlp/SWE-bench`, `princeton-nlp/SWE-bench_Lite`, or `princeton-nlp/SWE-bench_Verified`, specifies which dataset to evaluate on.
|
||||||
|
|||||||
@@ -37,9 +37,10 @@ from openhands.core.config import (
|
|||||||
)
|
)
|
||||||
from openhands.core.logger import openhands_logger as logger
|
from openhands.core.logger import openhands_logger as logger
|
||||||
from openhands.core.main import create_runtime, run_controller
|
from openhands.core.main import create_runtime, run_controller
|
||||||
|
from openhands.critic import AgentFinishedCritic
|
||||||
from openhands.events.action import CmdRunAction, MessageAction
|
from openhands.events.action import CmdRunAction, MessageAction
|
||||||
from openhands.events.observation import CmdOutputObservation, ErrorObservation
|
from openhands.events.observation import CmdOutputObservation, ErrorObservation
|
||||||
from openhands.events.serialization.event import event_to_dict
|
from openhands.events.serialization.event import event_from_dict, event_to_dict
|
||||||
from openhands.runtime.base import Runtime
|
from openhands.runtime.base import Runtime
|
||||||
from openhands.utils.async_utils import call_async_from_sync
|
from openhands.utils.async_utils import call_async_from_sync
|
||||||
from openhands.utils.shutdown_listener import sleep_if_should_continue
|
from openhands.utils.shutdown_listener import sleep_if_should_continue
|
||||||
@@ -122,7 +123,9 @@ You SHOULD NEVER attempt to browse the web.
|
|||||||
|
|
||||||
|
|
||||||
# TODO: migrate all swe-bench docker to ghcr.io/openhands
|
# TODO: migrate all swe-bench docker to ghcr.io/openhands
|
||||||
DEFAULT_DOCKER_IMAGE_PREFIX = os.environ.get('EVAL_DOCKER_IMAGE_PREFIX', 'docker.io/xingyaoww/')
|
DEFAULT_DOCKER_IMAGE_PREFIX = os.environ.get(
|
||||||
|
'EVAL_DOCKER_IMAGE_PREFIX', 'docker.io/xingyaoww/'
|
||||||
|
)
|
||||||
logger.info(f'Default docker image prefix: {DEFAULT_DOCKER_IMAGE_PREFIX}')
|
logger.info(f'Default docker image prefix: {DEFAULT_DOCKER_IMAGE_PREFIX}')
|
||||||
|
|
||||||
|
|
||||||
@@ -637,20 +640,132 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||||
print(f'### OUTPUT FILE: {output_file} ###')
|
print(f'### OUTPUT FILE: {output_file} ###')
|
||||||
instances = prepare_dataset(swe_bench_tests, output_file, args.eval_n_limit)
|
|
||||||
|
|
||||||
if len(instances) > 0 and not isinstance(
|
# Run evaluation in iterative mode:
|
||||||
instances['PASS_TO_PASS'][instances['PASS_TO_PASS'].index[0]], str
|
# If a rollout fails to output AgentFinishAction, we will try again until it succeeds OR total 3 attempts have been made.
|
||||||
):
|
ITERATIVE_EVAL_MODE = (
|
||||||
for col in ['PASS_TO_PASS', 'FAIL_TO_PASS']:
|
os.environ.get('ITERATIVE_EVAL_MODE', 'false').lower() == 'true'
|
||||||
instances[col] = instances[col].apply(lambda x: str(x))
|
|
||||||
|
|
||||||
run_evaluation(
|
|
||||||
instances,
|
|
||||||
metadata,
|
|
||||||
output_file,
|
|
||||||
args.eval_num_workers,
|
|
||||||
process_instance,
|
|
||||||
timeout_seconds=8 * 60 * 60, # 8 hour PER instance should be more than enough
|
|
||||||
max_retries=5,
|
|
||||||
)
|
)
|
||||||
|
ITERATIVE_EVAL_MODE_MAX_ATTEMPTS = int(
|
||||||
|
os.environ.get('ITERATIVE_EVAL_MODE_MAX_ATTEMPTS', '3')
|
||||||
|
)
|
||||||
|
|
||||||
|
if not ITERATIVE_EVAL_MODE:
|
||||||
|
# load the dataset
|
||||||
|
instances = prepare_dataset(swe_bench_tests, output_file, args.eval_n_limit)
|
||||||
|
if len(instances) > 0 and not isinstance(
|
||||||
|
instances['PASS_TO_PASS'][instances['PASS_TO_PASS'].index[0]], str
|
||||||
|
):
|
||||||
|
for col in ['PASS_TO_PASS', 'FAIL_TO_PASS']:
|
||||||
|
instances[col] = instances[col].apply(lambda x: str(x))
|
||||||
|
|
||||||
|
run_evaluation(
|
||||||
|
instances,
|
||||||
|
metadata,
|
||||||
|
output_file,
|
||||||
|
args.eval_num_workers,
|
||||||
|
process_instance,
|
||||||
|
timeout_seconds=8
|
||||||
|
* 60
|
||||||
|
* 60, # 8 hour PER instance should be more than enough
|
||||||
|
max_retries=5,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
critic = AgentFinishedCritic()
|
||||||
|
|
||||||
|
def get_cur_output_file_path(attempt: int) -> str:
|
||||||
|
return (
|
||||||
|
f'{output_file.removesuffix(".jsonl")}.critic_attempt_{attempt}.jsonl'
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_ids = None
|
||||||
|
for attempt in range(1, ITERATIVE_EVAL_MODE_MAX_ATTEMPTS + 1):
|
||||||
|
cur_output_file = get_cur_output_file_path(attempt)
|
||||||
|
logger.info(
|
||||||
|
f'Running evaluation with critic {critic.__class__.__name__} for attempt {attempt} of {ITERATIVE_EVAL_MODE_MAX_ATTEMPTS}.'
|
||||||
|
)
|
||||||
|
|
||||||
|
# For deterministic eval, we set temperature to 0.1 for (>1) attempt
|
||||||
|
# so hopefully we get slightly different results
|
||||||
|
if attempt > 1 and metadata.llm_config.temperature == 0:
|
||||||
|
logger.info(
|
||||||
|
f'Detected temperature is 0 for (>1) attempt {attempt}. Setting temperature to 0.1...'
|
||||||
|
)
|
||||||
|
metadata.llm_config.temperature = 0.1
|
||||||
|
|
||||||
|
# Load instances - at first attempt, we evaluate all instances
|
||||||
|
# On subsequent attempts, we only evaluate the instances that failed the previous attempt determined by critic
|
||||||
|
instances = prepare_dataset(
|
||||||
|
swe_bench_tests, cur_output_file, args.eval_n_limit, eval_ids=eval_ids
|
||||||
|
)
|
||||||
|
if len(instances) > 0 and not isinstance(
|
||||||
|
instances['PASS_TO_PASS'][instances['PASS_TO_PASS'].index[0]], str
|
||||||
|
):
|
||||||
|
for col in ['PASS_TO_PASS', 'FAIL_TO_PASS']:
|
||||||
|
instances[col] = instances[col].apply(lambda x: str(x))
|
||||||
|
|
||||||
|
# Run evaluation - but save them to cur_output_file
|
||||||
|
logger.info(
|
||||||
|
f'Evaluating {len(instances)} instances for attempt {attempt}...'
|
||||||
|
)
|
||||||
|
run_evaluation(
|
||||||
|
instances,
|
||||||
|
metadata,
|
||||||
|
cur_output_file,
|
||||||
|
args.eval_num_workers,
|
||||||
|
process_instance,
|
||||||
|
timeout_seconds=8
|
||||||
|
* 60
|
||||||
|
* 60, # 8 hour PER instance should be more than enough
|
||||||
|
max_retries=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# When eval is done, we update eval_ids to the instances that failed the current attempt
|
||||||
|
instances_failed = []
|
||||||
|
logger.info(
|
||||||
|
f'Use critic {critic.__class__.__name__} to check {len(instances)} instances for attempt {attempt}...'
|
||||||
|
)
|
||||||
|
with open(cur_output_file, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
instance = json.loads(line)
|
||||||
|
history = [event_from_dict(event) for event in instance['history']]
|
||||||
|
critic_result = critic.evaluate(history)
|
||||||
|
if not critic_result.success:
|
||||||
|
instances_failed.append(instance['instance_id'])
|
||||||
|
logger.info(
|
||||||
|
f'{len(instances_failed)} instances failed the current attempt {attempt}: {instances_failed}'
|
||||||
|
)
|
||||||
|
eval_ids = instances_failed
|
||||||
|
|
||||||
|
# If no instances failed, we break
|
||||||
|
if len(instances_failed) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Then we should aggregate the results from all attempts into the original output file
|
||||||
|
# and remove the intermediate files
|
||||||
|
logger.info(
|
||||||
|
'Aggregating results from all attempts into the original output file...'
|
||||||
|
)
|
||||||
|
fout = open(output_file, 'w')
|
||||||
|
added_instance_ids = set()
|
||||||
|
for attempt in reversed(range(1, ITERATIVE_EVAL_MODE_MAX_ATTEMPTS + 1)):
|
||||||
|
cur_output_file = get_cur_output_file_path(attempt)
|
||||||
|
if not os.path.exists(cur_output_file):
|
||||||
|
logger.warning(
|
||||||
|
f'Intermediate output file {cur_output_file} does not exist. Skipping...'
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
with open(cur_output_file, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
instance = json.loads(line)
|
||||||
|
if instance['instance_id'] not in added_instance_ids:
|
||||||
|
fout.write(line)
|
||||||
|
added_instance_ids.add(instance['instance_id'])
|
||||||
|
logger.info(
|
||||||
|
f'Aggregated instances from {cur_output_file}. Total instances added so far: {len(added_instance_ids)}'
|
||||||
|
)
|
||||||
|
fout.close()
|
||||||
|
logger.info(
|
||||||
|
f'Done! Total {len(added_instance_ids)} instances added to {output_file}'
|
||||||
|
)
|
||||||
|
|||||||
@@ -25,8 +25,8 @@ if [ -z "$AGENT" ]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -z "$MAX_ITER" ]; then
|
if [ -z "$MAX_ITER" ]; then
|
||||||
echo "MAX_ITER not specified, use default 100"
|
echo "MAX_ITER not specified, use default 60"
|
||||||
MAX_ITER=100
|
MAX_ITER=60
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ -z "$RUN_WITH_BROWSING" ]; then
|
if [ -z "$RUN_WITH_BROWSING" ]; then
|
||||||
|
|||||||
4
openhands/critic/__init__.py
Normal file
4
openhands/critic/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .base import BaseCritic, CriticResult
|
||||||
|
from .finish_critic import AgentFinishedCritic
|
||||||
|
|
||||||
|
__all__ = ['CriticResult', 'BaseCritic', 'AgentFinishedCritic']
|
||||||
31
openhands/critic/base.py
Normal file
31
openhands/critic/base.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import abc
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from openhands.events import Event
|
||||||
|
|
||||||
|
|
||||||
|
class CriticResult(BaseModel):
|
||||||
|
"""
|
||||||
|
A critic result is a score and a message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
score: float
|
||||||
|
message: str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def success(self) -> bool:
|
||||||
|
"""
|
||||||
|
Whether the agent is successful.
|
||||||
|
"""
|
||||||
|
return self.score >= 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCritic(abc.ABC):
|
||||||
|
"""
|
||||||
|
A critic is a function that takes in a list of events and returns a score about the quality of those events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def evaluate(self, events: list[Event]) -> CriticResult:
|
||||||
|
pass
|
||||||
21
openhands/critic/finish_critic.py
Normal file
21
openhands/critic/finish_critic.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from openhands.critic.base import BaseCritic, CriticResult
|
||||||
|
from openhands.events import Event
|
||||||
|
from openhands.events.action import Action, AgentFinishAction
|
||||||
|
|
||||||
|
|
||||||
|
class AgentFinishedCritic(BaseCritic):
|
||||||
|
"""This is a simple rule-based critic that checks if the last event is an AgentFinishAction.
|
||||||
|
|
||||||
|
If not, it will return a score of 0 and a message indicating that the agent did not finish.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def evaluate(self, events: list[Event]) -> CriticResult:
|
||||||
|
last_action = next((h for h in reversed(events) if isinstance(h, Action)), None)
|
||||||
|
|
||||||
|
if isinstance(last_action, AgentFinishAction):
|
||||||
|
return CriticResult(score=1, message='Agent finished.')
|
||||||
|
else:
|
||||||
|
return CriticResult(score=0, message='Agent did not finish.')
|
||||||
Reference in New Issue
Block a user