mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-07 22:14:03 -05:00
71 lines
2.5 KiB
Python
71 lines
2.5 KiB
Python
import json
|
|
from pathlib import Path
|
|
from typing import cast
|
|
|
|
from datasets import Dataset, load_dataset
|
|
|
|
from evaluation.benchmarks.testgeneval.constants import (
|
|
KEY_INSTANCE_ID,
|
|
TestGenEvalInstance,
|
|
)
|
|
|
|
|
|
def get_test_directives(instance: TestGenEvalInstance) -> list:
|
|
"""Get test directives from the test_patch of a task instance
|
|
|
|
Args:
|
|
instance (dict): task instance
|
|
Returns:
|
|
directives (list): List of test directives
|
|
"""
|
|
# For seq2seq code repos, testing command is fixed
|
|
if instance['repo'] == 'swe-bench/humaneval':
|
|
return ['test.py']
|
|
|
|
# Get test directives from test patch and remove non-test files
|
|
directives = [f'/testbed/{instance["test_file"]}']
|
|
|
|
# For Django tests, remove extension + "tests/" prefix and convert slashes to dots (module referencing)
|
|
if instance['repo'] == 'django/django':
|
|
directives = [instance['test_file']]
|
|
directives_transformed = []
|
|
for d in directives:
|
|
d = d[: -len('.py')] if d.endswith('.py') else d
|
|
d = d[len('tests/') :] if d.startswith('tests/') else d
|
|
d = d.replace('/', '.')
|
|
directives_transformed.append(d)
|
|
directives = directives_transformed
|
|
|
|
return directives
|
|
|
|
|
|
def load_testgeneval_dataset(
|
|
name='kjain14/testgeneval', split='test', ids=None
|
|
) -> list[TestGenEvalInstance]:
|
|
"""Load SWE-bench dataset from Hugging Face Datasets or local .json/.jsonl file"""
|
|
# check that all instance IDs are in the dataset
|
|
if ids:
|
|
ids = set(ids)
|
|
# Load from local .json/.jsonl file
|
|
if name.endswith('.json') or name.endswith('.jsonl'):
|
|
dataset = json.loads(Path(name).read_text())
|
|
dataset_ids = {instance[KEY_INSTANCE_ID] for instance in dataset}
|
|
else:
|
|
# Load from Hugging Face Datasets
|
|
if name.lower() in {'testgeneval'}:
|
|
name = 'kjain14/testgeneval'
|
|
elif name.lower() in {'testgeneval-lite', 'testgenevallite', 'lite'}:
|
|
name = 'kjain14/testgenevallite'
|
|
dataset = cast(Dataset, load_dataset(name, split=split))
|
|
dataset_ids = {instance['id'] for instance in dataset}
|
|
if ids:
|
|
if ids - dataset_ids:
|
|
raise ValueError(
|
|
(
|
|
'Some instance IDs not found in dataset!'
|
|
f'\nMissing IDs:\n{" ".join(ids - dataset_ids)}'
|
|
)
|
|
)
|
|
dataset = [instance for instance in dataset if instance['id'] in ids]
|
|
return [cast(TestGenEvalInstance, instance) for instance in dataset]
|