mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
[Fix]: token refresh for nested runtimes (#10637)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -1,8 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import Annotated, Any, Coroutine, Literal, cast, overload
|
from typing import Annotated, Any, Coroutine, Literal, cast, overload
|
||||||
|
|
||||||
|
import httpx
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
ConfigDict,
|
ConfigDict,
|
||||||
@@ -28,6 +30,7 @@ from openhands.integrations.service_types import (
|
|||||||
Repository,
|
Repository,
|
||||||
ResourceNotFoundError,
|
ResourceNotFoundError,
|
||||||
SuggestedTask,
|
SuggestedTask,
|
||||||
|
TokenResponse,
|
||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
from openhands.microagent.types import MicroagentContentResponse, MicroagentResponse
|
from openhands.microagent.types import MicroagentContentResponse, MicroagentResponse
|
||||||
@@ -112,6 +115,8 @@ class ProviderHandler:
|
|||||||
external_auth_id: str | None = None,
|
external_auth_id: str | None = None,
|
||||||
external_auth_token: SecretStr | None = None,
|
external_auth_token: SecretStr | None = None,
|
||||||
external_token_manager: bool = False,
|
external_token_manager: bool = False,
|
||||||
|
session_api_key: str | None = None,
|
||||||
|
sid: str | None = None,
|
||||||
):
|
):
|
||||||
if not isinstance(provider_tokens, MappingProxyType):
|
if not isinstance(provider_tokens, MappingProxyType):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@@ -127,7 +132,13 @@ class ProviderHandler:
|
|||||||
self.external_auth_id = external_auth_id
|
self.external_auth_id = external_auth_id
|
||||||
self.external_auth_token = external_auth_token
|
self.external_auth_token = external_auth_token
|
||||||
self.external_token_manager = external_token_manager
|
self.external_token_manager = external_token_manager
|
||||||
|
self.session_api_key = session_api_key
|
||||||
|
self.sid = sid
|
||||||
self._provider_tokens = provider_tokens
|
self._provider_tokens = provider_tokens
|
||||||
|
WEB_HOST = os.getenv('WEB_HOST', '').strip()
|
||||||
|
self.REFRESH_TOKEN_URL = (
|
||||||
|
f'https://{WEB_HOST}/api/refresh-tokens' if WEB_HOST else None
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def provider_tokens(self) -> PROVIDER_TOKEN_TYPE:
|
def provider_tokens(self) -> PROVIDER_TOKEN_TYPE:
|
||||||
@@ -161,8 +172,24 @@ class ProviderHandler:
|
|||||||
self, provider: ProviderType
|
self, provider: ProviderType
|
||||||
) -> SecretStr | None:
|
) -> SecretStr | None:
|
||||||
"""Get latest token from service"""
|
"""Get latest token from service"""
|
||||||
service = self._get_service(provider)
|
try:
|
||||||
return await service.get_latest_token()
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.get(
|
||||||
|
self.REFRESH_TOKEN_URL,
|
||||||
|
headers={
|
||||||
|
'X-Session-API-Key': self.session_api_key,
|
||||||
|
},
|
||||||
|
params={'provider': provider.value, 'sid': self.sid},
|
||||||
|
)
|
||||||
|
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = TokenResponse.model_validate_json(resp.text)
|
||||||
|
return SecretStr(data.token)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f'Failed to fetch latest token for provider {provider}: {e}')
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
async def get_github_installations(self) -> list[str]:
|
async def get_github_installations(self) -> list[str]:
|
||||||
service = cast(InstallationsService, self._get_service(ProviderType.GITHUB))
|
service = cast(InstallationsService, self._get_service(ProviderType.GITHUB))
|
||||||
@@ -356,7 +383,7 @@ class ProviderHandler:
|
|||||||
else SecretStr('')
|
else SecretStr('')
|
||||||
)
|
)
|
||||||
|
|
||||||
if get_latest:
|
if get_latest and self.REFRESH_TOKEN_URL and self.sid:
|
||||||
token = await self._get_latest_provider_token(provider)
|
token = await self._get_latest_provider_token(provider)
|
||||||
|
|
||||||
if token:
|
if token:
|
||||||
|
|||||||
@@ -14,6 +14,10 @@ from openhands.microagent.types import MicroagentContentResponse, MicroagentResp
|
|||||||
from openhands.server.types import AppMode
|
from openhands.server.types import AppMode
|
||||||
|
|
||||||
|
|
||||||
|
class TokenResponse(BaseModel):
|
||||||
|
token: str
|
||||||
|
|
||||||
|
|
||||||
class ProviderType(Enum):
|
class ProviderType(Enum):
|
||||||
GITHUB = 'github'
|
GITHUB = 'github'
|
||||||
GITLAB = 'gitlab'
|
GITLAB = 'gitlab'
|
||||||
|
|||||||
@@ -323,9 +323,6 @@ class Runtime(FileEditRuntimeMixin):
|
|||||||
|
|
||||||
async def _export_latest_git_provider_tokens(self, event: Action) -> None:
|
async def _export_latest_git_provider_tokens(self, event: Action) -> None:
|
||||||
"""Refresh runtime provider tokens when agent attemps to run action with provider token"""
|
"""Refresh runtime provider tokens when agent attemps to run action with provider token"""
|
||||||
if not self.user_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
providers_called = ProviderHandler.check_cmd_action_for_provider_token_ref(
|
providers_called = ProviderHandler.check_cmd_action_for_provider_token_ref(
|
||||||
event
|
event
|
||||||
)
|
)
|
||||||
@@ -333,8 +330,17 @@ class Runtime(FileEditRuntimeMixin):
|
|||||||
if not providers_called:
|
if not providers_called:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
provider_handler = ProviderHandler(
|
||||||
|
provider_tokens=self.git_provider_tokens
|
||||||
|
or cast(PROVIDER_TOKEN_TYPE, MappingProxyType({})),
|
||||||
|
external_auth_id=self.user_id,
|
||||||
|
external_token_manager=True,
|
||||||
|
session_api_key=self.session_api_key,
|
||||||
|
sid=self.sid,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f'Fetching latest provider tokens for runtime: {self.sid}')
|
logger.info(f'Fetching latest provider tokens for runtime: {self.sid}')
|
||||||
env_vars = await self.provider_handler.get_env_vars(
|
env_vars = await provider_handler.get_env_vars(
|
||||||
providers=providers_called, expose_secrets=False, get_latest=True
|
providers=providers_called, expose_secrets=False, get_latest=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -343,10 +349,10 @@ class Runtime(FileEditRuntimeMixin):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if self.event_stream:
|
if self.event_stream:
|
||||||
await self.provider_handler.set_event_stream_secrets(
|
await provider_handler.set_event_stream_secrets(
|
||||||
self.event_stream, env_vars=env_vars
|
self.event_stream, env_vars=env_vars
|
||||||
)
|
)
|
||||||
self.add_env_vars(self.provider_handler.expose_env_vars(env_vars))
|
self.add_env_vars(provider_handler.expose_env_vars(env_vars))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f'Failed export latest github token to runtime: {self.sid}, {e}'
|
f'Failed export latest github token to runtime: {self.sid}, {e}'
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
@@ -213,22 +213,28 @@ async def test_export_latest_git_provider_tokens_multiple_refs(temp_dir):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_export_latest_git_provider_tokens_token_update(runtime):
|
async def test_export_latest_git_provider_tokens_token_update(runtime, monkeypatch):
|
||||||
"""Test that token updates are handled correctly"""
|
"""Test that token updates are handled correctly"""
|
||||||
# First export with initial token
|
# First export with initial token
|
||||||
cmd = CmdRunAction(command='echo $GITHUB_TOKEN')
|
cmd = CmdRunAction(command='echo $GITHUB_TOKEN')
|
||||||
await runtime._export_latest_git_provider_tokens(cmd)
|
await runtime._export_latest_git_provider_tokens(cmd)
|
||||||
|
|
||||||
# Update the token
|
# Ensure refresh-token flow is enabled in ProviderHandler
|
||||||
|
monkeypatch.setenv('WEB_HOST', 'example.com')
|
||||||
|
|
||||||
|
# Simulate that provider handler will now fetch a new token from refresh endpoint
|
||||||
new_token = 'new_test_token'
|
new_token = 'new_test_token'
|
||||||
runtime.provider_handler._provider_tokens = MappingProxyType(
|
|
||||||
{ProviderType.GITHUB: ProviderToken(token=SecretStr(new_token))}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Export again with updated token
|
# Patch ProviderHandler._get_latest_provider_token to return new SecretStr
|
||||||
await runtime._export_latest_git_provider_tokens(cmd)
|
with patch.object(
|
||||||
|
ProviderHandler,
|
||||||
|
'_get_latest_provider_token',
|
||||||
|
new=AsyncMock(return_value=SecretStr(new_token)),
|
||||||
|
):
|
||||||
|
# Export again with updated token – runtime should fetch latest and update EventStream secrets
|
||||||
|
await runtime._export_latest_git_provider_tokens(cmd)
|
||||||
|
|
||||||
# Verify that the new token was exported
|
# Verify that the new token was exported to the event stream
|
||||||
assert runtime.event_stream.secrets == {'github_token': new_token}
|
assert runtime.event_stream.secrets == {'github_token': new_token}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user