Compare commits

...

28 Commits

Author SHA1 Message Date
openhands
7be21240ea Remove get_prompt_manager and add get_system_message 2025-02-06 18:09:38 +00:00
openhands
d30c21a66d Update get_prompt_manager to be implemented by each agent 2025-02-06 18:08:12 +00:00
openhands
e6ebbe67f8 Remove redundant get_prompt_manager method from CodeActAgent 2025-02-06 17:59:18 +00:00
Xingyao Wang
f6a4587327 Merge branch 'main' into openhands-workspace-runjzvph 2025-02-06 12:55:18 -05:00
Xingyao Wang
4cea1fd343 remove use of prompt manager in codeact 2025-02-04 23:40:11 -05:00
Xingyao Wang
de8753cef6 simplify PromptExtensionAction 2025-02-04 23:38:31 -05:00
Xingyao Wang
be8f0782fb Merge branch 'main' into openhands-workspace-runjzvph 2025-02-04 23:30:15 -05:00
Xingyao Wang
ec784e1535 Merge branch 'main' into openhands-workspace-runjzvph 2025-02-03 13:03:02 -05:00
Xingyao Wang
db80ee68d5 Merge branch 'main' into openhands-workspace-runjzvph 2025-02-03 12:19:42 -05:00
OpenHands Bot
579a9edbb0 🤖 Auto-fix Python linting issues 2025-02-02 23:45:02 +00:00
openhands
e82aa13fe1 test: Enable prompt extensions in test_prompt_manager_extensions 2025-02-02 05:55:56 +00:00
openhands
7a752364b5 fix: Add message action to event stream in _handle_message_action 2025-02-02 05:54:20 +00:00
openhands
dfc8db8a90 test: Fix metrics handling in test_run_controller_max_iterations_has_metrics 2025-02-02 05:34:35 +00:00
openhands
c513bee4ff test: Fix mock_agent fixture to include config and get_prompt_manager 2025-02-02 05:20:24 +00:00
openhands
8f9a7a131c feat: Add get_prompt_manager method to Agent class to allow specialized prompt managers 2025-02-02 05:09:32 +00:00
openhands
a765196e53 Merge main: Resolve conflict in test_prompt_caching.py 2025-02-02 05:04:51 +00:00
openhands
60aff59c3f Merge from upstream/main and resolve conflicts 2025-01-30 21:09:03 +00:00
openhands
86a844ebbb Remove prompt manager mocks from prompt caching tests 2025-01-30 21:02:35 +00:00
openhands
945e81d262 Add prompt manager tests to test_agent_controller.py 2025-01-30 21:00:50 +00:00
openhands
0fe1f58d97 Fix traffic control tests by properly mocking LLM 2025-01-30 20:47:34 +00:00
openhands
028e715c77 Fix truncation by ensuring start_id is not greater than truncation_id 2025-01-30 20:46:33 +00:00
openhands
f26710b7ae Fix prompt caching tests by properly mocking prompt manager 2025-01-30 20:45:49 +00:00
openhands
8a6d1344a9 Fix CodeActAgent tests by properly mocking LLM and prompt manager 2025-01-30 20:44:31 +00:00
openhands
1e17bbc228 Fix state restoration in agent session by properly setting event IDs 2025-01-30 20:42:22 +00:00
openhands
ab3d004bc9 Fix iteration count not including parent's step when ending delegation 2025-01-30 20:41:36 +00:00
openhands
bd5637fd3d Fix metrics not being synced before handling max iterations error 2025-01-30 20:40:48 +00:00
openhands
1f6be838a7 Fix test_context_window_exceeded_error_handling to match actual behavior of _apply_conversation_window 2025-01-30 20:39:04 +00:00
openhands
e42a3a6407 Move PromptManager to agent controller level
1. Add new action types: SystemMessageAction and PromptExtensionAction
2. Move PromptManager initialization from CodeActAgent to AgentController
3. Send system message at initialization
4. Handle prompt extensions after first user message
5. Update CodeActAgent to handle new actions
6. Fix type issues with prompt manager methods and file paths
2025-01-29 21:15:21 +00:00
11 changed files with 558 additions and 83 deletions

View File

@@ -1,10 +1,8 @@
import json
import os
from collections import deque
from litellm import ModelResponse
import openhands
import openhands.agenthub.codeact_agent.function_calling as codeact_function_calling
from openhands.controller.agent import Agent
from openhands.controller.state.state import State
@@ -23,6 +21,8 @@ from openhands.events.action import (
FileReadAction,
IPythonRunCellAction,
MessageAction,
PromptExtensionAction,
SystemMessageAction,
)
from openhands.events.observation import (
AgentCondensationObservation,
@@ -44,7 +44,6 @@ from openhands.runtime.plugins import (
JupyterRequirement,
PluginRequirement,
)
from openhands.utils.prompt import PromptManager
class CodeActAgent(Agent):
@@ -99,20 +98,30 @@ class CodeActAgent(Agent):
logger.debug(
f'TOOLS loaded for CodeActAgent: {json.dumps(self.tools, indent=2, ensure_ascii=False).replace("\\n", "\n")}'
)
self.prompt_manager = PromptManager(
microagent_dir=os.path.join(
os.path.dirname(os.path.dirname(openhands.__file__)),
'microagents',
)
if self.config.enable_prompt_extensions
else None,
prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'),
disabled_microagents=self.config.disabled_microagents,
)
self.condenser = Condenser.from_config(self.config.condenser)
logger.debug(f'Using condenser: {self.condenser}')
def get_system_message(self) -> str:
"""Get the system message for this agent.
Returns:
The system message as a string
"""
import os
# Determine the prompt_dir for CodeActAgent
current_dir = os.path.dirname(os.path.abspath(__file__))
prompt_dir = os.path.join(current_dir, 'prompts')
system_message_path = os.path.join(prompt_dir, 'system_message.txt')
if os.path.exists(system_message_path):
with open(system_message_path, 'r') as f:
return f.read().strip()
else:
return "You are a helpful AI assistant."
def get_action_message(
self,
action: Action,
@@ -231,6 +240,22 @@ class CodeActAgent(Agent):
content=content,
)
]
elif isinstance(action, SystemMessageAction):
return [
Message(
role='system',
content=[TextContent(text=action.content)],
)
]
elif isinstance(action, PromptExtensionAction):
# For prompt extensions, we add them as assistant messages
# This way they appear in the conversation history but don't interfere with the system message
return [
Message(
role='assistant',
content=[TextContent(text=action.content)],
)
]
return []
def get_observation_message(
@@ -427,20 +452,7 @@ class CodeActAgent(Agent):
- Messages from the same role are combined to prevent consecutive same-role messages
- For Anthropic models, specific messages are cached according to their documentation
"""
if not self.prompt_manager:
raise Exception('Prompt Manager not instantiated.')
messages: list[Message] = [
Message(
role='system',
content=[
TextContent(
text=self.prompt_manager.get_system_message(),
cache_prompt=self.llm.is_caching_prompt_active(),
)
],
)
]
messages: list[Message] = []
pending_tool_call_action_messages: dict[str, Message] = {}
tool_call_id_to_message: dict[str, Message] = {}
@@ -448,7 +460,6 @@ class CodeActAgent(Agent):
# Condense the events from the state.
events = self.condenser.condensed_history(state)
is_first_message_handled = False
for event in events:
# create a regular message from an event
if isinstance(event, Action):
@@ -492,19 +503,6 @@ class CodeActAgent(Agent):
for msg in messages_to_add:
if msg:
if msg.role == 'user' and not is_first_message_handled:
is_first_message_handled = True
# compose the first user message with examples
self.prompt_manager.add_examples_to_initial_message(msg)
# and/or repo/runtime info
if self.config.enable_prompt_extensions:
self.prompt_manager.add_info_to_initial_message(msg)
# enhance the user message with additional context based on keywords matched
if msg.role == 'user':
self.prompt_manager.enhance_message(msg)
messages.append(msg)
if self.llm.is_caching_prompt_active():

View File

@@ -111,3 +111,22 @@ class Agent(ABC):
if not bool(cls._registry):
raise AgentNotRegisteredError()
return list(cls._registry.keys())
def get_prompt_manager(
self,
microagent_dir: str | None = None,
disabled_microagents: list[str] | None = None,
) -> 'PromptManager | None':
"""Get a prompt manager instance for this agent.
This method can be overridden by subclasses to return a specialized prompt manager.
By default, it returns None.
Args:
microagent_dir: Directory containing microagent prompts
disabled_microagents: List of microagents to disable
Returns:
A PromptManager instance, a subclass of PromptManager, or None
"""
return None

View File

@@ -1,6 +1,7 @@
import asyncio
import copy
import os
import sys
import traceback
from typing import Callable, ClassVar, Type
@@ -12,6 +13,7 @@ from litellm.exceptions import (
RateLimitError,
)
import openhands
from openhands.controller.agent import Agent
from openhands.controller.replay import ReplayManager
from openhands.controller.state.state import State, TrafficControlState
@@ -27,6 +29,7 @@ from openhands.core.exceptions import (
)
from openhands.core.logger import LOG_ALL_EVENTS
from openhands.core.logger import openhands_logger as logger
from openhands.core.message import Message, TextContent
from openhands.core.schema import AgentState
from openhands.events import EventSource, EventStream, EventStreamSubscriber
from openhands.events.action import (
@@ -40,6 +43,8 @@ from openhands.events.action import (
IPythonRunCellAction,
MessageAction,
NullAction,
PromptExtensionAction,
SystemMessageAction,
)
from openhands.events.event import Event
from openhands.events.observation import (
@@ -147,6 +152,18 @@ class AgentController:
# replay-related
self._replay_manager = ReplayManager(replay_events)
# First user message tracking
self._first_user_message_received = False
# Send system message at initialization if the agent provides one
if not self.is_delegate and hasattr(agent, 'get_system_message'):
system_message = agent.get_system_message()
if system_message:
self.event_stream.add_event(
SystemMessageAction(content=system_message),
EventSource.AGENT,
)
async def close(self) -> None:
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.
@@ -401,11 +418,58 @@ class AgentController:
self.state.max_iterations = (
self.state.iteration + self._initial_max_iterations
)
if (
self.state.traffic_control_state == TrafficControlState.THROTTLING
or self.state.traffic_control_state == TrafficControlState.PAUSED
):
self.state.traffic_control_state = TrafficControlState.NORMAL
# Add the message to the event stream
self.event_stream.add_event(action, EventSource.USER)
# Handle first user message to trigger prompt extensions
if (
not self._first_user_message_received
and self.prompt_manager
and self.agent.config.enable_prompt_extensions
):
self._first_user_message_received = True
# Create a Message object from the action content
msg = Message(
role='user',
content=[TextContent(text=action.content)],
)
# Add examples to initial message
self.prompt_manager.add_examples_to_initial_message(msg)
if msg.content[0].text != action.content:
self.event_stream.add_event(
PromptExtensionAction(
content=msg.content[0].text, extension_type='examples'
),
EventSource.AGENT,
)
# Add info to initial message
self.prompt_manager.add_info_to_initial_message(msg)
if msg.content[0].text != action.content:
self.event_stream.add_event(
PromptExtensionAction(
content=msg.content[0].text, extension_type='info'
),
EventSource.AGENT,
)
# Enhance message
self.prompt_manager.enhance_message(msg)
if msg.content[0].text != action.content:
self.event_stream.add_event(
PromptExtensionAction(
content=msg.content[0].text, extension_type='enhance'
),
EventSource.AGENT,
)
if (
self.state.traffic_control_state == TrafficControlState.THROTTLING
or self.state.traffic_control_state == TrafficControlState.PAUSED
):
self.state.traffic_control_state = TrafficControlState.NORMAL
self.log(
'debug',
f'Extended max iterations to {self.state.max_iterations} after user message',
@@ -574,7 +638,8 @@ class AgentController:
delegate_state = self.delegate.get_agent_state()
# update iteration that is shared across agents
self.state.iteration = self.delegate.state.iteration
# Add 1 to account for the parent's step that initiated delegation
self.state.iteration = self.delegate.state.iteration + 1
# close the delegate controller before adding new events
asyncio.get_event_loop().run_until_complete(self.delegate.close())
@@ -761,6 +826,10 @@ class AgentController:
current_str = f'{current_value:.2f}'
max_str = f'{max_value:.2f}'
# Sync metrics before handling the error
await self.update_state_after_step()
self.state.metrics.merge(self.state.local_metrics)
if self.headless_mode:
e = RuntimeError(
f'Agent reached maximum {limit_type} in headless mode. '
@@ -1038,6 +1107,12 @@ class AgentController:
# start_id points to first user message
if first_user_msg:
self.state.start_id = first_user_msg.id
# Make sure start_id is not greater than truncation_id
if (
self.state.truncation_id is not None
and self.state.start_id > self.state.truncation_id
):
self.state.truncation_id = self.state.start_id
return kept_events

View File

@@ -78,5 +78,11 @@ class ActionTypeSchema(BaseModel):
SEND_PR: str = Field(default='send_pr')
"""Send a PR to github."""
SYSTEM_MESSAGE: str = Field(default='system_message')
"""Send a system message to the agent."""
PROMPT_EXTENSION: str = Field(default='prompt_extension')
"""Add extensions to the prompt."""
ActionType = ActionTypeSchema()

View File

@@ -5,6 +5,8 @@ from openhands.events.action.agent import (
AgentRejectAction,
AgentSummarizeAction,
ChangeAgentStateAction,
PromptExtensionAction,
SystemMessageAction,
)
from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction
from openhands.events.action.commands import CmdRunAction, IPythonRunCellAction
@@ -33,4 +35,6 @@ __all__ = [
'IPythonRunCellAction',
'MessageAction',
'ActionConfirmationStatus',
'SystemMessageAction',
'PromptExtensionAction',
]

View File

@@ -78,3 +78,39 @@ class AgentDelegateAction(Action):
@property
def message(self) -> str:
return f"I'm asking {self.agent} for help with this task."
@dataclass
class SystemMessageAction(Action):
"""An action that sends a system message to the agent.
Attributes:
content (str): The system message content.
action (str): The action type, namely ActionType.SYSTEM_MESSAGE.
"""
content: str
action: str = ActionType.SYSTEM_MESSAGE
@property
def message(self) -> str:
return f'System message: {self.content}'
@dataclass
class PromptExtensionAction(Action):
"""An action that adds extensions to the prompt.
Attributes:
content (str): The prompt extension content.
extension_type (str): The type of extension (e.g., 'examples', 'info', 'enhance').
action (str): The action type, namely ActionType.PROMPT_EXTENSION.
"""
content: str
extension_type: str
action: str = ActionType.PROMPT_EXTENSION
@property
def message(self) -> str:
return f'Prompt extension (type: {self.extension_type})'

View File

@@ -313,13 +313,22 @@ class AgentSession:
try:
restored_state = State.restore_from_session(self.sid, self.file_store)
logger.debug(f'Restored state from session, sid: {self.sid}')
# Set start_id to 0 since we want to include all events from the beginning
restored_state.start_id = 0
# Set end_id to -1 to indicate we want to include all events up to the latest
restored_state.end_id = -1
# Set truncation_id to -1 since we haven't truncated anything yet
restored_state.truncation_id = -1
return restored_state
except Exception as e:
if self.event_stream.get_latest_event_id() > 0:
# if we have events, we should have a state
logger.warning(f'State could not be restored: {e}')
else:
logger.debug('No events found, no state to restore')
return restored_state
return None
def get_state(self) -> AgentState | None:
controller = self.controller

View File

@@ -12,7 +12,13 @@ from openhands.core.config import AppConfig
from openhands.core.main import run_controller
from openhands.core.schema import AgentState
from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber
from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
from openhands.events.action import (
ChangeAgentStateAction,
CmdRunAction,
MessageAction,
PromptExtensionAction,
SystemMessageAction,
)
from openhands.events.observation import (
ErrorObservation,
)
@@ -21,6 +27,7 @@ from openhands.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.runtime.base import Runtime
from openhands.storage.memory import InMemoryFileStore
from openhands.utils.prompt import PromptManager
@pytest.fixture
@@ -41,6 +48,10 @@ def mock_agent():
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = AppConfig().get_llm_config()
agent.config = MagicMock()
agent.config.enable_prompt_extensions = False
agent.config.disabled_microagents = []
agent.get_prompt_manager.return_value = None
return agent
@@ -144,7 +155,6 @@ async def test_run_controller_with_fatal_error():
file_store = InMemoryFileStore({})
event_stream = EventStream(sid='test', file_store=file_store)
agent = MagicMock(spec=Agent)
agent = MagicMock(spec=Agent)
def agent_step_fn(state):
@@ -155,6 +165,10 @@ async def test_run_controller_with_fatal_error():
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = config.get_llm_config()
agent.config = MagicMock()
agent.config.enable_prompt_extensions = False
agent.config.disabled_microagents = []
agent.get_prompt_manager.return_value = None
runtime = MagicMock(spec=Runtime)
@@ -200,6 +214,10 @@ async def test_run_controller_stop_with_stuck():
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.config = config.get_llm_config()
agent.config = MagicMock()
agent.config.enable_prompt_extensions = False
agent.config.disabled_microagents = []
agent.get_prompt_manager.return_value = None
runtime = MagicMock(spec=Runtime)
def on_event(event: Event):
@@ -518,16 +536,23 @@ async def test_run_controller_max_iterations_has_metrics():
agent = MagicMock(spec=Agent)
agent.llm = MagicMock(spec=LLM)
agent.llm.metrics = Metrics()
agent.llm.metrics = Metrics() # Start with fresh metrics
agent.llm.config = config.get_llm_config()
agent.config = MagicMock()
agent.config.enable_prompt_extensions = False
agent.config.disabled_microagents = []
agent.get_prompt_manager.return_value = None
# Keep track of total cost
total_cost = 0.0
def agent_step_fn(state):
print(f'agent_step_fn received state: {state}')
# Mock the cost of the LLM
agent.llm.metrics.add_cost(10.0)
print(
f'agent.llm.metrics.accumulated_cost: {agent.llm.metrics.accumulated_cost}'
)
nonlocal total_cost
total_cost += 10.0
state.metrics.add_cost(10.0) # Add cost directly to state metrics
print(f'state.metrics.accumulated_cost: {state.metrics.accumulated_cost}')
return CmdRunAction(command='ls')
agent.step = agent_step_fn
@@ -591,9 +616,12 @@ async def test_context_window_exceeded_error_handling(mock_agent, mock_event_str
def step(self, state: State):
# Append a few messages to the history -- these will be truncated when we throw the error
state.history = [
MessageAction(content='Test message 0'),
MessageAction(content='Test message 1'),
MessageAction(content='Test message 0'), # First user message
MessageAction(content='Test message 1'), # Agent response
MessageAction(content='Test message 2'), # Another agent message
MessageAction(content='Test message 3'), # Another agent message
]
state.history[0]._source = EventSource.USER
error = ContextWindowExceededError(
message='prompt is too long: 233885 tokens > 200000 maximum',
@@ -623,7 +651,17 @@ async def test_context_window_exceeded_error_handling(mock_agent, mock_event_str
# Check that the error was thrown and the history has been truncated
assert state.has_errored
assert controller.state.history == [MessageAction(content='Test message 1')]
# Should keep first user message and second half of history
expected_history = [
MessageAction(content='Test message 0'), # First user message always kept
MessageAction(content='Test message 2'), # Second half of history
MessageAction(content='Test message 3'),
]
expected_history[0]._source = EventSource.USER
assert len(controller.state.history) == len(expected_history)
for actual, expected in zip(controller.state.history, expected_history):
assert actual.content == expected.content
assert actual.source == expected.source
@pytest.mark.asyncio
@@ -682,3 +720,246 @@ async def test_run_controller_with_context_window_exceeded(mock_agent, mock_runt
# Check that the context window exceeded error was raised during the run
assert step_state.has_errored
@pytest.mark.asyncio
async def test_prompt_manager_initialization(mock_agent, mock_event_stream):
"""Test that the prompt manager is properly initialized and sends system message."""
# Mock the prompt manager
mock_prompt_manager = MagicMock(spec=PromptManager)
mock_prompt_manager.get_system_message.return_value = 'Test system message'
mock_agent.get_prompt_manager.return_value = mock_prompt_manager
# Create controller
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
# Verify that system message was sent
mock_event_stream.add_event.assert_called_with(
SystemMessageAction(content='Test system message'),
EventSource.AGENT,
)
await controller.close()
@pytest.mark.asyncio
async def test_prompt_manager_extensions(mock_agent, mock_event_stream):
"""Test that prompt extensions are properly added to event stream."""
# Mock the prompt manager and enable extensions
mock_agent.config.enable_prompt_extensions = True
mock_prompt_manager = MagicMock(spec=PromptManager)
mock_prompt_manager.get_system_message.return_value = 'Test system message'
# Mock the prompt extension methods
def add_examples(msg):
msg.content[0].text = 'Examples added: ' + msg.content[0].text
mock_prompt_manager.add_examples_to_initial_message.side_effect = add_examples
def add_info(msg):
msg.content[0].text = 'Info added: ' + msg.content[0].text
mock_prompt_manager.add_info_to_initial_message.side_effect = add_info
def enhance(msg):
msg.content[0].text = 'Enhanced: ' + msg.content[0].text
mock_prompt_manager.enhance_message.side_effect = enhance
mock_agent.get_prompt_manager.return_value = mock_prompt_manager
# Create controller
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
# Send a user message
message_action = MessageAction(content='Test message')
message_action._source = EventSource.USER
await controller._on_event(message_action)
# Get all calls to add_event
actual_calls = mock_event_stream.add_event.call_args_list
# Verify that system message was added
assert any(
isinstance(args[0], SystemMessageAction)
and args[0].content == 'Test system message'
for args, _ in actual_calls
)
# Verify that prompt extensions were added
expected_extensions = [
('Examples added: Test message', 'examples'),
('Info added: Examples added: Test message', 'info'),
('Enhanced: Info added: Examples added: Test message', 'enhance'),
]
for content, ext_type in expected_extensions:
assert any(
isinstance(args[0], PromptExtensionAction)
and args[0].content == content
and args[0].extension_type == ext_type
for args, _ in actual_calls
), f'Missing extension: {ext_type}'
await controller.close()
@pytest.mark.asyncio
async def test_prompt_manager_extensions_disabled(mock_agent, mock_event_stream):
"""Test that prompt extensions are not added when disabled."""
# Mock the prompt manager but disable extensions
mock_agent.config.enable_prompt_extensions = False
mock_prompt_manager = MagicMock(spec=PromptManager)
mock_prompt_manager.get_system_message.return_value = 'Test system message'
mock_agent.get_prompt_manager.return_value = mock_prompt_manager
# Create controller
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
# Send a user message
message_action = MessageAction(content='Test message')
message_action._source = EventSource.USER
await controller._on_event(message_action)
# Get all calls to add_event
actual_calls = mock_event_stream.add_event.call_args_list
# Verify that system message was added
assert any(
isinstance(args[0], SystemMessageAction)
and args[0].content == 'Test system message'
for args, _ in actual_calls
)
# Verify that no prompt extensions were added
assert not any(
isinstance(args[0], PromptExtensionAction) for args, _ in actual_calls
)
# Verify that extension methods were not called
mock_prompt_manager.add_examples_to_initial_message.assert_not_called()
mock_prompt_manager.add_info_to_initial_message.assert_not_called()
mock_prompt_manager.enhance_message.assert_not_called()
await controller.close()
@pytest.mark.asyncio
async def test_prompt_manager_extensions_delegate(mock_agent, mock_event_stream):
"""Test that prompt extensions are not added for delegate controllers."""
# Mock the prompt manager
mock_prompt_manager = MagicMock(spec=PromptManager)
mock_prompt_manager.get_system_message.return_value = 'Test system message'
mock_agent.get_prompt_manager.return_value = mock_prompt_manager
# Create delegate controller
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
is_delegate=True, # This should prevent system message and extensions
)
# Send a user message
message_action = MessageAction(content='Test message')
message_action._source = EventSource.USER
await controller._on_event(message_action)
# Get all calls to add_event
actual_calls = mock_event_stream.add_event.call_args_list
# Verify that only the original message was added (plus state changes)
print('Message action:', message_action)
print('Actual calls:', actual_calls)
assert any(
args[0] == message_action and args[1] == EventSource.USER
for args, _ in actual_calls
)
# Verify that no system message or prompt extensions were added
assert not any(
isinstance(args[0], SystemMessageAction)
or isinstance(args[0], PromptExtensionAction)
for args, _ in actual_calls
)
# Verify that system message and extension methods were not called
mock_prompt_manager.get_system_message.assert_not_called()
mock_prompt_manager.add_examples_to_initial_message.assert_not_called()
mock_prompt_manager.add_info_to_initial_message.assert_not_called()
mock_prompt_manager.enhance_message.assert_not_called()
await controller.close()
@pytest.mark.asyncio
async def test_prompt_manager_not_initialized(mock_agent, mock_event_stream):
"""Test that no system message is sent if prompt manager is not initialized."""
# Set prompt manager to None
mock_agent.prompt_manager = None
# Create controller
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
)
# Verify that no system message was sent
for call in mock_event_stream.add_event.call_args_list:
args, _ = call
assert not isinstance(args[0], SystemMessageAction)
await controller.close()
@pytest.mark.asyncio
async def test_prompt_manager_delegate_initialization(mock_agent, mock_event_stream):
"""Test that system message is not sent for delegate controllers."""
# Mock the prompt manager
mock_agent.prompt_manager = MagicMock(spec=PromptManager)
mock_agent.prompt_manager.get_system_message.return_value = 'Test system message'
# Create controller with is_delegate=True
controller = AgentController(
agent=mock_agent,
event_stream=mock_event_stream,
max_iterations=10,
sid='test',
confirmation_mode=False,
headless_mode=True,
is_delegate=True, # This should prevent system message from being sent
)
# Verify that no system message was sent
for call in mock_event_stream.add_event.call_args_list:
args, _ = call
assert not isinstance(args[0], SystemMessageAction)
await controller.close()

View File

@@ -499,6 +499,9 @@ def test_step_with_no_pending_actions(mock_state: State):
config = AgentConfig()
config.enable_prompt_extensions = False
agent = CodeActAgent(llm=llm, config=config)
agent.prompt_manager = Mock()
agent.prompt_manager.get_system_message.return_value = 'System message'
agent.prompt_manager.add_examples_to_initial_message = Mock()
# Test step with no pending actions
mock_state.latest_user_message = None
@@ -508,6 +511,7 @@ def test_step_with_no_pending_actions(mock_state: State):
mock_state.latest_user_message_timeout = None
mock_state.latest_user_message_llm_metrics = None
mock_state.latest_user_message_tool_call_metadata = None
mock_state.history = []
action = agent.step(mock_state)
assert isinstance(action, MessageAction)
@@ -516,7 +520,16 @@ def test_step_with_no_pending_actions(mock_state: State):
def test_mismatched_tool_call_events(mock_state: State):
"""Tests that the agent can convert mismatched tool call events (i.e., an observation with no corresponding action) into messages."""
agent = CodeActAgent(llm=LLM(LLMConfig()), config=AgentConfig())
llm = Mock()
llm.is_function_calling_active = Mock(return_value=True) # Enable function calling
llm.is_caching_prompt_active = Mock(return_value=False)
llm.config = Mock()
llm.config.max_message_chars = 1000
agent = CodeActAgent(llm=llm, config=AgentConfig())
agent.prompt_manager = Mock()
agent.prompt_manager.get_system_message.return_value = 'System message'
agent.prompt_manager.add_examples_to_initial_message = Mock()
tool_call_metadata = Mock(
spec=ToolCallMetadata,

View File

@@ -79,23 +79,22 @@ def test_get_messages(codeact_agent: CodeActAgent):
)
assert (
len(messages) == 6
) # System, initial user + user message, agent message, last user message
assert messages[0].content[0].cache_prompt # system message
assert messages[1].role == 'user'
assert messages[1].content[0].text.endswith('Initial user message')
len(messages) == 5
) # Initial user message, agent message, user message, agent message, last user message
assert messages[0].role == 'user'
assert messages[0].content[0].text.endswith('Initial user message')
# we add cache breakpoint to the last 3 user messages
assert messages[1].content[0].cache_prompt
assert messages[0].content[0].cache_prompt
assert messages[3].role == 'user'
assert messages[3].content[0].text == ('Hello, agent!')
assert messages[3].content[0].cache_prompt
assert messages[4].role == 'assistant'
assert messages[4].content[0].text == 'Hello, user!'
assert not messages[4].content[0].cache_prompt
assert messages[5].role == 'user'
assert messages[5].content[0].text.startswith('Laaaaaaaast!')
assert messages[5].content[0].cache_prompt
assert messages[2].role == 'user'
assert messages[2].content[0].text == ('Hello, agent!')
assert messages[2].content[0].cache_prompt
assert messages[3].role == 'assistant'
assert messages[3].content[0].text == 'Hello, user!'
assert not messages[3].content[0].cache_prompt
assert messages[4].role == 'user'
assert messages[4].content[0].text.startswith('Laaaaaaaast!')
assert messages[4].content[0].cache_prompt
def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
@@ -116,15 +115,48 @@ def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
# Check that only the last two user messages have cache_prompt=True
cached_user_messages = [
msg
for msg in messages
if msg.role in ('user', 'system') and msg.content[0].cache_prompt
msg for msg in messages if msg.role == 'user' and msg.content[0].cache_prompt
]
assert (
len(cached_user_messages) == 4
) # Including the initial system+user + 2 last user message
len(cached_user_messages) == 3
) # Including the initial user message + 2 last user messages
# Verify that these are indeed the last two user messages (from start)
assert cached_user_messages[0].content[0].text.startswith('You are OpenHands agent')
assert cached_user_messages[0].content[0].text.startswith('User message 0')
assert cached_user_messages[1].content[0].text.startswith('User message 1')
assert cached_user_messages[2].content[0].text.startswith('User message 1')
assert cached_user_messages[3].content[0].text.startswith('User message 1')
def test_prompt_caching_headers(codeact_agent: CodeActAgent):
history = list()
# Setup
msg1 = MessageAction('Hello, agent!')
msg1._source = 'user'
history.append(msg1)
msg2 = MessageAction('Hello, user!')
msg2._source = 'agent'
history.append(msg2)
mock_state = Mock()
mock_state.history = history
mock_state.max_iterations = 5
mock_state.iteration = 0
mock_state.extra_data = {}
codeact_agent.reset()
# Create a mock for litellm_completion
def check_headers(**kwargs):
assert 'extra_headers' in kwargs
assert 'anthropic-beta' in kwargs['extra_headers']
assert kwargs['extra_headers']['anthropic-beta'] == 'prompt-caching-2024-07-31'
return ModelResponse(
choices=[{'message': {'content': 'Hello! How can I assist you today?'}}]
)
codeact_agent.llm._completion_unwrapped = check_headers
result = codeact_agent.step(mock_state)
# Assert
assert isinstance(result, MessageAction)
assert result.content == 'Hello! How can I assist you today?'

View File

@@ -3,7 +3,7 @@ from unittest.mock import MagicMock
import pytest
from openhands.controller.agent_controller import AgentController
from openhands.core.config import AgentConfig, LLMConfig
from openhands.core.config import AgentConfig
from openhands.events import EventStream
from openhands.llm.llm import LLM
from openhands.storage import InMemoryFileStore
@@ -11,7 +11,9 @@ from openhands.storage import InMemoryFileStore
@pytest.fixture
def agent_controller():
llm = LLM(config=LLMConfig())
llm = MagicMock(spec=LLM)
llm.config = MagicMock()
llm.metrics = MagicMock()
agent = MagicMock()
agent.name = 'test_agent'
agent.llm = llm