From 804674bb9f601dd42e2489df2cae5208c8145617 Mon Sep 17 00:00:00 2001 From: niliy01 Date: Tue, 17 Sep 2024 02:13:52 +0800 Subject: [PATCH] refactor the logic in agent_controller to imporve readability (#3873) Signed-off-by: Yi Lin --- openhands/controller/agent_controller.py | 321 +++++++++++++---------- tests/unit/test_agent_controller.py | 51 +++- 2 files changed, 226 insertions(+), 146 deletions(-) diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index ae32f0f351..09d1c02a46 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -172,56 +172,83 @@ class AgentController: Args: event (Event): The incoming event to process. """ - if isinstance(event, ChangeAgentStateAction): - await self.set_agent_state_to(event.agent_state) # type: ignore - elif isinstance(event, MessageAction): - if event.source == EventSource.USER: - logger.info( - event, - extra={'msg_type': 'ACTION', 'event_source': EventSource.USER}, - ) - if self.get_agent_state() != AgentState.RUNNING: - await self.set_agent_state_to(AgentState.RUNNING) - elif event.source == EventSource.AGENT and event.wait_for_response: - await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT) - elif isinstance(event, AgentDelegateAction): - await self.start_delegate(event) - elif isinstance(event, AddTaskAction): - self.state.root_task.add_subtask(event.parent, event.goal, event.subtasks) - elif isinstance(event, ModifyTaskAction): - self.state.root_task.set_subtask_state(event.task_id, event.state) - elif isinstance(event, AgentFinishAction): - self.state.outputs = event.outputs + if isinstance(event, Action): + await self._handle_action(event) + elif isinstance(event, Observation): + await self._handle_observation(event) + + async def _handle_action(self, action: Action): + """Handles actions from the event stream. + + Args: + action (Action): The action to handle. + """ + if isinstance(action, ChangeAgentStateAction): + await self.set_agent_state_to(action.agent_state) # type: ignore + elif isinstance(action, MessageAction): + await self._handle_message_action(action) + elif isinstance(action, AgentDelegateAction): + await self.start_delegate(action) + elif isinstance(action, AddTaskAction): + self.state.root_task.add_subtask( + action.parent, action.goal, action.subtasks + ) + elif isinstance(action, ModifyTaskAction): + self.state.root_task.set_subtask_state(action.task_id, action.state) + elif isinstance(action, AgentFinishAction): + self.state.outputs = action.outputs self.state.metrics.merge(self.state.local_metrics) await self.set_agent_state_to(AgentState.FINISHED) - elif isinstance(event, AgentRejectAction): - self.state.outputs = event.outputs + elif isinstance(action, AgentRejectAction): + self.state.outputs = action.outputs self.state.metrics.merge(self.state.local_metrics) await self.set_agent_state_to(AgentState.REJECTED) - elif isinstance(event, Observation): - if ( - self._pending_action - and hasattr(self._pending_action, 'is_confirmed') - and self._pending_action.is_confirmed - == ActionConfirmationStatus.AWAITING_CONFIRMATION - ): - return - if self._pending_action and self._pending_action.id == event.cause: - self._pending_action = None - if self.state.agent_state == AgentState.USER_CONFIRMED: - await self.set_agent_state_to(AgentState.RUNNING) - if self.state.agent_state == AgentState.USER_REJECTED: - await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT) - logger.info(event, extra={'msg_type': 'OBSERVATION'}) - elif isinstance(event, CmdOutputObservation): - logger.info(event, extra={'msg_type': 'OBSERVATION'}) - elif isinstance(event, AgentDelegateObservation): - self.state.history.on_event(event) - logger.info(event, extra={'msg_type': 'OBSERVATION'}) - elif isinstance(event, ErrorObservation): - logger.info(event, extra={'msg_type': 'OBSERVATION'}) - if self.state.agent_state == AgentState.ERROR: - self.state.metrics.merge(self.state.local_metrics) + + async def _handle_observation(self, observation: Observation): + """Handles observation from the event stream. + + Args: + observation (observation): The observation to handle. + """ + if ( + self._pending_action + and hasattr(self._pending_action, 'is_confirmed') + and self._pending_action.is_confirmed + == ActionConfirmationStatus.AWAITING_CONFIRMATION + ): + return + + logger.info(observation, extra={'msg_type': 'OBSERVATION'}) + if self._pending_action and self._pending_action.id == observation.cause: + self._pending_action = None + if self.state.agent_state == AgentState.USER_CONFIRMED: + await self.set_agent_state_to(AgentState.RUNNING) + if self.state.agent_state == AgentState.USER_REJECTED: + await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT) + return + + if isinstance(observation, CmdOutputObservation): + return + elif isinstance(observation, AgentDelegateObservation): + self.state.history.on_event(observation) + elif isinstance(observation, ErrorObservation): + if self.state.agent_state == AgentState.ERROR: + self.state.metrics.merge(self.state.local_metrics) + + async def _handle_message_action(self, action: MessageAction): + """Handles message actions from the event stream. + + Args: + action (MessageAction): The message action to handle. + """ + if action.source == EventSource.USER: + logger.info( + action, extra={'msg_type': 'ACTION', 'event_source': EventSource.USER} + ) + if self.get_agent_state() != AgentState.RUNNING: + await self.set_agent_state_to(AgentState.RUNNING) + elif action.source == EventSource.AGENT and action.wait_for_response: + await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT) def reset_task(self): """Resets the agent's task.""" @@ -242,9 +269,11 @@ class AgentController: if new_state == self.state.agent_state: return - if ( - self.state.agent_state == AgentState.PAUSED - and new_state == AgentState.RUNNING + if new_state == AgentState.STOPPED or new_state == AgentState.ERROR: + self.reset_task() + elif ( + new_state == AgentState.RUNNING + and self.state.agent_state == AgentState.PAUSED and self.state.traffic_control_state == TrafficControlState.THROTTLING ): # user intends to interrupt traffic control and let the task resume temporarily @@ -257,6 +286,7 @@ class AgentController: ): if self.state.iteration >= self.state.max_iterations: self.state.max_iterations += self._initial_max_iterations + if ( self.state.metrics.accumulated_cost is not None and self.max_budget_per_task is not None @@ -264,12 +294,7 @@ class AgentController: ): if self.state.metrics.accumulated_cost >= self.max_budget_per_task: self.max_budget_per_task += self._initial_max_budget_per_task - - self.state.agent_state = new_state - if new_state == AgentState.STOPPED or new_state == AgentState.ERROR: - self.reset_task() - - if self._pending_action is not None and ( + elif self._pending_action is not None and ( new_state == AgentState.USER_CONFIRMED or new_state == AgentState.USER_REJECTED ): @@ -281,6 +306,7 @@ class AgentController: self._pending_action.is_confirmed = ActionConfirmationStatus.REJECTED # type: ignore[attr-defined] self.event_stream.add_event(self._pending_action, EventSource.AGENT) + self.state.agent_state = new_state self.event_stream.add_event( AgentStateChangedObservation('', self.state.agent_state), EventSource.AGENT ) @@ -355,56 +381,8 @@ class AgentController: return if self.delegate is not None: - logger.debug(f'[Agent Controller {self.id}] Delegate not none, awaiting...') assert self.delegate != self - await self.delegate._step() - logger.debug(f'[Agent Controller {self.id}] Delegate step done') - assert self.delegate is not None - delegate_state = self.delegate.get_agent_state() - logger.debug( - f'[Agent Controller {self.id}] Delegate state: {delegate_state}' - ) - if delegate_state == AgentState.ERROR: - # update iteration that shall be shared across agents - self.state.iteration = self.delegate.state.iteration - - # close the delegate upon error - await self.delegate.close() - self.delegate = None - self.delegateAction = None - - await self.report_error('Delegator agent encounters an error') - return - delegate_done = delegate_state in (AgentState.FINISHED, AgentState.REJECTED) - if delegate_done: - logger.info( - f'[Agent Controller {self.id}] Delegate agent has finished execution' - ) - # retrieve delegate result - outputs = self.delegate.state.outputs if self.delegate.state else {} - - # update iteration that shall be shared across agents - self.state.iteration = self.delegate.state.iteration - - # close delegate controller: we must close the delegate controller before adding new events - await self.delegate.close() - - # update delegate result observation - # TODO: replace this with AI-generated summary (#2395) - formatted_output = ', '.join( - f'{key}: {value}' for key, value in outputs.items() - ) - content = ( - f'{self.delegate.agent.name} finishes task with {formatted_output}' - ) - obs: Observation = AgentDelegateObservation( - outputs=outputs, content=content - ) - - # clean up delegate status - self.delegate = None - self.delegateAction = None - self.event_stream.add_event(obs, EventSource.AGENT) + await self._delegate_step() return logger.info( @@ -412,50 +390,20 @@ class AgentController: extra={'msg_type': 'STEP'}, ) + # check if agent hit the resources limit + stop_step = False if self.state.iteration >= self.state.max_iterations: - if self.state.traffic_control_state == TrafficControlState.PAUSED: - logger.info( - 'Hitting traffic control, temporarily resume upon user request' - ) - self.state.traffic_control_state = TrafficControlState.NORMAL - else: - self.state.traffic_control_state = TrafficControlState.THROTTLING - if self.headless_mode: - # set to ERROR state if running in headless mode - # since user cannot resume on the web interface - await self.report_error( - 'Agent reached maximum number of iterations in headless mode, task stopped.' - ) - await self.set_agent_state_to(AgentState.ERROR) - else: - await self.report_error( - f'Agent reached maximum number of iterations, task paused. {TRAFFIC_CONTROL_REMINDER}' - ) - await self.set_agent_state_to(AgentState.PAUSED) - return - elif self.max_budget_per_task is not None: + stop_step = await self._handle_traffic_control( + 'iteration', self.state.iteration, self.state.max_iterations + ) + if self.max_budget_per_task is not None: current_cost = self.state.metrics.accumulated_cost if current_cost > self.max_budget_per_task: - if self.state.traffic_control_state == TrafficControlState.PAUSED: - logger.info( - 'Hitting traffic control, temporarily resume upon user request' - ) - self.state.traffic_control_state = TrafficControlState.NORMAL - else: - self.state.traffic_control_state = TrafficControlState.THROTTLING - if self.headless_mode: - # set to ERROR state if running in headless mode - # there is no way to resume - await self.report_error( - f'Task budget exceeded. Current cost: {current_cost:.2f}, max budget: {self.max_budget_per_task:.2f}, task stopped.' - ) - await self.set_agent_state_to(AgentState.ERROR) - else: - await self.report_error( - f'Task budget exceeded. Current cost: {current_cost:.2f}, Max budget: {self.max_budget_per_task:.2f}, task paused. {TRAFFIC_CONTROL_REMINDER}' - ) - await self.set_agent_state_to(AgentState.PAUSED) - return + stop_step = await self._handle_traffic_control( + 'budget', current_cost, self.max_budget_per_task + ) + if stop_step: + return self.update_state_before_step() action: Action = NullAction() @@ -492,6 +440,89 @@ class AgentController: await self.report_error('Agent got stuck in a loop') await self.set_agent_state_to(AgentState.ERROR) + async def _delegate_step(self): + """Executes a single step of the delegate agent.""" + logger.debug(f'[Agent Controller {self.id}] Delegate not none, awaiting...') + await self.delegate._step() # type: ignore[union-attr] + logger.debug(f'[Agent Controller {self.id}] Delegate step done') + assert self.delegate is not None + delegate_state = self.delegate.get_agent_state() + logger.debug(f'[Agent Controller {self.id}] Delegate state: {delegate_state}') + if delegate_state == AgentState.ERROR: + # update iteration that shall be shared across agents + self.state.iteration = self.delegate.state.iteration + + # close the delegate upon error + await self.delegate.close() + self.delegate = None + self.delegateAction = None + + await self.report_error('Delegator agent encounters an error') + elif delegate_state in (AgentState.FINISHED, AgentState.REJECTED): + logger.info( + f'[Agent Controller {self.id}] Delegate agent has finished execution' + ) + # retrieve delegate result + outputs = self.delegate.state.outputs if self.delegate.state else {} + + # update iteration that shall be shared across agents + self.state.iteration = self.delegate.state.iteration + + # close delegate controller: we must close the delegate controller before adding new events + await self.delegate.close() + + # update delegate result observation + # TODO: replace this with AI-generated summary (#2395) + formatted_output = ', '.join( + f'{key}: {value}' for key, value in outputs.items() + ) + content = ( + f'{self.delegate.agent.name} finishes task with {formatted_output}' + ) + obs: Observation = AgentDelegateObservation( + outputs=outputs, content=content + ) + + # clean up delegate status + self.delegate = None + self.delegateAction = None + self.event_stream.add_event(obs, EventSource.AGENT) + return + + async def _handle_traffic_control( + self, limit_type: str, current_value: float, max_value: float + ): + """Handles agent state after hitting the traffic control limit. + + Args: + limit_type (str): The type of limit that was hit. + current_value (float): The current value of the limit. + max_value (float): The maximum value of the limit. + """ + stop_step = False + if self.state.traffic_control_state == TrafficControlState.PAUSED: + logger.info('Hitting traffic control, temporarily resume upon user request') + self.state.traffic_control_state = TrafficControlState.NORMAL + else: + self.state.traffic_control_state = TrafficControlState.THROTTLING + if self.headless_mode: + # set to ERROR state if running in headless mode + # since user cannot resume on the web interface + await self.report_error( + f'Agent reached maximum {limit_type} in headless mode, task stopped. ' + f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}' + ) + await self.set_agent_state_to(AgentState.ERROR) + else: + await self.report_error( + f'Agent reached maximum {limit_type}, task paused. ' + f'Current {limit_type}: {current_value:.2f}, max {limit_type}: {max_value:.2f}. ' + f'{TRAFFIC_CONTROL_REMINDER}' + ) + await self.set_agent_state_to(AgentState.PAUSED) + stop_step = True + return stop_step + def get_state(self): """Returns the current running state object. diff --git a/tests/unit/test_agent_controller.py b/tests/unit/test_agent_controller.py index 20fc9292f3..fbaa225c90 100644 --- a/tests/unit/test_agent_controller.py +++ b/tests/unit/test_agent_controller.py @@ -1,5 +1,5 @@ import asyncio -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, Mock import pytest @@ -123,6 +123,55 @@ async def test_step_with_exception(mock_agent, mock_event_stream): await controller.close() +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'delegate_state', + [ + AgentState.RUNNING, + AgentState.FINISHED, + AgentState.ERROR, + AgentState.REJECTED, + ], +) +async def test_delegate_step_different_states( + mock_agent, mock_event_stream, delegate_state +): + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test', + confirmation_mode=False, + headless_mode=True, + ) + + mock_delegate = AsyncMock() + controller.delegate = mock_delegate + + mock_delegate.state.iteration = 5 + mock_delegate.state.outputs = {'result': 'test'} + mock_delegate.agent.name = 'TestDelegate' + + mock_delegate.get_agent_state = Mock(return_value=delegate_state) + mock_delegate._step = AsyncMock() + mock_delegate.close = AsyncMock() + + await controller._delegate_step() + + mock_delegate._step.assert_called_once() + + if delegate_state == AgentState.RUNNING: + assert controller.delegate is not None + assert controller.state.iteration == 0 + mock_delegate.close.assert_not_called() + else: + assert controller.delegate is None + assert controller.state.iteration == 5 + mock_delegate.close.assert_called_once() + + await controller.close() + + @pytest.mark.asyncio async def test_step_max_iterations(mock_agent, mock_event_stream): controller = AgentController(