Fix truncation, ensure first user message and log (#8103)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Engel Nyst
2025-04-28 22:43:41 +02:00
committed by GitHub
parent 998de564cd
commit 4b1ed30e97
9 changed files with 1324 additions and 240 deletions

View File

@@ -20,10 +20,7 @@ from openhands.controller.state.state import State
from openhands.core.config import AgentConfig
from openhands.core.logger import openhands_logger as logger
from openhands.core.message import Message
from openhands.events.action import (
Action,
AgentFinishAction,
)
from openhands.events.action import Action, AgentFinishAction, MessageAction
from openhands.events.event import Event
from openhands.llm.llm import LLM
from openhands.memory.condenser import Condenser
@@ -173,7 +170,8 @@ class CodeActAgent(Agent):
f'Processing {len(condensed_history)} events from a total of {len(state.history)} events'
)
messages = self._get_messages(condensed_history)
initial_user_message = self._get_initial_user_message(state.history)
messages = self._get_messages(condensed_history, initial_user_message)
params: dict = {
'messages': self.llm.format_messages_for_llm(messages),
}
@@ -216,7 +214,29 @@ class CodeActAgent(Agent):
self.pending_actions.append(action)
return self.pending_actions.popleft()
def _get_messages(self, events: list[Event]) -> list[Message]:
def _get_initial_user_message(self, history: list[Event]) -> MessageAction:
"""Finds the initial user message action from the full history."""
initial_user_message: MessageAction | None = None
for event in history:
if isinstance(event, MessageAction) and event.source == 'user':
initial_user_message = event
break
if initial_user_message is None:
# This should not happen in a valid conversation
logger.error(
f'CRITICAL: Could not find the initial user MessageAction in the full {len(history)} events history.'
)
# Depending on desired robustness, could raise error or create a dummy action
# and log the error
raise ValueError(
'Initial user message not found in history. Please report this issue.'
)
return initial_user_message
def _get_messages(
self, events: list[Event], initial_user_message: MessageAction
) -> list[Message]:
"""Constructs the message history for the LLM conversation.
This method builds a structured conversation history by processing events from the state
@@ -253,6 +273,7 @@ class CodeActAgent(Agent):
# Use ConversationMemory to process events (including SystemMessageAction)
messages = self.conversation_memory.process_events(
condensed_history=events,
initial_user_action=initial_user_message,
max_message_chars=self.llm.config.max_message_chars,
vision_is_active=self.llm.vision_is_active(),
)

View File

@@ -190,7 +190,7 @@ class AgentController:
logger.debug(f'System message got from agent: {system_message}')
if system_message:
self.event_stream.add_event(system_message, EventSource.AGENT)
logger.debug(f'System message added to event stream: {system_message}')
logger.info(f'System message added to event stream: {system_message}')
async def close(self, set_stop_state: bool = True) -> None:
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.
@@ -1020,7 +1020,7 @@ class AgentController:
self.state.start_id = 0
self.log(
'debug',
'info',
f'AgentController {self.id} - created new state. start_id: {self.state.start_id}',
)
else:
@@ -1030,7 +1030,7 @@ class AgentController:
self.state.start_id = 0
self.log(
'debug',
'info',
f'AgentController {self.id} initializing history from event {self.state.start_id}',
)
@@ -1143,70 +1143,169 @@ class AgentController:
def _handle_long_context_error(self) -> None:
# When context window is exceeded, keep roughly half of agent interactions
kept_event_ids = {
e.id for e in self._apply_conversation_window(self.state.history)
}
kept_events = self._apply_conversation_window()
kept_event_ids = {e.id for e in kept_events}
self.log(
'info',
f'Context window exceeded. Keeping events with IDs: {kept_event_ids}',
)
# The events to forget are those that are not in the kept set
forgotten_event_ids = {e.id for e in self.state.history} - kept_event_ids
# Save the ID of the first event in our truncated history for future reloading
if self.state.history:
self.state.start_id = self.state.history[0].id
if len(kept_event_ids) == 0:
self.log(
'warning',
'No events kept after applying conversation window. This should not happen.',
)
# verify that the first event id in kept_event_ids is the same as the start_id
if len(kept_event_ids) > 0 and self.state.history[0].id not in kept_event_ids:
self.log(
'warning',
f'First event after applying conversation window was not kept: {self.state.history[0].id} not in {kept_event_ids}',
)
# Add an error event to trigger another step by the agent
self.event_stream.add_event(
CondensationAction(
forgotten_events_start_id=min(forgotten_event_ids),
forgotten_events_end_id=max(forgotten_event_ids),
forgotten_events_start_id=min(forgotten_event_ids)
if forgotten_event_ids
else 0,
forgotten_events_end_id=max(forgotten_event_ids)
if forgotten_event_ids
else 0,
),
EventSource.AGENT,
)
def _apply_conversation_window(self, events: list[Event]) -> list[Event]:
def _apply_conversation_window(self) -> list[Event]:
"""Cuts history roughly in half when context window is exceeded.
It preserves action-observation pairs and ensures that the first user message is always included.
It preserves action-observation pairs and ensures that the system message,
the first user message, and its associated recall observation are always included
at the beginning of the context window.
The algorithm:
1. Cut history in half
2. Check first event in new history:
- If Observation: find and include its Action
- If MessageAction: ensure its related Action-Observation pair isn't split
3. Always include the first user message
1. Identify essential initial events: System Message, First User Message, Recall Observation.
2. Determine the slice of recent events to potentially keep.
3. Validate the start of the recent slice for dangling observations.
4. Combine essential events and validated recent events, ensuring essentials come first.
Args:
events: List of events to filter
Returns:
Filtered list of events keeping newest half while preserving pairs
Filtered list of events keeping newest half while preserving pairs and essential initial events.
"""
if not events:
return events
if not self.state.history:
return []
# Find first user message - we'll need to ensure it's included
first_user_msg = next(
(
e
for e in events
if isinstance(e, MessageAction) and e.source == EventSource.USER
),
None,
history = self.state.history
# 1. Identify essential initial events
system_message: SystemMessageAction | None = None
first_user_msg: MessageAction | None = None
recall_action: RecallAction | None = None
recall_observation: Observation | None = None
# Find System Message (should be the first event, if it exists)
system_message = next(
(e for e in history if isinstance(e, SystemMessageAction)), None
)
assert (
system_message is None
or isinstance(system_message, SystemMessageAction)
and system_message.id == history[0].id
)
# cut in half
mid_point = max(1, len(events) // 2)
kept_events = events[mid_point:]
if len(kept_events) > 0 and isinstance(kept_events[0], Observation):
kept_events = kept_events[1:]
# Find First User Message, which MUST exist
first_user_msg = self._first_user_message()
if first_user_msg is None:
raise RuntimeError('No first user message found in the event stream.')
# Ensure first user message is included
if first_user_msg and first_user_msg not in kept_events:
kept_events = [first_user_msg] + kept_events
first_user_msg_index = -1
for i, event in enumerate(history):
if isinstance(event, MessageAction) and event.source == EventSource.USER:
first_user_msg = event
first_user_msg_index = i
break
# start_id points to first user message
# Find Recall Action and Observation related to the First User Message
if first_user_msg is not None and first_user_msg_index != -1:
# Look for RecallAction after the first user message
for i in range(first_user_msg_index + 1, len(history)):
event = history[i]
if (
isinstance(event, RecallAction)
and event.query == first_user_msg.content
):
# Found RecallAction, now look for its Observation
recall_action = event
for j in range(i + 1, len(history)):
obs_event = history[j]
# Check for Observation caused by this RecallAction
if (
isinstance(obs_event, Observation)
and obs_event.cause == recall_action.id
):
recall_observation = obs_event
break # Found the observation, stop inner loop
break # Found the recall action (and maybe obs), stop outer loop
essential_events: list[Event] = []
if system_message:
essential_events.append(system_message)
if first_user_msg:
self.state.start_id = first_user_msg.id
essential_events.append(first_user_msg)
# Also keep the RecallAction that triggered the essential RecallObservation
if recall_action:
essential_events.append(recall_action)
if recall_observation:
essential_events.append(recall_observation)
return kept_events
# 2. Determine the slice of recent events to potentially keep
num_non_essential_events = len(history) - len(essential_events)
# Keep roughly half of the non-essential events, minimum 1
num_recent_to_keep = max(1, num_non_essential_events // 2)
# Calculate the starting index for the recent slice
slice_start_index = len(history) - num_recent_to_keep
slice_start_index = max(0, slice_start_index) # Ensure index is not negative
recent_events_slice = history[slice_start_index:]
# 3. Validate the start of the recent slice for dangling observations
# IMPORTANT: Most observations in history are tool call results, which cannot be without their action, or we get an LLM API error
first_valid_event_index = 0
for i, event in enumerate(recent_events_slice):
if isinstance(event, Observation):
first_valid_event_index += 1
else:
break
# If all events in the slice are dangling observations, we need to keep at least one
if first_valid_event_index == len(recent_events_slice):
self.log(
'warning',
'All recent events are dangling observations, which we truncate. This means the agent has only the essential first events. This should not happen.',
)
# Adjust the recent_events_slice if dangling observations were found at the start
if first_valid_event_index < len(recent_events_slice):
validated_recent_events = recent_events_slice[first_valid_event_index:]
if first_valid_event_index > 0:
self.log(
'debug',
f'Removed {first_valid_event_index} dangling observation(s) from the start of recent event slice.',
)
else:
validated_recent_events = []
# 4. Combine essential events and validated recent events
events_to_keep: list[Event] = essential_events + validated_recent_events
self.log('debug', f'History truncated. Kept {len(events_to_keep)} events.')
return events_to_keep
def _is_stuck(self) -> bool:
"""Checks if the agent or its delegate is stuck in a loop.

View File

@@ -54,6 +54,7 @@ class ConversationMemory:
def process_events(
self,
condensed_history: list[Event],
initial_user_action: MessageAction,
max_message_chars: int | None = None,
vision_is_active: bool = False,
) -> list[Message]:
@@ -66,12 +67,14 @@ class ConversationMemory:
max_message_chars: The maximum number of characters in the content of an event included
in the prompt to the LLM. Larger observations are truncated.
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included.
initial_user_action: The initial user message action, if available. Used to ensure the conversation starts correctly.
"""
events = condensed_history
# Ensure the system message exists (handles legacy cases)
# Ensure the event list starts with SystemMessageAction, then MessageAction(source='user')
self._ensure_system_message(events)
self._ensure_initial_user_message(events, initial_user_action)
# log visual browsing status
logger.debug(f'Visual browsing: {self.agent_config.enable_som_visual_browsing}')
@@ -699,6 +702,43 @@ class ConversationMemory:
system_message = SystemMessageAction(content=system_prompt)
# Insert the system message directly at the beginning of the events list
events.insert(0, system_message)
logger.debug(
logger.info(
'[ConversationMemory] Added SystemMessageAction for backward compatibility'
)
def _ensure_initial_user_message(
self, events: list[Event], initial_user_action: MessageAction
) -> None:
"""Checks if the second event is a user MessageAction and inserts the provided one if needed."""
if (
not events
): # Should have system message from previous step, but safety check
logger.error('Cannot ensure initial user message: event list is empty.')
# Or raise? Let's log for now, _ensure_system_message should handle this.
return
# We expect events[0] to be SystemMessageAction after _ensure_system_message
if len(events) == 1:
# Only system message exists
logger.info(
'Initial user message action was missing. Inserting the initial user message.'
)
events.insert(1, initial_user_action)
elif not isinstance(events[1], MessageAction) or events[1].source != 'user':
# The second event exists but is not the correct initial user message action.
# We will insert the correct one provided.
logger.info(
'Second event was not the initial user message action. Inserting correct one at index 1.'
)
# Insert the user message event at index 1. This will be the second message as LLM APIs expect
# but something was wrong with the history, so log all we can.
events.insert(1, initial_user_action)
# Else: events[1] is already a user MessageAction.
# Check if it matches the one provided (if any discrepancy, log warning but proceed).
elif events[1] != initial_user_action:
logger.debug(
'The user MessageAction at index 1 does not match the provided initial_user_action. '
'Proceeding with the one found in condensed history.'
)

View File

@@ -4,6 +4,7 @@ from typing import overload
from pydantic import BaseModel
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.agent import CondensationAction
from openhands.events.event import Event
from openhands.events.observation.agent import AgentCondensationObservation
@@ -65,6 +66,8 @@ class View(BaseModel):
break
if summary is not None and summary_offset is not None:
logger.info(f'Inserting summary at offset {summary_offset}')
kept_events.insert(
summary_offset, AgentCondensationObservation(content=summary)
)

View File

@@ -22,7 +22,6 @@ from openhands.events.observation import (
ErrorObservation,
)
from openhands.events.observation.agent import RecallObservation
from openhands.events.observation.commands import CmdOutputObservation
from openhands.events.observation.empty import NullObservation
from openhands.events.serialization import event_to_dict
from openhands.llm import LLM
@@ -765,7 +764,7 @@ async def test_context_window_exceeded_error_handling(
# We do that by playing the role of the recall module -- subscribe to the
# event stream and respond to recall actions by inserting fake recall
# obesrvations.
# observations.
def on_event_memory(event: Event):
if isinstance(event, RecallAction):
microagent_obs = RecallObservation(
@@ -807,13 +806,19 @@ async def test_context_window_exceeded_error_handling(
# size (because we return a message action, which triggers a recall, which
# triggers a recall response). But if the pre/post-views are on the turn
# when we throw the context window exceeded error, we should see the
# post-step view compressed.
# post-step view compressed (or rather, a CondensationAction added).
for index, (first_view, second_view) in enumerate(
zip(step_state.views[:-1], step_state.views[1:])
):
if index == error_after:
assert len(first_view) > len(second_view)
# Verify that the CondensationAction is present in the second view (after error)
# but not in the first view (before error)
assert not any(isinstance(e, CondensationAction) for e in first_view.events)
assert any(isinstance(e, CondensationAction) for e in second_view.events)
# The length might not strictly decrease due to CondensationAction being added
assert len(first_view) == len(second_view)
else:
# Before the error, the view length should increase
assert len(first_view) < len(second_view)
# The final state's history should contain:
@@ -886,7 +891,7 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
def step(self, state: State):
# If the state has more than one message and we haven't errored yet,
# throw the context window exceeded error
if len(state.history) > 3 and not self.has_errored:
if len(state.history) > 5 and not self.has_errored:
error = ContextWindowExceededError(
message='prompt is too long: 233885 tokens > 200000 maximum',
model='',
@@ -1467,126 +1472,6 @@ def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agen
), 'should_step should return False for NullObservation with cause = 0'
def test_apply_conversation_window_basic(mock_event_stream, mock_agent):
"""Test that the _apply_conversation_window method correctly prunes a list of events."""
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_apply_conversation_window_basic',
confirmation_mode=False,
headless_mode=True,
)
# Create a sequence of events with IDs
first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1
# Add agent question
agent_msg = MessageAction(
content='What task would you like me to perform?', wait_for_response=True
)
agent_msg._source = EventSource.AGENT
agent_msg._id = 2
# Add user response
user_response = MessageAction(
content='Please list all files and show me current directory',
wait_for_response=False,
)
user_response._source = EventSource.USER
user_response._id = 3
cmd1 = CmdRunAction(command='ls')
cmd1._id = 4
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=4)
obs1._id = 5
obs1._cause = 4
cmd2 = CmdRunAction(command='pwd')
cmd2._id = 6
obs2 = CmdOutputObservation(command='pwd', content='/home', command_id=6)
obs2._id = 7
obs2._cause = 6
events = [first_msg, agent_msg, user_response, cmd1, obs1, cmd2, obs2]
# Apply truncation
truncated = controller._apply_conversation_window(events)
# Verify truncation occured
# Should keep first user message and roughly half of other events
assert (
3 <= len(truncated) < len(events)
) # First message + at least one action-observation pair
assert truncated[0] == first_msg # First message always preserved
assert controller.state.start_id == first_msg._id
# Verify pairs aren't split
for i, event in enumerate(truncated[1:]):
if isinstance(event, CmdOutputObservation):
assert any(e._id == event._cause for e in truncated[: i + 1])
def test_history_restoration_after_truncation(mock_event_stream, mock_agent):
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)
# Create events with IDs
first_msg = MessageAction(content='Start task', wait_for_response=False)
first_msg._source = EventSource.USER
first_msg._id = 1
events = [first_msg]
for i in range(5):
cmd = CmdRunAction(command=f'cmd{i}')
cmd._id = i + 2
obs = CmdOutputObservation(
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
)
obs._cause = cmd._id
events.extend([cmd, obs])
# Set up initial history
controller.state.history = events.copy()
# Force truncation
controller.state.history = controller._apply_conversation_window(
controller.state.history
)
# Save state
saved_start_id = controller.state.start_id
saved_history_len = len(controller.state.history)
# Set up mock event stream for new controller
mock_event_stream.get_events.return_value = controller.state.history
# Create new controller with saved state
new_controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_truncation',
confirmation_mode=False,
headless_mode=True,
)
new_controller.state.start_id = saved_start_id
new_controller.state.history = mock_event_stream.get_events()
# Verify restoration
assert len(new_controller.state.history) == saved_history_len
assert new_controller.state.history[0] == first_msg
assert new_controller.state.start_id == saved_start_id
def test_system_message_in_event_stream(mock_agent, test_event_stream):
"""Test that SystemMessageAction is added to event stream in AgentController."""
_ = AgentController(

View File

@@ -0,0 +1,569 @@
from unittest.mock import MagicMock, patch
import pytest
from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController
from openhands.controller.state.state import State
from openhands.core.config import AppConfig
from openhands.events import EventSource
from openhands.events.action import CmdRunAction, MessageAction, RecallAction
from openhands.events.action.message import SystemMessageAction
from openhands.events.event import RecallType
from openhands.events.observation import (
CmdOutputObservation,
Observation,
RecallObservation,
)
from openhands.events.stream import EventStream
from openhands.llm.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.storage.memory import InMemoryFileStore
# Helper function to create events with sequential IDs and causes
def create_events(event_data):
events = []
# Import necessary types here to avoid repeated imports inside the loop
from openhands.events.action import CmdRunAction, RecallAction
from openhands.events.observation import CmdOutputObservation, RecallObservation
for i, data in enumerate(event_data):
event_type = data['type']
source = data.get('source', EventSource.AGENT)
kwargs = {} # Arguments for the event constructor
# Determine arguments based on event type
if event_type == RecallAction:
kwargs['query'] = data.get('query', '')
kwargs['recall_type'] = data.get('recall_type', RecallType.KNOWLEDGE)
elif event_type == RecallObservation:
kwargs['content'] = data.get('content', '')
kwargs['recall_type'] = data.get('recall_type', RecallType.KNOWLEDGE)
elif event_type == CmdRunAction:
kwargs['command'] = data.get('command', '')
elif event_type == CmdOutputObservation:
# Required args for CmdOutputObservation
kwargs['content'] = data.get('content', '')
kwargs['command'] = data.get('command', '')
# Pass command_id via kwargs if present in data
if 'command_id' in data:
kwargs['command_id'] = data['command_id']
# Pass metadata if present
if 'metadata' in data:
kwargs['metadata'] = data['metadata']
else: # Default for MessageAction, SystemMessageAction, etc.
kwargs['content'] = data.get('content', '')
# Instantiate the event
event = event_type(**kwargs)
# Assign internal attributes AFTER instantiation
event._id = i + 1 # Assign sequential IDs starting from 1
event._source = source
# Assign _cause using cause_id from data, AFTER event._id is set
if 'cause_id' in data:
event._cause = data['cause_id']
# If command_id was NOT passed via kwargs but cause_id exists,
# pass cause_id as command_id to __init__ via kwargs for legacy handling
# This needs to happen *before* instantiation if we want __init__ to handle it
# Let's adjust the logic slightly:
if event_type == CmdOutputObservation:
if 'command_id' not in kwargs and 'cause_id' in data:
kwargs['command_id'] = data['cause_id'] # Let __init__ handle this
# Re-instantiate if we added command_id
if 'command_id' in kwargs and event.command_id != kwargs['command_id']:
event = event_type(**kwargs)
event._id = i + 1
event._source = source
# Now assign _cause if it exists in data, after potential re-instantiation
if 'cause_id' in data:
event._cause = data['cause_id']
events.append(event)
return events
@pytest.fixture
def controller_fixture():
mock_agent = MagicMock(spec=Agent)
mock_agent.llm = MagicMock(spec=LLM)
mock_agent.llm.metrics = Metrics()
mock_agent.llm.config = AppConfig().get_llm_config()
mock_agent.config = AppConfig().get_agent_config('CodeActAgent')
mock_event_stream = MagicMock(spec=EventStream)
mock_event_stream.sid = 'test_sid'
mock_event_stream.file_store = InMemoryFileStore({})
# Ensure get_latest_event_id returns an integer
mock_event_stream.get_latest_event_id.return_value = -1
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test_sid',
)
controller.state = State(session_id='test_sid')
# Mock _first_user_message directly on the instance
mock_first_user_message = MagicMock(spec=MessageAction)
controller._first_user_message = MagicMock(return_value=mock_first_user_message)
return controller, mock_first_user_message
# =============================================
# Test Cases for _apply_conversation_window
# =============================================
def test_basic_truncation(controller_fixture):
controller, mock_first_user_message = controller_fixture
controller.state.history = create_events(
[
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
{
'type': MessageAction,
'content': 'User Task 1',
'source': EventSource.USER,
}, # 2
{'type': RecallAction, 'query': 'User Task 1'}, # 3
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4
{'type': CmdRunAction, 'command': 'ls'}, # 5
{
'type': CmdOutputObservation,
'content': 'file1',
'command': 'ls',
'cause_id': 5,
}, # 6
{'type': CmdRunAction, 'command': 'pwd'}, # 7
{
'type': CmdOutputObservation,
'content': '/dir',
'command': 'pwd',
'cause_id': 7,
}, # 8
{'type': CmdRunAction, 'command': 'cat file1'}, # 9
{
'type': CmdOutputObservation,
'content': 'content',
'command': 'cat file1',
'cause_id': 9,
}, # 10
]
)
mock_first_user_message.id = 2 # Set the ID of the mocked first user message
# Calculation (RecallAction now essential):
# History len = 10
# Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4)
# Non-essential count = 10 - 4 = 6
# num_recent_to_keep = max(1, 6 // 2) = 3
# slice_start_index = 10 - 3 = 7
# recent_events_slice = history[7:] = [obs2(8), cmd3(9), obs3(10)]
# Validation: remove leading obs2(8). validated_slice = [cmd3(9), obs3(10)]
# Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4), cmd3(9), obs3(10)]
# Expected IDs: [1, 2, 3, 4, 9, 10]. Length 6.
truncated_events = controller._apply_conversation_window()
assert len(truncated_events) == 6
expected_ids = [1, 2, 3, 4, 9, 10]
actual_ids = [e.id for e in truncated_events]
assert actual_ids == expected_ids
# Check no dangling observations at the start of the recent slice part
# The first event of the validated slice is cmd3(9)
assert not isinstance(truncated_events[4], Observation) # Index adjusted
def test_no_system_message(controller_fixture):
controller, mock_first_user_message = controller_fixture
controller.state.history = create_events(
[
{
'type': MessageAction,
'content': 'User Task 1',
'source': EventSource.USER,
}, # 1
{'type': RecallAction, 'query': 'User Task 1'}, # 2
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 2}, # 3
{'type': CmdRunAction, 'command': 'ls'}, # 4
{
'type': CmdOutputObservation,
'content': 'file1',
'command': 'ls',
'cause_id': 4,
}, # 5
{'type': CmdRunAction, 'command': 'pwd'}, # 6
{
'type': CmdOutputObservation,
'content': '/dir',
'command': 'pwd',
'cause_id': 6,
}, # 7
{'type': CmdRunAction, 'command': 'cat file1'}, # 8
{
'type': CmdOutputObservation,
'content': 'content',
'command': 'cat file1',
'cause_id': 8,
}, # 9
]
)
mock_first_user_message.id = 1
# Calculation (RecallAction now essential):
# History len = 9
# Essentials = [user(1), recall_act(2), recall_obs(3)] (len=3)
# Non-essential count = 9 - 3 = 6
# num_recent_to_keep = max(1, 6 // 2) = 3
# slice_start_index = 9 - 3 = 6
# recent_events_slice = history[6:] = [obs2(7), cmd3(8), obs3(9)]
# Validation: remove leading obs2(7). validated_slice = [cmd3(8), obs3(9)]
# Final = essentials + validated_slice = [user(1), recall_act(2), recall_obs(3), cmd3(8), obs3(9)]
# Expected IDs: [1, 2, 3, 8, 9]. Length 5.
truncated_events = controller._apply_conversation_window()
assert len(truncated_events) == 5
expected_ids = [1, 2, 3, 8, 9]
actual_ids = [e.id for e in truncated_events]
assert actual_ids == expected_ids
def test_no_recall_observation(controller_fixture):
controller, mock_first_user_message = controller_fixture
controller.state.history = create_events(
[
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
{
'type': MessageAction,
'content': 'User Task 1',
'source': EventSource.USER,
}, # 2
{'type': RecallAction, 'query': 'User Task 1'}, # 3 (Recall Action exists)
# Recall Observation is missing
{'type': CmdRunAction, 'command': 'ls'}, # 4
{
'type': CmdOutputObservation,
'content': 'file1',
'command': 'ls',
'cause_id': 4,
}, # 5
{'type': CmdRunAction, 'command': 'pwd'}, # 6
{
'type': CmdOutputObservation,
'content': '/dir',
'command': 'pwd',
'cause_id': 6,
}, # 7
{'type': CmdRunAction, 'command': 'cat file1'}, # 8
{
'type': CmdOutputObservation,
'content': 'content',
'command': 'cat file1',
'cause_id': 8,
}, # 9
]
)
mock_first_user_message.id = 2
# Calculation (RecallAction essential only if RecallObs exists):
# History len = 9
# Essentials = [sys(1), user(2)] (len=2) - RecallObs missing, so RecallAction not essential here
# Non-essential count = 9 - 2 = 7
# num_recent_to_keep = max(1, 7 // 2) = 3
# slice_start_index = 9 - 3 = 6
# recent_events_slice = history[6:] = [obs2(7), cmd3(8), obs3(9)]
# Validation: remove leading obs2(7). validated_slice = [cmd3(8), obs3(9)]
# Final = essentials + validated_slice = [sys(1), user(2), recall_action(3), cmd_cat(8), obs_cat(9)]
# Expected IDs: [1, 2, 3, 8, 9]. Length 5.
truncated_events = controller._apply_conversation_window()
assert len(truncated_events) == 5
expected_ids = [1, 2, 3, 8, 9]
actual_ids = [e.id for e in truncated_events]
assert actual_ids == expected_ids
def test_short_history_no_truncation(controller_fixture):
controller, mock_first_user_message = controller_fixture
history = create_events(
[
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
{
'type': MessageAction,
'content': 'User Task 1',
'source': EventSource.USER,
}, # 2
{'type': RecallAction, 'query': 'User Task 1'}, # 3
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4
{'type': CmdRunAction, 'command': 'ls'}, # 5
{
'type': CmdOutputObservation,
'content': 'file1',
'command': 'ls',
'cause_id': 5,
}, # 6
]
)
controller.state.history = history
mock_first_user_message.id = 2
# Calculation (RecallAction now essential):
# History len = 6
# Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4)
# Non-essential count = 6 - 4 = 2
# num_recent_to_keep = max(1, 2 // 2) = 1
# slice_start_index = 6 - 1 = 5
# recent_events_slice = history[5:] = [obs1(6)]
# Validation: remove leading obs1(6). validated_slice = []
# Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4)]
# Expected IDs: [1, 2, 3, 4]. Length 4.
truncated_events = controller._apply_conversation_window()
assert len(truncated_events) == 4
expected_ids = [1, 2, 3, 4]
actual_ids = [e.id for e in truncated_events]
assert actual_ids == expected_ids
def test_only_essential_events(controller_fixture):
controller, mock_first_user_message = controller_fixture
history = create_events(
[
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
{
'type': MessageAction,
'content': 'User Task 1',
'source': EventSource.USER,
}, # 2
{'type': RecallAction, 'query': 'User Task 1'}, # 3
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4
]
)
controller.state.history = history
mock_first_user_message.id = 2
# Calculation (RecallAction now essential):
# History len = 4
# Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4)
# Non-essential count = 4 - 4 = 0
# num_recent_to_keep = max(1, 0 // 2) = 1
# slice_start_index = 4 - 1 = 3
# recent_events_slice = history[3:] = [recall_obs(4)]
# Validation: remove leading recall_obs(4). validated_slice = []
# Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4)]
# Expected IDs: [1, 2, 3, 4]. Length 4.
truncated_events = controller._apply_conversation_window()
assert len(truncated_events) == 4
expected_ids = [1, 2, 3, 4]
actual_ids = [e.id for e in truncated_events]
assert actual_ids == expected_ids
def test_dangling_observations_at_cut_point(controller_fixture):
controller, mock_first_user_message = controller_fixture
history_forced_dangle = create_events(
[
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
{
'type': MessageAction,
'content': 'User Task 1',
'source': EventSource.USER,
}, # 2
{'type': RecallAction, 'query': 'User Task 1'}, # 3
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4
# --- Slice calculation should start here ---
{
'type': CmdOutputObservation,
'content': 'dangle1',
'command': 'cmd_unknown',
}, # 5 (Dangling)
{
'type': CmdOutputObservation,
'content': 'dangle2',
'command': 'cmd_unknown',
}, # 6 (Dangling)
{'type': CmdRunAction, 'command': 'cmd1'}, # 7
{
'type': CmdOutputObservation,
'content': 'obs1',
'command': 'cmd1',
'cause_id': 7,
}, # 8
{'type': CmdRunAction, 'command': 'cmd2'}, # 9
{
'type': CmdOutputObservation,
'content': 'obs2',
'command': 'cmd2',
'cause_id': 9,
}, # 10
]
) # 10 events total
controller.state.history = history_forced_dangle
mock_first_user_message.id = 2
# Calculation (RecallAction now essential):
# History len = 10
# Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4)
# Non-essential count = 10 - 4 = 6
# num_recent_to_keep = max(1, 6 // 2) = 3
# slice_start_index = 10 - 3 = 7
# recent_events_slice = history[7:] = [obs1(8), cmd2(9), obs2(10)]
# Validation: remove leading obs1(8). validated_slice = [cmd2(9), obs2(10)]
# Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4), cmd2(9), obs2(10)]
# Expected IDs: [1, 2, 3, 4, 9, 10]. Length 6.
truncated_events = controller._apply_conversation_window()
assert len(truncated_events) == 6
expected_ids = [1, 2, 3, 4, 9, 10]
actual_ids = [e.id for e in truncated_events]
assert actual_ids == expected_ids
# Verify dangling observations 5 and 6 were removed (implicitly by slice start and validation)
def test_only_dangling_observations_in_recent_slice(controller_fixture):
controller, mock_first_user_message = controller_fixture
history = create_events(
[
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
{
'type': MessageAction,
'content': 'User Task 1',
'source': EventSource.USER,
}, # 2
{'type': RecallAction, 'query': 'User Task 1'}, # 3
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4
# --- Slice calculation should start here ---
{
'type': CmdOutputObservation,
'content': 'dangle1',
'command': 'cmd_unknown',
}, # 5 (Dangling)
{
'type': CmdOutputObservation,
'content': 'dangle2',
'command': 'cmd_unknown',
}, # 6 (Dangling)
]
) # 6 events total
controller.state.history = history
mock_first_user_message.id = 2
# Calculation (RecallAction now essential):
# History len = 6
# Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4)
# Non-essential count = 6 - 4 = 2
# num_recent_to_keep = max(1, 2 // 2) = 1
# slice_start_index = 6 - 1 = 5
# recent_events_slice = history[5:] = [dangle2(6)]
# Validation: remove leading dangle2(6). validated_slice = [] (Corrected based on user feedback/bugfix)
# Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4)]
# Expected IDs: [1, 2, 3, 4]. Length 4.
with patch(
'openhands.controller.agent_controller.logger.warning'
) as mock_log_warning:
truncated_events = controller._apply_conversation_window()
assert len(truncated_events) == 4
expected_ids = [1, 2, 3, 4]
actual_ids = [e.id for e in truncated_events]
assert actual_ids == expected_ids
# Verify dangling observations 5 and 6 were removed
# Check that the specific warning was logged exactly once
assert mock_log_warning.call_count == 1
# Check the essential parts of the arguments, allowing for variations like stacklevel
call_args, call_kwargs = mock_log_warning.call_args
expected_message_substring = 'All recent events are dangling observations, which we truncate. This means the agent has only the essential first events. This should not happen.'
assert expected_message_substring in call_args[0]
assert 'extra' in call_kwargs
assert call_kwargs['extra'].get('session_id') == 'test_sid'
def test_empty_history(controller_fixture):
controller, _ = controller_fixture
controller.state.history = []
truncated_events = controller._apply_conversation_window()
assert truncated_events == []
def test_multiple_user_messages(controller_fixture):
controller, mock_first_user_message = controller_fixture
history = create_events(
[
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
{
'type': MessageAction,
'content': 'User Task 1',
'source': EventSource.USER,
}, # 2 (First)
{'type': RecallAction, 'query': 'User Task 1'}, # 3
{
'type': RecallObservation,
'content': 'Recall result 1',
'cause_id': 3,
}, # 4
{'type': CmdRunAction, 'command': 'cmd1'}, # 5
{
'type': CmdOutputObservation,
'content': 'obs1',
'command': 'cmd1',
'cause_id': 5,
}, # 6
{
'type': MessageAction,
'content': 'User Task 2',
'source': EventSource.USER,
}, # 7 (Second)
{'type': RecallAction, 'query': 'User Task 2'}, # 8
{
'type': RecallObservation,
'content': 'Recall result 2',
'cause_id': 8,
}, # 9
{'type': CmdRunAction, 'command': 'cmd2'}, # 10
{
'type': CmdOutputObservation,
'content': 'obs2',
'command': 'cmd2',
'cause_id': 10,
}, # 11
]
) # 11 events total
controller.state.history = history
mock_first_user_message.id = 2 # Explicitly set the first user message ID
# Calculation (RecallAction now essential):
# History len = 11
# Essentials = [sys(1), user1(2), recall_act1(3), recall_obs1(4)] (len=4)
# Non-essential count = 11 - 4 = 7
# num_recent_to_keep = max(1, 7 // 2) = 3
# slice_start_index = 11 - 3 = 8
# recent_events_slice = history[8:] = [recall_obs2(9), cmd2(10), obs2(11)]
# Validation: remove leading recall_obs2(9). validated_slice = [cmd2(10), obs2(11)]
# Final = essentials + validated_slice = [sys(1), user1(2), recall_act1(3), recall_obs1(4)] + [cmd2(10), obs2(11)]
# Expected IDs: [1, 2, 3, 4, 10, 11]. Length 6.
truncated_events = controller._apply_conversation_window()
assert len(truncated_events) == 6
expected_ids = [1, 2, 3, 4, 10, 11]
actual_ids = [e.id for e in truncated_events]
assert actual_ids == expected_ids
# Verify the second user message (ID 7) was NOT kept
assert not any(event.id == 7 for event in truncated_events)
# Verify the first user message (ID 2) is present
assert any(event.id == 2 for event in truncated_events)

View File

@@ -44,6 +44,7 @@ from openhands.events.observation.commands import (
)
from openhands.events.tool import ToolCallMetadata
from openhands.llm.llm import LLM
from openhands.memory.condenser import View
@pytest.fixture(params=['CodeActAgent', 'ReadOnlyAgent'])
@@ -97,6 +98,12 @@ def test_reset(agent):
action._source = EventSource.AGENT
agent.pending_actions.append(action)
# Create a mock state with initial user message
mock_state = Mock(spec=State)
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
mock_state.history = [initial_user_message]
# Reset
agent.reset()
@@ -110,8 +117,14 @@ def test_step_with_pending_actions(agent):
pending_action._source = EventSource.AGENT
agent.pending_actions.append(pending_action)
# Create a mock state with initial user message
mock_state = Mock(spec=State)
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
mock_state.history = [initial_user_message]
# Step should return the pending action
result = agent.step(Mock())
result = agent.step(mock_state)
assert result == pending_action
assert len(agent.pending_actions) == 0
@@ -260,6 +273,11 @@ def test_step_with_no_pending_actions(mock_state: State):
mock_state.latest_user_message_llm_metrics = None
mock_state.latest_user_message_tool_call_metadata = None
# Add initial user message to history
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
mock_state.history = [initial_user_message]
action = agent.step(mock_state)
assert isinstance(action, MessageAction)
assert action.content == 'Task completed'
@@ -330,42 +348,56 @@ def test_mismatched_tool_call_events_and_auto_add_system_message(
)
action = CmdRunAction('foo')
action._source = 'agent'
action._source = EventSource.AGENT
action.tool_call_metadata = tool_call_metadata
observation = CmdOutputObservation(content='', command_id=0, command='foo')
observation.tool_call_metadata = tool_call_metadata
# Add initial user message
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
# When both events are provided, the agent should get three messages:
# 1. The system message (added automatically for backward compatibility)
# 2. The action message
# 3. The observation message
mock_state.history = [action, observation]
messages = agent._get_messages(mock_state.history)
assert len(messages) == 3
mock_state.history = [initial_user_message, action, observation]
messages = agent._get_messages(mock_state.history, initial_user_message)
assert len(messages) == 4 # System + initial user + action + observation
assert messages[0].role == 'system' # First message should be the system message
assert messages[1].role == 'assistant' # Second message should be the action
assert messages[2].role == 'tool' # Third message should be the observation
assert (
messages[1].role == 'user'
) # Second message should be the initial user message
assert messages[2].role == 'assistant' # Third message should be the action
assert messages[3].role == 'tool' # Fourth message should be the observation
# The same should hold if the events are presented out-of-order
mock_state.history = [observation, action]
messages = agent._get_messages(mock_state.history)
assert len(messages) == 3
mock_state.history = [initial_user_message, observation, action]
messages = agent._get_messages(mock_state.history, initial_user_message)
assert len(messages) == 4
assert messages[0].role == 'system' # First message should be the system message
assert (
messages[1].role == 'user'
) # Second message should be the initial user message
# If only one of the two events is present, then we should just get the system message
# plus any valid message from the event
mock_state.history = [action]
messages = agent._get_messages(mock_state.history)
mock_state.history = [initial_user_message, action]
messages = agent._get_messages(mock_state.history, initial_user_message)
assert (
len(messages) == 1
) # Only system message, action is waiting for its observation
len(messages) == 2
) # System + initial user message, action is waiting for its observation
assert messages[0].role == 'system'
assert messages[1].role == 'user'
mock_state.history = [observation]
messages = agent._get_messages(mock_state.history)
assert len(messages) == 1 # Only system message, observation has no matching action
mock_state.history = [initial_user_message, observation]
messages = agent._get_messages(mock_state.history, initial_user_message)
assert (
len(messages) == 2
) # System + initial user message, observation has no matching action
assert messages[0].role == 'system'
assert messages[1].role == 'user'
def test_grep_tool():
@@ -470,3 +502,19 @@ def test_get_system_message():
assert len(result.tools) > 0
assert any(tool['function']['name'] == 'execute_bash' for tool in result.tools)
assert result._source == EventSource.AGENT
def test_step_raises_error_if_no_initial_user_message(
agent: CodeActAgent, mock_state: State
):
"""Tests that step raises ValueError if the initial user message is not found."""
# Ensure history does NOT contain a user MessageAction
assistant_message = MessageAction(content='Assistant message')
assistant_message._source = EventSource.AGENT
mock_state.history = [assistant_message]
# Mock the condenser to return the history as is
agent.condenser = Mock()
agent.condenser.condensed_history.return_value = View(events=mock_state.history)
with pytest.raises(ValueError, match='Initial user message not found'):
agent.step(mock_state)

View File

@@ -100,6 +100,7 @@ def test_process_events_with_message_action(conversation_memory):
# Process events
messages = conversation_memory.process_events(
condensed_history=[system_message, user_message, assistant_message],
initial_user_action=user_message,
max_message_chars=None,
vision_is_active=False,
)
@@ -108,10 +109,178 @@ def test_process_events_with_message_action(conversation_memory):
assert len(messages) == 3
assert messages[0].role == 'system'
assert messages[0].content[0].text == 'System message'
# Test cases for _ensure_system_message
def test_ensure_system_message_adds_if_missing(conversation_memory):
"""Test that _ensure_system_message adds a system message if none exists."""
user_message = MessageAction(content='User message')
user_message._source = EventSource.USER
events = [user_message]
conversation_memory._ensure_system_message(events)
assert len(events) == 2
assert isinstance(events[0], SystemMessageAction)
assert events[0].content == 'System message' # From fixture
assert isinstance(events[1], MessageAction) # Original event is still there
def test_ensure_system_message_does_nothing_if_present(conversation_memory):
"""Test that _ensure_system_message does nothing if a system message is already present."""
original_system_message = SystemMessageAction(content='Existing system message')
user_message = MessageAction(content='User message')
user_message._source = EventSource.USER
events = [
original_system_message,
user_message,
]
original_events = list(events) # Copy before modification
conversation_memory._ensure_system_message(events)
assert events == original_events # List should be unchanged
# Test cases for _ensure_initial_user_message
@pytest.fixture
def initial_user_action():
msg = MessageAction(content='Initial User Message')
msg._source = EventSource.USER
return msg
def test_ensure_initial_user_message_adds_if_only_system(
conversation_memory, initial_user_action
):
"""Test adding the initial user message when only the system message exists."""
system_message = SystemMessageAction(content='System')
system_message._source = EventSource.AGENT
events = [system_message]
conversation_memory._ensure_initial_user_message(events, initial_user_action)
assert len(events) == 2
assert events[0] == system_message
assert events[1] == initial_user_action
def test_ensure_initial_user_message_correct_already_present(
conversation_memory, initial_user_action
):
"""Test that nothing changes if the correct initial user message is at index 1."""
system_message = SystemMessageAction(content='System')
agent_message = MessageAction(content='Assistant')
agent_message._source = EventSource.USER
events = [
system_message,
initial_user_action,
agent_message,
]
original_events = list(events)
conversation_memory._ensure_initial_user_message(events, initial_user_action)
assert events == original_events
def test_ensure_initial_user_message_incorrect_at_index_1(
conversation_memory, initial_user_action
):
"""Test inserting the correct initial user message when an incorrect message is at index 1."""
system_message = SystemMessageAction(content='System')
incorrect_second_message = MessageAction(content='Assistant')
incorrect_second_message._source = EventSource.AGENT
events = [system_message, incorrect_second_message]
conversation_memory._ensure_initial_user_message(events, initial_user_action)
assert len(events) == 3
assert events[0] == system_message
assert events[1] == initial_user_action # Correct one inserted
assert events[2] == incorrect_second_message # Original second message shifted
def test_ensure_initial_user_message_correct_present_later(
conversation_memory, initial_user_action
):
"""Test inserting the correct initial user message at index 1 even if it exists later."""
system_message = SystemMessageAction(content='System')
incorrect_second_message = MessageAction(content='Assistant')
incorrect_second_message._source = EventSource.AGENT
# Correct initial message is present, but later in the list
events = [system_message, incorrect_second_message]
conversation_memory._ensure_system_message(events)
conversation_memory._ensure_initial_user_message(events, initial_user_action)
assert len(events) == 3 # Should still insert at index 1, not remove the later one
assert events[0] == system_message
assert events[1] == initial_user_action # Correct one inserted at index 1
assert events[2] == incorrect_second_message # Original second message shifted
# The duplicate initial_user_action originally at index 2 is now at index 3 (implicitly tested by length and content)
def test_ensure_initial_user_message_different_user_msg_at_index_1(
conversation_memory, initial_user_action
):
"""Test inserting the correct initial user message when a *different* user message is at index 1."""
system_message = SystemMessageAction(content='System')
different_user_message = MessageAction(content='Different User Message')
different_user_message._source = EventSource.USER
events = [system_message, different_user_message]
conversation_memory._ensure_initial_user_message(events, initial_user_action)
assert len(events) == 2
assert events[0] == system_message
assert events[1] == different_user_message # Original second message remains
def test_ensure_initial_user_message_different_user_msg_at_index_1_and_orphaned_obs(
conversation_memory, initial_user_action
):
"""
Test process_events when an incorrect user message is at index 1 AND
an orphaned observation (with tool_call_metadata but no matching action) exists.
Expect: System msg, CORRECT initial user msg, the incorrect user msg (shifted).
The orphaned observation should be filtered out.
"""
system_message = SystemMessageAction(content='System')
different_user_message = MessageAction(content='Different User Message')
different_user_message._source = EventSource.USER
# Create an orphaned observation (no matching action/tool call request will exist)
# Use a dictionary that mimics ModelResponse structure to satisfy Pydantic
mock_response = {
'id': 'mock_response_id',
'choices': [{'message': {'content': None, 'tool_calls': []}}],
'created': 0,
'model': '',
'object': '',
'usage': {'completion_tokens': 0, 'prompt_tokens': 0, 'total_tokens': 0},
}
orphaned_obs = CmdOutputObservation(
command='orphan_cmd',
content='Orphaned output',
command_id=99,
exit_code=0,
)
orphaned_obs.tool_call_metadata = ToolCallMetadata(
tool_call_id='orphan_call_id',
function_name='execute_bash',
model_response=mock_response,
total_calls_in_response=1,
)
# Initial events list: system, wrong user message, orphaned observation
events = [system_message, different_user_message, orphaned_obs]
# Call the main process_events method
messages = conversation_memory.process_events(
condensed_history=events,
initial_user_action=initial_user_action, # Provide the *correct* initial action
max_message_chars=None,
vision_is_active=False,
)
# Assertions on the final messages list
assert len(messages) == 2
# 1. System message should be first
assert messages[0].role == 'system'
assert messages[0].content[0].text == 'System'
# 2. The different user message should be left at index 1
assert messages[1].role == 'user'
assert messages[1].content[0].text == 'Hello'
assert messages[2].role == 'assistant'
assert messages[2].content[0].text == 'Hi there'
assert messages[1].content[0].text == different_user_message.content
# Implicitly assert that the orphaned_obs was filtered out by checking the length (2)
def test_process_events_with_cmd_output_observation(conversation_memory):
@@ -125,14 +294,17 @@ def test_process_events_with_cmd_output_observation(conversation_memory):
),
)
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_message,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 3 # System + initial user + result
result = messages[2] # The actual result is now at index 2
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@@ -148,14 +320,17 @@ def test_process_events_with_ipython_run_cell_observation(conversation_memory):
content='IPython output\n![image](data:image/png;base64,ABC123)',
)
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_message,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 3 # System + initial user + result
result = messages[2] # The actual result is now at index 2
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@@ -172,14 +347,17 @@ def test_process_events_with_agent_delegate_observation(conversation_memory):
content='Content', outputs={'content': 'Delegated agent output'}
)
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_message,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 3 # System + initial user + result
result = messages[2] # The actual result is now at index 2
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@@ -189,14 +367,17 @@ def test_process_events_with_agent_delegate_observation(conversation_memory):
def test_process_events_with_error_observation(conversation_memory):
obs = ErrorObservation('Error message')
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_message,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 3 # System + initial user + result
result = messages[2] # The actual result is now at index 2
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@@ -207,10 +388,13 @@ def test_process_events_with_error_observation(conversation_memory):
def test_process_events_with_unknown_observation(conversation_memory):
# Create a mock that inherits from Event but not Action or Observation
obs = Mock(spec=Event)
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
with pytest.raises(ValueError, match='Unknown event type'):
conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_message,
max_message_chars=None,
vision_is_active=False,
)
@@ -226,14 +410,17 @@ def test_process_events_with_file_edit_observation(conversation_memory):
impl_source=FileEditSource.LLM_BASED_EDIT,
)
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_message,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 3 # System + initial user + result
result = messages[2] # The actual result is now at index 2
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@@ -247,18 +434,21 @@ def test_process_events_with_file_read_observation(conversation_memory):
impl_source=FileReadSource.DEFAULT,
)
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_message,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 3 # System + initial user + result
result = messages[2] # The actual result is now at index 2
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == 'File content'
assert result.content[0].text == '\n\nFile content'
def test_process_events_with_browser_output_observation(conversation_memory):
@@ -270,14 +460,17 @@ def test_process_events_with_browser_output_observation(conversation_memory):
error=False,
)
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_message,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 3 # System + initial user + result
result = messages[2] # The actual result is now at index 2
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@@ -287,14 +480,17 @@ def test_process_events_with_browser_output_observation(conversation_memory):
def test_process_events_with_user_reject_observation(conversation_memory):
obs = UserRejectObservation('Action rejected')
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_message,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 3 # System + initial user + result
result = messages[2] # The actual result is now at index 2
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@@ -317,14 +513,17 @@ def test_process_events_with_empty_environment_info(conversation_memory):
content='Retrieved environment info',
)
initial_user_message = MessageAction(content='Initial user message')
initial_user_message._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[empty_obs],
initial_user_action=initial_user_message,
max_message_chars=None,
vision_is_active=False,
)
# Should only contain no messages except system message
assert len(messages) == 1
# Should only contain system message and initial user message
assert len(messages) == 2
# Verify that build_workspace_context was NOT called since all input values were empty
conversation_memory.prompt_manager.build_workspace_context.assert_not_called()
@@ -348,14 +547,20 @@ def test_process_events_with_function_calling_observation(conversation_memory):
model_response=mock_response,
total_calls_in_response=1,
)
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
# No direct message when using function calling
assert len(messages) == 1 # should be no messages except system message
assert (
len(messages) == 2
) # should be no messages except system message and initial user message
def test_process_events_with_message_action_with_image(conversation_memory):
@@ -365,14 +570,18 @@ def test_process_events_with_message_action_with_image(conversation_memory):
)
action._source = EventSource.AGENT
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[action],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=True,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 3
result = messages[2]
assert result.role == 'assistant'
assert len(result.content) == 2
assert isinstance(result.content[0], TextContent)
@@ -385,14 +594,18 @@ def test_process_events_with_user_cmd_action(conversation_memory):
action = CmdRunAction(command='ls -l')
action._source = EventSource.USER
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[action],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 3
result = messages[2]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@@ -418,14 +631,18 @@ def test_process_events_with_agent_finish_action_with_tool_metadata(
total_calls_in_response=1,
)
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[action],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 3
result = messages[2]
assert result.role == 'assistant'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@@ -461,18 +678,22 @@ def test_process_events_with_environment_microagent_observation(conversation_mem
content='Retrieved environment info',
)
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 3
result = messages[2]
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == 'Formatted repository and runtime info'
assert result.content[0].text == '\n\nFormatted repository and runtime info'
# Verify the prompt_manager was called with the correct parameters
conversation_memory.prompt_manager.build_workspace_context.assert_called_once()
@@ -516,14 +737,18 @@ def test_process_events_with_knowledge_microagent_microagent_observation(
content='Retrieved knowledge from microagents',
)
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
assert len(messages) == 2
result = messages[1]
assert len(messages) == 3 # System + Initial User + Result
result = messages[2] # Result is now at index 2
assert result.role == 'user'
assert len(result.content) == 1
assert isinstance(result.content[0], TextContent)
@@ -559,14 +784,18 @@ def test_process_events_with_microagent_observation_extensions_disabled(
content='Retrieved environment info',
)
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
# When prompt extensions are disabled, the RecallObservation should be ignored
assert len(messages) == 1 # should be no messages except system message
assert len(messages) == 2 # System + Initial User
# Verify the prompt_manager was not called
conversation_memory.prompt_manager.build_workspace_context.assert_not_called()
@@ -581,14 +810,18 @@ def test_process_events_with_empty_microagent_knowledge(conversation_memory):
content='Retrieved knowledge from microagents',
)
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
# The implementation returns an empty string and it doesn't creates a message
assert len(messages) == 1 # should be no messages except system message
assert len(messages) == 2 # System + Initial User
# When there are no triggered agents, build_microagent_info is not called
conversation_memory.prompt_manager.build_microagent_info.assert_not_called()
@@ -793,19 +1026,23 @@ def test_process_events_with_microagent_observation_deduplication(conversation_m
content='Third retrieval',
)
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs1, obs2, obs3],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
# Verify that only the first occurrence of content for each agent is included
assert len(messages) == 2 # with system message
assert len(messages) == 3 # System + Initial User + Result
# Result is now at index 2
# First microagent should include all agents since they appear here first
assert 'Image best practices v1' in messages[1].content[0].text
assert 'Git best practices v1' in messages[1].content[0].text
assert 'Python best practices v1' in messages[1].content[0].text
assert 'Image best practices v1' in messages[2].content[0].text
assert 'Git best practices v1' in messages[2].content[0].text
assert 'Python best practices v1' in messages[2].content[0].text
def test_process_events_with_microagent_observation_deduplication_disabled_agents(
@@ -842,18 +1079,22 @@ def test_process_events_with_microagent_observation_deduplication_disabled_agent
content='Second retrieval',
)
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs1, obs2],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
# Verify that disabled agents are filtered out and only the first occurrence of enabled agents is included
assert len(messages) == 2
assert len(messages) == 3 # System + Initial User + Result
# Result is now at index 2
# First microagent should include enabled_agent but not disabled_agent
assert 'Disabled agent content' not in messages[1].content[0].text
assert 'Enabled agent content v1' in messages[1].content[0].text
assert 'Disabled agent content' not in messages[2].content[0].text
assert 'Enabled agent content v1' in messages[2].content[0].text
def test_process_events_with_microagent_observation_deduplication_empty(
@@ -866,17 +1107,22 @@ def test_process_events_with_microagent_observation_deduplication_empty(
content='Empty retrieval',
)
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[obs],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
# Verify that empty RecallObservations are handled gracefully
assert (
len(messages) == 1
) # an empty microagent is not added to Messages, only system message is found
len(messages) == 2 # System + Initial User
) # an empty microagent is not added to Messages
assert messages[0].role == 'system'
assert messages[1].role == 'user' # Initial user message
def test_has_agent_in_earlier_events(conversation_memory):
@@ -1088,13 +1334,183 @@ def test_system_message_in_events(conversation_memory):
system_message._source = EventSource.AGENT
# Process events with the system message in condensed_history
# Define initial user action
initial_user_action = MessageAction(content='Initial user message')
initial_user_action._source = EventSource.USER
messages = conversation_memory.process_events(
condensed_history=[system_message],
initial_user_action=initial_user_action,
max_message_chars=None,
vision_is_active=False,
)
# Check that the system message was processed correctly
assert len(messages) == 1
assert len(messages) == 2 # System + Initial User
assert messages[0].role == 'system'
assert messages[0].content[0].text == 'System message'
assert messages[1].role == 'user' # Initial user message
# Helper function to create mock tool call metadata
def _create_mock_tool_call_metadata(
tool_call_id: str, function_name: str, response_id: str = 'mock_response_id'
) -> ToolCallMetadata:
# Use a dictionary that mimics ModelResponse structure to satisfy Pydantic
mock_response = {
'id': response_id,
'choices': [
{
'message': {
'role': 'assistant',
'content': None, # Content is None for tool calls
'tool_calls': [
{
'id': tool_call_id,
'type': 'function',
'function': {
'name': function_name,
'arguments': '{}',
}, # Args don't matter for this test
}
],
}
}
],
'created': 0,
'model': 'mock_model',
'object': 'chat.completion',
'usage': {'completion_tokens': 0, 'prompt_tokens': 0, 'total_tokens': 0},
}
return ToolCallMetadata(
tool_call_id=tool_call_id,
function_name=function_name,
model_response=mock_response,
total_calls_in_response=1,
)
def test_process_events_partial_history(conversation_memory):
"""
Tests process_events with full and partial histories to verify
_ensure_system_message, _ensure_initial_user_message, and tool call matching logic.
"""
# --- Define Common Events ---
system_message = SystemMessageAction(content='System message')
system_message._source = EventSource.AGENT
user_message = MessageAction(
content='Initial user query'
) # This is the crucial initial_user_action
user_message._source = EventSource.USER
recall_obs = RecallObservation(
recall_type=RecallType.WORKSPACE_CONTEXT,
repo_name='test-repo',
repo_directory='/path/to/repo',
content='Retrieved environment info',
)
recall_obs._source = EventSource.AGENT
cmd_action = CmdRunAction(command='ls', thought='Running ls')
cmd_action._source = EventSource.AGENT
cmd_action.tool_call_metadata = _create_mock_tool_call_metadata(
tool_call_id='call_ls_1', function_name='execute_bash', response_id='resp_ls_1'
)
cmd_obs = CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.py', exit_code=0
)
cmd_obs._source = EventSource.AGENT
cmd_obs.tool_call_metadata = _create_mock_tool_call_metadata(
tool_call_id='call_ls_1', function_name='execute_bash', response_id='resp_ls_1'
)
# --- Scenario 1: Full History ---
full_history: list[Event] = [
system_message,
user_message, # Correct initial user message at index 1
recall_obs,
cmd_action,
cmd_obs,
]
messages_full = conversation_memory.process_events(
condensed_history=list(full_history), # Pass a copy
initial_user_action=user_message, # Provide the initial action
max_message_chars=None,
vision_is_active=False,
)
# Expected: System, User, Recall (formatted), Assistant (tool call), Tool Response
assert len(messages_full) == 5
assert messages_full[0].role == 'system'
assert messages_full[0].content[0].text == 'System message'
assert messages_full[1].role == 'user'
assert messages_full[1].content[0].text == 'Initial user query'
assert messages_full[2].role == 'user' # Recall obs becomes user message
assert (
'Formatted repository and runtime info' in messages_full[2].content[0].text
) # From fixture mock
assert messages_full[3].role == 'assistant'
assert messages_full[3].tool_calls is not None
assert len(messages_full[3].tool_calls) == 1
assert messages_full[3].tool_calls[0].id == 'call_ls_1'
assert messages_full[4].role == 'tool'
assert messages_full[4].tool_call_id == 'call_ls_1'
assert 'file1.txt' in messages_full[4].content[0].text
# --- Scenario 2: Partial History (Action + Observation) ---
# Simulates processing only the last action/observation pair
partial_history_action_obs: list[Event] = [
cmd_action,
cmd_obs,
]
messages_partial_action_obs = conversation_memory.process_events(
condensed_history=list(partial_history_action_obs), # Pass a copy
initial_user_action=user_message, # Provide the initial action
max_message_chars=None,
vision_is_active=False,
)
# Expected: System (added), Initial User (added), Assistant (tool call), Tool Response
assert len(messages_partial_action_obs) == 4
assert (
messages_partial_action_obs[0].role == 'system'
) # Added by _ensure_system_message
assert messages_partial_action_obs[0].content[0].text == 'System message'
assert (
messages_partial_action_obs[1].role == 'user'
) # Added by _ensure_initial_user_message
assert messages_partial_action_obs[1].content[0].text == 'Initial user query'
assert messages_partial_action_obs[2].role == 'assistant'
assert messages_partial_action_obs[2].tool_calls is not None
assert len(messages_partial_action_obs[2].tool_calls) == 1
assert messages_partial_action_obs[2].tool_calls[0].id == 'call_ls_1'
assert messages_partial_action_obs[3].role == 'tool'
assert messages_partial_action_obs[3].tool_call_id == 'call_ls_1'
assert 'file1.txt' in messages_partial_action_obs[3].content[0].text
# --- Scenario 3: Partial History (Observation Only) ---
# Simulates processing only the last observation
partial_history_obs_only: list[Event] = [
cmd_obs,
]
messages_partial_obs_only = conversation_memory.process_events(
condensed_history=list(partial_history_obs_only), # Pass a copy
initial_user_action=user_message, # Provide the initial action
max_message_chars=None,
vision_is_active=False,
)
# Expected: System (added), Initial User (added).
# The CmdOutputObservation has tool_call_metadata, but there's no corresponding
# assistant message (from CmdRunAction) with the matching tool_call.id in the input history.
# Therefore, _filter_unmatched_tool_calls should remove the tool response message.
assert len(messages_partial_obs_only) == 2
assert (
messages_partial_obs_only[0].role == 'system'
) # Added by _ensure_system_message
assert messages_partial_obs_only[0].content[0].text == 'System message'
assert (
messages_partial_obs_only[1].role == 'user'
) # Added by _ensure_initial_user_message
assert messages_partial_obs_only[1].content[0].text == 'Initial user query'

View File

@@ -76,7 +76,7 @@ def test_get_messages(codeact_agent: CodeActAgent):
history.append(message_action_5)
codeact_agent.reset()
messages = codeact_agent._get_messages(history)
messages = codeact_agent._get_messages(history, message_action_1)
assert (
len(messages) == 6
@@ -106,16 +106,19 @@ def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
history.append(system_message_action)
# Add multiple user and agent messages
initial_user_message = None # Keep track of the first user message
for i in range(15):
message_action_user = MessageAction(f'User message {i}')
message_action_user._source = 'user'
if initial_user_message is None:
initial_user_message = message_action_user # Store the first one
history.append(message_action_user)
message_action_agent = MessageAction(f'Agent message {i}')
message_action_agent._source = 'agent'
history.append(message_action_agent)
codeact_agent.reset()
messages = codeact_agent._get_messages(history)
messages = codeact_agent._get_messages(history, initial_user_message)
# Check that only the last two user messages have cache_prompt=True
cached_user_messages = [