mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
28 Commits
fix/events
...
openhands-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7be21240ea | ||
|
|
d30c21a66d | ||
|
|
e6ebbe67f8 | ||
|
|
f6a4587327 | ||
|
|
4cea1fd343 | ||
|
|
de8753cef6 | ||
|
|
be8f0782fb | ||
|
|
ec784e1535 | ||
|
|
db80ee68d5 | ||
|
|
579a9edbb0 | ||
|
|
e82aa13fe1 | ||
|
|
7a752364b5 | ||
|
|
dfc8db8a90 | ||
|
|
c513bee4ff | ||
|
|
8f9a7a131c | ||
|
|
a765196e53 | ||
|
|
60aff59c3f | ||
|
|
86a844ebbb | ||
|
|
945e81d262 | ||
|
|
0fe1f58d97 | ||
|
|
028e715c77 | ||
|
|
f26710b7ae | ||
|
|
8a6d1344a9 | ||
|
|
1e17bbc228 | ||
|
|
ab3d004bc9 | ||
|
|
bd5637fd3d | ||
|
|
1f6be838a7 | ||
|
|
e42a3a6407 |
@@ -1,10 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
from collections import deque
|
||||
|
||||
from litellm import ModelResponse
|
||||
|
||||
import openhands
|
||||
import openhands.agenthub.codeact_agent.function_calling as codeact_function_calling
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.state.state import State
|
||||
@@ -23,6 +21,8 @@ from openhands.events.action import (
|
||||
FileReadAction,
|
||||
IPythonRunCellAction,
|
||||
MessageAction,
|
||||
PromptExtensionAction,
|
||||
SystemMessageAction,
|
||||
)
|
||||
from openhands.events.observation import (
|
||||
AgentCondensationObservation,
|
||||
@@ -44,7 +44,6 @@ from openhands.runtime.plugins import (
|
||||
JupyterRequirement,
|
||||
PluginRequirement,
|
||||
)
|
||||
from openhands.utils.prompt import PromptManager
|
||||
|
||||
|
||||
class CodeActAgent(Agent):
|
||||
@@ -99,20 +98,30 @@ class CodeActAgent(Agent):
|
||||
logger.debug(
|
||||
f'TOOLS loaded for CodeActAgent: {json.dumps(self.tools, indent=2, ensure_ascii=False).replace("\\n", "\n")}'
|
||||
)
|
||||
self.prompt_manager = PromptManager(
|
||||
microagent_dir=os.path.join(
|
||||
os.path.dirname(os.path.dirname(openhands.__file__)),
|
||||
'microagents',
|
||||
)
|
||||
if self.config.enable_prompt_extensions
|
||||
else None,
|
||||
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
|
||||
disabled_microagents=self.config.disabled_microagents,
|
||||
)
|
||||
|
||||
self.condenser = Condenser.from_config(self.config.condenser)
|
||||
logger.debug(f'Using condenser: {self.condenser}')
|
||||
|
||||
def get_system_message(self) -> str:
|
||||
"""Get the system message for this agent.
|
||||
|
||||
Returns:
|
||||
The system message as a string
|
||||
"""
|
||||
import os
|
||||
|
||||
# Determine the prompt_dir for CodeActAgent
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
prompt_dir = os.path.join(current_dir, 'prompts')
|
||||
system_message_path = os.path.join(prompt_dir, 'system_message.txt')
|
||||
|
||||
if os.path.exists(system_message_path):
|
||||
with open(system_message_path, 'r') as f:
|
||||
return f.read().strip()
|
||||
else:
|
||||
return "You are a helpful AI assistant."
|
||||
|
||||
|
||||
def get_action_message(
|
||||
self,
|
||||
action: Action,
|
||||
@@ -231,6 +240,22 @@ class CodeActAgent(Agent):
|
||||
content=content,
|
||||
)
|
||||
]
|
||||
elif isinstance(action, SystemMessageAction):
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
content=[TextContent(text=action.content)],
|
||||
)
|
||||
]
|
||||
elif isinstance(action, PromptExtensionAction):
|
||||
# For prompt extensions, we add them as assistant messages
|
||||
# This way they appear in the conversation history but don't interfere with the system message
|
||||
return [
|
||||
Message(
|
||||
role='assistant',
|
||||
content=[TextContent(text=action.content)],
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
def get_observation_message(
|
||||
@@ -427,20 +452,7 @@ class CodeActAgent(Agent):
|
||||
- Messages from the same role are combined to prevent consecutive same-role messages
|
||||
- For Anthropic models, specific messages are cached according to their documentation
|
||||
"""
|
||||
if not self.prompt_manager:
|
||||
raise Exception('Prompt Manager not instantiated.')
|
||||
|
||||
messages: list[Message] = [
|
||||
Message(
|
||||
role='system',
|
||||
content=[
|
||||
TextContent(
|
||||
text=self.prompt_manager.get_system_message(),
|
||||
cache_prompt=self.llm.is_caching_prompt_active(),
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
messages: list[Message] = []
|
||||
|
||||
pending_tool_call_action_messages: dict[str, Message] = {}
|
||||
tool_call_id_to_message: dict[str, Message] = {}
|
||||
@@ -448,7 +460,6 @@ class CodeActAgent(Agent):
|
||||
# Condense the events from the state.
|
||||
events = self.condenser.condensed_history(state)
|
||||
|
||||
is_first_message_handled = False
|
||||
for event in events:
|
||||
# create a regular message from an event
|
||||
if isinstance(event, Action):
|
||||
@@ -492,19 +503,6 @@ class CodeActAgent(Agent):
|
||||
|
||||
for msg in messages_to_add:
|
||||
if msg:
|
||||
if msg.role == 'user' and not is_first_message_handled:
|
||||
is_first_message_handled = True
|
||||
# compose the first user message with examples
|
||||
self.prompt_manager.add_examples_to_initial_message(msg)
|
||||
|
||||
# and/or repo/runtime info
|
||||
if self.config.enable_prompt_extensions:
|
||||
self.prompt_manager.add_info_to_initial_message(msg)
|
||||
|
||||
# enhance the user message with additional context based on keywords matched
|
||||
if msg.role == 'user':
|
||||
self.prompt_manager.enhance_message(msg)
|
||||
|
||||
messages.append(msg)
|
||||
|
||||
if self.llm.is_caching_prompt_active():
|
||||
|
||||
@@ -111,3 +111,22 @@ class Agent(ABC):
|
||||
if not bool(cls._registry):
|
||||
raise AgentNotRegisteredError()
|
||||
return list(cls._registry.keys())
|
||||
|
||||
def get_prompt_manager(
|
||||
self,
|
||||
microagent_dir: str | None = None,
|
||||
disabled_microagents: list[str] | None = None,
|
||||
) -> 'PromptManager | None':
|
||||
"""Get a prompt manager instance for this agent.
|
||||
|
||||
This method can be overridden by subclasses to return a specialized prompt manager.
|
||||
By default, it returns None.
|
||||
|
||||
Args:
|
||||
microagent_dir: Directory containing microagent prompts
|
||||
disabled_microagents: List of microagents to disable
|
||||
|
||||
Returns:
|
||||
A PromptManager instance, a subclass of PromptManager, or None
|
||||
"""
|
||||
return None
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Callable, ClassVar, Type
|
||||
|
||||
@@ -12,6 +13,7 @@ from litellm.exceptions import (
|
||||
RateLimitError,
|
||||
)
|
||||
|
||||
import openhands
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.controller.replay import ReplayManager
|
||||
from openhands.controller.state.state import State, TrafficControlState
|
||||
@@ -27,6 +29,7 @@ from openhands.core.exceptions import (
|
||||
)
|
||||
from openhands.core.logger import LOG_ALL_EVENTS
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.message import Message, TextContent
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import EventSource, EventStream, EventStreamSubscriber
|
||||
from openhands.events.action import (
|
||||
@@ -40,6 +43,8 @@ from openhands.events.action import (
|
||||
IPythonRunCellAction,
|
||||
MessageAction,
|
||||
NullAction,
|
||||
PromptExtensionAction,
|
||||
SystemMessageAction,
|
||||
)
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
@@ -147,6 +152,18 @@ class AgentController:
|
||||
# replay-related
|
||||
self._replay_manager = ReplayManager(replay_events)
|
||||
|
||||
# First user message tracking
|
||||
self._first_user_message_received = False
|
||||
|
||||
# Send system message at initialization if the agent provides one
|
||||
if not self.is_delegate and hasattr(agent, 'get_system_message'):
|
||||
system_message = agent.get_system_message()
|
||||
if system_message:
|
||||
self.event_stream.add_event(
|
||||
SystemMessageAction(content=system_message),
|
||||
EventSource.AGENT,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.
|
||||
|
||||
@@ -401,11 +418,58 @@ class AgentController:
|
||||
self.state.max_iterations = (
|
||||
self.state.iteration + self._initial_max_iterations
|
||||
)
|
||||
if (
|
||||
self.state.traffic_control_state == TrafficControlState.THROTTLING
|
||||
or self.state.traffic_control_state == TrafficControlState.PAUSED
|
||||
):
|
||||
self.state.traffic_control_state = TrafficControlState.NORMAL
|
||||
|
||||
# Add the message to the event stream
|
||||
self.event_stream.add_event(action, EventSource.USER)
|
||||
|
||||
# Handle first user message to trigger prompt extensions
|
||||
if (
|
||||
not self._first_user_message_received
|
||||
and self.prompt_manager
|
||||
and self.agent.config.enable_prompt_extensions
|
||||
):
|
||||
self._first_user_message_received = True
|
||||
# Create a Message object from the action content
|
||||
msg = Message(
|
||||
role='user',
|
||||
content=[TextContent(text=action.content)],
|
||||
)
|
||||
|
||||
# Add examples to initial message
|
||||
self.prompt_manager.add_examples_to_initial_message(msg)
|
||||
if msg.content[0].text != action.content:
|
||||
self.event_stream.add_event(
|
||||
PromptExtensionAction(
|
||||
content=msg.content[0].text, extension_type='examples'
|
||||
),
|
||||
EventSource.AGENT,
|
||||
)
|
||||
|
||||
# Add info to initial message
|
||||
self.prompt_manager.add_info_to_initial_message(msg)
|
||||
if msg.content[0].text != action.content:
|
||||
self.event_stream.add_event(
|
||||
PromptExtensionAction(
|
||||
content=msg.content[0].text, extension_type='info'
|
||||
),
|
||||
EventSource.AGENT,
|
||||
)
|
||||
|
||||
# Enhance message
|
||||
self.prompt_manager.enhance_message(msg)
|
||||
if msg.content[0].text != action.content:
|
||||
self.event_stream.add_event(
|
||||
PromptExtensionAction(
|
||||
content=msg.content[0].text, extension_type='enhance'
|
||||
),
|
||||
EventSource.AGENT,
|
||||
)
|
||||
|
||||
if (
|
||||
self.state.traffic_control_state == TrafficControlState.THROTTLING
|
||||
or self.state.traffic_control_state == TrafficControlState.PAUSED
|
||||
):
|
||||
self.state.traffic_control_state = TrafficControlState.NORMAL
|
||||
self.log(
|
||||
'debug',
|
||||
f'Extended max iterations to {self.state.max_iterations} after user message',
|
||||
@@ -574,7 +638,8 @@ class AgentController:
|
||||
delegate_state = self.delegate.get_agent_state()
|
||||
|
||||
# update iteration that is shared across agents
|
||||
self.state.iteration = self.delegate.state.iteration
|
||||
# Add 1 to account for the parent's step that initiated delegation
|
||||
self.state.iteration = self.delegate.state.iteration + 1
|
||||
|
||||
# close the delegate controller before adding new events
|
||||
asyncio.get_event_loop().run_until_complete(self.delegate.close())
|
||||
@@ -761,6 +826,10 @@ class AgentController:
|
||||
current_str = f'{current_value:.2f}'
|
||||
max_str = f'{max_value:.2f}'
|
||||
|
||||
# Sync metrics before handling the error
|
||||
await self.update_state_after_step()
|
||||
self.state.metrics.merge(self.state.local_metrics)
|
||||
|
||||
if self.headless_mode:
|
||||
e = RuntimeError(
|
||||
f'Agent reached maximum {limit_type} in headless mode. '
|
||||
@@ -1038,6 +1107,12 @@ class AgentController:
|
||||
# start_id points to first user message
|
||||
if first_user_msg:
|
||||
self.state.start_id = first_user_msg.id
|
||||
# Make sure start_id is not greater than truncation_id
|
||||
if (
|
||||
self.state.truncation_id is not None
|
||||
and self.state.start_id > self.state.truncation_id
|
||||
):
|
||||
self.state.truncation_id = self.state.start_id
|
||||
|
||||
return kept_events
|
||||
|
||||
|
||||
@@ -78,5 +78,11 @@ class ActionTypeSchema(BaseModel):
|
||||
SEND_PR: str = Field(default='send_pr')
|
||||
"""Send a PR to github."""
|
||||
|
||||
SYSTEM_MESSAGE: str = Field(default='system_message')
|
||||
"""Send a system message to the agent."""
|
||||
|
||||
PROMPT_EXTENSION: str = Field(default='prompt_extension')
|
||||
"""Add extensions to the prompt."""
|
||||
|
||||
|
||||
ActionType = ActionTypeSchema()
|
||||
|
||||
@@ -5,6 +5,8 @@ from openhands.events.action.agent import (
|
||||
AgentRejectAction,
|
||||
AgentSummarizeAction,
|
||||
ChangeAgentStateAction,
|
||||
PromptExtensionAction,
|
||||
SystemMessageAction,
|
||||
)
|
||||
from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction
|
||||
from openhands.events.action.commands import CmdRunAction, IPythonRunCellAction
|
||||
@@ -33,4 +35,6 @@ __all__ = [
|
||||
'IPythonRunCellAction',
|
||||
'MessageAction',
|
||||
'ActionConfirmationStatus',
|
||||
'SystemMessageAction',
|
||||
'PromptExtensionAction',
|
||||
]
|
||||
|
||||
@@ -78,3 +78,39 @@ class AgentDelegateAction(Action):
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f"I'm asking {self.agent} for help with this task."
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemMessageAction(Action):
|
||||
"""An action that sends a system message to the agent.
|
||||
|
||||
Attributes:
|
||||
content (str): The system message content.
|
||||
action (str): The action type, namely ActionType.SYSTEM_MESSAGE.
|
||||
"""
|
||||
|
||||
content: str
|
||||
action: str = ActionType.SYSTEM_MESSAGE
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'System message: {self.content}'
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptExtensionAction(Action):
|
||||
"""An action that adds extensions to the prompt.
|
||||
|
||||
Attributes:
|
||||
content (str): The prompt extension content.
|
||||
extension_type (str): The type of extension (e.g., 'examples', 'info', 'enhance').
|
||||
action (str): The action type, namely ActionType.PROMPT_EXTENSION.
|
||||
"""
|
||||
|
||||
content: str
|
||||
extension_type: str
|
||||
action: str = ActionType.PROMPT_EXTENSION
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Prompt extension (type: {self.extension_type})'
|
||||
|
||||
@@ -313,13 +313,22 @@ class AgentSession:
|
||||
try:
|
||||
restored_state = State.restore_from_session(self.sid, self.file_store)
|
||||
logger.debug(f'Restored state from session, sid: {self.sid}')
|
||||
|
||||
# Set start_id to 0 since we want to include all events from the beginning
|
||||
restored_state.start_id = 0
|
||||
# Set end_id to -1 to indicate we want to include all events up to the latest
|
||||
restored_state.end_id = -1
|
||||
# Set truncation_id to -1 since we haven't truncated anything yet
|
||||
restored_state.truncation_id = -1
|
||||
|
||||
return restored_state
|
||||
except Exception as e:
|
||||
if self.event_stream.get_latest_event_id() > 0:
|
||||
# if we have events, we should have a state
|
||||
logger.warning(f'State could not be restored: {e}')
|
||||
else:
|
||||
logger.debug('No events found, no state to restore')
|
||||
return restored_state
|
||||
return None
|
||||
|
||||
def get_state(self) -> AgentState | None:
|
||||
controller = self.controller
|
||||
|
||||
@@ -12,7 +12,13 @@ from openhands.core.config import AppConfig
|
||||
from openhands.core.main import run_controller
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
|
||||
from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
|
||||
from openhands.events.action import (
|
||||
ChangeAgentStateAction,
|
||||
CmdRunAction,
|
||||
MessageAction,
|
||||
PromptExtensionAction,
|
||||
SystemMessageAction,
|
||||
)
|
||||
from openhands.events.observation import (
|
||||
ErrorObservation,
|
||||
)
|
||||
@@ -21,6 +27,7 @@ from openhands.llm import LLM
|
||||
from openhands.llm.metrics import Metrics
|
||||
from openhands.runtime.base import Runtime
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
from openhands.utils.prompt import PromptManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -41,6 +48,10 @@ def mock_agent():
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = AppConfig().get_llm_config()
|
||||
agent.config = MagicMock()
|
||||
agent.config.enable_prompt_extensions = False
|
||||
agent.config.disabled_microagents = []
|
||||
agent.get_prompt_manager.return_value = None
|
||||
return agent
|
||||
|
||||
|
||||
@@ -144,7 +155,6 @@ async def test_run_controller_with_fatal_error():
|
||||
file_store = InMemoryFileStore({})
|
||||
event_stream = EventStream(sid='test', file_store=file_store)
|
||||
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent = MagicMock(spec=Agent)
|
||||
|
||||
def agent_step_fn(state):
|
||||
@@ -155,6 +165,10 @@ async def test_run_controller_with_fatal_error():
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = config.get_llm_config()
|
||||
agent.config = MagicMock()
|
||||
agent.config.enable_prompt_extensions = False
|
||||
agent.config.disabled_microagents = []
|
||||
agent.get_prompt_manager.return_value = None
|
||||
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
|
||||
@@ -200,6 +214,10 @@ async def test_run_controller_stop_with_stuck():
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.config = config.get_llm_config()
|
||||
agent.config = MagicMock()
|
||||
agent.config.enable_prompt_extensions = False
|
||||
agent.config.disabled_microagents = []
|
||||
agent.get_prompt_manager.return_value = None
|
||||
runtime = MagicMock(spec=Runtime)
|
||||
|
||||
def on_event(event: Event):
|
||||
@@ -518,16 +536,23 @@ async def test_run_controller_max_iterations_has_metrics():
|
||||
|
||||
agent = MagicMock(spec=Agent)
|
||||
agent.llm = MagicMock(spec=LLM)
|
||||
agent.llm.metrics = Metrics()
|
||||
agent.llm.metrics = Metrics() # Start with fresh metrics
|
||||
agent.llm.config = config.get_llm_config()
|
||||
agent.config = MagicMock()
|
||||
agent.config.enable_prompt_extensions = False
|
||||
agent.config.disabled_microagents = []
|
||||
agent.get_prompt_manager.return_value = None
|
||||
|
||||
# Keep track of total cost
|
||||
total_cost = 0.0
|
||||
|
||||
def agent_step_fn(state):
|
||||
print(f'agent_step_fn received state: {state}')
|
||||
# Mock the cost of the LLM
|
||||
agent.llm.metrics.add_cost(10.0)
|
||||
print(
|
||||
f'agent.llm.metrics.accumulated_cost: {agent.llm.metrics.accumulated_cost}'
|
||||
)
|
||||
nonlocal total_cost
|
||||
total_cost += 10.0
|
||||
state.metrics.add_cost(10.0) # Add cost directly to state metrics
|
||||
print(f'state.metrics.accumulated_cost: {state.metrics.accumulated_cost}')
|
||||
return CmdRunAction(command='ls')
|
||||
|
||||
agent.step = agent_step_fn
|
||||
@@ -591,9 +616,12 @@ async def test_context_window_exceeded_error_handling(mock_agent, mock_event_str
|
||||
def step(self, state: State):
|
||||
# Append a few messages to the history -- these will be truncated when we throw the error
|
||||
state.history = [
|
||||
MessageAction(content='Test message 0'),
|
||||
MessageAction(content='Test message 1'),
|
||||
MessageAction(content='Test message 0'), # First user message
|
||||
MessageAction(content='Test message 1'), # Agent response
|
||||
MessageAction(content='Test message 2'), # Another agent message
|
||||
MessageAction(content='Test message 3'), # Another agent message
|
||||
]
|
||||
state.history[0]._source = EventSource.USER
|
||||
|
||||
error = ContextWindowExceededError(
|
||||
message='prompt is too long: 233885 tokens > 200000 maximum',
|
||||
@@ -623,7 +651,17 @@ async def test_context_window_exceeded_error_handling(mock_agent, mock_event_str
|
||||
|
||||
# Check that the error was thrown and the history has been truncated
|
||||
assert state.has_errored
|
||||
assert controller.state.history == [MessageAction(content='Test message 1')]
|
||||
# Should keep first user message and second half of history
|
||||
expected_history = [
|
||||
MessageAction(content='Test message 0'), # First user message always kept
|
||||
MessageAction(content='Test message 2'), # Second half of history
|
||||
MessageAction(content='Test message 3'),
|
||||
]
|
||||
expected_history[0]._source = EventSource.USER
|
||||
assert len(controller.state.history) == len(expected_history)
|
||||
for actual, expected in zip(controller.state.history, expected_history):
|
||||
assert actual.content == expected.content
|
||||
assert actual.source == expected.source
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -682,3 +720,246 @@ async def test_run_controller_with_context_window_exceeded(mock_agent, mock_runt
|
||||
|
||||
# Check that the context window exceeded error was raised during the run
|
||||
assert step_state.has_errored
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_manager_initialization(mock_agent, mock_event_stream):
|
||||
"""Test that the prompt manager is properly initialized and sends system message."""
|
||||
# Mock the prompt manager
|
||||
mock_prompt_manager = MagicMock(spec=PromptManager)
|
||||
mock_prompt_manager.get_system_message.return_value = 'Test system message'
|
||||
mock_agent.get_prompt_manager.return_value = mock_prompt_manager
|
||||
|
||||
# Create controller
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Verify that system message was sent
|
||||
mock_event_stream.add_event.assert_called_with(
|
||||
SystemMessageAction(content='Test system message'),
|
||||
EventSource.AGENT,
|
||||
)
|
||||
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_manager_extensions(mock_agent, mock_event_stream):
|
||||
"""Test that prompt extensions are properly added to event stream."""
|
||||
# Mock the prompt manager and enable extensions
|
||||
mock_agent.config.enable_prompt_extensions = True
|
||||
mock_prompt_manager = MagicMock(spec=PromptManager)
|
||||
mock_prompt_manager.get_system_message.return_value = 'Test system message'
|
||||
|
||||
# Mock the prompt extension methods
|
||||
def add_examples(msg):
|
||||
msg.content[0].text = 'Examples added: ' + msg.content[0].text
|
||||
|
||||
mock_prompt_manager.add_examples_to_initial_message.side_effect = add_examples
|
||||
|
||||
def add_info(msg):
|
||||
msg.content[0].text = 'Info added: ' + msg.content[0].text
|
||||
|
||||
mock_prompt_manager.add_info_to_initial_message.side_effect = add_info
|
||||
|
||||
def enhance(msg):
|
||||
msg.content[0].text = 'Enhanced: ' + msg.content[0].text
|
||||
|
||||
mock_prompt_manager.enhance_message.side_effect = enhance
|
||||
|
||||
mock_agent.get_prompt_manager.return_value = mock_prompt_manager
|
||||
|
||||
# Create controller
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Send a user message
|
||||
message_action = MessageAction(content='Test message')
|
||||
message_action._source = EventSource.USER
|
||||
await controller._on_event(message_action)
|
||||
|
||||
# Get all calls to add_event
|
||||
actual_calls = mock_event_stream.add_event.call_args_list
|
||||
|
||||
# Verify that system message was added
|
||||
assert any(
|
||||
isinstance(args[0], SystemMessageAction)
|
||||
and args[0].content == 'Test system message'
|
||||
for args, _ in actual_calls
|
||||
)
|
||||
|
||||
# Verify that prompt extensions were added
|
||||
expected_extensions = [
|
||||
('Examples added: Test message', 'examples'),
|
||||
('Info added: Examples added: Test message', 'info'),
|
||||
('Enhanced: Info added: Examples added: Test message', 'enhance'),
|
||||
]
|
||||
for content, ext_type in expected_extensions:
|
||||
assert any(
|
||||
isinstance(args[0], PromptExtensionAction)
|
||||
and args[0].content == content
|
||||
and args[0].extension_type == ext_type
|
||||
for args, _ in actual_calls
|
||||
), f'Missing extension: {ext_type}'
|
||||
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_manager_extensions_disabled(mock_agent, mock_event_stream):
|
||||
"""Test that prompt extensions are not added when disabled."""
|
||||
# Mock the prompt manager but disable extensions
|
||||
mock_agent.config.enable_prompt_extensions = False
|
||||
mock_prompt_manager = MagicMock(spec=PromptManager)
|
||||
mock_prompt_manager.get_system_message.return_value = 'Test system message'
|
||||
mock_agent.get_prompt_manager.return_value = mock_prompt_manager
|
||||
|
||||
# Create controller
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Send a user message
|
||||
message_action = MessageAction(content='Test message')
|
||||
message_action._source = EventSource.USER
|
||||
await controller._on_event(message_action)
|
||||
|
||||
# Get all calls to add_event
|
||||
actual_calls = mock_event_stream.add_event.call_args_list
|
||||
|
||||
# Verify that system message was added
|
||||
assert any(
|
||||
isinstance(args[0], SystemMessageAction)
|
||||
and args[0].content == 'Test system message'
|
||||
for args, _ in actual_calls
|
||||
)
|
||||
|
||||
# Verify that no prompt extensions were added
|
||||
assert not any(
|
||||
isinstance(args[0], PromptExtensionAction) for args, _ in actual_calls
|
||||
)
|
||||
|
||||
# Verify that extension methods were not called
|
||||
mock_prompt_manager.add_examples_to_initial_message.assert_not_called()
|
||||
mock_prompt_manager.add_info_to_initial_message.assert_not_called()
|
||||
mock_prompt_manager.enhance_message.assert_not_called()
|
||||
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_manager_extensions_delegate(mock_agent, mock_event_stream):
|
||||
"""Test that prompt extensions are not added for delegate controllers."""
|
||||
# Mock the prompt manager
|
||||
mock_prompt_manager = MagicMock(spec=PromptManager)
|
||||
mock_prompt_manager.get_system_message.return_value = 'Test system message'
|
||||
mock_agent.get_prompt_manager.return_value = mock_prompt_manager
|
||||
|
||||
# Create delegate controller
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
is_delegate=True, # This should prevent system message and extensions
|
||||
)
|
||||
|
||||
# Send a user message
|
||||
message_action = MessageAction(content='Test message')
|
||||
message_action._source = EventSource.USER
|
||||
await controller._on_event(message_action)
|
||||
|
||||
# Get all calls to add_event
|
||||
actual_calls = mock_event_stream.add_event.call_args_list
|
||||
|
||||
# Verify that only the original message was added (plus state changes)
|
||||
print('Message action:', message_action)
|
||||
print('Actual calls:', actual_calls)
|
||||
assert any(
|
||||
args[0] == message_action and args[1] == EventSource.USER
|
||||
for args, _ in actual_calls
|
||||
)
|
||||
|
||||
# Verify that no system message or prompt extensions were added
|
||||
assert not any(
|
||||
isinstance(args[0], SystemMessageAction)
|
||||
or isinstance(args[0], PromptExtensionAction)
|
||||
for args, _ in actual_calls
|
||||
)
|
||||
|
||||
# Verify that system message and extension methods were not called
|
||||
mock_prompt_manager.get_system_message.assert_not_called()
|
||||
mock_prompt_manager.add_examples_to_initial_message.assert_not_called()
|
||||
mock_prompt_manager.add_info_to_initial_message.assert_not_called()
|
||||
mock_prompt_manager.enhance_message.assert_not_called()
|
||||
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_manager_not_initialized(mock_agent, mock_event_stream):
|
||||
"""Test that no system message is sent if prompt manager is not initialized."""
|
||||
# Set prompt manager to None
|
||||
mock_agent.prompt_manager = None
|
||||
|
||||
# Create controller
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
)
|
||||
|
||||
# Verify that no system message was sent
|
||||
for call in mock_event_stream.add_event.call_args_list:
|
||||
args, _ = call
|
||||
assert not isinstance(args[0], SystemMessageAction)
|
||||
|
||||
await controller.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_manager_delegate_initialization(mock_agent, mock_event_stream):
|
||||
"""Test that system message is not sent for delegate controllers."""
|
||||
# Mock the prompt manager
|
||||
mock_agent.prompt_manager = MagicMock(spec=PromptManager)
|
||||
mock_agent.prompt_manager.get_system_message.return_value = 'Test system message'
|
||||
|
||||
# Create controller with is_delegate=True
|
||||
controller = AgentController(
|
||||
agent=mock_agent,
|
||||
event_stream=mock_event_stream,
|
||||
max_iterations=10,
|
||||
sid='test',
|
||||
confirmation_mode=False,
|
||||
headless_mode=True,
|
||||
is_delegate=True, # This should prevent system message from being sent
|
||||
)
|
||||
|
||||
# Verify that no system message was sent
|
||||
for call in mock_event_stream.add_event.call_args_list:
|
||||
args, _ = call
|
||||
assert not isinstance(args[0], SystemMessageAction)
|
||||
|
||||
await controller.close()
|
||||
|
||||
@@ -499,6 +499,9 @@ def test_step_with_no_pending_actions(mock_state: State):
|
||||
config = AgentConfig()
|
||||
config.enable_prompt_extensions = False
|
||||
agent = CodeActAgent(llm=llm, config=config)
|
||||
agent.prompt_manager = Mock()
|
||||
agent.prompt_manager.get_system_message.return_value = 'System message'
|
||||
agent.prompt_manager.add_examples_to_initial_message = Mock()
|
||||
|
||||
# Test step with no pending actions
|
||||
mock_state.latest_user_message = None
|
||||
@@ -508,6 +511,7 @@ def test_step_with_no_pending_actions(mock_state: State):
|
||||
mock_state.latest_user_message_timeout = None
|
||||
mock_state.latest_user_message_llm_metrics = None
|
||||
mock_state.latest_user_message_tool_call_metadata = None
|
||||
mock_state.history = []
|
||||
|
||||
action = agent.step(mock_state)
|
||||
assert isinstance(action, MessageAction)
|
||||
@@ -516,7 +520,16 @@ def test_step_with_no_pending_actions(mock_state: State):
|
||||
|
||||
def test_mismatched_tool_call_events(mock_state: State):
|
||||
"""Tests that the agent can convert mismatched tool call events (i.e., an observation with no corresponding action) into messages."""
|
||||
agent = CodeActAgent(llm=LLM(LLMConfig()), config=AgentConfig())
|
||||
llm = Mock()
|
||||
llm.is_function_calling_active = Mock(return_value=True) # Enable function calling
|
||||
llm.is_caching_prompt_active = Mock(return_value=False)
|
||||
llm.config = Mock()
|
||||
llm.config.max_message_chars = 1000
|
||||
|
||||
agent = CodeActAgent(llm=llm, config=AgentConfig())
|
||||
agent.prompt_manager = Mock()
|
||||
agent.prompt_manager.get_system_message.return_value = 'System message'
|
||||
agent.prompt_manager.add_examples_to_initial_message = Mock()
|
||||
|
||||
tool_call_metadata = Mock(
|
||||
spec=ToolCallMetadata,
|
||||
|
||||
@@ -79,23 +79,22 @@ def test_get_messages(codeact_agent: CodeActAgent):
|
||||
)
|
||||
|
||||
assert (
|
||||
len(messages) == 6
|
||||
) # System, initial user + user message, agent message, last user message
|
||||
assert messages[0].content[0].cache_prompt # system message
|
||||
assert messages[1].role == 'user'
|
||||
assert messages[1].content[0].text.endswith('Initial user message')
|
||||
len(messages) == 5
|
||||
) # Initial user message, agent message, user message, agent message, last user message
|
||||
assert messages[0].role == 'user'
|
||||
assert messages[0].content[0].text.endswith('Initial user message')
|
||||
# we add cache breakpoint to the last 3 user messages
|
||||
assert messages[1].content[0].cache_prompt
|
||||
assert messages[0].content[0].cache_prompt
|
||||
|
||||
assert messages[3].role == 'user'
|
||||
assert messages[3].content[0].text == ('Hello, agent!')
|
||||
assert messages[3].content[0].cache_prompt
|
||||
assert messages[4].role == 'assistant'
|
||||
assert messages[4].content[0].text == 'Hello, user!'
|
||||
assert not messages[4].content[0].cache_prompt
|
||||
assert messages[5].role == 'user'
|
||||
assert messages[5].content[0].text.startswith('Laaaaaaaast!')
|
||||
assert messages[5].content[0].cache_prompt
|
||||
assert messages[2].role == 'user'
|
||||
assert messages[2].content[0].text == ('Hello, agent!')
|
||||
assert messages[2].content[0].cache_prompt
|
||||
assert messages[3].role == 'assistant'
|
||||
assert messages[3].content[0].text == 'Hello, user!'
|
||||
assert not messages[3].content[0].cache_prompt
|
||||
assert messages[4].role == 'user'
|
||||
assert messages[4].content[0].text.startswith('Laaaaaaaast!')
|
||||
assert messages[4].content[0].cache_prompt
|
||||
|
||||
|
||||
def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
|
||||
@@ -116,15 +115,48 @@ def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
|
||||
|
||||
# Check that only the last two user messages have cache_prompt=True
|
||||
cached_user_messages = [
|
||||
msg
|
||||
for msg in messages
|
||||
if msg.role in ('user', 'system') and msg.content[0].cache_prompt
|
||||
msg for msg in messages if msg.role == 'user' and msg.content[0].cache_prompt
|
||||
]
|
||||
assert (
|
||||
len(cached_user_messages) == 4
|
||||
) # Including the initial system+user + 2 last user message
|
||||
len(cached_user_messages) == 3
|
||||
) # Including the initial user message + 2 last user messages
|
||||
|
||||
# Verify that these are indeed the last two user messages (from start)
|
||||
assert cached_user_messages[0].content[0].text.startswith('You are OpenHands agent')
|
||||
assert cached_user_messages[0].content[0].text.startswith('User message 0')
|
||||
assert cached_user_messages[1].content[0].text.startswith('User message 1')
|
||||
assert cached_user_messages[2].content[0].text.startswith('User message 1')
|
||||
assert cached_user_messages[3].content[0].text.startswith('User message 1')
|
||||
|
||||
|
||||
def test_prompt_caching_headers(codeact_agent: CodeActAgent):
|
||||
history = list()
|
||||
# Setup
|
||||
msg1 = MessageAction('Hello, agent!')
|
||||
msg1._source = 'user'
|
||||
history.append(msg1)
|
||||
msg2 = MessageAction('Hello, user!')
|
||||
msg2._source = 'agent'
|
||||
history.append(msg2)
|
||||
|
||||
mock_state = Mock()
|
||||
mock_state.history = history
|
||||
mock_state.max_iterations = 5
|
||||
mock_state.iteration = 0
|
||||
mock_state.extra_data = {}
|
||||
|
||||
codeact_agent.reset()
|
||||
|
||||
# Create a mock for litellm_completion
|
||||
def check_headers(**kwargs):
|
||||
assert 'extra_headers' in kwargs
|
||||
assert 'anthropic-beta' in kwargs['extra_headers']
|
||||
assert kwargs['extra_headers']['anthropic-beta'] == 'prompt-caching-2024-07-31'
|
||||
return ModelResponse(
|
||||
choices=[{'message': {'content': 'Hello! How can I assist you today?'}}]
|
||||
)
|
||||
|
||||
codeact_agent.llm._completion_unwrapped = check_headers
|
||||
result = codeact_agent.step(mock_state)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageAction)
|
||||
assert result.content == 'Hello! How can I assist you today?'
|
||||
|
||||
@@ -3,7 +3,7 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from openhands.controller.agent_controller import AgentController
|
||||
from openhands.core.config import AgentConfig, LLMConfig
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.events import EventStream
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.storage import InMemoryFileStore
|
||||
@@ -11,7 +11,9 @@ from openhands.storage import InMemoryFileStore
|
||||
|
||||
@pytest.fixture
|
||||
def agent_controller():
|
||||
llm = LLM(config=LLMConfig())
|
||||
llm = MagicMock(spec=LLM)
|
||||
llm.config = MagicMock()
|
||||
llm.metrics = MagicMock()
|
||||
agent = MagicMock()
|
||||
agent.name = 'test_agent'
|
||||
agent.llm = llm
|
||||
|
||||
Reference in New Issue
Block a user