Feat: Introduce class for SessionInitData rather than using a dict (#5406)

This commit is contained in:
tofarr
2024-12-05 13:11:00 -07:00
committed by GitHub
parent 1146b6248b
commit de81020a8d
6 changed files with 68 additions and 45 deletions

View File

@@ -11,6 +11,7 @@ from openhands.events.stream import EventStream, session_exists
from openhands.runtime.base import RuntimeUnavailableError
from openhands.server.session.conversation import Conversation
from openhands.server.session.session import ROOM_KEY, Session
from openhands.server.session.session_init_data import SessionInitData
from openhands.storage.files import FileStore
from openhands.utils.shutdown_listener import should_continue
@@ -141,7 +142,7 @@ class SessionManager:
async def detach_from_conversation(self, conversation: Conversation):
await conversation.disconnect()
async def init_or_join_session(self, sid: str, connection_id: str, data: dict):
async def init_or_join_session(self, sid: str, connection_id: str, session_init_data: SessionInitData):
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
self.local_connection_id_to_session_id[connection_id] = sid
@@ -156,7 +157,7 @@ class SessionManager:
if redis_client and await self._is_session_running_in_cluster(sid):
return EventStream(sid, self.file_store)
return await self.start_local_session(sid, data)
return await self.start_local_session(sid, session_init_data)
async def _is_session_running_in_cluster(self, sid: str) -> bool:
"""As the rest of the cluster if a session is running. Wait a for a short timeout for a reply"""
@@ -210,14 +211,14 @@ class SessionManager:
finally:
self._has_remote_connections_flags.pop(sid)
async def start_local_session(self, sid: str, data: dict):
async def start_local_session(self, sid: str, session_init_data: SessionInitData):
# Start a new local session
logger.info(f'start_new_local_session:{sid}')
session = Session(
sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
)
self.local_sessions_by_sid[sid] = session
await session.initialize_agent(data)
await session.initialize_agent(session_init_data)
return session.agent_session.event_stream
async def send_to_event_stream(self, connection_id: str, data: dict):

View File

@@ -1,4 +1,5 @@
import asyncio
from copy import deepcopy
import time
import socketio
@@ -21,6 +22,7 @@ from openhands.events.serialization import event_from_dict, event_to_dict
from openhands.events.stream import EventStreamSubscriber
from openhands.llm.llm import LLM
from openhands.server.session.agent_session import AgentSession
from openhands.server.session.session_init_data import SessionInitData
from openhands.storage.files import FileStore
ROOM_KEY = 'room:{sid}'
@@ -34,7 +36,6 @@ class Session:
agent_session: AgentSession
loop: asyncio.AbstractEventLoop
config: AppConfig
settings: dict | None
def __init__(
self,
@@ -52,41 +53,31 @@ class Session:
self.agent_session.event_stream.subscribe(
EventStreamSubscriber.SERVER, self.on_event, self.sid
)
self.config = config
# Copying this means that when we update variables they are not applied to the shared global configuration!
self.config = deepcopy(config)
self.loop = asyncio.get_event_loop()
self.settings = None
def close(self):
self.is_alive = False
self.agent_session.close()
async def initialize_agent(self, data: dict):
self.settings = data
async def initialize_agent(self, session_init_data: SessionInitData):
self.agent_session.event_stream.add_event(
AgentStateChangedObservation('', AgentState.LOADING),
EventSource.ENVIRONMENT,
)
# Extract the agent-relevant arguments from the request
args = {key: value for key, value in data.get('args', {}).items()}
agent_cls = args.get(ConfigType.AGENT, self.config.default_agent)
self.config.security.confirmation_mode = args.get(
ConfigType.CONFIRMATION_MODE, self.config.security.confirmation_mode
)
self.config.security.security_analyzer = data.get('args', {}).get(
ConfigType.SECURITY_ANALYZER, self.config.security.security_analyzer
)
max_iterations = args.get(ConfigType.MAX_ITERATIONS, self.config.max_iterations)
agent_cls = session_init_data.agent or self.config.default_agent
self.config.security.confirmation_mode = self.config.security.confirmation_mode if session_init_data.confirmation_mode is None else session_init_data.confirmation_mode
self.config.security.security_analyzer = session_init_data.security_analyzer or self.config.security.security_analyzer
max_iterations = session_init_data.max_iterations or self.config.max_iterations
# override default LLM config
default_llm_config = self.config.get_llm_config()
default_llm_config.model = args.get(
ConfigType.LLM_MODEL, default_llm_config.model
)
default_llm_config.api_key = args.get(
ConfigType.LLM_API_KEY, default_llm_config.api_key
)
default_llm_config.base_url = args.get(
ConfigType.LLM_BASE_URL, default_llm_config.base_url
)
default_llm_config.model = session_init_data.llm_model or default_llm_config.model
default_llm_config.api_key = session_init_data.llm_api_key or default_llm_config.api_key
default_llm_config.base_url = session_init_data.llm_base_url or default_llm_config.base_url
# TODO: override other LLM config & agent config groups (#2075)

View File

@@ -0,0 +1,18 @@
from dataclasses import dataclass
@dataclass
class SessionInitData:
"""
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
"""
language: str | None = None
agent: str | None = None
max_iterations: int | None = None
security_analyzer: str | None = None
confirmation_mode: bool | None = None
llm_model: str | None = None
llm_api_key: str | None = None
llm_base_url: str | None = None

View File

@@ -13,6 +13,7 @@ from openhands.events.serialization import event_to_dict
from openhands.events.stream import AsyncEventStreamWrapper
from openhands.server.auth import get_sid_from_token, sign_token
from openhands.server.github_utils import authenticate_github_user
from openhands.server.session.session_init_data import SessionInitData
from openhands.server.shared import config, session_manager, sio
@@ -26,19 +27,30 @@ async def oh_action(connection_id: str, data: dict):
# If it's an init, we do it here.
action = data.get('action', '')
if action == ActionType.INIT:
await init_connection(connection_id, data)
token = data.pop('token', None)
github_token = data.pop('github_token', None)
latest_event_id = int(data.pop('latest_event_id', -1))
kwargs = {k.lower(): v for k, v in (data.get('args') or {}).items()}
session_init_data = SessionInitData(**kwargs)
await init_connection(
connection_id, token, github_token, session_init_data, latest_event_id
)
return
logger.info(f'sio:oh_action:{connection_id}')
await session_manager.send_to_event_stream(connection_id, data)
async def init_connection(connection_id: str, data: dict):
gh_token = data.pop('github_token', None)
async def init_connection(
connection_id: str,
token: str | None,
gh_token: str | None,
session_init_data: SessionInitData,
latest_event_id: int,
):
if not await authenticate_github_user(gh_token):
raise RuntimeError(status.WS_1008_POLICY_VIOLATION)
token = data.pop('token', None)
if token:
sid = get_sid_from_token(token, config.jwt_secret)
if sid == '':
@@ -52,10 +64,10 @@ async def init_connection(connection_id: str, data: dict):
token = sign_token({'sid': sid}, config.jwt_secret)
await sio.emit('oh_event', {'token': token, 'status': 'ok'}, to=connection_id)
latest_event_id = int(data.pop('latest_event_id', -1))
# The session in question should exist, but may not actually be running locally...
event_stream = await session_manager.init_or_join_session(sid, connection_id, data)
event_stream = await session_manager.init_or_join_session(
sid, connection_id, session_init_data
)
# Send events
agent_state_changed = None