Files
OpenHands/openhands/integrations/provider.py
2026-04-27 12:23:43 -06:00

649 lines
26 KiB
Python

from __future__ import annotations
import os
from collections.abc import Mapping
from types import MappingProxyType
from typing import cast
from urllib.parse import quote
import httpx
from pydantic import (
BaseModel,
ConfigDict,
Field,
SecretStr,
)
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.azure_devops.azure_devops_service import (
AzureDevOpsServiceImpl,
)
from openhands.integrations.bitbucket.bitbucket_service import BitBucketServiceImpl
from openhands.integrations.bitbucket_data_center.bitbucket_dc_service import (
BitbucketDCServiceImpl,
)
from openhands.integrations.forgejo.forgejo_service import ForgejoServiceImpl
from openhands.integrations.github.github_service import GithubServiceImpl
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
from openhands.integrations.service_types import (
AuthenticationError,
Branch,
GitService,
InstallationsService,
PaginatedBranchesResponse,
ProviderTimeoutError,
ProviderType,
Repository,
SuggestedTask,
TokenResponse,
User,
)
from openhands.server.types import AppMode
from openhands.utils.http_session import httpx_verify_option
class ProviderToken(BaseModel):
token: SecretStr | None = Field(default=None)
user_id: str | None = Field(default=None)
host: str | None = Field(default=None)
model_config = ConfigDict(
frozen=True, # Makes the entire model immutable
validate_assignment=True,
)
@classmethod
def from_value(cls, token_value: ProviderToken | dict[str, str]) -> ProviderToken:
"""Factory method to create a ProviderToken from various input types"""
if isinstance(token_value, cls):
return token_value
elif isinstance(token_value, dict):
token_str = token_value.get('token', '')
# Override with emtpy string if it was set to None
# Cannot pass None to SecretStr
if token_str is None:
token_str = '' # type: ignore[unreachable]
user_id = token_value.get('user_id')
host = token_value.get('host')
return cls(token=SecretStr(token_str), user_id=user_id, host=host)
else:
raise ValueError('Unsupported Provider token type')
class CustomSecret(BaseModel):
secret: SecretStr = Field(default_factory=lambda: SecretStr(''))
description: str = Field(default='')
model_config = ConfigDict(
frozen=True, # Makes the entire model immutable
validate_assignment=True,
)
@classmethod
def from_value(cls, secret_value: CustomSecret | dict[str, str]) -> CustomSecret:
"""Factory method to create a ProviderToken from various input types"""
if isinstance(secret_value, CustomSecret):
return secret_value
elif isinstance(secret_value, dict):
secret = secret_value.get('secret', '')
description = secret_value.get('description', '')
return cls(secret=SecretStr(secret), description=description)
else:
raise ValueError('Unsupport Provider token type')
PROVIDER_TOKEN_TYPE = Mapping[ProviderType, ProviderToken]
CUSTOM_SECRETS_TYPE = Mapping[str, CustomSecret]
class ProviderHandler:
# Class variable for provider domains
PROVIDER_DOMAINS: dict[ProviderType, str] = {
ProviderType.GITHUB: 'github.com',
ProviderType.GITLAB: 'gitlab.com',
ProviderType.BITBUCKET: 'bitbucket.org',
ProviderType.FORGEJO: 'codeberg.org',
ProviderType.AZURE_DEVOPS: 'dev.azure.com',
}
def __init__(
self,
provider_tokens: PROVIDER_TOKEN_TYPE,
external_auth_id: str | None = None,
external_auth_token: SecretStr | None = None,
external_token_manager: bool = False,
session_api_key: str | None = None,
sid: str | None = None,
):
if not isinstance(provider_tokens, MappingProxyType):
raise TypeError(
f'provider_tokens must be a MappingProxyType, got {type(provider_tokens).__name__}'
)
self.service_class_map: dict[ProviderType, type[GitService]] = {
ProviderType.GITHUB: GithubServiceImpl,
ProviderType.GITLAB: GitLabServiceImpl,
ProviderType.BITBUCKET: BitBucketServiceImpl,
ProviderType.BITBUCKET_DATA_CENTER: BitbucketDCServiceImpl,
ProviderType.FORGEJO: ForgejoServiceImpl,
ProviderType.AZURE_DEVOPS: AzureDevOpsServiceImpl,
}
self.external_auth_id = external_auth_id
self.external_auth_token = external_auth_token
self.external_token_manager = external_token_manager
self.session_api_key = session_api_key
self.sid = sid
self._provider_tokens = provider_tokens
WEB_HOST = os.getenv('WEB_HOST', '').strip()
self.REFRESH_TOKEN_URL = (
f'https://{WEB_HOST}/api/refresh-tokens' if WEB_HOST else None
)
@property
def provider_tokens(self) -> PROVIDER_TOKEN_TYPE:
"""Read-only access to provider tokens."""
return self._provider_tokens
def get_service(self, provider: ProviderType) -> GitService:
"""Helper method to instantiate a service for a given provider"""
token = self.provider_tokens[provider]
service_class = self.service_class_map[provider]
return service_class(
user_id=token.user_id,
external_auth_id=self.external_auth_id,
external_auth_token=self.external_auth_token,
token=token.token,
external_token_manager=self.external_token_manager,
base_domain=token.host,
)
async def get_user(self) -> User:
"""Get user information from the first available provider"""
exceptions: list[tuple[ProviderType, Exception]] = []
for provider in self.provider_tokens:
try:
service = self.get_service(provider)
return await service.get_user()
except Exception as e:
exceptions.append((provider, e))
continue
for provider, exc in exceptions:
logger.warning(
f'Failed to get user from provider {provider}: {exc}',
exc_info=(type(exc), exc, exc.__traceback__),
)
raise AuthenticationError('Need valid provider token')
async def _get_latest_provider_token(
self, provider: ProviderType
) -> SecretStr | None:
"""Get latest token from service"""
if not self.REFRESH_TOKEN_URL:
logger.warning('Refresh token URL not set')
return None
try:
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
headers = (
{'X-Session-API-Key': self.session_api_key}
if self.session_api_key
else {}
)
resp = await client.get(
self.REFRESH_TOKEN_URL,
headers=headers,
params={'provider': provider.value, 'sid': self.sid},
)
resp.raise_for_status()
data = TokenResponse.model_validate_json(resp.text)
return SecretStr(data.token)
except Exception as e:
logger.error(
f'Failed to fetch latest token for provider {provider}: {e}',
exc_info=True,
)
return None
async def get_github_installations(self) -> list[str]:
service = cast(InstallationsService, self.get_service(ProviderType.GITHUB))
try:
return await service.get_installations()
except Exception as e:
logger.warning(f'Failed to get github installations {e}')
return []
async def get_bitbucket_workspaces(self) -> list[str]:
service = cast(InstallationsService, self.get_service(ProviderType.BITBUCKET))
try:
return await service.get_installations()
except Exception as e:
logger.warning(f'Failed to get bitbucket workspaces {e}')
return []
async def get_bitbucket_dc_projects(self) -> list[str]:
service = cast(
InstallationsService,
self.get_service(ProviderType.BITBUCKET_DATA_CENTER),
)
try:
return await service.get_installations()
except Exception as e:
logger.warning(f'Failed to get bitbucket data center projects {e}')
return []
async def get_github_organizations(self) -> list[str]:
service = self.get_service(ProviderType.GITHUB)
try:
return await service.get_organizations_from_installations() # type: ignore[attr-defined]
except Exception as e:
logger.warning(f'Failed to get github organizations {e}')
return []
async def get_gitlab_groups(self) -> list[str]:
service = self.get_service(ProviderType.GITLAB)
try:
return await service.get_user_groups() # type: ignore[attr-defined]
except Exception as e:
logger.warning(f'Failed to get gitlab groups {e}')
return []
async def get_azure_devops_organizations(self) -> list[str]:
service = cast(
InstallationsService, self.get_service(ProviderType.AZURE_DEVOPS)
)
try:
return await service.get_installations()
except Exception as e:
logger.warning(f'Failed to get azure devops organizations {e}')
return []
async def get_repositories(
self,
sort: str,
app_mode: AppMode,
selected_provider: ProviderType | None,
page: int | None,
per_page: int | None,
installation_id: str | None,
) -> list[Repository]:
"""Get repositories from providers.
Raises:
ProviderTimeoutError: If a timeout occurs while fetching repos.
"""
if selected_provider:
if not page or not per_page:
raise ValueError('Failed to provider params for paginating repos')
service = self.get_service(selected_provider)
return await service.get_paginated_repos(
page, per_page, sort, installation_id
)
all_repos: list[Repository] = []
for provider in self.provider_tokens:
try:
service = self.get_service(provider)
service_repos = await service.get_all_repositories(sort, app_mode)
all_repos.extend(service_repos)
except ProviderTimeoutError:
# Propagate timeout errors so callers can handle them appropriately
raise
except Exception as e:
logger.warning(f'Error fetching repos from {provider}: {e}')
return all_repos
async def get_suggested_tasks(self) -> list[SuggestedTask]:
"""Get suggested tasks from providers"""
tasks: list[SuggestedTask] = []
for provider in self.provider_tokens:
try:
service = self.get_service(provider)
service_repos = await service.get_suggested_tasks()
tasks.extend(service_repos)
except Exception as e:
logger.warning(f'Error fetching repos from {provider}: {e}')
return tasks
async def search_branches(
self,
selected_provider: ProviderType | None,
repository: str,
query: str,
per_page: int = 30,
) -> list[Branch]:
"""Search for branches within a repository using the appropriate provider service."""
if selected_provider:
service = self.get_service(selected_provider)
try:
return await service.search_branches(repository, query, per_page)
except Exception as e:
logger.warning(
f'Error searching branches from selected provider {selected_provider}: {e}'
)
return []
# If provider not specified, determine provider by verifying repository access
try:
repo_details = await self.verify_repo_provider(repository)
service = self.get_service(repo_details.git_provider)
return await service.search_branches(repository, query, per_page)
except Exception as e:
logger.warning(f'Error searching branches for {repository}: {e}')
return []
async def search_repositories(
self,
selected_provider: ProviderType | None,
query: str,
per_page: int,
sort: str,
order: str,
app_mode: AppMode,
) -> list[Repository]:
if selected_provider:
service = self.get_service(selected_provider)
public = self._is_repository_url(query, selected_provider)
user_repos = await service.search_repositories(
query, per_page, sort, order, public, app_mode
)
return self._deduplicate_repositories(user_repos)
all_repos: list[Repository] = []
for provider in self.provider_tokens:
try:
service = self.get_service(provider)
public = self._is_repository_url(query, provider)
service_repos = await service.search_repositories(
query, per_page, sort, order, public, app_mode
)
all_repos.extend(service_repos)
except Exception as e:
logger.warning(f'Error searching repos from {provider}: {e}')
continue
return all_repos
def _is_repository_url(self, query: str, provider: ProviderType) -> bool:
"""Check if the query is a repository URL."""
custom_host = self.provider_tokens[provider].host
custom_host_exists = bool(custom_host and custom_host in query)
default_domain = self.PROVIDER_DOMAINS.get(provider)
default_host_exists = default_domain is not None and default_domain in query
return query.startswith(('http://', 'https://')) and (
custom_host_exists or default_host_exists
)
def _deduplicate_repositories(self, repos: list[Repository]) -> list[Repository]:
"""Remove duplicate repositories based on full_name."""
seen = set()
unique_repos = []
for repo in repos:
if repo.full_name not in seen:
seen.add(repo.id)
unique_repos.append(repo)
return unique_repos
@classmethod
def get_provider_env_key(cls, provider: ProviderType) -> str:
"""Map ProviderType value to the environment variable name in the runtime"""
return f'{provider.value}_token'.lower()
async def verify_repo_provider(
self,
repository: str,
specified_provider: ProviderType | None = None,
is_optional: bool = False,
) -> Repository:
errors = []
if specified_provider:
try:
service = self.get_service(specified_provider)
return await service.get_repository_details_from_repo_name(repository)
except Exception as e:
errors.append(f'{specified_provider.value}: {str(e)}')
for provider in self.provider_tokens:
try:
service = self.get_service(provider)
return await service.get_repository_details_from_repo_name(repository)
except Exception as e:
errors.append(f'{provider.value}: {str(e)}')
# Log detailed error based on whether we had tokens or not
# For optional repositories (like org-level microagents), use debug level
log_fn = logger.debug if is_optional else logger.error
if not self.provider_tokens:
log_fn(
f'Failed to access repository {repository}: No provider tokens available. '
f'provider_tokens dict is empty.'
)
elif errors:
log_fn(
f'Failed to access repository {repository} with all available providers. '
f'Tried providers: {list(self.provider_tokens.keys())}. '
f'Errors: {"; ".join(errors)}'
)
else:
log_fn(
f'Failed to access repository {repository}: Unknown error (no providers tried, no errors recorded)'
)
raise AuthenticationError(f'Unable to access repo {repository}')
async def get_branches(
self,
repository: str,
specified_provider: ProviderType | None = None,
page: int = 1,
per_page: int = 30,
) -> PaginatedBranchesResponse:
"""Get branches for a repository
Args:
repository: The repository name
specified_provider: Optional provider type to use
page: Page number for pagination (default: 1)
per_page: Number of branches per page (default: 30)
Returns:
A paginated response with branches for the repository
"""
if specified_provider:
try:
service = self.get_service(specified_provider)
return await service.get_paginated_branches(repository, page, per_page)
except Exception as e:
logger.warning(
f'Error fetching branches from {specified_provider}: {e}'
)
for provider in self.provider_tokens:
try:
service = self.get_service(provider)
return await service.get_paginated_branches(repository, page, per_page)
except Exception as e:
logger.warning(f'Error fetching branches from {provider}: {e}')
# Return empty response if no provider worked
return PaginatedBranchesResponse(
branches=[],
has_next_page=False,
current_page=page,
per_page=per_page,
total_count=0,
)
async def get_authenticated_git_url(
self, repo_name: str, is_optional: bool = False
) -> str:
"""Get an authenticated git URL for a repository.
Args:
repo_name: Repository name (owner/repo)
is_optional: If True, logs at debug level instead of error level when repo not found
Returns:
Authenticated git URL if credentials are available, otherwise regular HTTPS URL
"""
try:
repository = await self.verify_repo_provider(
repo_name, is_optional=is_optional
)
except AuthenticationError:
raise Exception('Git provider authentication issue when getting remote URL')
provider = repository.git_provider
repo_name = repository.full_name
domain = self.PROVIDER_DOMAINS.get(provider, '')
# If provider tokens are provided, use the host from the token if available
# Note: For Azure DevOps, don't use the host field as it may contain org/project path
if self.provider_tokens and provider in self.provider_tokens:
if provider != ProviderType.AZURE_DEVOPS:
domain = self.provider_tokens[provider].host or domain
# Detect protocol before normalizing domain
# Default to https, but preserve http if explicitly specified
protocol = 'https'
if domain and domain.strip().startswith('http://'):
# Check if insecure HTTP access is allowed
allow_insecure = os.environ.get(
'ALLOW_INSECURE_GIT_ACCESS', 'false'
).lower() in ('true', '1', 'yes')
if not allow_insecure:
raise ValueError(
'Attempting to connect to an insecure git repository over HTTP. '
"If you'd like to allow this nonetheless, set "
'ALLOW_INSECURE_GIT_ACCESS=true as an environment variable.'
)
protocol = 'http'
# Normalize domain to prevent double protocols or path segments
if domain:
domain = domain.strip()
domain = domain.replace('https://', '').replace('http://', '')
# Remove any trailing path like /api/v3 or /api/v4
if '/' in domain:
domain = domain.split('/')[0]
# Try to use token if available, otherwise use public URL
if self.provider_tokens and provider in self.provider_tokens:
git_token = self.provider_tokens[provider].token
if git_token:
token_value = git_token.get_secret_value()
if provider == ProviderType.GITLAB:
remote_url = (
f'{protocol}://oauth2:{token_value}@{domain}/{repo_name}.git'
)
elif provider == ProviderType.BITBUCKET:
# For Bitbucket, handle username:app_password format
if ':' in token_value:
# App token format: username:app_password
remote_url = (
f'{protocol}://{token_value}@{domain}/{repo_name}.git'
)
else:
# Access token format: use x-token-auth
remote_url = f'{protocol}://x-token-auth:{token_value}@{domain}/{repo_name}.git'
elif provider == ProviderType.BITBUCKET_DATA_CENTER:
# DC uses HTTP Basic auth — token must be in username:token format
project, repo_slug = (
repo_name.split('/', 1)
if '/' in repo_name
else (repo_name, repo_name)
)
scm_path = f'scm/{project.lower()}/{repo_slug}.git'
# Percent-encode each credential part so special characters
# (e.g. @, #, /) don't break the URL.
if ':' in token_value:
dc_user, dc_pass = token_value.split(':', 1)
url_creds = (
f'{quote(dc_user, safe="")}:{quote(dc_pass, safe="")}'
)
else:
url_creds = f'x-token-auth:{quote(token_value, safe="")}'
remote_url = f'{protocol}://{url_creds}@{domain}/{scm_path}'
elif provider == ProviderType.AZURE_DEVOPS:
# Azure DevOps uses PAT with Basic auth
# Format: https://{anything}:{PAT}@dev.azure.com/{org}/{project}/_git/{repo}
# The username can be anything (it's ignored), but cannot be empty
# We use the org name as the username for clarity
# repo_name is in format: org/project/repo
logger.info(
f'[Azure DevOps] Constructing authenticated git URL for repository: {repo_name}'
)
logger.debug(f'[Azure DevOps] Original domain: {domain}')
logger.debug(
f'[Azure DevOps] Token available: {bool(token_value)}, '
f'Token length: {len(token_value) if token_value else 0}'
)
# Remove domain prefix if it exists in domain variable
clean_domain = domain.replace('https://', '').replace('http://', '')
logger.debug(f'[Azure DevOps] Cleaned domain: {clean_domain}')
parts = repo_name.split('/')
logger.debug(
f'[Azure DevOps] Repository parts: {parts} (length: {len(parts)})'
)
if len(parts) >= 3:
org, project, repo = parts[0], parts[1], parts[2]
logger.info(
f'[Azure DevOps] Parsed repository - org: {org}, project: {project}, repo: {repo}'
)
# URL-encode org, project, and repo to handle spaces and special characters
org_encoded = quote(org, safe='')
project_encoded = quote(project, safe='')
repo_encoded = quote(repo, safe='')
logger.debug(
f'[Azure DevOps] URL-encoded parts - org: {org_encoded}, project: {project_encoded}, repo: {repo_encoded}'
)
# Use org name as username (it's ignored by Azure DevOps but required for git)
remote_url = f'https://{org}:***@{clean_domain}/{org_encoded}/{project_encoded}/_git/{repo_encoded}'
logger.info(
f'[Azure DevOps] Constructed git URL (token masked): {remote_url}'
)
# Set the actual URL with token
remote_url = f'https://{org}:{token_value}@{clean_domain}/{org_encoded}/{project_encoded}/_git/{repo_encoded}'
else:
# Fallback if format is unexpected
logger.warning(
f'[Azure DevOps] Unexpected repository format: {repo_name}. '
f'Expected org/project/repo (3 parts), got {len(parts)} parts. '
'Using fallback URL format.'
)
remote_url = (
f'https://user:{token_value}@{clean_domain}/{repo_name}.git'
)
logger.warning(
f'[Azure DevOps] Fallback URL constructed (token masked): '
f'https://user:***@{clean_domain}/{repo_name}.git'
)
else:
# GitHub, Forgejo
remote_url = f'{protocol}://{token_value}@{domain}/{repo_name}.git'
else:
remote_url = f'{protocol}://{domain}/{repo_name}.git'
else:
remote_url = f'{protocol}://{domain}/{repo_name}.git'
return remote_url