From 4b1ed30e97f22e7c4907a249902b0a9bda4b612d Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Mon, 28 Apr 2025 22:43:41 +0200 Subject: [PATCH] Fix truncation, ensure first user message and log (#8103) Co-authored-by: openhands --- .../agenthub/codeact_agent/codeact_agent.py | 33 +- openhands/controller/agent_controller.py | 179 ++++-- openhands/memory/conversation_memory.py | 44 +- openhands/memory/view.py | 3 + tests/unit/test_agent_controller.py | 135 +---- tests/unit/test_agent_history.py | 569 ++++++++++++++++++ tests/unit/test_agents.py | 82 ++- tests/unit/test_conversation_memory.py | 512 ++++++++++++++-- tests/unit/test_prompt_caching.py | 7 +- 9 files changed, 1324 insertions(+), 240 deletions(-) create mode 100644 tests/unit/test_agent_history.py diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index e246e77467..84cb70e9be 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -20,10 +20,7 @@ from openhands.controller.state.state import State from openhands.core.config import AgentConfig from openhands.core.logger import openhands_logger as logger from openhands.core.message import Message -from openhands.events.action import ( - Action, - AgentFinishAction, -) +from openhands.events.action import Action, AgentFinishAction, MessageAction from openhands.events.event import Event from openhands.llm.llm import LLM from openhands.memory.condenser import Condenser @@ -173,7 +170,8 @@ class CodeActAgent(Agent): f'Processing {len(condensed_history)} events from a total of {len(state.history)} events' ) - messages = self._get_messages(condensed_history) + initial_user_message = self._get_initial_user_message(state.history) + messages = self._get_messages(condensed_history, initial_user_message) params: dict = { 'messages': self.llm.format_messages_for_llm(messages), } @@ -216,7 +214,29 @@ class CodeActAgent(Agent): self.pending_actions.append(action) return self.pending_actions.popleft() - def _get_messages(self, events: list[Event]) -> list[Message]: + def _get_initial_user_message(self, history: list[Event]) -> MessageAction: + """Finds the initial user message action from the full history.""" + initial_user_message: MessageAction | None = None + for event in history: + if isinstance(event, MessageAction) and event.source == 'user': + initial_user_message = event + break + + if initial_user_message is None: + # This should not happen in a valid conversation + logger.error( + f'CRITICAL: Could not find the initial user MessageAction in the full {len(history)} events history.' + ) + # Depending on desired robustness, could raise error or create a dummy action + # and log the error + raise ValueError( + 'Initial user message not found in history. Please report this issue.' + ) + return initial_user_message + + def _get_messages( + self, events: list[Event], initial_user_message: MessageAction + ) -> list[Message]: """Constructs the message history for the LLM conversation. This method builds a structured conversation history by processing events from the state @@ -253,6 +273,7 @@ class CodeActAgent(Agent): # Use ConversationMemory to process events (including SystemMessageAction) messages = self.conversation_memory.process_events( condensed_history=events, + initial_user_action=initial_user_message, max_message_chars=self.llm.config.max_message_chars, vision_is_active=self.llm.vision_is_active(), ) diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index 2461254ebb..919996cfff 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -190,7 +190,7 @@ class AgentController: logger.debug(f'System message got from agent: {system_message}') if system_message: self.event_stream.add_event(system_message, EventSource.AGENT) - logger.debug(f'System message added to event stream: {system_message}') + logger.info(f'System message added to event stream: {system_message}') async def close(self, set_stop_state: bool = True) -> None: """Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream. @@ -1020,7 +1020,7 @@ class AgentController: self.state.start_id = 0 self.log( - 'debug', + 'info', f'AgentController {self.id} - created new state. start_id: {self.state.start_id}', ) else: @@ -1030,7 +1030,7 @@ class AgentController: self.state.start_id = 0 self.log( - 'debug', + 'info', f'AgentController {self.id} initializing history from event {self.state.start_id}', ) @@ -1143,70 +1143,169 @@ class AgentController: def _handle_long_context_error(self) -> None: # When context window is exceeded, keep roughly half of agent interactions - kept_event_ids = { - e.id for e in self._apply_conversation_window(self.state.history) - } + kept_events = self._apply_conversation_window() + kept_event_ids = {e.id for e in kept_events} + + self.log( + 'info', + f'Context window exceeded. Keeping events with IDs: {kept_event_ids}', + ) + + # The events to forget are those that are not in the kept set forgotten_event_ids = {e.id for e in self.state.history} - kept_event_ids - # Save the ID of the first event in our truncated history for future reloading - if self.state.history: - self.state.start_id = self.state.history[0].id + if len(kept_event_ids) == 0: + self.log( + 'warning', + 'No events kept after applying conversation window. This should not happen.', + ) + + # verify that the first event id in kept_event_ids is the same as the start_id + if len(kept_event_ids) > 0 and self.state.history[0].id not in kept_event_ids: + self.log( + 'warning', + f'First event after applying conversation window was not kept: {self.state.history[0].id} not in {kept_event_ids}', + ) # Add an error event to trigger another step by the agent self.event_stream.add_event( CondensationAction( - forgotten_events_start_id=min(forgotten_event_ids), - forgotten_events_end_id=max(forgotten_event_ids), + forgotten_events_start_id=min(forgotten_event_ids) + if forgotten_event_ids + else 0, + forgotten_events_end_id=max(forgotten_event_ids) + if forgotten_event_ids + else 0, ), EventSource.AGENT, ) - def _apply_conversation_window(self, events: list[Event]) -> list[Event]: + def _apply_conversation_window(self) -> list[Event]: """Cuts history roughly in half when context window is exceeded. - It preserves action-observation pairs and ensures that the first user message is always included. + It preserves action-observation pairs and ensures that the system message, + the first user message, and its associated recall observation are always included + at the beginning of the context window. The algorithm: - 1. Cut history in half - 2. Check first event in new history: - - If Observation: find and include its Action - - If MessageAction: ensure its related Action-Observation pair isn't split - 3. Always include the first user message + 1. Identify essential initial events: System Message, First User Message, Recall Observation. + 2. Determine the slice of recent events to potentially keep. + 3. Validate the start of the recent slice for dangling observations. + 4. Combine essential events and validated recent events, ensuring essentials come first. Args: events: List of events to filter Returns: - Filtered list of events keeping newest half while preserving pairs + Filtered list of events keeping newest half while preserving pairs and essential initial events. """ - if not events: - return events + if not self.state.history: + return [] - # Find first user message - we'll need to ensure it's included - first_user_msg = next( - ( - e - for e in events - if isinstance(e, MessageAction) and e.source == EventSource.USER - ), - None, + history = self.state.history + + # 1. Identify essential initial events + system_message: SystemMessageAction | None = None + first_user_msg: MessageAction | None = None + recall_action: RecallAction | None = None + recall_observation: Observation | None = None + + # Find System Message (should be the first event, if it exists) + system_message = next( + (e for e in history if isinstance(e, SystemMessageAction)), None + ) + assert ( + system_message is None + or isinstance(system_message, SystemMessageAction) + and system_message.id == history[0].id ) - # cut in half - mid_point = max(1, len(events) // 2) - kept_events = events[mid_point:] - if len(kept_events) > 0 and isinstance(kept_events[0], Observation): - kept_events = kept_events[1:] + # Find First User Message, which MUST exist + first_user_msg = self._first_user_message() + if first_user_msg is None: + raise RuntimeError('No first user message found in the event stream.') - # Ensure first user message is included - if first_user_msg and first_user_msg not in kept_events: - kept_events = [first_user_msg] + kept_events + first_user_msg_index = -1 + for i, event in enumerate(history): + if isinstance(event, MessageAction) and event.source == EventSource.USER: + first_user_msg = event + first_user_msg_index = i + break - # start_id points to first user message + # Find Recall Action and Observation related to the First User Message + if first_user_msg is not None and first_user_msg_index != -1: + # Look for RecallAction after the first user message + for i in range(first_user_msg_index + 1, len(history)): + event = history[i] + if ( + isinstance(event, RecallAction) + and event.query == first_user_msg.content + ): + # Found RecallAction, now look for its Observation + recall_action = event + for j in range(i + 1, len(history)): + obs_event = history[j] + # Check for Observation caused by this RecallAction + if ( + isinstance(obs_event, Observation) + and obs_event.cause == recall_action.id + ): + recall_observation = obs_event + break # Found the observation, stop inner loop + break # Found the recall action (and maybe obs), stop outer loop + + essential_events: list[Event] = [] + if system_message: + essential_events.append(system_message) if first_user_msg: - self.state.start_id = first_user_msg.id + essential_events.append(first_user_msg) + # Also keep the RecallAction that triggered the essential RecallObservation + if recall_action: + essential_events.append(recall_action) + if recall_observation: + essential_events.append(recall_observation) - return kept_events + # 2. Determine the slice of recent events to potentially keep + num_non_essential_events = len(history) - len(essential_events) + # Keep roughly half of the non-essential events, minimum 1 + num_recent_to_keep = max(1, num_non_essential_events // 2) + + # Calculate the starting index for the recent slice + slice_start_index = len(history) - num_recent_to_keep + slice_start_index = max(0, slice_start_index) # Ensure index is not negative + recent_events_slice = history[slice_start_index:] + + # 3. Validate the start of the recent slice for dangling observations + # IMPORTANT: Most observations in history are tool call results, which cannot be without their action, or we get an LLM API error + first_valid_event_index = 0 + for i, event in enumerate(recent_events_slice): + if isinstance(event, Observation): + first_valid_event_index += 1 + else: + break + # If all events in the slice are dangling observations, we need to keep at least one + if first_valid_event_index == len(recent_events_slice): + self.log( + 'warning', + 'All recent events are dangling observations, which we truncate. This means the agent has only the essential first events. This should not happen.', + ) + + # Adjust the recent_events_slice if dangling observations were found at the start + if first_valid_event_index < len(recent_events_slice): + validated_recent_events = recent_events_slice[first_valid_event_index:] + if first_valid_event_index > 0: + self.log( + 'debug', + f'Removed {first_valid_event_index} dangling observation(s) from the start of recent event slice.', + ) + else: + validated_recent_events = [] + + # 4. Combine essential events and validated recent events + events_to_keep: list[Event] = essential_events + validated_recent_events + self.log('debug', f'History truncated. Kept {len(events_to_keep)} events.') + + return events_to_keep def _is_stuck(self) -> bool: """Checks if the agent or its delegate is stuck in a loop. diff --git a/openhands/memory/conversation_memory.py b/openhands/memory/conversation_memory.py index a4f2831401..b643ead2ba 100644 --- a/openhands/memory/conversation_memory.py +++ b/openhands/memory/conversation_memory.py @@ -54,6 +54,7 @@ class ConversationMemory: def process_events( self, condensed_history: list[Event], + initial_user_action: MessageAction, max_message_chars: int | None = None, vision_is_active: bool = False, ) -> list[Message]: @@ -66,12 +67,14 @@ class ConversationMemory: max_message_chars: The maximum number of characters in the content of an event included in the prompt to the LLM. Larger observations are truncated. vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included. + initial_user_action: The initial user message action, if available. Used to ensure the conversation starts correctly. """ events = condensed_history - # Ensure the system message exists (handles legacy cases) + # Ensure the event list starts with SystemMessageAction, then MessageAction(source='user') self._ensure_system_message(events) + self._ensure_initial_user_message(events, initial_user_action) # log visual browsing status logger.debug(f'Visual browsing: {self.agent_config.enable_som_visual_browsing}') @@ -699,6 +702,43 @@ class ConversationMemory: system_message = SystemMessageAction(content=system_prompt) # Insert the system message directly at the beginning of the events list events.insert(0, system_message) - logger.debug( + logger.info( '[ConversationMemory] Added SystemMessageAction for backward compatibility' ) + + def _ensure_initial_user_message( + self, events: list[Event], initial_user_action: MessageAction + ) -> None: + """Checks if the second event is a user MessageAction and inserts the provided one if needed.""" + if ( + not events + ): # Should have system message from previous step, but safety check + logger.error('Cannot ensure initial user message: event list is empty.') + # Or raise? Let's log for now, _ensure_system_message should handle this. + return + + # We expect events[0] to be SystemMessageAction after _ensure_system_message + if len(events) == 1: + # Only system message exists + logger.info( + 'Initial user message action was missing. Inserting the initial user message.' + ) + events.insert(1, initial_user_action) + elif not isinstance(events[1], MessageAction) or events[1].source != 'user': + # The second event exists but is not the correct initial user message action. + # We will insert the correct one provided. + logger.info( + 'Second event was not the initial user message action. Inserting correct one at index 1.' + ) + + # Insert the user message event at index 1. This will be the second message as LLM APIs expect + # but something was wrong with the history, so log all we can. + events.insert(1, initial_user_action) + + # Else: events[1] is already a user MessageAction. + # Check if it matches the one provided (if any discrepancy, log warning but proceed). + elif events[1] != initial_user_action: + logger.debug( + 'The user MessageAction at index 1 does not match the provided initial_user_action. ' + 'Proceeding with the one found in condensed history.' + ) diff --git a/openhands/memory/view.py b/openhands/memory/view.py index d5c8259f64..6809b1206b 100644 --- a/openhands/memory/view.py +++ b/openhands/memory/view.py @@ -4,6 +4,7 @@ from typing import overload from pydantic import BaseModel +from openhands.core.logger import openhands_logger as logger from openhands.events.action.agent import CondensationAction from openhands.events.event import Event from openhands.events.observation.agent import AgentCondensationObservation @@ -65,6 +66,8 @@ class View(BaseModel): break if summary is not None and summary_offset is not None: + logger.info(f'Inserting summary at offset {summary_offset}') + kept_events.insert( summary_offset, AgentCondensationObservation(content=summary) ) diff --git a/tests/unit/test_agent_controller.py b/tests/unit/test_agent_controller.py index c44ee05108..91fb3eada7 100644 --- a/tests/unit/test_agent_controller.py +++ b/tests/unit/test_agent_controller.py @@ -22,7 +22,6 @@ from openhands.events.observation import ( ErrorObservation, ) from openhands.events.observation.agent import RecallObservation -from openhands.events.observation.commands import CmdOutputObservation from openhands.events.observation.empty import NullObservation from openhands.events.serialization import event_to_dict from openhands.llm import LLM @@ -765,7 +764,7 @@ async def test_context_window_exceeded_error_handling( # We do that by playing the role of the recall module -- subscribe to the # event stream and respond to recall actions by inserting fake recall - # obesrvations. + # observations. def on_event_memory(event: Event): if isinstance(event, RecallAction): microagent_obs = RecallObservation( @@ -807,13 +806,19 @@ async def test_context_window_exceeded_error_handling( # size (because we return a message action, which triggers a recall, which # triggers a recall response). But if the pre/post-views are on the turn # when we throw the context window exceeded error, we should see the - # post-step view compressed. + # post-step view compressed (or rather, a CondensationAction added). for index, (first_view, second_view) in enumerate( zip(step_state.views[:-1], step_state.views[1:]) ): if index == error_after: - assert len(first_view) > len(second_view) + # Verify that the CondensationAction is present in the second view (after error) + # but not in the first view (before error) + assert not any(isinstance(e, CondensationAction) for e in first_view.events) + assert any(isinstance(e, CondensationAction) for e in second_view.events) + # The length might not strictly decrease due to CondensationAction being added + assert len(first_view) == len(second_view) else: + # Before the error, the view length should increase assert len(first_view) < len(second_view) # The final state's history should contain: @@ -886,7 +891,7 @@ async def test_run_controller_with_context_window_exceeded_with_truncation( def step(self, state: State): # If the state has more than one message and we haven't errored yet, # throw the context window exceeded error - if len(state.history) > 3 and not self.has_errored: + if len(state.history) > 5 and not self.has_errored: error = ContextWindowExceededError( message='prompt is too long: 233885 tokens > 200000 maximum', model='', @@ -1467,126 +1472,6 @@ def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agen ), 'should_step should return False for NullObservation with cause = 0' -def test_apply_conversation_window_basic(mock_event_stream, mock_agent): - """Test that the _apply_conversation_window method correctly prunes a list of events.""" - controller = AgentController( - agent=mock_agent, - event_stream=mock_event_stream, - max_iterations=10, - sid='test_apply_conversation_window_basic', - confirmation_mode=False, - headless_mode=True, - ) - - # Create a sequence of events with IDs - first_msg = MessageAction(content='Hello, start task', wait_for_response=False) - first_msg._source = EventSource.USER - first_msg._id = 1 - - # Add agent question - agent_msg = MessageAction( - content='What task would you like me to perform?', wait_for_response=True - ) - agent_msg._source = EventSource.AGENT - agent_msg._id = 2 - - # Add user response - user_response = MessageAction( - content='Please list all files and show me current directory', - wait_for_response=False, - ) - user_response._source = EventSource.USER - user_response._id = 3 - - cmd1 = CmdRunAction(command='ls') - cmd1._id = 4 - obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=4) - obs1._id = 5 - obs1._cause = 4 - - cmd2 = CmdRunAction(command='pwd') - cmd2._id = 6 - obs2 = CmdOutputObservation(command='pwd', content='/home', command_id=6) - obs2._id = 7 - obs2._cause = 6 - - events = [first_msg, agent_msg, user_response, cmd1, obs1, cmd2, obs2] - - # Apply truncation - truncated = controller._apply_conversation_window(events) - - # Verify truncation occured - # Should keep first user message and roughly half of other events - assert ( - 3 <= len(truncated) < len(events) - ) # First message + at least one action-observation pair - assert truncated[0] == first_msg # First message always preserved - assert controller.state.start_id == first_msg._id - - # Verify pairs aren't split - for i, event in enumerate(truncated[1:]): - if isinstance(event, CmdOutputObservation): - assert any(e._id == event._cause for e in truncated[: i + 1]) - - -def test_history_restoration_after_truncation(mock_event_stream, mock_agent): - controller = AgentController( - agent=mock_agent, - event_stream=mock_event_stream, - max_iterations=10, - sid='test_truncation', - confirmation_mode=False, - headless_mode=True, - ) - - # Create events with IDs - first_msg = MessageAction(content='Start task', wait_for_response=False) - first_msg._source = EventSource.USER - first_msg._id = 1 - - events = [first_msg] - for i in range(5): - cmd = CmdRunAction(command=f'cmd{i}') - cmd._id = i + 2 - obs = CmdOutputObservation( - command=f'cmd{i}', content=f'output{i}', command_id=cmd._id - ) - obs._cause = cmd._id - events.extend([cmd, obs]) - - # Set up initial history - controller.state.history = events.copy() - - # Force truncation - controller.state.history = controller._apply_conversation_window( - controller.state.history - ) - - # Save state - saved_start_id = controller.state.start_id - saved_history_len = len(controller.state.history) - - # Set up mock event stream for new controller - mock_event_stream.get_events.return_value = controller.state.history - - # Create new controller with saved state - new_controller = AgentController( - agent=mock_agent, - event_stream=mock_event_stream, - max_iterations=10, - sid='test_truncation', - confirmation_mode=False, - headless_mode=True, - ) - new_controller.state.start_id = saved_start_id - new_controller.state.history = mock_event_stream.get_events() - - # Verify restoration - assert len(new_controller.state.history) == saved_history_len - assert new_controller.state.history[0] == first_msg - assert new_controller.state.start_id == saved_start_id - - def test_system_message_in_event_stream(mock_agent, test_event_stream): """Test that SystemMessageAction is added to event stream in AgentController.""" _ = AgentController( diff --git a/tests/unit/test_agent_history.py b/tests/unit/test_agent_history.py new file mode 100644 index 0000000000..1c53a03a26 --- /dev/null +++ b/tests/unit/test_agent_history.py @@ -0,0 +1,569 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from openhands.controller.agent import Agent +from openhands.controller.agent_controller import AgentController +from openhands.controller.state.state import State +from openhands.core.config import AppConfig +from openhands.events import EventSource +from openhands.events.action import CmdRunAction, MessageAction, RecallAction +from openhands.events.action.message import SystemMessageAction +from openhands.events.event import RecallType +from openhands.events.observation import ( + CmdOutputObservation, + Observation, + RecallObservation, +) +from openhands.events.stream import EventStream +from openhands.llm.llm import LLM +from openhands.llm.metrics import Metrics +from openhands.storage.memory import InMemoryFileStore + + +# Helper function to create events with sequential IDs and causes +def create_events(event_data): + events = [] + # Import necessary types here to avoid repeated imports inside the loop + from openhands.events.action import CmdRunAction, RecallAction + from openhands.events.observation import CmdOutputObservation, RecallObservation + + for i, data in enumerate(event_data): + event_type = data['type'] + source = data.get('source', EventSource.AGENT) + kwargs = {} # Arguments for the event constructor + + # Determine arguments based on event type + if event_type == RecallAction: + kwargs['query'] = data.get('query', '') + kwargs['recall_type'] = data.get('recall_type', RecallType.KNOWLEDGE) + elif event_type == RecallObservation: + kwargs['content'] = data.get('content', '') + kwargs['recall_type'] = data.get('recall_type', RecallType.KNOWLEDGE) + elif event_type == CmdRunAction: + kwargs['command'] = data.get('command', '') + elif event_type == CmdOutputObservation: + # Required args for CmdOutputObservation + kwargs['content'] = data.get('content', '') + kwargs['command'] = data.get('command', '') + # Pass command_id via kwargs if present in data + if 'command_id' in data: + kwargs['command_id'] = data['command_id'] + # Pass metadata if present + if 'metadata' in data: + kwargs['metadata'] = data['metadata'] + else: # Default for MessageAction, SystemMessageAction, etc. + kwargs['content'] = data.get('content', '') + + # Instantiate the event + event = event_type(**kwargs) + + # Assign internal attributes AFTER instantiation + event._id = i + 1 # Assign sequential IDs starting from 1 + event._source = source + # Assign _cause using cause_id from data, AFTER event._id is set + if 'cause_id' in data: + event._cause = data['cause_id'] + # If command_id was NOT passed via kwargs but cause_id exists, + # pass cause_id as command_id to __init__ via kwargs for legacy handling + # This needs to happen *before* instantiation if we want __init__ to handle it + # Let's adjust the logic slightly: + if event_type == CmdOutputObservation: + if 'command_id' not in kwargs and 'cause_id' in data: + kwargs['command_id'] = data['cause_id'] # Let __init__ handle this + # Re-instantiate if we added command_id + if 'command_id' in kwargs and event.command_id != kwargs['command_id']: + event = event_type(**kwargs) + event._id = i + 1 + event._source = source + + # Now assign _cause if it exists in data, after potential re-instantiation + if 'cause_id' in data: + event._cause = data['cause_id'] + + events.append(event) + return events + + +@pytest.fixture +def controller_fixture(): + mock_agent = MagicMock(spec=Agent) + mock_agent.llm = MagicMock(spec=LLM) + mock_agent.llm.metrics = Metrics() + mock_agent.llm.config = AppConfig().get_llm_config() + mock_agent.config = AppConfig().get_agent_config('CodeActAgent') + + mock_event_stream = MagicMock(spec=EventStream) + mock_event_stream.sid = 'test_sid' + mock_event_stream.file_store = InMemoryFileStore({}) + # Ensure get_latest_event_id returns an integer + mock_event_stream.get_latest_event_id.return_value = -1 + + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test_sid', + ) + controller.state = State(session_id='test_sid') + + # Mock _first_user_message directly on the instance + mock_first_user_message = MagicMock(spec=MessageAction) + controller._first_user_message = MagicMock(return_value=mock_first_user_message) + + return controller, mock_first_user_message + + +# ============================================= +# Test Cases for _apply_conversation_window +# ============================================= + + +def test_basic_truncation(controller_fixture): + controller, mock_first_user_message = controller_fixture + + controller.state.history = create_events( + [ + {'type': SystemMessageAction, 'content': 'System Prompt'}, # 1 + { + 'type': MessageAction, + 'content': 'User Task 1', + 'source': EventSource.USER, + }, # 2 + {'type': RecallAction, 'query': 'User Task 1'}, # 3 + {'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4 + {'type': CmdRunAction, 'command': 'ls'}, # 5 + { + 'type': CmdOutputObservation, + 'content': 'file1', + 'command': 'ls', + 'cause_id': 5, + }, # 6 + {'type': CmdRunAction, 'command': 'pwd'}, # 7 + { + 'type': CmdOutputObservation, + 'content': '/dir', + 'command': 'pwd', + 'cause_id': 7, + }, # 8 + {'type': CmdRunAction, 'command': 'cat file1'}, # 9 + { + 'type': CmdOutputObservation, + 'content': 'content', + 'command': 'cat file1', + 'cause_id': 9, + }, # 10 + ] + ) + mock_first_user_message.id = 2 # Set the ID of the mocked first user message + + # Calculation (RecallAction now essential): + # History len = 10 + # Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4) + # Non-essential count = 10 - 4 = 6 + # num_recent_to_keep = max(1, 6 // 2) = 3 + # slice_start_index = 10 - 3 = 7 + # recent_events_slice = history[7:] = [obs2(8), cmd3(9), obs3(10)] + # Validation: remove leading obs2(8). validated_slice = [cmd3(9), obs3(10)] + # Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4), cmd3(9), obs3(10)] + # Expected IDs: [1, 2, 3, 4, 9, 10]. Length 6. + truncated_events = controller._apply_conversation_window() + + assert len(truncated_events) == 6 + expected_ids = [1, 2, 3, 4, 9, 10] + actual_ids = [e.id for e in truncated_events] + assert actual_ids == expected_ids + # Check no dangling observations at the start of the recent slice part + # The first event of the validated slice is cmd3(9) + assert not isinstance(truncated_events[4], Observation) # Index adjusted + + +def test_no_system_message(controller_fixture): + controller, mock_first_user_message = controller_fixture + + controller.state.history = create_events( + [ + { + 'type': MessageAction, + 'content': 'User Task 1', + 'source': EventSource.USER, + }, # 1 + {'type': RecallAction, 'query': 'User Task 1'}, # 2 + {'type': RecallObservation, 'content': 'Recall result', 'cause_id': 2}, # 3 + {'type': CmdRunAction, 'command': 'ls'}, # 4 + { + 'type': CmdOutputObservation, + 'content': 'file1', + 'command': 'ls', + 'cause_id': 4, + }, # 5 + {'type': CmdRunAction, 'command': 'pwd'}, # 6 + { + 'type': CmdOutputObservation, + 'content': '/dir', + 'command': 'pwd', + 'cause_id': 6, + }, # 7 + {'type': CmdRunAction, 'command': 'cat file1'}, # 8 + { + 'type': CmdOutputObservation, + 'content': 'content', + 'command': 'cat file1', + 'cause_id': 8, + }, # 9 + ] + ) + mock_first_user_message.id = 1 + + # Calculation (RecallAction now essential): + # History len = 9 + # Essentials = [user(1), recall_act(2), recall_obs(3)] (len=3) + # Non-essential count = 9 - 3 = 6 + # num_recent_to_keep = max(1, 6 // 2) = 3 + # slice_start_index = 9 - 3 = 6 + # recent_events_slice = history[6:] = [obs2(7), cmd3(8), obs3(9)] + # Validation: remove leading obs2(7). validated_slice = [cmd3(8), obs3(9)] + # Final = essentials + validated_slice = [user(1), recall_act(2), recall_obs(3), cmd3(8), obs3(9)] + # Expected IDs: [1, 2, 3, 8, 9]. Length 5. + truncated_events = controller._apply_conversation_window() + + assert len(truncated_events) == 5 + expected_ids = [1, 2, 3, 8, 9] + actual_ids = [e.id for e in truncated_events] + assert actual_ids == expected_ids + + +def test_no_recall_observation(controller_fixture): + controller, mock_first_user_message = controller_fixture + + controller.state.history = create_events( + [ + {'type': SystemMessageAction, 'content': 'System Prompt'}, # 1 + { + 'type': MessageAction, + 'content': 'User Task 1', + 'source': EventSource.USER, + }, # 2 + {'type': RecallAction, 'query': 'User Task 1'}, # 3 (Recall Action exists) + # Recall Observation is missing + {'type': CmdRunAction, 'command': 'ls'}, # 4 + { + 'type': CmdOutputObservation, + 'content': 'file1', + 'command': 'ls', + 'cause_id': 4, + }, # 5 + {'type': CmdRunAction, 'command': 'pwd'}, # 6 + { + 'type': CmdOutputObservation, + 'content': '/dir', + 'command': 'pwd', + 'cause_id': 6, + }, # 7 + {'type': CmdRunAction, 'command': 'cat file1'}, # 8 + { + 'type': CmdOutputObservation, + 'content': 'content', + 'command': 'cat file1', + 'cause_id': 8, + }, # 9 + ] + ) + mock_first_user_message.id = 2 + + # Calculation (RecallAction essential only if RecallObs exists): + # History len = 9 + # Essentials = [sys(1), user(2)] (len=2) - RecallObs missing, so RecallAction not essential here + # Non-essential count = 9 - 2 = 7 + # num_recent_to_keep = max(1, 7 // 2) = 3 + # slice_start_index = 9 - 3 = 6 + # recent_events_slice = history[6:] = [obs2(7), cmd3(8), obs3(9)] + # Validation: remove leading obs2(7). validated_slice = [cmd3(8), obs3(9)] + # Final = essentials + validated_slice = [sys(1), user(2), recall_action(3), cmd_cat(8), obs_cat(9)] + # Expected IDs: [1, 2, 3, 8, 9]. Length 5. + truncated_events = controller._apply_conversation_window() + + assert len(truncated_events) == 5 + expected_ids = [1, 2, 3, 8, 9] + actual_ids = [e.id for e in truncated_events] + assert actual_ids == expected_ids + + +def test_short_history_no_truncation(controller_fixture): + controller, mock_first_user_message = controller_fixture + + history = create_events( + [ + {'type': SystemMessageAction, 'content': 'System Prompt'}, # 1 + { + 'type': MessageAction, + 'content': 'User Task 1', + 'source': EventSource.USER, + }, # 2 + {'type': RecallAction, 'query': 'User Task 1'}, # 3 + {'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4 + {'type': CmdRunAction, 'command': 'ls'}, # 5 + { + 'type': CmdOutputObservation, + 'content': 'file1', + 'command': 'ls', + 'cause_id': 5, + }, # 6 + ] + ) + controller.state.history = history + mock_first_user_message.id = 2 + + # Calculation (RecallAction now essential): + # History len = 6 + # Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4) + # Non-essential count = 6 - 4 = 2 + # num_recent_to_keep = max(1, 2 // 2) = 1 + # slice_start_index = 6 - 1 = 5 + # recent_events_slice = history[5:] = [obs1(6)] + # Validation: remove leading obs1(6). validated_slice = [] + # Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4)] + # Expected IDs: [1, 2, 3, 4]. Length 4. + truncated_events = controller._apply_conversation_window() + + assert len(truncated_events) == 4 + expected_ids = [1, 2, 3, 4] + actual_ids = [e.id for e in truncated_events] + assert actual_ids == expected_ids + + +def test_only_essential_events(controller_fixture): + controller, mock_first_user_message = controller_fixture + + history = create_events( + [ + {'type': SystemMessageAction, 'content': 'System Prompt'}, # 1 + { + 'type': MessageAction, + 'content': 'User Task 1', + 'source': EventSource.USER, + }, # 2 + {'type': RecallAction, 'query': 'User Task 1'}, # 3 + {'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4 + ] + ) + controller.state.history = history + mock_first_user_message.id = 2 + + # Calculation (RecallAction now essential): + # History len = 4 + # Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4) + # Non-essential count = 4 - 4 = 0 + # num_recent_to_keep = max(1, 0 // 2) = 1 + # slice_start_index = 4 - 1 = 3 + # recent_events_slice = history[3:] = [recall_obs(4)] + # Validation: remove leading recall_obs(4). validated_slice = [] + # Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4)] + # Expected IDs: [1, 2, 3, 4]. Length 4. + truncated_events = controller._apply_conversation_window() + + assert len(truncated_events) == 4 + expected_ids = [1, 2, 3, 4] + actual_ids = [e.id for e in truncated_events] + assert actual_ids == expected_ids + + +def test_dangling_observations_at_cut_point(controller_fixture): + controller, mock_first_user_message = controller_fixture + + history_forced_dangle = create_events( + [ + {'type': SystemMessageAction, 'content': 'System Prompt'}, # 1 + { + 'type': MessageAction, + 'content': 'User Task 1', + 'source': EventSource.USER, + }, # 2 + {'type': RecallAction, 'query': 'User Task 1'}, # 3 + {'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4 + # --- Slice calculation should start here --- + { + 'type': CmdOutputObservation, + 'content': 'dangle1', + 'command': 'cmd_unknown', + }, # 5 (Dangling) + { + 'type': CmdOutputObservation, + 'content': 'dangle2', + 'command': 'cmd_unknown', + }, # 6 (Dangling) + {'type': CmdRunAction, 'command': 'cmd1'}, # 7 + { + 'type': CmdOutputObservation, + 'content': 'obs1', + 'command': 'cmd1', + 'cause_id': 7, + }, # 8 + {'type': CmdRunAction, 'command': 'cmd2'}, # 9 + { + 'type': CmdOutputObservation, + 'content': 'obs2', + 'command': 'cmd2', + 'cause_id': 9, + }, # 10 + ] + ) # 10 events total + controller.state.history = history_forced_dangle + mock_first_user_message.id = 2 + + # Calculation (RecallAction now essential): + # History len = 10 + # Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4) + # Non-essential count = 10 - 4 = 6 + # num_recent_to_keep = max(1, 6 // 2) = 3 + # slice_start_index = 10 - 3 = 7 + # recent_events_slice = history[7:] = [obs1(8), cmd2(9), obs2(10)] + # Validation: remove leading obs1(8). validated_slice = [cmd2(9), obs2(10)] + # Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4), cmd2(9), obs2(10)] + # Expected IDs: [1, 2, 3, 4, 9, 10]. Length 6. + truncated_events = controller._apply_conversation_window() + + assert len(truncated_events) == 6 + expected_ids = [1, 2, 3, 4, 9, 10] + actual_ids = [e.id for e in truncated_events] + assert actual_ids == expected_ids + # Verify dangling observations 5 and 6 were removed (implicitly by slice start and validation) + + +def test_only_dangling_observations_in_recent_slice(controller_fixture): + controller, mock_first_user_message = controller_fixture + + history = create_events( + [ + {'type': SystemMessageAction, 'content': 'System Prompt'}, # 1 + { + 'type': MessageAction, + 'content': 'User Task 1', + 'source': EventSource.USER, + }, # 2 + {'type': RecallAction, 'query': 'User Task 1'}, # 3 + {'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4 + # --- Slice calculation should start here --- + { + 'type': CmdOutputObservation, + 'content': 'dangle1', + 'command': 'cmd_unknown', + }, # 5 (Dangling) + { + 'type': CmdOutputObservation, + 'content': 'dangle2', + 'command': 'cmd_unknown', + }, # 6 (Dangling) + ] + ) # 6 events total + controller.state.history = history + mock_first_user_message.id = 2 + + # Calculation (RecallAction now essential): + # History len = 6 + # Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4) + # Non-essential count = 6 - 4 = 2 + # num_recent_to_keep = max(1, 2 // 2) = 1 + # slice_start_index = 6 - 1 = 5 + # recent_events_slice = history[5:] = [dangle2(6)] + # Validation: remove leading dangle2(6). validated_slice = [] (Corrected based on user feedback/bugfix) + # Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4)] + # Expected IDs: [1, 2, 3, 4]. Length 4. + with patch( + 'openhands.controller.agent_controller.logger.warning' + ) as mock_log_warning: + truncated_events = controller._apply_conversation_window() + + assert len(truncated_events) == 4 + expected_ids = [1, 2, 3, 4] + actual_ids = [e.id for e in truncated_events] + assert actual_ids == expected_ids + # Verify dangling observations 5 and 6 were removed + + # Check that the specific warning was logged exactly once + assert mock_log_warning.call_count == 1 + + # Check the essential parts of the arguments, allowing for variations like stacklevel + call_args, call_kwargs = mock_log_warning.call_args + expected_message_substring = 'All recent events are dangling observations, which we truncate. This means the agent has only the essential first events. This should not happen.' + assert expected_message_substring in call_args[0] + assert 'extra' in call_kwargs + assert call_kwargs['extra'].get('session_id') == 'test_sid' + + +def test_empty_history(controller_fixture): + controller, _ = controller_fixture + controller.state.history = [] + + truncated_events = controller._apply_conversation_window() + assert truncated_events == [] + + +def test_multiple_user_messages(controller_fixture): + controller, mock_first_user_message = controller_fixture + + history = create_events( + [ + {'type': SystemMessageAction, 'content': 'System Prompt'}, # 1 + { + 'type': MessageAction, + 'content': 'User Task 1', + 'source': EventSource.USER, + }, # 2 (First) + {'type': RecallAction, 'query': 'User Task 1'}, # 3 + { + 'type': RecallObservation, + 'content': 'Recall result 1', + 'cause_id': 3, + }, # 4 + {'type': CmdRunAction, 'command': 'cmd1'}, # 5 + { + 'type': CmdOutputObservation, + 'content': 'obs1', + 'command': 'cmd1', + 'cause_id': 5, + }, # 6 + { + 'type': MessageAction, + 'content': 'User Task 2', + 'source': EventSource.USER, + }, # 7 (Second) + {'type': RecallAction, 'query': 'User Task 2'}, # 8 + { + 'type': RecallObservation, + 'content': 'Recall result 2', + 'cause_id': 8, + }, # 9 + {'type': CmdRunAction, 'command': 'cmd2'}, # 10 + { + 'type': CmdOutputObservation, + 'content': 'obs2', + 'command': 'cmd2', + 'cause_id': 10, + }, # 11 + ] + ) # 11 events total + controller.state.history = history + mock_first_user_message.id = 2 # Explicitly set the first user message ID + + # Calculation (RecallAction now essential): + # History len = 11 + # Essentials = [sys(1), user1(2), recall_act1(3), recall_obs1(4)] (len=4) + # Non-essential count = 11 - 4 = 7 + # num_recent_to_keep = max(1, 7 // 2) = 3 + # slice_start_index = 11 - 3 = 8 + # recent_events_slice = history[8:] = [recall_obs2(9), cmd2(10), obs2(11)] + # Validation: remove leading recall_obs2(9). validated_slice = [cmd2(10), obs2(11)] + # Final = essentials + validated_slice = [sys(1), user1(2), recall_act1(3), recall_obs1(4)] + [cmd2(10), obs2(11)] + # Expected IDs: [1, 2, 3, 4, 10, 11]. Length 6. + truncated_events = controller._apply_conversation_window() + + assert len(truncated_events) == 6 + expected_ids = [1, 2, 3, 4, 10, 11] + actual_ids = [e.id for e in truncated_events] + assert actual_ids == expected_ids + + # Verify the second user message (ID 7) was NOT kept + assert not any(event.id == 7 for event in truncated_events) + # Verify the first user message (ID 2) is present + assert any(event.id == 2 for event in truncated_events) diff --git a/tests/unit/test_agents.py b/tests/unit/test_agents.py index b72ef9d688..641411e03f 100644 --- a/tests/unit/test_agents.py +++ b/tests/unit/test_agents.py @@ -44,6 +44,7 @@ from openhands.events.observation.commands import ( ) from openhands.events.tool import ToolCallMetadata from openhands.llm.llm import LLM +from openhands.memory.condenser import View @pytest.fixture(params=['CodeActAgent', 'ReadOnlyAgent']) @@ -97,6 +98,12 @@ def test_reset(agent): action._source = EventSource.AGENT agent.pending_actions.append(action) + # Create a mock state with initial user message + mock_state = Mock(spec=State) + initial_user_message = MessageAction(content='Initial user message') + initial_user_message._source = EventSource.USER + mock_state.history = [initial_user_message] + # Reset agent.reset() @@ -110,8 +117,14 @@ def test_step_with_pending_actions(agent): pending_action._source = EventSource.AGENT agent.pending_actions.append(pending_action) + # Create a mock state with initial user message + mock_state = Mock(spec=State) + initial_user_message = MessageAction(content='Initial user message') + initial_user_message._source = EventSource.USER + mock_state.history = [initial_user_message] + # Step should return the pending action - result = agent.step(Mock()) + result = agent.step(mock_state) assert result == pending_action assert len(agent.pending_actions) == 0 @@ -260,6 +273,11 @@ def test_step_with_no_pending_actions(mock_state: State): mock_state.latest_user_message_llm_metrics = None mock_state.latest_user_message_tool_call_metadata = None + # Add initial user message to history + initial_user_message = MessageAction(content='Initial user message') + initial_user_message._source = EventSource.USER + mock_state.history = [initial_user_message] + action = agent.step(mock_state) assert isinstance(action, MessageAction) assert action.content == 'Task completed' @@ -330,42 +348,56 @@ def test_mismatched_tool_call_events_and_auto_add_system_message( ) action = CmdRunAction('foo') - action._source = 'agent' + action._source = EventSource.AGENT action.tool_call_metadata = tool_call_metadata observation = CmdOutputObservation(content='', command_id=0, command='foo') observation.tool_call_metadata = tool_call_metadata + # Add initial user message + initial_user_message = MessageAction(content='Initial user message') + initial_user_message._source = EventSource.USER + # When both events are provided, the agent should get three messages: # 1. The system message (added automatically for backward compatibility) # 2. The action message # 3. The observation message - mock_state.history = [action, observation] - messages = agent._get_messages(mock_state.history) - assert len(messages) == 3 + mock_state.history = [initial_user_message, action, observation] + messages = agent._get_messages(mock_state.history, initial_user_message) + assert len(messages) == 4 # System + initial user + action + observation assert messages[0].role == 'system' # First message should be the system message - assert messages[1].role == 'assistant' # Second message should be the action - assert messages[2].role == 'tool' # Third message should be the observation + assert ( + messages[1].role == 'user' + ) # Second message should be the initial user message + assert messages[2].role == 'assistant' # Third message should be the action + assert messages[3].role == 'tool' # Fourth message should be the observation # The same should hold if the events are presented out-of-order - mock_state.history = [observation, action] - messages = agent._get_messages(mock_state.history) - assert len(messages) == 3 + mock_state.history = [initial_user_message, observation, action] + messages = agent._get_messages(mock_state.history, initial_user_message) + assert len(messages) == 4 assert messages[0].role == 'system' # First message should be the system message + assert ( + messages[1].role == 'user' + ) # Second message should be the initial user message # If only one of the two events is present, then we should just get the system message # plus any valid message from the event - mock_state.history = [action] - messages = agent._get_messages(mock_state.history) + mock_state.history = [initial_user_message, action] + messages = agent._get_messages(mock_state.history, initial_user_message) assert ( - len(messages) == 1 - ) # Only system message, action is waiting for its observation + len(messages) == 2 + ) # System + initial user message, action is waiting for its observation assert messages[0].role == 'system' + assert messages[1].role == 'user' - mock_state.history = [observation] - messages = agent._get_messages(mock_state.history) - assert len(messages) == 1 # Only system message, observation has no matching action + mock_state.history = [initial_user_message, observation] + messages = agent._get_messages(mock_state.history, initial_user_message) + assert ( + len(messages) == 2 + ) # System + initial user message, observation has no matching action assert messages[0].role == 'system' + assert messages[1].role == 'user' def test_grep_tool(): @@ -470,3 +502,19 @@ def test_get_system_message(): assert len(result.tools) > 0 assert any(tool['function']['name'] == 'execute_bash' for tool in result.tools) assert result._source == EventSource.AGENT + + +def test_step_raises_error_if_no_initial_user_message( + agent: CodeActAgent, mock_state: State +): + """Tests that step raises ValueError if the initial user message is not found.""" + # Ensure history does NOT contain a user MessageAction + assistant_message = MessageAction(content='Assistant message') + assistant_message._source = EventSource.AGENT + mock_state.history = [assistant_message] + # Mock the condenser to return the history as is + agent.condenser = Mock() + agent.condenser.condensed_history.return_value = View(events=mock_state.history) + + with pytest.raises(ValueError, match='Initial user message not found'): + agent.step(mock_state) diff --git a/tests/unit/test_conversation_memory.py b/tests/unit/test_conversation_memory.py index fff3d30efd..4bce0a498d 100644 --- a/tests/unit/test_conversation_memory.py +++ b/tests/unit/test_conversation_memory.py @@ -100,6 +100,7 @@ def test_process_events_with_message_action(conversation_memory): # Process events messages = conversation_memory.process_events( condensed_history=[system_message, user_message, assistant_message], + initial_user_action=user_message, max_message_chars=None, vision_is_active=False, ) @@ -108,10 +109,178 @@ def test_process_events_with_message_action(conversation_memory): assert len(messages) == 3 assert messages[0].role == 'system' assert messages[0].content[0].text == 'System message' + + +# Test cases for _ensure_system_message +def test_ensure_system_message_adds_if_missing(conversation_memory): + """Test that _ensure_system_message adds a system message if none exists.""" + user_message = MessageAction(content='User message') + user_message._source = EventSource.USER + events = [user_message] + conversation_memory._ensure_system_message(events) + assert len(events) == 2 + assert isinstance(events[0], SystemMessageAction) + assert events[0].content == 'System message' # From fixture + assert isinstance(events[1], MessageAction) # Original event is still there + + +def test_ensure_system_message_does_nothing_if_present(conversation_memory): + """Test that _ensure_system_message does nothing if a system message is already present.""" + original_system_message = SystemMessageAction(content='Existing system message') + user_message = MessageAction(content='User message') + user_message._source = EventSource.USER + events = [ + original_system_message, + user_message, + ] + original_events = list(events) # Copy before modification + conversation_memory._ensure_system_message(events) + assert events == original_events # List should be unchanged + + +# Test cases for _ensure_initial_user_message +@pytest.fixture +def initial_user_action(): + msg = MessageAction(content='Initial User Message') + msg._source = EventSource.USER + return msg + + +def test_ensure_initial_user_message_adds_if_only_system( + conversation_memory, initial_user_action +): + """Test adding the initial user message when only the system message exists.""" + system_message = SystemMessageAction(content='System') + system_message._source = EventSource.AGENT + events = [system_message] + conversation_memory._ensure_initial_user_message(events, initial_user_action) + assert len(events) == 2 + assert events[0] == system_message + assert events[1] == initial_user_action + + +def test_ensure_initial_user_message_correct_already_present( + conversation_memory, initial_user_action +): + """Test that nothing changes if the correct initial user message is at index 1.""" + system_message = SystemMessageAction(content='System') + agent_message = MessageAction(content='Assistant') + agent_message._source = EventSource.USER + events = [ + system_message, + initial_user_action, + agent_message, + ] + original_events = list(events) + conversation_memory._ensure_initial_user_message(events, initial_user_action) + assert events == original_events + + +def test_ensure_initial_user_message_incorrect_at_index_1( + conversation_memory, initial_user_action +): + """Test inserting the correct initial user message when an incorrect message is at index 1.""" + system_message = SystemMessageAction(content='System') + incorrect_second_message = MessageAction(content='Assistant') + incorrect_second_message._source = EventSource.AGENT + events = [system_message, incorrect_second_message] + conversation_memory._ensure_initial_user_message(events, initial_user_action) + assert len(events) == 3 + assert events[0] == system_message + assert events[1] == initial_user_action # Correct one inserted + assert events[2] == incorrect_second_message # Original second message shifted + + +def test_ensure_initial_user_message_correct_present_later( + conversation_memory, initial_user_action +): + """Test inserting the correct initial user message at index 1 even if it exists later.""" + system_message = SystemMessageAction(content='System') + incorrect_second_message = MessageAction(content='Assistant') + incorrect_second_message._source = EventSource.AGENT + # Correct initial message is present, but later in the list + events = [system_message, incorrect_second_message] + conversation_memory._ensure_system_message(events) + conversation_memory._ensure_initial_user_message(events, initial_user_action) + assert len(events) == 3 # Should still insert at index 1, not remove the later one + assert events[0] == system_message + assert events[1] == initial_user_action # Correct one inserted at index 1 + assert events[2] == incorrect_second_message # Original second message shifted + # The duplicate initial_user_action originally at index 2 is now at index 3 (implicitly tested by length and content) + + +def test_ensure_initial_user_message_different_user_msg_at_index_1( + conversation_memory, initial_user_action +): + """Test inserting the correct initial user message when a *different* user message is at index 1.""" + system_message = SystemMessageAction(content='System') + different_user_message = MessageAction(content='Different User Message') + different_user_message._source = EventSource.USER + events = [system_message, different_user_message] + conversation_memory._ensure_initial_user_message(events, initial_user_action) + assert len(events) == 2 + assert events[0] == system_message + assert events[1] == different_user_message # Original second message remains + + +def test_ensure_initial_user_message_different_user_msg_at_index_1_and_orphaned_obs( + conversation_memory, initial_user_action +): + """ + Test process_events when an incorrect user message is at index 1 AND + an orphaned observation (with tool_call_metadata but no matching action) exists. + Expect: System msg, CORRECT initial user msg, the incorrect user msg (shifted). + The orphaned observation should be filtered out. + """ + system_message = SystemMessageAction(content='System') + different_user_message = MessageAction(content='Different User Message') + different_user_message._source = EventSource.USER + + # Create an orphaned observation (no matching action/tool call request will exist) + # Use a dictionary that mimics ModelResponse structure to satisfy Pydantic + mock_response = { + 'id': 'mock_response_id', + 'choices': [{'message': {'content': None, 'tool_calls': []}}], + 'created': 0, + 'model': '', + 'object': '', + 'usage': {'completion_tokens': 0, 'prompt_tokens': 0, 'total_tokens': 0}, + } + orphaned_obs = CmdOutputObservation( + command='orphan_cmd', + content='Orphaned output', + command_id=99, + exit_code=0, + ) + orphaned_obs.tool_call_metadata = ToolCallMetadata( + tool_call_id='orphan_call_id', + function_name='execute_bash', + model_response=mock_response, + total_calls_in_response=1, + ) + + # Initial events list: system, wrong user message, orphaned observation + events = [system_message, different_user_message, orphaned_obs] + + # Call the main process_events method + messages = conversation_memory.process_events( + condensed_history=events, + initial_user_action=initial_user_action, # Provide the *correct* initial action + max_message_chars=None, + vision_is_active=False, + ) + + # Assertions on the final messages list + assert len(messages) == 2 + # 1. System message should be first + assert messages[0].role == 'system' + assert messages[0].content[0].text == 'System' + + # 2. The different user message should be left at index 1 assert messages[1].role == 'user' - assert messages[1].content[0].text == 'Hello' - assert messages[2].role == 'assistant' - assert messages[2].content[0].text == 'Hi there' + assert messages[1].content[0].text == different_user_message.content + + # Implicitly assert that the orphaned_obs was filtered out by checking the length (2) def test_process_events_with_cmd_output_observation(conversation_memory): @@ -125,14 +294,17 @@ def test_process_events_with_cmd_output_observation(conversation_memory): ), ) + initial_user_message = MessageAction(content='Initial user message') + initial_user_message._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[obs], + initial_user_action=initial_user_message, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 3 # System + initial user + result + result = messages[2] # The actual result is now at index 2 assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -148,14 +320,17 @@ def test_process_events_with_ipython_run_cell_observation(conversation_memory): content='IPython output\n![image](data:image/png;base64,ABC123)', ) + initial_user_message = MessageAction(content='Initial user message') + initial_user_message._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[obs], + initial_user_action=initial_user_message, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 3 # System + initial user + result + result = messages[2] # The actual result is now at index 2 assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -172,14 +347,17 @@ def test_process_events_with_agent_delegate_observation(conversation_memory): content='Content', outputs={'content': 'Delegated agent output'} ) + initial_user_message = MessageAction(content='Initial user message') + initial_user_message._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[obs], + initial_user_action=initial_user_message, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 3 # System + initial user + result + result = messages[2] # The actual result is now at index 2 assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -189,14 +367,17 @@ def test_process_events_with_agent_delegate_observation(conversation_memory): def test_process_events_with_error_observation(conversation_memory): obs = ErrorObservation('Error message') + initial_user_message = MessageAction(content='Initial user message') + initial_user_message._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[obs], + initial_user_action=initial_user_message, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 3 # System + initial user + result + result = messages[2] # The actual result is now at index 2 assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -207,10 +388,13 @@ def test_process_events_with_error_observation(conversation_memory): def test_process_events_with_unknown_observation(conversation_memory): # Create a mock that inherits from Event but not Action or Observation obs = Mock(spec=Event) + initial_user_message = MessageAction(content='Initial user message') + initial_user_message._source = EventSource.USER with pytest.raises(ValueError, match='Unknown event type'): conversation_memory.process_events( condensed_history=[obs], + initial_user_action=initial_user_message, max_message_chars=None, vision_is_active=False, ) @@ -226,14 +410,17 @@ def test_process_events_with_file_edit_observation(conversation_memory): impl_source=FileEditSource.LLM_BASED_EDIT, ) + initial_user_message = MessageAction(content='Initial user message') + initial_user_message._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[obs], + initial_user_action=initial_user_message, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 3 # System + initial user + result + result = messages[2] # The actual result is now at index 2 assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -247,18 +434,21 @@ def test_process_events_with_file_read_observation(conversation_memory): impl_source=FileReadSource.DEFAULT, ) + initial_user_message = MessageAction(content='Initial user message') + initial_user_message._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[obs], + initial_user_action=initial_user_message, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 3 # System + initial user + result + result = messages[2] # The actual result is now at index 2 assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) - assert result.content[0].text == 'File content' + assert result.content[0].text == '\n\nFile content' def test_process_events_with_browser_output_observation(conversation_memory): @@ -270,14 +460,17 @@ def test_process_events_with_browser_output_observation(conversation_memory): error=False, ) + initial_user_message = MessageAction(content='Initial user message') + initial_user_message._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[obs], + initial_user_action=initial_user_message, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 3 # System + initial user + result + result = messages[2] # The actual result is now at index 2 assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -287,14 +480,17 @@ def test_process_events_with_browser_output_observation(conversation_memory): def test_process_events_with_user_reject_observation(conversation_memory): obs = UserRejectObservation('Action rejected') + initial_user_message = MessageAction(content='Initial user message') + initial_user_message._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[obs], + initial_user_action=initial_user_message, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 3 # System + initial user + result + result = messages[2] # The actual result is now at index 2 assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -317,14 +513,17 @@ def test_process_events_with_empty_environment_info(conversation_memory): content='Retrieved environment info', ) + initial_user_message = MessageAction(content='Initial user message') + initial_user_message._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[empty_obs], + initial_user_action=initial_user_message, max_message_chars=None, vision_is_active=False, ) - # Should only contain no messages except system message - assert len(messages) == 1 + # Should only contain system message and initial user message + assert len(messages) == 2 # Verify that build_workspace_context was NOT called since all input values were empty conversation_memory.prompt_manager.build_workspace_context.assert_not_called() @@ -348,14 +547,20 @@ def test_process_events_with_function_calling_observation(conversation_memory): model_response=mock_response, total_calls_in_response=1, ) + # Define initial user action + initial_user_action = MessageAction(content='Initial user message') + initial_user_action._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[obs], + initial_user_action=initial_user_action, max_message_chars=None, vision_is_active=False, ) # No direct message when using function calling - assert len(messages) == 1 # should be no messages except system message + assert ( + len(messages) == 2 + ) # should be no messages except system message and initial user message def test_process_events_with_message_action_with_image(conversation_memory): @@ -365,14 +570,18 @@ def test_process_events_with_message_action_with_image(conversation_memory): ) action._source = EventSource.AGENT + # Define initial user action + initial_user_action = MessageAction(content='Initial user message') + initial_user_action._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[action], + initial_user_action=initial_user_action, max_message_chars=None, vision_is_active=True, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 3 + result = messages[2] assert result.role == 'assistant' assert len(result.content) == 2 assert isinstance(result.content[0], TextContent) @@ -385,14 +594,18 @@ def test_process_events_with_user_cmd_action(conversation_memory): action = CmdRunAction(command='ls -l') action._source = EventSource.USER + # Define initial user action + initial_user_action = MessageAction(content='Initial user message') + initial_user_action._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[action], + initial_user_action=initial_user_action, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 3 + result = messages[2] assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -418,14 +631,18 @@ def test_process_events_with_agent_finish_action_with_tool_metadata( total_calls_in_response=1, ) + # Define initial user action + initial_user_action = MessageAction(content='Initial user message') + initial_user_action._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[action], + initial_user_action=initial_user_action, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 3 + result = messages[2] assert result.role == 'assistant' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -461,18 +678,22 @@ def test_process_events_with_environment_microagent_observation(conversation_mem content='Retrieved environment info', ) + # Define initial user action + initial_user_action = MessageAction(content='Initial user message') + initial_user_action._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[obs], + initial_user_action=initial_user_action, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 3 + result = messages[2] assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) - assert result.content[0].text == 'Formatted repository and runtime info' + assert result.content[0].text == '\n\nFormatted repository and runtime info' # Verify the prompt_manager was called with the correct parameters conversation_memory.prompt_manager.build_workspace_context.assert_called_once() @@ -516,14 +737,18 @@ def test_process_events_with_knowledge_microagent_microagent_observation( content='Retrieved knowledge from microagents', ) + # Define initial user action + initial_user_action = MessageAction(content='Initial user message') + initial_user_action._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[obs], + initial_user_action=initial_user_action, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 2 - result = messages[1] + assert len(messages) == 3 # System + Initial User + Result + result = messages[2] # Result is now at index 2 assert result.role == 'user' assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) @@ -559,14 +784,18 @@ def test_process_events_with_microagent_observation_extensions_disabled( content='Retrieved environment info', ) + # Define initial user action + initial_user_action = MessageAction(content='Initial user message') + initial_user_action._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[obs], + initial_user_action=initial_user_action, max_message_chars=None, vision_is_active=False, ) # When prompt extensions are disabled, the RecallObservation should be ignored - assert len(messages) == 1 # should be no messages except system message + assert len(messages) == 2 # System + Initial User # Verify the prompt_manager was not called conversation_memory.prompt_manager.build_workspace_context.assert_not_called() @@ -581,14 +810,18 @@ def test_process_events_with_empty_microagent_knowledge(conversation_memory): content='Retrieved knowledge from microagents', ) + # Define initial user action + initial_user_action = MessageAction(content='Initial user message') + initial_user_action._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[obs], + initial_user_action=initial_user_action, max_message_chars=None, vision_is_active=False, ) # The implementation returns an empty string and it doesn't creates a message - assert len(messages) == 1 # should be no messages except system message + assert len(messages) == 2 # System + Initial User # When there are no triggered agents, build_microagent_info is not called conversation_memory.prompt_manager.build_microagent_info.assert_not_called() @@ -793,19 +1026,23 @@ def test_process_events_with_microagent_observation_deduplication(conversation_m content='Third retrieval', ) + # Define initial user action + initial_user_action = MessageAction(content='Initial user message') + initial_user_action._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[obs1, obs2, obs3], + initial_user_action=initial_user_action, max_message_chars=None, vision_is_active=False, ) # Verify that only the first occurrence of content for each agent is included - assert len(messages) == 2 # with system message - + assert len(messages) == 3 # System + Initial User + Result + # Result is now at index 2 # First microagent should include all agents since they appear here first - assert 'Image best practices v1' in messages[1].content[0].text - assert 'Git best practices v1' in messages[1].content[0].text - assert 'Python best practices v1' in messages[1].content[0].text + assert 'Image best practices v1' in messages[2].content[0].text + assert 'Git best practices v1' in messages[2].content[0].text + assert 'Python best practices v1' in messages[2].content[0].text def test_process_events_with_microagent_observation_deduplication_disabled_agents( @@ -842,18 +1079,22 @@ def test_process_events_with_microagent_observation_deduplication_disabled_agent content='Second retrieval', ) + # Define initial user action + initial_user_action = MessageAction(content='Initial user message') + initial_user_action._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[obs1, obs2], + initial_user_action=initial_user_action, max_message_chars=None, vision_is_active=False, ) # Verify that disabled agents are filtered out and only the first occurrence of enabled agents is included - assert len(messages) == 2 - + assert len(messages) == 3 # System + Initial User + Result + # Result is now at index 2 # First microagent should include enabled_agent but not disabled_agent - assert 'Disabled agent content' not in messages[1].content[0].text - assert 'Enabled agent content v1' in messages[1].content[0].text + assert 'Disabled agent content' not in messages[2].content[0].text + assert 'Enabled agent content v1' in messages[2].content[0].text def test_process_events_with_microagent_observation_deduplication_empty( @@ -866,17 +1107,22 @@ def test_process_events_with_microagent_observation_deduplication_empty( content='Empty retrieval', ) + # Define initial user action + initial_user_action = MessageAction(content='Initial user message') + initial_user_action._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[obs], + initial_user_action=initial_user_action, max_message_chars=None, vision_is_active=False, ) # Verify that empty RecallObservations are handled gracefully assert ( - len(messages) == 1 - ) # an empty microagent is not added to Messages, only system message is found + len(messages) == 2 # System + Initial User + ) # an empty microagent is not added to Messages assert messages[0].role == 'system' + assert messages[1].role == 'user' # Initial user message def test_has_agent_in_earlier_events(conversation_memory): @@ -1088,13 +1334,183 @@ def test_system_message_in_events(conversation_memory): system_message._source = EventSource.AGENT # Process events with the system message in condensed_history + # Define initial user action + initial_user_action = MessageAction(content='Initial user message') + initial_user_action._source = EventSource.USER messages = conversation_memory.process_events( condensed_history=[system_message], + initial_user_action=initial_user_action, max_message_chars=None, vision_is_active=False, ) # Check that the system message was processed correctly - assert len(messages) == 1 + assert len(messages) == 2 # System + Initial User assert messages[0].role == 'system' assert messages[0].content[0].text == 'System message' + assert messages[1].role == 'user' # Initial user message + + +# Helper function to create mock tool call metadata +def _create_mock_tool_call_metadata( + tool_call_id: str, function_name: str, response_id: str = 'mock_response_id' +) -> ToolCallMetadata: + # Use a dictionary that mimics ModelResponse structure to satisfy Pydantic + mock_response = { + 'id': response_id, + 'choices': [ + { + 'message': { + 'role': 'assistant', + 'content': None, # Content is None for tool calls + 'tool_calls': [ + { + 'id': tool_call_id, + 'type': 'function', + 'function': { + 'name': function_name, + 'arguments': '{}', + }, # Args don't matter for this test + } + ], + } + } + ], + 'created': 0, + 'model': 'mock_model', + 'object': 'chat.completion', + 'usage': {'completion_tokens': 0, 'prompt_tokens': 0, 'total_tokens': 0}, + } + return ToolCallMetadata( + tool_call_id=tool_call_id, + function_name=function_name, + model_response=mock_response, + total_calls_in_response=1, + ) + + +def test_process_events_partial_history(conversation_memory): + """ + Tests process_events with full and partial histories to verify + _ensure_system_message, _ensure_initial_user_message, and tool call matching logic. + """ + # --- Define Common Events --- + system_message = SystemMessageAction(content='System message') + system_message._source = EventSource.AGENT + + user_message = MessageAction( + content='Initial user query' + ) # This is the crucial initial_user_action + user_message._source = EventSource.USER + + recall_obs = RecallObservation( + recall_type=RecallType.WORKSPACE_CONTEXT, + repo_name='test-repo', + repo_directory='/path/to/repo', + content='Retrieved environment info', + ) + recall_obs._source = EventSource.AGENT + + cmd_action = CmdRunAction(command='ls', thought='Running ls') + cmd_action._source = EventSource.AGENT + cmd_action.tool_call_metadata = _create_mock_tool_call_metadata( + tool_call_id='call_ls_1', function_name='execute_bash', response_id='resp_ls_1' + ) + + cmd_obs = CmdOutputObservation( + command_id=1, command='ls', content='file1.txt\nfile2.py', exit_code=0 + ) + cmd_obs._source = EventSource.AGENT + cmd_obs.tool_call_metadata = _create_mock_tool_call_metadata( + tool_call_id='call_ls_1', function_name='execute_bash', response_id='resp_ls_1' + ) + + # --- Scenario 1: Full History --- + full_history: list[Event] = [ + system_message, + user_message, # Correct initial user message at index 1 + recall_obs, + cmd_action, + cmd_obs, + ] + messages_full = conversation_memory.process_events( + condensed_history=list(full_history), # Pass a copy + initial_user_action=user_message, # Provide the initial action + max_message_chars=None, + vision_is_active=False, + ) + + # Expected: System, User, Recall (formatted), Assistant (tool call), Tool Response + assert len(messages_full) == 5 + assert messages_full[0].role == 'system' + assert messages_full[0].content[0].text == 'System message' + assert messages_full[1].role == 'user' + assert messages_full[1].content[0].text == 'Initial user query' + assert messages_full[2].role == 'user' # Recall obs becomes user message + assert ( + 'Formatted repository and runtime info' in messages_full[2].content[0].text + ) # From fixture mock + assert messages_full[3].role == 'assistant' + assert messages_full[3].tool_calls is not None + assert len(messages_full[3].tool_calls) == 1 + assert messages_full[3].tool_calls[0].id == 'call_ls_1' + assert messages_full[4].role == 'tool' + assert messages_full[4].tool_call_id == 'call_ls_1' + assert 'file1.txt' in messages_full[4].content[0].text + + # --- Scenario 2: Partial History (Action + Observation) --- + # Simulates processing only the last action/observation pair + partial_history_action_obs: list[Event] = [ + cmd_action, + cmd_obs, + ] + messages_partial_action_obs = conversation_memory.process_events( + condensed_history=list(partial_history_action_obs), # Pass a copy + initial_user_action=user_message, # Provide the initial action + max_message_chars=None, + vision_is_active=False, + ) + + # Expected: System (added), Initial User (added), Assistant (tool call), Tool Response + assert len(messages_partial_action_obs) == 4 + assert ( + messages_partial_action_obs[0].role == 'system' + ) # Added by _ensure_system_message + assert messages_partial_action_obs[0].content[0].text == 'System message' + assert ( + messages_partial_action_obs[1].role == 'user' + ) # Added by _ensure_initial_user_message + assert messages_partial_action_obs[1].content[0].text == 'Initial user query' + assert messages_partial_action_obs[2].role == 'assistant' + assert messages_partial_action_obs[2].tool_calls is not None + assert len(messages_partial_action_obs[2].tool_calls) == 1 + assert messages_partial_action_obs[2].tool_calls[0].id == 'call_ls_1' + assert messages_partial_action_obs[3].role == 'tool' + assert messages_partial_action_obs[3].tool_call_id == 'call_ls_1' + assert 'file1.txt' in messages_partial_action_obs[3].content[0].text + + # --- Scenario 3: Partial History (Observation Only) --- + # Simulates processing only the last observation + partial_history_obs_only: list[Event] = [ + cmd_obs, + ] + messages_partial_obs_only = conversation_memory.process_events( + condensed_history=list(partial_history_obs_only), # Pass a copy + initial_user_action=user_message, # Provide the initial action + max_message_chars=None, + vision_is_active=False, + ) + + # Expected: System (added), Initial User (added). + # The CmdOutputObservation has tool_call_metadata, but there's no corresponding + # assistant message (from CmdRunAction) with the matching tool_call.id in the input history. + # Therefore, _filter_unmatched_tool_calls should remove the tool response message. + assert len(messages_partial_obs_only) == 2 + assert ( + messages_partial_obs_only[0].role == 'system' + ) # Added by _ensure_system_message + assert messages_partial_obs_only[0].content[0].text == 'System message' + assert ( + messages_partial_obs_only[1].role == 'user' + ) # Added by _ensure_initial_user_message + assert messages_partial_obs_only[1].content[0].text == 'Initial user query' diff --git a/tests/unit/test_prompt_caching.py b/tests/unit/test_prompt_caching.py index 2b323da0dc..963b590d3f 100644 --- a/tests/unit/test_prompt_caching.py +++ b/tests/unit/test_prompt_caching.py @@ -76,7 +76,7 @@ def test_get_messages(codeact_agent: CodeActAgent): history.append(message_action_5) codeact_agent.reset() - messages = codeact_agent._get_messages(history) + messages = codeact_agent._get_messages(history, message_action_1) assert ( len(messages) == 6 @@ -106,16 +106,19 @@ def test_get_messages_prompt_caching(codeact_agent: CodeActAgent): history.append(system_message_action) # Add multiple user and agent messages + initial_user_message = None # Keep track of the first user message for i in range(15): message_action_user = MessageAction(f'User message {i}') message_action_user._source = 'user' + if initial_user_message is None: + initial_user_message = message_action_user # Store the first one history.append(message_action_user) message_action_agent = MessageAction(f'Agent message {i}') message_action_agent._source = 'agent' history.append(message_action_agent) codeact_agent.reset() - messages = codeact_agent._get_messages(history) + messages = codeact_agent._get_messages(history, initial_user_message) # Check that only the last two user messages have cache_prompt=True cached_user_messages = [