Compare commits

...

4 Commits

4 changed files with 78 additions and 40 deletions

View File

@@ -724,9 +724,6 @@ class AgentController:
# update iteration that is shared across agents
self.state.iteration = self.delegate.state.iteration
# close the delegate controller before adding new events
asyncio.get_event_loop().run_until_complete(self.delegate.close())
if delegate_state in (AgentState.FINISHED, AgentState.REJECTED):
# retrieve delegate result
delegate_outputs = (
@@ -753,6 +750,11 @@ class AgentController:
content = f'Delegated agent finished with result:\n\n{content}'
# close the delegate controller
delegate = self.delegate
self.delegate = None # unset delegate so parent can resume normal handling
asyncio.get_event_loop().run_until_complete(delegate.close())
# emit the delegate result observation
obs = AgentDelegateObservation(outputs=delegate_outputs, content=content)
@@ -765,9 +767,6 @@ class AgentController:
self.event_stream.add_event(obs, EventSource.AGENT)
# unset delegate so parent can resume normal handling
self.delegate = None
async def _step(self) -> None:
"""Executes a single step of the parent or delegate agent. Detects stuck agents and limits on the number of iterations and the task budget."""
if self.get_agent_state() != AgentState.RUNNING:

View File

@@ -77,10 +77,12 @@ class EventStream(EventStore):
self._thread_loops[subscriber_id][callback_id] = loop
def close(self) -> None:
# Set stop flag and wait for queue thread to finish
self._stop_flag.set()
if self._queue_thread.is_alive():
self._queue_thread.join()
# Clean up all subscribers and their resources
subscriber_ids = list(self._subscribers.keys())
for subscriber_id in subscriber_ids:
callback_ids = list(self._subscribers[subscriber_id].keys())
@@ -91,6 +93,22 @@ class EventStream(EventStore):
while not self._queue.empty():
self._queue.get()
# Clean up queue loop
if self._queue_loop is not None:
try:
self._queue_loop.stop()
self._queue_loop.close()
except Exception as e:
logger.warning(f'Error closing queue loop: {e}')
self._queue_loop = None
# Reset state
self.cur_id = 0
self._write_page_cache = []
self._thread_pools.clear()
self._thread_loops.clear()
self._subscribers.clear()
def _clean_up_subscriber(self, subscriber_id: str, callback_id: str) -> None:
if subscriber_id not in self._subscribers:
logger.warning(f'Subscriber not found during cleanup: {subscriber_id}')

View File

@@ -15,6 +15,7 @@ from openhands.events import EventSource, EventStream
from openhands.events.action import (
AgentDelegateAction,
AgentFinishAction,
AgentRejectAction,
MessageAction,
)
from openhands.events.action.agent import RecallAction
@@ -28,11 +29,17 @@ from openhands.storage.memory import InMemoryFileStore
@pytest.fixture
def mock_event_stream():
def mock_event_stream(request):
"""Creates an event stream in memory."""
sid = f'test-{uuid4()}'
file_store = InMemoryFileStore({})
return EventStream(sid=sid, file_store=file_store)
stream = EventStream(sid=sid, file_store=file_store)
def cleanup():
stream.close()
request.addfinalizer(cleanup)
return stream
@pytest.fixture
@@ -48,10 +55,13 @@ def mock_parent_agent():
# Add a proper system message mock
from openhands.events.action.message import SystemMessageAction
system_message = SystemMessageAction(content='Test system message')
system_message._source = EventSource.AGENT
system_message._id = -1 # Set invalid ID to avoid the ID check
agent.get_system_message.return_value = system_message
def get_system_message():
system_message = SystemMessageAction(content='Test system message')
system_message._source = EventSource.AGENT
system_message._id = Event.INVALID_ID # Set invalid ID to avoid the ID check
return system_message
agent.get_system_message.side_effect = get_system_message
return agent
@@ -69,10 +79,13 @@ def mock_child_agent():
# Add a proper system message mock
from openhands.events.action.message import SystemMessageAction
system_message = SystemMessageAction(content='Test system message')
system_message._source = EventSource.AGENT
system_message._id = -1 # Set invalid ID to avoid the ID check
agent.get_system_message.return_value = system_message
def get_system_message():
system_message = SystemMessageAction(content='Test system message')
system_message._source = EventSource.AGENT
system_message._id = Event.INVALID_ID # Set invalid ID to avoid the ID check
return system_message
agent.get_system_message.side_effect = get_system_message
return agent
@@ -207,24 +220,18 @@ async def test_delegate_step_different_states(
mock_delegate._step = AsyncMock()
mock_delegate.close = AsyncMock()
def call_on_event_with_new_loop():
"""
In this thread, create and set a fresh event loop, so that the run_until_complete()
calls inside controller.on_event(...) find a valid loop.
"""
loop_in_thread = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop_in_thread)
msg_action = MessageAction(content='Test message')
msg_action._source = EventSource.USER
controller.on_event(msg_action)
finally:
loop_in_thread.close()
# First send a message to trigger the delegate state check
msg_action = MessageAction(content='Test message')
msg_action._source = EventSource.USER
await controller._on_event(msg_action)
loop = asyncio.get_running_loop()
with ThreadPoolExecutor() as executor:
future = loop.run_in_executor(executor, call_on_event_with_new_loop)
await future
# Then send a reject action to trigger delegate cleanup
reject_action = AgentRejectAction()
reject_action._source = EventSource.USER
await controller._on_event(reject_action)
# Give a little time for async operations to complete
await asyncio.sleep(0.1)
if delegate_state == AgentState.RUNNING:
assert controller.delegate is not None

View File

@@ -12,7 +12,7 @@ from openhands.core.main import run_controller
from openhands.core.schema.agent import AgentState
from openhands.events.action.agent import RecallAction
from openhands.events.action.message import MessageAction, SystemMessageAction
from openhands.events.event import EventSource
from openhands.events.event import Event, EventSource
from openhands.events.observation.agent import (
RecallObservation,
RecallType,
@@ -34,9 +34,18 @@ def file_store():
@pytest.fixture
def event_stream(file_store):
def event_stream(file_store, request):
"""Create a test event stream."""
return EventStream(sid='test_sid', file_store=file_store)
stream = EventStream(sid='test_sid', file_store=file_store)
def cleanup():
# Ensure all subscribers are removed
stream._subscribers.clear()
# Close the stream
stream.close()
request.addfinalizer(cleanup)
return stream
@pytest.fixture
@@ -69,10 +78,15 @@ def mock_agent():
agent.llm.config = AppConfig().get_llm_config()
# Add a proper system message mock
system_message = SystemMessageAction(content='Test system message')
system_message._source = EventSource.AGENT
system_message._id = -1 # Set invalid ID to avoid the ID check
agent.get_system_message.return_value = system_message
def get_system_message():
system_message = SystemMessageAction(content='Test system message')
system_message._source = EventSource.AGENT
system_message._id = Event.INVALID_ID # Set invalid ID to avoid the ID check
return system_message
agent.get_system_message.side_effect = get_system_message
agent.name = 'TestAgent' # Add a name to the agent
return agent
@pytest.mark.asyncio