mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-08 22:38:05 -05:00
Lint all files in the repo (#9131)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
2
.github/workflows/lint-fix.yml
vendored
2
.github/workflows/lint-fix.yml
vendored
@@ -74,7 +74,7 @@ jobs:
|
|||||||
- name: Fix python lint issues
|
- name: Fix python lint issues
|
||||||
run: |
|
run: |
|
||||||
# Run all pre-commit hooks and continue even if they modify files (exit code 1)
|
# Run all pre-commit hooks and continue even if they modify files (exit code 1)
|
||||||
pre-commit run --config ./dev_config/python/.pre-commit-config.yaml --files openhands/**/* evaluation/**/* tests/**/* || true
|
pre-commit run --config ./dev_config/python/.pre-commit-config.yaml --all-files || true
|
||||||
|
|
||||||
# Commit and push changes if any
|
# Commit and push changes if any
|
||||||
- name: Check for changes
|
- name: Check for changes
|
||||||
|
|||||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -53,7 +53,7 @@ jobs:
|
|||||||
- name: Install pre-commit
|
- name: Install pre-commit
|
||||||
run: pip install pre-commit==3.7.0
|
run: pip install pre-commit==3.7.0
|
||||||
- name: Run pre-commit hooks
|
- name: Run pre-commit hooks
|
||||||
run: pre-commit run --files openhands/**/* evaluation/**/* tests/**/* --show-diff-on-failure --config ./dev_config/python/.pre-commit-config.yaml
|
run: pre-commit run --all-files --show-diff-on-failure --config ./dev_config/python/.pre-commit-config.yaml
|
||||||
|
|
||||||
# Check version consistency across documentation
|
# Check version consistency across documentation
|
||||||
check-version-consistency:
|
check-version-consistency:
|
||||||
|
|||||||
1
.github/workflows/py-unit-tests.yml
vendored
1
.github/workflows/py-unit-tests.yml
vendored
@@ -81,4 +81,3 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
TEST_RUNTIME: local
|
TEST_RUNTIME: local
|
||||||
DEBUG: "1"
|
DEBUG: "1"
|
||||||
|
|
||||||
|
|||||||
2
Makefile
2
Makefile
@@ -189,7 +189,7 @@ install-pre-commit-hooks:
|
|||||||
|
|
||||||
lint-backend:
|
lint-backend:
|
||||||
@echo "$(YELLOW)Running linters...$(RESET)"
|
@echo "$(YELLOW)Running linters...$(RESET)"
|
||||||
@poetry run pre-commit run --files openhands/**/* evaluation/**/* tests/**/* --show-diff-on-failure --config $(PRE_COMMIT_CONFIG_PATH)
|
@poetry run pre-commit run --all-files --show-diff-on-failure --config $(PRE_COMMIT_CONFIG_PATH)
|
||||||
|
|
||||||
lint-frontend:
|
lint-frontend:
|
||||||
@echo "$(YELLOW)Running linters for frontend...$(RESET)"
|
@echo "$(YELLOW)Running linters for frontend...$(RESET)"
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
npm install -g mint
|
npm install -g mint
|
||||||
```
|
```
|
||||||
|
|
||||||
or
|
or
|
||||||
|
|
||||||
```
|
```
|
||||||
yarn global add mint
|
yarn global add mint
|
||||||
@@ -14,4 +14,4 @@ yarn global add mint
|
|||||||
|
|
||||||
```
|
```
|
||||||
mint dev
|
mint dev
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ description: OpenHands uses LiteLLM to make calls to Google's chat models. You c
|
|||||||
When running OpenHands, you'll need to set the following in the OpenHands UI through the Settings under the `LLM` tab:
|
When running OpenHands, you'll need to set the following in the OpenHands UI through the Settings under the `LLM` tab:
|
||||||
- `LLM Provider` to `Gemini`
|
- `LLM Provider` to `Gemini`
|
||||||
- `LLM Model` to the model you will be using.
|
- `LLM Model` to the model you will be using.
|
||||||
If the model is not in the list, enable `Advanced` options, and enter it in `Custom Model`
|
If the model is not in the list, enable `Advanced` options, and enter it in `Custom Model`
|
||||||
(e.g. gemini/<model-name> like `gemini/gemini-2.0-flash`).
|
(e.g. gemini/<model-name> like `gemini/gemini-2.0-flash`).
|
||||||
- `API Key` to your Gemini API key
|
- `API Key` to your Gemini API key
|
||||||
|
|
||||||
@@ -26,5 +26,5 @@ VERTEXAI_LOCATION="<your-gcp-location>"
|
|||||||
Then set the following in the OpenHands UI through the Settings under the `LLM` tab:
|
Then set the following in the OpenHands UI through the Settings under the `LLM` tab:
|
||||||
- `LLM Provider` to `VertexAI`
|
- `LLM Provider` to `VertexAI`
|
||||||
- `LLM Model` to the model you will be using.
|
- `LLM Model` to the model you will be using.
|
||||||
If the model is not in the list, enable `Advanced` options, and enter it in `Custom Model`
|
If the model is not in the list, enable `Advanced` options, and enter it in `Custom Model`
|
||||||
(e.g. vertex_ai/<model-name>).
|
(e.g. vertex_ai/<model-name>).
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ description: OpenHands uses LiteLLM to make calls to chat models on Groq. You ca
|
|||||||
When running OpenHands, you'll need to set the following in the OpenHands UI through the Settings under the `LLM` tab:
|
When running OpenHands, you'll need to set the following in the OpenHands UI through the Settings under the `LLM` tab:
|
||||||
- `LLM Provider` to `Groq`
|
- `LLM Provider` to `Groq`
|
||||||
- `LLM Model` to the model you will be using. [Visit here to see the list of
|
- `LLM Model` to the model you will be using. [Visit here to see the list of
|
||||||
models that Groq hosts](https://console.groq.com/docs/models). If the model is not in the list,
|
models that Groq hosts](https://console.groq.com/docs/models). If the model is not in the list,
|
||||||
enable `Advanced` options, and enter it in `Custom Model` (e.g. groq/<model-name> like `groq/llama3-70b-8192`).
|
enable `Advanced` options, and enter it in `Custom Model` (e.g. groq/<model-name> like `groq/llama3-70b-8192`).
|
||||||
- `API key` to your Groq API key. To find or create your Groq API Key, [see here](https://console.groq.com/keys).
|
- `API key` to your Groq API key. To find or create your Groq API Key, [see here](https://console.groq.com/keys).
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ To use LiteLLM proxy with OpenHands, you need to:
|
|||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
The supported models depend on your LiteLLM proxy configuration. OpenHands supports any model that your LiteLLM proxy
|
The supported models depend on your LiteLLM proxy configuration. OpenHands supports any model that your LiteLLM proxy
|
||||||
is configured to handle.
|
is configured to handle.
|
||||||
|
|
||||||
Refer to your LiteLLM proxy configuration for the list of available models and their names.
|
Refer to your LiteLLM proxy configuration for the list of available models and their names.
|
||||||
|
|||||||
@@ -9,6 +9,6 @@ When running OpenHands, you'll need to set the following in the OpenHands UI thr
|
|||||||
* `LLM Provider` to `OpenRouter`
|
* `LLM Provider` to `OpenRouter`
|
||||||
* `LLM Model` to the model you will be using.
|
* `LLM Model` to the model you will be using.
|
||||||
[Visit here to see a full list of OpenRouter models](https://openrouter.ai/models).
|
[Visit here to see a full list of OpenRouter models](https://openrouter.ai/models).
|
||||||
If the model is not in the list, enable `Advanced` options, and enter it in
|
If the model is not in the list, enable `Advanced` options, and enter it in
|
||||||
`Custom Model` (e.g. openrouter/<model-name> like `openrouter/anthropic/claude-3.5-sonnet`).
|
`Custom Model` (e.g. openrouter/<model-name> like `openrouter/anthropic/claude-3.5-sonnet`).
|
||||||
* `API Key` to your OpenRouter API key.
|
* `API Key` to your OpenRouter API key.
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ description: Organizations and users can define microagents that apply to all re
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
These microagents can be [any type of microagent](./microagents-overview#microagent-types) and will be loaded
|
These microagents can be [any type of microagent](./microagents-overview#microagent-types) and will be loaded
|
||||||
accordingly. However, they are applied to all repositories belonging to the organization or user.
|
accordingly. However, they are applied to all repositories belonging to the organization or user.
|
||||||
|
|
||||||
Add a `.openhands` repository under the organization or user and create a `microagents` directory and place the
|
Add a `.openhands` repository under the organization or user and create a `microagents` directory and place the
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ Before using the Local Runtime, ensure that:
|
|||||||
1. You can run OpenHands using the [Development workflow](https://github.com/All-Hands-AI/OpenHands/blob/main/Development.md).
|
1. You can run OpenHands using the [Development workflow](https://github.com/All-Hands-AI/OpenHands/blob/main/Development.md).
|
||||||
2. For Linux and Mac, tmux is available on your system.
|
2. For Linux and Mac, tmux is available on your system.
|
||||||
3. For Windows, PowerShell is available on your system.
|
3. For Windows, PowerShell is available on your system.
|
||||||
- Only [CLI mode](../how-to/cli-mode) and [headless mode](../how-to/headless-mode) are supported in Windows with Local Runtime.
|
- Only [CLI mode](../how-to/cli-mode) and [headless mode](../how-to/headless-mode) are supported in Windows with Local Runtime.
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
TASK_INSTRUECTION="""
|
TASK_INSTRUECTION = """
|
||||||
Given the following GitHub problem description, your objective is to localize the specific files, classes or functions, and lines of code that need modification or contain key information to resolve the issue.
|
Given the following GitHub problem description, your objective is to localize the specific files, classes or functions, and lines of code that need modification or contain key information to resolve the issue.
|
||||||
|
|
||||||
Follow these steps to localize the issue:
|
Follow these steps to localize the issue:
|
||||||
@@ -66,4 +66,4 @@ FAKE_USER_MSG_FOR_LOC = (
|
|||||||
'Verify that you have carefully analyzed the impact of the found locations on the repository, especially their dependencies. '
|
'Verify that you have carefully analyzed the impact of the found locations on the repository, especially their dependencies. '
|
||||||
'If you think you have solved the task, please send your final answer (including the former answer and reranking) to user through message and then call `finish` to finish.\n'
|
'If you think you have solved the task, please send your final answer (including the former answer and reranking) to user through message and then call `finish` to finish.\n'
|
||||||
'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP.\n'
|
'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP.\n'
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ You MUST plan extensively before each function call, and reflect extensively on
|
|||||||
5. Debug as needed. Use debugging techniques to isolate and resolve issues.
|
5. Debug as needed. Use debugging techniques to isolate and resolve issues.
|
||||||
6. Test frequently. Run tests after each change to verify correctness.
|
6. Test frequently. Run tests after each change to verify correctness.
|
||||||
7. Iterate until the root cause is fixed and all tests pass.
|
7. Iterate until the root cause is fixed and all tests pass.
|
||||||
8. Reflect and validate comprehensively. After tests pass, think about the original intent, write additional tests to ensure correctness,
|
8. Reflect and validate comprehensively. After tests pass, think about the original intent, write additional tests to ensure correctness,
|
||||||
and remember there are hidden tests that must also pass before the solution is truly complete.
|
and remember there are hidden tests that must also pass before the solution is truly complete.
|
||||||
|
|
||||||
Refer to the detailed sections below for more information on each step.
|
Refer to the detailed sections below for more information on each step.
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ from openhands.core.config import (
|
|||||||
AgentConfig,
|
AgentConfig,
|
||||||
OpenHandsConfig,
|
OpenHandsConfig,
|
||||||
get_llm_config_arg,
|
get_llm_config_arg,
|
||||||
get_parser
|
get_parser,
|
||||||
)
|
)
|
||||||
from openhands.core.config.condenser_config import NoOpCondenserConfig
|
from openhands.core.config.condenser_config import NoOpCondenserConfig
|
||||||
from openhands.core.config.utils import get_condenser_config_arg
|
from openhands.core.config.utils import get_condenser_config_arg
|
||||||
@@ -92,10 +92,12 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata) -> MessageActio
|
|||||||
elif 'gpt-4.1' in llm_model:
|
elif 'gpt-4.1' in llm_model:
|
||||||
template_name = 'swe_gpt4.j2'
|
template_name = 'swe_gpt4.j2'
|
||||||
else:
|
else:
|
||||||
template_name = 'swe_default.j2' # Default for 'swe' mode (regular swe-bench)
|
template_name = (
|
||||||
|
'swe_default.j2' # Default for 'swe' mode (regular swe-bench)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Fallback or error handling if mode is unexpected
|
# Fallback or error handling if mode is unexpected
|
||||||
logger.error(f"Unexpected evaluation mode: {mode}. Falling back to default.")
|
logger.error(f'Unexpected evaluation mode: {mode}. Falling back to default.')
|
||||||
template_name = 'swe_default.j2'
|
template_name = 'swe_default.j2'
|
||||||
|
|
||||||
# Set up Jinja2 environment
|
# Set up Jinja2 environment
|
||||||
@@ -117,7 +119,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata) -> MessageActio
|
|||||||
f'The following command can be used to run the tests: `{list(MAP_REPO_TO_TEST_FRAMEWORK_VERBOSE[instance.repo].values())[0]}`. Make sure they fail in the expected way.\n'
|
f'The following command can be used to run the tests: `{list(MAP_REPO_TO_TEST_FRAMEWORK_VERBOSE[instance.repo].values())[0]}`. Make sure they fail in the expected way.\n'
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
context['test_instructions'] = '' # Ensure it's defined for other modes
|
context['test_instructions'] = '' # Ensure it's defined for other modes
|
||||||
|
|
||||||
# Render the instruction
|
# Render the instruction
|
||||||
instruction = template.render(context)
|
instruction = template.render(context)
|
||||||
|
|||||||
@@ -1,103 +1,102 @@
|
|||||||
# VersiCode benchmark
|
# VersiCode benchmark
|
||||||
|
|
||||||
This project is used to evaluate the performance of the model on VersiCode. It includes:
|
This project is used to evaluate the performance of the model on VersiCode. It includes:
|
||||||
|
|
||||||
- data: the test data needed and the model outputs
|
- data: the test data needed and the model outputs
|
||||||
- inference_utils: inference scripts for ours tasks and models
|
- inference_utils: inference scripts for ours tasks and models
|
||||||
- metric: scripts for calculating various metric
|
- metric: scripts for calculating various metric
|
||||||
- output_processing: process the model output to facilitate the calculation of model metrics
|
- output_processing: process the model output to facilitate the calculation of model metrics
|
||||||
|
|
||||||
# Details
|
# Details
|
||||||
|
|
||||||
1. **Prepare the environment**
|
1. **Prepare the environment**
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
#create conda environment
|
#create conda environment
|
||||||
conda create -n VersiCode python==3.12
|
conda create -n VersiCode python==3.12
|
||||||
|
|
||||||
#install requirements
|
#install requirements
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
2. **Experiment Data**
|
2. **Experiment Data**
|
||||||
|
|
||||||
To obtain the experimental data, please visit the Hugging Face link: https://huggingface.co/datasets/AstoneNg/VersiCode.
|
To obtain the experimental data, please visit the Hugging Face link: https://huggingface.co/datasets/AstoneNg/VersiCode.
|
||||||
Locate the files `VersiCode_block_completion.json` and `VersiCode_migration.json` under the `experiment_data` directory, and place them in the `/data/test_data directory` of this project.
|
Locate the files `VersiCode_block_completion.json` and `VersiCode_migration.json` under the `experiment_data` directory, and place them in the `/data/test_data directory` of this project.
|
||||||
|
|
||||||
|
|
||||||
3. **Model inference**
|
3. **Model inference**
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
#cd inference_utils directory
|
#cd inference_utils directory
|
||||||
cd inference_utils
|
cd inference_utils
|
||||||
|
|
||||||
#The script file starting with 'test' is used to test the local model
|
#The script file starting with 'test' is used to test the local model
|
||||||
#The script file at the beginning of the API is used to test the API call model
|
#The script file at the beginning of the API is used to test the API call model
|
||||||
|
|
||||||
#block level code completipn
|
#block level code completipn
|
||||||
#Modify the 10th and 12th lines of code to specify the base URL and model name
|
#Modify the 10th and 12th lines of code to specify the base URL and model name
|
||||||
python api_test_block_completion.py
|
python api_test_block_completion.py
|
||||||
#Modify the 30th line of code to specify the local model path
|
#Modify the 30th line of code to specify the local model path
|
||||||
python test_block.py
|
python test_block.py
|
||||||
|
|
||||||
# code migration (migration order is 'old_to_new')
|
# code migration (migration order is 'old_to_new')
|
||||||
#Modify the 10th and 12th lines of code to specify the base URL and model name
|
#Modify the 10th and 12th lines of code to specify the base URL and model name
|
||||||
python api_code_migration.py
|
python api_code_migration.py
|
||||||
#Modify the 30th line of code to specify the local model path
|
#Modify the 30th line of code to specify the local model path
|
||||||
python test_migration.py
|
python test_migration.py
|
||||||
```
|
```
|
||||||
|
|
||||||
4. **Process output**
|
4. **Process output**
|
||||||
Process the output content of the model, remove redundant content, extract specified content for easy calculation of indicators.
|
Process the output content of the model, remove redundant content, extract specified content for easy calculation of indicators.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
#cd output_processing
|
#cd output_processing
|
||||||
cd output_processing
|
cd output_processing
|
||||||
|
|
||||||
#Extract content from<start> and <end>
|
#Extract content from<start> and <end>
|
||||||
#Modify the 8th and 9th lines of code to specify the model and task granularity
|
#Modify the 8th and 9th lines of code to specify the model and task granularity
|
||||||
python clear_ans.py
|
python clear_ans.py
|
||||||
|
|
||||||
#In the block completion task and migration task, cdc@k The calculation of indicators needs to be targeted at key rows,
|
#In the block completion task and migration task, cdc@k The calculation of indicators needs to be targeted at key rows,
|
||||||
#Modify lines 76 and 79 to specify the data path
|
#Modify lines 76 and 79 to specify the data path
|
||||||
python choose_core_line_from_block_versicode.py
|
python choose_core_line_from_block_versicode.py
|
||||||
python choose_core_line_from_migration_versicode.py
|
python choose_core_line_from_migration_versicode.py
|
||||||
```
|
```
|
||||||
|
|
||||||
5. **Metric**
|
5. **Metric**
|
||||||
We have three metrics pass@k,em@k and cdc@k Due to our inability to automatically build a dynamic evaluation environment, we have not provided pass@k .
|
We have three metrics pass@k,em@k and cdc@k Due to our inability to automatically build a dynamic evaluation environment, we have not provided pass@k .
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
#cd metric
|
#cd metric
|
||||||
cd metric
|
cd metric
|
||||||
|
|
||||||
#Modify lines 137-140 in migration task (compute_migration_cdc_score.py) or 143-145 in block and line completion task (compute_versicode_cdc_score.py and compute_versicode_em_score.py) of the code to specify the data path and calculate the k-value of the metric
|
#Modify lines 137-140 in migration task (compute_migration_cdc_score.py) or 143-145 in block and line completion task (compute_versicode_cdc_score.py and compute_versicode_em_score.py) of the code to specify the data path and calculate the k-value of the metric
|
||||||
python compute_migration_cdc_score.py
|
python compute_migration_cdc_score.py
|
||||||
python compute_versicode_cdc_score.py
|
python compute_versicode_cdc_score.py
|
||||||
python compute_versicode_em_score.py
|
python compute_versicode_em_score.py
|
||||||
|
|
||||||
#Notes
|
#Notes
|
||||||
#We found limitations in the ISM@k and PM@k metrics for evaluating code generation, so they are used only as reference in our experiments.
|
#We found limitations in the ISM@k and PM@k metrics for evaluating code generation, so they are used only as reference in our experiments.
|
||||||
#Modify lines 261-265 in block and line completion task of the code to specify the data path and calculate the k-value of the metric
|
#Modify lines 261-265 in block and line completion task of the code to specify the data path and calculate the k-value of the metric
|
||||||
python compute_ism_pm_score.py
|
python compute_ism_pm_score.py
|
||||||
```
|
```
|
||||||
|
|
||||||
# Citation
|
# Citation
|
||||||
|
|
||||||
```
|
```
|
||||||
@article{versicode,
|
@article{versicode,
|
||||||
author={Tongtong Wu and Weigang Wu and Xingyu Wang and Kang Xu and Suyu Ma and Bo Jiang and Ping Yang and Zhenchang Xing and Yuan-Fang Li and Gholamreza Haffari},
|
author={Tongtong Wu and Weigang Wu and Xingyu Wang and Kang Xu and Suyu Ma and Bo Jiang and Ping Yang and Zhenchang Xing and Yuan-Fang Li and Gholamreza Haffari},
|
||||||
title = {VersiCode: Towards Version-controllable Code Generation},
|
title = {VersiCode: Towards Version-controllable Code Generation},
|
||||||
journal = {CoRR},
|
journal = {CoRR},
|
||||||
volume = {abs/2406.07411},
|
volume = {abs/2406.07411},
|
||||||
year = {2024},
|
year = {2024},
|
||||||
url = {https://arxiv.org/abs/2406.07411},
|
url = {https://arxiv.org/abs/2406.07411},
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
**Github url**: https://github.com/wutong8023/VersiCode
|
**Github url**: https://github.com/wutong8023/VersiCode
|
||||||
|
|
||||||
# Contributor
|
# Contributor
|
||||||
|
|
||||||
[Tongtong Wu](https://scholar.google.com/citations?hl=zh-CN&user=u1Qp8lUAAAAJ&view_op=list_works&sortby=pubdate), [Weigang Wu](https://scholar.google.com/citations?hl=zh-CN&user=UneIZo8AAAAJ), [Xingyu Wang](https://scholar.google.com/citations?hl=zh-CN&user=wqPJcxcAAAAJ), [Kang Xu](https://scholar.google.com/citations?hl=zh-CN&user=N1UUDi0AAAAJ), [Suyu Ma](https://scholar.google.com/citations?hl=zh-CN&user=NJHR1ukAAAAJ), [Bo Jiang](https://wutong8023.site/VersiCode/), [Ping Yang](https://scholar.google.com/citations?view_op=list_works&hl=en&hl=en&user=hrogvxoAAAAJ), [Zhenchang Xing](https://scholar.google.com/citations?hl=zh-CN&user=0vCxuH4AAAAJ), [Yuan-Fang Li](https://scholar.google.com/citations?hl=zh-CN&user=wufXO1kAAAAJ), [Gholamreza Haffari](https://scholar.google.com/citations?hl=zh-CN&user=Perjx5EAAAAJ)
|
[Tongtong Wu](https://scholar.google.com/citations?hl=zh-CN&user=u1Qp8lUAAAAJ&view_op=list_works&sortby=pubdate), [Weigang Wu](https://scholar.google.com/citations?hl=zh-CN&user=UneIZo8AAAAJ), [Xingyu Wang](https://scholar.google.com/citations?hl=zh-CN&user=wqPJcxcAAAAJ), [Kang Xu](https://scholar.google.com/citations?hl=zh-CN&user=N1UUDi0AAAAJ), [Suyu Ma](https://scholar.google.com/citations?hl=zh-CN&user=NJHR1ukAAAAJ), [Bo Jiang](https://wutong8023.site/VersiCode/), [Ping Yang](https://scholar.google.com/citations?view_op=list_works&hl=en&hl=en&user=hrogvxoAAAAJ), [Zhenchang Xing](https://scholar.google.com/citations?hl=zh-CN&user=0vCxuH4AAAAJ), [Yuan-Fang Li](https://scholar.google.com/citations?hl=zh-CN&user=wufXO1kAAAAJ), [Gholamreza Haffari](https://scholar.google.com/citations?hl=zh-CN&user=Perjx5EAAAAJ)
|
||||||
|
|
||||||
|
|||||||
@@ -1,134 +1,134 @@
|
|||||||
"""
|
"""
|
||||||
GPT performs line level generation prediction and truncates overly long tokens
|
GPT performs line level generation prediction and truncates overly long tokens
|
||||||
"""
|
"""
|
||||||
import json
|
|
||||||
import openai
|
import json
|
||||||
from openai import OpenAI
|
import os
|
||||||
import os
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
max_tokens = 127000 #gpt3.5 is 16ktoken gpt4o is 128k
|
from openai import OpenAI
|
||||||
model_name = ""
|
|
||||||
|
max_tokens = 127000 # gpt3.5 is 16ktoken gpt4o is 128k
|
||||||
os.environ["OPENAI_API_KEY"] = ""
|
model_name = ''
|
||||||
client = OpenAI()
|
|
||||||
|
os.environ['OPENAI_API_KEY'] = ''
|
||||||
def truncate_text(text, max_tokens):
|
client = OpenAI()
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
|
||||||
disallowed_special = ()
|
|
||||||
|
def truncate_text(text, max_tokens):
|
||||||
tokens = encoding.encode(text, disallowed_special=disallowed_special)
|
encoding = tiktoken.get_encoding('cl100k_base')
|
||||||
print(len(tokens))
|
disallowed_special = ()
|
||||||
|
|
||||||
if len(tokens) > max_tokens:
|
tokens = encoding.encode(text, disallowed_special=disallowed_special)
|
||||||
tokens = tokens[:max_tokens]
|
print(len(tokens))
|
||||||
|
|
||||||
truncated_text = encoding.decode(tokens)
|
if len(tokens) > max_tokens:
|
||||||
|
tokens = tokens[:max_tokens]
|
||||||
return truncated_text
|
|
||||||
|
truncated_text = encoding.decode(tokens)
|
||||||
def predict(content, model_name):
|
|
||||||
response = client.chat.completions.create(
|
return truncated_text
|
||||||
model=model_name,
|
|
||||||
messages=[
|
|
||||||
{
|
def predict(content, model_name):
|
||||||
"role": "user",
|
response = client.chat.completions.create(
|
||||||
"content": content
|
model=model_name,
|
||||||
}
|
messages=[{'role': 'user', 'content': content}],
|
||||||
],
|
frequency_penalty=0.1,
|
||||||
frequency_penalty=0.1,
|
max_tokens=128,
|
||||||
max_tokens=128,
|
logit_bias=None,
|
||||||
logit_bias=None,
|
logprobs=None,
|
||||||
logprobs=None,
|
n=6,
|
||||||
n=6,
|
presence_penalty=0.0,
|
||||||
presence_penalty=0.0,
|
seed=None,
|
||||||
seed=None,
|
stop=None,
|
||||||
stop=None,
|
stream=False,
|
||||||
stream=False,
|
temperature=0.8,
|
||||||
temperature=0.8,
|
top_p=0.95,
|
||||||
top_p=0.95
|
)
|
||||||
)
|
ans_list = []
|
||||||
ans_list = []
|
choices_list = response.choices
|
||||||
choices_list = response.choices
|
for c in choices_list:
|
||||||
for c in choices_list:
|
content = c.message.content
|
||||||
content = c.message.content
|
ans_list.append(content)
|
||||||
ans_list.append(content)
|
final_ans = str(ans_list)
|
||||||
final_ans = str(ans_list)
|
return final_ans
|
||||||
return final_ans
|
|
||||||
|
|
||||||
def bulid_prompt(description, old_version, old_code, new_version) -> str:
|
def bulid_prompt(description, old_version, old_code, new_version) -> str:
|
||||||
"""
|
"""
|
||||||
build prompt
|
build prompt
|
||||||
:param version:
|
:param version:
|
||||||
:param description:
|
:param description:
|
||||||
:param masked_code:
|
:param masked_code:
|
||||||
:param options:
|
:param options:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
You are now a professional Python programming engineer. I will provide you with a code snippet and a description of its functionality,
|
You are now a professional Python programming engineer. I will provide you with a code snippet and a description of its functionality,
|
||||||
including the dependencies and versions used in the code. Then, I will provide the same dependencies but with a specified new version.
|
including the dependencies and versions used in the code. Then, I will provide the same dependencies but with a specified new version.
|
||||||
Your task is to refactor the code using the methods provided by the specified new version and return the refactored code.
|
Your task is to refactor the code using the methods provided by the specified new version and return the refactored code.
|
||||||
Please note that you only need to return the refactored code and enclose it with <start> and <end>:
|
Please note that you only need to return the refactored code and enclose it with <start> and <end>:
|
||||||
###Functionality description of the code
|
###Functionality description of the code
|
||||||
{description}
|
{description}
|
||||||
###Dependency and old version
|
###Dependency and old version
|
||||||
{old_version}
|
{old_version}
|
||||||
###Old version code
|
###Old version code
|
||||||
{old_code}
|
{old_code}
|
||||||
###Dependency and new version
|
###Dependency and new version
|
||||||
{new_version}
|
{new_version}
|
||||||
###Refactored new code
|
###Refactored new code
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
json_path = '../data/test_data/VersiCode_migration.json'
|
json_path = '../data/test_data/VersiCode_migration.json'
|
||||||
|
|
||||||
|
|
||||||
with open(json_path, 'r', encoding='utf-8')as fr:
|
with open(json_path, 'r', encoding='utf-8') as fr:
|
||||||
lodict = json.load(fr)
|
lodict = json.load(fr)
|
||||||
data_dict = lodict
|
data_dict = lodict
|
||||||
data_list = data_dict
|
data_list = data_dict
|
||||||
|
|
||||||
|
|
||||||
for data in data_list:
|
for data in data_list:
|
||||||
if "model_output" in data:
|
if 'model_output' in data:
|
||||||
print(f"the {data_list.index(data) + 1} has already been predicted, skipping this data!")
|
print(
|
||||||
continue
|
f'the {data_list.index(data) + 1} has already been predicted, skipping this data!'
|
||||||
try:
|
)
|
||||||
print(f"Predicting {data_list.index(data) + 1} ")
|
continue
|
||||||
old_version = data['dependency'] + data['old_version'] # package == x.x.x
|
try:
|
||||||
new_version = data['dependency'] + data['new_version'] # package == x.x.x
|
print(f'Predicting {data_list.index(data) + 1} ')
|
||||||
description = data['description'] # 功能描述
|
old_version = data['dependency'] + data['old_version'] # package == x.x.x
|
||||||
old_code = data['old_code'] # mask后的代码
|
new_version = data['dependency'] + data['new_version'] # package == x.x.x
|
||||||
|
description = data['description'] # 功能描述
|
||||||
instruction = bulid_prompt(description, old_version, old_code, new_version)
|
old_code = data['old_code'] # mask后的代码
|
||||||
truncated_text = truncate_text(instruction, max_tokens)
|
|
||||||
prediction = predict(truncated_text, model_name)
|
instruction = bulid_prompt(description, old_version, old_code, new_version)
|
||||||
|
truncated_text = truncate_text(instruction, max_tokens)
|
||||||
data['model_output'] = prediction
|
prediction = predict(truncated_text, model_name)
|
||||||
except Exception as e:
|
|
||||||
print(f"error:{e}")
|
data['model_output'] = prediction
|
||||||
print("save current data")
|
except Exception as e:
|
||||||
save_folder_path = os.path.join('../data/result_data/code_migration', model_name)
|
print(f'error:{e}')
|
||||||
if not os.path.exists(save_folder_path):
|
print('save current data')
|
||||||
os.makedirs(save_folder_path)
|
save_folder_path = os.path.join(
|
||||||
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
'../data/result_data/code_migration', model_name
|
||||||
|
)
|
||||||
with open(save_json_path, 'w', encoding='utf-8') as fw:
|
if not os.path.exists(save_folder_path):
|
||||||
json.dump(data_dict, fw, indent=4, ensure_ascii=False)
|
os.makedirs(save_folder_path)
|
||||||
break
|
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
||||||
|
|
||||||
|
with open(save_json_path, 'w', encoding='utf-8') as fw:
|
||||||
|
json.dump(data_dict, fw, indent=4, ensure_ascii=False)
|
||||||
save_folder_path = os.path.join('../data/result_data/code_migration', model_name)
|
break
|
||||||
if not os.path.exists(save_folder_path):
|
|
||||||
os.makedirs(save_folder_path)
|
|
||||||
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
save_folder_path = os.path.join('../data/result_data/code_migration', model_name)
|
||||||
|
if not os.path.exists(save_folder_path):
|
||||||
with open(save_json_path, 'w', encoding='utf-8')as fw:
|
os.makedirs(save_folder_path)
|
||||||
json.dump(data_dict, fw, indent=4, ensure_ascii=False)
|
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
||||||
|
|
||||||
|
with open(save_json_path, 'w', encoding='utf-8') as fw:
|
||||||
|
json.dump(data_dict, fw, indent=4, ensure_ascii=False)
|
||||||
|
|||||||
@@ -1,141 +1,141 @@
|
|||||||
"""
|
"""
|
||||||
GPT performs line level generation prediction and truncates overly long tokens
|
GPT performs line level generation prediction and truncates overly long tokens
|
||||||
"""
|
"""
|
||||||
import json
|
|
||||||
import openai
|
import json
|
||||||
from openai import OpenAI
|
import os
|
||||||
import os
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
max_tokens = 127000 #gpt3.5 is 16ktoken gpt4o is 128k
|
from openai import OpenAI
|
||||||
model_name = ""
|
|
||||||
|
max_tokens = 127000 # gpt3.5 is 16ktoken gpt4o is 128k
|
||||||
os.environ["OPENAI_API_KEY"] = ""
|
model_name = ''
|
||||||
client = OpenAI()
|
|
||||||
|
os.environ['OPENAI_API_KEY'] = ''
|
||||||
def truncate_text(text, max_tokens):
|
client = OpenAI()
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
|
||||||
disallowed_special = ()
|
|
||||||
|
def truncate_text(text, max_tokens):
|
||||||
tokens = encoding.encode(text, disallowed_special=disallowed_special)
|
encoding = tiktoken.get_encoding('cl100k_base')
|
||||||
print(len(tokens))
|
disallowed_special = ()
|
||||||
|
|
||||||
if len(tokens) > max_tokens:
|
tokens = encoding.encode(text, disallowed_special=disallowed_special)
|
||||||
tokens = tokens[:max_tokens]
|
print(len(tokens))
|
||||||
|
|
||||||
truncated_text = encoding.decode(tokens)
|
if len(tokens) > max_tokens:
|
||||||
|
tokens = tokens[:max_tokens]
|
||||||
return truncated_text
|
|
||||||
|
truncated_text = encoding.decode(tokens)
|
||||||
def predict(content, model_name):
|
|
||||||
response = client.chat.completions.create(
|
return truncated_text
|
||||||
model=model_name,
|
|
||||||
messages=[
|
|
||||||
{
|
def predict(content, model_name):
|
||||||
"role": "user",
|
response = client.chat.completions.create(
|
||||||
"content": content
|
model=model_name,
|
||||||
}
|
messages=[{'role': 'user', 'content': content}],
|
||||||
],
|
frequency_penalty=0.1,
|
||||||
frequency_penalty=0.1,
|
max_tokens=128,
|
||||||
max_tokens=128,
|
logit_bias=None,
|
||||||
logit_bias=None,
|
logprobs=None,
|
||||||
logprobs=None,
|
n=6,
|
||||||
n=6,
|
presence_penalty=0.0,
|
||||||
presence_penalty=0.0,
|
seed=None,
|
||||||
seed=None,
|
stop=None,
|
||||||
stop=None,
|
stream=False,
|
||||||
stream=False,
|
temperature=0.8,
|
||||||
temperature=0.8,
|
top_p=0.95,
|
||||||
top_p=0.95
|
)
|
||||||
)
|
ans_list = []
|
||||||
ans_list = []
|
choices_list = response.choices
|
||||||
choices_list = response.choices
|
for c in choices_list:
|
||||||
for c in choices_list:
|
content = c.message.content
|
||||||
content = c.message.content
|
ans_list.append(content)
|
||||||
ans_list.append(content)
|
final_ans = str(ans_list)
|
||||||
final_ans = str(ans_list)
|
return final_ans
|
||||||
return final_ans
|
|
||||||
|
|
||||||
def bulid_prompt(version, description) -> str:
|
def bulid_prompt(version, description) -> str:
|
||||||
"""
|
"""
|
||||||
build prompt
|
build prompt
|
||||||
:param version:
|
:param version:
|
||||||
:param description:
|
:param description:
|
||||||
:param masked_code:
|
:param masked_code:
|
||||||
:param options:
|
:param options:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
prompt = f'''
|
prompt = f"""
|
||||||
You are a professional Python engineer, and I will provide functional descriptions and versions of specified dependency packages.
|
You are a professional Python engineer, and I will provide functional descriptions and versions of specified dependency packages.
|
||||||
You need to write code in Python to implement this feature based on the functional description and using the dependency package and version I specified.
|
You need to write code in Python to implement this feature based on the functional description and using the dependency package and version I specified.
|
||||||
Please note that you only need to return the code that implements the function, and do not return any other content.
|
Please note that you only need to return the code that implements the function, and do not return any other content.
|
||||||
Please use <start> and <end> to enclose the generated code. Here is an example:
|
Please use <start> and <end> to enclose the generated code. Here is an example:
|
||||||
###Function Description:
|
###Function Description:
|
||||||
The function of this code is to print the results predicted by calling the model using vllm.
|
The function of this code is to print the results predicted by calling the model using vllm.
|
||||||
###dependeny and version:
|
###dependeny and version:
|
||||||
vllm==0.3.3
|
vllm==0.3.3
|
||||||
###response:
|
###response:
|
||||||
<start>
|
<start>
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
prompt = output.prompt
|
prompt = output.prompt
|
||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
print("Prompt,Generated text")
|
print("Prompt,Generated text")
|
||||||
<end>
|
<end>
|
||||||
|
|
||||||
###Function Description:
|
###Function Description:
|
||||||
{description}
|
{description}
|
||||||
###dependeny and version:
|
###dependeny and version:
|
||||||
{version}
|
{version}
|
||||||
###response:
|
###response:
|
||||||
|
|
||||||
|
|
||||||
'''
|
"""
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
json_path = '../data/test_data/VersiCode_block_completion.json'
|
json_path = '../data/test_data/VersiCode_block_completion.json'
|
||||||
|
|
||||||
|
|
||||||
with open(json_path, 'r', encoding='utf-8')as fr:
|
with open(json_path, 'r', encoding='utf-8') as fr:
|
||||||
lodict = json.load(fr)
|
lodict = json.load(fr)
|
||||||
data_dict = lodict
|
data_dict = lodict
|
||||||
data_list = data_dict
|
data_list = data_dict
|
||||||
|
|
||||||
|
|
||||||
for data in data_list:
|
for data in data_list:
|
||||||
if "model_output" in data:
|
if 'model_output' in data:
|
||||||
print(f"the {data_list.index(data) + 1} has already been predicted, skipping this data!")
|
print(
|
||||||
continue
|
f'the {data_list.index(data) + 1} has already been predicted, skipping this data!'
|
||||||
try:
|
)
|
||||||
print(f"Predicting {data_list.index(data) + 1} ")
|
continue
|
||||||
version = data['dependency'] + data['version'] # package == x.x.x
|
try:
|
||||||
description = data['description'] # func description
|
print(f'Predicting {data_list.index(data) + 1} ')
|
||||||
|
version = data['dependency'] + data['version'] # package == x.x.x
|
||||||
instruction = bulid_prompt(version, description)
|
description = data['description'] # func description
|
||||||
truncated_text = truncate_text(instruction, max_tokens)
|
|
||||||
prediction = predict(truncated_text, model_name)
|
instruction = bulid_prompt(version, description)
|
||||||
|
truncated_text = truncate_text(instruction, max_tokens)
|
||||||
data['model_output'] = prediction
|
prediction = predict(truncated_text, model_name)
|
||||||
except Exception as e:
|
|
||||||
print(f"error:{e}")
|
data['model_output'] = prediction
|
||||||
print("save current data")
|
except Exception as e:
|
||||||
save_folder_path = os.path.join('../data/result_data/block_completion', model_name)
|
print(f'error:{e}')
|
||||||
if not os.path.exists(save_folder_path):
|
print('save current data')
|
||||||
os.makedirs(save_folder_path)
|
save_folder_path = os.path.join(
|
||||||
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
'../data/result_data/block_completion', model_name
|
||||||
|
)
|
||||||
with open(save_json_path, 'w', encoding='utf-8') as fw:
|
if not os.path.exists(save_folder_path):
|
||||||
json.dump(data_dict, fw, indent=4, ensure_ascii=False)
|
os.makedirs(save_folder_path)
|
||||||
break
|
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
||||||
|
|
||||||
|
with open(save_json_path, 'w', encoding='utf-8') as fw:
|
||||||
|
json.dump(data_dict, fw, indent=4, ensure_ascii=False)
|
||||||
save_folder_path = os.path.join('../data/result_data/block_completion', model_name)
|
break
|
||||||
if not os.path.exists(save_folder_path):
|
|
||||||
os.makedirs(save_folder_path)
|
|
||||||
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
save_folder_path = os.path.join('../data/result_data/block_completion', model_name)
|
||||||
|
if not os.path.exists(save_folder_path):
|
||||||
with open(save_json_path, 'w', encoding='utf-8')as fw:
|
os.makedirs(save_folder_path)
|
||||||
json.dump(data_dict, fw, indent=4, ensure_ascii=False)
|
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
||||||
|
|
||||||
|
with open(save_json_path, 'w', encoding='utf-8') as fw:
|
||||||
|
json.dump(data_dict, fw, indent=4, ensure_ascii=False)
|
||||||
|
|||||||
@@ -1,118 +1,129 @@
|
|||||||
"""
|
"""
|
||||||
block completion
|
block completion
|
||||||
"""
|
"""
|
||||||
import copy
|
|
||||||
import json
|
import copy
|
||||||
import os
|
import gc
|
||||||
from vllm import LLM, SamplingParams
|
import json
|
||||||
import tiktoken
|
import os
|
||||||
import time
|
import time
|
||||||
import gc
|
from multiprocessing import Process
|
||||||
import torch
|
|
||||||
from multiprocessing import Process
|
import tiktoken
|
||||||
|
import torch
|
||||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
def truncate_text(text, max_tokens):
|
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
|
||||||
disallowed_special = ()
|
|
||||||
|
def truncate_text(text, max_tokens):
|
||||||
tokens = encoding.encode(text, disallowed_special=disallowed_special)
|
encoding = tiktoken.get_encoding('cl100k_base')
|
||||||
print(len(tokens))
|
disallowed_special = ()
|
||||||
|
|
||||||
if len(tokens) > max_tokens:
|
tokens = encoding.encode(text, disallowed_special=disallowed_special)
|
||||||
tokens = tokens[:max_tokens]
|
print(len(tokens))
|
||||||
|
|
||||||
truncated_text = encoding.decode(tokens)
|
if len(tokens) > max_tokens:
|
||||||
|
tokens = tokens[:max_tokens]
|
||||||
return truncated_text
|
|
||||||
|
truncated_text = encoding.decode(tokens)
|
||||||
model_list = ['/data2/base models/starcoder2-15b', '/data2/base models/CodeGemma-7B']
|
|
||||||
|
return truncated_text
|
||||||
def run_inference(model_name, origin_data_list):
|
|
||||||
temp_data_list = copy.deepcopy(origin_data_list)
|
|
||||||
test_list = []
|
model_list = ['/data2/base models/starcoder2-15b', '/data2/base models/CodeGemma-7B']
|
||||||
for data in temp_data_list:
|
|
||||||
version = data['dependency'] + data['version'] # package == x.x.x
|
|
||||||
description = data['description'] # func description
|
def run_inference(model_name, origin_data_list):
|
||||||
|
temp_data_list = copy.deepcopy(origin_data_list)
|
||||||
instruction = bulid_prompt(version, description)
|
test_list = []
|
||||||
test_list.append(instruction)
|
for data in temp_data_list:
|
||||||
|
version = data['dependency'] + data['version'] # package == x.x.x
|
||||||
sampling_params = SamplingParams(n=6, temperature=0.8, top_p=0.95, max_tokens=64)
|
description = data['description'] # func description
|
||||||
llm = LLM(model=model_name, tensor_parallel_size=4, gpu_memory_utilization=0.9, swap_space=20)
|
|
||||||
|
instruction = bulid_prompt(version, description)
|
||||||
outputs = llm.generate(test_list, sampling_params)
|
test_list.append(instruction)
|
||||||
for output in outputs:
|
|
||||||
requests_id = int(output.request_id)
|
sampling_params = SamplingParams(n=6, temperature=0.8, top_p=0.95, max_tokens=64)
|
||||||
temp_ans_list = []
|
llm = LLM(
|
||||||
output_list = output.outputs
|
model=model_name,
|
||||||
for o in output_list:
|
tensor_parallel_size=4,
|
||||||
text = o.text
|
gpu_memory_utilization=0.9,
|
||||||
temp_ans_list.append(text)
|
swap_space=20,
|
||||||
|
)
|
||||||
temp_data_list[requests_id]['model_output'] = str(temp_ans_list)
|
|
||||||
|
outputs = llm.generate(test_list, sampling_params)
|
||||||
save_folder_path = os.path.join('../data/result_data/block_completion', model_name.split('/')[-1])
|
for output in outputs:
|
||||||
if not os.path.exists(save_folder_path):
|
requests_id = int(output.request_id)
|
||||||
os.makedirs(save_folder_path)
|
temp_ans_list = []
|
||||||
|
output_list = output.outputs
|
||||||
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
for o in output_list:
|
||||||
|
text = o.text
|
||||||
with open(save_json_path, 'w', encoding='utf-8') as fw:
|
temp_ans_list.append(text)
|
||||||
json.dump(temp_data_list, fw, indent=4, ensure_ascii=False)
|
|
||||||
|
temp_data_list[requests_id]['model_output'] = str(temp_ans_list)
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
save_folder_path = os.path.join(
|
||||||
|
'../data/result_data/block_completion', model_name.split('/')[-1]
|
||||||
|
)
|
||||||
def bulid_prompt(version, description) -> str:
|
if not os.path.exists(save_folder_path):
|
||||||
"""
|
os.makedirs(save_folder_path)
|
||||||
build prompt
|
|
||||||
:param version:
|
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
||||||
:param description:
|
|
||||||
:param masked_code:
|
with open(save_json_path, 'w', encoding='utf-8') as fw:
|
||||||
:param options:
|
json.dump(temp_data_list, fw, indent=4, ensure_ascii=False)
|
||||||
:return:
|
|
||||||
"""
|
gc.collect()
|
||||||
prompt = f'''
|
torch.cuda.empty_cache()
|
||||||
You are a professional Python engineer, and I will provide functional descriptions and versions of specified dependency packages.
|
|
||||||
You need to write code in Python to implement this feature based on the functional description and using the dependency package and version I specified.
|
|
||||||
Please note that you only need to return the code that implements the function, and do not return any other content.
|
def bulid_prompt(version, description) -> str:
|
||||||
Please use <start> and <end> to enclose the generated code. Here is an example:
|
"""
|
||||||
###Function Description:
|
build prompt
|
||||||
The function of this code is to print the results predicted by calling the model using vllm.
|
:param version:
|
||||||
###dependeny and version:
|
:param description:
|
||||||
vllm==0.3.3
|
:param masked_code:
|
||||||
###response:
|
:param options:
|
||||||
<start>
|
:return:
|
||||||
for output in outputs:
|
"""
|
||||||
prompt = output.prompt
|
prompt = f"""
|
||||||
generated_text = output.outputs[0].text
|
You are a professional Python engineer, and I will provide functional descriptions and versions of specified dependency packages.
|
||||||
print("Prompt,Generated text")
|
You need to write code in Python to implement this feature based on the functional description and using the dependency package and version I specified.
|
||||||
<end>
|
Please note that you only need to return the code that implements the function, and do not return any other content.
|
||||||
|
Please use <start> and <end> to enclose the generated code. Here is an example:
|
||||||
###Function Description:
|
###Function Description:
|
||||||
{description}
|
The function of this code is to print the results predicted by calling the model using vllm.
|
||||||
###dependeny and version:
|
###dependeny and version:
|
||||||
{version}
|
vllm==0.3.3
|
||||||
###response:
|
###response:
|
||||||
|
<start>
|
||||||
|
for output in outputs:
|
||||||
'''
|
prompt = output.prompt
|
||||||
return prompt
|
generated_text = output.outputs[0].text
|
||||||
|
print("Prompt,Generated text")
|
||||||
|
<end>
|
||||||
json_path = '../data/test_data/VersiCode_block_completion.json'
|
|
||||||
|
###Function Description:
|
||||||
with open(json_path, 'r', encoding='utf-8')as fr:
|
{description}
|
||||||
lodict = json.load(fr)
|
###dependeny and version:
|
||||||
|
{version}
|
||||||
origin_data_list = lodict
|
###response:
|
||||||
|
|
||||||
for model_name in model_list:
|
|
||||||
process = Process(target=run_inference, args=(model_name, origin_data_list))
|
"""
|
||||||
process.start()
|
return prompt
|
||||||
process.join()
|
|
||||||
time.sleep(120)
|
|
||||||
|
json_path = '../data/test_data/VersiCode_block_completion.json'
|
||||||
|
|
||||||
|
with open(json_path, 'r', encoding='utf-8') as fr:
|
||||||
|
lodict = json.load(fr)
|
||||||
|
|
||||||
|
origin_data_list = lodict
|
||||||
|
|
||||||
|
for model_name in model_list:
|
||||||
|
process = Process(target=run_inference, args=(model_name, origin_data_list))
|
||||||
|
process.start()
|
||||||
|
process.join()
|
||||||
|
time.sleep(120)
|
||||||
|
|||||||
@@ -1,111 +1,122 @@
|
|||||||
"""
|
"""
|
||||||
code migration
|
code migration
|
||||||
"""
|
"""
|
||||||
import copy
|
|
||||||
import json
|
import copy
|
||||||
import os
|
import gc
|
||||||
from vllm import LLM, SamplingParams
|
import json
|
||||||
import tiktoken
|
import os
|
||||||
import time
|
import time
|
||||||
import gc
|
from multiprocessing import Process
|
||||||
import torch
|
|
||||||
from multiprocessing import Process
|
import tiktoken
|
||||||
|
import torch
|
||||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
def truncate_text(text, max_tokens):
|
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
|
||||||
disallowed_special = ()
|
|
||||||
|
def truncate_text(text, max_tokens):
|
||||||
tokens = encoding.encode(text, disallowed_special=disallowed_special)
|
encoding = tiktoken.get_encoding('cl100k_base')
|
||||||
print(len(tokens))
|
disallowed_special = ()
|
||||||
|
|
||||||
if len(tokens) > max_tokens:
|
tokens = encoding.encode(text, disallowed_special=disallowed_special)
|
||||||
tokens = tokens[:max_tokens]
|
print(len(tokens))
|
||||||
|
|
||||||
truncated_text = encoding.decode(tokens)
|
if len(tokens) > max_tokens:
|
||||||
|
tokens = tokens[:max_tokens]
|
||||||
return truncated_text
|
|
||||||
|
truncated_text = encoding.decode(tokens)
|
||||||
model_list = ['/data2/base models/starcoder2-15b', '/data2/base models/CodeGemma-7B']
|
|
||||||
|
return truncated_text
|
||||||
def run_inference(model_name, origin_data_list):
|
|
||||||
temp_data_list = copy.deepcopy(origin_data_list)
|
|
||||||
test_list = []
|
model_list = ['/data2/base models/starcoder2-15b', '/data2/base models/CodeGemma-7B']
|
||||||
for data in temp_data_list:
|
|
||||||
old_version = data['dependency'] + data['old_version'] # package == x.x.x
|
|
||||||
new_version = data['dependency'] + data['new_version'] # package == x.x.x
|
def run_inference(model_name, origin_data_list):
|
||||||
description = data['description'] # 功能描述
|
temp_data_list = copy.deepcopy(origin_data_list)
|
||||||
old_code = data['old_code'] # mask后的代码
|
test_list = []
|
||||||
|
for data in temp_data_list:
|
||||||
instruction = bulid_prompt(description, old_version, old_code, new_version)
|
old_version = data['dependency'] + data['old_version'] # package == x.x.x
|
||||||
test_list.append(instruction)
|
new_version = data['dependency'] + data['new_version'] # package == x.x.x
|
||||||
|
description = data['description'] # 功能描述
|
||||||
sampling_params = SamplingParams(n=6, temperature=0.8, top_p=0.95, max_tokens=512)
|
old_code = data['old_code'] # mask后的代码
|
||||||
llm = LLM(model=model_name, tensor_parallel_size=4, gpu_memory_utilization=0.6, swap_space=40)
|
|
||||||
|
instruction = bulid_prompt(description, old_version, old_code, new_version)
|
||||||
outputs = llm.generate(test_list, sampling_params)
|
test_list.append(instruction)
|
||||||
for output in outputs:
|
|
||||||
requests_id = int(output.request_id)
|
sampling_params = SamplingParams(n=6, temperature=0.8, top_p=0.95, max_tokens=512)
|
||||||
temp_ans_list = []
|
llm = LLM(
|
||||||
output_list = output.outputs
|
model=model_name,
|
||||||
for o in output_list:
|
tensor_parallel_size=4,
|
||||||
text = o.text
|
gpu_memory_utilization=0.6,
|
||||||
temp_ans_list.append(text)
|
swap_space=40,
|
||||||
|
)
|
||||||
temp_data_list[requests_id]['model_output'] = str(temp_ans_list)
|
|
||||||
|
outputs = llm.generate(test_list, sampling_params)
|
||||||
save_folder_path = os.path.join('../data/result_data/code_migration', model_name.split('/')[-1])
|
for output in outputs:
|
||||||
if not os.path.exists(save_folder_path):
|
requests_id = int(output.request_id)
|
||||||
os.makedirs(save_folder_path)
|
temp_ans_list = []
|
||||||
|
output_list = output.outputs
|
||||||
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
for o in output_list:
|
||||||
|
text = o.text
|
||||||
with open(save_json_path, 'w', encoding='utf-8') as fw:
|
temp_ans_list.append(text)
|
||||||
json.dump(temp_data_list, fw, indent=4, ensure_ascii=False)
|
|
||||||
|
temp_data_list[requests_id]['model_output'] = str(temp_ans_list)
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
save_folder_path = os.path.join(
|
||||||
|
'../data/result_data/code_migration', model_name.split('/')[-1]
|
||||||
|
)
|
||||||
def bulid_prompt(description, old_version, old_code, new_version) -> str:
|
if not os.path.exists(save_folder_path):
|
||||||
"""
|
os.makedirs(save_folder_path)
|
||||||
build prompt
|
|
||||||
:param version:
|
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
||||||
:param description:
|
|
||||||
:param masked_code:
|
with open(save_json_path, 'w', encoding='utf-8') as fw:
|
||||||
:param options:
|
json.dump(temp_data_list, fw, indent=4, ensure_ascii=False)
|
||||||
:return:
|
|
||||||
"""
|
gc.collect()
|
||||||
prompt = f"""
|
torch.cuda.empty_cache()
|
||||||
You are now a professional Python programming engineer. I will provide you with a code snippet and a description of its functionality,
|
|
||||||
including the dependencies and versions used in the code. Then, I will provide the same dependencies but with a specified new version.
|
|
||||||
Your task is to refactor the code using the methods provided by the specified new version and return the refactored code.
|
def bulid_prompt(description, old_version, old_code, new_version) -> str:
|
||||||
Please note that you only need to return the refactored code and enclose it with <start> and <end>:
|
"""
|
||||||
###Functionality description of the code
|
build prompt
|
||||||
{description}
|
:param version:
|
||||||
###Dependency and old version
|
:param description:
|
||||||
{old_version}
|
:param masked_code:
|
||||||
###Old version code
|
:param options:
|
||||||
{old_code}
|
:return:
|
||||||
###Dependency and new version
|
"""
|
||||||
{new_version}
|
prompt = f"""
|
||||||
###Refactored new code
|
You are now a professional Python programming engineer. I will provide you with a code snippet and a description of its functionality,
|
||||||
"""
|
including the dependencies and versions used in the code. Then, I will provide the same dependencies but with a specified new version.
|
||||||
|
Your task is to refactor the code using the methods provided by the specified new version and return the refactored code.
|
||||||
return prompt
|
Please note that you only need to return the refactored code and enclose it with <start> and <end>:
|
||||||
|
###Functionality description of the code
|
||||||
|
{description}
|
||||||
json_path = '../data/test_data/VersiCode_migration.json'
|
###Dependency and old version
|
||||||
|
{old_version}
|
||||||
with open(json_path, 'r', encoding='utf-8')as fr:
|
###Old version code
|
||||||
lodict = json.load(fr)
|
{old_code}
|
||||||
|
###Dependency and new version
|
||||||
origin_data_list = lodict
|
{new_version}
|
||||||
|
###Refactored new code
|
||||||
for model_name in model_list:
|
"""
|
||||||
process = Process(target=run_inference, args=(model_name, origin_data_list))
|
|
||||||
process.start()
|
return prompt
|
||||||
process.join()
|
|
||||||
time.sleep(120)
|
|
||||||
|
json_path = '../data/test_data/VersiCode_migration.json'
|
||||||
|
|
||||||
|
with open(json_path, 'r', encoding='utf-8') as fr:
|
||||||
|
lodict = json.load(fr)
|
||||||
|
|
||||||
|
origin_data_list = lodict
|
||||||
|
|
||||||
|
for model_name in model_list:
|
||||||
|
process = Process(target=run_inference, args=(model_name, origin_data_list))
|
||||||
|
process.start()
|
||||||
|
process.join()
|
||||||
|
time.sleep(120)
|
||||||
|
|||||||
@@ -1,345 +1,356 @@
|
|||||||
"""
|
"""
|
||||||
评测block的预测能力
|
评测block的预测能力
|
||||||
1、判断是否包含正确的函数名
|
1、判断是否包含正确的函数名
|
||||||
2、判断是否合法
|
2、判断是否合法
|
||||||
3、计算ISM,和PM
|
3、计算ISM,和PM
|
||||||
"""
|
"""
|
||||||
import json
|
|
||||||
import tokenize
|
import io
|
||||||
import io
|
import json
|
||||||
import math
|
import math
|
||||||
import ast
|
import os
|
||||||
import re
|
import re
|
||||||
import os
|
import tokenize
|
||||||
|
|
||||||
def is_code_valid(code):
|
|
||||||
|
def is_code_valid(code):
|
||||||
try:
|
try:
|
||||||
compile(code, '<string>', 'exec')
|
compile(code, '<string>', 'exec')
|
||||||
return True
|
return True
|
||||||
except:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def longest_common_prefix_between_lists_with_elements(list1, list2):
|
def longest_common_prefix_between_lists_with_elements(list1, list2):
|
||||||
"""
|
"""
|
||||||
计算两个字符串列表中元素的最长前缀匹配长度
|
计算两个字符串列表中元素的最长前缀匹配长度
|
||||||
:param list1:
|
:param list1:
|
||||||
:param list2:
|
:param list2:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
max_prefix_length = 0
|
max_prefix_length = 0
|
||||||
max_prefix_elements = ()
|
max_prefix_elements = ()
|
||||||
for str1 in list1:
|
for str1 in list1:
|
||||||
for str2 in list2:
|
for str2 in list2:
|
||||||
prefix_length = 0
|
prefix_length = 0
|
||||||
min_len = min(len(str1), len(str2))
|
min_len = min(len(str1), len(str2))
|
||||||
for i in range(min_len):
|
for i in range(min_len):
|
||||||
if str1[i] == str2[i]:
|
if str1[i] == str2[i]:
|
||||||
prefix_length += 1
|
prefix_length += 1
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
if prefix_length > max_prefix_length:
|
if prefix_length > max_prefix_length:
|
||||||
max_prefix_length = prefix_length
|
max_prefix_length = prefix_length
|
||||||
max_prefix_elements = (str1, str2)
|
max_prefix_elements = (str1, str2)
|
||||||
return max_prefix_length, max_prefix_elements
|
return max_prefix_length, max_prefix_elements
|
||||||
|
|
||||||
def get_token(ans_code:str, output_code:str):
|
|
||||||
"""
|
def get_token(ans_code: str, output_code: str):
|
||||||
对代码进行词法分析,分解成标识符,返回两个标识符列表
|
"""
|
||||||
:param ans_code:
|
对代码进行词法分析,分解成标识符,返回两个标识符列表
|
||||||
:param output_code:
|
:param ans_code:
|
||||||
:return:
|
:param output_code:
|
||||||
"""
|
:return:
|
||||||
output_flag = True
|
"""
|
||||||
ans_flag = True
|
output_flag = True
|
||||||
try:
|
ans_flag = True
|
||||||
tokens_ans = tokenize.tokenize(io.BytesIO(ans_code.encode('utf-8')).readline)
|
try:
|
||||||
except Exception as e:
|
tokens_ans = tokenize.tokenize(io.BytesIO(ans_code.encode('utf-8')).readline)
|
||||||
tokens_ans = ans_code.splitlines()
|
except Exception:
|
||||||
ans_flag = False
|
tokens_ans = ans_code.splitlines()
|
||||||
|
ans_flag = False
|
||||||
try:
|
|
||||||
tokens_output = tokenize.tokenize(io.BytesIO(output_code.encode('utf-8')).readline)
|
try:
|
||||||
except Exception as e:
|
tokens_output = tokenize.tokenize(
|
||||||
tokens_output = output_code.splitlines()
|
io.BytesIO(output_code.encode('utf-8')).readline
|
||||||
output_flag = False
|
)
|
||||||
|
except Exception:
|
||||||
|
tokens_output = output_code.splitlines()
|
||||||
identifiers_ans = []
|
output_flag = False
|
||||||
identifiers_output = []
|
|
||||||
if ans_flag == True:
|
identifiers_ans = []
|
||||||
try:
|
identifiers_output = []
|
||||||
for token in tokens_ans:
|
if ans_flag:
|
||||||
if token.type == tokenize.NAME:
|
try:
|
||||||
identifiers_ans.append(token.string)
|
for token in tokens_ans:
|
||||||
except Exception as e:
|
if token.type == tokenize.NAME:
|
||||||
identifiers_ans = tokens_ans
|
identifiers_ans.append(token.string)
|
||||||
else:
|
except Exception:
|
||||||
identifiers_ans = tokens_ans
|
identifiers_ans = tokens_ans
|
||||||
|
else:
|
||||||
if output_flag == True:
|
identifiers_ans = tokens_ans
|
||||||
try:
|
|
||||||
for to in tokens_output:
|
if output_flag:
|
||||||
if to.type == tokenize.NAME:
|
try:
|
||||||
identifiers_output.append(to.string)
|
for to in tokens_output:
|
||||||
except Exception as e:
|
if to.type == tokenize.NAME:
|
||||||
identifiers_output = tokens_output
|
identifiers_output.append(to.string)
|
||||||
else:
|
except Exception:
|
||||||
identifiers_output = tokens_output
|
identifiers_output = tokens_output
|
||||||
|
else:
|
||||||
|
identifiers_output = tokens_output
|
||||||
return identifiers_ans, identifiers_output
|
|
||||||
|
return identifiers_ans, identifiers_output
|
||||||
|
|
||||||
def get_token_per_line(code: str):
|
|
||||||
"""
|
def get_token_per_line(code: str):
|
||||||
对每一行代码进行词法分析,记录每一行的标识符
|
"""
|
||||||
:param code: 代码字符串
|
对每一行代码进行词法分析,记录每一行的标识符
|
||||||
:return: 每一行的标识符列表组成的列表
|
:param code: 代码字符串
|
||||||
"""
|
:return: 每一行的标识符列表组成的列表
|
||||||
lines = code.split('\n') # 将代码按行分割成列表
|
"""
|
||||||
identifiers_per_line = [] # 用于存储每一行的标识符列表的列表
|
lines = code.split('\n') # 将代码按行分割成列表
|
||||||
|
identifiers_per_line = [] # 用于存储每一行的标识符列表的列表
|
||||||
for line in lines:
|
|
||||||
tokens = tokenize.tokenize(io.BytesIO(line.encode('utf-8')).readline)
|
for line in lines:
|
||||||
identifiers = []
|
tokens = tokenize.tokenize(io.BytesIO(line.encode('utf-8')).readline)
|
||||||
try:
|
identifiers = []
|
||||||
for token in tokens:
|
try:
|
||||||
if token.type == tokenize.NAME:
|
for token in tokens:
|
||||||
identifiers.append(token.string)
|
if token.type == tokenize.NAME:
|
||||||
except:
|
identifiers.append(token.string)
|
||||||
identifiers = line.split(' ')
|
except Exception:
|
||||||
identifiers_per_line.append(identifiers)
|
identifiers = line.split(' ')
|
||||||
|
identifiers_per_line.append(identifiers)
|
||||||
return identifiers_per_line
|
|
||||||
|
return identifiers_per_line
|
||||||
|
|
||||||
|
|
||||||
def get_ISM(answer_code:str, model_output_list:list, asnwer_name:str)->list:
|
def get_ISM(answer_code: str, model_output_list: list, asnwer_name: str) -> list:
|
||||||
"""
|
"""
|
||||||
计算ISM,返回一个有序的得分列表
|
计算ISM,返回一个有序的得分列表
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
score_list = []
|
score_list = []
|
||||||
for code in model_output_list:
|
for code in model_output_list:
|
||||||
if '```python' in code:
|
if '```python' in code:
|
||||||
code = code.replace('```python', '')
|
code = code.replace('```python', '')
|
||||||
code = code.replace('```', '')
|
code = code.replace('```', '')
|
||||||
if not re.search(rf'\b{re.escape(asnwer_name)}\b', code) or is_code_valid(code) == False:
|
if not re.search(rf'\b{re.escape(asnwer_name)}\b', code) or not is_code_valid(
|
||||||
score_list.append(0)
|
code
|
||||||
continue
|
):
|
||||||
|
score_list.append(0)
|
||||||
# if asnwer_name not in code:
|
continue
|
||||||
# score_list.append(0)
|
|
||||||
# continue
|
# if asnwer_name not in code:
|
||||||
|
# score_list.append(0)
|
||||||
identifiers_ans, identifiers_output = get_token(answer_code, code)
|
# continue
|
||||||
max_len, elements = longest_common_prefix_between_lists_with_elements(identifiers_ans, identifiers_output)
|
|
||||||
if max_len != 0:
|
identifiers_ans, identifiers_output = get_token(answer_code, code)
|
||||||
base_element_len = max(len(elements[0]), len(elements[1]))
|
max_len, elements = longest_common_prefix_between_lists_with_elements(
|
||||||
temp_score = max_len/base_element_len
|
identifiers_ans, identifiers_output
|
||||||
score_list.append(temp_score)
|
)
|
||||||
else:
|
if max_len != 0:
|
||||||
score_list.append(0)
|
base_element_len = max(len(elements[0]), len(elements[1]))
|
||||||
# base_element_len = max(len(elements[0]), len(elements[1]))
|
temp_score = max_len / base_element_len
|
||||||
# temp_score = max_len/base_element_len
|
score_list.append(temp_score)
|
||||||
# score_list.append(temp_score)
|
else:
|
||||||
|
score_list.append(0)
|
||||||
score_list = sorted(score_list, reverse=True)
|
# base_element_len = max(len(elements[0]), len(elements[1]))
|
||||||
return score_list
|
# temp_score = max_len/base_element_len
|
||||||
|
# score_list.append(temp_score)
|
||||||
def get_ISM_without_verification(answer_code:str, model_output_list:list, asnwer_name:str)->list:
|
|
||||||
"""
|
score_list = sorted(score_list, reverse=True)
|
||||||
计算ISM,返回一个有序的得分列表
|
return score_list
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
score_list = []
|
def get_ISM_without_verification(
|
||||||
for code in model_output_list:
|
answer_code: str, model_output_list: list, asnwer_name: str
|
||||||
|
) -> list:
|
||||||
if asnwer_name not in code:
|
"""
|
||||||
score_list.append(0)
|
计算ISM,返回一个有序的得分列表
|
||||||
continue
|
:return:
|
||||||
|
"""
|
||||||
# if asnwer_name not in code:
|
score_list = []
|
||||||
# score_list.append(0)
|
for code in model_output_list:
|
||||||
# continue
|
if asnwer_name not in code:
|
||||||
|
score_list.append(0)
|
||||||
identifiers_ans, identifiers_output = get_token(answer_code, code)
|
continue
|
||||||
max_len, elements = longest_common_prefix_between_lists_with_elements(identifiers_ans, identifiers_output)
|
|
||||||
if max_len != 0:
|
# if asnwer_name not in code:
|
||||||
base_element_len = max(len(elements[0]), len(elements[1]))
|
# score_list.append(0)
|
||||||
temp_score = max_len/base_element_len
|
# continue
|
||||||
score_list.append(temp_score)
|
|
||||||
else:
|
identifiers_ans, identifiers_output = get_token(answer_code, code)
|
||||||
score_list.append(0)
|
max_len, elements = longest_common_prefix_between_lists_with_elements(
|
||||||
# base_element_len = max(len(elements[0]), len(elements[1]))
|
identifiers_ans, identifiers_output
|
||||||
# temp_score = max_len/base_element_len
|
)
|
||||||
# score_list.append(temp_score)
|
if max_len != 0:
|
||||||
|
base_element_len = max(len(elements[0]), len(elements[1]))
|
||||||
score_list = sorted(score_list, reverse=True)
|
temp_score = max_len / base_element_len
|
||||||
return score_list
|
score_list.append(temp_score)
|
||||||
|
else:
|
||||||
def longest_common_prefix_with_lengths(list1, list2):
|
score_list.append(0)
|
||||||
"""
|
# base_element_len = max(len(elements[0]), len(elements[1]))
|
||||||
计算两个二维列表中每个子列表的最长前缀匹配长度,并记录拥有最长前缀匹配长度的两个子列表的长度
|
# temp_score = max_len/base_element_len
|
||||||
:param list1: 第一个二维列表
|
# score_list.append(temp_score)
|
||||||
:param list2: 第二个二维列表
|
|
||||||
:return: 最长前缀匹配长度以及拥有最长前缀匹配长度的两个子列表的长度
|
score_list = sorted(score_list, reverse=True)
|
||||||
"""
|
return score_list
|
||||||
max_length = 0
|
|
||||||
len_list1 = 0
|
|
||||||
len_list2 = 0
|
def longest_common_prefix_with_lengths(list1, list2):
|
||||||
for i, sublist1 in enumerate(list1):
|
"""
|
||||||
for j, sublist2 in enumerate(list2):
|
计算两个二维列表中每个子列表的最长前缀匹配长度,并记录拥有最长前缀匹配长度的两个子列表的长度
|
||||||
match_length = 0
|
:param list1: 第一个二维列表
|
||||||
min_length = min(len(sublist1), len(sublist2))
|
:param list2: 第二个二维列表
|
||||||
for k in range(min_length):
|
:return: 最长前缀匹配长度以及拥有最长前缀匹配长度的两个子列表的长度
|
||||||
if sublist1[k] == sublist2[k]:
|
"""
|
||||||
match_length += 1
|
max_length = 0
|
||||||
else:
|
len_list1 = 0
|
||||||
break
|
len_list2 = 0
|
||||||
if match_length > max_length:
|
for i, sublist1 in enumerate(list1):
|
||||||
max_length = match_length
|
for j, sublist2 in enumerate(list2):
|
||||||
len_list1 = len(sublist1)
|
match_length = 0
|
||||||
len_list2 = len(sublist2)
|
min_length = min(len(sublist1), len(sublist2))
|
||||||
return max_length, len_list1, len_list2
|
for k in range(min_length):
|
||||||
|
if sublist1[k] == sublist2[k]:
|
||||||
|
match_length += 1
|
||||||
def get_PM(answer_code:str, model_output_list:list, asnwer_name:str)->list:
|
else:
|
||||||
"""
|
break
|
||||||
计算PM,返回一个有序的得分列表
|
if match_length > max_length:
|
||||||
:return:
|
max_length = match_length
|
||||||
"""
|
len_list1 = len(sublist1)
|
||||||
score_list = []
|
len_list2 = len(sublist2)
|
||||||
for code in model_output_list:
|
return max_length, len_list1, len_list2
|
||||||
if '```python' in code:
|
|
||||||
code = code.replace('```python', '')
|
|
||||||
code = code.replace('```', '')
|
def get_PM(answer_code: str, model_output_list: list, asnwer_name: str) -> list:
|
||||||
if not re.search(rf'\b{re.escape(asnwer_name)}\b', code) or is_code_valid(code) == False:
|
"""
|
||||||
|
计算PM,返回一个有序的得分列表
|
||||||
# if asnwer_name not in code or is_code_valid(code) == False:
|
:return:
|
||||||
score_list.append(0)
|
"""
|
||||||
continue
|
score_list = []
|
||||||
|
for code in model_output_list:
|
||||||
# if asnwer_name not in code:
|
if '```python' in code:
|
||||||
# score_list.append(0)
|
code = code.replace('```python', '')
|
||||||
# continue
|
code = code.replace('```', '')
|
||||||
|
if not re.search(rf'\b{re.escape(asnwer_name)}\b', code) or not is_code_valid(
|
||||||
ans_list = get_token_per_line(answer_code)
|
code
|
||||||
output_token_list = get_token_per_line(code)
|
):
|
||||||
max_len, len1, len2 = longest_common_prefix_with_lengths(ans_list, output_token_list)
|
# if asnwer_name not in code or is_code_valid(code) == False:
|
||||||
base_element_len = max(len1, len2)
|
score_list.append(0)
|
||||||
|
continue
|
||||||
if base_element_len != 0:
|
|
||||||
temp_score = max_len/base_element_len
|
# if asnwer_name not in code:
|
||||||
score_list.append(temp_score)
|
# score_list.append(0)
|
||||||
else:
|
# continue
|
||||||
score_list.append(0)
|
|
||||||
|
ans_list = get_token_per_line(answer_code)
|
||||||
score_list = sorted(score_list, reverse=True)
|
output_token_list = get_token_per_line(code)
|
||||||
return score_list
|
max_len, len1, len2 = longest_common_prefix_with_lengths(
|
||||||
|
ans_list, output_token_list
|
||||||
def get_score(score_list:list, k):
|
)
|
||||||
"""
|
base_element_len = max(len1, len2)
|
||||||
计算score@n,k
|
|
||||||
:param score_list:
|
if base_element_len != 0:
|
||||||
:param k:
|
temp_score = max_len / base_element_len
|
||||||
:return:
|
score_list.append(temp_score)
|
||||||
"""
|
else:
|
||||||
n = len(score_list)
|
score_list.append(0)
|
||||||
sum = 0
|
|
||||||
final = n-k+1
|
score_list = sorted(score_list, reverse=True)
|
||||||
for i in range(1, final+1):
|
return score_list
|
||||||
sum += math.comb(n-i, k-1) * score_list[i-1]
|
|
||||||
|
|
||||||
final_score = sum/math.comb(n, k)
|
def get_score(score_list: list, k):
|
||||||
|
"""
|
||||||
return final_score
|
计算score@n,k
|
||||||
|
:param score_list:
|
||||||
|
:param k:
|
||||||
k = 1
|
:return:
|
||||||
task = 'block' # block or line
|
"""
|
||||||
json_name = f"Versicode_{task}_completion.json"
|
n = len(score_list)
|
||||||
|
sum = 0
|
||||||
folder_path = f'../data/result_data/{task}_completion'
|
final = n - k + 1
|
||||||
model_list = os.listdir(folder_path)
|
for i in range(1, final + 1):
|
||||||
|
sum += math.comb(n - i, k - 1) * score_list[i - 1]
|
||||||
for model in model_list:
|
|
||||||
model_json_path = os.path.join(folder_path, model, json_name)
|
final_score = sum / math.comb(n, k)
|
||||||
with open(model_json_path, 'r', encoding='utf-8')as fr:
|
|
||||||
lodict = json.load(fr)
|
return final_score
|
||||||
data_dict = lodict
|
|
||||||
data_list = data_dict
|
|
||||||
data_len = len(data_list)
|
k = 1
|
||||||
sum_ISM = 0
|
task = 'block' # block or line
|
||||||
sum_PM = 0
|
json_name = f'Versicode_{task}_completion.json'
|
||||||
|
|
||||||
for data in data_list:
|
folder_path = f'../data/result_data/{task}_completion'
|
||||||
# model_output_list = eval(data['model_output'])
|
model_list = os.listdir(folder_path)
|
||||||
model_output_list = eval(data['model_output_clear'])[:1]
|
|
||||||
temp_list = []
|
for model in model_list:
|
||||||
for o in model_output_list:
|
model_json_path = os.path.join(folder_path, model, json_name)
|
||||||
temp_out = o.replace('```python', '')
|
with open(model_json_path, 'r', encoding='utf-8') as fr:
|
||||||
temp_out = temp_out.replace('```', '')
|
lodict = json.load(fr)
|
||||||
temp_list.append(temp_out)
|
data_dict = lodict
|
||||||
model_output_list = temp_list
|
data_list = data_dict
|
||||||
answer_code = data['code']
|
data_len = len(data_list)
|
||||||
answer_name = data['core_token']
|
sum_ISM = 0
|
||||||
#
|
sum_PM = 0
|
||||||
# answer_code = data['new_code'] #code editing
|
|
||||||
# answer_name = data['new_name'] #code editing
|
for data in data_list:
|
||||||
|
# model_output_list = eval(data['model_output'])
|
||||||
# answer_code = data['old_code'] # code editing new to old
|
model_output_list = eval(data['model_output_clear'])[:1]
|
||||||
# answer_name = data['old_name'] # code editing new to old
|
temp_list = []
|
||||||
#
|
for o in model_output_list:
|
||||||
ISM_score_list = get_ISM(answer_code, model_output_list, answer_name)
|
temp_out = o.replace('```python', '')
|
||||||
# ISM_score_without_verification_list = get_ISM_without_verification(answer_code, model_output_list, answer_name) #新增
|
temp_out = temp_out.replace('```', '')
|
||||||
PM_score_list = get_PM(answer_code, model_output_list, answer_name)
|
temp_list.append(temp_out)
|
||||||
|
model_output_list = temp_list
|
||||||
# if not ISM_score_without_verification_list == ISM_score_list:#新增
|
answer_code = data['code']
|
||||||
# for s in ISM_score_list:#新增
|
answer_name = data['core_token']
|
||||||
# if s != ISM_score_without_verification_list[ISM_score_list.index(s)]:#新增
|
#
|
||||||
# print('元数据如下')#新增
|
# answer_code = data['new_code'] #code editing
|
||||||
# print(data)#新增
|
# answer_name = data['new_name'] #code editing
|
||||||
# print('答案如下')#新增
|
|
||||||
# print(model_output_list[ISM_score_list.index(s)])#新增
|
# answer_code = data['old_code'] # code editing new to old
|
||||||
|
# answer_name = data['old_name'] # code editing new to old
|
||||||
# flag = int(input('输入1继续,0退出'))#新增
|
#
|
||||||
# if flag == 1:
|
ISM_score_list = get_ISM(answer_code, model_output_list, answer_name)
|
||||||
# continue
|
# ISM_score_without_verification_list = get_ISM_without_verification(answer_code, model_output_list, answer_name) #新增
|
||||||
|
PM_score_list = get_PM(answer_code, model_output_list, answer_name)
|
||||||
|
|
||||||
ISM_score = get_score(ISM_score_list, k)
|
# if not ISM_score_without_verification_list == ISM_score_list:#新增
|
||||||
PM_score = get_score(PM_score_list, k)
|
# for s in ISM_score_list:#新增
|
||||||
|
# if s != ISM_score_without_verification_list[ISM_score_list.index(s)]:#新增
|
||||||
sum_ISM += ISM_score
|
# print('元数据如下')#新增
|
||||||
sum_PM += PM_score
|
# print(data)#新增
|
||||||
# print(f"ISM分数:{ISM_score}")
|
# print('答案如下')#新增
|
||||||
# print(f"PM分数:{PM_score}")
|
# print(model_output_list[ISM_score_list.index(s)])#新增
|
||||||
|
|
||||||
print(f"{model}, {task} completion task, ISM@{k} score: {sum_ISM/data_len}")
|
# flag = int(input('输入1继续,0退出'))#新增
|
||||||
print(f"{model}, {task} completion task, PM@{k} score: {sum_PM/data_len}")
|
# if flag == 1:
|
||||||
|
# continue
|
||||||
|
|
||||||
|
ISM_score = get_score(ISM_score_list, k)
|
||||||
# def get_token(ans_code:str, output_code:str):
|
PM_score = get_score(PM_score_list, k)
|
||||||
# """
|
|
||||||
# 对代码进行词法分析,分解成标识符,返回两个标识符列表
|
sum_ISM += ISM_score
|
||||||
# :param ans_code:
|
sum_PM += PM_score
|
||||||
# :param output_code:
|
# print(f"ISM分数:{ISM_score}")
|
||||||
# :return:
|
# print(f"PM分数:{PM_score}")
|
||||||
# """
|
|
||||||
# tokens_ans = tokenize.tokenize(io.BytesIO(ans_code.encode('utf-8')).readline)
|
print(f'{model}, {task} completion task, ISM@{k} score: {sum_ISM / data_len}')
|
||||||
# tokens_output = tokenize.tokenize(io.BytesIO(output_code.encode('utf-8')).readline)
|
print(f'{model}, {task} completion task, PM@{k} score: {sum_PM / data_len}')
|
||||||
# identifiers_ans = []
|
|
||||||
# identifiers_output = []
|
|
||||||
# for token in tokens_ans:
|
# def get_token(ans_code:str, output_code:str):
|
||||||
# if token.type == tokenize.NAME:
|
# """
|
||||||
# identifiers_ans.append(token.string)
|
# 对代码进行词法分析,分解成标识符,返回两个标识符列表
|
||||||
#
|
# :param ans_code:
|
||||||
# for to in tokens_output:
|
# :param output_code:
|
||||||
# if to.type == tokenize.NAME:
|
# :return:
|
||||||
# identifiers_output.append(to.string)
|
# """
|
||||||
#
|
# tokens_ans = tokenize.tokenize(io.BytesIO(ans_code.encode('utf-8')).readline)
|
||||||
# return identifiers_ans, identifiers_output
|
# tokens_output = tokenize.tokenize(io.BytesIO(output_code.encode('utf-8')).readline)
|
||||||
|
# identifiers_ans = []
|
||||||
|
# identifiers_output = []
|
||||||
|
# for token in tokens_ans:
|
||||||
|
# if token.type == tokenize.NAME:
|
||||||
|
# identifiers_ans.append(token.string)
|
||||||
|
#
|
||||||
|
# for to in tokens_output:
|
||||||
|
# if to.type == tokenize.NAME:
|
||||||
|
# identifiers_output.append(to.string)
|
||||||
|
#
|
||||||
|
# return identifiers_ans, identifiers_output
|
||||||
|
|||||||
@@ -1,165 +1,198 @@
|
|||||||
"""
|
"""
|
||||||
Calculate the cdc score for migration
|
Calculate the cdc score for migration
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import re
|
import os
|
||||||
import warnings
|
import re
|
||||||
# warnings.filterwarnings("ignore", category=SyntaxWarning)
|
|
||||||
|
# warnings.filterwarnings("ignore", category=SyntaxWarning)
|
||||||
def is_correct_parameter_count(function_name, correct_code, test_code):
|
|
||||||
"""
|
|
||||||
判断参数数量是否一致
|
def is_correct_parameter_count(function_name, correct_code, test_code):
|
||||||
:param function_name:
|
"""
|
||||||
:param correct_code:
|
判断参数数量是否一致
|
||||||
:param test_code:
|
:param function_name:
|
||||||
:return:
|
:param correct_code:
|
||||||
"""
|
:param test_code:
|
||||||
# 获取正确代码中的参数数量
|
:return:
|
||||||
# return True
|
"""
|
||||||
pattern = rf'{function_name}\((.*?)\)'
|
# 获取正确代码中的参数数量
|
||||||
correct_match = re.search(pattern, correct_code)
|
# return True
|
||||||
|
pattern = rf'{function_name}\((.*?)\)'
|
||||||
if correct_match:
|
correct_match = re.search(pattern, correct_code)
|
||||||
correct_params = correct_match.group(1).strip()
|
|
||||||
correct_param_list = [p.strip() for p in correct_params.split(',') if p.strip()]
|
if correct_match:
|
||||||
expected_count = len(correct_param_list)
|
correct_params = correct_match.group(1).strip()
|
||||||
else:
|
correct_param_list = [p.strip() for p in correct_params.split(',') if p.strip()]
|
||||||
expected_count = 0 # 如果没有参数,期望数量为0
|
expected_count = len(correct_param_list)
|
||||||
|
else:
|
||||||
# 在需要判断的代码中查找函数调用
|
expected_count = 0 # 如果没有参数,期望数量为0
|
||||||
test_match = re.search(pattern, test_code)
|
|
||||||
|
# 在需要判断的代码中查找函数调用
|
||||||
if test_match:
|
test_match = re.search(pattern, test_code)
|
||||||
test_params = test_match.group(1).strip()
|
|
||||||
test_param_list = [p.strip() for p in test_params.split(',') if p.strip()]
|
if test_match:
|
||||||
return len(test_param_list) == expected_count # 检查参数数量
|
test_params = test_match.group(1).strip()
|
||||||
else:
|
test_param_list = [p.strip() for p in test_params.split(',') if p.strip()]
|
||||||
# 如果没有括号,检查函数名是否在字符串中
|
return len(test_param_list) == expected_count # 检查参数数量
|
||||||
return expected_count == 0 and function_name in test_code
|
else:
|
||||||
|
# 如果没有括号,检查函数名是否在字符串中
|
||||||
def check_keyword_parameters(function_name, correct_code, test_code):
|
return expected_count == 0 and function_name in test_code
|
||||||
"""
|
|
||||||
判断关键词参数赋值是否正确使用
|
|
||||||
:param function_name:
|
def check_keyword_parameters(function_name, correct_code, test_code):
|
||||||
:param correct_code:
|
"""
|
||||||
:param test_code:
|
判断关键词参数赋值是否正确使用
|
||||||
:return:
|
:param function_name:
|
||||||
"""
|
:param correct_code:
|
||||||
# 正则表达式匹配正确代码中的函数调用
|
:param test_code:
|
||||||
# return True
|
:return:
|
||||||
pattern = rf'{function_name}\((.*?)\)'
|
"""
|
||||||
correct_match = re.search(pattern, correct_code)
|
# 正则表达式匹配正确代码中的函数调用
|
||||||
|
# return True
|
||||||
if correct_match:
|
pattern = rf'{function_name}\((.*?)\)'
|
||||||
correct_params = correct_match.group(1).strip()
|
correct_match = re.search(pattern, correct_code)
|
||||||
correct_param_list = [p.strip() for p in correct_params.split(',') if p.strip()]
|
|
||||||
|
if correct_match:
|
||||||
# 检查待检测代码中的函数调用
|
correct_params = correct_match.group(1).strip()
|
||||||
test_match = re.search(pattern, test_code)
|
correct_param_list = [p.strip() for p in correct_params.split(',') if p.strip()]
|
||||||
|
|
||||||
if test_match:
|
# 检查待检测代码中的函数调用
|
||||||
test_params = test_match.group(1).strip()
|
test_match = re.search(pattern, test_code)
|
||||||
test_param_list = [p.strip() for p in test_params.split(',') if p.strip()]
|
|
||||||
|
if test_match:
|
||||||
# 确保待检测的每个参数都以关键字参数形式赋值
|
test_params = test_match.group(1).strip()
|
||||||
for correct_param in correct_param_list:
|
test_param_list = [p.strip() for p in test_params.split(',') if p.strip()]
|
||||||
if '=' in correct_param: # 仅当正确代码中有关键词参数
|
|
||||||
param_name = correct_param.split('=')[0].strip()
|
# 确保待检测的每个参数都以关键字参数形式赋值
|
||||||
if not any(param_name in test_param and '=' in test_param for test_param in test_param_list):
|
for correct_param in correct_param_list:
|
||||||
return False # 如果对应参数不是关键词参数,则返回False
|
if '=' in correct_param: # 仅当正确代码中有关键词参数
|
||||||
|
param_name = correct_param.split('=')[0].strip()
|
||||||
return True # 所有关键字参数匹配
|
if not any(
|
||||||
|
param_name in test_param and '=' in test_param
|
||||||
return False # 如果没有匹配,返回False
|
for test_param in test_param_list
|
||||||
|
):
|
||||||
def with_correct(answer_code:str, model_output:str)->bool:
|
return False # 如果对应参数不是关键词参数,则返回False
|
||||||
"""
|
|
||||||
当answer是with结构时,判断模型生成的是不是with结构
|
return True # 所有关键字参数匹配
|
||||||
:param answer_code:
|
|
||||||
:param model_output:
|
return False # 如果没有匹配,返回False
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# return True
|
def with_correct(answer_code: str, model_output: str) -> bool:
|
||||||
if not answer_code.startswith('with') and not model_output.startswith('with'):
|
"""
|
||||||
return True
|
当answer是with结构时,判断模型生成的是不是with结构
|
||||||
elif answer_code.startswith('with') and model_output.startswith('with'):
|
:param answer_code:
|
||||||
return True
|
:param model_output:
|
||||||
else:
|
:return:
|
||||||
return False
|
"""
|
||||||
|
# return True
|
||||||
def compute_block_score_k(answer:str, model_output:list, k:int, model_filled_code, core_line_in_core_block, core_line_in_output_clear):
|
if not answer_code.startswith('with') and not model_output.startswith('with'):
|
||||||
"""
|
return True
|
||||||
cdc需要满足五个条件,em只需要满足第一个条件
|
elif answer_code.startswith('with') and model_output.startswith('with'):
|
||||||
"""
|
return True
|
||||||
c = 0
|
else:
|
||||||
n = len(model_output)
|
return False
|
||||||
for index, code in enumerate(model_output):
|
|
||||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(model_filled_code[index]) and is_correct_parameter_count(answer, core_line_in_core_block, core_line_in_output_clear[index]) and with_correct(core_line_in_core_block, core_line_in_output_clear[index]) and check_keyword_parameters(answer, core_line_in_core_block, core_line_in_output_clear[index]):#block
|
|
||||||
# if re.search(rf'\b{re.escape(answer)}\b', code):#block
|
def compute_block_score_k(
|
||||||
c += 1
|
answer: str,
|
||||||
if n-c < k:
|
model_output: list,
|
||||||
return 1.0
|
k: int,
|
||||||
|
model_filled_code,
|
||||||
score = 1 - (math.comb(n - c, k))/(math.comb(n, k))
|
core_line_in_core_block,
|
||||||
|
core_line_in_output_clear,
|
||||||
return score
|
):
|
||||||
|
"""
|
||||||
|
cdc需要满足五个条件,em只需要满足第一个条件
|
||||||
def is_code_valid(code):
|
"""
|
||||||
|
c = 0
|
||||||
try:
|
n = len(model_output)
|
||||||
compile(code, '<string>', 'exec')
|
for index, code in enumerate(model_output):
|
||||||
return True
|
if (
|
||||||
except:
|
re.search(rf'\b{re.escape(answer)}\b', code)
|
||||||
return False
|
and is_code_valid(model_filled_code[index])
|
||||||
|
and is_correct_parameter_count(
|
||||||
def compute_score_k(answer:str, model_output:list, k:int):
|
answer, core_line_in_core_block, core_line_in_output_clear[index]
|
||||||
|
)
|
||||||
c = 0
|
and with_correct(core_line_in_core_block, core_line_in_output_clear[index])
|
||||||
n = len(model_output)
|
and check_keyword_parameters(
|
||||||
for output in model_output:
|
answer, core_line_in_core_block, core_line_in_output_clear[index]
|
||||||
if '```python' in output:
|
)
|
||||||
output = output.replace('```python', '')
|
): # block
|
||||||
output = output.replace('```', '')
|
# if re.search(rf'\b{re.escape(answer)}\b', code):#block
|
||||||
# if answer == output:
|
c += 1
|
||||||
|
if n - c < k:
|
||||||
if re.search(rf'\b{re.escape(answer)}\b', output) and is_code_valid(output) == True:
|
return 1.0
|
||||||
c += 1
|
|
||||||
if n-c < k:
|
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||||
return 1.0
|
|
||||||
|
return score
|
||||||
score = 1 - (math.comb(n - c, k))/(math.comb(n, k))
|
|
||||||
|
|
||||||
return score
|
def is_code_valid(code):
|
||||||
|
try:
|
||||||
k = 1 #cdc@k
|
compile(code, '<string>', 'exec')
|
||||||
json_name = 'VersiCode_migration.json'
|
return True
|
||||||
task = 'migration'
|
except Exception:
|
||||||
folder_path = f'../data/result_data/code_migration'
|
return False
|
||||||
|
|
||||||
model_list = os.listdir(folder_path)
|
|
||||||
for model in model_list:
|
def compute_score_k(answer: str, model_output: list, k: int):
|
||||||
# if model != 'gpt-4o':
|
c = 0
|
||||||
# continue
|
n = len(model_output)
|
||||||
model_json_path = os.path.join(folder_path, model, json_name)
|
for output in model_output:
|
||||||
with open(model_json_path, 'r', encoding='utf-8')as fr:
|
if '```python' in output:
|
||||||
lodict = json.load(fr)
|
output = output.replace('```python', '')
|
||||||
data_list = lodict
|
output = output.replace('```', '')
|
||||||
|
# if answer == output:
|
||||||
score_list = []
|
|
||||||
for data in data_list:
|
if re.search(rf'\b{re.escape(answer)}\b', output) and is_code_valid(output):
|
||||||
answer = data['new_name']# old -> new
|
c += 1
|
||||||
model_output = data[f'model_output_clear']# old -> new
|
if n - c < k:
|
||||||
|
return 1.0
|
||||||
model_filled_code = model_output
|
|
||||||
# core_line_in_core_block = data['core_line_in_new_core_block']# old -> new
|
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||||
core_line_in_core_block = data['core_line_in_code'] # old -> new
|
|
||||||
core_line_in_output_clear = data['core_line_in_output_clear']# old -> new
|
return score
|
||||||
|
|
||||||
|
|
||||||
score_list.append(compute_block_score_k(answer, model_output, k, model_filled_code, core_line_in_core_block, core_line_in_output_clear))
|
k = 1 # cdc@k
|
||||||
|
json_name = 'VersiCode_migration.json'
|
||||||
final_score = sum(score_list)/len(score_list)
|
task = 'migration'
|
||||||
print(f"{model}, {task} task, cdc@{k} score: {final_score}")
|
folder_path = '../data/result_data/code_migration'
|
||||||
|
|
||||||
|
model_list = os.listdir(folder_path)
|
||||||
|
for model in model_list:
|
||||||
|
# if model != 'gpt-4o':
|
||||||
|
# continue
|
||||||
|
model_json_path = os.path.join(folder_path, model, json_name)
|
||||||
|
with open(model_json_path, 'r', encoding='utf-8') as fr:
|
||||||
|
lodict = json.load(fr)
|
||||||
|
data_list = lodict
|
||||||
|
|
||||||
|
score_list = []
|
||||||
|
for data in data_list:
|
||||||
|
answer = data['new_name'] # old -> new
|
||||||
|
model_output = data['model_output_clear'] # old -> new
|
||||||
|
|
||||||
|
model_filled_code = model_output
|
||||||
|
# core_line_in_core_block = data['core_line_in_new_core_block']# old -> new
|
||||||
|
core_line_in_core_block = data['core_line_in_code'] # old -> new
|
||||||
|
core_line_in_output_clear = data['core_line_in_output_clear'] # old -> new
|
||||||
|
|
||||||
|
score_list.append(
|
||||||
|
compute_block_score_k(
|
||||||
|
answer,
|
||||||
|
model_output,
|
||||||
|
k,
|
||||||
|
model_filled_code,
|
||||||
|
core_line_in_core_block,
|
||||||
|
core_line_in_output_clear,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
final_score = sum(score_list) / len(score_list)
|
||||||
|
print(f'{model}, {task} task, cdc@{k} score: {final_score}')
|
||||||
|
|||||||
@@ -1,175 +1,225 @@
|
|||||||
"""
|
"""
|
||||||
Calculate the cdc score for line and block
|
Calculate the cdc score for line and block
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import re
|
import os
|
||||||
import warnings
|
import re
|
||||||
# warnings.filterwarnings("ignore", category=SyntaxWarning)
|
|
||||||
|
# warnings.filterwarnings("ignore", category=SyntaxWarning)
|
||||||
def is_code_valid(code):
|
|
||||||
|
|
||||||
try:
|
def is_code_valid(code):
|
||||||
compile(code, '<string>', 'exec')
|
try:
|
||||||
return True
|
compile(code, '<string>', 'exec')
|
||||||
except:
|
return True
|
||||||
return False
|
except Exception:
|
||||||
|
return False
|
||||||
def is_correct_parameter_count(function_name, correct_code, test_code):
|
|
||||||
"""
|
|
||||||
判断参数数量是否一致
|
def is_correct_parameter_count(function_name, correct_code, test_code):
|
||||||
:param function_name:
|
"""
|
||||||
:param correct_code:
|
判断参数数量是否一致
|
||||||
:param test_code:
|
:param function_name:
|
||||||
:return:
|
:param correct_code:
|
||||||
"""
|
:param test_code:
|
||||||
# 获取正确代码中的参数数量
|
:return:
|
||||||
# return True
|
"""
|
||||||
pattern = rf'{function_name}\((.*?)\)'
|
# 获取正确代码中的参数数量
|
||||||
correct_match = re.search(pattern, correct_code)
|
# return True
|
||||||
|
pattern = rf'{function_name}\((.*?)\)'
|
||||||
if correct_match:
|
correct_match = re.search(pattern, correct_code)
|
||||||
correct_params = correct_match.group(1).strip()
|
|
||||||
correct_param_list = [p.strip() for p in correct_params.split(',') if p.strip()]
|
if correct_match:
|
||||||
expected_count = len(correct_param_list)
|
correct_params = correct_match.group(1).strip()
|
||||||
else:
|
correct_param_list = [p.strip() for p in correct_params.split(',') if p.strip()]
|
||||||
expected_count = 0 # 如果没有参数,期望数量为0
|
expected_count = len(correct_param_list)
|
||||||
|
else:
|
||||||
# 在需要判断的代码中查找函数调用
|
expected_count = 0 # 如果没有参数,期望数量为0
|
||||||
test_match = re.search(pattern, test_code)
|
|
||||||
|
# 在需要判断的代码中查找函数调用
|
||||||
if test_match:
|
test_match = re.search(pattern, test_code)
|
||||||
test_params = test_match.group(1).strip()
|
|
||||||
test_param_list = [p.strip() for p in test_params.split(',') if p.strip()]
|
if test_match:
|
||||||
return len(test_param_list) == expected_count # 检查参数数量
|
test_params = test_match.group(1).strip()
|
||||||
else:
|
test_param_list = [p.strip() for p in test_params.split(',') if p.strip()]
|
||||||
# 如果没有括号,检查函数名是否在字符串中
|
return len(test_param_list) == expected_count # 检查参数数量
|
||||||
return expected_count == 0 and function_name in test_code
|
else:
|
||||||
|
# 如果没有括号,检查函数名是否在字符串中
|
||||||
def check_keyword_parameters(function_name, correct_code, test_code):
|
return expected_count == 0 and function_name in test_code
|
||||||
"""
|
|
||||||
判断关键词参数赋值是否正确使用
|
|
||||||
:param function_name:
|
def check_keyword_parameters(function_name, correct_code, test_code):
|
||||||
:param correct_code:
|
"""
|
||||||
:param test_code:
|
判断关键词参数赋值是否正确使用
|
||||||
:return:
|
:param function_name:
|
||||||
"""
|
:param correct_code:
|
||||||
# 正则表达式匹配正确代码中的函数调用
|
:param test_code:
|
||||||
# return True
|
:return:
|
||||||
pattern = rf'{function_name}\((.*?)\)'
|
"""
|
||||||
correct_match = re.search(pattern, correct_code)
|
# 正则表达式匹配正确代码中的函数调用
|
||||||
|
# return True
|
||||||
if correct_match:
|
pattern = rf'{function_name}\((.*?)\)'
|
||||||
correct_params = correct_match.group(1).strip()
|
correct_match = re.search(pattern, correct_code)
|
||||||
correct_param_list = [p.strip() for p in correct_params.split(',') if p.strip()]
|
|
||||||
|
if correct_match:
|
||||||
# 检查待检测代码中的函数调用
|
correct_params = correct_match.group(1).strip()
|
||||||
test_match = re.search(pattern, test_code)
|
correct_param_list = [p.strip() for p in correct_params.split(',') if p.strip()]
|
||||||
|
|
||||||
if test_match:
|
# 检查待检测代码中的函数调用
|
||||||
test_params = test_match.group(1).strip()
|
test_match = re.search(pattern, test_code)
|
||||||
test_param_list = [p.strip() for p in test_params.split(',') if p.strip()]
|
|
||||||
|
if test_match:
|
||||||
# 确保待检测的每个参数都以关键字参数形式赋值
|
test_params = test_match.group(1).strip()
|
||||||
for correct_param in correct_param_list:
|
test_param_list = [p.strip() for p in test_params.split(',') if p.strip()]
|
||||||
if '=' in correct_param: # 仅当正确代码中有关键词参数
|
|
||||||
param_name = correct_param.split('=')[0].strip()
|
# 确保待检测的每个参数都以关键字参数形式赋值
|
||||||
if not any(param_name in test_param and '=' in test_param for test_param in test_param_list):
|
for correct_param in correct_param_list:
|
||||||
return False # 如果对应参数不是关键词参数,则返回False
|
if '=' in correct_param: # 仅当正确代码中有关键词参数
|
||||||
|
param_name = correct_param.split('=')[0].strip()
|
||||||
return True # 所有关键字参数匹配
|
if not any(
|
||||||
|
param_name in test_param and '=' in test_param
|
||||||
return False # 如果没有匹配,返回False
|
for test_param in test_param_list
|
||||||
|
):
|
||||||
def with_correct(answer_code:str, model_output:str)->bool:
|
return False # 如果对应参数不是关键词参数,则返回False
|
||||||
"""
|
|
||||||
当answer是with结构时,判断模型生成的是不是with结构
|
return True # 所有关键字参数匹配
|
||||||
:param answer_code:
|
|
||||||
:param model_output:
|
return False # 如果没有匹配,返回False
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# return True
|
def with_correct(answer_code: str, model_output: str) -> bool:
|
||||||
if not answer_code.startswith('with') and not model_output.startswith('with'):
|
"""
|
||||||
return True
|
当answer是with结构时,判断模型生成的是不是with结构
|
||||||
elif answer_code.startswith('with') and model_output.startswith('with'):
|
:param answer_code:
|
||||||
return True
|
:param model_output:
|
||||||
else:
|
:return:
|
||||||
return False
|
"""
|
||||||
|
# return True
|
||||||
def compute_line_score_k(answer:str, model_output:list, k:int, model_filled_code, core_line):
|
if not answer_code.startswith('with') and not model_output.startswith('with'):
|
||||||
|
return True
|
||||||
c = 0
|
elif answer_code.startswith('with') and model_output.startswith('with'):
|
||||||
n = len(model_output)
|
return True
|
||||||
for index, code in enumerate(model_output):
|
else:
|
||||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(model_filled_code[index]) == True and is_correct_parameter_count(answer, core_line, code) and with_correct(core_line, code) and check_keyword_parameters(answer, core_line, code):#line
|
return False
|
||||||
c += 1
|
|
||||||
if n-c < k:
|
|
||||||
return 1.0
|
def compute_line_score_k(
|
||||||
|
answer: str, model_output: list, k: int, model_filled_code, core_line
|
||||||
score = 1 - (math.comb(n - c, k))/(math.comb(n, k))
|
):
|
||||||
|
c = 0
|
||||||
return score
|
n = len(model_output)
|
||||||
|
for index, code in enumerate(model_output):
|
||||||
def compute_block_score_k(answer:str, model_output:list, k:int, model_filled_code, core_line_in_core_block, core_line_in_output_clear):
|
if (
|
||||||
|
re.search(rf'\b{re.escape(answer)}\b', code)
|
||||||
c = 0
|
and is_code_valid(model_filled_code[index])
|
||||||
n = len(model_output)
|
and is_correct_parameter_count(answer, core_line, code)
|
||||||
for index, code in enumerate(model_output):
|
and with_correct(core_line, code)
|
||||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(model_filled_code[index]) and is_correct_parameter_count(answer, core_line_in_core_block, core_line_in_output_clear[index]) and with_correct(core_line_in_core_block, core_line_in_output_clear[index]) and check_keyword_parameters(answer, core_line_in_core_block, core_line_in_output_clear[index]):#block
|
and check_keyword_parameters(answer, core_line, code)
|
||||||
c += 1
|
): # line
|
||||||
if n-c < k:
|
c += 1
|
||||||
return 1.0
|
if n - c < k:
|
||||||
|
return 1.0
|
||||||
score = 1 - (math.comb(n - c, k))/(math.comb(n, k))
|
|
||||||
|
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||||
return score
|
|
||||||
|
return score
|
||||||
def compute_score_k(answer:str, model_output:list, k:int):
|
|
||||||
|
|
||||||
c = 0
|
def compute_block_score_k(
|
||||||
n = len(model_output)
|
answer: str,
|
||||||
for index, code in enumerate(model_output):
|
model_output: list,
|
||||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(code):#block
|
k: int,
|
||||||
# if re.search(rf'\b{re.escape(answer)}\b', code):#line
|
model_filled_code,
|
||||||
c += 1
|
core_line_in_core_block,
|
||||||
if n-c < k:
|
core_line_in_output_clear,
|
||||||
return 1.0
|
):
|
||||||
|
c = 0
|
||||||
score = 1 - (math.comb(n - c, k))/(math.comb(n, k))
|
n = len(model_output)
|
||||||
|
for index, code in enumerate(model_output):
|
||||||
return score
|
if (
|
||||||
|
re.search(rf'\b{re.escape(answer)}\b', code)
|
||||||
k = 3 #cdc@k
|
and is_code_valid(model_filled_code[index])
|
||||||
task = 'block' # line or block
|
and is_correct_parameter_count(
|
||||||
json_name = f"Versicode_{task}_completion.json"
|
answer, core_line_in_core_block, core_line_in_output_clear[index]
|
||||||
|
)
|
||||||
folder_path = f'../data/result_data/{task}_completion'
|
and with_correct(core_line_in_core_block, core_line_in_output_clear[index])
|
||||||
model_list = os.listdir(folder_path)
|
and check_keyword_parameters(
|
||||||
|
answer, core_line_in_core_block, core_line_in_output_clear[index]
|
||||||
for model in model_list:
|
)
|
||||||
model_json_path = os.path.join(folder_path, model, json_name)
|
): # block
|
||||||
with open(model_json_path, 'r', encoding='utf-8')as fr:
|
c += 1
|
||||||
lodict = json.load(fr)
|
if n - c < k:
|
||||||
data_list = lodict
|
return 1.0
|
||||||
|
|
||||||
if task == 'line':
|
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||||
score_list = []
|
|
||||||
for data in data_list:
|
return score
|
||||||
answer = data['core_token']
|
|
||||||
model_output = eval(data['model_output_clear'])
|
|
||||||
model_filled_code = [data['masked_code'].replace('<mask>', i) for i in model_output]
|
def compute_score_k(answer: str, model_output: list, k: int):
|
||||||
core_line = data['core_line']
|
c = 0
|
||||||
score_list.append(compute_line_score_k(answer, model_output, k, model_filled_code, core_line))
|
n = len(model_output)
|
||||||
else:
|
for index, code in enumerate(model_output):
|
||||||
score_list = []
|
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(
|
||||||
for data in data_list:
|
code
|
||||||
answer = data['core_token']
|
): # block
|
||||||
model_output = eval(data['model_output_clear'])
|
# if re.search(rf'\b{re.escape(answer)}\b', code):#line
|
||||||
model_filled_code = eval(data['model_output_clear'])
|
c += 1
|
||||||
core_line = data['core_line']
|
if n - c < k:
|
||||||
core_line_in_output_clear = data['core_line_in_output_clear']
|
return 1.0
|
||||||
score_list.append(compute_block_score_k(answer, model_output, k, model_filled_code, core_line, core_line_in_output_clear))
|
|
||||||
|
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||||
final_score = sum(score_list)/len(score_list)
|
|
||||||
print(f"{model}, {task} completion task, cdc@{k} score: {final_score}")
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
k = 3 # cdc@k
|
||||||
|
task = 'block' # line or block
|
||||||
|
json_name = f'Versicode_{task}_completion.json'
|
||||||
|
|
||||||
|
folder_path = f'../data/result_data/{task}_completion'
|
||||||
|
model_list = os.listdir(folder_path)
|
||||||
|
|
||||||
|
for model in model_list:
|
||||||
|
model_json_path = os.path.join(folder_path, model, json_name)
|
||||||
|
with open(model_json_path, 'r', encoding='utf-8') as fr:
|
||||||
|
lodict = json.load(fr)
|
||||||
|
data_list = lodict
|
||||||
|
|
||||||
|
if task == 'line':
|
||||||
|
score_list = []
|
||||||
|
for data in data_list:
|
||||||
|
answer = data['core_token']
|
||||||
|
model_output = eval(data['model_output_clear'])
|
||||||
|
model_filled_code = [
|
||||||
|
data['masked_code'].replace('<mask>', i) for i in model_output
|
||||||
|
]
|
||||||
|
core_line = data['core_line']
|
||||||
|
score_list.append(
|
||||||
|
compute_line_score_k(
|
||||||
|
answer, model_output, k, model_filled_code, core_line
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
score_list = []
|
||||||
|
for data in data_list:
|
||||||
|
answer = data['core_token']
|
||||||
|
model_output = eval(data['model_output_clear'])
|
||||||
|
model_filled_code = eval(data['model_output_clear'])
|
||||||
|
core_line = data['core_line']
|
||||||
|
core_line_in_output_clear = data['core_line_in_output_clear']
|
||||||
|
score_list.append(
|
||||||
|
compute_block_score_k(
|
||||||
|
answer,
|
||||||
|
model_output,
|
||||||
|
k,
|
||||||
|
model_filled_code,
|
||||||
|
core_line,
|
||||||
|
core_line_in_output_clear,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
final_score = sum(score_list) / len(score_list)
|
||||||
|
print(f'{model}, {task} completion task, cdc@{k} score: {final_score}')
|
||||||
|
|||||||
@@ -1,175 +1,209 @@
|
|||||||
"""
|
"""
|
||||||
Calculate the cdc score for line and block
|
Calculate the cdc score for line and block
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import re
|
import os
|
||||||
import warnings
|
import re
|
||||||
# warnings.filterwarnings("ignore", category=SyntaxWarning)
|
|
||||||
|
# warnings.filterwarnings("ignore", category=SyntaxWarning)
|
||||||
def is_code_valid(code):
|
|
||||||
|
|
||||||
try:
|
def is_code_valid(code):
|
||||||
compile(code, '<string>', 'exec')
|
try:
|
||||||
return True
|
compile(code, '<string>', 'exec')
|
||||||
except:
|
return True
|
||||||
return False
|
except Exception:
|
||||||
|
return False
|
||||||
def is_correct_parameter_count(function_name, correct_code, test_code):
|
|
||||||
"""
|
|
||||||
判断参数数量是否一致
|
def is_correct_parameter_count(function_name, correct_code, test_code):
|
||||||
:param function_name:
|
"""
|
||||||
:param correct_code:
|
判断参数数量是否一致
|
||||||
:param test_code:
|
:param function_name:
|
||||||
:return:
|
:param correct_code:
|
||||||
"""
|
:param test_code:
|
||||||
# 获取正确代码中的参数数量
|
:return:
|
||||||
# return True
|
"""
|
||||||
pattern = rf'{function_name}\((.*?)\)'
|
# 获取正确代码中的参数数量
|
||||||
correct_match = re.search(pattern, correct_code)
|
# return True
|
||||||
|
pattern = rf'{function_name}\((.*?)\)'
|
||||||
if correct_match:
|
correct_match = re.search(pattern, correct_code)
|
||||||
correct_params = correct_match.group(1).strip()
|
|
||||||
correct_param_list = [p.strip() for p in correct_params.split(',') if p.strip()]
|
if correct_match:
|
||||||
expected_count = len(correct_param_list)
|
correct_params = correct_match.group(1).strip()
|
||||||
else:
|
correct_param_list = [p.strip() for p in correct_params.split(',') if p.strip()]
|
||||||
expected_count = 0 # 如果没有参数,期望数量为0
|
expected_count = len(correct_param_list)
|
||||||
|
else:
|
||||||
# 在需要判断的代码中查找函数调用
|
expected_count = 0 # 如果没有参数,期望数量为0
|
||||||
test_match = re.search(pattern, test_code)
|
|
||||||
|
# 在需要判断的代码中查找函数调用
|
||||||
if test_match:
|
test_match = re.search(pattern, test_code)
|
||||||
test_params = test_match.group(1).strip()
|
|
||||||
test_param_list = [p.strip() for p in test_params.split(',') if p.strip()]
|
if test_match:
|
||||||
return len(test_param_list) == expected_count # 检查参数数量
|
test_params = test_match.group(1).strip()
|
||||||
else:
|
test_param_list = [p.strip() for p in test_params.split(',') if p.strip()]
|
||||||
# 如果没有括号,检查函数名是否在字符串中
|
return len(test_param_list) == expected_count # 检查参数数量
|
||||||
return expected_count == 0 and function_name in test_code
|
else:
|
||||||
|
# 如果没有括号,检查函数名是否在字符串中
|
||||||
def check_keyword_parameters(function_name, correct_code, test_code):
|
return expected_count == 0 and function_name in test_code
|
||||||
"""
|
|
||||||
判断关键词参数赋值是否正确使用
|
|
||||||
:param function_name:
|
def check_keyword_parameters(function_name, correct_code, test_code):
|
||||||
:param correct_code:
|
"""
|
||||||
:param test_code:
|
判断关键词参数赋值是否正确使用
|
||||||
:return:
|
:param function_name:
|
||||||
"""
|
:param correct_code:
|
||||||
# 正则表达式匹配正确代码中的函数调用
|
:param test_code:
|
||||||
# return True
|
:return:
|
||||||
pattern = rf'{function_name}\((.*?)\)'
|
"""
|
||||||
correct_match = re.search(pattern, correct_code)
|
# 正则表达式匹配正确代码中的函数调用
|
||||||
|
# return True
|
||||||
if correct_match:
|
pattern = rf'{function_name}\((.*?)\)'
|
||||||
correct_params = correct_match.group(1).strip()
|
correct_match = re.search(pattern, correct_code)
|
||||||
correct_param_list = [p.strip() for p in correct_params.split(',') if p.strip()]
|
|
||||||
|
if correct_match:
|
||||||
# 检查待检测代码中的函数调用
|
correct_params = correct_match.group(1).strip()
|
||||||
test_match = re.search(pattern, test_code)
|
correct_param_list = [p.strip() for p in correct_params.split(',') if p.strip()]
|
||||||
|
|
||||||
if test_match:
|
# 检查待检测代码中的函数调用
|
||||||
test_params = test_match.group(1).strip()
|
test_match = re.search(pattern, test_code)
|
||||||
test_param_list = [p.strip() for p in test_params.split(',') if p.strip()]
|
|
||||||
|
if test_match:
|
||||||
# 确保待检测的每个参数都以关键字参数形式赋值
|
test_params = test_match.group(1).strip()
|
||||||
for correct_param in correct_param_list:
|
test_param_list = [p.strip() for p in test_params.split(',') if p.strip()]
|
||||||
if '=' in correct_param: # 仅当正确代码中有关键词参数
|
|
||||||
param_name = correct_param.split('=')[0].strip()
|
# 确保待检测的每个参数都以关键字参数形式赋值
|
||||||
if not any(param_name in test_param and '=' in test_param for test_param in test_param_list):
|
for correct_param in correct_param_list:
|
||||||
return False # 如果对应参数不是关键词参数,则返回False
|
if '=' in correct_param: # 仅当正确代码中有关键词参数
|
||||||
|
param_name = correct_param.split('=')[0].strip()
|
||||||
return True # 所有关键字参数匹配
|
if not any(
|
||||||
|
param_name in test_param and '=' in test_param
|
||||||
return False # 如果没有匹配,返回False
|
for test_param in test_param_list
|
||||||
|
):
|
||||||
def with_correct(answer_code:str, model_output:str)->bool:
|
return False # 如果对应参数不是关键词参数,则返回False
|
||||||
"""
|
|
||||||
当answer是with结构时,判断模型生成的是不是with结构
|
return True # 所有关键字参数匹配
|
||||||
:param answer_code:
|
|
||||||
:param model_output:
|
return False # 如果没有匹配,返回False
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# return True
|
def with_correct(answer_code: str, model_output: str) -> bool:
|
||||||
if not answer_code.startswith('with') and not model_output.startswith('with'):
|
"""
|
||||||
return True
|
当answer是with结构时,判断模型生成的是不是with结构
|
||||||
elif answer_code.startswith('with') and model_output.startswith('with'):
|
:param answer_code:
|
||||||
return True
|
:param model_output:
|
||||||
else:
|
:return:
|
||||||
return False
|
"""
|
||||||
|
# return True
|
||||||
def compute_line_score_k(answer:str, model_output:list, k:int, model_filled_code, core_line):
|
if not answer_code.startswith('with') and not model_output.startswith('with'):
|
||||||
|
return True
|
||||||
c = 0
|
elif answer_code.startswith('with') and model_output.startswith('with'):
|
||||||
n = len(model_output)
|
return True
|
||||||
for index, code in enumerate(model_output):
|
else:
|
||||||
if re.search(rf'\b{re.escape(answer)}\b', code):#line
|
return False
|
||||||
c += 1
|
|
||||||
if n-c < k:
|
|
||||||
return 1.0
|
def compute_line_score_k(
|
||||||
|
answer: str, model_output: list, k: int, model_filled_code, core_line
|
||||||
score = 1 - (math.comb(n - c, k))/(math.comb(n, k))
|
):
|
||||||
|
c = 0
|
||||||
return score
|
n = len(model_output)
|
||||||
|
for index, code in enumerate(model_output):
|
||||||
def compute_block_score_k(answer:str, model_output:list, k:int, model_filled_code, core_line_in_core_block, core_line_in_output_clear):
|
if re.search(rf'\b{re.escape(answer)}\b', code): # line
|
||||||
|
c += 1
|
||||||
c = 0
|
if n - c < k:
|
||||||
n = len(model_output)
|
return 1.0
|
||||||
for index, code in enumerate(model_output):
|
|
||||||
if re.search(rf'\b{re.escape(answer)}\b', code):#block
|
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||||
c += 1
|
|
||||||
if n-c < k:
|
return score
|
||||||
return 1.0
|
|
||||||
|
|
||||||
score = 1 - (math.comb(n - c, k))/(math.comb(n, k))
|
def compute_block_score_k(
|
||||||
|
answer: str,
|
||||||
return score
|
model_output: list,
|
||||||
|
k: int,
|
||||||
def compute_score_k(answer:str, model_output:list, k:int):
|
model_filled_code,
|
||||||
|
core_line_in_core_block,
|
||||||
c = 0
|
core_line_in_output_clear,
|
||||||
n = len(model_output)
|
):
|
||||||
for index, code in enumerate(model_output):
|
c = 0
|
||||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(code):#block
|
n = len(model_output)
|
||||||
# if re.search(rf'\b{re.escape(answer)}\b', code):#line
|
for index, code in enumerate(model_output):
|
||||||
c += 1
|
if re.search(rf'\b{re.escape(answer)}\b', code): # block
|
||||||
if n-c < k:
|
c += 1
|
||||||
return 1.0
|
if n - c < k:
|
||||||
|
return 1.0
|
||||||
score = 1 - (math.comb(n - c, k))/(math.comb(n, k))
|
|
||||||
|
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||||
return score
|
|
||||||
|
return score
|
||||||
k = 3 #em@k
|
|
||||||
task = 'block' # line or block
|
|
||||||
json_name = f"Versicode_{task}_completion.json"
|
def compute_score_k(answer: str, model_output: list, k: int):
|
||||||
|
c = 0
|
||||||
folder_path = f'../data/result_data/{task}_completion'
|
n = len(model_output)
|
||||||
model_list = os.listdir(folder_path)
|
for index, code in enumerate(model_output):
|
||||||
|
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(
|
||||||
for model in model_list:
|
code
|
||||||
model_json_path = os.path.join(folder_path, model, json_name)
|
): # block
|
||||||
with open(model_json_path, 'r', encoding='utf-8')as fr:
|
# if re.search(rf'\b{re.escape(answer)}\b', code):#line
|
||||||
lodict = json.load(fr)
|
c += 1
|
||||||
data_list = lodict
|
if n - c < k:
|
||||||
|
return 1.0
|
||||||
if task == 'line':
|
|
||||||
score_list = []
|
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||||
for data in data_list:
|
|
||||||
answer = data['core_token']
|
return score
|
||||||
model_output = eval(data['model_output_clear'])
|
|
||||||
model_filled_code = [data['masked_code'].replace('<mask>', i) for i in model_output]
|
|
||||||
core_line = data['core_line']
|
k = 3 # em@k
|
||||||
score_list.append(compute_line_score_k(answer, model_output, k, model_filled_code, core_line))
|
task = 'block' # line or block
|
||||||
else:
|
json_name = f'Versicode_{task}_completion.json'
|
||||||
score_list = []
|
|
||||||
for data in data_list:
|
folder_path = f'../data/result_data/{task}_completion'
|
||||||
answer = data['core_token']
|
model_list = os.listdir(folder_path)
|
||||||
model_output = eval(data['model_output_clear'])
|
|
||||||
model_filled_code = eval(data['model_output_clear'])
|
for model in model_list:
|
||||||
core_line = data['core_line']
|
model_json_path = os.path.join(folder_path, model, json_name)
|
||||||
core_line_in_output_clear = data['core_line_in_output_clear']
|
with open(model_json_path, 'r', encoding='utf-8') as fr:
|
||||||
score_list.append(compute_block_score_k(answer, model_output, k, model_filled_code, core_line, core_line_in_output_clear))
|
lodict = json.load(fr)
|
||||||
|
data_list = lodict
|
||||||
final_score = sum(score_list)/len(score_list)
|
|
||||||
print(f"{model}, {task} completion task, em@{k} score: {final_score}")
|
if task == 'line':
|
||||||
|
score_list = []
|
||||||
|
for data in data_list:
|
||||||
|
answer = data['core_token']
|
||||||
|
model_output = eval(data['model_output_clear'])
|
||||||
|
model_filled_code = [
|
||||||
|
data['masked_code'].replace('<mask>', i) for i in model_output
|
||||||
|
]
|
||||||
|
core_line = data['core_line']
|
||||||
|
score_list.append(
|
||||||
|
compute_line_score_k(
|
||||||
|
answer, model_output, k, model_filled_code, core_line
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
score_list = []
|
||||||
|
for data in data_list:
|
||||||
|
answer = data['core_token']
|
||||||
|
model_output = eval(data['model_output_clear'])
|
||||||
|
model_filled_code = eval(data['model_output_clear'])
|
||||||
|
core_line = data['core_line']
|
||||||
|
core_line_in_output_clear = data['core_line_in_output_clear']
|
||||||
|
score_list.append(
|
||||||
|
compute_block_score_k(
|
||||||
|
answer,
|
||||||
|
model_output,
|
||||||
|
k,
|
||||||
|
model_filled_code,
|
||||||
|
core_line,
|
||||||
|
core_line_in_output_clear,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
final_score = sum(score_list) / len(score_list)
|
||||||
|
print(f'{model}, {task} completion task, em@{k} score: {final_score}')
|
||||||
|
|||||||
@@ -1,107 +1,99 @@
|
|||||||
"""
|
"""
|
||||||
Find the line of code generated by the model using the block in the version code
|
Find the line of code generated by the model using the block in the version code
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import re
|
import json
|
||||||
import json
|
import os
|
||||||
import random
|
import random
|
||||||
|
import re
|
||||||
def process_line_mask(code_snippet, core_token):
|
|
||||||
if not core_token:
|
|
||||||
|
def process_line_mask(code_snippet, core_token):
|
||||||
return None, None
|
if not core_token:
|
||||||
|
return None, None
|
||||||
|
|
||||||
replaced_lines = {}
|
replaced_lines = {}
|
||||||
lines = code_snippet.split("\n")
|
lines = code_snippet.split('\n')
|
||||||
|
|
||||||
|
in_multi_line_comment = False
|
||||||
in_multi_line_comment = False
|
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
if in_multi_line_comment:
|
||||||
for i, line in enumerate(lines):
|
if ('"""' in line or "'''" in line) and not re.findall(
|
||||||
if in_multi_line_comment:
|
r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line
|
||||||
|
):
|
||||||
if ('"""' in line or "'''" in line) and not re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
in_multi_line_comment = False
|
||||||
in_multi_line_comment = False
|
continue
|
||||||
continue
|
elif line.strip().startswith('#'):
|
||||||
elif line.strip().startswith("#"):
|
continue
|
||||||
|
elif re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
||||||
continue
|
continue
|
||||||
elif re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
elif ('"""' in line or "'''" in line) and not re.findall(
|
||||||
|
r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line
|
||||||
continue
|
):
|
||||||
elif ('"""' in line or "'''" in line) and not re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
in_multi_line_comment = True
|
||||||
|
continue
|
||||||
in_multi_line_comment = True
|
else:
|
||||||
continue
|
if re.search(r'\bdef\s+task_function\b', line):
|
||||||
else:
|
continue
|
||||||
|
|
||||||
if re.search(r'\bdef\s+task_function\b', line):
|
if re.search(r'\b{}\b(?!\s*=)'.format(re.escape(core_token)), line):
|
||||||
continue
|
replaced_lines.update({i: line})
|
||||||
|
|
||||||
|
if replaced_lines:
|
||||||
if re.search(r'\b{}\b(?!\s*=)'.format(re.escape(core_token)), line):
|
random_line_location = random.choice(list(replaced_lines.keys()))
|
||||||
|
|
||||||
replaced_lines.update({i: line})
|
masked_line = lines[random_line_location]
|
||||||
|
leading_spaces = re.match(r'^\s*', masked_line).group(0)
|
||||||
if replaced_lines:
|
masked_line = masked_line.strip()
|
||||||
random_line_location = random.choice(list(replaced_lines.keys()))
|
lines[random_line_location] = leading_spaces + '<line_mask>'
|
||||||
|
|
||||||
masked_line = lines[random_line_location]
|
masked_code = '\n'.join(lines)
|
||||||
leading_spaces = re.match(r'^\s*', masked_line).group(0)
|
|
||||||
masked_line = masked_line.strip()
|
return masked_code, masked_line
|
||||||
lines[random_line_location] = leading_spaces + "<line_mask>"
|
|
||||||
|
return None, None
|
||||||
masked_code = '\n'.join(lines)
|
|
||||||
|
|
||||||
return masked_code, masked_line
|
def load_json(file_path):
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
return None, None
|
data = json.load(f)
|
||||||
|
return data
|
||||||
|
|
||||||
def load_json(file_path):
|
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
def save_json(file_path, data):
|
||||||
data = json.load(f)
|
with open(file_path, 'w', encoding='utf-8') as f:
|
||||||
return data
|
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
|
||||||
def save_json(file_path, data):
|
if __name__ == '__main__':
|
||||||
with open(file_path, 'w', encoding='utf-8') as f:
|
model_list = os.listdir('../data/result_data/block_completion')
|
||||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
for model in model_list:
|
||||||
|
input_json_file = f'../data/result_data/block_completion/{model}/VersiCode_block_completion.json'
|
||||||
|
output_json_file = input_json_file
|
||||||
|
data = load_json(input_json_file)
|
||||||
if __name__ == "__main__":
|
|
||||||
model_list = os.listdir('../data/result_data/block_completion')
|
for item in data:
|
||||||
for model in model_list:
|
core_token = item['core_token']
|
||||||
|
code = item['code']
|
||||||
input_json_file = f'../data/result_data/block_completion/{model}/VersiCode_block_completion.json'
|
|
||||||
output_json_file = input_json_file
|
_, core_line_in_code = process_line_mask(code, core_token)
|
||||||
data = load_json(input_json_file)
|
if core_line_in_code:
|
||||||
|
item['core_line_in_code'] = core_line_in_code
|
||||||
for item in data:
|
else:
|
||||||
core_token = item['core_token']
|
item['core_line_in_code'] = 'N/A'
|
||||||
code = item['code']
|
|
||||||
|
model_output_clear = item['model_output_clear']
|
||||||
_, core_line_in_code = process_line_mask(code, core_token)
|
core_line_in_output_list = []
|
||||||
if core_line_in_code:
|
|
||||||
item['core_line_in_code'] = core_line_in_code
|
for entry in eval(model_output_clear):
|
||||||
else:
|
_, core_line_in_output = process_line_mask(entry, core_token)
|
||||||
item['core_line_in_code'] = "N/A"
|
if core_line_in_output:
|
||||||
|
core_line_in_output_list.append(core_line_in_output)
|
||||||
model_output_clear = item['model_output_clear']
|
else:
|
||||||
core_line_in_output_list = []
|
core_line_in_output_list.append('N/A')
|
||||||
|
|
||||||
for entry in eval(model_output_clear):
|
item['core_line_in_output_clear'] = core_line_in_output_list
|
||||||
_, core_line_in_output = process_line_mask(entry, core_token)
|
|
||||||
if core_line_in_output:
|
save_json(output_json_file, data)
|
||||||
core_line_in_output_list.append(core_line_in_output)
|
print('Done!')
|
||||||
else:
|
|
||||||
core_line_in_output_list.append("N/A")
|
|
||||||
|
|
||||||
item['core_line_in_output_clear'] = core_line_in_output_list
|
|
||||||
|
|
||||||
save_json(output_json_file, data)
|
|
||||||
print("Done!")
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,108 +1,102 @@
|
|||||||
"""
|
"""
|
||||||
Find the line of code generated by the model using the block in the version code
|
Find the line of code generated by the model using the block in the version code
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import re
|
import json
|
||||||
import json
|
import os
|
||||||
import random
|
import random
|
||||||
|
import re
|
||||||
def process_line_mask(code_snippet, core_token):
|
|
||||||
if not core_token:
|
|
||||||
|
def process_line_mask(code_snippet, core_token):
|
||||||
return None, None
|
if not core_token:
|
||||||
|
return None, None
|
||||||
|
|
||||||
replaced_lines = {}
|
replaced_lines = {}
|
||||||
lines = code_snippet.split("\n")
|
lines = code_snippet.split('\n')
|
||||||
|
|
||||||
|
in_multi_line_comment = False
|
||||||
in_multi_line_comment = False
|
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
if in_multi_line_comment:
|
||||||
for i, line in enumerate(lines):
|
if ('"""' in line or "'''" in line) and not re.findall(
|
||||||
if in_multi_line_comment:
|
r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line
|
||||||
|
):
|
||||||
if ('"""' in line or "'''" in line) and not re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
in_multi_line_comment = False
|
||||||
in_multi_line_comment = False
|
continue
|
||||||
continue
|
elif line.strip().startswith('#'):
|
||||||
elif line.strip().startswith("#"):
|
continue
|
||||||
|
elif re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
||||||
continue
|
continue
|
||||||
elif re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
elif ('"""' in line or "'''" in line) and not re.findall(
|
||||||
|
r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line
|
||||||
continue
|
):
|
||||||
elif ('"""' in line or "'''" in line) and not re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
in_multi_line_comment = True
|
||||||
|
continue
|
||||||
in_multi_line_comment = True
|
else:
|
||||||
continue
|
if re.search(r'\bdef\s+task_function\b', line):
|
||||||
else:
|
continue
|
||||||
|
|
||||||
if re.search(r'\bdef\s+task_function\b', line):
|
if re.search(r'\b{}\b(?!\s*=)'.format(re.escape(core_token)), line):
|
||||||
continue
|
replaced_lines.update({i: line})
|
||||||
|
|
||||||
|
if replaced_lines:
|
||||||
if re.search(r'\b{}\b(?!\s*=)'.format(re.escape(core_token)), line):
|
random_line_location = random.choice(list(replaced_lines.keys()))
|
||||||
|
|
||||||
replaced_lines.update({i: line})
|
masked_line = lines[random_line_location]
|
||||||
|
leading_spaces = re.match(r'^\s*', masked_line).group(0)
|
||||||
if replaced_lines:
|
masked_line = masked_line.strip()
|
||||||
random_line_location = random.choice(list(replaced_lines.keys()))
|
lines[random_line_location] = leading_spaces + '<line_mask>'
|
||||||
|
|
||||||
masked_line = lines[random_line_location]
|
masked_code = '\n'.join(lines)
|
||||||
leading_spaces = re.match(r'^\s*', masked_line).group(0)
|
|
||||||
masked_line = masked_line.strip()
|
return masked_code, masked_line
|
||||||
lines[random_line_location] = leading_spaces + "<line_mask>"
|
|
||||||
|
return None, None
|
||||||
masked_code = '\n'.join(lines)
|
|
||||||
|
|
||||||
return masked_code, masked_line
|
def load_json(file_path):
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
return None, None
|
data = json.load(f)
|
||||||
|
return data
|
||||||
|
|
||||||
def load_json(file_path):
|
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
def save_json(file_path, data):
|
||||||
data = json.load(f)
|
with open(file_path, 'w', encoding='utf-8') as f:
|
||||||
return data
|
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
|
|
||||||
def save_json(file_path, data):
|
if __name__ == '__main__':
|
||||||
with open(file_path, 'w', encoding='utf-8') as f:
|
model_list = os.listdir('../data/result_data/code_migration')
|
||||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
for model in model_list:
|
||||||
|
input_json_file = (
|
||||||
|
f'../data/result_data/code_migration/{model}/VersiCode_migration.json'
|
||||||
|
)
|
||||||
if __name__ == "__main__":
|
output_json_file = input_json_file
|
||||||
model_list = os.listdir('../data/result_data/code_migration')
|
data = load_json(input_json_file)
|
||||||
for model in model_list:
|
|
||||||
|
for item in data:
|
||||||
input_json_file = f'../data/result_data/code_migration/{model}/VersiCode_migration.json'
|
core_token = item['old_name']
|
||||||
output_json_file = input_json_file
|
code = item['old_code']
|
||||||
data = load_json(input_json_file)
|
|
||||||
|
_, core_line_in_code = process_line_mask(code, core_token)
|
||||||
for item in data:
|
if core_line_in_code:
|
||||||
core_token = item['old_name']
|
item['core_line_in_code'] = core_line_in_code
|
||||||
code = item['old_code']
|
else:
|
||||||
|
item['core_line_in_code'] = 'N/A'
|
||||||
_, core_line_in_code = process_line_mask(code, core_token)
|
|
||||||
if core_line_in_code:
|
model_output_clear = item['model_output_clear']
|
||||||
item['core_line_in_code'] = core_line_in_code
|
core_line_in_output_list = []
|
||||||
else:
|
|
||||||
item['core_line_in_code'] = "N/A"
|
core_token = item['new_name']
|
||||||
|
for entry in eval(model_output_clear):
|
||||||
model_output_clear = item['model_output_clear']
|
_, core_line_in_output = process_line_mask(entry, core_token)
|
||||||
core_line_in_output_list = []
|
if core_line_in_output:
|
||||||
|
core_line_in_output_list.append(core_line_in_output)
|
||||||
core_token = item['new_name']
|
else:
|
||||||
for entry in eval(model_output_clear):
|
core_line_in_output_list.append('N/A')
|
||||||
_, core_line_in_output = process_line_mask(entry, core_token)
|
|
||||||
if core_line_in_output:
|
item['core_line_in_output_clear'] = core_line_in_output_list
|
||||||
core_line_in_output_list.append(core_line_in_output)
|
|
||||||
else:
|
save_json(output_json_file, data)
|
||||||
core_line_in_output_list.append("N/A")
|
print('Done!')
|
||||||
|
|
||||||
item['core_line_in_output_clear'] = core_line_in_output_list
|
|
||||||
|
|
||||||
save_json(output_json_file, data)
|
|
||||||
print("Done!")
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,36 +1,38 @@
|
|||||||
"""
|
"""
|
||||||
Clear the<start>and<end>generated by the model in inference
|
Clear the<start>and<end>generated by the model in inference
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
|
model_name = ''
|
||||||
model_name = ''
|
task = 'block_completion'
|
||||||
task = 'block_completion'
|
|
||||||
|
result_path = f'../data/result_data/{task}/{model_name}/VersiCode_block_completion.json' # Modify the file according to the task format
|
||||||
result_path = f'../data/result_data/{task}/{model_name}/VersiCode_block_completion.json' #Modify the file according to the task format
|
|
||||||
|
|
||||||
|
with open(result_path, 'r', encoding='utf-8') as fr:
|
||||||
with open(result_path, 'r', encoding='utf-8')as fr:
|
lodict = json.load(fr)
|
||||||
lodict = json.load(fr)
|
data_dict = lodict
|
||||||
data_dict = lodict
|
data_list = data_dict
|
||||||
data_list = data_dict
|
|
||||||
|
for data in data_list:
|
||||||
for data in data_list:
|
temp_list = []
|
||||||
temp_list = []
|
model_output_list = eval(data['model_output'])
|
||||||
model_output_list = eval(data['model_output'])
|
for output in model_output_list:
|
||||||
for output in model_output_list:
|
if '<start>' in output and '<end>' in output:
|
||||||
|
start_index = output.find('<start>') + len('<start>')
|
||||||
if "<start>" in output and "<end>" in output:
|
end_index = output.find('<end>')
|
||||||
start_index = output.find("<start>") + len("<start>")
|
content = (
|
||||||
end_index = output.find("<end>")
|
output[start_index:end_index]
|
||||||
content = output[start_index:end_index].replace('```python', '').replace('```', '')
|
.replace('```python', '')
|
||||||
else:
|
.replace('```', '')
|
||||||
content = "no_answer"
|
)
|
||||||
|
else:
|
||||||
temp_list.append(content)
|
content = 'no_answer'
|
||||||
|
|
||||||
data['model_output_clear'] = str(temp_list)
|
temp_list.append(content)
|
||||||
|
|
||||||
with open(result_path, 'w', encoding='utf-8')as fw:
|
data['model_output_clear'] = str(temp_list)
|
||||||
json.dump(data_dict, fw, indent=4, ensure_ascii=False)
|
|
||||||
|
with open(result_path, 'w', encoding='utf-8') as fw:
|
||||||
|
json.dump(data_dict, fw, indent=4, ensure_ascii=False)
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ vi.mock("react-i18next", async () => {
|
|||||||
if (i18nKey === "SETTINGS$API_KEYS_DESCRIPTION") {
|
if (i18nKey === "SETTINGS$API_KEYS_DESCRIPTION") {
|
||||||
return (
|
return (
|
||||||
<span>
|
<span>
|
||||||
API keys allow you to authenticate with the OpenHands API programmatically.
|
API keys allow you to authenticate with the OpenHands API programmatically.
|
||||||
Keep your API keys secure; anyone with your API key can access your account.
|
Keep your API keys secure; anyone with your API key can access your account.
|
||||||
For more information on how to use the API, see our {components.a}
|
For more information on how to use the API, see our {components.a}
|
||||||
</span>
|
</span>
|
||||||
);
|
);
|
||||||
@@ -48,7 +48,7 @@ describe("ApiKeysManager", () => {
|
|||||||
|
|
||||||
it("should render the API documentation link", () => {
|
it("should render the API documentation link", () => {
|
||||||
renderComponent();
|
renderComponent();
|
||||||
|
|
||||||
// Find the link to the API documentation
|
// Find the link to the API documentation
|
||||||
const link = screen.getByRole("link");
|
const link = screen.getByRole("link");
|
||||||
expect(link).toBeInTheDocument();
|
expect(link).toBeInTheDocument();
|
||||||
@@ -56,4 +56,4 @@ describe("ApiKeysManager", () => {
|
|||||||
expect(link).toHaveAttribute("target", "_blank");
|
expect(link).toHaveAttribute("target", "_blank");
|
||||||
expect(link).toHaveAttribute("rel", "noopener noreferrer");
|
expect(link).toHaveAttribute("rel", "noopener noreferrer");
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -39,4 +39,4 @@ describe("Check for hardcoded English strings in Home components", () => {
|
|||||||
expect(text).not.toContain(str);
|
expect(text).not.toContain(str);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -5,24 +5,23 @@
|
|||||||
* Mock Service Worker.
|
* Mock Service Worker.
|
||||||
* @see https://github.com/mswjs/msw
|
* @see https://github.com/mswjs/msw
|
||||||
* - Please do NOT modify this file.
|
* - Please do NOT modify this file.
|
||||||
* - Please do NOT serve this file on production.
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const PACKAGE_VERSION = '2.8.4'
|
const PACKAGE_VERSION = '2.10.2'
|
||||||
const INTEGRITY_CHECKSUM = '00729d72e3b82faf54ca8b9621dbb96f'
|
const INTEGRITY_CHECKSUM = 'f5825c521429caf22a4dd13b66e243af'
|
||||||
const IS_MOCKED_RESPONSE = Symbol('isMockedResponse')
|
const IS_MOCKED_RESPONSE = Symbol('isMockedResponse')
|
||||||
const activeClientIds = new Set()
|
const activeClientIds = new Set()
|
||||||
|
|
||||||
self.addEventListener('install', function () {
|
addEventListener('install', function () {
|
||||||
self.skipWaiting()
|
self.skipWaiting()
|
||||||
})
|
})
|
||||||
|
|
||||||
self.addEventListener('activate', function (event) {
|
addEventListener('activate', function (event) {
|
||||||
event.waitUntil(self.clients.claim())
|
event.waitUntil(self.clients.claim())
|
||||||
})
|
})
|
||||||
|
|
||||||
self.addEventListener('message', async function (event) {
|
addEventListener('message', async function (event) {
|
||||||
const clientId = event.source.id
|
const clientId = Reflect.get(event.source || {}, 'id')
|
||||||
|
|
||||||
if (!clientId || !self.clients) {
|
if (!clientId || !self.clients) {
|
||||||
return
|
return
|
||||||
@@ -94,17 +93,18 @@ self.addEventListener('message', async function (event) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
self.addEventListener('fetch', function (event) {
|
addEventListener('fetch', function (event) {
|
||||||
const { request } = event
|
|
||||||
|
|
||||||
// Bypass navigation requests.
|
// Bypass navigation requests.
|
||||||
if (request.mode === 'navigate') {
|
if (event.request.mode === 'navigate') {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Opening the DevTools triggers the "only-if-cached" request
|
// Opening the DevTools triggers the "only-if-cached" request
|
||||||
// that cannot be handled by the worker. Bypass such requests.
|
// that cannot be handled by the worker. Bypass such requests.
|
||||||
if (request.cache === 'only-if-cached' && request.mode !== 'same-origin') {
|
if (
|
||||||
|
event.request.cache === 'only-if-cached' &&
|
||||||
|
event.request.mode !== 'same-origin'
|
||||||
|
) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,48 +115,62 @@ self.addEventListener('fetch', function (event) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate unique request ID.
|
|
||||||
const requestId = crypto.randomUUID()
|
const requestId = crypto.randomUUID()
|
||||||
event.respondWith(handleRequest(event, requestId))
|
event.respondWith(handleRequest(event, requestId))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {FetchEvent} event
|
||||||
|
* @param {string} requestId
|
||||||
|
*/
|
||||||
async function handleRequest(event, requestId) {
|
async function handleRequest(event, requestId) {
|
||||||
const client = await resolveMainClient(event)
|
const client = await resolveMainClient(event)
|
||||||
|
const requestCloneForEvents = event.request.clone()
|
||||||
const response = await getResponse(event, client, requestId)
|
const response = await getResponse(event, client, requestId)
|
||||||
|
|
||||||
// Send back the response clone for the "response:*" life-cycle events.
|
// Send back the response clone for the "response:*" life-cycle events.
|
||||||
// Ensure MSW is active and ready to handle the message, otherwise
|
// Ensure MSW is active and ready to handle the message, otherwise
|
||||||
// this message will pend indefinitely.
|
// this message will pend indefinitely.
|
||||||
if (client && activeClientIds.has(client.id)) {
|
if (client && activeClientIds.has(client.id)) {
|
||||||
;(async function () {
|
const serializedRequest = await serializeRequest(requestCloneForEvents)
|
||||||
const responseClone = response.clone()
|
|
||||||
|
|
||||||
sendToClient(
|
// Clone the response so both the client and the library could consume it.
|
||||||
client,
|
const responseClone = response.clone()
|
||||||
{
|
|
||||||
type: 'RESPONSE',
|
sendToClient(
|
||||||
payload: {
|
client,
|
||||||
requestId,
|
{
|
||||||
isMockedResponse: IS_MOCKED_RESPONSE in response,
|
type: 'RESPONSE',
|
||||||
|
payload: {
|
||||||
|
isMockedResponse: IS_MOCKED_RESPONSE in response,
|
||||||
|
request: {
|
||||||
|
id: requestId,
|
||||||
|
...serializedRequest,
|
||||||
|
},
|
||||||
|
response: {
|
||||||
type: responseClone.type,
|
type: responseClone.type,
|
||||||
status: responseClone.status,
|
status: responseClone.status,
|
||||||
statusText: responseClone.statusText,
|
statusText: responseClone.statusText,
|
||||||
body: responseClone.body,
|
|
||||||
headers: Object.fromEntries(responseClone.headers.entries()),
|
headers: Object.fromEntries(responseClone.headers.entries()),
|
||||||
|
body: responseClone.body,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
[responseClone.body],
|
},
|
||||||
)
|
responseClone.body ? [serializedRequest.body, responseClone.body] : [],
|
||||||
})()
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return response
|
return response
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve the main client for the given event.
|
/**
|
||||||
// Client that issues a request doesn't necessarily equal the client
|
* Resolve the main client for the given event.
|
||||||
// that registered the worker. It's with the latter the worker should
|
* Client that issues a request doesn't necessarily equal the client
|
||||||
// communicate with during the response resolving phase.
|
* that registered the worker. It's with the latter the worker should
|
||||||
|
* communicate with during the response resolving phase.
|
||||||
|
* @param {FetchEvent} event
|
||||||
|
* @returns {Promise<Client | undefined>}
|
||||||
|
*/
|
||||||
async function resolveMainClient(event) {
|
async function resolveMainClient(event) {
|
||||||
const client = await self.clients.get(event.clientId)
|
const client = await self.clients.get(event.clientId)
|
||||||
|
|
||||||
@@ -184,12 +198,16 @@ async function resolveMainClient(event) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {FetchEvent} event
|
||||||
|
* @param {Client | undefined} client
|
||||||
|
* @param {string} requestId
|
||||||
|
* @returns {Promise<Response>}
|
||||||
|
*/
|
||||||
async function getResponse(event, client, requestId) {
|
async function getResponse(event, client, requestId) {
|
||||||
const { request } = event
|
|
||||||
|
|
||||||
// Clone the request because it might've been already used
|
// Clone the request because it might've been already used
|
||||||
// (i.e. its body has been read and sent to the client).
|
// (i.e. its body has been read and sent to the client).
|
||||||
const requestClone = request.clone()
|
const requestClone = event.request.clone()
|
||||||
|
|
||||||
function passthrough() {
|
function passthrough() {
|
||||||
// Cast the request headers to a new Headers instance
|
// Cast the request headers to a new Headers instance
|
||||||
@@ -230,29 +248,17 @@ async function getResponse(event, client, requestId) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Notify the client that a request has been intercepted.
|
// Notify the client that a request has been intercepted.
|
||||||
const requestBuffer = await request.arrayBuffer()
|
const serializedRequest = await serializeRequest(event.request)
|
||||||
const clientMessage = await sendToClient(
|
const clientMessage = await sendToClient(
|
||||||
client,
|
client,
|
||||||
{
|
{
|
||||||
type: 'REQUEST',
|
type: 'REQUEST',
|
||||||
payload: {
|
payload: {
|
||||||
id: requestId,
|
id: requestId,
|
||||||
url: request.url,
|
...serializedRequest,
|
||||||
mode: request.mode,
|
|
||||||
method: request.method,
|
|
||||||
headers: Object.fromEntries(request.headers.entries()),
|
|
||||||
cache: request.cache,
|
|
||||||
credentials: request.credentials,
|
|
||||||
destination: request.destination,
|
|
||||||
integrity: request.integrity,
|
|
||||||
redirect: request.redirect,
|
|
||||||
referrer: request.referrer,
|
|
||||||
referrerPolicy: request.referrerPolicy,
|
|
||||||
body: requestBuffer,
|
|
||||||
keepalive: request.keepalive,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
[requestBuffer],
|
[serializedRequest.body],
|
||||||
)
|
)
|
||||||
|
|
||||||
switch (clientMessage.type) {
|
switch (clientMessage.type) {
|
||||||
@@ -268,6 +274,12 @@ async function getResponse(event, client, requestId) {
|
|||||||
return passthrough()
|
return passthrough()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {Client} client
|
||||||
|
* @param {any} message
|
||||||
|
* @param {Array<Transferable>} transferrables
|
||||||
|
* @returns {Promise<any>}
|
||||||
|
*/
|
||||||
function sendToClient(client, message, transferrables = []) {
|
function sendToClient(client, message, transferrables = []) {
|
||||||
return new Promise((resolve, reject) => {
|
return new Promise((resolve, reject) => {
|
||||||
const channel = new MessageChannel()
|
const channel = new MessageChannel()
|
||||||
@@ -280,14 +292,18 @@ function sendToClient(client, message, transferrables = []) {
|
|||||||
resolve(event.data)
|
resolve(event.data)
|
||||||
}
|
}
|
||||||
|
|
||||||
client.postMessage(
|
client.postMessage(message, [
|
||||||
message,
|
channel.port2,
|
||||||
[channel.port2].concat(transferrables.filter(Boolean)),
|
...transferrables.filter(Boolean),
|
||||||
)
|
])
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async function respondWithMock(response) {
|
/**
|
||||||
|
* @param {Response} response
|
||||||
|
* @returns {Response}
|
||||||
|
*/
|
||||||
|
function respondWithMock(response) {
|
||||||
// Setting response status code to 0 is a no-op.
|
// Setting response status code to 0 is a no-op.
|
||||||
// However, when responding with a "Response.error()", the produced Response
|
// However, when responding with a "Response.error()", the produced Response
|
||||||
// instance will have status code set to 0. Since it's not possible to create
|
// instance will have status code set to 0. Since it's not possible to create
|
||||||
@@ -305,3 +321,24 @@ async function respondWithMock(response) {
|
|||||||
|
|
||||||
return mockedResponse
|
return mockedResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param {Request} request
|
||||||
|
*/
|
||||||
|
async function serializeRequest(request) {
|
||||||
|
return {
|
||||||
|
url: request.url,
|
||||||
|
mode: request.mode,
|
||||||
|
method: request.method,
|
||||||
|
headers: Object.fromEntries(request.headers.entries()),
|
||||||
|
cache: request.cache,
|
||||||
|
credentials: request.credentials,
|
||||||
|
destination: request.destination,
|
||||||
|
integrity: request.integrity,
|
||||||
|
redirect: request.redirect,
|
||||||
|
referrer: request.referrer,
|
||||||
|
referrerPolicy: request.referrerPolicy,
|
||||||
|
body: await request.arrayBuffer(),
|
||||||
|
keepalive: request.keepalive,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -60,11 +60,11 @@ Object.entries(translationJson).forEach(([key, translations]) => {
|
|||||||
if (Object.keys(missingTranslations).length > 0) {
|
if (Object.keys(missingTranslations).length > 0) {
|
||||||
console.error('\x1b[31m%s\x1b[0m', 'ERROR: Missing translations detected');
|
console.error('\x1b[31m%s\x1b[0m', 'ERROR: Missing translations detected');
|
||||||
console.error(`Found ${Object.keys(missingTranslations).length} translation keys with missing languages:`);
|
console.error(`Found ${Object.keys(missingTranslations).length} translation keys with missing languages:`);
|
||||||
|
|
||||||
Object.entries(missingTranslations).forEach(([key, langs]) => {
|
Object.entries(missingTranslations).forEach(([key, langs]) => {
|
||||||
console.error(`- Key "${key}" is missing translations for: ${langs.join(', ')}`);
|
console.error(`- Key "${key}" is missing translations for: ${langs.join(', ')}`);
|
||||||
});
|
});
|
||||||
|
|
||||||
console.error('\nPlease add the missing translations before committing.');
|
console.error('\nPlease add the missing translations before committing.');
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,11 +72,11 @@ if (Object.keys(missingTranslations).length > 0) {
|
|||||||
if (Object.keys(extraLanguages).length > 0) {
|
if (Object.keys(extraLanguages).length > 0) {
|
||||||
console.error('\x1b[31m%s\x1b[0m', 'ERROR: Extra languages detected');
|
console.error('\x1b[31m%s\x1b[0m', 'ERROR: Extra languages detected');
|
||||||
console.error(`Found ${Object.keys(extraLanguages).length} translation keys with extra languages not in AvailableLanguages:`);
|
console.error(`Found ${Object.keys(extraLanguages).length} translation keys with extra languages not in AvailableLanguages:`);
|
||||||
|
|
||||||
Object.entries(extraLanguages).forEach(([key, langs]) => {
|
Object.entries(extraLanguages).forEach(([key, langs]) => {
|
||||||
console.error(`- Key "${key}" has translations for unsupported languages: ${langs.join(', ')}`);
|
console.error(`- Key "${key}" has translations for unsupported languages: ${langs.join(', ')}`);
|
||||||
});
|
});
|
||||||
|
|
||||||
console.error('\nPlease remove the extra languages before committing.');
|
console.error('\nPlease remove the extra languages before committing.');
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,4 +85,4 @@ if (hasErrors) {
|
|||||||
process.exit(1);
|
process.exit(1);
|
||||||
} else {
|
} else {
|
||||||
console.log('\x1b[32m%s\x1b[0m', 'All translation keys have complete language coverage!');
|
console.log('\x1b[32m%s\x1b[0m', 'All translation keys have complete language coverage!');
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,10 +19,10 @@ vi.mock("react-i18next", () => ({
|
|||||||
|
|
||||||
describe("RepositorySelectionForm", () => {
|
describe("RepositorySelectionForm", () => {
|
||||||
const mockOnRepoSelection = vi.fn();
|
const mockOnRepoSelection = vi.fn();
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
vi.resetAllMocks();
|
vi.resetAllMocks();
|
||||||
|
|
||||||
// Mock the hooks with default values
|
// Mock the hooks with default values
|
||||||
(useUserRepositories as any).mockReturnValue({
|
(useUserRepositories as any).mockReturnValue({
|
||||||
data: [
|
data: [
|
||||||
@@ -32,7 +32,7 @@ describe("RepositorySelectionForm", () => {
|
|||||||
isLoading: false,
|
isLoading: false,
|
||||||
isError: false,
|
isError: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
(useRepositoryBranches as any).mockReturnValue({
|
(useRepositoryBranches as any).mockReturnValue({
|
||||||
data: [
|
data: [
|
||||||
{ name: "main" },
|
{ name: "main" },
|
||||||
@@ -41,90 +41,90 @@ describe("RepositorySelectionForm", () => {
|
|||||||
isLoading: false,
|
isLoading: false,
|
||||||
isError: false,
|
isError: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
(useCreateConversation as any).mockReturnValue({
|
(useCreateConversation as any).mockReturnValue({
|
||||||
mutate: vi.fn(),
|
mutate: vi.fn(),
|
||||||
isPending: false,
|
isPending: false,
|
||||||
isSuccess: false,
|
isSuccess: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
(useIsCreatingConversation as any).mockReturnValue(false);
|
(useIsCreatingConversation as any).mockReturnValue(false);
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should clear selected branch when input is empty", async () => {
|
it("should clear selected branch when input is empty", async () => {
|
||||||
render(<RepositorySelectionForm onRepoSelection={mockOnRepoSelection} />);
|
render(<RepositorySelectionForm onRepoSelection={mockOnRepoSelection} />);
|
||||||
|
|
||||||
// First select a repository to enable the branch dropdown
|
// First select a repository to enable the branch dropdown
|
||||||
const repoDropdown = screen.getByTestId("repository-dropdown");
|
const repoDropdown = screen.getByTestId("repository-dropdown");
|
||||||
fireEvent.change(repoDropdown, { target: { value: "test/repo1" } });
|
fireEvent.change(repoDropdown, { target: { value: "test/repo1" } });
|
||||||
|
|
||||||
// Get the branch dropdown and verify it's enabled
|
// Get the branch dropdown and verify it's enabled
|
||||||
const branchDropdown = screen.getByTestId("branch-dropdown");
|
const branchDropdown = screen.getByTestId("branch-dropdown");
|
||||||
expect(branchDropdown).not.toBeDisabled();
|
expect(branchDropdown).not.toBeDisabled();
|
||||||
|
|
||||||
// Simulate deleting all text in the branch input
|
// Simulate deleting all text in the branch input
|
||||||
fireEvent.change(branchDropdown, { target: { value: "" } });
|
fireEvent.change(branchDropdown, { target: { value: "" } });
|
||||||
|
|
||||||
// Verify the branch input is cleared (no selected branch)
|
// Verify the branch input is cleared (no selected branch)
|
||||||
expect(branchDropdown).toHaveValue("");
|
expect(branchDropdown).toHaveValue("");
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should clear selected branch when input contains only whitespace", async () => {
|
it("should clear selected branch when input contains only whitespace", async () => {
|
||||||
render(<RepositorySelectionForm onRepoSelection={mockOnRepoSelection} />);
|
render(<RepositorySelectionForm onRepoSelection={mockOnRepoSelection} />);
|
||||||
|
|
||||||
// First select a repository to enable the branch dropdown
|
// First select a repository to enable the branch dropdown
|
||||||
const repoDropdown = screen.getByTestId("repository-dropdown");
|
const repoDropdown = screen.getByTestId("repository-dropdown");
|
||||||
fireEvent.change(repoDropdown, { target: { value: "test/repo1" } });
|
fireEvent.change(repoDropdown, { target: { value: "test/repo1" } });
|
||||||
|
|
||||||
// Get the branch dropdown and verify it's enabled
|
// Get the branch dropdown and verify it's enabled
|
||||||
const branchDropdown = screen.getByTestId("branch-dropdown");
|
const branchDropdown = screen.getByTestId("branch-dropdown");
|
||||||
expect(branchDropdown).not.toBeDisabled();
|
expect(branchDropdown).not.toBeDisabled();
|
||||||
|
|
||||||
// Simulate entering only whitespace in the branch input
|
// Simulate entering only whitespace in the branch input
|
||||||
fireEvent.change(branchDropdown, { target: { value: " " } });
|
fireEvent.change(branchDropdown, { target: { value: " " } });
|
||||||
|
|
||||||
// Verify the branch input is cleared (no selected branch)
|
// Verify the branch input is cleared (no selected branch)
|
||||||
expect(branchDropdown).toHaveValue("");
|
expect(branchDropdown).toHaveValue("");
|
||||||
});
|
});
|
||||||
|
|
||||||
it("should keep branch empty after being cleared even with auto-selection", async () => {
|
it("should keep branch empty after being cleared even with auto-selection", async () => {
|
||||||
render(<RepositorySelectionForm onRepoSelection={mockOnRepoSelection} />);
|
render(<RepositorySelectionForm onRepoSelection={mockOnRepoSelection} />);
|
||||||
|
|
||||||
// First select a repository to enable the branch dropdown
|
// First select a repository to enable the branch dropdown
|
||||||
const repoDropdown = screen.getByTestId("repository-dropdown");
|
const repoDropdown = screen.getByTestId("repository-dropdown");
|
||||||
fireEvent.change(repoDropdown, { target: { value: "test/repo1" } });
|
fireEvent.change(repoDropdown, { target: { value: "test/repo1" } });
|
||||||
|
|
||||||
// Get the branch dropdown and verify it's enabled
|
// Get the branch dropdown and verify it's enabled
|
||||||
const branchDropdown = screen.getByTestId("branch-dropdown");
|
const branchDropdown = screen.getByTestId("branch-dropdown");
|
||||||
expect(branchDropdown).not.toBeDisabled();
|
expect(branchDropdown).not.toBeDisabled();
|
||||||
|
|
||||||
// The branch should be auto-selected to "main" initially
|
// The branch should be auto-selected to "main" initially
|
||||||
expect(branchDropdown).toHaveValue("main");
|
expect(branchDropdown).toHaveValue("main");
|
||||||
|
|
||||||
// Simulate deleting all text in the branch input
|
// Simulate deleting all text in the branch input
|
||||||
fireEvent.change(branchDropdown, { target: { value: "" } });
|
fireEvent.change(branchDropdown, { target: { value: "" } });
|
||||||
|
|
||||||
// Verify the branch input is cleared (no selected branch)
|
// Verify the branch input is cleared (no selected branch)
|
||||||
expect(branchDropdown).toHaveValue("");
|
expect(branchDropdown).toHaveValue("");
|
||||||
|
|
||||||
// Trigger a re-render by changing something else
|
// Trigger a re-render by changing something else
|
||||||
fireEvent.change(repoDropdown, { target: { value: "test/repo2" } });
|
fireEvent.change(repoDropdown, { target: { value: "test/repo2" } });
|
||||||
fireEvent.change(repoDropdown, { target: { value: "test/repo1" } });
|
fireEvent.change(repoDropdown, { target: { value: "test/repo1" } });
|
||||||
|
|
||||||
// The branch should be auto-selected to "main" again after repo change
|
// The branch should be auto-selected to "main" again after repo change
|
||||||
expect(branchDropdown).toHaveValue("main");
|
expect(branchDropdown).toHaveValue("main");
|
||||||
|
|
||||||
// Clear it again
|
// Clear it again
|
||||||
fireEvent.change(branchDropdown, { target: { value: "" } });
|
fireEvent.change(branchDropdown, { target: { value: "" } });
|
||||||
|
|
||||||
// Verify it stays empty
|
// Verify it stays empty
|
||||||
expect(branchDropdown).toHaveValue("");
|
expect(branchDropdown).toHaveValue("");
|
||||||
|
|
||||||
// Simulate a component update without changing repos
|
// Simulate a component update without changing repos
|
||||||
// This would normally trigger the useEffect if our fix wasn't working
|
// This would normally trigger the useEffect if our fix wasn't working
|
||||||
fireEvent.blur(branchDropdown);
|
fireEvent.blur(branchDropdown);
|
||||||
|
|
||||||
// Verify it still stays empty
|
// Verify it still stays empty
|
||||||
expect(branchDropdown).toHaveValue("");
|
expect(branchDropdown).toHaveValue("");
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -208,7 +208,9 @@ Note:
|
|||||||
# for visualwebarena, webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env
|
# for visualwebarena, webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env
|
||||||
# initialize and retrieve the first observation by issuing an noop OP
|
# initialize and retrieve the first observation by issuing an noop OP
|
||||||
# For non-benchmark browsing, the browser env starts with a blank page, and the agent is expected to first navigate to desired websites
|
# For non-benchmark browsing, the browser env starts with a blank page, and the agent is expected to first navigate to desired websites
|
||||||
return BrowseInteractiveAction(browser_actions='noop(1000)', return_axtree=True)
|
return BrowseInteractiveAction(
|
||||||
|
browser_actions='noop(1000)', return_axtree=True
|
||||||
|
)
|
||||||
|
|
||||||
for event in state.view:
|
for event in state.view:
|
||||||
if isinstance(event, BrowseInteractiveAction):
|
if isinstance(event, BrowseInteractiveAction):
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class MCPStdioServerConfig(BaseModel):
|
|||||||
and set(self.env.items()) == set(other.env.items())
|
and set(self.env.items()) == set(other.env.items())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MCPSHTTPServerConfig(BaseModel):
|
class MCPSHTTPServerConfig(BaseModel):
|
||||||
url: str
|
url: str
|
||||||
api_key: str | None = None
|
api_key: str | None = None
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ class GitHubService(BaseGitService, GitService):
|
|||||||
|
|
||||||
The class is instantiated via get_impl() in openhands.server.shared.py.
|
The class is instantiated via get_impl() in openhands.server.shared.py.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
BASE_URL = 'https://api.github.com'
|
BASE_URL = 'https://api.github.com'
|
||||||
token: SecretStr = SecretStr('')
|
token: SecretStr = SecretStr('')
|
||||||
refresh = False
|
refresh = False
|
||||||
@@ -508,7 +509,6 @@ class GitHubService(BaseGitService, GitService):
|
|||||||
return response['html_url']
|
return response['html_url']
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
github_service_cls = os.environ.get(
|
github_service_cls = os.environ.get(
|
||||||
'OPENHANDS_GITHUB_SERVICE_CLS',
|
'OPENHANDS_GITHUB_SERVICE_CLS',
|
||||||
'openhands.integrations.github.github_service.GitHubService',
|
'openhands.integrations.github.github_service.GitHubService',
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ class GitLabService(BaseGitService, GitService):
|
|||||||
|
|
||||||
The class is instantiated via get_impl() in openhands.server.shared.py.
|
The class is instantiated via get_impl() in openhands.server.shared.py.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
BASE_URL = 'https://gitlab.com/api/v4'
|
BASE_URL = 'https://gitlab.com/api/v4'
|
||||||
GRAPHQL_URL = 'https://gitlab.com/api/graphql'
|
GRAPHQL_URL = 'https://gitlab.com/api/graphql'
|
||||||
token: SecretStr = SecretStr('')
|
token: SecretStr = SecretStr('')
|
||||||
@@ -482,9 +483,7 @@ class GitLabService(BaseGitService, GitService):
|
|||||||
|
|
||||||
# Set default description if none provided
|
# Set default description if none provided
|
||||||
if not description:
|
if not description:
|
||||||
description = (
|
description = f'Merging changes from {source_branch} into {target_branch}'
|
||||||
f'Merging changes from {source_branch} into {target_branch}'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prepare the request payload
|
# Prepare the request payload
|
||||||
payload = {
|
payload = {
|
||||||
@@ -499,11 +498,9 @@ class GitLabService(BaseGitService, GitService):
|
|||||||
url=url, params=payload, method=RequestMethod.POST
|
url=url, params=payload, method=RequestMethod.POST
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
return response['web_url']
|
return response['web_url']
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
gitlab_service_cls = os.environ.get(
|
gitlab_service_cls = os.environ.get(
|
||||||
'OPENHANDS_GITLAB_SERVICE_CLS',
|
'OPENHANDS_GITLAB_SERVICE_CLS',
|
||||||
'openhands.integrations.gitlab.gitlab_service.GitLabService',
|
'openhands.integrations.gitlab.gitlab_service.GitLabService',
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
{{ issue_comment }}
|
{{ issue_comment }}
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
Please fix issue number #{{ issue_number }} in your repository.
|
Please fix issue number #{{ issue_number }} in your repository.
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
import typing
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
import typing
|
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import docker
|
import docker
|
||||||
@@ -283,7 +283,9 @@ class DockerRuntime(ActionExecutionClient):
|
|||||||
self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
|
self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
|
||||||
|
|
||||||
use_host_network = self.config.sandbox.use_host_network
|
use_host_network = self.config.sandbox.use_host_network
|
||||||
network_mode: typing.Literal['host'] | None = 'host' if use_host_network else None
|
network_mode: typing.Literal['host'] | None = (
|
||||||
|
'host' if use_host_network else None
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize port mappings
|
# Initialize port mappings
|
||||||
port_mapping: dict[str, list[dict[str, str]]] | None = None
|
port_mapping: dict[str, list[dict[str, str]]] | None = None
|
||||||
@@ -356,7 +358,7 @@ class DockerRuntime(ActionExecutionClient):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if self.runtime_container_image is None:
|
if self.runtime_container_image is None:
|
||||||
raise ValueError("Runtime container image is not set")
|
raise ValueError('Runtime container image is not set')
|
||||||
self.container = self.docker_client.containers.run(
|
self.container = self.docker_client.containers.run(
|
||||||
self.runtime_container_image,
|
self.runtime_container_image,
|
||||||
command=command,
|
command=command,
|
||||||
|
|||||||
@@ -363,7 +363,7 @@ class RemoteRuntime(ActionExecutionClient):
|
|||||||
self._session_api_key = start_response['session_api_key']
|
self._session_api_key = start_response['session_api_key']
|
||||||
self.log(
|
self.log(
|
||||||
'debug',
|
'debug',
|
||||||
f'Session API key setted',
|
'Session API key set',
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ class MCPProxyManager:
|
|||||||
"""
|
"""
|
||||||
if len(self.config['mcpServers']) == 0:
|
if len(self.config['mcpServers']) == 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"No MCP servers configured for FastMCP Proxy, skipping initialization."
|
'No MCP servers configured for FastMCP Proxy, skipping initialization.'
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -70,7 +70,7 @@ class MCPProxyManager:
|
|||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"FastMCP Proxy initialized successfully")
|
logger.info('FastMCP Proxy initialized successfully')
|
||||||
|
|
||||||
async def mount_to_app(
|
async def mount_to_app(
|
||||||
self, app: FastAPI, allow_origins: Optional[list[str]] = None
|
self, app: FastAPI, allow_origins: Optional[list[str]] = None
|
||||||
@@ -83,9 +83,7 @@ class MCPProxyManager:
|
|||||||
allow_origins: List of allowed origins for CORS
|
allow_origins: List of allowed origins for CORS
|
||||||
"""
|
"""
|
||||||
if len(self.config['mcpServers']) == 0:
|
if len(self.config['mcpServers']) == 0:
|
||||||
logger.info(
|
logger.info('No MCP servers configured for FastMCP Proxy, skipping mount.')
|
||||||
f"No MCP servers configured for FastMCP Proxy, skipping mount."
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.proxy:
|
if not self.proxy:
|
||||||
@@ -101,8 +99,7 @@ class MCPProxyManager:
|
|||||||
app.routes.remove('/mcp')
|
app.routes.remove('/mcp')
|
||||||
|
|
||||||
app.mount('/', mcp_app)
|
app.mount('/', mcp_app)
|
||||||
logger.info(f"Mounted FastMCP Proxy app at /mcp")
|
logger.info('Mounted FastMCP Proxy app at /mcp')
|
||||||
|
|
||||||
|
|
||||||
async def update_and_remount(
|
async def update_and_remount(
|
||||||
self,
|
self,
|
||||||
@@ -122,10 +119,7 @@ class MCPProxyManager:
|
|||||||
tools: List of tool configurations
|
tools: List of tool configurations
|
||||||
allow_origins: List of allowed origins for CORS
|
allow_origins: List of allowed origins for CORS
|
||||||
"""
|
"""
|
||||||
tools = {
|
tools = {t.name: t.model_dump() for t in stdio_servers}
|
||||||
t.name: t.model_dump()
|
|
||||||
for t in stdio_servers
|
|
||||||
}
|
|
||||||
self.config['mcpServers'] = tools
|
self.config['mcpServers'] = tools
|
||||||
|
|
||||||
del self.proxy
|
del self.proxy
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from openhands.llm.metrics import Metrics
|
|||||||
from openhands.utils.chunk_localizer import Chunk, get_top_k_chunk_matches
|
from openhands.utils.chunk_localizer import Chunk, get_top_k_chunk_matches
|
||||||
|
|
||||||
USER_MSG = """
|
USER_MSG = """
|
||||||
Code changes will be provided in the form of a draft. You will need to apply the draft to the original code.
|
Code changes will be provided in the form of a draft. You will need to apply the draft to the original code.
|
||||||
The original code will be enclosed within `<original_code>` tags.
|
The original code will be enclosed within `<original_code>` tags.
|
||||||
The draft will be enclosed within `<update_snippet>` tags.
|
The draft will be enclosed within `<update_snippet>` tags.
|
||||||
You need to output the update code within `<updated_code>` tags.
|
You need to output the update code within `<updated_code>` tags.
|
||||||
@@ -48,8 +48,8 @@ def _extract_code(string: str) -> str | None:
|
|||||||
|
|
||||||
content = str(matches[0])
|
content = str(matches[0])
|
||||||
if content.startswith('#EDIT:'):
|
if content.startswith('#EDIT:'):
|
||||||
#Remove first line
|
# Remove first line
|
||||||
content = content[content.find('\n') + 1:]
|
content = content[content.find('\n') + 1 :]
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import random
|
import random
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
from openhands.core.logger import openhands_logger as logger
|
|
||||||
|
|
||||||
def check_port_available(port: int) -> bool:
|
def check_port_available(port: int) -> bool:
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ def load_server_config() -> ServerConfig:
|
|||||||
logger.info(f'Using config class {config_cls}')
|
logger.info(f'Using config class {config_cls}')
|
||||||
|
|
||||||
server_config_cls = get_impl(ServerConfig, config_cls)
|
server_config_cls = get_impl(ServerConfig, config_cls)
|
||||||
server_config : ServerConfig = server_config_cls()
|
server_config: ServerConfig = server_config_cls()
|
||||||
server_config.verify_config()
|
server_config.verify_config()
|
||||||
|
|
||||||
return server_config
|
return server_config
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from openhands.events.stream import EventStreamSubscriber, session_exists
|
|||||||
from openhands.server.config.server_config import ServerConfig
|
from openhands.server.config.server_config import ServerConfig
|
||||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||||
from openhands.server.monitoring import MonitoringListener
|
from openhands.server.monitoring import MonitoringListener
|
||||||
from openhands.server.session.agent_session import AgentSession, WAIT_TIME_BEFORE_CLOSE
|
from openhands.server.session.agent_session import WAIT_TIME_BEFORE_CLOSE, AgentSession
|
||||||
from openhands.server.session.conversation import ServerConversation
|
from openhands.server.session.conversation import ServerConversation
|
||||||
from openhands.server.session.session import ROOM_KEY, Session
|
from openhands.server.session.session import ROOM_KEY, Session
|
||||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||||
@@ -508,7 +508,9 @@ class StandaloneConversationManager(ConversationManager):
|
|||||||
session_api_key=None,
|
session_api_key=None,
|
||||||
event_store=session.agent_session.event_stream,
|
event_store=session.agent_session.event_stream,
|
||||||
status=_get_status_from_session(session),
|
status=_get_status_from_session(session),
|
||||||
runtime_status=getattr(session.agent_session.runtime, 'runtime_status', None),
|
runtime_status=getattr(
|
||||||
|
session.agent_session.runtime, 'runtime_status', None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_conversation_url(self, conversation_id: str):
|
def _get_conversation_url(self, conversation_id: str):
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from openhands.core.schema.agent import AgentState
|
|
||||||
from openhands.integrations.service_types import ProviderType
|
from openhands.integrations.service_types import ProviderType
|
||||||
from openhands.runtime.runtime_status import RuntimeStatus
|
from openhands.runtime.runtime_status import RuntimeStatus
|
||||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||||
|
|||||||
@@ -5,13 +5,13 @@ from pydantic import BaseModel
|
|||||||
from openhands.core.logger import openhands_logger as logger
|
from openhands.core.logger import openhands_logger as logger
|
||||||
from openhands.events.event_filter import EventFilter
|
from openhands.events.event_filter import EventFilter
|
||||||
from openhands.events.serialization.event import event_to_dict
|
from openhands.events.serialization.event import event_to_dict
|
||||||
|
from openhands.memory.memory import Memory
|
||||||
|
from openhands.microagent.types import InputMetadata
|
||||||
from openhands.runtime.base import Runtime
|
from openhands.runtime.base import Runtime
|
||||||
from openhands.server.dependencies import get_dependencies
|
from openhands.server.dependencies import get_dependencies
|
||||||
from openhands.server.session.conversation import ServerConversation
|
from openhands.server.session.conversation import ServerConversation
|
||||||
from openhands.server.shared import conversation_manager
|
from openhands.server.shared import conversation_manager
|
||||||
from openhands.server.utils import get_conversation
|
from openhands.server.utils import get_conversation
|
||||||
from openhands.microagent.types import InputMetadata
|
|
||||||
from openhands.memory.memory import Memory
|
|
||||||
|
|
||||||
app = APIRouter(
|
app = APIRouter(
|
||||||
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
||||||
@@ -216,7 +216,11 @@ async def get_microagents(
|
|||||||
content=agent.content,
|
content=agent.content,
|
||||||
triggers=[],
|
triggers=[],
|
||||||
inputs=agent.metadata.inputs,
|
inputs=agent.metadata.inputs,
|
||||||
tools=[server.name for server in agent.metadata.mcp_tools.stdio_servers] if agent.metadata.mcp_tools else [],
|
tools=[
|
||||||
|
server.name for server in agent.metadata.mcp_tools.stdio_servers
|
||||||
|
]
|
||||||
|
if agent.metadata.mcp_tools
|
||||||
|
else [],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -229,7 +233,11 @@ async def get_microagents(
|
|||||||
content=agent.content,
|
content=agent.content,
|
||||||
triggers=agent.triggers,
|
triggers=agent.triggers,
|
||||||
inputs=agent.metadata.inputs,
|
inputs=agent.metadata.inputs,
|
||||||
tools=[server.name for server in agent.metadata.mcp_tools.stdio_servers] if agent.metadata.mcp_tools else [],
|
tools=[
|
||||||
|
server.name for server in agent.metadata.mcp_tools.stdio_servers
|
||||||
|
]
|
||||||
|
if agent.metadata.mcp_tools
|
||||||
|
else [],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -6,15 +6,19 @@ from openhands.events.async_event_store_wrapper import AsyncEventStoreWrapper
|
|||||||
from openhands.events.serialization import event_to_dict
|
from openhands.events.serialization import event_to_dict
|
||||||
from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
|
from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
|
||||||
from openhands.server.dependencies import get_dependencies
|
from openhands.server.dependencies import get_dependencies
|
||||||
|
from openhands.server.session.conversation import ServerConversation
|
||||||
from openhands.server.utils import get_conversation
|
from openhands.server.utils import get_conversation
|
||||||
from openhands.utils.async_utils import call_sync_from_async
|
from openhands.utils.async_utils import call_sync_from_async
|
||||||
from openhands.server.session.conversation import ServerConversation
|
|
||||||
|
|
||||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
app = APIRouter(
|
||||||
|
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post('/submit-feedback')
|
@app.post('/submit-feedback')
|
||||||
async def submit_feedback(request: Request, conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse:
|
async def submit_feedback(
|
||||||
|
request: Request, conversation: ServerConversation = Depends(get_conversation)
|
||||||
|
) -> JSONResponse:
|
||||||
"""Submit user feedback.
|
"""Submit user feedback.
|
||||||
|
|
||||||
This function stores the provided feedback data.
|
This function stores the provided feedback data.
|
||||||
@@ -37,9 +41,7 @@ async def submit_feedback(request: Request, conversation: ServerConversation = D
|
|||||||
# Assuming the storage service is already configured in the backend
|
# Assuming the storage service is already configured in the backend
|
||||||
# and there is a function to handle the storage.
|
# and there is a function to handle the storage.
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
async_store = AsyncEventStoreWrapper(
|
async_store = AsyncEventStoreWrapper(conversation.event_stream, filter_hidden=True)
|
||||||
conversation.event_stream, filter_hidden=True
|
|
||||||
)
|
|
||||||
trajectory = []
|
trajectory = []
|
||||||
async for event in async_store:
|
async for event in async_store:
|
||||||
trajectory.append(event_to_dict(event))
|
trajectory.append(event_to_dict(event))
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from fastapi import (
|
|||||||
APIRouter,
|
APIRouter,
|
||||||
Depends,
|
Depends,
|
||||||
HTTPException,
|
HTTPException,
|
||||||
Request,
|
|
||||||
status,
|
status,
|
||||||
)
|
)
|
||||||
from fastapi.responses import FileResponse, JSONResponse
|
from fastapi.responses import FileResponse, JSONResponse
|
||||||
@@ -27,17 +26,15 @@ from openhands.server.dependencies import get_dependencies
|
|||||||
from openhands.server.file_config import (
|
from openhands.server.file_config import (
|
||||||
FILES_TO_IGNORE,
|
FILES_TO_IGNORE,
|
||||||
)
|
)
|
||||||
from openhands.server.shared import (
|
from openhands.server.session.conversation import ServerConversation
|
||||||
ConversationStoreImpl,
|
|
||||||
config,
|
|
||||||
)
|
|
||||||
from openhands.server.user_auth import get_user_id
|
from openhands.server.user_auth import get_user_id
|
||||||
from openhands.server.utils import get_conversation, get_conversation_store
|
from openhands.server.utils import get_conversation, get_conversation_store
|
||||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||||
from openhands.utils.async_utils import call_sync_from_async
|
from openhands.utils.async_utils import call_sync_from_async
|
||||||
from openhands.server.session.conversation import ServerConversation
|
|
||||||
|
|
||||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
app = APIRouter(
|
||||||
|
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get(
|
@app.get(
|
||||||
@@ -50,7 +47,7 @@ app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_
|
|||||||
)
|
)
|
||||||
async def list_files(
|
async def list_files(
|
||||||
conversation: ServerConversation = Depends(get_conversation),
|
conversation: ServerConversation = Depends(get_conversation),
|
||||||
path: str | None = None
|
path: str | None = None,
|
||||||
) -> list[str] | JSONResponse:
|
) -> list[str] | JSONResponse:
|
||||||
"""List files in the specified path.
|
"""List files in the specified path.
|
||||||
|
|
||||||
@@ -132,7 +129,9 @@ async def list_files(
|
|||||||
415: {'description': 'Unsupported media type', 'model': dict},
|
415: {'description': 'Unsupported media type', 'model': dict},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def select_file(file: str, conversation: ServerConversation = Depends(get_conversation)) -> FileResponse | JSONResponse:
|
async def select_file(
|
||||||
|
file: str, conversation: ServerConversation = Depends(get_conversation)
|
||||||
|
) -> FileResponse | JSONResponse:
|
||||||
"""Retrieve the content of a specified file.
|
"""Retrieve the content of a specified file.
|
||||||
|
|
||||||
To select a file:
|
To select a file:
|
||||||
@@ -196,7 +195,9 @@ async def select_file(file: str, conversation: ServerConversation = Depends(get_
|
|||||||
500: {'description': 'Error zipping workspace', 'model': dict},
|
500: {'description': 'Error zipping workspace', 'model': dict},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
def zip_current_workspace(conversation: ServerConversation = Depends(get_conversation)) -> FileResponse | JSONResponse:
|
def zip_current_workspace(
|
||||||
|
conversation: ServerConversation = Depends(get_conversation),
|
||||||
|
) -> FileResponse | JSONResponse:
|
||||||
try:
|
try:
|
||||||
logger.debug('Zipping workspace')
|
logger.debug('Zipping workspace')
|
||||||
runtime: Runtime = conversation.runtime
|
runtime: Runtime = conversation.runtime
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import re
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
@@ -9,19 +9,18 @@ from fastapi.responses import JSONResponse
|
|||||||
from jinja2 import Environment, FileSystemLoader
|
from jinja2 import Environment, FileSystemLoader
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from openhands.events.event_filter import EventFilter
|
from openhands.core.config.llm_config import LLMConfig
|
||||||
from openhands.events.stream import EventStream
|
from openhands.core.logger import openhands_logger as logger
|
||||||
from openhands.events.action import (
|
from openhands.events.action import (
|
||||||
ChangeAgentStateAction,
|
ChangeAgentStateAction,
|
||||||
NullAction,
|
NullAction,
|
||||||
)
|
)
|
||||||
|
from openhands.events.event_filter import EventFilter
|
||||||
from openhands.events.observation import (
|
from openhands.events.observation import (
|
||||||
NullObservation,
|
|
||||||
AgentStateChangedObservation,
|
AgentStateChangedObservation,
|
||||||
|
NullObservation,
|
||||||
)
|
)
|
||||||
|
from openhands.events.stream import EventStream
|
||||||
from openhands.core.config.llm_config import LLMConfig
|
|
||||||
from openhands.core.logger import openhands_logger as logger
|
|
||||||
from openhands.integrations.provider import (
|
from openhands.integrations.provider import (
|
||||||
PROVIDER_TOKEN_TYPE,
|
PROVIDER_TOKEN_TYPE,
|
||||||
ProviderHandler,
|
ProviderHandler,
|
||||||
@@ -38,10 +37,9 @@ from openhands.server.data_models.conversation_info import ConversationInfo
|
|||||||
from openhands.server.data_models.conversation_info_result_set import (
|
from openhands.server.data_models.conversation_info_result_set import (
|
||||||
ConversationInfoResultSet,
|
ConversationInfoResultSet,
|
||||||
)
|
)
|
||||||
from openhands.server.services.conversation_service import create_new_conversation
|
|
||||||
from openhands.server.session.conversation import ServerConversation
|
|
||||||
from openhands.server.dependencies import get_dependencies
|
from openhands.server.dependencies import get_dependencies
|
||||||
from openhands.server.services.conversation_service import create_new_conversation
|
from openhands.server.services.conversation_service import create_new_conversation
|
||||||
|
from openhands.server.session.conversation import ServerConversation
|
||||||
from openhands.server.shared import (
|
from openhands.server.shared import (
|
||||||
ConversationStoreImpl,
|
ConversationStoreImpl,
|
||||||
config,
|
config,
|
||||||
@@ -53,11 +51,12 @@ from openhands.server.user_auth import (
|
|||||||
get_provider_tokens,
|
get_provider_tokens,
|
||||||
get_user_id,
|
get_user_id,
|
||||||
get_user_secrets,
|
get_user_secrets,
|
||||||
get_user_settings_store,
|
|
||||||
get_user_settings,
|
get_user_settings,
|
||||||
|
get_user_settings_store,
|
||||||
)
|
)
|
||||||
from openhands.server.user_auth.user_auth import AuthType
|
from openhands.server.user_auth.user_auth import AuthType
|
||||||
from openhands.server.utils import get_conversation_store, get_conversation as get_conversation_object
|
from openhands.server.utils import get_conversation as get_conversation_object
|
||||||
|
from openhands.server.utils import get_conversation_store
|
||||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||||
from openhands.storage.data_models.conversation_metadata import (
|
from openhands.storage.data_models.conversation_metadata import (
|
||||||
ConversationMetadata,
|
ConversationMetadata,
|
||||||
@@ -295,7 +294,7 @@ async def delete_conversation(
|
|||||||
async def get_prompt(
|
async def get_prompt(
|
||||||
event_id: int,
|
event_id: int,
|
||||||
user_settings: SettingsStore = Depends(get_user_settings_store),
|
user_settings: SettingsStore = Depends(get_user_settings_store),
|
||||||
conversation: ServerConversation | None = Depends(get_conversation_object)
|
conversation: ServerConversation | None = Depends(get_conversation_object),
|
||||||
):
|
):
|
||||||
if conversation is None:
|
if conversation is None:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@@ -409,7 +408,6 @@ async def start_conversation(
|
|||||||
logger.info(f'Starting conversation: {conversation_id}')
|
logger.info(f'Starting conversation: {conversation_id}')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# Check that the conversation exists
|
# Check that the conversation exists
|
||||||
try:
|
try:
|
||||||
await conversation_store.get_metadata(conversation_id)
|
await conversation_store.get_metadata(conversation_id)
|
||||||
@@ -463,10 +461,17 @@ async def stop_conversation(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Check if the conversation is running
|
# Check if the conversation is running
|
||||||
agent_loop_info = await conversation_manager.get_agent_loop_info(user_id=user_id, filter_to_sids={conversation_id})
|
agent_loop_info = await conversation_manager.get_agent_loop_info(
|
||||||
conversation_status = agent_loop_info[0].status if agent_loop_info else ConversationStatus.STOPPED
|
user_id=user_id, filter_to_sids={conversation_id}
|
||||||
|
)
|
||||||
|
conversation_status = (
|
||||||
|
agent_loop_info[0].status if agent_loop_info else ConversationStatus.STOPPED
|
||||||
|
)
|
||||||
|
|
||||||
if conversation_status not in (ConversationStatus.STARTING, ConversationStatus.RUNNING):
|
if conversation_status not in (
|
||||||
|
ConversationStatus.STARTING,
|
||||||
|
ConversationStatus.RUNNING,
|
||||||
|
):
|
||||||
return ConversationResponse(
|
return ConversationResponse(
|
||||||
status='ok',
|
status='ok',
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
@@ -505,9 +510,13 @@ def _get_contextual_events(event_stream: EventStream, event_id: int) -> str:
|
|||||||
|
|
||||||
agent_event_filter = EventFilter(
|
agent_event_filter = EventFilter(
|
||||||
exclude_hidden=True,
|
exclude_hidden=True,
|
||||||
exclude_types=(NullAction, NullObservation, ChangeAgentStateAction, AgentStateChangedObservation
|
exclude_types=(
|
||||||
|
NullAction,
|
||||||
|
NullObservation,
|
||||||
|
ChangeAgentStateAction,
|
||||||
|
AgentStateChangedObservation,
|
||||||
),
|
),
|
||||||
) # the types of events that can be in an agent's history
|
) # the types of events that can be in an agent's history
|
||||||
|
|
||||||
# from event_id - context_size to event_id..
|
# from event_id - context_size to event_id..
|
||||||
context_before = event_stream.search_events(
|
context_before = event_stream.search_events(
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ async def create_pr(
|
|||||||
target_branch: Annotated[str, Field(description='Target branch on repo')],
|
target_branch: Annotated[str, Field(description='Target branch on repo')],
|
||||||
title: Annotated[str, Field(description='PR Title')],
|
title: Annotated[str, Field(description='PR Title')],
|
||||||
body: Annotated[str | None, Field(description='PR body')],
|
body: Annotated[str | None, Field(description='PR body')],
|
||||||
draft: Annotated[bool, Field(description='Whether PR opened is a draft')] = True
|
draft: Annotated[bool, Field(description='Whether PR opened is a draft')] = True,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Open a PR in GitHub"""
|
"""Open a PR in GitHub"""
|
||||||
|
|
||||||
@@ -127,7 +127,7 @@ async def create_pr(
|
|||||||
target_branch=target_branch,
|
target_branch=target_branch,
|
||||||
title=title,
|
title=title,
|
||||||
body=body,
|
body=body,
|
||||||
draft=draft
|
draft=draft,
|
||||||
)
|
)
|
||||||
|
|
||||||
if conversation_id:
|
if conversation_id:
|
||||||
@@ -148,7 +148,12 @@ async def create_mr(
|
|||||||
],
|
],
|
||||||
source_branch: Annotated[str, Field(description='Source branch on repo')],
|
source_branch: Annotated[str, Field(description='Source branch on repo')],
|
||||||
target_branch: Annotated[str, Field(description='Target branch on repo')],
|
target_branch: Annotated[str, Field(description='Target branch on repo')],
|
||||||
title: Annotated[str, Field(description='MR Title. Start title with `DRAFT:` or `WIP:` if applicable.')],
|
title: Annotated[
|
||||||
|
str,
|
||||||
|
Field(
|
||||||
|
description='MR Title. Start title with `DRAFT:` or `WIP:` if applicable.'
|
||||||
|
),
|
||||||
|
],
|
||||||
description: Annotated[str | None, Field(description='MR description')],
|
description: Annotated[str | None, Field(description='MR description')],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Open a MR in GitLab"""
|
"""Open a MR in GitLab"""
|
||||||
|
|||||||
@@ -8,14 +8,18 @@ from fastapi import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from openhands.server.dependencies import get_dependencies
|
from openhands.server.dependencies import get_dependencies
|
||||||
from openhands.server.utils import get_conversation
|
|
||||||
from openhands.server.session.conversation import ServerConversation
|
from openhands.server.session.conversation import ServerConversation
|
||||||
|
from openhands.server.utils import get_conversation
|
||||||
|
|
||||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
app = APIRouter(
|
||||||
|
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.route('/security/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE'])
|
@app.route('/security/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE'])
|
||||||
async def security_api(request: Request, conversation: ServerConversation = Depends(get_conversation)) -> Response:
|
async def security_api(
|
||||||
|
request: Request, conversation: ServerConversation = Depends(get_conversation)
|
||||||
|
) -> Response:
|
||||||
"""Catch-all route for security analyzer API requests.
|
"""Catch-all route for security analyzer API requests.
|
||||||
|
|
||||||
Each request is handled directly to the security analyzer.
|
Each request is handled directly to the security analyzer.
|
||||||
@@ -35,6 +39,4 @@ async def security_api(request: Request, conversation: ServerConversation = Depe
|
|||||||
detail='Security analyzer not initialized',
|
detail='Security analyzer not initialized',
|
||||||
)
|
)
|
||||||
|
|
||||||
return await conversation.security_analyzer.handle_api_request(
|
return await conversation.security_analyzer.handle_api_request(request)
|
||||||
request
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,18 +1,22 @@
|
|||||||
from fastapi import APIRouter, Depends, Request, status
|
from fastapi import APIRouter, Depends, status
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from openhands.core.logger import openhands_logger as logger
|
from openhands.core.logger import openhands_logger as logger
|
||||||
from openhands.events.async_event_store_wrapper import AsyncEventStoreWrapper
|
from openhands.events.async_event_store_wrapper import AsyncEventStoreWrapper
|
||||||
from openhands.events.serialization import event_to_trajectory
|
from openhands.events.serialization import event_to_trajectory
|
||||||
from openhands.server.dependencies import get_dependencies
|
from openhands.server.dependencies import get_dependencies
|
||||||
from openhands.server.utils import get_conversation
|
|
||||||
from openhands.server.session.conversation import ServerConversation
|
from openhands.server.session.conversation import ServerConversation
|
||||||
|
from openhands.server.utils import get_conversation
|
||||||
|
|
||||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
app = APIRouter(
|
||||||
|
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get('/trajectory')
|
@app.get('/trajectory')
|
||||||
async def get_trajectory(conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse:
|
async def get_trajectory(
|
||||||
|
conversation: ServerConversation = Depends(get_conversation),
|
||||||
|
) -> JSONResponse:
|
||||||
"""Get trajectory.
|
"""Get trajectory.
|
||||||
|
|
||||||
This function retrieves the current trajectory and returns it.
|
This function retrieves the current trajectory and returns it.
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -80,7 +79,6 @@ async def create_new_conversation(
|
|||||||
session_init_args['conversation_instructions'] = conversation_instructions
|
session_init_args['conversation_instructions'] = conversation_instructions
|
||||||
conversation_init_data = ConversationInitData(**session_init_args)
|
conversation_init_data = ConversationInitData(**session_init_args)
|
||||||
|
|
||||||
|
|
||||||
logger.info('Loading conversation store')
|
logger.info('Loading conversation store')
|
||||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||||
logger.info('ServerConversation store loaded')
|
logger.info('ServerConversation store loaded')
|
||||||
@@ -90,13 +88,14 @@ async def create_new_conversation(
|
|||||||
conversation_id = uuid.uuid4().hex
|
conversation_id = uuid.uuid4().hex
|
||||||
|
|
||||||
if not await conversation_store.exists(conversation_id):
|
if not await conversation_store.exists(conversation_id):
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f'New conversation ID: {conversation_id}',
|
f'New conversation ID: {conversation_id}',
|
||||||
extra={'user_id': user_id, 'session_id': conversation_id},
|
extra={'user_id': user_id, 'session_id': conversation_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test(user_id, conversation_id, conversation_init_data)
|
conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test(
|
||||||
|
user_id, conversation_id, conversation_init_data
|
||||||
|
)
|
||||||
conversation_title = get_default_conversation_title(conversation_id)
|
conversation_title = get_default_conversation_title(conversation_id)
|
||||||
|
|
||||||
logger.info(f'Saving metadata for conversation {conversation_id}')
|
logger.info(f'Saving metadata for conversation {conversation_id}')
|
||||||
|
|||||||
@@ -197,23 +197,21 @@ class AgentSession:
|
|||||||
finally:
|
finally:
|
||||||
self._starting = False
|
self._starting = False
|
||||||
success = finished and runtime_connected
|
success = finished and runtime_connected
|
||||||
duration = (time.time() - started_at)
|
duration = time.time() - started_at
|
||||||
|
|
||||||
log_metadata = {
|
log_metadata = {
|
||||||
'signal': 'agent_session_start',
|
'signal': 'agent_session_start',
|
||||||
'success': success,
|
'success': success,
|
||||||
'duration': duration,
|
'duration': duration,
|
||||||
'restored_state': restored_state
|
'restored_state': restored_state,
|
||||||
}
|
}
|
||||||
if success:
|
if success:
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f'Agent session start succeeded in {duration}s',
|
f'Agent session start succeeded in {duration}s', extra=log_metadata
|
||||||
extra=log_metadata
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.logger.error(
|
self.logger.error(
|
||||||
f'Agent session start failed in {duration}s',
|
f'Agent session start failed in {duration}s', extra=log_metadata
|
||||||
extra=log_metadata
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
|
|||||||
@@ -105,7 +105,12 @@ class FileConversationStore(ConversationStore):
|
|||||||
async def get_instance(
|
async def get_instance(
|
||||||
cls, config: OpenHandsConfig, user_id: str | None
|
cls, config: OpenHandsConfig, user_id: str | None
|
||||||
) -> FileConversationStore:
|
) -> FileConversationStore:
|
||||||
file_store = get_file_store(config.file_store, config.file_store_path, config.file_store_web_hook_url, config.file_store_web_hook_headers)
|
file_store = get_file_store(
|
||||||
|
config.file_store,
|
||||||
|
config.file_store_path,
|
||||||
|
config.file_store_web_hook_url,
|
||||||
|
config.file_store_web_hook_headers,
|
||||||
|
)
|
||||||
return FileConversationStore(file_store)
|
return FileConversationStore(file_store)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -43,6 +43,6 @@ class FileSecretsStore(SecretsStore):
|
|||||||
config.file_store,
|
config.file_store,
|
||||||
config.file_store_path,
|
config.file_store_path,
|
||||||
config.file_store_web_hook_url,
|
config.file_store_web_hook_url,
|
||||||
config.file_store_web_hook_headers
|
config.file_store_web_hook_headers,
|
||||||
)
|
)
|
||||||
return FileSecretsStore(file_store)
|
return FileSecretsStore(file_store)
|
||||||
|
|||||||
@@ -37,6 +37,6 @@ class FileSettingsStore(SettingsStore):
|
|||||||
config.file_store,
|
config.file_store,
|
||||||
config.file_store_path,
|
config.file_store_path,
|
||||||
config.file_store_web_hook_url,
|
config.file_store_web_hook_url,
|
||||||
config.file_store_web_hook_headers
|
config.file_store_web_hook_headers,
|
||||||
)
|
)
|
||||||
return FileSettingsStore(file_store)
|
return FileSettingsStore(file_store)
|
||||||
|
|||||||
@@ -10,24 +10,36 @@ class TestTranslationCompleteness(unittest.TestCase):
|
|||||||
|
|
||||||
def test_translation_completeness_check_runs(self):
|
def test_translation_completeness_check_runs(self):
|
||||||
"""Test that the translation completeness check script can be executed."""
|
"""Test that the translation completeness check script can be executed."""
|
||||||
frontend_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "frontend")
|
frontend_dir = os.path.join(
|
||||||
script_path = os.path.join(frontend_dir, "scripts", "check-translation-completeness.cjs")
|
os.path.dirname(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
),
|
||||||
|
'frontend',
|
||||||
|
)
|
||||||
|
script_path = os.path.join(
|
||||||
|
frontend_dir, 'scripts', 'check-translation-completeness.cjs'
|
||||||
|
)
|
||||||
|
|
||||||
# Verify the script exists
|
# Verify the script exists
|
||||||
self.assertTrue(os.path.exists(script_path), f"Script not found at {script_path}")
|
self.assertTrue(
|
||||||
|
os.path.exists(script_path), f'Script not found at {script_path}'
|
||||||
|
)
|
||||||
|
|
||||||
# Verify the script is executable
|
# Verify the script is executable
|
||||||
self.assertTrue(os.access(script_path, os.X_OK), f"Script at {script_path} is not executable")
|
self.assertTrue(
|
||||||
|
os.access(script_path, os.X_OK),
|
||||||
|
f'Script at {script_path} is not executable',
|
||||||
|
)
|
||||||
|
|
||||||
# Run the script (it may fail due to missing translations, but we just want to verify it runs)
|
# Run the script (it may fail due to missing translations, but we just want to verify it runs)
|
||||||
try:
|
try:
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
["node", script_path],
|
['node', script_path],
|
||||||
cwd=frontend_dir,
|
cwd=frontend_dir,
|
||||||
check=False,
|
check=False,
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
text=True
|
text=True,
|
||||||
)
|
)
|
||||||
# We don't assert on the return code because it might fail due to missing translations
|
# We don't assert on the return code because it might fail due to missing translations
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.fail(f"Failed to run translation completeness check: {e}")
|
self.fail(f'Failed to run translation completeness check: {e}')
|
||||||
|
|||||||
@@ -100,7 +100,6 @@ def mock_conversation_instructions_template():
|
|||||||
return 'Instructions: {{ repo_instruction }}'
|
return 'Instructions: {{ repo_instruction }}'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_followup_prompt_template():
|
def mock_followup_prompt_template():
|
||||||
return 'Issue context: {{ issues }}\n\nReview comments: {{ review_comments }}\n\nReview threads: {{ review_threads }}\n\nFiles: {{ files }}\n\nThread comments: {{ thread_context }}\n\nPlease fix this issue.'
|
return 'Issue context: {{ issues }}\n\nReview comments: {{ review_comments }}\n\nReview threads: {{ review_threads }}\n\nFiles: {{ files }}\n\nThread comments: {{ thread_context }}\n\nPlease fix this issue.'
|
||||||
@@ -532,7 +531,11 @@ async def test_process_issue(
|
|||||||
handler_instance.guess_success.assert_not_called()
|
handler_instance.guess_success.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
def test_get_instruction(mock_user_instructions_template, mock_conversation_instructions_template, mock_followup_prompt_template):
|
def test_get_instruction(
|
||||||
|
mock_user_instructions_template,
|
||||||
|
mock_conversation_instructions_template,
|
||||||
|
mock_followup_prompt_template,
|
||||||
|
):
|
||||||
issue = Issue(
|
issue = Issue(
|
||||||
owner='test_owner',
|
owner='test_owner',
|
||||||
repo='test_repo',
|
repo='test_repo',
|
||||||
@@ -545,7 +548,10 @@ def test_get_instruction(mock_user_instructions_template, mock_conversation_inst
|
|||||||
GithubIssueHandler('owner', 'repo', 'token'), mock_llm_config
|
GithubIssueHandler('owner', 'repo', 'token'), mock_llm_config
|
||||||
)
|
)
|
||||||
instruction, conversation_instructions, images_urls = issue_handler.get_instruction(
|
instruction, conversation_instructions, images_urls = issue_handler.get_instruction(
|
||||||
issue, mock_user_instructions_template, mock_conversation_instructions_template, None
|
issue,
|
||||||
|
mock_user_instructions_template,
|
||||||
|
mock_conversation_instructions_template,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
expected_instruction = 'Issue: Test Issue\n\nThis is a test issue refer to image \n\nPlease fix this issue.'
|
expected_instruction = 'Issue: Test Issue\n\nThis is a test issue refer to image \n\nPlease fix this issue.'
|
||||||
|
|
||||||
@@ -576,7 +582,10 @@ def test_get_instruction(mock_user_instructions_template, mock_conversation_inst
|
|||||||
GithubPRHandler('owner', 'repo', 'token'), mock_llm_config
|
GithubPRHandler('owner', 'repo', 'token'), mock_llm_config
|
||||||
)
|
)
|
||||||
instruction, conversation_instructions, images_urls = pr_handler.get_instruction(
|
instruction, conversation_instructions, images_urls = pr_handler.get_instruction(
|
||||||
issue, mock_followup_prompt_template, mock_conversation_instructions_template, None
|
issue,
|
||||||
|
mock_followup_prompt_template,
|
||||||
|
mock_conversation_instructions_template,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
expected_instruction = "Issue context: [\n \"Issue 1 fix the type\"\n]\n\nReview comments: None\n\nReview threads: [\n \"There is still a typo 'pthon' instead of 'python'\"\n]\n\nFiles: []\n\nThread comments: I've left review comments, please address them\n---\nThis is a valid concern.\n\nPlease fix this issue."
|
expected_instruction = "Issue context: [\n \"Issue 1 fix the type\"\n]\n\nReview comments: None\n\nReview threads: [\n \"There is still a typo 'pthon' instead of 'python'\"\n]\n\nFiles: []\n\nThread comments: I've left review comments, please address them\n---\nThis is a valid concern.\n\nPlease fix this issue."
|
||||||
|
|
||||||
@@ -601,7 +610,9 @@ def test_file_instruction():
|
|||||||
with open('openhands/resolver/prompts/resolve/basic.jinja', 'r') as f:
|
with open('openhands/resolver/prompts/resolve/basic.jinja', 'r') as f:
|
||||||
prompt = f.read()
|
prompt = f.read()
|
||||||
|
|
||||||
with open('openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r') as f:
|
with open(
|
||||||
|
'openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r'
|
||||||
|
) as f:
|
||||||
conversation_instructions_template = f.read()
|
conversation_instructions_template = f.read()
|
||||||
|
|
||||||
# Test without thread comments
|
# Test without thread comments
|
||||||
@@ -610,7 +621,7 @@ def test_file_instruction():
|
|||||||
GithubIssueHandler('owner', 'repo', 'token'), mock_llm_config
|
GithubIssueHandler('owner', 'repo', 'token'), mock_llm_config
|
||||||
)
|
)
|
||||||
instruction, conversation_instructions, images_urls = issue_handler.get_instruction(
|
instruction, conversation_instructions, images_urls = issue_handler.get_instruction(
|
||||||
issue, prompt,conversation_instructions_template, None
|
issue, prompt, conversation_instructions_template, None
|
||||||
)
|
)
|
||||||
expected_instruction = """Please fix the following issue for the repository in /workspace.
|
expected_instruction = """Please fix the following issue for the repository in /workspace.
|
||||||
An environment has been set up for you to start working. You may assume all necessary tools are installed.
|
An environment has been set up for you to start working. You may assume all necessary tools are installed.
|
||||||
@@ -620,7 +631,6 @@ Test Issue
|
|||||||
|
|
||||||
This is a test issue """
|
This is a test issue """
|
||||||
|
|
||||||
|
|
||||||
expected_conversation_instructions = """IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.
|
expected_conversation_instructions = """IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.
|
||||||
You SHOULD INCLUDE PROPER INDENTATION in your edit commands.
|
You SHOULD INCLUDE PROPER INDENTATION in your edit commands.
|
||||||
|
|
||||||
@@ -644,7 +654,9 @@ def test_file_instruction_with_repo_instruction():
|
|||||||
with open('openhands/resolver/prompts/resolve/basic.jinja', 'r') as f:
|
with open('openhands/resolver/prompts/resolve/basic.jinja', 'r') as f:
|
||||||
prompt = f.read()
|
prompt = f.read()
|
||||||
|
|
||||||
with open('openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r') as f:
|
with open(
|
||||||
|
'openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r'
|
||||||
|
) as f:
|
||||||
conversation_instructions_prompt = f.read()
|
conversation_instructions_prompt = f.read()
|
||||||
|
|
||||||
# load repo instruction from openhands/resolver/prompts/repo_instructions/all-hands-ai___openhands-resolver.txt
|
# load repo instruction from openhands/resolver/prompts/repo_instructions/all-hands-ai___openhands-resolver.txt
|
||||||
@@ -662,7 +674,6 @@ def test_file_instruction_with_repo_instruction():
|
|||||||
issue, prompt, conversation_instructions_prompt, repo_instruction
|
issue, prompt, conversation_instructions_prompt, repo_instruction
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
expected_instruction = """Please fix the following issue for the repository in /workspace.
|
expected_instruction = """Please fix the following issue for the repository in /workspace.
|
||||||
An environment has been set up for you to start working. You may assume all necessary tools are installed.
|
An environment has been set up for you to start working. You may assume all necessary tools are installed.
|
||||||
|
|
||||||
@@ -683,7 +694,6 @@ This is a Python repo for openhands-resolver, a library that attempts to resolve
|
|||||||
|
|
||||||
When you think you have fixed the issue through code changes, please finish the interaction."""
|
When you think you have fixed the issue through code changes, please finish the interaction."""
|
||||||
|
|
||||||
|
|
||||||
assert instruction == expected_instruction
|
assert instruction == expected_instruction
|
||||||
assert conversation_instructions == expected_conversation_instructions
|
assert conversation_instructions == expected_conversation_instructions
|
||||||
assert conversation_instructions is not None
|
assert conversation_instructions is not None
|
||||||
@@ -785,7 +795,9 @@ def test_instruction_with_thread_comments():
|
|||||||
with open('openhands/resolver/prompts/resolve/basic.jinja', 'r') as f:
|
with open('openhands/resolver/prompts/resolve/basic.jinja', 'r') as f:
|
||||||
prompt = f.read()
|
prompt = f.read()
|
||||||
|
|
||||||
with open('openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r') as f:
|
with open(
|
||||||
|
'openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r'
|
||||||
|
) as f:
|
||||||
conversation_instructions_template = f.read()
|
conversation_instructions_template = f.read()
|
||||||
|
|
||||||
llm_config = LLMConfig(model='test', api_key='test')
|
llm_config = LLMConfig(model='test', api_key='test')
|
||||||
|
|||||||
@@ -1,6 +1,3 @@
|
|||||||
from typing import Type
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
@@ -8,11 +5,11 @@ from openhands.core.config import LLMConfig
|
|||||||
from openhands.integrations.provider import ProviderType
|
from openhands.integrations.provider import ProviderType
|
||||||
from openhands.resolver.interfaces.github import GithubIssueHandler, GithubPRHandler
|
from openhands.resolver.interfaces.github import GithubIssueHandler, GithubPRHandler
|
||||||
from openhands.resolver.interfaces.gitlab import GitlabIssueHandler, GitlabPRHandler
|
from openhands.resolver.interfaces.gitlab import GitlabIssueHandler, GitlabPRHandler
|
||||||
from openhands.resolver.issue_handler_factory import IssueHandlerFactory
|
|
||||||
from openhands.resolver.interfaces.issue_definitions import (
|
from openhands.resolver.interfaces.issue_definitions import (
|
||||||
ServiceContextIssue,
|
ServiceContextIssue,
|
||||||
ServiceContextPR,
|
ServiceContextPR,
|
||||||
)
|
)
|
||||||
|
from openhands.resolver.issue_handler_factory import IssueHandlerFactory
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -45,33 +42,29 @@ test_cases = [
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'platform,issue_type,expected_context_type,expected_handler_type',
|
'platform,issue_type,expected_context_type,expected_handler_type', test_cases
|
||||||
test_cases
|
|
||||||
)
|
)
|
||||||
def test_handler_creation(
|
def test_handler_creation(
|
||||||
factory_params,
|
factory_params,
|
||||||
platform: ProviderType,
|
platform: ProviderType,
|
||||||
issue_type: str,
|
issue_type: str,
|
||||||
expected_context_type: Type,
|
expected_context_type: type,
|
||||||
expected_handler_type: Type,
|
expected_handler_type: type,
|
||||||
):
|
):
|
||||||
factory = IssueHandlerFactory(
|
factory = IssueHandlerFactory(
|
||||||
**factory_params,
|
**factory_params, platform=platform, issue_type=issue_type
|
||||||
platform=platform,
|
|
||||||
issue_type=issue_type
|
|
||||||
)
|
)
|
||||||
|
|
||||||
handler = factory.create()
|
handler = factory.create()
|
||||||
|
|
||||||
assert isinstance(handler, expected_context_type)
|
assert isinstance(handler, expected_context_type)
|
||||||
assert isinstance(handler._strategy, expected_handler_type)
|
assert isinstance(handler._strategy, expected_handler_type)
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_issue_type(factory_params):
|
def test_invalid_issue_type(factory_params):
|
||||||
factory = IssueHandlerFactory(
|
factory = IssueHandlerFactory(
|
||||||
**factory_params,
|
**factory_params, platform=ProviderType.GITHUB, issue_type='invalid'
|
||||||
platform=ProviderType.GITHUB,
|
|
||||||
issue_type='invalid'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match='Invalid issue type: invalid'):
|
with pytest.raises(ValueError, match='Invalid issue type: invalid'):
|
||||||
factory.create()
|
factory.create()
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from unittest import mock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from openhands.core.config import SandboxConfig,OpenHandsConfig
|
from openhands.core.config import OpenHandsConfig, SandboxConfig
|
||||||
from openhands.events.action import CmdRunAction
|
from openhands.events.action import CmdRunAction
|
||||||
from openhands.resolver.issue_resolver import IssueResolver
|
from openhands.resolver.issue_resolver import IssueResolver
|
||||||
|
|
||||||
@@ -36,7 +36,8 @@ def test_setup_sandbox_config_default():
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert_sandbox_config(
|
assert_sandbox_config(
|
||||||
openhands_config.sandbox, runtime_container_image='ghcr.io/all-hands-ai/runtime:mock-nikolaik'
|
openhands_config.sandbox,
|
||||||
|
runtime_container_image='ghcr.io/all-hands-ai/runtime:mock-nikolaik',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -68,7 +69,9 @@ def test_setup_sandbox_config_base_only():
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert_sandbox_config(
|
assert_sandbox_config(
|
||||||
openhands_config.sandbox, base_container_image=base_image, runtime_container_image=None
|
openhands_config.sandbox,
|
||||||
|
base_container_image=base_image,
|
||||||
|
runtime_container_image=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -84,7 +87,9 @@ def test_setup_sandbox_config_runtime_only():
|
|||||||
is_experimental=False,
|
is_experimental=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert_sandbox_config(openhands_config.sandbox, runtime_container_image=runtime_image)
|
assert_sandbox_config(
|
||||||
|
openhands_config.sandbox, runtime_container_image=runtime_image
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_setup_sandbox_config_experimental():
|
def test_setup_sandbox_config_experimental():
|
||||||
@@ -117,7 +122,9 @@ def test_setup_sandbox_config_gitlab_ci(mock_get_unique_uid, mock_getuid):
|
|||||||
is_experimental=False,
|
is_experimental=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert_sandbox_config(openhands_config.sandbox, local_runtime_url='http://localhost')
|
assert_sandbox_config(
|
||||||
|
openhands_config.sandbox, local_runtime_url='http://localhost'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@mock.patch('openhands.resolver.issue_resolver.os.getuid', return_value=1000)
|
@mock.patch('openhands.resolver.issue_resolver.os.getuid', return_value=1000)
|
||||||
@@ -134,7 +141,9 @@ def test_setup_sandbox_config_gitlab_ci_non_root(mock_getuid):
|
|||||||
is_experimental=False,
|
is_experimental=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert_sandbox_config(openhands_config.sandbox, local_runtime_url='http://localhost')
|
assert_sandbox_config(
|
||||||
|
openhands_config.sandbox, local_runtime_url='http://localhost'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@mock.patch('openhands.events.observation.CmdOutputObservation')
|
@mock.patch('openhands.events.observation.CmdOutputObservation')
|
||||||
|
|||||||
Reference in New Issue
Block a user