diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index 4abecefcb4..3ff44fc5a3 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -72,6 +72,7 @@ from openhands.events.observation import ( from openhands.events.serialization.event import event_to_trajectory, truncate_content from openhands.llm.llm import LLM from openhands.llm.metrics import Metrics, TokenUsage +from openhands.memory.view import View # note: RESUME is only available on web GUI TRAFFIC_CONTROL_REMINDER = ( @@ -1161,7 +1162,8 @@ class AgentController: def _handle_long_context_error(self) -> None: # When context window is exceeded, keep roughly half of agent interactions - kept_events = self._apply_conversation_window() + current_view = View.from_events(self.state.history) + kept_events = self._apply_conversation_window(current_view.events) kept_event_ids = {e.id for e in kept_events} self.log( @@ -1198,7 +1200,7 @@ class AgentController: EventSource.AGENT, ) - def _apply_conversation_window(self) -> list[Event]: + def _apply_conversation_window(self, history: list[Event]) -> list[Event]: """Cuts history roughly in half when context window is exceeded. It preserves action-observation pairs and ensures that the system message, @@ -1217,11 +1219,9 @@ class AgentController: Returns: Filtered list of events keeping newest half while preserving pairs and essential initial events. """ - if not self.state.history: + # Handle empty history + if not history: return [] - - history = self.state.history - # 1. Identify essential initial events system_message: SystemMessageAction | None = None first_user_msg: MessageAction | None = None @@ -1238,50 +1238,59 @@ class AgentController: and system_message.id == history[0].id ) - # Find First User Message, which MUST exist - first_user_msg = self._first_user_message() + # Find First User Message in the history, which MUST exist + first_user_msg = self._first_user_message(history) if first_user_msg is None: - raise RuntimeError('No first user message found in the event stream.') + # If not found in history, try the event stream + first_user_msg = self._first_user_message() + if first_user_msg is None: + raise RuntimeError('No first user message found in the event stream.') + self.log( + 'warning', + 'First user message not found in history. Using cached version from event stream.', + ) + # Find the first user message index in the history first_user_msg_index = -1 for i, event in enumerate(history): if isinstance(event, MessageAction) and event.source == EventSource.USER: - first_user_msg = event first_user_msg_index = i break # Find Recall Action and Observation related to the First User Message - if first_user_msg is not None and first_user_msg_index != -1: - # Look for RecallAction after the first user message - for i in range(first_user_msg_index + 1, len(history)): - event = history[i] - if ( - isinstance(event, RecallAction) - and event.query == first_user_msg.content - ): - # Found RecallAction, now look for its Observation - recall_action = event - for j in range(i + 1, len(history)): - obs_event = history[j] - # Check for Observation caused by this RecallAction - if ( - isinstance(obs_event, Observation) - and obs_event.cause == recall_action.id - ): - recall_observation = obs_event - break # Found the observation, stop inner loop - break # Found the recall action (and maybe obs), stop outer loop + # Look for RecallAction after the first user message + for i in range(first_user_msg_index + 1, len(history)): + event = history[i] + if ( + isinstance(event, RecallAction) + and event.query == first_user_msg.content + ): + # Found RecallAction, now look for its Observation + recall_action = event + for j in range(i + 1, len(history)): + obs_event = history[j] + # Check for Observation caused by this RecallAction + if ( + isinstance(obs_event, Observation) + and obs_event.cause == recall_action.id + ): + recall_observation = obs_event + break # Found the observation, stop inner loop + break # Found the recall action (and maybe obs), stop outer loop essential_events: list[Event] = [] if system_message: essential_events.append(system_message) - if first_user_msg: + # Only include first user message if history is not empty + if history: essential_events.append(first_user_msg) - # Also keep the RecallAction that triggered the essential RecallObservation - if recall_action: - essential_events.append(recall_action) - if recall_observation: - essential_events.append(recall_observation) + # Include recall action and observation if both exist + if recall_action and recall_observation: + essential_events.append(recall_action) + essential_events.append(recall_observation) + # Include recall action without observation for backward compatibility + elif recall_action: + essential_events.append(recall_action) # 2. Determine the slice of recent events to potentially keep num_non_essential_events = len(history) - len(essential_events) @@ -1430,15 +1439,32 @@ class AgentController: return result return False - def _first_user_message(self) -> MessageAction | None: + def _first_user_message( + self, events: list[Event] | None = None + ) -> MessageAction | None: """Get the first user message for this agent. For regular agents, this is the first user message from the beginning (start_id=0). For delegate agents, this is the first user message after the delegate's start_id. + Args: + events: Optional list of events to search through. If None, uses the event stream. + Returns: MessageAction | None: The first user message, or None if no user message found """ + # If events list is provided, search through it + if events is not None: + return next( + ( + e + for e in events + if isinstance(e, MessageAction) and e.source == EventSource.USER + ), + None, + ) + + # Otherwise, use the original event stream logic with caching # Return cached message if any if self._cached_first_user_message is not None: return self._cached_first_user_message diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index 76f58c012c..c04fcf10d6 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -156,7 +156,7 @@ class Session: condensers=[ BrowserOutputCondenserConfig(attention_window=2), LLMSummarizingCondenserConfig( - llm_config=llm.config, keep_first=4, max_size=140 + llm_config=llm.config, keep_first=4, max_size=120 ), ] ) diff --git a/tests/unit/test_agent_history.py b/tests/unit/test_agent_history.py index d85280d422..5bbab8b91c 100644 --- a/tests/unit/test_agent_history.py +++ b/tests/unit/test_agent_history.py @@ -107,11 +107,8 @@ def controller_fixture(): ) controller.state = State(session_id='test_sid') - # Mock _first_user_message directly on the instance - mock_first_user_message = MagicMock(spec=MessageAction) - controller._first_user_message = MagicMock(return_value=mock_first_user_message) - - return controller, mock_first_user_message + # Don't mock _first_user_message anymore since we need it to work with history + return controller # ============================================= @@ -120,7 +117,7 @@ def controller_fixture(): def test_basic_truncation(controller_fixture): - controller, mock_first_user_message = controller_fixture + controller = controller_fixture controller.state.history = create_events( [ @@ -155,7 +152,6 @@ def test_basic_truncation(controller_fixture): }, # 10 ] ) - mock_first_user_message.id = 2 # Set the ID of the mocked first user message # Calculation (RecallAction now essential): # History len = 10 @@ -167,7 +163,7 @@ def test_basic_truncation(controller_fixture): # Validation: remove leading obs2(8). validated_slice = [cmd3(9), obs3(10)] # Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4), cmd3(9), obs3(10)] # Expected IDs: [1, 2, 3, 4, 9, 10]. Length 6. - truncated_events = controller._apply_conversation_window() + truncated_events = controller._apply_conversation_window(controller.state.history) assert len(truncated_events) == 6 expected_ids = [1, 2, 3, 4, 9, 10] @@ -179,7 +175,7 @@ def test_basic_truncation(controller_fixture): def test_no_system_message(controller_fixture): - controller, mock_first_user_message = controller_fixture + controller = controller_fixture controller.state.history = create_events( [ @@ -213,7 +209,7 @@ def test_no_system_message(controller_fixture): }, # 9 ] ) - mock_first_user_message.id = 1 + # No longer need to set mock ID # Calculation (RecallAction now essential): # History len = 9 @@ -225,7 +221,7 @@ def test_no_system_message(controller_fixture): # Validation: remove leading obs2(7). validated_slice = [cmd3(8), obs3(9)] # Final = essentials + validated_slice = [user(1), recall_act(2), recall_obs(3), cmd3(8), obs3(9)] # Expected IDs: [1, 2, 3, 8, 9]. Length 5. - truncated_events = controller._apply_conversation_window() + truncated_events = controller._apply_conversation_window(controller.state.history) assert len(truncated_events) == 5 expected_ids = [1, 2, 3, 8, 9] @@ -234,7 +230,7 @@ def test_no_system_message(controller_fixture): def test_no_recall_observation(controller_fixture): - controller, mock_first_user_message = controller_fixture + controller = controller_fixture controller.state.history = create_events( [ @@ -269,7 +265,6 @@ def test_no_recall_observation(controller_fixture): }, # 9 ] ) - mock_first_user_message.id = 2 # Calculation (RecallAction essential only if RecallObs exists): # History len = 9 @@ -281,7 +276,7 @@ def test_no_recall_observation(controller_fixture): # Validation: remove leading obs2(7). validated_slice = [cmd3(8), obs3(9)] # Final = essentials + validated_slice = [sys(1), user(2), recall_action(3), cmd_cat(8), obs_cat(9)] # Expected IDs: [1, 2, 3, 8, 9]. Length 5. - truncated_events = controller._apply_conversation_window() + truncated_events = controller._apply_conversation_window(controller.state.history) assert len(truncated_events) == 5 expected_ids = [1, 2, 3, 8, 9] @@ -290,7 +285,7 @@ def test_no_recall_observation(controller_fixture): def test_short_history_no_truncation(controller_fixture): - controller, mock_first_user_message = controller_fixture + controller = controller_fixture history = create_events( [ @@ -312,7 +307,6 @@ def test_short_history_no_truncation(controller_fixture): ] ) controller.state.history = history - mock_first_user_message.id = 2 # Calculation (RecallAction now essential): # History len = 6 @@ -324,7 +318,7 @@ def test_short_history_no_truncation(controller_fixture): # Validation: remove leading obs1(6). validated_slice = [] # Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4)] # Expected IDs: [1, 2, 3, 4]. Length 4. - truncated_events = controller._apply_conversation_window() + truncated_events = controller._apply_conversation_window(controller.state.history) assert len(truncated_events) == 4 expected_ids = [1, 2, 3, 4] @@ -333,7 +327,7 @@ def test_short_history_no_truncation(controller_fixture): def test_only_essential_events(controller_fixture): - controller, mock_first_user_message = controller_fixture + controller = controller_fixture history = create_events( [ @@ -348,7 +342,6 @@ def test_only_essential_events(controller_fixture): ] ) controller.state.history = history - mock_first_user_message.id = 2 # Calculation (RecallAction now essential): # History len = 4 @@ -360,7 +353,7 @@ def test_only_essential_events(controller_fixture): # Validation: remove leading recall_obs(4). validated_slice = [] # Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4)] # Expected IDs: [1, 2, 3, 4]. Length 4. - truncated_events = controller._apply_conversation_window() + truncated_events = controller._apply_conversation_window(controller.state.history) assert len(truncated_events) == 4 expected_ids = [1, 2, 3, 4] @@ -369,7 +362,7 @@ def test_only_essential_events(controller_fixture): def test_dangling_observations_at_cut_point(controller_fixture): - controller, mock_first_user_message = controller_fixture + controller = controller_fixture history_forced_dangle = create_events( [ @@ -409,7 +402,6 @@ def test_dangling_observations_at_cut_point(controller_fixture): ] ) # 10 events total controller.state.history = history_forced_dangle - mock_first_user_message.id = 2 # Calculation (RecallAction now essential): # History len = 10 @@ -421,7 +413,7 @@ def test_dangling_observations_at_cut_point(controller_fixture): # Validation: remove leading obs1(8). validated_slice = [cmd2(9), obs2(10)] # Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4), cmd2(9), obs2(10)] # Expected IDs: [1, 2, 3, 4, 9, 10]. Length 6. - truncated_events = controller._apply_conversation_window() + truncated_events = controller._apply_conversation_window(controller.state.history) assert len(truncated_events) == 6 expected_ids = [1, 2, 3, 4, 9, 10] @@ -431,7 +423,7 @@ def test_dangling_observations_at_cut_point(controller_fixture): def test_only_dangling_observations_in_recent_slice(controller_fixture): - controller, mock_first_user_message = controller_fixture + controller = controller_fixture history = create_events( [ @@ -457,7 +449,6 @@ def test_only_dangling_observations_in_recent_slice(controller_fixture): ] ) # 6 events total controller.state.history = history - mock_first_user_message.id = 2 # Calculation (RecallAction now essential): # History len = 6 @@ -472,7 +463,9 @@ def test_only_dangling_observations_in_recent_slice(controller_fixture): with patch( 'openhands.controller.agent_controller.logger.warning' ) as mock_log_warning: - truncated_events = controller._apply_conversation_window() + truncated_events = controller._apply_conversation_window( + controller.state.history + ) assert len(truncated_events) == 4 expected_ids = [1, 2, 3, 4] @@ -492,15 +485,15 @@ def test_only_dangling_observations_in_recent_slice(controller_fixture): def test_empty_history(controller_fixture): - controller, _ = controller_fixture + controller = controller_fixture controller.state.history = [] - truncated_events = controller._apply_conversation_window() + truncated_events = controller._apply_conversation_window(controller.state.history) assert truncated_events == [] def test_multiple_user_messages(controller_fixture): - controller, mock_first_user_message = controller_fixture + controller = controller_fixture history = create_events( [ @@ -544,7 +537,6 @@ def test_multiple_user_messages(controller_fixture): ] ) # 11 events total controller.state.history = history - mock_first_user_message.id = 2 # Explicitly set the first user message ID # Calculation (RecallAction now essential): # History len = 11 @@ -556,7 +548,7 @@ def test_multiple_user_messages(controller_fixture): # Validation: remove leading recall_obs2(9). validated_slice = [cmd2(10), obs2(11)] # Final = essentials + validated_slice = [sys(1), user1(2), recall_act1(3), recall_obs1(4)] + [cmd2(10), obs2(11)] # Expected IDs: [1, 2, 3, 4, 10, 11]. Length 6. - truncated_events = controller._apply_conversation_window() + truncated_events = controller._apply_conversation_window(controller.state.history) assert len(truncated_events) == 6 expected_ids = [1, 2, 3, 4, 10, 11]