mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
(enh) send status messages to UI during startup (#3771)
Co-authored-by: Robert Brennan <accounts@rbren.io> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com> Co-authored-by: Robert Brennan <contact@rbren.io> Co-authored-by: sp.wack <83104063+amanape@users.noreply.github.com>
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -228,3 +228,4 @@ runtime_*.tar
|
||||
# docker build
|
||||
containers/runtime/Dockerfile
|
||||
containers/runtime/project.tar.gz
|
||||
containers/runtime/code
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
# Dynamic constructed Dockerfile
|
||||
# Dynamically constructed Dockerfile
|
||||
|
||||
This folder builds runtime image (sandbox), which will use a `Dockerfile` that is dynamically generated depends on the `base_image` AND a [Python source distribution](https://docs.python.org/3.10/distutils/sourcedist.html) that's based on the current commit of `openhands`.
|
||||
This folder builds a runtime image (sandbox), which will use a dynamically generated `Dockerfile`
|
||||
that depends on the `base_image` **AND** a [Python source distribution](https://docs.python.org/3.10/distutils/sourcedist.html) that is based on the current commit of `openhands`.
|
||||
|
||||
The following command will generate Dockerfile for `ubuntu:22.04` and the source distribution `.tar` into `containers/runtime`.
|
||||
The following command will generate a `Dockerfile` file for `nikolaik/python-nodejs:python3.11-nodejs22` (the default base image), an updated `config.sh` and the runtime source distribution files/folders into `containers/runtime`:
|
||||
|
||||
```bash
|
||||
poetry run python3 openhands/runtime/utils/runtime_build.py \
|
||||
--base_image ubuntu:22.04 \
|
||||
--base_image nikolaik/python-nodejs:python3.11-nodejs22 \
|
||||
--build_folder containers/runtime
|
||||
```
|
||||
|
||||
@@ -18,6 +18,7 @@ enum IndicatorColor {
|
||||
function AgentStatusBar() {
|
||||
const { t } = useTranslation();
|
||||
const { curAgentState } = useSelector((state: RootState) => state.agent);
|
||||
const { curStatusMessage } = useSelector((state: RootState) => state.status);
|
||||
|
||||
const AgentStatusMap: {
|
||||
[k: string]: { message: string; indicator: IndicatorColor };
|
||||
@@ -90,14 +91,25 @@ function AgentStatusBar() {
|
||||
}
|
||||
}, [curAgentState]);
|
||||
|
||||
const [statusMessage, setStatusMessage] = React.useState<string>("");
|
||||
|
||||
React.useEffect(() => {
|
||||
const trimmedCustomMessage = curStatusMessage.message.trim();
|
||||
if (trimmedCustomMessage) {
|
||||
setStatusMessage(t(trimmedCustomMessage));
|
||||
} else {
|
||||
setStatusMessage(AgentStatusMap[curAgentState].message);
|
||||
}
|
||||
}, [curAgentState, curStatusMessage.message]);
|
||||
|
||||
return (
|
||||
<div className="flex items-center">
|
||||
<div
|
||||
className={`w-3 h-3 mr-2 rounded-full animate-pulse ${AgentStatusMap[curAgentState].indicator}`}
|
||||
/>
|
||||
<span className="text-sm text-stone-400">
|
||||
{AgentStatusMap[curAgentState].message}
|
||||
</span>
|
||||
<div className="flex flex-col items-center">
|
||||
<div className="flex items-center">
|
||||
<div
|
||||
className={`w-3 h-3 mr-2 rounded-full animate-pulse ${AgentStatusMap[curAgentState].indicator}`}
|
||||
/>
|
||||
<span className="text-sm text-stone-400">{statusMessage}</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,10 +6,11 @@ import {
|
||||
ActionSecurityRisk,
|
||||
appendSecurityAnalyzerInput,
|
||||
} from "#/state/securityAnalyzerSlice";
|
||||
import { setCurStatusMessage } from "#/state/statusSlice";
|
||||
import { setRootTask } from "#/state/taskSlice";
|
||||
import store from "#/store";
|
||||
import ActionType from "#/types/ActionType";
|
||||
import { ActionMessage } from "#/types/Message";
|
||||
import { ActionMessage, StatusMessage } from "#/types/Message";
|
||||
import { SocketMessage } from "#/types/ResponseType";
|
||||
import { handleObservationMessage } from "./observations";
|
||||
import { getRootTask } from "./taskService";
|
||||
@@ -138,6 +139,16 @@ export function handleActionMessage(message: ActionMessage) {
|
||||
}
|
||||
}
|
||||
|
||||
export function handleStatusMessage(message: StatusMessage) {
|
||||
const msg = message.message == null ? "" : message.message.trim();
|
||||
store.dispatch(
|
||||
setCurStatusMessage({
|
||||
...message,
|
||||
message: msg,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
export function handleAssistantMessage(data: string | SocketMessage) {
|
||||
let socketMessage: SocketMessage;
|
||||
|
||||
@@ -149,7 +160,9 @@ export function handleAssistantMessage(data: string | SocketMessage) {
|
||||
|
||||
if ("action" in socketMessage) {
|
||||
handleActionMessage(socketMessage);
|
||||
} else {
|
||||
} else if ("observation" in socketMessage) {
|
||||
handleObservationMessage(socketMessage);
|
||||
} else if ("message" in socketMessage) {
|
||||
handleStatusMessage(socketMessage);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,11 +8,19 @@ import { I18nKey } from "#/i18n/declaration";
|
||||
|
||||
const translate = (key: I18nKey) => i18next.t(key);
|
||||
|
||||
// Define a type for the messages
|
||||
type Message = {
|
||||
action: ActionType;
|
||||
args: Record<string, unknown>;
|
||||
};
|
||||
|
||||
class Session {
|
||||
private static _socket: WebSocket | null = null;
|
||||
|
||||
private static _latest_event_id: number = -1;
|
||||
|
||||
private static _messageQueue: Message[] = [];
|
||||
|
||||
public static _history: Record<string, unknown>[] = [];
|
||||
|
||||
// callbacks contain a list of callable functions
|
||||
@@ -83,6 +91,7 @@ class Session {
|
||||
toast.success("ws", translate(I18nKey.SESSION$SERVER_CONNECTED_MESSAGE));
|
||||
Session._connecting = false;
|
||||
Session._initializeAgent();
|
||||
Session._flushQueue();
|
||||
Session.callbacks.open?.forEach((callback) => {
|
||||
callback(e);
|
||||
});
|
||||
@@ -94,7 +103,6 @@ class Session {
|
||||
data = JSON.parse(e.data);
|
||||
Session._history.push(data);
|
||||
} catch (err) {
|
||||
// TODO: report the error
|
||||
toast.error(
|
||||
"ws",
|
||||
translate(I18nKey.SESSION$SESSION_HANDLING_ERROR_MESSAGE),
|
||||
@@ -115,6 +123,7 @@ class Session {
|
||||
};
|
||||
|
||||
Session._socket.onerror = () => {
|
||||
// TODO report error
|
||||
toast.error(
|
||||
"ws",
|
||||
translate(I18nKey.SESSION$SESSION_CONNECTION_ERROR_MESSAGE),
|
||||
@@ -145,9 +154,20 @@ class Session {
|
||||
Session._socket = null;
|
||||
}
|
||||
|
||||
private static _flushQueue(): void {
|
||||
while (Session._messageQueue.length > 0) {
|
||||
const message = Session._messageQueue.shift();
|
||||
if (message) {
|
||||
setTimeout(() => Session.send(JSON.stringify(message)), 1000);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static send(message: string): void {
|
||||
const messageObject: Message = JSON.parse(message);
|
||||
|
||||
if (Session._connecting) {
|
||||
setTimeout(() => Session.send(message), 1000);
|
||||
Session._messageQueue.push(messageObject);
|
||||
return;
|
||||
}
|
||||
if (!Session.isConnected()) {
|
||||
|
||||
23
frontend/src/state/statusSlice.ts
Normal file
23
frontend/src/state/statusSlice.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
import { createSlice, PayloadAction } from "@reduxjs/toolkit";
|
||||
import { StatusMessage } from "#/types/Message";
|
||||
|
||||
const initialStatusMessage: StatusMessage = {
|
||||
message: "",
|
||||
is_error: false,
|
||||
};
|
||||
|
||||
export const statusSlice = createSlice({
|
||||
name: "status",
|
||||
initialState: {
|
||||
curStatusMessage: initialStatusMessage,
|
||||
},
|
||||
reducers: {
|
||||
setCurStatusMessage: (state, action: PayloadAction<StatusMessage>) => {
|
||||
state.curStatusMessage = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
export const { setCurStatusMessage } = statusSlice.actions;
|
||||
|
||||
export default statusSlice.reducer;
|
||||
@@ -8,6 +8,7 @@ import errorsReducer from "./state/errorsSlice";
|
||||
import taskReducer from "./state/taskSlice";
|
||||
import jupyterReducer from "./state/jupyterSlice";
|
||||
import securityAnalyzerReducer from "./state/securityAnalyzerSlice";
|
||||
import statusReducer from "./state/statusSlice";
|
||||
|
||||
export const rootReducer = combineReducers({
|
||||
browser: browserReducer,
|
||||
@@ -19,6 +20,7 @@ export const rootReducer = combineReducers({
|
||||
agent: agentReducer,
|
||||
jupyter: jupyterReducer,
|
||||
securityAnalyzer: securityAnalyzerReducer,
|
||||
status: statusReducer,
|
||||
});
|
||||
|
||||
const store = configureStore({
|
||||
|
||||
@@ -31,3 +31,12 @@ export interface ObservationMessage {
|
||||
// The timestamp of the message
|
||||
timestamp: string;
|
||||
}
|
||||
|
||||
export interface StatusMessage {
|
||||
// TODO not implemented yet
|
||||
// Whether the status is an error, default is false
|
||||
is_error: boolean;
|
||||
|
||||
// A status message to display to the user
|
||||
message: string;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { ActionMessage, ObservationMessage } from "./Message";
|
||||
import { ActionMessage, ObservationMessage, StatusMessage } from "./Message";
|
||||
|
||||
type SocketMessage = ActionMessage | ObservationMessage;
|
||||
type SocketMessage = ActionMessage | ObservationMessage | StatusMessage;
|
||||
|
||||
export { type SocketMessage };
|
||||
|
||||
@@ -55,7 +55,6 @@ def create_runtime(
|
||||
|
||||
config: The app config.
|
||||
sid: The session id.
|
||||
runtime_tools_config: (will be deprecated) The runtime tools config.
|
||||
"""
|
||||
# if sid is provided on the command line, use it as the name of the event stream
|
||||
# otherwise generate it on the basis of the configured jwt_secret
|
||||
|
||||
@@ -16,8 +16,10 @@ from pathlib import Path
|
||||
|
||||
import pexpect
|
||||
from fastapi import FastAPI, HTTPException, Request, UploadFile
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from uvicorn import run
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -562,6 +564,35 @@ if __name__ == '__main__':
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# TODO below 3 exception handlers were recommended by Sonnet.
|
||||
# Are these something we should keep?
|
||||
@app.exception_handler(Exception)
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
logger.exception('Unhandled exception occurred:')
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
'message': 'An unexpected error occurred. Please try again later.'
|
||||
},
|
||||
)
|
||||
|
||||
@app.exception_handler(StarletteHTTPException)
|
||||
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
|
||||
logger.error(f'HTTP exception occurred: {exc.detail}')
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code, content={'message': exc.detail}
|
||||
)
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(
|
||||
request: Request, exc: RequestValidationError
|
||||
):
|
||||
logger.error(f'Validation error occurred: {exc}')
|
||||
return JSONResponse(
|
||||
status_code=422,
|
||||
content={'message': 'Invalid request parameters', 'details': exc.errors()},
|
||||
)
|
||||
|
||||
@app.middleware('http')
|
||||
async def one_request_at_a_time(request: Request, call_next):
|
||||
assert client is not None
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
import tempfile
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Callable
|
||||
from zipfile import ZipFile
|
||||
|
||||
import docker
|
||||
@@ -119,6 +120,7 @@ class EventStreamRuntime(Runtime):
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
status_message_callback: Callable | None = None,
|
||||
):
|
||||
self.config = config
|
||||
self._host_port = 30000 # initial dummy value
|
||||
@@ -130,12 +132,13 @@ class EventStreamRuntime(Runtime):
|
||||
self.instance_id = (
|
||||
sid + '_' + str(uuid.uuid4()) if sid is not None else str(uuid.uuid4())
|
||||
)
|
||||
self.status_message_callback = status_message_callback
|
||||
|
||||
self.send_status_message('STATUS$STARTING_RUNTIME')
|
||||
self.docker_client: docker.DockerClient = self._init_docker_client()
|
||||
self.base_container_image = self.config.sandbox.base_container_image
|
||||
self.runtime_container_image = self.config.sandbox.runtime_container_image
|
||||
self.container_name = self.container_name_prefix + self.instance_id
|
||||
|
||||
self.container = None
|
||||
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
|
||||
|
||||
@@ -146,9 +149,10 @@ class EventStreamRuntime(Runtime):
|
||||
self.log_buffer: LogBuffer | None = None
|
||||
|
||||
if self.config.sandbox.runtime_extra_deps:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}'
|
||||
)
|
||||
|
||||
self.skip_container_logs = (
|
||||
os.environ.get('SKIP_CONTAINER_LOGS', 'false').lower() == 'true'
|
||||
)
|
||||
@@ -157,6 +161,8 @@ class EventStreamRuntime(Runtime):
|
||||
raise ValueError(
|
||||
'Neither runtime container image nor base container image is set'
|
||||
)
|
||||
logger.info('Preparing container, this might take a few minutes...')
|
||||
self.send_status_message('STATUS$STARTING_CONTAINER')
|
||||
self.runtime_container_image = build_runtime_image(
|
||||
self.base_container_image,
|
||||
self.runtime_builder,
|
||||
@@ -169,9 +175,13 @@ class EventStreamRuntime(Runtime):
|
||||
)
|
||||
|
||||
# will initialize both the event stream and the env vars
|
||||
super().__init__(config, event_stream, sid, plugins, env_vars)
|
||||
super().__init__(
|
||||
config, event_stream, sid, plugins, env_vars, status_message_callback
|
||||
)
|
||||
|
||||
logger.info('Waiting for client to become ready...')
|
||||
self.send_status_message('STATUS$WAITING_FOR_CLIENT')
|
||||
|
||||
logger.info('Waiting for runtime container to be alive...')
|
||||
self._wait_until_alive()
|
||||
|
||||
self.setup_initial_env()
|
||||
@@ -179,6 +189,7 @@ class EventStreamRuntime(Runtime):
|
||||
logger.info(
|
||||
f'Container initialized with plugins: {[plugin.name for plugin in self.plugins]}'
|
||||
)
|
||||
self.send_status_message(' ')
|
||||
|
||||
@staticmethod
|
||||
def _init_docker_client() -> docker.DockerClient:
|
||||
@@ -201,9 +212,8 @@ class EventStreamRuntime(Runtime):
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
):
|
||||
try:
|
||||
logger.info(
|
||||
f'Starting container with image: {self.runtime_container_image} and name: {self.container_name}'
|
||||
)
|
||||
logger.info('Preparing to start container...')
|
||||
self.send_status_message('STATUS$PREPARING_CONTAINER')
|
||||
plugin_arg = ''
|
||||
if plugins is not None and len(plugins) > 0:
|
||||
plugin_arg = (
|
||||
@@ -241,17 +251,17 @@ class EventStreamRuntime(Runtime):
|
||||
if self.config.debug:
|
||||
environment['DEBUG'] = 'true'
|
||||
|
||||
logger.info(f'Workspace Base: {self.config.workspace_base}')
|
||||
logger.debug(f'Workspace Base: {self.config.workspace_base}')
|
||||
if mount_dir is not None and sandbox_workspace_dir is not None:
|
||||
# e.g. result would be: {"/home/user/openhands/workspace": {'bind': "/workspace", 'mode': 'rw'}}
|
||||
volumes = {mount_dir: {'bind': sandbox_workspace_dir, 'mode': 'rw'}}
|
||||
logger.info(f'Mount dir: {mount_dir}')
|
||||
logger.debug(f'Mount dir: {mount_dir}')
|
||||
else:
|
||||
logger.warn(
|
||||
'Warning: Mount dir is not set, will not mount the workspace directory to the container!\n'
|
||||
)
|
||||
volumes = None
|
||||
logger.info(f'Sandbox workspace: {sandbox_workspace_dir}')
|
||||
logger.debug(f'Sandbox workspace: {sandbox_workspace_dir}')
|
||||
|
||||
if self.config.sandbox.browsergym_eval_env is not None:
|
||||
browsergym_arg = (
|
||||
@@ -259,6 +269,7 @@ class EventStreamRuntime(Runtime):
|
||||
)
|
||||
else:
|
||||
browsergym_arg = ''
|
||||
|
||||
container = self.docker_client.containers.run(
|
||||
self.runtime_container_image,
|
||||
command=(
|
||||
@@ -281,6 +292,7 @@ class EventStreamRuntime(Runtime):
|
||||
)
|
||||
self.log_buffer = LogBuffer(container)
|
||||
logger.info(f'Container started. Server url: {self.api_url}')
|
||||
self.send_status_message('STATUS$CONTAINER_STARTED')
|
||||
return container
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -539,3 +551,8 @@ class EventStreamRuntime(Runtime):
|
||||
return port
|
||||
# If no port is found after max_attempts, return the last tried port
|
||||
return port
|
||||
|
||||
def send_status_message(self, message: str):
|
||||
"""Sends a status message if the callback function was provided."""
|
||||
if self.status_message_callback:
|
||||
self.status_message_callback(message)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Callable, Optional
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.events.action import (
|
||||
FileReadAction,
|
||||
@@ -25,8 +27,15 @@ class E2BRuntime(Runtime):
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
sandbox: E2BSandbox | None = None,
|
||||
status_message_callback: Optional[Callable] = None,
|
||||
):
|
||||
super().__init__(config, event_stream, sid, plugins)
|
||||
super().__init__(
|
||||
config,
|
||||
event_stream,
|
||||
sid,
|
||||
plugins,
|
||||
status_message_callback=status_message_callback,
|
||||
)
|
||||
if sandbox is None:
|
||||
self.sandbox = E2BSandbox()
|
||||
if not isinstance(self.sandbox, E2BSandbox):
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
import tempfile
|
||||
import threading
|
||||
import uuid
|
||||
from typing import Callable, Optional
|
||||
from zipfile import ZipFile
|
||||
|
||||
import requests
|
||||
@@ -55,6 +56,7 @@ class RemoteRuntime(Runtime):
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
status_message_callback: Optional[Callable] = None,
|
||||
):
|
||||
self.config = config
|
||||
if self.config.sandbox.api_hostname == 'localhost':
|
||||
@@ -168,7 +170,9 @@ class RemoteRuntime(Runtime):
|
||||
)
|
||||
|
||||
# Initialize the eventstream and env vars
|
||||
super().__init__(config, event_stream, sid, plugins, env_vars)
|
||||
super().__init__(
|
||||
config, event_stream, sid, plugins, env_vars, status_message_callback
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'Runtime initialized with plugins: {[plugin.name for plugin in self.plugins]}'
|
||||
|
||||
@@ -3,6 +3,7 @@ import copy
|
||||
import json
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
from openhands.core.config import AppConfig, SandboxConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -58,11 +59,13 @@ class Runtime:
|
||||
sid: str = 'default',
|
||||
plugins: list[PluginRequirement] | None = None,
|
||||
env_vars: dict[str, str] | None = None,
|
||||
status_message_callback: Callable | None = None,
|
||||
):
|
||||
self.sid = sid
|
||||
self.event_stream = event_stream
|
||||
self.event_stream.subscribe(EventStreamSubscriber.RUNTIME, self.on_event)
|
||||
self.plugins = plugins if plugins is not None and len(plugins) > 0 else []
|
||||
self.status_message_callback = status_message_callback
|
||||
|
||||
self.config = copy.deepcopy(config)
|
||||
atexit.register(self.close)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
import asyncio
|
||||
from typing import Callable, Optional
|
||||
|
||||
from openhands.controller import AgentController
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State
|
||||
@@ -46,9 +49,9 @@ class AgentSession:
|
||||
max_budget_per_task: float | None = None,
|
||||
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
status_message_callback: Optional[Callable] = None,
|
||||
):
|
||||
"""Starts the Agent session
|
||||
|
||||
Parameters:
|
||||
- runtime_name: The name of the runtime associated with the session
|
||||
- config:
|
||||
@@ -58,13 +61,12 @@ class AgentSession:
|
||||
- agent_to_llm_config:
|
||||
- agent_configs:
|
||||
"""
|
||||
|
||||
if self.controller or self.runtime:
|
||||
raise RuntimeError(
|
||||
'Session already started. You need to close this session and start a new one.'
|
||||
)
|
||||
await self._create_security_analyzer(config.security.security_analyzer)
|
||||
await self._create_runtime(runtime_name, config, agent)
|
||||
await self._create_runtime(runtime_name, config, agent, status_message_callback)
|
||||
await self._create_controller(
|
||||
agent,
|
||||
config.security.confirmation_mode,
|
||||
@@ -96,13 +98,19 @@ class AgentSession:
|
||||
- security_analyzer: The name of the security analyzer to use
|
||||
"""
|
||||
|
||||
logger.info(f'Using security analyzer: {security_analyzer}')
|
||||
if security_analyzer:
|
||||
logger.debug(f'Using security analyzer: {security_analyzer}')
|
||||
self.security_analyzer = options.SecurityAnalyzers.get(
|
||||
security_analyzer, SecurityAnalyzer
|
||||
)(self.event_stream)
|
||||
|
||||
async def _create_runtime(self, runtime_name: str, config: AppConfig, agent: Agent):
|
||||
async def _create_runtime(
|
||||
self,
|
||||
runtime_name: str,
|
||||
config: AppConfig,
|
||||
agent: Agent,
|
||||
status_message_callback: Optional[Callable] = None,
|
||||
):
|
||||
"""Creates a runtime instance
|
||||
|
||||
Parameters:
|
||||
@@ -112,17 +120,27 @@ class AgentSession:
|
||||
"""
|
||||
|
||||
if self.runtime is not None:
|
||||
raise Exception('Runtime already created')
|
||||
raise RuntimeError('Runtime already created')
|
||||
|
||||
logger.info(f'Initializing runtime `{runtime_name}` now...')
|
||||
runtime_cls = get_runtime_cls(runtime_name)
|
||||
self.runtime = runtime_cls(
|
||||
|
||||
self.runtime = await asyncio.to_thread(
|
||||
runtime_cls,
|
||||
config=config,
|
||||
event_stream=self.event_stream,
|
||||
sid=self.sid,
|
||||
plugins=agent.sandbox_plugins,
|
||||
status_message_callback=status_message_callback,
|
||||
)
|
||||
|
||||
if self.runtime is not None:
|
||||
logger.debug(
|
||||
f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}'
|
||||
)
|
||||
else:
|
||||
logger.warning('Runtime initialization failed')
|
||||
|
||||
async def _create_controller(
|
||||
self,
|
||||
agent: Agent,
|
||||
@@ -178,5 +196,5 @@ class AgentSession:
|
||||
)
|
||||
logger.info(f'Restored agent state from session, sid: {self.sid}')
|
||||
except Exception as e:
|
||||
logger.info(f'Error restoring state: {e}')
|
||||
logger.info(f'State could not be restored: {e}')
|
||||
logger.info('Agent controller initialized.')
|
||||
@@ -35,9 +35,11 @@ class SessionManager:
|
||||
|
||||
async def send(self, sid: str, data: dict[str, object]) -> bool:
|
||||
"""Sends data to the client."""
|
||||
if sid not in self._sessions:
|
||||
session = self.get_session(sid)
|
||||
if session is None:
|
||||
logger.error(f'*** No session found for {sid}, skipping message ***')
|
||||
return False
|
||||
return await self._sessions[sid].send(data)
|
||||
return await session.send(data)
|
||||
|
||||
async def send_error(self, sid: str, message: str) -> bool:
|
||||
"""Sends an error message to the client."""
|
||||
|
||||
@@ -21,7 +21,7 @@ from openhands.events.serialization import event_from_dict, event_to_dict
|
||||
from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.runtime.utils.shutdown_listener import should_continue
|
||||
from openhands.server.session.agent import AgentSession
|
||||
from openhands.server.session.agent_session import AgentSession
|
||||
from openhands.storage.files import FileStore
|
||||
|
||||
DEL_DELT_SEC = 60 * 60 * 5
|
||||
@@ -33,6 +33,7 @@ class Session:
|
||||
last_active_ts: int = 0
|
||||
is_alive: bool = True
|
||||
agent_session: AgentSession
|
||||
loop: asyncio.AbstractEventLoop
|
||||
|
||||
def __init__(
|
||||
self, sid: str, ws: WebSocket | None, config: AppConfig, file_store: FileStore
|
||||
@@ -45,6 +46,7 @@ class Session:
|
||||
EventStreamSubscriber.SERVER, self.on_event
|
||||
)
|
||||
self.config = config
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
async def close(self):
|
||||
self.is_alive = False
|
||||
@@ -113,6 +115,7 @@ class Session:
|
||||
max_budget_per_task=self.config.max_budget_per_task,
|
||||
agent_to_llm_config=self.config.get_agent_to_llm_config_map(),
|
||||
agent_configs=self.config.get_agent_configs(),
|
||||
status_message_callback=self.queue_status_message,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f'Error creating controller: {e}')
|
||||
@@ -125,7 +128,8 @@ class Session:
|
||||
)
|
||||
|
||||
async def on_event(self, event: Event):
|
||||
"""Callback function for agent events.
|
||||
"""Callback function for events that mainly come from the agent.
|
||||
Event is the base class for any agent action and observation.
|
||||
|
||||
Args:
|
||||
event: The agent event (Observation or Action).
|
||||
@@ -135,7 +139,6 @@ class Session:
|
||||
if isinstance(event, NullObservation):
|
||||
return
|
||||
if event.source == EventSource.AGENT:
|
||||
logger.info('Server event')
|
||||
await self.send(event_to_dict(event))
|
||||
elif event.source == EventSource.USER and isinstance(
|
||||
event, CmdOutputObservation
|
||||
@@ -172,6 +175,9 @@ class Session:
|
||||
await asyncio.sleep(0.001) # This flushes the data to the client
|
||||
self.last_active_ts = int(time.time())
|
||||
return True
|
||||
except RuntimeError:
|
||||
self.is_alive = False
|
||||
return False
|
||||
except WebSocketDisconnect:
|
||||
self.is_alive = False
|
||||
return False
|
||||
@@ -195,3 +201,8 @@ class Session:
|
||||
return False
|
||||
self.is_alive = data.get('is_alive', False)
|
||||
return True
|
||||
|
||||
def queue_status_message(self, message: str):
|
||||
"""Queues a status message to be sent asynchronously."""
|
||||
# Ensure the coroutine runs in the main event loop
|
||||
asyncio.run_coroutine_threadsafe(self.send_message(message), self.loop)
|
||||
|
||||
Reference in New Issue
Block a user