From 000055ba73ffff6a2ac63f6deced72fbd889e204 Mon Sep 17 00:00:00 2001 From: Rohit Malhotra Date: Fri, 17 Jan 2025 09:43:03 -0500 Subject: [PATCH] Add initial user msg to /new_conversation route (#6314) --- openhands/server/routes/manage_conversations.py | 8 ++++++-- openhands/server/session/agent_session.py | 8 ++++++++ openhands/server/session/manager.py | 8 ++++++-- openhands/server/session/session.py | 6 ++---- 4 files changed, 22 insertions(+), 8 deletions(-) diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index 0e8a136670..767c52b706 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -32,12 +32,14 @@ UPDATED_AT_CALLBACK_ID = 'updated_at_callback_id' class InitSessionRequest(BaseModel): github_token: str | None = None selected_repository: str | None = None + initial_user_msg: str | None = None async def _create_new_conversation( user_id: str | None, token: str | None, selected_repository: str | None, + initial_user_msg: str | None, ): logger.info('Loading settings') settings_store = await SettingsStoreImpl.get_instance(config, user_id) @@ -89,7 +91,7 @@ async def _create_new_conversation( logger.info(f'Starting agent loop for conversation {conversation_id}') event_stream = await session_manager.maybe_start_agent_loop( - conversation_id, conversation_init_data, user_id + conversation_id, conversation_init_data, user_id, initial_user_msg ) try: event_stream.subscribe( @@ -114,10 +116,11 @@ async def new_conversation(request: Request, data: InitSessionRequest): user_id = get_user_id(request) github_token = getattr(request.state, 'github_token', '') or data.github_token selected_repository = data.selected_repository + initial_user_msg = data.initial_user_msg try: conversation_id = await _create_new_conversation( - user_id, github_token, selected_repository + user_id, github_token, selected_repository, initial_user_msg ) return JSONResponse( @@ -140,6 +143,7 @@ async def new_conversation(request: Request, data: InitSessionRequest): 'message': str(e), 'msg_id': 'STATUS$ERROR_LLM_AUTHENTICATION', }, + status_code=400, ) diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index 285acccbfb..d876e45788 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -10,6 +10,7 @@ from openhands.core.exceptions import AgentRuntimeUnavailableError from openhands.core.logger import openhands_logger as logger from openhands.core.schema.agent import AgentState from openhands.events.action import ChangeAgentStateAction +from openhands.events.action.message import MessageAction from openhands.events.event import EventSource from openhands.events.stream import EventStream from openhands.microagent import BaseMicroAgent @@ -71,6 +72,7 @@ class AgentSession: agent_configs: dict[str, AgentConfig] | None = None, github_token: str | None = None, selected_repository: str | None = None, + initial_user_msg: str | None = None, ): """Starts the Agent session Parameters: @@ -112,6 +114,12 @@ class AgentSession: self.event_stream.add_event( ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT ) + + if initial_user_msg: + self.event_stream.add_event( + MessageAction(content=initial_user_msg), EventSource.USER + ) + self._starting = False async def close(self): diff --git a/openhands/server/session/manager.py b/openhands/server/session/manager.py index f3158f81b1..203a8dc3b2 100644 --- a/openhands/server/session/manager.py +++ b/openhands/server/session/manager.py @@ -442,7 +442,11 @@ class SessionManager: self._connection_queries.pop(query_id, None) async def maybe_start_agent_loop( - self, sid: str, settings: Settings, user_id: str | None + self, + sid: str, + settings: Settings, + user_id: str | None, + initial_user_msg: str | None = None, ) -> EventStream: logger.info(f'maybe_start_agent_loop:{sid}') session: Session | None = None @@ -462,7 +466,7 @@ class SessionManager: user_id=user_id, ) self._local_agent_loops_by_sid[sid] = session - asyncio.create_task(session.initialize_agent(settings)) + asyncio.create_task(session.initialize_agent(settings, initial_user_msg)) event_stream = await self._get_event_stream(sid) if not event_stream: diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index e77a77101b..b24b297020 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -74,10 +74,7 @@ class Session: self.is_alive = False await self.agent_session.close() - async def initialize_agent( - self, - settings: Settings, - ): + async def initialize_agent(self, settings: Settings, initial_user_msg: str | None): self.agent_session.event_stream.add_event( AgentStateChangedObservation('', AgentState.LOADING), EventSource.ENVIRONMENT, @@ -122,6 +119,7 @@ class Session: agent_configs=self.config.get_agent_configs(), github_token=github_token, selected_repository=selected_repository, + initial_user_msg=initial_user_msg, ) except Exception as e: logger.exception(f'Error creating agent_session: {e}')