fix: Recover from ContextWindowExceededError (#6519)

Co-authored-by: Calvin Smith <calvin@all-hands.dev>
This commit is contained in:
Calvin Smith
2025-01-29 15:25:46 -07:00
committed by GitHub
parent a6eed5b7e9
commit 473fcae57e
2 changed files with 79 additions and 3 deletions

View File

@@ -8,6 +8,7 @@ import litellm
from litellm.exceptions import (
BadRequestError,
ContextWindowExceededError,
OpenAIError,
RateLimitError,
)
@@ -42,6 +43,7 @@ from openhands.events.action import (
)
from openhands.events.event import Event
from openhands.events.observation import (
AgentCondensationObservation,
AgentDelegateObservation,
AgentStateChangedObservation,
ErrorObservation,
@@ -672,7 +674,7 @@ class AgentController:
EventSource.AGENT,
)
return
except (ContextWindowExceededError, BadRequestError) as e:
except (ContextWindowExceededError, BadRequestError, OpenAIError) as e:
# FIXME: this is a hack until a litellm fix is confirmed
# Check if this is a nested context window error
error_str = str(e).lower()
@@ -689,9 +691,16 @@ class AgentController:
# Save the ID of the first event in our truncated history for future reloading
if self.state.history:
self.state.start_id = self.state.history[0].id
# Don't add error event - let the agent retry with reduced context
# Add an error event to trigger another step by the agent
self.event_stream.add_event(
AgentCondensationObservation(
content='Trimming prompt to meet context window limitations'
),
EventSource.AGENT,
)
return
raise
raise e
if action.runnable:
if self.state.confirmation_mode and (

View File

@@ -40,6 +40,7 @@ def mock_agent():
agent = MagicMock(spec=Agent)
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = AppConfig().get_llm_config()
return agent
@@ -50,6 +51,14 @@ def mock_event_stream():
return mock
@pytest.fixture
def mock_runtime() -> Runtime:
return MagicMock(
spec=Runtime,
event_stream=EventStream(sid='test', file_store=InMemoryFileStore({})),
)
@pytest.fixture
def mock_status_callback():
return AsyncMock()
@@ -599,3 +608,61 @@ async def test_context_window_exceeded_error_handling(mock_agent, mock_event_str
# Check that the error was thrown and the history has been truncated
assert state.has_errored
assert controller.state.history == [MessageAction(content='Test message 1')]
@pytest.mark.asyncio
async def test_run_controller_with_context_window_exceeded(mock_agent, mock_runtime):
"""Tests that the controller can make progress after handling context window exceeded errors."""
class StepState:
def __init__(self):
self.has_errored = False
def step(self, state: State):
# If the state has more than one message and we haven't errored yet,
# throw the context window exceeded error
if len(state.history) > 1 and not self.has_errored:
error = ContextWindowExceededError(
message='prompt is too long: 233885 tokens > 200000 maximum',
model='',
llm_provider='',
)
self.has_errored = True
raise error
return MessageAction(content=f'STEP {len(state.history)}')
step_state = StepState()
mock_agent.step = step_state.step
try:
state = await asyncio.wait_for(
run_controller(
config=AppConfig(max_iterations=3),
initial_user_action=MessageAction(content='INITIAL'),
runtime=mock_runtime,
sid='test',
agent=mock_agent,
fake_user_response_fn=lambda _: 'repeat',
),
timeout=10,
)
# A timeout error indicates the run_controller entrypoint is not making
# progress
except asyncio.TimeoutError as e:
raise AssertionError(
'The run_controller function did not complete in time.'
) from e
# Hitting the iteration limit indicates the controller is failing for the
# expected reason
assert state.iteration == 3
assert state.agent_state == AgentState.ERROR
assert (
state.last_error
== 'RuntimeError: Agent reached maximum iteration in headless mode. Current iteration: 3, max iteration: 3'
)
# Check that the context window exceeded error was raised during the run
assert step_state.has_errored