[chore] Run full agent pre-commit (#8235)

This commit is contained in:
Engel Nyst
2025-05-03 17:24:03 +02:00
committed by GitHub
parent 98cb2e24ee
commit 985e20d529
27 changed files with 186 additions and 147 deletions

View File

@@ -36,13 +36,12 @@ from openhands.core.config import (
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction, MessageAction, FileReadAction
from openhands.events.action import CmdRunAction, FileReadAction, MessageAction
from openhands.events.observation import CmdOutputObservation, ErrorObservation
from openhands.events.serialization.event import event_to_dict
from openhands.runtime.base import Runtime
from openhands.utils.async_utils import call_async_from_sync
from openhands.utils.shutdown_listener import sleep_if_should_continue
import pdb
USE_HINT_TEXT = os.environ.get('USE_HINT_TEXT', 'false').lower() == 'true'
USE_INSTANCE_IMAGE = os.environ.get('USE_INSTANCE_IMAGE', 'true').lower() == 'true'
@@ -51,7 +50,7 @@ RUN_WITH_BROWSING = os.environ.get('RUN_WITH_BROWSING', 'false').lower() == 'tru
# TODO: migrate all swe-bench docker to ghcr.io/openhands
# TODO: 适应所有的语言
DOCKER_IMAGE_PREFIX = os.environ.get('EVAL_DOCKER_IMAGE_PREFIX', '')
LANGUAGE =os.environ.get('LANGUAGE', 'python')
LANGUAGE = os.environ.get('LANGUAGE', 'python')
logger.info(f'Using docker image prefix: {DOCKER_IMAGE_PREFIX}')
@@ -71,7 +70,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
# Instruction based on Anthropic's official trajectory
# https://github.com/eschluntz/swe-bench-experiments/tree/main/evaluation/verified/20241022_tools_claude-3-5-sonnet-updated/trajs
instructions = {
"python":(
'python': (
'<uploaded_files>\n'
f'/workspace/{workspace_dir_name}\n'
'</uploaded_files>\n'
@@ -96,7 +95,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
' Make sure all these tests pass with your changes.\n'
"Your thinking should be thorough and so it's fine if it's very long.\n"
),
"java": (
'java': (
'<uploaded_files>\n'
f'/workspace/{workspace_dir_name}\n'
'</uploaded_files>\n'
@@ -121,7 +120,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
" Make sure all these tests pass with your changes.\n"
"Your thinking should be thorough and so it's fine if it's very long.\n"
),
"go": (
'go': (
'<uploaded_files>\n'
f'/workspace/{workspace_dir_name}\n'
'</uploaded_files>\n'
@@ -146,7 +145,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
' Make sure all these tests pass with your changes.\n'
"Your thinking should be thorough and so it's fine if it's very long.\n"
),
"c": (
'c': (
'<uploaded_files>\n'
f'/workspace/{workspace_dir_name}\n'
'</uploaded_files>\n'
@@ -171,7 +170,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
' Make sure all these tests pass with your changes.\n'
"Your thinking should be thorough and so it's fine if it's very long.\n"
),
"cpp": (
'cpp': (
'<uploaded_files>\n'
f'/workspace/{workspace_dir_name}\n'
'</uploaded_files>\n'
@@ -196,7 +195,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
' Make sure all these tests pass with your changes.\n'
"Your thinking should be thorough and so it's fine if it's very long.\n"
),
"javascript": (
'javascript': (
'<uploaded_files>\n'
f'/workspace/{workspace_dir_name}\n'
'</uploaded_files>\n'
@@ -221,7 +220,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
' Make sure all these tests pass with your changes.\n'
"Your thinking should be thorough and so it's fine if it's very long.\n"
),
"typescript":(
'typescript': (
'<uploaded_files>\n'
f'/workspace/{workspace_dir_name}\n'
'</uploaded_files>\n'
@@ -246,7 +245,7 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
' Make sure all these tests pass with your changes.\n'
"Your thinking should be thorough and so it's fine if it's very long.\n"
),
"rust":(
'rust': (
'<uploaded_files>\n'
f'/workspace/{workspace_dir_name}\n'
'</uploaded_files>\n'
@@ -270,11 +269,10 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
' - The functions you changed\n'
' Make sure all these tests pass with your changes.\n'
"Your thinking should be thorough and so it's fine if it's very long.\n"
)
),
}
instruction = instructions.get(LANGUAGE.lower())
if instruction and RUN_WITH_BROWSING:
instruction += (
'<IMPORTANT!>\n'
@@ -284,7 +282,6 @@ def get_instruction(instance: pd.Series, metadata: EvalMetadata):
return instruction
# TODO: 适应所有的语言
# def get_instance_docker_image(instance_id: str) -> str:
# image_name = 'sweb.eval.x86_64.' + instance_id
@@ -307,16 +304,15 @@ def get_instance_docker_image(instance: pd.Series):
container_name = container_name.replace('/', '_m_')
instance_id = instance.get('instance_id', '')
tag_suffix = instance_id.split('-')[-1] if instance_id else ''
container_tag = f"pr-{tag_suffix}"
container_tag = f'pr-{tag_suffix}'
# pdb.set_trace()
return f"mswebench/{container_name}:{container_tag}"
return f'mswebench/{container_name}:{container_tag}'
# return "kong/insomnia:pr-8284"
# return "'sweb.eval.x86_64.local_insomnia"
# return "local_insomnia_why"
# return "local/kong-insomnia:pr-8117"
def get_config(
instance: pd.Series,
metadata: EvalMetadata,
@@ -569,7 +565,6 @@ def complete_runtime(
f'Failed to git config --global core.pager "": {str(obs)}',
)
action = CmdRunAction(command='git add -A')
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
@@ -582,14 +577,14 @@ def complete_runtime(
##删除二进制文件
action = CmdRunAction(
command=f'''
command="""
for file in $(git status --porcelain | grep -E "^(M| M|\\?\\?|A| A)" | cut -c4-); do
if [ -f "$file" ] && (file "$file" | grep -q "executable" || git check-attr binary "$file" | grep -q "binary: set"); then
git rm -f "$file" 2>/dev/null || rm -f "$file"
echo "Removed: $file"
fi
done
'''
"""
)
action.set_hard_timeout(600)
logger.info(action, extra={'msg_type': 'ACTION'})
@@ -626,9 +621,7 @@ def complete_runtime(
else:
assert_and_raise(False, f'Unexpected observation type: {str(obs)}')
action = FileReadAction(
path='patch.diff'
)
action = FileReadAction(path='patch.diff')
action.set_hard_timeout(max(300 + 100 * n_retries, 600))
logger.info(action, extra={'msg_type': 'ACTION'})
obs = runtime.run_action(action)
@@ -714,12 +707,12 @@ def process_instance(
is_binary_block = False
for line in lines:
if line.startswith("diff --git "):
if line.startswith('diff --git '):
if block and not is_binary_block:
cleaned_lines.extend(block)
block = [line]
is_binary_block = False
elif "Binary files" in line:
elif 'Binary files' in line:
is_binary_block = True
block.append(line)
else:
@@ -727,7 +720,8 @@ def process_instance(
if block and not is_binary_block:
cleaned_lines.extend(block)
return "\n".join(cleaned_lines)
return '\n'.join(cleaned_lines)
git_patch = remove_binary_diffs(git_patch)
test_result = {
'git_patch': git_patch,
@@ -797,7 +791,7 @@ if __name__ == '__main__':
# so we don't need to manage file uploading to OpenHands's repo
# dataset = load_dataset(args.dataset, split=args.split)
# dataset = load_dataset(args.dataset)
dataset = load_dataset("json", data_files = args.dataset)
dataset = load_dataset('json', data_files=args.dataset)
dataset = dataset[args.split]
swe_bench_tests = filter_dataset(dataset.to_pandas(), 'instance_id')
logger.info(

View File

@@ -3,7 +3,9 @@ import json
input_file = 'XXX.jsonl'
output_file = 'YYY.jsonl'
with open(input_file, 'r', encoding='utf-8') as fin, open(output_file, 'w', encoding='utf-8') as fout:
with open(input_file, 'r', encoding='utf-8') as fin, open(
output_file, 'w', encoding='utf-8'
) as fout:
for line in fin:
line = line.strip()
if not line:
@@ -13,18 +15,22 @@ with open(input_file, 'r', encoding='utf-8') as fin, open(output_file, 'w', enco
item = data
# 提取原始数据
org = item.get("org", "")
repo = item.get("repo", "")
number = str(item.get("number", ""))
org = item.get('org', '')
repo = item.get('repo', '')
number = str(item.get('number', ''))
new_item = {}
new_item["repo"] = f"{org}/{repo}"
new_item["instance_id"] = f"{org}__{repo}-{number}"
new_item["problem_statement"] = item["resolved_issues"][0].get("title", "") + "\n" + item["resolved_issues"][0].get("body", "")
new_item["FAIL_TO_PASS"] = []
new_item["PASS_TO_PASS"] = []
new_item["base_commit"] = item['base'].get("sha","")
new_item["version"] = "0.1" # depends
new_item['repo'] = f'{org}/{repo}'
new_item['instance_id'] = f'{org}__{repo}-{number}'
new_item['problem_statement'] = (
item['resolved_issues'][0].get('title', '')
+ '\n'
+ item['resolved_issues'][0].get('body', '')
)
new_item['FAIL_TO_PASS'] = []
new_item['PASS_TO_PASS'] = []
new_item['base_commit'] = item['base'].get('sha', '')
new_item['version'] = '0.1' # depends
output_data = new_item
fout.write(json.dumps(output_data, ensure_ascii=False) + "\n")
fout.write(json.dumps(output_data, ensure_ascii=False) + '\n')

View File

@@ -15,7 +15,7 @@ def main():
'org': groups.group(1),
'repo': groups.group(2),
'number': groups.group(3),
'fix_patch': data['test_result']['git_patch']
'fix_patch': data['test_result']['git_patch'],
}
fout.write(json.dumps(patch) + '\n')

View File

@@ -390,7 +390,9 @@ class GitHubService(BaseGitService, GitService):
except Exception:
return []
async def get_repository_details_from_repo_name(self, repository: str) -> Repository:
async def get_repository_details_from_repo_name(
self, repository: str
) -> Repository:
url = f'{self.BASE_URL}/repos/{repository}'
repo, _ = await self._make_request(url)

View File

@@ -382,9 +382,10 @@ class GitLabService(BaseGitService, GitService):
except Exception:
return []
async def get_repository_details_from_repo_name(self, repository: str) -> Repository:
encoded_name = repository.replace("/", "%2F")
async def get_repository_details_from_repo_name(
self, repository: str
) -> Repository:
encoded_name = repository.replace('/', '%2F')
url = f'{self.BASE_URL}/projects/{encoded_name}'
repo, _ = await self._make_request(url)
@@ -398,8 +399,6 @@ class GitLabService(BaseGitService, GitService):
)
gitlab_service_cls = os.environ.get(
'OPENHANDS_GITLAB_SERVICE_CLS',
'openhands.integrations.gitlab.gitlab_service.GitLabService',

View File

@@ -1,4 +1,3 @@
import asyncio
import os
import tempfile
import threading
@@ -46,6 +45,7 @@ from openhands.runtime.utils.request import send_request
from openhands.utils.http_session import HttpSession
from openhands.utils.tenacity_stop import stop_if_should_exit
def _is_retryable_error(exception):
return isinstance(
exception, (httpx.RemoteProtocolError, httpcore.RemoteProtocolError)
@@ -358,7 +358,8 @@ class ActionExecutionClient(Runtime):
async def call_tool_mcp(self, action: MCPAction) -> Observation:
# Import here to avoid circular imports
from openhands.mcp.utils import create_mcp_clients, call_tool_mcp as call_tool_mcp_handler
from openhands.mcp.utils import call_tool_mcp as call_tool_mcp_handler
from openhands.mcp.utils import create_mcp_clients
# Get the updated MCP config
updated_mcp_config = self.get_updated_mcp_config()

View File

@@ -10,8 +10,8 @@ from openhands.events.event_store import EventStore
from openhands.server.config.server_config import ServerConfig
from openhands.server.monitoring import MonitoringListener
from openhands.server.session.conversation import Conversation
from openhands.storage.data_models.settings import Settings
from openhands.storage.conversation.conversation_store import ConversationStore
from openhands.storage.data_models.settings import Settings
from openhands.storage.files import FileStore

View File

@@ -18,9 +18,9 @@ from openhands.server.monitoring import MonitoringListener
from openhands.server.session.agent_session import WAIT_TIME_BEFORE_CLOSE
from openhands.server.session.conversation import Conversation
from openhands.server.session.session import ROOM_KEY, Session
from openhands.storage.data_models.settings import Settings
from openhands.storage.conversation.conversation_store import ConversationStore
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
from openhands.storage.data_models.settings import Settings
from openhands.storage.files import FileStore
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync, wait_all
from openhands.utils.import_utils import get_impl

View File

@@ -14,7 +14,11 @@ from openhands.integrations.provider import (
PROVIDER_TOKEN_TYPE,
ProviderHandler,
)
from openhands.integrations.service_types import AuthenticationError, ProviderType, Repository, SuggestedTask
from openhands.integrations.service_types import (
AuthenticationError,
ProviderType,
SuggestedTask,
)
from openhands.runtime import get_runtime_cls
from openhands.server.data_models.conversation_info import ConversationInfo
from openhands.server.data_models.conversation_info_result_set import (
@@ -45,7 +49,6 @@ from openhands.storage.data_models.conversation_status import ConversationStatus
from openhands.utils.async_utils import wait_all
from openhands.utils.conversation_summary import generate_conversation_title
app = APIRouter(prefix='/api')
@@ -71,10 +74,13 @@ async def _create_new_conversation(
conversation_trigger: ConversationTrigger = ConversationTrigger.GUI,
attach_convo_id: bool = False,
):
logger.info(
'Creating conversation',
extra={'signal': 'create_conversation', 'user_id': user_id, 'trigger': conversation_trigger.value},
extra={
'signal': 'create_conversation',
'user_id': user_id,
'trigger': conversation_trigger.value,
},
)
logger.info('Loading settings')
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
@@ -163,7 +169,7 @@ async def new_conversation(
data: InitSessionRequest,
user_id: str = Depends(get_user_id),
provider_tokens: PROVIDER_TOKEN_TYPE = Depends(get_provider_tokens),
auth_type: AuthType | None = Depends(get_auth_type)
auth_type: AuthType | None = Depends(get_auth_type),
):
"""Initialize a new session or join an existing one.
@@ -202,7 +208,7 @@ async def new_conversation(
initial_user_msg=initial_user_msg,
image_urls=image_urls,
replay_json=replay_json,
conversation_trigger=conversation_trigger
conversation_trigger=conversation_trigger,
)
return JSONResponse(
@@ -233,7 +239,7 @@ async def new_conversation(
content={
'status': 'error',
'message': str(e),
'msg_id': 'STATUS$GIT_PROVIDER_AUTHENTICATION_ERROR'
'msg_id': 'STATUS$GIT_PROVIDER_AUTHENTICATION_ERROR',
},
status_code=status.HTTP_400_BAD_REQUEST,
)

View File

@@ -2,9 +2,8 @@ from typing import Any
from fastapi import APIRouter
from openhands.security.options import SecurityAnalyzers
from openhands.controller.agent import Agent
from openhands.security.options import SecurityAnalyzers
from openhands.server.shared import config, server_config
from openhands.utils.llm import get_supported_llm_models

View File

@@ -15,12 +15,12 @@ from openhands.server.settings import (
POSTSettingsModel,
)
from openhands.server.shared import config
from openhands.storage.data_models.settings import Settings
from openhands.server.user_auth import (
get_provider_tokens,
get_user_settings,
get_user_settings_store,
)
from openhands.storage.data_models.settings import Settings
from openhands.storage.settings.settings_store import SettingsStore
app = APIRouter(prefix='/api')
@@ -38,7 +38,7 @@ async def load_settings(
content={'error': 'Settings not found'},
)
provider_tokens_set: dict[ProviderType, str | None] = {}
provider_tokens_set: dict[ProviderType, str | None] = {}
if provider_tokens:
for provider_type, provider_token in provider_tokens.items():
if provider_token.token or provider_token.user_id:
@@ -227,8 +227,7 @@ async def store_provider_tokens(
if existing_settings:
if existing_settings.secrets_store:
existing_providers = [
provider
for provider in existing_settings.secrets_store.provider_tokens
provider for provider in existing_settings.secrets_store.provider_tokens
]
# Merge incoming settings store with the existing one
@@ -334,7 +333,11 @@ def convert_to_settings(settings_with_token_data: POSTSettingsModel) -> Settings
# Create new provider tokens immutably
if settings_with_token_data.provider_tokens:
settings = settings.model_copy(
update={'secrets_store': SecretStore(provider_tokens=settings_with_token_data.provider_tokens)}
update={
'secrets_store': SecretStore(
provider_tokens=settings_with_token_data.provider_tokens
)
}
)
return settings

View File

@@ -17,7 +17,6 @@ from openhands.events.action import ChangeAgentStateAction, MessageAction
from openhands.events.event import Event, EventSource
from openhands.events.stream import EventStream
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
from openhands.integrations.service_types import Repository
from openhands.mcp import add_mcp_tools_to_agent
from openhands.memory.memory import Memory
from openhands.microagent.microagent import BaseMicroagent
@@ -420,9 +419,7 @@ class AgentSession:
memory.load_user_workspace_microagents(microagents)
if selected_repository and repo_directory:
memory.set_repository_info(
selected_repository, repo_directory
)
memory.set_repository_info(selected_repository, repo_directory)
return memory
def _maybe_restore_state(self) -> State | None:

View File

@@ -21,8 +21,8 @@ from openhands.events.observation import (
CmdOutputObservation,
NullObservation,
)
from openhands.events.observation.error import ErrorObservation
from openhands.events.observation.agent import RecallObservation
from openhands.events.observation.error import ErrorObservation
from openhands.events.serialization import event_from_dict, event_to_dict
from openhands.events.stream import EventStreamSubscriber
from openhands.llm.llm import LLM
@@ -214,7 +214,8 @@ class Session:
await self.send(event_to_dict(event))
# NOTE: ipython observations are not sent here currently
elif event.source == EventSource.ENVIRONMENT and isinstance(
event, (CmdOutputObservation, AgentStateChangedObservation, RecallObservation)
event,
(CmdOutputObservation, AgentStateChangedObservation, RecallObservation),
):
# feedback from the environment to agent actions is understood as agent events by the UI
event_dict = event_to_dict(event)

View File

@@ -51,7 +51,6 @@ class DefaultUserAuth(UserAuth):
provider_tokens = getattr(secrets_store, 'provider_tokens', None)
return provider_tokens
@classmethod
async def get_instance(cls, request: Request) -> UserAuth:
user_auth = DefaultUserAuth()

View File

@@ -14,8 +14,8 @@ from openhands.utils.import_utils import get_impl
class AuthType(Enum):
COOKIE = "cookie"
BEARER = "bearer"
COOKIE = 'cookie'
BEARER = 'bearer'
class UserAuth(ABC):

View File

@@ -4,8 +4,8 @@ import json
from dataclasses import dataclass
from openhands.core.config.app_config import AppConfig
from openhands.storage.data_models.settings import Settings
from openhands.storage import get_file_store
from openhands.storage.data_models.settings import Settings
from openhands.storage.files import FileStore
from openhands.storage.settings.settings_store import SettingsStore
from openhands.utils.async_utils import call_sync_from_async

View File

@@ -56,7 +56,10 @@ def mock_github_token():
This eliminates the need for repeated patching in each test function.
"""
with patch('openhands.resolver.resolve_issue.identify_token', return_value=ProviderType.GITHUB) as patched:
with patch(
'openhands.resolver.resolve_issue.identify_token',
return_value=ProviderType.GITHUB,
) as patched:
yield patched
@@ -152,7 +155,9 @@ async def test_resolve_issue_no_issues_found(default_mock_args, mock_github_toke
# Verify that the handler was correctly configured and called
resolver.issue_handler_factory.assert_called_once()
mock_handler.get_converted_issues.assert_called_once_with(issue_numbers=[5432], comment_id=None)
mock_handler.get_converted_issues.assert_called_once_with(
issue_numbers=[5432], comment_id=None
)
def test_download_issues_from_github():
@@ -348,9 +353,7 @@ async def test_complete_runtime(default_mock_args, mock_github_token):
# Create resolver with mocked token identification
resolver = IssueResolver(default_mock_args)
result = await resolver.complete_runtime(
mock_runtime, 'base_commit_hash'
)
result = await resolver.complete_runtime(mock_runtime, 'base_commit_hash')
assert result == {'git_patch': 'git diff content'}
assert mock_runtime.run_action.call_count == 5
@@ -358,7 +361,7 @@ async def test_complete_runtime(default_mock_args, mock_github_token):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_case",
'test_case',
[
{
'name': 'successful_run',
@@ -410,11 +413,20 @@ async def test_complete_runtime(default_mock_args, mock_github_token):
'expected_error': None,
'expected_explanation': 'Non-JSON explanation',
'is_pr': True,
'comment_success': [True, False], # To trigger the PR success logging code path
'comment_success': [
True,
False,
], # To trigger the PR success logging code path
},
],
)
async def test_process_issue(default_mock_args, mock_github_token, mock_output_dir, mock_prompt_template, test_case):
async def test_process_issue(
default_mock_args,
mock_github_token,
mock_output_dir,
mock_prompt_template,
test_case,
):
"""Test the process_issue method with different scenarios."""
# Set up test data
@@ -466,15 +478,16 @@ async def test_process_issue(default_mock_args, mock_github_token, mock_output_d
mock_run_controller.return_value = test_case['run_controller_return']
# Patch the necessary functions and methods
with patch('openhands.resolver.resolve_issue.create_runtime', mock_create_runtime), \
patch('openhands.resolver.resolve_issue.run_controller', mock_run_controller), \
patch.object(resolver, 'complete_runtime', return_value={'git_patch': 'test patch'}), \
patch.object(resolver, 'initialize_runtime') as mock_initialize_runtime:
with patch(
'openhands.resolver.resolve_issue.create_runtime', mock_create_runtime
), patch(
'openhands.resolver.resolve_issue.run_controller', mock_run_controller
), patch.object(
resolver, 'complete_runtime', return_value={'git_patch': 'test patch'}
), patch.object(resolver, 'initialize_runtime') as mock_initialize_runtime:
# Call the process_issue method
result = await resolver.process_issue(issue, base_commit, handler_instance)
# Assert the result matches our expectations
assert isinstance(result, ResolverOutput)
assert result.issue == issue
@@ -496,10 +509,11 @@ async def test_process_issue(default_mock_args, mock_github_token, mock_output_d
# Check that the first positional argument is a config
assert 'config' in mock_run_controller.call_args[1]
# Check that initial_user_action is a MessageAction with the right content
assert isinstance(mock_run_controller.call_args[1]['initial_user_action'], MessageAction)
assert isinstance(
mock_run_controller.call_args[1]['initial_user_action'], MessageAction
)
assert mock_run_controller.call_args[1]['runtime'] == mock_runtime
# Assert that guess_success was called only for successful runs
if test_case['expected_success']:
handler_instance.guess_success.assert_called_once()

View File

@@ -19,7 +19,9 @@ from openhands.resolver.interfaces.issue_definitions import (
ServiceContextIssue,
ServiceContextPR,
)
from openhands.resolver.resolve_issue import IssueResolver, SandboxConfig, AppConfig, AgentConfig
from openhands.resolver.resolve_issue import (
IssueResolver,
)
from openhands.resolver.resolver_output import ResolverOutput
@@ -55,7 +57,10 @@ def mock_gitlab_token():
This eliminates the need for repeated patching in each test function.
"""
with patch('openhands.resolver.resolve_issue.identify_token', return_value=ProviderType.GITLAB) as patched:
with patch(
'openhands.resolver.resolve_issue.identify_token',
return_value=ProviderType.GITLAB,
) as patched:
yield patched
@@ -171,7 +176,9 @@ async def test_resolve_issue_no_issues_found(default_mock_args, mock_gitlab_toke
# Verify that the handler was correctly configured and called
resolver.issue_handler_factory.assert_called_once()
mock_handler.get_converted_issues.assert_called_once_with(issue_numbers=[5432], comment_id=None)
mock_handler.get_converted_issues.assert_called_once_with(
issue_numbers=[5432], comment_id=None
)
def test_download_issues_from_gitlab():
@@ -377,10 +384,12 @@ async def test_complete_runtime(default_mock_args, mock_gitlab_token):
content='',
command='git config --global --add safe.directory /workspace',
),
create_cmd_output(exit_code=0, content='', command='git add -A'),
create_cmd_output(
exit_code=0, content='', command='git add -A'
exit_code=0,
content='git diff content',
command='git diff --no-color --cached base_commit_hash',
),
create_cmd_output(exit_code=0, content='git diff content', command='git diff --no-color --cached base_commit_hash'),
]
# Create a resolver instance with mocked token identification
@@ -394,7 +403,7 @@ async def test_complete_runtime(default_mock_args, mock_gitlab_token):
@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_case",
'test_case',
[
{
'name': 'successful_run',
@@ -448,7 +457,13 @@ async def test_complete_runtime(default_mock_args, mock_gitlab_token):
},
],
)
async def test_process_issue(default_mock_args, mock_gitlab_token, mock_output_dir, mock_prompt_template, test_case):
async def test_process_issue(
default_mock_args,
mock_gitlab_token,
mock_output_dir,
mock_prompt_template,
test_case,
):
"""Test the process_issue method with different scenarios."""
# Set up test data
issue = Issue(
@@ -491,13 +506,15 @@ async def test_process_issue(default_mock_args, mock_gitlab_token, mock_output_d
mock_run_controller.return_value = test_case['run_controller_return']
# Patch the necessary functions and methods
with patch('openhands.resolver.resolve_issue.create_runtime', mock_create_runtime), \
patch('openhands.resolver.resolve_issue.run_controller', mock_run_controller), \
patch.object(resolver, 'complete_runtime', return_value={'git_patch': 'test patch'}), \
patch.object(resolver, 'initialize_runtime') as mock_initialize_runtime, \
patch('openhands.resolver.resolve_issue.SandboxConfig', return_value=MagicMock()), \
patch('openhands.resolver.resolve_issue.AppConfig', return_value=MagicMock()):
with patch(
'openhands.resolver.resolve_issue.create_runtime', mock_create_runtime
), patch(
'openhands.resolver.resolve_issue.run_controller', mock_run_controller
), patch.object(
resolver, 'complete_runtime', return_value={'git_patch': 'test patch'}
), patch.object(resolver, 'initialize_runtime') as mock_initialize_runtime, patch(
'openhands.resolver.resolve_issue.SandboxConfig', return_value=MagicMock()
), patch('openhands.resolver.resolve_issue.AppConfig', return_value=MagicMock()):
# Call the process_issue method
result = await resolver.process_issue(issue, base_commit, handler_instance)
@@ -521,6 +538,7 @@ async def test_process_issue(default_mock_args, mock_gitlab_token, mock_output_d
else:
handler_instance.guess_success.assert_not_called()
def test_get_instruction(mock_prompt_template, mock_followup_prompt_template):
issue = Issue(
owner='test_owner',