mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-08 22:38:05 -05:00
[Fix]: token refresh for nested runtimes (#10637)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from types import MappingProxyType
|
||||
from typing import Annotated, Any, Coroutine, Literal, cast, overload
|
||||
|
||||
import httpx
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
@@ -28,6 +30,7 @@ from openhands.integrations.service_types import (
|
||||
Repository,
|
||||
ResourceNotFoundError,
|
||||
SuggestedTask,
|
||||
TokenResponse,
|
||||
User,
|
||||
)
|
||||
from openhands.microagent.types import MicroagentContentResponse, MicroagentResponse
|
||||
@@ -112,6 +115,8 @@ class ProviderHandler:
|
||||
external_auth_id: str | None = None,
|
||||
external_auth_token: SecretStr | None = None,
|
||||
external_token_manager: bool = False,
|
||||
session_api_key: str | None = None,
|
||||
sid: str | None = None,
|
||||
):
|
||||
if not isinstance(provider_tokens, MappingProxyType):
|
||||
raise TypeError(
|
||||
@@ -127,7 +132,13 @@ class ProviderHandler:
|
||||
self.external_auth_id = external_auth_id
|
||||
self.external_auth_token = external_auth_token
|
||||
self.external_token_manager = external_token_manager
|
||||
self.session_api_key = session_api_key
|
||||
self.sid = sid
|
||||
self._provider_tokens = provider_tokens
|
||||
WEB_HOST = os.getenv('WEB_HOST', '').strip()
|
||||
self.REFRESH_TOKEN_URL = (
|
||||
f'https://{WEB_HOST}/api/refresh-tokens' if WEB_HOST else None
|
||||
)
|
||||
|
||||
@property
|
||||
def provider_tokens(self) -> PROVIDER_TOKEN_TYPE:
|
||||
@@ -161,8 +172,24 @@ class ProviderHandler:
|
||||
self, provider: ProviderType
|
||||
) -> SecretStr | None:
|
||||
"""Get latest token from service"""
|
||||
service = self._get_service(provider)
|
||||
return await service.get_latest_token()
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
self.REFRESH_TOKEN_URL,
|
||||
headers={
|
||||
'X-Session-API-Key': self.session_api_key,
|
||||
},
|
||||
params={'provider': provider.value, 'sid': self.sid},
|
||||
)
|
||||
|
||||
resp.raise_for_status()
|
||||
data = TokenResponse.model_validate_json(resp.text)
|
||||
return SecretStr(data.token)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f'Failed to fetch latest token for provider {provider}: {e}')
|
||||
|
||||
return None
|
||||
|
||||
async def get_github_installations(self) -> list[str]:
|
||||
service = cast(InstallationsService, self._get_service(ProviderType.GITHUB))
|
||||
@@ -356,7 +383,7 @@ class ProviderHandler:
|
||||
else SecretStr('')
|
||||
)
|
||||
|
||||
if get_latest:
|
||||
if get_latest and self.REFRESH_TOKEN_URL and self.sid:
|
||||
token = await self._get_latest_provider_token(provider)
|
||||
|
||||
if token:
|
||||
|
||||
@@ -14,6 +14,10 @@ from openhands.microagent.types import MicroagentContentResponse, MicroagentResp
|
||||
from openhands.server.types import AppMode
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
GITHUB = 'github'
|
||||
GITLAB = 'gitlab'
|
||||
|
||||
@@ -323,9 +323,6 @@ class Runtime(FileEditRuntimeMixin):
|
||||
|
||||
async def _export_latest_git_provider_tokens(self, event: Action) -> None:
|
||||
"""Refresh runtime provider tokens when agent attemps to run action with provider token"""
|
||||
if not self.user_id:
|
||||
return
|
||||
|
||||
providers_called = ProviderHandler.check_cmd_action_for_provider_token_ref(
|
||||
event
|
||||
)
|
||||
@@ -333,8 +330,17 @@ class Runtime(FileEditRuntimeMixin):
|
||||
if not providers_called:
|
||||
return
|
||||
|
||||
provider_handler = ProviderHandler(
|
||||
provider_tokens=self.git_provider_tokens
|
||||
or cast(PROVIDER_TOKEN_TYPE, MappingProxyType({})),
|
||||
external_auth_id=self.user_id,
|
||||
external_token_manager=True,
|
||||
session_api_key=self.session_api_key,
|
||||
sid=self.sid,
|
||||
)
|
||||
|
||||
logger.info(f'Fetching latest provider tokens for runtime: {self.sid}')
|
||||
env_vars = await self.provider_handler.get_env_vars(
|
||||
env_vars = await provider_handler.get_env_vars(
|
||||
providers=providers_called, expose_secrets=False, get_latest=True
|
||||
)
|
||||
|
||||
@@ -343,10 +349,10 @@ class Runtime(FileEditRuntimeMixin):
|
||||
|
||||
try:
|
||||
if self.event_stream:
|
||||
await self.provider_handler.set_event_stream_secrets(
|
||||
await provider_handler.set_event_stream_secrets(
|
||||
self.event_stream, env_vars=env_vars
|
||||
)
|
||||
self.add_env_vars(self.provider_handler.expose_env_vars(env_vars))
|
||||
self.add_env_vars(provider_handler.expose_env_vars(env_vars))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'Failed export latest github token to runtime: {self.sid}, {e}'
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from types import MappingProxyType
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
@@ -213,22 +213,28 @@ async def test_export_latest_git_provider_tokens_multiple_refs(temp_dir):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_export_latest_git_provider_tokens_token_update(runtime):
|
||||
async def test_export_latest_git_provider_tokens_token_update(runtime, monkeypatch):
|
||||
"""Test that token updates are handled correctly"""
|
||||
# First export with initial token
|
||||
cmd = CmdRunAction(command='echo $GITHUB_TOKEN')
|
||||
await runtime._export_latest_git_provider_tokens(cmd)
|
||||
|
||||
# Update the token
|
||||
# Ensure refresh-token flow is enabled in ProviderHandler
|
||||
monkeypatch.setenv('WEB_HOST', 'example.com')
|
||||
|
||||
# Simulate that provider handler will now fetch a new token from refresh endpoint
|
||||
new_token = 'new_test_token'
|
||||
runtime.provider_handler._provider_tokens = MappingProxyType(
|
||||
{ProviderType.GITHUB: ProviderToken(token=SecretStr(new_token))}
|
||||
)
|
||||
|
||||
# Export again with updated token
|
||||
await runtime._export_latest_git_provider_tokens(cmd)
|
||||
# Patch ProviderHandler._get_latest_provider_token to return new SecretStr
|
||||
with patch.object(
|
||||
ProviderHandler,
|
||||
'_get_latest_provider_token',
|
||||
new=AsyncMock(return_value=SecretStr(new_token)),
|
||||
):
|
||||
# Export again with updated token – runtime should fetch latest and update EventStream secrets
|
||||
await runtime._export_latest_git_provider_tokens(cmd)
|
||||
|
||||
# Verify that the new token was exported
|
||||
# Verify that the new token was exported to the event stream
|
||||
assert runtime.event_stream.secrets == {'github_token': new_token}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user