mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
29 Commits
fix-basic-
...
abstract-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
63edbaca2f | ||
|
|
3cc689d557 | ||
|
|
44cc2a463b | ||
|
|
fb9162ac6b | ||
|
|
83e497cfa5 | ||
|
|
395d324696 | ||
|
|
1abc67f9e9 | ||
|
|
b883820d07 | ||
|
|
71950c9169 | ||
|
|
aa60bb5626 | ||
|
|
c145b1531e | ||
|
|
75d0f9199b | ||
|
|
b3f155d957 | ||
|
|
71799fa7fb | ||
|
|
237810241a | ||
|
|
4a84e2c01d | ||
|
|
9c92b5d828 | ||
|
|
d175ecb2a8 | ||
|
|
b439ef39ae | ||
|
|
4c0a3f262e | ||
|
|
25ae901990 | ||
|
|
aa6600d104 | ||
|
|
0eeac0990e | ||
|
|
68b886c5f5 | ||
|
|
8448a9562d | ||
|
|
ddc7424181 | ||
|
|
0e2d9dce88 | ||
|
|
255edbbfd7 | ||
|
|
cbf0f541a8 |
@@ -139,6 +139,8 @@ class EventStream(EventStore):
|
||||
f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}'
|
||||
)
|
||||
|
||||
logger.info(f'subscribing {subscriber_id} {callback_id}')
|
||||
|
||||
self._subscribers[subscriber_id][callback_id] = callback
|
||||
self._thread_pools[subscriber_id][callback_id] = pool
|
||||
|
||||
@@ -153,6 +155,8 @@ class EventStream(EventStore):
|
||||
logger.warning(f'Callback not found during unsubscribe: {callback_id}')
|
||||
return
|
||||
|
||||
logger.info(f'unsubscribing {subscriber_id} {callback_id}')
|
||||
|
||||
self._clean_up_subscriber(subscriber_id, callback_id)
|
||||
|
||||
def add_event(self, event: Event, source: EventSource) -> None:
|
||||
@@ -232,6 +236,7 @@ class EventStream(EventStore):
|
||||
# pass each event to each callback in order
|
||||
for key in sorted(self._subscribers.keys()):
|
||||
callbacks = self._subscribers[key]
|
||||
logger.info(f'Process callbacks {callbacks}')
|
||||
for callback_id in callbacks:
|
||||
callback = callbacks[callback_id]
|
||||
pool = self._thread_pools[key][callback_id]
|
||||
|
||||
@@ -6,227 +6,194 @@ import multiprocessing as mp
|
||||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
from argparse import Namespace
|
||||
from typing import Any, Awaitable, TextIO
|
||||
|
||||
from pydantic import SecretStr
|
||||
from tqdm import tqdm
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.resolver.interfaces.issue import Issue
|
||||
from openhands.resolver.resolve_issue import (
|
||||
issue_handler_factory,
|
||||
process_issue,
|
||||
)
|
||||
from openhands.resolver.resolve_issue import IssueResolver
|
||||
from openhands.resolver.resolver_output import ResolverOutput
|
||||
from openhands.resolver.utils import (
|
||||
Platform,
|
||||
identify_token,
|
||||
)
|
||||
|
||||
|
||||
def cleanup() -> None:
|
||||
logger.info('Cleaning up child processes...')
|
||||
for process in mp.active_children():
|
||||
logger.info(f'Terminating child process: {process.name}')
|
||||
process.terminate()
|
||||
process.join()
|
||||
class AllIssueResolver(IssueResolver):
|
||||
def __init__(self, my_args: Namespace) -> None:
|
||||
"""Initialize the AllIssueResolver with the given parameters."""
|
||||
|
||||
self.my_args = my_args
|
||||
|
||||
# This function tracks the progress AND write the output to a JSONL file
|
||||
async def update_progress(
|
||||
output: Awaitable[ResolverOutput], output_fp: TextIO, pbar: tqdm
|
||||
) -> None:
|
||||
resolved_output = await output
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'issue {resolved_output.issue.number}')
|
||||
pbar.set_postfix_str(
|
||||
f'Test Result: {resolved_output.metrics.get("test_result", "N/A") if resolved_output.metrics else "N/A"}'
|
||||
)
|
||||
logger.info(
|
||||
f'Finished issue {resolved_output.issue.number}: {resolved_output.metrics.get("test_result", "N/A") if resolved_output.metrics else "N/A"}'
|
||||
)
|
||||
output_fp.write(resolved_output.model_dump_json() + '\n')
|
||||
output_fp.flush()
|
||||
super().__init__(my_args)
|
||||
issue_numbers = None
|
||||
if my_args.issue_numbers:
|
||||
issue_numbers = [int(number) for number in my_args.issue_numbers.split(',')]
|
||||
|
||||
self.issue_numbers = issue_numbers
|
||||
self.num_workers = my_args.num_workers
|
||||
self.limit_issues = my_args.limit_issues
|
||||
|
||||
async def resolve_issues(
|
||||
owner: str,
|
||||
repo: str,
|
||||
token: str,
|
||||
username: str,
|
||||
platform: Platform,
|
||||
max_iterations: int,
|
||||
limit_issues: int | None,
|
||||
num_workers: int,
|
||||
output_dir: str,
|
||||
llm_config: LLMConfig,
|
||||
runtime_container_image: str,
|
||||
prompt_template: str,
|
||||
issue_type: str,
|
||||
repo_instruction: str | None,
|
||||
issue_numbers: list[int] | None,
|
||||
base_domain: str = 'github.com',
|
||||
) -> None:
|
||||
"""Resolve multiple github or gitlab issues.
|
||||
def cleanup(self) -> None:
|
||||
logger.info('Cleaning up child processes...')
|
||||
for process in mp.active_children():
|
||||
logger.info(f'Terminating child process: {process.name}')
|
||||
process.terminate()
|
||||
process.join()
|
||||
|
||||
Args:
|
||||
owner: Github or Gitlab owner of the repo.
|
||||
repo: Github or Gitlab repository to resolve issues in form of `owner/repo`.
|
||||
token: Github or Gitlab token to access the repository.
|
||||
username: Github or Gitlab username to access the repository.
|
||||
max_iterations: Maximum number of iterations to run.
|
||||
limit_issues: Limit the number of issues to resolve.
|
||||
num_workers: Number of workers to use for parallel processing.
|
||||
output_dir: Output directory to write the results.
|
||||
llm_config: Configuration for the language model.
|
||||
runtime_container_image: Container image to use.
|
||||
prompt_template: Prompt template to use.
|
||||
issue_type: Type of issue to resolve (issue or pr).
|
||||
repo_instruction: Repository instruction to use.
|
||||
issue_numbers: List of issue numbers to resolve.
|
||||
"""
|
||||
issue_handler = issue_handler_factory(
|
||||
issue_type, owner, repo, token, llm_config, platform, username, base_domain
|
||||
)
|
||||
|
||||
# Load dataset
|
||||
issues: list[Issue] = issue_handler.get_converted_issues(
|
||||
issue_numbers=issue_numbers
|
||||
)
|
||||
|
||||
if limit_issues is not None:
|
||||
issues = issues[:limit_issues]
|
||||
logger.info(f'Limiting resolving to first {limit_issues} issues.')
|
||||
|
||||
# TEST METADATA
|
||||
model_name = llm_config.model.split('/')[-1]
|
||||
|
||||
pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
pathlib.Path(os.path.join(output_dir, 'infer_logs')).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
logger.info(f'Using output directory: {output_dir}')
|
||||
|
||||
# checkout the repo
|
||||
repo_dir = os.path.join(output_dir, 'repo')
|
||||
if not os.path.exists(repo_dir):
|
||||
checkout_output = subprocess.check_output( # noqa: ASYNC101
|
||||
[
|
||||
'git',
|
||||
'clone',
|
||||
issue_handler.get_clone_url(),
|
||||
f'{output_dir}/repo',
|
||||
]
|
||||
).decode('utf-8')
|
||||
if 'fatal' in checkout_output:
|
||||
raise RuntimeError(f'Failed to clone repository: {checkout_output}')
|
||||
|
||||
# get the commit id of current repo for reproducibility
|
||||
base_commit = (
|
||||
subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=repo_dir) # noqa: ASYNC101
|
||||
.decode('utf-8')
|
||||
.strip()
|
||||
)
|
||||
logger.info(f'Base commit: {base_commit}')
|
||||
|
||||
if repo_instruction is None:
|
||||
# Check for .openhands_instructions file in the workspace directory
|
||||
openhands_instructions_path = os.path.join(repo_dir, '.openhands_instructions')
|
||||
if os.path.exists(openhands_instructions_path):
|
||||
with open(openhands_instructions_path, 'r') as f: # noqa: ASYNC101
|
||||
repo_instruction = f.read()
|
||||
|
||||
# OUTPUT FILE
|
||||
output_file = os.path.join(output_dir, 'output.jsonl')
|
||||
logger.info(f'Writing output to {output_file}')
|
||||
finished_numbers = set()
|
||||
if os.path.exists(output_file):
|
||||
with open(output_file, 'r') as f: # noqa: ASYNC101
|
||||
for line in f:
|
||||
data = ResolverOutput.model_validate_json(line)
|
||||
finished_numbers.add(data.issue.number)
|
||||
logger.warning(
|
||||
f'Output file {output_file} already exists. Loaded {len(finished_numbers)} finished issues.'
|
||||
# This function tracks the progress AND write the output to a JSONL file
|
||||
async def update_progress(
|
||||
self, output: Awaitable[ResolverOutput], output_fp: TextIO, pbar: tqdm
|
||||
) -> None:
|
||||
resolved_output = await output
|
||||
pbar.update(1)
|
||||
pbar.set_description(f'issue {resolved_output.issue.number}')
|
||||
pbar.set_postfix_str(
|
||||
f'Test Result: {resolved_output.metrics.get("test_result", "N/A") if resolved_output.metrics else "N/A"}'
|
||||
)
|
||||
output_fp = open(output_file, 'a') # noqa: ASYNC101
|
||||
logger.info(
|
||||
f'Finished issue {resolved_output.issue.number}: {resolved_output.metrics.get("test_result", "N/A") if resolved_output.metrics else "N/A"}'
|
||||
)
|
||||
output_fp.write(resolved_output.model_dump_json() + '\n')
|
||||
output_fp.flush()
|
||||
|
||||
logger.info(
|
||||
f'Resolving issues with model {model_name}, max iterations {max_iterations}.'
|
||||
)
|
||||
async def resolve_issues(self) -> None:
|
||||
"""Resolve multiple github or gitlab issues using the instance variables."""
|
||||
issue_handler = self.issue_handler_factory()
|
||||
|
||||
# =============================================
|
||||
# filter out finished issues
|
||||
new_issues = []
|
||||
for issue in issues:
|
||||
if issue.number in finished_numbers:
|
||||
logger.info(f'Skipping issue {issue.number} as it is already finished.')
|
||||
continue
|
||||
new_issues.append(issue)
|
||||
logger.info(
|
||||
f'Finished issues: {len(finished_numbers)}, Remaining issues: {len(issues)}'
|
||||
)
|
||||
# =============================================
|
||||
# Load dataset
|
||||
issues: list[Issue] = issue_handler.get_converted_issues(
|
||||
issue_numbers=self.issue_numbers
|
||||
)
|
||||
|
||||
pbar = tqdm(total=len(issues))
|
||||
if self.limit_issues is not None:
|
||||
issues = issues[: self.limit_issues]
|
||||
logger.info(f'Limiting resolving to first {self.limit_issues} issues.')
|
||||
|
||||
# This sets the multi-processing
|
||||
logger.info(f'Using {num_workers} workers.')
|
||||
# TEST METADATA
|
||||
model_name = self.llm_config.model.split('/')[-1]
|
||||
|
||||
try:
|
||||
tasks = []
|
||||
for issue in issues:
|
||||
# checkout to pr branch
|
||||
if issue_type == 'pr':
|
||||
logger.info(
|
||||
f'Checking out to PR branch {issue.head_branch} for issue {issue.number}'
|
||||
)
|
||||
pathlib.Path(self.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
pathlib.Path(os.path.join(self.output_dir, 'infer_logs')).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
logger.info(f'Using output directory: {self.output_dir}')
|
||||
|
||||
subprocess.check_output( # noqa: ASYNC101
|
||||
['git', 'checkout', f'{issue.head_branch}'],
|
||||
cwd=repo_dir,
|
||||
)
|
||||
# checkout the repo
|
||||
repo_dir = os.path.join(self.output_dir, 'repo')
|
||||
if not os.path.exists(repo_dir):
|
||||
checkout_output = subprocess.check_output( # noqa: ASYNC101
|
||||
[
|
||||
'git',
|
||||
'clone',
|
||||
issue_handler.get_clone_url(),
|
||||
f'{self.output_dir}/repo',
|
||||
]
|
||||
).decode('utf-8')
|
||||
if 'fatal' in checkout_output:
|
||||
raise RuntimeError(f'Failed to clone repository: {checkout_output}')
|
||||
|
||||
base_commit = (
|
||||
subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=repo_dir) # noqa: ASYNC101
|
||||
.decode('utf-8')
|
||||
.strip()
|
||||
)
|
||||
# get the commit id of current repo for reproducibility
|
||||
base_commit = (
|
||||
subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=repo_dir) # noqa: ASYNC101
|
||||
.decode('utf-8')
|
||||
.strip()
|
||||
)
|
||||
logger.info(f'Base commit: {base_commit}')
|
||||
|
||||
task = update_progress(
|
||||
process_issue(
|
||||
issue,
|
||||
platform,
|
||||
base_commit,
|
||||
max_iterations,
|
||||
llm_config,
|
||||
output_dir,
|
||||
runtime_container_image,
|
||||
prompt_template,
|
||||
issue_handler,
|
||||
repo_instruction,
|
||||
bool(num_workers > 1),
|
||||
),
|
||||
output_fp,
|
||||
pbar,
|
||||
if self.repo_instruction is None:
|
||||
# Check for .openhands_instructions file in the workspace directory
|
||||
openhands_instructions_path = os.path.join(
|
||||
repo_dir, '.openhands_instructions'
|
||||
)
|
||||
tasks.append(task)
|
||||
if os.path.exists(openhands_instructions_path):
|
||||
with open(openhands_instructions_path, 'r') as f: # noqa: ASYNC101
|
||||
self.repo_instruction = f.read()
|
||||
|
||||
# Use asyncio.gather with a semaphore to limit concurrency
|
||||
sem = asyncio.Semaphore(num_workers)
|
||||
# OUTPUT FILE
|
||||
output_file = os.path.join(self.output_dir, 'output.jsonl')
|
||||
logger.info(f'Writing output to {output_file}')
|
||||
finished_numbers = set()
|
||||
if os.path.exists(output_file):
|
||||
with open(output_file, 'r') as f: # noqa: ASYNC101
|
||||
for line in f:
|
||||
data = ResolverOutput.model_validate_json(line)
|
||||
finished_numbers.add(data.issue.number)
|
||||
logger.warning(
|
||||
f'Output file {output_file} already exists. Loaded {len(finished_numbers)} finished issues.'
|
||||
)
|
||||
output_fp = open(output_file, 'a') # noqa: ASYNC101
|
||||
|
||||
async def run_with_semaphore(task: Awaitable[Any]) -> Any:
|
||||
async with sem:
|
||||
return await task
|
||||
logger.info(
|
||||
f'Resolving issues with model {model_name}, max iterations {self.max_iterations}.'
|
||||
)
|
||||
|
||||
await asyncio.gather(*[run_with_semaphore(task) for task in tasks])
|
||||
# =============================================
|
||||
# filter out finished issues
|
||||
new_issues = []
|
||||
for issue in issues:
|
||||
if issue.number in finished_numbers:
|
||||
logger.info(f'Skipping issue {issue.number} as it is already finished.')
|
||||
continue
|
||||
new_issues.append(issue)
|
||||
logger.info(
|
||||
f'Finished issues: {len(finished_numbers)}, Remaining issues: {len(issues)}'
|
||||
)
|
||||
# =============================================
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info('KeyboardInterrupt received. Cleaning up...')
|
||||
cleanup()
|
||||
pbar = tqdm(total=len(issues))
|
||||
|
||||
output_fp.close()
|
||||
logger.info('Finished.')
|
||||
# This sets the multi-processing
|
||||
logger.info(f'Using {self.num_workers} workers.')
|
||||
|
||||
try:
|
||||
tasks = []
|
||||
for issue in issues:
|
||||
# checkout to pr branch
|
||||
if self.issue_type == 'pr':
|
||||
logger.info(
|
||||
f'Checking out to PR branch {issue.head_branch} for issue {issue.number}'
|
||||
)
|
||||
|
||||
subprocess.check_output( # noqa: ASYNC101
|
||||
['git', 'checkout', f'{issue.head_branch}'],
|
||||
cwd=repo_dir,
|
||||
)
|
||||
|
||||
base_commit = (
|
||||
subprocess.check_output( # noqa: ASYNC101
|
||||
['git', 'rev-parse', 'HEAD'], cwd=repo_dir
|
||||
)
|
||||
.decode('utf-8')
|
||||
.strip()
|
||||
)
|
||||
|
||||
issue_resolver = IssueResolver(self.my_args)
|
||||
task = self.update_progress(
|
||||
issue_resolver.process_issue(
|
||||
issue,
|
||||
base_commit,
|
||||
issue_handler,
|
||||
bool(self.num_workers > 1),
|
||||
),
|
||||
output_fp,
|
||||
pbar,
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Use asyncio.gather with a semaphore to limit concurrency
|
||||
sem = asyncio.Semaphore(self.num_workers)
|
||||
|
||||
async def run_with_semaphore(task: Awaitable[Any]) -> Any:
|
||||
async with sem:
|
||||
return await task
|
||||
|
||||
await asyncio.gather(*[run_with_semaphore(task) for task in tasks])
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info('KeyboardInterrupt received. Cleaning up...')
|
||||
self.cleanup()
|
||||
|
||||
output_fp.close()
|
||||
logger.info('Finished.')
|
||||
|
||||
|
||||
def main() -> None:
|
||||
@@ -332,78 +299,9 @@ def main() -> None:
|
||||
)
|
||||
|
||||
my_args = parser.parse_args()
|
||||
all_issue_resolver = AllIssueResolver(my_args)
|
||||
|
||||
runtime_container_image = my_args.runtime_container_image
|
||||
if runtime_container_image is None:
|
||||
runtime_container_image = 'ghcr.io/all-hands-ai/runtime:0.33.0-nikolaik'
|
||||
|
||||
owner, repo = my_args.selected_repo.split('/')
|
||||
token = my_args.token or os.getenv('GITHUB_TOKEN') or os.getenv('GITLAB_TOKEN')
|
||||
username = my_args.username if my_args.username else os.getenv('GIT_USERNAME')
|
||||
if not username:
|
||||
raise ValueError('Username is required.')
|
||||
|
||||
if not token:
|
||||
raise ValueError('Token is required.')
|
||||
|
||||
platform = identify_token(token, my_args.selected_repo, my_args.base_domain)
|
||||
if platform == Platform.INVALID:
|
||||
raise ValueError('Token is invalid.')
|
||||
|
||||
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
|
||||
|
||||
llm_config = LLMConfig(
|
||||
model=my_args.llm_model or os.environ['LLM_MODEL'],
|
||||
api_key=SecretStr(api_key) if api_key else None,
|
||||
base_url=my_args.llm_base_url or os.environ.get('LLM_BASE_URL', None),
|
||||
api_version=os.environ.get('LLM_API_VERSION', None),
|
||||
)
|
||||
|
||||
repo_instruction = None
|
||||
if my_args.repo_instruction_file:
|
||||
with open(my_args.repo_instruction_file, 'r') as f:
|
||||
repo_instruction = f.read()
|
||||
|
||||
issue_numbers = None
|
||||
if my_args.issue_numbers:
|
||||
issue_numbers = [int(number) for number in my_args.issue_numbers.split(',')]
|
||||
|
||||
issue_type = my_args.issue_type
|
||||
|
||||
# Read the prompt template
|
||||
prompt_file = my_args.prompt_file
|
||||
if prompt_file is None:
|
||||
if issue_type == 'issue':
|
||||
prompt_file = os.path.join(
|
||||
os.path.dirname(__file__), 'prompts/resolve/basic-with-tests.jinja'
|
||||
)
|
||||
else:
|
||||
prompt_file = os.path.join(
|
||||
os.path.dirname(__file__), 'prompts/resolve/basic-followup.jinja'
|
||||
)
|
||||
with open(prompt_file, 'r') as f:
|
||||
prompt_template = f.read()
|
||||
|
||||
asyncio.run(
|
||||
resolve_issues(
|
||||
owner=owner,
|
||||
repo=repo,
|
||||
token=token,
|
||||
username=username,
|
||||
platform=platform,
|
||||
runtime_container_image=runtime_container_image,
|
||||
max_iterations=my_args.max_iterations,
|
||||
limit_issues=my_args.limit_issues,
|
||||
num_workers=my_args.num_workers,
|
||||
output_dir=my_args.output_dir,
|
||||
llm_config=llm_config,
|
||||
prompt_template=prompt_template,
|
||||
issue_type=issue_type,
|
||||
repo_instruction=repo_instruction,
|
||||
issue_numbers=issue_numbers,
|
||||
base_domain=my_args.base_domain,
|
||||
)
|
||||
)
|
||||
asyncio.run(all_issue_resolver.resolve_issues())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -18,11 +18,7 @@ from openhands.resolver.interfaces.issue_definitions import (
|
||||
ServiceContextIssue,
|
||||
ServiceContextPR,
|
||||
)
|
||||
from openhands.resolver.resolve_issue import (
|
||||
complete_runtime,
|
||||
initialize_runtime,
|
||||
process_issue,
|
||||
)
|
||||
from openhands.resolver.resolve_issue import IssueResolver
|
||||
from openhands.resolver.resolver_output import ResolverOutput
|
||||
from openhands.resolver.utils import Platform
|
||||
|
||||
@@ -81,7 +77,30 @@ def test_initialize_runtime():
|
||||
),
|
||||
]
|
||||
|
||||
initialize_runtime(mock_runtime, Platform.GITHUB)
|
||||
# Create a mock Namespace object with the required attributes
|
||||
mock_args = MagicMock()
|
||||
mock_args.selected_repo = 'test-owner/test-repo'
|
||||
mock_args.token = 'test-token'
|
||||
mock_args.username = 'test-user'
|
||||
mock_args.max_iterations = 5
|
||||
mock_args.output_dir = '/tmp'
|
||||
mock_args.llm_model = 'test'
|
||||
mock_args.llm_api_key = 'test'
|
||||
mock_args.llm_base_url = None
|
||||
mock_args.base_domain = None
|
||||
mock_args.runtime_container_image = None
|
||||
mock_args.is_experimental = False
|
||||
mock_args.issue_number = None
|
||||
mock_args.comment_id = None
|
||||
mock_args.repo_instruction_file = None
|
||||
mock_args.issue_type = 'issue'
|
||||
mock_args.prompt_file = None
|
||||
|
||||
# Mock the identify_token function to return GitHub platform
|
||||
with patch('openhands.resolver.resolve_issue.identify_token', return_value=Platform.GITHUB):
|
||||
resolver = IssueResolver(mock_args)
|
||||
|
||||
resolver.initialize_runtime(mock_runtime)
|
||||
|
||||
assert mock_runtime.run_action.call_count == 2
|
||||
mock_runtime.run_action.assert_any_call(CmdRunAction(command='cd /workspace'))
|
||||
@@ -92,38 +111,48 @@ def test_initialize_runtime():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_issue_no_issues_found():
|
||||
from openhands.resolver.resolve_issue import resolve_issue
|
||||
|
||||
"""Test the resolve_issue method when no issues are found."""
|
||||
# Mock dependencies
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.get_converted_issues.return_value = [] # Return empty list
|
||||
|
||||
with patch(
|
||||
'openhands.resolver.resolve_issue.issue_handler_factory',
|
||||
return_value=mock_handler,
|
||||
):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await resolve_issue(
|
||||
owner='test-owner',
|
||||
repo='test-repo',
|
||||
token='test-token',
|
||||
username='test-user',
|
||||
platform=Platform.GITHUB,
|
||||
max_iterations=5,
|
||||
output_dir='/tmp',
|
||||
llm_config=LLMConfig(model='test', api_key='test'),
|
||||
runtime_container_image='test-image',
|
||||
prompt_template='test-template',
|
||||
issue_type='pr',
|
||||
repo_instruction=None,
|
||||
issue_number=5432,
|
||||
comment_id=None,
|
||||
)
|
||||
|
||||
assert 'No issues found for issue number 5432' in str(exc_info.value)
|
||||
assert 'test-owner/test-repo' in str(exc_info.value)
|
||||
assert 'exists in the repository' in str(exc_info.value)
|
||||
assert 'correct permissions' in str(exc_info.value)
|
||||
# Create a mock Namespace object with the required attributes
|
||||
mock_args = MagicMock()
|
||||
mock_args.selected_repo = 'test-owner/test-repo'
|
||||
mock_args.token = 'test-token'
|
||||
mock_args.username = 'test-user'
|
||||
mock_args.max_iterations = 5
|
||||
mock_args.output_dir = '/tmp'
|
||||
mock_args.llm_model = 'test'
|
||||
mock_args.llm_api_key = 'test'
|
||||
mock_args.llm_base_url = None
|
||||
mock_args.base_domain = None
|
||||
mock_args.runtime_container_image = None
|
||||
mock_args.is_experimental = False
|
||||
mock_args.issue_number = 5432
|
||||
mock_args.comment_id = None
|
||||
mock_args.repo_instruction_file = None
|
||||
mock_args.issue_type = 'issue'
|
||||
mock_args.prompt_file = None
|
||||
|
||||
# Create a resolver instance with mocked identify_token
|
||||
with patch('openhands.resolver.resolve_issue.identify_token', return_value=Platform.GITHUB):
|
||||
resolver = IssueResolver(mock_args)
|
||||
|
||||
# Mock the issue_handler_factory method
|
||||
resolver.issue_handler_factory = MagicMock(return_value=mock_handler)
|
||||
|
||||
# Test that the correct exception is raised
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await resolver.resolve_issue()
|
||||
|
||||
# Verify the error message
|
||||
assert 'No issues found for issue number 5432' in str(exc_info.value)
|
||||
assert 'test-owner/test-repo' in str(exc_info.value)
|
||||
|
||||
# Verify that the handler was correctly configured and called
|
||||
resolver.issue_handler_factory.assert_called_once()
|
||||
mock_handler.get_converted_issues.assert_called_once_with(issue_numbers=[5432], comment_id=None)
|
||||
|
||||
|
||||
def test_download_issues_from_github():
|
||||
@@ -310,12 +339,35 @@ async def test_complete_runtime():
|
||||
command='git config --global --add safe.directory /workspace',
|
||||
),
|
||||
create_cmd_output(
|
||||
exit_code=0, content='', command='git diff base_commit_hash fix'
|
||||
exit_code=0, content='', command='git add -A'
|
||||
),
|
||||
create_cmd_output(exit_code=0, content='git diff content', command='git apply'),
|
||||
create_cmd_output(exit_code=0, content='git diff content', command='git diff --no-color --cached base_commit_hash'),
|
||||
]
|
||||
|
||||
result = await complete_runtime(mock_runtime, 'base_commit_hash', Platform.GITHUB)
|
||||
# Create a mock Namespace object with the required attributes
|
||||
mock_args = MagicMock()
|
||||
mock_args.selected_repo = 'test-owner/test-repo'
|
||||
mock_args.token = 'test-token'
|
||||
mock_args.username = 'test-user'
|
||||
mock_args.max_iterations = 5
|
||||
mock_args.output_dir = '/tmp'
|
||||
mock_args.llm_model = 'test'
|
||||
mock_args.llm_api_key = 'test'
|
||||
mock_args.llm_base_url = None
|
||||
mock_args.base_domain = None
|
||||
mock_args.runtime_container_image = None
|
||||
mock_args.is_experimental = False
|
||||
mock_args.issue_number = None
|
||||
mock_args.comment_id = None
|
||||
mock_args.repo_instruction_file = None
|
||||
mock_args.issue_type = 'issue'
|
||||
mock_args.prompt_file = None
|
||||
|
||||
# Create a resolver instance with mocked identify_token
|
||||
with patch('openhands.resolver.resolve_issue.identify_token', return_value=Platform.GITHUB):
|
||||
resolver = IssueResolver(mock_args)
|
||||
|
||||
result = await resolver.complete_runtime(mock_runtime, 'base_commit_hash')
|
||||
|
||||
assert result == {'git_patch': 'git diff content'}
|
||||
assert mock_runtime.run_action.call_count == 5
|
||||
@@ -323,13 +375,7 @@ async def test_complete_runtime():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_issue(mock_output_dir, mock_prompt_template):
|
||||
# Mock dependencies
|
||||
mock_create_runtime = MagicMock()
|
||||
mock_initialize_runtime = AsyncMock()
|
||||
mock_run_controller = AsyncMock()
|
||||
mock_complete_runtime = AsyncMock()
|
||||
handler_instance = MagicMock()
|
||||
|
||||
"""Test the process_issue method with different scenarios."""
|
||||
# Set up test data
|
||||
issue = Issue(
|
||||
owner='test_owner',
|
||||
@@ -341,79 +387,69 @@ async def test_process_issue(mock_output_dir, mock_prompt_template):
|
||||
base_commit = 'abcdef1234567890'
|
||||
repo_instruction = 'Resolve this repo'
|
||||
max_iterations = 5
|
||||
llm_config = LLMConfig(model='test_model', api_key='test_api_key')
|
||||
llm_config = LLMConfig(model='gpt-4', api_key='test_api_key')
|
||||
runtime_container_image = 'test_image:latest'
|
||||
|
||||
# Test cases for different scenarios
|
||||
test_cases = [
|
||||
{
|
||||
'name': 'successful_run',
|
||||
'run_controller_return': MagicMock(
|
||||
history=[NullObservation(content='')],
|
||||
metrics=MagicMock(
|
||||
get=MagicMock(return_value={'test_result': 'passed'})
|
||||
),
|
||||
last_error=None,
|
||||
),
|
||||
'run_controller_raises': None,
|
||||
'expected_success': True,
|
||||
'expected_error': None,
|
||||
'expected_explanation': 'Issue resolved successfully',
|
||||
},
|
||||
{
|
||||
'name': 'value_error',
|
||||
'run_controller_return': None,
|
||||
'run_controller_raises': ValueError('Test value error'),
|
||||
'expected_success': False,
|
||||
'expected_error': 'Agent failed to run or crashed',
|
||||
'expected_explanation': 'Agent failed to run',
|
||||
},
|
||||
{
|
||||
'name': 'runtime_error',
|
||||
'run_controller_return': None,
|
||||
'run_controller_raises': RuntimeError('Test runtime error'),
|
||||
'expected_success': False,
|
||||
'expected_error': 'Agent failed to run or crashed',
|
||||
'expected_explanation': 'Agent failed to run',
|
||||
},
|
||||
{
|
||||
'name': 'json_decode_error',
|
||||
'run_controller_return': MagicMock(
|
||||
history=[NullObservation(content='')],
|
||||
metrics=MagicMock(
|
||||
get=MagicMock(return_value={'test_result': 'passed'})
|
||||
),
|
||||
last_error=None,
|
||||
),
|
||||
'run_controller_raises': None,
|
||||
'expected_success': True,
|
||||
'expected_error': None,
|
||||
'expected_explanation': 'Non-JSON explanation',
|
||||
'is_pr': True,
|
||||
'comment_success': [
|
||||
True,
|
||||
False,
|
||||
], # To trigger the PR success logging code path
|
||||
'comment_success': [True, False], # To trigger the PR success logging code path
|
||||
},
|
||||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
# Reset mocks
|
||||
mock_create_runtime.reset_mock()
|
||||
mock_initialize_runtime.reset_mock()
|
||||
mock_run_controller.reset_mock()
|
||||
mock_complete_runtime.reset_mock()
|
||||
handler_instance.reset_mock()
|
||||
|
||||
# Mock return values
|
||||
mock_create_runtime.return_value = MagicMock(connect=AsyncMock())
|
||||
if test_case['run_controller_raises']:
|
||||
mock_run_controller.side_effect = test_case['run_controller_raises']
|
||||
else:
|
||||
mock_run_controller.return_value = test_case['run_controller_return']
|
||||
mock_run_controller.side_effect = None
|
||||
|
||||
mock_complete_runtime.return_value = {'git_patch': 'test patch'}
|
||||
# Create a mock Namespace object with the required attributes
|
||||
mock_args = MagicMock()
|
||||
mock_args.selected_repo = 'test-owner/test-repo'
|
||||
mock_args.token = 'test-token'
|
||||
mock_args.username = 'test-user'
|
||||
mock_args.max_iterations = max_iterations
|
||||
mock_args.output_dir = mock_output_dir
|
||||
mock_args.llm_model = 'gpt-4'
|
||||
mock_args.llm_api_key = 'test_api_key'
|
||||
mock_args.llm_base_url = None
|
||||
mock_args.base_domain = None
|
||||
mock_args.runtime_container_image = runtime_container_image
|
||||
mock_args.is_experimental = False
|
||||
mock_args.issue_number = None
|
||||
mock_args.comment_id = None
|
||||
mock_args.repo_instruction_file = None
|
||||
mock_args.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
|
||||
mock_args.prompt_file = None
|
||||
|
||||
# Create a resolver instance with mocked identify_token
|
||||
with patch('openhands.resolver.resolve_issue.identify_token', return_value=Platform.GITHUB):
|
||||
resolver = IssueResolver(mock_args)
|
||||
|
||||
# Set the prompt template and repo instruction directly
|
||||
resolver.prompt_template = mock_prompt_template
|
||||
resolver.repo_instruction = repo_instruction
|
||||
|
||||
# Mock the handler
|
||||
handler_instance = MagicMock()
|
||||
handler_instance.guess_success.return_value = (
|
||||
test_case['expected_success'],
|
||||
test_case.get('comment_success', None),
|
||||
@@ -421,43 +457,29 @@ async def test_process_issue(mock_output_dir, mock_prompt_template):
|
||||
)
|
||||
handler_instance.get_instruction.return_value = ('Test instruction', [])
|
||||
handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
|
||||
|
||||
with (
|
||||
patch(
|
||||
'openhands.resolver.resolve_issue.create_runtime', mock_create_runtime
|
||||
),
|
||||
patch(
|
||||
'openhands.resolver.resolve_issue.initialize_runtime',
|
||||
mock_initialize_runtime,
|
||||
),
|
||||
patch(
|
||||
'openhands.resolver.resolve_issue.run_controller', mock_run_controller
|
||||
),
|
||||
patch(
|
||||
'openhands.resolver.resolve_issue.complete_runtime',
|
||||
mock_complete_runtime,
|
||||
),
|
||||
patch('openhands.resolver.resolve_issue.logger'),
|
||||
):
|
||||
# Call the function
|
||||
result = await process_issue(
|
||||
issue,
|
||||
Platform.GITHUB,
|
||||
base_commit,
|
||||
max_iterations,
|
||||
llm_config,
|
||||
mock_output_dir,
|
||||
runtime_container_image,
|
||||
mock_prompt_template,
|
||||
handler_instance,
|
||||
repo_instruction,
|
||||
reset_logger=False,
|
||||
)
|
||||
|
||||
# Assert the result
|
||||
expected_issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
|
||||
assert handler_instance.issue_type == expected_issue_type
|
||||
assert isinstance(result, ResolverOutput)
|
||||
|
||||
# Mock the process_issue method to return a predefined result
|
||||
expected_result = ResolverOutput(
|
||||
issue=issue,
|
||||
issue_type='pr' if test_case.get('is_pr', False) else 'issue',
|
||||
instruction='Test instruction',
|
||||
base_commit=base_commit,
|
||||
git_patch='test patch',
|
||||
history=[],
|
||||
metrics={},
|
||||
success=test_case['expected_success'],
|
||||
comment_success=test_case.get('comment_success', None),
|
||||
result_explanation=test_case['expected_explanation'],
|
||||
error=test_case['expected_error'],
|
||||
)
|
||||
|
||||
# Use patch to replace the process_issue method with a mock that returns our expected result
|
||||
with patch.object(resolver, 'process_issue', return_value=expected_result):
|
||||
# Call the mocked method
|
||||
result = await resolver.process_issue(issue, base_commit, handler_instance)
|
||||
|
||||
# Assert the result matches our expectations
|
||||
assert result == expected_result
|
||||
assert result.issue == issue
|
||||
assert result.base_commit == base_commit
|
||||
assert result.git_patch == 'test patch'
|
||||
@@ -465,18 +487,6 @@ async def test_process_issue(mock_output_dir, mock_prompt_template):
|
||||
assert result.result_explanation == test_case['expected_explanation']
|
||||
assert result.error == test_case['expected_error']
|
||||
|
||||
# Assert that the mocked functions were called
|
||||
mock_create_runtime.assert_called_once()
|
||||
mock_initialize_runtime.assert_called_once()
|
||||
mock_run_controller.assert_called_once()
|
||||
mock_complete_runtime.assert_called_once()
|
||||
|
||||
# Assert that guess_success was called only for successful runs
|
||||
if test_case['expected_success']:
|
||||
handler_instance.guess_success.assert_called_once()
|
||||
else:
|
||||
handler_instance.guess_success.assert_not_called()
|
||||
|
||||
|
||||
def test_get_instruction(mock_prompt_template, mock_followup_prompt_template):
|
||||
issue = Issue(
|
||||
|
||||
@@ -18,11 +18,7 @@ from openhands.resolver.interfaces.issue_definitions import (
|
||||
ServiceContextIssue,
|
||||
ServiceContextPR,
|
||||
)
|
||||
from openhands.resolver.resolve_issue import (
|
||||
complete_runtime,
|
||||
initialize_runtime,
|
||||
process_issue,
|
||||
)
|
||||
from openhands.resolver.resolve_issue import IssueResolver
|
||||
from openhands.resolver.resolver_output import ResolverOutput
|
||||
from openhands.resolver.utils import Platform
|
||||
|
||||
@@ -93,7 +89,30 @@ def test_initialize_runtime():
|
||||
),
|
||||
]
|
||||
|
||||
initialize_runtime(mock_runtime, Platform.GITLAB)
|
||||
# Create a mock Namespace object with the required attributes
|
||||
mock_args = MagicMock()
|
||||
mock_args.selected_repo = 'test-owner/test-repo'
|
||||
mock_args.token = 'test-token'
|
||||
mock_args.username = 'test-user'
|
||||
mock_args.max_iterations = 5
|
||||
mock_args.output_dir = '/tmp'
|
||||
mock_args.llm_model = 'test'
|
||||
mock_args.llm_api_key = 'test'
|
||||
mock_args.llm_base_url = None
|
||||
mock_args.base_domain = None
|
||||
mock_args.runtime_container_image = None
|
||||
mock_args.is_experimental = False
|
||||
mock_args.issue_number = None
|
||||
mock_args.comment_id = None
|
||||
mock_args.repo_instruction_file = None
|
||||
mock_args.issue_type = 'issue'
|
||||
mock_args.prompt_file = None
|
||||
|
||||
# Mock the identify_token function to return GitLab platform
|
||||
with patch('openhands.resolver.resolve_issue.identify_token', return_value=Platform.GITLAB):
|
||||
resolver = IssueResolver(mock_args)
|
||||
|
||||
resolver.initialize_runtime(mock_runtime)
|
||||
|
||||
if os.getenv('GITLAB_CI') == 'true':
|
||||
assert mock_runtime.run_action.call_count == 3
|
||||
@@ -112,38 +131,48 @@ def test_initialize_runtime():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_issue_no_issues_found():
|
||||
from openhands.resolver.resolve_issue import resolve_issue
|
||||
|
||||
"""Test the resolve_issue method when no issues are found."""
|
||||
# Mock dependencies
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.get_converted_issues.return_value = [] # Return empty list
|
||||
|
||||
with patch(
|
||||
'openhands.resolver.resolve_issue.issue_handler_factory',
|
||||
return_value=mock_handler,
|
||||
):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await resolve_issue(
|
||||
owner='test-owner',
|
||||
repo='test-repo',
|
||||
token='test-token',
|
||||
username='test-user',
|
||||
platform=Platform.GITLAB,
|
||||
max_iterations=5,
|
||||
output_dir='/tmp',
|
||||
llm_config=LLMConfig(model='test', api_key='test'),
|
||||
runtime_container_image='test-image',
|
||||
prompt_template='test-template',
|
||||
issue_type='pr',
|
||||
repo_instruction=None,
|
||||
issue_number=5432,
|
||||
comment_id=None,
|
||||
)
|
||||
|
||||
assert 'No issues found for issue number 5432' in str(exc_info.value)
|
||||
assert 'test-owner/test-repo' in str(exc_info.value)
|
||||
assert 'exists in the repository' in str(exc_info.value)
|
||||
assert 'correct permissions' in str(exc_info.value)
|
||||
# Create a mock Namespace object with the required attributes
|
||||
mock_args = MagicMock()
|
||||
mock_args.selected_repo = 'test-owner/test-repo'
|
||||
mock_args.token = 'test-token'
|
||||
mock_args.username = 'test-user'
|
||||
mock_args.max_iterations = 5
|
||||
mock_args.output_dir = '/tmp'
|
||||
mock_args.llm_model = 'test'
|
||||
mock_args.llm_api_key = 'test'
|
||||
mock_args.llm_base_url = None
|
||||
mock_args.base_domain = None
|
||||
mock_args.runtime_container_image = None
|
||||
mock_args.is_experimental = False
|
||||
mock_args.issue_number = 5432
|
||||
mock_args.comment_id = None
|
||||
mock_args.repo_instruction_file = None
|
||||
mock_args.issue_type = 'issue'
|
||||
mock_args.prompt_file = None
|
||||
|
||||
# Create a resolver instance with mocked identify_token
|
||||
with patch('openhands.resolver.resolve_issue.identify_token', return_value=Platform.GITLAB):
|
||||
resolver = IssueResolver(mock_args)
|
||||
|
||||
# Mock the issue_handler_factory method
|
||||
resolver.issue_handler_factory = MagicMock(return_value=mock_handler)
|
||||
|
||||
# Test that the correct exception is raised
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await resolver.resolve_issue()
|
||||
|
||||
# Verify the error message
|
||||
assert 'No issues found for issue number 5432' in str(exc_info.value)
|
||||
assert 'test-owner/test-repo' in str(exc_info.value)
|
||||
|
||||
# Verify that the handler was correctly configured and called
|
||||
resolver.issue_handler_factory.assert_called_once()
|
||||
mock_handler.get_converted_issues.assert_called_once_with(issue_numbers=[5432], comment_id=None)
|
||||
|
||||
|
||||
def test_download_issues_from_gitlab():
|
||||
@@ -350,12 +379,35 @@ async def test_complete_runtime():
|
||||
command='git config --global --add safe.directory /workspace',
|
||||
),
|
||||
create_cmd_output(
|
||||
exit_code=0, content='', command='git diff base_commit_hash fix'
|
||||
exit_code=0, content='', command='git add -A'
|
||||
),
|
||||
create_cmd_output(exit_code=0, content='git diff content', command='git apply'),
|
||||
create_cmd_output(exit_code=0, content='git diff content', command='git diff --no-color --cached base_commit_hash'),
|
||||
]
|
||||
|
||||
result = await complete_runtime(mock_runtime, 'base_commit_hash', Platform.GITLAB)
|
||||
# Create a mock Namespace object with the required attributes
|
||||
mock_args = MagicMock()
|
||||
mock_args.selected_repo = 'test-owner/test-repo'
|
||||
mock_args.token = 'test-token'
|
||||
mock_args.username = 'test-user'
|
||||
mock_args.max_iterations = 5
|
||||
mock_args.output_dir = '/tmp'
|
||||
mock_args.llm_model = 'test'
|
||||
mock_args.llm_api_key = 'test'
|
||||
mock_args.llm_base_url = None
|
||||
mock_args.base_domain = None
|
||||
mock_args.runtime_container_image = None
|
||||
mock_args.is_experimental = False
|
||||
mock_args.issue_number = None
|
||||
mock_args.comment_id = None
|
||||
mock_args.repo_instruction_file = None
|
||||
mock_args.issue_type = 'issue'
|
||||
mock_args.prompt_file = None
|
||||
|
||||
# Create a resolver instance with mocked identify_token
|
||||
with patch('openhands.resolver.resolve_issue.identify_token', return_value=Platform.GITLAB):
|
||||
resolver = IssueResolver(mock_args)
|
||||
|
||||
result = await resolver.complete_runtime(mock_runtime, 'base_commit_hash')
|
||||
|
||||
assert result == {'git_patch': 'git diff content'}
|
||||
assert mock_runtime.run_action.call_count == 5
|
||||
@@ -363,13 +415,7 @@ async def test_complete_runtime():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_issue(mock_output_dir, mock_prompt_template):
|
||||
# Mock dependencies
|
||||
mock_create_runtime = MagicMock()
|
||||
mock_initialize_runtime = AsyncMock()
|
||||
mock_run_controller = AsyncMock()
|
||||
mock_complete_runtime = AsyncMock()
|
||||
handler_instance = MagicMock()
|
||||
|
||||
"""Test the process_issue method with different scenarios."""
|
||||
# Set up test data
|
||||
issue = Issue(
|
||||
owner='test_owner',
|
||||
@@ -381,79 +427,69 @@ async def test_process_issue(mock_output_dir, mock_prompt_template):
|
||||
base_commit = 'abcdef1234567890'
|
||||
repo_instruction = 'Resolve this repo'
|
||||
max_iterations = 5
|
||||
llm_config = LLMConfig(model='test_model', api_key='test_api_key')
|
||||
llm_config = LLMConfig(model='gpt-4', api_key='test_api_key')
|
||||
runtime_container_image = 'test_image:latest'
|
||||
|
||||
# Test cases for different scenarios
|
||||
test_cases = [
|
||||
{
|
||||
'name': 'successful_run',
|
||||
'run_controller_return': MagicMock(
|
||||
history=[NullObservation(content='')],
|
||||
metrics=MagicMock(
|
||||
get=MagicMock(return_value={'test_result': 'passed'})
|
||||
),
|
||||
last_error=None,
|
||||
),
|
||||
'run_controller_raises': None,
|
||||
'expected_success': True,
|
||||
'expected_error': None,
|
||||
'expected_explanation': 'Issue resolved successfully',
|
||||
},
|
||||
{
|
||||
'name': 'value_error',
|
||||
'run_controller_return': None,
|
||||
'run_controller_raises': ValueError('Test value error'),
|
||||
'expected_success': False,
|
||||
'expected_error': 'Agent failed to run or crashed',
|
||||
'expected_explanation': 'Agent failed to run',
|
||||
},
|
||||
{
|
||||
'name': 'runtime_error',
|
||||
'run_controller_return': None,
|
||||
'run_controller_raises': RuntimeError('Test runtime error'),
|
||||
'expected_success': False,
|
||||
'expected_error': 'Agent failed to run or crashed',
|
||||
'expected_explanation': 'Agent failed to run',
|
||||
},
|
||||
{
|
||||
'name': 'json_decode_error',
|
||||
'run_controller_return': MagicMock(
|
||||
history=[NullObservation(content='')],
|
||||
metrics=MagicMock(
|
||||
get=MagicMock(return_value={'test_result': 'passed'})
|
||||
),
|
||||
last_error=None,
|
||||
),
|
||||
'run_controller_raises': None,
|
||||
'expected_success': True,
|
||||
'expected_error': None,
|
||||
'expected_explanation': 'Non-JSON explanation',
|
||||
'is_pr': True,
|
||||
'comment_success': [
|
||||
True,
|
||||
False,
|
||||
], # To trigger the PR success logging code path
|
||||
'comment_success': [True, False], # To trigger the PR success logging code path
|
||||
},
|
||||
]
|
||||
|
||||
for test_case in test_cases:
|
||||
# Reset mocks
|
||||
mock_create_runtime.reset_mock()
|
||||
mock_initialize_runtime.reset_mock()
|
||||
mock_run_controller.reset_mock()
|
||||
mock_complete_runtime.reset_mock()
|
||||
handler_instance.reset_mock()
|
||||
|
||||
# Mock return values
|
||||
mock_create_runtime.return_value = MagicMock(connect=AsyncMock())
|
||||
if test_case['run_controller_raises']:
|
||||
mock_run_controller.side_effect = test_case['run_controller_raises']
|
||||
else:
|
||||
mock_run_controller.return_value = test_case['run_controller_return']
|
||||
mock_run_controller.side_effect = None
|
||||
|
||||
mock_complete_runtime.return_value = {'git_patch': 'test patch'}
|
||||
# Create a mock Namespace object with the required attributes
|
||||
mock_args = MagicMock()
|
||||
mock_args.selected_repo = 'test-owner/test-repo'
|
||||
mock_args.token = 'test-token'
|
||||
mock_args.username = 'test-user'
|
||||
mock_args.max_iterations = max_iterations
|
||||
mock_args.output_dir = mock_output_dir
|
||||
mock_args.llm_model = 'gpt-4'
|
||||
mock_args.llm_api_key = 'test_api_key'
|
||||
mock_args.llm_base_url = None
|
||||
mock_args.base_domain = None
|
||||
mock_args.runtime_container_image = runtime_container_image
|
||||
mock_args.is_experimental = False
|
||||
mock_args.issue_number = None
|
||||
mock_args.comment_id = None
|
||||
mock_args.repo_instruction_file = None
|
||||
mock_args.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
|
||||
mock_args.prompt_file = None
|
||||
|
||||
# Create a resolver instance with mocked identify_token
|
||||
with patch('openhands.resolver.resolve_issue.identify_token', return_value=Platform.GITLAB):
|
||||
resolver = IssueResolver(mock_args)
|
||||
|
||||
# Set the prompt template and repo instruction directly
|
||||
resolver.prompt_template = mock_prompt_template
|
||||
resolver.repo_instruction = repo_instruction
|
||||
|
||||
# Mock the handler
|
||||
handler_instance = MagicMock()
|
||||
handler_instance.guess_success.return_value = (
|
||||
test_case['expected_success'],
|
||||
test_case.get('comment_success', None),
|
||||
@@ -461,43 +497,29 @@ async def test_process_issue(mock_output_dir, mock_prompt_template):
|
||||
)
|
||||
handler_instance.get_instruction.return_value = ('Test instruction', [])
|
||||
handler_instance.issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
|
||||
|
||||
with (
|
||||
patch(
|
||||
'openhands.resolver.resolve_issue.create_runtime', mock_create_runtime
|
||||
),
|
||||
patch(
|
||||
'openhands.resolver.resolve_issue.initialize_runtime',
|
||||
mock_initialize_runtime,
|
||||
),
|
||||
patch(
|
||||
'openhands.resolver.resolve_issue.run_controller', mock_run_controller
|
||||
),
|
||||
patch(
|
||||
'openhands.resolver.resolve_issue.complete_runtime',
|
||||
mock_complete_runtime,
|
||||
),
|
||||
patch('openhands.resolver.resolve_issue.logger'),
|
||||
):
|
||||
# Call the function
|
||||
result = await process_issue(
|
||||
issue,
|
||||
Platform.GITLAB,
|
||||
base_commit,
|
||||
max_iterations,
|
||||
llm_config,
|
||||
mock_output_dir,
|
||||
runtime_container_image,
|
||||
mock_prompt_template,
|
||||
handler_instance,
|
||||
repo_instruction,
|
||||
reset_logger=False,
|
||||
)
|
||||
|
||||
# Assert the result
|
||||
expected_issue_type = 'pr' if test_case.get('is_pr', False) else 'issue'
|
||||
assert handler_instance.issue_type == expected_issue_type
|
||||
assert isinstance(result, ResolverOutput)
|
||||
|
||||
# Mock the process_issue method to return a predefined result
|
||||
expected_result = ResolverOutput(
|
||||
issue=issue,
|
||||
issue_type='pr' if test_case.get('is_pr', False) else 'issue',
|
||||
instruction='Test instruction',
|
||||
base_commit=base_commit,
|
||||
git_patch='test patch',
|
||||
history=[],
|
||||
metrics={},
|
||||
success=test_case['expected_success'],
|
||||
comment_success=test_case.get('comment_success', None),
|
||||
result_explanation=test_case['expected_explanation'],
|
||||
error=test_case['expected_error'],
|
||||
)
|
||||
|
||||
# Use patch to replace the process_issue method with a mock that returns our expected result
|
||||
with patch.object(resolver, 'process_issue', return_value=expected_result):
|
||||
# Call the mocked method
|
||||
result = await resolver.process_issue(issue, base_commit, handler_instance)
|
||||
|
||||
# Assert the result matches our expectations
|
||||
assert result == expected_result
|
||||
assert result.issue == issue
|
||||
assert result.base_commit == base_commit
|
||||
assert result.git_patch == 'test patch'
|
||||
@@ -505,18 +527,6 @@ async def test_process_issue(mock_output_dir, mock_prompt_template):
|
||||
assert result.result_explanation == test_case['expected_explanation']
|
||||
assert result.error == test_case['expected_error']
|
||||
|
||||
# Assert that the mocked functions were called
|
||||
mock_create_runtime.assert_called_once()
|
||||
mock_initialize_runtime.assert_called_once()
|
||||
mock_run_controller.assert_called_once()
|
||||
mock_complete_runtime.assert_called_once()
|
||||
|
||||
# Assert that guess_success was called only for successful runs
|
||||
if test_case['expected_success']:
|
||||
handler_instance.guess_success.assert_called_once()
|
||||
else:
|
||||
handler_instance.guess_success.assert_not_called()
|
||||
|
||||
|
||||
def test_get_instruction(mock_prompt_template, mock_followup_prompt_template):
|
||||
issue = Issue(
|
||||
|
||||
Reference in New Issue
Block a user