(enh) send status messages to UI during startup (#3771)

Co-authored-by: Robert Brennan <accounts@rbren.io>
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
Co-authored-by: Robert Brennan <contact@rbren.io>
Co-authored-by: sp.wack <83104063+amanape@users.noreply.github.com>
This commit is contained in:
tobitege
2024-09-24 20:46:58 +02:00
committed by GitHub
parent 7b2b1eff57
commit c32cec7f89
19 changed files with 992 additions and 128 deletions

1
.gitignore vendored
View File

@@ -228,3 +228,4 @@ runtime_*.tar
# docker build
containers/runtime/Dockerfile
containers/runtime/project.tar.gz
containers/runtime/code

View File

@@ -1,11 +1,12 @@
# Dynamic constructed Dockerfile
# Dynamically constructed Dockerfile
This folder builds runtime image (sandbox), which will use a `Dockerfile` that is dynamically generated depends on the `base_image` AND a [Python source distribution](https://docs.python.org/3.10/distutils/sourcedist.html) that's based on the current commit of `openhands`.
This folder builds a runtime image (sandbox), which will use a dynamically generated `Dockerfile`
that depends on the `base_image` **AND** a [Python source distribution](https://docs.python.org/3.10/distutils/sourcedist.html) that is based on the current commit of `openhands`.
The following command will generate Dockerfile for `ubuntu:22.04` and the source distribution `.tar` into `containers/runtime`.
The following command will generate a `Dockerfile` file for `nikolaik/python-nodejs:python3.11-nodejs22` (the default base image), an updated `config.sh` and the runtime source distribution files/folders into `containers/runtime`:
```bash
poetry run python3 openhands/runtime/utils/runtime_build.py \
--base_image ubuntu:22.04 \
--base_image nikolaik/python-nodejs:python3.11-nodejs22 \
--build_folder containers/runtime
```

View File

@@ -18,6 +18,7 @@ enum IndicatorColor {
function AgentStatusBar() {
const { t } = useTranslation();
const { curAgentState } = useSelector((state: RootState) => state.agent);
const { curStatusMessage } = useSelector((state: RootState) => state.status);
const AgentStatusMap: {
[k: string]: { message: string; indicator: IndicatorColor };
@@ -90,14 +91,25 @@ function AgentStatusBar() {
}
}, [curAgentState]);
const [statusMessage, setStatusMessage] = React.useState<string>("");
React.useEffect(() => {
const trimmedCustomMessage = curStatusMessage.message.trim();
if (trimmedCustomMessage) {
setStatusMessage(t(trimmedCustomMessage));
} else {
setStatusMessage(AgentStatusMap[curAgentState].message);
}
}, [curAgentState, curStatusMessage.message]);
return (
<div className="flex items-center">
<div
className={`w-3 h-3 mr-2 rounded-full animate-pulse ${AgentStatusMap[curAgentState].indicator}`}
/>
<span className="text-sm text-stone-400">
{AgentStatusMap[curAgentState].message}
</span>
<div className="flex flex-col items-center">
<div className="flex items-center">
<div
className={`w-3 h-3 mr-2 rounded-full animate-pulse ${AgentStatusMap[curAgentState].indicator}`}
/>
<span className="text-sm text-stone-400">{statusMessage}</span>
</div>
</div>
);
}

File diff suppressed because it is too large Load Diff

View File

@@ -6,10 +6,11 @@ import {
ActionSecurityRisk,
appendSecurityAnalyzerInput,
} from "#/state/securityAnalyzerSlice";
import { setCurStatusMessage } from "#/state/statusSlice";
import { setRootTask } from "#/state/taskSlice";
import store from "#/store";
import ActionType from "#/types/ActionType";
import { ActionMessage } from "#/types/Message";
import { ActionMessage, StatusMessage } from "#/types/Message";
import { SocketMessage } from "#/types/ResponseType";
import { handleObservationMessage } from "./observations";
import { getRootTask } from "./taskService";
@@ -138,6 +139,16 @@ export function handleActionMessage(message: ActionMessage) {
}
}
export function handleStatusMessage(message: StatusMessage) {
const msg = message.message == null ? "" : message.message.trim();
store.dispatch(
setCurStatusMessage({
...message,
message: msg,
}),
);
}
export function handleAssistantMessage(data: string | SocketMessage) {
let socketMessage: SocketMessage;
@@ -149,7 +160,9 @@ export function handleAssistantMessage(data: string | SocketMessage) {
if ("action" in socketMessage) {
handleActionMessage(socketMessage);
} else {
} else if ("observation" in socketMessage) {
handleObservationMessage(socketMessage);
} else if ("message" in socketMessage) {
handleStatusMessage(socketMessage);
}
}

View File

@@ -8,11 +8,19 @@ import { I18nKey } from "#/i18n/declaration";
const translate = (key: I18nKey) => i18next.t(key);
// Define a type for the messages
type Message = {
action: ActionType;
args: Record<string, unknown>;
};
class Session {
private static _socket: WebSocket | null = null;
private static _latest_event_id: number = -1;
private static _messageQueue: Message[] = [];
public static _history: Record<string, unknown>[] = [];
// callbacks contain a list of callable functions
@@ -83,6 +91,7 @@ class Session {
toast.success("ws", translate(I18nKey.SESSION$SERVER_CONNECTED_MESSAGE));
Session._connecting = false;
Session._initializeAgent();
Session._flushQueue();
Session.callbacks.open?.forEach((callback) => {
callback(e);
});
@@ -94,7 +103,6 @@ class Session {
data = JSON.parse(e.data);
Session._history.push(data);
} catch (err) {
// TODO: report the error
toast.error(
"ws",
translate(I18nKey.SESSION$SESSION_HANDLING_ERROR_MESSAGE),
@@ -115,6 +123,7 @@ class Session {
};
Session._socket.onerror = () => {
// TODO report error
toast.error(
"ws",
translate(I18nKey.SESSION$SESSION_CONNECTION_ERROR_MESSAGE),
@@ -145,9 +154,20 @@ class Session {
Session._socket = null;
}
private static _flushQueue(): void {
while (Session._messageQueue.length > 0) {
const message = Session._messageQueue.shift();
if (message) {
setTimeout(() => Session.send(JSON.stringify(message)), 1000);
}
}
}
static send(message: string): void {
const messageObject: Message = JSON.parse(message);
if (Session._connecting) {
setTimeout(() => Session.send(message), 1000);
Session._messageQueue.push(messageObject);
return;
}
if (!Session.isConnected()) {

View File

@@ -0,0 +1,23 @@
import { createSlice, PayloadAction } from "@reduxjs/toolkit";
import { StatusMessage } from "#/types/Message";
const initialStatusMessage: StatusMessage = {
message: "",
is_error: false,
};
export const statusSlice = createSlice({
name: "status",
initialState: {
curStatusMessage: initialStatusMessage,
},
reducers: {
setCurStatusMessage: (state, action: PayloadAction<StatusMessage>) => {
state.curStatusMessage = action.payload;
},
},
});
export const { setCurStatusMessage } = statusSlice.actions;
export default statusSlice.reducer;

View File

@@ -8,6 +8,7 @@ import errorsReducer from "./state/errorsSlice";
import taskReducer from "./state/taskSlice";
import jupyterReducer from "./state/jupyterSlice";
import securityAnalyzerReducer from "./state/securityAnalyzerSlice";
import statusReducer from "./state/statusSlice";
export const rootReducer = combineReducers({
browser: browserReducer,
@@ -19,6 +20,7 @@ export const rootReducer = combineReducers({
agent: agentReducer,
jupyter: jupyterReducer,
securityAnalyzer: securityAnalyzerReducer,
status: statusReducer,
});
const store = configureStore({

View File

@@ -31,3 +31,12 @@ export interface ObservationMessage {
// The timestamp of the message
timestamp: string;
}
export interface StatusMessage {
// TODO not implemented yet
// Whether the status is an error, default is false
is_error: boolean;
// A status message to display to the user
message: string;
}

View File

@@ -1,5 +1,5 @@
import { ActionMessage, ObservationMessage } from "./Message";
import { ActionMessage, ObservationMessage, StatusMessage } from "./Message";
type SocketMessage = ActionMessage | ObservationMessage;
type SocketMessage = ActionMessage | ObservationMessage | StatusMessage;
export { type SocketMessage };

View File

@@ -55,7 +55,6 @@ def create_runtime(
config: The app config.
sid: The session id.
runtime_tools_config: (will be deprecated) The runtime tools config.
"""
# if sid is provided on the command line, use it as the name of the event stream
# otherwise generate it on the basis of the configured jwt_secret

View File

@@ -16,8 +16,10 @@ from pathlib import Path
import pexpect
from fastapi import FastAPI, HTTPException, Request, UploadFile
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from starlette.exceptions import HTTPException as StarletteHTTPException
from uvicorn import run
from openhands.core.logger import openhands_logger as logger
@@ -562,6 +564,35 @@ if __name__ == '__main__':
app = FastAPI(lifespan=lifespan)
# TODO below 3 exception handlers were recommended by Sonnet.
# Are these something we should keep?
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
logger.exception('Unhandled exception occurred:')
return JSONResponse(
status_code=500,
content={
'message': 'An unexpected error occurred. Please try again later.'
},
)
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
logger.error(f'HTTP exception occurred: {exc.detail}')
return JSONResponse(
status_code=exc.status_code, content={'message': exc.detail}
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
request: Request, exc: RequestValidationError
):
logger.error(f'Validation error occurred: {exc}')
return JSONResponse(
status_code=422,
content={'message': 'Invalid request parameters', 'details': exc.errors()},
)
@app.middleware('http')
async def one_request_at_a_time(request: Request, call_next):
assert client is not None

View File

@@ -2,6 +2,7 @@ import os
import tempfile
import threading
import uuid
from typing import Callable
from zipfile import ZipFile
import docker
@@ -119,6 +120,7 @@ class EventStreamRuntime(Runtime):
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
status_message_callback: Callable | None = None,
):
self.config = config
self._host_port = 30000 # initial dummy value
@@ -130,12 +132,13 @@ class EventStreamRuntime(Runtime):
self.instance_id = (
sid + '_' + str(uuid.uuid4()) if sid is not None else str(uuid.uuid4())
)
self.status_message_callback = status_message_callback
self.send_status_message('STATUS$STARTING_RUNTIME')
self.docker_client: docker.DockerClient = self._init_docker_client()
self.base_container_image = self.config.sandbox.base_container_image
self.runtime_container_image = self.config.sandbox.runtime_container_image
self.container_name = self.container_name_prefix + self.instance_id
self.container = None
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
@@ -146,9 +149,10 @@ class EventStreamRuntime(Runtime):
self.log_buffer: LogBuffer | None = None
if self.config.sandbox.runtime_extra_deps:
logger.info(
logger.debug(
f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}'
)
self.skip_container_logs = (
os.environ.get('SKIP_CONTAINER_LOGS', 'false').lower() == 'true'
)
@@ -157,6 +161,8 @@ class EventStreamRuntime(Runtime):
raise ValueError(
'Neither runtime container image nor base container image is set'
)
logger.info('Preparing container, this might take a few minutes...')
self.send_status_message('STATUS$STARTING_CONTAINER')
self.runtime_container_image = build_runtime_image(
self.base_container_image,
self.runtime_builder,
@@ -169,9 +175,13 @@ class EventStreamRuntime(Runtime):
)
# will initialize both the event stream and the env vars
super().__init__(config, event_stream, sid, plugins, env_vars)
super().__init__(
config, event_stream, sid, plugins, env_vars, status_message_callback
)
logger.info('Waiting for client to become ready...')
self.send_status_message('STATUS$WAITING_FOR_CLIENT')
logger.info('Waiting for runtime container to be alive...')
self._wait_until_alive()
self.setup_initial_env()
@@ -179,6 +189,7 @@ class EventStreamRuntime(Runtime):
logger.info(
f'Container initialized with plugins: {[plugin.name for plugin in self.plugins]}'
)
self.send_status_message(' ')
@staticmethod
def _init_docker_client() -> docker.DockerClient:
@@ -201,9 +212,8 @@ class EventStreamRuntime(Runtime):
plugins: list[PluginRequirement] | None = None,
):
try:
logger.info(
f'Starting container with image: {self.runtime_container_image} and name: {self.container_name}'
)
logger.info('Preparing to start container...')
self.send_status_message('STATUS$PREPARING_CONTAINER')
plugin_arg = ''
if plugins is not None and len(plugins) > 0:
plugin_arg = (
@@ -241,17 +251,17 @@ class EventStreamRuntime(Runtime):
if self.config.debug:
environment['DEBUG'] = 'true'
logger.info(f'Workspace Base: {self.config.workspace_base}')
logger.debug(f'Workspace Base: {self.config.workspace_base}')
if mount_dir is not None and sandbox_workspace_dir is not None:
# e.g. result would be: {"/home/user/openhands/workspace": {'bind': "/workspace", 'mode': 'rw'}}
volumes = {mount_dir: {'bind': sandbox_workspace_dir, 'mode': 'rw'}}
logger.info(f'Mount dir: {mount_dir}')
logger.debug(f'Mount dir: {mount_dir}')
else:
logger.warn(
'Warning: Mount dir is not set, will not mount the workspace directory to the container!\n'
)
volumes = None
logger.info(f'Sandbox workspace: {sandbox_workspace_dir}')
logger.debug(f'Sandbox workspace: {sandbox_workspace_dir}')
if self.config.sandbox.browsergym_eval_env is not None:
browsergym_arg = (
@@ -259,6 +269,7 @@ class EventStreamRuntime(Runtime):
)
else:
browsergym_arg = ''
container = self.docker_client.containers.run(
self.runtime_container_image,
command=(
@@ -281,6 +292,7 @@ class EventStreamRuntime(Runtime):
)
self.log_buffer = LogBuffer(container)
logger.info(f'Container started. Server url: {self.api_url}')
self.send_status_message('STATUS$CONTAINER_STARTED')
return container
except Exception as e:
logger.error(
@@ -539,3 +551,8 @@ class EventStreamRuntime(Runtime):
return port
# If no port is found after max_attempts, return the last tried port
return port
def send_status_message(self, message: str):
"""Sends a status message if the callback function was provided."""
if self.status_message_callback:
self.status_message_callback(message)

View File

@@ -1,3 +1,5 @@
from typing import Callable, Optional
from openhands.core.config import AppConfig
from openhands.events.action import (
FileReadAction,
@@ -25,8 +27,15 @@ class E2BRuntime(Runtime):
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
sandbox: E2BSandbox | None = None,
status_message_callback: Optional[Callable] = None,
):
super().__init__(config, event_stream, sid, plugins)
super().__init__(
config,
event_stream,
sid,
plugins,
status_message_callback=status_message_callback,
)
if sandbox is None:
self.sandbox = E2BSandbox()
if not isinstance(self.sandbox, E2BSandbox):

View File

@@ -2,6 +2,7 @@ import os
import tempfile
import threading
import uuid
from typing import Callable, Optional
from zipfile import ZipFile
import requests
@@ -55,6 +56,7 @@ class RemoteRuntime(Runtime):
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
status_message_callback: Optional[Callable] = None,
):
self.config = config
if self.config.sandbox.api_hostname == 'localhost':
@@ -168,7 +170,9 @@ class RemoteRuntime(Runtime):
)
# Initialize the eventstream and env vars
super().__init__(config, event_stream, sid, plugins, env_vars)
super().__init__(
config, event_stream, sid, plugins, env_vars, status_message_callback
)
logger.info(
f'Runtime initialized with plugins: {[plugin.name for plugin in self.plugins]}'

View File

@@ -3,6 +3,7 @@ import copy
import json
import os
from abc import abstractmethod
from typing import Callable
from openhands.core.config import AppConfig, SandboxConfig
from openhands.core.logger import openhands_logger as logger
@@ -58,11 +59,13 @@ class Runtime:
sid: str = 'default',
plugins: list[PluginRequirement] | None = None,
env_vars: dict[str, str] | None = None,
status_message_callback: Callable | None = None,
):
self.sid = sid
self.event_stream = event_stream
self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event)
self.plugins = plugins if plugins is not None and len(plugins) > 0 else []
self.status_message_callback = status_message_callback
self.config = copy.deepcopy(config)
atexit.register(self.close)

View File

@@ -1,3 +1,6 @@
import asyncio
from typing import Callable, Optional
from openhands.controller import AgentController
from openhands.controller.agent import Agent
from openhands.controller.state.state import State
@@ -46,9 +49,9 @@ class AgentSession:
max_budget_per_task: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None,
agent_configs: dict[str, AgentConfig] | None = None,
status_message_callback: Optional[Callable] = None,
):
"""Starts the Agent session
Parameters:
- runtime_name: The name of the runtime associated with the session
- config:
@@ -58,13 +61,12 @@ class AgentSession:
- agent_to_llm_config:
- agent_configs:
"""
if self.controller or self.runtime:
raise RuntimeError(
'Session already started. You need to close this session and start a new one.'
)
await self._create_security_analyzer(config.security.security_analyzer)
await self._create_runtime(runtime_name, config, agent)
await self._create_runtime(runtime_name, config, agent, status_message_callback)
await self._create_controller(
agent,
config.security.confirmation_mode,
@@ -96,13 +98,19 @@ class AgentSession:
- security_analyzer: The name of the security analyzer to use
"""
logger.info(f'Using security analyzer: {security_analyzer}')
if security_analyzer:
logger.debug(f'Using security analyzer: {security_analyzer}')
self.security_analyzer = options.SecurityAnalyzers.get(
security_analyzer, SecurityAnalyzer
)(self.event_stream)
async def _create_runtime(self, runtime_name: str, config: AppConfig, agent: Agent):
async def _create_runtime(
self,
runtime_name: str,
config: AppConfig,
agent: Agent,
status_message_callback: Optional[Callable] = None,
):
"""Creates a runtime instance
Parameters:
@@ -112,17 +120,27 @@ class AgentSession:
"""
if self.runtime is not None:
raise Exception('Runtime already created')
raise RuntimeError('Runtime already created')
logger.info(f'Initializing runtime `{runtime_name}` now...')
runtime_cls = get_runtime_cls(runtime_name)
self.runtime = runtime_cls(
self.runtime = await asyncio.to_thread(
runtime_cls,
config=config,
event_stream=self.event_stream,
sid=self.sid,
plugins=agent.sandbox_plugins,
status_message_callback=status_message_callback,
)
if self.runtime is not None:
logger.debug(
f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
)
else:
logger.warning('Runtime initialization failed')
async def _create_controller(
self,
agent: Agent,
@@ -178,5 +196,5 @@ class AgentSession:
)
logger.info(f'Restored agent state from session, sid: {self.sid}')
except Exception as e:
logger.info(f'Error restoring state: {e}')
logger.info(f'State could not be restored: {e}')
logger.info('Agent controller initialized.')

View File

@@ -35,9 +35,11 @@ class SessionManager:
async def send(self, sid: str, data: dict[str, object]) -> bool:
"""Sends data to the client."""
if sid not in self._sessions:
session = self.get_session(sid)
if session is None:
logger.error(f'*** No session found for {sid}, skipping message ***')
return False
return await self._sessions[sid].send(data)
return await session.send(data)
async def send_error(self, sid: str, message: str) -> bool:
"""Sends an error message to the client."""

View File

@@ -21,7 +21,7 @@ from openhands.events.serialization import event_from_dict, event_to_dict
from openhands.events.stream import EventStreamSubscriber
from openhands.llm.llm import LLM
from openhands.runtime.utils.shutdown_listener import should_continue
from openhands.server.session.agent import AgentSession
from openhands.server.session.agent_session import AgentSession
from openhands.storage.files import FileStore
DEL_DELT_SEC = 60 * 60 * 5
@@ -33,6 +33,7 @@ class Session:
last_active_ts: int = 0
is_alive: bool = True
agent_session: AgentSession
loop: asyncio.AbstractEventLoop
def __init__(
self, sid: str, ws: WebSocket | None, config: AppConfig, file_store: FileStore
@@ -45,6 +46,7 @@ class Session:
EventStreamSubscriber.SERVER, self.on_event
)
self.config = config
self.loop = asyncio.get_event_loop()
async def close(self):
self.is_alive = False
@@ -113,6 +115,7 @@ class Session:
max_budget_per_task=self.config.max_budget_per_task,
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
agent_configs=self.config.get_agent_configs(),
status_message_callback=self.queue_status_message,
)
except Exception as e:
logger.exception(f'Error creating controller: {e}')
@@ -125,7 +128,8 @@ class Session:
)
async def on_event(self, event: Event):
"""Callback function for agent events.
"""Callback function for events that mainly come from the agent.
Event is the base class for any agent action and observation.
Args:
event: The agent event (Observation or Action).
@@ -135,7 +139,6 @@ class Session:
if isinstance(event, NullObservation):
return
if event.source == EventSource.AGENT:
logger.info('Server event')
await self.send(event_to_dict(event))
elif event.source == EventSource.USER and isinstance(
event, CmdOutputObservation
@@ -172,6 +175,9 @@ class Session:
await asyncio.sleep(0.001) # This flushes the data to the client
self.last_active_ts = int(time.time())
return True
except RuntimeError:
self.is_alive = False
return False
except WebSocketDisconnect:
self.is_alive = False
return False
@@ -195,3 +201,8 @@ class Session:
return False
self.is_alive = data.get('is_alive', False)
return True
def queue_status_message(self, message: str):
"""Queues a status message to be sent asynchronously."""
# Ensure the coroutine runs in the main event loop
asyncio.run_coroutine_threadsafe(self.send_message(message), self.loop)