diff --git a/evaluation/benchmarks/swe_bench/loc_eval/README.md b/evaluation/benchmarks/swe_bench/loc_eval/README.md new file mode 100644 index 0000000000..3eb3a50fbe --- /dev/null +++ b/evaluation/benchmarks/swe_bench/loc_eval/README.md @@ -0,0 +1,45 @@ +# **Localization Evaluation for SWE-Bench** + +This folder implements localization evaluation at both file and function levels to complementing the assessment of agent inference on [SWE-Bench](https://www.swebench.com/). + +## **1. Environment Setup** +- Python env: [Install python environment](../../../README.md#development-environment) +- LLM config: [Configure LLM config](../../../README.md#configure-openhands-and-your-llm) + +## **2. Inference & Evaluation** +- Inference and evaluation follow the original `run_infer.sh` and `run_eval.sh` implementation + - You may refer to instructions at [README.md](../README.md) for running inference and evaluation on SWE-Bench + +## **3. Localization Evaluation** +- Localization evaluation computes two-level localization accuracy, while also considers task success as an additional metric for overall evaluation: + - **File Localization Accuracy:** Accuracy of correctly localizing the target file + - **Function Localization Accuracy:** Accuracy of correctly localizing the target function + - **Resolve Rate** (will be auto-skipped if missing): Success rate of whether tasks are successfully resolved + - **File Localization Efficiency:** Average number of iterations taken to successfully localize the target file + - **Function Localization Efficiency:** Average number of iterations taken to successfully localize the target file + - **Task success efficiency:** Average number of iterations taken to resolve the task + - **Resource efficiency:** the API expenditure of the agent running inference on SWE-Bench instances + +- Run localization evaluation + - Format: + ```bash + ./evaluation/benchmarks/swe_bench/scripts/eval_localization.sh [infer-dir] [split] [dataset] [max-infer-turn] [align-with-max] + ``` + - `infer-dir`: inference directory containing inference outputs + - `split`: SWE-Bench dataset split to use + - `dataset`: SWE-Bench dataset name + - `max-infer-turn`: the maximum number of iterations the agent took to run inference + - `align-with-max`: whether to align failure indices (e.g., incorrect localization, unresolved tasks) with `max_iter` + + - Example: + ```bash + # Example + ./evaluation/benchmarks/swe_bench/scripts/eval_localization.sh \ + --infer-dir ./evaluation/evaluation_outputs/outputs/princeton-nlp__SWE-bench_Verified-test/CodeActAgent/gpt_4o_100_N \ + --split test \ + --dataset princeton-nlp/SWE-bench_Verified \ + --max-infer-turn 100 \ + --align-with-max true + ``` + +- Localization evaluation results will be automatically saved to `[infer-dir]/loc_eval` diff --git a/evaluation/benchmarks/swe_bench/loc_eval/__init__.py b/evaluation/benchmarks/swe_bench/loc_eval/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/evaluation/benchmarks/swe_bench/loc_eval/loc_evaluator.py b/evaluation/benchmarks/swe_bench/loc_eval/loc_evaluator.py new file mode 100644 index 0000000000..403ae446cb --- /dev/null +++ b/evaluation/benchmarks/swe_bench/loc_eval/loc_evaluator.py @@ -0,0 +1,1006 @@ +import argparse +import ast +import json +import os +import re + +import pandas as pd +from datasets import load_dataset +from tqdm import tqdm + +from evaluation.benchmarks.swe_bench.loc_eval.loc_utils import LocMeta +from evaluation.benchmarks.swe_bench.run_infer import filter_dataset +from evaluation.utils.shared import prepare_dataset +from openhands.core.logger import openhands_logger as logger + + +class LocEvaluator: + def __init__(self, args): + """ + Localization evaluation. + + Args: + args: all main arguments + """ + # Config + self.args = args + self.eval_dir = args.eval_dir + self.eval_task_success = self._check_if_to_eval_success() + self.sandbox_root = '/workspace' + self.agent_turn_num = -1 + self.max_agent_turn = args.max_infer_turn + self.align_failed_with_max_iter = args.align_with_max + + # Data + self.instance = None + self.trajectory = None + + # Localization + self.localizer = LocMeta(args.dataset, args.split) + self.gold_loc = {'file': [], 'function': []} + self.agent_loc = { + 'gold loc': {'file': [], 'function': []}, + 'agent loc': {'file': [], 'function': []}, + 'turn index': {'file': [], 'function': []}, + 'loc progress': {'file': [], 'function': []}, + } + + # Task success tracking + self.task_resolved = False + + # Cost + self.cost_summary = {'total_cost': 0.0, 'avg_cost': 0.0, 'details': {}} + + # Save + self.save_dir = os.path.join(args.save_dir, 'loc_eval_results') + self._init_dir(self.save_dir) + self.all_eval_results = {} + self.overall_eval = {} + + def _init_config(self): + # Data + self.instance = None + self.gold_loc = {'file': [], 'function': []} + self.trajectory = None + self.agent_turn_num = -1 + + # Localization + self.agent_loc = { + 'gold loc': {'file': [], 'function': []}, + 'agent loc': {'file': [], 'function': []}, + 'turn index': {'file': [], 'function': []}, + 'loc progress': {'file': [], 'function': []}, + } + + # Task success tracking + self.task_resolved = False + + def _init_dir(self, directory_path): + """ + Check if a directory exists and create it if it doesn't. + + Args: + directory_path (str): Path to the directory to check/create + + Returns: + bool: True if directory already existed, False if it was created + """ + if os.path.exists(directory_path): + if not os.path.isdir(directory_path): + raise NotADirectoryError( + f'Path exists but is not a directory: {directory_path}' + ) + return True + else: + os.makedirs(directory_path) + return False + + def _check_if_to_eval_success(self): + """Check if post-evaluation outputs exist""" + if not os.path.isdir(self.eval_dir): + return False + else: + return True + + def _compute_avg_over_all(self): + """Compute average loc evaluations over all instances""" + macro_la_file, micro_la_file = 0, 0 + macro_la_func, micro_la_func = 0, 0 + resolve_rate = 0 + macro_avg_file_idx, macro_avg_func_idx = 0, 0 + micro_avg_file_idx, micro_avg_func_idx = 0, 0 + avg_resolve_idx = 0 + total_instance_num = len(self.all_eval_results) + + for instance_id in self.all_eval_results: + curr_eval_result = self.all_eval_results[instance_id]['final_eval'] + + # File + macro_la_file += curr_eval_result['localization']['loc_acc (%)'][ + 'la_file (%)' + ]['la_file_macro'] + micro_la_file += curr_eval_result['localization']['loc_acc (%)'][ + 'la_file (%)' + ]['la_file_micro'] + macro_avg_file_idx += curr_eval_result['localization']['turn_idx']['file'][ + 'macro' + ] + micro_avg_file_idx += curr_eval_result['localization']['turn_idx']['file'][ + 'micro' + ] + + # Function + macro_la_func += curr_eval_result['localization']['loc_acc (%)'][ + 'la_func (%)' + ]['la_func_macro'] + micro_la_func += curr_eval_result['localization']['loc_acc (%)'][ + 'la_func (%)' + ]['la_func_micro'] + macro_avg_func_idx += curr_eval_result['localization']['turn_idx'][ + 'function' + ]['macro'] + micro_avg_func_idx += curr_eval_result['localization']['turn_idx'][ + 'function' + ]['micro'] + + if self.eval_task_success: + if curr_eval_result['task_success']['resolved']: + resolve_rate += 1 + avg_resolve_idx += curr_eval_result['task_success']['resolve_index'] + else: + avg_resolve_idx += self.max_agent_turn + + # Average + macro_la_file = macro_la_file / total_instance_num + micro_la_file = micro_la_file / total_instance_num + macro_la_func = macro_la_func / total_instance_num + micro_la_func = micro_la_func / total_instance_num + macro_avg_file_idx = macro_avg_file_idx / total_instance_num + micro_avg_file_idx = micro_avg_file_idx / total_instance_num + macro_avg_func_idx = macro_avg_func_idx / total_instance_num + micro_avg_func_idx = micro_avg_func_idx / total_instance_num + + if self.eval_task_success: + resolve_rate = resolve_rate / total_instance_num * 100 + avg_resolve_idx = avg_resolve_idx / total_instance_num + + # Cost metric + total_cost, avg_cost = 0.0, 0.0 + for instance_key in self.cost_summary['details']: + total_cost += self.cost_summary['details'][instance_key] + avg_cost = total_cost / len(self.cost_summary['details']) + self.cost_summary['total_cost'] = total_cost + self.cost_summary['avg_cost'] = avg_cost + + self.overall_eval = { + 'la_file (%)': {'macro': macro_la_file, 'micro': micro_la_file}, + 'la_func (%)': {'macro': macro_la_func, 'micro': micro_la_func}, + 'resolve_rate (%)': resolve_rate if self.eval_task_success else None, + 'loc_file_idx (turn idx)': { + 'macro': macro_avg_file_idx, + 'micro': micro_avg_file_idx, + }, + 'loc_func_idx (turn idx)': { + 'macro': macro_avg_func_idx, + 'micro': micro_avg_func_idx, + }, + 'resolve_idx (turn idx)': avg_resolve_idx + if self.eval_task_success + else None, + 'max_turn_limit': self.max_agent_turn, + 'total_instance_num': total_instance_num, + 'cost_summary': self.cost_summary, + } + self._write_to_json(self.overall_eval, 'overall_eval.json') + + def _save_to_eval_dicts(self, agent_trajectory: dict): + # Current instancec + self._write_to_json( + agent_trajectory, f'loc__instance_{self.instance.instance_id}.json' + ) + + # All instances + self.all_eval_results[self.instance.instance_id] = agent_trajectory + self._write_to_json(self.all_eval_results, 'all_loc_evals.json') + + # Overall scores + self._compute_avg_over_all() + + def _write_to_json(self, data, file_name): + """ + Writes the current object data to a JSON file. + + Returns: + bool: True if writing was successful, False otherwise. + """ + try: + output_dir = os.path.join(self.save_dir, 'loc_acc') + os.makedirs(output_dir, exist_ok=True) + filepath = os.path.join(output_dir, file_name) + with open(filepath, 'w') as f: + json.dump(data, f, indent=4) + return True + except Exception as e: + logger.error(f'Error writing to JSON: {str(e)}') + return False + + def read_from_json(self, file_path): + """ + Reads data from a JSON file and loads it into the current object. + + Returns: + dict: The loaded JSON data, or an empty dict if the file doesn't exist + or an error occurs. + """ + try: + with open(file_path, 'r') as file: + data = json.load(file) + return data + except FileNotFoundError: + logger.warning( + f"Warning: File '{file_path}' not found. Returning an empty dictionary..." + ) + return {} + except json.JSONDecodeError: + logger.error( + f"Error: File '{file_path}' contains invalid JSON. Returning an empty dictionary..." + ) + return {} + except Exception as e: + logger.error( + f'Error reading from JSON: {str(e)}\nReturning an empty dictionary...' + ) + return {} + + def read_from_jsonl(self, file_path): + """ + Reads data from a JSON file and loads it into the current object. + + Returns: + dict: The loaded JSON data, or an empty dict if the file doesn't exist + or an error occurs. + """ + try: + with open(file_path, 'r') as file: + data = json.load(file) + return data + except FileNotFoundError: + logger.warning( + f"Warning: File '{file_path}' not found. Returning an empty dictionary..." + ) + return {} + except json.JSONDecodeError: + logger.error( + f"Error: File '{file_path}' contains invalid JSON. Returning an empty dictionary..." + ) + return {} + except Exception as e: + logger.error( + f'Error reading from JSON: {str(e)}\nReturning an empty dictionary...' + ) + return {} + + def _parse_agent_turn_num(self): + """Get the max agent turn for current instance""" + history_idx = 1 + self.agent_turn_num = 0 + while history_idx < len(self.trajectory) - 1: + if ( + (self.trajectory[history_idx]['source'] == 'agent') + and ('action' in self.trajectory[history_idx].keys()) + and (self.trajectory[history_idx]['action'] != 'system') + ): + self.agent_turn_num += 1 + history_idx += 1 + + def _parse_string_to_dict(self, dict_string) -> dict: + """ + Convert a string representation of a dictionary to an actual dictionary. + + Args: + dict_string (str): String representation of a dictionary + + Returns: + dict or None: The parsed dictionary if successful, None if failed + """ + if not isinstance(dict_string, str): + return None + + dict_string = dict_string.strip() + + # (1) Try JSON parsing + try: + return json.loads(dict_string) + except (json.JSONDecodeError, ValueError): + pass + + # (1) Try ast parsing + try: + result = ast.literal_eval(dict_string) + if isinstance(result, dict): + return result + else: + return None + except (ValueError, SyntaxError): + pass + + # If both methods fail, return None + return None + + def _parse_value_from_args(self, argument_str: str, key: str) -> str: + """ + Parse a specific key's value from argument string. + + Args: + argument_str (str): The argument string containing key-value pairs + key (str): The key to extract (e.g., "path", "new_str", "old_str") + + Returns: + str: The extracted value, or empty string if not found + """ + if not isinstance(argument_str, str) or not isinstance(key, str): + return '' + + try: + json_pattern = rf'"{re.escape(key)}"\s*:\s*"((?:[^"\\]|\\.)*)"`' + match = re.search(json_pattern, argument_str, re.DOTALL) + if match: + value = match.group(1) + value = ( + value.replace('\\"', '"') + .replace('\\n', '\n') + .replace('\\t', '\t') + .replace('\\\\', '\\') + ) + return value + + python_pattern = rf"'{re.escape(key)}'\s*:\s*'((?:[^'\\]|\\.)*)'" + match = re.search(python_pattern, argument_str, re.DOTALL) + if match: + value = match.group(1) + value = ( + value.replace("\\'", "'") + .replace('\\n', '\n') + .replace('\\t', '\t') + .replace('\\\\', '\\') + ) + return value + + if key in argument_str: + parts = argument_str.split(f'"{key}"', 1) + if len(parts) == 1: + parts = argument_str.split(f"'{key}'", 1) + + if len(parts) > 1: + remainder = parts[1].strip() + + for quote_char in ['"', "'"]: + pattern = f'\\s*:\\s*{quote_char}((?:[^{quote_char}\\\\]|\\\\.)*)(?:{quote_char}|$)' + match = re.search(pattern, remainder, re.DOTALL) + if match: + value = match.group(1) + if quote_char == '"': + value = ( + value.replace('\\"', '"') + .replace('\\n', '\n') + .replace('\\t', '\t') + .replace('\\\\', '\\') + ) + else: + value = ( + value.replace("\\'", "'") + .replace('\\n', '\n') + .replace('\\t', '\t') + .replace('\\\\', '\\') + ) + return value + + if key == 'path': + path_pattern = r'/[^\s,}"\']*' + match = re.search(path_pattern, remainder) + if match: + return match.group(0) + + return '' + + except Exception: + return '' + + def _parse_path_from_args(self, argument_str: str) -> str: + """ + Parse path from argument string. + + Args: + argument_str (str): The argument string containing path information + + Returns: + str: The extracted file path, or empty string if not found + """ + return self._parse_value_from_args(argument_str, 'path') + + def _parse_func_names_from_str(self, code_patch) -> list: + """ + Parse function names from the new_str code patch. + + Args: + code_patch: Either a string (argument string) or already extracted new_str code + + Returns: + list: List of function names found in the code patch + """ + if not code_patch: + return [] + + try: + # Look for "def function_name(" patterns + # This pattern matches: + # - "def" followed by whitespace + # - function name (letters, numbers, underscores, also handle special methods like __len__) + # - opening parenthesis + func_pattern = r'\bdef\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(' + + matches = re.findall(func_pattern, code_patch) + + # Remove duplicates while preserving order + seen = set() + unique_funcs = [] + for func_name in matches: + if func_name not in seen: + seen.add(func_name) + unique_funcs.append(func_name) + + return unique_funcs + + except Exception: + return [] + + def _parse_loc_from_history(self, action_history: dict) -> list: + """Parse function name and file path""" + if not action_history: + logger.error('No action history provided.') + raise + + curr_turn_agent_loc = {} + + if action_history['action'] != 'edit': + return curr_turn_agent_loc + + agent_msg_list = action_history['tool_call_metadata']['model_response'][ + 'choices' + ] + agent_edit = { + 'create': ['file_text'], + 'str_replace': ['old_str', 'new_str'], + } + + for cho in agent_msg_list: + for func_dict in cho['message']['tool_calls']: + edit_args = func_dict['function']['arguments'] + edit_dict = self._parse_string_to_dict(edit_args) + + if edit_dict: + curr_command = edit_dict['command'] + agent_acts = agent_edit[curr_command] + + file_path = edit_dict.get('path', None) + func_names = [] + + for act in agent_acts: + code_patch = edit_dict.get(act, None) + func_names.extend(self._parse_func_names_from_str(code_patch)) + func_names = list(set(func_names)) + + else: + for new_act in list(agent_edit.values()): + if new_act in edit_args: + agent_acts = new_act + break + + file_path = self._parse_path_from_args(edit_args) + func_names = [] + + for act in agent_acts: + code_patch = edit_args.split(agent_acts)[-1].strip() + func_names.extend(self._parse_func_names_from_str(code_patch)) + func_names = list(set(func_names)) + + if file_path and len(file_path) > 0: + if func_names: + if file_path in curr_turn_agent_loc: + curr_turn_agent_loc[file_path].extend(func_names) + else: + curr_turn_agent_loc[file_path] = func_names + else: + curr_turn_agent_loc[file_path] = [] + + return curr_turn_agent_loc + + def _add_task_success_metric(self) -> bool: + """Task success evaluation result""" + self.task_resolved = False + report_pth = os.path.join( + self.eval_dir, self.instance.instance_id, 'report.json' + ) + eval_report = self.read_from_json(report_pth) + if self.instance.instance_id in eval_report.keys(): + self.task_resolved = eval_report[self.instance.instance_id]['resolved'] + + if self.task_resolved: + return { + 'resolved': self.task_resolved, + 'resolve_index': self.agent_turn_num, + } + + if self.align_failed_with_max_iter: + return { + 'resolved': self.task_resolved, + 'resolve_index': self.max_agent_turn, + } + else: + return { + 'resolved': self.task_resolved, + 'resolve_index': self.agent_turn_num, + } + + def eval_agent_trajectory(self): + """Evaluate agent's localization at current state""" + if not self.trajectory: + logger.warning( + f'Inference trajectory for current instance (instance ID: {self.instance.instance_id}) is None, skipping localization evaluation for current instance...' + ) + return + + # Process history + agent_trajectory = {'final_eval': {}, 'trajectory': {}} + turn_idx = 0 + history_idx = 1 + + while history_idx < len(self.trajectory) - 2: + history_idx += 1 + action_history = self.trajectory[history_idx] + observ_history = self.trajectory[history_idx + 1] + + # Pass non-agent histories + if (action_history['source'] != 'agent') or ( + 'action' not in action_history.keys() + ): + continue + + # Parse action + turn_idx += 1 + curr_turn_agent_loc = self._parse_loc_from_history(action_history) + + agent_trajectory['trajectory'][f'turn {turn_idx}'] = { + 'loc_eval': None, + 'loc': curr_turn_agent_loc, + 'action': { + 'action': action_history['action'], + 'message': action_history['message'], + }, + 'observation': None, + } + + if 'observation' in observ_history.keys(): + agent_trajectory['trajectory'][f'turn {turn_idx}']['observation'] = { + 'observation': observ_history['observation'], + 'message': observ_history['message'], + } + + # Loc eval + if len(curr_turn_agent_loc) > 0: + for file_key in curr_turn_agent_loc: + for func_name in curr_turn_agent_loc[file_key]: + # File loc + if file_key in self.gold_loc['file']: + if file_key not in self.agent_loc['agent loc']['file']: + self.agent_loc['agent loc']['file'].append(file_key) + self.agent_loc['turn index']['file'][ + self.gold_loc['file'].index(file_key) + ] = turn_idx + self.agent_loc['loc progress']['file'][ + self.gold_loc['file'].index(file_key) + ] = True + + # Function loc + new_agent_loc = {'file': file_key, 'function': func_name} + if new_agent_loc in self.gold_loc['function']: + if ( + new_agent_loc + not in self.agent_loc['agent loc']['function'] + ): + self.agent_loc['agent loc']['function'].append( + new_agent_loc + ) + self.agent_loc['turn index']['function'][ + self.gold_loc['function'].index(new_agent_loc) + ] = turn_idx + self.agent_loc['loc progress']['function'][ + self.gold_loc['function'].index(new_agent_loc) + ] = True + + agent_trajectory['trajectory'][f'turn {turn_idx}']['loc_eval'] = ( + self.agent_loc + ) + + # Task success + agent_trajectory['final_eval'] = { + 'total turn': self.agent_turn_num, + 'max turn': self.max_agent_turn, + 'localization': { + 'loc_acc (%)': { + 'la_file (%)': { + 'la_file_micro': sum(self.agent_loc['loc progress']['file']) + / len(self.agent_loc['loc progress']['file']) + * 100, + 'la_file_macro': 100.0 + if sum(self.agent_loc['loc progress']['file']) > 0 + else 0.0, + }, + 'la_func (%)': { + 'la_func_micro': sum(self.agent_loc['loc progress']['function']) + / len(self.agent_loc['loc progress']['function']) + * 100, + 'la_func_macro': 100.0 + if sum(self.agent_loc['loc progress']['function']) > 0 + else 0.0, + }, + }, + 'turn_idx': { + 'file': { + 'micro': max(self.agent_loc['turn index']['file']), + 'macro': min(self.agent_loc['turn index']['file']), + }, + 'function': { + 'micro': max(self.agent_loc['turn index']['function']), + 'macro': min(self.agent_loc['turn index']['function']), + }, + }, + 'details': { + 'loc_file': self.agent_loc['loc progress']['file'], + 'loc_func': self.agent_loc['loc progress']['function'], + }, + }, + 'task_success': None, + } + + # Task success + if self.eval_task_success: + agent_trajectory['final_eval']['task_success'] = ( + self._add_task_success_metric() + ) + + # Align loc with success + if self.task_resolved: + if agent_trajectory['final_eval']['localization']['loc_acc (%)'] != { + 'la_file (%)': {'la_file_micro': 100.0, 'la_file_macro': 100.0}, + 'la_func (%)': {'la_func_micro': 100.0, 'la_func_macro': 100.0}, + }: + agent_trajectory['final_eval']['localization']['loc_acc (%)'] = { + 'la_file (%)': {'la_file_micro': 100.0, 'la_file_macro': 100.0}, + 'la_func (%)': {'la_func_micro': 100.0, 'la_func_macro': 100.0}, + } + agent_trajectory['final_eval']['localization']['details'] = { + 'loc_file': [ + True for i in range(len(self.agent_loc['loc progress']['file'])) + ], + 'loc_func': [ + True + for i in range(len(self.agent_loc['loc progress']['function'])) + ], + } + + if self.align_failed_with_max_iter: + for level1 in agent_trajectory['final_eval']['localization'][ + 'turn_idx' + ]: + for level2 in agent_trajectory['final_eval']['localization'][ + 'turn_idx' + ][level1]: + if ( + agent_trajectory['final_eval']['localization']['turn_idx'][ + level1 + ][level2] + > self.agent_turn_num + ): + agent_trajectory['final_eval']['localization']['turn_idx'][ + level1 + ][level2] = self.agent_turn_num + + # Save + self._save_to_eval_dicts(agent_trajectory) + + def _get_instance_gt_loc(self): + """Get ground-truth localization for current instance""" + gt_localization = self.localizer.parse_instance_loc(self.instance) + + # Convert to dict + gt_loc_dict = gt_localization['patch'].to_dict() + assert gt_loc_dict['instance_id'] == self.instance.instance_id + self.gold_loc = { + 'gt_loc_dict': gt_loc_dict['functions'], + 'file': [], + 'function': [], + } + + for file_key in gt_loc_dict['functions']: + if len(gt_loc_dict['functions'][file_key]) == 0: + continue + + # File + if file_key not in self.gold_loc['file']: + self.gold_loc['file'].append(f'{self.sandbox_root}/{file_key}') + + # Function + for func_name in gt_loc_dict['functions'][file_key]: + new_gt = { + 'file': f'{self.sandbox_root}/{file_key}', + 'function': func_name, + } + self.gold_loc['function'].append(new_gt) + + # Init agent loc accordingly + init_turn = ( + self.max_agent_turn + if self.align_failed_with_max_iter + else self.agent_turn_num + ) + self.agent_loc['gold loc'] = { + 'file': self.gold_loc['file'], + 'function': self.gold_loc['function'], + } + self.agent_loc['turn index']['file'] = [ + init_turn for i in range(len(self.gold_loc['file'])) + ] + self.agent_loc['turn index']['function'] = [ + init_turn for i in range(len(self.gold_loc['function'])) + ] + self.agent_loc['loc progress']['file'] = [ + False for i in range(len(self.gold_loc['file'])) + ] + self.agent_loc['loc progress']['function'] = [ + False for i in range(len(self.gold_loc['function'])) + ] + + def instance_loc_eval( + self, + instance: pd.Series = None, + repo_root: str = None, + trajectory: list = None, + infer_cost: dict = None, + ): + if instance is None: + logger.error( + 'No instance provided. Skipping current localization evaluation...' + ) + if trajectory is None: + logger.error( + f'No inference trajectory provided for current instance with ID: {instance.instance_id}' + ) + if infer_cost is None: + logger.error( + f'No inference accumulated cost for current instance with ID: {instance.instance_id}' + ) + + # Init + self._init_config() + self.cost_summary['details'][instance.instance_id] = infer_cost + + # Update current instance + self.instance = instance + self.trajectory = trajectory + self.sandbox_root = repo_root + + # Max turn + self._parse_agent_turn_num() + + # GT loc + self._get_instance_gt_loc() + + # Loc evaluation + self.eval_agent_trajectory() + + +def swe_data_loader(args): + """ + Loading SWE-Bench data. + + Args: + args: Main arguments. + """ + dataset = load_dataset(args.dataset, split=args.split) + swe_bench_tests = filter_dataset(dataset.to_pandas(), 'instance_id') + logger.info( + f'Loaded dataset {args.dataset} with split {args.split}: {len(swe_bench_tests)} tasks' + ) + if 'SWE-Gym' in args.dataset: + with open( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'split', + 'swegym_verified_instances.json', + ), + 'r', + ) as f: + swegym_verified_instances = json.load(f) + swe_bench_tests = swe_bench_tests[ + swe_bench_tests['instance_id'].isin(swegym_verified_instances) + ] + logger.info( + f'{len(swe_bench_tests)} tasks left after filtering for SWE-Gym verified instances' + ) + + instances = prepare_dataset(swe_bench_tests, args.swe_output_file, -1) + return instances + + +def infer_data_loader(args): + """ + Load instance IDs. + + Args: + args: Main arguments. + + Returns: + list: A list of instance IDs (strings) extracted from JSON filenames + in the histories directory. + + Raises: + FileNotFoundError: If the histories directory doesn't exist. + AttributeError: If args doesn't have a 'infer_dir' attribute. + """ + infer_output_filepath = os.path.join(args.infer_dir, 'output.jsonl') + + infer_outputs = [] + with open(infer_output_filepath, 'r') as file: + for line_num, line in enumerate(file, 1): + line = line.strip() + if line: + try: + json_obj = json.loads(line) + infer_outputs.append(json_obj) + except json.JSONDecodeError as e: + logger.error( + f"Error parsing JSON on line {line_num} in '{infer_output_filepath}': {str(e)}" + ) + continue + + return infer_outputs + + +def infer_cost_calculator(args): + """ + Calculate total and average costs from metric JSON files with detailed output. + + Args: + args: Main arguments. + + Returns: + dict: A dictionary containing: + - 'total_cost': Sum of all accumulated costs + - 'average_cost': Average cost per JSON file + - 'file_count': Number of JSON files processed + - 'individual_costs': List of individual costs (optional) + """ + metrics_dir = os.path.join(args.infer_dir, 'metrics') + + if not os.path.exists(metrics_dir): + raise FileNotFoundError(f'Metrics directory not found: {metrics_dir}') + + individual_costs = [] + + for filename in os.listdir(metrics_dir): + if filename.endswith('.json'): + file_path = os.path.join(metrics_dir, filename) + + try: + with open(file_path, 'r', encoding='utf-8') as file: + metric_data = json.load(file) + + if 'accumulated_cost' not in metric_data: + raise KeyError(f"'accumulated_cost' not found in {filename}") + + cost = float(metric_data['accumulated_cost']) + individual_costs.append(cost) + + except (json.JSONDecodeError, ValueError, TypeError, IOError) as e: + logger.warning(f'Warning: Error processing {filename}: {e}') + continue + + if not individual_costs: + raise ValueError('No valid JSON files found in the metrics directory') + + total_cost = sum(individual_costs) + average_cost = total_cost / len(individual_costs) + + return { + 'total_cost': total_cost, + 'average_cost': average_cost, + 'file_count': len(individual_costs), + 'individual_costs': individual_costs, + } + + +if __name__ == '__main__': + """Main function for localization evaluation""" + parser = argparse.ArgumentParser( + description='Localization evaluation on SWE-Bench.' + ) + + parser.add_argument( + '--infer-dir', + type=str, + default=None, + help='Directory containing model inference outputs', + ) + parser.add_argument( + '--dataset', type=str, default=None, help='SWE-Bench dataset version' + ) + parser.add_argument( + '--split', type=str, default=None, help='SWE-Bench dataset split selection' + ) + parser.add_argument( + '--max-infer-turn', + type=int, + default=None, + help='Max number of turns allowed for coding agent.', + ) + parser.add_argument( + '--align-with-max', + type=str, + choices=['true', 'false'], + default='true', + help='Whether to align failed instances with max iteration count (true/false)', + ) + + args = parser.parse_args() + + # Convert args.align_with_max str to bool + args.align_with_max = args.align_with_max.lower() == 'true' + + # Eval infer and loc + args.save_dir = f'{args.infer_dir}/loc_eval' + os.makedirs(args.save_dir, exist_ok=True) + args.eval_dir = f'{args.infer_dir}/eval_outputs' + if not os.path.isdir(args.eval_dir): + args.eval_dir = None + + # SWE-Bench + args.swe_output_file = os.path.join(args.save_dir, 'swe_dataset.json') + + # Load swebench data + swe_instances = swe_data_loader(args) + + # Load inference data + infer_outputs = infer_data_loader(args) + + # Loc eval + processed_instances = [] + loc_eval_results = {} + loc_evaluator = LocEvaluator(args) + + for infer_idx, infer_instance in tqdm( + enumerate(infer_outputs), total=len(infer_outputs), desc='Processing instances' + ): + instance_id = infer_instance['instance_id'] + swe_instance = swe_instances.query(f"instance_id == '{instance_id}'").iloc[0] + assert instance_id == swe_instance.instance_id + + processed_instances.append(instance_id) + upload_instruction = infer_instance['instruction'] + repo_root = ( + upload_instruction.split('')[1] + .split('')[0] + .strip() + ) + curr_trajectory = infer_instance['history'] + curr_cost = infer_instance['metrics']['accumulated_cost'] + loc_evaluator.instance_loc_eval( + swe_instance, repo_root, curr_trajectory, curr_cost + ) + + logger.info( + f'\n[Inference Data Summary]' + f'\n{" " * 4} - Total cost: $ {loc_evaluator.cost_summary["total_cost"]}' + f'\n{" " * 4} - Average cost: $ {loc_evaluator.cost_summary["avg_cost"]}' + f'\n{" " * 4} - Number of Instances: {len(processed_instances)}' + ) diff --git a/evaluation/benchmarks/swe_bench/loc_eval/loc_utils.py b/evaluation/benchmarks/swe_bench/loc_eval/loc_utils.py new file mode 100644 index 0000000000..e290354d9d --- /dev/null +++ b/evaluation/benchmarks/swe_bench/loc_eval/loc_utils.py @@ -0,0 +1,1110 @@ +import ast +import logging +import re +from dataclasses import dataclass +from typing import Any, Union + +import pandas as pd +from datasets import load_dataset + +from openhands.runtime.base import Runtime + + +@dataclass +class LocalizationInfo: + """Container for ground-truth localization information""" + + instance_id: str # SWE-Bench instance identifier + files: list[str] # List of modified files + file_line_ranges: dict[ + str, list[tuple[int, int]] + ] # File -> [(start_line, end_line), ...] + functions: dict[str, list[str]] # File -> [function_names, ...] + classes: dict[str, list[str]] # File -> [class_names, ...] + line_to_function: dict[str, dict[int, str]] # File -> {line_num: function_name} + line_to_class: dict[str, dict[int, str]] # File -> {line_num: class_name} + total_lines_changed: int + total_files_changed: int + hunks_per_file: dict[str, int] # File -> number of hunks + + def to_dict(self) -> dict[str, Any]: + """ + Convert LocalizationInfo to a dictionary for JSON serialization. + + Returns: + Dictionary representation of the localization information + """ + return { + 'instance_id': self.instance_id, + 'files': self.files, + 'file_line_ranges': { + file: [[start, end] for start, end in ranges] + for file, ranges in self.file_line_ranges.items() + }, + 'functions': self.functions, + 'classes': self.classes, + 'line_to_function': { + file: {str(line): func for line, func in mapping.items()} + for file, mapping in self.line_to_function.items() + }, + 'line_to_class': { + file: {str(line): cls for line, cls in mapping.items()} + for file, mapping in self.line_to_class.items() + }, + 'total_lines_changed': self.total_lines_changed, + 'total_files_changed': self.total_files_changed, + 'hunks_per_file': self.hunks_per_file, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> 'LocalizationInfo': + """ + Create LocalizationInfo from a dictionary (for loading from JSON). + + Args: + data: Dictionary containing localization information + + Returns: + LocalizationInfo object + """ + return cls( + instance_id=data['instance_id'], + files=data['files'], + file_line_ranges={ + file: [(start, end) for start, end in ranges] + for file, ranges in data['file_line_ranges'].items() + }, + functions=data['functions'], + classes=data['classes'], + line_to_function={ + file: {int(line): func for line, func in mapping.items()} + for file, mapping in data['line_to_function'].items() + }, + line_to_class={ + file: {int(line): cls for line, cls in mapping.items()} + for file, mapping in data['line_to_class'].items() + }, + total_lines_changed=data['total_lines_changed'], + total_files_changed=data['total_files_changed'], + hunks_per_file=data['hunks_per_file'], + ) + + +class LocMeta: + """ + SWE-Bench dataset loader and ground-truth localization parser. + + This class handles loading SWE-Bench datasets and extracting ground-truth + localization information from patches for code localization evaluation. + Works with both standalone Docker containers and OpenHands runtime. + """ + + def __init__( + self, + dataset_name: str = 'princeton-nlp/SWE-bench_Verified', + split: str = 'test', + ): + """ + Initialize LocMeta with a SWE-Bench dataset. + + Args: + dataset_name: HuggingFace dataset name (e.g., "princeton-nlp/SWE-bench_Verified") + """ + self.dataset_name = dataset_name + self.dataset = None + self.split = split + self.df = None + self.instance_lookup = {} + + # Set up logging + logging.basicConfig(level=logging.INFO) + self.logger = logging.getLogger(__name__) + + # Initialize dataset + self._init_swe_dataset() + + def _init_swe_dataset(self) -> None: + """ + Load and initialize the SWE-Bench dataset from HuggingFace. + Converts to pandas DataFrame for easy manipulation. + """ + try: + self.logger.info(f'Loading dataset: {self.dataset_name}') + + # Load dataset from HuggingFace + self.dataset = load_dataset(self.dataset_name, split=self.split) + + # Convert to pandas DataFrame + self.df = pd.DataFrame(self.dataset) + + # Create lookup dictionary for fast instance access + self.instance_lookup = { + row['instance_id']: idx for idx, row in self.df.iterrows() + } + + self.logger.info(f'Successfully loaded {len(self.df)} instances') + self.logger.info(f'Available columns: {list(self.df.columns)}') + + except Exception as e: + self.logger.error(f'Failed to load dataset {self.dataset_name}: {e}') + raise + + def get_instance_by_id(self, instance_id: str) -> pd.Series: + """ + Retrieve a specific instance by its ID. + + Args: + instance_id: The instance identifier + + Returns: + pandas Series containing the instance data + + Raises: + KeyError: If instance_id is not found + """ + if instance_id not in self.instance_lookup: + raise KeyError(f"Instance ID '{instance_id}' not found in dataset") + + idx = self.instance_lookup[instance_id] + return self.df.iloc[idx] + + def parse_instance_loc(self, instance: Union[pd.Series, str]) -> LocalizationInfo: + """ + Parse ground-truth localization information from a SWE-Bench instance. + + Args: + instance: Either a pandas Series with instance data or an instance_id string + + Returns: + LocalizationInfo object containing extracted localization data + """ + # Handle different input types + if isinstance(instance, str): + # instance is actually an instance_id + actual_instance_id = instance + instance = self.get_instance_by_id(actual_instance_id) + else: + # instance is a pandas Series + actual_instance_id = instance.get('instance_id', 'unknown') + + self.logger.info(f'Parsing localization for instance: {actual_instance_id}') + + # Extract patch content + patch_content = instance.get('patch', '') + if not patch_content: + self.logger.warning( + f'No patch content found for instance {actual_instance_id}' + ) + patch_loc_info = self._empty_localization_info(actual_instance_id) + else: + patch_loc_info = self._parse_patch_localization( + patch_content, actual_instance_id + ) + + # Extract test patch content + patch_content = instance.get('test_patch', '') + if not patch_content: + self.logger.warning( + f'No test patch content found for instance {actual_instance_id}' + ) + test_patch_loc_info = self._empty_localization_info(actual_instance_id) + else: + test_patch_loc_info = self._parse_patch_localization( + patch_content, actual_instance_id + ) + + return {'patch': patch_loc_info, 'test_patch': test_patch_loc_info} + + def _parse_file_patch_lines( + self, file_patch: str + ) -> tuple[list[tuple[int, int]], int, int]: + """ + Parse line ranges and count changes from a single file patch. + + Args: + file_patch: Patch content for a single file + + Returns: + Tuple of (line_ranges, total_lines_changed, num_hunks) + """ + line_ranges = [] + lines_changed = 0 + num_hunks = 0 + + lines = file_patch.split('\n') + + for line in lines: + # Match hunk headers: @@ -start,count +start,count @@ + hunk_match = re.match( + r'@@\s+-(\d+)(?:,(\d+))?\s+\+(\d+)(?:,(\d+))?\s+@@', line + ) + if hunk_match: + num_hunks += 1 + new_start = int(hunk_match.group(3)) + new_count = int(hunk_match.group(4)) if hunk_match.group(4) else 1 + + # For localization purposes, we consider the entire hunk range as potentially affected + if new_count > 0: + line_ranges.append((new_start, new_start + new_count - 1)) + lines_changed += new_count + + return line_ranges, lines_changed, num_hunks + + def _parse_code_structures_from_patch( + self, file_patch: str, file_path: str + ) -> tuple[list[str], list[str]]: + """ + Extract function and class names from patch context (fallback method). + + Args: + file_patch: Patch content for a single file + file_path: Path to the file being patched + + Returns: + Tuple of (function_names, class_names) + """ + functions = set() + classes = set() + + # Only attempt Python AST parsing for Python files + if not file_path.endswith('.py'): + return list(functions), list(classes) + + lines = file_patch.split('\n') + + for line in lines: + # Check for function names in hunk headers + # Format: @@ -start,count +start,count @@ [optional context like "def function_name"] + hunk_match = re.match(r'@@.*?@@\s*(.*)', line) + if hunk_match: + context = hunk_match.group(1).strip() + if context: + # Look for function definition in context + func_match = re.search(r'def\s+([a-zA-Z_][a-zA-Z0-9_]*)', context) + if func_match: + functions.add(func_match.group(1)) + + # Look for class definition in context + class_match = re.search( + r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)', context + ) + if class_match: + classes.add(class_match.group(1)) + + # Look for function and class definitions in the patch content + stripped_line = line.lstrip('+-@ ') + + # Match function definitions + func_match = re.match(r'def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(', stripped_line) + if func_match: + functions.add(func_match.group(1)) + + # Match class definitions + class_match = re.match( + r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*[\(:]', stripped_line + ) + if class_match: + classes.add(class_match.group(1)) + + return list(functions), list(classes) + + def _parse_patch_localization( + self, patch_content: str, instance_id: str + ) -> LocalizationInfo: + """ + Parse localization information from a git patch (improved method). + + Args: + patch_content: The git patch content + instance_id: Instance ID for logging + + Returns: + LocalizationInfo object with extracted data + """ + files = [] + file_line_ranges = {} + functions = {} + classes = {} + line_to_function = {} + line_to_class = {} + hunks_per_file = {} + total_lines_changed = 0 + + # Split patch into individual file patches + file_patches = self._split_patch_by_files(patch_content) + + for file_path, file_patch in file_patches.items(): + files.append(file_path) + + # Parse line ranges and count changes + line_ranges, lines_changed, num_hunks = self._parse_file_patch_lines( + file_patch + ) + file_line_ranges[file_path] = line_ranges + total_lines_changed += lines_changed + hunks_per_file[file_path] = num_hunks + + # Extract function and class names from patch context and content + file_functions, file_classes = self._extract_code_structures_from_patch( + file_patch, file_path + ) + + functions[file_path] = file_functions + classes[file_path] = file_classes + + # Create basic line-to-function/class mapping + line_func_map = {} + line_class_map = {} + + # Get all affected lines + affected_lines = [] + for start, end in line_ranges: + affected_lines.extend(range(start, end + 1)) + + # Simple mapping - this is the best we can do without the actual source code + # In a more sophisticated implementation, you'd want to parse the actual source files + if file_functions and affected_lines: + # Map to the first function found (could be improved with better heuristics) + for line_num in affected_lines: + if file_functions: + line_func_map[line_num] = file_functions[0] + if file_classes: + line_class_map[line_num] = file_classes[0] + + line_to_function[file_path] = line_func_map + line_to_class[file_path] = line_class_map + + return LocalizationInfo( + instance_id=instance_id, + files=files, + file_line_ranges=file_line_ranges, + functions=functions, + classes=classes, + line_to_function=line_to_function, + line_to_class=line_to_class, + total_lines_changed=total_lines_changed, + total_files_changed=len(files), + hunks_per_file=hunks_per_file, + ) + + def _extract_code_structures_from_patch( + self, file_patch: str, file_path: str + ) -> tuple[list[str], list[str]]: + """ + Extract function and class names from patch context and content. + + Args: + file_patch: Patch content for a single file + file_path: Path to the file being patched + + Returns: + Tuple of (function_names, class_names) + """ + functions = set() + classes = set() + + # Process Python and Cython files + if not (file_path.endswith('.py') or file_path.endswith('.pyx')): + return list(functions), list(classes) + + lines = file_patch.split('\n') + + # Debug: Print some patch content for analysis + self.logger.info(f'Analyzing patch for {file_path}') + self.logger.info(f'Patch has {len(lines)} lines') + + for line in lines: + # Check for function names in hunk headers with context + # Format: @@ -start,count +start,count @@ [optional context like "def function_name"] + hunk_match = re.match(r'@@.*?@@\s*(.*)', line) + if hunk_match: + context = hunk_match.group(1).strip() + self.logger.info(f"Found hunk context: '{context}'") + if context: + # Look for function definition in context + func_match = re.search( + r'(?:def|async\s+def|cdef\s+\w*\s+|cpdef\s+\w*\s+)\s*([a-zA-Z_][a-zA-Z0-9_]*)', + context, + ) + if func_match: + func_name = func_match.group(1) + functions.add(func_name) + self.logger.info(f'Found function in hunk context: {func_name}') + + # Look for class definition in context + class_match = re.search( + r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)', context + ) + if class_match: + class_name = class_match.group(1) + classes.add(class_name) + self.logger.info(f'Found class in hunk context: {class_name}') + + # Look for function and class definitions in the patch content + # Check both added and removed lines, and context lines + if line.startswith(('+', '-', ' ')): + stripped_line = line[1:].strip() # Remove +/- prefix and whitespace + + # Match function definitions (including async and cdef for Cython) + func_match = re.match( + r'(?:async\s+|cdef\s+)?def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(', + stripped_line, + ) + if func_match: + func_name = func_match.group(1) + functions.add(func_name) + self.logger.info(f'Found function in patch content: {func_name}') + + # Match Cython cdef functions + cdef_func_match = re.match( + r'cdef\s+[^(]*\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(', stripped_line + ) + if cdef_func_match: + func_name = cdef_func_match.group(1) + functions.add(func_name) + self.logger.info( + f'Found cdef function in patch content: {func_name}' + ) + + # Match class definitions + class_match = re.match( + r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*[\(:]', stripped_line + ) + if class_match: + class_name = class_match.group(1) + classes.add(class_name) + self.logger.info(f'Found class in patch content: {class_name}') + + # Also check lines without prefixes (context lines in some patch formats) + elif line.strip() and not line.startswith( + ('@@', 'diff', '---', '+++', 'index') + ): + stripped_line = line.strip() + + # Match function definitions + func_match = re.match( + r'(?:async\s+|cdef\s+)?def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(', + stripped_line, + ) + if func_match: + func_name = func_match.group(1) + functions.add(func_name) + self.logger.info(f'Found function in context line: {func_name}') + + # Match Cython cdef functions + cdef_func_match = re.match( + r'cdef\s+[^(]*\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(', stripped_line + ) + if cdef_func_match: + func_name = cdef_func_match.group(1) + functions.add(func_name) + self.logger.info( + f'Found cdef function in context line: {func_name}' + ) + + # Match class definitions + class_match = re.match( + r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*[\(:]', stripped_line + ) + if class_match: + class_name = class_match.group(1) + classes.add(class_name) + self.logger.info(f'Found class in context line: {class_name}') + + self.logger.info( + f'Final results for {file_path}: functions={list(functions)}, classes={list(classes)}' + ) + return list(functions), list(classes) + + def _parse_patch_localization_with_runtime( + self, patch_content: str, instance_id: str, runtime: Runtime + ) -> LocalizationInfo: + """ + Parse localization information from a git patch using OpenHands runtime. + This is the superior method when runtime is available. + + Args: + patch_content: The git patch content + instance_id: Instance ID for logging + runtime: OpenHands runtime object + + Returns: + LocalizationInfo object with extracted data + """ + files = [] + file_line_ranges = {} + functions = {} + classes = {} + line_to_function = {} + line_to_class = {} + hunks_per_file = {} + total_lines_changed = 0 + + # Split patch into individual file patches + file_patches = self._split_patch_by_files(patch_content) + + for file_path, file_patch in file_patches.items(): + files.append(file_path) + + # Parse line ranges and count changes + line_ranges, lines_changed, num_hunks = self._parse_file_patch_lines( + file_patch + ) + file_line_ranges[file_path] = line_ranges + total_lines_changed += lines_changed + hunks_per_file[file_path] = num_hunks + + # Get all affected line numbers + affected_lines = [] + for start, end in line_ranges: + affected_lines.extend(range(start, end + 1)) + + # Analyze source code using OpenHands runtime for accurate function/class mapping + if affected_lines and ( + file_path.endswith('.py') or file_path.endswith('.pyx') + ): + file_functions, file_classes, line_func_map, line_class_map = ( + self._analyze_source_code_with_runtime( + runtime, file_path, affected_lines + ) + ) + else: + # Fallback to patch-based extraction for non-Python/Cython files or when no lines affected + file_functions, file_classes = self._extract_code_structures_from_patch( + file_patch, file_path + ) + line_func_map, line_class_map = {}, {} + + functions[file_path] = file_functions + classes[file_path] = file_classes + line_to_function[file_path] = line_func_map + line_to_class[file_path] = line_class_map + + return LocalizationInfo( + instance_id=instance_id, + files=files, + file_line_ranges=file_line_ranges, + functions=functions, + classes=classes, + line_to_function=line_to_function, + line_to_class=line_to_class, + total_lines_changed=total_lines_changed, + total_files_changed=len(files), + hunks_per_file=hunks_per_file, + ) + + def parse_instance_loc_with_runtime( + self, instance: Union[pd.Series, str], runtime: Runtime = None + ) -> LocalizationInfo: + """ + Parse ground-truth localization information using OpenHands runtime. + + Args: + instance: Either a pandas Series with instance data or an instance_id string + runtime: OpenHands runtime object + + Returns: + LocalizationInfo object containing extracted localization data + """ + # Handle different input types + if isinstance(instance, str): + # instance is actually an instance_id + actual_instance_id = instance + instance = self.get_instance_by_id(actual_instance_id) + else: + # instance is a pandas Series + actual_instance_id = instance.get('instance_id', 'unknown') + + self.logger.info( + f'Parsing localization with runtime for instance: {actual_instance_id}' + ) + + # Extract patch content + patch_content = instance.get('patch', '') + if not patch_content: + self.logger.warning( + f'No patch content found for instance {actual_instance_id}' + ) + return self._empty_localization_info(actual_instance_id) + + return self._parse_patch_localization_with_runtime( + patch_content, actual_instance_id, runtime + ) + + def _analyze_source_code_with_runtime( + self, runtime: Runtime, file_path: str, affected_lines: list[int] + ) -> tuple[list[str], list[str], dict[int, str], dict[int, str]]: + """ + Analyze source code using OpenHands runtime to find functions and classes. + + Args: + runtime: OpenHands runtime object + file_path: Path to the file being analyzed + affected_lines: List of line numbers that were changed + + Returns: + Tuple of (functions, classes, line_to_function_map, line_to_class_map) + """ + try: + # Check if file exists and is a Python/Cython file + if not (file_path.endswith('.py') or file_path.endswith('.pyx')): + self.logger.info(f'Skipping non-Python/Cython file: {file_path}') + return [], [], {}, {} + + # Read the file content using runtime + from openhands.events.action import CmdRunAction + + # First check if file exists + check_action = CmdRunAction( + command=f'test -f "{file_path}" && echo "EXISTS" || echo "NOT_EXISTS"' + ) + obs = runtime.run_action(check_action) + + if 'NOT_EXISTS' in obs.content: + self.logger.warning(f'File not found: {file_path}') + return [], [], {}, {} + + # Read file content + read_action = CmdRunAction(command=f'cat "{file_path}"') + obs = runtime.run_action(read_action) + + if obs.exit_code != 0: + self.logger.warning(f'Failed to read file {file_path}: {obs.content}') + return [], [], {}, {} + + file_content = obs.content + + # Parse the content + if file_path.endswith('.py'): + return self._parse_python_content_with_line_mapping( + file_content, affected_lines + ) + elif file_path.endswith('.pyx'): + return self._parse_cython_content_with_line_mapping( + file_content, affected_lines + ) + else: + return [], [], {}, {} + + except Exception as e: + self.logger.warning( + f'Failed to analyze source code with runtime for {file_path}: {e}' + ) + return [], [], {}, {} + + def _parse_cython_content_with_line_mapping( + self, content: str, affected_lines: list[int] + ) -> tuple[list[str], list[str], dict[int, str], dict[int, str]]: + """ + Parse Cython content to extract functions and classes with line mapping. + Since Cython files can't be parsed with Python's AST, we use regex-based parsing. + + Args: + content: Cython source code content + affected_lines: List of line numbers that were changed + + Returns: + Tuple of (functions, classes, line_to_function_map, line_to_class_map) + """ + try: + functions = set() + classes = set() + line_to_function = {} + line_to_class = {} + + lines = content.split('\n') + current_function = None + current_class = None + + for i, line in enumerate(lines, 1): + stripped_line = line.strip() + + # Match class definitions + class_match = re.match( + r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*[\(:]', stripped_line + ) + if class_match: + current_class = class_match.group(1) + classes.add(current_class) + continue + + # Match function definitions (def, cdef, cpdef) + func_match = re.match( + r'(?:async\s+|c?p?def\s+(?:[^(]*\s+)?)?def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(', + stripped_line, + ) + if not func_match: + # Try matching cdef functions with return types + func_match = re.match( + r'cdef\s+[^(]*\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(', stripped_line + ) + if not func_match: + # Try matching cpdef functions + func_match = re.match( + r'cpdef\s+[^(]*\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(', stripped_line + ) + + if func_match: + current_function = func_match.group(1) + functions.add(current_function) + continue + + # Check if we're leaving a function/class (basic heuristic based on indentation) + if ( + current_function + and line + and not line[0].isspace() + and not line.startswith('#') + ): + # We've left the function + current_function = None + + if ( + current_class + and line + and not line[0].isspace() + and not line.startswith('#') + and not stripped_line.startswith('def ') + and not stripped_line.startswith('cdef ') + and not stripped_line.startswith('cpdef ') + ): + # We've left the class + current_class = None + + # Map affected lines to functions and classes using a simple heuristic + # This is imperfect but better than nothing for Cython files + lines = content.split('\n') + for line_num in affected_lines: + if line_num <= len(lines): + # Find the nearest function/class definition above this line + nearest_function = None + nearest_class = None + + for i in range(line_num - 1, -1, -1): + if i < len(lines): + line = lines[i].strip() + + # Check for function definition + func_match = re.match( + r'(?:async\s+|c?p?def\s+(?:[^(]*\s+)?)?def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(', + line, + ) + if not func_match: + func_match = re.match( + r'cdef\s+[^(]*\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(', + line, + ) + if not func_match: + func_match = re.match( + r'cpdef\s+[^(]*\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(', + line, + ) + + if func_match and not nearest_function: + nearest_function = func_match.group(1) + + # Check for class definition + class_match = re.match( + r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*[\(:]', line + ) + if class_match and not nearest_class: + nearest_class = class_match.group(1) + + # Stop if we found both or hit the beginning + if (nearest_function and nearest_class) or i == 0: + break + + if nearest_function: + line_to_function[line_num] = nearest_function + if nearest_class: + line_to_class[line_num] = nearest_class + + return list(functions), list(classes), line_to_function, line_to_class + + except Exception as e: + self.logger.warning(f'Failed to parse Cython content: {e}') + return [], [], {}, {} + + def _parse_python_content_with_line_mapping( + self, content: str, affected_lines: list[int] + ) -> tuple[list[str], list[str], dict[int, str], dict[int, str]]: + """ + Parse Python content to extract functions and classes with accurate line mapping. + + Args: + content: Python source code content + affected_lines: List of line numbers that were changed + + Returns: + Tuple of (functions, classes, line_to_function_map, line_to_class_map) + """ + try: + tree = ast.parse(content) + + functions = set() + classes = set() + line_to_function = {} + line_to_class = {} + + # Create a mapping of line numbers to AST nodes + line_to_node = {} + + class NodeVisitor(ast.NodeVisitor): + def __init__(self): + self.current_class = None + self.class_stack = [] + + def visit_ClassDef(self, node): + self.class_stack.append(node.name) + old_class = self.current_class + self.current_class = node.name + classes.add(node.name) + + # Mark lines in this class + start_line = node.lineno + end_line = getattr(node, 'end_lineno', node.lineno) + if end_line is None: + # Estimate end line by finding the next class/function or end of file + end_line = start_line + 100 # Conservative estimate + + for line_num in range(start_line, end_line + 1): + line_to_node[line_num] = ('class', node.name) + + self.generic_visit(node) + self.current_class = old_class + self.class_stack.pop() + + def visit_FunctionDef(self, node): + functions.add(node.name) + + # Mark lines in this function + start_line = node.lineno + end_line = getattr(node, 'end_lineno', node.lineno) + if end_line is None: + # Estimate end line based on the next sibling or parent end + end_line = start_line + 50 # Conservative estimate + + for line_num in range(start_line, end_line + 1): + line_to_node[line_num] = ('function', node.name) + + self.generic_visit(node) + + def visit_AsyncFunctionDef(self, node): + # Handle async functions the same way + self.visit_FunctionDef(node) + + visitor = NodeVisitor() + visitor.visit(tree) + + # Map affected lines to functions and classes + for line_num in affected_lines: + if line_num in line_to_node: + node_type, node_name = line_to_node[line_num] + if node_type == 'function': + line_to_function[line_num] = node_name + elif node_type == 'class': + line_to_class[line_num] = node_name + + return list(functions), list(classes), line_to_function, line_to_class + + except Exception as e: + self.logger.warning(f'Failed to parse Python content: {e}') + return [], [], {}, {} + + def _parse_python_content( + self, content: str, affected_lines: list[int] + ) -> tuple[list[str], list[str], dict[int, str], dict[int, str]]: + """ + Parse Python content to extract functions and classes. + + Args: + content: Python source code content + affected_lines: List of line numbers that were changed + + Returns: + Tuple of (functions, classes, line_to_function_map, line_to_class_map) + """ + try: + tree = ast.parse(content) + + functions = set() + classes = set() + line_to_function = {} + line_to_class = {} + + class Analyzer(ast.NodeVisitor): + def __init__(self): + self.current_class = None + self.function_stack = [] + self.class_stack = [] + + def visit_ClassDef(self, node): + self.class_stack.append(node.name) + old_class = self.current_class + self.current_class = node.name + classes.add(node.name) + + # Mark lines in this class + end_line = getattr(node, 'end_lineno', node.lineno) + if end_line is None: + end_line = node.lineno + + for line_num in range(node.lineno, end_line + 1): + if line_num in affected_lines: + line_to_class[line_num] = node.name + + self.generic_visit(node) + self.current_class = old_class + self.class_stack.pop() + + def visit_FunctionDef(self, node): + self.function_stack.append(node.name) + functions.add(node.name) + + # Mark lines in this function + end_line = getattr(node, 'end_lineno', node.lineno) + if end_line is None: + end_line = node.lineno + + for line_num in range(node.lineno, end_line + 1): + if line_num in affected_lines: + line_to_function[line_num] = node.name + if self.current_class: + line_to_class[line_num] = self.current_class + + self.generic_visit(node) + self.function_stack.pop() + + def visit_AsyncFunctionDef(self, node): + # Handle async functions the same way + self.visit_FunctionDef(node) + + analyzer = Analyzer() + analyzer.visit(tree) + + return list(functions), list(classes), line_to_function, line_to_class + + except Exception as e: + self.logger.warning(f'Failed to parse Python content: {e}') + return [], [], {}, {} + + def _split_patch_by_files(self, patch_content: str) -> dict[str, str]: + """ + Split a multi-file patch into individual file patches. + + Args: + patch_content: Complete patch content + + Returns: + Dictionary mapping file paths to their patch content + """ + file_patches = {} + current_file = None + current_patch_lines = [] + + lines = patch_content.split('\n') + + for line in lines: + # Check for file header patterns + if line.startswith('diff --git'): + # Save previous file if exists + if current_file and current_patch_lines: + file_patches[current_file] = '\n'.join(current_patch_lines) + + # Extract file path from diff line + # Format: diff --git a/path/to/file.py b/path/to/file.py + match = re.search(r'diff --git a/(.*?) b/(.*?)(?:\s|$)', line) + if match: + current_file = match.group(1) # Use the 'a/' path + current_patch_lines = [line] + else: + current_file = None + current_patch_lines = [] + + elif line.startswith('---') or line.startswith('+++'): + # Alternative file path extraction + if not current_file: + match = re.search(r'[+-]{3}\s+(?:a/|b/)?(.+?)(?:\s|$)', line) + if match and not match.group(1).startswith('/dev/null'): + current_file = match.group(1) + if not current_patch_lines: + current_patch_lines = [line] + else: + current_patch_lines.append(line) + else: + if current_patch_lines: + current_patch_lines.append(line) + else: + current_patch_lines.append(line) + + elif current_file: + current_patch_lines.append(line) + + # Save the last file + if current_file and current_patch_lines: + file_patches[current_file] = '\n'.join(current_patch_lines) + + return file_patches + + def _empty_localization_info( + self, instance_id: str = 'unknown' + ) -> LocalizationInfo: + """ + Return an empty LocalizationInfo object. + + Args: + instance_id: Instance identifier + + Returns: + Empty LocalizationInfo instance + """ + return LocalizationInfo( + instance_id=instance_id, + files=[], + file_line_ranges={}, + functions={}, + classes={}, + line_to_function={}, + line_to_class={}, + total_lines_changed=0, + total_files_changed=0, + hunks_per_file={}, + ) + + def get_dataset_statistics(self) -> dict[str, Any]: + """ + Get statistics about the loaded dataset. + + Returns: + Dictionary containing dataset statistics + """ + if self.df is None: + return {} + + stats = { + 'total_instances': len(self.df), + 'repositories': self.df['repo'].nunique() + if 'repo' in self.df.columns + else 0, + 'avg_patch_length': self.df['patch'].str.len().mean() + if 'patch' in self.df.columns + else 0, + 'columns': list(self.df.columns), + } + + return stats + + def get_instances_by_repo(self, repo_name: str) -> pd.DataFrame: + """ + Get all instances for a specific repository. + + Args: + repo_name: Repository name (e.g., "django/django") + + Returns: + DataFrame containing instances for the specified repository + """ + if 'repo' not in self.df.columns: + raise ValueError('Repository information not available in dataset') + + return self.df[self.df['repo'] == repo_name].copy() diff --git a/evaluation/benchmarks/swe_bench/scripts/eval_localization.sh b/evaluation/benchmarks/swe_bench/scripts/eval_localization.sh new file mode 100755 index 0000000000..daa43e9658 --- /dev/null +++ b/evaluation/benchmarks/swe_bench/scripts/eval_localization.sh @@ -0,0 +1,227 @@ +#!/usr/bin/env bash +set -eo pipefail +source "evaluation/utils/version_control.sh" + +# Function to display usage information +usage() { + echo "Usage: $0 [OPTIONS]" + echo "Options:" + echo " --infer-dir DIR Directory containing model inference outputs" + echo " --split SPLIT SWE-Bench dataset split selection" + echo " --dataset DATASET Dataset name" + echo " --max-infer-turn NUM Max number of turns for coding agent" + echo " --align-with-max BOOL Align failed instance indices with max iteration (true/false)" + echo " -h, --help Display this help message" + echo "" + echo "Example:" + echo " $0 --infer-dir ./inference_outputs --split test --align-with-max false" +} + +# Check if no arguments were provided +if [ $# -eq 0 ]; then + usage + exit 1 +fi + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --infer-dir) + INFER_DIR="$2" + shift 2 + ;; + --split) + SPLIT="$2" + shift 2 + ;; + --dataset) + DATASET="$2" + shift 2 + ;; + --max-infer-turn) + MAX_TURN="$2" + shift 2 + ;; + --align-with-max) + ALIGN_WITH_MAX="$2" + shift 2 + ;; + -h|--help) + usage + exit 0 + ;; + *) + echo "Unknown option: $1" + usage + exit 1 + ;; + esac +done + +# Check for required arguments (only INFER_DIR is required) +if [ -z "$INFER_DIR" ]; then + echo "Error: Missing required arguments (--infer-dir is required)" + usage + exit 1 +fi + +# Set defaults for optional arguments if not provided +if [ -z "$SPLIT" ]; then + SPLIT="test" + echo "Split not specified, using default: $SPLIT" +fi + +if [ -z "$DATASET" ]; then + DATASET="princeton-nlp/SWE-bench_Verified" + echo "Dataset not specified, using default: $DATASET" +fi + +if [ -z "$MAX_TURN" ]; then + MAX_TURN=20 + echo "Max inference turn not specified, using default: $MAX_TURN" +fi + +if [ -z "$ALIGN_WITH_MAX" ]; then + ALIGN_WITH_MAX="true" + echo "Align with max not specified, using default: $ALIGN_WITH_MAX" +fi + +# Validate align-with-max value +if [ "$ALIGN_WITH_MAX" != "true" ] && [ "$ALIGN_WITH_MAX" != "false" ]; then + print_error "Invalid value for --align-with-max: $ALIGN_WITH_MAX. Must be 'true' or 'false'" + exit 1 +fi + +# Color codes for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' # No Color + +# Function to print colored output +print_status() { + echo -e "${GREEN}[INFO]${NC} $1" +} + +print_error() { + echo -e "${RED}[ERROR]${NC} $1" +} + +print_warning() { + echo -e "${YELLOW}[WARNING]${NC} $1" +} + +print_header() { + echo -e "${BLUE}[TASK]${NC} $1" +} + +# Check if Python is available +print_header "Checking Python installation..." +if ! command -v python3 &> /dev/null; then + if ! command -v python &> /dev/null; then + print_error "Python is not installed or not in PATH" + exit 1 + else + PYTHON_CMD="python" + print_status "Using python command" + fi +else + PYTHON_CMD="python3" + print_status "Using python3 command" +fi + +# Check if the Python script exists +SCRIPT_NAME="./evaluation/benchmarks/swe_bench/loc_eval/loc_evaluator.py" +if [ ! -f "$SCRIPT_NAME" ]; then + print_error "Python script '$SCRIPT_NAME' not found in current directory" + print_warning "Make sure the Python script is in the same directory as this bash script" + exit 1 +fi + +# Check if required directories exist +print_header "Validating directories..." +if [ ! -d "$INFER_DIR" ]; then + print_error "Inference directory not found: $INFER_DIR" + exit 1 +fi + +# Evaluation outputs +EVAL_DIR="$INFER_DIR/eval_outputs" + +# Display configuration +print_header "Starting Localization Evaluation with the following configuration:" +echo " Inference Directory: $INFER_DIR" +if [ -d "$EVAL_DIR" ]; then + echo " Evaluation Directory: $EVAL_DIR" +else + echo " Evaluation Directory: None (evaluation outputs doesn't exist)" +fi +echo " Output Directory: $INFER_DIR/loc_eval" +echo " Split: $SPLIT" +echo " Dataset: $DATASET" +echo " Max Turns: $MAX_TURN" +echo " Align with Max: $ALIGN_WITH_MAX" +echo " Python Command: $PYTHON_CMD" +echo "" + +# Check Python dependencies (optional check) +print_header "Checking Python dependencies..." +$PYTHON_CMD -c " +import sys +required_modules = ['pandas', 'json', 'os', 'argparse', 'collections'] +missing_modules = [] + +for module in required_modules: + try: + __import__(module) + except ImportError: + missing_modules.append(module) + +if missing_modules: + print(f'Missing required modules: {missing_modules}') + sys.exit(1) +else: + print('All basic dependencies are available') +" || { + print_error "Some Python dependencies are missing" + print_warning "Please install required packages: pip install pandas" + exit 1 +} + +# Create log directory if doesn't exists +mkdir -p "$INFER_DIR/loc_eval" + +# Set up logging +LOG_FILE="$INFER_DIR/loc_eval/loc_evaluation_$(date +%Y%m%d_%H%M%S).log" +print_status "Logging output to: $LOG_FILE" + +# Build the command +CMD_ARGS="\"$SCRIPT_NAME\" \ + --infer-dir \"$INFER_DIR\" \ + --split \"$SPLIT\" \ + --dataset \"$DATASET\" \ + --max-infer-turn \"$MAX_TURN\" \ + --align-with-max \"$ALIGN_WITH_MAX\"" + +# Run the Python script +print_header "Running localization evaluation..." +eval "$PYTHON_CMD $CMD_ARGS" 2>&1 | tee "$LOG_FILE" + +# Check if the script ran successfully +if [ ${PIPESTATUS[0]} -eq 0 ]; then + print_status "Localization evaluation completed successfully!" + print_status "Results saved to: $INFER_DIR/loc_eval" + print_status "Log file: $LOG_FILE" + + # Display summary if results exist + if [ -f "$INFER_DIR/loc_eval/loc_eval_results/loc_acc/overall_eval.json" ]; then + print_header "Evaluation Summary:" + cat "$INFER_DIR/loc_eval/loc_eval_results/loc_acc/overall_eval.json" + echo + fi +else + print_error "Localization evaluation failed!" + print_warning "Check the log file for details: $LOG_FILE" + exit 1 +fi