mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-10 07:18:10 -05:00
Add RecallActions and observations for retrieval of prompt extensions (#6909)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Calvin Smith <email@cjsmith.io>
This commit is contained in:
@@ -29,7 +29,12 @@ from openhands.core.exceptions import (
|
||||
from openhands.core.logger import LOG_ALL_EVENTS
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import AgentState
|
||||
from openhands.events import EventSource, EventStream, EventStreamSubscriber
|
||||
from openhands.events import (
|
||||
EventSource,
|
||||
EventStream,
|
||||
EventStreamSubscriber,
|
||||
RecallType,
|
||||
)
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
ActionConfirmationStatus,
|
||||
@@ -42,6 +47,7 @@ from openhands.events.action import (
|
||||
MessageAction,
|
||||
NullAction,
|
||||
)
|
||||
from openhands.events.action.agent import RecallAction
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.observation import (
|
||||
AgentCondensationObservation,
|
||||
@@ -89,7 +95,7 @@ class AgentController:
|
||||
max_budget_per_task: float | None = None,
|
||||
agent_to_llm_config: dict[str, LLMConfig] | None = None,
|
||||
agent_configs: dict[str, AgentConfig] | None = None,
|
||||
sid: str = 'default',
|
||||
sid: str | None = None,
|
||||
confirmation_mode: bool = False,
|
||||
initial_state: State | None = None,
|
||||
is_delegate: bool = False,
|
||||
@@ -116,7 +122,7 @@ class AgentController:
|
||||
status_callback: Optional callback function to handle status updates.
|
||||
replay_events: A list of logs to replay.
|
||||
"""
|
||||
self.id = sid
|
||||
self.id = sid or event_stream.sid
|
||||
self.agent = agent
|
||||
self.headless_mode = headless_mode
|
||||
self.is_delegate = is_delegate
|
||||
@@ -287,8 +293,14 @@ class AgentController:
|
||||
return True
|
||||
return False
|
||||
if isinstance(event, Observation):
|
||||
if isinstance(event, NullObservation) or isinstance(
|
||||
event, AgentStateChangedObservation
|
||||
if (
|
||||
isinstance(event, NullObservation)
|
||||
and event.cause is not None
|
||||
and event.cause > 0
|
||||
):
|
||||
return True
|
||||
if isinstance(event, AgentStateChangedObservation) or isinstance(
|
||||
event, NullObservation
|
||||
):
|
||||
return False
|
||||
return True
|
||||
@@ -388,6 +400,7 @@ class AgentController:
|
||||
if observation.llm_metrics is not None:
|
||||
self.agent.llm.metrics.merge(observation.llm_metrics)
|
||||
|
||||
# this happens for runnable actions and microagent actions
|
||||
if self._pending_action and self._pending_action.id == observation.cause:
|
||||
if self.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION:
|
||||
return
|
||||
@@ -431,6 +444,25 @@ class AgentController:
|
||||
'debug',
|
||||
f'Extended max iterations to {self.state.max_iterations} after user message',
|
||||
)
|
||||
# try to retrieve microagents relevant to the user message
|
||||
# set pending_action while we search for information
|
||||
|
||||
# if this is the first user message for this agent, matters for the microagent info type
|
||||
first_user_message = self._first_user_message()
|
||||
is_first_user_message = (
|
||||
action.id == first_user_message.id if first_user_message else False
|
||||
)
|
||||
recall_type = (
|
||||
RecallType.WORKSPACE_CONTEXT
|
||||
if is_first_user_message
|
||||
else RecallType.KNOWLEDGE
|
||||
)
|
||||
|
||||
recall_action = RecallAction(query=action.content, recall_type=recall_type)
|
||||
self._pending_action = recall_action
|
||||
# this is source=USER because the user message is the trigger for the microagent retrieval
|
||||
self.event_stream.add_event(recall_action, EventSource.USER)
|
||||
|
||||
if self.get_agent_state() != AgentState.RUNNING:
|
||||
await self.set_agent_state_to(AgentState.RUNNING)
|
||||
elif action.source == EventSource.AGENT and action.wait_for_response:
|
||||
@@ -438,6 +470,7 @@ class AgentController:
|
||||
|
||||
def _reset(self) -> None:
|
||||
"""Resets the agent controller"""
|
||||
# Runnable actions need an Observation
|
||||
# make sure there is an Observation with the tool call metadata to be recognized by the agent
|
||||
# otherwise the pending action is found in history, but it's incomplete without an obs with tool result
|
||||
if self._pending_action and hasattr(self._pending_action, 'tool_call_metadata'):
|
||||
@@ -459,6 +492,8 @@ class AgentController:
|
||||
obs._cause = self._pending_action.id # type: ignore[attr-defined]
|
||||
self.event_stream.add_event(obs, EventSource.AGENT)
|
||||
|
||||
# NOTE: RecallActions don't need an ErrorObservation upon reset, as long as they have no tool calls
|
||||
|
||||
# reset the pending action, this will be called when the agent is STOPPED or ERROR
|
||||
self._pending_action = None
|
||||
self.agent.reset()
|
||||
@@ -1146,3 +1181,26 @@ class AgentController:
|
||||
result = event.agent_state == AgentState.RUNNING
|
||||
return result
|
||||
return False
|
||||
|
||||
def _first_user_message(self) -> MessageAction | None:
|
||||
"""
|
||||
Get the first user message for this agent.
|
||||
|
||||
For regular agents, this is the first user message from the beginning (start_id=0).
|
||||
For delegate agents, this is the first user message after the delegate's start_id.
|
||||
|
||||
Returns:
|
||||
MessageAction | None: The first user message, or None if no user message found
|
||||
"""
|
||||
# Find the first user message from the appropriate starting point
|
||||
user_messages = list(self.event_stream.get_events(start_id=self.state.start_id))
|
||||
|
||||
# Get and return the first user message
|
||||
return next(
|
||||
(
|
||||
e
|
||||
for e in user_messages
|
||||
if isinstance(e, MessageAction) and e.source == EventSource.USER
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user