Files
OpenHands/openhands/integrations/github/github_service.py
Rohit Malhotra 890796cc9d [Feat]: Git mcp server to open PRs (#8348)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
Co-authored-by: Robert Brennan <accounts@rbren.io>
2025-05-21 11:48:02 -04:00

511 lines
17 KiB
Python

import json
import os
from datetime import datetime
from typing import Any
import httpx
from pydantic import SecretStr
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.github.queries import (
suggested_task_issue_graphql_query,
suggested_task_pr_graphql_query,
)
from openhands.integrations.service_types import (
BaseGitService,
Branch,
GitService,
ProviderType,
Repository,
RequestMethod,
SuggestedTask,
TaskType,
UnknownException,
User,
)
from openhands.server.types import AppMode
from openhands.utils.import_utils import get_impl
class GitHubService(BaseGitService, GitService):
BASE_URL = 'https://api.github.com'
token: SecretStr = SecretStr('')
refresh = False
def __init__(
self,
user_id: str | None = None,
external_auth_id: str | None = None,
external_auth_token: SecretStr | None = None,
token: SecretStr | None = None,
external_token_manager: bool = False,
base_domain: str | None = None,
):
self.user_id = user_id
self.external_token_manager = external_token_manager
if token:
self.token = token
if base_domain and base_domain != 'github.com':
self.BASE_URL = f'https://{base_domain}/api/v3'
self.external_auth_id = external_auth_id
self.external_auth_token = external_auth_token
@property
def provider(self) -> str:
return ProviderType.GITHUB.value
async def _get_github_headers(self) -> dict:
"""Retrieve the GH Token from settings store to construct the headers."""
if not self.token:
self.token = await self.get_latest_token()
return {
'Authorization': f'Bearer {self.token.get_secret_value() if self.token else ""}',
'Accept': 'application/vnd.github.v3+json',
}
def _has_token_expired(self, status_code: int) -> bool:
return status_code == 401
async def get_latest_token(self) -> SecretStr | None:
return self.token
async def _make_request(
self,
url: str,
params: dict | None = None,
method: RequestMethod = RequestMethod.GET,
) -> tuple[Any, dict]:
try:
async with httpx.AsyncClient() as client:
github_headers = await self._get_github_headers()
# Make initial request
response = await self.execute_request(
client=client,
url=url,
headers=github_headers,
params=params,
method=method,
)
# Handle token refresh if needed
if self.refresh and self._has_token_expired(response.status_code):
await self.get_latest_token()
github_headers = await self._get_github_headers()
response = await self.execute_request(
client=client,
url=url,
headers=github_headers,
params=params,
method=method,
)
response.raise_for_status()
headers = {}
if 'Link' in response.headers:
headers['Link'] = response.headers['Link']
return response.json(), headers
except httpx.HTTPStatusError as e:
raise self.handle_http_status_error(e)
except httpx.HTTPError as e:
raise self.handle_http_error(e)
async def get_user(self) -> User:
url = f'{self.BASE_URL}/user'
response, _ = await self._make_request(url)
return User(
id=response.get('id'),
login=response.get('login'),
avatar_url=response.get('avatar_url'),
company=response.get('company'),
name=response.get('name'),
email=response.get('email'),
)
async def verify_access(self) -> bool:
"""Verify if the token is valid by making a simple request."""
url = f'{self.BASE_URL}'
await self._make_request(url)
return True
async def _fetch_paginated_repos(
self, url: str, params: dict, max_repos: int, extract_key: str | None = None
) -> list[dict]:
"""
Fetch repositories with pagination support.
Args:
url: The API endpoint URL
params: Query parameters for the request
max_repos: Maximum number of repositories to fetch
extract_key: If provided, extract repositories from this key in the response
Returns:
List of repository dictionaries
"""
repos: list[dict] = []
page = 1
while len(repos) < max_repos:
page_params = {**params, 'page': str(page)}
response, headers = await self._make_request(url, page_params)
# Extract repositories from response
page_repos = response.get(extract_key, []) if extract_key else response
if not page_repos: # No more repositories
break
repos.extend(page_repos)
page += 1
# Check if we've reached the last page
link_header = headers.get('Link', '')
if 'rel="next"' not in link_header:
break
return repos[:max_repos] # Trim to max_repos if needed
def parse_pushed_at_date(self, repo):
ts = repo.get('pushed_at')
return datetime.strptime(ts, '%Y-%m-%dT%H:%M:%SZ') if ts else datetime.min
async def get_repositories(self, sort: str, app_mode: AppMode) -> list[Repository]:
MAX_REPOS = 1000
PER_PAGE = 100 # Maximum allowed by GitHub API
all_repos: list[dict] = []
if app_mode == AppMode.SAAS:
# Get all installation IDs and fetch repos for each one
installation_ids = await self.get_installation_ids()
# Iterate through each installation ID
for installation_id in installation_ids:
params = {'per_page': str(PER_PAGE)}
url = (
f'{self.BASE_URL}/user/installations/{installation_id}/repositories'
)
# Fetch repositories for this installation
installation_repos = await self._fetch_paginated_repos(
url, params, MAX_REPOS - len(all_repos), extract_key='repositories'
)
all_repos.extend(installation_repos)
# If we've already reached MAX_REPOS, no need to check other installations
if len(all_repos) >= MAX_REPOS:
break
if sort == 'pushed':
all_repos.sort(key=self.parse_pushed_at_date, reverse=True)
else:
# Original behavior for non-SaaS mode
params = {'per_page': str(PER_PAGE), 'sort': sort}
url = f'{self.BASE_URL}/user/repos'
# Fetch user repositories
all_repos = await self._fetch_paginated_repos(url, params, MAX_REPOS)
# Convert to Repository objects
return [
Repository(
id=repo.get('id'),
full_name=repo.get('full_name'),
stargazers_count=repo.get('stargazers_count'),
git_provider=ProviderType.GITHUB,
is_public=not repo.get('private', True),
)
for repo in all_repos
]
async def get_installation_ids(self) -> list[int]:
url = f'{self.BASE_URL}/user/installations'
response, _ = await self._make_request(url)
installations = response.get('installations', [])
return [i['id'] for i in installations]
async def search_repositories(
self, query: str, per_page: int, sort: str, order: str
) -> list[Repository]:
url = f'{self.BASE_URL}/search/repositories'
# Add is:public to the query to ensure we only search for public repositories
query_with_visibility = f'{query} is:public'
params = {
'q': query_with_visibility,
'per_page': per_page,
'sort': sort,
'order': order,
}
response, _ = await self._make_request(url, params)
repo_items = response.get('items', [])
repos = [
Repository(
id=repo.get('id'),
full_name=repo.get('full_name'),
stargazers_count=repo.get('stargazers_count'),
git_provider=ProviderType.GITHUB,
is_public=True,
)
for repo in repo_items
]
return repos
async def execute_graphql_query(
self, query: str, variables: dict[str, Any]
) -> dict[str, Any]:
"""Execute a GraphQL query against the GitHub API."""
try:
async with httpx.AsyncClient() as client:
github_headers = await self._get_github_headers()
response = await client.post(
f'{self.BASE_URL}/graphql',
headers=github_headers,
json={'query': query, 'variables': variables},
)
response.raise_for_status()
result = response.json()
if 'errors' in result:
raise UnknownException(
f'GraphQL query error: {json.dumps(result["errors"])}'
)
return dict(result)
except httpx.HTTPStatusError as e:
raise self.handle_http_status_error(e)
except httpx.HTTPError as e:
raise self.handle_http_error(e)
async def get_suggested_tasks(self) -> list[SuggestedTask]:
"""Get suggested tasks for the authenticated user across all repositories.
Returns:
- PRs authored by the user.
- Issues assigned to the user.
Note: Queries are split to avoid timeout issues.
"""
# Get user info to use in queries
user = await self.get_user()
login = user.login
tasks: list[SuggestedTask] = []
variables = {'login': login}
try:
pr_response = await self.execute_graphql_query(
suggested_task_pr_graphql_query, variables
)
pr_data = pr_response['data']['user']
# Process pull requests
for pr in pr_data['pullRequests']['nodes']:
repo_name = pr['repository']['nameWithOwner']
# Start with default task type
task_type = TaskType.OPEN_PR
# Check for specific states
if pr['mergeable'] == 'CONFLICTING':
task_type = TaskType.MERGE_CONFLICTS
elif (
pr['commits']['nodes']
and pr['commits']['nodes'][0]['commit']['statusCheckRollup']
and pr['commits']['nodes'][0]['commit']['statusCheckRollup'][
'state'
]
== 'FAILURE'
):
task_type = TaskType.FAILING_CHECKS
elif any(
review['state'] in ['CHANGES_REQUESTED', 'COMMENTED']
for review in pr['reviews']['nodes']
):
task_type = TaskType.UNRESOLVED_COMMENTS
# Only add the task if it's not OPEN_PR
if task_type != TaskType.OPEN_PR:
tasks.append(
SuggestedTask(
git_provider=ProviderType.GITHUB,
task_type=task_type,
repo=repo_name,
issue_number=pr['number'],
title=pr['title'],
)
)
except Exception as e:
logger.info(
f'Error fetching suggested task for PRs: {e}',
extra={
'signal': 'github_suggested_tasks',
'user_id': self.external_auth_id,
},
)
try:
# Execute issue query
issue_response = await self.execute_graphql_query(
suggested_task_issue_graphql_query, variables
)
issue_data = issue_response['data']['user']
# Process issues
for issue in issue_data['issues']['nodes']:
repo_name = issue['repository']['nameWithOwner']
tasks.append(
SuggestedTask(
git_provider=ProviderType.GITHUB,
task_type=TaskType.OPEN_ISSUE,
repo=repo_name,
issue_number=issue['number'],
title=issue['title'],
)
)
return tasks
except Exception as e:
logger.info(
f'Error fetching suggested task for issues: {e}',
extra={
'signal': 'github_suggested_tasks',
'user_id': self.external_auth_id,
},
)
return tasks
async def get_repository_details_from_repo_name(
self, repository: str
) -> Repository:
url = f'{self.BASE_URL}/repos/{repository}'
repo, _ = await self._make_request(url)
return Repository(
id=repo.get('id'),
full_name=repo.get('full_name'),
stargazers_count=repo.get('stargazers_count'),
git_provider=ProviderType.GITHUB,
is_public=not repo.get('private', True),
)
async def get_branches(self, repository: str) -> list[Branch]:
"""Get branches for a repository"""
url = f'{self.BASE_URL}/repos/{repository}/branches'
# Set maximum branches to fetch (10 pages with 100 per page)
MAX_BRANCHES = 1000
PER_PAGE = 100
all_branches: list[Branch] = []
page = 1
# Fetch up to 10 pages of branches
while page <= 10 and len(all_branches) < MAX_BRANCHES:
params = {'per_page': str(PER_PAGE), 'page': str(page)}
response, headers = await self._make_request(url, params)
if not response: # No more branches
break
for branch_data in response:
# Extract the last commit date if available
last_push_date = None
if branch_data.get('commit') and branch_data['commit'].get('commit'):
commit_info = branch_data['commit']['commit']
if commit_info.get('committer') and commit_info['committer'].get(
'date'
):
last_push_date = commit_info['committer']['date']
branch = Branch(
name=branch_data.get('name'),
commit_sha=branch_data.get('commit', {}).get('sha', ''),
protected=branch_data.get('protected', False),
last_push_date=last_push_date,
)
all_branches.append(branch)
page += 1
# Check if we've reached the last page
link_header = headers.get('Link', '')
if 'rel="next"' not in link_header:
break
return all_branches
async def create_pr(
self,
repo_name: str,
source_branch: str,
target_branch: str,
title: str,
body: str | None = None,
draft: bool = True,
) -> str:
"""
Creates a PR using user credentials
Args:
repo_name: The full name of the repository (owner/repo)
source_branch: The name of the branch where your changes are implemented
target_branch: The name of the branch you want the changes pulled into
title: The title of the pull request (optional, defaults to a generic title)
body: The body/description of the pull request (optional)
draft: Whether to create the PR as a draft (optional, defaults to False)
Returns:
- PR URL when successful
- Error message when unsuccessful
"""
try:
url = f'{self.BASE_URL}/repos/{repo_name}/pulls'
# Set default body if none provided
if not body:
body = f'Merging changes from {source_branch} into {target_branch}'
# Prepare the request payload
payload = {
'title': title,
'head': source_branch,
'base': target_branch,
'body': body,
'draft': draft,
}
# Make the POST request to create the PR
response, _ = await self._make_request(
url=url, params=payload, method=RequestMethod.POST
)
# Return the HTML URL of the created PR
if 'html_url' in response:
return response['html_url']
else:
return f'PR created but URL not found in response: {response}'
except Exception as e:
return f'Error creating pull request: {str(e)}'
github_service_cls = os.environ.get(
'OPENHANDS_GITHUB_SERVICE_CLS',
'openhands.integrations.github.github_service.GitHubService',
)
GithubServiceImpl = get_impl(GitHubService, github_service_cls)