[Fix]: token refresh for nested runtimes (#10637)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Rohit Malhotra
2025-08-27 12:20:34 -04:00
committed by GitHub
parent e68abf8d75
commit 36e0d8d3da
4 changed files with 61 additions and 18 deletions

View File

@@ -1,8 +1,10 @@
from __future__ import annotations from __future__ import annotations
import os
from types import MappingProxyType from types import MappingProxyType
from typing import Annotated, Any, Coroutine, Literal, cast, overload from typing import Annotated, Any, Coroutine, Literal, cast, overload
import httpx
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
ConfigDict, ConfigDict,
@@ -28,6 +30,7 @@ from openhands.integrations.service_types import (
Repository, Repository,
ResourceNotFoundError, ResourceNotFoundError,
SuggestedTask, SuggestedTask,
TokenResponse,
User, User,
) )
from openhands.microagent.types import MicroagentContentResponse, MicroagentResponse from openhands.microagent.types import MicroagentContentResponse, MicroagentResponse
@@ -112,6 +115,8 @@ class ProviderHandler:
external_auth_id: str | None = None, external_auth_id: str | None = None,
external_auth_token: SecretStr | None = None, external_auth_token: SecretStr | None = None,
external_token_manager: bool = False, external_token_manager: bool = False,
session_api_key: str | None = None,
sid: str | None = None,
): ):
if not isinstance(provider_tokens, MappingProxyType): if not isinstance(provider_tokens, MappingProxyType):
raise TypeError( raise TypeError(
@@ -127,7 +132,13 @@ class ProviderHandler:
self.external_auth_id = external_auth_id self.external_auth_id = external_auth_id
self.external_auth_token = external_auth_token self.external_auth_token = external_auth_token
self.external_token_manager = external_token_manager self.external_token_manager = external_token_manager
self.session_api_key = session_api_key
self.sid = sid
self._provider_tokens = provider_tokens self._provider_tokens = provider_tokens
WEB_HOST = os.getenv('WEB_HOST', '').strip()
self.REFRESH_TOKEN_URL = (
f'https://{WEB_HOST}/api/refresh-tokens' if WEB_HOST else None
)
@property @property
def provider_tokens(self) -> PROVIDER_TOKEN_TYPE: def provider_tokens(self) -> PROVIDER_TOKEN_TYPE:
@@ -161,8 +172,24 @@ class ProviderHandler:
self, provider: ProviderType self, provider: ProviderType
) -> SecretStr | None: ) -> SecretStr | None:
"""Get latest token from service""" """Get latest token from service"""
service = self._get_service(provider) try:
return await service.get_latest_token() async with httpx.AsyncClient() as client:
resp = await client.get(
self.REFRESH_TOKEN_URL,
headers={
'X-Session-API-Key': self.session_api_key,
},
params={'provider': provider.value, 'sid': self.sid},
)
resp.raise_for_status()
data = TokenResponse.model_validate_json(resp.text)
return SecretStr(data.token)
except Exception as e:
logger.warning(f'Failed to fetch latest token for provider {provider}: {e}')
return None
async def get_github_installations(self) -> list[str]: async def get_github_installations(self) -> list[str]:
service = cast(InstallationsService, self._get_service(ProviderType.GITHUB)) service = cast(InstallationsService, self._get_service(ProviderType.GITHUB))
@@ -356,7 +383,7 @@ class ProviderHandler:
else SecretStr('') else SecretStr('')
) )
if get_latest: if get_latest and self.REFRESH_TOKEN_URL and self.sid:
token = await self._get_latest_provider_token(provider) token = await self._get_latest_provider_token(provider)
if token: if token:

View File

@@ -14,6 +14,10 @@ from openhands.microagent.types import MicroagentContentResponse, MicroagentResp
from openhands.server.types import AppMode from openhands.server.types import AppMode
class TokenResponse(BaseModel):
token: str
class ProviderType(Enum): class ProviderType(Enum):
GITHUB = 'github' GITHUB = 'github'
GITLAB = 'gitlab' GITLAB = 'gitlab'

View File

@@ -323,9 +323,6 @@ class Runtime(FileEditRuntimeMixin):
async def _export_latest_git_provider_tokens(self, event: Action) -> None: async def _export_latest_git_provider_tokens(self, event: Action) -> None:
"""Refresh runtime provider tokens when agent attemps to run action with provider token""" """Refresh runtime provider tokens when agent attemps to run action with provider token"""
if not self.user_id:
return
providers_called = ProviderHandler.check_cmd_action_for_provider_token_ref( providers_called = ProviderHandler.check_cmd_action_for_provider_token_ref(
event event
) )
@@ -333,8 +330,17 @@ class Runtime(FileEditRuntimeMixin):
if not providers_called: if not providers_called:
return return
provider_handler = ProviderHandler(
provider_tokens=self.git_provider_tokens
or cast(PROVIDER_TOKEN_TYPE, MappingProxyType({})),
external_auth_id=self.user_id,
external_token_manager=True,
session_api_key=self.session_api_key,
sid=self.sid,
)
logger.info(f'Fetching latest provider tokens for runtime: {self.sid}') logger.info(f'Fetching latest provider tokens for runtime: {self.sid}')
env_vars = await self.provider_handler.get_env_vars( env_vars = await provider_handler.get_env_vars(
providers=providers_called, expose_secrets=False, get_latest=True providers=providers_called, expose_secrets=False, get_latest=True
) )
@@ -343,10 +349,10 @@ class Runtime(FileEditRuntimeMixin):
try: try:
if self.event_stream: if self.event_stream:
await self.provider_handler.set_event_stream_secrets( await provider_handler.set_event_stream_secrets(
self.event_stream, env_vars=env_vars self.event_stream, env_vars=env_vars
) )
self.add_env_vars(self.provider_handler.expose_env_vars(env_vars)) self.add_env_vars(provider_handler.expose_env_vars(env_vars))
except Exception as e: except Exception as e:
logger.warning( logger.warning(
f'Failed export latest github token to runtime: {self.sid}, {e}' f'Failed export latest github token to runtime: {self.sid}, {e}'

View File

@@ -1,5 +1,5 @@
from types import MappingProxyType from types import MappingProxyType
from unittest.mock import MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from pydantic import SecretStr from pydantic import SecretStr
@@ -213,22 +213,28 @@ async def test_export_latest_git_provider_tokens_multiple_refs(temp_dir):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_export_latest_git_provider_tokens_token_update(runtime): async def test_export_latest_git_provider_tokens_token_update(runtime, monkeypatch):
"""Test that token updates are handled correctly""" """Test that token updates are handled correctly"""
# First export with initial token # First export with initial token
cmd = CmdRunAction(command='echo $GITHUB_TOKEN') cmd = CmdRunAction(command='echo $GITHUB_TOKEN')
await runtime._export_latest_git_provider_tokens(cmd) await runtime._export_latest_git_provider_tokens(cmd)
# Update the token # Ensure refresh-token flow is enabled in ProviderHandler
monkeypatch.setenv('WEB_HOST', 'example.com')
# Simulate that provider handler will now fetch a new token from refresh endpoint
new_token = 'new_test_token' new_token = 'new_test_token'
runtime.provider_handler._provider_tokens = MappingProxyType(
{ProviderType.GITHUB: ProviderToken(token=SecretStr(new_token))}
)
# Export again with updated token # Patch ProviderHandler._get_latest_provider_token to return new SecretStr
await runtime._export_latest_git_provider_tokens(cmd) with patch.object(
ProviderHandler,
'_get_latest_provider_token',
new=AsyncMock(return_value=SecretStr(new_token)),
):
# Export again with updated token runtime should fetch latest and update EventStream secrets
await runtime._export_latest_git_provider_tokens(cmd)
# Verify that the new token was exported # Verify that the new token was exported to the event stream
assert runtime.event_stream.secrets == {'github_token': new_token} assert runtime.event_stream.secrets == {'github_token': new_token}