mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 06:48:02 -05:00
Fix truncation, ensure first user message and log (#8103)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -20,10 +20,7 @@ from openhands.controller.state.state import State
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.message import Message
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
AgentFinishAction,
|
||||
)
|
||||
from openhands.events.action import Action, AgentFinishAction, MessageAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.memory.condenser import Condenser
|
||||
@@ -173,7 +170,8 @@ class CodeActAgent(Agent):
|
||||
f'Processing {len(condensed_history)} events from a total of {len(state.history)} events'
|
||||
)
|
||||
|
||||
messages = self._get_messages(condensed_history)
|
||||
initial_user_message = self._get_initial_user_message(state.history)
|
||||
messages = self._get_messages(condensed_history, initial_user_message)
|
||||
params: dict = {
|
||||
'messages': self.llm.format_messages_for_llm(messages),
|
||||
}
|
||||
@@ -216,7 +214,29 @@ class CodeActAgent(Agent):
|
||||
self.pending_actions.append(action)
|
||||
return self.pending_actions.popleft()
|
||||
|
||||
def _get_messages(self, events: list[Event]) -> list[Message]:
|
||||
def _get_initial_user_message(self, history: list[Event]) -> MessageAction:
|
||||
"""Finds the initial user message action from the full history."""
|
||||
initial_user_message: MessageAction | None = None
|
||||
for event in history:
|
||||
if isinstance(event, MessageAction) and event.source == 'user':
|
||||
initial_user_message = event
|
||||
break
|
||||
|
||||
if initial_user_message is None:
|
||||
# This should not happen in a valid conversation
|
||||
logger.error(
|
||||
f'CRITICAL: Could not find the initial user MessageAction in the full {len(history)} events history.'
|
||||
)
|
||||
# Depending on desired robustness, could raise error or create a dummy action
|
||||
# and log the error
|
||||
raise ValueError(
|
||||
'Initial user message not found in history. Please report this issue.'
|
||||
)
|
||||
return initial_user_message
|
||||
|
||||
def _get_messages(
|
||||
self, events: list[Event], initial_user_message: MessageAction
|
||||
) -> list[Message]:
|
||||
"""Constructs the message history for the LLM conversation.
|
||||
|
||||
This method builds a structured conversation history by processing events from the state
|
||||
@@ -253,6 +273,7 @@ class CodeActAgent(Agent):
|
||||
# Use ConversationMemory to process events (including SystemMessageAction)
|
||||
messages = self.conversation_memory.process_events(
|
||||
condensed_history=events,
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=self.llm.config.max_message_chars,
|
||||
vision_is_active=self.llm.vision_is_active(),
|
||||
)
|
||||
|
||||
@@ -190,7 +190,7 @@ class AgentController:
|
||||
logger.debug(f'System message got from agent: {system_message}')
|
||||
if system_message:
|
||||
self.event_stream.add_event(system_message, EventSource.AGENT)
|
||||
logger.debug(f'System message added to event stream: {system_message}')
|
||||
logger.info(f'System message added to event stream: {system_message}')
|
||||
|
||||
async def close(self, set_stop_state: bool = True) -> None:
|
||||
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.
|
||||
@@ -1020,7 +1020,7 @@ class AgentController:
|
||||
self.state.start_id = 0
|
||||
|
||||
self.log(
|
||||
'debug',
|
||||
'info',
|
||||
f'AgentController {self.id} - created new state. start_id: {self.state.start_id}',
|
||||
)
|
||||
else:
|
||||
@@ -1030,7 +1030,7 @@ class AgentController:
|
||||
self.state.start_id = 0
|
||||
|
||||
self.log(
|
||||
'debug',
|
||||
'info',
|
||||
f'AgentController {self.id} initializing history from event {self.state.start_id}',
|
||||
)
|
||||
|
||||
@@ -1143,70 +1143,169 @@ class AgentController:
|
||||
|
||||
def _handle_long_context_error(self) -> None:
|
||||
# When context window is exceeded, keep roughly half of agent interactions
|
||||
kept_event_ids = {
|
||||
e.id for e in self._apply_conversation_window(self.state.history)
|
||||
}
|
||||
kept_events = self._apply_conversation_window()
|
||||
kept_event_ids = {e.id for e in kept_events}
|
||||
|
||||
self.log(
|
||||
'info',
|
||||
f'Context window exceeded. Keeping events with IDs: {kept_event_ids}',
|
||||
)
|
||||
|
||||
# The events to forget are those that are not in the kept set
|
||||
forgotten_event_ids = {e.id for e in self.state.history} - kept_event_ids
|
||||
|
||||
# Save the ID of the first event in our truncated history for future reloading
|
||||
if self.state.history:
|
||||
self.state.start_id = self.state.history[0].id
|
||||
if len(kept_event_ids) == 0:
|
||||
self.log(
|
||||
'warning',
|
||||
'No events kept after applying conversation window. This should not happen.',
|
||||
)
|
||||
|
||||
# verify that the first event id in kept_event_ids is the same as the start_id
|
||||
if len(kept_event_ids) > 0 and self.state.history[0].id not in kept_event_ids:
|
||||
self.log(
|
||||
'warning',
|
||||
f'First event after applying conversation window was not kept: {self.state.history[0].id} not in {kept_event_ids}',
|
||||
)
|
||||
|
||||
# Add an error event to trigger another step by the agent
|
||||
self.event_stream.add_event(
|
||||
CondensationAction(
|
||||
forgotten_events_start_id=min(forgotten_event_ids),
|
||||
forgotten_events_end_id=max(forgotten_event_ids),
|
||||
forgotten_events_start_id=min(forgotten_event_ids)
|
||||
if forgotten_event_ids
|
||||
else 0,
|
||||
forgotten_events_end_id=max(forgotten_event_ids)
|
||||
if forgotten_event_ids
|
||||
else 0,
|
||||
),
|
||||
EventSource.AGENT,
|
||||
)
|
||||
|
||||
def _apply_conversation_window(self, events: list[Event]) -> list[Event]:
|
||||
def _apply_conversation_window(self) -> list[Event]:
|
||||
"""Cuts history roughly in half when context window is exceeded.
|
||||
|
||||
It preserves action-observation pairs and ensures that the first user message is always included.
|
||||
It preserves action-observation pairs and ensures that the system message,
|
||||
the first user message, and its associated recall observation are always included
|
||||
at the beginning of the context window.
|
||||
|
||||
The algorithm:
|
||||
1. Cut history in half
|
||||
2. Check first event in new history:
|
||||
- If Observation: find and include its Action
|
||||
- If MessageAction: ensure its related Action-Observation pair isn't split
|
||||
3. Always include the first user message
|
||||
1. Identify essential initial events: System Message, First User Message, Recall Observation.
|
||||
2. Determine the slice of recent events to potentially keep.
|
||||
3. Validate the start of the recent slice for dangling observations.
|
||||
4. Combine essential events and validated recent events, ensuring essentials come first.
|
||||
|
||||
Args:
|
||||
events: List of events to filter
|
||||
|
||||
Returns:
|
||||
Filtered list of events keeping newest half while preserving pairs
|
||||
Filtered list of events keeping newest half while preserving pairs and essential initial events.
|
||||
"""
|
||||
if not events:
|
||||
return events
|
||||
if not self.state.history:
|
||||
return []
|
||||
|
||||
# Find first user message - we'll need to ensure it's included
|
||||
first_user_msg = next(
|
||||
(
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, MessageAction) and e.source == EventSource.USER
|
||||
),
|
||||
None,
|
||||
history = self.state.history
|
||||
|
||||
# 1. Identify essential initial events
|
||||
system_message: SystemMessageAction | None = None
|
||||
first_user_msg: MessageAction | None = None
|
||||
recall_action: RecallAction | None = None
|
||||
recall_observation: Observation | None = None
|
||||
|
||||
# Find System Message (should be the first event, if it exists)
|
||||
system_message = next(
|
||||
(e for e in history if isinstance(e, SystemMessageAction)), None
|
||||
)
|
||||
assert (
|
||||
system_message is None
|
||||
or isinstance(system_message, SystemMessageAction)
|
||||
and system_message.id == history[0].id
|
||||
)
|
||||
|
||||
# cut in half
|
||||
mid_point = max(1, len(events) // 2)
|
||||
kept_events = events[mid_point:]
|
||||
if len(kept_events) > 0 and isinstance(kept_events[0], Observation):
|
||||
kept_events = kept_events[1:]
|
||||
# Find First User Message, which MUST exist
|
||||
first_user_msg = self._first_user_message()
|
||||
if first_user_msg is None:
|
||||
raise RuntimeError('No first user message found in the event stream.')
|
||||
|
||||
# Ensure first user message is included
|
||||
if first_user_msg and first_user_msg not in kept_events:
|
||||
kept_events = [first_user_msg] + kept_events
|
||||
first_user_msg_index = -1
|
||||
for i, event in enumerate(history):
|
||||
if isinstance(event, MessageAction) and event.source == EventSource.USER:
|
||||
first_user_msg = event
|
||||
first_user_msg_index = i
|
||||
break
|
||||
|
||||
# start_id points to first user message
|
||||
# Find Recall Action and Observation related to the First User Message
|
||||
if first_user_msg is not None and first_user_msg_index != -1:
|
||||
# Look for RecallAction after the first user message
|
||||
for i in range(first_user_msg_index + 1, len(history)):
|
||||
event = history[i]
|
||||
if (
|
||||
isinstance(event, RecallAction)
|
||||
and event.query == first_user_msg.content
|
||||
):
|
||||
# Found RecallAction, now look for its Observation
|
||||
recall_action = event
|
||||
for j in range(i + 1, len(history)):
|
||||
obs_event = history[j]
|
||||
# Check for Observation caused by this RecallAction
|
||||
if (
|
||||
isinstance(obs_event, Observation)
|
||||
and obs_event.cause == recall_action.id
|
||||
):
|
||||
recall_observation = obs_event
|
||||
break # Found the observation, stop inner loop
|
||||
break # Found the recall action (and maybe obs), stop outer loop
|
||||
|
||||
essential_events: list[Event] = []
|
||||
if system_message:
|
||||
essential_events.append(system_message)
|
||||
if first_user_msg:
|
||||
self.state.start_id = first_user_msg.id
|
||||
essential_events.append(first_user_msg)
|
||||
# Also keep the RecallAction that triggered the essential RecallObservation
|
||||
if recall_action:
|
||||
essential_events.append(recall_action)
|
||||
if recall_observation:
|
||||
essential_events.append(recall_observation)
|
||||
|
||||
return kept_events
|
||||
# 2. Determine the slice of recent events to potentially keep
|
||||
num_non_essential_events = len(history) - len(essential_events)
|
||||
# Keep roughly half of the non-essential events, minimum 1
|
||||
num_recent_to_keep = max(1, num_non_essential_events // 2)
|
||||
|
||||
# Calculate the starting index for the recent slice
|
||||
slice_start_index = len(history) - num_recent_to_keep
|
||||
slice_start_index = max(0, slice_start_index) # Ensure index is not negative
|
||||
recent_events_slice = history[slice_start_index:]
|
||||
|
||||
# 3. Validate the start of the recent slice for dangling observations
|
||||
# IMPORTANT: Most observations in history are tool call results, which cannot be without their action, or we get an LLM API error
|
||||
first_valid_event_index = 0
|
||||
for i, event in enumerate(recent_events_slice):
|
||||
if isinstance(event, Observation):
|
||||
first_valid_event_index += 1
|
||||
else:
|
||||
break
|
||||
# If all events in the slice are dangling observations, we need to keep at least one
|
||||
if first_valid_event_index == len(recent_events_slice):
|
||||
self.log(
|
||||
'warning',
|
||||
'All recent events are dangling observations, which we truncate. This means the agent has only the essential first events. This should not happen.',
|
||||
)
|
||||
|
||||
# Adjust the recent_events_slice if dangling observations were found at the start
|
||||
if first_valid_event_index < len(recent_events_slice):
|
||||
validated_recent_events = recent_events_slice[first_valid_event_index:]
|
||||
if first_valid_event_index > 0:
|
||||
self.log(
|
||||
'debug',
|
||||
f'Removed {first_valid_event_index} dangling observation(s) from the start of recent event slice.',
|
||||
)
|
||||
else:
|
||||
validated_recent_events = []
|
||||
|
||||
# 4. Combine essential events and validated recent events
|
||||
events_to_keep: list[Event] = essential_events + validated_recent_events
|
||||
self.log('debug', f'History truncated. Kept {len(events_to_keep)} events.')
|
||||
|
||||
return events_to_keep
|
||||
|
||||
def _is_stuck(self) -> bool:
|
||||
"""Checks if the agent or its delegate is stuck in a loop.
|
||||
|
||||
@@ -54,6 +54,7 @@ class ConversationMemory:
|
||||
def process_events(
|
||||
self,
|
||||
condensed_history: list[Event],
|
||||
initial_user_action: MessageAction,
|
||||
max_message_chars: int | None = None,
|
||||
vision_is_active: bool = False,
|
||||
) -> list[Message]:
|
||||
@@ -66,12 +67,14 @@ class ConversationMemory:
|
||||
max_message_chars: The maximum number of characters in the content of an event included
|
||||
in the prompt to the LLM. Larger observations are truncated.
|
||||
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included.
|
||||
initial_user_action: The initial user message action, if available. Used to ensure the conversation starts correctly.
|
||||
"""
|
||||
|
||||
events = condensed_history
|
||||
|
||||
# Ensure the system message exists (handles legacy cases)
|
||||
# Ensure the event list starts with SystemMessageAction, then MessageAction(source='user')
|
||||
self._ensure_system_message(events)
|
||||
self._ensure_initial_user_message(events, initial_user_action)
|
||||
|
||||
# log visual browsing status
|
||||
logger.debug(f'Visual browsing: {self.agent_config.enable_som_visual_browsing}')
|
||||
@@ -699,6 +702,43 @@ class ConversationMemory:
|
||||
system_message = SystemMessageAction(content=system_prompt)
|
||||
# Insert the system message directly at the beginning of the events list
|
||||
events.insert(0, system_message)
|
||||
logger.debug(
|
||||
logger.info(
|
||||
'[ConversationMemory] Added SystemMessageAction for backward compatibility'
|
||||
)
|
||||
|
||||
def _ensure_initial_user_message(
|
||||
self, events: list[Event], initial_user_action: MessageAction
|
||||
) -> None:
|
||||
"""Checks if the second event is a user MessageAction and inserts the provided one if needed."""
|
||||
if (
|
||||
not events
|
||||
): # Should have system message from previous step, but safety check
|
||||
logger.error('Cannot ensure initial user message: event list is empty.')
|
||||
# Or raise? Let's log for now, _ensure_system_message should handle this.
|
||||
return
|
||||
|
||||
# We expect events[0] to be SystemMessageAction after _ensure_system_message
|
||||
if len(events) == 1:
|
||||
# Only system message exists
|
||||
logger.info(
|
||||
'Initial user message action was missing. Inserting the initial user message.'
|
||||
)
|
||||
events.insert(1, initial_user_action)
|
||||
elif not isinstance(events[1], MessageAction) or events[1].source != 'user':
|
||||
# The second event exists but is not the correct initial user message action.
|
||||
# We will insert the correct one provided.
|
||||
logger.info(
|
||||
'Second event was not the initial user message action. Inserting correct one at index 1.'
|
||||
)
|
||||
|
||||
# Insert the user message event at index 1. This will be the second message as LLM APIs expect
|
||||
# but something was wrong with the history, so log all we can.
|
||||
events.insert(1, initial_user_action)
|
||||
|
||||
# Else: events[1] is already a user MessageAction.
|
||||
# Check if it matches the one provided (if any discrepancy, log warning but proceed).
|
||||
elif events[1] != initial_user_action:
|
||||
logger.debug(
|
||||
'The user MessageAction at index 1 does not match the provided initial_user_action. '
|
||||
'Proceeding with the one found in condensed history.'
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import overload
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action.agent import CondensationAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation.agent import AgentCondensationObservation
|
||||
@@ -65,6 +66,8 @@ class View(BaseModel):
|
||||
break
|
||||
|
||||
if summary is not None and summary_offset is not None:
|
||||
logger.info(f'Inserting summary at offset {summary_offset}')
|
||||
|
||||
kept_events.insert(
|
||||
summary_offset, AgentCondensationObservation(content=summary)
|
||||
)
|
||||
|
||||
@@ -22,7 +22,6 @@ from openhands.events.observation import (
|
||||
ErrorObservation,
|
||||
)
|
||||
from openhands.events.observation.agent import RecallObservation
|
||||
from openhands.events.observation.commands import CmdOutputObservation
|
||||
from openhands.events.observation.empty import NullObservation
|
||||
from openhands.events.serialization import event_to_dict
|
||||
from openhands.llm import LLM
|
||||
@@ -765,7 +764,7 @@ async def test_context_window_exceeded_error_handling(
|
||||
|
||||
# We do that by playing the role of the recall module -- subscribe to the
|
||||
# event stream and respond to recall actions by inserting fake recall
|
||||
# obesrvations.
|
||||
# observations.
|
||||
def on_event_memory(event: Event):
|
||||
if isinstance(event, RecallAction):
|
||||
microagent_obs = RecallObservation(
|
||||
@@ -807,13 +806,19 @@ async def test_context_window_exceeded_error_handling(
|
||||
# size (because we return a message action, which triggers a recall, which
|
||||
# triggers a recall response). But if the pre/post-views are on the turn
|
||||
# when we throw the context window exceeded error, we should see the
|
||||
# post-step view compressed.
|
||||
# post-step view compressed (or rather, a CondensationAction added).
|
||||
for index, (first_view, second_view) in enumerate(
|
||||
zip(step_state.views[:-1], step_state.views[1:])
|
||||
):
|
||||
if index == error_after:
|
||||
assert len(first_view) > len(second_view)
|
||||
# Verify that the CondensationAction is present in the second view (after error)
|
||||
# but not in the first view (before error)
|
||||
assert not any(isinstance(e, CondensationAction) for e in first_view.events)
|
||||
assert any(isinstance(e, CondensationAction) for e in second_view.events)
|
||||
# The length might not strictly decrease due to CondensationAction being added
|
||||
assert len(first_view) == len(second_view)
|
||||
else:
|
||||
# Before the error, the view length should increase
|
||||
assert len(first_view) < len(second_view)
|
||||
|
||||
# The final state's history should contain:
|
||||
@@ -886,7 +891,7 @@ async def test_run_controller_with_context_window_exceeded_with_truncation(
|
||||
def step(self, state: State):
|
||||
# If the state has more than one message and we haven't errored yet,
|
||||
# throw the context window exceeded error
|
||||
if len(state.history) > 3 and not self.has_errored:
|
||||
if len(state.history) > 5 and not self.has_errored:
|
||||
error = ContextWindowExceededError(
|
||||
message='prompt is too long: 233885 tokens > 200000 maximum',
|
||||
model='',
|
||||
@@ -1467,126 +1472,6 @@ def test_agent_controller_should_step_with_null_observation_cause_zero(mock_agen
|
||||
), 'should_step should return False for NullObservation with cause = 0'
|
||||
|
||||
|
||||
def test_apply_conversation_window_basic(mock_event_stream, mock_agent):
|
||||
"""Test that the _apply_conversation_window method correctly prunes a list of events."""
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_apply_conversation_window_basic',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Create a sequence of events with IDs
|
||||
first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
|
||||
first_msg._source = EventSource.USER
|
||||
first_msg._id = 1
|
||||
|
||||
# Add agent question
|
||||
agent_msg = MessageAction(
|
||||
content='What task would you like me to perform?', wait_for_response=True
|
||||
)
|
||||
agent_msg._source = EventSource.AGENT
|
||||
agent_msg._id = 2
|
||||
|
||||
# Add user response
|
||||
user_response = MessageAction(
|
||||
content='Please list all files and show me current directory',
|
||||
wait_for_response=False,
|
||||
)
|
||||
user_response._source = EventSource.USER
|
||||
user_response._id = 3
|
||||
|
||||
cmd1 = CmdRunAction(command='ls')
|
||||
cmd1._id = 4
|
||||
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=4)
|
||||
obs1._id = 5
|
||||
obs1._cause = 4
|
||||
|
||||
cmd2 = CmdRunAction(command='pwd')
|
||||
cmd2._id = 6
|
||||
obs2 = CmdOutputObservation(command='pwd', content='/home', command_id=6)
|
||||
obs2._id = 7
|
||||
obs2._cause = 6
|
||||
|
||||
events = [first_msg, agent_msg, user_response, cmd1, obs1, cmd2, obs2]
|
||||
|
||||
# Apply truncation
|
||||
truncated = controller._apply_conversation_window(events)
|
||||
|
||||
# Verify truncation occured
|
||||
# Should keep first user message and roughly half of other events
|
||||
assert (
|
||||
3 <= len(truncated) < len(events)
|
||||
) # First message + at least one action-observation pair
|
||||
assert truncated[0] == first_msg # First message always preserved
|
||||
assert controller.state.start_id == first_msg._id
|
||||
|
||||
# Verify pairs aren't split
|
||||
for i, event in enumerate(truncated[1:]):
|
||||
if isinstance(event, CmdOutputObservation):
|
||||
assert any(e._id == event._cause for e in truncated[: i + 1])
|
||||
|
||||
|
||||
def test_history_restoration_after_truncation(mock_event_stream, mock_agent):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_truncation',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Create events with IDs
|
||||
first_msg = MessageAction(content='Start task', wait_for_response=False)
|
||||
first_msg._source = EventSource.USER
|
||||
first_msg._id = 1
|
||||
|
||||
events = [first_msg]
|
||||
for i in range(5):
|
||||
cmd = CmdRunAction(command=f'cmd{i}')
|
||||
cmd._id = i + 2
|
||||
obs = CmdOutputObservation(
|
||||
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
|
||||
)
|
||||
obs._cause = cmd._id
|
||||
events.extend([cmd, obs])
|
||||
|
||||
# Set up initial history
|
||||
controller.state.history = events.copy()
|
||||
|
||||
# Force truncation
|
||||
controller.state.history = controller._apply_conversation_window(
|
||||
controller.state.history
|
||||
)
|
||||
|
||||
# Save state
|
||||
saved_start_id = controller.state.start_id
|
||||
saved_history_len = len(controller.state.history)
|
||||
|
||||
# Set up mock event stream for new controller
|
||||
mock_event_stream.get_events.return_value = controller.state.history
|
||||
|
||||
# Create new controller with saved state
|
||||
new_controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_truncation',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
new_controller.state.start_id = saved_start_id
|
||||
new_controller.state.history = mock_event_stream.get_events()
|
||||
|
||||
# Verify restoration
|
||||
assert len(new_controller.state.history) == saved_history_len
|
||||
assert new_controller.state.history[0] == first_msg
|
||||
assert new_controller.state.start_id == saved_start_id
|
||||
|
||||
|
||||
def test_system_message_in_event_stream(mock_agent, test_event_stream):
|
||||
"""Test that SystemMessageAction is added to event stream in AgentController."""
|
||||
_ = AgentController(
|
||||
|
||||
569
tests/unit/test_agent_history.py
Normal file
569
tests/unit/test_agent_history.py
Normal file
@@ -0,0 +1,569 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AppConfig
|
||||
from openhands.events import EventSource
|
||||
from openhands.events.action import CmdRunAction, MessageAction, RecallAction
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation import (
|
||||
CmdOutputObservation,
|
||||
Observation,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
# Helper function to create events with sequential IDs and causes
|
||||
def create_events(event_data):
|
||||
events = []
|
||||
# Import necessary types here to avoid repeated imports inside the loop
|
||||
from openhands.events.action import CmdRunAction, RecallAction
|
||||
from openhands.events.observation import CmdOutputObservation, RecallObservation
|
||||
|
||||
for i, data in enumerate(event_data):
|
||||
event_type = data['type']
|
||||
source = data.get('source', EventSource.AGENT)
|
||||
kwargs = {} # Arguments for the event constructor
|
||||
|
||||
# Determine arguments based on event type
|
||||
if event_type == RecallAction:
|
||||
kwargs['query'] = data.get('query', '')
|
||||
kwargs['recall_type'] = data.get('recall_type', RecallType.KNOWLEDGE)
|
||||
elif event_type == RecallObservation:
|
||||
kwargs['content'] = data.get('content', '')
|
||||
kwargs['recall_type'] = data.get('recall_type', RecallType.KNOWLEDGE)
|
||||
elif event_type == CmdRunAction:
|
||||
kwargs['command'] = data.get('command', '')
|
||||
elif event_type == CmdOutputObservation:
|
||||
# Required args for CmdOutputObservation
|
||||
kwargs['content'] = data.get('content', '')
|
||||
kwargs['command'] = data.get('command', '')
|
||||
# Pass command_id via kwargs if present in data
|
||||
if 'command_id' in data:
|
||||
kwargs['command_id'] = data['command_id']
|
||||
# Pass metadata if present
|
||||
if 'metadata' in data:
|
||||
kwargs['metadata'] = data['metadata']
|
||||
else: # Default for MessageAction, SystemMessageAction, etc.
|
||||
kwargs['content'] = data.get('content', '')
|
||||
|
||||
# Instantiate the event
|
||||
event = event_type(**kwargs)
|
||||
|
||||
# Assign internal attributes AFTER instantiation
|
||||
event._id = i + 1 # Assign sequential IDs starting from 1
|
||||
event._source = source
|
||||
# Assign _cause using cause_id from data, AFTER event._id is set
|
||||
if 'cause_id' in data:
|
||||
event._cause = data['cause_id']
|
||||
# If command_id was NOT passed via kwargs but cause_id exists,
|
||||
# pass cause_id as command_id to __init__ via kwargs for legacy handling
|
||||
# This needs to happen *before* instantiation if we want __init__ to handle it
|
||||
# Let's adjust the logic slightly:
|
||||
if event_type == CmdOutputObservation:
|
||||
if 'command_id' not in kwargs and 'cause_id' in data:
|
||||
kwargs['command_id'] = data['cause_id'] # Let __init__ handle this
|
||||
# Re-instantiate if we added command_id
|
||||
if 'command_id' in kwargs and event.command_id != kwargs['command_id']:
|
||||
event = event_type(**kwargs)
|
||||
event._id = i + 1
|
||||
event._source = source
|
||||
|
||||
# Now assign _cause if it exists in data, after potential re-instantiation
|
||||
if 'cause_id' in data:
|
||||
event._cause = data['cause_id']
|
||||
|
||||
events.append(event)
|
||||
return events
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def controller_fixture():
|
||||
mock_agent = MagicMock(spec=Agent)
|
||||
mock_agent.llm = MagicMock(spec=LLM)
|
||||
mock_agent.llm.metrics = Metrics()
|
||||
mock_agent.llm.config = AppConfig().get_llm_config()
|
||||
mock_agent.config = AppConfig().get_agent_config('CodeActAgent')
|
||||
|
||||
mock_event_stream = MagicMock(spec=EventStream)
|
||||
mock_event_stream.sid = 'test_sid'
|
||||
mock_event_stream.file_store = InMemoryFileStore({})
|
||||
# Ensure get_latest_event_id returns an integer
|
||||
mock_event_stream.get_latest_event_id.return_value = -1
|
||||
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_sid',
|
||||
)
|
||||
controller.state = State(session_id='test_sid')
|
||||
|
||||
# Mock _first_user_message directly on the instance
|
||||
mock_first_user_message = MagicMock(spec=MessageAction)
|
||||
controller._first_user_message = MagicMock(return_value=mock_first_user_message)
|
||||
|
||||
return controller, mock_first_user_message
|
||||
|
||||
|
||||
# =============================================
|
||||
# Test Cases for _apply_conversation_window
|
||||
# =============================================
|
||||
|
||||
|
||||
def test_basic_truncation(controller_fixture):
|
||||
controller, mock_first_user_message = controller_fixture
|
||||
|
||||
controller.state.history = create_events(
|
||||
[
|
||||
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
|
||||
{
|
||||
'type': MessageAction,
|
||||
'content': 'User Task 1',
|
||||
'source': EventSource.USER,
|
||||
}, # 2
|
||||
{'type': RecallAction, 'query': 'User Task 1'}, # 3
|
||||
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4
|
||||
{'type': CmdRunAction, 'command': 'ls'}, # 5
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'file1',
|
||||
'command': 'ls',
|
||||
'cause_id': 5,
|
||||
}, # 6
|
||||
{'type': CmdRunAction, 'command': 'pwd'}, # 7
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': '/dir',
|
||||
'command': 'pwd',
|
||||
'cause_id': 7,
|
||||
}, # 8
|
||||
{'type': CmdRunAction, 'command': 'cat file1'}, # 9
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'content',
|
||||
'command': 'cat file1',
|
||||
'cause_id': 9,
|
||||
}, # 10
|
||||
]
|
||||
)
|
||||
mock_first_user_message.id = 2 # Set the ID of the mocked first user message
|
||||
|
||||
# Calculation (RecallAction now essential):
|
||||
# History len = 10
|
||||
# Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4)
|
||||
# Non-essential count = 10 - 4 = 6
|
||||
# num_recent_to_keep = max(1, 6 // 2) = 3
|
||||
# slice_start_index = 10 - 3 = 7
|
||||
# recent_events_slice = history[7:] = [obs2(8), cmd3(9), obs3(10)]
|
||||
# Validation: remove leading obs2(8). validated_slice = [cmd3(9), obs3(10)]
|
||||
# Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4), cmd3(9), obs3(10)]
|
||||
# Expected IDs: [1, 2, 3, 4, 9, 10]. Length 6.
|
||||
truncated_events = controller._apply_conversation_window()
|
||||
|
||||
assert len(truncated_events) == 6
|
||||
expected_ids = [1, 2, 3, 4, 9, 10]
|
||||
actual_ids = [e.id for e in truncated_events]
|
||||
assert actual_ids == expected_ids
|
||||
# Check no dangling observations at the start of the recent slice part
|
||||
# The first event of the validated slice is cmd3(9)
|
||||
assert not isinstance(truncated_events[4], Observation) # Index adjusted
|
||||
|
||||
|
||||
def test_no_system_message(controller_fixture):
|
||||
controller, mock_first_user_message = controller_fixture
|
||||
|
||||
controller.state.history = create_events(
|
||||
[
|
||||
{
|
||||
'type': MessageAction,
|
||||
'content': 'User Task 1',
|
||||
'source': EventSource.USER,
|
||||
}, # 1
|
||||
{'type': RecallAction, 'query': 'User Task 1'}, # 2
|
||||
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 2}, # 3
|
||||
{'type': CmdRunAction, 'command': 'ls'}, # 4
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'file1',
|
||||
'command': 'ls',
|
||||
'cause_id': 4,
|
||||
}, # 5
|
||||
{'type': CmdRunAction, 'command': 'pwd'}, # 6
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': '/dir',
|
||||
'command': 'pwd',
|
||||
'cause_id': 6,
|
||||
}, # 7
|
||||
{'type': CmdRunAction, 'command': 'cat file1'}, # 8
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'content',
|
||||
'command': 'cat file1',
|
||||
'cause_id': 8,
|
||||
}, # 9
|
||||
]
|
||||
)
|
||||
mock_first_user_message.id = 1
|
||||
|
||||
# Calculation (RecallAction now essential):
|
||||
# History len = 9
|
||||
# Essentials = [user(1), recall_act(2), recall_obs(3)] (len=3)
|
||||
# Non-essential count = 9 - 3 = 6
|
||||
# num_recent_to_keep = max(1, 6 // 2) = 3
|
||||
# slice_start_index = 9 - 3 = 6
|
||||
# recent_events_slice = history[6:] = [obs2(7), cmd3(8), obs3(9)]
|
||||
# Validation: remove leading obs2(7). validated_slice = [cmd3(8), obs3(9)]
|
||||
# Final = essentials + validated_slice = [user(1), recall_act(2), recall_obs(3), cmd3(8), obs3(9)]
|
||||
# Expected IDs: [1, 2, 3, 8, 9]. Length 5.
|
||||
truncated_events = controller._apply_conversation_window()
|
||||
|
||||
assert len(truncated_events) == 5
|
||||
expected_ids = [1, 2, 3, 8, 9]
|
||||
actual_ids = [e.id for e in truncated_events]
|
||||
assert actual_ids == expected_ids
|
||||
|
||||
|
||||
def test_no_recall_observation(controller_fixture):
|
||||
controller, mock_first_user_message = controller_fixture
|
||||
|
||||
controller.state.history = create_events(
|
||||
[
|
||||
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
|
||||
{
|
||||
'type': MessageAction,
|
||||
'content': 'User Task 1',
|
||||
'source': EventSource.USER,
|
||||
}, # 2
|
||||
{'type': RecallAction, 'query': 'User Task 1'}, # 3 (Recall Action exists)
|
||||
# Recall Observation is missing
|
||||
{'type': CmdRunAction, 'command': 'ls'}, # 4
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'file1',
|
||||
'command': 'ls',
|
||||
'cause_id': 4,
|
||||
}, # 5
|
||||
{'type': CmdRunAction, 'command': 'pwd'}, # 6
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': '/dir',
|
||||
'command': 'pwd',
|
||||
'cause_id': 6,
|
||||
}, # 7
|
||||
{'type': CmdRunAction, 'command': 'cat file1'}, # 8
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'content',
|
||||
'command': 'cat file1',
|
||||
'cause_id': 8,
|
||||
}, # 9
|
||||
]
|
||||
)
|
||||
mock_first_user_message.id = 2
|
||||
|
||||
# Calculation (RecallAction essential only if RecallObs exists):
|
||||
# History len = 9
|
||||
# Essentials = [sys(1), user(2)] (len=2) - RecallObs missing, so RecallAction not essential here
|
||||
# Non-essential count = 9 - 2 = 7
|
||||
# num_recent_to_keep = max(1, 7 // 2) = 3
|
||||
# slice_start_index = 9 - 3 = 6
|
||||
# recent_events_slice = history[6:] = [obs2(7), cmd3(8), obs3(9)]
|
||||
# Validation: remove leading obs2(7). validated_slice = [cmd3(8), obs3(9)]
|
||||
# Final = essentials + validated_slice = [sys(1), user(2), recall_action(3), cmd_cat(8), obs_cat(9)]
|
||||
# Expected IDs: [1, 2, 3, 8, 9]. Length 5.
|
||||
truncated_events = controller._apply_conversation_window()
|
||||
|
||||
assert len(truncated_events) == 5
|
||||
expected_ids = [1, 2, 3, 8, 9]
|
||||
actual_ids = [e.id for e in truncated_events]
|
||||
assert actual_ids == expected_ids
|
||||
|
||||
|
||||
def test_short_history_no_truncation(controller_fixture):
|
||||
controller, mock_first_user_message = controller_fixture
|
||||
|
||||
history = create_events(
|
||||
[
|
||||
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
|
||||
{
|
||||
'type': MessageAction,
|
||||
'content': 'User Task 1',
|
||||
'source': EventSource.USER,
|
||||
}, # 2
|
||||
{'type': RecallAction, 'query': 'User Task 1'}, # 3
|
||||
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4
|
||||
{'type': CmdRunAction, 'command': 'ls'}, # 5
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'file1',
|
||||
'command': 'ls',
|
||||
'cause_id': 5,
|
||||
}, # 6
|
||||
]
|
||||
)
|
||||
controller.state.history = history
|
||||
mock_first_user_message.id = 2
|
||||
|
||||
# Calculation (RecallAction now essential):
|
||||
# History len = 6
|
||||
# Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4)
|
||||
# Non-essential count = 6 - 4 = 2
|
||||
# num_recent_to_keep = max(1, 2 // 2) = 1
|
||||
# slice_start_index = 6 - 1 = 5
|
||||
# recent_events_slice = history[5:] = [obs1(6)]
|
||||
# Validation: remove leading obs1(6). validated_slice = []
|
||||
# Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4)]
|
||||
# Expected IDs: [1, 2, 3, 4]. Length 4.
|
||||
truncated_events = controller._apply_conversation_window()
|
||||
|
||||
assert len(truncated_events) == 4
|
||||
expected_ids = [1, 2, 3, 4]
|
||||
actual_ids = [e.id for e in truncated_events]
|
||||
assert actual_ids == expected_ids
|
||||
|
||||
|
||||
def test_only_essential_events(controller_fixture):
|
||||
controller, mock_first_user_message = controller_fixture
|
||||
|
||||
history = create_events(
|
||||
[
|
||||
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
|
||||
{
|
||||
'type': MessageAction,
|
||||
'content': 'User Task 1',
|
||||
'source': EventSource.USER,
|
||||
}, # 2
|
||||
{'type': RecallAction, 'query': 'User Task 1'}, # 3
|
||||
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4
|
||||
]
|
||||
)
|
||||
controller.state.history = history
|
||||
mock_first_user_message.id = 2
|
||||
|
||||
# Calculation (RecallAction now essential):
|
||||
# History len = 4
|
||||
# Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4)
|
||||
# Non-essential count = 4 - 4 = 0
|
||||
# num_recent_to_keep = max(1, 0 // 2) = 1
|
||||
# slice_start_index = 4 - 1 = 3
|
||||
# recent_events_slice = history[3:] = [recall_obs(4)]
|
||||
# Validation: remove leading recall_obs(4). validated_slice = []
|
||||
# Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4)]
|
||||
# Expected IDs: [1, 2, 3, 4]. Length 4.
|
||||
truncated_events = controller._apply_conversation_window()
|
||||
|
||||
assert len(truncated_events) == 4
|
||||
expected_ids = [1, 2, 3, 4]
|
||||
actual_ids = [e.id for e in truncated_events]
|
||||
assert actual_ids == expected_ids
|
||||
|
||||
|
||||
def test_dangling_observations_at_cut_point(controller_fixture):
|
||||
controller, mock_first_user_message = controller_fixture
|
||||
|
||||
history_forced_dangle = create_events(
|
||||
[
|
||||
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
|
||||
{
|
||||
'type': MessageAction,
|
||||
'content': 'User Task 1',
|
||||
'source': EventSource.USER,
|
||||
}, # 2
|
||||
{'type': RecallAction, 'query': 'User Task 1'}, # 3
|
||||
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4
|
||||
# --- Slice calculation should start here ---
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'dangle1',
|
||||
'command': 'cmd_unknown',
|
||||
}, # 5 (Dangling)
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'dangle2',
|
||||
'command': 'cmd_unknown',
|
||||
}, # 6 (Dangling)
|
||||
{'type': CmdRunAction, 'command': 'cmd1'}, # 7
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'obs1',
|
||||
'command': 'cmd1',
|
||||
'cause_id': 7,
|
||||
}, # 8
|
||||
{'type': CmdRunAction, 'command': 'cmd2'}, # 9
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'obs2',
|
||||
'command': 'cmd2',
|
||||
'cause_id': 9,
|
||||
}, # 10
|
||||
]
|
||||
) # 10 events total
|
||||
controller.state.history = history_forced_dangle
|
||||
mock_first_user_message.id = 2
|
||||
|
||||
# Calculation (RecallAction now essential):
|
||||
# History len = 10
|
||||
# Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4)
|
||||
# Non-essential count = 10 - 4 = 6
|
||||
# num_recent_to_keep = max(1, 6 // 2) = 3
|
||||
# slice_start_index = 10 - 3 = 7
|
||||
# recent_events_slice = history[7:] = [obs1(8), cmd2(9), obs2(10)]
|
||||
# Validation: remove leading obs1(8). validated_slice = [cmd2(9), obs2(10)]
|
||||
# Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4), cmd2(9), obs2(10)]
|
||||
# Expected IDs: [1, 2, 3, 4, 9, 10]. Length 6.
|
||||
truncated_events = controller._apply_conversation_window()
|
||||
|
||||
assert len(truncated_events) == 6
|
||||
expected_ids = [1, 2, 3, 4, 9, 10]
|
||||
actual_ids = [e.id for e in truncated_events]
|
||||
assert actual_ids == expected_ids
|
||||
# Verify dangling observations 5 and 6 were removed (implicitly by slice start and validation)
|
||||
|
||||
|
||||
def test_only_dangling_observations_in_recent_slice(controller_fixture):
|
||||
controller, mock_first_user_message = controller_fixture
|
||||
|
||||
history = create_events(
|
||||
[
|
||||
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
|
||||
{
|
||||
'type': MessageAction,
|
||||
'content': 'User Task 1',
|
||||
'source': EventSource.USER,
|
||||
}, # 2
|
||||
{'type': RecallAction, 'query': 'User Task 1'}, # 3
|
||||
{'type': RecallObservation, 'content': 'Recall result', 'cause_id': 3}, # 4
|
||||
# --- Slice calculation should start here ---
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'dangle1',
|
||||
'command': 'cmd_unknown',
|
||||
}, # 5 (Dangling)
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'dangle2',
|
||||
'command': 'cmd_unknown',
|
||||
}, # 6 (Dangling)
|
||||
]
|
||||
) # 6 events total
|
||||
controller.state.history = history
|
||||
mock_first_user_message.id = 2
|
||||
|
||||
# Calculation (RecallAction now essential):
|
||||
# History len = 6
|
||||
# Essentials = [sys(1), user(2), recall_act(3), recall_obs(4)] (len=4)
|
||||
# Non-essential count = 6 - 4 = 2
|
||||
# num_recent_to_keep = max(1, 2 // 2) = 1
|
||||
# slice_start_index = 6 - 1 = 5
|
||||
# recent_events_slice = history[5:] = [dangle2(6)]
|
||||
# Validation: remove leading dangle2(6). validated_slice = [] (Corrected based on user feedback/bugfix)
|
||||
# Final = essentials + validated_slice = [sys(1), user(2), recall_act(3), recall_obs(4)]
|
||||
# Expected IDs: [1, 2, 3, 4]. Length 4.
|
||||
with patch(
|
||||
'openhands.controller.agent_controller.logger.warning'
|
||||
) as mock_log_warning:
|
||||
truncated_events = controller._apply_conversation_window()
|
||||
|
||||
assert len(truncated_events) == 4
|
||||
expected_ids = [1, 2, 3, 4]
|
||||
actual_ids = [e.id for e in truncated_events]
|
||||
assert actual_ids == expected_ids
|
||||
# Verify dangling observations 5 and 6 were removed
|
||||
|
||||
# Check that the specific warning was logged exactly once
|
||||
assert mock_log_warning.call_count == 1
|
||||
|
||||
# Check the essential parts of the arguments, allowing for variations like stacklevel
|
||||
call_args, call_kwargs = mock_log_warning.call_args
|
||||
expected_message_substring = 'All recent events are dangling observations, which we truncate. This means the agent has only the essential first events. This should not happen.'
|
||||
assert expected_message_substring in call_args[0]
|
||||
assert 'extra' in call_kwargs
|
||||
assert call_kwargs['extra'].get('session_id') == 'test_sid'
|
||||
|
||||
|
||||
def test_empty_history(controller_fixture):
|
||||
controller, _ = controller_fixture
|
||||
controller.state.history = []
|
||||
|
||||
truncated_events = controller._apply_conversation_window()
|
||||
assert truncated_events == []
|
||||
|
||||
|
||||
def test_multiple_user_messages(controller_fixture):
|
||||
controller, mock_first_user_message = controller_fixture
|
||||
|
||||
history = create_events(
|
||||
[
|
||||
{'type': SystemMessageAction, 'content': 'System Prompt'}, # 1
|
||||
{
|
||||
'type': MessageAction,
|
||||
'content': 'User Task 1',
|
||||
'source': EventSource.USER,
|
||||
}, # 2 (First)
|
||||
{'type': RecallAction, 'query': 'User Task 1'}, # 3
|
||||
{
|
||||
'type': RecallObservation,
|
||||
'content': 'Recall result 1',
|
||||
'cause_id': 3,
|
||||
}, # 4
|
||||
{'type': CmdRunAction, 'command': 'cmd1'}, # 5
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'obs1',
|
||||
'command': 'cmd1',
|
||||
'cause_id': 5,
|
||||
}, # 6
|
||||
{
|
||||
'type': MessageAction,
|
||||
'content': 'User Task 2',
|
||||
'source': EventSource.USER,
|
||||
}, # 7 (Second)
|
||||
{'type': RecallAction, 'query': 'User Task 2'}, # 8
|
||||
{
|
||||
'type': RecallObservation,
|
||||
'content': 'Recall result 2',
|
||||
'cause_id': 8,
|
||||
}, # 9
|
||||
{'type': CmdRunAction, 'command': 'cmd2'}, # 10
|
||||
{
|
||||
'type': CmdOutputObservation,
|
||||
'content': 'obs2',
|
||||
'command': 'cmd2',
|
||||
'cause_id': 10,
|
||||
}, # 11
|
||||
]
|
||||
) # 11 events total
|
||||
controller.state.history = history
|
||||
mock_first_user_message.id = 2 # Explicitly set the first user message ID
|
||||
|
||||
# Calculation (RecallAction now essential):
|
||||
# History len = 11
|
||||
# Essentials = [sys(1), user1(2), recall_act1(3), recall_obs1(4)] (len=4)
|
||||
# Non-essential count = 11 - 4 = 7
|
||||
# num_recent_to_keep = max(1, 7 // 2) = 3
|
||||
# slice_start_index = 11 - 3 = 8
|
||||
# recent_events_slice = history[8:] = [recall_obs2(9), cmd2(10), obs2(11)]
|
||||
# Validation: remove leading recall_obs2(9). validated_slice = [cmd2(10), obs2(11)]
|
||||
# Final = essentials + validated_slice = [sys(1), user1(2), recall_act1(3), recall_obs1(4)] + [cmd2(10), obs2(11)]
|
||||
# Expected IDs: [1, 2, 3, 4, 10, 11]. Length 6.
|
||||
truncated_events = controller._apply_conversation_window()
|
||||
|
||||
assert len(truncated_events) == 6
|
||||
expected_ids = [1, 2, 3, 4, 10, 11]
|
||||
actual_ids = [e.id for e in truncated_events]
|
||||
assert actual_ids == expected_ids
|
||||
|
||||
# Verify the second user message (ID 7) was NOT kept
|
||||
assert not any(event.id == 7 for event in truncated_events)
|
||||
# Verify the first user message (ID 2) is present
|
||||
assert any(event.id == 2 for event in truncated_events)
|
||||
@@ -44,6 +44,7 @@ from openhands.events.observation.commands import (
|
||||
)
|
||||
from openhands.events.tool import ToolCallMetadata
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.memory.condenser import View
|
||||
|
||||
|
||||
@pytest.fixture(params=['CodeActAgent', 'ReadOnlyAgent'])
|
||||
@@ -97,6 +98,12 @@ def test_reset(agent):
|
||||
action._source = EventSource.AGENT
|
||||
agent.pending_actions.append(action)
|
||||
|
||||
# Create a mock state with initial user message
|
||||
mock_state = Mock(spec=State)
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
mock_state.history = [initial_user_message]
|
||||
|
||||
# Reset
|
||||
agent.reset()
|
||||
|
||||
@@ -110,8 +117,14 @@ def test_step_with_pending_actions(agent):
|
||||
pending_action._source = EventSource.AGENT
|
||||
agent.pending_actions.append(pending_action)
|
||||
|
||||
# Create a mock state with initial user message
|
||||
mock_state = Mock(spec=State)
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
mock_state.history = [initial_user_message]
|
||||
|
||||
# Step should return the pending action
|
||||
result = agent.step(Mock())
|
||||
result = agent.step(mock_state)
|
||||
assert result == pending_action
|
||||
assert len(agent.pending_actions) == 0
|
||||
|
||||
@@ -260,6 +273,11 @@ def test_step_with_no_pending_actions(mock_state: State):
|
||||
mock_state.latest_user_message_llm_metrics = None
|
||||
mock_state.latest_user_message_tool_call_metadata = None
|
||||
|
||||
# Add initial user message to history
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
mock_state.history = [initial_user_message]
|
||||
|
||||
action = agent.step(mock_state)
|
||||
assert isinstance(action, MessageAction)
|
||||
assert action.content == 'Task completed'
|
||||
@@ -330,42 +348,56 @@ def test_mismatched_tool_call_events_and_auto_add_system_message(
|
||||
)
|
||||
|
||||
action = CmdRunAction('foo')
|
||||
action._source = 'agent'
|
||||
action._source = EventSource.AGENT
|
||||
action.tool_call_metadata = tool_call_metadata
|
||||
|
||||
observation = CmdOutputObservation(content='', command_id=0, command='foo')
|
||||
observation.tool_call_metadata = tool_call_metadata
|
||||
|
||||
# Add initial user message
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
|
||||
# When both events are provided, the agent should get three messages:
|
||||
# 1. The system message (added automatically for backward compatibility)
|
||||
# 2. The action message
|
||||
# 3. The observation message
|
||||
mock_state.history = [action, observation]
|
||||
messages = agent._get_messages(mock_state.history)
|
||||
assert len(messages) == 3
|
||||
mock_state.history = [initial_user_message, action, observation]
|
||||
messages = agent._get_messages(mock_state.history, initial_user_message)
|
||||
assert len(messages) == 4 # System + initial user + action + observation
|
||||
assert messages[0].role == 'system' # First message should be the system message
|
||||
assert messages[1].role == 'assistant' # Second message should be the action
|
||||
assert messages[2].role == 'tool' # Third message should be the observation
|
||||
assert (
|
||||
messages[1].role == 'user'
|
||||
) # Second message should be the initial user message
|
||||
assert messages[2].role == 'assistant' # Third message should be the action
|
||||
assert messages[3].role == 'tool' # Fourth message should be the observation
|
||||
|
||||
# The same should hold if the events are presented out-of-order
|
||||
mock_state.history = [observation, action]
|
||||
messages = agent._get_messages(mock_state.history)
|
||||
assert len(messages) == 3
|
||||
mock_state.history = [initial_user_message, observation, action]
|
||||
messages = agent._get_messages(mock_state.history, initial_user_message)
|
||||
assert len(messages) == 4
|
||||
assert messages[0].role == 'system' # First message should be the system message
|
||||
assert (
|
||||
messages[1].role == 'user'
|
||||
) # Second message should be the initial user message
|
||||
|
||||
# If only one of the two events is present, then we should just get the system message
|
||||
# plus any valid message from the event
|
||||
mock_state.history = [action]
|
||||
messages = agent._get_messages(mock_state.history)
|
||||
mock_state.history = [initial_user_message, action]
|
||||
messages = agent._get_messages(mock_state.history, initial_user_message)
|
||||
assert (
|
||||
len(messages) == 1
|
||||
) # Only system message, action is waiting for its observation
|
||||
len(messages) == 2
|
||||
) # System + initial user message, action is waiting for its observation
|
||||
assert messages[0].role == 'system'
|
||||
assert messages[1].role == 'user'
|
||||
|
||||
mock_state.history = [observation]
|
||||
messages = agent._get_messages(mock_state.history)
|
||||
assert len(messages) == 1 # Only system message, observation has no matching action
|
||||
mock_state.history = [initial_user_message, observation]
|
||||
messages = agent._get_messages(mock_state.history, initial_user_message)
|
||||
assert (
|
||||
len(messages) == 2
|
||||
) # System + initial user message, observation has no matching action
|
||||
assert messages[0].role == 'system'
|
||||
assert messages[1].role == 'user'
|
||||
|
||||
|
||||
def test_grep_tool():
|
||||
@@ -470,3 +502,19 @@ def test_get_system_message():
|
||||
assert len(result.tools) > 0
|
||||
assert any(tool['function']['name'] == 'execute_bash' for tool in result.tools)
|
||||
assert result._source == EventSource.AGENT
|
||||
|
||||
|
||||
def test_step_raises_error_if_no_initial_user_message(
|
||||
agent: CodeActAgent, mock_state: State
|
||||
):
|
||||
"""Tests that step raises ValueError if the initial user message is not found."""
|
||||
# Ensure history does NOT contain a user MessageAction
|
||||
assistant_message = MessageAction(content='Assistant message')
|
||||
assistant_message._source = EventSource.AGENT
|
||||
mock_state.history = [assistant_message]
|
||||
# Mock the condenser to return the history as is
|
||||
agent.condenser = Mock()
|
||||
agent.condenser.condensed_history.return_value = View(events=mock_state.history)
|
||||
|
||||
with pytest.raises(ValueError, match='Initial user message not found'):
|
||||
agent.step(mock_state)
|
||||
|
||||
@@ -100,6 +100,7 @@ def test_process_events_with_message_action(conversation_memory):
|
||||
# Process events
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[system_message, user_message, assistant_message],
|
||||
initial_user_action=user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
@@ -108,10 +109,178 @@ def test_process_events_with_message_action(conversation_memory):
|
||||
assert len(messages) == 3
|
||||
assert messages[0].role == 'system'
|
||||
assert messages[0].content[0].text == 'System message'
|
||||
|
||||
|
||||
# Test cases for _ensure_system_message
|
||||
def test_ensure_system_message_adds_if_missing(conversation_memory):
|
||||
"""Test that _ensure_system_message adds a system message if none exists."""
|
||||
user_message = MessageAction(content='User message')
|
||||
user_message._source = EventSource.USER
|
||||
events = [user_message]
|
||||
conversation_memory._ensure_system_message(events)
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], SystemMessageAction)
|
||||
assert events[0].content == 'System message' # From fixture
|
||||
assert isinstance(events[1], MessageAction) # Original event is still there
|
||||
|
||||
|
||||
def test_ensure_system_message_does_nothing_if_present(conversation_memory):
|
||||
"""Test that _ensure_system_message does nothing if a system message is already present."""
|
||||
original_system_message = SystemMessageAction(content='Existing system message')
|
||||
user_message = MessageAction(content='User message')
|
||||
user_message._source = EventSource.USER
|
||||
events = [
|
||||
original_system_message,
|
||||
user_message,
|
||||
]
|
||||
original_events = list(events) # Copy before modification
|
||||
conversation_memory._ensure_system_message(events)
|
||||
assert events == original_events # List should be unchanged
|
||||
|
||||
|
||||
# Test cases for _ensure_initial_user_message
|
||||
@pytest.fixture
|
||||
def initial_user_action():
|
||||
msg = MessageAction(content='Initial User Message')
|
||||
msg._source = EventSource.USER
|
||||
return msg
|
||||
|
||||
|
||||
def test_ensure_initial_user_message_adds_if_only_system(
|
||||
conversation_memory, initial_user_action
|
||||
):
|
||||
"""Test adding the initial user message when only the system message exists."""
|
||||
system_message = SystemMessageAction(content='System')
|
||||
system_message._source = EventSource.AGENT
|
||||
events = [system_message]
|
||||
conversation_memory._ensure_initial_user_message(events, initial_user_action)
|
||||
assert len(events) == 2
|
||||
assert events[0] == system_message
|
||||
assert events[1] == initial_user_action
|
||||
|
||||
|
||||
def test_ensure_initial_user_message_correct_already_present(
|
||||
conversation_memory, initial_user_action
|
||||
):
|
||||
"""Test that nothing changes if the correct initial user message is at index 1."""
|
||||
system_message = SystemMessageAction(content='System')
|
||||
agent_message = MessageAction(content='Assistant')
|
||||
agent_message._source = EventSource.USER
|
||||
events = [
|
||||
system_message,
|
||||
initial_user_action,
|
||||
agent_message,
|
||||
]
|
||||
original_events = list(events)
|
||||
conversation_memory._ensure_initial_user_message(events, initial_user_action)
|
||||
assert events == original_events
|
||||
|
||||
|
||||
def test_ensure_initial_user_message_incorrect_at_index_1(
|
||||
conversation_memory, initial_user_action
|
||||
):
|
||||
"""Test inserting the correct initial user message when an incorrect message is at index 1."""
|
||||
system_message = SystemMessageAction(content='System')
|
||||
incorrect_second_message = MessageAction(content='Assistant')
|
||||
incorrect_second_message._source = EventSource.AGENT
|
||||
events = [system_message, incorrect_second_message]
|
||||
conversation_memory._ensure_initial_user_message(events, initial_user_action)
|
||||
assert len(events) == 3
|
||||
assert events[0] == system_message
|
||||
assert events[1] == initial_user_action # Correct one inserted
|
||||
assert events[2] == incorrect_second_message # Original second message shifted
|
||||
|
||||
|
||||
def test_ensure_initial_user_message_correct_present_later(
|
||||
conversation_memory, initial_user_action
|
||||
):
|
||||
"""Test inserting the correct initial user message at index 1 even if it exists later."""
|
||||
system_message = SystemMessageAction(content='System')
|
||||
incorrect_second_message = MessageAction(content='Assistant')
|
||||
incorrect_second_message._source = EventSource.AGENT
|
||||
# Correct initial message is present, but later in the list
|
||||
events = [system_message, incorrect_second_message]
|
||||
conversation_memory._ensure_system_message(events)
|
||||
conversation_memory._ensure_initial_user_message(events, initial_user_action)
|
||||
assert len(events) == 3 # Should still insert at index 1, not remove the later one
|
||||
assert events[0] == system_message
|
||||
assert events[1] == initial_user_action # Correct one inserted at index 1
|
||||
assert events[2] == incorrect_second_message # Original second message shifted
|
||||
# The duplicate initial_user_action originally at index 2 is now at index 3 (implicitly tested by length and content)
|
||||
|
||||
|
||||
def test_ensure_initial_user_message_different_user_msg_at_index_1(
|
||||
conversation_memory, initial_user_action
|
||||
):
|
||||
"""Test inserting the correct initial user message when a *different* user message is at index 1."""
|
||||
system_message = SystemMessageAction(content='System')
|
||||
different_user_message = MessageAction(content='Different User Message')
|
||||
different_user_message._source = EventSource.USER
|
||||
events = [system_message, different_user_message]
|
||||
conversation_memory._ensure_initial_user_message(events, initial_user_action)
|
||||
assert len(events) == 2
|
||||
assert events[0] == system_message
|
||||
assert events[1] == different_user_message # Original second message remains
|
||||
|
||||
|
||||
def test_ensure_initial_user_message_different_user_msg_at_index_1_and_orphaned_obs(
|
||||
conversation_memory, initial_user_action
|
||||
):
|
||||
"""
|
||||
Test process_events when an incorrect user message is at index 1 AND
|
||||
an orphaned observation (with tool_call_metadata but no matching action) exists.
|
||||
Expect: System msg, CORRECT initial user msg, the incorrect user msg (shifted).
|
||||
The orphaned observation should be filtered out.
|
||||
"""
|
||||
system_message = SystemMessageAction(content='System')
|
||||
different_user_message = MessageAction(content='Different User Message')
|
||||
different_user_message._source = EventSource.USER
|
||||
|
||||
# Create an orphaned observation (no matching action/tool call request will exist)
|
||||
# Use a dictionary that mimics ModelResponse structure to satisfy Pydantic
|
||||
mock_response = {
|
||||
'id': 'mock_response_id',
|
||||
'choices': [{'message': {'content': None, 'tool_calls': []}}],
|
||||
'created': 0,
|
||||
'model': '',
|
||||
'object': '',
|
||||
'usage': {'completion_tokens': 0, 'prompt_tokens': 0, 'total_tokens': 0},
|
||||
}
|
||||
orphaned_obs = CmdOutputObservation(
|
||||
command='orphan_cmd',
|
||||
content='Orphaned output',
|
||||
command_id=99,
|
||||
exit_code=0,
|
||||
)
|
||||
orphaned_obs.tool_call_metadata = ToolCallMetadata(
|
||||
tool_call_id='orphan_call_id',
|
||||
function_name='execute_bash',
|
||||
model_response=mock_response,
|
||||
total_calls_in_response=1,
|
||||
)
|
||||
|
||||
# Initial events list: system, wrong user message, orphaned observation
|
||||
events = [system_message, different_user_message, orphaned_obs]
|
||||
|
||||
# Call the main process_events method
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=events,
|
||||
initial_user_action=initial_user_action, # Provide the *correct* initial action
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Assertions on the final messages list
|
||||
assert len(messages) == 2
|
||||
# 1. System message should be first
|
||||
assert messages[0].role == 'system'
|
||||
assert messages[0].content[0].text == 'System'
|
||||
|
||||
# 2. The different user message should be left at index 1
|
||||
assert messages[1].role == 'user'
|
||||
assert messages[1].content[0].text == 'Hello'
|
||||
assert messages[2].role == 'assistant'
|
||||
assert messages[2].content[0].text == 'Hi there'
|
||||
assert messages[1].content[0].text == different_user_message.content
|
||||
|
||||
# Implicitly assert that the orphaned_obs was filtered out by checking the length (2)
|
||||
|
||||
|
||||
def test_process_events_with_cmd_output_observation(conversation_memory):
|
||||
@@ -125,14 +294,17 @@ def test_process_events_with_cmd_output_observation(conversation_memory):
|
||||
),
|
||||
)
|
||||
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3 # System + initial user + result
|
||||
result = messages[2] # The actual result is now at index 2
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -148,14 +320,17 @@ def test_process_events_with_ipython_run_cell_observation(conversation_memory):
|
||||
content='IPython output\n',
|
||||
)
|
||||
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3 # System + initial user + result
|
||||
result = messages[2] # The actual result is now at index 2
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -172,14 +347,17 @@ def test_process_events_with_agent_delegate_observation(conversation_memory):
|
||||
content='Content', outputs={'content': 'Delegated agent output'}
|
||||
)
|
||||
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3 # System + initial user + result
|
||||
result = messages[2] # The actual result is now at index 2
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -189,14 +367,17 @@ def test_process_events_with_agent_delegate_observation(conversation_memory):
|
||||
def test_process_events_with_error_observation(conversation_memory):
|
||||
obs = ErrorObservation('Error message')
|
||||
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3 # System + initial user + result
|
||||
result = messages[2] # The actual result is now at index 2
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -207,10 +388,13 @@ def test_process_events_with_error_observation(conversation_memory):
|
||||
def test_process_events_with_unknown_observation(conversation_memory):
|
||||
# Create a mock that inherits from Event but not Action or Observation
|
||||
obs = Mock(spec=Event)
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
|
||||
with pytest.raises(ValueError, match='Unknown event type'):
|
||||
conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
@@ -226,14 +410,17 @@ def test_process_events_with_file_edit_observation(conversation_memory):
|
||||
impl_source=FileEditSource.LLM_BASED_EDIT,
|
||||
)
|
||||
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3 # System + initial user + result
|
||||
result = messages[2] # The actual result is now at index 2
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -247,18 +434,21 @@ def test_process_events_with_file_read_observation(conversation_memory):
|
||||
impl_source=FileReadSource.DEFAULT,
|
||||
)
|
||||
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3 # System + initial user + result
|
||||
result = messages[2] # The actual result is now at index 2
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert result.content[0].text == 'File content'
|
||||
assert result.content[0].text == '\n\nFile content'
|
||||
|
||||
|
||||
def test_process_events_with_browser_output_observation(conversation_memory):
|
||||
@@ -270,14 +460,17 @@ def test_process_events_with_browser_output_observation(conversation_memory):
|
||||
error=False,
|
||||
)
|
||||
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3 # System + initial user + result
|
||||
result = messages[2] # The actual result is now at index 2
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -287,14 +480,17 @@ def test_process_events_with_browser_output_observation(conversation_memory):
|
||||
def test_process_events_with_user_reject_observation(conversation_memory):
|
||||
obs = UserRejectObservation('Action rejected')
|
||||
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3 # System + initial user + result
|
||||
result = messages[2] # The actual result is now at index 2
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -317,14 +513,17 @@ def test_process_events_with_empty_environment_info(conversation_memory):
|
||||
content='Retrieved environment info',
|
||||
)
|
||||
|
||||
initial_user_message = MessageAction(content='Initial user message')
|
||||
initial_user_message._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[empty_obs],
|
||||
initial_user_action=initial_user_message,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Should only contain no messages except system message
|
||||
assert len(messages) == 1
|
||||
# Should only contain system message and initial user message
|
||||
assert len(messages) == 2
|
||||
|
||||
# Verify that build_workspace_context was NOT called since all input values were empty
|
||||
conversation_memory.prompt_manager.build_workspace_context.assert_not_called()
|
||||
@@ -348,14 +547,20 @@ def test_process_events_with_function_calling_observation(conversation_memory):
|
||||
model_response=mock_response,
|
||||
total_calls_in_response=1,
|
||||
)
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# No direct message when using function calling
|
||||
assert len(messages) == 1 # should be no messages except system message
|
||||
assert (
|
||||
len(messages) == 2
|
||||
) # should be no messages except system message and initial user message
|
||||
|
||||
|
||||
def test_process_events_with_message_action_with_image(conversation_memory):
|
||||
@@ -365,14 +570,18 @@ def test_process_events_with_message_action_with_image(conversation_memory):
|
||||
)
|
||||
action._source = EventSource.AGENT
|
||||
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[action],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=True,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3
|
||||
result = messages[2]
|
||||
assert result.role == 'assistant'
|
||||
assert len(result.content) == 2
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -385,14 +594,18 @@ def test_process_events_with_user_cmd_action(conversation_memory):
|
||||
action = CmdRunAction(command='ls -l')
|
||||
action._source = EventSource.USER
|
||||
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[action],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3
|
||||
result = messages[2]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -418,14 +631,18 @@ def test_process_events_with_agent_finish_action_with_tool_metadata(
|
||||
total_calls_in_response=1,
|
||||
)
|
||||
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[action],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3
|
||||
result = messages[2]
|
||||
assert result.role == 'assistant'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -461,18 +678,22 @@ def test_process_events_with_environment_microagent_observation(conversation_mem
|
||||
content='Retrieved environment info',
|
||||
)
|
||||
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3
|
||||
result = messages[2]
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
assert result.content[0].text == 'Formatted repository and runtime info'
|
||||
assert result.content[0].text == '\n\nFormatted repository and runtime info'
|
||||
|
||||
# Verify the prompt_manager was called with the correct parameters
|
||||
conversation_memory.prompt_manager.build_workspace_context.assert_called_once()
|
||||
@@ -516,14 +737,18 @@ def test_process_events_with_knowledge_microagent_microagent_observation(
|
||||
content='Retrieved knowledge from microagents',
|
||||
)
|
||||
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
assert len(messages) == 2
|
||||
result = messages[1]
|
||||
assert len(messages) == 3 # System + Initial User + Result
|
||||
result = messages[2] # Result is now at index 2
|
||||
assert result.role == 'user'
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], TextContent)
|
||||
@@ -559,14 +784,18 @@ def test_process_events_with_microagent_observation_extensions_disabled(
|
||||
content='Retrieved environment info',
|
||||
)
|
||||
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# When prompt extensions are disabled, the RecallObservation should be ignored
|
||||
assert len(messages) == 1 # should be no messages except system message
|
||||
assert len(messages) == 2 # System + Initial User
|
||||
|
||||
# Verify the prompt_manager was not called
|
||||
conversation_memory.prompt_manager.build_workspace_context.assert_not_called()
|
||||
@@ -581,14 +810,18 @@ def test_process_events_with_empty_microagent_knowledge(conversation_memory):
|
||||
content='Retrieved knowledge from microagents',
|
||||
)
|
||||
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# The implementation returns an empty string and it doesn't creates a message
|
||||
assert len(messages) == 1 # should be no messages except system message
|
||||
assert len(messages) == 2 # System + Initial User
|
||||
|
||||
# When there are no triggered agents, build_microagent_info is not called
|
||||
conversation_memory.prompt_manager.build_microagent_info.assert_not_called()
|
||||
@@ -793,19 +1026,23 @@ def test_process_events_with_microagent_observation_deduplication(conversation_m
|
||||
content='Third retrieval',
|
||||
)
|
||||
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs1, obs2, obs3],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Verify that only the first occurrence of content for each agent is included
|
||||
assert len(messages) == 2 # with system message
|
||||
|
||||
assert len(messages) == 3 # System + Initial User + Result
|
||||
# Result is now at index 2
|
||||
# First microagent should include all agents since they appear here first
|
||||
assert 'Image best practices v1' in messages[1].content[0].text
|
||||
assert 'Git best practices v1' in messages[1].content[0].text
|
||||
assert 'Python best practices v1' in messages[1].content[0].text
|
||||
assert 'Image best practices v1' in messages[2].content[0].text
|
||||
assert 'Git best practices v1' in messages[2].content[0].text
|
||||
assert 'Python best practices v1' in messages[2].content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_microagent_observation_deduplication_disabled_agents(
|
||||
@@ -842,18 +1079,22 @@ def test_process_events_with_microagent_observation_deduplication_disabled_agent
|
||||
content='Second retrieval',
|
||||
)
|
||||
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs1, obs2],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Verify that disabled agents are filtered out and only the first occurrence of enabled agents is included
|
||||
assert len(messages) == 2
|
||||
|
||||
assert len(messages) == 3 # System + Initial User + Result
|
||||
# Result is now at index 2
|
||||
# First microagent should include enabled_agent but not disabled_agent
|
||||
assert 'Disabled agent content' not in messages[1].content[0].text
|
||||
assert 'Enabled agent content v1' in messages[1].content[0].text
|
||||
assert 'Disabled agent content' not in messages[2].content[0].text
|
||||
assert 'Enabled agent content v1' in messages[2].content[0].text
|
||||
|
||||
|
||||
def test_process_events_with_microagent_observation_deduplication_empty(
|
||||
@@ -866,17 +1107,22 @@ def test_process_events_with_microagent_observation_deduplication_empty(
|
||||
content='Empty retrieval',
|
||||
)
|
||||
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[obs],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Verify that empty RecallObservations are handled gracefully
|
||||
assert (
|
||||
len(messages) == 1
|
||||
) # an empty microagent is not added to Messages, only system message is found
|
||||
len(messages) == 2 # System + Initial User
|
||||
) # an empty microagent is not added to Messages
|
||||
assert messages[0].role == 'system'
|
||||
assert messages[1].role == 'user' # Initial user message
|
||||
|
||||
|
||||
def test_has_agent_in_earlier_events(conversation_memory):
|
||||
@@ -1088,13 +1334,183 @@ def test_system_message_in_events(conversation_memory):
|
||||
system_message._source = EventSource.AGENT
|
||||
|
||||
# Process events with the system message in condensed_history
|
||||
# Define initial user action
|
||||
initial_user_action = MessageAction(content='Initial user message')
|
||||
initial_user_action._source = EventSource.USER
|
||||
messages = conversation_memory.process_events(
|
||||
condensed_history=[system_message],
|
||||
initial_user_action=initial_user_action,
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Check that the system message was processed correctly
|
||||
assert len(messages) == 1
|
||||
assert len(messages) == 2 # System + Initial User
|
||||
assert messages[0].role == 'system'
|
||||
assert messages[0].content[0].text == 'System message'
|
||||
assert messages[1].role == 'user' # Initial user message
|
||||
|
||||
|
||||
# Helper function to create mock tool call metadata
|
||||
def _create_mock_tool_call_metadata(
|
||||
tool_call_id: str, function_name: str, response_id: str = 'mock_response_id'
|
||||
) -> ToolCallMetadata:
|
||||
# Use a dictionary that mimics ModelResponse structure to satisfy Pydantic
|
||||
mock_response = {
|
||||
'id': response_id,
|
||||
'choices': [
|
||||
{
|
||||
'message': {
|
||||
'role': 'assistant',
|
||||
'content': None, # Content is None for tool calls
|
||||
'tool_calls': [
|
||||
{
|
||||
'id': tool_call_id,
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': function_name,
|
||||
'arguments': '{}',
|
||||
}, # Args don't matter for this test
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
],
|
||||
'created': 0,
|
||||
'model': 'mock_model',
|
||||
'object': 'chat.completion',
|
||||
'usage': {'completion_tokens': 0, 'prompt_tokens': 0, 'total_tokens': 0},
|
||||
}
|
||||
return ToolCallMetadata(
|
||||
tool_call_id=tool_call_id,
|
||||
function_name=function_name,
|
||||
model_response=mock_response,
|
||||
total_calls_in_response=1,
|
||||
)
|
||||
|
||||
|
||||
def test_process_events_partial_history(conversation_memory):
|
||||
"""
|
||||
Tests process_events with full and partial histories to verify
|
||||
_ensure_system_message, _ensure_initial_user_message, and tool call matching logic.
|
||||
"""
|
||||
# --- Define Common Events ---
|
||||
system_message = SystemMessageAction(content='System message')
|
||||
system_message._source = EventSource.AGENT
|
||||
|
||||
user_message = MessageAction(
|
||||
content='Initial user query'
|
||||
) # This is the crucial initial_user_action
|
||||
user_message._source = EventSource.USER
|
||||
|
||||
recall_obs = RecallObservation(
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name='test-repo',
|
||||
repo_directory='/path/to/repo',
|
||||
content='Retrieved environment info',
|
||||
)
|
||||
recall_obs._source = EventSource.AGENT
|
||||
|
||||
cmd_action = CmdRunAction(command='ls', thought='Running ls')
|
||||
cmd_action._source = EventSource.AGENT
|
||||
cmd_action.tool_call_metadata = _create_mock_tool_call_metadata(
|
||||
tool_call_id='call_ls_1', function_name='execute_bash', response_id='resp_ls_1'
|
||||
)
|
||||
|
||||
cmd_obs = CmdOutputObservation(
|
||||
command_id=1, command='ls', content='file1.txt\nfile2.py', exit_code=0
|
||||
)
|
||||
cmd_obs._source = EventSource.AGENT
|
||||
cmd_obs.tool_call_metadata = _create_mock_tool_call_metadata(
|
||||
tool_call_id='call_ls_1', function_name='execute_bash', response_id='resp_ls_1'
|
||||
)
|
||||
|
||||
# --- Scenario 1: Full History ---
|
||||
full_history: list[Event] = [
|
||||
system_message,
|
||||
user_message, # Correct initial user message at index 1
|
||||
recall_obs,
|
||||
cmd_action,
|
||||
cmd_obs,
|
||||
]
|
||||
messages_full = conversation_memory.process_events(
|
||||
condensed_history=list(full_history), # Pass a copy
|
||||
initial_user_action=user_message, # Provide the initial action
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Expected: System, User, Recall (formatted), Assistant (tool call), Tool Response
|
||||
assert len(messages_full) == 5
|
||||
assert messages_full[0].role == 'system'
|
||||
assert messages_full[0].content[0].text == 'System message'
|
||||
assert messages_full[1].role == 'user'
|
||||
assert messages_full[1].content[0].text == 'Initial user query'
|
||||
assert messages_full[2].role == 'user' # Recall obs becomes user message
|
||||
assert (
|
||||
'Formatted repository and runtime info' in messages_full[2].content[0].text
|
||||
) # From fixture mock
|
||||
assert messages_full[3].role == 'assistant'
|
||||
assert messages_full[3].tool_calls is not None
|
||||
assert len(messages_full[3].tool_calls) == 1
|
||||
assert messages_full[3].tool_calls[0].id == 'call_ls_1'
|
||||
assert messages_full[4].role == 'tool'
|
||||
assert messages_full[4].tool_call_id == 'call_ls_1'
|
||||
assert 'file1.txt' in messages_full[4].content[0].text
|
||||
|
||||
# --- Scenario 2: Partial History (Action + Observation) ---
|
||||
# Simulates processing only the last action/observation pair
|
||||
partial_history_action_obs: list[Event] = [
|
||||
cmd_action,
|
||||
cmd_obs,
|
||||
]
|
||||
messages_partial_action_obs = conversation_memory.process_events(
|
||||
condensed_history=list(partial_history_action_obs), # Pass a copy
|
||||
initial_user_action=user_message, # Provide the initial action
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Expected: System (added), Initial User (added), Assistant (tool call), Tool Response
|
||||
assert len(messages_partial_action_obs) == 4
|
||||
assert (
|
||||
messages_partial_action_obs[0].role == 'system'
|
||||
) # Added by _ensure_system_message
|
||||
assert messages_partial_action_obs[0].content[0].text == 'System message'
|
||||
assert (
|
||||
messages_partial_action_obs[1].role == 'user'
|
||||
) # Added by _ensure_initial_user_message
|
||||
assert messages_partial_action_obs[1].content[0].text == 'Initial user query'
|
||||
assert messages_partial_action_obs[2].role == 'assistant'
|
||||
assert messages_partial_action_obs[2].tool_calls is not None
|
||||
assert len(messages_partial_action_obs[2].tool_calls) == 1
|
||||
assert messages_partial_action_obs[2].tool_calls[0].id == 'call_ls_1'
|
||||
assert messages_partial_action_obs[3].role == 'tool'
|
||||
assert messages_partial_action_obs[3].tool_call_id == 'call_ls_1'
|
||||
assert 'file1.txt' in messages_partial_action_obs[3].content[0].text
|
||||
|
||||
# --- Scenario 3: Partial History (Observation Only) ---
|
||||
# Simulates processing only the last observation
|
||||
partial_history_obs_only: list[Event] = [
|
||||
cmd_obs,
|
||||
]
|
||||
messages_partial_obs_only = conversation_memory.process_events(
|
||||
condensed_history=list(partial_history_obs_only), # Pass a copy
|
||||
initial_user_action=user_message, # Provide the initial action
|
||||
max_message_chars=None,
|
||||
vision_is_active=False,
|
||||
)
|
||||
|
||||
# Expected: System (added), Initial User (added).
|
||||
# The CmdOutputObservation has tool_call_metadata, but there's no corresponding
|
||||
# assistant message (from CmdRunAction) with the matching tool_call.id in the input history.
|
||||
# Therefore, _filter_unmatched_tool_calls should remove the tool response message.
|
||||
assert len(messages_partial_obs_only) == 2
|
||||
assert (
|
||||
messages_partial_obs_only[0].role == 'system'
|
||||
) # Added by _ensure_system_message
|
||||
assert messages_partial_obs_only[0].content[0].text == 'System message'
|
||||
assert (
|
||||
messages_partial_obs_only[1].role == 'user'
|
||||
) # Added by _ensure_initial_user_message
|
||||
assert messages_partial_obs_only[1].content[0].text == 'Initial user query'
|
||||
|
||||
@@ -76,7 +76,7 @@ def test_get_messages(codeact_agent: CodeActAgent):
|
||||
history.append(message_action_5)
|
||||
|
||||
codeact_agent.reset()
|
||||
messages = codeact_agent._get_messages(history)
|
||||
messages = codeact_agent._get_messages(history, message_action_1)
|
||||
|
||||
assert (
|
||||
len(messages) == 6
|
||||
@@ -106,16 +106,19 @@ def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
|
||||
history.append(system_message_action)
|
||||
|
||||
# Add multiple user and agent messages
|
||||
initial_user_message = None # Keep track of the first user message
|
||||
for i in range(15):
|
||||
message_action_user = MessageAction(f'User message {i}')
|
||||
message_action_user._source = 'user'
|
||||
if initial_user_message is None:
|
||||
initial_user_message = message_action_user # Store the first one
|
||||
history.append(message_action_user)
|
||||
message_action_agent = MessageAction(f'Agent message {i}')
|
||||
message_action_agent._source = 'agent'
|
||||
history.append(message_action_agent)
|
||||
|
||||
codeact_agent.reset()
|
||||
messages = codeact_agent._get_messages(history)
|
||||
messages = codeact_agent._get_messages(history, initial_user_message)
|
||||
|
||||
# Check that only the last two user messages have cache_prompt=True
|
||||
cached_user_messages = [
|
||||
|
||||
Reference in New Issue
Block a user