mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 88d96c4f65 | |||
| 29886ba0ff | |||
| d099c21f5d | |||
| 4c89b5ad91 | |||
| 729c181313 |
@@ -258,8 +258,4 @@ containers/runtime/code
|
||||
# test results
|
||||
test-results
|
||||
.sessions
|
||||
|
||||
# ignore agent-sdk embedded repo if present
|
||||
agent-sdk/
|
||||
|
||||
.eval_sessions
|
||||
|
||||
@@ -88,13 +88,15 @@ function GitChanges() {
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
gitChanges.map((change) => (
|
||||
<FileDiffViewer
|
||||
key={change.path}
|
||||
path={change.path}
|
||||
type={change.status}
|
||||
/>
|
||||
))
|
||||
gitChanges
|
||||
.slice(0, 100)
|
||||
.map((change) => (
|
||||
<FileDiffViewer
|
||||
key={change.path}
|
||||
path={change.path}
|
||||
type={change.status}
|
||||
/>
|
||||
))
|
||||
)}
|
||||
</main>
|
||||
);
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
from .server import ACPAgentServer, run_stdio_server
|
||||
|
||||
__all__ = [
|
||||
'ACPAgentServer',
|
||||
'run_stdio_server',
|
||||
]
|
||||
@@ -1,11 +0,0 @@
|
||||
import asyncio
|
||||
|
||||
from .server import run_stdio_server
|
||||
|
||||
|
||||
def main() -> None:
|
||||
asyncio.run(run_stdio_server())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -1,154 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable
|
||||
|
||||
Json = dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Request:
|
||||
id: int
|
||||
method: str
|
||||
params: Any | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Response:
|
||||
id: int
|
||||
result: Any | None = None
|
||||
error: Any | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Notification:
|
||||
method: str
|
||||
params: Any | None
|
||||
|
||||
|
||||
class NDJsonStdio:
|
||||
"""Simple newline-delimited JSON over stdio.
|
||||
|
||||
This intentionally follows the ACP typescript ndJsonStream helper for simplicity.
|
||||
"""
|
||||
|
||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
self.reader = reader
|
||||
self.writer = writer
|
||||
self._write_lock = asyncio.Lock()
|
||||
|
||||
async def write(self, obj: Any) -> None:
|
||||
data = json.dumps(obj, separators=(',', ':')) + '\n'
|
||||
async with self._write_lock:
|
||||
self.writer.write(data.encode('utf-8'))
|
||||
await self.writer.drain()
|
||||
|
||||
async def read(self) -> AsyncIterator[Any]:
|
||||
while not self.reader.at_eof():
|
||||
line = await self.reader.readline()
|
||||
if not line:
|
||||
break
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
yield json.loads(line)
|
||||
except Exception:
|
||||
# ignore malformed lines
|
||||
continue
|
||||
|
||||
|
||||
class JsonRpcConnection:
|
||||
def __init__(self, stream: NDJsonStdio):
|
||||
self.stream = stream
|
||||
self._id = 0
|
||||
self._pending: dict[int, asyncio.Future[Any]] = {}
|
||||
self._closed = asyncio.Event()
|
||||
self._tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
async def send_request(self, method: str, params: Any | None = None) -> Any:
|
||||
self._id += 1
|
||||
req_id = self._id
|
||||
fut: asyncio.Future[Any] = asyncio.get_running_loop().create_future()
|
||||
self._pending[req_id] = fut
|
||||
await self.stream.write(
|
||||
{'jsonrpc': '2.0', 'id': req_id, 'method': method, 'params': params}
|
||||
)
|
||||
return await fut
|
||||
|
||||
async def send_notification(self, method: str, params: Any | None = None) -> None:
|
||||
await self.stream.write({'jsonrpc': '2.0', 'method': method, 'params': params})
|
||||
|
||||
async def send_response(
|
||||
self, id: int, result: Any | None = None, error: Any | None = None
|
||||
) -> None:
|
||||
if error is not None:
|
||||
await self.stream.write({'jsonrpc': '2.0', 'id': id, 'error': error})
|
||||
else:
|
||||
await self.stream.write({'jsonrpc': '2.0', 'id': id, 'result': result})
|
||||
|
||||
def _create_task(self, coro: Awaitable[Any]) -> None:
|
||||
task: asyncio.Task[Any] = asyncio.create_task(coro) # type: ignore[arg-type]
|
||||
self._tasks.add(task)
|
||||
task.add_done_callback(self._tasks.discard) # type: ignore[arg-type]
|
||||
|
||||
async def serve(
|
||||
self,
|
||||
on_request: Callable[[str, Any | None], Awaitable[Any | None]],
|
||||
on_notification: Callable[[str, Any | None], Awaitable[None]] | None = None,
|
||||
) -> None:
|
||||
async for msg in self.stream.read():
|
||||
try:
|
||||
if not isinstance(msg, dict) or msg.get('jsonrpc') != '2.0':
|
||||
continue
|
||||
if 'method' in msg:
|
||||
method = msg['method']
|
||||
params = msg.get('params')
|
||||
if 'id' in msg:
|
||||
req_id = msg['id']
|
||||
|
||||
async def handle_req(
|
||||
method: str = method,
|
||||
params: Any | None = params,
|
||||
req_id: int = req_id,
|
||||
) -> None:
|
||||
try:
|
||||
result = await on_request(method, params)
|
||||
await self.send_response(
|
||||
req_id, result=result if result is not None else {}
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
await self.send_response(
|
||||
req_id,
|
||||
error={'code': -32800, 'message': 'cancelled'},
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
await self.send_response(
|
||||
req_id, error={'code': -32603, 'message': str(e)}
|
||||
)
|
||||
|
||||
self._create_task(handle_req())
|
||||
else:
|
||||
if on_notification is not None:
|
||||
self._create_task(on_notification(method, params))
|
||||
elif 'id' in msg:
|
||||
fut = self._pending.pop(int(msg['id']), None)
|
||||
if fut:
|
||||
if 'result' in msg:
|
||||
fut.set_result(msg['result'])
|
||||
else:
|
||||
fut.set_exception(
|
||||
RuntimeError(msg.get('error') or 'unknown error')
|
||||
)
|
||||
except Exception:
|
||||
# ignore
|
||||
continue
|
||||
# Wait a brief moment for any straggling tasks
|
||||
if self._tasks:
|
||||
await asyncio.wait(self._tasks, timeout=1.0)
|
||||
self._closed.set()
|
||||
|
||||
async def wait_closed(self) -> None:
|
||||
await self._closed.wait()
|
||||
@@ -1,166 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from openhands.acp.jsonrpc import JsonRpcConnection, NDJsonStdio
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
PROTOCOL_VERSION = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionState:
|
||||
pending_task: asyncio.Task | None = None
|
||||
|
||||
|
||||
class ACPAgentServer:
|
||||
"""Minimal ACP adapter to expose OpenHands as an ACP Agent over stdio NDJSON.
|
||||
|
||||
Implements initialize, session/new, session/prompt and session/cancel,
|
||||
and provides client-facing notifications session/update and requests
|
||||
like session/request_permission in the future. This is a minimal MVP to
|
||||
integrate with Zed ACP client.
|
||||
"""
|
||||
|
||||
def __init__(self, rpc: JsonRpcConnection):
|
||||
self.rpc = rpc
|
||||
self.sessions: dict[str, SessionState] = {}
|
||||
|
||||
async def handle_request(self, method: str, params: Any | None) -> Any | None:
|
||||
if method == 'initialize':
|
||||
return await self._initialize(params)
|
||||
if method == 'session/new':
|
||||
return await self._session_new(params)
|
||||
if method == 'session/prompt':
|
||||
return await self._session_prompt(params)
|
||||
if method == 'session/cancel':
|
||||
# Spec: cancel is a notification, but handle gracefully if sent as request
|
||||
await self._session_cancel(params)
|
||||
return {}
|
||||
if method == 'authenticate':
|
||||
# No-op for now
|
||||
return {}
|
||||
if method == 'session/set_mode':
|
||||
return {}
|
||||
raise RuntimeError(f'Method not implemented: {method}')
|
||||
|
||||
async def handle_notification(self, method: str, params: Any | None) -> None:
|
||||
if method == 'session/cancel':
|
||||
await self._session_cancel(params)
|
||||
|
||||
async def _initialize(self, params: dict[str, Any] | None) -> dict[str, Any]:
|
||||
return {
|
||||
'protocolVersion': PROTOCOL_VERSION,
|
||||
'agentCapabilities': {
|
||||
'loadSession': False,
|
||||
},
|
||||
'promptCapabilities': {
|
||||
'supportsImage': True,
|
||||
'supportsAudio': False,
|
||||
'supportsResources': True,
|
||||
},
|
||||
}
|
||||
|
||||
async def _session_new(self, params: dict[str, Any] | None) -> dict[str, Any]:
|
||||
# Client may provide preferred model or workspace details; ignore for MVP
|
||||
session_id = await self._generate_session_id()
|
||||
self.sessions[session_id] = SessionState()
|
||||
return {'sessionId': session_id}
|
||||
|
||||
async def _session_prompt(self, params: dict[str, Any] | None) -> dict[str, Any]:
|
||||
assert params is not None
|
||||
session_id = params.get('sessionId', '')
|
||||
# Accept either 'messages' (python test harness) or 'prompt' (ACP TS client)
|
||||
_messages = (
|
||||
params.get('messages') if 'messages' in params else params.get('prompt', [])
|
||||
)
|
||||
# For MVP we just echo a text agent message chunk and end_turn
|
||||
state = self.sessions.get(session_id)
|
||||
if state is None:
|
||||
raise RuntimeError(f'Unknown session {session_id}')
|
||||
|
||||
# cancel any pending prompt
|
||||
if state.pending_task and not state.pending_task.done():
|
||||
state.pending_task.cancel()
|
||||
try:
|
||||
await state.pending_task
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
async def run_turn() -> None:
|
||||
try:
|
||||
await self.rpc.send_notification(
|
||||
'session/update',
|
||||
{
|
||||
'sessionId': session_id,
|
||||
'update': {
|
||||
'sessionUpdate': 'agent_message_chunk',
|
||||
'content': {
|
||||
'type': 'text',
|
||||
'text': 'OpenHands is thinking...',
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
await asyncio.sleep(0.2)
|
||||
await self.rpc.send_notification(
|
||||
'session/update',
|
||||
{
|
||||
'sessionId': session_id,
|
||||
'update': {
|
||||
'sessionUpdate': 'agent_message_chunk',
|
||||
'content': {
|
||||
'type': 'text',
|
||||
'text': 'This is a minimal ACP adapter.',
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
# Send nothing more
|
||||
raise
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.exception('Error in prompt run: %s', e)
|
||||
|
||||
task = asyncio.create_task(run_turn())
|
||||
state.pending_task = task
|
||||
try:
|
||||
await task
|
||||
stop_reason = 'end_turn'
|
||||
except asyncio.CancelledError:
|
||||
stop_reason = 'cancelled'
|
||||
return {'stopReason': stop_reason}
|
||||
|
||||
async def _session_cancel(self, params: dict[str, Any] | None) -> None:
|
||||
if not params:
|
||||
return
|
||||
session_id = params.get('sessionId')
|
||||
if not session_id:
|
||||
return
|
||||
state = self.sessions.get(session_id)
|
||||
if state and state.pending_task and not state.pending_task.done():
|
||||
state.pending_task.cancel()
|
||||
|
||||
async def _generate_session_id(self) -> str:
|
||||
# Simple increasing counter based id
|
||||
return f'sess-{len(self.sessions) + 1:04d}'
|
||||
|
||||
|
||||
async def run_stdio_server() -> None:
|
||||
import sys
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
reader = asyncio.StreamReader()
|
||||
reader_protocol = asyncio.StreamReaderProtocol(reader)
|
||||
await loop.connect_read_pipe(lambda: reader_protocol, sys.stdin)
|
||||
write_transport, write_protocol = await loop.connect_write_pipe(
|
||||
asyncio.streams.FlowControlMixin, sys.stdout
|
||||
)
|
||||
writer = asyncio.StreamWriter(write_transport, write_protocol, reader, loop)
|
||||
|
||||
stream = NDJsonStdio(reader, writer)
|
||||
rpc = JsonRpcConnection(stream)
|
||||
server = ACPAgentServer(rpc)
|
||||
await rpc.serve(server.handle_request, server.handle_notification)
|
||||
@@ -28,6 +28,7 @@ Your primary role is to assist users by executing commands, modifying code, and
|
||||
* Before implementing any changes, first thoroughly understand the codebase through exploration.
|
||||
* If you are adding a lot of code to a function or file, consider splitting the function or file into smaller pieces when appropriate.
|
||||
* Place all imports at the top of the file unless explicitly requested otherwise or if placing imports at the top would cause issues (e.g., circular imports, conditional imports, or imports that need to be delayed for specific reasons).
|
||||
* If working in a git repo, before you commit code create a .gitignore file if one doesn't exist. And if there are existing files that should not be included then update the .gitignore file as appropriate.
|
||||
</CODE_QUALITY>
|
||||
|
||||
<VERSION_CONTROL>
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
import pathlib
|
||||
import platform
|
||||
import sys
|
||||
import traceback
|
||||
from ast import literal_eval
|
||||
from types import UnionType
|
||||
from typing import Any, MutableMapping, get_args, get_origin, get_type_hints
|
||||
@@ -825,10 +826,18 @@ def load_openhands_config(
|
||||
set_logging_levels: Whether to set the global variables for logging levels.
|
||||
config_file: Path to the config file. Defaults to 'config.toml' in the current directory.
|
||||
"""
|
||||
logger.openhands_logger.info('load_openhands_config stack trace:')
|
||||
logger.openhands_logger.info(''.join(traceback.format_stack()))
|
||||
config = OpenHandsConfig()
|
||||
load_from_toml(config, config_file)
|
||||
logger.openhands_logger.info(
|
||||
f'Config from TOML file {config_file}: {config.model_dump_json()}'
|
||||
)
|
||||
logger.openhands_logger.info(f'Env vars: {dict(os.environ)}')
|
||||
load_from_env(config, os.environ)
|
||||
logger.openhands_logger.info(f'Config from env: {config.model_dump_json()}')
|
||||
finalize_config(config)
|
||||
logger.openhands_logger.info(f'Config finalized: {config.model_dump_json()}')
|
||||
register_custom_agents(config)
|
||||
if set_logging_levels:
|
||||
logger.DEBUG = config.debug
|
||||
|
||||
@@ -466,10 +466,22 @@ class ProviderHandler:
|
||||
except Exception as e:
|
||||
errors.append(f'{provider.value}: {str(e)}')
|
||||
|
||||
# Log all accumulated errors before raising AuthenticationError
|
||||
logger.error(
|
||||
f'Failed to access repository {repository} with all available providers. Errors: {"; ".join(errors)}'
|
||||
)
|
||||
# Log detailed error based on whether we had tokens or not
|
||||
if not self.provider_tokens:
|
||||
logger.error(
|
||||
f'Failed to access repository {repository}: No provider tokens available. '
|
||||
f'provider_tokens dict is empty.'
|
||||
)
|
||||
elif errors:
|
||||
logger.error(
|
||||
f'Failed to access repository {repository} with all available providers. '
|
||||
f'Tried providers: {list(self.provider_tokens.keys())}. '
|
||||
f'Errors: {"; ".join(errors)}'
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f'Failed to access repository {repository}: Unknown error (no providers tried, no errors recorded)'
|
||||
)
|
||||
raise AuthenticationError(f'Unable to access repo {repository}')
|
||||
|
||||
async def get_branches(
|
||||
|
||||
@@ -462,6 +462,7 @@ async def start_conversation(
|
||||
providers_set: ProvidersSetModel,
|
||||
conversation_id: str = Depends(validate_conversation_id),
|
||||
user_id: str = Depends(get_user_id),
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE = Depends(get_provider_tokens),
|
||||
settings: Settings = Depends(get_user_settings),
|
||||
conversation_store: ConversationStore = Depends(get_conversation_store),
|
||||
) -> ConversationResponse:
|
||||
@@ -471,7 +472,22 @@ async def start_conversation(
|
||||
to start a conversation. If the conversation is already running, it will
|
||||
return the existing agent loop info.
|
||||
"""
|
||||
logger.info(f'Starting conversation: {conversation_id}')
|
||||
logger.info(
|
||||
f'Starting conversation: {conversation_id}',
|
||||
extra={'session_id': conversation_id},
|
||||
)
|
||||
|
||||
# Log token fetch status
|
||||
if provider_tokens:
|
||||
logger.info(
|
||||
f'/start endpoint: Fetched provider tokens: {list(provider_tokens.keys())}',
|
||||
extra={'session_id': conversation_id},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
'/start endpoint: No provider tokens fetched (provider_tokens is None/empty)',
|
||||
extra={'session_id': conversation_id},
|
||||
)
|
||||
|
||||
try:
|
||||
# Check that the conversation exists
|
||||
@@ -488,7 +504,7 @@ async def start_conversation(
|
||||
|
||||
# Set up conversation init data with provider information
|
||||
conversation_init_data = await setup_init_conversation_settings(
|
||||
user_id, conversation_id, providers_set.providers_set or []
|
||||
user_id, conversation_id, providers_set.providers_set or [], provider_tokens
|
||||
)
|
||||
|
||||
# Start the agent loop
|
||||
|
||||
@@ -11,15 +11,10 @@ from openhands.server.routes.secrets import invalidate_legacy_secrets_store
|
||||
from openhands.server.settings import (
|
||||
GETSettingsModel,
|
||||
)
|
||||
from openhands.server.settings_validation import (
|
||||
check_llm_settings_changes,
|
||||
validate_llm_settings_access,
|
||||
)
|
||||
from openhands.server.shared import config
|
||||
from openhands.server.user_auth import (
|
||||
get_provider_tokens,
|
||||
get_secrets_store,
|
||||
get_user_id,
|
||||
get_user_settings,
|
||||
get_user_settings_store,
|
||||
)
|
||||
@@ -140,34 +135,17 @@ async def store_llm_settings(
|
||||
response_model=None,
|
||||
responses={
|
||||
200: {'description': 'Settings stored successfully', 'model': dict},
|
||||
403: {'description': 'Subscription required for pro models', 'model': dict},
|
||||
500: {'description': 'Error storing settings', 'model': dict},
|
||||
},
|
||||
)
|
||||
async def store_settings(
|
||||
settings: Settings,
|
||||
settings_store: SettingsStore = Depends(get_user_settings_store),
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> JSONResponse:
|
||||
# Check provider tokens are valid
|
||||
try:
|
||||
existing_settings = await settings_store.load()
|
||||
|
||||
# Check if any LLM-related settings are being changed
|
||||
llm_settings_being_changed = check_llm_settings_changes(
|
||||
settings, existing_settings
|
||||
)
|
||||
|
||||
if llm_settings_being_changed:
|
||||
has_access = await validate_llm_settings_access(user_id)
|
||||
if not has_access:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
content={
|
||||
'error': 'Modifying LLM settings requires an active OpenHands Pro subscription. Please upgrade your account to access LLM configuration.',
|
||||
'detail': 'Subscription required for LLM settings modifications',
|
||||
},
|
||||
)
|
||||
|
||||
# Convert to Settings model and merge with existing settings
|
||||
if existing_settings:
|
||||
settings = await store_llm_settings(settings, settings_store)
|
||||
|
||||
@@ -215,7 +215,10 @@ def create_provider_tokens_object(
|
||||
|
||||
|
||||
async def setup_init_conversation_settings(
|
||||
user_id: str | None, conversation_id: str, providers_set: list[ProviderType]
|
||||
user_id: str | None,
|
||||
conversation_id: str,
|
||||
providers_set: list[ProviderType],
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = None,
|
||||
) -> ConversationInitData:
|
||||
"""Set up conversation initialization data with provider tokens.
|
||||
|
||||
@@ -223,6 +226,7 @@ async def setup_init_conversation_settings(
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
providers_set: List of provider types to set up tokens for
|
||||
provider_tokens: Optional provider tokens to use (for SAAS mode resume)
|
||||
|
||||
Returns:
|
||||
ConversationInitData with provider tokens configured
|
||||
@@ -243,11 +247,30 @@ async def setup_init_conversation_settings(
|
||||
session_init_args: dict = {}
|
||||
session_init_args = {**settings.__dict__, **session_init_args}
|
||||
|
||||
git_provider_tokens = create_provider_tokens_object(providers_set)
|
||||
logger.info(f'Git provider scaffold: {git_provider_tokens}')
|
||||
# Use provided tokens if available (for SAAS resume), otherwise create scaffold
|
||||
if provider_tokens:
|
||||
logger.info(
|
||||
f'Using provided provider_tokens: {list(provider_tokens.keys())}',
|
||||
extra={'session_id': conversation_id},
|
||||
)
|
||||
git_provider_tokens = provider_tokens
|
||||
else:
|
||||
logger.info(
|
||||
f'No provider_tokens provided, creating scaffold for: {providers_set}',
|
||||
extra={'session_id': conversation_id},
|
||||
)
|
||||
git_provider_tokens = create_provider_tokens_object(providers_set)
|
||||
logger.info(
|
||||
f'Git provider scaffold: {git_provider_tokens}',
|
||||
extra={'session_id': conversation_id},
|
||||
)
|
||||
|
||||
if server_config.app_mode != AppMode.SAAS and user_secrets:
|
||||
git_provider_tokens = user_secrets.provider_tokens
|
||||
if server_config.app_mode != AppMode.SAAS and user_secrets:
|
||||
logger.info(
|
||||
f'Non-SaaS mode: Overriding with user_secrets provider tokens: {list(user_secrets.provider_tokens.keys())}',
|
||||
extra={'session_id': conversation_id},
|
||||
)
|
||||
git_provider_tokens = user_secrets.provider_tokens
|
||||
|
||||
session_init_args['git_provider_tokens'] = git_provider_tokens
|
||||
if user_secrets:
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
"""Settings validation utilities for LLM settings access control."""
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.shared import server_config
|
||||
from openhands.server.types import AppMode
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
def _is_llm_setting_changing(setting_name: str, new_value, existing_settings) -> bool:
|
||||
"""Check if a specific LLM setting is being changed from its existing value.
|
||||
|
||||
Args:
|
||||
setting_name: Name of the setting to check
|
||||
new_value: New value being set
|
||||
existing_settings: Existing settings object (can be None)
|
||||
|
||||
Returns:
|
||||
bool: True if the setting is being changed, False otherwise
|
||||
"""
|
||||
if new_value is None:
|
||||
return False
|
||||
|
||||
# Handle special case for enable_default_condenser with default value
|
||||
if setting_name == 'enable_default_condenser':
|
||||
if not existing_settings:
|
||||
# First time setting - only validate if setting to non-default value
|
||||
return not new_value
|
||||
else:
|
||||
# Changing existing value
|
||||
return new_value != existing_settings.enable_default_condenser
|
||||
|
||||
# For other settings, validate if explicitly provided and different from existing
|
||||
if not existing_settings:
|
||||
return True
|
||||
|
||||
existing_value = getattr(existing_settings, setting_name, None)
|
||||
return new_value != existing_value
|
||||
|
||||
|
||||
def check_llm_settings_changes(settings: Settings, existing_settings) -> bool:
|
||||
"""Check if any LLM-related settings are being changed.
|
||||
|
||||
Validates both core LLM settings (model, API key, base URL) and advanced settings
|
||||
shown to SaaS users (confirmation mode, security analyzer, memory condenser settings).
|
||||
|
||||
Args:
|
||||
settings: New settings being applied
|
||||
existing_settings: Current settings (can be None)
|
||||
|
||||
Returns:
|
||||
bool: True if any LLM settings are being changed, False otherwise
|
||||
"""
|
||||
# Core LLM settings - always validate if provided
|
||||
core_llm_changes = any(
|
||||
[
|
||||
settings.llm_model is not None,
|
||||
settings.llm_api_key is not None,
|
||||
settings.llm_base_url is not None,
|
||||
]
|
||||
)
|
||||
|
||||
if core_llm_changes:
|
||||
return True
|
||||
|
||||
# Additional LLM settings shown to SaaS users - validate if actually changing
|
||||
advanced_llm_changes = any(
|
||||
[
|
||||
_is_llm_setting_changing(
|
||||
'confirmation_mode', settings.confirmation_mode, existing_settings
|
||||
),
|
||||
_is_llm_setting_changing(
|
||||
'security_analyzer', settings.security_analyzer, existing_settings
|
||||
),
|
||||
_is_llm_setting_changing(
|
||||
'enable_default_condenser',
|
||||
settings.enable_default_condenser,
|
||||
existing_settings,
|
||||
),
|
||||
_is_llm_setting_changing(
|
||||
'condenser_max_size', settings.condenser_max_size, existing_settings
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
return advanced_llm_changes
|
||||
|
||||
|
||||
async def validate_llm_settings_access(user_id: str) -> bool:
|
||||
"""Validate if user has access to modify LLM settings in SaaS mode.
|
||||
|
||||
In SaaS mode, only pro users with active subscriptions can modify LLM settings.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to check subscription for
|
||||
|
||||
Returns:
|
||||
bool: True if user can modify LLM settings, False otherwise
|
||||
"""
|
||||
# Skip validation in non-SaaS mode
|
||||
if server_config.app_mode != AppMode.SAAS:
|
||||
return True
|
||||
|
||||
# In SaaS mode, check for active subscription for ANY LLM settings changes
|
||||
try:
|
||||
# Import here to avoid circular imports and handle enterprise mode gracefully
|
||||
from enterprise.server.routes.billing import get_subscription_access
|
||||
|
||||
subscription = await get_subscription_access(user_id)
|
||||
# The get_subscription_access function already filters for ACTIVE status,
|
||||
# so if we get a subscription back, it means it's active
|
||||
return subscription is not None
|
||||
except ImportError:
|
||||
# Enterprise billing module not available - in SaaS mode, this means
|
||||
# we can't validate subscriptions, so deny access to be safe
|
||||
logger.warning(
|
||||
'Enterprise billing module not available in SaaS mode, denying LLM settings access'
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
# On error, deny access to be safe
|
||||
logger.warning(f'Error checking subscription access for user {user_id}: {e}')
|
||||
return False
|
||||
@@ -1,329 +0,0 @@
|
||||
"""Security tests for settings API to ensure pro-only features are properly validated on backend."""
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import Request
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.server.app import app
|
||||
from openhands.server.types import AppMode
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
from openhands.storage.secrets.secrets_store import SecretsStore
|
||||
from openhands.storage.settings.file_settings_store import FileSettingsStore
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
|
||||
|
||||
class MockUserAuthNonPro(UserAuth):
|
||||
"""Mock implementation of UserAuth for non-pro user testing"""
|
||||
|
||||
def __init__(self):
|
||||
self._settings = None
|
||||
self._settings_store = MagicMock()
|
||||
self._settings_store.load = AsyncMock(return_value=None)
|
||||
self._settings_store.store = AsyncMock()
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
return 'test-user-nonpro'
|
||||
|
||||
async def get_user_email(self) -> str | None:
|
||||
return 'nonpro@test.com'
|
||||
|
||||
async def get_access_token(self) -> SecretStr | None:
|
||||
return SecretStr('test-token-nonpro')
|
||||
|
||||
async def get_provider_tokens(self) -> dict[ProviderType, ProviderToken] | None:
|
||||
return None
|
||||
|
||||
async def get_user_settings_store(self) -> SettingsStore | None:
|
||||
return self._settings_store
|
||||
|
||||
async def get_secrets_store(self) -> SecretsStore | None:
|
||||
return None
|
||||
|
||||
async def get_user_secrets(self) -> UserSecrets | None:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls, request: Request) -> UserAuth:
|
||||
return MockUserAuthNonPro()
|
||||
|
||||
|
||||
class MockUserAuthPro(UserAuth):
|
||||
"""Mock implementation of UserAuth for pro user testing"""
|
||||
|
||||
def __init__(self):
|
||||
self._settings = None
|
||||
self._settings_store = MagicMock()
|
||||
self._settings_store.load = AsyncMock(return_value=None)
|
||||
self._settings_store.store = AsyncMock()
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
return 'test-user-pro'
|
||||
|
||||
async def get_user_email(self) -> str | None:
|
||||
return 'pro@test.com'
|
||||
|
||||
async def get_access_token(self) -> SecretStr | None:
|
||||
return SecretStr('test-token-pro')
|
||||
|
||||
async def get_provider_tokens(self) -> dict[ProviderType, ProviderToken] | None:
|
||||
return None
|
||||
|
||||
async def get_user_settings_store(self) -> SettingsStore | None:
|
||||
return self._settings_store
|
||||
|
||||
async def get_secrets_store(self) -> SecretsStore | None:
|
||||
return None
|
||||
|
||||
async def get_user_secrets(self) -> UserSecrets | None:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls, request: Request) -> UserAuth:
|
||||
return MockUserAuthPro()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client_non_pro():
|
||||
"""Test client for non-pro user"""
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ, {'SESSION_API_KEY': '', 'APP_MODE': 'saas'}, clear=False
|
||||
),
|
||||
patch('openhands.server.dependencies._SESSION_API_KEY', None),
|
||||
patch('openhands.server.shared.server_config.app_mode', AppMode.SAAS),
|
||||
patch(
|
||||
'openhands.server.user_auth.user_auth.UserAuth.get_instance',
|
||||
return_value=MockUserAuthNonPro(),
|
||||
),
|
||||
patch(
|
||||
'openhands.storage.settings.file_settings_store.FileSettingsStore.get_instance',
|
||||
AsyncMock(return_value=FileSettingsStore(InMemoryFileStore())),
|
||||
),
|
||||
# Mock the validation function at the routes level to return False (no access)
|
||||
patch(
|
||||
'openhands.server.routes.settings.validate_llm_settings_access',
|
||||
AsyncMock(return_value=False),
|
||||
),
|
||||
):
|
||||
client = TestClient(app)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client_pro():
|
||||
"""Test client for pro user"""
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ, {'SESSION_API_KEY': '', 'APP_MODE': 'saas'}, clear=False
|
||||
),
|
||||
patch('openhands.server.dependencies._SESSION_API_KEY', None),
|
||||
patch('openhands.server.shared.server_config.app_mode', AppMode.SAAS),
|
||||
patch(
|
||||
'openhands.server.user_auth.user_auth.UserAuth.get_instance',
|
||||
return_value=MockUserAuthPro(),
|
||||
),
|
||||
patch(
|
||||
'openhands.storage.settings.file_settings_store.FileSettingsStore.get_instance',
|
||||
AsyncMock(return_value=FileSettingsStore(InMemoryFileStore())),
|
||||
),
|
||||
# Mock the validation function at the routes level to return True (has access)
|
||||
patch(
|
||||
'openhands.server.routes.settings.validate_llm_settings_access',
|
||||
AsyncMock(return_value=True),
|
||||
),
|
||||
):
|
||||
client = TestClient(app)
|
||||
yield client
|
||||
|
||||
|
||||
# Test data constants
|
||||
OPENHANDS_PRO_MODELS = [
|
||||
'openhands/claude-sonnet-4-20250514',
|
||||
'openhands/gpt-5-2025-08-07',
|
||||
'openhands/gpt-5-mini-2025-08-07',
|
||||
'openhands/claude-opus-4-20250514',
|
||||
'openhands/claude-opus-4-1-20250805',
|
||||
'openhands/gemini-2.5-pro',
|
||||
'openhands/o3',
|
||||
'openhands/o4-mini',
|
||||
]
|
||||
|
||||
DEFAULT_MODEL = 'claude-sonnet-4-20250514'
|
||||
|
||||
USER_PROVIDED_MODELS = [
|
||||
'anthropic/claude-3-5-sonnet-20241022',
|
||||
'openai/gpt-4o',
|
||||
'mistral/mistral-large',
|
||||
]
|
||||
|
||||
|
||||
# Helper functions
|
||||
def create_base_settings(**overrides):
|
||||
"""Create base settings data with optional overrides"""
|
||||
base_settings = {
|
||||
'language': 'en',
|
||||
'agent': 'test-agent',
|
||||
'max_iterations': 100,
|
||||
}
|
||||
base_settings.update(overrides)
|
||||
return base_settings
|
||||
|
||||
|
||||
def assert_forbidden_response(response, model_or_setting_name=''):
|
||||
"""Assert that response is 403 with subscription-related error"""
|
||||
assert response.status_code == 403, (
|
||||
f'{model_or_setting_name} should be forbidden for non-pro users'
|
||||
)
|
||||
response_data = response.json()
|
||||
assert any(
|
||||
keyword in response_data.get('detail', '').lower()
|
||||
or keyword in response_data.get('error', '').lower()
|
||||
for keyword in ['subscription', 'pro', 'upgrade']
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'model',
|
||||
[
|
||||
'openhands/claude-sonnet-4-20250514',
|
||||
'openhands/gpt-5-2025-08-07',
|
||||
'openhands/claude-opus-4-20250514',
|
||||
DEFAULT_MODEL,
|
||||
]
|
||||
+ USER_PROVIDED_MODELS,
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_pro_user_cannot_set_any_llm_model(test_client_non_pro, model):
|
||||
"""SECURITY TEST: Non-pro user should not be able to set any LLM model"""
|
||||
settings_data = create_base_settings(llm_model=model, llm_api_key='test-key')
|
||||
response = test_client_non_pro.post('/api/settings', json=settings_data)
|
||||
assert_forbidden_response(response, f'Model {model}')
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'llm_setting,value',
|
||||
[
|
||||
('llm_api_key', 'new-api-key'),
|
||||
('llm_base_url', 'https://custom-api.example.com'),
|
||||
('llm_model', DEFAULT_MODEL),
|
||||
('confirmation_mode', True),
|
||||
('security_analyzer', 'llm'),
|
||||
('enable_default_condenser', False),
|
||||
('condenser_max_size', 50),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_pro_user_cannot_set_individual_llm_settings(
|
||||
test_client_non_pro, llm_setting, value
|
||||
):
|
||||
"""SECURITY TEST: Non-pro user should not be able to set individual LLM settings"""
|
||||
settings_data = create_base_settings(**{llm_setting: value})
|
||||
response = test_client_non_pro.post('/api/settings', json=settings_data)
|
||||
assert_forbidden_response(response, f'LLM setting {llm_setting}')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_pro_user_can_set_non_llm_settings(test_client_non_pro):
|
||||
"""Non-pro users should still be able to modify non-LLM settings"""
|
||||
# Only use settings that definitely don't trigger LLM validation
|
||||
settings_data = {
|
||||
'language': 'fr',
|
||||
'max_iterations': 50,
|
||||
'user_consents_to_analytics': True,
|
||||
'git_user_name': 'test-user',
|
||||
'git_user_email': 'test@example.com',
|
||||
}
|
||||
response = test_client_non_pro.post('/api/settings', json=settings_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pro_user_can_set_llm_models(test_client_pro):
|
||||
"""Pro user should be able to set any LLM models"""
|
||||
settings_data = create_base_settings(
|
||||
llm_model='openhands/claude-sonnet-4-20250514', llm_api_key='test-key'
|
||||
)
|
||||
response = test_client_pro.post('/api/settings', json=settings_data)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expired_subscription_cannot_access_llm_settings():
|
||||
"""SECURITY TEST: User with expired subscription should not access LLM settings"""
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ, {'SESSION_API_KEY': '', 'APP_MODE': 'saas'}, clear=False
|
||||
),
|
||||
patch('openhands.server.dependencies._SESSION_API_KEY', None),
|
||||
patch('openhands.server.shared.server_config.app_mode', AppMode.SAAS),
|
||||
patch(
|
||||
'openhands.server.user_auth.user_auth.UserAuth.get_instance',
|
||||
return_value=MockUserAuthPro(),
|
||||
),
|
||||
patch(
|
||||
'openhands.storage.settings.file_settings_store.FileSettingsStore.get_instance',
|
||||
AsyncMock(return_value=FileSettingsStore(InMemoryFileStore())),
|
||||
),
|
||||
# Mock validation to return False (expired subscription, no access)
|
||||
patch(
|
||||
'openhands.server.routes.settings.validate_llm_settings_access',
|
||||
AsyncMock(return_value=False),
|
||||
),
|
||||
):
|
||||
client = TestClient(app)
|
||||
settings_data = create_base_settings(
|
||||
llm_model='openhands/claude-sonnet-4-20250514', llm_api_key='test-key'
|
||||
)
|
||||
response = client.post('/api/settings', json=settings_data)
|
||||
assert_forbidden_response(response, 'Expired subscription')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_api_bypass_prevention(test_client_non_pro):
|
||||
"""SECURITY TEST: Direct API calls should still validate subscription status"""
|
||||
settings_data = create_base_settings(
|
||||
llm_model='openhands/claude-sonnet-4-20250514',
|
||||
llm_api_key='fake-api-key',
|
||||
llm_base_url='https://api.anthropic.com',
|
||||
remote_runtime_resource_factor=4,
|
||||
)
|
||||
|
||||
response = test_client_non_pro.post(
|
||||
'/api/settings',
|
||||
json=settings_data,
|
||||
headers={
|
||||
'Content-Type': 'application/json',
|
||||
'User-Agent': 'DirectAPIClient/1.0',
|
||||
},
|
||||
)
|
||||
assert_forbidden_response(response, 'Direct API bypass attempt')
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'malicious_model',
|
||||
[
|
||||
'openhands/claude-sonnet-4-20250514', # Direct
|
||||
'OPENHANDS/claude-sonnet-4-20250514', # Case manipulation
|
||||
' openhands/claude-sonnet-4-20250514', # Leading space
|
||||
'openhands//claude-sonnet-4-20250514', # Double slash
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_prefix_bypass_attempts_blocked(
|
||||
test_client_non_pro, malicious_model
|
||||
):
|
||||
"""SECURITY TEST: Various prefix bypass attempts should be blocked"""
|
||||
settings_data = create_base_settings(
|
||||
llm_model=malicious_model, llm_api_key='test-key'
|
||||
)
|
||||
response = test_client_non_pro.post('/api/settings', json=settings_data)
|
||||
assert_forbidden_response(
|
||||
response, f"Bypass attempt with model '{malicious_model}'"
|
||||
)
|
||||
@@ -0,0 +1,150 @@
|
||||
"""Tests for /start endpoint provider token handling."""
|
||||
|
||||
from types import MappingProxyType
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.server.data_models.agent_loop_info import AgentLoopInfo
|
||||
from openhands.server.routes.manage_conversations import (
|
||||
ProvidersSetModel,
|
||||
start_conversation,
|
||||
)
|
||||
from openhands.server.types import AppMode
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
"""Create a real Settings object with minimal required fields."""
|
||||
return Settings(
|
||||
language='en',
|
||||
agent='CodeActAgent',
|
||||
max_iterations=100,
|
||||
llm_model='anthropic/claude-3-5-sonnet-20241022',
|
||||
llm_api_key=SecretStr('test_api_key_12345'),
|
||||
llm_base_url=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_tokens():
|
||||
"""Create real provider tokens to test with."""
|
||||
return MappingProxyType(
|
||||
{
|
||||
ProviderType.GITHUB: ProviderToken(
|
||||
token=SecretStr('ghp_real_token_test123'), user_id='test_user_456'
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation_metadata():
|
||||
"""Create a real ConversationMetadata object."""
|
||||
return ConversationMetadata(
|
||||
conversation_id='test_conv_123',
|
||||
user_id='test_user_456',
|
||||
title='Test Conversation',
|
||||
selected_repository='test/repo',
|
||||
selected_branch='main',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_endpoint_passes_provider_tokens(
|
||||
mock_settings, mock_provider_tokens, mock_conversation_metadata
|
||||
):
|
||||
"""Test that /start endpoint passes provider_tokens to setup_init_conversation_settings.
|
||||
|
||||
This test verifies the full end-to-end flow with real tokens through to ConversationInitData.
|
||||
"""
|
||||
conversation_id = 'test_conv_123'
|
||||
user_id = 'test_user_456'
|
||||
providers_set = ProvidersSetModel(providers_set=[ProviderType.GITHUB])
|
||||
|
||||
# Mock conversation store
|
||||
mock_conversation_store = AsyncMock()
|
||||
mock_conversation_store.get_metadata = AsyncMock(
|
||||
return_value=mock_conversation_metadata
|
||||
)
|
||||
|
||||
# Mock agent loop info that will be returned
|
||||
mock_agent_loop_info = AgentLoopInfo(
|
||||
conversation_id=conversation_id,
|
||||
url=None,
|
||||
session_api_key=None,
|
||||
event_store=None,
|
||||
status=ConversationStatus.RUNNING,
|
||||
)
|
||||
|
||||
# Mock only infrastructure - let setup_init_conversation_settings run for real
|
||||
with patch(
|
||||
'openhands.server.routes.manage_conversations.conversation_manager'
|
||||
) as mock_manager:
|
||||
# Mock the stores that setup_init_conversation_settings needs
|
||||
with patch(
|
||||
'openhands.server.services.conversation_service.SettingsStoreImpl.get_instance'
|
||||
) as mock_settings_store_cls:
|
||||
with patch(
|
||||
'openhands.server.services.conversation_service.SecretsStoreImpl.get_instance'
|
||||
) as mock_secrets_store_cls:
|
||||
with patch(
|
||||
'openhands.server.services.conversation_service.server_config'
|
||||
) as mock_server_config:
|
||||
# Setup store mocks
|
||||
mock_settings_store = AsyncMock()
|
||||
mock_settings_store.load = AsyncMock(return_value=mock_settings)
|
||||
mock_settings_store_cls.return_value = mock_settings_store
|
||||
|
||||
mock_secrets_store = AsyncMock()
|
||||
mock_secrets_store.load = AsyncMock(return_value=None)
|
||||
mock_secrets_store_cls.return_value = mock_secrets_store
|
||||
|
||||
mock_server_config.app_mode = AppMode.SAAS
|
||||
|
||||
mock_manager.maybe_start_agent_loop = AsyncMock(
|
||||
return_value=mock_agent_loop_info
|
||||
)
|
||||
|
||||
# Call endpoint with provider tokens
|
||||
response = await start_conversation(
|
||||
providers_set=providers_set,
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
provider_tokens=mock_provider_tokens,
|
||||
settings=mock_settings,
|
||||
conversation_store=mock_conversation_store,
|
||||
)
|
||||
|
||||
# Verify ConversationInitData has real provider tokens
|
||||
mock_manager.maybe_start_agent_loop.assert_called_once()
|
||||
call_kwargs = mock_manager.maybe_start_agent_loop.call_args[1]
|
||||
conversation_init_data = call_kwargs['settings']
|
||||
|
||||
assert conversation_init_data.git_provider_tokens is not None
|
||||
assert (
|
||||
conversation_init_data.git_provider_tokens
|
||||
== mock_provider_tokens
|
||||
)
|
||||
assert (
|
||||
ProviderType.GITHUB
|
||||
in conversation_init_data.git_provider_tokens
|
||||
)
|
||||
|
||||
github_token = conversation_init_data.git_provider_tokens[
|
||||
ProviderType.GITHUB
|
||||
]
|
||||
assert (
|
||||
github_token.token.get_secret_value()
|
||||
== 'ghp_real_token_test123'
|
||||
)
|
||||
assert github_token.user_id == 'test_user_456'
|
||||
|
||||
assert response.status == 'ok'
|
||||
assert response.conversation_id == conversation_id
|
||||
@@ -0,0 +1,162 @@
|
||||
"""Unit tests for conversation_service.py - specifically testing provider token handling.
|
||||
|
||||
These tests verify that setup_init_conversation_settings correctly handles provider tokens
|
||||
in different scenarios (provided tokens vs scaffold creation).
|
||||
"""
|
||||
|
||||
from types import MappingProxyType
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.server.services.conversation_service import (
|
||||
setup_init_conversation_settings,
|
||||
)
|
||||
from openhands.server.types import AppMode
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
"""Create a real Settings object with minimal required fields."""
|
||||
return Settings(
|
||||
language='en',
|
||||
agent='CodeActAgent',
|
||||
max_iterations=100,
|
||||
llm_model='anthropic/claude-3-5-sonnet-20241022',
|
||||
llm_api_key=SecretStr('test_api_key_12345'),
|
||||
llm_base_url=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_tokens():
|
||||
"""Create real provider tokens to test with."""
|
||||
return MappingProxyType(
|
||||
{
|
||||
ProviderType.GITHUB: ProviderToken(
|
||||
token=SecretStr('ghp_real_token_test123'), user_id='test_user_456'
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_with_provided_tokens_uses_real_tokens(
|
||||
mock_settings, mock_provider_tokens
|
||||
):
|
||||
"""Test that real tokens are used when provided in SAAS mode.
|
||||
|
||||
Verifies provider tokens passed in are used in ConversationInitData.
|
||||
"""
|
||||
user_id = 'test_user_456'
|
||||
conversation_id = 'test_conv_123'
|
||||
providers_set = [ProviderType.GITHUB]
|
||||
|
||||
# Mock the stores to return our test settings
|
||||
with patch(
|
||||
'openhands.server.services.conversation_service.SettingsStoreImpl.get_instance'
|
||||
) as mock_settings_store_cls:
|
||||
with patch(
|
||||
'openhands.server.services.conversation_service.SecretsStoreImpl.get_instance'
|
||||
) as mock_secrets_store_cls:
|
||||
with patch(
|
||||
'openhands.server.services.conversation_service.server_config'
|
||||
) as mock_server_config:
|
||||
# Setup mocks
|
||||
mock_settings_store = AsyncMock()
|
||||
mock_settings_store.load = AsyncMock(return_value=mock_settings)
|
||||
mock_settings_store_cls.return_value = mock_settings_store
|
||||
|
||||
mock_secrets_store = AsyncMock()
|
||||
mock_secrets_store.load = AsyncMock(return_value=None)
|
||||
mock_secrets_store_cls.return_value = mock_secrets_store
|
||||
|
||||
mock_server_config.app_mode = AppMode.SAAS
|
||||
|
||||
# Call with real tokens
|
||||
result = await setup_init_conversation_settings(
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
providers_set=providers_set,
|
||||
provider_tokens=mock_provider_tokens,
|
||||
)
|
||||
|
||||
# Verify real tokens are used
|
||||
assert result.git_provider_tokens is not None
|
||||
assert result.git_provider_tokens == mock_provider_tokens
|
||||
assert ProviderType.GITHUB in result.git_provider_tokens, (
|
||||
'GitHub provider should be in tokens'
|
||||
)
|
||||
|
||||
github_token = result.git_provider_tokens[ProviderType.GITHUB]
|
||||
assert (
|
||||
github_token.token.get_secret_value() == 'ghp_real_token_test123'
|
||||
), 'Should use real token, not None'
|
||||
assert github_token.user_id == 'test_user_456', (
|
||||
'Should preserve user_id from real token'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_without_tokens_non_saas_uses_user_secrets(mock_settings):
|
||||
"""Test that OSS mode uses user_secrets.provider_tokens when no tokens provided.
|
||||
|
||||
This test verifies OSS mode backward compatibility - tokens come from local config, not endpoint.
|
||||
"""
|
||||
user_id = 'test_user_456'
|
||||
conversation_id = 'test_conv_123'
|
||||
providers_set = [ProviderType.GITHUB]
|
||||
|
||||
# Create user_secrets with real tokens
|
||||
mock_user_secrets = MagicMock()
|
||||
mock_user_secrets.provider_tokens = MappingProxyType(
|
||||
{
|
||||
ProviderType.GITHUB: ProviderToken(
|
||||
token=SecretStr('ghp_local_token_from_config'),
|
||||
user_id='local_user_123',
|
||||
)
|
||||
}
|
||||
)
|
||||
mock_user_secrets.custom_secrets = MappingProxyType({}) # Empty dict is fine
|
||||
|
||||
with patch(
|
||||
'openhands.server.services.conversation_service.SettingsStoreImpl.get_instance'
|
||||
) as mock_settings_store_cls:
|
||||
with patch(
|
||||
'openhands.server.services.conversation_service.SecretsStoreImpl.get_instance'
|
||||
) as mock_secrets_store_cls:
|
||||
with patch(
|
||||
'openhands.server.services.conversation_service.server_config'
|
||||
) as mock_server_config:
|
||||
# Setup mocks
|
||||
mock_settings_store = AsyncMock()
|
||||
mock_settings_store.load = AsyncMock(return_value=mock_settings)
|
||||
mock_settings_store_cls.return_value = mock_settings_store
|
||||
|
||||
mock_secrets_store = AsyncMock()
|
||||
mock_secrets_store.load = AsyncMock(return_value=mock_user_secrets)
|
||||
mock_secrets_store_cls.return_value = mock_secrets_store
|
||||
|
||||
mock_server_config.app_mode = AppMode.OSS
|
||||
|
||||
# Call without endpoint tokens
|
||||
result = await setup_init_conversation_settings(
|
||||
user_id=user_id,
|
||||
conversation_id=conversation_id,
|
||||
providers_set=providers_set,
|
||||
provider_tokens=None,
|
||||
)
|
||||
|
||||
# Verify user_secrets tokens are used
|
||||
assert result.git_provider_tokens is not None
|
||||
assert ProviderType.GITHUB in result.git_provider_tokens
|
||||
|
||||
github_token = result.git_provider_tokens[ProviderType.GITHUB]
|
||||
assert (
|
||||
github_token.token.get_secret_value()
|
||||
== 'ghp_local_token_from_config'
|
||||
)
|
||||
assert github_token.user_id == 'local_user_123'
|
||||
@@ -1,122 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.acp.jsonrpc import JsonRpcConnection, NDJsonStdio
|
||||
from openhands.acp.server import ACPAgentServer
|
||||
|
||||
|
||||
class MemoryRW:
|
||||
def __init__(self):
|
||||
self.read_q: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
self.write_q: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
|
||||
def get_streams(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
async def reader_gen(reader: asyncio.StreamReader):
|
||||
while True:
|
||||
data = await self.read_q.get()
|
||||
reader.feed_data(data)
|
||||
if data == b'':
|
||||
reader.feed_eof()
|
||||
break
|
||||
|
||||
async def make_reader():
|
||||
reader = asyncio.StreamReader()
|
||||
asyncio.create_task(reader_gen(reader))
|
||||
return reader
|
||||
|
||||
class DummyProto(asyncio.Protocol):
|
||||
async def _drain_helper(self) -> None: # satisfy StreamWriter.drain()
|
||||
return None
|
||||
|
||||
async def make_writer(reader: asyncio.StreamReader):
|
||||
class DummyTransport(asyncio.Transport):
|
||||
def write(inner_self, data: bytes) -> None:
|
||||
self.write_q.put_nowait(data)
|
||||
|
||||
def is_closing(inner_self) -> bool: # noqa: PLW3201
|
||||
return False
|
||||
|
||||
return asyncio.StreamWriter(DummyTransport(), DummyProto(), reader, loop)
|
||||
|
||||
return make_reader, make_writer
|
||||
|
||||
|
||||
async def rpc_pair():
|
||||
mem = MemoryRW()
|
||||
make_reader, make_writer = mem.get_streams()
|
||||
reader = await make_reader()
|
||||
writer = await make_writer(reader)
|
||||
|
||||
stream = NDJsonStdio(reader, writer)
|
||||
rpc = JsonRpcConnection(stream)
|
||||
server = ACPAgentServer(rpc)
|
||||
|
||||
async def serve():
|
||||
await rpc.serve(server.handle_request, server.handle_notification)
|
||||
|
||||
task = asyncio.create_task(serve())
|
||||
return mem, task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_minimal_initialize_and_prompt():
|
||||
mem, task = await rpc_pair()
|
||||
|
||||
def encode(obj: Any) -> bytes:
|
||||
return (json.dumps(obj) + '\n').encode()
|
||||
|
||||
# send initialize request
|
||||
mem.read_q.put_nowait(
|
||||
encode({'jsonrpc': '2.0', 'id': 1, 'method': 'initialize', 'params': {}})
|
||||
)
|
||||
|
||||
# read initialize response
|
||||
data = await mem.write_q.get()
|
||||
msg = json.loads(data.decode())
|
||||
assert msg['id'] == 1
|
||||
assert 'result' in msg
|
||||
assert msg['result']['protocolVersion'] == 1
|
||||
|
||||
# new session
|
||||
mem.read_q.put_nowait(
|
||||
encode({'jsonrpc': '2.0', 'id': 2, 'method': 'session/new', 'params': {}})
|
||||
)
|
||||
msg = json.loads((await mem.write_q.get()).decode())
|
||||
assert msg['id'] == 2
|
||||
session_id = msg['result']['sessionId']
|
||||
|
||||
# prompt
|
||||
mem.read_q.put_nowait(
|
||||
encode(
|
||||
{
|
||||
'jsonrpc': '2.0',
|
||||
'id': 3,
|
||||
'method': 'session/prompt',
|
||||
'params': {'sessionId': session_id, 'messages': []},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Expect one or more session/update notifications before the result
|
||||
while True:
|
||||
msg = json.loads((await mem.write_q.get()).decode())
|
||||
if 'method' in msg:
|
||||
assert msg['method'] == 'session/update'
|
||||
assert msg['params']['sessionId'] == session_id
|
||||
continue
|
||||
# Then response to prompt
|
||||
assert msg['id'] == 3
|
||||
assert msg['result']['stopReason'] in ('end_turn', 'cancelled')
|
||||
break
|
||||
|
||||
# Close
|
||||
mem.read_q.put_nowait(b'')
|
||||
await asyncio.sleep(0) # let server finish
|
||||
task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
Reference in New Issue
Block a user