Add RecallActions and observations for retrieval of prompt extensions (#6909)

Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Calvin Smith <email@cjsmith.io>
This commit is contained in:
Engel Nyst
2025-03-15 21:48:37 +01:00
committed by GitHub
parent e34a771e66
commit cc45f5d9c3
38 changed files with 2317 additions and 735 deletions

View File

@@ -29,7 +29,12 @@ from openhands.core.exceptions import (
from openhands.core.logger import LOG_ALL_EVENTS
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema import AgentState
from openhands.events import EventSource, EventStream, EventStreamSubscriber
from openhands.events import (
EventSource,
EventStream,
EventStreamSubscriber,
RecallType,
)
from openhands.events.action import (
Action,
ActionConfirmationStatus,
@@ -42,6 +47,7 @@ from openhands.events.action import (
MessageAction,
NullAction,
)
from openhands.events.action.agent import RecallAction
from openhands.events.event import Event
from openhands.events.observation import (
AgentCondensationObservation,
@@ -89,7 +95,7 @@ class AgentController:
max_budget_per_task: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None,
agent_configs: dict[str, AgentConfig] | None = None,
sid: str = 'default',
sid: str | None = None,
confirmation_mode: bool = False,
initial_state: State | None = None,
is_delegate: bool = False,
@@ -116,7 +122,7 @@ class AgentController:
status_callback: Optional callback function to handle status updates.
replay_events: A list of logs to replay.
"""
self.id = sid
self.id = sid or event_stream.sid
self.agent = agent
self.headless_mode = headless_mode
self.is_delegate = is_delegate
@@ -287,8 +293,14 @@ class AgentController:
return True
return False
if isinstance(event, Observation):
if isinstance(event, NullObservation) or isinstance(
event, AgentStateChangedObservation
if (
isinstance(event, NullObservation)
and event.cause is not None
and event.cause > 0
):
return True
if isinstance(event, AgentStateChangedObservation) or isinstance(
event, NullObservation
):
return False
return True
@@ -388,6 +400,7 @@ class AgentController:
if observation.llm_metrics is not None:
self.agent.llm.metrics.merge(observation.llm_metrics)
# this happens for runnable actions and microagent actions
if self._pending_action and self._pending_action.id == observation.cause:
if self.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
return
@@ -431,6 +444,25 @@ class AgentController:
'debug',
f'Extended max iterations to {self.state.max_iterations} after user message',
)
# try to retrieve microagents relevant to the user message
# set pending_action while we search for information
# if this is the first user message for this agent, matters for the microagent info type
first_user_message = self._first_user_message()
is_first_user_message = (
action.id == first_user_message.id if first_user_message else False
)
recall_type = (
RecallType.WORKSPACE_CONTEXT
if is_first_user_message
else RecallType.KNOWLEDGE
)
recall_action = RecallAction(query=action.content, recall_type=recall_type)
self._pending_action = recall_action
# this is source=USER because the user message is the trigger for the microagent retrieval
self.event_stream.add_event(recall_action, EventSource.USER)
if self.get_agent_state() != AgentState.RUNNING:
await self.set_agent_state_to(AgentState.RUNNING)
elif action.source == EventSource.AGENT and action.wait_for_response:
@@ -438,6 +470,7 @@ class AgentController:
def _reset(self) -> None:
"""Resets the agent controller"""
# Runnable actions need an Observation
# make sure there is an Observation with the tool call metadata to be recognized by the agent
# otherwise the pending action is found in history, but it's incomplete without an obs with tool result
if self._pending_action and hasattr(self._pending_action, 'tool_call_metadata'):
@@ -459,6 +492,8 @@ class AgentController:
obs._cause = self._pending_action.id # type: ignore[attr-defined]
self.event_stream.add_event(obs, EventSource.AGENT)
# NOTE: RecallActions don't need an ErrorObservation upon reset, as long as they have no tool calls
# reset the pending action, this will be called when the agent is STOPPED or ERROR
self._pending_action = None
self.agent.reset()
@@ -1146,3 +1181,26 @@ class AgentController:
result = event.agent_state == AgentState.RUNNING
return result
return False
def _first_user_message(self) -> MessageAction | None:
"""
Get the first user message for this agent.
For regular agents, this is the first user message from the beginning (start_id=0).
For delegate agents, this is the first user message after the delegate's start_id.
Returns:
MessageAction | None: The first user message, or None if no user message found
"""
# Find the first user message from the appropriate starting point
user_messages = list(self.event_stream.get_events(start_id=self.state.start_id))
# Get and return the first user message
return next(
(
e
for e in user_messages
if isinstance(e, MessageAction) and e.source == EventSource.USER
),
None,
)