Compare commits

...

1 Commits

Author SHA1 Message Date
openhands
81d4a80f7c Fix issue #6336: [Bug]: swe_bench/eval_infer.py type checking issues 2025-01-17 22:07:22 +00:00
3 changed files with 48 additions and 40 deletions

View File

@@ -1,8 +1,10 @@
import asyncio
import json
import os
import tempfile
import time
from functools import partial
from typing import Awaitable
import pandas as pd
from swebench.harness.grading import get_eval_report
@@ -32,8 +34,9 @@ from openhands.core.config import (
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime
from openhands.runtime.base import Runtime
from openhands.events.action import CmdRunAction
from openhands.events.observation import CmdOutputObservation
from openhands.events.observation import CmdOutputObservation, Observation
from openhands.utils.async_utils import call_async_from_sync
# TODO: migrate all swe-bench docker to ghcr.io/openhands
@@ -71,7 +74,7 @@ def process_git_patch(patch):
return patch
def get_config(instance: pd.Series) -> AppConfig:
def get_config(instance: pd.Series, metadata: EvalMetadata | None = None) -> AppConfig:
# We use a different instance image for the each instance of swe-bench eval
base_container_image = get_instance_docker_image(instance['instance_id'])
logger.info(
@@ -91,7 +94,7 @@ def get_config(instance: pd.Series) -> AppConfig:
remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
remote_runtime_init_timeout=3600,
remote_runtime_resource_factor=get_instance_resource_factor(
dataset_name=metadata.dataset,
dataset_name=metadata.dataset if metadata else None,
instance_id=instance['instance_id'],
),
),
@@ -102,7 +105,7 @@ def get_config(instance: pd.Series) -> AppConfig:
return config
def process_instance(
async def process_instance(
instance: pd.Series,
metadata: EvalMetadata,
reset_logger: bool = True,
@@ -132,7 +135,7 @@ def process_instance(
else:
logger.info(f'Starting evaluation for instance {instance.instance_id}.')
config = get_config(instance)
config = get_config(instance, metadata)
instance_id = instance.instance_id
model_patch = instance['model_patch']
test_spec: TestSpec = instance['test_spec']
@@ -167,28 +170,28 @@ def process_instance(
)
try:
runtime = create_runtime(config)
call_async_from_sync(runtime.connect)
runtime: Runtime = create_runtime(config)
await runtime.connect()
# Get patch and save it to /tmp/patch.diff
with tempfile.TemporaryDirectory() as temp_dir:
# Patch file
patch_file_path = os.path.join(temp_dir, 'patch.diff')
with open(patch_file_path, 'w') as f:
f.write(model_patch)
runtime.copy_to(patch_file_path, '/tmp')
await runtime.copy_to(patch_file_path, '/tmp')
# Eval script
eval_script_path = os.path.join(temp_dir, 'eval.sh')
with open(eval_script_path, 'w') as f:
f.write(test_spec.eval_script)
runtime.copy_to(eval_script_path, '/tmp')
await runtime.copy_to(eval_script_path, '/tmp')
# Set +x
action = CmdRunAction(command='chmod +x /tmp/eval.sh')
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.exit_code == 0
chmod_obs: Observation = await runtime.run_action(action)
logger.info(chmod_obs, extra={'msg_type': 'OBSERVATION'})
assert isinstance(chmod_obs, CmdOutputObservation) and chmod_obs.exit_code == 0
# Apply patch
exec_command = (
@@ -200,9 +203,9 @@ def process_instance(
)
action = CmdRunAction(command=exec_command)
action.set_hard_timeout(600)
obs = runtime.run_action(action)
assert isinstance(obs, CmdOutputObservation)
apply_patch_output = obs.content
patch_obs: Observation = await runtime.run_action(action)
assert isinstance(patch_obs, CmdOutputObservation)
apply_patch_output = patch_obs.content
assert isinstance(apply_patch_output, str)
instance['test_result']['apply_patch_output'] = apply_patch_output
@@ -222,10 +225,10 @@ def process_instance(
log_file = '/tmp/eval_output.log'
action = CmdRunAction(command=f'/tmp/eval.sh > {log_file} 2>&1 & echo $!')
action.set_hard_timeout(300) # Short timeout just to get the process ID
obs = runtime.run_action(action)
eval_obs: Observation = await runtime.run_action(action)
if isinstance(obs, CmdOutputObservation) and obs.exit_code == 0:
pid = obs.content.split()[-1].strip()
if isinstance(eval_obs, CmdOutputObservation) and eval_obs.exit_code == 0:
pid = eval_obs.content.split()[-1].strip()
logger.info(
f'[{instance_id}] Evaluation process started with PID: {pid}'
)
@@ -245,7 +248,7 @@ def process_instance(
command=f'ps -p {pid} > /dev/null; echo $?'
)
check_action.set_hard_timeout(300)
check_obs = runtime.run_action(check_action)
check_obs: Observation = await runtime.run_action(check_action)
if (
isinstance(check_obs, CmdOutputObservation)
and check_obs.content.split()[-1].strip() == '1'
@@ -257,12 +260,12 @@ def process_instance(
logger.info(
f'[{instance_id}] [{seconds_elapsed:.0f}s] Evaluation still running, waiting...'
)
time.sleep(30) # Wait for 30 seconds before checking again
await asyncio.sleep(30) # Wait for 30 seconds before checking again
# Read the log file
cat_action = CmdRunAction(command=f'cat {log_file}')
cat_action.set_hard_timeout(300)
cat_obs = runtime.run_action(cat_action)
cat_obs: Observation = await runtime.run_action(cat_action)
# Grade answer
if isinstance(cat_obs, CmdOutputObservation) and cat_obs.exit_code == 0:
@@ -305,7 +308,7 @@ def process_instance(
instance['test_result']['report']['resolved'] = False
instance['test_result']['report']['error_eval'] = True
else:
logger.info(f'[{instance_id}] Error when starting eval:\n{obs.content}')
logger.info(f'[{instance_id}] Error when starting eval:\n{eval_obs.content}')
instance['test_result']['report']['error_eval'] = True
return EvalOutput(
@@ -323,7 +326,10 @@ def process_instance(
logger,
)
finally:
runtime.close()
try:
await runtime.close()
except Exception:
pass
if __name__ == '__main__':

View File

@@ -18,7 +18,9 @@ DEFAULT_RUNTIME_RESOURCE_FACTOR = int(
_global_resource_mapping: dict[str, dict[str, float]] = {}
def get_resource_mapping(dataset_name: str) -> dict[str, float]:
def get_resource_mapping(dataset_name: str | None) -> dict[str, float] | None:
if dataset_name is None:
return None
if dataset_name not in _global_resource_mapping:
file_path = os.path.join(CUR_DIR, f'{dataset_name}.json')
if not os.path.exists(file_path):
@@ -31,7 +33,7 @@ def get_resource_mapping(dataset_name: str) -> dict[str, float]:
return _global_resource_mapping[dataset_name]
def get_instance_resource_factor(dataset_name: str, instance_id: str) -> int:
def get_instance_resource_factor(dataset_name: str | None, instance_id: str) -> int:
resource_mapping = get_resource_mapping(dataset_name)
if resource_mapping is None:
return DEFAULT_RUNTIME_RESOURCE_FACTOR

View File

@@ -133,7 +133,7 @@ class Runtime(FileEditRuntimeMixin):
if self.config.sandbox.runtime_startup_env_vars:
self.add_env_vars(self.config.sandbox.runtime_startup_env_vars)
def close(self) -> None:
async def close(self) -> None:
pass
def log(self, level: str, message: str) -> None:
@@ -211,7 +211,7 @@ class Runtime(FileEditRuntimeMixin):
source = event.source if event.source else EventSource.AGENT
self.event_stream.add_event(observation, source) # type: ignore[arg-type]
def clone_repo(self, github_token: str, selected_repository: str) -> str:
async def clone_repo(self, github_token: str, selected_repository: str) -> str:
if not github_token or not selected_repository:
raise ValueError(
'github_token and selected_repository must be provided to clone a repository'
@@ -227,7 +227,7 @@ class Runtime(FileEditRuntimeMixin):
command=f'git clone {url} {dir_name} ; cd {dir_name} ; git checkout -b {branch_name}',
)
self.log('info', f'Cloning repo: {selected_repository}')
self.run_action(action)
await self.run_action(action)
return dir_name
def get_microagents_from_selected_repo(
@@ -310,7 +310,7 @@ class Runtime(FileEditRuntimeMixin):
return loaded_microagents
def run_action(self, action: Action) -> Observation:
async def run_action(self, action: Action) -> Observation:
"""Run an action and return the resulting observation.
If the action is not runnable in any runtime, a NullObservation is returned.
If the action is not supported by the current runtime, an ErrorObservation is returned.
@@ -347,8 +347,8 @@ class Runtime(FileEditRuntimeMixin):
def __enter__(self) -> 'Runtime':
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
self.close()
async def __exit__(self, exc_type, exc_value, traceback) -> None:
await self.close()
@abstractmethod
async def connect(self) -> None:
@@ -359,27 +359,27 @@ class Runtime(FileEditRuntimeMixin):
# ====================================================================
@abstractmethod
def run(self, action: CmdRunAction) -> Observation:
async def run(self, action: CmdRunAction) -> Observation:
pass
@abstractmethod
def run_ipython(self, action: IPythonRunCellAction) -> Observation:
async def run_ipython(self, action: IPythonRunCellAction) -> Observation:
pass
@abstractmethod
def read(self, action: FileReadAction) -> Observation:
async def read(self, action: FileReadAction) -> Observation:
pass
@abstractmethod
def write(self, action: FileWriteAction) -> Observation:
async def write(self, action: FileWriteAction) -> Observation:
pass
@abstractmethod
def browse(self, action: BrowseURLAction) -> Observation:
async def browse(self, action: BrowseURLAction) -> Observation:
pass
@abstractmethod
def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
async def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
pass
# ====================================================================
@@ -387,11 +387,11 @@ class Runtime(FileEditRuntimeMixin):
# ====================================================================
@abstractmethod
def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
async def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
raise NotImplementedError('This method is not implemented in the base class.')
@abstractmethod
def list_files(self, path: str | None = None) -> list[str]:
async def list_files(self, path: str | None = None) -> list[str]:
"""List files in the sandbox.
If path is None, list files in the sandbox's initial working directory (e.g., /workspace).
@@ -399,7 +399,7 @@ class Runtime(FileEditRuntimeMixin):
raise NotImplementedError('This method is not implemented in the base class.')
@abstractmethod
def copy_from(self, path: str) -> Path:
async def copy_from(self, path: str) -> Path:
"""Zip all files in the sandbox and return a path in the local filesystem."""
raise NotImplementedError('This method is not implemented in the base class.')