Fix up conversation initialization (#6430)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Robert Brennan
2025-01-24 13:43:02 -05:00
committed by GitHub
parent 19a4f1c3ec
commit 38e19d214d
20 changed files with 86 additions and 138 deletions

View File

@@ -1,20 +1,20 @@
import { describe, it, expect } from "vitest";
import store from "../src/store";
import {
setInitialQuery,
clearInitialQuery,
setInitialPrompt,
clearInitialPrompt,
} from "../src/state/initial-query-slice";
describe("Initial Query Behavior", () => {
it("should clear initial query when clearInitialQuery is dispatched", () => {
it("should clear initial query when clearInitialPrompt is dispatched", () => {
// Set up initial query in the store
store.dispatch(setInitialQuery("test query"));
expect(store.getState().initialQuery.initialQuery).toBe("test query");
store.dispatch(setInitialPrompt("test query"));
expect(store.getState().initialQuery.initialPrompt).toBe("test query");
// Clear the initial query
store.dispatch(clearInitialQuery());
store.dispatch(clearInitialPrompt());
// Verify initial query is cleared
expect(store.getState().initialQuery.initialQuery).toBeNull();
expect(store.getState().initialQuery.initialPrompt).toBeNull();
});
});

View File

@@ -244,10 +244,14 @@ class OpenHands {
static async createConversation(
githubToken?: string,
selectedRepository?: string,
initialUserMsg?: string,
imageUrls?: string[],
): Promise<Conversation> {
const body = {
github_token: githubToken,
selected_repository: selectedRepository,
initial_user_msg: initialUserMsg,
image_urls: imageUrls,
};
const { data } = await openHands.post<Conversation>(

View File

@@ -23,7 +23,7 @@ export const AGENT_STATUS_MAP: {
},
[AgentState.AWAITING_USER_INPUT]: {
message: I18nKey.CHAT_INTERFACE$AGENT_AWAITING_USER_INPUT_MESSAGE,
indicator: IndicatorColor.ORANGE,
indicator: IndicatorColor.BLUE,
},
[AgentState.PAUSED]: {
message: I18nKey.CHAT_INTERFACE$AGENT_PAUSED_MESSAGE,

View File

@@ -3,7 +3,7 @@ import { useNavigate } from "react-router";
import posthog from "posthog-js";
import { useDispatch, useSelector } from "react-redux";
import OpenHands from "#/api/open-hands";
import { setInitialQuery } from "#/state/initial-query-slice";
import { setInitialPrompt } from "#/state/initial-query-slice";
import { RootState } from "#/store";
import { useAuth } from "#/context/auth-context";
@@ -18,7 +18,7 @@ export const useCreateConversation = () => {
);
return useMutation({
mutationFn: (variables: { q?: string }) => {
mutationFn: async (variables: { q?: string }) => {
if (
!variables.q?.trim() &&
!selectedRepository &&
@@ -28,10 +28,13 @@ export const useCreateConversation = () => {
throw new Error("No query provided");
}
if (variables.q) dispatch(setInitialQuery(variables.q));
if (variables.q) dispatch(setInitialPrompt(variables.q));
return OpenHands.createConversation(
gitHubToken || undefined,
selectedRepository || undefined,
variables.q,
files,
);
},
onSuccess: async ({ conversation_id: conversationId }, { q }) => {

View File

@@ -1,10 +1,8 @@
import React from "react";
import { useWSStatusChange } from "./hooks/use-ws-status-change";
import { useHandleWSEvents } from "./hooks/use-handle-ws-events";
import { useHandleRuntimeActive } from "./hooks/use-handle-runtime-active";
export function EventHandler({ children }: React.PropsWithChildren) {
useWSStatusChange();
useHandleWSEvents();
useHandleRuntimeActive();

View File

@@ -1,68 +0,0 @@
import React from "react";
import { useDispatch, useSelector } from "react-redux";
import {
useWsClient,
WsClientProviderStatus,
} from "#/context/ws-client-provider";
import { createChatMessage } from "#/services/chat-service";
import { setCurrentAgentState } from "#/state/agent-slice";
import { addUserMessage } from "#/state/chat-slice";
import { clearFiles, clearInitialQuery } from "#/state/initial-query-slice";
import { RootState } from "#/store";
import { AgentState } from "#/types/agent-state";
export const useWSStatusChange = () => {
const { send, status } = useWsClient();
const { curAgentState } = useSelector((state: RootState) => state.agent);
const dispatch = useDispatch();
const statusRef = React.useRef<WsClientProviderStatus | null>(null);
const { files, initialQuery } = useSelector(
(state: RootState) => state.initialQuery,
);
const sendInitialQuery = (query: string, base64Files: string[]) => {
const timestamp = new Date().toISOString();
send(createChatMessage(query, base64Files, timestamp));
};
const dispatchInitialQuery = (query: string) => {
sendInitialQuery(query, files);
dispatch(clearFiles()); // reset selected files
dispatch(clearInitialQuery()); // reset initial query
};
const handleAgentInit = () => {
if (initialQuery) {
dispatchInitialQuery(initialQuery);
}
};
React.useEffect(() => {
if (curAgentState === AgentState.INIT) {
handleAgentInit();
}
}, [curAgentState]);
React.useEffect(() => {
if (statusRef.current === status) {
return; // This is a check because of strict mode - if the status did not change, don't do anything
}
statusRef.current = status;
if (status !== WsClientProviderStatus.DISCONNECTED && initialQuery) {
dispatch(
addUserMessage({
content: initialQuery,
imageUrls: files,
timestamp: new Date().toISOString(),
pending: true,
}),
);
}
if (status === WsClientProviderStatus.DISCONNECTED) {
dispatch(setCurrentAgentState(AgentState.STOPPED));
}
}, [status]);
};

View File

@@ -1,7 +1,7 @@
import { useDisclosure } from "@nextui-org/react";
import React from "react";
import { Outlet } from "react-router";
import { useDispatch } from "react-redux";
import { useDispatch, useSelector } from "react-redux";
import { FaServer } from "react-icons/fa";
import toast from "react-hot-toast";
import { useTranslation } from "react-i18next";
@@ -11,7 +11,7 @@ import {
useConversation,
} from "#/context/conversation-context";
import { Controls } from "#/components/features/controls/controls";
import { clearMessages } from "#/state/chat-slice";
import { clearMessages, addUserMessage } from "#/state/chat-slice";
import { clearTerminal } from "#/state/command-slice";
import { useEffectOnce } from "#/hooks/use-effect-once";
import CodeIcon from "#/icons/code.svg?react";
@@ -36,6 +36,8 @@ import { ServedAppLabel } from "#/components/layout/served-app-label";
import { TerminalStatusLabel } from "#/components/features/terminal/terminal-status-label";
import { useSettings } from "#/hooks/query/use-settings";
import { MULTI_CONVERSATION_UI } from "#/utils/feature-flags";
import { clearFiles, clearInitialPrompt } from "#/state/initial-query-slice";
import { RootState } from "#/store";
function AppContent() {
useConversationConfig();
@@ -46,6 +48,9 @@ function AppContent() {
const { data: conversation, isFetched } = useUserConversation(
conversationId || null,
);
const { initialPrompt, files } = useSelector(
(state: RootState) => state.initialQuery,
);
const dispatch = useDispatch();
const endSession = useEndSession();
@@ -74,6 +79,18 @@ function AppContent() {
dispatch(clearMessages());
dispatch(clearTerminal());
dispatch(clearJupyter());
if (conversationId && (initialPrompt || files.length > 0)) {
dispatch(
addUserMessage({
content: initialPrompt || "",
imageUrls: files || [],
timestamp: new Date().toISOString(),
pending: true,
}),
);
dispatch(clearInitialPrompt());
dispatch(clearFiles());
}
}, [conversationId]);
useEffectOnce(() => {

View File

@@ -2,14 +2,14 @@ import { createSlice, PayloadAction } from "@reduxjs/toolkit";
type SliceState = {
files: string[]; // base64 encoded images
initialQuery: string | null;
initialPrompt: string | null;
selectedRepository: string | null;
importedProjectZip: string | null; // base64 encoded zip
};
const initialState: SliceState = {
files: [],
initialQuery: null,
initialPrompt: null,
selectedRepository: null,
importedProjectZip: null,
};
@@ -27,11 +27,11 @@ export const selectedFilesSlice = createSlice({
clearFiles(state) {
state.files = [];
},
setInitialQuery(state, action: PayloadAction<string>) {
state.initialQuery = action.payload;
setInitialPrompt(state, action: PayloadAction<string>) {
state.initialPrompt = action.payload;
},
clearInitialQuery(state) {
state.initialQuery = null;
clearInitialPrompt(state) {
state.initialPrompt = null;
},
setSelectedRepository(state, action: PayloadAction<string | null>) {
state.selectedRepository = action.payload;
@@ -49,8 +49,8 @@ export const {
addFile,
removeFile,
clearFiles,
setInitialQuery,
clearInitialQuery,
setInitialPrompt,
clearInitialPrompt,
setSelectedRepository,
clearSelectedRepository,
setImportedProjectZip,

View File

@@ -501,10 +501,6 @@ class AgentController:
EventSource.ENVIRONMENT,
)
if new_state == AgentState.INIT and self.state.resume_state:
await self.set_agent_state_to(self.state.resume_state)
self.state.resume_state = None
def get_agent_state(self) -> AgentState:
"""Returns the current state of the agent.

View File

@@ -4,10 +4,6 @@ __all__ = ['ActionType']
class ActionTypeSchema(BaseModel):
INIT: str = Field(default='initialize')
"""Initializes the agent. Only sent by client.
"""
MESSAGE: str = Field(default='message')
"""Represents a message.
"""

View File

@@ -6,10 +6,6 @@ class AgentState(str, Enum):
"""The agent is loading.
"""
INIT = 'init'
"""The agent is initialized.
"""
RUNNING = 'running'
"""The agent is running.
"""

View File

@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
import socketio
from openhands.core.config import AppConfig
from openhands.events.action import MessageAction
from openhands.events.stream import EventStream
from openhands.server.session.conversation import Conversation
from openhands.server.settings import Settings
@@ -68,7 +69,7 @@ class ConversationManager(ABC):
sid: str,
settings: Settings,
user_id: str | None,
initial_user_msg: str | None = None,
initial_user_msg: MessageAction | None = None,
) -> EventStream:
"""Start an event loop if one is not already running"""

View File

@@ -9,6 +9,7 @@ from openhands.core.config.app_config import AppConfig
from openhands.core.exceptions import AgentRuntimeUnavailableError
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
from openhands.events.action import MessageAction
from openhands.events.stream import EventStream, session_exists
from openhands.server.session.conversation import Conversation
from openhands.server.session.session import ROOM_KEY, Session
@@ -186,7 +187,7 @@ class StandaloneConversationManager(ConversationManager):
sid: str,
settings: Settings,
user_id: str | None,
initial_user_msg: str | None = None,
initial_user_msg: MessageAction | None = None,
) -> EventStream:
logger.info(f'maybe_start_agent_loop:{sid}')
session: Session | None = None

View File

@@ -5,7 +5,6 @@ from pydantic import SecretStr
from socketio.exceptions import ConnectionRefusedError
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
from openhands.events.action import (
NullAction,
)
@@ -86,8 +85,6 @@ async def connect(connection_id: str, environ, auth):
):
continue
elif isinstance(event, AgentStateChangedObservation):
if event.agent_state == AgentState.INIT:
await sio.emit('oh_event', event_to_dict(event), to=connection_id)
agent_state_changed = event
else:
await sio.emit('oh_event', event_to_dict(event), to=connection_id)

View File

@@ -2,7 +2,6 @@ import uvicorn
from fastapi import FastAPI, WebSocket
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema import ActionType
from openhands.utils.shutdown_listener import should_continue
app = FastAPI()
@@ -11,10 +10,6 @@ app = FastAPI()
@app.websocket('/ws')
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
# send message to mock connection
await websocket.send_json(
{'action': ActionType.INIT, 'message': 'Control loop started.'}
)
try:
while should_continue():

View File

@@ -7,6 +7,7 @@ from fastapi.responses import JSONResponse
from pydantic import BaseModel
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.message import MessageAction
from openhands.events.stream import EventStreamSubscriber
from openhands.runtime import get_runtime_cls
from openhands.server.auth import get_user_id
@@ -34,6 +35,7 @@ class InitSessionRequest(BaseModel):
github_token: str | None = None
selected_repository: str | None = None
initial_user_msg: str | None = None
image_urls: list[str] | None = None
async def _create_new_conversation(
@@ -41,6 +43,7 @@ async def _create_new_conversation(
token: str | None,
selected_repository: str | None,
initial_user_msg: str | None,
image_urls: list[str] | None,
):
logger.info('Loading settings')
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
@@ -94,8 +97,14 @@ async def _create_new_conversation(
)
logger.info(f'Starting agent loop for conversation {conversation_id}')
initial_message_action = None
if initial_user_msg or image_urls:
initial_message_action = MessageAction(
content=initial_user_msg or '',
image_urls=image_urls or [],
)
event_stream = await conversation_manager.maybe_start_agent_loop(
conversation_id, conversation_init_data, user_id, initial_user_msg
conversation_id, conversation_init_data, user_id, initial_message_action
)
try:
event_stream.subscribe(
@@ -121,10 +130,16 @@ async def new_conversation(request: Request, data: InitSessionRequest):
github_token = getattr(request.state, 'github_token', '') or data.github_token
selected_repository = data.selected_repository
initial_user_msg = data.initial_user_msg
image_urls = data.image_urls or []
try:
# Create conversation with initial message
conversation_id = await _create_new_conversation(
user_id, github_token, selected_repository, initial_user_msg
user_id,
github_token,
selected_repository,
initial_user_msg,
image_urls,
)
return JSONResponse(

View File

@@ -8,19 +8,12 @@ interruptions are recoverable.
There are 3 main server side event handlers:
* `connect` - Invoked when a new connection to the server is established. (This may be via http or WebSocket)
* `oh_action` - Invoked when a connected client sends an event (Such as `INIT` or a prompt for the Agent) -
* `oh_action` - Invoked when a connected client sends an event (such as a prompt for the Agent) -
this is distinct from the `oh_event` sent from the server to the client.
* `disconnect` - Invoked when a connected client disconnects from the server.
## Init
Each connection has a unique id, and when initially established, is not associated with any session. An
`INIT` event must be sent to the server in order to attach a connection to a session. The `INIT` event
may optionally include a GitHub token and a token to connect to an existing session. (Which may be running
locally or may need to be hydrated). If no token is received as part of the init event, it is assumed a
new session should be started.
## Disconnect
The (manager)[manager.py] manages connections and sessions. Each session may have zero or more connections
associated with it, managed by invocations of `INIT` and disconnect. When a session no longer has any
associated with it. When a session no longer has any
connections associated with it, after a set amount of time (determined by `config.sandbox.close_delay`),
the session and runtime are passivated (So will need to be rehydrated to continue.)

View File

@@ -9,8 +9,7 @@ from openhands.core.config import AgentConfig, AppConfig, LLMConfig
from openhands.core.exceptions import AgentRuntimeUnavailableError
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
from openhands.events.action import ChangeAgentStateAction
from openhands.events.action.message import MessageAction
from openhands.events.action import ChangeAgentStateAction, MessageAction
from openhands.events.event import EventSource
from openhands.events.stream import EventStream
from openhands.microagent import BaseMicroAgent
@@ -72,7 +71,7 @@ class AgentSession:
agent_configs: dict[str, AgentConfig] | None = None,
github_token: str | None = None,
selected_repository: str | None = None,
initial_user_msg: str | None = None,
initial_message: MessageAction | None = None,
):
"""Starts the Agent session
Parameters:
@@ -111,15 +110,17 @@ class AgentSession:
agent_to_llm_config=agent_to_llm_config,
agent_configs=agent_configs,
)
self.event_stream.add_event(
ChangeAgentStateAction(AgentState.INIT), EventSource.ENVIRONMENT
)
if initial_user_msg:
if initial_message:
self.event_stream.add_event(initial_message, EventSource.USER)
self.event_stream.add_event(
MessageAction(content=initial_user_msg), EventSource.USER
ChangeAgentStateAction(AgentState.RUNNING), EventSource.ENVIRONMENT
)
else:
self.event_stream.add_event(
ChangeAgentStateAction(AgentState.AWAITING_USER_INPUT),
EventSource.ENVIRONMENT,
)
self._starting = False
async def close(self):

View File

@@ -11,6 +11,7 @@ from openhands.core.config import AppConfig
from openhands.core.exceptions import AgentRuntimeUnavailableError
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
from openhands.events.action import MessageAction
from openhands.events.stream import EventStream, session_exists
from openhands.server.session.agent_session import WAIT_TIME_BEFORE_CLOSE
from openhands.server.session.conversation import Conversation
@@ -446,7 +447,7 @@ class SessionManager:
sid: str,
settings: Settings,
user_id: str | None,
initial_user_msg: str | None = None,
initial_message: MessageAction | None = None,
) -> EventStream:
logger.info(f'maybe_start_agent_loop:{sid}')
session: Session | None = None
@@ -469,7 +470,7 @@ class SessionManager:
user_id=user_id,
)
self._local_agent_loops_by_sid[sid] = session
asyncio.create_task(session.initialize_agent(settings, initial_user_msg))
asyncio.create_task(session.initialize_agent(settings, initial_message))
event_stream = await self._get_event_stream(sid)
if not event_stream:

View File

@@ -74,7 +74,9 @@ class Session:
self.is_alive = False
await self.agent_session.close()
async def initialize_agent(self, settings: Settings, initial_user_msg: str | None):
async def initialize_agent(
self, settings: Settings, initial_message: MessageAction | None
):
self.agent_session.event_stream.add_event(
AgentStateChangedObservation('', AgentState.LOADING),
EventSource.ENVIRONMENT,
@@ -122,7 +124,7 @@ class Session:
agent_configs=self.config.get_agent_configs(),
github_token=github_token,
selected_repository=selected_repository,
initial_user_msg=initial_user_msg,
initial_message=initial_message,
)
except Exception as e:
logger.exception(f'Error creating agent_session: {e}')