From 3ca585b79f48034c406d4bde97e22ff066baf44c Mon Sep 17 00:00:00 2001 From: Yueqi Song <141804823+yueqis@users.noreply.github.com> Date: Thu, 15 May 2025 00:27:28 -0400 Subject: [PATCH] Update run_infer.py to incorporate selection of task based on repo (#8509) --- evaluation/benchmarks/swe_bench/run_infer.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/evaluation/benchmarks/swe_bench/run_infer.py b/evaluation/benchmarks/swe_bench/run_infer.py index ea84f2d191..887cf90f97 100644 --- a/evaluation/benchmarks/swe_bench/run_infer.py +++ b/evaluation/benchmarks/swe_bench/run_infer.py @@ -714,6 +714,19 @@ def filter_dataset(dataset: pd.DataFrame, filter_column: str) -> pd.DataFrame: subset = dataset[dataset[filter_column].isin(selected_ids)] logger.info(f'Retained {subset.shape[0]} tasks after filtering') return subset + if 'selected_repos' in data: + # repos for the swe-bench instances: + # ['astropy/astropy', 'django/django', 'matplotlib/matplotlib', 'mwaskom/seaborn', 'pallets/flask', 'psf/requests', 'pydata/xarray', 'pylint-dev/pylint', 'pytest-dev/pytest', 'scikit-learn/scikit-learn', 'sphinx-doc/sphinx', 'sympy/sympy'] + selected_repos = data['selected_repos'] + if isinstance(selected_repos, str): selected_repos = [selected_repos] + assert isinstance(selected_repos, list) + logger.info( + f'Filtering {selected_repos} tasks from "selected_repos"...' + ) + subset = dataset[dataset["repo"].isin(selected_repos)] + logger.info(f'Retained {subset.shape[0]} tasks after filtering') + return subset + skip_ids = os.environ.get('SKIP_IDS', '').split(',') if len(skip_ids) > 0: logger.info(f'Filtering {len(skip_ids)} tasks from "SKIP_IDS"...')