mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-08 22:38:05 -05:00
Lint all files in the repo (#9131)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
2
.github/workflows/lint-fix.yml
vendored
2
.github/workflows/lint-fix.yml
vendored
@@ -74,7 +74,7 @@ jobs:
|
||||
- name: Fix python lint issues
|
||||
run: |
|
||||
# Run all pre-commit hooks and continue even if they modify files (exit code 1)
|
||||
pre-commit run --config ./dev_config/python/.pre-commit-config.yaml --files openhands/**/* evaluation/**/* tests/**/* || true
|
||||
pre-commit run --config ./dev_config/python/.pre-commit-config.yaml --all-files || true
|
||||
|
||||
# Commit and push changes if any
|
||||
- name: Check for changes
|
||||
|
||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -53,7 +53,7 @@ jobs:
|
||||
- name: Install pre-commit
|
||||
run: pip install pre-commit==3.7.0
|
||||
- name: Run pre-commit hooks
|
||||
run: pre-commit run --files openhands/**/* evaluation/**/* tests/**/* --show-diff-on-failure --config ./dev_config/python/.pre-commit-config.yaml
|
||||
run: pre-commit run --all-files --show-diff-on-failure --config ./dev_config/python/.pre-commit-config.yaml
|
||||
|
||||
# Check version consistency across documentation
|
||||
check-version-consistency:
|
||||
|
||||
1
.github/workflows/py-unit-tests.yml
vendored
1
.github/workflows/py-unit-tests.yml
vendored
@@ -81,4 +81,3 @@ jobs:
|
||||
env:
|
||||
TEST_RUNTIME: local
|
||||
DEBUG: "1"
|
||||
|
||||
|
||||
2
Makefile
2
Makefile
@@ -189,7 +189,7 @@ install-pre-commit-hooks:
|
||||
|
||||
lint-backend:
|
||||
@echo "$(YELLOW)Running linters...$(RESET)"
|
||||
@poetry run pre-commit run --files openhands/**/* evaluation/**/* tests/**/* --show-diff-on-failure --config $(PRE_COMMIT_CONFIG_PATH)
|
||||
@poetry run pre-commit run --all-files --show-diff-on-failure --config $(PRE_COMMIT_CONFIG_PATH)
|
||||
|
||||
lint-frontend:
|
||||
@echo "$(YELLOW)Running linters for frontend...$(RESET)"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
TASK_INSTRUECTION="""
|
||||
TASK_INSTRUECTION = """
|
||||
Given the following GitHub problem description, your objective is to localize the specific files, classes or functions, and lines of code that need modification or contain key information to resolve the issue.
|
||||
|
||||
Follow these steps to localize the issue:
|
||||
|
||||
@@ -43,7 +43,7 @@ from openhands.core.config import (
|
||||
AgentConfig,
|
||||
OpenHandsConfig,
|
||||
get_llm_config_arg,
|
||||
get_parser
|
||||
get_parser,
|
||||
)
|
||||
from openhands.core.config.condenser_config import NoOpCondenserConfig
|
||||
from openhands.core.config.utils import get_condenser_config_arg
|
||||
@@ -92,10 +92,12 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata) -> MessageActio
|
||||
elif 'gpt-4.1' in llm_model:
|
||||
template_name = 'swe_gpt4.j2'
|
||||
else:
|
||||
template_name = 'swe_default.j2' # Default for 'swe' mode (regular swe-bench)
|
||||
template_name = (
|
||||
'swe_default.j2' # Default for 'swe' mode (regular swe-bench)
|
||||
)
|
||||
else:
|
||||
# Fallback or error handling if mode is unexpected
|
||||
logger.error(f"Unexpected evaluation mode: {mode}. Falling back to default.")
|
||||
logger.error(f'Unexpected evaluation mode: {mode}. Falling back to default.')
|
||||
template_name = 'swe_default.j2'
|
||||
|
||||
# Set up Jinja2 environment
|
||||
|
||||
@@ -100,4 +100,3 @@ This project is used to evaluate the performance of the model on VersiCode. It i
|
||||
# Contributor
|
||||
|
||||
[Tongtong Wu](https://scholar.google.com/citations?hl=zh-CN&user=u1Qp8lUAAAAJ&view_op=list_works&sortby=pubdate), [Weigang Wu](https://scholar.google.com/citations?hl=zh-CN&user=UneIZo8AAAAJ), [Xingyu Wang](https://scholar.google.com/citations?hl=zh-CN&user=wqPJcxcAAAAJ), [Kang Xu](https://scholar.google.com/citations?hl=zh-CN&user=N1UUDi0AAAAJ), [Suyu Ma](https://scholar.google.com/citations?hl=zh-CN&user=NJHR1ukAAAAJ), [Bo Jiang](https://wutong8023.site/VersiCode/), [Ping Yang](https://scholar.google.com/citations?view_op=list_works&hl=en&hl=en&user=hrogvxoAAAAJ), [Zhenchang Xing](https://scholar.google.com/citations?hl=zh-CN&user=0vCxuH4AAAAJ), [Yuan-Fang Li](https://scholar.google.com/citations?hl=zh-CN&user=wufXO1kAAAAJ), [Gholamreza Haffari](https://scholar.google.com/citations?hl=zh-CN&user=Perjx5EAAAAJ)
|
||||
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
"""
|
||||
GPT performs line level generation prediction and truncates overly long tokens
|
||||
"""
|
||||
import json
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
import os
|
||||
import tiktoken
|
||||
max_tokens = 127000 #gpt3.5 is 16ktoken gpt4o is 128k
|
||||
model_name = ""
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = ""
|
||||
import json
|
||||
import os
|
||||
|
||||
import tiktoken
|
||||
from openai import OpenAI
|
||||
|
||||
max_tokens = 127000 # gpt3.5 is 16ktoken gpt4o is 128k
|
||||
model_name = ''
|
||||
|
||||
os.environ['OPENAI_API_KEY'] = ''
|
||||
client = OpenAI()
|
||||
|
||||
|
||||
def truncate_text(text, max_tokens):
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
encoding = tiktoken.get_encoding('cl100k_base')
|
||||
disallowed_special = ()
|
||||
|
||||
tokens = encoding.encode(text, disallowed_special=disallowed_special)
|
||||
@@ -26,15 +29,11 @@ def truncate_text(text, max_tokens):
|
||||
|
||||
return truncated_text
|
||||
|
||||
|
||||
def predict(content, model_name):
|
||||
response = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": content
|
||||
}
|
||||
],
|
||||
messages=[{'role': 'user', 'content': content}],
|
||||
frequency_penalty=0.1,
|
||||
max_tokens=128,
|
||||
logit_bias=None,
|
||||
@@ -45,7 +44,7 @@ def predict(content, model_name):
|
||||
stop=None,
|
||||
stream=False,
|
||||
temperature=0.8,
|
||||
top_p=0.95
|
||||
top_p=0.95,
|
||||
)
|
||||
ans_list = []
|
||||
choices_list = response.choices
|
||||
@@ -55,6 +54,7 @@ def predict(content, model_name):
|
||||
final_ans = str(ans_list)
|
||||
return final_ans
|
||||
|
||||
|
||||
def bulid_prompt(description, old_version, old_code, new_version) -> str:
|
||||
"""
|
||||
build prompt
|
||||
@@ -86,18 +86,20 @@ def bulid_prompt(description, old_version, old_code, new_version) -> str:
|
||||
json_path = '../data/test_data/VersiCode_migration.json'
|
||||
|
||||
|
||||
with open(json_path, 'r', encoding='utf-8')as fr:
|
||||
with open(json_path, 'r', encoding='utf-8') as fr:
|
||||
lodict = json.load(fr)
|
||||
data_dict = lodict
|
||||
data_list = data_dict
|
||||
|
||||
|
||||
for data in data_list:
|
||||
if "model_output" in data:
|
||||
print(f"the {data_list.index(data) + 1} has already been predicted, skipping this data!")
|
||||
if 'model_output' in data:
|
||||
print(
|
||||
f'the {data_list.index(data) + 1} has already been predicted, skipping this data!'
|
||||
)
|
||||
continue
|
||||
try:
|
||||
print(f"Predicting {data_list.index(data) + 1} ")
|
||||
print(f'Predicting {data_list.index(data) + 1} ')
|
||||
old_version = data['dependency'] + data['old_version'] # package == x.x.x
|
||||
new_version = data['dependency'] + data['new_version'] # package == x.x.x
|
||||
description = data['description'] # 功能描述
|
||||
@@ -109,9 +111,11 @@ for data in data_list:
|
||||
|
||||
data['model_output'] = prediction
|
||||
except Exception as e:
|
||||
print(f"error:{e}")
|
||||
print("save current data")
|
||||
save_folder_path = os.path.join('../data/result_data/code_migration', model_name)
|
||||
print(f'error:{e}')
|
||||
print('save current data')
|
||||
save_folder_path = os.path.join(
|
||||
'../data/result_data/code_migration', model_name
|
||||
)
|
||||
if not os.path.exists(save_folder_path):
|
||||
os.makedirs(save_folder_path)
|
||||
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
||||
@@ -121,14 +125,10 @@ for data in data_list:
|
||||
break
|
||||
|
||||
|
||||
|
||||
save_folder_path = os.path.join('../data/result_data/code_migration', model_name)
|
||||
if not os.path.exists(save_folder_path):
|
||||
os.makedirs(save_folder_path)
|
||||
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
||||
|
||||
with open(save_json_path, 'w', encoding='utf-8')as fw:
|
||||
with open(save_json_path, 'w', encoding='utf-8') as fw:
|
||||
json.dump(data_dict, fw, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
"""
|
||||
GPT performs line level generation prediction and truncates overly long tokens
|
||||
"""
|
||||
import json
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
import os
|
||||
import tiktoken
|
||||
max_tokens = 127000 #gpt3.5 is 16ktoken gpt4o is 128k
|
||||
model_name = ""
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = ""
|
||||
import json
|
||||
import os
|
||||
|
||||
import tiktoken
|
||||
from openai import OpenAI
|
||||
|
||||
max_tokens = 127000 # gpt3.5 is 16ktoken gpt4o is 128k
|
||||
model_name = ''
|
||||
|
||||
os.environ['OPENAI_API_KEY'] = ''
|
||||
client = OpenAI()
|
||||
|
||||
|
||||
def truncate_text(text, max_tokens):
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
encoding = tiktoken.get_encoding('cl100k_base')
|
||||
disallowed_special = ()
|
||||
|
||||
tokens = encoding.encode(text, disallowed_special=disallowed_special)
|
||||
@@ -26,15 +29,11 @@ def truncate_text(text, max_tokens):
|
||||
|
||||
return truncated_text
|
||||
|
||||
|
||||
def predict(content, model_name):
|
||||
response = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": content
|
||||
}
|
||||
],
|
||||
messages=[{'role': 'user', 'content': content}],
|
||||
frequency_penalty=0.1,
|
||||
max_tokens=128,
|
||||
logit_bias=None,
|
||||
@@ -45,7 +44,7 @@ def predict(content, model_name):
|
||||
stop=None,
|
||||
stream=False,
|
||||
temperature=0.8,
|
||||
top_p=0.95
|
||||
top_p=0.95,
|
||||
)
|
||||
ans_list = []
|
||||
choices_list = response.choices
|
||||
@@ -55,6 +54,7 @@ def predict(content, model_name):
|
||||
final_ans = str(ans_list)
|
||||
return final_ans
|
||||
|
||||
|
||||
def bulid_prompt(version, description) -> str:
|
||||
"""
|
||||
build prompt
|
||||
@@ -64,7 +64,7 @@ def bulid_prompt(version, description) -> str:
|
||||
:param options:
|
||||
:return:
|
||||
"""
|
||||
prompt = f'''
|
||||
prompt = f"""
|
||||
You are a professional Python engineer, and I will provide functional descriptions and versions of specified dependency packages.
|
||||
You need to write code in Python to implement this feature based on the functional description and using the dependency package and version I specified.
|
||||
Please note that you only need to return the code that implements the function, and do not return any other content.
|
||||
@@ -88,25 +88,27 @@ def bulid_prompt(version, description) -> str:
|
||||
###response:
|
||||
|
||||
|
||||
'''
|
||||
"""
|
||||
return prompt
|
||||
|
||||
|
||||
json_path = '../data/test_data/VersiCode_block_completion.json'
|
||||
|
||||
|
||||
with open(json_path, 'r', encoding='utf-8')as fr:
|
||||
with open(json_path, 'r', encoding='utf-8') as fr:
|
||||
lodict = json.load(fr)
|
||||
data_dict = lodict
|
||||
data_list = data_dict
|
||||
|
||||
|
||||
for data in data_list:
|
||||
if "model_output" in data:
|
||||
print(f"the {data_list.index(data) + 1} has already been predicted, skipping this data!")
|
||||
if 'model_output' in data:
|
||||
print(
|
||||
f'the {data_list.index(data) + 1} has already been predicted, skipping this data!'
|
||||
)
|
||||
continue
|
||||
try:
|
||||
print(f"Predicting {data_list.index(data) + 1} ")
|
||||
print(f'Predicting {data_list.index(data) + 1} ')
|
||||
version = data['dependency'] + data['version'] # package == x.x.x
|
||||
description = data['description'] # func description
|
||||
|
||||
@@ -116,9 +118,11 @@ for data in data_list:
|
||||
|
||||
data['model_output'] = prediction
|
||||
except Exception as e:
|
||||
print(f"error:{e}")
|
||||
print("save current data")
|
||||
save_folder_path = os.path.join('../data/result_data/block_completion', model_name)
|
||||
print(f'error:{e}')
|
||||
print('save current data')
|
||||
save_folder_path = os.path.join(
|
||||
'../data/result_data/block_completion', model_name
|
||||
)
|
||||
if not os.path.exists(save_folder_path):
|
||||
os.makedirs(save_folder_path)
|
||||
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
||||
@@ -128,14 +132,10 @@ for data in data_list:
|
||||
break
|
||||
|
||||
|
||||
|
||||
save_folder_path = os.path.join('../data/result_data/block_completion', model_name)
|
||||
if not os.path.exists(save_folder_path):
|
||||
os.makedirs(save_folder_path)
|
||||
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
||||
|
||||
with open(save_json_path, 'w', encoding='utf-8')as fw:
|
||||
with open(save_json_path, 'w', encoding='utf-8') as fw:
|
||||
json.dump(data_dict, fw, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,20 +1,23 @@
|
||||
"""
|
||||
block completion
|
||||
"""
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
from vllm import LLM, SamplingParams
|
||||
import tiktoken
|
||||
import time
|
||||
import gc
|
||||
import torch
|
||||
from multiprocessing import Process
|
||||
|
||||
import tiktoken
|
||||
import torch
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||
|
||||
|
||||
def truncate_text(text, max_tokens):
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
encoding = tiktoken.get_encoding('cl100k_base')
|
||||
disallowed_special = ()
|
||||
|
||||
tokens = encoding.encode(text, disallowed_special=disallowed_special)
|
||||
@@ -27,8 +30,10 @@ def truncate_text(text, max_tokens):
|
||||
|
||||
return truncated_text
|
||||
|
||||
|
||||
model_list = ['/data2/base models/starcoder2-15b', '/data2/base models/CodeGemma-7B']
|
||||
|
||||
|
||||
def run_inference(model_name, origin_data_list):
|
||||
temp_data_list = copy.deepcopy(origin_data_list)
|
||||
test_list = []
|
||||
@@ -40,7 +45,12 @@ def run_inference(model_name, origin_data_list):
|
||||
test_list.append(instruction)
|
||||
|
||||
sampling_params = SamplingParams(n=6, temperature=0.8, top_p=0.95, max_tokens=64)
|
||||
llm = LLM(model=model_name, tensor_parallel_size=4, gpu_memory_utilization=0.9, swap_space=20)
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=4,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=20,
|
||||
)
|
||||
|
||||
outputs = llm.generate(test_list, sampling_params)
|
||||
for output in outputs:
|
||||
@@ -53,7 +63,9 @@ def run_inference(model_name, origin_data_list):
|
||||
|
||||
temp_data_list[requests_id]['model_output'] = str(temp_ans_list)
|
||||
|
||||
save_folder_path = os.path.join('../data/result_data/block_completion', model_name.split('/')[-1])
|
||||
save_folder_path = os.path.join(
|
||||
'../data/result_data/block_completion', model_name.split('/')[-1]
|
||||
)
|
||||
if not os.path.exists(save_folder_path):
|
||||
os.makedirs(save_folder_path)
|
||||
|
||||
@@ -75,7 +87,7 @@ def bulid_prompt(version, description) -> str:
|
||||
:param options:
|
||||
:return:
|
||||
"""
|
||||
prompt = f'''
|
||||
prompt = f"""
|
||||
You are a professional Python engineer, and I will provide functional descriptions and versions of specified dependency packages.
|
||||
You need to write code in Python to implement this feature based on the functional description and using the dependency package and version I specified.
|
||||
Please note that you only need to return the code that implements the function, and do not return any other content.
|
||||
@@ -99,13 +111,13 @@ def bulid_prompt(version, description) -> str:
|
||||
###response:
|
||||
|
||||
|
||||
'''
|
||||
"""
|
||||
return prompt
|
||||
|
||||
|
||||
json_path = '../data/test_data/VersiCode_block_completion.json'
|
||||
|
||||
with open(json_path, 'r', encoding='utf-8')as fr:
|
||||
with open(json_path, 'r', encoding='utf-8') as fr:
|
||||
lodict = json.load(fr)
|
||||
|
||||
origin_data_list = lodict
|
||||
@@ -115,4 +127,3 @@ for model_name in model_list:
|
||||
process.start()
|
||||
process.join()
|
||||
time.sleep(120)
|
||||
|
||||
|
||||
@@ -1,20 +1,23 @@
|
||||
"""
|
||||
code migration
|
||||
"""
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
from vllm import LLM, SamplingParams
|
||||
import tiktoken
|
||||
import time
|
||||
import gc
|
||||
import torch
|
||||
from multiprocessing import Process
|
||||
|
||||
import tiktoken
|
||||
import torch
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||
|
||||
|
||||
def truncate_text(text, max_tokens):
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
encoding = tiktoken.get_encoding('cl100k_base')
|
||||
disallowed_special = ()
|
||||
|
||||
tokens = encoding.encode(text, disallowed_special=disallowed_special)
|
||||
@@ -27,8 +30,10 @@ def truncate_text(text, max_tokens):
|
||||
|
||||
return truncated_text
|
||||
|
||||
|
||||
model_list = ['/data2/base models/starcoder2-15b', '/data2/base models/CodeGemma-7B']
|
||||
|
||||
|
||||
def run_inference(model_name, origin_data_list):
|
||||
temp_data_list = copy.deepcopy(origin_data_list)
|
||||
test_list = []
|
||||
@@ -42,7 +47,12 @@ def run_inference(model_name, origin_data_list):
|
||||
test_list.append(instruction)
|
||||
|
||||
sampling_params = SamplingParams(n=6, temperature=0.8, top_p=0.95, max_tokens=512)
|
||||
llm = LLM(model=model_name, tensor_parallel_size=4, gpu_memory_utilization=0.6, swap_space=40)
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=4,
|
||||
gpu_memory_utilization=0.6,
|
||||
swap_space=40,
|
||||
)
|
||||
|
||||
outputs = llm.generate(test_list, sampling_params)
|
||||
for output in outputs:
|
||||
@@ -55,7 +65,9 @@ def run_inference(model_name, origin_data_list):
|
||||
|
||||
temp_data_list[requests_id]['model_output'] = str(temp_ans_list)
|
||||
|
||||
save_folder_path = os.path.join('../data/result_data/code_migration', model_name.split('/')[-1])
|
||||
save_folder_path = os.path.join(
|
||||
'../data/result_data/code_migration', model_name.split('/')[-1]
|
||||
)
|
||||
if not os.path.exists(save_folder_path):
|
||||
os.makedirs(save_folder_path)
|
||||
|
||||
@@ -98,7 +110,7 @@ def bulid_prompt(description, old_version, old_code, new_version) -> str:
|
||||
|
||||
json_path = '../data/test_data/VersiCode_migration.json'
|
||||
|
||||
with open(json_path, 'r', encoding='utf-8')as fr:
|
||||
with open(json_path, 'r', encoding='utf-8') as fr:
|
||||
lodict = json.load(fr)
|
||||
|
||||
origin_data_list = lodict
|
||||
@@ -108,4 +120,3 @@ for model_name in model_list:
|
||||
process.start()
|
||||
process.join()
|
||||
time.sleep(120)
|
||||
|
||||
|
||||
@@ -4,20 +4,20 @@
|
||||
2、判断是否合法
|
||||
3、计算ISM,和PM
|
||||
"""
|
||||
import json
|
||||
import tokenize
|
||||
|
||||
import io
|
||||
import json
|
||||
import math
|
||||
import ast
|
||||
import re
|
||||
import os
|
||||
import re
|
||||
import tokenize
|
||||
|
||||
|
||||
def is_code_valid(code):
|
||||
|
||||
try:
|
||||
compile(code, '<string>', 'exec')
|
||||
return True
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@@ -44,7 +44,8 @@ def longest_common_prefix_between_lists_with_elements(list1, list2):
|
||||
max_prefix_elements = (str1, str2)
|
||||
return max_prefix_length, max_prefix_elements
|
||||
|
||||
def get_token(ans_code:str, output_code:str):
|
||||
|
||||
def get_token(ans_code: str, output_code: str):
|
||||
"""
|
||||
对代码进行词法分析,分解成标识符,返回两个标识符列表
|
||||
:param ans_code:
|
||||
@@ -55,40 +56,40 @@ def get_token(ans_code:str, output_code:str):
|
||||
ans_flag = True
|
||||
try:
|
||||
tokens_ans = tokenize.tokenize(io.BytesIO(ans_code.encode('utf-8')).readline)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
tokens_ans = ans_code.splitlines()
|
||||
ans_flag = False
|
||||
|
||||
try:
|
||||
tokens_output = tokenize.tokenize(io.BytesIO(output_code.encode('utf-8')).readline)
|
||||
except Exception as e:
|
||||
tokens_output = tokenize.tokenize(
|
||||
io.BytesIO(output_code.encode('utf-8')).readline
|
||||
)
|
||||
except Exception:
|
||||
tokens_output = output_code.splitlines()
|
||||
output_flag = False
|
||||
|
||||
|
||||
identifiers_ans = []
|
||||
identifiers_output = []
|
||||
if ans_flag == True:
|
||||
if ans_flag:
|
||||
try:
|
||||
for token in tokens_ans:
|
||||
if token.type == tokenize.NAME:
|
||||
identifiers_ans.append(token.string)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
identifiers_ans = tokens_ans
|
||||
else:
|
||||
identifiers_ans = tokens_ans
|
||||
|
||||
if output_flag == True:
|
||||
if output_flag:
|
||||
try:
|
||||
for to in tokens_output:
|
||||
if to.type == tokenize.NAME:
|
||||
identifiers_output.append(to.string)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
identifiers_output = tokens_output
|
||||
else:
|
||||
identifiers_output = tokens_output
|
||||
|
||||
|
||||
return identifiers_ans, identifiers_output
|
||||
|
||||
|
||||
@@ -108,15 +109,14 @@ def get_token_per_line(code: str):
|
||||
for token in tokens:
|
||||
if token.type == tokenize.NAME:
|
||||
identifiers.append(token.string)
|
||||
except:
|
||||
except Exception:
|
||||
identifiers = line.split(' ')
|
||||
identifiers_per_line.append(identifiers)
|
||||
|
||||
return identifiers_per_line
|
||||
|
||||
|
||||
|
||||
def get_ISM(answer_code:str, model_output_list:list, asnwer_name:str)->list:
|
||||
def get_ISM(answer_code: str, model_output_list: list, asnwer_name: str) -> list:
|
||||
"""
|
||||
计算ISM,返回一个有序的得分列表
|
||||
:return:
|
||||
@@ -126,7 +126,9 @@ def get_ISM(answer_code:str, model_output_list:list, asnwer_name:str)->list:
|
||||
if '```python' in code:
|
||||
code = code.replace('```python', '')
|
||||
code = code.replace('```', '')
|
||||
if not re.search(rf'\b{re.escape(asnwer_name)}\b', code) or is_code_valid(code) == False:
|
||||
if not re.search(rf'\b{re.escape(asnwer_name)}\b', code) or not is_code_valid(
|
||||
code
|
||||
):
|
||||
score_list.append(0)
|
||||
continue
|
||||
|
||||
@@ -135,10 +137,12 @@ def get_ISM(answer_code:str, model_output_list:list, asnwer_name:str)->list:
|
||||
# continue
|
||||
|
||||
identifiers_ans, identifiers_output = get_token(answer_code, code)
|
||||
max_len, elements = longest_common_prefix_between_lists_with_elements(identifiers_ans, identifiers_output)
|
||||
max_len, elements = longest_common_prefix_between_lists_with_elements(
|
||||
identifiers_ans, identifiers_output
|
||||
)
|
||||
if max_len != 0:
|
||||
base_element_len = max(len(elements[0]), len(elements[1]))
|
||||
temp_score = max_len/base_element_len
|
||||
temp_score = max_len / base_element_len
|
||||
score_list.append(temp_score)
|
||||
else:
|
||||
score_list.append(0)
|
||||
@@ -149,14 +153,16 @@ def get_ISM(answer_code:str, model_output_list:list, asnwer_name:str)->list:
|
||||
score_list = sorted(score_list, reverse=True)
|
||||
return score_list
|
||||
|
||||
def get_ISM_without_verification(answer_code:str, model_output_list:list, asnwer_name:str)->list:
|
||||
|
||||
def get_ISM_without_verification(
|
||||
answer_code: str, model_output_list: list, asnwer_name: str
|
||||
) -> list:
|
||||
"""
|
||||
计算ISM,返回一个有序的得分列表
|
||||
:return:
|
||||
"""
|
||||
score_list = []
|
||||
for code in model_output_list:
|
||||
|
||||
if asnwer_name not in code:
|
||||
score_list.append(0)
|
||||
continue
|
||||
@@ -166,10 +172,12 @@ def get_ISM_without_verification(answer_code:str, model_output_list:list, asnwer
|
||||
# continue
|
||||
|
||||
identifiers_ans, identifiers_output = get_token(answer_code, code)
|
||||
max_len, elements = longest_common_prefix_between_lists_with_elements(identifiers_ans, identifiers_output)
|
||||
max_len, elements = longest_common_prefix_between_lists_with_elements(
|
||||
identifiers_ans, identifiers_output
|
||||
)
|
||||
if max_len != 0:
|
||||
base_element_len = max(len(elements[0]), len(elements[1]))
|
||||
temp_score = max_len/base_element_len
|
||||
temp_score = max_len / base_element_len
|
||||
score_list.append(temp_score)
|
||||
else:
|
||||
score_list.append(0)
|
||||
@@ -180,6 +188,7 @@ def get_ISM_without_verification(answer_code:str, model_output_list:list, asnwer
|
||||
score_list = sorted(score_list, reverse=True)
|
||||
return score_list
|
||||
|
||||
|
||||
def longest_common_prefix_with_lengths(list1, list2):
|
||||
"""
|
||||
计算两个二维列表中每个子列表的最长前缀匹配长度,并记录拥有最长前缀匹配长度的两个子列表的长度
|
||||
@@ -206,7 +215,7 @@ def longest_common_prefix_with_lengths(list1, list2):
|
||||
return max_length, len_list1, len_list2
|
||||
|
||||
|
||||
def get_PM(answer_code:str, model_output_list:list, asnwer_name:str)->list:
|
||||
def get_PM(answer_code: str, model_output_list: list, asnwer_name: str) -> list:
|
||||
"""
|
||||
计算PM,返回一个有序的得分列表
|
||||
:return:
|
||||
@@ -216,8 +225,9 @@ def get_PM(answer_code:str, model_output_list:list, asnwer_name:str)->list:
|
||||
if '```python' in code:
|
||||
code = code.replace('```python', '')
|
||||
code = code.replace('```', '')
|
||||
if not re.search(rf'\b{re.escape(asnwer_name)}\b', code) or is_code_valid(code) == False:
|
||||
|
||||
if not re.search(rf'\b{re.escape(asnwer_name)}\b', code) or not is_code_valid(
|
||||
code
|
||||
):
|
||||
# if asnwer_name not in code or is_code_valid(code) == False:
|
||||
score_list.append(0)
|
||||
continue
|
||||
@@ -228,11 +238,13 @@ def get_PM(answer_code:str, model_output_list:list, asnwer_name:str)->list:
|
||||
|
||||
ans_list = get_token_per_line(answer_code)
|
||||
output_token_list = get_token_per_line(code)
|
||||
max_len, len1, len2 = longest_common_prefix_with_lengths(ans_list, output_token_list)
|
||||
max_len, len1, len2 = longest_common_prefix_with_lengths(
|
||||
ans_list, output_token_list
|
||||
)
|
||||
base_element_len = max(len1, len2)
|
||||
|
||||
if base_element_len != 0:
|
||||
temp_score = max_len/base_element_len
|
||||
temp_score = max_len / base_element_len
|
||||
score_list.append(temp_score)
|
||||
else:
|
||||
score_list.append(0)
|
||||
@@ -240,7 +252,8 @@ def get_PM(answer_code:str, model_output_list:list, asnwer_name:str)->list:
|
||||
score_list = sorted(score_list, reverse=True)
|
||||
return score_list
|
||||
|
||||
def get_score(score_list:list, k):
|
||||
|
||||
def get_score(score_list: list, k):
|
||||
"""
|
||||
计算score@n,k
|
||||
:param score_list:
|
||||
@@ -249,25 +262,25 @@ def get_score(score_list:list, k):
|
||||
"""
|
||||
n = len(score_list)
|
||||
sum = 0
|
||||
final = n-k+1
|
||||
for i in range(1, final+1):
|
||||
sum += math.comb(n-i, k-1) * score_list[i-1]
|
||||
final = n - k + 1
|
||||
for i in range(1, final + 1):
|
||||
sum += math.comb(n - i, k - 1) * score_list[i - 1]
|
||||
|
||||
final_score = sum/math.comb(n, k)
|
||||
final_score = sum / math.comb(n, k)
|
||||
|
||||
return final_score
|
||||
|
||||
|
||||
k = 1
|
||||
task = 'block' # block or line
|
||||
json_name = f"Versicode_{task}_completion.json"
|
||||
json_name = f'Versicode_{task}_completion.json'
|
||||
|
||||
folder_path = f'../data/result_data/{task}_completion'
|
||||
model_list = os.listdir(folder_path)
|
||||
|
||||
for model in model_list:
|
||||
model_json_path = os.path.join(folder_path, model, json_name)
|
||||
with open(model_json_path, 'r', encoding='utf-8')as fr:
|
||||
with open(model_json_path, 'r', encoding='utf-8') as fr:
|
||||
lodict = json.load(fr)
|
||||
data_dict = lodict
|
||||
data_list = data_dict
|
||||
@@ -309,7 +322,6 @@ for model in model_list:
|
||||
# if flag == 1:
|
||||
# continue
|
||||
|
||||
|
||||
ISM_score = get_score(ISM_score_list, k)
|
||||
PM_score = get_score(PM_score_list, k)
|
||||
|
||||
@@ -318,9 +330,8 @@ for model in model_list:
|
||||
# print(f"ISM分数:{ISM_score}")
|
||||
# print(f"PM分数:{PM_score}")
|
||||
|
||||
print(f"{model}, {task} completion task, ISM@{k} score: {sum_ISM/data_len}")
|
||||
print(f"{model}, {task} completion task, PM@{k} score: {sum_PM/data_len}")
|
||||
|
||||
print(f'{model}, {task} completion task, ISM@{k} score: {sum_ISM / data_len}')
|
||||
print(f'{model}, {task} completion task, PM@{k} score: {sum_PM / data_len}')
|
||||
|
||||
|
||||
# def get_token(ans_code:str, output_code:str):
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
"""
|
||||
Calculate the cdc score for migration
|
||||
"""
|
||||
import os
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
|
||||
# warnings.filterwarnings("ignore", category=SyntaxWarning)
|
||||
|
||||
|
||||
def is_correct_parameter_count(function_name, correct_code, test_code):
|
||||
"""
|
||||
判断参数数量是否一致
|
||||
@@ -39,6 +41,7 @@ def is_correct_parameter_count(function_name, correct_code, test_code):
|
||||
# 如果没有括号,检查函数名是否在字符串中
|
||||
return expected_count == 0 and function_name in test_code
|
||||
|
||||
|
||||
def check_keyword_parameters(function_name, correct_code, test_code):
|
||||
"""
|
||||
判断关键词参数赋值是否正确使用
|
||||
@@ -67,14 +70,18 @@ def check_keyword_parameters(function_name, correct_code, test_code):
|
||||
for correct_param in correct_param_list:
|
||||
if '=' in correct_param: # 仅当正确代码中有关键词参数
|
||||
param_name = correct_param.split('=')[0].strip()
|
||||
if not any(param_name in test_param and '=' in test_param for test_param in test_param_list):
|
||||
if not any(
|
||||
param_name in test_param and '=' in test_param
|
||||
for test_param in test_param_list
|
||||
):
|
||||
return False # 如果对应参数不是关键词参数,则返回False
|
||||
|
||||
return True # 所有关键字参数匹配
|
||||
|
||||
return False # 如果没有匹配,返回False
|
||||
|
||||
def with_correct(answer_code:str, model_output:str)->bool:
|
||||
|
||||
def with_correct(answer_code: str, model_output: str) -> bool:
|
||||
"""
|
||||
当answer是with结构时,判断模型生成的是不是with结构
|
||||
:param answer_code:
|
||||
@@ -89,34 +96,51 @@ def with_correct(answer_code:str, model_output:str)->bool:
|
||||
else:
|
||||
return False
|
||||
|
||||
def compute_block_score_k(answer:str, model_output:list, k:int, model_filled_code, core_line_in_core_block, core_line_in_output_clear):
|
||||
|
||||
def compute_block_score_k(
|
||||
answer: str,
|
||||
model_output: list,
|
||||
k: int,
|
||||
model_filled_code,
|
||||
core_line_in_core_block,
|
||||
core_line_in_output_clear,
|
||||
):
|
||||
"""
|
||||
cdc需要满足五个条件,em只需要满足第一个条件
|
||||
"""
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(model_filled_code[index]) and is_correct_parameter_count(answer, core_line_in_core_block, core_line_in_output_clear[index]) and with_correct(core_line_in_core_block, core_line_in_output_clear[index]) and check_keyword_parameters(answer, core_line_in_core_block, core_line_in_output_clear[index]):#block
|
||||
if (
|
||||
re.search(rf'\b{re.escape(answer)}\b', code)
|
||||
and is_code_valid(model_filled_code[index])
|
||||
and is_correct_parameter_count(
|
||||
answer, core_line_in_core_block, core_line_in_output_clear[index]
|
||||
)
|
||||
and with_correct(core_line_in_core_block, core_line_in_output_clear[index])
|
||||
and check_keyword_parameters(
|
||||
answer, core_line_in_core_block, core_line_in_output_clear[index]
|
||||
)
|
||||
): # block
|
||||
# if re.search(rf'\b{re.escape(answer)}\b', code):#block
|
||||
c += 1
|
||||
if n-c < k:
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
|
||||
score = 1 - (math.comb(n - c, k))/(math.comb(n, k))
|
||||
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def is_code_valid(code):
|
||||
|
||||
try:
|
||||
compile(code, '<string>', 'exec')
|
||||
return True
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def compute_score_k(answer:str, model_output:list, k:int):
|
||||
|
||||
def compute_score_k(answer: str, model_output: list, k: int):
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for output in model_output:
|
||||
@@ -125,41 +149,50 @@ def compute_score_k(answer:str, model_output:list, k:int):
|
||||
output = output.replace('```', '')
|
||||
# if answer == output:
|
||||
|
||||
if re.search(rf'\b{re.escape(answer)}\b', output) and is_code_valid(output) == True:
|
||||
if re.search(rf'\b{re.escape(answer)}\b', output) and is_code_valid(output):
|
||||
c += 1
|
||||
if n-c < k:
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
|
||||
score = 1 - (math.comb(n - c, k))/(math.comb(n, k))
|
||||
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||
|
||||
return score
|
||||
|
||||
k = 1 #cdc@k
|
||||
|
||||
k = 1 # cdc@k
|
||||
json_name = 'VersiCode_migration.json'
|
||||
task = 'migration'
|
||||
folder_path = f'../data/result_data/code_migration'
|
||||
folder_path = '../data/result_data/code_migration'
|
||||
|
||||
model_list = os.listdir(folder_path)
|
||||
for model in model_list:
|
||||
# if model != 'gpt-4o':
|
||||
# continue
|
||||
model_json_path = os.path.join(folder_path, model, json_name)
|
||||
with open(model_json_path, 'r', encoding='utf-8')as fr:
|
||||
with open(model_json_path, 'r', encoding='utf-8') as fr:
|
||||
lodict = json.load(fr)
|
||||
data_list = lodict
|
||||
|
||||
score_list = []
|
||||
for data in data_list:
|
||||
answer = data['new_name']# old -> new
|
||||
model_output = data[f'model_output_clear']# old -> new
|
||||
answer = data['new_name'] # old -> new
|
||||
model_output = data['model_output_clear'] # old -> new
|
||||
|
||||
model_filled_code = model_output
|
||||
# core_line_in_core_block = data['core_line_in_new_core_block']# old -> new
|
||||
core_line_in_core_block = data['core_line_in_code'] # old -> new
|
||||
core_line_in_output_clear = data['core_line_in_output_clear']# old -> new
|
||||
core_line_in_output_clear = data['core_line_in_output_clear'] # old -> new
|
||||
|
||||
score_list.append(
|
||||
compute_block_score_k(
|
||||
answer,
|
||||
model_output,
|
||||
k,
|
||||
model_filled_code,
|
||||
core_line_in_core_block,
|
||||
core_line_in_output_clear,
|
||||
)
|
||||
)
|
||||
|
||||
score_list.append(compute_block_score_k(answer, model_output, k, model_filled_code, core_line_in_core_block, core_line_in_output_clear))
|
||||
|
||||
final_score = sum(score_list)/len(score_list)
|
||||
print(f"{model}, {task} task, cdc@{k} score: {final_score}")
|
||||
final_score = sum(score_list) / len(score_list)
|
||||
print(f'{model}, {task} task, cdc@{k} score: {final_score}')
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
"""
|
||||
Calculate the cdc score for line and block
|
||||
"""
|
||||
import os
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
|
||||
# warnings.filterwarnings("ignore", category=SyntaxWarning)
|
||||
|
||||
def is_code_valid(code):
|
||||
|
||||
def is_code_valid(code):
|
||||
try:
|
||||
compile(code, '<string>', 'exec')
|
||||
return True
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_correct_parameter_count(function_name, correct_code, test_code):
|
||||
"""
|
||||
判断参数数量是否一致
|
||||
@@ -47,6 +49,7 @@ def is_correct_parameter_count(function_name, correct_code, test_code):
|
||||
# 如果没有括号,检查函数名是否在字符串中
|
||||
return expected_count == 0 and function_name in test_code
|
||||
|
||||
|
||||
def check_keyword_parameters(function_name, correct_code, test_code):
|
||||
"""
|
||||
判断关键词参数赋值是否正确使用
|
||||
@@ -75,14 +78,18 @@ def check_keyword_parameters(function_name, correct_code, test_code):
|
||||
for correct_param in correct_param_list:
|
||||
if '=' in correct_param: # 仅当正确代码中有关键词参数
|
||||
param_name = correct_param.split('=')[0].strip()
|
||||
if not any(param_name in test_param and '=' in test_param for test_param in test_param_list):
|
||||
if not any(
|
||||
param_name in test_param and '=' in test_param
|
||||
for test_param in test_param_list
|
||||
):
|
||||
return False # 如果对应参数不是关键词参数,则返回False
|
||||
|
||||
return True # 所有关键字参数匹配
|
||||
|
||||
return False # 如果没有匹配,返回False
|
||||
|
||||
def with_correct(answer_code:str, model_output:str)->bool:
|
||||
|
||||
def with_correct(answer_code: str, model_output: str) -> bool:
|
||||
"""
|
||||
当answer是with结构时,判断模型生成的是不是with结构
|
||||
:param answer_code:
|
||||
@@ -97,59 +104,87 @@ def with_correct(answer_code:str, model_output:str)->bool:
|
||||
else:
|
||||
return False
|
||||
|
||||
def compute_line_score_k(answer:str, model_output:list, k:int, model_filled_code, core_line):
|
||||
|
||||
def compute_line_score_k(
|
||||
answer: str, model_output: list, k: int, model_filled_code, core_line
|
||||
):
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(model_filled_code[index]) == True and is_correct_parameter_count(answer, core_line, code) and with_correct(core_line, code) and check_keyword_parameters(answer, core_line, code):#line
|
||||
if (
|
||||
re.search(rf'\b{re.escape(answer)}\b', code)
|
||||
and is_code_valid(model_filled_code[index])
|
||||
and is_correct_parameter_count(answer, core_line, code)
|
||||
and with_correct(core_line, code)
|
||||
and check_keyword_parameters(answer, core_line, code)
|
||||
): # line
|
||||
c += 1
|
||||
if n-c < k:
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
|
||||
score = 1 - (math.comb(n - c, k))/(math.comb(n, k))
|
||||
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||
|
||||
return score
|
||||
|
||||
def compute_block_score_k(answer:str, model_output:list, k:int, model_filled_code, core_line_in_core_block, core_line_in_output_clear):
|
||||
|
||||
def compute_block_score_k(
|
||||
answer: str,
|
||||
model_output: list,
|
||||
k: int,
|
||||
model_filled_code,
|
||||
core_line_in_core_block,
|
||||
core_line_in_output_clear,
|
||||
):
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(model_filled_code[index]) and is_correct_parameter_count(answer, core_line_in_core_block, core_line_in_output_clear[index]) and with_correct(core_line_in_core_block, core_line_in_output_clear[index]) and check_keyword_parameters(answer, core_line_in_core_block, core_line_in_output_clear[index]):#block
|
||||
if (
|
||||
re.search(rf'\b{re.escape(answer)}\b', code)
|
||||
and is_code_valid(model_filled_code[index])
|
||||
and is_correct_parameter_count(
|
||||
answer, core_line_in_core_block, core_line_in_output_clear[index]
|
||||
)
|
||||
and with_correct(core_line_in_core_block, core_line_in_output_clear[index])
|
||||
and check_keyword_parameters(
|
||||
answer, core_line_in_core_block, core_line_in_output_clear[index]
|
||||
)
|
||||
): # block
|
||||
c += 1
|
||||
if n-c < k:
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
|
||||
score = 1 - (math.comb(n - c, k))/(math.comb(n, k))
|
||||
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||
|
||||
return score
|
||||
|
||||
def compute_score_k(answer:str, model_output:list, k:int):
|
||||
|
||||
def compute_score_k(answer: str, model_output: list, k: int):
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(code):#block
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(
|
||||
code
|
||||
): # block
|
||||
# if re.search(rf'\b{re.escape(answer)}\b', code):#line
|
||||
c += 1
|
||||
if n-c < k:
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
|
||||
score = 1 - (math.comb(n - c, k))/(math.comb(n, k))
|
||||
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||
|
||||
return score
|
||||
|
||||
k = 3 #cdc@k
|
||||
|
||||
k = 3 # cdc@k
|
||||
task = 'block' # line or block
|
||||
json_name = f"Versicode_{task}_completion.json"
|
||||
json_name = f'Versicode_{task}_completion.json'
|
||||
|
||||
folder_path = f'../data/result_data/{task}_completion'
|
||||
model_list = os.listdir(folder_path)
|
||||
|
||||
for model in model_list:
|
||||
model_json_path = os.path.join(folder_path, model, json_name)
|
||||
with open(model_json_path, 'r', encoding='utf-8')as fr:
|
||||
with open(model_json_path, 'r', encoding='utf-8') as fr:
|
||||
lodict = json.load(fr)
|
||||
data_list = lodict
|
||||
|
||||
@@ -158,9 +193,15 @@ for model in model_list:
|
||||
for data in data_list:
|
||||
answer = data['core_token']
|
||||
model_output = eval(data['model_output_clear'])
|
||||
model_filled_code = [data['masked_code'].replace('<mask>', i) for i in model_output]
|
||||
model_filled_code = [
|
||||
data['masked_code'].replace('<mask>', i) for i in model_output
|
||||
]
|
||||
core_line = data['core_line']
|
||||
score_list.append(compute_line_score_k(answer, model_output, k, model_filled_code, core_line))
|
||||
score_list.append(
|
||||
compute_line_score_k(
|
||||
answer, model_output, k, model_filled_code, core_line
|
||||
)
|
||||
)
|
||||
else:
|
||||
score_list = []
|
||||
for data in data_list:
|
||||
@@ -169,7 +210,16 @@ for model in model_list:
|
||||
model_filled_code = eval(data['model_output_clear'])
|
||||
core_line = data['core_line']
|
||||
core_line_in_output_clear = data['core_line_in_output_clear']
|
||||
score_list.append(compute_block_score_k(answer, model_output, k, model_filled_code, core_line, core_line_in_output_clear))
|
||||
score_list.append(
|
||||
compute_block_score_k(
|
||||
answer,
|
||||
model_output,
|
||||
k,
|
||||
model_filled_code,
|
||||
core_line,
|
||||
core_line_in_output_clear,
|
||||
)
|
||||
)
|
||||
|
||||
final_score = sum(score_list)/len(score_list)
|
||||
print(f"{model}, {task} completion task, cdc@{k} score: {final_score}")
|
||||
final_score = sum(score_list) / len(score_list)
|
||||
print(f'{model}, {task} completion task, cdc@{k} score: {final_score}')
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
"""
|
||||
Calculate the cdc score for line and block
|
||||
"""
|
||||
import os
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
|
||||
# warnings.filterwarnings("ignore", category=SyntaxWarning)
|
||||
|
||||
def is_code_valid(code):
|
||||
|
||||
def is_code_valid(code):
|
||||
try:
|
||||
compile(code, '<string>', 'exec')
|
||||
return True
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_correct_parameter_count(function_name, correct_code, test_code):
|
||||
"""
|
||||
判断参数数量是否一致
|
||||
@@ -47,6 +49,7 @@ def is_correct_parameter_count(function_name, correct_code, test_code):
|
||||
# 如果没有括号,检查函数名是否在字符串中
|
||||
return expected_count == 0 and function_name in test_code
|
||||
|
||||
|
||||
def check_keyword_parameters(function_name, correct_code, test_code):
|
||||
"""
|
||||
判断关键词参数赋值是否正确使用
|
||||
@@ -75,14 +78,18 @@ def check_keyword_parameters(function_name, correct_code, test_code):
|
||||
for correct_param in correct_param_list:
|
||||
if '=' in correct_param: # 仅当正确代码中有关键词参数
|
||||
param_name = correct_param.split('=')[0].strip()
|
||||
if not any(param_name in test_param and '=' in test_param for test_param in test_param_list):
|
||||
if not any(
|
||||
param_name in test_param and '=' in test_param
|
||||
for test_param in test_param_list
|
||||
):
|
||||
return False # 如果对应参数不是关键词参数,则返回False
|
||||
|
||||
return True # 所有关键字参数匹配
|
||||
|
||||
return False # 如果没有匹配,返回False
|
||||
|
||||
def with_correct(answer_code:str, model_output:str)->bool:
|
||||
|
||||
def with_correct(answer_code: str, model_output: str) -> bool:
|
||||
"""
|
||||
当answer是with结构时,判断模型生成的是不是with结构
|
||||
:param answer_code:
|
||||
@@ -97,59 +104,71 @@ def with_correct(answer_code:str, model_output:str)->bool:
|
||||
else:
|
||||
return False
|
||||
|
||||
def compute_line_score_k(answer:str, model_output:list, k:int, model_filled_code, core_line):
|
||||
|
||||
def compute_line_score_k(
|
||||
answer: str, model_output: list, k: int, model_filled_code, core_line
|
||||
):
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code):#line
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code): # line
|
||||
c += 1
|
||||
if n-c < k:
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
|
||||
score = 1 - (math.comb(n - c, k))/(math.comb(n, k))
|
||||
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||
|
||||
return score
|
||||
|
||||
def compute_block_score_k(answer:str, model_output:list, k:int, model_filled_code, core_line_in_core_block, core_line_in_output_clear):
|
||||
|
||||
def compute_block_score_k(
|
||||
answer: str,
|
||||
model_output: list,
|
||||
k: int,
|
||||
model_filled_code,
|
||||
core_line_in_core_block,
|
||||
core_line_in_output_clear,
|
||||
):
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code):#block
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code): # block
|
||||
c += 1
|
||||
if n-c < k:
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
|
||||
score = 1 - (math.comb(n - c, k))/(math.comb(n, k))
|
||||
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||
|
||||
return score
|
||||
|
||||
def compute_score_k(answer:str, model_output:list, k:int):
|
||||
|
||||
def compute_score_k(answer: str, model_output: list, k: int):
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(code):#block
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(
|
||||
code
|
||||
): # block
|
||||
# if re.search(rf'\b{re.escape(answer)}\b', code):#line
|
||||
c += 1
|
||||
if n-c < k:
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
|
||||
score = 1 - (math.comb(n - c, k))/(math.comb(n, k))
|
||||
score = 1 - (math.comb(n - c, k)) / (math.comb(n, k))
|
||||
|
||||
return score
|
||||
|
||||
k = 3 #em@k
|
||||
|
||||
k = 3 # em@k
|
||||
task = 'block' # line or block
|
||||
json_name = f"Versicode_{task}_completion.json"
|
||||
json_name = f'Versicode_{task}_completion.json'
|
||||
|
||||
folder_path = f'../data/result_data/{task}_completion'
|
||||
model_list = os.listdir(folder_path)
|
||||
|
||||
for model in model_list:
|
||||
model_json_path = os.path.join(folder_path, model, json_name)
|
||||
with open(model_json_path, 'r', encoding='utf-8')as fr:
|
||||
with open(model_json_path, 'r', encoding='utf-8') as fr:
|
||||
lodict = json.load(fr)
|
||||
data_list = lodict
|
||||
|
||||
@@ -158,9 +177,15 @@ for model in model_list:
|
||||
for data in data_list:
|
||||
answer = data['core_token']
|
||||
model_output = eval(data['model_output_clear'])
|
||||
model_filled_code = [data['masked_code'].replace('<mask>', i) for i in model_output]
|
||||
model_filled_code = [
|
||||
data['masked_code'].replace('<mask>', i) for i in model_output
|
||||
]
|
||||
core_line = data['core_line']
|
||||
score_list.append(compute_line_score_k(answer, model_output, k, model_filled_code, core_line))
|
||||
score_list.append(
|
||||
compute_line_score_k(
|
||||
answer, model_output, k, model_filled_code, core_line
|
||||
)
|
||||
)
|
||||
else:
|
||||
score_list = []
|
||||
for data in data_list:
|
||||
@@ -169,7 +194,16 @@ for model in model_list:
|
||||
model_filled_code = eval(data['model_output_clear'])
|
||||
core_line = data['core_line']
|
||||
core_line_in_output_clear = data['core_line_in_output_clear']
|
||||
score_list.append(compute_block_score_k(answer, model_output, k, model_filled_code, core_line, core_line_in_output_clear))
|
||||
score_list.append(
|
||||
compute_block_score_k(
|
||||
answer,
|
||||
model_output,
|
||||
k,
|
||||
model_filled_code,
|
||||
core_line,
|
||||
core_line_in_output_clear,
|
||||
)
|
||||
)
|
||||
|
||||
final_score = sum(score_list)/len(score_list)
|
||||
print(f"{model}, {task} completion task, em@{k} score: {final_score}")
|
||||
final_score = sum(score_list) / len(score_list)
|
||||
print(f'{model}, {task} completion task, em@{k} score: {final_score}')
|
||||
|
||||
@@ -1,48 +1,43 @@
|
||||
"""
|
||||
Find the line of code generated by the model using the block in the version code
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
|
||||
|
||||
def process_line_mask(code_snippet, core_token):
|
||||
if not core_token:
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
replaced_lines = {}
|
||||
lines = code_snippet.split("\n")
|
||||
|
||||
lines = code_snippet.split('\n')
|
||||
|
||||
in_multi_line_comment = False
|
||||
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if in_multi_line_comment:
|
||||
|
||||
if ('"""' in line or "'''" in line) and not re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
||||
if ('"""' in line or "'''" in line) and not re.findall(
|
||||
r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line
|
||||
):
|
||||
in_multi_line_comment = False
|
||||
continue
|
||||
elif line.strip().startswith("#"):
|
||||
|
||||
elif line.strip().startswith('#'):
|
||||
continue
|
||||
elif re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
||||
|
||||
continue
|
||||
elif ('"""' in line or "'''" in line) and not re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
||||
|
||||
elif ('"""' in line or "'''" in line) and not re.findall(
|
||||
r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line
|
||||
):
|
||||
in_multi_line_comment = True
|
||||
continue
|
||||
else:
|
||||
|
||||
if re.search(r'\bdef\s+task_function\b', line):
|
||||
continue
|
||||
|
||||
|
||||
if re.search(r'\b{}\b(?!\s*=)'.format(re.escape(core_token)), line):
|
||||
|
||||
replaced_lines.update({i: line})
|
||||
|
||||
if replaced_lines:
|
||||
@@ -51,7 +46,7 @@ def process_line_mask(code_snippet, core_token):
|
||||
masked_line = lines[random_line_location]
|
||||
leading_spaces = re.match(r'^\s*', masked_line).group(0)
|
||||
masked_line = masked_line.strip()
|
||||
lines[random_line_location] = leading_spaces + "<line_mask>"
|
||||
lines[random_line_location] = leading_spaces + '<line_mask>'
|
||||
|
||||
masked_code = '\n'.join(lines)
|
||||
|
||||
@@ -71,11 +66,9 @@ def save_json(file_path, data):
|
||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
model_list = os.listdir('../data/result_data/block_completion')
|
||||
for model in model_list:
|
||||
|
||||
input_json_file = f'../data/result_data/block_completion/{model}/VersiCode_block_completion.json'
|
||||
output_json_file = input_json_file
|
||||
data = load_json(input_json_file)
|
||||
@@ -88,7 +81,7 @@ if __name__ == "__main__":
|
||||
if core_line_in_code:
|
||||
item['core_line_in_code'] = core_line_in_code
|
||||
else:
|
||||
item['core_line_in_code'] = "N/A"
|
||||
item['core_line_in_code'] = 'N/A'
|
||||
|
||||
model_output_clear = item['model_output_clear']
|
||||
core_line_in_output_list = []
|
||||
@@ -98,10 +91,9 @@ if __name__ == "__main__":
|
||||
if core_line_in_output:
|
||||
core_line_in_output_list.append(core_line_in_output)
|
||||
else:
|
||||
core_line_in_output_list.append("N/A")
|
||||
core_line_in_output_list.append('N/A')
|
||||
|
||||
item['core_line_in_output_clear'] = core_line_in_output_list
|
||||
|
||||
save_json(output_json_file, data)
|
||||
print("Done!")
|
||||
|
||||
print('Done!')
|
||||
|
||||
@@ -1,48 +1,43 @@
|
||||
"""
|
||||
Find the line of code generated by the model using the block in the version code
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
|
||||
|
||||
def process_line_mask(code_snippet, core_token):
|
||||
if not core_token:
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
replaced_lines = {}
|
||||
lines = code_snippet.split("\n")
|
||||
|
||||
lines = code_snippet.split('\n')
|
||||
|
||||
in_multi_line_comment = False
|
||||
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if in_multi_line_comment:
|
||||
|
||||
if ('"""' in line or "'''" in line) and not re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
||||
if ('"""' in line or "'''" in line) and not re.findall(
|
||||
r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line
|
||||
):
|
||||
in_multi_line_comment = False
|
||||
continue
|
||||
elif line.strip().startswith("#"):
|
||||
|
||||
elif line.strip().startswith('#'):
|
||||
continue
|
||||
elif re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
||||
|
||||
continue
|
||||
elif ('"""' in line or "'''" in line) and not re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
||||
|
||||
elif ('"""' in line or "'''" in line) and not re.findall(
|
||||
r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line
|
||||
):
|
||||
in_multi_line_comment = True
|
||||
continue
|
||||
else:
|
||||
|
||||
if re.search(r'\bdef\s+task_function\b', line):
|
||||
continue
|
||||
|
||||
|
||||
if re.search(r'\b{}\b(?!\s*=)'.format(re.escape(core_token)), line):
|
||||
|
||||
replaced_lines.update({i: line})
|
||||
|
||||
if replaced_lines:
|
||||
@@ -51,7 +46,7 @@ def process_line_mask(code_snippet, core_token):
|
||||
masked_line = lines[random_line_location]
|
||||
leading_spaces = re.match(r'^\s*', masked_line).group(0)
|
||||
masked_line = masked_line.strip()
|
||||
lines[random_line_location] = leading_spaces + "<line_mask>"
|
||||
lines[random_line_location] = leading_spaces + '<line_mask>'
|
||||
|
||||
masked_code = '\n'.join(lines)
|
||||
|
||||
@@ -71,12 +66,12 @@ def save_json(file_path, data):
|
||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
model_list = os.listdir('../data/result_data/code_migration')
|
||||
for model in model_list:
|
||||
|
||||
input_json_file = f'../data/result_data/code_migration/{model}/VersiCode_migration.json'
|
||||
input_json_file = (
|
||||
f'../data/result_data/code_migration/{model}/VersiCode_migration.json'
|
||||
)
|
||||
output_json_file = input_json_file
|
||||
data = load_json(input_json_file)
|
||||
|
||||
@@ -88,7 +83,7 @@ if __name__ == "__main__":
|
||||
if core_line_in_code:
|
||||
item['core_line_in_code'] = core_line_in_code
|
||||
else:
|
||||
item['core_line_in_code'] = "N/A"
|
||||
item['core_line_in_code'] = 'N/A'
|
||||
|
||||
model_output_clear = item['model_output_clear']
|
||||
core_line_in_output_list = []
|
||||
@@ -99,10 +94,9 @@ if __name__ == "__main__":
|
||||
if core_line_in_output:
|
||||
core_line_in_output_list.append(core_line_in_output)
|
||||
else:
|
||||
core_line_in_output_list.append("N/A")
|
||||
core_line_in_output_list.append('N/A')
|
||||
|
||||
item['core_line_in_output_clear'] = core_line_in_output_list
|
||||
|
||||
save_json(output_json_file, data)
|
||||
print("Done!")
|
||||
|
||||
print('Done!')
|
||||
|
||||
@@ -3,15 +3,14 @@ Clear the<start>and<end>generated by the model in inference
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
model_name = ''
|
||||
task = 'block_completion'
|
||||
|
||||
result_path = f'../data/result_data/{task}/{model_name}/VersiCode_block_completion.json' #Modify the file according to the task format
|
||||
result_path = f'../data/result_data/{task}/{model_name}/VersiCode_block_completion.json' # Modify the file according to the task format
|
||||
|
||||
|
||||
with open(result_path, 'r', encoding='utf-8')as fr:
|
||||
with open(result_path, 'r', encoding='utf-8') as fr:
|
||||
lodict = json.load(fr)
|
||||
data_dict = lodict
|
||||
data_list = data_dict
|
||||
@@ -20,17 +19,20 @@ for data in data_list:
|
||||
temp_list = []
|
||||
model_output_list = eval(data['model_output'])
|
||||
for output in model_output_list:
|
||||
|
||||
if "<start>" in output and "<end>" in output:
|
||||
start_index = output.find("<start>") + len("<start>")
|
||||
end_index = output.find("<end>")
|
||||
content = output[start_index:end_index].replace('```python', '').replace('```', '')
|
||||
if '<start>' in output and '<end>' in output:
|
||||
start_index = output.find('<start>') + len('<start>')
|
||||
end_index = output.find('<end>')
|
||||
content = (
|
||||
output[start_index:end_index]
|
||||
.replace('```python', '')
|
||||
.replace('```', '')
|
||||
)
|
||||
else:
|
||||
content = "no_answer"
|
||||
content = 'no_answer'
|
||||
|
||||
temp_list.append(content)
|
||||
|
||||
data['model_output_clear'] = str(temp_list)
|
||||
|
||||
with open(result_path, 'w', encoding='utf-8')as fw:
|
||||
with open(result_path, 'w', encoding='utf-8') as fw:
|
||||
json.dump(data_dict, fw, indent=4, ensure_ascii=False)
|
||||
@@ -5,24 +5,23 @@
|
||||
* Mock Service Worker.
|
||||
* @see https://github.com/mswjs/msw
|
||||
* - Please do NOT modify this file.
|
||||
* - Please do NOT serve this file on production.
|
||||
*/
|
||||
|
||||
const PACKAGE_VERSION = '2.8.4'
|
||||
const INTEGRITY_CHECKSUM = '00729d72e3b82faf54ca8b9621dbb96f'
|
||||
const PACKAGE_VERSION = '2.10.2'
|
||||
const INTEGRITY_CHECKSUM = 'f5825c521429caf22a4dd13b66e243af'
|
||||
const IS_MOCKED_RESPONSE = Symbol('isMockedResponse')
|
||||
const activeClientIds = new Set()
|
||||
|
||||
self.addEventListener('install', function () {
|
||||
addEventListener('install', function () {
|
||||
self.skipWaiting()
|
||||
})
|
||||
|
||||
self.addEventListener('activate', function (event) {
|
||||
addEventListener('activate', function (event) {
|
||||
event.waitUntil(self.clients.claim())
|
||||
})
|
||||
|
||||
self.addEventListener('message', async function (event) {
|
||||
const clientId = event.source.id
|
||||
addEventListener('message', async function (event) {
|
||||
const clientId = Reflect.get(event.source || {}, 'id')
|
||||
|
||||
if (!clientId || !self.clients) {
|
||||
return
|
||||
@@ -94,17 +93,18 @@ self.addEventListener('message', async function (event) {
|
||||
}
|
||||
})
|
||||
|
||||
self.addEventListener('fetch', function (event) {
|
||||
const { request } = event
|
||||
|
||||
addEventListener('fetch', function (event) {
|
||||
// Bypass navigation requests.
|
||||
if (request.mode === 'navigate') {
|
||||
if (event.request.mode === 'navigate') {
|
||||
return
|
||||
}
|
||||
|
||||
// Opening the DevTools triggers the "only-if-cached" request
|
||||
// that cannot be handled by the worker. Bypass such requests.
|
||||
if (request.cache === 'only-if-cached' && request.mode !== 'same-origin') {
|
||||
if (
|
||||
event.request.cache === 'only-if-cached' &&
|
||||
event.request.mode !== 'same-origin'
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -115,20 +115,26 @@ self.addEventListener('fetch', function (event) {
|
||||
return
|
||||
}
|
||||
|
||||
// Generate unique request ID.
|
||||
const requestId = crypto.randomUUID()
|
||||
event.respondWith(handleRequest(event, requestId))
|
||||
})
|
||||
|
||||
/**
|
||||
* @param {FetchEvent} event
|
||||
* @param {string} requestId
|
||||
*/
|
||||
async function handleRequest(event, requestId) {
|
||||
const client = await resolveMainClient(event)
|
||||
const requestCloneForEvents = event.request.clone()
|
||||
const response = await getResponse(event, client, requestId)
|
||||
|
||||
// Send back the response clone for the "response:*" life-cycle events.
|
||||
// Ensure MSW is active and ready to handle the message, otherwise
|
||||
// this message will pend indefinitely.
|
||||
if (client && activeClientIds.has(client.id)) {
|
||||
;(async function () {
|
||||
const serializedRequest = await serializeRequest(requestCloneForEvents)
|
||||
|
||||
// Clone the response so both the client and the library could consume it.
|
||||
const responseClone = response.clone()
|
||||
|
||||
sendToClient(
|
||||
@@ -136,27 +142,35 @@ async function handleRequest(event, requestId) {
|
||||
{
|
||||
type: 'RESPONSE',
|
||||
payload: {
|
||||
requestId,
|
||||
isMockedResponse: IS_MOCKED_RESPONSE in response,
|
||||
request: {
|
||||
id: requestId,
|
||||
...serializedRequest,
|
||||
},
|
||||
response: {
|
||||
type: responseClone.type,
|
||||
status: responseClone.status,
|
||||
statusText: responseClone.statusText,
|
||||
body: responseClone.body,
|
||||
headers: Object.fromEntries(responseClone.headers.entries()),
|
||||
body: responseClone.body,
|
||||
},
|
||||
},
|
||||
[responseClone.body],
|
||||
},
|
||||
responseClone.body ? [serializedRequest.body, responseClone.body] : [],
|
||||
)
|
||||
})()
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// Resolve the main client for the given event.
|
||||
// Client that issues a request doesn't necessarily equal the client
|
||||
// that registered the worker. It's with the latter the worker should
|
||||
// communicate with during the response resolving phase.
|
||||
/**
|
||||
* Resolve the main client for the given event.
|
||||
* Client that issues a request doesn't necessarily equal the client
|
||||
* that registered the worker. It's with the latter the worker should
|
||||
* communicate with during the response resolving phase.
|
||||
* @param {FetchEvent} event
|
||||
* @returns {Promise<Client | undefined>}
|
||||
*/
|
||||
async function resolveMainClient(event) {
|
||||
const client = await self.clients.get(event.clientId)
|
||||
|
||||
@@ -184,12 +198,16 @@ async function resolveMainClient(event) {
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {FetchEvent} event
|
||||
* @param {Client | undefined} client
|
||||
* @param {string} requestId
|
||||
* @returns {Promise<Response>}
|
||||
*/
|
||||
async function getResponse(event, client, requestId) {
|
||||
const { request } = event
|
||||
|
||||
// Clone the request because it might've been already used
|
||||
// (i.e. its body has been read and sent to the client).
|
||||
const requestClone = request.clone()
|
||||
const requestClone = event.request.clone()
|
||||
|
||||
function passthrough() {
|
||||
// Cast the request headers to a new Headers instance
|
||||
@@ -230,29 +248,17 @@ async function getResponse(event, client, requestId) {
|
||||
}
|
||||
|
||||
// Notify the client that a request has been intercepted.
|
||||
const requestBuffer = await request.arrayBuffer()
|
||||
const serializedRequest = await serializeRequest(event.request)
|
||||
const clientMessage = await sendToClient(
|
||||
client,
|
||||
{
|
||||
type: 'REQUEST',
|
||||
payload: {
|
||||
id: requestId,
|
||||
url: request.url,
|
||||
mode: request.mode,
|
||||
method: request.method,
|
||||
headers: Object.fromEntries(request.headers.entries()),
|
||||
cache: request.cache,
|
||||
credentials: request.credentials,
|
||||
destination: request.destination,
|
||||
integrity: request.integrity,
|
||||
redirect: request.redirect,
|
||||
referrer: request.referrer,
|
||||
referrerPolicy: request.referrerPolicy,
|
||||
body: requestBuffer,
|
||||
keepalive: request.keepalive,
|
||||
...serializedRequest,
|
||||
},
|
||||
},
|
||||
[requestBuffer],
|
||||
[serializedRequest.body],
|
||||
)
|
||||
|
||||
switch (clientMessage.type) {
|
||||
@@ -268,6 +274,12 @@ async function getResponse(event, client, requestId) {
|
||||
return passthrough()
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Client} client
|
||||
* @param {any} message
|
||||
* @param {Array<Transferable>} transferrables
|
||||
* @returns {Promise<any>}
|
||||
*/
|
||||
function sendToClient(client, message, transferrables = []) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const channel = new MessageChannel()
|
||||
@@ -280,14 +292,18 @@ function sendToClient(client, message, transferrables = []) {
|
||||
resolve(event.data)
|
||||
}
|
||||
|
||||
client.postMessage(
|
||||
message,
|
||||
[channel.port2].concat(transferrables.filter(Boolean)),
|
||||
)
|
||||
client.postMessage(message, [
|
||||
channel.port2,
|
||||
...transferrables.filter(Boolean),
|
||||
])
|
||||
})
|
||||
}
|
||||
|
||||
async function respondWithMock(response) {
|
||||
/**
|
||||
* @param {Response} response
|
||||
* @returns {Response}
|
||||
*/
|
||||
function respondWithMock(response) {
|
||||
// Setting response status code to 0 is a no-op.
|
||||
// However, when responding with a "Response.error()", the produced Response
|
||||
// instance will have status code set to 0. Since it's not possible to create
|
||||
@@ -305,3 +321,24 @@ async function respondWithMock(response) {
|
||||
|
||||
return mockedResponse
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Request} request
|
||||
*/
|
||||
async function serializeRequest(request) {
|
||||
return {
|
||||
url: request.url,
|
||||
mode: request.mode,
|
||||
method: request.method,
|
||||
headers: Object.fromEntries(request.headers.entries()),
|
||||
cache: request.cache,
|
||||
credentials: request.credentials,
|
||||
destination: request.destination,
|
||||
integrity: request.integrity,
|
||||
redirect: request.redirect,
|
||||
referrer: request.referrer,
|
||||
referrerPolicy: request.referrerPolicy,
|
||||
body: await request.arrayBuffer(),
|
||||
keepalive: request.keepalive,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -208,7 +208,9 @@ Note:
|
||||
# for visualwebarena, webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env
|
||||
# initialize and retrieve the first observation by issuing an noop OP
|
||||
# For non-benchmark browsing, the browser env starts with a blank page, and the agent is expected to first navigate to desired websites
|
||||
return BrowseInteractiveAction(browser_actions='noop(1000)', return_axtree=True)
|
||||
return BrowseInteractiveAction(
|
||||
browser_actions='noop(1000)', return_axtree=True
|
||||
)
|
||||
|
||||
for event in state.view:
|
||||
if isinstance(event, BrowseInteractiveAction):
|
||||
|
||||
@@ -54,6 +54,7 @@ class MCPStdioServerConfig(BaseModel):
|
||||
and set(self.env.items()) == set(other.env.items())
|
||||
)
|
||||
|
||||
|
||||
class MCPSHTTPServerConfig(BaseModel):
|
||||
url: str
|
||||
api_key: str | None = None
|
||||
|
||||
@@ -39,6 +39,7 @@ class GitHubService(BaseGitService, GitService):
|
||||
|
||||
The class is instantiated via get_impl() in openhands.server.shared.py.
|
||||
"""
|
||||
|
||||
BASE_URL = 'https://api.github.com'
|
||||
token: SecretStr = SecretStr('')
|
||||
refresh = False
|
||||
@@ -508,7 +509,6 @@ class GitHubService(BaseGitService, GitService):
|
||||
return response['html_url']
|
||||
|
||||
|
||||
|
||||
github_service_cls = os.environ.get(
|
||||
'OPENHANDS_GITHUB_SERVICE_CLS',
|
||||
'openhands.integrations.github.github_service.GitHubService',
|
||||
|
||||
@@ -32,6 +32,7 @@ class GitLabService(BaseGitService, GitService):
|
||||
|
||||
The class is instantiated via get_impl() in openhands.server.shared.py.
|
||||
"""
|
||||
|
||||
BASE_URL = 'https://gitlab.com/api/v4'
|
||||
GRAPHQL_URL = 'https://gitlab.com/api/graphql'
|
||||
token: SecretStr = SecretStr('')
|
||||
@@ -482,9 +483,7 @@ class GitLabService(BaseGitService, GitService):
|
||||
|
||||
# Set default description if none provided
|
||||
if not description:
|
||||
description = (
|
||||
f'Merging changes from {source_branch} into {target_branch}'
|
||||
)
|
||||
description = f'Merging changes from {source_branch} into {target_branch}'
|
||||
|
||||
# Prepare the request payload
|
||||
payload = {
|
||||
@@ -499,11 +498,9 @@ class GitLabService(BaseGitService, GitService):
|
||||
url=url, params=payload, method=RequestMethod.POST
|
||||
)
|
||||
|
||||
|
||||
return response['web_url']
|
||||
|
||||
|
||||
|
||||
gitlab_service_cls = os.environ.get(
|
||||
'OPENHANDS_GITLAB_SERVICE_CLS',
|
||||
'openhands.integrations.gitlab.gitlab_service.GitLabService',
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import typing
|
||||
from functools import lru_cache
|
||||
from typing import Callable
|
||||
import typing
|
||||
from uuid import UUID
|
||||
|
||||
import docker
|
||||
@@ -283,7 +283,9 @@ class DockerRuntime(ActionExecutionClient):
|
||||
self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
|
||||
|
||||
use_host_network = self.config.sandbox.use_host_network
|
||||
network_mode: typing.Literal['host'] | None = 'host' if use_host_network else None
|
||||
network_mode: typing.Literal['host'] | None = (
|
||||
'host' if use_host_network else None
|
||||
)
|
||||
|
||||
# Initialize port mappings
|
||||
port_mapping: dict[str, list[dict[str, str]]] | None = None
|
||||
@@ -356,7 +358,7 @@ class DockerRuntime(ActionExecutionClient):
|
||||
|
||||
try:
|
||||
if self.runtime_container_image is None:
|
||||
raise ValueError("Runtime container image is not set")
|
||||
raise ValueError('Runtime container image is not set')
|
||||
self.container = self.docker_client.containers.run(
|
||||
self.runtime_container_image,
|
||||
command=command,
|
||||
|
||||
@@ -363,7 +363,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
self._session_api_key = start_response['session_api_key']
|
||||
self.log(
|
||||
'debug',
|
||||
f'Session API key setted',
|
||||
'Session API key set',
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -59,7 +59,7 @@ class MCPProxyManager:
|
||||
"""
|
||||
if len(self.config['mcpServers']) == 0:
|
||||
logger.info(
|
||||
f"No MCP servers configured for FastMCP Proxy, skipping initialization."
|
||||
'No MCP servers configured for FastMCP Proxy, skipping initialization.'
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -70,7 +70,7 @@ class MCPProxyManager:
|
||||
api_key=self.api_key,
|
||||
)
|
||||
|
||||
logger.info(f"FastMCP Proxy initialized successfully")
|
||||
logger.info('FastMCP Proxy initialized successfully')
|
||||
|
||||
async def mount_to_app(
|
||||
self, app: FastAPI, allow_origins: Optional[list[str]] = None
|
||||
@@ -83,9 +83,7 @@ class MCPProxyManager:
|
||||
allow_origins: List of allowed origins for CORS
|
||||
"""
|
||||
if len(self.config['mcpServers']) == 0:
|
||||
logger.info(
|
||||
f"No MCP servers configured for FastMCP Proxy, skipping mount."
|
||||
)
|
||||
logger.info('No MCP servers configured for FastMCP Proxy, skipping mount.')
|
||||
return
|
||||
|
||||
if not self.proxy:
|
||||
@@ -101,8 +99,7 @@ class MCPProxyManager:
|
||||
app.routes.remove('/mcp')
|
||||
|
||||
app.mount('/', mcp_app)
|
||||
logger.info(f"Mounted FastMCP Proxy app at /mcp")
|
||||
|
||||
logger.info('Mounted FastMCP Proxy app at /mcp')
|
||||
|
||||
async def update_and_remount(
|
||||
self,
|
||||
@@ -122,10 +119,7 @@ class MCPProxyManager:
|
||||
tools: List of tool configurations
|
||||
allow_origins: List of allowed origins for CORS
|
||||
"""
|
||||
tools = {
|
||||
t.name: t.model_dump()
|
||||
for t in stdio_servers
|
||||
}
|
||||
tools = {t.name: t.model_dump() for t in stdio_servers}
|
||||
self.config['mcpServers'] = tools
|
||||
|
||||
del self.proxy
|
||||
|
||||
@@ -48,8 +48,8 @@ def _extract_code(string: str) -> str | None:
|
||||
|
||||
content = str(matches[0])
|
||||
if content.startswith('#EDIT:'):
|
||||
#Remove first line
|
||||
content = content[content.find('\n') + 1:]
|
||||
# Remove first line
|
||||
content = content[content.find('\n') + 1 :]
|
||||
return content
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import random
|
||||
import socket
|
||||
import time
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
def check_port_available(port: int) -> bool:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
|
||||
@@ -53,7 +53,7 @@ def load_server_config() -> ServerConfig:
|
||||
logger.info(f'Using config class {config_cls}')
|
||||
|
||||
server_config_cls = get_impl(ServerConfig, config_cls)
|
||||
server_config : ServerConfig = server_config_cls()
|
||||
server_config: ServerConfig = server_config_cls()
|
||||
server_config.verify_config()
|
||||
|
||||
return server_config
|
||||
|
||||
@@ -15,7 +15,7 @@ from openhands.events.stream import EventStreamSubscriber, session_exists
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||
from openhands.server.monitoring import MonitoringListener
|
||||
from openhands.server.session.agent_session import AgentSession, WAIT_TIME_BEFORE_CLOSE
|
||||
from openhands.server.session.agent_session import WAIT_TIME_BEFORE_CLOSE, AgentSession
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.session.session import ROOM_KEY, Session
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
@@ -508,7 +508,9 @@ class StandaloneConversationManager(ConversationManager):
|
||||
session_api_key=None,
|
||||
event_store=session.agent_session.event_stream,
|
||||
status=_get_status_from_session(session),
|
||||
runtime_status=getattr(session.agent_session.runtime, 'runtime_status', None),
|
||||
runtime_status=getattr(
|
||||
session.agent_session.runtime, 'runtime_status', None
|
||||
),
|
||||
)
|
||||
|
||||
def _get_conversation_url(self, conversation_id: str):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.runtime.runtime_status import RuntimeStatus
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
@@ -5,13 +5,13 @@ from pydantic import BaseModel
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.microagent.types import InputMetadata
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.shared import conversation_manager
|
||||
from openhands.server.utils import get_conversation
|
||||
from openhands.microagent.types import InputMetadata
|
||||
from openhands.memory.memory import Memory
|
||||
|
||||
app = APIRouter(
|
||||
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
||||
@@ -216,7 +216,11 @@ async def get_microagents(
|
||||
content=agent.content,
|
||||
triggers=[],
|
||||
inputs=agent.metadata.inputs,
|
||||
tools=[server.name for server in agent.metadata.mcp_tools.stdio_servers] if agent.metadata.mcp_tools else [],
|
||||
tools=[
|
||||
server.name for server in agent.metadata.mcp_tools.stdio_servers
|
||||
]
|
||||
if agent.metadata.mcp_tools
|
||||
else [],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -229,7 +233,11 @@ async def get_microagents(
|
||||
content=agent.content,
|
||||
triggers=agent.triggers,
|
||||
inputs=agent.metadata.inputs,
|
||||
tools=[server.name for server in agent.metadata.mcp_tools.stdio_servers] if agent.metadata.mcp_tools else [],
|
||||
tools=[
|
||||
server.name for server in agent.metadata.mcp_tools.stdio_servers
|
||||
]
|
||||
if agent.metadata.mcp_tools
|
||||
else [],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -6,15 +6,19 @@ from openhands.events.async_event_store_wrapper import AsyncEventStoreWrapper
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.utils import get_conversation
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
||||
app = APIRouter(
|
||||
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
||||
)
|
||||
|
||||
|
||||
@app.post('/submit-feedback')
|
||||
async def submit_feedback(request: Request, conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse:
|
||||
async def submit_feedback(
|
||||
request: Request, conversation: ServerConversation = Depends(get_conversation)
|
||||
) -> JSONResponse:
|
||||
"""Submit user feedback.
|
||||
|
||||
This function stores the provided feedback data.
|
||||
@@ -37,9 +41,7 @@ async def submit_feedback(request: Request, conversation: ServerConversation = D
|
||||
# Assuming the storage service is already configured in the backend
|
||||
# and there is a function to handle the storage.
|
||||
body = await request.json()
|
||||
async_store = AsyncEventStoreWrapper(
|
||||
conversation.event_stream, filter_hidden=True
|
||||
)
|
||||
async_store = AsyncEventStoreWrapper(conversation.event_stream, filter_hidden=True)
|
||||
trajectory = []
|
||||
async for event in async_store:
|
||||
trajectory.append(event_to_dict(event))
|
||||
|
||||
@@ -5,7 +5,6 @@ from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Request,
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
@@ -27,17 +26,15 @@ from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.file_config import (
|
||||
FILES_TO_IGNORE,
|
||||
)
|
||||
from openhands.server.shared import (
|
||||
ConversationStoreImpl,
|
||||
config,
|
||||
)
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.server.utils import get_conversation, get_conversation_store
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
||||
app = APIRouter(
|
||||
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
||||
)
|
||||
|
||||
|
||||
@app.get(
|
||||
@@ -50,7 +47,7 @@ app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_
|
||||
)
|
||||
async def list_files(
|
||||
conversation: ServerConversation = Depends(get_conversation),
|
||||
path: str | None = None
|
||||
path: str | None = None,
|
||||
) -> list[str] | JSONResponse:
|
||||
"""List files in the specified path.
|
||||
|
||||
@@ -132,7 +129,9 @@ async def list_files(
|
||||
415: {'description': 'Unsupported media type', 'model': dict},
|
||||
},
|
||||
)
|
||||
async def select_file(file: str, conversation: ServerConversation = Depends(get_conversation)) -> FileResponse | JSONResponse:
|
||||
async def select_file(
|
||||
file: str, conversation: ServerConversation = Depends(get_conversation)
|
||||
) -> FileResponse | JSONResponse:
|
||||
"""Retrieve the content of a specified file.
|
||||
|
||||
To select a file:
|
||||
@@ -196,7 +195,9 @@ async def select_file(file: str, conversation: ServerConversation = Depends(get_
|
||||
500: {'description': 'Error zipping workspace', 'model': dict},
|
||||
},
|
||||
)
|
||||
def zip_current_workspace(conversation: ServerConversation = Depends(get_conversation)) -> FileResponse | JSONResponse:
|
||||
def zip_current_workspace(
|
||||
conversation: ServerConversation = Depends(get_conversation),
|
||||
) -> FileResponse | JSONResponse:
|
||||
try:
|
||||
logger.debug('Zipping workspace')
|
||||
runtime: Runtime = conversation.runtime
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import itertools
|
||||
import re
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
@@ -9,19 +9,18 @@ from fastapi.responses import JSONResponse
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import (
|
||||
ChangeAgentStateAction,
|
||||
NullAction,
|
||||
)
|
||||
from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.observation import (
|
||||
NullObservation,
|
||||
AgentStateChangedObservation,
|
||||
NullObservation,
|
||||
)
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
ProviderHandler,
|
||||
@@ -38,10 +37,9 @@ from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.data_models.conversation_info_result_set import (
|
||||
ConversationInfoResultSet,
|
||||
)
|
||||
from openhands.server.services.conversation_service import create_new_conversation
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.services.conversation_service import create_new_conversation
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.shared import (
|
||||
ConversationStoreImpl,
|
||||
config,
|
||||
@@ -53,11 +51,12 @@ from openhands.server.user_auth import (
|
||||
get_provider_tokens,
|
||||
get_user_id,
|
||||
get_user_secrets,
|
||||
get_user_settings_store,
|
||||
get_user_settings,
|
||||
get_user_settings_store,
|
||||
)
|
||||
from openhands.server.user_auth.user_auth import AuthType
|
||||
from openhands.server.utils import get_conversation_store, get_conversation as get_conversation_object
|
||||
from openhands.server.utils import get_conversation as get_conversation_object
|
||||
from openhands.server.utils import get_conversation_store
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
@@ -295,7 +294,7 @@ async def delete_conversation(
|
||||
async def get_prompt(
|
||||
event_id: int,
|
||||
user_settings: SettingsStore = Depends(get_user_settings_store),
|
||||
conversation: ServerConversation | None = Depends(get_conversation_object)
|
||||
conversation: ServerConversation | None = Depends(get_conversation_object),
|
||||
):
|
||||
if conversation is None:
|
||||
return JSONResponse(
|
||||
@@ -409,7 +408,6 @@ async def start_conversation(
|
||||
logger.info(f'Starting conversation: {conversation_id}')
|
||||
|
||||
try:
|
||||
|
||||
# Check that the conversation exists
|
||||
try:
|
||||
await conversation_store.get_metadata(conversation_id)
|
||||
@@ -463,10 +461,17 @@ async def stop_conversation(
|
||||
|
||||
try:
|
||||
# Check if the conversation is running
|
||||
agent_loop_info = await conversation_manager.get_agent_loop_info(user_id=user_id, filter_to_sids={conversation_id})
|
||||
conversation_status = agent_loop_info[0].status if agent_loop_info else ConversationStatus.STOPPED
|
||||
agent_loop_info = await conversation_manager.get_agent_loop_info(
|
||||
user_id=user_id, filter_to_sids={conversation_id}
|
||||
)
|
||||
conversation_status = (
|
||||
agent_loop_info[0].status if agent_loop_info else ConversationStatus.STOPPED
|
||||
)
|
||||
|
||||
if conversation_status not in (ConversationStatus.STARTING, ConversationStatus.RUNNING):
|
||||
if conversation_status not in (
|
||||
ConversationStatus.STARTING,
|
||||
ConversationStatus.RUNNING,
|
||||
):
|
||||
return ConversationResponse(
|
||||
status='ok',
|
||||
conversation_id=conversation_id,
|
||||
@@ -505,7 +510,11 @@ def _get_contextual_events(event_stream: EventStream, event_id: int) -> str:
|
||||
|
||||
agent_event_filter = EventFilter(
|
||||
exclude_hidden=True,
|
||||
exclude_types=(NullAction, NullObservation, ChangeAgentStateAction, AgentStateChangedObservation
|
||||
exclude_types=(
|
||||
NullAction,
|
||||
NullObservation,
|
||||
ChangeAgentStateAction,
|
||||
AgentStateChangedObservation,
|
||||
),
|
||||
) # the types of events that can be in an agent's history
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ async def create_pr(
|
||||
target_branch: Annotated[str, Field(description='Target branch on repo')],
|
||||
title: Annotated[str, Field(description='PR Title')],
|
||||
body: Annotated[str | None, Field(description='PR body')],
|
||||
draft: Annotated[bool, Field(description='Whether PR opened is a draft')] = True
|
||||
draft: Annotated[bool, Field(description='Whether PR opened is a draft')] = True,
|
||||
) -> str:
|
||||
"""Open a PR in GitHub"""
|
||||
|
||||
@@ -127,7 +127,7 @@ async def create_pr(
|
||||
target_branch=target_branch,
|
||||
title=title,
|
||||
body=body,
|
||||
draft=draft
|
||||
draft=draft,
|
||||
)
|
||||
|
||||
if conversation_id:
|
||||
@@ -148,7 +148,12 @@ async def create_mr(
|
||||
],
|
||||
source_branch: Annotated[str, Field(description='Source branch on repo')],
|
||||
target_branch: Annotated[str, Field(description='Target branch on repo')],
|
||||
title: Annotated[str, Field(description='MR Title. Start title with `DRAFT:` or `WIP:` if applicable.')],
|
||||
title: Annotated[
|
||||
str,
|
||||
Field(
|
||||
description='MR Title. Start title with `DRAFT:` or `WIP:` if applicable.'
|
||||
),
|
||||
],
|
||||
description: Annotated[str | None, Field(description='MR description')],
|
||||
) -> str:
|
||||
"""Open a MR in GitLab"""
|
||||
|
||||
@@ -8,14 +8,18 @@ from fastapi import (
|
||||
)
|
||||
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.utils import get_conversation
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.utils import get_conversation
|
||||
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
||||
app = APIRouter(
|
||||
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
||||
)
|
||||
|
||||
|
||||
@app.route('/security/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE'])
|
||||
async def security_api(request: Request, conversation: ServerConversation = Depends(get_conversation)) -> Response:
|
||||
async def security_api(
|
||||
request: Request, conversation: ServerConversation = Depends(get_conversation)
|
||||
) -> Response:
|
||||
"""Catch-all route for security analyzer API requests.
|
||||
|
||||
Each request is handled directly to the security analyzer.
|
||||
@@ -35,6 +39,4 @@ async def security_api(request: Request, conversation: ServerConversation = Depe
|
||||
detail='Security analyzer not initialized',
|
||||
)
|
||||
|
||||
return await conversation.security_analyzer.handle_api_request(
|
||||
request
|
||||
)
|
||||
return await conversation.security_analyzer.handle_api_request(request)
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
from fastapi import APIRouter, Depends, Request, status
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.async_event_store_wrapper import AsyncEventStoreWrapper
|
||||
from openhands.events.serialization import event_to_trajectory
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.utils import get_conversation
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.utils import get_conversation
|
||||
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
||||
app = APIRouter(
|
||||
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
||||
)
|
||||
|
||||
|
||||
@app.get('/trajectory')
|
||||
async def get_trajectory(conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse:
|
||||
async def get_trajectory(
|
||||
conversation: ServerConversation = Depends(get_conversation),
|
||||
) -> JSONResponse:
|
||||
"""Get trajectory.
|
||||
|
||||
This function retrieves the current trajectory and returns it.
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
@@ -80,7 +79,6 @@ async def create_new_conversation(
|
||||
session_init_args['conversation_instructions'] = conversation_instructions
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
|
||||
|
||||
logger.info('Loading conversation store')
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
logger.info('ServerConversation store loaded')
|
||||
@@ -90,13 +88,14 @@ async def create_new_conversation(
|
||||
conversation_id = uuid.uuid4().hex
|
||||
|
||||
if not await conversation_store.exists(conversation_id):
|
||||
|
||||
logger.info(
|
||||
f'New conversation ID: {conversation_id}',
|
||||
extra={'user_id': user_id, 'session_id': conversation_id},
|
||||
)
|
||||
|
||||
conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test(user_id, conversation_id, conversation_init_data)
|
||||
conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test(
|
||||
user_id, conversation_id, conversation_init_data
|
||||
)
|
||||
conversation_title = get_default_conversation_title(conversation_id)
|
||||
|
||||
logger.info(f'Saving metadata for conversation {conversation_id}')
|
||||
|
||||
@@ -197,23 +197,21 @@ class AgentSession:
|
||||
finally:
|
||||
self._starting = False
|
||||
success = finished and runtime_connected
|
||||
duration = (time.time() - started_at)
|
||||
duration = time.time() - started_at
|
||||
|
||||
log_metadata = {
|
||||
'signal': 'agent_session_start',
|
||||
'success': success,
|
||||
'duration': duration,
|
||||
'restored_state': restored_state
|
||||
'restored_state': restored_state,
|
||||
}
|
||||
if success:
|
||||
self.logger.info(
|
||||
f'Agent session start succeeded in {duration}s',
|
||||
extra=log_metadata
|
||||
f'Agent session start succeeded in {duration}s', extra=log_metadata
|
||||
)
|
||||
else:
|
||||
self.logger.error(
|
||||
f'Agent session start failed in {duration}s',
|
||||
extra=log_metadata
|
||||
f'Agent session start failed in {duration}s', extra=log_metadata
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
|
||||
@@ -105,7 +105,12 @@ class FileConversationStore(ConversationStore):
|
||||
async def get_instance(
|
||||
cls, config: OpenHandsConfig, user_id: str | None
|
||||
) -> FileConversationStore:
|
||||
file_store = get_file_store(config.file_store, config.file_store_path, config.file_store_web_hook_url, config.file_store_web_hook_headers)
|
||||
file_store = get_file_store(
|
||||
config.file_store,
|
||||
config.file_store_path,
|
||||
config.file_store_web_hook_url,
|
||||
config.file_store_web_hook_headers,
|
||||
)
|
||||
return FileConversationStore(file_store)
|
||||
|
||||
|
||||
|
||||
@@ -43,6 +43,6 @@ class FileSecretsStore(SecretsStore):
|
||||
config.file_store,
|
||||
config.file_store_path,
|
||||
config.file_store_web_hook_url,
|
||||
config.file_store_web_hook_headers
|
||||
config.file_store_web_hook_headers,
|
||||
)
|
||||
return FileSecretsStore(file_store)
|
||||
|
||||
@@ -37,6 +37,6 @@ class FileSettingsStore(SettingsStore):
|
||||
config.file_store,
|
||||
config.file_store_path,
|
||||
config.file_store_web_hook_url,
|
||||
config.file_store_web_hook_headers
|
||||
config.file_store_web_hook_headers,
|
||||
)
|
||||
return FileSettingsStore(file_store)
|
||||
|
||||
@@ -10,24 +10,36 @@ class TestTranslationCompleteness(unittest.TestCase):
|
||||
|
||||
def test_translation_completeness_check_runs(self):
|
||||
"""Test that the translation completeness check script can be executed."""
|
||||
frontend_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "frontend")
|
||||
script_path = os.path.join(frontend_dir, "scripts", "check-translation-completeness.cjs")
|
||||
frontend_dir = os.path.join(
|
||||
os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
),
|
||||
'frontend',
|
||||
)
|
||||
script_path = os.path.join(
|
||||
frontend_dir, 'scripts', 'check-translation-completeness.cjs'
|
||||
)
|
||||
|
||||
# Verify the script exists
|
||||
self.assertTrue(os.path.exists(script_path), f"Script not found at {script_path}")
|
||||
self.assertTrue(
|
||||
os.path.exists(script_path), f'Script not found at {script_path}'
|
||||
)
|
||||
|
||||
# Verify the script is executable
|
||||
self.assertTrue(os.access(script_path, os.X_OK), f"Script at {script_path} is not executable")
|
||||
self.assertTrue(
|
||||
os.access(script_path, os.X_OK),
|
||||
f'Script at {script_path} is not executable',
|
||||
)
|
||||
|
||||
# Run the script (it may fail due to missing translations, but we just want to verify it runs)
|
||||
try:
|
||||
subprocess.run(
|
||||
["node", script_path],
|
||||
['node', script_path],
|
||||
cwd=frontend_dir,
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True
|
||||
text=True,
|
||||
)
|
||||
# We don't assert on the return code because it might fail due to missing translations
|
||||
except Exception as e:
|
||||
self.fail(f"Failed to run translation completeness check: {e}")
|
||||
self.fail(f'Failed to run translation completeness check: {e}')
|
||||
|
||||
@@ -100,7 +100,6 @@ def mock_conversation_instructions_template():
|
||||
return 'Instructions: {{ repo_instruction }}'
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_followup_prompt_template():
|
||||
return 'Issue context: {{ issues }}\n\nReview comments: {{ review_comments }}\n\nReview threads: {{ review_threads }}\n\nFiles: {{ files }}\n\nThread comments: {{ thread_context }}\n\nPlease fix this issue.'
|
||||
@@ -532,7 +531,11 @@ async def test_process_issue(
|
||||
handler_instance.guess_success.assert_not_called()
|
||||
|
||||
|
||||
def test_get_instruction(mock_user_instructions_template, mock_conversation_instructions_template, mock_followup_prompt_template):
|
||||
def test_get_instruction(
|
||||
mock_user_instructions_template,
|
||||
mock_conversation_instructions_template,
|
||||
mock_followup_prompt_template,
|
||||
):
|
||||
issue = Issue(
|
||||
owner='test_owner',
|
||||
repo='test_repo',
|
||||
@@ -545,7 +548,10 @@ def test_get_instruction(mock_user_instructions_template, mock_conversation_inst
|
||||
GithubIssueHandler('owner', 'repo', 'token'), mock_llm_config
|
||||
)
|
||||
instruction, conversation_instructions, images_urls = issue_handler.get_instruction(
|
||||
issue, mock_user_instructions_template, mock_conversation_instructions_template, None
|
||||
issue,
|
||||
mock_user_instructions_template,
|
||||
mock_conversation_instructions_template,
|
||||
None,
|
||||
)
|
||||
expected_instruction = 'Issue: Test Issue\n\nThis is a test issue refer to image \n\nPlease fix this issue.'
|
||||
|
||||
@@ -576,7 +582,10 @@ def test_get_instruction(mock_user_instructions_template, mock_conversation_inst
|
||||
GithubPRHandler('owner', 'repo', 'token'), mock_llm_config
|
||||
)
|
||||
instruction, conversation_instructions, images_urls = pr_handler.get_instruction(
|
||||
issue, mock_followup_prompt_template, mock_conversation_instructions_template, None
|
||||
issue,
|
||||
mock_followup_prompt_template,
|
||||
mock_conversation_instructions_template,
|
||||
None,
|
||||
)
|
||||
expected_instruction = "Issue context: [\n \"Issue 1 fix the type\"\n]\n\nReview comments: None\n\nReview threads: [\n \"There is still a typo 'pthon' instead of 'python'\"\n]\n\nFiles: []\n\nThread comments: I've left review comments, please address them\n---\nThis is a valid concern.\n\nPlease fix this issue."
|
||||
|
||||
@@ -601,7 +610,9 @@ def test_file_instruction():
|
||||
with open('openhands/resolver/prompts/resolve/basic.jinja', 'r') as f:
|
||||
prompt = f.read()
|
||||
|
||||
with open('openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r') as f:
|
||||
with open(
|
||||
'openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r'
|
||||
) as f:
|
||||
conversation_instructions_template = f.read()
|
||||
|
||||
# Test without thread comments
|
||||
@@ -610,7 +621,7 @@ def test_file_instruction():
|
||||
GithubIssueHandler('owner', 'repo', 'token'), mock_llm_config
|
||||
)
|
||||
instruction, conversation_instructions, images_urls = issue_handler.get_instruction(
|
||||
issue, prompt,conversation_instructions_template, None
|
||||
issue, prompt, conversation_instructions_template, None
|
||||
)
|
||||
expected_instruction = """Please fix the following issue for the repository in /workspace.
|
||||
An environment has been set up for you to start working. You may assume all necessary tools are installed.
|
||||
@@ -620,7 +631,6 @@ Test Issue
|
||||
|
||||
This is a test issue """
|
||||
|
||||
|
||||
expected_conversation_instructions = """IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.
|
||||
You SHOULD INCLUDE PROPER INDENTATION in your edit commands.
|
||||
|
||||
@@ -644,7 +654,9 @@ def test_file_instruction_with_repo_instruction():
|
||||
with open('openhands/resolver/prompts/resolve/basic.jinja', 'r') as f:
|
||||
prompt = f.read()
|
||||
|
||||
with open('openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r') as f:
|
||||
with open(
|
||||
'openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r'
|
||||
) as f:
|
||||
conversation_instructions_prompt = f.read()
|
||||
|
||||
# load repo instruction from openhands/resolver/prompts/repo_instructions/all-hands-ai___openhands-resolver.txt
|
||||
@@ -662,7 +674,6 @@ def test_file_instruction_with_repo_instruction():
|
||||
issue, prompt, conversation_instructions_prompt, repo_instruction
|
||||
)
|
||||
|
||||
|
||||
expected_instruction = """Please fix the following issue for the repository in /workspace.
|
||||
An environment has been set up for you to start working. You may assume all necessary tools are installed.
|
||||
|
||||
@@ -683,7 +694,6 @@ This is a Python repo for openhands-resolver, a library that attempts to resolve
|
||||
|
||||
When you think you have fixed the issue through code changes, please finish the interaction."""
|
||||
|
||||
|
||||
assert instruction == expected_instruction
|
||||
assert conversation_instructions == expected_conversation_instructions
|
||||
assert conversation_instructions is not None
|
||||
@@ -785,7 +795,9 @@ def test_instruction_with_thread_comments():
|
||||
with open('openhands/resolver/prompts/resolve/basic.jinja', 'r') as f:
|
||||
prompt = f.read()
|
||||
|
||||
with open('openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r') as f:
|
||||
with open(
|
||||
'openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r'
|
||||
) as f:
|
||||
conversation_instructions_template = f.read()
|
||||
|
||||
llm_config = LLMConfig(model='test', api_key='test')
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
from typing import Type
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
@@ -8,11 +5,11 @@ from openhands.core.config import LLMConfig
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.resolver.interfaces.github import GithubIssueHandler, GithubPRHandler
|
||||
from openhands.resolver.interfaces.gitlab import GitlabIssueHandler, GitlabPRHandler
|
||||
from openhands.resolver.issue_handler_factory import IssueHandlerFactory
|
||||
from openhands.resolver.interfaces.issue_definitions import (
|
||||
ServiceContextIssue,
|
||||
ServiceContextPR,
|
||||
)
|
||||
from openhands.resolver.issue_handler_factory import IssueHandlerFactory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -45,20 +42,17 @@ test_cases = [
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'platform,issue_type,expected_context_type,expected_handler_type',
|
||||
test_cases
|
||||
'platform,issue_type,expected_context_type,expected_handler_type', test_cases
|
||||
)
|
||||
def test_handler_creation(
|
||||
factory_params,
|
||||
platform: ProviderType,
|
||||
issue_type: str,
|
||||
expected_context_type: Type,
|
||||
expected_handler_type: Type,
|
||||
expected_context_type: type,
|
||||
expected_handler_type: type,
|
||||
):
|
||||
factory = IssueHandlerFactory(
|
||||
**factory_params,
|
||||
platform=platform,
|
||||
issue_type=issue_type
|
||||
**factory_params, platform=platform, issue_type=issue_type
|
||||
)
|
||||
|
||||
handler = factory.create()
|
||||
@@ -66,11 +60,10 @@ def test_handler_creation(
|
||||
assert isinstance(handler, expected_context_type)
|
||||
assert isinstance(handler._strategy, expected_handler_type)
|
||||
|
||||
|
||||
def test_invalid_issue_type(factory_params):
|
||||
factory = IssueHandlerFactory(
|
||||
**factory_params,
|
||||
platform=ProviderType.GITHUB,
|
||||
issue_type='invalid'
|
||||
**factory_params, platform=ProviderType.GITHUB, issue_type='invalid'
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match='Invalid issue type: invalid'):
|
||||
|
||||
@@ -2,7 +2,7 @@ from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.config import SandboxConfig,OpenHandsConfig
|
||||
from openhands.core.config import OpenHandsConfig, SandboxConfig
|
||||
from openhands.events.action import CmdRunAction
|
||||
from openhands.resolver.issue_resolver import IssueResolver
|
||||
|
||||
@@ -36,7 +36,8 @@ def test_setup_sandbox_config_default():
|
||||
)
|
||||
|
||||
assert_sandbox_config(
|
||||
openhands_config.sandbox, runtime_container_image='ghcr.io/all-hands-ai/runtime:mock-nikolaik'
|
||||
openhands_config.sandbox,
|
||||
runtime_container_image='ghcr.io/all-hands-ai/runtime:mock-nikolaik',
|
||||
)
|
||||
|
||||
|
||||
@@ -68,7 +69,9 @@ def test_setup_sandbox_config_base_only():
|
||||
)
|
||||
|
||||
assert_sandbox_config(
|
||||
openhands_config.sandbox, base_container_image=base_image, runtime_container_image=None
|
||||
openhands_config.sandbox,
|
||||
base_container_image=base_image,
|
||||
runtime_container_image=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -84,7 +87,9 @@ def test_setup_sandbox_config_runtime_only():
|
||||
is_experimental=False,
|
||||
)
|
||||
|
||||
assert_sandbox_config(openhands_config.sandbox, runtime_container_image=runtime_image)
|
||||
assert_sandbox_config(
|
||||
openhands_config.sandbox, runtime_container_image=runtime_image
|
||||
)
|
||||
|
||||
|
||||
def test_setup_sandbox_config_experimental():
|
||||
@@ -117,7 +122,9 @@ def test_setup_sandbox_config_gitlab_ci(mock_get_unique_uid, mock_getuid):
|
||||
is_experimental=False,
|
||||
)
|
||||
|
||||
assert_sandbox_config(openhands_config.sandbox, local_runtime_url='http://localhost')
|
||||
assert_sandbox_config(
|
||||
openhands_config.sandbox, local_runtime_url='http://localhost'
|
||||
)
|
||||
|
||||
|
||||
@mock.patch('openhands.resolver.issue_resolver.os.getuid', return_value=1000)
|
||||
@@ -134,7 +141,9 @@ def test_setup_sandbox_config_gitlab_ci_non_root(mock_getuid):
|
||||
is_experimental=False,
|
||||
)
|
||||
|
||||
assert_sandbox_config(openhands_config.sandbox, local_runtime_url='http://localhost')
|
||||
assert_sandbox_config(
|
||||
openhands_config.sandbox, local_runtime_url='http://localhost'
|
||||
)
|
||||
|
||||
|
||||
@mock.patch('openhands.events.observation.CmdOutputObservation')
|
||||
|
||||
Reference in New Issue
Block a user