From de81020a8dc9ce1dfdcfddf7cf85fe0615b25c73 Mon Sep 17 00:00:00 2001 From: tofarr Date: Thu, 5 Dec 2024 13:11:00 -0700 Subject: [PATCH] Feat: Introduce class for SessionInitData rather than using a dict (#5406) --- config.template.toml | 10 ++--- openhands/server/session/manager.py | 9 +++-- openhands/server/session/session.py | 37 +++++++------------ openhands/server/session/session_init_data.py | 18 +++++++++ openhands/server/socket.py | 26 +++++++++---- tests/unit/test_manager.py | 13 ++++--- 6 files changed, 68 insertions(+), 45 deletions(-) create mode 100644 openhands/server/session/session_init_data.py diff --git a/config.template.toml b/config.template.toml index d19ff6085e..9d84c5a7ff 100644 --- a/config.template.toml +++ b/config.template.toml @@ -95,10 +95,10 @@ workspace_base = "./workspace" # AWS secret access key #aws_secret_access_key = "" -# API key to use +# API key to use (For Headless / CLI only - In Web this is overridden by Session Init) api_key = "your-api-key" -# API base URL +# API base URL (For Headless / CLI only - In Web this is overridden by Session Init) #base_url = "" # API version @@ -131,7 +131,7 @@ embedding_model = "local" # Maximum number of output tokens #max_output_tokens = 0 -# Model to use +# Model to use. (For Headless / CLI only - In Web this is overridden by Session Init) model = "gpt-4o" # Number of retries to attempt when an operation fails with the LLM. @@ -237,10 +237,10 @@ llm_config = 'gpt3' ############################################################################## [security] -# Enable confirmation mode +# Enable confirmation mode (For Headless / CLI only - In Web this is overridden by Session Init) #confirmation_mode = false -# The security analyzer to use +# The security analyzer to use (For Headless / CLI only - In Web this is overridden by Session Init) #security_analyzer = "" #################################### Eval #################################### diff --git a/openhands/server/session/manager.py b/openhands/server/session/manager.py index 7447ce0f8e..10b34b5dd2 100644 --- a/openhands/server/session/manager.py +++ b/openhands/server/session/manager.py @@ -11,6 +11,7 @@ from openhands.events.stream import EventStream, session_exists from openhands.runtime.base import RuntimeUnavailableError from openhands.server.session.conversation import Conversation from openhands.server.session.session import ROOM_KEY, Session +from openhands.server.session.session_init_data import SessionInitData from openhands.storage.files import FileStore from openhands.utils.shutdown_listener import should_continue @@ -141,7 +142,7 @@ class SessionManager: async def detach_from_conversation(self, conversation: Conversation): await conversation.disconnect() - async def init_or_join_session(self, sid: str, connection_id: str, data: dict): + async def init_or_join_session(self, sid: str, connection_id: str, session_init_data: SessionInitData): await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid)) self.local_connection_id_to_session_id[connection_id] = sid @@ -156,7 +157,7 @@ class SessionManager: if redis_client and await self._is_session_running_in_cluster(sid): return EventStream(sid, self.file_store) - return await self.start_local_session(sid, data) + return await self.start_local_session(sid, session_init_data) async def _is_session_running_in_cluster(self, sid: str) -> bool: """As the rest of the cluster if a session is running. Wait a for a short timeout for a reply""" @@ -210,14 +211,14 @@ class SessionManager: finally: self._has_remote_connections_flags.pop(sid) - async def start_local_session(self, sid: str, data: dict): + async def start_local_session(self, sid: str, session_init_data: SessionInitData): # Start a new local session logger.info(f'start_new_local_session:{sid}') session = Session( sid=sid, file_store=self.file_store, config=self.config, sio=self.sio ) self.local_sessions_by_sid[sid] = session - await session.initialize_agent(data) + await session.initialize_agent(session_init_data) return session.agent_session.event_stream async def send_to_event_stream(self, connection_id: str, data: dict): diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index a6edb82f3d..1039d954bc 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -1,4 +1,5 @@ import asyncio +from copy import deepcopy import time import socketio @@ -21,6 +22,7 @@ from openhands.events.serialization import event_from_dict, event_to_dict from openhands.events.stream import EventStreamSubscriber from openhands.llm.llm import LLM from openhands.server.session.agent_session import AgentSession +from openhands.server.session.session_init_data import SessionInitData from openhands.storage.files import FileStore ROOM_KEY = 'room:{sid}' @@ -34,7 +36,6 @@ class Session: agent_session: AgentSession loop: asyncio.AbstractEventLoop config: AppConfig - settings: dict | None def __init__( self, @@ -52,41 +53,31 @@ class Session: self.agent_session.event_stream.subscribe( EventStreamSubscriber.SERVER, self.on_event, self.sid ) - self.config = config + # Copying this means that when we update variables they are not applied to the shared global configuration! + self.config = deepcopy(config) self.loop = asyncio.get_event_loop() - self.settings = None def close(self): self.is_alive = False self.agent_session.close() - async def initialize_agent(self, data: dict): - self.settings = data + async def initialize_agent(self, session_init_data: SessionInitData): self.agent_session.event_stream.add_event( AgentStateChangedObservation('', AgentState.LOADING), EventSource.ENVIRONMENT, ) # Extract the agent-relevant arguments from the request - args = {key: value for key, value in data.get('args', {}).items()} - agent_cls = args.get(ConfigType.AGENT, self.config.default_agent) - self.config.security.confirmation_mode = args.get( - ConfigType.CONFIRMATION_MODE, self.config.security.confirmation_mode - ) - self.config.security.security_analyzer = data.get('args', {}).get( - ConfigType.SECURITY_ANALYZER, self.config.security.security_analyzer - ) - max_iterations = args.get(ConfigType.MAX_ITERATIONS, self.config.max_iterations) + agent_cls = session_init_data.agent or self.config.default_agent + self.config.security.confirmation_mode = self.config.security.confirmation_mode if session_init_data.confirmation_mode is None else session_init_data.confirmation_mode + self.config.security.security_analyzer = session_init_data.security_analyzer or self.config.security.security_analyzer + max_iterations = session_init_data.max_iterations or self.config.max_iterations # override default LLM config + + default_llm_config = self.config.get_llm_config() - default_llm_config.model = args.get( - ConfigType.LLM_MODEL, default_llm_config.model - ) - default_llm_config.api_key = args.get( - ConfigType.LLM_API_KEY, default_llm_config.api_key - ) - default_llm_config.base_url = args.get( - ConfigType.LLM_BASE_URL, default_llm_config.base_url - ) + default_llm_config.model = session_init_data.llm_model or default_llm_config.model + default_llm_config.api_key = session_init_data.llm_api_key or default_llm_config.api_key + default_llm_config.base_url = session_init_data.llm_base_url or default_llm_config.base_url # TODO: override other LLM config & agent config groups (#2075) diff --git a/openhands/server/session/session_init_data.py b/openhands/server/session/session_init_data.py new file mode 100644 index 0000000000..2c030f0714 --- /dev/null +++ b/openhands/server/session/session_init_data.py @@ -0,0 +1,18 @@ + + +from dataclasses import dataclass + + +@dataclass +class SessionInitData: + """ + Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data. + """ + language: str | None = None + agent: str | None = None + max_iterations: int | None = None + security_analyzer: str | None = None + confirmation_mode: bool | None = None + llm_model: str | None = None + llm_api_key: str | None = None + llm_base_url: str | None = None diff --git a/openhands/server/socket.py b/openhands/server/socket.py index 0dd121d9f5..8e5ab058af 100644 --- a/openhands/server/socket.py +++ b/openhands/server/socket.py @@ -13,6 +13,7 @@ from openhands.events.serialization import event_to_dict from openhands.events.stream import AsyncEventStreamWrapper from openhands.server.auth import get_sid_from_token, sign_token from openhands.server.github_utils import authenticate_github_user +from openhands.server.session.session_init_data import SessionInitData from openhands.server.shared import config, session_manager, sio @@ -26,19 +27,30 @@ async def oh_action(connection_id: str, data: dict): # If it's an init, we do it here. action = data.get('action', '') if action == ActionType.INIT: - await init_connection(connection_id, data) + token = data.pop('token', None) + github_token = data.pop('github_token', None) + latest_event_id = int(data.pop('latest_event_id', -1)) + kwargs = {k.lower(): v for k, v in (data.get('args') or {}).items()} + session_init_data = SessionInitData(**kwargs) + await init_connection( + connection_id, token, github_token, session_init_data, latest_event_id + ) return logger.info(f'sio:oh_action:{connection_id}') await session_manager.send_to_event_stream(connection_id, data) -async def init_connection(connection_id: str, data: dict): - gh_token = data.pop('github_token', None) +async def init_connection( + connection_id: str, + token: str | None, + gh_token: str | None, + session_init_data: SessionInitData, + latest_event_id: int, +): if not await authenticate_github_user(gh_token): raise RuntimeError(status.WS_1008_POLICY_VIOLATION) - token = data.pop('token', None) if token: sid = get_sid_from_token(token, config.jwt_secret) if sid == '': @@ -52,10 +64,10 @@ async def init_connection(connection_id: str, data: dict): token = sign_token({'sid': sid}, config.jwt_secret) await sio.emit('oh_event', {'token': token, 'status': 'ok'}, to=connection_id) - latest_event_id = int(data.pop('latest_event_id', -1)) - # The session in question should exist, but may not actually be running locally... - event_stream = await session_manager.init_or_join_session(sid, connection_id, data) + event_stream = await session_manager.init_or_join_session( + sid, connection_id, session_init_data + ) # Send events agent_state_changed = None diff --git a/tests/unit/test_manager.py b/tests/unit/test_manager.py index d17fc94bf1..9ec9e4ac31 100644 --- a/tests/unit/test_manager.py +++ b/tests/unit/test_manager.py @@ -7,6 +7,7 @@ import pytest from openhands.core.config.app_config import AppConfig from openhands.server.session.manager import SessionManager +from openhands.server.session.session_init_data import SessionInitData from openhands.storage.memory import InMemoryFileStore @@ -100,7 +101,7 @@ async def test_init_new_local_session(): sio, AppConfig(), InMemoryFileStore() ) as session_manager: await session_manager.init_or_join_session( - 'new-session-id', 'new-session-id', {'type': 'mock-settings'} + 'new-session-id', 'new-session-id', SessionInitData() ) assert session_instance.initialize_agent.call_count == 1 assert sio.enter_room.await_count == 1 @@ -132,11 +133,11 @@ async def test_join_local_session(): ) as session_manager: # First call initializes await session_manager.init_or_join_session( - 'new-session-id', 'new-session-id', {'type': 'mock-settings'} + 'new-session-id', 'new-session-id', SessionInitData() ) # Second call joins await session_manager.init_or_join_session( - 'new-session-id', 'extra-connection-id', {'type': 'mock-settings'} + 'new-session-id', 'extra-connection-id', SessionInitData() ) assert session_instance.initialize_agent.call_count == 1 assert sio.enter_room.await_count == 2 @@ -168,7 +169,7 @@ async def test_join_cluster_session(): ) as session_manager: # First call initializes await session_manager.init_or_join_session( - 'new-session-id', 'new-session-id', {'type': 'mock-settings'} + 'new-session-id', 'new-session-id', SessionInitData() ) assert session_instance.initialize_agent.call_count == 0 assert sio.enter_room.await_count == 1 @@ -199,7 +200,7 @@ async def test_add_to_local_event_stream(): sio, AppConfig(), InMemoryFileStore() ) as session_manager: await session_manager.init_or_join_session( - 'new-session-id', 'connection-id', {'type': 'mock-settings'} + 'new-session-id', 'connection-id', SessionInitData() ) await session_manager.send_to_event_stream( 'connection-id', {'event_type': 'some_event'} @@ -232,7 +233,7 @@ async def test_add_to_cluster_event_stream(): sio, AppConfig(), InMemoryFileStore() ) as session_manager: await session_manager.init_or_join_session( - 'new-session-id', 'connection-id', {'type': 'mock-settings'} + 'new-session-id', 'connection-id', SessionInitData() ) await session_manager.send_to_event_stream( 'connection-id', {'event_type': 'some_event'}