From 8dee3342360d462e9ae32d47406e00799dc316db Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Thu, 14 Nov 2024 03:42:39 +0100 Subject: [PATCH] Context Window Exceeded fix (#4977) --- openhands/controller/agent_controller.py | 132 +++++++++++++++- openhands/controller/state/state.py | 2 + tests/unit/test_truncation.py | 188 +++++++++++++++++++++++ 3 files changed, 320 insertions(+), 2 deletions(-) create mode 100644 tests/unit/test_truncation.py diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index 78b27c89ff..e6f4e2eb3e 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -5,6 +5,7 @@ import traceback from typing import Callable, ClassVar, Type import litellm +from litellm.exceptions import ContextWindowExceededError from openhands.controller.agent import Agent from openhands.controller.state.state import State, TrafficControlState @@ -485,6 +486,15 @@ class AgentController: EventSource.AGENT, ) return + except ContextWindowExceededError: + # When context window is exceeded, keep roughly half of agent interactions + self.state.history = self._apply_conversation_window(self.state.history) + + # Save the ID of the first event in our truncated history for future reloading + if self.state.history: + self.state.start_id = self.state.history[0].id + # Don't add error event - let the agent retry with reduced context + return if action.runnable: if self.state.confirmation_mode and ( @@ -659,6 +669,12 @@ class AgentController: - For delegate events (between AgentDelegateAction and AgentDelegateObservation): - Excludes all events between the action and observation - Includes the delegate action and observation themselves + + The history is loaded in two parts if truncation_id is set: + 1. First user message from start_id onwards + 2. Rest of history from truncation_id to the end + + Otherwise loads normally from start_id. """ # define range of events to fetch @@ -680,8 +696,33 @@ class AgentController: self.state.history = [] return - # Get all events, filtering out backend events and hidden events - events = list( + events: list[Event] = [] + + # If we have a truncation point, get first user message and then rest of history + if hasattr(self.state, 'truncation_id') and self.state.truncation_id > 0: + # Find first user message from stream + first_user_msg = next( + ( + e + for e in self.event_stream.get_events( + start_id=start_id, + end_id=end_id, + reverse=False, + filter_out_type=self.filter_out, + filter_hidden=True, + ) + if isinstance(e, MessageAction) and e.source == EventSource.USER + ), + None, + ) + if first_user_msg: + events.append(first_user_msg) + + # the rest of the events are from the truncation point + start_id = self.state.truncation_id + + # Get rest of history + events_to_add = list( self.event_stream.get_events( start_id=start_id, end_id=end_id, @@ -690,6 +731,7 @@ class AgentController: filter_hidden=True, ) ) + events.extend(events_to_add) # Find all delegate action/observation pairs delegate_ranges: list[tuple[int, int]] = [] @@ -744,6 +786,92 @@ class AgentController: # make sure history is in sync self.state.start_id = start_id + def _apply_conversation_window(self, events: list[Event]) -> list[Event]: + """Cuts history roughly in half when context window is exceeded, preserving action-observation pairs + and ensuring the first user message is always included. + + The algorithm: + 1. Cut history in half + 2. Check first event in new history: + - If Observation: find and include its Action + - If MessageAction: ensure its related Action-Observation pair isn't split + 3. Always include the first user message + + Args: + events: List of events to filter + + Returns: + Filtered list of events keeping newest half while preserving pairs + """ + if not events: + return events + + # Find first user message - we'll need to ensure it's included + first_user_msg = next( + ( + e + for e in events + if isinstance(e, MessageAction) and e.source == EventSource.USER + ), + None, + ) + + # cut in half + mid_point = max(1, len(events) // 2) + kept_events = events[mid_point:] + + # Handle first event in truncated history + if kept_events: + i = 0 + while i < len(kept_events): + first_event = kept_events[i] + if isinstance(first_event, Observation) and first_event.cause: + # Find its action and include it + matching_action = next( + ( + e + for e in reversed(events[:mid_point]) + if isinstance(e, Action) and e.id == first_event.cause + ), + None, + ) + if matching_action: + kept_events = [matching_action] + kept_events + else: + self.log( + 'warning', + f'Found Observation without matching Action at id={first_event.id}', + ) + # drop this observation + kept_events = kept_events[1:] + break + + elif isinstance(first_event, MessageAction) or ( + isinstance(first_event, Action) + and first_event.source == EventSource.USER + ): + # if it's a message action or a user action, keep it and continue to find the next event + i += 1 + continue + + else: + # if it's an action with source == EventSource.AGENT, we're good + break + + # Save where to continue from in next reload + if kept_events: + self.state.truncation_id = kept_events[0].id + + # Ensure first user message is included + if first_user_msg and first_user_msg not in kept_events: + kept_events = [first_user_msg] + kept_events + + # start_id points to first user message + if first_user_msg: + self.state.start_id = first_user_msg.id + + return kept_events + def _is_stuck(self): """Checks if the agent or its delegate is stuck in a loop. diff --git a/openhands/controller/state/state.py b/openhands/controller/state/state.py index d52844d418..73aa4e666e 100644 --- a/openhands/controller/state/state.py +++ b/openhands/controller/state/state.py @@ -92,6 +92,8 @@ class State: # start_id and end_id track the range of events in history start_id: int = -1 end_id: int = -1 + # truncation_id tracks where to load history after context window truncation + truncation_id: int = -1 almost_stuck: int = 0 delegates: dict[tuple[int, int], tuple[str, str]] = field(default_factory=dict) # NOTE: This will never be used by the controller, but it can be used by different diff --git a/tests/unit/test_truncation.py b/tests/unit/test_truncation.py new file mode 100644 index 0000000000..7d03d2f619 --- /dev/null +++ b/tests/unit/test_truncation.py @@ -0,0 +1,188 @@ +from unittest.mock import MagicMock + +import pytest + +from openhands.controller.agent_controller import AgentController +from openhands.events import EventSource +from openhands.events.action import CmdRunAction, MessageAction +from openhands.events.observation import CmdOutputObservation + + +@pytest.fixture +def mock_event_stream(): + stream = MagicMock() + # Mock get_events to return an empty list by default + stream.get_events.return_value = [] + return stream + + +@pytest.fixture +def mock_agent(): + agent = MagicMock() + agent.llm = MagicMock() + agent.llm.config = MagicMock() + return agent + + +class TestTruncation: + def test_apply_conversation_window_basic(self, mock_event_stream, mock_agent): + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test_truncation', + confirmation_mode=False, + headless_mode=True, + ) + + # Create a sequence of events with IDs + first_msg = MessageAction(content='Hello, start task', wait_for_response=False) + first_msg._source = EventSource.USER + first_msg._id = 1 + + cmd1 = CmdRunAction(command='ls') + cmd1._id = 2 + obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=2) + obs1._id = 3 + obs1._cause = 2 + + cmd2 = CmdRunAction(command='pwd') + cmd2._id = 4 + obs2 = CmdOutputObservation(command='pwd', content='/home', command_id=4) + obs2._id = 5 + obs2._cause = 4 + + events = [first_msg, cmd1, obs1, cmd2, obs2] + + # Apply truncation + truncated = controller._apply_conversation_window(events) + + # Should keep first user message and roughly half of other events + assert ( + len(truncated) >= 3 + ) # First message + at least one action-observation pair + assert truncated[0] == first_msg # First message always preserved + assert controller.state.start_id == first_msg._id + assert controller.state.truncation_id is not None + + # Verify pairs aren't split + for i, event in enumerate(truncated[1:]): + if isinstance(event, CmdOutputObservation): + assert any(e._id == event._cause for e in truncated[: i + 1]) + + def test_context_window_exceeded_handling(self, mock_event_stream, mock_agent): + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test_truncation', + confirmation_mode=False, + headless_mode=True, + ) + + # Setup initial history with IDs + first_msg = MessageAction(content='Start task', wait_for_response=False) + first_msg._source = EventSource.USER + first_msg._id = 1 + + # Add agent question + agent_msg = MessageAction( + content='What task would you like me to perform?', wait_for_response=True + ) + agent_msg._source = EventSource.AGENT + agent_msg._id = 2 + + # Add user response + user_response = MessageAction( + content='Please list all files and show me current directory', + wait_for_response=False, + ) + user_response._source = EventSource.USER + user_response._id = 3 + + cmd1 = CmdRunAction(command='ls') + cmd1._id = 4 + obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=4) + obs1._id = 5 + obs1._cause = 4 + + # Update mock event stream to include new messages + mock_event_stream.get_events.return_value = [ + first_msg, + agent_msg, + user_response, + cmd1, + obs1, + ] + controller.state.history = [first_msg, agent_msg, user_response, cmd1, obs1] + original_history_len = len(controller.state.history) + + # Simulate ContextWindowExceededError and truncation + controller.state.history = controller._apply_conversation_window( + controller.state.history + ) + + # Verify truncation occurred + assert len(controller.state.history) < original_history_len + assert controller.state.start_id == first_msg._id + assert controller.state.truncation_id is not None + assert controller.state.truncation_id > controller.state.start_id + + def test_history_restoration_after_truncation(self, mock_event_stream, mock_agent): + controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test_truncation', + confirmation_mode=False, + headless_mode=True, + ) + + # Create events with IDs + first_msg = MessageAction(content='Start task', wait_for_response=False) + first_msg._source = EventSource.USER + first_msg._id = 1 + + events = [first_msg] + for i in range(5): + cmd = CmdRunAction(command=f'cmd{i}') + cmd._id = i + 2 + obs = CmdOutputObservation( + command=f'cmd{i}', content=f'output{i}', command_id=cmd._id + ) + obs._cause = cmd._id + events.extend([cmd, obs]) + + # Set up initial history + controller.state.history = events.copy() + + # Force truncation + controller.state.history = controller._apply_conversation_window( + controller.state.history + ) + + # Save state + saved_start_id = controller.state.start_id + saved_truncation_id = controller.state.truncation_id + saved_history_len = len(controller.state.history) + + # Set up mock event stream for new controller + mock_event_stream.get_events.return_value = controller.state.history + + # Create new controller with saved state + new_controller = AgentController( + agent=mock_agent, + event_stream=mock_event_stream, + max_iterations=10, + sid='test_truncation', + confirmation_mode=False, + headless_mode=True, + ) + new_controller.state.start_id = saved_start_id + new_controller.state.truncation_id = saved_truncation_id + new_controller.state.history = mock_event_stream.get_events() + + # Verify restoration + assert len(new_controller.state.history) == saved_history_len + assert new_controller.state.history[0] == first_msg + assert new_controller.state.start_id == saved_start_id