Override eq for CmdOutputObservation to ignore the pid, compare the actual command only

This commit is contained in:
Engel Nyst
2024-05-18 22:40:55 +02:00
parent 6d1963d7ae
commit 6418d856b5
2 changed files with 19 additions and 7 deletions

View File

@@ -27,6 +27,15 @@ class CmdOutputObservation(Observation):
def __str__(self) -> str:
return f'**CmdOutputObservation (exit code={self.exit_code})**\n{self.content}'
def __eq__(self, other: object) -> bool:
"""
Compare two CmdOutputObservation objects ignoring the command_id/pid.
"""
if isinstance(other, CmdOutputObservation):
# for loop detection purpose, we care about running the same command, not the same pid
return self.command == other.command and self.exit_code == other.exit_code
return False
@dataclass
class IPythonRunCellObservation(Observation):

View File

@@ -115,8 +115,9 @@ class TestAgentController:
),
(
CmdRunAction(command='ls'),
# command_id is ignored for the eq check, it could be a pid
CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
command_id=2, command='ls', content='file1.txt\nfile2.txt'
),
),
(
@@ -128,7 +129,7 @@ class TestAgentController:
(
CmdRunAction(command='ls'),
CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
command_id=3, command='ls', content='file1.txt\nfile2.txt'
),
),
(
@@ -160,7 +161,8 @@ class TestAgentController:
),
(
CmdRunAction(command='pwd'),
CmdOutputObservation(command_id=1, command='pwd', content='/home/user'),
# command_id is ignored for the eq check, it could be a pid
CmdOutputObservation(command_id=2, command='pwd', content='/home/user'),
),
(
FileReadAction(path='file2.txt'),
@@ -170,7 +172,7 @@ class TestAgentController:
(message_action, NullObservation(content='')),
(
CmdRunAction(command='pwd'),
CmdOutputObservation(command_id=1, command='pwd', content='/home/user'),
CmdOutputObservation(command_id=3, command='pwd', content='/home/user'),
),
(
FileReadAction(path='file2.txt'),
@@ -195,8 +197,9 @@ class TestAgentController:
),
(
CmdRunAction(command='ls'),
# command_id is ignored for the eq check, it could be a pid
CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
command_id=2, command='ls', content='file1.txt\nfile2.txt'
),
),
# message from the user
@@ -204,13 +207,13 @@ class TestAgentController:
(
CmdRunAction(command='ls'),
CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
command_id=3, command='ls', content='file1.txt\nfile2.txt'
),
),
(
CmdRunAction(command='ls'),
CmdOutputObservation(
command_id=1, command='ls', content='file1.txt\nfile2.txt'
command_id=4, command='ls', content='file1.txt\nfile2.txt'
),
),
]