Add close method to EventStream (#6093)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: tofarr <tofarr@gmail.com>
This commit is contained in:
Robert Brennan
2025-01-06 16:59:42 -05:00
committed by GitHub
parent 9515ac5e62
commit 8cfcdd7ba3
4 changed files with 77 additions and 11 deletions

View File

@@ -1,9 +1,10 @@
import asyncio
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from enum import Enum
from queue import Queue
from functools import partial
from typing import Callable, Iterable
from openhands.core.logger import openhands_logger as logger
@@ -61,12 +62,19 @@ class EventStream:
_subscribers: dict[str, dict[str, Callable]]
_cur_id: int = 0
_lock: threading.Lock
_queue: queue.Queue[Event]
_queue_thread: threading.Thread
_queue_loop: asyncio.AbstractEventLoop | None
_thread_loops: dict[str, dict[str, asyncio.AbstractEventLoop]]
def __init__(self, sid: str, file_store: FileStore, num_workers: int = 1):
def __init__(self, sid: str, file_store: FileStore):
self.sid = sid
self.file_store = file_store
self._queue: Queue[Event] = Queue()
self._stop_flag = threading.Event()
self._queue: queue.Queue[Event] = queue.Queue()
self._thread_pools: dict[str, dict[str, ThreadPoolExecutor]] = {}
self._thread_loops: dict[str, dict[str, asyncio.AbstractEventLoop]] = {}
self._queue_loop = None
self._queue_thread = threading.Thread(target=self._run_queue_loop)
self._queue_thread.daemon = True
self._queue_thread.start()
@@ -91,9 +99,54 @@ class EventStream:
if id >= self._cur_id:
self._cur_id = id + 1
def _init_thread_loop(self):
def _init_thread_loop(self, subscriber_id: str, callback_id: str):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if subscriber_id not in self._thread_loops:
self._thread_loops[subscriber_id] = {}
self._thread_loops[subscriber_id][callback_id] = loop
def close(self):
self._stop_flag.set()
if self._queue_thread.is_alive():
self._queue_thread.join()
subscriber_ids = list(self._subscribers.keys())
for subscriber_id in subscriber_ids:
callback_ids = list(self._subscribers[subscriber_id].keys())
for callback_id in callback_ids:
self._clean_up_subscriber(subscriber_id, callback_id)
def _clean_up_subscriber(self, subscriber_id: str, callback_id: str):
if subscriber_id not in self._subscribers:
logger.warning(f'Subscriber not found during cleanup: {subscriber_id}')
return
if callback_id not in self._subscribers[subscriber_id]:
logger.warning(f'Callback not found during cleanup: {callback_id}')
return
if (
subscriber_id in self._thread_loops
and callback_id in self._thread_loops[subscriber_id]
):
loop = self._thread_loops[subscriber_id][callback_id]
try:
loop.stop()
loop.close()
except Exception as e:
logger.warning(
f'Error closing loop for {subscriber_id}/{callback_id}: {e}'
)
del self._thread_loops[subscriber_id][callback_id]
if (
subscriber_id in self._thread_pools
and callback_id in self._thread_pools[subscriber_id]
):
pool = self._thread_pools[subscriber_id][callback_id]
pool.shutdown()
del self._thread_pools[subscriber_id][callback_id]
del self._subscribers[subscriber_id][callback_id]
def _get_filename_for_id(self, id: int) -> str:
return get_conversation_event_filename(self.sid, id)
@@ -176,7 +229,8 @@ class EventStream:
def subscribe(
self, subscriber_id: EventStreamSubscriber, callback: Callable, callback_id: str
):
pool = ThreadPoolExecutor(max_workers=1, initializer=self._init_thread_loop)
initializer = partial(self._init_thread_loop, subscriber_id, callback_id)
pool = ThreadPoolExecutor(max_workers=1, initializer=initializer)
if subscriber_id not in self._subscribers:
self._subscribers[subscriber_id] = {}
self._thread_pools[subscriber_id] = {}
@@ -198,7 +252,7 @@ class EventStream:
logger.warning(f'Callback not found during unsubscribe: {callback_id}')
return
del self._subscribers[subscriber_id][callback_id]
self._clean_up_subscriber(subscriber_id, callback_id)
def add_event(self, event: Event, source: EventSource):
if hasattr(event, '_id') and event.id is not None:
@@ -217,13 +271,20 @@ class EventStream:
self._queue.put(event)
def _run_queue_loop(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._process_queue())
self._queue_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._queue_loop)
try:
self._queue_loop.run_until_complete(self._process_queue())
finally:
self._queue_loop.close()
async def _process_queue(self):
while should_continue():
event = self._queue.get()
while should_continue() and not self._stop_flag.is_set():
event = None
try:
event = self._queue.get(timeout=0.1)
except queue.Empty:
continue
for key in sorted(self._subscribers.keys()):
callbacks = self._subscribers[key]
for callback_id in callbacks:

View File

@@ -131,6 +131,8 @@ class AgentSession:
f'Waited too long for initialization to finish before closing session {self.sid}'
)
break
if self.event_stream is not None:
self.event_stream.close()
if self.controller is not None:
end_state = self.controller.get_state()
end_state.save_to_session(self.sid, self.file_store)

View File

@@ -43,4 +43,6 @@ class Conversation:
await self.runtime.connect()
async def disconnect(self):
if self.event_stream:
self.event_stream.close()
asyncio.create_task(call_sync_from_async(self.runtime.close))

View File

@@ -201,6 +201,7 @@ class SessionManager:
await c.connect()
except AgentRuntimeUnavailableError as e:
logger.error(f'Error connecting to conversation {c.sid}: {e}')
await c.disconnect()
return None
end_time = time.time()
logger.info(