Compare commits

..

33 Commits

Author SHA1 Message Date
Robert Brennan
dbda7bcc82 add cloud sql 2024-11-15 13:55:50 -05:00
Robert Brennan
2f275dce9e new pg impl 2024-11-15 13:55:04 -05:00
Robert Brennan
340b9ead40 debug logs 2024-11-15 13:46:01 -05:00
Robert Brennan
1fba72043c create db if not exists 2024-11-15 13:38:06 -05:00
Robert Brennan
bcab72c981 add logging 2024-11-15 13:26:36 -05:00
Robert Brennan
5949f5bdb6 fix import 2024-11-15 13:19:32 -05:00
Robert Brennan
20d4a0cce2 add asyncpg 2024-11-15 13:18:28 -05:00
Robert Brennan
37ee90cf6f add gcp support 2024-11-15 13:17:46 -05:00
Robert Brennan
69c646ebf9 add pg socket 2024-11-15 12:56:25 -05:00
Robert Brennan
71e57dc069 better pg creds 2024-11-15 12:52:32 -05:00
Robert Brennan
3f03fb69c8 add pg socket 2024-11-15 12:49:38 -05:00
Tim O'Farrell
551a5a72e1 Client side error fix 2024-11-15 09:55:39 -07:00
Tim O'Farrell
2aa6bab85f Fix initial startup 2024-11-15 09:27:44 -07:00
Tim O'Farrell
3f0e619997 WIP 2024-11-15 08:52:12 -07:00
Tim O'Farrell
511ce8cc3a Fixed closing 2024-11-15 08:29:51 -07:00
Tim O'Farrell
a1b53a9498 Culled dead code 2024-11-15 07:39:30 -07:00
Tim O'Farrell
767878563e WIP 2024-11-15 07:15:21 -07:00
Tim O'Farrell
b2bd23ac86 Fix for test errors - client side seems to work 2024-11-14 18:46:13 -07:00
Tim O'Farrell
f0487e6818 Server side code is working - now on the client side 2024-11-14 18:37:33 -07:00
Tim O'Farrell
b85fbf39fd Tokens in init event 2024-11-14 15:45:57 -07:00
Tim O'Farrell
8b4d263319 Merge branch 'main' into feat-socket-io 2024-11-14 07:25:01 -07:00
Tim O'Farrell
ff7783ec81 Removed retry - socketio does this anyway 2024-11-14 07:24:37 -07:00
Rohit Malhotra
38dc41ca42 Fix: [Bug] Do not render editor action buttons (save/discard) when displaying non-code files (#4903) 2024-11-14 09:09:28 +02:00
Engel Nyst
8dee334236 Context Window Exceeded fix (#4977) 2024-11-14 02:42:39 +00:00
Tim O'Farrell
d4b20c284d Now using socket io 2024-11-13 16:38:03 -07:00
Tim O'Farrell
2501cce470 Merge branch 'main' into feat-socket-io 2024-11-13 16:20:49 -07:00
Tim O'Farrell
ca9aefd7b2 WIP - frontend is still janky 2024-11-13 16:18:39 -07:00
Tim O'Farrell
92586a090d Server side init is working 2024-11-13 14:41:57 -07:00
Tim O'Farrell
18e774dd8a Fix for merge error 2024-11-13 13:37:07 -07:00
Tim O'Farrell
3a59e037fb Merge branch 'main' into feat-socket-io 2024-11-13 12:56:45 -07:00
Tim O'Farrell
5973c0c269 Working through socket issues 2024-11-13 12:55:51 -07:00
Tim O'Farrell
2fa8c4e14d Initial stab at reimplementing session management 2024-11-13 12:31:02 -07:00
Tim O'Farrell
c2a5fbceb7 Server side work for socketio.
Still needs testing and integration of proper client (Rather than the mock I've been using)
2024-11-13 11:45:43 -07:00
16 changed files with 1595 additions and 717 deletions

View File

@@ -40,6 +40,7 @@
"react-textarea-autosize": "^8.5.4",
"remark-gfm": "^4.0.0",
"sirv-cli": "^3.0.0",
"socket.io-client": "^4.8.1",
"tailwind-merge": "^2.5.4",
"vite": "^5.4.9",
"web-vitals": "^3.5.2",
@@ -5576,6 +5577,11 @@
"integrity": "sha512-+Fj43pSMwJs4KRrH/938Uf+uAELIgVBmQzg/q1YG10djyfA3TnrU8N8XzqCh/okZdszqBQTZf96idMfE5lnwTA==",
"dev": true
},
"node_modules/@socket.io/component-emitter": {
"version": "3.1.2",
"resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.2.tgz",
"integrity": "sha512-9BCxFwvbGg/RsZK9tjXd8s4UcwR0MWeFQ1XEKIQVVvAGJyINdrqKMcTRyLoK8Rse1GjzLV9cwjWV1olXRWEXVA=="
},
"node_modules/@svgr/babel-plugin-add-jsx-attribute": {
"version": "8.0.0",
"resolved": "https://registry.npmjs.org/@svgr/babel-plugin-add-jsx-attribute/-/babel-plugin-add-jsx-attribute-8.0.0.tgz",
@@ -8469,6 +8475,46 @@
"once": "^1.4.0"
}
},
"node_modules/engine.io-client": {
"version": "6.6.2",
"resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.6.2.tgz",
"integrity": "sha512-TAr+NKeoVTjEVW8P3iHguO1LO6RlUz9O5Y8o7EY0fU+gY1NYqas7NN3slpFtbXEsLMHk0h90fJMfKjRkQ0qUIw==",
"dependencies": {
"@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.1",
"engine.io-parser": "~5.2.1",
"ws": "~8.17.1",
"xmlhttprequest-ssl": "~2.1.1"
}
},
"node_modules/engine.io-client/node_modules/ws": {
"version": "8.17.1",
"resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz",
"integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==",
"engines": {
"node": ">=10.0.0"
},
"peerDependencies": {
"bufferutil": "^4.0.1",
"utf-8-validate": ">=5.0.2"
},
"peerDependenciesMeta": {
"bufferutil": {
"optional": true
},
"utf-8-validate": {
"optional": true
}
}
},
"node_modules/engine.io-parser": {
"version": "5.2.3",
"resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.3.tgz",
"integrity": "sha512-HqD3yTBfnBxIrbnM1DoD6Pcq8NECnh8d4As1Qgh0z5Gg3jRRIqijury0CL3ghu/edArpUYiYqQiDUQBIs4np3Q==",
"engines": {
"node": ">=10.0.0"
}
},
"node_modules/entities": {
"version": "4.5.0",
"resolved": "https://registry.npmjs.org/entities/-/entities-4.5.0.tgz",
@@ -22587,6 +22633,32 @@
"tslib": "^2.0.3"
}
},
"node_modules/socket.io-client": {
"version": "4.8.1",
"resolved": "https://registry.npmjs.org/socket.io-client/-/socket.io-client-4.8.1.tgz",
"integrity": "sha512-hJVXfu3E28NmzGk8o1sHhN3om52tRvwYeidbj7xKy2eIIse5IoKX3USlS6Tqt3BHAtflLIkCQBkzVrEEfWUyYQ==",
"dependencies": {
"@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.2",
"engine.io-client": "~6.6.1",
"socket.io-parser": "~4.2.4"
},
"engines": {
"node": ">=10.0.0"
}
},
"node_modules/socket.io-parser": {
"version": "4.2.4",
"resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.4.tgz",
"integrity": "sha512-/GbIKmo8ioc+NIWIhwdecY0ge+qVBSMdgxGygevmdHj24bsfgtCmcUUcQ5ZzcylGFHsN3k4HB4Cgkl96KVnuew==",
"dependencies": {
"@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.1"
},
"engines": {
"node": ">=10.0.0"
}
},
"node_modules/source-map": {
"version": "0.7.4",
"resolved": "https://registry.npmjs.org/source-map/-/source-map-0.7.4.tgz",
@@ -25317,6 +25389,14 @@
"integrity": "sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw==",
"dev": true
},
"node_modules/xmlhttprequest-ssl": {
"version": "2.1.2",
"resolved": "https://registry.npmjs.org/xmlhttprequest-ssl/-/xmlhttprequest-ssl-2.1.2.tgz",
"integrity": "sha512-TEU+nJVUUnA4CYJFLvK5X9AOeH4KvDvhIfm0vV1GaQRtchnG0hgK5p8hw/xjv8cunWYCsiPCSDzObPyhEwq3KQ==",
"engines": {
"node": ">=0.4.0"
}
},
"node_modules/xtend": {
"version": "4.0.2",
"resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz",

View File

@@ -39,6 +39,7 @@
"react-textarea-autosize": "^8.5.4",
"remark-gfm": "^4.0.0",
"sirv-cli": "^3.0.0",
"socket.io-client": "^4.8.1",
"tailwind-merge": "^2.5.4",
"vite": "^5.4.9",
"web-vitals": "^3.5.2",
@@ -120,4 +121,4 @@
"public"
]
}
}
}

View File

@@ -1,13 +1,11 @@
import posthog from "posthog-js";
import React from "react";
import { io, Socket } from "socket.io-client";
import { Settings } from "#/services/settings";
import ActionType from "#/types/ActionType";
import EventLogger from "#/utils/event-logger";
import AgentState from "#/types/AgentState";
import { handleAssistantMessage } from "#/services/actions";
const RECONNECT_RETRIES = 5;
export enum WsClientProviderStatus {
STOPPED,
OPENING,
@@ -43,38 +41,46 @@ export function WsClientProvider({
settings,
children,
}: React.PropsWithChildren<WsClientProviderProps>) {
const wsRef = React.useRef<WebSocket | null>(null);
const sioRef = React.useRef<Socket | null>(null);
const tokenRef = React.useRef<string | null>(token);
const ghTokenRef = React.useRef<string | null>(ghToken);
const closeRef = React.useRef<ReturnType<typeof setTimeout> | null>(null);
const disconnectRef = React.useRef<ReturnType<typeof setTimeout> | null>(
null,
);
const [status, setStatus] = React.useState(WsClientProviderStatus.STOPPED);
const [events, setEvents] = React.useState<Record<string, unknown>[]>([]);
const [retryCount, setRetryCount] = React.useState(RECONNECT_RETRIES);
function send(event: Record<string, unknown>) {
if (!wsRef.current) {
if (!sioRef.current) {
EventLogger.error("WebSocket is not connected.");
return;
}
wsRef.current.send(JSON.stringify(event));
sioRef.current.emit("oh_action", event);
}
function handleOpen() {
setRetryCount(RECONNECT_RETRIES);
function handleConnect() {
setStatus(WsClientProviderStatus.OPENING);
const initEvent = {
const initEvent: Record<string, unknown> = {
action: ActionType.INIT,
args: settings,
};
if (token) {
initEvent.token = token;
}
if (ghToken) {
initEvent.github_token = ghToken;
}
if (events.length) {
// Wrong. Events is out of sync here...
initEvent.latest_event_id = `${events[events.length - 1].id}`;
}
send(initEvent);
}
function handleMessage(messageEvent: MessageEvent) {
const event = JSON.parse(messageEvent.data);
function handleMessage(event: Record<string, unknown>) {
setEvents((prevEvents) => [...prevEvents, event]);
if (event.extras?.agent_state === AgentState.INIT) {
setStatus(WsClientProviderStatus.ACTIVE);
}
const extras = event.extras as Record<string, unknown>;
if (
status !== WsClientProviderStatus.ACTIVE &&
event?.observation === "error"
@@ -82,93 +88,98 @@ export function WsClientProvider({
setStatus(WsClientProviderStatus.ERROR);
}
handleAssistantMessage(event);
}
function handleClose() {
if (retryCount) {
setTimeout(() => {
setRetryCount(retryCount - 1);
}, 1000);
if (event.token) {
setStatus(WsClientProviderStatus.ACTIVE);
} else {
setStatus(WsClientProviderStatus.STOPPED);
setEvents([]);
handleAssistantMessage(event);
}
wsRef.current = null;
}
function handleError(event: Event) {
function handleDisconnect() {
setStatus(WsClientProviderStatus.STOPPED);
// setEvents([]);
// sioRef.current = null;
}
function handleError() {
posthog.capture("socket_error");
EventLogger.event(event, "SOCKET ERROR");
setStatus(WsClientProviderStatus.ERROR);
sioRef.current?.disconnect();
}
// Connect websocket
React.useEffect(() => {
let ws = wsRef.current;
let sio = sioRef.current;
// If disabled close any existing websockets...
if (!enabled || !retryCount) {
if (ws) {
ws.close();
// If disabled disconnect any existing websockets...
if (!enabled) {
if (sio) {
sio.disconnect();
}
wsRef.current = null;
return () => {};
}
// If there is no websocket or the tokens have changed or the current websocket is closed,
// If there is no websocket or the tokens have changed or the current websocket is disconnected,
// create a new one
if (
!ws ||
!sio ||
(tokenRef.current && token !== tokenRef.current) ||
ghToken !== ghTokenRef.current ||
ws.readyState === WebSocket.CLOSED ||
ws.readyState === WebSocket.CLOSING
ghToken !== ghTokenRef.current
) {
ws?.close();
sio?.disconnect();
const baseUrl =
import.meta.env.VITE_BACKEND_BASE_URL || window?.location.host;
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
let wsUrl = `${protocol}//${baseUrl}/ws`;
if (events.length) {
wsUrl += `?latest_event_id=${events[events.length - 1].id}`;
}
ws = new WebSocket(wsUrl, [
"openhands",
token || "NO_JWT",
ghToken || "NO_GITHUB",
]);
sio = io(baseUrl, {
transports: ["websocket"],
// extraHeaders: {
// Testy: "TESTER"
// },
// We force a new connection, because the headers may have changed.
// forceNew: true,
// Had to do this for now because reconnection actually starts a new session,
// which we don't want - The reconnect has the same headers as the original
// which don't include the original session id
// reconnection: false,
// reconnectionDelay: 1000,
// reconnectionDelayMax : 5000,
// reconnectionAttempts: 5
});
}
ws.addEventListener("open", handleOpen);
ws.addEventListener("message", handleMessage);
ws.addEventListener("error", handleError);
ws.addEventListener("close", handleClose);
wsRef.current = ws;
sio.on("connect", handleConnect);
sio.on("oh_event", handleMessage);
sio.on("connect_error", handleError);
sio.on("connect_failed", handleError);
sio.on("disconnect", handleDisconnect);
sioRef.current = sio;
tokenRef.current = token;
ghTokenRef.current = ghToken;
return () => {
ws.removeEventListener("open", handleOpen);
ws.removeEventListener("message", handleMessage);
ws.removeEventListener("error", handleError);
ws.removeEventListener("close", handleClose);
sio.off("connect", handleConnect);
sio.off("oh_event", handleMessage);
sio.off("connect_error", handleError);
sio.off("connect_failed", handleError);
sio.off("disconnect", handleDisconnect);
};
}, [enabled, token, ghToken, retryCount]);
}, [enabled, token, ghToken, events]);
// Strict mode mounts and unmounts each component twice, so we have to wait in the destructor
// before actually closing the socket and cancel the operation if the component gets remounted.
// before actually disconnecting the socket and cancel the operation if the component gets remounted.
React.useEffect(() => {
const timeout = closeRef.current;
const timeout = disconnectRef.current;
if (timeout != null) {
clearTimeout(timeout);
}
return () => {
closeRef.current = setTimeout(() => {
const ws = wsRef.current;
if (ws) {
ws.removeEventListener("close", handleClose);
ws.close();
disconnectRef.current = setTimeout(() => {
const sio = sioRef.current;
if (sio) {
sio.off("disconnect", handleDisconnect);
sio.disconnect();
}
}, 100);
};

View File

@@ -12,6 +12,18 @@ import CodeEditorCompoonent from "./code-editor-component";
import { useFiles } from "#/context/files";
import { EditorActions } from "#/components/editor-actions";
const ASSET_FILE_TYPES = [
".png",
".jpg",
".jpeg",
".bmp",
".gif",
".pdf",
".mp4",
".webm",
".ogg",
];
export const clientLoader = async () => {
const token = localStorage.getItem("token");
return json({ token });
@@ -104,6 +116,10 @@ function CodeEditor() {
if (selectedPath) discardChanges(selectedPath);
};
const isAssetFileType = selectedPath
? ASSET_FILE_TYPES.some((ext) => selectedPath.endsWith(ext))
: false;
return (
<div className="flex h-full bg-neutral-900 relative">
<FileExplorer
@@ -112,7 +128,7 @@ function CodeEditor() {
error={errors.getFiles}
/>
<div className="w-full">
{selectedPath && (
{selectedPath && !isAssetFileType && (
<div className="flex w-full items-center justify-between self-end p-2">
<span className="text-sm text-neutral-500">{selectedPath}</span>
<EditorActions

View File

@@ -82,6 +82,13 @@ export default defineConfig(({ mode }) => {
changeOrigin: true,
secure: !INSECURE_SKIP_VERIFY,
},
"/socket.io": {
target: WS_URL,
ws: true,
changeOrigin: true,
secure: !INSECURE_SKIP_VERIFY,
//rewriteWsOrigin: true,
}
},
},
ssr: {

View File

@@ -5,6 +5,7 @@ import traceback
from typing import Callable, ClassVar, Type
import litellm
from litellm.exceptions import ContextWindowExceededError
from openhands.controller.agent import Agent
from openhands.controller.state.state import State, TrafficControlState
@@ -485,6 +486,15 @@ class AgentController:
EventSource.AGENT,
)
return
except ContextWindowExceededError:
# When context window is exceeded, keep roughly half of agent interactions
self.state.history = self._apply_conversation_window(self.state.history)
# Save the ID of the first event in our truncated history for future reloading
if self.state.history:
self.state.start_id = self.state.history[0].id
# Don't add error event - let the agent retry with reduced context
return
if action.runnable:
if self.state.confirmation_mode and (
@@ -659,6 +669,12 @@ class AgentController:
- For delegate events (between AgentDelegateAction and AgentDelegateObservation):
- Excludes all events between the action and observation
- Includes the delegate action and observation themselves
The history is loaded in two parts if truncation_id is set:
1. First user message from start_id onwards
2. Rest of history from truncation_id to the end
Otherwise loads normally from start_id.
"""
# define range of events to fetch
@@ -680,8 +696,33 @@ class AgentController:
self.state.history = []
return
# Get all events, filtering out backend events and hidden events
events = list(
events: list[Event] = []
# If we have a truncation point, get first user message and then rest of history
if hasattr(self.state, 'truncation_id') and self.state.truncation_id > 0:
# Find first user message from stream
first_user_msg = next(
(
e
for e in self.event_stream.get_events(
start_id=start_id,
end_id=end_id,
reverse=False,
filter_out_type=self.filter_out,
filter_hidden=True,
)
if isinstance(e, MessageAction) and e.source == EventSource.USER
),
None,
)
if first_user_msg:
events.append(first_user_msg)
# the rest of the events are from the truncation point
start_id = self.state.truncation_id
# Get rest of history
events_to_add = list(
self.event_stream.get_events(
start_id=start_id,
end_id=end_id,
@@ -690,6 +731,7 @@ class AgentController:
filter_hidden=True,
)
)
events.extend(events_to_add)
# Find all delegate action/observation pairs
delegate_ranges: list[tuple[int, int]] = []
@@ -744,6 +786,92 @@ class AgentController:
# make sure history is in sync
self.state.start_id = start_id
def _apply_conversation_window(self, events: list[Event]) -> list[Event]:
"""Cuts history roughly in half when context window is exceeded, preserving action-observation pairs
and ensuring the first user message is always included.
The algorithm:
1. Cut history in half
2. Check first event in new history:
- If Observation: find and include its Action
- If MessageAction: ensure its related Action-Observation pair isn't split
3. Always include the first user message
Args:
events: List of events to filter
Returns:
Filtered list of events keeping newest half while preserving pairs
"""
if not events:
return events
# Find first user message - we'll need to ensure it's included
first_user_msg = next(
(
e
for e in events
if isinstance(e, MessageAction) and e.source == EventSource.USER
),
None,
)
# cut in half
mid_point = max(1, len(events) // 2)
kept_events = events[mid_point:]
# Handle first event in truncated history
if kept_events:
i = 0
while i < len(kept_events):
first_event = kept_events[i]
if isinstance(first_event, Observation) and first_event.cause:
# Find its action and include it
matching_action = next(
(
e
for e in reversed(events[:mid_point])
if isinstance(e, Action) and e.id == first_event.cause
),
None,
)
if matching_action:
kept_events = [matching_action] + kept_events
else:
self.log(
'warning',
f'Found Observation without matching Action at id={first_event.id}',
)
# drop this observation
kept_events = kept_events[1:]
break
elif isinstance(first_event, MessageAction) or (
isinstance(first_event, Action)
and first_event.source == EventSource.USER
):
# if it's a message action or a user action, keep it and continue to find the next event
i += 1
continue
else:
# if it's an action with source == EventSource.AGENT, we're good
break
# Save where to continue from in next reload
if kept_events:
self.state.truncation_id = kept_events[0].id
# Ensure first user message is included
if first_user_msg and first_user_msg not in kept_events:
kept_events = [first_user_msg] + kept_events
# start_id points to first user message
if first_user_msg:
self.state.start_id = first_user_msg.id
return kept_events
def _is_stuck(self):
"""Checks if the agent or its delegate is stuck in a loop.

View File

@@ -92,6 +92,8 @@ class State:
# start_id and end_id track the range of events in history
start_id: int = -1
end_id: int = -1
# truncation_id tracks where to load history after context window truncation
truncation_id: int = -1
almost_stuck: int = 0
delegates: dict[tuple[int, int], tuple[str, str]] = field(default_factory=dict)
# NOTE: This will never be used by the controller, but it can be used by different

View File

@@ -111,9 +111,6 @@ class LogBuffer:
def close(self, timeout: float = 5.0):
self._stop_event.set()
self.log_stream_thread.join(timeout)
# Close the log generator to release the file descriptor
if hasattr(self.log_generator, 'close'):
self.log_generator.close()
class EventStreamRuntime(Runtime):
@@ -235,8 +232,6 @@ class EventStreamRuntime(Runtime):
f'Container started: {self.container_name}. VSCode URL: {self.vscode_url}',
)
self.log_buffer = LogBuffer(self.container, self.log)
if not self.attach_to_existing:
self.log('info', f'Waiting for client to become ready at {self.api_url}...')
self.send_status_message('STATUS$WAITING_FOR_CLIENT')
@@ -363,6 +358,7 @@ class EventStreamRuntime(Runtime):
environment=environment,
volumes=volumes,
)
self.log_buffer = LogBuffer(self.container, self.log)
self.log('debug', f'Container started. Server url: {self.api_url}')
self.send_status_message('STATUS$CONTAINER_STARTED')
except docker.errors.APIError as e:
@@ -389,9 +385,11 @@ class EventStreamRuntime(Runtime):
raise e
def _attach_to_container(self):
container = self.docker_client.containers.get(self.container_name)
self.log_buffer = LogBuffer(container, self.log)
self.container = container
self._container_port = 0
self.container = self.docker_client.containers.get(self.container_name)
for port in self.container.attrs['NetworkSettings']['Ports']: # type: ignore
for port in container.attrs['NetworkSettings']['Ports']:
self._container_port = int(port.split('/')[0])
break
self._host_port = self._container_port

View File

@@ -115,15 +115,13 @@ async def get_github_user(token: str) -> str:
github handle of the user
"""
logger.debug('Fetching GitHub user info from token')
g = Github(token)
try:
g = Github(token)
user = await call_sync_from_async(g.get_user)
login = user.login
logger.info(f'Successfully retrieved GitHub user: {login}')
return login
except GithubException as e:
logger.error(f'Error making request to GitHub API: {str(e)}')
logger.error(e)
raise
finally:
g.close()
login = user.login
logger.info(f'Successfully retrieved GitHub user: {login}')
return login

View File

@@ -1,16 +1,16 @@
import asyncio
import os
import re
import tempfile
import time
import uuid
import warnings
import jwt
import requests
import socketio
from pathspec import PathSpec
from pathspec.patterns import GitWildMatchPattern
from openhands.core.schema.action import ActionType
from openhands.security.options import SecurityAnalyzers
from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
from openhands.server.github import (
@@ -19,8 +19,9 @@ from openhands.server.github import (
UserVerifier,
authenticate_github_user,
)
from openhands.server.pg_socket import AsyncPostgresAdapter
from openhands.storage import get_file_store
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.async_utils import call_async_from_sync, call_sync_from_async
with warnings.catch_warnings():
warnings.simplefilter('ignore')
@@ -33,7 +34,6 @@ from fastapi import (
HTTPException,
Request,
UploadFile,
WebSocket,
status,
)
from fastapi.responses import FileResponse, JSONResponse
@@ -52,7 +52,6 @@ from openhands.events.action import (
NullAction,
)
from openhands.events.observation import (
AgentStateChangedObservation,
ErrorObservation,
FileReadObservation,
FileWriteObservation,
@@ -250,122 +249,6 @@ async def attach_session(request: Request, call_next):
return response
@app.websocket('/ws')
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for receiving events from the client (i.e., the browser).
Once connected, the client can send various actions:
- Initialize the agent:
session management, and event streaming.
```json
{"action": "initialize", "args": {"LLM_MODEL": "ollama/llama3", "AGENT": "CodeActAgent", "LANGUAGE": "en", "LLM_API_KEY": "ollama"}}
Args:
```
websocket (WebSocket): The WebSocket connection object.
- Start a new development task:
```json
{"action": "start", "args": {"task": "write a bash script that prints hello"}}
```
- Send a message:
```json
{"action": "message", "args": {"content": "Hello, how are you?", "image_urls": ["base64_url1", "base64_url2"]}}
```
- Write contents to a file:
```json
{"action": "write", "args": {"path": "./greetings.txt", "content": "Hello, OpenHands?"}}
```
- Read the contents of a file:
```json
{"action": "read", "args": {"path": "./greetings.txt"}}
```
- Run a command:
```json
{"action": "run", "args": {"command": "ls -l", "thought": "", "confirmation_state": "confirmed"}}
```
- Run an IPython command:
```json
{"action": "run_ipython", "args": {"command": "print('Hello, IPython!')"}}
```
- Open a web page:
```json
{"action": "browse", "args": {"url": "https://arxiv.org/html/2402.01030v2"}}
```
- Add a task to the root_task:
```json
{"action": "add_task", "args": {"task": "Implement feature X"}}
```
- Update a task in the root_task:
```json
{"action": "modify_task", "args": {"id": "0", "state": "in_progress", "thought": ""}}
```
- Change the agent's state:
```json
{"action": "change_agent_state", "args": {"state": "paused"}}
```
- Finish the task:
```json
{"action": "finish", "args": {}}
```
"""
# Get protocols from Sec-WebSocket-Protocol header
protocols = websocket.headers.get('sec-websocket-protocol', '').split(', ')
# The first protocol should be our real protocol (e.g. 'openhands')
# The second protocol should contain our auth token
if len(protocols) < 3:
logger.error('Expected 3 websocket protocols, got %d', len(protocols))
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
real_protocol = protocols[0]
jwt_token = protocols[1] if protocols[1] != 'NO_JWT' else ''
github_token = protocols[2] if protocols[2] != 'NO_GITHUB' else ''
if not await authenticate_github_user(github_token):
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
await asyncio.wait_for(websocket.accept(subprotocol=real_protocol), 10)
if jwt_token:
sid = get_sid_from_token(jwt_token, config.jwt_secret)
if sid == '':
await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
await websocket.close()
return
else:
sid = str(uuid.uuid4())
jwt_token = sign_token({'sid': sid}, config.jwt_secret)
logger.info(f'New session: {sid}')
session = session_manager.add_or_restart_session(sid, websocket)
await websocket.send_json({'token': jwt_token, 'status': 'ok'})
latest_event_id = -1
if websocket.query_params.get('latest_event_id'):
latest_event_id = int(websocket.query_params.get('latest_event_id'))
async_stream = AsyncEventStreamWrapper(
session.agent_session.event_stream, latest_event_id + 1
)
async for event in async_stream:
if isinstance(
event,
(
NullAction,
NullObservation,
ChangeAgentStateAction,
AgentStateChangedObservation,
),
):
continue
await websocket.send_json(event_to_dict(event))
await session.loop_recv()
@app.get('/api/options/models')
async def get_litellm_models() -> list[str]:
"""
@@ -930,3 +813,135 @@ class SPAStaticFiles(StaticFiles):
app.mount('/', SPAStaticFiles(directory='./frontend/build', html=True), name='dist')
use_manager = os.getenv('DB_HOST') or os.getenv('GCP_DB_INSTANCE')
manager = AsyncPostgresAdapter() if use_manager else None
if manager:
call_async_from_sync(manager.setup, 10)
sio = socketio.AsyncServer(
async_mode='asgi', cors_allowed_origins='*', client_manager=manager
)
app = socketio.ASGIApp(sio, other_asgi_app=app)
@sio.event
async def connect(connection_id: str, environ):
logger.info(f'sio:connect: {connection_id}')
@sio.event
async def oh_action(connection_id: str, data: dict):
"""WebSocket endpoint for receiving events from the client (i.e., the browser).
Once connected, the client can send various actions:
- Initialize the agent:
session management, and event streaming.
```json
{"action": "initialize", "args": {"LLM_MODEL": "ollama/llama3", "AGENT": "CodeActAgent", "LANGUAGE": "en", "LLM_API_KEY": "ollama"}}
Args:
```
websocket (WebSocket): The WebSocket connection object.
- Start a new development task:
```json
{"action": "start", "args": {"task": "write a bash script that prints hello"}}
```
- Send a message:
```json
{"action": "message", "args": {"content": "Hello, how are you?", "image_urls": ["base64_url1", "base64_url2"]}}
```
- Write contents to a file:
```json
{"action": "write", "args": {"path": "./greetings.txt", "content": "Hello, OpenHands?"}}
```
- Read the contents of a file:
```json
{"action": "read", "args": {"path": "./greetings.txt"}}
```
- Run a command:
```json
{"action": "run", "args": {"command": "ls -l", "thought": "", "confirmation_state": "confirmed"}}
```
- Run an IPython command:
```json
{"action": "run_ipython", "args": {"command": "print('Hello, IPython!')"}}
```
- Open a web page:
```json
{"action": "browse", "args": {"url": "https://arxiv.org/html/2402.01030v2"}}
```
- Add a task to the root_task:
```json
{"action": "add_task", "args": {"task": "Implement feature X"}}
```
- Update a task in the root_task:
```json
{"action": "modify_task", "args": {"id": "0", "state": "in_progress", "thought": ""}}
```
- Change the agent's state:
```json
{"action": "change_agent_state", "args": {"state": "paused"}}
```
- Finish the task:
```json
{"action": "finish", "args": {}}
```
"""
# If it's an init, we do it here.
action = data.get('action', '')
if action == ActionType.INIT:
await init_connection(connection_id, data)
return
logger.info(f'sio:oh_action:{connection_id}')
session = session_manager.get_local_session(connection_id)
await session.dispatch(data)
async def init_connection(connection_id: str, data: dict):
gh_token = data.pop('gh_token', None)
if not await authenticate_github_user(gh_token):
raise RuntimeError(status.WS_1008_POLICY_VIOLATION)
token = data.pop('token', None)
if token:
sid = get_sid_from_token(token, config.jwt_secret)
if sid == '':
await sio.send({'error': 'Invalid token', 'error_code': 401})
return
logger.info(f'Existing session: {sid}')
else:
sid = connection_id
logger.info(f'New session: {sid}')
token = sign_token({'sid': sid}, config.jwt_secret)
await sio.emit('oh_event', {'token': token, 'status': 'ok'}, to=connection_id)
latest_event_id = int(data.pop('latest_event_id', -1))
# The session in question should exist, but may not actually be running locally...
session = await session_manager.init_or_join_local_session(
sio, sid, connection_id, data
)
# Send events
async_stream = AsyncEventStreamWrapper(
session.agent_session.event_stream, latest_event_id + 1
)
async for event in async_stream:
if isinstance(
event,
(
NullAction,
NullObservation,
ChangeAgentStateAction,
),
):
continue
await sio.emit('oh_event', event_to_dict(event), to=connection_id)
@sio.event
async def disconnect(connection_id: str):
logger.info(f'sio:disconnect:{connection_id}')
await session_manager.disconnect_from_local_session(connection_id)

View File

@@ -0,0 +1,242 @@
import asyncio
import json
import os
import pickle
from asyncio import Task
from typing import Any, Dict, Optional
import asyncpg
from asyncpg.exceptions import PostgresError
from asyncpg.pool import Pool
from google.cloud.sql.connector import Connector
from socketio.async_pubsub_manager import AsyncPubSubManager
def has_binary(obj: Any, to_json: bool = False) -> bool:
if not obj or not isinstance(obj, (dict, list, bytes, bytearray, memoryview)):
return False
if isinstance(obj, (bytes, bytearray, memoryview)):
return True
if isinstance(obj, list):
return any(has_binary(item) for item in obj)
if isinstance(obj, dict):
return any(has_binary(v) for v in obj.values())
if hasattr(obj, 'to_json') and callable(obj.to_json) and not to_json:
return has_binary(obj.to_json(), True)
return False
class AsyncPostgresAdapter(AsyncPubSubManager):
def __init__(
self,
channel: str = 'socketio',
table_name: str = 'socket_io_attachments',
payload_threshold: int = 8000,
cleanup_interval: int = 30000,
):
self.channel = channel
super().__init__(channel=channel)
self.table_name = table_name
self.payload_threshold = payload_threshold
self.cleanup_interval = cleanup_interval
self._cleanup_timer: Optional[Task] = None
self._client: Optional[asyncpg.Connection] = None
self.pool: Optional[Pool] = None
# Connection configs
self.db_host = os.environ.get('DB_HOST')
self.db_user = os.environ.get('DB_USER')
self.db_pass = os.environ.get('DB_PASS', '').strip()
self.db_name = os.environ.get('DB_NAME')
# GCP configs
self.gcp_instance = os.environ.get('GCP_DB_INSTANCE')
self.gcp_project = os.environ.get('GCP_PROJECT')
self.gcp_region = os.environ.get('GCP_REGION')
self.connector = Connector() if self.gcp_instance else None
async def setup(self):
if not self.pool:
self.pool = await self._create_pool()
await self._create_db()
await self._init_client()
async def _create_pool(self) -> Pool:
if self.gcp_instance:
return await asyncpg.create_pool(
lambda: self._get_gcp_connection(self.db_name)
)
else:
return await asyncpg.create_pool(
user=self.db_user,
password=self.db_pass,
host=self.db_host,
database=self.db_name,
)
async def _get_gcp_connection(
self, db_name: Optional[str] = None
) -> asyncpg.Connection:
instance_string = f'{self.gcp_project}:{self.gcp_region}:{self.gcp_instance}'
conn = await self.connector.connect_async(
instance_connection_string=instance_string,
driver='asyncpg',
user=self.db_user,
password=self.db_pass,
db=db_name or self.db_name,
)
return conn
async def _init_client(self) -> None:
if not self.pool:
raise RuntimeError('Pool not initialized')
try:
self._client = await self.pool.acquire()
await self._client.execute(f'LISTEN "{self.channel}"')
self._client.add_listener(self.channel, self._on_notification)
if not self._cleanup_timer:
self._schedule_cleanup()
except PostgresError as e:
self.logger.error(f'Error initializing client: {e}')
await asyncio.sleep(2)
await self._init_client()
def _schedule_cleanup(self) -> None:
async def cleanup() -> None:
if not self.pool:
return
try:
await self.pool.execute(
f"DELETE FROM {self.table_name} WHERE created_at < now() - interval '{self.cleanup_interval} milliseconds'"
)
except PostgresError as e:
self.logger.error(f'Cleanup error: {e}')
self._cleanup_timer = asyncio.create_task(cleanup())
await asyncio.sleep(self.cleanup_interval / 1000)
self._cleanup_timer = asyncio.create_task(cleanup())
async def _publish_with_attachment(self, data: Dict) -> None:
if not self.pool:
raise RuntimeError('Pool not initialized')
payload = pickle.dumps(data)
result = await self.pool.fetchrow(
f'INSERT INTO {self.table_name} (payload) VALUES ($1) RETURNING id', payload
)
if not result:
raise RuntimeError('Failed to insert payload')
notification = {
'uid': self.uid,
'type': data['type'],
'attachmentId': result['id'],
}
await self.pool.execute(
'SELECT pg_notify($1, $2)', self.channel, json.dumps(notification)
)
async def _publish(self, data: Dict) -> None:
if not self.pool:
raise RuntimeError('Pool not initialized')
try:
data['uid'] = self.uid
if has_binary(data) or len(json.dumps(data)) > self.payload_threshold:
await self._publish_with_attachment(data)
return
await self.pool.execute(
'SELECT pg_notify($1, $2)', self.channel, json.dumps(data)
)
except PostgresError as e:
self.logger.error(f'Publish error: {e}')
raise
async def _on_notification(self, conn, pid, channel, payload) -> None:
if not self.pool:
return
try:
data = json.loads(payload)
if data.get('uid') == self.uid:
return
if 'attachmentId' in data:
result = await self.pool.fetchrow(
f'SELECT payload FROM {self.table_name} WHERE id = $1',
data['attachmentId'],
)
if not result:
self.logger.error(f"Attachment {data['attachmentId']} not found")
return
data = pickle.loads(result['payload'])
await self._handle_message(data)
except Exception as e:
self.logger.error(f'Notification error: {e}')
async def close(self) -> None:
if self._cleanup_timer:
self._cleanup_timer.cancel()
if self._client and self.pool:
await self.pool.release(self._client)
if self.connector:
await self.connector.close_async()
await super().close()
async def _create_db(self) -> None:
try:
# Connect to default postgres DB first
if self.gcp_instance:
sys_conn = await self._get_gcp_connection('postgres')
else:
sys_conn = await asyncpg.connect(
user=self.db_user,
password=self.db_pass,
host=self.db_host,
database='postgres',
)
try:
# Create DB if needed
exists = await sys_conn.fetchval(
'SELECT 1 FROM pg_database WHERE datname = $1', self.db_name
)
if not exists:
await sys_conn.execute(f'CREATE DATABASE "{self.db_name}"')
finally:
await sys_conn.close()
# Create attachments table
if not self.pool:
raise RuntimeError('Pool not initialized')
await self.pool.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
id SERIAL PRIMARY KEY,
payload BYTEA NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
except Exception as e:
self.logger.error(f'Database creation error: {e}')
raise

View File

@@ -1,7 +1,9 @@
import asyncio
import time
from dataclasses import dataclass
from dataclasses import dataclass, field
from fastapi import WebSocket
import socketio
from openhands.core.config import AppConfig
from openhands.core.logger import openhands_logger as logger
@@ -15,11 +17,8 @@ from openhands.storage.files import FileStore
class SessionManager:
config: AppConfig
file_store: FileStore
def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session:
return Session(
sid=sid, file_store=self.file_store, ws=ws_conn, config=self.config
)
local_sessions_by_sid: dict[str, Session] = field(default_factory=dict)
local_sessions_by_connection_id: dict[str, Session] = field(default_factory=dict)
async def attach_to_conversation(self, sid: str) -> Conversation | None:
start_time = time.time()
@@ -35,3 +34,41 @@ class SessionManager:
async def detach_from_conversation(self, conversation: Conversation):
await conversation.disconnect()
async def init_or_join_local_session(self, sio: socketio.AsyncServer, sid: str, connection_id: str, data: dict):
""" If there is no local session running, initialize one """
session = self.local_sessions_by_sid.get(sid)
if not session:
session = Session(
sid=sid, file_store=self.file_store, config=self.config, sio=sio, ws=None
)
session.connect(connection_id)
self.local_sessions_by_sid[sid] = session
self.local_sessions_by_connection_id[connection_id] = session
await session.initialize_agent(data)
else:
session.connect(connection_id)
self.local_sessions_by_connection_id[connection_id] = session
return session
def get_local_session(self, connection_id: str) -> Session:
return self.local_sessions_by_connection_id[connection_id]
async def disconnect_from_local_session(self, connection_id: str):
session = self.local_sessions_by_connection_id.pop(connection_id, None)
if not session:
# This can occur if the init action was never run.
logger.warning(f'disconnect_from_uninitialized_session:{connection_id}')
return
if session.disconnect(connection_id):
asyncio.create_task(self._check_and_close_session(session))
async def _check_and_close_session(self, session: Session):
# Once there have been no connections to a session for a reasonable period, we close it
try:
await asyncio.sleep(15)
finally:
# If the sleep was cancelled, we still want to close these
if not session.connection_ids:
session.close()
self.local_sessions_by_sid.pop(session.sid)

View File

@@ -1,6 +1,7 @@
import asyncio
import time
import socketio
from fastapi import WebSocket, WebSocketDisconnect
from openhands.controller.agent import Agent
@@ -23,22 +24,30 @@ from openhands.events.stream import EventStreamSubscriber
from openhands.llm.llm import LLM
from openhands.server.session.agent_session import AgentSession
from openhands.storage.files import FileStore
from openhands.utils.shutdown_listener import should_continue
from openhands.utils.async_utils import wait_all
class Session:
sid: str
websocket: WebSocket | None
sio: socketio.AsyncServer | None
connection_ids: set[str]
last_active_ts: int = 0
is_alive: bool = True
agent_session: AgentSession
loop: asyncio.AbstractEventLoop
def __init__(
self, sid: str, ws: WebSocket | None, config: AppConfig, file_store: FileStore
self,
sid: str,
ws: WebSocket | None,
config: AppConfig,
file_store: FileStore,
sio: socketio.AsyncServer | None,
):
self.sid = sid
self.websocket = ws
self.sio = sio
self.last_active_ts = int(time.time())
self.agent_session = AgentSession(
sid, file_store, status_callback=self.queue_status_message
@@ -47,36 +56,21 @@ class Session:
EventStreamSubscriber.SERVER, self.on_event, self.sid
)
self.config = config
self.connection_ids = set()
self.loop = asyncio.get_event_loop()
def connect(self, connection_id: str):
self.connection_ids.add(connection_id)
def disconnect(self, connection_id: str) -> bool:
self.connection_ids.remove(connection_id)
return not self.connection_ids
def close(self):
self.is_alive = False
try:
if self.websocket is not None:
self.websocket.close()
self.websocket = None
finally:
self.agent_session.close()
self.agent_session.close()
async def loop_recv(self):
try:
if self.websocket is None:
return
while should_continue():
try:
data = await self.websocket.receive_json()
except ValueError:
await self.send_error('Invalid JSON')
continue
await self.dispatch(data)
except WebSocketDisconnect:
logger.info('WebSocket disconnected, sid: %s', self.sid)
self.close()
except RuntimeError as e:
logger.exception('Error in loop_recv: %s', e)
self.close()
async def _initialize_agent(self, data: dict):
async def initialize_agent(self, data: dict):
self.agent_session.event_stream.add_event(
ChangeAgentStateAction(AgentState.LOADING), EventSource.ENVIRONMENT
)
@@ -164,7 +158,7 @@ class Session:
async def dispatch(self, data: dict):
action = data.get('action', '')
if action == ActionType.INIT:
await self._initialize_agent(data)
await self.initialize_agent(data)
return
event = event_from_dict(data.copy())
# This checks if the model supports images
@@ -185,9 +179,15 @@ class Session:
async def send(self, data: dict[str, object]) -> bool:
try:
if self.websocket is None or not self.is_alive:
if not self.is_alive:
return False
await self.websocket.send_json(data)
if self.websocket:
await self.websocket.send_json(data)
if self.sio:
await wait_all(
self.sio.emit('oh_event', data, to=connection_id)
for connection_id in self.connection_ids
)
await asyncio.sleep(0.001) # This flushes the data to the client
self.last_active_ts = int(time.time())
return True

1100
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -64,6 +64,9 @@ modal = "^0.64.145"
runloop-api-client = "0.7.0"
pygithub = "^2.5.0"
openhands-aci = "^0.1.0"
python-socketio = "^5.11.4"
asyncpg = "^0.30.0"
cloud-sql-python-connector = "^1.13.0"
[tool.poetry.group.llama-index.dependencies]
llama-index = "*"
@@ -95,6 +98,7 @@ reportlab = "*"
[tool.coverage.run]
concurrency = ["gevent"]
[tool.poetry.group.runtime.dependencies]
jupyterlab = "*"
notebook = "*"
@@ -125,6 +129,7 @@ ignore = ["D1"]
[tool.ruff.lint.pydocstyle]
convention = "google"
[tool.poetry.group.evaluation.dependencies]
streamlit = "*"
whatthepatch = "*"

View File

@@ -0,0 +1,188 @@
from unittest.mock import MagicMock
import pytest
from openhands.controller.agent_controller import AgentController
from openhands.events import EventSource
from openhands.events.action import CmdRunAction, MessageAction
from openhands.events.observation import CmdOutputObservation
@pytest.fixture
def mock_event_stream():
stream = MagicMock()
# Mock get_events to return an empty list by default
stream.get_events.return_value = []
return stream
@pytest.fixture
def mock_agent():
agent = MagicMock()
agent.llm = MagicMock()
agent.llm.config = MagicMock()
return agent
class TestTruncation:
def test_apply_conversation_window_basic(self, mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)
# Create a sequence of events with IDs
first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1
cmd1 = CmdRunAction(command='ls')
cmd1._id = 2
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=2)
obs1._id = 3
obs1._cause = 2
cmd2 = CmdRunAction(command='pwd')
cmd2._id = 4
obs2 = CmdOutputObservation(command='pwd', content='/home', command_id=4)
obs2._id = 5
obs2._cause = 4
events = [first_msg, cmd1, obs1, cmd2, obs2]
# Apply truncation
truncated = controller._apply_conversation_window(events)
# Should keep first user message and roughly half of other events
assert (
len(truncated) >= 3
) # First message + at least one action-observation pair
assert truncated[0] == first_msg # First message always preserved
assert controller.state.start_id == first_msg._id
assert controller.state.truncation_id is not None
# Verify pairs aren't split
for i, event in enumerate(truncated[1:]):
if isinstance(event, CmdOutputObservation):
assert any(e._id == event._cause for e in truncated[: i + 1])
def test_context_window_exceeded_handling(self, mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)
# Setup initial history with IDs
first_msg = MessageAction(content='Start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1
# Add agent question
agent_msg = MessageAction(
content='What task would you like me to perform?', wait_for_response=True
)
agent_msg._source = EventSource.AGENT
agent_msg._id = 2
# Add user response
user_response = MessageAction(
content='Please list all files and show me current directory',
wait_for_response=False,
)
user_response._source = EventSource.USER
user_response._id = 3
cmd1 = CmdRunAction(command='ls')
cmd1._id = 4
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=4)
obs1._id = 5
obs1._cause = 4
# Update mock event stream to include new messages
mock_event_stream.get_events.return_value = [
first_msg,
agent_msg,
user_response,
cmd1,
obs1,
]
controller.state.history = [first_msg, agent_msg, user_response, cmd1, obs1]
original_history_len = len(controller.state.history)
# Simulate ContextWindowExceededError and truncation
controller.state.history = controller._apply_conversation_window(
controller.state.history
)
# Verify truncation occurred
assert len(controller.state.history) < original_history_len
assert controller.state.start_id == first_msg._id
assert controller.state.truncation_id is not None
assert controller.state.truncation_id > controller.state.start_id
def test_history_restoration_after_truncation(self, mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)
# Create events with IDs
first_msg = MessageAction(content='Start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1
events = [first_msg]
for i in range(5):
cmd = CmdRunAction(command=f'cmd{i}')
cmd._id = i + 2
obs = CmdOutputObservation(
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
)
obs._cause = cmd._id
events.extend([cmd, obs])
# Set up initial history
controller.state.history = events.copy()
# Force truncation
controller.state.history = controller._apply_conversation_window(
controller.state.history
)
# Save state
saved_start_id = controller.state.start_id
saved_truncation_id = controller.state.truncation_id
saved_history_len = len(controller.state.history)
# Set up mock event stream for new controller
mock_event_stream.get_events.return_value = controller.state.history
# Create new controller with saved state
new_controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)
new_controller.state.start_id = saved_start_id
new_controller.state.truncation_id = saved_truncation_id
new_controller.state.history = mock_event_stream.get_events()
# Verify restoration
assert len(new_controller.state.history) == saved_history_len
assert new_controller.state.history[0] == first_msg
assert new_controller.state.start_id == saved_start_id