mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
Context Window Exceeded fix (#4977)
This commit is contained in:
@@ -5,6 +5,7 @@ import traceback
|
||||
from typing import Callable, ClassVar, Type
|
||||
|
||||
import litellm
|
||||
from litellm.exceptions import ContextWindowExceededError
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State, TrafficControlState
|
||||
@@ -485,6 +486,15 @@ class AgentController:
|
||||
EventSource.AGENT,
|
||||
)
|
||||
return
|
||||
except ContextWindowExceededError:
|
||||
# When context window is exceeded, keep roughly half of agent interactions
|
||||
self.state.history = self._apply_conversation_window(self.state.history)
|
||||
|
||||
# Save the ID of the first event in our truncated history for future reloading
|
||||
if self.state.history:
|
||||
self.state.start_id = self.state.history[0].id
|
||||
# Don't add error event - let the agent retry with reduced context
|
||||
return
|
||||
|
||||
if action.runnable:
|
||||
if self.state.confirmation_mode and (
|
||||
@@ -659,6 +669,12 @@ class AgentController:
|
||||
- For delegate events (between AgentDelegateAction and AgentDelegateObservation):
|
||||
- Excludes all events between the action and observation
|
||||
- Includes the delegate action and observation themselves
|
||||
|
||||
The history is loaded in two parts if truncation_id is set:
|
||||
1. First user message from start_id onwards
|
||||
2. Rest of history from truncation_id to the end
|
||||
|
||||
Otherwise loads normally from start_id.
|
||||
"""
|
||||
|
||||
# define range of events to fetch
|
||||
@@ -680,8 +696,33 @@ class AgentController:
|
||||
self.state.history = []
|
||||
return
|
||||
|
||||
# Get all events, filtering out backend events and hidden events
|
||||
events = list(
|
||||
events: list[Event] = []
|
||||
|
||||
# If we have a truncation point, get first user message and then rest of history
|
||||
if hasattr(self.state, 'truncation_id') and self.state.truncation_id > 0:
|
||||
# Find first user message from stream
|
||||
first_user_msg = next(
|
||||
(
|
||||
e
|
||||
for e in self.event_stream.get_events(
|
||||
start_id=start_id,
|
||||
end_id=end_id,
|
||||
reverse=False,
|
||||
filter_out_type=self.filter_out,
|
||||
filter_hidden=True,
|
||||
)
|
||||
if isinstance(e, MessageAction) and e.source == EventSource.USER
|
||||
),
|
||||
None,
|
||||
)
|
||||
if first_user_msg:
|
||||
events.append(first_user_msg)
|
||||
|
||||
# the rest of the events are from the truncation point
|
||||
start_id = self.state.truncation_id
|
||||
|
||||
# Get rest of history
|
||||
events_to_add = list(
|
||||
self.event_stream.get_events(
|
||||
start_id=start_id,
|
||||
end_id=end_id,
|
||||
@@ -690,6 +731,7 @@ class AgentController:
|
||||
filter_hidden=True,
|
||||
)
|
||||
)
|
||||
events.extend(events_to_add)
|
||||
|
||||
# Find all delegate action/observation pairs
|
||||
delegate_ranges: list[tuple[int, int]] = []
|
||||
@@ -744,6 +786,92 @@ class AgentController:
|
||||
# make sure history is in sync
|
||||
self.state.start_id = start_id
|
||||
|
||||
def _apply_conversation_window(self, events: list[Event]) -> list[Event]:
|
||||
"""Cuts history roughly in half when context window is exceeded, preserving action-observation pairs
|
||||
and ensuring the first user message is always included.
|
||||
|
||||
The algorithm:
|
||||
1. Cut history in half
|
||||
2. Check first event in new history:
|
||||
- If Observation: find and include its Action
|
||||
- If MessageAction: ensure its related Action-Observation pair isn't split
|
||||
3. Always include the first user message
|
||||
|
||||
Args:
|
||||
events: List of events to filter
|
||||
|
||||
Returns:
|
||||
Filtered list of events keeping newest half while preserving pairs
|
||||
"""
|
||||
if not events:
|
||||
return events
|
||||
|
||||
# Find first user message - we'll need to ensure it's included
|
||||
first_user_msg = next(
|
||||
(
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, MessageAction) and e.source == EventSource.USER
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# cut in half
|
||||
mid_point = max(1, len(events) // 2)
|
||||
kept_events = events[mid_point:]
|
||||
|
||||
# Handle first event in truncated history
|
||||
if kept_events:
|
||||
i = 0
|
||||
while i < len(kept_events):
|
||||
first_event = kept_events[i]
|
||||
if isinstance(first_event, Observation) and first_event.cause:
|
||||
# Find its action and include it
|
||||
matching_action = next(
|
||||
(
|
||||
e
|
||||
for e in reversed(events[:mid_point])
|
||||
if isinstance(e, Action) and e.id == first_event.cause
|
||||
),
|
||||
None,
|
||||
)
|
||||
if matching_action:
|
||||
kept_events = [matching_action] + kept_events
|
||||
else:
|
||||
self.log(
|
||||
'warning',
|
||||
f'Found Observation without matching Action at id={first_event.id}',
|
||||
)
|
||||
# drop this observation
|
||||
kept_events = kept_events[1:]
|
||||
break
|
||||
|
||||
elif isinstance(first_event, MessageAction) or (
|
||||
isinstance(first_event, Action)
|
||||
and first_event.source == EventSource.USER
|
||||
):
|
||||
# if it's a message action or a user action, keep it and continue to find the next event
|
||||
i += 1
|
||||
continue
|
||||
|
||||
else:
|
||||
# if it's an action with source == EventSource.AGENT, we're good
|
||||
break
|
||||
|
||||
# Save where to continue from in next reload
|
||||
if kept_events:
|
||||
self.state.truncation_id = kept_events[0].id
|
||||
|
||||
# Ensure first user message is included
|
||||
if first_user_msg and first_user_msg not in kept_events:
|
||||
kept_events = [first_user_msg] + kept_events
|
||||
|
||||
# start_id points to first user message
|
||||
if first_user_msg:
|
||||
self.state.start_id = first_user_msg.id
|
||||
|
||||
return kept_events
|
||||
|
||||
def _is_stuck(self):
|
||||
"""Checks if the agent or its delegate is stuck in a loop.
|
||||
|
||||
|
||||
@@ -92,6 +92,8 @@ class State:
|
||||
# start_id and end_id track the range of events in history
|
||||
start_id: int = -1
|
||||
end_id: int = -1
|
||||
# truncation_id tracks where to load history after context window truncation
|
||||
truncation_id: int = -1
|
||||
almost_stuck: int = 0
|
||||
delegates: dict[tuple[int, int], tuple[str, str]] = field(default_factory=dict)
|
||||
# NOTE: This will never be used by the controller, but it can be used by different
|
||||
|
||||
188
tests/unit/test_truncation.py
Normal file
188
tests/unit/test_truncation.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
from openhands.events import EventSource
|
||||
from openhands.events.action import CmdRunAction, MessageAction
|
||||
from openhands.events.observation import CmdOutputObservation
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_stream():
|
||||
stream = MagicMock()
|
||||
# Mock get_events to return an empty list by default
|
||||
stream.get_events.return_value = []
|
||||
return stream
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
agent = MagicMock()
|
||||
agent.llm = MagicMock()
|
||||
agent.llm.config = MagicMock()
|
||||
return agent
|
||||
|
||||
|
||||
class TestTruncation:
|
||||
def test_apply_conversation_window_basic(self, mock_event_stream, mock_agent):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_truncation',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Create a sequence of events with IDs
|
||||
first_msg = MessageAction(content='Hello, start task', wait_for_response=False)
|
||||
first_msg._source = EventSource.USER
|
||||
first_msg._id = 1
|
||||
|
||||
cmd1 = CmdRunAction(command='ls')
|
||||
cmd1._id = 2
|
||||
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=2)
|
||||
obs1._id = 3
|
||||
obs1._cause = 2
|
||||
|
||||
cmd2 = CmdRunAction(command='pwd')
|
||||
cmd2._id = 4
|
||||
obs2 = CmdOutputObservation(command='pwd', content='/home', command_id=4)
|
||||
obs2._id = 5
|
||||
obs2._cause = 4
|
||||
|
||||
events = [first_msg, cmd1, obs1, cmd2, obs2]
|
||||
|
||||
# Apply truncation
|
||||
truncated = controller._apply_conversation_window(events)
|
||||
|
||||
# Should keep first user message and roughly half of other events
|
||||
assert (
|
||||
len(truncated) >= 3
|
||||
) # First message + at least one action-observation pair
|
||||
assert truncated[0] == first_msg # First message always preserved
|
||||
assert controller.state.start_id == first_msg._id
|
||||
assert controller.state.truncation_id is not None
|
||||
|
||||
# Verify pairs aren't split
|
||||
for i, event in enumerate(truncated[1:]):
|
||||
if isinstance(event, CmdOutputObservation):
|
||||
assert any(e._id == event._cause for e in truncated[: i + 1])
|
||||
|
||||
def test_context_window_exceeded_handling(self, mock_event_stream, mock_agent):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_truncation',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Setup initial history with IDs
|
||||
first_msg = MessageAction(content='Start task', wait_for_response=False)
|
||||
first_msg._source = EventSource.USER
|
||||
first_msg._id = 1
|
||||
|
||||
# Add agent question
|
||||
agent_msg = MessageAction(
|
||||
content='What task would you like me to perform?', wait_for_response=True
|
||||
)
|
||||
agent_msg._source = EventSource.AGENT
|
||||
agent_msg._id = 2
|
||||
|
||||
# Add user response
|
||||
user_response = MessageAction(
|
||||
content='Please list all files and show me current directory',
|
||||
wait_for_response=False,
|
||||
)
|
||||
user_response._source = EventSource.USER
|
||||
user_response._id = 3
|
||||
|
||||
cmd1 = CmdRunAction(command='ls')
|
||||
cmd1._id = 4
|
||||
obs1 = CmdOutputObservation(command='ls', content='file1.txt', command_id=4)
|
||||
obs1._id = 5
|
||||
obs1._cause = 4
|
||||
|
||||
# Update mock event stream to include new messages
|
||||
mock_event_stream.get_events.return_value = [
|
||||
first_msg,
|
||||
agent_msg,
|
||||
user_response,
|
||||
cmd1,
|
||||
obs1,
|
||||
]
|
||||
controller.state.history = [first_msg, agent_msg, user_response, cmd1, obs1]
|
||||
original_history_len = len(controller.state.history)
|
||||
|
||||
# Simulate ContextWindowExceededError and truncation
|
||||
controller.state.history = controller._apply_conversation_window(
|
||||
controller.state.history
|
||||
)
|
||||
|
||||
# Verify truncation occurred
|
||||
assert len(controller.state.history) < original_history_len
|
||||
assert controller.state.start_id == first_msg._id
|
||||
assert controller.state.truncation_id is not None
|
||||
assert controller.state.truncation_id > controller.state.start_id
|
||||
|
||||
def test_history_restoration_after_truncation(self, mock_event_stream, mock_agent):
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_truncation',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Create events with IDs
|
||||
first_msg = MessageAction(content='Start task', wait_for_response=False)
|
||||
first_msg._source = EventSource.USER
|
||||
first_msg._id = 1
|
||||
|
||||
events = [first_msg]
|
||||
for i in range(5):
|
||||
cmd = CmdRunAction(command=f'cmd{i}')
|
||||
cmd._id = i + 2
|
||||
obs = CmdOutputObservation(
|
||||
command=f'cmd{i}', content=f'output{i}', command_id=cmd._id
|
||||
)
|
||||
obs._cause = cmd._id
|
||||
events.extend([cmd, obs])
|
||||
|
||||
# Set up initial history
|
||||
controller.state.history = events.copy()
|
||||
|
||||
# Force truncation
|
||||
controller.state.history = controller._apply_conversation_window(
|
||||
controller.state.history
|
||||
)
|
||||
|
||||
# Save state
|
||||
saved_start_id = controller.state.start_id
|
||||
saved_truncation_id = controller.state.truncation_id
|
||||
saved_history_len = len(controller.state.history)
|
||||
|
||||
# Set up mock event stream for new controller
|
||||
mock_event_stream.get_events.return_value = controller.state.history
|
||||
|
||||
# Create new controller with saved state
|
||||
new_controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test_truncation',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
new_controller.state.start_id = saved_start_id
|
||||
new_controller.state.truncation_id = saved_truncation_id
|
||||
new_controller.state.history = mock_event_stream.get_events()
|
||||
|
||||
# Verify restoration
|
||||
assert len(new_controller.state.history) == saved_history_len
|
||||
assert new_controller.state.history[0] == first_msg
|
||||
assert new_controller.state.start_id == saved_start_id
|
||||
Reference in New Issue
Block a user