diff --git a/frontend/__tests__/components/chat/chat-interface.test.tsx b/frontend/__tests__/components/chat/chat-interface.test.tsx index d27897c2d8..501389f989 100644 --- a/frontend/__tests__/components/chat/chat-interface.test.tsx +++ b/frontend/__tests__/components/chat/chat-interface.test.tsx @@ -128,14 +128,14 @@ describe.skip("ChatInterface", () => { timestamp: new Date().toISOString(), }, { - error: "Woops!", + error: true, + id: "", message: "Something went wrong", }, ]; renderChatInterface(messages); const error = screen.getByTestId("error-message"); - expect(within(error).getByText("Woops!")).toBeInTheDocument(); expect(within(error).getByText("Something went wrong")).toBeInTheDocument(); }); diff --git a/frontend/src/components/AgentStatusBar.tsx b/frontend/src/components/AgentStatusBar.tsx index c337a838f3..7de9ae0397 100644 --- a/frontend/src/components/AgentStatusBar.tsx +++ b/frontend/src/components/AgentStatusBar.tsx @@ -1,6 +1,7 @@ import React, { useEffect } from "react"; import { useTranslation } from "react-i18next"; import { useSelector } from "react-redux"; +import toast from "react-hot-toast"; import { I18nKey } from "#/i18n/declaration"; import { RootState } from "#/store"; import AgentState from "#/types/AgentState"; @@ -16,7 +17,7 @@ enum IndicatorColor { } function AgentStatusBar() { - const { t } = useTranslation(); + const { t, i18n } = useTranslation(); const { curAgentState } = useSelector((state: RootState) => state.agent); const { curStatusMessage } = useSelector((state: RootState) => state.status); @@ -94,15 +95,27 @@ function AgentStatusBar() { const [statusMessage, setStatusMessage] = React.useState(""); React.useEffect(() => { - if (curAgentState === AgentState.LOADING) { - const trimmedCustomMessage = curStatusMessage.status.trim(); - if (trimmedCustomMessage) { - setStatusMessage(t(trimmedCustomMessage)); - return; + let message = curStatusMessage.message || ""; + if (curStatusMessage?.id) { + const id = curStatusMessage.id.trim(); + if (i18n.exists(id)) { + message = t(curStatusMessage.id.trim()) || message; } } + if (curStatusMessage?.type === "error") { + toast.error(message); + return; + } + if (curAgentState === AgentState.LOADING && message.trim()) { + setStatusMessage(message); + } else { + setStatusMessage(AgentStatusMap[curAgentState].message); + } + }, [curStatusMessage.id]); + + React.useEffect(() => { setStatusMessage(AgentStatusMap[curAgentState].message); - }, [curAgentState, curStatusMessage.status]); + }, [curAgentState]); return (
diff --git a/frontend/src/components/chat-interface.tsx b/frontend/src/components/chat-interface.tsx index 25a5707369..9e9af07396 100644 --- a/frontend/src/components/chat-interface.tsx +++ b/frontend/src/components/chat-interface.tsx @@ -73,7 +73,7 @@ export function ChatInterface() { isErrorMessage(message) ? ( ) : ( diff --git a/frontend/src/components/chat/message.d.ts b/frontend/src/components/chat/message.d.ts index e7248fbd64..b2ccf43c99 100644 --- a/frontend/src/components/chat/message.d.ts +++ b/frontend/src/components/chat/message.d.ts @@ -6,6 +6,7 @@ type Message = { }; type ErrorMessage = { - error: string; + error: boolean; + id?: string; message: string; }; diff --git a/frontend/src/components/error-message.tsx b/frontend/src/components/error-message.tsx index cded8c3729..86454539f7 100644 --- a/frontend/src/components/error-message.tsx +++ b/frontend/src/components/error-message.tsx @@ -1,14 +1,41 @@ +import { useState, useEffect } from "react"; +import { useTranslation } from "react-i18next"; + interface ErrorMessageProps { - error: string; + id?: string; message: string; } -export function ErrorMessage({ error, message }: ErrorMessageProps) { +export function ErrorMessage({ id, message }: ErrorMessageProps) { + const { t, i18n } = useTranslation(); + const [showDetails, setShowDetails] = useState(true); + const [headline, setHeadline] = useState(""); + const [details, setDetails] = useState(message); + + useEffect(() => { + if (id && i18n.exists(id)) { + setHeadline(t(id)); + setDetails(message); + setShowDetails(false); + } + }, [id, message, i18n.language]); + return (
-

{error}

-

{message}

+ {headline &&

{headline}

} + {headline && ( + + )} + {showDetails &&

{details}

}
); diff --git a/frontend/src/i18n/translation.json b/frontend/src/i18n/translation.json index 795c60e051..6db2520b6b 100644 --- a/frontend/src/i18n/translation.json +++ b/frontend/src/i18n/translation.json @@ -1441,6 +1441,12 @@ "fr": "Privé", "tr": "Özel" }, + "ERROR_MESSAGE$SHOW_DETAILS": { + "en": "Show details" + }, + "ERROR_MESSAGE$HIDE_DETAILS": { + "en": "Hide details" + }, "STATUS$STARTING_RUNTIME": { "en": "Starting Runtime...", "zh-CN": "启动运行时...", @@ -1510,5 +1516,17 @@ "ar": "في انتظار جاهزية العميل...", "fr": "En attente que le client soit prêt...", "tr": "İstemcinin hazır olması bekleniyor..." + }, + "STATUS$ERROR_LLM_AUTHENTICATION": { + "en": "Error authenticating with the LLM provider. Please check your API key" + }, + "STATUS$ERROR_RUNTIME_DISCONNECTED": { + "en": "There was an error while connecting to the runtime. Please refresh the page." + }, + "AGENT_ERROR$BAD_ACTION": { + "en": "Agent tried to execute a malformed action." + }, + "AGENT_ERROR$ACTION_TIMEOUT": { + "en": "Action timed out." } } diff --git a/frontend/src/routes/_oh.app.tsx b/frontend/src/routes/_oh.app.tsx index a2a4067b56..50c933b64d 100644 --- a/frontend/src/routes/_oh.app.tsx +++ b/frontend/src/routes/_oh.app.tsx @@ -184,21 +184,6 @@ function App() { if (q) addIntialQueryToChat(q, files); }, [settings]); - const handleError = (message: string) => { - const [error, ...rest] = message.split(":"); - const details = rest.join(":"); - if (!details) { - dispatch( - addErrorMessage({ - error: "An error has occured", - message: error, - }), - ); - } else { - dispatch(addErrorMessage({ error, message: details })); - } - }; - const handleMessage = React.useCallback( (message: MessageEvent) => { // set token received from the server @@ -224,7 +209,12 @@ function App() { return; } if (isErrorObservation(parsed)) { - handleError(parsed.message); + dispatch( + addErrorMessage({ + id: parsed.extras?.error_id, + message: parsed.message, + }), + ); return; } diff --git a/frontend/src/services/actions.ts b/frontend/src/services/actions.ts index 46b6aad851..ccdff694e8 100644 --- a/frontend/src/services/actions.ts +++ b/frontend/src/services/actions.ts @@ -1,4 +1,8 @@ -import { addAssistantMessage, addUserMessage } from "#/state/chatSlice"; +import { + addAssistantMessage, + addUserMessage, + addErrorMessage, +} from "#/state/chatSlice"; import { setCode, setActiveFilepath } from "#/state/codeSlice"; import { appendJupyterInput } from "#/state/jupyterSlice"; import { @@ -119,13 +123,19 @@ export function handleActionMessage(message: ActionMessage) { } export function handleStatusMessage(message: StatusMessage) { - const msg = message.status == null ? "" : message.status.trim(); - store.dispatch( - setCurStatusMessage({ - ...message, - status: msg, - }), - ); + if (message.type === "info") { + store.dispatch( + setCurStatusMessage({ + ...message, + }), + ); + } else if (message.type === "error") { + store.dispatch( + addErrorMessage({ + ...message, + }), + ); + } } export function handleAssistantMessage(data: string | SocketMessage) { @@ -139,9 +149,11 @@ export function handleAssistantMessage(data: string | SocketMessage) { if ("action" in socketMessage) { handleActionMessage(socketMessage); - } else if ("status" in socketMessage) { + } else if ("observation" in socketMessage) { + handleObservationMessage(socketMessage); + } else if ("status_update" in socketMessage) { handleStatusMessage(socketMessage); } else { - handleObservationMessage(socketMessage); + console.error("Unknown message type", socketMessage); } } diff --git a/frontend/src/state/chatSlice.ts b/frontend/src/state/chatSlice.ts index 46f156ebdd..7d77901fee 100644 --- a/frontend/src/state/chatSlice.ts +++ b/frontend/src/state/chatSlice.ts @@ -39,10 +39,10 @@ export const chatSlice = createSlice({ addErrorMessage( state, - action: PayloadAction<{ error: string; message: string }>, + action: PayloadAction<{ id?: string; message: string }>, ) { - const { error, message } = action.payload; - state.messages.push({ error, message }); + const { id, message } = action.payload; + state.messages.push({ id, message, error: true }); }, clearMessages(state) { diff --git a/frontend/src/state/statusSlice.ts b/frontend/src/state/statusSlice.ts index b0b503d6c6..6f5158c9f6 100644 --- a/frontend/src/state/statusSlice.ts +++ b/frontend/src/state/statusSlice.ts @@ -2,8 +2,10 @@ import { createSlice, PayloadAction } from "@reduxjs/toolkit"; import { StatusMessage } from "#/types/Message"; const initialStatusMessage: StatusMessage = { - status: "", - is_error: false, + status_update: true, + type: "info", + id: "", + message: "", }; export const statusSlice = createSlice({ diff --git a/frontend/src/types/Message.tsx b/frontend/src/types/Message.tsx index d4d365d590..85b1d97064 100644 --- a/frontend/src/types/Message.tsx +++ b/frontend/src/types/Message.tsx @@ -33,10 +33,8 @@ export interface ObservationMessage { } export interface StatusMessage { - // TODO not implemented yet - // Whether the status is an error, default is false - is_error: boolean; - - // A status message to display to the user - status: string; + status_update: true; + type: string; + id: string; + message: string; } diff --git a/frontend/src/types/core/observations.ts b/frontend/src/types/core/observations.ts index 9de2a70e8b..21bafddf21 100644 --- a/frontend/src/types/core/observations.ts +++ b/frontend/src/types/core/observations.ts @@ -54,6 +54,9 @@ export interface BrowseObservation extends OpenHandsObservationEvent<"browse"> { export interface ErrorObservation extends OpenHandsObservationEvent<"error"> { source: "user"; + extras: { + error_id?: string; + }; } export type OpenHandsObservation = diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index 98265a1915..3eb7a7d066 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -1,7 +1,7 @@ import asyncio import copy import traceback -from typing import Type +from typing import Callable, Type import litellm @@ -35,9 +35,7 @@ from openhands.events.event import Event from openhands.events.observation import ( AgentDelegateObservation, AgentStateChangedObservation, - CmdOutputObservation, ErrorObservation, - FatalErrorObservation, Observation, ) from openhands.events.serialization.event import truncate_content @@ -77,6 +75,7 @@ class AgentController: initial_state: State | None = None, is_delegate: bool = False, headless_mode: bool = True, + status_callback: Callable | None = None, ): """Initializes a new instance of the AgentController class. @@ -119,6 +118,7 @@ class AgentController: # stuck helper self._stuck_detector = StuckDetector(self.state) + self.status_callback = status_callback async def close(self): """Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.""" @@ -132,7 +132,7 @@ class AgentController: message (str): The message to log. """ message = f'[Agent Controller {self.id}] {message}' - getattr(logger, level)(message, extra=extra) + getattr(logger, level)(message, extra=extra, stacklevel=2) def update_state_before_step(self): self.state.iteration += 1 @@ -142,22 +142,16 @@ class AgentController: # update metrics especially for cost. Use deepcopy to avoid it being modified by agent.reset() self.state.local_metrics = copy.deepcopy(self.agent.llm.metrics) - async def report_error(self, message: str, exception: Exception | None = None): - """Reports an error to the user and sends the exception to the LLM next step, in the hope it can self-correct. - - This method should be called for a particular type of errors, which have: - - a user-friendly message, which will be shown in the chat box. This should not be a raw exception message. - - an ErrorObservation that can be sent to the LLM by the user role, with the exception message, so it can self-correct next time. - """ - self.state.last_error = message - if exception: - self.state.last_error += f': {exception}' - detail = str(exception) if exception is not None else '' - if exception is not None and isinstance(exception, litellm.AuthenticationError): - detail = 'Please check your credentials. Is your API key correct?' - self.event_stream.add_event( - ErrorObservation(f'{message}:{detail}'), EventSource.ENVIRONMENT - ) + async def _react_to_exception( + self, + e: Exception, + ): + await self.set_agent_state_to(AgentState.ERROR) + if self.status_callback is not None: + err_id = '' + if isinstance(e, litellm.AuthenticationError): + err_id = 'STATUS$ERROR_LLM_AUTHENTICATION' + self.status_callback('error', err_id, str(e)) async def start_step_loop(self): """The main loop for the agent's step-by-step execution.""" @@ -172,12 +166,7 @@ class AgentController: except Exception as e: traceback.print_exc() self.log('error', f'Error while running the agent: {e}') - self.log('error', traceback.format_exc()) - await self.report_error( - 'There was an unexpected error while running the agent', exception=e - ) - await self.set_agent_state_to(AgentState.ERROR) - break + await self._react_to_exception(e) await asyncio.sleep(0.1) @@ -227,15 +216,6 @@ class AgentController: Args: observation (observation): The observation to handle. """ - if ( - self._pending_action - and hasattr(self._pending_action, 'confirmation_state') - and self._pending_action.confirmation_state - == ActionConfirmationStatus.AWAITING_CONFIRMATION - ): - return - - # Make sure we print the observation in the same way as the LLM sees it observation_to_print = copy.deepcopy(observation) if len(observation_to_print.content) > self.agent.llm.config.max_message_chars: observation_to_print.content = truncate_content( @@ -243,7 +223,6 @@ class AgentController: ) self.log('debug', str(observation_to_print), extra={'msg_type': 'OBSERVATION'}) - # Merge with the metrics from the LLM - it will to synced to the controller's local metrics in update_state_after_step() if observation.llm_metrics is not None: self.agent.llm.metrics.merge(observation.llm_metrics) @@ -255,19 +234,11 @@ class AgentController: await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT) return - if isinstance(observation, CmdOutputObservation): - return - elif isinstance(observation, AgentDelegateObservation): + if isinstance(observation, AgentDelegateObservation): self.state.history.on_event(observation) elif isinstance(observation, ErrorObservation): if self.state.agent_state == AgentState.ERROR: self.state.metrics.merge(self.state.local_metrics) - elif isinstance(observation, FatalErrorObservation): - self.state.last_error = ( - f'There was a fatal error during agent execution: {str(observation)}' - ) - self.state.metrics.merge(self.state.local_metrics) - await self.set_agent_state_to(AgentState.ERROR) async def _handle_message_action(self, action: MessageAction): """Handles message actions from the event stream. @@ -420,13 +391,8 @@ class AgentController: await asyncio.sleep(1) return - # check if agent got stuck before taking any action if self._is_stuck(): - # This need to go BEFORE report_error to sync metrics - self.event_stream.add_event( - FatalErrorObservation('Agent got stuck in a loop'), - EventSource.ENVIRONMENT, - ) + await self._react_to_exception(RuntimeError('Agent got stuck in a loop')) return if self.delegate is not None: @@ -465,15 +431,12 @@ class AgentController: if action is None: raise LLMNoActionError('No action was returned') except (LLMMalformedActionError, LLMNoActionError, LLMResponseError) as e: - # report to the user - # and send the underlying exception to the LLM for self-correction - await self.report_error(str(e)) - return - # FIXME: more graceful handling of litellm.exceptions.ContextWindowExceededError - # e.g. try to condense the memory and try again - except litellm.exceptions.ContextWindowExceededError as e: - self.state.last_error = str(e) - await self.set_agent_state_to(AgentState.ERROR) + self.event_stream.add_event( + ErrorObservation( + content=str(e), + ), + EventSource.AGENT, + ) return if action.runnable: @@ -495,6 +458,7 @@ class AgentController: self.event_stream.add_event(action, EventSource.AGENT) await self.update_state_after_step() + self.log('debug', str(action), extra={'msg_type': 'ACTION'}) async def _delegate_step(self): @@ -524,7 +488,10 @@ class AgentController: self.delegate = None self.delegateAction = None - await self.report_error('Delegator agent encountered an error') + self.event_stream.add_event( + ErrorObservation('Delegate agent encountered an error'), + EventSource.AGENT, + ) elif delegate_state in (AgentState.FINISHED, AgentState.REJECTED): self.log('debug', 'Delegate agent has finished execution') # retrieve delegate result @@ -571,21 +538,18 @@ class AgentController: else: self.state.traffic_control_state = TrafficControlState.THROTTLING if self.headless_mode: - # This need to go BEFORE report_error to sync metrics - await self.set_agent_state_to(AgentState.ERROR) - # set to ERROR state if running in headless mode - # since user cannot resume on the web interface - await self.report_error( - f'Agent reached maximum {limit_type} in headless mode, task stopped. ' + e = RuntimeError( + f'Agent reached maximum {limit_type} in headless mode. ' f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}' ) + await self._react_to_exception(e) else: - await self.set_agent_state_to(AgentState.PAUSED) - await self.report_error( - f'Agent reached maximum {limit_type}, task paused. ' + e = RuntimeError( + f'Agent reached maximum {limit_type}. ' f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}. ' - f'{TRAFFIC_CONTROL_REMINDER}' ) + # FIXME: this isn't really an exception--we should have a different path + await self._react_to_exception(e) stop_step = True return stop_step diff --git a/openhands/controller/state/state.py b/openhands/controller/state/state.py index 52a21f0499..8e6c911c5e 100644 --- a/openhands/controller/state/state.py +++ b/openhands/controller/state/state.py @@ -11,6 +11,7 @@ from openhands.events.action import ( MessageAction, ) from openhands.events.action.agent import AgentFinishAction +from openhands.events.observation import ErrorObservation from openhands.llm.metrics import Metrics from openhands.memory.history import ShortTermHistory from openhands.storage.files import FileStore @@ -80,7 +81,6 @@ class State: history: ShortTermHistory = field(default_factory=ShortTermHistory) inputs: dict = field(default_factory=dict) outputs: dict = field(default_factory=dict) - last_error: str | None = None agent_state: AgentState = AgentState.LOADING resume_state: AgentState | None = None traffic_control_state: TrafficControlState = TrafficControlState.NORMAL @@ -97,6 +97,7 @@ class State: # NOTE: This will never be used by the controller, but it can be used by different # evaluation tasks to store extra data needed to track the progress/state of the task. extra_data: dict[str, Any] = field(default_factory=dict) + last_error: str = '' def save_to_session(self, sid: str, file_store: FileStore): pickled = pickle.dumps(self) @@ -124,9 +125,6 @@ class State: else: state.resume_state = None - # don't carry last_error anymore after restore - state.last_error = None - # first state after restore state.agent_state = AgentState.LOADING return state @@ -151,11 +149,9 @@ class State: if not hasattr(self, 'history'): self.history = ShortTermHistory() - # restore the relevant data in history from the state self.history.start_id = self.start_id self.history.end_id = self.end_id - # remove the restored data from the state if any def get_current_user_intent(self): """Returns the latest user message and image(if provided) that appears after a FinishAction, or the first (the task) if nothing was finished yet.""" diff --git a/openhands/core/cli.py b/openhands/core/cli.py index acf39a71f5..2b369c0e3e 100644 --- a/openhands/core/cli.py +++ b/openhands/core/cli.py @@ -1,5 +1,6 @@ import asyncio import logging +import sys from typing import Type from termcolor import colored @@ -13,6 +14,7 @@ from openhands.core.config import ( load_app_config, ) from openhands.core.logger import openhands_logger as logger +from openhands.core.loop import run_agent_until_done from openhands.core.schema import AgentState from openhands.events import EventSource, EventStream, EventStreamSubscriber from openhands.events.action import ( @@ -114,7 +116,6 @@ async def main(): sid=sid, plugins=agent_cls.sandbox_plugins, ) - await runtime.connect() controller = AgentController( agent=agent, @@ -124,11 +125,14 @@ async def main(): event_stream=event_stream, ) - if controller is not None: - controller.agent_task = asyncio.create_task(controller.start_step_loop()) - async def prompt_for_next_task(): - next_message = input('How can I help? >> ') + # Run input() in a thread pool to avoid blocking the event loop + loop = asyncio.get_event_loop() + next_message = await loop.run_in_executor( + None, lambda: input('How can I help? >> ') + ) + if not next_message.strip(): + await prompt_for_next_task() if next_message == 'exit': event_stream.add_event( ChangeAgentStateAction(AgentState.STOPPED), EventSource.ENVIRONMENT @@ -140,31 +144,45 @@ async def main(): async def on_event(event: Event): display_event(event) if isinstance(event, AgentStateChangedObservation): - if event.agent_state == AgentState.ERROR: - print('An error occurred. Please try again.') if event.agent_state in [ AgentState.AWAITING_USER_INPUT, AgentState.FINISHED, - AgentState.ERROR, ]: await prompt_for_next_task() event_stream.subscribe(EventStreamSubscriber.MAIN, on_event) - await prompt_for_next_task() + await runtime.connect() - while controller.state.agent_state not in [ - AgentState.STOPPED, - ]: - await asyncio.sleep(1) # Give back control for a tick, so the agent can run + asyncio.create_task(prompt_for_next_task()) - print('Exiting...') - await controller.close() + await run_agent_until_done( + controller, runtime, [AgentState.STOPPED, AgentState.ERROR] + ) if __name__ == '__main__': - loop = asyncio.get_event_loop() + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) try: loop.run_until_complete(main()) + except KeyboardInterrupt: + print('Received keyboard interrupt, shutting down...') + except ConnectionRefusedError as e: + print(f'Connection refused: {e}') + sys.exit(1) + except Exception as e: + print(f'An error occurred: {e}') + sys.exit(1) finally: - pass + try: + # Cancel all running tasks + pending = asyncio.all_tasks(loop) + for task in pending: + task.cancel() + # Wait for all tasks to complete with a timeout + loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) + loop.close() + except Exception as e: + print(f'Error during cleanup: {e}') + sys.exit(1) diff --git a/openhands/core/loop.py b/openhands/core/loop.py new file mode 100644 index 0000000000..2a2808dd09 --- /dev/null +++ b/openhands/core/loop.py @@ -0,0 +1,50 @@ +import asyncio + +from openhands.controller import AgentController +from openhands.core.logger import openhands_logger as logger +from openhands.core.schema import AgentState +from openhands.runtime.base import Runtime + + +async def run_agent_until_done( + controller: AgentController, + runtime: Runtime, + end_states: list[AgentState], +): + """ + run_agent_until_done takes a controller and a runtime, and will run + the agent until it reaches a terminal state. + Note that runtime must be connected before being passed in here. + """ + controller.agent_task = asyncio.create_task(controller.start_step_loop()) + + def status_callback(msg_type, msg_id, msg): + if msg_type == 'error': + logger.error(msg) + if controller: + controller.state.last_error = msg + asyncio.create_task(controller.set_agent_state_to(AgentState.ERROR)) + else: + logger.info(msg) + + if hasattr(runtime, 'status_callback') and runtime.status_callback: + raise ValueError( + 'Runtime status_callback was set, but run_agent_until_done will override it' + ) + if hasattr(controller, 'status_callback') and controller.status_callback: + raise ValueError( + 'Controller status_callback was set, but run_agent_until_done will override it' + ) + + runtime.status_callback = status_callback + controller.status_callback = status_callback + + while controller.state.agent_state not in end_states: + await asyncio.sleep(1) + + if not controller.agent_task.done(): + controller.agent_task.cancel() + try: + await controller.agent_task + except asyncio.CancelledError: + pass diff --git a/openhands/core/main.py b/openhands/core/main.py index 0d653ea8b0..c338f35e6b 100644 --- a/openhands/core/main.py +++ b/openhands/core/main.py @@ -17,6 +17,7 @@ from openhands.core.config import ( parse_arguments, ) from openhands.core.logger import openhands_logger as logger +from openhands.core.loop import run_agent_until_done from openhands.core.schema import AgentState from openhands.events import EventSource, EventStream, EventStreamSubscriber from openhands.events.action import MessageAction @@ -122,7 +123,6 @@ async def run_controller( if runtime is None: runtime = create_runtime(config, sid=sid) - await runtime.connect() event_stream = runtime.event_stream # restore cli session if enabled @@ -147,9 +147,6 @@ async def run_controller( headless_mode=headless_mode, ) - if controller is not None: - controller.agent_task = asyncio.create_task(controller.start_step_loop()) - assert isinstance( initial_user_action, Action ), f'initial user actions must be an Action, got {type(initial_user_action)}' @@ -188,22 +185,27 @@ async def run_controller( event_stream.add_event(action, EventSource.USER) event_stream.subscribe(EventStreamSubscriber.MAIN, on_event) - while controller.state.agent_state not in [ + + await runtime.connect() + + end_states = [ AgentState.FINISHED, AgentState.REJECTED, AgentState.ERROR, AgentState.PAUSED, AgentState.STOPPED, - ]: - await asyncio.sleep(1) # Give back control for a tick, so the agent can run + ] + + try: + await run_agent_until_done(controller, runtime, end_states) + except Exception as e: + logger.error(f'Exception in main loop: {e}') # save session when we're about to close if config.enable_cli_session: end_state = controller.get_state() end_state.save_to_session(event_stream.sid, event_stream.file_store) - # close when done - await controller.close() state = controller.get_state() # save trajectories if applicable diff --git a/openhands/events/observation/__init__.py b/openhands/events/observation/__init__.py index a0fad86dfb..28525b09aa 100644 --- a/openhands/events/observation/__init__.py +++ b/openhands/events/observation/__init__.py @@ -6,7 +6,7 @@ from openhands.events.observation.commands import ( ) from openhands.events.observation.delegate import AgentDelegateObservation from openhands.events.observation.empty import NullObservation -from openhands.events.observation.error import ErrorObservation, FatalErrorObservation +from openhands.events.observation.error import ErrorObservation from openhands.events.observation.files import ( FileEditObservation, FileReadObservation, @@ -26,7 +26,6 @@ __all__ = [ 'FileWriteObservation', 'FileEditObservation', 'ErrorObservation', - 'FatalErrorObservation', 'AgentStateChangedObservation', 'AgentDelegateObservation', 'SuccessObservation', diff --git a/openhands/events/observation/error.py b/openhands/events/observation/error.py index cfbb291eb0..4ed05b89ac 100644 --- a/openhands/events/observation/error.py +++ b/openhands/events/observation/error.py @@ -13,6 +13,7 @@ class ErrorObservation(Observation): """ observation: str = ObservationType.ERROR + error_id: str = '' @property def message(self) -> str: @@ -20,17 +21,3 @@ class ErrorObservation(Observation): def __str__(self) -> str: return f'**ErrorObservation**\n{self.content}' - - -@dataclass -class FatalErrorObservation(Observation): - """This data class represents a fatal error encountered by the agent. - - This is the type of error that LLM CANNOT recover from, and the agent controller should stop the execution and report the error to the user. - E.g., Remote runtime action execution failure: 503 Server Error: Service Unavailable for url OR 404 Not Found. - """ - - observation: str = ObservationType.ERROR - - def __str__(self) -> str: - return f'**FatalErrorObservation**\n{self.content}' diff --git a/openhands/events/stream.py b/openhands/events/stream.py index c2a335c3b5..680aba511e 100644 --- a/openhands/events/stream.py +++ b/openhands/events/stream.py @@ -152,12 +152,16 @@ class EventStream: def add_event(self, event: Event, source: EventSource): try: - asyncio.get_running_loop().create_task(self.async_add_event(event, source)) + asyncio.get_running_loop().create_task(self._async_add_event(event, source)) except RuntimeError: # No event loop running... - asyncio.run(self.async_add_event(event, source)) + asyncio.run(self._async_add_event(event, source)) - async def async_add_event(self, event: Event, source: EventSource): + async def _async_add_event(self, event: Event, source: EventSource): + if hasattr(event, '_id') and event.id is not None: + raise ValueError( + 'Event already has an ID. It was probably added back to the EventStream from inside a handler, trigging a loop.' + ) with self._lock: event._id = self._cur_id # type: ignore [attr-defined] self._cur_id += 1 diff --git a/openhands/memory/history.py b/openhands/memory/history.py index 07e904087c..1e4cfb8b5f 100644 --- a/openhands/memory/history.py +++ b/openhands/memory/history.py @@ -12,7 +12,6 @@ from openhands.events.event import Event, EventSource from openhands.events.observation.agent import AgentStateChangedObservation from openhands.events.observation.delegate import AgentDelegateObservation from openhands.events.observation.empty import NullObservation -from openhands.events.observation.error import FatalErrorObservation from openhands.events.observation.observation import Observation from openhands.events.serialization.event import event_to_dict from openhands.events.stream import EventStream @@ -34,7 +33,6 @@ class ShortTermHistory(list[Event]): NullObservation, ChangeAgentStateAction, AgentStateChangedObservation, - FatalErrorObservation, ) def __init__(self): diff --git a/openhands/runtime/action_execution_server.py b/openhands/runtime/action_execution_server.py index 8b263dfbe3..883ea8e954 100644 --- a/openhands/runtime/action_execution_server.py +++ b/openhands/runtime/action_execution_server.py @@ -37,7 +37,6 @@ from openhands.events.action import ( from openhands.events.observation import ( CmdOutputObservation, ErrorObservation, - FatalErrorObservation, FileReadObservation, FileWriteObservation, IPythonRunCellObservation, @@ -168,7 +167,7 @@ class ActionExecutor: async def run( self, action: CmdRunAction - ) -> CmdOutputObservation | FatalErrorObservation: + ) -> CmdOutputObservation | ErrorObservation: return self.bash_session.run(action) async def run_ipython(self, action: IPythonRunCellAction) -> Observation: diff --git a/openhands/runtime/base.py b/openhands/runtime/base.py index 1b1b01d32f..5970a8d840 100644 --- a/openhands/runtime/base.py +++ b/openhands/runtime/base.py @@ -5,6 +5,8 @@ import os from abc import abstractmethod from typing import Callable +from requests.exceptions import ConnectionError + from openhands.core.config import AppConfig, SandboxConfig from openhands.core.logger import openhands_logger as logger from openhands.events import EventSource, EventStream, EventStreamSubscriber @@ -31,6 +33,22 @@ from openhands.runtime.plugins import JupyterRequirement, PluginRequirement from openhands.runtime.utils.edit import FileEditRuntimeMixin from openhands.utils.async_utils import call_sync_from_async +STATUS_MESSAGES = { + 'STATUS$STARTING_RUNTIME': 'Starting runtime...', + 'STATUS$STARTING_CONTAINER': 'Starting container...', + 'STATUS$PREPARING_CONTAINER': 'Preparing container...', + 'STATUS$CONTAINER_STARTED': 'Container started.', + 'STATUS$WAITING_FOR_CLIENT': 'Waiting for client...', +} + + +class RuntimeNotReadyError(Exception): + pass + + +class RuntimeDisconnectedError(Exception): + pass + def _default_env_vars(sandbox_config: SandboxConfig) -> dict[str, str]: ret = {} @@ -54,6 +72,7 @@ class Runtime(FileEditRuntimeMixin): config: AppConfig initial_env_vars: dict[str, str] attach_to_existing: bool + status_callback: Callable | None def __init__( self, @@ -62,14 +81,14 @@ class Runtime(FileEditRuntimeMixin): sid: str = 'default', plugins: list[PluginRequirement] | None = None, env_vars: dict[str, str] | None = None, - status_message_callback: Callable | None = None, + status_callback: Callable | None = None, attach_to_existing: bool = False, ): self.sid = sid self.event_stream = event_stream self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event) self.plugins = plugins if plugins is not None and len(plugins) > 0 else [] - self.status_message_callback = status_message_callback + self.status_callback = status_callback self.attach_to_existing = attach_to_existing self.config = copy.deepcopy(config) @@ -95,7 +114,17 @@ class Runtime(FileEditRuntimeMixin): def log(self, level: str, message: str) -> None: message = f'[runtime {self.sid}] {message}' - getattr(logger, level)(message) + getattr(logger, level)(message, stacklevel=2) + + def send_status_message(self, message_id: str): + """Sends a status message if the callback function was provided.""" + if self.status_callback: + msg = STATUS_MESSAGES.get(message_id, '') + self.status_callback('info', message_id, msg) + + def send_error_message(self, message_id: str, message: str): + if self.status_callback: + self.status_callback('error', message_id, message) # ==================================================================== @@ -131,15 +160,28 @@ class Runtime(FileEditRuntimeMixin): if event.timeout is None: event.timeout = self.config.sandbox.timeout assert event.timeout is not None - observation: Observation = await call_sync_from_async( - self.run_action, event - ) + try: + observation: Observation = await call_sync_from_async( + self.run_action, event + ) + except Exception as e: + err_id = '' + if isinstance(e, ConnectionError) or isinstance( + e, RuntimeDisconnectedError + ): + err_id = 'STATUS$ERROR_RUNTIME_DISCONNECTED' + self.log('error', f'Unexpected error while running action {e}') + self.log('error', f'Problematic action: {str(event)}') + self.send_error_message(err_id, str(e)) + self.close() + return + observation._cause = event.id # type: ignore[attr-defined] observation.tool_call_metadata = event.tool_call_metadata # this might be unnecessary, since source should be set by the event stream when we're here source = event.source if event.source else EventSource.AGENT - await self.event_stream.async_add_event(observation, source) # type: ignore[arg-type] + self.event_stream.add_event(observation, source) # type: ignore[arg-type] def run_action(self, action: Action) -> Observation: """Run an action and return the resulting observation. diff --git a/openhands/runtime/builder/remote.py b/openhands/runtime/builder/remote.py index f96afb38ee..b1b14752cb 100644 --- a/openhands/runtime/builder/remote.py +++ b/openhands/runtime/builder/remote.py @@ -7,7 +7,7 @@ import requests from openhands.core.logger import openhands_logger as logger from openhands.runtime.builder import RuntimeBuilder -from openhands.runtime.utils.request import is_429_error, send_request_with_retry +from openhands.runtime.utils.request import send_request from openhands.runtime.utils.shutdown_listener import ( should_continue, sleep_if_should_continue, @@ -45,18 +45,21 @@ class RemoteRuntimeBuilder(RuntimeBuilder): files.append(('tags', (None, tag))) # Send the POST request to /build (Begins the build process) - response = send_request_with_retry( - self.session, - 'POST', - f'{self.api_url}/build', - files=files, - timeout=30, - retry_fns=[is_429_error], - ) - - if response.status_code != 202: - logger.error(f'Build initiation failed: {response.text}') - raise RuntimeError(f'Build initiation failed: {response.text}') + try: + response = send_request( + self.session, + 'POST', + f'{self.api_url}/build', + files=files, + timeout=30, + ) + except requests.exceptions.HTTPError as e: + if e.response.status_code == 429: + logger.warning('Build was rate limited. Retrying in 30 seconds.') + time.sleep(30) + return self.build(path, tags, platform) + else: + raise e build_data = response.json() build_id = build_data['build_id'] @@ -70,12 +73,11 @@ class RemoteRuntimeBuilder(RuntimeBuilder): logger.error('Build timed out after 30 minutes') raise RuntimeError('Build timed out after 30 minutes') - status_response = send_request_with_retry( + status_response = send_request( self.session, 'GET', f'{self.api_url}/build_status', params={'build_id': build_id}, - timeout=30, ) if status_response.status_code != 200: @@ -112,12 +114,11 @@ class RemoteRuntimeBuilder(RuntimeBuilder): def image_exists(self, image_name: str, pull_from_repo: bool = True) -> bool: """Checks if an image exists in the remote registry using the /image_exists endpoint.""" params = {'image': image_name} - response = send_request_with_retry( + response = send_request( self.session, 'GET', f'{self.api_url}/image_exists', params=params, - timeout=30, ) if response.status_code != 200: diff --git a/openhands/runtime/impl/e2b/e2b_runtime.py b/openhands/runtime/impl/e2b/e2b_runtime.py index b5233574f0..7c9c297f42 100644 --- a/openhands/runtime/impl/e2b/e2b_runtime.py +++ b/openhands/runtime/impl/e2b/e2b_runtime.py @@ -27,14 +27,14 @@ class E2BRuntime(Runtime): sid: str = 'default', plugins: list[PluginRequirement] | None = None, sandbox: E2BSandbox | None = None, - status_message_callback: Optional[Callable] = None, + status_callback: Optional[Callable] = None, ): super().__init__( config, event_stream, sid, plugins, - status_message_callback=status_message_callback, + status_callback=status_callback, ) if sandbox is None: self.sandbox = E2BSandbox() diff --git a/openhands/runtime/impl/eventstream/eventstream_runtime.py b/openhands/runtime/impl/eventstream/eventstream_runtime.py index b29b8d1b26..e90fb7680b 100644 --- a/openhands/runtime/impl/eventstream/eventstream_runtime.py +++ b/openhands/runtime/impl/eventstream/eventstream_runtime.py @@ -25,7 +25,7 @@ from openhands.events.action import ( ) from openhands.events.action.action import Action from openhands.events.observation import ( - FatalErrorObservation, + ErrorObservation, NullObservation, Observation, UserRejectObservation, @@ -36,8 +36,9 @@ from openhands.runtime.base import Runtime from openhands.runtime.builder import DockerRuntimeBuilder from openhands.runtime.plugins import PluginRequirement from openhands.runtime.utils import find_available_tcp_port -from openhands.runtime.utils.request import send_request_with_retry +from openhands.runtime.utils.request import send_request from openhands.runtime.utils.runtime_build import build_runtime_image +from openhands.utils.async_utils import call_sync_from_async from openhands.utils.tenacity_stop import stop_if_should_exit @@ -123,7 +124,7 @@ class EventStreamRuntime(Runtime): sid: str = 'default', plugins: list[PluginRequirement] | None = None, env_vars: dict[str, str] | None = None, - status_message_callback: Callable | None = None, + status_callback: Callable | None = None, attach_to_existing: bool = False, ): super().__init__( @@ -132,7 +133,7 @@ class EventStreamRuntime(Runtime): sid, plugins, env_vars, - status_message_callback, + status_callback, attach_to_existing, ) @@ -143,7 +144,7 @@ class EventStreamRuntime(Runtime): sid: str = 'default', plugins: list[PluginRequirement] | None = None, env_vars: dict[str, str] | None = None, - status_message_callback: Callable | None = None, + status_callback: Callable | None = None, attach_to_existing: bool = False, ): self.config = config @@ -151,7 +152,7 @@ class EventStreamRuntime(Runtime): self._container_port = 30001 # initial dummy value self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}' self.session = requests.Session() - self.status_message_callback = status_message_callback + self.status_callback = status_callback self.docker_client: docker.DockerClient = self._init_docker_client() self.base_container_image = self.config.sandbox.base_container_image @@ -181,7 +182,7 @@ class EventStreamRuntime(Runtime): sid, plugins, env_vars, - status_message_callback, + status_callback, attach_to_existing, ) @@ -205,21 +206,21 @@ class EventStreamRuntime(Runtime): self.log( 'info', f'Starting runtime with image: {self.runtime_container_image}' ) - self._init_container() + await call_sync_from_async(self._init_container) self.log('info', f'Container started: {self.container_name}') else: - self._attach_to_container() + await call_sync_from_async(self._attach_to_container) if not self.attach_to_existing: self.log('info', f'Waiting for client to become ready at {self.api_url}...') self.send_status_message('STATUS$WAITING_FOR_CLIENT') - self._wait_until_alive() + await call_sync_from_async(self._wait_until_alive) if not self.attach_to_existing: self.log('info', 'Runtime is ready.') if not self.attach_to_existing: - self.setup_initial_env() + await call_sync_from_async(self.setup_initial_env) self.log( 'debug', @@ -238,82 +239,74 @@ class EventStreamRuntime(Runtime): ) raise ex - @tenacity.retry( - stop=tenacity.stop_after_attempt(5) | stop_if_should_exit(), - wait=tenacity.wait_fixed(5), - ) def _init_container(self): - try: - self.log('debug', 'Preparing to start container...') - self.send_status_message('STATUS$PREPARING_CONTAINER') - plugin_arg = '' - if self.plugins is not None and len(self.plugins) > 0: - plugin_arg = ( - f'--plugins {" ".join([plugin.name for plugin in self.plugins])} ' - ) - - self._host_port = self._find_available_port() - self._container_port = ( - self._host_port - ) # in future this might differ from host port - self.api_url = ( - f'{self.config.sandbox.local_runtime_url}:{self._container_port}' + self.log('debug', 'Preparing to start container...') + self.send_status_message('STATUS$PREPARING_CONTAINER') + plugin_arg = '' + if self.plugins is not None and len(self.plugins) > 0: + plugin_arg = ( + f'--plugins {" ".join([plugin.name for plugin in self.plugins])} ' ) - use_host_network = self.config.sandbox.use_host_network - network_mode: str | None = 'host' if use_host_network else None - port_mapping: dict[str, list[dict[str, str]]] | None = ( - None - if use_host_network - else { - f'{self._container_port}/tcp': [{'HostPort': str(self._host_port)}] - } - ) + self._host_port = self._find_available_port() + self._container_port = ( + self._host_port + ) # in future this might differ from host port + self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}' - if use_host_network: - self.log( - 'warn', - 'Using host network mode. If you are using MacOS, please make sure you have the latest version of Docker Desktop and enabled host network feature: https://docs.docker.com/network/drivers/host/#docker-desktop', - ) + use_host_network = self.config.sandbox.use_host_network + network_mode: str | None = 'host' if use_host_network else None + port_mapping: dict[str, list[dict[str, str]]] | None = ( + None + if use_host_network + else {f'{self._container_port}/tcp': [{'HostPort': str(self._host_port)}]} + ) - # Combine environment variables - environment = { - 'port': str(self._container_port), - 'PYTHONUNBUFFERED': 1, - } - if self.config.debug or DEBUG: - environment['DEBUG'] = 'true' - - self.log('debug', f'Workspace Base: {self.config.workspace_base}') - if ( - self.config.workspace_mount_path is not None - and self.config.workspace_mount_path_in_sandbox is not None - ): - # e.g. result would be: {"/home/user/openhands/workspace": {'bind': "/workspace", 'mode': 'rw'}} - volumes = { - self.config.workspace_mount_path: { - 'bind': self.config.workspace_mount_path_in_sandbox, - 'mode': 'rw', - } - } - logger.debug(f'Mount dir: {self.config.workspace_mount_path}') - else: - logger.debug( - 'Mount dir is not set, will not mount the workspace directory to the container' - ) - volumes = None + if use_host_network: self.log( - 'debug', - f'Sandbox workspace: {self.config.workspace_mount_path_in_sandbox}', + 'warn', + 'Using host network mode. If you are using MacOS, please make sure you have the latest version of Docker Desktop and enabled host network feature: https://docs.docker.com/network/drivers/host/#docker-desktop', ) - if self.config.sandbox.browsergym_eval_env is not None: - browsergym_arg = ( - f'--browsergym-eval-env {self.config.sandbox.browsergym_eval_env}' - ) - else: - browsergym_arg = '' + # Combine environment variables + environment = { + 'port': str(self._container_port), + 'PYTHONUNBUFFERED': 1, + } + if self.config.debug or DEBUG: + environment['DEBUG'] = 'true' + self.log('debug', f'Workspace Base: {self.config.workspace_base}') + if ( + self.config.workspace_mount_path is not None + and self.config.workspace_mount_path_in_sandbox is not None + ): + # e.g. result would be: {"/home/user/openhands/workspace": {'bind': "/workspace", 'mode': 'rw'}} + volumes = { + self.config.workspace_mount_path: { + 'bind': self.config.workspace_mount_path_in_sandbox, + 'mode': 'rw', + } + } + logger.debug(f'Mount dir: {self.config.workspace_mount_path}') + else: + logger.debug( + 'Mount dir is not set, will not mount the workspace directory to the container' + ) + volumes = None + self.log( + 'debug', + f'Sandbox workspace: {self.config.workspace_mount_path_in_sandbox}', + ) + + if self.config.sandbox.browsergym_eval_env is not None: + browsergym_arg = ( + f'--browsergym-eval-env {self.config.sandbox.browsergym_eval_env}' + ) + else: + browsergym_arg = '' + + try: self.container = self.docker_client.containers.run( self.runtime_container_image, command=( @@ -337,6 +330,21 @@ class EventStreamRuntime(Runtime): self.log_buffer = LogBuffer(self.container, self.log) self.log('debug', f'Container started. Server url: {self.api_url}') self.send_status_message('STATUS$CONTAINER_STARTED') + except docker.errors.APIError as e: + # check 409 error + if '409' in str(e): + self.log( + 'warning', + f'Container {self.container_name} already exists. Removing...', + ) + self._close_containers(rm_all_containers=True) + return self._init_container() + + else: + self.log( + 'error', + f'Error: Instance {self.container_name} FAILED to start container!\n', + ) except Exception as e: self.log( 'error', @@ -384,27 +392,20 @@ class EventStreamRuntime(Runtime): @tenacity.retry( stop=tenacity.stop_after_delay(120) | stop_if_should_exit(), - wait=tenacity.wait_exponential(multiplier=2, min=1, max=20), reraise=(ConnectionRefusedError,), + wait=tenacity.wait_fixed(2), ) def _wait_until_alive(self): self._refresh_logs() if not self.log_buffer: raise RuntimeError('Runtime client is not ready.') - response = send_request_with_retry( + send_request( self.session, 'GET', f'{self.api_url}/alive', - retry_exceptions=[ConnectionRefusedError], - timeout=300, # 5 minutes gives the container time to be alive 🧟‍♂️ + timeout=5, ) - if response.status_code == 200: - return - else: - msg = f'Action execution API is not alive. Response: {response}' - self.log('error', msg) - raise RuntimeError(msg) def close(self, rm_all_containers: bool = True): """Closes the EventStreamRuntime and associated objects @@ -421,7 +422,9 @@ class EventStreamRuntime(Runtime): if self.attach_to_existing: return + self._close_containers(rm_all_containers) + def _close_containers(self, rm_all_containers: bool = True): try: containers = self.docker_client.containers.list(all=True) for container in containers: @@ -466,10 +469,11 @@ class EventStreamRuntime(Runtime): return NullObservation('') action_type = action.action # type: ignore[attr-defined] if action_type not in ACTION_TYPE_TO_CLASS: - return FatalErrorObservation(f'Action {action_type} does not exist.') + raise ValueError(f'Action {action_type} does not exist.') if not hasattr(self, action_type): - return FatalErrorObservation( - f'Action {action_type} is not supported in the current runtime.' + return ErrorObservation( + f'Action {action_type} is not supported in the current runtime.', + error_id='AGENT_ERROR$BAD_ACTION', ) if ( getattr(action, 'confirmation_state', None) @@ -484,33 +488,21 @@ class EventStreamRuntime(Runtime): assert action.timeout is not None try: - response = send_request_with_retry( + response = send_request( self.session, 'POST', f'{self.api_url}/execute_action', json={'action': event_to_dict(action)}, - timeout=action.timeout, + # wait a few more seconds to get the timeout error from client side + timeout=action.timeout + 5, ) - if response.status_code == 200: - output = response.json() - obs = observation_from_dict(output) - obs._cause = action.id # type: ignore[attr-defined] - else: - self.log('debug', f'action: {action}') - self.log('debug', f'response: {response}') - error_message = response.text - self.log('error', f'Error from server: {error_message}') - obs = FatalErrorObservation( - f'Action execution failed: {error_message}' - ) + output = response.json() + obs = observation_from_dict(output) + obs._cause = action.id # type: ignore[attr-defined] except requests.Timeout: - self.log('error', 'No response received within the timeout period.') - obs = FatalErrorObservation( - f'Action execution timed out after {action.timeout} seconds.' + raise RuntimeError( + f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s' ) - except Exception as e: - self.log('error', f'Error during action execution: {e}') - obs = FatalErrorObservation(f'Action execution failed: {str(e)}') self._refresh_logs() return obs @@ -567,7 +559,7 @@ class EventStreamRuntime(Runtime): params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()} - response = send_request_with_retry( + send_request( self.session, 'POST', f'{self.api_url}/upload_file', @@ -575,11 +567,6 @@ class EventStreamRuntime(Runtime): params=params, timeout=300, ) - if response.status_code == 200: - return - else: - error_message = response.text - raise Exception(f'Copy operation failed: {error_message}') except requests.Timeout: raise TimeoutError('Copy operation timed out') @@ -604,31 +591,25 @@ class EventStreamRuntime(Runtime): if path is not None: data['path'] = path - response = send_request_with_retry( + response = send_request( self.session, 'POST', f'{self.api_url}/list_files', json=data, - timeout=30, # 30 seconds because the container should already be alive + timeout=10, ) - if response.status_code == 200: - response_json = response.json() - assert isinstance(response_json, list) - return response_json - else: - error_message = response.text - raise Exception(f'List files operation failed: {error_message}') + response_json = response.json() + assert isinstance(response_json, list) + return response_json except requests.Timeout: raise TimeoutError('List files operation timed out') - except Exception as e: - raise RuntimeError(f'List files operation failed: {str(e)}') def copy_from(self, path: str) -> bytes: """Zip all files in the sandbox and return as a stream of bytes.""" self._refresh_logs() try: params = {'path': path} - response = send_request_with_retry( + response = send_request( self.session, 'GET', f'{self.api_url}/download_files', @@ -636,16 +617,10 @@ class EventStreamRuntime(Runtime): stream=True, timeout=30, ) - if response.status_code == 200: - data = response.content - return data - else: - error_message = response.text - raise Exception(f'Copy operation failed: {error_message}') + data = response.content + return data except requests.Timeout: raise TimeoutError('Copy operation timed out') - except Exception as e: - raise RuntimeError(f'Copy operation failed: {str(e)}') def _is_port_in_use_docker(self, port): containers = self.docker_client.containers.list() @@ -663,8 +638,3 @@ class EventStreamRuntime(Runtime): return port # If no port is found after max_attempts, return the last tried port return port - - def send_status_message(self, message: str): - """Sends a status message if the callback function was provided.""" - if self.status_message_callback: - self.status_message_callback(message) diff --git a/openhands/runtime/impl/modal/modal_runtime.py b/openhands/runtime/impl/modal/modal_runtime.py index 3a484c43e6..0e598a437f 100644 --- a/openhands/runtime/impl/modal/modal_runtime.py +++ b/openhands/runtime/impl/modal/modal_runtime.py @@ -75,7 +75,7 @@ class ModalRuntime(EventStreamRuntime): sid: str = 'default', plugins: list[PluginRequirement] | None = None, env_vars: dict[str, str] | None = None, - status_message_callback: Callable | None = None, + status_callback: Callable | None = None, attach_to_existing: bool = False, ): assert config.modal_api_token_id, 'Modal API token id is required' @@ -102,7 +102,7 @@ class ModalRuntime(EventStreamRuntime): self.container_port = 3000 self.session = requests.Session() - self.status_message_callback = status_message_callback + self.status_callback = status_callback self.base_container_image_id = self.config.sandbox.base_container_image self.runtime_container_image_id = self.config.sandbox.runtime_container_image self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time @@ -122,7 +122,7 @@ class ModalRuntime(EventStreamRuntime): sid, plugins, env_vars, - status_message_callback, + status_callback, attach_to_existing, ) diff --git a/openhands/runtime/impl/remote/remote_runtime.py b/openhands/runtime/impl/remote/remote_runtime.py index 7e533269d1..1e6fdecf51 100644 --- a/openhands/runtime/impl/remote/remote_runtime.py +++ b/openhands/runtime/impl/remote/remote_runtime.py @@ -1,12 +1,11 @@ import os import tempfile import threading -import time from typing import Callable, Optional from zipfile import ZipFile import requests -from requests.exceptions import Timeout +import tenacity from openhands.core.config import AppConfig from openhands.events import EventStream @@ -21,22 +20,26 @@ from openhands.events.action import ( ) from openhands.events.action.action import Action from openhands.events.observation import ( - FatalErrorObservation, + ErrorObservation, NullObservation, Observation, ) from openhands.events.serialization import event_to_dict, observation_from_dict from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS -from openhands.runtime.base import Runtime +from openhands.runtime.base import ( + Runtime, + RuntimeDisconnectedError, + RuntimeNotReadyError, +) from openhands.runtime.builder.remote import RemoteRuntimeBuilder from openhands.runtime.plugins import PluginRequirement from openhands.runtime.utils.command import get_remote_startup_command from openhands.runtime.utils.request import ( - is_404_error, - is_503_error, - send_request_with_retry, + send_request, ) from openhands.runtime.utils.runtime_build import build_runtime_image +from openhands.utils.async_utils import call_sync_from_async +from openhands.utils.tenacity_stop import stop_if_should_exit class RemoteRuntime(Runtime): @@ -51,31 +54,32 @@ class RemoteRuntime(Runtime): sid: str = 'default', plugins: list[PluginRequirement] | None = None, env_vars: dict[str, str] | None = None, - status_message_callback: Optional[Callable] = None, + status_callback: Optional[Callable] = None, attach_to_existing: bool = False, ): + # We need to set session and action_semaphore before the __init__ below, or we get odd errors + self.session = requests.Session() + self.action_semaphore = threading.Semaphore(1) + super().__init__( config, event_stream, sid, plugins, env_vars, - status_message_callback, + status_callback, attach_to_existing, ) - if self.config.sandbox.api_key is None: raise ValueError( 'API key is required to use the remote runtime. ' 'Please set the API key in the config (config.toml) or as an environment variable (SANDBOX_API_KEY).' ) - self.session = requests.Session() self.session.headers.update({'X-API-Key': self.config.sandbox.api_key}) - self.action_semaphore = threading.Semaphore(1) if self.config.workspace_base is not None: self.log( - 'warning', + 'debug', 'Setting workspace_base is not supported in the remote runtime.', ) @@ -86,9 +90,13 @@ class RemoteRuntime(Runtime): self.runtime_url: str | None = None async def connect(self): - self._start_or_attach_to_runtime() - self._wait_until_alive() - self.setup_initial_env() + await call_sync_from_async(self._start_or_attach_to_runtime) + try: + await call_sync_from_async(self._wait_until_alive) + except RuntimeNotReadyError: + self.log('error', 'Runtime failed to start, timed out before ready') + raise + await call_sync_from_async(self.setup_initial_env) def _start_or_attach_to_runtime(self): existing_runtime = self._check_existing_runtime() @@ -127,44 +135,40 @@ class RemoteRuntime(Runtime): def _check_existing_runtime(self) -> bool: try: - response = send_request_with_retry( - self.session, + response = self._send_request( 'GET', f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.sid}', timeout=5, ) - except Exception as e: + except requests.HTTPError as e: + if e.response.status_code == 404: + return False self.log('debug', f'Error while looking for remote runtime: {e}') - return False + raise - if response.status_code == 200: - data = response.json() - status = data.get('status') - if status == 'running': - self._parse_runtime_response(response) - return True - elif status == 'stopped': - self.log('debug', 'Found existing remote runtime, but it is stopped') - return False - elif status == 'paused': - self.log('debug', 'Found existing remote runtime, but it is paused') - self._parse_runtime_response(response) - self._resume_runtime() - return True - else: - self.log('error', f'Invalid response from runtime API: {data}') - return False + data = response.json() + status = data.get('status') + if status == 'running': + self._parse_runtime_response(response) + return True + elif status == 'stopped': + self.log('debug', 'Found existing remote runtime, but it is stopped') + return False + elif status == 'paused': + self.log('debug', 'Found existing remote runtime, but it is paused') + self._parse_runtime_response(response) + self._resume_runtime() + return True else: - self.log('debug', 'Could not find existing remote runtime') + self.log('error', f'Invalid response from runtime API: {data}') return False def _build_runtime(self): self.log('debug', f'Building RemoteRuntime config:\n{self.config}') - response = send_request_with_retry( - self.session, + response = self._send_request( 'GET', f'{self.config.sandbox.remote_runtime_api_url}/registry_prefix', - timeout=30, + timeout=10, ) response_json = response.json() registry_prefix = response_json['registry_prefix'] @@ -191,14 +195,13 @@ class RemoteRuntime(Runtime): force_rebuild=self.config.sandbox.force_rebuild_runtime, ) - response = send_request_with_retry( - self.session, + response = self._send_request( 'GET', f'{self.config.sandbox.remote_runtime_api_url}/image_exists', params={'image': self.container_image}, - timeout=30, + timeout=10, ) - if response.status_code != 200 or not response.json()['exists']: + if not response.json()['exists']: raise RuntimeError(f'Container image {self.container_image} does not exist') def _start_runtime(self): @@ -228,17 +231,11 @@ class RemoteRuntime(Runtime): } # Start the sandbox using the /start endpoint - response = send_request_with_retry( - self.session, + response = self._send_request( 'POST', f'{self.config.sandbox.remote_runtime_api_url}/start', json=start_request, - timeout=300, ) - if response.status_code != 201: - raise RuntimeError( - f'[Runtime (ID={self.runtime_id})] Failed to start runtime: {response.text}' - ) self._parse_runtime_response(response) self.log( 'debug', @@ -246,17 +243,12 @@ class RemoteRuntime(Runtime): ) def _resume_runtime(self): - response = send_request_with_retry( - self.session, + self._send_request( 'POST', f'{self.config.sandbox.remote_runtime_api_url}/resume', json={'runtime_id': self.runtime_id}, timeout=30, ) - if response.status_code != 200: - raise RuntimeError( - f'[Runtime (ID={self.runtime_id})] Failed to resume runtime: {response.text}' - ) self.log('debug', 'Runtime resumed.') def _parse_runtime_response(self, response: requests.Response): @@ -268,72 +260,57 @@ class RemoteRuntime(Runtime): {'X-Session-API-Key': start_response['session_api_key']} ) + @tenacity.retry( + stop=tenacity.stop_after_delay(180) | stop_if_should_exit(), + reraise=True, + retry=tenacity.retry_if_exception_type(RuntimeNotReadyError), + wait=tenacity.wait_fixed(2), + ) def _wait_until_alive(self): self.log('debug', f'Waiting for runtime to be alive at url: {self.runtime_url}') - # send GET request to /runtime/ - pod_running = False - max_not_found_count = 12 # 2 minutes - not_found_count = 0 - while not pod_running: - runtime_info_response = send_request_with_retry( - self.session, - 'GET', - f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.runtime_id}', - timeout=5, - ) - if runtime_info_response.status_code != 200: - raise RuntimeError( - f'Failed to get runtime status: {runtime_info_response.status_code}. Response: {runtime_info_response.text}' - ) - runtime_data = runtime_info_response.json() - assert runtime_data['runtime_id'] == self.runtime_id - pod_status = runtime_data['pod_status'] - self.log( - 'debug', - f'Waiting for runtime pod to be active. Current status: {pod_status}', - ) - if pod_status == 'Ready': - pod_running = True - break - elif pod_status == 'Not Found' and not_found_count < max_not_found_count: - not_found_count += 1 - self.log( - 'debug', - f'Runtime pod not found. Count: {not_found_count} / {max_not_found_count}', - ) - elif pod_status in ('Failed', 'Unknown', 'Not Found'): - # clean up the runtime - self.close() - raise RuntimeError( - f'Runtime (ID={self.runtime_id}) failed to start. Current status: {pod_status}' - ) - # Pending otherwise - add proper sleep - time.sleep(10) - - response = send_request_with_retry( - self.session, + runtime_info_response = self._send_request( 'GET', - f'{self.runtime_url}/alive', - # Retry 404 & 503 errors for the /alive endpoint - # because the runtime might just be starting up - # and have not registered the endpoint yet - retry_fns=[is_404_error, is_503_error], - # leave enough time for the runtime to start up - timeout=600, + f'{self.config.sandbox.remote_runtime_api_url}/runtime/{self.runtime_id}', ) - if response.status_code != 200: - msg = f'Runtime (ID={self.runtime_id}) is not alive yet. Status: {response.status_code}.' - self.log('warning', msg) - raise RuntimeError(msg) + runtime_data = runtime_info_response.json() + assert 'runtime_id' in runtime_data + assert runtime_data['runtime_id'] == self.runtime_id + assert 'pod_status' in runtime_data + pod_status = runtime_data['pod_status'] + if pod_status == 'Ready': + try: + self._send_request( + 'GET', + f'{self.runtime_url}/alive', + ) # will raise exception if we don't get 200 back. + except requests.HTTPError as e: + self.log( + 'warning', f"Runtime /alive failed, but pod says it's ready: {e}" + ) + raise RuntimeNotReadyError( + f'Runtime /alive failed to respond with 200: {e}' + ) + return + if pod_status in ('Failed', 'Unknown', 'Not Found'): + # clean up the runtime + self.close() + raise RuntimeError( + f'Runtime (ID={self.runtime_id}) failed to start. Current status: {pod_status}' + ) + + self.log( + 'debug', + f'Waiting for runtime pod to be active. Current status: {pod_status}', + ) + raise RuntimeNotReadyError() def close(self, timeout: int = 10): if self.config.sandbox.keep_remote_runtime_alive or self.attach_to_existing: self.session.close() return - if self.runtime_id: + if self.runtime_id and self.session: try: - response = send_request_with_retry( - self.session, + response = self._send_request( 'POST', f'{self.config.sandbox.remote_runtime_api_url}/stop', json={'runtime_id': self.runtime_id}, @@ -361,12 +338,11 @@ class RemoteRuntime(Runtime): return NullObservation('') action_type = action.action # type: ignore[attr-defined] if action_type not in ACTION_TYPE_TO_CLASS: - return FatalErrorObservation( - f'[Runtime (ID={self.runtime_id})] Action {action_type} does not exist.' - ) + raise ValueError(f'Action {action_type} does not exist.') if not hasattr(self, action_type): - return FatalErrorObservation( - f'[Runtime (ID={self.runtime_id})] Action {action_type} is not supported in the current runtime.' + return ErrorObservation( + f'[Runtime (ID={self.runtime_id})] Action {action_type} is not supported in the current runtime.', + error_id='AGENT_ERROR$BAD_ACTION', ) assert action.timeout is not None @@ -374,36 +350,37 @@ class RemoteRuntime(Runtime): try: request_body = {'action': event_to_dict(action)} self.log('debug', f'Request body: {request_body}') - response = send_request_with_retry( - self.session, + response = self._send_request( 'POST', f'{self.runtime_url}/execute_action', json=request_body, - timeout=action.timeout, + # wait a few more seconds to get the timeout error from client side + timeout=action.timeout + 5, ) - if response.status_code == 200: - output = response.json() - obs = observation_from_dict(output) - obs._cause = action.id # type: ignore[attr-defined] - return obs - else: - error_message = response.text - self.log('error', f'Error from server: {error_message}') - obs = FatalErrorObservation( - f'Action execution failed: {error_message}' - ) - except Timeout: - self.log('error', 'No response received within the timeout period.') - obs = FatalErrorObservation( - f'[Runtime (ID={self.runtime_id})] Action execution timed out' - ) - except Exception as e: - self.log('error', f'Error during action execution: {e}') - obs = FatalErrorObservation( - f'[Runtime (ID={self.runtime_id})] Action execution failed: {str(e)}' + output = response.json() + obs = observation_from_dict(output) + obs._cause = action.id # type: ignore[attr-defined] + except requests.Timeout: + raise RuntimeError( + f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s' ) return obs + def _send_request(self, method, url, **kwargs): + is_runtime_request = self.runtime_url and self.runtime_url in url + try: + return send_request(self.session, method, url, **kwargs) + except requests.Timeout: + self.log('error', 'No response received within the timeout period.') + raise + except requests.HTTPError as e: + if is_runtime_request and e.response.status_code == 404: + raise RuntimeDisconnectedError( + f'404 error while connecting to {self.runtime_url}' + ) + else: + raise e + def run(self, action: CmdRunAction) -> Observation: return self.run_action(action) @@ -450,32 +427,16 @@ class RemoteRuntime(Runtime): params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()} - response = send_request_with_retry( - self.session, + response = self._send_request( 'POST', f'{self.runtime_url}/upload_file', files=upload_data, params=params, timeout=300, ) - if response.status_code == 200: - self.log( - 'debug', - f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}', - ) - return - else: - error_message = response.text - raise Exception( - f'[Runtime (ID={self.runtime_id})] Copy operation failed: {error_message}' - ) - except TimeoutError: - raise TimeoutError( - f'[Runtime (ID={self.runtime_id})] Copy operation timed out' - ) - except Exception as e: - raise RuntimeError( - f'[Runtime (ID={self.runtime_id})] Copy operation failed: {str(e)}' + self.log( + 'debug', + f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}', ) finally: if recursive: @@ -485,64 +446,27 @@ class RemoteRuntime(Runtime): ) def list_files(self, path: str | None = None) -> list[str]: - try: - data = {} - if path is not None: - data['path'] = path + data = {} + if path is not None: + data['path'] = path - response = send_request_with_retry( - self.session, - 'POST', - f'{self.runtime_url}/list_files', - json=data, - timeout=30, - ) - if response.status_code == 200: - response_json = response.json() - assert isinstance(response_json, list) - return response_json - else: - error_message = response.text - raise Exception( - f'[Runtime (ID={self.runtime_id})] List files operation failed: {error_message}' - ) - except TimeoutError: - raise TimeoutError( - f'[Runtime (ID={self.runtime_id})] List files operation timed out' - ) - except Exception as e: - raise RuntimeError( - f'[Runtime (ID={self.runtime_id})] List files operation failed: {str(e)}' - ) + response = self._send_request( + 'POST', + f'{self.runtime_url}/list_files', + json=data, + timeout=30, + ) + response_json = response.json() + assert isinstance(response_json, list) + return response_json def copy_from(self, path: str) -> bytes: """Zip all files in the sandbox and return as a stream of bytes.""" - try: - params = {'path': path} - response = send_request_with_retry( - self.session, - 'GET', - f'{self.runtime_url}/download_files', - params=params, - timeout=30, - ) - if response.status_code == 200: - return response.content - else: - error_message = response.text - raise Exception( - f'[Runtime (ID={self.runtime_id})] Copy operation failed: {error_message}' - ) - except requests.Timeout: - raise TimeoutError( - f'[Runtime (ID={self.runtime_id})] Copy operation timed out' - ) - except Exception as e: - raise RuntimeError( - f'[Runtime (ID={self.runtime_id})] Copy operation failed: {str(e)}' - ) - - def send_status_message(self, message: str): - """Sends a status message if the callback function was provided.""" - if self.status_message_callback: - self.status_message_callback(message) + params = {'path': path} + response = self._send_request( + 'GET', + f'{self.runtime_url}/download_files', + params=params, + timeout=30, + ) + return response.content diff --git a/openhands/runtime/utils/bash.py b/openhands/runtime/utils/bash.py index fba16787c6..a5019315a0 100644 --- a/openhands/runtime/utils/bash.py +++ b/openhands/runtime/utils/bash.py @@ -9,7 +9,7 @@ from openhands.events.action import CmdRunAction from openhands.events.event import EventSource from openhands.events.observation import ( CmdOutputObservation, - FatalErrorObservation, + ErrorObservation, ) SOFT_TIMEOUT_SECONDS = 5 @@ -275,7 +275,7 @@ class BashSession: output += '\r\n' + bash_prompt return output, exit_code - def run(self, action: CmdRunAction) -> CmdOutputObservation | FatalErrorObservation: + def run(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservation: try: assert ( action.timeout is not None @@ -329,6 +329,6 @@ class BashSession: interpreter_details=python_interpreter, ) except UnicodeDecodeError as e: - return FatalErrorObservation( - f'Runtime bash execution failed: Command output could not be decoded as utf-8. {str(e)}' + return ErrorObservation( + f'Runtime bash execution failed: Command output could not be decoded as utf-8. {str(e)}', ) diff --git a/openhands/runtime/utils/edit.py b/openhands/runtime/utils/edit.py index 7743425525..6595760e06 100644 --- a/openhands/runtime/utils/edit.py +++ b/openhands/runtime/utils/edit.py @@ -13,7 +13,6 @@ from openhands.events.action import ( ) from openhands.events.observation import ( ErrorObservation, - FatalErrorObservation, FileEditObservation, FileReadObservation, FileWriteObservation, @@ -214,9 +213,7 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface): if isinstance(obs, ErrorObservation): return obs if not isinstance(obs, FileWriteObservation): - return FatalErrorObservation( - f'Fatal Runtime in editing: Expected FileWriteObservation, got {type(obs)}: {str(obs)}' - ) + raise ValueError(f'Expected FileWriteObservation, got {type(obs)}: {str(obs)}') return FileEditObservation( content=get_diff('', action.content, action.path), path=action.path, @@ -225,9 +222,7 @@ class FileEditRuntimeMixin(FileEditRuntimeInterface): new_content=action.content, ) if not isinstance(obs, FileReadObservation): - return FatalErrorObservation( - f'Fatal Runtime in editing: Expected FileReadObservation, got {type(obs)}: {str(obs)}' - ) + raise ValueError(f'Expected FileReadObservation, got {type(obs)}: {str(obs)}') original_file_content = obs.content old_file_lines = original_file_content.split('\n') diff --git a/openhands/runtime/utils/request.py b/openhands/runtime/utils/request.py index 6940827a3a..655fe304e5 100644 --- a/openhands/runtime/utils/request.py +++ b/openhands/runtime/utils/request.py @@ -1,22 +1,12 @@ -from typing import Any, Callable, Type +from typing import Any import requests from requests.exceptions import ( ChunkedEncodingError, ConnectionError, ) -from tenacity import ( - retry, - retry_if_exception, - retry_if_exception_type, - stop_after_delay, - wait_exponential, -) from urllib3.exceptions import IncompleteRead -from openhands.core.logger import openhands_logger as logger -from openhands.utils.tenacity_stop import stop_if_should_exit - def is_server_error(exception): return ( @@ -60,37 +50,13 @@ DEFAULT_RETRY_EXCEPTIONS = [ ] -def send_request_with_retry( +def send_request( session: requests.Session, method: str, url: str, - timeout: int, - retry_exceptions: list[Type[Exception]] | None = None, - retry_fns: list[Callable[[Exception], bool]] | None = None, + timeout: int = 10, **kwargs: Any, ) -> requests.Response: - exceptions_to_catch = retry_exceptions or DEFAULT_RETRY_EXCEPTIONS - retry_condition = retry_if_exception_type( - tuple(exceptions_to_catch) - ) | retry_if_exception(is_502_error) - if retry_fns is not None: - for fn in retry_fns: - retry_condition |= retry_if_exception(fn) - # wait a few more seconds to get the timeout error from client side - kwargs['timeout'] = timeout + 10 - - @retry( - stop=stop_after_delay(timeout) | stop_if_should_exit(), - wait=wait_exponential(multiplier=1, min=4, max=20), - retry=retry_condition, - reraise=True, - before_sleep=lambda retry_state: logger.debug( - f'Retrying {method} request to {url} due to {retry_state.outcome.exception()}. Attempt {retry_state.attempt_number}' - ), - ) - def _send_request_with_retry(): - response = session.request(method, url, **kwargs) - response.raise_for_status() - return response - - return _send_request_with_retry() + response = session.request(method, url, **kwargs) + response.raise_for_status() + return response diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index 89700ead5e..95757a1351 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -32,7 +32,12 @@ class AgentSession: _closed: bool = False loop: asyncio.AbstractEventLoop | None = None - def __init__(self, sid: str, file_store: FileStore): + def __init__( + self, + sid: str, + file_store: FileStore, + status_callback: Optional[Callable] = None, + ): """Initializes a new instance of the Session class Parameters: @@ -43,6 +48,7 @@ class AgentSession: self.sid = sid self.event_stream = EventStream(sid, file_store) self.file_store = file_store + self._status_callback = status_callback async def start( self, @@ -53,7 +59,6 @@ class AgentSession: max_budget_per_task: float | None = None, agent_to_llm_config: dict[str, LLMConfig] | None = None, agent_configs: dict[str, AgentConfig] | None = None, - status_message_callback: Optional[Callable] = None, ): """Starts the Agent session Parameters: @@ -80,7 +85,6 @@ class AgentSession: max_budget_per_task, agent_to_llm_config, agent_configs, - status_message_callback, ) def _start_thread(self, *args): @@ -99,14 +103,12 @@ class AgentSession: max_budget_per_task: float | None = None, agent_to_llm_config: dict[str, LLMConfig] | None = None, agent_configs: dict[str, AgentConfig] | None = None, - status_message_callback: Optional[Callable] = None, ): self._create_security_analyzer(config.security.security_analyzer) await self._create_runtime( runtime_name=runtime_name, config=config, agent=agent, - status_message_callback=status_message_callback, ) self._create_controller( agent, @@ -132,6 +134,10 @@ class AgentSession: asyncio.get_event_loop().run_in_executor(None, inner_close) + async def stop_agent_loop_for_error(self): + if self.controller is not None: + await self.controller.set_agent_state_to(AgentState.ERROR) + async def _close(self): if self._closed: return @@ -162,7 +168,6 @@ class AgentSession: runtime_name: str, config: AppConfig, agent: Agent, - status_message_callback: Optional[Callable] = None, ): """Creates a runtime instance @@ -182,13 +187,17 @@ class AgentSession: event_stream=self.event_stream, sid=self.sid, plugins=agent.sandbox_plugins, - status_message_callback=status_message_callback, + status_callback=self._status_callback, ) try: await self.runtime.connect() except Exception as e: logger.error(f'Runtime initialization failed: {e}', exc_info=True) + if self._status_callback: + self._status_callback( + 'error', 'STATUS$ERROR_RUNTIME_DISCONNECTED', str(e) + ) raise if self.runtime is not None: @@ -252,9 +261,8 @@ class AgentSession: agent_to_llm_config=agent_to_llm_config, agent_configs=agent_configs, confirmation_mode=confirmation_mode, - # AgentSession is designed to communicate with the frontend, so we don't want to - # run the agent in headless mode. headless_mode=False, + status_callback=self._status_callback, ) try: agent_state = State.restore_from_session(self.sid, self.file_store) diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index ef58ae052a..59dd8a0295 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -40,7 +40,9 @@ class Session: self.sid = sid self.websocket = ws self.last_active_ts = int(time.time()) - self.agent_session = AgentSession(sid, file_store) + self.agent_session = AgentSession( + sid, file_store, status_callback=self.queue_status_message + ) self.agent_session.event_stream.subscribe( EventStreamSubscriber.SERVER, self.on_event ) @@ -115,7 +117,6 @@ class Session: max_budget_per_task=self.config.max_budget_per_task, agent_to_llm_config=self.config.get_agent_to_llm_config_map(), agent_configs=self.config.get_agent_configs(), - status_message_callback=self.queue_status_message, ) except Exception as e: logger.exception(f'Error creating controller: {e}') @@ -171,14 +172,6 @@ class Session: 'Model does not support image upload, change to a different model or try without an image.' ) return - if self.loop: - asyncio.run_coroutine_threadsafe( - self._add_event(event, EventSource.USER), self.loop - ) # type: ignore - else: - raise RuntimeError('No event loop found') - - async def _add_event(self, event, event_source): self.agent_session.event_stream.add_event(event, EventSource.USER) async def send(self, data: dict[str, object]) -> bool: @@ -200,11 +193,17 @@ class Session: """Sends an error message to the client.""" return await self.send({'error': True, 'message': message}) - async def send_status_message(self, message: str) -> bool: + async def _send_status_message(self, msg_type: str, id: str, message: str) -> bool: """Sends a status message to the client.""" - return await self.send({'status': message}) + if msg_type == 'error': + await self.agent_session.stop_agent_loop_for_error() - def queue_status_message(self, message: str): + return await self.send( + {'status_update': True, 'type': msg_type, 'id': id, 'message': message} + ) + + def queue_status_message(self, msg_type: str, id: str, message: str): """Queues a status message to be sent asynchronously.""" - # Ensure the coroutine runs in the main event loop - asyncio.run_coroutine_threadsafe(self.send_status_message(message), self.loop) + asyncio.run_coroutine_threadsafe( + self._send_status_message(msg_type, id, message), self.loop + ) diff --git a/tests/runtime/test_stress_remote_runtime.py b/tests/runtime/test_stress_remote_runtime.py new file mode 100644 index 0000000000..4d96ee132a --- /dev/null +++ b/tests/runtime/test_stress_remote_runtime.py @@ -0,0 +1,231 @@ +"""Bash-related tests for the EventStreamRuntime, which connects to the ActionExecutor running in the sandbox.""" + +import asyncio +import os +import tempfile +from unittest.mock import MagicMock + +import pandas as pd +import pytest +from conftest import TEST_IN_CI + +from evaluation.utils.shared import ( + EvalException, + EvalMetadata, + EvalOutput, + assert_and_raise, + codeact_user_response, + make_metadata, + prepare_dataset, + reset_logger_for_multiprocessing, + run_evaluation, +) +from openhands.agenthub import Agent +from openhands.controller.state.state import State +from openhands.core.config import ( + AgentConfig, + AppConfig, + LLMConfig, + SandboxConfig, +) +from openhands.core.logger import openhands_logger as logger +from openhands.core.main import create_runtime, run_controller +from openhands.events.action import CmdRunAction, MessageAction +from openhands.events.observation import CmdOutputObservation +from openhands.events.serialization.event import event_to_dict +from openhands.llm import LLM +from openhands.runtime.base import Runtime +from openhands.utils.async_utils import call_async_from_sync + +AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = { + 'CodeActAgent': codeact_user_response, +} + + +def get_config( + metadata: EvalMetadata, +) -> AppConfig: + assert ( + os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL') is not None + ), 'SANDBOX_REMOTE_RUNTIME_API_URL must be set.' + assert ( + os.environ.get('ALLHANDS_API_KEY') is not None + ), 'ALLHANDS_API_KEY must be set.' + config = AppConfig( + default_agent=metadata.agent_class, + run_as_openhands=False, + max_iterations=metadata.max_iterations, + runtime='remote', + sandbox=SandboxConfig( + base_container_image='python:3.11-bookworm', + enable_auto_lint=True, + use_host_network=False, + # large enough timeout, since some testcases take very long to run + timeout=300, + api_key=os.environ.get('ALLHANDS_API_KEY', None), + remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'), + keep_remote_runtime_alive=False, + ), + # do not mount workspace + workspace_base=None, + workspace_mount_path=None, + ) + agent_config = AgentConfig( + codeact_enable_jupyter=False, + codeact_enable_browsing=False, + codeact_enable_llm_editor=False, + ) + config.set_agent_config(agent_config) + return config + + +def initialize_runtime( + runtime: Runtime, +): + """Initialize the runtime for the agent. + + This function is called before the runtime is used to run the agent. + """ + logger.info('-' * 30) + logger.info('BEGIN Runtime Initialization Fn') + logger.info('-' * 30) + obs: CmdOutputObservation + + action = CmdRunAction(command="""export USER=$(whoami); echo USER=${USER} """) + action.timeout = 600 + logger.info(action, extra={'msg_type': 'ACTION'}) + obs = runtime.run_action(action) + logger.info(obs, extra={'msg_type': 'OBSERVATION'}) + assert_and_raise(obs.exit_code == 0, f'Failed to export USER: {str(obs)}') + + action = CmdRunAction(command='mkdir -p /dummy_dir') + action.timeout = 600 + logger.info(action, extra={'msg_type': 'ACTION'}) + obs = runtime.run_action(action) + logger.info(obs, extra={'msg_type': 'OBSERVATION'}) + assert_and_raise( + obs.exit_code == 0, + f'Failed to create /dummy_dir: {str(obs)}', + ) + + with tempfile.TemporaryDirectory() as temp_dir: + # Construct the full path for the desired file name within the temporary directory + temp_file_path = os.path.join(temp_dir, 'dummy_file') + # Write to the file with the desired name within the temporary directory + with open(temp_file_path, 'w') as f: + f.write('dummy content') + + # Copy the file to the desired location + runtime.copy_to(temp_file_path, '/dummy_dir/') + + logger.info('-' * 30) + logger.info('END Runtime Initialization Fn') + logger.info('-' * 30) + + +def process_instance( + instance: pd.Series, + metadata: EvalMetadata, + reset_logger: bool = True, +) -> EvalOutput: + config = get_config(metadata) + + # Setup the logger properly, so you can run multi-processing to parallelize the evaluation + if reset_logger: + log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs') + reset_logger_for_multiprocessing(logger, instance.instance_id, log_dir) + else: + logger.info(f'Starting evaluation for instance {instance.instance_id}.') + + runtime = create_runtime(config) + call_async_from_sync(runtime.connect) + + try: + initialize_runtime(runtime) + + instruction = 'dummy instruction' + agent = Agent.get_cls(metadata.agent_class)( + llm=LLM(config=metadata.llm_config), + config=config.get_agent_config(metadata.agent_class), + ) + + def next_command(*args, **kwargs): + return CmdRunAction(command='ls -lah') + + agent.step = MagicMock(side_effect=next_command) + + # Here's how you can run the agent (similar to the `main` function) and get the final task state + state: State | None = asyncio.run( + run_controller( + config=config, + initial_user_action=MessageAction(content=instruction), + runtime=runtime, + fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[ + metadata.agent_class + ], + agent=agent, + ) + ) + + # if fatal error, throw EvalError to trigger re-run + if ( + state.last_error + and 'fatal error during agent execution' in state.last_error + and 'stuck in a loop' not in state.last_error + ): + raise EvalException('Fatal error detected: ' + state.last_error) + + finally: + runtime.close() + + test_result = {} + if state is None: + raise ValueError('State should not be None.') + histories = [event_to_dict(event) for event in state.history.get_events()] + metrics = state.metrics.get() if state.metrics else None + + # Save the output + output = EvalOutput( + instance_id=instance.instance_id, + instruction=instruction, + instance=instance.to_dict(), # SWE Bench specific + test_result=test_result, + metadata=metadata, + history=histories, + metrics=metrics, + error=state.last_error if state and state.last_error else None, + ) + return output + + +@pytest.mark.skipif( + TEST_IN_CI, + reason='This test should only be run locally, not in CI.', +) +def test_stress_remote_runtime(n_eval_workers: int = 64): + """Mimic evaluation setting to test remote runtime in a multi-processing setting.""" + + llm_config = LLMConfig() + metadata = make_metadata( + llm_config, + 'dummy_dataset_descrption', + 'CodeActAgent', + max_iterations=10, + eval_note='dummy_eval_note', + eval_output_dir='./dummy_eval_output_dir', + details={}, + ) + + # generate 300 random dummy instances + dummy_instance = pd.DataFrame( + { + 'instance_id': [f'dummy_instance_{i}' for i in range(300)], + } + ) + + output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl') + instances = prepare_dataset( + dummy_instance, output_file, eval_n_limit=len(dummy_instance) + ) + + run_evaluation(instances, metadata, output_file, n_eval_workers, process_instance) diff --git a/tests/unit/test_agent_controller.py b/tests/unit/test_agent_controller.py index 4d63c27405..0e38a92d1d 100644 --- a/tests/unit/test_agent_controller.py +++ b/tests/unit/test_agent_controller.py @@ -7,14 +7,12 @@ from openhands.controller.agent import Agent from openhands.controller.agent_controller import AgentController from openhands.controller.state.state import TrafficControlState from openhands.core.config import AppConfig -from openhands.core.exceptions import LLMMalformedActionError from openhands.core.main import run_controller from openhands.core.schema import AgentState from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction from openhands.events.observation import ( ErrorObservation, - FatalErrorObservation, ) from openhands.events.serialization import event_to_dict from openhands.llm import LLM @@ -45,6 +43,11 @@ def mock_event_stream(): return MagicMock(spec=EventStream) +@pytest.fixture +def mock_status_callback(): + return AsyncMock() + + @pytest.mark.asyncio async def test_set_agent_state(mock_agent, mock_event_stream): controller = AgentController( @@ -98,39 +101,19 @@ async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream) @pytest.mark.asyncio -async def test_report_error(mock_agent, mock_event_stream): +async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_callback): controller = AgentController( agent=mock_agent, event_stream=mock_event_stream, + status_callback=mock_status_callback, max_iterations=10, sid='test', confirmation_mode=False, headless_mode=True, ) error_message = 'Test error' - await controller.report_error(error_message) - assert controller.state.last_error == error_message - controller.event_stream.add_event.assert_called_once() - await controller.close() - - -@pytest.mark.asyncio -async def test_step_with_exception(mock_agent, mock_event_stream): - controller = AgentController( - agent=mock_agent, - event_stream=mock_event_stream, - max_iterations=10, - sid='test', - confirmation_mode=False, - headless_mode=True, - ) - controller.state.agent_state = AgentState.RUNNING - controller.report_error = AsyncMock() - controller.agent.step.side_effect = LLMMalformedActionError('Malformed action') - await controller._step() - - # Verify that report_error was called with the correct error message - controller.report_error.assert_called_once_with('Malformed action') + await controller._react_to_exception(RuntimeError(error_message)) + controller.status_callback.assert_called_once() await controller.close() @@ -141,21 +124,24 @@ async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream): event_stream = EventStream(sid='test', file_store=file_store) agent = MagicMock(spec=Agent) - # a random message to send to the runtime - event = CmdRunAction(command='ls') - agent.step.return_value = event + agent = MagicMock(spec=Agent) + + def agent_step_fn(state): + print(f'agent_step_fn received state: {state}') + return CmdRunAction(command='ls') + + agent.step = agent_step_fn agent.llm = MagicMock(spec=LLM) agent.llm.metrics = Metrics() agent.llm.config = config.get_llm_config() - fatal_error_obs = FatalErrorObservation('Fatal error detected') - fatal_error_obs._cause = event.id - runtime = MagicMock(spec=Runtime) async def on_event(event: Event): if isinstance(event, CmdRunAction): - await event_stream.async_add_event(fatal_error_obs, EventSource.USER) + error_obs = ErrorObservation('You messed around with Jim') + error_obs._cause = event.id + event_stream.add_event(error_obs, EventSource.USER) event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event) runtime.event_stream = event_stream @@ -170,30 +156,23 @@ async def test_run_controller_with_fatal_error(mock_agent, mock_event_stream): ) print(f'state: {state}') print(f'event_stream: {list(event_stream.get_events())}') - assert state.iteration == 1 - # it will first become AgentState.ERROR, then become AgentState.STOPPED - # in side run_controller (since the while loop + sleep no longer loop) - assert state.agent_state == AgentState.STOPPED - assert ( - state.last_error - == 'There was a fatal error during agent execution: **FatalErrorObservation**\nFatal error detected' - ) - assert len(list(event_stream.get_events())) == 5 + assert state.iteration == 4 + assert state.agent_state == AgentState.ERROR + assert state.last_error == 'Agent got stuck in a loop' + assert len(list(event_stream.get_events())) == 11 @pytest.mark.asyncio -async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream): +async def test_run_controller_stop_with_stuck(): config = AppConfig() file_store = get_file_store(config.file_store, config.file_store_path) event_stream = EventStream(sid='test', file_store=file_store) agent = MagicMock(spec=Agent) - # a random message to send to the runtime - event = CmdRunAction(command='ls') def agent_step_fn(state): print(f'agent_step_fn received state: {state}') - return event + return CmdRunAction(command='ls') agent.step = agent_step_fn agent.llm = MagicMock(spec=LLM) @@ -207,9 +186,7 @@ async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream): 'Non fatal error here to trigger loop' ) non_fatal_error_obs._cause = event.id - await event_stream.async_add_event( - non_fatal_error_obs, EventSource.ENVIRONMENT - ) + event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT) event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event) runtime.event_stream = event_stream @@ -228,7 +205,7 @@ async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream): print(f'event {i}: {event_to_dict(event)}') assert state.iteration == 4 - assert len(events) == 12 + assert len(events) == 11 # check the eventstream have 4 pairs of repeated actions and observations repeating_actions_and_observations = events[2:10] for action, observation in zip( @@ -246,13 +223,8 @@ async def test_run_controller_stop_with_stuck(mock_agent, mock_event_stream): assert last_event['extras']['agent_state'] == 'error' assert last_event['observation'] == 'agent_state_changed' - # it will first become AgentState.ERROR, then become AgentState.STOPPED - # in side run_controller (since the while loop + sleep no longer loop) - assert state.agent_state == AgentState.STOPPED - assert ( - state.last_error - == 'There was a fatal error during agent execution: **FatalErrorObservation**\nAgent got stuck in a loop' - ) + assert state.agent_state == AgentState.ERROR + assert state.last_error == 'Agent got stuck in a loop' @pytest.mark.asyncio @@ -319,7 +291,7 @@ async def test_step_max_iterations(mock_agent, mock_event_stream): assert controller.state.traffic_control_state == TrafficControlState.NORMAL await controller._step() assert controller.state.traffic_control_state == TrafficControlState.THROTTLING - assert controller.state.agent_state == AgentState.PAUSED + assert controller.state.agent_state == AgentState.ERROR await controller.close() @@ -359,7 +331,7 @@ async def test_step_max_budget(mock_agent, mock_event_stream): assert controller.state.traffic_control_state == TrafficControlState.NORMAL await controller._step() assert controller.state.traffic_control_state == TrafficControlState.THROTTLING - assert controller.state.agent_state == AgentState.PAUSED + assert controller.state.agent_state == AgentState.ERROR await controller.close() diff --git a/tests/unit/test_is_stuck.py b/tests/unit/test_is_stuck.py index 2fe8a683c8..af3ef6b83c 100644 --- a/tests/unit/test_is_stuck.py +++ b/tests/unit/test_is_stuck.py @@ -440,9 +440,10 @@ class TestStuckDetector: read_observation_2._cause = read_action_2._id event_stream.add_event(read_observation_2, EventSource.ENVIRONMENT) - # one more message to break the pattern - message_null_observation = NullObservation(content='') + message_action = MessageAction(content='Come on', wait_for_response=False) event_stream.add_event(message_action, EventSource.USER) + + message_null_observation = NullObservation(content='') event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT) cmd_action_3 = CmdRunAction(command='ls')