Enhance dead-loop recovery by pausing agent and reprompting (#11439)

Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
softpudding
2025-10-24 19:25:14 +08:00
committed by GitHub
parent 4b303ec9b4
commit 17e32af6fe
21 changed files with 932 additions and 43 deletions

View File

@@ -47,6 +47,7 @@ from openhands.core.schema.exit_reason import ExitReason
from openhands.events import EventSource
from openhands.events.action import (
ChangeAgentStateAction,
LoopRecoveryAction,
MessageAction,
)
from openhands.events.stream import EventStream
@@ -159,9 +160,9 @@ async def handle_commands(
exit_reason = ExitReason.INTENTIONAL
elif command == '/settings':
await handle_settings_command(config, settings_store)
elif command == '/resume':
elif command.startswith('/resume'):
close_repl, new_session_requested = await handle_resume_command(
event_stream, agent_state
command, event_stream, agent_state
)
elif command == '/mcp':
await handle_mcp_command(config)
@@ -294,6 +295,7 @@ async def handle_settings_command(
# Setting the agent state to RUNNING will currently freeze the agent without continuing with the rest of the task.
# This is a workaround to handle the resume command for the time being. Replace user message with the state change event once the issue is fixed.
async def handle_resume_command(
command: str,
event_stream: EventStream,
agent_state: str,
) -> tuple[bool, bool]:
@@ -309,10 +311,29 @@ async def handle_resume_command(
)
return close_repl, new_session_requested
event_stream.add_event(
MessageAction(content='continue'),
EventSource.USER,
)
# Check if this is a loop recovery resume with an option
if command.strip() != '/resume':
# Parse the option from the command (e.g., '/resume 1', '/resume 2')
parts = command.strip().split()
if len(parts) == 2 and parts[1] in ['1', '2']:
option = parts[1]
# Send the option as a message to be handled by the controller
event_stream.add_event(
LoopRecoveryAction(option=int(option)),
EventSource.USER,
)
else:
# Invalid format, send as regular resume
event_stream.add_event(
MessageAction(content='continue'),
EventSource.USER,
)
else:
# Regular resume without loop recovery option
event_stream.add_event(
MessageAction(content='continue'),
EventSource.USER,
)
# event_stream.add_event(
# ChangeAgentStateAction(AgentState.RUNNING),

View File

@@ -430,9 +430,25 @@ async def run_session(
# No session restored, no initial action: prompt for the user's first message
asyncio.create_task(prompt_for_next_task(''))
await run_agent_until_done(
controller, runtime, memory, [AgentState.STOPPED, AgentState.ERROR]
)
skip_set_callback = False
while True:
await run_agent_until_done(
controller,
runtime,
memory,
[AgentState.STOPPED, AgentState.ERROR],
skip_set_callback,
)
# Try loop recovery in CLI app
if (
controller.state.agent_state == AgentState.ERROR
and controller.state.last_error.startswith('AgentStuckInLoopError')
):
controller.attempt_loop_recovery()
skip_set_callback = True
continue
else:
break
await cleanup_session(loop, agent, runtime, controller)

View File

@@ -59,6 +59,7 @@ from openhands.events.observation import (
ErrorObservation,
FileEditObservation,
FileReadObservation,
LoopDetectionObservation,
MCPObservation,
TaskTrackingObservation,
)
@@ -309,6 +310,8 @@ def display_event(event: Event, config: OpenHandsConfig) -> None:
display_agent_state_change_message(event.agent_state)
elif isinstance(event, ErrorObservation):
display_error(event.content)
elif isinstance(event, LoopDetectionObservation):
handle_loop_recovery_state_observation(event)
def display_message(message: str, is_agent_message: bool = False) -> None:
@@ -1039,3 +1042,25 @@ class UserCancelledError(Exception):
"""Raised when the user cancels an operation via key binding."""
pass
def handle_loop_recovery_state_observation(
observation: LoopDetectionObservation,
) -> None:
"""Handle loop recovery state observation events.
Updates the global loop recovery state based on the observation.
"""
content = observation.content
container = Frame(
TextArea(
text=content,
read_only=True,
style=COLOR_GREY,
wrap_lines=True,
),
title='Agent Loop Detection',
style=f'fg:{COLOR_GREY}',
)
print_formatted_text('')
print_container(container)

View File

@@ -64,6 +64,7 @@ from openhands.events.action import (
MessageAction,
NullAction,
SystemMessageAction,
LoopRecoveryAction,
)
from openhands.events.action.agent import (
CondensationAction,
@@ -77,6 +78,7 @@ from openhands.events.observation import (
ErrorObservation,
NullObservation,
Observation,
LoopDetectionObservation,
)
from openhands.events.serialization.event import truncate_content
from openhands.llm.metrics import Metrics
@@ -523,6 +525,8 @@ class AgentController:
elif isinstance(action, AgentRejectAction):
self.state.outputs = action.outputs
await self.set_agent_state_to(AgentState.REJECTED)
elif isinstance(action, LoopRecoveryAction):
await self._handle_loop_recovery_action(action)
async def _handle_observation(self, observation: Observation) -> None:
"""Handles observation from the event stream.
@@ -595,6 +599,25 @@ class AgentController:
if action.wait_for_response:
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
async def _handle_loop_recovery_action(self, action: LoopRecoveryAction) -> None:
# Check if this is a loop recovery option
if self._stuck_detector.stuck_analysis:
option = action.option
# Handle the loop recovery option
if option == 1:
# Option 1: Restart from before loop
await self._perform_loop_recovery(self._stuck_detector.stuck_analysis)
elif option == 2:
# Option 2: Restart with last user message
await self._restart_with_last_user_message(
self._stuck_detector.stuck_analysis
)
elif option == 3:
# Option 3: Stop agent completely
await self.set_agent_state_to(AgentState.STOPPED)
return
def _reset(self) -> None:
"""Resets the agent controller."""
# Runnable actions need an Observation
@@ -1084,6 +1107,45 @@ class AgentController:
return self._stuck_detector.is_stuck(self.headless_mode)
def attempt_loop_recovery(self) -> bool:
"""Attempts loop recovery when agent is stuck in a loop.
Only supports CLI for now.
Returns:
bool: True if recovery was successful and agent should continue,
False if recovery failed or was not attempted.
"""
# Check if we're in a loop
if not self._stuck_detector.stuck_analysis:
return False
"""Handle loop recovery in CLI mode by pausing the agent and presenting recovery options."""
recovery_point = self._stuck_detector.stuck_analysis.loop_start_idx
# Present loop detection message
self.event_stream.add_event(
LoopDetectionObservation(
content=f"""⚠️ Agent detected in a loop!
Loop type: {self._stuck_detector.stuck_analysis.loop_type}
Loop detected at iteration {self.state.iteration_flag.current_value}
\nRecovery options:
/resume 1. Restart from before loop (preserves {recovery_point} events)
/resume 2. Restart with last user message (reuses your most recent instruction)
/exit. Quit directly
\nThe agent has been paused. Type '/resume 1', '/resume 2', or '/exit' to choose an option.
"""
),
source=EventSource.ENVIRONMENT,
)
# Pause the agent using the same mechanism as Ctrl+P
# This ensures consistent behavior and avoids event loop conflicts
self.event_stream.add_event(
ChangeAgentStateAction(AgentState.PAUSED),
EventSource.ENVIRONMENT, # Use ENVIRONMENT source to distinguish from user pause
)
return True
def _prepare_metrics_for_frontend(self, action: Action) -> None:
"""Create a minimal metrics object for frontend display and log it.
@@ -1208,5 +1270,92 @@ class AgentController:
)
return self._cached_first_user_message
async def _perform_loop_recovery(
self, stuck_analysis: StuckDetector.StuckAnalysis
) -> None:
"""Perform loop recovery by truncating memory and restarting from before the loop."""
recovery_point = stuck_analysis.loop_start_idx
# Truncate memory to the recovery point
await self._truncate_memory_to_point(recovery_point)
# Set agent state to AWAITING_USER_INPUT to allow user to provide new instructions
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
self.event_stream.add_event(
LoopDetectionObservation(
content="""✅ Loop recovery completed. Agent has been reset to before the loop.
You can now provide new instructions to continue.
"""
),
source=EventSource.ENVIRONMENT,
)
async def _truncate_memory_to_point(self, recovery_point: int) -> None:
"""Truncate memory to the specified recovery point."""
# Get all events from state history
all_events = self.state.history
if recovery_point >= len(all_events):
return
# Keep only events up to the recovery point
events_to_keep = all_events[:recovery_point]
# Update state history
self.state.history = events_to_keep
# Update end_id to reflect the truncation
if events_to_keep:
self.state.end_id = events_to_keep[-1].id
else:
self.state.end_id = -1
# Clear any cached messages
self._cached_first_user_message = None
async def _restart_with_last_user_message(
self, stuck_analysis: StuckDetector.StuckAnalysis
) -> None:
"""Restart the agent using the last user message as the new instruction."""
# Find the last user message in the history
last_user_message = None
for event in reversed(self.state.history):
if isinstance(event, MessageAction) and event.source == EventSource.USER:
last_user_message = event
break
if last_user_message:
# Truncate memory to just before the loop started
recovery_point = stuck_analysis.loop_start_idx
await self._truncate_memory_to_point(recovery_point)
# Set agent state to RUNNING and re-use the last user message
await self.set_agent_state_to(AgentState.RUNNING)
# Re-use the last user message as the new instruction
self.event_stream.add_event(
LoopDetectionObservation(
content=f"""\n✅ Restarting with your last instruction: {last_user_message.content}
Agent is now continuing with the same task...
"""
),
source=EventSource.ENVIRONMENT,
)
# Create a new action with the last user message
new_action = MessageAction(
content=last_user_message.content, wait_for_response=False
)
new_action._source = EventSource.USER # type: ignore [attr-defined]
# Process the action to restart the agent
await self._handle_action(new_action)
else:
# If no user message found, fall back to regular recovery
print('\n⚠️ No previous user message found. Using standard recovery.')
await self._perform_loop_recovery(stuck_analysis)
def save_state(self):
self.state_tracker.save_state()

View File

@@ -1,10 +1,13 @@
from dataclasses import dataclass
from typing import Optional
from openhands.controller.state.state import State
from openhands.core.logger import openhands_logger as logger
from openhands.events import Event, EventSource
from openhands.events.action.action import Action
from openhands.events.action.commands import IPythonRunCellAction
from openhands.events.action.empty import NullAction
from openhands.events.action.message import MessageAction
from openhands.events.event import Event, EventSource
from openhands.events.observation import (
CmdOutputObservation,
IPythonRunCellObservation,
@@ -22,8 +25,15 @@ class StuckDetector:
'SyntaxError: incomplete input',
]
@dataclass
class StuckAnalysis:
loop_type: str
loop_repeat_times: int
loop_start_idx: int # in filtered_history
def __init__(self, state: State):
self.state = state
self.stuck_analysis: Optional[StuckDetector.StuckAnalysis] = None
def is_stuck(self, headless_mode: bool = True) -> bool:
"""Checks if the agent is stuck in a loop.
@@ -36,6 +46,7 @@ class StuckDetector:
Returns:
bool: True if the agent is stuck in a loop, False otherwise.
"""
filtered_history_offset = 0
if not headless_mode:
# In interactive mode, only look at history after the last user message
last_user_msg_idx = -1
@@ -46,7 +57,7 @@ class StuckDetector:
):
last_user_msg_idx = len(self.state.history) - i - 1
break
filtered_history_offset = last_user_msg_idx + 1
history_to_check = self.state.history[last_user_msg_idx + 1 :]
else:
# In headless mode, look at all history
@@ -86,31 +97,45 @@ class StuckDetector:
break
# scenario 1: same action, same observation
if self._is_stuck_repeating_action_observation(last_actions, last_observations):
if self._is_stuck_repeating_action_observation(
last_actions, last_observations, filtered_history, filtered_history_offset
):
return True
# scenario 2: same action, errors
if self._is_stuck_repeating_action_error(last_actions, last_observations):
if self._is_stuck_repeating_action_error(
last_actions, last_observations, filtered_history, filtered_history_offset
):
return True
# scenario 3: monologue
if self._is_stuck_monologue(filtered_history):
if self._is_stuck_monologue(filtered_history, filtered_history_offset):
return True
# scenario 4: action, observation pattern on the last six steps
if len(filtered_history) >= 6:
if self._is_stuck_action_observation_pattern(filtered_history):
if self._is_stuck_action_observation_pattern(
filtered_history, filtered_history_offset
):
return True
# scenario 5: context window error loop
if len(filtered_history) >= 10:
if self._is_stuck_context_window_error(filtered_history):
if self._is_stuck_context_window_error(
filtered_history, filtered_history_offset
):
return True
# Empty stuck_analysis when not stuck
self.stuck_analysis = None
return False
def _is_stuck_repeating_action_observation(
self, last_actions: list[Event], last_observations: list[Event]
self,
last_actions: list[Event],
last_observations: list[Event],
filtered_history: list[Event],
filtered_history_offset: int = 0,
) -> bool:
# scenario 1: same action, same observation
# it takes 4 actions and 4 observations to detect a loop
@@ -128,12 +153,22 @@ class StuckDetector:
if actions_equal and observations_equal:
logger.warning('Action, Observation loop detected')
self.stuck_analysis = StuckDetector.StuckAnalysis(
loop_type='repeating_action_observation',
loop_repeat_times=4,
loop_start_idx=filtered_history.index(last_actions[-1])
+ filtered_history_offset,
)
return True
return False
def _is_stuck_repeating_action_error(
self, last_actions: list[Event], last_observations: list[Event]
self,
last_actions: list[Event],
last_observations: list[Event],
filtered_history: list[Event],
filtered_history_offset: int = 0,
) -> bool:
# scenario 2: same action, errors
# it takes 3 actions and 3 observations to detect a loop
@@ -147,6 +182,12 @@ class StuckDetector:
# and the last three observations are all errors?
if all(isinstance(obs, ErrorObservation) for obs in last_observations[:3]):
logger.warning('Action, ErrorObservation loop detected')
self.stuck_analysis = StuckDetector.StuckAnalysis(
loop_type='repeating_action_error',
loop_repeat_times=3,
loop_start_idx=filtered_history.index(last_actions[-1])
+ filtered_history_offset,
)
return True
# or, are the last three observations all IPythonRunCellObservation with SyntaxError?
elif all(
@@ -167,6 +208,12 @@ class StuckDetector:
error_message,
):
logger.warning(warning)
self.stuck_analysis = StuckDetector.StuckAnalysis(
loop_type='repeating_action_error',
loop_repeat_times=3,
loop_start_idx=filtered_history.index(last_actions[-1])
+ filtered_history_offset,
)
return True
elif error_message in (
'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
@@ -180,6 +227,12 @@ class StuckDetector:
error_message,
):
logger.warning(warning)
self.stuck_analysis = StuckDetector.StuckAnalysis(
loop_type='repeating_action_error',
loop_repeat_times=3,
loop_start_idx=filtered_history.index(last_actions[-1])
+ filtered_history_offset,
)
return True
return False
@@ -255,7 +308,9 @@ class StuckDetector:
# and the 3rd-to-last line is identical across all occurrences
return len(error_lines) == 3 and len(set(error_lines)) == 1
def _is_stuck_monologue(self, filtered_history: list[Event]) -> bool:
def _is_stuck_monologue(
self, filtered_history: list[Event], filtered_history_offset: int = 0
) -> bool:
# scenario 3: monologue
# check for repeated MessageActions with source=AGENT
# see if the agent is engaged in a good old monologue, telling itself the same thing over and over
@@ -286,11 +341,16 @@ class StuckDetector:
if not has_observation_between:
logger.warning('Repeated MessageAction with source=AGENT detected')
self.stuck_analysis = StuckDetector.StuckAnalysis(
loop_type='monologue',
loop_repeat_times=3,
loop_start_idx=start_index + filtered_history_offset,
)
return True
return False
def _is_stuck_action_observation_pattern(
self, filtered_history: list[Event]
self, filtered_history: list[Event], filtered_history_offset: int = 0
) -> bool:
# scenario 4: action, observation pattern on the last six steps
# check if the agent repeats the same (Action, Observation)
@@ -330,10 +390,18 @@ class StuckDetector:
if actions_equal and observations_equal:
logger.warning('Action, Observation pattern detected')
self.stuck_analysis = StuckDetector.StuckAnalysis(
loop_type='repeating_action_observation_pattern',
loop_repeat_times=3,
loop_start_idx=filtered_history.index(last_six_actions[-1])
+ filtered_history_offset,
)
return True
return False
def _is_stuck_context_window_error(self, filtered_history: list[Event]) -> bool:
def _is_stuck_context_window_error(
self, filtered_history: list[Event], filtered_history_offset: int = 0
) -> bool:
"""Detects if we're stuck in a loop of context window errors.
This happens when we repeatedly get context window errors and try to trim,
@@ -377,6 +445,11 @@ class StuckDetector:
logger.warning(
'Context window error loop detected - repeated condensation events'
)
self.stuck_analysis = StuckDetector.StuckAnalysis(
loop_type='context_window_error',
loop_repeat_times=2,
loop_start_idx=start_idx + filtered_history_offset,
)
return True
return False

View File

@@ -13,6 +13,7 @@ async def run_agent_until_done(
runtime: Runtime,
memory: Memory,
end_states: list[AgentState],
skip_set_callback: bool = False,
) -> None:
"""run_agent_until_done takes a controller and a runtime, and will run
the agent until it reaches a terminal state.
@@ -28,18 +29,19 @@ async def run_agent_until_done(
else:
logger.info(msg)
if hasattr(runtime, 'status_callback') and runtime.status_callback:
raise ValueError(
'Runtime status_callback was set, but run_agent_until_done will override it'
)
if hasattr(controller, 'status_callback') and controller.status_callback:
raise ValueError(
'Controller status_callback was set, but run_agent_until_done will override it'
)
if not skip_set_callback:
if hasattr(runtime, 'status_callback') and runtime.status_callback:
raise ValueError(
'Runtime status_callback was set, but run_agent_until_done will override it'
)
if hasattr(controller, 'status_callback') and controller.status_callback:
raise ValueError(
'Controller status_callback was set, but run_agent_until_done will override it'
)
runtime.status_callback = status_callback
controller.status_callback = status_callback
memory.status_callback = status_callback
runtime.status_callback = status_callback
controller.status_callback = status_callback
memory.status_callback = status_callback
while controller.state.agent_state not in end_states:
await asyncio.sleep(1)

View File

@@ -97,3 +97,6 @@ class ActionType(str, Enum):
TASK_TRACKING = 'task_tracking'
"""Views or updates the task list for task management."""
LOOP_RECOVERY = 'loop_recovery'
"""Recover dead loop."""

View File

@@ -58,3 +58,6 @@ class ObservationType(str, Enum):
TASK_TRACKING = 'task_tracking'
"""Result of a task tracking operation"""
LOOP_DETECTION = 'loop_detection'
"""Results of a dead-loop detection"""

View File

@@ -9,6 +9,7 @@ from openhands.events.action.agent import (
AgentRejectAction,
AgentThinkAction,
ChangeAgentStateAction,
LoopRecoveryAction,
RecallAction,
TaskTrackingAction,
)
@@ -45,4 +46,5 @@ __all__ = [
'MCPAction',
'TaskTrackingAction',
'ActionSecurityRisk',
'LoopRecoveryAction',
]

View File

@@ -226,3 +226,17 @@ class TaskTrackingAction(Action):
return 'Managing 1 task item.'
else:
return f'Managing {num_tasks} task items.'
@dataclass
class LoopRecoveryAction(Action):
"""An action that shows three ways to handle dead loop.
The class should be invisible to LLM.
Attributes:
option (int): 1 allow user to prompt again
2 automatically use latest user prompt
3 stop agent
"""
option: int = 1
action: str = ActionType.LOOP_RECOVERY

View File

@@ -22,6 +22,7 @@ from openhands.events.observation.files import (
FileReadObservation,
FileWriteObservation,
)
from openhands.events.observation.loop_recovery import LoopDetectionObservation
from openhands.events.observation.mcp import MCPObservation
from openhands.events.observation.observation import Observation
from openhands.events.observation.reject import UserRejectObservation
@@ -47,6 +48,7 @@ __all__ = [
'AgentCondensationObservation',
'RecallObservation',
'RecallType',
'LoopDetectionObservation',
'MCPObservation',
'FileDownloadObservation',
'TaskTrackingObservation',

View File

@@ -0,0 +1,18 @@
from dataclasses import dataclass
from openhands.core.schema import ObservationType
from openhands.events.observation.observation import Observation
@dataclass
class LoopDetectionObservation(Observation):
"""Observation for loop recovery state changes.
This observation is used to notify the UI layer when agent
is in loop recovery mode.
This observation is CLI-specific and should only be displayed
in CLI/TUI mode, not in GUI or other UI modes.
"""
observation: str = ObservationType.LOOP_DETECTION

View File

@@ -10,6 +10,7 @@ from openhands.events.action.agent import (
ChangeAgentStateAction,
CondensationAction,
CondensationRequestAction,
LoopRecoveryAction,
RecallAction,
TaskTrackingAction,
)
@@ -48,6 +49,7 @@ actions = (
CondensationRequestAction,
MCPAction,
TaskTrackingAction,
LoopRecoveryAction,
)
ACTION_TYPE_TO_CLASS = {action_class.action: action_class for action_class in actions} # type: ignore[attr-defined]

View File

@@ -26,6 +26,7 @@ from openhands.events.observation.files import (
FileReadObservation,
FileWriteObservation,
)
from openhands.events.observation.loop_recovery import LoopDetectionObservation
from openhands.events.observation.mcp import MCPObservation
from openhands.events.observation.observation import Observation
from openhands.events.observation.reject import UserRejectObservation
@@ -51,6 +52,7 @@ observations = (
MCPObservation,
FileDownloadObservation,
TaskTrackingObservation,
LoopDetectionObservation,
)
OBSERVATION_TYPE_TO_CLASS = {

View File

@@ -33,6 +33,7 @@ from openhands.events.observation import (
FileEditObservation,
FileReadObservation,
IPythonRunCellObservation,
LoopDetectionObservation,
TaskTrackingObservation,
UserRejectObservation,
)
@@ -524,6 +525,9 @@ class ConversationMemory:
elif isinstance(obs, FileDownloadObservation):
text = truncate_content(obs.content, max_message_chars)
message = Message(role='user', content=[TextContent(text=text)])
elif isinstance(obs, LoopDetectionObservation):
# LoopRecovery should not be observed by llm, handled internally.
return []
elif (
isinstance(obs, RecallObservation)
and self.agent_config.enable_prompt_extensions

View File

@@ -632,7 +632,7 @@ async def test_main_with_session_name_passes_name_to_run_session(
) # For REPL control
@patch('openhands.cli.main.handle_commands', new_callable=AsyncMock) # For REPL control
@patch('openhands.core.setup.State.restore_from_session') # Key mock
@patch('openhands.controller.AgentController.__init__') # To check initial_state
@patch('openhands.cli.main.create_controller') # To check initial_state
@patch('openhands.cli.main.display_runtime_initialization_message') # Cosmetic
@patch('openhands.cli.main.display_initialization_animation') # Cosmetic
@patch('openhands.cli.main.initialize_repository_for_runtime') # Cosmetic / setup
@@ -644,7 +644,7 @@ async def test_run_session_with_name_attempts_state_restore(
mock_initialize_repo,
mock_display_init_anim,
mock_display_runtime_init,
mock_agent_controller_init,
mock_create_controller,
mock_restore_from_session,
mock_handle_commands,
mock_read_prompt_input,
@@ -680,8 +680,20 @@ async def test_run_session_with_name_attempts_state_restore(
mock_loaded_state = MagicMock(spec=State)
mock_restore_from_session.return_value = mock_loaded_state
# AgentController.__init__ should not return a value (it's __init__)
mock_agent_controller_init.return_value = None
# Create a mock controller with state attribute
mock_controller = MagicMock()
mock_controller.state = MagicMock()
mock_controller.state.agent_state = None
mock_controller.state.last_error = None
# Mock create_controller to return the mock controller and loaded state
# but still call the real restore_from_session
def create_controller_side_effect(*args, **kwargs):
# Call the real restore_from_session to verify it's called
mock_restore_from_session(expected_sid, mock_runtime.event_stream.file_store)
return (mock_controller, mock_loaded_state)
mock_create_controller.side_effect = create_controller_side_effect
# To make run_session exit cleanly after one loop
mock_read_prompt_input.return_value = '/exit'
@@ -712,10 +724,10 @@ async def test_run_session_with_name_attempts_state_restore(
expected_sid, mock_runtime.event_stream.file_store
)
# Check that AgentController was initialized with the loaded state
mock_agent_controller_init.assert_called_once()
args, kwargs = mock_agent_controller_init.call_args
assert kwargs.get('initial_state') == mock_loaded_state
# Check that create_controller was called and returned the loaded state
mock_create_controller.assert_called_once()
# The create_controller should have been called with the loaded state
# (this is verified by the fact that restore_from_session was called and returned mock_loaded_state)
@pytest.mark.asyncio

View File

@@ -573,7 +573,7 @@ class TestHandleResumeCommand:
# Call the function with PAUSED state
close_repl, new_session_requested = await handle_resume_command(
event_stream, AgentState.PAUSED
'/resume', event_stream, AgentState.PAUSED
)
# Check that the event stream add_event was called with the correct message action
@@ -604,7 +604,7 @@ class TestHandleResumeCommand:
event_stream = MagicMock(spec=EventStream)
close_repl, new_session_requested = await handle_resume_command(
event_stream, invalid_state
'/resume', event_stream, invalid_state
)
# Check that no event was added to the stream

View File

@@ -0,0 +1,143 @@
"""Tests for CLI loop recovery functionality."""
from unittest.mock import MagicMock, patch
import pytest
from openhands.cli.commands import handle_resume_command
from openhands.controller.agent_controller import AgentController
from openhands.controller.stuck import StuckDetector
from openhands.core.schema import AgentState
from openhands.events import EventSource
from openhands.events.action import LoopRecoveryAction, MessageAction
from openhands.events.stream import EventStream
class TestCliLoopRecoveryIntegration:
"""Integration tests for CLI loop recovery functionality."""
@pytest.mark.asyncio
async def test_loop_recovery_resume_option_1(self):
"""Test that resume option 1 triggers loop recovery with memory truncation."""
# Create a mock agent controller with stuck analysis
mock_controller = MagicMock(spec=AgentController)
mock_controller._stuck_detector = MagicMock(spec=StuckDetector)
mock_controller._stuck_detector.stuck_analysis = MagicMock()
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
# Mock the loop recovery methods
mock_controller._perform_loop_recovery = MagicMock()
mock_controller._restart_with_last_user_message = MagicMock()
mock_controller.set_agent_state_to = MagicMock()
mock_controller._loop_recovery_info = None
# Create a mock event stream
event_stream = MagicMock(spec=EventStream)
# Call handle_resume_command with option 1
close_repl, new_session_requested = await handle_resume_command(
'/resume 1', event_stream, AgentState.PAUSED
)
# Verify that LoopRecoveryAction was added to the event stream
event_stream.add_event.assert_called_once()
args, kwargs = event_stream.add_event.call_args
loop_recovery_action, source = args
assert isinstance(loop_recovery_action, LoopRecoveryAction)
assert loop_recovery_action.option == 1
assert source == EventSource.USER
# Check the return values
assert close_repl is True
assert new_session_requested is False
@pytest.mark.asyncio
async def test_loop_recovery_resume_option_2(self):
"""Test that resume option 2 triggers restart with last user message."""
# Create a mock event stream
event_stream = MagicMock(spec=EventStream)
# Call handle_resume_command with option 2
close_repl, new_session_requested = await handle_resume_command(
'/resume 2', event_stream, AgentState.PAUSED
)
# Verify that LoopRecoveryAction was added to the event stream
event_stream.add_event.assert_called_once()
args, kwargs = event_stream.add_event.call_args
loop_recovery_action, source = args
assert isinstance(loop_recovery_action, LoopRecoveryAction)
assert loop_recovery_action.option == 2
assert source == EventSource.USER
# Check the return values
assert close_repl is True
assert new_session_requested is False
@pytest.mark.asyncio
async def test_regular_resume_without_loop_recovery(self):
"""Test that regular resume without option sends continue message."""
# Create a mock event stream
event_stream = MagicMock(spec=EventStream)
# Call handle_resume_command without loop recovery option
close_repl, new_session_requested = await handle_resume_command(
'/resume', event_stream, AgentState.PAUSED
)
# Verify that MessageAction was added to the event stream
event_stream.add_event.assert_called_once()
args, kwargs = event_stream.add_event.call_args
message_action, source = args
assert isinstance(message_action, MessageAction)
assert message_action.content == 'continue'
assert source == EventSource.USER
# Check the return values
assert close_repl is True
assert new_session_requested is False
@pytest.mark.asyncio
async def test_handle_commands_with_loop_recovery_resume(self):
"""Test that handle_commands properly routes loop recovery resume commands."""
from openhands.cli.commands import handle_commands
# Create mock dependencies
event_stream = MagicMock(spec=EventStream)
usage_metrics = MagicMock()
sid = 'test-session-id'
config = MagicMock()
current_dir = '/test/dir'
settings_store = MagicMock()
agent_state = AgentState.PAUSED
# Mock handle_resume_command
with patch(
'openhands.cli.commands.handle_resume_command'
) as mock_handle_resume:
mock_handle_resume.return_value = (False, False)
# Call handle_commands with loop recovery resume
close_repl, reload_microagents, new_session, _ = await handle_commands(
'/resume 1',
event_stream,
usage_metrics,
sid,
config,
current_dir,
settings_store,
agent_state,
)
# Check that handle_resume_command was called with correct args
mock_handle_resume.assert_called_once_with(
'/resume 1', event_stream, agent_state
)
# Check the return values
assert close_repl is False
assert reload_microagents is False
assert new_session is False

View File

@@ -271,7 +271,7 @@ class TestCliCommandsPauseResume:
)
# Check that handle_resume_command was called with correct args
mock_handle_resume.assert_called_once_with(event_stream, agent_state)
mock_handle_resume.assert_called_once_with(message, event_stream, agent_state)
# Check the return values
assert close_repl is False

View File

@@ -0,0 +1,374 @@
"""Tests for agent controller loop recovery functionality."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from openhands.controller.agent_controller import AgentController
from openhands.controller.stuck import StuckDetector
from openhands.core.schema import AgentState
from openhands.events import EventStream
from openhands.events.action import LoopRecoveryAction, MessageAction
from openhands.events.observation import LoopDetectionObservation
from openhands.server.services.conversation_stats import ConversationStats
from openhands.storage.memory import InMemoryFileStore
class TestAgentControllerLoopRecovery:
"""Tests for agent controller loop recovery functionality."""
@pytest.fixture
def mock_controller(self):
"""Create a mock agent controller for testing."""
# Create mock dependencies
mock_event_stream = MagicMock(
spec=EventStream,
event_stream=EventStream(
sid='test-session-id', file_store=InMemoryFileStore({})
),
)
mock_event_stream.sid = 'test-session-id'
mock_event_stream.get_latest_event_id.return_value = 0
mock_conversation_stats = MagicMock(spec=ConversationStats)
mock_agent = MagicMock()
mock_agent.act = AsyncMock()
# Create controller with correct parameters
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
conversation_stats=mock_conversation_stats,
iteration_delta=100,
headless_mode=True,
)
# Mock state properties
controller.state.history = []
controller.state.agent_state = AgentState.RUNNING
controller.state.iteration_flag = MagicMock()
controller.state.iteration_flag.current_value = 10
# Mock stuck detector
controller._stuck_detector = MagicMock(spec=StuckDetector)
controller._stuck_detector.stuck_analysis = None
controller._stuck_detector.is_stuck = MagicMock(return_value=False)
return controller
@pytest.mark.asyncio
async def test_controller_detects_loop_and_produces_observation(
self, mock_controller
):
"""Test that controller detects loops and produces LoopDetectionObservation."""
# Setup stuck detector to detect a loop
mock_controller._stuck_detector.is_stuck.return_value = True
mock_controller._stuck_detector.stuck_analysis = MagicMock()
mock_controller._stuck_detector.stuck_analysis.loop_type = (
'repeating_action_observation'
)
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
# Call attempt_loop_recovery
result = mock_controller.attempt_loop_recovery()
# Verify that loop recovery was attempted
assert result is True
# Verify that LoopDetectionObservation was added to event stream
mock_controller.event_stream.add_event.assert_called()
# Check that LoopDetectionObservation was created
calls = mock_controller.event_stream.add_event.call_args_list
loop_detection_found = False
pause_action_found = False
for call in calls:
args, _ = call
# add_event only takes one argument (the event)
event = args[0]
if isinstance(event, LoopDetectionObservation):
loop_detection_found = True
assert 'Agent detected in a loop!' in event.content
assert 'repeating_action_observation' in event.content
assert 'Loop detected at iteration 10' in event.content
elif (
hasattr(event, 'agent_state') and event.agent_state == AgentState.PAUSED
):
pause_action_found = True
assert loop_detection_found, 'LoopDetectionObservation should be created'
assert pause_action_found, 'Agent should be paused'
@pytest.mark.asyncio
async def test_controller_handles_loop_recovery_action_option_1(
self, mock_controller
):
"""Test that controller handles LoopRecoveryAction with option 1."""
# Setup stuck analysis
mock_controller._stuck_detector.stuck_analysis = MagicMock()
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
# Mock the _perform_loop_recovery method for this test
mock_controller._perform_loop_recovery = AsyncMock()
# Create LoopRecoveryAction with option 1
action = LoopRecoveryAction(option=1)
# Call _handle_loop_recovery_action
await mock_controller._handle_loop_recovery_action(action)
# Verify that _perform_loop_recovery was called
mock_controller._perform_loop_recovery.assert_called_once_with(
mock_controller._stuck_detector.stuck_analysis
)
@pytest.mark.asyncio
async def test_controller_handles_loop_recovery_action_option_2(
self, mock_controller
):
"""Test that controller handles LoopRecoveryAction with option 2."""
# Setup stuck analysis
mock_controller._stuck_detector.stuck_analysis = MagicMock()
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
# Mock the _restart_with_last_user_message method for this test
mock_controller._restart_with_last_user_message = AsyncMock()
# Create LoopRecoveryAction with option 2
action = LoopRecoveryAction(option=2)
# Call _handle_loop_recovery_action
await mock_controller._handle_loop_recovery_action(action)
# Verify that _restart_with_last_user_message was called
mock_controller._restart_with_last_user_message.assert_called_once_with(
mock_controller._stuck_detector.stuck_analysis
)
@pytest.mark.asyncio
async def test_controller_handles_loop_recovery_action_option_3(
self, mock_controller
):
"""Test that controller handles LoopRecoveryAction with option 3 (stop)."""
# Setup stuck analysis
mock_controller._stuck_detector.stuck_analysis = MagicMock()
# Mock the set_agent_state_to method for this test
mock_controller.set_agent_state_to = AsyncMock()
# Create LoopRecoveryAction with option 3
action = LoopRecoveryAction(option=3)
# Call _handle_loop_recovery_action
await mock_controller._handle_loop_recovery_action(action)
# Verify that set_agent_state_to was called with STOPPED
mock_controller.set_agent_state_to.assert_called_once_with(AgentState.STOPPED)
@pytest.mark.asyncio
async def test_controller_ignores_loop_recovery_without_stuck_analysis(
self, mock_controller
):
"""Test that controller ignores LoopRecoveryAction when no stuck analysis exists."""
# Ensure no stuck analysis
mock_controller._stuck_detector.stuck_analysis = None
# Mock all recovery methods for this test
mock_controller._perform_loop_recovery = AsyncMock()
mock_controller._restart_with_last_user_message = AsyncMock()
mock_controller.set_agent_state_to = AsyncMock()
# Create LoopRecoveryAction
action = LoopRecoveryAction(option=1)
# Call _handle_loop_recovery_action
await mock_controller._handle_loop_recovery_action(action)
# Verify that no recovery methods were called
mock_controller._perform_loop_recovery.assert_not_called()
mock_controller._restart_with_last_user_message.assert_not_called()
mock_controller.set_agent_state_to.assert_not_called()
@pytest.mark.asyncio
async def test_controller_no_loop_recovery_when_not_stuck(self, mock_controller):
"""Test that controller doesn't attempt recovery when not stuck."""
# Setup no stuck analysis
mock_controller._stuck_detector.stuck_analysis = None
# Reset the mock to ignore any previous calls (like system message)
mock_controller.event_stream.add_event.reset_mock()
# Call attempt_loop_recovery
result = mock_controller.attempt_loop_recovery()
# Verify that no recovery was attempted
assert result is False
# Verify that no loop recovery events were added to the stream
# (Note: there might be other events, but no loop recovery specific ones)
calls = mock_controller.event_stream.add_event.call_args_list
loop_recovery_events = [
call
for call in calls
if len(call[0]) > 0
and (
isinstance(call[0][0], LoopDetectionObservation)
or (
hasattr(call[0][0], 'agent_state')
and call[0][0].agent_state == AgentState.PAUSED
)
)
]
assert len(loop_recovery_events) == 0, (
'No loop recovery events should be added when not stuck'
)
@pytest.mark.asyncio
async def test_controller_state_transition_after_loop_recovery(
self, mock_controller
):
"""Test that controller state transitions correctly after loop recovery."""
# Setup initial state
mock_controller.state.agent_state = AgentState.RUNNING
# Setup stuck detector to detect a loop
mock_controller._stuck_detector.is_stuck.return_value = True
mock_controller._stuck_detector.stuck_analysis = MagicMock()
mock_controller._stuck_detector.stuck_analysis.loop_type = 'monologue'
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 3
# Call attempt_loop_recovery
result = mock_controller.attempt_loop_recovery()
# Verify that recovery was attempted
assert result is True
# Verify that agent was paused
calls = mock_controller.event_stream.add_event.call_args_list
pause_found = False
for call in calls:
args, _ = call
# add_event only takes one argument (the event)
event = args[0]
if hasattr(event, 'agent_state') and event.agent_state == AgentState.PAUSED:
pause_found = True
break
assert pause_found, 'Agent should be paused after loop detection'
@pytest.mark.asyncio
async def test_controller_resumes_after_loop_recovery(self, mock_controller):
"""Test that controller can resume normal operation after loop recovery."""
# Setup stuck analysis
mock_controller._stuck_detector.stuck_analysis = MagicMock()
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
# Mock the _perform_loop_recovery method for this test
mock_controller._perform_loop_recovery = AsyncMock()
# Create LoopRecoveryAction with option 1
action = LoopRecoveryAction(option=1)
# Call _handle_loop_recovery_action
await mock_controller._handle_loop_recovery_action(action)
# Verify that recovery was performed
mock_controller._perform_loop_recovery.assert_called_once()
# Verify that agent can continue normal operation
# (This would be tested in integration tests with actual agent execution)
@pytest.mark.asyncio
async def test_controller_truncates_history_during_loop_recovery(
self, mock_controller
):
"""Test that controller correctly truncates history during loop recovery."""
# Setup mock history with events
from openhands.events.action import CmdRunAction
from openhands.events.observation import CmdOutputObservation, NullObservation
# Create a realistic history with 10 events
mock_history = []
# Add initial user message
user_msg = MessageAction(
content='Hello, help me with this task', wait_for_response=False
)
user_msg._source = 'user'
user_msg._id = 1
mock_history.append(user_msg)
# Add agent response
agent_obs = NullObservation(content='')
agent_obs._id = 2
mock_history.append(agent_obs)
# Add some commands and observations (simulating a loop)
for i in range(3, 11):
if i % 2 == 1: # Action
cmd = CmdRunAction(command='ls -la')
cmd._id = i
mock_history.append(cmd)
else: # Observation
obs = CmdOutputObservation(
content='file1.txt file2.txt', command='ls -la'
)
obs._id = i
obs._cause = i - 1
mock_history.append(obs)
# Set the mock history
mock_controller.state.history = mock_history
mock_controller.state.end_id = 10
# Setup stuck analysis to indicate loop starts at index 5
mock_controller._stuck_detector.stuck_analysis = MagicMock()
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
# Create LoopRecoveryAction with option 1 (truncate memory)
LoopRecoveryAction(option=1)
# Test actual truncation by calling the _perform_loop_recovery method directly
# Reset history for actual truncation test
mock_controller.state.history = mock_history.copy()
mock_controller.state.end_id = 10
# Call the actual _perform_loop_recovery method directly
print(
f'Before truncation: {len(mock_controller.state.history)} events, recovery_point={mock_controller._stuck_detector.stuck_analysis.loop_start_idx}'
)
print(
f'_perform_loop_recovery method: {mock_controller._perform_loop_recovery}'
)
print(
f'_truncate_memory_to_point method: {mock_controller._truncate_memory_to_point}'
)
await mock_controller._perform_loop_recovery(
mock_controller._stuck_detector.stuck_analysis
)
# Debug: print the actual history after truncation
print(f'History after truncation: {len(mock_controller.state.history)} events')
for i, event in enumerate(mock_controller.state.history):
print(f' Event {i}: id={event.id}, type={type(event).__name__}')
# Verify that history was truncated to the recovery point
# The recovery point is index 5, so we should keep events 0-4 (5 events)
assert len(mock_controller.state.history) == 5, (
f'Expected 5 events after truncation, got {len(mock_controller.state.history)}'
)
# Verify the specific events that remain
expected_ids = [1, 2, 3, 4, 5]
for i, event in enumerate(mock_controller.state.history):
assert event.id == expected_ids[i], (
f'Event at index {i} should have id {expected_ids[i]}, got {event.id}'
)
# Verify end_id was updated
assert mock_controller.state.end_id == 5, (
f'Expected end_id to be 5, got {mock_controller.state.end_id}'
)

View File

@@ -116,6 +116,7 @@ class TestStuckDetector:
state.history.append(cmd_observation)
assert stuck_detector.is_stuck(headless_mode=True) is False
assert stuck_detector.stuck_analysis is None
def test_interactive_mode_resets_after_user_message(
self, stuck_detector: StuckDetector
@@ -237,6 +238,11 @@ class TestStuckDetector:
assert stuck_detector.is_stuck(headless_mode=True) is True
mock_warning.assert_called_once_with('Action, Observation loop detected')
# recover to before first loop pattern
assert stuck_detector.stuck_analysis.loop_type == 'repeating_action_observation'
assert stuck_detector.stuck_analysis.loop_repeat_times == 4
assert stuck_detector.stuck_analysis.loop_start_idx == 1
def test_is_stuck_repeating_action_error(self, stuck_detector: StuckDetector):
state = stuck_detector.state
# (action, error_observation), not necessarily the same error
@@ -290,6 +296,9 @@ class TestStuckDetector:
mock_warning.assert_called_once_with(
'Action, ErrorObservation loop detected'
)
assert stuck_detector.stuck_analysis.loop_type == 'repeating_action_error'
assert stuck_detector.stuck_analysis.loop_repeat_times == 3
assert stuck_detector.stuck_analysis.loop_start_idx == 1
def test_is_stuck_invalid_syntax_error(self, stuck_detector: StuckDetector):
state = stuck_detector.state
@@ -494,6 +503,12 @@ class TestStuckDetector:
with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck(headless_mode=True) is True
mock_warning.assert_called_once_with('Action, Observation pattern detected')
assert (
stuck_detector.stuck_analysis.loop_type
== 'repeating_action_observation_pattern'
)
assert stuck_detector.stuck_analysis.loop_repeat_times == 3
assert stuck_detector.stuck_analysis.loop_start_idx == 0 # null ignored
def test_is_stuck_not_stuck(self, stuck_detector: StuckDetector):
state = stuck_detector.state
@@ -585,6 +600,9 @@ class TestStuckDetector:
state.history.append(message_action_6)
assert stuck_detector.is_stuck(headless_mode=True)
assert stuck_detector.stuck_analysis.loop_type == 'monologue'
assert stuck_detector.stuck_analysis.loop_repeat_times == 3
assert stuck_detector.stuck_analysis.loop_start_idx == 2 # null ignored
# Add an observation event between the repeated message actions
cmd_output_observation = CmdOutputObservation(
@@ -628,6 +646,9 @@ class TestStuckDetector:
mock_warning.assert_called_once_with(
'Context window error loop detected - repeated condensation events'
)
assert stuck_detector.stuck_analysis.loop_type == 'context_window_error'
assert stuck_detector.stuck_analysis.loop_repeat_times == 2
assert stuck_detector.stuck_analysis.loop_start_idx == 0
def test_is_not_stuck_context_window_error_with_other_events(self, stuck_detector):
"""Test that we don't detect a loop when there are other events between condensation events."""
@@ -731,6 +752,9 @@ class TestStuckDetector:
mock_warning.assert_called_once_with(
'Context window error loop detected - repeated condensation events'
)
assert stuck_detector.stuck_analysis.loop_type == 'context_window_error'
assert stuck_detector.stuck_analysis.loop_repeat_times == 2
assert stuck_detector.stuck_analysis.loop_start_idx == 0
def test_is_not_stuck_context_window_error_in_non_headless(self, stuck_detector):
"""Test that in non-headless mode, we don't detect a loop if the condensation events