Compare commits

...

10 Commits

Author SHA1 Message Date
Robert Brennan
bb33535275 fix timesouts 2024-11-03 12:39:20 -05:00
Robert Brennan
00739e3270 revert some changes 2024-11-03 12:32:30 -05:00
Robert Brennan
4a41c979cb fix log stack 2024-11-03 12:23:49 -05:00
Robert Brennan
032a577a3a remove max_retires 2024-11-03 12:21:39 -05:00
Robert Brennan
ea511af181 revert log_gen 2024-11-03 12:20:27 -05:00
Robert Brennan
eb639fad62 fix try 2024-11-03 12:16:28 -05:00
openhands
6aff0459eb fix: improve runtime shutdown handling
- Add proper timeouts to runtime startup and shutdown
- Improve container cleanup with graceful shutdown
- Add better error handling in log buffer
- Add proper thread cleanup
- Add timeouts to all blocking operations

Key improvements:
1. Shorter timeouts for runtime startup (30s vs 120s)
2. Better container cleanup with graceful stop then force kill
3. Proper thread cleanup in log buffer
4. Better error handling during shutdown
5. Proper cleanup order (requests -> logs -> container)
6. Detailed error logging during cleanup
2024-11-03 17:12:25 +00:00
Robert Brennan
7738de78d3 fix indent 2024-11-03 12:03:55 -05:00
openhands
6280ee7a03 fix: improve agent session cleanup and startup
- Add timeouts to all critical operations in session cleanup
- Add proper task cancellation and cleanup order
- Improve error handling and logging during cleanup
- Add timeouts and error handling to session startup
- Add proper thread and event loop handling
- Add cleanup of background tasks

Key improvements:
1. Cancel agent tasks before cleanup
2. Add timeouts to all async operations
3. Proper cleanup order (tasks before loop)
4. Better error handling and recovery
5. Proper thread and loop management
6. Detailed error logging
2024-11-03 17:02:38 +00:00
openhands
f93701a76a fix: improve websocket handling robustness
- Add heartbeat mechanism to detect disconnected clients
- Add timeouts to all critical websocket operations
- Improve error handling and recovery
- Add proper cleanup of resources and background tasks
- Prevent deadlocks in event processing
- Add better logging for websocket errors

These changes make the websocket handling more robust and prevent
server lockups when connections close unexpectedly. Key improvements:

1. Heartbeat every 30s to detect disconnected clients
2. Timeouts on all critical operations
3. Better error handling and recovery
4. Proper cleanup of resources
5. Prevention of deadlocks
6. Detailed error logging
2024-11-03 16:55:16 +00:00
5 changed files with 402 additions and 122 deletions

View File

@@ -95,7 +95,7 @@ class Runtime(FileEditRuntimeMixin):
def log(self, level: str, message: str) -> None:
message = f'[runtime {self.sid}] {message}'
getattr(logger, level)(message)
getattr(logger, level)(message, stacklevel=2)
# ====================================================================

View File

@@ -42,7 +42,7 @@ from openhands.utils.tenacity_stop import stop_if_should_exit
class LogBuffer:
"""Synchronous buffer for Docker container logs.
"""Synchronous buffer for Docker container logs with proper shutdown handling.
This class provides a thread-safe way to collect, store, and retrieve logs
from a Docker container. It uses a list to store log lines and provides methods
@@ -51,53 +51,94 @@ class LogBuffer:
def __init__(self, container: docker.models.containers.Container, logFn: Callable):
self.init_msg = 'Runtime client initialized.'
self.container = container
self.buffer: list[str] = []
self.lock = threading.Lock()
self._stop_event = threading.Event()
self.log_generator = container.logs(stream=True, follow=True)
self._closed = False
self.log_generator = container.logs(stream=True, follow=True, tail=100)
self.log = logFn
self.log_stream_thread = threading.Thread(target=self.stream_logs)
self.log_stream_thread.daemon = True
self.log_stream_thread.start()
self.log = logFn
def append(self, log_line: str):
"""Thread-safe append to log buffer"""
if self._closed:
return
with self.lock:
self.buffer.append(log_line)
def get_and_clear(self) -> list[str]:
"""Thread-safe get and clear of log buffer"""
with self.lock:
logs = list(self.buffer)
self.buffer.clear()
return logs
def stream_logs(self):
"""Stream logs from the Docker container in a separate thread.
"""Stream logs from the Docker container in a separate thread with error handling.
This method runs in its own thread to handle the blocking
operation of reading log lines from the Docker SDK's synchronous generator.
"""
if not self.log_generator:
return
try:
for log_line in self.log_generator:
if self._stop_event.is_set():
while not self._stop_event.is_set():
try:
# Use a timeout when reading from generator
log_line = next(self.log_generator, None)
if log_line is None:
break
if log_line:
decoded_line = log_line.decode('utf-8').rstrip()
self.append(decoded_line)
except StopIteration:
break
except Exception as e:
if not self._stop_event.is_set():
self.log('error', f'Error reading docker logs: {e}')
break
if log_line:
decoded_line = log_line.decode('utf-8').rstrip()
self.append(decoded_line)
except Exception as e:
self.log('error', f'Error streaming docker logs: {e}')
if not self._stop_event.is_set():
self.log('error', f'Error in log stream thread: {e}')
finally:
self._closed = True
def __del__(self):
if self.log_stream_thread.is_alive():
"""Ensure proper cleanup on deletion"""
if not self._closed and hasattr(self, 'log_stream_thread') and self.log_stream_thread.is_alive():
self.log(
'warn',
"LogBuffer was not properly closed. Use 'log_buffer.close()' for clean shutdown.",
)
self.close(timeout=5)
self.close(timeout=2)
def close(self, timeout: float = 5.0):
"""Close the log buffer with proper cleanup
Args:
timeout (float): Maximum time to wait for thread shutdown
"""
if self._closed:
return
self._stop_event.set()
self.log_stream_thread.join(timeout)
self._closed = True
if hasattr(self, 'log_stream_thread') and self.log_stream_thread.is_alive():
try:
self.log_stream_thread.join(timeout)
except Exception as e:
self.log('error', f'Error joining log thread: {e}')
# Force kill thread if it's still alive
if self.log_stream_thread.is_alive():
self.log('warn', 'Log thread did not shut down cleanly')
class EventStreamRuntime(Runtime):
@@ -383,22 +424,25 @@ class EventStreamRuntime(Runtime):
)
@tenacity.retry(
stop=tenacity.stop_after_delay(120) | stop_if_should_exit(),
wait=tenacity.wait_exponential(multiplier=2, min=1, max=20),
stop=tenacity.stop_after_delay(120) | stop_if_should_exit(), # Reduced timeout
wait=tenacity.wait_exponential(multiplier=1, min=1, max=5), # Faster retries
reraise=(ConnectionRefusedError,),
)
def _wait_until_alive(self):
"""Wait for runtime to be ready with proper error handling and timeouts"""
self._refresh_logs()
if not self.log_buffer:
raise RuntimeError('Runtime client is not ready.')
# Use a shorter timeout for individual requests
response = send_request_with_retry(
self.session,
'GET',
f'{self.api_url}/alive',
retry_exceptions=[ConnectionRefusedError],
timeout=300, # 5 minutes gives the container time to be alive 🧟‍♂️
timeout=10, # Shorter timeout per request
)
if response.status_code == 200:
return
else:
@@ -406,22 +450,55 @@ class EventStreamRuntime(Runtime):
self.log('error', msg)
raise RuntimeError(msg)
def close(self, rm_all_containers: bool = True):
"""Closes the EventStreamRuntime and associated objects
"""Closes the EventStreamRuntime and associated objects with proper cleanup
Parameters:
- rm_all_containers (bool): Whether to remove all containers with the 'openhands-sandbox-' prefix
"""
if self.log_buffer:
self.log_buffer.close()
print("CLOSE RUNTIME")
# First stop any ongoing requests
if self.session:
self.session.close()
try:
self.session.close()
except Exception as e:
self.log('error', f'Error closing session: {e}')
self.session = None
# Then close log buffer
if self.log_buffer:
try:
self.log_buffer.close(timeout=2) # Short timeout for log buffer
except Exception as e:
self.log('error', f'Error closing log buffer: {e}')
self.log_buffer = None
# Skip container cleanup if we're attached
if self.attach_to_existing:
return
# Clean up container
if self.container:
try:
# Try to stop container gracefully first
self.container.stop(timeout=5)
except Exception as e:
self.log('error', f'Error stopping container: {e}')
try:
# Force kill if graceful stop fails
self.container.kill()
except Exception as e2:
self.log('error', f'Error killing container: {e2}')
try:
# Remove the container
self.container.remove(force=True)
except Exception as e:
self.log('error', f'Error removing container: {e}')
self.container = None
try:
containers = self.docker_client.containers.list(all=True)
for container in containers:

View File

@@ -301,60 +301,100 @@ async def websocket_endpoint(websocket: WebSocket):
{"action": "finish", "args": {}}
```
"""
# Get protocols from Sec-WebSocket-Protocol header
protocols = websocket.headers.get('sec-websocket-protocol', '').split(', ')
session = None
try:
# Get protocols from Sec-WebSocket-Protocol header
protocols = websocket.headers.get('sec-websocket-protocol', '').split(', ')
# The first protocol should be our real protocol (e.g. 'openhands')
# The second protocol should contain our auth token
if len(protocols) < 3:
logger.error('Expected 3 websocket protocols, got %d', len(protocols))
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
real_protocol = protocols[0]
jwt_token = protocols[1] if protocols[1] != 'NO_JWT' else ''
github_token = protocols[2] if protocols[2] != 'NO_GITHUB' else ''
if not await authenticate_github_user(github_token):
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
await asyncio.wait_for(websocket.accept(subprotocol=real_protocol), 10)
if jwt_token:
sid = get_sid_from_token(jwt_token, config.jwt_secret)
if sid == '':
await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
await websocket.close()
# The first protocol should be our real protocol (e.g. 'openhands')
# The second protocol should contain our auth token
if len(protocols) < 3:
logger.error('Expected 3 websocket protocols, got %d', len(protocols))
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
else:
sid = str(uuid.uuid4())
jwt_token = sign_token({'sid': sid}, config.jwt_secret)
logger.info(f'New session: {sid}')
session = session_manager.add_or_restart_session(sid, websocket)
await websocket.send_json({'token': jwt_token, 'status': 'ok'})
real_protocol = protocols[0]
jwt_token = protocols[1] if protocols[1] != 'NO_JWT' else ''
github_token = protocols[2] if protocols[2] != 'NO_GITHUB' else ''
latest_event_id = -1
if websocket.query_params.get('latest_event_id'):
latest_event_id = int(websocket.query_params.get('latest_event_id'))
for event in session.agent_session.event_stream.get_events(
start_id=latest_event_id + 1
):
if isinstance(
event,
(
NullAction,
NullObservation,
ChangeAgentStateAction,
AgentStateChangedObservation,
),
):
continue
await websocket.send_json(event_to_dict(event))
if not await authenticate_github_user(github_token):
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
await session.loop_recv()
try:
await asyncio.wait_for(websocket.accept(subprotocol=real_protocol), timeout=10)
except asyncio.TimeoutError:
logger.error("WebSocket accept timed out")
await websocket.close(code=status.WS_1408_REQUEST_TIMEOUT)
return
if jwt_token:
sid = get_sid_from_token(jwt_token, config.jwt_secret)
if sid == '':
await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
await websocket.close()
return
else:
sid = str(uuid.uuid4())
jwt_token = sign_token({'sid': sid}, config.jwt_secret)
logger.info(f'New session: {sid}')
session = session_manager.add_or_restart_session(sid, websocket)
try:
await asyncio.wait_for(
websocket.send_json({'token': jwt_token, 'status': 'ok'}),
timeout=5
)
except asyncio.TimeoutError:
logger.error("Failed to send initial response")
await session.close()
await websocket.close(code=status.WS_1408_REQUEST_TIMEOUT)
return
latest_event_id = -1
if websocket.query_params.get('latest_event_id'):
latest_event_id = int(websocket.query_params.get('latest_event_id'))
# Send historical events with timeout
try:
async with asyncio.timeout(30): # 30 second timeout for historical events
for event in session.agent_session.event_stream.get_events(
start_id=latest_event_id + 1
):
if isinstance(
event,
(
NullAction,
NullObservation,
ChangeAgentStateAction,
AgentStateChangedObservation,
),
):
continue
await websocket.send_json(event_to_dict(event))
except asyncio.TimeoutError:
logger.error("Timeout sending historical events")
await session.close()
await websocket.close(code=status.WS_1408_REQUEST_TIMEOUT)
return
# Start the main receive loop with heartbeat
await session.loop_recv()
except WebSocketDisconnect:
logger.info("WebSocket disconnected normally")
except Exception as e:
logger.exception("Error in websocket handler: %s", str(e))
finally:
if session:
try:
await asyncio.wait_for(session.close(), timeout=5)
except asyncio.TimeoutError:
logger.error("Timeout during session cleanup")
except Exception as e:
logger.exception("Error during session cleanup: %s", str(e))
@app.get('/api/options/models')

View File

@@ -55,22 +55,39 @@ class AgentSession:
agent_configs: dict[str, AgentConfig] | None = None,
status_message_callback: Optional[Callable] = None,
):
"""Starts the Agent session
"""Starts the Agent session with proper error handling and timeouts
Parameters:
- runtime_name: The name of the runtime associated with the session
- config:
- agent:
- max_iterations:
- max_budget_per_task:
- agent_to_llm_config:
- agent_configs:
- config: Application configuration
- agent: Agent instance to use
- max_iterations: Maximum number of iterations
- max_budget_per_task: Maximum budget per task
- agent_to_llm_config: LLM configurations for different agents
- agent_configs: Agent configurations
- status_message_callback: Callback for status updates
"""
if self.controller or self.runtime:
raise RuntimeError(
'Session already started. You need to close this session and start a new one.'
)
asyncio.get_event_loop().run_in_executor(
# Create a future to track the start operation
start_future = asyncio.Future()
def start_callback(future):
try:
exc = future.exception()
if exc:
start_future.set_exception(exc)
else:
start_future.set_result(None)
except asyncio.CancelledError:
start_future.cancel()
except Exception as e:
start_future.set_exception(e)
# Start the agent in a thread pool with proper error propagation
task = asyncio.get_event_loop().run_in_executor(
None,
self._start_thread,
runtime_name,
@@ -82,13 +99,49 @@ class AgentSession:
agent_configs,
status_message_callback,
)
task.add_done_callback(start_callback)
try:
# Wait for start with timeout
await asyncio.wait_for(start_future, timeout=120)
except asyncio.TimeoutError:
logger.error("Agent session start timed out")
# Cleanup if start times out
await self.close()
raise RuntimeError("Agent session start timed out")
except Exception as e:
logger.exception("Error starting agent session")
await self.close()
raise
def _start_thread(self, *args):
"""Start the agent in a separate thread with proper error handling"""
try:
asyncio.run(self._start(*args), debug=True)
except RuntimeError:
logger.error(f'Error starting session: {RuntimeError}', exc_info=True)
logger.debug('Session Finished')
# Create new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# Run with timeout
loop.run_until_complete(
asyncio.wait_for(self._start(*args), timeout=25)
)
except asyncio.TimeoutError:
logger.error("Timeout in agent start")
raise RuntimeError("Timeout in agent start")
except Exception as e:
logger.exception("Error in agent start")
raise
finally:
try:
# Clean up the loop
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
except Exception as e:
logger.error(f"Error cleaning up thread loop: {e}")
except Exception as e:
logger.error(f"Fatal error in start thread: {e}")
raise
async def _start(
self,
@@ -101,7 +154,6 @@ class AgentSession:
agent_configs: dict[str, AgentConfig] | None = None,
status_message_callback: Optional[Callable] = None,
):
self.loop = asyncio.get_running_loop()
self._create_security_analyzer(config.security.security_analyzer)
await self._create_runtime(
runtime_name=runtime_name,
@@ -125,29 +177,60 @@ class AgentSession:
await self.controller.agent_task # type: ignore
async def close(self):
"""Closes the Agent session"""
"""Closes the Agent session with proper cleanup and timeouts"""
if self._closed:
return
if self.controller is not None:
end_state = self.controller.get_state()
end_state.save_to_session(self.sid, self.file_store)
await self.controller.close()
if self.runtime is not None:
self.runtime.close()
if self.security_analyzer is not None:
await self.security_analyzer.close()
if self.loop:
if self.loop.is_closed():
logger.debug(
'Trying to close already closed loop. (It probably never started correctly)'
)
else:
self.loop.stop()
self.loop = None
try:
# Set closed flag early to prevent multiple close attempts
self._closed = True
self._closed = True
# Save state with timeout
if self.controller is not None:
try:
async with asyncio.timeout(5):
end_state = self.controller.get_state()
end_state.save_to_session(self.sid, self.file_store)
except asyncio.TimeoutError:
logger.error("Timeout saving agent state")
except Exception as e:
logger.error(f"Error saving agent state: {e}")
# Close controller with timeout
if self.controller is not None:
try:
async with asyncio.timeout(5):
await self.controller.close()
except asyncio.TimeoutError:
logger.error("Timeout closing controller")
except Exception as e:
logger.error(f"Error closing controller: {e}")
self.controller = None
# Close runtime (this is synchronous but should be quick)
if self.runtime is not None:
try:
self.runtime.close()
except Exception as e:
logger.error(f"Error closing runtime: {e}")
self.runtime = None
# Close security analyzer with timeout
if self.security_analyzer is not None:
try:
async with asyncio.timeout(5):
await self.security_analyzer.close()
except asyncio.TimeoutError:
logger.error("Timeout closing security analyzer")
except Exception as e:
logger.error(f"Error closing security analyzer: {e}")
self.security_analyzer = None
except Exception as e:
logger.exception(f"Unexpected error during session cleanup: {e}")
finally:
# Ensure closed flag is set even if cleanup fails
self._closed = True
def _create_security_analyzer(self, security_analyzer: str | None):
"""Creates a SecurityAnalyzer instance that will be used to analyze the agent actions

View File

@@ -51,23 +51,70 @@ class Session:
self.is_alive = False
await self.agent_session.close()
async def _heartbeat(self):
"""Send periodic heartbeat to check connection is alive"""
while self.is_alive and should_continue():
try:
await asyncio.wait_for(
self.websocket.send_json({"type": "heartbeat"}),
timeout=5
)
await asyncio.sleep(1) # Send heartbeat every 30 seconds
except (asyncio.TimeoutError, WebSocketDisconnect, RuntimeError):
self.is_alive = False
break
except Exception as e:
logger.exception("Error in heartbeat: %s", e)
self.is_alive = False
break
async def loop_recv(self):
"""Main websocket receive loop with heartbeat"""
if self.websocket is None:
return
# Start heartbeat in background task
heartbeat_task = asyncio.create_task(self._heartbeat())
try:
if self.websocket is None:
return
while should_continue():
while self.is_alive and should_continue():
try:
data = await self.websocket.receive_json()
# Use timeout to prevent blocking forever
data = await asyncio.wait_for(
self.websocket.receive_json(),
timeout=35 # Slightly longer than heartbeat interval
)
await self.dispatch(data)
except asyncio.TimeoutError:
# No message received within timeout, check if heartbeat is still alive
if not heartbeat_task.done():
continue
else:
logger.error("Heartbeat task failed, closing connection")
break
except ValueError:
await self.send_error('Invalid JSON')
continue
await self.dispatch(data)
except WebSocketDisconnect:
except WebSocketDisconnect:
logger.debug('WebSocket disconnected, sid: %s', self.sid)
break
except Exception as e:
logger.exception("Error processing message: %s", e)
await self.send_error(f"Error processing message: {str(e)}")
continue
except Exception as e:
logger.exception("Fatal error in receive loop: %s", e)
finally:
# Cancel heartbeat task
heartbeat_task.cancel()
try:
await heartbeat_task
except asyncio.CancelledError:
pass
# Close the session
await self.close()
logger.debug('WebSocket disconnected, sid: %s', self.sid)
except RuntimeError as e:
await self.close()
logger.exception('Error in loop_recv: %s', e)
async def _initialize_agent(self, data: dict):
self.agent_session.event_stream.add_event(
@@ -152,12 +199,30 @@ class Session:
await self.send(event_dict)
async def dispatch(self, data: dict):
"""Dispatch incoming websocket messages to appropriate handlers"""
action = data.get('action', '')
# Handle initialization separately
if action == ActionType.INIT:
await self._initialize_agent(data)
try:
async with asyncio.timeout(120):
await self._initialize_agent(data)
except asyncio.TimeoutError:
await self.send_error('Agent initialization timed out')
except Exception as e:
logger.exception("Error initializing agent: %s", e)
await self.send_error(f'Failed to initialize agent: {str(e)}')
return
event = event_from_dict(data.copy())
# This checks if the model supports images
# Convert message to event
try:
event = event_from_dict(data.copy())
except Exception as e:
logger.error("Failed to parse event: %s", e)
await self.send_error('Invalid event format')
return
# Handle image validation
if isinstance(event, MessageAction) and event.images_urls:
controller = self.agent_session.controller
if controller:
@@ -171,10 +236,25 @@ class Session:
'Model does not support image upload, change to a different model or try without an image.'
)
return
if self.agent_session.loop:
asyncio.run_coroutine_threadsafe(
self._add_event(event, EventSource.USER), self.agent_session.loop
) # type: ignore
try:
# Use asyncio.wait_for to prevent blocking indefinitely
future = asyncio.run_coroutine_threadsafe(
self._add_event(event, EventSource.USER),
asyncio.get_running_loop()
)
# Wait for the event to be processed with timeout
await asyncio.wait_for(
asyncio.wrap_future(future),
timeout=10
)
except asyncio.TimeoutError:
logger.error("Event processing timed out")
await self.send_error('Event processing timed out')
except Exception as e:
logger.exception("Error processing event: %s", e)
await self.send_error(f'Failed to process event: {str(e)}')
async def _add_event(self, event, event_source):
self.agent_session.event_stream.add_event(event, EventSource.USER)