mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
64 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| bc27eae841 | |||
| f0e5c81272 | |||
| a4b836b5f9 | |||
| a4d632498c | |||
| 4f017081fc | |||
| 51fb1fae88 | |||
| 106b230fea | |||
| 9b262dd057 | |||
| 8074b261d3 | |||
| 999a59f938 | |||
| fbba57d3b5 | |||
| 3f6c8a2338 | |||
| dd09d46ccb | |||
| 8897b45eeb | |||
| 30109e8f20 | |||
| cc45f5d9c3 | |||
| e34a771e66 | |||
| ec763f8105 | |||
| 4984bf6ee7 | |||
| 92ddc1b46c | |||
| 367c8a9f83 | |||
| 09335d67be | |||
| eb36426d33 | |||
| ace9e6e724 | |||
| 1290a2599d | |||
| 513dd9791d | |||
| fd53378d06 | |||
| 4471002c79 | |||
| 326e75e829 | |||
| d8ad8babf6 | |||
| 8782e3ae65 | |||
| eef0ed3410 | |||
| e7a8daf3ec | |||
| 64abd4a95e | |||
| c7d575b4e1 | |||
| f781bc8343 | |||
| 9f9a65c787 | |||
| 3355baea4c | |||
| 8848e60c6d | |||
| d1e84093cc | |||
| 3f0f13d335 | |||
| 219a134bb0 | |||
| efb525a463 | |||
| 1ded123116 | |||
| 31b6967a87 | |||
| 90422e5bfd | |||
| b47da9e894 | |||
| 3401bd610d | |||
| 77a153e42f | |||
| fb9bc87e35 | |||
| b685c67263 | |||
| 2cd64bc636 | |||
| 7c81deb132 | |||
| b19f735808 | |||
| 3af6025303 | |||
| 585dba9917 | |||
| bd66d09a33 | |||
| 791b7f9f60 | |||
| 30197e616b | |||
| f7f25319e3 | |||
| 75fba59588 | |||
| 280baa24e2 | |||
| c6206f5cf2 | |||
| 7a4729c034 |
@@ -12,7 +12,7 @@
|
|||||||
<a href="https://codecov.io/github/All-Hands-AI/OpenHands?branch=main"><img alt="CodeCov" src="https://img.shields.io/codecov/c/github/All-Hands-AI/OpenHands?style=for-the-badge&color=blue"></a>
|
<a href="https://codecov.io/github/All-Hands-AI/OpenHands?branch=main"><img alt="CodeCov" src="https://img.shields.io/codecov/c/github/All-Hands-AI/OpenHands?style=for-the-badge&color=blue"></a>
|
||||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/github/license/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="MIT License"></a>
|
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/github/license/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="MIT License"></a>
|
||||||
<br/>
|
<br/>
|
||||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community"></a>
|
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community"></a>
|
||||||
<a href="https://discord.gg/ESHStjSjD4"><img src="https://img.shields.io/badge/Discord-Join%20Us-purple?logo=discord&logoColor=white&style=for-the-badge" alt="Join our Discord community"></a>
|
<a href="https://discord.gg/ESHStjSjD4"><img src="https://img.shields.io/badge/Discord-Join%20Us-purple?logo=discord&logoColor=white&style=for-the-badge" alt="Join our Discord community"></a>
|
||||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/CREDITS.md"><img src="https://img.shields.io/badge/Project-Credits-blue?style=for-the-badge&color=FFE165&logo=github&logoColor=white" alt="Credits"></a>
|
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/CREDITS.md"><img src="https://img.shields.io/badge/Project-Credits-blue?style=for-the-badge&color=FFE165&logo=github&logoColor=white" alt="Credits"></a>
|
||||||
<br/>
|
<br/>
|
||||||
@@ -96,7 +96,7 @@ troubleshooting resources, and advanced configuration options.
|
|||||||
OpenHands is a community-driven project, and we welcome contributions from everyone. We do most of our communication
|
OpenHands is a community-driven project, and we welcome contributions from everyone. We do most of our communication
|
||||||
through Slack, so this is the best place to start, but we also are happy to have you contact us on Discord or Github:
|
through Slack, so this is the best place to start, but we also are happy to have you contact us on Discord or Github:
|
||||||
|
|
||||||
- [Join our Slack workspace](https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg) - Here we talk about research, architecture, and future development.
|
- [Join our Slack workspace](https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw) - Here we talk about research, architecture, and future development.
|
||||||
- [Join our Discord server](https://discord.gg/ESHStjSjD4) - This is a community-run server for general discussion, questions, and feedback.
|
- [Join our Discord server](https://discord.gg/ESHStjSjD4) - This is a community-run server for general discussion, questions, and feedback.
|
||||||
- [Read or post Github Issues](https://github.com/All-Hands-AI/OpenHands/issues) - Check out the issues we're working on, or add your own ideas.
|
- [Read or post Github Issues](https://github.com/All-Hands-AI/OpenHands/issues) - Check out the issues we're working on, or add your own ideas.
|
||||||
|
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ Explorez le code source d'OpenHands sur [GitHub](https://github.com/All-Hands-AI
|
|||||||
/>
|
/>
|
||||||
</a>
|
</a>
|
||||||
<br></br>
|
<br></br>
|
||||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg">
|
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw">
|
||||||
<img
|
<img
|
||||||
src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge"
|
src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge"
|
||||||
alt="Join our Slack community"
|
alt="Join our Slack community"
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ OpenHands 是一个**自主 AI 软件工程师**,能够执行复杂的工程
|
|||||||
/>
|
/>
|
||||||
</a>
|
</a>
|
||||||
<br></br>
|
<br></br>
|
||||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg">
|
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw">
|
||||||
<img
|
<img
|
||||||
src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge"
|
src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge"
|
||||||
alt="Join our Slack community"
|
alt="Join our Slack community"
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ function CustomFooter() {
|
|||||||
<footer className="custom-footer">
|
<footer className="custom-footer">
|
||||||
<div className="footer-content">
|
<div className="footer-content">
|
||||||
<div className="footer-icons">
|
<div className="footer-icons">
|
||||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg" target="_blank" rel="noopener noreferrer">
|
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw" target="_blank" rel="noopener noreferrer">
|
||||||
<FaSlack />
|
<FaSlack />
|
||||||
</a>
|
</a>
|
||||||
<a href="https://discord.gg/ESHStjSjD4" target="_blank" rel="noopener noreferrer">
|
<a href="https://discord.gg/ESHStjSjD4" target="_blank" rel="noopener noreferrer">
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ export function HomepageHeader() {
|
|||||||
<a href="https://codecov.io/github/All-Hands-AI/OpenHands?branch=main"><img alt="CodeCov" src="https://img.shields.io/codecov/c/github/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" /></a>
|
<a href="https://codecov.io/github/All-Hands-AI/OpenHands?branch=main"><img alt="CodeCov" src="https://img.shields.io/codecov/c/github/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" /></a>
|
||||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/github/license/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="MIT License" /></a>
|
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/github/license/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="MIT License" /></a>
|
||||||
<br/>
|
<br/>
|
||||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ypg5jweb-d~6hObZDbXi_HEL8PDrbHg"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community" /></a>
|
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-2ngejmfw6-9gW4APWOC9XUp1n~SiQ6iw"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community" /></a>
|
||||||
<a href="https://discord.gg/ESHStjSjD4"><img src="https://img.shields.io/badge/Discord-Join%20Us-purple?logo=discord&logoColor=white&style=for-the-badge" alt="Join our Discord community" /></a>
|
<a href="https://discord.gg/ESHStjSjD4"><img src="https://img.shields.io/badge/Discord-Join%20Us-purple?logo=discord&logoColor=white&style=for-the-badge" alt="Join our Discord community" /></a>
|
||||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/CREDITS.md"><img src="https://img.shields.io/badge/Project-Credits-blue?style=for-the-badge&color=FFE165&logo=github&logoColor=white" alt="Credits" /></a>
|
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/CREDITS.md"><img src="https://img.shields.io/badge/Project-Credits-blue?style=for-the-badge&color=FFE165&logo=github&logoColor=white" alt="Credits" /></a>
|
||||||
<br/>
|
<br/>
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -175,6 +176,11 @@ def process_instance(
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f'This is the {runtime_failure_count + 1}th attempt for instance {instance.instance_id}, setting resource factor to {config.sandbox.remote_runtime_resource_factor}'
|
f'This is the {runtime_failure_count + 1}th attempt for instance {instance.instance_id}, setting resource factor to {config.sandbox.remote_runtime_resource_factor}'
|
||||||
)
|
)
|
||||||
|
metadata = copy.deepcopy(metadata)
|
||||||
|
metadata.details['runtime_failure_count'] = runtime_failure_count
|
||||||
|
metadata.details['remote_runtime_resource_factor'] = (
|
||||||
|
config.sandbox.remote_runtime_resource_factor
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
runtime = create_runtime(config)
|
runtime = create_runtime(config)
|
||||||
@@ -296,14 +302,20 @@ def process_instance(
|
|||||||
with open(test_output_path, 'w') as f:
|
with open(test_output_path, 'w') as f:
|
||||||
f.write(test_output)
|
f.write(test_output)
|
||||||
try:
|
try:
|
||||||
|
extra_kwargs = {}
|
||||||
|
if 'SWE-Gym' in metadata.dataset:
|
||||||
|
# SWE-Gym uses a different version of the package, hence a different eval report argument
|
||||||
|
extra_kwargs['log_path'] = test_output_path
|
||||||
|
else:
|
||||||
|
extra_kwargs['test_log_path'] = test_output_path
|
||||||
_report = conditional_imports.get_eval_report(
|
_report = conditional_imports.get_eval_report(
|
||||||
test_spec=test_spec,
|
test_spec=test_spec,
|
||||||
prediction={
|
prediction={
|
||||||
'model_patch': model_patch,
|
'model_patch': model_patch,
|
||||||
'instance_id': instance_id,
|
'instance_id': instance_id,
|
||||||
},
|
},
|
||||||
test_log_path=test_output_path,
|
|
||||||
include_tests_status=True,
|
include_tests_status=True,
|
||||||
|
**extra_kwargs,
|
||||||
)
|
)
|
||||||
report = _report[instance_id]
|
report = _report[instance_id]
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -463,6 +475,7 @@ if __name__ == '__main__':
|
|||||||
.decode('utf-8')
|
.decode('utf-8')
|
||||||
.strip(), # Current commit
|
.strip(), # Current commit
|
||||||
dataset=args.dataset, # Dataset name from args
|
dataset=args.dataset, # Dataset name from args
|
||||||
|
details={},
|
||||||
)
|
)
|
||||||
|
|
||||||
# The evaluation harness constrains the signature of `process_instance_func` but we need to
|
# The evaluation harness constrains the signature of `process_instance_func` but we need to
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -23,7 +23,7 @@ def get_resource_mapping(dataset_name: str) -> dict[str, float]:
|
|||||||
if dataset_name not in _global_resource_mapping:
|
if dataset_name not in _global_resource_mapping:
|
||||||
file_path = os.path.join(CUR_DIR, f'{dataset_name}.json')
|
file_path = os.path.join(CUR_DIR, f'{dataset_name}.json')
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
logger.warning(f'Resource mapping for {dataset_name} not found.')
|
logger.info(f'Resource mapping for {dataset_name} not found.')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, 'r') as f:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -149,7 +150,8 @@ def get_config(
|
|||||||
) -> AppConfig:
|
) -> AppConfig:
|
||||||
# We use a different instance image for the each instance of swe-bench eval
|
# We use a different instance image for the each instance of swe-bench eval
|
||||||
use_official_image = bool(
|
use_official_image = bool(
|
||||||
'verified' in metadata.dataset.lower() or 'lite' in metadata.dataset.lower()
|
('verified' in metadata.dataset.lower() or 'lite' in metadata.dataset.lower())
|
||||||
|
and 'swe-gym' not in metadata.dataset.lower()
|
||||||
)
|
)
|
||||||
base_container_image = get_instance_docker_image(
|
base_container_image = get_instance_docker_image(
|
||||||
instance['instance_id'], use_official_image
|
instance['instance_id'], use_official_image
|
||||||
@@ -475,6 +477,13 @@ def process_instance(
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f'This is the {runtime_failure_count + 1}th attempt for instance {instance.instance_id}, setting resource factor to {config.sandbox.remote_runtime_resource_factor}'
|
f'This is the {runtime_failure_count + 1}th attempt for instance {instance.instance_id}, setting resource factor to {config.sandbox.remote_runtime_resource_factor}'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
metadata = copy.deepcopy(metadata)
|
||||||
|
metadata.details['runtime_failure_count'] = runtime_failure_count
|
||||||
|
metadata.details['remote_runtime_resource_factor'] = (
|
||||||
|
config.sandbox.remote_runtime_resource_factor
|
||||||
|
)
|
||||||
|
|
||||||
runtime = create_runtime(config)
|
runtime = create_runtime(config)
|
||||||
call_async_from_sync(runtime.connect)
|
call_async_from_sync(runtime.connect)
|
||||||
|
|
||||||
@@ -560,20 +569,6 @@ def filter_dataset(dataset: pd.DataFrame, filter_column: str) -> pd.DataFrame:
|
|||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
# A list of instances that are known to be tricky to infer
|
|
||||||
# (will cause runtime failure even with resource factor = 8)
|
|
||||||
SWEGYM_EXCLUDE_IDS = [
|
|
||||||
'dask__dask-10422',
|
|
||||||
'pandas-dev__pandas-50548',
|
|
||||||
'pandas-dev__pandas-53672',
|
|
||||||
'pandas-dev__pandas-54174',
|
|
||||||
'pandas-dev__pandas-55518',
|
|
||||||
'pandas-dev__pandas-58383',
|
|
||||||
'pydata__xarray-6721',
|
|
||||||
'pytest-dev__pytest-10081',
|
|
||||||
'pytest-dev__pytest-7236',
|
|
||||||
]
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -598,11 +593,20 @@ if __name__ == '__main__':
|
|||||||
f'Loaded dataset {args.dataset} with split {args.split}: {len(swe_bench_tests)} tasks'
|
f'Loaded dataset {args.dataset} with split {args.split}: {len(swe_bench_tests)} tasks'
|
||||||
)
|
)
|
||||||
if 'SWE-Gym' in args.dataset:
|
if 'SWE-Gym' in args.dataset:
|
||||||
swe_bench_tests = swe_bench_tests[
|
with open(
|
||||||
~swe_bench_tests['instance_id'].isin(SWEGYM_EXCLUDE_IDS)
|
os.path.join(
|
||||||
]
|
os.path.dirname(os.path.abspath(__file__)),
|
||||||
|
'split',
|
||||||
|
'swegym_verified_instances.json',
|
||||||
|
),
|
||||||
|
'r',
|
||||||
|
) as f:
|
||||||
|
swegym_verified_instances = json.load(f)
|
||||||
|
swe_bench_tests = swe_bench_tests[
|
||||||
|
swe_bench_tests['instance_id'].isin(swegym_verified_instances)
|
||||||
|
]
|
||||||
logger.info(
|
logger.info(
|
||||||
f'{len(swe_bench_tests)} tasks left after excluding SWE-Gym excluded tasks'
|
f'{len(swe_bench_tests)} tasks left after filtering for SWE-Gym verified instances'
|
||||||
)
|
)
|
||||||
|
|
||||||
llm_config = None
|
llm_config = None
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ parser.add_argument(
|
|||||||
'--dataset_name',
|
'--dataset_name',
|
||||||
type=str,
|
type=str,
|
||||||
help='Name of the dataset to download',
|
help='Name of the dataset to download',
|
||||||
default='princeton-nlp/SWE-bench_Lite',
|
default='princeton-nlp/SWE-bench_Verified',
|
||||||
)
|
)
|
||||||
parser.add_argument('--split', type=str, help='Split to download', default='test')
|
parser.add_argument('--split', type=str, help='Split to download', default='test')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -20,7 +20,12 @@ print(
|
|||||||
f'Downloading gold patches from {args.dataset_name} (split: {args.split}) to {output_filepath}'
|
f'Downloading gold patches from {args.dataset_name} (split: {args.split}) to {output_filepath}'
|
||||||
)
|
)
|
||||||
patches = [
|
patches = [
|
||||||
{'instance_id': row['instance_id'], 'model_patch': row['patch']} for row in dataset
|
{
|
||||||
|
'instance_id': row['instance_id'],
|
||||||
|
'model_patch': row['patch'],
|
||||||
|
'model_name_or_path': 'gold',
|
||||||
|
}
|
||||||
|
for row in dataset
|
||||||
]
|
]
|
||||||
print(f'{len(patches)} gold patches loaded')
|
print(f'{len(patches)} gold patches loaded')
|
||||||
pd.DataFrame(patches).to_json(output_filepath, lines=True, orient='records')
|
pd.DataFrame(patches).to_json(output_filepath, lines=True, orient='records')
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,12 @@
|
|||||||
|
codamosa_ids = ['pydata__xarray-4750-16496', 'pydata__xarray-3239-16458', 'pydata__xarray-4966-16515', 'pydata__xarray-3302-16459', 'pydata__xarray-5126-16518', 'pydata__xarray-4994-16516', 'pydata__xarray-3905-16478', 'pydata__xarray-4182-16484', 'pydata__xarray-5131-16520', 'pydata__xarray-5662-16532', 'pydata__xarray-3364-16461', 'pydata__xarray-5731-16534', 'pydata__xarray-3239-16457', 'pydata__xarray-7203-16577', 'pydata__xarray-3156-16454', 'pydata__xarray-5126-16519', 'pydata__xarray-5365-16529', 'pydata__xarray-4629-16492', 'pydata__xarray-4248-16486', 'pydata__xarray-4339-16487', 'pydata__xarray-3151-16453', 'pydata__xarray-3114-16452', 'pydata__xarray-5033-16517', 'pydata__xarray-4802-16505', 'pydata__xarray-5455-16530', 'pydata__xarray-6400-16539', 'pydata__xarray-3239-16456', 'pydata__xarray-4419-16488']
|
||||||
|
|
||||||
|
pynguin_ids = ['pydata__xarray-6548-16541', 'pydata__xarray-7003-16557', 'pydata__xarray-3114-16452', 'pydata__xarray-4339-16487', 'pydata__xarray-6889-16549', 'pydata__xarray-3239-16458', 'pydata__xarray-3364-16461', 'pydata__xarray-3239-16457', 'pydata__xarray-5365-16529', 'pydata__xarray-5131-16520', 'pydata__xarray-7229-16578', 'pydata__xarray-6461-16540', 'pydata__xarray-4419-16488', 'pydata__xarray-7147-16571', 'pydata__xarray-3151-16453', 'pydata__xarray-4966-16515', 'pydata__xarray-4629-16492', 'pydata__xarray-3239-16456', 'pydata__xarray-7400-16582', 'pydata__xarray-4994-16516', 'pydata__xarray-3302-16459', 'pydata__xarray-6601-16544', 'pydata__xarray-6882-16548', 'pydata__xarray-6135-16535', 'pydata__xarray-7393-16581', 'pydata__xarray-5731-16534', 'pydata__xarray-7203-16577']
|
||||||
|
|
||||||
|
ids = ['pydata__xarray-3114-16452', 'pydata__xarray-3151-16453', 'pydata__xarray-3156-16454', 'pydata__xarray-3239-16456', 'pydata__xarray-3239-16457', 'pydata__xarray-3239-16458', 'pydata__xarray-3302-16459', 'pydata__xarray-3364-16461', 'pydata__xarray-3677-16471', 'pydata__xarray-3905-16478', 'pydata__xarray-4182-16484', 'pydata__xarray-4248-16486', 'pydata__xarray-4339-16487', 'pydata__xarray-4419-16488', 'pydata__xarray-4629-16492', 'pydata__xarray-4750-16496', 'pydata__xarray-4802-16505', 'pydata__xarray-4966-16515', 'pydata__xarray-4994-16516', 'pydata__xarray-5033-16517', 'pydata__xarray-5126-16518', 'pydata__xarray-5126-16519', 'pydata__xarray-5131-16520', 'pydata__xarray-5365-16529', 'pydata__xarray-5455-16530', 'pydata__xarray-5662-16532', 'pydata__xarray-5731-16534', 'pydata__xarray-6135-16535', 'pydata__xarray-6135-16536', 'pydata__xarray-6386-16537', 'pydata__xarray-6394-16538', 'pydata__xarray-6400-16539', 'pydata__xarray-6461-16540', 'pydata__xarray-6548-16541', 'pydata__xarray-6599-16543', 'pydata__xarray-6601-16544', 'pydata__xarray-6882-16548', 'pydata__xarray-6889-16549', 'pydata__xarray-7003-16557', 'pydata__xarray-7147-16571', 'pydata__xarray-7150-16572', 'pydata__xarray-7203-16577', 'pydata__xarray-7229-16578', 'pydata__xarray-7393-16581', 'pydata__xarray-7400-16582']
|
||||||
|
|
||||||
|
|
||||||
|
Command eval (our approach):
|
||||||
|
poetry run ./evaluation/benchmarks/testgeneval/scripts/eval_infer_remote.sh evaluation/evaluation_outputs/outputs/kjain14__testgeneval-test/CodeActAgent/gpt-4o_maxiter_25_N_v0.20.0-no-hint-run_1/output.jsonl 10 kjain14/testgeneval test true
|
||||||
|
|
||||||
|
Command run (our approach):
|
||||||
|
./evaluation/benchmarks/testgeneval/scripts/run_infer.sh llm.eval_gpt HEAD CodeActAgent -1 25 10 kjain14/testgeneval test 1 ../TestGenEval/results/testgeneval/preds/gpt-4o-2024-08-06__testgeneval__0.2__test.jsonl
|
||||||
@@ -0,0 +1,80 @@
|
|||||||
|
# TestGenEval Benchmark Evaluation
|
||||||
|
|
||||||
|
This folder contains the evaluation harness for the TestGenEval benchmark, which is based on the original TestGenEval benchmark ([paper](https://arxiv.org/abs/2410.00752)). TestGenEval is designed to evaluate the ability of language models to generate unit tests for given Python functions.
|
||||||
|
|
||||||
|
## Setup Environment and LLM Configuration
|
||||||
|
|
||||||
|
1. Follow the instructions [here](../../README.md#setup) to set up your local development environment and configure your LLM.
|
||||||
|
|
||||||
|
2. Install the TestGenEval dependencies:
|
||||||
|
```bash
|
||||||
|
poetry install --with testgeneval
|
||||||
|
```
|
||||||
|
|
||||||
|
## Run Inference
|
||||||
|
|
||||||
|
To generate tests using your model, run the following command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./evaluation/benchmarks/testgeneval/scripts/run_infer.sh [model_config] [git-version] [agent] [eval_limit] [max_iter] [num_workers] [dataset] [dataset_split]
|
||||||
|
|
||||||
|
# Example
|
||||||
|
./evaluation/benchmarks/testgeneval/scripts/run_infer.sh llm.eval_gpt4_1106_preview HEAD CodeActAgent 100 30 1 kjain14/testgenevallite test
|
||||||
|
```
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- `model_config`: The config group name for your LLM settings (e.g., `eval_gpt4_1106_preview`)
|
||||||
|
- `git-version`: The git commit hash or release tag of OpenHands to evaluate (e.g., `HEAD` or `0.6.2`)
|
||||||
|
- `agent`: The name of the agent for benchmarks (default: `CodeActAgent`)
|
||||||
|
- `eval_limit`: Limit the evaluation to the first N instances (optional)
|
||||||
|
- `max_iter`: Maximum number of iterations for the agent to run (default: 30)
|
||||||
|
- `num_workers`: Number of parallel workers for evaluation (default: 1)
|
||||||
|
- `dataset`: HuggingFace dataset name (default: `kjain14/testgenevallite`)
|
||||||
|
- `dataset_split`: Dataset split to use (default: `test`)
|
||||||
|
|
||||||
|
After running the inference, you will obtain an `output.jsonl` file (by default saved to `evaluation/evaluation_outputs`).
|
||||||
|
|
||||||
|
## Evaluate Generated Tests
|
||||||
|
|
||||||
|
To evaluate the generated tests, use the `eval_infer.sh` script:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./evaluation/benchmarks/testgeneval/scripts/eval_infer.sh $YOUR_OUTPUT_JSONL [instance_id] [dataset_name] [split] [num_workers] [skip_mutation]
|
||||||
|
|
||||||
|
# Example
|
||||||
|
./evaluation/benchmarks/testgeneval/scripts/eval_infer.sh evaluation/evaluation_outputs/outputs/testgeneval/CodeActAgent/gpt-4-1106-preview_maxiter_50_N_v1.0/output.jsonl
|
||||||
|
```
|
||||||
|
|
||||||
|
Optional arguments:
|
||||||
|
- `instance_id`: Evaluate a single instance (optional)
|
||||||
|
- `dataset_name`: Name of the dataset to use (default: `kjain14/testgenevallite`)
|
||||||
|
- `split`: Dataset split to use (default: `test`)
|
||||||
|
- `num_workers`: Number of workers for running docker (default: 1)
|
||||||
|
- `skip_mutation`: Skip mutation testing (enter `true` if desired)
|
||||||
|
|
||||||
|
The evaluation results will be saved to `evaluation/evaluation_outputs/outputs/testgeneval/CodeActAgent/gpt-4-1106-preview_maxiter_50_N_v1.0/` with `output.testgeneval.jsonl` containing the metrics.
|
||||||
|
|
||||||
|
## Metrics
|
||||||
|
|
||||||
|
The TestGenEval benchmark evaluates generated tests based on the following metrics:
|
||||||
|
|
||||||
|
1. Correctness: Measures if the generated tests are syntactically correct and run without errors.
|
||||||
|
2. Coverage: Assesses the code coverage achieved by the generated tests.
|
||||||
|
3. Mutation Score: Evaluates the effectiveness of the tests in detecting intentionally introduced bugs (mutations).
|
||||||
|
4. Readability: Analyzes the readability of the generated tests using various metrics.
|
||||||
|
|
||||||
|
## Submit Your Evaluation Results
|
||||||
|
|
||||||
|
To contribute your evaluation results:
|
||||||
|
|
||||||
|
1. Fork [our HuggingFace evaluation outputs](https://huggingface.co/spaces/OpenHands/evaluation).
|
||||||
|
2. Add your results to the forked repository.
|
||||||
|
3. Submit a Pull Request with your evaluation results following the guide [here](https://huggingface.co/docs/hub/en/repositories-pull-requests-discussions#pull-requests-and-discussions).
|
||||||
|
|
||||||
|
## Additional Resources
|
||||||
|
|
||||||
|
- [TestGenEval Paper](https://arxiv.org/abs/2410.00752)
|
||||||
|
- [OpenHands Documentation](https://github.com/All-Hands-AI/OpenHands)
|
||||||
|
- [HuggingFace Datasets](https://huggingface.co/datasets)
|
||||||
|
|
||||||
|
For any questions or issues, please open an issue in the [OpenHands repository](https://github.com/All-Hands-AI/OpenHands/issues).
|
||||||
@@ -0,0 +1,356 @@
|
|||||||
|
import math
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from tree_sitter import Language, Parser
|
||||||
|
|
||||||
|
|
||||||
|
def total_byte_entropy_stats(python_code):
|
||||||
|
# Count the occurrence of each byte (character for simplicity)
|
||||||
|
byte_counts = {}
|
||||||
|
for byte in python_code.encode('utf-8'):
|
||||||
|
byte_counts[byte] = byte_counts.get(byte, 0) + 1
|
||||||
|
|
||||||
|
total_bytes = sum(byte_counts.values())
|
||||||
|
entropy = -sum(
|
||||||
|
(count / total_bytes) * math.log2(count / total_bytes)
|
||||||
|
for count in byte_counts.values()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {'total_byte_entropy': entropy}
|
||||||
|
|
||||||
|
|
||||||
|
def average_nulls_stats(tree, num_lines):
|
||||||
|
total_nulls = 0
|
||||||
|
nulls_per_line = {} # Dictionary to count nulls per line
|
||||||
|
|
||||||
|
def traverse(node):
|
||||||
|
nonlocal total_nulls
|
||||||
|
if node.type == 'null_literal':
|
||||||
|
total_nulls += 1
|
||||||
|
line_number = node.start_point[0] # Get line number
|
||||||
|
if line_number in nulls_per_line:
|
||||||
|
nulls_per_line[line_number] += 1
|
||||||
|
else:
|
||||||
|
nulls_per_line[line_number] = 1
|
||||||
|
for child in node.children:
|
||||||
|
traverse(child)
|
||||||
|
|
||||||
|
traverse(tree.root_node)
|
||||||
|
|
||||||
|
# Calculate average nulls per line
|
||||||
|
avg_nulls = total_nulls / num_lines if num_lines > 0 else 0
|
||||||
|
|
||||||
|
# Calculate max nulls on any line
|
||||||
|
max_nulls_on_any_line = max(nulls_per_line.values()) if nulls_per_line else 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
'avg_nulls': avg_nulls,
|
||||||
|
'total_nulls': total_nulls,
|
||||||
|
'max_nulls': max_nulls_on_any_line,
|
||||||
|
'has_nulls': 1 if total_nulls > 0 else 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def arithmetic_operations_stats(tree, num_lines):
|
||||||
|
# Dictionary to hold counts of each arithmetic operation
|
||||||
|
op_counts = {'+': 0, '-': 0, '*': 0, '/': 0, '%': 0}
|
||||||
|
total_ops = 0
|
||||||
|
|
||||||
|
# Function to traverse the AST and update operation counts
|
||||||
|
def traverse(node):
|
||||||
|
nonlocal total_ops
|
||||||
|
if node.type == 'binary_expression' or node.type == 'update_expression':
|
||||||
|
for child in node.children:
|
||||||
|
if child.type == 'operator':
|
||||||
|
op = child.text.decode('utf8')
|
||||||
|
if op in op_counts:
|
||||||
|
op_counts[op] += 1
|
||||||
|
total_ops += 1
|
||||||
|
else:
|
||||||
|
for child in node.children:
|
||||||
|
traverse(child)
|
||||||
|
|
||||||
|
traverse(tree.root_node)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total_arithmetic_operations': total_ops,
|
||||||
|
'avg_arithmetic_operations': total_ops / num_lines,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def numbers_floats_stats(tree, num_lines):
|
||||||
|
total_numbers = 0
|
||||||
|
total_floats = 0
|
||||||
|
|
||||||
|
def traverse(node):
|
||||||
|
nonlocal total_numbers, total_floats
|
||||||
|
if node.type in ['integer_literal', 'decimal_literal']:
|
||||||
|
total_numbers += 1
|
||||||
|
if (
|
||||||
|
'.' in node.text.decode('utf8')
|
||||||
|
or 'e' in node.text.decode('utf8').lower()
|
||||||
|
):
|
||||||
|
total_floats += 1
|
||||||
|
for child in node.children:
|
||||||
|
traverse(child)
|
||||||
|
|
||||||
|
traverse(tree.root_node)
|
||||||
|
return {'total_numbers': total_numbers, 'total_floats': total_floats}
|
||||||
|
|
||||||
|
|
||||||
|
def code_stats(python_code):
|
||||||
|
lines = python_code.strip().split('\n')
|
||||||
|
total_line_length = sum(len(line) for line in lines)
|
||||||
|
max_line_length = max(len(line) for line in lines)
|
||||||
|
return {
|
||||||
|
'total_line_length': total_line_length,
|
||||||
|
'max_line_length': max_line_length,
|
||||||
|
'avg_characters': total_line_length / len(lines),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def assertions_stats(tree, num_lines):
|
||||||
|
total_assertions = 0
|
||||||
|
|
||||||
|
def traverse(node):
|
||||||
|
nonlocal total_assertions
|
||||||
|
if node.type == 'assert_statement':
|
||||||
|
total_assertions += 1
|
||||||
|
for child in node.children:
|
||||||
|
traverse(child)
|
||||||
|
|
||||||
|
traverse(tree.root_node)
|
||||||
|
return {
|
||||||
|
'total_assertions': total_assertions,
|
||||||
|
'total_has_assertions': 1 if total_assertions > 0 else 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def class_instances_stats(tree, num_lines):
|
||||||
|
total_class_instances = 0
|
||||||
|
|
||||||
|
def traverse(node):
|
||||||
|
nonlocal total_class_instances
|
||||||
|
if node.type == 'object_creation_expression':
|
||||||
|
total_class_instances += 1
|
||||||
|
for child in node.children:
|
||||||
|
traverse(child)
|
||||||
|
|
||||||
|
traverse(tree.root_node)
|
||||||
|
return {'total_class_instances': total_class_instances}
|
||||||
|
|
||||||
|
|
||||||
|
def has_execeptions(tree, num_lines):
|
||||||
|
total_has_exceptions = 0
|
||||||
|
|
||||||
|
def traverse(node):
|
||||||
|
nonlocal total_has_exceptions
|
||||||
|
if node.type == 'try_statement':
|
||||||
|
total_has_exceptions += 1
|
||||||
|
for child in node.children:
|
||||||
|
traverse(child)
|
||||||
|
|
||||||
|
traverse(tree.root_node)
|
||||||
|
return {'total_has_exceptions': 1 if total_has_exceptions > 0 else 0}
|
||||||
|
|
||||||
|
|
||||||
|
def distinct_methods_stats(tree, num_lines):
|
||||||
|
method_names = set()
|
||||||
|
total_nodes = 0
|
||||||
|
|
||||||
|
def traverse(node):
|
||||||
|
nonlocal total_nodes
|
||||||
|
if node.type == 'method_declaration':
|
||||||
|
for child in node.children:
|
||||||
|
if child.type == 'identifier':
|
||||||
|
method_names.add(child.text.decode('utf8'))
|
||||||
|
break
|
||||||
|
total_nodes += 1
|
||||||
|
for child in node.children:
|
||||||
|
traverse(child)
|
||||||
|
|
||||||
|
traverse(tree.root_node)
|
||||||
|
total_distinct_methods = len(method_names)
|
||||||
|
total_method_ratio = (
|
||||||
|
total_distinct_methods / (total_nodes - total_distinct_methods)
|
||||||
|
if total_nodes > total_distinct_methods
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total_distinct_methods': total_distinct_methods,
|
||||||
|
'total_method_ratio': total_method_ratio,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def loops_stats(tree, num_lines):
|
||||||
|
"""
|
||||||
|
Calculate the average number of loops.
|
||||||
|
"""
|
||||||
|
total_loops = 0
|
||||||
|
|
||||||
|
def traverse(node):
|
||||||
|
nonlocal total_loops
|
||||||
|
if node.type in ['for_statement', 'while_statement', 'do_statement']:
|
||||||
|
total_loops += 1
|
||||||
|
for child in node.children:
|
||||||
|
traverse(child)
|
||||||
|
|
||||||
|
traverse(tree.root_node)
|
||||||
|
avg_loops = total_loops / num_lines
|
||||||
|
return {'avg_loops': avg_loops}
|
||||||
|
|
||||||
|
|
||||||
|
def branches_stats(tree, num_lines):
|
||||||
|
"""
|
||||||
|
Calculate the average number of branches (conditional statements).
|
||||||
|
"""
|
||||||
|
total_branches = 0
|
||||||
|
|
||||||
|
def traverse(node):
|
||||||
|
nonlocal total_branches
|
||||||
|
if node.type in ['if_statement', 'switch_statement']:
|
||||||
|
total_branches += 1
|
||||||
|
for child in node.children:
|
||||||
|
traverse(child)
|
||||||
|
|
||||||
|
traverse(tree.root_node)
|
||||||
|
# Assuming each branch is its own, this might need refinement based on definition
|
||||||
|
avg_branches = total_branches / num_lines
|
||||||
|
return {'avg_branches': avg_branches}
|
||||||
|
|
||||||
|
|
||||||
|
def string_stats(tree, num_lines):
|
||||||
|
string_literals = []
|
||||||
|
|
||||||
|
# Function to traverse the AST and collect string literals
|
||||||
|
def traverse(node):
|
||||||
|
if node.type == 'string_literal':
|
||||||
|
# Extracting the string literal, excluding the quotation marks
|
||||||
|
literal_text = node.text.decode('utf8')[1:-1]
|
||||||
|
string_literals.append(literal_text)
|
||||||
|
for child in node.children:
|
||||||
|
traverse(child)
|
||||||
|
|
||||||
|
traverse(tree.root_node)
|
||||||
|
|
||||||
|
# Calculate the average string length
|
||||||
|
total_length = sum(len(s) for s in string_literals)
|
||||||
|
avg_length = total_length / num_lines
|
||||||
|
return {'avg_str_length': avg_length}
|
||||||
|
|
||||||
|
|
||||||
|
def identifier_stats(tree, num_lines):
|
||||||
|
root_node = tree.root_node
|
||||||
|
identifier_counts = {} # Dictionary to count occurrences of each identifier
|
||||||
|
total_nodes = 0 # Counter for all nodes
|
||||||
|
|
||||||
|
# Function to recursively count identifiers and all nodes, gathering their stats
|
||||||
|
def count(node):
|
||||||
|
nonlocal identifier_counts, total_nodes
|
||||||
|
iden_count = 0
|
||||||
|
max_length = 0
|
||||||
|
total_nodes += 1 # Increment total nodes for every node visited
|
||||||
|
if node.type == 'identifier':
|
||||||
|
identifier = node.text.decode('utf8') # Assuming UTF-8 encoding
|
||||||
|
iden_count += 1
|
||||||
|
identifier_counts[identifier] = identifier_counts.get(identifier, 0) + 1
|
||||||
|
iden_length = len(identifier)
|
||||||
|
if iden_length > max_length:
|
||||||
|
max_length = iden_length
|
||||||
|
for child in node.children:
|
||||||
|
child_count, child_max_length = count(child)
|
||||||
|
iden_count += child_count
|
||||||
|
if child_max_length > max_length:
|
||||||
|
max_length = child_max_length
|
||||||
|
return iden_count, max_length
|
||||||
|
|
||||||
|
total_identifiers, max_identifier_length = count(root_node)
|
||||||
|
total_unique_identifiers = len(identifier_counts)
|
||||||
|
total_identifier_length = sum(len(k) * v for k, v in identifier_counts.items())
|
||||||
|
avg_identifier_length = total_identifier_length / num_lines
|
||||||
|
|
||||||
|
# Calculate the identifier ratio as total identifiers over total nodes
|
||||||
|
identifier_ratio = total_identifiers / total_nodes if total_nodes > 0 else 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total_identifiers': total_identifiers,
|
||||||
|
'total_identifier_length': total_identifier_length,
|
||||||
|
'max_identifier_length': max_identifier_length,
|
||||||
|
'avg_identifier_length': avg_identifier_length,
|
||||||
|
'total_unique_identifiers': total_unique_identifiers,
|
||||||
|
'total_identifier_ratio': identifier_ratio, # Include the new ratio in the returned dictionary
|
||||||
|
'total_nodes': total_nodes, # Include total node count for reference or further calculations
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def compute_regression(results):
|
||||||
|
components = {
|
||||||
|
'total_line_length': -0.0001,
|
||||||
|
'max_line_length': -0.0021,
|
||||||
|
'total_identifiers': 0.0076,
|
||||||
|
'total_identifier_length': -0.0004,
|
||||||
|
'max_identifier_length': -0.0067,
|
||||||
|
'avg_identifier_length': -0.005,
|
||||||
|
'avg_arithmetic_operations': 0.0225,
|
||||||
|
'avg_branches': 0.9886,
|
||||||
|
'avg_loops': 0.1572,
|
||||||
|
'total_assertions': 0.0119,
|
||||||
|
'total_has_assertions': -0.0147,
|
||||||
|
'avg_characters': 0.1242,
|
||||||
|
'total_class_instances': -0.043,
|
||||||
|
'total_distinct_methods': -0.0127,
|
||||||
|
'avg_str_length': 0.0026,
|
||||||
|
'total_has_exceptions': 0.1206,
|
||||||
|
'total_unique_identifiers': -0.019,
|
||||||
|
'max_nulls': -0.0712,
|
||||||
|
'total_numbers': -0.0078,
|
||||||
|
'avg_nulls': 0.1444,
|
||||||
|
'total_identifier_ratio': 0.334,
|
||||||
|
'total_method_ratio': 0.0406,
|
||||||
|
'total_floats': -0.0174,
|
||||||
|
'total_byte_entropy': -0.3917,
|
||||||
|
}
|
||||||
|
test_score = 0
|
||||||
|
|
||||||
|
for component in components:
|
||||||
|
test_score += components[component] * results[component]
|
||||||
|
|
||||||
|
test_score += 5.7501
|
||||||
|
return test_score
|
||||||
|
|
||||||
|
|
||||||
|
def compute_readability(python_code):
|
||||||
|
parser = Parser()
|
||||||
|
this_dir = Path(os.path.dirname(os.path.realpath(__file__)))
|
||||||
|
parser.set_language(Language.build_library(
|
||||||
|
# Store the library in the `build` directory
|
||||||
|
this_dir / "build" / "my-languages.so",
|
||||||
|
# Include one or more languages
|
||||||
|
[
|
||||||
|
this_dir / "tree-sitter-python"
|
||||||
|
]
|
||||||
|
).get_language('python'))
|
||||||
|
|
||||||
|
results = code_stats(python_code)
|
||||||
|
|
||||||
|
num_lines = len(python_code.strip().split('\n'))
|
||||||
|
results.update(total_byte_entropy_stats(python_code))
|
||||||
|
|
||||||
|
tree = parser.parse(bytes(python_code, 'utf8'))
|
||||||
|
|
||||||
|
results.update(identifier_stats(tree, num_lines))
|
||||||
|
results.update(loops_stats(tree, num_lines))
|
||||||
|
results.update(branches_stats(tree, num_lines))
|
||||||
|
results.update(distinct_methods_stats(tree, num_lines))
|
||||||
|
results.update(has_execeptions(tree, num_lines))
|
||||||
|
results.update(class_instances_stats(tree, num_lines))
|
||||||
|
results.update(assertions_stats(tree, num_lines))
|
||||||
|
results.update(numbers_floats_stats(tree, num_lines))
|
||||||
|
results.update(average_nulls_stats(tree, num_lines))
|
||||||
|
results.update(arithmetic_operations_stats(tree, num_lines))
|
||||||
|
results.update(string_stats(tree, num_lines))
|
||||||
|
|
||||||
|
score = compute_regression(results)
|
||||||
|
return score
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,633 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from report_utils import (
|
||||||
|
check_coverage,
|
||||||
|
check_mutation,
|
||||||
|
count_methods,
|
||||||
|
get_lines_of_code,
|
||||||
|
)
|
||||||
|
|
||||||
|
from evaluation.benchmarks.testgeneval.compute_readability import compute_readability
|
||||||
|
from evaluation.benchmarks.testgeneval.constants import (
|
||||||
|
COVERAGE_PREFIX,
|
||||||
|
MUTATION_BUFFER,
|
||||||
|
MUTATION_TEMPLATE,
|
||||||
|
MUTATION_TIMEOUT,
|
||||||
|
TESTS_SUFFIX,
|
||||||
|
)
|
||||||
|
from evaluation.benchmarks.testgeneval.metrics import (
|
||||||
|
bleu,
|
||||||
|
code_bleu,
|
||||||
|
edit_sim,
|
||||||
|
exact_match,
|
||||||
|
rouge_l,
|
||||||
|
)
|
||||||
|
from evaluation.benchmarks.testgeneval.pygments_utils import tokenize_code
|
||||||
|
from evaluation.benchmarks.testgeneval.run_infer import get_instance_docker_image
|
||||||
|
from evaluation.benchmarks.testgeneval.test_filter import filter_tests
|
||||||
|
from evaluation.benchmarks.testgeneval.test_spec import (
|
||||||
|
TestGenEvalInstance,
|
||||||
|
TestSpec,
|
||||||
|
make_test_spec,
|
||||||
|
)
|
||||||
|
from evaluation.benchmarks.testgeneval.utils import load_testgeneval_dataset
|
||||||
|
from evaluation.utils.shared import (
|
||||||
|
EvalMetadata,
|
||||||
|
EvalOutput,
|
||||||
|
prepare_dataset,
|
||||||
|
reset_logger_for_multiprocessing,
|
||||||
|
run_evaluation,
|
||||||
|
)
|
||||||
|
from openhands.core.config import AppConfig, SandboxConfig, get_parser
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.core.main import create_runtime
|
||||||
|
from openhands.events.action import CmdRunAction
|
||||||
|
from openhands.events.observation import CmdOutputObservation
|
||||||
|
from openhands.utils.async_utils import call_async_from_sync
|
||||||
|
|
||||||
|
DOCKER_IMAGE_PREFIX = os.environ.get('EVAL_DOCKER_IMAGE_PREFIX', 'docker.io/kdjain/')
|
||||||
|
logger.info(f'Using docker image prefix: {DOCKER_IMAGE_PREFIX}')
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(instance: pd.Series) -> AppConfig:
|
||||||
|
base_container_image = get_instance_docker_image(instance['instance_id_swebench'])
|
||||||
|
assert (
|
||||||
|
base_container_image
|
||||||
|
), f"Invalid container image for instance {instance['instance_id_swebench']}."
|
||||||
|
logger.info(f'Using instance container image: {base_container_image}.')
|
||||||
|
return AppConfig(
|
||||||
|
run_as_openhands=False,
|
||||||
|
runtime=os.environ.get('RUNTIME', 'eventstream'),
|
||||||
|
sandbox=SandboxConfig(
|
||||||
|
base_container_image=base_container_image,
|
||||||
|
use_host_network=False,
|
||||||
|
timeout=1800,
|
||||||
|
api_key=os.environ.get('ALLHANDS_API_KEY'),
|
||||||
|
remote_runtime_api_url=os.environ.get(
|
||||||
|
'SANDBOX_REMOTE_RUNTIME_API_URL', 'http://localhost:8000'
|
||||||
|
),
|
||||||
|
),
|
||||||
|
workspace_base=None,
|
||||||
|
workspace_mount_path=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_lexical_metrics(pred_suite, gold_suite):
|
||||||
|
pred_loc = get_lines_of_code(pred_suite)
|
||||||
|
gold_loc = get_lines_of_code(gold_suite)
|
||||||
|
pred_methods = count_methods(pred_suite)
|
||||||
|
gold_methods = count_methods(gold_suite)
|
||||||
|
readability_pred = compute_readability(pred_suite)
|
||||||
|
readability_gold = compute_readability(gold_suite)
|
||||||
|
|
||||||
|
preds = tokenize_code(pred_suite)
|
||||||
|
golds = tokenize_code(gold_suite)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'pred_loc': pred_loc,
|
||||||
|
'gold_loc': gold_loc,
|
||||||
|
'pred_readability': readability_pred,
|
||||||
|
'gold_readability': readability_gold,
|
||||||
|
'pred_methods': pred_methods,
|
||||||
|
'gold_methods': gold_methods,
|
||||||
|
'code_bleu': code_bleu(preds, golds, 'Python3'),
|
||||||
|
'bleu': bleu(preds, golds),
|
||||||
|
'xmatch': exact_match(preds, golds),
|
||||||
|
'edit_sim': edit_sim(preds, golds),
|
||||||
|
'rouge_f': rouge_l(golds, preds)['f'],
|
||||||
|
'rouge_p': rouge_l(golds, preds)['p'],
|
||||||
|
'rouge_r': rouge_l(golds, preds)['r'],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_command(runtime, command, timeout=600):
|
||||||
|
action = CmdRunAction(command=command)
|
||||||
|
action.set_hard_timeout(timeout)
|
||||||
|
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||||
|
obs = runtime.run_action(action)
|
||||||
|
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||||
|
assert obs.exit_code == 0
|
||||||
|
return obs
|
||||||
|
|
||||||
|
|
||||||
|
def run_tests(runtime, instance, test_script, log_file='/tmp/test_output.log'):
|
||||||
|
action = CmdRunAction(command=f'bash {test_script} > {log_file} 2>&1 & echo $!')
|
||||||
|
action.set_hard_timeout(60)
|
||||||
|
obs = runtime.run_action(action)
|
||||||
|
|
||||||
|
assert isinstance(obs, CmdOutputObservation), 'Failed to start test script.'
|
||||||
|
pid = obs.content.split()[-1].strip()
|
||||||
|
logger.info(f'[{instance.instance_id}] Test process started with PID: {pid}')
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
timeout = 1800
|
||||||
|
while True:
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
if elapsed_time > timeout:
|
||||||
|
logger.info(f'[{instance.instance_id}] Test process timed out.')
|
||||||
|
instance['test_result']['report']['test_timeout'] = True
|
||||||
|
break
|
||||||
|
|
||||||
|
check_action = CmdRunAction(command=f'ps -p {pid} > /dev/null; echo $?')
|
||||||
|
check_obs = runtime.run_action(check_action)
|
||||||
|
if (
|
||||||
|
isinstance(check_obs, CmdOutputObservation)
|
||||||
|
and len(check_obs.content.split()) > 0
|
||||||
|
and check_obs.content.split()[-1].strip() == '1'
|
||||||
|
):
|
||||||
|
logger.info(f'[{instance.instance_id}] Test process completed.')
|
||||||
|
break
|
||||||
|
time.sleep(30)
|
||||||
|
|
||||||
|
test_action = CmdRunAction(command=f'cat {log_file}')
|
||||||
|
test_action.set_hard_timeout(300)
|
||||||
|
test_obs = runtime.run_action(test_action)
|
||||||
|
assert isinstance(test_obs, CmdOutputObservation), 'Failed to retrieve test output.'
|
||||||
|
return test_obs.exit_code, test_obs.content, elapsed_time
|
||||||
|
|
||||||
|
|
||||||
|
def run_mutation_testing(
|
||||||
|
runtime, instance, mutation_script, log_file='/tmp/mutation_output.log'
|
||||||
|
):
|
||||||
|
action = CmdRunAction(command=f'bash {mutation_script} > {log_file} 2>&1 & echo $!')
|
||||||
|
action.set_hard_timeout(60)
|
||||||
|
obs = runtime.run_action(action)
|
||||||
|
|
||||||
|
assert isinstance(obs, CmdOutputObservation), 'Failed to start test script.'
|
||||||
|
pid = obs.content.split()[-1].strip()
|
||||||
|
logger.info(f'[{instance.instance_id}] Mutation process started with PID: {pid}')
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
timeout = 4000
|
||||||
|
while True:
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
if elapsed_time > timeout:
|
||||||
|
logger.info(f'[{instance.instance_id}] Mutation process timed out.')
|
||||||
|
instance['test_result']['report']['mutation_timeout'] = True
|
||||||
|
break
|
||||||
|
|
||||||
|
check_action = CmdRunAction(command=f'ps -p {pid} > /dev/null; echo $?')
|
||||||
|
check_obs = runtime.run_action(check_action)
|
||||||
|
if (
|
||||||
|
isinstance(check_obs, CmdOutputObservation)
|
||||||
|
and len(check_obs.content.split()) > 0
|
||||||
|
and check_obs.content.split()[-1].strip() == '1'
|
||||||
|
):
|
||||||
|
logger.info(f'[{instance.instance_id}] Mutation process completed.')
|
||||||
|
break
|
||||||
|
time.sleep(30)
|
||||||
|
|
||||||
|
assert isinstance(obs, CmdOutputObservation), 'Failed to run mutation script.'
|
||||||
|
mutation_action = CmdRunAction(command=f'cat {log_file}')
|
||||||
|
mutation_action.set_hard_timeout(300)
|
||||||
|
mutation_obs = runtime.run_action(mutation_action)
|
||||||
|
assert isinstance(
|
||||||
|
mutation_obs, CmdOutputObservation
|
||||||
|
), 'Failed to retrieve mutation output.'
|
||||||
|
return mutation_obs.exit_code, mutation_obs.content
|
||||||
|
|
||||||
|
|
||||||
|
def grade_test_output(
|
||||||
|
test_suite: str, instance: pd.Series, test_output: str, test_spec: TestSpec, runtime
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Two-pass test grading with short-circuiting:
|
||||||
|
1. Run all tests to identify passing/failing tests
|
||||||
|
2. If no failing tests, evaluate coverage immediately
|
||||||
|
3. Otherwise, run only passing tests for coverage analysis
|
||||||
|
"""
|
||||||
|
unit_test_output, coverage_output = '', ''
|
||||||
|
if TESTS_SUFFIX in test_output:
|
||||||
|
unit_test_output = test_output.split(TESTS_SUFFIX)[0]
|
||||||
|
|
||||||
|
if not unit_test_output:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
0,
|
||||||
|
'',
|
||||||
|
'',
|
||||||
|
{
|
||||||
|
'total_tests': 0,
|
||||||
|
'passing_tests': 0,
|
||||||
|
'failing_tests': 0,
|
||||||
|
'any_pass': False,
|
||||||
|
'all_pass': False,
|
||||||
|
'passing_test_names': [],
|
||||||
|
'failing_test_names': [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info('Calling filter unit tests')
|
||||||
|
filtered_content, passing_tests, failing_tests = filter_tests(
|
||||||
|
test_suite, unit_test_output, test_spec.repo
|
||||||
|
)
|
||||||
|
|
||||||
|
total_tests = len(passing_tests) + len(failing_tests)
|
||||||
|
test_stats = {
|
||||||
|
'total_tests': total_tests,
|
||||||
|
'passing_tests': len(passing_tests),
|
||||||
|
'failing_tests': len(failing_tests),
|
||||||
|
'any_pass': len(passing_tests) > 0,
|
||||||
|
'all_pass': len(failing_tests) == 0 and total_tests > 0,
|
||||||
|
'passing_test_names': passing_tests,
|
||||||
|
'failing_test_names': failing_tests,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not passing_tests:
|
||||||
|
return False, 0, unit_test_output, coverage_output, test_stats
|
||||||
|
|
||||||
|
# If all tests pass, evaluate coverage immediately
|
||||||
|
if not failing_tests:
|
||||||
|
coverage = 0
|
||||||
|
cov_success = False
|
||||||
|
if COVERAGE_PREFIX in test_output:
|
||||||
|
coverage_output = test_output.split(COVERAGE_PREFIX)[1]
|
||||||
|
_, coverage = check_coverage(coverage_output, test_spec.code_file)
|
||||||
|
cov_success = True
|
||||||
|
# test_stats['filtered_suite'] = test_suite
|
||||||
|
return cov_success, coverage, unit_test_output, coverage_output, test_stats
|
||||||
|
|
||||||
|
cov_success = False
|
||||||
|
coverage = 0
|
||||||
|
# Second pass - run coverage on passing tests
|
||||||
|
if filtered_content:
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
test_suite_path = os.path.join(temp_dir, 'test_suite.py')
|
||||||
|
with open(test_suite_path, 'w') as f:
|
||||||
|
f.write(filtered_content)
|
||||||
|
runtime.copy_to(test_suite_path, '/tmp')
|
||||||
|
|
||||||
|
run_command(runtime, f'cp /tmp/test_suite.py /testbed/{test_spec.test_file}')
|
||||||
|
_, test_output_second_pass, _ = run_tests(runtime, instance, '/tmp/test.sh')
|
||||||
|
|
||||||
|
coverage, coverage_output, unit_test_output = 0, '', test_output_second_pass
|
||||||
|
|
||||||
|
if COVERAGE_PREFIX in test_output_second_pass:
|
||||||
|
coverage_output = test_output_second_pass.split(COVERAGE_PREFIX)[1]
|
||||||
|
unit_test_output = test_output_second_pass.split(TESTS_SUFFIX)[0]
|
||||||
|
_, coverage = check_coverage(coverage_output, test_spec.code_file)
|
||||||
|
cov_success = True
|
||||||
|
|
||||||
|
# test_stats['filtered_suite'] = filtered_content
|
||||||
|
return cov_success, coverage, unit_test_output, coverage_output, test_stats
|
||||||
|
|
||||||
|
|
||||||
|
def process_instance(
|
||||||
|
instance: pd.Series,
|
||||||
|
metadata: EvalMetadata,
|
||||||
|
reset_logger: bool = True,
|
||||||
|
log_dir: str | None = None,
|
||||||
|
) -> EvalOutput:
|
||||||
|
"""
|
||||||
|
Evaluate agent performance on a TestGenEval problem instance.
|
||||||
|
|
||||||
|
Note that this signature differs from the expected input to `run_evaluation`. Use
|
||||||
|
`functools.partial` to provide optional arguments before passing to the evaluation harness.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_dir (str | None, default=None): Path to directory where log files will be written. Must
|
||||||
|
be provided if `reset_logger` is set.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: if the `reset_logger` flag is set without a provided log directory.
|
||||||
|
"""
|
||||||
|
if reset_logger:
|
||||||
|
assert (
|
||||||
|
log_dir is not None
|
||||||
|
), "Can't reset logger without a provided log directory."
|
||||||
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
|
reset_logger_for_multiprocessing(logger, instance.instance_id, log_dir)
|
||||||
|
else:
|
||||||
|
logger.info(f'Starting evaluation for instance {instance.instance_id}.')
|
||||||
|
|
||||||
|
config = get_config(instance)
|
||||||
|
id = instance.instance_id
|
||||||
|
logger.info(f'Starting evaluation for instance {id}.')
|
||||||
|
|
||||||
|
instance['test_result']['id'] = id
|
||||||
|
instance['test_result']['report'] = {
|
||||||
|
'test_output': '',
|
||||||
|
# 'coverage_output': '',
|
||||||
|
# 'mutation_output': '',
|
||||||
|
'empty_generation': False,
|
||||||
|
'error_eval': False,
|
||||||
|
'all_tests_pass': False,
|
||||||
|
'tests_pass': False,
|
||||||
|
'test_timeout': False,
|
||||||
|
'mutation_timeout': False,
|
||||||
|
'coverage_success': False,
|
||||||
|
'mutation_success': False,
|
||||||
|
'coverage': 0,
|
||||||
|
'mutation_score': 0,
|
||||||
|
'mutation_error_interval': -1,
|
||||||
|
'num_mutants': -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
instance['test_result']['lexical'] = {
|
||||||
|
'pred_loc': -1,
|
||||||
|
'gold_loc': -1,
|
||||||
|
'pred_readability': -1,
|
||||||
|
'gold_readability': -1,
|
||||||
|
'pred_methods': -1,
|
||||||
|
'gold_methods': -1,
|
||||||
|
'code_bleu': -1,
|
||||||
|
'bleu': -1,
|
||||||
|
'xmatch': -1,
|
||||||
|
'edit_sim': -1,
|
||||||
|
'rouge_f': -1,
|
||||||
|
'rouge_p': -1,
|
||||||
|
'rouge_r': -1,
|
||||||
|
}
|
||||||
|
|
||||||
|
if instance['test_suite'] == '' or instance['test_suite'] is None:
|
||||||
|
instance['test_result']['report']['empty_generation'] = True
|
||||||
|
return EvalOutput(
|
||||||
|
instance_id=instance.instance_id, test_result=instance['test_result']
|
||||||
|
)
|
||||||
|
|
||||||
|
if not args.skip_lexical:
|
||||||
|
lexical_metrics = compute_lexical_metrics(
|
||||||
|
instance['test_suite'], instance['instance']['test_src']
|
||||||
|
)
|
||||||
|
instance['test_result']['lexical'] = lexical_metrics
|
||||||
|
|
||||||
|
test_suite = instance['test_suite']
|
||||||
|
test_spec: TestSpec = instance['test_spec']
|
||||||
|
runtime = create_runtime(config)
|
||||||
|
call_async_from_sync(runtime.connect)
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
test_suite_path = os.path.join(temp_dir, 'test_suite.py')
|
||||||
|
with open(test_suite_path, 'w') as f:
|
||||||
|
f.write(test_suite)
|
||||||
|
runtime.copy_to(test_suite_path, '/tmp')
|
||||||
|
|
||||||
|
test_script_path = os.path.join(temp_dir, 'test.sh')
|
||||||
|
with open(test_script_path, 'w') as f:
|
||||||
|
f.write(test_spec.test_script)
|
||||||
|
runtime.copy_to(test_script_path, '/tmp')
|
||||||
|
|
||||||
|
mutation_script_path = os.path.join(temp_dir, 'mutation.sh')
|
||||||
|
with open(mutation_script_path, 'w') as f:
|
||||||
|
f.write(test_spec.mutation_script)
|
||||||
|
runtime.copy_to(mutation_script_path, '/tmp')
|
||||||
|
|
||||||
|
try:
|
||||||
|
run_command(runtime, 'chmod +x /tmp/test.sh /tmp/mutation.sh')
|
||||||
|
run_command(runtime, f'cp /tmp/test_suite.py /testbed/{test_spec.test_file}')
|
||||||
|
|
||||||
|
# First pass - run all tests
|
||||||
|
_, test_output, test_time = run_tests(runtime, instance, '/tmp/test.sh')
|
||||||
|
|
||||||
|
# Grade tests with two-pass approach
|
||||||
|
coverage_success, coverage, unit_test_output, coverage_output, test_stats = (
|
||||||
|
grade_test_output(test_suite, instance, test_output, test_spec, runtime)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update report with test statistics
|
||||||
|
instance['test_result']['report'].update(
|
||||||
|
{
|
||||||
|
'test_output': unit_test_output,
|
||||||
|
# 'coverage_output': coverage_output,
|
||||||
|
'tests_pass': test_stats['any_pass'], # Changed to use any_pass
|
||||||
|
'all_tests_pass': test_stats['all_pass'], # Added all_pass metric
|
||||||
|
'coverage_success': coverage_success,
|
||||||
|
'coverage': coverage if coverage_success else 0,
|
||||||
|
'test_stats': test_stats,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only run mutation testing if we have passing tests and coverage
|
||||||
|
if (
|
||||||
|
not args.skip_mutation
|
||||||
|
and coverage_success
|
||||||
|
and test_stats['any_pass']
|
||||||
|
and coverage > 0
|
||||||
|
):
|
||||||
|
mutation_timeout = max(10, 1.5 * test_time)
|
||||||
|
mutation_toml = MUTATION_TEMPLATE.format(
|
||||||
|
test_cmd=test_spec.test_cmd,
|
||||||
|
source_fp=test_spec.code_file,
|
||||||
|
timeout=mutation_timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
mutation_toml_path = os.path.join(temp_dir, 'mutation.toml')
|
||||||
|
with open(mutation_toml_path, 'w') as f:
|
||||||
|
f.write(mutation_toml)
|
||||||
|
runtime.copy_to(mutation_toml_path, '/tmp')
|
||||||
|
|
||||||
|
run_command(runtime, 'cp /tmp/mutation.toml /testbed/mutation.toml')
|
||||||
|
|
||||||
|
mutation_code, mutation_output = run_mutation_testing(
|
||||||
|
runtime, instance, '/tmp/mutation.sh'
|
||||||
|
)
|
||||||
|
# instance['test_result']['report']['mutation_output'] = mutation_output
|
||||||
|
if mutation_output and mutation_code == 0:
|
||||||
|
(
|
||||||
|
mutation_success,
|
||||||
|
num_mutants,
|
||||||
|
mutation_score,
|
||||||
|
mutation_confidence_interval,
|
||||||
|
) = check_mutation(mutation_output)
|
||||||
|
instance['test_result']['report']['num_mutants'] = num_mutants
|
||||||
|
instance['test_result']['report']['mutation_success'] = mutation_success
|
||||||
|
instance['test_result']['report']['mutation_score'] = mutation_score
|
||||||
|
instance['test_result']['report']['mutation_error_interval'] = (
|
||||||
|
mutation_confidence_interval
|
||||||
|
)
|
||||||
|
|
||||||
|
return EvalOutput(
|
||||||
|
instance_id=instance.instance_id, test_result=instance['test_result']
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'Error processing instance {instance.instance_id}: {e}')
|
||||||
|
raise RuntimeError(
|
||||||
|
instance.instance_id,
|
||||||
|
'Unexpected output...',
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
runtime.close()
|
||||||
|
|
||||||
|
|
||||||
|
def count_and_log_fields(evaluated_predictions, fields, key):
|
||||||
|
"""
|
||||||
|
Count and log the sum of specified fields in the evaluated predictions,
|
||||||
|
ignoring fields with a value of -1. If all values for a field are -1,
|
||||||
|
return -1.
|
||||||
|
|
||||||
|
:param evaluated_predictions: DataFrame containing evaluation results
|
||||||
|
:param fields: List of field names to count
|
||||||
|
:param key: Key to access the field values ('report' or 'lexical')
|
||||||
|
"""
|
||||||
|
|
||||||
|
def count_field(row, field):
|
||||||
|
value = row['test_result'][key][field]
|
||||||
|
return (
|
||||||
|
value if value != -1 else None
|
||||||
|
) # Ignore -1 fields by treating them as None
|
||||||
|
|
||||||
|
for field in fields:
|
||||||
|
# Extract the valid values for the field, ignoring -1
|
||||||
|
valid_values = evaluated_predictions.apply(
|
||||||
|
count_field, args=(field,), axis=1
|
||||||
|
).dropna()
|
||||||
|
|
||||||
|
if valid_values.empty: # If all values are -1
|
||||||
|
logger.info(f'# {field}: -1 (All values are -1)')
|
||||||
|
else:
|
||||||
|
count = valid_values.sum() # Sum of valid values
|
||||||
|
length = len(valid_values) # Count of valid entries
|
||||||
|
logger.info(f'# {field}: {length}. ({count / length:.2f})')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = get_parser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--input-file', type=str, required=True, help='Path to input predictions file'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--dataset',
|
||||||
|
type=str,
|
||||||
|
default='kjain14/testgeneval',
|
||||||
|
help='Dataset to evaluate on',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--split', type=str, default='test', help='Split to evaluate on'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--skip_mutation', action='store_true', help='Skip mutation testing'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--skip_lexical', action='store_true', help='Skip lexical metrics'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--mutation_timeout',
|
||||||
|
type=int,
|
||||||
|
default=MUTATION_TIMEOUT,
|
||||||
|
help='Mutation timeout',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--mutation_buffer',
|
||||||
|
type=int,
|
||||||
|
default=MUTATION_BUFFER,
|
||||||
|
help='Mutation buffer',
|
||||||
|
)
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
dataset: list[TestGenEvalInstance] = load_testgeneval_dataset(
|
||||||
|
args.dataset, args.split
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'Loaded dataset {args.dataset} with split {args.split} to run inference on.'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load predictions
|
||||||
|
assert args.input_file.endswith('.jsonl'), 'Input file must be a jsonl file.'
|
||||||
|
predictions = pd.read_json(args.input_file, lines=True)
|
||||||
|
assert (
|
||||||
|
'instance_id' in predictions.columns
|
||||||
|
), 'Input file must contain instance_id column.'
|
||||||
|
|
||||||
|
if 'test_suite' not in predictions.columns and (
|
||||||
|
'test_result' in predictions.columns
|
||||||
|
and 'test_suite' in predictions['test_result'].iloc(0)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
'Input file must contain test_suite column OR test_result column with test_suite field.'
|
||||||
|
)
|
||||||
|
|
||||||
|
if 'instance_id_swebench' not in predictions.columns:
|
||||||
|
predictions['instance_id_swebench'] = predictions['instance'].apply(
|
||||||
|
lambda x: x['instance_id_swebench']
|
||||||
|
)
|
||||||
|
|
||||||
|
if 'instance_id' not in predictions.columns and (
|
||||||
|
'instance_id' in predictions['instance'].iloc(0)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
'Input file must contain id column OR instance column with id field.'
|
||||||
|
)
|
||||||
|
|
||||||
|
if 'instance_id' not in predictions.columns:
|
||||||
|
predictions['instance_id'] = predictions['instance'].apply(
|
||||||
|
lambda x: x['instance_id']
|
||||||
|
)
|
||||||
|
|
||||||
|
if 'test_suite' not in predictions.columns:
|
||||||
|
predictions['test_suite'] = predictions['test_result'].apply(
|
||||||
|
lambda x: x['test_suite']
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(predictions['instance_id'].unique()) == len(
|
||||||
|
predictions
|
||||||
|
), 'instance_id column must be unique.'
|
||||||
|
|
||||||
|
assert {'instance_id_swebench', 'test_suite', 'instance_id'}.issubset(
|
||||||
|
set(predictions.columns)
|
||||||
|
), 'Input file must contain id, instance_id and test_suite columns.'
|
||||||
|
|
||||||
|
predictions['test_spec'] = predictions['instance'].apply(
|
||||||
|
lambda x: make_test_spec(x, args.mutation_timeout, args.mutation_buffer)
|
||||||
|
)
|
||||||
|
|
||||||
|
output_file = args.input_file.replace('.jsonl', '.testgeneval.jsonl')
|
||||||
|
instances = prepare_dataset(predictions, output_file, args.eval_n_limit)
|
||||||
|
|
||||||
|
# If possible, load the relevant metadata to avoid issues with `run_evaluation`.
|
||||||
|
metadata: EvalMetadata | None = None
|
||||||
|
metadata_filepath = os.path.join(os.path.dirname(args.input_file), 'metadata.json')
|
||||||
|
if os.path.exists(metadata_filepath):
|
||||||
|
with open(metadata_filepath, 'r') as metadata_file:
|
||||||
|
data = metadata_file.read()
|
||||||
|
metadata = EvalMetadata.model_validate_json(data)
|
||||||
|
|
||||||
|
# The evaluation harness constrains the signature of `process_instance_func` but we need to
|
||||||
|
# pass extra information. Build a new function object to avoid issues with multiprocessing.
|
||||||
|
process_instance_func = partial(
|
||||||
|
process_instance, log_dir=output_file.replace('.jsonl', '.logs')
|
||||||
|
)
|
||||||
|
|
||||||
|
run_evaluation(
|
||||||
|
instances,
|
||||||
|
metadata=None,
|
||||||
|
output_file=output_file,
|
||||||
|
num_workers=args.eval_num_workers,
|
||||||
|
process_instance_func=process_instance_func,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load evaluated predictions & print number of resolved predictions
|
||||||
|
evaluated_predictions = pd.read_json(output_file, lines=True)
|
||||||
|
report_fields = [
|
||||||
|
'coverage',
|
||||||
|
'mutation_score',
|
||||||
|
'tests_pass',
|
||||||
|
'all_tests_pass',
|
||||||
|
'empty_generation',
|
||||||
|
'coverage_success',
|
||||||
|
'test_timeout',
|
||||||
|
'error_eval',
|
||||||
|
]
|
||||||
|
lexical_fields = [
|
||||||
|
'pred_loc',
|
||||||
|
'gold_loc',
|
||||||
|
'pred_methods',
|
||||||
|
'gold_methods',
|
||||||
|
'code_bleu',
|
||||||
|
'bleu',
|
||||||
|
'xmatch',
|
||||||
|
'edit_sim',
|
||||||
|
'rouge_f',
|
||||||
|
'rouge_p',
|
||||||
|
'rouge_r',
|
||||||
|
]
|
||||||
|
|
||||||
|
# Log report and lexical fields
|
||||||
|
count_and_log_fields(evaluated_predictions, report_fields, key='report')
|
||||||
|
count_and_log_fields(evaluated_predictions, lexical_fields, key='lexical')
|
||||||
@@ -0,0 +1,291 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
from evaluation.benchmarks.testgeneval.constants import TestStatus
|
||||||
|
|
||||||
|
|
||||||
|
def parse_log_pytest(log: str) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Parser for test logs generated with PyTest framework
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log (str): log content
|
||||||
|
Returns:
|
||||||
|
dict: test case to test status mapping
|
||||||
|
"""
|
||||||
|
test_status_map = {}
|
||||||
|
for line in log.split('\n'):
|
||||||
|
if any([line.startswith(x.value) for x in TestStatus]):
|
||||||
|
# Additional parsing for FAILED status
|
||||||
|
if line.startswith(TestStatus.FAILED.value):
|
||||||
|
line = line.replace(' - ', ' ')
|
||||||
|
test_case = line.split()
|
||||||
|
if len(test_case) <= 1:
|
||||||
|
continue
|
||||||
|
test_status_map[test_case[1]] = test_case[0]
|
||||||
|
return test_status_map
|
||||||
|
|
||||||
|
|
||||||
|
def parse_log_pytest_options(log: str) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Parser for test logs generated with PyTest framework with options
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log (str): log content
|
||||||
|
Returns:
|
||||||
|
dict: test case to test status mapping
|
||||||
|
"""
|
||||||
|
option_pattern = re.compile(r'(.*?)\[(.*)\]')
|
||||||
|
test_status_map = {}
|
||||||
|
for line in log.split('\n'):
|
||||||
|
if any([line.startswith(x.value) for x in TestStatus]):
|
||||||
|
# Additional parsing for FAILED status
|
||||||
|
if line.startswith(TestStatus.FAILED.value):
|
||||||
|
line = line.replace(' - ', ' ')
|
||||||
|
test_case = line.split()
|
||||||
|
if len(test_case) <= 1:
|
||||||
|
continue
|
||||||
|
has_option = option_pattern.search(test_case[1])
|
||||||
|
if has_option:
|
||||||
|
main, option = has_option.groups()
|
||||||
|
if (
|
||||||
|
option.startswith('/')
|
||||||
|
and not option.startswith('//')
|
||||||
|
and '*' not in option
|
||||||
|
):
|
||||||
|
option = '/' + option.split('/')[-1]
|
||||||
|
test_name = f'{main}[{option}]'
|
||||||
|
else:
|
||||||
|
test_name = test_case[1]
|
||||||
|
test_status_map[test_name] = test_case[0]
|
||||||
|
return test_status_map
|
||||||
|
|
||||||
|
|
||||||
|
def parse_log_django(log: str) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Parser for test logs generated with Django tester framework
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log (str): log content
|
||||||
|
Returns:
|
||||||
|
dict: test case to test status mapping
|
||||||
|
"""
|
||||||
|
test_status_map = {}
|
||||||
|
lines = log.split('\n')
|
||||||
|
|
||||||
|
prev_test = None
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
|
||||||
|
# This isn't ideal but the test output spans multiple lines
|
||||||
|
if '--version is equivalent to version' in line:
|
||||||
|
test_status_map['--version is equivalent to version'] = (
|
||||||
|
TestStatus.PASSED.value
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log it in case of error
|
||||||
|
if ' ... ' in line:
|
||||||
|
prev_test = line.split(' ... ')[0]
|
||||||
|
|
||||||
|
pass_suffixes = (' ... ok', ' ... OK', ' ... OK')
|
||||||
|
for suffix in pass_suffixes:
|
||||||
|
if line.endswith(suffix):
|
||||||
|
# TODO: Temporary, exclusive fix for django__django-7188
|
||||||
|
# The proper fix should involve somehow getting the test results to
|
||||||
|
# print on a separate line, rather than the same line
|
||||||
|
if line.strip().startswith(
|
||||||
|
'Applying sites.0002_alter_domain_unique...test_no_migrations'
|
||||||
|
):
|
||||||
|
line = line.split('...', 1)[-1].strip()
|
||||||
|
test = line.rsplit(suffix, 1)[0]
|
||||||
|
test_status_map[test] = TestStatus.PASSED.value
|
||||||
|
break
|
||||||
|
if ' ... skipped' in line:
|
||||||
|
test = line.split(' ... skipped')[0]
|
||||||
|
test_status_map[test] = TestStatus.SKIPPED.value
|
||||||
|
if line.endswith(' ... FAIL'):
|
||||||
|
test = line.split(' ... FAIL')[0]
|
||||||
|
test_status_map[test] = TestStatus.FAILED.value
|
||||||
|
if line.startswith('FAIL:'):
|
||||||
|
test = line.split()[1].strip()
|
||||||
|
test_status_map[test] = TestStatus.FAILED.value
|
||||||
|
if line.endswith(' ... ERROR'):
|
||||||
|
test = line.split(' ... ERROR')[0]
|
||||||
|
test_status_map[test] = TestStatus.ERROR.value
|
||||||
|
if line.startswith('ERROR:'):
|
||||||
|
test = line.split()[1].strip()
|
||||||
|
test_status_map[test] = TestStatus.ERROR.value
|
||||||
|
|
||||||
|
if line.lstrip().startswith('ok') and prev_test is not None:
|
||||||
|
# It means the test passed, but there's some additional output (including new lines)
|
||||||
|
# between "..." and "ok" message
|
||||||
|
test = prev_test
|
||||||
|
test_status_map[test] = TestStatus.PASSED.value
|
||||||
|
|
||||||
|
# TODO: This is very brittle, we should do better
|
||||||
|
# There's a bug in the django logger, such that sometimes a test output near the end gets
|
||||||
|
# interrupted by a particular long multiline print statement.
|
||||||
|
# We have observed this in one of 3 forms:
|
||||||
|
# - "{test_name} ... Testing against Django installed in {*} silenced.\nok"
|
||||||
|
# - "{test_name} ... Internal Server Error: \/(.*)\/\nok"
|
||||||
|
# - "{test_name} ... System check identified no issues (0 silenced).\nok"
|
||||||
|
patterns = [
|
||||||
|
r'^(.*?)\s\.\.\.\sTesting\ against\ Django\ installed\ in\ ((?s:.*?))\ silenced\)\.\nok$',
|
||||||
|
r'^(.*?)\s\.\.\.\sInternal\ Server\ Error:\ \/(.*)\/\nok$',
|
||||||
|
r'^(.*?)\s\.\.\.\sSystem check identified no issues \(0 silenced\)\nok$',
|
||||||
|
]
|
||||||
|
for pattern in patterns:
|
||||||
|
for match in re.finditer(pattern, log, re.MULTILINE):
|
||||||
|
test_name = match.group(1)
|
||||||
|
test_status_map[test_name] = TestStatus.PASSED.value
|
||||||
|
return test_status_map
|
||||||
|
|
||||||
|
|
||||||
|
def parse_log_pytest_v2(log: str) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Parser for test logs generated with PyTest framework (Later Version)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log (str): log content
|
||||||
|
Returns:
|
||||||
|
dict: test case to test status mapping
|
||||||
|
"""
|
||||||
|
test_status_map = {}
|
||||||
|
escapes = ''.join([chr(char) for char in range(1, 32)])
|
||||||
|
for line in log.split('\n'):
|
||||||
|
line = re.sub(r'\[(\d+)m', '', line)
|
||||||
|
translator = str.maketrans('', '', escapes)
|
||||||
|
line = line.translate(translator)
|
||||||
|
if any([line.startswith(x.value) for x in TestStatus]):
|
||||||
|
if line.startswith(TestStatus.FAILED.value):
|
||||||
|
line = line.replace(' - ', ' ')
|
||||||
|
test_case = line.split()
|
||||||
|
if len(test_case) >= 2:
|
||||||
|
test_status_map[test_case[1]] = test_case[0]
|
||||||
|
# Support older pytest versions by checking if the line ends with the test status
|
||||||
|
elif any([line.endswith(x.value) for x in TestStatus]):
|
||||||
|
test_case = line.split()
|
||||||
|
if len(test_case) >= 2:
|
||||||
|
test_status_map[test_case[0]] = test_case[1]
|
||||||
|
return test_status_map
|
||||||
|
|
||||||
|
|
||||||
|
def parse_log_seaborn(log: str) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Parser for test logs generated with seaborn testing framework
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log (str): log content
|
||||||
|
Returns:
|
||||||
|
dict: test case to test status mapping
|
||||||
|
"""
|
||||||
|
test_status_map = {}
|
||||||
|
for line in log.split('\n'):
|
||||||
|
if line.startswith(TestStatus.FAILED.value):
|
||||||
|
test_case = line.split()[1]
|
||||||
|
test_status_map[test_case] = TestStatus.FAILED.value
|
||||||
|
elif f' {TestStatus.PASSED.value} ' in line:
|
||||||
|
parts = line.split()
|
||||||
|
if parts[1] == TestStatus.PASSED.value:
|
||||||
|
test_case = parts[0]
|
||||||
|
test_status_map[test_case] = TestStatus.PASSED.value
|
||||||
|
elif line.startswith(TestStatus.PASSED.value):
|
||||||
|
parts = line.split()
|
||||||
|
test_case = parts[1]
|
||||||
|
test_status_map[test_case] = TestStatus.PASSED.value
|
||||||
|
return test_status_map
|
||||||
|
|
||||||
|
|
||||||
|
def parse_log_sympy(log: str) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Parser for test logs generated with Sympy framework
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log (str): log content
|
||||||
|
Returns:
|
||||||
|
dict: test case to test status mapping
|
||||||
|
"""
|
||||||
|
test_status_map = {}
|
||||||
|
pattern = r'(_*) (.*)\.py:(.*) (_*)'
|
||||||
|
matches = re.findall(pattern, log)
|
||||||
|
for match in matches:
|
||||||
|
test_case = f'{match[1]}.py:{match[2]}'
|
||||||
|
test_status_map[test_case] = TestStatus.FAILED.value
|
||||||
|
for line in log.split('\n'):
|
||||||
|
line = line.strip()
|
||||||
|
if line.startswith('test_'):
|
||||||
|
if line.endswith('[FAIL]') or line.endswith('[OK]'):
|
||||||
|
line = line[: line.rfind('[')]
|
||||||
|
line = line.strip()
|
||||||
|
if line.endswith(' E'):
|
||||||
|
test = line.split()[0]
|
||||||
|
test_status_map[test] = TestStatus.ERROR.value
|
||||||
|
if line.endswith(' F'):
|
||||||
|
test = line.split()[0]
|
||||||
|
test_status_map[test] = TestStatus.FAILED.value
|
||||||
|
if line.endswith(' ok'):
|
||||||
|
test = line.split()[0]
|
||||||
|
test_status_map[test] = TestStatus.PASSED.value
|
||||||
|
return test_status_map
|
||||||
|
|
||||||
|
|
||||||
|
def parse_log_matplotlib(log: str) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Parser for test logs generated with PyTest framework
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log (str): log content
|
||||||
|
Returns:
|
||||||
|
dict: test case to test status mapping
|
||||||
|
"""
|
||||||
|
test_status_map = {}
|
||||||
|
for line in log.split('\n'):
|
||||||
|
line = line.replace('MouseButton.LEFT', '1')
|
||||||
|
line = line.replace('MouseButton.RIGHT', '3')
|
||||||
|
if any([line.startswith(x.value) for x in TestStatus]):
|
||||||
|
# Additional parsing for FAILED status
|
||||||
|
if line.startswith(TestStatus.FAILED.value):
|
||||||
|
line = line.replace(' - ', ' ')
|
||||||
|
test_case = line.split()
|
||||||
|
if len(test_case) <= 1:
|
||||||
|
continue
|
||||||
|
test_status_map[test_case[1]] = test_case[0]
|
||||||
|
return test_status_map
|
||||||
|
|
||||||
|
|
||||||
|
parse_log_astroid = parse_log_pytest
|
||||||
|
parse_log_flask = parse_log_pytest
|
||||||
|
parse_log_marshmallow = parse_log_pytest
|
||||||
|
parse_log_pvlib = parse_log_pytest
|
||||||
|
parse_log_pyvista = parse_log_pytest
|
||||||
|
parse_log_sqlfluff = parse_log_pytest
|
||||||
|
parse_log_xarray = parse_log_pytest
|
||||||
|
|
||||||
|
parse_log_pydicom = parse_log_pytest_options
|
||||||
|
parse_log_requests = parse_log_pytest_options
|
||||||
|
parse_log_pylint = parse_log_pytest_options
|
||||||
|
|
||||||
|
parse_log_astropy = parse_log_pytest_v2
|
||||||
|
parse_log_scikit = parse_log_pytest_v2
|
||||||
|
parse_log_sphinx = parse_log_pytest_v2
|
||||||
|
|
||||||
|
|
||||||
|
MAP_REPO_TO_PARSER = {
|
||||||
|
'astropy/astropy': parse_log_astropy,
|
||||||
|
'django/django': parse_log_django,
|
||||||
|
'marshmallow-code/marshmallow': parse_log_marshmallow,
|
||||||
|
'matplotlib/matplotlib': parse_log_matplotlib,
|
||||||
|
'mwaskom/seaborn': parse_log_seaborn,
|
||||||
|
'pallets/flask': parse_log_flask,
|
||||||
|
'psf/requests': parse_log_requests,
|
||||||
|
'pvlib/pvlib-python': parse_log_pvlib,
|
||||||
|
'pydata/xarray': parse_log_xarray,
|
||||||
|
'pydicom/pydicom': parse_log_pydicom,
|
||||||
|
'pylint-dev/astroid': parse_log_astroid,
|
||||||
|
'pylint-dev/pylint': parse_log_pylint,
|
||||||
|
'pytest-dev/pytest': parse_log_pytest,
|
||||||
|
'pyvista/pyvista': parse_log_pyvista,
|
||||||
|
'scikit-learn/scikit-learn': parse_log_scikit,
|
||||||
|
'sqlfluff/sqlfluff': parse_log_sqlfluff,
|
||||||
|
'sphinx-doc/sphinx': parse_log_sphinx,
|
||||||
|
'sympy/sympy': parse_log_sympy,
|
||||||
|
}
|
||||||
@@ -0,0 +1,311 @@
|
|||||||
|
import sys
|
||||||
|
from typing import Callable, Dict, List, Optional, Sequence, TypeVar, Union
|
||||||
|
|
||||||
|
import nltk
|
||||||
|
import numpy as np
|
||||||
|
from fuzzywuzzy import fuzz
|
||||||
|
from rouge import Rouge
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# increase recursion depth to ensure ROUGE can be calculated for long sentences
|
||||||
|
if sys.getrecursionlimit() < 10_000:
|
||||||
|
sys.setrecursionlimit(10_000)
|
||||||
|
|
||||||
|
def bleu(gold: List[str], pred: List[str]) -> float:
|
||||||
|
"""
|
||||||
|
Calculate BLEU score, using smoothing method 2 with auto reweighting, in the range of 0~100.
|
||||||
|
|
||||||
|
:param gold: list of gold tokens
|
||||||
|
:param pred: list of predicted tokens
|
||||||
|
:return: BLEU score
|
||||||
|
"""
|
||||||
|
if len(pred) == 0 or len(gold) == 0:
|
||||||
|
return 0.0
|
||||||
|
return 100.0 * nltk.translate.bleu_score.sentence_bleu(
|
||||||
|
[gold],
|
||||||
|
pred,
|
||||||
|
smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method2,
|
||||||
|
auto_reweigh=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def batch_bleu(golds: List[List[str]], preds: List[List[str]]) -> List[float]:
|
||||||
|
"""
|
||||||
|
Calculate BLEU score for a batch of sentences.
|
||||||
|
|
||||||
|
:param golds: list of gold sentences
|
||||||
|
:param preds: list of predicted sentences
|
||||||
|
:return: list of BLEU scores
|
||||||
|
"""
|
||||||
|
if len(golds) != len(preds):
|
||||||
|
raise ValueError("golds and preds must have the same length")
|
||||||
|
return [bleu(gold, pred) for gold, pred in zip(golds, preds)]
|
||||||
|
|
||||||
|
|
||||||
|
def corpus_bleu(golds: List[List[str]], preds: List[List[str]]) -> float:
|
||||||
|
"""
|
||||||
|
Calculate corpus-level BLEU score for a batch of sentences.
|
||||||
|
|
||||||
|
:param golds: list of gold sentences
|
||||||
|
:param preds: list of predicted sentences
|
||||||
|
:return: corpus-level BLEU score
|
||||||
|
"""
|
||||||
|
if len(golds) != len(preds):
|
||||||
|
raise ValueError("golds and preds must have the same length")
|
||||||
|
return 100.0 * nltk.translate.bleu_score.corpus_bleu(
|
||||||
|
[[gold] for gold in golds],
|
||||||
|
preds,
|
||||||
|
smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method2,
|
||||||
|
auto_reweigh=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def edit_sim(
|
||||||
|
gold: Union[str, List[str]], pred: Union[str, List[str]], sep: str = " "
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Calculate char-level edit similarity, in the range of 0~100.
|
||||||
|
|
||||||
|
:param gold: gold sentence or list of gold tokens
|
||||||
|
:param pred: predicted sentence or list of predicted tokens
|
||||||
|
:param sep: separator between tokens
|
||||||
|
:return: char-level edit similarity
|
||||||
|
"""
|
||||||
|
if len(pred) == 0 or len(gold) == 0:
|
||||||
|
return 0.0
|
||||||
|
if isinstance(gold, list):
|
||||||
|
gold = sep.join(gold)
|
||||||
|
if isinstance(pred, list):
|
||||||
|
pred = sep.join(pred)
|
||||||
|
return fuzz.ratio(gold, pred)
|
||||||
|
|
||||||
|
|
||||||
|
def batch_edit_sim(
|
||||||
|
golds: List[Union[str, List[str]]],
|
||||||
|
preds: List[Union[str, List[str]]],
|
||||||
|
sep: str = " ",
|
||||||
|
) -> List[float]:
|
||||||
|
"""
|
||||||
|
Calculate char-level edit similarity for a batch of sentences.
|
||||||
|
|
||||||
|
:param golds: list of gold sentences
|
||||||
|
:param preds: list of predicted sentences
|
||||||
|
:param sep: separator between tokens
|
||||||
|
:return: list of char-level edit similarity
|
||||||
|
"""
|
||||||
|
if len(golds) != len(preds):
|
||||||
|
raise ValueError("golds and preds must have the same length")
|
||||||
|
return [edit_sim(gold, pred, sep) for gold, pred in zip(golds, preds)]
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def exact_match(gold: T, pred: T) -> float:
|
||||||
|
"""
|
||||||
|
Calculate exact match accuracy, in the range of {0, 100}.
|
||||||
|
|
||||||
|
:param gold: gold sentence or list of gold tokens
|
||||||
|
:param pred: predicted sentence or list of predicted tokens
|
||||||
|
:return: exact match accuracy
|
||||||
|
"""
|
||||||
|
if len(pred) == 0 or len(gold) == 0:
|
||||||
|
return 0.0
|
||||||
|
return 100.0 if gold == pred else 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def batch_exact_match(golds: List[T], preds: List[T]) -> List[float]:
|
||||||
|
"""
|
||||||
|
Calculate exact match accuracy for a batch of sentences.
|
||||||
|
|
||||||
|
:param golds: list of gold sentences
|
||||||
|
:param preds: list of predicted sentences
|
||||||
|
:return: list of exact match accuracy
|
||||||
|
"""
|
||||||
|
if len(golds) != len(preds):
|
||||||
|
raise ValueError("golds and preds must have the same length")
|
||||||
|
return [exact_match(gold, pred) for gold, pred in zip(golds, preds)]
|
||||||
|
|
||||||
|
|
||||||
|
def rouge_l(
|
||||||
|
gold: Union[str, List[str]], pred: Union[str, List[str]], sep: str = " "
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Calculate ROUGE-L F1, precision, and recall scores, in the range of 0~100.
|
||||||
|
|
||||||
|
:param gold: gold sentence or list of gold tokens
|
||||||
|
:param pred: predicted sentence or list of predicted tokens
|
||||||
|
:return: {"p": precision, "r": recall, "f": F1}
|
||||||
|
"""
|
||||||
|
if len(pred) == 0 or len(gold) == 0:
|
||||||
|
return {"p": 0.0, "r": 0.0, "f": 0.0}
|
||||||
|
if isinstance(gold, list):
|
||||||
|
gold = sep.join(gold)
|
||||||
|
if isinstance(pred, list):
|
||||||
|
pred = sep.join(pred)
|
||||||
|
try:
|
||||||
|
rouge = Rouge()
|
||||||
|
scores = rouge.get_scores(hyps=pred, refs=gold, avg=True)
|
||||||
|
return {x: scores["rouge-l"][x] * 100.0 for x in ["p", "r", "f"]}
|
||||||
|
except ValueError:
|
||||||
|
return {"p": 0.0, "r": 0.0, "f": 0.0}
|
||||||
|
|
||||||
|
|
||||||
|
def batch_rouge_l(
|
||||||
|
golds: List[Union[str, List[str]]],
|
||||||
|
preds: List[Union[str, List[str]]],
|
||||||
|
sep: str = " ",
|
||||||
|
) -> Dict[str, List[float]]:
|
||||||
|
"""
|
||||||
|
Calculate ROUGE-L F1, precision, and recall scores for a batch of sentences.
|
||||||
|
|
||||||
|
:param golds: list of gold sentences
|
||||||
|
:param preds: list of predicted sentences
|
||||||
|
:param sep: separator between tokens
|
||||||
|
:return: list of {"p": precision, "r": recall, "f": F1}
|
||||||
|
"""
|
||||||
|
if len(golds) != len(preds):
|
||||||
|
raise ValueError("golds and preds must have the same length")
|
||||||
|
scores = [rouge_l(gold, pred, sep) for gold, pred in zip(golds, preds)]
|
||||||
|
return {x: [score[x] for score in scores] for x in ["p", "r", "f"]}
|
||||||
|
|
||||||
|
|
||||||
|
def accuracy(
|
||||||
|
gold: List[str],
|
||||||
|
pred: List[str],
|
||||||
|
ignore: Optional[Sequence[str]] = None,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Calculate token-level accuracy, in the range of 0~100.
|
||||||
|
If gold and pred are not the same length, the longer one would be truncated.
|
||||||
|
|
||||||
|
:param gold: list of gold tokens
|
||||||
|
:param pred: list of predicted tokens
|
||||||
|
:param ignore: list of (gold) tokens to ignore
|
||||||
|
:return: accuracy
|
||||||
|
"""
|
||||||
|
if len(pred) == 0 or len(gold) == 0:
|
||||||
|
return 0.0
|
||||||
|
if ignore is None:
|
||||||
|
ignore = []
|
||||||
|
i = 0
|
||||||
|
total = 0
|
||||||
|
match = 0
|
||||||
|
while i < len(gold) and i < len(pred):
|
||||||
|
if gold[i] in ignore:
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
total += 1
|
||||||
|
if gold[i] == pred[i]:
|
||||||
|
match += 1
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
if total == 0:
|
||||||
|
return 0.0
|
||||||
|
return 100.0 * match / total
|
||||||
|
|
||||||
|
|
||||||
|
def batch_accuracy(
|
||||||
|
golds: List[List[str]],
|
||||||
|
preds: List[List[str]],
|
||||||
|
ignore: Optional[Sequence[str]] = None,
|
||||||
|
) -> List[float]:
|
||||||
|
"""
|
||||||
|
Calculate token-level accuracy for a batch of sentences.
|
||||||
|
|
||||||
|
:param golds: list of gold sentences
|
||||||
|
:param preds: list of predicted sentences
|
||||||
|
:param ignore: list of (gold) tokens to ignore
|
||||||
|
:return: list of accuracy
|
||||||
|
"""
|
||||||
|
if len(golds) != len(preds):
|
||||||
|
raise ValueError("golds and preds must have the same length")
|
||||||
|
return [accuracy(gold, pred, ignore) for gold, pred in zip(golds, preds)]
|
||||||
|
|
||||||
|
|
||||||
|
def first_match_to_topk(
|
||||||
|
first_match_list: List[int], k_values: List[int]
|
||||||
|
) -> Dict[int, List[float]]:
|
||||||
|
"""
|
||||||
|
Calculate top-k accuracy with the first match ranks (1-indexed).
|
||||||
|
|
||||||
|
:param first_match: first match ranks (1-indexed)
|
||||||
|
:param k_values: k values to consider
|
||||||
|
:return: a mapping from k to top-k accuracies (ranging from 0~100)
|
||||||
|
"""
|
||||||
|
return {k: [100.0 if x <= k else 0.0 for x in first_match_list] for k in k_values}
|
||||||
|
|
||||||
|
|
||||||
|
def pass_at_k(n: int, c: int, k: int) -> float:
|
||||||
|
"""
|
||||||
|
Sample pass@k metric according to the Codex paper, but in the scale of 0~100.
|
||||||
|
:param n: total number of samples
|
||||||
|
:param c: number of correct samples
|
||||||
|
:param k: k in pass@$k$
|
||||||
|
"""
|
||||||
|
if n < k or (n - c) < k:
|
||||||
|
# fallback to the (1 - (1-p)^k) formula
|
||||||
|
return (1 - (1 - (c / n)) ** k) * 100
|
||||||
|
else:
|
||||||
|
return (1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)).item()) * 100
|
||||||
|
|
||||||
|
|
||||||
|
def self_bleu(samples: List[List[str]]) -> float:
|
||||||
|
"""
|
||||||
|
Calculate self-BLEU among the samples.
|
||||||
|
:param samples: the chosen m samples
|
||||||
|
:return: self-BLEU
|
||||||
|
"""
|
||||||
|
if len(samples) == 0:
|
||||||
|
return 100.0
|
||||||
|
|
||||||
|
scores = []
|
||||||
|
for i in range(len(samples)):
|
||||||
|
scores.append(
|
||||||
|
100.0
|
||||||
|
* nltk.translate.bleu_score.sentence_bleu(
|
||||||
|
[samples[j] for j in range(len(samples)) if j != i],
|
||||||
|
samples[i],
|
||||||
|
smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method2,
|
||||||
|
auto_reweigh=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return np.mean(scores).item()
|
||||||
|
|
||||||
|
|
||||||
|
def self_edit_distance(samples: List[Union[str, List[str]]], sep=" ") -> float:
|
||||||
|
"""
|
||||||
|
Calculate self-edit-distance among the samples.
|
||||||
|
:param samples: the chosen m samples
|
||||||
|
:param sep: the separator between tokens
|
||||||
|
:return: self-edit-distance
|
||||||
|
"""
|
||||||
|
if len(samples) == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
scores = []
|
||||||
|
for i in range(len(samples)):
|
||||||
|
sample_i = samples[i]
|
||||||
|
if not isinstance(sample_i, str):
|
||||||
|
sample_i = sep.join(sample_i)
|
||||||
|
for j in range(len(samples)):
|
||||||
|
if i == j:
|
||||||
|
continue
|
||||||
|
sample_j = samples[j]
|
||||||
|
if not isinstance(sample_j, str):
|
||||||
|
sample_j = sep.join(sample_j)
|
||||||
|
|
||||||
|
scores.append(100 - fuzz.ratio(sample_i, sample_j))
|
||||||
|
return np.mean(scores).item()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
QUALITY_METRICS: Dict[str, Callable[[List[str], List[str]], float]] = {
|
||||||
|
"bleu": bleu,
|
||||||
|
"xmatch": exact_match,
|
||||||
|
"edit-sim": edit_sim,
|
||||||
|
"rouge-f": lambda g, p: rouge_l(g, p)["f"],
|
||||||
|
"rouge-p": lambda g, p: rouge_l(g, p)["p"],
|
||||||
|
"rouge-r": lambda g, p: rouge_l(g, p)["r"],
|
||||||
|
}
|
||||||
@@ -0,0 +1,114 @@
|
|||||||
|
CODEACT_TESTGEN_PROMPT_OLD = """Your goal is to generate a high-quality test suite (at least 20+ passing tests) for the code file: {code_file}. Output the test suite at {test_file}\n'
|
||||||
|
|
||||||
|
[current directory: /workspace/{workspace_dir_name}]
|
||||||
|
|
||||||
|
IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP
|
||||||
|
|
||||||
|
IMPORTANT: Follow instructions, if you have < 80 tests you should generate more tests rather than trying to fix the ones you have.
|
||||||
|
|
||||||
|
IMPORTANT: Code file to test:
|
||||||
|
```python
|
||||||
|
{code_src}
|
||||||
|
```
|
||||||
|
|
||||||
|
Here are additional imports that you may need:
|
||||||
|
{imports}
|
||||||
|
|
||||||
|
Look at code dependencies (NOT {code_file} since you already have contents) and test files you need context for to write a complete test suite.
|
||||||
|
|
||||||
|
Aim for 20+ test functions with asserts. Do not hestitate to use the Python interpreter to understand the input output behavior of the code you are testing.
|
||||||
|
|
||||||
|
Output your test suite at {test_file}. Each unit test must be a function starting with test_. Include all your test imports and setup before your first test. Do not include a main method to run the tests. Make sure to make it as comprehensive as possible, try to execute all the methods you saw.
|
||||||
|
|
||||||
|
When you think you've successfully generated a test suite, run it on for the current project using {coverage_command}.
|
||||||
|
|
||||||
|
If you have few tests GENERATE MORE TESTS rather than trying to fix the ones you have (it is possible to filter out failing tests later).
|
||||||
|
|
||||||
|
Then run coverage report -m --include {code_file} to see how well your test suite covers the code under test.
|
||||||
|
|
||||||
|
When you are trying to improve coverage pick a part of the code that is not covered (indicated by lines on coverage report), examine the code and then
|
||||||
|
try to generate a test for it. Feel free to use a code interpreter to understand the input output behavior. ONLY add tests
|
||||||
|
not remove them.
|
||||||
|
|
||||||
|
If you are unable to see passing and failing tests, FIX YOUR IMPORTS to use the same style as other test files.
|
||||||
|
|
||||||
|
You should NOT modify any existing test case files. You SHOULD add new test in a NEW file to reproduce the issue.
|
||||||
|
|
||||||
|
You should NEVER use web browsing or any other web-based tools.
|
||||||
|
|
||||||
|
You should NEVER install new packages, use existing packages only.
|
||||||
|
|
||||||
|
You should ALWAYS use the default Python interpreter available in the <execute_bash> environment to run code related to the provided issue and/or repository.
|
||||||
|
|
||||||
|
You should ALWAYS use local imports DO NOT import the general library.
|
||||||
|
|
||||||
|
When you think you have a fully adequate test suite, please run the following command: <execute_bash> exit </execute_bash>.
|
||||||
|
"""
|
||||||
|
|
||||||
|
CODEACT_TESTGEN_PROMPT = """
|
||||||
|
Your goal is to generate a comprehensive, **broad-coverage** test suite for the code below, ensuring you test as many lines and branches as possible on the first attempt.
|
||||||
|
|
||||||
|
Place your test suite in a new file named {test_file}.
|
||||||
|
|
||||||
|
IMPORTANT REQUIREMENTS:
|
||||||
|
1. **No external help or resources**—use only the snippet below.
|
||||||
|
2. **Focus on breadth over depth**: cover all major functions, classes, and code paths early to minimize coverage iterations.
|
||||||
|
3. Each test function must start with `test_` and use `assert` to verify behavior.
|
||||||
|
4. Include only necessary imports (standard library or local).
|
||||||
|
5. Do **not** modify existing test files—create a brand new one. No `main()` or other non-test code.
|
||||||
|
6. Produce **at least 20 test functions**; if coverage is lacking, add more tests rather than removing or changing existing ones.
|
||||||
|
7. Use the following commands to check coverage:
|
||||||
|
<execute_bash> {coverage_command} </execute_bash>
|
||||||
|
<execute_bash> coverage report -m --include {code_file} </execute_bash>
|
||||||
|
If lines remain uncovered, add new tests targeting them specifically.
|
||||||
|
8. When you're satisfied with coverage, finalize by running:
|
||||||
|
<execute_bash> exit </execute_bash>
|
||||||
|
|
||||||
|
Below is the **complete code snippet** to test:
|
||||||
|
|
||||||
|
<START_OF_CODE>
|
||||||
|
{code_src}
|
||||||
|
<END_OF_CODE>
|
||||||
|
|
||||||
|
NOTE: if you are testing django, you must use from django.test import SimpleTestCase and class based tests (i.e. class TestSomething(SimpleTestCase)).
|
||||||
|
NOTE: if there is an error executing tests you MUST fix it before exiting. DO NOT install new packages.
|
||||||
|
NOTE: if outputting a revised test suite REPLACE {test_file} with the revised suite
|
||||||
|
|
||||||
|
**Output the final test suite** (20+ tests) for {test_file} in a single code block, no extra commentary. MAKE SURE you run the tests and ensure you can see which tests passed and failed BEFORE exiting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
CODEACT_TESTGEN_PROMPT_ITERATE = """
|
||||||
|
Your goal is to improve the test suite at {test_file} to achieve **broad-coverage** of the code below.
|
||||||
|
|
||||||
|
First run the test suite.
|
||||||
|
|
||||||
|
If no tests run, then remove {test_file} and create {test_file} with a new suite.
|
||||||
|
|
||||||
|
Otherwise, improve it aiming to improve code coverage.
|
||||||
|
|
||||||
|
IMPORTANT REQUIREMENTS:
|
||||||
|
1. Use the following commands to check coverage (RUN THIS FIRST):
|
||||||
|
<execute_bash> {coverage_command} </execute_bash>
|
||||||
|
<execute_bash> coverage report -m --include {code_file} </execute_bash>
|
||||||
|
If lines remain uncovered, add new tests targeting them specifically.
|
||||||
|
2. **No external help or resources**—use only the snippet below.
|
||||||
|
3. **Focus on breadth over depth**: cover all major functions, classes, and code paths early to minimize coverage iterations.
|
||||||
|
4. Each test function must use `assert` to verify behavior.
|
||||||
|
5. Include only necessary imports (standard library or local).
|
||||||
|
6. Do **not** modify other test files in the repository. No `main()` or other non-test code.
|
||||||
|
7. Produce **at least 20 test functions**; if coverage is lacking, add more tests rather than removing or changing existing ones.
|
||||||
|
8. When you're satisfied with coverage, finalize by running:
|
||||||
|
<execute_bash> exit </execute_bash>
|
||||||
|
|
||||||
|
Below is the **complete code snippet** to test:
|
||||||
|
|
||||||
|
<START_OF_CODE>
|
||||||
|
{code_src}
|
||||||
|
<END_OF_CODE>
|
||||||
|
|
||||||
|
NOTE: if you are testing django, you must use from django.test import SimpleTestCase and class based tests (i.e. class TestSomething(SimpleTestCase)).
|
||||||
|
NOTE: if there is an error executing tests you MUST fix it before exiting. DO NOT install new packages.
|
||||||
|
NOTE: if outputting a revised test suite REPLACE {test_file} with the revised suite
|
||||||
|
|
||||||
|
**Output the final test suite** (20+ tests) for {test_file} in a single code block, no extra commentary. MAKE SURE you run the tests and ensure you can see which tests passed and failed BEFORE exiting.
|
||||||
|
"""
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
import re
|
||||||
|
from pygments.lexers.python import PythonLexer
|
||||||
|
|
||||||
|
def tokenize_code(code):
|
||||||
|
lexer = PythonLexer()
|
||||||
|
tokens = process_pygments_tokens(lexer.get_tokens(code))
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def process_pygments_tokens(tokens):
|
||||||
|
new_tokens = []
|
||||||
|
|
||||||
|
for token in tokens:
|
||||||
|
if str(token[0]) == "Token.Text" and re.match(r'\s+', token[1]) or str(token[0]) == "Token.Text.Whitespace":
|
||||||
|
continue
|
||||||
|
new_tokens.append(token[1])
|
||||||
|
|
||||||
|
new_tokens_final = []
|
||||||
|
i = 0
|
||||||
|
while i < len(new_tokens)-2:
|
||||||
|
if new_tokens[i] == '"' and new_tokens[i+1]=='STR' and new_tokens[i+2] == '"':
|
||||||
|
new_tokens_final.append("\"STR\"")
|
||||||
|
i = i + 3
|
||||||
|
else:
|
||||||
|
new_tokens_final.append(new_tokens[i])
|
||||||
|
i = i + 1
|
||||||
|
|
||||||
|
for i in range(len(new_tokens)-2, len(new_tokens)):
|
||||||
|
if i >= 0:
|
||||||
|
new_tokens_final.append(new_tokens[i])
|
||||||
|
|
||||||
|
return new_tokens_final
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def check_coverage(coverage_output, code_file):
|
||||||
|
json_cov = json.loads(coverage_output)
|
||||||
|
if code_file in json_cov['files'].keys():
|
||||||
|
file_data = json_cov['files'][code_file]
|
||||||
|
return True, file_data['summary']['percent_covered']
|
||||||
|
|
||||||
|
return False, 0
|
||||||
|
|
||||||
|
|
||||||
|
def check_mutation(mutation_output):
|
||||||
|
if 'total jobs: ' in mutation_output:
|
||||||
|
num_mutants = int(mutation_output.split('total jobs: ')[1].split('\n')[0])
|
||||||
|
final_conf = mutation_output.split('\n')[-1]
|
||||||
|
if len(final_conf.strip().split(' ')) == 3:
|
||||||
|
low, val, high = final_conf.split(' ')
|
||||||
|
low = float(low)
|
||||||
|
val = float(val)
|
||||||
|
high = float(high)
|
||||||
|
|
||||||
|
confidence_range = high - val
|
||||||
|
mutation_score = 100 - val
|
||||||
|
|
||||||
|
return True, num_mutants, mutation_score, confidence_range
|
||||||
|
|
||||||
|
return False, -1, 0, -1
|
||||||
|
|
||||||
|
|
||||||
|
def count_methods(code_str):
|
||||||
|
"""
|
||||||
|
Counts the number of methods/functions in a given string of code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code_str (str): A string containing code.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The number of methods/functions found.
|
||||||
|
"""
|
||||||
|
# Regular expression to find Python function definitions
|
||||||
|
pattern = r'\bdef\b\s+\w+\s*\('
|
||||||
|
matches = re.findall(pattern, code_str)
|
||||||
|
return len(matches)
|
||||||
|
|
||||||
|
|
||||||
|
def get_lines_of_code(code_str):
|
||||||
|
"""
|
||||||
|
Extracts lines of code from a given string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code_str (str): A string containing code.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of lines of code.
|
||||||
|
"""
|
||||||
|
return len(code_str.strip().split('\n'))
|
||||||
@@ -0,0 +1,577 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import toml
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
import openhands.agenthub
|
||||||
|
from evaluation.benchmarks.testgeneval.constants import MAP_REPO_VERSION_TO_SPECS
|
||||||
|
from evaluation.benchmarks.testgeneval.prompt import (
|
||||||
|
CODEACT_TESTGEN_PROMPT,
|
||||||
|
CODEACT_TESTGEN_PROMPT_ITERATE,
|
||||||
|
)
|
||||||
|
from evaluation.benchmarks.testgeneval.utils import get_test_directives
|
||||||
|
from evaluation.utils.shared import (
|
||||||
|
EvalException,
|
||||||
|
EvalMetadata,
|
||||||
|
EvalOutput,
|
||||||
|
assert_and_raise,
|
||||||
|
codeact_user_response,
|
||||||
|
get_metrics,
|
||||||
|
is_fatal_evaluation_error,
|
||||||
|
make_metadata,
|
||||||
|
prepare_dataset,
|
||||||
|
reset_logger_for_multiprocessing,
|
||||||
|
run_evaluation,
|
||||||
|
update_llm_config_for_completions_logging,
|
||||||
|
)
|
||||||
|
from openhands.controller.state.state import State
|
||||||
|
from openhands.core.config import (
|
||||||
|
AgentConfig,
|
||||||
|
AppConfig,
|
||||||
|
SandboxConfig,
|
||||||
|
get_llm_config_arg,
|
||||||
|
get_parser,
|
||||||
|
)
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.core.main import create_runtime, run_controller
|
||||||
|
from openhands.events.action import CmdRunAction, MessageAction
|
||||||
|
from openhands.events.observation import CmdOutputObservation, ErrorObservation
|
||||||
|
from openhands.events.serialization.event import event_to_dict
|
||||||
|
from openhands.runtime.base import Runtime
|
||||||
|
from openhands.utils.async_utils import call_async_from_sync
|
||||||
|
|
||||||
|
RUN_WITH_BROWSING = os.environ.get('RUN_WITH_BROWSING', 'false').lower() == 'true'
|
||||||
|
|
||||||
|
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
|
||||||
|
'CodeActAgent': codeact_user_response,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess_instance(d):
|
||||||
|
for key, value in d.items():
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
d[key] = value.tolist()
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def _get_swebench_workspace_dir_name(instance: pd.Series) -> str:
|
||||||
|
return f'{instance.repo}__{instance.version}'.replace('/', '__')
|
||||||
|
|
||||||
|
|
||||||
|
def get_instruction(instance: pd.Series, metadata: EvalMetadata):
|
||||||
|
# workspace_dir_name = _get_swebench_workspace_dir_name(instance)
|
||||||
|
# Prepare instruction
|
||||||
|
coverage_command = ' '.join(
|
||||||
|
[
|
||||||
|
MAP_REPO_VERSION_TO_SPECS[instance['repo']][instance['version']][
|
||||||
|
'test_cmd'
|
||||||
|
],
|
||||||
|
*get_test_directives(instance),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Testing general agents
|
||||||
|
prompt_to_use = (
|
||||||
|
CODEACT_TESTGEN_PROMPT_ITERATE
|
||||||
|
if instance['full_pred'] is not None
|
||||||
|
else CODEACT_TESTGEN_PROMPT
|
||||||
|
)
|
||||||
|
instruction = prompt_to_use.format(
|
||||||
|
code_file=os.path.join('/testbed', instance.code_file),
|
||||||
|
test_file=os.path.join('/testbed', instance.test_file),
|
||||||
|
coverage_command=coverage_command,
|
||||||
|
code_src=instance['code_src'],
|
||||||
|
imports='\n'.join(instance.local_imports),
|
||||||
|
workspace_dir_name=_get_swebench_workspace_dir_name(instance),
|
||||||
|
)
|
||||||
|
|
||||||
|
if RUN_WITH_BROWSING:
|
||||||
|
instruction += (
|
||||||
|
'<IMPORTANT!>\n'
|
||||||
|
'You SHOULD NEVER attempt to browse the web. '
|
||||||
|
'</IMPORTANT!>\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
return instruction
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: migrate all swe-bench docker to ghcr.io/openhands
|
||||||
|
DOCKER_IMAGE_PREFIX = os.environ.get('EVAL_DOCKER_IMAGE_PREFIX', 'docker.io/kdjain/')
|
||||||
|
logger.info(f'Using docker image prefix: {DOCKER_IMAGE_PREFIX}')
|
||||||
|
|
||||||
|
|
||||||
|
def get_instance_docker_image(instance_id: str) -> str:
|
||||||
|
image_name = 'sweb.eval.x86_64.' + instance_id
|
||||||
|
image_name = image_name.replace(
|
||||||
|
'__', '_s_'
|
||||||
|
) # to comply with docker image naming convention
|
||||||
|
return DOCKER_IMAGE_PREFIX.rstrip('/') + '/' + image_name
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(
|
||||||
|
instance: pd.Series,
|
||||||
|
metadata: EvalMetadata,
|
||||||
|
) -> AppConfig:
|
||||||
|
# We use a different instance image for the each instance of TestGenEval
|
||||||
|
base_container_image = get_instance_docker_image(instance['instance_id_swebench'])
|
||||||
|
logger.info(
|
||||||
|
f'Using instance container image: {base_container_image}. '
|
||||||
|
f'Please make sure this image exists. '
|
||||||
|
f'Submit an issue on https://github.com/All-Hands-AI/OpenHands if you run into any issues.'
|
||||||
|
)
|
||||||
|
|
||||||
|
config = AppConfig(
|
||||||
|
default_agent=metadata.agent_class,
|
||||||
|
run_as_openhands=False,
|
||||||
|
max_iterations=metadata.max_iterations,
|
||||||
|
runtime=os.environ.get('RUNTIME', 'eventstream'),
|
||||||
|
sandbox=SandboxConfig(
|
||||||
|
base_container_image=base_container_image,
|
||||||
|
enable_auto_lint=True,
|
||||||
|
use_host_network=False,
|
||||||
|
# large enough timeout, since some testcases take very long to run
|
||||||
|
timeout=300,
|
||||||
|
# Add platform to the sandbox config to solve issue 4401
|
||||||
|
platform='linux/amd64',
|
||||||
|
api_key=os.environ.get('ALLHANDS_API_KEY', None),
|
||||||
|
remote_runtime_api_url=os.environ.get(
|
||||||
|
'SANDBOX_REMOTE_RUNTIME_API_URL', 'http://localhost:8000'
|
||||||
|
),
|
||||||
|
keep_runtime_alive=False,
|
||||||
|
remote_runtime_init_timeout=3600,
|
||||||
|
),
|
||||||
|
# do not mount workspace
|
||||||
|
workspace_base=None,
|
||||||
|
workspace_mount_path=None,
|
||||||
|
)
|
||||||
|
config.set_llm_config(
|
||||||
|
update_llm_config_for_completions_logging(
|
||||||
|
metadata.llm_config, metadata.eval_output_dir, instance['id']
|
||||||
|
)
|
||||||
|
)
|
||||||
|
agent_config = AgentConfig(
|
||||||
|
codeact_enable_jupyter=False,
|
||||||
|
codeact_enable_browsing=RUN_WITH_BROWSING,
|
||||||
|
codeact_enable_llm_editor=False,
|
||||||
|
condenser=metadata.condenser_config,
|
||||||
|
)
|
||||||
|
config.set_agent_config(agent_config)
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_runtime(
|
||||||
|
runtime: Runtime,
|
||||||
|
instance: pd.Series, # this argument is not required
|
||||||
|
):
|
||||||
|
"""Initialize the runtime for the agent.
|
||||||
|
|
||||||
|
This function is called before the runtime is used to run the agent.
|
||||||
|
"""
|
||||||
|
logger.info('-' * 30)
|
||||||
|
logger.info('BEGIN Runtime Initialization Fn')
|
||||||
|
logger.info('-' * 30)
|
||||||
|
workspace_dir_name = _get_swebench_workspace_dir_name(instance)
|
||||||
|
obs: CmdOutputObservation
|
||||||
|
|
||||||
|
instance['instance_id'] = instance['instance_id_swebench']
|
||||||
|
|
||||||
|
# Set instance id
|
||||||
|
action = CmdRunAction(
|
||||||
|
command=f"""echo 'export SWE_INSTANCE_ID={instance['instance_id_swebench']}' >> ~/.bashrc && echo 'export PIP_CACHE_DIR=~/.cache/pip' >> ~/.bashrc && echo "alias git='git --no-pager'" >> ~/.bashrc"""
|
||||||
|
)
|
||||||
|
action.set_hard_timeout(600)
|
||||||
|
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||||
|
obs = runtime.run_action(action)
|
||||||
|
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||||
|
assert_and_raise(
|
||||||
|
obs.exit_code == 0, f'Failed to export SWE_INSTANCE_ID: {str(obs)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
action = CmdRunAction(command="""export USER=$(whoami); echo USER=${USER} """)
|
||||||
|
action.set_hard_timeout(600)
|
||||||
|
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||||
|
obs = runtime.run_action(action)
|
||||||
|
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||||
|
assert_and_raise(obs.exit_code == 0, f'Failed to export USER: {str(obs)}')
|
||||||
|
|
||||||
|
# inject the init script
|
||||||
|
script_dir = os.path.dirname(__file__)
|
||||||
|
|
||||||
|
# inject the instance info
|
||||||
|
action = CmdRunAction(command='mkdir -p /swe_util/eval_data/instances')
|
||||||
|
action.set_hard_timeout(600)
|
||||||
|
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||||
|
obs = runtime.run_action(action)
|
||||||
|
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||||
|
assert_and_raise(
|
||||||
|
obs.exit_code == 0,
|
||||||
|
f'Failed to create /swe_util/eval_data/instances: {str(obs)}',
|
||||||
|
)
|
||||||
|
|
||||||
|
swe_instance_json_name = 'swe-bench-instance.json'
|
||||||
|
swe_prediction = 'test_suite.py'
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
# Construct the full path for the desired file name within the temporary directory
|
||||||
|
temp_file_path = os.path.join(temp_dir, swe_instance_json_name)
|
||||||
|
# Write to the file with the desired name within the temporary directory
|
||||||
|
with open(temp_file_path, 'w') as f:
|
||||||
|
if not isinstance(instance, dict):
|
||||||
|
preprocessed_instance = _preprocess_instance(instance.to_dict())
|
||||||
|
json.dump([preprocessed_instance], f)
|
||||||
|
else:
|
||||||
|
preprocessed_instance = _preprocess_instance(instance)
|
||||||
|
json.dump([preprocessed_instance], f)
|
||||||
|
|
||||||
|
# Copy the file to the desired location
|
||||||
|
runtime.copy_to(temp_file_path, '/swe_util/eval_data/instances/')
|
||||||
|
|
||||||
|
if instance['full_pred'] is not None:
|
||||||
|
temp_file_path_pred = os.path.join(temp_dir, swe_prediction)
|
||||||
|
with open(temp_file_path_pred, 'w') as f:
|
||||||
|
f.write(instance['full_pred'])
|
||||||
|
|
||||||
|
runtime.copy_to(temp_file_path_pred, '/tmp')
|
||||||
|
|
||||||
|
# Copy the file to the desired location
|
||||||
|
action = CmdRunAction(
|
||||||
|
command=f"cp /tmp/test_suite.py /testbed/{instance['test_file']}"
|
||||||
|
)
|
||||||
|
action.set_hard_timeout(600)
|
||||||
|
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||||
|
obs = runtime.run_action(action)
|
||||||
|
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||||
|
assert_and_raise(
|
||||||
|
obs.exit_code == 0, f'Failed to copy test file: {str(obs)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
action = CmdRunAction(
|
||||||
|
command='git -C /testbed add . && git -C /testbed commit -m "Add test file"'
|
||||||
|
)
|
||||||
|
action.set_hard_timeout(600)
|
||||||
|
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||||
|
obs = runtime.run_action(action)
|
||||||
|
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||||
|
assert_and_raise(obs.exit_code == 0, f'Failed to cat ~/.bashrc: {str(obs)}')
|
||||||
|
|
||||||
|
# inject the instance swe entry
|
||||||
|
runtime.copy_to(
|
||||||
|
str(os.path.join(script_dir, 'scripts/setup/instance_swe_entry.sh')),
|
||||||
|
'/swe_util/',
|
||||||
|
)
|
||||||
|
action = CmdRunAction(command='cat ~/.bashrc')
|
||||||
|
action.set_hard_timeout(600)
|
||||||
|
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||||
|
obs = runtime.run_action(action)
|
||||||
|
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||||
|
assert_and_raise(obs.exit_code == 0, f'Failed to cat ~/.bashrc: {str(obs)}')
|
||||||
|
|
||||||
|
action = CmdRunAction(command='source ~/.bashrc')
|
||||||
|
action.set_hard_timeout(600)
|
||||||
|
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||||
|
obs = runtime.run_action(action)
|
||||||
|
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||||
|
if isinstance(obs, ErrorObservation):
|
||||||
|
logger.error(f'Failed to source ~/.bashrc: {str(obs)}')
|
||||||
|
assert_and_raise(obs.exit_code == 0, f'Failed to source ~/.bashrc: {str(obs)}')
|
||||||
|
|
||||||
|
action = CmdRunAction(command='source /swe_util/instance_swe_entry.sh')
|
||||||
|
action.set_hard_timeout(600)
|
||||||
|
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||||
|
obs = runtime.run_action(action)
|
||||||
|
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||||
|
assert_and_raise(
|
||||||
|
obs.exit_code == 0,
|
||||||
|
f'Failed to source /swe_util/instance_swe_entry.sh: {str(obs)}',
|
||||||
|
)
|
||||||
|
|
||||||
|
action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
|
||||||
|
action.set_hard_timeout(600)
|
||||||
|
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||||
|
obs = runtime.run_action(action)
|
||||||
|
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||||
|
assert_and_raise(
|
||||||
|
obs.exit_code == 0,
|
||||||
|
f'Failed to cd to /workspace/{workspace_dir_name}: {str(obs)}',
|
||||||
|
)
|
||||||
|
|
||||||
|
action = CmdRunAction(command='git reset --hard')
|
||||||
|
action.set_hard_timeout(600)
|
||||||
|
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||||
|
obs = runtime.run_action(action)
|
||||||
|
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||||
|
assert_and_raise(obs.exit_code == 0, f'Failed to git reset --hard: {str(obs)}')
|
||||||
|
|
||||||
|
action = CmdRunAction(
|
||||||
|
command='for remote_name in $(git remote); do git remote remove "${remote_name}"; done'
|
||||||
|
)
|
||||||
|
action.set_hard_timeout(600)
|
||||||
|
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||||
|
obs = runtime.run_action(action)
|
||||||
|
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||||
|
assert_and_raise(obs.exit_code == 0, f'Failed to remove git remotes: {str(obs)}')
|
||||||
|
|
||||||
|
logger.info('-' * 30)
|
||||||
|
logger.info('END Runtime Initialization Fn')
|
||||||
|
logger.info('-' * 30)
|
||||||
|
|
||||||
|
|
||||||
|
def complete_runtime(
|
||||||
|
runtime: Runtime,
|
||||||
|
instance: pd.Series, # this argument is not required, but it is used to get the workspace_dir_name
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Complete the runtime for the agent.
|
||||||
|
|
||||||
|
This function is called before the runtime is used to run the agent.
|
||||||
|
If you need to do something in the sandbox to get the correctness metric after
|
||||||
|
the agent has run, modify this function.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info('-' * 30)
|
||||||
|
logger.info('BEGIN Runtime Completion Fn')
|
||||||
|
logger.info('-' * 30)
|
||||||
|
obs: CmdOutputObservation
|
||||||
|
workspace_dir_name = _get_swebench_workspace_dir_name(instance)
|
||||||
|
|
||||||
|
action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
|
||||||
|
action.set_hard_timeout(600)
|
||||||
|
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||||
|
obs = runtime.run_action(action)
|
||||||
|
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||||
|
assert_and_raise(
|
||||||
|
obs.exit_code == 0,
|
||||||
|
f'Failed to cd to /workspace/{workspace_dir_name}: {str(obs)}',
|
||||||
|
)
|
||||||
|
|
||||||
|
action = CmdRunAction(command=f'cat {instance.test_file}')
|
||||||
|
action.set_hard_timeout(600)
|
||||||
|
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||||
|
obs = runtime.run_action(action)
|
||||||
|
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||||
|
assert_and_raise(
|
||||||
|
obs.exit_code == 0,
|
||||||
|
f'Failed to find file: {instance.test_file} in /workspace/{workspace_dir_name}',
|
||||||
|
)
|
||||||
|
|
||||||
|
test_suite = obs.content.strip()
|
||||||
|
except Exception:
|
||||||
|
# Print stack trace
|
||||||
|
print('Skipping, exception in complete_runtime')
|
||||||
|
print(traceback.format_exc())
|
||||||
|
test_suite = instance['full_pred'] if instance['full_pred'] is not None else ''
|
||||||
|
|
||||||
|
# action = CmdRunAction(command='git add -A')
|
||||||
|
# action.set_hard_timeout(600)
|
||||||
|
# logger.info(action, extra={'msg_type': 'ACTION'})
|
||||||
|
# obs = runtime.run_action(action)
|
||||||
|
# logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||||
|
# assert_and_raise(obs.exit_code == 0, f'Failed to git add -A: {str(obs)}')
|
||||||
|
|
||||||
|
logger.info('-' * 30)
|
||||||
|
logger.info('END Runtime Completion Fn')
|
||||||
|
logger.info('-' * 30)
|
||||||
|
return {
|
||||||
|
'test_suite': test_suite,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def process_instance(
|
||||||
|
instance: pd.Series,
|
||||||
|
metadata: EvalMetadata,
|
||||||
|
reset_logger: bool = True,
|
||||||
|
) -> EvalOutput:
|
||||||
|
config = get_config(instance, metadata)
|
||||||
|
start_time = time.time() # Track start time
|
||||||
|
|
||||||
|
# Setup the logger properly, so you can run multi-processing to parallelize the evaluation
|
||||||
|
if reset_logger:
|
||||||
|
log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
|
||||||
|
reset_logger_for_multiprocessing(logger, instance.id, log_dir)
|
||||||
|
else:
|
||||||
|
logger.info(f'Starting evaluation for instance {instance.id}.')
|
||||||
|
|
||||||
|
runtime = create_runtime(config)
|
||||||
|
call_async_from_sync(runtime.connect)
|
||||||
|
|
||||||
|
try:
|
||||||
|
initialize_runtime(runtime, instance)
|
||||||
|
|
||||||
|
instruction = get_instruction(instance, metadata)
|
||||||
|
|
||||||
|
# Here's how you can run the agent (similar to the `main` function) and get the final task state
|
||||||
|
state: State | None = asyncio.run(
|
||||||
|
run_controller(
|
||||||
|
config=config,
|
||||||
|
initial_user_action=MessageAction(content=instruction),
|
||||||
|
runtime=runtime,
|
||||||
|
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
|
||||||
|
metadata.agent_class
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# if fatal error, throw EvalError to trigger re-run
|
||||||
|
if is_fatal_evaluation_error(state.last_error):
|
||||||
|
raise EvalException('Fatal error detected: ' + state.last_error)
|
||||||
|
|
||||||
|
# ======= THIS IS SWE-Bench specific =======
|
||||||
|
return_val = complete_runtime(runtime, instance)
|
||||||
|
test_suite = return_val['test_suite']
|
||||||
|
logger.info(
|
||||||
|
f'Got test suite for instance {instance.instance_id}:\n--------\n{test_suite}\n--------'
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
runtime.close()
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
elapsed_time = end_time - start_time
|
||||||
|
logger.info(
|
||||||
|
f'Evaluation for instance {instance.instance_id} took {elapsed_time:.2f} seconds.'
|
||||||
|
)
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
|
||||||
|
# ======= Attempt to evaluate the agent's edits =======
|
||||||
|
# we use eval_infer.sh to evaluate the agent's edits, not here
|
||||||
|
# because the agent may alter the environment / testcases
|
||||||
|
test_result = {
|
||||||
|
'test_suite': test_suite,
|
||||||
|
'elapsed_time': elapsed_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
# If you are working on some simpler benchmark that only evaluates the final model output (e.g., in a MessageAction)
|
||||||
|
# You can simply get the LAST `MessageAction` from the returned `state.history` and parse it for evaluation.
|
||||||
|
if state is None:
|
||||||
|
raise ValueError('State should not be None.')
|
||||||
|
|
||||||
|
histories = [event_to_dict(event) for event in state.history]
|
||||||
|
metrics = get_metrics(state)
|
||||||
|
|
||||||
|
# Save the output
|
||||||
|
output = EvalOutput(
|
||||||
|
instance_id=instance.id,
|
||||||
|
instruction=instruction,
|
||||||
|
instance=_preprocess_instance(instance.to_dict()), # SWE Bench specific
|
||||||
|
test_result=test_result,
|
||||||
|
metadata=metadata,
|
||||||
|
history=histories,
|
||||||
|
metrics=metrics,
|
||||||
|
error=state.last_error if state and state.last_error else None,
|
||||||
|
)
|
||||||
|
# print(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_dataset_pre(dataset: pd.DataFrame, filter_column: str) -> pd.DataFrame:
|
||||||
|
file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.toml')
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
with open(file_path, 'r') as file:
|
||||||
|
data = toml.load(file)
|
||||||
|
if 'selected_ids' in data:
|
||||||
|
selected_ids = data['selected_ids']
|
||||||
|
logger.info(
|
||||||
|
f'Filtering {len(selected_ids)} tasks from "selected_ids"...'
|
||||||
|
)
|
||||||
|
subset = dataset[dataset[filter_column].isin(selected_ids)]
|
||||||
|
logger.info(f'Retained {subset.shape[0]} tasks after filtering')
|
||||||
|
|
||||||
|
subset['instance_id_swebench'] = subset['instance_id']
|
||||||
|
subset['instance_id'] = subset['id']
|
||||||
|
return subset
|
||||||
|
|
||||||
|
dataset['instance_id_swebench'] = dataset['instance_id']
|
||||||
|
dataset['instance_id'] = dataset['id']
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = get_parser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--dataset',
|
||||||
|
type=str,
|
||||||
|
default='kjain/testgenevallite',
|
||||||
|
help='data set to evaluate on, either full-test or lite-test',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--split',
|
||||||
|
type=str,
|
||||||
|
default='test',
|
||||||
|
help='split to evaluate on',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--testfile_start',
|
||||||
|
action='store_true',
|
||||||
|
help='Whether to start from the 0 shot test file',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'--zero_shot_path',
|
||||||
|
type=str,
|
||||||
|
help='Path to the zero shot test file predictions',
|
||||||
|
)
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
if args.testfile_start and not args.zero_shot_path:
|
||||||
|
raise ValueError(
|
||||||
|
'If you want to start from the 0 shot test file, you must provide the path to the zero shot test file predictions'
|
||||||
|
)
|
||||||
|
|
||||||
|
preds_map = {}
|
||||||
|
if args.testfile_start:
|
||||||
|
with open(args.zero_shot_path, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
pred = json.loads(line)
|
||||||
|
preds_map[pred['id']] = pred['preds']['full'][0]
|
||||||
|
|
||||||
|
# NOTE: It is preferable to load datasets from huggingface datasets and perform post-processing
|
||||||
|
# so we don't need to manage file uploading to OpenHands's repo
|
||||||
|
dataset = load_dataset(args.dataset, split=args.split)
|
||||||
|
logger.info(f'Loaded dataset {args.dataset} with split {args.split}')
|
||||||
|
testgeneval_filepairs = prepare_dataset_pre(dataset.to_pandas(), 'id')
|
||||||
|
|
||||||
|
llm_config = None
|
||||||
|
if args.llm_config:
|
||||||
|
llm_config = get_llm_config_arg(args.llm_config)
|
||||||
|
llm_config.log_completions = True
|
||||||
|
# modify_params must be False for evaluation purpose, for reproducibility and accurancy of results
|
||||||
|
llm_config.modify_params = False
|
||||||
|
|
||||||
|
if llm_config is None:
|
||||||
|
raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
|
||||||
|
|
||||||
|
details = {}
|
||||||
|
_agent_cls = openhands.agenthub.Agent.get_cls(args.agent_cls)
|
||||||
|
|
||||||
|
dataset_descrption = (
|
||||||
|
args.dataset.replace('/', '__') + '-' + args.split.replace('/', '__')
|
||||||
|
)
|
||||||
|
metadata = make_metadata(
|
||||||
|
llm_config,
|
||||||
|
dataset_descrption,
|
||||||
|
args.agent_cls,
|
||||||
|
args.max_iterations,
|
||||||
|
args.eval_note,
|
||||||
|
args.eval_output_dir,
|
||||||
|
details=details,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
|
||||||
|
instances = prepare_dataset(testgeneval_filepairs, output_file, args.eval_n_limit)
|
||||||
|
|
||||||
|
if not instances.empty:
|
||||||
|
instances['full_pred'] = (
|
||||||
|
instances['instance_id']
|
||||||
|
.map(preds_map)
|
||||||
|
.apply(lambda x: x if pd.notna(x) else None)
|
||||||
|
)
|
||||||
|
|
||||||
|
run_evaluation(
|
||||||
|
instances, metadata, output_file, args.eval_num_workers, process_instance
|
||||||
|
)
|
||||||
@@ -0,0 +1,128 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
|
||||||
|
# Function to run shell commands
|
||||||
|
def run_command(command):
|
||||||
|
try:
|
||||||
|
subprocess.run(command, check=True, shell=True)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f'An error occurred: {e}')
|
||||||
|
|
||||||
|
|
||||||
|
# Function to log in to Docker Hub
|
||||||
|
def docker_login():
|
||||||
|
print('Logging into Docker Hub...')
|
||||||
|
run_command('docker login')
|
||||||
|
|
||||||
|
|
||||||
|
# Function to generate Dockerfile content based on image type
|
||||||
|
def generate_dockerfile_content(
|
||||||
|
base_image, dependencies, datum, patch_path, test_patch_path
|
||||||
|
):
|
||||||
|
dockerfile_content = f"""
|
||||||
|
FROM {base_image}
|
||||||
|
SHELL ["/bin/bash", "-c"]
|
||||||
|
RUN source /opt/miniconda3/bin/activate && conda activate testbed && pip install {' '.join(dependencies)}
|
||||||
|
COPY {patch_path} /app/patch.diff
|
||||||
|
RUN git apply /app/patch.diff
|
||||||
|
RUN rm /app/patch.diff
|
||||||
|
COPY {test_patch_path} /app/patch.diff
|
||||||
|
RUN git apply /app/patch.diff
|
||||||
|
RUN git config --global user.email ""
|
||||||
|
RUN git config --global user.name "TestGenEval"
|
||||||
|
RUN rm /app/patch.diff
|
||||||
|
RUN rm {datum['test_file']}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Add specific content based on image type
|
||||||
|
dockerfile_content += 'RUN git add .\nRUN git commit -m "Testing fixes"'
|
||||||
|
|
||||||
|
return dockerfile_content
|
||||||
|
|
||||||
|
|
||||||
|
# Function to build, push, and clean up Docker images
|
||||||
|
def build_and_push_image(dockerfile_content, image_name):
|
||||||
|
with open('Dockerfile.temp', 'w') as dockerfile:
|
||||||
|
dockerfile.write(dockerfile_content)
|
||||||
|
run_command(f'docker build -f Dockerfile.temp -t {image_name} .')
|
||||||
|
run_command(f'docker push {image_name}')
|
||||||
|
run_command(f'docker rmi {image_name}')
|
||||||
|
os.remove('Dockerfile.temp')
|
||||||
|
|
||||||
|
|
||||||
|
# Function to process images with .eval in the name
|
||||||
|
def process_images(dataset, original_namespace, new_namespace, start_instance_id):
|
||||||
|
dependencies = ['coverage', 'cosmic-ray']
|
||||||
|
|
||||||
|
found_start = len(start_instance_id) == 0
|
||||||
|
for datum in dataset:
|
||||||
|
if not found_start and datum['instance_id'] == start_instance_id:
|
||||||
|
found_start = True
|
||||||
|
elif found_start:
|
||||||
|
full_image_name = f'{original_namespace}/sweb.eval.x86_64.{datum["instance_id"].replace("__", "_s_")}:latest'
|
||||||
|
print(f'Processing image: {full_image_name}')
|
||||||
|
run_command(f'docker pull {full_image_name}')
|
||||||
|
|
||||||
|
# Save patches and preds_context to regular files
|
||||||
|
patch_file_path = 'patch.diff'
|
||||||
|
test_patch_file_path = 'test_patch.diff'
|
||||||
|
|
||||||
|
with open(patch_file_path, 'w') as patch_file, open(
|
||||||
|
test_patch_file_path, 'w'
|
||||||
|
) as test_patch_file:
|
||||||
|
patch_file.write(datum['patch'])
|
||||||
|
test_patch_file.write(datum['test_patch'])
|
||||||
|
|
||||||
|
# Define image types and corresponding tags
|
||||||
|
new_image_name = f'{new_namespace}/sweb.eval.x86_64.{datum["instance_id"].replace("__", "_s_")}:latest'
|
||||||
|
dockerfile_content = generate_dockerfile_content(
|
||||||
|
full_image_name,
|
||||||
|
dependencies,
|
||||||
|
datum,
|
||||||
|
patch_file_path,
|
||||||
|
test_patch_file_path,
|
||||||
|
)
|
||||||
|
build_and_push_image(dockerfile_content, new_image_name)
|
||||||
|
|
||||||
|
# Cleanup regular files and images
|
||||||
|
os.remove(patch_file_path)
|
||||||
|
os.remove(test_patch_file_path)
|
||||||
|
run_command(f'docker rmi {full_image_name}')
|
||||||
|
run_command('docker system prune -f') # Clean up dangling resources
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Process Docker images with .eval in the name.'
|
||||||
|
)
|
||||||
|
parser.add_argument('--dataset', type=str, default='kjain14/testgeneval')
|
||||||
|
parser.add_argument('--split', type=str, default='test')
|
||||||
|
parser.add_argument(
|
||||||
|
'--new_namespace',
|
||||||
|
type=str,
|
||||||
|
default='kdjain',
|
||||||
|
help='The new Docker Hub namespace to push the images',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--original_namespace',
|
||||||
|
type=str,
|
||||||
|
default='xingyaoww',
|
||||||
|
help='The original Docker Hub namespace',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--start_instance_id',
|
||||||
|
type=str,
|
||||||
|
default='',
|
||||||
|
help='The instance_id to start processing from',
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
dataset = load_dataset(args.dataset)[args.split]
|
||||||
|
|
||||||
|
docker_login()
|
||||||
|
process_images(
|
||||||
|
dataset, args.original_namespace, args.new_namespace, args.start_instance_id
|
||||||
|
)
|
||||||
+1274
File diff suppressed because it is too large
Load Diff
+196
@@ -0,0 +1,196 @@
|
|||||||
|
sweb.base.x86_64:latest
|
||||||
|
sweb.env.x86_64.088a7e628bda9770f9757b:latest
|
||||||
|
sweb.env.x86_64.0d80c7dec81ee2f2f513e2:latest
|
||||||
|
sweb.env.x86_64.0f99bce2750f3109957bec:latest
|
||||||
|
sweb.env.x86_64.1b3b218535da0abf4469cb:latest
|
||||||
|
sweb.env.x86_64.1c1a6945f732f9391228c5:latest
|
||||||
|
sweb.env.x86_64.1f92e6d7cef88badc4f744:latest
|
||||||
|
sweb.env.x86_64.27dd9791e13f5c857a09f9:latest
|
||||||
|
sweb.env.x86_64.297af196949a2a635bce66:latest
|
||||||
|
sweb.env.x86_64.2baaea72acc974f6c02079:latest
|
||||||
|
sweb.env.x86_64.2e50125951bc69cddd7421:latest
|
||||||
|
sweb.env.x86_64.2f217c8b4490bfa0e2ba14:latest
|
||||||
|
sweb.env.x86_64.31244378a92e3bcce809ac:latest
|
||||||
|
sweb.env.x86_64.428468730904ff6b4232aa:latest
|
||||||
|
sweb.env.x86_64.5d1fda9d55d65d8a4e5bdb:latest
|
||||||
|
sweb.env.x86_64.6b007979cf533f0f3016e8:latest
|
||||||
|
sweb.env.x86_64.7037e8c448a4b8ebfe9b13:latest
|
||||||
|
sweb.env.x86_64.71498c7426dbf05599642f:latest
|
||||||
|
sweb.env.x86_64.756beac07713d7e8dc1129:latest
|
||||||
|
sweb.env.x86_64.78278ae2cf880e395f1337:latest
|
||||||
|
sweb.env.x86_64.8f1f7b974f0c57c7aeba39:latest
|
||||||
|
sweb.env.x86_64.934a137824256b612e9dc5:latest
|
||||||
|
sweb.env.x86_64.a0efca7a0fe6719dbf65c2:latest
|
||||||
|
sweb.env.x86_64.a18371b03f944585b4f08c:latest
|
||||||
|
sweb.env.x86_64.a33dddf55cdff5d8e23374:latest
|
||||||
|
sweb.env.x86_64.aa92880033da20ca313928:latest
|
||||||
|
sweb.env.x86_64.b649f0ff62fad147f7f073:latest
|
||||||
|
sweb.env.x86_64.b7ce4be3b3c35f68c61248:latest
|
||||||
|
sweb.env.x86_64.c70909fdac4897d1c685df:latest
|
||||||
|
sweb.env.x86_64.c795f4b88616b8462021ed:latest
|
||||||
|
sweb.env.x86_64.cc47cc71483942d0c3a15e:latest
|
||||||
|
sweb.env.x86_64.dc5ff4c0e3fe8db5afc4da:latest
|
||||||
|
sweb.env.x86_64.e3afd7f04b325a4de4982d:latest
|
||||||
|
sweb.env.x86_64.e5bb89bf78258a7d14c34b:latest
|
||||||
|
sweb.env.x86_64.e83e37f52c09532c62acfb:latest
|
||||||
|
sweb.env.x86_64.efa6065ed5bf204410fd53:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-17087:latest
|
||||||
|
sweb.eval.x86_64.scikit-learn_s_scikit-learn-10508:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-14017:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-11422:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-14774:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-14915:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-22005:latest
|
||||||
|
sweb.eval.x86_64.pytest-dev_s_pytest-5221:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-17022:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-15996:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-15252:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-21171:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-11797:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-16046:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-11583:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-15738:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-21612:latest
|
||||||
|
sweb.eval.x86_64.astropy_s_astropy-12907:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-11620:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-16792:latest
|
||||||
|
sweb.eval.x86_64.scikit-learn_s_scikit-learn-13779:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-16041:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-13471:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-20442:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-20049:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-14411:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-13447:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-12856:latest
|
||||||
|
sweb.eval.x86_64.scikit-learn_s_scikit-learn-10949:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-14787:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-11815:latest
|
||||||
|
sweb.eval.x86_64.scikit-learn_s_scikit-learn-13584:latest
|
||||||
|
sweb.eval.x86_64.scikit-learn_s_scikit-learn-14087:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-15388:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-11179:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-24102:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-24213:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-15781:latest
|
||||||
|
sweb.eval.x86_64.pytest-dev_s_pytest-8906:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-13710:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-13925:latest
|
||||||
|
sweb.eval.x86_64.scikit-learn_s_scikit-learn-14092:latest
|
||||||
|
sweb.eval.x86_64.pytest-dev_s_pytest-7373:latest
|
||||||
|
sweb.eval.x86_64.matplotlib_s_matplotlib-25498:latest
|
||||||
|
sweb.eval.x86_64.pytest-dev_s_pytest-5227:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-15678:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-13551:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-14155:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-13933:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-21055:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-13660:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-16527:latest
|
||||||
|
sweb.eval.x86_64.pytest-dev_s_pytest-5692:latest
|
||||||
|
sweb.eval.x86_64.mwaskom_s_seaborn-3010:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-12700:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-11400:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-23117:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-20639:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-23262:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-15498:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-12453:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-14999:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-13480:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-21847:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-15011:latest
|
||||||
|
sweb.eval.x86_64.scikit-learn_s_scikit-learn-25570:latest
|
||||||
|
sweb.eval.x86_64.sphinx-doc_s_sphinx-7975:latest
|
||||||
|
sweb.eval.x86_64.scikit-learn_s_scikit-learn-14983:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-14534:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-14396:latest
|
||||||
|
sweb.eval.x86_64.matplotlib_s_matplotlib-25442:latest
|
||||||
|
sweb.eval.x86_64.scikit-learn_s_scikit-learn-15535:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-22714:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-15789:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-21627:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-24066:latest
|
||||||
|
sweb.eval.x86_64.pylint-dev_s_pylint-7993:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-14752:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-18835:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-17051:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-12171:latest
|
||||||
|
sweb.eval.x86_64.pydata_s_xarray-3364:latest
|
||||||
|
sweb.eval.x86_64.mwaskom_s_seaborn-3190:latest
|
||||||
|
sweb.eval.x86_64.pytest-dev_s_pytest-7168:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-12747:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-15695:latest
|
||||||
|
sweb.eval.x86_64.matplotlib_s_matplotlib-22835:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-12481:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-15851:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-14024:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-14608:latest
|
||||||
|
sweb.eval.x86_64.pytest-dev_s_pytest-9359:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-16873:latest
|
||||||
|
sweb.eval.x86_64.matplotlib_s_matplotlib-25433:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-13031:latest
|
||||||
|
sweb.eval.x86_64.pytest-dev_s_pytest-7432:latest
|
||||||
|
sweb.eval.x86_64.scikit-learn_s_scikit-learn-25747:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-12286:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-11910:latest
|
||||||
|
sweb.eval.x86_64.scikit-learn_s_scikit-learn-12471:latest
|
||||||
|
sweb.eval.x86_64.pylint-dev_s_pylint-5859:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-11133:latest
|
||||||
|
sweb.eval.x86_64.astropy_s_astropy-14365:latest
|
||||||
|
sweb.eval.x86_64.scikit-learn_s_scikit-learn-13496:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-19487:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-13895:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-15345:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-13590:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-13757:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-16379:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-13768:latest
|
||||||
|
sweb.eval.x86_64.pytest-dev_s_pytest-8365:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-14580:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-20154:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-12419:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-12125:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-24152:latest
|
||||||
|
sweb.eval.x86_64.scikit-learn_s_scikit-learn-15512:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-18621:latest
|
||||||
|
sweb.eval.x86_64.pydata_s_xarray-4248:latest
|
||||||
|
sweb.eval.x86_64.scikit-learn_s_scikit-learn-11040:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-11099:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-16816:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-13265:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-16139:latest
|
||||||
|
sweb.eval.x86_64.scikit-learn_s_scikit-learn-10297:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-14016:latest
|
||||||
|
sweb.eval.x86_64.pallets_s_flask-5063:latest
|
||||||
|
sweb.eval.x86_64.astropy_s_astropy-7746:latest
|
||||||
|
sweb.eval.x86_64.matplotlib_s_matplotlib-24265:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-13448:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-12908:latest
|
||||||
|
sweb.eval.x86_64.sphinx-doc_s_sphinx-8627:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-14317:latest
|
||||||
|
sweb.eval.x86_64.pytest-dev_s_pytest-6116:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-23191:latest
|
||||||
|
sweb.eval.x86_64.pydata_s_xarray-5131:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-11019:latest
|
||||||
|
sweb.eval.x86_64.matplotlib_s_matplotlib-23913:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-15790:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-12497:latest
|
||||||
|
sweb.eval.x86_64.matplotlib_s_matplotlib-26020:latest
|
||||||
|
sweb.eval.x86_64.scikit-learn_s_scikit-learn-25638:latest
|
||||||
|
sweb.eval.x86_64.scikit-learn_s_scikit-learn-25500:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-19007:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-12308:latest
|
||||||
|
sweb.eval.x86_64.pytest-dev_s_pytest-7220:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-11848:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-15347:latest
|
||||||
|
sweb.eval.x86_64.pytest-dev_s_pytest-7490:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-18532:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-14997:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-24909:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-13220:latest
|
||||||
|
sweb.eval.x86_64.sympy_s_sympy-21614:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-15902:latest
|
||||||
|
sweb.eval.x86_64.scikit-learn_s_scikit-learn-13497:latest
|
||||||
|
sweb.eval.x86_64.scikit-learn_s_scikit-learn-13439:latest
|
||||||
|
sweb.eval.x86_64.scikit-learn_s_scikit-learn-14894:latest
|
||||||
|
sweb.eval.x86_64.django_s_django-12983:latest
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
def print_diff_ignore_order(file1, file2):
|
||||||
|
with open(file1, 'r') as f1, open(file2, 'r') as f2:
|
||||||
|
file1_lines = set(f1.readlines())
|
||||||
|
file2_lines = set(f2.readlines())
|
||||||
|
|
||||||
|
only_in_file1 = file1_lines - file2_lines
|
||||||
|
only_in_file2 = file2_lines - file1_lines
|
||||||
|
|
||||||
|
if only_in_file1:
|
||||||
|
print(f'Lines in {file1} but not in {file2}:')
|
||||||
|
for line in sorted(only_in_file1):
|
||||||
|
print(f'- {line.strip()}')
|
||||||
|
|
||||||
|
# if only_in_file2:
|
||||||
|
# print(f"Lines in {file2} but not in {file1}:")
|
||||||
|
# for line in sorted(only_in_file2):
|
||||||
|
# print(f"+ {line.strip()}")
|
||||||
|
|
||||||
|
if not only_in_file1 and not only_in_file2:
|
||||||
|
print('The files have the same content (ignoring line order).')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Usage
|
||||||
|
lite1 = 'all-swebench-lite-instance-images.txt' # Replace with the path to your first file
|
||||||
|
lite2 = '../../swe_bench/scripts/docker/all-swebench-lite-instance-images.txt' # Replace with the path to your second file
|
||||||
|
print_diff_ignore_order(lite1, lite2)
|
||||||
|
|
||||||
|
full1 = 'all-swebench-full-instance-images.txt' # Replace with the path to your first file
|
||||||
|
full2 = '../../swe_bench/scripts/docker/all-swebench-full-instance-images.txt' # Replace with the path to your second file
|
||||||
|
print_diff_ignore_order(full1, full2)
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Script will delete all repositories and tags in your Docker Hub account
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Set username and password from command-line arguments
|
||||||
|
UNAME=$1
|
||||||
|
UPASS=$2
|
||||||
|
|
||||||
|
# Get token to interact with Docker Hub
|
||||||
|
TOKEN=$(curl -s -H "Content-Type: application/json" -X POST -d '{"username": "'${UNAME}'", "password": "'${UPASS}'"}' https://hub.docker.com/v2/users/login/ | jq -r .token)
|
||||||
|
|
||||||
|
# Ensure token retrieval was successful
|
||||||
|
if [[ -z "$TOKEN" ]]; then
|
||||||
|
echo "Failed to obtain authentication token. Please check your credentials."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Get list of repositories for that user account
|
||||||
|
echo "Listing repositories in Docker Hub account '${UNAME}':"
|
||||||
|
REPO_LIST=$(curl -s -H "Authorization: JWT ${TOKEN}" "https://hub.docker.com/v2/repositories/${UNAME}/?page_size=10000" | jq -r '.results|.[]|.name')
|
||||||
|
if [[ -z "$REPO_LIST" ]]; then
|
||||||
|
echo "No repositories found for user '${UNAME}' or failed to fetch repositories."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Loop through each repository and delete its tags and the repository itself
|
||||||
|
for rep in ${REPO_LIST}; do
|
||||||
|
echo "Processing repository: ${UNAME}/${rep}"
|
||||||
|
|
||||||
|
# Get all tags for the repository
|
||||||
|
IMAGES=$(curl -s -H "Authorization: JWT ${TOKEN}" "https://hub.docker.com/v2/repositories/${UNAME}/${rep}/tags/?page_size=100")
|
||||||
|
IMAGE_TAGS=$(echo $IMAGES | jq -r '.results|.[]|.name')
|
||||||
|
|
||||||
|
# Delete each tag
|
||||||
|
for tag in ${IMAGE_TAGS}; do
|
||||||
|
echo "Deleting tag: ${UNAME}/${rep}:${tag}"
|
||||||
|
curl -s -X DELETE -H "Authorization: JWT ${TOKEN}" "https://hub.docker.com/v2/repositories/${UNAME}/${rep}/tags/${tag}/"
|
||||||
|
done
|
||||||
|
|
||||||
|
# Delete the repository itself
|
||||||
|
echo "Deleting repository: ${UNAME}/${rep}"
|
||||||
|
curl -s -X DELETE -H "Authorization: JWT ${TOKEN}" "https://hub.docker.com/v2/repositories/${UNAME}/${rep}/" || {
|
||||||
|
echo "Failed to delete repository '${UNAME}/${rep}'. Please check permissions or API limits."
|
||||||
|
}
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "Script execution completed."
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def dataset_to_txt(dataset, txt_file, split='test'):
|
||||||
|
with open(txt_file, 'w') as f:
|
||||||
|
for datum in dataset[split]:
|
||||||
|
instance_id = datum['instance_id'].replace('__', '_s_')
|
||||||
|
f.write(f'sweb.eval.x86_64.{instance_id}:latest\n')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# Load the private dataset
|
||||||
|
dataset = load_dataset('kjain14/testgeneval')
|
||||||
|
|
||||||
|
dataset_lite = load_dataset('kjain14/testgenevallite')
|
||||||
|
|
||||||
|
dataset_to_txt(dataset_lite, 'all-swebench-lite-instance-images.txt', lite=True)
|
||||||
|
dataset_to_txt(dataset, 'all-swebench-full-instance-images.txt')
|
||||||
@@ -0,0 +1,173 @@
|
|||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import difflib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
def insert_line_in_string(input_string, new_str, insert_line):
|
||||||
|
"""
|
||||||
|
Inserts a new line into a string at the specified line number.
|
||||||
|
|
||||||
|
:param input_string: The original string.
|
||||||
|
:param new_str: The string to insert.
|
||||||
|
:param insert_line: The line number at which to insert (1-based index).
|
||||||
|
:return: The modified string.
|
||||||
|
"""
|
||||||
|
file_text = input_string.expandtabs()
|
||||||
|
new_str = new_str.expandtabs()
|
||||||
|
|
||||||
|
file_text_lines = file_text.split('\n')
|
||||||
|
|
||||||
|
new_str_lines = new_str.split('\n')
|
||||||
|
new_file_text_lines = (
|
||||||
|
file_text_lines[:insert_line] + new_str_lines + file_text_lines[insert_line:]
|
||||||
|
)
|
||||||
|
|
||||||
|
return '\n'.join(new_file_text_lines)
|
||||||
|
|
||||||
|
|
||||||
|
def print_string_diff(original, modified):
|
||||||
|
"""
|
||||||
|
Prints the differences between two strings line by line.
|
||||||
|
|
||||||
|
:param original: The original string.
|
||||||
|
:param modified: The modified string.
|
||||||
|
"""
|
||||||
|
original_lines = original.splitlines(keepends=True)
|
||||||
|
modified_lines = modified.splitlines(keepends=True)
|
||||||
|
|
||||||
|
diff = difflib.unified_diff(
|
||||||
|
original_lines,
|
||||||
|
modified_lines,
|
||||||
|
fromfile='original',
|
||||||
|
tofile='modified',
|
||||||
|
lineterm='',
|
||||||
|
)
|
||||||
|
|
||||||
|
print(''.join(diff))
|
||||||
|
|
||||||
|
|
||||||
|
def parse_json_files(root_dir, output_dir, metadata_objs, preds_objs):
|
||||||
|
final_output = {i: [] for i in range(25)}
|
||||||
|
|
||||||
|
for subdir in sorted(os.listdir(root_dir)): # Sorting ensures consistent order
|
||||||
|
subdir_path = os.path.join(root_dir, subdir)
|
||||||
|
# subdir_instance = subdir.rsplit('-', 1)[0]
|
||||||
|
metadata = metadata_objs[subdir]
|
||||||
|
orig_test_suite = metadata['test_result']['test_suite']
|
||||||
|
|
||||||
|
if os.path.isdir(subdir_path): # Check if it's a directory
|
||||||
|
print(f'Processing subdirectory: {subdir}')
|
||||||
|
|
||||||
|
# Now loop through the JSON files in this subdirectory
|
||||||
|
i = 0
|
||||||
|
test_suite = preds_objs[subdir] if subdir in preds_objs else ''
|
||||||
|
for file in sorted(
|
||||||
|
os.listdir(subdir_path)
|
||||||
|
): # Sorting ensures consistent order
|
||||||
|
metadata_copy = copy.deepcopy(metadata)
|
||||||
|
if file.endswith('.json'): # Check for JSON files
|
||||||
|
file_path = os.path.join(subdir_path, file)
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f) # Load JSON data
|
||||||
|
try:
|
||||||
|
tool_calls = data['response']['choices'][0]['message'][
|
||||||
|
'tool_calls'
|
||||||
|
]
|
||||||
|
if tool_calls is not None:
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
tool_call_dict = eval(
|
||||||
|
tool_call['function']['arguments']
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
tool_call_dict is not None
|
||||||
|
and tool_call_dict != {}
|
||||||
|
):
|
||||||
|
command = tool_call_dict['command']
|
||||||
|
if command == 'create':
|
||||||
|
test_suite = tool_call_dict['file_text']
|
||||||
|
if (
|
||||||
|
command != 'str_replace'
|
||||||
|
and command != 'insert'
|
||||||
|
and 'coverage' not in command
|
||||||
|
):
|
||||||
|
print(command)
|
||||||
|
if command == 'insert':
|
||||||
|
test_suite_new = insert_line_in_string(
|
||||||
|
test_suite,
|
||||||
|
tool_call_dict['new_str'],
|
||||||
|
tool_call_dict['insert_line'],
|
||||||
|
)
|
||||||
|
test_suite = test_suite_new
|
||||||
|
if command == 'str_replace':
|
||||||
|
if (
|
||||||
|
test_suite.count(
|
||||||
|
tool_call_dict['old_str']
|
||||||
|
)
|
||||||
|
== 1
|
||||||
|
):
|
||||||
|
test_suite_new = test_suite.replace(
|
||||||
|
tool_call_dict['old_str'],
|
||||||
|
tool_call_dict['new_str'],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
test_suite = test_suite_new
|
||||||
|
except Exception:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
continue
|
||||||
|
|
||||||
|
metadata_copy['test_result']['test_suite'] = test_suite
|
||||||
|
if i < 25:
|
||||||
|
final_output[i].append(metadata_copy)
|
||||||
|
i += 1
|
||||||
|
except Exception as e:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
print(f' Error loading {file_path}: {e}')
|
||||||
|
|
||||||
|
for j in range(i, 24):
|
||||||
|
final_output[j].append(metadata_copy)
|
||||||
|
metadata_orig = copy.deepcopy(metadata)
|
||||||
|
metadata_orig['test_result']['test_suite'] = orig_test_suite
|
||||||
|
final_output[24].append(metadata_orig)
|
||||||
|
|
||||||
|
for i in range(25):
|
||||||
|
output_file = os.path.join(output_dir, f'output_{i}.jsonl')
|
||||||
|
with open(output_file, 'w') as f:
|
||||||
|
for metadata in final_output[i]:
|
||||||
|
f.write(json.dumps(metadata) + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Parse JSON file')
|
||||||
|
parser.add_argument('--root_dir', type=str, help='Root directory', required=True)
|
||||||
|
parser.add_argument(
|
||||||
|
'--output_dir', type=str, help='Output directory', required=True
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--starting_preds_file', type=str, help='Starting predictions', default=None
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
output_file = os.path.join(args.output_dir, 'output.jsonl')
|
||||||
|
metadata_objs = {}
|
||||||
|
with open(output_file, 'r') as f:
|
||||||
|
content = f.readlines()
|
||||||
|
for line in content:
|
||||||
|
metadata = json.loads(line)
|
||||||
|
metadata_objs[metadata['instance_id']] = metadata
|
||||||
|
|
||||||
|
starting_preds_file = args.starting_preds_file
|
||||||
|
preds_objs = {}
|
||||||
|
if starting_preds_file is not None:
|
||||||
|
with open(starting_preds_file, 'r') as f:
|
||||||
|
content = f.readlines()
|
||||||
|
for line in content:
|
||||||
|
pred = json.loads(line)
|
||||||
|
preds_objs[pred['id']] = pred['preds']['full'][0]
|
||||||
|
|
||||||
|
parse_json_files(args.root_dir, args.output_dir, metadata_objs, preds_objs)
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Compare two TestGenEval output JSONL files and print the resolved diff'
|
||||||
|
)
|
||||||
|
parser.add_argument('input_file_1', type=str)
|
||||||
|
parser.add_argument('input_file_2', type=str)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
df1 = pd.read_json(args.input_file_1, orient='records', lines=True)
|
||||||
|
df2 = pd.read_json(args.input_file_2, orient='records', lines=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Get the intersection of the ids
|
||||||
|
df = pd.merge(df1, df2, on='id', how='inner')
|
||||||
|
|
||||||
|
|
||||||
|
def _get_coverage(report):
|
||||||
|
if report is None:
|
||||||
|
return False
|
||||||
|
if isinstance(report, float):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return report.get('test_pass', False)
|
||||||
|
|
||||||
|
|
||||||
|
df['test_pass_x'] = df['test_pass_x'].apply(_get_coverage)
|
||||||
|
df['test_pass_y'] = df['test_pass_y'].apply(_get_coverage)
|
||||||
|
df['diff'] = df.apply(lambda x: x['test_pass_x'] != x['test_pass_y'], axis=1)
|
||||||
|
|
||||||
|
df_diff = df[df['diff']].sort_values(
|
||||||
|
by=['test_pass_x', 'test_pass_y'], ascending=[False, False]
|
||||||
|
)
|
||||||
|
# skip if any of the pass is nan, which means one of the eval is not finished yet
|
||||||
|
df_diff = df_diff[df_diff['test_pass_x'].notna() & df_diff['test_pass_y'].notna()]
|
||||||
|
|
||||||
|
print(f'X={args.input_file_1}')
|
||||||
|
print(f'Y={args.input_file_2}')
|
||||||
|
print(f'# diff={df_diff.shape[0]}')
|
||||||
|
df_diff = df_diff[['id', 'test_pass_x', 'test_pass_y', 'report_x', 'report_y']]
|
||||||
|
|
||||||
|
# x pass but y not
|
||||||
|
print('-' * 100)
|
||||||
|
df_diff_x_only = df_diff[df_diff['test_pass_x'] & ~df_diff['test_pass_y']].sort_values(
|
||||||
|
by='id'
|
||||||
|
)
|
||||||
|
print(f'# x pass but y not={df_diff_x_only.shape[0]}')
|
||||||
|
print(df_diff_x_only[['id', 'report_x', 'report_y']])
|
||||||
|
|
||||||
|
# y pass but x not
|
||||||
|
print('-' * 100)
|
||||||
|
df_diff_y_only = df_diff[~df_diff['test_pass_x'] & df_diff['test_pass_y']].sort_values(
|
||||||
|
by='id'
|
||||||
|
)
|
||||||
|
print(f'# y pass but x not={df_diff_y_only.shape[0]}')
|
||||||
|
print(df_diff_y_only[['id', 'report_x', 'report_y']])
|
||||||
|
# get instance_id from df_diff_y_only
|
||||||
|
print('-' * 100)
|
||||||
|
print('Instances that x pass but y not:')
|
||||||
|
print(df_diff_x_only['id'].tolist())
|
||||||
|
|
||||||
|
print('-' * 100)
|
||||||
|
print('Instances that y pass but x not:')
|
||||||
|
print(df_diff_y_only['id'].tolist())
|
||||||
Executable
+28
@@ -0,0 +1,28 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
FOLDER_PATH=$1
|
||||||
|
NEW_FOLDER_PATH=${FOLDER_PATH}.swebench_submission
|
||||||
|
mkdir -p $NEW_FOLDER_PATH
|
||||||
|
|
||||||
|
# Build all_preds.jsonl
|
||||||
|
poetry run python evaluation/testgeneval/scripts/eval/convert_oh_output_to_swe_json.py $FOLDER_PATH/output.jsonl
|
||||||
|
mv $FOLDER_PATH/output.swebench.jsonl $NEW_FOLDER_PATH/all_preds.jsonl
|
||||||
|
|
||||||
|
# Build trajs/
|
||||||
|
mkdir -p $NEW_FOLDER_PATH/trajs
|
||||||
|
for instance_dir in $FOLDER_PATH/llm_completions/*/; do
|
||||||
|
instance_id=$(basename "$instance_dir")
|
||||||
|
latest_json=$(ls -t "$instance_dir"/*.json | head -n1)
|
||||||
|
if [ -n "$latest_json" ]; then
|
||||||
|
cat "$latest_json" | jq -r '.messages' > "$NEW_FOLDER_PATH/trajs/$instance_id.json"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# Build logs/
|
||||||
|
# check if $FOLDER_PATH/eval_outputs exists, if so copy over - else raise error
|
||||||
|
if [ -d "$FOLDER_PATH/eval_outputs" ]; then
|
||||||
|
cp -r $FOLDER_PATH/eval_outputs $NEW_FOLDER_PATH/logs
|
||||||
|
else
|
||||||
|
echo "Error: $FOLDER_PATH/eval_outputs does not exist. You should run the local docker eval_infer.sh first."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Convert OpenHands output to a readable markdown format for visualization."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from evaluation.testgeneval.eval_infer import process_test_suite
|
||||||
|
from openhands.events.serialization import event_from_dict
|
||||||
|
|
||||||
|
tqdm.pandas()
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('oh_output_file', type=str)
|
||||||
|
args = parser.parse_args()
|
||||||
|
output_md_folder = args.oh_output_file.replace('.jsonl', '.viz')
|
||||||
|
print(f'Converting {args.oh_output_file} to markdown files in {output_md_folder}')
|
||||||
|
|
||||||
|
oh_format = pd.read_json(args.oh_output_file, orient='records', lines=True)
|
||||||
|
# model name is the folder name of oh_output_file
|
||||||
|
model_name = os.path.basename(os.path.dirname(args.oh_output_file))
|
||||||
|
|
||||||
|
|
||||||
|
def convert_history_to_str(history):
|
||||||
|
ret = ''
|
||||||
|
separator = '\n\n' + '-' * 100 + '\n'
|
||||||
|
|
||||||
|
for i, event in enumerate(history):
|
||||||
|
if i != 0:
|
||||||
|
ret += separator
|
||||||
|
|
||||||
|
if isinstance(event, list):
|
||||||
|
# "event" is a legacy pair of (action, observation)
|
||||||
|
event_obj = event_from_dict(event[0])
|
||||||
|
ret += f'## {i+1}| {event_obj.__class__.__name__}\n\n'
|
||||||
|
ret += str(event_obj)
|
||||||
|
ret += separator
|
||||||
|
|
||||||
|
event_obj = event_from_dict(event[1])
|
||||||
|
ret += f'## {i+1}| {event_obj.__class__.__name__}\n\n'
|
||||||
|
ret += str(event_obj)
|
||||||
|
else:
|
||||||
|
# "event" is a single event
|
||||||
|
event_obj = event_from_dict(event)
|
||||||
|
ret += f'## {i+1}| {event_obj.__class__.__name__}\n\n'
|
||||||
|
ret += str(event_obj)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def write_row_to_md_file(row):
|
||||||
|
if 'test_suite' in row:
|
||||||
|
test_suite = row['test_suite']
|
||||||
|
elif 'test_result' in row and 'test_suite' in row['test_result']:
|
||||||
|
test_suite = row['test_result']['test_suite']
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Row {row} does not have a test_suite')
|
||||||
|
|
||||||
|
if 'report' in row:
|
||||||
|
coverage = row['report'].get('coverage', 0)
|
||||||
|
mutation = row['report'].get('mutation_score', 0)
|
||||||
|
else:
|
||||||
|
coverage = None
|
||||||
|
mutation = None
|
||||||
|
|
||||||
|
id = row['id']
|
||||||
|
filename = f'{id}.md'
|
||||||
|
os.makedirs(output_md_folder, exist_ok=True)
|
||||||
|
filepath = os.path.join(output_md_folder, filename)
|
||||||
|
|
||||||
|
with open(filepath, 'w') as f:
|
||||||
|
f.write(f'# {id} (coverage: {coverage})\n')
|
||||||
|
f.write(f'# {id} (mutation score: {mutation})\n')
|
||||||
|
|
||||||
|
# MetaData
|
||||||
|
f.write('## MetaData\n')
|
||||||
|
f.write('```json\n')
|
||||||
|
f.write(json.dumps(row['metadata'], indent=2))
|
||||||
|
f.write('\n```\n')
|
||||||
|
|
||||||
|
# Trajectory
|
||||||
|
f.write('## History\n')
|
||||||
|
f.write(convert_history_to_str(row['history']))
|
||||||
|
|
||||||
|
f.write('## Test Suite\n')
|
||||||
|
f.write(f'{test_suite}\n')
|
||||||
|
|
||||||
|
|
||||||
|
oh_format.progress_apply(write_row_to_md_file, axis=1)
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from evaluation.swe_bench.eval_infer import process_git_patch
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('oh_output_file', type=str)
|
||||||
|
args = parser.parse_args()
|
||||||
|
output_filepath = args.oh_output_file.replace('.jsonl', '.swebench.jsonl')
|
||||||
|
print(f'Converting {args.oh_output_file} to {output_filepath}')
|
||||||
|
|
||||||
|
oh_format = pd.read_json(args.oh_output_file, orient='records', lines=True)
|
||||||
|
# model name is the folder name of oh_output_file
|
||||||
|
model_name = os.path.basename(os.path.dirname(args.oh_output_file))
|
||||||
|
|
||||||
|
|
||||||
|
def convert_row_to_swebench_format(row):
|
||||||
|
if 'git_patch' in row:
|
||||||
|
model_patch = row['git_patch']
|
||||||
|
elif 'test_result' in row and 'git_patch' in row['test_result']:
|
||||||
|
model_patch = row['test_result']['git_patch']
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Row {row} does not have a git_patch')
|
||||||
|
|
||||||
|
return {
|
||||||
|
'instance_id': row['instance_id'],
|
||||||
|
'model_patch': process_git_patch(model_patch),
|
||||||
|
'model_name_or_path': model_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
swebench_format = oh_format.apply(convert_row_to_swebench_format, axis=1)
|
||||||
|
swebench_format.to_json(output_filepath, lines=True, orient='records')
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('output_filepath', type=str, help='Path to save the output file')
|
||||||
|
parser.add_argument(
|
||||||
|
'--dataset_name',
|
||||||
|
type=str,
|
||||||
|
help='Name of the dataset to download',
|
||||||
|
default='kjain14/testgeneval',
|
||||||
|
)
|
||||||
|
parser.add_argument('--split', type=str, help='Split to download', default='test')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
dataset = load_dataset(args.dataset_name, split=args.split)
|
||||||
|
output_filepath = args.output_filepath
|
||||||
|
print(
|
||||||
|
f'Downloading gold test suites from {args.dataset_name} (split: {args.split}) to {output_filepath}'
|
||||||
|
)
|
||||||
|
test_suites = [
|
||||||
|
{'instance_id': row['instance_id'], 'test_suite': row['test_src']} for row in dataset
|
||||||
|
]
|
||||||
|
print(f'{len(test_suites)} test suites loaded')
|
||||||
|
pd.DataFrame(test_suites).to_json(output_filepath, lines=True, orient='records')
|
||||||
|
print(f'Test suites saved to {output_filepath}')
|
||||||
@@ -0,0 +1,122 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
from openhands.events.serialization import event_from_dict
|
||||||
|
from openhands.events.utils import get_pairs_from_events
|
||||||
|
|
||||||
|
ERROR_KEYWORDS = [
|
||||||
|
'Agent encountered an error while processing the last action',
|
||||||
|
'APIError',
|
||||||
|
'Action execution failed',
|
||||||
|
]
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('output_file', type=str, help='The file to summarize')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
with open(args.output_file, 'r') as file:
|
||||||
|
lines = file.readlines()
|
||||||
|
|
||||||
|
num_lines = len(lines)
|
||||||
|
num_error_lines = 0
|
||||||
|
num_agent_stuck_in_loop = 0
|
||||||
|
|
||||||
|
coverage = 0
|
||||||
|
mutation_score = 0
|
||||||
|
num_empty_suite = 0
|
||||||
|
|
||||||
|
error_counter = Counter()
|
||||||
|
|
||||||
|
main_agent_cost = []
|
||||||
|
editor_cost = []
|
||||||
|
num_turns = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
_d = json.loads(line)
|
||||||
|
|
||||||
|
# Cost
|
||||||
|
costs = _d['metrics'].get('costs', [])
|
||||||
|
_cur_main_agent_cost = 0
|
||||||
|
_cur_editor_cost = 0
|
||||||
|
for cost in costs:
|
||||||
|
if isinstance(cost, float):
|
||||||
|
# backward compatible
|
||||||
|
_cur_main_agent_cost += cost
|
||||||
|
else:
|
||||||
|
if 'draft_editor' in cost['model']:
|
||||||
|
_cur_editor_cost += cost['cost']
|
||||||
|
else:
|
||||||
|
_cur_main_agent_cost += cost['cost']
|
||||||
|
|
||||||
|
main_agent_cost.append(_cur_main_agent_cost)
|
||||||
|
editor_cost.append(_cur_editor_cost)
|
||||||
|
|
||||||
|
# Turn status
|
||||||
|
history = _d.get('history', [])
|
||||||
|
events = [event_from_dict(event) for event in history]
|
||||||
|
pairs = get_pairs_from_events(events)
|
||||||
|
num_turns.append(len(pairs))
|
||||||
|
|
||||||
|
# Suite & resolve status
|
||||||
|
suite = _d.get('test_result', {}).get('test_suite', '')
|
||||||
|
if suite == '':
|
||||||
|
num_empty_suite += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
report = _d.get('report', {}) or {}
|
||||||
|
coverage += report.get('coverage', 0)
|
||||||
|
mutation_score += report.get('mutation_score', 0)
|
||||||
|
|
||||||
|
# Error
|
||||||
|
error = _d.get('error', None)
|
||||||
|
|
||||||
|
if error is not None and isinstance(error, str):
|
||||||
|
agent_stuck_in_loop = 'Agent got stuck in a loop' in error
|
||||||
|
contains_error = bool(error) and not agent_stuck_in_loop
|
||||||
|
if agent_stuck_in_loop:
|
||||||
|
error_counter['Agent got stuck in a loop'] += 1
|
||||||
|
num_agent_stuck_in_loop += 1
|
||||||
|
elif contains_error:
|
||||||
|
error_counter[error] += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
for keyword in ERROR_KEYWORDS:
|
||||||
|
if keyword in line:
|
||||||
|
error_counter[keyword] += 1
|
||||||
|
num_error_lines += 1
|
||||||
|
break
|
||||||
|
|
||||||
|
# print the error counter (with percentage)
|
||||||
|
print(
|
||||||
|
f'Average coverage for {num_lines} ({coverage / num_lines * 100:.2f}%)'
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f'Average mutation score for {num_lines} ({mutation_score / num_lines * 100:.2f}%)'
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f'Number of empty suite: {num_empty_suite} / {num_lines} ({num_empty_suite / num_lines * 100:.2f}%)'
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f'Number of error lines: {num_error_lines} / {num_lines} ({num_error_lines / num_lines * 100:.2f}%)'
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f'Number of agent stuck in loop: {num_agent_stuck_in_loop} / {num_lines} ({num_agent_stuck_in_loop / num_lines * 100:.2f}%)'
|
||||||
|
)
|
||||||
|
assert len(num_turns) == num_lines
|
||||||
|
assert len(main_agent_cost) == num_lines
|
||||||
|
assert len(editor_cost) == num_lines
|
||||||
|
print('## Statistics')
|
||||||
|
print(f'Avg. num of turns per instance: {sum(num_turns) / num_lines:.2f}')
|
||||||
|
print(f'Avg. agent cost per instance: {sum(main_agent_cost) / num_lines:.2f} USD')
|
||||||
|
print(f'Avg. editor cost per instance: {sum(editor_cost) / num_lines:.2f} USD')
|
||||||
|
print(
|
||||||
|
f'Avg. total cost per instance: {(sum(main_agent_cost) + sum(editor_cost)) / num_lines:.2f} USD'
|
||||||
|
)
|
||||||
|
|
||||||
|
print('## Detailed error breakdown:')
|
||||||
|
for error, count in error_counter.items():
|
||||||
|
print(f'{error}: {count} ({count / num_lines * 100:.2f}%)')
|
||||||
+53
@@ -0,0 +1,53 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -eo pipefail
|
||||||
|
|
||||||
|
INPUT_FILE=$1
|
||||||
|
NUM_WORKERS=$2
|
||||||
|
DATASET=$3
|
||||||
|
SPLIT=$4
|
||||||
|
SKIP_MUTATION=$5
|
||||||
|
|
||||||
|
if [ -z "$INPUT_FILE" ]; then
|
||||||
|
echo "INPUT_FILE not specified (should be a path to a jsonl file)"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "$DATASET" ]; then
|
||||||
|
echo "DATASET not specified, use default kjain14/testgenevallite"
|
||||||
|
DATASET="kjain14/testgenevallite"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "$SPLIT" ]; then
|
||||||
|
echo "SPLIT not specified, use default test"
|
||||||
|
SPLIT="test"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "$NUM_WORKERS" ]; then
|
||||||
|
echo "NUM_WORKERS not specified, use default 1"
|
||||||
|
NUM_WORKERS=1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "... Evaluating on $INPUT_FILE ..."
|
||||||
|
|
||||||
|
COMMAND="poetry run python evaluation/benchmarks/testgeneval/eval_infer.py \
|
||||||
|
--eval-num-workers $NUM_WORKERS \
|
||||||
|
--input-file $INPUT_FILE \
|
||||||
|
--dataset $DATASET \
|
||||||
|
--split $SPLIT"
|
||||||
|
|
||||||
|
if [ "$SKIP_MUTATION" == "true" ]; then
|
||||||
|
echo "Skipping mutation evaluation"
|
||||||
|
COMMAND="$COMMAND --skip_mutation"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "$EVAL_LIMIT" ]; then
|
||||||
|
echo "EVAL_LIMIT: $EVAL_LIMIT"
|
||||||
|
COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo $COMMAND
|
||||||
|
# Run the command
|
||||||
|
eval $COMMAND
|
||||||
|
|
||||||
|
# update the output with evaluation results
|
||||||
|
# poetry run python evaluation/benchmarks/testgeneval/scripts/eval/update_output_with_eval.py $INPUT_FILE
|
||||||
+122
@@ -0,0 +1,122 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -eo pipefail
|
||||||
|
|
||||||
|
source "evaluation/utils/version_control.sh"
|
||||||
|
|
||||||
|
MODEL_CONFIG=$1
|
||||||
|
COMMIT_HASH=$2
|
||||||
|
AGENT=$3
|
||||||
|
EVAL_LIMIT=$4
|
||||||
|
MAX_ITER=$5
|
||||||
|
NUM_WORKERS=$6
|
||||||
|
DATASET=$7
|
||||||
|
SPLIT=$8
|
||||||
|
N_RUNS=$9
|
||||||
|
ZERO_SHOT_PATH=${10} # New argument for zero-shot path
|
||||||
|
|
||||||
|
if [ -z "$NUM_WORKERS" ]; then
|
||||||
|
NUM_WORKERS=1
|
||||||
|
echo "Number of workers not specified, use default $NUM_WORKERS"
|
||||||
|
fi
|
||||||
|
checkout_eval_branch
|
||||||
|
|
||||||
|
if [ -z "$AGENT" ]; then
|
||||||
|
echo "Agent not specified, use default CodeActAgent"
|
||||||
|
AGENT="CodeActAgent"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "$MAX_ITER" ]; then
|
||||||
|
echo "MAX_ITER not specified, use default 100"
|
||||||
|
MAX_ITER=100
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "$USE_INSTANCE_IMAGE" ]; then
|
||||||
|
echo "USE_INSTANCE_IMAGE not specified, use default true"
|
||||||
|
USE_INSTANCE_IMAGE=true
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "$RUN_WITH_BROWSING" ]; then
|
||||||
|
echo "RUN_WITH_BROWSING not specified, use default false"
|
||||||
|
RUN_WITH_BROWSING=false
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
if [ -z "$DATASET" ]; then
|
||||||
|
echo "DATASET not specified, use default princeton-nlp/SWE-bench_Lite"
|
||||||
|
DATASET="princeton-nlp/SWE-bench_Lite"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -z "$SPLIT" ]; then
|
||||||
|
echo "SPLIT not specified, use default test"
|
||||||
|
SPLIT="test"
|
||||||
|
fi
|
||||||
|
|
||||||
|
export USE_INSTANCE_IMAGE=$USE_INSTANCE_IMAGE
|
||||||
|
echo "USE_INSTANCE_IMAGE: $USE_INSTANCE_IMAGE"
|
||||||
|
export RUN_WITH_BROWSING=$RUN_WITH_BROWSING
|
||||||
|
echo "RUN_WITH_BROWSING: $RUN_WITH_BROWSING"
|
||||||
|
|
||||||
|
get_openhands_version
|
||||||
|
|
||||||
|
echo "AGENT: $AGENT"
|
||||||
|
echo "OPENHANDS_VERSION: $OPENHANDS_VERSION"
|
||||||
|
echo "MODEL_CONFIG: $MODEL_CONFIG"
|
||||||
|
echo "DATASET: $DATASET"
|
||||||
|
echo "SPLIT: $SPLIT"
|
||||||
|
|
||||||
|
# Default to NOT use Hint
|
||||||
|
if [ -z "$USE_HINT_TEXT" ]; then
|
||||||
|
export USE_HINT_TEXT=false
|
||||||
|
fi
|
||||||
|
echo "USE_HINT_TEXT: $USE_HINT_TEXT"
|
||||||
|
EVAL_NOTE="$OPENHANDS_VERSION"
|
||||||
|
# if not using Hint, add -no-hint to the eval note
|
||||||
|
if [ "$USE_HINT_TEXT" = false ]; then
|
||||||
|
EVAL_NOTE="$EVAL_NOTE-no-hint"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$RUN_WITH_BROWSING" = true ]; then
|
||||||
|
EVAL_NOTE="$EVAL_NOTE-with-browsing"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "$EXP_NAME" ]; then
|
||||||
|
EVAL_NOTE="$EVAL_NOTE-$EXP_NAME"
|
||||||
|
fi
|
||||||
|
|
||||||
|
function run_eval() {
|
||||||
|
local eval_note=$1
|
||||||
|
COMMAND="poetry run python evaluation/benchmarks/testgeneval/run_infer.py \
|
||||||
|
--agent-cls $AGENT \
|
||||||
|
--llm-config $MODEL_CONFIG \
|
||||||
|
--max-iterations $MAX_ITER \
|
||||||
|
--eval-num-workers $NUM_WORKERS \
|
||||||
|
--eval-note $eval_note \
|
||||||
|
--dataset $DATASET \
|
||||||
|
--split $SPLIT"
|
||||||
|
|
||||||
|
if [ -n "$EVAL_LIMIT" ]; then
|
||||||
|
echo "EVAL_LIMIT: $EVAL_LIMIT"
|
||||||
|
COMMAND="$COMMAND --eval-n-limit $EVAL_LIMIT"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "$ZERO_SHOT_PATH" ]; then
|
||||||
|
echo "ZERO_SHOT_PATH: $ZERO_SHOT_PATH"
|
||||||
|
COMMAND="$COMMAND --testfile_start --zero_shot_path $ZERO_SHOT_PATH"
|
||||||
|
fi
|
||||||
|
|
||||||
|
eval $COMMAND
|
||||||
|
}
|
||||||
|
|
||||||
|
unset SANDBOX_ENV_GITHUB_TOKEN # prevent the agent from using the github token to push
|
||||||
|
if [ -z "$N_RUNS" ]; then
|
||||||
|
N_RUNS=1
|
||||||
|
echo "N_RUNS not specified, use default $N_RUNS"
|
||||||
|
fi
|
||||||
|
|
||||||
|
for i in $(seq 1 $N_RUNS); do
|
||||||
|
current_eval_note="$EVAL_NOTE-run_$i"
|
||||||
|
echo "EVAL_NOTE: $current_eval_note"
|
||||||
|
run_eval $current_eval_note
|
||||||
|
done
|
||||||
|
|
||||||
|
checkout_original_branch
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
source ~/.bashrc
|
||||||
|
SWEUTIL_DIR=/swe_util
|
||||||
|
|
||||||
|
# FIXME: Cannot read SWE_INSTANCE_ID from the environment variable
|
||||||
|
# SWE_INSTANCE_ID=django__django-11099
|
||||||
|
if [ -z "$SWE_INSTANCE_ID" ]; then
|
||||||
|
echo "Error: SWE_INSTANCE_ID is not set." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Read the swe-bench-test-lite.json file and extract the required item based on instance_id
|
||||||
|
item=$(jq --arg INSTANCE_ID "$SWE_INSTANCE_ID" '.[] | select(.instance_id == $INSTANCE_ID)' $SWEUTIL_DIR/eval_data/instances/swe-bench-instance.json)
|
||||||
|
|
||||||
|
if [[ -z "$item" ]]; then
|
||||||
|
echo "No item found for the provided instance ID."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
WORKSPACE_NAME=$(echo "$item" | jq -r '(.repo | tostring) + "__" + (.version | tostring) | gsub("/"; "__")')
|
||||||
|
|
||||||
|
echo "WORKSPACE_NAME: $WORKSPACE_NAME"
|
||||||
|
|
||||||
|
# Clear the workspace
|
||||||
|
if [ -d /workspace ]; then
|
||||||
|
rm -rf /workspace/*
|
||||||
|
else
|
||||||
|
mkdir /workspace
|
||||||
|
fi
|
||||||
|
# Copy repo to workspace
|
||||||
|
if [ -d /workspace/$WORKSPACE_NAME ]; then
|
||||||
|
rm -rf /workspace/$WORKSPACE_NAME
|
||||||
|
fi
|
||||||
|
mkdir -p /workspace
|
||||||
|
ln -s /testbed /workspace/$WORKSPACE_NAME
|
||||||
|
|
||||||
|
# Activate instance-specific environment
|
||||||
|
. /opt/miniconda3/etc/profile.d/conda.sh
|
||||||
|
conda activate testbed
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
EVAL_WORKSPACE="evaluation/swe_bench/eval_workspace"
|
||||||
|
mkdir -p $EVAL_WORKSPACE
|
||||||
|
|
||||||
|
# 1. Prepare REPO
|
||||||
|
echo "==== Prepare SWE-bench repo ===="
|
||||||
|
OH_SWE_BENCH_REPO_PATH="https://github.com/All-Hands-AI/SWE-bench.git"
|
||||||
|
OH_SWE_BENCH_REPO_BRANCH="eval"
|
||||||
|
git clone -b $OH_SWE_BENCH_REPO_BRANCH $OH_SWE_BENCH_REPO_PATH $EVAL_WORKSPACE/OH-SWE-bench
|
||||||
|
|
||||||
|
# 2. Prepare DATA
|
||||||
|
echo "==== Prepare SWE-bench data ===="
|
||||||
|
EVAL_IMAGE=ghcr.io/all-hands-ai/eval-swe-bench:builder_with_conda
|
||||||
|
EVAL_WORKSPACE=$(realpath $EVAL_WORKSPACE)
|
||||||
|
chmod +x $EVAL_WORKSPACE/OH-SWE-bench/swebench/harness/prepare_data.sh
|
||||||
|
if [ -d $EVAL_WORKSPACE/eval_data ]; then
|
||||||
|
rm -r $EVAL_WORKSPACE/eval_data
|
||||||
|
fi
|
||||||
|
docker run \
|
||||||
|
-v $EVAL_WORKSPACE:/workspace \
|
||||||
|
-w /workspace \
|
||||||
|
-u $(id -u):$(id -g) \
|
||||||
|
-e HF_DATASETS_CACHE="/tmp" \
|
||||||
|
--rm -it $EVAL_IMAGE \
|
||||||
|
bash -c "cd OH-SWE-bench/swebench/harness && /swe_util/miniforge3/bin/conda run -n swe-bench-eval ./prepare_data.sh && mv eval_data /workspace/"
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# assert user name is `root`
|
||||||
|
if [ "$USER" != "root" ]; then
|
||||||
|
echo "Error: This script is intended to be run by the 'root' user only." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
source ~/.bashrc
|
||||||
|
|
||||||
|
SWEUTIL_DIR=/swe_util
|
||||||
|
|
||||||
|
# Create logs directory
|
||||||
|
LOG_DIR=/openhands/logs
|
||||||
|
mkdir -p $LOG_DIR && chmod 777 $LOG_DIR
|
||||||
|
|
||||||
|
# FIXME: Cannot read SWE_INSTANCE_ID from the environment variable
|
||||||
|
# SWE_INSTANCE_ID=django__django-11099
|
||||||
|
if [ -z "$SWE_INSTANCE_ID" ]; then
|
||||||
|
echo "Error: SWE_INSTANCE_ID is not set." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Read the swe-bench-test-lite.json file and extract the required item based on instance_id
|
||||||
|
item=$(jq --arg INSTANCE_ID "$SWE_INSTANCE_ID" '.[] | select(.instance_id == $INSTANCE_ID)' $SWEUTIL_DIR/eval_data/instances/swe-bench-test-lite.json)
|
||||||
|
|
||||||
|
if [[ -z "$item" ]]; then
|
||||||
|
echo "No item found for the provided instance ID."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
CONDA_ENV_NAME=$(echo "$item" | jq -r '.repo + "__" + .version | gsub("/"; "__")')
|
||||||
|
|
||||||
|
echo "CONDA_ENV_NAME: $CONDA_ENV_NAME"
|
||||||
|
|
||||||
|
SWE_TASK_DIR=/openhands/swe_tasks
|
||||||
|
mkdir -p $SWE_TASK_DIR
|
||||||
|
# Dump test_patch to /workspace/test.patch
|
||||||
|
echo "$item" | jq -r '.test_patch' > $SWE_TASK_DIR/test.patch
|
||||||
|
# Dump patch to /workspace/gold.patch
|
||||||
|
echo "$item" | jq -r '.patch' > $SWE_TASK_DIR/gold.patch
|
||||||
|
# Dump the item to /workspace/instance.json except for the "test_patch" and "patch" fields
|
||||||
|
echo "$item" | jq 'del(.test_patch, .patch)' > $SWE_TASK_DIR/instance.json
|
||||||
|
|
||||||
|
# Clear the workspace
|
||||||
|
rm -rf /workspace/*
|
||||||
|
# Copy repo to workspace
|
||||||
|
if [ -d /workspace/$CONDA_ENV_NAME ]; then
|
||||||
|
rm -rf /workspace/$CONDA_ENV_NAME
|
||||||
|
fi
|
||||||
|
cp -r $SWEUTIL_DIR/eval_data/testbeds/$CONDA_ENV_NAME /workspace
|
||||||
|
|
||||||
|
# Reset swe-bench testbed and install the repo
|
||||||
|
. $SWEUTIL_DIR/miniforge3/etc/profile.d/conda.sh
|
||||||
|
conda config --set changeps1 False
|
||||||
|
conda config --append channels conda-forge
|
||||||
|
conda activate swe-bench-eval
|
||||||
|
|
||||||
|
mkdir -p $SWE_TASK_DIR/reset_testbed_temp
|
||||||
|
mkdir -p $SWE_TASK_DIR/reset_testbed_log_dir
|
||||||
|
SWE_BENCH_DIR=/swe_util/OH-SWE-bench
|
||||||
|
output=$(
|
||||||
|
export PYTHONPATH=$SWE_BENCH_DIR && \
|
||||||
|
cd $SWE_BENCH_DIR && \
|
||||||
|
python swebench/harness/reset_swe_env.py \
|
||||||
|
--swe_bench_tasks $SWEUTIL_DIR/eval_data/instances/swe-bench-test.json \
|
||||||
|
--temp_dir $SWE_TASK_DIR/reset_testbed_temp \
|
||||||
|
--testbed /workspace \
|
||||||
|
--conda_path $SWEUTIL_DIR/miniforge3 \
|
||||||
|
--instance_id $SWE_INSTANCE_ID \
|
||||||
|
--log_dir $SWE_TASK_DIR/reset_testbed_log_dir \
|
||||||
|
--timeout 900 \
|
||||||
|
--verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
REPO_PATH=$(echo "$output" | awk -F': ' '/repo_path:/ {print $2}')
|
||||||
|
TEST_CMD=$(echo "$output" | awk -F': ' '/test_cmd:/ {print $2}')
|
||||||
|
echo "Repo Path: $REPO_PATH"
|
||||||
|
echo "Test Command: $TEST_CMD"
|
||||||
|
|
||||||
|
echo "export SWE_BENCH_DIR=\"$SWE_BENCH_DIR\"" >> ~/.bashrc
|
||||||
|
echo "export REPO_PATH=\"$REPO_PATH\"" >> ~/.bashrc
|
||||||
|
echo "export TEST_CMD=\"$TEST_CMD\"" >> ~/.bashrc
|
||||||
|
|
||||||
|
if [[ "$REPO_PATH" == "None" ]]; then
|
||||||
|
echo "Error: Failed to retrieve repository path. Tests may not have passed or output was not as expected." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Activate instance-specific environment
|
||||||
|
. $SWEUTIL_DIR/miniforge3/etc/profile.d/conda.sh
|
||||||
|
conda activate $CONDA_ENV_NAME
|
||||||
|
|
||||||
|
set +e
|
||||||
@@ -0,0 +1,327 @@
|
|||||||
|
import ast
|
||||||
|
import re
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from evaluation.benchmarks.testgeneval.constants import TestStatus
|
||||||
|
from evaluation.benchmarks.testgeneval.log_parsers import (
|
||||||
|
MAP_REPO_TO_PARSER,
|
||||||
|
parse_log_pytest,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def indent_text(text, indent_level):
|
||||||
|
return '\n'.join(
|
||||||
|
' ' * indent_level + line if line.strip() else line for line in text.split('\n')
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_preamble_classes_and_functions(code):
|
||||||
|
class_pattern = re.compile(
|
||||||
|
r'(?P<decorators>(?:^@[^\r\n]*(?:\r?\n(?:[ \t]+[^\r\n]*|^\)[^\r\n]*)*)*\r?\n)*?)'
|
||||||
|
r'^class\s+([\w]+)(?:\([^)]*\))?:', # the class line
|
||||||
|
re.MULTILINE,
|
||||||
|
)
|
||||||
|
# Capture methods with or without decorators
|
||||||
|
method_pattern = re.compile(r'(^(\s*@.*\s*)*^\s*def\s+[\w_]+\(.*\):)', re.MULTILINE)
|
||||||
|
|
||||||
|
# Capture functions with or without decorators
|
||||||
|
function_pattern = re.compile(
|
||||||
|
r'(?P<decorators>(?:^@[^\r\n]*(?:\r?\n(?:[ \t]+[^\r\n]*|^\)[^\r\n]*)*)*\r?\n)*?)'
|
||||||
|
r'^def\s+([\w_]+)\(.*\):', # the function line
|
||||||
|
re.MULTILINE,
|
||||||
|
)
|
||||||
|
|
||||||
|
preamble = ''
|
||||||
|
classes = []
|
||||||
|
test_functions = []
|
||||||
|
|
||||||
|
current_position = 0
|
||||||
|
|
||||||
|
def extract_class_body(code: str, start_index: int) -> Tuple[str, int]:
|
||||||
|
"""
|
||||||
|
Extracts the body of a class from the given code starting from the specified index.
|
||||||
|
Returns the class body and the end index of the class body.
|
||||||
|
"""
|
||||||
|
if not code or start_index < 0 or start_index >= len(code):
|
||||||
|
raise ValueError('Invalid code or start index')
|
||||||
|
|
||||||
|
# Split the code into lines
|
||||||
|
lines = code[start_index:].split('\n')
|
||||||
|
class_body_lines = []
|
||||||
|
|
||||||
|
# Find the starting indentation level of the class definition
|
||||||
|
class_start_line = lines[0]
|
||||||
|
start_indent = len(class_start_line) - len(class_start_line.lstrip())
|
||||||
|
|
||||||
|
inside_multiline_comment = False
|
||||||
|
end_index = start_index
|
||||||
|
for i, line in enumerate(lines[1:], start=1):
|
||||||
|
stripped_line = line.strip()
|
||||||
|
current_indent = len(line) - len(line.lstrip())
|
||||||
|
|
||||||
|
# Handle multiline comments or docstrings
|
||||||
|
if stripped_line.startswith('"""') or stripped_line.startswith("'''"):
|
||||||
|
if inside_multiline_comment:
|
||||||
|
inside_multiline_comment = False
|
||||||
|
else:
|
||||||
|
inside_multiline_comment = True
|
||||||
|
|
||||||
|
if not inside_multiline_comment:
|
||||||
|
# Stop when we reach a line with less indentation than the class definition
|
||||||
|
if current_indent <= start_indent and stripped_line:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Add lines that are part of the class body
|
||||||
|
class_body_lines.append(line)
|
||||||
|
# Update the end index to the current line end
|
||||||
|
end_index = start_index + len('\n'.join(lines[: i + 1])) + 1
|
||||||
|
|
||||||
|
return code[start_index:end_index], end_index
|
||||||
|
|
||||||
|
while current_position < len(code):
|
||||||
|
class_match = class_pattern.search(code, current_position)
|
||||||
|
method_match = method_pattern.search(code, current_position)
|
||||||
|
|
||||||
|
if class_match and (
|
||||||
|
not method_match or class_match.start() < method_match.start()
|
||||||
|
):
|
||||||
|
class_name = class_match.group(0)
|
||||||
|
class_body, end_idx = extract_class_body(code, class_match.end())
|
||||||
|
current_position = end_idx
|
||||||
|
|
||||||
|
methods = []
|
||||||
|
class_prefix = class_name
|
||||||
|
set_prefix = False
|
||||||
|
for method_match in method_pattern.finditer(class_body):
|
||||||
|
method_name = method_match.group()
|
||||||
|
method_start = method_match.start()
|
||||||
|
if not set_prefix:
|
||||||
|
class_prefix = class_name + class_body[:method_start]
|
||||||
|
set_prefix = True
|
||||||
|
next_method = method_pattern.search(
|
||||||
|
class_body, method_start + len(method_name)
|
||||||
|
)
|
||||||
|
method_body = (
|
||||||
|
class_body[method_start : next_method.start()]
|
||||||
|
if next_method
|
||||||
|
else class_body[method_start:]
|
||||||
|
)
|
||||||
|
methods.append((method_name, method_body))
|
||||||
|
|
||||||
|
classes.append((class_prefix, methods, class_match.start()))
|
||||||
|
|
||||||
|
elif method_match:
|
||||||
|
function_name = method_match.group(0)
|
||||||
|
start_idx = method_match.start()
|
||||||
|
|
||||||
|
# Extract the current function's indentation level
|
||||||
|
lines = code[start_idx:].split('\n')
|
||||||
|
current_indent = len(lines[0]) - len(lines[0].lstrip())
|
||||||
|
|
||||||
|
next_function = function_pattern.search(
|
||||||
|
code, start_idx + len(function_name)
|
||||||
|
)
|
||||||
|
while next_function and (
|
||||||
|
class_match is None or next_function.start() < class_match.start()
|
||||||
|
):
|
||||||
|
# Calculate the indentation of the next function
|
||||||
|
next_function_start = next_function.start()
|
||||||
|
next_line = code[next_function_start:].split('\n', 1)[0]
|
||||||
|
next_indent = len(next_line) - len(next_line.lstrip())
|
||||||
|
|
||||||
|
# Check if the next function is top-level
|
||||||
|
if next_indent <= current_indent:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Continue searching for the next top-level function
|
||||||
|
next_function = function_pattern.search(
|
||||||
|
code, next_function.start() + len(next_function.group(0))
|
||||||
|
)
|
||||||
|
|
||||||
|
if next_function:
|
||||||
|
next_function_start = next_function.start()
|
||||||
|
if class_match and next_function_start > class_match.start():
|
||||||
|
next_function_start = class_match.start()
|
||||||
|
function_body = code[start_idx:next_function_start]
|
||||||
|
else:
|
||||||
|
function_body = code[start_idx:]
|
||||||
|
|
||||||
|
test_functions.append((function_body, start_idx))
|
||||||
|
current_position = start_idx + len(function_body)
|
||||||
|
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if classes and test_functions:
|
||||||
|
preamble = code[: min(classes[0][2], test_functions[0][1])]
|
||||||
|
else:
|
||||||
|
preamble = (
|
||||||
|
code[: classes[0][2]]
|
||||||
|
if classes
|
||||||
|
else code[: test_functions[0][1]]
|
||||||
|
if test_functions
|
||||||
|
else code
|
||||||
|
)
|
||||||
|
|
||||||
|
return preamble.strip(), classes, test_functions
|
||||||
|
|
||||||
|
|
||||||
|
def filter_passing_tests(
|
||||||
|
test_content: str, test_output: str, repo: str
|
||||||
|
) -> Tuple[str, List[str], List[str]]:
|
||||||
|
"""
|
||||||
|
Filter tests based on their execution results.
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- Modified test content with only passing tests
|
||||||
|
- List of passing test names
|
||||||
|
- List of failing test names
|
||||||
|
"""
|
||||||
|
# Parse test results using appropriate parser
|
||||||
|
parser = MAP_REPO_TO_PARSER.get(repo, parse_log_pytest)
|
||||||
|
test_results = parser(test_output)
|
||||||
|
# Get passing and failing tests
|
||||||
|
passing_tests = []
|
||||||
|
failing_tests = []
|
||||||
|
for test_name, status in test_results.items():
|
||||||
|
if status == TestStatus.PASSED.value:
|
||||||
|
passing_tests.append(test_name)
|
||||||
|
else:
|
||||||
|
failing_tests.append(test_name)
|
||||||
|
|
||||||
|
if not passing_tests:
|
||||||
|
return '', passing_tests, failing_tests
|
||||||
|
|
||||||
|
# Extract test components
|
||||||
|
preamble, classes, functions = extract_preamble_classes_and_functions(test_content)
|
||||||
|
|
||||||
|
# Filter classes to only include passing methods
|
||||||
|
filtered_classes = []
|
||||||
|
for class_name, methods, start_idx in classes:
|
||||||
|
non_fail_methods = []
|
||||||
|
for method_name, method_body in methods:
|
||||||
|
# Extract the base method name for matching
|
||||||
|
method_full_name = (
|
||||||
|
method_name.split('.')[-1].split('(')[0].strip().split(' ')[-1]
|
||||||
|
)
|
||||||
|
# Check if the method name is in failing_tests or if any failing_test is in the method name
|
||||||
|
if not (
|
||||||
|
any(method_full_name in failing_test for failing_test in failing_tests)
|
||||||
|
or any(
|
||||||
|
failing_test in method_full_name for failing_test in failing_tests
|
||||||
|
)
|
||||||
|
):
|
||||||
|
non_fail_methods.append((method_name, method_body))
|
||||||
|
|
||||||
|
if non_fail_methods:
|
||||||
|
filtered_classes.append((class_name, non_fail_methods, start_idx))
|
||||||
|
|
||||||
|
# Filter standalone functions
|
||||||
|
filtered_functions = []
|
||||||
|
for func_body, start_idx in functions:
|
||||||
|
func_name = func_body.split('def ')[1].split('(')[0].strip()
|
||||||
|
if any(func_name in failing_test for failing_test in failing_tests) or any(
|
||||||
|
failing_test in func_name for failing_test in failing_tests
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
filtered_functions.append((func_body, start_idx))
|
||||||
|
|
||||||
|
# Reconstruct test content with only passing tests
|
||||||
|
content_parts = [preamble]
|
||||||
|
|
||||||
|
# Add filtered classes
|
||||||
|
for class_name, methods, _ in filtered_classes:
|
||||||
|
class_content = class_name + '\n'
|
||||||
|
for _, method_body in methods:
|
||||||
|
class_content += method_body + '\n'
|
||||||
|
content_parts.append(class_content)
|
||||||
|
|
||||||
|
# Add filtered functions
|
||||||
|
for func_body, _ in filtered_functions:
|
||||||
|
content_parts.append(func_body)
|
||||||
|
|
||||||
|
return '\n\n'.join(content_parts), passing_tests, failing_tests
|
||||||
|
|
||||||
|
|
||||||
|
def filter_tests(
|
||||||
|
test_content: str, test_output: str, repo: str
|
||||||
|
) -> Tuple[str, List[str], List[str]]:
|
||||||
|
"""
|
||||||
|
Filter tests using AST parsing to remove failing test functions from the test file.
|
||||||
|
Non-test functions (e.g. setup or helper methods) and classes (even if all test methods are failing)
|
||||||
|
are preserved.
|
||||||
|
|
||||||
|
If AST processing fails (for example, because the test file cannot be parsed),
|
||||||
|
this function falls back on the existing regex-based filtering (filter_passing_tests).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing:
|
||||||
|
- Modified test content (as a string) containing only passing tests.
|
||||||
|
- List of passing test names.
|
||||||
|
- List of failing test names.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Attempt to parse the test file using the AST.
|
||||||
|
tree = ast.parse(test_content)
|
||||||
|
|
||||||
|
# Parse test results using the appropriate parser.
|
||||||
|
parser = MAP_REPO_TO_PARSER.get(repo, parse_log_pytest)
|
||||||
|
test_results = parser(test_output)
|
||||||
|
passing_tests = [
|
||||||
|
name
|
||||||
|
for name, status in test_results.items()
|
||||||
|
if status == TestStatus.PASSED.value
|
||||||
|
]
|
||||||
|
failing_tests = [
|
||||||
|
name
|
||||||
|
for name, status in test_results.items()
|
||||||
|
if status != TestStatus.PASSED.value
|
||||||
|
]
|
||||||
|
|
||||||
|
# Helper function to decide if a test name should be considered failing.
|
||||||
|
def is_failing(name: str) -> bool:
|
||||||
|
for ft in failing_tests:
|
||||||
|
if name in ft or ft in name:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
new_body = []
|
||||||
|
for node in tree.body:
|
||||||
|
# For top-level function definitions, only filter those that look like tests.
|
||||||
|
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||||
|
if node.name.startswith('test') and is_failing(node.name):
|
||||||
|
continue
|
||||||
|
new_body.append(node)
|
||||||
|
# For classes, filter out failing test methods but preserve other methods (e.g. setup).
|
||||||
|
elif isinstance(node, ast.ClassDef):
|
||||||
|
new_class_body = []
|
||||||
|
for subnode in node.body:
|
||||||
|
if isinstance(subnode, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||||
|
# Only consider filtering if the method is a test.
|
||||||
|
qualified_name = f'{node.name}.{subnode.name}'
|
||||||
|
if is_failing(subnode.name) or is_failing(qualified_name):
|
||||||
|
continue
|
||||||
|
new_class_body.append(subnode)
|
||||||
|
else:
|
||||||
|
new_class_body.append(subnode)
|
||||||
|
# Always include the class even if no test methods remain, as it might contain
|
||||||
|
# setup, teardown, or other necessary logic.
|
||||||
|
if new_class_body:
|
||||||
|
node.body = new_class_body
|
||||||
|
new_body.append(node)
|
||||||
|
|
||||||
|
else:
|
||||||
|
new_body.append(node)
|
||||||
|
|
||||||
|
tree.body = new_body
|
||||||
|
|
||||||
|
# Reconstruct the source code from the filtered AST.
|
||||||
|
# (Requires Python 3.9+ for ast.unparse; otherwise an exception will trigger the fallback.)
|
||||||
|
new_test_content = ast.unparse(tree)
|
||||||
|
return new_test_content, passing_tests, failing_tests
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
print('AST processing failed; falling back on regex-based filtering.')
|
||||||
|
# If AST processing fails for any reason, fall back on the original regex-based filtering.
|
||||||
|
return filter_passing_tests(test_content, test_output, repo)
|
||||||
@@ -0,0 +1,166 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from evaluation.benchmarks.testgeneval.constants import (
|
||||||
|
COVERAGE_PREFIX,
|
||||||
|
KEY_INSTANCE_ID,
|
||||||
|
MAP_REPO_VERSION_TO_SPECS,
|
||||||
|
TESTS_FAILED,
|
||||||
|
TESTS_SUFFIX,
|
||||||
|
UPDATE_TOX,
|
||||||
|
TestGenEvalInstance,
|
||||||
|
)
|
||||||
|
from evaluation.benchmarks.testgeneval.utils import (
|
||||||
|
get_test_directives,
|
||||||
|
)
|
||||||
|
|
||||||
|
DIFF_MODIFIED_FILE_REGEX = r'--- a/(.*)'
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TestSpec:
|
||||||
|
"""
|
||||||
|
A dataclass that represents a test specification for a single instance of SWE-bench.
|
||||||
|
"""
|
||||||
|
|
||||||
|
instance_id: str
|
||||||
|
id: str
|
||||||
|
repo: str
|
||||||
|
version: str
|
||||||
|
test_cmd: str
|
||||||
|
code_file: str
|
||||||
|
test_file: str
|
||||||
|
baseline_covs: dict
|
||||||
|
local_imports: list[str]
|
||||||
|
test_script_list: list[str]
|
||||||
|
mutation_script_list: list[str]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def test_script(self):
|
||||||
|
return (
|
||||||
|
'\n'.join(['#!/bin/bash', 'set -uo pipefail'] + self.test_script_list)
|
||||||
|
+ '\n'
|
||||||
|
)
|
||||||
|
# Don't exit early because we need to revert tests at the end
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mutation_script(self):
|
||||||
|
return (
|
||||||
|
'\n'.join(['#!/bin/bash', 'set -uo pipefail'] + self.mutation_script_list)
|
||||||
|
+ '\n'
|
||||||
|
)
|
||||||
|
# Don't exit early because we need to revert tests at the end
|
||||||
|
|
||||||
|
|
||||||
|
def make_test_setup(specs, env_name, repo_directory, includes_tox=False):
|
||||||
|
eval_commands = []
|
||||||
|
|
||||||
|
if includes_tox:
|
||||||
|
eval_commands.append(UPDATE_TOX)
|
||||||
|
|
||||||
|
eval_commands += [
|
||||||
|
'source /opt/miniconda3/bin/activate',
|
||||||
|
f'conda activate {env_name}',
|
||||||
|
f'cd {repo_directory}',
|
||||||
|
]
|
||||||
|
if 'eval_commands' in specs:
|
||||||
|
eval_commands += specs['eval_commands']
|
||||||
|
eval_commands += [
|
||||||
|
f'git config --global --add safe.directory {repo_directory}', # for nonroot user
|
||||||
|
f'cd {repo_directory}',
|
||||||
|
# This is just informational, so we have a record
|
||||||
|
'git status',
|
||||||
|
'git show',
|
||||||
|
'source /opt/miniconda3/bin/activate',
|
||||||
|
f'conda activate {env_name}',
|
||||||
|
]
|
||||||
|
if 'install' in specs:
|
||||||
|
eval_commands.append(specs['install'])
|
||||||
|
|
||||||
|
if includes_tox:
|
||||||
|
eval_commands.append('add_coverage_tox "tox.ini"')
|
||||||
|
|
||||||
|
eval_commands.append('[ -f ".coveragerc" ] && rm ".coveragerc"')
|
||||||
|
return eval_commands
|
||||||
|
|
||||||
|
|
||||||
|
def make_test_script_list(test_cmd, specs, env_name, repo_directory):
|
||||||
|
"""
|
||||||
|
Runs the tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
includes_tox = 'tox' in test_cmd
|
||||||
|
eval_commands = make_test_setup(specs, env_name, repo_directory, includes_tox)
|
||||||
|
eval_commands += [
|
||||||
|
f'{test_cmd} || {{ echo "{TESTS_FAILED}\n{TESTS_SUFFIX}\n" && exit 1; }}',
|
||||||
|
f'echo "{TESTS_SUFFIX}"\n',
|
||||||
|
'coverage json -o coverage.json',
|
||||||
|
f'echo "{COVERAGE_PREFIX}"\n',
|
||||||
|
'cat coverage.json',
|
||||||
|
]
|
||||||
|
|
||||||
|
return eval_commands
|
||||||
|
|
||||||
|
|
||||||
|
def make_mutation_script_list(specs, env_name, repo_directory, mutation_timeout):
|
||||||
|
"""
|
||||||
|
Runs the tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
eval_commands = make_test_setup(specs, env_name, repo_directory)
|
||||||
|
eval_commands += [
|
||||||
|
'cosmic-ray init mutation.toml mutation.sqlite',
|
||||||
|
f'timeout {mutation_timeout}s cosmic-ray exec mutation.toml mutation.sqlite',
|
||||||
|
'cr-report mutation.sqlite',
|
||||||
|
'cr-rate mutation.sqlite --estimate --confidence 95.0',
|
||||||
|
]
|
||||||
|
return eval_commands
|
||||||
|
|
||||||
|
|
||||||
|
def make_test_spec(
|
||||||
|
instance: TestGenEvalInstance, mutation_timeout: int, buffer: int
|
||||||
|
) -> TestSpec:
|
||||||
|
if isinstance(instance, TestSpec):
|
||||||
|
return instance
|
||||||
|
instance_id = instance[KEY_INSTANCE_ID]
|
||||||
|
id = instance['id']
|
||||||
|
repo = instance['repo']
|
||||||
|
version = instance['version']
|
||||||
|
baseline_covs = instance['baseline_covs']
|
||||||
|
code_file = instance['code_file']
|
||||||
|
test_file = instance['test_file']
|
||||||
|
local_imports = instance['local_imports']
|
||||||
|
|
||||||
|
env_name = 'testbed'
|
||||||
|
repo_directory = f'/{env_name}'
|
||||||
|
specs = MAP_REPO_VERSION_TO_SPECS[repo][version]
|
||||||
|
|
||||||
|
test_cmd = ' '.join(
|
||||||
|
[
|
||||||
|
MAP_REPO_VERSION_TO_SPECS[instance['repo']][instance['version']][
|
||||||
|
'test_cmd'
|
||||||
|
],
|
||||||
|
*get_test_directives(instance),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
test_script_list = make_test_script_list(test_cmd, specs, env_name, repo_directory)
|
||||||
|
|
||||||
|
mutation_script_list = make_mutation_script_list(
|
||||||
|
specs, env_name, repo_directory, mutation_timeout - buffer
|
||||||
|
)
|
||||||
|
|
||||||
|
return TestSpec(
|
||||||
|
instance_id=instance_id,
|
||||||
|
id=id,
|
||||||
|
repo=repo,
|
||||||
|
test_script_list=test_script_list,
|
||||||
|
test_cmd=test_cmd,
|
||||||
|
local_imports=local_imports,
|
||||||
|
mutation_script_list=mutation_script_list,
|
||||||
|
code_file=code_file,
|
||||||
|
test_file=test_file,
|
||||||
|
baseline_covs=baseline_covs,
|
||||||
|
version=version,
|
||||||
|
)
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from datasets import Dataset, load_dataset
|
||||||
|
|
||||||
|
from evaluation.benchmarks.testgeneval.constants import (
|
||||||
|
KEY_INSTANCE_ID,
|
||||||
|
TestGenEvalInstance,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_directives(instance: TestGenEvalInstance) -> list:
|
||||||
|
"""
|
||||||
|
Get test directives from the test_patch of a task instance
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance (dict): task instance
|
||||||
|
Returns:
|
||||||
|
directives (list): List of test directives
|
||||||
|
"""
|
||||||
|
# For seq2seq code repos, testing command is fixed
|
||||||
|
if instance['repo'] == 'swe-bench/humaneval':
|
||||||
|
return ['test.py']
|
||||||
|
|
||||||
|
# Get test directives from test patch and remove non-test files
|
||||||
|
directives = [f"/testbed/{instance['test_file']}"]
|
||||||
|
|
||||||
|
# For Django tests, remove extension + "tests/" prefix and convert slashes to dots (module referencing)
|
||||||
|
if instance['repo'] == 'django/django':
|
||||||
|
directives = [instance['test_file']]
|
||||||
|
directives_transformed = []
|
||||||
|
for d in directives:
|
||||||
|
d = d[: -len('.py')] if d.endswith('.py') else d
|
||||||
|
d = d[len('tests/') :] if d.startswith('tests/') else d
|
||||||
|
d = d.replace('/', '.')
|
||||||
|
directives_transformed.append(d)
|
||||||
|
directives = directives_transformed
|
||||||
|
|
||||||
|
return directives
|
||||||
|
|
||||||
|
|
||||||
|
def load_testgeneval_dataset(
|
||||||
|
name='kjain14/testgeneval', split='test', ids=None
|
||||||
|
) -> list[TestGenEvalInstance]:
|
||||||
|
"""
|
||||||
|
Load SWE-bench dataset from Hugging Face Datasets or local .json/.jsonl file
|
||||||
|
"""
|
||||||
|
# check that all instance IDs are in the dataset
|
||||||
|
if ids:
|
||||||
|
ids = set(ids)
|
||||||
|
# Load from local .json/.jsonl file
|
||||||
|
if name.endswith('.json') or name.endswith('.jsonl'):
|
||||||
|
dataset = json.loads(Path(name).read_text())
|
||||||
|
dataset_ids = {instance[KEY_INSTANCE_ID] for instance in dataset}
|
||||||
|
else:
|
||||||
|
# Load from Hugging Face Datasets
|
||||||
|
if name.lower() in {'testgeneval'}:
|
||||||
|
name = 'kjain14/testgeneval'
|
||||||
|
elif name.lower() in {'testgeneval-lite', 'testgenevallite', 'lite'}:
|
||||||
|
name = 'kjain14/testgenevallite'
|
||||||
|
dataset = cast(Dataset, load_dataset(name, split=split))
|
||||||
|
dataset_ids = {instance['id'] for instance in dataset}
|
||||||
|
if ids:
|
||||||
|
if ids - dataset_ids:
|
||||||
|
raise ValueError(
|
||||||
|
(
|
||||||
|
"Some instance IDs not found in dataset!"
|
||||||
|
f"\nMissing IDs:\n{' '.join(ids - dataset_ids)}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
dataset = [instance for instance in dataset if instance['id'] in ids]
|
||||||
|
return [cast(TestGenEvalInstance, instance) for instance in dataset]
|
||||||
@@ -34,7 +34,6 @@ from openhands.utils.async_utils import call_async_from_sync
|
|||||||
|
|
||||||
FAKE_RESPONSES = {
|
FAKE_RESPONSES = {
|
||||||
'CodeActAgent': fake_user_response,
|
'CodeActAgent': fake_user_response,
|
||||||
'DelegatorAgent': fake_user_response,
|
|
||||||
'VisualBrowsingAgent': fake_user_response,
|
'VisualBrowsingAgent': fake_user_response,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ import userEvent from "@testing-library/user-event";
|
|||||||
import { afterEach, beforeEach, describe, expect, it, test, vi } from "vitest";
|
import { afterEach, beforeEach, describe, expect, it, test, vi } from "vitest";
|
||||||
import OpenHands from "#/api/open-hands";
|
import OpenHands from "#/api/open-hands";
|
||||||
import { PaymentForm } from "#/components/features/payment/payment-form";
|
import { PaymentForm } from "#/components/features/payment/payment-form";
|
||||||
|
import * as featureFlags from "#/utils/feature-flags";
|
||||||
|
|
||||||
describe("PaymentForm", () => {
|
describe("PaymentForm", () => {
|
||||||
|
const billingSettingsSpy = vi.spyOn(featureFlags, "BILLING_SETTINGS");
|
||||||
const getBalanceSpy = vi.spyOn(OpenHands, "getBalance");
|
const getBalanceSpy = vi.spyOn(OpenHands, "getBalance");
|
||||||
const createCheckoutSessionSpy = vi.spyOn(OpenHands, "createCheckoutSession");
|
const createCheckoutSessionSpy = vi.spyOn(OpenHands, "createCheckoutSession");
|
||||||
const getConfigSpy = vi.spyOn(OpenHands, "getConfig");
|
const getConfigSpy = vi.spyOn(OpenHands, "getConfig");
|
||||||
@@ -26,6 +28,7 @@ describe("PaymentForm", () => {
|
|||||||
GITHUB_CLIENT_ID: "123",
|
GITHUB_CLIENT_ID: "123",
|
||||||
POSTHOG_CLIENT_KEY: "456",
|
POSTHOG_CLIENT_KEY: "456",
|
||||||
});
|
});
|
||||||
|
billingSettingsSpy.mockReturnValue(true);
|
||||||
});
|
});
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import { useQuery } from "@tanstack/react-query";
|
import { useQuery } from "@tanstack/react-query";
|
||||||
import { useConfig } from "./use-config";
|
import { useConfig } from "./use-config";
|
||||||
import OpenHands from "#/api/open-hands";
|
import OpenHands from "#/api/open-hands";
|
||||||
|
import { BILLING_SETTINGS } from "#/utils/feature-flags";
|
||||||
|
|
||||||
export const useBalance = () => {
|
export const useBalance = () => {
|
||||||
const { data: config } = useConfig();
|
const { data: config } = useConfig();
|
||||||
@@ -8,6 +9,6 @@ export const useBalance = () => {
|
|||||||
return useQuery({
|
return useQuery({
|
||||||
queryKey: ["user", "balance"],
|
queryKey: ["user", "balance"],
|
||||||
queryFn: OpenHands.getBalance,
|
queryFn: OpenHands.getBalance,
|
||||||
enabled: config?.APP_MODE === "saas",
|
enabled: config?.APP_MODE === "saas" && BILLING_SETTINGS(),
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ load_dotenv()
|
|||||||
from openhands.agenthub import ( # noqa: E402
|
from openhands.agenthub import ( # noqa: E402
|
||||||
browsing_agent,
|
browsing_agent,
|
||||||
codeact_agent,
|
codeact_agent,
|
||||||
delegator_agent,
|
|
||||||
dummy_agent,
|
dummy_agent,
|
||||||
visualbrowsing_agent,
|
visualbrowsing_agent,
|
||||||
)
|
)
|
||||||
@@ -15,7 +14,6 @@ from openhands.controller.agent import Agent # noqa: E402
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
'Agent',
|
'Agent',
|
||||||
'codeact_agent',
|
'codeact_agent',
|
||||||
'delegator_agent',
|
|
||||||
'dummy_agent',
|
'dummy_agent',
|
||||||
'browsing_agent',
|
'browsing_agent',
|
||||||
'visualbrowsing_agent',
|
'visualbrowsing_agent',
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
import openhands
|
|
||||||
import openhands.agenthub.codeact_agent.function_calling as codeact_function_calling
|
import openhands.agenthub.codeact_agent.function_calling as codeact_function_calling
|
||||||
from openhands.controller.agent import Agent
|
from openhands.controller.agent import Agent
|
||||||
from openhands.controller.state.state import State
|
from openhands.controller.state.state import State
|
||||||
@@ -72,23 +70,17 @@ class CodeActAgent(Agent):
|
|||||||
codeact_enable_browsing=self.config.codeact_enable_browsing,
|
codeact_enable_browsing=self.config.codeact_enable_browsing,
|
||||||
codeact_enable_jupyter=self.config.codeact_enable_jupyter,
|
codeact_enable_jupyter=self.config.codeact_enable_jupyter,
|
||||||
codeact_enable_llm_editor=self.config.codeact_enable_llm_editor,
|
codeact_enable_llm_editor=self.config.codeact_enable_llm_editor,
|
||||||
|
llm=self.llm,
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f'TOOLS loaded for CodeActAgent: {json.dumps(self.tools, indent=2, ensure_ascii=False).replace("\\n", "\n")}'
|
f"TOOLS loaded for CodeActAgent: {', '.join([tool.get('function').get('name') for tool in self.tools])}"
|
||||||
)
|
)
|
||||||
self.prompt_manager = PromptManager(
|
self.prompt_manager = PromptManager(
|
||||||
microagent_dir=os.path.join(
|
|
||||||
os.path.dirname(os.path.dirname(openhands.__file__)),
|
|
||||||
'microagents',
|
|
||||||
)
|
|
||||||
if self.config.enable_prompt_extensions
|
|
||||||
else None,
|
|
||||||
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
|
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
|
||||||
disabled_microagents=self.config.disabled_microagents,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a ConversationMemory instance
|
# Create a ConversationMemory instance
|
||||||
self.conversation_memory = ConversationMemory(self.prompt_manager)
|
self.conversation_memory = ConversationMemory(self.config, self.prompt_manager)
|
||||||
|
|
||||||
self.condenser = Condenser.from_config(self.config.condenser)
|
self.condenser = Condenser.from_config(self.config.condenser)
|
||||||
logger.debug(f'Using condenser: {type(self.condenser)}')
|
logger.debug(f'Using condenser: {type(self.condenser)}')
|
||||||
@@ -168,7 +160,7 @@ class CodeActAgent(Agent):
|
|||||||
if not self.prompt_manager:
|
if not self.prompt_manager:
|
||||||
raise Exception('Prompt Manager not instantiated.')
|
raise Exception('Prompt Manager not instantiated.')
|
||||||
|
|
||||||
# Use conversation_memory to process events instead of calling events_to_messages directly
|
# Use ConversationMemory to process initial messages
|
||||||
messages = self.conversation_memory.process_initial_messages(
|
messages = self.conversation_memory.process_initial_messages(
|
||||||
with_caching=self.llm.is_caching_prompt_active()
|
with_caching=self.llm.is_caching_prompt_active()
|
||||||
)
|
)
|
||||||
@@ -180,12 +172,12 @@ class CodeActAgent(Agent):
|
|||||||
f'Processing {len(events)} events from a total of {len(state.history)} events'
|
f'Processing {len(events)} events from a total of {len(state.history)} events'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Use ConversationMemory to process events
|
||||||
messages = self.conversation_memory.process_events(
|
messages = self.conversation_memory.process_events(
|
||||||
condensed_history=events,
|
condensed_history=events,
|
||||||
initial_messages=messages,
|
initial_messages=messages,
|
||||||
max_message_chars=self.llm.config.max_message_chars,
|
max_message_chars=self.llm.config.max_message_chars,
|
||||||
vision_is_active=self.llm.vision_is_active(),
|
vision_is_active=self.llm.vision_is_active(),
|
||||||
enable_som_visual_browsing=self.config.enable_som_visual_browsing,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = self._enhance_messages(messages)
|
messages = self._enhance_messages(messages)
|
||||||
@@ -216,14 +208,7 @@ class CodeActAgent(Agent):
|
|||||||
# compose the first user message with examples
|
# compose the first user message with examples
|
||||||
self.prompt_manager.add_examples_to_initial_message(msg)
|
self.prompt_manager.add_examples_to_initial_message(msg)
|
||||||
|
|
||||||
# and/or repo/runtime info
|
elif msg.role == 'user':
|
||||||
if self.config.enable_prompt_extensions:
|
|
||||||
self.prompt_manager.add_info_to_initial_message(msg)
|
|
||||||
|
|
||||||
# enhance the user message with additional context based on keywords matched
|
|
||||||
if msg.role == 'user':
|
|
||||||
self.prompt_manager.enhance_message(msg)
|
|
||||||
|
|
||||||
# Add double newline between consecutive user messages
|
# Add double newline between consecutive user messages
|
||||||
if prev_role == 'user' and len(msg.content) > 0:
|
if prev_role == 'user' and len(msg.content) > 0:
|
||||||
# Find the first TextContent in the message to add newlines
|
# Find the first TextContent in the message to add newlines
|
||||||
|
|||||||
@@ -12,13 +12,13 @@ from litellm import (
|
|||||||
|
|
||||||
from openhands.agenthub.codeact_agent.tools import (
|
from openhands.agenthub.codeact_agent.tools import (
|
||||||
BrowserTool,
|
BrowserTool,
|
||||||
CmdRunTool,
|
|
||||||
FinishTool,
|
FinishTool,
|
||||||
IPythonTool,
|
IPythonTool,
|
||||||
LLMBasedFileEditTool,
|
LLMBasedFileEditTool,
|
||||||
StrReplaceEditorTool,
|
|
||||||
ThinkTool,
|
ThinkTool,
|
||||||
WebReadTool,
|
WebReadTool,
|
||||||
|
create_cmd_run_tool,
|
||||||
|
create_str_replace_editor_tool,
|
||||||
)
|
)
|
||||||
from openhands.core.exceptions import (
|
from openhands.core.exceptions import (
|
||||||
FunctionCallNotExistsError,
|
FunctionCallNotExistsError,
|
||||||
@@ -39,6 +39,7 @@ from openhands.events.action import (
|
|||||||
)
|
)
|
||||||
from openhands.events.event import FileEditSource, FileReadSource
|
from openhands.events.event import FileEditSource, FileReadSource
|
||||||
from openhands.events.tool import ToolCallMetadata
|
from openhands.events.tool import ToolCallMetadata
|
||||||
|
from openhands.llm import LLM
|
||||||
|
|
||||||
|
|
||||||
def combine_thought(action: Action, thought: str) -> Action:
|
def combine_thought(action: Action, thought: str) -> Action:
|
||||||
@@ -80,7 +81,7 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
|
|||||||
# CmdRunTool (Bash)
|
# CmdRunTool (Bash)
|
||||||
# ================================================
|
# ================================================
|
||||||
|
|
||||||
if tool_call.function.name == CmdRunTool['function']['name']:
|
if tool_call.function.name == create_cmd_run_tool()['function']['name']:
|
||||||
if 'command' not in arguments:
|
if 'command' not in arguments:
|
||||||
raise FunctionCallValidationError(
|
raise FunctionCallValidationError(
|
||||||
f'Missing required argument "command" in tool call {tool_call.function.name}'
|
f'Missing required argument "command" in tool call {tool_call.function.name}'
|
||||||
@@ -131,7 +132,10 @@ def response_to_actions(response: ModelResponse) -> list[Action]:
|
|||||||
start=arguments.get('start', 1),
|
start=arguments.get('start', 1),
|
||||||
end=arguments.get('end', -1),
|
end=arguments.get('end', -1),
|
||||||
)
|
)
|
||||||
elif tool_call.function.name == StrReplaceEditorTool['function']['name']:
|
elif (
|
||||||
|
tool_call.function.name
|
||||||
|
== create_str_replace_editor_tool()['function']['name']
|
||||||
|
):
|
||||||
if 'command' not in arguments:
|
if 'command' not in arguments:
|
||||||
raise FunctionCallValidationError(
|
raise FunctionCallValidationError(
|
||||||
f'Missing required argument "command" in tool call {tool_call.function.name}'
|
f'Missing required argument "command" in tool call {tool_call.function.name}'
|
||||||
@@ -219,8 +223,22 @@ def get_tools(
|
|||||||
codeact_enable_browsing: bool = False,
|
codeact_enable_browsing: bool = False,
|
||||||
codeact_enable_llm_editor: bool = False,
|
codeact_enable_llm_editor: bool = False,
|
||||||
codeact_enable_jupyter: bool = False,
|
codeact_enable_jupyter: bool = False,
|
||||||
|
llm: LLM | None = None,
|
||||||
) -> list[ChatCompletionToolParam]:
|
) -> list[ChatCompletionToolParam]:
|
||||||
tools = [CmdRunTool, ThinkTool, FinishTool]
|
SIMPLIFIED_TOOL_DESCRIPTION_LLM_SUBSTRS = ['gpt-', 'o3', 'o1']
|
||||||
|
|
||||||
|
use_simplified_tool_desc = False
|
||||||
|
if llm is not None:
|
||||||
|
use_simplified_tool_desc = any(
|
||||||
|
model_substr in llm.config.model
|
||||||
|
for model_substr in SIMPLIFIED_TOOL_DESCRIPTION_LLM_SUBSTRS
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
create_cmd_run_tool(use_simplified_description=use_simplified_tool_desc),
|
||||||
|
ThinkTool,
|
||||||
|
FinishTool,
|
||||||
|
]
|
||||||
if codeact_enable_browsing:
|
if codeact_enable_browsing:
|
||||||
tools.append(WebReadTool)
|
tools.append(WebReadTool)
|
||||||
tools.append(BrowserTool)
|
tools.append(BrowserTool)
|
||||||
@@ -229,5 +247,9 @@ def get_tools(
|
|||||||
if codeact_enable_llm_editor:
|
if codeact_enable_llm_editor:
|
||||||
tools.append(LLMBasedFileEditTool)
|
tools.append(LLMBasedFileEditTool)
|
||||||
else:
|
else:
|
||||||
tools.append(StrReplaceEditorTool)
|
tools.append(
|
||||||
|
create_str_replace_editor_tool(
|
||||||
|
use_simplified_description=use_simplified_tool_desc
|
||||||
|
)
|
||||||
|
)
|
||||||
return tools
|
return tools
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{% if repository_info %}
|
{% if repository_info %}
|
||||||
<REPOSITORY_INFO>
|
<REPOSITORY_INFO>
|
||||||
At the user's request, repository {{ repository_info.repo_name }} has been cloned to directory {{ repository_info.repo_directory }}.
|
At the user's request, repository {{ repository_info.repo_name }} has been cloned to the current working directory {{ repository_info.repo_directory }}.
|
||||||
</REPOSITORY_INFO>
|
</REPOSITORY_INFO>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% if repository_instructions -%}
|
{% if repository_instructions -%}
|
||||||
@@ -20,6 +20,8 @@ When starting a web server, use the corresponding ports. You should also
|
|||||||
set any options to allow iframes and CORS requests, and allow the server to
|
set any options to allow iframes and CORS requests, and allow the server to
|
||||||
be accessed from any host (e.g. 0.0.0.0).
|
be accessed from any host (e.g. 0.0.0.0).
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
{% if runtime_info.additional_agent_instructions %}
|
||||||
{{ runtime_info.additional_agent_instructions }}
|
{{ runtime_info.additional_agent_instructions }}
|
||||||
|
{% endif %}
|
||||||
</RUNTIME_INFORMATION>
|
</RUNTIME_INFORMATION>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
{% for agent_info in triggered_agents %}
|
{% for agent_info in triggered_agents %}
|
||||||
<EXTRA_INFO>
|
<EXTRA_INFO>
|
||||||
The following information has been included based on a keyword match for "{{ agent_info.trigger_word }}".
|
The following information has been included based on a keyword match for "{{ agent_info.trigger }}".
|
||||||
It may or may not be relevant to the user's request.
|
It may or may not be relevant to the user's request.
|
||||||
|
|
||||||
{{ agent_info.agent.content }}
|
{{ agent_info.content }}
|
||||||
</EXTRA_INFO>
|
</EXTRA_INFO>
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|||||||
@@ -1,19 +1,19 @@
|
|||||||
from .bash import CmdRunTool
|
from .bash import create_cmd_run_tool
|
||||||
from .browser import BrowserTool
|
from .browser import BrowserTool
|
||||||
from .finish import FinishTool
|
from .finish import FinishTool
|
||||||
from .ipython import IPythonTool
|
from .ipython import IPythonTool
|
||||||
from .llm_based_edit import LLMBasedFileEditTool
|
from .llm_based_edit import LLMBasedFileEditTool
|
||||||
from .str_replace_editor import StrReplaceEditorTool
|
from .str_replace_editor import create_str_replace_editor_tool
|
||||||
from .think import ThinkTool
|
from .think import ThinkTool
|
||||||
from .web_read import WebReadTool
|
from .web_read import WebReadTool
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BrowserTool',
|
'BrowserTool',
|
||||||
'CmdRunTool',
|
'create_cmd_run_tool',
|
||||||
'FinishTool',
|
'FinishTool',
|
||||||
'IPythonTool',
|
'IPythonTool',
|
||||||
'LLMBasedFileEditTool',
|
'LLMBasedFileEditTool',
|
||||||
'StrReplaceEditorTool',
|
'create_str_replace_editor_tool',
|
||||||
'WebReadTool',
|
'WebReadTool',
|
||||||
'ThinkTool',
|
'ThinkTool',
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from litellm import ChatCompletionToolParam, ChatCompletionToolParamFunctionChunk
|
from litellm import ChatCompletionToolParam, ChatCompletionToolParamFunctionChunk
|
||||||
|
|
||||||
_BASH_DESCRIPTION = """Execute a bash command in the terminal within a persistent shell session.
|
_DETAILED_BASH_DESCRIPTION = """Execute a bash command in the terminal within a persistent shell session.
|
||||||
|
|
||||||
### Command Execution
|
### Command Execution
|
||||||
* One command at a time: You can only execute one bash command at a time. If you need to run multiple commands sequentially, use `&&` or `;` to chain them together.
|
* One command at a time: You can only execute one bash command at a time. If you need to run multiple commands sequentially, use `&&` or `;` to chain them together.
|
||||||
@@ -22,25 +22,39 @@ _BASH_DESCRIPTION = """Execute a bash command in the terminal within a persisten
|
|||||||
* Output truncation: If the output exceeds a maximum length, it will be truncated before being returned.
|
* Output truncation: If the output exceeds a maximum length, it will be truncated before being returned.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CmdRunTool = ChatCompletionToolParam(
|
_SIMPLIFIED_BASH_DESCRIPTION = """Execute a bash command in the terminal.
|
||||||
type='function',
|
* Long running commands: For commands that may run indefinitely, it should be run in the background and the output should be redirected to a file, e.g. command = `python3 app.py > server.log 2>&1 &`.
|
||||||
function=ChatCompletionToolParamFunctionChunk(
|
* Interact with running process: If a bash command returns exit code `-1`, this means the process is not yet finished. By setting `is_input` to `true`, the assistant can interact with the running process and send empty `command` to retrieve any additional logs, or send additional text (set `command` to the text) to STDIN of the running process, or send command like `C-c` (Ctrl+C), `C-d` (Ctrl+D), `C-z` (Ctrl+Z) to interrupt the process.
|
||||||
name='execute_bash',
|
* One command at a time: You can only execute one bash command at a time. If you need to run multiple commands sequentially, you can use `&&` or `;` to chain them together."""
|
||||||
description=_BASH_DESCRIPTION,
|
|
||||||
parameters={
|
|
||||||
'type': 'object',
|
def create_cmd_run_tool(
|
||||||
'properties': {
|
use_simplified_description: bool = False,
|
||||||
'command': {
|
) -> ChatCompletionToolParam:
|
||||||
'type': 'string',
|
description = (
|
||||||
'description': 'The bash command to execute. Can be empty string to view additional logs when previous exit code is `-1`. Can be `C-c` (Ctrl+C) to interrupt the currently running process. Note: You can only execute one bash command at a time. If you need to run multiple commands sequentially, you can use `&&` or `;` to chain them together.',
|
_SIMPLIFIED_BASH_DESCRIPTION
|
||||||
},
|
if use_simplified_description
|
||||||
'is_input': {
|
else _DETAILED_BASH_DESCRIPTION
|
||||||
'type': 'string',
|
)
|
||||||
'description': 'If True, the command is an input to the running process. If False, the command is a bash command to be executed in the terminal. Default is False.',
|
return ChatCompletionToolParam(
|
||||||
'enum': ['true', 'false'],
|
type='function',
|
||||||
|
function=ChatCompletionToolParamFunctionChunk(
|
||||||
|
name='execute_bash',
|
||||||
|
description=description,
|
||||||
|
parameters={
|
||||||
|
'type': 'object',
|
||||||
|
'properties': {
|
||||||
|
'command': {
|
||||||
|
'type': 'string',
|
||||||
|
'description': 'The bash command to execute. Can be empty string to view additional logs when previous exit code is `-1`. Can be `C-c` (Ctrl+C) to interrupt the currently running process. Note: You can only execute one bash command at a time. If you need to run multiple commands sequentially, you can use `&&` or `;` to chain them together.',
|
||||||
|
},
|
||||||
|
'is_input': {
|
||||||
|
'type': 'string',
|
||||||
|
'description': 'If True, the command is an input to the running process. If False, the command is a bash command to be executed in the terminal. Default is False.',
|
||||||
|
'enum': ['true', 'false'],
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
'required': ['command'],
|
||||||
},
|
},
|
||||||
'required': ['command'],
|
),
|
||||||
},
|
)
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from litellm import ChatCompletionToolParam, ChatCompletionToolParamFunctionChunk
|
from litellm import ChatCompletionToolParam, ChatCompletionToolParamFunctionChunk
|
||||||
|
|
||||||
_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files in plain-text format
|
_DETAILED_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files in plain-text format
|
||||||
* State is persistent across command calls and discussions with the user
|
* State is persistent across command calls and discussions with the user
|
||||||
* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep
|
* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep
|
||||||
* The `create` command cannot be used if the specified `path` already exists as a file
|
* The `create` command cannot be used if the specified `path` already exists as a file
|
||||||
@@ -31,46 +31,73 @@ CRITICAL REQUIREMENTS FOR USING THIS TOOL:
|
|||||||
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.
|
Remember: when making multiple file edits in a row to the same file, you should prefer to send all edits in a single message with multiple calls to this tool, rather than multiple messages with a single call each.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
StrReplaceEditorTool = ChatCompletionToolParam(
|
_SIMPLIFIED_STR_REPLACE_EDITOR_DESCRIPTION = """Custom editing tool for viewing, creating and editing files in plain-text format
|
||||||
type='function',
|
* State is persistent across command calls and discussions with the user
|
||||||
function=ChatCompletionToolParamFunctionChunk(
|
* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep
|
||||||
name='str_replace_editor',
|
* The `create` command cannot be used if the specified `path` already exists as a file
|
||||||
description=_STR_REPLACE_EDITOR_DESCRIPTION,
|
* If a `command` generates a long output, it will be truncated and marked with `<response clipped>`
|
||||||
parameters={
|
* The `undo_edit` command will revert the last edit made to the file at `path`
|
||||||
'type': 'object',
|
Notes for using the `str_replace` command:
|
||||||
'properties': {
|
* The `old_str` parameter should match EXACTLY one or more consecutive lines from the original file. Be mindful of whitespaces!
|
||||||
'command': {
|
* If the `old_str` parameter is not unique in the file, the replacement will not be performed. Make sure to include enough context in `old_str` to make it unique
|
||||||
'description': 'The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`, `undo_edit`.',
|
* The `new_str` parameter should contain the edited lines that should replace the `old_str`
|
||||||
'enum': ['view', 'create', 'str_replace', 'insert', 'undo_edit'],
|
"""
|
||||||
'type': 'string',
|
|
||||||
},
|
|
||||||
'path': {
|
def create_str_replace_editor_tool(
|
||||||
'description': 'Absolute path to file or directory, e.g. `/workspace/file.py` or `/workspace`.',
|
use_simplified_description: bool = False,
|
||||||
'type': 'string',
|
) -> ChatCompletionToolParam:
|
||||||
},
|
description = (
|
||||||
'file_text': {
|
_SIMPLIFIED_STR_REPLACE_EDITOR_DESCRIPTION
|
||||||
'description': 'Required parameter of `create` command, with the content of the file to be created.',
|
if use_simplified_description
|
||||||
'type': 'string',
|
else _DETAILED_STR_REPLACE_EDITOR_DESCRIPTION
|
||||||
},
|
)
|
||||||
'old_str': {
|
return ChatCompletionToolParam(
|
||||||
'description': 'Required parameter of `str_replace` command containing the string in `path` to replace.',
|
type='function',
|
||||||
'type': 'string',
|
function=ChatCompletionToolParamFunctionChunk(
|
||||||
},
|
name='str_replace_editor',
|
||||||
'new_str': {
|
description=description,
|
||||||
'description': 'Optional parameter of `str_replace` command containing the new string (if not given, no string will be added). Required parameter of `insert` command containing the string to insert.',
|
parameters={
|
||||||
'type': 'string',
|
'type': 'object',
|
||||||
},
|
'properties': {
|
||||||
'insert_line': {
|
'command': {
|
||||||
'description': 'Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`.',
|
'description': 'The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`, `undo_edit`.',
|
||||||
'type': 'integer',
|
'enum': [
|
||||||
},
|
'view',
|
||||||
'view_range': {
|
'create',
|
||||||
'description': 'Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.',
|
'str_replace',
|
||||||
'items': {'type': 'integer'},
|
'insert',
|
||||||
'type': 'array',
|
'undo_edit',
|
||||||
|
],
|
||||||
|
'type': 'string',
|
||||||
|
},
|
||||||
|
'path': {
|
||||||
|
'description': 'Absolute path to file or directory, e.g. `/workspace/file.py` or `/workspace`.',
|
||||||
|
'type': 'string',
|
||||||
|
},
|
||||||
|
'file_text': {
|
||||||
|
'description': 'Required parameter of `create` command, with the content of the file to be created.',
|
||||||
|
'type': 'string',
|
||||||
|
},
|
||||||
|
'old_str': {
|
||||||
|
'description': 'Required parameter of `str_replace` command containing the string in `path` to replace.',
|
||||||
|
'type': 'string',
|
||||||
|
},
|
||||||
|
'new_str': {
|
||||||
|
'description': 'Optional parameter of `str_replace` command containing the new string (if not given, no string will be added). Required parameter of `insert` command containing the string to insert.',
|
||||||
|
'type': 'string',
|
||||||
|
},
|
||||||
|
'insert_line': {
|
||||||
|
'description': 'Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`.',
|
||||||
|
'type': 'integer',
|
||||||
|
},
|
||||||
|
'view_range': {
|
||||||
|
'description': 'Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.',
|
||||||
|
'items': {'type': 'integer'},
|
||||||
|
'type': 'array',
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
'required': ['command', 'path'],
|
||||||
},
|
},
|
||||||
'required': ['command', 'path'],
|
),
|
||||||
},
|
)
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,4 +0,0 @@
|
|||||||
from openhands.agenthub.delegator_agent.agent import DelegatorAgent
|
|
||||||
from openhands.controller.agent import Agent
|
|
||||||
|
|
||||||
Agent.register('DelegatorAgent', DelegatorAgent)
|
|
||||||
@@ -1,87 +0,0 @@
|
|||||||
from openhands.controller.agent import Agent
|
|
||||||
from openhands.controller.state.state import State
|
|
||||||
from openhands.core.config import AgentConfig
|
|
||||||
from openhands.events.action import Action, AgentDelegateAction, AgentFinishAction
|
|
||||||
from openhands.events.observation import AgentDelegateObservation, Observation
|
|
||||||
from openhands.llm.llm import LLM
|
|
||||||
|
|
||||||
|
|
||||||
class DelegatorAgent(Agent):
|
|
||||||
VERSION = '1.0'
|
|
||||||
"""
|
|
||||||
The Delegator Agent is responsible for delegating tasks to other agents based on the current task.
|
|
||||||
"""
|
|
||||||
|
|
||||||
current_delegate: str = ''
|
|
||||||
|
|
||||||
def __init__(self, llm: LLM, config: AgentConfig):
|
|
||||||
"""Initialize the Delegator Agent with an LLM
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- llm (LLM): The llm to be used by this agent
|
|
||||||
"""
|
|
||||||
super().__init__(llm, config)
|
|
||||||
|
|
||||||
def step(self, state: State) -> Action:
|
|
||||||
"""Checks to see if current step is completed, returns AgentFinishAction if True.
|
|
||||||
Otherwise, delegates the task to the next agent in the pipeline.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- state (State): The current state given the previous actions and observations
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- AgentFinishAction: If the last state was 'completed', 'verified', or 'abandoned'
|
|
||||||
- AgentDelegateAction: The next agent to delegate the task to
|
|
||||||
"""
|
|
||||||
if self.current_delegate == '':
|
|
||||||
self.current_delegate = 'study'
|
|
||||||
task, _ = state.get_current_user_intent()
|
|
||||||
return AgentDelegateAction(
|
|
||||||
agent='StudyRepoForTaskAgent', inputs={'task': task}
|
|
||||||
)
|
|
||||||
|
|
||||||
# last observation in history should be from the delegate
|
|
||||||
last_observation = None
|
|
||||||
for event in reversed(state.history):
|
|
||||||
if isinstance(event, Observation):
|
|
||||||
last_observation = event
|
|
||||||
break
|
|
||||||
|
|
||||||
if not isinstance(last_observation, AgentDelegateObservation):
|
|
||||||
raise Exception('Last observation is not an AgentDelegateObservation')
|
|
||||||
|
|
||||||
goal, _ = state.get_current_user_intent()
|
|
||||||
if self.current_delegate == 'study':
|
|
||||||
self.current_delegate = 'coder'
|
|
||||||
return AgentDelegateAction(
|
|
||||||
agent='CoderAgent',
|
|
||||||
inputs={
|
|
||||||
'task': goal,
|
|
||||||
'summary': last_observation.outputs['summary'],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
elif self.current_delegate == 'coder':
|
|
||||||
self.current_delegate = 'verifier'
|
|
||||||
return AgentDelegateAction(
|
|
||||||
agent='VerifierAgent',
|
|
||||||
inputs={
|
|
||||||
'task': goal,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
elif self.current_delegate == 'verifier':
|
|
||||||
if (
|
|
||||||
'completed' in last_observation.outputs
|
|
||||||
and last_observation.outputs['completed']
|
|
||||||
):
|
|
||||||
return AgentFinishAction()
|
|
||||||
else:
|
|
||||||
self.current_delegate = 'coder'
|
|
||||||
return AgentDelegateAction(
|
|
||||||
agent='CoderAgent',
|
|
||||||
inputs={
|
|
||||||
'task': goal,
|
|
||||||
'summary': last_observation.outputs['summary'],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise Exception('Invalid delegate state')
|
|
||||||
@@ -202,6 +202,7 @@ Note:
|
|||||||
tabs = ''
|
tabs = ''
|
||||||
last_obs = None
|
last_obs = None
|
||||||
last_action = None
|
last_action = None
|
||||||
|
set_of_marks = None # Initialize set_of_marks to None
|
||||||
|
|
||||||
if len(state.history) == 1:
|
if len(state.history) == 1:
|
||||||
# for visualwebarena, webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env
|
# for visualwebarena, webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env
|
||||||
@@ -217,6 +218,9 @@ Note:
|
|||||||
# agent has responded, task finished.
|
# agent has responded, task finished.
|
||||||
return AgentFinishAction(outputs={'content': event.content})
|
return AgentFinishAction(outputs={'content': event.content})
|
||||||
elif isinstance(event, Observation):
|
elif isinstance(event, Observation):
|
||||||
|
# Only process BrowserOutputObservation and skip other observation types
|
||||||
|
if not isinstance(event, BrowserOutputObservation):
|
||||||
|
continue
|
||||||
last_obs = event
|
last_obs = event
|
||||||
|
|
||||||
if len(prev_actions) >= 1: # ignore noop()
|
if len(prev_actions) >= 1: # ignore noop()
|
||||||
|
|||||||
@@ -29,7 +29,12 @@ from openhands.core.exceptions import (
|
|||||||
from openhands.core.logger import LOG_ALL_EVENTS
|
from openhands.core.logger import LOG_ALL_EVENTS
|
||||||
from openhands.core.logger import openhands_logger as logger
|
from openhands.core.logger import openhands_logger as logger
|
||||||
from openhands.core.schema import AgentState
|
from openhands.core.schema import AgentState
|
||||||
from openhands.events import EventSource, EventStream, EventStreamSubscriber
|
from openhands.events import (
|
||||||
|
EventSource,
|
||||||
|
EventStream,
|
||||||
|
EventStreamSubscriber,
|
||||||
|
RecallType,
|
||||||
|
)
|
||||||
from openhands.events.action import (
|
from openhands.events.action import (
|
||||||
Action,
|
Action,
|
||||||
ActionConfirmationStatus,
|
ActionConfirmationStatus,
|
||||||
@@ -42,6 +47,7 @@ from openhands.events.action import (
|
|||||||
MessageAction,
|
MessageAction,
|
||||||
NullAction,
|
NullAction,
|
||||||
)
|
)
|
||||||
|
from openhands.events.action.agent import RecallAction
|
||||||
from openhands.events.event import Event
|
from openhands.events.event import Event
|
||||||
from openhands.events.observation import (
|
from openhands.events.observation import (
|
||||||
AgentCondensationObservation,
|
AgentCondensationObservation,
|
||||||
@@ -89,7 +95,7 @@ class AgentController:
|
|||||||
max_budget_per_task: float | None = None,
|
max_budget_per_task: float | None = None,
|
||||||
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
||||||
agent_configs: dict[str, AgentConfig] | None = None,
|
agent_configs: dict[str, AgentConfig] | None = None,
|
||||||
sid: str = 'default',
|
sid: str | None = None,
|
||||||
confirmation_mode: bool = False,
|
confirmation_mode: bool = False,
|
||||||
initial_state: State | None = None,
|
initial_state: State | None = None,
|
||||||
is_delegate: bool = False,
|
is_delegate: bool = False,
|
||||||
@@ -116,7 +122,7 @@ class AgentController:
|
|||||||
status_callback: Optional callback function to handle status updates.
|
status_callback: Optional callback function to handle status updates.
|
||||||
replay_events: A list of logs to replay.
|
replay_events: A list of logs to replay.
|
||||||
"""
|
"""
|
||||||
self.id = sid
|
self.id = sid or event_stream.sid
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
self.headless_mode = headless_mode
|
self.headless_mode = headless_mode
|
||||||
self.is_delegate = is_delegate
|
self.is_delegate = is_delegate
|
||||||
@@ -287,8 +293,14 @@ class AgentController:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
if isinstance(event, Observation):
|
if isinstance(event, Observation):
|
||||||
if isinstance(event, NullObservation) or isinstance(
|
if (
|
||||||
event, AgentStateChangedObservation
|
isinstance(event, NullObservation)
|
||||||
|
and event.cause is not None
|
||||||
|
and event.cause > 0
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
if isinstance(event, AgentStateChangedObservation) or isinstance(
|
||||||
|
event, NullObservation
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
@@ -388,6 +400,7 @@ class AgentController:
|
|||||||
if observation.llm_metrics is not None:
|
if observation.llm_metrics is not None:
|
||||||
self.agent.llm.metrics.merge(observation.llm_metrics)
|
self.agent.llm.metrics.merge(observation.llm_metrics)
|
||||||
|
|
||||||
|
# this happens for runnable actions and microagent actions
|
||||||
if self._pending_action and self._pending_action.id == observation.cause:
|
if self._pending_action and self._pending_action.id == observation.cause:
|
||||||
if self.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
|
if self.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
|
||||||
return
|
return
|
||||||
@@ -431,6 +444,25 @@ class AgentController:
|
|||||||
'debug',
|
'debug',
|
||||||
f'Extended max iterations to {self.state.max_iterations} after user message',
|
f'Extended max iterations to {self.state.max_iterations} after user message',
|
||||||
)
|
)
|
||||||
|
# try to retrieve microagents relevant to the user message
|
||||||
|
# set pending_action while we search for information
|
||||||
|
|
||||||
|
# if this is the first user message for this agent, matters for the microagent info type
|
||||||
|
first_user_message = self._first_user_message()
|
||||||
|
is_first_user_message = (
|
||||||
|
action.id == first_user_message.id if first_user_message else False
|
||||||
|
)
|
||||||
|
recall_type = (
|
||||||
|
RecallType.WORKSPACE_CONTEXT
|
||||||
|
if is_first_user_message
|
||||||
|
else RecallType.KNOWLEDGE
|
||||||
|
)
|
||||||
|
|
||||||
|
recall_action = RecallAction(query=action.content, recall_type=recall_type)
|
||||||
|
self._pending_action = recall_action
|
||||||
|
# this is source=USER because the user message is the trigger for the microagent retrieval
|
||||||
|
self.event_stream.add_event(recall_action, EventSource.USER)
|
||||||
|
|
||||||
if self.get_agent_state() != AgentState.RUNNING:
|
if self.get_agent_state() != AgentState.RUNNING:
|
||||||
await self.set_agent_state_to(AgentState.RUNNING)
|
await self.set_agent_state_to(AgentState.RUNNING)
|
||||||
elif action.source == EventSource.AGENT and action.wait_for_response:
|
elif action.source == EventSource.AGENT and action.wait_for_response:
|
||||||
@@ -438,6 +470,7 @@ class AgentController:
|
|||||||
|
|
||||||
def _reset(self) -> None:
|
def _reset(self) -> None:
|
||||||
"""Resets the agent controller"""
|
"""Resets the agent controller"""
|
||||||
|
# Runnable actions need an Observation
|
||||||
# make sure there is an Observation with the tool call metadata to be recognized by the agent
|
# make sure there is an Observation with the tool call metadata to be recognized by the agent
|
||||||
# otherwise the pending action is found in history, but it's incomplete without an obs with tool result
|
# otherwise the pending action is found in history, but it's incomplete without an obs with tool result
|
||||||
if self._pending_action and hasattr(self._pending_action, 'tool_call_metadata'):
|
if self._pending_action and hasattr(self._pending_action, 'tool_call_metadata'):
|
||||||
@@ -459,6 +492,8 @@ class AgentController:
|
|||||||
obs._cause = self._pending_action.id # type: ignore[attr-defined]
|
obs._cause = self._pending_action.id # type: ignore[attr-defined]
|
||||||
self.event_stream.add_event(obs, EventSource.AGENT)
|
self.event_stream.add_event(obs, EventSource.AGENT)
|
||||||
|
|
||||||
|
# NOTE: RecallActions don't need an ErrorObservation upon reset, as long as they have no tool calls
|
||||||
|
|
||||||
# reset the pending action, this will be called when the agent is STOPPED or ERROR
|
# reset the pending action, this will be called when the agent is STOPPED or ERROR
|
||||||
self._pending_action = None
|
self._pending_action = None
|
||||||
self.agent.reset()
|
self.agent.reset()
|
||||||
@@ -1146,3 +1181,26 @@ class AgentController:
|
|||||||
result = event.agent_state == AgentState.RUNNING
|
result = event.agent_state == AgentState.RUNNING
|
||||||
return result
|
return result
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _first_user_message(self) -> MessageAction | None:
|
||||||
|
"""
|
||||||
|
Get the first user message for this agent.
|
||||||
|
|
||||||
|
For regular agents, this is the first user message from the beginning (start_id=0).
|
||||||
|
For delegate agents, this is the first user message after the delegate's start_id.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MessageAction | None: The first user message, or None if no user message found
|
||||||
|
"""
|
||||||
|
# Find the first user message from the appropriate starting point
|
||||||
|
user_messages = list(self.event_stream.get_events(start_id=self.state.start_id))
|
||||||
|
|
||||||
|
# Get and return the first user message
|
||||||
|
return next(
|
||||||
|
(
|
||||||
|
e
|
||||||
|
for e in user_messages
|
||||||
|
if isinstance(e, MessageAction) and e.source == EventSource.USER
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ class StuckDetector:
|
|||||||
# it takes 3 actions and 3 observations to detect a loop
|
# it takes 3 actions and 3 observations to detect a loop
|
||||||
# check if the last three actions are the same and result in errors
|
# check if the last three actions are the same and result in errors
|
||||||
|
|
||||||
if len(last_actions) < 4 or len(last_observations) < 4:
|
if len(last_actions) < 3 or len(last_observations) < 3:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# are the last three actions the "same"?
|
# are the last three actions the "same"?
|
||||||
|
|||||||
+13
-3
@@ -17,6 +17,7 @@ from openhands.core.schema import AgentState
|
|||||||
from openhands.core.setup import (
|
from openhands.core.setup import (
|
||||||
create_agent,
|
create_agent,
|
||||||
create_controller,
|
create_controller,
|
||||||
|
create_memory,
|
||||||
create_runtime,
|
create_runtime,
|
||||||
initialize_repository_for_runtime,
|
initialize_repository_for_runtime,
|
||||||
)
|
)
|
||||||
@@ -170,13 +171,22 @@ async def main(loop: asyncio.AbstractEventLoop):
|
|||||||
await runtime.connect()
|
await runtime.connect()
|
||||||
|
|
||||||
# Initialize repository if needed
|
# Initialize repository if needed
|
||||||
|
repo_directory = None
|
||||||
if config.sandbox.selected_repo:
|
if config.sandbox.selected_repo:
|
||||||
initialize_repository_for_runtime(
|
repo_directory = initialize_repository_for_runtime(
|
||||||
runtime,
|
runtime,
|
||||||
agent=agent,
|
|
||||||
selected_repository=config.sandbox.selected_repo,
|
selected_repository=config.sandbox.selected_repo,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# when memory is created, it will load the microagents from the selected repository
|
||||||
|
memory = create_memory(
|
||||||
|
runtime=runtime,
|
||||||
|
event_stream=event_stream,
|
||||||
|
sid=sid,
|
||||||
|
selected_repository=config.sandbox.selected_repo,
|
||||||
|
repo_directory=repo_directory,
|
||||||
|
)
|
||||||
|
|
||||||
if initial_user_action:
|
if initial_user_action:
|
||||||
# If there's an initial user action, enqueue it and do not prompt again
|
# If there's an initial user action, enqueue it and do not prompt again
|
||||||
event_stream.add_event(initial_user_action, EventSource.USER)
|
event_stream.add_event(initial_user_action, EventSource.USER)
|
||||||
@@ -185,7 +195,7 @@ async def main(loop: asyncio.AbstractEventLoop):
|
|||||||
asyncio.create_task(prompt_for_next_task())
|
asyncio.create_task(prompt_for_next_task())
|
||||||
|
|
||||||
await run_agent_until_done(
|
await run_agent_until_done(
|
||||||
controller, runtime, [AgentState.STOPPED, AgentState.ERROR]
|
controller, runtime, memory, [AgentState.STOPPED, AgentState.ERROR]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ class SandboxConfig(BaseModel):
|
|||||||
timeout: The timeout for the default sandbox action execution.
|
timeout: The timeout for the default sandbox action execution.
|
||||||
remote_runtime_init_timeout: The timeout for the remote runtime to start.
|
remote_runtime_init_timeout: The timeout for the remote runtime to start.
|
||||||
remote_runtime_api_timeout: The timeout for the remote runtime API requests.
|
remote_runtime_api_timeout: The timeout for the remote runtime API requests.
|
||||||
|
remote_runtime_enable_retries: Whether to enable retries (on recoverable errors like requests.ConnectionError) for the remote runtime API requests.
|
||||||
enable_auto_lint: Whether to enable auto-lint.
|
enable_auto_lint: Whether to enable auto-lint.
|
||||||
use_host_network: Whether to use the host network.
|
use_host_network: Whether to use the host network.
|
||||||
runtime_binding_address: The binding address for the runtime ports. It specifies which network interface on the host machine Docker should bind the runtime ports to.
|
runtime_binding_address: The binding address for the runtime ports. It specifies which network interface on the host machine Docker should bind the runtime ports to.
|
||||||
@@ -53,7 +54,7 @@ class SandboxConfig(BaseModel):
|
|||||||
timeout: int = Field(default=120)
|
timeout: int = Field(default=120)
|
||||||
remote_runtime_init_timeout: int = Field(default=180)
|
remote_runtime_init_timeout: int = Field(default=180)
|
||||||
remote_runtime_api_timeout: int = Field(default=10)
|
remote_runtime_api_timeout: int = Field(default=10)
|
||||||
remote_runtime_enable_retries: bool = Field(default=False)
|
remote_runtime_enable_retries: bool = Field(default=True)
|
||||||
remote_runtime_class: str | None = Field(
|
remote_runtime_class: str | None = Field(
|
||||||
default=None
|
default=None
|
||||||
) # can be "None" (default to gvisor) or "sysbox" (support docker inside runtime + more stable)
|
) # can be "None" (default to gvisor) or "sysbox" (support docker inside runtime + more stable)
|
||||||
|
|||||||
@@ -240,7 +240,7 @@ class SensitiveDataFilter(logging.Filter):
|
|||||||
if (
|
if (
|
||||||
len(value) > 2
|
len(value) > 2
|
||||||
and value != 'default'
|
and value != 'default'
|
||||||
and any(s in key_upper for s in ('SECRET', 'KEY', 'CODE', 'TOKEN'))
|
and any(s in key_upper for s in ('SECRET', '_KEY', '_CODE', '_TOKEN'))
|
||||||
):
|
):
|
||||||
sensitive_values.append(value)
|
sensitive_values.append(value)
|
||||||
|
|
||||||
|
|||||||
@@ -3,12 +3,14 @@ import asyncio
|
|||||||
from openhands.controller import AgentController
|
from openhands.controller import AgentController
|
||||||
from openhands.core.logger import openhands_logger as logger
|
from openhands.core.logger import openhands_logger as logger
|
||||||
from openhands.core.schema import AgentState
|
from openhands.core.schema import AgentState
|
||||||
|
from openhands.memory.memory import Memory
|
||||||
from openhands.runtime.base import Runtime
|
from openhands.runtime.base import Runtime
|
||||||
|
|
||||||
|
|
||||||
async def run_agent_until_done(
|
async def run_agent_until_done(
|
||||||
controller: AgentController,
|
controller: AgentController,
|
||||||
runtime: Runtime,
|
runtime: Runtime,
|
||||||
|
memory: Memory,
|
||||||
end_states: list[AgentState],
|
end_states: list[AgentState],
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -37,6 +39,7 @@ async def run_agent_until_done(
|
|||||||
|
|
||||||
runtime.status_callback = status_callback
|
runtime.status_callback = status_callback
|
||||||
controller.status_callback = status_callback
|
controller.status_callback = status_callback
|
||||||
|
memory.status_callback = status_callback
|
||||||
|
|
||||||
while controller.state.agent_state not in end_states:
|
while controller.state.agent_state not in end_states:
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|||||||
+17
-3
@@ -18,6 +18,7 @@ from openhands.core.schema import AgentState
|
|||||||
from openhands.core.setup import (
|
from openhands.core.setup import (
|
||||||
create_agent,
|
create_agent,
|
||||||
create_controller,
|
create_controller,
|
||||||
|
create_memory,
|
||||||
create_runtime,
|
create_runtime,
|
||||||
generate_sid,
|
generate_sid,
|
||||||
initialize_repository_for_runtime,
|
initialize_repository_for_runtime,
|
||||||
@@ -29,6 +30,7 @@ from openhands.events.event import Event
|
|||||||
from openhands.events.observation import AgentStateChangedObservation
|
from openhands.events.observation import AgentStateChangedObservation
|
||||||
from openhands.events.serialization import event_from_dict
|
from openhands.events.serialization import event_from_dict
|
||||||
from openhands.io import read_input, read_task
|
from openhands.io import read_input, read_task
|
||||||
|
from openhands.memory.memory import Memory
|
||||||
from openhands.runtime.base import Runtime
|
from openhands.runtime.base import Runtime
|
||||||
from openhands.utils.async_utils import call_async_from_sync
|
from openhands.utils.async_utils import call_async_from_sync
|
||||||
|
|
||||||
@@ -51,6 +53,7 @@ async def run_controller(
|
|||||||
exit_on_message: bool = False,
|
exit_on_message: bool = False,
|
||||||
fake_user_response_fn: FakeUserResponseFunc | None = None,
|
fake_user_response_fn: FakeUserResponseFunc | None = None,
|
||||||
headless_mode: bool = True,
|
headless_mode: bool = True,
|
||||||
|
memory: Memory | None = None,
|
||||||
) -> State | None:
|
) -> State | None:
|
||||||
"""Main coroutine to run the agent controller with task input flexibility.
|
"""Main coroutine to run the agent controller with task input flexibility.
|
||||||
|
|
||||||
@@ -93,6 +96,8 @@ async def run_controller(
|
|||||||
if agent is None:
|
if agent is None:
|
||||||
agent = create_agent(config)
|
agent = create_agent(config)
|
||||||
|
|
||||||
|
# when the runtime is created, it will be connected and clone the selected repository
|
||||||
|
repo_directory = None
|
||||||
if runtime is None:
|
if runtime is None:
|
||||||
runtime = create_runtime(
|
runtime = create_runtime(
|
||||||
config,
|
config,
|
||||||
@@ -105,14 +110,23 @@ async def run_controller(
|
|||||||
|
|
||||||
# Initialize repository if needed
|
# Initialize repository if needed
|
||||||
if config.sandbox.selected_repo:
|
if config.sandbox.selected_repo:
|
||||||
initialize_repository_for_runtime(
|
repo_directory = initialize_repository_for_runtime(
|
||||||
runtime,
|
runtime,
|
||||||
agent=agent,
|
|
||||||
selected_repository=config.sandbox.selected_repo,
|
selected_repository=config.sandbox.selected_repo,
|
||||||
)
|
)
|
||||||
|
|
||||||
event_stream = runtime.event_stream
|
event_stream = runtime.event_stream
|
||||||
|
|
||||||
|
# when memory is created, it will load the microagents from the selected repository
|
||||||
|
if memory is None:
|
||||||
|
memory = create_memory(
|
||||||
|
runtime=runtime,
|
||||||
|
event_stream=event_stream,
|
||||||
|
sid=sid,
|
||||||
|
selected_repository=config.sandbox.selected_repo,
|
||||||
|
repo_directory=repo_directory,
|
||||||
|
)
|
||||||
|
|
||||||
replay_events: list[Event] | None = None
|
replay_events: list[Event] | None = None
|
||||||
if config.replay_trajectory_path:
|
if config.replay_trajectory_path:
|
||||||
logger.info('Trajectory replay is enabled')
|
logger.info('Trajectory replay is enabled')
|
||||||
@@ -172,7 +186,7 @@ async def run_controller(
|
|||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await run_agent_until_done(controller, runtime, end_states)
|
await run_agent_until_done(controller, runtime, memory, end_states)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f'Exception in main loop: {e}')
|
logger.error(f'Exception in main loop: {e}')
|
||||||
|
|
||||||
|
|||||||
@@ -82,5 +82,8 @@ class ActionTypeSchema(BaseModel):
|
|||||||
SEND_PR: str = Field(default='send_pr')
|
SEND_PR: str = Field(default='send_pr')
|
||||||
"""Send a PR to github."""
|
"""Send a PR to github."""
|
||||||
|
|
||||||
|
RECALL: str = Field(default='recall')
|
||||||
|
"""Retrieves content from a user workspace, microagent, or other source."""
|
||||||
|
|
||||||
|
|
||||||
ActionType = ActionTypeSchema()
|
ActionType = ActionTypeSchema()
|
||||||
|
|||||||
@@ -49,5 +49,8 @@ class ObservationTypeSchema(BaseModel):
|
|||||||
CONDENSE: str = Field(default='condense')
|
CONDENSE: str = Field(default='condense')
|
||||||
"""Result of a condensation operation."""
|
"""Result of a condensation operation."""
|
||||||
|
|
||||||
|
RECALL: str = Field(default='recall')
|
||||||
|
"""Result of a recall operation. This can be the workspace context, a microagent, or other types of information."""
|
||||||
|
|
||||||
|
|
||||||
ObservationType = ObservationTypeSchema()
|
ObservationType = ObservationTypeSchema()
|
||||||
|
|||||||
+40
-10
@@ -1,7 +1,7 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Tuple, Type
|
from typing import Callable, Tuple, Type
|
||||||
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
@@ -16,6 +16,7 @@ from openhands.core.logger import openhands_logger as logger
|
|||||||
from openhands.events import EventStream
|
from openhands.events import EventStream
|
||||||
from openhands.events.event import Event
|
from openhands.events.event import Event
|
||||||
from openhands.llm.llm import LLM
|
from openhands.llm.llm import LLM
|
||||||
|
from openhands.memory.memory import Memory
|
||||||
from openhands.microagent.microagent import BaseMicroAgent
|
from openhands.microagent.microagent import BaseMicroAgent
|
||||||
from openhands.runtime import get_runtime_cls
|
from openhands.runtime import get_runtime_cls
|
||||||
from openhands.runtime.base import Runtime
|
from openhands.runtime.base import Runtime
|
||||||
@@ -83,7 +84,6 @@ def create_runtime(
|
|||||||
|
|
||||||
def initialize_repository_for_runtime(
|
def initialize_repository_for_runtime(
|
||||||
runtime: Runtime,
|
runtime: Runtime,
|
||||||
agent: Agent | None = None,
|
|
||||||
selected_repository: str | None = None,
|
selected_repository: str | None = None,
|
||||||
github_token: SecretStr | None = None,
|
github_token: SecretStr | None = None,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
@@ -91,7 +91,6 @@ def initialize_repository_for_runtime(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
runtime: The runtime to initialize the repository for.
|
runtime: The runtime to initialize the repository for.
|
||||||
agent: (optional) The agent to load microagents for.
|
|
||||||
selected_repository: (optional) The GitHub repository to use.
|
selected_repository: (optional) The GitHub repository to use.
|
||||||
github_token: (optional) The GitHub token to use.
|
github_token: (optional) The GitHub token to use.
|
||||||
|
|
||||||
@@ -99,10 +98,10 @@ def initialize_repository_for_runtime(
|
|||||||
The repository directory path if a repository was cloned, None otherwise.
|
The repository directory path if a repository was cloned, None otherwise.
|
||||||
"""
|
"""
|
||||||
# clone selected repository if provided
|
# clone selected repository if provided
|
||||||
repo_directory = None
|
|
||||||
github_token = (
|
github_token = (
|
||||||
SecretStr(os.environ.get('GITHUB_TOKEN')) if not github_token else github_token
|
SecretStr(os.environ.get('GITHUB_TOKEN')) if not github_token else github_token
|
||||||
)
|
)
|
||||||
|
repo_directory = None
|
||||||
if selected_repository and github_token:
|
if selected_repository and github_token:
|
||||||
logger.debug(f'Selected repository {selected_repository}.')
|
logger.debug(f'Selected repository {selected_repository}.')
|
||||||
repo_directory = runtime.clone_repo(
|
repo_directory = runtime.clone_repo(
|
||||||
@@ -111,16 +110,47 @@ def initialize_repository_for_runtime(
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# load microagents from selected repository
|
return repo_directory
|
||||||
if agent and agent.prompt_manager and selected_repository and repo_directory:
|
|
||||||
agent.prompt_manager.set_runtime_info(runtime)
|
|
||||||
|
def create_memory(
|
||||||
|
runtime: Runtime,
|
||||||
|
event_stream: EventStream,
|
||||||
|
sid: str,
|
||||||
|
selected_repository: str | None = None,
|
||||||
|
repo_directory: str | None = None,
|
||||||
|
status_callback: Callable | None = None,
|
||||||
|
) -> Memory:
|
||||||
|
"""Create a memory for the agent to use.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runtime: The runtime to use.
|
||||||
|
event_stream: The event stream it will subscribe to.
|
||||||
|
sid: The session id.
|
||||||
|
selected_repository: The repository to clone and start with, if any.
|
||||||
|
repo_directory: The repository directory, if any.
|
||||||
|
status_callback: Optional callback function to handle status updates.
|
||||||
|
"""
|
||||||
|
memory = Memory(
|
||||||
|
event_stream=event_stream,
|
||||||
|
sid=sid,
|
||||||
|
status_callback=status_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if runtime:
|
||||||
|
# sets available hosts
|
||||||
|
memory.set_runtime_info(runtime)
|
||||||
|
|
||||||
|
# loads microagents from repo/.openhands/microagents
|
||||||
microagents: list[BaseMicroAgent] = runtime.get_microagents_from_selected_repo(
|
microagents: list[BaseMicroAgent] = runtime.get_microagents_from_selected_repo(
|
||||||
selected_repository
|
selected_repository
|
||||||
)
|
)
|
||||||
agent.prompt_manager.load_microagents(microagents)
|
memory.load_user_workspace_microagents(microagents)
|
||||||
agent.prompt_manager.set_repository_info(selected_repository, repo_directory)
|
|
||||||
|
|
||||||
return repo_directory
|
if selected_repository and repo_directory:
|
||||||
|
memory.set_repository_info(selected_repository, repo_directory)
|
||||||
|
|
||||||
|
return memory
|
||||||
|
|
||||||
|
|
||||||
def create_agent(config: AppConfig) -> Agent:
|
def create_agent(config: AppConfig) -> Agent:
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from openhands.events.event import Event, EventSource
|
from openhands.events.event import Event, EventSource, RecallType
|
||||||
from openhands.events.stream import EventStream, EventStreamSubscriber
|
from openhands.events.stream import EventStream, EventStreamSubscriber
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -6,4 +6,5 @@ __all__ = [
|
|||||||
'EventSource',
|
'EventSource',
|
||||||
'EventStream',
|
'EventStream',
|
||||||
'EventStreamSubscriber',
|
'EventStreamSubscriber',
|
||||||
|
'RecallType',
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from openhands.events.action.agent import (
|
|||||||
AgentSummarizeAction,
|
AgentSummarizeAction,
|
||||||
AgentThinkAction,
|
AgentThinkAction,
|
||||||
ChangeAgentStateAction,
|
ChangeAgentStateAction,
|
||||||
|
RecallAction,
|
||||||
)
|
)
|
||||||
from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction
|
from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction
|
||||||
from openhands.events.action.commands import CmdRunAction, IPythonRunCellAction
|
from openhands.events.action.commands import CmdRunAction, IPythonRunCellAction
|
||||||
@@ -35,4 +36,5 @@ __all__ = [
|
|||||||
'MessageAction',
|
'MessageAction',
|
||||||
'ActionConfirmationStatus',
|
'ActionConfirmationStatus',
|
||||||
'AgentThinkAction',
|
'AgentThinkAction',
|
||||||
|
'RecallAction',
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typing import Any
|
|||||||
|
|
||||||
from openhands.core.schema import ActionType
|
from openhands.core.schema import ActionType
|
||||||
from openhands.events.action.action import Action
|
from openhands.events.action.action import Action
|
||||||
|
from openhands.events.event import RecallType
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -106,3 +107,22 @@ class AgentDelegateAction(Action):
|
|||||||
@property
|
@property
|
||||||
def message(self) -> str:
|
def message(self) -> str:
|
||||||
return f"I'm asking {self.agent} for help with this task."
|
return f"I'm asking {self.agent} for help with this task."
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RecallAction(Action):
|
||||||
|
"""This action is used for retrieving content, e.g., from the global directory or user workspace."""
|
||||||
|
|
||||||
|
recall_type: RecallType
|
||||||
|
query: str = ''
|
||||||
|
thought: str = ''
|
||||||
|
action: str = ActionType.RECALL
|
||||||
|
|
||||||
|
@property
|
||||||
|
def message(self) -> str:
|
||||||
|
return f'Retrieving content for: {self.query[:50]}'
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
ret = '**RecallAction**\n'
|
||||||
|
ret += f'QUERY: {self.query[:50]}'
|
||||||
|
return ret
|
||||||
|
|||||||
@@ -22,6 +22,16 @@ class FileReadSource(str, Enum):
|
|||||||
DEFAULT = 'default'
|
DEFAULT = 'default'
|
||||||
|
|
||||||
|
|
||||||
|
class RecallType(str, Enum):
|
||||||
|
"""The type of information that can be retrieved from microagents."""
|
||||||
|
|
||||||
|
WORKSPACE_CONTEXT = 'workspace_context'
|
||||||
|
"""Workspace context (repo instructions, runtime, etc.)"""
|
||||||
|
|
||||||
|
KNOWLEDGE = 'knowledge'
|
||||||
|
"""A knowledge microagent."""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Event:
|
class Event:
|
||||||
INVALID_ID = -1
|
INVALID_ID = -1
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
|
from openhands.events.event import RecallType
|
||||||
from openhands.events.observation.agent import (
|
from openhands.events.observation.agent import (
|
||||||
AgentCondensationObservation,
|
AgentCondensationObservation,
|
||||||
AgentStateChangedObservation,
|
AgentStateChangedObservation,
|
||||||
AgentThinkObservation,
|
AgentThinkObservation,
|
||||||
|
RecallObservation,
|
||||||
)
|
)
|
||||||
from openhands.events.observation.browse import BrowserOutputObservation
|
from openhands.events.observation.browse import BrowserOutputObservation
|
||||||
from openhands.events.observation.commands import (
|
from openhands.events.observation.commands import (
|
||||||
@@ -40,4 +42,6 @@ __all__ = [
|
|||||||
'SuccessObservation',
|
'SuccessObservation',
|
||||||
'UserRejectObservation',
|
'UserRejectObservation',
|
||||||
'AgentCondensationObservation',
|
'AgentCondensationObservation',
|
||||||
|
'RecallObservation',
|
||||||
|
'RecallType',
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from openhands.core.schema import ObservationType
|
from openhands.core.schema import ObservationType
|
||||||
|
from openhands.events.event import RecallType
|
||||||
from openhands.events.observation.observation import Observation
|
from openhands.events.observation.observation import Observation
|
||||||
|
|
||||||
|
|
||||||
@@ -40,3 +41,90 @@ class AgentThinkObservation(Observation):
|
|||||||
@property
|
@property
|
||||||
def message(self) -> str:
|
def message(self) -> str:
|
||||||
return self.content
|
return self.content
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MicroagentKnowledge:
|
||||||
|
"""
|
||||||
|
Represents knowledge from a triggered microagent.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: The name of the microagent that was triggered
|
||||||
|
trigger: The word that triggered this microagent
|
||||||
|
content: The actual content/knowledge from the microagent
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
trigger: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RecallObservation(Observation):
|
||||||
|
"""The retrieval of content from a microagent or more microagents."""
|
||||||
|
|
||||||
|
recall_type: RecallType
|
||||||
|
observation: str = ObservationType.RECALL
|
||||||
|
|
||||||
|
# workspace context
|
||||||
|
repo_name: str = ''
|
||||||
|
repo_directory: str = ''
|
||||||
|
repo_instructions: str = ''
|
||||||
|
runtime_hosts: dict[str, int] = field(default_factory=dict)
|
||||||
|
additional_agent_instructions: str = ''
|
||||||
|
|
||||||
|
# knowledge
|
||||||
|
microagent_knowledge: list[MicroagentKnowledge] = field(default_factory=list)
|
||||||
|
"""
|
||||||
|
A list of MicroagentKnowledge objects, each containing information from a triggered microagent.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
[
|
||||||
|
MicroagentKnowledge(
|
||||||
|
name="python_best_practices",
|
||||||
|
trigger="python",
|
||||||
|
content="Always use virtual environments for Python projects."
|
||||||
|
),
|
||||||
|
MicroagentKnowledge(
|
||||||
|
name="git_workflow",
|
||||||
|
trigger="git",
|
||||||
|
content="Create a new branch for each feature or bugfix."
|
||||||
|
)
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def message(self) -> str:
|
||||||
|
return (
|
||||||
|
'Added workspace context'
|
||||||
|
if self.recall_type == RecallType.WORKSPACE_CONTEXT
|
||||||
|
else 'Added microagent knowledge'
|
||||||
|
)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
# Build a string representation
|
||||||
|
fields = []
|
||||||
|
if self.recall_type == RecallType.WORKSPACE_CONTEXT:
|
||||||
|
fields.extend(
|
||||||
|
[
|
||||||
|
f'recall_type={self.recall_type}',
|
||||||
|
f'repo_name={self.repo_name}',
|
||||||
|
f'repo_instructions={self.repo_instructions[:20]}...',
|
||||||
|
f'runtime_hosts={self.runtime_hosts}',
|
||||||
|
f'additional_agent_instructions={self.additional_agent_instructions[:20]}...',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
fields.extend(
|
||||||
|
[
|
||||||
|
f'recall_type={self.recall_type}',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if self.microagent_knowledge:
|
||||||
|
fields.extend(
|
||||||
|
[
|
||||||
|
f'microagent_knowledge={", ".join([m.name for m in self.microagent_knowledge])}',
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return f'**RecallObservation**\n{", ".join(fields)}'
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from openhands.events.action.agent import (
|
|||||||
AgentRejectAction,
|
AgentRejectAction,
|
||||||
AgentThinkAction,
|
AgentThinkAction,
|
||||||
ChangeAgentStateAction,
|
ChangeAgentStateAction,
|
||||||
|
RecallAction,
|
||||||
)
|
)
|
||||||
from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction
|
from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction
|
||||||
from openhands.events.action.commands import (
|
from openhands.events.action.commands import (
|
||||||
@@ -35,6 +36,7 @@ actions = (
|
|||||||
AgentFinishAction,
|
AgentFinishAction,
|
||||||
AgentRejectAction,
|
AgentRejectAction,
|
||||||
AgentDelegateAction,
|
AgentDelegateAction,
|
||||||
|
RecallAction,
|
||||||
ChangeAgentStateAction,
|
ChangeAgentStateAction,
|
||||||
MessageAction,
|
MessageAction,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -102,6 +103,8 @@ def event_to_dict(event: 'Event') -> dict:
|
|||||||
d['timestamp'] = d['timestamp'].isoformat()
|
d['timestamp'] = d['timestamp'].isoformat()
|
||||||
if key == 'source' and 'source' in d:
|
if key == 'source' and 'source' in d:
|
||||||
d['source'] = d['source'].value
|
d['source'] = d['source'].value
|
||||||
|
if key == 'recall_type' and 'recall_type' in d:
|
||||||
|
d['recall_type'] = d['recall_type'].value
|
||||||
if key == 'tool_call_metadata' and 'tool_call_metadata' in d:
|
if key == 'tool_call_metadata' and 'tool_call_metadata' in d:
|
||||||
d['tool_call_metadata'] = d['tool_call_metadata'].model_dump()
|
d['tool_call_metadata'] = d['tool_call_metadata'].model_dump()
|
||||||
if key == 'llm_metrics' and 'llm_metrics' in d:
|
if key == 'llm_metrics' and 'llm_metrics' in d:
|
||||||
@@ -119,7 +122,11 @@ def event_to_dict(event: 'Event') -> dict:
|
|||||||
# props is a dict whose values can include a complex object like an instance of a BaseModel subclass
|
# props is a dict whose values can include a complex object like an instance of a BaseModel subclass
|
||||||
# such as CmdOutputMetadata
|
# such as CmdOutputMetadata
|
||||||
# we serialize it along with the rest
|
# we serialize it along with the rest
|
||||||
d['extras'] = {k: _convert_pydantic_to_dict(v) for k, v in props.items()}
|
# we also handle the Enum conversion for RecallObservation
|
||||||
|
d['extras'] = {
|
||||||
|
k: (v.value if isinstance(v, Enum) else _convert_pydantic_to_dict(v))
|
||||||
|
for k, v in props.items()
|
||||||
|
}
|
||||||
# Include success field for CmdOutputObservation
|
# Include success field for CmdOutputObservation
|
||||||
if hasattr(event, 'success'):
|
if hasattr(event, 'success'):
|
||||||
d['success'] = event.success
|
d['success'] = event.success
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
from openhands.events.event import RecallType
|
||||||
from openhands.events.observation.agent import (
|
from openhands.events.observation.agent import (
|
||||||
AgentCondensationObservation,
|
AgentCondensationObservation,
|
||||||
AgentStateChangedObservation,
|
AgentStateChangedObservation,
|
||||||
AgentThinkObservation,
|
AgentThinkObservation,
|
||||||
|
MicroagentKnowledge,
|
||||||
|
RecallObservation,
|
||||||
)
|
)
|
||||||
from openhands.events.observation.browse import BrowserOutputObservation
|
from openhands.events.observation.browse import BrowserOutputObservation
|
||||||
from openhands.events.observation.commands import (
|
from openhands.events.observation.commands import (
|
||||||
@@ -40,6 +43,7 @@ observations = (
|
|||||||
UserRejectObservation,
|
UserRejectObservation,
|
||||||
AgentCondensationObservation,
|
AgentCondensationObservation,
|
||||||
AgentThinkObservation,
|
AgentThinkObservation,
|
||||||
|
RecallObservation,
|
||||||
)
|
)
|
||||||
|
|
||||||
OBSERVATION_TYPE_TO_CLASS = {
|
OBSERVATION_TYPE_TO_CLASS = {
|
||||||
@@ -110,4 +114,18 @@ def observation_from_dict(observation: dict) -> Observation:
|
|||||||
else:
|
else:
|
||||||
extras['metadata'] = CmdOutputMetadata()
|
extras['metadata'] = CmdOutputMetadata()
|
||||||
|
|
||||||
|
if observation_class is RecallObservation:
|
||||||
|
# handle the Enum conversion
|
||||||
|
if 'recall_type' in extras:
|
||||||
|
extras['recall_type'] = RecallType(extras['recall_type'])
|
||||||
|
|
||||||
|
# convert dicts in microagent_knowledge to MicroagentKnowledge objects
|
||||||
|
if 'microagent_knowledge' in extras and isinstance(
|
||||||
|
extras['microagent_knowledge'], list
|
||||||
|
):
|
||||||
|
extras['microagent_knowledge'] = [
|
||||||
|
MicroagentKnowledge(**item) if isinstance(item, dict) else item
|
||||||
|
for item in extras['microagent_knowledge']
|
||||||
|
]
|
||||||
|
|
||||||
return observation_class(content=content, **extras)
|
return observation_class(content=content, **extras)
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ class EventStreamSubscriber(str, Enum):
|
|||||||
RESOLVER = 'openhands_resolver'
|
RESOLVER = 'openhands_resolver'
|
||||||
SERVER = 'server'
|
SERVER = 'server'
|
||||||
RUNTIME = 'runtime'
|
RUNTIME = 'runtime'
|
||||||
|
MEMORY = 'memory'
|
||||||
MAIN = 'main'
|
MAIN = 'main'
|
||||||
TEST = 'test'
|
TEST = 'test'
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from typing import Any
|
|||||||
import httpx
|
import httpx
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
from openhands.integrations.service_types import (
|
from openhands.integrations.service_types import (
|
||||||
AuthenticationError,
|
AuthenticationError,
|
||||||
GitService,
|
GitService,
|
||||||
@@ -15,7 +16,7 @@ from openhands.integrations.service_types import (
|
|||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
from openhands.utils.import_utils import get_impl
|
from openhands.utils.import_utils import get_impl
|
||||||
from openhands.core.logger import openhands_logger as logger
|
|
||||||
|
|
||||||
class GitHubService(GitService):
|
class GitHubService(GitService):
|
||||||
BASE_URL = 'https://api.github.com'
|
BASE_URL = 'https://api.github.com'
|
||||||
@@ -25,6 +26,7 @@ class GitHubService(GitService):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
|
external_auth_id: str | None = None,
|
||||||
external_auth_token: SecretStr | None = None,
|
external_auth_token: SecretStr | None = None,
|
||||||
token: SecretStr | None = None,
|
token: SecretStr | None = None,
|
||||||
external_token_manager: bool = False,
|
external_token_manager: bool = False,
|
||||||
|
|||||||
@@ -249,7 +249,8 @@ class LLM(RetryMixin, DebugMixin):
|
|||||||
|
|
||||||
# if we mocked function calling, and we have tools, convert the response back to function calling format
|
# if we mocked function calling, and we have tools, convert the response back to function calling format
|
||||||
if mock_function_calling and mock_fncall_tools is not None:
|
if mock_function_calling and mock_fncall_tools is not None:
|
||||||
assert len(resp.choices) == 1
|
logger.debug(f'Response choices: {len(resp.choices)}')
|
||||||
|
assert len(resp.choices) >= 1
|
||||||
non_fncall_response_message = resp.choices[0].message
|
non_fncall_response_message = resp.choices[0].message
|
||||||
fn_call_messages_with_response = (
|
fn_call_messages_with_response = (
|
||||||
convert_non_fncall_messages_to_fncall_messages(
|
convert_non_fncall_messages_to_fncall_messages(
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from litellm import ModelResponse
|
from litellm import ModelResponse
|
||||||
|
|
||||||
|
from openhands.core.config.agent_config import AgentConfig
|
||||||
from openhands.core.logger import openhands_logger as logger
|
from openhands.core.logger import openhands_logger as logger
|
||||||
from openhands.core.message import ImageContent, Message, TextContent
|
from openhands.core.message import ImageContent, Message, TextContent
|
||||||
from openhands.core.schema import ActionType
|
from openhands.core.schema import ActionType
|
||||||
@@ -16,7 +17,7 @@ from openhands.events.action import (
|
|||||||
IPythonRunCellAction,
|
IPythonRunCellAction,
|
||||||
MessageAction,
|
MessageAction,
|
||||||
)
|
)
|
||||||
from openhands.events.event import Event
|
from openhands.events.event import Event, RecallType
|
||||||
from openhands.events.observation import (
|
from openhands.events.observation import (
|
||||||
AgentCondensationObservation,
|
AgentCondensationObservation,
|
||||||
AgentDelegateObservation,
|
AgentDelegateObservation,
|
||||||
@@ -28,16 +29,21 @@ from openhands.events.observation import (
|
|||||||
IPythonRunCellObservation,
|
IPythonRunCellObservation,
|
||||||
UserRejectObservation,
|
UserRejectObservation,
|
||||||
)
|
)
|
||||||
|
from openhands.events.observation.agent import (
|
||||||
|
MicroagentKnowledge,
|
||||||
|
RecallObservation,
|
||||||
|
)
|
||||||
from openhands.events.observation.error import ErrorObservation
|
from openhands.events.observation.error import ErrorObservation
|
||||||
from openhands.events.observation.observation import Observation
|
from openhands.events.observation.observation import Observation
|
||||||
from openhands.events.serialization.event import truncate_content
|
from openhands.events.serialization.event import truncate_content
|
||||||
from openhands.utils.prompt import PromptManager
|
from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo
|
||||||
|
|
||||||
|
|
||||||
class ConversationMemory:
|
class ConversationMemory:
|
||||||
"""Processes event history into a coherent conversation for the agent."""
|
"""Processes event history into a coherent conversation for the agent."""
|
||||||
|
|
||||||
def __init__(self, prompt_manager: PromptManager):
|
def __init__(self, config: AgentConfig, prompt_manager: PromptManager):
|
||||||
|
self.agent_config = config
|
||||||
self.prompt_manager = prompt_manager
|
self.prompt_manager = prompt_manager
|
||||||
|
|
||||||
def process_events(
|
def process_events(
|
||||||
@@ -46,23 +52,24 @@ class ConversationMemory:
|
|||||||
initial_messages: list[Message],
|
initial_messages: list[Message],
|
||||||
max_message_chars: int | None = None,
|
max_message_chars: int | None = None,
|
||||||
vision_is_active: bool = False,
|
vision_is_active: bool = False,
|
||||||
enable_som_visual_browsing: bool = False,
|
|
||||||
) -> list[Message]:
|
) -> list[Message]:
|
||||||
"""Process state history into a list of messages for the LLM.
|
"""Process state history into a list of messages for the LLM.
|
||||||
|
|
||||||
Ensures that tool call actions are processed correctly in function calling mode.
|
Ensures that tool call actions are processed correctly in function calling mode.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: The state containing the history of events to convert
|
condensed_history: The condensed history of events to convert
|
||||||
condensed_history: The condensed list of events to process
|
initial_messages: The initial messages to include in the conversation
|
||||||
initial_messages: The initial messages to include in the result
|
|
||||||
max_message_chars: The maximum number of characters in the content of an event included
|
max_message_chars: The maximum number of characters in the content of an event included
|
||||||
in the prompt to the LLM. Larger observations are truncated.
|
in the prompt to the LLM. Larger observations are truncated.
|
||||||
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included.
|
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included.
|
||||||
enable_som_visual_browsing: Whether to enable visual browsing for the SOM model.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
events = condensed_history
|
events = condensed_history
|
||||||
|
|
||||||
|
# log visual browsing status
|
||||||
|
logger.debug(f'Visual browsing: {self.agent_config.enable_som_visual_browsing}')
|
||||||
|
|
||||||
# Process special events first (system prompts, etc.)
|
# Process special events first (system prompts, etc.)
|
||||||
messages = initial_messages
|
messages = initial_messages
|
||||||
|
|
||||||
@@ -70,7 +77,7 @@ class ConversationMemory:
|
|||||||
pending_tool_call_action_messages: dict[str, Message] = {}
|
pending_tool_call_action_messages: dict[str, Message] = {}
|
||||||
tool_call_id_to_message: dict[str, Message] = {}
|
tool_call_id_to_message: dict[str, Message] = {}
|
||||||
|
|
||||||
for event in events:
|
for i, event in enumerate(events):
|
||||||
# create a regular message from an event
|
# create a regular message from an event
|
||||||
if isinstance(event, Action):
|
if isinstance(event, Action):
|
||||||
messages_to_add = self._process_action(
|
messages_to_add = self._process_action(
|
||||||
@@ -84,7 +91,9 @@ class ConversationMemory:
|
|||||||
tool_call_id_to_message=tool_call_id_to_message,
|
tool_call_id_to_message=tool_call_id_to_message,
|
||||||
max_message_chars=max_message_chars,
|
max_message_chars=max_message_chars,
|
||||||
vision_is_active=vision_is_active,
|
vision_is_active=vision_is_active,
|
||||||
enable_som_visual_browsing=enable_som_visual_browsing,
|
enable_som_visual_browsing=self.agent_config.enable_som_visual_browsing,
|
||||||
|
current_index=i,
|
||||||
|
events=events,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unknown event type: {type(event)}')
|
raise ValueError(f'Unknown event type: {type(event)}')
|
||||||
@@ -270,6 +279,8 @@ class ConversationMemory:
|
|||||||
max_message_chars: int | None = None,
|
max_message_chars: int | None = None,
|
||||||
vision_is_active: bool = False,
|
vision_is_active: bool = False,
|
||||||
enable_som_visual_browsing: bool = False,
|
enable_som_visual_browsing: bool = False,
|
||||||
|
current_index: int = 0,
|
||||||
|
events: list[Event] | None = None,
|
||||||
) -> list[Message]:
|
) -> list[Message]:
|
||||||
"""Converts an observation into a message format that can be sent to the LLM.
|
"""Converts an observation into a message format that can be sent to the LLM.
|
||||||
|
|
||||||
@@ -291,6 +302,8 @@ class ConversationMemory:
|
|||||||
max_message_chars: The maximum number of characters in the content of an observation included in the prompt to the LLM
|
max_message_chars: The maximum number of characters in the content of an observation included in the prompt to the LLM
|
||||||
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included
|
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included
|
||||||
enable_som_visual_browsing: Whether to enable visual browsing for the SOM model
|
enable_som_visual_browsing: Whether to enable visual browsing for the SOM model
|
||||||
|
current_index: The index of the current event in the events list (for deduplication)
|
||||||
|
events: The list of all events (for deduplication)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[Message]: A list containing the formatted message(s) for the observation.
|
list[Message]: A list containing the formatted message(s) for the observation.
|
||||||
@@ -372,6 +385,119 @@ class ConversationMemory:
|
|||||||
elif isinstance(obs, AgentCondensationObservation):
|
elif isinstance(obs, AgentCondensationObservation):
|
||||||
text = truncate_content(obs.content, max_message_chars)
|
text = truncate_content(obs.content, max_message_chars)
|
||||||
message = Message(role='user', content=[TextContent(text=text)])
|
message = Message(role='user', content=[TextContent(text=text)])
|
||||||
|
elif (
|
||||||
|
isinstance(obs, RecallObservation)
|
||||||
|
and self.agent_config.enable_prompt_extensions
|
||||||
|
):
|
||||||
|
if obs.recall_type == RecallType.WORKSPACE_CONTEXT:
|
||||||
|
# everything is optional, check if they are present
|
||||||
|
if obs.repo_name or obs.repo_directory:
|
||||||
|
repo_info = RepositoryInfo(
|
||||||
|
repo_name=obs.repo_name or '',
|
||||||
|
repo_directory=obs.repo_directory or '',
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
repo_info = None
|
||||||
|
|
||||||
|
if obs.runtime_hosts or obs.additional_agent_instructions:
|
||||||
|
runtime_info = RuntimeInfo(
|
||||||
|
available_hosts=obs.runtime_hosts,
|
||||||
|
additional_agent_instructions=obs.additional_agent_instructions,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
runtime_info = None
|
||||||
|
|
||||||
|
repo_instructions = (
|
||||||
|
obs.repo_instructions if obs.repo_instructions else ''
|
||||||
|
)
|
||||||
|
|
||||||
|
# Have some meaningful content before calling the template
|
||||||
|
has_repo_info = repo_info is not None and (
|
||||||
|
repo_info.repo_name or repo_info.repo_directory
|
||||||
|
)
|
||||||
|
has_runtime_info = runtime_info is not None and (
|
||||||
|
runtime_info.available_hosts
|
||||||
|
or runtime_info.additional_agent_instructions
|
||||||
|
)
|
||||||
|
has_repo_instructions = bool(repo_instructions.strip())
|
||||||
|
|
||||||
|
# Filter and process microagent knowledge
|
||||||
|
filtered_agents = []
|
||||||
|
if obs.microagent_knowledge:
|
||||||
|
# Exclude disabled microagents
|
||||||
|
filtered_agents = [
|
||||||
|
agent
|
||||||
|
for agent in obs.microagent_knowledge
|
||||||
|
if agent.name not in self.agent_config.disabled_microagents
|
||||||
|
]
|
||||||
|
|
||||||
|
has_microagent_knowledge = bool(filtered_agents)
|
||||||
|
|
||||||
|
# Generate appropriate content based on what is present
|
||||||
|
message_content = []
|
||||||
|
|
||||||
|
# Build the workspace context information
|
||||||
|
if has_repo_info or has_runtime_info or has_repo_instructions:
|
||||||
|
formatted_workspace_text = (
|
||||||
|
self.prompt_manager.build_workspace_context(
|
||||||
|
repository_info=repo_info,
|
||||||
|
runtime_info=runtime_info,
|
||||||
|
repo_instructions=repo_instructions,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
message_content.append(TextContent(text=formatted_workspace_text))
|
||||||
|
|
||||||
|
# Add microagent knowledge if present
|
||||||
|
if has_microagent_knowledge:
|
||||||
|
formatted_microagent_text = (
|
||||||
|
self.prompt_manager.build_microagent_info(
|
||||||
|
triggered_agents=filtered_agents,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
message_content.append(TextContent(text=formatted_microagent_text))
|
||||||
|
|
||||||
|
# Return the combined message if we have any content
|
||||||
|
if message_content:
|
||||||
|
message = Message(role='user', content=message_content)
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
elif obs.recall_type == RecallType.KNOWLEDGE:
|
||||||
|
# Use prompt manager to build the microagent info
|
||||||
|
# First, filter out agents that appear in earlier RecallObservations
|
||||||
|
filtered_agents = self._filter_agents_in_microagent_obs(
|
||||||
|
obs, current_index, events or []
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create and return a message if there is microagent knowledge to include
|
||||||
|
if filtered_agents:
|
||||||
|
# Exclude disabled microagents
|
||||||
|
filtered_agents = [
|
||||||
|
agent
|
||||||
|
for agent in filtered_agents
|
||||||
|
if agent.name not in self.agent_config.disabled_microagents
|
||||||
|
]
|
||||||
|
|
||||||
|
# Only proceed if we still have agents after filtering out disabled ones
|
||||||
|
if filtered_agents:
|
||||||
|
formatted_text = self.prompt_manager.build_microagent_info(
|
||||||
|
triggered_agents=filtered_agents,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
Message(
|
||||||
|
role='user', content=[TextContent(text=formatted_text)]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Return empty list if no microagents to include or all were disabled
|
||||||
|
return []
|
||||||
|
elif (
|
||||||
|
isinstance(obs, RecallObservation)
|
||||||
|
and not self.agent_config.enable_prompt_extensions
|
||||||
|
):
|
||||||
|
# If prompt extensions are disabled, we don't add any additional info
|
||||||
|
# TODO: test this
|
||||||
|
return []
|
||||||
else:
|
else:
|
||||||
# If an observation message is not returned, it will cause an error
|
# If an observation message is not returned, it will cause an error
|
||||||
# when the LLM tries to return the next message
|
# when the LLM tries to return the next message
|
||||||
@@ -404,3 +530,51 @@ class ConversationMemory:
|
|||||||
-1
|
-1
|
||||||
].cache_prompt = True # Last item inside the message content
|
].cache_prompt = True # Last item inside the message content
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def _filter_agents_in_microagent_obs(
|
||||||
|
self, obs: RecallObservation, current_index: int, events: list[Event]
|
||||||
|
) -> list[MicroagentKnowledge]:
|
||||||
|
"""Filter out agents that appear in earlier RecallObservations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs: The current RecallObservation to filter
|
||||||
|
current_index: The index of the current event in the events list
|
||||||
|
events: The list of all events
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[MicroagentKnowledge]: The filtered list of microagent knowledge
|
||||||
|
"""
|
||||||
|
if obs.recall_type != RecallType.KNOWLEDGE:
|
||||||
|
return obs.microagent_knowledge
|
||||||
|
|
||||||
|
# For each agent in the current microagent observation, check if it appears in any earlier microagent observation
|
||||||
|
filtered_agents = []
|
||||||
|
for agent in obs.microagent_knowledge:
|
||||||
|
# Keep this agent if it doesn't appear in any earlier observation
|
||||||
|
# that is, if this is the first microagent observation with this microagent
|
||||||
|
if not self._has_agent_in_earlier_events(agent.name, current_index, events):
|
||||||
|
filtered_agents.append(agent)
|
||||||
|
|
||||||
|
return filtered_agents
|
||||||
|
|
||||||
|
def _has_agent_in_earlier_events(
|
||||||
|
self, agent_name: str, current_index: int, events: list[Event]
|
||||||
|
) -> bool:
|
||||||
|
"""Check if an agent appears in any earlier RecallObservation in the event list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_name: The name of the agent to look for
|
||||||
|
current_index: The index of the current event in the events list
|
||||||
|
events: The list of all events
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the agent appears in an earlier RecallObservation, False otherwise
|
||||||
|
"""
|
||||||
|
for event in events[:current_index]:
|
||||||
|
# Note that this check includes the WORKSPACE_CONTEXT
|
||||||
|
if isinstance(event, RecallObservation):
|
||||||
|
if any(
|
||||||
|
agent.name == agent_name for agent in event.microagent_knowledge
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|||||||
@@ -0,0 +1,292 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import openhands
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.events.action.agent import RecallAction
|
||||||
|
from openhands.events.event import Event, EventSource, RecallType
|
||||||
|
from openhands.events.observation.agent import (
|
||||||
|
MicroagentKnowledge,
|
||||||
|
RecallObservation,
|
||||||
|
)
|
||||||
|
from openhands.events.observation.empty import NullObservation
|
||||||
|
from openhands.events.stream import EventStream, EventStreamSubscriber
|
||||||
|
from openhands.microagent import (
|
||||||
|
BaseMicroAgent,
|
||||||
|
KnowledgeMicroAgent,
|
||||||
|
RepoMicroAgent,
|
||||||
|
load_microagents_from_dir,
|
||||||
|
)
|
||||||
|
from openhands.runtime.base import Runtime
|
||||||
|
from openhands.utils.prompt import RepositoryInfo, RuntimeInfo
|
||||||
|
|
||||||
|
GLOBAL_MICROAGENTS_DIR = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(openhands.__file__)),
|
||||||
|
'microagents',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Memory:
|
||||||
|
"""
|
||||||
|
Memory is a component that listens to the EventStream for information retrieval actions
|
||||||
|
(a RecallAction) and publishes observations with the content (such as RecallObservation).
|
||||||
|
"""
|
||||||
|
|
||||||
|
sid: str
|
||||||
|
event_stream: EventStream
|
||||||
|
status_callback: Callable | None
|
||||||
|
loop: asyncio.AbstractEventLoop | None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
event_stream: EventStream,
|
||||||
|
sid: str,
|
||||||
|
status_callback: Callable | None = None,
|
||||||
|
):
|
||||||
|
self.event_stream = event_stream
|
||||||
|
self.sid = sid if sid else str(uuid.uuid4())
|
||||||
|
self.status_callback = status_callback
|
||||||
|
self.loop = None
|
||||||
|
|
||||||
|
self.event_stream.subscribe(
|
||||||
|
EventStreamSubscriber.MEMORY,
|
||||||
|
self.on_event,
|
||||||
|
self.sid,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Additional placeholders to store user workspace microagents
|
||||||
|
self.repo_microagents: dict[str, RepoMicroAgent] = {}
|
||||||
|
self.knowledge_microagents: dict[str, KnowledgeMicroAgent] = {}
|
||||||
|
|
||||||
|
# Store repository / runtime info to send them to the templating later
|
||||||
|
self.repository_info: RepositoryInfo | None = None
|
||||||
|
self.runtime_info: RuntimeInfo | None = None
|
||||||
|
|
||||||
|
# Load global microagents (Knowledge + Repo)
|
||||||
|
# from typically OpenHands/microagents (i.e., the PUBLIC microagents)
|
||||||
|
self._load_global_microagents()
|
||||||
|
|
||||||
|
def on_event(self, event: Event):
|
||||||
|
"""Handle an event from the event stream."""
|
||||||
|
asyncio.get_event_loop().run_until_complete(self._on_event(event))
|
||||||
|
|
||||||
|
async def _on_event(self, event: Event):
|
||||||
|
"""Handle an event from the event stream asynchronously."""
|
||||||
|
try:
|
||||||
|
if isinstance(event, RecallAction):
|
||||||
|
# if this is a workspace context recall (on first user message)
|
||||||
|
# create and add a RecallObservation
|
||||||
|
# with info about repo, runtime, instructions, etc. including microagent knowledge if any
|
||||||
|
if (
|
||||||
|
event.source == EventSource.USER
|
||||||
|
and event.recall_type == RecallType.WORKSPACE_CONTEXT
|
||||||
|
):
|
||||||
|
logger.debug('Workspace context recall')
|
||||||
|
workspace_obs: RecallObservation | NullObservation | None = None
|
||||||
|
|
||||||
|
workspace_obs = self._on_workspace_context_recall(event)
|
||||||
|
if workspace_obs is None:
|
||||||
|
workspace_obs = NullObservation(content='')
|
||||||
|
|
||||||
|
# important: this will release the execution flow from waiting for the retrieval to complete
|
||||||
|
workspace_obs._cause = event.id # type: ignore[union-attr]
|
||||||
|
|
||||||
|
self.event_stream.add_event(workspace_obs, EventSource.ENVIRONMENT)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle knowledge recall (triggered microagents)
|
||||||
|
elif (
|
||||||
|
event.source == EventSource.USER
|
||||||
|
and event.recall_type == RecallType.KNOWLEDGE
|
||||||
|
):
|
||||||
|
logger.debug('Microagent knowledge recall')
|
||||||
|
microagent_obs: RecallObservation | NullObservation | None = None
|
||||||
|
microagent_obs = self._on_microagent_recall(event)
|
||||||
|
if microagent_obs is None:
|
||||||
|
microagent_obs = NullObservation(content='')
|
||||||
|
|
||||||
|
# important: this will release the execution flow from waiting for the retrieval to complete
|
||||||
|
microagent_obs._cause = event.id # type: ignore[union-attr]
|
||||||
|
|
||||||
|
self.event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
error_str = f'Error: {str(e.__class__.__name__)}'
|
||||||
|
logger.error(error_str)
|
||||||
|
self.send_error_message('STATUS$ERROR_MEMORY', error_str)
|
||||||
|
return
|
||||||
|
|
||||||
|
def _on_workspace_context_recall(
|
||||||
|
self, event: RecallAction
|
||||||
|
) -> RecallObservation | None:
|
||||||
|
"""Add repository and runtime information to the stream as a RecallObservation."""
|
||||||
|
|
||||||
|
# Create WORKSPACE_CONTEXT info:
|
||||||
|
# - repository_info
|
||||||
|
# - runtime_info
|
||||||
|
# - repository_instructions
|
||||||
|
# - microagent_knowledge
|
||||||
|
|
||||||
|
# Collect raw repository instructions
|
||||||
|
repo_instructions = ''
|
||||||
|
assert (
|
||||||
|
len(self.repo_microagents) <= 1
|
||||||
|
), f'Expecting at most one repo microagent, but found {len(self.repo_microagents)}: {self.repo_microagents.keys()}'
|
||||||
|
|
||||||
|
# Retrieve the context of repo instructions
|
||||||
|
for microagent in self.repo_microagents.values():
|
||||||
|
if repo_instructions:
|
||||||
|
repo_instructions += '\n\n'
|
||||||
|
repo_instructions += microagent.content
|
||||||
|
|
||||||
|
# Find any matched microagents based on the query
|
||||||
|
microagent_knowledge = self._find_microagent_knowledge(event.query)
|
||||||
|
|
||||||
|
# Create observation if we have anything
|
||||||
|
if (
|
||||||
|
self.repository_info
|
||||||
|
or self.runtime_info
|
||||||
|
or repo_instructions
|
||||||
|
or microagent_knowledge
|
||||||
|
):
|
||||||
|
obs = RecallObservation(
|
||||||
|
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||||
|
repo_name=self.repository_info.repo_name
|
||||||
|
if self.repository_info and self.repository_info.repo_name is not None
|
||||||
|
else '',
|
||||||
|
repo_directory=self.repository_info.repo_directory
|
||||||
|
if self.repository_info
|
||||||
|
and self.repository_info.repo_directory is not None
|
||||||
|
else '',
|
||||||
|
repo_instructions=repo_instructions if repo_instructions else '',
|
||||||
|
runtime_hosts=self.runtime_info.available_hosts
|
||||||
|
if self.runtime_info and self.runtime_info.available_hosts is not None
|
||||||
|
else {},
|
||||||
|
additional_agent_instructions=self.runtime_info.additional_agent_instructions
|
||||||
|
if self.runtime_info
|
||||||
|
and self.runtime_info.additional_agent_instructions is not None
|
||||||
|
else '',
|
||||||
|
microagent_knowledge=microagent_knowledge,
|
||||||
|
content='Added workspace context',
|
||||||
|
)
|
||||||
|
return obs
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _on_microagent_recall(
|
||||||
|
self,
|
||||||
|
event: RecallAction,
|
||||||
|
) -> RecallObservation | None:
|
||||||
|
"""When a microagent action triggers microagents, create a RecallObservation with structured data."""
|
||||||
|
|
||||||
|
# Find any matched microagents based on the query
|
||||||
|
microagent_knowledge = self._find_microagent_knowledge(event.query)
|
||||||
|
|
||||||
|
# Create observation if we have anything
|
||||||
|
if microagent_knowledge:
|
||||||
|
obs = RecallObservation(
|
||||||
|
recall_type=RecallType.KNOWLEDGE,
|
||||||
|
microagent_knowledge=microagent_knowledge,
|
||||||
|
content='Retrieved knowledge from microagents',
|
||||||
|
)
|
||||||
|
return obs
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _find_microagent_knowledge(self, query: str) -> list[MicroagentKnowledge]:
|
||||||
|
"""Find microagent knowledge based on a query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The query to search for microagent triggers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of MicroagentKnowledge objects for matched triggers
|
||||||
|
"""
|
||||||
|
recalled_content: list[MicroagentKnowledge] = []
|
||||||
|
|
||||||
|
# skip empty queries
|
||||||
|
if not query:
|
||||||
|
return recalled_content
|
||||||
|
|
||||||
|
# Search for microagent triggers in the query
|
||||||
|
for name, microagent in self.knowledge_microagents.items():
|
||||||
|
trigger = microagent.match_trigger(query)
|
||||||
|
if trigger:
|
||||||
|
logger.info("Microagent '%s' triggered by keyword '%s'", name, trigger)
|
||||||
|
recalled_content.append(
|
||||||
|
MicroagentKnowledge(
|
||||||
|
name=microagent.name,
|
||||||
|
trigger=trigger,
|
||||||
|
content=microagent.content,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return recalled_content
|
||||||
|
|
||||||
|
def load_user_workspace_microagents(
|
||||||
|
self, user_microagents: list[BaseMicroAgent]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
This method loads microagents from a user's cloned repo or workspace directory.
|
||||||
|
|
||||||
|
This is typically called from agent_session or setup once the workspace is cloned.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
'Loading user workspace microagents: %s', [m.name for m in user_microagents]
|
||||||
|
)
|
||||||
|
for user_microagent in user_microagents:
|
||||||
|
if isinstance(user_microagent, KnowledgeMicroAgent):
|
||||||
|
self.knowledge_microagents[user_microagent.name] = user_microagent
|
||||||
|
elif isinstance(user_microagent, RepoMicroAgent):
|
||||||
|
self.repo_microagents[user_microagent.name] = user_microagent
|
||||||
|
|
||||||
|
def _load_global_microagents(self) -> None:
|
||||||
|
"""
|
||||||
|
Loads microagents from the global microagents_dir
|
||||||
|
"""
|
||||||
|
repo_agents, knowledge_agents, _ = load_microagents_from_dir(
|
||||||
|
GLOBAL_MICROAGENTS_DIR
|
||||||
|
)
|
||||||
|
for name, agent in knowledge_agents.items():
|
||||||
|
if isinstance(agent, KnowledgeMicroAgent):
|
||||||
|
self.knowledge_microagents[name] = agent
|
||||||
|
for name, agent in repo_agents.items():
|
||||||
|
if isinstance(agent, RepoMicroAgent):
|
||||||
|
self.repo_microagents[name] = agent
|
||||||
|
|
||||||
|
def set_repository_info(self, repo_name: str, repo_directory: str) -> None:
|
||||||
|
"""Store repository info so we can reference it in an observation."""
|
||||||
|
if repo_name or repo_directory:
|
||||||
|
self.repository_info = RepositoryInfo(repo_name, repo_directory)
|
||||||
|
else:
|
||||||
|
self.repository_info = None
|
||||||
|
|
||||||
|
def set_runtime_info(self, runtime: Runtime) -> None:
|
||||||
|
"""Store runtime info (web hosts, ports, etc.)."""
|
||||||
|
# e.g. { '127.0.0.1': 8080 }
|
||||||
|
if runtime.web_hosts or runtime.additional_agent_instructions:
|
||||||
|
self.runtime_info = RuntimeInfo(
|
||||||
|
available_hosts=runtime.web_hosts,
|
||||||
|
additional_agent_instructions=runtime.additional_agent_instructions,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.runtime_info = None
|
||||||
|
|
||||||
|
def send_error_message(self, message_id: str, message: str):
|
||||||
|
"""Sends an error message if the callback function was provided."""
|
||||||
|
if self.status_callback:
|
||||||
|
try:
|
||||||
|
if self.loop is None:
|
||||||
|
self.loop = asyncio.get_running_loop()
|
||||||
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
self._send_status_message('error', message_id, message), self.loop
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(
|
||||||
|
f'Error sending status message: {e.__class__.__name__}',
|
||||||
|
stack_info=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _send_status_message(self, msg_type: str, id: str, message: str):
|
||||||
|
"""Sends a status message to the client."""
|
||||||
|
if self.status_callback:
|
||||||
|
self.status_callback(msg_type, id, message)
|
||||||
@@ -97,7 +97,7 @@ class Runtime(FileEditRuntimeMixin):
|
|||||||
status_callback: Callable | None = None,
|
status_callback: Callable | None = None,
|
||||||
attach_to_existing: bool = False,
|
attach_to_existing: bool = False,
|
||||||
headless_mode: bool = False,
|
headless_mode: bool = False,
|
||||||
github_user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
self.sid = sid
|
self.sid = sid
|
||||||
self.event_stream = event_stream
|
self.event_stream = event_stream
|
||||||
@@ -130,7 +130,7 @@ class Runtime(FileEditRuntimeMixin):
|
|||||||
self, enable_llm_editor=config.get_agent_config().codeact_enable_llm_editor
|
self, enable_llm_editor=config.get_agent_config().codeact_enable_llm_editor
|
||||||
)
|
)
|
||||||
|
|
||||||
self.github_user_id = github_user_id
|
self.user_id = user_id
|
||||||
|
|
||||||
def setup_initial_env(self) -> None:
|
def setup_initial_env(self) -> None:
|
||||||
if self.attach_to_existing:
|
if self.attach_to_existing:
|
||||||
@@ -220,9 +220,9 @@ class Runtime(FileEditRuntimeMixin):
|
|||||||
assert event.timeout is not None
|
assert event.timeout is not None
|
||||||
try:
|
try:
|
||||||
if isinstance(event, CmdRunAction):
|
if isinstance(event, CmdRunAction):
|
||||||
if self.github_user_id and '$GITHUB_TOKEN' in event.command:
|
if self.user_id and '$GITHUB_TOKEN' in event.command:
|
||||||
gh_client = GithubServiceImpl(
|
gh_client = GithubServiceImpl(
|
||||||
user_id=self.github_user_id, external_token_manager=True
|
external_auth_id=self.user_id, external_token_manager=True
|
||||||
)
|
)
|
||||||
token = await gh_client.get_latest_token()
|
token = await gh_client.get_latest_token()
|
||||||
if token:
|
if token:
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class ActionExecutionClient(Runtime):
|
|||||||
status_callback: Any | None = None,
|
status_callback: Any | None = None,
|
||||||
attach_to_existing: bool = False,
|
attach_to_existing: bool = False,
|
||||||
headless_mode: bool = True,
|
headless_mode: bool = True,
|
||||||
github_user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
self.session = HttpSession()
|
self.session = HttpSession()
|
||||||
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
|
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
|
||||||
@@ -75,7 +75,7 @@ class ActionExecutionClient(Runtime):
|
|||||||
status_callback,
|
status_callback,
|
||||||
attach_to_existing,
|
attach_to_existing,
|
||||||
headless_mode,
|
headless_mode,
|
||||||
github_user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
@@ -45,7 +46,7 @@ class RemoteRuntime(ActionExecutionClient):
|
|||||||
status_callback: Callable | None = None,
|
status_callback: Callable | None = None,
|
||||||
attach_to_existing: bool = False,
|
attach_to_existing: bool = False,
|
||||||
headless_mode: bool = True,
|
headless_mode: bool = True,
|
||||||
github_user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
config,
|
config,
|
||||||
@@ -56,7 +57,7 @@ class RemoteRuntime(ActionExecutionClient):
|
|||||||
status_callback,
|
status_callback,
|
||||||
attach_to_existing,
|
attach_to_existing,
|
||||||
headless_mode,
|
headless_mode,
|
||||||
github_user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
if self.config.sandbox.api_key is None:
|
if self.config.sandbox.api_key is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -425,10 +426,11 @@ class RemoteRuntime(ActionExecutionClient):
|
|||||||
return self._send_action_server_request_impl(method, url, **kwargs)
|
return self._send_action_server_request_impl(method, url, **kwargs)
|
||||||
|
|
||||||
retry_decorator = tenacity.retry(
|
retry_decorator = tenacity.retry(
|
||||||
retry=tenacity.retry_if_exception_type(ConnectionError),
|
retry=tenacity.retry_if_exception_type(requests.ConnectionError),
|
||||||
stop=tenacity.stop_after_attempt(3)
|
stop=tenacity.stop_after_attempt(3)
|
||||||
| stop_if_should_exit()
|
| stop_if_should_exit()
|
||||||
| self._stop_if_closed,
|
| self._stop_if_closed,
|
||||||
|
before_sleep=tenacity.before_sleep_log(logger, logging.WARNING),
|
||||||
wait=tenacity.wait_exponential(multiplier=1, min=4, max=60),
|
wait=tenacity.wait_exponential(multiplier=1, min=4, max=60),
|
||||||
)
|
)
|
||||||
return retry_decorator(self._send_action_server_request_impl)(
|
return retry_decorator(self._send_action_server_request_impl)(
|
||||||
|
|||||||
@@ -46,7 +46,12 @@ class ConversationManager(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def join_conversation(
|
async def join_conversation(
|
||||||
self, sid: str, connection_id: str, settings: Settings, user_id: str | None
|
self,
|
||||||
|
sid: str,
|
||||||
|
connection_id: str,
|
||||||
|
settings: Settings,
|
||||||
|
user_id: str | None,
|
||||||
|
github_user_id: str | None,
|
||||||
) -> EventStream | None:
|
) -> EventStream | None:
|
||||||
"""Join a conversation and return its event stream."""
|
"""Join a conversation and return its event stream."""
|
||||||
|
|
||||||
@@ -74,6 +79,7 @@ class ConversationManager(ABC):
|
|||||||
settings: Settings,
|
settings: Settings,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
initial_user_msg: MessageAction | None = None,
|
initial_user_msg: MessageAction | None = None,
|
||||||
|
github_user_id: str | None = None,
|
||||||
) -> EventStream:
|
) -> EventStream:
|
||||||
"""Start an event loop if one is not already running"""
|
"""Start an event loop if one is not already running"""
|
||||||
|
|
||||||
|
|||||||
@@ -106,7 +106,12 @@ class StandaloneConversationManager(ConversationManager):
|
|||||||
return c
|
return c
|
||||||
|
|
||||||
async def join_conversation(
|
async def join_conversation(
|
||||||
self, sid: str, connection_id: str, settings: Settings, user_id: str | None
|
self,
|
||||||
|
sid: str,
|
||||||
|
connection_id: str,
|
||||||
|
settings: Settings,
|
||||||
|
user_id: str | None,
|
||||||
|
github_user_id: str | None,
|
||||||
):
|
):
|
||||||
logger.info(
|
logger.info(
|
||||||
f'join_conversation:{sid}:{connection_id}',
|
f'join_conversation:{sid}:{connection_id}',
|
||||||
@@ -116,7 +121,9 @@ class StandaloneConversationManager(ConversationManager):
|
|||||||
self._local_connection_id_to_session_id[connection_id] = sid
|
self._local_connection_id_to_session_id[connection_id] = sid
|
||||||
event_stream = await self._get_event_stream(sid)
|
event_stream = await self._get_event_stream(sid)
|
||||||
if not event_stream:
|
if not event_stream:
|
||||||
return await self.maybe_start_agent_loop(sid, settings, user_id)
|
return await self.maybe_start_agent_loop(
|
||||||
|
sid, settings, user_id, github_user_id=github_user_id
|
||||||
|
)
|
||||||
for event in event_stream.get_events(reverse=True):
|
for event in event_stream.get_events(reverse=True):
|
||||||
if isinstance(event, AgentStateChangedObservation):
|
if isinstance(event, AgentStateChangedObservation):
|
||||||
if event.agent_state in (
|
if event.agent_state in (
|
||||||
@@ -187,14 +194,18 @@ class StandaloneConversationManager(ConversationManager):
|
|||||||
logger.error('error_cleaning_stale')
|
logger.error('error_cleaning_stale')
|
||||||
await asyncio.sleep(_CLEANUP_INTERVAL)
|
await asyncio.sleep(_CLEANUP_INTERVAL)
|
||||||
|
|
||||||
async def _get_conversation_store(self, user_id: str | None) -> ConversationStore:
|
async def _get_conversation_store(
|
||||||
|
self, user_id: str | None, github_user_id: str | None
|
||||||
|
) -> ConversationStore:
|
||||||
conversation_store_class = self._conversation_store_class
|
conversation_store_class = self._conversation_store_class
|
||||||
if not conversation_store_class:
|
if not conversation_store_class:
|
||||||
self._conversation_store_class = conversation_store_class = get_impl(
|
self._conversation_store_class = conversation_store_class = get_impl(
|
||||||
ConversationStore, # type: ignore
|
ConversationStore, # type: ignore
|
||||||
self.server_config.conversation_store_class,
|
self.server_config.conversation_store_class,
|
||||||
)
|
)
|
||||||
store = await conversation_store_class.get_instance(self.config, user_id)
|
store = await conversation_store_class.get_instance(
|
||||||
|
self.config, user_id, github_user_id
|
||||||
|
)
|
||||||
return store
|
return store
|
||||||
|
|
||||||
async def get_running_agent_loops(
|
async def get_running_agent_loops(
|
||||||
@@ -243,6 +254,7 @@ class StandaloneConversationManager(ConversationManager):
|
|||||||
settings: Settings,
|
settings: Settings,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
initial_user_msg: MessageAction | None = None,
|
initial_user_msg: MessageAction | None = None,
|
||||||
|
github_user_id: str | None = None,
|
||||||
) -> EventStream:
|
) -> EventStream:
|
||||||
logger.info(f'maybe_start_agent_loop:{sid}', extra={'session_id': sid})
|
logger.info(f'maybe_start_agent_loop:{sid}', extra={'session_id': sid})
|
||||||
session: Session | None = None
|
session: Session | None = None
|
||||||
@@ -256,7 +268,9 @@ class StandaloneConversationManager(ConversationManager):
|
|||||||
extra={'session_id': sid, 'user_id': user_id},
|
extra={'session_id': sid, 'user_id': user_id},
|
||||||
)
|
)
|
||||||
# Get the conversations sorted (oldest first)
|
# Get the conversations sorted (oldest first)
|
||||||
conversation_store = await self._get_conversation_store(user_id)
|
conversation_store = await self._get_conversation_store(
|
||||||
|
user_id, github_user_id
|
||||||
|
)
|
||||||
conversations = await conversation_store.get_all_metadata(response_ids)
|
conversations = await conversation_store.get_all_metadata(response_ids)
|
||||||
conversations.sort(key=_last_updated_at_key, reverse=True)
|
conversations.sort(key=_last_updated_at_key, reverse=True)
|
||||||
|
|
||||||
@@ -277,7 +291,9 @@ class StandaloneConversationManager(ConversationManager):
|
|||||||
try:
|
try:
|
||||||
session.agent_session.event_stream.subscribe(
|
session.agent_session.event_stream.subscribe(
|
||||||
EventStreamSubscriber.SERVER,
|
EventStreamSubscriber.SERVER,
|
||||||
self._create_conversation_update_callback(user_id, sid),
|
self._create_conversation_update_callback(
|
||||||
|
user_id, github_user_id, sid
|
||||||
|
),
|
||||||
UPDATED_AT_CALLBACK_ID,
|
UPDATED_AT_CALLBACK_ID,
|
||||||
)
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@@ -374,22 +390,23 @@ class StandaloneConversationManager(ConversationManager):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _create_conversation_update_callback(
|
def _create_conversation_update_callback(
|
||||||
self, user_id: str | None, conversation_id: str
|
self, user_id: str | None, github_user_id: str | None, conversation_id: str
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
def callback(*args, **kwargs):
|
def callback(*args, **kwargs):
|
||||||
call_async_from_sync(
|
call_async_from_sync(
|
||||||
self._update_timestamp_for_conversation,
|
self._update_timestamp_for_conversation,
|
||||||
GENERAL_TIMEOUT,
|
GENERAL_TIMEOUT,
|
||||||
user_id,
|
user_id,
|
||||||
|
github_user_id,
|
||||||
conversation_id,
|
conversation_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
async def _update_timestamp_for_conversation(
|
async def _update_timestamp_for_conversation(
|
||||||
self, user_id: str, conversation_id: str
|
self, user_id: str, github_user_id: str, conversation_id: str
|
||||||
):
|
):
|
||||||
conversation_store = await self._get_conversation_store(user_id)
|
conversation_store = await self._get_conversation_store(user_id, github_user_id)
|
||||||
conversation = await conversation_store.get_metadata(conversation_id)
|
conversation = await conversation_store.get_metadata(conversation_id)
|
||||||
conversation.last_updated_at = datetime.now(timezone.utc)
|
conversation.last_updated_at = datetime.now(timezone.utc)
|
||||||
await conversation_store.save_metadata(conversation)
|
await conversation_store.save_metadata(conversation)
|
||||||
|
|||||||
@@ -6,10 +6,14 @@ from openhands.core.logger import openhands_logger as logger
|
|||||||
from openhands.events.action import (
|
from openhands.events.action import (
|
||||||
NullAction,
|
NullAction,
|
||||||
)
|
)
|
||||||
|
from openhands.events.action.agent import RecallAction
|
||||||
from openhands.events.observation import (
|
from openhands.events.observation import (
|
||||||
NullObservation,
|
NullObservation,
|
||||||
)
|
)
|
||||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
from openhands.events.observation.agent import (
|
||||||
|
AgentStateChangedObservation,
|
||||||
|
RecallObservation,
|
||||||
|
)
|
||||||
from openhands.events.serialization import event_to_dict
|
from openhands.events.serialization import event_to_dict
|
||||||
from openhands.events.stream import AsyncEventStreamWrapper
|
from openhands.events.stream import AsyncEventStreamWrapper
|
||||||
from openhands.server.shared import (
|
from openhands.server.shared import (
|
||||||
@@ -35,7 +39,9 @@ async def connect(connection_id: str, environ):
|
|||||||
|
|
||||||
cookies_str = environ.get('HTTP_COOKIE', '')
|
cookies_str = environ.get('HTTP_COOKIE', '')
|
||||||
conversation_validator = ConversationValidatorImpl()
|
conversation_validator = ConversationValidatorImpl()
|
||||||
user_id = await conversation_validator.validate(conversation_id, cookies_str)
|
user_id, github_user_id = await conversation_validator.validate(
|
||||||
|
conversation_id, cookies_str
|
||||||
|
)
|
||||||
|
|
||||||
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
|
||||||
settings = await settings_store.load()
|
settings = await settings_store.load()
|
||||||
@@ -46,7 +52,7 @@ async def connect(connection_id: str, environ):
|
|||||||
)
|
)
|
||||||
|
|
||||||
event_stream = await conversation_manager.join_conversation(
|
event_stream = await conversation_manager.join_conversation(
|
||||||
conversation_id, connection_id, settings, user_id
|
conversation_id, connection_id, settings, user_id, github_user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
agent_state_changed = None
|
agent_state_changed = None
|
||||||
@@ -54,10 +60,7 @@ async def connect(connection_id: str, environ):
|
|||||||
async for event in async_stream:
|
async for event in async_stream:
|
||||||
if isinstance(
|
if isinstance(
|
||||||
event,
|
event,
|
||||||
(
|
(NullAction, NullObservation, RecallAction, RecallObservation),
|
||||||
NullAction,
|
|
||||||
NullObservation,
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
elif isinstance(event, AgentStateChangedObservation):
|
elif isinstance(event, AgentStateChangedObservation):
|
||||||
|
|||||||
@@ -10,7 +10,12 @@ from openhands.events.action.message import MessageAction
|
|||||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||||
from openhands.integrations.provider import ProviderType
|
from openhands.integrations.provider import ProviderType
|
||||||
from openhands.runtime import get_runtime_cls
|
from openhands.runtime import get_runtime_cls
|
||||||
from openhands.server.auth import get_provider_tokens, get_access_token, get_github_user_id
|
from openhands.server.auth import (
|
||||||
|
get_access_token,
|
||||||
|
get_github_user_id,
|
||||||
|
get_provider_tokens,
|
||||||
|
get_user_id,
|
||||||
|
)
|
||||||
from openhands.server.data_models.conversation_info import ConversationInfo
|
from openhands.server.data_models.conversation_info import ConversationInfo
|
||||||
from openhands.server.data_models.conversation_info_result_set import (
|
from openhands.server.data_models.conversation_info_result_set import (
|
||||||
ConversationInfoResultSet,
|
ConversationInfoResultSet,
|
||||||
@@ -73,12 +78,12 @@ async def _create_new_conversation(
|
|||||||
logger.warn('Settings not present, not starting conversation')
|
logger.warn('Settings not present, not starting conversation')
|
||||||
raise MissingSettingsError('Settings not found')
|
raise MissingSettingsError('Settings not found')
|
||||||
|
|
||||||
session_init_args['github_token'] = token or SecretStr('')
|
session_init_args['provider_token'] = token
|
||||||
session_init_args['selected_repository'] = selected_repository
|
session_init_args['selected_repository'] = selected_repository
|
||||||
session_init_args['selected_branch'] = selected_branch
|
session_init_args['selected_branch'] = selected_branch
|
||||||
conversation_init_data = ConversationInitData(**session_init_args)
|
conversation_init_data = ConversationInitData(**session_init_args)
|
||||||
logger.info('Loading conversation store')
|
logger.info('Loading conversation store')
|
||||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
conversation_store = await ConversationStoreImpl.get_instance(config, user_id, None)
|
||||||
logger.info('Conversation store loaded')
|
logger.info('Conversation store loaded')
|
||||||
|
|
||||||
conversation_id = uuid.uuid4().hex
|
conversation_id = uuid.uuid4().hex
|
||||||
@@ -100,7 +105,8 @@ async def _create_new_conversation(
|
|||||||
ConversationMetadata(
|
ConversationMetadata(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
title=conversation_title,
|
title=conversation_title,
|
||||||
github_user_id=user_id,
|
user_id=user_id,
|
||||||
|
github_user_id=None,
|
||||||
selected_repository=selected_repository,
|
selected_repository=selected_repository,
|
||||||
selected_branch=selected_branch,
|
selected_branch=selected_branch,
|
||||||
)
|
)
|
||||||
@@ -122,7 +128,10 @@ async def _create_new_conversation(
|
|||||||
image_urls=image_urls or [],
|
image_urls=image_urls or [],
|
||||||
)
|
)
|
||||||
await conversation_manager.maybe_start_agent_loop(
|
await conversation_manager.maybe_start_agent_loop(
|
||||||
conversation_id, conversation_init_data, user_id, initial_message_action
|
conversation_id,
|
||||||
|
conversation_init_data,
|
||||||
|
user_id,
|
||||||
|
initial_user_msg=initial_message_action,
|
||||||
)
|
)
|
||||||
logger.info(f'Finished initializing conversation {conversation_id}')
|
logger.info(f'Finished initializing conversation {conversation_id}')
|
||||||
|
|
||||||
@@ -158,7 +167,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
|
|||||||
try:
|
try:
|
||||||
# Create conversation with initial message
|
# Create conversation with initial message
|
||||||
conversation_id = await _create_new_conversation(
|
conversation_id = await _create_new_conversation(
|
||||||
user_id,
|
get_user_id(request),
|
||||||
github_token,
|
github_token,
|
||||||
selected_repository,
|
selected_repository,
|
||||||
selected_branch,
|
selected_branch,
|
||||||
@@ -197,7 +206,7 @@ async def search_conversations(
|
|||||||
limit: int = 20,
|
limit: int = 20,
|
||||||
) -> ConversationInfoResultSet:
|
) -> ConversationInfoResultSet:
|
||||||
conversation_store = await ConversationStoreImpl.get_instance(
|
conversation_store = await ConversationStoreImpl.get_instance(
|
||||||
config, get_github_user_id(request)
|
config, get_user_id(request), get_github_user_id(request)
|
||||||
)
|
)
|
||||||
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
|
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
|
||||||
|
|
||||||
@@ -216,7 +225,7 @@ async def search_conversations(
|
|||||||
conversation.conversation_id for conversation in filtered_results
|
conversation.conversation_id for conversation in filtered_results
|
||||||
)
|
)
|
||||||
running_conversations = await conversation_manager.get_running_agent_loops(
|
running_conversations = await conversation_manager.get_running_agent_loops(
|
||||||
get_github_user_id(request), set(conversation_ids)
|
get_user_id(request), set(conversation_ids)
|
||||||
)
|
)
|
||||||
result = ConversationInfoResultSet(
|
result = ConversationInfoResultSet(
|
||||||
results=await wait_all(
|
results=await wait_all(
|
||||||
@@ -236,7 +245,7 @@ async def get_conversation(
|
|||||||
conversation_id: str, request: Request
|
conversation_id: str, request: Request
|
||||||
) -> ConversationInfo | None:
|
) -> ConversationInfo | None:
|
||||||
conversation_store = await ConversationStoreImpl.get_instance(
|
conversation_store = await ConversationStoreImpl.get_instance(
|
||||||
config, get_github_user_id(request)
|
config, get_user_id(request), get_github_user_id(request)
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
metadata = await conversation_store.get_metadata(conversation_id)
|
metadata = await conversation_store.get_metadata(conversation_id)
|
||||||
@@ -252,7 +261,7 @@ async def update_conversation(
|
|||||||
request: Request, conversation_id: str, title: str = Body(embed=True)
|
request: Request, conversation_id: str, title: str = Body(embed=True)
|
||||||
) -> bool:
|
) -> bool:
|
||||||
conversation_store = await ConversationStoreImpl.get_instance(
|
conversation_store = await ConversationStoreImpl.get_instance(
|
||||||
config, get_github_user_id(request)
|
config, get_user_id(request), get_github_user_id(request)
|
||||||
)
|
)
|
||||||
metadata = await conversation_store.get_metadata(conversation_id)
|
metadata = await conversation_store.get_metadata(conversation_id)
|
||||||
if not metadata:
|
if not metadata:
|
||||||
@@ -268,7 +277,7 @@ async def delete_conversation(
|
|||||||
request: Request,
|
request: Request,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
conversation_store = await ConversationStoreImpl.get_instance(
|
conversation_store = await ConversationStoreImpl.get_instance(
|
||||||
config, get_github_user_id(request)
|
config, get_user_id(request), get_github_user_id(request)
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
await conversation_store.get_metadata(conversation_id)
|
await conversation_store.get_metadata(conversation_id)
|
||||||
|
|||||||
@@ -90,30 +90,38 @@ async def store_settings(
|
|||||||
existing_settings.user_consents_to_analytics
|
existing_settings.user_consents_to_analytics
|
||||||
)
|
)
|
||||||
|
|
||||||
if existing_settings.secrets_store:
|
if settings.unset_github_token:
|
||||||
existing_providers = [
|
|
||||||
provider.value
|
|
||||||
for provider in existing_settings.secrets_store.provider_tokens
|
|
||||||
]
|
|
||||||
|
|
||||||
# Merge incoming settings store with the existing one
|
|
||||||
for provider, token_value in settings.provider_tokens.items():
|
|
||||||
if provider in existing_providers and not token_value:
|
|
||||||
provider_type = ProviderType(provider)
|
|
||||||
existing_token = (
|
|
||||||
existing_settings.secrets_store.provider_tokens.get(
|
|
||||||
provider_type
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if existing_token and existing_token.token:
|
|
||||||
settings.provider_tokens[provider] = (
|
|
||||||
existing_token.token.get_secret_value()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Merge provider tokens with existing ones
|
|
||||||
if settings.unset_github_token: # Only merge if not unsetting tokens
|
|
||||||
settings.secrets_store.provider_tokens = {}
|
settings.secrets_store.provider_tokens = {}
|
||||||
settings.provider_tokens = {}
|
settings.provider_tokens = {}
|
||||||
|
else: # Only merge if not unsetting tokens
|
||||||
|
if settings.provider_tokens:
|
||||||
|
if existing_settings.secrets_store:
|
||||||
|
existing_providers = [
|
||||||
|
provider.value
|
||||||
|
for provider in existing_settings.secrets_store.provider_tokens
|
||||||
|
]
|
||||||
|
|
||||||
|
# Merge incoming settings store with the existing one
|
||||||
|
for provider, token_value in settings.provider_tokens.items():
|
||||||
|
if provider in existing_providers and not token_value:
|
||||||
|
provider_type = ProviderType(provider)
|
||||||
|
existing_token = (
|
||||||
|
existing_settings.secrets_store.provider_tokens.get(
|
||||||
|
provider_type
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if existing_token and existing_token.token:
|
||||||
|
settings.provider_tokens[provider] = (
|
||||||
|
existing_token.token.get_secret_value()
|
||||||
|
)
|
||||||
|
else: # nothing passed in means keep current settings
|
||||||
|
provider_tokens = existing_settings.secrets_store.provider_tokens
|
||||||
|
settings.provider_tokens = {
|
||||||
|
provider.value: data.token.get_secret_value()
|
||||||
|
if data.token
|
||||||
|
else None
|
||||||
|
for provider, data in provider_tokens.items()
|
||||||
|
}
|
||||||
|
|
||||||
# Update sandbox config with new settings
|
# Update sandbox config with new settings
|
||||||
if settings.remote_runtime_resource_factor is not None:
|
if settings.remote_runtime_resource_factor is not None:
|
||||||
|
|||||||
@@ -15,7 +15,8 @@ from openhands.core.schema.agent import AgentState
|
|||||||
from openhands.events.action import ChangeAgentStateAction, MessageAction
|
from openhands.events.action import ChangeAgentStateAction, MessageAction
|
||||||
from openhands.events.event import EventSource
|
from openhands.events.event import EventSource
|
||||||
from openhands.events.stream import EventStream
|
from openhands.events.stream import EventStream
|
||||||
from openhands.microagent import BaseMicroAgent
|
from openhands.memory.memory import Memory
|
||||||
|
from openhands.microagent.microagent import BaseMicroAgent
|
||||||
from openhands.runtime import get_runtime_cls
|
from openhands.runtime import get_runtime_cls
|
||||||
from openhands.runtime.base import Runtime
|
from openhands.runtime.base import Runtime
|
||||||
from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime
|
from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime
|
||||||
@@ -52,7 +53,7 @@ class AgentSession:
|
|||||||
sid: str,
|
sid: str,
|
||||||
file_store: FileStore,
|
file_store: FileStore,
|
||||||
status_callback: Callable | None = None,
|
status_callback: Callable | None = None,
|
||||||
github_user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
):
|
):
|
||||||
"""Initializes a new instance of the Session class
|
"""Initializes a new instance of the Session class
|
||||||
|
|
||||||
@@ -65,9 +66,9 @@ class AgentSession:
|
|||||||
self.event_stream = EventStream(sid, file_store)
|
self.event_stream = EventStream(sid, file_store)
|
||||||
self.file_store = file_store
|
self.file_store = file_store
|
||||||
self._status_callback = status_callback
|
self._status_callback = status_callback
|
||||||
self.github_user_id = github_user_id
|
self.user_id = user_id
|
||||||
self.logger = OpenHandsLoggerAdapter(
|
self.logger = OpenHandsLoggerAdapter(
|
||||||
extra={'session_id': sid, 'user_id': github_user_id}
|
extra={'session_id': sid, 'user_id': user_id}
|
||||||
)
|
)
|
||||||
|
|
||||||
async def start(
|
async def start(
|
||||||
@@ -126,6 +127,15 @@ class AgentSession:
|
|||||||
agent_to_llm_config=agent_to_llm_config,
|
agent_to_llm_config=agent_to_llm_config,
|
||||||
agent_configs=agent_configs,
|
agent_configs=agent_configs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
repo_directory = None
|
||||||
|
if self.runtime and runtime_connected and selected_repository:
|
||||||
|
repo_directory = selected_repository.split('/')[-1]
|
||||||
|
self.memory = await self._create_memory(
|
||||||
|
selected_repository=selected_repository,
|
||||||
|
repo_directory=repo_directory,
|
||||||
|
)
|
||||||
|
|
||||||
if github_token:
|
if github_token:
|
||||||
self.event_stream.set_secrets(
|
self.event_stream.set_secrets(
|
||||||
{
|
{
|
||||||
@@ -231,7 +241,7 @@ class AgentSession:
|
|||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if runtime_cls == RemoteRuntime:
|
if runtime_cls == RemoteRuntime:
|
||||||
kwargs['github_user_id'] = self.github_user_id
|
kwargs['user_id'] = self.user_id
|
||||||
|
|
||||||
self.runtime = runtime_cls(
|
self.runtime = runtime_cls(
|
||||||
config=config,
|
config=config,
|
||||||
@@ -260,26 +270,14 @@ class AgentSession:
|
|||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
repo_directory = None
|
|
||||||
if selected_repository:
|
if selected_repository:
|
||||||
repo_directory = await call_sync_from_async(
|
await call_sync_from_async(
|
||||||
self.runtime.clone_repo,
|
self.runtime.clone_repo,
|
||||||
github_token,
|
github_token,
|
||||||
selected_repository,
|
selected_repository,
|
||||||
selected_branch,
|
selected_branch,
|
||||||
)
|
)
|
||||||
|
|
||||||
if agent.prompt_manager:
|
|
||||||
agent.prompt_manager.set_runtime_info(self.runtime)
|
|
||||||
microagents: list[BaseMicroAgent] = await call_sync_from_async(
|
|
||||||
self.runtime.get_microagents_from_selected_repo, selected_repository
|
|
||||||
)
|
|
||||||
agent.prompt_manager.load_microagents(microagents)
|
|
||||||
if selected_repository and repo_directory:
|
|
||||||
agent.prompt_manager.set_repository_info(
|
|
||||||
selected_repository, repo_directory
|
|
||||||
)
|
|
||||||
|
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
|
f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
|
||||||
)
|
)
|
||||||
@@ -342,6 +340,29 @@ class AgentSession:
|
|||||||
|
|
||||||
return controller
|
return controller
|
||||||
|
|
||||||
|
async def _create_memory(
|
||||||
|
self, selected_repository: str | None, repo_directory: str | None
|
||||||
|
) -> Memory:
|
||||||
|
memory = Memory(
|
||||||
|
event_stream=self.event_stream,
|
||||||
|
sid=self.sid,
|
||||||
|
status_callback=self._status_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.runtime:
|
||||||
|
# sets available hosts and other runtime info
|
||||||
|
memory.set_runtime_info(self.runtime)
|
||||||
|
|
||||||
|
# loads microagents from repo/.openhands/microagents
|
||||||
|
microagents: list[BaseMicroAgent] = await call_sync_from_async(
|
||||||
|
self.runtime.get_microagents_from_selected_repo, selected_repository
|
||||||
|
)
|
||||||
|
memory.load_user_workspace_microagents(microagents)
|
||||||
|
|
||||||
|
if selected_repository and repo_directory:
|
||||||
|
memory.set_repository_info(selected_repository, repo_directory)
|
||||||
|
return memory
|
||||||
|
|
||||||
def _maybe_restore_state(self) -> State | None:
|
def _maybe_restore_state(self) -> State | None:
|
||||||
"""Helper method to handle state restore logic."""
|
"""Helper method to handle state restore logic."""
|
||||||
restored_state = None
|
restored_state = None
|
||||||
|
|||||||
@@ -8,6 +8,6 @@ class ConversationInitData(Settings):
|
|||||||
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
|
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
github_token: SecretStr | None = Field(default=None)
|
provider_token: SecretStr | None = Field(default=None)
|
||||||
selected_repository: str | None = Field(default=None)
|
selected_repository: str | None = Field(default=None)
|
||||||
selected_branch: str | None = Field(default=None)
|
selected_branch: str | None = Field(default=None)
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ class Session:
|
|||||||
sid,
|
sid,
|
||||||
file_store,
|
file_store,
|
||||||
status_callback=self.queue_status_message,
|
status_callback=self.queue_status_message,
|
||||||
github_user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
self.agent_session.event_stream.subscribe(
|
self.agent_session.event_stream.subscribe(
|
||||||
EventStreamSubscriber.SERVER, self.on_event, self.sid
|
EventStreamSubscriber.SERVER, self.on_event, self.sid
|
||||||
@@ -123,11 +123,11 @@ class Session:
|
|||||||
|
|
||||||
agent = Agent.get_cls(agent_cls)(llm, agent_config)
|
agent = Agent.get_cls(agent_cls)(llm, agent_config)
|
||||||
|
|
||||||
github_token = None
|
provider_token = None
|
||||||
selected_repository = None
|
selected_repository = None
|
||||||
selected_branch = None
|
selected_branch = None
|
||||||
if isinstance(settings, ConversationInitData):
|
if isinstance(settings, ConversationInitData):
|
||||||
github_token = settings.github_token
|
provider_token = settings.provider_token
|
||||||
selected_repository = settings.selected_repository
|
selected_repository = settings.selected_repository
|
||||||
selected_branch = settings.selected_branch
|
selected_branch = settings.selected_branch
|
||||||
|
|
||||||
@@ -140,7 +140,7 @@ class Session:
|
|||||||
max_budget_per_task=self.config.max_budget_per_task,
|
max_budget_per_task=self.config.max_budget_per_task,
|
||||||
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
|
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
|
||||||
agent_configs=self.config.get_agent_configs(),
|
agent_configs=self.config.get_agent_configs(),
|
||||||
github_token=github_token,
|
github_token=provider_token,
|
||||||
selected_repository=selected_repository,
|
selected_repository=selected_repository,
|
||||||
selected_branch=selected_branch,
|
selected_branch=selected_branch,
|
||||||
initial_message=initial_message,
|
initial_message=initial_message,
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class Settings(BaseModel):
|
|||||||
if context and context.get('expose_secrets', False):
|
if context and context.get('expose_secrets', False):
|
||||||
return llm_api_key.get_secret_value()
|
return llm_api_key.get_secret_value()
|
||||||
|
|
||||||
return pydantic_encoder(llm_api_key)
|
return pydantic_encoder(llm_api_key) if llm_api_key else None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_token_value(
|
def _convert_token_value(
|
||||||
|
|||||||
@@ -12,25 +12,36 @@ from openhands.utils.async_utils import wait_all
|
|||||||
|
|
||||||
|
|
||||||
class ConversationStore(ABC):
|
class ConversationStore(ABC):
|
||||||
"""
|
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
|
||||||
Storage for conversation metadata. May or may not support multiple users depending on the environment
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def save_metadata(self, metadata: ConversationMetadata) -> None:
|
async def save_metadata(self, metadata: ConversationMetadata) -> None:
|
||||||
"""Store conversation metadata"""
|
"""Store conversation metadata."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_metadata(self, conversation_id: str) -> ConversationMetadata:
|
async def get_metadata(self, conversation_id: str) -> ConversationMetadata:
|
||||||
"""Load conversation metadata"""
|
"""Load conversation metadata."""
|
||||||
|
|
||||||
|
async def validate_metadata(
|
||||||
|
self, conversation_id: str, user_id: str, github_user_id: str
|
||||||
|
) -> bool:
|
||||||
|
"""Validate that conversation belongs to the current user."""
|
||||||
|
# TODO: remove github_user_id after transition to Keycloak is complete.
|
||||||
|
metadata = await self.get_metadata(conversation_id)
|
||||||
|
if (not metadata.user_id and not metadata.github_user_id) or (
|
||||||
|
metadata.user_id != user_id and metadata.github_user_id != github_user_id
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def delete_metadata(self, conversation_id: str) -> None:
|
async def delete_metadata(self, conversation_id: str) -> None:
|
||||||
"""delete conversation metadata"""
|
"""Delete conversation metadata."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def exists(self, conversation_id: str) -> bool:
|
async def exists(self, conversation_id: str) -> bool:
|
||||||
"""Check if conversation exists"""
|
"""Check if conversation exists."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def search(
|
async def search(
|
||||||
@@ -49,6 +60,6 @@ class ConversationStore(ABC):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_instance(
|
async def get_instance(
|
||||||
cls, config: AppConfig, user_id: str | None
|
cls, config: AppConfig, user_id: str | None, github_user_id: str | None
|
||||||
) -> ConversationStore:
|
) -> ConversationStore:
|
||||||
"""Get a store for the user represented by the token given"""
|
"""Get a store for the user represented by the token given"""
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ class ConversationValidator:
|
|||||||
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
|
"""Storage for conversation metadata. May or may not support multiple users depending on the environment."""
|
||||||
|
|
||||||
async def validate(self, conversation_id: str, cookies_str: str):
|
async def validate(self, conversation_id: str, cookies_str: str):
|
||||||
return None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
conversation_validator_cls = os.environ.get(
|
conversation_validator_cls = os.environ.get(
|
||||||
|
|||||||
@@ -85,8 +85,8 @@ class FileConversationStore(ConversationStore):
|
|||||||
try:
|
try:
|
||||||
conversations.append(await self.get_metadata(conversation_id))
|
conversations.append(await self.get_metadata(conversation_id))
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error(
|
logger.warning(
|
||||||
f'Error loading conversation: {conversation_id}',
|
f'Could not load conversation metadata: {conversation_id}',
|
||||||
)
|
)
|
||||||
conversations.sort(key=_sort_key, reverse=True)
|
conversations.sort(key=_sort_key, reverse=True)
|
||||||
conversations = conversations[start:end]
|
conversations = conversations[start:end]
|
||||||
@@ -101,7 +101,7 @@ class FileConversationStore(ConversationStore):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_instance(
|
async def get_instance(
|
||||||
cls, config: AppConfig, user_id: str | None
|
cls, config: AppConfig, user_id: str | None, github_user_id: str | None
|
||||||
) -> FileConversationStore:
|
) -> FileConversationStore:
|
||||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||||
return FileConversationStore(file_store)
|
return FileConversationStore(file_store)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from datetime import datetime, timezone
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ConversationMetadata:
|
class ConversationMetadata:
|
||||||
conversation_id: str
|
conversation_id: str
|
||||||
|
user_id: str | None
|
||||||
github_user_id: str | None
|
github_user_id: str | None
|
||||||
selected_repository: str | None
|
selected_repository: str | None
|
||||||
selected_branch: str | None = None
|
selected_branch: str | None = None
|
||||||
|
|||||||
+19
-154
@@ -1,25 +1,18 @@
|
|||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
|
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
|
|
||||||
from openhands.controller.state.state import State
|
from openhands.controller.state.state import State
|
||||||
from openhands.core.logger import openhands_logger
|
|
||||||
from openhands.core.message import Message, TextContent
|
from openhands.core.message import Message, TextContent
|
||||||
from openhands.microagent import (
|
from openhands.events.observation.agent import MicroagentKnowledge
|
||||||
BaseMicroAgent,
|
|
||||||
KnowledgeMicroAgent,
|
|
||||||
RepoMicroAgent,
|
|
||||||
load_microagents_from_dir,
|
|
||||||
)
|
|
||||||
from openhands.runtime.base import Runtime
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RuntimeInfo:
|
class RuntimeInfo:
|
||||||
available_hosts: dict[str, int]
|
available_hosts: dict[str, int] = field(default_factory=dict)
|
||||||
additional_agent_instructions: str
|
additional_agent_instructions: str = ''
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -32,75 +25,23 @@ class RepositoryInfo:
|
|||||||
|
|
||||||
class PromptManager:
|
class PromptManager:
|
||||||
"""
|
"""
|
||||||
Manages prompt templates and micro-agents for AI interactions.
|
Manages prompt templates and includes information from the user's workspace micro-agents and global micro-agents.
|
||||||
|
|
||||||
This class handles loading and rendering of system and user prompt templates,
|
This class is dedicated to loading and rendering prompts (system prompt, user prompt).
|
||||||
as well as loading micro-agent specifications. It provides methods to access
|
|
||||||
rendered system and initial user messages for AI interactions.
|
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
prompt_dir (str): Directory containing prompt templates.
|
prompt_dir: Directory containing prompt templates.
|
||||||
microagent_dir (str): Directory containing microagent specifications.
|
|
||||||
disabled_microagents (list[str] | None): List of microagents to disable. If None, all microagents are enabled.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
prompt_dir: str,
|
prompt_dir: str,
|
||||||
microagent_dir: str | None = None,
|
|
||||||
disabled_microagents: list[str] | None = None,
|
|
||||||
):
|
):
|
||||||
self.disabled_microagents: list[str] = disabled_microagents or []
|
|
||||||
self.prompt_dir: str = prompt_dir
|
self.prompt_dir: str = prompt_dir
|
||||||
self.repository_info: RepositoryInfo | None = None
|
|
||||||
self.system_template: Template = self._load_template('system_prompt')
|
self.system_template: Template = self._load_template('system_prompt')
|
||||||
self.user_template: Template = self._load_template('user_prompt')
|
self.user_template: Template = self._load_template('user_prompt')
|
||||||
self.additional_info_template: Template = self._load_template('additional_info')
|
self.additional_info_template: Template = self._load_template('additional_info')
|
||||||
self.microagent_info_template: Template = self._load_template('microagent_info')
|
self.microagent_info_template: Template = self._load_template('microagent_info')
|
||||||
self.runtime_info = RuntimeInfo(
|
|
||||||
available_hosts={}, additional_agent_instructions=''
|
|
||||||
)
|
|
||||||
|
|
||||||
self.knowledge_microagents: dict[str, KnowledgeMicroAgent] = {}
|
|
||||||
self.repo_microagents: dict[str, RepoMicroAgent] = {}
|
|
||||||
|
|
||||||
if microagent_dir:
|
|
||||||
# This loads micro-agents from the microagent_dir
|
|
||||||
# which is typically the OpenHands/microagents (i.e., the PUBLIC microagents)
|
|
||||||
|
|
||||||
# Only load KnowledgeMicroAgents
|
|
||||||
repo_microagents, knowledge_microagents, _ = load_microagents_from_dir(
|
|
||||||
microagent_dir
|
|
||||||
)
|
|
||||||
assert all(
|
|
||||||
isinstance(microagent, KnowledgeMicroAgent)
|
|
||||||
for microagent in knowledge_microagents.values()
|
|
||||||
)
|
|
||||||
for name, microagent in knowledge_microagents.items():
|
|
||||||
if name not in self.disabled_microagents:
|
|
||||||
self.knowledge_microagents[name] = microagent
|
|
||||||
assert all(
|
|
||||||
isinstance(microagent, RepoMicroAgent)
|
|
||||||
for microagent in repo_microagents.values()
|
|
||||||
)
|
|
||||||
for name, microagent in repo_microagents.items():
|
|
||||||
if name not in self.disabled_microagents:
|
|
||||||
self.repo_microagents[name] = microagent
|
|
||||||
|
|
||||||
def load_microagents(self, microagents: list[BaseMicroAgent]) -> None:
|
|
||||||
"""Load microagents from a list of BaseMicroAgents.
|
|
||||||
|
|
||||||
This is typically used when loading microagents from inside a repo.
|
|
||||||
"""
|
|
||||||
openhands_logger.info('Loading microagents: %s', [m.name for m in microagents])
|
|
||||||
# Only keep KnowledgeMicroAgents and RepoMicroAgents
|
|
||||||
for microagent in microagents:
|
|
||||||
if microagent.name in self.disabled_microagents:
|
|
||||||
continue
|
|
||||||
if isinstance(microagent, KnowledgeMicroAgent):
|
|
||||||
self.knowledge_microagents[microagent.name] = microagent
|
|
||||||
elif isinstance(microagent, RepoMicroAgent):
|
|
||||||
self.repo_microagents[microagent.name] = microagent
|
|
||||||
|
|
||||||
def _load_template(self, template_name: str) -> Template:
|
def _load_template(self, template_name: str) -> Template:
|
||||||
if self.prompt_dir is None:
|
if self.prompt_dir is None:
|
||||||
@@ -114,27 +55,6 @@ class PromptManager:
|
|||||||
def get_system_message(self) -> str:
|
def get_system_message(self) -> str:
|
||||||
return self.system_template.render().strip()
|
return self.system_template.render().strip()
|
||||||
|
|
||||||
def set_runtime_info(self, runtime: Runtime) -> None:
|
|
||||||
self.runtime_info.available_hosts = runtime.web_hosts
|
|
||||||
self.runtime_info.additional_agent_instructions = (
|
|
||||||
runtime.additional_agent_instructions
|
|
||||||
)
|
|
||||||
|
|
||||||
def set_repository_info(
|
|
||||||
self,
|
|
||||||
repo_name: str,
|
|
||||||
repo_directory: str,
|
|
||||||
) -> None:
|
|
||||||
"""Sets information about the GitHub repository that has been cloned.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
repo_name: The name of the GitHub repository (e.g. 'owner/repo')
|
|
||||||
repo_directory: The directory where the repository has been cloned
|
|
||||||
"""
|
|
||||||
self.repository_info = RepositoryInfo(
|
|
||||||
repo_name=repo_name, repo_directory=repo_directory
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_example_user_message(self) -> str:
|
def get_example_user_message(self) -> str:
|
||||||
"""This is the initial user message provided to the agent
|
"""This is the initial user message provided to the agent
|
||||||
before *actual* user instructions are provided.
|
before *actual* user instructions are provided.
|
||||||
@@ -148,45 +68,6 @@ class PromptManager:
|
|||||||
|
|
||||||
return self.user_template.render().strip()
|
return self.user_template.render().strip()
|
||||||
|
|
||||||
def enhance_message(self, message: Message) -> None:
|
|
||||||
"""Enhance the user message with additional context.
|
|
||||||
|
|
||||||
This method is used to enhance the user message with additional context
|
|
||||||
about the user's task. The additional context will convert the current
|
|
||||||
generic agent into a more specialized agent that is tailored to the user's task.
|
|
||||||
"""
|
|
||||||
if not message.content:
|
|
||||||
return
|
|
||||||
|
|
||||||
# if there were other texts included, they were before the user message
|
|
||||||
# so the last TextContent is the user message
|
|
||||||
# content can be a list of TextContent or ImageContent
|
|
||||||
message_content = ''
|
|
||||||
for content in reversed(message.content):
|
|
||||||
if isinstance(content, TextContent):
|
|
||||||
message_content = content.text
|
|
||||||
break
|
|
||||||
|
|
||||||
if not message_content:
|
|
||||||
return
|
|
||||||
|
|
||||||
triggered_agents = []
|
|
||||||
for name, microagent in self.knowledge_microagents.items():
|
|
||||||
trigger = microagent.match_trigger(message_content)
|
|
||||||
if trigger:
|
|
||||||
openhands_logger.info(
|
|
||||||
"Microagent '%s' triggered by keyword '%s'",
|
|
||||||
name,
|
|
||||||
trigger,
|
|
||||||
)
|
|
||||||
# Create a dictionary with the agent and trigger word
|
|
||||||
triggered_agents.append({'agent': microagent, 'trigger_word': trigger})
|
|
||||||
|
|
||||||
if triggered_agents:
|
|
||||||
formatted_text = self.build_microagent_info(triggered_agents)
|
|
||||||
# Insert the new content at the start of the TextContent list
|
|
||||||
message.content.insert(0, TextContent(text=formatted_text))
|
|
||||||
|
|
||||||
def add_examples_to_initial_message(self, message: Message) -> None:
|
def add_examples_to_initial_message(self, message: Message) -> None:
|
||||||
"""Add example_message to the first user message."""
|
"""Add example_message to the first user message."""
|
||||||
example_message = self.get_example_user_message() or None
|
example_message = self.get_example_user_message() or None
|
||||||
@@ -195,44 +76,28 @@ class PromptManager:
|
|||||||
if example_message:
|
if example_message:
|
||||||
message.content.insert(0, TextContent(text=example_message))
|
message.content.insert(0, TextContent(text=example_message))
|
||||||
|
|
||||||
def add_info_to_initial_message(
|
def build_workspace_context(
|
||||||
self,
|
self,
|
||||||
message: Message,
|
repository_info: RepositoryInfo | None,
|
||||||
) -> None:
|
runtime_info: RuntimeInfo | None,
|
||||||
"""Adds information about the repository and runtime to the initial user message.
|
repo_instructions: str = '',
|
||||||
|
) -> str:
|
||||||
Args:
|
"""Renders the additional info template with the stored repository/runtime info."""
|
||||||
message: The initial user message to add information to.
|
return self.additional_info_template.render(
|
||||||
"""
|
repository_info=repository_info,
|
||||||
repo_instructions = ''
|
|
||||||
assert (
|
|
||||||
len(self.repo_microagents) <= 1
|
|
||||||
), f'Expecting at most one repo microagent, but found {len(self.repo_microagents)}: {self.repo_microagents.keys()}'
|
|
||||||
for microagent in self.repo_microagents.values():
|
|
||||||
# We assume these are the repo instructions
|
|
||||||
if repo_instructions:
|
|
||||||
repo_instructions += '\n\n'
|
|
||||||
repo_instructions += microagent.content
|
|
||||||
|
|
||||||
additional_info = self.additional_info_template.render(
|
|
||||||
repository_instructions=repo_instructions,
|
repository_instructions=repo_instructions,
|
||||||
repository_info=self.repository_info,
|
runtime_info=runtime_info,
|
||||||
runtime_info=self.runtime_info,
|
|
||||||
).strip()
|
).strip()
|
||||||
|
|
||||||
# Insert the new content at the start of the TextContent list
|
|
||||||
if additional_info:
|
|
||||||
message.content.insert(0, TextContent(text=additional_info))
|
|
||||||
|
|
||||||
def build_microagent_info(
|
def build_microagent_info(
|
||||||
self,
|
self,
|
||||||
triggered_agents: list[dict],
|
triggered_agents: list[MicroagentKnowledge],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Renders the microagent info template with the triggered agents.
|
"""Renders the microagent info template with the triggered agents.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
triggered_agents: A list of dictionaries, each containing an "agent"
|
triggered_agents: A list of MicroagentKnowledge objects containing information
|
||||||
(KnowledgeMicroAgent) and a "trigger_word" (str).
|
about triggered microagents.
|
||||||
"""
|
"""
|
||||||
return self.microagent_info_template.render(
|
return self.microagent_info_template.render(
|
||||||
triggered_agents=triggered_agents
|
triggered_agents=triggered_agents
|
||||||
|
|||||||
Generated
+37
-37
@@ -496,18 +496,18 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "boto3"
|
name = "boto3"
|
||||||
version = "1.37.11"
|
version = "1.37.12"
|
||||||
description = "The AWS SDK for Python"
|
description = "The AWS SDK for Python"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "boto3-1.37.11-py3-none-any.whl", hash = "sha256:da6c22fc8a7e9bca5d7fc465a877ac3d45b6b086d776bd1a6c55bdde60523741"},
|
{file = "boto3-1.37.12-py3-none-any.whl", hash = "sha256:516feaa0d2afaeda1515216fd09291368a1215754bbccb0f28414c0a91a830a2"},
|
||||||
{file = "boto3-1.37.11.tar.gz", hash = "sha256:8eec08363ef5db05c2fbf58e89f0c0de6276cda2fdce01e76b3b5f423cd5c0f4"},
|
{file = "boto3-1.37.12.tar.gz", hash = "sha256:9412d404f103ad6d14f033eb29cd5e0cdca2b9b08cbfa9d4dabd1d7be2de2625"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
botocore = ">=1.37.11,<1.38.0"
|
botocore = ">=1.37.12,<1.38.0"
|
||||||
jmespath = ">=0.7.1,<2.0.0"
|
jmespath = ">=0.7.1,<2.0.0"
|
||||||
s3transfer = ">=0.11.0,<0.12.0"
|
s3transfer = ">=0.11.0,<0.12.0"
|
||||||
|
|
||||||
@@ -516,14 +516,14 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"]
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "botocore"
|
name = "botocore"
|
||||||
version = "1.37.11"
|
version = "1.37.12"
|
||||||
description = "Low-level, data-driven core of boto 3."
|
description = "Low-level, data-driven core of boto 3."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
groups = ["main"]
|
groups = ["main"]
|
||||||
files = [
|
files = [
|
||||||
{file = "botocore-1.37.11-py3-none-any.whl", hash = "sha256:02505309b1235f9f15a6da79103ca224b3f3dc5f6a62f8630fbb2c6ed05e2da8"},
|
{file = "botocore-1.37.12-py3-none-any.whl", hash = "sha256:ba1948c883bbabe20d95ff62c3e36954c9269686f7db9361857835677ca3e676"},
|
||||||
{file = "botocore-1.37.11.tar.gz", hash = "sha256:72eb3a9a58b064be26ba154e5e56373633b58f951941c340ace0d379590d98b5"},
|
{file = "botocore-1.37.12.tar.gz", hash = "sha256:ae2d5328ce6ad02eb615270507235a6e90fd3eeed615a6c0732b5a68b12f2017"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -3547,14 +3547,14 @@ test = ["jupyter-server (>=2.0.0)", "pytest (>=7.0)", "pytest-jupyter[server] (>
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "jupyterlab"
|
name = "jupyterlab"
|
||||||
version = "4.3.5"
|
version = "4.3.6"
|
||||||
description = "JupyterLab computational environment"
|
description = "JupyterLab computational environment"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
groups = ["runtime"]
|
groups = ["runtime"]
|
||||||
files = [
|
files = [
|
||||||
{file = "jupyterlab-4.3.5-py3-none-any.whl", hash = "sha256:571bbdee20e4c5321ab5195bc41cf92a75a5cff886be5e57ce78dfa37a5e9fdb"},
|
{file = "jupyterlab-4.3.6-py3-none-any.whl", hash = "sha256:fc9eb0455562a56a9bd6d2977cf090842f321fa1a298fcee9bf8c19de353d5fd"},
|
||||||
{file = "jupyterlab-4.3.5.tar.gz", hash = "sha256:c779bf72ced007d7d29d5bcef128e7fdda96ea69299e19b04a43635a7d641f9d"},
|
{file = "jupyterlab-4.3.6.tar.gz", hash = "sha256:2900ffdbfca9ed37c4ad7fdda3eb76582fd945d46962af3ac64741ae2d6b2ff4"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -4251,14 +4251,14 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "modal"
|
name = "modal"
|
||||||
version = "0.73.98"
|
version = "0.73.102"
|
||||||
description = "Python client library for Modal"
|
description = "Python client library for Modal"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
groups = ["main", "evaluation"]
|
groups = ["main", "evaluation"]
|
||||||
files = [
|
files = [
|
||||||
{file = "modal-0.73.98-py3-none-any.whl", hash = "sha256:a49cd5f5b46d1a6c6a0d528618d3cbb73ac2908e199716590ec3a5275d79ed98"},
|
{file = "modal-0.73.102-py3-none-any.whl", hash = "sha256:26151ef6164e0b93b0d1961f73d5a715deb72f23e2641215f5410cf58bf403d3"},
|
||||||
{file = "modal-0.73.98.tar.gz", hash = "sha256:817f73c222fa39a16d6888a92eb7a6847ecae574e44ef04e2dce5e534bdd2df9"},
|
{file = "modal-0.73.102.tar.gz", hash = "sha256:198876cf94ff13633283e251d8b37cc1f1bb5e27a7aa547e02072def1f29b66e"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
@@ -4670,19 +4670,19 @@ files = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "notebook"
|
name = "notebook"
|
||||||
version = "7.3.2"
|
version = "7.3.3"
|
||||||
description = "Jupyter Notebook - A web-based notebook environment for interactive computing"
|
description = "Jupyter Notebook - A web-based notebook environment for interactive computing"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
groups = ["runtime"]
|
groups = ["runtime"]
|
||||||
files = [
|
files = [
|
||||||
{file = "notebook-7.3.2-py3-none-any.whl", hash = "sha256:e5f85fc59b69d3618d73cf27544418193ff8e8058d5bf61d315ce4f473556288"},
|
{file = "notebook-7.3.3-py3-none-any.whl", hash = "sha256:b193df0878956562d5171c8e25c9252b8e86c9fcc16163b8ee3fe6c5e3f422f7"},
|
||||||
{file = "notebook-7.3.2.tar.gz", hash = "sha256:705e83a1785f45b383bf3ee13cb76680b92d24f56fb0c7d2136fe1d850cd3ca8"},
|
{file = "notebook-7.3.3.tar.gz", hash = "sha256:707a313fb882d35f921989eb3d204de942ed5132a44e4aa1fe0e8f24bb9dc25d"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
jupyter-server = ">=2.4.0,<3"
|
jupyter-server = ">=2.4.0,<3"
|
||||||
jupyterlab = ">=4.3.4,<4.4"
|
jupyterlab = ">=4.3.6,<4.4"
|
||||||
jupyterlab-server = ">=2.27.1,<3"
|
jupyterlab-server = ">=2.27.1,<3"
|
||||||
notebook-shim = ">=0.2,<0.3"
|
notebook-shim = ">=0.2,<0.3"
|
||||||
tornado = ">=6.2.0"
|
tornado = ">=6.2.0"
|
||||||
@@ -6947,30 +6947,30 @@ pyasn1 = ">=0.1.3"
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ruff"
|
name = "ruff"
|
||||||
version = "0.9.10"
|
version = "0.11.0"
|
||||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.7"
|
python-versions = ">=3.7"
|
||||||
groups = ["dev", "evaluation"]
|
groups = ["dev", "evaluation"]
|
||||||
files = [
|
files = [
|
||||||
{file = "ruff-0.9.10-py3-none-linux_armv6l.whl", hash = "sha256:eb4d25532cfd9fe461acc83498361ec2e2252795b4f40b17e80692814329e42d"},
|
{file = "ruff-0.11.0-py3-none-linux_armv6l.whl", hash = "sha256:dc67e32bc3b29557513eb7eeabb23efdb25753684b913bebb8a0c62495095acb"},
|
||||||
{file = "ruff-0.9.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:188a6638dab1aa9bb6228a7302387b2c9954e455fb25d6b4470cb0641d16759d"},
|
{file = "ruff-0.11.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:38c23fd9bdec4eb437b4c1e3595905a0a8edfccd63a790f818b28c78fe345639"},
|
||||||
{file = "ruff-0.9.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5284dcac6b9dbc2fcb71fdfc26a217b2ca4ede6ccd57476f52a587451ebe450d"},
|
{file = "ruff-0.11.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7c8661b0be91a38bd56db593e9331beaf9064a79028adee2d5f392674bbc5e88"},
|
||||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:47678f39fa2a3da62724851107f438c8229a3470f533894b5568a39b40029c0c"},
|
{file = "ruff-0.11.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b6c0e8d3d2db7e9f6efd884f44b8dc542d5b6b590fc4bb334fdbc624d93a29a2"},
|
||||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:99713a6e2766b7a17147b309e8c915b32b07a25c9efd12ada79f217c9c778b3e"},
|
{file = "ruff-0.11.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3c3156d3f4b42e57247275a0a7e15a851c165a4fc89c5e8fa30ea6da4f7407b8"},
|
||||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:524ee184d92f7c7304aa568e2db20f50c32d1d0caa235d8ddf10497566ea1a12"},
|
{file = "ruff-0.11.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:490b1e147c1260545f6d041c4092483e3f6d8eba81dc2875eaebcf9140b53905"},
|
||||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:df92aeac30af821f9acf819fc01b4afc3dfb829d2782884f8739fb52a8119a16"},
|
{file = "ruff-0.11.0-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:1bc09a7419e09662983b1312f6fa5dab829d6ab5d11f18c3760be7ca521c9329"},
|
||||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de42e4edc296f520bb84954eb992a07a0ec5a02fecb834498415908469854a52"},
|
{file = "ruff-0.11.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bcfa478daf61ac8002214eb2ca5f3e9365048506a9d52b11bea3ecea822bb844"},
|
||||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d257f95b65806104b6b1ffca0ea53f4ef98454036df65b1eda3693534813ecd1"},
|
{file = "ruff-0.11.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6fbb2aed66fe742a6a3a0075ed467a459b7cedc5ae01008340075909d819df1e"},
|
||||||
{file = "ruff-0.9.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b60dec7201c0b10d6d11be00e8f2dbb6f40ef1828ee75ed739923799513db24c"},
|
{file = "ruff-0.11.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:92c0c1ff014351c0b0cdfdb1e35fa83b780f1e065667167bb9502d47ca41e6db"},
|
||||||
{file = "ruff-0.9.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:d838b60007da7a39c046fcdd317293d10b845001f38bcb55ba766c3875b01e43"},
|
{file = "ruff-0.11.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e4fd5ff5de5f83e0458a138e8a869c7c5e907541aec32b707f57cf9a5e124445"},
|
||||||
{file = "ruff-0.9.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:ccaf903108b899beb8e09a63ffae5869057ab649c1e9231c05ae354ebc62066c"},
|
{file = "ruff-0.11.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:96bc89a5c5fd21a04939773f9e0e276308be0935de06845110f43fd5c2e4ead7"},
|
||||||
{file = "ruff-0.9.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:f9567d135265d46e59d62dc60c0bfad10e9a6822e231f5b24032dba5a55be6b5"},
|
{file = "ruff-0.11.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a9352b9d767889ec5df1483f94870564e8102d4d7e99da52ebf564b882cdc2c7"},
|
||||||
{file = "ruff-0.9.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5f202f0d93738c28a89f8ed9eaba01b7be339e5d8d642c994347eaa81c6d75b8"},
|
{file = "ruff-0.11.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:049a191969a10897fe052ef9cc7491b3ef6de79acd7790af7d7897b7a9bfbcb6"},
|
||||||
{file = "ruff-0.9.10-py3-none-win32.whl", hash = "sha256:bfb834e87c916521ce46b1788fbb8484966e5113c02df216680102e9eb960029"},
|
{file = "ruff-0.11.0-py3-none-win32.whl", hash = "sha256:3191e9116b6b5bbe187447656f0c8526f0d36b6fd89ad78ccaad6bdc2fad7df2"},
|
||||||
{file = "ruff-0.9.10-py3-none-win_amd64.whl", hash = "sha256:f2160eeef3031bf4b17df74e307d4c5fb689a6f3a26a2de3f7ef4044e3c484f1"},
|
{file = "ruff-0.11.0-py3-none-win_amd64.whl", hash = "sha256:c58bfa00e740ca0a6c43d41fb004cd22d165302f360aaa56f7126d544db31a21"},
|
||||||
{file = "ruff-0.9.10-py3-none-win_arm64.whl", hash = "sha256:5fd804c0327a5e5ea26615550e706942f348b197d5475ff34c19733aee4b2e69"},
|
{file = "ruff-0.11.0-py3-none-win_arm64.whl", hash = "sha256:868364fc23f5aa122b00c6f794211e85f7e78f5dffdf7c590ab90b8c4e69b657"},
|
||||||
{file = "ruff-0.9.10.tar.gz", hash = "sha256:9bacb735d7bada9cfb0f2c227d3658fc443d90a727b47f206fb33f52f3c0eac7"},
|
{file = "ruff-0.11.0.tar.gz", hash = "sha256:e55c620690a4a7ee6f1cccb256ec2157dc597d109400ae75bbf944fc9d6462e2"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -9056,4 +9056,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = "^3.12"
|
python-versions = "^3.12"
|
||||||
content-hash = "6a644bc65782a717a49718496bd279ecb888807ec625d992af4448cc5d9271c1"
|
content-hash = "9b74f62a4afa719a1f7167e0b3b45cdaf282c2e18fd2931da91c0f1b22776178"
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user