mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-08 22:38:05 -05:00
Refactor sessions a bit, and fix issue where runtimes get killed (#4900)
This commit is contained in:
2
.github/workflows/ghcr-build.yml
vendored
2
.github/workflows/ghcr-build.yml
vendored
@@ -286,7 +286,6 @@ jobs:
|
||||
image_name=ghcr.io/${{ github.repository_owner }}/runtime:${{ env.RELEVANT_SHA }}-${{ matrix.base_image }}
|
||||
image_name=$(echo $image_name | tr '[:upper:]' '[:lower:]')
|
||||
|
||||
SKIP_CONTAINER_LOGS=true \
|
||||
TEST_RUNTIME=eventstream \
|
||||
SANDBOX_USER_ID=$(id -u) \
|
||||
SANDBOX_RUNTIME_CONTAINER_IMAGE=$image_name \
|
||||
@@ -364,7 +363,6 @@ jobs:
|
||||
image_name=ghcr.io/${{ github.repository_owner }}/runtime:${{ env.RELEVANT_SHA }}-${{ matrix.base_image }}
|
||||
image_name=$(echo $image_name | tr '[:upper:]' '[:lower:]')
|
||||
|
||||
SKIP_CONTAINER_LOGS=true \
|
||||
TEST_RUNTIME=eventstream \
|
||||
SANDBOX_USER_ID=$(id -u) \
|
||||
SANDBOX_RUNTIME_CONTAINER_IMAGE=$image_name \
|
||||
|
||||
@@ -59,7 +59,7 @@ docker run # ...
|
||||
-e RUNTIME=remote \
|
||||
-e SANDBOX_REMOTE_RUNTIME_API_URL="https://runtime.app.all-hands.dev" \
|
||||
-e SANDBOX_API_KEY="your-all-hands-api-key" \
|
||||
-e SANDBOX_KEEP_REMOTE_RUNTIME_ALIVE="true" \
|
||||
-e SANDBOX_KEEP_RUNTIME_ALIVE="true" \
|
||||
# ...
|
||||
```
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ def get_config(
|
||||
browsergym_eval_env=env_id,
|
||||
api_key=os.environ.get('ALLHANDS_API_KEY', None),
|
||||
remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
|
||||
keep_remote_runtime_alive=False,
|
||||
keep_runtime_alive=False,
|
||||
),
|
||||
# do not mount workspace
|
||||
workspace_base=None,
|
||||
|
||||
@@ -72,7 +72,7 @@ def get_config(
|
||||
timeout=300,
|
||||
api_key=os.environ.get('ALLHANDS_API_KEY', None),
|
||||
remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
|
||||
keep_remote_runtime_alive=False,
|
||||
keep_runtime_alive=False,
|
||||
),
|
||||
# do not mount workspace
|
||||
workspace_base=None,
|
||||
|
||||
@@ -145,7 +145,7 @@ def get_config(
|
||||
platform='linux/amd64',
|
||||
api_key=os.environ.get('ALLHANDS_API_KEY', None),
|
||||
remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
|
||||
keep_remote_runtime_alive=False,
|
||||
keep_runtime_alive=False,
|
||||
remote_runtime_init_timeout=1800,
|
||||
),
|
||||
# do not mount workspace
|
||||
|
||||
@@ -36,7 +36,7 @@ class SandboxConfig:
|
||||
|
||||
remote_runtime_api_url: str = 'http://localhost:8000'
|
||||
local_runtime_url: str = 'http://localhost'
|
||||
keep_remote_runtime_alive: bool = True
|
||||
keep_runtime_alive: bool = True
|
||||
api_key: str | None = None
|
||||
base_container_image: str = 'nikolaik/python-nodejs:python3.12-nodejs22' # default to nikolaik/python-nodejs:python3.12-nodejs22 for eventstream runtime
|
||||
runtime_container_image: str | None = None
|
||||
|
||||
18
openhands/runtime/impl/eventstream/containers.py
Normal file
18
openhands/runtime/impl/eventstream/containers.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import docker
|
||||
|
||||
|
||||
def remove_all_containers(prefix: str):
|
||||
docker_client = docker.from_env()
|
||||
|
||||
try:
|
||||
containers = docker_client.containers.list(all=True)
|
||||
for container in containers:
|
||||
try:
|
||||
if container.name.startswith(prefix):
|
||||
container.remove(force=True)
|
||||
except docker.errors.APIError:
|
||||
pass
|
||||
except docker.errors.NotFound:
|
||||
pass
|
||||
except docker.errors.NotFound: # yes, this can happen!
|
||||
pass
|
||||
@@ -1,8 +1,9 @@
|
||||
import atexit
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import threading
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
from zipfile import ZipFile
|
||||
|
||||
@@ -35,6 +36,7 @@ from openhands.events.serialization import event_to_dict, observation_from_dict
|
||||
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.runtime.builder import DockerRuntimeBuilder
|
||||
from openhands.runtime.impl.eventstream.containers import remove_all_containers
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
from openhands.runtime.utils import find_available_tcp_port
|
||||
from openhands.runtime.utils.request import send_request
|
||||
@@ -42,6 +44,15 @@ from openhands.runtime.utils.runtime_build import build_runtime_image
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
|
||||
CONTAINER_NAME_PREFIX = 'openhands-runtime-'
|
||||
|
||||
|
||||
def remove_all_runtime_containers():
|
||||
remove_all_containers(CONTAINER_NAME_PREFIX)
|
||||
|
||||
|
||||
atexit.register(remove_all_runtime_containers)
|
||||
|
||||
|
||||
class LogBuffer:
|
||||
"""Synchronous buffer for Docker container logs.
|
||||
@@ -114,8 +125,6 @@ class EventStreamRuntime(Runtime):
|
||||
env_vars (dict[str, str] | None, optional): Environment variables to set. Defaults to None.
|
||||
"""
|
||||
|
||||
container_name_prefix = 'openhands-runtime-'
|
||||
|
||||
# Need to provide this method to allow inheritors to init the Runtime
|
||||
# without initting the EventStreamRuntime.
|
||||
def init_base_runtime(
|
||||
@@ -158,7 +167,7 @@ class EventStreamRuntime(Runtime):
|
||||
self.docker_client: docker.DockerClient = self._init_docker_client()
|
||||
self.base_container_image = self.config.sandbox.base_container_image
|
||||
self.runtime_container_image = self.config.sandbox.runtime_container_image
|
||||
self.container_name = self.container_name_prefix + sid
|
||||
self.container_name = CONTAINER_NAME_PREFIX + sid
|
||||
self.container = None
|
||||
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
|
||||
|
||||
@@ -173,10 +182,6 @@ class EventStreamRuntime(Runtime):
|
||||
f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}',
|
||||
)
|
||||
|
||||
self.skip_container_logs = (
|
||||
os.environ.get('SKIP_CONTAINER_LOGS', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
self.init_base_runtime(
|
||||
config,
|
||||
event_stream,
|
||||
@@ -189,7 +194,15 @@ class EventStreamRuntime(Runtime):
|
||||
|
||||
async def connect(self):
|
||||
self.send_status_message('STATUS$STARTING_RUNTIME')
|
||||
if not self.attach_to_existing:
|
||||
try:
|
||||
await call_sync_from_async(self._attach_to_container)
|
||||
except docker.errors.NotFound as e:
|
||||
if self.attach_to_existing:
|
||||
self.log(
|
||||
'error',
|
||||
f'Container {self.container_name} not found.',
|
||||
)
|
||||
raise e
|
||||
if self.runtime_container_image is None:
|
||||
if self.base_container_image is None:
|
||||
raise ValueError(
|
||||
@@ -210,13 +223,12 @@ class EventStreamRuntime(Runtime):
|
||||
await call_sync_from_async(self._init_container)
|
||||
self.log('info', f'Container started: {self.container_name}')
|
||||
|
||||
else:
|
||||
await call_sync_from_async(self._attach_to_container)
|
||||
|
||||
if not self.attach_to_existing:
|
||||
self.log('info', f'Waiting for client to become ready at {self.api_url}...')
|
||||
self.send_status_message('STATUS$WAITING_FOR_CLIENT')
|
||||
self.send_status_message('STATUS$WAITING_FOR_CLIENT')
|
||||
|
||||
await call_sync_from_async(self._wait_until_alive)
|
||||
|
||||
if not self.attach_to_existing:
|
||||
self.log('info', 'Runtime is ready.')
|
||||
|
||||
@@ -227,7 +239,8 @@ class EventStreamRuntime(Runtime):
|
||||
'debug',
|
||||
f'Container initialized with plugins: {[plugin.name for plugin in self.plugins]}',
|
||||
)
|
||||
self.send_status_message(' ')
|
||||
if not self.attach_to_existing:
|
||||
self.send_status_message(' ')
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=1)
|
||||
@@ -332,13 +345,12 @@ class EventStreamRuntime(Runtime):
|
||||
self.log('debug', f'Container started. Server url: {self.api_url}')
|
||||
self.send_status_message('STATUS$CONTAINER_STARTED')
|
||||
except docker.errors.APIError as e:
|
||||
# check 409 error
|
||||
if '409' in str(e):
|
||||
self.log(
|
||||
'warning',
|
||||
f'Container {self.container_name} already exists. Removing...',
|
||||
)
|
||||
self._close_containers(rm_all_containers=True)
|
||||
remove_all_containers(self.container_name)
|
||||
return self._init_container()
|
||||
|
||||
else:
|
||||
@@ -414,42 +426,18 @@ class EventStreamRuntime(Runtime):
|
||||
Parameters:
|
||||
- rm_all_containers (bool): Whether to remove all containers with the 'openhands-sandbox-' prefix
|
||||
"""
|
||||
|
||||
if self.log_buffer:
|
||||
self.log_buffer.close()
|
||||
|
||||
if self.session:
|
||||
self.session.close()
|
||||
|
||||
if self.attach_to_existing:
|
||||
if self.config.sandbox.keep_runtime_alive or self.attach_to_existing:
|
||||
return
|
||||
self._close_containers(rm_all_containers)
|
||||
|
||||
def _close_containers(self, rm_all_containers: bool = True):
|
||||
try:
|
||||
containers = self.docker_client.containers.list(all=True)
|
||||
for container in containers:
|
||||
try:
|
||||
# If the app doesn't shut down properly, it can leave runtime containers on the system. This ensures
|
||||
# that all 'openhands-sandbox-' containers are removed as well.
|
||||
if rm_all_containers and container.name.startswith(
|
||||
self.container_name_prefix
|
||||
):
|
||||
container.remove(force=True)
|
||||
elif container.name == self.container_name:
|
||||
if not self.skip_container_logs:
|
||||
logs = container.logs(tail=1000).decode('utf-8')
|
||||
self.log(
|
||||
'debug',
|
||||
f'==== Container logs on close ====\n{logs}\n==== End of container logs ====',
|
||||
)
|
||||
container.remove(force=True)
|
||||
except docker.errors.APIError:
|
||||
pass
|
||||
except docker.errors.NotFound:
|
||||
pass
|
||||
except docker.errors.NotFound: # yes, this can happen!
|
||||
pass
|
||||
close_prefix = (
|
||||
CONTAINER_NAME_PREFIX if rm_all_containers else self.container_name
|
||||
)
|
||||
remove_all_containers(close_prefix)
|
||||
|
||||
def run_action(self, action: Action) -> Observation:
|
||||
if isinstance(action, FileEditAction):
|
||||
|
||||
@@ -288,7 +288,6 @@ class RemoteRuntime(Runtime):
|
||||
assert runtime_data['runtime_id'] == self.runtime_id
|
||||
assert 'pod_status' in runtime_data
|
||||
pod_status = runtime_data['pod_status']
|
||||
self.log('debug', runtime_data)
|
||||
self.log('debug', f'Pod status: {pod_status}')
|
||||
|
||||
# FIXME: We should fix it at the backend of /start endpoint, make sure
|
||||
@@ -333,7 +332,7 @@ class RemoteRuntime(Runtime):
|
||||
raise RuntimeNotReadyError()
|
||||
|
||||
def close(self, timeout: int = 10):
|
||||
if self.config.sandbox.keep_remote_runtime_alive or self.attach_to_existing:
|
||||
if self.config.sandbox.keep_runtime_alive or self.attach_to_existing:
|
||||
self.session.close()
|
||||
return
|
||||
if self.runtime_id and self.session:
|
||||
|
||||
@@ -21,6 +21,8 @@ from openhands.runtime.utils.command import get_remote_startup_command
|
||||
from openhands.runtime.utils.request import send_request
|
||||
from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
|
||||
CONTAINER_NAME_PREFIX = 'openhands-runtime-'
|
||||
|
||||
|
||||
class RunloopLogBuffer(LogBuffer):
|
||||
"""Synchronous buffer for Runloop devbox logs.
|
||||
@@ -115,7 +117,7 @@ class RunloopRuntime(EventStreamRuntime):
|
||||
bearer_token=config.runloop_api_key,
|
||||
)
|
||||
self.session = requests.Session()
|
||||
self.container_name = self.container_name_prefix + sid
|
||||
self.container_name = CONTAINER_NAME_PREFIX + sid
|
||||
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
|
||||
self.init_base_runtime(
|
||||
config,
|
||||
@@ -190,7 +192,7 @@ class RunloopRuntime(EventStreamRuntime):
|
||||
prebuilt='openhands',
|
||||
launch_parameters=LaunchParameters(
|
||||
available_ports=[self._sandbox_port],
|
||||
resource_size_request="LARGE",
|
||||
resource_size_request='LARGE',
|
||||
),
|
||||
metadata={'container-name': self.container_name},
|
||||
)
|
||||
|
||||
@@ -5,7 +5,6 @@ import tempfile
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import jwt
|
||||
import requests
|
||||
@@ -74,14 +73,7 @@ file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
session_manager = SessionManager(config, file_store)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global session_manager
|
||||
async with session_manager:
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
LocalhostCORSMiddleware,
|
||||
allow_credentials=True,
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.stream import session_exists
|
||||
from openhands.runtime.utils.shutdown_listener import should_continue
|
||||
from openhands.server.session.conversation import Conversation
|
||||
from openhands.server.session.session import Session
|
||||
from openhands.storage.files import FileStore
|
||||
@@ -18,78 +15,23 @@ from openhands.storage.files import FileStore
|
||||
class SessionManager:
|
||||
config: AppConfig
|
||||
file_store: FileStore
|
||||
cleanup_interval: int = 300
|
||||
session_timeout: int = 600
|
||||
_sessions: dict[str, Session] = field(default_factory=dict)
|
||||
_session_cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def __aenter__(self):
|
||||
if not self._session_cleanup_task:
|
||||
self._session_cleanup_task = asyncio.create_task(self._cleanup_sessions())
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
if self._session_cleanup_task:
|
||||
self._session_cleanup_task.cancel()
|
||||
self._session_cleanup_task = None
|
||||
|
||||
def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session:
|
||||
if sid in self._sessions:
|
||||
self._sessions[sid].close()
|
||||
self._sessions[sid] = Session(
|
||||
return Session(
|
||||
sid=sid, file_store=self.file_store, ws=ws_conn, config=self.config
|
||||
)
|
||||
return self._sessions[sid]
|
||||
|
||||
def get_session(self, sid: str) -> Session | None:
|
||||
if sid not in self._sessions:
|
||||
return None
|
||||
return self._sessions.get(sid)
|
||||
|
||||
async def attach_to_conversation(self, sid: str) -> Conversation | None:
|
||||
start_time = time.time()
|
||||
if not await session_exists(sid, self.file_store):
|
||||
return None
|
||||
c = Conversation(sid, file_store=self.file_store, config=self.config)
|
||||
await c.connect()
|
||||
end_time = time.time()
|
||||
logger.info(
|
||||
f'Conversation {c.sid} connected in {end_time - start_time} seconds'
|
||||
)
|
||||
return c
|
||||
|
||||
async def detach_from_conversation(self, conversation: Conversation):
|
||||
await conversation.disconnect()
|
||||
|
||||
async def send(self, sid: str, data: dict[str, object]) -> bool:
|
||||
"""Sends data to the client."""
|
||||
session = self.get_session(sid)
|
||||
if session is None:
|
||||
logger.error(f'*** No session found for {sid}, skipping message ***')
|
||||
return False
|
||||
return await session.send(data)
|
||||
|
||||
async def send_error(self, sid: str, message: str) -> bool:
|
||||
"""Sends an error message to the client."""
|
||||
return await self.send(sid, {'error': True, 'message': message})
|
||||
|
||||
async def send_message(self, sid: str, message: str) -> bool:
|
||||
"""Sends a message to the client."""
|
||||
return await self.send(sid, {'message': message})
|
||||
|
||||
async def _cleanup_sessions(self):
|
||||
while should_continue():
|
||||
current_time = time.time()
|
||||
session_ids_to_remove = []
|
||||
for sid, session in list(self._sessions.items()):
|
||||
# if session inactive for a long time, remove it
|
||||
if (
|
||||
not session.is_alive
|
||||
and current_time - session.last_active_ts > self.session_timeout
|
||||
):
|
||||
session_ids_to_remove.append(sid)
|
||||
|
||||
for sid in session_ids_to_remove:
|
||||
to_del_session: Session | None = self._sessions.pop(sid, None)
|
||||
if to_del_session is not None:
|
||||
to_del_session.close()
|
||||
logger.debug(
|
||||
f'Session {sid} and related resource have been removed due to inactivity.'
|
||||
)
|
||||
|
||||
await asyncio.sleep(self.cleanup_interval)
|
||||
|
||||
@@ -224,6 +224,7 @@ def _load_runtime(
|
||||
config = load_app_config()
|
||||
config.run_as_openhands = run_as_openhands
|
||||
config.sandbox.force_rebuild_runtime = force_rebuild_runtime
|
||||
config.sandbox.keep_runtime_alive = False
|
||||
# Folder where all tests create their own folder
|
||||
global test_mount_path
|
||||
if use_workspace:
|
||||
|
||||
@@ -64,7 +64,7 @@ def get_config(
|
||||
timeout=300,
|
||||
api_key=os.environ.get('ALLHANDS_API_KEY', None),
|
||||
remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
|
||||
keep_remote_runtime_alive=False,
|
||||
keep_runtime_alive=False,
|
||||
),
|
||||
# do not mount workspace
|
||||
workspace_base=None,
|
||||
|
||||
Reference in New Issue
Block a user