from __future__ import annotations import os from collections.abc import Mapping from types import MappingProxyType from typing import cast from urllib.parse import quote import httpx from pydantic import ( BaseModel, ConfigDict, Field, SecretStr, ) from openhands.core.logger import openhands_logger as logger from openhands.integrations.azure_devops.azure_devops_service import ( AzureDevOpsServiceImpl, ) from openhands.integrations.bitbucket.bitbucket_service import BitBucketServiceImpl from openhands.integrations.bitbucket_data_center.bitbucket_dc_service import ( BitbucketDCServiceImpl, ) from openhands.integrations.forgejo.forgejo_service import ForgejoServiceImpl from openhands.integrations.github.github_service import GithubServiceImpl from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl from openhands.integrations.service_types import ( AuthenticationError, Branch, GitService, InstallationsService, PaginatedBranchesResponse, ProviderTimeoutError, ProviderType, Repository, SuggestedTask, TokenResponse, User, ) from openhands.server.types import AppMode from openhands.utils.http_session import httpx_verify_option class ProviderToken(BaseModel): token: SecretStr | None = Field(default=None) user_id: str | None = Field(default=None) host: str | None = Field(default=None) model_config = ConfigDict( frozen=True, # Makes the entire model immutable validate_assignment=True, ) @classmethod def from_value(cls, token_value: ProviderToken | dict[str, str]) -> ProviderToken: """Factory method to create a ProviderToken from various input types""" if isinstance(token_value, cls): return token_value elif isinstance(token_value, dict): token_str = token_value.get('token', '') # Override with emtpy string if it was set to None # Cannot pass None to SecretStr if token_str is None: token_str = '' # type: ignore[unreachable] user_id = token_value.get('user_id') host = token_value.get('host') return cls(token=SecretStr(token_str), user_id=user_id, host=host) else: raise ValueError('Unsupported Provider token type') class CustomSecret(BaseModel): secret: SecretStr = Field(default_factory=lambda: SecretStr('')) description: str = Field(default='') model_config = ConfigDict( frozen=True, # Makes the entire model immutable validate_assignment=True, ) @classmethod def from_value(cls, secret_value: CustomSecret | dict[str, str]) -> CustomSecret: """Factory method to create a ProviderToken from various input types""" if isinstance(secret_value, CustomSecret): return secret_value elif isinstance(secret_value, dict): secret = secret_value.get('secret', '') description = secret_value.get('description', '') return cls(secret=SecretStr(secret), description=description) else: raise ValueError('Unsupport Provider token type') PROVIDER_TOKEN_TYPE = Mapping[ProviderType, ProviderToken] CUSTOM_SECRETS_TYPE = Mapping[str, CustomSecret] class ProviderHandler: # Class variable for provider domains PROVIDER_DOMAINS: dict[ProviderType, str] = { ProviderType.GITHUB: 'github.com', ProviderType.GITLAB: 'gitlab.com', ProviderType.BITBUCKET: 'bitbucket.org', ProviderType.FORGEJO: 'codeberg.org', ProviderType.AZURE_DEVOPS: 'dev.azure.com', } def __init__( self, provider_tokens: PROVIDER_TOKEN_TYPE, external_auth_id: str | None = None, external_auth_token: SecretStr | None = None, external_token_manager: bool = False, session_api_key: str | None = None, sid: str | None = None, ): if not isinstance(provider_tokens, MappingProxyType): raise TypeError( f'provider_tokens must be a MappingProxyType, got {type(provider_tokens).__name__}' ) self.service_class_map: dict[ProviderType, type[GitService]] = { ProviderType.GITHUB: GithubServiceImpl, ProviderType.GITLAB: GitLabServiceImpl, ProviderType.BITBUCKET: BitBucketServiceImpl, ProviderType.BITBUCKET_DATA_CENTER: BitbucketDCServiceImpl, ProviderType.FORGEJO: ForgejoServiceImpl, ProviderType.AZURE_DEVOPS: AzureDevOpsServiceImpl, } self.external_auth_id = external_auth_id self.external_auth_token = external_auth_token self.external_token_manager = external_token_manager self.session_api_key = session_api_key self.sid = sid self._provider_tokens = provider_tokens WEB_HOST = os.getenv('WEB_HOST', '').strip() self.REFRESH_TOKEN_URL = ( f'https://{WEB_HOST}/api/refresh-tokens' if WEB_HOST else None ) @property def provider_tokens(self) -> PROVIDER_TOKEN_TYPE: """Read-only access to provider tokens.""" return self._provider_tokens def get_service(self, provider: ProviderType) -> GitService: """Helper method to instantiate a service for a given provider""" token = self.provider_tokens[provider] service_class = self.service_class_map[provider] return service_class( user_id=token.user_id, external_auth_id=self.external_auth_id, external_auth_token=self.external_auth_token, token=token.token, external_token_manager=self.external_token_manager, base_domain=token.host, ) async def get_user(self) -> User: """Get user information from the first available provider""" exceptions: list[tuple[ProviderType, Exception]] = [] for provider in self.provider_tokens: try: service = self.get_service(provider) return await service.get_user() except Exception as e: exceptions.append((provider, e)) continue for provider, exc in exceptions: logger.warning( f'Failed to get user from provider {provider}: {exc}', exc_info=(type(exc), exc, exc.__traceback__), ) raise AuthenticationError('Need valid provider token') async def _get_latest_provider_token( self, provider: ProviderType ) -> SecretStr | None: """Get latest token from service""" if not self.REFRESH_TOKEN_URL: logger.warning('Refresh token URL not set') return None try: async with httpx.AsyncClient(verify=httpx_verify_option()) as client: headers = ( {'X-Session-API-Key': self.session_api_key} if self.session_api_key else {} ) resp = await client.get( self.REFRESH_TOKEN_URL, headers=headers, params={'provider': provider.value, 'sid': self.sid}, ) resp.raise_for_status() data = TokenResponse.model_validate_json(resp.text) return SecretStr(data.token) except Exception as e: logger.error( f'Failed to fetch latest token for provider {provider}: {e}', exc_info=True, ) return None async def get_github_installations(self) -> list[str]: service = cast(InstallationsService, self.get_service(ProviderType.GITHUB)) try: return await service.get_installations() except Exception as e: logger.warning(f'Failed to get github installations {e}') return [] async def get_bitbucket_workspaces(self) -> list[str]: service = cast(InstallationsService, self.get_service(ProviderType.BITBUCKET)) try: return await service.get_installations() except Exception as e: logger.warning(f'Failed to get bitbucket workspaces {e}') return [] async def get_bitbucket_dc_projects(self) -> list[str]: service = cast( InstallationsService, self.get_service(ProviderType.BITBUCKET_DATA_CENTER), ) try: return await service.get_installations() except Exception as e: logger.warning(f'Failed to get bitbucket data center projects {e}') return [] async def get_github_organizations(self) -> list[str]: service = self.get_service(ProviderType.GITHUB) try: return await service.get_organizations_from_installations() # type: ignore[attr-defined] except Exception as e: logger.warning(f'Failed to get github organizations {e}') return [] async def get_gitlab_groups(self) -> list[str]: service = self.get_service(ProviderType.GITLAB) try: return await service.get_user_groups() # type: ignore[attr-defined] except Exception as e: logger.warning(f'Failed to get gitlab groups {e}') return [] async def get_azure_devops_organizations(self) -> list[str]: service = cast( InstallationsService, self.get_service(ProviderType.AZURE_DEVOPS) ) try: return await service.get_installations() except Exception as e: logger.warning(f'Failed to get azure devops organizations {e}') return [] async def get_repositories( self, sort: str, app_mode: AppMode, selected_provider: ProviderType | None, page: int | None, per_page: int | None, installation_id: str | None, ) -> list[Repository]: """Get repositories from providers. Raises: ProviderTimeoutError: If a timeout occurs while fetching repos. """ if selected_provider: if not page or not per_page: raise ValueError('Failed to provider params for paginating repos') service = self.get_service(selected_provider) return await service.get_paginated_repos( page, per_page, sort, installation_id ) all_repos: list[Repository] = [] for provider in self.provider_tokens: try: service = self.get_service(provider) service_repos = await service.get_all_repositories(sort, app_mode) all_repos.extend(service_repos) except ProviderTimeoutError: # Propagate timeout errors so callers can handle them appropriately raise except Exception as e: logger.warning(f'Error fetching repos from {provider}: {e}') return all_repos async def get_suggested_tasks(self) -> list[SuggestedTask]: """Get suggested tasks from providers""" tasks: list[SuggestedTask] = [] for provider in self.provider_tokens: try: service = self.get_service(provider) service_repos = await service.get_suggested_tasks() tasks.extend(service_repos) except Exception as e: logger.warning(f'Error fetching repos from {provider}: {e}') return tasks async def search_branches( self, selected_provider: ProviderType | None, repository: str, query: str, per_page: int = 30, ) -> list[Branch]: """Search for branches within a repository using the appropriate provider service.""" if selected_provider: service = self.get_service(selected_provider) try: return await service.search_branches(repository, query, per_page) except Exception as e: logger.warning( f'Error searching branches from selected provider {selected_provider}: {e}' ) return [] # If provider not specified, determine provider by verifying repository access try: repo_details = await self.verify_repo_provider(repository) service = self.get_service(repo_details.git_provider) return await service.search_branches(repository, query, per_page) except Exception as e: logger.warning(f'Error searching branches for {repository}: {e}') return [] async def search_repositories( self, selected_provider: ProviderType | None, query: str, per_page: int, sort: str, order: str, app_mode: AppMode, ) -> list[Repository]: if selected_provider: service = self.get_service(selected_provider) public = self._is_repository_url(query, selected_provider) user_repos = await service.search_repositories( query, per_page, sort, order, public, app_mode ) return self._deduplicate_repositories(user_repos) all_repos: list[Repository] = [] for provider in self.provider_tokens: try: service = self.get_service(provider) public = self._is_repository_url(query, provider) service_repos = await service.search_repositories( query, per_page, sort, order, public, app_mode ) all_repos.extend(service_repos) except Exception as e: logger.warning(f'Error searching repos from {provider}: {e}') continue return all_repos def _is_repository_url(self, query: str, provider: ProviderType) -> bool: """Check if the query is a repository URL.""" custom_host = self.provider_tokens[provider].host custom_host_exists = bool(custom_host and custom_host in query) default_domain = self.PROVIDER_DOMAINS.get(provider) default_host_exists = default_domain is not None and default_domain in query return query.startswith(('http://', 'https://')) and ( custom_host_exists or default_host_exists ) def _deduplicate_repositories(self, repos: list[Repository]) -> list[Repository]: """Remove duplicate repositories based on full_name.""" seen = set() unique_repos = [] for repo in repos: if repo.full_name not in seen: seen.add(repo.id) unique_repos.append(repo) return unique_repos @classmethod def get_provider_env_key(cls, provider: ProviderType) -> str: """Map ProviderType value to the environment variable name in the runtime""" return f'{provider.value}_token'.lower() async def verify_repo_provider( self, repository: str, specified_provider: ProviderType | None = None, is_optional: bool = False, ) -> Repository: errors = [] if specified_provider: try: service = self.get_service(specified_provider) return await service.get_repository_details_from_repo_name(repository) except Exception as e: errors.append(f'{specified_provider.value}: {str(e)}') for provider in self.provider_tokens: try: service = self.get_service(provider) return await service.get_repository_details_from_repo_name(repository) except Exception as e: errors.append(f'{provider.value}: {str(e)}') # Log detailed error based on whether we had tokens or not # For optional repositories (like org-level microagents), use debug level log_fn = logger.debug if is_optional else logger.error if not self.provider_tokens: log_fn( f'Failed to access repository {repository}: No provider tokens available. ' f'provider_tokens dict is empty.' ) elif errors: log_fn( f'Failed to access repository {repository} with all available providers. ' f'Tried providers: {list(self.provider_tokens.keys())}. ' f'Errors: {"; ".join(errors)}' ) else: log_fn( f'Failed to access repository {repository}: Unknown error (no providers tried, no errors recorded)' ) raise AuthenticationError(f'Unable to access repo {repository}') async def get_branches( self, repository: str, specified_provider: ProviderType | None = None, page: int = 1, per_page: int = 30, ) -> PaginatedBranchesResponse: """Get branches for a repository Args: repository: The repository name specified_provider: Optional provider type to use page: Page number for pagination (default: 1) per_page: Number of branches per page (default: 30) Returns: A paginated response with branches for the repository """ if specified_provider: try: service = self.get_service(specified_provider) return await service.get_paginated_branches(repository, page, per_page) except Exception as e: logger.warning( f'Error fetching branches from {specified_provider}: {e}' ) for provider in self.provider_tokens: try: service = self.get_service(provider) return await service.get_paginated_branches(repository, page, per_page) except Exception as e: logger.warning(f'Error fetching branches from {provider}: {e}') # Return empty response if no provider worked return PaginatedBranchesResponse( branches=[], has_next_page=False, current_page=page, per_page=per_page, total_count=0, ) async def get_authenticated_git_url( self, repo_name: str, is_optional: bool = False ) -> str: """Get an authenticated git URL for a repository. Args: repo_name: Repository name (owner/repo) is_optional: If True, logs at debug level instead of error level when repo not found Returns: Authenticated git URL if credentials are available, otherwise regular HTTPS URL """ try: repository = await self.verify_repo_provider( repo_name, is_optional=is_optional ) except AuthenticationError: raise Exception('Git provider authentication issue when getting remote URL') provider = repository.git_provider repo_name = repository.full_name domain = self.PROVIDER_DOMAINS.get(provider, '') # If provider tokens are provided, use the host from the token if available # Note: For Azure DevOps, don't use the host field as it may contain org/project path if self.provider_tokens and provider in self.provider_tokens: if provider != ProviderType.AZURE_DEVOPS: domain = self.provider_tokens[provider].host or domain # Detect protocol before normalizing domain # Default to https, but preserve http if explicitly specified protocol = 'https' if domain and domain.strip().startswith('http://'): # Check if insecure HTTP access is allowed allow_insecure = os.environ.get( 'ALLOW_INSECURE_GIT_ACCESS', 'false' ).lower() in ('true', '1', 'yes') if not allow_insecure: raise ValueError( 'Attempting to connect to an insecure git repository over HTTP. ' "If you'd like to allow this nonetheless, set " 'ALLOW_INSECURE_GIT_ACCESS=true as an environment variable.' ) protocol = 'http' # Normalize domain to prevent double protocols or path segments if domain: domain = domain.strip() domain = domain.replace('https://', '').replace('http://', '') # Remove any trailing path like /api/v3 or /api/v4 if '/' in domain: domain = domain.split('/')[0] # Try to use token if available, otherwise use public URL if self.provider_tokens and provider in self.provider_tokens: git_token = self.provider_tokens[provider].token if git_token: token_value = git_token.get_secret_value() if provider == ProviderType.GITLAB: remote_url = ( f'{protocol}://oauth2:{token_value}@{domain}/{repo_name}.git' ) elif provider == ProviderType.BITBUCKET: # For Bitbucket, handle username:app_password format if ':' in token_value: # App token format: username:app_password remote_url = ( f'{protocol}://{token_value}@{domain}/{repo_name}.git' ) else: # Access token format: use x-token-auth remote_url = f'{protocol}://x-token-auth:{token_value}@{domain}/{repo_name}.git' elif provider == ProviderType.BITBUCKET_DATA_CENTER: # DC uses HTTP Basic auth — token must be in username:token format project, repo_slug = ( repo_name.split('/', 1) if '/' in repo_name else (repo_name, repo_name) ) scm_path = f'scm/{project.lower()}/{repo_slug}.git' # Percent-encode each credential part so special characters # (e.g. @, #, /) don't break the URL. if ':' in token_value: dc_user, dc_pass = token_value.split(':', 1) url_creds = ( f'{quote(dc_user, safe="")}:{quote(dc_pass, safe="")}' ) else: url_creds = f'x-token-auth:{quote(token_value, safe="")}' remote_url = f'{protocol}://{url_creds}@{domain}/{scm_path}' elif provider == ProviderType.AZURE_DEVOPS: # Azure DevOps uses PAT with Basic auth # Format: https://{anything}:{PAT}@dev.azure.com/{org}/{project}/_git/{repo} # The username can be anything (it's ignored), but cannot be empty # We use the org name as the username for clarity # repo_name is in format: org/project/repo logger.info( f'[Azure DevOps] Constructing authenticated git URL for repository: {repo_name}' ) logger.debug(f'[Azure DevOps] Original domain: {domain}') logger.debug( f'[Azure DevOps] Token available: {bool(token_value)}, ' f'Token length: {len(token_value) if token_value else 0}' ) # Remove domain prefix if it exists in domain variable clean_domain = domain.replace('https://', '').replace('http://', '') logger.debug(f'[Azure DevOps] Cleaned domain: {clean_domain}') parts = repo_name.split('/') logger.debug( f'[Azure DevOps] Repository parts: {parts} (length: {len(parts)})' ) if len(parts) >= 3: org, project, repo = parts[0], parts[1], parts[2] logger.info( f'[Azure DevOps] Parsed repository - org: {org}, project: {project}, repo: {repo}' ) # URL-encode org, project, and repo to handle spaces and special characters org_encoded = quote(org, safe='') project_encoded = quote(project, safe='') repo_encoded = quote(repo, safe='') logger.debug( f'[Azure DevOps] URL-encoded parts - org: {org_encoded}, project: {project_encoded}, repo: {repo_encoded}' ) # Use org name as username (it's ignored by Azure DevOps but required for git) remote_url = f'https://{org}:***@{clean_domain}/{org_encoded}/{project_encoded}/_git/{repo_encoded}' logger.info( f'[Azure DevOps] Constructed git URL (token masked): {remote_url}' ) # Set the actual URL with token remote_url = f'https://{org}:{token_value}@{clean_domain}/{org_encoded}/{project_encoded}/_git/{repo_encoded}' else: # Fallback if format is unexpected logger.warning( f'[Azure DevOps] Unexpected repository format: {repo_name}. ' f'Expected org/project/repo (3 parts), got {len(parts)} parts. ' 'Using fallback URL format.' ) remote_url = ( f'https://user:{token_value}@{clean_domain}/{repo_name}.git' ) logger.warning( f'[Azure DevOps] Fallback URL constructed (token masked): ' f'https://user:***@{clean_domain}/{repo_name}.git' ) else: # GitHub, Forgejo remote_url = f'{protocol}://{token_value}@{domain}/{repo_name}.git' else: remote_url = f'{protocol}://{domain}/{repo_name}.git' else: remote_url = f'{protocol}://{domain}/{repo_name}.git' return remote_url