Compare commits

...

31 Commits

Author SHA1 Message Date
Robert Brennan dbda7bcc82 add cloud sql 2024-11-15 13:55:50 -05:00
Robert Brennan 2f275dce9e new pg impl 2024-11-15 13:55:04 -05:00
Robert Brennan 340b9ead40 debug logs 2024-11-15 13:46:01 -05:00
Robert Brennan 1fba72043c create db if not exists 2024-11-15 13:38:06 -05:00
Robert Brennan bcab72c981 add logging 2024-11-15 13:26:36 -05:00
Robert Brennan 5949f5bdb6 fix import 2024-11-15 13:19:32 -05:00
Robert Brennan 20d4a0cce2 add asyncpg 2024-11-15 13:18:28 -05:00
Robert Brennan 37ee90cf6f add gcp support 2024-11-15 13:17:46 -05:00
Robert Brennan 69c646ebf9 add pg socket 2024-11-15 12:56:25 -05:00
Robert Brennan 71e57dc069 better pg creds 2024-11-15 12:52:32 -05:00
Robert Brennan 3f03fb69c8 add pg socket 2024-11-15 12:49:38 -05:00
Tim O'Farrell 551a5a72e1 Client side error fix 2024-11-15 09:55:39 -07:00
Tim O'Farrell 2aa6bab85f Fix initial startup 2024-11-15 09:27:44 -07:00
Tim O'Farrell 3f0e619997 WIP 2024-11-15 08:52:12 -07:00
Tim O'Farrell 511ce8cc3a Fixed closing 2024-11-15 08:29:51 -07:00
Tim O'Farrell a1b53a9498 Culled dead code 2024-11-15 07:39:30 -07:00
Tim O'Farrell 767878563e WIP 2024-11-15 07:15:21 -07:00
Tim O'Farrell b2bd23ac86 Fix for test errors - client side seems to work 2024-11-14 18:46:13 -07:00
Tim O'Farrell f0487e6818 Server side code is working - now on the client side 2024-11-14 18:37:33 -07:00
Tim O'Farrell b85fbf39fd Tokens in init event 2024-11-14 15:45:57 -07:00
Tim O'Farrell 8b4d263319 Merge branch 'main' into feat-socket-io 2024-11-14 07:25:01 -07:00
Tim O'Farrell ff7783ec81 Removed retry - socketio does this anyway 2024-11-14 07:24:37 -07:00
Tim O'Farrell d4b20c284d Now using socket io 2024-11-13 16:38:03 -07:00
Tim O'Farrell 2501cce470 Merge branch 'main' into feat-socket-io 2024-11-13 16:20:49 -07:00
Tim O'Farrell ca9aefd7b2 WIP - frontend is still janky 2024-11-13 16:18:39 -07:00
Tim O'Farrell 92586a090d Server side init is working 2024-11-13 14:41:57 -07:00
Tim O'Farrell 18e774dd8a Fix for merge error 2024-11-13 13:37:07 -07:00
Tim O'Farrell 3a59e037fb Merge branch 'main' into feat-socket-io 2024-11-13 12:56:45 -07:00
Tim O'Farrell 5973c0c269 Working through socket issues 2024-11-13 12:55:51 -07:00
Tim O'Farrell 2fa8c4e14d Initial stab at reimplementing session management 2024-11-13 12:31:02 -07:00
Tim O'Farrell c2a5fbceb7 Server side work for socketio.
Still needs testing and integration of proper client (Rather than the mock I've been using)
2024-11-13 11:45:43 -07:00
10 changed files with 1248 additions and 695 deletions
+80
View File
@@ -40,6 +40,7 @@
"react-textarea-autosize": "^8.5.4",
"remark-gfm": "^4.0.0",
"sirv-cli": "^3.0.0",
"socket.io-client": "^4.8.1",
"tailwind-merge": "^2.5.4",
"vite": "^5.4.9",
"web-vitals": "^3.5.2",
@@ -5576,6 +5577,11 @@
"integrity": "sha512-+Fj43pSMwJs4KRrH/938Uf+uAELIgVBmQzg/q1YG10djyfA3TnrU8N8XzqCh/okZdszqBQTZf96idMfE5lnwTA==",
"dev": true
},
"node_modules/@socket.io/component-emitter": {
"version": "3.1.2",
"resolved": "https://registry.npmjs.org/@socket.io/component-emitter/-/component-emitter-3.1.2.tgz",
"integrity": "sha512-9BCxFwvbGg/RsZK9tjXd8s4UcwR0MWeFQ1XEKIQVVvAGJyINdrqKMcTRyLoK8Rse1GjzLV9cwjWV1olXRWEXVA=="
},
"node_modules/@svgr/babel-plugin-add-jsx-attribute": {
"version": "8.0.0",
"resolved": "https://registry.npmjs.org/@svgr/babel-plugin-add-jsx-attribute/-/babel-plugin-add-jsx-attribute-8.0.0.tgz",
@@ -8469,6 +8475,46 @@
"once": "^1.4.0"
}
},
"node_modules/engine.io-client": {
"version": "6.6.2",
"resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.6.2.tgz",
"integrity": "sha512-TAr+NKeoVTjEVW8P3iHguO1LO6RlUz9O5Y8o7EY0fU+gY1NYqas7NN3slpFtbXEsLMHk0h90fJMfKjRkQ0qUIw==",
"dependencies": {
"@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.1",
"engine.io-parser": "~5.2.1",
"ws": "~8.17.1",
"xmlhttprequest-ssl": "~2.1.1"
}
},
"node_modules/engine.io-client/node_modules/ws": {
"version": "8.17.1",
"resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz",
"integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==",
"engines": {
"node": ">=10.0.0"
},
"peerDependencies": {
"bufferutil": "^4.0.1",
"utf-8-validate": ">=5.0.2"
},
"peerDependenciesMeta": {
"bufferutil": {
"optional": true
},
"utf-8-validate": {
"optional": true
}
}
},
"node_modules/engine.io-parser": {
"version": "5.2.3",
"resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.3.tgz",
"integrity": "sha512-HqD3yTBfnBxIrbnM1DoD6Pcq8NECnh8d4As1Qgh0z5Gg3jRRIqijury0CL3ghu/edArpUYiYqQiDUQBIs4np3Q==",
"engines": {
"node": ">=10.0.0"
}
},
"node_modules/entities": {
"version": "4.5.0",
"resolved": "https://registry.npmjs.org/entities/-/entities-4.5.0.tgz",
@@ -22587,6 +22633,32 @@
"tslib": "^2.0.3"
}
},
"node_modules/socket.io-client": {
"version": "4.8.1",
"resolved": "https://registry.npmjs.org/socket.io-client/-/socket.io-client-4.8.1.tgz",
"integrity": "sha512-hJVXfu3E28NmzGk8o1sHhN3om52tRvwYeidbj7xKy2eIIse5IoKX3USlS6Tqt3BHAtflLIkCQBkzVrEEfWUyYQ==",
"dependencies": {
"@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.2",
"engine.io-client": "~6.6.1",
"socket.io-parser": "~4.2.4"
},
"engines": {
"node": ">=10.0.0"
}
},
"node_modules/socket.io-parser": {
"version": "4.2.4",
"resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.2.4.tgz",
"integrity": "sha512-/GbIKmo8ioc+NIWIhwdecY0ge+qVBSMdgxGygevmdHj24bsfgtCmcUUcQ5ZzcylGFHsN3k4HB4Cgkl96KVnuew==",
"dependencies": {
"@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.1"
},
"engines": {
"node": ">=10.0.0"
}
},
"node_modules/source-map": {
"version": "0.7.4",
"resolved": "https://registry.npmjs.org/source-map/-/source-map-0.7.4.tgz",
@@ -25317,6 +25389,14 @@
"integrity": "sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw==",
"dev": true
},
"node_modules/xmlhttprequest-ssl": {
"version": "2.1.2",
"resolved": "https://registry.npmjs.org/xmlhttprequest-ssl/-/xmlhttprequest-ssl-2.1.2.tgz",
"integrity": "sha512-TEU+nJVUUnA4CYJFLvK5X9AOeH4KvDvhIfm0vV1GaQRtchnG0hgK5p8hw/xjv8cunWYCsiPCSDzObPyhEwq3KQ==",
"engines": {
"node": ">=0.4.0"
}
},
"node_modules/xtend": {
"version": "4.0.2",
"resolved": "https://registry.npmjs.org/xtend/-/xtend-4.0.2.tgz",
+2 -1
View File
@@ -39,6 +39,7 @@
"react-textarea-autosize": "^8.5.4",
"remark-gfm": "^4.0.0",
"sirv-cli": "^3.0.0",
"socket.io-client": "^4.8.1",
"tailwind-merge": "^2.5.4",
"vite": "^5.4.9",
"web-vitals": "^3.5.2",
@@ -120,4 +121,4 @@
"public"
]
}
}
}
+79 -68
View File
@@ -1,13 +1,11 @@
import posthog from "posthog-js";
import React from "react";
import { io, Socket } from "socket.io-client";
import { Settings } from "#/services/settings";
import ActionType from "#/types/ActionType";
import EventLogger from "#/utils/event-logger";
import AgentState from "#/types/AgentState";
import { handleAssistantMessage } from "#/services/actions";
const RECONNECT_RETRIES = 5;
export enum WsClientProviderStatus {
STOPPED,
OPENING,
@@ -43,38 +41,46 @@ export function WsClientProvider({
settings,
children,
}: React.PropsWithChildren<WsClientProviderProps>) {
const wsRef = React.useRef<WebSocket | null>(null);
const sioRef = React.useRef<Socket | null>(null);
const tokenRef = React.useRef<string | null>(token);
const ghTokenRef = React.useRef<string | null>(ghToken);
const closeRef = React.useRef<ReturnType<typeof setTimeout> | null>(null);
const disconnectRef = React.useRef<ReturnType<typeof setTimeout> | null>(
null,
);
const [status, setStatus] = React.useState(WsClientProviderStatus.STOPPED);
const [events, setEvents] = React.useState<Record<string, unknown>[]>([]);
const [retryCount, setRetryCount] = React.useState(RECONNECT_RETRIES);
function send(event: Record<string, unknown>) {
if (!wsRef.current) {
if (!sioRef.current) {
EventLogger.error("WebSocket is not connected.");
return;
}
wsRef.current.send(JSON.stringify(event));
sioRef.current.emit("oh_action", event);
}
function handleOpen() {
setRetryCount(RECONNECT_RETRIES);
function handleConnect() {
setStatus(WsClientProviderStatus.OPENING);
const initEvent = {
const initEvent: Record<string, unknown> = {
action: ActionType.INIT,
args: settings,
};
if (token) {
initEvent.token = token;
}
if (ghToken) {
initEvent.github_token = ghToken;
}
if (events.length) {
// Wrong. Events is out of sync here...
initEvent.latest_event_id = `${events[events.length - 1].id}`;
}
send(initEvent);
}
function handleMessage(messageEvent: MessageEvent) {
const event = JSON.parse(messageEvent.data);
function handleMessage(event: Record<string, unknown>) {
setEvents((prevEvents) => [...prevEvents, event]);
if (event.extras?.agent_state === AgentState.INIT) {
setStatus(WsClientProviderStatus.ACTIVE);
}
const extras = event.extras as Record<string, unknown>;
if (
status !== WsClientProviderStatus.ACTIVE &&
event?.observation === "error"
@@ -82,93 +88,98 @@ export function WsClientProvider({
setStatus(WsClientProviderStatus.ERROR);
}
handleAssistantMessage(event);
}
function handleClose() {
if (retryCount) {
setTimeout(() => {
setRetryCount(retryCount - 1);
}, 1000);
if (event.token) {
setStatus(WsClientProviderStatus.ACTIVE);
} else {
setStatus(WsClientProviderStatus.STOPPED);
setEvents([]);
handleAssistantMessage(event);
}
wsRef.current = null;
}
function handleError(event: Event) {
function handleDisconnect() {
setStatus(WsClientProviderStatus.STOPPED);
// setEvents([]);
// sioRef.current = null;
}
function handleError() {
posthog.capture("socket_error");
EventLogger.event(event, "SOCKET ERROR");
setStatus(WsClientProviderStatus.ERROR);
sioRef.current?.disconnect();
}
// Connect websocket
React.useEffect(() => {
let ws = wsRef.current;
let sio = sioRef.current;
// If disabled close any existing websockets...
if (!enabled || !retryCount) {
if (ws) {
ws.close();
// If disabled disconnect any existing websockets...
if (!enabled) {
if (sio) {
sio.disconnect();
}
wsRef.current = null;
return () => {};
}
// If there is no websocket or the tokens have changed or the current websocket is closed,
// If there is no websocket or the tokens have changed or the current websocket is disconnected,
// create a new one
if (
!ws ||
!sio ||
(tokenRef.current && token !== tokenRef.current) ||
ghToken !== ghTokenRef.current ||
ws.readyState === WebSocket.CLOSED ||
ws.readyState === WebSocket.CLOSING
ghToken !== ghTokenRef.current
) {
ws?.close();
sio?.disconnect();
const baseUrl =
import.meta.env.VITE_BACKEND_BASE_URL || window?.location.host;
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
let wsUrl = `${protocol}//${baseUrl}/ws`;
if (events.length) {
wsUrl += `?latest_event_id=${events[events.length - 1].id}`;
}
ws = new WebSocket(wsUrl, [
"openhands",
token || "NO_JWT",
ghToken || "NO_GITHUB",
]);
sio = io(baseUrl, {
transports: ["websocket"],
// extraHeaders: {
// Testy: "TESTER"
// },
// We force a new connection, because the headers may have changed.
// forceNew: true,
// Had to do this for now because reconnection actually starts a new session,
// which we don't want - The reconnect has the same headers as the original
// which don't include the original session id
// reconnection: false,
// reconnectionDelay: 1000,
// reconnectionDelayMax : 5000,
// reconnectionAttempts: 5
});
}
ws.addEventListener("open", handleOpen);
ws.addEventListener("message", handleMessage);
ws.addEventListener("error", handleError);
ws.addEventListener("close", handleClose);
wsRef.current = ws;
sio.on("connect", handleConnect);
sio.on("oh_event", handleMessage);
sio.on("connect_error", handleError);
sio.on("connect_failed", handleError);
sio.on("disconnect", handleDisconnect);
sioRef.current = sio;
tokenRef.current = token;
ghTokenRef.current = ghToken;
return () => {
ws.removeEventListener("open", handleOpen);
ws.removeEventListener("message", handleMessage);
ws.removeEventListener("error", handleError);
ws.removeEventListener("close", handleClose);
sio.off("connect", handleConnect);
sio.off("oh_event", handleMessage);
sio.off("connect_error", handleError);
sio.off("connect_failed", handleError);
sio.off("disconnect", handleDisconnect);
};
}, [enabled, token, ghToken, retryCount]);
}, [enabled, token, ghToken, events]);
// Strict mode mounts and unmounts each component twice, so we have to wait in the destructor
// before actually closing the socket and cancel the operation if the component gets remounted.
// before actually disconnecting the socket and cancel the operation if the component gets remounted.
React.useEffect(() => {
const timeout = closeRef.current;
const timeout = disconnectRef.current;
if (timeout != null) {
clearTimeout(timeout);
}
return () => {
closeRef.current = setTimeout(() => {
const ws = wsRef.current;
if (ws) {
ws.removeEventListener("close", handleClose);
ws.close();
disconnectRef.current = setTimeout(() => {
const sio = sioRef.current;
if (sio) {
sio.off("disconnect", handleDisconnect);
sio.disconnect();
}
}, 100);
};
+7
View File
@@ -82,6 +82,13 @@ export default defineConfig(({ mode }) => {
changeOrigin: true,
secure: !INSECURE_SKIP_VERIFY,
},
"/socket.io": {
target: WS_URL,
ws: true,
changeOrigin: true,
secure: !INSECURE_SKIP_VERIFY,
//rewriteWsOrigin: true,
}
},
},
ssr: {
+136 -121
View File
@@ -1,16 +1,16 @@
import asyncio
import os
import re
import tempfile
import time
import uuid
import warnings
import jwt
import requests
import socketio
from pathspec import PathSpec
from pathspec.patterns import GitWildMatchPattern
from openhands.core.schema.action import ActionType
from openhands.security.options import SecurityAnalyzers
from openhands.server.data_models.feedback import FeedbackDataModel, store_feedback
from openhands.server.github import (
@@ -19,8 +19,9 @@ from openhands.server.github import (
UserVerifier,
authenticate_github_user,
)
from openhands.server.pg_socket import AsyncPostgresAdapter
from openhands.storage import get_file_store
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.async_utils import call_async_from_sync, call_sync_from_async
with warnings.catch_warnings():
warnings.simplefilter('ignore')
@@ -33,7 +34,6 @@ from fastapi import (
HTTPException,
Request,
UploadFile,
WebSocket,
status,
)
from fastapi.responses import FileResponse, JSONResponse
@@ -52,7 +52,6 @@ from openhands.events.action import (
NullAction,
)
from openhands.events.observation import (
AgentStateChangedObservation,
ErrorObservation,
FileReadObservation,
FileWriteObservation,
@@ -250,122 +249,6 @@ async def attach_session(request: Request, call_next):
return response
@app.websocket('/ws')
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for receiving events from the client (i.e., the browser).
Once connected, the client can send various actions:
- Initialize the agent:
session management, and event streaming.
```json
{"action": "initialize", "args": {"LLM_MODEL": "ollama/llama3", "AGENT": "CodeActAgent", "LANGUAGE": "en", "LLM_API_KEY": "ollama"}}
Args:
```
websocket (WebSocket): The WebSocket connection object.
- Start a new development task:
```json
{"action": "start", "args": {"task": "write a bash script that prints hello"}}
```
- Send a message:
```json
{"action": "message", "args": {"content": "Hello, how are you?", "image_urls": ["base64_url1", "base64_url2"]}}
```
- Write contents to a file:
```json
{"action": "write", "args": {"path": "./greetings.txt", "content": "Hello, OpenHands?"}}
```
- Read the contents of a file:
```json
{"action": "read", "args": {"path": "./greetings.txt"}}
```
- Run a command:
```json
{"action": "run", "args": {"command": "ls -l", "thought": "", "confirmation_state": "confirmed"}}
```
- Run an IPython command:
```json
{"action": "run_ipython", "args": {"command": "print('Hello, IPython!')"}}
```
- Open a web page:
```json
{"action": "browse", "args": {"url": "https://arxiv.org/html/2402.01030v2"}}
```
- Add a task to the root_task:
```json
{"action": "add_task", "args": {"task": "Implement feature X"}}
```
- Update a task in the root_task:
```json
{"action": "modify_task", "args": {"id": "0", "state": "in_progress", "thought": ""}}
```
- Change the agent's state:
```json
{"action": "change_agent_state", "args": {"state": "paused"}}
```
- Finish the task:
```json
{"action": "finish", "args": {}}
```
"""
# Get protocols from Sec-WebSocket-Protocol header
protocols = websocket.headers.get('sec-websocket-protocol', '').split(', ')
# The first protocol should be our real protocol (e.g. 'openhands')
# The second protocol should contain our auth token
if len(protocols) < 3:
logger.error('Expected 3 websocket protocols, got %d', len(protocols))
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
real_protocol = protocols[0]
jwt_token = protocols[1] if protocols[1] != 'NO_JWT' else ''
github_token = protocols[2] if protocols[2] != 'NO_GITHUB' else ''
if not await authenticate_github_user(github_token):
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
await asyncio.wait_for(websocket.accept(subprotocol=real_protocol), 10)
if jwt_token:
sid = get_sid_from_token(jwt_token, config.jwt_secret)
if sid == '':
await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
await websocket.close()
return
else:
sid = str(uuid.uuid4())
jwt_token = sign_token({'sid': sid}, config.jwt_secret)
logger.info(f'New session: {sid}')
session = session_manager.add_or_restart_session(sid, websocket)
await websocket.send_json({'token': jwt_token, 'status': 'ok'})
latest_event_id = -1
if websocket.query_params.get('latest_event_id'):
latest_event_id = int(websocket.query_params.get('latest_event_id'))
async_stream = AsyncEventStreamWrapper(
session.agent_session.event_stream, latest_event_id + 1
)
async for event in async_stream:
if isinstance(
event,
(
NullAction,
NullObservation,
ChangeAgentStateAction,
AgentStateChangedObservation,
),
):
continue
await websocket.send_json(event_to_dict(event))
await session.loop_recv()
@app.get('/api/options/models')
async def get_litellm_models() -> list[str]:
"""
@@ -930,3 +813,135 @@ class SPAStaticFiles(StaticFiles):
app.mount('/', SPAStaticFiles(directory='./frontend/build', html=True), name='dist')
use_manager = os.getenv('DB_HOST') or os.getenv('GCP_DB_INSTANCE')
manager = AsyncPostgresAdapter() if use_manager else None
if manager:
call_async_from_sync(manager.setup, 10)
sio = socketio.AsyncServer(
async_mode='asgi', cors_allowed_origins='*', client_manager=manager
)
app = socketio.ASGIApp(sio, other_asgi_app=app)
@sio.event
async def connect(connection_id: str, environ):
logger.info(f'sio:connect: {connection_id}')
@sio.event
async def oh_action(connection_id: str, data: dict):
"""WebSocket endpoint for receiving events from the client (i.e., the browser).
Once connected, the client can send various actions:
- Initialize the agent:
session management, and event streaming.
```json
{"action": "initialize", "args": {"LLM_MODEL": "ollama/llama3", "AGENT": "CodeActAgent", "LANGUAGE": "en", "LLM_API_KEY": "ollama"}}
Args:
```
websocket (WebSocket): The WebSocket connection object.
- Start a new development task:
```json
{"action": "start", "args": {"task": "write a bash script that prints hello"}}
```
- Send a message:
```json
{"action": "message", "args": {"content": "Hello, how are you?", "image_urls": ["base64_url1", "base64_url2"]}}
```
- Write contents to a file:
```json
{"action": "write", "args": {"path": "./greetings.txt", "content": "Hello, OpenHands?"}}
```
- Read the contents of a file:
```json
{"action": "read", "args": {"path": "./greetings.txt"}}
```
- Run a command:
```json
{"action": "run", "args": {"command": "ls -l", "thought": "", "confirmation_state": "confirmed"}}
```
- Run an IPython command:
```json
{"action": "run_ipython", "args": {"command": "print('Hello, IPython!')"}}
```
- Open a web page:
```json
{"action": "browse", "args": {"url": "https://arxiv.org/html/2402.01030v2"}}
```
- Add a task to the root_task:
```json
{"action": "add_task", "args": {"task": "Implement feature X"}}
```
- Update a task in the root_task:
```json
{"action": "modify_task", "args": {"id": "0", "state": "in_progress", "thought": ""}}
```
- Change the agent's state:
```json
{"action": "change_agent_state", "args": {"state": "paused"}}
```
- Finish the task:
```json
{"action": "finish", "args": {}}
```
"""
# If it's an init, we do it here.
action = data.get('action', '')
if action == ActionType.INIT:
await init_connection(connection_id, data)
return
logger.info(f'sio:oh_action:{connection_id}')
session = session_manager.get_local_session(connection_id)
await session.dispatch(data)
async def init_connection(connection_id: str, data: dict):
gh_token = data.pop('gh_token', None)
if not await authenticate_github_user(gh_token):
raise RuntimeError(status.WS_1008_POLICY_VIOLATION)
token = data.pop('token', None)
if token:
sid = get_sid_from_token(token, config.jwt_secret)
if sid == '':
await sio.send({'error': 'Invalid token', 'error_code': 401})
return
logger.info(f'Existing session: {sid}')
else:
sid = connection_id
logger.info(f'New session: {sid}')
token = sign_token({'sid': sid}, config.jwt_secret)
await sio.emit('oh_event', {'token': token, 'status': 'ok'}, to=connection_id)
latest_event_id = int(data.pop('latest_event_id', -1))
# The session in question should exist, but may not actually be running locally...
session = await session_manager.init_or_join_local_session(
sio, sid, connection_id, data
)
# Send events
async_stream = AsyncEventStreamWrapper(
session.agent_session.event_stream, latest_event_id + 1
)
async for event in async_stream:
if isinstance(
event,
(
NullAction,
NullObservation,
ChangeAgentStateAction,
),
):
continue
await sio.emit('oh_event', event_to_dict(event), to=connection_id)
@sio.event
async def disconnect(connection_id: str):
logger.info(f'sio:disconnect:{connection_id}')
await session_manager.disconnect_from_local_session(connection_id)
+242
View File
@@ -0,0 +1,242 @@
import asyncio
import json
import os
import pickle
from asyncio import Task
from typing import Any, Dict, Optional
import asyncpg
from asyncpg.exceptions import PostgresError
from asyncpg.pool import Pool
from google.cloud.sql.connector import Connector
from socketio.async_pubsub_manager import AsyncPubSubManager
def has_binary(obj: Any, to_json: bool = False) -> bool:
if not obj or not isinstance(obj, (dict, list, bytes, bytearray, memoryview)):
return False
if isinstance(obj, (bytes, bytearray, memoryview)):
return True
if isinstance(obj, list):
return any(has_binary(item) for item in obj)
if isinstance(obj, dict):
return any(has_binary(v) for v in obj.values())
if hasattr(obj, 'to_json') and callable(obj.to_json) and not to_json:
return has_binary(obj.to_json(), True)
return False
class AsyncPostgresAdapter(AsyncPubSubManager):
def __init__(
self,
channel: str = 'socketio',
table_name: str = 'socket_io_attachments',
payload_threshold: int = 8000,
cleanup_interval: int = 30000,
):
self.channel = channel
super().__init__(channel=channel)
self.table_name = table_name
self.payload_threshold = payload_threshold
self.cleanup_interval = cleanup_interval
self._cleanup_timer: Optional[Task] = None
self._client: Optional[asyncpg.Connection] = None
self.pool: Optional[Pool] = None
# Connection configs
self.db_host = os.environ.get('DB_HOST')
self.db_user = os.environ.get('DB_USER')
self.db_pass = os.environ.get('DB_PASS', '').strip()
self.db_name = os.environ.get('DB_NAME')
# GCP configs
self.gcp_instance = os.environ.get('GCP_DB_INSTANCE')
self.gcp_project = os.environ.get('GCP_PROJECT')
self.gcp_region = os.environ.get('GCP_REGION')
self.connector = Connector() if self.gcp_instance else None
async def setup(self):
if not self.pool:
self.pool = await self._create_pool()
await self._create_db()
await self._init_client()
async def _create_pool(self) -> Pool:
if self.gcp_instance:
return await asyncpg.create_pool(
lambda: self._get_gcp_connection(self.db_name)
)
else:
return await asyncpg.create_pool(
user=self.db_user,
password=self.db_pass,
host=self.db_host,
database=self.db_name,
)
async def _get_gcp_connection(
self, db_name: Optional[str] = None
) -> asyncpg.Connection:
instance_string = f'{self.gcp_project}:{self.gcp_region}:{self.gcp_instance}'
conn = await self.connector.connect_async(
instance_connection_string=instance_string,
driver='asyncpg',
user=self.db_user,
password=self.db_pass,
db=db_name or self.db_name,
)
return conn
async def _init_client(self) -> None:
if not self.pool:
raise RuntimeError('Pool not initialized')
try:
self._client = await self.pool.acquire()
await self._client.execute(f'LISTEN "{self.channel}"')
self._client.add_listener(self.channel, self._on_notification)
if not self._cleanup_timer:
self._schedule_cleanup()
except PostgresError as e:
self.logger.error(f'Error initializing client: {e}')
await asyncio.sleep(2)
await self._init_client()
def _schedule_cleanup(self) -> None:
async def cleanup() -> None:
if not self.pool:
return
try:
await self.pool.execute(
f"DELETE FROM {self.table_name} WHERE created_at < now() - interval '{self.cleanup_interval} milliseconds'"
)
except PostgresError as e:
self.logger.error(f'Cleanup error: {e}')
self._cleanup_timer = asyncio.create_task(cleanup())
await asyncio.sleep(self.cleanup_interval / 1000)
self._cleanup_timer = asyncio.create_task(cleanup())
async def _publish_with_attachment(self, data: Dict) -> None:
if not self.pool:
raise RuntimeError('Pool not initialized')
payload = pickle.dumps(data)
result = await self.pool.fetchrow(
f'INSERT INTO {self.table_name} (payload) VALUES ($1) RETURNING id', payload
)
if not result:
raise RuntimeError('Failed to insert payload')
notification = {
'uid': self.uid,
'type': data['type'],
'attachmentId': result['id'],
}
await self.pool.execute(
'SELECT pg_notify($1, $2)', self.channel, json.dumps(notification)
)
async def _publish(self, data: Dict) -> None:
if not self.pool:
raise RuntimeError('Pool not initialized')
try:
data['uid'] = self.uid
if has_binary(data) or len(json.dumps(data)) > self.payload_threshold:
await self._publish_with_attachment(data)
return
await self.pool.execute(
'SELECT pg_notify($1, $2)', self.channel, json.dumps(data)
)
except PostgresError as e:
self.logger.error(f'Publish error: {e}')
raise
async def _on_notification(self, conn, pid, channel, payload) -> None:
if not self.pool:
return
try:
data = json.loads(payload)
if data.get('uid') == self.uid:
return
if 'attachmentId' in data:
result = await self.pool.fetchrow(
f'SELECT payload FROM {self.table_name} WHERE id = $1',
data['attachmentId'],
)
if not result:
self.logger.error(f"Attachment {data['attachmentId']} not found")
return
data = pickle.loads(result['payload'])
await self._handle_message(data)
except Exception as e:
self.logger.error(f'Notification error: {e}')
async def close(self) -> None:
if self._cleanup_timer:
self._cleanup_timer.cancel()
if self._client and self.pool:
await self.pool.release(self._client)
if self.connector:
await self.connector.close_async()
await super().close()
async def _create_db(self) -> None:
try:
# Connect to default postgres DB first
if self.gcp_instance:
sys_conn = await self._get_gcp_connection('postgres')
else:
sys_conn = await asyncpg.connect(
user=self.db_user,
password=self.db_pass,
host=self.db_host,
database='postgres',
)
try:
# Create DB if needed
exists = await sys_conn.fetchval(
'SELECT 1 FROM pg_database WHERE datname = $1', self.db_name
)
if not exists:
await sys_conn.execute(f'CREATE DATABASE "{self.db_name}"')
finally:
await sys_conn.close()
# Create attachments table
if not self.pool:
raise RuntimeError('Pool not initialized')
await self.pool.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
id SERIAL PRIMARY KEY,
payload BYTEA NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
except Exception as e:
self.logger.error(f'Database creation error: {e}')
raise
+43 -6
View File
@@ -1,7 +1,9 @@
import asyncio
import time
from dataclasses import dataclass
from dataclasses import dataclass, field
from fastapi import WebSocket
import socketio
from openhands.core.config import AppConfig
from openhands.core.logger import openhands_logger as logger
@@ -15,11 +17,8 @@ from openhands.storage.files import FileStore
class SessionManager:
config: AppConfig
file_store: FileStore
def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session:
return Session(
sid=sid, file_store=self.file_store, ws=ws_conn, config=self.config
)
local_sessions_by_sid: dict[str, Session] = field(default_factory=dict)
local_sessions_by_connection_id: dict[str, Session] = field(default_factory=dict)
async def attach_to_conversation(self, sid: str) -> Conversation | None:
start_time = time.time()
@@ -35,3 +34,41 @@ class SessionManager:
async def detach_from_conversation(self, conversation: Conversation):
await conversation.disconnect()
async def init_or_join_local_session(self, sio: socketio.AsyncServer, sid: str, connection_id: str, data: dict):
""" If there is no local session running, initialize one """
session = self.local_sessions_by_sid.get(sid)
if not session:
session = Session(
sid=sid, file_store=self.file_store, config=self.config, sio=sio, ws=None
)
session.connect(connection_id)
self.local_sessions_by_sid[sid] = session
self.local_sessions_by_connection_id[connection_id] = session
await session.initialize_agent(data)
else:
session.connect(connection_id)
self.local_sessions_by_connection_id[connection_id] = session
return session
def get_local_session(self, connection_id: str) -> Session:
return self.local_sessions_by_connection_id[connection_id]
async def disconnect_from_local_session(self, connection_id: str):
session = self.local_sessions_by_connection_id.pop(connection_id, None)
if not session:
# This can occur if the init action was never run.
logger.warning(f'disconnect_from_uninitialized_session:{connection_id}')
return
if session.disconnect(connection_id):
asyncio.create_task(self._check_and_close_session(session))
async def _check_and_close_session(self, session: Session):
# Once there have been no connections to a session for a reasonable period, we close it
try:
await asyncio.sleep(15)
finally:
# If the sleep was cancelled, we still want to close these
if not session.connection_ids:
session.close()
self.local_sessions_by_sid.pop(session.sid)
+29 -24
View File
@@ -1,6 +1,7 @@
import asyncio
import time
import socketio
from fastapi import WebSocket, WebSocketDisconnect
from openhands.controller.agent import Agent
@@ -23,22 +24,30 @@ from openhands.events.stream import EventStreamSubscriber
from openhands.llm.llm import LLM
from openhands.server.session.agent_session import AgentSession
from openhands.storage.files import FileStore
from openhands.utils.shutdown_listener import should_continue
from openhands.utils.async_utils import wait_all
class Session:
sid: str
websocket: WebSocket | None
sio: socketio.AsyncServer | None
connection_ids: set[str]
last_active_ts: int = 0
is_alive: bool = True
agent_session: AgentSession
loop: asyncio.AbstractEventLoop
def __init__(
self, sid: str, ws: WebSocket | None, config: AppConfig, file_store: FileStore
self,
sid: str,
ws: WebSocket | None,
config: AppConfig,
file_store: FileStore,
sio: socketio.AsyncServer | None,
):
self.sid = sid
self.websocket = ws
self.sio = sio
self.last_active_ts = int(time.time())
self.agent_session = AgentSession(
sid, file_store, status_callback=self.queue_status_message
@@ -47,31 +56,21 @@ class Session:
EventStreamSubscriber.SERVER, self.on_event, self.sid
)
self.config = config
self.connection_ids = set()
self.loop = asyncio.get_event_loop()
def connect(self, connection_id: str):
self.connection_ids.add(connection_id)
def disconnect(self, connection_id: str) -> bool:
self.connection_ids.remove(connection_id)
return not self.connection_ids
def close(self):
self.is_alive = False
self.agent_session.close()
async def loop_recv(self):
try:
if self.websocket is None:
return
while should_continue():
try:
data = await self.websocket.receive_json()
except ValueError:
await self.send_error('Invalid JSON')
continue
await self.dispatch(data)
except WebSocketDisconnect:
logger.info('WebSocket disconnected, sid: %s', self.sid)
self.close()
except RuntimeError as e:
logger.exception('Error in loop_recv: %s', e)
self.close()
async def _initialize_agent(self, data: dict):
async def initialize_agent(self, data: dict):
self.agent_session.event_stream.add_event(
ChangeAgentStateAction(AgentState.LOADING), EventSource.ENVIRONMENT
)
@@ -159,7 +158,7 @@ class Session:
async def dispatch(self, data: dict):
action = data.get('action', '')
if action == ActionType.INIT:
await self._initialize_agent(data)
await self.initialize_agent(data)
return
event = event_from_dict(data.copy())
# This checks if the model supports images
@@ -180,9 +179,15 @@ class Session:
async def send(self, data: dict[str, object]) -> bool:
try:
if self.websocket is None or not self.is_alive:
if not self.is_alive:
return False
await self.websocket.send_json(data)
if self.websocket:
await self.websocket.send_json(data)
if self.sio:
await wait_all(
self.sio.emit('oh_event', data, to=connection_id)
for connection_id in self.connection_ids
)
await asyncio.sleep(0.001) # This flushes the data to the client
self.last_active_ts = int(time.time())
return True
Generated
+625 -475
View File
File diff suppressed because it is too large Load Diff
+5
View File
@@ -64,6 +64,9 @@ modal = "^0.64.145"
runloop-api-client = "0.7.0"
pygithub = "^2.5.0"
openhands-aci = "^0.1.0"
python-socketio = "^5.11.4"
asyncpg = "^0.30.0"
cloud-sql-python-connector = "^1.13.0"
[tool.poetry.group.llama-index.dependencies]
llama-index = "*"
@@ -95,6 +98,7 @@ reportlab = "*"
[tool.coverage.run]
concurrency = ["gevent"]
[tool.poetry.group.runtime.dependencies]
jupyterlab = "*"
notebook = "*"
@@ -125,6 +129,7 @@ ignore = ["D1"]
[tool.ruff.lint.pydocstyle]
convention = "google"
[tool.poetry.group.evaluation.dependencies]
streamlit = "*"
whatthepatch = "*"