diff --git a/openhands/integrations/provider.py b/openhands/integrations/provider.py index 5b4c229768..c0bff80581 100644 --- a/openhands/integrations/provider.py +++ b/openhands/integrations/provider.py @@ -1,8 +1,10 @@ from __future__ import annotations +import os from types import MappingProxyType from typing import Annotated, Any, Coroutine, Literal, cast, overload +import httpx from pydantic import ( BaseModel, ConfigDict, @@ -28,6 +30,7 @@ from openhands.integrations.service_types import ( Repository, ResourceNotFoundError, SuggestedTask, + TokenResponse, User, ) from openhands.microagent.types import MicroagentContentResponse, MicroagentResponse @@ -112,6 +115,8 @@ class ProviderHandler: external_auth_id: str | None = None, external_auth_token: SecretStr | None = None, external_token_manager: bool = False, + session_api_key: str | None = None, + sid: str | None = None, ): if not isinstance(provider_tokens, MappingProxyType): raise TypeError( @@ -127,7 +132,13 @@ class ProviderHandler: self.external_auth_id = external_auth_id self.external_auth_token = external_auth_token self.external_token_manager = external_token_manager + self.session_api_key = session_api_key + self.sid = sid self._provider_tokens = provider_tokens + WEB_HOST = os.getenv('WEB_HOST', '').strip() + self.REFRESH_TOKEN_URL = ( + f'https://{WEB_HOST}/api/refresh-tokens' if WEB_HOST else None + ) @property def provider_tokens(self) -> PROVIDER_TOKEN_TYPE: @@ -161,8 +172,24 @@ class ProviderHandler: self, provider: ProviderType ) -> SecretStr | None: """Get latest token from service""" - service = self._get_service(provider) - return await service.get_latest_token() + try: + async with httpx.AsyncClient() as client: + resp = await client.get( + self.REFRESH_TOKEN_URL, + headers={ + 'X-Session-API-Key': self.session_api_key, + }, + params={'provider': provider.value, 'sid': self.sid}, + ) + + resp.raise_for_status() + data = TokenResponse.model_validate_json(resp.text) + return SecretStr(data.token) + + except Exception as e: + logger.warning(f'Failed to fetch latest token for provider {provider}: {e}') + + return None async def get_github_installations(self) -> list[str]: service = cast(InstallationsService, self._get_service(ProviderType.GITHUB)) @@ -356,7 +383,7 @@ class ProviderHandler: else SecretStr('') ) - if get_latest: + if get_latest and self.REFRESH_TOKEN_URL and self.sid: token = await self._get_latest_provider_token(provider) if token: diff --git a/openhands/integrations/service_types.py b/openhands/integrations/service_types.py index bd8091c088..8020c88e97 100644 --- a/openhands/integrations/service_types.py +++ b/openhands/integrations/service_types.py @@ -14,6 +14,10 @@ from openhands.microagent.types import MicroagentContentResponse, MicroagentResp from openhands.server.types import AppMode +class TokenResponse(BaseModel): + token: str + + class ProviderType(Enum): GITHUB = 'github' GITLAB = 'gitlab' diff --git a/openhands/runtime/base.py b/openhands/runtime/base.py index e25b5636c3..3a377d11b0 100644 --- a/openhands/runtime/base.py +++ b/openhands/runtime/base.py @@ -323,9 +323,6 @@ class Runtime(FileEditRuntimeMixin): async def _export_latest_git_provider_tokens(self, event: Action) -> None: """Refresh runtime provider tokens when agent attemps to run action with provider token""" - if not self.user_id: - return - providers_called = ProviderHandler.check_cmd_action_for_provider_token_ref( event ) @@ -333,8 +330,17 @@ class Runtime(FileEditRuntimeMixin): if not providers_called: return + provider_handler = ProviderHandler( + provider_tokens=self.git_provider_tokens + or cast(PROVIDER_TOKEN_TYPE, MappingProxyType({})), + external_auth_id=self.user_id, + external_token_manager=True, + session_api_key=self.session_api_key, + sid=self.sid, + ) + logger.info(f'Fetching latest provider tokens for runtime: {self.sid}') - env_vars = await self.provider_handler.get_env_vars( + env_vars = await provider_handler.get_env_vars( providers=providers_called, expose_secrets=False, get_latest=True ) @@ -343,10 +349,10 @@ class Runtime(FileEditRuntimeMixin): try: if self.event_stream: - await self.provider_handler.set_event_stream_secrets( + await provider_handler.set_event_stream_secrets( self.event_stream, env_vars=env_vars ) - self.add_env_vars(self.provider_handler.expose_env_vars(env_vars)) + self.add_env_vars(provider_handler.expose_env_vars(env_vars)) except Exception as e: logger.warning( f'Failed export latest github token to runtime: {self.sid}, {e}' diff --git a/tests/unit/runtime/test_runtime_git_tokens.py b/tests/unit/runtime/test_runtime_git_tokens.py index aed15fd777..e1c4bb0613 100644 --- a/tests/unit/runtime/test_runtime_git_tokens.py +++ b/tests/unit/runtime/test_runtime_git_tokens.py @@ -1,5 +1,5 @@ from types import MappingProxyType -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from pydantic import SecretStr @@ -213,22 +213,28 @@ async def test_export_latest_git_provider_tokens_multiple_refs(temp_dir): @pytest.mark.asyncio -async def test_export_latest_git_provider_tokens_token_update(runtime): +async def test_export_latest_git_provider_tokens_token_update(runtime, monkeypatch): """Test that token updates are handled correctly""" # First export with initial token cmd = CmdRunAction(command='echo $GITHUB_TOKEN') await runtime._export_latest_git_provider_tokens(cmd) - # Update the token + # Ensure refresh-token flow is enabled in ProviderHandler + monkeypatch.setenv('WEB_HOST', 'example.com') + + # Simulate that provider handler will now fetch a new token from refresh endpoint new_token = 'new_test_token' - runtime.provider_handler._provider_tokens = MappingProxyType( - {ProviderType.GITHUB: ProviderToken(token=SecretStr(new_token))} - ) - # Export again with updated token - await runtime._export_latest_git_provider_tokens(cmd) + # Patch ProviderHandler._get_latest_provider_token to return new SecretStr + with patch.object( + ProviderHandler, + '_get_latest_provider_token', + new=AsyncMock(return_value=SecretStr(new_token)), + ): + # Export again with updated token – runtime should fetch latest and update EventStream secrets + await runtime._export_latest_git_provider_tokens(cmd) - # Verify that the new token was exported + # Verify that the new token was exported to the event stream assert runtime.event_stream.secrets == {'github_token': new_token}