mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 23:08:04 -05:00
fix: Recover from ContextWindowExceededError (#6519)
Co-authored-by: Calvin Smith <calvin@all-hands.dev>
This commit is contained in:
@@ -8,6 +8,7 @@ import litellm
|
||||
from litellm.exceptions import (
|
||||
BadRequestError,
|
||||
ContextWindowExceededError,
|
||||
OpenAIError,
|
||||
RateLimitError,
|
||||
)
|
||||
|
||||
@@ -42,6 +43,7 @@ from openhands.events.action import (
|
||||
)
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
AgentCondensationObservation,
|
||||
AgentDelegateObservation,
|
||||
AgentStateChangedObservation,
|
||||
ErrorObservation,
|
||||
@@ -672,7 +674,7 @@ class AgentController:
|
||||
EventSource.AGENT,
|
||||
)
|
||||
return
|
||||
except (ContextWindowExceededError, BadRequestError) as e:
|
||||
except (ContextWindowExceededError, BadRequestError, OpenAIError) as e:
|
||||
# FIXME: this is a hack until a litellm fix is confirmed
|
||||
# Check if this is a nested context window error
|
||||
error_str = str(e).lower()
|
||||
@@ -689,9 +691,16 @@ class AgentController:
|
||||
# Save the ID of the first event in our truncated history for future reloading
|
||||
if self.state.history:
|
||||
self.state.start_id = self.state.history[0].id
|
||||
# Don't add error event - let the agent retry with reduced context
|
||||
|
||||
# Add an error event to trigger another step by the agent
|
||||
self.event_stream.add_event(
|
||||
AgentCondensationObservation(
|
||||
content='Trimming prompt to meet context window limitations'
|
||||
),
|
||||
EventSource.AGENT,
|
||||
)
|
||||
return
|
||||
raise
|
||||
raise e
|
||||
|
||||
if action.runnable:
|
||||
if self.state.confirmation_mode and (
|
||||
|
||||
@@ -40,6 +40,7 @@ def mock_agent():
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = AppConfig().get_llm_config()
|
||||
return agent
|
||||
|
||||
|
||||
@@ -50,6 +51,14 @@ def mock_event_stream():
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runtime() -> Runtime:
|
||||
return MagicMock(
|
||||
spec=Runtime,
|
||||
event_stream=EventStream(sid='test', file_store=InMemoryFileStore({})),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_status_callback():
|
||||
return AsyncMock()
|
||||
@@ -599,3 +608,61 @@ async def test_context_window_exceeded_error_handling(mock_agent, mock_event_str
|
||||
# Check that the error was thrown and the history has been truncated
|
||||
assert state.has_errored
|
||||
assert controller.state.history == [MessageAction(content='Test message 1')]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_controller_with_context_window_exceeded(mock_agent, mock_runtime):
|
||||
"""Tests that the controller can make progress after handling context window exceeded errors."""
|
||||
|
||||
class StepState:
|
||||
def __init__(self):
|
||||
self.has_errored = False
|
||||
|
||||
def step(self, state: State):
|
||||
# If the state has more than one message and we haven't errored yet,
|
||||
# throw the context window exceeded error
|
||||
if len(state.history) > 1 and not self.has_errored:
|
||||
error = ContextWindowExceededError(
|
||||
message='prompt is too long: 233885 tokens > 200000 maximum',
|
||||
model='',
|
||||
llm_provider='',
|
||||
)
|
||||
self.has_errored = True
|
||||
raise error
|
||||
|
||||
return MessageAction(content=f'STEP {len(state.history)}')
|
||||
|
||||
step_state = StepState()
|
||||
mock_agent.step = step_state.step
|
||||
|
||||
try:
|
||||
state = await asyncio.wait_for(
|
||||
run_controller(
|
||||
config=AppConfig(max_iterations=3),
|
||||
initial_user_action=MessageAction(content='INITIAL'),
|
||||
runtime=mock_runtime,
|
||||
sid='test',
|
||||
agent=mock_agent,
|
||||
fake_user_response_fn=lambda _: 'repeat',
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
|
||||
# A timeout error indicates the run_controller entrypoint is not making
|
||||
# progress
|
||||
except asyncio.TimeoutError as e:
|
||||
raise AssertionError(
|
||||
'The run_controller function did not complete in time.'
|
||||
) from e
|
||||
|
||||
# Hitting the iteration limit indicates the controller is failing for the
|
||||
# expected reason
|
||||
assert state.iteration == 3
|
||||
assert state.agent_state == AgentState.ERROR
|
||||
assert (
|
||||
state.last_error
|
||||
== 'RuntimeError: Agent reached maximum iteration in headless mode. Current iteration: 3, max iteration: 3'
|
||||
)
|
||||
|
||||
# Check that the context window exceeded error was raised during the run
|
||||
assert step_state.has_errored
|
||||
|
||||
Reference in New Issue
Block a user