Compare commits

...

61 Commits

Author SHA1 Message Date
openhands bc27eae841 Remove CodeBLEU dependency and use tree-sitter directly 2025-03-17 16:45:19 +00:00
openhands f0e5c81272 Merge main branch 2025-03-17 16:43:36 +00:00
Engel Nyst a4b836b5f9 Don't try to send the new events in the UI (#7277) 2025-03-17 14:50:22 +01:00
Xingyao Wang a4d632498c SWE-Gym rollout stability fix & using a validated SWE-Gym set (#7182)
Co-authored-by: Robert Brennan <accounts@rbren.io>
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
Co-authored-by: Graham Neubig <neubig@gmail.com>
2025-03-17 21:15:01 +08:00
Engel Nyst 4f017081fc Quick fix docs (#7299)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-03-17 05:50:05 +00:00
Engel Nyst 51fb1fae88 RecallObservations (#7292) 2025-03-17 03:18:22 +01:00
Graham Neubig 106b230fea Update Slack invitation links (#7296)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-03-17 02:06:48 +00:00
Xingyao Wang 9b262dd057 fix retry on ConnectionError & retry on remote runtime by default (#7294) 2025-03-17 01:18:54 +00:00
chuckbutkus 8074b261d3 Move current user_id to github_user_id and create a new user_id field (#7231)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Rohit Malhotra <rohitvinodmalhotra@gmail.com>
2025-03-16 16:32:27 -04:00
dependabot[bot] 999a59f938 chore(deps): bump the version-all group with 5 updates (#7253)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: openhands <openhands@all-hands.dev>
2025-03-16 19:51:08 +00:00
chuckbutkus fbba57d3b5 Fix saving of settings (#7282) 2025-03-16 19:06:46 +00:00
Engel Nyst 3f6c8a2338 Fix visual browsing (#7278)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-03-16 16:50:25 +01:00
Engel Nyst dd09d46ccb Remove DelegatorAgent (fix #7280)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-03-16 16:49:28 +01:00
tofarr 8897b45eeb Fix for too much reaction in logs (#7276) 2025-03-16 08:21:30 -06:00
Ryan H. Tran 30109e8f20 Separate tool descriptions to support models with limited description length (#7258) 2025-03-16 09:48:13 +01:00
kjain14 4984bf6ee7 Merge pull request #6 from All-Hands-AI/main
Main of OpenHands
2025-02-17 14:45:13 -05:00
Kush Dave Jain 92ddc1b46c Merge branch 'All-Hands-AI-main' 2025-02-16 18:58:14 +00:00
Kush Dave Jain 367c8a9f83 Merging 2025-02-16 18:57:47 +00:00
Kush Dave Jain 09335d67be Merging 2025-02-16 17:14:17 +00:00
kjain14 eb36426d33 Adding so file for codebleu 2025-02-11 06:57:07 -05:00
openhands ace9e6e724 Update TestGenEval README to include dependency installation 2025-02-10 15:45:04 +00:00
Graham Neubig 1290a2599d Update lock 2025-02-10 10:43:49 -05:00
openhands 513dd9791d Restore testgeneval poetry group 2025-02-10 15:38:08 +00:00
Graham Neubig fd53378d06 Rename eval-infer 2025-02-10 10:21:07 -05:00
Graham Neubig 4471002c79 Remove unneeded input 2025-02-09 08:25:08 -05:00
Graham Neubig 326e75e829 Remove prompt truncation 2025-02-09 07:31:36 -05:00
Graham Neubig d8ad8babf6 Merge branch 'main' into main 2025-02-08 13:11:27 -05:00
kjain14 8782e3ae65 Merge pull request #3 from All-Hands-AI/main
Merging master
2025-02-06 14:54:49 -05:00
Kush Dave Jain eef0ed3410 Final prompt for final experiments 2025-02-06 19:48:04 +00:00
Kush Dave Jain e7a8daf3ec Fixing code to handle ablations 2025-02-04 03:10:11 +00:00
Kush Dave Jain 64abd4a95e Ablation outputs 2025-01-30 21:12:42 +00:00
Kush Dave Jain c7d575b4e1 Removing duplicate script 2025-01-28 21:29:30 +00:00
Kush Dave Jain f781bc8343 Fix prompting 2025-01-28 21:11:35 +00:00
Kush Dave Jain 9f9a65c787 More updates 2025-01-20 19:39:17 +00:00
Graham Neubig 3355baea4c Merge branch 'main' of github.com:kjain14/OpenHands into kjain14-main 2025-01-20 07:47:17 -05:00
Kush Dave Jain 8848e60c6d Only top level filtering 2025-01-17 19:33:19 +00:00
Kush Dave Jain d1e84093cc Update filtering 2025-01-17 18:19:52 +00:00
Kush Dave Jain 3f0f13d335 Update prompt 2025-01-10 18:20:16 +00:00
Kush Dave Jain 219a134bb0 Refine prompt 2025-01-10 18:06:20 +00:00
Kush Dave Jain efb525a463 Refine postprocessing 2025-01-09 22:53:15 +00:00
Kush Dave Jain 1ded123116 Reset to normal time 2025-01-08 19:08:29 +00:00
Kush Dave Jain 31b6967a87 Any and all pass 2025-01-08 18:04:50 +00:00
Graham Neubig 90422e5bfd Update lock file 2024-12-25 16:43:59 -05:00
Graham Neubig b47da9e894 Merge branch 'main' of github.com:All-Hands-AI/OpenHands into kjain14-main 2024-12-25 16:40:51 -05:00
openhands 3401bd610d Update TestGenEval README with comprehensive information 2024-12-25 21:23:53 +00:00
Kush Dave Jain 77a153e42f Final update, now working on all projects 2024-12-16 19:54:58 +00:00
Kush Dave Jain fb9bc87e35 testgeneval deps 2024-12-12 15:50:35 +00:00
Kush Dave Jain b685c67263 reset 2024-12-12 15:50:12 +00:00
Kush Dave Jain 2cd64bc636 Merge branch 'main' of https://github.com/kjain14/OpenHands 2024-12-11 22:25:32 +00:00
Kush Dave Jain 7c81deb132 Update README 2024-12-11 22:25:17 +00:00
kjain14 b19f735808 Merge pull request #2 from All-Hands-AI/main
merging openhands
2024-12-11 17:06:55 -05:00
Kush Dave Jain 3af6025303 + mutation testing 2024-12-11 22:05:46 +00:00
Kush Dave Jain 585dba9917 TestGenEval MVP 2024-12-11 20:36:20 +00:00
kjain14 bd66d09a33 Merge pull request #1 from All-Hands-AI/main
Merge OpenHands
2024-12-06 11:14:34 -05:00
Kush Dave Jain 791b7f9f60 Cleaning to not OOM 2024-12-05 16:12:37 +00:00
Kush Dave Jain 30197e616b Add option for starting point 2024-12-04 20:33:15 +00:00
Kush Dave Jain f7f25319e3 Fixing testing dependencies 2024-12-04 14:34:01 +00:00
Kush Dave Jain 75fba59588 Readability metrics 2024-11-29 20:45:42 +00:00
Kush Dave Jain 280baa24e2 Licensing 2024-11-29 20:32:46 +00:00
Kush Dave Jain c6206f5cf2 Initial pass for TestGenEval 2024-11-29 18:29:08 +00:00
Kush Dave Jain 7a4729c034 initial TestGenEval code 2024-11-29 17:57:57 +00:00
95 changed files with 12147 additions and 531 deletions
+2 -2
View File
@@ -12,7 +12,7 @@
<a href="https://codecov.io/github/All-Hands-AI/OpenHands?branch=main"><img alt="CodeCov" src="https://img.shields.io/codecov/c/github/All-Hands-AI/OpenHands?style=for-the-badge&color=blue"></a>
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/github/license/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="MIT License"></a>
<br/>
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community"></a>
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community"></a>
<a href="https://discord.gg/ESHStjSjD4"><img src="https://img.shields.io/badge/Discord-Join%20Us-purple?logo=discord&logoColor=white&style=for-the-badge" alt="Join our Discord community"></a>
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/CREDITS.md"><img src="https://img.shields.io/badge/Project-Credits-blue?style=for-the-badge&color=FFE165&logo=github&logoColor=white" alt="Credits"></a>
<br/>
@@ -96,7 +96,7 @@ troubleshooting resources, and advanced configuration options.
OpenHands is a community-driven project, and we welcome contributions from everyone. We do most of our communication
through Slack, so this is the best place to start, but we also are happy to have you contact us on Discord or Github:
- [Join our Slack workspace](https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg) - Here we talk about research, architecture, and future development.
- [Join our Slack workspace](https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw) - Here we talk about research, architecture, and future development.
- [Join our Discord server](https://discord.gg/ESHStjSjD4) - This is a community-run server for general discussion, questions, and feedback.
- [Read or post Github Issues](https://github.com/All-Hands-AI/OpenHands/issues) - Check out the issues we're working on, or add your own ideas.
@@ -42,7 +42,7 @@ Explorez le code source d'OpenHands sur [GitHub](https://github.com/All-Hands-AI
/>
</a>
<br></br>
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg">
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw">
<img
src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge"
alt="Join our Slack community"
@@ -42,7 +42,7 @@ OpenHands 是一个**自主 AI 软件工程师**,能够执行复杂的工程
/>
</a>
<br></br>
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg">
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw">
<img
src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge"
alt="Join our Slack community"
+1 -1
View File
@@ -8,7 +8,7 @@ function CustomFooter() {
<footer className="custom-footer">
<div className="footer-content">
<div className="footer-icons">
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg" target="_blank" rel="noopener noreferrer">
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw" target="_blank" rel="noopener noreferrer">
<FaSlack />
</a>
<a href="https://discord.gg/ESHStjSjD4" target="_blank" rel="noopener noreferrer">
@@ -46,7 +46,7 @@ export function HomepageHeader() {
<a href="https://codecov.io/github/All-Hands-AI/OpenHands?branch=main"><img alt="CodeCov" src="https://img.shields.io/codecov/c/github/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" /></a>
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/github/license/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="MIT License" /></a>
<br/>
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community" /></a>
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community" /></a>
<a href="https://discord.gg/ESHStjSjD4"><img src="https://img.shields.io/badge/Discord-Join%20Us-purple?logo=discord&logoColor=white&style=for-the-badge" alt="Join our Discord community" /></a>
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/CREDITS.md"><img src="https://img.shields.io/badge/Project-Credits-blue?style=for-the-badge&color=FFE165&logo=github&logoColor=white" alt="Credits" /></a>
<br/>
+14 -1
View File
@@ -1,3 +1,4 @@
import copy
import json
import os
import subprocess
@@ -175,6 +176,11 @@ def process_instance(
logger.warning(
f'This is the {runtime_failure_count + 1}th attempt for instance {instance.instance_id}, setting resource factor to {config.sandbox.remote_runtime_resource_factor}'
)
metadata = copy.deepcopy(metadata)
metadata.details['runtime_failure_count'] = runtime_failure_count
metadata.details['remote_runtime_resource_factor'] = (
config.sandbox.remote_runtime_resource_factor
)
try:
runtime = create_runtime(config)
@@ -296,14 +302,20 @@ def process_instance(
with open(test_output_path, 'w') as f:
f.write(test_output)
try:
extra_kwargs = {}
if 'SWE-Gym' in metadata.dataset:
# SWE-Gym uses a different version of the package, hence a different eval report argument
extra_kwargs['log_path'] = test_output_path
else:
extra_kwargs['test_log_path'] = test_output_path
_report = conditional_imports.get_eval_report(
test_spec=test_spec,
prediction={
'model_patch': model_patch,
'instance_id': instance_id,
},
test_log_path=test_output_path,
include_tests_status=True,
**extra_kwargs,
)
report = _report[instance_id]
logger.info(
@@ -463,6 +475,7 @@ if __name__ == '__main__':
.decode('utf-8')
.strip(), # Current commit
dataset=args.dataset, # Dataset name from args
details={},
)
# The evaluation harness constrains the signature of `process_instance_func` but we need to
File diff suppressed because it is too large Load Diff
@@ -23,7 +23,7 @@ def get_resource_mapping(dataset_name: str) -> dict[str, float]:
if dataset_name not in _global_resource_mapping:
file_path = os.path.join(CUR_DIR, f'{dataset_name}.json')
if not os.path.exists(file_path):
logger.warning(f'Resource mapping for {dataset_name} not found.')
logger.info(f'Resource mapping for {dataset_name} not found.')
return None
with open(file_path, 'r') as f:
+23 -19
View File
@@ -1,4 +1,5 @@
import asyncio
import copy
import json
import os
import tempfile
@@ -149,7 +150,8 @@ def get_config(
) -> AppConfig:
# We use a different instance image for the each instance of swe-bench eval
use_official_image = bool(
'verified' in metadata.dataset.lower() or 'lite' in metadata.dataset.lower()
('verified' in metadata.dataset.lower() or 'lite' in metadata.dataset.lower())
and 'swe-gym' not in metadata.dataset.lower()
)
base_container_image = get_instance_docker_image(
instance['instance_id'], use_official_image
@@ -475,6 +477,13 @@ def process_instance(
logger.warning(
f'This is the {runtime_failure_count + 1}th attempt for instance {instance.instance_id}, setting resource factor to {config.sandbox.remote_runtime_resource_factor}'
)
metadata = copy.deepcopy(metadata)
metadata.details['runtime_failure_count'] = runtime_failure_count
metadata.details['remote_runtime_resource_factor'] = (
config.sandbox.remote_runtime_resource_factor
)
runtime = create_runtime(config)
call_async_from_sync(runtime.connect)
@@ -560,20 +569,6 @@ def filter_dataset(dataset: pd.DataFrame, filter_column: str) -> pd.DataFrame:
return dataset
# A list of instances that are known to be tricky to infer
# (will cause runtime failure even with resource factor = 8)
SWEGYM_EXCLUDE_IDS = [
'dask__dask-10422',
'pandas-dev__pandas-50548',
'pandas-dev__pandas-53672',
'pandas-dev__pandas-54174',
'pandas-dev__pandas-55518',
'pandas-dev__pandas-58383',
'pydata__xarray-6721',
'pytest-dev__pytest-10081',
'pytest-dev__pytest-7236',
]
if __name__ == '__main__':
parser = get_parser()
parser.add_argument(
@@ -598,11 +593,20 @@ if __name__ == '__main__':
f'Loaded dataset {args.dataset} with split {args.split}: {len(swe_bench_tests)} tasks'
)
if 'SWE-Gym' in args.dataset:
swe_bench_tests = swe_bench_tests[
~swe_bench_tests['instance_id'].isin(SWEGYM_EXCLUDE_IDS)
]
with open(
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'split',
'swegym_verified_instances.json',
),
'r',
) as f:
swegym_verified_instances = json.load(f)
swe_bench_tests = swe_bench_tests[
swe_bench_tests['instance_id'].isin(swegym_verified_instances)
]
logger.info(
f'{len(swe_bench_tests)} tasks left after excluding SWE-Gym excluded tasks'
f'{len(swe_bench_tests)} tasks left after filtering for SWE-Gym verified instances'
)
llm_config = None
@@ -9,7 +9,7 @@ parser.add_argument(
'--dataset_name',
type=str,
help='Name of the dataset to download',
default='princeton-nlp/SWE-bench_Lite',
default='princeton-nlp/SWE-bench_Verified',
)
parser.add_argument('--split', type=str, help='Split to download', default='test')
args = parser.parse_args()
@@ -20,7 +20,12 @@ print(
f'Downloading gold patches from {args.dataset_name} (split: {args.split}) to {output_filepath}'
)
patches = [
{'instance_id': row['instance_id'], 'model_patch': row['patch']} for row in dataset
{
'instance_id': row['instance_id'],
'model_patch': row['patch'],
'model_name_or_path': 'gold',
}
for row in dataset
]
print(f'{len(patches)} gold patches loaded')
pd.DataFrame(patches).to_json(output_filepath, lines=True, orient='records')
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,12 @@
codamosa_ids = ['pydata__xarray-4750-16496', 'pydata__xarray-3239-16458', 'pydata__xarray-4966-16515', 'pydata__xarray-3302-16459', 'pydata__xarray-5126-16518', 'pydata__xarray-4994-16516', 'pydata__xarray-3905-16478', 'pydata__xarray-4182-16484', 'pydata__xarray-5131-16520', 'pydata__xarray-5662-16532', 'pydata__xarray-3364-16461', 'pydata__xarray-5731-16534', 'pydata__xarray-3239-16457', 'pydata__xarray-7203-16577', 'pydata__xarray-3156-16454', 'pydata__xarray-5126-16519', 'pydata__xarray-5365-16529', 'pydata__xarray-4629-16492', 'pydata__xarray-4248-16486', 'pydata__xarray-4339-16487', 'pydata__xarray-3151-16453', 'pydata__xarray-3114-16452', 'pydata__xarray-5033-16517', 'pydata__xarray-4802-16505', 'pydata__xarray-5455-16530', 'pydata__xarray-6400-16539', 'pydata__xarray-3239-16456', 'pydata__xarray-4419-16488']
pynguin_ids = ['pydata__xarray-6548-16541', 'pydata__xarray-7003-16557', 'pydata__xarray-3114-16452', 'pydata__xarray-4339-16487', 'pydata__xarray-6889-16549', 'pydata__xarray-3239-16458', 'pydata__xarray-3364-16461', 'pydata__xarray-3239-16457', 'pydata__xarray-5365-16529', 'pydata__xarray-5131-16520', 'pydata__xarray-7229-16578', 'pydata__xarray-6461-16540', 'pydata__xarray-4419-16488', 'pydata__xarray-7147-16571', 'pydata__xarray-3151-16453', 'pydata__xarray-4966-16515', 'pydata__xarray-4629-16492', 'pydata__xarray-3239-16456', 'pydata__xarray-7400-16582', 'pydata__xarray-4994-16516', 'pydata__xarray-3302-16459', 'pydata__xarray-6601-16544', 'pydata__xarray-6882-16548', 'pydata__xarray-6135-16535', 'pydata__xarray-7393-16581', 'pydata__xarray-5731-16534', 'pydata__xarray-7203-16577']
ids = ['pydata__xarray-3114-16452', 'pydata__xarray-3151-16453', 'pydata__xarray-3156-16454', 'pydata__xarray-3239-16456', 'pydata__xarray-3239-16457', 'pydata__xarray-3239-16458', 'pydata__xarray-3302-16459', 'pydata__xarray-3364-16461', 'pydata__xarray-3677-16471', 'pydata__xarray-3905-16478', 'pydata__xarray-4182-16484', 'pydata__xarray-4248-16486', 'pydata__xarray-4339-16487', 'pydata__xarray-4419-16488', 'pydata__xarray-4629-16492', 'pydata__xarray-4750-16496', 'pydata__xarray-4802-16505', 'pydata__xarray-4966-16515', 'pydata__xarray-4994-16516', 'pydata__xarray-5033-16517', 'pydata__xarray-5126-16518', 'pydata__xarray-5126-16519', 'pydata__xarray-5131-16520', 'pydata__xarray-5365-16529', 'pydata__xarray-5455-16530', 'pydata__xarray-5662-16532', 'pydata__xarray-5731-16534', 'pydata__xarray-6135-16535', 'pydata__xarray-6135-16536', 'pydata__xarray-6386-16537', 'pydata__xarray-6394-16538', 'pydata__xarray-6400-16539', 'pydata__xarray-6461-16540', 'pydata__xarray-6548-16541', 'pydata__xarray-6599-16543', 'pydata__xarray-6601-16544', 'pydata__xarray-6882-16548', 'pydata__xarray-6889-16549', 'pydata__xarray-7003-16557', 'pydata__xarray-7147-16571', 'pydata__xarray-7150-16572', 'pydata__xarray-7203-16577', 'pydata__xarray-7229-16578', 'pydata__xarray-7393-16581', 'pydata__xarray-7400-16582']
Command eval (our approach):
poetry run ./evaluation/benchmarks/testgeneval/scripts/eval_infer_remote.sh evaluation/evaluation_outputs/outputs/kjain14__testgeneval-test/CodeActAgent/gpt-4o_maxiter_25_N_v0.20.0-no-hint-run_1/output.jsonl 10 kjain14/testgeneval test true
Command run (our approach):
./evaluation/benchmarks/testgeneval/scripts/run_infer.sh llm.eval_gpt HEAD CodeActAgent -1 25 10 kjain14/testgeneval test 1 ../TestGenEval/results/testgeneval/preds/gpt-4o-2024-08-06__testgeneval__0.2__test.jsonl
@@ -0,0 +1,80 @@
# TestGenEval Benchmark Evaluation
This folder contains the evaluation harness for the TestGenEval benchmark, which is based on the original TestGenEval benchmark ([paper](https://arxiv.org/abs/2410.00752)). TestGenEval is designed to evaluate the ability of language models to generate unit tests for given Python functions.
## Setup Environment and LLM Configuration
1. Follow the instructions [here](../../README.md#setup) to set up your local development environment and configure your LLM.
2. Install the TestGenEval dependencies:
```bash
poetry install --with testgeneval
```
## Run Inference
To generate tests using your model, run the following command:
```bash
./evaluation/benchmarks/testgeneval/scripts/run_infer.sh [model_config] [git-version] [agent] [eval_limit] [max_iter] [num_workers] [dataset] [dataset_split]
# Example
./evaluation/benchmarks/testgeneval/scripts/run_infer.sh llm.eval_gpt4_1106_preview HEAD CodeActAgent 100 30 1 kjain14/testgenevallite test
```
Parameters:
- `model_config`: The config group name for your LLM settings (e.g., `eval_gpt4_1106_preview`)
- `git-version`: The git commit hash or release tag of OpenHands to evaluate (e.g., `HEAD` or `0.6.2`)
- `agent`: The name of the agent for benchmarks (default: `CodeActAgent`)
- `eval_limit`: Limit the evaluation to the first N instances (optional)
- `max_iter`: Maximum number of iterations for the agent to run (default: 30)
- `num_workers`: Number of parallel workers for evaluation (default: 1)
- `dataset`: HuggingFace dataset name (default: `kjain14/testgenevallite`)
- `dataset_split`: Dataset split to use (default: `test`)
After running the inference, you will obtain an `output.jsonl` file (by default saved to `evaluation/evaluation_outputs`).
## Evaluate Generated Tests
To evaluate the generated tests, use the `eval_infer.sh` script:
```bash
./evaluation/benchmarks/testgeneval/scripts/eval_infer.sh $YOUR_OUTPUT_JSONL [instance_id] [dataset_name] [split] [num_workers] [skip_mutation]
# Example
./evaluation/benchmarks/testgeneval/scripts/eval_infer.sh evaluation/evaluation_outputs/outputs/testgeneval/CodeActAgent/gpt-4-1106-preview_maxiter_50_N_v1.0/output.jsonl
```
Optional arguments:
- `instance_id`: Evaluate a single instance (optional)
- `dataset_name`: Name of the dataset to use (default: `kjain14/testgenevallite`)
- `split`: Dataset split to use (default: `test`)
- `num_workers`: Number of workers for running docker (default: 1)
- `skip_mutation`: Skip mutation testing (enter `true` if desired)
The evaluation results will be saved to `evaluation/evaluation_outputs/outputs/testgeneval/CodeActAgent/gpt-4-1106-preview_maxiter_50_N_v1.0/` with `output.testgeneval.jsonl` containing the metrics.
## Metrics
The TestGenEval benchmark evaluates generated tests based on the following metrics:
1. Correctness: Measures if the generated tests are syntactically correct and run without errors.
2. Coverage: Assesses the code coverage achieved by the generated tests.
3. Mutation Score: Evaluates the effectiveness of the tests in detecting intentionally introduced bugs (mutations).
4. Readability: Analyzes the readability of the generated tests using various metrics.
## Submit Your Evaluation Results
To contribute your evaluation results:
1. Fork [our HuggingFace evaluation outputs](https://huggingface.co/spaces/OpenHands/evaluation).
2. Add your results to the forked repository.
3. Submit a Pull Request with your evaluation results following the guide [here](https://huggingface.co/docs/hub/en/repositories-pull-requests-discussions#pull-requests-and-discussions).
## Additional Resources
- [TestGenEval Paper](https://arxiv.org/abs/2410.00752)
- [OpenHands Documentation](https://github.com/All-Hands-AI/OpenHands)
- [HuggingFace Datasets](https://huggingface.co/datasets)
For any questions or issues, please open an issue in the [OpenHands repository](https://github.com/All-Hands-AI/OpenHands/issues).
@@ -0,0 +1,356 @@
import math
import os
from pathlib import Path
from tree_sitter import Language, Parser
def total_byte_entropy_stats(python_code):
# Count the occurrence of each byte (character for simplicity)
byte_counts = {}
for byte in python_code.encode('utf-8'):
byte_counts[byte] = byte_counts.get(byte, 0) + 1
total_bytes = sum(byte_counts.values())
entropy = -sum(
(count / total_bytes) * math.log2(count / total_bytes)
for count in byte_counts.values()
)
return {'total_byte_entropy': entropy}
def average_nulls_stats(tree, num_lines):
total_nulls = 0
nulls_per_line = {} # Dictionary to count nulls per line
def traverse(node):
nonlocal total_nulls
if node.type == 'null_literal':
total_nulls += 1
line_number = node.start_point[0] # Get line number
if line_number in nulls_per_line:
nulls_per_line[line_number] += 1
else:
nulls_per_line[line_number] = 1
for child in node.children:
traverse(child)
traverse(tree.root_node)
# Calculate average nulls per line
avg_nulls = total_nulls / num_lines if num_lines > 0 else 0
# Calculate max nulls on any line
max_nulls_on_any_line = max(nulls_per_line.values()) if nulls_per_line else 0
return {
'avg_nulls': avg_nulls,
'total_nulls': total_nulls,
'max_nulls': max_nulls_on_any_line,
'has_nulls': 1 if total_nulls > 0 else 0,
}
def arithmetic_operations_stats(tree, num_lines):
# Dictionary to hold counts of each arithmetic operation
op_counts = {'+': 0, '-': 0, '*': 0, '/': 0, '%': 0}
total_ops = 0
# Function to traverse the AST and update operation counts
def traverse(node):
nonlocal total_ops
if node.type == 'binary_expression' or node.type == 'update_expression':
for child in node.children:
if child.type == 'operator':
op = child.text.decode('utf8')
if op in op_counts:
op_counts[op] += 1
total_ops += 1
else:
for child in node.children:
traverse(child)
traverse(tree.root_node)
return {
'total_arithmetic_operations': total_ops,
'avg_arithmetic_operations': total_ops / num_lines,
}
def numbers_floats_stats(tree, num_lines):
total_numbers = 0
total_floats = 0
def traverse(node):
nonlocal total_numbers, total_floats
if node.type in ['integer_literal', 'decimal_literal']:
total_numbers += 1
if (
'.' in node.text.decode('utf8')
or 'e' in node.text.decode('utf8').lower()
):
total_floats += 1
for child in node.children:
traverse(child)
traverse(tree.root_node)
return {'total_numbers': total_numbers, 'total_floats': total_floats}
def code_stats(python_code):
lines = python_code.strip().split('\n')
total_line_length = sum(len(line) for line in lines)
max_line_length = max(len(line) for line in lines)
return {
'total_line_length': total_line_length,
'max_line_length': max_line_length,
'avg_characters': total_line_length / len(lines),
}
def assertions_stats(tree, num_lines):
total_assertions = 0
def traverse(node):
nonlocal total_assertions
if node.type == 'assert_statement':
total_assertions += 1
for child in node.children:
traverse(child)
traverse(tree.root_node)
return {
'total_assertions': total_assertions,
'total_has_assertions': 1 if total_assertions > 0 else 0,
}
def class_instances_stats(tree, num_lines):
total_class_instances = 0
def traverse(node):
nonlocal total_class_instances
if node.type == 'object_creation_expression':
total_class_instances += 1
for child in node.children:
traverse(child)
traverse(tree.root_node)
return {'total_class_instances': total_class_instances}
def has_execeptions(tree, num_lines):
total_has_exceptions = 0
def traverse(node):
nonlocal total_has_exceptions
if node.type == 'try_statement':
total_has_exceptions += 1
for child in node.children:
traverse(child)
traverse(tree.root_node)
return {'total_has_exceptions': 1 if total_has_exceptions > 0 else 0}
def distinct_methods_stats(tree, num_lines):
method_names = set()
total_nodes = 0
def traverse(node):
nonlocal total_nodes
if node.type == 'method_declaration':
for child in node.children:
if child.type == 'identifier':
method_names.add(child.text.decode('utf8'))
break
total_nodes += 1
for child in node.children:
traverse(child)
traverse(tree.root_node)
total_distinct_methods = len(method_names)
total_method_ratio = (
total_distinct_methods / (total_nodes - total_distinct_methods)
if total_nodes > total_distinct_methods
else 0
)
return {
'total_distinct_methods': total_distinct_methods,
'total_method_ratio': total_method_ratio,
}
def loops_stats(tree, num_lines):
"""
Calculate the average number of loops.
"""
total_loops = 0
def traverse(node):
nonlocal total_loops
if node.type in ['for_statement', 'while_statement', 'do_statement']:
total_loops += 1
for child in node.children:
traverse(child)
traverse(tree.root_node)
avg_loops = total_loops / num_lines
return {'avg_loops': avg_loops}
def branches_stats(tree, num_lines):
"""
Calculate the average number of branches (conditional statements).
"""
total_branches = 0
def traverse(node):
nonlocal total_branches
if node.type in ['if_statement', 'switch_statement']:
total_branches += 1
for child in node.children:
traverse(child)
traverse(tree.root_node)
# Assuming each branch is its own, this might need refinement based on definition
avg_branches = total_branches / num_lines
return {'avg_branches': avg_branches}
def string_stats(tree, num_lines):
string_literals = []
# Function to traverse the AST and collect string literals
def traverse(node):
if node.type == 'string_literal':
# Extracting the string literal, excluding the quotation marks
literal_text = node.text.decode('utf8')[1:-1]
string_literals.append(literal_text)
for child in node.children:
traverse(child)
traverse(tree.root_node)
# Calculate the average string length
total_length = sum(len(s) for s in string_literals)
avg_length = total_length / num_lines
return {'avg_str_length': avg_length}
def identifier_stats(tree, num_lines):
root_node = tree.root_node
identifier_counts = {} # Dictionary to count occurrences of each identifier
total_nodes = 0 # Counter for all nodes
# Function to recursively count identifiers and all nodes, gathering their stats
def count(node):
nonlocal identifier_counts, total_nodes
iden_count = 0
max_length = 0
total_nodes += 1 # Increment total nodes for every node visited
if node.type == 'identifier':
identifier = node.text.decode('utf8') # Assuming UTF-8 encoding
iden_count += 1
identifier_counts[identifier] = identifier_counts.get(identifier, 0) + 1
iden_length = len(identifier)
if iden_length > max_length:
max_length = iden_length
for child in node.children:
child_count, child_max_length = count(child)
iden_count += child_count
if child_max_length > max_length:
max_length = child_max_length
return iden_count, max_length
total_identifiers, max_identifier_length = count(root_node)
total_unique_identifiers = len(identifier_counts)
total_identifier_length = sum(len(k) * v for k, v in identifier_counts.items())
avg_identifier_length = total_identifier_length / num_lines
# Calculate the identifier ratio as total identifiers over total nodes
identifier_ratio = total_identifiers / total_nodes if total_nodes > 0 else 0
return {
'total_identifiers': total_identifiers,
'total_identifier_length': total_identifier_length,
'max_identifier_length': max_identifier_length,
'avg_identifier_length': avg_identifier_length,
'total_unique_identifiers': total_unique_identifiers,
'total_identifier_ratio': identifier_ratio, # Include the new ratio in the returned dictionary
'total_nodes': total_nodes, # Include total node count for reference or further calculations
}
def compute_regression(results):
components = {
'total_line_length': -0.0001,
'max_line_length': -0.0021,
'total_identifiers': 0.0076,
'total_identifier_length': -0.0004,
'max_identifier_length': -0.0067,
'avg_identifier_length': -0.005,
'avg_arithmetic_operations': 0.0225,
'avg_branches': 0.9886,
'avg_loops': 0.1572,
'total_assertions': 0.0119,
'total_has_assertions': -0.0147,
'avg_characters': 0.1242,
'total_class_instances': -0.043,
'total_distinct_methods': -0.0127,
'avg_str_length': 0.0026,
'total_has_exceptions': 0.1206,
'total_unique_identifiers': -0.019,
'max_nulls': -0.0712,
'total_numbers': -0.0078,
'avg_nulls': 0.1444,
'total_identifier_ratio': 0.334,
'total_method_ratio': 0.0406,
'total_floats': -0.0174,
'total_byte_entropy': -0.3917,
}
test_score = 0
for component in components:
test_score += components[component] * results[component]
test_score += 5.7501
return test_score
def compute_readability(python_code):
parser = Parser()
this_dir = Path(os.path.dirname(os.path.realpath(__file__)))
parser.set_language(Language.build_library(
# Store the library in the `build` directory
this_dir / "build" / "my-languages.so",
# Include one or more languages
[
this_dir / "tree-sitter-python"
]
).get_language('python'))
results = code_stats(python_code)
num_lines = len(python_code.strip().split('\n'))
results.update(total_byte_entropy_stats(python_code))
tree = parser.parse(bytes(python_code, 'utf8'))
results.update(identifier_stats(tree, num_lines))
results.update(loops_stats(tree, num_lines))
results.update(branches_stats(tree, num_lines))
results.update(distinct_methods_stats(tree, num_lines))
results.update(has_execeptions(tree, num_lines))
results.update(class_instances_stats(tree, num_lines))
results.update(assertions_stats(tree, num_lines))
results.update(numbers_floats_stats(tree, num_lines))
results.update(average_nulls_stats(tree, num_lines))
results.update(arithmetic_operations_stats(tree, num_lines))
results.update(string_stats(tree, num_lines))
score = compute_regression(results)
return score
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,633 @@
import os
import tempfile
import time
from functools import partial
import pandas as pd
from report_utils import (
check_coverage,
check_mutation,
count_methods,
get_lines_of_code,
)
from evaluation.benchmarks.testgeneval.compute_readability import compute_readability
from evaluation.benchmarks.testgeneval.constants import (
COVERAGE_PREFIX,
MUTATION_BUFFER,
MUTATION_TEMPLATE,
MUTATION_TIMEOUT,
TESTS_SUFFIX,
)
from evaluation.benchmarks.testgeneval.metrics import (
bleu,
code_bleu,
edit_sim,
exact_match,
rouge_l,
)
from evaluation.benchmarks.testgeneval.pygments_utils import tokenize_code
from evaluation.benchmarks.testgeneval.run_infer import get_instance_docker_image
from evaluation.benchmarks.testgeneval.test_filter import filter_tests
from evaluation.benchmarks.testgeneval.test_spec import (
TestGenEvalInstance,
TestSpec,
make_test_spec,
)
from evaluation.benchmarks.testgeneval.utils import load_testgeneval_dataset
from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
prepare_dataset,
reset_logger_for_multiprocessing,
run_evaluation,
)
from openhands.core.config import AppConfig, SandboxConfig, get_parser
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime
from openhands.events.action import CmdRunAction
from openhands.events.observation import CmdOutputObservation
from openhands.utils.async_utils import call_async_from_sync
DOCKER_IMAGE_PREFIX = os.environ.get('EVAL_DOCKER_IMAGE_PREFIX', 'docker.io/kdjain/')
logger.info(f'Using docker image prefix: {DOCKER_IMAGE_PREFIX}')
def get_config(instance: pd.Series) -> AppConfig:
base_container_image = get_instance_docker_image(instance['instance_id_swebench'])
assert (
base_container_image
), f"Invalid container image for instance {instance['instance_id_swebench']}."
logger.info(f'Using instance container image: {base_container_image}.')
return AppConfig(
run_as_openhands=False,
runtime=os.environ.get('RUNTIME', 'eventstream'),
sandbox=SandboxConfig(
base_container_image=base_container_image,
use_host_network=False,
timeout=1800,
api_key=os.environ.get('ALLHANDS_API_KEY'),
remote_runtime_api_url=os.environ.get(
'SANDBOX_REMOTE_RUNTIME_API_URL', 'http://localhost:8000'
),
),
workspace_base=None,
workspace_mount_path=None,
)
def compute_lexical_metrics(pred_suite, gold_suite):
pred_loc = get_lines_of_code(pred_suite)
gold_loc = get_lines_of_code(gold_suite)
pred_methods = count_methods(pred_suite)
gold_methods = count_methods(gold_suite)
readability_pred = compute_readability(pred_suite)
readability_gold = compute_readability(gold_suite)
preds = tokenize_code(pred_suite)
golds = tokenize_code(gold_suite)
return {
'pred_loc': pred_loc,
'gold_loc': gold_loc,
'pred_readability': readability_pred,
'gold_readability': readability_gold,
'pred_methods': pred_methods,
'gold_methods': gold_methods,
'code_bleu': code_bleu(preds, golds, 'Python3'),
'bleu': bleu(preds, golds),
'xmatch': exact_match(preds, golds),
'edit_sim': edit_sim(preds, golds),
'rouge_f': rouge_l(golds, preds)['f'],
'rouge_p': rouge_l(golds, preds)['p'],
'rouge_r': rouge_l(golds, preds)['r'],
}
def run_command(runtime, command, timeout=600):
action = CmdRunAction(command=command)
action.set_hard_timeout(timeout)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert obs.exit_code == 0
return obs
def run_tests(runtime, instance, test_script, log_file='/tmp/test_output.log'):
action = CmdRunAction(command=f'bash {test_script} > {log_file} 2>&1 & echo $!')
action.set_hard_timeout(60)
obs = runtime.run_action(action)
assert isinstance(obs, CmdOutputObservation), 'Failed to start test script.'
pid = obs.content.split()[-1].strip()
logger.info(f'[{instance.instance_id}] Test process started with PID: {pid}')
start_time = time.time()
timeout = 1800
while True:
elapsed_time = time.time() - start_time
if elapsed_time > timeout:
logger.info(f'[{instance.instance_id}] Test process timed out.')
instance['test_result']['report']['test_timeout'] = True
break
check_action = CmdRunAction(command=f'ps -p {pid} > /dev/null; echo $?')
check_obs = runtime.run_action(check_action)
if (
isinstance(check_obs, CmdOutputObservation)
and len(check_obs.content.split()) > 0
and check_obs.content.split()[-1].strip() == '1'
):
logger.info(f'[{instance.instance_id}] Test process completed.')
break
time.sleep(30)
test_action = CmdRunAction(command=f'cat {log_file}')
test_action.set_hard_timeout(300)
test_obs = runtime.run_action(test_action)
assert isinstance(test_obs, CmdOutputObservation), 'Failed to retrieve test output.'
return test_obs.exit_code, test_obs.content, elapsed_time
def run_mutation_testing(
runtime, instance, mutation_script, log_file='/tmp/mutation_output.log'
):
action = CmdRunAction(command=f'bash {mutation_script} > {log_file} 2>&1 & echo $!')
action.set_hard_timeout(60)
obs = runtime.run_action(action)
assert isinstance(obs, CmdOutputObservation), 'Failed to start test script.'
pid = obs.content.split()[-1].strip()
logger.info(f'[{instance.instance_id}] Mutation process started with PID: {pid}')
start_time = time.time()
timeout = 4000
while True:
elapsed_time = time.time() - start_time
if elapsed_time > timeout:
logger.info(f'[{instance.instance_id}] Mutation process timed out.')
instance['test_result']['report']['mutation_timeout'] = True
break
check_action = CmdRunAction(command=f'ps -p {pid} > /dev/null; echo $?')
check_obs = runtime.run_action(check_action)
if (
isinstance(check_obs, CmdOutputObservation)
and len(check_obs.content.split()) > 0
and check_obs.content.split()[-1].strip() == '1'
):
logger.info(f'[{instance.instance_id}] Mutation process completed.')
break
time.sleep(30)
assert isinstance(obs, CmdOutputObservation), 'Failed to run mutation script.'
mutation_action = CmdRunAction(command=f'cat {log_file}')
mutation_action.set_hard_timeout(300)
mutation_obs = runtime.run_action(mutation_action)
assert isinstance(
mutation_obs, CmdOutputObservation
), 'Failed to retrieve mutation output.'
return mutation_obs.exit_code, mutation_obs.content
def grade_test_output(
test_suite: str, instance: pd.Series, test_output: str, test_spec: TestSpec, runtime
):
"""
Two-pass test grading with short-circuiting:
1. Run all tests to identify passing/failing tests
2. If no failing tests, evaluate coverage immediately
3. Otherwise, run only passing tests for coverage analysis
"""
unit_test_output, coverage_output = '', ''
if TESTS_SUFFIX in test_output:
unit_test_output = test_output.split(TESTS_SUFFIX)[0]
if not unit_test_output:
return (
False,
0,
'',
'',
{
'total_tests': 0,
'passing_tests': 0,
'failing_tests': 0,
'any_pass': False,
'all_pass': False,
'passing_test_names': [],
'failing_test_names': [],
},
)
logger.info('Calling filter unit tests')
filtered_content, passing_tests, failing_tests = filter_tests(
test_suite, unit_test_output, test_spec.repo
)
total_tests = len(passing_tests) + len(failing_tests)
test_stats = {
'total_tests': total_tests,
'passing_tests': len(passing_tests),
'failing_tests': len(failing_tests),
'any_pass': len(passing_tests) > 0,
'all_pass': len(failing_tests) == 0 and total_tests > 0,
'passing_test_names': passing_tests,
'failing_test_names': failing_tests,
}
if not passing_tests:
return False, 0, unit_test_output, coverage_output, test_stats
# If all tests pass, evaluate coverage immediately
if not failing_tests:
coverage = 0
cov_success = False
if COVERAGE_PREFIX in test_output:
coverage_output = test_output.split(COVERAGE_PREFIX)[1]
_, coverage = check_coverage(coverage_output, test_spec.code_file)
cov_success = True
# test_stats['filtered_suite'] = test_suite
return cov_success, coverage, unit_test_output, coverage_output, test_stats
cov_success = False
coverage = 0
# Second pass - run coverage on passing tests
if filtered_content:
with tempfile.TemporaryDirectory() as temp_dir:
test_suite_path = os.path.join(temp_dir, 'test_suite.py')
with open(test_suite_path, 'w') as f:
f.write(filtered_content)
runtime.copy_to(test_suite_path, '/tmp')
run_command(runtime, f'cp /tmp/test_suite.py /testbed/{test_spec.test_file}')
_, test_output_second_pass, _ = run_tests(runtime, instance, '/tmp/test.sh')
coverage, coverage_output, unit_test_output = 0, '', test_output_second_pass
if COVERAGE_PREFIX in test_output_second_pass:
coverage_output = test_output_second_pass.split(COVERAGE_PREFIX)[1]
unit_test_output = test_output_second_pass.split(TESTS_SUFFIX)[0]
_, coverage = check_coverage(coverage_output, test_spec.code_file)
cov_success = True
# test_stats['filtered_suite'] = filtered_content
return cov_success, coverage, unit_test_output, coverage_output, test_stats
def process_instance(
instance: pd.Series,
metadata: EvalMetadata,
reset_logger: bool = True,
log_dir: str | None = None,
) -> EvalOutput:
"""
Evaluate agent performance on a TestGenEval problem instance.
Note that this signature differs from the expected input to `run_evaluation`. Use
`functools.partial` to provide optional arguments before passing to the evaluation harness.
Args:
log_dir (str | None, default=None): Path to directory where log files will be written. Must
be provided if `reset_logger` is set.
Raises:
AssertionError: if the `reset_logger` flag is set without a provided log directory.
"""
if reset_logger:
assert (
log_dir is not None
), "Can't reset logger without a provided log directory."
os.makedirs(log_dir, exist_ok=True)
reset_logger_for_multiprocessing(logger, instance.instance_id, log_dir)
else:
logger.info(f'Starting evaluation for instance {instance.instance_id}.')
config = get_config(instance)
id = instance.instance_id
logger.info(f'Starting evaluation for instance {id}.')
instance['test_result']['id'] = id
instance['test_result']['report'] = {
'test_output': '',
# 'coverage_output': '',
# 'mutation_output': '',
'empty_generation': False,
'error_eval': False,
'all_tests_pass': False,
'tests_pass': False,
'test_timeout': False,
'mutation_timeout': False,
'coverage_success': False,
'mutation_success': False,
'coverage': 0,
'mutation_score': 0,
'mutation_error_interval': -1,
'num_mutants': -1,
}
instance['test_result']['lexical'] = {
'pred_loc': -1,
'gold_loc': -1,
'pred_readability': -1,
'gold_readability': -1,
'pred_methods': -1,
'gold_methods': -1,
'code_bleu': -1,
'bleu': -1,
'xmatch': -1,
'edit_sim': -1,
'rouge_f': -1,
'rouge_p': -1,
'rouge_r': -1,
}
if instance['test_suite'] == '' or instance['test_suite'] is None:
instance['test_result']['report']['empty_generation'] = True
return EvalOutput(
instance_id=instance.instance_id, test_result=instance['test_result']
)
if not args.skip_lexical:
lexical_metrics = compute_lexical_metrics(
instance['test_suite'], instance['instance']['test_src']
)
instance['test_result']['lexical'] = lexical_metrics
test_suite = instance['test_suite']
test_spec: TestSpec = instance['test_spec']
runtime = create_runtime(config)
call_async_from_sync(runtime.connect)
with tempfile.TemporaryDirectory() as temp_dir:
test_suite_path = os.path.join(temp_dir, 'test_suite.py')
with open(test_suite_path, 'w') as f:
f.write(test_suite)
runtime.copy_to(test_suite_path, '/tmp')
test_script_path = os.path.join(temp_dir, 'test.sh')
with open(test_script_path, 'w') as f:
f.write(test_spec.test_script)
runtime.copy_to(test_script_path, '/tmp')
mutation_script_path = os.path.join(temp_dir, 'mutation.sh')
with open(mutation_script_path, 'w') as f:
f.write(test_spec.mutation_script)
runtime.copy_to(mutation_script_path, '/tmp')
try:
run_command(runtime, 'chmod +x /tmp/test.sh /tmp/mutation.sh')
run_command(runtime, f'cp /tmp/test_suite.py /testbed/{test_spec.test_file}')
# First pass - run all tests
_, test_output, test_time = run_tests(runtime, instance, '/tmp/test.sh')
# Grade tests with two-pass approach
coverage_success, coverage, unit_test_output, coverage_output, test_stats = (
grade_test_output(test_suite, instance, test_output, test_spec, runtime)
)
# Update report with test statistics
instance['test_result']['report'].update(
{
'test_output': unit_test_output,
# 'coverage_output': coverage_output,
'tests_pass': test_stats['any_pass'], # Changed to use any_pass
'all_tests_pass': test_stats['all_pass'], # Added all_pass metric
'coverage_success': coverage_success,
'coverage': coverage if coverage_success else 0,
'test_stats': test_stats,
}
)
# Only run mutation testing if we have passing tests and coverage
if (
not args.skip_mutation
and coverage_success
and test_stats['any_pass']
and coverage > 0
):
mutation_timeout = max(10, 1.5 * test_time)
mutation_toml = MUTATION_TEMPLATE.format(
test_cmd=test_spec.test_cmd,
source_fp=test_spec.code_file,
timeout=mutation_timeout,
)
with tempfile.TemporaryDirectory() as temp_dir:
mutation_toml_path = os.path.join(temp_dir, 'mutation.toml')
with open(mutation_toml_path, 'w') as f:
f.write(mutation_toml)
runtime.copy_to(mutation_toml_path, '/tmp')
run_command(runtime, 'cp /tmp/mutation.toml /testbed/mutation.toml')
mutation_code, mutation_output = run_mutation_testing(
runtime, instance, '/tmp/mutation.sh'
)
# instance['test_result']['report']['mutation_output'] = mutation_output
if mutation_output and mutation_code == 0:
(
mutation_success,
num_mutants,
mutation_score,
mutation_confidence_interval,
) = check_mutation(mutation_output)
instance['test_result']['report']['num_mutants'] = num_mutants
instance['test_result']['report']['mutation_success'] = mutation_success
instance['test_result']['report']['mutation_score'] = mutation_score
instance['test_result']['report']['mutation_error_interval'] = (
mutation_confidence_interval
)
return EvalOutput(
instance_id=instance.instance_id, test_result=instance['test_result']
)
except Exception as e:
logger.error(f'Error processing instance {instance.instance_id}: {e}')
raise RuntimeError(
instance.instance_id,
'Unexpected output...',
logger,
)
finally:
runtime.close()
def count_and_log_fields(evaluated_predictions, fields, key):
"""
Count and log the sum of specified fields in the evaluated predictions,
ignoring fields with a value of -1. If all values for a field are -1,
return -1.
:param evaluated_predictions: DataFrame containing evaluation results
:param fields: List of field names to count
:param key: Key to access the field values ('report' or 'lexical')
"""
def count_field(row, field):
value = row['test_result'][key][field]
return (
value if value != -1 else None
) # Ignore -1 fields by treating them as None
for field in fields:
# Extract the valid values for the field, ignoring -1
valid_values = evaluated_predictions.apply(
count_field, args=(field,), axis=1
).dropna()
if valid_values.empty: # If all values are -1
logger.info(f'# {field}: -1 (All values are -1)')
else:
count = valid_values.sum() # Sum of valid values
length = len(valid_values) # Count of valid entries
logger.info(f'# {field}: {length}. ({count / length:.2f})')
if __name__ == '__main__':
parser = get_parser()
parser.add_argument(
'--input-file', type=str, required=True, help='Path to input predictions file'
)
parser.add_argument(
'--dataset',
type=str,
default='kjain14/testgeneval',
help='Dataset to evaluate on',
)
parser.add_argument(
'--split', type=str, default='test', help='Split to evaluate on'
)
parser.add_argument(
'--skip_mutation', action='store_true', help='Skip mutation testing'
)
parser.add_argument(
'--skip_lexical', action='store_true', help='Skip lexical metrics'
)
parser.add_argument(
'--mutation_timeout',
type=int,
default=MUTATION_TIMEOUT,
help='Mutation timeout',
)
parser.add_argument(
'--mutation_buffer',
type=int,
default=MUTATION_BUFFER,
help='Mutation buffer',
)
args, _ = parser.parse_known_args()
dataset: list[TestGenEvalInstance] = load_testgeneval_dataset(
args.dataset, args.split
)
logger.info(
f'Loaded dataset {args.dataset} with split {args.split} to run inference on.'
)
# Load predictions
assert args.input_file.endswith('.jsonl'), 'Input file must be a jsonl file.'
predictions = pd.read_json(args.input_file, lines=True)
assert (
'instance_id' in predictions.columns
), 'Input file must contain instance_id column.'
if 'test_suite' not in predictions.columns and (
'test_result' in predictions.columns
and 'test_suite' in predictions['test_result'].iloc(0)
):
raise ValueError(
'Input file must contain test_suite column OR test_result column with test_suite field.'
)
if 'instance_id_swebench' not in predictions.columns:
predictions['instance_id_swebench'] = predictions['instance'].apply(
lambda x: x['instance_id_swebench']
)
if 'instance_id' not in predictions.columns and (
'instance_id' in predictions['instance'].iloc(0)
):
raise ValueError(
'Input file must contain id column OR instance column with id field.'
)
if 'instance_id' not in predictions.columns:
predictions['instance_id'] = predictions['instance'].apply(
lambda x: x['instance_id']
)
if 'test_suite' not in predictions.columns:
predictions['test_suite'] = predictions['test_result'].apply(
lambda x: x['test_suite']
)
assert len(predictions['instance_id'].unique()) == len(
predictions
), 'instance_id column must be unique.'
assert {'instance_id_swebench', 'test_suite', 'instance_id'}.issubset(
set(predictions.columns)
), 'Input file must contain id, instance_id and test_suite columns.'
predictions['test_spec'] = predictions['instance'].apply(
lambda x: make_test_spec(x, args.mutation_timeout, args.mutation_buffer)
)
output_file = args.input_file.replace('.jsonl', '.testgeneval.jsonl')
instances = prepare_dataset(predictions, output_file, args.eval_n_limit)
# If possible, load the relevant metadata to avoid issues with `run_evaluation`.
metadata: EvalMetadata | None = None
metadata_filepath = os.path.join(os.path.dirname(args.input_file), 'metadata.json')
if os.path.exists(metadata_filepath):
with open(metadata_filepath, 'r') as metadata_file:
data = metadata_file.read()
metadata = EvalMetadata.model_validate_json(data)
# The evaluation harness constrains the signature of `process_instance_func` but we need to
# pass extra information. Build a new function object to avoid issues with multiprocessing.
process_instance_func = partial(
process_instance, log_dir=output_file.replace('.jsonl', '.logs')
)
run_evaluation(
instances,
metadata=None,
output_file=output_file,
num_workers=args.eval_num_workers,
process_instance_func=process_instance_func,
)
# Load evaluated predictions & print number of resolved predictions
evaluated_predictions = pd.read_json(output_file, lines=True)
report_fields = [
'coverage',
'mutation_score',
'tests_pass',
'all_tests_pass',
'empty_generation',
'coverage_success',
'test_timeout',
'error_eval',
]
lexical_fields = [
'pred_loc',
'gold_loc',
'pred_methods',
'gold_methods',
'code_bleu',
'bleu',
'xmatch',
'edit_sim',
'rouge_f',
'rouge_p',
'rouge_r',
]
# Log report and lexical fields
count_and_log_fields(evaluated_predictions, report_fields, key='report')
count_and_log_fields(evaluated_predictions, lexical_fields, key='lexical')
@@ -0,0 +1,291 @@
import re
from evaluation.benchmarks.testgeneval.constants import TestStatus
def parse_log_pytest(log: str) -> dict[str, str]:
"""
Parser for test logs generated with PyTest framework
Args:
log (str): log content
Returns:
dict: test case to test status mapping
"""
test_status_map = {}
for line in log.split('\n'):
if any([line.startswith(x.value) for x in TestStatus]):
# Additional parsing for FAILED status
if line.startswith(TestStatus.FAILED.value):
line = line.replace(' - ', ' ')
test_case = line.split()
if len(test_case) <= 1:
continue
test_status_map[test_case[1]] = test_case[0]
return test_status_map
def parse_log_pytest_options(log: str) -> dict[str, str]:
"""
Parser for test logs generated with PyTest framework with options
Args:
log (str): log content
Returns:
dict: test case to test status mapping
"""
option_pattern = re.compile(r'(.*?)\[(.*)\]')
test_status_map = {}
for line in log.split('\n'):
if any([line.startswith(x.value) for x in TestStatus]):
# Additional parsing for FAILED status
if line.startswith(TestStatus.FAILED.value):
line = line.replace(' - ', ' ')
test_case = line.split()
if len(test_case) <= 1:
continue
has_option = option_pattern.search(test_case[1])
if has_option:
main, option = has_option.groups()
if (
option.startswith('/')
and not option.startswith('//')
and '*' not in option
):
option = '/' + option.split('/')[-1]
test_name = f'{main}[{option}]'
else:
test_name = test_case[1]
test_status_map[test_name] = test_case[0]
return test_status_map
def parse_log_django(log: str) -> dict[str, str]:
"""
Parser for test logs generated with Django tester framework
Args:
log (str): log content
Returns:
dict: test case to test status mapping
"""
test_status_map = {}
lines = log.split('\n')
prev_test = None
for line in lines:
line = line.strip()
# This isn't ideal but the test output spans multiple lines
if '--version is equivalent to version' in line:
test_status_map['--version is equivalent to version'] = (
TestStatus.PASSED.value
)
# Log it in case of error
if ' ... ' in line:
prev_test = line.split(' ... ')[0]
pass_suffixes = (' ... ok', ' ... OK', ' ... OK')
for suffix in pass_suffixes:
if line.endswith(suffix):
# TODO: Temporary, exclusive fix for django__django-7188
# The proper fix should involve somehow getting the test results to
# print on a separate line, rather than the same line
if line.strip().startswith(
'Applying sites.0002_alter_domain_unique...test_no_migrations'
):
line = line.split('...', 1)[-1].strip()
test = line.rsplit(suffix, 1)[0]
test_status_map[test] = TestStatus.PASSED.value
break
if ' ... skipped' in line:
test = line.split(' ... skipped')[0]
test_status_map[test] = TestStatus.SKIPPED.value
if line.endswith(' ... FAIL'):
test = line.split(' ... FAIL')[0]
test_status_map[test] = TestStatus.FAILED.value
if line.startswith('FAIL:'):
test = line.split()[1].strip()
test_status_map[test] = TestStatus.FAILED.value
if line.endswith(' ... ERROR'):
test = line.split(' ... ERROR')[0]
test_status_map[test] = TestStatus.ERROR.value
if line.startswith('ERROR:'):
test = line.split()[1].strip()
test_status_map[test] = TestStatus.ERROR.value
if line.lstrip().startswith('ok') and prev_test is not None:
# It means the test passed, but there's some additional output (including new lines)
# between "..." and "ok" message
test = prev_test
test_status_map[test] = TestStatus.PASSED.value
# TODO: This is very brittle, we should do better
# There's a bug in the django logger, such that sometimes a test output near the end gets
# interrupted by a particular long multiline print statement.
# We have observed this in one of 3 forms:
# - "{test_name} ... Testing against Django installed in {*} silenced.\nok"
# - "{test_name} ... Internal Server Error: \/(.*)\/\nok"
# - "{test_name} ... System check identified no issues (0 silenced).\nok"
patterns = [
r'^(.*?)\s\.\.\.\sTesting\ against\ Django\ installed\ in\ ((?s:.*?))\ silenced\)\.\nok$',
r'^(.*?)\s\.\.\.\sInternal\ Server\ Error:\ \/(.*)\/\nok$',
r'^(.*?)\s\.\.\.\sSystem check identified no issues \(0 silenced\)\nok$',
]
for pattern in patterns:
for match in re.finditer(pattern, log, re.MULTILINE):
test_name = match.group(1)
test_status_map[test_name] = TestStatus.PASSED.value
return test_status_map
def parse_log_pytest_v2(log: str) -> dict[str, str]:
"""
Parser for test logs generated with PyTest framework (Later Version)
Args:
log (str): log content
Returns:
dict: test case to test status mapping
"""
test_status_map = {}
escapes = ''.join([chr(char) for char in range(1, 32)])
for line in log.split('\n'):
line = re.sub(r'\[(\d+)m', '', line)
translator = str.maketrans('', '', escapes)
line = line.translate(translator)
if any([line.startswith(x.value) for x in TestStatus]):
if line.startswith(TestStatus.FAILED.value):
line = line.replace(' - ', ' ')
test_case = line.split()
if len(test_case) >= 2:
test_status_map[test_case[1]] = test_case[0]
# Support older pytest versions by checking if the line ends with the test status
elif any([line.endswith(x.value) for x in TestStatus]):
test_case = line.split()
if len(test_case) >= 2:
test_status_map[test_case[0]] = test_case[1]
return test_status_map
def parse_log_seaborn(log: str) -> dict[str, str]:
"""
Parser for test logs generated with seaborn testing framework
Args:
log (str): log content
Returns:
dict: test case to test status mapping
"""
test_status_map = {}
for line in log.split('\n'):
if line.startswith(TestStatus.FAILED.value):
test_case = line.split()[1]
test_status_map[test_case] = TestStatus.FAILED.value
elif f' {TestStatus.PASSED.value} ' in line:
parts = line.split()
if parts[1] == TestStatus.PASSED.value:
test_case = parts[0]
test_status_map[test_case] = TestStatus.PASSED.value
elif line.startswith(TestStatus.PASSED.value):
parts = line.split()
test_case = parts[1]
test_status_map[test_case] = TestStatus.PASSED.value
return test_status_map
def parse_log_sympy(log: str) -> dict[str, str]:
"""
Parser for test logs generated with Sympy framework
Args:
log (str): log content
Returns:
dict: test case to test status mapping
"""
test_status_map = {}
pattern = r'(_*) (.*)\.py:(.*) (_*)'
matches = re.findall(pattern, log)
for match in matches:
test_case = f'{match[1]}.py:{match[2]}'
test_status_map[test_case] = TestStatus.FAILED.value
for line in log.split('\n'):
line = line.strip()
if line.startswith('test_'):
if line.endswith('[FAIL]') or line.endswith('[OK]'):
line = line[: line.rfind('[')]
line = line.strip()
if line.endswith(' E'):
test = line.split()[0]
test_status_map[test] = TestStatus.ERROR.value
if line.endswith(' F'):
test = line.split()[0]
test_status_map[test] = TestStatus.FAILED.value
if line.endswith(' ok'):
test = line.split()[0]
test_status_map[test] = TestStatus.PASSED.value
return test_status_map
def parse_log_matplotlib(log: str) -> dict[str, str]:
"""
Parser for test logs generated with PyTest framework
Args:
log (str): log content
Returns:
dict: test case to test status mapping
"""
test_status_map = {}
for line in log.split('\n'):
line = line.replace('MouseButton.LEFT', '1')
line = line.replace('MouseButton.RIGHT', '3')
if any([line.startswith(x.value) for x in TestStatus]):
# Additional parsing for FAILED status
if line.startswith(TestStatus.FAILED.value):
line = line.replace(' - ', ' ')
test_case = line.split()
if len(test_case) <= 1:
continue
test_status_map[test_case[1]] = test_case[0]
return test_status_map
parse_log_astroid = parse_log_pytest
parse_log_flask = parse_log_pytest
parse_log_marshmallow = parse_log_pytest
parse_log_pvlib = parse_log_pytest
parse_log_pyvista = parse_log_pytest
parse_log_sqlfluff = parse_log_pytest
parse_log_xarray = parse_log_pytest
parse_log_pydicom = parse_log_pytest_options
parse_log_requests = parse_log_pytest_options
parse_log_pylint = parse_log_pytest_options
parse_log_astropy = parse_log_pytest_v2
parse_log_scikit = parse_log_pytest_v2
parse_log_sphinx = parse_log_pytest_v2
MAP_REPO_TO_PARSER = {
'astropy/astropy': parse_log_astropy,
'django/django': parse_log_django,
'marshmallow-code/marshmallow': parse_log_marshmallow,
'matplotlib/matplotlib': parse_log_matplotlib,
'mwaskom/seaborn': parse_log_seaborn,
'pallets/flask': parse_log_flask,
'psf/requests': parse_log_requests,
'pvlib/pvlib-python': parse_log_pvlib,
'pydata/xarray': parse_log_xarray,
'pydicom/pydicom': parse_log_pydicom,
'pylint-dev/astroid': parse_log_astroid,
'pylint-dev/pylint': parse_log_pylint,
'pytest-dev/pytest': parse_log_pytest,
'pyvista/pyvista': parse_log_pyvista,
'scikit-learn/scikit-learn': parse_log_scikit,
'sqlfluff/sqlfluff': parse_log_sqlfluff,
'sphinx-doc/sphinx': parse_log_sphinx,
'sympy/sympy': parse_log_sympy,
}
@@ -0,0 +1,311 @@
import sys
from typing import Callable, Dict, List, Optional, Sequence, TypeVar, Union
import nltk
import numpy as np
from fuzzywuzzy import fuzz
from rouge import Rouge
# increase recursion depth to ensure ROUGE can be calculated for long sentences
if sys.getrecursionlimit() < 10_000:
sys.setrecursionlimit(10_000)
def bleu(gold: List[str], pred: List[str]) -> float:
"""
Calculate BLEU score, using smoothing method 2 with auto reweighting, in the range of 0~100.
:param gold: list of gold tokens
:param pred: list of predicted tokens
:return: BLEU score
"""
if len(pred) == 0 or len(gold) == 0:
return 0.0
return 100.0 * nltk.translate.bleu_score.sentence_bleu(
[gold],
pred,
smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method2,
auto_reweigh=True,
)
def batch_bleu(golds: List[List[str]], preds: List[List[str]]) -> List[float]:
"""
Calculate BLEU score for a batch of sentences.
:param golds: list of gold sentences
:param preds: list of predicted sentences
:return: list of BLEU scores
"""
if len(golds) != len(preds):
raise ValueError("golds and preds must have the same length")
return [bleu(gold, pred) for gold, pred in zip(golds, preds)]
def corpus_bleu(golds: List[List[str]], preds: List[List[str]]) -> float:
"""
Calculate corpus-level BLEU score for a batch of sentences.
:param golds: list of gold sentences
:param preds: list of predicted sentences
:return: corpus-level BLEU score
"""
if len(golds) != len(preds):
raise ValueError("golds and preds must have the same length")
return 100.0 * nltk.translate.bleu_score.corpus_bleu(
[[gold] for gold in golds],
preds,
smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method2,
auto_reweigh=True,
)
def edit_sim(
gold: Union[str, List[str]], pred: Union[str, List[str]], sep: str = " "
) -> float:
"""
Calculate char-level edit similarity, in the range of 0~100.
:param gold: gold sentence or list of gold tokens
:param pred: predicted sentence or list of predicted tokens
:param sep: separator between tokens
:return: char-level edit similarity
"""
if len(pred) == 0 or len(gold) == 0:
return 0.0
if isinstance(gold, list):
gold = sep.join(gold)
if isinstance(pred, list):
pred = sep.join(pred)
return fuzz.ratio(gold, pred)
def batch_edit_sim(
golds: List[Union[str, List[str]]],
preds: List[Union[str, List[str]]],
sep: str = " ",
) -> List[float]:
"""
Calculate char-level edit similarity for a batch of sentences.
:param golds: list of gold sentences
:param preds: list of predicted sentences
:param sep: separator between tokens
:return: list of char-level edit similarity
"""
if len(golds) != len(preds):
raise ValueError("golds and preds must have the same length")
return [edit_sim(gold, pred, sep) for gold, pred in zip(golds, preds)]
T = TypeVar("T")
def exact_match(gold: T, pred: T) -> float:
"""
Calculate exact match accuracy, in the range of {0, 100}.
:param gold: gold sentence or list of gold tokens
:param pred: predicted sentence or list of predicted tokens
:return: exact match accuracy
"""
if len(pred) == 0 or len(gold) == 0:
return 0.0
return 100.0 if gold == pred else 0.0
def batch_exact_match(golds: List[T], preds: List[T]) -> List[float]:
"""
Calculate exact match accuracy for a batch of sentences.
:param golds: list of gold sentences
:param preds: list of predicted sentences
:return: list of exact match accuracy
"""
if len(golds) != len(preds):
raise ValueError("golds and preds must have the same length")
return [exact_match(gold, pred) for gold, pred in zip(golds, preds)]
def rouge_l(
gold: Union[str, List[str]], pred: Union[str, List[str]], sep: str = " "
) -> Dict[str, float]:
"""
Calculate ROUGE-L F1, precision, and recall scores, in the range of 0~100.
:param gold: gold sentence or list of gold tokens
:param pred: predicted sentence or list of predicted tokens
:return: {"p": precision, "r": recall, "f": F1}
"""
if len(pred) == 0 or len(gold) == 0:
return {"p": 0.0, "r": 0.0, "f": 0.0}
if isinstance(gold, list):
gold = sep.join(gold)
if isinstance(pred, list):
pred = sep.join(pred)
try:
rouge = Rouge()
scores = rouge.get_scores(hyps=pred, refs=gold, avg=True)
return {x: scores["rouge-l"][x] * 100.0 for x in ["p", "r", "f"]}
except ValueError:
return {"p": 0.0, "r": 0.0, "f": 0.0}
def batch_rouge_l(
golds: List[Union[str, List[str]]],
preds: List[Union[str, List[str]]],
sep: str = " ",
) -> Dict[str, List[float]]:
"""
Calculate ROUGE-L F1, precision, and recall scores for a batch of sentences.
:param golds: list of gold sentences
:param preds: list of predicted sentences
:param sep: separator between tokens
:return: list of {"p": precision, "r": recall, "f": F1}
"""
if len(golds) != len(preds):
raise ValueError("golds and preds must have the same length")
scores = [rouge_l(gold, pred, sep) for gold, pred in zip(golds, preds)]
return {x: [score[x] for score in scores] for x in ["p", "r", "f"]}
def accuracy(
gold: List[str],
pred: List[str],
ignore: Optional[Sequence[str]] = None,
) -> float:
"""
Calculate token-level accuracy, in the range of 0~100.
If gold and pred are not the same length, the longer one would be truncated.
:param gold: list of gold tokens
:param pred: list of predicted tokens
:param ignore: list of (gold) tokens to ignore
:return: accuracy
"""
if len(pred) == 0 or len(gold) == 0:
return 0.0
if ignore is None:
ignore = []
i = 0
total = 0
match = 0
while i < len(gold) and i < len(pred):
if gold[i] in ignore:
i += 1
continue
total += 1
if gold[i] == pred[i]:
match += 1
i += 1
if total == 0:
return 0.0
return 100.0 * match / total
def batch_accuracy(
golds: List[List[str]],
preds: List[List[str]],
ignore: Optional[Sequence[str]] = None,
) -> List[float]:
"""
Calculate token-level accuracy for a batch of sentences.
:param golds: list of gold sentences
:param preds: list of predicted sentences
:param ignore: list of (gold) tokens to ignore
:return: list of accuracy
"""
if len(golds) != len(preds):
raise ValueError("golds and preds must have the same length")
return [accuracy(gold, pred, ignore) for gold, pred in zip(golds, preds)]
def first_match_to_topk(
first_match_list: List[int], k_values: List[int]
) -> Dict[int, List[float]]:
"""
Calculate top-k accuracy with the first match ranks (1-indexed).
:param first_match: first match ranks (1-indexed)
:param k_values: k values to consider
:return: a mapping from k to top-k accuracies (ranging from 0~100)
"""
return {k: [100.0 if x <= k else 0.0 for x in first_match_list] for k in k_values}
def pass_at_k(n: int, c: int, k: int) -> float:
"""
Sample pass@k metric according to the Codex paper, but in the scale of 0~100.
:param n: total number of samples
:param c: number of correct samples
:param k: k in pass@$k$
"""
if n < k or (n - c) < k:
# fallback to the (1 - (1-p)^k) formula
return (1 - (1 - (c / n)) ** k) * 100
else:
return (1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)).item()) * 100
def self_bleu(samples: List[List[str]]) -> float:
"""
Calculate self-BLEU among the samples.
:param samples: the chosen m samples
:return: self-BLEU
"""
if len(samples) == 0:
return 100.0
scores = []
for i in range(len(samples)):
scores.append(
100.0
* nltk.translate.bleu_score.sentence_bleu(
[samples[j] for j in range(len(samples)) if j != i],
samples[i],
smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method2,
auto_reweigh=True,
)
)
return np.mean(scores).item()
def self_edit_distance(samples: List[Union[str, List[str]]], sep=" ") -> float:
"""
Calculate self-edit-distance among the samples.
:param samples: the chosen m samples
:param sep: the separator between tokens
:return: self-edit-distance
"""
if len(samples) == 0:
return 0.0
scores = []
for i in range(len(samples)):
sample_i = samples[i]
if not isinstance(sample_i, str):
sample_i = sep.join(sample_i)
for j in range(len(samples)):
if i == j:
continue
sample_j = samples[j]
if not isinstance(sample_j, str):
sample_j = sep.join(sample_j)
scores.append(100 - fuzz.ratio(sample_i, sample_j))
return np.mean(scores).item()
QUALITY_METRICS: Dict[str, Callable[[List[str], List[str]], float]] = {
"bleu": bleu,
"xmatch": exact_match,
"edit-sim": edit_sim,
"rouge-f": lambda g, p: rouge_l(g, p)["f"],
"rouge-p": lambda g, p: rouge_l(g, p)["p"],
"rouge-r": lambda g, p: rouge_l(g, p)["r"],
}
+114
View File
@@ -0,0 +1,114 @@
CODEACT_TESTGEN_PROMPT_OLD = """Your goal is to generate a high-quality test suite (at least 20+ passing tests) for the code file: {code_file}. Output the test suite at {test_file}\n'
[current directory: /workspace/{workspace_dir_name}]
IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP
IMPORTANT: Follow instructions, if you have < 80 tests you should generate more tests rather than trying to fix the ones you have.
IMPORTANT: Code file to test:
```python
{code_src}
```
Here are additional imports that you may need:
{imports}
Look at code dependencies (NOT {code_file} since you already have contents) and test files you need context for to write a complete test suite.
Aim for 20+ test functions with asserts. Do not hestitate to use the Python interpreter to understand the input output behavior of the code you are testing.
Output your test suite at {test_file}. Each unit test must be a function starting with test_. Include all your test imports and setup before your first test. Do not include a main method to run the tests. Make sure to make it as comprehensive as possible, try to execute all the methods you saw.
When you think you've successfully generated a test suite, run it on for the current project using {coverage_command}.
If you have few tests GENERATE MORE TESTS rather than trying to fix the ones you have (it is possible to filter out failing tests later).
Then run coverage report -m --include {code_file} to see how well your test suite covers the code under test.
When you are trying to improve coverage pick a part of the code that is not covered (indicated by lines on coverage report), examine the code and then
try to generate a test for it. Feel free to use a code interpreter to understand the input output behavior. ONLY add tests
not remove them.
If you are unable to see passing and failing tests, FIX YOUR IMPORTS to use the same style as other test files.
You should NOT modify any existing test case files. You SHOULD add new test in a NEW file to reproduce the issue.
You should NEVER use web browsing or any other web-based tools.
You should NEVER install new packages, use existing packages only.
You should ALWAYS use the default Python interpreter available in the <execute_bash> environment to run code related to the provided issue and/or repository.
You should ALWAYS use local imports DO NOT import the general library.
When you think you have a fully adequate test suite, please run the following command: <execute_bash> exit </execute_bash>.
"""
CODEACT_TESTGEN_PROMPT = """
Your goal is to generate a comprehensive, **broad-coverage** test suite for the code below, ensuring you test as many lines and branches as possible on the first attempt.
Place your test suite in a new file named {test_file}.
IMPORTANT REQUIREMENTS:
1. **No external help or resources**—use only the snippet below.
2. **Focus on breadth over depth**: cover all major functions, classes, and code paths early to minimize coverage iterations.
3. Each test function must start with `test_` and use `assert` to verify behavior.
4. Include only necessary imports (standard library or local).
5. Do **not** modify existing test files—create a brand new one. No `main()` or other non-test code.
6. Produce **at least 20 test functions**; if coverage is lacking, add more tests rather than removing or changing existing ones.
7. Use the following commands to check coverage:
<execute_bash> {coverage_command} </execute_bash>
<execute_bash> coverage report -m --include {code_file} </execute_bash>
If lines remain uncovered, add new tests targeting them specifically.
8. When you're satisfied with coverage, finalize by running:
<execute_bash> exit </execute_bash>
Below is the **complete code snippet** to test:
<START_OF_CODE>
{code_src}
<END_OF_CODE>
NOTE: if you are testing django, you must use from django.test import SimpleTestCase and class based tests (i.e. class TestSomething(SimpleTestCase)).
NOTE: if there is an error executing tests you MUST fix it before exiting. DO NOT install new packages.
NOTE: if outputting a revised test suite REPLACE {test_file} with the revised suite
**Output the final test suite** (20+ tests) for {test_file} in a single code block, no extra commentary. MAKE SURE you run the tests and ensure you can see which tests passed and failed BEFORE exiting.
"""
CODEACT_TESTGEN_PROMPT_ITERATE = """
Your goal is to improve the test suite at {test_file} to achieve **broad-coverage** of the code below.
First run the test suite.
If no tests run, then remove {test_file} and create {test_file} with a new suite.
Otherwise, improve it aiming to improve code coverage.
IMPORTANT REQUIREMENTS:
1. Use the following commands to check coverage (RUN THIS FIRST):
<execute_bash> {coverage_command} </execute_bash>
<execute_bash> coverage report -m --include {code_file} </execute_bash>
If lines remain uncovered, add new tests targeting them specifically.
2. **No external help or resources**—use only the snippet below.
3. **Focus on breadth over depth**: cover all major functions, classes, and code paths early to minimize coverage iterations.
4. Each test function must use `assert` to verify behavior.
5. Include only necessary imports (standard library or local).
6. Do **not** modify other test files in the repository. No `main()` or other non-test code.
7. Produce **at least 20 test functions**; if coverage is lacking, add more tests rather than removing or changing existing ones.
8. When you're satisfied with coverage, finalize by running:
<execute_bash> exit </execute_bash>
Below is the **complete code snippet** to test:
<START_OF_CODE>
{code_src}
<END_OF_CODE>
NOTE: if you are testing django, you must use from django.test import SimpleTestCase and class based tests (i.e. class TestSomething(SimpleTestCase)).
NOTE: if there is an error executing tests you MUST fix it before exiting. DO NOT install new packages.
NOTE: if outputting a revised test suite REPLACE {test_file} with the revised suite
**Output the final test suite** (20+ tests) for {test_file} in a single code block, no extra commentary. MAKE SURE you run the tests and ensure you can see which tests passed and failed BEFORE exiting.
"""
@@ -0,0 +1,31 @@
import re
from pygments.lexers.python import PythonLexer
def tokenize_code(code):
lexer = PythonLexer()
tokens = process_pygments_tokens(lexer.get_tokens(code))
return tokens
def process_pygments_tokens(tokens):
new_tokens = []
for token in tokens:
if str(token[0]) == "Token.Text" and re.match(r'\s+', token[1]) or str(token[0]) == "Token.Text.Whitespace":
continue
new_tokens.append(token[1])
new_tokens_final = []
i = 0
while i < len(new_tokens)-2:
if new_tokens[i] == '"' and new_tokens[i+1]=='STR' and new_tokens[i+2] == '"':
new_tokens_final.append("\"STR\"")
i = i + 3
else:
new_tokens_final.append(new_tokens[i])
i = i + 1
for i in range(len(new_tokens)-2, len(new_tokens)):
if i >= 0:
new_tokens_final.append(new_tokens[i])
return new_tokens_final
@@ -0,0 +1,58 @@
import json
import re
def check_coverage(coverage_output, code_file):
json_cov = json.loads(coverage_output)
if code_file in json_cov['files'].keys():
file_data = json_cov['files'][code_file]
return True, file_data['summary']['percent_covered']
return False, 0
def check_mutation(mutation_output):
if 'total jobs: ' in mutation_output:
num_mutants = int(mutation_output.split('total jobs: ')[1].split('\n')[0])
final_conf = mutation_output.split('\n')[-1]
if len(final_conf.strip().split(' ')) == 3:
low, val, high = final_conf.split(' ')
low = float(low)
val = float(val)
high = float(high)
confidence_range = high - val
mutation_score = 100 - val
return True, num_mutants, mutation_score, confidence_range
return False, -1, 0, -1
def count_methods(code_str):
"""
Counts the number of methods/functions in a given string of code.
Args:
code_str (str): A string containing code.
Returns:
int: The number of methods/functions found.
"""
# Regular expression to find Python function definitions
pattern = r'\bdef\b\s+\w+\s*\('
matches = re.findall(pattern, code_str)
return len(matches)
def get_lines_of_code(code_str):
"""
Extracts lines of code from a given string.
Args:
code_str (str): A string containing code.
Returns:
list: A list of lines of code.
"""
return len(code_str.strip().split('\n'))
@@ -0,0 +1,577 @@
import asyncio
import json
import os
import tempfile
import time
import traceback
from typing import Any
import numpy as np
import pandas as pd
import toml
from datasets import load_dataset
import openhands.agenthub
from evaluation.benchmarks.testgeneval.constants import MAP_REPO_VERSION_TO_SPECS
from evaluation.benchmarks.testgeneval.prompt import (
CODEACT_TESTGEN_PROMPT,
CODEACT_TESTGEN_PROMPT_ITERATE,
)
from evaluation.benchmarks.testgeneval.utils import get_test_directives
from evaluation.utils.shared import (
EvalException,
EvalMetadata,
EvalOutput,
assert_and_raise,
codeact_user_response,
get_metrics,
is_fatal_evaluation_error,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
run_evaluation,
update_llm_config_for_completions_logging,
)
from openhands.controller.state.state import State
from openhands.core.config import (
AgentConfig,
AppConfig,
SandboxConfig,
get_llm_config_arg,
get_parser,
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation, ErrorObservation
from openhands.events.serialization.event import event_to_dict
from openhands.runtime.base import Runtime
from openhands.utils.async_utils import call_async_from_sync
RUN_WITH_BROWSING = os.environ.get('RUN_WITH_BROWSING', 'false').lower() == 'true'
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
'CodeActAgent': codeact_user_response,
}
def _preprocess_instance(d):
for key, value in d.items():
if isinstance(value, np.ndarray):
d[key] = value.tolist()
return d
def _get_swebench_workspace_dir_name(instance: pd.Series) -> str:
return f'{instance.repo}__{instance.version}'.replace('/', '__')
def get_instruction(instance: pd.Series, metadata: EvalMetadata):
# workspace_dir_name = _get_swebench_workspace_dir_name(instance)
# Prepare instruction
coverage_command = ' '.join(
[
MAP_REPO_VERSION_TO_SPECS[instance['repo']][instance['version']][
'test_cmd'
],
*get_test_directives(instance),
]
)
# Testing general agents
prompt_to_use = (
CODEACT_TESTGEN_PROMPT_ITERATE
if instance['full_pred'] is not None
else CODEACT_TESTGEN_PROMPT
)
instruction = prompt_to_use.format(
code_file=os.path.join('/testbed', instance.code_file),
test_file=os.path.join('/testbed', instance.test_file),
coverage_command=coverage_command,
code_src=instance['code_src'],
imports='\n'.join(instance.local_imports),
workspace_dir_name=_get_swebench_workspace_dir_name(instance),
)
if RUN_WITH_BROWSING:
instruction += (
'<IMPORTANT!>\n'
'You SHOULD NEVER attempt to browse the web. '
'</IMPORTANT!>\n'
)
return instruction
# TODO: migrate all swe-bench docker to ghcr.io/openhands
DOCKER_IMAGE_PREFIX = os.environ.get('EVAL_DOCKER_IMAGE_PREFIX', 'docker.io/kdjain/')
logger.info(f'Using docker image prefix: {DOCKER_IMAGE_PREFIX}')
def get_instance_docker_image(instance_id: str) -> str:
image_name = 'sweb.eval.x86_64.' + instance_id
image_name = image_name.replace(
'__', '_s_'
) # to comply with docker image naming convention
return DOCKER_IMAGE_PREFIX.rstrip('/') + '/' + image_name
def get_config(
instance: pd.Series,
metadata: EvalMetadata,
) -> AppConfig:
# We use a different instance image for the each instance of TestGenEval
base_container_image = get_instance_docker_image(instance['instance_id_swebench'])
logger.info(
f'Using instance container image: {base_container_image}. '
f'Please make sure this image exists. '
f'Submit an issue on https://github.com/All-Hands-AI/OpenHands if you run into any issues.'
)
config = AppConfig(
default_agent=metadata.agent_class,
run_as_openhands=False,
max_iterations=metadata.max_iterations,
runtime=os.environ.get('RUNTIME', 'eventstream'),
sandbox=SandboxConfig(
base_container_image=base_container_image,
enable_auto_lint=True,
use_host_network=False,
# large enough timeout, since some testcases take very long to run
timeout=300,
# Add platform to the sandbox config to solve issue 4401
platform='linux/amd64',
api_key=os.environ.get('ALLHANDS_API_KEY', None),
remote_runtime_api_url=os.environ.get(
'SANDBOX_REMOTE_RUNTIME_API_URL', 'http://localhost:8000'
),
keep_runtime_alive=False,
remote_runtime_init_timeout=3600,
),
# do not mount workspace
workspace_base=None,
workspace_mount_path=None,
)
config.set_llm_config(
update_llm_config_for_completions_logging(
metadata.llm_config, metadata.eval_output_dir, instance['id']
)
)
agent_config = AgentConfig(
codeact_enable_jupyter=False,
codeact_enable_browsing=RUN_WITH_BROWSING,
codeact_enable_llm_editor=False,
condenser=metadata.condenser_config,
)
config.set_agent_config(agent_config)
return config
def initialize_runtime(
runtime: Runtime,
instance: pd.Series, # this argument is not required
):
"""Initialize the runtime for the agent.
This function is called before the runtime is used to run the agent.
"""
logger.info('-' * 30)
logger.info('BEGIN Runtime Initialization Fn')
logger.info('-' * 30)
workspace_dir_name = _get_swebench_workspace_dir_name(instance)
obs: CmdOutputObservation
instance['instance_id'] = instance['instance_id_swebench']
# Set instance id
action = CmdRunAction(
command=f"""echo 'export SWE_INSTANCE_ID={instance['instance_id_swebench']}' >> ~/.bashrc && echo 'export PIP_CACHE_DIR=~/.cache/pip' >> ~/.bashrc && echo "alias git='git --no-pager'" >> ~/.bashrc"""
)
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(
obs.exit_code == 0, f'Failed to export SWE_INSTANCE_ID: {str(obs)}'
)
action = CmdRunAction(command="""export USER=$(whoami); echo USER=${USER} """)
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(obs.exit_code == 0, f'Failed to export USER: {str(obs)}')
# inject the init script
script_dir = os.path.dirname(__file__)
# inject the instance info
action = CmdRunAction(command='mkdir -p /swe_util/eval_data/instances')
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(
obs.exit_code == 0,
f'Failed to create /swe_util/eval_data/instances: {str(obs)}',
)
swe_instance_json_name = 'swe-bench-instance.json'
swe_prediction = 'test_suite.py'
with tempfile.TemporaryDirectory() as temp_dir:
# Construct the full path for the desired file name within the temporary directory
temp_file_path = os.path.join(temp_dir, swe_instance_json_name)
# Write to the file with the desired name within the temporary directory
with open(temp_file_path, 'w') as f:
if not isinstance(instance, dict):
preprocessed_instance = _preprocess_instance(instance.to_dict())
json.dump([preprocessed_instance], f)
else:
preprocessed_instance = _preprocess_instance(instance)
json.dump([preprocessed_instance], f)
# Copy the file to the desired location
runtime.copy_to(temp_file_path, '/swe_util/eval_data/instances/')
if instance['full_pred'] is not None:
temp_file_path_pred = os.path.join(temp_dir, swe_prediction)
with open(temp_file_path_pred, 'w') as f:
f.write(instance['full_pred'])
runtime.copy_to(temp_file_path_pred, '/tmp')
# Copy the file to the desired location
action = CmdRunAction(
command=f"cp /tmp/test_suite.py /testbed/{instance['test_file']}"
)
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(
obs.exit_code == 0, f'Failed to copy test file: {str(obs)}'
)
action = CmdRunAction(
command='git -C /testbed add . && git -C /testbed commit -m "Add test file"'
)
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(obs.exit_code == 0, f'Failed to cat ~/.bashrc: {str(obs)}')
# inject the instance swe entry
runtime.copy_to(
str(os.path.join(script_dir, 'scripts/setup/instance_swe_entry.sh')),
'/swe_util/',
)
action = CmdRunAction(command='cat ~/.bashrc')
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(obs.exit_code == 0, f'Failed to cat ~/.bashrc: {str(obs)}')
action = CmdRunAction(command='source ~/.bashrc')
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
if isinstance(obs, ErrorObservation):
logger.error(f'Failed to source ~/.bashrc: {str(obs)}')
assert_and_raise(obs.exit_code == 0, f'Failed to source ~/.bashrc: {str(obs)}')
action = CmdRunAction(command='source /swe_util/instance_swe_entry.sh')
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(
obs.exit_code == 0,
f'Failed to source /swe_util/instance_swe_entry.sh: {str(obs)}',
)
action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(
obs.exit_code == 0,
f'Failed to cd to /workspace/{workspace_dir_name}: {str(obs)}',
)
action = CmdRunAction(command='git reset --hard')
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(obs.exit_code == 0, f'Failed to git reset --hard: {str(obs)}')
action = CmdRunAction(
command='for remote_name in $(git remote); do git remote remove "${remote_name}"; done'
)
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(obs.exit_code == 0, f'Failed to remove git remotes: {str(obs)}')
logger.info('-' * 30)
logger.info('END Runtime Initialization Fn')
logger.info('-' * 30)
def complete_runtime(
runtime: Runtime,
instance: pd.Series, # this argument is not required, but it is used to get the workspace_dir_name
) -> dict[str, Any]:
"""Complete the runtime for the agent.
This function is called before the runtime is used to run the agent.
If you need to do something in the sandbox to get the correctness metric after
the agent has run, modify this function.
"""
try:
logger.info('-' * 30)
logger.info('BEGIN Runtime Completion Fn')
logger.info('-' * 30)
obs: CmdOutputObservation
workspace_dir_name = _get_swebench_workspace_dir_name(instance)
action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(
obs.exit_code == 0,
f'Failed to cd to /workspace/{workspace_dir_name}: {str(obs)}',
)
action = CmdRunAction(command=f'cat {instance.test_file}')
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
assert_and_raise(
obs.exit_code == 0,
f'Failed to find file: {instance.test_file} in /workspace/{workspace_dir_name}',
)
test_suite = obs.content.strip()
except Exception:
# Print stack trace
print('Skipping, exception in complete_runtime')
print(traceback.format_exc())
test_suite = instance['full_pred'] if instance['full_pred'] is not None else ''
# action = CmdRunAction(command='git add -A')
# action.set_hard_timeout(600)
# logger.info(action, extra={'msg_type': 'ACTION'})
# obs = runtime.run_action(action)
# logger.info(obs, extra={'msg_type': 'OBSERVATION'})
# assert_and_raise(obs.exit_code == 0, f'Failed to git add -A: {str(obs)}')
logger.info('-' * 30)
logger.info('END Runtime Completion Fn')
logger.info('-' * 30)
return {
'test_suite': test_suite,
}
def process_instance(
instance: pd.Series,
metadata: EvalMetadata,
reset_logger: bool = True,
) -> EvalOutput:
config = get_config(instance, metadata)
start_time = time.time() # Track start time
# Setup the logger properly, so you can run multi-processing to parallelize the evaluation
if reset_logger:
log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
reset_logger_for_multiprocessing(logger, instance.id, log_dir)
else:
logger.info(f'Starting evaluation for instance {instance.id}.')
runtime = create_runtime(config)
call_async_from_sync(runtime.connect)
try:
initialize_runtime(runtime, instance)
instruction = get_instruction(instance, metadata)
# Here's how you can run the agent (similar to the `main` function) and get the final task state
state: State | None = asyncio.run(
run_controller(
config=config,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
metadata.agent_class
],
)
)
# if fatal error, throw EvalError to trigger re-run
if is_fatal_evaluation_error(state.last_error):
raise EvalException('Fatal error detected: ' + state.last_error)
# ======= THIS IS SWE-Bench specific =======
return_val = complete_runtime(runtime, instance)
test_suite = return_val['test_suite']
logger.info(
f'Got test suite for instance {instance.instance_id}:\n--------\n{test_suite}\n--------'
)
finally:
runtime.close()
end_time = time.time()
elapsed_time = end_time - start_time
logger.info(
f'Evaluation for instance {instance.instance_id} took {elapsed_time:.2f} seconds.'
)
# ==========================================
# ======= Attempt to evaluate the agent's edits =======
# we use eval_infer.sh to evaluate the agent's edits, not here
# because the agent may alter the environment / testcases
test_result = {
'test_suite': test_suite,
'elapsed_time': elapsed_time,
}
# If you are working on some simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
# You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
if state is None:
raise ValueError('State should not be None.')
histories = [event_to_dict(event) for event in state.history]
metrics = get_metrics(state)
# Save the output
output = EvalOutput(
instance_id=instance.id,
instruction=instruction,
instance=_preprocess_instance(instance.to_dict()), # SWE Bench specific
test_result=test_result,
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
)
# print(output)
return output
def prepare_dataset_pre(dataset: pd.DataFrame, filter_column: str) -> pd.DataFrame:
file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.toml')
if os.path.exists(file_path):
with open(file_path, 'r') as file:
data = toml.load(file)
if 'selected_ids' in data:
selected_ids = data['selected_ids']
logger.info(
f'Filtering {len(selected_ids)} tasks from "selected_ids"...'
)
subset = dataset[dataset[filter_column].isin(selected_ids)]
logger.info(f'Retained {subset.shape[0]} tasks after filtering')
subset['instance_id_swebench'] = subset['instance_id']
subset['instance_id'] = subset['id']
return subset
dataset['instance_id_swebench'] = dataset['instance_id']
dataset['instance_id'] = dataset['id']
return dataset
if __name__ == '__main__':
parser = get_parser()
parser.add_argument(
'--dataset',
type=str,
default='kjain/testgenevallite',
help='data set to evaluate on, either full-test or lite-test',
)
parser.add_argument(
'--split',
type=str,
default='test',
help='split to evaluate on',
)
parser.add_argument(
'--testfile_start',
action='store_true',
help='Whether to start from the 0 shot test file',
)
parser.add_argument(
'--zero_shot_path',
type=str,
help='Path to the zero shot test file predictions',
)
args, _ = parser.parse_known_args()
if args.testfile_start and not args.zero_shot_path:
raise ValueError(
'If you want to start from the 0 shot test file, you must provide the path to the zero shot test file predictions'
)
preds_map = {}
if args.testfile_start:
with open(args.zero_shot_path, 'r') as f:
for line in f:
pred = json.loads(line)
preds_map[pred['id']] = pred['preds']['full'][0]
# NOTE: It is preferable to load datasets from huggingface datasets and perform post-processing
# so we don't need to manage file uploading to OpenHands's repo
dataset = load_dataset(args.dataset, split=args.split)
logger.info(f'Loaded dataset {args.dataset} with split {args.split}')
testgeneval_filepairs = prepare_dataset_pre(dataset.to_pandas(), 'id')
llm_config = None
if args.llm_config:
llm_config = get_llm_config_arg(args.llm_config)
llm_config.log_completions = True
# modify_params must be False for evaluation purpose, for reproducibility and accurancy of results
llm_config.modify_params = False
if llm_config is None:
raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
details = {}
_agent_cls = openhands.agenthub.Agent.get_cls(args.agent_cls)
dataset_descrption = (
args.dataset.replace('/', '__') + '-' + args.split.replace('/', '__')
)
metadata = make_metadata(
llm_config,
dataset_descrption,
args.agent_cls,
args.max_iterations,
args.eval_note,
args.eval_output_dir,
details=details,
)
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
instances = prepare_dataset(testgeneval_filepairs, output_file, args.eval_n_limit)
if not instances.empty:
instances['full_pred'] = (
instances['instance_id']
.map(preds_map)
.apply(lambda x: x if pd.notna(x) else None)
)
run_evaluation(
instances, metadata, output_file, args.eval_num_workers, process_instance
)
@@ -0,0 +1,128 @@
import argparse
import os
import subprocess
from datasets import load_dataset
# Function to run shell commands
def run_command(command):
try:
subprocess.run(command, check=True, shell=True)
except subprocess.CalledProcessError as e:
print(f'An error occurred: {e}')
# Function to log in to Docker Hub
def docker_login():
print('Logging into Docker Hub...')
run_command('docker login')
# Function to generate Dockerfile content based on image type
def generate_dockerfile_content(
base_image, dependencies, datum, patch_path, test_patch_path
):
dockerfile_content = f"""
FROM {base_image}
SHELL ["/bin/bash", "-c"]
RUN source /opt/miniconda3/bin/activate && conda activate testbed && pip install {' '.join(dependencies)}
COPY {patch_path} /app/patch.diff
RUN git apply /app/patch.diff
RUN rm /app/patch.diff
COPY {test_patch_path} /app/patch.diff
RUN git apply /app/patch.diff
RUN git config --global user.email ""
RUN git config --global user.name "TestGenEval"
RUN rm /app/patch.diff
RUN rm {datum['test_file']}
"""
# Add specific content based on image type
dockerfile_content += 'RUN git add .\nRUN git commit -m "Testing fixes"'
return dockerfile_content
# Function to build, push, and clean up Docker images
def build_and_push_image(dockerfile_content, image_name):
with open('Dockerfile.temp', 'w') as dockerfile:
dockerfile.write(dockerfile_content)
run_command(f'docker build -f Dockerfile.temp -t {image_name} .')
run_command(f'docker push {image_name}')
run_command(f'docker rmi {image_name}')
os.remove('Dockerfile.temp')
# Function to process images with .eval in the name
def process_images(dataset, original_namespace, new_namespace, start_instance_id):
dependencies = ['coverage', 'cosmic-ray']
found_start = len(start_instance_id) == 0
for datum in dataset:
if not found_start and datum['instance_id'] == start_instance_id:
found_start = True
elif found_start:
full_image_name = f'{original_namespace}/sweb.eval.x86_64.{datum["instance_id"].replace("__", "_s_")}:latest'
print(f'Processing image: {full_image_name}')
run_command(f'docker pull {full_image_name}')
# Save patches and preds_context to regular files
patch_file_path = 'patch.diff'
test_patch_file_path = 'test_patch.diff'
with open(patch_file_path, 'w') as patch_file, open(
test_patch_file_path, 'w'
) as test_patch_file:
patch_file.write(datum['patch'])
test_patch_file.write(datum['test_patch'])
# Define image types and corresponding tags
new_image_name = f'{new_namespace}/sweb.eval.x86_64.{datum["instance_id"].replace("__", "_s_")}:latest'
dockerfile_content = generate_dockerfile_content(
full_image_name,
dependencies,
datum,
patch_file_path,
test_patch_file_path,
)
build_and_push_image(dockerfile_content, new_image_name)
# Cleanup regular files and images
os.remove(patch_file_path)
os.remove(test_patch_file_path)
run_command(f'docker rmi {full_image_name}')
run_command('docker system prune -f') # Clean up dangling resources
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Process Docker images with .eval in the name.'
)
parser.add_argument('--dataset', type=str, default='kjain14/testgeneval')
parser.add_argument('--split', type=str, default='test')
parser.add_argument(
'--new_namespace',
type=str,
default='kdjain',
help='The new Docker Hub namespace to push the images',
)
parser.add_argument(
'--original_namespace',
type=str,
default='xingyaoww',
help='The original Docker Hub namespace',
)
parser.add_argument(
'--start_instance_id',
type=str,
default='',
help='The instance_id to start processing from',
)
args = parser.parse_args()
dataset = load_dataset(args.dataset)[args.split]
docker_login()
process_images(
dataset, args.original_namespace, args.new_namespace, args.start_instance_id
)
@@ -0,0 +1,196 @@
sweb.base.x86_64:latest
sweb.env.x86_64.088a7e628bda9770f9757b:latest
sweb.env.x86_64.0d80c7dec81ee2f2f513e2:latest
sweb.env.x86_64.0f99bce2750f3109957bec:latest
sweb.env.x86_64.1b3b218535da0abf4469cb:latest
sweb.env.x86_64.1c1a6945f732f9391228c5:latest
sweb.env.x86_64.1f92e6d7cef88badc4f744:latest
sweb.env.x86_64.27dd9791e13f5c857a09f9:latest
sweb.env.x86_64.297af196949a2a635bce66:latest
sweb.env.x86_64.2baaea72acc974f6c02079:latest
sweb.env.x86_64.2e50125951bc69cddd7421:latest
sweb.env.x86_64.2f217c8b4490bfa0e2ba14:latest
sweb.env.x86_64.31244378a92e3bcce809ac:latest
sweb.env.x86_64.428468730904ff6b4232aa:latest
sweb.env.x86_64.5d1fda9d55d65d8a4e5bdb:latest
sweb.env.x86_64.6b007979cf533f0f3016e8:latest
sweb.env.x86_64.7037e8c448a4b8ebfe9b13:latest
sweb.env.x86_64.71498c7426dbf05599642f:latest
sweb.env.x86_64.756beac07713d7e8dc1129:latest
sweb.env.x86_64.78278ae2cf880e395f1337:latest
sweb.env.x86_64.8f1f7b974f0c57c7aeba39:latest
sweb.env.x86_64.934a137824256b612e9dc5:latest
sweb.env.x86_64.a0efca7a0fe6719dbf65c2:latest
sweb.env.x86_64.a18371b03f944585b4f08c:latest
sweb.env.x86_64.a33dddf55cdff5d8e23374:latest
sweb.env.x86_64.aa92880033da20ca313928:latest
sweb.env.x86_64.b649f0ff62fad147f7f073:latest
sweb.env.x86_64.b7ce4be3b3c35f68c61248:latest
sweb.env.x86_64.c70909fdac4897d1c685df:latest
sweb.env.x86_64.c795f4b88616b8462021ed:latest
sweb.env.x86_64.cc47cc71483942d0c3a15e:latest
sweb.env.x86_64.dc5ff4c0e3fe8db5afc4da:latest
sweb.env.x86_64.e3afd7f04b325a4de4982d:latest
sweb.env.x86_64.e5bb89bf78258a7d14c34b:latest
sweb.env.x86_64.e83e37f52c09532c62acfb:latest
sweb.env.x86_64.efa6065ed5bf204410fd53:latest
sweb.eval.x86_64.django_s_django-17087:latest
sweb.eval.x86_64.scikit-learn_s_scikit-learn-10508:latest
sweb.eval.x86_64.django_s_django-14017:latest
sweb.eval.x86_64.django_s_django-11422:latest
sweb.eval.x86_64.sympy_s_sympy-14774:latest
sweb.eval.x86_64.django_s_django-14915:latest
sweb.eval.x86_64.sympy_s_sympy-22005:latest
sweb.eval.x86_64.pytest-dev_s_pytest-5221:latest
sweb.eval.x86_64.sympy_s_sympy-17022:latest
sweb.eval.x86_64.django_s_django-15996:latest
sweb.eval.x86_64.django_s_django-15252:latest
sweb.eval.x86_64.sympy_s_sympy-21171:latest
sweb.eval.x86_64.django_s_django-11797:latest
sweb.eval.x86_64.django_s_django-16046:latest
sweb.eval.x86_64.django_s_django-11583:latest
sweb.eval.x86_64.django_s_django-15738:latest
sweb.eval.x86_64.sympy_s_sympy-21612:latest
sweb.eval.x86_64.astropy_s_astropy-12907:latest
sweb.eval.x86_64.django_s_django-11620:latest
sweb.eval.x86_64.sympy_s_sympy-16792:latest
sweb.eval.x86_64.scikit-learn_s_scikit-learn-13779:latest
sweb.eval.x86_64.django_s_django-16041:latest
sweb.eval.x86_64.sympy_s_sympy-13471:latest
sweb.eval.x86_64.sympy_s_sympy-20442:latest
sweb.eval.x86_64.sympy_s_sympy-20049:latest
sweb.eval.x86_64.django_s_django-14411:latest
sweb.eval.x86_64.django_s_django-13447:latest
sweb.eval.x86_64.django_s_django-12856:latest
sweb.eval.x86_64.scikit-learn_s_scikit-learn-10949:latest
sweb.eval.x86_64.django_s_django-14787:latest
sweb.eval.x86_64.django_s_django-11815:latest
sweb.eval.x86_64.scikit-learn_s_scikit-learn-13584:latest
sweb.eval.x86_64.scikit-learn_s_scikit-learn-14087:latest
sweb.eval.x86_64.django_s_django-15388:latest
sweb.eval.x86_64.django_s_django-11179:latest
sweb.eval.x86_64.sympy_s_sympy-24102:latest
sweb.eval.x86_64.sympy_s_sympy-24213:latest
sweb.eval.x86_64.django_s_django-15781:latest
sweb.eval.x86_64.pytest-dev_s_pytest-8906:latest
sweb.eval.x86_64.django_s_django-13710:latest
sweb.eval.x86_64.django_s_django-13925:latest
sweb.eval.x86_64.scikit-learn_s_scikit-learn-14092:latest
sweb.eval.x86_64.pytest-dev_s_pytest-7373:latest
sweb.eval.x86_64.matplotlib_s_matplotlib-25498:latest
sweb.eval.x86_64.pytest-dev_s_pytest-5227:latest
sweb.eval.x86_64.sympy_s_sympy-15678:latest
sweb.eval.x86_64.django_s_django-13551:latest
sweb.eval.x86_64.django_s_django-14155:latest
sweb.eval.x86_64.django_s_django-13933:latest
sweb.eval.x86_64.sympy_s_sympy-21055:latest
sweb.eval.x86_64.django_s_django-13660:latest
sweb.eval.x86_64.django_s_django-16527:latest
sweb.eval.x86_64.pytest-dev_s_pytest-5692:latest
sweb.eval.x86_64.mwaskom_s_seaborn-3010:latest
sweb.eval.x86_64.django_s_django-12700:latest
sweb.eval.x86_64.sympy_s_sympy-11400:latest
sweb.eval.x86_64.sympy_s_sympy-23117:latest
sweb.eval.x86_64.sympy_s_sympy-20639:latest
sweb.eval.x86_64.sympy_s_sympy-23262:latest
sweb.eval.x86_64.django_s_django-15498:latest
sweb.eval.x86_64.django_s_django-12453:latest
sweb.eval.x86_64.django_s_django-14999:latest
sweb.eval.x86_64.sympy_s_sympy-13480:latest
sweb.eval.x86_64.sympy_s_sympy-21847:latest
sweb.eval.x86_64.sympy_s_sympy-15011:latest
sweb.eval.x86_64.scikit-learn_s_scikit-learn-25570:latest
sweb.eval.x86_64.sphinx-doc_s_sphinx-7975:latest
sweb.eval.x86_64.scikit-learn_s_scikit-learn-14983:latest
sweb.eval.x86_64.django_s_django-14534:latest
sweb.eval.x86_64.sympy_s_sympy-14396:latest
sweb.eval.x86_64.matplotlib_s_matplotlib-25442:latest
sweb.eval.x86_64.scikit-learn_s_scikit-learn-15535:latest
sweb.eval.x86_64.sympy_s_sympy-22714:latest
sweb.eval.x86_64.django_s_django-15789:latest
sweb.eval.x86_64.sympy_s_sympy-21627:latest
sweb.eval.x86_64.sympy_s_sympy-24066:latest
sweb.eval.x86_64.pylint-dev_s_pylint-7993:latest
sweb.eval.x86_64.django_s_django-14752:latest
sweb.eval.x86_64.sympy_s_sympy-18835:latest
sweb.eval.x86_64.django_s_django-17051:latest
sweb.eval.x86_64.sympy_s_sympy-12171:latest
sweb.eval.x86_64.pydata_s_xarray-3364:latest
sweb.eval.x86_64.mwaskom_s_seaborn-3190:latest
sweb.eval.x86_64.pytest-dev_s_pytest-7168:latest
sweb.eval.x86_64.django_s_django-12747:latest
sweb.eval.x86_64.django_s_django-15695:latest
sweb.eval.x86_64.matplotlib_s_matplotlib-22835:latest
sweb.eval.x86_64.sympy_s_sympy-12481:latest
sweb.eval.x86_64.django_s_django-15851:latest
sweb.eval.x86_64.sympy_s_sympy-14024:latest
sweb.eval.x86_64.django_s_django-14608:latest
sweb.eval.x86_64.pytest-dev_s_pytest-9359:latest
sweb.eval.x86_64.django_s_django-16873:latest
sweb.eval.x86_64.matplotlib_s_matplotlib-25433:latest
sweb.eval.x86_64.sympy_s_sympy-13031:latest
sweb.eval.x86_64.pytest-dev_s_pytest-7432:latest
sweb.eval.x86_64.scikit-learn_s_scikit-learn-25747:latest
sweb.eval.x86_64.django_s_django-12286:latest
sweb.eval.x86_64.django_s_django-11910:latest
sweb.eval.x86_64.scikit-learn_s_scikit-learn-12471:latest
sweb.eval.x86_64.pylint-dev_s_pylint-5859:latest
sweb.eval.x86_64.django_s_django-11133:latest
sweb.eval.x86_64.astropy_s_astropy-14365:latest
sweb.eval.x86_64.scikit-learn_s_scikit-learn-13496:latest
sweb.eval.x86_64.sympy_s_sympy-19487:latest
sweb.eval.x86_64.sympy_s_sympy-13895:latest
sweb.eval.x86_64.sympy_s_sympy-15345:latest
sweb.eval.x86_64.django_s_django-13590:latest
sweb.eval.x86_64.django_s_django-13757:latest
sweb.eval.x86_64.django_s_django-16379:latest
sweb.eval.x86_64.django_s_django-13768:latest
sweb.eval.x86_64.pytest-dev_s_pytest-8365:latest
sweb.eval.x86_64.django_s_django-14580:latest
sweb.eval.x86_64.sympy_s_sympy-20154:latest
sweb.eval.x86_64.sympy_s_sympy-12419:latest
sweb.eval.x86_64.django_s_django-12125:latest
sweb.eval.x86_64.sympy_s_sympy-24152:latest
sweb.eval.x86_64.scikit-learn_s_scikit-learn-15512:latest
sweb.eval.x86_64.sympy_s_sympy-18621:latest
sweb.eval.x86_64.pydata_s_xarray-4248:latest
sweb.eval.x86_64.scikit-learn_s_scikit-learn-11040:latest
sweb.eval.x86_64.django_s_django-11099:latest
sweb.eval.x86_64.django_s_django-16816:latest
sweb.eval.x86_64.django_s_django-13265:latest
sweb.eval.x86_64.django_s_django-16139:latest
sweb.eval.x86_64.scikit-learn_s_scikit-learn-10297:latest
sweb.eval.x86_64.django_s_django-14016:latest
sweb.eval.x86_64.pallets_s_flask-5063:latest
sweb.eval.x86_64.astropy_s_astropy-7746:latest
sweb.eval.x86_64.matplotlib_s_matplotlib-24265:latest
sweb.eval.x86_64.django_s_django-13448:latest
sweb.eval.x86_64.django_s_django-12908:latest
sweb.eval.x86_64.sphinx-doc_s_sphinx-8627:latest
sweb.eval.x86_64.sympy_s_sympy-14317:latest
sweb.eval.x86_64.pytest-dev_s_pytest-6116:latest
sweb.eval.x86_64.sympy_s_sympy-23191:latest
sweb.eval.x86_64.pydata_s_xarray-5131:latest
sweb.eval.x86_64.django_s_django-11019:latest
sweb.eval.x86_64.matplotlib_s_matplotlib-23913:latest
sweb.eval.x86_64.django_s_django-15790:latest
sweb.eval.x86_64.django_s_django-12497:latest
sweb.eval.x86_64.matplotlib_s_matplotlib-26020:latest
sweb.eval.x86_64.scikit-learn_s_scikit-learn-25638:latest
sweb.eval.x86_64.scikit-learn_s_scikit-learn-25500:latest
sweb.eval.x86_64.sympy_s_sympy-19007:latest
sweb.eval.x86_64.django_s_django-12308:latest
sweb.eval.x86_64.pytest-dev_s_pytest-7220:latest
sweb.eval.x86_64.django_s_django-11848:latest
sweb.eval.x86_64.django_s_django-15347:latest
sweb.eval.x86_64.pytest-dev_s_pytest-7490:latest
sweb.eval.x86_64.sympy_s_sympy-18532:latest
sweb.eval.x86_64.django_s_django-14997:latest
sweb.eval.x86_64.sympy_s_sympy-24909:latest
sweb.eval.x86_64.django_s_django-13220:latest
sweb.eval.x86_64.sympy_s_sympy-21614:latest
sweb.eval.x86_64.django_s_django-15902:latest
sweb.eval.x86_64.scikit-learn_s_scikit-learn-13497:latest
sweb.eval.x86_64.scikit-learn_s_scikit-learn-13439:latest
sweb.eval.x86_64.scikit-learn_s_scikit-learn-14894:latest
sweb.eval.x86_64.django_s_django-12983:latest
@@ -0,0 +1,31 @@
def print_diff_ignore_order(file1, file2):
with open(file1, 'r') as f1, open(file2, 'r') as f2:
file1_lines = set(f1.readlines())
file2_lines = set(f2.readlines())
only_in_file1 = file1_lines - file2_lines
only_in_file2 = file2_lines - file1_lines
if only_in_file1:
print(f'Lines in {file1} but not in {file2}:')
for line in sorted(only_in_file1):
print(f'- {line.strip()}')
# if only_in_file2:
# print(f"Lines in {file2} but not in {file1}:")
# for line in sorted(only_in_file2):
# print(f"+ {line.strip()}")
if not only_in_file1 and not only_in_file2:
print('The files have the same content (ignoring line order).')
if __name__ == '__main__':
# Usage
lite1 = 'all-swebench-lite-instance-images.txt' # Replace with the path to your first file
lite2 = '../../swe_bench/scripts/docker/all-swebench-lite-instance-images.txt' # Replace with the path to your second file
print_diff_ignore_order(lite1, lite2)
full1 = 'all-swebench-full-instance-images.txt' # Replace with the path to your first file
full2 = '../../swe_bench/scripts/docker/all-swebench-full-instance-images.txt' # Replace with the path to your second file
print_diff_ignore_order(full1, full2)
@@ -0,0 +1,48 @@
#!/bin/bash
# Script will delete all repositories and tags in your Docker Hub account
set -e
# Set username and password from command-line arguments
UNAME=$1
UPASS=$2
# Get token to interact with Docker Hub
TOKEN=$(curl -s -H "Content-Type: application/json" -X POST -d '{"username": "'${UNAME}'", "password": "'${UPASS}'"}' https://hub.docker.com/v2/users/login/ | jq -r .token)
# Ensure token retrieval was successful
if [[ -z "$TOKEN" ]]; then
echo "Failed to obtain authentication token. Please check your credentials."
exit 1
fi
# Get list of repositories for that user account
echo "Listing repositories in Docker Hub account '${UNAME}':"
REPO_LIST=$(curl -s -H "Authorization: JWT ${TOKEN}" "https://hub.docker.com/v2/repositories/${UNAME}/?page_size=10000" | jq -r '.results|.[]|.name')
if [[ -z "$REPO_LIST" ]]; then
echo "No repositories found for user '${UNAME}' or failed to fetch repositories."
exit 1
fi
# Loop through each repository and delete its tags and the repository itself
for rep in ${REPO_LIST}; do
echo "Processing repository: ${UNAME}/${rep}"
# Get all tags for the repository
IMAGES=$(curl -s -H "Authorization: JWT ${TOKEN}" "https://hub.docker.com/v2/repositories/${UNAME}/${rep}/tags/?page_size=100")
IMAGE_TAGS=$(echo $IMAGES | jq -r '.results|.[]|.name')
# Delete each tag
for tag in ${IMAGE_TAGS}; do
echo "Deleting tag: ${UNAME}/${rep}:${tag}"
curl -s -X DELETE -H "Authorization: JWT ${TOKEN}" "https://hub.docker.com/v2/repositories/${UNAME}/${rep}/tags/${tag}/"
done
# Delete the repository itself
echo "Deleting repository: ${UNAME}/${rep}"
curl -s -X DELETE -H "Authorization: JWT ${TOKEN}" "https://hub.docker.com/v2/repositories/${UNAME}/${rep}/" || {
echo "Failed to delete repository '${UNAME}/${rep}'. Please check permissions or API limits."
}
sleep 1
done
echo "Script execution completed."
@@ -0,0 +1,18 @@
from datasets import load_dataset
def dataset_to_txt(dataset, txt_file, split='test'):
with open(txt_file, 'w') as f:
for datum in dataset[split]:
instance_id = datum['instance_id'].replace('__', '_s_')
f.write(f'sweb.eval.x86_64.{instance_id}:latest\n')
if __name__ == '__main__':
# Load the private dataset
dataset = load_dataset('kjain14/testgeneval')
dataset_lite = load_dataset('kjain14/testgenevallite')
dataset_to_txt(dataset_lite, 'all-swebench-lite-instance-images.txt', lite=True)
dataset_to_txt(dataset, 'all-swebench-full-instance-images.txt')
@@ -0,0 +1,173 @@
import argparse
import copy
import difflib
import json
import os
import traceback
def insert_line_in_string(input_string, new_str, insert_line):
"""
Inserts a new line into a string at the specified line number.
:param input_string: The original string.
:param new_str: The string to insert.
:param insert_line: The line number at which to insert (1-based index).
:return: The modified string.
"""
file_text = input_string.expandtabs()
new_str = new_str.expandtabs()
file_text_lines = file_text.split('\n')
new_str_lines = new_str.split('\n')
new_file_text_lines = (
file_text_lines[:insert_line] + new_str_lines + file_text_lines[insert_line:]
)
return '\n'.join(new_file_text_lines)
def print_string_diff(original, modified):
"""
Prints the differences between two strings line by line.
:param original: The original string.
:param modified: The modified string.
"""
original_lines = original.splitlines(keepends=True)
modified_lines = modified.splitlines(keepends=True)
diff = difflib.unified_diff(
original_lines,
modified_lines,
fromfile='original',
tofile='modified',
lineterm='',
)
print(''.join(diff))
def parse_json_files(root_dir, output_dir, metadata_objs, preds_objs):
final_output = {i: [] for i in range(25)}
for subdir in sorted(os.listdir(root_dir)): # Sorting ensures consistent order
subdir_path = os.path.join(root_dir, subdir)
# subdir_instance = subdir.rsplit('-', 1)[0]
metadata = metadata_objs[subdir]
orig_test_suite = metadata['test_result']['test_suite']
if os.path.isdir(subdir_path): # Check if it's a directory
print(f'Processing subdirectory: {subdir}')
# Now loop through the JSON files in this subdirectory
i = 0
test_suite = preds_objs[subdir] if subdir in preds_objs else ''
for file in sorted(
os.listdir(subdir_path)
): # Sorting ensures consistent order
metadata_copy = copy.deepcopy(metadata)
if file.endswith('.json'): # Check for JSON files
file_path = os.path.join(subdir_path, file)
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f) # Load JSON data
try:
tool_calls = data['response']['choices'][0]['message'][
'tool_calls'
]
if tool_calls is not None:
for tool_call in tool_calls:
tool_call_dict = eval(
tool_call['function']['arguments']
)
if (
tool_call_dict is not None
and tool_call_dict != {}
):
command = tool_call_dict['command']
if command == 'create':
test_suite = tool_call_dict['file_text']
if (
command != 'str_replace'
and command != 'insert'
and 'coverage' not in command
):
print(command)
if command == 'insert':
test_suite_new = insert_line_in_string(
test_suite,
tool_call_dict['new_str'],
tool_call_dict['insert_line'],
)
test_suite = test_suite_new
if command == 'str_replace':
if (
test_suite.count(
tool_call_dict['old_str']
)
== 1
):
test_suite_new = test_suite.replace(
tool_call_dict['old_str'],
tool_call_dict['new_str'],
)
else:
continue
test_suite = test_suite_new
except Exception:
print(traceback.format_exc())
continue
metadata_copy['test_result']['test_suite'] = test_suite
if i < 25:
final_output[i].append(metadata_copy)
i += 1
except Exception as e:
print(traceback.format_exc())
print(f' Error loading {file_path}: {e}')
for j in range(i, 24):
final_output[j].append(metadata_copy)
metadata_orig = copy.deepcopy(metadata)
metadata_orig['test_result']['test_suite'] = orig_test_suite
final_output[24].append(metadata_orig)
for i in range(25):
output_file = os.path.join(output_dir, f'output_{i}.jsonl')
with open(output_file, 'w') as f:
for metadata in final_output[i]:
f.write(json.dumps(metadata) + '\n')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Parse JSON file')
parser.add_argument('--root_dir', type=str, help='Root directory', required=True)
parser.add_argument(
'--output_dir', type=str, help='Output directory', required=True
)
parser.add_argument(
'--starting_preds_file', type=str, help='Starting predictions', default=None
)
args = parser.parse_args()
output_file = os.path.join(args.output_dir, 'output.jsonl')
metadata_objs = {}
with open(output_file, 'r') as f:
content = f.readlines()
for line in content:
metadata = json.loads(line)
metadata_objs[metadata['instance_id']] = metadata
starting_preds_file = args.starting_preds_file
preds_objs = {}
if starting_preds_file is not None:
with open(starting_preds_file, 'r') as f:
content = f.readlines()
for line in content:
pred = json.loads(line)
preds_objs[pred['id']] = pred['preds']['full'][0]
parse_json_files(args.root_dir, args.output_dir, metadata_objs, preds_objs)
@@ -0,0 +1,67 @@
#!/usr/bin/env python3
import argparse
import pandas as pd
parser = argparse.ArgumentParser(
description='Compare two TestGenEval output JSONL files and print the resolved diff'
)
parser.add_argument('input_file_1', type=str)
parser.add_argument('input_file_2', type=str)
args = parser.parse_args()
df1 = pd.read_json(args.input_file_1, orient='records', lines=True)
df2 = pd.read_json(args.input_file_2, orient='records', lines=True)
# Get the intersection of the ids
df = pd.merge(df1, df2, on='id', how='inner')
def _get_coverage(report):
if report is None:
return False
if isinstance(report, float):
return False
else:
return report.get('test_pass', False)
df['test_pass_x'] = df['test_pass_x'].apply(_get_coverage)
df['test_pass_y'] = df['test_pass_y'].apply(_get_coverage)
df['diff'] = df.apply(lambda x: x['test_pass_x'] != x['test_pass_y'], axis=1)
df_diff = df[df['diff']].sort_values(
by=['test_pass_x', 'test_pass_y'], ascending=[False, False]
)
# skip if any of the pass is nan, which means one of the eval is not finished yet
df_diff = df_diff[df_diff['test_pass_x'].notna() & df_diff['test_pass_y'].notna()]
print(f'X={args.input_file_1}')
print(f'Y={args.input_file_2}')
print(f'# diff={df_diff.shape[0]}')
df_diff = df_diff[['id', 'test_pass_x', 'test_pass_y', 'report_x', 'report_y']]
# x pass but y not
print('-' * 100)
df_diff_x_only = df_diff[df_diff['test_pass_x'] & ~df_diff['test_pass_y']].sort_values(
by='id'
)
print(f'# x pass but y not={df_diff_x_only.shape[0]}')
print(df_diff_x_only[['id', 'report_x', 'report_y']])
# y pass but x not
print('-' * 100)
df_diff_y_only = df_diff[~df_diff['test_pass_x'] & df_diff['test_pass_y']].sort_values(
by='id'
)
print(f'# y pass but x not={df_diff_y_only.shape[0]}')
print(df_diff_y_only[['id', 'report_x', 'report_y']])
# get instance_id from df_diff_y_only
print('-' * 100)
print('Instances that x pass but y not:')
print(df_diff_x_only['id'].tolist())
print('-' * 100)
print('Instances that y pass but x not:')
print(df_diff_y_only['id'].tolist())
@@ -0,0 +1,28 @@
#!/bin/bash
FOLDER_PATH=$1
NEW_FOLDER_PATH=${FOLDER_PATH}.swebench_submission
mkdir -p $NEW_FOLDER_PATH
# Build all_preds.jsonl
poetry run python evaluation/testgeneval/scripts/eval/convert_oh_output_to_swe_json.py $FOLDER_PATH/output.jsonl
mv $FOLDER_PATH/output.swebench.jsonl $NEW_FOLDER_PATH/all_preds.jsonl
# Build trajs/
mkdir -p $NEW_FOLDER_PATH/trajs
for instance_dir in $FOLDER_PATH/llm_completions/*/; do
instance_id=$(basename "$instance_dir")
latest_json=$(ls -t "$instance_dir"/*.json | head -n1)
if [ -n "$latest_json" ]; then
cat "$latest_json" | jq -r '.messages' > "$NEW_FOLDER_PATH/trajs/$instance_id.json"
fi
done
# Build logs/
# check if $FOLDER_PATH/eval_outputs exists, if so copy over - else raise error
if [ -d "$FOLDER_PATH/eval_outputs" ]; then
cp -r $FOLDER_PATH/eval_outputs $NEW_FOLDER_PATH/logs
else
echo "Error: $FOLDER_PATH/eval_outputs does not exist. You should run the local docker eval_infer.sh first."
exit 1
fi
@@ -0,0 +1,91 @@
#!/usr/bin/env python3
"""Convert OpenHands output to a readable markdown format for visualization."""
import argparse
import json
import os
import pandas as pd
from tqdm import tqdm
from evaluation.testgeneval.eval_infer import process_test_suite
from openhands.events.serialization import event_from_dict
tqdm.pandas()
parser = argparse.ArgumentParser()
parser.add_argument('oh_output_file', type=str)
args = parser.parse_args()
output_md_folder = args.oh_output_file.replace('.jsonl', '.viz')
print(f'Converting {args.oh_output_file} to markdown files in {output_md_folder}')
oh_format = pd.read_json(args.oh_output_file, orient='records', lines=True)
# model name is the folder name of oh_output_file
model_name = os.path.basename(os.path.dirname(args.oh_output_file))
def convert_history_to_str(history):
ret = ''
separator = '\n\n' + '-' * 100 + '\n'
for i, event in enumerate(history):
if i != 0:
ret += separator
if isinstance(event, list):
# "event" is a legacy pair of (action, observation)
event_obj = event_from_dict(event[0])
ret += f'## {i+1}| {event_obj.__class__.__name__}\n\n'
ret += str(event_obj)
ret += separator
event_obj = event_from_dict(event[1])
ret += f'## {i+1}| {event_obj.__class__.__name__}\n\n'
ret += str(event_obj)
else:
# "event" is a single event
event_obj = event_from_dict(event)
ret += f'## {i+1}| {event_obj.__class__.__name__}\n\n'
ret += str(event_obj)
return ret
def write_row_to_md_file(row):
if 'test_suite' in row:
test_suite = row['test_suite']
elif 'test_result' in row and 'test_suite' in row['test_result']:
test_suite = row['test_result']['test_suite']
else:
raise ValueError(f'Row {row} does not have a test_suite')
if 'report' in row:
coverage = row['report'].get('coverage', 0)
mutation = row['report'].get('mutation_score', 0)
else:
coverage = None
mutation = None
id = row['id']
filename = f'{id}.md'
os.makedirs(output_md_folder, exist_ok=True)
filepath = os.path.join(output_md_folder, filename)
with open(filepath, 'w') as f:
f.write(f'# {id} (coverage: {coverage})\n')
f.write(f'# {id} (mutation score: {mutation})\n')
# MetaData
f.write('## MetaData\n')
f.write('```json\n')
f.write(json.dumps(row['metadata'], indent=2))
f.write('\n```\n')
# Trajectory
f.write('## History\n')
f.write(convert_history_to_str(row['history']))
f.write('## Test Suite\n')
f.write(f'{test_suite}\n')
oh_format.progress_apply(write_row_to_md_file, axis=1)
@@ -0,0 +1,35 @@
import argparse
import os
import pandas as pd
from evaluation.swe_bench.eval_infer import process_git_patch
parser = argparse.ArgumentParser()
parser.add_argument('oh_output_file', type=str)
args = parser.parse_args()
output_filepath = args.oh_output_file.replace('.jsonl', '.swebench.jsonl')
print(f'Converting {args.oh_output_file} to {output_filepath}')
oh_format = pd.read_json(args.oh_output_file, orient='records', lines=True)
# model name is the folder name of oh_output_file
model_name = os.path.basename(os.path.dirname(args.oh_output_file))
def convert_row_to_swebench_format(row):
if 'git_patch' in row:
model_patch = row['git_patch']
elif 'test_result' in row and 'git_patch' in row['test_result']:
model_patch = row['test_result']['git_patch']
else:
raise ValueError(f'Row {row} does not have a git_patch')
return {
'instance_id': row['instance_id'],
'model_patch': process_git_patch(model_patch),
'model_name_or_path': model_name,
}
swebench_format = oh_format.apply(convert_row_to_swebench_format, axis=1)
swebench_format.to_json(output_filepath, lines=True, orient='records')
@@ -0,0 +1,27 @@
import argparse
import pandas as pd
from datasets import load_dataset
parser = argparse.ArgumentParser()
parser.add_argument('output_filepath', type=str, help='Path to save the output file')
parser.add_argument(
'--dataset_name',
type=str,
help='Name of the dataset to download',
default='kjain14/testgeneval',
)
parser.add_argument('--split', type=str, help='Split to download', default='test')
args = parser.parse_args()
dataset = load_dataset(args.dataset_name, split=args.split)
output_filepath = args.output_filepath
print(
f'Downloading gold test suites from {args.dataset_name} (split: {args.split}) to {output_filepath}'
)
test_suites = [
{'instance_id': row['instance_id'], 'test_suite': row['test_src']} for row in dataset
]
print(f'{len(test_suites)} test suites loaded')
pd.DataFrame(test_suites).to_json(output_filepath, lines=True, orient='records')
print(f'Test suites saved to {output_filepath}')
@@ -0,0 +1,122 @@
#!/usr/bin/env python3
import argparse
import json
from collections import Counter
from openhands.events.serialization import event_from_dict
from openhands.events.utils import get_pairs_from_events
ERROR_KEYWORDS = [
'Agent encountered an error while processing the last action',
'APIError',
'Action execution failed',
]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('output_file', type=str, help='The file to summarize')
args = parser.parse_args()
with open(args.output_file, 'r') as file:
lines = file.readlines()
num_lines = len(lines)
num_error_lines = 0
num_agent_stuck_in_loop = 0
coverage = 0
mutation_score = 0
num_empty_suite = 0
error_counter = Counter()
main_agent_cost = []
editor_cost = []
num_turns = []
for line in lines:
_d = json.loads(line)
# Cost
costs = _d['metrics'].get('costs', [])
_cur_main_agent_cost = 0
_cur_editor_cost = 0
for cost in costs:
if isinstance(cost, float):
# backward compatible
_cur_main_agent_cost += cost
else:
if 'draft_editor' in cost['model']:
_cur_editor_cost += cost['cost']
else:
_cur_main_agent_cost += cost['cost']
main_agent_cost.append(_cur_main_agent_cost)
editor_cost.append(_cur_editor_cost)
# Turn status
history = _d.get('history', [])
events = [event_from_dict(event) for event in history]
pairs = get_pairs_from_events(events)
num_turns.append(len(pairs))
# Suite & resolve status
suite = _d.get('test_result', {}).get('test_suite', '')
if suite == '':
num_empty_suite += 1
continue
report = _d.get('report', {}) or {}
coverage += report.get('coverage', 0)
mutation_score += report.get('mutation_score', 0)
# Error
error = _d.get('error', None)
if error is not None and isinstance(error, str):
agent_stuck_in_loop = 'Agent got stuck in a loop' in error
contains_error = bool(error) and not agent_stuck_in_loop
if agent_stuck_in_loop:
error_counter['Agent got stuck in a loop'] += 1
num_agent_stuck_in_loop += 1
elif contains_error:
error_counter[error] += 1
continue
for keyword in ERROR_KEYWORDS:
if keyword in line:
error_counter[keyword] += 1
num_error_lines += 1
break
# print the error counter (with percentage)
print(
f'Average coverage for {num_lines} ({coverage / num_lines * 100:.2f}%)'
)
print(
f'Average mutation score for {num_lines} ({mutation_score / num_lines * 100:.2f}%)'
)
print(
f'Number of empty suite: {num_empty_suite} / {num_lines} ({num_empty_suite / num_lines * 100:.2f}%)'
)
print(
f'Number of error lines: {num_error_lines} / {num_lines} ({num_error_lines / num_lines * 100:.2f}%)'
)
print(
f'Number of agent stuck in loop: {num_agent_stuck_in_loop} / {num_lines} ({num_agent_stuck_in_loop / num_lines * 100:.2f}%)'
)
assert len(num_turns) == num_lines
assert len(main_agent_cost) == num_lines
assert len(editor_cost) == num_lines
print('## Statistics')
print(f'Avg. num of turns per instance: {sum(num_turns) / num_lines:.2f}')
print(f'Avg. agent cost per instance: {sum(main_agent_cost) / num_lines:.2f} USD')
print(f'Avg. editor cost per instance: {sum(editor_cost) / num_lines:.2f} USD')
print(
f'Avg. total cost per instance: {(sum(main_agent_cost) + sum(editor_cost)) / num_lines:.2f} USD'
)
print('## Detailed error breakdown:')
for error, count in error_counter.items():
print(f'{error}: {count} ({count / num_lines * 100:.2f}%)')
+53
View File
@@ -0,0 +1,53 @@
#!/bin/bash
set -eo pipefail
INPUT_FILE=$1
NUM_WORKERS=$2
DATASET=$3
SPLIT=$4
SKIP_MUTATION=$5
if [ -z "$INPUT_FILE" ]; then
echo "INPUT_FILE not specified (should be a path to a jsonl file)"
exit 1
fi
if [ -z "$DATASET" ]; then
echo "DATASET not specified, use default kjain14/testgenevallite"
DATASET="kjain14/testgenevallite"
fi
if [ -z "$SPLIT" ]; then
echo "SPLIT not specified, use default test"
SPLIT="test"
fi
if [ -z "$NUM_WORKERS" ]; then
echo "NUM_WORKERS not specified, use default 1"
NUM_WORKERS=1
fi
echo "... Evaluating on $INPUT_FILE ..."
COMMAND="poetry run python evaluation/benchmarks/testgeneval/eval_infer.py \
--eval-num-workers $NUM_WORKERS \
--input-file $INPUT_FILE \
--dataset $DATASET \
--split $SPLIT"
if [ "$SKIP_MUTATION" == "true" ]; then
echo "Skipping mutation evaluation"
COMMAND="$COMMAND --skip_mutation"
fi
if [ -n "$EVAL_LIMIT" ]; then
echo "EVAL_LIMIT: $EVAL_LIMIT"
COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
fi
echo $COMMAND
# Run the command
eval $COMMAND
# update the output with evaluation results
# poetry run python evaluation/benchmarks/testgeneval/scripts/eval/update_output_with_eval.py $INPUT_FILE
+122
View File
@@ -0,0 +1,122 @@
#!/bin/bash
set -eo pipefail
source "evaluation/utils/version_control.sh"
MODEL_CONFIG=$1
COMMIT_HASH=$2
AGENT=$3
EVAL_LIMIT=$4
MAX_ITER=$5
NUM_WORKERS=$6
DATASET=$7
SPLIT=$8
N_RUNS=$9
ZERO_SHOT_PATH=${10} # New argument for zero-shot path
if [ -z "$NUM_WORKERS" ]; then
NUM_WORKERS=1
echo "Number of workers not specified, use default $NUM_WORKERS"
fi
checkout_eval_branch
if [ -z "$AGENT" ]; then
echo "Agent not specified, use default CodeActAgent"
AGENT="CodeActAgent"
fi
if [ -z "$MAX_ITER" ]; then
echo "MAX_ITER not specified, use default 100"
MAX_ITER=100
fi
if [ -z "$USE_INSTANCE_IMAGE" ]; then
echo "USE_INSTANCE_IMAGE not specified, use default true"
USE_INSTANCE_IMAGE=true
fi
if [ -z "$RUN_WITH_BROWSING" ]; then
echo "RUN_WITH_BROWSING not specified, use default false"
RUN_WITH_BROWSING=false
fi
if [ -z "$DATASET" ]; then
echo "DATASET not specified, use default princeton-nlp/SWE-bench_Lite"
DATASET="princeton-nlp/SWE-bench_Lite"
fi
if [ -z "$SPLIT" ]; then
echo "SPLIT not specified, use default test"
SPLIT="test"
fi
export USE_INSTANCE_IMAGE=$USE_INSTANCE_IMAGE
echo "USE_INSTANCE_IMAGE: $USE_INSTANCE_IMAGE"
export RUN_WITH_BROWSING=$RUN_WITH_BROWSING
echo "RUN_WITH_BROWSING: $RUN_WITH_BROWSING"
get_openhands_version
echo "AGENT: $AGENT"
echo "OPENHANDS_VERSION: $OPENHANDS_VERSION"
echo "MODEL_CONFIG: $MODEL_CONFIG"
echo "DATASET: $DATASET"
echo "SPLIT: $SPLIT"
# Default to NOT use Hint
if [ -z "$USE_HINT_TEXT" ]; then
export USE_HINT_TEXT=false
fi
echo "USE_HINT_TEXT: $USE_HINT_TEXT"
EVAL_NOTE="$OPENHANDS_VERSION"
# if not using Hint, add -no-hint to the eval note
if [ "$USE_HINT_TEXT" = false ]; then
EVAL_NOTE="$EVAL_NOTE-no-hint"
fi
if [ "$RUN_WITH_BROWSING" = true ]; then
EVAL_NOTE="$EVAL_NOTE-with-browsing"
fi
if [ -n "$EXP_NAME" ]; then
EVAL_NOTE="$EVAL_NOTE-$EXP_NAME"
fi
function run_eval() {
local eval_note=$1
COMMAND="poetry run python evaluation/benchmarks/testgeneval/run_infer.py \
--agent-cls $AGENT \
--llm-config $MODEL_CONFIG \
--max-iterations $MAX_ITER \
--eval-num-workers $NUM_WORKERS \
--eval-note $eval_note \
--dataset $DATASET \
--split $SPLIT"
if [ -n "$EVAL_LIMIT" ]; then
echo "EVAL_LIMIT: $EVAL_LIMIT"
COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
fi
if [ -n "$ZERO_SHOT_PATH" ]; then
echo "ZERO_SHOT_PATH: $ZERO_SHOT_PATH"
COMMAND="$COMMAND --testfile_start --zero_shot_path $ZERO_SHOT_PATH"
fi
eval $COMMAND
}
unset SANDBOX_ENV_GITHUB_TOKEN # prevent the agent from using the github token to push
if [ -z "$N_RUNS" ]; then
N_RUNS=1
echo "N_RUNS not specified, use default $N_RUNS"
fi
for i in $(seq 1 $N_RUNS); do
current_eval_note="$EVAL_NOTE-run_$i"
echo "EVAL_NOTE: $current_eval_note"
run_eval $current_eval_note
done
checkout_original_branch
@@ -0,0 +1,40 @@
#!/bin/bash
source ~/.bashrc
SWEUTIL_DIR=/swe_util
# FIXME: Cannot read SWE_INSTANCE_ID from the environment variable
# SWE_INSTANCE_ID=django__django-11099
if [ -z "$SWE_INSTANCE_ID" ]; then
echo "Error: SWE_INSTANCE_ID is not set." >&2
exit 1
fi
# Read the swe-bench-test-lite.json file and extract the required item based on instance_id
item=$(jq --arg INSTANCE_ID "$SWE_INSTANCE_ID" '.[] | select(.instance_id == $INSTANCE_ID)' $SWEUTIL_DIR/eval_data/instances/swe-bench-instance.json)
if [[ -z "$item" ]]; then
echo "No item found for the provided instance ID."
exit 1
fi
WORKSPACE_NAME=$(echo "$item" | jq -r '(.repo | tostring) + "__" + (.version | tostring) | gsub("/"; "__")')
echo "WORKSPACE_NAME: $WORKSPACE_NAME"
# Clear the workspace
if [ -d /workspace ]; then
rm -rf /workspace/*
else
mkdir /workspace
fi
# Copy repo to workspace
if [ -d /workspace/$WORKSPACE_NAME ]; then
rm -rf /workspace/$WORKSPACE_NAME
fi
mkdir -p /workspace
ln -s /testbed /workspace/$WORKSPACE_NAME
# Activate instance-specific environment
. /opt/miniconda3/etc/profile.d/conda.sh
conda activate testbed
@@ -0,0 +1,27 @@
#!/bin/bash
set -e
EVAL_WORKSPACE="evaluation/swe_bench/eval_workspace"
mkdir -p $EVAL_WORKSPACE
# 1. Prepare REPO
echo "==== Prepare SWE-bench repo ===="
OH_SWE_BENCH_REPO_PATH="https://github.com/All-Hands-AI/SWE-bench.git"
OH_SWE_BENCH_REPO_BRANCH="eval"
git clone -b $OH_SWE_BENCH_REPO_BRANCH $OH_SWE_BENCH_REPO_PATH $EVAL_WORKSPACE/OH-SWE-bench
# 2. Prepare DATA
echo "==== Prepare SWE-bench data ===="
EVAL_IMAGE=ghcr.io/all-hands-ai/eval-swe-bench:builder_with_conda
EVAL_WORKSPACE=$(realpath $EVAL_WORKSPACE)
chmod +x $EVAL_WORKSPACE/OH-SWE-bench/swebench/harness/prepare_data.sh
if [ -d $EVAL_WORKSPACE/eval_data ]; then
rm -r $EVAL_WORKSPACE/eval_data
fi
docker run \
-v $EVAL_WORKSPACE:/workspace \
-w /workspace \
-u $(id -u):$(id -g) \
-e HF_DATASETS_CACHE="/tmp" \
--rm -it $EVAL_IMAGE \
bash -c "cd OH-SWE-bench/swebench/harness && /swe_util/miniforge3/bin/conda run -n swe-bench-eval ./prepare_data.sh && mv eval_data /workspace/"
@@ -0,0 +1,96 @@
#!/bin/bash
set -e
# assert user name is `root`
if [ "$USER" != "root" ]; then
echo "Error: This script is intended to be run by the 'root' user only." >&2
exit 1
fi
source ~/.bashrc
SWEUTIL_DIR=/swe_util
# Create logs directory
LOG_DIR=/openhands/logs
mkdir -p $LOG_DIR && chmod 777 $LOG_DIR
# FIXME: Cannot read SWE_INSTANCE_ID from the environment variable
# SWE_INSTANCE_ID=django__django-11099
if [ -z "$SWE_INSTANCE_ID" ]; then
echo "Error: SWE_INSTANCE_ID is not set." >&2
exit 1
fi
# Read the swe-bench-test-lite.json file and extract the required item based on instance_id
item=$(jq --arg INSTANCE_ID "$SWE_INSTANCE_ID" '.[] | select(.instance_id == $INSTANCE_ID)' $SWEUTIL_DIR/eval_data/instances/swe-bench-test-lite.json)
if [[ -z "$item" ]]; then
echo "No item found for the provided instance ID."
exit 1
fi
CONDA_ENV_NAME=$(echo "$item" | jq -r '.repo + "__" + .version | gsub("/"; "__")')
echo "CONDA_ENV_NAME: $CONDA_ENV_NAME"
SWE_TASK_DIR=/openhands/swe_tasks
mkdir -p $SWE_TASK_DIR
# Dump test_patch to /workspace/test.patch
echo "$item" | jq -r '.test_patch' > $SWE_TASK_DIR/test.patch
# Dump patch to /workspace/gold.patch
echo "$item" | jq -r '.patch' > $SWE_TASK_DIR/gold.patch
# Dump the item to /workspace/instance.json except for the "test_patch" and "patch" fields
echo "$item" | jq 'del(.test_patch, .patch)' > $SWE_TASK_DIR/instance.json
# Clear the workspace
rm -rf /workspace/*
# Copy repo to workspace
if [ -d /workspace/$CONDA_ENV_NAME ]; then
rm -rf /workspace/$CONDA_ENV_NAME
fi
cp -r $SWEUTIL_DIR/eval_data/testbeds/$CONDA_ENV_NAME /workspace
# Reset swe-bench testbed and install the repo
. $SWEUTIL_DIR/miniforge3/etc/profile.d/conda.sh
conda config --set changeps1 False
conda config --append channels conda-forge
conda activate swe-bench-eval
mkdir -p $SWE_TASK_DIR/reset_testbed_temp
mkdir -p $SWE_TASK_DIR/reset_testbed_log_dir
SWE_BENCH_DIR=/swe_util/OH-SWE-bench
output=$(
export PYTHONPATH=$SWE_BENCH_DIR && \
cd $SWE_BENCH_DIR && \
python swebench/harness/reset_swe_env.py \
--swe_bench_tasks $SWEUTIL_DIR/eval_data/instances/swe-bench-test.json \
--temp_dir $SWE_TASK_DIR/reset_testbed_temp \
--testbed /workspace \
--conda_path $SWEUTIL_DIR/miniforge3 \
--instance_id $SWE_INSTANCE_ID \
--log_dir $SWE_TASK_DIR/reset_testbed_log_dir \
--timeout 900 \
--verbose
)
REPO_PATH=$(echo "$output" | awk -F': ' '/repo_path:/ {print $2}')
TEST_CMD=$(echo "$output" | awk -F': ' '/test_cmd:/ {print $2}')
echo "Repo Path: $REPO_PATH"
echo "Test Command: $TEST_CMD"
echo "export SWE_BENCH_DIR=\"$SWE_BENCH_DIR\"" >> ~/.bashrc
echo "export REPO_PATH=\"$REPO_PATH\"" >> ~/.bashrc
echo "export TEST_CMD=\"$TEST_CMD\"" >> ~/.bashrc
if [[ "$REPO_PATH" == "None" ]]; then
echo "Error: Failed to retrieve repository path. Tests may not have passed or output was not as expected." >&2
exit 1
fi
# Activate instance-specific environment
. $SWEUTIL_DIR/miniforge3/etc/profile.d/conda.sh
conda activate $CONDA_ENV_NAME
set +e
@@ -0,0 +1,327 @@
import ast
import re
from typing import List, Tuple
from evaluation.benchmarks.testgeneval.constants import TestStatus
from evaluation.benchmarks.testgeneval.log_parsers import (
MAP_REPO_TO_PARSER,
parse_log_pytest,
)
def indent_text(text, indent_level):
return '\n'.join(
' ' * indent_level + line if line.strip() else line for line in text.split('\n')
)
def extract_preamble_classes_and_functions(code):
class_pattern = re.compile(
r'(?P<decorators>(?:^@[^\r\n]*(?:\r?\n(?:[ \t]+[^\r\n]*|^\)[^\r\n]*)*)*\r?\n)*?)'
r'^class\s+([\w]+)(?:\([^)]*\))?:', # the class line
re.MULTILINE,
)
# Capture methods with or without decorators
method_pattern = re.compile(r'(^(\s*@.*\s*)*^\s*def\s+[\w_]+\(.*\):)', re.MULTILINE)
# Capture functions with or without decorators
function_pattern = re.compile(
r'(?P<decorators>(?:^@[^\r\n]*(?:\r?\n(?:[ \t]+[^\r\n]*|^\)[^\r\n]*)*)*\r?\n)*?)'
r'^def\s+([\w_]+)\(.*\):', # the function line
re.MULTILINE,
)
preamble = ''
classes = []
test_functions = []
current_position = 0
def extract_class_body(code: str, start_index: int) -> Tuple[str, int]:
"""
Extracts the body of a class from the given code starting from the specified index.
Returns the class body and the end index of the class body.
"""
if not code or start_index < 0 or start_index >= len(code):
raise ValueError('Invalid code or start index')
# Split the code into lines
lines = code[start_index:].split('\n')
class_body_lines = []
# Find the starting indentation level of the class definition
class_start_line = lines[0]
start_indent = len(class_start_line) - len(class_start_line.lstrip())
inside_multiline_comment = False
end_index = start_index
for i, line in enumerate(lines[1:], start=1):
stripped_line = line.strip()
current_indent = len(line) - len(line.lstrip())
# Handle multiline comments or docstrings
if stripped_line.startswith('"""') or stripped_line.startswith("'''"):
if inside_multiline_comment:
inside_multiline_comment = False
else:
inside_multiline_comment = True
if not inside_multiline_comment:
# Stop when we reach a line with less indentation than the class definition
if current_indent <= start_indent and stripped_line:
break
# Add lines that are part of the class body
class_body_lines.append(line)
# Update the end index to the current line end
end_index = start_index + len('\n'.join(lines[: i + 1])) + 1
return code[start_index:end_index], end_index
while current_position < len(code):
class_match = class_pattern.search(code, current_position)
method_match = method_pattern.search(code, current_position)
if class_match and (
not method_match or class_match.start() < method_match.start()
):
class_name = class_match.group(0)
class_body, end_idx = extract_class_body(code, class_match.end())
current_position = end_idx
methods = []
class_prefix = class_name
set_prefix = False
for method_match in method_pattern.finditer(class_body):
method_name = method_match.group()
method_start = method_match.start()
if not set_prefix:
class_prefix = class_name + class_body[:method_start]
set_prefix = True
next_method = method_pattern.search(
class_body, method_start + len(method_name)
)
method_body = (
class_body[method_start : next_method.start()]
if next_method
else class_body[method_start:]
)
methods.append((method_name, method_body))
classes.append((class_prefix, methods, class_match.start()))
elif method_match:
function_name = method_match.group(0)
start_idx = method_match.start()
# Extract the current function's indentation level
lines = code[start_idx:].split('\n')
current_indent = len(lines[0]) - len(lines[0].lstrip())
next_function = function_pattern.search(
code, start_idx + len(function_name)
)
while next_function and (
class_match is None or next_function.start() < class_match.start()
):
# Calculate the indentation of the next function
next_function_start = next_function.start()
next_line = code[next_function_start:].split('\n', 1)[0]
next_indent = len(next_line) - len(next_line.lstrip())
# Check if the next function is top-level
if next_indent <= current_indent:
break
# Continue searching for the next top-level function
next_function = function_pattern.search(
code, next_function.start() + len(next_function.group(0))
)
if next_function:
next_function_start = next_function.start()
if class_match and next_function_start > class_match.start():
next_function_start = class_match.start()
function_body = code[start_idx:next_function_start]
else:
function_body = code[start_idx:]
test_functions.append((function_body, start_idx))
current_position = start_idx + len(function_body)
else:
break
if classes and test_functions:
preamble = code[: min(classes[0][2], test_functions[0][1])]
else:
preamble = (
code[: classes[0][2]]
if classes
else code[: test_functions[0][1]]
if test_functions
else code
)
return preamble.strip(), classes, test_functions
def filter_passing_tests(
test_content: str, test_output: str, repo: str
) -> Tuple[str, List[str], List[str]]:
"""
Filter tests based on their execution results.
Returns:
Tuple containing:
- Modified test content with only passing tests
- List of passing test names
- List of failing test names
"""
# Parse test results using appropriate parser
parser = MAP_REPO_TO_PARSER.get(repo, parse_log_pytest)
test_results = parser(test_output)
# Get passing and failing tests
passing_tests = []
failing_tests = []
for test_name, status in test_results.items():
if status == TestStatus.PASSED.value:
passing_tests.append(test_name)
else:
failing_tests.append(test_name)
if not passing_tests:
return '', passing_tests, failing_tests
# Extract test components
preamble, classes, functions = extract_preamble_classes_and_functions(test_content)
# Filter classes to only include passing methods
filtered_classes = []
for class_name, methods, start_idx in classes:
non_fail_methods = []
for method_name, method_body in methods:
# Extract the base method name for matching
method_full_name = (
method_name.split('.')[-1].split('(')[0].strip().split(' ')[-1]
)
# Check if the method name is in failing_tests or if any failing_test is in the method name
if not (
any(method_full_name in failing_test for failing_test in failing_tests)
or any(
failing_test in method_full_name for failing_test in failing_tests
)
):
non_fail_methods.append((method_name, method_body))
if non_fail_methods:
filtered_classes.append((class_name, non_fail_methods, start_idx))
# Filter standalone functions
filtered_functions = []
for func_body, start_idx in functions:
func_name = func_body.split('def ')[1].split('(')[0].strip()
if any(func_name in failing_test for failing_test in failing_tests) or any(
failing_test in func_name for failing_test in failing_tests
):
continue
filtered_functions.append((func_body, start_idx))
# Reconstruct test content with only passing tests
content_parts = [preamble]
# Add filtered classes
for class_name, methods, _ in filtered_classes:
class_content = class_name + '\n'
for _, method_body in methods:
class_content += method_body + '\n'
content_parts.append(class_content)
# Add filtered functions
for func_body, _ in filtered_functions:
content_parts.append(func_body)
return '\n\n'.join(content_parts), passing_tests, failing_tests
def filter_tests(
test_content: str, test_output: str, repo: str
) -> Tuple[str, List[str], List[str]]:
"""
Filter tests using AST parsing to remove failing test functions from the test file.
Non-test functions (e.g. setup or helper methods) and classes (even if all test methods are failing)
are preserved.
If AST processing fails (for example, because the test file cannot be parsed),
this function falls back on the existing regex-based filtering (filter_passing_tests).
Returns:
Tuple containing:
- Modified test content (as a string) containing only passing tests.
- List of passing test names.
- List of failing test names.
"""
try:
# Attempt to parse the test file using the AST.
tree = ast.parse(test_content)
# Parse test results using the appropriate parser.
parser = MAP_REPO_TO_PARSER.get(repo, parse_log_pytest)
test_results = parser(test_output)
passing_tests = [
name
for name, status in test_results.items()
if status == TestStatus.PASSED.value
]
failing_tests = [
name
for name, status in test_results.items()
if status != TestStatus.PASSED.value
]
# Helper function to decide if a test name should be considered failing.
def is_failing(name: str) -> bool:
for ft in failing_tests:
if name in ft or ft in name:
return True
return False
new_body = []
for node in tree.body:
# For top-level function definitions, only filter those that look like tests.
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
if node.name.startswith('test') and is_failing(node.name):
continue
new_body.append(node)
# For classes, filter out failing test methods but preserve other methods (e.g. setup).
elif isinstance(node, ast.ClassDef):
new_class_body = []
for subnode in node.body:
if isinstance(subnode, (ast.FunctionDef, ast.AsyncFunctionDef)):
# Only consider filtering if the method is a test.
qualified_name = f'{node.name}.{subnode.name}'
if is_failing(subnode.name) or is_failing(qualified_name):
continue
new_class_body.append(subnode)
else:
new_class_body.append(subnode)
# Always include the class even if no test methods remain, as it might contain
# setup, teardown, or other necessary logic.
if new_class_body:
node.body = new_class_body
new_body.append(node)
else:
new_body.append(node)
tree.body = new_body
# Reconstruct the source code from the filtered AST.
# (Requires Python 3.9+ for ast.unparse; otherwise an exception will trigger the fallback.)
new_test_content = ast.unparse(tree)
return new_test_content, passing_tests, failing_tests
except Exception:
print('AST processing failed; falling back on regex-based filtering.')
# If AST processing fails for any reason, fall back on the original regex-based filtering.
return filter_passing_tests(test_content, test_output, repo)
@@ -0,0 +1,166 @@
from __future__ import annotations
from dataclasses import dataclass
from evaluation.benchmarks.testgeneval.constants import (
COVERAGE_PREFIX,
KEY_INSTANCE_ID,
MAP_REPO_VERSION_TO_SPECS,
TESTS_FAILED,
TESTS_SUFFIX,
UPDATE_TOX,
TestGenEvalInstance,
)
from evaluation.benchmarks.testgeneval.utils import (
get_test_directives,
)
DIFF_MODIFIED_FILE_REGEX = r'--- a/(.*)'
@dataclass
class TestSpec:
"""
A dataclass that represents a test specification for a single instance of SWE-bench.
"""
instance_id: str
id: str
repo: str
version: str
test_cmd: str
code_file: str
test_file: str
baseline_covs: dict
local_imports: list[str]
test_script_list: list[str]
mutation_script_list: list[str]
@property
def test_script(self):
return (
'\n'.join(['#!/bin/bash', 'set -uo pipefail'] + self.test_script_list)
+ '\n'
)
# Don't exit early because we need to revert tests at the end
@property
def mutation_script(self):
return (
'\n'.join(['#!/bin/bash', 'set -uo pipefail'] + self.mutation_script_list)
+ '\n'
)
# Don't exit early because we need to revert tests at the end
def make_test_setup(specs, env_name, repo_directory, includes_tox=False):
eval_commands = []
if includes_tox:
eval_commands.append(UPDATE_TOX)
eval_commands += [
'source /opt/miniconda3/bin/activate',
f'conda activate {env_name}',
f'cd {repo_directory}',
]
if 'eval_commands' in specs:
eval_commands += specs['eval_commands']
eval_commands += [
f'git config --global --add safe.directory {repo_directory}', # for nonroot user
f'cd {repo_directory}',
# This is just informational, so we have a record
'git status',
'git show',
'source /opt/miniconda3/bin/activate',
f'conda activate {env_name}',
]
if 'install' in specs:
eval_commands.append(specs['install'])
if includes_tox:
eval_commands.append('add_coverage_tox "tox.ini"')
eval_commands.append('[ -f ".coveragerc" ] && rm ".coveragerc"')
return eval_commands
def make_test_script_list(test_cmd, specs, env_name, repo_directory):
"""
Runs the tests.
"""
includes_tox = 'tox' in test_cmd
eval_commands = make_test_setup(specs, env_name, repo_directory, includes_tox)
eval_commands += [
f'{test_cmd} || {{ echo "{TESTS_FAILED}\n{TESTS_SUFFIX}\n" && exit 1; }}',
f'echo "{TESTS_SUFFIX}"\n',
'coverage json -o coverage.json',
f'echo "{COVERAGE_PREFIX}"\n',
'cat coverage.json',
]
return eval_commands
def make_mutation_script_list(specs, env_name, repo_directory, mutation_timeout):
"""
Runs the tests.
"""
eval_commands = make_test_setup(specs, env_name, repo_directory)
eval_commands += [
'cosmic-ray init mutation.toml mutation.sqlite',
f'timeout {mutation_timeout}s cosmic-ray exec mutation.toml mutation.sqlite',
'cr-report mutation.sqlite',
'cr-rate mutation.sqlite --estimate --confidence 95.0',
]
return eval_commands
def make_test_spec(
instance: TestGenEvalInstance, mutation_timeout: int, buffer: int
) -> TestSpec:
if isinstance(instance, TestSpec):
return instance
instance_id = instance[KEY_INSTANCE_ID]
id = instance['id']
repo = instance['repo']
version = instance['version']
baseline_covs = instance['baseline_covs']
code_file = instance['code_file']
test_file = instance['test_file']
local_imports = instance['local_imports']
env_name = 'testbed'
repo_directory = f'/{env_name}'
specs = MAP_REPO_VERSION_TO_SPECS[repo][version]
test_cmd = ' '.join(
[
MAP_REPO_VERSION_TO_SPECS[instance['repo']][instance['version']][
'test_cmd'
],
*get_test_directives(instance),
]
)
test_script_list = make_test_script_list(test_cmd, specs, env_name, repo_directory)
mutation_script_list = make_mutation_script_list(
specs, env_name, repo_directory, mutation_timeout - buffer
)
return TestSpec(
instance_id=instance_id,
id=id,
repo=repo,
test_script_list=test_script_list,
test_cmd=test_cmd,
local_imports=local_imports,
mutation_script_list=mutation_script_list,
code_file=code_file,
test_file=test_file,
baseline_covs=baseline_covs,
version=version,
)
@@ -0,0 +1,73 @@
import json
from pathlib import Path
from typing import cast
from datasets import Dataset, load_dataset
from evaluation.benchmarks.testgeneval.constants import (
KEY_INSTANCE_ID,
TestGenEvalInstance,
)
def get_test_directives(instance: TestGenEvalInstance) -> list:
"""
Get test directives from the test_patch of a task instance
Args:
instance (dict): task instance
Returns:
directives (list): List of test directives
"""
# For seq2seq code repos, testing command is fixed
if instance['repo'] == 'swe-bench/humaneval':
return ['test.py']
# Get test directives from test patch and remove non-test files
directives = [f"/testbed/{instance['test_file']}"]
# For Django tests, remove extension + "tests/" prefix and convert slashes to dots (module referencing)
if instance['repo'] == 'django/django':
directives = [instance['test_file']]
directives_transformed = []
for d in directives:
d = d[: -len('.py')] if d.endswith('.py') else d
d = d[len('tests/') :] if d.startswith('tests/') else d
d = d.replace('/', '.')
directives_transformed.append(d)
directives = directives_transformed
return directives
def load_testgeneval_dataset(
name='kjain14/testgeneval', split='test', ids=None
) -> list[TestGenEvalInstance]:
"""
Load SWE-bench dataset from Hugging Face Datasets or local .json/.jsonl file
"""
# check that all instance IDs are in the dataset
if ids:
ids = set(ids)
# Load from local .json/.jsonl file
if name.endswith('.json') or name.endswith('.jsonl'):
dataset = json.loads(Path(name).read_text())
dataset_ids = {instance[KEY_INSTANCE_ID] for instance in dataset}
else:
# Load from Hugging Face Datasets
if name.lower() in {'testgeneval'}:
name = 'kjain14/testgeneval'
elif name.lower() in {'testgeneval-lite', 'testgenevallite', 'lite'}:
name = 'kjain14/testgenevallite'
dataset = cast(Dataset, load_dataset(name, split=split))
dataset_ids = {instance['id'] for instance in dataset}
if ids:
if ids - dataset_ids:
raise ValueError(
(
"Some instance IDs not found in dataset!"
f"\nMissing IDs:\n{' '.join(ids - dataset_ids)}"
)
)
dataset = [instance for instance in dataset if instance['id'] in ids]
return [cast(TestGenEvalInstance, instance) for instance in dataset]
@@ -34,7 +34,6 @@ from openhands.utils.async_utils import call_async_from_sync
FAKE_RESPONSES = {
'CodeActAgent': fake_user_response,
'DelegatorAgent': fake_user_response,
'VisualBrowsingAgent': fake_user_response,
}
-2
View File
@@ -6,7 +6,6 @@ load_dotenv()
from openhands.agenthub import ( # noqa: E402
browsing_agent,
codeact_agent,
delegator_agent,
dummy_agent,
visualbrowsing_agent,
)
@@ -15,7 +14,6 @@ from openhands.controller.agent import Agent # noqa: E402
__all__ = [
'Agent',
'codeact_agent',
'delegator_agent',
'dummy_agent',
'browsing_agent',
'visualbrowsing_agent',
@@ -70,9 +70,10 @@ class CodeActAgent(Agent):
codeact_enable_browsing=self.config.codeact_enable_browsing,
codeact_enable_jupyter=self.config.codeact_enable_jupyter,
codeact_enable_llm_editor=self.config.codeact_enable_llm_editor,
llm=self.llm,
)
logger.debug(
f'TOOLS loaded for CodeActAgent: {', '.join([tool.get('function').get('name') for tool in self.tools])}'
f"TOOLS loaded for CodeActAgent: {', '.join([tool.get('function').get('name') for tool in self.tools])}"
)
self.prompt_manager = PromptManager(
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
@@ -12,13 +12,13 @@ from litellm import (
from openhands.agenthub.codeact_agent.tools import (
BrowserTool,
CmdRunTool,
FinishTool,
IPythonTool,
LLMBasedFileEditTool,
StrReplaceEditorTool,
ThinkTool,
WebReadTool,
create_cmd_run_tool,
create_str_replace_editor_tool,
)
from openhands.core.exceptions import (
FunctionCallNotExistsError,
@@ -39,6 +39,7 @@ from openhands.events.action import (
)
from openhands.events.event import FileEditSource, FileReadSource
from openhands.events.tool import ToolCallMetadata
from openhands.llm import LLM
def combine_thought(action: Action, thought: str) -> Action:
@@ -80,7 +81,7 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
# CmdRunTool (Bash)
# ================================================
if tool_call.function.name == CmdRunTool['function']['name']:
if tool_call.function.name == create_cmd_run_tool()['function']['name']:
if 'command' not in arguments:
raise FunctionCallValidationError(
f'Missing required argument "command" in tool call {tool_call.function.name}'
@@ -131,7 +132,10 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
start=arguments.get('start', 1),
end=arguments.get('end', -1),
)
elif tool_call.function.name == StrReplaceEditorTool['function']['name']:
elif (
tool_call.function.name
== create_str_replace_editor_tool()['function']['name']
):
if 'command' not in arguments:
raise FunctionCallValidationError(
f'Missing required argument "command" in tool call {tool_call.function.name}'
@@ -219,8 +223,22 @@ def get_tools(
codeact_enable_browsing: bool = False,
codeact_enable_llm_editor: bool = False,
codeact_enable_jupyter: bool = False,
llm: LLM | None = None,
) -> list[ChatCompletionToolParam]:
tools = [CmdRunTool, ThinkTool, FinishTool]
SIMPLIFIED_TOOL_DESCRIPTION_LLM_SUBSTRS = ['gpt-', 'o3', 'o1']
use_simplified_tool_desc = False
if llm is not None:
use_simplified_tool_desc = any(
model_substr in llm.config.model
for model_substr in SIMPLIFIED_TOOL_DESCRIPTION_LLM_SUBSTRS
)
tools = [
create_cmd_run_tool(use_simplified_description=use_simplified_tool_desc),
ThinkTool,
FinishTool,
]
if codeact_enable_browsing:
tools.append(WebReadTool)
tools.append(BrowserTool)
@@ -229,5 +247,9 @@ def get_tools(
if codeact_enable_llm_editor:
tools.append(LLMBasedFileEditTool)
else:
tools.append(StrReplaceEditorTool)
tools.append(
create_str_replace_editor_tool(
use_simplified_description=use_simplified_tool_desc
)
)
return tools
@@ -1,19 +1,19 @@
from .bash import CmdRunTool
from .bash import create_cmd_run_tool
from .browser import BrowserTool
from .finish import FinishTool
from .ipython import IPythonTool
from .llm_based_edit import LLMBasedFileEditTool
from .str_replace_editor import StrReplaceEditorTool
from .str_replace_editor import create_str_replace_editor_tool
from .think import ThinkTool
from .web_read import WebReadTool
__all__ = [
'BrowserTool',
'CmdRunTool',
'create_cmd_run_tool',
'FinishTool',
'IPythonTool',
'LLMBasedFileEditTool',
'StrReplaceEditorTool',
'create_str_replace_editor_tool',
'WebReadTool',
'ThinkTool',
]
+35 -21
View File
@@ -1,6 +1,6 @@
from litellm import ChatCompletionToolParam, ChatCompletionToolParamFunctionChunk
_BASH_DESCRIPTION = """Execute a bash command in the terminal within a persistent shell session.
_DETAILED_BASH_DESCRIPTION = """Execute a bash command in the terminal within a persistent shell session.
### Command Execution
* One command at a time: You can only execute one bash command at a time. If you need to run multiple commands sequentially, use `&&` or `;` to chain them together.
@@ -22,25 +22,39 @@ _BASH_DESCRIPTION = """Execute a bash command in the terminal within a persisten
* Output truncation: If the output exceeds a maximum length, it will be truncated before being returned.
"""
CmdRunTool = ChatCompletionToolParam(
type='function',
function=ChatCompletionToolParamFunctionChunk(
name='execute_bash',
description=_BASH_DESCRIPTION,
parameters={
'type': 'object',
'properties': {
'command': {
'type': 'string',
'description': 'The bash command to execute. Can be empty string to view additional logs when previous exit code is `-1`. Can be `C-c` (Ctrl+C) to interrupt the currently running process. Note: You can only execute one bash command at a time. If you need to run multiple commands sequentially, you can use `&&` or `;` to chain them together.',
},
'is_input': {
'type': 'string',
'description': 'If True, the command is an input to the running process. If False, the command is a bash command to be executed in the terminal. Default is False.',
'enum': ['true', 'false'],
_SIMPLIFIED_BASH_DESCRIPTION = """Execute a bash command in the terminal.
* Long running commands: For commands that may run indefinitely, it should be run in the background and the output should be redirected to a file, e.g. command = `python3 app.py > server.log 2>&1 &`.
* Interact with running process: If a bash command returns exit code `-1`, this means the process is not yet finished. By setting `is_input` to `true`, the assistant can interact with the running process and send empty `command` to retrieve any additional logs, or send additional text (set `command` to the text) to STDIN of the running process, or send command like `C-c` (Ctrl+C), `C-d` (Ctrl+D), `C-z` (Ctrl+Z) to interrupt the process.
* One command at a time: You can only execute one bash command at a time. If you need to run multiple commands sequentially, you can use `&&` or `;` to chain them together."""
def create_cmd_run_tool(
use_simplified_description: bool = False,
) -> ChatCompletionToolParam:
description = (
_SIMPLIFIED_BASH_DESCRIPTION
if use_simplified_description
else _DETAILED_BASH_DESCRIPTION
)
return ChatCompletionToolParam(
type='function',
function=ChatCompletionToolParamFunctionChunk(
name='execute_bash',
description=description,
parameters={
'type': 'object',
'properties': {
'command': {
'type': 'string',
'description': 'The bash command to execute. Can be empty string to view additional logs when previous exit code is `-1`. Can be `C-c` (Ctrl+C) to interrupt the currently running process. Note: You can only execute one bash command at a time. If you need to run multiple commands sequentially, you can use `&&` or `;` to chain them together.',
},
'is_input': {
'type': 'string',
'description': 'If True, the command is an input to the running process. If False, the command is a bash command to be executed in the terminal. Default is False.',
'enum': ['true', 'false'],
},
},
'required': ['command'],
},
'required': ['command'],
},
),
)
),
)
@@ -1,6 +1,6 @@
from litellm import ChatCompletionToolParam, ChatCompletionToolParamFunctionChunk
_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files in plain-text format
_DETAILED_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files in plain-text format
* State is persistent across command calls and discussions with the user
* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep
* The `create` command cannot be used if the specified `path` already exists as a file
@@ -31,46 +31,73 @@ CRITICAL REQUIREMENTS FOR USING THIS TOOL:
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.
"""
StrReplaceEditorTool = ChatCompletionToolParam(
type='function',
function=ChatCompletionToolParamFunctionChunk(
name='str_replace_editor',
description=_STR_REPLACE_EDITOR_DESCRIPTION,
parameters={
'type': 'object',
'properties': {
'command': {
'description': 'The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`, `undo_edit`.',
'enum': ['view', 'create', 'str_replace', 'insert', 'undo_edit'],
'type': 'string',
},
'path': {
'description': 'Absolute path to file or directory, e.g. `/workspace/file.py` or `/workspace`.',
'type': 'string',
},
'file_text': {
'description': 'Required parameter of `create` command, with the content of the file to be created.',
'type': 'string',
},
'old_str': {
'description': 'Required parameter of `str_replace` command containing the string in `path` to replace.',
'type': 'string',
},
'new_str': {
'description': 'Optional parameter of `str_replace` command containing the new string (if not given, no string will be added). Required parameter of `insert` command containing the string to insert.',
'type': 'string',
},
'insert_line': {
'description': 'Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`.',
'type': 'integer',
},
'view_range': {
'description': 'Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.',
'items': {'type': 'integer'},
'type': 'array',
_SIMPLIFIED_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files in plain-text format
* State is persistent across command calls and discussions with the user
* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep
* The `create` command cannot be used if the specified `path` already exists as a file
* If a `command` generates a long output, it will be truncated and marked with `<response clipped>`
* The `undo_edit` command will revert the last edit made to the file at `path`
Notes for using the `str_replace` command:
* The `old_str` parameter should match EXACTLY one or more consecutive lines from the original file. Be mindful of whitespaces!
* If the `old_str` parameter is not unique in the file, the replacement will not be performed. Make sure to include enough context in `old_str` to make it unique
* The `new_str` parameter should contain the edited lines that should replace the `old_str`
"""
def create_str_replace_editor_tool(
use_simplified_description: bool = False,
) -> ChatCompletionToolParam:
description = (
_SIMPLIFIED_STR_REPLACE_EDITOR_DESCRIPTION
if use_simplified_description
else _DETAILED_STR_REPLACE_EDITOR_DESCRIPTION
)
return ChatCompletionToolParam(
type='function',
function=ChatCompletionToolParamFunctionChunk(
name='str_replace_editor',
description=description,
parameters={
'type': 'object',
'properties': {
'command': {
'description': 'The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`, `undo_edit`.',
'enum': [
'view',
'create',
'str_replace',
'insert',
'undo_edit',
],
'type': 'string',
},
'path': {
'description': 'Absolute path to file or directory, e.g. `/workspace/file.py` or `/workspace`.',
'type': 'string',
},
'file_text': {
'description': 'Required parameter of `create` command, with the content of the file to be created.',
'type': 'string',
},
'old_str': {
'description': 'Required parameter of `str_replace` command containing the string in `path` to replace.',
'type': 'string',
},
'new_str': {
'description': 'Optional parameter of `str_replace` command containing the new string (if not given, no string will be added). Required parameter of `insert` command containing the string to insert.',
'type': 'string',
},
'insert_line': {
'description': 'Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`.',
'type': 'integer',
},
'view_range': {
'description': 'Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.',
'items': {'type': 'integer'},
'type': 'array',
},
},
'required': ['command', 'path'],
},
'required': ['command', 'path'],
},
),
)
),
)
@@ -1,4 +0,0 @@
from openhands.agenthub.delegator_agent.agent import DelegatorAgent
from openhands.controller.agent import Agent
Agent.register('DelegatorAgent', DelegatorAgent)
@@ -1,87 +0,0 @@
from openhands.controller.agent import Agent
from openhands.controller.state.state import State
from openhands.core.config import AgentConfig
from openhands.events.action import Action, AgentDelegateAction, AgentFinishAction
from openhands.events.observation import AgentDelegateObservation, Observation
from openhands.llm.llm import LLM
class DelegatorAgent(Agent):
VERSION = '1.0'
"""
The Delegator Agent is responsible for delegating tasks to other agents based on the current task.
"""
current_delegate: str = ''
def __init__(self, llm: LLM, config: AgentConfig):
"""Initialize the Delegator Agent with an LLM
Parameters:
- llm (LLM): The llm to be used by this agent
"""
super().__init__(llm, config)
def step(self, state: State) -> Action:
"""Checks to see if current step is completed, returns AgentFinishAction if True.
Otherwise, delegates the task to the next agent in the pipeline.
Parameters:
- state (State): The current state given the previous actions and observations
Returns:
- AgentFinishAction: If the last state was 'completed', 'verified', or 'abandoned'
- AgentDelegateAction: The next agent to delegate the task to
"""
if self.current_delegate == '':
self.current_delegate = 'study'
task, _ = state.get_current_user_intent()
return AgentDelegateAction(
agent='StudyRepoForTaskAgent', inputs={'task': task}
)
# last observation in history should be from the delegate
last_observation = None
for event in reversed(state.history):
if isinstance(event, Observation):
last_observation = event
break
if not isinstance(last_observation, AgentDelegateObservation):
raise Exception('Last observation is not an AgentDelegateObservation')
goal, _ = state.get_current_user_intent()
if self.current_delegate == 'study':
self.current_delegate = 'coder'
return AgentDelegateAction(
agent='CoderAgent',
inputs={
'task': goal,
'summary': last_observation.outputs['summary'],
},
)
elif self.current_delegate == 'coder':
self.current_delegate = 'verifier'
return AgentDelegateAction(
agent='VerifierAgent',
inputs={
'task': goal,
},
)
elif self.current_delegate == 'verifier':
if (
'completed' in last_observation.outputs
and last_observation.outputs['completed']
):
return AgentFinishAction()
else:
self.current_delegate = 'coder'
return AgentDelegateAction(
agent='CoderAgent',
inputs={
'task': goal,
'summary': last_observation.outputs['summary'],
},
)
else:
raise Exception('Invalid delegate state')
@@ -202,6 +202,7 @@ Note:
tabs = ''
last_obs = None
last_action = None
set_of_marks = None # Initialize set_of_marks to None
if len(state.history) == 1:
# for visualwebarena, webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env
@@ -217,6 +218,9 @@ Note:
# agent has responded, task finished.
return AgentFinishAction(outputs={'content': event.content})
elif isinstance(event, Observation):
# Only process BrowserOutputObservation and skip other observation types
if not isinstance(event, BrowserOutputObservation):
continue
last_obs = event
if len(prev_actions) >= 1: # ignore noop()
+2 -1
View File
@@ -15,6 +15,7 @@ class SandboxConfig(BaseModel):
timeout: The timeout for the default sandbox action execution.
remote_runtime_init_timeout: The timeout for the remote runtime to start.
remote_runtime_api_timeout: The timeout for the remote runtime API requests.
remote_runtime_enable_retries: Whether to enable retries (on recoverable errors like requests.ConnectionError) for the remote runtime API requests.
enable_auto_lint: Whether to enable auto-lint.
use_host_network: Whether to use the host network.
runtime_binding_address: The binding address for the runtime ports. It specifies which network interface on the host machine Docker should bind the runtime ports to.
@@ -53,7 +54,7 @@ class SandboxConfig(BaseModel):
timeout: int = Field(default=120)
remote_runtime_init_timeout: int = Field(default=180)
remote_runtime_api_timeout: int = Field(default=10)
remote_runtime_enable_retries: bool = Field(default=False)
remote_runtime_enable_retries: bool = Field(default=True)
remote_runtime_class: str | None = Field(
default=None
) # can be "None" (default to gvisor) or "sysbox" (support docker inside runtime + more stable)
+1 -1
View File
@@ -240,7 +240,7 @@ class SensitiveDataFilter(logging.Filter):
if (
len(value) > 2
and value != 'default'
and any(s in key_upper for s in ('SECRET', 'KEY', 'CODE', 'TOKEN'))
and any(s in key_upper for s in ('SECRET', '_KEY', '_CODE', '_TOKEN'))
):
sensitive_values.append(value)
+2 -2
View File
@@ -49,8 +49,8 @@ class ObservationTypeSchema(BaseModel):
CONDENSE: str = Field(default='condense')
"""Result of a condensation operation."""
MICROAGENT: str = Field(default='microagent')
"""Result of a microagent retrieval operation."""
RECALL: str = Field(default='recall')
"""Result of a recall operation. This can be the workspace context, a microagent, or other types of information."""
ObservationType = ObservationTypeSchema()
+2 -2
View File
@@ -3,7 +3,7 @@ from openhands.events.observation.agent import (
AgentCondensationObservation,
AgentStateChangedObservation,
AgentThinkObservation,
MicroagentObservation,
RecallObservation,
)
from openhands.events.observation.browse import BrowserOutputObservation
from openhands.events.observation.commands import (
@@ -42,6 +42,6 @@ __all__ = [
'SuccessObservation',
'UserRejectObservation',
'AgentCondensationObservation',
'MicroagentObservation',
'RecallObservation',
'RecallType',
]
+31 -17
View File
@@ -60,13 +60,13 @@ class MicroagentKnowledge:
@dataclass
class MicroagentObservation(Observation):
class RecallObservation(Observation):
"""The retrieval of content from a microagent or more microagents."""
recall_type: RecallType
observation: str = ObservationType.MICROAGENT
observation: str = ObservationType.RECALL
# environment
# workspace context
repo_name: str = ''
repo_directory: str = ''
repo_instructions: str = ''
@@ -95,22 +95,36 @@ class MicroagentObservation(Observation):
@property
def message(self) -> str:
return self.__str__()
return (
'Added workspace context'
if self.recall_type == RecallType.WORKSPACE_CONTEXT
else 'Added microagent knowledge'
)
def __str__(self) -> str:
# Build a string representation of all fields
fields = [
f'recall_type={self.recall_type}',
f'repo_name={self.repo_name}',
f'repo_instructions={self.repo_instructions[:20]}...',
f'runtime_hosts={self.runtime_hosts}',
f'additional_agent_instructions={self.additional_agent_instructions[:20]}...',
]
# Only include microagent_knowledge if it's not empty
# Build a string representation
fields = []
if self.recall_type == RecallType.WORKSPACE_CONTEXT:
fields.extend(
[
f'recall_type={self.recall_type}',
f'repo_name={self.repo_name}',
f'repo_instructions={self.repo_instructions[:20]}...',
f'runtime_hosts={self.runtime_hosts}',
f'additional_agent_instructions={self.additional_agent_instructions[:20]}...',
]
)
else:
fields.extend(
[
f'recall_type={self.recall_type}',
]
)
if self.microagent_knowledge:
fields.append(
f'microagent_knowledge={", ".join([m.name for m in self.microagent_knowledge])}'
fields.extend(
[
f'microagent_knowledge={", ".join([m.name for m in self.microagent_knowledge])}',
]
)
return f'**MicroagentObservation**\n{", ".join(fields)}'
return f'**RecallObservation**\n{", ".join(fields)}'
+1 -1
View File
@@ -122,7 +122,7 @@ def event_to_dict(event: 'Event') -> dict:
# props is a dict whose values can include a complex object like an instance of a BaseModel subclass
# such as CmdOutputMetadata
# we serialize it along with the rest
# we also handle the Enum conversion for MicroagentObservation
# we also handle the Enum conversion for RecallObservation
d['extras'] = {
k: (v.value if isinstance(v, Enum) else _convert_pydantic_to_dict(v))
for k, v in props.items()
@@ -6,7 +6,7 @@ from openhands.events.observation.agent import (
AgentStateChangedObservation,
AgentThinkObservation,
MicroagentKnowledge,
MicroagentObservation,
RecallObservation,
)
from openhands.events.observation.browse import BrowserOutputObservation
from openhands.events.observation.commands import (
@@ -43,7 +43,7 @@ observations = (
UserRejectObservation,
AgentCondensationObservation,
AgentThinkObservation,
MicroagentObservation,
RecallObservation,
)
OBSERVATION_TYPE_TO_CLASS = {
@@ -114,7 +114,7 @@ def observation_from_dict(observation: dict) -> Observation:
else:
extras['metadata'] = CmdOutputMetadata()
if observation_class is MicroagentObservation:
if observation_class is RecallObservation:
# handle the Enum conversion
if 'recall_type' in extras:
extras['recall_type'] = RecallType(extras['recall_type'])
@@ -5,6 +5,7 @@ from typing import Any
import httpx
from pydantic import SecretStr
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.service_types import (
AuthenticationError,
GitService,
@@ -15,7 +16,7 @@ from openhands.integrations.service_types import (
User,
)
from openhands.utils.import_utils import get_impl
from openhands.core.logger import openhands_logger as logger
class GitHubService(GitService):
BASE_URL = 'https://api.github.com'
@@ -25,6 +26,7 @@ class GitHubService(GitService):
def __init__(
self,
user_id: str | None = None,
external_auth_id: str | None = None,
external_auth_token: SecretStr | None = None,
token: SecretStr | None = None,
external_token_manager: bool = False,
+54 -28
View File
@@ -31,7 +31,7 @@ from openhands.events.observation import (
)
from openhands.events.observation.agent import (
MicroagentKnowledge,
MicroagentObservation,
RecallObservation,
)
from openhands.events.observation.error import ErrorObservation
from openhands.events.observation.observation import Observation
@@ -52,7 +52,6 @@ class ConversationMemory:
initial_messages: list[Message],
max_message_chars: int | None = None,
vision_is_active: bool = False,
enable_som_visual_browsing: bool = False,
) -> list[Message]:
"""Process state history into a list of messages for the LLM.
@@ -64,11 +63,13 @@ class ConversationMemory:
max_message_chars: The maximum number of characters in the content of an event included
in the prompt to the LLM. Larger observations are truncated.
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included.
enable_som_visual_browsing: Whether to enable visual browsing for the SOM model.
"""
events = condensed_history
# log visual browsing status
logger.debug(f'Visual browsing: {self.agent_config.enable_som_visual_browsing}')
# Process special events first (system prompts, etc.)
messages = initial_messages
@@ -385,19 +386,19 @@ class ConversationMemory:
text = truncate_content(obs.content, max_message_chars)
message = Message(role='user', content=[TextContent(text=text)])
elif (
isinstance(obs, MicroagentObservation)
isinstance(obs, RecallObservation)
and self.agent_config.enable_prompt_extensions
):
if obs.recall_type == RecallType.WORKSPACE_CONTEXT:
# everything is optional, check if they are present
repo_info = (
RepositoryInfo(
if obs.repo_name or obs.repo_directory:
repo_info = RepositoryInfo(
repo_name=obs.repo_name or '',
repo_directory=obs.repo_directory or '',
)
if obs.repo_name or obs.repo_directory
else None
)
else:
repo_info = None
if obs.runtime_hosts or obs.additional_agent_instructions:
runtime_info = RuntimeInfo(
available_hosts=obs.runtime_hosts,
@@ -420,22 +421,49 @@ class ConversationMemory:
)
has_repo_instructions = bool(repo_instructions.strip())
# Build additional info if we have something to render
# Filter and process microagent knowledge
filtered_agents = []
if obs.microagent_knowledge:
# Exclude disabled microagents
filtered_agents = [
agent
for agent in obs.microagent_knowledge
if agent.name not in self.agent_config.disabled_microagents
]
has_microagent_knowledge = bool(filtered_agents)
# Generate appropriate content based on what is present
message_content = []
# Build the workspace context information
if has_repo_info or has_runtime_info or has_repo_instructions:
# ok, now we can build the additional info
formatted_text = self.prompt_manager.build_additional_info(
repository_info=repo_info,
runtime_info=runtime_info,
repo_instructions=repo_instructions,
formatted_workspace_text = (
self.prompt_manager.build_workspace_context(
repository_info=repo_info,
runtime_info=runtime_info,
repo_instructions=repo_instructions,
)
)
message = Message(
role='user', content=[TextContent(text=formatted_text)]
message_content.append(TextContent(text=formatted_workspace_text))
# Add microagent knowledge if present
if has_microagent_knowledge:
formatted_microagent_text = (
self.prompt_manager.build_microagent_info(
triggered_agents=filtered_agents,
)
)
message_content.append(TextContent(text=formatted_microagent_text))
# Return the combined message if we have any content
if message_content:
message = Message(role='user', content=message_content)
else:
return []
elif obs.recall_type == RecallType.KNOWLEDGE:
# Use prompt manager to build the microagent info
# First, filter out agents that appear in earlier MicroagentObservations
# First, filter out agents that appear in earlier RecallObservations
filtered_agents = self._filter_agents_in_microagent_obs(
obs, current_index, events or []
)
@@ -464,7 +492,7 @@ class ConversationMemory:
# Return empty list if no microagents to include or all were disabled
return []
elif (
isinstance(obs, MicroagentObservation)
isinstance(obs, RecallObservation)
and not self.agent_config.enable_prompt_extensions
):
# If prompt extensions are disabled, we don't add any additional info
@@ -504,12 +532,12 @@ class ConversationMemory:
break
def _filter_agents_in_microagent_obs(
self, obs: MicroagentObservation, current_index: int, events: list[Event]
self, obs: RecallObservation, current_index: int, events: list[Event]
) -> list[MicroagentKnowledge]:
"""Filter out agents that appear in earlier MicroagentObservations.
"""Filter out agents that appear in earlier RecallObservations.
Args:
obs: The current MicroagentObservation to filter
obs: The current RecallObservation to filter
current_index: The index of the current event in the events list
events: The list of all events
@@ -532,7 +560,7 @@ class ConversationMemory:
def _has_agent_in_earlier_events(
self, agent_name: str, current_index: int, events: list[Event]
) -> bool:
"""Check if an agent appears in any earlier MicroagentObservation in the event list.
"""Check if an agent appears in any earlier RecallObservation in the event list.
Args:
agent_name: The name of the agent to look for
@@ -540,13 +568,11 @@ class ConversationMemory:
events: The list of all events
Returns:
bool: True if the agent appears in an earlier MicroagentObservation, False otherwise
bool: True if the agent appears in an earlier RecallObservation, False otherwise
"""
for event in events[:current_index]:
if (
isinstance(event, MicroagentObservation)
and event.recall_type == RecallType.KNOWLEDGE
):
# Note that this check includes the WORKSPACE_CONTEXT
if isinstance(event, RecallObservation):
if any(
agent.name == agent_name for agent in event.microagent_knowledge
):
+77 -55
View File
@@ -9,7 +9,7 @@ from openhands.events.action.agent import RecallAction
from openhands.events.event import Event, EventSource, RecallType
from openhands.events.observation.agent import (
MicroagentKnowledge,
MicroagentObservation,
RecallObservation,
)
from openhands.events.observation.empty import NullObservation
from openhands.events.stream import EventStream, EventStreamSubscriber
@@ -31,7 +31,7 @@ GLOBAL_MICROAGENTS_DIR = os.path.join(
class Memory:
"""
Memory is a component that listens to the EventStream for information retrieval actions
(a RecallAction) and publishes observations with the content (such as MicroagentObservation).
(a RecallAction) and publishes observations with the content (such as RecallObservation).
"""
sid: str
@@ -75,48 +75,59 @@ class Memory:
async def _on_event(self, event: Event):
"""Handle an event from the event stream asynchronously."""
try:
observation: MicroagentObservation | NullObservation | None = None
if isinstance(event, RecallAction):
# if this is a workspace context recall (on first user message)
# create and add a MicroagentObservation
# with info about repo and runtime.
# create and add a RecallObservation
# with info about repo, runtime, instructions, etc. including microagent knowledge if any
if (
event.source == EventSource.USER
and event.recall_type == RecallType.WORKSPACE_CONTEXT
):
observation = self._on_first_microagent_action(event)
logger.debug('Workspace context recall')
workspace_obs: RecallObservation | NullObservation | None = None
# continue with the next handler, to include knowledge microagents if suitable for this query
assert observation is None or isinstance(
observation, MicroagentObservation
), f'Expected a MicroagentObservation, but got {type(observation)}'
observation = self._on_microagent_action(
event, prev_observation=observation
)
workspace_obs = self._on_workspace_context_recall(event)
if workspace_obs is None:
workspace_obs = NullObservation(content='')
if observation is None:
observation = NullObservation(content='')
# important: this will release the execution flow from waiting for the retrieval to complete
workspace_obs._cause = event.id # type: ignore[union-attr]
# important: this will release the execution flow from waiting for the retrieval to complete
observation._cause = event.id # type: ignore[union-attr]
self.event_stream.add_event(workspace_obs, EventSource.ENVIRONMENT)
return
self.event_stream.add_event(observation, EventSource.ENVIRONMENT)
# Handle knowledge recall (triggered microagents)
elif (
event.source == EventSource.USER
and event.recall_type == RecallType.KNOWLEDGE
):
logger.debug('Microagent knowledge recall')
microagent_obs: RecallObservation | NullObservation | None = None
microagent_obs = self._on_microagent_recall(event)
if microagent_obs is None:
microagent_obs = NullObservation(content='')
# important: this will release the execution flow from waiting for the retrieval to complete
microagent_obs._cause = event.id # type: ignore[union-attr]
self.event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
return
except Exception as e:
error_str = f'Error: {str(e.__class__.__name__)}'
logger.error(error_str)
self.send_error_message('STATUS$ERROR_MEMORY', error_str)
return
def _on_first_microagent_action(
def _on_workspace_context_recall(
self, event: RecallAction
) -> MicroagentObservation | None:
"""Add repository and runtime information to the stream as a MicroagentObservation."""
) -> RecallObservation | None:
"""Add repository and runtime information to the stream as a RecallObservation."""
# Create ENVIRONMENT info:
# Create WORKSPACE_CONTEXT info:
# - repository_info
# - runtime_info
# - repository_instructions
# - microagent_knowledge
# Collect raw repository instructions
repo_instructions = ''
@@ -130,9 +141,17 @@ class Memory:
repo_instructions += '\n\n'
repo_instructions += microagent.content
# Find any matched microagents based on the query
microagent_knowledge = self._find_microagent_knowledge(event.query)
# Create observation if we have anything
if self.repository_info or self.runtime_info or repo_instructions:
obs = MicroagentObservation(
if (
self.repository_info
or self.runtime_info
or repo_instructions
or microagent_knowledge
):
obs = RecallObservation(
recall_type=RecallType.WORKSPACE_CONTEXT,
repo_name=self.repository_info.repo_name
if self.repository_info and self.repository_info.repo_name is not None
@@ -149,29 +168,47 @@ class Memory:
if self.runtime_info
and self.runtime_info.additional_agent_instructions is not None
else '',
microagent_knowledge=[],
content='Retrieved environment info',
microagent_knowledge=microagent_knowledge,
content='Added workspace context',
)
return obs
return None
def _on_microagent_action(
def _on_microagent_recall(
self,
event: RecallAction,
prev_observation: MicroagentObservation | None = None,
) -> MicroagentObservation | None:
"""When a microagent action triggers microagents, create a MicroagentObservation with structured data."""
# If there's no query, do nothing
query = event.query.strip()
if not query:
return prev_observation
) -> RecallObservation | None:
"""When a microagent action triggers microagents, create a RecallObservation with structured data."""
assert prev_observation is None or isinstance(
prev_observation, MicroagentObservation
), f'Expected a MicroagentObservation, but got {type(prev_observation)}'
# Find any matched microagents based on the query
microagent_knowledge = self._find_microagent_knowledge(event.query)
# Process text to find suitable microagents and create a MicroagentObservation.
# Create observation if we have anything
if microagent_knowledge:
obs = RecallObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=microagent_knowledge,
content='Retrieved knowledge from microagents',
)
return obs
return None
def _find_microagent_knowledge(self, query: str) -> list[MicroagentKnowledge]:
"""Find microagent knowledge based on a query.
Args:
query: The query to search for microagent triggers
Returns:
A list of MicroagentKnowledge objects for matched triggers
"""
recalled_content: list[MicroagentKnowledge] = []
# skip empty queries
if not query:
return recalled_content
# Search for microagent triggers in the query
for name, microagent in self.knowledge_microagents.items():
trigger = microagent.match_trigger(query)
if trigger:
@@ -183,22 +220,7 @@ class Memory:
content=microagent.content,
)
)
if recalled_content:
if prev_observation is not None:
# it may be on the first user message that already found some repo info etc
prev_observation.microagent_knowledge.extend(recalled_content)
else:
# if it's not the first user message, we may not have found any information this step
obs = MicroagentObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=recalled_content,
content='Retrieved knowledge from microagents',
)
return obs
return prev_observation
return recalled_content
def load_user_workspace_microagents(
self, user_microagents: list[BaseMicroAgent]
+4 -4
View File
@@ -97,7 +97,7 @@ class Runtime(FileEditRuntimeMixin):
status_callback: Callable | None = None,
attach_to_existing: bool = False,
headless_mode: bool = False,
github_user_id: str | None = None,
user_id: str | None = None,
):
self.sid = sid
self.event_stream = event_stream
@@ -130,7 +130,7 @@ class Runtime(FileEditRuntimeMixin):
self, enable_llm_editor=config.get_agent_config().codeact_enable_llm_editor
)
self.github_user_id = github_user_id
self.user_id = user_id
def setup_initial_env(self) -> None:
if self.attach_to_existing:
@@ -220,9 +220,9 @@ class Runtime(FileEditRuntimeMixin):
assert event.timeout is not None
try:
if isinstance(event, CmdRunAction):
if self.github_user_id and '$GITHUB_TOKEN' in event.command:
if self.user_id and '$GITHUB_TOKEN' in event.command:
gh_client = GithubServiceImpl(
user_id=self.github_user_id, external_token_manager=True
external_auth_id=self.user_id, external_token_manager=True
)
token = await gh_client.get_latest_token()
if token:
@@ -59,7 +59,7 @@ class ActionExecutionClient(Runtime):
status_callback: Any | None = None,
attach_to_existing: bool = False,
headless_mode: bool = True,
github_user_id: str | None = None,
user_id: str | None = None,
):
self.session = HttpSession()
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
@@ -75,7 +75,7 @@ class ActionExecutionClient(Runtime):
status_callback,
attach_to_existing,
headless_mode,
github_user_id,
user_id,
)
@abstractmethod
@@ -1,3 +1,4 @@
import logging
import os
from typing import Callable
from urllib.parse import urlparse
@@ -45,7 +46,7 @@ class RemoteRuntime(ActionExecutionClient):
status_callback: Callable | None = None,
attach_to_existing: bool = False,
headless_mode: bool = True,
github_user_id: str | None = None,
user_id: str | None = None,
):
super().__init__(
config,
@@ -56,7 +57,7 @@ class RemoteRuntime(ActionExecutionClient):
status_callback,
attach_to_existing,
headless_mode,
github_user_id,
user_id,
)
if self.config.sandbox.api_key is None:
raise ValueError(
@@ -425,10 +426,11 @@ class RemoteRuntime(ActionExecutionClient):
return self._send_action_server_request_impl(method, url, **kwargs)
retry_decorator = tenacity.retry(
retry=tenacity.retry_if_exception_type(ConnectionError),
retry=tenacity.retry_if_exception_type(requests.ConnectionError),
stop=tenacity.stop_after_attempt(3)
| stop_if_should_exit()
| self._stop_if_closed,
before_sleep=tenacity.before_sleep_log(logger, logging.WARNING),
wait=tenacity.wait_exponential(multiplier=1, min=4, max=60),
)
return retry_decorator(self._send_action_server_request_impl)(
@@ -46,7 +46,12 @@ class ConversationManager(ABC):
@abstractmethod
async def join_conversation(
self, sid: str, connection_id: str, settings: Settings, user_id: str | None
self,
sid: str,
connection_id: str,
settings: Settings,
user_id: str | None,
github_user_id: str | None,
) -> EventStream | None:
"""Join a conversation and return its event stream."""
@@ -74,6 +79,7 @@ class ConversationManager(ABC):
settings: Settings,
user_id: str | None,
initial_user_msg: MessageAction | None = None,
github_user_id: str | None = None,
) -> EventStream:
"""Start an event loop if one is not already running"""
@@ -106,7 +106,12 @@ class StandaloneConversationManager(ConversationManager):
return c
async def join_conversation(
self, sid: str, connection_id: str, settings: Settings, user_id: str | None
self,
sid: str,
connection_id: str,
settings: Settings,
user_id: str | None,
github_user_id: str | None,
):
logger.info(
f'join_conversation:{sid}:{connection_id}',
@@ -116,7 +121,9 @@ class StandaloneConversationManager(ConversationManager):
self._local_connection_id_to_session_id[connection_id] = sid
event_stream = await self._get_event_stream(sid)
if not event_stream:
return await self.maybe_start_agent_loop(sid, settings, user_id)
return await self.maybe_start_agent_loop(
sid, settings, user_id, github_user_id=github_user_id
)
for event in event_stream.get_events(reverse=True):
if isinstance(event, AgentStateChangedObservation):
if event.agent_state in (
@@ -187,14 +194,18 @@ class StandaloneConversationManager(ConversationManager):
logger.error('error_cleaning_stale')
await asyncio.sleep(_CLEANUP_INTERVAL)
async def _get_conversation_store(self, user_id: str | None) -> ConversationStore:
async def _get_conversation_store(
self, user_id: str | None, github_user_id: str | None
) -> ConversationStore:
conversation_store_class = self._conversation_store_class
if not conversation_store_class:
self._conversation_store_class = conversation_store_class = get_impl(
ConversationStore, # type: ignore
self.server_config.conversation_store_class,
)
store = await conversation_store_class.get_instance(self.config, user_id)
store = await conversation_store_class.get_instance(
self.config, user_id, github_user_id
)
return store
async def get_running_agent_loops(
@@ -243,6 +254,7 @@ class StandaloneConversationManager(ConversationManager):
settings: Settings,
user_id: str | None,
initial_user_msg: MessageAction | None = None,
github_user_id: str | None = None,
) -> EventStream:
logger.info(f'maybe_start_agent_loop:{sid}', extra={'session_id': sid})
session: Session | None = None
@@ -256,7 +268,9 @@ class StandaloneConversationManager(ConversationManager):
extra={'session_id': sid, 'user_id': user_id},
)
# Get the conversations sorted (oldest first)
conversation_store = await self._get_conversation_store(user_id)
conversation_store = await self._get_conversation_store(
user_id, github_user_id
)
conversations = await conversation_store.get_all_metadata(response_ids)
conversations.sort(key=_last_updated_at_key, reverse=True)
@@ -277,7 +291,9 @@ class StandaloneConversationManager(ConversationManager):
try:
session.agent_session.event_stream.subscribe(
EventStreamSubscriber.SERVER,
self._create_conversation_update_callback(user_id, sid),
self._create_conversation_update_callback(
user_id, github_user_id, sid
),
UPDATED_AT_CALLBACK_ID,
)
except ValueError:
@@ -374,22 +390,23 @@ class StandaloneConversationManager(ConversationManager):
)
def _create_conversation_update_callback(
self, user_id: str | None, conversation_id: str
self, user_id: str | None, github_user_id: str | None, conversation_id: str
) -> Callable:
def callback(*args, **kwargs):
call_async_from_sync(
self._update_timestamp_for_conversation,
GENERAL_TIMEOUT,
user_id,
github_user_id,
conversation_id,
)
return callback
async def _update_timestamp_for_conversation(
self, user_id: str, conversation_id: str
self, user_id: str, github_user_id: str, conversation_id: str
):
conversation_store = await self._get_conversation_store(user_id)
conversation_store = await self._get_conversation_store(user_id, github_user_id)
conversation = await conversation_store.get_metadata(conversation_id)
conversation.last_updated_at = datetime.now(timezone.utc)
await conversation_store.save_metadata(conversation)
+10 -7
View File
@@ -6,10 +6,14 @@ from openhands.core.logger import openhands_logger as logger
from openhands.events.action import (
NullAction,
)
from openhands.events.action.agent import RecallAction
from openhands.events.observation import (
NullObservation,
)
from openhands.events.observation.agent import AgentStateChangedObservation
from openhands.events.observation.agent import (
AgentStateChangedObservation,
RecallObservation,
)
from openhands.events.serialization import event_to_dict
from openhands.events.stream import AsyncEventStreamWrapper
from openhands.server.shared import (
@@ -35,7 +39,9 @@ async def connect(connection_id: str, environ):
cookies_str = environ.get('HTTP_COOKIE', '')
conversation_validator = ConversationValidatorImpl()
user_id = await conversation_validator.validate(conversation_id, cookies_str)
user_id, github_user_id = await conversation_validator.validate(
conversation_id, cookies_str
)
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
settings = await settings_store.load()
@@ -46,7 +52,7 @@ async def connect(connection_id: str, environ):
)
event_stream = await conversation_manager.join_conversation(
conversation_id, connection_id, settings, user_id
conversation_id, connection_id, settings, user_id, github_user_id
)
agent_state_changed = None
@@ -54,10 +60,7 @@ async def connect(connection_id: str, environ):
async for event in async_stream:
if isinstance(
event,
(
NullAction,
NullObservation,
),
(NullAction, NullObservation, RecallAction, RecallObservation),
):
continue
elif isinstance(event, AgentStateChangedObservation):
+20 -11
View File
@@ -10,7 +10,12 @@ from openhands.events.action.message import MessageAction
from openhands.integrations.github.github_service import GithubServiceImpl
from openhands.integrations.provider import ProviderType
from openhands.runtime import get_runtime_cls
from openhands.server.auth import get_provider_tokens, get_access_token, get_github_user_id
from openhands.server.auth import (
get_access_token,
get_github_user_id,
get_provider_tokens,
get_user_id,
)
from openhands.server.data_models.conversation_info import ConversationInfo
from openhands.server.data_models.conversation_info_result_set import (
ConversationInfoResultSet,
@@ -73,12 +78,12 @@ async def _create_new_conversation(
logger.warn('Settings not present, not starting conversation')
raise MissingSettingsError('Settings not found')
session_init_args['github_token'] = token or SecretStr('')
session_init_args['provider_token'] = token
session_init_args['selected_repository'] = selected_repository
session_init_args['selected_branch'] = selected_branch
conversation_init_data = ConversationInitData(**session_init_args)
logger.info('Loading conversation store')
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
conversation_store = await ConversationStoreImpl.get_instance(config, user_id, None)
logger.info('Conversation store loaded')
conversation_id = uuid.uuid4().hex
@@ -100,7 +105,8 @@ async def _create_new_conversation(
ConversationMetadata(
conversation_id=conversation_id,
title=conversation_title,
github_user_id=user_id,
user_id=user_id,
github_user_id=None,
selected_repository=selected_repository,
selected_branch=selected_branch,
)
@@ -122,7 +128,10 @@ async def _create_new_conversation(
image_urls=image_urls or [],
)
await conversation_manager.maybe_start_agent_loop(
conversation_id, conversation_init_data, user_id, initial_message_action
conversation_id,
conversation_init_data,
user_id,
initial_user_msg=initial_message_action,
)
logger.info(f'Finished initializing conversation {conversation_id}')
@@ -158,7 +167,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
try:
# Create conversation with initial message
conversation_id = await _create_new_conversation(
user_id,
get_user_id(request),
github_token,
selected_repository,
selected_branch,
@@ -197,7 +206,7 @@ async def search_conversations(
limit: int = 20,
) -> ConversationInfoResultSet:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_github_user_id(request)
config, get_user_id(request), get_github_user_id(request)
)
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
@@ -216,7 +225,7 @@ async def search_conversations(
conversation.conversation_id for conversation in filtered_results
)
running_conversations = await conversation_manager.get_running_agent_loops(
get_github_user_id(request), set(conversation_ids)
get_user_id(request), set(conversation_ids)
)
result = ConversationInfoResultSet(
results=await wait_all(
@@ -236,7 +245,7 @@ async def get_conversation(
conversation_id: str, request: Request
) -> ConversationInfo | None:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_github_user_id(request)
config, get_user_id(request), get_github_user_id(request)
)
try:
metadata = await conversation_store.get_metadata(conversation_id)
@@ -252,7 +261,7 @@ async def update_conversation(
request: Request, conversation_id: str, title: str = Body(embed=True)
) -> bool:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_github_user_id(request)
config, get_user_id(request), get_github_user_id(request)
)
metadata = await conversation_store.get_metadata(conversation_id)
if not metadata:
@@ -268,7 +277,7 @@ async def delete_conversation(
request: Request,
) -> bool:
conversation_store = await ConversationStoreImpl.get_instance(
config, get_github_user_id(request)
config, get_user_id(request), get_github_user_id(request)
)
try:
await conversation_store.get_metadata(conversation_id)
+30 -22
View File
@@ -90,30 +90,38 @@ async def store_settings(
existing_settings.user_consents_to_analytics
)
if existing_settings.secrets_store:
existing_providers = [
provider.value
for provider in existing_settings.secrets_store.provider_tokens
]
# Merge incoming settings store with the existing one
for provider, token_value in settings.provider_tokens.items():
if provider in existing_providers and not token_value:
provider_type = ProviderType(provider)
existing_token = (
existing_settings.secrets_store.provider_tokens.get(
provider_type
)
)
if existing_token and existing_token.token:
settings.provider_tokens[provider] = (
existing_token.token.get_secret_value()
)
# Merge provider tokens with existing ones
if settings.unset_github_token: # Only merge if not unsetting tokens
if settings.unset_github_token:
settings.secrets_store.provider_tokens = {}
settings.provider_tokens = {}
else: # Only merge if not unsetting tokens
if settings.provider_tokens:
if existing_settings.secrets_store:
existing_providers = [
provider.value
for provider in existing_settings.secrets_store.provider_tokens
]
# Merge incoming settings store with the existing one
for provider, token_value in settings.provider_tokens.items():
if provider in existing_providers and not token_value:
provider_type = ProviderType(provider)
existing_token = (
existing_settings.secrets_store.provider_tokens.get(
provider_type
)
)
if existing_token and existing_token.token:
settings.provider_tokens[provider] = (
existing_token.token.get_secret_value()
)
else: # nothing passed in means keep current settings
provider_tokens = existing_settings.secrets_store.provider_tokens
settings.provider_tokens = {
provider.value: data.token.get_secret_value()
if data.token
else None
for provider, data in provider_tokens.items()
}
# Update sandbox config with new settings
if settings.remote_runtime_resource_factor is not None:
+4 -4
View File
@@ -53,7 +53,7 @@ class AgentSession:
sid: str,
file_store: FileStore,
status_callback: Callable | None = None,
github_user_id: str | None = None,
user_id: str | None = None,
):
"""Initializes a new instance of the Session class
@@ -66,9 +66,9 @@ class AgentSession:
self.event_stream = EventStream(sid, file_store)
self.file_store = file_store
self._status_callback = status_callback
self.github_user_id = github_user_id
self.user_id = user_id
self.logger = OpenHandsLoggerAdapter(
extra={'session_id': sid, 'user_id': github_user_id}
extra={'session_id': sid, 'user_id': user_id}
)
async def start(
@@ -241,7 +241,7 @@ class AgentSession:
kwargs = {}
if runtime_cls == RemoteRuntime:
kwargs['github_user_id'] = self.github_user_id
kwargs['user_id'] = self.user_id
self.runtime = runtime_cls(
config=config,
@@ -8,6 +8,6 @@ class ConversationInitData(Settings):
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
"""
github_token: SecretStr | None = Field(default=None)
provider_token: SecretStr | None = Field(default=None)
selected_repository: str | None = Field(default=None)
selected_branch: str | None = Field(default=None)
+4 -4
View File
@@ -61,7 +61,7 @@ class Session:
sid,
file_store,
status_callback=self.queue_status_message,
github_user_id=user_id,
user_id=user_id,
)
self.agent_session.event_stream.subscribe(
EventStreamSubscriber.SERVER, self.on_event, self.sid
@@ -123,11 +123,11 @@ class Session:
agent = Agent.get_cls(agent_cls)(llm, agent_config)
github_token = None
provider_token = None
selected_repository = None
selected_branch = None
if isinstance(settings, ConversationInitData):
github_token = settings.github_token
provider_token = settings.provider_token
selected_repository = settings.selected_repository
selected_branch = settings.selected_branch
@@ -140,7 +140,7 @@ class Session:
max_budget_per_task=self.config.max_budget_per_task,
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
agent_configs=self.config.get_agent_configs(),
github_token=github_token,
github_token=provider_token,
selected_repository=selected_repository,
selected_branch=selected_branch,
initial_message=initial_message,
+1 -1
View File
@@ -43,7 +43,7 @@ class Settings(BaseModel):
if context and context.get('expose_secrets', False):
return llm_api_key.get_secret_value()
return pydantic_encoder(llm_api_key)
return pydantic_encoder(llm_api_key) if llm_api_key else None
@staticmethod
def _convert_token_value(
@@ -12,25 +12,36 @@ from openhands.utils.async_utils import wait_all
class ConversationStore(ABC):
"""
Storage for conversation metadata. May or may not support multiple users depending on the environment
"""
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
@abstractmethod
async def save_metadata(self, metadata: ConversationMetadata) -> None:
"""Store conversation metadata"""
"""Store conversation metadata."""
@abstractmethod
async def get_metadata(self, conversation_id: str) -> ConversationMetadata:
"""Load conversation metadata"""
"""Load conversation metadata."""
async def validate_metadata(
self, conversation_id: str, user_id: str, github_user_id: str
) -> bool:
"""Validate that conversation belongs to the current user."""
# TODO: remove github_user_id after transition to Keycloak is complete.
metadata = await self.get_metadata(conversation_id)
if (not metadata.user_id and not metadata.github_user_id) or (
metadata.user_id != user_id and metadata.github_user_id != github_user_id
):
return False
else:
return True
@abstractmethod
async def delete_metadata(self, conversation_id: str) -> None:
"""delete conversation metadata"""
"""Delete conversation metadata."""
@abstractmethod
async def exists(self, conversation_id: str) -> bool:
"""Check if conversation exists"""
"""Check if conversation exists."""
@abstractmethod
async def search(
@@ -49,6 +60,6 @@ class ConversationStore(ABC):
@classmethod
@abstractmethod
async def get_instance(
cls, config: AppConfig, user_id: str | None
cls, config: AppConfig, user_id: str | None, github_user_id: str | None
) -> ConversationStore:
"""Get a store for the user represented by the token given"""
@@ -7,7 +7,7 @@ class ConversationValidator:
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
async def validate(self, conversation_id: str, cookies_str: str):
return None
return None, None
conversation_validator_cls = os.environ.get(
@@ -101,7 +101,7 @@ class FileConversationStore(ConversationStore):
@classmethod
async def get_instance(
cls, config: AppConfig, user_id: str | None
cls, config: AppConfig, user_id: str | None, github_user_id: str | None
) -> FileConversationStore:
file_store = get_file_store(config.file_store, config.file_store_path)
return FileConversationStore(file_store)
@@ -5,6 +5,7 @@ from datetime import datetime, timezone
@dataclass
class ConversationMetadata:
conversation_id: str
user_id: str | None
github_user_id: str | None
selected_repository: str | None
selected_branch: str | None = None
+1 -1
View File
@@ -76,7 +76,7 @@ class PromptManager:
if example_message:
message.content.insert(0, TextContent(text=example_message))
def build_additional_info(
def build_workspace_context(
self,
repository_info: RepositoryInfo | None,
runtime_info: RuntimeInfo | None,
Generated
+37 -37
View File
@@ -496,18 +496,18 @@ files = [
[[package]]
name = "boto3"
version = "1.37.11"
version = "1.37.12"
description = "The AWS SDK for Python"
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "boto3-1.37.11-py3-none-any.whl", hash = "sha256:da6c22fc8a7e9bca5d7fc465a877ac3d45b6b086d776bd1a6c55bdde60523741"},
{file = "boto3-1.37.11.tar.gz", hash = "sha256:8eec08363ef5db05c2fbf58e89f0c0de6276cda2fdce01e76b3b5f423cd5c0f4"},
{file = "boto3-1.37.12-py3-none-any.whl", hash = "sha256:516feaa0d2afaeda1515216fd09291368a1215754bbccb0f28414c0a91a830a2"},
{file = "boto3-1.37.12.tar.gz", hash = "sha256:9412d404f103ad6d14f033eb29cd5e0cdca2b9b08cbfa9d4dabd1d7be2de2625"},
]
[package.dependencies]
botocore = ">=1.37.11,<1.38.0"
botocore = ">=1.37.12,<1.38.0"
jmespath = ">=0.7.1,<2.0.0"
s3transfer = ">=0.11.0,<0.12.0"
@@ -516,14 +516,14 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
[[package]]
name = "botocore"
version = "1.37.11"
version = "1.37.12"
description = "Low-level, data-driven core of boto 3."
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "botocore-1.37.11-py3-none-any.whl", hash = "sha256:02505309b1235f9f15a6da79103ca224b3f3dc5f6a62f8630fbb2c6ed05e2da8"},
{file = "botocore-1.37.11.tar.gz", hash = "sha256:72eb3a9a58b064be26ba154e5e56373633b58f951941c340ace0d379590d98b5"},
{file = "botocore-1.37.12-py3-none-any.whl", hash = "sha256:ba1948c883bbabe20d95ff62c3e36954c9269686f7db9361857835677ca3e676"},
{file = "botocore-1.37.12.tar.gz", hash = "sha256:ae2d5328ce6ad02eb615270507235a6e90fd3eeed615a6c0732b5a68b12f2017"},
]
[package.dependencies]
@@ -3547,14 +3547,14 @@ test = ["jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-jupyter[server] (>
[[package]]
name = "jupyterlab"
version = "4.3.5"
version = "4.3.6"
description = "JupyterLab computational environment"
optional = false
python-versions = ">=3.8"
groups = ["runtime"]
files = [
{file = "jupyterlab-4.3.5-py3-none-any.whl", hash = "sha256:571bbdee20e4c5321ab5195bc41cf92a75a5cff886be5e57ce78dfa37a5e9fdb"},
{file = "jupyterlab-4.3.5.tar.gz", hash = "sha256:c779bf72ced007d7d29d5bcef128e7fdda96ea69299e19b04a43635a7d641f9d"},
{file = "jupyterlab-4.3.6-py3-none-any.whl", hash = "sha256:fc9eb0455562a56a9bd6d2977cf090842f321fa1a298fcee9bf8c19de353d5fd"},
{file = "jupyterlab-4.3.6.tar.gz", hash = "sha256:2900ffdbfca9ed37c4ad7fdda3eb76582fd945d46962af3ac64741ae2d6b2ff4"},
]
[package.dependencies]
@@ -4251,14 +4251,14 @@ files = [
[[package]]
name = "modal"
version = "0.73.98"
version = "0.73.102"
description = "Python client library for Modal"
optional = false
python-versions = ">=3.9"
groups = ["main", "evaluation"]
files = [
{file = "modal-0.73.98-py3-none-any.whl", hash = "sha256:a49cd5f5b46d1a6c6a0d528618d3cbb73ac2908e199716590ec3a5275d79ed98"},
{file = "modal-0.73.98.tar.gz", hash = "sha256:817f73c222fa39a16d6888a92eb7a6847ecae574e44ef04e2dce5e534bdd2df9"},
{file = "modal-0.73.102-py3-none-any.whl", hash = "sha256:26151ef6164e0b93b0d1961f73d5a715deb72f23e2641215f5410cf58bf403d3"},
{file = "modal-0.73.102.tar.gz", hash = "sha256:198876cf94ff13633283e251d8b37cc1f1bb5e27a7aa547e02072def1f29b66e"},
]
[package.dependencies]
@@ -4670,19 +4670,19 @@ files = [
[[package]]
name = "notebook"
version = "7.3.2"
version = "7.3.3"
description = "Jupyter Notebook - A web-based notebook environment for interactive computing"
optional = false
python-versions = ">=3.8"
groups = ["runtime"]
files = [
{file = "notebook-7.3.2-py3-none-any.whl", hash = "sha256:e5f85fc59b69d3618d73cf27544418193ff8e8058d5bf61d315ce4f473556288"},
{file = "notebook-7.3.2.tar.gz", hash = "sha256:705e83a1785f45b383bf3ee13cb76680b92d24f56fb0c7d2136fe1d850cd3ca8"},
{file = "notebook-7.3.3-py3-none-any.whl", hash = "sha256:b193df0878956562d5171c8e25c9252b8e86c9fcc16163b8ee3fe6c5e3f422f7"},
{file = "notebook-7.3.3.tar.gz", hash = "sha256:707a313fb882d35f921989eb3d204de942ed5132a44e4aa1fe0e8f24bb9dc25d"},
]
[package.dependencies]
jupyter-server = ">=2.4.0,<3"
jupyterlab = ">=4.3.4,<4.4"
jupyterlab = ">=4.3.6,<4.4"
jupyterlab-server = ">=2.27.1,<3"
notebook-shim = ">=0.2,<0.3"
tornado = ">=6.2.0"
@@ -6947,30 +6947,30 @@ pyasn1 = ">=0.1.3"
[[package]]
name = "ruff"
version = "0.9.10"
version = "0.11.0"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
groups = ["dev", "evaluation"]
files = [
{file = "ruff-0.9.10-py3-none-linux_armv6l.whl", hash = "sha256:eb4d25532cfd9fe461acc83498361ec2e2252795b4f40b17e80692814329e42d"},
{file = "ruff-0.9.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:188a6638dab1aa9bb6228a7302387b2c9954e455fb25d6b4470cb0641d16759d"},
{file = "ruff-0.9.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5284dcac6b9dbc2fcb71fdfc26a217b2ca4ede6ccd57476f52a587451ebe450d"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:47678f39fa2a3da62724851107f438c8229a3470f533894b5568a39b40029c0c"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:99713a6e2766b7a17147b309e8c915b32b07a25c9efd12ada79f217c9c778b3e"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:524ee184d92f7c7304aa568e2db20f50c32d1d0caa235d8ddf10497566ea1a12"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:df92aeac30af821f9acf819fc01b4afc3dfb829d2782884f8739fb52a8119a16"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de42e4edc296f520bb84954eb992a07a0ec5a02fecb834498415908469854a52"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d257f95b65806104b6b1ffca0ea53f4ef98454036df65b1eda3693534813ecd1"},
{file = "ruff-0.9.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b60dec7201c0b10d6d11be00e8f2dbb6f40ef1828ee75ed739923799513db24c"},
{file = "ruff-0.9.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:d838b60007da7a39c046fcdd317293d10b845001f38bcb55ba766c3875b01e43"},
{file = "ruff-0.9.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:ccaf903108b899beb8e09a63ffae5869057ab649c1e9231c05ae354ebc62066c"},
{file = "ruff-0.9.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f9567d135265d46e59d62dc60c0bfad10e9a6822e231f5b24032dba5a55be6b5"},
{file = "ruff-0.9.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5f202f0d93738c28a89f8ed9eaba01b7be339e5d8d642c994347eaa81c6d75b8"},
{file = "ruff-0.9.10-py3-none-win32.whl", hash = "sha256:bfb834e87c916521ce46b1788fbb8484966e5113c02df216680102e9eb960029"},
{file = "ruff-0.9.10-py3-none-win_amd64.whl", hash = "sha256:f2160eeef3031bf4b17df74e307d4c5fb689a6f3a26a2de3f7ef4044e3c484f1"},
{file = "ruff-0.9.10-py3-none-win_arm64.whl", hash = "sha256:5fd804c0327a5e5ea26615550e706942f348b197d5475ff34c19733aee4b2e69"},
{file = "ruff-0.9.10.tar.gz", hash = "sha256:9bacb735d7bada9cfb0f2c227d3658fc443d90a727b47f206fb33f52f3c0eac7"},
{file = "ruff-0.11.0-py3-none-linux_armv6l.whl", hash = "sha256:dc67e32bc3b29557513eb7eeabb23efdb25753684b913bebb8a0c62495095acb"},
{file = "ruff-0.11.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:38c23fd9bdec4eb437b4c1e3595905a0a8edfccd63a790f818b28c78fe345639"},
{file = "ruff-0.11.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7c8661b0be91a38bd56db593e9331beaf9064a79028adee2d5f392674bbc5e88"},
{file = "ruff-0.11.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b6c0e8d3d2db7e9f6efd884f44b8dc542d5b6b590fc4bb334fdbc624d93a29a2"},
{file = "ruff-0.11.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3c3156d3f4b42e57247275a0a7e15a851c165a4fc89c5e8fa30ea6da4f7407b8"},
{file = "ruff-0.11.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:490b1e147c1260545f6d041c4092483e3f6d8eba81dc2875eaebcf9140b53905"},
{file = "ruff-0.11.0-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:1bc09a7419e09662983b1312f6fa5dab829d6ab5d11f18c3760be7ca521c9329"},
{file = "ruff-0.11.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bcfa478daf61ac8002214eb2ca5f3e9365048506a9d52b11bea3ecea822bb844"},
{file = "ruff-0.11.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6fbb2aed66fe742a6a3a0075ed467a459b7cedc5ae01008340075909d819df1e"},
{file = "ruff-0.11.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:92c0c1ff014351c0b0cdfdb1e35fa83b780f1e065667167bb9502d47ca41e6db"},
{file = "ruff-0.11.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e4fd5ff5de5f83e0458a138e8a869c7c5e907541aec32b707f57cf9a5e124445"},
{file = "ruff-0.11.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:96bc89a5c5fd21a04939773f9e0e276308be0935de06845110f43fd5c2e4ead7"},
{file = "ruff-0.11.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a9352b9d767889ec5df1483f94870564e8102d4d7e99da52ebf564b882cdc2c7"},
{file = "ruff-0.11.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:049a191969a10897fe052ef9cc7491b3ef6de79acd7790af7d7897b7a9bfbcb6"},
{file = "ruff-0.11.0-py3-none-win32.whl", hash = "sha256:3191e9116b6b5bbe187447656f0c8526f0d36b6fd89ad78ccaad6bdc2fad7df2"},
{file = "ruff-0.11.0-py3-none-win_amd64.whl", hash = "sha256:c58bfa00e740ca0a6c43d41fb004cd22d165302f360aaa56f7126d544db31a21"},
{file = "ruff-0.11.0-py3-none-win_arm64.whl", hash = "sha256:868364fc23f5aa122b00c6f794211e85f7e78f5dffdf7c590ab90b8c4e69b657"},
{file = "ruff-0.11.0.tar.gz", hash = "sha256:e55c620690a4a7ee6f1cccb256ec2157dc597d109400ae75bbf944fc9d6462e2"},
]
[[package]]
@@ -9056,4 +9056,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
[metadata]
lock-version = "2.1"
python-versions = "^3.12"
content-hash = "6a644bc65782a717a49718496bd279ecb888807ec625d992af4448cc5d9271c1"
content-hash = "9b74f62a4afa719a1f7167e0b3b45cdaf282c2e18fd2931da91c0f1b22776178"
+6 -1
View File
@@ -80,7 +80,7 @@ daytona-sdk = "0.10.2"
python-json-logger = "^3.2.1"
[tool.poetry.group.dev.dependencies]
ruff = "0.9.10"
ruff = "0.11.0"
mypy = "1.15.0"
pre-commit = "4.1.0"
build = "*"
@@ -154,3 +154,8 @@ style = "semver"
[tool.poetry.scripts]
openhands = "openhands.core.cli:main"
[tool.poetry.group.testgeneval.dependencies]
fuzzywuzzy = "^0.18.0"
rouge = "^1.0.1"
python-levenshtein = "^0.26.1"
+16 -9
View File
@@ -19,7 +19,7 @@ from openhands.events.event import RecallType
from openhands.events.observation import (
ErrorObservation,
)
from openhands.events.observation.agent import MicroagentObservation
from openhands.events.observation.agent import RecallObservation
from openhands.events.serialization import event_to_dict
from openhands.llm import LLM
from openhands.llm.metrics import Metrics, TokenUsage
@@ -192,7 +192,7 @@ async def test_run_controller_with_fatal_error(test_event_stream, mock_memory):
def on_event_memory(event: Event):
if isinstance(event, RecallAction):
microagent_obs = MicroagentObservation(
microagent_obs = RecallObservation(
content='Test microagent content',
recall_type=RecallType.KNOWLEDGE,
)
@@ -249,7 +249,7 @@ async def test_run_controller_stop_with_stuck(test_event_stream, mock_memory):
def on_event_memory(event: Event):
if isinstance(event, RecallAction):
microagent_obs = MicroagentObservation(
microagent_obs = RecallObservation(
content='Test microagent content',
recall_type=RecallType.KNOWLEDGE,
)
@@ -596,7 +596,7 @@ async def test_run_controller_max_iterations_has_metrics(
def on_event_memory(event: Event):
if isinstance(event, RecallAction):
microagent_obs = MicroagentObservation(
microagent_obs = RecallObservation(
content='Test microagent content',
recall_type=RecallType.KNOWLEDGE,
)
@@ -718,7 +718,7 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
def on_event_memory(event: Event):
if isinstance(event, RecallAction):
microagent_obs = MicroagentObservation(
microagent_obs = RecallObservation(
content='Test microagent content',
recall_type=RecallType.KNOWLEDGE,
)
@@ -795,7 +795,7 @@ async def test_run_controller_with_context_window_exceeded_without_truncation(
def on_event_memory(event: Event):
if isinstance(event, RecallAction):
microagent_obs = MicroagentObservation(
microagent_obs = RecallObservation(
content='Test microagent content',
recall_type=RecallType.KNOWLEDGE,
)
@@ -845,23 +845,30 @@ async def test_run_controller_with_memory_error(test_event_stream):
config = AppConfig()
event_stream = test_event_stream
# Create a propert agent that returns an action without an ID
agent = MagicMock(spec=Agent)
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = config.get_llm_config()
# Create a real action to return from the mocked step function
def agent_step_fn(state):
return MessageAction(content='Agent returned a message')
agent.step = agent_step_fn
runtime = MagicMock(spec=Runtime)
runtime.event_stream = event_stream
# Create a real Memory instance
memory = Memory(event_stream=event_stream, sid='test-memory')
# Patch the _on_microagent_action method to raise our test exception
def mock_on_microagent_action(*args, **kwargs):
# Patch the _find_microagent_knowledge method to raise our test exception
def mock_find_microagent_knowledge(*args, **kwargs):
raise RuntimeError('Test memory error')
with patch.object(
memory, '_on_microagent_action', side_effect=mock_on_microagent_action
memory, '_find_microagent_knowledge', side_effect=mock_find_microagent_knowledge
):
state = await run_controller(
config=config,
+7 -7
View File
@@ -19,7 +19,7 @@ from openhands.events.action import (
)
from openhands.events.action.agent import RecallAction
from openhands.events.event import Event, RecallType
from openhands.events.observation.agent import MicroagentObservation
from openhands.events.observation.agent import RecallObservation
from openhands.events.stream import EventStreamSubscriber
from openhands.llm.llm import LLM
from openhands.llm.metrics import Metrics
@@ -86,10 +86,10 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
def on_event(event: Event):
if isinstance(event, RecallAction):
# create a MicroagentObservation
microagent_observation = MicroagentObservation(
# create a RecallObservation
microagent_observation = RecallObservation(
recall_type=RecallType.KNOWLEDGE,
content='microagent',
content='Found info',
)
microagent_observation._cause = event.id # ignore attr-defined warning
mock_event_stream.add_event(microagent_observation, EventSource.ENVIRONMENT)
@@ -111,14 +111,14 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s
# Give time for the async step() to execute
await asyncio.sleep(1)
# Verify that a MicroagentObservation was added to the event stream
# Verify that a RecallObservation was added to the event stream
events = list(mock_event_stream.get_events())
assert (
mock_event_stream.get_latest_event_id() == 3
) # Microagents and AgentChangeState
# a MicroagentObservation and an AgentDelegateAction should be in the list
assert any(isinstance(event, MicroagentObservation) for event in events)
# a RecallObservation and an AgentDelegateAction should be in the list
assert any(isinstance(event, RecallObservation) for event in events)
assert any(isinstance(event, AgentDelegateAction) for event in events)
# Verify that a delegate agent controller is created
+30 -2
View File
@@ -6,11 +6,11 @@ from litellm import ChatCompletionMessageToolCall
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
from openhands.agenthub.codeact_agent.function_calling import (
BrowserTool,
CmdRunTool,
IPythonTool,
LLMBasedFileEditTool,
StrReplaceEditorTool,
WebReadTool,
create_cmd_run_tool,
create_str_replace_editor_tool,
get_tools,
response_to_actions,
)
@@ -119,6 +119,7 @@ def test_get_tools_with_options():
def test_cmd_run_tool():
CmdRunTool = create_cmd_run_tool()
assert CmdRunTool['type'] == 'function'
assert CmdRunTool['function']['name'] == 'execute_bash'
assert 'command' in CmdRunTool['function']['parameters']['properties']
@@ -149,6 +150,7 @@ def test_llm_based_file_edit_tool():
def test_str_replace_editor_tool():
StrReplaceEditorTool = create_str_replace_editor_tool()
assert StrReplaceEditorTool['type'] == 'function'
assert StrReplaceEditorTool['function']['name'] == 'str_replace_editor'
@@ -236,7 +238,11 @@ def test_step_with_no_pending_actions(mock_state: State):
mock_response.choices[0].message.content = 'Task completed'
mock_response.choices[0].message.tool_calls = []
mock_config = Mock()
mock_config.model = 'mock_model'
llm = Mock()
llm.config = mock_config
llm.completion = Mock(return_value=mock_response)
llm.is_function_calling_active = Mock(return_value=True) # Enable function calling
llm.is_caching_prompt_active = Mock(return_value=False)
@@ -260,6 +266,28 @@ def test_step_with_no_pending_actions(mock_state: State):
assert action.content == 'Task completed'
def test_correct_tool_description_loaded_based_on_model_name(mock_state: State):
"""Tests that the simplified tool descriptions are loaded for specific models."""
o3_mock_config = Mock()
o3_mock_config.model = 'mock_o3_model'
llm = Mock()
llm.config = o3_mock_config
agent = CodeActAgent(llm=llm, config=AgentConfig())
for tool in agent.tools:
# Assert all descriptions have less than 1024 characters
assert len(tool['function']['description']) < 1024
sonnet_mock_config = Mock()
sonnet_mock_config.model = 'mock_sonnet_model'
llm.config = sonnet_mock_config
agent = CodeActAgent(llm=llm, config=AgentConfig())
# Assert existence of the detailed tool descriptions that are longer than 1024 characters
assert any(len(tool['function']['description']) > 1024 for tool in agent.tools)
def test_mismatched_tool_call_events(mock_state: State):
"""Tests that the agent can convert mismatched tool call events (i.e., an observation with no corresponding action) into messages."""
agent = CodeActAgent(llm=LLM(LLMConfig()), config=AgentConfig())
+1
View File
@@ -32,6 +32,7 @@ def _patch_store():
'selected_repository': 'foobar',
'conversation_id': 'some_conversation_id',
'github_user_id': '12345',
'user_id': '12345',
'created_at': '2025-01-01T00:00:00+00:00',
'last_updated_at': '2025-01-01T00:01:00+00:00',
}
+40 -40
View File
@@ -22,7 +22,7 @@ from openhands.events.event import (
from openhands.events.observation import CmdOutputObservation
from openhands.events.observation.agent import (
MicroagentKnowledge,
MicroagentObservation,
RecallObservation,
)
from openhands.events.observation.browse import BrowserOutputObservation
from openhands.events.observation.commands import (
@@ -51,7 +51,7 @@ def agent_config():
def conversation_memory(agent_config):
prompt_manager = MagicMock(spec=PromptManager)
prompt_manager.get_system_message.return_value = 'System message'
prompt_manager.build_additional_info.return_value = (
prompt_manager.build_workspace_context.return_value = (
'Formatted repository and runtime info'
)
@@ -353,10 +353,10 @@ def test_process_events_with_user_reject_observation(conversation_memory):
def test_process_events_with_empty_environment_info(conversation_memory):
"""Test that empty environment info observations return an empty list of messages without calling build_additional_info."""
# Create a MicroagentObservation with empty info
"""Test that empty environment info observations return an empty list of messages without calling build_workspace_context."""
# Create a RecallObservation with empty info
empty_obs = MicroagentObservation(
empty_obs = RecallObservation(
recall_type=RecallType.WORKSPACE_CONTEXT,
repo_name='',
repo_directory='',
@@ -382,8 +382,8 @@ def test_process_events_with_empty_environment_info(conversation_memory):
assert len(messages) == 1
assert messages[0].role == 'system'
# Verify that build_additional_info was NOT called since all input values were empty
conversation_memory.prompt_manager.build_additional_info.assert_not_called()
# Verify that build_workspace_context was NOT called since all input values were empty
conversation_memory.prompt_manager.build_workspace_context.assert_not_called()
def test_process_events_with_function_calling_observation(conversation_memory):
@@ -527,8 +527,8 @@ def test_apply_prompt_caching(conversation_memory):
def test_process_events_with_environment_microagent_observation(conversation_memory):
"""Test processing a MicroagentObservation with ENVIRONMENT info type."""
obs = MicroagentObservation(
"""Test processing a RecallObservation with ENVIRONMENT info type."""
obs = RecallObservation(
recall_type=RecallType.WORKSPACE_CONTEXT,
repo_name='test-repo',
repo_directory='/path/to/repo',
@@ -556,8 +556,8 @@ def test_process_events_with_environment_microagent_observation(conversation_mem
assert result.content[0].text == 'Formatted repository and runtime info'
# Verify the prompt_manager was called with the correct parameters
conversation_memory.prompt_manager.build_additional_info.assert_called_once()
call_args = conversation_memory.prompt_manager.build_additional_info.call_args[1]
conversation_memory.prompt_manager.build_workspace_context.assert_called_once()
call_args = conversation_memory.prompt_manager.build_workspace_context.call_args[1]
assert isinstance(call_args['repository_info'], RepositoryInfo)
assert call_args['repository_info'].repo_name == 'test-repo'
assert call_args['repository_info'].repo_directory == '/path/to/repo'
@@ -572,7 +572,7 @@ def test_process_events_with_environment_microagent_observation(conversation_mem
def test_process_events_with_knowledge_microagent_microagent_observation(
conversation_memory,
):
"""Test processing a MicroagentObservation with KNOWLEDGE type."""
"""Test processing a RecallObservation with KNOWLEDGE type."""
microagent_knowledge = [
MicroagentKnowledge(
name='test_agent',
@@ -591,7 +591,7 @@ def test_process_events_with_knowledge_microagent_microagent_observation(
),
]
obs = MicroagentObservation(
obs = RecallObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=microagent_knowledge,
content='Retrieved knowledge from microagents',
@@ -634,11 +634,11 @@ def test_process_events_with_knowledge_microagent_microagent_observation(
def test_process_events_with_microagent_observation_extensions_disabled(
agent_config, conversation_memory
):
"""Test processing a MicroagentObservation when prompt extensions are disabled."""
"""Test processing a RecallObservation when prompt extensions are disabled."""
# Modify the agent config to disable prompt extensions
agent_config.enable_prompt_extensions = False
obs = MicroagentObservation(
obs = RecallObservation(
recall_type=RecallType.WORKSPACE_CONTEXT,
repo_name='test-repo',
repo_directory='/path/to/repo',
@@ -656,18 +656,18 @@ def test_process_events_with_microagent_observation_extensions_disabled(
vision_is_active=False,
)
# When prompt extensions are disabled, the MicroagentObservation should be ignored
# When prompt extensions are disabled, the RecallObservation should be ignored
assert len(messages) == 1 # Only the initial system message
assert messages[0].role == 'system'
# Verify the prompt_manager was not called
conversation_memory.prompt_manager.build_additional_info.assert_not_called()
conversation_memory.prompt_manager.build_workspace_context.assert_not_called()
conversation_memory.prompt_manager.build_microagent_info.assert_not_called()
def test_process_events_with_empty_microagent_knowledge(conversation_memory):
"""Test processing a MicroagentObservation with empty microagent knowledge."""
obs = MicroagentObservation(
"""Test processing a RecallObservation with empty microagent knowledge."""
obs = RecallObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[],
content='Retrieved knowledge from microagents',
@@ -693,7 +693,7 @@ def test_process_events_with_empty_microagent_knowledge(conversation_memory):
def test_conversation_memory_processes_microagent_observation(prompt_dir):
"""Test that ConversationMemory processes MicroagentObservations correctly."""
"""Test that ConversationMemory processes RecallObservations correctly."""
# Create a microagent_info.j2 template file
template_path = os.path.join(prompt_dir, 'microagent_info.j2')
if not os.path.exists(template_path):
@@ -722,8 +722,8 @@ It may or may not be relevant to the user's request.
config=agent_config, prompt_manager=prompt_manager
)
# Create a MicroagentObservation with microagent knowledge
microagent_observation = MicroagentObservation(
# Create a RecallObservation with microagent knowledge
microagent_observation = RecallObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[
MicroagentKnowledge(
@@ -761,7 +761,7 @@ This is triggered content for testing.
def test_conversation_memory_processes_environment_microagent_observation(prompt_dir):
"""Test that ConversationMemory processes environment info MicroagentObservations correctly."""
"""Test that ConversationMemory processes environment info RecallObservations correctly."""
# Create an additional_info.j2 template file
template_path = os.path.join(prompt_dir, 'additional_info.j2')
if not os.path.exists(template_path):
@@ -802,8 +802,8 @@ each of which has a corresponding port:
config=agent_config, prompt_manager=prompt_manager
)
# Create a MicroagentObservation with environment info
microagent_observation = MicroagentObservation(
# Create a RecallObservation with environment info
microagent_observation = RecallObservation(
recall_type=RecallType.WORKSPACE_CONTEXT,
repo_name='owner/repo',
repo_directory='/workspace/repo',
@@ -839,13 +839,13 @@ each of which has a corresponding port:
def test_process_events_with_microagent_observation_deduplication(conversation_memory):
"""Test that MicroagentObservations are properly deduplicated based on agent name.
"""Test that RecallObservations are properly deduplicated based on agent name.
The deduplication logic should keep the FIRST occurrence of each microagent
and filter out later occurrences to avoid redundant information.
"""
# Create a sequence of MicroagentObservations with overlapping agents
obs1 = MicroagentObservation(
# Create a sequence of RecallObservations with overlapping agents
obs1 = RecallObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[
MicroagentKnowledge(
@@ -867,7 +867,7 @@ def test_process_events_with_microagent_observation_deduplication(conversation_m
content='First retrieval',
)
obs2 = MicroagentObservation(
obs2 = RecallObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[
MicroagentKnowledge(
@@ -879,7 +879,7 @@ def test_process_events_with_microagent_observation_deduplication(conversation_m
content='Second retrieval',
)
obs3 = MicroagentObservation(
obs3 = RecallObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[
MicroagentKnowledge(
@@ -918,8 +918,8 @@ def test_process_events_with_microagent_observation_deduplication_disabled_agent
conversation_memory,
):
"""Test that disabled agents are filtered out and deduplication keeps the first occurrence."""
# Create a sequence of MicroagentObservations with disabled agents
obs1 = MicroagentObservation(
# Create a sequence of RecallObservations with disabled agents
obs1 = RecallObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[
MicroagentKnowledge(
@@ -936,7 +936,7 @@ def test_process_events_with_microagent_observation_deduplication_disabled_agent
content='First retrieval',
)
obs2 = MicroagentObservation(
obs2 = RecallObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[
MicroagentKnowledge(
@@ -973,8 +973,8 @@ def test_process_events_with_microagent_observation_deduplication_disabled_agent
def test_process_events_with_microagent_observation_deduplication_empty(
conversation_memory,
):
"""Test that empty MicroagentObservations are handled correctly."""
obs = MicroagentObservation(
"""Test that empty RecallObservations are handled correctly."""
obs = RecallObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[],
content='Empty retrieval',
@@ -991,7 +991,7 @@ def test_process_events_with_microagent_observation_deduplication_empty(
vision_is_active=False,
)
# Verify that empty MicroagentObservations are handled gracefully
# Verify that empty RecallObservations are handled gracefully
assert (
len(messages) == 1
) # system message, because an empty microagent is not added to Messages
@@ -999,8 +999,8 @@ def test_process_events_with_microagent_observation_deduplication_empty(
def test_has_agent_in_earlier_events(conversation_memory):
"""Test the _has_agent_in_earlier_events helper method."""
# Create test MicroagentObservations
obs1 = MicroagentObservation(
# Create test RecallObservations
obs1 = RecallObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[
MicroagentKnowledge(
@@ -1012,7 +1012,7 @@ def test_has_agent_in_earlier_events(conversation_memory):
content='First retrieval',
)
obs2 = MicroagentObservation(
obs2 = RecallObservation(
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[
MicroagentKnowledge(
@@ -1024,7 +1024,7 @@ def test_has_agent_in_earlier_events(conversation_memory):
content='Second retrieval',
)
obs3 = MicroagentObservation(
obs3 = RecallObservation(
recall_type=RecallType.WORKSPACE_CONTEXT,
content='Environment info',
)
+11 -1
View File
@@ -13,7 +13,8 @@ async def test_load_store():
store = FileConversationStore(InMemoryFileStore({}))
expected = ConversationMetadata(
conversation_id='some-conversation-id',
github_user_id='some-user-id',
user_id='some-user-id',
github_user_id='12345',
selected_repository='some-repo',
title="Let's talk about trains",
)
@@ -31,6 +32,7 @@ async def test_load_int_user_id():
{
'conversation_id': 'some-conversation-id',
'github_user_id': 12345,
'user_id': '67890',
'selected_repository': 'some-repo',
'title': "Let's talk about trains",
'created_at': '2025-01-16T19:51:04.886331Z',
@@ -41,6 +43,7 @@ async def test_load_int_user_id():
)
found = await store.get_metadata('some-conversation-id')
assert found.github_user_id == '12345'
assert found.user_id == '67890'
@pytest.mark.asyncio
@@ -61,6 +64,7 @@ async def test_search_basic():
{
'conversation_id': 'conv1',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': 'First conversation',
'created_at': '2025-01-16T19:51:04Z',
@@ -70,6 +74,7 @@ async def test_search_basic():
{
'conversation_id': 'conv2',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': 'Second conversation',
'created_at': '2025-01-17T19:51:04Z',
@@ -79,6 +84,7 @@ async def test_search_basic():
{
'conversation_id': 'conv3',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': 'Third conversation',
'created_at': '2025-01-15T19:51:04Z',
@@ -107,6 +113,7 @@ async def test_search_pagination():
{
'conversation_id': f'conv{i}',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': f'Conversation {i}',
'created_at': f'2025-01-{15+i}T19:51:04Z',
@@ -148,6 +155,7 @@ async def test_search_with_invalid_conversation():
{
'conversation_id': 'conv1',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': 'Valid conversation',
'created_at': '2025-01-16T19:51:04Z',
@@ -176,6 +184,7 @@ async def test_get_all_metadata():
{
'conversation_id': 'conv1',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': 'First conversation',
'created_at': '2025-01-16T19:51:04Z',
@@ -185,6 +194,7 @@ async def test_get_all_metadata():
{
'conversation_id': 'conv2',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': 'Second conversation',
'created_at': '2025-01-17T19:51:04Z',
+21 -17
View File
@@ -14,7 +14,7 @@ from openhands.events.action.agent import RecallAction
from openhands.events.action.message import MessageAction
from openhands.events.event import EventSource
from openhands.events.observation.agent import (
MicroagentObservation,
RecallObservation,
RecallType,
)
from openhands.events.stream import EventStream
@@ -74,7 +74,7 @@ async def test_memory_on_event_exception_handling(memory, event_stream):
# Mock Memory method to raise an exception
with patch.object(
memory, '_on_first_microagent_action', side_effect=Exception('Test error')
memory, '_on_workspace_context_recall', side_effect=Exception('Test error')
):
state = await run_controller(
config=AppConfig(),
@@ -93,10 +93,10 @@ async def test_memory_on_event_exception_handling(memory, event_stream):
@pytest.mark.asyncio
async def test_memory_on_first_microagent_action_exception_handling(
async def test_memory_on_workspace_context_recall_exception_handling(
memory, event_stream
):
"""Test that exceptions in Memory._on_first_microagent_action are properly handled via status callback."""
"""Test that exceptions in Memory._on_workspace_context_recall are properly handled via status callback."""
# Create a dummy agent for the controller
agent = MagicMock(spec=Agent)
@@ -108,11 +108,11 @@ async def test_memory_on_first_microagent_action_exception_handling(
runtime = MagicMock(spec=Runtime)
runtime.event_stream = event_stream
# Mock Memory._on_first_microagent_action to raise an exception
# Mock Memory._on_workspace_context_recall to raise an exception
with patch.object(
memory,
'_on_first_microagent_action',
side_effect=Exception('Test error from _on_first_microagent_action'),
'_find_microagent_knowledge',
side_effect=Exception('Test error from _find_microagent_knowledge'),
):
state = await run_controller(
config=AppConfig(),
@@ -130,12 +130,13 @@ async def test_memory_on_first_microagent_action_exception_handling(
assert state.last_error == 'Error: Exception'
def test_memory_with_microagents():
@pytest.mark.asyncio
async def test_memory_with_microagents():
"""Test that Memory loads microagents from the global directory and processes microagent actions.
This test verifies that:
1. Memory loads microagents from the global GLOBAL_MICROAGENTS_DIR
2. When a microagent action with a trigger word is processed, a MicroagentObservation is created
2. When a microagent action with a trigger word is processed, a RecallObservation is created
"""
# Create a mock event stream
event_stream = MagicMock(spec=EventStream)
@@ -158,6 +159,9 @@ def test_memory_with_microagents():
query='Hello, flarglebargle!', recall_type=RecallType.KNOWLEDGE
)
# Set the source to USER
microagent_action._source = EventSource.USER # type: ignore[attr-defined]
# Mock the event_stream.add_event method
added_events = []
@@ -173,12 +177,12 @@ def test_memory_with_microagents():
added_events.clear()
# Process the microagent action
memory.on_event(microagent_action)
await memory._on_event(microagent_action)
# Verify a MicroagentObservation was added to the event stream
# Verify a RecallObservation was added to the event stream
assert len(added_events) == 1
observation, source = added_events[0]
assert isinstance(observation, MicroagentObservation)
assert isinstance(observation, RecallObservation)
assert source == EventSource.ENVIRONMENT
assert observation.recall_type == RecallType.KNOWLEDGE
assert len(observation.microagent_knowledge) == 1
@@ -188,7 +192,7 @@ def test_memory_with_microagents():
def test_memory_repository_info(prompt_dir):
"""Test that Memory adds repository info to MicroagentObservations."""
"""Test that Memory adds repository info to RecallObservations."""
# Create an in-memory file store and real event stream
file_store = InMemoryFileStore()
event_stream = EventStream(sid='test-session', file_store=file_store)
@@ -241,15 +245,15 @@ REPOSITORY INSTRUCTIONS: This is a test repository.
# Get all events from the stream
events = list(event_stream.get_events())
# Find the MicroagentObservation event
# Find the RecallObservation event
microagent_obs_events = [
event for event in events if isinstance(event, MicroagentObservation)
event for event in events if isinstance(event, RecallObservation)
]
# We should have at least one MicroagentObservation
# We should have at least one RecallObservation
assert len(microagent_obs_events) > 0
# Get the first MicroagentObservation
# Get the first RecallObservation
observation = microagent_obs_events[0]
assert observation.recall_type == RecallType.WORKSPACE_CONTEXT
assert observation.repo_name == 'owner/repo'
+21 -21
View File
@@ -5,8 +5,8 @@ from openhands.events.observation import (
CmdOutputMetadata,
CmdOutputObservation,
FileEditObservation,
MicroagentObservation,
Observation,
RecallObservation,
)
from openhands.events.observation.agent import MicroagentKnowledge
from openhands.events.serialization import (
@@ -245,9 +245,9 @@ def test_file_edit_observation_legacy_serialization():
def test_microagent_observation_serialization():
original_observation_dict = {
'observation': 'microagent',
'observation': 'recall',
'content': '',
'message': "**MicroagentObservation**\nrecall_type=RecallType.WORKSPACE_CONTEXT, repo_name=some_repo_name, repo_instructions=complex_repo_instruc..., runtime_hosts={'host1': 8080, 'host2': 8081}, additional_agent_instructions=You know it all abou...",
'message': 'Added workspace context',
'extras': {
'recall_type': 'workspace_context',
'repo_name': 'some_repo_name',
@@ -258,14 +258,14 @@ def test_microagent_observation_serialization():
'microagent_knowledge': [],
},
}
serialization_deserialization(original_observation_dict, MicroagentObservation)
serialization_deserialization(original_observation_dict, RecallObservation)
def test_microagent_observation_microagent_knowledge_serialization():
original_observation_dict = {
'observation': 'microagent',
'observation': 'recall',
'content': '',
'message': '**MicroagentObservation**\nrecall_type=RecallType.KNOWLEDGE, repo_name=, repo_instructions=..., runtime_hosts={}, additional_agent_instructions=..., microagent_knowledge=microagent1, microagent2',
'message': 'Added microagent knowledge',
'extras': {
'recall_type': 'knowledge',
'repo_name': '',
@@ -287,13 +287,13 @@ def test_microagent_observation_microagent_knowledge_serialization():
],
},
}
serialization_deserialization(original_observation_dict, MicroagentObservation)
serialization_deserialization(original_observation_dict, RecallObservation)
def test_microagent_observation_knowledge_microagent_serialization():
"""Test serialization of a MicroagentObservation with KNOWLEDGE_MICROAGENT type."""
# Create a MicroagentObservation with microagent knowledge content
original = MicroagentObservation(
"""Test serialization of a RecallObservation with KNOWLEDGE_MICROAGENT type."""
# Create a RecallObservation with microagent knowledge content
original = RecallObservation(
content='Knowledge microagent information',
recall_type=RecallType.KNOWLEDGE,
microagent_knowledge=[
@@ -314,13 +314,13 @@ def test_microagent_observation_knowledge_microagent_serialization():
serialized = event_to_dict(original)
# Verify serialized data structure
assert serialized['observation'] == ObservationType.MICROAGENT
assert serialized['observation'] == ObservationType.RECALL
assert serialized['content'] == 'Knowledge microagent information'
assert serialized['extras']['recall_type'] == RecallType.KNOWLEDGE.value
assert len(serialized['extras']['microagent_knowledge']) == 2
assert serialized['extras']['microagent_knowledge'][0]['trigger'] == 'python'
# Deserialize back to MicroagentObservation
# Deserialize back to RecallObservation
deserialized = observation_from_dict(serialized)
# Verify properties are preserved
@@ -336,9 +336,9 @@ def test_microagent_observation_knowledge_microagent_serialization():
def test_microagent_observation_environment_serialization():
"""Test serialization of a MicroagentObservation with ENVIRONMENT type."""
# Create a MicroagentObservation with environment info
original = MicroagentObservation(
"""Test serialization of a RecallObservation with ENVIRONMENT type."""
# Create a RecallObservation with environment info
original = RecallObservation(
content='Environment information',
recall_type=RecallType.WORKSPACE_CONTEXT,
repo_name='OpenHands',
@@ -352,7 +352,7 @@ def test_microagent_observation_environment_serialization():
serialized = event_to_dict(original)
# Verify serialized data structure
assert serialized['observation'] == ObservationType.MICROAGENT
assert serialized['observation'] == ObservationType.RECALL
assert serialized['content'] == 'Environment information'
assert serialized['extras']['recall_type'] == RecallType.WORKSPACE_CONTEXT.value
assert serialized['extras']['repo_name'] == 'OpenHands'
@@ -364,7 +364,7 @@ def test_microagent_observation_environment_serialization():
serialized['extras']['additional_agent_instructions']
== 'You know it all about this runtime'
)
# Deserialize back to MicroagentObservation
# Deserialize back to RecallObservation
deserialized = observation_from_dict(serialized)
# Verify properties are preserved
@@ -382,11 +382,11 @@ def test_microagent_observation_environment_serialization():
def test_microagent_observation_combined_serialization():
"""Test serialization of a MicroagentObservation with both types of information."""
# Create a MicroagentObservation with both environment and microagent info
"""Test serialization of a RecallObservation with both types of information."""
# Create a RecallObservation with both environment and microagent info
# Note: In practice, recall_type would still be one specific type,
# but the object could contain both types of fields
original = MicroagentObservation(
original = RecallObservation(
content='Combined information',
recall_type=RecallType.WORKSPACE_CONTEXT,
# Environment info
@@ -419,7 +419,7 @@ def test_microagent_observation_combined_serialization():
serialized['extras']['additional_agent_instructions']
== 'You know it all about this runtime'
)
# Deserialize back to MicroagentObservation
# Deserialize back to RecallObservation
deserialized = observation_from_dict(serialized)
# Verify all properties are preserved
+3 -3
View File
@@ -51,7 +51,7 @@ At the user's request, repository {{ repository_info.repo_name }} has been clone
assert 'System prompt: bar' in system_msg
# Test building additional info
additional_info = manager.build_additional_info(
additional_info = manager.build_workspace_context(
repository_info=repo_info, runtime_info=None, repo_instructions=''
)
assert '<REPOSITORY_INFO>' in additional_info
@@ -199,7 +199,7 @@ def test_add_turns_left_reminder(prompt_dir):
)
def test_build_additional_info_with_repo_and_runtime(prompt_dir):
def test_build_workspace_context_with_repo_and_runtime(prompt_dir):
"""Test building additional info with repository and runtime information."""
# Create an additional_info.j2 template file
with open(os.path.join(prompt_dir, 'additional_info.j2'), 'w') as f:
@@ -245,7 +245,7 @@ each of which has a corresponding port:
repo_instructions = 'This repository contains important code.'
# Build additional info
result = manager.build_additional_info(
result = manager.build_workspace_context(
repository_info=repo_info,
runtime_info=runtime_info,
repo_instructions=repo_instructions,
+4
View File
@@ -49,6 +49,7 @@ async def test_iterate_single_page():
{
'conversation_id': 'conv1',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': 'First conversation',
'created_at': '2025-01-16T19:51:04Z',
@@ -58,6 +59,7 @@ async def test_iterate_single_page():
{
'conversation_id': 'conv2',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': 'Second conversation',
'created_at': '2025-01-17T19:51:04Z',
@@ -86,6 +88,7 @@ async def test_iterate_multiple_pages():
{
'conversation_id': f'conv{i}',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': f'Conversation {i}',
'created_at': f'2025-01-{15+i}T19:51:04Z',
@@ -120,6 +123,7 @@ async def test_iterate_with_invalid_conversation():
{
'conversation_id': 'conv1',
'github_user_id': '123',
'user_id': '123',
'selected_repository': 'repo1',
'title': 'Valid conversation',
'created_at': '2025-01-16T19:51:04Z',
@@ -61,7 +61,7 @@ async def test_init_new_local_session():
'new-session-id', ConversationInitData(), 1
)
await conversation_manager.join_conversation(
'new-session-id', 'new-session-id', ConversationInitData(), 1
'new-session-id', 'new-session-id', ConversationInitData(), 1, '12345'
)
assert session_instance.initialize_agent.call_count == 1
assert sio.enter_room.await_count == 1
@@ -93,10 +93,18 @@ async def test_join_local_session():
'new-session-id', ConversationInitData(), None
)
await conversation_manager.join_conversation(
'new-session-id', 'new-session-id', ConversationInitData(), None
'new-session-id',
'new-session-id',
ConversationInitData(),
None,
'12345',
)
await conversation_manager.join_conversation(
'new-session-id', 'new-session-id', ConversationInitData(), None
'new-session-id',
'new-session-id',
ConversationInitData(),
None,
'12345',
)
assert session_instance.initialize_agent.call_count == 1
assert sio.enter_room.await_count == 2
@@ -128,7 +136,7 @@ async def test_add_to_local_event_stream():
'new-session-id', ConversationInitData(), 1
)
await conversation_manager.join_conversation(
'new-session-id', 'connection-id', ConversationInitData(), 1
'new-session-id', 'connection-id', ConversationInitData(), 1, '12345'
)
await conversation_manager.send_to_event_stream(
'connection-id', {'event_type': 'some_event'}
+7 -1
View File
@@ -23,7 +23,13 @@ def mock_event_stream():
def mock_agent():
agent = MagicMock()
agent.llm = MagicMock()
agent.llm.config = MagicMock()
# Create a step function that returns an action without an ID
def agent_step_fn(state):
return MessageAction(content='Agent returned a message')
agent.step = agent_step_fn
return agent