Ignore pid for loop detection (Was: override eq...) (#2045)

* rewrite, implement pid ignore in the controller

* make the helper method private
This commit is contained in:
Engel Nyst
2024-05-26 19:27:12 +02:00
committed by GitHub
parent 2c0a2dbc61
commit 783fea62a0
2 changed files with 124 additions and 13 deletions

View File

@@ -24,6 +24,7 @@ from opendevin.events.action import (
ModifyTaskAction,
NullAction,
)
from opendevin.events.action.commands import CmdKillAction
from opendevin.events.event import Event
from opendevin.events.observation import (
AgentDelegateObservation,
@@ -271,13 +272,29 @@ class AgentController:
if len(filtered_history) < 4:
return False
# FIXME rewrite this to be more readable
# Check if the last four (Action, Observation) tuples are too repetitive
last_four_tuples = filtered_history[-4:]
if all(last_four_tuples[-1] == _tuple for _tuple in last_four_tuples):
if all(
# (Action, Observation) tuples
# compare the last action to the last four actions
self._eq_no_pid(last_four_tuples[-1][0], _tuple[0])
for _tuple in last_four_tuples
) and all(
# compare the last observation to the last four observations
self._eq_no_pid(last_four_tuples[-1][1], _tuple[1])
for _tuple in last_four_tuples
):
logger.warning('Action, Observation loop detected')
return True
if all(last_four_tuples[-1][0] == _tuple[0] for _tuple in last_four_tuples):
# (action, error) tuples
if all(
self._eq_no_pid(last_four_tuples[-1][0], _tuple[0])
for _tuple in last_four_tuples
):
# It repeats the same action, give it a chance, but not if:
if all(
isinstance(_tuple[1], ErrorObservation) for _tuple in last_four_tuples
@@ -287,13 +304,35 @@ class AgentController:
# check if the agent repeats the same (Action, Observation)
# every other step in the last six tuples
if len(filtered_history) >= 6:
last_six_tuples = filtered_history[-6:]
if (
last_six_tuples[-1] == last_six_tuples[-3] == last_six_tuples[-5]
and last_six_tuples[-2] == last_six_tuples[-4] == last_six_tuples[-6]
# this pattern is every other step, like:
# (action_1, obs_1), (action_2, obs_2), (action_1, obs_1), (action_2, obs_2),...
self._eq_no_pid(last_six_tuples[-1][0], last_six_tuples[-3][0])
and self._eq_no_pid(last_six_tuples[-1][0], last_six_tuples[-5][0])
and self._eq_no_pid(last_six_tuples[-2][0], last_six_tuples[-4][0])
and self._eq_no_pid(last_six_tuples[-2][0], last_six_tuples[-6][0])
and self._eq_no_pid(last_six_tuples[-1][1], last_six_tuples[-3][1])
and self._eq_no_pid(last_six_tuples[-1][1], last_six_tuples[-5][1])
and self._eq_no_pid(last_six_tuples[-2][1], last_six_tuples[-4][1])
and self._eq_no_pid(last_six_tuples[-2][1], last_six_tuples[-6][1])
):
logger.warning('Action, Observation pattern detected')
return True
return False
def _eq_no_pid(self, obj1, obj2):
if isinstance(obj1, CmdOutputObservation) and isinstance(
obj2, CmdOutputObservation
):
# for loop detection, ignore command_id, which is the pid
return obj1.command == obj2.command and obj1.exit_code == obj2.exit_code
elif isinstance(obj1, CmdKillAction) and isinstance(obj2, CmdKillAction):
# for loop detection, ignore command_id, which is the pid
return obj1.thought == obj2.thought
else:
# this is the default comparison
return obj1 == obj2

View File

@@ -4,6 +4,7 @@ import pytest
from opendevin.controller.agent_controller import AgentController
from opendevin.events.action import CmdRunAction, FileReadAction, MessageAction
from opendevin.events.action.commands import CmdKillAction
from opendevin.events.observation import (
CmdOutputObservation,
FileReadObservation,
@@ -21,6 +22,9 @@ class TestAgentController:
controller._is_stuck = AgentController._is_stuck.__get__(
controller, AgentController
)
controller._eq_no_pid = AgentController._eq_no_pid.__get__(
controller, AgentController
)
controller.delegate = None
controller.state = Mock()
controller.state.history = []
@@ -75,7 +79,7 @@ class TestAgentController:
),
(
CmdRunAction(command='invalid_command'),
ErrorObservation(content='Command not found'),
ErrorObservation(content='Command still not found or another error'),
),
# user message shouldn't break detection
(message_action, NullObservation(content='')),
@@ -115,8 +119,9 @@ class TestAgentController:
),
(
CmdRunAction(command='ls'),
# command_id is ignored for the eq check, it's a pid
CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
command_id=2, command='ls', content='file1.txt\nfile2.txt'
),
),
(
@@ -128,7 +133,7 @@ class TestAgentController:
(
CmdRunAction(command='ls'),
CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
command_id=3, command='ls', content='file1.txt\nfile2.txt'
),
),
(
@@ -160,7 +165,8 @@ class TestAgentController:
),
(
CmdRunAction(command='pwd'),
CmdOutputObservation(command_id=1, command='pwd', content='/home/user'),
# command_id is ignored for the eq check, it's the pid
CmdOutputObservation(command_id=2, command='pwd', content='/home/user'),
),
(
FileReadAction(path='file2.txt'),
@@ -170,7 +176,7 @@ class TestAgentController:
(message_action, NullObservation(content='')),
(
CmdRunAction(command='pwd'),
CmdOutputObservation(command_id=1, command='pwd', content='/home/user'),
CmdOutputObservation(command_id=3, command='pwd', content='/home/user'),
),
(
FileReadAction(path='file2.txt'),
@@ -195,22 +201,88 @@ class TestAgentController:
),
(
CmdRunAction(command='ls'),
# command_id is ignored for the eq check, it's just the pid
CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
command_id=2, command='ls', content='file1.txt\nfile2.txt'
),
),
# message from the user
# message from the user shouldn't interfere with the detection
(message_action, NullObservation(content='')),
(
CmdRunAction(command='ls'),
CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
command_id=3, command='ls', content='file1.txt\nfile2.txt'
),
),
(
CmdRunAction(command='ls'),
CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
command_id=4, command='ls', content='file1.txt\nfile2.txt'
),
),
]
with patch('logging.Logger.warning') as mock_warning:
assert controller._is_stuck() is True
mock_warning.assert_called_once_with('Action, Observation loop detected')
def test_is_stuck_four_tuples_cmd_kill_and_output(self, controller):
message_action = MessageAction(content='Done', wait_for_response=False)
message_action._source = EventSource.USER
controller.state.history = [
(
MessageAction(content='Hello', wait_for_response=False),
Observation(content='Response 1'),
),
(
CmdKillAction(
command_id=1,
thought='It looks like storybook is stuck, lets kill it',
),
CmdOutputObservation(
content='Background command storybook has been killed.',
command_id=1,
command='storybook',
exit_code=0,
),
),
(
# command_id is ignored for the eq check, it's the pid
CmdKillAction(
command_id=2,
thought='It looks like storybook is stuck, lets kill it',
),
# command_id here too
CmdOutputObservation(
content='Background command storybook has been killed.',
command_id=2,
command='storybook',
exit_code=0,
),
),
# message from the user, shouldn't be counted
(message_action, NullObservation(content='')),
(
CmdKillAction(
command_id=3,
thought='It looks like storybook is stuck, lets kill it',
),
CmdOutputObservation(
content='Background command storybook has been killed.',
command_id=3,
command='storybook',
exit_code=0,
),
),
(
CmdKillAction(
command_id=4,
thought='It looks like storybook is stuck, lets kill it',
),
CmdOutputObservation(
content='Background command storybook has been killed.',
command_id=4,
command='storybook',
exit_code=0,
),
),
]