[Fix]: token refresh for nested runtimes (#10637)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Rohit Malhotra
2025-08-27 12:20:34 -04:00
committed by GitHub
parent e68abf8d75
commit 36e0d8d3da
4 changed files with 61 additions and 18 deletions

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
import os
from types import MappingProxyType
from typing import Annotated, Any, Coroutine, Literal, cast, overload
import httpx
from pydantic import (
BaseModel,
ConfigDict,
@@ -28,6 +30,7 @@ from openhands.integrations.service_types import (
Repository,
ResourceNotFoundError,
SuggestedTask,
TokenResponse,
User,
)
from openhands.microagent.types import MicroagentContentResponse, MicroagentResponse
@@ -112,6 +115,8 @@ class ProviderHandler:
external_auth_id: str | None = None,
external_auth_token: SecretStr | None = None,
external_token_manager: bool = False,
session_api_key: str | None = None,
sid: str | None = None,
):
if not isinstance(provider_tokens, MappingProxyType):
raise TypeError(
@@ -127,7 +132,13 @@ class ProviderHandler:
self.external_auth_id = external_auth_id
self.external_auth_token = external_auth_token
self.external_token_manager = external_token_manager
self.session_api_key = session_api_key
self.sid = sid
self._provider_tokens = provider_tokens
WEB_HOST = os.getenv('WEB_HOST', '').strip()
self.REFRESH_TOKEN_URL = (
f'https://{WEB_HOST}/api/refresh-tokens' if WEB_HOST else None
)
@property
def provider_tokens(self) -> PROVIDER_TOKEN_TYPE:
@@ -161,8 +172,24 @@ class ProviderHandler:
self, provider: ProviderType
) -> SecretStr | None:
"""Get latest token from service"""
service = self._get_service(provider)
return await service.get_latest_token()
try:
async with httpx.AsyncClient() as client:
resp = await client.get(
self.REFRESH_TOKEN_URL,
headers={
'X-Session-API-Key': self.session_api_key,
},
params={'provider': provider.value, 'sid': self.sid},
)
resp.raise_for_status()
data = TokenResponse.model_validate_json(resp.text)
return SecretStr(data.token)
except Exception as e:
logger.warning(f'Failed to fetch latest token for provider {provider}: {e}')
return None
async def get_github_installations(self) -> list[str]:
service = cast(InstallationsService, self._get_service(ProviderType.GITHUB))
@@ -356,7 +383,7 @@ class ProviderHandler:
else SecretStr('')
)
if get_latest:
if get_latest and self.REFRESH_TOKEN_URL and self.sid:
token = await self._get_latest_provider_token(provider)
if token:

View File

@@ -14,6 +14,10 @@ from openhands.microagent.types import MicroagentContentResponse, MicroagentResp
from openhands.server.types import AppMode
class TokenResponse(BaseModel):
token: str
class ProviderType(Enum):
GITHUB = 'github'
GITLAB = 'gitlab'

View File

@@ -323,9 +323,6 @@ class Runtime(FileEditRuntimeMixin):
async def _export_latest_git_provider_tokens(self, event: Action) -> None:
"""Refresh runtime provider tokens when agent attemps to run action with provider token"""
if not self.user_id:
return
providers_called = ProviderHandler.check_cmd_action_for_provider_token_ref(
event
)
@@ -333,8 +330,17 @@ class Runtime(FileEditRuntimeMixin):
if not providers_called:
return
provider_handler = ProviderHandler(
provider_tokens=self.git_provider_tokens
or cast(PROVIDER_TOKEN_TYPE, MappingProxyType({})),
external_auth_id=self.user_id,
external_token_manager=True,
session_api_key=self.session_api_key,
sid=self.sid,
)
logger.info(f'Fetching latest provider tokens for runtime: {self.sid}')
env_vars = await self.provider_handler.get_env_vars(
env_vars = await provider_handler.get_env_vars(
providers=providers_called, expose_secrets=False, get_latest=True
)
@@ -343,10 +349,10 @@ class Runtime(FileEditRuntimeMixin):
try:
if self.event_stream:
await self.provider_handler.set_event_stream_secrets(
await provider_handler.set_event_stream_secrets(
self.event_stream, env_vars=env_vars
)
self.add_env_vars(self.provider_handler.expose_env_vars(env_vars))
self.add_env_vars(provider_handler.expose_env_vars(env_vars))
except Exception as e:
logger.warning(
f'Failed export latest github token to runtime: {self.sid}, {e}'

View File

@@ -1,5 +1,5 @@
from types import MappingProxyType
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
@@ -213,22 +213,28 @@ async def test_export_latest_git_provider_tokens_multiple_refs(temp_dir):
@pytest.mark.asyncio
async def test_export_latest_git_provider_tokens_token_update(runtime):
async def test_export_latest_git_provider_tokens_token_update(runtime, monkeypatch):
"""Test that token updates are handled correctly"""
# First export with initial token
cmd = CmdRunAction(command='echo $GITHUB_TOKEN')
await runtime._export_latest_git_provider_tokens(cmd)
# Update the token
# Ensure refresh-token flow is enabled in ProviderHandler
monkeypatch.setenv('WEB_HOST', 'example.com')
# Simulate that provider handler will now fetch a new token from refresh endpoint
new_token = 'new_test_token'
runtime.provider_handler._provider_tokens = MappingProxyType(
{ProviderType.GITHUB: ProviderToken(token=SecretStr(new_token))}
)
# Export again with updated token
await runtime._export_latest_git_provider_tokens(cmd)
# Patch ProviderHandler._get_latest_provider_token to return new SecretStr
with patch.object(
ProviderHandler,
'_get_latest_provider_token',
new=AsyncMock(return_value=SecretStr(new_token)),
):
# Export again with updated token runtime should fetch latest and update EventStream secrets
await runtime._export_latest_git_provider_tokens(cmd)
# Verify that the new token was exported
# Verify that the new token was exported to the event stream
assert runtime.event_stream.secrets == {'github_token': new_token}