mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
33 Commits
fix-logbuf
...
rb/socket-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dbda7bcc82 | ||
|
|
2f275dce9e | ||
|
|
340b9ead40 | ||
|
|
1fba72043c | ||
|
|
bcab72c981 | ||
|
|
5949f5bdb6 | ||
|
|
20d4a0cce2 | ||
|
|
37ee90cf6f | ||
|
|
69c646ebf9 | ||
|
|
71e57dc069 | ||
|
|
3f03fb69c8 | ||
|
|
551a5a72e1 | ||
|
|
2aa6bab85f | ||
|
|
3f0e619997 | ||
|
|
511ce8cc3a | ||
|
|
a1b53a9498 | ||
|
|
767878563e | ||
|
|
b2bd23ac86 | ||
|
|
f0487e6818 | ||
|
|
b85fbf39fd | ||
|
|
8b4d263319 | ||
|
|
ff7783ec81 | ||
|
|
38dc41ca42 | ||
|
|
8dee334236 | ||
|
|
d4b20c284d | ||
|
|
2501cce470 | ||
|
|
ca9aefd7b2 | ||
|
|
92586a090d | ||
|
|
18e774dd8a | ||
|
|
3a59e037fb | ||
|
|
5973c0c269 | ||
|
|
2fa8c4e14d | ||
|
|
c2a5fbceb7 |
80
frontend/package-lock.json
generated
80
frontend/package-lock.json
generated
@@ -40,6 +40,7 @@
|
||||
"react-textarea-autosize": "^8.5.4",
|
||||
"remark-gfm": "^4.0.0",
|
||||
"sirv-cli": "^3.0.0",
|
||||
"socket.io-client": "^4.8.1",
|
||||
"tailwind-merge": "^2.5.4",
|
||||
"vite": "^5.4.9",
|
||||
"web-vitals": "^3.5.2",
|
||||
@@ -5576,6 +5577,11 @@
|
||||
"integrity": "sha512-+Fj43pSMwJs4KRrH/938Uf+uAELIgVBmQzg/q1YG10djyfA3TnrU8N8XzqCh/okZdszqBQTZf96idMfE5lnwTA==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/@socket.io/component-emitter": {
|
||||
"version": "3.1.2",
|
||||
"resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.2.tgz",
|
||||
"integrity": "sha512-9BCxFwvbGg/RsZK9tjXd8s4UcwR0MWeFQ1XEKIQVVvAGJyINdrqKMcTRyLoK8Rse1GjzLV9cwjWV1olXRWEXVA=="
|
||||
},
|
||||
"node_modules/@svgr/babel-plugin-add-jsx-attribute": {
|
||||
"version": "8.0.0",
|
||||
"resolved": "https://registry.npmjs.org/@svgr/babel-plugin-add-jsx-attribute/-/babel-plugin-add-jsx-attribute-8.0.0.tgz",
|
||||
@@ -8469,6 +8475,46 @@
|
||||
"once": "^1.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/engine.io-client": {
|
||||
"version": "6.6.2",
|
||||
"resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.6.2.tgz",
|
||||
"integrity": "sha512-TAr+NKeoVTjEVW8P3iHguO1LO6RlUz9O5Y8o7EY0fU+gY1NYqas7NN3slpFtbXEsLMHk0h90fJMfKjRkQ0qUIw==",
|
||||
"dependencies": {
|
||||
"@socket.io/component-emitter": "~3.1.0",
|
||||
"debug": "~4.3.1",
|
||||
"engine.io-parser": "~5.2.1",
|
||||
"ws": "~8.17.1",
|
||||
"xmlhttprequest-ssl": "~2.1.1"
|
||||
}
|
||||
},
|
||||
"node_modules/engine.io-client/node_modules/ws": {
|
||||
"version": "8.17.1",
|
||||
"resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz",
|
||||
"integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==",
|
||||
"engines": {
|
||||
"node": ">=10.0.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"bufferutil": "^4.0.1",
|
||||
"utf-8-validate": ">=5.0.2"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"bufferutil": {
|
||||
"optional": true
|
||||
},
|
||||
"utf-8-validate": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/engine.io-parser": {
|
||||
"version": "5.2.3",
|
||||
"resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.3.tgz",
|
||||
"integrity": "sha512-HqD3yTBfnBxIrbnM1DoD6Pcq8NECnh8d4As1Qgh0z5Gg3jRRIqijury0CL3ghu/edArpUYiYqQiDUQBIs4np3Q==",
|
||||
"engines": {
|
||||
"node": ">=10.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/entities": {
|
||||
"version": "4.5.0",
|
||||
"resolved": "https://registry.npmjs.org/entities/-/entities-4.5.0.tgz",
|
||||
@@ -22587,6 +22633,32 @@
|
||||
"tslib": "^2.0.3"
|
||||
}
|
||||
},
|
||||
"node_modules/socket.io-client": {
|
||||
"version": "4.8.1",
|
||||
"resolved": "https://registry.npmjs.org/socket.io-client/-/socket.io-client-4.8.1.tgz",
|
||||
"integrity": "sha512-hJVXfu3E28NmzGk8o1sHhN3om52tRvwYeidbj7xKy2eIIse5IoKX3USlS6Tqt3BHAtflLIkCQBkzVrEEfWUyYQ==",
|
||||
"dependencies": {
|
||||
"@socket.io/component-emitter": "~3.1.0",
|
||||
"debug": "~4.3.2",
|
||||
"engine.io-client": "~6.6.1",
|
||||
"socket.io-parser": "~4.2.4"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=10.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/socket.io-parser": {
|
||||
"version": "4.2.4",
|
||||
"resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.4.tgz",
|
||||
"integrity": "sha512-/GbIKmo8ioc+NIWIhwdecY0ge+qVBSMdgxGygevmdHj24bsfgtCmcUUcQ5ZzcylGFHsN3k4HB4Cgkl96KVnuew==",
|
||||
"dependencies": {
|
||||
"@socket.io/component-emitter": "~3.1.0",
|
||||
"debug": "~4.3.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=10.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/source-map": {
|
||||
"version": "0.7.4",
|
||||
"resolved": "https://registry.npmjs.org/source-map/-/source-map-0.7.4.tgz",
|
||||
@@ -25317,6 +25389,14 @@
|
||||
"integrity": "sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/xmlhttprequest-ssl": {
|
||||
"version": "2.1.2",
|
||||
"resolved": "https://registry.npmjs.org/xmlhttprequest-ssl/-/xmlhttprequest-ssl-2.1.2.tgz",
|
||||
"integrity": "sha512-TEU+nJVUUnA4CYJFLvK5X9AOeH4KvDvhIfm0vV1GaQRtchnG0hgK5p8hw/xjv8cunWYCsiPCSDzObPyhEwq3KQ==",
|
||||
"engines": {
|
||||
"node": ">=0.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/xtend": {
|
||||
"version": "4.0.2",
|
||||
"resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz",
|
||||
|
||||
@@ -39,6 +39,7 @@
|
||||
"react-textarea-autosize": "^8.5.4",
|
||||
"remark-gfm": "^4.0.0",
|
||||
"sirv-cli": "^3.0.0",
|
||||
"socket.io-client": "^4.8.1",
|
||||
"tailwind-merge": "^2.5.4",
|
||||
"vite": "^5.4.9",
|
||||
"web-vitals": "^3.5.2",
|
||||
@@ -120,4 +121,4 @@
|
||||
"public"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import posthog from "posthog-js";
|
||||
import React from "react";
|
||||
import { io, Socket } from "socket.io-client";
|
||||
import { Settings } from "#/services/settings";
|
||||
import ActionType from "#/types/ActionType";
|
||||
import EventLogger from "#/utils/event-logger";
|
||||
import AgentState from "#/types/AgentState";
|
||||
import { handleAssistantMessage } from "#/services/actions";
|
||||
|
||||
const RECONNECT_RETRIES = 5;
|
||||
|
||||
export enum WsClientProviderStatus {
|
||||
STOPPED,
|
||||
OPENING,
|
||||
@@ -43,38 +41,46 @@ export function WsClientProvider({
|
||||
settings,
|
||||
children,
|
||||
}: React.PropsWithChildren<WsClientProviderProps>) {
|
||||
const wsRef = React.useRef<WebSocket | null>(null);
|
||||
const sioRef = React.useRef<Socket | null>(null);
|
||||
const tokenRef = React.useRef<string | null>(token);
|
||||
const ghTokenRef = React.useRef<string | null>(ghToken);
|
||||
const closeRef = React.useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
const disconnectRef = React.useRef<ReturnType<typeof setTimeout> | null>(
|
||||
null,
|
||||
);
|
||||
const [status, setStatus] = React.useState(WsClientProviderStatus.STOPPED);
|
||||
const [events, setEvents] = React.useState<Record<string, unknown>[]>([]);
|
||||
const [retryCount, setRetryCount] = React.useState(RECONNECT_RETRIES);
|
||||
|
||||
function send(event: Record<string, unknown>) {
|
||||
if (!wsRef.current) {
|
||||
if (!sioRef.current) {
|
||||
EventLogger.error("WebSocket is not connected.");
|
||||
return;
|
||||
}
|
||||
wsRef.current.send(JSON.stringify(event));
|
||||
sioRef.current.emit("oh_action", event);
|
||||
}
|
||||
|
||||
function handleOpen() {
|
||||
setRetryCount(RECONNECT_RETRIES);
|
||||
function handleConnect() {
|
||||
setStatus(WsClientProviderStatus.OPENING);
|
||||
const initEvent = {
|
||||
|
||||
const initEvent: Record<string, unknown> = {
|
||||
action: ActionType.INIT,
|
||||
args: settings,
|
||||
};
|
||||
if (token) {
|
||||
initEvent.token = token;
|
||||
}
|
||||
if (ghToken) {
|
||||
initEvent.github_token = ghToken;
|
||||
}
|
||||
if (events.length) {
|
||||
// Wrong. Events is out of sync here...
|
||||
initEvent.latest_event_id = `${events[events.length - 1].id}`;
|
||||
}
|
||||
send(initEvent);
|
||||
}
|
||||
|
||||
function handleMessage(messageEvent: MessageEvent) {
|
||||
const event = JSON.parse(messageEvent.data);
|
||||
function handleMessage(event: Record<string, unknown>) {
|
||||
setEvents((prevEvents) => [...prevEvents, event]);
|
||||
if (event.extras?.agent_state === AgentState.INIT) {
|
||||
setStatus(WsClientProviderStatus.ACTIVE);
|
||||
}
|
||||
const extras = event.extras as Record<string, unknown>;
|
||||
if (
|
||||
status !== WsClientProviderStatus.ACTIVE &&
|
||||
event?.observation === "error"
|
||||
@@ -82,93 +88,98 @@ export function WsClientProvider({
|
||||
setStatus(WsClientProviderStatus.ERROR);
|
||||
}
|
||||
|
||||
handleAssistantMessage(event);
|
||||
}
|
||||
|
||||
function handleClose() {
|
||||
if (retryCount) {
|
||||
setTimeout(() => {
|
||||
setRetryCount(retryCount - 1);
|
||||
}, 1000);
|
||||
if (event.token) {
|
||||
setStatus(WsClientProviderStatus.ACTIVE);
|
||||
} else {
|
||||
setStatus(WsClientProviderStatus.STOPPED);
|
||||
setEvents([]);
|
||||
handleAssistantMessage(event);
|
||||
}
|
||||
wsRef.current = null;
|
||||
}
|
||||
|
||||
function handleError(event: Event) {
|
||||
function handleDisconnect() {
|
||||
setStatus(WsClientProviderStatus.STOPPED);
|
||||
// setEvents([]);
|
||||
// sioRef.current = null;
|
||||
}
|
||||
|
||||
function handleError() {
|
||||
posthog.capture("socket_error");
|
||||
EventLogger.event(event, "SOCKET ERROR");
|
||||
setStatus(WsClientProviderStatus.ERROR);
|
||||
sioRef.current?.disconnect();
|
||||
}
|
||||
|
||||
// Connect websocket
|
||||
React.useEffect(() => {
|
||||
let ws = wsRef.current;
|
||||
let sio = sioRef.current;
|
||||
|
||||
// If disabled close any existing websockets...
|
||||
if (!enabled || !retryCount) {
|
||||
if (ws) {
|
||||
ws.close();
|
||||
// If disabled disconnect any existing websockets...
|
||||
if (!enabled) {
|
||||
if (sio) {
|
||||
sio.disconnect();
|
||||
}
|
||||
wsRef.current = null;
|
||||
return () => {};
|
||||
}
|
||||
|
||||
// If there is no websocket or the tokens have changed or the current websocket is closed,
|
||||
// If there is no websocket or the tokens have changed or the current websocket is disconnected,
|
||||
// create a new one
|
||||
if (
|
||||
!ws ||
|
||||
!sio ||
|
||||
(tokenRef.current && token !== tokenRef.current) ||
|
||||
ghToken !== ghTokenRef.current ||
|
||||
ws.readyState === WebSocket.CLOSED ||
|
||||
ws.readyState === WebSocket.CLOSING
|
||||
ghToken !== ghTokenRef.current
|
||||
) {
|
||||
ws?.close();
|
||||
sio?.disconnect();
|
||||
|
||||
const baseUrl =
|
||||
import.meta.env.VITE_BACKEND_BASE_URL || window?.location.host;
|
||||
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
|
||||
let wsUrl = `${protocol}//${baseUrl}/ws`;
|
||||
if (events.length) {
|
||||
wsUrl += `?latest_event_id=${events[events.length - 1].id}`;
|
||||
}
|
||||
ws = new WebSocket(wsUrl, [
|
||||
"openhands",
|
||||
token || "NO_JWT",
|
||||
ghToken || "NO_GITHUB",
|
||||
]);
|
||||
sio = io(baseUrl, {
|
||||
transports: ["websocket"],
|
||||
// extraHeaders: {
|
||||
// Testy: "TESTER"
|
||||
// },
|
||||
// We force a new connection, because the headers may have changed.
|
||||
// forceNew: true,
|
||||
|
||||
// Had to do this for now because reconnection actually starts a new session,
|
||||
// which we don't want - The reconnect has the same headers as the original
|
||||
// which don't include the original session id
|
||||
// reconnection: false,
|
||||
// reconnectionDelay: 1000,
|
||||
// reconnectionDelayMax : 5000,
|
||||
// reconnectionAttempts: 5
|
||||
});
|
||||
}
|
||||
ws.addEventListener("open", handleOpen);
|
||||
ws.addEventListener("message", handleMessage);
|
||||
ws.addEventListener("error", handleError);
|
||||
ws.addEventListener("close", handleClose);
|
||||
wsRef.current = ws;
|
||||
sio.on("connect", handleConnect);
|
||||
sio.on("oh_event", handleMessage);
|
||||
sio.on("connect_error", handleError);
|
||||
sio.on("connect_failed", handleError);
|
||||
sio.on("disconnect", handleDisconnect);
|
||||
|
||||
sioRef.current = sio;
|
||||
tokenRef.current = token;
|
||||
ghTokenRef.current = ghToken;
|
||||
|
||||
return () => {
|
||||
ws.removeEventListener("open", handleOpen);
|
||||
ws.removeEventListener("message", handleMessage);
|
||||
ws.removeEventListener("error", handleError);
|
||||
ws.removeEventListener("close", handleClose);
|
||||
sio.off("connect", handleConnect);
|
||||
sio.off("oh_event", handleMessage);
|
||||
sio.off("connect_error", handleError);
|
||||
sio.off("connect_failed", handleError);
|
||||
sio.off("disconnect", handleDisconnect);
|
||||
};
|
||||
}, [enabled, token, ghToken, retryCount]);
|
||||
}, [enabled, token, ghToken, events]);
|
||||
|
||||
// Strict mode mounts and unmounts each component twice, so we have to wait in the destructor
|
||||
// before actually closing the socket and cancel the operation if the component gets remounted.
|
||||
// before actually disconnecting the socket and cancel the operation if the component gets remounted.
|
||||
React.useEffect(() => {
|
||||
const timeout = closeRef.current;
|
||||
const timeout = disconnectRef.current;
|
||||
if (timeout != null) {
|
||||
clearTimeout(timeout);
|
||||
}
|
||||
|
||||
return () => {
|
||||
closeRef.current = setTimeout(() => {
|
||||
const ws = wsRef.current;
|
||||
if (ws) {
|
||||
ws.removeEventListener("close", handleClose);
|
||||
ws.close();
|
||||
disconnectRef.current = setTimeout(() => {
|
||||
const sio = sioRef.current;
|
||||
if (sio) {
|
||||
sio.off("disconnect", handleDisconnect);
|
||||
sio.disconnect();
|
||||
}
|
||||
}, 100);
|
||||
};
|
||||
|
||||
@@ -12,6 +12,18 @@ import CodeEditorCompoonent from "./code-editor-component";
|
||||
import { useFiles } from "#/context/files";
|
||||
import { EditorActions } from "#/components/editor-actions";
|
||||
|
||||
const ASSET_FILE_TYPES = [
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".bmp",
|
||||
".gif",
|
||||
".pdf",
|
||||
".mp4",
|
||||
".webm",
|
||||
".ogg",
|
||||
];
|
||||
|
||||
export const clientLoader = async () => {
|
||||
const token = localStorage.getItem("token");
|
||||
return json({ token });
|
||||
@@ -104,6 +116,10 @@ function CodeEditor() {
|
||||
if (selectedPath) discardChanges(selectedPath);
|
||||
};
|
||||
|
||||
const isAssetFileType = selectedPath
|
||||
? ASSET_FILE_TYPES.some((ext) => selectedPath.endsWith(ext))
|
||||
: false;
|
||||
|
||||
return (
|
||||
<div className="flex h-full bg-neutral-900 relative">
|
||||
<FileExplorer
|
||||
@@ -112,7 +128,7 @@ function CodeEditor() {
|
||||
error={errors.getFiles}
|
||||
/>
|
||||
<div className="w-full">
|
||||
{selectedPath && (
|
||||
{selectedPath && !isAssetFileType && (
|
||||
<div className="flex w-full items-center justify-between self-end p-2">
|
||||
<span className="text-sm text-neutral-500">{selectedPath}</span>
|
||||
<EditorActions
|
||||
|
||||
@@ -82,6 +82,13 @@ export default defineConfig(({ mode }) => {
|
||||
changeOrigin: true,
|
||||
secure: !INSECURE_SKIP_VERIFY,
|
||||
},
|
||||
"/socket.io": {
|
||||
target: WS_URL,
|
||||
ws: true,
|
||||
changeOrigin: true,
|
||||
secure: !INSECURE_SKIP_VERIFY,
|
||||
//rewriteWsOrigin: true,
|
||||
}
|
||||
},
|
||||
},
|
||||
ssr: {
|
||||
|
||||
@@ -5,6 +5,7 @@ import traceback
|
||||
from typing import Callable, ClassVar, Type
|
||||
|
||||
import litellm
|
||||
from litellm.exceptions import ContextWindowExceededError
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State, TrafficControlState
|
||||
@@ -485,6 +486,15 @@ class AgentController:
|
||||
EventSource.AGENT,
|
||||
)
|
||||
return
|
||||
except ContextWindowExceededError:
|
||||
# When context window is exceeded, keep roughly half of agent interactions
|
||||
self.state.history = self._apply_conversation_window(self.state.history)
|
||||
|
||||
# Save the ID of the first event in our truncated history for future reloading
|
||||
if self.state.history:
|
||||
self.state.start_id = self.state.history[0].id
|
||||
# Don't add error event - let the agent retry with reduced context
|
||||
return
|
||||
|
||||
if action.runnable:
|
||||
if self.state.confirmation_mode and (
|
||||
@@ -659,6 +669,12 @@ class AgentController:
|
||||
- For delegate events (between AgentDelegateAction and AgentDelegateObservation):
|
||||
- Excludes all events between the action and observation
|
||||
- Includes the delegate action and observation themselves
|
||||
|
||||
The history is loaded in two parts if truncation_id is set:
|
||||
1. First user message from start_id onwards
|
||||
2. Rest of history from truncation_id to the end
|
||||
|
||||
Otherwise loads normally from start_id.
|
||||
"""
|
||||
|
||||
# define range of events to fetch
|
||||
@@ -680,8 +696,33 @@ class AgentController:
|
||||
self.state.history = []
|
||||
return
|
||||
|
||||
# Get all events, filtering out backend events and hidden events
|
||||
events = list(
|
||||
events: list[Event] = []
|
||||
|
||||
# If we have a truncation point, get first user message and then rest of history
|
||||
if hasattr(self.state, 'truncation_id') and self.state.truncation_id > 0:
|
||||
# Find first user message from stream
|
||||
first_user_msg = next(
|
||||
(
|
||||
e
|
||||
for e in self.event_stream.get_events(
|
||||
start_id=start_id,
|
||||
end_id=end_id,
|
||||
reverse=False,
|
||||
filter_out_type=self.filter_out,
|
||||
filter_hidden=True,
|
||||
)
|
||||
if isinstance(e, MessageAction) and e.source == EventSource.USER
|
||||
),
|
||||
None,
|
||||
)
|
||||
if first_user_msg:
|
||||
events.append(first_user_msg)
|
||||
|
||||
# the rest of the events are from the truncation point
|
||||
start_id = self.state.truncation_id
|
||||
|
||||
# Get rest of history
|
||||
events_to_add = list(
|
||||
self.event_stream.get_events(
|
||||
start_id=start_id,
|
||||
end_id=end_id,
|
||||
@@ -690,6 +731,7 @@ class AgentController:
|
||||
filter_hidden=True,
|
||||
)
|
||||
)
|
||||
events.extend(events_to_add)
|
||||
|
||||
# Find all delegate action/observation pairs
|
||||
delegate_ranges: list[tuple[int, int]] = []
|
||||
@@ -744,6 +786,92 @@ class AgentController:
|
||||
# make sure history is in sync
|
||||
self.state.start_id = start_id
|
||||
|
||||
def _apply_conversation_window(self, events: list[Event]) -> list[Event]:
|
||||
"""Cuts history roughly in half when context window is exceeded, preserving action-observation pairs
|
||||
and ensuring the first user message is always included.
|
||||
|
||||
The algorithm:
|
||||
1. Cut history in half
|
||||
2. Check first event in new history:
|
||||
- If Observation: find and include its Action
|
||||
- If MessageAction: ensure its related Action-Observation pair isn't split
|
||||
3. Always include the first user message
|
||||
|
||||
Args:
|
||||
events: List of events to filter
|
||||
|
||||
Returns:
|
||||
Filtered list of events keeping newest half while preserving pairs
|
||||
"""
|
||||
if not events:
|
||||
return events
|
||||
|
||||
# Find first user message - we'll need to ensure it's included
|
||||
first_user_msg = next(
|
||||
(
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, MessageAction) and e.source == EventSource.USER
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# cut in half
|
||||
mid_point = max(1, len(events) // 2)
|
||||
kept_events = events[mid_point:]
|
||||
|
||||
# Handle first event in truncated history
|
||||
if kept_events:
|
||||
i = 0
|
||||
while i < len(kept_events):
|
||||
first_event = kept_events[i]
|
||||
if isinstance(first_event, Observation) and first_event.cause:
|
||||
# Find its action and include it
|
||||
matching_action = next(
|
||||
(
|
||||
e
|
||||
for e in reversed(events[:mid_point])
|
||||
if isinstance(e, Action) and e.id == first_event.cause
|
||||
),
|
||||
None,
|
||||
)
|
||||
if matching_action:
|
||||
kept_events = [matching_action] + kept_events
|
||||
else:
|
||||
self.log(
|
||||
'warning',
|
||||
f'Found Observation without matching Action at id={first_event.id}',
|
||||
)
|
||||
# drop this observation
|
||||
kept_events = kept_events[1:]
|
||||
break
|
||||
|
||||
elif isinstance(first_event, MessageAction) or (
|
||||
isinstance(first_event, Action)
|
||||
and first_event.source == EventSource.USER
|
||||
):
|
||||
# if it's a message action or a user action, keep it and continue to find the next event
|
||||
i += 1
|
||||
continue
|
||||
|
||||
else:
|
||||
# if it's an action with source == EventSource.AGENT, we're good
|
||||
break
|
||||
|
||||
# Save where to continue from in next reload
|
||||
if kept_events:
|
||||
self.state.truncation_id = kept_events[0].id
|
||||
|
||||
# Ensure first user message is included
|
||||
if first_user_msg and first_user_msg not in kept_events:
|
||||
kept_events = [first_user_msg] + kept_events
|
||||
|
||||
# start_id points to first user message
|
||||
if first_user_msg:
|
||||
self.state.start_id = first_user_msg.id
|
||||
|
||||
return kept_events
|
||||
|
||||
def _is_stuck(self):
|
||||
"""Checks if the agent or its delegate is stuck in a loop.
|
||||
|
||||
|
||||
@@ -92,6 +92,8 @@ class State:
|
||||
# start_id and end_id track the range of events in history
|
||||
start_id: int = -1
|
||||
end_id: int = -1
|
||||
# truncation_id tracks where to load history after context window truncation
|
||||
truncation_id: int = -1
|
||||
almost_stuck: int = 0
|
||||
delegates: dict[tuple[int, int], tuple[str, str]] = field(default_factory=dict)
|
||||
# NOTE: This will never be used by the controller, but it can be used by different
|
||||
|
||||
@@ -111,9 +111,6 @@ class LogBuffer:
|
||||
def close(self, timeout: float = 5.0):
|
||||
self._stop_event.set()
|
||||
self.log_stream_thread.join(timeout)
|
||||
# Close the log generator to release the file descriptor
|
||||
if hasattr(self.log_generator, 'close'):
|
||||
self.log_generator.close()
|
||||
|
||||
|
||||
class EventStreamRuntime(Runtime):
|
||||
@@ -235,8 +232,6 @@ class EventStreamRuntime(Runtime):
|
||||
f'Container started: {self.container_name}. VSCode URL: {self.vscode_url}',
|
||||
)
|
||||
|
||||
self.log_buffer = LogBuffer(self.container, self.log)
|
||||
|
||||
if not self.attach_to_existing:
|
||||
self.log('info', f'Waiting for client to become ready at {self.api_url}...')
|
||||
self.send_status_message('STATUS$WAITING_FOR_CLIENT')
|
||||
@@ -363,6 +358,7 @@ class EventStreamRuntime(Runtime):
|
||||
environment=environment,
|
||||
volumes=volumes,
|
||||
)
|
||||
self.log_buffer = LogBuffer(self.container, self.log)
|
||||
self.log('debug', f'Container started. Server url: {self.api_url}')
|
||||
self.send_status_message('STATUS$CONTAINER_STARTED')
|
||||
except docker.errors.APIError as e:
|
||||
@@ -389,9 +385,11 @@ class EventStreamRuntime(Runtime):
|
||||
raise e
|
||||
|
||||
def _attach_to_container(self):
|
||||
container = self.docker_client.containers.get(self.container_name)
|
||||
self.log_buffer = LogBuffer(container, self.log)
|
||||
self.container = container
|
||||
self._container_port = 0
|
||||
self.container = self.docker_client.containers.get(self.container_name)
|
||||
for port in self.container.attrs['NetworkSettings']['Ports']: # type: ignore
|
||||
for port in container.attrs['NetworkSettings']['Ports']:
|
||||
self._container_port = int(port.split('/')[0])
|
||||
break
|
||||
self._host_port = self._container_port
|
||||
|
||||
@@ -115,15 +115,13 @@ async def get_github_user(token: str) -> str:
|
||||
github handle of the user
|
||||
"""
|
||||
logger.debug('Fetching GitHub user info from token')
|
||||
g = Github(token)
|
||||
try:
|
||||
g = Github(token)
|
||||
user = await call_sync_from_async(g.get_user)
|
||||
login = user.login
|
||||
logger.info(f'Successfully retrieved GitHub user: {login}')
|
||||
return login
|
||||
except GithubException as e:
|
||||
logger.error(f'Error making request to GitHub API: {str(e)}')
|
||||
logger.error(e)
|
||||
raise
|
||||
finally:
|
||||
g.close()
|
||||
login = user.login
|
||||
logger.info(f'Successfully retrieved GitHub user: {login}')
|
||||
return login
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
|
||||
import jwt
|
||||
import requests
|
||||
import socketio
|
||||
from pathspec import PathSpec
|
||||
from pathspec.patterns import GitWildMatchPattern
|
||||
|
||||
from openhands.core.schema.action import ActionType
|
||||
from openhands.security.options import SecurityAnalyzers
|
||||
from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
|
||||
from openhands.server.github import (
|
||||
@@ -19,8 +19,9 @@ from openhands.server.github import (
|
||||
UserVerifier,
|
||||
authenticate_github_user,
|
||||
)
|
||||
from openhands.server.pg_socket import AsyncPostgresAdapter
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.async_utils import call_async_from_sync, call_sync_from_async
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
@@ -33,7 +34,6 @@ from fastapi import (
|
||||
HTTPException,
|
||||
Request,
|
||||
UploadFile,
|
||||
WebSocket,
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
@@ -52,7 +52,6 @@ from openhands.events.action import (
|
||||
NullAction,
|
||||
)
|
||||
from openhands.events.observation import (
|
||||
AgentStateChangedObservation,
|
||||
ErrorObservation,
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
@@ -250,122 +249,6 @@ async def attach_session(request: Request, call_next):
|
||||
return response
|
||||
|
||||
|
||||
@app.websocket('/ws')
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
"""WebSocket endpoint for receiving events from the client (i.e., the browser).
|
||||
Once connected, the client can send various actions:
|
||||
- Initialize the agent:
|
||||
session management, and event streaming.
|
||||
```json
|
||||
{"action": "initialize", "args": {"LLM_MODEL": "ollama/llama3", "AGENT": "CodeActAgent", "LANGUAGE": "en", "LLM_API_KEY": "ollama"}}
|
||||
|
||||
Args:
|
||||
```
|
||||
websocket (WebSocket): The WebSocket connection object.
|
||||
- Start a new development task:
|
||||
```json
|
||||
{"action": "start", "args": {"task": "write a bash script that prints hello"}}
|
||||
```
|
||||
- Send a message:
|
||||
```json
|
||||
{"action": "message", "args": {"content": "Hello, how are you?", "image_urls": ["base64_url1", "base64_url2"]}}
|
||||
```
|
||||
- Write contents to a file:
|
||||
```json
|
||||
{"action": "write", "args": {"path": "./greetings.txt", "content": "Hello, OpenHands?"}}
|
||||
```
|
||||
- Read the contents of a file:
|
||||
```json
|
||||
{"action": "read", "args": {"path": "./greetings.txt"}}
|
||||
```
|
||||
- Run a command:
|
||||
```json
|
||||
{"action": "run", "args": {"command": "ls -l", "thought": "", "confirmation_state": "confirmed"}}
|
||||
```
|
||||
- Run an IPython command:
|
||||
```json
|
||||
{"action": "run_ipython", "args": {"command": "print('Hello, IPython!')"}}
|
||||
```
|
||||
- Open a web page:
|
||||
```json
|
||||
{"action": "browse", "args": {"url": "https://arxiv.org/html/2402.01030v2"}}
|
||||
```
|
||||
- Add a task to the root_task:
|
||||
```json
|
||||
{"action": "add_task", "args": {"task": "Implement feature X"}}
|
||||
```
|
||||
- Update a task in the root_task:
|
||||
```json
|
||||
{"action": "modify_task", "args": {"id": "0", "state": "in_progress", "thought": ""}}
|
||||
```
|
||||
- Change the agent's state:
|
||||
```json
|
||||
{"action": "change_agent_state", "args": {"state": "paused"}}
|
||||
```
|
||||
- Finish the task:
|
||||
```json
|
||||
{"action": "finish", "args": {}}
|
||||
```
|
||||
"""
|
||||
# Get protocols from Sec-WebSocket-Protocol header
|
||||
protocols = websocket.headers.get('sec-websocket-protocol', '').split(', ')
|
||||
|
||||
# The first protocol should be our real protocol (e.g. 'openhands')
|
||||
# The second protocol should contain our auth token
|
||||
if len(protocols) < 3:
|
||||
logger.error('Expected 3 websocket protocols, got %d', len(protocols))
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
real_protocol = protocols[0]
|
||||
jwt_token = protocols[1] if protocols[1] != 'NO_JWT' else ''
|
||||
github_token = protocols[2] if protocols[2] != 'NO_GITHUB' else ''
|
||||
|
||||
if not await authenticate_github_user(github_token):
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
await asyncio.wait_for(websocket.accept(subprotocol=real_protocol), 10)
|
||||
|
||||
if jwt_token:
|
||||
sid = get_sid_from_token(jwt_token, config.jwt_secret)
|
||||
|
||||
if sid == '':
|
||||
await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
|
||||
await websocket.close()
|
||||
return
|
||||
else:
|
||||
sid = str(uuid.uuid4())
|
||||
jwt_token = sign_token({'sid': sid}, config.jwt_secret)
|
||||
|
||||
logger.info(f'New session: {sid}')
|
||||
session = session_manager.add_or_restart_session(sid, websocket)
|
||||
await websocket.send_json({'token': jwt_token, 'status': 'ok'})
|
||||
|
||||
latest_event_id = -1
|
||||
if websocket.query_params.get('latest_event_id'):
|
||||
latest_event_id = int(websocket.query_params.get('latest_event_id'))
|
||||
|
||||
async_stream = AsyncEventStreamWrapper(
|
||||
session.agent_session.event_stream, latest_event_id + 1
|
||||
)
|
||||
|
||||
async for event in async_stream:
|
||||
if isinstance(
|
||||
event,
|
||||
(
|
||||
NullAction,
|
||||
NullObservation,
|
||||
ChangeAgentStateAction,
|
||||
AgentStateChangedObservation,
|
||||
),
|
||||
):
|
||||
continue
|
||||
await websocket.send_json(event_to_dict(event))
|
||||
|
||||
await session.loop_recv()
|
||||
|
||||
|
||||
@app.get('/api/options/models')
|
||||
async def get_litellm_models() -> list[str]:
|
||||
"""
|
||||
@@ -930,3 +813,135 @@ class SPAStaticFiles(StaticFiles):
|
||||
|
||||
|
||||
app.mount('/', SPAStaticFiles(directory='./frontend/build', html=True), name='dist')
|
||||
|
||||
use_manager = os.getenv('DB_HOST') or os.getenv('GCP_DB_INSTANCE')
|
||||
manager = AsyncPostgresAdapter() if use_manager else None
|
||||
if manager:
|
||||
call_async_from_sync(manager.setup, 10)
|
||||
sio = socketio.AsyncServer(
|
||||
async_mode='asgi', cors_allowed_origins='*', client_manager=manager
|
||||
)
|
||||
app = socketio.ASGIApp(sio, other_asgi_app=app)
|
||||
|
||||
|
||||
@sio.event
|
||||
async def connect(connection_id: str, environ):
|
||||
logger.info(f'sio:connect: {connection_id}')
|
||||
|
||||
|
||||
@sio.event
|
||||
async def oh_action(connection_id: str, data: dict):
|
||||
"""WebSocket endpoint for receiving events from the client (i.e., the browser).
|
||||
Once connected, the client can send various actions:
|
||||
- Initialize the agent:
|
||||
session management, and event streaming.
|
||||
```json
|
||||
{"action": "initialize", "args": {"LLM_MODEL": "ollama/llama3", "AGENT": "CodeActAgent", "LANGUAGE": "en", "LLM_API_KEY": "ollama"}}
|
||||
|
||||
Args:
|
||||
```
|
||||
websocket (WebSocket): The WebSocket connection object.
|
||||
- Start a new development task:
|
||||
```json
|
||||
{"action": "start", "args": {"task": "write a bash script that prints hello"}}
|
||||
```
|
||||
- Send a message:
|
||||
```json
|
||||
{"action": "message", "args": {"content": "Hello, how are you?", "image_urls": ["base64_url1", "base64_url2"]}}
|
||||
```
|
||||
- Write contents to a file:
|
||||
```json
|
||||
{"action": "write", "args": {"path": "./greetings.txt", "content": "Hello, OpenHands?"}}
|
||||
```
|
||||
- Read the contents of a file:
|
||||
```json
|
||||
{"action": "read", "args": {"path": "./greetings.txt"}}
|
||||
```
|
||||
- Run a command:
|
||||
```json
|
||||
{"action": "run", "args": {"command": "ls -l", "thought": "", "confirmation_state": "confirmed"}}
|
||||
```
|
||||
- Run an IPython command:
|
||||
```json
|
||||
{"action": "run_ipython", "args": {"command": "print('Hello, IPython!')"}}
|
||||
```
|
||||
- Open a web page:
|
||||
```json
|
||||
{"action": "browse", "args": {"url": "https://arxiv.org/html/2402.01030v2"}}
|
||||
```
|
||||
- Add a task to the root_task:
|
||||
```json
|
||||
{"action": "add_task", "args": {"task": "Implement feature X"}}
|
||||
```
|
||||
- Update a task in the root_task:
|
||||
```json
|
||||
{"action": "modify_task", "args": {"id": "0", "state": "in_progress", "thought": ""}}
|
||||
```
|
||||
- Change the agent's state:
|
||||
```json
|
||||
{"action": "change_agent_state", "args": {"state": "paused"}}
|
||||
```
|
||||
- Finish the task:
|
||||
```json
|
||||
{"action": "finish", "args": {}}
|
||||
```
|
||||
"""
|
||||
|
||||
# If it's an init, we do it here.
|
||||
action = data.get('action', '')
|
||||
if action == ActionType.INIT:
|
||||
await init_connection(connection_id, data)
|
||||
return
|
||||
|
||||
logger.info(f'sio:oh_action:{connection_id}')
|
||||
session = session_manager.get_local_session(connection_id)
|
||||
await session.dispatch(data)
|
||||
|
||||
|
||||
async def init_connection(connection_id: str, data: dict):
|
||||
gh_token = data.pop('gh_token', None)
|
||||
if not await authenticate_github_user(gh_token):
|
||||
raise RuntimeError(status.WS_1008_POLICY_VIOLATION)
|
||||
|
||||
token = data.pop('token', None)
|
||||
if token:
|
||||
sid = get_sid_from_token(token, config.jwt_secret)
|
||||
if sid == '':
|
||||
await sio.send({'error': 'Invalid token', 'error_code': 401})
|
||||
return
|
||||
logger.info(f'Existing session: {sid}')
|
||||
else:
|
||||
sid = connection_id
|
||||
logger.info(f'New session: {sid}')
|
||||
|
||||
token = sign_token({'sid': sid}, config.jwt_secret)
|
||||
await sio.emit('oh_event', {'token': token, 'status': 'ok'}, to=connection_id)
|
||||
|
||||
latest_event_id = int(data.pop('latest_event_id', -1))
|
||||
|
||||
# The session in question should exist, but may not actually be running locally...
|
||||
session = await session_manager.init_or_join_local_session(
|
||||
sio, sid, connection_id, data
|
||||
)
|
||||
|
||||
# Send events
|
||||
async_stream = AsyncEventStreamWrapper(
|
||||
session.agent_session.event_stream, latest_event_id + 1
|
||||
)
|
||||
async for event in async_stream:
|
||||
if isinstance(
|
||||
event,
|
||||
(
|
||||
NullAction,
|
||||
NullObservation,
|
||||
ChangeAgentStateAction,
|
||||
),
|
||||
):
|
||||
continue
|
||||
await sio.emit('oh_event', event_to_dict(event), to=connection_id)
|
||||
|
||||
|
||||
@sio.event
|
||||
async def disconnect(connection_id: str):
|
||||
logger.info(f'sio:disconnect:{connection_id}')
|
||||
await session_manager.disconnect_from_local_session(connection_id)
|
||||
|
||||
242
openhands/server/pg_socket.py
Normal file
242
openhands/server/pg_socket.py
Normal file
@@ -0,0 +1,242 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
from asyncio import Task
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import asyncpg
|
||||
from asyncpg.exceptions import PostgresError
|
||||
from asyncpg.pool import Pool
|
||||
from google.cloud.sql.connector import Connector
|
||||
from socketio.async_pubsub_manager import AsyncPubSubManager
|
||||
|
||||
|
||||
def has_binary(obj: Any, to_json: bool = False) -> bool:
|
||||
if not obj or not isinstance(obj, (dict, list, bytes, bytearray, memoryview)):
|
||||
return False
|
||||
|
||||
if isinstance(obj, (bytes, bytearray, memoryview)):
|
||||
return True
|
||||
|
||||
if isinstance(obj, list):
|
||||
return any(has_binary(item) for item in obj)
|
||||
|
||||
if isinstance(obj, dict):
|
||||
return any(has_binary(v) for v in obj.values())
|
||||
|
||||
if hasattr(obj, 'to_json') and callable(obj.to_json) and not to_json:
|
||||
return has_binary(obj.to_json(), True)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class AsyncPostgresAdapter(AsyncPubSubManager):
|
||||
def __init__(
|
||||
self,
|
||||
channel: str = 'socketio',
|
||||
table_name: str = 'socket_io_attachments',
|
||||
payload_threshold: int = 8000,
|
||||
cleanup_interval: int = 30000,
|
||||
):
|
||||
self.channel = channel
|
||||
super().__init__(channel=channel)
|
||||
self.table_name = table_name
|
||||
self.payload_threshold = payload_threshold
|
||||
self.cleanup_interval = cleanup_interval
|
||||
self._cleanup_timer: Optional[Task] = None
|
||||
self._client: Optional[asyncpg.Connection] = None
|
||||
self.pool: Optional[Pool] = None
|
||||
|
||||
# Connection configs
|
||||
self.db_host = os.environ.get('DB_HOST')
|
||||
self.db_user = os.environ.get('DB_USER')
|
||||
self.db_pass = os.environ.get('DB_PASS', '').strip()
|
||||
self.db_name = os.environ.get('DB_NAME')
|
||||
|
||||
# GCP configs
|
||||
self.gcp_instance = os.environ.get('GCP_DB_INSTANCE')
|
||||
self.gcp_project = os.environ.get('GCP_PROJECT')
|
||||
self.gcp_region = os.environ.get('GCP_REGION')
|
||||
|
||||
self.connector = Connector() if self.gcp_instance else None
|
||||
|
||||
async def setup(self):
|
||||
if not self.pool:
|
||||
self.pool = await self._create_pool()
|
||||
await self._create_db()
|
||||
await self._init_client()
|
||||
|
||||
async def _create_pool(self) -> Pool:
|
||||
if self.gcp_instance:
|
||||
return await asyncpg.create_pool(
|
||||
lambda: self._get_gcp_connection(self.db_name)
|
||||
)
|
||||
else:
|
||||
return await asyncpg.create_pool(
|
||||
user=self.db_user,
|
||||
password=self.db_pass,
|
||||
host=self.db_host,
|
||||
database=self.db_name,
|
||||
)
|
||||
|
||||
async def _get_gcp_connection(
|
||||
self, db_name: Optional[str] = None
|
||||
) -> asyncpg.Connection:
|
||||
instance_string = f'{self.gcp_project}:{self.gcp_region}:{self.gcp_instance}'
|
||||
|
||||
conn = await self.connector.connect_async(
|
||||
instance_connection_string=instance_string,
|
||||
driver='asyncpg',
|
||||
user=self.db_user,
|
||||
password=self.db_pass,
|
||||
db=db_name or self.db_name,
|
||||
)
|
||||
return conn
|
||||
|
||||
async def _init_client(self) -> None:
|
||||
if not self.pool:
|
||||
raise RuntimeError('Pool not initialized')
|
||||
|
||||
try:
|
||||
self._client = await self.pool.acquire()
|
||||
await self._client.execute(f'LISTEN "{self.channel}"')
|
||||
|
||||
self._client.add_listener(self.channel, self._on_notification)
|
||||
|
||||
if not self._cleanup_timer:
|
||||
self._schedule_cleanup()
|
||||
|
||||
except PostgresError as e:
|
||||
self.logger.error(f'Error initializing client: {e}')
|
||||
await asyncio.sleep(2)
|
||||
await self._init_client()
|
||||
|
||||
def _schedule_cleanup(self) -> None:
|
||||
async def cleanup() -> None:
|
||||
if not self.pool:
|
||||
return
|
||||
|
||||
try:
|
||||
await self.pool.execute(
|
||||
f"DELETE FROM {self.table_name} WHERE created_at < now() - interval '{self.cleanup_interval} milliseconds'"
|
||||
)
|
||||
except PostgresError as e:
|
||||
self.logger.error(f'Cleanup error: {e}')
|
||||
|
||||
self._cleanup_timer = asyncio.create_task(cleanup())
|
||||
await asyncio.sleep(self.cleanup_interval / 1000)
|
||||
|
||||
self._cleanup_timer = asyncio.create_task(cleanup())
|
||||
|
||||
async def _publish_with_attachment(self, data: Dict) -> None:
|
||||
if not self.pool:
|
||||
raise RuntimeError('Pool not initialized')
|
||||
|
||||
payload = pickle.dumps(data)
|
||||
result = await self.pool.fetchrow(
|
||||
f'INSERT INTO {self.table_name} (payload) VALUES ($1) RETURNING id', payload
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError('Failed to insert payload')
|
||||
|
||||
notification = {
|
||||
'uid': self.uid,
|
||||
'type': data['type'],
|
||||
'attachmentId': result['id'],
|
||||
}
|
||||
await self.pool.execute(
|
||||
'SELECT pg_notify($1, $2)', self.channel, json.dumps(notification)
|
||||
)
|
||||
|
||||
async def _publish(self, data: Dict) -> None:
|
||||
if not self.pool:
|
||||
raise RuntimeError('Pool not initialized')
|
||||
|
||||
try:
|
||||
data['uid'] = self.uid
|
||||
|
||||
if has_binary(data) or len(json.dumps(data)) > self.payload_threshold:
|
||||
await self._publish_with_attachment(data)
|
||||
return
|
||||
|
||||
await self.pool.execute(
|
||||
'SELECT pg_notify($1, $2)', self.channel, json.dumps(data)
|
||||
)
|
||||
|
||||
except PostgresError as e:
|
||||
self.logger.error(f'Publish error: {e}')
|
||||
raise
|
||||
|
||||
async def _on_notification(self, conn, pid, channel, payload) -> None:
|
||||
if not self.pool:
|
||||
return
|
||||
|
||||
try:
|
||||
data = json.loads(payload)
|
||||
|
||||
if data.get('uid') == self.uid:
|
||||
return
|
||||
|
||||
if 'attachmentId' in data:
|
||||
result = await self.pool.fetchrow(
|
||||
f'SELECT payload FROM {self.table_name} WHERE id = $1',
|
||||
data['attachmentId'],
|
||||
)
|
||||
if not result:
|
||||
self.logger.error(f"Attachment {data['attachmentId']} not found")
|
||||
return
|
||||
|
||||
data = pickle.loads(result['payload'])
|
||||
|
||||
await self._handle_message(data)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f'Notification error: {e}')
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._cleanup_timer:
|
||||
self._cleanup_timer.cancel()
|
||||
if self._client and self.pool:
|
||||
await self.pool.release(self._client)
|
||||
if self.connector:
|
||||
await self.connector.close_async()
|
||||
await super().close()
|
||||
|
||||
async def _create_db(self) -> None:
|
||||
try:
|
||||
# Connect to default postgres DB first
|
||||
if self.gcp_instance:
|
||||
sys_conn = await self._get_gcp_connection('postgres')
|
||||
else:
|
||||
sys_conn = await asyncpg.connect(
|
||||
user=self.db_user,
|
||||
password=self.db_pass,
|
||||
host=self.db_host,
|
||||
database='postgres',
|
||||
)
|
||||
|
||||
try:
|
||||
# Create DB if needed
|
||||
exists = await sys_conn.fetchval(
|
||||
'SELECT 1 FROM pg_database WHERE datname = $1', self.db_name
|
||||
)
|
||||
if not exists:
|
||||
await sys_conn.execute(f'CREATE DATABASE "{self.db_name}"')
|
||||
finally:
|
||||
await sys_conn.close()
|
||||
|
||||
# Create attachments table
|
||||
if not self.pool:
|
||||
raise RuntimeError('Pool not initialized')
|
||||
|
||||
await self.pool.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
id SERIAL PRIMARY KEY,
|
||||
payload BYTEA NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
except Exception as e:
|
||||
self.logger.error(f'Database creation error: {e}')
|
||||
raise
|
||||
@@ -1,7 +1,9 @@
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fastapi import WebSocket
|
||||
import socketio
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -15,11 +17,8 @@ from openhands.storage.files import FileStore
|
||||
class SessionManager:
|
||||
config: AppConfig
|
||||
file_store: FileStore
|
||||
|
||||
def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session:
|
||||
return Session(
|
||||
sid=sid, file_store=self.file_store, ws=ws_conn, config=self.config
|
||||
)
|
||||
local_sessions_by_sid: dict[str, Session] = field(default_factory=dict)
|
||||
local_sessions_by_connection_id: dict[str, Session] = field(default_factory=dict)
|
||||
|
||||
async def attach_to_conversation(self, sid: str) -> Conversation | None:
|
||||
start_time = time.time()
|
||||
@@ -35,3 +34,41 @@ class SessionManager:
|
||||
|
||||
async def detach_from_conversation(self, conversation: Conversation):
|
||||
await conversation.disconnect()
|
||||
|
||||
async def init_or_join_local_session(self, sio: socketio.AsyncServer, sid: str, connection_id: str, data: dict):
|
||||
""" If there is no local session running, initialize one """
|
||||
session = self.local_sessions_by_sid.get(sid)
|
||||
if not session:
|
||||
session = Session(
|
||||
sid=sid, file_store=self.file_store, config=self.config, sio=sio, ws=None
|
||||
)
|
||||
session.connect(connection_id)
|
||||
self.local_sessions_by_sid[sid] = session
|
||||
self.local_sessions_by_connection_id[connection_id] = session
|
||||
await session.initialize_agent(data)
|
||||
else:
|
||||
session.connect(connection_id)
|
||||
self.local_sessions_by_connection_id[connection_id] = session
|
||||
return session
|
||||
|
||||
def get_local_session(self, connection_id: str) -> Session:
|
||||
return self.local_sessions_by_connection_id[connection_id]
|
||||
|
||||
async def disconnect_from_local_session(self, connection_id: str):
|
||||
session = self.local_sessions_by_connection_id.pop(connection_id, None)
|
||||
if not session:
|
||||
# This can occur if the init action was never run.
|
||||
logger.warning(f'disconnect_from_uninitialized_session:{connection_id}')
|
||||
return
|
||||
if session.disconnect(connection_id):
|
||||
asyncio.create_task(self._check_and_close_session(session))
|
||||
|
||||
async def _check_and_close_session(self, session: Session):
|
||||
# Once there have been no connections to a session for a reasonable period, we close it
|
||||
try:
|
||||
await asyncio.sleep(15)
|
||||
finally:
|
||||
# If the sleep was cancelled, we still want to close these
|
||||
if not session.connection_ids:
|
||||
session.close()
|
||||
self.local_sessions_by_sid.pop(session.sid)
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import socketio
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
@@ -23,22 +24,30 @@ from openhands.events.stream import EventStreamSubscriber
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.server.session.agent_session import AgentSession
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.utils.shutdown_listener import should_continue
|
||||
from openhands.utils.async_utils import wait_all
|
||||
|
||||
|
||||
class Session:
|
||||
sid: str
|
||||
websocket: WebSocket | None
|
||||
sio: socketio.AsyncServer | None
|
||||
connection_ids: set[str]
|
||||
last_active_ts: int = 0
|
||||
is_alive: bool = True
|
||||
agent_session: AgentSession
|
||||
loop: asyncio.AbstractEventLoop
|
||||
|
||||
def __init__(
|
||||
self, sid: str, ws: WebSocket | None, config: AppConfig, file_store: FileStore
|
||||
self,
|
||||
sid: str,
|
||||
ws: WebSocket | None,
|
||||
config: AppConfig,
|
||||
file_store: FileStore,
|
||||
sio: socketio.AsyncServer | None,
|
||||
):
|
||||
self.sid = sid
|
||||
self.websocket = ws
|
||||
self.sio = sio
|
||||
self.last_active_ts = int(time.time())
|
||||
self.agent_session = AgentSession(
|
||||
sid, file_store, status_callback=self.queue_status_message
|
||||
@@ -47,36 +56,21 @@ class Session:
|
||||
EventStreamSubscriber.SERVER, self.on_event, self.sid
|
||||
)
|
||||
self.config = config
|
||||
self.connection_ids = set()
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
def connect(self, connection_id: str):
|
||||
self.connection_ids.add(connection_id)
|
||||
|
||||
def disconnect(self, connection_id: str) -> bool:
|
||||
self.connection_ids.remove(connection_id)
|
||||
return not self.connection_ids
|
||||
|
||||
def close(self):
|
||||
self.is_alive = False
|
||||
try:
|
||||
if self.websocket is not None:
|
||||
self.websocket.close()
|
||||
self.websocket = None
|
||||
finally:
|
||||
self.agent_session.close()
|
||||
self.agent_session.close()
|
||||
|
||||
async def loop_recv(self):
|
||||
try:
|
||||
if self.websocket is None:
|
||||
return
|
||||
while should_continue():
|
||||
try:
|
||||
data = await self.websocket.receive_json()
|
||||
except ValueError:
|
||||
await self.send_error('Invalid JSON')
|
||||
continue
|
||||
await self.dispatch(data)
|
||||
except WebSocketDisconnect:
|
||||
logger.info('WebSocket disconnected, sid: %s', self.sid)
|
||||
self.close()
|
||||
except RuntimeError as e:
|
||||
logger.exception('Error in loop_recv: %s', e)
|
||||
self.close()
|
||||
|
||||
async def _initialize_agent(self, data: dict):
|
||||
async def initialize_agent(self, data: dict):
|
||||
self.agent_session.event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.LOADING), EventSource.ENVIRONMENT
|
||||
)
|
||||
@@ -164,7 +158,7 @@ class Session:
|
||||
async def dispatch(self, data: dict):
|
||||
action = data.get('action', '')
|
||||
if action == ActionType.INIT:
|
||||
await self._initialize_agent(data)
|
||||
await self.initialize_agent(data)
|
||||
return
|
||||
event = event_from_dict(data.copy())
|
||||
# This checks if the model supports images
|
||||
@@ -185,9 +179,15 @@ class Session:
|
||||
|
||||
async def send(self, data: dict[str, object]) -> bool:
|
||||
try:
|
||||
if self.websocket is None or not self.is_alive:
|
||||
if not self.is_alive:
|
||||
return False
|
||||
await self.websocket.send_json(data)
|
||||
if self.websocket:
|
||||
await self.websocket.send_json(data)
|
||||
if self.sio:
|
||||
await wait_all(
|
||||
self.sio.emit('oh_event', data, to=connection_id)
|
||||
for connection_id in self.connection_ids
|
||||
)
|
||||
await asyncio.sleep(0.001) # This flushes the data to the client
|
||||
self.last_active_ts = int(time.time())
|
||||
return True
|
||||
|
||||
1100
poetry.lock
generated
1100
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -64,6 +64,9 @@ modal = "^0.64.145"
|
||||
runloop-api-client = "0.7.0"
|
||||
pygithub = "^2.5.0"
|
||||
openhands-aci = "^0.1.0"
|
||||
python-socketio = "^5.11.4"
|
||||
asyncpg = "^0.30.0"
|
||||
cloud-sql-python-connector = "^1.13.0"
|
||||
|
||||
[tool.poetry.group.llama-index.dependencies]
|
||||
llama-index = "*"
|
||||
@@ -95,6 +98,7 @@ reportlab = "*"
|
||||
[tool.coverage.run]
|
||||
concurrency = ["gevent"]
|
||||
|
||||
|
||||
[tool.poetry.group.runtime.dependencies]
|
||||
jupyterlab = "*"
|
||||
notebook = "*"
|
||||
@@ -125,6 +129,7 @@ ignore = ["D1"]
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
|
||||
[tool.poetry.group.evaluation.dependencies]
|
||||
streamlit = "*"
|
||||
whatthepatch = "*"
|
||||
|
||||
188
tests/unit/test_truncation.py
Normal file
188
tests/unit/test_truncation.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
from openhands.events import EventSource
|
||||
from openhands.events.action import CmdRunAction, MessageAction
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_stream():
|
||||
stream = MagicMock()
|
||||
# Mock get_events to return an empty list by default
|
||||
stream.get_events.return_value = []
|
||||
return stream
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
agent = MagicMock()
|
||||
agent.llm = MagicMock()
|
||||
agent.llm.config = MagicMock()
|
||||
return agent
|
||||
|
||||
|
||||
class TestTruncation:
|
||||
def test_apply_conversation_window_basic(self, mock_event_stream, mock_agent):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_truncation',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Create a sequence of events with IDs
|
||||
first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
|
||||
first_msg._source = EventSource.USER
|
||||
first_msg._id = 1
|
||||
|
||||
cmd1 = CmdRunAction(command='ls')
|
||||
cmd1._id = 2
|
||||
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=2)
|
||||
obs1._id = 3
|
||||
obs1._cause = 2
|
||||
|
||||
cmd2 = CmdRunAction(command='pwd')
|
||||
cmd2._id = 4
|
||||
obs2 = CmdOutputObservation(command='pwd', content='/home', command_id=4)
|
||||
obs2._id = 5
|
||||
obs2._cause = 4
|
||||
|
||||
events = [first_msg, cmd1, obs1, cmd2, obs2]
|
||||
|
||||
# Apply truncation
|
||||
truncated = controller._apply_conversation_window(events)
|
||||
|
||||
# Should keep first user message and roughly half of other events
|
||||
assert (
|
||||
len(truncated) >= 3
|
||||
) # First message + at least one action-observation pair
|
||||
assert truncated[0] == first_msg # First message always preserved
|
||||
assert controller.state.start_id == first_msg._id
|
||||
assert controller.state.truncation_id is not None
|
||||
|
||||
# Verify pairs aren't split
|
||||
for i, event in enumerate(truncated[1:]):
|
||||
if isinstance(event, CmdOutputObservation):
|
||||
assert any(e._id == event._cause for e in truncated[: i + 1])
|
||||
|
||||
def test_context_window_exceeded_handling(self, mock_event_stream, mock_agent):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_truncation',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Setup initial history with IDs
|
||||
first_msg = MessageAction(content='Start task', wait_for_response=False)
|
||||
first_msg._source = EventSource.USER
|
||||
first_msg._id = 1
|
||||
|
||||
# Add agent question
|
||||
agent_msg = MessageAction(
|
||||
content='What task would you like me to perform?', wait_for_response=True
|
||||
)
|
||||
agent_msg._source = EventSource.AGENT
|
||||
agent_msg._id = 2
|
||||
|
||||
# Add user response
|
||||
user_response = MessageAction(
|
||||
content='Please list all files and show me current directory',
|
||||
wait_for_response=False,
|
||||
)
|
||||
user_response._source = EventSource.USER
|
||||
user_response._id = 3
|
||||
|
||||
cmd1 = CmdRunAction(command='ls')
|
||||
cmd1._id = 4
|
||||
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=4)
|
||||
obs1._id = 5
|
||||
obs1._cause = 4
|
||||
|
||||
# Update mock event stream to include new messages
|
||||
mock_event_stream.get_events.return_value = [
|
||||
first_msg,
|
||||
agent_msg,
|
||||
user_response,
|
||||
cmd1,
|
||||
obs1,
|
||||
]
|
||||
controller.state.history = [first_msg, agent_msg, user_response, cmd1, obs1]
|
||||
original_history_len = len(controller.state.history)
|
||||
|
||||
# Simulate ContextWindowExceededError and truncation
|
||||
controller.state.history = controller._apply_conversation_window(
|
||||
controller.state.history
|
||||
)
|
||||
|
||||
# Verify truncation occurred
|
||||
assert len(controller.state.history) < original_history_len
|
||||
assert controller.state.start_id == first_msg._id
|
||||
assert controller.state.truncation_id is not None
|
||||
assert controller.state.truncation_id > controller.state.start_id
|
||||
|
||||
def test_history_restoration_after_truncation(self, mock_event_stream, mock_agent):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_truncation',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Create events with IDs
|
||||
first_msg = MessageAction(content='Start task', wait_for_response=False)
|
||||
first_msg._source = EventSource.USER
|
||||
first_msg._id = 1
|
||||
|
||||
events = [first_msg]
|
||||
for i in range(5):
|
||||
cmd = CmdRunAction(command=f'cmd{i}')
|
||||
cmd._id = i + 2
|
||||
obs = CmdOutputObservation(
|
||||
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
|
||||
)
|
||||
obs._cause = cmd._id
|
||||
events.extend([cmd, obs])
|
||||
|
||||
# Set up initial history
|
||||
controller.state.history = events.copy()
|
||||
|
||||
# Force truncation
|
||||
controller.state.history = controller._apply_conversation_window(
|
||||
controller.state.history
|
||||
)
|
||||
|
||||
# Save state
|
||||
saved_start_id = controller.state.start_id
|
||||
saved_truncation_id = controller.state.truncation_id
|
||||
saved_history_len = len(controller.state.history)
|
||||
|
||||
# Set up mock event stream for new controller
|
||||
mock_event_stream.get_events.return_value = controller.state.history
|
||||
|
||||
# Create new controller with saved state
|
||||
new_controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_truncation',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
new_controller.state.start_id = saved_start_id
|
||||
new_controller.state.truncation_id = saved_truncation_id
|
||||
new_controller.state.history = mock_event_stream.get_events()
|
||||
|
||||
# Verify restoration
|
||||
assert len(new_controller.state.history) == saved_history_len
|
||||
assert new_controller.state.history[0] == first_msg
|
||||
assert new_controller.state.start_id == saved_start_id
|
||||
Reference in New Issue
Block a user