mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
4 Commits
replace-si
...
openhands-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
68f5275c4c | ||
|
|
485a844d5e | ||
|
|
2930f67451 | ||
|
|
dd3e54f27c |
@@ -724,9 +724,6 @@ class AgentController:
|
||||
# update iteration that is shared across agents
|
||||
self.state.iteration = self.delegate.state.iteration
|
||||
|
||||
# close the delegate controller before adding new events
|
||||
asyncio.get_event_loop().run_until_complete(self.delegate.close())
|
||||
|
||||
if delegate_state in (AgentState.FINISHED, AgentState.REJECTED):
|
||||
# retrieve delegate result
|
||||
delegate_outputs = (
|
||||
@@ -753,6 +750,11 @@ class AgentController:
|
||||
|
||||
content = f'Delegated agent finished with result:\n\n{content}'
|
||||
|
||||
# close the delegate controller
|
||||
delegate = self.delegate
|
||||
self.delegate = None # unset delegate so parent can resume normal handling
|
||||
asyncio.get_event_loop().run_until_complete(delegate.close())
|
||||
|
||||
# emit the delegate result observation
|
||||
obs = AgentDelegateObservation(outputs=delegate_outputs, content=content)
|
||||
|
||||
@@ -765,9 +767,6 @@ class AgentController:
|
||||
|
||||
self.event_stream.add_event(obs, EventSource.AGENT)
|
||||
|
||||
# unset delegate so parent can resume normal handling
|
||||
self.delegate = None
|
||||
|
||||
async def _step(self) -> None:
|
||||
"""Executes a single step of the parent or delegate agent. Detects stuck agents and limits on the number of iterations and the task budget."""
|
||||
if self.get_agent_state() != AgentState.RUNNING:
|
||||
|
||||
@@ -77,10 +77,12 @@ class EventStream(EventStore):
|
||||
self._thread_loops[subscriber_id][callback_id] = loop
|
||||
|
||||
def close(self) -> None:
|
||||
# Set stop flag and wait for queue thread to finish
|
||||
self._stop_flag.set()
|
||||
if self._queue_thread.is_alive():
|
||||
self._queue_thread.join()
|
||||
|
||||
# Clean up all subscribers and their resources
|
||||
subscriber_ids = list(self._subscribers.keys())
|
||||
for subscriber_id in subscriber_ids:
|
||||
callback_ids = list(self._subscribers[subscriber_id].keys())
|
||||
@@ -91,6 +93,22 @@ class EventStream(EventStore):
|
||||
while not self._queue.empty():
|
||||
self._queue.get()
|
||||
|
||||
# Clean up queue loop
|
||||
if self._queue_loop is not None:
|
||||
try:
|
||||
self._queue_loop.stop()
|
||||
self._queue_loop.close()
|
||||
except Exception as e:
|
||||
logger.warning(f'Error closing queue loop: {e}')
|
||||
self._queue_loop = None
|
||||
|
||||
# Reset state
|
||||
self.cur_id = 0
|
||||
self._write_page_cache = []
|
||||
self._thread_pools.clear()
|
||||
self._thread_loops.clear()
|
||||
self._subscribers.clear()
|
||||
|
||||
def _clean_up_subscriber(self, subscriber_id: str, callback_id: str) -> None:
|
||||
if subscriber_id not in self._subscribers:
|
||||
logger.warning(f'Subscriber not found during cleanup: {subscriber_id}')
|
||||
|
||||
@@ -15,6 +15,7 @@ from openhands.events import EventSource, EventStream
|
||||
from openhands.events.action import (
|
||||
AgentDelegateAction,
|
||||
AgentFinishAction,
|
||||
AgentRejectAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.action.agent import RecallAction
|
||||
@@ -28,11 +29,17 @@ from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_stream():
|
||||
def mock_event_stream(request):
|
||||
"""Creates an event stream in memory."""
|
||||
sid = f'test-{uuid4()}'
|
||||
file_store = InMemoryFileStore({})
|
||||
return EventStream(sid=sid, file_store=file_store)
|
||||
stream = EventStream(sid=sid, file_store=file_store)
|
||||
|
||||
def cleanup():
|
||||
stream.close()
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
return stream
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -48,10 +55,13 @@ def mock_parent_agent():
|
||||
# Add a proper system message mock
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
|
||||
system_message = SystemMessageAction(content='Test system message')
|
||||
system_message._source = EventSource.AGENT
|
||||
system_message._id = -1 # Set invalid ID to avoid the ID check
|
||||
agent.get_system_message.return_value = system_message
|
||||
def get_system_message():
|
||||
system_message = SystemMessageAction(content='Test system message')
|
||||
system_message._source = EventSource.AGENT
|
||||
system_message._id = Event.INVALID_ID # Set invalid ID to avoid the ID check
|
||||
return system_message
|
||||
|
||||
agent.get_system_message.side_effect = get_system_message
|
||||
|
||||
return agent
|
||||
|
||||
@@ -69,10 +79,13 @@ def mock_child_agent():
|
||||
# Add a proper system message mock
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
|
||||
system_message = SystemMessageAction(content='Test system message')
|
||||
system_message._source = EventSource.AGENT
|
||||
system_message._id = -1 # Set invalid ID to avoid the ID check
|
||||
agent.get_system_message.return_value = system_message
|
||||
def get_system_message():
|
||||
system_message = SystemMessageAction(content='Test system message')
|
||||
system_message._source = EventSource.AGENT
|
||||
system_message._id = Event.INVALID_ID # Set invalid ID to avoid the ID check
|
||||
return system_message
|
||||
|
||||
agent.get_system_message.side_effect = get_system_message
|
||||
|
||||
return agent
|
||||
|
||||
@@ -207,24 +220,18 @@ async def test_delegate_step_different_states(
|
||||
mock_delegate._step = AsyncMock()
|
||||
mock_delegate.close = AsyncMock()
|
||||
|
||||
def call_on_event_with_new_loop():
|
||||
"""
|
||||
In this thread, create and set a fresh event loop, so that the run_until_complete()
|
||||
calls inside controller.on_event(...) find a valid loop.
|
||||
"""
|
||||
loop_in_thread = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(loop_in_thread)
|
||||
msg_action = MessageAction(content='Test message')
|
||||
msg_action._source = EventSource.USER
|
||||
controller.on_event(msg_action)
|
||||
finally:
|
||||
loop_in_thread.close()
|
||||
# First send a message to trigger the delegate state check
|
||||
msg_action = MessageAction(content='Test message')
|
||||
msg_action._source = EventSource.USER
|
||||
await controller._on_event(msg_action)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future = loop.run_in_executor(executor, call_on_event_with_new_loop)
|
||||
await future
|
||||
# Then send a reject action to trigger delegate cleanup
|
||||
reject_action = AgentRejectAction()
|
||||
reject_action._source = EventSource.USER
|
||||
await controller._on_event(reject_action)
|
||||
|
||||
# Give a little time for async operations to complete
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
if delegate_state == AgentState.RUNNING:
|
||||
assert controller.delegate is not None
|
||||
|
||||
@@ -12,7 +12,7 @@ from openhands.core.main import run_controller
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.action.message import MessageAction, SystemMessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.observation.agent import (
|
||||
RecallObservation,
|
||||
RecallType,
|
||||
@@ -34,9 +34,18 @@ def file_store():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def event_stream(file_store):
|
||||
def event_stream(file_store, request):
|
||||
"""Create a test event stream."""
|
||||
return EventStream(sid='test_sid', file_store=file_store)
|
||||
stream = EventStream(sid='test_sid', file_store=file_store)
|
||||
|
||||
def cleanup():
|
||||
# Ensure all subscribers are removed
|
||||
stream._subscribers.clear()
|
||||
# Close the stream
|
||||
stream.close()
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
return stream
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -69,10 +78,15 @@ def mock_agent():
|
||||
agent.llm.config = AppConfig().get_llm_config()
|
||||
|
||||
# Add a proper system message mock
|
||||
system_message = SystemMessageAction(content='Test system message')
|
||||
system_message._source = EventSource.AGENT
|
||||
system_message._id = -1 # Set invalid ID to avoid the ID check
|
||||
agent.get_system_message.return_value = system_message
|
||||
def get_system_message():
|
||||
system_message = SystemMessageAction(content='Test system message')
|
||||
system_message._source = EventSource.AGENT
|
||||
system_message._id = Event.INVALID_ID # Set invalid ID to avoid the ID check
|
||||
return system_message
|
||||
|
||||
agent.get_system_message.side_effect = get_system_message
|
||||
agent.name = 'TestAgent' # Add a name to the agent
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user