Refactor sessions a bit, and fix issue where runtimes get killed (#4900)

This commit is contained in:
Robert Brennan
2024-11-12 11:20:36 -05:00
committed by GitHub
parent 910b283ac2
commit 17f4c6e1a9
14 changed files with 71 additions and 131 deletions

View File

@@ -286,7 +286,6 @@ jobs:
image_name=ghcr.io/${{ github.repository_owner }}/runtime:${{ env.RELEVANT_SHA }}-${{ matrix.base_image }}
image_name=$(echo $image_name | tr '[:upper:]' '[:lower:]')
SKIP_CONTAINER_LOGS=true \
TEST_RUNTIME=eventstream \
SANDBOX_USER_ID=$(id -u) \
SANDBOX_RUNTIME_CONTAINER_IMAGE=$image_name \
@@ -364,7 +363,6 @@ jobs:
image_name=ghcr.io/${{ github.repository_owner }}/runtime:${{ env.RELEVANT_SHA }}-${{ matrix.base_image }}
image_name=$(echo $image_name | tr '[:upper:]' '[:lower:]')
SKIP_CONTAINER_LOGS=true \
TEST_RUNTIME=eventstream \
SANDBOX_USER_ID=$(id -u) \
SANDBOX_RUNTIME_CONTAINER_IMAGE=$image_name \

View File

@@ -59,7 +59,7 @@ docker run # ...
-e RUNTIME=remote \
-e SANDBOX_REMOTE_RUNTIME_API_URL="https://runtime.app.all-hands.dev" \
-e SANDBOX_API_KEY="your-all-hands-api-key" \
-e SANDBOX_KEEP_REMOTE_RUNTIME_ALIVE="true" \
-e SANDBOX_KEEP_RUNTIME_ALIVE="true" \
# ...
```

View File

@@ -66,7 +66,7 @@ def get_config(
browsergym_eval_env=env_id,
api_key=os.environ.get('ALLHANDS_API_KEY', None),
remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
keep_remote_runtime_alive=False,
keep_runtime_alive=False,
),
# do not mount workspace
workspace_base=None,

View File

@@ -72,7 +72,7 @@ def get_config(
timeout=300,
api_key=os.environ.get('ALLHANDS_API_KEY', None),
remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
keep_remote_runtime_alive=False,
keep_runtime_alive=False,
),
# do not mount workspace
workspace_base=None,

View File

@@ -145,7 +145,7 @@ def get_config(
platform='linux/amd64',
api_key=os.environ.get('ALLHANDS_API_KEY', None),
remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
keep_remote_runtime_alive=False,
keep_runtime_alive=False,
remote_runtime_init_timeout=1800,
),
# do not mount workspace

View File

@@ -36,7 +36,7 @@ class SandboxConfig:
remote_runtime_api_url: str = 'http://localhost:8000'
local_runtime_url: str = 'http://localhost'
keep_remote_runtime_alive: bool = True
keep_runtime_alive: bool = True
api_key: str | None = None
base_container_image: str = 'nikolaik/python-nodejs:python3.12-nodejs22' # default to nikolaik/python-nodejs:python3.12-nodejs22 for eventstream runtime
runtime_container_image: str | None = None

View File

@@ -0,0 +1,18 @@
import docker
def remove_all_containers(prefix: str):
docker_client = docker.from_env()
try:
containers = docker_client.containers.list(all=True)
for container in containers:
try:
if container.name.startswith(prefix):
container.remove(force=True)
except docker.errors.APIError:
pass
except docker.errors.NotFound:
pass
except docker.errors.NotFound: # yes, this can happen!
pass

View File

@@ -1,8 +1,9 @@
import atexit
import os
from pathlib import Path
import tempfile
import threading
from functools import lru_cache
from pathlib import Path
from typing import Callable
from zipfile import ZipFile
@@ -35,6 +36,7 @@ from openhands.events.serialization import event_to_dict, observation_from_dict
from openhands.events.serialization.action import ACTION_TYPE_TO_CLASS
from openhands.runtime.base import Runtime
from openhands.runtime.builder import DockerRuntimeBuilder
from openhands.runtime.impl.eventstream.containers import remove_all_containers
from openhands.runtime.plugins import PluginRequirement
from openhands.runtime.utils import find_available_tcp_port
from openhands.runtime.utils.request import send_request
@@ -42,6 +44,15 @@ from openhands.runtime.utils.runtime_build import build_runtime_image
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.tenacity_stop import stop_if_should_exit
CONTAINER_NAME_PREFIX = 'openhands-runtime-'
def remove_all_runtime_containers():
remove_all_containers(CONTAINER_NAME_PREFIX)
atexit.register(remove_all_runtime_containers)
class LogBuffer:
"""Synchronous buffer for Docker container logs.
@@ -114,8 +125,6 @@ class EventStreamRuntime(Runtime):
env_vars (dict[str, str] | None, optional): Environment variables to set. Defaults to None.
"""
container_name_prefix = 'openhands-runtime-'
# Need to provide this method to allow inheritors to init the Runtime
# without initting the EventStreamRuntime.
def init_base_runtime(
@@ -158,7 +167,7 @@ class EventStreamRuntime(Runtime):
self.docker_client: docker.DockerClient = self._init_docker_client()
self.base_container_image = self.config.sandbox.base_container_image
self.runtime_container_image = self.config.sandbox.runtime_container_image
self.container_name = self.container_name_prefix + sid
self.container_name = CONTAINER_NAME_PREFIX + sid
self.container = None
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
@@ -173,10 +182,6 @@ class EventStreamRuntime(Runtime):
f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}',
)
self.skip_container_logs = (
os.environ.get('SKIP_CONTAINER_LOGS', 'false').lower() == 'true'
)
self.init_base_runtime(
config,
event_stream,
@@ -189,7 +194,15 @@ class EventStreamRuntime(Runtime):
async def connect(self):
self.send_status_message('STATUS$STARTING_RUNTIME')
if not self.attach_to_existing:
try:
await call_sync_from_async(self._attach_to_container)
except docker.errors.NotFound as e:
if self.attach_to_existing:
self.log(
'error',
f'Container {self.container_name} not found.',
)
raise e
if self.runtime_container_image is None:
if self.base_container_image is None:
raise ValueError(
@@ -210,13 +223,12 @@ class EventStreamRuntime(Runtime):
await call_sync_from_async(self._init_container)
self.log('info', f'Container started: {self.container_name}')
else:
await call_sync_from_async(self._attach_to_container)
if not self.attach_to_existing:
self.log('info', f'Waiting for client to become ready at {self.api_url}...')
self.send_status_message('STATUS$WAITING_FOR_CLIENT')
self.send_status_message('STATUS$WAITING_FOR_CLIENT')
await call_sync_from_async(self._wait_until_alive)
if not self.attach_to_existing:
self.log('info', 'Runtime is ready.')
@@ -227,7 +239,8 @@ class EventStreamRuntime(Runtime):
'debug',
f'Container initialized with plugins: {[plugin.name for plugin in self.plugins]}',
)
self.send_status_message(' ')
if not self.attach_to_existing:
self.send_status_message(' ')
@staticmethod
@lru_cache(maxsize=1)
@@ -332,13 +345,12 @@ class EventStreamRuntime(Runtime):
self.log('debug', f'Container started. Server url: {self.api_url}')
self.send_status_message('STATUS$CONTAINER_STARTED')
except docker.errors.APIError as e:
# check 409 error
if '409' in str(e):
self.log(
'warning',
f'Container {self.container_name} already exists. Removing...',
)
self._close_containers(rm_all_containers=True)
remove_all_containers(self.container_name)
return self._init_container()
else:
@@ -414,42 +426,18 @@ class EventStreamRuntime(Runtime):
Parameters:
- rm_all_containers (bool): Whether to remove all containers with the 'openhands-sandbox-' prefix
"""
if self.log_buffer:
self.log_buffer.close()
if self.session:
self.session.close()
if self.attach_to_existing:
if self.config.sandbox.keep_runtime_alive or self.attach_to_existing:
return
self._close_containers(rm_all_containers)
def _close_containers(self, rm_all_containers: bool = True):
try:
containers = self.docker_client.containers.list(all=True)
for container in containers:
try:
# If the app doesn't shut down properly, it can leave runtime containers on the system. This ensures
# that all 'openhands-sandbox-' containers are removed as well.
if rm_all_containers and container.name.startswith(
self.container_name_prefix
):
container.remove(force=True)
elif container.name == self.container_name:
if not self.skip_container_logs:
logs = container.logs(tail=1000).decode('utf-8')
self.log(
'debug',
f'==== Container logs on close ====\n{logs}\n==== End of container logs ====',
)
container.remove(force=True)
except docker.errors.APIError:
pass
except docker.errors.NotFound:
pass
except docker.errors.NotFound: # yes, this can happen!
pass
close_prefix = (
CONTAINER_NAME_PREFIX if rm_all_containers else self.container_name
)
remove_all_containers(close_prefix)
def run_action(self, action: Action) -> Observation:
if isinstance(action, FileEditAction):

View File

@@ -288,7 +288,6 @@ class RemoteRuntime(Runtime):
assert runtime_data['runtime_id'] == self.runtime_id
assert 'pod_status' in runtime_data
pod_status = runtime_data['pod_status']
self.log('debug', runtime_data)
self.log('debug', f'Pod status: {pod_status}')
# FIXME: We should fix it at the backend of /start endpoint, make sure
@@ -333,7 +332,7 @@ class RemoteRuntime(Runtime):
raise RuntimeNotReadyError()
def close(self, timeout: int = 10):
if self.config.sandbox.keep_remote_runtime_alive or self.attach_to_existing:
if self.config.sandbox.keep_runtime_alive or self.attach_to_existing:
self.session.close()
return
if self.runtime_id and self.session:

View File

@@ -21,6 +21,8 @@ from openhands.runtime.utils.command import get_remote_startup_command
from openhands.runtime.utils.request import send_request
from openhands.utils.tenacity_stop import stop_if_should_exit
CONTAINER_NAME_PREFIX = 'openhands-runtime-'
class RunloopLogBuffer(LogBuffer):
"""Synchronous buffer for Runloop devbox logs.
@@ -115,7 +117,7 @@ class RunloopRuntime(EventStreamRuntime):
bearer_token=config.runloop_api_key,
)
self.session = requests.Session()
self.container_name = self.container_name_prefix + sid
self.container_name = CONTAINER_NAME_PREFIX + sid
self.action_semaphore = threading.Semaphore(1) # Ensure one action at a time
self.init_base_runtime(
config,
@@ -190,7 +192,7 @@ class RunloopRuntime(EventStreamRuntime):
prebuilt='openhands',
launch_parameters=LaunchParameters(
available_ports=[self._sandbox_port],
resource_size_request="LARGE",
resource_size_request='LARGE',
),
metadata={'container-name': self.container_name},
)

View File

@@ -5,7 +5,6 @@ import tempfile
import time
import uuid
import warnings
from contextlib import asynccontextmanager
import jwt
import requests
@@ -74,14 +73,7 @@ file_store = get_file_store(config.file_store, config.file_store_path)
session_manager = SessionManager(config, file_store)
@asynccontextmanager
async def lifespan(app: FastAPI):
global session_manager
async with session_manager:
yield
app = FastAPI(lifespan=lifespan)
app = FastAPI()
app.add_middleware(
LocalhostCORSMiddleware,
allow_credentials=True,

View File

@@ -1,14 +1,11 @@
import asyncio
import time
from dataclasses import dataclass, field
from typing import Optional
from dataclasses import dataclass
from fastapi import WebSocket
from openhands.core.config import AppConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events.stream import session_exists
from openhands.runtime.utils.shutdown_listener import should_continue
from openhands.server.session.conversation import Conversation
from openhands.server.session.session import Session
from openhands.storage.files import FileStore
@@ -18,78 +15,23 @@ from openhands.storage.files import FileStore
class SessionManager:
config: AppConfig
file_store: FileStore
cleanup_interval: int = 300
session_timeout: int = 600
_sessions: dict[str, Session] = field(default_factory=dict)
_session_cleanup_task: Optional[asyncio.Task] = None
async def __aenter__(self):
if not self._session_cleanup_task:
self._session_cleanup_task = asyncio.create_task(self._cleanup_sessions())
return self
async def __aexit__(self, exc_type, exc_value, traceback):
if self._session_cleanup_task:
self._session_cleanup_task.cancel()
self._session_cleanup_task = None
def add_or_restart_session(self, sid: str, ws_conn: WebSocket) -> Session:
if sid in self._sessions:
self._sessions[sid].close()
self._sessions[sid] = Session(
return Session(
sid=sid, file_store=self.file_store, ws=ws_conn, config=self.config
)
return self._sessions[sid]
def get_session(self, sid: str) -> Session | None:
if sid not in self._sessions:
return None
return self._sessions.get(sid)
async def attach_to_conversation(self, sid: str) -> Conversation | None:
start_time = time.time()
if not await session_exists(sid, self.file_store):
return None
c = Conversation(sid, file_store=self.file_store, config=self.config)
await c.connect()
end_time = time.time()
logger.info(
f'Conversation {c.sid} connected in {end_time - start_time} seconds'
)
return c
async def detach_from_conversation(self, conversation: Conversation):
await conversation.disconnect()
async def send(self, sid: str, data: dict[str, object]) -> bool:
"""Sends data to the client."""
session = self.get_session(sid)
if session is None:
logger.error(f'*** No session found for {sid}, skipping message ***')
return False
return await session.send(data)
async def send_error(self, sid: str, message: str) -> bool:
"""Sends an error message to the client."""
return await self.send(sid, {'error': True, 'message': message})
async def send_message(self, sid: str, message: str) -> bool:
"""Sends a message to the client."""
return await self.send(sid, {'message': message})
async def _cleanup_sessions(self):
while should_continue():
current_time = time.time()
session_ids_to_remove = []
for sid, session in list(self._sessions.items()):
# if session inactive for a long time, remove it
if (
not session.is_alive
and current_time - session.last_active_ts > self.session_timeout
):
session_ids_to_remove.append(sid)
for sid in session_ids_to_remove:
to_del_session: Session | None = self._sessions.pop(sid, None)
if to_del_session is not None:
to_del_session.close()
logger.debug(
f'Session {sid} and related resource have been removed due to inactivity.'
)
await asyncio.sleep(self.cleanup_interval)

View File

@@ -224,6 +224,7 @@ def _load_runtime(
config = load_app_config()
config.run_as_openhands = run_as_openhands
config.sandbox.force_rebuild_runtime = force_rebuild_runtime
config.sandbox.keep_runtime_alive = False
# Folder where all tests create their own folder
global test_mount_path
if use_workspace:

View File

@@ -64,7 +64,7 @@ def get_config(
timeout=300,
api_key=os.environ.get('ALLHANDS_API_KEY', None),
remote_runtime_api_url=os.environ.get('SANDBOX_REMOTE_RUNTIME_API_URL'),
keep_remote_runtime_alive=False,
keep_runtime_alive=False,
),
# do not mount workspace
workspace_base=None,