mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-08 22:38:05 -05:00
Fix duplicate state initialization (#6089)
This commit is contained in:
@@ -993,10 +993,12 @@ class AgentController:
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f'AgentController(id={self.id}, agent={self.agent!r}, '
|
||||
f'event_stream={self.event_stream!r}, '
|
||||
f'state={self.state!r}, '
|
||||
f'delegate={self.delegate!r}, _pending_action={self._pending_action!r})'
|
||||
f'AgentController(id={getattr(self, "id", "<uninitialized>")}, '
|
||||
f'agent={getattr(self, "agent", "<uninitialized>")!r}, '
|
||||
f'event_stream={getattr(self, "event_stream", "<uninitialized>")!r}, '
|
||||
f'state={getattr(self, "state", "<uninitialized>")!r}, '
|
||||
f'delegate={getattr(self, "delegate", "<uninitialized>")!r}, '
|
||||
f'_pending_action={getattr(self, "_pending_action", "<uninitialized>")!r})'
|
||||
)
|
||||
|
||||
def _is_awaiting_observation(self):
|
||||
|
||||
@@ -274,27 +274,25 @@ class AgentSession:
|
||||
confirmation_mode=confirmation_mode,
|
||||
headless_mode=False,
|
||||
status_callback=self._status_callback,
|
||||
initial_state=self._maybe_restore_state(),
|
||||
)
|
||||
|
||||
# Note: We now attempt to restore the state from session here,
|
||||
# but if it fails, we fall back to None and still initialize the controller
|
||||
# with a fresh state. That way, the controller will always load events from the event stream
|
||||
# even if the state file was corrupt.
|
||||
return controller
|
||||
|
||||
def _maybe_restore_state(self) -> State | None:
|
||||
"""Helper method to handle state restore logic."""
|
||||
restored_state = None
|
||||
|
||||
# Attempt to restore the state from session.
|
||||
# Use a heuristic to figure out if we should have a state:
|
||||
# if we have events in the stream.
|
||||
try:
|
||||
restored_state = State.restore_from_session(self.sid, self.file_store)
|
||||
logger.debug(f'Restored state from session, sid: {self.sid}')
|
||||
except Exception as e:
|
||||
if self.event_stream.get_latest_event_id() > 0:
|
||||
# if we have events, we should have a state
|
||||
logger.warning(f'State could not be restored: {e}')
|
||||
|
||||
# Set the initial state through the controller.
|
||||
controller.set_initial_state(restored_state, max_iterations, confirmation_mode)
|
||||
if restored_state:
|
||||
logger.debug(f'Restored agent state from session, sid: {self.sid}')
|
||||
else:
|
||||
logger.debug('New session state created.')
|
||||
|
||||
logger.debug('Agent controller initialized.')
|
||||
return controller
|
||||
else:
|
||||
logger.debug('No events found, no state to restore')
|
||||
return restored_state
|
||||
|
||||
186
tests/unit/test_agent_session.py
Normal file
186
tests/unit/test_agent_session.py
Normal file
@@ -0,0 +1,186 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AppConfig, LLMConfig
|
||||
from openhands.events import EventStream, EventStreamSubscriber
|
||||
from openhands.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.server.session.agent_session import AgentSession
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
"""Create a properly configured mock agent with all required nested attributes"""
|
||||
# Create the base mocks
|
||||
agent = MagicMock(spec=Agent)
|
||||
llm = MagicMock(spec=LLM)
|
||||
metrics = MagicMock(spec=Metrics)
|
||||
llm_config = MagicMock(spec=LLMConfig)
|
||||
|
||||
# Configure the LLM config
|
||||
llm_config.model = 'test-model'
|
||||
llm_config.base_url = 'http://test'
|
||||
llm_config.draft_editor = None
|
||||
llm_config.max_message_chars = 1000
|
||||
|
||||
# Set up the chain of mocks
|
||||
llm.metrics = metrics
|
||||
llm.config = llm_config
|
||||
agent.llm = llm
|
||||
agent.name = 'test-agent'
|
||||
agent.sandbox_plugins = []
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_session_start_with_no_state(mock_agent):
|
||||
"""Test that AgentSession.start() works correctly when there's no state to restore"""
|
||||
|
||||
# Setup
|
||||
file_store = InMemoryFileStore({})
|
||||
session = AgentSession(sid='test-session', file_store=file_store)
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
mock_runtime = MagicMock(spec=Runtime)
|
||||
|
||||
# Mock the runtime creation to set up the runtime attribute
|
||||
async def mock_create_runtime(*args, **kwargs):
|
||||
session.runtime = mock_runtime
|
||||
|
||||
session._create_runtime = AsyncMock(side_effect=mock_create_runtime)
|
||||
|
||||
# Create a mock EventStream with no events
|
||||
mock_event_stream = MagicMock(spec=EventStream)
|
||||
mock_event_stream.get_events.return_value = []
|
||||
mock_event_stream.subscribe = MagicMock()
|
||||
mock_event_stream.get_latest_event_id.return_value = 0
|
||||
|
||||
# Inject the mock event stream into the session
|
||||
session.event_stream = mock_event_stream
|
||||
|
||||
# Create a spy on set_initial_state
|
||||
class SpyAgentController(AgentController):
|
||||
set_initial_state_call_count = 0
|
||||
test_initial_state = None
|
||||
|
||||
def set_initial_state(self, *args, state=None, **kwargs):
|
||||
self.set_initial_state_call_count += 1
|
||||
self.test_initial_state = state
|
||||
super().set_initial_state(*args, state=state, **kwargs)
|
||||
|
||||
# Patch AgentController and State.restore_from_session to fail
|
||||
with patch(
|
||||
'openhands.server.session.agent_session.AgentController', SpyAgentController
|
||||
), patch(
|
||||
'openhands.server.session.agent_session.EventStream',
|
||||
return_value=mock_event_stream,
|
||||
), patch(
|
||||
'openhands.controller.state.state.State.restore_from_session',
|
||||
side_effect=Exception('No state found'),
|
||||
):
|
||||
await session.start(
|
||||
runtime_name='test-runtime',
|
||||
config=AppConfig(),
|
||||
agent=mock_agent,
|
||||
max_iterations=10,
|
||||
)
|
||||
|
||||
# Verify EventStream.subscribe was called with correct parameters
|
||||
mock_event_stream.subscribe.assert_called_with(
|
||||
EventStreamSubscriber.AGENT_CONTROLLER,
|
||||
session.controller.on_event,
|
||||
session.controller.id,
|
||||
)
|
||||
|
||||
# Verify set_initial_state was called once with None as state
|
||||
assert session.controller.set_initial_state_call_count == 1
|
||||
assert session.controller.test_initial_state is None
|
||||
assert session.controller.state.max_iterations == 10
|
||||
assert session.controller.agent.name == 'test-agent'
|
||||
assert session.controller.state.start_id == 0
|
||||
assert session.controller.state.end_id == -1
|
||||
assert session.controller.state.truncation_id == -1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_session_start_with_restored_state(mock_agent):
|
||||
"""Test that AgentSession.start() works correctly when there's a state to restore"""
|
||||
|
||||
# Setup
|
||||
file_store = InMemoryFileStore({})
|
||||
session = AgentSession(sid='test-session', file_store=file_store)
|
||||
|
||||
# Create a mock runtime and set it up
|
||||
mock_runtime = MagicMock(spec=Runtime)
|
||||
|
||||
# Mock the runtime creation to set up the runtime attribute
|
||||
async def mock_create_runtime(*args, **kwargs):
|
||||
session.runtime = mock_runtime
|
||||
|
||||
session._create_runtime = AsyncMock(side_effect=mock_create_runtime)
|
||||
|
||||
# Create a mock EventStream with some events
|
||||
mock_event_stream = MagicMock(spec=EventStream)
|
||||
mock_event_stream.get_events.return_value = []
|
||||
mock_event_stream.subscribe = MagicMock()
|
||||
mock_event_stream.get_latest_event_id.return_value = 5 # Indicate some events exist
|
||||
|
||||
# Inject the mock event stream into the session
|
||||
session.event_stream = mock_event_stream
|
||||
|
||||
# Create a mock restored state
|
||||
mock_restored_state = MagicMock(spec=State)
|
||||
mock_restored_state.start_id = -1
|
||||
mock_restored_state.end_id = -1
|
||||
mock_restored_state.truncation_id = -1
|
||||
mock_restored_state.max_iterations = 5
|
||||
|
||||
# Create a spy on set_initial_state by subclassing AgentController
|
||||
class SpyAgentController(AgentController):
|
||||
set_initial_state_call_count = 0
|
||||
test_initial_state = None
|
||||
|
||||
def set_initial_state(self, *args, state=None, **kwargs):
|
||||
self.set_initial_state_call_count += 1
|
||||
self.test_initial_state = state
|
||||
super().set_initial_state(*args, state=state, **kwargs)
|
||||
|
||||
# Patch AgentController and State.restore_from_session to succeed
|
||||
with patch(
|
||||
'openhands.server.session.agent_session.AgentController', SpyAgentController
|
||||
), patch(
|
||||
'openhands.server.session.agent_session.EventStream',
|
||||
return_value=mock_event_stream,
|
||||
), patch(
|
||||
'openhands.controller.state.state.State.restore_from_session',
|
||||
return_value=mock_restored_state,
|
||||
):
|
||||
await session.start(
|
||||
runtime_name='test-runtime',
|
||||
config=AppConfig(),
|
||||
agent=mock_agent,
|
||||
max_iterations=10,
|
||||
)
|
||||
|
||||
# Verify set_initial_state was called once with the restored state
|
||||
assert session.controller.set_initial_state_call_count == 1
|
||||
|
||||
# Verify EventStream.subscribe was called with correct parameters
|
||||
mock_event_stream.subscribe.assert_called_with(
|
||||
EventStreamSubscriber.AGENT_CONTROLLER,
|
||||
session.controller.on_event,
|
||||
session.controller.id,
|
||||
)
|
||||
assert session.controller.test_initial_state is mock_restored_state
|
||||
assert session.controller.state is mock_restored_state
|
||||
assert session.controller.state.max_iterations == 5
|
||||
assert session.controller.state.start_id == 0
|
||||
assert session.controller.state.end_id == -1
|
||||
assert session.controller.state.truncation_id == -1
|
||||
Reference in New Issue
Block a user