Compare commits

...

2 Commits

Author SHA1 Message Date
openhands 1e739bddb7 fix: improve stuck detection in UI mode
- Add UI mode awareness to stuck detection
- Only consider history after last user message in UI mode
- Keep existing behavior in headless mode
- Add comprehensive tests for both modes

Fix: #5480
2024-12-14 16:46:07 +00:00
openhands 212787c072 refactor: remove unused almost_stuck functionality
- Remove almost_stuck field from State class
- Remove almost_stuck counter from StuckDetector
- Simplify stuck detection logic to focus on actual loop detection
- Update tests to remove almost_stuck assertions
2024-12-14 16:01:02 +00:00
4 changed files with 95 additions and 81 deletions
+7 -4
View File
@@ -319,7 +319,7 @@ class AgentController:
def _reset(self) -> None:
"""Resets the agent controller"""
self.almost_stuck = 0
self._pending_action = None
self.agent.reset()
@@ -902,17 +902,20 @@ class AgentController:
return kept_events
def _is_stuck(self) -> bool:
def _is_stuck(self, ui_mode: bool | None = None) -> bool:
"""Checks if the agent or its delegate is stuck in a loop.
Args:
ui_mode: Optional override for UI mode. If not provided, uses not self.headless_mode.
Returns:
bool: True if the agent is stuck, False otherwise.
"""
# check if delegate stuck
if self.delegate and self.delegate._is_stuck():
if self.delegate and self.delegate._is_stuck(ui_mode):
return True
return self._stuck_detector.is_stuck()
return self._stuck_detector.is_stuck(ui_mode if ui_mode is not None else not self.headless_mode)
def __repr__(self):
return (
+1 -1
View File
@@ -94,7 +94,7 @@ class State:
end_id: int = -1
# truncation_id tracks where to load history after context window truncation
truncation_id: int = -1
almost_stuck: int = 0
delegates: dict[tuple[int, int], tuple[str, str]] = field(default_factory=dict)
# NOTE: This will never be used by the controller, but it can be used by different
# evaluation tasks to store extra data needed to track the progress/state of the task.
+22 -36
View File
@@ -24,16 +24,26 @@ class StuckDetector:
def __init__(self, state: State):
self.state = state
def is_stuck(self):
# filter out MessageAction with source='user' from history
def is_stuck(self, ui_mode: bool = False):
if ui_mode:
# In UI mode, only look at history after the last user message
last_user_msg_idx = -1
for i, event in enumerate(self.state.history):
if isinstance(event, MessageAction) and event.source == EventSource.USER:
last_user_msg_idx = i
history_to_check = self.state.history[last_user_msg_idx + 1:]
else:
# In headless mode, look at all history
history_to_check = self.state.history
# Filter out user messages and null events
filtered_history = [
event
for event in self.state.history
for event in history_to_check
if not (
(isinstance(event, MessageAction) and event.source == EventSource.USER)
or
# there might be some NullAction or NullObservation in the history at least for now
isinstance(event, (NullAction, NullObservation))
or isinstance(event, (NullAction, NullObservation))
)
]
@@ -81,43 +91,19 @@ class StuckDetector:
# it takes 4 actions and 4 observations to detect a loop
# assert len(last_actions) == 4 and len(last_observations) == 4
# reset almost_stuck reminder
self.state.almost_stuck = 0
# almost stuck? if two actions, obs are the same, we're almost stuck
if len(last_actions) >= 2 and len(last_observations) >= 2:
# Check for a loop of 4 identical action-observation pairs
if len(last_actions) == 4 and len(last_observations) == 4:
actions_equal = all(
self._eq_no_pid(last_actions[0], action) for action in last_actions[:2]
self._eq_no_pid(last_actions[0], action) for action in last_actions
)
observations_equal = all(
self._eq_no_pid(last_observations[0], observation)
for observation in last_observations[:2]
for observation in last_observations
)
# the last two actions and obs are the same?
if actions_equal and observations_equal:
self.state.almost_stuck = 2
# the last three actions and observations are the same?
if len(last_actions) >= 3 and len(last_observations) >= 3:
if (
actions_equal
and observations_equal
and self._eq_no_pid(last_actions[0], last_actions[2])
and self._eq_no_pid(last_observations[0], last_observations[2])
):
self.state.almost_stuck = 1
if len(last_actions) == 4 and len(last_observations) == 4:
if (
actions_equal
and observations_equal
and self._eq_no_pid(last_actions[0], last_actions[3])
and self._eq_no_pid(last_observations[0], last_observations[3])
):
logger.warning('Action, Observation loop detected')
self.state.almost_stuck = 0
return True
logger.warning('Action, Observation loop detected')
return True
return False
+65 -40
View File
@@ -112,7 +112,53 @@ class TestStuckDetector:
# cmd_observation._cause = cmd_action._id
state.history.append(cmd_observation)
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False
assert stuck_detector.is_stuck(ui_mode=True) is False
def test_ui_mode_resets_after_user_message(self, stuck_detector: StuckDetector):
state = stuck_detector.state
# First add some actions that would be stuck in non-UI mode
for i in range(4):
cmd_action = CmdRunAction(command='ls')
cmd_action._id = i
state.history.append(cmd_action)
cmd_observation = CmdOutputObservation(content='', command='ls', command_id=i)
cmd_observation._cause = cmd_action._id
state.history.append(cmd_observation)
# In non-UI mode, this should be stuck
assert stuck_detector.is_stuck(ui_mode=False) is True
# Add a user message
message_action = MessageAction(content='Hello', wait_for_response=False)
message_action._source = EventSource.USER
state.history.append(message_action)
# In UI mode, this should not be stuck because we ignore history before user message
assert stuck_detector.is_stuck(ui_mode=True) is False
# Add two more identical actions - still not stuck because we need at least 3
for i in range(2):
cmd_action = CmdRunAction(command='ls')
cmd_action._id = i + 4
state.history.append(cmd_action)
cmd_observation = CmdOutputObservation(content='', command='ls', command_id=i + 4)
cmd_observation._cause = cmd_action._id
state.history.append(cmd_observation)
assert stuck_detector.is_stuck(ui_mode=True) is False
# Add two more identical actions - now it should be stuck
for i in range(2):
cmd_action = CmdRunAction(command='ls')
cmd_action._id = i + 6
state.history.append(cmd_action)
cmd_observation = CmdOutputObservation(content='', command='ls', command_id=i + 6)
cmd_observation._cause = cmd_action._id
state.history.append(cmd_observation)
assert stuck_detector.is_stuck(ui_mode=True) is True
def test_is_stuck_repeating_action_observation(self, stuck_detector: StuckDetector):
state = stuck_detector.state
@@ -148,8 +194,7 @@ class TestStuckDetector:
state.history.append(message_null_observation)
# 8 events
assert stuck_detector.is_stuck() is False
assert stuck_detector.state.almost_stuck == 2
assert stuck_detector.is_stuck(ui_mode=False) is False
cmd_action_3 = CmdRunAction(command='ls')
cmd_action_3._id = 3
@@ -160,20 +205,7 @@ class TestStuckDetector:
# 10 events
assert len(state.history) == 10
assert (
len(state.history) == 10
) # Adjusted since history is a list and the controller is not running
# FIXME are we still testing this without this test?
# assert (
# len(
# get_pairs_from_events(state.history)
# )
# == 5
# )
assert stuck_detector.is_stuck() is False
assert stuck_detector.state.almost_stuck == 1
assert stuck_detector.is_stuck(ui_mode=False) is False
cmd_action_4 = CmdRunAction(command='ls')
cmd_action_4._id = 4
@@ -184,16 +216,9 @@ class TestStuckDetector:
# 12 events
assert len(state.history) == 12
# assert (
# len(
# get_pairs_from_events(state.history)
# )
# == 6
# )
with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is True
assert stuck_detector.state.almost_stuck == 0
assert stuck_detector.is_stuck(ui_mode=False) is True
mock_warning.assert_called_once_with('Action, Observation loop detected')
def test_is_stuck_repeating_action_error(self, stuck_detector: StuckDetector):
@@ -245,7 +270,7 @@ class TestStuckDetector:
# 12 events
with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(ui_mode=False) is True
mock_warning.assert_called_once_with(
'Action, ErrorObservation loop detected'
)
@@ -259,7 +284,7 @@ class TestStuckDetector:
)
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(ui_mode=False) is True
def test_is_not_stuck_invalid_syntax_error_random_lines(
self, stuck_detector: StuckDetector
@@ -272,7 +297,7 @@ class TestStuckDetector:
)
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False
def test_is_not_stuck_invalid_syntax_error_only_three_incidents(
self, stuck_detector: StuckDetector
@@ -286,7 +311,7 @@ class TestStuckDetector:
)
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False
def test_is_stuck_incomplete_input_error(self, stuck_detector: StuckDetector):
state = stuck_detector.state
@@ -297,7 +322,7 @@ class TestStuckDetector:
)
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(ui_mode=False) is True
def test_is_not_stuck_incomplete_input_error(self, stuck_detector: StuckDetector):
state = stuck_detector.state
@@ -308,7 +333,7 @@ class TestStuckDetector:
)
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False
def test_is_not_stuck_ipython_unterminated_string_error_random_lines(
self, stuck_detector: StuckDetector
@@ -317,7 +342,7 @@ class TestStuckDetector:
self._impl_unterminated_string_error_events(state, random_line=True)
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False
def test_is_not_stuck_ipython_unterminated_string_error_only_three_incidents(
self, stuck_detector: StuckDetector
@@ -328,7 +353,7 @@ class TestStuckDetector:
)
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False
def test_is_stuck_ipython_unterminated_string_error(
self, stuck_detector: StuckDetector
@@ -337,7 +362,7 @@ class TestStuckDetector:
self._impl_unterminated_string_error_events(state, random_line=False)
with patch('logging.Logger.warning'):
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(ui_mode=False) is True
def test_is_not_stuck_ipython_syntax_error_not_at_end(
self, stuck_detector: StuckDetector
@@ -382,7 +407,7 @@ class TestStuckDetector:
state.history.append(ipython_observation_4)
with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False
mock_warning.assert_not_called()
def test_is_stuck_repeating_action_observation_pattern(
@@ -451,7 +476,7 @@ class TestStuckDetector:
state.history.append(read_observation_3)
with patch('logging.Logger.warning') as mock_warning:
assert stuck_detector.is_stuck() is True
assert stuck_detector.is_stuck(ui_mode=False) is True
mock_warning.assert_called_once_with('Action, Observation pattern detected')
def test_is_stuck_not_stuck(self, stuck_detector: StuckDetector):
@@ -517,7 +542,7 @@ class TestStuckDetector:
# read_observation_3._cause = read_action_3._id
state.history.append(read_observation_3)
assert stuck_detector.is_stuck() is False
assert stuck_detector.is_stuck(ui_mode=False) is False
def test_is_stuck_monologue(self, stuck_detector):
state = stuck_detector.state
@@ -547,7 +572,7 @@ class TestStuckDetector:
message_action_6._source = EventSource.AGENT
state.history.append(message_action_6)
assert stuck_detector.is_stuck()
assert stuck_detector.is_stuck(ui_mode=False)
# Add an observation event between the repeated message actions
cmd_output_observation = CmdOutputObservation(
@@ -567,7 +592,7 @@ class TestStuckDetector:
state.history.append(message_action_8)
with patch('logging.Logger.warning'):
assert not stuck_detector.is_stuck()
assert not stuck_detector.is_stuck(ui_mode=False)
class TestAgentController:
@@ -584,4 +609,4 @@ class TestAgentController:
def test_is_stuck_delegate_stuck(self, controller: AgentController):
controller.delegate = Mock()
controller.delegate._is_stuck.return_value = True
assert controller._is_stuck() is True
assert controller._is_stuck(ui_mode=False) is True