mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
1 Commits
test/remot
...
openhands-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
81d4a80f7c |
@@ -1,8 +1,10 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Awaitable
|
||||
|
||||
import pandas as pd
|
||||
from swebench.harness.grading import get_eval_report
|
||||
@@ -32,8 +34,9 @@ from openhands.core.config import (
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.events.action import CmdRunAction
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
from openhands.events.observation import CmdOutputObservation, Observation
|
||||
from openhands.utils.async_utils import call_async_from_sync
|
||||
|
||||
# TODO: migrate all swe-bench docker to ghcr.io/openhands
|
||||
@@ -71,7 +74,7 @@ def process_git_patch(patch):
|
||||
return patch
|
||||
|
||||
|
||||
def get_config(instance: pd.Series) -> AppConfig:
|
||||
def get_config(instance: pd.Series, metadata: EvalMetadata | None = None) -> AppConfig:
|
||||
# We use a different instance image for the each instance of swe-bench eval
|
||||
base_container_image = get_instance_docker_image(instance['instance_id'])
|
||||
logger.info(
|
||||
@@ -91,7 +94,7 @@ def get_config(instance: pd.Series) -> AppConfig:
|
||||
remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
|
||||
remote_runtime_init_timeout=3600,
|
||||
remote_runtime_resource_factor=get_instance_resource_factor(
|
||||
dataset_name=metadata.dataset,
|
||||
dataset_name=metadata.dataset if metadata else None,
|
||||
instance_id=instance['instance_id'],
|
||||
),
|
||||
),
|
||||
@@ -102,7 +105,7 @@ def get_config(instance: pd.Series) -> AppConfig:
|
||||
return config
|
||||
|
||||
|
||||
def process_instance(
|
||||
async def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
@@ -132,7 +135,7 @@ def process_instance(
|
||||
else:
|
||||
logger.info(f'Starting evaluation for instance {instance.instance_id}.')
|
||||
|
||||
config = get_config(instance)
|
||||
config = get_config(instance, metadata)
|
||||
instance_id = instance.instance_id
|
||||
model_patch = instance['model_patch']
|
||||
test_spec: TestSpec = instance['test_spec']
|
||||
@@ -167,28 +170,28 @@ def process_instance(
|
||||
)
|
||||
|
||||
try:
|
||||
runtime = create_runtime(config)
|
||||
call_async_from_sync(runtime.connect)
|
||||
runtime: Runtime = create_runtime(config)
|
||||
await runtime.connect()
|
||||
# Get patch and save it to /tmp/patch.diff
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Patch file
|
||||
patch_file_path = os.path.join(temp_dir, 'patch.diff')
|
||||
with open(patch_file_path, 'w') as f:
|
||||
f.write(model_patch)
|
||||
runtime.copy_to(patch_file_path, '/tmp')
|
||||
await runtime.copy_to(patch_file_path, '/tmp')
|
||||
# Eval script
|
||||
eval_script_path = os.path.join(temp_dir, 'eval.sh')
|
||||
with open(eval_script_path, 'w') as f:
|
||||
f.write(test_spec.eval_script)
|
||||
runtime.copy_to(eval_script_path, '/tmp')
|
||||
await runtime.copy_to(eval_script_path, '/tmp')
|
||||
|
||||
# Set +x
|
||||
action = CmdRunAction(command='chmod +x /tmp/eval.sh')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
chmod_obs: Observation = await runtime.run_action(action)
|
||||
logger.info(chmod_obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert isinstance(chmod_obs, CmdOutputObservation) and chmod_obs.exit_code == 0
|
||||
|
||||
# Apply patch
|
||||
exec_command = (
|
||||
@@ -200,9 +203,9 @@ def process_instance(
|
||||
)
|
||||
action = CmdRunAction(command=exec_command)
|
||||
action.set_hard_timeout(600)
|
||||
obs = runtime.run_action(action)
|
||||
assert isinstance(obs, CmdOutputObservation)
|
||||
apply_patch_output = obs.content
|
||||
patch_obs: Observation = await runtime.run_action(action)
|
||||
assert isinstance(patch_obs, CmdOutputObservation)
|
||||
apply_patch_output = patch_obs.content
|
||||
assert isinstance(apply_patch_output, str)
|
||||
instance['test_result']['apply_patch_output'] = apply_patch_output
|
||||
|
||||
@@ -222,10 +225,10 @@ def process_instance(
|
||||
log_file = '/tmp/eval_output.log'
|
||||
action = CmdRunAction(command=f'/tmp/eval.sh > {log_file} 2>&1 & echo $!')
|
||||
action.set_hard_timeout(300) # Short timeout just to get the process ID
|
||||
obs = runtime.run_action(action)
|
||||
eval_obs: Observation = await runtime.run_action(action)
|
||||
|
||||
if isinstance(obs, CmdOutputObservation) and obs.exit_code == 0:
|
||||
pid = obs.content.split()[-1].strip()
|
||||
if isinstance(eval_obs, CmdOutputObservation) and eval_obs.exit_code == 0:
|
||||
pid = eval_obs.content.split()[-1].strip()
|
||||
logger.info(
|
||||
f'[{instance_id}] Evaluation process started with PID: {pid}'
|
||||
)
|
||||
@@ -245,7 +248,7 @@ def process_instance(
|
||||
command=f'ps -p {pid} > /dev/null; echo $?'
|
||||
)
|
||||
check_action.set_hard_timeout(300)
|
||||
check_obs = runtime.run_action(check_action)
|
||||
check_obs: Observation = await runtime.run_action(check_action)
|
||||
if (
|
||||
isinstance(check_obs, CmdOutputObservation)
|
||||
and check_obs.content.split()[-1].strip() == '1'
|
||||
@@ -257,12 +260,12 @@ def process_instance(
|
||||
logger.info(
|
||||
f'[{instance_id}] [{seconds_elapsed:.0f}s] Evaluation still running, waiting...'
|
||||
)
|
||||
time.sleep(30) # Wait for 30 seconds before checking again
|
||||
await asyncio.sleep(30) # Wait for 30 seconds before checking again
|
||||
|
||||
# Read the log file
|
||||
cat_action = CmdRunAction(command=f'cat {log_file}')
|
||||
cat_action.set_hard_timeout(300)
|
||||
cat_obs = runtime.run_action(cat_action)
|
||||
cat_obs: Observation = await runtime.run_action(cat_action)
|
||||
|
||||
# Grade answer
|
||||
if isinstance(cat_obs, CmdOutputObservation) and cat_obs.exit_code == 0:
|
||||
@@ -305,7 +308,7 @@ def process_instance(
|
||||
instance['test_result']['report']['resolved'] = False
|
||||
instance['test_result']['report']['error_eval'] = True
|
||||
else:
|
||||
logger.info(f'[{instance_id}] Error when starting eval:\n{obs.content}')
|
||||
logger.info(f'[{instance_id}] Error when starting eval:\n{eval_obs.content}')
|
||||
instance['test_result']['report']['error_eval'] = True
|
||||
|
||||
return EvalOutput(
|
||||
@@ -323,7 +326,10 @@ def process_instance(
|
||||
logger,
|
||||
)
|
||||
finally:
|
||||
runtime.close()
|
||||
try:
|
||||
await runtime.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -18,7 +18,9 @@ DEFAULT_RUNTIME_RESOURCE_FACTOR = int(
|
||||
_global_resource_mapping: dict[str, dict[str, float]] = {}
|
||||
|
||||
|
||||
def get_resource_mapping(dataset_name: str) -> dict[str, float]:
|
||||
def get_resource_mapping(dataset_name: str | None) -> dict[str, float] | None:
|
||||
if dataset_name is None:
|
||||
return None
|
||||
if dataset_name not in _global_resource_mapping:
|
||||
file_path = os.path.join(CUR_DIR, f'{dataset_name}.json')
|
||||
if not os.path.exists(file_path):
|
||||
@@ -31,7 +33,7 @@ def get_resource_mapping(dataset_name: str) -> dict[str, float]:
|
||||
return _global_resource_mapping[dataset_name]
|
||||
|
||||
|
||||
def get_instance_resource_factor(dataset_name: str, instance_id: str) -> int:
|
||||
def get_instance_resource_factor(dataset_name: str | None, instance_id: str) -> int:
|
||||
resource_mapping = get_resource_mapping(dataset_name)
|
||||
if resource_mapping is None:
|
||||
return DEFAULT_RUNTIME_RESOURCE_FACTOR
|
||||
|
||||
@@ -133,7 +133,7 @@ class Runtime(FileEditRuntimeMixin):
|
||||
if self.config.sandbox.runtime_startup_env_vars:
|
||||
self.add_env_vars(self.config.sandbox.runtime_startup_env_vars)
|
||||
|
||||
def close(self) -> None:
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
def log(self, level: str, message: str) -> None:
|
||||
@@ -211,7 +211,7 @@ class Runtime(FileEditRuntimeMixin):
|
||||
source = event.source if event.source else EventSource.AGENT
|
||||
self.event_stream.add_event(observation, source) # type: ignore[arg-type]
|
||||
|
||||
def clone_repo(self, github_token: str, selected_repository: str) -> str:
|
||||
async def clone_repo(self, github_token: str, selected_repository: str) -> str:
|
||||
if not github_token or not selected_repository:
|
||||
raise ValueError(
|
||||
'github_token and selected_repository must be provided to clone a repository'
|
||||
@@ -227,7 +227,7 @@ class Runtime(FileEditRuntimeMixin):
|
||||
command=f'git clone {url} {dir_name} ; cd {dir_name} ; git checkout -b {branch_name}',
|
||||
)
|
||||
self.log('info', f'Cloning repo: {selected_repository}')
|
||||
self.run_action(action)
|
||||
await self.run_action(action)
|
||||
return dir_name
|
||||
|
||||
def get_microagents_from_selected_repo(
|
||||
@@ -310,7 +310,7 @@ class Runtime(FileEditRuntimeMixin):
|
||||
|
||||
return loaded_microagents
|
||||
|
||||
def run_action(self, action: Action) -> Observation:
|
||||
async def run_action(self, action: Action) -> Observation:
|
||||
"""Run an action and return the resulting observation.
|
||||
If the action is not runnable in any runtime, a NullObservation is returned.
|
||||
If the action is not supported by the current runtime, an ErrorObservation is returned.
|
||||
@@ -347,8 +347,8 @@ class Runtime(FileEditRuntimeMixin):
|
||||
def __enter__(self) -> 'Runtime':
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
self.close()
|
||||
async def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
await self.close()
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> None:
|
||||
@@ -359,27 +359,27 @@ class Runtime(FileEditRuntimeMixin):
|
||||
# ====================================================================
|
||||
|
||||
@abstractmethod
|
||||
def run(self, action: CmdRunAction) -> Observation:
|
||||
async def run(self, action: CmdRunAction) -> Observation:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def run_ipython(self, action: IPythonRunCellAction) -> Observation:
|
||||
async def run_ipython(self, action: IPythonRunCellAction) -> Observation:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def read(self, action: FileReadAction) -> Observation:
|
||||
async def read(self, action: FileReadAction) -> Observation:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def write(self, action: FileWriteAction) -> Observation:
|
||||
async def write(self, action: FileWriteAction) -> Observation:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def browse(self, action: BrowseURLAction) -> Observation:
|
||||
async def browse(self, action: BrowseURLAction) -> Observation:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
|
||||
async def browse_interactive(self, action: BrowseInteractiveAction) -> Observation:
|
||||
pass
|
||||
|
||||
# ====================================================================
|
||||
@@ -387,11 +387,11 @@ class Runtime(FileEditRuntimeMixin):
|
||||
# ====================================================================
|
||||
|
||||
@abstractmethod
|
||||
def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
|
||||
async def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False):
|
||||
raise NotImplementedError('This method is not implemented in the base class.')
|
||||
|
||||
@abstractmethod
|
||||
def list_files(self, path: str | None = None) -> list[str]:
|
||||
async def list_files(self, path: str | None = None) -> list[str]:
|
||||
"""List files in the sandbox.
|
||||
|
||||
If path is None, list files in the sandbox's initial working directory (e.g., /workspace).
|
||||
@@ -399,7 +399,7 @@ class Runtime(FileEditRuntimeMixin):
|
||||
raise NotImplementedError('This method is not implemented in the base class.')
|
||||
|
||||
@abstractmethod
|
||||
def copy_from(self, path: str) -> Path:
|
||||
async def copy_from(self, path: str) -> Path:
|
||||
"""Zip all files in the sandbox and return a path in the local filesystem."""
|
||||
raise NotImplementedError('This method is not implemented in the base class.')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user