mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Refactor system message handling to use event stream (#7824)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Calvin Smith <email@cjsmith.io>
This commit is contained in:
@@ -11,6 +11,7 @@ from openhands.events.action import (
|
||||
Action,
|
||||
AgentFinishAction,
|
||||
)
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.memory.condenser import Condenser
|
||||
@@ -166,8 +167,8 @@ class CodeActAgent(Agent):
|
||||
message flow and function-calling scenarios.
|
||||
|
||||
The method performs the following steps:
|
||||
1. Initializes with system prompt and optional initial user message
|
||||
2. Processes events (Actions and Observations) into messages
|
||||
1. Checks for SystemMessageAction in events, adds one if missing (legacy support)
|
||||
2. Processes events (Actions and Observations) into messages, including SystemMessageAction
|
||||
3. Handles tool calls and their responses in function-calling mode
|
||||
4. Manages message role alternation (user/assistant/tool)
|
||||
5. Applies caching for specific LLM providers (e.g., Anthropic)
|
||||
@@ -178,8 +179,7 @@ class CodeActAgent(Agent):
|
||||
|
||||
Returns:
|
||||
list[Message]: A list of formatted messages ready for LLM consumption, including:
|
||||
- System message with prompt
|
||||
- Initial user message (if configured)
|
||||
- System message with prompt (from SystemMessageAction)
|
||||
- Action messages (from both user and assistant)
|
||||
- Observation messages (including tool responses)
|
||||
- Environment reminders (in non-function-calling mode)
|
||||
@@ -193,15 +193,32 @@ class CodeActAgent(Agent):
|
||||
if not self.prompt_manager:
|
||||
raise Exception('Prompt Manager not instantiated.')
|
||||
|
||||
# Use ConversationMemory to process initial messages
|
||||
messages = self.conversation_memory.process_initial_messages(
|
||||
with_caching=self.llm.is_caching_prompt_active()
|
||||
# Check if there's a SystemMessageAction in the events
|
||||
has_system_message = any(
|
||||
isinstance(event, SystemMessageAction) for event in events
|
||||
)
|
||||
|
||||
# Use ConversationMemory to process events
|
||||
# Legacy behavior: If no SystemMessageAction is found, add one
|
||||
if not has_system_message:
|
||||
logger.warning(
|
||||
f'[{self.name}] No SystemMessageAction found in events. '
|
||||
'Adding one for backward compatibility. '
|
||||
'This is deprecated behavior and will be removed in a future version.'
|
||||
)
|
||||
system_message = self.get_system_message()
|
||||
if system_message:
|
||||
# Create a copy and insert at the beginning of the list
|
||||
processed_events = list(events)
|
||||
processed_events.insert(0, system_message)
|
||||
logger.debug(
|
||||
f'[{self.name}] Added SystemMessageAction for backward compatibility'
|
||||
)
|
||||
else:
|
||||
processed_events = events
|
||||
|
||||
# Use ConversationMemory to process events (including SystemMessageAction)
|
||||
messages = self.conversation_memory.process_events(
|
||||
condensed_history=events,
|
||||
initial_messages=messages,
|
||||
condensed_history=processed_events,
|
||||
max_message_chars=self.llm.config.max_message_chars,
|
||||
vision_is_active=self.llm.vision_is_active(),
|
||||
)
|
||||
|
||||
@@ -5,10 +5,13 @@ if TYPE_CHECKING:
|
||||
from openhands.controller.state.state import State
|
||||
from openhands.core.config import AgentConfig
|
||||
from openhands.events.action import Action
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.core.exceptions import (
|
||||
AgentAlreadyRegisteredError,
|
||||
AgentNotRegisteredError,
|
||||
)
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.llm.llm import LLM
|
||||
from openhands.runtime.plugins import PluginRequirement
|
||||
|
||||
@@ -38,6 +41,42 @@ class Agent(ABC):
|
||||
self._complete = False
|
||||
self.prompt_manager: 'PromptManager' | None = None
|
||||
self.mcp_tools: list[dict] = []
|
||||
self.tools: list = []
|
||||
|
||||
def get_system_message(self) -> 'SystemMessageAction | None':
|
||||
"""
|
||||
Returns a SystemMessageAction containing the system message and tools.
|
||||
This will be added to the event stream as the first message.
|
||||
|
||||
Returns:
|
||||
SystemMessageAction: The system message action with content and tools
|
||||
None: If there was an error generating the system message
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
|
||||
try:
|
||||
if not self.prompt_manager:
|
||||
logger.warning(
|
||||
f'[{self.name}] Prompt manager not initialized before getting system message'
|
||||
)
|
||||
return None
|
||||
|
||||
system_message = self.prompt_manager.get_system_message()
|
||||
|
||||
# Get tools if available
|
||||
tools = getattr(self, 'tools', None)
|
||||
|
||||
system_message_action = SystemMessageAction(
|
||||
content=system_message, tools=tools
|
||||
)
|
||||
# Set the source attribute
|
||||
system_message_action._source = EventSource.AGENT # type: ignore
|
||||
|
||||
return system_message_action
|
||||
except Exception as e:
|
||||
logger.warning(f'[{self.name}] Failed to generate system message: {e}')
|
||||
return None
|
||||
|
||||
@property
|
||||
def complete(self) -> bool:
|
||||
|
||||
@@ -54,6 +54,7 @@ from openhands.events.action import (
|
||||
IPythonRunCellAction,
|
||||
MessageAction,
|
||||
NullAction,
|
||||
SystemMessageAction,
|
||||
)
|
||||
from openhands.events.action.agent import CondensationAction, RecallAction
|
||||
from openhands.events.event import Event
|
||||
@@ -163,6 +164,31 @@ class AgentController:
|
||||
# replay-related
|
||||
self._replay_manager = ReplayManager(replay_events)
|
||||
|
||||
# Add the system message to the event stream
|
||||
self._add_system_message()
|
||||
|
||||
def _add_system_message(self):
|
||||
for event in self.event_stream.get_events(start_id=self.state.start_id):
|
||||
if isinstance(event, MessageAction) and event.source == EventSource.USER:
|
||||
# FIXME: Remove this after 6/1/2025
|
||||
# Do not try to add a system message if we first run into
|
||||
# a user message -- this means the eventstream exits before
|
||||
# SystemMessageAction is introduced.
|
||||
# We expect *agent* to handle this case gracefully.
|
||||
return
|
||||
|
||||
if isinstance(event, SystemMessageAction):
|
||||
# Do not try to add the system message if it already exists
|
||||
return
|
||||
|
||||
# Add the system message to the event stream
|
||||
# This should be done for all agents, including delegates
|
||||
system_message = self.agent.get_system_message()
|
||||
logger.debug(f'System message got from agent: {system_message}')
|
||||
if system_message:
|
||||
self.event_stream.add_event(system_message, EventSource.AGENT)
|
||||
logger.debug(f'System message added to event stream: {system_message}')
|
||||
|
||||
async def close(self, set_stop_state=True) -> None:
|
||||
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.
|
||||
|
||||
|
||||
@@ -6,6 +6,10 @@ class ActionType(str, Enum):
|
||||
"""Represents a message.
|
||||
"""
|
||||
|
||||
SYSTEM = 'system'
|
||||
"""Represents a system message.
|
||||
"""
|
||||
|
||||
START = 'start'
|
||||
"""Starts a new development task OR send chat from the user. Only sent by the client.
|
||||
"""
|
||||
|
||||
@@ -16,7 +16,7 @@ from openhands.events.action.files import (
|
||||
FileWriteAction,
|
||||
)
|
||||
from openhands.events.action.mcp import McpAction
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.action.message import MessageAction, SystemMessageAction
|
||||
|
||||
__all__ = [
|
||||
'Action',
|
||||
@@ -33,6 +33,7 @@ __all__ = [
|
||||
'ChangeAgentStateAction',
|
||||
'IPythonRunCellAction',
|
||||
'MessageAction',
|
||||
'SystemMessageAction',
|
||||
'ActionConfirmationStatus',
|
||||
'AgentThinkAction',
|
||||
'RecallAction',
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import openhands
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.events.action.action import Action, ActionSecurityRisk
|
||||
|
||||
@@ -32,3 +34,27 @@ class MessageAction(Action):
|
||||
for url in self.image_urls:
|
||||
ret += f'\nIMAGE_URL: {url}'
|
||||
return ret
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemMessageAction(Action):
|
||||
"""
|
||||
Action that represents a system message for an agent, including the system prompt
|
||||
and available tools. This should be the first message in the event stream.
|
||||
"""
|
||||
|
||||
content: str
|
||||
tools: list[Any] | None = None
|
||||
openhands_version: str | None = openhands.__version__
|
||||
action: ActionType = ActionType.SYSTEM
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.content
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = f'**SystemMessageAction** (source={self.source})\n'
|
||||
ret += f'CONTENT: {self.content}'
|
||||
if self.tools:
|
||||
ret += f'\nTOOLS: {len(self.tools)} tools available'
|
||||
return ret
|
||||
|
||||
@@ -23,7 +23,7 @@ from openhands.events.action.files import (
|
||||
FileWriteAction,
|
||||
)
|
||||
from openhands.events.action.mcp import McpAction
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.action.message import MessageAction, SystemMessageAction
|
||||
|
||||
actions = (
|
||||
NullAction,
|
||||
@@ -41,6 +41,7 @@ actions = (
|
||||
RecallAction,
|
||||
ChangeAgentStateAction,
|
||||
MessageAction,
|
||||
SystemMessageAction,
|
||||
CondensationAction,
|
||||
McpAction,
|
||||
)
|
||||
|
||||
@@ -20,6 +20,7 @@ from openhands.events.action import (
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.action.mcp import McpAction
|
||||
from openhands.events.action.message import SystemMessageAction
|
||||
from openhands.events.event import Event, RecallType
|
||||
from openhands.events.observation import (
|
||||
AgentCondensationObservation,
|
||||
@@ -53,7 +54,6 @@ class ConversationMemory:
|
||||
def process_events(
|
||||
self,
|
||||
condensed_history: list[Event],
|
||||
initial_messages: list[Message],
|
||||
max_message_chars: int | None = None,
|
||||
vision_is_active: bool = False,
|
||||
) -> list[Message]:
|
||||
@@ -63,7 +63,6 @@ class ConversationMemory:
|
||||
|
||||
Args:
|
||||
condensed_history: The condensed history of events to convert
|
||||
initial_messages: The initial messages to include in the conversation
|
||||
max_message_chars: The maximum number of characters in the content of an event included
|
||||
in the prompt to the LLM. Larger observations are truncated.
|
||||
vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included.
|
||||
@@ -74,8 +73,8 @@ class ConversationMemory:
|
||||
# log visual browsing status
|
||||
logger.debug(f'Visual browsing: {self.agent_config.enable_som_visual_browsing}')
|
||||
|
||||
# Process special events first (system prompts, etc.)
|
||||
messages = initial_messages
|
||||
# Initialize empty messages list
|
||||
messages = []
|
||||
|
||||
# Process regular events
|
||||
pending_tool_call_action_messages: dict[str, Message] = {}
|
||||
@@ -132,20 +131,6 @@ class ConversationMemory:
|
||||
messages = list(ConversationMemory._filter_unmatched_tool_calls(messages))
|
||||
return messages
|
||||
|
||||
def process_initial_messages(self, with_caching: bool = False) -> list[Message]:
|
||||
"""Create the initial messages for the conversation."""
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
content=[
|
||||
TextContent(
|
||||
text=self.prompt_manager.get_system_message(),
|
||||
cache_prompt=with_caching,
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
def _process_action(
|
||||
self,
|
||||
action: Action,
|
||||
@@ -275,6 +260,16 @@ class ConversationMemory:
|
||||
content=content,
|
||||
)
|
||||
]
|
||||
elif isinstance(action, SystemMessageAction):
|
||||
# Convert SystemMessageAction to a system message
|
||||
return [
|
||||
Message(
|
||||
role='system',
|
||||
content=[TextContent(text=action.content)],
|
||||
# Include tools if function calling is enabled
|
||||
tool_calls=None,
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
def _process_observation(
|
||||
@@ -546,6 +541,8 @@ class ConversationMemory:
|
||||
|
||||
For new Anthropic API, we only need to mark the last user or tool message as cacheable.
|
||||
"""
|
||||
if len(messages) > 0 and messages[0].role == 'system':
|
||||
messages[0].content[-1].cache_prompt = True
|
||||
# NOTE: this is only needed for anthropic
|
||||
for message in reversed(messages):
|
||||
if message.role in ('user', 'tool'):
|
||||
|
||||
@@ -12,9 +12,16 @@ from openhands.resolver.utils import extract_issue_references
|
||||
|
||||
|
||||
class GithubIssueHandler(IssueHandlerInterface):
|
||||
def __init__(self, owner: str, repo: str, token: str, username: str | None = None, base_domain: str = "github.com"):
|
||||
def __init__(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
token: str,
|
||||
username: str | None = None,
|
||||
base_domain: str = 'github.com',
|
||||
):
|
||||
"""Initialize a GitHub issue handler.
|
||||
|
||||
|
||||
Args:
|
||||
owner: The owner of the repository
|
||||
repo: The name of the repository
|
||||
@@ -42,7 +49,7 @@ class GithubIssueHandler(IssueHandlerInterface):
|
||||
}
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
if self.base_domain == "github.com":
|
||||
if self.base_domain == 'github.com':
|
||||
return f'https://api.github.com/repos/{self.owner}/{self.repo}'
|
||||
else:
|
||||
return f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}'
|
||||
@@ -65,7 +72,7 @@ class GithubIssueHandler(IssueHandlerInterface):
|
||||
return f'https://{username_and_token}@{self.base_domain}/{self.owner}/{self.repo}.git'
|
||||
|
||||
def get_graphql_url(self) -> str:
|
||||
if self.base_domain == "github.com":
|
||||
if self.base_domain == 'github.com':
|
||||
return 'https://api.github.com/graphql'
|
||||
else:
|
||||
return f'https://{self.base_domain}/api/v3/graphql'
|
||||
@@ -302,9 +309,16 @@ class GithubIssueHandler(IssueHandlerInterface):
|
||||
|
||||
|
||||
class GithubPRHandler(GithubIssueHandler):
|
||||
def __init__(self, owner: str, repo: str, token: str, username: str | None = None, base_domain: str = "github.com"):
|
||||
def __init__(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
token: str,
|
||||
username: str | None = None,
|
||||
base_domain: str = 'github.com',
|
||||
):
|
||||
"""Initialize a GitHub PR handler.
|
||||
|
||||
|
||||
Args:
|
||||
owner: The owner of the repository
|
||||
repo: The name of the repository
|
||||
@@ -313,8 +327,10 @@ class GithubPRHandler(GithubIssueHandler):
|
||||
base_domain: The domain for GitHub Enterprise (default: "github.com")
|
||||
"""
|
||||
super().__init__(owner, repo, token, username, base_domain)
|
||||
if self.base_domain == "github.com":
|
||||
self.download_url = f'https://api.github.com/repos/{self.owner}/{self.repo}/pulls'
|
||||
if self.base_domain == 'github.com':
|
||||
self.download_url = (
|
||||
f'https://api.github.com/repos/{self.owner}/{self.repo}/pulls'
|
||||
)
|
||||
else:
|
||||
self.download_url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/pulls'
|
||||
|
||||
@@ -470,7 +486,7 @@ class GithubPRHandler(GithubIssueHandler):
|
||||
self, pr_number: int, comment_id: int | None = None
|
||||
) -> list[str] | None:
|
||||
"""Download comments for a specific pull request from Github."""
|
||||
if self.base_domain == "github.com":
|
||||
if self.base_domain == 'github.com':
|
||||
url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{pr_number}/comments'
|
||||
else:
|
||||
url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/issues/{pr_number}/comments'
|
||||
@@ -542,7 +558,7 @@ class GithubPRHandler(GithubIssueHandler):
|
||||
|
||||
for issue_number in unique_issue_references:
|
||||
try:
|
||||
if self.base_domain == "github.com":
|
||||
if self.base_domain == 'github.com':
|
||||
url = f'https://api.github.com/repos/{self.owner}/{self.repo}/issues/{issue_number}'
|
||||
else:
|
||||
url = f'https://{self.base_domain}/api/v3/repos/{self.owner}/{self.repo}/issues/{issue_number}'
|
||||
|
||||
@@ -134,10 +134,7 @@ async def reset_settings(request: Request) -> JSONResponse:
|
||||
)
|
||||
|
||||
|
||||
|
||||
async def check_provider_tokens(request: Request,
|
||||
settings: POSTSettingsModel) -> str:
|
||||
|
||||
async def check_provider_tokens(request: Request, settings: POSTSettingsModel) -> str:
|
||||
if settings.provider_tokens:
|
||||
# Remove extraneous token types
|
||||
provider_types = [provider.value for provider in ProviderType]
|
||||
@@ -152,17 +149,13 @@ async def check_provider_tokens(request: Request,
|
||||
SecretStr(token_value)
|
||||
)
|
||||
if not confirmed_token_type or confirmed_token_type.value != token_type:
|
||||
return f"Invalid token. Please make sure it is a valid {token_type} token."
|
||||
|
||||
|
||||
return ""
|
||||
return f'Invalid token. Please make sure it is a valid {token_type} token.'
|
||||
|
||||
return ''
|
||||
|
||||
|
||||
async def store_provider_tokens(request: Request, settings: POSTSettingsModel):
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
|
||||
existing_settings = await settings_store.load()
|
||||
if existing_settings:
|
||||
if settings.provider_tokens:
|
||||
@@ -188,19 +181,17 @@ async def store_provider_tokens(request: Request, settings: POSTSettingsModel):
|
||||
else: # nothing passed in means keep current settings
|
||||
provider_tokens = existing_settings.secrets_store.provider_tokens
|
||||
settings.provider_tokens = {
|
||||
provider.value: data.token.get_secret_value()
|
||||
if data.token
|
||||
else None
|
||||
provider.value: data.token.get_secret_value() if data.token else None
|
||||
for provider, data in provider_tokens.items()
|
||||
}
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
async def store_llm_settings(request: Request, settings: POSTSettingsModel) -> POSTSettingsModel:
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
config, get_user_id(request)
|
||||
)
|
||||
async def store_llm_settings(
|
||||
request: Request, settings: POSTSettingsModel
|
||||
) -> POSTSettingsModel:
|
||||
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
|
||||
existing_settings = await settings_store.load()
|
||||
|
||||
# Convert to Settings model and merge with existing settings
|
||||
@@ -215,6 +206,7 @@ async def store_llm_settings(request: Request, settings: POSTSettingsModel) -> P
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
@app.post('/settings', response_model=dict[str, str])
|
||||
async def store_settings(
|
||||
request: Request,
|
||||
@@ -225,11 +217,8 @@ async def store_settings(
|
||||
if provider_err_msg:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': provider_err_msg
|
||||
},
|
||||
content={'error': provider_err_msg},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
settings_store = await SettingsStoreImpl.get_instance(
|
||||
@@ -248,7 +237,6 @@ async def store_settings(
|
||||
)
|
||||
|
||||
settings = await store_provider_tokens(request, settings)
|
||||
|
||||
|
||||
# Update sandbox config with new settings
|
||||
if settings.remote_runtime_resource_factor is not None:
|
||||
|
||||
@@ -139,7 +139,7 @@ class Session:
|
||||
condensers=[
|
||||
BrowserOutputCondenserConfig(),
|
||||
LLMSummarizingCondenserConfig(
|
||||
llm_config=llm.config, keep_first=3, max_size=80
|
||||
llm_config=llm.config, keep_first=4, max_size=80
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user