mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
10 Commits
rb/github-
...
rb/websock
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bb33535275 | ||
|
|
00739e3270 | ||
|
|
4a41c979cb | ||
|
|
032a577a3a | ||
|
|
ea511af181 | ||
|
|
eb639fad62 | ||
|
|
6aff0459eb | ||
|
|
7738de78d3 | ||
|
|
6280ee7a03 | ||
|
|
f93701a76a |
@@ -95,7 +95,7 @@ class Runtime(FileEditRuntimeMixin):
|
||||
|
||||
def log(self, level: str, message: str) -> None:
|
||||
message = f'[runtime {self.sid}] {message}'
|
||||
getattr(logger, level)(message)
|
||||
getattr(logger, level)(message, stacklevel=2)
|
||||
|
||||
# ====================================================================
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
|
||||
|
||||
class LogBuffer:
|
||||
"""Synchronous buffer for Docker container logs.
|
||||
"""Synchronous buffer for Docker container logs with proper shutdown handling.
|
||||
|
||||
This class provides a thread-safe way to collect, store, and retrieve logs
|
||||
from a Docker container. It uses a list to store log lines and provides methods
|
||||
@@ -51,53 +51,94 @@ class LogBuffer:
|
||||
|
||||
def __init__(self, container: docker.models.containers.Container, logFn: Callable):
|
||||
self.init_msg = 'Runtime client initialized.'
|
||||
|
||||
self.container = container
|
||||
self.buffer: list[str] = []
|
||||
self.lock = threading.Lock()
|
||||
self._stop_event = threading.Event()
|
||||
self.log_generator = container.logs(stream=True, follow=True)
|
||||
self._closed = False
|
||||
|
||||
self.log_generator = container.logs(stream=True, follow=True, tail=100)
|
||||
|
||||
self.log = logFn
|
||||
self.log_stream_thread = threading.Thread(target=self.stream_logs)
|
||||
self.log_stream_thread.daemon = True
|
||||
self.log_stream_thread.start()
|
||||
self.log = logFn
|
||||
|
||||
def append(self, log_line: str):
|
||||
"""Thread-safe append to log buffer"""
|
||||
if self._closed:
|
||||
return
|
||||
with self.lock:
|
||||
self.buffer.append(log_line)
|
||||
|
||||
def get_and_clear(self) -> list[str]:
|
||||
"""Thread-safe get and clear of log buffer"""
|
||||
with self.lock:
|
||||
logs = list(self.buffer)
|
||||
self.buffer.clear()
|
||||
return logs
|
||||
|
||||
def stream_logs(self):
|
||||
"""Stream logs from the Docker container in a separate thread.
|
||||
"""Stream logs from the Docker container in a separate thread with error handling.
|
||||
|
||||
This method runs in its own thread to handle the blocking
|
||||
operation of reading log lines from the Docker SDK's synchronous generator.
|
||||
"""
|
||||
if not self.log_generator:
|
||||
return
|
||||
|
||||
try:
|
||||
for log_line in self.log_generator:
|
||||
if self._stop_event.is_set():
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
# Use a timeout when reading from generator
|
||||
log_line = next(self.log_generator, None)
|
||||
if log_line is None:
|
||||
break
|
||||
if log_line:
|
||||
decoded_line = log_line.decode('utf-8').rstrip()
|
||||
self.append(decoded_line)
|
||||
except StopIteration:
|
||||
break
|
||||
except Exception as e:
|
||||
if not self._stop_event.is_set():
|
||||
self.log('error', f'Error reading docker logs: {e}')
|
||||
break
|
||||
if log_line:
|
||||
decoded_line = log_line.decode('utf-8').rstrip()
|
||||
self.append(decoded_line)
|
||||
except Exception as e:
|
||||
self.log('error', f'Error streaming docker logs: {e}')
|
||||
if not self._stop_event.is_set():
|
||||
self.log('error', f'Error in log stream thread: {e}')
|
||||
finally:
|
||||
self._closed = True
|
||||
|
||||
def __del__(self):
|
||||
if self.log_stream_thread.is_alive():
|
||||
"""Ensure proper cleanup on deletion"""
|
||||
if not self._closed and hasattr(self, 'log_stream_thread') and self.log_stream_thread.is_alive():
|
||||
self.log(
|
||||
'warn',
|
||||
"LogBuffer was not properly closed. Use 'log_buffer.close()' for clean shutdown.",
|
||||
)
|
||||
self.close(timeout=5)
|
||||
self.close(timeout=2)
|
||||
|
||||
def close(self, timeout: float = 5.0):
|
||||
"""Close the log buffer with proper cleanup
|
||||
|
||||
Args:
|
||||
timeout (float): Maximum time to wait for thread shutdown
|
||||
"""
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._stop_event.set()
|
||||
self.log_stream_thread.join(timeout)
|
||||
self._closed = True
|
||||
|
||||
if hasattr(self, 'log_stream_thread') and self.log_stream_thread.is_alive():
|
||||
try:
|
||||
self.log_stream_thread.join(timeout)
|
||||
except Exception as e:
|
||||
self.log('error', f'Error joining log thread: {e}')
|
||||
|
||||
# Force kill thread if it's still alive
|
||||
if self.log_stream_thread.is_alive():
|
||||
self.log('warn', 'Log thread did not shut down cleanly')
|
||||
|
||||
|
||||
class EventStreamRuntime(Runtime):
|
||||
@@ -383,22 +424,25 @@ class EventStreamRuntime(Runtime):
|
||||
)
|
||||
|
||||
@tenacity.retry(
|
||||
stop=tenacity.stop_after_delay(120) | stop_if_should_exit(),
|
||||
wait=tenacity.wait_exponential(multiplier=2, min=1, max=20),
|
||||
stop=tenacity.stop_after_delay(120) | stop_if_should_exit(), # Reduced timeout
|
||||
wait=tenacity.wait_exponential(multiplier=1, min=1, max=5), # Faster retries
|
||||
reraise=(ConnectionRefusedError,),
|
||||
)
|
||||
def _wait_until_alive(self):
|
||||
"""Wait for runtime to be ready with proper error handling and timeouts"""
|
||||
self._refresh_logs()
|
||||
if not self.log_buffer:
|
||||
raise RuntimeError('Runtime client is not ready.')
|
||||
|
||||
# Use a shorter timeout for individual requests
|
||||
response = send_request_with_retry(
|
||||
self.session,
|
||||
'GET',
|
||||
f'{self.api_url}/alive',
|
||||
retry_exceptions=[ConnectionRefusedError],
|
||||
timeout=300, # 5 minutes gives the container time to be alive 🧟♂️
|
||||
timeout=10, # Shorter timeout per request
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return
|
||||
else:
|
||||
@@ -406,22 +450,55 @@ class EventStreamRuntime(Runtime):
|
||||
self.log('error', msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
def close(self, rm_all_containers: bool = True):
|
||||
"""Closes the EventStreamRuntime and associated objects
|
||||
"""Closes the EventStreamRuntime and associated objects with proper cleanup
|
||||
|
||||
Parameters:
|
||||
- rm_all_containers (bool): Whether to remove all containers with the 'openhands-sandbox-' prefix
|
||||
"""
|
||||
|
||||
if self.log_buffer:
|
||||
self.log_buffer.close()
|
||||
|
||||
print("CLOSE RUNTIME")
|
||||
# First stop any ongoing requests
|
||||
if self.session:
|
||||
self.session.close()
|
||||
try:
|
||||
self.session.close()
|
||||
except Exception as e:
|
||||
self.log('error', f'Error closing session: {e}')
|
||||
self.session = None
|
||||
|
||||
# Then close log buffer
|
||||
if self.log_buffer:
|
||||
try:
|
||||
self.log_buffer.close(timeout=2) # Short timeout for log buffer
|
||||
except Exception as e:
|
||||
self.log('error', f'Error closing log buffer: {e}')
|
||||
self.log_buffer = None
|
||||
|
||||
# Skip container cleanup if we're attached
|
||||
if self.attach_to_existing:
|
||||
return
|
||||
|
||||
# Clean up container
|
||||
if self.container:
|
||||
try:
|
||||
# Try to stop container gracefully first
|
||||
self.container.stop(timeout=5)
|
||||
except Exception as e:
|
||||
self.log('error', f'Error stopping container: {e}')
|
||||
try:
|
||||
# Force kill if graceful stop fails
|
||||
self.container.kill()
|
||||
except Exception as e2:
|
||||
self.log('error', f'Error killing container: {e2}')
|
||||
|
||||
try:
|
||||
# Remove the container
|
||||
self.container.remove(force=True)
|
||||
except Exception as e:
|
||||
self.log('error', f'Error removing container: {e}')
|
||||
|
||||
self.container = None
|
||||
|
||||
try:
|
||||
containers = self.docker_client.containers.list(all=True)
|
||||
for container in containers:
|
||||
|
||||
@@ -301,60 +301,100 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
{"action": "finish", "args": {}}
|
||||
```
|
||||
"""
|
||||
# Get protocols from Sec-WebSocket-Protocol header
|
||||
protocols = websocket.headers.get('sec-websocket-protocol', '').split(', ')
|
||||
session = None
|
||||
try:
|
||||
# Get protocols from Sec-WebSocket-Protocol header
|
||||
protocols = websocket.headers.get('sec-websocket-protocol', '').split(', ')
|
||||
|
||||
# The first protocol should be our real protocol (e.g. 'openhands')
|
||||
# The second protocol should contain our auth token
|
||||
if len(protocols) < 3:
|
||||
logger.error('Expected 3 websocket protocols, got %d', len(protocols))
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
real_protocol = protocols[0]
|
||||
jwt_token = protocols[1] if protocols[1] != 'NO_JWT' else ''
|
||||
github_token = protocols[2] if protocols[2] != 'NO_GITHUB' else ''
|
||||
|
||||
if not await authenticate_github_user(github_token):
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
await asyncio.wait_for(websocket.accept(subprotocol=real_protocol), 10)
|
||||
|
||||
if jwt_token:
|
||||
sid = get_sid_from_token(jwt_token, config.jwt_secret)
|
||||
|
||||
if sid == '':
|
||||
await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
|
||||
await websocket.close()
|
||||
# The first protocol should be our real protocol (e.g. 'openhands')
|
||||
# The second protocol should contain our auth token
|
||||
if len(protocols) < 3:
|
||||
logger.error('Expected 3 websocket protocols, got %d', len(protocols))
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
else:
|
||||
sid = str(uuid.uuid4())
|
||||
jwt_token = sign_token({'sid': sid}, config.jwt_secret)
|
||||
|
||||
logger.info(f'New session: {sid}')
|
||||
session = session_manager.add_or_restart_session(sid, websocket)
|
||||
await websocket.send_json({'token': jwt_token, 'status': 'ok'})
|
||||
real_protocol = protocols[0]
|
||||
jwt_token = protocols[1] if protocols[1] != 'NO_JWT' else ''
|
||||
github_token = protocols[2] if protocols[2] != 'NO_GITHUB' else ''
|
||||
|
||||
latest_event_id = -1
|
||||
if websocket.query_params.get('latest_event_id'):
|
||||
latest_event_id = int(websocket.query_params.get('latest_event_id'))
|
||||
for event in session.agent_session.event_stream.get_events(
|
||||
start_id=latest_event_id + 1
|
||||
):
|
||||
if isinstance(
|
||||
event,
|
||||
(
|
||||
NullAction,
|
||||
NullObservation,
|
||||
ChangeAgentStateAction,
|
||||
AgentStateChangedObservation,
|
||||
),
|
||||
):
|
||||
continue
|
||||
await websocket.send_json(event_to_dict(event))
|
||||
if not await authenticate_github_user(github_token):
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
||||
return
|
||||
|
||||
await session.loop_recv()
|
||||
try:
|
||||
await asyncio.wait_for(websocket.accept(subprotocol=real_protocol), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("WebSocket accept timed out")
|
||||
await websocket.close(code=status.WS_1408_REQUEST_TIMEOUT)
|
||||
return
|
||||
|
||||
if jwt_token:
|
||||
sid = get_sid_from_token(jwt_token, config.jwt_secret)
|
||||
|
||||
if sid == '':
|
||||
await websocket.send_json({'error': 'Invalid token', 'error_code': 401})
|
||||
await websocket.close()
|
||||
return
|
||||
else:
|
||||
sid = str(uuid.uuid4())
|
||||
jwt_token = sign_token({'sid': sid}, config.jwt_secret)
|
||||
|
||||
logger.info(f'New session: {sid}')
|
||||
session = session_manager.add_or_restart_session(sid, websocket)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
websocket.send_json({'token': jwt_token, 'status': 'ok'}),
|
||||
timeout=5
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Failed to send initial response")
|
||||
await session.close()
|
||||
await websocket.close(code=status.WS_1408_REQUEST_TIMEOUT)
|
||||
return
|
||||
|
||||
latest_event_id = -1
|
||||
if websocket.query_params.get('latest_event_id'):
|
||||
latest_event_id = int(websocket.query_params.get('latest_event_id'))
|
||||
|
||||
# Send historical events with timeout
|
||||
try:
|
||||
async with asyncio.timeout(30): # 30 second timeout for historical events
|
||||
for event in session.agent_session.event_stream.get_events(
|
||||
start_id=latest_event_id + 1
|
||||
):
|
||||
if isinstance(
|
||||
event,
|
||||
(
|
||||
NullAction,
|
||||
NullObservation,
|
||||
ChangeAgentStateAction,
|
||||
AgentStateChangedObservation,
|
||||
),
|
||||
):
|
||||
continue
|
||||
await websocket.send_json(event_to_dict(event))
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Timeout sending historical events")
|
||||
await session.close()
|
||||
await websocket.close(code=status.WS_1408_REQUEST_TIMEOUT)
|
||||
return
|
||||
|
||||
# Start the main receive loop with heartbeat
|
||||
await session.loop_recv()
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info("WebSocket disconnected normally")
|
||||
except Exception as e:
|
||||
logger.exception("Error in websocket handler: %s", str(e))
|
||||
finally:
|
||||
if session:
|
||||
try:
|
||||
await asyncio.wait_for(session.close(), timeout=5)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Timeout during session cleanup")
|
||||
except Exception as e:
|
||||
logger.exception("Error during session cleanup: %s", str(e))
|
||||
|
||||
|
||||
@app.get('/api/options/models')
|
||||
|
||||
@@ -55,22 +55,39 @@ class AgentSession:
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
status_message_callback: Optional[Callable] = None,
|
||||
):
|
||||
"""Starts the Agent session
|
||||
"""Starts the Agent session with proper error handling and timeouts
|
||||
Parameters:
|
||||
- runtime_name: The name of the runtime associated with the session
|
||||
- config:
|
||||
- agent:
|
||||
- max_iterations:
|
||||
- max_budget_per_task:
|
||||
- agent_to_llm_config:
|
||||
- agent_configs:
|
||||
- config: Application configuration
|
||||
- agent: Agent instance to use
|
||||
- max_iterations: Maximum number of iterations
|
||||
- max_budget_per_task: Maximum budget per task
|
||||
- agent_to_llm_config: LLM configurations for different agents
|
||||
- agent_configs: Agent configurations
|
||||
- status_message_callback: Callback for status updates
|
||||
"""
|
||||
if self.controller or self.runtime:
|
||||
raise RuntimeError(
|
||||
'Session already started. You need to close this session and start a new one.'
|
||||
)
|
||||
|
||||
asyncio.get_event_loop().run_in_executor(
|
||||
# Create a future to track the start operation
|
||||
start_future = asyncio.Future()
|
||||
|
||||
def start_callback(future):
|
||||
try:
|
||||
exc = future.exception()
|
||||
if exc:
|
||||
start_future.set_exception(exc)
|
||||
else:
|
||||
start_future.set_result(None)
|
||||
except asyncio.CancelledError:
|
||||
start_future.cancel()
|
||||
except Exception as e:
|
||||
start_future.set_exception(e)
|
||||
|
||||
# Start the agent in a thread pool with proper error propagation
|
||||
task = asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
self._start_thread,
|
||||
runtime_name,
|
||||
@@ -82,13 +99,49 @@ class AgentSession:
|
||||
agent_configs,
|
||||
status_message_callback,
|
||||
)
|
||||
task.add_done_callback(start_callback)
|
||||
|
||||
try:
|
||||
# Wait for start with timeout
|
||||
await asyncio.wait_for(start_future, timeout=120)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Agent session start timed out")
|
||||
# Cleanup if start times out
|
||||
await self.close()
|
||||
raise RuntimeError("Agent session start timed out")
|
||||
except Exception as e:
|
||||
logger.exception("Error starting agent session")
|
||||
await self.close()
|
||||
raise
|
||||
|
||||
def _start_thread(self, *args):
|
||||
"""Start the agent in a separate thread with proper error handling"""
|
||||
try:
|
||||
asyncio.run(self._start(*args), debug=True)
|
||||
except RuntimeError:
|
||||
logger.error(f'Error starting session: {RuntimeError}', exc_info=True)
|
||||
logger.debug('Session Finished')
|
||||
# Create new event loop for this thread
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# Run with timeout
|
||||
loop.run_until_complete(
|
||||
asyncio.wait_for(self._start(*args), timeout=25)
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Timeout in agent start")
|
||||
raise RuntimeError("Timeout in agent start")
|
||||
except Exception as e:
|
||||
logger.exception("Error in agent start")
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
# Clean up the loop
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up thread loop: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error in start thread: {e}")
|
||||
raise
|
||||
|
||||
async def _start(
|
||||
self,
|
||||
@@ -101,7 +154,6 @@ class AgentSession:
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
status_message_callback: Optional[Callable] = None,
|
||||
):
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self._create_security_analyzer(config.security.security_analyzer)
|
||||
await self._create_runtime(
|
||||
runtime_name=runtime_name,
|
||||
@@ -125,29 +177,60 @@ class AgentSession:
|
||||
await self.controller.agent_task # type: ignore
|
||||
|
||||
async def close(self):
|
||||
"""Closes the Agent session"""
|
||||
|
||||
"""Closes the Agent session with proper cleanup and timeouts"""
|
||||
if self._closed:
|
||||
return
|
||||
if self.controller is not None:
|
||||
end_state = self.controller.get_state()
|
||||
end_state.save_to_session(self.sid, self.file_store)
|
||||
await self.controller.close()
|
||||
if self.runtime is not None:
|
||||
self.runtime.close()
|
||||
if self.security_analyzer is not None:
|
||||
await self.security_analyzer.close()
|
||||
|
||||
if self.loop:
|
||||
if self.loop.is_closed():
|
||||
logger.debug(
|
||||
'Trying to close already closed loop. (It probably never started correctly)'
|
||||
)
|
||||
else:
|
||||
self.loop.stop()
|
||||
self.loop = None
|
||||
try:
|
||||
# Set closed flag early to prevent multiple close attempts
|
||||
self._closed = True
|
||||
|
||||
self._closed = True
|
||||
# Save state with timeout
|
||||
if self.controller is not None:
|
||||
try:
|
||||
async with asyncio.timeout(5):
|
||||
end_state = self.controller.get_state()
|
||||
end_state.save_to_session(self.sid, self.file_store)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Timeout saving agent state")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving agent state: {e}")
|
||||
|
||||
# Close controller with timeout
|
||||
if self.controller is not None:
|
||||
try:
|
||||
async with asyncio.timeout(5):
|
||||
await self.controller.close()
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Timeout closing controller")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing controller: {e}")
|
||||
self.controller = None
|
||||
|
||||
# Close runtime (this is synchronous but should be quick)
|
||||
if self.runtime is not None:
|
||||
try:
|
||||
self.runtime.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing runtime: {e}")
|
||||
self.runtime = None
|
||||
|
||||
# Close security analyzer with timeout
|
||||
if self.security_analyzer is not None:
|
||||
try:
|
||||
async with asyncio.timeout(5):
|
||||
await self.security_analyzer.close()
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Timeout closing security analyzer")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing security analyzer: {e}")
|
||||
self.security_analyzer = None
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error during session cleanup: {e}")
|
||||
finally:
|
||||
# Ensure closed flag is set even if cleanup fails
|
||||
self._closed = True
|
||||
|
||||
def _create_security_analyzer(self, security_analyzer: str | None):
|
||||
"""Creates a SecurityAnalyzer instance that will be used to analyze the agent actions
|
||||
|
||||
@@ -51,23 +51,70 @@ class Session:
|
||||
self.is_alive = False
|
||||
await self.agent_session.close()
|
||||
|
||||
async def _heartbeat(self):
|
||||
"""Send periodic heartbeat to check connection is alive"""
|
||||
while self.is_alive and should_continue():
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self.websocket.send_json({"type": "heartbeat"}),
|
||||
timeout=5
|
||||
)
|
||||
await asyncio.sleep(1) # Send heartbeat every 30 seconds
|
||||
except (asyncio.TimeoutError, WebSocketDisconnect, RuntimeError):
|
||||
self.is_alive = False
|
||||
break
|
||||
except Exception as e:
|
||||
logger.exception("Error in heartbeat: %s", e)
|
||||
self.is_alive = False
|
||||
break
|
||||
|
||||
async def loop_recv(self):
|
||||
"""Main websocket receive loop with heartbeat"""
|
||||
if self.websocket is None:
|
||||
return
|
||||
|
||||
# Start heartbeat in background task
|
||||
heartbeat_task = asyncio.create_task(self._heartbeat())
|
||||
|
||||
try:
|
||||
if self.websocket is None:
|
||||
return
|
||||
while should_continue():
|
||||
while self.is_alive and should_continue():
|
||||
try:
|
||||
data = await self.websocket.receive_json()
|
||||
# Use timeout to prevent blocking forever
|
||||
data = await asyncio.wait_for(
|
||||
self.websocket.receive_json(),
|
||||
timeout=35 # Slightly longer than heartbeat interval
|
||||
)
|
||||
await self.dispatch(data)
|
||||
except asyncio.TimeoutError:
|
||||
# No message received within timeout, check if heartbeat is still alive
|
||||
if not heartbeat_task.done():
|
||||
continue
|
||||
else:
|
||||
logger.error("Heartbeat task failed, closing connection")
|
||||
break
|
||||
except ValueError:
|
||||
await self.send_error('Invalid JSON')
|
||||
continue
|
||||
await self.dispatch(data)
|
||||
except WebSocketDisconnect:
|
||||
except WebSocketDisconnect:
|
||||
logger.debug('WebSocket disconnected, sid: %s', self.sid)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.exception("Error processing message: %s", e)
|
||||
await self.send_error(f"Error processing message: {str(e)}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Fatal error in receive loop: %s", e)
|
||||
finally:
|
||||
# Cancel heartbeat task
|
||||
heartbeat_task.cancel()
|
||||
try:
|
||||
await heartbeat_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Close the session
|
||||
await self.close()
|
||||
logger.debug('WebSocket disconnected, sid: %s', self.sid)
|
||||
except RuntimeError as e:
|
||||
await self.close()
|
||||
logger.exception('Error in loop_recv: %s', e)
|
||||
|
||||
async def _initialize_agent(self, data: dict):
|
||||
self.agent_session.event_stream.add_event(
|
||||
@@ -152,12 +199,30 @@ class Session:
|
||||
await self.send(event_dict)
|
||||
|
||||
async def dispatch(self, data: dict):
|
||||
"""Dispatch incoming websocket messages to appropriate handlers"""
|
||||
action = data.get('action', '')
|
||||
|
||||
# Handle initialization separately
|
||||
if action == ActionType.INIT:
|
||||
await self._initialize_agent(data)
|
||||
try:
|
||||
async with asyncio.timeout(120):
|
||||
await self._initialize_agent(data)
|
||||
except asyncio.TimeoutError:
|
||||
await self.send_error('Agent initialization timed out')
|
||||
except Exception as e:
|
||||
logger.exception("Error initializing agent: %s", e)
|
||||
await self.send_error(f'Failed to initialize agent: {str(e)}')
|
||||
return
|
||||
event = event_from_dict(data.copy())
|
||||
# This checks if the model supports images
|
||||
|
||||
# Convert message to event
|
||||
try:
|
||||
event = event_from_dict(data.copy())
|
||||
except Exception as e:
|
||||
logger.error("Failed to parse event: %s", e)
|
||||
await self.send_error('Invalid event format')
|
||||
return
|
||||
|
||||
# Handle image validation
|
||||
if isinstance(event, MessageAction) and event.images_urls:
|
||||
controller = self.agent_session.controller
|
||||
if controller:
|
||||
@@ -171,10 +236,25 @@ class Session:
|
||||
'Model does not support image upload, change to a different model or try without an image.'
|
||||
)
|
||||
return
|
||||
if self.agent_session.loop:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._add_event(event, EventSource.USER), self.agent_session.loop
|
||||
) # type: ignore
|
||||
|
||||
|
||||
try:
|
||||
# Use asyncio.wait_for to prevent blocking indefinitely
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._add_event(event, EventSource.USER),
|
||||
asyncio.get_running_loop()
|
||||
)
|
||||
# Wait for the event to be processed with timeout
|
||||
await asyncio.wait_for(
|
||||
asyncio.wrap_future(future),
|
||||
timeout=10
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Event processing timed out")
|
||||
await self.send_error('Event processing timed out')
|
||||
except Exception as e:
|
||||
logger.exception("Error processing event: %s", e)
|
||||
await self.send_error(f'Failed to process event: {str(e)}')
|
||||
|
||||
async def _add_event(self, event, event_source):
|
||||
self.agent_session.event_stream.add_event(event, EventSource.USER)
|
||||
|
||||
Reference in New Issue
Block a user