mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-08 22:38:05 -05:00
Lint all files in the repo (#9131)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
2
.github/workflows/lint-fix.yml
vendored
2
.github/workflows/lint-fix.yml
vendored
@@ -74,7 +74,7 @@ jobs:
|
||||
- name: Fix python lint issues
|
||||
run: |
|
||||
# Run all pre-commit hooks and continue even if they modify files (exit code 1)
|
||||
pre-commit run --config ./dev_config/python/.pre-commit-config.yaml --files openhands/**/* evaluation/**/* tests/**/* || true
|
||||
pre-commit run --config ./dev_config/python/.pre-commit-config.yaml --all-files || true
|
||||
|
||||
# Commit and push changes if any
|
||||
- name: Check for changes
|
||||
|
||||
2
.github/workflows/lint.yml
vendored
2
.github/workflows/lint.yml
vendored
@@ -53,7 +53,7 @@ jobs:
|
||||
- name: Install pre-commit
|
||||
run: pip install pre-commit==3.7.0
|
||||
- name: Run pre-commit hooks
|
||||
run: pre-commit run --files openhands/**/* evaluation/**/* tests/**/* --show-diff-on-failure --config ./dev_config/python/.pre-commit-config.yaml
|
||||
run: pre-commit run --all-files --show-diff-on-failure --config ./dev_config/python/.pre-commit-config.yaml
|
||||
|
||||
# Check version consistency across documentation
|
||||
check-version-consistency:
|
||||
|
||||
1
.github/workflows/py-unit-tests.yml
vendored
1
.github/workflows/py-unit-tests.yml
vendored
@@ -81,4 +81,3 @@ jobs:
|
||||
env:
|
||||
TEST_RUNTIME: local
|
||||
DEBUG: "1"
|
||||
|
||||
|
||||
2
Makefile
2
Makefile
@@ -189,7 +189,7 @@ install-pre-commit-hooks:
|
||||
|
||||
lint-backend:
|
||||
@echo "$(YELLOW)Running linters...$(RESET)"
|
||||
@poetry run pre-commit run --files openhands/**/* evaluation/**/* tests/**/* --show-diff-on-failure --config $(PRE_COMMIT_CONFIG_PATH)
|
||||
@poetry run pre-commit run --all-files --show-diff-on-failure --config $(PRE_COMMIT_CONFIG_PATH)
|
||||
|
||||
lint-frontend:
|
||||
@echo "$(YELLOW)Running linters for frontend...$(RESET)"
|
||||
|
||||
@@ -43,7 +43,7 @@ from openhands.core.config import (
|
||||
AgentConfig,
|
||||
OpenHandsConfig,
|
||||
get_llm_config_arg,
|
||||
get_parser
|
||||
get_parser,
|
||||
)
|
||||
from openhands.core.config.condenser_config import NoOpCondenserConfig
|
||||
from openhands.core.config.utils import get_condenser_config_arg
|
||||
@@ -92,10 +92,12 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata) -> MessageActio
|
||||
elif 'gpt-4.1' in llm_model:
|
||||
template_name = 'swe_gpt4.j2'
|
||||
else:
|
||||
template_name = 'swe_default.j2' # Default for 'swe' mode (regular swe-bench)
|
||||
template_name = (
|
||||
'swe_default.j2' # Default for 'swe' mode (regular swe-bench)
|
||||
)
|
||||
else:
|
||||
# Fallback or error handling if mode is unexpected
|
||||
logger.error(f"Unexpected evaluation mode: {mode}. Falling back to default.")
|
||||
logger.error(f'Unexpected evaluation mode: {mode}. Falling back to default.')
|
||||
template_name = 'swe_default.j2'
|
||||
|
||||
# Set up Jinja2 environment
|
||||
|
||||
@@ -100,4 +100,3 @@ This project is used to evaluate the performance of the model on VersiCode. It i
|
||||
# Contributor
|
||||
|
||||
[Tongtong Wu](https://scholar.google.com/citations?hl=zh-CN&user=u1Qp8lUAAAAJ&view_op=list_works&sortby=pubdate), [Weigang Wu](https://scholar.google.com/citations?hl=zh-CN&user=UneIZo8AAAAJ), [Xingyu Wang](https://scholar.google.com/citations?hl=zh-CN&user=wqPJcxcAAAAJ), [Kang Xu](https://scholar.google.com/citations?hl=zh-CN&user=N1UUDi0AAAAJ), [Suyu Ma](https://scholar.google.com/citations?hl=zh-CN&user=NJHR1ukAAAAJ), [Bo Jiang](https://wutong8023.site/VersiCode/), [Ping Yang](https://scholar.google.com/citations?view_op=list_works&hl=en&hl=en&user=hrogvxoAAAAJ), [Zhenchang Xing](https://scholar.google.com/citations?hl=zh-CN&user=0vCxuH4AAAAJ), [Yuan-Fang Li](https://scholar.google.com/citations?hl=zh-CN&user=wufXO1kAAAAJ), [Gholamreza Haffari](https://scholar.google.com/citations?hl=zh-CN&user=Perjx5EAAAAJ)
|
||||
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
"""
|
||||
GPT performs line level generation prediction and truncates overly long tokens
|
||||
"""
|
||||
import json
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
import os
|
||||
import tiktoken
|
||||
max_tokens = 127000 #gpt3.5 is 16ktoken gpt4o is 128k
|
||||
model_name = ""
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = ""
|
||||
import json
|
||||
import os
|
||||
|
||||
import tiktoken
|
||||
from openai import OpenAI
|
||||
|
||||
max_tokens = 127000 # gpt3.5 is 16ktoken gpt4o is 128k
|
||||
model_name = ''
|
||||
|
||||
os.environ['OPENAI_API_KEY'] = ''
|
||||
client = OpenAI()
|
||||
|
||||
|
||||
def truncate_text(text, max_tokens):
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
encoding = tiktoken.get_encoding('cl100k_base')
|
||||
disallowed_special = ()
|
||||
|
||||
tokens = encoding.encode(text, disallowed_special=disallowed_special)
|
||||
@@ -26,15 +29,11 @@ def truncate_text(text, max_tokens):
|
||||
|
||||
return truncated_text
|
||||
|
||||
|
||||
def predict(content, model_name):
|
||||
response = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": content
|
||||
}
|
||||
],
|
||||
messages=[{'role': 'user', 'content': content}],
|
||||
frequency_penalty=0.1,
|
||||
max_tokens=128,
|
||||
logit_bias=None,
|
||||
@@ -45,7 +44,7 @@ def predict(content, model_name):
|
||||
stop=None,
|
||||
stream=False,
|
||||
temperature=0.8,
|
||||
top_p=0.95
|
||||
top_p=0.95,
|
||||
)
|
||||
ans_list = []
|
||||
choices_list = response.choices
|
||||
@@ -55,6 +54,7 @@ def predict(content, model_name):
|
||||
final_ans = str(ans_list)
|
||||
return final_ans
|
||||
|
||||
|
||||
def bulid_prompt(description, old_version, old_code, new_version) -> str:
|
||||
"""
|
||||
build prompt
|
||||
@@ -93,11 +93,13 @@ data_list = data_dict
|
||||
|
||||
|
||||
for data in data_list:
|
||||
if "model_output" in data:
|
||||
print(f"the {data_list.index(data) + 1} has already been predicted, skipping this data!")
|
||||
if 'model_output' in data:
|
||||
print(
|
||||
f'the {data_list.index(data) + 1} has already been predicted, skipping this data!'
|
||||
)
|
||||
continue
|
||||
try:
|
||||
print(f"Predicting {data_list.index(data) + 1} ")
|
||||
print(f'Predicting {data_list.index(data) + 1} ')
|
||||
old_version = data['dependency'] + data['old_version'] # package == x.x.x
|
||||
new_version = data['dependency'] + data['new_version'] # package == x.x.x
|
||||
description = data['description'] # 功能描述
|
||||
@@ -109,9 +111,11 @@ for data in data_list:
|
||||
|
||||
data['model_output'] = prediction
|
||||
except Exception as e:
|
||||
print(f"error:{e}")
|
||||
print("save current data")
|
||||
save_folder_path = os.path.join('../data/result_data/code_migration', model_name)
|
||||
print(f'error:{e}')
|
||||
print('save current data')
|
||||
save_folder_path = os.path.join(
|
||||
'../data/result_data/code_migration', model_name
|
||||
)
|
||||
if not os.path.exists(save_folder_path):
|
||||
os.makedirs(save_folder_path)
|
||||
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
||||
@@ -121,7 +125,6 @@ for data in data_list:
|
||||
break
|
||||
|
||||
|
||||
|
||||
save_folder_path = os.path.join('../data/result_data/code_migration', model_name)
|
||||
if not os.path.exists(save_folder_path):
|
||||
os.makedirs(save_folder_path)
|
||||
@@ -129,6 +132,3 @@ save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
||||
|
||||
with open(save_json_path, 'w', encoding='utf-8') as fw:
|
||||
json.dump(data_dict, fw, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
"""
|
||||
GPT performs line level generation prediction and truncates overly long tokens
|
||||
"""
|
||||
import json
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
import os
|
||||
import tiktoken
|
||||
max_tokens = 127000 #gpt3.5 is 16ktoken gpt4o is 128k
|
||||
model_name = ""
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = ""
|
||||
import json
|
||||
import os
|
||||
|
||||
import tiktoken
|
||||
from openai import OpenAI
|
||||
|
||||
max_tokens = 127000 # gpt3.5 is 16ktoken gpt4o is 128k
|
||||
model_name = ''
|
||||
|
||||
os.environ['OPENAI_API_KEY'] = ''
|
||||
client = OpenAI()
|
||||
|
||||
|
||||
def truncate_text(text, max_tokens):
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
encoding = tiktoken.get_encoding('cl100k_base')
|
||||
disallowed_special = ()
|
||||
|
||||
tokens = encoding.encode(text, disallowed_special=disallowed_special)
|
||||
@@ -26,15 +29,11 @@ def truncate_text(text, max_tokens):
|
||||
|
||||
return truncated_text
|
||||
|
||||
|
||||
def predict(content, model_name):
|
||||
response = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": content
|
||||
}
|
||||
],
|
||||
messages=[{'role': 'user', 'content': content}],
|
||||
frequency_penalty=0.1,
|
||||
max_tokens=128,
|
||||
logit_bias=None,
|
||||
@@ -45,7 +44,7 @@ def predict(content, model_name):
|
||||
stop=None,
|
||||
stream=False,
|
||||
temperature=0.8,
|
||||
top_p=0.95
|
||||
top_p=0.95,
|
||||
)
|
||||
ans_list = []
|
||||
choices_list = response.choices
|
||||
@@ -55,6 +54,7 @@ def predict(content, model_name):
|
||||
final_ans = str(ans_list)
|
||||
return final_ans
|
||||
|
||||
|
||||
def bulid_prompt(version, description) -> str:
|
||||
"""
|
||||
build prompt
|
||||
@@ -64,7 +64,7 @@ def bulid_prompt(version, description) -> str:
|
||||
:param options:
|
||||
:return:
|
||||
"""
|
||||
prompt = f'''
|
||||
prompt = f"""
|
||||
You are a professional Python engineer, and I will provide functional descriptions and versions of specified dependency packages.
|
||||
You need to write code in Python to implement this feature based on the functional description and using the dependency package and version I specified.
|
||||
Please note that you only need to return the code that implements the function, and do not return any other content.
|
||||
@@ -88,7 +88,7 @@ def bulid_prompt(version, description) -> str:
|
||||
###response:
|
||||
|
||||
|
||||
'''
|
||||
"""
|
||||
return prompt
|
||||
|
||||
|
||||
@@ -102,11 +102,13 @@ data_list = data_dict
|
||||
|
||||
|
||||
for data in data_list:
|
||||
if "model_output" in data:
|
||||
print(f"the {data_list.index(data) + 1} has already been predicted, skipping this data!")
|
||||
if 'model_output' in data:
|
||||
print(
|
||||
f'the {data_list.index(data) + 1} has already been predicted, skipping this data!'
|
||||
)
|
||||
continue
|
||||
try:
|
||||
print(f"Predicting {data_list.index(data) + 1} ")
|
||||
print(f'Predicting {data_list.index(data) + 1} ')
|
||||
version = data['dependency'] + data['version'] # package == x.x.x
|
||||
description = data['description'] # func description
|
||||
|
||||
@@ -116,9 +118,11 @@ for data in data_list:
|
||||
|
||||
data['model_output'] = prediction
|
||||
except Exception as e:
|
||||
print(f"error:{e}")
|
||||
print("save current data")
|
||||
save_folder_path = os.path.join('../data/result_data/block_completion', model_name)
|
||||
print(f'error:{e}')
|
||||
print('save current data')
|
||||
save_folder_path = os.path.join(
|
||||
'../data/result_data/block_completion', model_name
|
||||
)
|
||||
if not os.path.exists(save_folder_path):
|
||||
os.makedirs(save_folder_path)
|
||||
save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
||||
@@ -128,7 +132,6 @@ for data in data_list:
|
||||
break
|
||||
|
||||
|
||||
|
||||
save_folder_path = os.path.join('../data/result_data/block_completion', model_name)
|
||||
if not os.path.exists(save_folder_path):
|
||||
os.makedirs(save_folder_path)
|
||||
@@ -136,6 +139,3 @@ save_json_path = os.path.join(save_folder_path, json_path.split('/')[-1])
|
||||
|
||||
with open(save_json_path, 'w', encoding='utf-8') as fw:
|
||||
json.dump(data_dict, fw, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,20 +1,23 @@
|
||||
"""
|
||||
block completion
|
||||
"""
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
from vllm import LLM, SamplingParams
|
||||
import tiktoken
|
||||
import time
|
||||
import gc
|
||||
import torch
|
||||
from multiprocessing import Process
|
||||
|
||||
import tiktoken
|
||||
import torch
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||
|
||||
|
||||
def truncate_text(text, max_tokens):
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
encoding = tiktoken.get_encoding('cl100k_base')
|
||||
disallowed_special = ()
|
||||
|
||||
tokens = encoding.encode(text, disallowed_special=disallowed_special)
|
||||
@@ -27,8 +30,10 @@ def truncate_text(text, max_tokens):
|
||||
|
||||
return truncated_text
|
||||
|
||||
|
||||
model_list = ['/data2/base models/starcoder2-15b', '/data2/base models/CodeGemma-7B']
|
||||
|
||||
|
||||
def run_inference(model_name, origin_data_list):
|
||||
temp_data_list = copy.deepcopy(origin_data_list)
|
||||
test_list = []
|
||||
@@ -40,7 +45,12 @@ def run_inference(model_name, origin_data_list):
|
||||
test_list.append(instruction)
|
||||
|
||||
sampling_params = SamplingParams(n=6, temperature=0.8, top_p=0.95, max_tokens=64)
|
||||
llm = LLM(model=model_name, tensor_parallel_size=4, gpu_memory_utilization=0.9, swap_space=20)
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=4,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=20,
|
||||
)
|
||||
|
||||
outputs = llm.generate(test_list, sampling_params)
|
||||
for output in outputs:
|
||||
@@ -53,7 +63,9 @@ def run_inference(model_name, origin_data_list):
|
||||
|
||||
temp_data_list[requests_id]['model_output'] = str(temp_ans_list)
|
||||
|
||||
save_folder_path = os.path.join('../data/result_data/block_completion', model_name.split('/')[-1])
|
||||
save_folder_path = os.path.join(
|
||||
'../data/result_data/block_completion', model_name.split('/')[-1]
|
||||
)
|
||||
if not os.path.exists(save_folder_path):
|
||||
os.makedirs(save_folder_path)
|
||||
|
||||
@@ -75,7 +87,7 @@ def bulid_prompt(version, description) -> str:
|
||||
:param options:
|
||||
:return:
|
||||
"""
|
||||
prompt = f'''
|
||||
prompt = f"""
|
||||
You are a professional Python engineer, and I will provide functional descriptions and versions of specified dependency packages.
|
||||
You need to write code in Python to implement this feature based on the functional description and using the dependency package and version I specified.
|
||||
Please note that you only need to return the code that implements the function, and do not return any other content.
|
||||
@@ -99,7 +111,7 @@ def bulid_prompt(version, description) -> str:
|
||||
###response:
|
||||
|
||||
|
||||
'''
|
||||
"""
|
||||
return prompt
|
||||
|
||||
|
||||
@@ -115,4 +127,3 @@ for model_name in model_list:
|
||||
process.start()
|
||||
process.join()
|
||||
time.sleep(120)
|
||||
|
||||
|
||||
@@ -1,20 +1,23 @@
|
||||
"""
|
||||
code migration
|
||||
"""
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
from vllm import LLM, SamplingParams
|
||||
import tiktoken
|
||||
import time
|
||||
import gc
|
||||
import torch
|
||||
from multiprocessing import Process
|
||||
|
||||
import tiktoken
|
||||
import torch
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||
|
||||
|
||||
def truncate_text(text, max_tokens):
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
encoding = tiktoken.get_encoding('cl100k_base')
|
||||
disallowed_special = ()
|
||||
|
||||
tokens = encoding.encode(text, disallowed_special=disallowed_special)
|
||||
@@ -27,8 +30,10 @@ def truncate_text(text, max_tokens):
|
||||
|
||||
return truncated_text
|
||||
|
||||
|
||||
model_list = ['/data2/base models/starcoder2-15b', '/data2/base models/CodeGemma-7B']
|
||||
|
||||
|
||||
def run_inference(model_name, origin_data_list):
|
||||
temp_data_list = copy.deepcopy(origin_data_list)
|
||||
test_list = []
|
||||
@@ -42,7 +47,12 @@ def run_inference(model_name, origin_data_list):
|
||||
test_list.append(instruction)
|
||||
|
||||
sampling_params = SamplingParams(n=6, temperature=0.8, top_p=0.95, max_tokens=512)
|
||||
llm = LLM(model=model_name, tensor_parallel_size=4, gpu_memory_utilization=0.6, swap_space=40)
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
tensor_parallel_size=4,
|
||||
gpu_memory_utilization=0.6,
|
||||
swap_space=40,
|
||||
)
|
||||
|
||||
outputs = llm.generate(test_list, sampling_params)
|
||||
for output in outputs:
|
||||
@@ -55,7 +65,9 @@ def run_inference(model_name, origin_data_list):
|
||||
|
||||
temp_data_list[requests_id]['model_output'] = str(temp_ans_list)
|
||||
|
||||
save_folder_path = os.path.join('../data/result_data/code_migration', model_name.split('/')[-1])
|
||||
save_folder_path = os.path.join(
|
||||
'../data/result_data/code_migration', model_name.split('/')[-1]
|
||||
)
|
||||
if not os.path.exists(save_folder_path):
|
||||
os.makedirs(save_folder_path)
|
||||
|
||||
@@ -108,4 +120,3 @@ for model_name in model_list:
|
||||
process.start()
|
||||
process.join()
|
||||
time.sleep(120)
|
||||
|
||||
|
||||
@@ -4,20 +4,20 @@
|
||||
2、判断是否合法
|
||||
3、计算ISM,和PM
|
||||
"""
|
||||
import json
|
||||
import tokenize
|
||||
|
||||
import io
|
||||
import json
|
||||
import math
|
||||
import ast
|
||||
import re
|
||||
import os
|
||||
import re
|
||||
import tokenize
|
||||
|
||||
|
||||
def is_code_valid(code):
|
||||
|
||||
try:
|
||||
compile(code, '<string>', 'exec')
|
||||
return True
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ def longest_common_prefix_between_lists_with_elements(list1, list2):
|
||||
max_prefix_elements = (str1, str2)
|
||||
return max_prefix_length, max_prefix_elements
|
||||
|
||||
|
||||
def get_token(ans_code: str, output_code: str):
|
||||
"""
|
||||
对代码进行词法分析,分解成标识符,返回两个标识符列表
|
||||
@@ -55,40 +56,40 @@ def get_token(ans_code:str, output_code:str):
|
||||
ans_flag = True
|
||||
try:
|
||||
tokens_ans = tokenize.tokenize(io.BytesIO(ans_code.encode('utf-8')).readline)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
tokens_ans = ans_code.splitlines()
|
||||
ans_flag = False
|
||||
|
||||
try:
|
||||
tokens_output = tokenize.tokenize(io.BytesIO(output_code.encode('utf-8')).readline)
|
||||
except Exception as e:
|
||||
tokens_output = tokenize.tokenize(
|
||||
io.BytesIO(output_code.encode('utf-8')).readline
|
||||
)
|
||||
except Exception:
|
||||
tokens_output = output_code.splitlines()
|
||||
output_flag = False
|
||||
|
||||
|
||||
identifiers_ans = []
|
||||
identifiers_output = []
|
||||
if ans_flag == True:
|
||||
if ans_flag:
|
||||
try:
|
||||
for token in tokens_ans:
|
||||
if token.type == tokenize.NAME:
|
||||
identifiers_ans.append(token.string)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
identifiers_ans = tokens_ans
|
||||
else:
|
||||
identifiers_ans = tokens_ans
|
||||
|
||||
if output_flag == True:
|
||||
if output_flag:
|
||||
try:
|
||||
for to in tokens_output:
|
||||
if to.type == tokenize.NAME:
|
||||
identifiers_output.append(to.string)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
identifiers_output = tokens_output
|
||||
else:
|
||||
identifiers_output = tokens_output
|
||||
|
||||
|
||||
return identifiers_ans, identifiers_output
|
||||
|
||||
|
||||
@@ -108,14 +109,13 @@ def get_token_per_line(code: str):
|
||||
for token in tokens:
|
||||
if token.type == tokenize.NAME:
|
||||
identifiers.append(token.string)
|
||||
except:
|
||||
except Exception:
|
||||
identifiers = line.split(' ')
|
||||
identifiers_per_line.append(identifiers)
|
||||
|
||||
return identifiers_per_line
|
||||
|
||||
|
||||
|
||||
def get_ISM(answer_code: str, model_output_list: list, asnwer_name: str) -> list:
|
||||
"""
|
||||
计算ISM,返回一个有序的得分列表
|
||||
@@ -126,7 +126,9 @@ def get_ISM(answer_code:str, model_output_list:list, asnwer_name:str)->list:
|
||||
if '```python' in code:
|
||||
code = code.replace('```python', '')
|
||||
code = code.replace('```', '')
|
||||
if not re.search(rf'\b{re.escape(asnwer_name)}\b', code) or is_code_valid(code) == False:
|
||||
if not re.search(rf'\b{re.escape(asnwer_name)}\b', code) or not is_code_valid(
|
||||
code
|
||||
):
|
||||
score_list.append(0)
|
||||
continue
|
||||
|
||||
@@ -135,7 +137,9 @@ def get_ISM(answer_code:str, model_output_list:list, asnwer_name:str)->list:
|
||||
# continue
|
||||
|
||||
identifiers_ans, identifiers_output = get_token(answer_code, code)
|
||||
max_len, elements = longest_common_prefix_between_lists_with_elements(identifiers_ans, identifiers_output)
|
||||
max_len, elements = longest_common_prefix_between_lists_with_elements(
|
||||
identifiers_ans, identifiers_output
|
||||
)
|
||||
if max_len != 0:
|
||||
base_element_len = max(len(elements[0]), len(elements[1]))
|
||||
temp_score = max_len / base_element_len
|
||||
@@ -149,14 +153,16 @@ def get_ISM(answer_code:str, model_output_list:list, asnwer_name:str)->list:
|
||||
score_list = sorted(score_list, reverse=True)
|
||||
return score_list
|
||||
|
||||
def get_ISM_without_verification(answer_code:str, model_output_list:list, asnwer_name:str)->list:
|
||||
|
||||
def get_ISM_without_verification(
|
||||
answer_code: str, model_output_list: list, asnwer_name: str
|
||||
) -> list:
|
||||
"""
|
||||
计算ISM,返回一个有序的得分列表
|
||||
:return:
|
||||
"""
|
||||
score_list = []
|
||||
for code in model_output_list:
|
||||
|
||||
if asnwer_name not in code:
|
||||
score_list.append(0)
|
||||
continue
|
||||
@@ -166,7 +172,9 @@ def get_ISM_without_verification(answer_code:str, model_output_list:list, asnwer
|
||||
# continue
|
||||
|
||||
identifiers_ans, identifiers_output = get_token(answer_code, code)
|
||||
max_len, elements = longest_common_prefix_between_lists_with_elements(identifiers_ans, identifiers_output)
|
||||
max_len, elements = longest_common_prefix_between_lists_with_elements(
|
||||
identifiers_ans, identifiers_output
|
||||
)
|
||||
if max_len != 0:
|
||||
base_element_len = max(len(elements[0]), len(elements[1]))
|
||||
temp_score = max_len / base_element_len
|
||||
@@ -180,6 +188,7 @@ def get_ISM_without_verification(answer_code:str, model_output_list:list, asnwer
|
||||
score_list = sorted(score_list, reverse=True)
|
||||
return score_list
|
||||
|
||||
|
||||
def longest_common_prefix_with_lengths(list1, list2):
|
||||
"""
|
||||
计算两个二维列表中每个子列表的最长前缀匹配长度,并记录拥有最长前缀匹配长度的两个子列表的长度
|
||||
@@ -216,8 +225,9 @@ def get_PM(answer_code:str, model_output_list:list, asnwer_name:str)->list:
|
||||
if '```python' in code:
|
||||
code = code.replace('```python', '')
|
||||
code = code.replace('```', '')
|
||||
if not re.search(rf'\b{re.escape(asnwer_name)}\b', code) or is_code_valid(code) == False:
|
||||
|
||||
if not re.search(rf'\b{re.escape(asnwer_name)}\b', code) or not is_code_valid(
|
||||
code
|
||||
):
|
||||
# if asnwer_name not in code or is_code_valid(code) == False:
|
||||
score_list.append(0)
|
||||
continue
|
||||
@@ -228,7 +238,9 @@ def get_PM(answer_code:str, model_output_list:list, asnwer_name:str)->list:
|
||||
|
||||
ans_list = get_token_per_line(answer_code)
|
||||
output_token_list = get_token_per_line(code)
|
||||
max_len, len1, len2 = longest_common_prefix_with_lengths(ans_list, output_token_list)
|
||||
max_len, len1, len2 = longest_common_prefix_with_lengths(
|
||||
ans_list, output_token_list
|
||||
)
|
||||
base_element_len = max(len1, len2)
|
||||
|
||||
if base_element_len != 0:
|
||||
@@ -240,6 +252,7 @@ def get_PM(answer_code:str, model_output_list:list, asnwer_name:str)->list:
|
||||
score_list = sorted(score_list, reverse=True)
|
||||
return score_list
|
||||
|
||||
|
||||
def get_score(score_list: list, k):
|
||||
"""
|
||||
计算score@n,k
|
||||
@@ -260,7 +273,7 @@ def get_score(score_list:list, k):
|
||||
|
||||
k = 1
|
||||
task = 'block' # block or line
|
||||
json_name = f"Versicode_{task}_completion.json"
|
||||
json_name = f'Versicode_{task}_completion.json'
|
||||
|
||||
folder_path = f'../data/result_data/{task}_completion'
|
||||
model_list = os.listdir(folder_path)
|
||||
@@ -309,7 +322,6 @@ for model in model_list:
|
||||
# if flag == 1:
|
||||
# continue
|
||||
|
||||
|
||||
ISM_score = get_score(ISM_score_list, k)
|
||||
PM_score = get_score(PM_score_list, k)
|
||||
|
||||
@@ -318,9 +330,8 @@ for model in model_list:
|
||||
# print(f"ISM分数:{ISM_score}")
|
||||
# print(f"PM分数:{PM_score}")
|
||||
|
||||
print(f"{model}, {task} completion task, ISM@{k} score: {sum_ISM/data_len}")
|
||||
print(f"{model}, {task} completion task, PM@{k} score: {sum_PM/data_len}")
|
||||
|
||||
print(f'{model}, {task} completion task, ISM@{k} score: {sum_ISM / data_len}')
|
||||
print(f'{model}, {task} completion task, PM@{k} score: {sum_PM / data_len}')
|
||||
|
||||
|
||||
# def get_token(ans_code:str, output_code:str):
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
"""
|
||||
Calculate the cdc score for migration
|
||||
"""
|
||||
import os
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
|
||||
# warnings.filterwarnings("ignore", category=SyntaxWarning)
|
||||
|
||||
|
||||
def is_correct_parameter_count(function_name, correct_code, test_code):
|
||||
"""
|
||||
判断参数数量是否一致
|
||||
@@ -39,6 +41,7 @@ def is_correct_parameter_count(function_name, correct_code, test_code):
|
||||
# 如果没有括号,检查函数名是否在字符串中
|
||||
return expected_count == 0 and function_name in test_code
|
||||
|
||||
|
||||
def check_keyword_parameters(function_name, correct_code, test_code):
|
||||
"""
|
||||
判断关键词参数赋值是否正确使用
|
||||
@@ -67,13 +70,17 @@ def check_keyword_parameters(function_name, correct_code, test_code):
|
||||
for correct_param in correct_param_list:
|
||||
if '=' in correct_param: # 仅当正确代码中有关键词参数
|
||||
param_name = correct_param.split('=')[0].strip()
|
||||
if not any(param_name in test_param and '=' in test_param for test_param in test_param_list):
|
||||
if not any(
|
||||
param_name in test_param and '=' in test_param
|
||||
for test_param in test_param_list
|
||||
):
|
||||
return False # 如果对应参数不是关键词参数,则返回False
|
||||
|
||||
return True # 所有关键字参数匹配
|
||||
|
||||
return False # 如果没有匹配,返回False
|
||||
|
||||
|
||||
def with_correct(answer_code: str, model_output: str) -> bool:
|
||||
"""
|
||||
当answer是with结构时,判断模型生成的是不是with结构
|
||||
@@ -89,14 +96,32 @@ def with_correct(answer_code:str, model_output:str)->bool:
|
||||
else:
|
||||
return False
|
||||
|
||||
def compute_block_score_k(answer:str, model_output:list, k:int, model_filled_code, core_line_in_core_block, core_line_in_output_clear):
|
||||
|
||||
def compute_block_score_k(
|
||||
answer: str,
|
||||
model_output: list,
|
||||
k: int,
|
||||
model_filled_code,
|
||||
core_line_in_core_block,
|
||||
core_line_in_output_clear,
|
||||
):
|
||||
"""
|
||||
cdc需要满足五个条件,em只需要满足第一个条件
|
||||
"""
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(model_filled_code[index]) and is_correct_parameter_count(answer, core_line_in_core_block, core_line_in_output_clear[index]) and with_correct(core_line_in_core_block, core_line_in_output_clear[index]) and check_keyword_parameters(answer, core_line_in_core_block, core_line_in_output_clear[index]):#block
|
||||
if (
|
||||
re.search(rf'\b{re.escape(answer)}\b', code)
|
||||
and is_code_valid(model_filled_code[index])
|
||||
and is_correct_parameter_count(
|
||||
answer, core_line_in_core_block, core_line_in_output_clear[index]
|
||||
)
|
||||
and with_correct(core_line_in_core_block, core_line_in_output_clear[index])
|
||||
and check_keyword_parameters(
|
||||
answer, core_line_in_core_block, core_line_in_output_clear[index]
|
||||
)
|
||||
): # block
|
||||
# if re.search(rf'\b{re.escape(answer)}\b', code):#block
|
||||
c += 1
|
||||
if n - c < k:
|
||||
@@ -108,15 +133,14 @@ def compute_block_score_k(answer:str, model_output:list, k:int, model_filled_cod
|
||||
|
||||
|
||||
def is_code_valid(code):
|
||||
|
||||
try:
|
||||
compile(code, '<string>', 'exec')
|
||||
return True
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def compute_score_k(answer:str, model_output:list, k:int):
|
||||
|
||||
def compute_score_k(answer: str, model_output: list, k: int):
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for output in model_output:
|
||||
@@ -125,7 +149,7 @@ def compute_score_k(answer:str, model_output:list, k:int):
|
||||
output = output.replace('```', '')
|
||||
# if answer == output:
|
||||
|
||||
if re.search(rf'\b{re.escape(answer)}\b', output) and is_code_valid(output) == True:
|
||||
if re.search(rf'\b{re.escape(answer)}\b', output) and is_code_valid(output):
|
||||
c += 1
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
@@ -134,10 +158,11 @@ def compute_score_k(answer:str, model_output:list, k:int):
|
||||
|
||||
return score
|
||||
|
||||
|
||||
k = 1 # cdc@k
|
||||
json_name = 'VersiCode_migration.json'
|
||||
task = 'migration'
|
||||
folder_path = f'../data/result_data/code_migration'
|
||||
folder_path = '../data/result_data/code_migration'
|
||||
|
||||
model_list = os.listdir(folder_path)
|
||||
for model in model_list:
|
||||
@@ -151,15 +176,23 @@ for model in model_list:
|
||||
score_list = []
|
||||
for data in data_list:
|
||||
answer = data['new_name'] # old -> new
|
||||
model_output = data[f'model_output_clear']# old -> new
|
||||
model_output = data['model_output_clear'] # old -> new
|
||||
|
||||
model_filled_code = model_output
|
||||
# core_line_in_core_block = data['core_line_in_new_core_block']# old -> new
|
||||
core_line_in_core_block = data['core_line_in_code'] # old -> new
|
||||
core_line_in_output_clear = data['core_line_in_output_clear'] # old -> new
|
||||
|
||||
|
||||
score_list.append(compute_block_score_k(answer, model_output, k, model_filled_code, core_line_in_core_block, core_line_in_output_clear))
|
||||
score_list.append(
|
||||
compute_block_score_k(
|
||||
answer,
|
||||
model_output,
|
||||
k,
|
||||
model_filled_code,
|
||||
core_line_in_core_block,
|
||||
core_line_in_output_clear,
|
||||
)
|
||||
)
|
||||
|
||||
final_score = sum(score_list) / len(score_list)
|
||||
print(f"{model}, {task} task, cdc@{k} score: {final_score}")
|
||||
print(f'{model}, {task} task, cdc@{k} score: {final_score}')
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
"""
|
||||
Calculate the cdc score for line and block
|
||||
"""
|
||||
import os
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
|
||||
# warnings.filterwarnings("ignore", category=SyntaxWarning)
|
||||
|
||||
def is_code_valid(code):
|
||||
|
||||
def is_code_valid(code):
|
||||
try:
|
||||
compile(code, '<string>', 'exec')
|
||||
return True
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_correct_parameter_count(function_name, correct_code, test_code):
|
||||
"""
|
||||
判断参数数量是否一致
|
||||
@@ -47,6 +49,7 @@ def is_correct_parameter_count(function_name, correct_code, test_code):
|
||||
# 如果没有括号,检查函数名是否在字符串中
|
||||
return expected_count == 0 and function_name in test_code
|
||||
|
||||
|
||||
def check_keyword_parameters(function_name, correct_code, test_code):
|
||||
"""
|
||||
判断关键词参数赋值是否正确使用
|
||||
@@ -75,13 +78,17 @@ def check_keyword_parameters(function_name, correct_code, test_code):
|
||||
for correct_param in correct_param_list:
|
||||
if '=' in correct_param: # 仅当正确代码中有关键词参数
|
||||
param_name = correct_param.split('=')[0].strip()
|
||||
if not any(param_name in test_param and '=' in test_param for test_param in test_param_list):
|
||||
if not any(
|
||||
param_name in test_param and '=' in test_param
|
||||
for test_param in test_param_list
|
||||
):
|
||||
return False # 如果对应参数不是关键词参数,则返回False
|
||||
|
||||
return True # 所有关键字参数匹配
|
||||
|
||||
return False # 如果没有匹配,返回False
|
||||
|
||||
|
||||
def with_correct(answer_code: str, model_output: str) -> bool:
|
||||
"""
|
||||
当answer是with结构时,判断模型生成的是不是with结构
|
||||
@@ -97,12 +104,20 @@ def with_correct(answer_code:str, model_output:str)->bool:
|
||||
else:
|
||||
return False
|
||||
|
||||
def compute_line_score_k(answer:str, model_output:list, k:int, model_filled_code, core_line):
|
||||
|
||||
def compute_line_score_k(
|
||||
answer: str, model_output: list, k: int, model_filled_code, core_line
|
||||
):
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(model_filled_code[index]) == True and is_correct_parameter_count(answer, core_line, code) and with_correct(core_line, code) and check_keyword_parameters(answer, core_line, code):#line
|
||||
if (
|
||||
re.search(rf'\b{re.escape(answer)}\b', code)
|
||||
and is_code_valid(model_filled_code[index])
|
||||
and is_correct_parameter_count(answer, core_line, code)
|
||||
and with_correct(core_line, code)
|
||||
and check_keyword_parameters(answer, core_line, code)
|
||||
): # line
|
||||
c += 1
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
@@ -111,12 +126,29 @@ def compute_line_score_k(answer:str, model_output:list, k:int, model_filled_code
|
||||
|
||||
return score
|
||||
|
||||
def compute_block_score_k(answer:str, model_output:list, k:int, model_filled_code, core_line_in_core_block, core_line_in_output_clear):
|
||||
|
||||
def compute_block_score_k(
|
||||
answer: str,
|
||||
model_output: list,
|
||||
k: int,
|
||||
model_filled_code,
|
||||
core_line_in_core_block,
|
||||
core_line_in_output_clear,
|
||||
):
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(model_filled_code[index]) and is_correct_parameter_count(answer, core_line_in_core_block, core_line_in_output_clear[index]) and with_correct(core_line_in_core_block, core_line_in_output_clear[index]) and check_keyword_parameters(answer, core_line_in_core_block, core_line_in_output_clear[index]):#block
|
||||
if (
|
||||
re.search(rf'\b{re.escape(answer)}\b', code)
|
||||
and is_code_valid(model_filled_code[index])
|
||||
and is_correct_parameter_count(
|
||||
answer, core_line_in_core_block, core_line_in_output_clear[index]
|
||||
)
|
||||
and with_correct(core_line_in_core_block, core_line_in_output_clear[index])
|
||||
and check_keyword_parameters(
|
||||
answer, core_line_in_core_block, core_line_in_output_clear[index]
|
||||
)
|
||||
): # block
|
||||
c += 1
|
||||
if n - c < k:
|
||||
return 1.0
|
||||
@@ -125,12 +157,14 @@ def compute_block_score_k(answer:str, model_output:list, k:int, model_filled_cod
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def compute_score_k(answer: str, model_output: list, k: int):
|
||||
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(code):#block
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(
|
||||
code
|
||||
): # block
|
||||
# if re.search(rf'\b{re.escape(answer)}\b', code):#line
|
||||
c += 1
|
||||
if n - c < k:
|
||||
@@ -140,9 +174,10 @@ def compute_score_k(answer:str, model_output:list, k:int):
|
||||
|
||||
return score
|
||||
|
||||
|
||||
k = 3 # cdc@k
|
||||
task = 'block' # line or block
|
||||
json_name = f"Versicode_{task}_completion.json"
|
||||
json_name = f'Versicode_{task}_completion.json'
|
||||
|
||||
folder_path = f'../data/result_data/{task}_completion'
|
||||
model_list = os.listdir(folder_path)
|
||||
@@ -158,9 +193,15 @@ for model in model_list:
|
||||
for data in data_list:
|
||||
answer = data['core_token']
|
||||
model_output = eval(data['model_output_clear'])
|
||||
model_filled_code = [data['masked_code'].replace('<mask>', i) for i in model_output]
|
||||
model_filled_code = [
|
||||
data['masked_code'].replace('<mask>', i) for i in model_output
|
||||
]
|
||||
core_line = data['core_line']
|
||||
score_list.append(compute_line_score_k(answer, model_output, k, model_filled_code, core_line))
|
||||
score_list.append(
|
||||
compute_line_score_k(
|
||||
answer, model_output, k, model_filled_code, core_line
|
||||
)
|
||||
)
|
||||
else:
|
||||
score_list = []
|
||||
for data in data_list:
|
||||
@@ -169,7 +210,16 @@ for model in model_list:
|
||||
model_filled_code = eval(data['model_output_clear'])
|
||||
core_line = data['core_line']
|
||||
core_line_in_output_clear = data['core_line_in_output_clear']
|
||||
score_list.append(compute_block_score_k(answer, model_output, k, model_filled_code, core_line, core_line_in_output_clear))
|
||||
score_list.append(
|
||||
compute_block_score_k(
|
||||
answer,
|
||||
model_output,
|
||||
k,
|
||||
model_filled_code,
|
||||
core_line,
|
||||
core_line_in_output_clear,
|
||||
)
|
||||
)
|
||||
|
||||
final_score = sum(score_list) / len(score_list)
|
||||
print(f"{model}, {task} completion task, cdc@{k} score: {final_score}")
|
||||
print(f'{model}, {task} completion task, cdc@{k} score: {final_score}')
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
"""
|
||||
Calculate the cdc score for line and block
|
||||
"""
|
||||
import os
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
|
||||
# warnings.filterwarnings("ignore", category=SyntaxWarning)
|
||||
|
||||
def is_code_valid(code):
|
||||
|
||||
def is_code_valid(code):
|
||||
try:
|
||||
compile(code, '<string>', 'exec')
|
||||
return True
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def is_correct_parameter_count(function_name, correct_code, test_code):
|
||||
"""
|
||||
判断参数数量是否一致
|
||||
@@ -47,6 +49,7 @@ def is_correct_parameter_count(function_name, correct_code, test_code):
|
||||
# 如果没有括号,检查函数名是否在字符串中
|
||||
return expected_count == 0 and function_name in test_code
|
||||
|
||||
|
||||
def check_keyword_parameters(function_name, correct_code, test_code):
|
||||
"""
|
||||
判断关键词参数赋值是否正确使用
|
||||
@@ -75,13 +78,17 @@ def check_keyword_parameters(function_name, correct_code, test_code):
|
||||
for correct_param in correct_param_list:
|
||||
if '=' in correct_param: # 仅当正确代码中有关键词参数
|
||||
param_name = correct_param.split('=')[0].strip()
|
||||
if not any(param_name in test_param and '=' in test_param for test_param in test_param_list):
|
||||
if not any(
|
||||
param_name in test_param and '=' in test_param
|
||||
for test_param in test_param_list
|
||||
):
|
||||
return False # 如果对应参数不是关键词参数,则返回False
|
||||
|
||||
return True # 所有关键字参数匹配
|
||||
|
||||
return False # 如果没有匹配,返回False
|
||||
|
||||
|
||||
def with_correct(answer_code: str, model_output: str) -> bool:
|
||||
"""
|
||||
当answer是with结构时,判断模型生成的是不是with结构
|
||||
@@ -97,8 +104,10 @@ def with_correct(answer_code:str, model_output:str)->bool:
|
||||
else:
|
||||
return False
|
||||
|
||||
def compute_line_score_k(answer:str, model_output:list, k:int, model_filled_code, core_line):
|
||||
|
||||
def compute_line_score_k(
|
||||
answer: str, model_output: list, k: int, model_filled_code, core_line
|
||||
):
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
@@ -111,8 +120,15 @@ def compute_line_score_k(answer:str, model_output:list, k:int, model_filled_code
|
||||
|
||||
return score
|
||||
|
||||
def compute_block_score_k(answer:str, model_output:list, k:int, model_filled_code, core_line_in_core_block, core_line_in_output_clear):
|
||||
|
||||
def compute_block_score_k(
|
||||
answer: str,
|
||||
model_output: list,
|
||||
k: int,
|
||||
model_filled_code,
|
||||
core_line_in_core_block,
|
||||
core_line_in_output_clear,
|
||||
):
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
@@ -125,12 +141,14 @@ def compute_block_score_k(answer:str, model_output:list, k:int, model_filled_cod
|
||||
|
||||
return score
|
||||
|
||||
def compute_score_k(answer:str, model_output:list, k:int):
|
||||
|
||||
def compute_score_k(answer: str, model_output: list, k: int):
|
||||
c = 0
|
||||
n = len(model_output)
|
||||
for index, code in enumerate(model_output):
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(code):#block
|
||||
if re.search(rf'\b{re.escape(answer)}\b', code) and is_code_valid(
|
||||
code
|
||||
): # block
|
||||
# if re.search(rf'\b{re.escape(answer)}\b', code):#line
|
||||
c += 1
|
||||
if n - c < k:
|
||||
@@ -140,9 +158,10 @@ def compute_score_k(answer:str, model_output:list, k:int):
|
||||
|
||||
return score
|
||||
|
||||
|
||||
k = 3 # em@k
|
||||
task = 'block' # line or block
|
||||
json_name = f"Versicode_{task}_completion.json"
|
||||
json_name = f'Versicode_{task}_completion.json'
|
||||
|
||||
folder_path = f'../data/result_data/{task}_completion'
|
||||
model_list = os.listdir(folder_path)
|
||||
@@ -158,9 +177,15 @@ for model in model_list:
|
||||
for data in data_list:
|
||||
answer = data['core_token']
|
||||
model_output = eval(data['model_output_clear'])
|
||||
model_filled_code = [data['masked_code'].replace('<mask>', i) for i in model_output]
|
||||
model_filled_code = [
|
||||
data['masked_code'].replace('<mask>', i) for i in model_output
|
||||
]
|
||||
core_line = data['core_line']
|
||||
score_list.append(compute_line_score_k(answer, model_output, k, model_filled_code, core_line))
|
||||
score_list.append(
|
||||
compute_line_score_k(
|
||||
answer, model_output, k, model_filled_code, core_line
|
||||
)
|
||||
)
|
||||
else:
|
||||
score_list = []
|
||||
for data in data_list:
|
||||
@@ -169,7 +194,16 @@ for model in model_list:
|
||||
model_filled_code = eval(data['model_output_clear'])
|
||||
core_line = data['core_line']
|
||||
core_line_in_output_clear = data['core_line_in_output_clear']
|
||||
score_list.append(compute_block_score_k(answer, model_output, k, model_filled_code, core_line, core_line_in_output_clear))
|
||||
score_list.append(
|
||||
compute_block_score_k(
|
||||
answer,
|
||||
model_output,
|
||||
k,
|
||||
model_filled_code,
|
||||
core_line,
|
||||
core_line_in_output_clear,
|
||||
)
|
||||
)
|
||||
|
||||
final_score = sum(score_list) / len(score_list)
|
||||
print(f"{model}, {task} completion task, em@{k} score: {final_score}")
|
||||
print(f'{model}, {task} completion task, em@{k} score: {final_score}')
|
||||
|
||||
@@ -1,48 +1,43 @@
|
||||
"""
|
||||
Find the line of code generated by the model using the block in the version code
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
|
||||
|
||||
def process_line_mask(code_snippet, core_token):
|
||||
if not core_token:
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
replaced_lines = {}
|
||||
lines = code_snippet.split("\n")
|
||||
|
||||
lines = code_snippet.split('\n')
|
||||
|
||||
in_multi_line_comment = False
|
||||
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if in_multi_line_comment:
|
||||
|
||||
if ('"""' in line or "'''" in line) and not re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
||||
if ('"""' in line or "'''" in line) and not re.findall(
|
||||
r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line
|
||||
):
|
||||
in_multi_line_comment = False
|
||||
continue
|
||||
elif line.strip().startswith("#"):
|
||||
|
||||
elif line.strip().startswith('#'):
|
||||
continue
|
||||
elif re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
||||
|
||||
continue
|
||||
elif ('"""' in line or "'''" in line) and not re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
||||
|
||||
elif ('"""' in line or "'''" in line) and not re.findall(
|
||||
r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line
|
||||
):
|
||||
in_multi_line_comment = True
|
||||
continue
|
||||
else:
|
||||
|
||||
if re.search(r'\bdef\s+task_function\b', line):
|
||||
continue
|
||||
|
||||
|
||||
if re.search(r'\b{}\b(?!\s*=)'.format(re.escape(core_token)), line):
|
||||
|
||||
replaced_lines.update({i: line})
|
||||
|
||||
if replaced_lines:
|
||||
@@ -51,7 +46,7 @@ def process_line_mask(code_snippet, core_token):
|
||||
masked_line = lines[random_line_location]
|
||||
leading_spaces = re.match(r'^\s*', masked_line).group(0)
|
||||
masked_line = masked_line.strip()
|
||||
lines[random_line_location] = leading_spaces + "<line_mask>"
|
||||
lines[random_line_location] = leading_spaces + '<line_mask>'
|
||||
|
||||
masked_code = '\n'.join(lines)
|
||||
|
||||
@@ -71,11 +66,9 @@ def save_json(file_path, data):
|
||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
model_list = os.listdir('../data/result_data/block_completion')
|
||||
for model in model_list:
|
||||
|
||||
input_json_file = f'../data/result_data/block_completion/{model}/VersiCode_block_completion.json'
|
||||
output_json_file = input_json_file
|
||||
data = load_json(input_json_file)
|
||||
@@ -88,7 +81,7 @@ if __name__ == "__main__":
|
||||
if core_line_in_code:
|
||||
item['core_line_in_code'] = core_line_in_code
|
||||
else:
|
||||
item['core_line_in_code'] = "N/A"
|
||||
item['core_line_in_code'] = 'N/A'
|
||||
|
||||
model_output_clear = item['model_output_clear']
|
||||
core_line_in_output_list = []
|
||||
@@ -98,10 +91,9 @@ if __name__ == "__main__":
|
||||
if core_line_in_output:
|
||||
core_line_in_output_list.append(core_line_in_output)
|
||||
else:
|
||||
core_line_in_output_list.append("N/A")
|
||||
core_line_in_output_list.append('N/A')
|
||||
|
||||
item['core_line_in_output_clear'] = core_line_in_output_list
|
||||
|
||||
save_json(output_json_file, data)
|
||||
print("Done!")
|
||||
|
||||
print('Done!')
|
||||
|
||||
@@ -1,48 +1,43 @@
|
||||
"""
|
||||
Find the line of code generated by the model using the block in the version code
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
|
||||
|
||||
def process_line_mask(code_snippet, core_token):
|
||||
if not core_token:
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
replaced_lines = {}
|
||||
lines = code_snippet.split("\n")
|
||||
|
||||
lines = code_snippet.split('\n')
|
||||
|
||||
in_multi_line_comment = False
|
||||
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if in_multi_line_comment:
|
||||
|
||||
if ('"""' in line or "'''" in line) and not re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
||||
if ('"""' in line or "'''" in line) and not re.findall(
|
||||
r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line
|
||||
):
|
||||
in_multi_line_comment = False
|
||||
continue
|
||||
elif line.strip().startswith("#"):
|
||||
|
||||
elif line.strip().startswith('#'):
|
||||
continue
|
||||
elif re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
||||
|
||||
continue
|
||||
elif ('"""' in line or "'''" in line) and not re.findall(r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line):
|
||||
|
||||
elif ('"""' in line or "'''" in line) and not re.findall(
|
||||
r"'''(.*?)'''|\"\"\"(.*?)\"\"\"", line
|
||||
):
|
||||
in_multi_line_comment = True
|
||||
continue
|
||||
else:
|
||||
|
||||
if re.search(r'\bdef\s+task_function\b', line):
|
||||
continue
|
||||
|
||||
|
||||
if re.search(r'\b{}\b(?!\s*=)'.format(re.escape(core_token)), line):
|
||||
|
||||
replaced_lines.update({i: line})
|
||||
|
||||
if replaced_lines:
|
||||
@@ -51,7 +46,7 @@ def process_line_mask(code_snippet, core_token):
|
||||
masked_line = lines[random_line_location]
|
||||
leading_spaces = re.match(r'^\s*', masked_line).group(0)
|
||||
masked_line = masked_line.strip()
|
||||
lines[random_line_location] = leading_spaces + "<line_mask>"
|
||||
lines[random_line_location] = leading_spaces + '<line_mask>'
|
||||
|
||||
masked_code = '\n'.join(lines)
|
||||
|
||||
@@ -71,12 +66,12 @@ def save_json(file_path, data):
|
||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
model_list = os.listdir('../data/result_data/code_migration')
|
||||
for model in model_list:
|
||||
|
||||
input_json_file = f'../data/result_data/code_migration/{model}/VersiCode_migration.json'
|
||||
input_json_file = (
|
||||
f'../data/result_data/code_migration/{model}/VersiCode_migration.json'
|
||||
)
|
||||
output_json_file = input_json_file
|
||||
data = load_json(input_json_file)
|
||||
|
||||
@@ -88,7 +83,7 @@ if __name__ == "__main__":
|
||||
if core_line_in_code:
|
||||
item['core_line_in_code'] = core_line_in_code
|
||||
else:
|
||||
item['core_line_in_code'] = "N/A"
|
||||
item['core_line_in_code'] = 'N/A'
|
||||
|
||||
model_output_clear = item['model_output_clear']
|
||||
core_line_in_output_list = []
|
||||
@@ -99,10 +94,9 @@ if __name__ == "__main__":
|
||||
if core_line_in_output:
|
||||
core_line_in_output_list.append(core_line_in_output)
|
||||
else:
|
||||
core_line_in_output_list.append("N/A")
|
||||
core_line_in_output_list.append('N/A')
|
||||
|
||||
item['core_line_in_output_clear'] = core_line_in_output_list
|
||||
|
||||
save_json(output_json_file, data)
|
||||
print("Done!")
|
||||
|
||||
print('Done!')
|
||||
|
||||
@@ -3,7 +3,6 @@ Clear the<start>and<end>generated by the model in inference
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
model_name = ''
|
||||
task = 'block_completion'
|
||||
@@ -20,13 +19,16 @@ for data in data_list:
|
||||
temp_list = []
|
||||
model_output_list = eval(data['model_output'])
|
||||
for output in model_output_list:
|
||||
|
||||
if "<start>" in output and "<end>" in output:
|
||||
start_index = output.find("<start>") + len("<start>")
|
||||
end_index = output.find("<end>")
|
||||
content = output[start_index:end_index].replace('```python', '').replace('```', '')
|
||||
if '<start>' in output and '<end>' in output:
|
||||
start_index = output.find('<start>') + len('<start>')
|
||||
end_index = output.find('<end>')
|
||||
content = (
|
||||
output[start_index:end_index]
|
||||
.replace('```python', '')
|
||||
.replace('```', '')
|
||||
)
|
||||
else:
|
||||
content = "no_answer"
|
||||
content = 'no_answer'
|
||||
|
||||
temp_list.append(content)
|
||||
|
||||
|
||||
@@ -5,24 +5,23 @@
|
||||
* Mock Service Worker.
|
||||
* @see https://github.com/mswjs/msw
|
||||
* - Please do NOT modify this file.
|
||||
* - Please do NOT serve this file on production.
|
||||
*/
|
||||
|
||||
const PACKAGE_VERSION = '2.8.4'
|
||||
const INTEGRITY_CHECKSUM = '00729d72e3b82faf54ca8b9621dbb96f'
|
||||
const PACKAGE_VERSION = '2.10.2'
|
||||
const INTEGRITY_CHECKSUM = 'f5825c521429caf22a4dd13b66e243af'
|
||||
const IS_MOCKED_RESPONSE = Symbol('isMockedResponse')
|
||||
const activeClientIds = new Set()
|
||||
|
||||
self.addEventListener('install', function () {
|
||||
addEventListener('install', function () {
|
||||
self.skipWaiting()
|
||||
})
|
||||
|
||||
self.addEventListener('activate', function (event) {
|
||||
addEventListener('activate', function (event) {
|
||||
event.waitUntil(self.clients.claim())
|
||||
})
|
||||
|
||||
self.addEventListener('message', async function (event) {
|
||||
const clientId = event.source.id
|
||||
addEventListener('message', async function (event) {
|
||||
const clientId = Reflect.get(event.source || {}, 'id')
|
||||
|
||||
if (!clientId || !self.clients) {
|
||||
return
|
||||
@@ -94,17 +93,18 @@ self.addEventListener('message', async function (event) {
|
||||
}
|
||||
})
|
||||
|
||||
self.addEventListener('fetch', function (event) {
|
||||
const { request } = event
|
||||
|
||||
addEventListener('fetch', function (event) {
|
||||
// Bypass navigation requests.
|
||||
if (request.mode === 'navigate') {
|
||||
if (event.request.mode === 'navigate') {
|
||||
return
|
||||
}
|
||||
|
||||
// Opening the DevTools triggers the "only-if-cached" request
|
||||
// that cannot be handled by the worker. Bypass such requests.
|
||||
if (request.cache === 'only-if-cached' && request.mode !== 'same-origin') {
|
||||
if (
|
||||
event.request.cache === 'only-if-cached' &&
|
||||
event.request.mode !== 'same-origin'
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -115,20 +115,26 @@ self.addEventListener('fetch', function (event) {
|
||||
return
|
||||
}
|
||||
|
||||
// Generate unique request ID.
|
||||
const requestId = crypto.randomUUID()
|
||||
event.respondWith(handleRequest(event, requestId))
|
||||
})
|
||||
|
||||
/**
|
||||
* @param {FetchEvent} event
|
||||
* @param {string} requestId
|
||||
*/
|
||||
async function handleRequest(event, requestId) {
|
||||
const client = await resolveMainClient(event)
|
||||
const requestCloneForEvents = event.request.clone()
|
||||
const response = await getResponse(event, client, requestId)
|
||||
|
||||
// Send back the response clone for the "response:*" life-cycle events.
|
||||
// Ensure MSW is active and ready to handle the message, otherwise
|
||||
// this message will pend indefinitely.
|
||||
if (client && activeClientIds.has(client.id)) {
|
||||
;(async function () {
|
||||
const serializedRequest = await serializeRequest(requestCloneForEvents)
|
||||
|
||||
// Clone the response so both the client and the library could consume it.
|
||||
const responseClone = response.clone()
|
||||
|
||||
sendToClient(
|
||||
@@ -136,27 +142,35 @@ async function handleRequest(event, requestId) {
|
||||
{
|
||||
type: 'RESPONSE',
|
||||
payload: {
|
||||
requestId,
|
||||
isMockedResponse: IS_MOCKED_RESPONSE in response,
|
||||
request: {
|
||||
id: requestId,
|
||||
...serializedRequest,
|
||||
},
|
||||
response: {
|
||||
type: responseClone.type,
|
||||
status: responseClone.status,
|
||||
statusText: responseClone.statusText,
|
||||
body: responseClone.body,
|
||||
headers: Object.fromEntries(responseClone.headers.entries()),
|
||||
body: responseClone.body,
|
||||
},
|
||||
},
|
||||
[responseClone.body],
|
||||
},
|
||||
responseClone.body ? [serializedRequest.body, responseClone.body] : [],
|
||||
)
|
||||
})()
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// Resolve the main client for the given event.
|
||||
// Client that issues a request doesn't necessarily equal the client
|
||||
// that registered the worker. It's with the latter the worker should
|
||||
// communicate with during the response resolving phase.
|
||||
/**
|
||||
* Resolve the main client for the given event.
|
||||
* Client that issues a request doesn't necessarily equal the client
|
||||
* that registered the worker. It's with the latter the worker should
|
||||
* communicate with during the response resolving phase.
|
||||
* @param {FetchEvent} event
|
||||
* @returns {Promise<Client | undefined>}
|
||||
*/
|
||||
async function resolveMainClient(event) {
|
||||
const client = await self.clients.get(event.clientId)
|
||||
|
||||
@@ -184,12 +198,16 @@ async function resolveMainClient(event) {
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {FetchEvent} event
|
||||
* @param {Client | undefined} client
|
||||
* @param {string} requestId
|
||||
* @returns {Promise<Response>}
|
||||
*/
|
||||
async function getResponse(event, client, requestId) {
|
||||
const { request } = event
|
||||
|
||||
// Clone the request because it might've been already used
|
||||
// (i.e. its body has been read and sent to the client).
|
||||
const requestClone = request.clone()
|
||||
const requestClone = event.request.clone()
|
||||
|
||||
function passthrough() {
|
||||
// Cast the request headers to a new Headers instance
|
||||
@@ -230,29 +248,17 @@ async function getResponse(event, client, requestId) {
|
||||
}
|
||||
|
||||
// Notify the client that a request has been intercepted.
|
||||
const requestBuffer = await request.arrayBuffer()
|
||||
const serializedRequest = await serializeRequest(event.request)
|
||||
const clientMessage = await sendToClient(
|
||||
client,
|
||||
{
|
||||
type: 'REQUEST',
|
||||
payload: {
|
||||
id: requestId,
|
||||
url: request.url,
|
||||
mode: request.mode,
|
||||
method: request.method,
|
||||
headers: Object.fromEntries(request.headers.entries()),
|
||||
cache: request.cache,
|
||||
credentials: request.credentials,
|
||||
destination: request.destination,
|
||||
integrity: request.integrity,
|
||||
redirect: request.redirect,
|
||||
referrer: request.referrer,
|
||||
referrerPolicy: request.referrerPolicy,
|
||||
body: requestBuffer,
|
||||
keepalive: request.keepalive,
|
||||
...serializedRequest,
|
||||
},
|
||||
},
|
||||
[requestBuffer],
|
||||
[serializedRequest.body],
|
||||
)
|
||||
|
||||
switch (clientMessage.type) {
|
||||
@@ -268,6 +274,12 @@ async function getResponse(event, client, requestId) {
|
||||
return passthrough()
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Client} client
|
||||
* @param {any} message
|
||||
* @param {Array<Transferable>} transferrables
|
||||
* @returns {Promise<any>}
|
||||
*/
|
||||
function sendToClient(client, message, transferrables = []) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const channel = new MessageChannel()
|
||||
@@ -280,14 +292,18 @@ function sendToClient(client, message, transferrables = []) {
|
||||
resolve(event.data)
|
||||
}
|
||||
|
||||
client.postMessage(
|
||||
message,
|
||||
[channel.port2].concat(transferrables.filter(Boolean)),
|
||||
)
|
||||
client.postMessage(message, [
|
||||
channel.port2,
|
||||
...transferrables.filter(Boolean),
|
||||
])
|
||||
})
|
||||
}
|
||||
|
||||
async function respondWithMock(response) {
|
||||
/**
|
||||
* @param {Response} response
|
||||
* @returns {Response}
|
||||
*/
|
||||
function respondWithMock(response) {
|
||||
// Setting response status code to 0 is a no-op.
|
||||
// However, when responding with a "Response.error()", the produced Response
|
||||
// instance will have status code set to 0. Since it's not possible to create
|
||||
@@ -305,3 +321,24 @@ async function respondWithMock(response) {
|
||||
|
||||
return mockedResponse
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Request} request
|
||||
*/
|
||||
async function serializeRequest(request) {
|
||||
return {
|
||||
url: request.url,
|
||||
mode: request.mode,
|
||||
method: request.method,
|
||||
headers: Object.fromEntries(request.headers.entries()),
|
||||
cache: request.cache,
|
||||
credentials: request.credentials,
|
||||
destination: request.destination,
|
||||
integrity: request.integrity,
|
||||
redirect: request.redirect,
|
||||
referrer: request.referrer,
|
||||
referrerPolicy: request.referrerPolicy,
|
||||
body: await request.arrayBuffer(),
|
||||
keepalive: request.keepalive,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -208,7 +208,9 @@ Note:
|
||||
# for visualwebarena, webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env
|
||||
# initialize and retrieve the first observation by issuing an noop OP
|
||||
# For non-benchmark browsing, the browser env starts with a blank page, and the agent is expected to first navigate to desired websites
|
||||
return BrowseInteractiveAction(browser_actions='noop(1000)', return_axtree=True)
|
||||
return BrowseInteractiveAction(
|
||||
browser_actions='noop(1000)', return_axtree=True
|
||||
)
|
||||
|
||||
for event in state.view:
|
||||
if isinstance(event, BrowseInteractiveAction):
|
||||
|
||||
@@ -54,6 +54,7 @@ class MCPStdioServerConfig(BaseModel):
|
||||
and set(self.env.items()) == set(other.env.items())
|
||||
)
|
||||
|
||||
|
||||
class MCPSHTTPServerConfig(BaseModel):
|
||||
url: str
|
||||
api_key: str | None = None
|
||||
|
||||
@@ -39,6 +39,7 @@ class GitHubService(BaseGitService, GitService):
|
||||
|
||||
The class is instantiated via get_impl() in openhands.server.shared.py.
|
||||
"""
|
||||
|
||||
BASE_URL = 'https://api.github.com'
|
||||
token: SecretStr = SecretStr('')
|
||||
refresh = False
|
||||
@@ -508,7 +509,6 @@ class GitHubService(BaseGitService, GitService):
|
||||
return response['html_url']
|
||||
|
||||
|
||||
|
||||
github_service_cls = os.environ.get(
|
||||
'OPENHANDS_GITHUB_SERVICE_CLS',
|
||||
'openhands.integrations.github.github_service.GitHubService',
|
||||
|
||||
@@ -32,6 +32,7 @@ class GitLabService(BaseGitService, GitService):
|
||||
|
||||
The class is instantiated via get_impl() in openhands.server.shared.py.
|
||||
"""
|
||||
|
||||
BASE_URL = 'https://gitlab.com/api/v4'
|
||||
GRAPHQL_URL = 'https://gitlab.com/api/graphql'
|
||||
token: SecretStr = SecretStr('')
|
||||
@@ -482,9 +483,7 @@ class GitLabService(BaseGitService, GitService):
|
||||
|
||||
# Set default description if none provided
|
||||
if not description:
|
||||
description = (
|
||||
f'Merging changes from {source_branch} into {target_branch}'
|
||||
)
|
||||
description = f'Merging changes from {source_branch} into {target_branch}'
|
||||
|
||||
# Prepare the request payload
|
||||
payload = {
|
||||
@@ -499,11 +498,9 @@ class GitLabService(BaseGitService, GitService):
|
||||
url=url, params=payload, method=RequestMethod.POST
|
||||
)
|
||||
|
||||
|
||||
return response['web_url']
|
||||
|
||||
|
||||
|
||||
gitlab_service_cls = os.environ.get(
|
||||
'OPENHANDS_GITLAB_SERVICE_CLS',
|
||||
'openhands.integrations.gitlab.gitlab_service.GitLabService',
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import typing
|
||||
from functools import lru_cache
|
||||
from typing import Callable
|
||||
import typing
|
||||
from uuid import UUID
|
||||
|
||||
import docker
|
||||
@@ -283,7 +283,9 @@ class DockerRuntime(ActionExecutionClient):
|
||||
self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}'
|
||||
|
||||
use_host_network = self.config.sandbox.use_host_network
|
||||
network_mode: typing.Literal['host'] | None = 'host' if use_host_network else None
|
||||
network_mode: typing.Literal['host'] | None = (
|
||||
'host' if use_host_network else None
|
||||
)
|
||||
|
||||
# Initialize port mappings
|
||||
port_mapping: dict[str, list[dict[str, str]]] | None = None
|
||||
@@ -356,7 +358,7 @@ class DockerRuntime(ActionExecutionClient):
|
||||
|
||||
try:
|
||||
if self.runtime_container_image is None:
|
||||
raise ValueError("Runtime container image is not set")
|
||||
raise ValueError('Runtime container image is not set')
|
||||
self.container = self.docker_client.containers.run(
|
||||
self.runtime_container_image,
|
||||
command=command,
|
||||
|
||||
@@ -363,7 +363,7 @@ class RemoteRuntime(ActionExecutionClient):
|
||||
self._session_api_key = start_response['session_api_key']
|
||||
self.log(
|
||||
'debug',
|
||||
f'Session API key setted',
|
||||
'Session API key set',
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -59,7 +59,7 @@ class MCPProxyManager:
|
||||
"""
|
||||
if len(self.config['mcpServers']) == 0:
|
||||
logger.info(
|
||||
f"No MCP servers configured for FastMCP Proxy, skipping initialization."
|
||||
'No MCP servers configured for FastMCP Proxy, skipping initialization.'
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -70,7 +70,7 @@ class MCPProxyManager:
|
||||
api_key=self.api_key,
|
||||
)
|
||||
|
||||
logger.info(f"FastMCP Proxy initialized successfully")
|
||||
logger.info('FastMCP Proxy initialized successfully')
|
||||
|
||||
async def mount_to_app(
|
||||
self, app: FastAPI, allow_origins: Optional[list[str]] = None
|
||||
@@ -83,9 +83,7 @@ class MCPProxyManager:
|
||||
allow_origins: List of allowed origins for CORS
|
||||
"""
|
||||
if len(self.config['mcpServers']) == 0:
|
||||
logger.info(
|
||||
f"No MCP servers configured for FastMCP Proxy, skipping mount."
|
||||
)
|
||||
logger.info('No MCP servers configured for FastMCP Proxy, skipping mount.')
|
||||
return
|
||||
|
||||
if not self.proxy:
|
||||
@@ -101,8 +99,7 @@ class MCPProxyManager:
|
||||
app.routes.remove('/mcp')
|
||||
|
||||
app.mount('/', mcp_app)
|
||||
logger.info(f"Mounted FastMCP Proxy app at /mcp")
|
||||
|
||||
logger.info('Mounted FastMCP Proxy app at /mcp')
|
||||
|
||||
async def update_and_remount(
|
||||
self,
|
||||
@@ -122,10 +119,7 @@ class MCPProxyManager:
|
||||
tools: List of tool configurations
|
||||
allow_origins: List of allowed origins for CORS
|
||||
"""
|
||||
tools = {
|
||||
t.name: t.model_dump()
|
||||
for t in stdio_servers
|
||||
}
|
||||
tools = {t.name: t.model_dump() for t in stdio_servers}
|
||||
self.config['mcpServers'] = tools
|
||||
|
||||
del self.proxy
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import random
|
||||
import socket
|
||||
import time
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
def check_port_available(port: int) -> bool:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
|
||||
@@ -15,7 +15,7 @@ from openhands.events.stream import EventStreamSubscriber, session_exists
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||
from openhands.server.monitoring import MonitoringListener
|
||||
from openhands.server.session.agent_session import AgentSession, WAIT_TIME_BEFORE_CLOSE
|
||||
from openhands.server.session.agent_session import WAIT_TIME_BEFORE_CLOSE, AgentSession
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.session.session import ROOM_KEY, Session
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
@@ -508,7 +508,9 @@ class StandaloneConversationManager(ConversationManager):
|
||||
session_api_key=None,
|
||||
event_store=session.agent_session.event_stream,
|
||||
status=_get_status_from_session(session),
|
||||
runtime_status=getattr(session.agent_session.runtime, 'runtime_status', None),
|
||||
runtime_status=getattr(
|
||||
session.agent_session.runtime, 'runtime_status', None
|
||||
),
|
||||
)
|
||||
|
||||
def _get_conversation_url(self, conversation_id: str):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.runtime.runtime_status import RuntimeStatus
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
@@ -5,13 +5,13 @@ from pydantic import BaseModel
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.memory.memory import Memory
|
||||
from openhands.microagent.types import InputMetadata
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.shared import conversation_manager
|
||||
from openhands.server.utils import get_conversation
|
||||
from openhands.microagent.types import InputMetadata
|
||||
from openhands.memory.memory import Memory
|
||||
|
||||
app = APIRouter(
|
||||
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
||||
@@ -216,7 +216,11 @@ async def get_microagents(
|
||||
content=agent.content,
|
||||
triggers=[],
|
||||
inputs=agent.metadata.inputs,
|
||||
tools=[server.name for server in agent.metadata.mcp_tools.stdio_servers] if agent.metadata.mcp_tools else [],
|
||||
tools=[
|
||||
server.name for server in agent.metadata.mcp_tools.stdio_servers
|
||||
]
|
||||
if agent.metadata.mcp_tools
|
||||
else [],
|
||||
)
|
||||
)
|
||||
|
||||
@@ -229,7 +233,11 @@ async def get_microagents(
|
||||
content=agent.content,
|
||||
triggers=agent.triggers,
|
||||
inputs=agent.metadata.inputs,
|
||||
tools=[server.name for server in agent.metadata.mcp_tools.stdio_servers] if agent.metadata.mcp_tools else [],
|
||||
tools=[
|
||||
server.name for server in agent.metadata.mcp_tools.stdio_servers
|
||||
]
|
||||
if agent.metadata.mcp_tools
|
||||
else [],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -6,15 +6,19 @@ from openhands.events.async_event_store_wrapper import AsyncEventStoreWrapper
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.utils import get_conversation
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
||||
app = APIRouter(
|
||||
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
||||
)
|
||||
|
||||
|
||||
@app.post('/submit-feedback')
|
||||
async def submit_feedback(request: Request, conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse:
|
||||
async def submit_feedback(
|
||||
request: Request, conversation: ServerConversation = Depends(get_conversation)
|
||||
) -> JSONResponse:
|
||||
"""Submit user feedback.
|
||||
|
||||
This function stores the provided feedback data.
|
||||
@@ -37,9 +41,7 @@ async def submit_feedback(request: Request, conversation: ServerConversation = D
|
||||
# Assuming the storage service is already configured in the backend
|
||||
# and there is a function to handle the storage.
|
||||
body = await request.json()
|
||||
async_store = AsyncEventStoreWrapper(
|
||||
conversation.event_stream, filter_hidden=True
|
||||
)
|
||||
async_store = AsyncEventStoreWrapper(conversation.event_stream, filter_hidden=True)
|
||||
trajectory = []
|
||||
async for event in async_store:
|
||||
trajectory.append(event_to_dict(event))
|
||||
|
||||
@@ -5,7 +5,6 @@ from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
Request,
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
@@ -27,17 +26,15 @@ from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.file_config import (
|
||||
FILES_TO_IGNORE,
|
||||
)
|
||||
from openhands.server.shared import (
|
||||
ConversationStoreImpl,
|
||||
config,
|
||||
)
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.server.utils import get_conversation, get_conversation_store
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
||||
app = APIRouter(
|
||||
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
||||
)
|
||||
|
||||
|
||||
@app.get(
|
||||
@@ -50,7 +47,7 @@ app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_
|
||||
)
|
||||
async def list_files(
|
||||
conversation: ServerConversation = Depends(get_conversation),
|
||||
path: str | None = None
|
||||
path: str | None = None,
|
||||
) -> list[str] | JSONResponse:
|
||||
"""List files in the specified path.
|
||||
|
||||
@@ -132,7 +129,9 @@ async def list_files(
|
||||
415: {'description': 'Unsupported media type', 'model': dict},
|
||||
},
|
||||
)
|
||||
async def select_file(file: str, conversation: ServerConversation = Depends(get_conversation)) -> FileResponse | JSONResponse:
|
||||
async def select_file(
|
||||
file: str, conversation: ServerConversation = Depends(get_conversation)
|
||||
) -> FileResponse | JSONResponse:
|
||||
"""Retrieve the content of a specified file.
|
||||
|
||||
To select a file:
|
||||
@@ -196,7 +195,9 @@ async def select_file(file: str, conversation: ServerConversation = Depends(get_
|
||||
500: {'description': 'Error zipping workspace', 'model': dict},
|
||||
},
|
||||
)
|
||||
def zip_current_workspace(conversation: ServerConversation = Depends(get_conversation)) -> FileResponse | JSONResponse:
|
||||
def zip_current_workspace(
|
||||
conversation: ServerConversation = Depends(get_conversation),
|
||||
) -> FileResponse | JSONResponse:
|
||||
try:
|
||||
logger.debug('Zipping workspace')
|
||||
runtime: Runtime = conversation.runtime
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import itertools
|
||||
import re
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
@@ -9,19 +9,18 @@ from fastapi.responses import JSONResponse
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import (
|
||||
ChangeAgentStateAction,
|
||||
NullAction,
|
||||
)
|
||||
from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.observation import (
|
||||
NullObservation,
|
||||
AgentStateChangedObservation,
|
||||
NullObservation,
|
||||
)
|
||||
|
||||
from openhands.core.config.llm_config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
ProviderHandler,
|
||||
@@ -38,10 +37,9 @@ from openhands.server.data_models.conversation_info import ConversationInfo
|
||||
from openhands.server.data_models.conversation_info_result_set import (
|
||||
ConversationInfoResultSet,
|
||||
)
|
||||
from openhands.server.services.conversation_service import create_new_conversation
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.services.conversation_service import create_new_conversation
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.shared import (
|
||||
ConversationStoreImpl,
|
||||
config,
|
||||
@@ -53,11 +51,12 @@ from openhands.server.user_auth import (
|
||||
get_provider_tokens,
|
||||
get_user_id,
|
||||
get_user_secrets,
|
||||
get_user_settings_store,
|
||||
get_user_settings,
|
||||
get_user_settings_store,
|
||||
)
|
||||
from openhands.server.user_auth.user_auth import AuthType
|
||||
from openhands.server.utils import get_conversation_store, get_conversation as get_conversation_object
|
||||
from openhands.server.utils import get_conversation as get_conversation_object
|
||||
from openhands.server.utils import get_conversation_store
|
||||
from openhands.storage.conversation.conversation_store import ConversationStore
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
@@ -295,7 +294,7 @@ async def delete_conversation(
|
||||
async def get_prompt(
|
||||
event_id: int,
|
||||
user_settings: SettingsStore = Depends(get_user_settings_store),
|
||||
conversation: ServerConversation | None = Depends(get_conversation_object)
|
||||
conversation: ServerConversation | None = Depends(get_conversation_object),
|
||||
):
|
||||
if conversation is None:
|
||||
return JSONResponse(
|
||||
@@ -409,7 +408,6 @@ async def start_conversation(
|
||||
logger.info(f'Starting conversation: {conversation_id}')
|
||||
|
||||
try:
|
||||
|
||||
# Check that the conversation exists
|
||||
try:
|
||||
await conversation_store.get_metadata(conversation_id)
|
||||
@@ -463,10 +461,17 @@ async def stop_conversation(
|
||||
|
||||
try:
|
||||
# Check if the conversation is running
|
||||
agent_loop_info = await conversation_manager.get_agent_loop_info(user_id=user_id, filter_to_sids={conversation_id})
|
||||
conversation_status = agent_loop_info[0].status if agent_loop_info else ConversationStatus.STOPPED
|
||||
agent_loop_info = await conversation_manager.get_agent_loop_info(
|
||||
user_id=user_id, filter_to_sids={conversation_id}
|
||||
)
|
||||
conversation_status = (
|
||||
agent_loop_info[0].status if agent_loop_info else ConversationStatus.STOPPED
|
||||
)
|
||||
|
||||
if conversation_status not in (ConversationStatus.STARTING, ConversationStatus.RUNNING):
|
||||
if conversation_status not in (
|
||||
ConversationStatus.STARTING,
|
||||
ConversationStatus.RUNNING,
|
||||
):
|
||||
return ConversationResponse(
|
||||
status='ok',
|
||||
conversation_id=conversation_id,
|
||||
@@ -505,7 +510,11 @@ def _get_contextual_events(event_stream: EventStream, event_id: int) -> str:
|
||||
|
||||
agent_event_filter = EventFilter(
|
||||
exclude_hidden=True,
|
||||
exclude_types=(NullAction, NullObservation, ChangeAgentStateAction, AgentStateChangedObservation
|
||||
exclude_types=(
|
||||
NullAction,
|
||||
NullObservation,
|
||||
ChangeAgentStateAction,
|
||||
AgentStateChangedObservation,
|
||||
),
|
||||
) # the types of events that can be in an agent's history
|
||||
|
||||
|
||||
@@ -87,7 +87,7 @@ async def create_pr(
|
||||
target_branch: Annotated[str, Field(description='Target branch on repo')],
|
||||
title: Annotated[str, Field(description='PR Title')],
|
||||
body: Annotated[str | None, Field(description='PR body')],
|
||||
draft: Annotated[bool, Field(description='Whether PR opened is a draft')] = True
|
||||
draft: Annotated[bool, Field(description='Whether PR opened is a draft')] = True,
|
||||
) -> str:
|
||||
"""Open a PR in GitHub"""
|
||||
|
||||
@@ -127,7 +127,7 @@ async def create_pr(
|
||||
target_branch=target_branch,
|
||||
title=title,
|
||||
body=body,
|
||||
draft=draft
|
||||
draft=draft,
|
||||
)
|
||||
|
||||
if conversation_id:
|
||||
@@ -148,7 +148,12 @@ async def create_mr(
|
||||
],
|
||||
source_branch: Annotated[str, Field(description='Source branch on repo')],
|
||||
target_branch: Annotated[str, Field(description='Target branch on repo')],
|
||||
title: Annotated[str, Field(description='MR Title. Start title with `DRAFT:` or `WIP:` if applicable.')],
|
||||
title: Annotated[
|
||||
str,
|
||||
Field(
|
||||
description='MR Title. Start title with `DRAFT:` or `WIP:` if applicable.'
|
||||
),
|
||||
],
|
||||
description: Annotated[str | None, Field(description='MR description')],
|
||||
) -> str:
|
||||
"""Open a MR in GitLab"""
|
||||
|
||||
@@ -8,14 +8,18 @@ from fastapi import (
|
||||
)
|
||||
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.utils import get_conversation
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.utils import get_conversation
|
||||
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
||||
app = APIRouter(
|
||||
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
||||
)
|
||||
|
||||
|
||||
@app.route('/security/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE'])
|
||||
async def security_api(request: Request, conversation: ServerConversation = Depends(get_conversation)) -> Response:
|
||||
async def security_api(
|
||||
request: Request, conversation: ServerConversation = Depends(get_conversation)
|
||||
) -> Response:
|
||||
"""Catch-all route for security analyzer API requests.
|
||||
|
||||
Each request is handled directly to the security analyzer.
|
||||
@@ -35,6 +39,4 @@ async def security_api(request: Request, conversation: ServerConversation = Depe
|
||||
detail='Security analyzer not initialized',
|
||||
)
|
||||
|
||||
return await conversation.security_analyzer.handle_api_request(
|
||||
request
|
||||
)
|
||||
return await conversation.security_analyzer.handle_api_request(request)
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
from fastapi import APIRouter, Depends, Request, status
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.async_event_store_wrapper import AsyncEventStoreWrapper
|
||||
from openhands.events.serialization import event_to_trajectory
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.utils import get_conversation
|
||||
from openhands.server.session.conversation import ServerConversation
|
||||
from openhands.server.utils import get_conversation
|
||||
|
||||
app = APIRouter(prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies())
|
||||
app = APIRouter(
|
||||
prefix='/api/conversations/{conversation_id}', dependencies=get_dependencies()
|
||||
)
|
||||
|
||||
|
||||
@app.get('/trajectory')
|
||||
async def get_trajectory(conversation: ServerConversation = Depends(get_conversation)) -> JSONResponse:
|
||||
async def get_trajectory(
|
||||
conversation: ServerConversation = Depends(get_conversation),
|
||||
) -> JSONResponse:
|
||||
"""Get trajectory.
|
||||
|
||||
This function retrieves the current trajectory and returns it.
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
@@ -80,7 +79,6 @@ async def create_new_conversation(
|
||||
session_init_args['conversation_instructions'] = conversation_instructions
|
||||
conversation_init_data = ConversationInitData(**session_init_args)
|
||||
|
||||
|
||||
logger.info('Loading conversation store')
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
logger.info('ServerConversation store loaded')
|
||||
@@ -90,13 +88,14 @@ async def create_new_conversation(
|
||||
conversation_id = uuid.uuid4().hex
|
||||
|
||||
if not await conversation_store.exists(conversation_id):
|
||||
|
||||
logger.info(
|
||||
f'New conversation ID: {conversation_id}',
|
||||
extra={'user_id': user_id, 'session_id': conversation_id},
|
||||
)
|
||||
|
||||
conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test(user_id, conversation_id, conversation_init_data)
|
||||
conversation_init_data = ExperimentManagerImpl.run_conversation_variant_test(
|
||||
user_id, conversation_id, conversation_init_data
|
||||
)
|
||||
conversation_title = get_default_conversation_title(conversation_id)
|
||||
|
||||
logger.info(f'Saving metadata for conversation {conversation_id}')
|
||||
|
||||
@@ -197,23 +197,21 @@ class AgentSession:
|
||||
finally:
|
||||
self._starting = False
|
||||
success = finished and runtime_connected
|
||||
duration = (time.time() - started_at)
|
||||
duration = time.time() - started_at
|
||||
|
||||
log_metadata = {
|
||||
'signal': 'agent_session_start',
|
||||
'success': success,
|
||||
'duration': duration,
|
||||
'restored_state': restored_state
|
||||
'restored_state': restored_state,
|
||||
}
|
||||
if success:
|
||||
self.logger.info(
|
||||
f'Agent session start succeeded in {duration}s',
|
||||
extra=log_metadata
|
||||
f'Agent session start succeeded in {duration}s', extra=log_metadata
|
||||
)
|
||||
else:
|
||||
self.logger.error(
|
||||
f'Agent session start failed in {duration}s',
|
||||
extra=log_metadata
|
||||
f'Agent session start failed in {duration}s', extra=log_metadata
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
|
||||
@@ -105,7 +105,12 @@ class FileConversationStore(ConversationStore):
|
||||
async def get_instance(
|
||||
cls, config: OpenHandsConfig, user_id: str | None
|
||||
) -> FileConversationStore:
|
||||
file_store = get_file_store(config.file_store, config.file_store_path, config.file_store_web_hook_url, config.file_store_web_hook_headers)
|
||||
file_store = get_file_store(
|
||||
config.file_store,
|
||||
config.file_store_path,
|
||||
config.file_store_web_hook_url,
|
||||
config.file_store_web_hook_headers,
|
||||
)
|
||||
return FileConversationStore(file_store)
|
||||
|
||||
|
||||
|
||||
@@ -43,6 +43,6 @@ class FileSecretsStore(SecretsStore):
|
||||
config.file_store,
|
||||
config.file_store_path,
|
||||
config.file_store_web_hook_url,
|
||||
config.file_store_web_hook_headers
|
||||
config.file_store_web_hook_headers,
|
||||
)
|
||||
return FileSecretsStore(file_store)
|
||||
|
||||
@@ -37,6 +37,6 @@ class FileSettingsStore(SettingsStore):
|
||||
config.file_store,
|
||||
config.file_store_path,
|
||||
config.file_store_web_hook_url,
|
||||
config.file_store_web_hook_headers
|
||||
config.file_store_web_hook_headers,
|
||||
)
|
||||
return FileSettingsStore(file_store)
|
||||
|
||||
@@ -10,24 +10,36 @@ class TestTranslationCompleteness(unittest.TestCase):
|
||||
|
||||
def test_translation_completeness_check_runs(self):
|
||||
"""Test that the translation completeness check script can be executed."""
|
||||
frontend_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), "frontend")
|
||||
script_path = os.path.join(frontend_dir, "scripts", "check-translation-completeness.cjs")
|
||||
frontend_dir = os.path.join(
|
||||
os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
),
|
||||
'frontend',
|
||||
)
|
||||
script_path = os.path.join(
|
||||
frontend_dir, 'scripts', 'check-translation-completeness.cjs'
|
||||
)
|
||||
|
||||
# Verify the script exists
|
||||
self.assertTrue(os.path.exists(script_path), f"Script not found at {script_path}")
|
||||
self.assertTrue(
|
||||
os.path.exists(script_path), f'Script not found at {script_path}'
|
||||
)
|
||||
|
||||
# Verify the script is executable
|
||||
self.assertTrue(os.access(script_path, os.X_OK), f"Script at {script_path} is not executable")
|
||||
self.assertTrue(
|
||||
os.access(script_path, os.X_OK),
|
||||
f'Script at {script_path} is not executable',
|
||||
)
|
||||
|
||||
# Run the script (it may fail due to missing translations, but we just want to verify it runs)
|
||||
try:
|
||||
subprocess.run(
|
||||
["node", script_path],
|
||||
['node', script_path],
|
||||
cwd=frontend_dir,
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True
|
||||
text=True,
|
||||
)
|
||||
# We don't assert on the return code because it might fail due to missing translations
|
||||
except Exception as e:
|
||||
self.fail(f"Failed to run translation completeness check: {e}")
|
||||
self.fail(f'Failed to run translation completeness check: {e}')
|
||||
|
||||
@@ -100,7 +100,6 @@ def mock_conversation_instructions_template():
|
||||
return 'Instructions: {{ repo_instruction }}'
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_followup_prompt_template():
|
||||
return 'Issue context: {{ issues }}\n\nReview comments: {{ review_comments }}\n\nReview threads: {{ review_threads }}\n\nFiles: {{ files }}\n\nThread comments: {{ thread_context }}\n\nPlease fix this issue.'
|
||||
@@ -532,7 +531,11 @@ async def test_process_issue(
|
||||
handler_instance.guess_success.assert_not_called()
|
||||
|
||||
|
||||
def test_get_instruction(mock_user_instructions_template, mock_conversation_instructions_template, mock_followup_prompt_template):
|
||||
def test_get_instruction(
|
||||
mock_user_instructions_template,
|
||||
mock_conversation_instructions_template,
|
||||
mock_followup_prompt_template,
|
||||
):
|
||||
issue = Issue(
|
||||
owner='test_owner',
|
||||
repo='test_repo',
|
||||
@@ -545,7 +548,10 @@ def test_get_instruction(mock_user_instructions_template, mock_conversation_inst
|
||||
GithubIssueHandler('owner', 'repo', 'token'), mock_llm_config
|
||||
)
|
||||
instruction, conversation_instructions, images_urls = issue_handler.get_instruction(
|
||||
issue, mock_user_instructions_template, mock_conversation_instructions_template, None
|
||||
issue,
|
||||
mock_user_instructions_template,
|
||||
mock_conversation_instructions_template,
|
||||
None,
|
||||
)
|
||||
expected_instruction = 'Issue: Test Issue\n\nThis is a test issue refer to image \n\nPlease fix this issue.'
|
||||
|
||||
@@ -576,7 +582,10 @@ def test_get_instruction(mock_user_instructions_template, mock_conversation_inst
|
||||
GithubPRHandler('owner', 'repo', 'token'), mock_llm_config
|
||||
)
|
||||
instruction, conversation_instructions, images_urls = pr_handler.get_instruction(
|
||||
issue, mock_followup_prompt_template, mock_conversation_instructions_template, None
|
||||
issue,
|
||||
mock_followup_prompt_template,
|
||||
mock_conversation_instructions_template,
|
||||
None,
|
||||
)
|
||||
expected_instruction = "Issue context: [\n \"Issue 1 fix the type\"\n]\n\nReview comments: None\n\nReview threads: [\n \"There is still a typo 'pthon' instead of 'python'\"\n]\n\nFiles: []\n\nThread comments: I've left review comments, please address them\n---\nThis is a valid concern.\n\nPlease fix this issue."
|
||||
|
||||
@@ -601,7 +610,9 @@ def test_file_instruction():
|
||||
with open('openhands/resolver/prompts/resolve/basic.jinja', 'r') as f:
|
||||
prompt = f.read()
|
||||
|
||||
with open('openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r') as f:
|
||||
with open(
|
||||
'openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r'
|
||||
) as f:
|
||||
conversation_instructions_template = f.read()
|
||||
|
||||
# Test without thread comments
|
||||
@@ -620,7 +631,6 @@ Test Issue
|
||||
|
||||
This is a test issue """
|
||||
|
||||
|
||||
expected_conversation_instructions = """IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.
|
||||
You SHOULD INCLUDE PROPER INDENTATION in your edit commands.
|
||||
|
||||
@@ -644,7 +654,9 @@ def test_file_instruction_with_repo_instruction():
|
||||
with open('openhands/resolver/prompts/resolve/basic.jinja', 'r') as f:
|
||||
prompt = f.read()
|
||||
|
||||
with open('openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r') as f:
|
||||
with open(
|
||||
'openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r'
|
||||
) as f:
|
||||
conversation_instructions_prompt = f.read()
|
||||
|
||||
# load repo instruction from openhands/resolver/prompts/repo_instructions/all-hands-ai___openhands-resolver.txt
|
||||
@@ -662,7 +674,6 @@ def test_file_instruction_with_repo_instruction():
|
||||
issue, prompt, conversation_instructions_prompt, repo_instruction
|
||||
)
|
||||
|
||||
|
||||
expected_instruction = """Please fix the following issue for the repository in /workspace.
|
||||
An environment has been set up for you to start working. You may assume all necessary tools are installed.
|
||||
|
||||
@@ -683,7 +694,6 @@ This is a Python repo for openhands-resolver, a library that attempts to resolve
|
||||
|
||||
When you think you have fixed the issue through code changes, please finish the interaction."""
|
||||
|
||||
|
||||
assert instruction == expected_instruction
|
||||
assert conversation_instructions == expected_conversation_instructions
|
||||
assert conversation_instructions is not None
|
||||
@@ -785,7 +795,9 @@ def test_instruction_with_thread_comments():
|
||||
with open('openhands/resolver/prompts/resolve/basic.jinja', 'r') as f:
|
||||
prompt = f.read()
|
||||
|
||||
with open('openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r') as f:
|
||||
with open(
|
||||
'openhands/resolver/prompts/resolve/basic-conversation-instructions.jinja', 'r'
|
||||
) as f:
|
||||
conversation_instructions_template = f.read()
|
||||
|
||||
llm_config = LLMConfig(model='test', api_key='test')
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
from typing import Type
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
@@ -8,11 +5,11 @@ from openhands.core.config import LLMConfig
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.resolver.interfaces.github import GithubIssueHandler, GithubPRHandler
|
||||
from openhands.resolver.interfaces.gitlab import GitlabIssueHandler, GitlabPRHandler
|
||||
from openhands.resolver.issue_handler_factory import IssueHandlerFactory
|
||||
from openhands.resolver.interfaces.issue_definitions import (
|
||||
ServiceContextIssue,
|
||||
ServiceContextPR,
|
||||
)
|
||||
from openhands.resolver.issue_handler_factory import IssueHandlerFactory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -45,20 +42,17 @@ test_cases = [
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'platform,issue_type,expected_context_type,expected_handler_type',
|
||||
test_cases
|
||||
'platform,issue_type,expected_context_type,expected_handler_type', test_cases
|
||||
)
|
||||
def test_handler_creation(
|
||||
factory_params,
|
||||
platform: ProviderType,
|
||||
issue_type: str,
|
||||
expected_context_type: Type,
|
||||
expected_handler_type: Type,
|
||||
expected_context_type: type,
|
||||
expected_handler_type: type,
|
||||
):
|
||||
factory = IssueHandlerFactory(
|
||||
**factory_params,
|
||||
platform=platform,
|
||||
issue_type=issue_type
|
||||
**factory_params, platform=platform, issue_type=issue_type
|
||||
)
|
||||
|
||||
handler = factory.create()
|
||||
@@ -66,11 +60,10 @@ def test_handler_creation(
|
||||
assert isinstance(handler, expected_context_type)
|
||||
assert isinstance(handler._strategy, expected_handler_type)
|
||||
|
||||
|
||||
def test_invalid_issue_type(factory_params):
|
||||
factory = IssueHandlerFactory(
|
||||
**factory_params,
|
||||
platform=ProviderType.GITHUB,
|
||||
issue_type='invalid'
|
||||
**factory_params, platform=ProviderType.GITHUB, issue_type='invalid'
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match='Invalid issue type: invalid'):
|
||||
|
||||
@@ -2,7 +2,7 @@ from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.core.config import SandboxConfig,OpenHandsConfig
|
||||
from openhands.core.config import OpenHandsConfig, SandboxConfig
|
||||
from openhands.events.action import CmdRunAction
|
||||
from openhands.resolver.issue_resolver import IssueResolver
|
||||
|
||||
@@ -36,7 +36,8 @@ def test_setup_sandbox_config_default():
|
||||
)
|
||||
|
||||
assert_sandbox_config(
|
||||
openhands_config.sandbox, runtime_container_image='ghcr.io/all-hands-ai/runtime:mock-nikolaik'
|
||||
openhands_config.sandbox,
|
||||
runtime_container_image='ghcr.io/all-hands-ai/runtime:mock-nikolaik',
|
||||
)
|
||||
|
||||
|
||||
@@ -68,7 +69,9 @@ def test_setup_sandbox_config_base_only():
|
||||
)
|
||||
|
||||
assert_sandbox_config(
|
||||
openhands_config.sandbox, base_container_image=base_image, runtime_container_image=None
|
||||
openhands_config.sandbox,
|
||||
base_container_image=base_image,
|
||||
runtime_container_image=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -84,7 +87,9 @@ def test_setup_sandbox_config_runtime_only():
|
||||
is_experimental=False,
|
||||
)
|
||||
|
||||
assert_sandbox_config(openhands_config.sandbox, runtime_container_image=runtime_image)
|
||||
assert_sandbox_config(
|
||||
openhands_config.sandbox, runtime_container_image=runtime_image
|
||||
)
|
||||
|
||||
|
||||
def test_setup_sandbox_config_experimental():
|
||||
@@ -117,7 +122,9 @@ def test_setup_sandbox_config_gitlab_ci(mock_get_unique_uid, mock_getuid):
|
||||
is_experimental=False,
|
||||
)
|
||||
|
||||
assert_sandbox_config(openhands_config.sandbox, local_runtime_url='http://localhost')
|
||||
assert_sandbox_config(
|
||||
openhands_config.sandbox, local_runtime_url='http://localhost'
|
||||
)
|
||||
|
||||
|
||||
@mock.patch('openhands.resolver.issue_resolver.os.getuid', return_value=1000)
|
||||
@@ -134,7 +141,9 @@ def test_setup_sandbox_config_gitlab_ci_non_root(mock_getuid):
|
||||
is_experimental=False,
|
||||
)
|
||||
|
||||
assert_sandbox_config(openhands_config.sandbox, local_runtime_url='http://localhost')
|
||||
assert_sandbox_config(
|
||||
openhands_config.sandbox, local_runtime_url='http://localhost'
|
||||
)
|
||||
|
||||
|
||||
@mock.patch('openhands.events.observation.CmdOutputObservation')
|
||||
|
||||
Reference in New Issue
Block a user