mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
Add tests for agent controller (#3357)
* Add tests for agent controller * Remove dead code * Remove dead code
This commit is contained in:
194
tests/unit/test_agent_controller.py
Normal file
194
tests/unit/test_agent_controller.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from opendevin.controller.agent import Agent
|
||||
from opendevin.controller.agent_controller import AgentController
|
||||
from opendevin.controller.state.state import TrafficControlState
|
||||
from opendevin.core.exceptions import LLMMalformedActionError
|
||||
from opendevin.core.schema import AgentState
|
||||
from opendevin.events import EventStream
|
||||
from opendevin.events.action import ChangeAgentStateAction, MessageAction
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str:
|
||||
return str(tmp_path_factory.mktemp('test_event_stream'))
|
||||
|
||||
|
||||
@pytest.fixture(scope='function')
|
||||
def event_loop():
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
return MagicMock(spec=Agent)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_stream():
|
||||
return MagicMock(spec=EventStream)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_agent_state(mock_agent, mock_event_stream):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
await controller.set_agent_state_to(AgentState.RUNNING)
|
||||
assert controller.get_agent_state() == AgentState.RUNNING
|
||||
|
||||
await controller.set_agent_state_to(AgentState.PAUSED)
|
||||
assert controller.get_agent_state() == AgentState.PAUSED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_event_message_action(mock_agent, mock_event_stream):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
controller.state.agent_state = AgentState.RUNNING
|
||||
message_action = MessageAction(content='Test message')
|
||||
await controller.on_event(message_action)
|
||||
assert controller.get_agent_state() == AgentState.RUNNING
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_event_change_agent_state_action(mock_agent, mock_event_stream):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
controller.state.agent_state = AgentState.RUNNING
|
||||
change_state_action = ChangeAgentStateAction(agent_state=AgentState.PAUSED)
|
||||
await controller.on_event(change_state_action)
|
||||
assert controller.get_agent_state() == AgentState.PAUSED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_error(mock_agent, mock_event_stream):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
error_message = 'Test error'
|
||||
await controller.report_error(error_message)
|
||||
assert controller.state.last_error == error_message
|
||||
controller.event_stream.add_event.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_with_exception(mock_agent, mock_event_stream):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
controller.state.agent_state = AgentState.RUNNING
|
||||
controller.report_error = AsyncMock()
|
||||
controller.agent.step.side_effect = LLMMalformedActionError('Malformed action')
|
||||
await controller._step()
|
||||
|
||||
# Verify that report_error was called with the correct error message
|
||||
controller.report_error.assert_called_once_with('Malformed action')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_max_iterations(mock_agent, mock_event_stream):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=False,
|
||||
)
|
||||
controller.state.agent_state = AgentState.RUNNING
|
||||
controller.state.iteration = 10
|
||||
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
||||
await controller._step()
|
||||
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
||||
assert controller.state.agent_state == AgentState.PAUSED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_max_iterations_headless(mock_agent, mock_event_stream):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
controller.state.agent_state = AgentState.RUNNING
|
||||
controller.state.iteration = 10
|
||||
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
||||
await controller._step()
|
||||
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
||||
# In headless mode, throttling results in an error
|
||||
assert controller.state.agent_state == AgentState.ERROR
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_max_budget(mock_agent, mock_event_stream):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
max_budget_per_task=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=False,
|
||||
)
|
||||
controller.state.agent_state = AgentState.RUNNING
|
||||
controller.state.metrics.accumulated_cost = 10.1
|
||||
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
||||
await controller._step()
|
||||
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
||||
assert controller.state.agent_state == AgentState.PAUSED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_step_max_budget_headless(mock_agent, mock_event_stream):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
max_budget_per_task=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
controller.state.agent_state = AgentState.RUNNING
|
||||
controller.state.metrics.accumulated_cost = 10.1
|
||||
assert controller.state.traffic_control_state == TrafficControlState.NORMAL
|
||||
await controller._step()
|
||||
assert controller.state.traffic_control_state == TrafficControlState.THROTTLING
|
||||
# In headless mode, throttling results in an error
|
||||
assert controller.state.agent_state == AgentState.ERROR
|
||||
Reference in New Issue
Block a user