From aabbbb6c6a4e43470ce60401b96038f5ddb533cc Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Tue, 7 Jan 2025 23:22:43 +0100 Subject: [PATCH] Fix duplicate state initialization (#6089) --- openhands/controller/agent_controller.py | 10 +- openhands/server/session/agent_session.py | 26 ++- tests/unit/test_agent_session.py | 186 ++++++++++++++++++++++ 3 files changed, 204 insertions(+), 18 deletions(-) create mode 100644 tests/unit/test_agent_session.py diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index f218716acd..7899c2dcfa 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -993,10 +993,12 @@ class AgentController: def __repr__(self): return ( - f'AgentController(id={self.id}, agent={self.agent!r}, ' - f'event_stream={self.event_stream!r}, ' - f'state={self.state!r}, ' - f'delegate={self.delegate!r}, _pending_action={self._pending_action!r})' + f'AgentController(id={getattr(self, "id", "")}, ' + f'agent={getattr(self, "agent", "")!r}, ' + f'event_stream={getattr(self, "event_stream", "")!r}, ' + f'state={getattr(self, "state", "")!r}, ' + f'delegate={getattr(self, "delegate", "")!r}, ' + f'_pending_action={getattr(self, "_pending_action", "")!r})' ) def _is_awaiting_observation(self): diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index 3a0c96804c..b8a440d32f 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -274,27 +274,25 @@ class AgentSession: confirmation_mode=confirmation_mode, headless_mode=False, status_callback=self._status_callback, + initial_state=self._maybe_restore_state(), ) - # Note: We now attempt to restore the state from session here, - # but if it fails, we fall back to None and still initialize the controller - # with a fresh state. That way, the controller will always load events from the event stream - # even if the state file was corrupt. + return controller + def _maybe_restore_state(self) -> State | None: + """Helper method to handle state restore logic.""" restored_state = None + + # Attempt to restore the state from session. + # Use a heuristic to figure out if we should have a state: + # if we have events in the stream. try: restored_state = State.restore_from_session(self.sid, self.file_store) + logger.debug(f'Restored state from session, sid: {self.sid}') except Exception as e: if self.event_stream.get_latest_event_id() > 0: # if we have events, we should have a state logger.warning(f'State could not be restored: {e}') - - # Set the initial state through the controller. - controller.set_initial_state(restored_state, max_iterations, confirmation_mode) - if restored_state: - logger.debug(f'Restored agent state from session, sid: {self.sid}') - else: - logger.debug('New session state created.') - - logger.debug('Agent controller initialized.') - return controller + else: + logger.debug('No events found, no state to restore') + return restored_state diff --git a/tests/unit/test_agent_session.py b/tests/unit/test_agent_session.py new file mode 100644 index 0000000000..90fca71e39 --- /dev/null +++ b/tests/unit/test_agent_session.py @@ -0,0 +1,186 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from openhands.controller.agent import Agent +from openhands.controller.agent_controller import AgentController +from openhands.controller.state.state import State +from openhands.core.config import AppConfig, LLMConfig +from openhands.events import EventStream, EventStreamSubscriber +from openhands.llm import LLM +from openhands.llm.metrics import Metrics +from openhands.runtime.base import Runtime +from openhands.server.session.agent_session import AgentSession +from openhands.storage.memory import InMemoryFileStore + + +@pytest.fixture +def mock_agent(): + """Create a properly configured mock agent with all required nested attributes""" + # Create the base mocks + agent = MagicMock(spec=Agent) + llm = MagicMock(spec=LLM) + metrics = MagicMock(spec=Metrics) + llm_config = MagicMock(spec=LLMConfig) + + # Configure the LLM config + llm_config.model = 'test-model' + llm_config.base_url = 'http://test' + llm_config.draft_editor = None + llm_config.max_message_chars = 1000 + + # Set up the chain of mocks + llm.metrics = metrics + llm.config = llm_config + agent.llm = llm + agent.name = 'test-agent' + agent.sandbox_plugins = [] + + return agent + + +@pytest.mark.asyncio +async def test_agent_session_start_with_no_state(mock_agent): + """Test that AgentSession.start() works correctly when there's no state to restore""" + + # Setup + file_store = InMemoryFileStore({}) + session = AgentSession(sid='test-session', file_store=file_store) + + # Create a mock runtime and set it up + mock_runtime = MagicMock(spec=Runtime) + + # Mock the runtime creation to set up the runtime attribute + async def mock_create_runtime(*args, **kwargs): + session.runtime = mock_runtime + + session._create_runtime = AsyncMock(side_effect=mock_create_runtime) + + # Create a mock EventStream with no events + mock_event_stream = MagicMock(spec=EventStream) + mock_event_stream.get_events.return_value = [] + mock_event_stream.subscribe = MagicMock() + mock_event_stream.get_latest_event_id.return_value = 0 + + # Inject the mock event stream into the session + session.event_stream = mock_event_stream + + # Create a spy on set_initial_state + class SpyAgentController(AgentController): + set_initial_state_call_count = 0 + test_initial_state = None + + def set_initial_state(self, *args, state=None, **kwargs): + self.set_initial_state_call_count += 1 + self.test_initial_state = state + super().set_initial_state(*args, state=state, **kwargs) + + # Patch AgentController and State.restore_from_session to fail + with patch( + 'openhands.server.session.agent_session.AgentController', SpyAgentController + ), patch( + 'openhands.server.session.agent_session.EventStream', + return_value=mock_event_stream, + ), patch( + 'openhands.controller.state.state.State.restore_from_session', + side_effect=Exception('No state found'), + ): + await session.start( + runtime_name='test-runtime', + config=AppConfig(), + agent=mock_agent, + max_iterations=10, + ) + + # Verify EventStream.subscribe was called with correct parameters + mock_event_stream.subscribe.assert_called_with( + EventStreamSubscriber.AGENT_CONTROLLER, + session.controller.on_event, + session.controller.id, + ) + + # Verify set_initial_state was called once with None as state + assert session.controller.set_initial_state_call_count == 1 + assert session.controller.test_initial_state is None + assert session.controller.state.max_iterations == 10 + assert session.controller.agent.name == 'test-agent' + assert session.controller.state.start_id == 0 + assert session.controller.state.end_id == -1 + assert session.controller.state.truncation_id == -1 + + +@pytest.mark.asyncio +async def test_agent_session_start_with_restored_state(mock_agent): + """Test that AgentSession.start() works correctly when there's a state to restore""" + + # Setup + file_store = InMemoryFileStore({}) + session = AgentSession(sid='test-session', file_store=file_store) + + # Create a mock runtime and set it up + mock_runtime = MagicMock(spec=Runtime) + + # Mock the runtime creation to set up the runtime attribute + async def mock_create_runtime(*args, **kwargs): + session.runtime = mock_runtime + + session._create_runtime = AsyncMock(side_effect=mock_create_runtime) + + # Create a mock EventStream with some events + mock_event_stream = MagicMock(spec=EventStream) + mock_event_stream.get_events.return_value = [] + mock_event_stream.subscribe = MagicMock() + mock_event_stream.get_latest_event_id.return_value = 5 # Indicate some events exist + + # Inject the mock event stream into the session + session.event_stream = mock_event_stream + + # Create a mock restored state + mock_restored_state = MagicMock(spec=State) + mock_restored_state.start_id = -1 + mock_restored_state.end_id = -1 + mock_restored_state.truncation_id = -1 + mock_restored_state.max_iterations = 5 + + # Create a spy on set_initial_state by subclassing AgentController + class SpyAgentController(AgentController): + set_initial_state_call_count = 0 + test_initial_state = None + + def set_initial_state(self, *args, state=None, **kwargs): + self.set_initial_state_call_count += 1 + self.test_initial_state = state + super().set_initial_state(*args, state=state, **kwargs) + + # Patch AgentController and State.restore_from_session to succeed + with patch( + 'openhands.server.session.agent_session.AgentController', SpyAgentController + ), patch( + 'openhands.server.session.agent_session.EventStream', + return_value=mock_event_stream, + ), patch( + 'openhands.controller.state.state.State.restore_from_session', + return_value=mock_restored_state, + ): + await session.start( + runtime_name='test-runtime', + config=AppConfig(), + agent=mock_agent, + max_iterations=10, + ) + + # Verify set_initial_state was called once with the restored state + assert session.controller.set_initial_state_call_count == 1 + + # Verify EventStream.subscribe was called with correct parameters + mock_event_stream.subscribe.assert_called_with( + EventStreamSubscriber.AGENT_CONTROLLER, + session.controller.on_event, + session.controller.id, + ) + assert session.controller.test_initial_state is mock_restored_state + assert session.controller.state is mock_restored_state + assert session.controller.state.max_iterations == 5 + assert session.controller.state.start_id == 0 + assert session.controller.state.end_id == -1 + assert session.controller.state.truncation_id == -1