From 7d331acffa83e9bb709299f26c408f3b361b9941 Mon Sep 17 00:00:00 2001 From: Graham Neubig Date: Wed, 14 Aug 2024 09:47:31 -0400 Subject: [PATCH] Handle error observations in codeact (#3383) * Handle error observations in codeact * Remove comments --- agenthub/codeact_agent/codeact_agent.py | 10 ++- opendevin/controller/agent_controller.py | 4 +- opendevin/events/action/agent.py | 11 ++- opendevin/events/observation/delegate.py | 8 +- tests/unit/test_codeact_agent.py | 97 ++++++++++++++++++++++++ 5 files changed, 125 insertions(+), 5 deletions(-) create mode 100644 tests/unit/test_codeact_agent.py diff --git a/agenthub/codeact_agent/codeact_agent.py b/agenthub/codeact_agent/codeact_agent.py index 56d1b5f5c9..559e53641f 100644 --- a/agenthub/codeact_agent/codeact_agent.py +++ b/agenthub/codeact_agent/codeact_agent.py @@ -22,6 +22,7 @@ from opendevin.events.observation import ( CmdOutputObservation, IPythonRunCellObservation, ) +from opendevin.events.observation.error import ErrorObservation from opendevin.events.observation.observation import Observation from opendevin.events.serialization.event import truncate_content from opendevin.llm.llm import LLM @@ -169,7 +170,14 @@ class CodeActAgent(Agent): str(obs.outputs), max_message_chars ) return Message(role='user', content=[TextContent(text=text)]) - return None + elif isinstance(obs, ErrorObservation): + text = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars) + text += '\n[Error occurred in processing last action]' + return Message(role='user', content=[TextContent(text=text)]) + else: + # If an observation message is not returned, it will cause an error + # when the LLM tries to return the next message + raise ValueError(f'Unknown observation type: {type(obs)}') def reset(self) -> None: """Resets the CodeAct Agent.""" diff --git a/opendevin/controller/agent_controller.py b/opendevin/controller/agent_controller.py index 350ae49d5c..46bb209258 100644 --- a/opendevin/controller/agent_controller.py +++ b/opendevin/controller/agent_controller.py @@ -176,11 +176,11 @@ class AgentController: elif isinstance(event, ModifyTaskAction): self.state.root_task.set_subtask_state(event.task_id, event.state) elif isinstance(event, AgentFinishAction): - self.state.outputs = event.outputs # type: ignore[attr-defined] + self.state.outputs = event.outputs self.state.metrics.merge(self.state.local_metrics) await self.set_agent_state_to(AgentState.FINISHED) elif isinstance(event, AgentRejectAction): - self.state.outputs = event.outputs # type: ignore[attr-defined] + self.state.outputs = event.outputs self.state.metrics.merge(self.state.local_metrics) await self.set_agent_state_to(AgentState.REJECTED) elif isinstance(event, Observation): diff --git a/opendevin/events/action/agent.py b/opendevin/events/action/agent.py index c9d4ec6a3e..d57e02b038 100644 --- a/opendevin/events/action/agent.py +++ b/opendevin/events/action/agent.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from typing import Any from opendevin.core.schema import ActionType @@ -35,7 +36,15 @@ class AgentSummarizeAction(Action): @dataclass class AgentFinishAction(Action): - outputs: dict = field(default_factory=dict) + """An action where the agent finishes the task. + + Attributes: + outputs (dict): The outputs of the agent, for instance "content". + thought (str): The agent's explanation of its actions. + action (str): The action type, namely ActionType.FINISH. + """ + + outputs: dict[str, Any] = field(default_factory=dict) thought: str = '' action: str = ActionType.FINISH diff --git a/opendevin/events/observation/delegate.py b/opendevin/events/observation/delegate.py index fe3d9fc4b9..c50a0a37da 100644 --- a/opendevin/events/observation/delegate.py +++ b/opendevin/events/observation/delegate.py @@ -7,7 +7,13 @@ from .observation import Observation @dataclass class AgentDelegateObservation(Observation): - """This data class represents the result from delegating to another agent""" + """This data class represents the result from delegating to another agent. + + Attributes: + content (str): The content of the observation. + outputs (dict): The outputs of the delegated agent. + observation (str): The type of observation. + """ outputs: dict observation: str = ObservationType.DELEGATE diff --git a/tests/unit/test_codeact_agent.py b/tests/unit/test_codeact_agent.py new file mode 100644 index 0000000000..b152dc27f9 --- /dev/null +++ b/tests/unit/test_codeact_agent.py @@ -0,0 +1,97 @@ +from unittest.mock import Mock + +import pytest + +from agenthub.codeact_agent.codeact_agent import CodeActAgent +from opendevin.core.config import LLMConfig +from opendevin.core.message import TextContent +from opendevin.events.observation.commands import ( + CmdOutputObservation, + IPythonRunCellObservation, +) +from opendevin.events.observation.delegate import AgentDelegateObservation +from opendevin.events.observation.error import ErrorObservation +from opendevin.llm.llm import LLM + + +@pytest.fixture +def agent() -> CodeActAgent: + agent = CodeActAgent(llm=LLM(LLMConfig())) + agent.llm = Mock() + agent.llm.config = Mock() + agent.llm.config.max_message_chars = 100 + return agent + + +def test_cmd_output_observation_message(agent: CodeActAgent): + obs = CmdOutputObservation( + command='echo hello', content='Command output', command_id=1, exit_code=0 + ) + + result = agent.get_observation_message(obs) + + assert result is not None + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert 'OBSERVATION:' in result.content[0].text + assert 'Command output' in result.content[0].text + assert 'Command 1 finished with exit code 0' in result.content[0].text + + +def test_ipython_run_cell_observation_message(agent: CodeActAgent): + obs = IPythonRunCellObservation( + code='plt.plot()', + content='IPython output\n![image]()', + ) + + result = agent.get_observation_message(obs) + + assert result is not None + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert 'OBSERVATION:' in result.content[0].text + assert 'IPython output' in result.content[0].text + assert ( + '![image](data:image/png;base64, ...) already displayed to user' + in result.content[0].text + ) + assert 'ABC123' not in result.content[0].text + + +def test_agent_delegate_observation_message(agent: CodeActAgent): + obs = AgentDelegateObservation( + content='Content', outputs={'content': 'Delegated agent output'} + ) + + result = agent.get_observation_message(obs) + + assert result is not None + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert 'OBSERVATION:' in result.content[0].text + assert 'Delegated agent output' in result.content[0].text + + +def test_error_observation_message(agent: CodeActAgent): + obs = ErrorObservation('Error message') + + result = agent.get_observation_message(obs) + + assert result is not None + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert 'OBSERVATION:' in result.content[0].text + assert 'Error message' in result.content[0].text + assert 'Error occurred in processing last action' in result.content[0].text + + +def test_unknown_observation_message(agent: CodeActAgent): + obs = Mock() + + with pytest.raises(ValueError) as excinfo: + agent.get_observation_message(obs) + assert 'Unknown observation type:' in str(excinfo.value)