mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
33 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 75a1fad77e | |||
| 030d934621 | |||
| b1ac189aaa | |||
| 0fa92ccfe4 | |||
| 8ddad5a52c | |||
| 8f1182135f | |||
| ce8d857690 | |||
| 84b2b5a062 | |||
| 4076445a7a | |||
| e58d6a9e35 | |||
| 7d5e64507c | |||
| f3ef5e84dc | |||
| 41d4cb5d29 | |||
| c06772fbc6 | |||
| 4f8baf3698 | |||
| aa5e9f792c | |||
| a0c4d5217b | |||
| 5aeeaca0f0 | |||
| ba014c957e | |||
| 6c67517f56 | |||
| 2825bb6dc3 | |||
| ea3076364f | |||
| f6245b9a99 | |||
| 2e6fa13550 | |||
| 315d586b14 | |||
| 3774a459df | |||
| 4fde183c0b | |||
| 95e60953f1 | |||
| aab80f2975 | |||
| 83783c44b3 | |||
| 207d628817 | |||
| f51ecec3e7 | |||
| b89f4c1748 |
@@ -74,7 +74,7 @@ jobs:
|
||||
- name: Fix python lint issues
|
||||
run: |
|
||||
# Run all pre-commit hooks and continue even if they modify files (exit code 1)
|
||||
pre-commit run --config ./dev_config/python/.pre-commit-config.yaml --all-files || true
|
||||
pre-commit run --config ./dev_config/python/.pre-commit-config.yaml --files openhands/**/* evaluation/**/* tests/**/* || true
|
||||
|
||||
# Commit and push changes if any
|
||||
- name: Check for changes
|
||||
|
||||
@@ -53,7 +53,7 @@ jobs:
|
||||
- name: Install pre-commit
|
||||
run: pip install pre-commit==3.7.0
|
||||
- name: Run pre-commit hooks
|
||||
run: pre-commit run --all-files --show-diff-on-failure --config ./dev_config/python/.pre-commit-config.yaml
|
||||
run: pre-commit run --files openhands/**/* evaluation/**/* tests/**/* --show-diff-on-failure --config ./dev_config/python/.pre-commit-config.yaml
|
||||
|
||||
# Check version consistency across documentation
|
||||
check-version-consistency:
|
||||
|
||||
@@ -81,3 +81,4 @@ jobs:
|
||||
env:
|
||||
TEST_RUNTIME: local
|
||||
DEBUG: "1"
|
||||
|
||||
|
||||
@@ -189,7 +189,7 @@ install-pre-commit-hooks:
|
||||
|
||||
lint-backend:
|
||||
@echo "$(YELLOW)Running linters...$(RESET)"
|
||||
@poetry run pre-commit run --all-files --show-diff-on-failure --config $(PRE_COMMIT_CONFIG_PATH)
|
||||
@poetry run pre-commit run --files openhands/**/* evaluation/**/* tests/**/* --show-diff-on-failure --config $(PRE_COMMIT_CONFIG_PATH)
|
||||
|
||||
lint-frontend:
|
||||
@echo "$(YELLOW)Running linters for frontend...$(RESET)"
|
||||
|
||||
+2
-2
@@ -4,7 +4,7 @@
|
||||
npm install -g mint
|
||||
```
|
||||
|
||||
or
|
||||
or
|
||||
|
||||
```
|
||||
yarn global add mint
|
||||
@@ -14,4 +14,4 @@ yarn global add mint
|
||||
|
||||
```
|
||||
mint dev
|
||||
```
|
||||
```
|
||||
@@ -1,8 +1,10 @@
|
||||
---
|
||||
title: Slack Integration (Beta)
|
||||
title: Slack Integration - Coming soon...
|
||||
description: This guide walks you through installing the OpenHands Slack app.
|
||||
---
|
||||
|
||||
<Warning>This integration is not live yet, but will be available soon.</Warning>
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- You are a slack workspace admin
|
||||
|
||||
@@ -17,14 +17,13 @@ for scripting.
|
||||
pip install openhands-ai
|
||||
```
|
||||
|
||||
2. Launch an interactive OpenHands conversation from the command line:
|
||||
2. Set your model, API key, and other preferences using environment variables or with the [`config.toml`](https://github.com/All-Hands-AI/OpenHands/blob/main/config.template.toml) file.
|
||||
3. Launch an interactive OpenHands conversation from the command line:
|
||||
|
||||
```bash
|
||||
openhands
|
||||
```
|
||||
|
||||
3. Set your model, API key, and other preferences using the UI (or alternatively environment variables, below).
|
||||
|
||||
This command opens an interactive prompt where you can type tasks or commands and get responses from OpenHands.
|
||||
|
||||
#### For Developers
|
||||
|
||||
@@ -27,7 +27,7 @@ You can use the Settings page at any time to:
|
||||
- [Configure MCP servers](/usage/mcp).
|
||||
- [Connect to GitHub](/usage/how-to/gui-mode#github-setup) and [connect to GitLab](/usage/how-to/gui-mode#gitlab-setup)
|
||||
- Set application settings like your preferred language, notifications and other preferences.
|
||||
- [Manage custom secrets](/usage/how-to/gui-mode#secrets-management).
|
||||
- Generate custom secrets.
|
||||
|
||||
#### GitHub Setup
|
||||
|
||||
@@ -122,36 +122,6 @@ OpenHands automatically exports a `GITLAB_TOKEN` to the shell environment if pro
|
||||
</Accordion>
|
||||
</AccordionGroup>
|
||||
|
||||
#### Secrets Management
|
||||
|
||||
OpenHands provides a secrets manager that allows you to securely store and manage sensitive information that can be accessed by the agent during runtime, such as API keys. These secrets are automatically exported as environment variables in the agent's runtime environment.
|
||||
|
||||
1. **Accessing the Secrets Manager**:
|
||||
- In the Settings page, navigate to the `Secrets` tab.
|
||||
- You'll see a list of all your existing custom secrets (if any).
|
||||
|
||||
2. **Adding a New Secret**:
|
||||
- Click the `Add New Secret` button.
|
||||
- Fill in the following fields:
|
||||
- **Name**: A unique identifier for your secret (e.g., `AWS_ACCESS_KEY`). This will be the environment variable name.
|
||||
- **Value**: The sensitive information you want to store.
|
||||
- **Description** (optional): A brief description of what the secret is used for, which is also provided to the agent.
|
||||
- Click `Add Secret` to save.
|
||||
|
||||
3. **Editing a Secret**:
|
||||
- Click the `Edit` button next to the secret you want to modify.
|
||||
- You can update the name and description of the secret.
|
||||
- Note: For security reasons, you cannot view or edit the value of an existing secret. If you need to change the value, delete the secret and create a new one.
|
||||
|
||||
4. **Deleting a Secret**:
|
||||
- Click the `Delete` button next to the secret you want to remove.
|
||||
- Confirm the deletion when prompted.
|
||||
|
||||
5. **Using Secrets in the Agent**:
|
||||
- All custom secrets are automatically exported as environment variables in the agent's runtime environment.
|
||||
- You can access them in your code using standard environment variable access methods (e.g., `os.environ['SECRET_NAME']` in Python).
|
||||
- Example: If you create a secret named `OPENAI_API_KEY`, you can access it in your code as `process.env.OPENAI_API_KEY` in JavaScript or `os.environ['OPENAI_API_KEY']` in Python.
|
||||
|
||||
#### Advanced Settings
|
||||
|
||||
The `Advanced` settings allows configuration of additional LLM settings. Inside the Settings page, under the `LLM` tab,
|
||||
@@ -184,7 +154,7 @@ is loaded. Typically these include:
|
||||
## Tips for Effective Use
|
||||
|
||||
- Be specific in your requests to get the most accurate and helpful responses, as described in the [prompting best practices](../prompting/prompting-best-practices).
|
||||
- Use one of the recommended models, as described in the [LLMs section](/usage/llms/llms).
|
||||
- Use one of the recommended models, as described in the [LLMs section](usage/llms/llms.md).
|
||||
|
||||
## Other Ways to Run Openhands
|
||||
- [Run OpenHands in a scriptable headless mode.](/usage/how-to/headless-mode)
|
||||
|
||||
@@ -8,7 +8,7 @@ description: OpenHands uses LiteLLM to make calls to Google's chat models. You c
|
||||
When running OpenHands, you'll need to set the following in the OpenHands UI through the Settings under the `LLM` tab:
|
||||
- `LLM Provider` to `Gemini`
|
||||
- `LLM Model` to the model you will be using.
|
||||
If the model is not in the list, enable `Advanced` options, and enter it in `Custom Model`
|
||||
If the model is not in the list, enable `Advanced` options, and enter it in `Custom Model`
|
||||
(e.g. gemini/<model-name> like `gemini/gemini-2.0-flash`).
|
||||
- `API Key` to your Gemini API key
|
||||
|
||||
@@ -26,5 +26,5 @@ VERTEXAI_LOCATION="<your-gcp-location>"
|
||||
Then set the following in the OpenHands UI through the Settings under the `LLM` tab:
|
||||
- `LLM Provider` to `VertexAI`
|
||||
- `LLM Model` to the model you will be using.
|
||||
If the model is not in the list, enable `Advanced` options, and enter it in `Custom Model`
|
||||
If the model is not in the list, enable `Advanced` options, and enter it in `Custom Model`
|
||||
(e.g. vertex_ai/<model-name>).
|
||||
|
||||
@@ -8,7 +8,7 @@ description: OpenHands uses LiteLLM to make calls to chat models on Groq. You ca
|
||||
When running OpenHands, you'll need to set the following in the OpenHands UI through the Settings under the `LLM` tab:
|
||||
- `LLM Provider` to `Groq`
|
||||
- `LLM Model` to the model you will be using. [Visit here to see the list of
|
||||
models that Groq hosts](https://console.groq.com/docs/models). If the model is not in the list,
|
||||
models that Groq hosts](https://console.groq.com/docs/models). If the model is not in the list,
|
||||
enable `Advanced` options, and enter it in `Custom Model` (e.g. groq/<model-name> like `groq/llama3-70b-8192`).
|
||||
- `API key` to your Groq API key. To find or create your Groq API Key, [see here](https://console.groq.com/keys).
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ To use LiteLLM proxy with OpenHands, you need to:
|
||||
|
||||
## Supported Models
|
||||
|
||||
The supported models depend on your LiteLLM proxy configuration. OpenHands supports any model that your LiteLLM proxy
|
||||
The supported models depend on your LiteLLM proxy configuration. OpenHands supports any model that your LiteLLM proxy
|
||||
is configured to handle.
|
||||
|
||||
Refer to your LiteLLM proxy configuration for the list of available models and their names.
|
||||
|
||||
@@ -9,6 +9,6 @@ When running OpenHands, you'll need to set the following in the OpenHands UI thr
|
||||
* `LLM Provider` to `OpenRouter`
|
||||
* `LLM Model` to the model you will be using.
|
||||
[Visit here to see a full list of OpenRouter models](https://openrouter.ai/models).
|
||||
If the model is not in the list, enable `Advanced` options, and enter it in
|
||||
If the model is not in the list, enable `Advanced` options, and enter it in
|
||||
`Custom Model` (e.g. openrouter/<model-name> like `openrouter/anthropic/claude-3.5-sonnet`).
|
||||
* `API Key` to your OpenRouter API key.
|
||||
|
||||
@@ -5,7 +5,7 @@ description: Organizations and users can define microagents that apply to all re
|
||||
|
||||
## Usage
|
||||
|
||||
These microagents can be [any type of microagent](./microagents-overview#microagent-types) and will be loaded
|
||||
These microagents can be [any type of microagent](./microagents-overview#microagent-types) and will be loaded
|
||||
accordingly. However, they are applied to all repositories belonging to the organization or user.
|
||||
|
||||
Add a `.openhands` repository under the organization or user and create a `microagents` directory and place the
|
||||
|
||||
@@ -15,7 +15,7 @@ Before using the Local Runtime, ensure that:
|
||||
1. You can run OpenHands using the [Development workflow](https://github.com/All-Hands-AI/OpenHands/blob/main/Development.md).
|
||||
2. For Linux and Mac, tmux is available on your system.
|
||||
3. For Windows, PowerShell is available on your system.
|
||||
- Only [CLI mode](../how-to/cli-mode) and [headless mode](../how-to/headless-mode) are supported in Windows with Local Runtime.
|
||||
- Only [CLI mode](../how-to/cli-mode) and [headless mode](../how-to/headless-mode) are supported in Windows with Local Runtime.
|
||||
|
||||
## Configuration
|
||||
|
||||
|
||||
@@ -144,7 +144,7 @@ if __name__ == '__main__':
|
||||
llm_config = None
|
||||
if args.llm_config:
|
||||
llm_config = get_llm_config_arg(args.llm_config)
|
||||
# modify_params must be False for evaluation purpose, for reproducibility and accuracy of results
|
||||
# modify_params must be False for evaluation purpose, for reproducibility and accurancy of results
|
||||
llm_config.modify_params = False
|
||||
|
||||
if llm_config is None:
|
||||
|
||||
@@ -223,7 +223,7 @@ if __name__ == '__main__':
|
||||
llm_config = None
|
||||
if args.llm_config:
|
||||
llm_config = get_llm_config_arg(args.llm_config)
|
||||
# modify_params must be False for evaluation purpose, for reproducibility and accuracy of results
|
||||
# modify_params must be False for evaluation purpose, for reproducibility and accurancy of results
|
||||
llm_config.modify_params = False
|
||||
if llm_config is None:
|
||||
raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
|
||||
This folder contains the evaluation harness that we built on top of the original [SWE-Bench benchmark](https://www.swebench.com/) ([paper](https://arxiv.org/abs/2310.06770)).
|
||||
|
||||
**UPDATE (6/15/2025): We now support running SWE-bench-Live evaluation (see the paper [here](https://arxiv.org/abs/2505.23419))! For how to run it, checkout [this README](./SWE-bench-Live.md).**
|
||||
|
||||
**UPDATE (5/26/2025): We now support running interactive SWE-Bench evaluation (see the paper [here](https://arxiv.org/abs/2502.13069))! For how to run it, checkout [this README](./SWE-Interact.md).**
|
||||
|
||||
**UPDATE (4/8/2025): We now support running SWT-Bench evaluation! For more details, checkout [the corresponding section](#SWT-Bench-Evaluation).**
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
# SWE-bench-Live
|
||||
|
||||
<p align="center">
|
||||
<a href="https://arxiv.org/abs/2505.23419">📃 Paper</a>
|
||||
•
|
||||
<a href="https://huggingface.co/SWE-bench-Live" >🤗 HuggingFace</a>
|
||||
•
|
||||
<a href="https://SWE-bench-Live.github.io" >📊 Leaderboard</a>
|
||||
</p>
|
||||
|
||||
SWE-bench-Live is a live benchmark for issue resolving, providing a dataset that contains the latest issue tasks. This document explains how to run the evaluation of OpenHands on SWE-bench-Live.
|
||||
|
||||
Since SWE-bench-Live has an almost identical setting to SWE-bench, you only need to simply change the dataset name to `SWE-bench-Live/SWE-bench-Live`, the other parts are basically the same as running on SWE-bench.
|
||||
|
||||
## Setting Up
|
||||
|
||||
Set up the development environment and configure your LLM provider by following the [README](README.md).
|
||||
|
||||
## Running Inference
|
||||
|
||||
Use the same script, but change the dataset name to `SWE-bench-Live` and select the split (either `lite` or `full`). The lite split contains 300 instances from the past six months, while the full split includes 1,319 instances created after 2024.
|
||||
|
||||
```shell
|
||||
./evaluation/benchmarks/swe_bench/scripts/run_infer.sh [model_config] [git-version] [agent] [eval_limit] [max_iter] [num_workers] [dataset] [dataset_split]
|
||||
```
|
||||
|
||||
In the original SWE-bench-Live paper, max_iterations is set to 100.
|
||||
|
||||
```shell
|
||||
./evaluation/benchmarks/swe_bench/scripts/run_infer.sh llm.your_llm HEAD CodeActAgent 300 100 3 SWE-bench-Live/SWE-bench-Live lite
|
||||
```
|
||||
|
||||
## Evaluating Results
|
||||
|
||||
After OpenHands generates patch results for each issue, we evaluate the results using the [SWE-bench-Live evaluation harness](https://github.com/microsoft/SWE-bench-Live).
|
||||
|
||||
Convert to the format of predictions for SWE benchmarks:
|
||||
|
||||
```shell
|
||||
# You can find output.jsonl in evaluation/evaluation_outputs
|
||||
python evaluation/benchmarks/swe_bench/scripts/live/convert.py --output_jsonl [path/to/evaluation/output.jsonl] > preds.jsonl
|
||||
```
|
||||
|
||||
Please refer to the original [SWE-bench-Live repository](https://github.com/microsoft/SWE-bench-Live) to set up the evaluation harness and use the provided scripts to generate the evaluation report:
|
||||
|
||||
```shell
|
||||
python -m swebench.harness.run_evaluation \
|
||||
--dataset_name SWE-bench-Live/SWE-bench-Live \
|
||||
--split lite \
|
||||
--namespace starryzhang \
|
||||
--predictions_path preds.jsonl \
|
||||
--max_workers 10 \
|
||||
--run_id openhands
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{zhang2025swebenchgoeslive,
|
||||
title={SWE-bench Goes Live!},
|
||||
author={Linghao Zhang and Shilin He and Chaoyun Zhang and Yu Kang and Bowen Li and Chengxing Xie and Junhao Wang and Maoquan Wang and Yufan Huang and Shengyu Fu and Elsie Nallipogu and Qingwei Lin and Yingnong Dang and Saravan Rajmohan and Dongmei Zhang},
|
||||
journal={arXiv preprint arXiv:2505.23419},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
@@ -1,80 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from evaluation.utils.shared import assert_and_raise
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import CmdRunAction
|
||||
from openhands.events.observation import (
|
||||
CmdOutputObservation,
|
||||
ErrorObservation,
|
||||
)
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.utils.shutdown_listener import sleep_if_should_continue
|
||||
|
||||
|
||||
def complete_runtime(
|
||||
runtime: Runtime,
|
||||
instance: pd.Series,
|
||||
) -> dict[str, Any]:
|
||||
"""Complete the runtime and export the git patch for SWE-bench-Live."""
|
||||
logger.info('-' * 30)
|
||||
logger.info('BEGIN Runtime Completion Fn')
|
||||
logger.info('-' * 30)
|
||||
obs: CmdOutputObservation
|
||||
workspace_dir_name = instance.instance_id
|
||||
action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action)
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
isinstance(obs, CmdOutputObservation) and obs.exit_code == 0,
|
||||
f'Failed to cd to /workspace/{workspace_dir_name}: {str(obs)}',
|
||||
)
|
||||
action = CmdRunAction(command='git config --global core.pager ""')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
isinstance(obs, CmdOutputObservation) and obs.exit_code == 0,
|
||||
f'Failed to git config --global core.pager "": {str(obs)}',
|
||||
)
|
||||
action = CmdRunAction(command='git add -A')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
isinstance(obs, CmdOutputObservation) and obs.exit_code == 0,
|
||||
f'Failed to git add -A: {str(obs)}',
|
||||
)
|
||||
n_retries = 0
|
||||
git_patch = None
|
||||
while n_retries < 5:
|
||||
action = CmdRunAction(
|
||||
command=f'git diff --no-color --cached {instance["base_commit"]}',
|
||||
)
|
||||
action.set_hard_timeout(100 + 10 * n_retries)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
n_retries += 1
|
||||
if isinstance(obs, CmdOutputObservation):
|
||||
if obs.exit_code == 0:
|
||||
git_patch = obs.content.strip()
|
||||
break
|
||||
else:
|
||||
logger.info('Failed to get git diff, retrying...')
|
||||
sleep_if_should_continue(10)
|
||||
elif isinstance(obs, ErrorObservation):
|
||||
logger.error(f'Error occurred: {obs.content}. Retrying...')
|
||||
sleep_if_should_continue(10)
|
||||
else:
|
||||
assert_and_raise(False, f'Unexpected observation type: {str(obs)}')
|
||||
assert_and_raise(git_patch is not None, 'Failed to get git diff (None)')
|
||||
logger.info('-' * 30)
|
||||
logger.info('END Runtime Completion Fn')
|
||||
logger.info('-' * 30)
|
||||
return {'git_patch': git_patch}
|
||||
@@ -1,4 +1,4 @@
|
||||
TASK_INSTRUECTION = """
|
||||
TASK_INSTRUECTION="""
|
||||
Given the following GitHub problem description, your objective is to localize the specific files, classes or functions, and lines of code that need modification or contain key information to resolve the issue.
|
||||
|
||||
Follow these steps to localize the issue:
|
||||
@@ -66,4 +66,4 @@ FAKE_USER_MSG_FOR_LOC = (
|
||||
'Verify that you have carefully analyzed the impact of the found locations on the repository, especially their dependencies. '
|
||||
'If you think you have solved the task, please send your final answer (including the former answer and reranking) to user through message and then call `finish` to finish.\n'
|
||||
'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP.\n'
|
||||
)
|
||||
)
|
||||
@@ -27,7 +27,7 @@ You MUST plan extensively before each function call, and reflect extensively on
|
||||
5. Debug as needed. Use debugging techniques to isolate and resolve issues.
|
||||
6. Test frequently. Run tests after each change to verify correctness.
|
||||
7. Iterate until the root cause is fixed and all tests pass.
|
||||
8. Reflect and validate comprehensively. After tests pass, think about the original intent, write additional tests to ensure correctness,
|
||||
8. Reflect and validate comprehensively. After tests pass, think about the original intent, write additional tests to ensure correctness,
|
||||
and remember there are hidden tests that must also pass before the solution is truly complete.
|
||||
|
||||
Refer to the detailed sections below for more information on each step.
|
||||
|
||||
@@ -43,7 +43,7 @@ from openhands.core.config import (
|
||||
AgentConfig,
|
||||
OpenHandsConfig,
|
||||
get_llm_config_arg,
|
||||
get_parser,
|
||||
get_parser
|
||||
)
|
||||
from openhands.core.config.condenser_config import NoOpCondenserConfig
|
||||
from openhands.core.config.utils import get_condenser_config_arg
|
||||
@@ -66,26 +66,6 @@ RUN_WITH_BROWSING = os.environ.get('RUN_WITH_BROWSING', 'false').lower() == 'tru
|
||||
ENABLE_LLM_EDITOR = os.environ.get('ENABLE_LLM_EDITOR', 'false').lower() == 'true'
|
||||
BenchMode = Literal['swe', 'swt', 'swt-ci']
|
||||
|
||||
# Global variable to track dataset type
|
||||
DATASET_TYPE = 'SWE-bench'
|
||||
|
||||
|
||||
def set_dataset_type(dataset_name: str) -> str:
|
||||
"""Set dataset type based on dataset name."""
|
||||
global DATASET_TYPE
|
||||
name_lower = dataset_name.lower()
|
||||
|
||||
if 'swe-gym' in name_lower:
|
||||
DATASET_TYPE = 'SWE-Gym'
|
||||
elif 'swe-bench-live' in name_lower:
|
||||
DATASET_TYPE = 'SWE-bench-Live'
|
||||
elif 'multimodal' in name_lower:
|
||||
DATASET_TYPE = 'Multimodal'
|
||||
else:
|
||||
DATASET_TYPE = 'SWE-bench'
|
||||
|
||||
logger.info(f'Dataset type set to: {DATASET_TYPE}')
|
||||
|
||||
|
||||
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
|
||||
'CodeActAgent': codeact_user_response,
|
||||
@@ -93,10 +73,7 @@ AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
|
||||
|
||||
|
||||
def _get_swebench_workspace_dir_name(instance: pd.Series) -> str:
|
||||
if DATASET_TYPE == 'SWE-bench-Live':
|
||||
return instance.instance_id
|
||||
else:
|
||||
return f'{instance.repo}__{instance.version}'.replace('/', '__')
|
||||
return f'{instance.repo}__{instance.version}'.replace('/', '__')
|
||||
|
||||
|
||||
def get_instruction(instance: pd.Series, metadata: EvalMetadata) -> MessageAction:
|
||||
@@ -115,12 +92,10 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata) -> MessageActio
|
||||
elif 'gpt-4.1' in llm_model:
|
||||
template_name = 'swe_gpt4.j2'
|
||||
else:
|
||||
template_name = (
|
||||
'swe_default.j2' # Default for 'swe' mode (regular swe-bench)
|
||||
)
|
||||
template_name = 'swe_default.j2' # Default for 'swe' mode (regular swe-bench)
|
||||
else:
|
||||
# Fallback or error handling if mode is unexpected
|
||||
logger.error(f'Unexpected evaluation mode: {mode}. Falling back to default.')
|
||||
logger.error(f"Unexpected evaluation mode: {mode}. Falling back to default.")
|
||||
template_name = 'swe_default.j2'
|
||||
|
||||
# Set up Jinja2 environment
|
||||
@@ -142,7 +117,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata) -> MessageActio
|
||||
f'The following command can be used to run the tests: `{list(MAP_REPO_TO_TEST_FRAMEWORK_VERBOSE[instance.repo].values())[0]}`. Make sure they fail in the expected way.\n'
|
||||
)
|
||||
else:
|
||||
context['test_instructions'] = '' # Ensure it's defined for other modes
|
||||
context['test_instructions'] = '' # Ensure it's defined for other modes
|
||||
|
||||
# Render the instruction
|
||||
instruction = template.render(context)
|
||||
@@ -176,13 +151,9 @@ def get_instance_docker_image(
|
||||
if swebench_official_image:
|
||||
# Official SWE-Bench image
|
||||
# swebench/sweb.eval.x86_64.django_1776_django-11333:v1
|
||||
# SWE-bench-Live uses the same naming convention as SWE-Bench
|
||||
if DATASET_TYPE == 'SWE-bench-Live':
|
||||
docker_image_prefix = 'docker.io/starryzhang/'
|
||||
elif DATASET_TYPE == 'SWE-bench':
|
||||
docker_image_prefix = 'docker.io/swebench/'
|
||||
docker_image_prefix = 'docker.io/swebench/'
|
||||
repo, name = instance_id.split('__')
|
||||
image_name = f'{docker_image_prefix.rstrip("/")}/sweb.eval.x86_64.{repo}_1776_{name}:latest'.lower()
|
||||
image_name = f'swebench/sweb.eval.x86_64.{repo}_1776_{name}:latest'.lower()
|
||||
logger.debug(f'Using official SWE-Bench image: {image_name}')
|
||||
return image_name
|
||||
else:
|
||||
@@ -200,8 +171,7 @@ def get_config(
|
||||
metadata: EvalMetadata,
|
||||
) -> OpenHandsConfig:
|
||||
# We use a different instance image for the each instance of swe-bench eval
|
||||
use_swebench_official_image = DATASET_TYPE != 'SWE-Gym'
|
||||
|
||||
use_swebench_official_image = 'swe-gym' not in metadata.dataset.lower()
|
||||
base_container_image = get_instance_docker_image(
|
||||
instance['instance_id'],
|
||||
swebench_official_image=use_swebench_official_image,
|
||||
@@ -318,12 +288,8 @@ def initialize_runtime(
|
||||
runtime.copy_to(temp_file_path, '/swe_util/eval_data/instances/')
|
||||
|
||||
# inject the instance swe entry
|
||||
if DATASET_TYPE == 'SWE-bench-Live':
|
||||
entry_script_path = 'instance_swe_entry_live.sh'
|
||||
else:
|
||||
entry_script_path = 'instance_swe_entry.sh'
|
||||
runtime.copy_to(
|
||||
str(os.path.join(script_dir, f'scripts/setup/{entry_script_path}')),
|
||||
str(os.path.join(script_dir, 'scripts/setup/instance_swe_entry.sh')),
|
||||
'/swe_util/',
|
||||
)
|
||||
|
||||
@@ -343,14 +309,14 @@ def initialize_runtime(
|
||||
logger.error(f'Failed to source ~/.bashrc: {str(obs)}')
|
||||
assert_and_raise(obs.exit_code == 0, f'Failed to source ~/.bashrc: {str(obs)}')
|
||||
|
||||
action = CmdRunAction(command=f'source /swe_util/{entry_script_path}')
|
||||
action = CmdRunAction(command='source /swe_util/instance_swe_entry.sh')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
assert_and_raise(
|
||||
obs.exit_code == 0,
|
||||
f'Failed to source /swe_util/{entry_script_path}: {str(obs)}',
|
||||
f'Failed to source /swe_util/instance_swe_entry.sh: {str(obs)}',
|
||||
)
|
||||
|
||||
action = CmdRunAction(command=f'cd /workspace/{workspace_dir_name}')
|
||||
@@ -403,9 +369,9 @@ def initialize_runtime(
|
||||
obs = runtime.run_action(action)
|
||||
logger.info(obs, extra={'msg_type': 'OBSERVATION'})
|
||||
|
||||
if DATASET_TYPE != 'Multimodal' and DATASET_TYPE != 'SWE-bench-Live':
|
||||
if 'multimodal' not in metadata.dataset.lower():
|
||||
# Only for non-multimodal datasets, we need to activate the testbed environment for Python
|
||||
# SWE-Bench multimodal datasets and SWE-bench-Live are not using the testbed environment
|
||||
# SWE-Bench multimodal datasets are not using the testbed environment
|
||||
action = CmdRunAction(command='which python')
|
||||
action.set_hard_timeout(600)
|
||||
logger.info(action, extra={'msg_type': 'ACTION'})
|
||||
@@ -647,13 +613,7 @@ def process_instance(
|
||||
|
||||
# ======= THIS IS SWE-Bench specific =======
|
||||
# Get git patch
|
||||
if DATASET_TYPE == 'SWE-bench-Live':
|
||||
from evaluation.benchmarks.swe_bench.live_utils import (
|
||||
complete_runtime as complete_runtime_fn,
|
||||
)
|
||||
else:
|
||||
complete_runtime_fn = complete_runtime
|
||||
return_val = complete_runtime_fn(runtime, instance)
|
||||
return_val = complete_runtime(runtime, instance)
|
||||
git_patch = return_val['git_patch']
|
||||
logger.info(
|
||||
f'Got git diff for instance {instance.instance_id}:\n--------\n{git_patch}\n--------'
|
||||
@@ -758,15 +718,11 @@ if __name__ == '__main__':
|
||||
# NOTE: It is preferable to load datasets from huggingface datasets and perform post-processing
|
||||
# so we don't need to manage file uploading to OpenHands's repo
|
||||
dataset = load_dataset(args.dataset, split=args.split)
|
||||
|
||||
# Set the global dataset type based on dataset name
|
||||
set_dataset_type(args.dataset)
|
||||
|
||||
swe_bench_tests = filter_dataset(dataset.to_pandas(), 'instance_id')
|
||||
logger.info(
|
||||
f'Loaded dataset {args.dataset} with split {args.split}: {len(swe_bench_tests)} tasks'
|
||||
)
|
||||
if DATASET_TYPE == 'SWE-Gym':
|
||||
if 'SWE-Gym' in args.dataset:
|
||||
with open(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
|
||||
@@ -192,8 +192,6 @@ def get_config(
|
||||
dataset_name=metadata.dataset,
|
||||
instance_id=instance['instance_id'],
|
||||
)
|
||||
oh_aci_li_cmd = '/openhands/micromamba/bin/micromamba run -n openhands poetry run pip install openhands-aci[llama]'
|
||||
sandbox_config.runtime_extra_deps = oh_aci_li_cmd
|
||||
workspace_dir_name = _get_swebench_workspace_dir_name(instance)
|
||||
sandbox_config.runtime_startup_env_vars = {
|
||||
'REPO_PATH': f'/workspace/{workspace_dir_name}/',
|
||||
@@ -218,7 +216,6 @@ def get_config(
|
||||
enable_jupyter=False,
|
||||
enable_browsing=RUN_WITH_BROWSING,
|
||||
enable_llm_editor=False,
|
||||
enable_mcp=os.environ.get('ENABLE_MCP', False),
|
||||
condenser=metadata.condenser_config,
|
||||
enable_prompt_extensions=False,
|
||||
)
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
import argparse
|
||||
import json
|
||||
|
||||
|
||||
def main(output_jsonl: str):
|
||||
with open(output_jsonl, 'r') as f:
|
||||
for line in f:
|
||||
try:
|
||||
output = json.loads(line)
|
||||
pred = {
|
||||
'instance_id': output['instance_id'],
|
||||
'model_name_or_path': output['metadata']['llm_config']['model'],
|
||||
'model_patch': output['test_result']['git_patch'],
|
||||
}
|
||||
except Exception as e:
|
||||
print(
|
||||
f'Error while reading output of instance {output["instance_id"]}: {e}'
|
||||
)
|
||||
|
||||
print(json.dumps(pred))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--output_jsonl',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Path to the prediction file (.../outputs.jsonl)',
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args.output_jsonl)
|
||||
@@ -1,41 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
source ~/.bashrc
|
||||
SWEUTIL_DIR=/swe_util
|
||||
|
||||
# FIXME: Cannot read SWE_INSTANCE_ID from the environment variable
|
||||
# SWE_INSTANCE_ID=django__django-11099
|
||||
if [ -z "$SWE_INSTANCE_ID" ]; then
|
||||
echo "Error: SWE_INSTANCE_ID is not set." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Read the swe-bench-test-lite.json file and extract the required item based on instance_id
|
||||
item=$(jq --arg INSTANCE_ID "$SWE_INSTANCE_ID" '.[] | select(.instance_id == $INSTANCE_ID)' $SWEUTIL_DIR/eval_data/instances/swe-bench-instance.json)
|
||||
|
||||
if [[ -z "$item" ]]; then
|
||||
echo "No item found for the provided instance ID."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
echo "WORKSPACE_NAME: $SWE_INSTANCE_ID"
|
||||
|
||||
# Clear the workspace
|
||||
if [ -d /workspace ]; then
|
||||
rm -rf /workspace/*
|
||||
else
|
||||
mkdir /workspace
|
||||
fi
|
||||
# Copy repo to workspace
|
||||
if [ -d /workspace/$SWE_INSTANCE_ID ]; then
|
||||
rm -rf /workspace/$SWE_INSTANCE_ID
|
||||
fi
|
||||
mkdir -p /workspace
|
||||
cp -r /testbed /workspace/$SWE_INSTANCE_ID
|
||||
|
||||
# SWE-bench-Live does not use conda to manage Python
|
||||
# if [ -d /opt/miniconda3 ]; then
|
||||
# . /opt/miniconda3/etc/profile.d/conda.sh
|
||||
# conda activate testbed
|
||||
# fi
|
||||
@@ -1,102 +0,0 @@
|
||||
# VersiCode benchmark
|
||||
|
||||
This project is used to evaluate the performance of the model on VersiCode. It includes:
|
||||
|
||||
- data: the test data needed and the model outputs
|
||||
- inference_utils: inference scripts for ours tasks and models
|
||||
- metric: scripts for calculating various metric
|
||||
- output_processing: process the model output to facilitate the calculation of model metrics
|
||||
|
||||
# Details
|
||||
|
||||
1. **Prepare the environment**
|
||||
|
||||
```shell
|
||||
#create conda environment
|
||||
conda create -n VersiCode python==3.12
|
||||
|
||||
#install requirements
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. **Experiment Data**
|
||||
|
||||
To obtain the experimental data, please visit the Hugging Face link: https://huggingface.co/datasets/AstoneNg/VersiCode.
|
||||
Locate the files `VersiCode_block_completion.json` and `VersiCode_migration.json` under the `experiment_data` directory, and place them in the `/data/test_data directory` of this project.
|
||||
|
||||
|
||||
3. **Model inference**
|
||||
|
||||
```shell
|
||||
#cd inference_utils directory
|
||||
cd inference_utils
|
||||
|
||||
#The script file starting with 'test' is used to test the local model
|
||||
#The script file at the beginning of the API is used to test the API call model
|
||||
|
||||
#block level code completipn
|
||||
#Modify the 10th and 12th lines of code to specify the base URL and model name
|
||||
python api_test_block_completion.py
|
||||
#Modify the 30th line of code to specify the local model path
|
||||
python test_block.py
|
||||
|
||||
# code migration (migration order is 'old_to_new')
|
||||
#Modify the 10th and 12th lines of code to specify the base URL and model name
|
||||
python api_code_migration.py
|
||||
#Modify the 30th line of code to specify the local model path
|
||||
python test_migration.py
|
||||
```
|
||||
|
||||
4. **Process output**
|
||||
Process the output content of the model, remove redundant content, extract specified content for easy calculation of indicators.
|
||||
|
||||
```shell
|
||||
#cd output_processing
|
||||
cd output_processing
|
||||
|
||||
#Extract content from<start> and <end>
|
||||
#Modify the 8th and 9th lines of code to specify the model and task granularity
|
||||
python clear_ans.py
|
||||
|
||||
#In the block completion task and migration task, cdc@k The calculation of indicators needs to be targeted at key rows,
|
||||
#Modify lines 76 and 79 to specify the data path
|
||||
python choose_core_line_from_block_versicode.py
|
||||
python choose_core_line_from_migration_versicode.py
|
||||
```
|
||||
|
||||
5. **Metric**
|
||||
We have three metrics pass@k,em@k and cdc@k Due to our inability to automatically build a dynamic evaluation environment, we have not provided pass@k .
|
||||
|
||||
```shell
|
||||
#cd metric
|
||||
cd metric
|
||||
|
||||
#Modify lines 137-140 in migration task (compute_migration_cdc_score.py) or 143-145 in block and line completion task (compute_versicode_cdc_score.py and compute_versicode_em_score.py) of the code to specify the data path and calculate the k-value of the metric
|
||||
python compute_migration_cdc_score.py
|
||||
python compute_versicode_cdc_score.py
|
||||
python compute_versicode_em_score.py
|
||||
|
||||
#Notes
|
||||
#We found limitations in the ISM@k and PM@k metrics for evaluating code generation, so they are used only as reference in our experiments.
|
||||
#Modify lines 261-265 in block and line completion task of the code to specify the data path and calculate the k-value of the metric
|
||||
python compute_ism_pm_score.py
|
||||
```
|
||||
|
||||
# Citation
|
||||
|
||||
```
|
||||
@article{versicode,
|
||||
author={Tongtong Wu and Weigang Wu and Xingyu Wang and Kang Xu and Suyu Ma and Bo Jiang and Ping Yang and Zhenchang Xing and Yuan-Fang Li and Gholamreza Haffari},
|
||||
title = {VersiCode: Towards Version-controllable Code Generation},
|
||||
journal = {CoRR},
|
||||
volume = {abs/2406.07411},
|
||||
year = {2024},
|
||||
url = {https://arxiv.org/abs/2406.07411},
|
||||
}
|
||||
```
|
||||
|
||||
**Github url**: https://github.com/wutong8023/VersiCode
|
||||
|
||||
# Contributor
|
||||
|
||||
[Tongtong Wu](https://scholar.google.com/citations?hl=zh-CN&user=u1Qp8lUAAAAJ&view_op=list_works&sortby=pubdate), [Weigang Wu](https://scholar.google.com/citations?hl=zh-CN&user=UneIZo8AAAAJ), [Xingyu Wang](https://scholar.google.com/citations?hl=zh-CN&user=wqPJcxcAAAAJ), [Kang Xu](https://scholar.google.com/citations?hl=zh-CN&user=N1UUDi0AAAAJ), [Suyu Ma](https://scholar.google.com/citations?hl=zh-CN&user=NJHR1ukAAAAJ), [Bo Jiang](https://wutong8023.site/VersiCode/), [Ping Yang](https://scholar.google.com/citations?view_op=list_works&hl=en&hl=en&user=hrogvxoAAAAJ), [Zhenchang Xing](https://scholar.google.com/citations?hl=zh-CN&user=0vCxuH4AAAAJ), [Yuan-Fang Li](https://scholar.google.com/citations?hl=zh-CN&user=wufXO1kAAAAJ), [Gholamreza Haffari](https://scholar.google.com/citations?hl=zh-CN&user=Perjx5EAAAAJ)
|
||||
@@ -1,134 +0,0 @@
|
||||
"""
|
||||
GPT performs line level generation prediction and truncates overly long tokens
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import tiktoken
|
||||
from openai import OpenAI
|
||||
|
||||
max_tokens = 127000 # gpt3.5 is 16ktoken gpt4o is 128k
|
||||
model_name = ''
|
||||
|
||||
os.environ['OPENAI_API_KEY'] = ''
|
||||
client = OpenAI()
|
||||
|
||||
|
||||
def truncate_text(text, max_tokens):
|
||||
encoding = tiktoken.get_encoding('cl100k_base')
|
||||
disallowed_special = ()
|
||||
|
||||
tokens = encoding.encode(text, disallowed_special=disallowed_special)
|
||||
print(len(tokens))
|
||||
|
||||
if len(tokens) > max_tokens:
|
||||
tokens = tokens[:max_tokens]
|
||||
|
||||
truncated_text = encoding.decode(tokens)
|
||||
|
||||
return truncated_text
|
||||
|
||||
|
||||
def predict(content, model_name):
|
||||
response = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[{'role': 'user', 'content': content}],
|
||||
frequency_penalty=0.1,
|
||||
max_tokens=128,
|
||||
logit_bias=None,
|
||||
logprobs=None,
|
||||
n=6,
|
||||
presence_penalty=0.0,
|
||||
seed=None,
|
||||
stop=None,
|
||||
stream=False,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
)
|
||||
ans_list = []
|
||||
choices_list = response.choices
|
||||
for c in choices_list:
|
||||
content = c.message.content
|
||||
ans_list.append(content)
|
||||
final_ans = str(ans_list)
|
||||
return final_ans
|
||||
|
||||
|
||||
def bulid_prompt(description, old_version, old_code, new_version) -> str:
|
||||
"""
|
||||
build prompt
|
||||
:param version:
|
||||
:param description:
|
||||
:param masked_code:
|
||||
:param options:
|
||||
:return:
|
||||
"""
|
||||
prompt = f"""
|
||||
You are now a professional Python programming engineer. I will provide you with a code snippet and a description of its functionality,
|
||||
including the dependencies and versions used in the code. Then, I will provide the same dependencies but with a specified new version.
|
||||
Your task is to refactor the code using the methods provided by the specified new version and return the refactored code.
|
||||
Please note that you only need to return the refactored code and enclose it with <start> and <end>:
|
||||
###Functionality description of the code
|
||||
{description}
|
||||
###Dependency and old version
|
||||
{old_version}
|
||||
###Old version code
|
||||
{old_code}
|
||||
###Dependency and new version
|
||||
{new_version}
|
||||
###Refactored new code
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
json_path = '../data/test_data/VersiCode_migration.json'
|
||||
|
||||
|
||||
with open(json_path, 'r', encoding='utf-8') as fr:
|
||||
lodict = json.load(fr)
|
||||
data_dict = lodict
|
||||
data_list = data_dict
|
||||
|
||||
|
||||
for data in data_list:
|
||||
if 'model_output' in data:
|
||||
print(
|
||||
f'the {data_list.index(data) + 1} has already been predicted, skipping this data!'
|
||||
)
|
||||
continue
|
||||
try:
|
||||
print(f'Predicting {data_list.index(data) + 1} ')
|
||||
old_version = data['dependency'] + data['old_version'] # package == x.x.x
|
||||
new_version = data['dependency'] + data['new_version'] # package == x.x.x
|
||||
description = data['description'] # 功能描述
|
||||
old_code = data['old_code'] # mask后的代码
|
||||
|
||||
instruction = bulid_prompt(description, old_version, old_code, new_version)
|
||||
truncated_text = truncate_text(instruction, max_tokens)
|
||||
prediction = predict(truncated_text, model_name)
|
||||
|
||||
data['model_output'] = prediction
|
||||
except Exception as e:
|
||||
print(f'error:{e}')
|
||||
print('save current data')
|
||||
save_folder_path = os.path.join(
|
||||
'../data/result_data/code_migration', model_name
|
||||
)
|
||||
if not os.path.exists(save_folder_path):
|
||||
os.makedirs(save_folder_path)
|
||||
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
||||
|
||||
with open(save_json_path, 'w', encoding='utf-8') as fw:
|
||||
json.dump(data_dict, fw, indent=4, ensure_ascii=False)
|
||||
break
|
||||
|
||||
|
||||
save_folder_path = os.path.join('../data/result_data/code_migration', model_name)
|
||||
if not os.path.exists(save_folder_path):
|
||||
os.makedirs(save_folder_path)
|
||||
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
||||
|
||||
with open(save_json_path, 'w', encoding='utf-8') as fw:
|
||||
json.dump(data_dict, fw, indent=4, ensure_ascii=False)
|
||||
@@ -1,141 +0,0 @@
|
||||
"""
|
||||
GPT performs line level generation prediction and truncates overly long tokens
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import tiktoken
|
||||
from openai import OpenAI
|
||||
|
||||
max_tokens = 127000 # gpt3.5 is 16ktoken gpt4o is 128k
|
||||
model_name = ''
|
||||
|
||||
os.environ['OPENAI_API_KEY'] = ''
|
||||
client = OpenAI()
|
||||
|
||||
|
||||
def truncate_text(text, max_tokens):
|
||||
encoding = tiktoken.get_encoding('cl100k_base')
|
||||
disallowed_special = ()
|
||||
|
||||
tokens = encoding.encode(text, disallowed_special=disallowed_special)
|
||||
print(len(tokens))
|
||||
|
||||
if len(tokens) > max_tokens:
|
||||
tokens = tokens[:max_tokens]
|
||||
|
||||
truncated_text = encoding.decode(tokens)
|
||||
|
||||
return truncated_text
|
||||
|
||||
|
||||
def predict(content, model_name):
|
||||
response = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[{'role': 'user', 'content': content}],
|
||||
frequency_penalty=0.1,
|
||||
max_tokens=128,
|
||||
logit_bias=None,
|
||||
logprobs=None,
|
||||
n=6,
|
||||
presence_penalty=0.0,
|
||||
seed=None,
|
||||
stop=None,
|
||||
stream=False,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
)
|
||||
ans_list = []
|
||||
choices_list = response.choices
|
||||
for c in choices_list:
|
||||
content = c.message.content
|
||||
ans_list.append(content)
|
||||
final_ans = str(ans_list)
|
||||
return final_ans
|
||||
|
||||
|
||||
def bulid_prompt(version, description) -> str:
|
||||
"""
|
||||
build prompt
|
||||
:param version:
|
||||
:param description:
|
||||
:param masked_code:
|
||||
:param options:
|
||||
:return:
|
||||
"""
|
||||
prompt = f"""
|
||||
You are a professional Python engineer, and I will provide functional descriptions and versions of specified dependency packages.
|
||||
You need to write code in Python to implement this feature based on the functional description and using the dependency package and version I specified.
|
||||
Please note that you only need to return the code that implements the function, and do not return any other content.
|
||||
Please use <start> and <end> to enclose the generated code. Here is an example:
|
||||
###Function Description:
|
||||
The function of this code is to print the results predicted by calling the model using vllm.
|
||||
###dependeny and version:
|
||||
vllm==0.3.3
|
||||
###response:
|
||||
<start>
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print("Prompt,Generated text")
|
||||
<end>
|
||||
|
||||
###Function Description:
|
||||
{description}
|
||||
###dependeny and version:
|
||||
{version}
|
||||
###response:
|
||||
|
||||
|
||||
"""
|
||||
return prompt
|
||||
|
||||
|
||||
json_path = '../data/test_data/VersiCode_block_completion.json'
|
||||
|
||||
|
||||
with open(json_path, 'r', encoding='utf-8') as fr:
|
||||
lodict = json.load(fr)
|
||||
data_dict = lodict
|
||||
data_list = data_dict
|
||||
|
||||
|
||||
for data in data_list:
|
||||
if 'model_output' in data:
|
||||
print(
|
||||
f'the {data_list.index(data) + 1} has already been predicted, skipping this data!'
|
||||
)
|
||||
continue
|
||||
try:
|
||||
print(f'Predicting {data_list.index(data) + 1} ')
|
||||
version = data['dependency'] + data['version'] # package == x.x.x
|
||||
description = data['description'] # func description
|
||||
|
||||
instruction = bulid_prompt(version, description)
|
||||
truncated_text = truncate_text(instruction, max_tokens)
|
||||
prediction = predict(truncated_text, model_name)
|
||||
|
||||
data['model_output'] = prediction
|
||||
except Exception as e:
|
||||
print(f'error:{e}')
|
||||
print('save current data')
|
||||
save_folder_path = os.path.join(
|
||||
'../data/result_data/block_completion', model_name
|
||||
)
|
||||
if not os.path.exists(save_folder_path):
|
||||
os.makedirs(save_folder_path)
|
||||
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
||||
|
||||
with open(save_json_path, 'w', encoding='utf-8') as fw:
|
||||
json.dump(data_dict, fw, indent=4, ensure_ascii=False)
|
||||
break
|
||||
|
||||
|
||||
save_folder_path = os.path.join('../data/result_data/block_completion', model_name)
|
||||
if not os.path.exists(save_folder_path):
|
||||
os.makedirs(save_folder_path)
|
||||
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
||||
|
||||
with open(save_json_path, 'w', encoding='utf-8') as fw:
|
||||
json.dump(data_dict, fw, indent=4, ensure_ascii=False)
|
||||
@@ -1,129 +0,0 @@
|
||||
"""
|
||||
block completion
|
||||
"""
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from multiprocessing import Process
|
||||
|
||||
import tiktoken
|
||||
import torch
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||
|
||||
|
||||
def truncate_text(text, max_tokens):
|
||||
encoding = tiktoken.get_encoding('cl100k_base')
|
||||
disallowed_special = ()
|
||||
|
||||
tokens = encoding.encode(text, disallowed_special=disallowed_special)
|
||||
print(len(tokens))
|
||||
|
||||
if len(tokens) > max_tokens:
|
||||
tokens = tokens[:max_tokens]
|
||||
|
||||
truncated_text = encoding.decode(tokens)
|
||||
|
||||
return truncated_text
|
||||
|
||||
|
||||
model_list = ['/data2/base models/starcoder2-15b', '/data2/base models/CodeGemma-7B']
|
||||
|
||||
|
||||
def run_inference(model_name, origin_data_list):
|
||||
temp_data_list = copy.deepcopy(origin_data_list)
|
||||
test_list = []
|
||||
for data in temp_data_list:
|
||||
version = data['dependency'] + data['version'] # package == x.x.x
|
||||
description = data['description'] # func description
|
||||
|
||||
instruction = bulid_prompt(version, description)
|
||||
test_list.append(instruction)
|
||||
|
||||
sampling_params = SamplingParams(n=6, temperature=0.8, top_p=0.95, max_tokens=64)
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=4,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=20,
|
||||
)
|
||||
|
||||
outputs = llm.generate(test_list, sampling_params)
|
||||
for output in outputs:
|
||||
requests_id = int(output.request_id)
|
||||
temp_ans_list = []
|
||||
output_list = output.outputs
|
||||
for o in output_list:
|
||||
text = o.text
|
||||
temp_ans_list.append(text)
|
||||
|
||||
temp_data_list[requests_id]['model_output'] = str(temp_ans_list)
|
||||
|
||||
save_folder_path = os.path.join(
|
||||
'../data/result_data/block_completion', model_name.split('/')[-1]
|
||||
)
|
||||
if not os.path.exists(save_folder_path):
|
||||
os.makedirs(save_folder_path)
|
||||
|
||||
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
||||
|
||||
with open(save_json_path, 'w', encoding='utf-8') as fw:
|
||||
json.dump(temp_data_list, fw, indent=4, ensure_ascii=False)
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def bulid_prompt(version, description) -> str:
|
||||
"""
|
||||
build prompt
|
||||
:param version:
|
||||
:param description:
|
||||
:param masked_code:
|
||||
:param options:
|
||||
:return:
|
||||
"""
|
||||
prompt = f"""
|
||||
You are a professional Python engineer, and I will provide functional descriptions and versions of specified dependency packages.
|
||||
You need to write code in Python to implement this feature based on the functional description and using the dependency package and version I specified.
|
||||
Please note that you only need to return the code that implements the function, and do not return any other content.
|
||||
Please use <start> and <end> to enclose the generated code. Here is an example:
|
||||
###Function Description:
|
||||
The function of this code is to print the results predicted by calling the model using vllm.
|
||||
###dependeny and version:
|
||||
vllm==0.3.3
|
||||
###response:
|
||||
<start>
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print("Prompt,Generated text")
|
||||
<end>
|
||||
|
||||
###Function Description:
|
||||
{description}
|
||||
###dependeny and version:
|
||||
{version}
|
||||
###response:
|
||||
|
||||
|
||||
"""
|
||||
return prompt
|
||||
|
||||
|
||||
json_path = '../data/test_data/VersiCode_block_completion.json'
|
||||
|
||||
with open(json_path, 'r', encoding='utf-8') as fr:
|
||||
lodict = json.load(fr)
|
||||
|
||||
origin_data_list = lodict
|
||||
|
||||
for model_name in model_list:
|
||||
process = Process(target=run_inference, args=(model_name, origin_data_list))
|
||||
process.start()
|
||||
process.join()
|
||||
time.sleep(120)
|
||||
@@ -1,122 +0,0 @@
|
||||
"""
|
||||
code migration
|
||||
"""
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from multiprocessing import Process
|
||||
|
||||
import tiktoken
|
||||
import torch
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||
|
||||
|
||||
def truncate_text(text, max_tokens):
|
||||
encoding = tiktoken.get_encoding('cl100k_base')
|
||||
disallowed_special = ()
|
||||
|
||||
tokens = encoding.encode(text, disallowed_special=disallowed_special)
|
||||
print(len(tokens))
|
||||
|
||||
if len(tokens) > max_tokens:
|
||||
tokens = tokens[:max_tokens]
|
||||
|
||||
truncated_text = encoding.decode(tokens)
|
||||
|
||||
return truncated_text
|
||||
|
||||
|
||||
model_list = ['/data2/base models/starcoder2-15b', '/data2/base models/CodeGemma-7B']
|
||||
|
||||
|
||||
def run_inference(model_name, origin_data_list):
|
||||
temp_data_list = copy.deepcopy(origin_data_list)
|
||||
test_list = []
|
||||
for data in temp_data_list:
|
||||
old_version = data['dependency'] + data['old_version'] # package == x.x.x
|
||||
new_version = data['dependency'] + data['new_version'] # package == x.x.x
|
||||
description = data['description'] # 功能描述
|
||||
old_code = data['old_code'] # mask后的代码
|
||||
|
||||
instruction = bulid_prompt(description, old_version, old_code, new_version)
|
||||
test_list.append(instruction)
|
||||
|
||||
sampling_params = SamplingParams(n=6, temperature=0.8, top_p=0.95, max_tokens=512)
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=4,
|
||||
gpu_memory_utilization=0.6,
|
||||
swap_space=40,
|
||||
)
|
||||
|
||||
outputs = llm.generate(test_list, sampling_params)
|
||||
for output in outputs:
|
||||
requests_id = int(output.request_id)
|
||||
temp_ans_list = []
|
||||
output_list = output.outputs
|
||||
for o in output_list:
|
||||
text = o.text
|
||||
temp_ans_list.append(text)
|
||||
|
||||
temp_data_list[requests_id]['model_output'] = str(temp_ans_list)
|
||||
|
||||
save_folder_path = os.path.join(
|
||||
'../data/result_data/code_migration', model_name.split('/')[-1]
|
||||
)
|
||||
if not os.path.exists(save_folder_path):
|
||||
os.makedirs(save_folder_path)
|
||||
|
||||
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
||||
|
||||
with open(save_json_path, 'w', encoding='utf-8') as fw:
|
||||
json.dump(temp_data_list, fw, indent=4, ensure_ascii=False)
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def bulid_prompt(description, old_version, old_code, new_version) -> str:
|
||||
"""
|
||||
build prompt
|
||||
:param version:
|
||||
:param description:
|
||||
:param masked_code:
|
||||
:param options:
|
||||
:return:
|
||||
"""
|
||||
prompt = f"""
|
||||
You are now a professional Python programming engineer. I will provide you with a code snippet and a description of its functionality,
|
||||
including the dependencies and versions used in the code. Then, I will provide the same dependencies but with a specified new version.
|
||||
Your task is to refactor the code using the methods provided by the specified new version and return the refactored code.
|
||||
Please note that you only need to return the refactored code and enclose it with <start> and <end>:
|
||||
###Functionality description of the code
|
||||
{description}
|
||||
###Dependency and old version
|
||||
{old_version}
|
||||
###Old version code
|
||||
{old_code}
|
||||
###Dependency and new version
|
||||
{new_version}
|
||||
###Refactored new code
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
json_path = '../data/test_data/VersiCode_migration.json'
|
||||
|
||||
with open(json_path, 'r', encoding='utf-8') as fr:
|
||||
lodict = json.load(fr)
|
||||
|
||||
origin_data_list = lodict
|
||||
|
||||
for model_name in model_list:
|
||||
process = Process(target=run_inference, args=(model_name, origin_data_list))
|
||||
process.start()
|
||||
process.join()
|
||||
time.sleep(120)
|
||||
@@ -1,356 +0,0 @@
|
||||
"""
|
||||
评测block的预测能力
|
||||
1、判断是否包含正确的函数名
|
||||
2、判断是否合法
|
||||
3、计算ISM,和PM
|
||||
"""
|
||||
|
||||
import io
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import tokenize
|
||||
|
||||
|
||||
def is_code_valid(code):
|
||||
try:
|
||||
compile(code, '<string>', 'exec')
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def longest_common_prefix_between_lists_with_elements(list1, list2):
|
||||
"""
|
||||
计算两个字符串列表中元素的最长前缀匹配长度
|
||||
:param list1:
|
||||
:param list2:
|
||||
:return:
|
||||
"""
|
||||
max_prefix_length = 0
|
||||
max_prefix_elements = ()
|
||||
for str1 in list1:
|
||||
for str2 in list2:
|
||||
prefix_length = 0
|
||||
min_len = min(len(str1), len(str2))
|
||||
for i in range(min_len):
|
||||
if str1[i] == str2[i]:
|
||||
prefix_length += 1
|
||||
else:
|
||||
break
|
||||
if prefix_length > max_prefix_length:
|
||||
max_prefix_length = prefix_length
|
||||
max_prefix_elements = (str1, str2)
|
||||
return max_prefix_length, max_prefix_elements
|
||||
|
||||
|
||||
def get_token(ans_code: str, output_code: str):
|
||||
"""
|
||||
对代码进行词法分析,分解成标识符,返回两个标识符列表
|
||||
:param ans_code:
|
||||
:param output_code:
|
||||
:return:
|
||||
"""
|
||||
output_flag = True
|
||||
ans_flag = True
|
||||
try:
|
||||
tokens_ans = tokenize.tokenize(io.BytesIO(ans_code.encode('utf-8')).readline)
|
||||
except Exception:
|
||||
tokens_ans = ans_code.splitlines()
|
||||
ans_flag = False
|
||||
|
||||
try:
|
||||
tokens_output = tokenize.tokenize(
|
||||
io.BytesIO(output_code.encode('utf-8')).readline
|
||||
)
|
||||
except Exception:
|
||||
tokens_output = output_code.splitlines()
|
||||
output_flag = False
|
||||
|
||||
identifiers_ans = []
|
||||
identifiers_output = []
|
||||
if ans_flag:
|
||||
try:
|
||||
for token in tokens_ans:
|
||||
if token.type == tokenize.NAME:
|
||||
identifiers_ans.append(token.string)
|
||||
except Exception:
|
||||
identifiers_ans = tokens_ans
|
||||
else:
|
||||
identifiers_ans = tokens_ans
|
||||
|
||||
if output_flag:
|
||||
try:
|
||||
for to in tokens_output:
|
||||
if to.type == tokenize.NAME:
|
||||
identifiers_output.append(to.string)
|
||||
except Exception:
|
||||
identifiers_output = tokens_output
|
||||
else:
|
||||
identifiers_output = tokens_output
|
||||
|
||||
return identifiers_ans, identifiers_output
|
||||
|
||||
|
||||
def get_token_per_line(code: str):
|
||||
"""
|
||||
对每一行代码进行词法分析,记录每一行的标识符
|
||||
:param code: 代码字符串
|
||||
:return: 每一行的标识符列表组成的列表
|
||||
"""
|
||||
lines = code.split('\n') # 将代码按行分割成列表
|
||||
identifiers_per_line = [] # 用于存储每一行的标识符列表的列表
|
||||
|
||||
for line in lines:
|
||||
tokens = tokenize.tokenize(io.BytesIO(line.encode('utf-8')).readline)
|
||||
identifiers = []
|
||||
try:
|
||||
for token in tokens:
|
||||
if token.type == tokenize.NAME:
|
||||
identifiers.append(token.string)
|
||||
except Exception:
|
||||
identifiers = line.split(' ')
|
||||
identifiers_per_line.append(identifiers)
|
||||
|
||||
return identifiers_per_line
|
||||
|
||||
|
||||
def get_ISM(answer_code: str, model_output_list: list, asnwer_name: str) -> list:
|
||||
"""
|
||||
计算ISM,返回一个有序的得分列表
|
||||
:return:
|
||||
"""
|
||||
score_list = []
|
||||
for code in model_output_list:
|
||||
if '```python' in code:
|
||||
code = code.replace('```python', '')
|
||||
code = code.replace('```', '')
|
||||
if not re.search(rf'\b{re.escape(asnwer_name)}\b', code) or not is_code_valid(
|
||||
code
|
||||
):
|
||||
score_list.append(0)
|
||||
continue
|
||||
|
||||
# if asnwer_name not in code:
|
||||
# score_list.append(0)
|
||||
# continue
|
||||
|
||||
identifiers_ans, identifiers_output = get_token(answer_code, code)
|
||||
max_len, elements = longest_common_prefix_between_lists_with_elements(
|
||||
identifiers_ans, identifiers_output
|
||||
)
|
||||
if max_len != 0:
|
||||
base_element_len = max(len(elements[0]), len(elements[1]))
|
||||
temp_score = max_len / base_element_len
|
||||
score_list.append(temp_score)
|
||||
else:
|
||||
score_list.append(0)
|
||||
# base_element_len = max(len(elements[0]), len(elements[1]))
|
||||
# temp_score = max_len/base_element_len
|
||||
# score_list.append(temp_score)
|
||||
|
||||
score_list = sorted(score_list, reverse=True)
|
||||
return score_list
|
||||
|
||||
|
||||
def get_ISM_without_verification(
|
||||
answer_code: str, model_output_list: list, asnwer_name: str
|
||||
) -> list:
|
||||
"""
|
||||
计算ISM,返回一个有序的得分列表
|
||||
:return:
|
||||
"""
|
||||
score_list = []
|
||||
for code in model_output_list:
|
||||
if asnwer_name not in code:
|
||||
score_list.append(0)
|
||||
continue
|
||||
|
||||
# if asnwer_name not in code:
|
||||
# score_list.append(0)
|
||||
# continue
|
||||
|
||||
identifiers_ans, identifiers_output = get_token(answer_code, code)
|
||||
max_len, elements = longest_common_prefix_between_lists_with_elements(
|
||||
identifiers_ans, identifiers_output
|
||||
)
|
||||
if max_len != 0:
|
||||
base_element_len = max(len(elements[0]), len(elements[1]))
|
||||
temp_score = max_len / base_element_len
|
||||
score_list.append(temp_score)
|
||||
else:
|
||||
score_list.append(0)
|
||||
# base_element_len = max(len(elements[0]), len(elements[1]))
|
||||
# temp_score = max_len/base_element_len
|
||||
# score_list.append(temp_score)
|
||||
|
||||
score_list = sorted(score_list, reverse=True)
|
||||
return score_list
|
||||
|
||||
|
||||
def longest_common_prefix_with_lengths(list1, list2):
|
||||
"""
|
||||
计算两个二维列表中每个子列表的最长前缀匹配长度,并记录拥有最长前缀匹配长度的两个子列表的长度
|
||||
:param list1: 第一个二维列表
|
||||
:param list2: 第二个二维列表
|
||||
:return: 最长前缀匹配长度以及拥有最长前缀匹配长度的两个子列表的长度
|
||||
"""
|
||||
max_length = 0
|
||||
len_list1 = 0
|
||||
len_list2 = 0
|
||||
for i, sublist1 in enumerate(list1):
|
||||
for j, sublist2 in enumerate(list2):
|
||||
match_length = 0
|
||||
min_length = min(len(sublist1), len(sublist2))
|
||||
for k in range(min_length):
|
||||
if sublist1[k] == sublist2[k]:
|
||||
match_length += 1
|
||||
else:
|
||||
break
|
||||
if match_length > max_length:
|
||||
max_length = match_length
|
||||
len_list1 = len(sublist1)
|
||||
len_list2 = len(sublist2)
|
||||
return max_length, len_list1, len_list2
|
||||
|
||||
|
||||
def get_PM(answer_code: str, model_output_list: list, asnwer_name: str) -> list:
|
||||
"""
|
||||
计算PM,返回一个有序的得分列表
|
||||
:return:
|
||||
"""
|
||||
score_list = []
|
||||
for code in model_output_list:
|
||||
if '```python' in code:
|
||||
code = code.replace('```python', '')
|
||||
code = code.replace('```', '')
|
||||
if not re.search(rf'\b{re.escape(asnwer_name)}\b', code) or not is_code_valid(
|
||||
code
|
||||
):
|
||||
# if asnwer_name not in code or is_code_valid(code) == False:
|
||||
score_list.append(0)
|
||||
continue
|
||||
|
||||
# if asnwer_name not in code:
|
||||
# score_list.append(0)
|
||||
# continue
|
||||
|
||||
ans_list = get_token_per_line(answer_code)
|
||||
output_token_list = get_token_per_line(code)
|
||||
max_len, len1, len2 = longest_common_prefix_with_lengths(
|
||||
ans_list, output_token_list
|
||||
)
|
||||
base_element_len = max(len1, len2)
|
||||
|
||||
if base_element_len != 0:
|
||||
temp_score = max_len / base_element_len
|
||||
score_list.append(temp_score)
|
||||
else:
|
||||
score_list.append(0)
|
||||
|
||||
score_list = sorted(score_list, reverse=True)
|
||||
return score_list
|
||||
|
||||
|
||||
def get_score(score_list: list, k):
|
||||
"""
|
||||
计算score@n,k
|
||||
:param score_list:
|
||||
:param k:
|
||||
:return:
|
||||
"""
|
||||
n = len(score_list)
|
||||
sum = 0
|
||||
final = n - k + 1
|
||||
for i in range(1, final + 1):
|
||||
sum += math.comb(n - i, k - 1) * score_list[i - 1]
|
||||
|
||||
final_score = sum / math.comb(n, k)
|
||||
|
||||
return final_score
|
||||
|
||||
|
||||
k = 1
|
||||
task = 'block' # block or line
|
||||
json_name = f'Versicode_{task}_completion.json'
|
||||
|
||||
folder_path = f'../data/result_data/{task}_completion'
|
||||
model_list = os.listdir(folder_path)
|
||||
|
||||
for model in model_list:
|
||||
model_json_path = os.path.join(folder_path, model, json_name)
|
||||
with open(model_json_path, 'r', encoding='utf-8') as fr:
|
||||
lodict = json.load(fr)
|
||||
data_dict = lodict
|
||||
data_list = data_dict
|
||||
data_len = len(data_list)
|
||||
sum_ISM = 0
|
||||
sum_PM = 0
|
||||
|
||||
for data in data_list:
|
||||
# model_output_list = eval(data['model_output'])
|
||||
model_output_list = eval(data['model_output_clear'])[:1]
|
||||
temp_list = []
|
||||
for o in model_output_list:
|
||||
temp_out = o.replace('```python', '')
|
||||
temp_out = temp_out.replace('```', '')
|
||||
temp_list.append(temp_out)
|
||||
model_output_list = temp_list
|
||||
answer_code = data['code']
|
||||
answer_name = data['core_token']
|
||||
#
|
||||
# answer_code = data['new_code'] #code editing
|
||||
# answer_name = data['new_name'] #code editing
|
||||
|
||||
# answer_code = data['old_code'] # code editing new to old
|
||||
# answer_name = data['old_name'] # code editing new to old
|
||||
#
|
||||
ISM_score_list = get_ISM(answer_code, model_output_list, answer_name)
|
||||
# ISM_score_without_verification_list = get_ISM_without_verification(answer_code, model_output_list, answer_name) #新增
|
||||
PM_score_list = get_PM(answer_code, model_output_list, answer_name)
|
||||
|
||||
# if not ISM_score_without_verification_list == ISM_score_list:#新增
|
||||
# for s in ISM_score_list:#新增
|
||||
# if s != ISM_score_without_verification_list[ISM_score_list.index(s)]:#新增
|
||||
# print('元数据如下')#新增
|
||||
# print(data)#新增
|
||||
# print('答案如下')#新增
|
||||
# print(model_output_list[ISM_score_list.index(s)])#新增
|
||||
|
||||
# flag = int(input('输入1继续,0退出'))#新增
|
||||
# if flag == 1:
|
||||
# continue
|
||||
|
||||
ISM_score = get_score(ISM_score_list, k)
|
||||
PM_score = get_score(PM_score_list, k)
|
||||
|
||||
sum_ISM += ISM_score
|
||||
sum_PM += PM_score
|
||||
# print(f"ISM分数:{ISM_score}")
|
||||
# print(f"PM分数:{PM_score}")
|
||||
|
||||
print(f'{model}, {task} completion task, ISM@{k} score: {sum_ISM / data_len}')
|
||||
print(f'{model}, {task} completion task, PM@{k} score: {sum_PM / data_len}')
|
||||
|
||||
|
||||
# def get_token(ans_code:str, output_code:str):
|
||||
# """
|
||||
# 对代码进行词法分析,分解成标识符,返回两个标识符列表
|
||||
# :param ans_code:
|
||||
# :param output_code:
|
||||
# :return:
|
||||
# """
|
||||
# tokens_ans = tokenize.tokenize(io.BytesIO(ans_code.encode('utf-8')).readline)
|
||||
# tokens_output = tokenize.tokenize(io.BytesIO(output_code.encode('utf-8')).readline)
|
||||
# identifiers_ans = []
|
||||
# identifiers_output = []
|
||||
# for token in tokens_ans:
|
||||
# if token.type == tokenize.NAME:
|
||||
# identifiers_ans.append(token.string)
|
||||
#
|
||||
# for to in tokens_output:
|
||||
# if to.type == tokenize.NAME:
|
||||
# identifiers_output.append(to.string)
|
||||
#
|
||||
# return identifiers_ans, identifiers_output
|
||||
@@ -1,198 +0,0 @@
|
||||
"""
|
||||
Calculate the cdc score for migration
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
|
||||
# warnings.filterwarnings("ignore", category=SyntaxWarning)
|
||||
|
||||
|
||||
def is_correct_parameter_count(function_name, correct_code, test_code):
|
||||
"""
|
||||
判断参数数量是否一致
|
||||
:param function_name:
|
||||
:param correct_code:
|
||||
:param test_code:
|
||||
:return:
|
||||
"""
|
||||
# 获取正确代码中的参数数量
|
||||
# return True
|
||||
pattern = rf'{function_name}\((.*?)\)'
|
||||
correct_match = re.search(pattern, correct_code)
|
||||
|
||||
if correct_match:
|
||||
correct_params = correct_match.group(1).strip()
|
||||
correct_param_list = [p.strip() for p in correct_params.split(',') if p.strip()]
|
||||
expected_count = len(correct_param_list)
|
||||
else:
|
||||
expected_count = 0 # 如果没有参数,期望数量为0
|
||||
|
||||
# 在需要判断的代码中查找函数调用
|
||||
test_match = re.search(pattern, test_code)
|
||||
|
||||
if test_match:
|
||||
test_params = test_match.group(1).strip()
|
||||
test_param_list = [p.strip() for p in test_params.split(',') if p.strip()]
|
||||
return len(test_param_list) == expected_count # 检查参数数量
|
||||
else:
|
||||
# 如果没有括号,检查函数名是否在字符串中
|
||||
return expected_count == 0 and function_name in test_code
|
||||
|
||||
|
||||
def check_keyword_parameters(function_name, correct_code, test_code):
|
||||
"""
|
||||
判断关键词参数赋值是否正确使用
|
||||
:param function_name:
|
||||
:param correct_code:
|
||||
:param test_code:
|
||||
:return:
|
||||
"""
|
||||
# 正则表达式匹配正确代码中的函数调用
|
||||
# return True
|
||||
pattern = rf'{function_name}\((.*?)\)'
|
||||
correct_match = re.search(pattern, correct_code)
|
||||
|
||||
if correct_match:
|
||||
correct_params = correct_match.group(1).strip()
|
||||
correct_param_list = [p.strip() for p in correct_params.split(',') if p.strip()]
|
||||
|
||||
# 检查待检测代码中的函数调用
|
||||
test_match = re.search(pattern, test_code)
|
||||
|
||||
if test_match:
|
||||
test_params = test_match.group(1).strip()
|
||||
test_param_list = [p.strip() for p in test_params.split(',') if p.strip()]
|
||||
|
||||
# 确保待检测的每个参数都以关键字参数形式赋值
|
||||
for correct_param in correct_param_list:
|
||||
if '=' in correct_param: # 仅当正确代码中有关键词参数
|
||||
param_name = correct_param.split('=')[0].strip()
|
||||
if not any(
|
||||
param_name in test_param and '=' in test_param
|
||||
for test_param in test_param_list
|
||||
):
|
||||
return False # 如果对应参数不是关键词参数,则返回False
|
||||
|
||||
return True # 所有关键字参数匹配
|
||||
|
||||
return False # 如果没有匹配,返回False
|
||||
|
||||
|
||||
def with_correct(answer_code: str, model_output: str) -> bool:
|
||||
"""
|
||||
当answer是with结构时,判断模型生成的是不是with结构
|
||||
:param answer_code:
|
||||
:param model_output:
|
||||
:return:
|
||||
"""
|
||||
# return True
|
||||
if not answer_code.startswith('with') and not model_output.startswith('with'):
|
||||
return True
|
||||
elif answer_code.startswith('with') and model_output.startswith('with'):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def compute_block_score_k(
|
||||
answer: str,
|
||||
model_output: list,
|
||||
k: int,
|
||||
model_filled_code,
|
||||
core_line_in_core_block,
|
||||
core_line_in_output_clear,
|
||||
):
|
||||
"""
|
||||
cdc需要满足五个条件,em只需要满足第一个条件
|
||||
"""
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
if (
|
||||
re.search(rf'\b{re.escape(answer)}\b', code)
|
||||
and is_code_valid(model_filled_code[index])
|
||||
and is_correct_parameter_count(
|
||||
answer, core_line_in_core_block, core_line_in_output_clear[index]
|
||||
)
|
||||
and with_correct(core_line_in_core_block, core_line_in_output_clear[index])
|
||||
and check_keyword_parameters(
|
||||
answer, core_line_in_core_block, core_line_in_output_clear[index]
|
||||
)
|
||||
): # block
|
||||
# if re.search(rf'\b{re.escape(answer)}\b', code):#block
|
||||
c += 1
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
|
||||
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def is_code_valid(code):
|
||||
try:
|
||||
compile(code, '<string>', 'exec')
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def compute_score_k(answer: str, model_output: list, k: int):
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for output in model_output:
|
||||
if '```python' in output:
|
||||
output = output.replace('```python', '')
|
||||
output = output.replace('```', '')
|
||||
# if answer == output:
|
||||
|
||||
if re.search(rf'\b{re.escape(answer)}\b', output) and is_code_valid(output):
|
||||
c += 1
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
|
||||
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||
|
||||
return score
|
||||
|
||||
|
||||
k = 1 # cdc@k
|
||||
json_name = 'VersiCode_migration.json'
|
||||
task = 'migration'
|
||||
folder_path = '../data/result_data/code_migration'
|
||||
|
||||
model_list = os.listdir(folder_path)
|
||||
for model in model_list:
|
||||
# if model != 'gpt-4o':
|
||||
# continue
|
||||
model_json_path = os.path.join(folder_path, model, json_name)
|
||||
with open(model_json_path, 'r', encoding='utf-8') as fr:
|
||||
lodict = json.load(fr)
|
||||
data_list = lodict
|
||||
|
||||
score_list = []
|
||||
for data in data_list:
|
||||
answer = data['new_name'] # old -> new
|
||||
model_output = data['model_output_clear'] # old -> new
|
||||
|
||||
model_filled_code = model_output
|
||||
# core_line_in_core_block = data['core_line_in_new_core_block']# old -> new
|
||||
core_line_in_core_block = data['core_line_in_code'] # old -> new
|
||||
core_line_in_output_clear = data['core_line_in_output_clear'] # old -> new
|
||||
|
||||
score_list.append(
|
||||
compute_block_score_k(
|
||||
answer,
|
||||
model_output,
|
||||
k,
|
||||
model_filled_code,
|
||||
core_line_in_core_block,
|
||||
core_line_in_output_clear,
|
||||
)
|
||||
)
|
||||
|
||||
final_score = sum(score_list) / len(score_list)
|
||||
print(f'{model}, {task} task, cdc@{k} score: {final_score}')
|
||||
@@ -1,225 +0,0 @@
|
||||
"""
|
||||
Calculate the cdc score for line and block
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
|
||||
# warnings.filterwarnings("ignore", category=SyntaxWarning)
|
||||
|
||||
|
||||
def is_code_valid(code):
|
||||
try:
|
||||
compile(code, '<string>', 'exec')
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_correct_parameter_count(function_name, correct_code, test_code):
|
||||
"""
|
||||
判断参数数量是否一致
|
||||
:param function_name:
|
||||
:param correct_code:
|
||||
:param test_code:
|
||||
:return:
|
||||
"""
|
||||
# 获取正确代码中的参数数量
|
||||
# return True
|
||||
pattern = rf'{function_name}\((.*?)\)'
|
||||
correct_match = re.search(pattern, correct_code)
|
||||
|
||||
if correct_match:
|
||||
correct_params = correct_match.group(1).strip()
|
||||
correct_param_list = [p.strip() for p in correct_params.split(',') if p.strip()]
|
||||
expected_count = len(correct_param_list)
|
||||
else:
|
||||
expected_count = 0 # 如果没有参数,期望数量为0
|
||||
|
||||
# 在需要判断的代码中查找函数调用
|
||||
test_match = re.search(pattern, test_code)
|
||||
|
||||
if test_match:
|
||||
test_params = test_match.group(1).strip()
|
||||
test_param_list = [p.strip() for p in test_params.split(',') if p.strip()]
|
||||
return len(test_param_list) == expected_count # 检查参数数量
|
||||
else:
|
||||
# 如果没有括号,检查函数名是否在字符串中
|
||||
return expected_count == 0 and function_name in test_code
|
||||
|
||||
|
||||
def check_keyword_parameters(function_name, correct_code, test_code):
|
||||
"""
|
||||
判断关键词参数赋值是否正确使用
|
||||
:param function_name:
|
||||
:param correct_code:
|
||||
:param test_code:
|
||||
:return:
|
||||
"""
|
||||
# 正则表达式匹配正确代码中的函数调用
|
||||
# return True
|
||||
pattern = rf'{function_name}\((.*?)\)'
|
||||
correct_match = re.search(pattern, correct_code)
|
||||
|
||||
if correct_match:
|
||||
correct_params = correct_match.group(1).strip()
|
||||
correct_param_list = [p.strip() for p in correct_params.split(',') if p.strip()]
|
||||
|
||||
# 检查待检测代码中的函数调用
|
||||
test_match = re.search(pattern, test_code)
|
||||
|
||||
if test_match:
|
||||
test_params = test_match.group(1).strip()
|
||||
test_param_list = [p.strip() for p in test_params.split(',') if p.strip()]
|
||||
|
||||
# 确保待检测的每个参数都以关键字参数形式赋值
|
||||
for correct_param in correct_param_list:
|
||||
if '=' in correct_param: # 仅当正确代码中有关键词参数
|
||||
param_name = correct_param.split('=')[0].strip()
|
||||
if not any(
|
||||
param_name in test_param and '=' in test_param
|
||||
for test_param in test_param_list
|
||||
):
|
||||
return False # 如果对应参数不是关键词参数,则返回False
|
||||
|
||||
return True # 所有关键字参数匹配
|
||||
|
||||
return False # 如果没有匹配,返回False
|
||||
|
||||
|
||||
def with_correct(answer_code: str, model_output: str) -> bool:
|
||||
"""
|
||||
当answer是with结构时,判断模型生成的是不是with结构
|
||||
:param answer_code:
|
||||
:param model_output:
|
||||
:return:
|
||||
"""
|
||||
# return True
|
||||
if not answer_code.startswith('with') and not model_output.startswith('with'):
|
||||
return True
|
||||
elif answer_code.startswith('with') and model_output.startswith('with'):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def compute_line_score_k(
|
||||
answer: str, model_output: list, k: int, model_filled_code, core_line
|
||||
):
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
if (
|
||||
re.search(rf'\b{re.escape(answer)}\b', code)
|
||||
and is_code_valid(model_filled_code[index])
|
||||
and is_correct_parameter_count(answer, core_line, code)
|
||||
and with_correct(core_line, code)
|
||||
and check_keyword_parameters(answer, core_line, code)
|
||||
): # line
|
||||
c += 1
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
|
||||
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def compute_block_score_k(
|
||||
answer: str,
|
||||
model_output: list,
|
||||
k: int,
|
||||
model_filled_code,
|
||||
core_line_in_core_block,
|
||||
core_line_in_output_clear,
|
||||
):
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
if (
|
||||
re.search(rf'\b{re.escape(answer)}\b', code)
|
||||
and is_code_valid(model_filled_code[index])
|
||||
and is_correct_parameter_count(
|
||||
answer, core_line_in_core_block, core_line_in_output_clear[index]
|
||||
)
|
||||
and with_correct(core_line_in_core_block, core_line_in_output_clear[index])
|
||||
and check_keyword_parameters(
|
||||
answer, core_line_in_core_block, core_line_in_output_clear[index]
|
||||
)
|
||||
): # block
|
||||
c += 1
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
|
||||
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def compute_score_k(answer: str, model_output: list, k: int):
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(
|
||||
code
|
||||
): # block
|
||||
# if re.search(rf'\b{re.escape(answer)}\b', code):#line
|
||||
c += 1
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
|
||||
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||
|
||||
return score
|
||||
|
||||
|
||||
k = 3 # cdc@k
|
||||
task = 'block' # line or block
|
||||
json_name = f'Versicode_{task}_completion.json'
|
||||
|
||||
folder_path = f'../data/result_data/{task}_completion'
|
||||
model_list = os.listdir(folder_path)
|
||||
|
||||
for model in model_list:
|
||||
model_json_path = os.path.join(folder_path, model, json_name)
|
||||
with open(model_json_path, 'r', encoding='utf-8') as fr:
|
||||
lodict = json.load(fr)
|
||||
data_list = lodict
|
||||
|
||||
if task == 'line':
|
||||
score_list = []
|
||||
for data in data_list:
|
||||
answer = data['core_token']
|
||||
model_output = eval(data['model_output_clear'])
|
||||
model_filled_code = [
|
||||
data['masked_code'].replace('<mask>', i) for i in model_output
|
||||
]
|
||||
core_line = data['core_line']
|
||||
score_list.append(
|
||||
compute_line_score_k(
|
||||
answer, model_output, k, model_filled_code, core_line
|
||||
)
|
||||
)
|
||||
else:
|
||||
score_list = []
|
||||
for data in data_list:
|
||||
answer = data['core_token']
|
||||
model_output = eval(data['model_output_clear'])
|
||||
model_filled_code = eval(data['model_output_clear'])
|
||||
core_line = data['core_line']
|
||||
core_line_in_output_clear = data['core_line_in_output_clear']
|
||||
score_list.append(
|
||||
compute_block_score_k(
|
||||
answer,
|
||||
model_output,
|
||||
k,
|
||||
model_filled_code,
|
||||
core_line,
|
||||
core_line_in_output_clear,
|
||||
)
|
||||
)
|
||||
|
||||
final_score = sum(score_list) / len(score_list)
|
||||
print(f'{model}, {task} completion task, cdc@{k} score: {final_score}')
|
||||
@@ -1,209 +0,0 @@
|
||||
"""
|
||||
Calculate the cdc score for line and block
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
|
||||
# warnings.filterwarnings("ignore", category=SyntaxWarning)
|
||||
|
||||
|
||||
def is_code_valid(code):
|
||||
try:
|
||||
compile(code, '<string>', 'exec')
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_correct_parameter_count(function_name, correct_code, test_code):
|
||||
"""
|
||||
判断参数数量是否一致
|
||||
:param function_name:
|
||||
:param correct_code:
|
||||
:param test_code:
|
||||
:return:
|
||||
"""
|
||||
# 获取正确代码中的参数数量
|
||||
# return True
|
||||
pattern = rf'{function_name}\((.*?)\)'
|
||||
correct_match = re.search(pattern, correct_code)
|
||||
|
||||
if correct_match:
|
||||
correct_params = correct_match.group(1).strip()
|
||||
correct_param_list = [p.strip() for p in correct_params.split(',') if p.strip()]
|
||||
expected_count = len(correct_param_list)
|
||||
else:
|
||||
expected_count = 0 # 如果没有参数,期望数量为0
|
||||
|
||||
# 在需要判断的代码中查找函数调用
|
||||
test_match = re.search(pattern, test_code)
|
||||
|
||||
if test_match:
|
||||
test_params = test_match.group(1).strip()
|
||||
test_param_list = [p.strip() for p in test_params.split(',') if p.strip()]
|
||||
return len(test_param_list) == expected_count # 检查参数数量
|
||||
else:
|
||||
# 如果没有括号,检查函数名是否在字符串中
|
||||
return expected_count == 0 and function_name in test_code
|
||||
|
||||
|
||||
def check_keyword_parameters(function_name, correct_code, test_code):
|
||||
"""
|
||||
判断关键词参数赋值是否正确使用
|
||||
:param function_name:
|
||||
:param correct_code:
|
||||
:param test_code:
|
||||
:return:
|
||||
"""
|
||||
# 正则表达式匹配正确代码中的函数调用
|
||||
# return True
|
||||
pattern = rf'{function_name}\((.*?)\)'
|
||||
correct_match = re.search(pattern, correct_code)
|
||||
|
||||
if correct_match:
|
||||
correct_params = correct_match.group(1).strip()
|
||||
correct_param_list = [p.strip() for p in correct_params.split(',') if p.strip()]
|
||||
|
||||
# 检查待检测代码中的函数调用
|
||||
test_match = re.search(pattern, test_code)
|
||||
|
||||
if test_match:
|
||||
test_params = test_match.group(1).strip()
|
||||
test_param_list = [p.strip() for p in test_params.split(',') if p.strip()]
|
||||
|
||||
# 确保待检测的每个参数都以关键字参数形式赋值
|
||||
for correct_param in correct_param_list:
|
||||
if '=' in correct_param: # 仅当正确代码中有关键词参数
|
||||
param_name = correct_param.split('=')[0].strip()
|
||||
if not any(
|
||||
param_name in test_param and '=' in test_param
|
||||
for test_param in test_param_list
|
||||
):
|
||||
return False # 如果对应参数不是关键词参数,则返回False
|
||||
|
||||
return True # 所有关键字参数匹配
|
||||
|
||||
return False # 如果没有匹配,返回False
|
||||
|
||||
|
||||
def with_correct(answer_code: str, model_output: str) -> bool:
|
||||
"""
|
||||
当answer是with结构时,判断模型生成的是不是with结构
|
||||
:param answer_code:
|
||||
:param model_output:
|
||||
:return:
|
||||
"""
|
||||
# return True
|
||||
if not answer_code.startswith('with') and not model_output.startswith('with'):
|
||||
return True
|
||||
elif answer_code.startswith('with') and model_output.startswith('with'):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def compute_line_score_k(
|
||||
answer: str, model_output: list, k: int, model_filled_code, core_line
|
||||
):
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code): # line
|
||||
c += 1
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
|
||||
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def compute_block_score_k(
|
||||
answer: str,
|
||||
model_output: list,
|
||||
k: int,
|
||||
model_filled_code,
|
||||
core_line_in_core_block,
|
||||
core_line_in_output_clear,
|
||||
):
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code): # block
|
||||
c += 1
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
|
||||
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def compute_score_k(answer: str, model_output: list, k: int):
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(
|
||||
code
|
||||
): # block
|
||||
# if re.search(rf'\b{re.escape(answer)}\b', code):#line
|
||||
c += 1
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
|
||||
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||
|
||||
return score
|
||||
|
||||
|
||||
k = 3 # em@k
|
||||
task = 'block' # line or block
|
||||
json_name = f'Versicode_{task}_completion.json'
|
||||
|
||||
folder_path = f'../data/result_data/{task}_completion'
|
||||
model_list = os.listdir(folder_path)
|
||||
|
||||
for model in model_list:
|
||||
model_json_path = os.path.join(folder_path, model, json_name)
|
||||
with open(model_json_path, 'r', encoding='utf-8') as fr:
|
||||
lodict = json.load(fr)
|
||||
data_list = lodict
|
||||
|
||||
if task == 'line':
|
||||
score_list = []
|
||||
for data in data_list:
|
||||
answer = data['core_token']
|
||||
model_output = eval(data['model_output_clear'])
|
||||
model_filled_code = [
|
||||
data['masked_code'].replace('<mask>', i) for i in model_output
|
||||
]
|
||||
core_line = data['core_line']
|
||||
score_list.append(
|
||||
compute_line_score_k(
|
||||
answer, model_output, k, model_filled_code, core_line
|
||||
)
|
||||
)
|
||||
else:
|
||||
score_list = []
|
||||
for data in data_list:
|
||||
answer = data['core_token']
|
||||
model_output = eval(data['model_output_clear'])
|
||||
model_filled_code = eval(data['model_output_clear'])
|
||||
core_line = data['core_line']
|
||||
core_line_in_output_clear = data['core_line_in_output_clear']
|
||||
score_list.append(
|
||||
compute_block_score_k(
|
||||
answer,
|
||||
model_output,
|
||||
k,
|
||||
model_filled_code,
|
||||
core_line,
|
||||
core_line_in_output_clear,
|
||||
)
|
||||
)
|
||||
|
||||
final_score = sum(score_list) / len(score_list)
|
||||
print(f'{model}, {task} completion task, em@{k} score: {final_score}')
|
||||
-99
@@ -1,99 +0,0 @@
|
||||
"""
|
||||
Find the line of code generated by the model using the block in the version code
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
|
||||
|
||||
def process_line_mask(code_snippet, core_token):
|
||||
if not core_token:
|
||||
return None, None
|
||||
|
||||
replaced_lines = {}
|
||||
lines = code_snippet.split('\n')
|
||||
|
||||
in_multi_line_comment = False
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if in_multi_line_comment:
|
||||
if ('"""' in line or "'''" in line) and not re.findall(
|
||||
r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line
|
||||
):
|
||||
in_multi_line_comment = False
|
||||
continue
|
||||
elif line.strip().startswith('#'):
|
||||
continue
|
||||
elif re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
||||
continue
|
||||
elif ('"""' in line or "'''" in line) and not re.findall(
|
||||
r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line
|
||||
):
|
||||
in_multi_line_comment = True
|
||||
continue
|
||||
else:
|
||||
if re.search(r'\bdef\s+task_function\b', line):
|
||||
continue
|
||||
|
||||
if re.search(r'\b{}\b(?!\s*=)'.format(re.escape(core_token)), line):
|
||||
replaced_lines.update({i: line})
|
||||
|
||||
if replaced_lines:
|
||||
random_line_location = random.choice(list(replaced_lines.keys()))
|
||||
|
||||
masked_line = lines[random_line_location]
|
||||
leading_spaces = re.match(r'^\s*', masked_line).group(0)
|
||||
masked_line = masked_line.strip()
|
||||
lines[random_line_location] = leading_spaces + '<line_mask>'
|
||||
|
||||
masked_code = '\n'.join(lines)
|
||||
|
||||
return masked_code, masked_line
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def load_json(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
|
||||
def save_json(file_path, data):
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_list = os.listdir('../data/result_data/block_completion')
|
||||
for model in model_list:
|
||||
input_json_file = f'../data/result_data/block_completion/{model}/VersiCode_block_completion.json'
|
||||
output_json_file = input_json_file
|
||||
data = load_json(input_json_file)
|
||||
|
||||
for item in data:
|
||||
core_token = item['core_token']
|
||||
code = item['code']
|
||||
|
||||
_, core_line_in_code = process_line_mask(code, core_token)
|
||||
if core_line_in_code:
|
||||
item['core_line_in_code'] = core_line_in_code
|
||||
else:
|
||||
item['core_line_in_code'] = 'N/A'
|
||||
|
||||
model_output_clear = item['model_output_clear']
|
||||
core_line_in_output_list = []
|
||||
|
||||
for entry in eval(model_output_clear):
|
||||
_, core_line_in_output = process_line_mask(entry, core_token)
|
||||
if core_line_in_output:
|
||||
core_line_in_output_list.append(core_line_in_output)
|
||||
else:
|
||||
core_line_in_output_list.append('N/A')
|
||||
|
||||
item['core_line_in_output_clear'] = core_line_in_output_list
|
||||
|
||||
save_json(output_json_file, data)
|
||||
print('Done!')
|
||||
-102
@@ -1,102 +0,0 @@
|
||||
"""
|
||||
Find the line of code generated by the model using the block in the version code
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
|
||||
|
||||
def process_line_mask(code_snippet, core_token):
|
||||
if not core_token:
|
||||
return None, None
|
||||
|
||||
replaced_lines = {}
|
||||
lines = code_snippet.split('\n')
|
||||
|
||||
in_multi_line_comment = False
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if in_multi_line_comment:
|
||||
if ('"""' in line or "'''" in line) and not re.findall(
|
||||
r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line
|
||||
):
|
||||
in_multi_line_comment = False
|
||||
continue
|
||||
elif line.strip().startswith('#'):
|
||||
continue
|
||||
elif re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
||||
continue
|
||||
elif ('"""' in line or "'''" in line) and not re.findall(
|
||||
r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line
|
||||
):
|
||||
in_multi_line_comment = True
|
||||
continue
|
||||
else:
|
||||
if re.search(r'\bdef\s+task_function\b', line):
|
||||
continue
|
||||
|
||||
if re.search(r'\b{}\b(?!\s*=)'.format(re.escape(core_token)), line):
|
||||
replaced_lines.update({i: line})
|
||||
|
||||
if replaced_lines:
|
||||
random_line_location = random.choice(list(replaced_lines.keys()))
|
||||
|
||||
masked_line = lines[random_line_location]
|
||||
leading_spaces = re.match(r'^\s*', masked_line).group(0)
|
||||
masked_line = masked_line.strip()
|
||||
lines[random_line_location] = leading_spaces + '<line_mask>'
|
||||
|
||||
masked_code = '\n'.join(lines)
|
||||
|
||||
return masked_code, masked_line
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
def load_json(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
|
||||
def save_json(file_path, data):
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_list = os.listdir('../data/result_data/code_migration')
|
||||
for model in model_list:
|
||||
input_json_file = (
|
||||
f'../data/result_data/code_migration/{model}/VersiCode_migration.json'
|
||||
)
|
||||
output_json_file = input_json_file
|
||||
data = load_json(input_json_file)
|
||||
|
||||
for item in data:
|
||||
core_token = item['old_name']
|
||||
code = item['old_code']
|
||||
|
||||
_, core_line_in_code = process_line_mask(code, core_token)
|
||||
if core_line_in_code:
|
||||
item['core_line_in_code'] = core_line_in_code
|
||||
else:
|
||||
item['core_line_in_code'] = 'N/A'
|
||||
|
||||
model_output_clear = item['model_output_clear']
|
||||
core_line_in_output_list = []
|
||||
|
||||
core_token = item['new_name']
|
||||
for entry in eval(model_output_clear):
|
||||
_, core_line_in_output = process_line_mask(entry, core_token)
|
||||
if core_line_in_output:
|
||||
core_line_in_output_list.append(core_line_in_output)
|
||||
else:
|
||||
core_line_in_output_list.append('N/A')
|
||||
|
||||
item['core_line_in_output_clear'] = core_line_in_output_list
|
||||
|
||||
save_json(output_json_file, data)
|
||||
print('Done!')
|
||||
@@ -1,38 +0,0 @@
|
||||
"""
|
||||
Clear the<start>and<end>generated by the model in inference
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
model_name = ''
|
||||
task = 'block_completion'
|
||||
|
||||
result_path = f'../data/result_data/{task}/{model_name}/VersiCode_block_completion.json' # Modify the file according to the task format
|
||||
|
||||
|
||||
with open(result_path, 'r', encoding='utf-8') as fr:
|
||||
lodict = json.load(fr)
|
||||
data_dict = lodict
|
||||
data_list = data_dict
|
||||
|
||||
for data in data_list:
|
||||
temp_list = []
|
||||
model_output_list = eval(data['model_output'])
|
||||
for output in model_output_list:
|
||||
if '<start>' in output and '<end>' in output:
|
||||
start_index = output.find('<start>') + len('<start>')
|
||||
end_index = output.find('<end>')
|
||||
content = (
|
||||
output[start_index:end_index]
|
||||
.replace('```python', '')
|
||||
.replace('```', '')
|
||||
)
|
||||
else:
|
||||
content = 'no_answer'
|
||||
|
||||
temp_list.append(content)
|
||||
|
||||
data['model_output_clear'] = str(temp_list)
|
||||
|
||||
with open(result_path, 'w', encoding='utf-8') as fw:
|
||||
json.dump(data_dict, fw, indent=4, ensure_ascii=False)
|
||||
@@ -1,146 +0,0 @@
|
||||
aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.11.18
|
||||
aiosignal==1.3.2
|
||||
airportsdata==20250224
|
||||
annotated-types==0.7.0
|
||||
anyio==4.9.0
|
||||
astor==0.8.1
|
||||
attrs==25.3.0
|
||||
blake3==1.0.4
|
||||
cachetools==5.5.2
|
||||
certifi==2025.1.31
|
||||
charset-normalizer==3.4.1
|
||||
click==8.1.8
|
||||
cloudpickle==3.1.1
|
||||
compressed-tensors==0.9.3
|
||||
cupy-cuda12x==13.4.1
|
||||
Deprecated==1.2.18
|
||||
depyf==0.18.0
|
||||
dill==0.4.0
|
||||
diskcache==5.6.3
|
||||
distro==1.9.0
|
||||
dnspython==2.7.0
|
||||
einops==0.8.1
|
||||
email_validator==2.2.0
|
||||
fastapi==0.115.12
|
||||
fastapi-cli==0.0.7
|
||||
fastrlock==0.8.3
|
||||
filelock==3.18.0
|
||||
frozenlist==1.6.0
|
||||
fsspec==2025.3.2
|
||||
gguf==0.16.2
|
||||
googleapis-common-protos==1.70.0
|
||||
grpcio==1.71.0
|
||||
h11==0.14.0
|
||||
hf-xet==1.0.3
|
||||
httpcore==1.0.8
|
||||
httptools==0.6.4
|
||||
httpx==0.28.1
|
||||
huggingface-hub==0.30.2
|
||||
idna==3.10
|
||||
importlib_metadata==8.0.0
|
||||
interegular==0.3.3
|
||||
Jinja2==3.1.6
|
||||
jiter==0.9.0
|
||||
jsonschema==4.23.0
|
||||
jsonschema-specifications==2024.10.1
|
||||
lark==1.2.2
|
||||
llguidance==0.7.16
|
||||
llvmlite==0.44.0
|
||||
lm-format-enforcer==0.10.11
|
||||
markdown-it-py==3.0.0
|
||||
MarkupSafe==3.0.2
|
||||
mdurl==0.1.2
|
||||
mistral_common==1.5.4
|
||||
mpmath==1.3.0
|
||||
msgpack==1.1.0
|
||||
msgspec==0.19.0
|
||||
multidict==6.4.3
|
||||
nest-asyncio==1.6.0
|
||||
networkx==3.4.2
|
||||
ninja==1.11.1.4
|
||||
numba==0.61.2
|
||||
numpy==2.2.5
|
||||
nvidia-cublas-cu12==12.4.5.8
|
||||
nvidia-cuda-cupti-cu12==12.4.127
|
||||
nvidia-cuda-nvrtc-cu12==12.4.127
|
||||
nvidia-cuda-runtime-cu12==12.4.127
|
||||
nvidia-cudnn-cu12==9.1.0.70
|
||||
nvidia-cufft-cu12==11.2.1.3
|
||||
nvidia-curand-cu12==10.3.5.147
|
||||
nvidia-cusolver-cu12==11.6.1.9
|
||||
nvidia-cusparse-cu12==12.3.1.170
|
||||
nvidia-cusparselt-cu12==0.6.2
|
||||
nvidia-nccl-cu12==2.21.5
|
||||
nvidia-nvjitlink-cu12==12.4.127
|
||||
nvidia-nvtx-cu12==12.4.127
|
||||
openai==1.75.0
|
||||
opencv-python-headless==4.11.0.86
|
||||
opentelemetry-api==1.26.0
|
||||
opentelemetry-exporter-otlp==1.26.0
|
||||
opentelemetry-exporter-otlp-proto-common==1.26.0
|
||||
opentelemetry-exporter-otlp-proto-grpc==1.26.0
|
||||
opentelemetry-exporter-otlp-proto-http==1.26.0
|
||||
opentelemetry-proto==1.26.0
|
||||
opentelemetry-sdk==1.26.0
|
||||
opentelemetry-semantic-conventions==0.47b0
|
||||
opentelemetry-semantic-conventions-ai==0.4.3
|
||||
outlines==0.1.11
|
||||
outlines_core==0.1.26
|
||||
packaging==25.0
|
||||
partial-json-parser==0.2.1.1.post5
|
||||
pillow==11.2.1
|
||||
prometheus-fastapi-instrumentator==7.1.0
|
||||
prometheus_client==0.21.1
|
||||
propcache==0.3.1
|
||||
protobuf==4.25.6
|
||||
psutil==7.0.0
|
||||
py-cpuinfo==9.0.0
|
||||
pycountry==24.6.1
|
||||
pydantic==2.11.3
|
||||
pydantic_core==2.33.1
|
||||
Pygments==2.19.1
|
||||
python-dotenv==1.1.0
|
||||
python-json-logger==3.3.0
|
||||
python-multipart==0.0.20
|
||||
PyYAML==6.0.2
|
||||
pyzmq==26.4.0
|
||||
ray==2.43.0
|
||||
referencing==0.36.2
|
||||
regex==2024.11.6
|
||||
requests==2.32.3
|
||||
rich==14.0.0
|
||||
rich-toolkit==0.14.1
|
||||
rpds-py==0.24.0
|
||||
safetensors==0.5.3
|
||||
scipy==1.15.2
|
||||
sentencepiece==0.2.0
|
||||
setuptools==75.8.0
|
||||
shellingham==1.5.4
|
||||
six==1.17.0
|
||||
sniffio==1.3.1
|
||||
starlette==0.46.2
|
||||
sympy==1.13.1
|
||||
tiktoken==0.9.0
|
||||
tokenizers==0.21.1
|
||||
torch==2.6.0
|
||||
torchaudio==2.6.0
|
||||
torchvision==0.21.0
|
||||
tqdm==4.67.1
|
||||
transformers==4.51.3
|
||||
triton==3.2.0
|
||||
typer==0.15.2
|
||||
typing-inspection==0.4.0
|
||||
typing_extensions==4.13.2
|
||||
urllib3==2.4.0
|
||||
uvicorn==0.34.2
|
||||
uvloop==0.21.0
|
||||
vllm==0.8.4
|
||||
watchfiles==1.0.5
|
||||
websockets==15.0.1
|
||||
wheel==0.45.1
|
||||
wrapt==1.17.2
|
||||
xformers==0.0.29.post2
|
||||
xgrammar==0.1.18
|
||||
yarl==1.20.0
|
||||
zipp==3.21.0
|
||||
@@ -212,7 +212,7 @@ if __name__ == '__main__':
|
||||
llm_config = None
|
||||
if args.llm_config:
|
||||
llm_config = get_llm_config_arg(args.llm_config)
|
||||
# modify_params must be False for evaluation purpose, for reproducibility and accuracy of results
|
||||
# modify_params must be False for evaluation purpose, for reproducibility and accurancy of results
|
||||
llm_config.modify_params = False
|
||||
if llm_config is None:
|
||||
raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
|
||||
|
||||
@@ -263,19 +263,8 @@ def prepare_dataset(
|
||||
f'Randomly sampling {eval_n_limit} unique instances with random seed 42.'
|
||||
)
|
||||
|
||||
def make_serializable(instance: pd.Series) -> dict:
|
||||
import numpy as np
|
||||
|
||||
instance_dict = instance.to_dict()
|
||||
for k, v in instance_dict.items():
|
||||
if isinstance(v, np.ndarray):
|
||||
instance_dict[k] = v.tolist()
|
||||
elif isinstance(v, pd.Timestamp):
|
||||
instance_dict[k] = str(v)
|
||||
return instance_dict
|
||||
|
||||
new_dataset = [
|
||||
make_serializable(instance)
|
||||
instance
|
||||
for _, instance in dataset.iterrows()
|
||||
if str(instance[id_column]) not in finished_ids
|
||||
]
|
||||
|
||||
@@ -16,8 +16,8 @@ vi.mock("react-i18next", async () => {
|
||||
if (i18nKey === "SETTINGS$API_KEYS_DESCRIPTION") {
|
||||
return (
|
||||
<span>
|
||||
API keys allow you to authenticate with the OpenHands API programmatically.
|
||||
Keep your API keys secure; anyone with your API key can access your account.
|
||||
API keys allow you to authenticate with the OpenHands API programmatically.
|
||||
Keep your API keys secure; anyone with your API key can access your account.
|
||||
For more information on how to use the API, see our {components.a}
|
||||
</span>
|
||||
);
|
||||
@@ -48,7 +48,7 @@ describe("ApiKeysManager", () => {
|
||||
|
||||
it("should render the API documentation link", () => {
|
||||
renderComponent();
|
||||
|
||||
|
||||
// Find the link to the API documentation
|
||||
const link = screen.getByRole("link");
|
||||
expect(link).toBeInTheDocument();
|
||||
@@ -56,4 +56,4 @@ describe("ApiKeysManager", () => {
|
||||
expect(link).toHaveAttribute("target", "_blank");
|
||||
expect(link).toHaveAttribute("rel", "noopener noreferrer");
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -39,4 +39,4 @@ describe("Check for hardcoded English strings in Home components", () => {
|
||||
expect(text).not.toContain(str);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -60,11 +60,11 @@ Object.entries(translationJson).forEach(([key, translations]) => {
|
||||
if (Object.keys(missingTranslations).length > 0) {
|
||||
console.error('\x1b[31m%s\x1b[0m', 'ERROR: Missing translations detected');
|
||||
console.error(`Found ${Object.keys(missingTranslations).length} translation keys with missing languages:`);
|
||||
|
||||
|
||||
Object.entries(missingTranslations).forEach(([key, langs]) => {
|
||||
console.error(`- Key "${key}" is missing translations for: ${langs.join(', ')}`);
|
||||
});
|
||||
|
||||
|
||||
console.error('\nPlease add the missing translations before committing.');
|
||||
}
|
||||
|
||||
@@ -72,11 +72,11 @@ if (Object.keys(missingTranslations).length > 0) {
|
||||
if (Object.keys(extraLanguages).length > 0) {
|
||||
console.error('\x1b[31m%s\x1b[0m', 'ERROR: Extra languages detected');
|
||||
console.error(`Found ${Object.keys(extraLanguages).length} translation keys with extra languages not in AvailableLanguages:`);
|
||||
|
||||
|
||||
Object.entries(extraLanguages).forEach(([key, langs]) => {
|
||||
console.error(`- Key "${key}" has translations for unsupported languages: ${langs.join(', ')}`);
|
||||
});
|
||||
|
||||
|
||||
console.error('\nPlease remove the extra languages before committing.');
|
||||
}
|
||||
|
||||
@@ -85,4 +85,4 @@ if (hasErrors) {
|
||||
process.exit(1);
|
||||
} else {
|
||||
console.log('\x1b[32m%s\x1b[0m', 'All translation keys have complete language coverage!');
|
||||
}
|
||||
}
|
||||
@@ -19,10 +19,10 @@ vi.mock("react-i18next", () => ({
|
||||
|
||||
describe("RepositorySelectionForm", () => {
|
||||
const mockOnRepoSelection = vi.fn();
|
||||
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
|
||||
|
||||
// Mock the hooks with default values
|
||||
(useUserRepositories as any).mockReturnValue({
|
||||
data: [
|
||||
@@ -32,7 +32,7 @@ describe("RepositorySelectionForm", () => {
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
});
|
||||
|
||||
|
||||
(useRepositoryBranches as any).mockReturnValue({
|
||||
data: [
|
||||
{ name: "main" },
|
||||
@@ -41,90 +41,90 @@ describe("RepositorySelectionForm", () => {
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
});
|
||||
|
||||
|
||||
(useCreateConversation as any).mockReturnValue({
|
||||
mutate: vi.fn(),
|
||||
isPending: false,
|
||||
isSuccess: false,
|
||||
});
|
||||
|
||||
|
||||
(useIsCreatingConversation as any).mockReturnValue(false);
|
||||
});
|
||||
|
||||
|
||||
it("should clear selected branch when input is empty", async () => {
|
||||
render(<RepositorySelectionForm onRepoSelection={mockOnRepoSelection} />);
|
||||
|
||||
|
||||
// First select a repository to enable the branch dropdown
|
||||
const repoDropdown = screen.getByTestId("repository-dropdown");
|
||||
fireEvent.change(repoDropdown, { target: { value: "test/repo1" } });
|
||||
|
||||
|
||||
// Get the branch dropdown and verify it's enabled
|
||||
const branchDropdown = screen.getByTestId("branch-dropdown");
|
||||
expect(branchDropdown).not.toBeDisabled();
|
||||
|
||||
|
||||
// Simulate deleting all text in the branch input
|
||||
fireEvent.change(branchDropdown, { target: { value: "" } });
|
||||
|
||||
|
||||
// Verify the branch input is cleared (no selected branch)
|
||||
expect(branchDropdown).toHaveValue("");
|
||||
});
|
||||
|
||||
|
||||
it("should clear selected branch when input contains only whitespace", async () => {
|
||||
render(<RepositorySelectionForm onRepoSelection={mockOnRepoSelection} />);
|
||||
|
||||
|
||||
// First select a repository to enable the branch dropdown
|
||||
const repoDropdown = screen.getByTestId("repository-dropdown");
|
||||
fireEvent.change(repoDropdown, { target: { value: "test/repo1" } });
|
||||
|
||||
|
||||
// Get the branch dropdown and verify it's enabled
|
||||
const branchDropdown = screen.getByTestId("branch-dropdown");
|
||||
expect(branchDropdown).not.toBeDisabled();
|
||||
|
||||
|
||||
// Simulate entering only whitespace in the branch input
|
||||
fireEvent.change(branchDropdown, { target: { value: " " } });
|
||||
|
||||
|
||||
// Verify the branch input is cleared (no selected branch)
|
||||
expect(branchDropdown).toHaveValue("");
|
||||
});
|
||||
|
||||
it("should keep branch empty after being cleared even with auto-selection", async () => {
|
||||
render(<RepositorySelectionForm onRepoSelection={mockOnRepoSelection} />);
|
||||
|
||||
|
||||
// First select a repository to enable the branch dropdown
|
||||
const repoDropdown = screen.getByTestId("repository-dropdown");
|
||||
fireEvent.change(repoDropdown, { target: { value: "test/repo1" } });
|
||||
|
||||
|
||||
// Get the branch dropdown and verify it's enabled
|
||||
const branchDropdown = screen.getByTestId("branch-dropdown");
|
||||
expect(branchDropdown).not.toBeDisabled();
|
||||
|
||||
|
||||
// The branch should be auto-selected to "main" initially
|
||||
expect(branchDropdown).toHaveValue("main");
|
||||
|
||||
|
||||
// Simulate deleting all text in the branch input
|
||||
fireEvent.change(branchDropdown, { target: { value: "" } });
|
||||
|
||||
|
||||
// Verify the branch input is cleared (no selected branch)
|
||||
expect(branchDropdown).toHaveValue("");
|
||||
|
||||
|
||||
// Trigger a re-render by changing something else
|
||||
fireEvent.change(repoDropdown, { target: { value: "test/repo2" } });
|
||||
fireEvent.change(repoDropdown, { target: { value: "test/repo1" } });
|
||||
|
||||
|
||||
// The branch should be auto-selected to "main" again after repo change
|
||||
expect(branchDropdown).toHaveValue("main");
|
||||
|
||||
|
||||
// Clear it again
|
||||
fireEvent.change(branchDropdown, { target: { value: "" } });
|
||||
|
||||
|
||||
// Verify it stays empty
|
||||
expect(branchDropdown).toHaveValue("");
|
||||
|
||||
|
||||
// Simulate a component update without changing repos
|
||||
// This would normally trigger the useEffect if our fix wasn't working
|
||||
fireEvent.blur(branchDropdown);
|
||||
|
||||
|
||||
// Verify it still stays empty
|
||||
expect(branchDropdown).toHaveValue("");
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -8,6 +8,7 @@ from openhands.agenthub import ( # noqa: E402
|
||||
codeact_agent,
|
||||
dummy_agent,
|
||||
loc_agent,
|
||||
proxy_agent,
|
||||
readonly_agent,
|
||||
visualbrowsing_agent,
|
||||
)
|
||||
@@ -19,6 +20,7 @@ __all__ = [
|
||||
'dummy_agent',
|
||||
'browsing_agent',
|
||||
'visualbrowsing_agent',
|
||||
'proxy_agent',
|
||||
'readonly_agent',
|
||||
'loc_agent',
|
||||
]
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
# Proxy Agent
|
||||
|
||||
This folder is an implementation of a Proxy Agent.
|
||||
The Proxy Agent delegates a given task to an appropriate agent capable of accomplishing it.
|
||||
The list of available agents is defined in agent_list.json, located in this directory.
|
||||
|
||||
A key feature of the Proxy Agent is that, in addition to delegating task to different agents available locally within OpenHands, it can also send messages to agents hosted on different server, using A2A Protocol.
|
||||
|
||||
## How to run
|
||||
### Set as the initial agent
|
||||
This agent is designed to be the initial agent that receives user input at the start of a session.
|
||||
Configure the Proxy Agent as the initial agent of a session.
|
||||
```mermaid
|
||||
flowchart LR
|
||||
u((User)) --> A
|
||||
|
||||
subgraph Server1
|
||||
A["Proxy Agent"]
|
||||
B["Other Agents<br>(e.g. CodeActAgent)"]
|
||||
A -->|delegate| B
|
||||
end
|
||||
|
||||
subgraph Server2
|
||||
D["Other Agents"]
|
||||
end
|
||||
|
||||
A --->|Remote Delegation| D
|
||||
|
||||
```
|
||||
|
||||
### Place agent_list.json
|
||||
Place agent_list.json under openhands/agenthub/proxy_agent. Below is an example of its structure:
|
||||
```json
|
||||
{
|
||||
"local": {
|
||||
"CodeActAgent": {
|
||||
"agent_name": "CodeActAgent",
|
||||
"description": "A helpful AI assistant that can interact with a computer to solve tasks."
|
||||
}
|
||||
},
|
||||
"remote": {
|
||||
"FooAgent": {
|
||||
"agent_name": "FooAgent",
|
||||
"url": "http(s)://IP or FQDN:port",
|
||||
"description": "A brief description of FooAgent.",
|
||||
"protocol": "A2A"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
The contents of this JSON file are simply passed as a string to the agent as part of its prompt, assisting the LLM in selecting the most suitable agent.
|
||||
Therefore, there areno strict formatting requirements, but please keep the following points in mind:
|
||||
- Clearly specify whether the agent is available locally within the same instance or hosted on a different instance.
|
||||
- If an agent is hosted on a different instance, explicitly provide the URL where that instance is hosted.
|
||||
@@ -0,0 +1,4 @@
|
||||
from openhands.agenthub.proxy_agent.proxy_agent import ProxyAgent
|
||||
from openhands.controller.agent import Agent
|
||||
|
||||
Agent.register('ProxyAgent', ProxyAgent)
|
||||
@@ -0,0 +1,180 @@
|
||||
import json
|
||||
|
||||
from litellm import (
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
from openhands.core.exceptions import (
|
||||
FunctionCallNotExistsError,
|
||||
FunctionCallValidationError,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
AgentDelegateAction,
|
||||
AgentFinishAction,
|
||||
IPythonRunCellAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.tool import ToolCallMetadata
|
||||
|
||||
_DELEGATE_LOCAL = """Delegate a task to a local agent hosted on a same instance.
|
||||
"""
|
||||
|
||||
DelegateLocalTool = ChatCompletionToolParam(
|
||||
type='function',
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name='delegate_local',
|
||||
description=_DELEGATE_LOCAL,
|
||||
parameters={
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'agent_name': {
|
||||
'type': 'string',
|
||||
'description': 'The name of the agent to delegate to.',
|
||||
},
|
||||
'task': {
|
||||
'type': 'string',
|
||||
'description': 'The task to delegate.',
|
||||
},
|
||||
},
|
||||
'required': ['agent_name', 'task'],
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
_DELEGATE_REMOTE = """Delegate a task to a remote agent hosted on a remote server using A2A Protocol.
|
||||
"""
|
||||
|
||||
DelegateRemoteTool = ChatCompletionToolParam(
|
||||
type='function',
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name='delegate_remote',
|
||||
description=_DELEGATE_REMOTE,
|
||||
parameters={
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'url': {
|
||||
'type': 'string',
|
||||
'description': 'The URL of the remote agent.',
|
||||
},
|
||||
'task': {
|
||||
'type': 'string',
|
||||
'description': 'The task to delegate.',
|
||||
},
|
||||
'session_id': {
|
||||
'type': 'string',
|
||||
'description': 'The session id of the remote agent.',
|
||||
},
|
||||
'task_id': {
|
||||
'type': 'string',
|
||||
'description': 'The task id of the remote agent.',
|
||||
}
|
||||
},
|
||||
'required': ['url', 'task'],
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
_FINISH_DESCRIPTION = """Finish the interaction when the task is complete OR if the assistant cannot proceed further with the task."""
|
||||
|
||||
FinishTool = ChatCompletionToolParam(
|
||||
type='function',
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name='finish',
|
||||
description=_FINISH_DESCRIPTION,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def combine_thought(action: Action, thought: str) -> Action:
|
||||
if not hasattr(action, 'thought'):
|
||||
return action
|
||||
if thought:
|
||||
action.thought = thought
|
||||
return action
|
||||
|
||||
|
||||
def response_to_action(response: ModelResponse) -> Action:
|
||||
action: Action = None # type: ignore
|
||||
assert len(response.choices) == 1, 'Only one choice is supported for now'
|
||||
assistant_msg = response.choices[0].message
|
||||
if assistant_msg.tool_calls:
|
||||
# Check if there's assistant_msg.content. If so, add it to the thought
|
||||
thought = ''
|
||||
if isinstance(assistant_msg.content, str):
|
||||
thought = assistant_msg.content
|
||||
elif isinstance(assistant_msg.content, list):
|
||||
for msg in assistant_msg.content:
|
||||
if msg['type'] == 'text':
|
||||
thought += msg['text']
|
||||
|
||||
# Assume only one tool call is returned
|
||||
if len(assistant_msg.tool_calls) != 1:
|
||||
logger.info(
|
||||
f'Expected only one tool call, but got {len(assistant_msg.tool_calls)}'
|
||||
)
|
||||
tool_call = assistant_msg.tool_calls[0]
|
||||
try:
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
raise RuntimeError(
|
||||
f'Failed to parse tool call arguments: {tool_call.function.arguments}'
|
||||
) from e
|
||||
|
||||
if tool_call.function.name == 'delegate_remote':
|
||||
for k in ['url', 'task']:
|
||||
if k not in arguments:
|
||||
raise FunctionCallValidationError(
|
||||
f'Missing required argument "{k}" in tool call {tool_call.function.name}'
|
||||
)
|
||||
|
||||
message = arguments['task']
|
||||
message = message.replace('\n', '\\\n')
|
||||
url = arguments['url']
|
||||
session_id = arguments.get('session_id')
|
||||
task_id = arguments.get('task_id')
|
||||
if session_id and task_id:
|
||||
code = (
|
||||
f'await send_task_A2A('
|
||||
f'message="{message}", '
|
||||
f'url="{url}", '
|
||||
f'session_id="{session_id}", '
|
||||
f'task_id="{task_id}")'
|
||||
)
|
||||
else:
|
||||
code = (
|
||||
f'await send_task_A2A('
|
||||
f'message="{message}", '
|
||||
f'url="{url}")'
|
||||
)
|
||||
|
||||
action = IPythonRunCellAction(code=code, include_extra=False)
|
||||
|
||||
elif tool_call.function.name == 'finish':
|
||||
action = AgentFinishAction()
|
||||
else:
|
||||
raise FunctionCallNotExistsError(
|
||||
f'Tool {tool_call.function.name} is not registered. (arguments: {arguments}). Please check the tool name and retry with an existing tool.'
|
||||
)
|
||||
|
||||
action = combine_thought(action, thought)
|
||||
# Add metadata for tool calling
|
||||
action.tool_call_metadata = ToolCallMetadata(
|
||||
tool_call_id=tool_call.id,
|
||||
function_name=tool_call.function.name,
|
||||
model_response=response,
|
||||
total_calls_in_response=len(assistant_msg.tool_calls),
|
||||
)
|
||||
|
||||
else:
|
||||
action = MessageAction(content=assistant_msg.content, wait_for_response=True)
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def get_tools() -> list[ChatCompletionToolParam]:
|
||||
tools = [DelegateLocalTool, DelegateRemoteTool, FinishTool]
|
||||
return tools
|
||||
@@ -0,0 +1,6 @@
|
||||
You are a Proxy Agent, a helpful AI assistant which is responsible for delegating tasks to other agents.
|
||||
You delegate tasks to agents that exist locally or are hosted remotely on another server.
|
||||
<IMPORTANT>
|
||||
* Never execute an action again once the action has been completed.
|
||||
* When you delegate a task to a remote-host agent, you must read the response of the remote agent and return a message to the user as if you were that agent.
|
||||
</IMPORTANT>
|
||||
@@ -0,0 +1,126 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import openhands.agenthub.proxy_agent.function_calling as proxy_function_calling
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.message import Message, TextContent
|
||||
from openhands.events.action import Action, MessageAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.memory.conversation_memory import ConversationMemory
|
||||
from openhands.microagent.prompt_manager import PromptManager
|
||||
from openhands.runtime.plugins import (
|
||||
AgentSkillsRequirement,
|
||||
JupyterRequirement,
|
||||
PluginRequirement,
|
||||
)
|
||||
|
||||
|
||||
class ProxyAgent(Agent):
|
||||
sandbox_plugins: list[PluginRequirement] = [
|
||||
AgentSkillsRequirement(),
|
||||
JupyterRequirement(),
|
||||
]
|
||||
|
||||
def __init__(self, llm: LLM, config: AgentConfig) -> None:
|
||||
super().__init__(llm, config)
|
||||
self.reset()
|
||||
|
||||
self.mock_function_calling = False
|
||||
if not self.llm.is_function_calling_active():
|
||||
logger.info(
|
||||
f'Function calling not enabled for model {self.llm.config.model}. '
|
||||
'Mocking function calling via prompting.'
|
||||
)
|
||||
self.mock_function_calling = True
|
||||
|
||||
# Function calling mode
|
||||
self.tools = proxy_function_calling.get_tools()
|
||||
|
||||
self._prompt_manager = PromptManager(
|
||||
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
|
||||
)
|
||||
|
||||
# Create a ConversationMemory instance
|
||||
# _prompt_manager is guaranteed to be set at this point
|
||||
assert self._prompt_manager is not None
|
||||
self.conversation_memory = ConversationMemory(self.config, self._prompt_manager)
|
||||
|
||||
agent_list_path = os.path.join(os.path.dirname(__file__), 'agent_list.json')
|
||||
if not os.path.exists(agent_list_path):
|
||||
raise FileNotFoundError('agent list file not found')
|
||||
with open(agent_list_path, 'r') as f:
|
||||
self.agent_list = json.load(f)
|
||||
if self.agent_list == {}:
|
||||
raise ValueError('agent list file is empty')
|
||||
|
||||
def step(self, state: State) -> Action:
|
||||
# Prepare the message to send to the LLM
|
||||
initial_user_message = self._get_initial_user_message(state.history)
|
||||
messages = self._get_messages(state.history, initial_user_message)
|
||||
|
||||
params: dict = {
|
||||
'messages': self.llm.format_messages_for_llm(messages),
|
||||
}
|
||||
params['tools'] = self.tools
|
||||
if self.mock_function_calling:
|
||||
params['mock_function_calling'] = True
|
||||
response = self.llm.completion(**params)
|
||||
|
||||
# Assume only one tool call is returned
|
||||
action = proxy_function_calling.response_to_action(response)
|
||||
return action
|
||||
|
||||
def _get_initial_user_message(self, history: list[Event]) -> MessageAction:
|
||||
"""Finds the initial user message action from the full history."""
|
||||
initial_user_message: MessageAction | None = None
|
||||
for event in history:
|
||||
if isinstance(event, MessageAction) and event.source == 'user':
|
||||
initial_user_message = event
|
||||
break
|
||||
|
||||
if initial_user_message is None:
|
||||
# This should not happen in a valid conversation
|
||||
raise ValueError(
|
||||
'Initial user message not found in history. Please report this issue.'
|
||||
)
|
||||
return initial_user_message
|
||||
|
||||
def _get_messages(
|
||||
self, events: list[Event], initial_user_message: MessageAction
|
||||
) -> list[Message]:
|
||||
if not self.prompt_manager:
|
||||
raise Exception('Prompt Manager not instantiated.')
|
||||
|
||||
# Use ConversationMemory to process events (including SystemMessageAction)
|
||||
messages = self.conversation_memory.process_events(
|
||||
condensed_history=events,
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=self.llm.config.max_message_chars,
|
||||
vision_is_active=self.llm.vision_is_active(),
|
||||
)
|
||||
|
||||
agent_list_message = Message(
|
||||
role='system',
|
||||
content=[
|
||||
TextContent(
|
||||
text='Available agents are the following:'
|
||||
+ json.dumps(self.agent_list)
|
||||
)
|
||||
],
|
||||
)
|
||||
if len(messages) > 1:
|
||||
messages.insert(1, agent_list_message)
|
||||
else:
|
||||
messages.append(agent_list_message)
|
||||
|
||||
if self.llm.is_caching_prompt_active():
|
||||
self.conversation_memory.apply_prompt_caching(messages)
|
||||
|
||||
return messages
|
||||
|
||||
def reset(self) -> None:
|
||||
super().reset()
|
||||
@@ -208,9 +208,7 @@ Note:
|
||||
# for visualwebarena, webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env
|
||||
# initialize and retrieve the first observation by issuing an noop OP
|
||||
# For non-benchmark browsing, the browser env starts with a blank page, and the agent is expected to first navigate to desired websites
|
||||
return BrowseInteractiveAction(
|
||||
browser_actions='noop(1000)', return_axtree=True
|
||||
)
|
||||
return BrowseInteractiveAction(browser_actions='noop(1000)', return_axtree=True)
|
||||
|
||||
for event in state.view:
|
||||
if isinstance(event, BrowseInteractiveAction):
|
||||
|
||||
@@ -215,18 +215,10 @@ async def modify_llm_settings_basic(
|
||||
]
|
||||
provider_models = VERIFIED_ANTHROPIC_MODELS + provider_models
|
||||
|
||||
# Set default model to the best verified model for the provider
|
||||
if provider == 'anthropic' and VERIFIED_ANTHROPIC_MODELS:
|
||||
# Use the first model in the VERIFIED_ANTHROPIC_MODELS list as it's the best/newest
|
||||
default_model = VERIFIED_ANTHROPIC_MODELS[0]
|
||||
elif provider == 'openai' and VERIFIED_OPENAI_MODELS:
|
||||
# Use the first model in the VERIFIED_OPENAI_MODELS list as it's the best/newest
|
||||
default_model = VERIFIED_OPENAI_MODELS[0]
|
||||
else:
|
||||
# For other providers, use the first model in the list
|
||||
default_model = (
|
||||
provider_models[0] if provider_models else 'claude-sonnet-4-20250514'
|
||||
)
|
||||
# Set default model to the first model in the list (which will be a verified model if available)
|
||||
default_model = (
|
||||
provider_models[0] if provider_models else 'claude-sonnet-4-20250514'
|
||||
)
|
||||
|
||||
# Show the default model but allow changing it
|
||||
print_formatted_text(
|
||||
|
||||
@@ -158,17 +158,17 @@ VERIFIED_OPENAI_MODELS = [
|
||||
]
|
||||
|
||||
VERIFIED_ANTHROPIC_MODELS = [
|
||||
'claude-2',
|
||||
'claude-2.1',
|
||||
'claude-3-5-sonnet-20240620',
|
||||
'claude-3-5-sonnet-20241022',
|
||||
'claude-3-5-haiku-20241022',
|
||||
'claude-3-haiku-20240307',
|
||||
'claude-3-opus-20240229',
|
||||
'claude-3-sonnet-20240229',
|
||||
'claude-3-7-sonnet-20250219',
|
||||
'claude-sonnet-4-20250514',
|
||||
'claude-opus-4-20250514',
|
||||
'claude-3-7-sonnet-20250219',
|
||||
'claude-3-sonnet-20240229',
|
||||
'claude-3-opus-20240229',
|
||||
'claude-3-haiku-20240307',
|
||||
'claude-3-5-haiku-20241022',
|
||||
'claude-3-5-sonnet-20241022',
|
||||
'claude-3-5-sonnet-20240620',
|
||||
'claude-2.1',
|
||||
'claude-2',
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -54,7 +54,6 @@ class MCPStdioServerConfig(BaseModel):
|
||||
and set(self.env.items()) == set(other.env.items())
|
||||
)
|
||||
|
||||
|
||||
class MCPSHTTPServerConfig(BaseModel):
|
||||
url: str
|
||||
api_key: str | None = None
|
||||
|
||||
@@ -744,27 +744,6 @@ def get_parser() -> argparse.ArgumentParser:
|
||||
type=bool,
|
||||
default=False,
|
||||
)
|
||||
|
||||
# LLM configuration arguments for local models
|
||||
parser.add_argument(
|
||||
'--llm-model',
|
||||
help='LLM model to use (e.g., "lm_studio/devstral", "openai/gpt-4")',
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
'--llm-base-url',
|
||||
help='Base URL for LLM API (required for local models, e.g., "http://localhost:1234/v1")',
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
'--llm-api-key',
|
||||
help='API key for LLM (use "dummy" for local models)',
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -842,21 +821,6 @@ def setup_config_from_args(args: argparse.Namespace) -> OpenHandsConfig:
|
||||
raise ValueError(f'Invalid toml file, cannot read {args.llm_config}')
|
||||
config.set_llm_config(llm_config)
|
||||
|
||||
# Override LLM settings with direct CLI arguments
|
||||
if args.llm_model or args.llm_base_url or args.llm_api_key:
|
||||
from pydantic import SecretStr
|
||||
|
||||
llm_config = config.get_llm_config()
|
||||
|
||||
if args.llm_model:
|
||||
llm_config.model = args.llm_model
|
||||
if args.llm_base_url:
|
||||
llm_config.base_url = args.llm_base_url
|
||||
if args.llm_api_key:
|
||||
llm_config.api_key = SecretStr(args.llm_api_key)
|
||||
|
||||
config.set_llm_config(llm_config)
|
||||
|
||||
# Override default agent if provided
|
||||
if args.agent_cls:
|
||||
config.default_agent = args.agent_cls
|
||||
|
||||
@@ -39,7 +39,6 @@ class GitHubService(BaseGitService, GitService):
|
||||
|
||||
The class is instantiated via get_impl() in openhands.server.shared.py.
|
||||
"""
|
||||
|
||||
BASE_URL = 'https://api.github.com'
|
||||
token: SecretStr = SecretStr('')
|
||||
refresh = False
|
||||
@@ -509,6 +508,7 @@ class GitHubService(BaseGitService, GitService):
|
||||
return response['html_url']
|
||||
|
||||
|
||||
|
||||
github_service_cls = os.environ.get(
|
||||
'OPENHANDS_GITHUB_SERVICE_CLS',
|
||||
'openhands.integrations.github.github_service.GitHubService',
|
||||
|
||||
@@ -32,7 +32,6 @@ class GitLabService(BaseGitService, GitService):
|
||||
|
||||
The class is instantiated via get_impl() in openhands.server.shared.py.
|
||||
"""
|
||||
|
||||
BASE_URL = 'https://gitlab.com/api/v4'
|
||||
GRAPHQL_URL = 'https://gitlab.com/api/graphql'
|
||||
token: SecretStr = SecretStr('')
|
||||
@@ -483,7 +482,9 @@ class GitLabService(BaseGitService, GitService):
|
||||
|
||||
# Set default description if none provided
|
||||
if not description:
|
||||
description = f'Merging changes from {source_branch} into {target_branch}'
|
||||
description = (
|
||||
f'Merging changes from {source_branch} into {target_branch}'
|
||||
)
|
||||
|
||||
# Prepare the request payload
|
||||
payload = {
|
||||
@@ -498,9 +499,11 @@ class GitLabService(BaseGitService, GitService):
|
||||
url=url, params=payload, method=RequestMethod.POST
|
||||
)
|
||||
|
||||
|
||||
return response['web_url']
|
||||
|
||||
|
||||
|
||||
gitlab_service_cls = os.environ.get(
|
||||
'OPENHANDS_GITLAB_SERVICE_CLS',
|
||||
'openhands.integrations.gitlab.gitlab_service.GitLabService',
|
||||
|
||||
@@ -1 +1 @@
|
||||
{{ issue_comment }}
|
||||
{{ issue_comment }}
|
||||
@@ -1 +1 @@
|
||||
Please fix issue number #{{ issue_number }} in your repository.
|
||||
Please fix issue number #{{ issue_number }} in your repository.
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import typing
|
||||
from functools import lru_cache
|
||||
from typing import Callable
|
||||
import typing
|
||||
from uuid import UUID
|
||||
|
||||
import docker
|
||||
@@ -283,9 +283,7 @@ class DockerRuntime(ActionExecutionClient):
|
||||
self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
|
||||
|
||||
use_host_network = self.config.sandbox.use_host_network
|
||||
network_mode: typing.Literal['host'] | None = (
|
||||
'host' if use_host_network else None
|
||||
)
|
||||
network_mode: typing.Literal['host'] | None = 'host' if use_host_network else None
|
||||
|
||||
# Initialize port mappings
|
||||
port_mapping: dict[str, list[dict[str, str]]] | None = None
|
||||
@@ -358,7 +356,7 @@ class DockerRuntime(ActionExecutionClient):
|
||||
|
||||
try:
|
||||
if self.runtime_container_image is None:
|
||||
raise ValueError('Runtime container image is not set')
|
||||
raise ValueError("Runtime container image is not set")
|
||||
self.container = self.docker_client.containers.run(
|
||||
self.runtime_container_image,
|
||||
command=command,
|
||||
|
||||
@@ -363,7 +363,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
self._session_api_key = start_response['session_api_key']
|
||||
self.log(
|
||||
'debug',
|
||||
'Session API key set',
|
||||
f'Session API key setted',
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -59,7 +59,7 @@ class MCPProxyManager:
|
||||
"""
|
||||
if len(self.config['mcpServers']) == 0:
|
||||
logger.info(
|
||||
'No MCP servers configured for FastMCP Proxy, skipping initialization.'
|
||||
f"No MCP servers configured for FastMCP Proxy, skipping initialization."
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -70,7 +70,7 @@ class MCPProxyManager:
|
||||
api_key=self.api_key,
|
||||
)
|
||||
|
||||
logger.info('FastMCP Proxy initialized successfully')
|
||||
logger.info(f"FastMCP Proxy initialized successfully")
|
||||
|
||||
async def mount_to_app(
|
||||
self, app: FastAPI, allow_origins: Optional[list[str]] = None
|
||||
@@ -83,7 +83,9 @@ class MCPProxyManager:
|
||||
allow_origins: List of allowed origins for CORS
|
||||
"""
|
||||
if len(self.config['mcpServers']) == 0:
|
||||
logger.info('No MCP servers configured for FastMCP Proxy, skipping mount.')
|
||||
logger.info(
|
||||
f"No MCP servers configured for FastMCP Proxy, skipping mount."
|
||||
)
|
||||
return
|
||||
|
||||
if not self.proxy:
|
||||
@@ -99,7 +101,8 @@ class MCPProxyManager:
|
||||
app.routes.remove('/mcp')
|
||||
|
||||
app.mount('/', mcp_app)
|
||||
logger.info('Mounted FastMCP Proxy app at /mcp')
|
||||
logger.info(f"Mounted FastMCP Proxy app at /mcp")
|
||||
|
||||
|
||||
async def update_and_remount(
|
||||
self,
|
||||
@@ -119,7 +122,10 @@ class MCPProxyManager:
|
||||
tools: List of tool configurations
|
||||
allow_origins: List of allowed origins for CORS
|
||||
"""
|
||||
tools = {t.name: t.model_dump() for t in stdio_servers}
|
||||
tools = {
|
||||
t.name: t.model_dump()
|
||||
for t in stdio_servers
|
||||
}
|
||||
self.config['mcpServers'] = tools
|
||||
|
||||
del self.proxy
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
# A2A Client
|
||||
This is an implementation of an A2A Client, called by agents within runtime container.
|
||||
|
||||
This directory contains code from [A2A](https://github.com/google/A2A), originally licensed under the Apache License 2.0.
|
||||
The original source has been modified to fit the needs of this project.
|
||||
See third_party_license/LICENSE for the full license text.
|
||||
|
||||
## Modifications
|
||||
|
||||
- Removed unused components (e.g. PushNotfication) from original code.
|
||||
- Implemented 'send_task_a2a' with customed I/O to make it more convenient for AI Agent
|
||||
@@ -0,0 +1,9 @@
|
||||
from openhands.runtime.plugins.agent_skills.a2a_client import a2a_client
|
||||
from openhands.runtime.plugins.agent_skills.utils.dependency import import_functions
|
||||
|
||||
import_functions(
|
||||
module=a2a_client,
|
||||
function_names=a2a_client.__all__,
|
||||
target_globals=globals(),
|
||||
)
|
||||
__all__ = a2a_client.__all__
|
||||
@@ -0,0 +1,80 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from openhands.runtime.plugins.agent_skills.a2a_client.common.client import (
|
||||
A2ACardResolver,
|
||||
A2AClient,
|
||||
)
|
||||
from openhands.runtime.plugins.agent_skills.a2a_client.common.types import (
|
||||
TaskState,
|
||||
)
|
||||
|
||||
|
||||
async def send_task_A2A(url, message, session_id=0, task_id=0):
|
||||
"""
|
||||
Send a task to an agent hosted on remote server, compatible with A2A protocol.
|
||||
"""
|
||||
## Get the agent card
|
||||
card_resolver = A2ACardResolver(url)
|
||||
card = card_resolver.get_agent_card()
|
||||
|
||||
print('======= Agent Card ========')
|
||||
print(card.model_dump_json(exclude_none=True))
|
||||
|
||||
client = A2AClient(agent_card=card)
|
||||
|
||||
if session_id == 0:
|
||||
session_id = uuid4().hex
|
||||
if task_id == 0:
|
||||
task_id = uuid4().hex
|
||||
|
||||
streaming = card.capabilities.streaming
|
||||
print('======= Session ID and Task ID ========')
|
||||
print(f'Session ID: {session_id}')
|
||||
print(f'Task ID: {task_id}')
|
||||
print('If you want to send more input, use the same session ID and task ID.')
|
||||
|
||||
print('========= starting a task ======== ')
|
||||
await completeTask(client, message, streaming, task_id, session_id)
|
||||
|
||||
|
||||
async def completeTask(client: A2AClient, message, streaming, task_id, session_id):
|
||||
prompt = message
|
||||
|
||||
message = {
|
||||
'role': 'user',
|
||||
'parts': [
|
||||
{
|
||||
'type': 'text',
|
||||
'text': prompt,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
payload = {
|
||||
'id': task_id,
|
||||
'sessionId': session_id,
|
||||
'acceptedOutputModes': ['text'],
|
||||
'message': message,
|
||||
}
|
||||
|
||||
taskResult = None
|
||||
if streaming:
|
||||
response_stream = client.send_task_streaming(payload)
|
||||
async for result in response_stream:
|
||||
print(f'stream event => {result.model_dump_json(exclude_none=True)}')
|
||||
taskResult = await client.get_task({'id': task_id})
|
||||
else:
|
||||
taskResult = await client.send_task(payload)
|
||||
print(f'\n{taskResult.model_dump_json(exclude_none=True)}')
|
||||
|
||||
## if the result is that more input is required, tell the user and exit.
|
||||
if taskResult.result:
|
||||
state = TaskState(taskResult.result.status.state)
|
||||
if state.name == TaskState.INPUT_REQUIRED.name:
|
||||
print('Task requires more input. Use this tool again to provide it.')
|
||||
else:
|
||||
## task is complete
|
||||
return True
|
||||
|
||||
|
||||
__all__ = ['send_task_A2A']
|
||||
@@ -0,0 +1,4 @@
|
||||
from .client import A2AClient
|
||||
from .card_resolver import A2ACardResolver
|
||||
|
||||
__all__ = ["A2AClient", "A2ACardResolver"]
|
||||
@@ -0,0 +1,21 @@
|
||||
import httpx
|
||||
from openhands.runtime.plugins.agent_skills.a2a_client.common.types import (
|
||||
AgentCard,
|
||||
A2AClientJSONError,
|
||||
)
|
||||
import json
|
||||
|
||||
|
||||
class A2ACardResolver:
|
||||
def __init__(self, base_url, agent_card_path="/.well-known/agent.json"):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.agent_card_path = agent_card_path.lstrip("/")
|
||||
|
||||
def get_agent_card(self) -> AgentCard:
|
||||
with httpx.Client() as client:
|
||||
response = client.get(self.base_url + "/" + self.agent_card_path)
|
||||
response.raise_for_status()
|
||||
try:
|
||||
return AgentCard(**response.json())
|
||||
except json.JSONDecodeError as e:
|
||||
raise A2AClientJSONError(str(e)) from e
|
||||
@@ -0,0 +1,88 @@
|
||||
import json
|
||||
from typing import Any, AsyncIterable
|
||||
|
||||
import httpx
|
||||
from httpx_sse import connect_sse
|
||||
|
||||
from openhands.runtime.plugins.agent_skills.a2a_client.common.types import (
|
||||
A2AClientHTTPError,
|
||||
A2AClientJSONError,
|
||||
AgentCard,
|
||||
CancelTaskRequest,
|
||||
CancelTaskResponse,
|
||||
GetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationResponse,
|
||||
GetTaskRequest,
|
||||
GetTaskResponse,
|
||||
JSONRPCRequest,
|
||||
SendTaskRequest,
|
||||
SendTaskResponse,
|
||||
SendTaskStreamingRequest,
|
||||
SendTaskStreamingResponse,
|
||||
SetTaskPushNotificationRequest,
|
||||
SetTaskPushNotificationResponse,
|
||||
)
|
||||
|
||||
|
||||
class A2AClient:
|
||||
def __init__(self, agent_card: AgentCard | None = None, url: str | None = None):
|
||||
if agent_card:
|
||||
self.url = agent_card.url
|
||||
elif url:
|
||||
self.url = url
|
||||
else:
|
||||
raise ValueError('Must provide either agent_card or url')
|
||||
|
||||
async def send_task(self, payload: dict[str, Any]) -> SendTaskResponse:
|
||||
request = SendTaskRequest(params=payload)
|
||||
return SendTaskResponse(**await self._send_request(request))
|
||||
|
||||
async def send_task_streaming(
|
||||
self, payload: dict[str, Any]
|
||||
) -> AsyncIterable[SendTaskStreamingResponse]:
|
||||
request = SendTaskStreamingRequest(params=payload)
|
||||
with httpx.Client(timeout=None) as client:
|
||||
with connect_sse(
|
||||
client, 'POST', self.url, json=request.model_dump()
|
||||
) as event_source:
|
||||
try:
|
||||
for sse in event_source.iter_sse():
|
||||
yield SendTaskStreamingResponse(**json.loads(sse.data))
|
||||
except json.JSONDecodeError as e:
|
||||
raise A2AClientJSONError(str(e)) from e
|
||||
except httpx.RequestError as e:
|
||||
raise A2AClientHTTPError(400, str(e)) from e
|
||||
|
||||
async def _send_request(self, request: JSONRPCRequest) -> dict[str, Any]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
# Image generation could take time, adding timeout
|
||||
response = await client.post(
|
||||
self.url, json=request.model_dump(), timeout=30
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
|
||||
except json.JSONDecodeError as e:
|
||||
raise A2AClientJSONError(str(e)) from e
|
||||
|
||||
async def get_task(self, payload: dict[str, Any]) -> GetTaskResponse:
|
||||
request = GetTaskRequest(params=payload)
|
||||
return GetTaskResponse(**await self._send_request(request))
|
||||
|
||||
async def cancel_task(self, payload: dict[str, Any]) -> CancelTaskResponse:
|
||||
request = CancelTaskRequest(params=payload)
|
||||
return CancelTaskResponse(**await self._send_request(request))
|
||||
|
||||
async def set_task_callback(
|
||||
self, payload: dict[str, Any]
|
||||
) -> SetTaskPushNotificationResponse:
|
||||
request = SetTaskPushNotificationRequest(params=payload)
|
||||
return SetTaskPushNotificationResponse(**await self._send_request(request))
|
||||
|
||||
async def get_task_callback(
|
||||
self, payload: dict[str, Any]
|
||||
) -> GetTaskPushNotificationResponse:
|
||||
request = GetTaskPushNotificationRequest(params=payload)
|
||||
return GetTaskPushNotificationResponse(**await self._send_request(request))
|
||||
@@ -0,0 +1,365 @@
|
||||
from typing import Union, Any
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from typing import Literal, List, Annotated, Optional
|
||||
from datetime import datetime
|
||||
from pydantic import model_validator, ConfigDict, field_serializer
|
||||
from uuid import uuid4
|
||||
from enum import Enum
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
class TaskState(str, Enum):
|
||||
SUBMITTED = "submitted"
|
||||
WORKING = "working"
|
||||
INPUT_REQUIRED = "input-required"
|
||||
COMPLETED = "completed"
|
||||
CANCELED = "canceled"
|
||||
FAILED = "failed"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class TextPart(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class FileContent(BaseModel):
|
||||
name: str | None = None
|
||||
mimeType: str | None = None
|
||||
bytes: str | None = None
|
||||
uri: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_content(self) -> Self:
|
||||
if not (self.bytes or self.uri):
|
||||
raise ValueError("Either 'bytes' or 'uri' must be present in the file data")
|
||||
if self.bytes and self.uri:
|
||||
raise ValueError(
|
||||
"Only one of 'bytes' or 'uri' can be present in the file data"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class FilePart(BaseModel):
|
||||
type: Literal["file"] = "file"
|
||||
file: FileContent
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class DataPart(BaseModel):
|
||||
type: Literal["data"] = "data"
|
||||
data: dict[str, Any]
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
Part = Annotated[Union[TextPart, FilePart, DataPart], Field(discriminator="type")]
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: Literal["user", "agent"]
|
||||
parts: List[Part]
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskStatus(BaseModel):
|
||||
state: TaskState
|
||||
message: Message | None = None
|
||||
timestamp: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
@field_serializer("timestamp")
|
||||
def serialize_dt(self, dt: datetime, _info):
|
||||
return dt.isoformat()
|
||||
|
||||
|
||||
class Artifact(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
parts: List[Part]
|
||||
metadata: dict[str, Any] | None = None
|
||||
index: int = 0
|
||||
append: bool | None = None
|
||||
lastChunk: bool | None = None
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
id: str
|
||||
sessionId: str | None = None
|
||||
status: TaskStatus
|
||||
artifacts: List[Artifact] | None = None
|
||||
history: List[Message] | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskStatusUpdateEvent(BaseModel):
|
||||
id: str
|
||||
status: TaskStatus
|
||||
final: bool = False
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskArtifactUpdateEvent(BaseModel):
|
||||
id: str
|
||||
artifact: Artifact
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AuthenticationInfo(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
schemes: List[str]
|
||||
credentials: str | None = None
|
||||
|
||||
|
||||
class PushNotificationConfig(BaseModel):
|
||||
url: str
|
||||
token: str | None = None
|
||||
authentication: AuthenticationInfo | None = None
|
||||
|
||||
|
||||
class TaskIdParams(BaseModel):
|
||||
id: str
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskQueryParams(TaskIdParams):
|
||||
historyLength: int | None = None
|
||||
|
||||
|
||||
class TaskSendParams(BaseModel):
|
||||
id: str
|
||||
sessionId: str = Field(default_factory=lambda: uuid4().hex)
|
||||
message: Message
|
||||
acceptedOutputModes: Optional[List[str]] = None
|
||||
pushNotification: PushNotificationConfig | None = None
|
||||
historyLength: int | None = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TaskPushNotificationConfig(BaseModel):
|
||||
id: str
|
||||
pushNotificationConfig: PushNotificationConfig
|
||||
|
||||
|
||||
## RPC Messages
|
||||
|
||||
|
||||
class JSONRPCMessage(BaseModel):
|
||||
jsonrpc: Literal["2.0"] = "2.0"
|
||||
id: int | str | None = Field(default_factory=lambda: uuid4().hex)
|
||||
|
||||
|
||||
class JSONRPCRequest(JSONRPCMessage):
|
||||
method: str
|
||||
params: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class JSONRPCError(BaseModel):
|
||||
code: int
|
||||
message: str
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class JSONRPCResponse(JSONRPCMessage):
|
||||
result: Any | None = None
|
||||
error: JSONRPCError | None = None
|
||||
|
||||
|
||||
class SendTaskRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/send"] = "tasks/send"
|
||||
params: TaskSendParams
|
||||
|
||||
|
||||
class SendTaskResponse(JSONRPCResponse):
|
||||
result: Task | None = None
|
||||
|
||||
|
||||
class SendTaskStreamingRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/sendSubscribe"] = "tasks/sendSubscribe"
|
||||
params: TaskSendParams
|
||||
|
||||
|
||||
class SendTaskStreamingResponse(JSONRPCResponse):
|
||||
result: TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None = None
|
||||
|
||||
|
||||
class GetTaskRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/get"] = "tasks/get"
|
||||
params: TaskQueryParams
|
||||
|
||||
|
||||
class GetTaskResponse(JSONRPCResponse):
|
||||
result: Task | None = None
|
||||
|
||||
|
||||
class CancelTaskRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/cancel",] = "tasks/cancel"
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
class CancelTaskResponse(JSONRPCResponse):
|
||||
result: Task | None = None
|
||||
|
||||
|
||||
class SetTaskPushNotificationRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/pushNotification/set",] = "tasks/pushNotification/set"
|
||||
params: TaskPushNotificationConfig
|
||||
|
||||
|
||||
class SetTaskPushNotificationResponse(JSONRPCResponse):
|
||||
result: TaskPushNotificationConfig | None = None
|
||||
|
||||
|
||||
class GetTaskPushNotificationRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/pushNotification/get",] = "tasks/pushNotification/get"
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
class GetTaskPushNotificationResponse(JSONRPCResponse):
|
||||
result: TaskPushNotificationConfig | None = None
|
||||
|
||||
|
||||
class TaskResubscriptionRequest(JSONRPCRequest):
|
||||
method: Literal["tasks/resubscribe",] = "tasks/resubscribe"
|
||||
params: TaskIdParams
|
||||
|
||||
|
||||
A2ARequest = TypeAdapter(
|
||||
Annotated[
|
||||
Union[
|
||||
SendTaskRequest,
|
||||
GetTaskRequest,
|
||||
CancelTaskRequest,
|
||||
SetTaskPushNotificationRequest,
|
||||
GetTaskPushNotificationRequest,
|
||||
TaskResubscriptionRequest,
|
||||
SendTaskStreamingRequest,
|
||||
],
|
||||
Field(discriminator="method"),
|
||||
]
|
||||
)
|
||||
|
||||
## Error types
|
||||
|
||||
|
||||
class JSONParseError(JSONRPCError):
|
||||
code: int = -32700
|
||||
message: str = "Invalid JSON payload"
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class InvalidRequestError(JSONRPCError):
|
||||
code: int = -32600
|
||||
message: str = "Request payload validation error"
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class MethodNotFoundError(JSONRPCError):
|
||||
code: int = -32601
|
||||
message: str = "Method not found"
|
||||
data: None = None
|
||||
|
||||
|
||||
class InvalidParamsError(JSONRPCError):
|
||||
code: int = -32602
|
||||
message: str = "Invalid parameters"
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class InternalError(JSONRPCError):
|
||||
code: int = -32603
|
||||
message: str = "Internal error"
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class TaskNotFoundError(JSONRPCError):
|
||||
code: int = -32001
|
||||
message: str = "Task not found"
|
||||
data: None = None
|
||||
|
||||
|
||||
class TaskNotCancelableError(JSONRPCError):
|
||||
code: int = -32002
|
||||
message: str = "Task cannot be canceled"
|
||||
data: None = None
|
||||
|
||||
|
||||
class PushNotificationNotSupportedError(JSONRPCError):
|
||||
code: int = -32003
|
||||
message: str = "Push Notification is not supported"
|
||||
data: None = None
|
||||
|
||||
|
||||
class UnsupportedOperationError(JSONRPCError):
|
||||
code: int = -32004
|
||||
message: str = "This operation is not supported"
|
||||
data: None = None
|
||||
|
||||
|
||||
class ContentTypeNotSupportedError(JSONRPCError):
|
||||
code: int = -32005
|
||||
message: str = "Incompatible content types"
|
||||
data: None = None
|
||||
|
||||
|
||||
class AgentProvider(BaseModel):
|
||||
organization: str
|
||||
url: str | None = None
|
||||
|
||||
|
||||
class AgentCapabilities(BaseModel):
|
||||
streaming: bool = False
|
||||
pushNotifications: bool = False
|
||||
stateTransitionHistory: bool = False
|
||||
|
||||
|
||||
class AgentAuthentication(BaseModel):
|
||||
schemes: List[str]
|
||||
credentials: str | None = None
|
||||
|
||||
|
||||
class AgentSkill(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
tags: List[str] | None = None
|
||||
examples: List[str] | None = None
|
||||
inputModes: List[str] | None = None
|
||||
outputModes: List[str] | None = None
|
||||
|
||||
|
||||
class AgentCard(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
url: str
|
||||
provider: AgentProvider | None = None
|
||||
version: str
|
||||
documentationUrl: str | None = None
|
||||
capabilities: AgentCapabilities
|
||||
authentication: AgentAuthentication | None = None
|
||||
defaultInputModes: List[str] = ["text"]
|
||||
defaultOutputModes: List[str] = ["text"]
|
||||
skills: List[AgentSkill]
|
||||
|
||||
|
||||
class A2AClientError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class A2AClientHTTPError(A2AClientError):
|
||||
def __init__(self, status_code: int, message: str):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
super().__init__(f"HTTP Error {status_code}: {message}")
|
||||
|
||||
|
||||
class A2AClientJSONError(A2AClientError):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(f"JSON Error: {message}")
|
||||
|
||||
|
||||
class MissingAPIKeyError(Exception):
|
||||
"""Exception for missing API key."""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,202 @@
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@@ -1,6 +1,10 @@
|
||||
from inspect import signature
|
||||
|
||||
from openhands.runtime.plugins.agent_skills import file_ops, file_reader
|
||||
from openhands.runtime.plugins.agent_skills import (
|
||||
a2a_client,
|
||||
file_ops,
|
||||
file_reader,
|
||||
)
|
||||
from openhands.runtime.plugins.agent_skills.utils.dependency import import_functions
|
||||
|
||||
import_functions(
|
||||
@@ -9,8 +13,13 @@ import_functions(
|
||||
import_functions(
|
||||
module=file_reader, function_names=file_reader.__all__, target_globals=globals()
|
||||
)
|
||||
import_functions(
|
||||
module=a2a_client,
|
||||
function_names=a2a_client.__all__,
|
||||
target_globals=globals(),
|
||||
)
|
||||
|
||||
__all__ = file_ops.__all__ + file_reader.__all__
|
||||
__all__ = file_ops.__all__ + file_reader.__all__ + a2a_client.__all__
|
||||
|
||||
try:
|
||||
from openhands.runtime.plugins.agent_skills import repo_ops
|
||||
|
||||
@@ -27,7 +27,7 @@ from openhands.llm.metrics import Metrics
|
||||
from openhands.utils.chunk_localizer import Chunk, get_top_k_chunk_matches
|
||||
|
||||
USER_MSG = """
|
||||
Code changes will be provided in the form of a draft. You will need to apply the draft to the original code.
|
||||
Code changes will be provided in the form of a draft. You will need to apply the draft to the original code.
|
||||
The original code will be enclosed within `<original_code>` tags.
|
||||
The draft will be enclosed within `<update_snippet>` tags.
|
||||
You need to output the update code within `<updated_code>` tags.
|
||||
@@ -48,8 +48,8 @@ def _extract_code(string: str) -> str | None:
|
||||
|
||||
content = str(matches[0])
|
||||
if content.startswith('#EDIT:'):
|
||||
# Remove first line
|
||||
content = content[content.find('\n') + 1 :]
|
||||
#Remove first line
|
||||
content = content[content.find('\n') + 1:]
|
||||
return content
|
||||
|
||||
|
||||
|
||||
@@ -117,7 +117,7 @@ RUN /openhands/micromamba/bin/micromamba run -n openhands poetry install --only
|
||||
|
||||
# Install playwright and its dependencies
|
||||
RUN apt-get update && \
|
||||
/openhands/micromamba/bin/micromamba run -n openhands poetry run pip install playwright && \
|
||||
/openhands/micromamba/bin/micromamba run -n openhands poetry run pip install playwright httpx httpx-sse pydantic && \
|
||||
/openhands/micromamba/bin/micromamba run -n openhands poetry run playwright install --with-deps chromium
|
||||
|
||||
# Set environment variables and permissions
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import random
|
||||
import socket
|
||||
import time
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
def check_port_available(port: int) -> bool:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
|
||||
@@ -53,7 +53,7 @@ def load_server_config() -> ServerConfig:
|
||||
logger.info(f'Using config class {config_cls}')
|
||||
|
||||
server_config_cls = get_impl(ServerConfig, config_cls)
|
||||
server_config: ServerConfig = server_config_cls()
|
||||
server_config : ServerConfig = server_config_cls()
|
||||
server_config.verify_config()
|
||||
|
||||
return server_config
|
||||
|
||||
@@ -15,7 +15,7 @@ from openhands.events.stream import EventStreamSubscriber, session_exists
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||
from openhands.server.monitoring import MonitoringListener
|
||||
from openhands.server.session.agent_session import WAIT_TIME_BEFORE_CLOSE, AgentSession
|
||||
from openhands.server.session.agent_session import AgentSession, WAIT_TIME_BEFORE_CLOSE
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.session.session import ROOM_KEY, Session
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
@@ -508,9 +508,7 @@ class StandaloneConversationManager(ConversationManager):
|
||||
session_api_key=None,
|
||||
event_store=session.agent_session.event_stream,
|
||||
status=_get_status_from_session(session),
|
||||
runtime_status=getattr(
|
||||
session.agent_session.runtime, 'runtime_status', None
|
||||
),
|
||||
runtime_status=getattr(session.agent_session.runtime, 'runtime_status', None),
|
||||
)
|
||||
|
||||
def _get_conversation_url(self, conversation_id: str):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.runtime.runtime_status import RuntimeStatus
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
@@ -5,13 +5,13 @@ from pydantic import BaseModel
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.microagent.types import InputMetadata
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.shared import conversation_manager
|
||||
from openhands.server.utils import get_conversation
|
||||
from openhands.microagent.types import InputMetadata
|
||||
from openhands.memory.memory import Memory
|
||||
|
||||
app = APIRouter(
|
||||
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
||||
@@ -216,11 +216,7 @@ async def get_microagents(
|
||||
content=agent.content,
|
||||
triggers=[],
|
||||
inputs=agent.metadata.inputs,
|
||||
tools=[
|
||||
server.name for server in agent.metadata.mcp_tools.stdio_servers
|
||||
]
|
||||
if agent.metadata.mcp_tools
|
||||
else [],
|
||||
tools=[server.name for server in agent.metadata.mcp_tools.stdio_servers] if agent.metadata.mcp_tools else [],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -233,11 +229,7 @@ async def get_microagents(
|
||||
content=agent.content,
|
||||
triggers=agent.triggers,
|
||||
inputs=agent.metadata.inputs,
|
||||
tools=[
|
||||
server.name for server in agent.metadata.mcp_tools.stdio_servers
|
||||
]
|
||||
if agent.metadata.mcp_tools
|
||||
else [],
|
||||
tools=[server.name for server in agent.metadata.mcp_tools.stdio_servers] if agent.metadata.mcp_tools else [],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -6,19 +6,15 @@ from openhands.events.async_event_store_wrapper import AsyncEventStoreWrapper
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.utils import get_conversation
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
|
||||
app = APIRouter(
|
||||
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
||||
)
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
||||
|
||||
|
||||
@app.post('/submit-feedback')
|
||||
async def submit_feedback(
|
||||
request: Request, conversation: ServerConversation = Depends(get_conversation)
|
||||
) -> JSONResponse:
|
||||
async def submit_feedback(request: Request, conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse:
|
||||
"""Submit user feedback.
|
||||
|
||||
This function stores the provided feedback data.
|
||||
@@ -41,7 +37,9 @@ async def submit_feedback(
|
||||
# Assuming the storage service is already configured in the backend
|
||||
# and there is a function to handle the storage.
|
||||
body = await request.json()
|
||||
async_store = AsyncEventStoreWrapper(conversation.event_stream, filter_hidden=True)
|
||||
async_store = AsyncEventStoreWrapper(
|
||||
conversation.event_stream, filter_hidden=True
|
||||
)
|
||||
trajectory = []
|
||||
async for event in async_store:
|
||||
trajectory.append(event_to_dict(event))
|
||||
|
||||
@@ -5,6 +5,7 @@ from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Request,
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
@@ -26,15 +27,17 @@ from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.file_config import (
|
||||
FILES_TO_IGNORE,
|
||||
)
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.shared import (
|
||||
ConversationStoreImpl,
|
||||
config,
|
||||
)
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.server.utils import get_conversation, get_conversation_store
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
|
||||
app = APIRouter(
|
||||
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
||||
)
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
||||
|
||||
|
||||
@app.get(
|
||||
@@ -47,7 +50,7 @@ app = APIRouter(
|
||||
)
|
||||
async def list_files(
|
||||
conversation: ServerConversation = Depends(get_conversation),
|
||||
path: str | None = None,
|
||||
path: str | None = None
|
||||
) -> list[str] | JSONResponse:
|
||||
"""List files in the specified path.
|
||||
|
||||
@@ -129,9 +132,7 @@ async def list_files(
|
||||
415: {'description': 'Unsupported media type', 'model': dict},
|
||||
},
|
||||
)
|
||||
async def select_file(
|
||||
file: str, conversation: ServerConversation = Depends(get_conversation)
|
||||
) -> FileResponse | JSONResponse:
|
||||
async def select_file(file: str, conversation: ServerConversation = Depends(get_conversation)) -> FileResponse | JSONResponse:
|
||||
"""Retrieve the content of a specified file.
|
||||
|
||||
To select a file:
|
||||
@@ -195,9 +196,7 @@ async def select_file(
|
||||
500: {'description': 'Error zipping workspace', 'model': dict},
|
||||
},
|
||||
)
|
||||
def zip_current_workspace(
|
||||
conversation: ServerConversation = Depends(get_conversation),
|
||||
) -> FileResponse | JSONResponse:
|
||||
def zip_current_workspace(conversation: ServerConversation = Depends(get_conversation)) -> FileResponse | JSONResponse:
|
||||
try:
|
||||
logger.debug('Zipping workspace')
|
||||
runtime: Runtime = conversation.runtime
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import itertools
|
||||
import os
|
||||
import re
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
@@ -9,18 +9,19 @@ from fastapi.responses import JSONResponse
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.events.action import (
|
||||
ChangeAgentStateAction,
|
||||
NullAction,
|
||||
)
|
||||
from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.observation import (
|
||||
AgentStateChangedObservation,
|
||||
NullObservation,
|
||||
AgentStateChangedObservation,
|
||||
)
|
||||
from openhands.events.stream import EventStream
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
ProviderHandler,
|
||||
@@ -37,9 +38,10 @@ from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.data_models.conversation_info_result_set import (
|
||||
ConversationInfoResultSet,
|
||||
)
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.services.conversation_service import create_new_conversation
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.services.conversation_service import create_new_conversation
|
||||
from openhands.server.shared import (
|
||||
ConversationStoreImpl,
|
||||
config,
|
||||
@@ -51,12 +53,11 @@ from openhands.server.user_auth import (
|
||||
get_provider_tokens,
|
||||
get_user_id,
|
||||
get_user_secrets,
|
||||
get_user_settings,
|
||||
get_user_settings_store,
|
||||
get_user_settings,
|
||||
)
|
||||
from openhands.server.user_auth.user_auth import AuthType
|
||||
from openhands.server.utils import get_conversation as get_conversation_object
|
||||
from openhands.server.utils import get_conversation_store
|
||||
from openhands.server.utils import get_conversation_store, get_conversation as get_conversation_object
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
@@ -294,7 +295,7 @@ async def delete_conversation(
|
||||
async def get_prompt(
|
||||
event_id: int,
|
||||
user_settings: SettingsStore = Depends(get_user_settings_store),
|
||||
conversation: ServerConversation | None = Depends(get_conversation_object),
|
||||
conversation: ServerConversation | None = Depends(get_conversation_object)
|
||||
):
|
||||
if conversation is None:
|
||||
return JSONResponse(
|
||||
@@ -408,6 +409,7 @@ async def start_conversation(
|
||||
logger.info(f'Starting conversation: {conversation_id}')
|
||||
|
||||
try:
|
||||
|
||||
# Check that the conversation exists
|
||||
try:
|
||||
await conversation_store.get_metadata(conversation_id)
|
||||
@@ -461,17 +463,10 @@ async def stop_conversation(
|
||||
|
||||
try:
|
||||
# Check if the conversation is running
|
||||
agent_loop_info = await conversation_manager.get_agent_loop_info(
|
||||
user_id=user_id, filter_to_sids={conversation_id}
|
||||
)
|
||||
conversation_status = (
|
||||
agent_loop_info[0].status if agent_loop_info else ConversationStatus.STOPPED
|
||||
)
|
||||
agent_loop_info = await conversation_manager.get_agent_loop_info(user_id=user_id, filter_to_sids={conversation_id})
|
||||
conversation_status = agent_loop_info[0].status if agent_loop_info else ConversationStatus.STOPPED
|
||||
|
||||
if conversation_status not in (
|
||||
ConversationStatus.STARTING,
|
||||
ConversationStatus.RUNNING,
|
||||
):
|
||||
if conversation_status not in (ConversationStatus.STARTING, ConversationStatus.RUNNING):
|
||||
return ConversationResponse(
|
||||
status='ok',
|
||||
conversation_id=conversation_id,
|
||||
@@ -510,13 +505,9 @@ def _get_contextual_events(event_stream: EventStream, event_id: int) -> str:
|
||||
|
||||
agent_event_filter = EventFilter(
|
||||
exclude_hidden=True,
|
||||
exclude_types=(
|
||||
NullAction,
|
||||
NullObservation,
|
||||
ChangeAgentStateAction,
|
||||
AgentStateChangedObservation,
|
||||
exclude_types=(NullAction, NullObservation, ChangeAgentStateAction, AgentStateChangedObservation
|
||||
),
|
||||
) # the types of events that can be in an agent's history
|
||||
) # the types of events that can be in an agent's history
|
||||
|
||||
# from event_id - context_size to event_id..
|
||||
context_before = event_stream.search_events(
|
||||
|
||||
@@ -27,7 +27,7 @@ mcp_server = FastMCP(
|
||||
)
|
||||
|
||||
HOST = f'https://{os.getenv("WEB_HOST", "app.all-hands.dev").strip()}'
|
||||
CONVO_URL = HOST + '/conversations/{}'
|
||||
CONVO_URL = HOST + '/{}'
|
||||
|
||||
|
||||
async def get_convo_link(service: GitService, conversation_id: str, body: str) -> str:
|
||||
@@ -87,7 +87,7 @@ async def create_pr(
|
||||
target_branch: Annotated[str, Field(description='Target branch on repo')],
|
||||
title: Annotated[str, Field(description='PR Title')],
|
||||
body: Annotated[str | None, Field(description='PR body')],
|
||||
draft: Annotated[bool, Field(description='Whether PR opened is a draft')] = True,
|
||||
draft: Annotated[bool, Field(description='Whether PR opened is a draft')] = True
|
||||
) -> str:
|
||||
"""Open a PR in GitHub"""
|
||||
|
||||
@@ -127,7 +127,7 @@ async def create_pr(
|
||||
target_branch=target_branch,
|
||||
title=title,
|
||||
body=body,
|
||||
draft=draft,
|
||||
draft=draft
|
||||
)
|
||||
|
||||
if conversation_id:
|
||||
@@ -148,12 +148,7 @@ async def create_mr(
|
||||
],
|
||||
source_branch: Annotated[str, Field(description='Source branch on repo')],
|
||||
target_branch: Annotated[str, Field(description='Target branch on repo')],
|
||||
title: Annotated[
|
||||
str,
|
||||
Field(
|
||||
description='MR Title. Start title with `DRAFT:` or `WIP:` if applicable.'
|
||||
),
|
||||
],
|
||||
title: Annotated[str, Field(description='MR Title. Start title with `DRAFT:` or `WIP:` if applicable.')],
|
||||
description: Annotated[str | None, Field(description='MR description')],
|
||||
) -> str:
|
||||
"""Open a MR in GitLab"""
|
||||
|
||||
@@ -8,18 +8,14 @@ from fastapi import (
|
||||
)
|
||||
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.utils import get_conversation
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
|
||||
app = APIRouter(
|
||||
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
||||
)
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
||||
|
||||
|
||||
@app.route('/security/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE'])
|
||||
async def security_api(
|
||||
request: Request, conversation: ServerConversation = Depends(get_conversation)
|
||||
) -> Response:
|
||||
async def security_api(request: Request, conversation: ServerConversation = Depends(get_conversation)) -> Response:
|
||||
"""Catch-all route for security analyzer API requests.
|
||||
|
||||
Each request is handled directly to the security analyzer.
|
||||
@@ -39,4 +35,6 @@ async def security_api(
|
||||
detail='Security analyzer not initialized',
|
||||
)
|
||||
|
||||
return await conversation.security_analyzer.handle_api_request(request)
|
||||
return await conversation.security_analyzer.handle_api_request(
|
||||
request
|
||||
)
|
||||
|
||||
@@ -1,22 +1,18 @@
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from fastapi import APIRouter, Depends, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.async_event_store_wrapper import AsyncEventStoreWrapper
|
||||
from openhands.events.serialization import event_to_trajectory
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.utils import get_conversation
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
|
||||
app = APIRouter(
|
||||
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
||||
)
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
||||
|
||||
|
||||
@app.get('/trajectory')
|
||||
async def get_trajectory(
|
||||
conversation: ServerConversation = Depends(get_conversation),
|
||||
) -> JSONResponse:
|
||||
async def get_trajectory(conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse:
|
||||
"""Get trajectory.
|
||||
|
||||
This function retrieves the current trajectory and returns it.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
@@ -79,6 +80,7 @@ async def create_new_conversation(
|
||||
session_init_args['conversation_instructions'] = conversation_instructions
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
|
||||
|
||||
logger.info('Loading conversation store')
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
logger.info('ServerConversation store loaded')
|
||||
@@ -88,14 +90,13 @@ async def create_new_conversation(
|
||||
conversation_id = uuid.uuid4().hex
|
||||
|
||||
if not await conversation_store.exists(conversation_id):
|
||||
|
||||
logger.info(
|
||||
f'New conversation ID: {conversation_id}',
|
||||
extra={'user_id': user_id, 'session_id': conversation_id},
|
||||
)
|
||||
|
||||
conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test(
|
||||
user_id, conversation_id, conversation_init_data
|
||||
)
|
||||
conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test(user_id, conversation_id, conversation_init_data)
|
||||
conversation_title = get_default_conversation_title(conversation_id)
|
||||
|
||||
logger.info(f'Saving metadata for conversation {conversation_id}')
|
||||
|
||||
@@ -197,21 +197,23 @@ class AgentSession:
|
||||
finally:
|
||||
self._starting = False
|
||||
success = finished and runtime_connected
|
||||
duration = time.time() - started_at
|
||||
duration = (time.time() - started_at)
|
||||
|
||||
log_metadata = {
|
||||
'signal': 'agent_session_start',
|
||||
'success': success,
|
||||
'duration': duration,
|
||||
'restored_state': restored_state,
|
||||
'restored_state': restored_state
|
||||
}
|
||||
if success:
|
||||
self.logger.info(
|
||||
f'Agent session start succeeded in {duration}s', extra=log_metadata
|
||||
f'Agent session start succeeded in {duration}s',
|
||||
extra=log_metadata
|
||||
)
|
||||
else:
|
||||
self.logger.error(
|
||||
f'Agent session start failed in {duration}s', extra=log_metadata
|
||||
f'Agent session start failed in {duration}s',
|
||||
extra=log_metadata
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
|
||||
@@ -105,12 +105,7 @@ class FileConversationStore(ConversationStore):
|
||||
async def get_instance(
|
||||
cls, config: OpenHandsConfig, user_id: str | None
|
||||
) -> FileConversationStore:
|
||||
file_store = get_file_store(
|
||||
config.file_store,
|
||||
config.file_store_path,
|
||||
config.file_store_web_hook_url,
|
||||
config.file_store_web_hook_headers,
|
||||
)
|
||||
file_store = get_file_store(config.file_store, config.file_store_path, config.file_store_web_hook_url, config.file_store_web_hook_headers)
|
||||
return FileConversationStore(file_store)
|
||||
|
||||
|
||||
|
||||
@@ -43,6 +43,6 @@ class FileSecretsStore(SecretsStore):
|
||||
config.file_store,
|
||||
config.file_store_path,
|
||||
config.file_store_web_hook_url,
|
||||
config.file_store_web_hook_headers,
|
||||
config.file_store_web_hook_headers
|
||||
)
|
||||
return FileSecretsStore(file_store)
|
||||
|
||||
@@ -37,6 +37,6 @@ class FileSettingsStore(SettingsStore):
|
||||
config.file_store,
|
||||
config.file_store_path,
|
||||
config.file_store_web_hook_url,
|
||||
config.file_store_web_hook_headers,
|
||||
config.file_store_web_hook_headers
|
||||
)
|
||||
return FileSettingsStore(file_store)
|
||||
|
||||
Generated
+11
-12
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiofiles"
|
||||
@@ -400,7 +400,7 @@ description = "LTS Port of Python audioop"
|
||||
optional = false
|
||||
python-versions = ">=3.13"
|
||||
groups = ["main"]
|
||||
markers = "python_version >= \"3.13\""
|
||||
markers = "python_version == \"3.13\""
|
||||
files = [
|
||||
{file = "audioop_lts-0.2.1-cp313-abi3-macosx_10_13_universal2.whl", hash = "sha256:fd1345ae99e17e6910f47ce7d52673c6a1a70820d78b67de1b7abb3af29c426a"},
|
||||
{file = "audioop_lts-0.2.1-cp313-abi3-macosx_10_13_x86_64.whl", hash = "sha256:e175350da05d2087e12cea8e72a70a1a8b14a17e92ed2022952a4419689ede5e"},
|
||||
@@ -1580,7 +1580,7 @@ files = [
|
||||
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
|
||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||
]
|
||||
markers = {main = "platform_system == \"Windows\" or sys_platform == \"win32\" or os_name == \"nt\"", dev = "os_name == \"nt\" or sys_platform == \"win32\"", runtime = "sys_platform == \"win32\"", test = "platform_system == \"Windows\" or sys_platform == \"win32\""}
|
||||
markers = {main = "platform_system == \"Windows\" or os_name == \"nt\" or sys_platform == \"win32\"", dev = "os_name == \"nt\" or sys_platform == \"win32\"", runtime = "sys_platform == \"win32\"", test = "platform_system == \"Windows\" or sys_platform == \"win32\""}
|
||||
|
||||
[[package]]
|
||||
name = "comm"
|
||||
@@ -2974,8 +2974,8 @@ files = [
|
||||
google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]}
|
||||
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev"
|
||||
proto-plus = [
|
||||
{version = ">=1.22.3,<2.0.0dev"},
|
||||
{version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=1.22.3,<2.0.0dev"},
|
||||
]
|
||||
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev"
|
||||
|
||||
@@ -2997,8 +2997,8 @@ googleapis-common-protos = ">=1.56.2,<2.0.0"
|
||||
grpcio = {version = ">=1.49.1,<2.0.0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}
|
||||
grpcio-status = {version = ">=1.49.1,<2.0.0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}
|
||||
proto-plus = [
|
||||
{version = ">=1.22.3,<2.0.0"},
|
||||
{version = ">=1.25.0,<2.0.0", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=1.22.3,<2.0.0"},
|
||||
]
|
||||
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
|
||||
requests = ">=2.18.0,<3.0.0"
|
||||
@@ -3216,8 +3216,8 @@ google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0", extras
|
||||
google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0"
|
||||
grpc-google-iam-v1 = ">=0.14.0,<1.0.0"
|
||||
proto-plus = [
|
||||
{version = ">=1.22.3,<2.0.0"},
|
||||
{version = ">=1.25.0,<2.0.0", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=1.22.3,<2.0.0"},
|
||||
]
|
||||
protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<7.0.0"
|
||||
|
||||
@@ -5422,7 +5422,7 @@ version = "0.61.0"
|
||||
description = "A module for monitoring memory usage of a python program"
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
groups = ["runtime"]
|
||||
groups = ["main", "runtime"]
|
||||
files = [
|
||||
{file = "memory_profiler-0.61.0-py3-none-any.whl", hash = "sha256:400348e61031e3942ad4d4109d18753b2fb08c2f6fb8290671c5513a34182d84"},
|
||||
{file = "memory_profiler-0.61.0.tar.gz", hash = "sha256:4e5b73d7864a1d1292fb76a03e82a3e78ef934d06828a698d9dada76da2067b0"},
|
||||
@@ -6479,8 +6479,8 @@ files = [
|
||||
[package.dependencies]
|
||||
googleapis-common-protos = ">=1.52,<2.0"
|
||||
grpcio = [
|
||||
{version = ">=1.63.2,<2.0.0", markers = "python_version < \"3.13\""},
|
||||
{version = ">=1.66.2,<2.0.0", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=1.63.2,<2.0.0", markers = "python_version < \"3.13\""},
|
||||
]
|
||||
opentelemetry-api = ">=1.15,<2.0"
|
||||
opentelemetry-exporter-otlp-proto-common = "1.34.1"
|
||||
@@ -9243,7 +9243,6 @@ files = [
|
||||
{file = "setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922"},
|
||||
{file = "setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c"},
|
||||
]
|
||||
markers = {evaluation = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
|
||||
|
||||
[package.extras]
|
||||
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""]
|
||||
@@ -9486,7 +9485,7 @@ description = "Standard library aifc redistribution. \"dead battery\"."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
markers = "python_version >= \"3.13\""
|
||||
markers = "python_version == \"3.13\""
|
||||
files = [
|
||||
{file = "standard_aifc-3.13.0-py3-none-any.whl", hash = "sha256:f7ae09cc57de1224a0dd8e3eb8f73830be7c3d0bc485de4c1f82b4a7f645ac66"},
|
||||
{file = "standard_aifc-3.13.0.tar.gz", hash = "sha256:64e249c7cb4b3daf2fdba4e95721f811bde8bdfc43ad9f936589b7bb2fae2e43"},
|
||||
@@ -9503,7 +9502,7 @@ description = "Standard library chunk redistribution. \"dead battery\"."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
markers = "python_version >= \"3.13\""
|
||||
markers = "python_version == \"3.13\""
|
||||
files = [
|
||||
{file = "standard_chunk-3.13.0-py3-none-any.whl", hash = "sha256:17880a26c285189c644bd5bd8f8ed2bdb795d216e3293e6dbe55bbd848e2982c"},
|
||||
{file = "standard_chunk-3.13.0.tar.gz", hash = "sha256:4ac345d37d7e686d2755e01836b8d98eda0d1a3ee90375e597ae43aaf064d654"},
|
||||
@@ -11665,4 +11664,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.12,<3.14"
|
||||
content-hash = "47df4fc76b97147ff31169028edafaf35c1f4e661c7ab74bad48cb0ceea06aba"
|
||||
content-hash = "0b8da1a7da2d598f9ca4a8933245c99495f7a34bb26e1221eebd7ba2fa1d6ddc"
|
||||
|
||||
@@ -71,6 +71,11 @@ python-frontmatter = "^1.1.0"
|
||||
# TODO: Should these go into the runtime group?
|
||||
ipywidgets = "^8.1.5"
|
||||
qtconsole = "^5.6.1"
|
||||
memory-profiler = "^0.61.0"
|
||||
playwright = "^1.51.0"
|
||||
pydantic = "^2.11.3"
|
||||
httpx = "^0.28.1"
|
||||
httpx-sse = "^0.4.0"
|
||||
PyPDF2 = "*"
|
||||
python-pptx = "*"
|
||||
pylatexenc = "*"
|
||||
|
||||
@@ -10,36 +10,24 @@ class TestTranslationCompleteness(unittest.TestCase):
|
||||
|
||||
def test_translation_completeness_check_runs(self):
|
||||
"""Test that the translation completeness check script can be executed."""
|
||||
frontend_dir = os.path.join(
|
||||
os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
),
|
||||
'frontend',
|
||||
)
|
||||
script_path = os.path.join(
|
||||
frontend_dir, 'scripts', 'check-translation-completeness.cjs'
|
||||
)
|
||||
|
||||
frontend_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "frontend")
|
||||
script_path = os.path.join(frontend_dir, "scripts", "check-translation-completeness.cjs")
|
||||
|
||||
# Verify the script exists
|
||||
self.assertTrue(
|
||||
os.path.exists(script_path), f'Script not found at {script_path}'
|
||||
)
|
||||
|
||||
self.assertTrue(os.path.exists(script_path), f"Script not found at {script_path}")
|
||||
|
||||
# Verify the script is executable
|
||||
self.assertTrue(
|
||||
os.access(script_path, os.X_OK),
|
||||
f'Script at {script_path} is not executable',
|
||||
)
|
||||
|
||||
self.assertTrue(os.access(script_path, os.X_OK), f"Script at {script_path} is not executable")
|
||||
|
||||
# Run the script (it may fail due to missing translations, but we just want to verify it runs)
|
||||
try:
|
||||
subprocess.run(
|
||||
['node', script_path],
|
||||
cwd=frontend_dir,
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
["node", script_path],
|
||||
cwd=frontend_dir,
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
# We don't assert on the return code because it might fail due to missing translations
|
||||
except Exception as e:
|
||||
self.fail(f'Failed to run translation completeness check: {e}')
|
||||
self.fail(f"Failed to run translation completeness check: {e}")
|
||||
@@ -100,6 +100,7 @@ def mock_conversation_instructions_template():
|
||||
return 'Instructions: {{ repo_instruction }}'
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_followup_prompt_template():
|
||||
return 'Issue context: {{ issues }}\n\nReview comments: {{ review_comments }}\n\nReview threads: {{ review_threads }}\n\nFiles: {{ files }}\n\nThread comments: {{ thread_context }}\n\nPlease fix this issue.'
|
||||
@@ -531,11 +532,7 @@ async def test_process_issue(
|
||||
handler_instance.guess_success.assert_not_called()
|
||||
|
||||
|
||||
def test_get_instruction(
|
||||
mock_user_instructions_template,
|
||||
mock_conversation_instructions_template,
|
||||
mock_followup_prompt_template,
|
||||
):
|
||||
def test_get_instruction(mock_user_instructions_template, mock_conversation_instructions_template, mock_followup_prompt_template):
|
||||
issue = Issue(
|
||||
owner='test_owner',
|
||||
repo='test_repo',
|
||||
@@ -548,10 +545,7 @@ def test_get_instruction(
|
||||
GithubIssueHandler('owner', 'repo', 'token'), mock_llm_config
|
||||
)
|
||||
instruction, conversation_instructions, images_urls = issue_handler.get_instruction(
|
||||
issue,
|
||||
mock_user_instructions_template,
|
||||
mock_conversation_instructions_template,
|
||||
None,
|
||||
issue, mock_user_instructions_template, mock_conversation_instructions_template, None
|
||||
)
|
||||
expected_instruction = 'Issue: Test Issue\n\nThis is a test issue refer to image \n\nPlease fix this issue.'
|
||||
|
||||
@@ -582,10 +576,7 @@ def test_get_instruction(
|
||||
GithubPRHandler('owner', 'repo', 'token'), mock_llm_config
|
||||
)
|
||||
instruction, conversation_instructions, images_urls = pr_handler.get_instruction(
|
||||
issue,
|
||||
mock_followup_prompt_template,
|
||||
mock_conversation_instructions_template,
|
||||
None,
|
||||
issue, mock_followup_prompt_template, mock_conversation_instructions_template, None
|
||||
)
|
||||
expected_instruction = "Issue context: [\n \"Issue 1 fix the type\"\n]\n\nReview comments: None\n\nReview threads: [\n \"There is still a typo 'pthon' instead of 'python'\"\n]\n\nFiles: []\n\nThread comments: I've left review comments, please address them\n---\nThis is a valid concern.\n\nPlease fix this issue."
|
||||
|
||||
@@ -610,9 +601,7 @@ def test_file_instruction():
|
||||
with open('openhands/resolver/prompts/resolve/basic.jinja', 'r') as f:
|
||||
prompt = f.read()
|
||||
|
||||
with open(
|
||||
'openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r'
|
||||
) as f:
|
||||
with open('openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r') as f:
|
||||
conversation_instructions_template = f.read()
|
||||
|
||||
# Test without thread comments
|
||||
@@ -621,7 +610,7 @@ def test_file_instruction():
|
||||
GithubIssueHandler('owner', 'repo', 'token'), mock_llm_config
|
||||
)
|
||||
instruction, conversation_instructions, images_urls = issue_handler.get_instruction(
|
||||
issue, prompt, conversation_instructions_template, None
|
||||
issue, prompt,conversation_instructions_template, None
|
||||
)
|
||||
expected_instruction = """Please fix the following issue for the repository in /workspace.
|
||||
An environment has been set up for you to start working. You may assume all necessary tools are installed.
|
||||
@@ -631,6 +620,7 @@ Test Issue
|
||||
|
||||
This is a test issue """
|
||||
|
||||
|
||||
expected_conversation_instructions = """IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.
|
||||
You SHOULD INCLUDE PROPER INDENTATION in your edit commands.
|
||||
|
||||
@@ -654,9 +644,7 @@ def test_file_instruction_with_repo_instruction():
|
||||
with open('openhands/resolver/prompts/resolve/basic.jinja', 'r') as f:
|
||||
prompt = f.read()
|
||||
|
||||
with open(
|
||||
'openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r'
|
||||
) as f:
|
||||
with open('openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r') as f:
|
||||
conversation_instructions_prompt = f.read()
|
||||
|
||||
# load repo instruction from openhands/resolver/prompts/repo_instructions/all-hands-ai___openhands-resolver.txt
|
||||
@@ -674,6 +662,7 @@ def test_file_instruction_with_repo_instruction():
|
||||
issue, prompt, conversation_instructions_prompt, repo_instruction
|
||||
)
|
||||
|
||||
|
||||
expected_instruction = """Please fix the following issue for the repository in /workspace.
|
||||
An environment has been set up for you to start working. You may assume all necessary tools are installed.
|
||||
|
||||
@@ -694,6 +683,7 @@ This is a Python repo for openhands-resolver, a library that attempts to resolve
|
||||
|
||||
When you think you have fixed the issue through code changes, please finish the interaction."""
|
||||
|
||||
|
||||
assert instruction == expected_instruction
|
||||
assert conversation_instructions == expected_conversation_instructions
|
||||
assert conversation_instructions is not None
|
||||
@@ -795,9 +785,7 @@ def test_instruction_with_thread_comments():
|
||||
with open('openhands/resolver/prompts/resolve/basic.jinja', 'r') as f:
|
||||
prompt = f.read()
|
||||
|
||||
with open(
|
||||
'openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r'
|
||||
) as f:
|
||||
with open('openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r') as f:
|
||||
conversation_instructions_template = f.read()
|
||||
|
||||
llm_config = LLMConfig(model='test', api_key='test')
|
||||
|
||||
@@ -483,7 +483,7 @@ def test_send_pull_request_with_reviewer(
|
||||
), # PR creation
|
||||
]
|
||||
|
||||
# Mock request reviewers response
|
||||
# Mock request reviwers response
|
||||
mock_put.side_effect = [
|
||||
MagicMock(status_code=200), # Reviewer request
|
||||
]
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
from typing import Type
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
@@ -5,11 +8,11 @@ from openhands.core.config import LLMConfig
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.resolver.interfaces.github import GithubIssueHandler, GithubPRHandler
|
||||
from openhands.resolver.interfaces.gitlab import GitlabIssueHandler, GitlabPRHandler
|
||||
from openhands.resolver.issue_handler_factory import IssueHandlerFactory
|
||||
from openhands.resolver.interfaces.issue_definitions import (
|
||||
ServiceContextIssue,
|
||||
ServiceContextPR,
|
||||
)
|
||||
from openhands.resolver.issue_handler_factory import IssueHandlerFactory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -42,29 +45,33 @@ test_cases = [
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'platform,issue_type,expected_context_type,expected_handler_type', test_cases
|
||||
'platform,issue_type,expected_context_type,expected_handler_type',
|
||||
test_cases
|
||||
)
|
||||
def test_handler_creation(
|
||||
factory_params,
|
||||
platform: ProviderType,
|
||||
issue_type: str,
|
||||
expected_context_type: type,
|
||||
expected_handler_type: type,
|
||||
expected_context_type: Type,
|
||||
expected_handler_type: Type,
|
||||
):
|
||||
factory = IssueHandlerFactory(
|
||||
**factory_params, platform=platform, issue_type=issue_type
|
||||
**factory_params,
|
||||
platform=platform,
|
||||
issue_type=issue_type
|
||||
)
|
||||
|
||||
|
||||
handler = factory.create()
|
||||
|
||||
|
||||
assert isinstance(handler, expected_context_type)
|
||||
assert isinstance(handler._strategy, expected_handler_type)
|
||||
|
||||
|
||||
def test_invalid_issue_type(factory_params):
|
||||
factory = IssueHandlerFactory(
|
||||
**factory_params, platform=ProviderType.GITHUB, issue_type='invalid'
|
||||
**factory_params,
|
||||
platform=ProviderType.GITHUB,
|
||||
issue_type='invalid'
|
||||
)
|
||||
|
||||
|
||||
with pytest.raises(ValueError, match='Invalid issue type: invalid'):
|
||||
factory.create()
|
||||
factory.create()
|
||||
@@ -2,7 +2,7 @@ from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.config import OpenHandsConfig, SandboxConfig
|
||||
from openhands.core.config import SandboxConfig,OpenHandsConfig
|
||||
from openhands.events.action import CmdRunAction
|
||||
from openhands.resolver.issue_resolver import IssueResolver
|
||||
|
||||
@@ -36,8 +36,7 @@ def test_setup_sandbox_config_default():
|
||||
)
|
||||
|
||||
assert_sandbox_config(
|
||||
openhands_config.sandbox,
|
||||
runtime_container_image='ghcr.io/all-hands-ai/runtime:mock-nikolaik',
|
||||
openhands_config.sandbox, runtime_container_image='ghcr.io/all-hands-ai/runtime:mock-nikolaik'
|
||||
)
|
||||
|
||||
|
||||
@@ -69,9 +68,7 @@ def test_setup_sandbox_config_base_only():
|
||||
)
|
||||
|
||||
assert_sandbox_config(
|
||||
openhands_config.sandbox,
|
||||
base_container_image=base_image,
|
||||
runtime_container_image=None,
|
||||
openhands_config.sandbox, base_container_image=base_image, runtime_container_image=None
|
||||
)
|
||||
|
||||
|
||||
@@ -87,9 +84,7 @@ def test_setup_sandbox_config_runtime_only():
|
||||
is_experimental=False,
|
||||
)
|
||||
|
||||
assert_sandbox_config(
|
||||
openhands_config.sandbox, runtime_container_image=runtime_image
|
||||
)
|
||||
assert_sandbox_config(openhands_config.sandbox, runtime_container_image=runtime_image)
|
||||
|
||||
|
||||
def test_setup_sandbox_config_experimental():
|
||||
@@ -122,9 +117,7 @@ def test_setup_sandbox_config_gitlab_ci(mock_get_unique_uid, mock_getuid):
|
||||
is_experimental=False,
|
||||
)
|
||||
|
||||
assert_sandbox_config(
|
||||
openhands_config.sandbox, local_runtime_url='http://localhost'
|
||||
)
|
||||
assert_sandbox_config(openhands_config.sandbox, local_runtime_url='http://localhost')
|
||||
|
||||
|
||||
@mock.patch('openhands.resolver.issue_resolver.os.getuid', return_value=1000)
|
||||
@@ -141,9 +134,7 @@ def test_setup_sandbox_config_gitlab_ci_non_root(mock_getuid):
|
||||
is_experimental=False,
|
||||
)
|
||||
|
||||
assert_sandbox_config(
|
||||
openhands_config.sandbox, local_runtime_url='http://localhost'
|
||||
)
|
||||
assert_sandbox_config(openhands_config.sandbox, local_runtime_url='http://localhost')
|
||||
|
||||
|
||||
@mock.patch('openhands.events.observation.CmdOutputObservation')
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user