Refactor system message handling to use event stream (#7824)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Calvin Smith <email@cjsmith.io>
This commit is contained in:
Xingyao Wang
2025-04-17 10:30:19 -04:00
committed by GitHub
parent caf34d83bd
commit 93e9db3206
19 changed files with 446 additions and 321 deletions

View File

@@ -11,6 +11,7 @@ from openhands.events.action import (
Action,
AgentFinishAction,
)
from openhands.events.action.message import SystemMessageAction
from openhands.events.event import Event
from openhands.llm.llm import LLM
from openhands.memory.condenser import Condenser
@@ -166,8 +167,8 @@ class CodeActAgent(Agent):
message flow and function-calling scenarios.
The method performs the following steps:
1. Initializes with system prompt and optional initial user message
2. Processes events (Actions and Observations) into messages
1. Checks for SystemMessageAction in events, adds one if missing (legacy support)
2. Processes events (Actions and Observations) into messages, including SystemMessageAction
3. Handles tool calls and their responses in function-calling mode
4. Manages message role alternation (user/assistant/tool)
5. Applies caching for specific LLM providers (e.g., Anthropic)
@@ -178,8 +179,7 @@ class CodeActAgent(Agent):
Returns:
list[Message]: A list of formatted messages ready for LLM consumption, including:
- System message with prompt
- Initial user message (if configured)
- System message with prompt (from SystemMessageAction)
- Action messages (from both user and assistant)
- Observation messages (including tool responses)
- Environment reminders (in non-function-calling mode)
@@ -193,15 +193,32 @@ class CodeActAgent(Agent):
if not self.prompt_manager:
raise Exception('Prompt Manager not instantiated.')
# Use ConversationMemory to process initial messages
messages = self.conversation_memory.process_initial_messages(
with_caching=self.llm.is_caching_prompt_active()
# Check if there's a SystemMessageAction in the events
has_system_message = any(
isinstance(event, SystemMessageAction) for event in events
)
# Use ConversationMemory to process events
# Legacy behavior: If no SystemMessageAction is found, add one
if not has_system_message:
logger.warning(
f'[{self.name}] No SystemMessageAction found in events. '
'Adding one for backward compatibility. '
'This is deprecated behavior and will be removed in a future version.'
)
system_message = self.get_system_message()
if system_message:
# Create a copy and insert at the beginning of the list
processed_events = list(events)
processed_events.insert(0, system_message)
logger.debug(
f'[{self.name}] Added SystemMessageAction for backward compatibility'
)
else:
processed_events = events
# Use ConversationMemory to process events (including SystemMessageAction)
messages = self.conversation_memory.process_events(
condensed_history=events,
initial_messages=messages,
condensed_history=processed_events,
max_message_chars=self.llm.config.max_message_chars,
vision_is_active=self.llm.vision_is_active(),
)

View File

@@ -5,10 +5,13 @@ if TYPE_CHECKING:
from openhands.controller.state.state import State
from openhands.core.config import AgentConfig
from openhands.events.action import Action
from openhands.events.action.message import SystemMessageAction
from openhands.core.exceptions import (
AgentAlreadyRegisteredError,
AgentNotRegisteredError,
)
from openhands.core.logger import openhands_logger as logger
from openhands.events.event import EventSource
from openhands.llm.llm import LLM
from openhands.runtime.plugins import PluginRequirement
@@ -38,6 +41,42 @@ class Agent(ABC):
self._complete = False
self.prompt_manager: 'PromptManager' | None = None
self.mcp_tools: list[dict] = []
self.tools: list = []
def get_system_message(self) -> 'SystemMessageAction | None':
"""
Returns a SystemMessageAction containing the system message and tools.
This will be added to the event stream as the first message.
Returns:
SystemMessageAction: The system message action with content and tools
None: If there was an error generating the system message
"""
# Import here to avoid circular imports
from openhands.events.action.message import SystemMessageAction
try:
if not self.prompt_manager:
logger.warning(
f'[{self.name}] Prompt manager not initialized before getting system message'
)
return None
system_message = self.prompt_manager.get_system_message()
# Get tools if available
tools = getattr(self, 'tools', None)
system_message_action = SystemMessageAction(
content=system_message, tools=tools
)
# Set the source attribute
system_message_action._source = EventSource.AGENT # type: ignore
return system_message_action
except Exception as e:
logger.warning(f'[{self.name}] Failed to generate system message: {e}')
return None
@property
def complete(self) -> bool:

View File

@@ -54,6 +54,7 @@ from openhands.events.action import (
IPythonRunCellAction,
MessageAction,
NullAction,
SystemMessageAction,
)
from openhands.events.action.agent import CondensationAction, RecallAction
from openhands.events.event import Event
@@ -163,6 +164,31 @@ class AgentController:
# replay-related
self._replay_manager = ReplayManager(replay_events)
# Add the system message to the event stream
self._add_system_message()
def _add_system_message(self):
for event in self.event_stream.get_events(start_id=self.state.start_id):
if isinstance(event, MessageAction) and event.source == EventSource.USER:
# FIXME: Remove this after 6/1/2025
# Do not try to add a system message if we first run into
# a user message -- this means the eventstream exits before
# SystemMessageAction is introduced.
# We expect *agent* to handle this case gracefully.
return
if isinstance(event, SystemMessageAction):
# Do not try to add the system message if it already exists
return
# Add the system message to the event stream
# This should be done for all agents, including delegates
system_message = self.agent.get_system_message()
logger.debug(f'System message got from agent: {system_message}')
if system_message:
self.event_stream.add_event(system_message, EventSource.AGENT)
logger.debug(f'System message added to event stream: {system_message}')
async def close(self, set_stop_state=True) -> None:
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.

View File

@@ -6,6 +6,10 @@ class ActionType(str, Enum):
"""Represents a message.
"""
SYSTEM = 'system'
"""Represents a system message.
"""
START = 'start'
"""Starts a new development task OR send chat from the user. Only sent by the client.
"""

View File

@@ -16,7 +16,7 @@ from openhands.events.action.files import (
FileWriteAction,
)
from openhands.events.action.mcp import McpAction
from openhands.events.action.message import MessageAction
from openhands.events.action.message import MessageAction, SystemMessageAction
__all__ = [
'Action',
@@ -33,6 +33,7 @@ __all__ = [
'ChangeAgentStateAction',
'IPythonRunCellAction',
'MessageAction',
'SystemMessageAction',
'ActionConfirmationStatus',
'AgentThinkAction',
'RecallAction',

View File

@@ -1,5 +1,7 @@
from dataclasses import dataclass
from typing import Any
import openhands
from openhands.core.schema import ActionType
from openhands.events.action.action import Action, ActionSecurityRisk
@@ -32,3 +34,27 @@ class MessageAction(Action):
for url in self.image_urls:
ret += f'\nIMAGE_URL: {url}'
return ret
@dataclass
class SystemMessageAction(Action):
"""
Action that represents a system message for an agent, including the system prompt
and available tools. This should be the first message in the event stream.
"""
content: str
tools: list[Any] | None = None
openhands_version: str | None = openhands.__version__
action: ActionType = ActionType.SYSTEM
@property
def message(self) -> str:
return self.content
def __str__(self) -> str:
ret = f'**SystemMessageAction** (source={self.source})\n'
ret += f'CONTENT: {self.content}'
if self.tools:
ret += f'\nTOOLS: {len(self.tools)} tools available'
return ret

View File

@@ -23,7 +23,7 @@ from openhands.events.action.files import (
FileWriteAction,
)
from openhands.events.action.mcp import McpAction
from openhands.events.action.message import MessageAction
from openhands.events.action.message import MessageAction, SystemMessageAction
actions = (
NullAction,
@@ -41,6 +41,7 @@ actions = (
RecallAction,
ChangeAgentStateAction,
MessageAction,
SystemMessageAction,
CondensationAction,
McpAction,
)

View File

@@ -20,6 +20,7 @@ from openhands.events.action import (
MessageAction,
)
from openhands.events.action.mcp import McpAction
from openhands.events.action.message import SystemMessageAction
from openhands.events.event import Event, RecallType
from openhands.events.observation import (
AgentCondensationObservation,
@@ -53,7 +54,6 @@ class ConversationMemory:
def process_events(
self,
condensed_history: list[Event],
initial_messages: list[Message],
max_message_chars: int | None = None,
vision_is_active: bool = False,
) -> list[Message]:
@@ -63,7 +63,6 @@ class ConversationMemory:
Args:
condensed_history: The condensed history of events to convert
initial_messages: The initial messages to include in the conversation
max_message_chars: The maximum number of characters in the content of an event included
in the prompt to the LLM. Larger observations are truncated.
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included.
@@ -74,8 +73,8 @@ class ConversationMemory:
# log visual browsing status
logger.debug(f'Visual browsing: {self.agent_config.enable_som_visual_browsing}')
# Process special events first (system prompts, etc.)
messages = initial_messages
# Initialize empty messages list
messages = []
# Process regular events
pending_tool_call_action_messages: dict[str, Message] = {}
@@ -132,20 +131,6 @@ class ConversationMemory:
messages = list(ConversationMemory._filter_unmatched_tool_calls(messages))
return messages
def process_initial_messages(self, with_caching: bool = False) -> list[Message]:
"""Create the initial messages for the conversation."""
return [
Message(
role='system',
content=[
TextContent(
text=self.prompt_manager.get_system_message(),
cache_prompt=with_caching,
)
],
)
]
def _process_action(
self,
action: Action,
@@ -275,6 +260,16 @@ class ConversationMemory:
content=content,
)
]
elif isinstance(action, SystemMessageAction):
# Convert SystemMessageAction to a system message
return [
Message(
role='system',
content=[TextContent(text=action.content)],
# Include tools if function calling is enabled
tool_calls=None,
)
]
return []
def _process_observation(
@@ -546,6 +541,8 @@ class ConversationMemory:
For new Anthropic API, we only need to mark the last user or tool message as cacheable.
"""
if len(messages) > 0 and messages[0].role == 'system':
messages[0].content[-1].cache_prompt = True
# NOTE: this is only needed for anthropic
for message in reversed(messages):
if message.role in ('user', 'tool'):

View File

@@ -12,9 +12,16 @@ from openhands.resolver.utils import extract_issue_references
class GithubIssueHandler(IssueHandlerInterface):
def __init__(self, owner: str, repo: str, token: str, username: str | None = None, base_domain: str = "github.com"):
def __init__(
self,
owner: str,
repo: str,
token: str,
username: str | None = None,
base_domain: str = 'github.com',
):
"""Initialize a GitHub issue handler.
Args:
owner: The owner of the repository
repo: The name of the repository
@@ -42,7 +49,7 @@ class GithubIssueHandler(IssueHandlerInterface):
}
def get_base_url(self) -> str:
if self.base_domain == "github.com":
if self.base_domain == 'github.com':
return f'https://api.github.com/repos/{self.owner}/{self.repo}'
else:
return f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}'
@@ -65,7 +72,7 @@ class GithubIssueHandler(IssueHandlerInterface):
return f'https://{username_and_token}@{self.base_domain}/{self.owner}/{self.repo}.git'
def get_graphql_url(self) -> str:
if self.base_domain == "github.com":
if self.base_domain == 'github.com':
return 'https://api.github.com/graphql'
else:
return f'https://{self.base_domain}/api/v3/graphql'
@@ -302,9 +309,16 @@ class GithubIssueHandler(IssueHandlerInterface):
class GithubPRHandler(GithubIssueHandler):
def __init__(self, owner: str, repo: str, token: str, username: str | None = None, base_domain: str = "github.com"):
def __init__(
self,
owner: str,
repo: str,
token: str,
username: str | None = None,
base_domain: str = 'github.com',
):
"""Initialize a GitHub PR handler.
Args:
owner: The owner of the repository
repo: The name of the repository
@@ -313,8 +327,10 @@ class GithubPRHandler(GithubIssueHandler):
base_domain: The domain for GitHub Enterprise (default: "github.com")
"""
super().__init__(owner, repo, token, username, base_domain)
if self.base_domain == "github.com":
self.download_url = f'https://api.github.com/repos/{self.owner}/{self.repo}/pulls'
if self.base_domain == 'github.com':
self.download_url = (
f'https://api.github.com/repos/{self.owner}/{self.repo}/pulls'
)
else:
self.download_url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/pulls'
@@ -470,7 +486,7 @@ class GithubPRHandler(GithubIssueHandler):
self, pr_number: int, comment_id: int | None = None
) -> list[str] | None:
"""Download comments for a specific pull request from Github."""
if self.base_domain == "github.com":
if self.base_domain == 'github.com':
url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{pr_number}/comments'
else:
url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/issues/{pr_number}/comments'
@@ -542,7 +558,7 @@ class GithubPRHandler(GithubIssueHandler):
for issue_number in unique_issue_references:
try:
if self.base_domain == "github.com":
if self.base_domain == 'github.com':
url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{issue_number}'
else:
url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/issues/{issue_number}'

View File

@@ -134,10 +134,7 @@ async def reset_settings(request: Request) -> JSONResponse:
)
async def check_provider_tokens(request: Request,
settings: POSTSettingsModel) -> str:
async def check_provider_tokens(request: Request, settings: POSTSettingsModel) -> str:
if settings.provider_tokens:
# Remove extraneous token types
provider_types = [provider.value for provider in ProviderType]
@@ -152,17 +149,13 @@ async def check_provider_tokens(request: Request,
SecretStr(token_value)
)
if not confirmed_token_type or confirmed_token_type.value != token_type:
return f"Invalid token. Please make sure it is a valid {token_type} token."
return ""
return f'Invalid token. Please make sure it is a valid {token_type} token.'
return ''
async def store_provider_tokens(request: Request, settings: POSTSettingsModel):
settings_store = await SettingsStoreImpl.get_instance(
config, get_user_id(request)
)
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
existing_settings = await settings_store.load()
if existing_settings:
if settings.provider_tokens:
@@ -188,19 +181,17 @@ async def store_provider_tokens(request: Request, settings: POSTSettingsModel):
else: # nothing passed in means keep current settings
provider_tokens = existing_settings.secrets_store.provider_tokens
settings.provider_tokens = {
provider.value: data.token.get_secret_value()
if data.token
else None
provider.value: data.token.get_secret_value() if data.token else None
for provider, data in provider_tokens.items()
}
return settings
async def store_llm_settings(request: Request, settings: POSTSettingsModel) -> POSTSettingsModel:
settings_store = await SettingsStoreImpl.get_instance(
config, get_user_id(request)
)
async def store_llm_settings(
request: Request, settings: POSTSettingsModel
) -> POSTSettingsModel:
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
existing_settings = await settings_store.load()
# Convert to Settings model and merge with existing settings
@@ -215,6 +206,7 @@ async def store_llm_settings(request: Request, settings: POSTSettingsModel) -> P
return settings
@app.post('/settings', response_model=dict[str, str])
async def store_settings(
request: Request,
@@ -225,11 +217,8 @@ async def store_settings(
if provider_err_msg:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
'error': provider_err_msg
},
content={'error': provider_err_msg},
)
try:
settings_store = await SettingsStoreImpl.get_instance(
@@ -248,7 +237,6 @@ async def store_settings(
)
settings = await store_provider_tokens(request, settings)
# Update sandbox config with new settings
if settings.remote_runtime_resource_factor is not None:

View File

@@ -139,7 +139,7 @@ class Session:
condensers=[
BrowserOutputCondenserConfig(),
LLMSummarizingCondenserConfig(
llm_config=llm.config, keep_first=3, max_size=80
llm_config=llm.config, keep_first=4, max_size=80
),
]
)