Add base_domain parameter for GitHub Enterprise support (#7754)

Co-authored-by: Tom Deckers <tdeckers@cisco.com>
Co-authored-by: Robert Brennan <accounts@rbren.io>
Co-authored-by: Rohit Malhotra <rohitvinodmalhotra@gmail.com>
This commit is contained in:
Tom Deckers
2025-04-16 02:00:32 +02:00
committed by GitHub
parent d7e8f843ad
commit 7e14a512e0
8 changed files with 203 additions and 49 deletions

View File

@@ -235,6 +235,7 @@ def send_pull_request(
target_branch: str | None = None,
reviewer: str | None = None,
pr_title: str | None = None,
base_domain: str | None = None,
) -> str:
"""Send a pull request to a GitHub or Gitlab repository.
@@ -250,18 +251,25 @@ def send_pull_request(
target_branch: The target branch to create the pull request against (defaults to repository default branch)
reviewer: The GitHub or Gitlab username of the reviewer to assign
pr_title: Custom title for the pull request (optional)
base_domain: The base domain for the git server (defaults to "github.com" for GitHub and "gitlab.com" for GitLab)
"""
if pr_type not in ['branch', 'draft', 'ready']:
raise ValueError(f'Invalid pr_type: {pr_type}')
# Determine default base_domain based on platform
if base_domain is None:
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
handler = None
if platform == Platform.GITHUB:
handler = ServiceContextIssue(
GithubIssueHandler(issue.owner, issue.repo, token, username), None
GithubIssueHandler(issue.owner, issue.repo, token, username, base_domain),
None,
)
else: # platform == Platform.GITLAB
handler = ServiceContextIssue(
GitlabIssueHandler(issue.owner, issue.repo, token, username), None
GitlabIssueHandler(issue.owner, issue.repo, token, username, base_domain),
None,
)
# Create a new branch with a unique name
@@ -363,6 +371,7 @@ def update_existing_pull_request(
llm_config: LLMConfig,
comment_message: str | None = None,
additional_message: str | None = None,
base_domain: str | None = None,
) -> str:
"""Update an existing pull request with the new patches.
@@ -375,17 +384,24 @@ def update_existing_pull_request(
llm_config: The LLM configuration to use for summarizing changes.
comment_message: The main message to post as a comment on the PR.
additional_message: The additional messages to post as a comment on the PR in json list format.
base_domain: The base domain for the git server (defaults to "github.com" for GitHub and "gitlab.com" for GitLab)
"""
# Set up headers and base URL for GitHub or GitLab API
# Determine default base_domain based on platform
if base_domain is None:
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
handler = None
if platform == Platform.GITHUB:
handler = ServiceContextIssue(
GithubIssueHandler(issue.owner, issue.repo, token, username), llm_config
GithubIssueHandler(issue.owner, issue.repo, token, username, base_domain),
llm_config,
)
else: # platform == Platform.GITLAB
handler = ServiceContextIssue(
GitlabIssueHandler(issue.owner, issue.repo, token, username), llm_config
GitlabIssueHandler(issue.owner, issue.repo, token, username, base_domain),
llm_config,
)
branch_name = issue.head_branch
@@ -468,7 +484,11 @@ def process_single_issue(
target_branch: str | None = None,
reviewer: str | None = None,
pr_title: str | None = None,
base_domain: str | None = None,
) -> None:
# Determine default base_domain based on platform
if base_domain is None:
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
if not resolver_output.success and not send_on_failure:
logger.info(
f'Issue {resolver_output.issue.number} was not successfully resolved. Skipping PR creation.'
@@ -507,6 +527,7 @@ def process_single_issue(
patch_dir=patched_repo_dir,
additional_message=resolver_output.result_explanation,
llm_config=llm_config,
base_domain=base_domain,
)
else:
send_pull_request(
@@ -521,6 +542,7 @@ def process_single_issue(
target_branch=target_branch,
reviewer=reviewer,
pr_title=pr_title,
base_domain=base_domain,
)
@@ -532,7 +554,11 @@ def process_all_successful_issues(
pr_type: str,
llm_config: LLMConfig,
fork_owner: str | None,
base_domain: str | None = None,
) -> None:
# Determine default base_domain based on platform
if base_domain is None:
base_domain = 'github.com' if platform == Platform.GITHUB else 'gitlab.com'
output_path = os.path.join(output_dir, 'output.jsonl')
for resolver_output in load_all_resolver_outputs(output_path):
if resolver_output.success:
@@ -548,6 +574,9 @@ def process_all_successful_issues(
fork_owner,
False,
None,
None,
None,
base_domain,
)
@@ -633,6 +662,12 @@ def main() -> None:
help='Custom title for the pull request',
default=None,
)
parser.add_argument(
'--base-domain',
type=str,
default=None,
help='Base domain for the git server (defaults to "github.com" for GitHub and "gitlab.com" for GitLab)',
)
my_args = parser.parse_args()
token = my_args.token or os.getenv('GITHUB_TOKEN') or os.getenv('GITLAB_TOKEN')
@@ -642,7 +677,7 @@ def main() -> None:
)
username = my_args.username if my_args.username else os.getenv('GIT_USERNAME')
platform = identify_token(token)
platform = identify_token(token, None, my_args.base_domain)
if platform == Platform.INVALID:
raise ValueError('Token is invalid.')
@@ -667,6 +702,7 @@ def main() -> None:
my_args.pr_type,
llm_config,
my_args.fork_owner,
my_args.base_domain,
)
else:
if not my_args.issue_number.isdigit():
@@ -689,6 +725,7 @@ def main() -> None:
my_args.target_branch,
my_args.reviewer,
my_args.pr_title,
my_args.base_domain,
)