Fix/update controller is_stuck() (#1891)

* Refactor monologue to use the messages in state history

remove now unused method

* is_stuck update

* fix is_stuck

* unit tests

* fix tests

* Revert "Refactor monologue to use the messages in state history"

This reverts commit 76b4b765ef.

* Override eq for CmdOutputObservation to ignore the pid, compare the actual command only

* Revert "Override eq for CmdOutputObservation to ignore the pid, compare the actual command only"

This reverts commit 6418d856b5.
This commit is contained in:
Engel Nyst
2024-05-21 16:56:59 +02:00
committed by GitHub
parent 4add8a5595
commit 1e51bb9276
2 changed files with 256 additions and 21 deletions

View File

@@ -244,32 +244,43 @@ class AgentController:
# check if delegate stuck
if self.delegate and self.delegate._is_stuck():
return True
if len(self.state.history) < 3:
# filter out MessageAction with source='user' from history
filtered_history = [
_tuple
for _tuple in self.state.history
if not (
isinstance(_tuple[0], MessageAction)
and _tuple[0].source == EventSource.USER
)
]
if len(filtered_history) < 4:
return False
# if the last three (Action, Observation) tuples are too repetitive
# the agent got stuck in a loop
if all(
[
self.state.history[-i][0] == self.state.history[-3][0]
for i in range(1, 3)
]
):
# it repeats same action, give it a chance, but not if:
# Check if the last four (Action, Observation) tuples are too repetitive
last_four_tuples = filtered_history[-4:]
if all(last_four_tuples[-1] == _tuple for _tuple in last_four_tuples):
logger.warning('Action, Observation loop detected')
return True
if all(last_four_tuples[-1][0] == _tuple[0] for _tuple in last_four_tuples):
# It repeats the same action, give it a chance, but not if:
if all(
isinstance(self.state.history[-i][1], NullObservation)
for i in range(1, 4)
isinstance(_tuple[1], ErrorObservation) for _tuple in last_four_tuples
):
# same (Action, NullObservation): like 'think' the same thought over and over
logger.warning('Action, NullObservation loop detected')
return True
elif all(
isinstance(self.state.history[-i][1], ErrorObservation)
for i in range(1, 4)
):
# (NullAction, ErrorObservation): errors coming from an exception
# (Action, ErrorObservation): the same action getting an error, even if not necessarily the same error
logger.warning('Action, ErrorObservation loop detected')
return True
# check if the agent repeats the same (Action, Observation)
# every other step in the last six tuples
if len(filtered_history) >= 6:
last_six_tuples = filtered_history[-6:]
if (
last_six_tuples[-1] == last_six_tuples[-3] == last_six_tuples[-5]
and last_six_tuples[-2] == last_six_tuples[-4] == last_six_tuples[-6]
):
logger.warning('Action, Observation pattern detected')
return True
return False

224
tests/unit/test_is_stuck.py Normal file
View File

@@ -0,0 +1,224 @@
from unittest.mock import Mock, patch
import pytest
from opendevin.controller.agent_controller import AgentController
from opendevin.events.action import CmdRunAction, FileReadAction, MessageAction
from opendevin.events.observation import (
CmdOutputObservation,
FileReadObservation,
Observation,
)
from opendevin.events.observation.empty import NullObservation
from opendevin.events.observation.error import ErrorObservation
from opendevin.events.stream import EventSource
class TestAgentController:
@pytest.fixture
def controller(self):
controller = Mock(spec=AgentController)
controller._is_stuck = AgentController._is_stuck.__get__(
controller, AgentController
)
controller.delegate = None
controller.state = Mock()
controller.state.history = []
return controller
def test_history_too_short(self, controller):
controller.state.history = [
(
MessageAction(content='Hello', wait_for_response=False),
Observation(content='Response 1'),
),
(
CmdRunAction(command='ls'),
CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
),
),
]
assert controller._is_stuck() is False
def test_is_stuck_repeating_action_null_observation(self, controller):
# message actions with source USER are not considered in the stuck check
message_action = MessageAction(content='Done', wait_for_response=False)
message_action._source = EventSource.USER
controller.state.history = [
(
MessageAction(content='Hello', wait_for_response=False),
Observation(content='Response 1'),
),
(CmdRunAction(command='ls'), NullObservation(content='')),
(CmdRunAction(command='ls'), NullObservation(content='')),
# user message shouldn't break detection
(message_action, NullObservation(content='')),
(CmdRunAction(command='ls'), NullObservation(content='')),
(CmdRunAction(command='ls'), NullObservation(content='')),
]
with patch('logging.Logger.warning') as mock_warning:
assert controller._is_stuck() is True
mock_warning.assert_called_once_with('Action, Observation loop detected')
def test_is_stuck_repeating_action_error_observation(self, controller):
message_action = MessageAction(content='Done', wait_for_response=False)
message_action._source = EventSource.USER
controller.state.history = [
(
MessageAction(content='Hello', wait_for_response=False),
Observation(content='Response 1'),
),
(
CmdRunAction(command='invalid_command'),
ErrorObservation(content='Command not found'),
),
(
CmdRunAction(command='invalid_command'),
ErrorObservation(content='Command not found'),
),
# user message shouldn't break detection
(message_action, NullObservation(content='')),
(
CmdRunAction(command='invalid_command'),
ErrorObservation(content='Different error'),
),
(
CmdRunAction(command='invalid_command'),
ErrorObservation(content='Command not found'),
),
]
with patch('logging.Logger.warning') as mock_warning:
assert controller._is_stuck() is True
mock_warning.assert_called_once_with(
'Action, ErrorObservation loop detected'
)
def test_is_stuck_repeating_action_observation_pattern(self, controller):
# six tuples of action, observation
message_action = MessageAction(content='Come on', wait_for_response=False)
message_action._source = EventSource.USER
controller.state.history = [
(
message_action,
Observation(content=''),
),
(
CmdRunAction(command='ls'),
CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
),
),
(
FileReadAction(path='file1.txt'),
FileReadObservation(content='File content', path='file1.txt'),
),
(
CmdRunAction(command='ls'),
CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
),
),
(
FileReadAction(path='file1.txt'),
FileReadObservation(content='File content', path='file1.txt'),
),
# insert a message just because they can, it shouldn't break detection
(message_action, NullObservation(content='')),
(
CmdRunAction(command='ls'),
CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
),
),
(
FileReadAction(path='file1.txt'),
FileReadObservation(content='File content', path='file1.txt'),
),
]
with patch('logging.Logger.warning') as mock_warning:
assert controller._is_stuck() is True
mock_warning.assert_called_once_with('Action, Observation pattern detected')
def test_is_stuck_not_stuck(self, controller):
message_action = MessageAction(content='Done', wait_for_response=False)
message_action._source = EventSource.USER
controller.state.history = [
(
MessageAction(content='Hello', wait_for_response=False),
Observation(content='Response 1'),
),
(
CmdRunAction(command='ls'),
CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
),
),
(
FileReadAction(path='file1.txt'),
FileReadObservation(content='File content', path='file1.txt'),
),
(
CmdRunAction(command='pwd'),
CmdOutputObservation(command_id=1, command='pwd', content='/home/user'),
),
(
FileReadAction(path='file2.txt'),
Observation(content='Another file content'),
),
# insert a message from the user
(message_action, NullObservation(content='')),
(
CmdRunAction(command='pwd'),
CmdOutputObservation(command_id=1, command='pwd', content='/home/user'),
),
(
FileReadAction(path='file2.txt'),
Observation(content='Another file content'),
),
]
assert controller._is_stuck() is False
def test_is_stuck_four_identical_tuples(self, controller):
message_action = MessageAction(content='Done', wait_for_response=False)
message_action._source = EventSource.USER
controller.state.history = [
(
MessageAction(content='Hello', wait_for_response=False),
Observation(content='Response 1'),
),
(
CmdRunAction(command='ls'),
CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
),
),
(
CmdRunAction(command='ls'),
CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
),
),
# message from the user
(message_action, NullObservation(content='')),
(
CmdRunAction(command='ls'),
CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
),
),
(
CmdRunAction(command='ls'),
CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
),
),
]
with patch('logging.Logger.warning') as mock_warning:
assert controller._is_stuck() is True
mock_warning.assert_called_once_with('Action, Observation loop detected')
def test_is_stuck_delegate_stuck(self, controller):
controller.delegate = Mock()
controller.delegate._is_stuck.return_value = True
assert controller._is_stuck() is True