mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
refactor: introduce HTTPClient protocol for git service integrations (#10731)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import Any
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.protocols.http_client import HTTPClient
|
||||
from openhands.integrations.service_types import (
|
||||
BaseGitService,
|
||||
OwnerType,
|
||||
@@ -15,14 +16,12 @@ from openhands.integrations.service_types import (
|
||||
)
|
||||
|
||||
|
||||
class BitBucketMixinBase(BaseGitService):
|
||||
class BitBucketMixinBase(BaseGitService, HTTPClient):
|
||||
"""
|
||||
Base mixin for BitBucket service containing common functionality
|
||||
"""
|
||||
|
||||
BASE_URL = 'https://api.bitbucket.org/2.0'
|
||||
token: SecretStr = SecretStr('')
|
||||
refresh = False
|
||||
|
||||
def _extract_owner_and_repo(self, repository: str) -> tuple[str, str]:
|
||||
"""Extract owner and repo from repository string.
|
||||
@@ -49,7 +48,7 @@ class BitBucketMixinBase(BaseGitService):
|
||||
def _has_token_expired(self, status_code: int) -> bool:
|
||||
return status_code == 401
|
||||
|
||||
async def _get_bitbucket_headers(self) -> dict[str, str]:
|
||||
async def _get_headers(self) -> dict[str, str]:
|
||||
"""Get headers for Bitbucket API requests."""
|
||||
token_value = self.token.get_secret_value()
|
||||
|
||||
@@ -85,13 +84,13 @@ class BitBucketMixinBase(BaseGitService):
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
bitbucket_headers = await self._get_bitbucket_headers()
|
||||
bitbucket_headers = await self._get_headers()
|
||||
response = await self.execute_request(
|
||||
client, url, bitbucket_headers, params, method
|
||||
)
|
||||
if self.refresh and self._has_token_expired(response.status_code):
|
||||
await self.get_latest_token()
|
||||
bitbucket_headers = await self._get_bitbucket_headers()
|
||||
bitbucket_headers = await self._get_headers()
|
||||
response = await self.execute_request(
|
||||
client=client,
|
||||
url=url,
|
||||
|
||||
@@ -43,8 +43,6 @@ class GitHubService(
|
||||
|
||||
BASE_URL = 'https://api.github.com'
|
||||
GRAPHQL_URL = 'https://api.github.com/graphql'
|
||||
token: SecretStr = SecretStr('')
|
||||
refresh = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# openhands/integrations/github/service/__init__.py
|
||||
|
||||
from .base import GitHubMixinBase
|
||||
from .branches_prs import GitHubBranchesMixin
|
||||
from .features import GitHubFeaturesMixin
|
||||
from .prs import GitHubPRsMixin
|
||||
@@ -7,6 +8,7 @@ from .repos import GitHubReposMixin
|
||||
from .resolver import GitHubResolverMixin
|
||||
|
||||
__all__ = [
|
||||
'GitHubMixinBase',
|
||||
'GitHubBranchesMixin',
|
||||
'GitHubFeaturesMixin',
|
||||
'GitHubPRsMixin',
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, cast
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.protocols.http_client import HTTPClient
|
||||
from openhands.integrations.service_types import (
|
||||
BaseGitService,
|
||||
RequestMethod,
|
||||
@@ -12,19 +13,15 @@ from openhands.integrations.service_types import (
|
||||
)
|
||||
|
||||
|
||||
class GitHubMixinBase(BaseGitService):
|
||||
class GitHubMixinBase(BaseGitService, HTTPClient):
|
||||
"""
|
||||
Declares common attributes and method signatures used across mixins.
|
||||
"""
|
||||
|
||||
BASE_URL: str
|
||||
GRAPHQL_URL: str
|
||||
token: SecretStr
|
||||
refresh: bool
|
||||
external_auth_id: str | None
|
||||
base_domain: str | None
|
||||
|
||||
async def _get_github_headers(self) -> dict:
|
||||
async def _get_headers(self) -> dict:
|
||||
"""Retrieve the GH Token from settings store to construct the headers."""
|
||||
if not self.token:
|
||||
latest_token = await self.get_latest_token()
|
||||
@@ -47,7 +44,7 @@ class GitHubMixinBase(BaseGitService):
|
||||
) -> tuple[Any, dict]: # type: ignore[override]
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
github_headers = await self._get_github_headers()
|
||||
github_headers = await self._get_headers()
|
||||
|
||||
# Make initial request
|
||||
response = await self.execute_request(
|
||||
@@ -61,7 +58,7 @@ class GitHubMixinBase(BaseGitService):
|
||||
# Handle token refresh if needed
|
||||
if self.refresh and self._has_token_expired(response.status_code):
|
||||
await self.get_latest_token()
|
||||
github_headers = await self._get_github_headers()
|
||||
github_headers = await self._get_headers()
|
||||
response = await self.execute_request(
|
||||
client=client,
|
||||
url=url,
|
||||
@@ -87,7 +84,7 @@ class GitHubMixinBase(BaseGitService):
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
github_headers = await self._get_github_headers()
|
||||
github_headers = await self._get_headers()
|
||||
|
||||
response = await client.post(
|
||||
self.GRAPHQL_URL,
|
||||
|
||||
@@ -41,8 +41,6 @@ class GitLabService(
|
||||
|
||||
BASE_URL = 'https://gitlab.com/api/v4'
|
||||
GRAPHQL_URL = 'https://gitlab.com/api/graphql'
|
||||
token: SecretStr = SecretStr('')
|
||||
refresh = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# openhands/integrations/gitlab/service/__init__.py
|
||||
|
||||
from .base import GitLabMixinBase
|
||||
from .branches import GitLabBranchesMixin
|
||||
from .features import GitLabFeaturesMixin
|
||||
from .prs import GitLabPRsMixin
|
||||
@@ -7,6 +8,7 @@ from .repos import GitLabReposMixin
|
||||
from .resolver import GitLabResolverMixin
|
||||
|
||||
__all__ = [
|
||||
'GitLabMixinBase',
|
||||
'GitLabBranchesMixin',
|
||||
'GitLabFeaturesMixin',
|
||||
'GitLabPRsMixin',
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.protocols.http_client import HTTPClient
|
||||
from openhands.integrations.service_types import (
|
||||
BaseGitService,
|
||||
RequestMethod,
|
||||
@@ -11,19 +12,15 @@ from openhands.integrations.service_types import (
|
||||
)
|
||||
|
||||
|
||||
class GitLabMixinBase(BaseGitService):
|
||||
class GitLabMixinBase(BaseGitService, HTTPClient):
|
||||
"""
|
||||
Declares common attributes and method signatures used across mixins.
|
||||
"""
|
||||
|
||||
BASE_URL: str
|
||||
GRAPHQL_URL: str
|
||||
token: SecretStr
|
||||
refresh: bool
|
||||
external_auth_id: str | None
|
||||
base_domain: str | None
|
||||
|
||||
async def _get_gitlab_headers(self) -> dict[str, Any]:
|
||||
async def _get_headers(self) -> dict[str, Any]:
|
||||
"""Retrieve the GitLab Token to construct the headers"""
|
||||
if not self.token:
|
||||
latest_token = await self.get_latest_token()
|
||||
@@ -45,7 +42,7 @@ class GitLabMixinBase(BaseGitService):
|
||||
) -> tuple[Any, dict]: # type: ignore[override]
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
gitlab_headers = await self._get_gitlab_headers()
|
||||
gitlab_headers = await self._get_headers()
|
||||
|
||||
# Make initial request
|
||||
response = await self.execute_request(
|
||||
@@ -59,7 +56,7 @@ class GitLabMixinBase(BaseGitService):
|
||||
# Handle token refresh if needed
|
||||
if self.refresh and self._has_token_expired(response.status_code):
|
||||
await self.get_latest_token()
|
||||
gitlab_headers = await self._get_gitlab_headers()
|
||||
gitlab_headers = await self._get_headers()
|
||||
response = await self.execute_request(
|
||||
client=client,
|
||||
url=url,
|
||||
@@ -103,7 +100,7 @@ class GitLabMixinBase(BaseGitService):
|
||||
variables = {}
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
gitlab_headers = await self._get_gitlab_headers()
|
||||
gitlab_headers = await self._get_headers()
|
||||
# Add content type header for GraphQL
|
||||
gitlab_headers['Content-Type'] = 'application/json'
|
||||
|
||||
@@ -118,7 +115,7 @@ class GitLabMixinBase(BaseGitService):
|
||||
|
||||
if self.refresh and self._has_token_expired(response.status_code):
|
||||
await self.get_latest_token()
|
||||
gitlab_headers = await self._get_gitlab_headers()
|
||||
gitlab_headers = await self._get_headers()
|
||||
gitlab_headers['Content-Type'] = 'application/json'
|
||||
response = await client.post(
|
||||
self.GRAPHQL_URL, headers=gitlab_headers, json=payload
|
||||
|
||||
99
openhands/integrations/protocols/http_client.py
Normal file
99
openhands/integrations/protocols/http_client.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""HTTP Client Protocol for Git Service Integrations."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from httpx import AsyncClient, HTTPError, HTTPStatusError
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
RateLimitError,
|
||||
RequestMethod,
|
||||
ResourceNotFoundError,
|
||||
UnknownException,
|
||||
)
|
||||
|
||||
|
||||
class HTTPClient(ABC):
|
||||
"""Abstract base class defining the HTTP client interface for Git service integrations.
|
||||
|
||||
This class abstracts the common HTTP client functionality needed by all
|
||||
Git service providers (GitHub, GitLab, BitBucket) while keeping inheritance in place.
|
||||
"""
|
||||
|
||||
# Default attributes (subclasses may override)
|
||||
token: SecretStr = SecretStr('')
|
||||
refresh: bool = False
|
||||
external_auth_id: str | None = None
|
||||
external_auth_token: SecretStr | None = None
|
||||
external_token_manager: bool = False
|
||||
base_domain: str | None = None
|
||||
|
||||
# Provider identification must be implemented by subclasses
|
||||
@property
|
||||
@abstractmethod
|
||||
def provider(self) -> str: ...
|
||||
|
||||
# Abstract methods that concrete classes must implement
|
||||
@abstractmethod
|
||||
async def get_latest_token(self) -> SecretStr | None:
|
||||
"""Get the latest working token for the service."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def _get_headers(self) -> dict[str, Any]:
|
||||
"""Get HTTP headers for API requests."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def _make_request(
|
||||
self,
|
||||
url: str,
|
||||
params: dict | None = None,
|
||||
method: RequestMethod = RequestMethod.GET,
|
||||
) -> tuple[Any, dict]:
|
||||
"""Make an HTTP request to the Git service API."""
|
||||
...
|
||||
|
||||
def _has_token_expired(self, status_code: int) -> bool:
|
||||
"""Check if the token has expired based on HTTP status code."""
|
||||
return status_code == 401
|
||||
|
||||
async def execute_request(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
url: str,
|
||||
headers: dict,
|
||||
params: dict | None,
|
||||
method: RequestMethod = RequestMethod.GET,
|
||||
):
|
||||
"""Execute an HTTP request using the provided client."""
|
||||
if method == RequestMethod.POST:
|
||||
return await client.post(url, headers=headers, json=params)
|
||||
return await client.get(url, headers=headers, params=params)
|
||||
|
||||
def handle_http_status_error(
|
||||
self, e: HTTPStatusError
|
||||
) -> (
|
||||
AuthenticationError | RateLimitError | ResourceNotFoundError | UnknownException
|
||||
):
|
||||
"""Handle HTTP status errors and convert them to appropriate exceptions."""
|
||||
if e.response.status_code == 401:
|
||||
return AuthenticationError(f'Invalid {self.provider} token')
|
||||
elif e.response.status_code == 404:
|
||||
return ResourceNotFoundError(
|
||||
f'Resource not found on {self.provider} API: {e}'
|
||||
)
|
||||
elif e.response.status_code == 429:
|
||||
logger.warning(f'Rate limit exceeded on {self.provider} API: {e}')
|
||||
return RateLimitError(f'{self.provider} API rate limit exceeded')
|
||||
|
||||
logger.warning(f'Status error on {self.provider} API: {e}')
|
||||
return UnknownException(f'Unknown error: {e}')
|
||||
|
||||
def handle_http_error(self, e: HTTPError) -> UnknownException:
|
||||
"""Handle general HTTP errors."""
|
||||
logger.warning(f'HTTP error on {self.provider} API: {type(e).__name__} : {e}')
|
||||
return UnknownException(f'HTTP error {type(e).__name__} : {e}')
|
||||
@@ -4,7 +4,6 @@ from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
from httpx import AsyncClient, HTTPError, HTTPStatusError
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
@@ -242,40 +241,6 @@ class BaseGitService(ABC):
|
||||
"""Extract file path from directory item."""
|
||||
...
|
||||
|
||||
async def execute_request(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
url: str,
|
||||
headers: dict,
|
||||
params: dict | None,
|
||||
method: RequestMethod = RequestMethod.GET,
|
||||
):
|
||||
if method == RequestMethod.POST:
|
||||
return await client.post(url, headers=headers, json=params)
|
||||
return await client.get(url, headers=headers, params=params)
|
||||
|
||||
def handle_http_status_error(
|
||||
self, e: HTTPStatusError
|
||||
) -> (
|
||||
AuthenticationError | RateLimitError | ResourceNotFoundError | UnknownException
|
||||
):
|
||||
if e.response.status_code == 401:
|
||||
return AuthenticationError(f'Invalid {self.provider} token')
|
||||
elif e.response.status_code == 404:
|
||||
return ResourceNotFoundError(
|
||||
f'Resource not found on {self.provider} API: {e}'
|
||||
)
|
||||
elif e.response.status_code == 429:
|
||||
logger.warning(f'Rate limit exceeded on {self.provider} API: {e}')
|
||||
return RateLimitError('GitHub API rate limit exceeded')
|
||||
|
||||
logger.warning(f'Status error on {self.provider} API: {e}')
|
||||
return UnknownException(f'Unknown error: {e}')
|
||||
|
||||
def handle_http_error(self, e: HTTPError) -> UnknownException:
|
||||
logger.warning(f'HTTP error on {self.provider} API: {type(e).__name__} : {e}')
|
||||
return UnknownException(f'HTTP error {type(e).__name__} : {e}')
|
||||
|
||||
def _determine_microagents_path(self, repository_name: str) -> str:
|
||||
"""Determine the microagents directory path based on repository name."""
|
||||
actual_repo_name = repository_name.split('/')[-1]
|
||||
@@ -462,9 +427,6 @@ class BaseGitService(ABC):
|
||||
return comment_body[:max_comment_length] + '...'
|
||||
return comment_body
|
||||
|
||||
def _has_token_expired(self, status_code: int) -> bool:
|
||||
return status_code == 401
|
||||
|
||||
|
||||
class InstallationsService(Protocol):
|
||||
async def get_installations(self) -> list[str]:
|
||||
|
||||
Reference in New Issue
Block a user