[Fix]: Replace duplicate enums for providers in resolver (#7954)

This commit is contained in:
Rohit Malhotra
2025-04-20 14:06:18 -04:00
committed by GitHub
parent 20bf48b693
commit 0637b5b912
10 changed files with 102 additions and 115 deletions

View File

@@ -9,6 +9,7 @@ from pydantic import SecretStr
from openhands.core.config import LLMConfig
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.service_types import ProviderType
from openhands.llm.llm import LLM
from openhands.resolver.interfaces.github import GithubIssueHandler
from openhands.resolver.interfaces.gitlab import GitlabIssueHandler
@@ -20,10 +21,7 @@ from openhands.resolver.io_utils import (
)
from openhands.resolver.patching import apply_diff, parse_patch
from openhands.resolver.resolver_output import ResolverOutput
from openhands.resolver.utils import (
Platform,
identify_token,
)
from openhands.resolver.utils import identify_token
def apply_patch(repo_dir: str, patch: str) -> None:
@@ -227,7 +225,7 @@ def send_pull_request(
issue: Issue,
token: str,
username: str | None,
platform: Platform,
platform: ProviderType,
patch_dir: str,
pr_type: str,
fork_owner: str | None = None,
@@ -258,10 +256,10 @@ def send_pull_request(
# Determine default base_domain based on platform
if base_domain is None:
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
base_domain = 'github.com' if platform == ProviderType.GITHUB else 'gitlab.com'
handler = None
if platform == Platform.GITHUB:
if platform == ProviderType.GITHUB:
handler = ServiceContextIssue(
GithubIssueHandler(issue.owner, issue.repo, token, username, base_domain),
None,
@@ -329,7 +327,7 @@ def send_pull_request(
# For cross repo pull request, we need to send head parameter like fork_owner:branch as per git documentation here : https://docs.github.com/en/rest/pulls/pulls?apiVersion=2022-11-28#create-a-pull-request
# head parameter usage : The name of the branch where your changes are implemented. For cross-repository pull requests in the same network, namespace head with a user like this: username:branch.
if fork_owner and platform == Platform.GITHUB:
if fork_owner and platform == ProviderType.GITHUB:
head_branch = f'{fork_owner}:{branch_name}'
else:
head_branch = branch_name
@@ -341,9 +339,13 @@ def send_pull_request(
# Prepare the PR for the GitHub API
data = {
'title': final_pr_title,
('body' if platform == Platform.GITHUB else 'description'): pr_body,
('head' if platform == Platform.GITHUB else 'source_branch'): head_branch,
('base' if platform == Platform.GITHUB else 'target_branch'): base_branch,
('body' if platform == ProviderType.GITHUB else 'description'): pr_body,
(
'head' if platform == ProviderType.GITHUB else 'source_branch'
): head_branch,
(
'base' if platform == ProviderType.GITHUB else 'target_branch'
): base_branch,
'draft': pr_type == 'draft',
}
@@ -366,7 +368,7 @@ def update_existing_pull_request(
issue: Issue,
token: str,
username: str | None,
platform: Platform,
platform: ProviderType,
patch_dir: str,
llm_config: LLMConfig,
comment_message: str | None = None,
@@ -390,10 +392,10 @@ def update_existing_pull_request(
# Determine default base_domain based on platform
if base_domain is None:
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
base_domain = 'github.com' if platform == ProviderType.GITHUB else 'gitlab.com'
handler = None
if platform == Platform.GITHUB:
if platform == ProviderType.GITHUB:
handler = ServiceContextIssue(
GithubIssueHandler(issue.owner, issue.repo, token, username, base_domain),
llm_config,
@@ -476,7 +478,7 @@ def process_single_issue(
resolver_output: ResolverOutput,
token: str,
username: str,
platform: Platform,
platform: ProviderType,
pr_type: str,
llm_config: LLMConfig,
fork_owner: str | None,
@@ -488,7 +490,7 @@ def process_single_issue(
) -> None:
# Determine default base_domain based on platform
if base_domain is None:
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
base_domain = 'github.com' if platform == ProviderType.GITHUB else 'gitlab.com'
if not resolver_output.success and not send_on_failure:
logger.info(
f'Issue {resolver_output.issue.number} was not successfully resolved. Skipping PR creation.'
@@ -550,7 +552,7 @@ def process_all_successful_issues(
output_dir: str,
token: str,
username: str,
platform: Platform,
platform: ProviderType,
pr_type: str,
llm_config: LLMConfig,
fork_owner: str | None,
@@ -558,7 +560,7 @@ def process_all_successful_issues(
) -> None:
# Determine default base_domain based on platform
if base_domain is None:
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
base_domain = 'github.com' if platform == ProviderType.GITHUB else 'gitlab.com'
output_path = os.path.join(output_dir, 'output.jsonl')
for resolver_output in load_all_resolver_outputs(output_path):
if resolver_output.success:
@@ -684,8 +686,6 @@ def main() -> None:
username = my_args.username if my_args.username else os.getenv('GIT_USERNAME')
platform = identify_token(token, my_args.selected_repo, my_args.base_domain)
if platform == Platform.INVALID:
raise ValueError('Token is invalid.')
api_key = my_args.llm_api_key or os.environ['LLM_API_KEY']
llm_config = LLMConfig(