mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
61 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bc27eae841 | |||
| f0e5c81272 | |||
| a4b836b5f9 | |||
| a4d632498c | |||
| 4f017081fc | |||
| 51fb1fae88 | |||
| 106b230fea | |||
| 9b262dd057 | |||
| 8074b261d3 | |||
| 999a59f938 | |||
| fbba57d3b5 | |||
| 3f6c8a2338 | |||
| dd09d46ccb | |||
| 8897b45eeb | |||
| 30109e8f20 | |||
| 4984bf6ee7 | |||
| 92ddc1b46c | |||
| 367c8a9f83 | |||
| 09335d67be | |||
| eb36426d33 | |||
| ace9e6e724 | |||
| 1290a2599d | |||
| 513dd9791d | |||
| fd53378d06 | |||
| 4471002c79 | |||
| 326e75e829 | |||
| d8ad8babf6 | |||
| 8782e3ae65 | |||
| eef0ed3410 | |||
| e7a8daf3ec | |||
| 64abd4a95e | |||
| c7d575b4e1 | |||
| f781bc8343 | |||
| 9f9a65c787 | |||
| 3355baea4c | |||
| 8848e60c6d | |||
| d1e84093cc | |||
| 3f0f13d335 | |||
| 219a134bb0 | |||
| efb525a463 | |||
| 1ded123116 | |||
| 31b6967a87 | |||
| 90422e5bfd | |||
| b47da9e894 | |||
| 3401bd610d | |||
| 77a153e42f | |||
| fb9bc87e35 | |||
| b685c67263 | |||
| 2cd64bc636 | |||
| 7c81deb132 | |||
| b19f735808 | |||
| 3af6025303 | |||
| 585dba9917 | |||
| bd66d09a33 | |||
| 791b7f9f60 | |||
| 30197e616b | |||
| f7f25319e3 | |||
| 75fba59588 | |||
| 280baa24e2 | |||
| c6206f5cf2 | |||
| 7a4729c034 |
@@ -12,7 +12,7 @@
|
||||
<a href="https://codecov.io/github/All-Hands-AI/OpenHands?branch=main"><img alt="CodeCov" src="https://img.shields.io/codecov/c/github/All-Hands-AI/OpenHands?style=for-the-badge&color=blue"></a>
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/github/license/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="MIT License"></a>
|
||||
<br/>
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community"></a>
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community"></a>
|
||||
<a href="https://discord.gg/ESHStjSjD4"><img src="https://img.shields.io/badge/Discord-Join%20Us-purple?logo=discord&logoColor=white&style=for-the-badge" alt="Join our Discord community"></a>
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/CREDITS.md"><img src="https://img.shields.io/badge/Project-Credits-blue?style=for-the-badge&color=FFE165&logo=github&logoColor=white" alt="Credits"></a>
|
||||
<br/>
|
||||
@@ -96,7 +96,7 @@ troubleshooting resources, and advanced configuration options.
|
||||
OpenHands is a community-driven project, and we welcome contributions from everyone. We do most of our communication
|
||||
through Slack, so this is the best place to start, but we also are happy to have you contact us on Discord or Github:
|
||||
|
||||
- [Join our Slack workspace](https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg) - Here we talk about research, architecture, and future development.
|
||||
- [Join our Slack workspace](https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw) - Here we talk about research, architecture, and future development.
|
||||
- [Join our Discord server](https://discord.gg/ESHStjSjD4) - This is a community-run server for general discussion, questions, and feedback.
|
||||
- [Read or post Github Issues](https://github.com/All-Hands-AI/OpenHands/issues) - Check out the issues we're working on, or add your own ideas.
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ Explorez le code source d'OpenHands sur [GitHub](https://github.com/All-Hands-AI
|
||||
/>
|
||||
</a>
|
||||
<br></br>
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg">
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw">
|
||||
<img
|
||||
src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge"
|
||||
alt="Join our Slack community"
|
||||
|
||||
@@ -42,7 +42,7 @@ OpenHands 是一个**自主 AI 软件工程师**,能够执行复杂的工程
|
||||
/>
|
||||
</a>
|
||||
<br></br>
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg">
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw">
|
||||
<img
|
||||
src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge"
|
||||
alt="Join our Slack community"
|
||||
|
||||
@@ -8,7 +8,7 @@ function CustomFooter() {
|
||||
<footer className="custom-footer">
|
||||
<div className="footer-content">
|
||||
<div className="footer-icons">
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg" target="_blank" rel="noopener noreferrer">
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw" target="_blank" rel="noopener noreferrer">
|
||||
<FaSlack />
|
||||
</a>
|
||||
<a href="https://discord.gg/ESHStjSjD4" target="_blank" rel="noopener noreferrer">
|
||||
|
||||
@@ -46,7 +46,7 @@ export function HomepageHeader() {
|
||||
<a href="https://codecov.io/github/All-Hands-AI/OpenHands?branch=main"><img alt="CodeCov" src="https://img.shields.io/codecov/c/github/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" /></a>
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/github/license/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="MIT License" /></a>
|
||||
<br/>
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community" /></a>
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community" /></a>
|
||||
<a href="https://discord.gg/ESHStjSjD4"><img src="https://img.shields.io/badge/Discord-Join%20Us-purple?logo=discord&logoColor=white&style=for-the-badge" alt="Join our Discord community" /></a>
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/CREDITS.md"><img src="https://img.shields.io/badge/Project-Credits-blue?style=for-the-badge&color=FFE165&logo=github&logoColor=white" alt="Credits" /></a>
|
||||
<br/>
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
@@ -175,6 +176,11 @@ def process_instance(
|
||||
logger.warning(
|
||||
f'This is the {runtime_failure_count + 1}th attempt for instance {instance.instance_id}, setting resource factor to {config.sandbox.remote_runtime_resource_factor}'
|
||||
)
|
||||
metadata = copy.deepcopy(metadata)
|
||||
metadata.details['runtime_failure_count'] = runtime_failure_count
|
||||
metadata.details['remote_runtime_resource_factor'] = (
|
||||
config.sandbox.remote_runtime_resource_factor
|
||||
)
|
||||
|
||||
try:
|
||||
runtime = create_runtime(config)
|
||||
@@ -296,14 +302,20 @@ def process_instance(
|
||||
with open(test_output_path, 'w') as f:
|
||||
f.write(test_output)
|
||||
try:
|
||||
extra_kwargs = {}
|
||||
if 'SWE-Gym' in metadata.dataset:
|
||||
# SWE-Gym uses a different version of the package, hence a different eval report argument
|
||||
extra_kwargs['log_path'] = test_output_path
|
||||
else:
|
||||
extra_kwargs['test_log_path'] = test_output_path
|
||||
_report = conditional_imports.get_eval_report(
|
||||
test_spec=test_spec,
|
||||
prediction={
|
||||
'model_patch': model_patch,
|
||||
'instance_id': instance_id,
|
||||
},
|
||||
test_log_path=test_output_path,
|
||||
include_tests_status=True,
|
||||
**extra_kwargs,
|
||||
)
|
||||
report = _report[instance_id]
|
||||
logger.info(
|
||||
@@ -463,6 +475,7 @@ if __name__ == '__main__':
|
||||
.decode('utf-8')
|
||||
.strip(), # Current commit
|
||||
dataset=args.dataset, # Dataset name from args
|
||||
details={},
|
||||
)
|
||||
|
||||
# The evaluation harness constrains the signature of `process_instance_func` but we need to
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -23,7 +23,7 @@ def get_resource_mapping(dataset_name: str) -> dict[str, float]:
|
||||
if dataset_name not in _global_resource_mapping:
|
||||
file_path = os.path.join(CUR_DIR, f'{dataset_name}.json')
|
||||
if not os.path.exists(file_path):
|
||||
logger.warning(f'Resource mapping for {dataset_name} not found.')
|
||||
logger.info(f'Resource mapping for {dataset_name} not found.')
|
||||
return None
|
||||
|
||||
with open(file_path, 'r') as f:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
@@ -149,7 +150,8 @@ def get_config(
|
||||
) -> AppConfig:
|
||||
# We use a different instance image for the each instance of swe-bench eval
|
||||
use_official_image = bool(
|
||||
'verified' in metadata.dataset.lower() or 'lite' in metadata.dataset.lower()
|
||||
('verified' in metadata.dataset.lower() or 'lite' in metadata.dataset.lower())
|
||||
and 'swe-gym' not in metadata.dataset.lower()
|
||||
)
|
||||
base_container_image = get_instance_docker_image(
|
||||
instance['instance_id'], use_official_image
|
||||
@@ -475,6 +477,13 @@ def process_instance(
|
||||
logger.warning(
|
||||
f'This is the {runtime_failure_count + 1}th attempt for instance {instance.instance_id}, setting resource factor to {config.sandbox.remote_runtime_resource_factor}'
|
||||
)
|
||||
|
||||
metadata = copy.deepcopy(metadata)
|
||||
metadata.details['runtime_failure_count'] = runtime_failure_count
|
||||
metadata.details['remote_runtime_resource_factor'] = (
|
||||
config.sandbox.remote_runtime_resource_factor
|
||||
)
|
||||
|
||||
runtime = create_runtime(config)
|
||||
call_async_from_sync(runtime.connect)
|
||||
|
||||
@@ -560,20 +569,6 @@ def filter_dataset(dataset: pd.DataFrame, filter_column: str) -> pd.DataFrame:
|
||||
return dataset
|
||||
|
||||
|
||||
# A list of instances that are known to be tricky to infer
|
||||
# (will cause runtime failure even with resource factor = 8)
|
||||
SWEGYM_EXCLUDE_IDS = [
|
||||
'dask__dask-10422',
|
||||
'pandas-dev__pandas-50548',
|
||||
'pandas-dev__pandas-53672',
|
||||
'pandas-dev__pandas-54174',
|
||||
'pandas-dev__pandas-55518',
|
||||
'pandas-dev__pandas-58383',
|
||||
'pydata__xarray-6721',
|
||||
'pytest-dev__pytest-10081',
|
||||
'pytest-dev__pytest-7236',
|
||||
]
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = get_parser()
|
||||
parser.add_argument(
|
||||
@@ -598,11 +593,20 @@ if __name__ == '__main__':
|
||||
f'Loaded dataset {args.dataset} with split {args.split}: {len(swe_bench_tests)} tasks'
|
||||
)
|
||||
if 'SWE-Gym' in args.dataset:
|
||||
swe_bench_tests = swe_bench_tests[
|
||||
~swe_bench_tests['instance_id'].isin(SWEGYM_EXCLUDE_IDS)
|
||||
]
|
||||
with open(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
'split',
|
||||
'swegym_verified_instances.json',
|
||||
),
|
||||
'r',
|
||||
) as f:
|
||||
swegym_verified_instances = json.load(f)
|
||||
swe_bench_tests = swe_bench_tests[
|
||||
swe_bench_tests['instance_id'].isin(swegym_verified_instances)
|
||||
]
|
||||
logger.info(
|
||||
f'{len(swe_bench_tests)} tasks left after excluding SWE-Gym excluded tasks'
|
||||
f'{len(swe_bench_tests)} tasks left after filtering for SWE-Gym verified instances'
|
||||
)
|
||||
|
||||
llm_config = None
|
||||
|
||||
@@ -9,7 +9,7 @@ parser.add_argument(
|
||||
'--dataset_name',
|
||||
type=str,
|
||||
help='Name of the dataset to download',
|
||||
default='princeton-nlp/SWE-bench_Lite',
|
||||
default='princeton-nlp/SWE-bench_Verified',
|
||||
)
|
||||
parser.add_argument('--split', type=str, help='Split to download', default='test')
|
||||
args = parser.parse_args()
|
||||
@@ -20,7 +20,12 @@ print(
|
||||
f'Downloading gold patches from {args.dataset_name} (split: {args.split}) to {output_filepath}'
|
||||
)
|
||||
patches = [
|
||||
{'instance_id': row['instance_id'], 'model_patch': row['patch']} for row in dataset
|
||||
{
|
||||
'instance_id': row['instance_id'],
|
||||
'model_patch': row['patch'],
|
||||
'model_name_or_path': 'gold',
|
||||
}
|
||||
for row in dataset
|
||||
]
|
||||
print(f'{len(patches)} gold patches loaded')
|
||||
pd.DataFrame(patches).to_json(output_filepath, lines=True, orient='records')
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,12 @@
|
||||
codamosa_ids = ['pydata__xarray-4750-16496', 'pydata__xarray-3239-16458', 'pydata__xarray-4966-16515', 'pydata__xarray-3302-16459', 'pydata__xarray-5126-16518', 'pydata__xarray-4994-16516', 'pydata__xarray-3905-16478', 'pydata__xarray-4182-16484', 'pydata__xarray-5131-16520', 'pydata__xarray-5662-16532', 'pydata__xarray-3364-16461', 'pydata__xarray-5731-16534', 'pydata__xarray-3239-16457', 'pydata__xarray-7203-16577', 'pydata__xarray-3156-16454', 'pydata__xarray-5126-16519', 'pydata__xarray-5365-16529', 'pydata__xarray-4629-16492', 'pydata__xarray-4248-16486', 'pydata__xarray-4339-16487', 'pydata__xarray-3151-16453', 'pydata__xarray-3114-16452', 'pydata__xarray-5033-16517', 'pydata__xarray-4802-16505', 'pydata__xarray-5455-16530', 'pydata__xarray-6400-16539', 'pydata__xarray-3239-16456', 'pydata__xarray-4419-16488']
|
||||
|
||||
pynguin_ids = ['pydata__xarray-6548-16541', 'pydata__xarray-7003-16557', 'pydata__xarray-3114-16452', 'pydata__xarray-4339-16487', 'pydata__xarray-6889-16549', 'pydata__xarray-3239-16458', 'pydata__xarray-3364-16461', 'pydata__xarray-3239-16457', 'pydata__xarray-5365-16529', 'pydata__xarray-5131-16520', 'pydata__xarray-7229-16578', 'pydata__xarray-6461-16540', 'pydata__xarray-4419-16488', 'pydata__xarray-7147-16571', 'pydata__xarray-3151-16453', 'pydata__xarray-4966-16515', 'pydata__xarray-4629-16492', 'pydata__xarray-3239-16456', 'pydata__xarray-7400-16582', 'pydata__xarray-4994-16516', 'pydata__xarray-3302-16459', 'pydata__xarray-6601-16544', 'pydata__xarray-6882-16548', 'pydata__xarray-6135-16535', 'pydata__xarray-7393-16581', 'pydata__xarray-5731-16534', 'pydata__xarray-7203-16577']
|
||||
|
||||
ids = ['pydata__xarray-3114-16452', 'pydata__xarray-3151-16453', 'pydata__xarray-3156-16454', 'pydata__xarray-3239-16456', 'pydata__xarray-3239-16457', 'pydata__xarray-3239-16458', 'pydata__xarray-3302-16459', 'pydata__xarray-3364-16461', 'pydata__xarray-3677-16471', 'pydata__xarray-3905-16478', 'pydata__xarray-4182-16484', 'pydata__xarray-4248-16486', 'pydata__xarray-4339-16487', 'pydata__xarray-4419-16488', 'pydata__xarray-4629-16492', 'pydata__xarray-4750-16496', 'pydata__xarray-4802-16505', 'pydata__xarray-4966-16515', 'pydata__xarray-4994-16516', 'pydata__xarray-5033-16517', 'pydata__xarray-5126-16518', 'pydata__xarray-5126-16519', 'pydata__xarray-5131-16520', 'pydata__xarray-5365-16529', 'pydata__xarray-5455-16530', 'pydata__xarray-5662-16532', 'pydata__xarray-5731-16534', 'pydata__xarray-6135-16535', 'pydata__xarray-6135-16536', 'pydata__xarray-6386-16537', 'pydata__xarray-6394-16538', 'pydata__xarray-6400-16539', 'pydata__xarray-6461-16540', 'pydata__xarray-6548-16541', 'pydata__xarray-6599-16543', 'pydata__xarray-6601-16544', 'pydata__xarray-6882-16548', 'pydata__xarray-6889-16549', 'pydata__xarray-7003-16557', 'pydata__xarray-7147-16571', 'pydata__xarray-7150-16572', 'pydata__xarray-7203-16577', 'pydata__xarray-7229-16578', 'pydata__xarray-7393-16581', 'pydata__xarray-7400-16582']
|
||||
|
||||
|
||||
Command eval (our approach):
|
||||
poetry run ./evaluation/benchmarks/testgeneval/scripts/eval_infer_remote.sh evaluation/evaluation_outputs/outputs/kjain14__testgeneval-test/CodeActAgent/gpt-4o_maxiter_25_N_v0.20.0-no-hint-run_1/output.jsonl 10 kjain14/testgeneval test true
|
||||
|
||||
Command run (our approach):
|
||||
./evaluation/benchmarks/testgeneval/scripts/run_infer.sh llm.eval_gpt HEAD CodeActAgent -1 25 10 kjain14/testgeneval test 1 ../TestGenEval/results/testgeneval/preds/gpt-4o-2024-08-06__testgeneval__0.2__test.jsonl
|
||||
@@ -0,0 +1,80 @@
|
||||
# TestGenEval Benchmark Evaluation
|
||||
|
||||
This folder contains the evaluation harness for the TestGenEval benchmark, which is based on the original TestGenEval benchmark ([paper](https://arxiv.org/abs/2410.00752)). TestGenEval is designed to evaluate the ability of language models to generate unit tests for given Python functions.
|
||||
|
||||
## Setup Environment and LLM Configuration
|
||||
|
||||
1. Follow the instructions [here](../../README.md#setup) to set up your local development environment and configure your LLM.
|
||||
|
||||
2. Install the TestGenEval dependencies:
|
||||
```bash
|
||||
poetry install --with testgeneval
|
||||
```
|
||||
|
||||
## Run Inference
|
||||
|
||||
To generate tests using your model, run the following command:
|
||||
|
||||
```bash
|
||||
./evaluation/benchmarks/testgeneval/scripts/run_infer.sh [model_config] [git-version] [agent] [eval_limit] [max_iter] [num_workers] [dataset] [dataset_split]
|
||||
|
||||
# Example
|
||||
./evaluation/benchmarks/testgeneval/scripts/run_infer.sh llm.eval_gpt4_1106_preview HEAD CodeActAgent 100 30 1 kjain14/testgenevallite test
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- `model_config`: The config group name for your LLM settings (e.g., `eval_gpt4_1106_preview`)
|
||||
- `git-version`: The git commit hash or release tag of OpenHands to evaluate (e.g., `HEAD` or `0.6.2`)
|
||||
- `agent`: The name of the agent for benchmarks (default: `CodeActAgent`)
|
||||
- `eval_limit`: Limit the evaluation to the first N instances (optional)
|
||||
- `max_iter`: Maximum number of iterations for the agent to run (default: 30)
|
||||
- `num_workers`: Number of parallel workers for evaluation (default: 1)
|
||||
- `dataset`: HuggingFace dataset name (default: `kjain14/testgenevallite`)
|
||||
- `dataset_split`: Dataset split to use (default: `test`)
|
||||
|
||||
After running the inference, you will obtain an `output.jsonl` file (by default saved to `evaluation/evaluation_outputs`).
|
||||
|
||||
## Evaluate Generated Tests
|
||||
|
||||
To evaluate the generated tests, use the `eval_infer.sh` script:
|
||||
|
||||
```bash
|
||||
./evaluation/benchmarks/testgeneval/scripts/eval_infer.sh $YOUR_OUTPUT_JSONL [instance_id] [dataset_name] [split] [num_workers] [skip_mutation]
|
||||
|
||||
# Example
|
||||
./evaluation/benchmarks/testgeneval/scripts/eval_infer.sh evaluation/evaluation_outputs/outputs/testgeneval/CodeActAgent/gpt-4-1106-preview_maxiter_50_N_v1.0/output.jsonl
|
||||
```
|
||||
|
||||
Optional arguments:
|
||||
- `instance_id`: Evaluate a single instance (optional)
|
||||
- `dataset_name`: Name of the dataset to use (default: `kjain14/testgenevallite`)
|
||||
- `split`: Dataset split to use (default: `test`)
|
||||
- `num_workers`: Number of workers for running docker (default: 1)
|
||||
- `skip_mutation`: Skip mutation testing (enter `true` if desired)
|
||||
|
||||
The evaluation results will be saved to `evaluation/evaluation_outputs/outputs/testgeneval/CodeActAgent/gpt-4-1106-preview_maxiter_50_N_v1.0/` with `output.testgeneval.jsonl` containing the metrics.
|
||||
|
||||
## Metrics
|
||||
|
||||
The TestGenEval benchmark evaluates generated tests based on the following metrics:
|
||||
|
||||
1. Correctness: Measures if the generated tests are syntactically correct and run without errors.
|
||||
2. Coverage: Assesses the code coverage achieved by the generated tests.
|
||||
3. Mutation Score: Evaluates the effectiveness of the tests in detecting intentionally introduced bugs (mutations).
|
||||
4. Readability: Analyzes the readability of the generated tests using various metrics.
|
||||
|
||||
## Submit Your Evaluation Results
|
||||
|
||||
To contribute your evaluation results:
|
||||
|
||||
1. Fork [our HuggingFace evaluation outputs](https://huggingface.co/spaces/OpenHands/evaluation).
|
||||
2. Add your results to the forked repository.
|
||||
3. Submit a Pull Request with your evaluation results following the guide [here](https://huggingface.co/docs/hub/en/repositories-pull-requests-discussions#pull-requests-and-discussions).
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [TestGenEval Paper](https://arxiv.org/abs/2410.00752)
|
||||
- [OpenHands Documentation](https://github.com/All-Hands-AI/OpenHands)
|
||||
- [HuggingFace Datasets](https://huggingface.co/datasets)
|
||||
|
||||
For any questions or issues, please open an issue in the [OpenHands repository](https://github.com/All-Hands-AI/OpenHands/issues).
|
||||
@@ -0,0 +1,356 @@
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
|
||||
def total_byte_entropy_stats(python_code):
|
||||
# Count the occurrence of each byte (character for simplicity)
|
||||
byte_counts = {}
|
||||
for byte in python_code.encode('utf-8'):
|
||||
byte_counts[byte] = byte_counts.get(byte, 0) + 1
|
||||
|
||||
total_bytes = sum(byte_counts.values())
|
||||
entropy = -sum(
|
||||
(count / total_bytes) * math.log2(count / total_bytes)
|
||||
for count in byte_counts.values()
|
||||
)
|
||||
|
||||
return {'total_byte_entropy': entropy}
|
||||
|
||||
|
||||
def average_nulls_stats(tree, num_lines):
|
||||
total_nulls = 0
|
||||
nulls_per_line = {} # Dictionary to count nulls per line
|
||||
|
||||
def traverse(node):
|
||||
nonlocal total_nulls
|
||||
if node.type == 'null_literal':
|
||||
total_nulls += 1
|
||||
line_number = node.start_point[0] # Get line number
|
||||
if line_number in nulls_per_line:
|
||||
nulls_per_line[line_number] += 1
|
||||
else:
|
||||
nulls_per_line[line_number] = 1
|
||||
for child in node.children:
|
||||
traverse(child)
|
||||
|
||||
traverse(tree.root_node)
|
||||
|
||||
# Calculate average nulls per line
|
||||
avg_nulls = total_nulls / num_lines if num_lines > 0 else 0
|
||||
|
||||
# Calculate max nulls on any line
|
||||
max_nulls_on_any_line = max(nulls_per_line.values()) if nulls_per_line else 0
|
||||
|
||||
return {
|
||||
'avg_nulls': avg_nulls,
|
||||
'total_nulls': total_nulls,
|
||||
'max_nulls': max_nulls_on_any_line,
|
||||
'has_nulls': 1 if total_nulls > 0 else 0,
|
||||
}
|
||||
|
||||
|
||||
def arithmetic_operations_stats(tree, num_lines):
|
||||
# Dictionary to hold counts of each arithmetic operation
|
||||
op_counts = {'+': 0, '-': 0, '*': 0, '/': 0, '%': 0}
|
||||
total_ops = 0
|
||||
|
||||
# Function to traverse the AST and update operation counts
|
||||
def traverse(node):
|
||||
nonlocal total_ops
|
||||
if node.type == 'binary_expression' or node.type == 'update_expression':
|
||||
for child in node.children:
|
||||
if child.type == 'operator':
|
||||
op = child.text.decode('utf8')
|
||||
if op in op_counts:
|
||||
op_counts[op] += 1
|
||||
total_ops += 1
|
||||
else:
|
||||
for child in node.children:
|
||||
traverse(child)
|
||||
|
||||
traverse(tree.root_node)
|
||||
|
||||
return {
|
||||
'total_arithmetic_operations': total_ops,
|
||||
'avg_arithmetic_operations': total_ops / num_lines,
|
||||
}
|
||||
|
||||
|
||||
def numbers_floats_stats(tree, num_lines):
|
||||
total_numbers = 0
|
||||
total_floats = 0
|
||||
|
||||
def traverse(node):
|
||||
nonlocal total_numbers, total_floats
|
||||
if node.type in ['integer_literal', 'decimal_literal']:
|
||||
total_numbers += 1
|
||||
if (
|
||||
'.' in node.text.decode('utf8')
|
||||
or 'e' in node.text.decode('utf8').lower()
|
||||
):
|
||||
total_floats += 1
|
||||
for child in node.children:
|
||||
traverse(child)
|
||||
|
||||
traverse(tree.root_node)
|
||||
return {'total_numbers': total_numbers, 'total_floats': total_floats}
|
||||
|
||||
|
||||
def code_stats(python_code):
|
||||
lines = python_code.strip().split('\n')
|
||||
total_line_length = sum(len(line) for line in lines)
|
||||
max_line_length = max(len(line) for line in lines)
|
||||
return {
|
||||
'total_line_length': total_line_length,
|
||||
'max_line_length': max_line_length,
|
||||
'avg_characters': total_line_length / len(lines),
|
||||
}
|
||||
|
||||
|
||||
def assertions_stats(tree, num_lines):
|
||||
total_assertions = 0
|
||||
|
||||
def traverse(node):
|
||||
nonlocal total_assertions
|
||||
if node.type == 'assert_statement':
|
||||
total_assertions += 1
|
||||
for child in node.children:
|
||||
traverse(child)
|
||||
|
||||
traverse(tree.root_node)
|
||||
return {
|
||||
'total_assertions': total_assertions,
|
||||
'total_has_assertions': 1 if total_assertions > 0 else 0,
|
||||
}
|
||||
|
||||
|
||||
def class_instances_stats(tree, num_lines):
|
||||
total_class_instances = 0
|
||||
|
||||
def traverse(node):
|
||||
nonlocal total_class_instances
|
||||
if node.type == 'object_creation_expression':
|
||||
total_class_instances += 1
|
||||
for child in node.children:
|
||||
traverse(child)
|
||||
|
||||
traverse(tree.root_node)
|
||||
return {'total_class_instances': total_class_instances}
|
||||
|
||||
|
||||
def has_execeptions(tree, num_lines):
|
||||
total_has_exceptions = 0
|
||||
|
||||
def traverse(node):
|
||||
nonlocal total_has_exceptions
|
||||
if node.type == 'try_statement':
|
||||
total_has_exceptions += 1
|
||||
for child in node.children:
|
||||
traverse(child)
|
||||
|
||||
traverse(tree.root_node)
|
||||
return {'total_has_exceptions': 1 if total_has_exceptions > 0 else 0}
|
||||
|
||||
|
||||
def distinct_methods_stats(tree, num_lines):
|
||||
method_names = set()
|
||||
total_nodes = 0
|
||||
|
||||
def traverse(node):
|
||||
nonlocal total_nodes
|
||||
if node.type == 'method_declaration':
|
||||
for child in node.children:
|
||||
if child.type == 'identifier':
|
||||
method_names.add(child.text.decode('utf8'))
|
||||
break
|
||||
total_nodes += 1
|
||||
for child in node.children:
|
||||
traverse(child)
|
||||
|
||||
traverse(tree.root_node)
|
||||
total_distinct_methods = len(method_names)
|
||||
total_method_ratio = (
|
||||
total_distinct_methods / (total_nodes - total_distinct_methods)
|
||||
if total_nodes > total_distinct_methods
|
||||
else 0
|
||||
)
|
||||
|
||||
return {
|
||||
'total_distinct_methods': total_distinct_methods,
|
||||
'total_method_ratio': total_method_ratio,
|
||||
}
|
||||
|
||||
|
||||
def loops_stats(tree, num_lines):
|
||||
"""
|
||||
Calculate the average number of loops.
|
||||
"""
|
||||
total_loops = 0
|
||||
|
||||
def traverse(node):
|
||||
nonlocal total_loops
|
||||
if node.type in ['for_statement', 'while_statement', 'do_statement']:
|
||||
total_loops += 1
|
||||
for child in node.children:
|
||||
traverse(child)
|
||||
|
||||
traverse(tree.root_node)
|
||||
avg_loops = total_loops / num_lines
|
||||
return {'avg_loops': avg_loops}
|
||||
|
||||
|
||||
def branches_stats(tree, num_lines):
|
||||
"""
|
||||
Calculate the average number of branches (conditional statements).
|
||||
"""
|
||||
total_branches = 0
|
||||
|
||||
def traverse(node):
|
||||
nonlocal total_branches
|
||||
if node.type in ['if_statement', 'switch_statement']:
|
||||
total_branches += 1
|
||||
for child in node.children:
|
||||
traverse(child)
|
||||
|
||||
traverse(tree.root_node)
|
||||
# Assuming each branch is its own, this might need refinement based on definition
|
||||
avg_branches = total_branches / num_lines
|
||||
return {'avg_branches': avg_branches}
|
||||
|
||||
|
||||
def string_stats(tree, num_lines):
|
||||
string_literals = []
|
||||
|
||||
# Function to traverse the AST and collect string literals
|
||||
def traverse(node):
|
||||
if node.type == 'string_literal':
|
||||
# Extracting the string literal, excluding the quotation marks
|
||||
literal_text = node.text.decode('utf8')[1:-1]
|
||||
string_literals.append(literal_text)
|
||||
for child in node.children:
|
||||
traverse(child)
|
||||
|
||||
traverse(tree.root_node)
|
||||
|
||||
# Calculate the average string length
|
||||
total_length = sum(len(s) for s in string_literals)
|
||||
avg_length = total_length / num_lines
|
||||
return {'avg_str_length': avg_length}
|
||||
|
||||
|
||||
def identifier_stats(tree, num_lines):
|
||||
root_node = tree.root_node
|
||||
identifier_counts = {} # Dictionary to count occurrences of each identifier
|
||||
total_nodes = 0 # Counter for all nodes
|
||||
|
||||
# Function to recursively count identifiers and all nodes, gathering their stats
|
||||
def count(node):
|
||||
nonlocal identifier_counts, total_nodes
|
||||
iden_count = 0
|
||||
max_length = 0
|
||||
total_nodes += 1 # Increment total nodes for every node visited
|
||||
if node.type == 'identifier':
|
||||
identifier = node.text.decode('utf8') # Assuming UTF-8 encoding
|
||||
iden_count += 1
|
||||
identifier_counts[identifier] = identifier_counts.get(identifier, 0) + 1
|
||||
iden_length = len(identifier)
|
||||
if iden_length > max_length:
|
||||
max_length = iden_length
|
||||
for child in node.children:
|
||||
child_count, child_max_length = count(child)
|
||||
iden_count += child_count
|
||||
if child_max_length > max_length:
|
||||
max_length = child_max_length
|
||||
return iden_count, max_length
|
||||
|
||||
total_identifiers, max_identifier_length = count(root_node)
|
||||
total_unique_identifiers = len(identifier_counts)
|
||||
total_identifier_length = sum(len(k) * v for k, v in identifier_counts.items())
|
||||
avg_identifier_length = total_identifier_length / num_lines
|
||||
|
||||
# Calculate the identifier ratio as total identifiers over total nodes
|
||||
identifier_ratio = total_identifiers / total_nodes if total_nodes > 0 else 0
|
||||
|
||||
return {
|
||||
'total_identifiers': total_identifiers,
|
||||
'total_identifier_length': total_identifier_length,
|
||||
'max_identifier_length': max_identifier_length,
|
||||
'avg_identifier_length': avg_identifier_length,
|
||||
'total_unique_identifiers': total_unique_identifiers,
|
||||
'total_identifier_ratio': identifier_ratio, # Include the new ratio in the returned dictionary
|
||||
'total_nodes': total_nodes, # Include total node count for reference or further calculations
|
||||
}
|
||||
|
||||
|
||||
def compute_regression(results):
|
||||
components = {
|
||||
'total_line_length': -0.0001,
|
||||
'max_line_length': -0.0021,
|
||||
'total_identifiers': 0.0076,
|
||||
'total_identifier_length': -0.0004,
|
||||
'max_identifier_length': -0.0067,
|
||||
'avg_identifier_length': -0.005,
|
||||
'avg_arithmetic_operations': 0.0225,
|
||||
'avg_branches': 0.9886,
|
||||
'avg_loops': 0.1572,
|
||||
'total_assertions': 0.0119,
|
||||
'total_has_assertions': -0.0147,
|
||||
'avg_characters': 0.1242,
|
||||
'total_class_instances': -0.043,
|
||||
'total_distinct_methods': -0.0127,
|
||||
'avg_str_length': 0.0026,
|
||||
'total_has_exceptions': 0.1206,
|
||||
'total_unique_identifiers': -0.019,
|
||||
'max_nulls': -0.0712,
|
||||
'total_numbers': -0.0078,
|
||||
'avg_nulls': 0.1444,
|
||||
'total_identifier_ratio': 0.334,
|
||||
'total_method_ratio': 0.0406,
|
||||
'total_floats': -0.0174,
|
||||
'total_byte_entropy': -0.3917,
|
||||
}
|
||||
test_score = 0
|
||||
|
||||
for component in components:
|
||||
test_score += components[component] * results[component]
|
||||
|
||||
test_score += 5.7501
|
||||
return test_score
|
||||
|
||||
|
||||
def compute_readability(python_code):
|
||||
parser = Parser()
|
||||
this_dir = Path(os.path.dirname(os.path.realpath(__file__)))
|
||||
parser.set_language(Language.build_library(
|
||||
# Store the library in the `build` directory
|
||||
this_dir / "build" / "my-languages.so",
|
||||
# Include one or more languages
|
||||
[
|
||||
this_dir / "tree-sitter-python"
|
||||
]
|
||||
).get_language('python'))
|
||||
|
||||
results = code_stats(python_code)
|
||||
|
||||
num_lines = len(python_code.strip().split('\n'))
|
||||
results.update(total_byte_entropy_stats(python_code))
|
||||
|
||||
tree = parser.parse(bytes(python_code, 'utf8'))
|
||||
|
||||
results.update(identifier_stats(tree, num_lines))
|
||||
results.update(loops_stats(tree, num_lines))
|
||||
results.update(branches_stats(tree, num_lines))
|
||||
results.update(distinct_methods_stats(tree, num_lines))
|
||||
results.update(has_execeptions(tree, num_lines))
|
||||
results.update(class_instances_stats(tree, num_lines))
|
||||
results.update(assertions_stats(tree, num_lines))
|
||||
results.update(numbers_floats_stats(tree, num_lines))
|
||||
results.update(average_nulls_stats(tree, num_lines))
|
||||
results.update(arithmetic_operations_stats(tree, num_lines))
|
||||
results.update(string_stats(tree, num_lines))
|
||||
|
||||
score = compute_regression(results)
|
||||
return score
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,633 @@
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
import pandas as pd
|
||||
from report_utils import (
|
||||
check_coverage,
|
||||
check_mutation,
|
||||
count_methods,
|
||||
get_lines_of_code,
|
||||
)
|
||||
|
||||
from evaluation.benchmarks.testgeneval.compute_readability import compute_readability
|
||||
from evaluation.benchmarks.testgeneval.constants import (
|
||||
COVERAGE_PREFIX,
|
||||
MUTATION_BUFFER,
|
||||
MUTATION_TEMPLATE,
|
||||
MUTATION_TIMEOUT,
|
||||
TESTS_SUFFIX,
|
||||
)
|
||||
from evaluation.benchmarks.testgeneval.metrics import (
|
||||
bleu,
|
||||
code_bleu,
|
||||
edit_sim,
|
||||
exact_match,
|
||||
rouge_l,
|
||||
)
|
||||
from evaluation.benchmarks.testgeneval.pygments_utils import tokenize_code
|
||||
from evaluation.benchmarks.testgeneval.run_infer import get_instance_docker_image
|
||||
from evaluation.benchmarks.testgeneval.test_filter import filter_tests
|
||||
from evaluation.benchmarks.testgeneval.test_spec import (
|
||||
TestGenEvalInstance,
|
||||
TestSpec,
|
||||
make_test_spec,
|
||||
)
|
||||
from evaluation.benchmarks.testgeneval.utils import load_testgeneval_dataset
|
||||
from evaluation.utils.shared import (
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
run_evaluation,
|
||||
)
|
||||
from openhands.core.config import AppConfig, SandboxConfig, get_parser
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime
|
||||
from openhands.events.action import CmdRunAction
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
from openhands.utils.async_utils import call_async_from_sync
|
||||
|
||||
DOCKER_IMAGE_PREFIX = os.environ.get('EVAL_DOCKER_IMAGE_PREFIX', 'docker.io/kdjain/')
|
||||
logger.info(f'Using docker image prefix: {DOCKER_IMAGE_PREFIX}')
|
||||
|
||||
|
||||
def get_config(instance: pd.Series) -> AppConfig:
|
||||
base_container_image = get_instance_docker_image(instance['instance_id_swebench'])
|
||||
assert (
|
||||
base_container_image
|
||||
), f"Invalid container image for instance {instance['instance_id_swebench']}."
|
||||
logger.info(f'Using instance container image: {base_container_image}.')
|
||||
return AppConfig(
|
||||
run_as_openhands=False,
|
||||
runtime=os.environ.get('RUNTIME', 'eventstream'),
|
||||
sandbox=SandboxConfig(
|
||||
base_container_image=base_container_image,
|
||||
use_host_network=False,
|
||||
timeout=1800,
|
||||
api_key=os.environ.get('ALLHANDS_API_KEY'),
|
||||
remote_runtime_api_url=os.environ.get(
|
||||
'SANDBOX_REMOTE_RUNTIME_API_URL', 'http://localhost:8000'
|
||||
),
|
||||
),
|
||||
workspace_base=None,
|
||||
workspace_mount_path=None,
|
||||
)
|
||||
|
||||
|
||||
def compute_lexical_metrics(pred_suite, gold_suite):
|
||||
pred_loc = get_lines_of_code(pred_suite)
|
||||
gold_loc = get_lines_of_code(gold_suite)
|
||||
pred_methods = count_methods(pred_suite)
|
||||
gold_methods = count_methods(gold_suite)
|
||||
readability_pred = compute_readability(pred_suite)
|
||||
readability_gold = compute_readability(gold_suite)
|
||||
|
||||
preds = tokenize_code(pred_suite)
|
||||
golds = tokenize_code(gold_suite)
|
||||
|
||||
return {
|
||||
'pred_loc': pred_loc,
|
||||
'gold_loc': gold_loc,
|
||||
'pred_readability': readability_pred,
|
||||
'gold_readability': readability_gold,
|
||||
'pred_methods': pred_methods,
|
||||
'gold_methods': gold_methods,
|
||||
'code_bleu': code_bleu(preds, golds, 'Python3'),
|
||||
'bleu': bleu(preds, golds),
|
||||
'xmatch': exact_match(preds, golds),
|
||||
'edit_sim': edit_sim(preds, golds),
|
||||
'rouge_f': rouge_l(golds, preds)['f'],
|
||||
'rouge_p': rouge_l(golds, preds)['p'],
|
||||
'rouge_r': rouge_l(golds, preds)['r'],
|
||||
}
|
||||
|
||||
|
||||
def run_command(runtime, command, timeout=600):
|
||||
action = CmdRunAction(command=command)
|
||||
action.set_hard_timeout(timeout)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert obs.exit_code == 0
|
||||
return obs
|
||||
|
||||
|
||||
def run_tests(runtime, instance, test_script, log_file='/tmp/test_output.log'):
|
||||
action = CmdRunAction(command=f'bash {test_script} > {log_file} 2>&1 & echo $!')
|
||||
action.set_hard_timeout(60)
|
||||
obs = runtime.run_action(action)
|
||||
|
||||
assert isinstance(obs, CmdOutputObservation), 'Failed to start test script.'
|
||||
pid = obs.content.split()[-1].strip()
|
||||
logger.info(f'[{instance.instance_id}] Test process started with PID: {pid}')
|
||||
|
||||
start_time = time.time()
|
||||
timeout = 1800
|
||||
while True:
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time > timeout:
|
||||
logger.info(f'[{instance.instance_id}] Test process timed out.')
|
||||
instance['test_result']['report']['test_timeout'] = True
|
||||
break
|
||||
|
||||
check_action = CmdRunAction(command=f'ps -p {pid} > /dev/null; echo $?')
|
||||
check_obs = runtime.run_action(check_action)
|
||||
if (
|
||||
isinstance(check_obs, CmdOutputObservation)
|
||||
and len(check_obs.content.split()) > 0
|
||||
and check_obs.content.split()[-1].strip() == '1'
|
||||
):
|
||||
logger.info(f'[{instance.instance_id}] Test process completed.')
|
||||
break
|
||||
time.sleep(30)
|
||||
|
||||
test_action = CmdRunAction(command=f'cat {log_file}')
|
||||
test_action.set_hard_timeout(300)
|
||||
test_obs = runtime.run_action(test_action)
|
||||
assert isinstance(test_obs, CmdOutputObservation), 'Failed to retrieve test output.'
|
||||
return test_obs.exit_code, test_obs.content, elapsed_time
|
||||
|
||||
|
||||
def run_mutation_testing(
|
||||
runtime, instance, mutation_script, log_file='/tmp/mutation_output.log'
|
||||
):
|
||||
action = CmdRunAction(command=f'bash {mutation_script} > {log_file} 2>&1 & echo $!')
|
||||
action.set_hard_timeout(60)
|
||||
obs = runtime.run_action(action)
|
||||
|
||||
assert isinstance(obs, CmdOutputObservation), 'Failed to start test script.'
|
||||
pid = obs.content.split()[-1].strip()
|
||||
logger.info(f'[{instance.instance_id}] Mutation process started with PID: {pid}')
|
||||
|
||||
start_time = time.time()
|
||||
timeout = 4000
|
||||
while True:
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time > timeout:
|
||||
logger.info(f'[{instance.instance_id}] Mutation process timed out.')
|
||||
instance['test_result']['report']['mutation_timeout'] = True
|
||||
break
|
||||
|
||||
check_action = CmdRunAction(command=f'ps -p {pid} > /dev/null; echo $?')
|
||||
check_obs = runtime.run_action(check_action)
|
||||
if (
|
||||
isinstance(check_obs, CmdOutputObservation)
|
||||
and len(check_obs.content.split()) > 0
|
||||
and check_obs.content.split()[-1].strip() == '1'
|
||||
):
|
||||
logger.info(f'[{instance.instance_id}] Mutation process completed.')
|
||||
break
|
||||
time.sleep(30)
|
||||
|
||||
assert isinstance(obs, CmdOutputObservation), 'Failed to run mutation script.'
|
||||
mutation_action = CmdRunAction(command=f'cat {log_file}')
|
||||
mutation_action.set_hard_timeout(300)
|
||||
mutation_obs = runtime.run_action(mutation_action)
|
||||
assert isinstance(
|
||||
mutation_obs, CmdOutputObservation
|
||||
), 'Failed to retrieve mutation output.'
|
||||
return mutation_obs.exit_code, mutation_obs.content
|
||||
|
||||
|
||||
def grade_test_output(
|
||||
test_suite: str, instance: pd.Series, test_output: str, test_spec: TestSpec, runtime
|
||||
):
|
||||
"""
|
||||
Two-pass test grading with short-circuiting:
|
||||
1. Run all tests to identify passing/failing tests
|
||||
2. If no failing tests, evaluate coverage immediately
|
||||
3. Otherwise, run only passing tests for coverage analysis
|
||||
"""
|
||||
unit_test_output, coverage_output = '', ''
|
||||
if TESTS_SUFFIX in test_output:
|
||||
unit_test_output = test_output.split(TESTS_SUFFIX)[0]
|
||||
|
||||
if not unit_test_output:
|
||||
return (
|
||||
False,
|
||||
0,
|
||||
'',
|
||||
'',
|
||||
{
|
||||
'total_tests': 0,
|
||||
'passing_tests': 0,
|
||||
'failing_tests': 0,
|
||||
'any_pass': False,
|
||||
'all_pass': False,
|
||||
'passing_test_names': [],
|
||||
'failing_test_names': [],
|
||||
},
|
||||
)
|
||||
|
||||
logger.info('Calling filter unit tests')
|
||||
filtered_content, passing_tests, failing_tests = filter_tests(
|
||||
test_suite, unit_test_output, test_spec.repo
|
||||
)
|
||||
|
||||
total_tests = len(passing_tests) + len(failing_tests)
|
||||
test_stats = {
|
||||
'total_tests': total_tests,
|
||||
'passing_tests': len(passing_tests),
|
||||
'failing_tests': len(failing_tests),
|
||||
'any_pass': len(passing_tests) > 0,
|
||||
'all_pass': len(failing_tests) == 0 and total_tests > 0,
|
||||
'passing_test_names': passing_tests,
|
||||
'failing_test_names': failing_tests,
|
||||
}
|
||||
|
||||
if not passing_tests:
|
||||
return False, 0, unit_test_output, coverage_output, test_stats
|
||||
|
||||
# If all tests pass, evaluate coverage immediately
|
||||
if not failing_tests:
|
||||
coverage = 0
|
||||
cov_success = False
|
||||
if COVERAGE_PREFIX in test_output:
|
||||
coverage_output = test_output.split(COVERAGE_PREFIX)[1]
|
||||
_, coverage = check_coverage(coverage_output, test_spec.code_file)
|
||||
cov_success = True
|
||||
# test_stats['filtered_suite'] = test_suite
|
||||
return cov_success, coverage, unit_test_output, coverage_output, test_stats
|
||||
|
||||
cov_success = False
|
||||
coverage = 0
|
||||
# Second pass - run coverage on passing tests
|
||||
if filtered_content:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
test_suite_path = os.path.join(temp_dir, 'test_suite.py')
|
||||
with open(test_suite_path, 'w') as f:
|
||||
f.write(filtered_content)
|
||||
runtime.copy_to(test_suite_path, '/tmp')
|
||||
|
||||
run_command(runtime, f'cp /tmp/test_suite.py /testbed/{test_spec.test_file}')
|
||||
_, test_output_second_pass, _ = run_tests(runtime, instance, '/tmp/test.sh')
|
||||
|
||||
coverage, coverage_output, unit_test_output = 0, '', test_output_second_pass
|
||||
|
||||
if COVERAGE_PREFIX in test_output_second_pass:
|
||||
coverage_output = test_output_second_pass.split(COVERAGE_PREFIX)[1]
|
||||
unit_test_output = test_output_second_pass.split(TESTS_SUFFIX)[0]
|
||||
_, coverage = check_coverage(coverage_output, test_spec.code_file)
|
||||
cov_success = True
|
||||
|
||||
# test_stats['filtered_suite'] = filtered_content
|
||||
return cov_success, coverage, unit_test_output, coverage_output, test_stats
|
||||
|
||||
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
log_dir: str | None = None,
|
||||
) -> EvalOutput:
|
||||
"""
|
||||
Evaluate agent performance on a TestGenEval problem instance.
|
||||
|
||||
Note that this signature differs from the expected input to `run_evaluation`. Use
|
||||
`functools.partial` to provide optional arguments before passing to the evaluation harness.
|
||||
|
||||
Args:
|
||||
log_dir (str | None, default=None): Path to directory where log files will be written. Must
|
||||
be provided if `reset_logger` is set.
|
||||
|
||||
Raises:
|
||||
AssertionError: if the `reset_logger` flag is set without a provided log directory.
|
||||
"""
|
||||
if reset_logger:
|
||||
assert (
|
||||
log_dir is not None
|
||||
), "Can't reset logger without a provided log directory."
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
reset_logger_for_multiprocessing(logger, instance.instance_id, log_dir)
|
||||
else:
|
||||
logger.info(f'Starting evaluation for instance {instance.instance_id}.')
|
||||
|
||||
config = get_config(instance)
|
||||
id = instance.instance_id
|
||||
logger.info(f'Starting evaluation for instance {id}.')
|
||||
|
||||
instance['test_result']['id'] = id
|
||||
instance['test_result']['report'] = {
|
||||
'test_output': '',
|
||||
# 'coverage_output': '',
|
||||
# 'mutation_output': '',
|
||||
'empty_generation': False,
|
||||
'error_eval': False,
|
||||
'all_tests_pass': False,
|
||||
'tests_pass': False,
|
||||
'test_timeout': False,
|
||||
'mutation_timeout': False,
|
||||
'coverage_success': False,
|
||||
'mutation_success': False,
|
||||
'coverage': 0,
|
||||
'mutation_score': 0,
|
||||
'mutation_error_interval': -1,
|
||||
'num_mutants': -1,
|
||||
}
|
||||
|
||||
instance['test_result']['lexical'] = {
|
||||
'pred_loc': -1,
|
||||
'gold_loc': -1,
|
||||
'pred_readability': -1,
|
||||
'gold_readability': -1,
|
||||
'pred_methods': -1,
|
||||
'gold_methods': -1,
|
||||
'code_bleu': -1,
|
||||
'bleu': -1,
|
||||
'xmatch': -1,
|
||||
'edit_sim': -1,
|
||||
'rouge_f': -1,
|
||||
'rouge_p': -1,
|
||||
'rouge_r': -1,
|
||||
}
|
||||
|
||||
if instance['test_suite'] == '' or instance['test_suite'] is None:
|
||||
instance['test_result']['report']['empty_generation'] = True
|
||||
return EvalOutput(
|
||||
instance_id=instance.instance_id, test_result=instance['test_result']
|
||||
)
|
||||
|
||||
if not args.skip_lexical:
|
||||
lexical_metrics = compute_lexical_metrics(
|
||||
instance['test_suite'], instance['instance']['test_src']
|
||||
)
|
||||
instance['test_result']['lexical'] = lexical_metrics
|
||||
|
||||
test_suite = instance['test_suite']
|
||||
test_spec: TestSpec = instance['test_spec']
|
||||
runtime = create_runtime(config)
|
||||
call_async_from_sync(runtime.connect)
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
test_suite_path = os.path.join(temp_dir, 'test_suite.py')
|
||||
with open(test_suite_path, 'w') as f:
|
||||
f.write(test_suite)
|
||||
runtime.copy_to(test_suite_path, '/tmp')
|
||||
|
||||
test_script_path = os.path.join(temp_dir, 'test.sh')
|
||||
with open(test_script_path, 'w') as f:
|
||||
f.write(test_spec.test_script)
|
||||
runtime.copy_to(test_script_path, '/tmp')
|
||||
|
||||
mutation_script_path = os.path.join(temp_dir, 'mutation.sh')
|
||||
with open(mutation_script_path, 'w') as f:
|
||||
f.write(test_spec.mutation_script)
|
||||
runtime.copy_to(mutation_script_path, '/tmp')
|
||||
|
||||
try:
|
||||
run_command(runtime, 'chmod +x /tmp/test.sh /tmp/mutation.sh')
|
||||
run_command(runtime, f'cp /tmp/test_suite.py /testbed/{test_spec.test_file}')
|
||||
|
||||
# First pass - run all tests
|
||||
_, test_output, test_time = run_tests(runtime, instance, '/tmp/test.sh')
|
||||
|
||||
# Grade tests with two-pass approach
|
||||
coverage_success, coverage, unit_test_output, coverage_output, test_stats = (
|
||||
grade_test_output(test_suite, instance, test_output, test_spec, runtime)
|
||||
)
|
||||
|
||||
# Update report with test statistics
|
||||
instance['test_result']['report'].update(
|
||||
{
|
||||
'test_output': unit_test_output,
|
||||
# 'coverage_output': coverage_output,
|
||||
'tests_pass': test_stats['any_pass'], # Changed to use any_pass
|
||||
'all_tests_pass': test_stats['all_pass'], # Added all_pass metric
|
||||
'coverage_success': coverage_success,
|
||||
'coverage': coverage if coverage_success else 0,
|
||||
'test_stats': test_stats,
|
||||
}
|
||||
)
|
||||
|
||||
# Only run mutation testing if we have passing tests and coverage
|
||||
if (
|
||||
not args.skip_mutation
|
||||
and coverage_success
|
||||
and test_stats['any_pass']
|
||||
and coverage > 0
|
||||
):
|
||||
mutation_timeout = max(10, 1.5 * test_time)
|
||||
mutation_toml = MUTATION_TEMPLATE.format(
|
||||
test_cmd=test_spec.test_cmd,
|
||||
source_fp=test_spec.code_file,
|
||||
timeout=mutation_timeout,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
mutation_toml_path = os.path.join(temp_dir, 'mutation.toml')
|
||||
with open(mutation_toml_path, 'w') as f:
|
||||
f.write(mutation_toml)
|
||||
runtime.copy_to(mutation_toml_path, '/tmp')
|
||||
|
||||
run_command(runtime, 'cp /tmp/mutation.toml /testbed/mutation.toml')
|
||||
|
||||
mutation_code, mutation_output = run_mutation_testing(
|
||||
runtime, instance, '/tmp/mutation.sh'
|
||||
)
|
||||
# instance['test_result']['report']['mutation_output'] = mutation_output
|
||||
if mutation_output and mutation_code == 0:
|
||||
(
|
||||
mutation_success,
|
||||
num_mutants,
|
||||
mutation_score,
|
||||
mutation_confidence_interval,
|
||||
) = check_mutation(mutation_output)
|
||||
instance['test_result']['report']['num_mutants'] = num_mutants
|
||||
instance['test_result']['report']['mutation_success'] = mutation_success
|
||||
instance['test_result']['report']['mutation_score'] = mutation_score
|
||||
instance['test_result']['report']['mutation_error_interval'] = (
|
||||
mutation_confidence_interval
|
||||
)
|
||||
|
||||
return EvalOutput(
|
||||
instance_id=instance.instance_id, test_result=instance['test_result']
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Error processing instance {instance.instance_id}: {e}')
|
||||
raise RuntimeError(
|
||||
instance.instance_id,
|
||||
'Unexpected output...',
|
||||
logger,
|
||||
)
|
||||
|
||||
finally:
|
||||
runtime.close()
|
||||
|
||||
|
||||
def count_and_log_fields(evaluated_predictions, fields, key):
|
||||
"""
|
||||
Count and log the sum of specified fields in the evaluated predictions,
|
||||
ignoring fields with a value of -1. If all values for a field are -1,
|
||||
return -1.
|
||||
|
||||
:param evaluated_predictions: DataFrame containing evaluation results
|
||||
:param fields: List of field names to count
|
||||
:param key: Key to access the field values ('report' or 'lexical')
|
||||
"""
|
||||
|
||||
def count_field(row, field):
|
||||
value = row['test_result'][key][field]
|
||||
return (
|
||||
value if value != -1 else None
|
||||
) # Ignore -1 fields by treating them as None
|
||||
|
||||
for field in fields:
|
||||
# Extract the valid values for the field, ignoring -1
|
||||
valid_values = evaluated_predictions.apply(
|
||||
count_field, args=(field,), axis=1
|
||||
).dropna()
|
||||
|
||||
if valid_values.empty: # If all values are -1
|
||||
logger.info(f'# {field}: -1 (All values are -1)')
|
||||
else:
|
||||
count = valid_values.sum() # Sum of valid values
|
||||
length = len(valid_values) # Count of valid entries
|
||||
logger.info(f'# {field}: {length}. ({count / length:.2f})')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = get_parser()
|
||||
parser.add_argument(
|
||||
'--input-file', type=str, required=True, help='Path to input predictions file'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dataset',
|
||||
type=str,
|
||||
default='kjain14/testgeneval',
|
||||
help='Dataset to evaluate on',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--split', type=str, default='test', help='Split to evaluate on'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--skip_mutation', action='store_true', help='Skip mutation testing'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--skip_lexical', action='store_true', help='Skip lexical metrics'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--mutation_timeout',
|
||||
type=int,
|
||||
default=MUTATION_TIMEOUT,
|
||||
help='Mutation timeout',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--mutation_buffer',
|
||||
type=int,
|
||||
default=MUTATION_BUFFER,
|
||||
help='Mutation buffer',
|
||||
)
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
dataset: list[TestGenEvalInstance] = load_testgeneval_dataset(
|
||||
args.dataset, args.split
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'Loaded dataset {args.dataset} with split {args.split} to run inference on.'
|
||||
)
|
||||
|
||||
# Load predictions
|
||||
assert args.input_file.endswith('.jsonl'), 'Input file must be a jsonl file.'
|
||||
predictions = pd.read_json(args.input_file, lines=True)
|
||||
assert (
|
||||
'instance_id' in predictions.columns
|
||||
), 'Input file must contain instance_id column.'
|
||||
|
||||
if 'test_suite' not in predictions.columns and (
|
||||
'test_result' in predictions.columns
|
||||
and 'test_suite' in predictions['test_result'].iloc(0)
|
||||
):
|
||||
raise ValueError(
|
||||
'Input file must contain test_suite column OR test_result column with test_suite field.'
|
||||
)
|
||||
|
||||
if 'instance_id_swebench' not in predictions.columns:
|
||||
predictions['instance_id_swebench'] = predictions['instance'].apply(
|
||||
lambda x: x['instance_id_swebench']
|
||||
)
|
||||
|
||||
if 'instance_id' not in predictions.columns and (
|
||||
'instance_id' in predictions['instance'].iloc(0)
|
||||
):
|
||||
raise ValueError(
|
||||
'Input file must contain id column OR instance column with id field.'
|
||||
)
|
||||
|
||||
if 'instance_id' not in predictions.columns:
|
||||
predictions['instance_id'] = predictions['instance'].apply(
|
||||
lambda x: x['instance_id']
|
||||
)
|
||||
|
||||
if 'test_suite' not in predictions.columns:
|
||||
predictions['test_suite'] = predictions['test_result'].apply(
|
||||
lambda x: x['test_suite']
|
||||
)
|
||||
|
||||
assert len(predictions['instance_id'].unique()) == len(
|
||||
predictions
|
||||
), 'instance_id column must be unique.'
|
||||
|
||||
assert {'instance_id_swebench', 'test_suite', 'instance_id'}.issubset(
|
||||
set(predictions.columns)
|
||||
), 'Input file must contain id, instance_id and test_suite columns.'
|
||||
|
||||
predictions['test_spec'] = predictions['instance'].apply(
|
||||
lambda x: make_test_spec(x, args.mutation_timeout, args.mutation_buffer)
|
||||
)
|
||||
|
||||
output_file = args.input_file.replace('.jsonl', '.testgeneval.jsonl')
|
||||
instances = prepare_dataset(predictions, output_file, args.eval_n_limit)
|
||||
|
||||
# If possible, load the relevant metadata to avoid issues with `run_evaluation`.
|
||||
metadata: EvalMetadata | None = None
|
||||
metadata_filepath = os.path.join(os.path.dirname(args.input_file), 'metadata.json')
|
||||
if os.path.exists(metadata_filepath):
|
||||
with open(metadata_filepath, 'r') as metadata_file:
|
||||
data = metadata_file.read()
|
||||
metadata = EvalMetadata.model_validate_json(data)
|
||||
|
||||
# The evaluation harness constrains the signature of `process_instance_func` but we need to
|
||||
# pass extra information. Build a new function object to avoid issues with multiprocessing.
|
||||
process_instance_func = partial(
|
||||
process_instance, log_dir=output_file.replace('.jsonl', '.logs')
|
||||
)
|
||||
|
||||
run_evaluation(
|
||||
instances,
|
||||
metadata=None,
|
||||
output_file=output_file,
|
||||
num_workers=args.eval_num_workers,
|
||||
process_instance_func=process_instance_func,
|
||||
)
|
||||
|
||||
# Load evaluated predictions & print number of resolved predictions
|
||||
evaluated_predictions = pd.read_json(output_file, lines=True)
|
||||
report_fields = [
|
||||
'coverage',
|
||||
'mutation_score',
|
||||
'tests_pass',
|
||||
'all_tests_pass',
|
||||
'empty_generation',
|
||||
'coverage_success',
|
||||
'test_timeout',
|
||||
'error_eval',
|
||||
]
|
||||
lexical_fields = [
|
||||
'pred_loc',
|
||||
'gold_loc',
|
||||
'pred_methods',
|
||||
'gold_methods',
|
||||
'code_bleu',
|
||||
'bleu',
|
||||
'xmatch',
|
||||
'edit_sim',
|
||||
'rouge_f',
|
||||
'rouge_p',
|
||||
'rouge_r',
|
||||
]
|
||||
|
||||
# Log report and lexical fields
|
||||
count_and_log_fields(evaluated_predictions, report_fields, key='report')
|
||||
count_and_log_fields(evaluated_predictions, lexical_fields, key='lexical')
|
||||
@@ -0,0 +1,291 @@
|
||||
import re
|
||||
|
||||
from evaluation.benchmarks.testgeneval.constants import TestStatus
|
||||
|
||||
|
||||
def parse_log_pytest(log: str) -> dict[str, str]:
|
||||
"""
|
||||
Parser for test logs generated with PyTest framework
|
||||
|
||||
Args:
|
||||
log (str): log content
|
||||
Returns:
|
||||
dict: test case to test status mapping
|
||||
"""
|
||||
test_status_map = {}
|
||||
for line in log.split('\n'):
|
||||
if any([line.startswith(x.value) for x in TestStatus]):
|
||||
# Additional parsing for FAILED status
|
||||
if line.startswith(TestStatus.FAILED.value):
|
||||
line = line.replace(' - ', ' ')
|
||||
test_case = line.split()
|
||||
if len(test_case) <= 1:
|
||||
continue
|
||||
test_status_map[test_case[1]] = test_case[0]
|
||||
return test_status_map
|
||||
|
||||
|
||||
def parse_log_pytest_options(log: str) -> dict[str, str]:
|
||||
"""
|
||||
Parser for test logs generated with PyTest framework with options
|
||||
|
||||
Args:
|
||||
log (str): log content
|
||||
Returns:
|
||||
dict: test case to test status mapping
|
||||
"""
|
||||
option_pattern = re.compile(r'(.*?)\[(.*)\]')
|
||||
test_status_map = {}
|
||||
for line in log.split('\n'):
|
||||
if any([line.startswith(x.value) for x in TestStatus]):
|
||||
# Additional parsing for FAILED status
|
||||
if line.startswith(TestStatus.FAILED.value):
|
||||
line = line.replace(' - ', ' ')
|
||||
test_case = line.split()
|
||||
if len(test_case) <= 1:
|
||||
continue
|
||||
has_option = option_pattern.search(test_case[1])
|
||||
if has_option:
|
||||
main, option = has_option.groups()
|
||||
if (
|
||||
option.startswith('/')
|
||||
and not option.startswith('//')
|
||||
and '*' not in option
|
||||
):
|
||||
option = '/' + option.split('/')[-1]
|
||||
test_name = f'{main}[{option}]'
|
||||
else:
|
||||
test_name = test_case[1]
|
||||
test_status_map[test_name] = test_case[0]
|
||||
return test_status_map
|
||||
|
||||
|
||||
def parse_log_django(log: str) -> dict[str, str]:
|
||||
"""
|
||||
Parser for test logs generated with Django tester framework
|
||||
|
||||
Args:
|
||||
log (str): log content
|
||||
Returns:
|
||||
dict: test case to test status mapping
|
||||
"""
|
||||
test_status_map = {}
|
||||
lines = log.split('\n')
|
||||
|
||||
prev_test = None
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
|
||||
# This isn't ideal but the test output spans multiple lines
|
||||
if '--version is equivalent to version' in line:
|
||||
test_status_map['--version is equivalent to version'] = (
|
||||
TestStatus.PASSED.value
|
||||
)
|
||||
|
||||
# Log it in case of error
|
||||
if ' ... ' in line:
|
||||
prev_test = line.split(' ... ')[0]
|
||||
|
||||
pass_suffixes = (' ... ok', ' ... OK', ' ... OK')
|
||||
for suffix in pass_suffixes:
|
||||
if line.endswith(suffix):
|
||||
# TODO: Temporary, exclusive fix for django__django-7188
|
||||
# The proper fix should involve somehow getting the test results to
|
||||
# print on a separate line, rather than the same line
|
||||
if line.strip().startswith(
|
||||
'Applying sites.0002_alter_domain_unique...test_no_migrations'
|
||||
):
|
||||
line = line.split('...', 1)[-1].strip()
|
||||
test = line.rsplit(suffix, 1)[0]
|
||||
test_status_map[test] = TestStatus.PASSED.value
|
||||
break
|
||||
if ' ... skipped' in line:
|
||||
test = line.split(' ... skipped')[0]
|
||||
test_status_map[test] = TestStatus.SKIPPED.value
|
||||
if line.endswith(' ... FAIL'):
|
||||
test = line.split(' ... FAIL')[0]
|
||||
test_status_map[test] = TestStatus.FAILED.value
|
||||
if line.startswith('FAIL:'):
|
||||
test = line.split()[1].strip()
|
||||
test_status_map[test] = TestStatus.FAILED.value
|
||||
if line.endswith(' ... ERROR'):
|
||||
test = line.split(' ... ERROR')[0]
|
||||
test_status_map[test] = TestStatus.ERROR.value
|
||||
if line.startswith('ERROR:'):
|
||||
test = line.split()[1].strip()
|
||||
test_status_map[test] = TestStatus.ERROR.value
|
||||
|
||||
if line.lstrip().startswith('ok') and prev_test is not None:
|
||||
# It means the test passed, but there's some additional output (including new lines)
|
||||
# between "..." and "ok" message
|
||||
test = prev_test
|
||||
test_status_map[test] = TestStatus.PASSED.value
|
||||
|
||||
# TODO: This is very brittle, we should do better
|
||||
# There's a bug in the django logger, such that sometimes a test output near the end gets
|
||||
# interrupted by a particular long multiline print statement.
|
||||
# We have observed this in one of 3 forms:
|
||||
# - "{test_name} ... Testing against Django installed in {*} silenced.\nok"
|
||||
# - "{test_name} ... Internal Server Error: \/(.*)\/\nok"
|
||||
# - "{test_name} ... System check identified no issues (0 silenced).\nok"
|
||||
patterns = [
|
||||
r'^(.*?)\s\.\.\.\sTesting\ against\ Django\ installed\ in\ ((?s:.*?))\ silenced\)\.\nok$',
|
||||
r'^(.*?)\s\.\.\.\sInternal\ Server\ Error:\ \/(.*)\/\nok$',
|
||||
r'^(.*?)\s\.\.\.\sSystem check identified no issues \(0 silenced\)\nok$',
|
||||
]
|
||||
for pattern in patterns:
|
||||
for match in re.finditer(pattern, log, re.MULTILINE):
|
||||
test_name = match.group(1)
|
||||
test_status_map[test_name] = TestStatus.PASSED.value
|
||||
return test_status_map
|
||||
|
||||
|
||||
def parse_log_pytest_v2(log: str) -> dict[str, str]:
|
||||
"""
|
||||
Parser for test logs generated with PyTest framework (Later Version)
|
||||
|
||||
Args:
|
||||
log (str): log content
|
||||
Returns:
|
||||
dict: test case to test status mapping
|
||||
"""
|
||||
test_status_map = {}
|
||||
escapes = ''.join([chr(char) for char in range(1, 32)])
|
||||
for line in log.split('\n'):
|
||||
line = re.sub(r'\[(\d+)m', '', line)
|
||||
translator = str.maketrans('', '', escapes)
|
||||
line = line.translate(translator)
|
||||
if any([line.startswith(x.value) for x in TestStatus]):
|
||||
if line.startswith(TestStatus.FAILED.value):
|
||||
line = line.replace(' - ', ' ')
|
||||
test_case = line.split()
|
||||
if len(test_case) >= 2:
|
||||
test_status_map[test_case[1]] = test_case[0]
|
||||
# Support older pytest versions by checking if the line ends with the test status
|
||||
elif any([line.endswith(x.value) for x in TestStatus]):
|
||||
test_case = line.split()
|
||||
if len(test_case) >= 2:
|
||||
test_status_map[test_case[0]] = test_case[1]
|
||||
return test_status_map
|
||||
|
||||
|
||||
def parse_log_seaborn(log: str) -> dict[str, str]:
|
||||
"""
|
||||
Parser for test logs generated with seaborn testing framework
|
||||
|
||||
Args:
|
||||
log (str): log content
|
||||
Returns:
|
||||
dict: test case to test status mapping
|
||||
"""
|
||||
test_status_map = {}
|
||||
for line in log.split('\n'):
|
||||
if line.startswith(TestStatus.FAILED.value):
|
||||
test_case = line.split()[1]
|
||||
test_status_map[test_case] = TestStatus.FAILED.value
|
||||
elif f' {TestStatus.PASSED.value} ' in line:
|
||||
parts = line.split()
|
||||
if parts[1] == TestStatus.PASSED.value:
|
||||
test_case = parts[0]
|
||||
test_status_map[test_case] = TestStatus.PASSED.value
|
||||
elif line.startswith(TestStatus.PASSED.value):
|
||||
parts = line.split()
|
||||
test_case = parts[1]
|
||||
test_status_map[test_case] = TestStatus.PASSED.value
|
||||
return test_status_map
|
||||
|
||||
|
||||
def parse_log_sympy(log: str) -> dict[str, str]:
|
||||
"""
|
||||
Parser for test logs generated with Sympy framework
|
||||
|
||||
Args:
|
||||
log (str): log content
|
||||
Returns:
|
||||
dict: test case to test status mapping
|
||||
"""
|
||||
test_status_map = {}
|
||||
pattern = r'(_*) (.*)\.py:(.*) (_*)'
|
||||
matches = re.findall(pattern, log)
|
||||
for match in matches:
|
||||
test_case = f'{match[1]}.py:{match[2]}'
|
||||
test_status_map[test_case] = TestStatus.FAILED.value
|
||||
for line in log.split('\n'):
|
||||
line = line.strip()
|
||||
if line.startswith('test_'):
|
||||
if line.endswith('[FAIL]') or line.endswith('[OK]'):
|
||||
line = line[: line.rfind('[')]
|
||||
line = line.strip()
|
||||
if line.endswith(' E'):
|
||||
test = line.split()[0]
|
||||
test_status_map[test] = TestStatus.ERROR.value
|
||||
if line.endswith(' F'):
|
||||
test = line.split()[0]
|
||||
test_status_map[test] = TestStatus.FAILED.value
|
||||
if line.endswith(' ok'):
|
||||
test = line.split()[0]
|
||||
test_status_map[test] = TestStatus.PASSED.value
|
||||
return test_status_map
|
||||
|
||||
|
||||
def parse_log_matplotlib(log: str) -> dict[str, str]:
|
||||
"""
|
||||
Parser for test logs generated with PyTest framework
|
||||
|
||||
Args:
|
||||
log (str): log content
|
||||
Returns:
|
||||
dict: test case to test status mapping
|
||||
"""
|
||||
test_status_map = {}
|
||||
for line in log.split('\n'):
|
||||
line = line.replace('MouseButton.LEFT', '1')
|
||||
line = line.replace('MouseButton.RIGHT', '3')
|
||||
if any([line.startswith(x.value) for x in TestStatus]):
|
||||
# Additional parsing for FAILED status
|
||||
if line.startswith(TestStatus.FAILED.value):
|
||||
line = line.replace(' - ', ' ')
|
||||
test_case = line.split()
|
||||
if len(test_case) <= 1:
|
||||
continue
|
||||
test_status_map[test_case[1]] = test_case[0]
|
||||
return test_status_map
|
||||
|
||||
|
||||
parse_log_astroid = parse_log_pytest
|
||||
parse_log_flask = parse_log_pytest
|
||||
parse_log_marshmallow = parse_log_pytest
|
||||
parse_log_pvlib = parse_log_pytest
|
||||
parse_log_pyvista = parse_log_pytest
|
||||
parse_log_sqlfluff = parse_log_pytest
|
||||
parse_log_xarray = parse_log_pytest
|
||||
|
||||
parse_log_pydicom = parse_log_pytest_options
|
||||
parse_log_requests = parse_log_pytest_options
|
||||
parse_log_pylint = parse_log_pytest_options
|
||||
|
||||
parse_log_astropy = parse_log_pytest_v2
|
||||
parse_log_scikit = parse_log_pytest_v2
|
||||
parse_log_sphinx = parse_log_pytest_v2
|
||||
|
||||
|
||||
MAP_REPO_TO_PARSER = {
|
||||
'astropy/astropy': parse_log_astropy,
|
||||
'django/django': parse_log_django,
|
||||
'marshmallow-code/marshmallow': parse_log_marshmallow,
|
||||
'matplotlib/matplotlib': parse_log_matplotlib,
|
||||
'mwaskom/seaborn': parse_log_seaborn,
|
||||
'pallets/flask': parse_log_flask,
|
||||
'psf/requests': parse_log_requests,
|
||||
'pvlib/pvlib-python': parse_log_pvlib,
|
||||
'pydata/xarray': parse_log_xarray,
|
||||
'pydicom/pydicom': parse_log_pydicom,
|
||||
'pylint-dev/astroid': parse_log_astroid,
|
||||
'pylint-dev/pylint': parse_log_pylint,
|
||||
'pytest-dev/pytest': parse_log_pytest,
|
||||
'pyvista/pyvista': parse_log_pyvista,
|
||||
'scikit-learn/scikit-learn': parse_log_scikit,
|
||||
'sqlfluff/sqlfluff': parse_log_sqlfluff,
|
||||
'sphinx-doc/sphinx': parse_log_sphinx,
|
||||
'sympy/sympy': parse_log_sympy,
|
||||
}
|
||||
@@ -0,0 +1,311 @@
|
||||
import sys
|
||||
from typing import Callable, Dict, List, Optional, Sequence, TypeVar, Union
|
||||
|
||||
import nltk
|
||||
import numpy as np
|
||||
from fuzzywuzzy import fuzz
|
||||
from rouge import Rouge
|
||||
|
||||
|
||||
|
||||
# increase recursion depth to ensure ROUGE can be calculated for long sentences
|
||||
if sys.getrecursionlimit() < 10_000:
|
||||
sys.setrecursionlimit(10_000)
|
||||
|
||||
def bleu(gold: List[str], pred: List[str]) -> float:
|
||||
"""
|
||||
Calculate BLEU score, using smoothing method 2 with auto reweighting, in the range of 0~100.
|
||||
|
||||
:param gold: list of gold tokens
|
||||
:param pred: list of predicted tokens
|
||||
:return: BLEU score
|
||||
"""
|
||||
if len(pred) == 0 or len(gold) == 0:
|
||||
return 0.0
|
||||
return 100.0 * nltk.translate.bleu_score.sentence_bleu(
|
||||
[gold],
|
||||
pred,
|
||||
smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method2,
|
||||
auto_reweigh=True,
|
||||
)
|
||||
|
||||
|
||||
def batch_bleu(golds: List[List[str]], preds: List[List[str]]) -> List[float]:
|
||||
"""
|
||||
Calculate BLEU score for a batch of sentences.
|
||||
|
||||
:param golds: list of gold sentences
|
||||
:param preds: list of predicted sentences
|
||||
:return: list of BLEU scores
|
||||
"""
|
||||
if len(golds) != len(preds):
|
||||
raise ValueError("golds and preds must have the same length")
|
||||
return [bleu(gold, pred) for gold, pred in zip(golds, preds)]
|
||||
|
||||
|
||||
def corpus_bleu(golds: List[List[str]], preds: List[List[str]]) -> float:
|
||||
"""
|
||||
Calculate corpus-level BLEU score for a batch of sentences.
|
||||
|
||||
:param golds: list of gold sentences
|
||||
:param preds: list of predicted sentences
|
||||
:return: corpus-level BLEU score
|
||||
"""
|
||||
if len(golds) != len(preds):
|
||||
raise ValueError("golds and preds must have the same length")
|
||||
return 100.0 * nltk.translate.bleu_score.corpus_bleu(
|
||||
[[gold] for gold in golds],
|
||||
preds,
|
||||
smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method2,
|
||||
auto_reweigh=True,
|
||||
)
|
||||
|
||||
|
||||
def edit_sim(
|
||||
gold: Union[str, List[str]], pred: Union[str, List[str]], sep: str = " "
|
||||
) -> float:
|
||||
"""
|
||||
Calculate char-level edit similarity, in the range of 0~100.
|
||||
|
||||
:param gold: gold sentence or list of gold tokens
|
||||
:param pred: predicted sentence or list of predicted tokens
|
||||
:param sep: separator between tokens
|
||||
:return: char-level edit similarity
|
||||
"""
|
||||
if len(pred) == 0 or len(gold) == 0:
|
||||
return 0.0
|
||||
if isinstance(gold, list):
|
||||
gold = sep.join(gold)
|
||||
if isinstance(pred, list):
|
||||
pred = sep.join(pred)
|
||||
return fuzz.ratio(gold, pred)
|
||||
|
||||
|
||||
def batch_edit_sim(
|
||||
golds: List[Union[str, List[str]]],
|
||||
preds: List[Union[str, List[str]]],
|
||||
sep: str = " ",
|
||||
) -> List[float]:
|
||||
"""
|
||||
Calculate char-level edit similarity for a batch of sentences.
|
||||
|
||||
:param golds: list of gold sentences
|
||||
:param preds: list of predicted sentences
|
||||
:param sep: separator between tokens
|
||||
:return: list of char-level edit similarity
|
||||
"""
|
||||
if len(golds) != len(preds):
|
||||
raise ValueError("golds and preds must have the same length")
|
||||
return [edit_sim(gold, pred, sep) for gold, pred in zip(golds, preds)]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def exact_match(gold: T, pred: T) -> float:
|
||||
"""
|
||||
Calculate exact match accuracy, in the range of {0, 100}.
|
||||
|
||||
:param gold: gold sentence or list of gold tokens
|
||||
:param pred: predicted sentence or list of predicted tokens
|
||||
:return: exact match accuracy
|
||||
"""
|
||||
if len(pred) == 0 or len(gold) == 0:
|
||||
return 0.0
|
||||
return 100.0 if gold == pred else 0.0
|
||||
|
||||
|
||||
def batch_exact_match(golds: List[T], preds: List[T]) -> List[float]:
|
||||
"""
|
||||
Calculate exact match accuracy for a batch of sentences.
|
||||
|
||||
:param golds: list of gold sentences
|
||||
:param preds: list of predicted sentences
|
||||
:return: list of exact match accuracy
|
||||
"""
|
||||
if len(golds) != len(preds):
|
||||
raise ValueError("golds and preds must have the same length")
|
||||
return [exact_match(gold, pred) for gold, pred in zip(golds, preds)]
|
||||
|
||||
|
||||
def rouge_l(
|
||||
gold: Union[str, List[str]], pred: Union[str, List[str]], sep: str = " "
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Calculate ROUGE-L F1, precision, and recall scores, in the range of 0~100.
|
||||
|
||||
:param gold: gold sentence or list of gold tokens
|
||||
:param pred: predicted sentence or list of predicted tokens
|
||||
:return: {"p": precision, "r": recall, "f": F1}
|
||||
"""
|
||||
if len(pred) == 0 or len(gold) == 0:
|
||||
return {"p": 0.0, "r": 0.0, "f": 0.0}
|
||||
if isinstance(gold, list):
|
||||
gold = sep.join(gold)
|
||||
if isinstance(pred, list):
|
||||
pred = sep.join(pred)
|
||||
try:
|
||||
rouge = Rouge()
|
||||
scores = rouge.get_scores(hyps=pred, refs=gold, avg=True)
|
||||
return {x: scores["rouge-l"][x] * 100.0 for x in ["p", "r", "f"]}
|
||||
except ValueError:
|
||||
return {"p": 0.0, "r": 0.0, "f": 0.0}
|
||||
|
||||
|
||||
def batch_rouge_l(
|
||||
golds: List[Union[str, List[str]]],
|
||||
preds: List[Union[str, List[str]]],
|
||||
sep: str = " ",
|
||||
) -> Dict[str, List[float]]:
|
||||
"""
|
||||
Calculate ROUGE-L F1, precision, and recall scores for a batch of sentences.
|
||||
|
||||
:param golds: list of gold sentences
|
||||
:param preds: list of predicted sentences
|
||||
:param sep: separator between tokens
|
||||
:return: list of {"p": precision, "r": recall, "f": F1}
|
||||
"""
|
||||
if len(golds) != len(preds):
|
||||
raise ValueError("golds and preds must have the same length")
|
||||
scores = [rouge_l(gold, pred, sep) for gold, pred in zip(golds, preds)]
|
||||
return {x: [score[x] for score in scores] for x in ["p", "r", "f"]}
|
||||
|
||||
|
||||
def accuracy(
|
||||
gold: List[str],
|
||||
pred: List[str],
|
||||
ignore: Optional[Sequence[str]] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate token-level accuracy, in the range of 0~100.
|
||||
If gold and pred are not the same length, the longer one would be truncated.
|
||||
|
||||
:param gold: list of gold tokens
|
||||
:param pred: list of predicted tokens
|
||||
:param ignore: list of (gold) tokens to ignore
|
||||
:return: accuracy
|
||||
"""
|
||||
if len(pred) == 0 or len(gold) == 0:
|
||||
return 0.0
|
||||
if ignore is None:
|
||||
ignore = []
|
||||
i = 0
|
||||
total = 0
|
||||
match = 0
|
||||
while i < len(gold) and i < len(pred):
|
||||
if gold[i] in ignore:
|
||||
i += 1
|
||||
continue
|
||||
total += 1
|
||||
if gold[i] == pred[i]:
|
||||
match += 1
|
||||
i += 1
|
||||
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return 100.0 * match / total
|
||||
|
||||
|
||||
def batch_accuracy(
|
||||
golds: List[List[str]],
|
||||
preds: List[List[str]],
|
||||
ignore: Optional[Sequence[str]] = None,
|
||||
) -> List[float]:
|
||||
"""
|
||||
Calculate token-level accuracy for a batch of sentences.
|
||||
|
||||
:param golds: list of gold sentences
|
||||
:param preds: list of predicted sentences
|
||||
:param ignore: list of (gold) tokens to ignore
|
||||
:return: list of accuracy
|
||||
"""
|
||||
if len(golds) != len(preds):
|
||||
raise ValueError("golds and preds must have the same length")
|
||||
return [accuracy(gold, pred, ignore) for gold, pred in zip(golds, preds)]
|
||||
|
||||
|
||||
def first_match_to_topk(
|
||||
first_match_list: List[int], k_values: List[int]
|
||||
) -> Dict[int, List[float]]:
|
||||
"""
|
||||
Calculate top-k accuracy with the first match ranks (1-indexed).
|
||||
|
||||
:param first_match: first match ranks (1-indexed)
|
||||
:param k_values: k values to consider
|
||||
:return: a mapping from k to top-k accuracies (ranging from 0~100)
|
||||
"""
|
||||
return {k: [100.0 if x <= k else 0.0 for x in first_match_list] for k in k_values}
|
||||
|
||||
|
||||
def pass_at_k(n: int, c: int, k: int) -> float:
|
||||
"""
|
||||
Sample pass@k metric according to the Codex paper, but in the scale of 0~100.
|
||||
:param n: total number of samples
|
||||
:param c: number of correct samples
|
||||
:param k: k in pass@$k$
|
||||
"""
|
||||
if n < k or (n - c) < k:
|
||||
# fallback to the (1 - (1-p)^k) formula
|
||||
return (1 - (1 - (c / n)) ** k) * 100
|
||||
else:
|
||||
return (1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)).item()) * 100
|
||||
|
||||
|
||||
def self_bleu(samples: List[List[str]]) -> float:
|
||||
"""
|
||||
Calculate self-BLEU among the samples.
|
||||
:param samples: the chosen m samples
|
||||
:return: self-BLEU
|
||||
"""
|
||||
if len(samples) == 0:
|
||||
return 100.0
|
||||
|
||||
scores = []
|
||||
for i in range(len(samples)):
|
||||
scores.append(
|
||||
100.0
|
||||
* nltk.translate.bleu_score.sentence_bleu(
|
||||
[samples[j] for j in range(len(samples)) if j != i],
|
||||
samples[i],
|
||||
smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method2,
|
||||
auto_reweigh=True,
|
||||
)
|
||||
)
|
||||
return np.mean(scores).item()
|
||||
|
||||
|
||||
def self_edit_distance(samples: List[Union[str, List[str]]], sep=" ") -> float:
|
||||
"""
|
||||
Calculate self-edit-distance among the samples.
|
||||
:param samples: the chosen m samples
|
||||
:param sep: the separator between tokens
|
||||
:return: self-edit-distance
|
||||
"""
|
||||
if len(samples) == 0:
|
||||
return 0.0
|
||||
|
||||
scores = []
|
||||
for i in range(len(samples)):
|
||||
sample_i = samples[i]
|
||||
if not isinstance(sample_i, str):
|
||||
sample_i = sep.join(sample_i)
|
||||
for j in range(len(samples)):
|
||||
if i == j:
|
||||
continue
|
||||
sample_j = samples[j]
|
||||
if not isinstance(sample_j, str):
|
||||
sample_j = sep.join(sample_j)
|
||||
|
||||
scores.append(100 - fuzz.ratio(sample_i, sample_j))
|
||||
return np.mean(scores).item()
|
||||
|
||||
|
||||
|
||||
QUALITY_METRICS: Dict[str, Callable[[List[str], List[str]], float]] = {
|
||||
"bleu": bleu,
|
||||
"xmatch": exact_match,
|
||||
"edit-sim": edit_sim,
|
||||
"rouge-f": lambda g, p: rouge_l(g, p)["f"],
|
||||
"rouge-p": lambda g, p: rouge_l(g, p)["p"],
|
||||
"rouge-r": lambda g, p: rouge_l(g, p)["r"],
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
CODEACT_TESTGEN_PROMPT_OLD = """Your goal is to generate a high-quality test suite (at least 20+ passing tests) for the code file: {code_file}. Output the test suite at {test_file}\n'
|
||||
|
||||
[current directory: /workspace/{workspace_dir_name}]
|
||||
|
||||
IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP
|
||||
|
||||
IMPORTANT: Follow instructions, if you have < 80 tests you should generate more tests rather than trying to fix the ones you have.
|
||||
|
||||
IMPORTANT: Code file to test:
|
||||
```python
|
||||
{code_src}
|
||||
```
|
||||
|
||||
Here are additional imports that you may need:
|
||||
{imports}
|
||||
|
||||
Look at code dependencies (NOT {code_file} since you already have contents) and test files you need context for to write a complete test suite.
|
||||
|
||||
Aim for 20+ test functions with asserts. Do not hestitate to use the Python interpreter to understand the input output behavior of the code you are testing.
|
||||
|
||||
Output your test suite at {test_file}. Each unit test must be a function starting with test_. Include all your test imports and setup before your first test. Do not include a main method to run the tests. Make sure to make it as comprehensive as possible, try to execute all the methods you saw.
|
||||
|
||||
When you think you've successfully generated a test suite, run it on for the current project using {coverage_command}.
|
||||
|
||||
If you have few tests GENERATE MORE TESTS rather than trying to fix the ones you have (it is possible to filter out failing tests later).
|
||||
|
||||
Then run coverage report -m --include {code_file} to see how well your test suite covers the code under test.
|
||||
|
||||
When you are trying to improve coverage pick a part of the code that is not covered (indicated by lines on coverage report), examine the code and then
|
||||
try to generate a test for it. Feel free to use a code interpreter to understand the input output behavior. ONLY add tests
|
||||
not remove them.
|
||||
|
||||
If you are unable to see passing and failing tests, FIX YOUR IMPORTS to use the same style as other test files.
|
||||
|
||||
You should NOT modify any existing test case files. You SHOULD add new test in a NEW file to reproduce the issue.
|
||||
|
||||
You should NEVER use web browsing or any other web-based tools.
|
||||
|
||||
You should NEVER install new packages, use existing packages only.
|
||||
|
||||
You should ALWAYS use the default Python interpreter available in the <execute_bash> environment to run code related to the provided issue and/or repository.
|
||||
|
||||
You should ALWAYS use local imports DO NOT import the general library.
|
||||
|
||||
When you think you have a fully adequate test suite, please run the following command: <execute_bash> exit </execute_bash>.
|
||||
"""
|
||||
|
||||
CODEACT_TESTGEN_PROMPT = """
|
||||
Your goal is to generate a comprehensive, **broad-coverage** test suite for the code below, ensuring you test as many lines and branches as possible on the first attempt.
|
||||
|
||||
Place your test suite in a new file named {test_file}.
|
||||
|
||||
IMPORTANT REQUIREMENTS:
|
||||
1. **No external help or resources**—use only the snippet below.
|
||||
2. **Focus on breadth over depth**: cover all major functions, classes, and code paths early to minimize coverage iterations.
|
||||
3. Each test function must start with `test_` and use `assert` to verify behavior.
|
||||
4. Include only necessary imports (standard library or local).
|
||||
5. Do **not** modify existing test files—create a brand new one. No `main()` or other non-test code.
|
||||
6. Produce **at least 20 test functions**; if coverage is lacking, add more tests rather than removing or changing existing ones.
|
||||
7. Use the following commands to check coverage:
|
||||
<execute_bash> {coverage_command} </execute_bash>
|
||||
<execute_bash> coverage report -m --include {code_file} </execute_bash>
|
||||
If lines remain uncovered, add new tests targeting them specifically.
|
||||
8. When you're satisfied with coverage, finalize by running:
|
||||
<execute_bash> exit </execute_bash>
|
||||
|
||||
Below is the **complete code snippet** to test:
|
||||
|
||||
<START_OF_CODE>
|
||||
{code_src}
|
||||
<END_OF_CODE>
|
||||
|
||||
NOTE: if you are testing django, you must use from django.test import SimpleTestCase and class based tests (i.e. class TestSomething(SimpleTestCase)).
|
||||
NOTE: if there is an error executing tests you MUST fix it before exiting. DO NOT install new packages.
|
||||
NOTE: if outputting a revised test suite REPLACE {test_file} with the revised suite
|
||||
|
||||
**Output the final test suite** (20+ tests) for {test_file} in a single code block, no extra commentary. MAKE SURE you run the tests and ensure you can see which tests passed and failed BEFORE exiting.
|
||||
"""
|
||||
|
||||
CODEACT_TESTGEN_PROMPT_ITERATE = """
|
||||
Your goal is to improve the test suite at {test_file} to achieve **broad-coverage** of the code below.
|
||||
|
||||
First run the test suite.
|
||||
|
||||
If no tests run, then remove {test_file} and create {test_file} with a new suite.
|
||||
|
||||
Otherwise, improve it aiming to improve code coverage.
|
||||
|
||||
IMPORTANT REQUIREMENTS:
|
||||
1. Use the following commands to check coverage (RUN THIS FIRST):
|
||||
<execute_bash> {coverage_command} </execute_bash>
|
||||
<execute_bash> coverage report -m --include {code_file} </execute_bash>
|
||||
If lines remain uncovered, add new tests targeting them specifically.
|
||||
2. **No external help or resources**—use only the snippet below.
|
||||
3. **Focus on breadth over depth**: cover all major functions, classes, and code paths early to minimize coverage iterations.
|
||||
4. Each test function must use `assert` to verify behavior.
|
||||
5. Include only necessary imports (standard library or local).
|
||||
6. Do **not** modify other test files in the repository. No `main()` or other non-test code.
|
||||
7. Produce **at least 20 test functions**; if coverage is lacking, add more tests rather than removing or changing existing ones.
|
||||
8. When you're satisfied with coverage, finalize by running:
|
||||
<execute_bash> exit </execute_bash>
|
||||
|
||||
Below is the **complete code snippet** to test:
|
||||
|
||||
<START_OF_CODE>
|
||||
{code_src}
|
||||
<END_OF_CODE>
|
||||
|
||||
NOTE: if you are testing django, you must use from django.test import SimpleTestCase and class based tests (i.e. class TestSomething(SimpleTestCase)).
|
||||
NOTE: if there is an error executing tests you MUST fix it before exiting. DO NOT install new packages.
|
||||
NOTE: if outputting a revised test suite REPLACE {test_file} with the revised suite
|
||||
|
||||
**Output the final test suite** (20+ tests) for {test_file} in a single code block, no extra commentary. MAKE SURE you run the tests and ensure you can see which tests passed and failed BEFORE exiting.
|
||||
"""
|
||||
@@ -0,0 +1,31 @@
|
||||
import re
|
||||
from pygments.lexers.python import PythonLexer
|
||||
|
||||
def tokenize_code(code):
|
||||
lexer = PythonLexer()
|
||||
tokens = process_pygments_tokens(lexer.get_tokens(code))
|
||||
return tokens
|
||||
|
||||
def process_pygments_tokens(tokens):
|
||||
new_tokens = []
|
||||
|
||||
for token in tokens:
|
||||
if str(token[0]) == "Token.Text" and re.match(r'\s+', token[1]) or str(token[0]) == "Token.Text.Whitespace":
|
||||
continue
|
||||
new_tokens.append(token[1])
|
||||
|
||||
new_tokens_final = []
|
||||
i = 0
|
||||
while i < len(new_tokens)-2:
|
||||
if new_tokens[i] == '"' and new_tokens[i+1]=='STR' and new_tokens[i+2] == '"':
|
||||
new_tokens_final.append("\"STR\"")
|
||||
i = i + 3
|
||||
else:
|
||||
new_tokens_final.append(new_tokens[i])
|
||||
i = i + 1
|
||||
|
||||
for i in range(len(new_tokens)-2, len(new_tokens)):
|
||||
if i >= 0:
|
||||
new_tokens_final.append(new_tokens[i])
|
||||
|
||||
return new_tokens_final
|
||||
@@ -0,0 +1,58 @@
|
||||
import json
|
||||
import re
|
||||
|
||||
|
||||
def check_coverage(coverage_output, code_file):
|
||||
json_cov = json.loads(coverage_output)
|
||||
if code_file in json_cov['files'].keys():
|
||||
file_data = json_cov['files'][code_file]
|
||||
return True, file_data['summary']['percent_covered']
|
||||
|
||||
return False, 0
|
||||
|
||||
|
||||
def check_mutation(mutation_output):
|
||||
if 'total jobs: ' in mutation_output:
|
||||
num_mutants = int(mutation_output.split('total jobs: ')[1].split('\n')[0])
|
||||
final_conf = mutation_output.split('\n')[-1]
|
||||
if len(final_conf.strip().split(' ')) == 3:
|
||||
low, val, high = final_conf.split(' ')
|
||||
low = float(low)
|
||||
val = float(val)
|
||||
high = float(high)
|
||||
|
||||
confidence_range = high - val
|
||||
mutation_score = 100 - val
|
||||
|
||||
return True, num_mutants, mutation_score, confidence_range
|
||||
|
||||
return False, -1, 0, -1
|
||||
|
||||
|
||||
def count_methods(code_str):
|
||||
"""
|
||||
Counts the number of methods/functions in a given string of code.
|
||||
|
||||
Args:
|
||||
code_str (str): A string containing code.
|
||||
|
||||
Returns:
|
||||
int: The number of methods/functions found.
|
||||
"""
|
||||
# Regular expression to find Python function definitions
|
||||
pattern = r'\bdef\b\s+\w+\s*\('
|
||||
matches = re.findall(pattern, code_str)
|
||||
return len(matches)
|
||||
|
||||
|
||||
def get_lines_of_code(code_str):
|
||||
"""
|
||||
Extracts lines of code from a given string.
|
||||
|
||||
Args:
|
||||
code_str (str): A string containing code.
|
||||
|
||||
Returns:
|
||||
list: A list of lines of code.
|
||||
"""
|
||||
return len(code_str.strip().split('\n'))
|
||||
@@ -0,0 +1,577 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import toml
|
||||
from datasets import load_dataset
|
||||
|
||||
import openhands.agenthub
|
||||
from evaluation.benchmarks.testgeneval.constants import MAP_REPO_VERSION_TO_SPECS
|
||||
from evaluation.benchmarks.testgeneval.prompt import (
|
||||
CODEACT_TESTGEN_PROMPT,
|
||||
CODEACT_TESTGEN_PROMPT_ITERATE,
|
||||
)
|
||||
from evaluation.benchmarks.testgeneval.utils import get_test_directives
|
||||
from evaluation.utils.shared import (
|
||||
EvalException,
|
||||
EvalMetadata,
|
||||
EvalOutput,
|
||||
assert_and_raise,
|
||||
codeact_user_response,
|
||||
get_metrics,
|
||||
is_fatal_evaluation_error,
|
||||
make_metadata,
|
||||
prepare_dataset,
|
||||
reset_logger_for_multiprocessing,
|
||||
run_evaluation,
|
||||
update_llm_config_for_completions_logging,
|
||||
)
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import (
|
||||
AgentConfig,
|
||||
AppConfig,
|
||||
SandboxConfig,
|
||||
get_llm_config_arg,
|
||||
get_parser,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.main import create_runtime, run_controller
|
||||
from openhands.events.action import CmdRunAction, MessageAction
|
||||
from openhands.events.observation import CmdOutputObservation, ErrorObservation
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.utils.async_utils import call_async_from_sync
|
||||
|
||||
RUN_WITH_BROWSING = os.environ.get('RUN_WITH_BROWSING', 'false').lower() == 'true'
|
||||
|
||||
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
|
||||
'CodeActAgent': codeact_user_response,
|
||||
}
|
||||
|
||||
|
||||
def _preprocess_instance(d):
|
||||
for key, value in d.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
d[key] = value.tolist()
|
||||
return d
|
||||
|
||||
|
||||
def _get_swebench_workspace_dir_name(instance: pd.Series) -> str:
|
||||
return f'{instance.repo}__{instance.version}'.replace('/', '__')
|
||||
|
||||
|
||||
def get_instruction(instance: pd.Series, metadata: EvalMetadata):
|
||||
# workspace_dir_name = _get_swebench_workspace_dir_name(instance)
|
||||
# Prepare instruction
|
||||
coverage_command = ' '.join(
|
||||
[
|
||||
MAP_REPO_VERSION_TO_SPECS[instance['repo']][instance['version']][
|
||||
'test_cmd'
|
||||
],
|
||||
*get_test_directives(instance),
|
||||
]
|
||||
)
|
||||
|
||||
# Testing general agents
|
||||
prompt_to_use = (
|
||||
CODEACT_TESTGEN_PROMPT_ITERATE
|
||||
if instance['full_pred'] is not None
|
||||
else CODEACT_TESTGEN_PROMPT
|
||||
)
|
||||
instruction = prompt_to_use.format(
|
||||
code_file=os.path.join('/testbed', instance.code_file),
|
||||
test_file=os.path.join('/testbed', instance.test_file),
|
||||
coverage_command=coverage_command,
|
||||
code_src=instance['code_src'],
|
||||
imports='\n'.join(instance.local_imports),
|
||||
workspace_dir_name=_get_swebench_workspace_dir_name(instance),
|
||||
)
|
||||
|
||||
if RUN_WITH_BROWSING:
|
||||
instruction += (
|
||||
'<IMPORTANT!>\n'
|
||||
'You SHOULD NEVER attempt to browse the web. '
|
||||
'</IMPORTANT!>\n'
|
||||
)
|
||||
|
||||
return instruction
|
||||
|
||||
|
||||
# TODO: migrate all swe-bench docker to ghcr.io/openhands
|
||||
DOCKER_IMAGE_PREFIX = os.environ.get('EVAL_DOCKER_IMAGE_PREFIX', 'docker.io/kdjain/')
|
||||
logger.info(f'Using docker image prefix: {DOCKER_IMAGE_PREFIX}')
|
||||
|
||||
|
||||
def get_instance_docker_image(instance_id: str) -> str:
|
||||
image_name = 'sweb.eval.x86_64.' + instance_id
|
||||
image_name = image_name.replace(
|
||||
'__', '_s_'
|
||||
) # to comply with docker image naming convention
|
||||
return DOCKER_IMAGE_PREFIX.rstrip('/') + '/' + image_name
|
||||
|
||||
|
||||
def get_config(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
) -> AppConfig:
|
||||
# We use a different instance image for the each instance of TestGenEval
|
||||
base_container_image = get_instance_docker_image(instance['instance_id_swebench'])
|
||||
logger.info(
|
||||
f'Using instance container image: {base_container_image}. '
|
||||
f'Please make sure this image exists. '
|
||||
f'Submit an issue on https://github.com/All-Hands-AI/OpenHands if you run into any issues.'
|
||||
)
|
||||
|
||||
config = AppConfig(
|
||||
default_agent=metadata.agent_class,
|
||||
run_as_openhands=False,
|
||||
max_iterations=metadata.max_iterations,
|
||||
runtime=os.environ.get('RUNTIME', 'eventstream'),
|
||||
sandbox=SandboxConfig(
|
||||
base_container_image=base_container_image,
|
||||
enable_auto_lint=True,
|
||||
use_host_network=False,
|
||||
# large enough timeout, since some testcases take very long to run
|
||||
timeout=300,
|
||||
# Add platform to the sandbox config to solve issue 4401
|
||||
platform='linux/amd64',
|
||||
api_key=os.environ.get('ALLHANDS_API_KEY', None),
|
||||
remote_runtime_api_url=os.environ.get(
|
||||
'SANDBOX_REMOTE_RUNTIME_API_URL', 'http://localhost:8000'
|
||||
),
|
||||
keep_runtime_alive=False,
|
||||
remote_runtime_init_timeout=3600,
|
||||
),
|
||||
# do not mount workspace
|
||||
workspace_base=None,
|
||||
workspace_mount_path=None,
|
||||
)
|
||||
config.set_llm_config(
|
||||
update_llm_config_for_completions_logging(
|
||||
metadata.llm_config, metadata.eval_output_dir, instance['id']
|
||||
)
|
||||
)
|
||||
agent_config = AgentConfig(
|
||||
codeact_enable_jupyter=False,
|
||||
codeact_enable_browsing=RUN_WITH_BROWSING,
|
||||
codeact_enable_llm_editor=False,
|
||||
condenser=metadata.condenser_config,
|
||||
)
|
||||
config.set_agent_config(agent_config)
|
||||
return config
|
||||
|
||||
|
||||
def initialize_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series, # this argument is not required
|
||||
):
|
||||
"""Initialize the runtime for the agent.
|
||||
|
||||
This function is called before the runtime is used to run the agent.
|
||||
"""
|
||||
logger.info('-' * 30)
|
||||
logger.info('BEGIN Runtime Initialization Fn')
|
||||
logger.info('-' * 30)
|
||||
workspace_dir_name = _get_swebench_workspace_dir_name(instance)
|
||||
obs: CmdOutputObservation
|
||||
|
||||
instance['instance_id'] = instance['instance_id_swebench']
|
||||
|
||||
# Set instance id
|
||||
action = CmdRunAction(
|
||||
command=f"""echo 'export SWE_INSTANCE_ID={instance['instance_id_swebench']}' >> ~/.bashrc && echo 'export PIP_CACHE_DIR=~/.cache/pip' >> ~/.bashrc && echo "alias git='git --no-pager'" >> ~/.bashrc"""
|
||||
)
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0, f'Failed to export SWE_INSTANCE_ID: {str(obs)}'
|
||||
)
|
||||
|
||||
action = CmdRunAction(command="""export USER=$(whoami); echo USER=${USER} """)
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to export USER: {str(obs)}')
|
||||
|
||||
# inject the init script
|
||||
script_dir = os.path.dirname(__file__)
|
||||
|
||||
# inject the instance info
|
||||
action = CmdRunAction(command='mkdir -p /swe_util/eval_data/instances')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to create /swe_util/eval_data/instances: {str(obs)}',
|
||||
)
|
||||
|
||||
swe_instance_json_name = 'swe-bench-instance.json'
|
||||
swe_prediction = 'test_suite.py'
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Construct the full path for the desired file name within the temporary directory
|
||||
temp_file_path = os.path.join(temp_dir, swe_instance_json_name)
|
||||
# Write to the file with the desired name within the temporary directory
|
||||
with open(temp_file_path, 'w') as f:
|
||||
if not isinstance(instance, dict):
|
||||
preprocessed_instance = _preprocess_instance(instance.to_dict())
|
||||
json.dump([preprocessed_instance], f)
|
||||
else:
|
||||
preprocessed_instance = _preprocess_instance(instance)
|
||||
json.dump([preprocessed_instance], f)
|
||||
|
||||
# Copy the file to the desired location
|
||||
runtime.copy_to(temp_file_path, '/swe_util/eval_data/instances/')
|
||||
|
||||
if instance['full_pred'] is not None:
|
||||
temp_file_path_pred = os.path.join(temp_dir, swe_prediction)
|
||||
with open(temp_file_path_pred, 'w') as f:
|
||||
f.write(instance['full_pred'])
|
||||
|
||||
runtime.copy_to(temp_file_path_pred, '/tmp')
|
||||
|
||||
# Copy the file to the desired location
|
||||
action = CmdRunAction(
|
||||
command=f"cp /tmp/test_suite.py /testbed/{instance['test_file']}"
|
||||
)
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0, f'Failed to copy test file: {str(obs)}'
|
||||
)
|
||||
|
||||
action = CmdRunAction(
|
||||
command='git -C /testbed add . && git -C /testbed commit -m "Add test file"'
|
||||
)
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to cat ~/.bashrc: {str(obs)}')
|
||||
|
||||
# inject the instance swe entry
|
||||
runtime.copy_to(
|
||||
str(os.path.join(script_dir, 'scripts/setup/instance_swe_entry.sh')),
|
||||
'/swe_util/',
|
||||
)
|
||||
action = CmdRunAction(command='cat ~/.bashrc')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to cat ~/.bashrc: {str(obs)}')
|
||||
|
||||
action = CmdRunAction(command='source ~/.bashrc')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
if isinstance(obs, ErrorObservation):
|
||||
logger.error(f'Failed to source ~/.bashrc: {str(obs)}')
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to source ~/.bashrc: {str(obs)}')
|
||||
|
||||
action = CmdRunAction(command='source /swe_util/instance_swe_entry.sh')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to source /swe_util/instance_swe_entry.sh: {str(obs)}',
|
||||
)
|
||||
|
||||
action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to cd to /workspace/{workspace_dir_name}: {str(obs)}',
|
||||
)
|
||||
|
||||
action = CmdRunAction(command='git reset --hard')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to git reset --hard: {str(obs)}')
|
||||
|
||||
action = CmdRunAction(
|
||||
command='for remote_name in $(git remote); do git remote remove "${remote_name}"; done'
|
||||
)
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to remove git remotes: {str(obs)}')
|
||||
|
||||
logger.info('-' * 30)
|
||||
logger.info('END Runtime Initialization Fn')
|
||||
logger.info('-' * 30)
|
||||
|
||||
|
||||
def complete_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series, # this argument is not required, but it is used to get the workspace_dir_name
|
||||
) -> dict[str, Any]:
|
||||
"""Complete the runtime for the agent.
|
||||
|
||||
This function is called before the runtime is used to run the agent.
|
||||
If you need to do something in the sandbox to get the correctness metric after
|
||||
the agent has run, modify this function.
|
||||
"""
|
||||
try:
|
||||
logger.info('-' * 30)
|
||||
logger.info('BEGIN Runtime Completion Fn')
|
||||
logger.info('-' * 30)
|
||||
obs: CmdOutputObservation
|
||||
workspace_dir_name = _get_swebench_workspace_dir_name(instance)
|
||||
|
||||
action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to cd to /workspace/{workspace_dir_name}: {str(obs)}',
|
||||
)
|
||||
|
||||
action = CmdRunAction(command=f'cat {instance.test_file}')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to find file: {instance.test_file} in /workspace/{workspace_dir_name}',
|
||||
)
|
||||
|
||||
test_suite = obs.content.strip()
|
||||
except Exception:
|
||||
# Print stack trace
|
||||
print('Skipping, exception in complete_runtime')
|
||||
print(traceback.format_exc())
|
||||
test_suite = instance['full_pred'] if instance['full_pred'] is not None else ''
|
||||
|
||||
# action = CmdRunAction(command='git add -A')
|
||||
# action.set_hard_timeout(600)
|
||||
# logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
# obs = runtime.run_action(action)
|
||||
# logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
# assert_and_raise(obs.exit_code == 0, f'Failed to git add -A: {str(obs)}')
|
||||
|
||||
logger.info('-' * 30)
|
||||
logger.info('END Runtime Completion Fn')
|
||||
logger.info('-' * 30)
|
||||
return {
|
||||
'test_suite': test_suite,
|
||||
}
|
||||
|
||||
|
||||
def process_instance(
|
||||
instance: pd.Series,
|
||||
metadata: EvalMetadata,
|
||||
reset_logger: bool = True,
|
||||
) -> EvalOutput:
|
||||
config = get_config(instance, metadata)
|
||||
start_time = time.time() # Track start time
|
||||
|
||||
# Setup the logger properly, so you can run multi-processing to parallelize the evaluation
|
||||
if reset_logger:
|
||||
log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
|
||||
reset_logger_for_multiprocessing(logger, instance.id, log_dir)
|
||||
else:
|
||||
logger.info(f'Starting evaluation for instance {instance.id}.')
|
||||
|
||||
runtime = create_runtime(config)
|
||||
call_async_from_sync(runtime.connect)
|
||||
|
||||
try:
|
||||
initialize_runtime(runtime, instance)
|
||||
|
||||
instruction = get_instruction(instance, metadata)
|
||||
|
||||
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||
state: State | None = asyncio.run(
|
||||
run_controller(
|
||||
config=config,
|
||||
initial_user_action=MessageAction(content=instruction),
|
||||
runtime=runtime,
|
||||
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
||||
metadata.agent_class
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
# if fatal error, throw EvalError to trigger re-run
|
||||
if is_fatal_evaluation_error(state.last_error):
|
||||
raise EvalException('Fatal error detected: ' + state.last_error)
|
||||
|
||||
# ======= THIS IS SWE-Bench specific =======
|
||||
return_val = complete_runtime(runtime, instance)
|
||||
test_suite = return_val['test_suite']
|
||||
logger.info(
|
||||
f'Got test suite for instance {instance.instance_id}:\n--------\n{test_suite}\n--------'
|
||||
)
|
||||
finally:
|
||||
runtime.close()
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
logger.info(
|
||||
f'Evaluation for instance {instance.instance_id} took {elapsed_time:.2f} seconds.'
|
||||
)
|
||||
|
||||
# ==========================================
|
||||
|
||||
# ======= Attempt to evaluate the agent's edits =======
|
||||
# we use eval_infer.sh to evaluate the agent's edits, not here
|
||||
# because the agent may alter the environment / testcases
|
||||
test_result = {
|
||||
'test_suite': test_suite,
|
||||
'elapsed_time': elapsed_time,
|
||||
}
|
||||
|
||||
# If you are working on some simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
|
||||
# You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
|
||||
if state is None:
|
||||
raise ValueError('State should not be None.')
|
||||
|
||||
histories = [event_to_dict(event) for event in state.history]
|
||||
metrics = get_metrics(state)
|
||||
|
||||
# Save the output
|
||||
output = EvalOutput(
|
||||
instance_id=instance.id,
|
||||
instruction=instruction,
|
||||
instance=_preprocess_instance(instance.to_dict()), # SWE Bench specific
|
||||
test_result=test_result,
|
||||
metadata=metadata,
|
||||
history=histories,
|
||||
metrics=metrics,
|
||||
error=state.last_error if state and state.last_error else None,
|
||||
)
|
||||
# print(output)
|
||||
return output
|
||||
|
||||
|
||||
def prepare_dataset_pre(dataset: pd.DataFrame, filter_column: str) -> pd.DataFrame:
|
||||
file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.toml')
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path, 'r') as file:
|
||||
data = toml.load(file)
|
||||
if 'selected_ids' in data:
|
||||
selected_ids = data['selected_ids']
|
||||
logger.info(
|
||||
f'Filtering {len(selected_ids)} tasks from "selected_ids"...'
|
||||
)
|
||||
subset = dataset[dataset[filter_column].isin(selected_ids)]
|
||||
logger.info(f'Retained {subset.shape[0]} tasks after filtering')
|
||||
|
||||
subset['instance_id_swebench'] = subset['instance_id']
|
||||
subset['instance_id'] = subset['id']
|
||||
return subset
|
||||
|
||||
dataset['instance_id_swebench'] = dataset['instance_id']
|
||||
dataset['instance_id'] = dataset['id']
|
||||
return dataset
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = get_parser()
|
||||
parser.add_argument(
|
||||
'--dataset',
|
||||
type=str,
|
||||
default='kjain/testgenevallite',
|
||||
help='data set to evaluate on, either full-test or lite-test',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--split',
|
||||
type=str,
|
||||
default='test',
|
||||
help='split to evaluate on',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--testfile_start',
|
||||
action='store_true',
|
||||
help='Whether to start from the 0 shot test file',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--zero_shot_path',
|
||||
type=str,
|
||||
help='Path to the zero shot test file predictions',
|
||||
)
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if args.testfile_start and not args.zero_shot_path:
|
||||
raise ValueError(
|
||||
'If you want to start from the 0 shot test file, you must provide the path to the zero shot test file predictions'
|
||||
)
|
||||
|
||||
preds_map = {}
|
||||
if args.testfile_start:
|
||||
with open(args.zero_shot_path, 'r') as f:
|
||||
for line in f:
|
||||
pred = json.loads(line)
|
||||
preds_map[pred['id']] = pred['preds']['full'][0]
|
||||
|
||||
# NOTE: It is preferable to load datasets from huggingface datasets and perform post-processing
|
||||
# so we don't need to manage file uploading to OpenHands's repo
|
||||
dataset = load_dataset(args.dataset, split=args.split)
|
||||
logger.info(f'Loaded dataset {args.dataset} with split {args.split}')
|
||||
testgeneval_filepairs = prepare_dataset_pre(dataset.to_pandas(), 'id')
|
||||
|
||||
llm_config = None
|
||||
if args.llm_config:
|
||||
llm_config = get_llm_config_arg(args.llm_config)
|
||||
llm_config.log_completions = True
|
||||
# modify_params must be False for evaluation purpose, for reproducibility and accurancy of results
|
||||
llm_config.modify_params = False
|
||||
|
||||
if llm_config is None:
|
||||
raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
|
||||
|
||||
details = {}
|
||||
_agent_cls = openhands.agenthub.Agent.get_cls(args.agent_cls)
|
||||
|
||||
dataset_descrption = (
|
||||
args.dataset.replace('/', '__') + '-' + args.split.replace('/', '__')
|
||||
)
|
||||
metadata = make_metadata(
|
||||
llm_config,
|
||||
dataset_descrption,
|
||||
args.agent_cls,
|
||||
args.max_iterations,
|
||||
args.eval_note,
|
||||
args.eval_output_dir,
|
||||
details=details,
|
||||
)
|
||||
|
||||
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||
instances = prepare_dataset(testgeneval_filepairs, output_file, args.eval_n_limit)
|
||||
|
||||
if not instances.empty:
|
||||
instances['full_pred'] = (
|
||||
instances['instance_id']
|
||||
.map(preds_map)
|
||||
.apply(lambda x: x if pd.notna(x) else None)
|
||||
)
|
||||
|
||||
run_evaluation(
|
||||
instances, metadata, output_file, args.eval_num_workers, process_instance
|
||||
)
|
||||
@@ -0,0 +1,128 @@
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
# Function to run shell commands
|
||||
def run_command(command):
|
||||
try:
|
||||
subprocess.run(command, check=True, shell=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f'An error occurred: {e}')
|
||||
|
||||
|
||||
# Function to log in to Docker Hub
|
||||
def docker_login():
|
||||
print('Logging into Docker Hub...')
|
||||
run_command('docker login')
|
||||
|
||||
|
||||
# Function to generate Dockerfile content based on image type
|
||||
def generate_dockerfile_content(
|
||||
base_image, dependencies, datum, patch_path, test_patch_path
|
||||
):
|
||||
dockerfile_content = f"""
|
||||
FROM {base_image}
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
RUN source /opt/miniconda3/bin/activate && conda activate testbed && pip install {' '.join(dependencies)}
|
||||
COPY {patch_path} /app/patch.diff
|
||||
RUN git apply /app/patch.diff
|
||||
RUN rm /app/patch.diff
|
||||
COPY {test_patch_path} /app/patch.diff
|
||||
RUN git apply /app/patch.diff
|
||||
RUN git config --global user.email ""
|
||||
RUN git config --global user.name "TestGenEval"
|
||||
RUN rm /app/patch.diff
|
||||
RUN rm {datum['test_file']}
|
||||
"""
|
||||
|
||||
# Add specific content based on image type
|
||||
dockerfile_content += 'RUN git add .\nRUN git commit -m "Testing fixes"'
|
||||
|
||||
return dockerfile_content
|
||||
|
||||
|
||||
# Function to build, push, and clean up Docker images
|
||||
def build_and_push_image(dockerfile_content, image_name):
|
||||
with open('Dockerfile.temp', 'w') as dockerfile:
|
||||
dockerfile.write(dockerfile_content)
|
||||
run_command(f'docker build -f Dockerfile.temp -t {image_name} .')
|
||||
run_command(f'docker push {image_name}')
|
||||
run_command(f'docker rmi {image_name}')
|
||||
os.remove('Dockerfile.temp')
|
||||
|
||||
|
||||
# Function to process images with .eval in the name
|
||||
def process_images(dataset, original_namespace, new_namespace, start_instance_id):
|
||||
dependencies = ['coverage', 'cosmic-ray']
|
||||
|
||||
found_start = len(start_instance_id) == 0
|
||||
for datum in dataset:
|
||||
if not found_start and datum['instance_id'] == start_instance_id:
|
||||
found_start = True
|
||||
elif found_start:
|
||||
full_image_name = f'{original_namespace}/sweb.eval.x86_64.{datum["instance_id"].replace("__", "_s_")}:latest'
|
||||
print(f'Processing image: {full_image_name}')
|
||||
run_command(f'docker pull {full_image_name}')
|
||||
|
||||
# Save patches and preds_context to regular files
|
||||
patch_file_path = 'patch.diff'
|
||||
test_patch_file_path = 'test_patch.diff'
|
||||
|
||||
with open(patch_file_path, 'w') as patch_file, open(
|
||||
test_patch_file_path, 'w'
|
||||
) as test_patch_file:
|
||||
patch_file.write(datum['patch'])
|
||||
test_patch_file.write(datum['test_patch'])
|
||||
|
||||
# Define image types and corresponding tags
|
||||
new_image_name = f'{new_namespace}/sweb.eval.x86_64.{datum["instance_id"].replace("__", "_s_")}:latest'
|
||||
dockerfile_content = generate_dockerfile_content(
|
||||
full_image_name,
|
||||
dependencies,
|
||||
datum,
|
||||
patch_file_path,
|
||||
test_patch_file_path,
|
||||
)
|
||||
build_and_push_image(dockerfile_content, new_image_name)
|
||||
|
||||
# Cleanup regular files and images
|
||||
os.remove(patch_file_path)
|
||||
os.remove(test_patch_file_path)
|
||||
run_command(f'docker rmi {full_image_name}')
|
||||
run_command('docker system prune -f') # Clean up dangling resources
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Process Docker images with .eval in the name.'
|
||||
)
|
||||
parser.add_argument('--dataset', type=str, default='kjain14/testgeneval')
|
||||
parser.add_argument('--split', type=str, default='test')
|
||||
parser.add_argument(
|
||||
'--new_namespace',
|
||||
type=str,
|
||||
default='kdjain',
|
||||
help='The new Docker Hub namespace to push the images',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--original_namespace',
|
||||
type=str,
|
||||
default='xingyaoww',
|
||||
help='The original Docker Hub namespace',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--start_instance_id',
|
||||
type=str,
|
||||
default='',
|
||||
help='The instance_id to start processing from',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
dataset = load_dataset(args.dataset)[args.split]
|
||||
|
||||
docker_login()
|
||||
process_images(
|
||||
dataset, args.original_namespace, args.new_namespace, args.start_instance_id
|
||||
)
|
||||
+1274
File diff suppressed because it is too large
Load Diff
+196
@@ -0,0 +1,196 @@
|
||||
sweb.base.x86_64:latest
|
||||
sweb.env.x86_64.088a7e628bda9770f9757b:latest
|
||||
sweb.env.x86_64.0d80c7dec81ee2f2f513e2:latest
|
||||
sweb.env.x86_64.0f99bce2750f3109957bec:latest
|
||||
sweb.env.x86_64.1b3b218535da0abf4469cb:latest
|
||||
sweb.env.x86_64.1c1a6945f732f9391228c5:latest
|
||||
sweb.env.x86_64.1f92e6d7cef88badc4f744:latest
|
||||
sweb.env.x86_64.27dd9791e13f5c857a09f9:latest
|
||||
sweb.env.x86_64.297af196949a2a635bce66:latest
|
||||
sweb.env.x86_64.2baaea72acc974f6c02079:latest
|
||||
sweb.env.x86_64.2e50125951bc69cddd7421:latest
|
||||
sweb.env.x86_64.2f217c8b4490bfa0e2ba14:latest
|
||||
sweb.env.x86_64.31244378a92e3bcce809ac:latest
|
||||
sweb.env.x86_64.428468730904ff6b4232aa:latest
|
||||
sweb.env.x86_64.5d1fda9d55d65d8a4e5bdb:latest
|
||||
sweb.env.x86_64.6b007979cf533f0f3016e8:latest
|
||||
sweb.env.x86_64.7037e8c448a4b8ebfe9b13:latest
|
||||
sweb.env.x86_64.71498c7426dbf05599642f:latest
|
||||
sweb.env.x86_64.756beac07713d7e8dc1129:latest
|
||||
sweb.env.x86_64.78278ae2cf880e395f1337:latest
|
||||
sweb.env.x86_64.8f1f7b974f0c57c7aeba39:latest
|
||||
sweb.env.x86_64.934a137824256b612e9dc5:latest
|
||||
sweb.env.x86_64.a0efca7a0fe6719dbf65c2:latest
|
||||
sweb.env.x86_64.a18371b03f944585b4f08c:latest
|
||||
sweb.env.x86_64.a33dddf55cdff5d8e23374:latest
|
||||
sweb.env.x86_64.aa92880033da20ca313928:latest
|
||||
sweb.env.x86_64.b649f0ff62fad147f7f073:latest
|
||||
sweb.env.x86_64.b7ce4be3b3c35f68c61248:latest
|
||||
sweb.env.x86_64.c70909fdac4897d1c685df:latest
|
||||
sweb.env.x86_64.c795f4b88616b8462021ed:latest
|
||||
sweb.env.x86_64.cc47cc71483942d0c3a15e:latest
|
||||
sweb.env.x86_64.dc5ff4c0e3fe8db5afc4da:latest
|
||||
sweb.env.x86_64.e3afd7f04b325a4de4982d:latest
|
||||
sweb.env.x86_64.e5bb89bf78258a7d14c34b:latest
|
||||
sweb.env.x86_64.e83e37f52c09532c62acfb:latest
|
||||
sweb.env.x86_64.efa6065ed5bf204410fd53:latest
|
||||
sweb.eval.x86_64.django_s_django-17087:latest
|
||||
sweb.eval.x86_64.scikit-learn_s_scikit-learn-10508:latest
|
||||
sweb.eval.x86_64.django_s_django-14017:latest
|
||||
sweb.eval.x86_64.django_s_django-11422:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-14774:latest
|
||||
sweb.eval.x86_64.django_s_django-14915:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-22005:latest
|
||||
sweb.eval.x86_64.pytest-dev_s_pytest-5221:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-17022:latest
|
||||
sweb.eval.x86_64.django_s_django-15996:latest
|
||||
sweb.eval.x86_64.django_s_django-15252:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-21171:latest
|
||||
sweb.eval.x86_64.django_s_django-11797:latest
|
||||
sweb.eval.x86_64.django_s_django-16046:latest
|
||||
sweb.eval.x86_64.django_s_django-11583:latest
|
||||
sweb.eval.x86_64.django_s_django-15738:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-21612:latest
|
||||
sweb.eval.x86_64.astropy_s_astropy-12907:latest
|
||||
sweb.eval.x86_64.django_s_django-11620:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-16792:latest
|
||||
sweb.eval.x86_64.scikit-learn_s_scikit-learn-13779:latest
|
||||
sweb.eval.x86_64.django_s_django-16041:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-13471:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-20442:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-20049:latest
|
||||
sweb.eval.x86_64.django_s_django-14411:latest
|
||||
sweb.eval.x86_64.django_s_django-13447:latest
|
||||
sweb.eval.x86_64.django_s_django-12856:latest
|
||||
sweb.eval.x86_64.scikit-learn_s_scikit-learn-10949:latest
|
||||
sweb.eval.x86_64.django_s_django-14787:latest
|
||||
sweb.eval.x86_64.django_s_django-11815:latest
|
||||
sweb.eval.x86_64.scikit-learn_s_scikit-learn-13584:latest
|
||||
sweb.eval.x86_64.scikit-learn_s_scikit-learn-14087:latest
|
||||
sweb.eval.x86_64.django_s_django-15388:latest
|
||||
sweb.eval.x86_64.django_s_django-11179:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-24102:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-24213:latest
|
||||
sweb.eval.x86_64.django_s_django-15781:latest
|
||||
sweb.eval.x86_64.pytest-dev_s_pytest-8906:latest
|
||||
sweb.eval.x86_64.django_s_django-13710:latest
|
||||
sweb.eval.x86_64.django_s_django-13925:latest
|
||||
sweb.eval.x86_64.scikit-learn_s_scikit-learn-14092:latest
|
||||
sweb.eval.x86_64.pytest-dev_s_pytest-7373:latest
|
||||
sweb.eval.x86_64.matplotlib_s_matplotlib-25498:latest
|
||||
sweb.eval.x86_64.pytest-dev_s_pytest-5227:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-15678:latest
|
||||
sweb.eval.x86_64.django_s_django-13551:latest
|
||||
sweb.eval.x86_64.django_s_django-14155:latest
|
||||
sweb.eval.x86_64.django_s_django-13933:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-21055:latest
|
||||
sweb.eval.x86_64.django_s_django-13660:latest
|
||||
sweb.eval.x86_64.django_s_django-16527:latest
|
||||
sweb.eval.x86_64.pytest-dev_s_pytest-5692:latest
|
||||
sweb.eval.x86_64.mwaskom_s_seaborn-3010:latest
|
||||
sweb.eval.x86_64.django_s_django-12700:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-11400:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-23117:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-20639:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-23262:latest
|
||||
sweb.eval.x86_64.django_s_django-15498:latest
|
||||
sweb.eval.x86_64.django_s_django-12453:latest
|
||||
sweb.eval.x86_64.django_s_django-14999:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-13480:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-21847:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-15011:latest
|
||||
sweb.eval.x86_64.scikit-learn_s_scikit-learn-25570:latest
|
||||
sweb.eval.x86_64.sphinx-doc_s_sphinx-7975:latest
|
||||
sweb.eval.x86_64.scikit-learn_s_scikit-learn-14983:latest
|
||||
sweb.eval.x86_64.django_s_django-14534:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-14396:latest
|
||||
sweb.eval.x86_64.matplotlib_s_matplotlib-25442:latest
|
||||
sweb.eval.x86_64.scikit-learn_s_scikit-learn-15535:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-22714:latest
|
||||
sweb.eval.x86_64.django_s_django-15789:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-21627:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-24066:latest
|
||||
sweb.eval.x86_64.pylint-dev_s_pylint-7993:latest
|
||||
sweb.eval.x86_64.django_s_django-14752:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-18835:latest
|
||||
sweb.eval.x86_64.django_s_django-17051:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-12171:latest
|
||||
sweb.eval.x86_64.pydata_s_xarray-3364:latest
|
||||
sweb.eval.x86_64.mwaskom_s_seaborn-3190:latest
|
||||
sweb.eval.x86_64.pytest-dev_s_pytest-7168:latest
|
||||
sweb.eval.x86_64.django_s_django-12747:latest
|
||||
sweb.eval.x86_64.django_s_django-15695:latest
|
||||
sweb.eval.x86_64.matplotlib_s_matplotlib-22835:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-12481:latest
|
||||
sweb.eval.x86_64.django_s_django-15851:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-14024:latest
|
||||
sweb.eval.x86_64.django_s_django-14608:latest
|
||||
sweb.eval.x86_64.pytest-dev_s_pytest-9359:latest
|
||||
sweb.eval.x86_64.django_s_django-16873:latest
|
||||
sweb.eval.x86_64.matplotlib_s_matplotlib-25433:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-13031:latest
|
||||
sweb.eval.x86_64.pytest-dev_s_pytest-7432:latest
|
||||
sweb.eval.x86_64.scikit-learn_s_scikit-learn-25747:latest
|
||||
sweb.eval.x86_64.django_s_django-12286:latest
|
||||
sweb.eval.x86_64.django_s_django-11910:latest
|
||||
sweb.eval.x86_64.scikit-learn_s_scikit-learn-12471:latest
|
||||
sweb.eval.x86_64.pylint-dev_s_pylint-5859:latest
|
||||
sweb.eval.x86_64.django_s_django-11133:latest
|
||||
sweb.eval.x86_64.astropy_s_astropy-14365:latest
|
||||
sweb.eval.x86_64.scikit-learn_s_scikit-learn-13496:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-19487:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-13895:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-15345:latest
|
||||
sweb.eval.x86_64.django_s_django-13590:latest
|
||||
sweb.eval.x86_64.django_s_django-13757:latest
|
||||
sweb.eval.x86_64.django_s_django-16379:latest
|
||||
sweb.eval.x86_64.django_s_django-13768:latest
|
||||
sweb.eval.x86_64.pytest-dev_s_pytest-8365:latest
|
||||
sweb.eval.x86_64.django_s_django-14580:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-20154:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-12419:latest
|
||||
sweb.eval.x86_64.django_s_django-12125:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-24152:latest
|
||||
sweb.eval.x86_64.scikit-learn_s_scikit-learn-15512:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-18621:latest
|
||||
sweb.eval.x86_64.pydata_s_xarray-4248:latest
|
||||
sweb.eval.x86_64.scikit-learn_s_scikit-learn-11040:latest
|
||||
sweb.eval.x86_64.django_s_django-11099:latest
|
||||
sweb.eval.x86_64.django_s_django-16816:latest
|
||||
sweb.eval.x86_64.django_s_django-13265:latest
|
||||
sweb.eval.x86_64.django_s_django-16139:latest
|
||||
sweb.eval.x86_64.scikit-learn_s_scikit-learn-10297:latest
|
||||
sweb.eval.x86_64.django_s_django-14016:latest
|
||||
sweb.eval.x86_64.pallets_s_flask-5063:latest
|
||||
sweb.eval.x86_64.astropy_s_astropy-7746:latest
|
||||
sweb.eval.x86_64.matplotlib_s_matplotlib-24265:latest
|
||||
sweb.eval.x86_64.django_s_django-13448:latest
|
||||
sweb.eval.x86_64.django_s_django-12908:latest
|
||||
sweb.eval.x86_64.sphinx-doc_s_sphinx-8627:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-14317:latest
|
||||
sweb.eval.x86_64.pytest-dev_s_pytest-6116:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-23191:latest
|
||||
sweb.eval.x86_64.pydata_s_xarray-5131:latest
|
||||
sweb.eval.x86_64.django_s_django-11019:latest
|
||||
sweb.eval.x86_64.matplotlib_s_matplotlib-23913:latest
|
||||
sweb.eval.x86_64.django_s_django-15790:latest
|
||||
sweb.eval.x86_64.django_s_django-12497:latest
|
||||
sweb.eval.x86_64.matplotlib_s_matplotlib-26020:latest
|
||||
sweb.eval.x86_64.scikit-learn_s_scikit-learn-25638:latest
|
||||
sweb.eval.x86_64.scikit-learn_s_scikit-learn-25500:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-19007:latest
|
||||
sweb.eval.x86_64.django_s_django-12308:latest
|
||||
sweb.eval.x86_64.pytest-dev_s_pytest-7220:latest
|
||||
sweb.eval.x86_64.django_s_django-11848:latest
|
||||
sweb.eval.x86_64.django_s_django-15347:latest
|
||||
sweb.eval.x86_64.pytest-dev_s_pytest-7490:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-18532:latest
|
||||
sweb.eval.x86_64.django_s_django-14997:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-24909:latest
|
||||
sweb.eval.x86_64.django_s_django-13220:latest
|
||||
sweb.eval.x86_64.sympy_s_sympy-21614:latest
|
||||
sweb.eval.x86_64.django_s_django-15902:latest
|
||||
sweb.eval.x86_64.scikit-learn_s_scikit-learn-13497:latest
|
||||
sweb.eval.x86_64.scikit-learn_s_scikit-learn-13439:latest
|
||||
sweb.eval.x86_64.scikit-learn_s_scikit-learn-14894:latest
|
||||
sweb.eval.x86_64.django_s_django-12983:latest
|
||||
@@ -0,0 +1,31 @@
|
||||
def print_diff_ignore_order(file1, file2):
|
||||
with open(file1, 'r') as f1, open(file2, 'r') as f2:
|
||||
file1_lines = set(f1.readlines())
|
||||
file2_lines = set(f2.readlines())
|
||||
|
||||
only_in_file1 = file1_lines - file2_lines
|
||||
only_in_file2 = file2_lines - file1_lines
|
||||
|
||||
if only_in_file1:
|
||||
print(f'Lines in {file1} but not in {file2}:')
|
||||
for line in sorted(only_in_file1):
|
||||
print(f'- {line.strip()}')
|
||||
|
||||
# if only_in_file2:
|
||||
# print(f"Lines in {file2} but not in {file1}:")
|
||||
# for line in sorted(only_in_file2):
|
||||
# print(f"+ {line.strip()}")
|
||||
|
||||
if not only_in_file1 and not only_in_file2:
|
||||
print('The files have the same content (ignoring line order).')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Usage
|
||||
lite1 = 'all-swebench-lite-instance-images.txt' # Replace with the path to your first file
|
||||
lite2 = '../../swe_bench/scripts/docker/all-swebench-lite-instance-images.txt' # Replace with the path to your second file
|
||||
print_diff_ignore_order(lite1, lite2)
|
||||
|
||||
full1 = 'all-swebench-full-instance-images.txt' # Replace with the path to your first file
|
||||
full2 = '../../swe_bench/scripts/docker/all-swebench-full-instance-images.txt' # Replace with the path to your second file
|
||||
print_diff_ignore_order(full1, full2)
|
||||
@@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
# Script will delete all repositories and tags in your Docker Hub account
|
||||
set -e
|
||||
|
||||
# Set username and password from command-line arguments
|
||||
UNAME=$1
|
||||
UPASS=$2
|
||||
|
||||
# Get token to interact with Docker Hub
|
||||
TOKEN=$(curl -s -H "Content-Type: application/json" -X POST -d '{"username": "'${UNAME}'", "password": "'${UPASS}'"}' https://hub.docker.com/v2/users/login/ | jq -r .token)
|
||||
|
||||
# Ensure token retrieval was successful
|
||||
if [[ -z "$TOKEN" ]]; then
|
||||
echo "Failed to obtain authentication token. Please check your credentials."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Get list of repositories for that user account
|
||||
echo "Listing repositories in Docker Hub account '${UNAME}':"
|
||||
REPO_LIST=$(curl -s -H "Authorization: JWT ${TOKEN}" "https://hub.docker.com/v2/repositories/${UNAME}/?page_size=10000" | jq -r '.results|.[]|.name')
|
||||
if [[ -z "$REPO_LIST" ]]; then
|
||||
echo "No repositories found for user '${UNAME}' or failed to fetch repositories."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Loop through each repository and delete its tags and the repository itself
|
||||
for rep in ${REPO_LIST}; do
|
||||
echo "Processing repository: ${UNAME}/${rep}"
|
||||
|
||||
# Get all tags for the repository
|
||||
IMAGES=$(curl -s -H "Authorization: JWT ${TOKEN}" "https://hub.docker.com/v2/repositories/${UNAME}/${rep}/tags/?page_size=100")
|
||||
IMAGE_TAGS=$(echo $IMAGES | jq -r '.results|.[]|.name')
|
||||
|
||||
# Delete each tag
|
||||
for tag in ${IMAGE_TAGS}; do
|
||||
echo "Deleting tag: ${UNAME}/${rep}:${tag}"
|
||||
curl -s -X DELETE -H "Authorization: JWT ${TOKEN}" "https://hub.docker.com/v2/repositories/${UNAME}/${rep}/tags/${tag}/"
|
||||
done
|
||||
|
||||
# Delete the repository itself
|
||||
echo "Deleting repository: ${UNAME}/${rep}"
|
||||
curl -s -X DELETE -H "Authorization: JWT ${TOKEN}" "https://hub.docker.com/v2/repositories/${UNAME}/${rep}/" || {
|
||||
echo "Failed to delete repository '${UNAME}/${rep}'. Please check permissions or API limits."
|
||||
}
|
||||
sleep 1
|
||||
done
|
||||
|
||||
echo "Script execution completed."
|
||||
@@ -0,0 +1,18 @@
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
def dataset_to_txt(dataset, txt_file, split='test'):
|
||||
with open(txt_file, 'w') as f:
|
||||
for datum in dataset[split]:
|
||||
instance_id = datum['instance_id'].replace('__', '_s_')
|
||||
f.write(f'sweb.eval.x86_64.{instance_id}:latest\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Load the private dataset
|
||||
dataset = load_dataset('kjain14/testgeneval')
|
||||
|
||||
dataset_lite = load_dataset('kjain14/testgenevallite')
|
||||
|
||||
dataset_to_txt(dataset_lite, 'all-swebench-lite-instance-images.txt', lite=True)
|
||||
dataset_to_txt(dataset, 'all-swebench-full-instance-images.txt')
|
||||
@@ -0,0 +1,173 @@
|
||||
import argparse
|
||||
import copy
|
||||
import difflib
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
|
||||
|
||||
def insert_line_in_string(input_string, new_str, insert_line):
|
||||
"""
|
||||
Inserts a new line into a string at the specified line number.
|
||||
|
||||
:param input_string: The original string.
|
||||
:param new_str: The string to insert.
|
||||
:param insert_line: The line number at which to insert (1-based index).
|
||||
:return: The modified string.
|
||||
"""
|
||||
file_text = input_string.expandtabs()
|
||||
new_str = new_str.expandtabs()
|
||||
|
||||
file_text_lines = file_text.split('\n')
|
||||
|
||||
new_str_lines = new_str.split('\n')
|
||||
new_file_text_lines = (
|
||||
file_text_lines[:insert_line] + new_str_lines + file_text_lines[insert_line:]
|
||||
)
|
||||
|
||||
return '\n'.join(new_file_text_lines)
|
||||
|
||||
|
||||
def print_string_diff(original, modified):
|
||||
"""
|
||||
Prints the differences between two strings line by line.
|
||||
|
||||
:param original: The original string.
|
||||
:param modified: The modified string.
|
||||
"""
|
||||
original_lines = original.splitlines(keepends=True)
|
||||
modified_lines = modified.splitlines(keepends=True)
|
||||
|
||||
diff = difflib.unified_diff(
|
||||
original_lines,
|
||||
modified_lines,
|
||||
fromfile='original',
|
||||
tofile='modified',
|
||||
lineterm='',
|
||||
)
|
||||
|
||||
print(''.join(diff))
|
||||
|
||||
|
||||
def parse_json_files(root_dir, output_dir, metadata_objs, preds_objs):
|
||||
final_output = {i: [] for i in range(25)}
|
||||
|
||||
for subdir in sorted(os.listdir(root_dir)): # Sorting ensures consistent order
|
||||
subdir_path = os.path.join(root_dir, subdir)
|
||||
# subdir_instance = subdir.rsplit('-', 1)[0]
|
||||
metadata = metadata_objs[subdir]
|
||||
orig_test_suite = metadata['test_result']['test_suite']
|
||||
|
||||
if os.path.isdir(subdir_path): # Check if it's a directory
|
||||
print(f'Processing subdirectory: {subdir}')
|
||||
|
||||
# Now loop through the JSON files in this subdirectory
|
||||
i = 0
|
||||
test_suite = preds_objs[subdir] if subdir in preds_objs else ''
|
||||
for file in sorted(
|
||||
os.listdir(subdir_path)
|
||||
): # Sorting ensures consistent order
|
||||
metadata_copy = copy.deepcopy(metadata)
|
||||
if file.endswith('.json'): # Check for JSON files
|
||||
file_path = os.path.join(subdir_path, file)
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f) # Load JSON data
|
||||
try:
|
||||
tool_calls = data['response']['choices'][0]['message'][
|
||||
'tool_calls'
|
||||
]
|
||||
if tool_calls is not None:
|
||||
for tool_call in tool_calls:
|
||||
tool_call_dict = eval(
|
||||
tool_call['function']['arguments']
|
||||
)
|
||||
|
||||
if (
|
||||
tool_call_dict is not None
|
||||
and tool_call_dict != {}
|
||||
):
|
||||
command = tool_call_dict['command']
|
||||
if command == 'create':
|
||||
test_suite = tool_call_dict['file_text']
|
||||
if (
|
||||
command != 'str_replace'
|
||||
and command != 'insert'
|
||||
and 'coverage' not in command
|
||||
):
|
||||
print(command)
|
||||
if command == 'insert':
|
||||
test_suite_new = insert_line_in_string(
|
||||
test_suite,
|
||||
tool_call_dict['new_str'],
|
||||
tool_call_dict['insert_line'],
|
||||
)
|
||||
test_suite = test_suite_new
|
||||
if command == 'str_replace':
|
||||
if (
|
||||
test_suite.count(
|
||||
tool_call_dict['old_str']
|
||||
)
|
||||
== 1
|
||||
):
|
||||
test_suite_new = test_suite.replace(
|
||||
tool_call_dict['old_str'],
|
||||
tool_call_dict['new_str'],
|
||||
)
|
||||
else:
|
||||
continue
|
||||
test_suite = test_suite_new
|
||||
except Exception:
|
||||
print(traceback.format_exc())
|
||||
continue
|
||||
|
||||
metadata_copy['test_result']['test_suite'] = test_suite
|
||||
if i < 25:
|
||||
final_output[i].append(metadata_copy)
|
||||
i += 1
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
print(f' Error loading {file_path}: {e}')
|
||||
|
||||
for j in range(i, 24):
|
||||
final_output[j].append(metadata_copy)
|
||||
metadata_orig = copy.deepcopy(metadata)
|
||||
metadata_orig['test_result']['test_suite'] = orig_test_suite
|
||||
final_output[24].append(metadata_orig)
|
||||
|
||||
for i in range(25):
|
||||
output_file = os.path.join(output_dir, f'output_{i}.jsonl')
|
||||
with open(output_file, 'w') as f:
|
||||
for metadata in final_output[i]:
|
||||
f.write(json.dumps(metadata) + '\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Parse JSON file')
|
||||
parser.add_argument('--root_dir', type=str, help='Root directory', required=True)
|
||||
parser.add_argument(
|
||||
'--output_dir', type=str, help='Output directory', required=True
|
||||
)
|
||||
parser.add_argument(
|
||||
'--starting_preds_file', type=str, help='Starting predictions', default=None
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
output_file = os.path.join(args.output_dir, 'output.jsonl')
|
||||
metadata_objs = {}
|
||||
with open(output_file, 'r') as f:
|
||||
content = f.readlines()
|
||||
for line in content:
|
||||
metadata = json.loads(line)
|
||||
metadata_objs[metadata['instance_id']] = metadata
|
||||
|
||||
starting_preds_file = args.starting_preds_file
|
||||
preds_objs = {}
|
||||
if starting_preds_file is not None:
|
||||
with open(starting_preds_file, 'r') as f:
|
||||
content = f.readlines()
|
||||
for line in content:
|
||||
pred = json.loads(line)
|
||||
preds_objs[pred['id']] = pred['preds']['full'][0]
|
||||
|
||||
parse_json_files(args.root_dir, args.output_dir, metadata_objs, preds_objs)
|
||||
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
|
||||
import pandas as pd
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Compare two TestGenEval output JSONL files and print the resolved diff'
|
||||
)
|
||||
parser.add_argument('input_file_1', type=str)
|
||||
parser.add_argument('input_file_2', type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
df1 = pd.read_json(args.input_file_1, orient='records', lines=True)
|
||||
df2 = pd.read_json(args.input_file_2, orient='records', lines=True)
|
||||
|
||||
|
||||
# Get the intersection of the ids
|
||||
df = pd.merge(df1, df2, on='id', how='inner')
|
||||
|
||||
|
||||
def _get_coverage(report):
|
||||
if report is None:
|
||||
return False
|
||||
if isinstance(report, float):
|
||||
return False
|
||||
else:
|
||||
return report.get('test_pass', False)
|
||||
|
||||
|
||||
df['test_pass_x'] = df['test_pass_x'].apply(_get_coverage)
|
||||
df['test_pass_y'] = df['test_pass_y'].apply(_get_coverage)
|
||||
df['diff'] = df.apply(lambda x: x['test_pass_x'] != x['test_pass_y'], axis=1)
|
||||
|
||||
df_diff = df[df['diff']].sort_values(
|
||||
by=['test_pass_x', 'test_pass_y'], ascending=[False, False]
|
||||
)
|
||||
# skip if any of the pass is nan, which means one of the eval is not finished yet
|
||||
df_diff = df_diff[df_diff['test_pass_x'].notna() & df_diff['test_pass_y'].notna()]
|
||||
|
||||
print(f'X={args.input_file_1}')
|
||||
print(f'Y={args.input_file_2}')
|
||||
print(f'# diff={df_diff.shape[0]}')
|
||||
df_diff = df_diff[['id', 'test_pass_x', 'test_pass_y', 'report_x', 'report_y']]
|
||||
|
||||
# x pass but y not
|
||||
print('-' * 100)
|
||||
df_diff_x_only = df_diff[df_diff['test_pass_x'] & ~df_diff['test_pass_y']].sort_values(
|
||||
by='id'
|
||||
)
|
||||
print(f'# x pass but y not={df_diff_x_only.shape[0]}')
|
||||
print(df_diff_x_only[['id', 'report_x', 'report_y']])
|
||||
|
||||
# y pass but x not
|
||||
print('-' * 100)
|
||||
df_diff_y_only = df_diff[~df_diff['test_pass_x'] & df_diff['test_pass_y']].sort_values(
|
||||
by='id'
|
||||
)
|
||||
print(f'# y pass but x not={df_diff_y_only.shape[0]}')
|
||||
print(df_diff_y_only[['id', 'report_x', 'report_y']])
|
||||
# get instance_id from df_diff_y_only
|
||||
print('-' * 100)
|
||||
print('Instances that x pass but y not:')
|
||||
print(df_diff_x_only['id'].tolist())
|
||||
|
||||
print('-' * 100)
|
||||
print('Instances that y pass but x not:')
|
||||
print(df_diff_y_only['id'].tolist())
|
||||
Executable
+28
@@ -0,0 +1,28 @@
|
||||
#!/bin/bash
|
||||
|
||||
FOLDER_PATH=$1
|
||||
NEW_FOLDER_PATH=${FOLDER_PATH}.swebench_submission
|
||||
mkdir -p $NEW_FOLDER_PATH
|
||||
|
||||
# Build all_preds.jsonl
|
||||
poetry run python evaluation/testgeneval/scripts/eval/convert_oh_output_to_swe_json.py $FOLDER_PATH/output.jsonl
|
||||
mv $FOLDER_PATH/output.swebench.jsonl $NEW_FOLDER_PATH/all_preds.jsonl
|
||||
|
||||
# Build trajs/
|
||||
mkdir -p $NEW_FOLDER_PATH/trajs
|
||||
for instance_dir in $FOLDER_PATH/llm_completions/*/; do
|
||||
instance_id=$(basename "$instance_dir")
|
||||
latest_json=$(ls -t "$instance_dir"/*.json | head -n1)
|
||||
if [ -n "$latest_json" ]; then
|
||||
cat "$latest_json" | jq -r '.messages' > "$NEW_FOLDER_PATH/trajs/$instance_id.json"
|
||||
fi
|
||||
done
|
||||
|
||||
# Build logs/
|
||||
# check if $FOLDER_PATH/eval_outputs exists, if so copy over - else raise error
|
||||
if [ -d "$FOLDER_PATH/eval_outputs" ]; then
|
||||
cp -r $FOLDER_PATH/eval_outputs $NEW_FOLDER_PATH/logs
|
||||
else
|
||||
echo "Error: $FOLDER_PATH/eval_outputs does not exist. You should run the local docker eval_infer.sh first."
|
||||
exit 1
|
||||
fi
|
||||
@@ -0,0 +1,91 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Convert OpenHands output to a readable markdown format for visualization."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from evaluation.testgeneval.eval_infer import process_test_suite
|
||||
from openhands.events.serialization import event_from_dict
|
||||
|
||||
tqdm.pandas()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('oh_output_file', type=str)
|
||||
args = parser.parse_args()
|
||||
output_md_folder = args.oh_output_file.replace('.jsonl', '.viz')
|
||||
print(f'Converting {args.oh_output_file} to markdown files in {output_md_folder}')
|
||||
|
||||
oh_format = pd.read_json(args.oh_output_file, orient='records', lines=True)
|
||||
# model name is the folder name of oh_output_file
|
||||
model_name = os.path.basename(os.path.dirname(args.oh_output_file))
|
||||
|
||||
|
||||
def convert_history_to_str(history):
|
||||
ret = ''
|
||||
separator = '\n\n' + '-' * 100 + '\n'
|
||||
|
||||
for i, event in enumerate(history):
|
||||
if i != 0:
|
||||
ret += separator
|
||||
|
||||
if isinstance(event, list):
|
||||
# "event" is a legacy pair of (action, observation)
|
||||
event_obj = event_from_dict(event[0])
|
||||
ret += f'## {i+1}| {event_obj.__class__.__name__}\n\n'
|
||||
ret += str(event_obj)
|
||||
ret += separator
|
||||
|
||||
event_obj = event_from_dict(event[1])
|
||||
ret += f'## {i+1}| {event_obj.__class__.__name__}\n\n'
|
||||
ret += str(event_obj)
|
||||
else:
|
||||
# "event" is a single event
|
||||
event_obj = event_from_dict(event)
|
||||
ret += f'## {i+1}| {event_obj.__class__.__name__}\n\n'
|
||||
ret += str(event_obj)
|
||||
return ret
|
||||
|
||||
|
||||
def write_row_to_md_file(row):
|
||||
if 'test_suite' in row:
|
||||
test_suite = row['test_suite']
|
||||
elif 'test_result' in row and 'test_suite' in row['test_result']:
|
||||
test_suite = row['test_result']['test_suite']
|
||||
else:
|
||||
raise ValueError(f'Row {row} does not have a test_suite')
|
||||
|
||||
if 'report' in row:
|
||||
coverage = row['report'].get('coverage', 0)
|
||||
mutation = row['report'].get('mutation_score', 0)
|
||||
else:
|
||||
coverage = None
|
||||
mutation = None
|
||||
|
||||
id = row['id']
|
||||
filename = f'{id}.md'
|
||||
os.makedirs(output_md_folder, exist_ok=True)
|
||||
filepath = os.path.join(output_md_folder, filename)
|
||||
|
||||
with open(filepath, 'w') as f:
|
||||
f.write(f'# {id} (coverage: {coverage})\n')
|
||||
f.write(f'# {id} (mutation score: {mutation})\n')
|
||||
|
||||
# MetaData
|
||||
f.write('## MetaData\n')
|
||||
f.write('```json\n')
|
||||
f.write(json.dumps(row['metadata'], indent=2))
|
||||
f.write('\n```\n')
|
||||
|
||||
# Trajectory
|
||||
f.write('## History\n')
|
||||
f.write(convert_history_to_str(row['history']))
|
||||
|
||||
f.write('## Test Suite\n')
|
||||
f.write(f'{test_suite}\n')
|
||||
|
||||
|
||||
oh_format.progress_apply(write_row_to_md_file, axis=1)
|
||||
@@ -0,0 +1,35 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from evaluation.swe_bench.eval_infer import process_git_patch
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('oh_output_file', type=str)
|
||||
args = parser.parse_args()
|
||||
output_filepath = args.oh_output_file.replace('.jsonl', '.swebench.jsonl')
|
||||
print(f'Converting {args.oh_output_file} to {output_filepath}')
|
||||
|
||||
oh_format = pd.read_json(args.oh_output_file, orient='records', lines=True)
|
||||
# model name is the folder name of oh_output_file
|
||||
model_name = os.path.basename(os.path.dirname(args.oh_output_file))
|
||||
|
||||
|
||||
def convert_row_to_swebench_format(row):
|
||||
if 'git_patch' in row:
|
||||
model_patch = row['git_patch']
|
||||
elif 'test_result' in row and 'git_patch' in row['test_result']:
|
||||
model_patch = row['test_result']['git_patch']
|
||||
else:
|
||||
raise ValueError(f'Row {row} does not have a git_patch')
|
||||
|
||||
return {
|
||||
'instance_id': row['instance_id'],
|
||||
'model_patch': process_git_patch(model_patch),
|
||||
'model_name_or_path': model_name,
|
||||
}
|
||||
|
||||
|
||||
swebench_format = oh_format.apply(convert_row_to_swebench_format, axis=1)
|
||||
swebench_format.to_json(output_filepath, lines=True, orient='records')
|
||||
@@ -0,0 +1,27 @@
|
||||
import argparse
|
||||
|
||||
import pandas as pd
|
||||
from datasets import load_dataset
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('output_filepath', type=str, help='Path to save the output file')
|
||||
parser.add_argument(
|
||||
'--dataset_name',
|
||||
type=str,
|
||||
help='Name of the dataset to download',
|
||||
default='kjain14/testgeneval',
|
||||
)
|
||||
parser.add_argument('--split', type=str, help='Split to download', default='test')
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset = load_dataset(args.dataset_name, split=args.split)
|
||||
output_filepath = args.output_filepath
|
||||
print(
|
||||
f'Downloading gold test suites from {args.dataset_name} (split: {args.split}) to {output_filepath}'
|
||||
)
|
||||
test_suites = [
|
||||
{'instance_id': row['instance_id'], 'test_suite': row['test_src']} for row in dataset
|
||||
]
|
||||
print(f'{len(test_suites)} test suites loaded')
|
||||
pd.DataFrame(test_suites).to_json(output_filepath, lines=True, orient='records')
|
||||
print(f'Test suites saved to {output_filepath}')
|
||||
@@ -0,0 +1,122 @@
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import json
|
||||
from collections import Counter
|
||||
|
||||
from openhands.events.serialization import event_from_dict
|
||||
from openhands.events.utils import get_pairs_from_events
|
||||
|
||||
ERROR_KEYWORDS = [
|
||||
'Agent encountered an error while processing the last action',
|
||||
'APIError',
|
||||
'Action execution failed',
|
||||
]
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('output_file', type=str, help='The file to summarize')
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.output_file, 'r') as file:
|
||||
lines = file.readlines()
|
||||
|
||||
num_lines = len(lines)
|
||||
num_error_lines = 0
|
||||
num_agent_stuck_in_loop = 0
|
||||
|
||||
coverage = 0
|
||||
mutation_score = 0
|
||||
num_empty_suite = 0
|
||||
|
||||
error_counter = Counter()
|
||||
|
||||
main_agent_cost = []
|
||||
editor_cost = []
|
||||
num_turns = []
|
||||
|
||||
for line in lines:
|
||||
_d = json.loads(line)
|
||||
|
||||
# Cost
|
||||
costs = _d['metrics'].get('costs', [])
|
||||
_cur_main_agent_cost = 0
|
||||
_cur_editor_cost = 0
|
||||
for cost in costs:
|
||||
if isinstance(cost, float):
|
||||
# backward compatible
|
||||
_cur_main_agent_cost += cost
|
||||
else:
|
||||
if 'draft_editor' in cost['model']:
|
||||
_cur_editor_cost += cost['cost']
|
||||
else:
|
||||
_cur_main_agent_cost += cost['cost']
|
||||
|
||||
main_agent_cost.append(_cur_main_agent_cost)
|
||||
editor_cost.append(_cur_editor_cost)
|
||||
|
||||
# Turn status
|
||||
history = _d.get('history', [])
|
||||
events = [event_from_dict(event) for event in history]
|
||||
pairs = get_pairs_from_events(events)
|
||||
num_turns.append(len(pairs))
|
||||
|
||||
# Suite & resolve status
|
||||
suite = _d.get('test_result', {}).get('test_suite', '')
|
||||
if suite == '':
|
||||
num_empty_suite += 1
|
||||
continue
|
||||
|
||||
report = _d.get('report', {}) or {}
|
||||
coverage += report.get('coverage', 0)
|
||||
mutation_score += report.get('mutation_score', 0)
|
||||
|
||||
# Error
|
||||
error = _d.get('error', None)
|
||||
|
||||
if error is not None and isinstance(error, str):
|
||||
agent_stuck_in_loop = 'Agent got stuck in a loop' in error
|
||||
contains_error = bool(error) and not agent_stuck_in_loop
|
||||
if agent_stuck_in_loop:
|
||||
error_counter['Agent got stuck in a loop'] += 1
|
||||
num_agent_stuck_in_loop += 1
|
||||
elif contains_error:
|
||||
error_counter[error] += 1
|
||||
continue
|
||||
|
||||
for keyword in ERROR_KEYWORDS:
|
||||
if keyword in line:
|
||||
error_counter[keyword] += 1
|
||||
num_error_lines += 1
|
||||
break
|
||||
|
||||
# print the error counter (with percentage)
|
||||
print(
|
||||
f'Average coverage for {num_lines} ({coverage / num_lines * 100:.2f}%)'
|
||||
)
|
||||
print(
|
||||
f'Average mutation score for {num_lines} ({mutation_score / num_lines * 100:.2f}%)'
|
||||
)
|
||||
|
||||
print(
|
||||
f'Number of empty suite: {num_empty_suite} / {num_lines} ({num_empty_suite / num_lines * 100:.2f}%)'
|
||||
)
|
||||
print(
|
||||
f'Number of error lines: {num_error_lines} / {num_lines} ({num_error_lines / num_lines * 100:.2f}%)'
|
||||
)
|
||||
print(
|
||||
f'Number of agent stuck in loop: {num_agent_stuck_in_loop} / {num_lines} ({num_agent_stuck_in_loop / num_lines * 100:.2f}%)'
|
||||
)
|
||||
assert len(num_turns) == num_lines
|
||||
assert len(main_agent_cost) == num_lines
|
||||
assert len(editor_cost) == num_lines
|
||||
print('## Statistics')
|
||||
print(f'Avg. num of turns per instance: {sum(num_turns) / num_lines:.2f}')
|
||||
print(f'Avg. agent cost per instance: {sum(main_agent_cost) / num_lines:.2f} USD')
|
||||
print(f'Avg. editor cost per instance: {sum(editor_cost) / num_lines:.2f} USD')
|
||||
print(
|
||||
f'Avg. total cost per instance: {(sum(main_agent_cost) + sum(editor_cost)) / num_lines:.2f} USD'
|
||||
)
|
||||
|
||||
print('## Detailed error breakdown:')
|
||||
for error, count in error_counter.items():
|
||||
print(f'{error}: {count} ({count / num_lines * 100:.2f}%)')
|
||||
+53
@@ -0,0 +1,53 @@
|
||||
#!/bin/bash
|
||||
set -eo pipefail
|
||||
|
||||
INPUT_FILE=$1
|
||||
NUM_WORKERS=$2
|
||||
DATASET=$3
|
||||
SPLIT=$4
|
||||
SKIP_MUTATION=$5
|
||||
|
||||
if [ -z "$INPUT_FILE" ]; then
|
||||
echo "INPUT_FILE not specified (should be a path to a jsonl file)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$DATASET" ]; then
|
||||
echo "DATASET not specified, use default kjain14/testgenevallite"
|
||||
DATASET="kjain14/testgenevallite"
|
||||
fi
|
||||
|
||||
if [ -z "$SPLIT" ]; then
|
||||
echo "SPLIT not specified, use default test"
|
||||
SPLIT="test"
|
||||
fi
|
||||
|
||||
if [ -z "$NUM_WORKERS" ]; then
|
||||
echo "NUM_WORKERS not specified, use default 1"
|
||||
NUM_WORKERS=1
|
||||
fi
|
||||
|
||||
echo "... Evaluating on $INPUT_FILE ..."
|
||||
|
||||
COMMAND="poetry run python evaluation/benchmarks/testgeneval/eval_infer.py \
|
||||
--eval-num-workers $NUM_WORKERS \
|
||||
--input-file $INPUT_FILE \
|
||||
--dataset $DATASET \
|
||||
--split $SPLIT"
|
||||
|
||||
if [ "$SKIP_MUTATION" == "true" ]; then
|
||||
echo "Skipping mutation evaluation"
|
||||
COMMAND="$COMMAND --skip_mutation"
|
||||
fi
|
||||
|
||||
if [ -n "$EVAL_LIMIT" ]; then
|
||||
echo "EVAL_LIMIT: $EVAL_LIMIT"
|
||||
COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
|
||||
fi
|
||||
|
||||
echo $COMMAND
|
||||
# Run the command
|
||||
eval $COMMAND
|
||||
|
||||
# update the output with evaluation results
|
||||
# poetry run python evaluation/benchmarks/testgeneval/scripts/eval/update_output_with_eval.py $INPUT_FILE
|
||||
+122
@@ -0,0 +1,122 @@
|
||||
#!/bin/bash
|
||||
set -eo pipefail
|
||||
|
||||
source "evaluation/utils/version_control.sh"
|
||||
|
||||
MODEL_CONFIG=$1
|
||||
COMMIT_HASH=$2
|
||||
AGENT=$3
|
||||
EVAL_LIMIT=$4
|
||||
MAX_ITER=$5
|
||||
NUM_WORKERS=$6
|
||||
DATASET=$7
|
||||
SPLIT=$8
|
||||
N_RUNS=$9
|
||||
ZERO_SHOT_PATH=${10} # New argument for zero-shot path
|
||||
|
||||
if [ -z "$NUM_WORKERS" ]; then
|
||||
NUM_WORKERS=1
|
||||
echo "Number of workers not specified, use default $NUM_WORKERS"
|
||||
fi
|
||||
checkout_eval_branch
|
||||
|
||||
if [ -z "$AGENT" ]; then
|
||||
echo "Agent not specified, use default CodeActAgent"
|
||||
AGENT="CodeActAgent"
|
||||
fi
|
||||
|
||||
if [ -z "$MAX_ITER" ]; then
|
||||
echo "MAX_ITER not specified, use default 100"
|
||||
MAX_ITER=100
|
||||
fi
|
||||
|
||||
if [ -z "$USE_INSTANCE_IMAGE" ]; then
|
||||
echo "USE_INSTANCE_IMAGE not specified, use default true"
|
||||
USE_INSTANCE_IMAGE=true
|
||||
fi
|
||||
|
||||
if [ -z "$RUN_WITH_BROWSING" ]; then
|
||||
echo "RUN_WITH_BROWSING not specified, use default false"
|
||||
RUN_WITH_BROWSING=false
|
||||
fi
|
||||
|
||||
|
||||
if [ -z "$DATASET" ]; then
|
||||
echo "DATASET not specified, use default princeton-nlp/SWE-bench_Lite"
|
||||
DATASET="princeton-nlp/SWE-bench_Lite"
|
||||
fi
|
||||
|
||||
if [ -z "$SPLIT" ]; then
|
||||
echo "SPLIT not specified, use default test"
|
||||
SPLIT="test"
|
||||
fi
|
||||
|
||||
export USE_INSTANCE_IMAGE=$USE_INSTANCE_IMAGE
|
||||
echo "USE_INSTANCE_IMAGE: $USE_INSTANCE_IMAGE"
|
||||
export RUN_WITH_BROWSING=$RUN_WITH_BROWSING
|
||||
echo "RUN_WITH_BROWSING: $RUN_WITH_BROWSING"
|
||||
|
||||
get_openhands_version
|
||||
|
||||
echo "AGENT: $AGENT"
|
||||
echo "OPENHANDS_VERSION: $OPENHANDS_VERSION"
|
||||
echo "MODEL_CONFIG: $MODEL_CONFIG"
|
||||
echo "DATASET: $DATASET"
|
||||
echo "SPLIT: $SPLIT"
|
||||
|
||||
# Default to NOT use Hint
|
||||
if [ -z "$USE_HINT_TEXT" ]; then
|
||||
export USE_HINT_TEXT=false
|
||||
fi
|
||||
echo "USE_HINT_TEXT: $USE_HINT_TEXT"
|
||||
EVAL_NOTE="$OPENHANDS_VERSION"
|
||||
# if not using Hint, add -no-hint to the eval note
|
||||
if [ "$USE_HINT_TEXT" = false ]; then
|
||||
EVAL_NOTE="$EVAL_NOTE-no-hint"
|
||||
fi
|
||||
|
||||
if [ "$RUN_WITH_BROWSING" = true ]; then
|
||||
EVAL_NOTE="$EVAL_NOTE-with-browsing"
|
||||
fi
|
||||
|
||||
if [ -n "$EXP_NAME" ]; then
|
||||
EVAL_NOTE="$EVAL_NOTE-$EXP_NAME"
|
||||
fi
|
||||
|
||||
function run_eval() {
|
||||
local eval_note=$1
|
||||
COMMAND="poetry run python evaluation/benchmarks/testgeneval/run_infer.py \
|
||||
--agent-cls $AGENT \
|
||||
--llm-config $MODEL_CONFIG \
|
||||
--max-iterations $MAX_ITER \
|
||||
--eval-num-workers $NUM_WORKERS \
|
||||
--eval-note $eval_note \
|
||||
--dataset $DATASET \
|
||||
--split $SPLIT"
|
||||
|
||||
if [ -n "$EVAL_LIMIT" ]; then
|
||||
echo "EVAL_LIMIT: $EVAL_LIMIT"
|
||||
COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
|
||||
fi
|
||||
|
||||
if [ -n "$ZERO_SHOT_PATH" ]; then
|
||||
echo "ZERO_SHOT_PATH: $ZERO_SHOT_PATH"
|
||||
COMMAND="$COMMAND --testfile_start --zero_shot_path $ZERO_SHOT_PATH"
|
||||
fi
|
||||
|
||||
eval $COMMAND
|
||||
}
|
||||
|
||||
unset SANDBOX_ENV_GITHUB_TOKEN # prevent the agent from using the github token to push
|
||||
if [ -z "$N_RUNS" ]; then
|
||||
N_RUNS=1
|
||||
echo "N_RUNS not specified, use default $N_RUNS"
|
||||
fi
|
||||
|
||||
for i in $(seq 1 $N_RUNS); do
|
||||
current_eval_note="$EVAL_NOTE-run_$i"
|
||||
echo "EVAL_NOTE: $current_eval_note"
|
||||
run_eval $current_eval_note
|
||||
done
|
||||
|
||||
checkout_original_branch
|
||||
@@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
|
||||
source ~/.bashrc
|
||||
SWEUTIL_DIR=/swe_util
|
||||
|
||||
# FIXME: Cannot read SWE_INSTANCE_ID from the environment variable
|
||||
# SWE_INSTANCE_ID=django__django-11099
|
||||
if [ -z "$SWE_INSTANCE_ID" ]; then
|
||||
echo "Error: SWE_INSTANCE_ID is not set." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Read the swe-bench-test-lite.json file and extract the required item based on instance_id
|
||||
item=$(jq --arg INSTANCE_ID "$SWE_INSTANCE_ID" '.[] | select(.instance_id == $INSTANCE_ID)' $SWEUTIL_DIR/eval_data/instances/swe-bench-instance.json)
|
||||
|
||||
if [[ -z "$item" ]]; then
|
||||
echo "No item found for the provided instance ID."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
WORKSPACE_NAME=$(echo "$item" | jq -r '(.repo | tostring) + "__" + (.version | tostring) | gsub("/"; "__")')
|
||||
|
||||
echo "WORKSPACE_NAME: $WORKSPACE_NAME"
|
||||
|
||||
# Clear the workspace
|
||||
if [ -d /workspace ]; then
|
||||
rm -rf /workspace/*
|
||||
else
|
||||
mkdir /workspace
|
||||
fi
|
||||
# Copy repo to workspace
|
||||
if [ -d /workspace/$WORKSPACE_NAME ]; then
|
||||
rm -rf /workspace/$WORKSPACE_NAME
|
||||
fi
|
||||
mkdir -p /workspace
|
||||
ln -s /testbed /workspace/$WORKSPACE_NAME
|
||||
|
||||
# Activate instance-specific environment
|
||||
. /opt/miniconda3/etc/profile.d/conda.sh
|
||||
conda activate testbed
|
||||
@@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
EVAL_WORKSPACE="evaluation/swe_bench/eval_workspace"
|
||||
mkdir -p $EVAL_WORKSPACE
|
||||
|
||||
# 1. Prepare REPO
|
||||
echo "==== Prepare SWE-bench repo ===="
|
||||
OH_SWE_BENCH_REPO_PATH="https://github.com/All-Hands-AI/SWE-bench.git"
|
||||
OH_SWE_BENCH_REPO_BRANCH="eval"
|
||||
git clone -b $OH_SWE_BENCH_REPO_BRANCH $OH_SWE_BENCH_REPO_PATH $EVAL_WORKSPACE/OH-SWE-bench
|
||||
|
||||
# 2. Prepare DATA
|
||||
echo "==== Prepare SWE-bench data ===="
|
||||
EVAL_IMAGE=ghcr.io/all-hands-ai/eval-swe-bench:builder_with_conda
|
||||
EVAL_WORKSPACE=$(realpath $EVAL_WORKSPACE)
|
||||
chmod +x $EVAL_WORKSPACE/OH-SWE-bench/swebench/harness/prepare_data.sh
|
||||
if [ -d $EVAL_WORKSPACE/eval_data ]; then
|
||||
rm -r $EVAL_WORKSPACE/eval_data
|
||||
fi
|
||||
docker run \
|
||||
-v $EVAL_WORKSPACE:/workspace \
|
||||
-w /workspace \
|
||||
-u $(id -u):$(id -g) \
|
||||
-e HF_DATASETS_CACHE="/tmp" \
|
||||
--rm -it $EVAL_IMAGE \
|
||||
bash -c "cd OH-SWE-bench/swebench/harness && /swe_util/miniforge3/bin/conda run -n swe-bench-eval ./prepare_data.sh && mv eval_data /workspace/"
|
||||
@@ -0,0 +1,96 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
# assert user name is `root`
|
||||
if [ "$USER" != "root" ]; then
|
||||
echo "Error: This script is intended to be run by the 'root' user only." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
source ~/.bashrc
|
||||
|
||||
SWEUTIL_DIR=/swe_util
|
||||
|
||||
# Create logs directory
|
||||
LOG_DIR=/openhands/logs
|
||||
mkdir -p $LOG_DIR && chmod 777 $LOG_DIR
|
||||
|
||||
# FIXME: Cannot read SWE_INSTANCE_ID from the environment variable
|
||||
# SWE_INSTANCE_ID=django__django-11099
|
||||
if [ -z "$SWE_INSTANCE_ID" ]; then
|
||||
echo "Error: SWE_INSTANCE_ID is not set." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Read the swe-bench-test-lite.json file and extract the required item based on instance_id
|
||||
item=$(jq --arg INSTANCE_ID "$SWE_INSTANCE_ID" '.[] | select(.instance_id == $INSTANCE_ID)' $SWEUTIL_DIR/eval_data/instances/swe-bench-test-lite.json)
|
||||
|
||||
if [[ -z "$item" ]]; then
|
||||
echo "No item found for the provided instance ID."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
CONDA_ENV_NAME=$(echo "$item" | jq -r '.repo + "__" + .version | gsub("/"; "__")')
|
||||
|
||||
echo "CONDA_ENV_NAME: $CONDA_ENV_NAME"
|
||||
|
||||
SWE_TASK_DIR=/openhands/swe_tasks
|
||||
mkdir -p $SWE_TASK_DIR
|
||||
# Dump test_patch to /workspace/test.patch
|
||||
echo "$item" | jq -r '.test_patch' > $SWE_TASK_DIR/test.patch
|
||||
# Dump patch to /workspace/gold.patch
|
||||
echo "$item" | jq -r '.patch' > $SWE_TASK_DIR/gold.patch
|
||||
# Dump the item to /workspace/instance.json except for the "test_patch" and "patch" fields
|
||||
echo "$item" | jq 'del(.test_patch, .patch)' > $SWE_TASK_DIR/instance.json
|
||||
|
||||
# Clear the workspace
|
||||
rm -rf /workspace/*
|
||||
# Copy repo to workspace
|
||||
if [ -d /workspace/$CONDA_ENV_NAME ]; then
|
||||
rm -rf /workspace/$CONDA_ENV_NAME
|
||||
fi
|
||||
cp -r $SWEUTIL_DIR/eval_data/testbeds/$CONDA_ENV_NAME /workspace
|
||||
|
||||
# Reset swe-bench testbed and install the repo
|
||||
. $SWEUTIL_DIR/miniforge3/etc/profile.d/conda.sh
|
||||
conda config --set changeps1 False
|
||||
conda config --append channels conda-forge
|
||||
conda activate swe-bench-eval
|
||||
|
||||
mkdir -p $SWE_TASK_DIR/reset_testbed_temp
|
||||
mkdir -p $SWE_TASK_DIR/reset_testbed_log_dir
|
||||
SWE_BENCH_DIR=/swe_util/OH-SWE-bench
|
||||
output=$(
|
||||
export PYTHONPATH=$SWE_BENCH_DIR && \
|
||||
cd $SWE_BENCH_DIR && \
|
||||
python swebench/harness/reset_swe_env.py \
|
||||
--swe_bench_tasks $SWEUTIL_DIR/eval_data/instances/swe-bench-test.json \
|
||||
--temp_dir $SWE_TASK_DIR/reset_testbed_temp \
|
||||
--testbed /workspace \
|
||||
--conda_path $SWEUTIL_DIR/miniforge3 \
|
||||
--instance_id $SWE_INSTANCE_ID \
|
||||
--log_dir $SWE_TASK_DIR/reset_testbed_log_dir \
|
||||
--timeout 900 \
|
||||
--verbose
|
||||
)
|
||||
|
||||
REPO_PATH=$(echo "$output" | awk -F': ' '/repo_path:/ {print $2}')
|
||||
TEST_CMD=$(echo "$output" | awk -F': ' '/test_cmd:/ {print $2}')
|
||||
echo "Repo Path: $REPO_PATH"
|
||||
echo "Test Command: $TEST_CMD"
|
||||
|
||||
echo "export SWE_BENCH_DIR=\"$SWE_BENCH_DIR\"" >> ~/.bashrc
|
||||
echo "export REPO_PATH=\"$REPO_PATH\"" >> ~/.bashrc
|
||||
echo "export TEST_CMD=\"$TEST_CMD\"" >> ~/.bashrc
|
||||
|
||||
if [[ "$REPO_PATH" == "None" ]]; then
|
||||
echo "Error: Failed to retrieve repository path. Tests may not have passed or output was not as expected." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Activate instance-specific environment
|
||||
. $SWEUTIL_DIR/miniforge3/etc/profile.d/conda.sh
|
||||
conda activate $CONDA_ENV_NAME
|
||||
|
||||
set +e
|
||||
@@ -0,0 +1,327 @@
|
||||
import ast
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
from evaluation.benchmarks.testgeneval.constants import TestStatus
|
||||
from evaluation.benchmarks.testgeneval.log_parsers import (
|
||||
MAP_REPO_TO_PARSER,
|
||||
parse_log_pytest,
|
||||
)
|
||||
|
||||
|
||||
def indent_text(text, indent_level):
|
||||
return '\n'.join(
|
||||
' ' * indent_level + line if line.strip() else line for line in text.split('\n')
|
||||
)
|
||||
|
||||
|
||||
def extract_preamble_classes_and_functions(code):
|
||||
class_pattern = re.compile(
|
||||
r'(?P<decorators>(?:^@[^\r\n]*(?:\r?\n(?:[ \t]+[^\r\n]*|^\)[^\r\n]*)*)*\r?\n)*?)'
|
||||
r'^class\s+([\w]+)(?:\([^)]*\))?:', # the class line
|
||||
re.MULTILINE,
|
||||
)
|
||||
# Capture methods with or without decorators
|
||||
method_pattern = re.compile(r'(^(\s*@.*\s*)*^\s*def\s+[\w_]+\(.*\):)', re.MULTILINE)
|
||||
|
||||
# Capture functions with or without decorators
|
||||
function_pattern = re.compile(
|
||||
r'(?P<decorators>(?:^@[^\r\n]*(?:\r?\n(?:[ \t]+[^\r\n]*|^\)[^\r\n]*)*)*\r?\n)*?)'
|
||||
r'^def\s+([\w_]+)\(.*\):', # the function line
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
preamble = ''
|
||||
classes = []
|
||||
test_functions = []
|
||||
|
||||
current_position = 0
|
||||
|
||||
def extract_class_body(code: str, start_index: int) -> Tuple[str, int]:
|
||||
"""
|
||||
Extracts the body of a class from the given code starting from the specified index.
|
||||
Returns the class body and the end index of the class body.
|
||||
"""
|
||||
if not code or start_index < 0 or start_index >= len(code):
|
||||
raise ValueError('Invalid code or start index')
|
||||
|
||||
# Split the code into lines
|
||||
lines = code[start_index:].split('\n')
|
||||
class_body_lines = []
|
||||
|
||||
# Find the starting indentation level of the class definition
|
||||
class_start_line = lines[0]
|
||||
start_indent = len(class_start_line) - len(class_start_line.lstrip())
|
||||
|
||||
inside_multiline_comment = False
|
||||
end_index = start_index
|
||||
for i, line in enumerate(lines[1:], start=1):
|
||||
stripped_line = line.strip()
|
||||
current_indent = len(line) - len(line.lstrip())
|
||||
|
||||
# Handle multiline comments or docstrings
|
||||
if stripped_line.startswith('"""') or stripped_line.startswith("'''"):
|
||||
if inside_multiline_comment:
|
||||
inside_multiline_comment = False
|
||||
else:
|
||||
inside_multiline_comment = True
|
||||
|
||||
if not inside_multiline_comment:
|
||||
# Stop when we reach a line with less indentation than the class definition
|
||||
if current_indent <= start_indent and stripped_line:
|
||||
break
|
||||
|
||||
# Add lines that are part of the class body
|
||||
class_body_lines.append(line)
|
||||
# Update the end index to the current line end
|
||||
end_index = start_index + len('\n'.join(lines[: i + 1])) + 1
|
||||
|
||||
return code[start_index:end_index], end_index
|
||||
|
||||
while current_position < len(code):
|
||||
class_match = class_pattern.search(code, current_position)
|
||||
method_match = method_pattern.search(code, current_position)
|
||||
|
||||
if class_match and (
|
||||
not method_match or class_match.start() < method_match.start()
|
||||
):
|
||||
class_name = class_match.group(0)
|
||||
class_body, end_idx = extract_class_body(code, class_match.end())
|
||||
current_position = end_idx
|
||||
|
||||
methods = []
|
||||
class_prefix = class_name
|
||||
set_prefix = False
|
||||
for method_match in method_pattern.finditer(class_body):
|
||||
method_name = method_match.group()
|
||||
method_start = method_match.start()
|
||||
if not set_prefix:
|
||||
class_prefix = class_name + class_body[:method_start]
|
||||
set_prefix = True
|
||||
next_method = method_pattern.search(
|
||||
class_body, method_start + len(method_name)
|
||||
)
|
||||
method_body = (
|
||||
class_body[method_start : next_method.start()]
|
||||
if next_method
|
||||
else class_body[method_start:]
|
||||
)
|
||||
methods.append((method_name, method_body))
|
||||
|
||||
classes.append((class_prefix, methods, class_match.start()))
|
||||
|
||||
elif method_match:
|
||||
function_name = method_match.group(0)
|
||||
start_idx = method_match.start()
|
||||
|
||||
# Extract the current function's indentation level
|
||||
lines = code[start_idx:].split('\n')
|
||||
current_indent = len(lines[0]) - len(lines[0].lstrip())
|
||||
|
||||
next_function = function_pattern.search(
|
||||
code, start_idx + len(function_name)
|
||||
)
|
||||
while next_function and (
|
||||
class_match is None or next_function.start() < class_match.start()
|
||||
):
|
||||
# Calculate the indentation of the next function
|
||||
next_function_start = next_function.start()
|
||||
next_line = code[next_function_start:].split('\n', 1)[0]
|
||||
next_indent = len(next_line) - len(next_line.lstrip())
|
||||
|
||||
# Check if the next function is top-level
|
||||
if next_indent <= current_indent:
|
||||
break
|
||||
|
||||
# Continue searching for the next top-level function
|
||||
next_function = function_pattern.search(
|
||||
code, next_function.start() + len(next_function.group(0))
|
||||
)
|
||||
|
||||
if next_function:
|
||||
next_function_start = next_function.start()
|
||||
if class_match and next_function_start > class_match.start():
|
||||
next_function_start = class_match.start()
|
||||
function_body = code[start_idx:next_function_start]
|
||||
else:
|
||||
function_body = code[start_idx:]
|
||||
|
||||
test_functions.append((function_body, start_idx))
|
||||
current_position = start_idx + len(function_body)
|
||||
|
||||
else:
|
||||
break
|
||||
|
||||
if classes and test_functions:
|
||||
preamble = code[: min(classes[0][2], test_functions[0][1])]
|
||||
else:
|
||||
preamble = (
|
||||
code[: classes[0][2]]
|
||||
if classes
|
||||
else code[: test_functions[0][1]]
|
||||
if test_functions
|
||||
else code
|
||||
)
|
||||
|
||||
return preamble.strip(), classes, test_functions
|
||||
|
||||
|
||||
def filter_passing_tests(
|
||||
test_content: str, test_output: str, repo: str
|
||||
) -> Tuple[str, List[str], List[str]]:
|
||||
"""
|
||||
Filter tests based on their execution results.
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- Modified test content with only passing tests
|
||||
- List of passing test names
|
||||
- List of failing test names
|
||||
"""
|
||||
# Parse test results using appropriate parser
|
||||
parser = MAP_REPO_TO_PARSER.get(repo, parse_log_pytest)
|
||||
test_results = parser(test_output)
|
||||
# Get passing and failing tests
|
||||
passing_tests = []
|
||||
failing_tests = []
|
||||
for test_name, status in test_results.items():
|
||||
if status == TestStatus.PASSED.value:
|
||||
passing_tests.append(test_name)
|
||||
else:
|
||||
failing_tests.append(test_name)
|
||||
|
||||
if not passing_tests:
|
||||
return '', passing_tests, failing_tests
|
||||
|
||||
# Extract test components
|
||||
preamble, classes, functions = extract_preamble_classes_and_functions(test_content)
|
||||
|
||||
# Filter classes to only include passing methods
|
||||
filtered_classes = []
|
||||
for class_name, methods, start_idx in classes:
|
||||
non_fail_methods = []
|
||||
for method_name, method_body in methods:
|
||||
# Extract the base method name for matching
|
||||
method_full_name = (
|
||||
method_name.split('.')[-1].split('(')[0].strip().split(' ')[-1]
|
||||
)
|
||||
# Check if the method name is in failing_tests or if any failing_test is in the method name
|
||||
if not (
|
||||
any(method_full_name in failing_test for failing_test in failing_tests)
|
||||
or any(
|
||||
failing_test in method_full_name for failing_test in failing_tests
|
||||
)
|
||||
):
|
||||
non_fail_methods.append((method_name, method_body))
|
||||
|
||||
if non_fail_methods:
|
||||
filtered_classes.append((class_name, non_fail_methods, start_idx))
|
||||
|
||||
# Filter standalone functions
|
||||
filtered_functions = []
|
||||
for func_body, start_idx in functions:
|
||||
func_name = func_body.split('def ')[1].split('(')[0].strip()
|
||||
if any(func_name in failing_test for failing_test in failing_tests) or any(
|
||||
failing_test in func_name for failing_test in failing_tests
|
||||
):
|
||||
continue
|
||||
|
||||
filtered_functions.append((func_body, start_idx))
|
||||
|
||||
# Reconstruct test content with only passing tests
|
||||
content_parts = [preamble]
|
||||
|
||||
# Add filtered classes
|
||||
for class_name, methods, _ in filtered_classes:
|
||||
class_content = class_name + '\n'
|
||||
for _, method_body in methods:
|
||||
class_content += method_body + '\n'
|
||||
content_parts.append(class_content)
|
||||
|
||||
# Add filtered functions
|
||||
for func_body, _ in filtered_functions:
|
||||
content_parts.append(func_body)
|
||||
|
||||
return '\n\n'.join(content_parts), passing_tests, failing_tests
|
||||
|
||||
|
||||
def filter_tests(
|
||||
test_content: str, test_output: str, repo: str
|
||||
) -> Tuple[str, List[str], List[str]]:
|
||||
"""
|
||||
Filter tests using AST parsing to remove failing test functions from the test file.
|
||||
Non-test functions (e.g. setup or helper methods) and classes (even if all test methods are failing)
|
||||
are preserved.
|
||||
|
||||
If AST processing fails (for example, because the test file cannot be parsed),
|
||||
this function falls back on the existing regex-based filtering (filter_passing_tests).
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- Modified test content (as a string) containing only passing tests.
|
||||
- List of passing test names.
|
||||
- List of failing test names.
|
||||
"""
|
||||
try:
|
||||
# Attempt to parse the test file using the AST.
|
||||
tree = ast.parse(test_content)
|
||||
|
||||
# Parse test results using the appropriate parser.
|
||||
parser = MAP_REPO_TO_PARSER.get(repo, parse_log_pytest)
|
||||
test_results = parser(test_output)
|
||||
passing_tests = [
|
||||
name
|
||||
for name, status in test_results.items()
|
||||
if status == TestStatus.PASSED.value
|
||||
]
|
||||
failing_tests = [
|
||||
name
|
||||
for name, status in test_results.items()
|
||||
if status != TestStatus.PASSED.value
|
||||
]
|
||||
|
||||
# Helper function to decide if a test name should be considered failing.
|
||||
def is_failing(name: str) -> bool:
|
||||
for ft in failing_tests:
|
||||
if name in ft or ft in name:
|
||||
return True
|
||||
return False
|
||||
|
||||
new_body = []
|
||||
for node in tree.body:
|
||||
# For top-level function definitions, only filter those that look like tests.
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
if node.name.startswith('test') and is_failing(node.name):
|
||||
continue
|
||||
new_body.append(node)
|
||||
# For classes, filter out failing test methods but preserve other methods (e.g. setup).
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
new_class_body = []
|
||||
for subnode in node.body:
|
||||
if isinstance(subnode, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
# Only consider filtering if the method is a test.
|
||||
qualified_name = f'{node.name}.{subnode.name}'
|
||||
if is_failing(subnode.name) or is_failing(qualified_name):
|
||||
continue
|
||||
new_class_body.append(subnode)
|
||||
else:
|
||||
new_class_body.append(subnode)
|
||||
# Always include the class even if no test methods remain, as it might contain
|
||||
# setup, teardown, or other necessary logic.
|
||||
if new_class_body:
|
||||
node.body = new_class_body
|
||||
new_body.append(node)
|
||||
|
||||
else:
|
||||
new_body.append(node)
|
||||
|
||||
tree.body = new_body
|
||||
|
||||
# Reconstruct the source code from the filtered AST.
|
||||
# (Requires Python 3.9+ for ast.unparse; otherwise an exception will trigger the fallback.)
|
||||
new_test_content = ast.unparse(tree)
|
||||
return new_test_content, passing_tests, failing_tests
|
||||
|
||||
except Exception:
|
||||
print('AST processing failed; falling back on regex-based filtering.')
|
||||
# If AST processing fails for any reason, fall back on the original regex-based filtering.
|
||||
return filter_passing_tests(test_content, test_output, repo)
|
||||
@@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from evaluation.benchmarks.testgeneval.constants import (
|
||||
COVERAGE_PREFIX,
|
||||
KEY_INSTANCE_ID,
|
||||
MAP_REPO_VERSION_TO_SPECS,
|
||||
TESTS_FAILED,
|
||||
TESTS_SUFFIX,
|
||||
UPDATE_TOX,
|
||||
TestGenEvalInstance,
|
||||
)
|
||||
from evaluation.benchmarks.testgeneval.utils import (
|
||||
get_test_directives,
|
||||
)
|
||||
|
||||
DIFF_MODIFIED_FILE_REGEX = r'--- a/(.*)'
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestSpec:
|
||||
"""
|
||||
A dataclass that represents a test specification for a single instance of SWE-bench.
|
||||
"""
|
||||
|
||||
instance_id: str
|
||||
id: str
|
||||
repo: str
|
||||
version: str
|
||||
test_cmd: str
|
||||
code_file: str
|
||||
test_file: str
|
||||
baseline_covs: dict
|
||||
local_imports: list[str]
|
||||
test_script_list: list[str]
|
||||
mutation_script_list: list[str]
|
||||
|
||||
@property
|
||||
def test_script(self):
|
||||
return (
|
||||
'\n'.join(['#!/bin/bash', 'set -uo pipefail'] + self.test_script_list)
|
||||
+ '\n'
|
||||
)
|
||||
# Don't exit early because we need to revert tests at the end
|
||||
|
||||
@property
|
||||
def mutation_script(self):
|
||||
return (
|
||||
'\n'.join(['#!/bin/bash', 'set -uo pipefail'] + self.mutation_script_list)
|
||||
+ '\n'
|
||||
)
|
||||
# Don't exit early because we need to revert tests at the end
|
||||
|
||||
|
||||
def make_test_setup(specs, env_name, repo_directory, includes_tox=False):
|
||||
eval_commands = []
|
||||
|
||||
if includes_tox:
|
||||
eval_commands.append(UPDATE_TOX)
|
||||
|
||||
eval_commands += [
|
||||
'source /opt/miniconda3/bin/activate',
|
||||
f'conda activate {env_name}',
|
||||
f'cd {repo_directory}',
|
||||
]
|
||||
if 'eval_commands' in specs:
|
||||
eval_commands += specs['eval_commands']
|
||||
eval_commands += [
|
||||
f'git config --global --add safe.directory {repo_directory}', # for nonroot user
|
||||
f'cd {repo_directory}',
|
||||
# This is just informational, so we have a record
|
||||
'git status',
|
||||
'git show',
|
||||
'source /opt/miniconda3/bin/activate',
|
||||
f'conda activate {env_name}',
|
||||
]
|
||||
if 'install' in specs:
|
||||
eval_commands.append(specs['install'])
|
||||
|
||||
if includes_tox:
|
||||
eval_commands.append('add_coverage_tox "tox.ini"')
|
||||
|
||||
eval_commands.append('[ -f ".coveragerc" ] && rm ".coveragerc"')
|
||||
return eval_commands
|
||||
|
||||
|
||||
def make_test_script_list(test_cmd, specs, env_name, repo_directory):
|
||||
"""
|
||||
Runs the tests.
|
||||
"""
|
||||
|
||||
includes_tox = 'tox' in test_cmd
|
||||
eval_commands = make_test_setup(specs, env_name, repo_directory, includes_tox)
|
||||
eval_commands += [
|
||||
f'{test_cmd} || {{ echo "{TESTS_FAILED}\n{TESTS_SUFFIX}\n" && exit 1; }}',
|
||||
f'echo "{TESTS_SUFFIX}"\n',
|
||||
'coverage json -o coverage.json',
|
||||
f'echo "{COVERAGE_PREFIX}"\n',
|
||||
'cat coverage.json',
|
||||
]
|
||||
|
||||
return eval_commands
|
||||
|
||||
|
||||
def make_mutation_script_list(specs, env_name, repo_directory, mutation_timeout):
|
||||
"""
|
||||
Runs the tests.
|
||||
"""
|
||||
|
||||
eval_commands = make_test_setup(specs, env_name, repo_directory)
|
||||
eval_commands += [
|
||||
'cosmic-ray init mutation.toml mutation.sqlite',
|
||||
f'timeout {mutation_timeout}s cosmic-ray exec mutation.toml mutation.sqlite',
|
||||
'cr-report mutation.sqlite',
|
||||
'cr-rate mutation.sqlite --estimate --confidence 95.0',
|
||||
]
|
||||
return eval_commands
|
||||
|
||||
|
||||
def make_test_spec(
|
||||
instance: TestGenEvalInstance, mutation_timeout: int, buffer: int
|
||||
) -> TestSpec:
|
||||
if isinstance(instance, TestSpec):
|
||||
return instance
|
||||
instance_id = instance[KEY_INSTANCE_ID]
|
||||
id = instance['id']
|
||||
repo = instance['repo']
|
||||
version = instance['version']
|
||||
baseline_covs = instance['baseline_covs']
|
||||
code_file = instance['code_file']
|
||||
test_file = instance['test_file']
|
||||
local_imports = instance['local_imports']
|
||||
|
||||
env_name = 'testbed'
|
||||
repo_directory = f'/{env_name}'
|
||||
specs = MAP_REPO_VERSION_TO_SPECS[repo][version]
|
||||
|
||||
test_cmd = ' '.join(
|
||||
[
|
||||
MAP_REPO_VERSION_TO_SPECS[instance['repo']][instance['version']][
|
||||
'test_cmd'
|
||||
],
|
||||
*get_test_directives(instance),
|
||||
]
|
||||
)
|
||||
|
||||
test_script_list = make_test_script_list(test_cmd, specs, env_name, repo_directory)
|
||||
|
||||
mutation_script_list = make_mutation_script_list(
|
||||
specs, env_name, repo_directory, mutation_timeout - buffer
|
||||
)
|
||||
|
||||
return TestSpec(
|
||||
instance_id=instance_id,
|
||||
id=id,
|
||||
repo=repo,
|
||||
test_script_list=test_script_list,
|
||||
test_cmd=test_cmd,
|
||||
local_imports=local_imports,
|
||||
mutation_script_list=mutation_script_list,
|
||||
code_file=code_file,
|
||||
test_file=test_file,
|
||||
baseline_covs=baseline_covs,
|
||||
version=version,
|
||||
)
|
||||
@@ -0,0 +1,73 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
|
||||
from evaluation.benchmarks.testgeneval.constants import (
|
||||
KEY_INSTANCE_ID,
|
||||
TestGenEvalInstance,
|
||||
)
|
||||
|
||||
|
||||
def get_test_directives(instance: TestGenEvalInstance) -> list:
|
||||
"""
|
||||
Get test directives from the test_patch of a task instance
|
||||
|
||||
Args:
|
||||
instance (dict): task instance
|
||||
Returns:
|
||||
directives (list): List of test directives
|
||||
"""
|
||||
# For seq2seq code repos, testing command is fixed
|
||||
if instance['repo'] == 'swe-bench/humaneval':
|
||||
return ['test.py']
|
||||
|
||||
# Get test directives from test patch and remove non-test files
|
||||
directives = [f"/testbed/{instance['test_file']}"]
|
||||
|
||||
# For Django tests, remove extension + "tests/" prefix and convert slashes to dots (module referencing)
|
||||
if instance['repo'] == 'django/django':
|
||||
directives = [instance['test_file']]
|
||||
directives_transformed = []
|
||||
for d in directives:
|
||||
d = d[: -len('.py')] if d.endswith('.py') else d
|
||||
d = d[len('tests/') :] if d.startswith('tests/') else d
|
||||
d = d.replace('/', '.')
|
||||
directives_transformed.append(d)
|
||||
directives = directives_transformed
|
||||
|
||||
return directives
|
||||
|
||||
|
||||
def load_testgeneval_dataset(
|
||||
name='kjain14/testgeneval', split='test', ids=None
|
||||
) -> list[TestGenEvalInstance]:
|
||||
"""
|
||||
Load SWE-bench dataset from Hugging Face Datasets or local .json/.jsonl file
|
||||
"""
|
||||
# check that all instance IDs are in the dataset
|
||||
if ids:
|
||||
ids = set(ids)
|
||||
# Load from local .json/.jsonl file
|
||||
if name.endswith('.json') or name.endswith('.jsonl'):
|
||||
dataset = json.loads(Path(name).read_text())
|
||||
dataset_ids = {instance[KEY_INSTANCE_ID] for instance in dataset}
|
||||
else:
|
||||
# Load from Hugging Face Datasets
|
||||
if name.lower() in {'testgeneval'}:
|
||||
name = 'kjain14/testgeneval'
|
||||
elif name.lower() in {'testgeneval-lite', 'testgenevallite', 'lite'}:
|
||||
name = 'kjain14/testgenevallite'
|
||||
dataset = cast(Dataset, load_dataset(name, split=split))
|
||||
dataset_ids = {instance['id'] for instance in dataset}
|
||||
if ids:
|
||||
if ids - dataset_ids:
|
||||
raise ValueError(
|
||||
(
|
||||
"Some instance IDs not found in dataset!"
|
||||
f"\nMissing IDs:\n{' '.join(ids - dataset_ids)}"
|
||||
)
|
||||
)
|
||||
dataset = [instance for instance in dataset if instance['id'] in ids]
|
||||
return [cast(TestGenEvalInstance, instance) for instance in dataset]
|
||||
@@ -34,7 +34,6 @@ from openhands.utils.async_utils import call_async_from_sync
|
||||
|
||||
FAKE_RESPONSES = {
|
||||
'CodeActAgent': fake_user_response,
|
||||
'DelegatorAgent': fake_user_response,
|
||||
'VisualBrowsingAgent': fake_user_response,
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ load_dotenv()
|
||||
from openhands.agenthub import ( # noqa: E402
|
||||
browsing_agent,
|
||||
codeact_agent,
|
||||
delegator_agent,
|
||||
dummy_agent,
|
||||
visualbrowsing_agent,
|
||||
)
|
||||
@@ -15,7 +14,6 @@ from openhands.controller.agent import Agent # noqa: E402
|
||||
__all__ = [
|
||||
'Agent',
|
||||
'codeact_agent',
|
||||
'delegator_agent',
|
||||
'dummy_agent',
|
||||
'browsing_agent',
|
||||
'visualbrowsing_agent',
|
||||
|
||||
@@ -70,9 +70,10 @@ class CodeActAgent(Agent):
|
||||
codeact_enable_browsing=self.config.codeact_enable_browsing,
|
||||
codeact_enable_jupyter=self.config.codeact_enable_jupyter,
|
||||
codeact_enable_llm_editor=self.config.codeact_enable_llm_editor,
|
||||
llm=self.llm,
|
||||
)
|
||||
logger.debug(
|
||||
f'TOOLS loaded for CodeActAgent: {', '.join([tool.get('function').get('name') for tool in self.tools])}'
|
||||
f"TOOLS loaded for CodeActAgent: {', '.join([tool.get('function').get('name') for tool in self.tools])}"
|
||||
)
|
||||
self.prompt_manager = PromptManager(
|
||||
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
|
||||
|
||||
@@ -12,13 +12,13 @@ from litellm import (
|
||||
|
||||
from openhands.agenthub.codeact_agent.tools import (
|
||||
BrowserTool,
|
||||
CmdRunTool,
|
||||
FinishTool,
|
||||
IPythonTool,
|
||||
LLMBasedFileEditTool,
|
||||
StrReplaceEditorTool,
|
||||
ThinkTool,
|
||||
WebReadTool,
|
||||
create_cmd_run_tool,
|
||||
create_str_replace_editor_tool,
|
||||
)
|
||||
from openhands.core.exceptions import (
|
||||
FunctionCallNotExistsError,
|
||||
@@ -39,6 +39,7 @@ from openhands.events.action import (
|
||||
)
|
||||
from openhands.events.event import FileEditSource, FileReadSource
|
||||
from openhands.events.tool import ToolCallMetadata
|
||||
from openhands.llm import LLM
|
||||
|
||||
|
||||
def combine_thought(action: Action, thought: str) -> Action:
|
||||
@@ -80,7 +81,7 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
|
||||
# CmdRunTool (Bash)
|
||||
# ================================================
|
||||
|
||||
if tool_call.function.name == CmdRunTool['function']['name']:
|
||||
if tool_call.function.name == create_cmd_run_tool()['function']['name']:
|
||||
if 'command' not in arguments:
|
||||
raise FunctionCallValidationError(
|
||||
f'Missing required argument "command" in tool call {tool_call.function.name}'
|
||||
@@ -131,7 +132,10 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
|
||||
start=arguments.get('start', 1),
|
||||
end=arguments.get('end', -1),
|
||||
)
|
||||
elif tool_call.function.name == StrReplaceEditorTool['function']['name']:
|
||||
elif (
|
||||
tool_call.function.name
|
||||
== create_str_replace_editor_tool()['function']['name']
|
||||
):
|
||||
if 'command' not in arguments:
|
||||
raise FunctionCallValidationError(
|
||||
f'Missing required argument "command" in tool call {tool_call.function.name}'
|
||||
@@ -219,8 +223,22 @@ def get_tools(
|
||||
codeact_enable_browsing: bool = False,
|
||||
codeact_enable_llm_editor: bool = False,
|
||||
codeact_enable_jupyter: bool = False,
|
||||
llm: LLM | None = None,
|
||||
) -> list[ChatCompletionToolParam]:
|
||||
tools = [CmdRunTool, ThinkTool, FinishTool]
|
||||
SIMPLIFIED_TOOL_DESCRIPTION_LLM_SUBSTRS = ['gpt-', 'o3', 'o1']
|
||||
|
||||
use_simplified_tool_desc = False
|
||||
if llm is not None:
|
||||
use_simplified_tool_desc = any(
|
||||
model_substr in llm.config.model
|
||||
for model_substr in SIMPLIFIED_TOOL_DESCRIPTION_LLM_SUBSTRS
|
||||
)
|
||||
|
||||
tools = [
|
||||
create_cmd_run_tool(use_simplified_description=use_simplified_tool_desc),
|
||||
ThinkTool,
|
||||
FinishTool,
|
||||
]
|
||||
if codeact_enable_browsing:
|
||||
tools.append(WebReadTool)
|
||||
tools.append(BrowserTool)
|
||||
@@ -229,5 +247,9 @@ def get_tools(
|
||||
if codeact_enable_llm_editor:
|
||||
tools.append(LLMBasedFileEditTool)
|
||||
else:
|
||||
tools.append(StrReplaceEditorTool)
|
||||
tools.append(
|
||||
create_str_replace_editor_tool(
|
||||
use_simplified_description=use_simplified_tool_desc
|
||||
)
|
||||
)
|
||||
return tools
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
from .bash import CmdRunTool
|
||||
from .bash import create_cmd_run_tool
|
||||
from .browser import BrowserTool
|
||||
from .finish import FinishTool
|
||||
from .ipython import IPythonTool
|
||||
from .llm_based_edit import LLMBasedFileEditTool
|
||||
from .str_replace_editor import StrReplaceEditorTool
|
||||
from .str_replace_editor import create_str_replace_editor_tool
|
||||
from .think import ThinkTool
|
||||
from .web_read import WebReadTool
|
||||
|
||||
__all__ = [
|
||||
'BrowserTool',
|
||||
'CmdRunTool',
|
||||
'create_cmd_run_tool',
|
||||
'FinishTool',
|
||||
'IPythonTool',
|
||||
'LLMBasedFileEditTool',
|
||||
'StrReplaceEditorTool',
|
||||
'create_str_replace_editor_tool',
|
||||
'WebReadTool',
|
||||
'ThinkTool',
|
||||
]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from litellm import ChatCompletionToolParam, ChatCompletionToolParamFunctionChunk
|
||||
|
||||
_BASH_DESCRIPTION = """Execute a bash command in the terminal within a persistent shell session.
|
||||
_DETAILED_BASH_DESCRIPTION = """Execute a bash command in the terminal within a persistent shell session.
|
||||
|
||||
### Command Execution
|
||||
* One command at a time: You can only execute one bash command at a time. If you need to run multiple commands sequentially, use `&&` or `;` to chain them together.
|
||||
@@ -22,25 +22,39 @@ _BASH_DESCRIPTION = """Execute a bash command in the terminal within a persisten
|
||||
* Output truncation: If the output exceeds a maximum length, it will be truncated before being returned.
|
||||
"""
|
||||
|
||||
CmdRunTool = ChatCompletionToolParam(
|
||||
type='function',
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name='execute_bash',
|
||||
description=_BASH_DESCRIPTION,
|
||||
parameters={
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'command': {
|
||||
'type': 'string',
|
||||
'description': 'The bash command to execute. Can be empty string to view additional logs when previous exit code is `-1`. Can be `C-c` (Ctrl+C) to interrupt the currently running process. Note: You can only execute one bash command at a time. If you need to run multiple commands sequentially, you can use `&&` or `;` to chain them together.',
|
||||
},
|
||||
'is_input': {
|
||||
'type': 'string',
|
||||
'description': 'If True, the command is an input to the running process. If False, the command is a bash command to be executed in the terminal. Default is False.',
|
||||
'enum': ['true', 'false'],
|
||||
_SIMPLIFIED_BASH_DESCRIPTION = """Execute a bash command in the terminal.
|
||||
* Long running commands: For commands that may run indefinitely, it should be run in the background and the output should be redirected to a file, e.g. command = `python3 app.py > server.log 2>&1 &`.
|
||||
* Interact with running process: If a bash command returns exit code `-1`, this means the process is not yet finished. By setting `is_input` to `true`, the assistant can interact with the running process and send empty `command` to retrieve any additional logs, or send additional text (set `command` to the text) to STDIN of the running process, or send command like `C-c` (Ctrl+C), `C-d` (Ctrl+D), `C-z` (Ctrl+Z) to interrupt the process.
|
||||
* One command at a time: You can only execute one bash command at a time. If you need to run multiple commands sequentially, you can use `&&` or `;` to chain them together."""
|
||||
|
||||
|
||||
def create_cmd_run_tool(
|
||||
use_simplified_description: bool = False,
|
||||
) -> ChatCompletionToolParam:
|
||||
description = (
|
||||
_SIMPLIFIED_BASH_DESCRIPTION
|
||||
if use_simplified_description
|
||||
else _DETAILED_BASH_DESCRIPTION
|
||||
)
|
||||
return ChatCompletionToolParam(
|
||||
type='function',
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name='execute_bash',
|
||||
description=description,
|
||||
parameters={
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'command': {
|
||||
'type': 'string',
|
||||
'description': 'The bash command to execute. Can be empty string to view additional logs when previous exit code is `-1`. Can be `C-c` (Ctrl+C) to interrupt the currently running process. Note: You can only execute one bash command at a time. If you need to run multiple commands sequentially, you can use `&&` or `;` to chain them together.',
|
||||
},
|
||||
'is_input': {
|
||||
'type': 'string',
|
||||
'description': 'If True, the command is an input to the running process. If False, the command is a bash command to be executed in the terminal. Default is False.',
|
||||
'enum': ['true', 'false'],
|
||||
},
|
||||
},
|
||||
'required': ['command'],
|
||||
},
|
||||
'required': ['command'],
|
||||
},
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from litellm import ChatCompletionToolParam, ChatCompletionToolParamFunctionChunk
|
||||
|
||||
_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files in plain-text format
|
||||
_DETAILED_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files in plain-text format
|
||||
* State is persistent across command calls and discussions with the user
|
||||
* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep
|
||||
* The `create` command cannot be used if the specified `path` already exists as a file
|
||||
@@ -31,46 +31,73 @@ CRITICAL REQUIREMENTS FOR USING THIS TOOL:
|
||||
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.
|
||||
"""
|
||||
|
||||
StrReplaceEditorTool = ChatCompletionToolParam(
|
||||
type='function',
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name='str_replace_editor',
|
||||
description=_STR_REPLACE_EDITOR_DESCRIPTION,
|
||||
parameters={
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'command': {
|
||||
'description': 'The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`, `undo_edit`.',
|
||||
'enum': ['view', 'create', 'str_replace', 'insert', 'undo_edit'],
|
||||
'type': 'string',
|
||||
},
|
||||
'path': {
|
||||
'description': 'Absolute path to file or directory, e.g. `/workspace/file.py` or `/workspace`.',
|
||||
'type': 'string',
|
||||
},
|
||||
'file_text': {
|
||||
'description': 'Required parameter of `create` command, with the content of the file to be created.',
|
||||
'type': 'string',
|
||||
},
|
||||
'old_str': {
|
||||
'description': 'Required parameter of `str_replace` command containing the string in `path` to replace.',
|
||||
'type': 'string',
|
||||
},
|
||||
'new_str': {
|
||||
'description': 'Optional parameter of `str_replace` command containing the new string (if not given, no string will be added). Required parameter of `insert` command containing the string to insert.',
|
||||
'type': 'string',
|
||||
},
|
||||
'insert_line': {
|
||||
'description': 'Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`.',
|
||||
'type': 'integer',
|
||||
},
|
||||
'view_range': {
|
||||
'description': 'Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.',
|
||||
'items': {'type': 'integer'},
|
||||
'type': 'array',
|
||||
_SIMPLIFIED_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files in plain-text format
|
||||
* State is persistent across command calls and discussions with the user
|
||||
* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep
|
||||
* The `create` command cannot be used if the specified `path` already exists as a file
|
||||
* If a `command` generates a long output, it will be truncated and marked with `<response clipped>`
|
||||
* The `undo_edit` command will revert the last edit made to the file at `path`
|
||||
Notes for using the `str_replace` command:
|
||||
* The `old_str` parameter should match EXACTLY one or more consecutive lines from the original file. Be mindful of whitespaces!
|
||||
* If the `old_str` parameter is not unique in the file, the replacement will not be performed. Make sure to include enough context in `old_str` to make it unique
|
||||
* The `new_str` parameter should contain the edited lines that should replace the `old_str`
|
||||
"""
|
||||
|
||||
|
||||
def create_str_replace_editor_tool(
|
||||
use_simplified_description: bool = False,
|
||||
) -> ChatCompletionToolParam:
|
||||
description = (
|
||||
_SIMPLIFIED_STR_REPLACE_EDITOR_DESCRIPTION
|
||||
if use_simplified_description
|
||||
else _DETAILED_STR_REPLACE_EDITOR_DESCRIPTION
|
||||
)
|
||||
return ChatCompletionToolParam(
|
||||
type='function',
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name='str_replace_editor',
|
||||
description=description,
|
||||
parameters={
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'command': {
|
||||
'description': 'The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`, `undo_edit`.',
|
||||
'enum': [
|
||||
'view',
|
||||
'create',
|
||||
'str_replace',
|
||||
'insert',
|
||||
'undo_edit',
|
||||
],
|
||||
'type': 'string',
|
||||
},
|
||||
'path': {
|
||||
'description': 'Absolute path to file or directory, e.g. `/workspace/file.py` or `/workspace`.',
|
||||
'type': 'string',
|
||||
},
|
||||
'file_text': {
|
||||
'description': 'Required parameter of `create` command, with the content of the file to be created.',
|
||||
'type': 'string',
|
||||
},
|
||||
'old_str': {
|
||||
'description': 'Required parameter of `str_replace` command containing the string in `path` to replace.',
|
||||
'type': 'string',
|
||||
},
|
||||
'new_str': {
|
||||
'description': 'Optional parameter of `str_replace` command containing the new string (if not given, no string will be added). Required parameter of `insert` command containing the string to insert.',
|
||||
'type': 'string',
|
||||
},
|
||||
'insert_line': {
|
||||
'description': 'Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`.',
|
||||
'type': 'integer',
|
||||
},
|
||||
'view_range': {
|
||||
'description': 'Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.',
|
||||
'items': {'type': 'integer'},
|
||||
'type': 'array',
|
||||
},
|
||||
},
|
||||
'required': ['command', 'path'],
|
||||
},
|
||||
'required': ['command', 'path'],
|
||||
},
|
||||
),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
from openhands.agenthub.delegator_agent.agent import DelegatorAgent
|
||||
from openhands.controller.agent import Agent
|
||||
|
||||
Agent.register('DelegatorAgent', DelegatorAgent)
|
||||
@@ -1,87 +0,0 @@
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.events.action import Action, AgentDelegateAction, AgentFinishAction
|
||||
from openhands.events.observation import AgentDelegateObservation, Observation
|
||||
from openhands.llm.llm import LLM
|
||||
|
||||
|
||||
class DelegatorAgent(Agent):
|
||||
VERSION = '1.0'
|
||||
"""
|
||||
The Delegator Agent is responsible for delegating tasks to other agents based on the current task.
|
||||
"""
|
||||
|
||||
current_delegate: str = ''
|
||||
|
||||
def __init__(self, llm: LLM, config: AgentConfig):
|
||||
"""Initialize the Delegator Agent with an LLM
|
||||
|
||||
Parameters:
|
||||
- llm (LLM): The llm to be used by this agent
|
||||
"""
|
||||
super().__init__(llm, config)
|
||||
|
||||
def step(self, state: State) -> Action:
|
||||
"""Checks to see if current step is completed, returns AgentFinishAction if True.
|
||||
Otherwise, delegates the task to the next agent in the pipeline.
|
||||
|
||||
Parameters:
|
||||
- state (State): The current state given the previous actions and observations
|
||||
|
||||
Returns:
|
||||
- AgentFinishAction: If the last state was 'completed', 'verified', or 'abandoned'
|
||||
- AgentDelegateAction: The next agent to delegate the task to
|
||||
"""
|
||||
if self.current_delegate == '':
|
||||
self.current_delegate = 'study'
|
||||
task, _ = state.get_current_user_intent()
|
||||
return AgentDelegateAction(
|
||||
agent='StudyRepoForTaskAgent', inputs={'task': task}
|
||||
)
|
||||
|
||||
# last observation in history should be from the delegate
|
||||
last_observation = None
|
||||
for event in reversed(state.history):
|
||||
if isinstance(event, Observation):
|
||||
last_observation = event
|
||||
break
|
||||
|
||||
if not isinstance(last_observation, AgentDelegateObservation):
|
||||
raise Exception('Last observation is not an AgentDelegateObservation')
|
||||
|
||||
goal, _ = state.get_current_user_intent()
|
||||
if self.current_delegate == 'study':
|
||||
self.current_delegate = 'coder'
|
||||
return AgentDelegateAction(
|
||||
agent='CoderAgent',
|
||||
inputs={
|
||||
'task': goal,
|
||||
'summary': last_observation.outputs['summary'],
|
||||
},
|
||||
)
|
||||
elif self.current_delegate == 'coder':
|
||||
self.current_delegate = 'verifier'
|
||||
return AgentDelegateAction(
|
||||
agent='VerifierAgent',
|
||||
inputs={
|
||||
'task': goal,
|
||||
},
|
||||
)
|
||||
elif self.current_delegate == 'verifier':
|
||||
if (
|
||||
'completed' in last_observation.outputs
|
||||
and last_observation.outputs['completed']
|
||||
):
|
||||
return AgentFinishAction()
|
||||
else:
|
||||
self.current_delegate = 'coder'
|
||||
return AgentDelegateAction(
|
||||
agent='CoderAgent',
|
||||
inputs={
|
||||
'task': goal,
|
||||
'summary': last_observation.outputs['summary'],
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise Exception('Invalid delegate state')
|
||||
@@ -202,6 +202,7 @@ Note:
|
||||
tabs = ''
|
||||
last_obs = None
|
||||
last_action = None
|
||||
set_of_marks = None # Initialize set_of_marks to None
|
||||
|
||||
if len(state.history) == 1:
|
||||
# for visualwebarena, webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env
|
||||
@@ -217,6 +218,9 @@ Note:
|
||||
# agent has responded, task finished.
|
||||
return AgentFinishAction(outputs={'content': event.content})
|
||||
elif isinstance(event, Observation):
|
||||
# Only process BrowserOutputObservation and skip other observation types
|
||||
if not isinstance(event, BrowserOutputObservation):
|
||||
continue
|
||||
last_obs = event
|
||||
|
||||
if len(prev_actions) >= 1: # ignore noop()
|
||||
|
||||
@@ -15,6 +15,7 @@ class SandboxConfig(BaseModel):
|
||||
timeout: The timeout for the default sandbox action execution.
|
||||
remote_runtime_init_timeout: The timeout for the remote runtime to start.
|
||||
remote_runtime_api_timeout: The timeout for the remote runtime API requests.
|
||||
remote_runtime_enable_retries: Whether to enable retries (on recoverable errors like requests.ConnectionError) for the remote runtime API requests.
|
||||
enable_auto_lint: Whether to enable auto-lint.
|
||||
use_host_network: Whether to use the host network.
|
||||
runtime_binding_address: The binding address for the runtime ports. It specifies which network interface on the host machine Docker should bind the runtime ports to.
|
||||
@@ -53,7 +54,7 @@ class SandboxConfig(BaseModel):
|
||||
timeout: int = Field(default=120)
|
||||
remote_runtime_init_timeout: int = Field(default=180)
|
||||
remote_runtime_api_timeout: int = Field(default=10)
|
||||
remote_runtime_enable_retries: bool = Field(default=False)
|
||||
remote_runtime_enable_retries: bool = Field(default=True)
|
||||
remote_runtime_class: str | None = Field(
|
||||
default=None
|
||||
) # can be "None" (default to gvisor) or "sysbox" (support docker inside runtime + more stable)
|
||||
|
||||
@@ -240,7 +240,7 @@ class SensitiveDataFilter(logging.Filter):
|
||||
if (
|
||||
len(value) > 2
|
||||
and value != 'default'
|
||||
and any(s in key_upper for s in ('SECRET', 'KEY', 'CODE', 'TOKEN'))
|
||||
and any(s in key_upper for s in ('SECRET', '_KEY', '_CODE', '_TOKEN'))
|
||||
):
|
||||
sensitive_values.append(value)
|
||||
|
||||
|
||||
@@ -49,8 +49,8 @@ class ObservationTypeSchema(BaseModel):
|
||||
CONDENSE: str = Field(default='condense')
|
||||
"""Result of a condensation operation."""
|
||||
|
||||
MICROAGENT: str = Field(default='microagent')
|
||||
"""Result of a microagent retrieval operation."""
|
||||
RECALL: str = Field(default='recall')
|
||||
"""Result of a recall operation. This can be the workspace context, a microagent, or other types of information."""
|
||||
|
||||
|
||||
ObservationType = ObservationTypeSchema()
|
||||
|
||||
@@ -3,7 +3,7 @@ from openhands.events.observation.agent import (
|
||||
AgentCondensationObservation,
|
||||
AgentStateChangedObservation,
|
||||
AgentThinkObservation,
|
||||
MicroagentObservation,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.observation.browse import BrowserOutputObservation
|
||||
from openhands.events.observation.commands import (
|
||||
@@ -42,6 +42,6 @@ __all__ = [
|
||||
'SuccessObservation',
|
||||
'UserRejectObservation',
|
||||
'AgentCondensationObservation',
|
||||
'MicroagentObservation',
|
||||
'RecallObservation',
|
||||
'RecallType',
|
||||
]
|
||||
|
||||
@@ -60,13 +60,13 @@ class MicroagentKnowledge:
|
||||
|
||||
|
||||
@dataclass
|
||||
class MicroagentObservation(Observation):
|
||||
class RecallObservation(Observation):
|
||||
"""The retrieval of content from a microagent or more microagents."""
|
||||
|
||||
recall_type: RecallType
|
||||
observation: str = ObservationType.MICROAGENT
|
||||
observation: str = ObservationType.RECALL
|
||||
|
||||
# environment
|
||||
# workspace context
|
||||
repo_name: str = ''
|
||||
repo_directory: str = ''
|
||||
repo_instructions: str = ''
|
||||
@@ -95,22 +95,36 @@ class MicroagentObservation(Observation):
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.__str__()
|
||||
return (
|
||||
'Added workspace context'
|
||||
if self.recall_type == RecallType.WORKSPACE_CONTEXT
|
||||
else 'Added microagent knowledge'
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
# Build a string representation of all fields
|
||||
fields = [
|
||||
f'recall_type={self.recall_type}',
|
||||
f'repo_name={self.repo_name}',
|
||||
f'repo_instructions={self.repo_instructions[:20]}...',
|
||||
f'runtime_hosts={self.runtime_hosts}',
|
||||
f'additional_agent_instructions={self.additional_agent_instructions[:20]}...',
|
||||
]
|
||||
|
||||
# Only include microagent_knowledge if it's not empty
|
||||
# Build a string representation
|
||||
fields = []
|
||||
if self.recall_type == RecallType.WORKSPACE_CONTEXT:
|
||||
fields.extend(
|
||||
[
|
||||
f'recall_type={self.recall_type}',
|
||||
f'repo_name={self.repo_name}',
|
||||
f'repo_instructions={self.repo_instructions[:20]}...',
|
||||
f'runtime_hosts={self.runtime_hosts}',
|
||||
f'additional_agent_instructions={self.additional_agent_instructions[:20]}...',
|
||||
]
|
||||
)
|
||||
else:
|
||||
fields.extend(
|
||||
[
|
||||
f'recall_type={self.recall_type}',
|
||||
]
|
||||
)
|
||||
if self.microagent_knowledge:
|
||||
fields.append(
|
||||
f'microagent_knowledge={", ".join([m.name for m in self.microagent_knowledge])}'
|
||||
fields.extend(
|
||||
[
|
||||
f'microagent_knowledge={", ".join([m.name for m in self.microagent_knowledge])}',
|
||||
]
|
||||
)
|
||||
|
||||
return f'**MicroagentObservation**\n{", ".join(fields)}'
|
||||
return f'**RecallObservation**\n{", ".join(fields)}'
|
||||
|
||||
@@ -122,7 +122,7 @@ def event_to_dict(event: 'Event') -> dict:
|
||||
# props is a dict whose values can include a complex object like an instance of a BaseModel subclass
|
||||
# such as CmdOutputMetadata
|
||||
# we serialize it along with the rest
|
||||
# we also handle the Enum conversion for MicroagentObservation
|
||||
# we also handle the Enum conversion for RecallObservation
|
||||
d['extras'] = {
|
||||
k: (v.value if isinstance(v, Enum) else _convert_pydantic_to_dict(v))
|
||||
for k, v in props.items()
|
||||
|
||||
@@ -6,7 +6,7 @@ from openhands.events.observation.agent import (
|
||||
AgentStateChangedObservation,
|
||||
AgentThinkObservation,
|
||||
MicroagentKnowledge,
|
||||
MicroagentObservation,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.observation.browse import BrowserOutputObservation
|
||||
from openhands.events.observation.commands import (
|
||||
@@ -43,7 +43,7 @@ observations = (
|
||||
UserRejectObservation,
|
||||
AgentCondensationObservation,
|
||||
AgentThinkObservation,
|
||||
MicroagentObservation,
|
||||
RecallObservation,
|
||||
)
|
||||
|
||||
OBSERVATION_TYPE_TO_CLASS = {
|
||||
@@ -114,7 +114,7 @@ def observation_from_dict(observation: dict) -> Observation:
|
||||
else:
|
||||
extras['metadata'] = CmdOutputMetadata()
|
||||
|
||||
if observation_class is MicroagentObservation:
|
||||
if observation_class is RecallObservation:
|
||||
# handle the Enum conversion
|
||||
if 'recall_type' in extras:
|
||||
extras['recall_type'] = RecallType(extras['recall_type'])
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
GitService,
|
||||
@@ -15,7 +16,7 @@ from openhands.integrations.service_types import (
|
||||
User,
|
||||
)
|
||||
from openhands.utils.import_utils import get_impl
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class GitHubService(GitService):
|
||||
BASE_URL = 'https://api.github.com'
|
||||
@@ -25,6 +26,7 @@ class GitHubService(GitService):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
external_auth_id: str | None = None,
|
||||
external_auth_token: SecretStr | None = None,
|
||||
token: SecretStr | None = None,
|
||||
external_token_manager: bool = False,
|
||||
|
||||
@@ -31,7 +31,7 @@ from openhands.events.observation import (
|
||||
)
|
||||
from openhands.events.observation.agent import (
|
||||
MicroagentKnowledge,
|
||||
MicroagentObservation,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
@@ -52,7 +52,6 @@ class ConversationMemory:
|
||||
initial_messages: list[Message],
|
||||
max_message_chars: int | None = None,
|
||||
vision_is_active: bool = False,
|
||||
enable_som_visual_browsing: bool = False,
|
||||
) -> list[Message]:
|
||||
"""Process state history into a list of messages for the LLM.
|
||||
|
||||
@@ -64,11 +63,13 @@ class ConversationMemory:
|
||||
max_message_chars: The maximum number of characters in the content of an event included
|
||||
in the prompt to the LLM. Larger observations are truncated.
|
||||
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included.
|
||||
enable_som_visual_browsing: Whether to enable visual browsing for the SOM model.
|
||||
"""
|
||||
|
||||
events = condensed_history
|
||||
|
||||
# log visual browsing status
|
||||
logger.debug(f'Visual browsing: {self.agent_config.enable_som_visual_browsing}')
|
||||
|
||||
# Process special events first (system prompts, etc.)
|
||||
messages = initial_messages
|
||||
|
||||
@@ -385,19 +386,19 @@ class ConversationMemory:
|
||||
text = truncate_content(obs.content, max_message_chars)
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif (
|
||||
isinstance(obs, MicroagentObservation)
|
||||
isinstance(obs, RecallObservation)
|
||||
and self.agent_config.enable_prompt_extensions
|
||||
):
|
||||
if obs.recall_type == RecallType.WORKSPACE_CONTEXT:
|
||||
# everything is optional, check if they are present
|
||||
repo_info = (
|
||||
RepositoryInfo(
|
||||
if obs.repo_name or obs.repo_directory:
|
||||
repo_info = RepositoryInfo(
|
||||
repo_name=obs.repo_name or '',
|
||||
repo_directory=obs.repo_directory or '',
|
||||
)
|
||||
if obs.repo_name or obs.repo_directory
|
||||
else None
|
||||
)
|
||||
else:
|
||||
repo_info = None
|
||||
|
||||
if obs.runtime_hosts or obs.additional_agent_instructions:
|
||||
runtime_info = RuntimeInfo(
|
||||
available_hosts=obs.runtime_hosts,
|
||||
@@ -420,22 +421,49 @@ class ConversationMemory:
|
||||
)
|
||||
has_repo_instructions = bool(repo_instructions.strip())
|
||||
|
||||
# Build additional info if we have something to render
|
||||
# Filter and process microagent knowledge
|
||||
filtered_agents = []
|
||||
if obs.microagent_knowledge:
|
||||
# Exclude disabled microagents
|
||||
filtered_agents = [
|
||||
agent
|
||||
for agent in obs.microagent_knowledge
|
||||
if agent.name not in self.agent_config.disabled_microagents
|
||||
]
|
||||
|
||||
has_microagent_knowledge = bool(filtered_agents)
|
||||
|
||||
# Generate appropriate content based on what is present
|
||||
message_content = []
|
||||
|
||||
# Build the workspace context information
|
||||
if has_repo_info or has_runtime_info or has_repo_instructions:
|
||||
# ok, now we can build the additional info
|
||||
formatted_text = self.prompt_manager.build_additional_info(
|
||||
repository_info=repo_info,
|
||||
runtime_info=runtime_info,
|
||||
repo_instructions=repo_instructions,
|
||||
formatted_workspace_text = (
|
||||
self.prompt_manager.build_workspace_context(
|
||||
repository_info=repo_info,
|
||||
runtime_info=runtime_info,
|
||||
repo_instructions=repo_instructions,
|
||||
)
|
||||
)
|
||||
message = Message(
|
||||
role='user', content=[TextContent(text=formatted_text)]
|
||||
message_content.append(TextContent(text=formatted_workspace_text))
|
||||
|
||||
# Add microagent knowledge if present
|
||||
if has_microagent_knowledge:
|
||||
formatted_microagent_text = (
|
||||
self.prompt_manager.build_microagent_info(
|
||||
triggered_agents=filtered_agents,
|
||||
)
|
||||
)
|
||||
message_content.append(TextContent(text=formatted_microagent_text))
|
||||
|
||||
# Return the combined message if we have any content
|
||||
if message_content:
|
||||
message = Message(role='user', content=message_content)
|
||||
else:
|
||||
return []
|
||||
elif obs.recall_type == RecallType.KNOWLEDGE:
|
||||
# Use prompt manager to build the microagent info
|
||||
# First, filter out agents that appear in earlier MicroagentObservations
|
||||
# First, filter out agents that appear in earlier RecallObservations
|
||||
filtered_agents = self._filter_agents_in_microagent_obs(
|
||||
obs, current_index, events or []
|
||||
)
|
||||
@@ -464,7 +492,7 @@ class ConversationMemory:
|
||||
# Return empty list if no microagents to include or all were disabled
|
||||
return []
|
||||
elif (
|
||||
isinstance(obs, MicroagentObservation)
|
||||
isinstance(obs, RecallObservation)
|
||||
and not self.agent_config.enable_prompt_extensions
|
||||
):
|
||||
# If prompt extensions are disabled, we don't add any additional info
|
||||
@@ -504,12 +532,12 @@ class ConversationMemory:
|
||||
break
|
||||
|
||||
def _filter_agents_in_microagent_obs(
|
||||
self, obs: MicroagentObservation, current_index: int, events: list[Event]
|
||||
self, obs: RecallObservation, current_index: int, events: list[Event]
|
||||
) -> list[MicroagentKnowledge]:
|
||||
"""Filter out agents that appear in earlier MicroagentObservations.
|
||||
"""Filter out agents that appear in earlier RecallObservations.
|
||||
|
||||
Args:
|
||||
obs: The current MicroagentObservation to filter
|
||||
obs: The current RecallObservation to filter
|
||||
current_index: The index of the current event in the events list
|
||||
events: The list of all events
|
||||
|
||||
@@ -532,7 +560,7 @@ class ConversationMemory:
|
||||
def _has_agent_in_earlier_events(
|
||||
self, agent_name: str, current_index: int, events: list[Event]
|
||||
) -> bool:
|
||||
"""Check if an agent appears in any earlier MicroagentObservation in the event list.
|
||||
"""Check if an agent appears in any earlier RecallObservation in the event list.
|
||||
|
||||
Args:
|
||||
agent_name: The name of the agent to look for
|
||||
@@ -540,13 +568,11 @@ class ConversationMemory:
|
||||
events: The list of all events
|
||||
|
||||
Returns:
|
||||
bool: True if the agent appears in an earlier MicroagentObservation, False otherwise
|
||||
bool: True if the agent appears in an earlier RecallObservation, False otherwise
|
||||
"""
|
||||
for event in events[:current_index]:
|
||||
if (
|
||||
isinstance(event, MicroagentObservation)
|
||||
and event.recall_type == RecallType.KNOWLEDGE
|
||||
):
|
||||
# Note that this check includes the WORKSPACE_CONTEXT
|
||||
if isinstance(event, RecallObservation):
|
||||
if any(
|
||||
agent.name == agent_name for agent in event.microagent_knowledge
|
||||
):
|
||||
|
||||
+77
-55
@@ -9,7 +9,7 @@ from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.event import Event, EventSource, RecallType
|
||||
from openhands.events.observation.agent import (
|
||||
MicroagentKnowledge,
|
||||
MicroagentObservation,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.stream import EventStream, EventStreamSubscriber
|
||||
@@ -31,7 +31,7 @@ GLOBAL_MICROAGENTS_DIR = os.path.join(
|
||||
class Memory:
|
||||
"""
|
||||
Memory is a component that listens to the EventStream for information retrieval actions
|
||||
(a RecallAction) and publishes observations with the content (such as MicroagentObservation).
|
||||
(a RecallAction) and publishes observations with the content (such as RecallObservation).
|
||||
"""
|
||||
|
||||
sid: str
|
||||
@@ -75,48 +75,59 @@ class Memory:
|
||||
async def _on_event(self, event: Event):
|
||||
"""Handle an event from the event stream asynchronously."""
|
||||
try:
|
||||
observation: MicroagentObservation | NullObservation | None = None
|
||||
|
||||
if isinstance(event, RecallAction):
|
||||
# if this is a workspace context recall (on first user message)
|
||||
# create and add a MicroagentObservation
|
||||
# with info about repo and runtime.
|
||||
# create and add a RecallObservation
|
||||
# with info about repo, runtime, instructions, etc. including microagent knowledge if any
|
||||
if (
|
||||
event.source == EventSource.USER
|
||||
and event.recall_type == RecallType.WORKSPACE_CONTEXT
|
||||
):
|
||||
observation = self._on_first_microagent_action(event)
|
||||
logger.debug('Workspace context recall')
|
||||
workspace_obs: RecallObservation | NullObservation | None = None
|
||||
|
||||
# continue with the next handler, to include knowledge microagents if suitable for this query
|
||||
assert observation is None or isinstance(
|
||||
observation, MicroagentObservation
|
||||
), f'Expected a MicroagentObservation, but got {type(observation)}'
|
||||
observation = self._on_microagent_action(
|
||||
event, prev_observation=observation
|
||||
)
|
||||
workspace_obs = self._on_workspace_context_recall(event)
|
||||
if workspace_obs is None:
|
||||
workspace_obs = NullObservation(content='')
|
||||
|
||||
if observation is None:
|
||||
observation = NullObservation(content='')
|
||||
# important: this will release the execution flow from waiting for the retrieval to complete
|
||||
workspace_obs._cause = event.id # type: ignore[union-attr]
|
||||
|
||||
# important: this will release the execution flow from waiting for the retrieval to complete
|
||||
observation._cause = event.id # type: ignore[union-attr]
|
||||
self.event_stream.add_event(workspace_obs, EventSource.ENVIRONMENT)
|
||||
return
|
||||
|
||||
self.event_stream.add_event(observation, EventSource.ENVIRONMENT)
|
||||
# Handle knowledge recall (triggered microagents)
|
||||
elif (
|
||||
event.source == EventSource.USER
|
||||
and event.recall_type == RecallType.KNOWLEDGE
|
||||
):
|
||||
logger.debug('Microagent knowledge recall')
|
||||
microagent_obs: RecallObservation | NullObservation | None = None
|
||||
microagent_obs = self._on_microagent_recall(event)
|
||||
if microagent_obs is None:
|
||||
microagent_obs = NullObservation(content='')
|
||||
|
||||
# important: this will release the execution flow from waiting for the retrieval to complete
|
||||
microagent_obs._cause = event.id # type: ignore[union-attr]
|
||||
|
||||
self.event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
|
||||
return
|
||||
except Exception as e:
|
||||
error_str = f'Error: {str(e.__class__.__name__)}'
|
||||
logger.error(error_str)
|
||||
self.send_error_message('STATUS$ERROR_MEMORY', error_str)
|
||||
return
|
||||
|
||||
def _on_first_microagent_action(
|
||||
def _on_workspace_context_recall(
|
||||
self, event: RecallAction
|
||||
) -> MicroagentObservation | None:
|
||||
"""Add repository and runtime information to the stream as a MicroagentObservation."""
|
||||
) -> RecallObservation | None:
|
||||
"""Add repository and runtime information to the stream as a RecallObservation."""
|
||||
|
||||
# Create ENVIRONMENT info:
|
||||
# Create WORKSPACE_CONTEXT info:
|
||||
# - repository_info
|
||||
# - runtime_info
|
||||
# - repository_instructions
|
||||
# - microagent_knowledge
|
||||
|
||||
# Collect raw repository instructions
|
||||
repo_instructions = ''
|
||||
@@ -130,9 +141,17 @@ class Memory:
|
||||
repo_instructions += '\n\n'
|
||||
repo_instructions += microagent.content
|
||||
|
||||
# Find any matched microagents based on the query
|
||||
microagent_knowledge = self._find_microagent_knowledge(event.query)
|
||||
|
||||
# Create observation if we have anything
|
||||
if self.repository_info or self.runtime_info or repo_instructions:
|
||||
obs = MicroagentObservation(
|
||||
if (
|
||||
self.repository_info
|
||||
or self.runtime_info
|
||||
or repo_instructions
|
||||
or microagent_knowledge
|
||||
):
|
||||
obs = RecallObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name=self.repository_info.repo_name
|
||||
if self.repository_info and self.repository_info.repo_name is not None
|
||||
@@ -149,29 +168,47 @@ class Memory:
|
||||
if self.runtime_info
|
||||
and self.runtime_info.additional_agent_instructions is not None
|
||||
else '',
|
||||
microagent_knowledge=[],
|
||||
content='Retrieved environment info',
|
||||
microagent_knowledge=microagent_knowledge,
|
||||
content='Added workspace context',
|
||||
)
|
||||
return obs
|
||||
return None
|
||||
|
||||
def _on_microagent_action(
|
||||
def _on_microagent_recall(
|
||||
self,
|
||||
event: RecallAction,
|
||||
prev_observation: MicroagentObservation | None = None,
|
||||
) -> MicroagentObservation | None:
|
||||
"""When a microagent action triggers microagents, create a MicroagentObservation with structured data."""
|
||||
# If there's no query, do nothing
|
||||
query = event.query.strip()
|
||||
if not query:
|
||||
return prev_observation
|
||||
) -> RecallObservation | None:
|
||||
"""When a microagent action triggers microagents, create a RecallObservation with structured data."""
|
||||
|
||||
assert prev_observation is None or isinstance(
|
||||
prev_observation, MicroagentObservation
|
||||
), f'Expected a MicroagentObservation, but got {type(prev_observation)}'
|
||||
# Find any matched microagents based on the query
|
||||
microagent_knowledge = self._find_microagent_knowledge(event.query)
|
||||
|
||||
# Process text to find suitable microagents and create a MicroagentObservation.
|
||||
# Create observation if we have anything
|
||||
if microagent_knowledge:
|
||||
obs = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=microagent_knowledge,
|
||||
content='Retrieved knowledge from microagents',
|
||||
)
|
||||
return obs
|
||||
return None
|
||||
|
||||
def _find_microagent_knowledge(self, query: str) -> list[MicroagentKnowledge]:
|
||||
"""Find microagent knowledge based on a query.
|
||||
|
||||
Args:
|
||||
query: The query to search for microagent triggers
|
||||
|
||||
Returns:
|
||||
A list of MicroagentKnowledge objects for matched triggers
|
||||
"""
|
||||
recalled_content: list[MicroagentKnowledge] = []
|
||||
|
||||
# skip empty queries
|
||||
if not query:
|
||||
return recalled_content
|
||||
|
||||
# Search for microagent triggers in the query
|
||||
for name, microagent in self.knowledge_microagents.items():
|
||||
trigger = microagent.match_trigger(query)
|
||||
if trigger:
|
||||
@@ -183,22 +220,7 @@ class Memory:
|
||||
content=microagent.content,
|
||||
)
|
||||
)
|
||||
|
||||
if recalled_content:
|
||||
if prev_observation is not None:
|
||||
# it may be on the first user message that already found some repo info etc
|
||||
prev_observation.microagent_knowledge.extend(recalled_content)
|
||||
else:
|
||||
# if it's not the first user message, we may not have found any information this step
|
||||
obs = MicroagentObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=recalled_content,
|
||||
content='Retrieved knowledge from microagents',
|
||||
)
|
||||
|
||||
return obs
|
||||
|
||||
return prev_observation
|
||||
return recalled_content
|
||||
|
||||
def load_user_workspace_microagents(
|
||||
self, user_microagents: list[BaseMicroAgent]
|
||||
|
||||
@@ -97,7 +97,7 @@ class Runtime(FileEditRuntimeMixin):
|
||||
status_callback: Callable | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = False,
|
||||
github_user_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
self.sid = sid
|
||||
self.event_stream = event_stream
|
||||
@@ -130,7 +130,7 @@ class Runtime(FileEditRuntimeMixin):
|
||||
self, enable_llm_editor=config.get_agent_config().codeact_enable_llm_editor
|
||||
)
|
||||
|
||||
self.github_user_id = github_user_id
|
||||
self.user_id = user_id
|
||||
|
||||
def setup_initial_env(self) -> None:
|
||||
if self.attach_to_existing:
|
||||
@@ -220,9 +220,9 @@ class Runtime(FileEditRuntimeMixin):
|
||||
assert event.timeout is not None
|
||||
try:
|
||||
if isinstance(event, CmdRunAction):
|
||||
if self.github_user_id and '$GITHUB_TOKEN' in event.command:
|
||||
if self.user_id and '$GITHUB_TOKEN' in event.command:
|
||||
gh_client = GithubServiceImpl(
|
||||
user_id=self.github_user_id, external_token_manager=True
|
||||
external_auth_id=self.user_id, external_token_manager=True
|
||||
)
|
||||
token = await gh_client.get_latest_token()
|
||||
if token:
|
||||
|
||||
@@ -59,7 +59,7 @@ class ActionExecutionClient(Runtime):
|
||||
status_callback: Any | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
github_user_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
self.session = HttpSession()
|
||||
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
|
||||
@@ -75,7 +75,7 @@ class ActionExecutionClient(Runtime):
|
||||
status_callback,
|
||||
attach_to_existing,
|
||||
headless_mode,
|
||||
github_user_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable
|
||||
from urllib.parse import urlparse
|
||||
@@ -45,7 +46,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
status_callback: Callable | None = None,
|
||||
attach_to_existing: bool = False,
|
||||
headless_mode: bool = True,
|
||||
github_user_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
config,
|
||||
@@ -56,7 +57,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
status_callback,
|
||||
attach_to_existing,
|
||||
headless_mode,
|
||||
github_user_id,
|
||||
user_id,
|
||||
)
|
||||
if self.config.sandbox.api_key is None:
|
||||
raise ValueError(
|
||||
@@ -425,10 +426,11 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
return self._send_action_server_request_impl(method, url, **kwargs)
|
||||
|
||||
retry_decorator = tenacity.retry(
|
||||
retry=tenacity.retry_if_exception_type(ConnectionError),
|
||||
retry=tenacity.retry_if_exception_type(requests.ConnectionError),
|
||||
stop=tenacity.stop_after_attempt(3)
|
||||
| stop_if_should_exit()
|
||||
| self._stop_if_closed,
|
||||
before_sleep=tenacity.before_sleep_log(logger, logging.WARNING),
|
||||
wait=tenacity.wait_exponential(multiplier=1, min=4, max=60),
|
||||
)
|
||||
return retry_decorator(self._send_action_server_request_impl)(
|
||||
|
||||
@@ -46,7 +46,12 @@ class ConversationManager(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def join_conversation(
|
||||
self, sid: str, connection_id: str, settings: Settings, user_id: str | None
|
||||
self,
|
||||
sid: str,
|
||||
connection_id: str,
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
github_user_id: str | None,
|
||||
) -> EventStream | None:
|
||||
"""Join a conversation and return its event stream."""
|
||||
|
||||
@@ -74,6 +79,7 @@ class ConversationManager(ABC):
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
initial_user_msg: MessageAction | None = None,
|
||||
github_user_id: str | None = None,
|
||||
) -> EventStream:
|
||||
"""Start an event loop if one is not already running"""
|
||||
|
||||
|
||||
@@ -106,7 +106,12 @@ class StandaloneConversationManager(ConversationManager):
|
||||
return c
|
||||
|
||||
async def join_conversation(
|
||||
self, sid: str, connection_id: str, settings: Settings, user_id: str | None
|
||||
self,
|
||||
sid: str,
|
||||
connection_id: str,
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
github_user_id: str | None,
|
||||
):
|
||||
logger.info(
|
||||
f'join_conversation:{sid}:{connection_id}',
|
||||
@@ -116,7 +121,9 @@ class StandaloneConversationManager(ConversationManager):
|
||||
self._local_connection_id_to_session_id[connection_id] = sid
|
||||
event_stream = await self._get_event_stream(sid)
|
||||
if not event_stream:
|
||||
return await self.maybe_start_agent_loop(sid, settings, user_id)
|
||||
return await self.maybe_start_agent_loop(
|
||||
sid, settings, user_id, github_user_id=github_user_id
|
||||
)
|
||||
for event in event_stream.get_events(reverse=True):
|
||||
if isinstance(event, AgentStateChangedObservation):
|
||||
if event.agent_state in (
|
||||
@@ -187,14 +194,18 @@ class StandaloneConversationManager(ConversationManager):
|
||||
logger.error('error_cleaning_stale')
|
||||
await asyncio.sleep(_CLEANUP_INTERVAL)
|
||||
|
||||
async def _get_conversation_store(self, user_id: str | None) -> ConversationStore:
|
||||
async def _get_conversation_store(
|
||||
self, user_id: str | None, github_user_id: str | None
|
||||
) -> ConversationStore:
|
||||
conversation_store_class = self._conversation_store_class
|
||||
if not conversation_store_class:
|
||||
self._conversation_store_class = conversation_store_class = get_impl(
|
||||
ConversationStore, # type: ignore
|
||||
self.server_config.conversation_store_class,
|
||||
)
|
||||
store = await conversation_store_class.get_instance(self.config, user_id)
|
||||
store = await conversation_store_class.get_instance(
|
||||
self.config, user_id, github_user_id
|
||||
)
|
||||
return store
|
||||
|
||||
async def get_running_agent_loops(
|
||||
@@ -243,6 +254,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
settings: Settings,
|
||||
user_id: str | None,
|
||||
initial_user_msg: MessageAction | None = None,
|
||||
github_user_id: str | None = None,
|
||||
) -> EventStream:
|
||||
logger.info(f'maybe_start_agent_loop:{sid}', extra={'session_id': sid})
|
||||
session: Session | None = None
|
||||
@@ -256,7 +268,9 @@ class StandaloneConversationManager(ConversationManager):
|
||||
extra={'session_id': sid, 'user_id': user_id},
|
||||
)
|
||||
# Get the conversations sorted (oldest first)
|
||||
conversation_store = await self._get_conversation_store(user_id)
|
||||
conversation_store = await self._get_conversation_store(
|
||||
user_id, github_user_id
|
||||
)
|
||||
conversations = await conversation_store.get_all_metadata(response_ids)
|
||||
conversations.sort(key=_last_updated_at_key, reverse=True)
|
||||
|
||||
@@ -277,7 +291,9 @@ class StandaloneConversationManager(ConversationManager):
|
||||
try:
|
||||
session.agent_session.event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER,
|
||||
self._create_conversation_update_callback(user_id, sid),
|
||||
self._create_conversation_update_callback(
|
||||
user_id, github_user_id, sid
|
||||
),
|
||||
UPDATED_AT_CALLBACK_ID,
|
||||
)
|
||||
except ValueError:
|
||||
@@ -374,22 +390,23 @@ class StandaloneConversationManager(ConversationManager):
|
||||
)
|
||||
|
||||
def _create_conversation_update_callback(
|
||||
self, user_id: str | None, conversation_id: str
|
||||
self, user_id: str | None, github_user_id: str | None, conversation_id: str
|
||||
) -> Callable:
|
||||
def callback(*args, **kwargs):
|
||||
call_async_from_sync(
|
||||
self._update_timestamp_for_conversation,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
github_user_id,
|
||||
conversation_id,
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
async def _update_timestamp_for_conversation(
|
||||
self, user_id: str, conversation_id: str
|
||||
self, user_id: str, github_user_id: str, conversation_id: str
|
||||
):
|
||||
conversation_store = await self._get_conversation_store(user_id)
|
||||
conversation_store = await self._get_conversation_store(user_id, github_user_id)
|
||||
conversation = await conversation_store.get_metadata(conversation_id)
|
||||
conversation.last_updated_at = datetime.now(timezone.utc)
|
||||
await conversation_store.save_metadata(conversation)
|
||||
|
||||
@@ -6,10 +6,14 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import (
|
||||
NullAction,
|
||||
)
|
||||
from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.observation import (
|
||||
NullObservation,
|
||||
)
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.events.observation.agent import (
|
||||
AgentStateChangedObservation,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.events.stream import AsyncEventStreamWrapper
|
||||
from openhands.server.shared import (
|
||||
@@ -35,7 +39,9 @@ async def connect(connection_id: str, environ):
|
||||
|
||||
cookies_str = environ.get('HTTP_COOKIE', '')
|
||||
conversation_validator = ConversationValidatorImpl()
|
||||
user_id = await conversation_validator.validate(conversation_id, cookies_str)
|
||||
user_id, github_user_id = await conversation_validator.validate(
|
||||
conversation_id, cookies_str
|
||||
)
|
||||
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||
settings = await settings_store.load()
|
||||
@@ -46,7 +52,7 @@ async def connect(connection_id: str, environ):
|
||||
)
|
||||
|
||||
event_stream = await conversation_manager.join_conversation(
|
||||
conversation_id, connection_id, settings, user_id
|
||||
conversation_id, connection_id, settings, user_id, github_user_id
|
||||
)
|
||||
|
||||
agent_state_changed = None
|
||||
@@ -54,10 +60,7 @@ async def connect(connection_id: str, environ):
|
||||
async for event in async_stream:
|
||||
if isinstance(
|
||||
event,
|
||||
(
|
||||
NullAction,
|
||||
NullObservation,
|
||||
),
|
||||
(NullAction, NullObservation, RecallAction, RecallObservation),
|
||||
):
|
||||
continue
|
||||
elif isinstance(event, AgentStateChangedObservation):
|
||||
|
||||
@@ -10,7 +10,12 @@ from openhands.events.action.message import MessageAction
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.runtime import get_runtime_cls
|
||||
from openhands.server.auth import get_provider_tokens, get_access_token, get_github_user_id
|
||||
from openhands.server.auth import (
|
||||
get_access_token,
|
||||
get_github_user_id,
|
||||
get_provider_tokens,
|
||||
get_user_id,
|
||||
)
|
||||
from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.data_models.conversation_info_result_set import (
|
||||
ConversationInfoResultSet,
|
||||
@@ -73,12 +78,12 @@ async def _create_new_conversation(
|
||||
logger.warn('Settings not present, not starting conversation')
|
||||
raise MissingSettingsError('Settings not found')
|
||||
|
||||
session_init_args['github_token'] = token or SecretStr('')
|
||||
session_init_args['provider_token'] = token
|
||||
session_init_args['selected_repository'] = selected_repository
|
||||
session_init_args['selected_branch'] = selected_branch
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
logger.info('Loading conversation store')
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id, None)
|
||||
logger.info('Conversation store loaded')
|
||||
|
||||
conversation_id = uuid.uuid4().hex
|
||||
@@ -100,7 +105,8 @@ async def _create_new_conversation(
|
||||
ConversationMetadata(
|
||||
conversation_id=conversation_id,
|
||||
title=conversation_title,
|
||||
github_user_id=user_id,
|
||||
user_id=user_id,
|
||||
github_user_id=None,
|
||||
selected_repository=selected_repository,
|
||||
selected_branch=selected_branch,
|
||||
)
|
||||
@@ -122,7 +128,10 @@ async def _create_new_conversation(
|
||||
image_urls=image_urls or [],
|
||||
)
|
||||
await conversation_manager.maybe_start_agent_loop(
|
||||
conversation_id, conversation_init_data, user_id, initial_message_action
|
||||
conversation_id,
|
||||
conversation_init_data,
|
||||
user_id,
|
||||
initial_user_msg=initial_message_action,
|
||||
)
|
||||
logger.info(f'Finished initializing conversation {conversation_id}')
|
||||
|
||||
@@ -158,7 +167,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
||||
try:
|
||||
# Create conversation with initial message
|
||||
conversation_id = await _create_new_conversation(
|
||||
user_id,
|
||||
get_user_id(request),
|
||||
github_token,
|
||||
selected_repository,
|
||||
selected_branch,
|
||||
@@ -197,7 +206,7 @@ async def search_conversations(
|
||||
limit: int = 20,
|
||||
) -> ConversationInfoResultSet:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_github_user_id(request)
|
||||
config, get_user_id(request), get_github_user_id(request)
|
||||
)
|
||||
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
|
||||
|
||||
@@ -216,7 +225,7 @@ async def search_conversations(
|
||||
conversation.conversation_id for conversation in filtered_results
|
||||
)
|
||||
running_conversations = await conversation_manager.get_running_agent_loops(
|
||||
get_github_user_id(request), set(conversation_ids)
|
||||
get_user_id(request), set(conversation_ids)
|
||||
)
|
||||
result = ConversationInfoResultSet(
|
||||
results=await wait_all(
|
||||
@@ -236,7 +245,7 @@ async def get_conversation(
|
||||
conversation_id: str, request: Request
|
||||
) -> ConversationInfo | None:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_github_user_id(request)
|
||||
config, get_user_id(request), get_github_user_id(request)
|
||||
)
|
||||
try:
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
@@ -252,7 +261,7 @@ async def update_conversation(
|
||||
request: Request, conversation_id: str, title: str = Body(embed=True)
|
||||
) -> bool:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_github_user_id(request)
|
||||
config, get_user_id(request), get_github_user_id(request)
|
||||
)
|
||||
metadata = await conversation_store.get_metadata(conversation_id)
|
||||
if not metadata:
|
||||
@@ -268,7 +277,7 @@ async def delete_conversation(
|
||||
request: Request,
|
||||
) -> bool:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, get_github_user_id(request)
|
||||
config, get_user_id(request), get_github_user_id(request)
|
||||
)
|
||||
try:
|
||||
await conversation_store.get_metadata(conversation_id)
|
||||
|
||||
@@ -90,30 +90,38 @@ async def store_settings(
|
||||
existing_settings.user_consents_to_analytics
|
||||
)
|
||||
|
||||
if existing_settings.secrets_store:
|
||||
existing_providers = [
|
||||
provider.value
|
||||
for provider in existing_settings.secrets_store.provider_tokens
|
||||
]
|
||||
|
||||
# Merge incoming settings store with the existing one
|
||||
for provider, token_value in settings.provider_tokens.items():
|
||||
if provider in existing_providers and not token_value:
|
||||
provider_type = ProviderType(provider)
|
||||
existing_token = (
|
||||
existing_settings.secrets_store.provider_tokens.get(
|
||||
provider_type
|
||||
)
|
||||
)
|
||||
if existing_token and existing_token.token:
|
||||
settings.provider_tokens[provider] = (
|
||||
existing_token.token.get_secret_value()
|
||||
)
|
||||
|
||||
# Merge provider tokens with existing ones
|
||||
if settings.unset_github_token: # Only merge if not unsetting tokens
|
||||
if settings.unset_github_token:
|
||||
settings.secrets_store.provider_tokens = {}
|
||||
settings.provider_tokens = {}
|
||||
else: # Only merge if not unsetting tokens
|
||||
if settings.provider_tokens:
|
||||
if existing_settings.secrets_store:
|
||||
existing_providers = [
|
||||
provider.value
|
||||
for provider in existing_settings.secrets_store.provider_tokens
|
||||
]
|
||||
|
||||
# Merge incoming settings store with the existing one
|
||||
for provider, token_value in settings.provider_tokens.items():
|
||||
if provider in existing_providers and not token_value:
|
||||
provider_type = ProviderType(provider)
|
||||
existing_token = (
|
||||
existing_settings.secrets_store.provider_tokens.get(
|
||||
provider_type
|
||||
)
|
||||
)
|
||||
if existing_token and existing_token.token:
|
||||
settings.provider_tokens[provider] = (
|
||||
existing_token.token.get_secret_value()
|
||||
)
|
||||
else: # nothing passed in means keep current settings
|
||||
provider_tokens = existing_settings.secrets_store.provider_tokens
|
||||
settings.provider_tokens = {
|
||||
provider.value: data.token.get_secret_value()
|
||||
if data.token
|
||||
else None
|
||||
for provider, data in provider_tokens.items()
|
||||
}
|
||||
|
||||
# Update sandbox config with new settings
|
||||
if settings.remote_runtime_resource_factor is not None:
|
||||
|
||||
@@ -53,7 +53,7 @@ class AgentSession:
|
||||
sid: str,
|
||||
file_store: FileStore,
|
||||
status_callback: Callable | None = None,
|
||||
github_user_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
):
|
||||
"""Initializes a new instance of the Session class
|
||||
|
||||
@@ -66,9 +66,9 @@ class AgentSession:
|
||||
self.event_stream = EventStream(sid, file_store)
|
||||
self.file_store = file_store
|
||||
self._status_callback = status_callback
|
||||
self.github_user_id = github_user_id
|
||||
self.user_id = user_id
|
||||
self.logger = OpenHandsLoggerAdapter(
|
||||
extra={'session_id': sid, 'user_id': github_user_id}
|
||||
extra={'session_id': sid, 'user_id': user_id}
|
||||
)
|
||||
|
||||
async def start(
|
||||
@@ -241,7 +241,7 @@ class AgentSession:
|
||||
|
||||
kwargs = {}
|
||||
if runtime_cls == RemoteRuntime:
|
||||
kwargs['github_user_id'] = self.github_user_id
|
||||
kwargs['user_id'] = self.user_id
|
||||
|
||||
self.runtime = runtime_cls(
|
||||
config=config,
|
||||
|
||||
@@ -8,6 +8,6 @@ class ConversationInitData(Settings):
|
||||
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
|
||||
"""
|
||||
|
||||
github_token: SecretStr | None = Field(default=None)
|
||||
provider_token: SecretStr | None = Field(default=None)
|
||||
selected_repository: str | None = Field(default=None)
|
||||
selected_branch: str | None = Field(default=None)
|
||||
|
||||
@@ -61,7 +61,7 @@ class Session:
|
||||
sid,
|
||||
file_store,
|
||||
status_callback=self.queue_status_message,
|
||||
github_user_id=user_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
self.agent_session.event_stream.subscribe(
|
||||
EventStreamSubscriber.SERVER, self.on_event, self.sid
|
||||
@@ -123,11 +123,11 @@ class Session:
|
||||
|
||||
agent = Agent.get_cls(agent_cls)(llm, agent_config)
|
||||
|
||||
github_token = None
|
||||
provider_token = None
|
||||
selected_repository = None
|
||||
selected_branch = None
|
||||
if isinstance(settings, ConversationInitData):
|
||||
github_token = settings.github_token
|
||||
provider_token = settings.provider_token
|
||||
selected_repository = settings.selected_repository
|
||||
selected_branch = settings.selected_branch
|
||||
|
||||
@@ -140,7 +140,7 @@ class Session:
|
||||
max_budget_per_task=self.config.max_budget_per_task,
|
||||
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
|
||||
agent_configs=self.config.get_agent_configs(),
|
||||
github_token=github_token,
|
||||
github_token=provider_token,
|
||||
selected_repository=selected_repository,
|
||||
selected_branch=selected_branch,
|
||||
initial_message=initial_message,
|
||||
|
||||
@@ -43,7 +43,7 @@ class Settings(BaseModel):
|
||||
if context and context.get('expose_secrets', False):
|
||||
return llm_api_key.get_secret_value()
|
||||
|
||||
return pydantic_encoder(llm_api_key)
|
||||
return pydantic_encoder(llm_api_key) if llm_api_key else None
|
||||
|
||||
@staticmethod
|
||||
def _convert_token_value(
|
||||
|
||||
@@ -12,25 +12,36 @@ from openhands.utils.async_utils import wait_all
|
||||
|
||||
|
||||
class ConversationStore(ABC):
|
||||
"""
|
||||
Storage for conversation metadata. May or may not support multiple users depending on the environment
|
||||
"""
|
||||
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
|
||||
|
||||
@abstractmethod
|
||||
async def save_metadata(self, metadata: ConversationMetadata) -> None:
|
||||
"""Store conversation metadata"""
|
||||
"""Store conversation metadata."""
|
||||
|
||||
@abstractmethod
|
||||
async def get_metadata(self, conversation_id: str) -> ConversationMetadata:
|
||||
"""Load conversation metadata"""
|
||||
"""Load conversation metadata."""
|
||||
|
||||
async def validate_metadata(
|
||||
self, conversation_id: str, user_id: str, github_user_id: str
|
||||
) -> bool:
|
||||
"""Validate that conversation belongs to the current user."""
|
||||
# TODO: remove github_user_id after transition to Keycloak is complete.
|
||||
metadata = await self.get_metadata(conversation_id)
|
||||
if (not metadata.user_id and not metadata.github_user_id) or (
|
||||
metadata.user_id != user_id and metadata.github_user_id != github_user_id
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
async def delete_metadata(self, conversation_id: str) -> None:
|
||||
"""delete conversation metadata"""
|
||||
"""Delete conversation metadata."""
|
||||
|
||||
@abstractmethod
|
||||
async def exists(self, conversation_id: str) -> bool:
|
||||
"""Check if conversation exists"""
|
||||
"""Check if conversation exists."""
|
||||
|
||||
@abstractmethod
|
||||
async def search(
|
||||
@@ -49,6 +60,6 @@ class ConversationStore(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
async def get_instance(
|
||||
cls, config: AppConfig, user_id: str | None
|
||||
cls, config: AppConfig, user_id: str | None, github_user_id: str | None
|
||||
) -> ConversationStore:
|
||||
"""Get a store for the user represented by the token given"""
|
||||
|
||||
@@ -7,7 +7,7 @@ class ConversationValidator:
|
||||
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
|
||||
|
||||
async def validate(self, conversation_id: str, cookies_str: str):
|
||||
return None
|
||||
return None, None
|
||||
|
||||
|
||||
conversation_validator_cls = os.environ.get(
|
||||
|
||||
@@ -101,7 +101,7 @@ class FileConversationStore(ConversationStore):
|
||||
|
||||
@classmethod
|
||||
async def get_instance(
|
||||
cls, config: AppConfig, user_id: str | None
|
||||
cls, config: AppConfig, user_id: str | None, github_user_id: str | None
|
||||
) -> FileConversationStore:
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
return FileConversationStore(file_store)
|
||||
|
||||
@@ -5,6 +5,7 @@ from datetime import datetime, timezone
|
||||
@dataclass
|
||||
class ConversationMetadata:
|
||||
conversation_id: str
|
||||
user_id: str | None
|
||||
github_user_id: str | None
|
||||
selected_repository: str | None
|
||||
selected_branch: str | None = None
|
||||
|
||||
@@ -76,7 +76,7 @@ class PromptManager:
|
||||
if example_message:
|
||||
message.content.insert(0, TextContent(text=example_message))
|
||||
|
||||
def build_additional_info(
|
||||
def build_workspace_context(
|
||||
self,
|
||||
repository_info: RepositoryInfo | None,
|
||||
runtime_info: RuntimeInfo | None,
|
||||
|
||||
Generated
+37
-37
@@ -496,18 +496,18 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "boto3"
|
||||
version = "1.37.11"
|
||||
version = "1.37.12"
|
||||
description = "The AWS SDK for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "boto3-1.37.11-py3-none-any.whl", hash = "sha256:da6c22fc8a7e9bca5d7fc465a877ac3d45b6b086d776bd1a6c55bdde60523741"},
|
||||
{file = "boto3-1.37.11.tar.gz", hash = "sha256:8eec08363ef5db05c2fbf58e89f0c0de6276cda2fdce01e76b3b5f423cd5c0f4"},
|
||||
{file = "boto3-1.37.12-py3-none-any.whl", hash = "sha256:516feaa0d2afaeda1515216fd09291368a1215754bbccb0f28414c0a91a830a2"},
|
||||
{file = "boto3-1.37.12.tar.gz", hash = "sha256:9412d404f103ad6d14f033eb29cd5e0cdca2b9b08cbfa9d4dabd1d7be2de2625"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
botocore = ">=1.37.11,<1.38.0"
|
||||
botocore = ">=1.37.12,<1.38.0"
|
||||
jmespath = ">=0.7.1,<2.0.0"
|
||||
s3transfer = ">=0.11.0,<0.12.0"
|
||||
|
||||
@@ -516,14 +516,14 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
|
||||
|
||||
[[package]]
|
||||
name = "botocore"
|
||||
version = "1.37.11"
|
||||
version = "1.37.12"
|
||||
description = "Low-level, data-driven core of boto 3."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "botocore-1.37.11-py3-none-any.whl", hash = "sha256:02505309b1235f9f15a6da79103ca224b3f3dc5f6a62f8630fbb2c6ed05e2da8"},
|
||||
{file = "botocore-1.37.11.tar.gz", hash = "sha256:72eb3a9a58b064be26ba154e5e56373633b58f951941c340ace0d379590d98b5"},
|
||||
{file = "botocore-1.37.12-py3-none-any.whl", hash = "sha256:ba1948c883bbabe20d95ff62c3e36954c9269686f7db9361857835677ca3e676"},
|
||||
{file = "botocore-1.37.12.tar.gz", hash = "sha256:ae2d5328ce6ad02eb615270507235a6e90fd3eeed615a6c0732b5a68b12f2017"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -3547,14 +3547,14 @@ test = ["jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-jupyter[server] (>
|
||||
|
||||
[[package]]
|
||||
name = "jupyterlab"
|
||||
version = "4.3.5"
|
||||
version = "4.3.6"
|
||||
description = "JupyterLab computational environment"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["runtime"]
|
||||
files = [
|
||||
{file = "jupyterlab-4.3.5-py3-none-any.whl", hash = "sha256:571bbdee20e4c5321ab5195bc41cf92a75a5cff886be5e57ce78dfa37a5e9fdb"},
|
||||
{file = "jupyterlab-4.3.5.tar.gz", hash = "sha256:c779bf72ced007d7d29d5bcef128e7fdda96ea69299e19b04a43635a7d641f9d"},
|
||||
{file = "jupyterlab-4.3.6-py3-none-any.whl", hash = "sha256:fc9eb0455562a56a9bd6d2977cf090842f321fa1a298fcee9bf8c19de353d5fd"},
|
||||
{file = "jupyterlab-4.3.6.tar.gz", hash = "sha256:2900ffdbfca9ed37c4ad7fdda3eb76582fd945d46962af3ac64741ae2d6b2ff4"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -4251,14 +4251,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "modal"
|
||||
version = "0.73.98"
|
||||
version = "0.73.102"
|
||||
description = "Python client library for Modal"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main", "evaluation"]
|
||||
files = [
|
||||
{file = "modal-0.73.98-py3-none-any.whl", hash = "sha256:a49cd5f5b46d1a6c6a0d528618d3cbb73ac2908e199716590ec3a5275d79ed98"},
|
||||
{file = "modal-0.73.98.tar.gz", hash = "sha256:817f73c222fa39a16d6888a92eb7a6847ecae574e44ef04e2dce5e534bdd2df9"},
|
||||
{file = "modal-0.73.102-py3-none-any.whl", hash = "sha256:26151ef6164e0b93b0d1961f73d5a715deb72f23e2641215f5410cf58bf403d3"},
|
||||
{file = "modal-0.73.102.tar.gz", hash = "sha256:198876cf94ff13633283e251d8b37cc1f1bb5e27a7aa547e02072def1f29b66e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -4670,19 +4670,19 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "notebook"
|
||||
version = "7.3.2"
|
||||
version = "7.3.3"
|
||||
description = "Jupyter Notebook - A web-based notebook environment for interactive computing"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["runtime"]
|
||||
files = [
|
||||
{file = "notebook-7.3.2-py3-none-any.whl", hash = "sha256:e5f85fc59b69d3618d73cf27544418193ff8e8058d5bf61d315ce4f473556288"},
|
||||
{file = "notebook-7.3.2.tar.gz", hash = "sha256:705e83a1785f45b383bf3ee13cb76680b92d24f56fb0c7d2136fe1d850cd3ca8"},
|
||||
{file = "notebook-7.3.3-py3-none-any.whl", hash = "sha256:b193df0878956562d5171c8e25c9252b8e86c9fcc16163b8ee3fe6c5e3f422f7"},
|
||||
{file = "notebook-7.3.3.tar.gz", hash = "sha256:707a313fb882d35f921989eb3d204de942ed5132a44e4aa1fe0e8f24bb9dc25d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
jupyter-server = ">=2.4.0,<3"
|
||||
jupyterlab = ">=4.3.4,<4.4"
|
||||
jupyterlab = ">=4.3.6,<4.4"
|
||||
jupyterlab-server = ">=2.27.1,<3"
|
||||
notebook-shim = ">=0.2,<0.3"
|
||||
tornado = ">=6.2.0"
|
||||
@@ -6947,30 +6947,30 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.9.10"
|
||||
version = "0.11.0"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
groups = ["dev", "evaluation"]
|
||||
files = [
|
||||
{file = "ruff-0.9.10-py3-none-linux_armv6l.whl", hash = "sha256:eb4d25532cfd9fe461acc83498361ec2e2252795b4f40b17e80692814329e42d"},
|
||||
{file = "ruff-0.9.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:188a6638dab1aa9bb6228a7302387b2c9954e455fb25d6b4470cb0641d16759d"},
|
||||
{file = "ruff-0.9.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5284dcac6b9dbc2fcb71fdfc26a217b2ca4ede6ccd57476f52a587451ebe450d"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:47678f39fa2a3da62724851107f438c8229a3470f533894b5568a39b40029c0c"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:99713a6e2766b7a17147b309e8c915b32b07a25c9efd12ada79f217c9c778b3e"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:524ee184d92f7c7304aa568e2db20f50c32d1d0caa235d8ddf10497566ea1a12"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:df92aeac30af821f9acf819fc01b4afc3dfb829d2782884f8739fb52a8119a16"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de42e4edc296f520bb84954eb992a07a0ec5a02fecb834498415908469854a52"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d257f95b65806104b6b1ffca0ea53f4ef98454036df65b1eda3693534813ecd1"},
|
||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b60dec7201c0b10d6d11be00e8f2dbb6f40ef1828ee75ed739923799513db24c"},
|
||||
{file = "ruff-0.9.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:d838b60007da7a39c046fcdd317293d10b845001f38bcb55ba766c3875b01e43"},
|
||||
{file = "ruff-0.9.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:ccaf903108b899beb8e09a63ffae5869057ab649c1e9231c05ae354ebc62066c"},
|
||||
{file = "ruff-0.9.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f9567d135265d46e59d62dc60c0bfad10e9a6822e231f5b24032dba5a55be6b5"},
|
||||
{file = "ruff-0.9.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5f202f0d93738c28a89f8ed9eaba01b7be339e5d8d642c994347eaa81c6d75b8"},
|
||||
{file = "ruff-0.9.10-py3-none-win32.whl", hash = "sha256:bfb834e87c916521ce46b1788fbb8484966e5113c02df216680102e9eb960029"},
|
||||
{file = "ruff-0.9.10-py3-none-win_amd64.whl", hash = "sha256:f2160eeef3031bf4b17df74e307d4c5fb689a6f3a26a2de3f7ef4044e3c484f1"},
|
||||
{file = "ruff-0.9.10-py3-none-win_arm64.whl", hash = "sha256:5fd804c0327a5e5ea26615550e706942f348b197d5475ff34c19733aee4b2e69"},
|
||||
{file = "ruff-0.9.10.tar.gz", hash = "sha256:9bacb735d7bada9cfb0f2c227d3658fc443d90a727b47f206fb33f52f3c0eac7"},
|
||||
{file = "ruff-0.11.0-py3-none-linux_armv6l.whl", hash = "sha256:dc67e32bc3b29557513eb7eeabb23efdb25753684b913bebb8a0c62495095acb"},
|
||||
{file = "ruff-0.11.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:38c23fd9bdec4eb437b4c1e3595905a0a8edfccd63a790f818b28c78fe345639"},
|
||||
{file = "ruff-0.11.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7c8661b0be91a38bd56db593e9331beaf9064a79028adee2d5f392674bbc5e88"},
|
||||
{file = "ruff-0.11.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b6c0e8d3d2db7e9f6efd884f44b8dc542d5b6b590fc4bb334fdbc624d93a29a2"},
|
||||
{file = "ruff-0.11.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3c3156d3f4b42e57247275a0a7e15a851c165a4fc89c5e8fa30ea6da4f7407b8"},
|
||||
{file = "ruff-0.11.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:490b1e147c1260545f6d041c4092483e3f6d8eba81dc2875eaebcf9140b53905"},
|
||||
{file = "ruff-0.11.0-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:1bc09a7419e09662983b1312f6fa5dab829d6ab5d11f18c3760be7ca521c9329"},
|
||||
{file = "ruff-0.11.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bcfa478daf61ac8002214eb2ca5f3e9365048506a9d52b11bea3ecea822bb844"},
|
||||
{file = "ruff-0.11.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6fbb2aed66fe742a6a3a0075ed467a459b7cedc5ae01008340075909d819df1e"},
|
||||
{file = "ruff-0.11.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:92c0c1ff014351c0b0cdfdb1e35fa83b780f1e065667167bb9502d47ca41e6db"},
|
||||
{file = "ruff-0.11.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e4fd5ff5de5f83e0458a138e8a869c7c5e907541aec32b707f57cf9a5e124445"},
|
||||
{file = "ruff-0.11.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:96bc89a5c5fd21a04939773f9e0e276308be0935de06845110f43fd5c2e4ead7"},
|
||||
{file = "ruff-0.11.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a9352b9d767889ec5df1483f94870564e8102d4d7e99da52ebf564b882cdc2c7"},
|
||||
{file = "ruff-0.11.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:049a191969a10897fe052ef9cc7491b3ef6de79acd7790af7d7897b7a9bfbcb6"},
|
||||
{file = "ruff-0.11.0-py3-none-win32.whl", hash = "sha256:3191e9116b6b5bbe187447656f0c8526f0d36b6fd89ad78ccaad6bdc2fad7df2"},
|
||||
{file = "ruff-0.11.0-py3-none-win_amd64.whl", hash = "sha256:c58bfa00e740ca0a6c43d41fb004cd22d165302f360aaa56f7126d544db31a21"},
|
||||
{file = "ruff-0.11.0-py3-none-win_arm64.whl", hash = "sha256:868364fc23f5aa122b00c6f794211e85f7e78f5dffdf7c590ab90b8c4e69b657"},
|
||||
{file = "ruff-0.11.0.tar.gz", hash = "sha256:e55c620690a4a7ee6f1cccb256ec2157dc597d109400ae75bbf944fc9d6462e2"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -9056,4 +9056,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.12"
|
||||
content-hash = "6a644bc65782a717a49718496bd279ecb888807ec625d992af4448cc5d9271c1"
|
||||
content-hash = "9b74f62a4afa719a1f7167e0b3b45cdaf282c2e18fd2931da91c0f1b22776178"
|
||||
|
||||
+6
-1
@@ -80,7 +80,7 @@ daytona-sdk = "0.10.2"
|
||||
python-json-logger = "^3.2.1"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "0.9.10"
|
||||
ruff = "0.11.0"
|
||||
mypy = "1.15.0"
|
||||
pre-commit = "4.1.0"
|
||||
build = "*"
|
||||
@@ -154,3 +154,8 @@ style = "semver"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
openhands = "openhands.core.cli:main"
|
||||
|
||||
[tool.poetry.group.testgeneval.dependencies]
|
||||
fuzzywuzzy = "^0.18.0"
|
||||
rouge = "^1.0.1"
|
||||
python-levenshtein = "^0.26.1"
|
||||
|
||||
@@ -19,7 +19,7 @@ from openhands.events.event import RecallType
|
||||
from openhands.events.observation import (
|
||||
ErrorObservation,
|
||||
)
|
||||
from openhands.events.observation.agent import MicroagentObservation
|
||||
from openhands.events.observation.agent import RecallObservation
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.metrics import Metrics, TokenUsage
|
||||
@@ -192,7 +192,7 @@ async def test_run_controller_with_fatal_error(test_event_stream, mock_memory):
|
||||
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
microagent_obs = MicroagentObservation(
|
||||
microagent_obs = RecallObservation(
|
||||
content='Test microagent content',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
)
|
||||
@@ -249,7 +249,7 @@ async def test_run_controller_stop_with_stuck(test_event_stream, mock_memory):
|
||||
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
microagent_obs = MicroagentObservation(
|
||||
microagent_obs = RecallObservation(
|
||||
content='Test microagent content',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
)
|
||||
@@ -596,7 +596,7 @@ async def test_run_controller_max_iterations_has_metrics(
|
||||
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
microagent_obs = MicroagentObservation(
|
||||
microagent_obs = RecallObservation(
|
||||
content='Test microagent content',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
)
|
||||
@@ -718,7 +718,7 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
microagent_obs = MicroagentObservation(
|
||||
microagent_obs = RecallObservation(
|
||||
content='Test microagent content',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
)
|
||||
@@ -795,7 +795,7 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
|
||||
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
microagent_obs = MicroagentObservation(
|
||||
microagent_obs = RecallObservation(
|
||||
content='Test microagent content',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
)
|
||||
@@ -845,23 +845,30 @@ async def test_run_controller_with_memory_error(test_event_stream):
|
||||
config = AppConfig()
|
||||
event_stream = test_event_stream
|
||||
|
||||
# Create a propert agent that returns an action without an ID
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = config.get_llm_config()
|
||||
|
||||
# Create a real action to return from the mocked step function
|
||||
def agent_step_fn(state):
|
||||
return MessageAction(content='Agent returned a message')
|
||||
|
||||
agent.step = agent_step_fn
|
||||
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
# Create a real Memory instance
|
||||
memory = Memory(event_stream=event_stream, sid='test-memory')
|
||||
|
||||
# Patch the _on_microagent_action method to raise our test exception
|
||||
def mock_on_microagent_action(*args, **kwargs):
|
||||
# Patch the _find_microagent_knowledge method to raise our test exception
|
||||
def mock_find_microagent_knowledge(*args, **kwargs):
|
||||
raise RuntimeError('Test memory error')
|
||||
|
||||
with patch.object(
|
||||
memory, '_on_microagent_action', side_effect=mock_on_microagent_action
|
||||
memory, '_find_microagent_knowledge', side_effect=mock_find_microagent_knowledge
|
||||
):
|
||||
state = await run_controller(
|
||||
config=config,
|
||||
|
||||
@@ -19,7 +19,7 @@ from openhands.events.action import (
|
||||
)
|
||||
from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.event import Event, RecallType
|
||||
from openhands.events.observation.agent import MicroagentObservation
|
||||
from openhands.events.observation.agent import RecallObservation
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
@@ -86,10 +86,10 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
||||
|
||||
def on_event(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
# create a MicroagentObservation
|
||||
microagent_observation = MicroagentObservation(
|
||||
# create a RecallObservation
|
||||
microagent_observation = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
content='microagent',
|
||||
content='Found info',
|
||||
)
|
||||
microagent_observation._cause = event.id # ignore attr-defined warning
|
||||
mock_event_stream.add_event(microagent_observation, EventSource.ENVIRONMENT)
|
||||
@@ -111,14 +111,14 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
|
||||
# Give time for the async step() to execute
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Verify that a MicroagentObservation was added to the event stream
|
||||
# Verify that a RecallObservation was added to the event stream
|
||||
events = list(mock_event_stream.get_events())
|
||||
assert (
|
||||
mock_event_stream.get_latest_event_id() == 3
|
||||
) # Microagents and AgentChangeState
|
||||
|
||||
# a MicroagentObservation and an AgentDelegateAction should be in the list
|
||||
assert any(isinstance(event, MicroagentObservation) for event in events)
|
||||
# a RecallObservation and an AgentDelegateAction should be in the list
|
||||
assert any(isinstance(event, RecallObservation) for event in events)
|
||||
assert any(isinstance(event, AgentDelegateAction) for event in events)
|
||||
|
||||
# Verify that a delegate agent controller is created
|
||||
|
||||
@@ -6,11 +6,11 @@ from litellm import ChatCompletionMessageToolCall
|
||||
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
|
||||
from openhands.agenthub.codeact_agent.function_calling import (
|
||||
BrowserTool,
|
||||
CmdRunTool,
|
||||
IPythonTool,
|
||||
LLMBasedFileEditTool,
|
||||
StrReplaceEditorTool,
|
||||
WebReadTool,
|
||||
create_cmd_run_tool,
|
||||
create_str_replace_editor_tool,
|
||||
get_tools,
|
||||
response_to_actions,
|
||||
)
|
||||
@@ -119,6 +119,7 @@ def test_get_tools_with_options():
|
||||
|
||||
|
||||
def test_cmd_run_tool():
|
||||
CmdRunTool = create_cmd_run_tool()
|
||||
assert CmdRunTool['type'] == 'function'
|
||||
assert CmdRunTool['function']['name'] == 'execute_bash'
|
||||
assert 'command' in CmdRunTool['function']['parameters']['properties']
|
||||
@@ -149,6 +150,7 @@ def test_llm_based_file_edit_tool():
|
||||
|
||||
|
||||
def test_str_replace_editor_tool():
|
||||
StrReplaceEditorTool = create_str_replace_editor_tool()
|
||||
assert StrReplaceEditorTool['type'] == 'function'
|
||||
assert StrReplaceEditorTool['function']['name'] == 'str_replace_editor'
|
||||
|
||||
@@ -236,7 +238,11 @@ def test_step_with_no_pending_actions(mock_state: State):
|
||||
mock_response.choices[0].message.content = 'Task completed'
|
||||
mock_response.choices[0].message.tool_calls = []
|
||||
|
||||
mock_config = Mock()
|
||||
mock_config.model = 'mock_model'
|
||||
|
||||
llm = Mock()
|
||||
llm.config = mock_config
|
||||
llm.completion = Mock(return_value=mock_response)
|
||||
llm.is_function_calling_active = Mock(return_value=True) # Enable function calling
|
||||
llm.is_caching_prompt_active = Mock(return_value=False)
|
||||
@@ -260,6 +266,28 @@ def test_step_with_no_pending_actions(mock_state: State):
|
||||
assert action.content == 'Task completed'
|
||||
|
||||
|
||||
def test_correct_tool_description_loaded_based_on_model_name(mock_state: State):
|
||||
"""Tests that the simplified tool descriptions are loaded for specific models."""
|
||||
o3_mock_config = Mock()
|
||||
o3_mock_config.model = 'mock_o3_model'
|
||||
|
||||
llm = Mock()
|
||||
llm.config = o3_mock_config
|
||||
|
||||
agent = CodeActAgent(llm=llm, config=AgentConfig())
|
||||
for tool in agent.tools:
|
||||
# Assert all descriptions have less than 1024 characters
|
||||
assert len(tool['function']['description']) < 1024
|
||||
|
||||
sonnet_mock_config = Mock()
|
||||
sonnet_mock_config.model = 'mock_sonnet_model'
|
||||
|
||||
llm.config = sonnet_mock_config
|
||||
agent = CodeActAgent(llm=llm, config=AgentConfig())
|
||||
# Assert existence of the detailed tool descriptions that are longer than 1024 characters
|
||||
assert any(len(tool['function']['description']) > 1024 for tool in agent.tools)
|
||||
|
||||
|
||||
def test_mismatched_tool_call_events(mock_state: State):
|
||||
"""Tests that the agent can convert mismatched tool call events (i.e., an observation with no corresponding action) into messages."""
|
||||
agent = CodeActAgent(llm=LLM(LLMConfig()), config=AgentConfig())
|
||||
|
||||
@@ -32,6 +32,7 @@ def _patch_store():
|
||||
'selected_repository': 'foobar',
|
||||
'conversation_id': 'some_conversation_id',
|
||||
'github_user_id': '12345',
|
||||
'user_id': '12345',
|
||||
'created_at': '2025-01-01T00:00:00+00:00',
|
||||
'last_updated_at': '2025-01-01T00:01:00+00:00',
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ from openhands.events.event import (
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
from openhands.events.observation.agent import (
|
||||
MicroagentKnowledge,
|
||||
MicroagentObservation,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.observation.browse import BrowserOutputObservation
|
||||
from openhands.events.observation.commands import (
|
||||
@@ -51,7 +51,7 @@ def agent_config():
|
||||
def conversation_memory(agent_config):
|
||||
prompt_manager = MagicMock(spec=PromptManager)
|
||||
prompt_manager.get_system_message.return_value = 'System message'
|
||||
prompt_manager.build_additional_info.return_value = (
|
||||
prompt_manager.build_workspace_context.return_value = (
|
||||
'Formatted repository and runtime info'
|
||||
)
|
||||
|
||||
@@ -353,10 +353,10 @@ def test_process_events_with_user_reject_observation(conversation_memory):
|
||||
|
||||
|
||||
def test_process_events_with_empty_environment_info(conversation_memory):
|
||||
"""Test that empty environment info observations return an empty list of messages without calling build_additional_info."""
|
||||
# Create a MicroagentObservation with empty info
|
||||
"""Test that empty environment info observations return an empty list of messages without calling build_workspace_context."""
|
||||
# Create a RecallObservation with empty info
|
||||
|
||||
empty_obs = MicroagentObservation(
|
||||
empty_obs = RecallObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name='',
|
||||
repo_directory='',
|
||||
@@ -382,8 +382,8 @@ def test_process_events_with_empty_environment_info(conversation_memory):
|
||||
assert len(messages) == 1
|
||||
assert messages[0].role == 'system'
|
||||
|
||||
# Verify that build_additional_info was NOT called since all input values were empty
|
||||
conversation_memory.prompt_manager.build_additional_info.assert_not_called()
|
||||
# Verify that build_workspace_context was NOT called since all input values were empty
|
||||
conversation_memory.prompt_manager.build_workspace_context.assert_not_called()
|
||||
|
||||
|
||||
def test_process_events_with_function_calling_observation(conversation_memory):
|
||||
@@ -527,8 +527,8 @@ def test_apply_prompt_caching(conversation_memory):
|
||||
|
||||
|
||||
def test_process_events_with_environment_microagent_observation(conversation_memory):
|
||||
"""Test processing a MicroagentObservation with ENVIRONMENT info type."""
|
||||
obs = MicroagentObservation(
|
||||
"""Test processing a RecallObservation with ENVIRONMENT info type."""
|
||||
obs = RecallObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name='test-repo',
|
||||
repo_directory='/path/to/repo',
|
||||
@@ -556,8 +556,8 @@ def test_process_events_with_environment_microagent_observation(conversation_mem
|
||||
assert result.content[0].text == 'Formatted repository and runtime info'
|
||||
|
||||
# Verify the prompt_manager was called with the correct parameters
|
||||
conversation_memory.prompt_manager.build_additional_info.assert_called_once()
|
||||
call_args = conversation_memory.prompt_manager.build_additional_info.call_args[1]
|
||||
conversation_memory.prompt_manager.build_workspace_context.assert_called_once()
|
||||
call_args = conversation_memory.prompt_manager.build_workspace_context.call_args[1]
|
||||
assert isinstance(call_args['repository_info'], RepositoryInfo)
|
||||
assert call_args['repository_info'].repo_name == 'test-repo'
|
||||
assert call_args['repository_info'].repo_directory == '/path/to/repo'
|
||||
@@ -572,7 +572,7 @@ def test_process_events_with_environment_microagent_observation(conversation_mem
|
||||
def test_process_events_with_knowledge_microagent_microagent_observation(
|
||||
conversation_memory,
|
||||
):
|
||||
"""Test processing a MicroagentObservation with KNOWLEDGE type."""
|
||||
"""Test processing a RecallObservation with KNOWLEDGE type."""
|
||||
microagent_knowledge = [
|
||||
MicroagentKnowledge(
|
||||
name='test_agent',
|
||||
@@ -591,7 +591,7 @@ def test_process_events_with_knowledge_microagent_microagent_observation(
|
||||
),
|
||||
]
|
||||
|
||||
obs = MicroagentObservation(
|
||||
obs = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=microagent_knowledge,
|
||||
content='Retrieved knowledge from microagents',
|
||||
@@ -634,11 +634,11 @@ def test_process_events_with_knowledge_microagent_microagent_observation(
|
||||
def test_process_events_with_microagent_observation_extensions_disabled(
|
||||
agent_config, conversation_memory
|
||||
):
|
||||
"""Test processing a MicroagentObservation when prompt extensions are disabled."""
|
||||
"""Test processing a RecallObservation when prompt extensions are disabled."""
|
||||
# Modify the agent config to disable prompt extensions
|
||||
agent_config.enable_prompt_extensions = False
|
||||
|
||||
obs = MicroagentObservation(
|
||||
obs = RecallObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name='test-repo',
|
||||
repo_directory='/path/to/repo',
|
||||
@@ -656,18 +656,18 @@ def test_process_events_with_microagent_observation_extensions_disabled(
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# When prompt extensions are disabled, the MicroagentObservation should be ignored
|
||||
# When prompt extensions are disabled, the RecallObservation should be ignored
|
||||
assert len(messages) == 1 # Only the initial system message
|
||||
assert messages[0].role == 'system'
|
||||
|
||||
# Verify the prompt_manager was not called
|
||||
conversation_memory.prompt_manager.build_additional_info.assert_not_called()
|
||||
conversation_memory.prompt_manager.build_workspace_context.assert_not_called()
|
||||
conversation_memory.prompt_manager.build_microagent_info.assert_not_called()
|
||||
|
||||
|
||||
def test_process_events_with_empty_microagent_knowledge(conversation_memory):
|
||||
"""Test processing a MicroagentObservation with empty microagent knowledge."""
|
||||
obs = MicroagentObservation(
|
||||
"""Test processing a RecallObservation with empty microagent knowledge."""
|
||||
obs = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[],
|
||||
content='Retrieved knowledge from microagents',
|
||||
@@ -693,7 +693,7 @@ def test_process_events_with_empty_microagent_knowledge(conversation_memory):
|
||||
|
||||
|
||||
def test_conversation_memory_processes_microagent_observation(prompt_dir):
|
||||
"""Test that ConversationMemory processes MicroagentObservations correctly."""
|
||||
"""Test that ConversationMemory processes RecallObservations correctly."""
|
||||
# Create a microagent_info.j2 template file
|
||||
template_path = os.path.join(prompt_dir, 'microagent_info.j2')
|
||||
if not os.path.exists(template_path):
|
||||
@@ -722,8 +722,8 @@ It may or may not be relevant to the user's request.
|
||||
config=agent_config, prompt_manager=prompt_manager
|
||||
)
|
||||
|
||||
# Create a MicroagentObservation with microagent knowledge
|
||||
microagent_observation = MicroagentObservation(
|
||||
# Create a RecallObservation with microagent knowledge
|
||||
microagent_observation = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
@@ -761,7 +761,7 @@ This is triggered content for testing.
|
||||
|
||||
|
||||
def test_conversation_memory_processes_environment_microagent_observation(prompt_dir):
|
||||
"""Test that ConversationMemory processes environment info MicroagentObservations correctly."""
|
||||
"""Test that ConversationMemory processes environment info RecallObservations correctly."""
|
||||
# Create an additional_info.j2 template file
|
||||
template_path = os.path.join(prompt_dir, 'additional_info.j2')
|
||||
if not os.path.exists(template_path):
|
||||
@@ -802,8 +802,8 @@ each of which has a corresponding port:
|
||||
config=agent_config, prompt_manager=prompt_manager
|
||||
)
|
||||
|
||||
# Create a MicroagentObservation with environment info
|
||||
microagent_observation = MicroagentObservation(
|
||||
# Create a RecallObservation with environment info
|
||||
microagent_observation = RecallObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name='owner/repo',
|
||||
repo_directory='/workspace/repo',
|
||||
@@ -839,13 +839,13 @@ each of which has a corresponding port:
|
||||
|
||||
|
||||
def test_process_events_with_microagent_observation_deduplication(conversation_memory):
|
||||
"""Test that MicroagentObservations are properly deduplicated based on agent name.
|
||||
"""Test that RecallObservations are properly deduplicated based on agent name.
|
||||
|
||||
The deduplication logic should keep the FIRST occurrence of each microagent
|
||||
and filter out later occurrences to avoid redundant information.
|
||||
"""
|
||||
# Create a sequence of MicroagentObservations with overlapping agents
|
||||
obs1 = MicroagentObservation(
|
||||
# Create a sequence of RecallObservations with overlapping agents
|
||||
obs1 = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
@@ -867,7 +867,7 @@ def test_process_events_with_microagent_observation_deduplication(conversation_m
|
||||
content='First retrieval',
|
||||
)
|
||||
|
||||
obs2 = MicroagentObservation(
|
||||
obs2 = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
@@ -879,7 +879,7 @@ def test_process_events_with_microagent_observation_deduplication(conversation_m
|
||||
content='Second retrieval',
|
||||
)
|
||||
|
||||
obs3 = MicroagentObservation(
|
||||
obs3 = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
@@ -918,8 +918,8 @@ def test_process_events_with_microagent_observation_deduplication_disabled_agent
|
||||
conversation_memory,
|
||||
):
|
||||
"""Test that disabled agents are filtered out and deduplication keeps the first occurrence."""
|
||||
# Create a sequence of MicroagentObservations with disabled agents
|
||||
obs1 = MicroagentObservation(
|
||||
# Create a sequence of RecallObservations with disabled agents
|
||||
obs1 = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
@@ -936,7 +936,7 @@ def test_process_events_with_microagent_observation_deduplication_disabled_agent
|
||||
content='First retrieval',
|
||||
)
|
||||
|
||||
obs2 = MicroagentObservation(
|
||||
obs2 = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
@@ -973,8 +973,8 @@ def test_process_events_with_microagent_observation_deduplication_disabled_agent
|
||||
def test_process_events_with_microagent_observation_deduplication_empty(
|
||||
conversation_memory,
|
||||
):
|
||||
"""Test that empty MicroagentObservations are handled correctly."""
|
||||
obs = MicroagentObservation(
|
||||
"""Test that empty RecallObservations are handled correctly."""
|
||||
obs = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[],
|
||||
content='Empty retrieval',
|
||||
@@ -991,7 +991,7 @@ def test_process_events_with_microagent_observation_deduplication_empty(
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Verify that empty MicroagentObservations are handled gracefully
|
||||
# Verify that empty RecallObservations are handled gracefully
|
||||
assert (
|
||||
len(messages) == 1
|
||||
) # system message, because an empty microagent is not added to Messages
|
||||
@@ -999,8 +999,8 @@ def test_process_events_with_microagent_observation_deduplication_empty(
|
||||
|
||||
def test_has_agent_in_earlier_events(conversation_memory):
|
||||
"""Test the _has_agent_in_earlier_events helper method."""
|
||||
# Create test MicroagentObservations
|
||||
obs1 = MicroagentObservation(
|
||||
# Create test RecallObservations
|
||||
obs1 = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
@@ -1012,7 +1012,7 @@ def test_has_agent_in_earlier_events(conversation_memory):
|
||||
content='First retrieval',
|
||||
)
|
||||
|
||||
obs2 = MicroagentObservation(
|
||||
obs2 = RecallObservation(
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
@@ -1024,7 +1024,7 @@ def test_has_agent_in_earlier_events(conversation_memory):
|
||||
content='Second retrieval',
|
||||
)
|
||||
|
||||
obs3 = MicroagentObservation(
|
||||
obs3 = RecallObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
content='Environment info',
|
||||
)
|
||||
|
||||
@@ -13,7 +13,8 @@ async def test_load_store():
|
||||
store = FileConversationStore(InMemoryFileStore({}))
|
||||
expected = ConversationMetadata(
|
||||
conversation_id='some-conversation-id',
|
||||
github_user_id='some-user-id',
|
||||
user_id='some-user-id',
|
||||
github_user_id='12345',
|
||||
selected_repository='some-repo',
|
||||
title="Let's talk about trains",
|
||||
)
|
||||
@@ -31,6 +32,7 @@ async def test_load_int_user_id():
|
||||
{
|
||||
'conversation_id': 'some-conversation-id',
|
||||
'github_user_id': 12345,
|
||||
'user_id': '67890',
|
||||
'selected_repository': 'some-repo',
|
||||
'title': "Let's talk about trains",
|
||||
'created_at': '2025-01-16T19:51:04.886331Z',
|
||||
@@ -41,6 +43,7 @@ async def test_load_int_user_id():
|
||||
)
|
||||
found = await store.get_metadata('some-conversation-id')
|
||||
assert found.github_user_id == '12345'
|
||||
assert found.user_id == '67890'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -61,6 +64,7 @@ async def test_search_basic():
|
||||
{
|
||||
'conversation_id': 'conv1',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': 'First conversation',
|
||||
'created_at': '2025-01-16T19:51:04Z',
|
||||
@@ -70,6 +74,7 @@ async def test_search_basic():
|
||||
{
|
||||
'conversation_id': 'conv2',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': 'Second conversation',
|
||||
'created_at': '2025-01-17T19:51:04Z',
|
||||
@@ -79,6 +84,7 @@ async def test_search_basic():
|
||||
{
|
||||
'conversation_id': 'conv3',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': 'Third conversation',
|
||||
'created_at': '2025-01-15T19:51:04Z',
|
||||
@@ -107,6 +113,7 @@ async def test_search_pagination():
|
||||
{
|
||||
'conversation_id': f'conv{i}',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': f'Conversation {i}',
|
||||
'created_at': f'2025-01-{15+i}T19:51:04Z',
|
||||
@@ -148,6 +155,7 @@ async def test_search_with_invalid_conversation():
|
||||
{
|
||||
'conversation_id': 'conv1',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': 'Valid conversation',
|
||||
'created_at': '2025-01-16T19:51:04Z',
|
||||
@@ -176,6 +184,7 @@ async def test_get_all_metadata():
|
||||
{
|
||||
'conversation_id': 'conv1',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': 'First conversation',
|
||||
'created_at': '2025-01-16T19:51:04Z',
|
||||
@@ -185,6 +194,7 @@ async def test_get_all_metadata():
|
||||
{
|
||||
'conversation_id': 'conv2',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': 'Second conversation',
|
||||
'created_at': '2025-01-17T19:51:04Z',
|
||||
|
||||
+21
-17
@@ -14,7 +14,7 @@ from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.observation.agent import (
|
||||
MicroagentObservation,
|
||||
RecallObservation,
|
||||
RecallType,
|
||||
)
|
||||
from openhands.events.stream import EventStream
|
||||
@@ -74,7 +74,7 @@ async def test_memory_on_event_exception_handling(memory, event_stream):
|
||||
|
||||
# Mock Memory method to raise an exception
|
||||
with patch.object(
|
||||
memory, '_on_first_microagent_action', side_effect=Exception('Test error')
|
||||
memory, '_on_workspace_context_recall', side_effect=Exception('Test error')
|
||||
):
|
||||
state = await run_controller(
|
||||
config=AppConfig(),
|
||||
@@ -93,10 +93,10 @@ async def test_memory_on_event_exception_handling(memory, event_stream):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_on_first_microagent_action_exception_handling(
|
||||
async def test_memory_on_workspace_context_recall_exception_handling(
|
||||
memory, event_stream
|
||||
):
|
||||
"""Test that exceptions in Memory._on_first_microagent_action are properly handled via status callback."""
|
||||
"""Test that exceptions in Memory._on_workspace_context_recall are properly handled via status callback."""
|
||||
|
||||
# Create a dummy agent for the controller
|
||||
agent = MagicMock(spec=Agent)
|
||||
@@ -108,11 +108,11 @@ async def test_memory_on_first_microagent_action_exception_handling(
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
runtime.event_stream = event_stream
|
||||
|
||||
# Mock Memory._on_first_microagent_action to raise an exception
|
||||
# Mock Memory._on_workspace_context_recall to raise an exception
|
||||
with patch.object(
|
||||
memory,
|
||||
'_on_first_microagent_action',
|
||||
side_effect=Exception('Test error from _on_first_microagent_action'),
|
||||
'_find_microagent_knowledge',
|
||||
side_effect=Exception('Test error from _find_microagent_knowledge'),
|
||||
):
|
||||
state = await run_controller(
|
||||
config=AppConfig(),
|
||||
@@ -130,12 +130,13 @@ async def test_memory_on_first_microagent_action_exception_handling(
|
||||
assert state.last_error == 'Error: Exception'
|
||||
|
||||
|
||||
def test_memory_with_microagents():
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_with_microagents():
|
||||
"""Test that Memory loads microagents from the global directory and processes microagent actions.
|
||||
|
||||
This test verifies that:
|
||||
1. Memory loads microagents from the global GLOBAL_MICROAGENTS_DIR
|
||||
2. When a microagent action with a trigger word is processed, a MicroagentObservation is created
|
||||
2. When a microagent action with a trigger word is processed, a RecallObservation is created
|
||||
"""
|
||||
# Create a mock event stream
|
||||
event_stream = MagicMock(spec=EventStream)
|
||||
@@ -158,6 +159,9 @@ def test_memory_with_microagents():
|
||||
query='Hello, flarglebargle!', recall_type=RecallType.KNOWLEDGE
|
||||
)
|
||||
|
||||
# Set the source to USER
|
||||
microagent_action._source = EventSource.USER # type: ignore[attr-defined]
|
||||
|
||||
# Mock the event_stream.add_event method
|
||||
added_events = []
|
||||
|
||||
@@ -173,12 +177,12 @@ def test_memory_with_microagents():
|
||||
added_events.clear()
|
||||
|
||||
# Process the microagent action
|
||||
memory.on_event(microagent_action)
|
||||
await memory._on_event(microagent_action)
|
||||
|
||||
# Verify a MicroagentObservation was added to the event stream
|
||||
# Verify a RecallObservation was added to the event stream
|
||||
assert len(added_events) == 1
|
||||
observation, source = added_events[0]
|
||||
assert isinstance(observation, MicroagentObservation)
|
||||
assert isinstance(observation, RecallObservation)
|
||||
assert source == EventSource.ENVIRONMENT
|
||||
assert observation.recall_type == RecallType.KNOWLEDGE
|
||||
assert len(observation.microagent_knowledge) == 1
|
||||
@@ -188,7 +192,7 @@ def test_memory_with_microagents():
|
||||
|
||||
|
||||
def test_memory_repository_info(prompt_dir):
|
||||
"""Test that Memory adds repository info to MicroagentObservations."""
|
||||
"""Test that Memory adds repository info to RecallObservations."""
|
||||
# Create an in-memory file store and real event stream
|
||||
file_store = InMemoryFileStore()
|
||||
event_stream = EventStream(sid='test-session', file_store=file_store)
|
||||
@@ -241,15 +245,15 @@ REPOSITORY INSTRUCTIONS: This is a test repository.
|
||||
# Get all events from the stream
|
||||
events = list(event_stream.get_events())
|
||||
|
||||
# Find the MicroagentObservation event
|
||||
# Find the RecallObservation event
|
||||
microagent_obs_events = [
|
||||
event for event in events if isinstance(event, MicroagentObservation)
|
||||
event for event in events if isinstance(event, RecallObservation)
|
||||
]
|
||||
|
||||
# We should have at least one MicroagentObservation
|
||||
# We should have at least one RecallObservation
|
||||
assert len(microagent_obs_events) > 0
|
||||
|
||||
# Get the first MicroagentObservation
|
||||
# Get the first RecallObservation
|
||||
observation = microagent_obs_events[0]
|
||||
assert observation.recall_type == RecallType.WORKSPACE_CONTEXT
|
||||
assert observation.repo_name == 'owner/repo'
|
||||
|
||||
@@ -5,8 +5,8 @@ from openhands.events.observation import (
|
||||
CmdOutputMetadata,
|
||||
CmdOutputObservation,
|
||||
FileEditObservation,
|
||||
MicroagentObservation,
|
||||
Observation,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.observation.agent import MicroagentKnowledge
|
||||
from openhands.events.serialization import (
|
||||
@@ -245,9 +245,9 @@ def test_file_edit_observation_legacy_serialization():
|
||||
|
||||
def test_microagent_observation_serialization():
|
||||
original_observation_dict = {
|
||||
'observation': 'microagent',
|
||||
'observation': 'recall',
|
||||
'content': '',
|
||||
'message': "**MicroagentObservation**\nrecall_type=RecallType.WORKSPACE_CONTEXT, repo_name=some_repo_name, repo_instructions=complex_repo_instruc..., runtime_hosts={'host1': 8080, 'host2': 8081}, additional_agent_instructions=You know it all abou...",
|
||||
'message': 'Added workspace context',
|
||||
'extras': {
|
||||
'recall_type': 'workspace_context',
|
||||
'repo_name': 'some_repo_name',
|
||||
@@ -258,14 +258,14 @@ def test_microagent_observation_serialization():
|
||||
'microagent_knowledge': [],
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_observation_dict, MicroagentObservation)
|
||||
serialization_deserialization(original_observation_dict, RecallObservation)
|
||||
|
||||
|
||||
def test_microagent_observation_microagent_knowledge_serialization():
|
||||
original_observation_dict = {
|
||||
'observation': 'microagent',
|
||||
'observation': 'recall',
|
||||
'content': '',
|
||||
'message': '**MicroagentObservation**\nrecall_type=RecallType.KNOWLEDGE, repo_name=, repo_instructions=..., runtime_hosts={}, additional_agent_instructions=..., microagent_knowledge=microagent1, microagent2',
|
||||
'message': 'Added microagent knowledge',
|
||||
'extras': {
|
||||
'recall_type': 'knowledge',
|
||||
'repo_name': '',
|
||||
@@ -287,13 +287,13 @@ def test_microagent_observation_microagent_knowledge_serialization():
|
||||
],
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_observation_dict, MicroagentObservation)
|
||||
serialization_deserialization(original_observation_dict, RecallObservation)
|
||||
|
||||
|
||||
def test_microagent_observation_knowledge_microagent_serialization():
|
||||
"""Test serialization of a MicroagentObservation with KNOWLEDGE_MICROAGENT type."""
|
||||
# Create a MicroagentObservation with microagent knowledge content
|
||||
original = MicroagentObservation(
|
||||
"""Test serialization of a RecallObservation with KNOWLEDGE_MICROAGENT type."""
|
||||
# Create a RecallObservation with microagent knowledge content
|
||||
original = RecallObservation(
|
||||
content='Knowledge microagent information',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
microagent_knowledge=[
|
||||
@@ -314,13 +314,13 @@ def test_microagent_observation_knowledge_microagent_serialization():
|
||||
serialized = event_to_dict(original)
|
||||
|
||||
# Verify serialized data structure
|
||||
assert serialized['observation'] == ObservationType.MICROAGENT
|
||||
assert serialized['observation'] == ObservationType.RECALL
|
||||
assert serialized['content'] == 'Knowledge microagent information'
|
||||
assert serialized['extras']['recall_type'] == RecallType.KNOWLEDGE.value
|
||||
assert len(serialized['extras']['microagent_knowledge']) == 2
|
||||
assert serialized['extras']['microagent_knowledge'][0]['trigger'] == 'python'
|
||||
|
||||
# Deserialize back to MicroagentObservation
|
||||
# Deserialize back to RecallObservation
|
||||
deserialized = observation_from_dict(serialized)
|
||||
|
||||
# Verify properties are preserved
|
||||
@@ -336,9 +336,9 @@ def test_microagent_observation_knowledge_microagent_serialization():
|
||||
|
||||
|
||||
def test_microagent_observation_environment_serialization():
|
||||
"""Test serialization of a MicroagentObservation with ENVIRONMENT type."""
|
||||
# Create a MicroagentObservation with environment info
|
||||
original = MicroagentObservation(
|
||||
"""Test serialization of a RecallObservation with ENVIRONMENT type."""
|
||||
# Create a RecallObservation with environment info
|
||||
original = RecallObservation(
|
||||
content='Environment information',
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name='OpenHands',
|
||||
@@ -352,7 +352,7 @@ def test_microagent_observation_environment_serialization():
|
||||
serialized = event_to_dict(original)
|
||||
|
||||
# Verify serialized data structure
|
||||
assert serialized['observation'] == ObservationType.MICROAGENT
|
||||
assert serialized['observation'] == ObservationType.RECALL
|
||||
assert serialized['content'] == 'Environment information'
|
||||
assert serialized['extras']['recall_type'] == RecallType.WORKSPACE_CONTEXT.value
|
||||
assert serialized['extras']['repo_name'] == 'OpenHands'
|
||||
@@ -364,7 +364,7 @@ def test_microagent_observation_environment_serialization():
|
||||
serialized['extras']['additional_agent_instructions']
|
||||
== 'You know it all about this runtime'
|
||||
)
|
||||
# Deserialize back to MicroagentObservation
|
||||
# Deserialize back to RecallObservation
|
||||
deserialized = observation_from_dict(serialized)
|
||||
|
||||
# Verify properties are preserved
|
||||
@@ -382,11 +382,11 @@ def test_microagent_observation_environment_serialization():
|
||||
|
||||
|
||||
def test_microagent_observation_combined_serialization():
|
||||
"""Test serialization of a MicroagentObservation with both types of information."""
|
||||
# Create a MicroagentObservation with both environment and microagent info
|
||||
"""Test serialization of a RecallObservation with both types of information."""
|
||||
# Create a RecallObservation with both environment and microagent info
|
||||
# Note: In practice, recall_type would still be one specific type,
|
||||
# but the object could contain both types of fields
|
||||
original = MicroagentObservation(
|
||||
original = RecallObservation(
|
||||
content='Combined information',
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
# Environment info
|
||||
@@ -419,7 +419,7 @@ def test_microagent_observation_combined_serialization():
|
||||
serialized['extras']['additional_agent_instructions']
|
||||
== 'You know it all about this runtime'
|
||||
)
|
||||
# Deserialize back to MicroagentObservation
|
||||
# Deserialize back to RecallObservation
|
||||
deserialized = observation_from_dict(serialized)
|
||||
|
||||
# Verify all properties are preserved
|
||||
|
||||
@@ -51,7 +51,7 @@ At the user's request, repository {{ repository_info.repo_name }} has been clone
|
||||
assert 'System prompt: bar' in system_msg
|
||||
|
||||
# Test building additional info
|
||||
additional_info = manager.build_additional_info(
|
||||
additional_info = manager.build_workspace_context(
|
||||
repository_info=repo_info, runtime_info=None, repo_instructions=''
|
||||
)
|
||||
assert '<REPOSITORY_INFO>' in additional_info
|
||||
@@ -199,7 +199,7 @@ def test_add_turns_left_reminder(prompt_dir):
|
||||
)
|
||||
|
||||
|
||||
def test_build_additional_info_with_repo_and_runtime(prompt_dir):
|
||||
def test_build_workspace_context_with_repo_and_runtime(prompt_dir):
|
||||
"""Test building additional info with repository and runtime information."""
|
||||
# Create an additional_info.j2 template file
|
||||
with open(os.path.join(prompt_dir, 'additional_info.j2'), 'w') as f:
|
||||
@@ -245,7 +245,7 @@ each of which has a corresponding port:
|
||||
repo_instructions = 'This repository contains important code.'
|
||||
|
||||
# Build additional info
|
||||
result = manager.build_additional_info(
|
||||
result = manager.build_workspace_context(
|
||||
repository_info=repo_info,
|
||||
runtime_info=runtime_info,
|
||||
repo_instructions=repo_instructions,
|
||||
|
||||
@@ -49,6 +49,7 @@ async def test_iterate_single_page():
|
||||
{
|
||||
'conversation_id': 'conv1',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': 'First conversation',
|
||||
'created_at': '2025-01-16T19:51:04Z',
|
||||
@@ -58,6 +59,7 @@ async def test_iterate_single_page():
|
||||
{
|
||||
'conversation_id': 'conv2',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': 'Second conversation',
|
||||
'created_at': '2025-01-17T19:51:04Z',
|
||||
@@ -86,6 +88,7 @@ async def test_iterate_multiple_pages():
|
||||
{
|
||||
'conversation_id': f'conv{i}',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': f'Conversation {i}',
|
||||
'created_at': f'2025-01-{15+i}T19:51:04Z',
|
||||
@@ -120,6 +123,7 @@ async def test_iterate_with_invalid_conversation():
|
||||
{
|
||||
'conversation_id': 'conv1',
|
||||
'github_user_id': '123',
|
||||
'user_id': '123',
|
||||
'selected_repository': 'repo1',
|
||||
'title': 'Valid conversation',
|
||||
'created_at': '2025-01-16T19:51:04Z',
|
||||
|
||||
@@ -61,7 +61,7 @@ async def test_init_new_local_session():
|
||||
'new-session-id', ConversationInitData(), 1
|
||||
)
|
||||
await conversation_manager.join_conversation(
|
||||
'new-session-id', 'new-session-id', ConversationInitData(), 1
|
||||
'new-session-id', 'new-session-id', ConversationInitData(), 1, '12345'
|
||||
)
|
||||
assert session_instance.initialize_agent.call_count == 1
|
||||
assert sio.enter_room.await_count == 1
|
||||
@@ -93,10 +93,18 @@ async def test_join_local_session():
|
||||
'new-session-id', ConversationInitData(), None
|
||||
)
|
||||
await conversation_manager.join_conversation(
|
||||
'new-session-id', 'new-session-id', ConversationInitData(), None
|
||||
'new-session-id',
|
||||
'new-session-id',
|
||||
ConversationInitData(),
|
||||
None,
|
||||
'12345',
|
||||
)
|
||||
await conversation_manager.join_conversation(
|
||||
'new-session-id', 'new-session-id', ConversationInitData(), None
|
||||
'new-session-id',
|
||||
'new-session-id',
|
||||
ConversationInitData(),
|
||||
None,
|
||||
'12345',
|
||||
)
|
||||
assert session_instance.initialize_agent.call_count == 1
|
||||
assert sio.enter_room.await_count == 2
|
||||
@@ -128,7 +136,7 @@ async def test_add_to_local_event_stream():
|
||||
'new-session-id', ConversationInitData(), 1
|
||||
)
|
||||
await conversation_manager.join_conversation(
|
||||
'new-session-id', 'connection-id', ConversationInitData(), 1
|
||||
'new-session-id', 'connection-id', ConversationInitData(), 1, '12345'
|
||||
)
|
||||
await conversation_manager.send_to_event_stream(
|
||||
'connection-id', {'event_type': 'some_event'}
|
||||
|
||||
@@ -23,7 +23,13 @@ def mock_event_stream():
|
||||
def mock_agent():
|
||||
agent = MagicMock()
|
||||
agent.llm = MagicMock()
|
||||
agent.llm.config = MagicMock()
|
||||
|
||||
# Create a step function that returns an action without an ID
|
||||
def agent_step_fn(state):
|
||||
return MessageAction(content='Agent returned a message')
|
||||
|
||||
agent.step = agent_step_fn
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user