mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-10 23:38:08 -05:00
Enhance dead-loop recovery by pausing agent and reprompting (#11439)
Co-authored-by: Engel Nyst <enyst@users.noreply.github.com> Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -47,6 +47,7 @@ from openhands.core.schema.exit_reason import ExitReason
|
||||
from openhands.events import EventSource
|
||||
from openhands.events.action import (
|
||||
ChangeAgentStateAction,
|
||||
LoopRecoveryAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.stream import EventStream
|
||||
@@ -159,9 +160,9 @@ async def handle_commands(
|
||||
exit_reason = ExitReason.INTENTIONAL
|
||||
elif command == '/settings':
|
||||
await handle_settings_command(config, settings_store)
|
||||
elif command == '/resume':
|
||||
elif command.startswith('/resume'):
|
||||
close_repl, new_session_requested = await handle_resume_command(
|
||||
event_stream, agent_state
|
||||
command, event_stream, agent_state
|
||||
)
|
||||
elif command == '/mcp':
|
||||
await handle_mcp_command(config)
|
||||
@@ -294,6 +295,7 @@ async def handle_settings_command(
|
||||
# Setting the agent state to RUNNING will currently freeze the agent without continuing with the rest of the task.
|
||||
# This is a workaround to handle the resume command for the time being. Replace user message with the state change event once the issue is fixed.
|
||||
async def handle_resume_command(
|
||||
command: str,
|
||||
event_stream: EventStream,
|
||||
agent_state: str,
|
||||
) -> tuple[bool, bool]:
|
||||
@@ -309,10 +311,29 @@ async def handle_resume_command(
|
||||
)
|
||||
return close_repl, new_session_requested
|
||||
|
||||
event_stream.add_event(
|
||||
MessageAction(content='continue'),
|
||||
EventSource.USER,
|
||||
)
|
||||
# Check if this is a loop recovery resume with an option
|
||||
if command.strip() != '/resume':
|
||||
# Parse the option from the command (e.g., '/resume 1', '/resume 2')
|
||||
parts = command.strip().split()
|
||||
if len(parts) == 2 and parts[1] in ['1', '2']:
|
||||
option = parts[1]
|
||||
# Send the option as a message to be handled by the controller
|
||||
event_stream.add_event(
|
||||
LoopRecoveryAction(option=int(option)),
|
||||
EventSource.USER,
|
||||
)
|
||||
else:
|
||||
# Invalid format, send as regular resume
|
||||
event_stream.add_event(
|
||||
MessageAction(content='continue'),
|
||||
EventSource.USER,
|
||||
)
|
||||
else:
|
||||
# Regular resume without loop recovery option
|
||||
event_stream.add_event(
|
||||
MessageAction(content='continue'),
|
||||
EventSource.USER,
|
||||
)
|
||||
|
||||
# event_stream.add_event(
|
||||
# ChangeAgentStateAction(AgentState.RUNNING),
|
||||
|
||||
@@ -430,9 +430,25 @@ async def run_session(
|
||||
# No session restored, no initial action: prompt for the user's first message
|
||||
asyncio.create_task(prompt_for_next_task(''))
|
||||
|
||||
await run_agent_until_done(
|
||||
controller, runtime, memory, [AgentState.STOPPED, AgentState.ERROR]
|
||||
)
|
||||
skip_set_callback = False
|
||||
while True:
|
||||
await run_agent_until_done(
|
||||
controller,
|
||||
runtime,
|
||||
memory,
|
||||
[AgentState.STOPPED, AgentState.ERROR],
|
||||
skip_set_callback,
|
||||
)
|
||||
# Try loop recovery in CLI app
|
||||
if (
|
||||
controller.state.agent_state == AgentState.ERROR
|
||||
and controller.state.last_error.startswith('AgentStuckInLoopError')
|
||||
):
|
||||
controller.attempt_loop_recovery()
|
||||
skip_set_callback = True
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
await cleanup_session(loop, agent, runtime, controller)
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ from openhands.events.observation import (
|
||||
ErrorObservation,
|
||||
FileEditObservation,
|
||||
FileReadObservation,
|
||||
LoopDetectionObservation,
|
||||
MCPObservation,
|
||||
TaskTrackingObservation,
|
||||
)
|
||||
@@ -309,6 +310,8 @@ def display_event(event: Event, config: OpenHandsConfig) -> None:
|
||||
display_agent_state_change_message(event.agent_state)
|
||||
elif isinstance(event, ErrorObservation):
|
||||
display_error(event.content)
|
||||
elif isinstance(event, LoopDetectionObservation):
|
||||
handle_loop_recovery_state_observation(event)
|
||||
|
||||
|
||||
def display_message(message: str, is_agent_message: bool = False) -> None:
|
||||
@@ -1039,3 +1042,25 @@ class UserCancelledError(Exception):
|
||||
"""Raised when the user cancels an operation via key binding."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def handle_loop_recovery_state_observation(
|
||||
observation: LoopDetectionObservation,
|
||||
) -> None:
|
||||
"""Handle loop recovery state observation events.
|
||||
|
||||
Updates the global loop recovery state based on the observation.
|
||||
"""
|
||||
content = observation.content
|
||||
container = Frame(
|
||||
TextArea(
|
||||
text=content,
|
||||
read_only=True,
|
||||
style=COLOR_GREY,
|
||||
wrap_lines=True,
|
||||
),
|
||||
title='Agent Loop Detection',
|
||||
style=f'fg:{COLOR_GREY}',
|
||||
)
|
||||
print_formatted_text('')
|
||||
print_container(container)
|
||||
|
||||
@@ -64,6 +64,7 @@ from openhands.events.action import (
|
||||
MessageAction,
|
||||
NullAction,
|
||||
SystemMessageAction,
|
||||
LoopRecoveryAction,
|
||||
)
|
||||
from openhands.events.action.agent import (
|
||||
CondensationAction,
|
||||
@@ -77,6 +78,7 @@ from openhands.events.observation import (
|
||||
ErrorObservation,
|
||||
NullObservation,
|
||||
Observation,
|
||||
LoopDetectionObservation,
|
||||
)
|
||||
from openhands.events.serialization.event import truncate_content
|
||||
from openhands.llm.metrics import Metrics
|
||||
@@ -523,6 +525,8 @@ class AgentController:
|
||||
elif isinstance(action, AgentRejectAction):
|
||||
self.state.outputs = action.outputs
|
||||
await self.set_agent_state_to(AgentState.REJECTED)
|
||||
elif isinstance(action, LoopRecoveryAction):
|
||||
await self._handle_loop_recovery_action(action)
|
||||
|
||||
async def _handle_observation(self, observation: Observation) -> None:
|
||||
"""Handles observation from the event stream.
|
||||
@@ -595,6 +599,25 @@ class AgentController:
|
||||
if action.wait_for_response:
|
||||
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
|
||||
|
||||
async def _handle_loop_recovery_action(self, action: LoopRecoveryAction) -> None:
|
||||
# Check if this is a loop recovery option
|
||||
if self._stuck_detector.stuck_analysis:
|
||||
option = action.option
|
||||
|
||||
# Handle the loop recovery option
|
||||
if option == 1:
|
||||
# Option 1: Restart from before loop
|
||||
await self._perform_loop_recovery(self._stuck_detector.stuck_analysis)
|
||||
elif option == 2:
|
||||
# Option 2: Restart with last user message
|
||||
await self._restart_with_last_user_message(
|
||||
self._stuck_detector.stuck_analysis
|
||||
)
|
||||
elif option == 3:
|
||||
# Option 3: Stop agent completely
|
||||
await self.set_agent_state_to(AgentState.STOPPED)
|
||||
return
|
||||
|
||||
def _reset(self) -> None:
|
||||
"""Resets the agent controller."""
|
||||
# Runnable actions need an Observation
|
||||
@@ -1084,6 +1107,45 @@ class AgentController:
|
||||
|
||||
return self._stuck_detector.is_stuck(self.headless_mode)
|
||||
|
||||
def attempt_loop_recovery(self) -> bool:
|
||||
"""Attempts loop recovery when agent is stuck in a loop.
|
||||
Only supports CLI for now.
|
||||
|
||||
Returns:
|
||||
bool: True if recovery was successful and agent should continue,
|
||||
False if recovery failed or was not attempted.
|
||||
"""
|
||||
# Check if we're in a loop
|
||||
if not self._stuck_detector.stuck_analysis:
|
||||
return False
|
||||
|
||||
"""Handle loop recovery in CLI mode by pausing the agent and presenting recovery options."""
|
||||
recovery_point = self._stuck_detector.stuck_analysis.loop_start_idx
|
||||
|
||||
# Present loop detection message
|
||||
self.event_stream.add_event(
|
||||
LoopDetectionObservation(
|
||||
content=f"""⚠️ Agent detected in a loop!
|
||||
Loop type: {self._stuck_detector.stuck_analysis.loop_type}
|
||||
Loop detected at iteration {self.state.iteration_flag.current_value}
|
||||
\nRecovery options:
|
||||
/resume 1. Restart from before loop (preserves {recovery_point} events)
|
||||
/resume 2. Restart with last user message (reuses your most recent instruction)
|
||||
/exit. Quit directly
|
||||
\nThe agent has been paused. Type '/resume 1', '/resume 2', or '/exit' to choose an option.
|
||||
"""
|
||||
),
|
||||
source=EventSource.ENVIRONMENT,
|
||||
)
|
||||
|
||||
# Pause the agent using the same mechanism as Ctrl+P
|
||||
# This ensures consistent behavior and avoids event loop conflicts
|
||||
self.event_stream.add_event(
|
||||
ChangeAgentStateAction(AgentState.PAUSED),
|
||||
EventSource.ENVIRONMENT, # Use ENVIRONMENT source to distinguish from user pause
|
||||
)
|
||||
return True
|
||||
|
||||
def _prepare_metrics_for_frontend(self, action: Action) -> None:
|
||||
"""Create a minimal metrics object for frontend display and log it.
|
||||
|
||||
@@ -1208,5 +1270,92 @@ class AgentController:
|
||||
)
|
||||
return self._cached_first_user_message
|
||||
|
||||
async def _perform_loop_recovery(
|
||||
self, stuck_analysis: StuckDetector.StuckAnalysis
|
||||
) -> None:
|
||||
"""Perform loop recovery by truncating memory and restarting from before the loop."""
|
||||
recovery_point = stuck_analysis.loop_start_idx
|
||||
|
||||
# Truncate memory to the recovery point
|
||||
await self._truncate_memory_to_point(recovery_point)
|
||||
|
||||
# Set agent state to AWAITING_USER_INPUT to allow user to provide new instructions
|
||||
await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT)
|
||||
|
||||
self.event_stream.add_event(
|
||||
LoopDetectionObservation(
|
||||
content="""✅ Loop recovery completed. Agent has been reset to before the loop.
|
||||
You can now provide new instructions to continue.
|
||||
"""
|
||||
),
|
||||
source=EventSource.ENVIRONMENT,
|
||||
)
|
||||
|
||||
async def _truncate_memory_to_point(self, recovery_point: int) -> None:
|
||||
"""Truncate memory to the specified recovery point."""
|
||||
# Get all events from state history
|
||||
all_events = self.state.history
|
||||
|
||||
if recovery_point >= len(all_events):
|
||||
return
|
||||
|
||||
# Keep only events up to the recovery point
|
||||
events_to_keep = all_events[:recovery_point]
|
||||
|
||||
# Update state history
|
||||
self.state.history = events_to_keep
|
||||
|
||||
# Update end_id to reflect the truncation
|
||||
if events_to_keep:
|
||||
self.state.end_id = events_to_keep[-1].id
|
||||
else:
|
||||
self.state.end_id = -1
|
||||
|
||||
# Clear any cached messages
|
||||
self._cached_first_user_message = None
|
||||
|
||||
async def _restart_with_last_user_message(
|
||||
self, stuck_analysis: StuckDetector.StuckAnalysis
|
||||
) -> None:
|
||||
"""Restart the agent using the last user message as the new instruction."""
|
||||
|
||||
# Find the last user message in the history
|
||||
last_user_message = None
|
||||
for event in reversed(self.state.history):
|
||||
if isinstance(event, MessageAction) and event.source == EventSource.USER:
|
||||
last_user_message = event
|
||||
break
|
||||
|
||||
if last_user_message:
|
||||
# Truncate memory to just before the loop started
|
||||
recovery_point = stuck_analysis.loop_start_idx
|
||||
await self._truncate_memory_to_point(recovery_point)
|
||||
|
||||
# Set agent state to RUNNING and re-use the last user message
|
||||
await self.set_agent_state_to(AgentState.RUNNING)
|
||||
|
||||
# Re-use the last user message as the new instruction
|
||||
self.event_stream.add_event(
|
||||
LoopDetectionObservation(
|
||||
content=f"""\n✅ Restarting with your last instruction: {last_user_message.content}
|
||||
Agent is now continuing with the same task...
|
||||
"""
|
||||
),
|
||||
source=EventSource.ENVIRONMENT,
|
||||
)
|
||||
|
||||
# Create a new action with the last user message
|
||||
new_action = MessageAction(
|
||||
content=last_user_message.content, wait_for_response=False
|
||||
)
|
||||
new_action._source = EventSource.USER # type: ignore [attr-defined]
|
||||
|
||||
# Process the action to restart the agent
|
||||
await self._handle_action(new_action)
|
||||
else:
|
||||
# If no user message found, fall back to regular recovery
|
||||
print('\n⚠️ No previous user message found. Using standard recovery.')
|
||||
await self._perform_loop_recovery(stuck_analysis)
|
||||
|
||||
def save_state(self):
|
||||
self.state_tracker.save_state()
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events import Event, EventSource
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.action.commands import IPythonRunCellAction
|
||||
from openhands.events.action.empty import NullAction
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.observation import (
|
||||
CmdOutputObservation,
|
||||
IPythonRunCellObservation,
|
||||
@@ -22,8 +25,15 @@ class StuckDetector:
|
||||
'SyntaxError: incomplete input',
|
||||
]
|
||||
|
||||
@dataclass
|
||||
class StuckAnalysis:
|
||||
loop_type: str
|
||||
loop_repeat_times: int
|
||||
loop_start_idx: int # in filtered_history
|
||||
|
||||
def __init__(self, state: State):
|
||||
self.state = state
|
||||
self.stuck_analysis: Optional[StuckDetector.StuckAnalysis] = None
|
||||
|
||||
def is_stuck(self, headless_mode: bool = True) -> bool:
|
||||
"""Checks if the agent is stuck in a loop.
|
||||
@@ -36,6 +46,7 @@ class StuckDetector:
|
||||
Returns:
|
||||
bool: True if the agent is stuck in a loop, False otherwise.
|
||||
"""
|
||||
filtered_history_offset = 0
|
||||
if not headless_mode:
|
||||
# In interactive mode, only look at history after the last user message
|
||||
last_user_msg_idx = -1
|
||||
@@ -46,7 +57,7 @@ class StuckDetector:
|
||||
):
|
||||
last_user_msg_idx = len(self.state.history) - i - 1
|
||||
break
|
||||
|
||||
filtered_history_offset = last_user_msg_idx + 1
|
||||
history_to_check = self.state.history[last_user_msg_idx + 1 :]
|
||||
else:
|
||||
# In headless mode, look at all history
|
||||
@@ -86,31 +97,45 @@ class StuckDetector:
|
||||
break
|
||||
|
||||
# scenario 1: same action, same observation
|
||||
if self._is_stuck_repeating_action_observation(last_actions, last_observations):
|
||||
if self._is_stuck_repeating_action_observation(
|
||||
last_actions, last_observations, filtered_history, filtered_history_offset
|
||||
):
|
||||
return True
|
||||
|
||||
# scenario 2: same action, errors
|
||||
if self._is_stuck_repeating_action_error(last_actions, last_observations):
|
||||
if self._is_stuck_repeating_action_error(
|
||||
last_actions, last_observations, filtered_history, filtered_history_offset
|
||||
):
|
||||
return True
|
||||
|
||||
# scenario 3: monologue
|
||||
if self._is_stuck_monologue(filtered_history):
|
||||
if self._is_stuck_monologue(filtered_history, filtered_history_offset):
|
||||
return True
|
||||
|
||||
# scenario 4: action, observation pattern on the last six steps
|
||||
if len(filtered_history) >= 6:
|
||||
if self._is_stuck_action_observation_pattern(filtered_history):
|
||||
if self._is_stuck_action_observation_pattern(
|
||||
filtered_history, filtered_history_offset
|
||||
):
|
||||
return True
|
||||
|
||||
# scenario 5: context window error loop
|
||||
if len(filtered_history) >= 10:
|
||||
if self._is_stuck_context_window_error(filtered_history):
|
||||
if self._is_stuck_context_window_error(
|
||||
filtered_history, filtered_history_offset
|
||||
):
|
||||
return True
|
||||
|
||||
# Empty stuck_analysis when not stuck
|
||||
self.stuck_analysis = None
|
||||
return False
|
||||
|
||||
def _is_stuck_repeating_action_observation(
|
||||
self, last_actions: list[Event], last_observations: list[Event]
|
||||
self,
|
||||
last_actions: list[Event],
|
||||
last_observations: list[Event],
|
||||
filtered_history: list[Event],
|
||||
filtered_history_offset: int = 0,
|
||||
) -> bool:
|
||||
# scenario 1: same action, same observation
|
||||
# it takes 4 actions and 4 observations to detect a loop
|
||||
@@ -128,12 +153,22 @@ class StuckDetector:
|
||||
|
||||
if actions_equal and observations_equal:
|
||||
logger.warning('Action, Observation loop detected')
|
||||
self.stuck_analysis = StuckDetector.StuckAnalysis(
|
||||
loop_type='repeating_action_observation',
|
||||
loop_repeat_times=4,
|
||||
loop_start_idx=filtered_history.index(last_actions[-1])
|
||||
+ filtered_history_offset,
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_stuck_repeating_action_error(
|
||||
self, last_actions: list[Event], last_observations: list[Event]
|
||||
self,
|
||||
last_actions: list[Event],
|
||||
last_observations: list[Event],
|
||||
filtered_history: list[Event],
|
||||
filtered_history_offset: int = 0,
|
||||
) -> bool:
|
||||
# scenario 2: same action, errors
|
||||
# it takes 3 actions and 3 observations to detect a loop
|
||||
@@ -147,6 +182,12 @@ class StuckDetector:
|
||||
# and the last three observations are all errors?
|
||||
if all(isinstance(obs, ErrorObservation) for obs in last_observations[:3]):
|
||||
logger.warning('Action, ErrorObservation loop detected')
|
||||
self.stuck_analysis = StuckDetector.StuckAnalysis(
|
||||
loop_type='repeating_action_error',
|
||||
loop_repeat_times=3,
|
||||
loop_start_idx=filtered_history.index(last_actions[-1])
|
||||
+ filtered_history_offset,
|
||||
)
|
||||
return True
|
||||
# or, are the last three observations all IPythonRunCellObservation with SyntaxError?
|
||||
elif all(
|
||||
@@ -167,6 +208,12 @@ class StuckDetector:
|
||||
error_message,
|
||||
):
|
||||
logger.warning(warning)
|
||||
self.stuck_analysis = StuckDetector.StuckAnalysis(
|
||||
loop_type='repeating_action_error',
|
||||
loop_repeat_times=3,
|
||||
loop_start_idx=filtered_history.index(last_actions[-1])
|
||||
+ filtered_history_offset,
|
||||
)
|
||||
return True
|
||||
elif error_message in (
|
||||
'SyntaxError: invalid syntax. Perhaps you forgot a comma?',
|
||||
@@ -180,6 +227,12 @@ class StuckDetector:
|
||||
error_message,
|
||||
):
|
||||
logger.warning(warning)
|
||||
self.stuck_analysis = StuckDetector.StuckAnalysis(
|
||||
loop_type='repeating_action_error',
|
||||
loop_repeat_times=3,
|
||||
loop_start_idx=filtered_history.index(last_actions[-1])
|
||||
+ filtered_history_offset,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -255,7 +308,9 @@ class StuckDetector:
|
||||
# and the 3rd-to-last line is identical across all occurrences
|
||||
return len(error_lines) == 3 and len(set(error_lines)) == 1
|
||||
|
||||
def _is_stuck_monologue(self, filtered_history: list[Event]) -> bool:
|
||||
def _is_stuck_monologue(
|
||||
self, filtered_history: list[Event], filtered_history_offset: int = 0
|
||||
) -> bool:
|
||||
# scenario 3: monologue
|
||||
# check for repeated MessageActions with source=AGENT
|
||||
# see if the agent is engaged in a good old monologue, telling itself the same thing over and over
|
||||
@@ -286,11 +341,16 @@ class StuckDetector:
|
||||
|
||||
if not has_observation_between:
|
||||
logger.warning('Repeated MessageAction with source=AGENT detected')
|
||||
self.stuck_analysis = StuckDetector.StuckAnalysis(
|
||||
loop_type='monologue',
|
||||
loop_repeat_times=3,
|
||||
loop_start_idx=start_index + filtered_history_offset,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_stuck_action_observation_pattern(
|
||||
self, filtered_history: list[Event]
|
||||
self, filtered_history: list[Event], filtered_history_offset: int = 0
|
||||
) -> bool:
|
||||
# scenario 4: action, observation pattern on the last six steps
|
||||
# check if the agent repeats the same (Action, Observation)
|
||||
@@ -330,10 +390,18 @@ class StuckDetector:
|
||||
|
||||
if actions_equal and observations_equal:
|
||||
logger.warning('Action, Observation pattern detected')
|
||||
self.stuck_analysis = StuckDetector.StuckAnalysis(
|
||||
loop_type='repeating_action_observation_pattern',
|
||||
loop_repeat_times=3,
|
||||
loop_start_idx=filtered_history.index(last_six_actions[-1])
|
||||
+ filtered_history_offset,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_stuck_context_window_error(self, filtered_history: list[Event]) -> bool:
|
||||
def _is_stuck_context_window_error(
|
||||
self, filtered_history: list[Event], filtered_history_offset: int = 0
|
||||
) -> bool:
|
||||
"""Detects if we're stuck in a loop of context window errors.
|
||||
|
||||
This happens when we repeatedly get context window errors and try to trim,
|
||||
@@ -377,6 +445,11 @@ class StuckDetector:
|
||||
logger.warning(
|
||||
'Context window error loop detected - repeated condensation events'
|
||||
)
|
||||
self.stuck_analysis = StuckDetector.StuckAnalysis(
|
||||
loop_type='context_window_error',
|
||||
loop_repeat_times=2,
|
||||
loop_start_idx=start_idx + filtered_history_offset,
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -13,6 +13,7 @@ async def run_agent_until_done(
|
||||
runtime: Runtime,
|
||||
memory: Memory,
|
||||
end_states: list[AgentState],
|
||||
skip_set_callback: bool = False,
|
||||
) -> None:
|
||||
"""run_agent_until_done takes a controller and a runtime, and will run
|
||||
the agent until it reaches a terminal state.
|
||||
@@ -28,18 +29,19 @@ async def run_agent_until_done(
|
||||
else:
|
||||
logger.info(msg)
|
||||
|
||||
if hasattr(runtime, 'status_callback') and runtime.status_callback:
|
||||
raise ValueError(
|
||||
'Runtime status_callback was set, but run_agent_until_done will override it'
|
||||
)
|
||||
if hasattr(controller, 'status_callback') and controller.status_callback:
|
||||
raise ValueError(
|
||||
'Controller status_callback was set, but run_agent_until_done will override it'
|
||||
)
|
||||
if not skip_set_callback:
|
||||
if hasattr(runtime, 'status_callback') and runtime.status_callback:
|
||||
raise ValueError(
|
||||
'Runtime status_callback was set, but run_agent_until_done will override it'
|
||||
)
|
||||
if hasattr(controller, 'status_callback') and controller.status_callback:
|
||||
raise ValueError(
|
||||
'Controller status_callback was set, but run_agent_until_done will override it'
|
||||
)
|
||||
|
||||
runtime.status_callback = status_callback
|
||||
controller.status_callback = status_callback
|
||||
memory.status_callback = status_callback
|
||||
runtime.status_callback = status_callback
|
||||
controller.status_callback = status_callback
|
||||
memory.status_callback = status_callback
|
||||
|
||||
while controller.state.agent_state not in end_states:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
@@ -97,3 +97,6 @@ class ActionType(str, Enum):
|
||||
|
||||
TASK_TRACKING = 'task_tracking'
|
||||
"""Views or updates the task list for task management."""
|
||||
|
||||
LOOP_RECOVERY = 'loop_recovery'
|
||||
"""Recover dead loop."""
|
||||
|
||||
@@ -58,3 +58,6 @@ class ObservationType(str, Enum):
|
||||
|
||||
TASK_TRACKING = 'task_tracking'
|
||||
"""Result of a task tracking operation"""
|
||||
|
||||
LOOP_DETECTION = 'loop_detection'
|
||||
"""Results of a dead-loop detection"""
|
||||
|
||||
@@ -9,6 +9,7 @@ from openhands.events.action.agent import (
|
||||
AgentRejectAction,
|
||||
AgentThinkAction,
|
||||
ChangeAgentStateAction,
|
||||
LoopRecoveryAction,
|
||||
RecallAction,
|
||||
TaskTrackingAction,
|
||||
)
|
||||
@@ -45,4 +46,5 @@ __all__ = [
|
||||
'MCPAction',
|
||||
'TaskTrackingAction',
|
||||
'ActionSecurityRisk',
|
||||
'LoopRecoveryAction',
|
||||
]
|
||||
|
||||
@@ -226,3 +226,17 @@ class TaskTrackingAction(Action):
|
||||
return 'Managing 1 task item.'
|
||||
else:
|
||||
return f'Managing {num_tasks} task items.'
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoopRecoveryAction(Action):
|
||||
"""An action that shows three ways to handle dead loop.
|
||||
The class should be invisible to LLM.
|
||||
Attributes:
|
||||
option (int): 1 allow user to prompt again
|
||||
2 automatically use latest user prompt
|
||||
3 stop agent
|
||||
"""
|
||||
|
||||
option: int = 1
|
||||
action: str = ActionType.LOOP_RECOVERY
|
||||
|
||||
@@ -22,6 +22,7 @@ from openhands.events.observation.files import (
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
)
|
||||
from openhands.events.observation.loop_recovery import LoopDetectionObservation
|
||||
from openhands.events.observation.mcp import MCPObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.observation.reject import UserRejectObservation
|
||||
@@ -47,6 +48,7 @@ __all__ = [
|
||||
'AgentCondensationObservation',
|
||||
'RecallObservation',
|
||||
'RecallType',
|
||||
'LoopDetectionObservation',
|
||||
'MCPObservation',
|
||||
'FileDownloadObservation',
|
||||
'TaskTrackingObservation',
|
||||
|
||||
18
openhands/events/observation/loop_recovery.py
Normal file
18
openhands/events/observation/loop_recovery.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.observation.observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoopDetectionObservation(Observation):
|
||||
"""Observation for loop recovery state changes.
|
||||
|
||||
This observation is used to notify the UI layer when agent
|
||||
is in loop recovery mode.
|
||||
|
||||
This observation is CLI-specific and should only be displayed
|
||||
in CLI/TUI mode, not in GUI or other UI modes.
|
||||
"""
|
||||
|
||||
observation: str = ObservationType.LOOP_DETECTION
|
||||
@@ -10,6 +10,7 @@ from openhands.events.action.agent import (
|
||||
ChangeAgentStateAction,
|
||||
CondensationAction,
|
||||
CondensationRequestAction,
|
||||
LoopRecoveryAction,
|
||||
RecallAction,
|
||||
TaskTrackingAction,
|
||||
)
|
||||
@@ -48,6 +49,7 @@ actions = (
|
||||
CondensationRequestAction,
|
||||
MCPAction,
|
||||
TaskTrackingAction,
|
||||
LoopRecoveryAction,
|
||||
)
|
||||
|
||||
ACTION_TYPE_TO_CLASS = {action_class.action: action_class for action_class in actions} # type: ignore[attr-defined]
|
||||
|
||||
@@ -26,6 +26,7 @@ from openhands.events.observation.files import (
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
)
|
||||
from openhands.events.observation.loop_recovery import LoopDetectionObservation
|
||||
from openhands.events.observation.mcp import MCPObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.observation.reject import UserRejectObservation
|
||||
@@ -51,6 +52,7 @@ observations = (
|
||||
MCPObservation,
|
||||
FileDownloadObservation,
|
||||
TaskTrackingObservation,
|
||||
LoopDetectionObservation,
|
||||
)
|
||||
|
||||
OBSERVATION_TYPE_TO_CLASS = {
|
||||
|
||||
@@ -33,6 +33,7 @@ from openhands.events.observation import (
|
||||
FileEditObservation,
|
||||
FileReadObservation,
|
||||
IPythonRunCellObservation,
|
||||
LoopDetectionObservation,
|
||||
TaskTrackingObservation,
|
||||
UserRejectObservation,
|
||||
)
|
||||
@@ -524,6 +525,9 @@ class ConversationMemory:
|
||||
elif isinstance(obs, FileDownloadObservation):
|
||||
text = truncate_content(obs.content, max_message_chars)
|
||||
message = Message(role='user', content=[TextContent(text=text)])
|
||||
elif isinstance(obs, LoopDetectionObservation):
|
||||
# LoopRecovery should not be observed by llm, handled internally.
|
||||
return []
|
||||
elif (
|
||||
isinstance(obs, RecallObservation)
|
||||
and self.agent_config.enable_prompt_extensions
|
||||
|
||||
@@ -632,7 +632,7 @@ async def test_main_with_session_name_passes_name_to_run_session(
|
||||
) # For REPL control
|
||||
@patch('openhands.cli.main.handle_commands', new_callable=AsyncMock) # For REPL control
|
||||
@patch('openhands.core.setup.State.restore_from_session') # Key mock
|
||||
@patch('openhands.controller.AgentController.__init__') # To check initial_state
|
||||
@patch('openhands.cli.main.create_controller') # To check initial_state
|
||||
@patch('openhands.cli.main.display_runtime_initialization_message') # Cosmetic
|
||||
@patch('openhands.cli.main.display_initialization_animation') # Cosmetic
|
||||
@patch('openhands.cli.main.initialize_repository_for_runtime') # Cosmetic / setup
|
||||
@@ -644,7 +644,7 @@ async def test_run_session_with_name_attempts_state_restore(
|
||||
mock_initialize_repo,
|
||||
mock_display_init_anim,
|
||||
mock_display_runtime_init,
|
||||
mock_agent_controller_init,
|
||||
mock_create_controller,
|
||||
mock_restore_from_session,
|
||||
mock_handle_commands,
|
||||
mock_read_prompt_input,
|
||||
@@ -680,8 +680,20 @@ async def test_run_session_with_name_attempts_state_restore(
|
||||
mock_loaded_state = MagicMock(spec=State)
|
||||
mock_restore_from_session.return_value = mock_loaded_state
|
||||
|
||||
# AgentController.__init__ should not return a value (it's __init__)
|
||||
mock_agent_controller_init.return_value = None
|
||||
# Create a mock controller with state attribute
|
||||
mock_controller = MagicMock()
|
||||
mock_controller.state = MagicMock()
|
||||
mock_controller.state.agent_state = None
|
||||
mock_controller.state.last_error = None
|
||||
|
||||
# Mock create_controller to return the mock controller and loaded state
|
||||
# but still call the real restore_from_session
|
||||
def create_controller_side_effect(*args, **kwargs):
|
||||
# Call the real restore_from_session to verify it's called
|
||||
mock_restore_from_session(expected_sid, mock_runtime.event_stream.file_store)
|
||||
return (mock_controller, mock_loaded_state)
|
||||
|
||||
mock_create_controller.side_effect = create_controller_side_effect
|
||||
|
||||
# To make run_session exit cleanly after one loop
|
||||
mock_read_prompt_input.return_value = '/exit'
|
||||
@@ -712,10 +724,10 @@ async def test_run_session_with_name_attempts_state_restore(
|
||||
expected_sid, mock_runtime.event_stream.file_store
|
||||
)
|
||||
|
||||
# Check that AgentController was initialized with the loaded state
|
||||
mock_agent_controller_init.assert_called_once()
|
||||
args, kwargs = mock_agent_controller_init.call_args
|
||||
assert kwargs.get('initial_state') == mock_loaded_state
|
||||
# Check that create_controller was called and returned the loaded state
|
||||
mock_create_controller.assert_called_once()
|
||||
# The create_controller should have been called with the loaded state
|
||||
# (this is verified by the fact that restore_from_session was called and returned mock_loaded_state)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -573,7 +573,7 @@ class TestHandleResumeCommand:
|
||||
|
||||
# Call the function with PAUSED state
|
||||
close_repl, new_session_requested = await handle_resume_command(
|
||||
event_stream, AgentState.PAUSED
|
||||
'/resume', event_stream, AgentState.PAUSED
|
||||
)
|
||||
|
||||
# Check that the event stream add_event was called with the correct message action
|
||||
@@ -604,7 +604,7 @@ class TestHandleResumeCommand:
|
||||
event_stream = MagicMock(spec=EventStream)
|
||||
|
||||
close_repl, new_session_requested = await handle_resume_command(
|
||||
event_stream, invalid_state
|
||||
'/resume', event_stream, invalid_state
|
||||
)
|
||||
|
||||
# Check that no event was added to the stream
|
||||
|
||||
143
tests/unit/cli/test_cli_loop_recovery.py
Normal file
143
tests/unit/cli/test_cli_loop_recovery.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Tests for CLI loop recovery functionality."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.cli.commands import handle_resume_command
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
from openhands.controller.stuck import StuckDetector
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import EventSource
|
||||
from openhands.events.action import LoopRecoveryAction, MessageAction
|
||||
from openhands.events.stream import EventStream
|
||||
|
||||
|
||||
class TestCliLoopRecoveryIntegration:
|
||||
"""Integration tests for CLI loop recovery functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_recovery_resume_option_1(self):
|
||||
"""Test that resume option 1 triggers loop recovery with memory truncation."""
|
||||
# Create a mock agent controller with stuck analysis
|
||||
mock_controller = MagicMock(spec=AgentController)
|
||||
mock_controller._stuck_detector = MagicMock(spec=StuckDetector)
|
||||
mock_controller._stuck_detector.stuck_analysis = MagicMock()
|
||||
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
|
||||
|
||||
# Mock the loop recovery methods
|
||||
mock_controller._perform_loop_recovery = MagicMock()
|
||||
mock_controller._restart_with_last_user_message = MagicMock()
|
||||
mock_controller.set_agent_state_to = MagicMock()
|
||||
mock_controller._loop_recovery_info = None
|
||||
|
||||
# Create a mock event stream
|
||||
event_stream = MagicMock(spec=EventStream)
|
||||
|
||||
# Call handle_resume_command with option 1
|
||||
close_repl, new_session_requested = await handle_resume_command(
|
||||
'/resume 1', event_stream, AgentState.PAUSED
|
||||
)
|
||||
|
||||
# Verify that LoopRecoveryAction was added to the event stream
|
||||
event_stream.add_event.assert_called_once()
|
||||
args, kwargs = event_stream.add_event.call_args
|
||||
loop_recovery_action, source = args
|
||||
|
||||
assert isinstance(loop_recovery_action, LoopRecoveryAction)
|
||||
assert loop_recovery_action.option == 1
|
||||
assert source == EventSource.USER
|
||||
|
||||
# Check the return values
|
||||
assert close_repl is True
|
||||
assert new_session_requested is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_recovery_resume_option_2(self):
|
||||
"""Test that resume option 2 triggers restart with last user message."""
|
||||
# Create a mock event stream
|
||||
event_stream = MagicMock(spec=EventStream)
|
||||
|
||||
# Call handle_resume_command with option 2
|
||||
close_repl, new_session_requested = await handle_resume_command(
|
||||
'/resume 2', event_stream, AgentState.PAUSED
|
||||
)
|
||||
|
||||
# Verify that LoopRecoveryAction was added to the event stream
|
||||
event_stream.add_event.assert_called_once()
|
||||
args, kwargs = event_stream.add_event.call_args
|
||||
loop_recovery_action, source = args
|
||||
|
||||
assert isinstance(loop_recovery_action, LoopRecoveryAction)
|
||||
assert loop_recovery_action.option == 2
|
||||
assert source == EventSource.USER
|
||||
|
||||
# Check the return values
|
||||
assert close_repl is True
|
||||
assert new_session_requested is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_resume_without_loop_recovery(self):
|
||||
"""Test that regular resume without option sends continue message."""
|
||||
# Create a mock event stream
|
||||
event_stream = MagicMock(spec=EventStream)
|
||||
|
||||
# Call handle_resume_command without loop recovery option
|
||||
close_repl, new_session_requested = await handle_resume_command(
|
||||
'/resume', event_stream, AgentState.PAUSED
|
||||
)
|
||||
|
||||
# Verify that MessageAction was added to the event stream
|
||||
event_stream.add_event.assert_called_once()
|
||||
args, kwargs = event_stream.add_event.call_args
|
||||
message_action, source = args
|
||||
|
||||
assert isinstance(message_action, MessageAction)
|
||||
assert message_action.content == 'continue'
|
||||
assert source == EventSource.USER
|
||||
|
||||
# Check the return values
|
||||
assert close_repl is True
|
||||
assert new_session_requested is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_commands_with_loop_recovery_resume(self):
|
||||
"""Test that handle_commands properly routes loop recovery resume commands."""
|
||||
from openhands.cli.commands import handle_commands
|
||||
|
||||
# Create mock dependencies
|
||||
event_stream = MagicMock(spec=EventStream)
|
||||
usage_metrics = MagicMock()
|
||||
sid = 'test-session-id'
|
||||
config = MagicMock()
|
||||
current_dir = '/test/dir'
|
||||
settings_store = MagicMock()
|
||||
agent_state = AgentState.PAUSED
|
||||
|
||||
# Mock handle_resume_command
|
||||
with patch(
|
||||
'openhands.cli.commands.handle_resume_command'
|
||||
) as mock_handle_resume:
|
||||
mock_handle_resume.return_value = (False, False)
|
||||
|
||||
# Call handle_commands with loop recovery resume
|
||||
close_repl, reload_microagents, new_session, _ = await handle_commands(
|
||||
'/resume 1',
|
||||
event_stream,
|
||||
usage_metrics,
|
||||
sid,
|
||||
config,
|
||||
current_dir,
|
||||
settings_store,
|
||||
agent_state,
|
||||
)
|
||||
|
||||
# Check that handle_resume_command was called with correct args
|
||||
mock_handle_resume.assert_called_once_with(
|
||||
'/resume 1', event_stream, agent_state
|
||||
)
|
||||
|
||||
# Check the return values
|
||||
assert close_repl is False
|
||||
assert reload_microagents is False
|
||||
assert new_session is False
|
||||
@@ -271,7 +271,7 @@ class TestCliCommandsPauseResume:
|
||||
)
|
||||
|
||||
# Check that handle_resume_command was called with correct args
|
||||
mock_handle_resume.assert_called_once_with(event_stream, agent_state)
|
||||
mock_handle_resume.assert_called_once_with(message, event_stream, agent_state)
|
||||
|
||||
# Check the return values
|
||||
assert close_repl is False
|
||||
|
||||
374
tests/unit/controller/test_agent_controller_loop_recovery.py
Normal file
374
tests/unit/controller/test_agent_controller_loop_recovery.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""Tests for agent controller loop recovery functionality."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
from openhands.controller.stuck import StuckDetector
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import EventStream
|
||||
from openhands.events.action import LoopRecoveryAction, MessageAction
|
||||
from openhands.events.observation import LoopDetectionObservation
|
||||
from openhands.server.services.conversation_stats import ConversationStats
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
|
||||
class TestAgentControllerLoopRecovery:
|
||||
"""Tests for agent controller loop recovery functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_controller(self):
|
||||
"""Create a mock agent controller for testing."""
|
||||
# Create mock dependencies
|
||||
mock_event_stream = MagicMock(
|
||||
spec=EventStream,
|
||||
event_stream=EventStream(
|
||||
sid='test-session-id', file_store=InMemoryFileStore({})
|
||||
),
|
||||
)
|
||||
mock_event_stream.sid = 'test-session-id'
|
||||
mock_event_stream.get_latest_event_id.return_value = 0
|
||||
|
||||
mock_conversation_stats = MagicMock(spec=ConversationStats)
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.act = AsyncMock()
|
||||
|
||||
# Create controller with correct parameters
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
conversation_stats=mock_conversation_stats,
|
||||
iteration_delta=100,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Mock state properties
|
||||
controller.state.history = []
|
||||
controller.state.agent_state = AgentState.RUNNING
|
||||
controller.state.iteration_flag = MagicMock()
|
||||
controller.state.iteration_flag.current_value = 10
|
||||
|
||||
# Mock stuck detector
|
||||
controller._stuck_detector = MagicMock(spec=StuckDetector)
|
||||
controller._stuck_detector.stuck_analysis = None
|
||||
controller._stuck_detector.is_stuck = MagicMock(return_value=False)
|
||||
|
||||
return controller
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_controller_detects_loop_and_produces_observation(
|
||||
self, mock_controller
|
||||
):
|
||||
"""Test that controller detects loops and produces LoopDetectionObservation."""
|
||||
# Setup stuck detector to detect a loop
|
||||
mock_controller._stuck_detector.is_stuck.return_value = True
|
||||
mock_controller._stuck_detector.stuck_analysis = MagicMock()
|
||||
mock_controller._stuck_detector.stuck_analysis.loop_type = (
|
||||
'repeating_action_observation'
|
||||
)
|
||||
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
|
||||
|
||||
# Call attempt_loop_recovery
|
||||
result = mock_controller.attempt_loop_recovery()
|
||||
|
||||
# Verify that loop recovery was attempted
|
||||
assert result is True
|
||||
|
||||
# Verify that LoopDetectionObservation was added to event stream
|
||||
mock_controller.event_stream.add_event.assert_called()
|
||||
|
||||
# Check that LoopDetectionObservation was created
|
||||
calls = mock_controller.event_stream.add_event.call_args_list
|
||||
loop_detection_found = False
|
||||
pause_action_found = False
|
||||
|
||||
for call in calls:
|
||||
args, _ = call
|
||||
# add_event only takes one argument (the event)
|
||||
event = args[0]
|
||||
|
||||
if isinstance(event, LoopDetectionObservation):
|
||||
loop_detection_found = True
|
||||
assert 'Agent detected in a loop!' in event.content
|
||||
assert 'repeating_action_observation' in event.content
|
||||
assert 'Loop detected at iteration 10' in event.content
|
||||
elif (
|
||||
hasattr(event, 'agent_state') and event.agent_state == AgentState.PAUSED
|
||||
):
|
||||
pause_action_found = True
|
||||
|
||||
assert loop_detection_found, 'LoopDetectionObservation should be created'
|
||||
assert pause_action_found, 'Agent should be paused'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_controller_handles_loop_recovery_action_option_1(
|
||||
self, mock_controller
|
||||
):
|
||||
"""Test that controller handles LoopRecoveryAction with option 1."""
|
||||
# Setup stuck analysis
|
||||
mock_controller._stuck_detector.stuck_analysis = MagicMock()
|
||||
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
|
||||
|
||||
# Mock the _perform_loop_recovery method for this test
|
||||
mock_controller._perform_loop_recovery = AsyncMock()
|
||||
|
||||
# Create LoopRecoveryAction with option 1
|
||||
action = LoopRecoveryAction(option=1)
|
||||
|
||||
# Call _handle_loop_recovery_action
|
||||
await mock_controller._handle_loop_recovery_action(action)
|
||||
|
||||
# Verify that _perform_loop_recovery was called
|
||||
mock_controller._perform_loop_recovery.assert_called_once_with(
|
||||
mock_controller._stuck_detector.stuck_analysis
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_controller_handles_loop_recovery_action_option_2(
|
||||
self, mock_controller
|
||||
):
|
||||
"""Test that controller handles LoopRecoveryAction with option 2."""
|
||||
# Setup stuck analysis
|
||||
mock_controller._stuck_detector.stuck_analysis = MagicMock()
|
||||
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
|
||||
|
||||
# Mock the _restart_with_last_user_message method for this test
|
||||
mock_controller._restart_with_last_user_message = AsyncMock()
|
||||
|
||||
# Create LoopRecoveryAction with option 2
|
||||
action = LoopRecoveryAction(option=2)
|
||||
|
||||
# Call _handle_loop_recovery_action
|
||||
await mock_controller._handle_loop_recovery_action(action)
|
||||
|
||||
# Verify that _restart_with_last_user_message was called
|
||||
mock_controller._restart_with_last_user_message.assert_called_once_with(
|
||||
mock_controller._stuck_detector.stuck_analysis
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_controller_handles_loop_recovery_action_option_3(
|
||||
self, mock_controller
|
||||
):
|
||||
"""Test that controller handles LoopRecoveryAction with option 3 (stop)."""
|
||||
# Setup stuck analysis
|
||||
mock_controller._stuck_detector.stuck_analysis = MagicMock()
|
||||
|
||||
# Mock the set_agent_state_to method for this test
|
||||
mock_controller.set_agent_state_to = AsyncMock()
|
||||
|
||||
# Create LoopRecoveryAction with option 3
|
||||
action = LoopRecoveryAction(option=3)
|
||||
|
||||
# Call _handle_loop_recovery_action
|
||||
await mock_controller._handle_loop_recovery_action(action)
|
||||
|
||||
# Verify that set_agent_state_to was called with STOPPED
|
||||
mock_controller.set_agent_state_to.assert_called_once_with(AgentState.STOPPED)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_controller_ignores_loop_recovery_without_stuck_analysis(
|
||||
self, mock_controller
|
||||
):
|
||||
"""Test that controller ignores LoopRecoveryAction when no stuck analysis exists."""
|
||||
# Ensure no stuck analysis
|
||||
mock_controller._stuck_detector.stuck_analysis = None
|
||||
|
||||
# Mock all recovery methods for this test
|
||||
mock_controller._perform_loop_recovery = AsyncMock()
|
||||
mock_controller._restart_with_last_user_message = AsyncMock()
|
||||
mock_controller.set_agent_state_to = AsyncMock()
|
||||
|
||||
# Create LoopRecoveryAction
|
||||
action = LoopRecoveryAction(option=1)
|
||||
|
||||
# Call _handle_loop_recovery_action
|
||||
await mock_controller._handle_loop_recovery_action(action)
|
||||
|
||||
# Verify that no recovery methods were called
|
||||
mock_controller._perform_loop_recovery.assert_not_called()
|
||||
mock_controller._restart_with_last_user_message.assert_not_called()
|
||||
mock_controller.set_agent_state_to.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_controller_no_loop_recovery_when_not_stuck(self, mock_controller):
|
||||
"""Test that controller doesn't attempt recovery when not stuck."""
|
||||
# Setup no stuck analysis
|
||||
mock_controller._stuck_detector.stuck_analysis = None
|
||||
|
||||
# Reset the mock to ignore any previous calls (like system message)
|
||||
mock_controller.event_stream.add_event.reset_mock()
|
||||
|
||||
# Call attempt_loop_recovery
|
||||
result = mock_controller.attempt_loop_recovery()
|
||||
|
||||
# Verify that no recovery was attempted
|
||||
assert result is False
|
||||
|
||||
# Verify that no loop recovery events were added to the stream
|
||||
# (Note: there might be other events, but no loop recovery specific ones)
|
||||
calls = mock_controller.event_stream.add_event.call_args_list
|
||||
loop_recovery_events = [
|
||||
call
|
||||
for call in calls
|
||||
if len(call[0]) > 0
|
||||
and (
|
||||
isinstance(call[0][0], LoopDetectionObservation)
|
||||
or (
|
||||
hasattr(call[0][0], 'agent_state')
|
||||
and call[0][0].agent_state == AgentState.PAUSED
|
||||
)
|
||||
)
|
||||
]
|
||||
assert len(loop_recovery_events) == 0, (
|
||||
'No loop recovery events should be added when not stuck'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_controller_state_transition_after_loop_recovery(
|
||||
self, mock_controller
|
||||
):
|
||||
"""Test that controller state transitions correctly after loop recovery."""
|
||||
# Setup initial state
|
||||
mock_controller.state.agent_state = AgentState.RUNNING
|
||||
|
||||
# Setup stuck detector to detect a loop
|
||||
mock_controller._stuck_detector.is_stuck.return_value = True
|
||||
mock_controller._stuck_detector.stuck_analysis = MagicMock()
|
||||
mock_controller._stuck_detector.stuck_analysis.loop_type = 'monologue'
|
||||
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 3
|
||||
|
||||
# Call attempt_loop_recovery
|
||||
result = mock_controller.attempt_loop_recovery()
|
||||
|
||||
# Verify that recovery was attempted
|
||||
assert result is True
|
||||
|
||||
# Verify that agent was paused
|
||||
calls = mock_controller.event_stream.add_event.call_args_list
|
||||
pause_found = False
|
||||
for call in calls:
|
||||
args, _ = call
|
||||
# add_event only takes one argument (the event)
|
||||
event = args[0]
|
||||
if hasattr(event, 'agent_state') and event.agent_state == AgentState.PAUSED:
|
||||
pause_found = True
|
||||
break
|
||||
|
||||
assert pause_found, 'Agent should be paused after loop detection'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_controller_resumes_after_loop_recovery(self, mock_controller):
|
||||
"""Test that controller can resume normal operation after loop recovery."""
|
||||
# Setup stuck analysis
|
||||
mock_controller._stuck_detector.stuck_analysis = MagicMock()
|
||||
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
|
||||
|
||||
# Mock the _perform_loop_recovery method for this test
|
||||
mock_controller._perform_loop_recovery = AsyncMock()
|
||||
|
||||
# Create LoopRecoveryAction with option 1
|
||||
action = LoopRecoveryAction(option=1)
|
||||
|
||||
# Call _handle_loop_recovery_action
|
||||
await mock_controller._handle_loop_recovery_action(action)
|
||||
|
||||
# Verify that recovery was performed
|
||||
mock_controller._perform_loop_recovery.assert_called_once()
|
||||
|
||||
# Verify that agent can continue normal operation
|
||||
# (This would be tested in integration tests with actual agent execution)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_controller_truncates_history_during_loop_recovery(
|
||||
self, mock_controller
|
||||
):
|
||||
"""Test that controller correctly truncates history during loop recovery."""
|
||||
# Setup mock history with events
|
||||
from openhands.events.action import CmdRunAction
|
||||
from openhands.events.observation import CmdOutputObservation, NullObservation
|
||||
|
||||
# Create a realistic history with 10 events
|
||||
mock_history = []
|
||||
|
||||
# Add initial user message
|
||||
user_msg = MessageAction(
|
||||
content='Hello, help me with this task', wait_for_response=False
|
||||
)
|
||||
user_msg._source = 'user'
|
||||
user_msg._id = 1
|
||||
mock_history.append(user_msg)
|
||||
|
||||
# Add agent response
|
||||
agent_obs = NullObservation(content='')
|
||||
agent_obs._id = 2
|
||||
mock_history.append(agent_obs)
|
||||
|
||||
# Add some commands and observations (simulating a loop)
|
||||
for i in range(3, 11):
|
||||
if i % 2 == 1: # Action
|
||||
cmd = CmdRunAction(command='ls -la')
|
||||
cmd._id = i
|
||||
mock_history.append(cmd)
|
||||
else: # Observation
|
||||
obs = CmdOutputObservation(
|
||||
content='file1.txt file2.txt', command='ls -la'
|
||||
)
|
||||
obs._id = i
|
||||
obs._cause = i - 1
|
||||
mock_history.append(obs)
|
||||
|
||||
# Set the mock history
|
||||
mock_controller.state.history = mock_history
|
||||
mock_controller.state.end_id = 10
|
||||
|
||||
# Setup stuck analysis to indicate loop starts at index 5
|
||||
mock_controller._stuck_detector.stuck_analysis = MagicMock()
|
||||
mock_controller._stuck_detector.stuck_analysis.loop_start_idx = 5
|
||||
|
||||
# Create LoopRecoveryAction with option 1 (truncate memory)
|
||||
LoopRecoveryAction(option=1)
|
||||
|
||||
# Test actual truncation by calling the _perform_loop_recovery method directly
|
||||
# Reset history for actual truncation test
|
||||
mock_controller.state.history = mock_history.copy()
|
||||
mock_controller.state.end_id = 10
|
||||
|
||||
# Call the actual _perform_loop_recovery method directly
|
||||
print(
|
||||
f'Before truncation: {len(mock_controller.state.history)} events, recovery_point={mock_controller._stuck_detector.stuck_analysis.loop_start_idx}'
|
||||
)
|
||||
print(
|
||||
f'_perform_loop_recovery method: {mock_controller._perform_loop_recovery}'
|
||||
)
|
||||
print(
|
||||
f'_truncate_memory_to_point method: {mock_controller._truncate_memory_to_point}'
|
||||
)
|
||||
await mock_controller._perform_loop_recovery(
|
||||
mock_controller._stuck_detector.stuck_analysis
|
||||
)
|
||||
|
||||
# Debug: print the actual history after truncation
|
||||
print(f'History after truncation: {len(mock_controller.state.history)} events')
|
||||
for i, event in enumerate(mock_controller.state.history):
|
||||
print(f' Event {i}: id={event.id}, type={type(event).__name__}')
|
||||
|
||||
# Verify that history was truncated to the recovery point
|
||||
# The recovery point is index 5, so we should keep events 0-4 (5 events)
|
||||
assert len(mock_controller.state.history) == 5, (
|
||||
f'Expected 5 events after truncation, got {len(mock_controller.state.history)}'
|
||||
)
|
||||
|
||||
# Verify the specific events that remain
|
||||
expected_ids = [1, 2, 3, 4, 5]
|
||||
for i, event in enumerate(mock_controller.state.history):
|
||||
assert event.id == expected_ids[i], (
|
||||
f'Event at index {i} should have id {expected_ids[i]}, got {event.id}'
|
||||
)
|
||||
|
||||
# Verify end_id was updated
|
||||
assert mock_controller.state.end_id == 5, (
|
||||
f'Expected end_id to be 5, got {mock_controller.state.end_id}'
|
||||
)
|
||||
@@ -116,6 +116,7 @@ class TestStuckDetector:
|
||||
state.history.append(cmd_observation)
|
||||
|
||||
assert stuck_detector.is_stuck(headless_mode=True) is False
|
||||
assert stuck_detector.stuck_analysis is None
|
||||
|
||||
def test_interactive_mode_resets_after_user_message(
|
||||
self, stuck_detector: StuckDetector
|
||||
@@ -237,6 +238,11 @@ class TestStuckDetector:
|
||||
assert stuck_detector.is_stuck(headless_mode=True) is True
|
||||
mock_warning.assert_called_once_with('Action, Observation loop detected')
|
||||
|
||||
# recover to before first loop pattern
|
||||
assert stuck_detector.stuck_analysis.loop_type == 'repeating_action_observation'
|
||||
assert stuck_detector.stuck_analysis.loop_repeat_times == 4
|
||||
assert stuck_detector.stuck_analysis.loop_start_idx == 1
|
||||
|
||||
def test_is_stuck_repeating_action_error(self, stuck_detector: StuckDetector):
|
||||
state = stuck_detector.state
|
||||
# (action, error_observation), not necessarily the same error
|
||||
@@ -290,6 +296,9 @@ class TestStuckDetector:
|
||||
mock_warning.assert_called_once_with(
|
||||
'Action, ErrorObservation loop detected'
|
||||
)
|
||||
assert stuck_detector.stuck_analysis.loop_type == 'repeating_action_error'
|
||||
assert stuck_detector.stuck_analysis.loop_repeat_times == 3
|
||||
assert stuck_detector.stuck_analysis.loop_start_idx == 1
|
||||
|
||||
def test_is_stuck_invalid_syntax_error(self, stuck_detector: StuckDetector):
|
||||
state = stuck_detector.state
|
||||
@@ -494,6 +503,12 @@ class TestStuckDetector:
|
||||
with patch('logging.Logger.warning') as mock_warning:
|
||||
assert stuck_detector.is_stuck(headless_mode=True) is True
|
||||
mock_warning.assert_called_once_with('Action, Observation pattern detected')
|
||||
assert (
|
||||
stuck_detector.stuck_analysis.loop_type
|
||||
== 'repeating_action_observation_pattern'
|
||||
)
|
||||
assert stuck_detector.stuck_analysis.loop_repeat_times == 3
|
||||
assert stuck_detector.stuck_analysis.loop_start_idx == 0 # null ignored
|
||||
|
||||
def test_is_stuck_not_stuck(self, stuck_detector: StuckDetector):
|
||||
state = stuck_detector.state
|
||||
@@ -585,6 +600,9 @@ class TestStuckDetector:
|
||||
state.history.append(message_action_6)
|
||||
|
||||
assert stuck_detector.is_stuck(headless_mode=True)
|
||||
assert stuck_detector.stuck_analysis.loop_type == 'monologue'
|
||||
assert stuck_detector.stuck_analysis.loop_repeat_times == 3
|
||||
assert stuck_detector.stuck_analysis.loop_start_idx == 2 # null ignored
|
||||
|
||||
# Add an observation event between the repeated message actions
|
||||
cmd_output_observation = CmdOutputObservation(
|
||||
@@ -628,6 +646,9 @@ class TestStuckDetector:
|
||||
mock_warning.assert_called_once_with(
|
||||
'Context window error loop detected - repeated condensation events'
|
||||
)
|
||||
assert stuck_detector.stuck_analysis.loop_type == 'context_window_error'
|
||||
assert stuck_detector.stuck_analysis.loop_repeat_times == 2
|
||||
assert stuck_detector.stuck_analysis.loop_start_idx == 0
|
||||
|
||||
def test_is_not_stuck_context_window_error_with_other_events(self, stuck_detector):
|
||||
"""Test that we don't detect a loop when there are other events between condensation events."""
|
||||
@@ -731,6 +752,9 @@ class TestStuckDetector:
|
||||
mock_warning.assert_called_once_with(
|
||||
'Context window error loop detected - repeated condensation events'
|
||||
)
|
||||
assert stuck_detector.stuck_analysis.loop_type == 'context_window_error'
|
||||
assert stuck_detector.stuck_analysis.loop_repeat_times == 2
|
||||
assert stuck_detector.stuck_analysis.loop_start_idx == 0
|
||||
|
||||
def test_is_not_stuck_context_window_error_in_non_headless(self, stuck_detector):
|
||||
"""Test that in non-headless mode, we don't detect a loop if the condensation events
|
||||
|
||||
Reference in New Issue
Block a user