From cc45f5d9c365ba762e6ab3edfe5069dccecf2ed4 Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Sat, 15 Mar 2025 21:48:37 +0100 Subject: [PATCH] Add RecallActions and observations for retrieval of prompt extensions (#6909) Co-authored-by: openhands Co-authored-by: Calvin Smith --- .../agenthub/codeact_agent/codeact_agent.py | 26 +- .../codeact_agent/prompts/additional_info.j2 | 2 + .../codeact_agent/prompts/microagent_info.j2 | 4 +- openhands/controller/agent_controller.py | 68 +- openhands/controller/stuck.py | 2 +- openhands/core/cli.py | 16 +- openhands/core/loop.py | 3 + openhands/core/main.py | 20 +- openhands/core/schema/action.py | 3 + openhands/core/schema/observation.py | 3 + openhands/core/setup.py | 50 +- openhands/events/__init__.py | 3 +- openhands/events/action/__init__.py | 2 + openhands/events/action/agent.py | 20 + openhands/events/event.py | 10 + openhands/events/observation/__init__.py | 4 + openhands/events/observation/agent.py | 76 ++- openhands/events/serialization/action.py | 2 + openhands/events/serialization/event.py | 9 +- openhands/events/serialization/observation.py | 18 + openhands/events/stream.py | 1 + openhands/llm/llm.py | 3 +- openhands/memory/conversation_memory.py | 164 ++++- openhands/memory/memory.py | 270 ++++++++ openhands/server/session/agent_session.py | 49 +- .../conversation/file_conversation_store.py | 4 +- openhands/utils/prompt.py | 173 +---- tests/unit/test_action_serialization.py | 13 + tests/unit/test_agent_controller.py | 236 ++++++- tests/unit/test_agent_delegation.py | 35 +- tests/unit/test_agent_session.py | 31 +- tests/unit/test_cli_sid.py | 23 +- tests/unit/test_codeact_agent.py | 6 - tests/unit/test_conversation_memory.py | 612 ++++++++++++++++- tests/unit/test_is_stuck.py | 4 +- tests/unit/test_memory.py | 260 ++++++++ tests/unit/test_observation_serialization.py | 205 +++++- tests/unit/test_prompt_manager.py | 622 +++++------------- 38 files changed, 2317 insertions(+), 735 deletions(-) create mode 100644 openhands/memory/memory.py create mode 100644 tests/unit/test_memory.py diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index 6266278dc2..fa041afe25 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -1,8 +1,6 @@ -import json import os from collections import deque -import openhands import openhands.agenthub.codeact_agent.function_calling as codeact_function_calling from openhands.controller.agent import Agent from openhands.controller.state.state import State @@ -74,21 +72,14 @@ class CodeActAgent(Agent): codeact_enable_llm_editor=self.config.codeact_enable_llm_editor, ) logger.debug( - f'TOOLS loaded for CodeActAgent: {json.dumps(self.tools, indent=2, ensure_ascii=False).replace("\\n", "\n")}' + f'TOOLS loaded for CodeActAgent: {', '.join([tool.get('function').get('name') for tool in self.tools])}' ) self.prompt_manager = PromptManager( - microagent_dir=os.path.join( - os.path.dirname(os.path.dirname(openhands.__file__)), - 'microagents', - ) - if self.config.enable_prompt_extensions - else None, prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts'), - disabled_microagents=self.config.disabled_microagents, ) # Create a ConversationMemory instance - self.conversation_memory = ConversationMemory(self.prompt_manager) + self.conversation_memory = ConversationMemory(self.config, self.prompt_manager) self.condenser = Condenser.from_config(self.config.condenser) logger.debug(f'Using condenser: {type(self.condenser)}') @@ -168,7 +159,7 @@ class CodeActAgent(Agent): if not self.prompt_manager: raise Exception('Prompt Manager not instantiated.') - # Use conversation_memory to process events instead of calling events_to_messages directly + # Use ConversationMemory to process initial messages messages = self.conversation_memory.process_initial_messages( with_caching=self.llm.is_caching_prompt_active() ) @@ -180,12 +171,12 @@ class CodeActAgent(Agent): f'Processing {len(events)} events from a total of {len(state.history)} events' ) + # Use ConversationMemory to process events messages = self.conversation_memory.process_events( condensed_history=events, initial_messages=messages, max_message_chars=self.llm.config.max_message_chars, vision_is_active=self.llm.vision_is_active(), - enable_som_visual_browsing=self.config.enable_som_visual_browsing, ) messages = self._enhance_messages(messages) @@ -216,14 +207,7 @@ class CodeActAgent(Agent): # compose the first user message with examples self.prompt_manager.add_examples_to_initial_message(msg) - # and/or repo/runtime info - if self.config.enable_prompt_extensions: - self.prompt_manager.add_info_to_initial_message(msg) - - # enhance the user message with additional context based on keywords matched - if msg.role == 'user': - self.prompt_manager.enhance_message(msg) - + elif msg.role == 'user': # Add double newline between consecutive user messages if prev_role == 'user' and len(msg.content) > 0: # Find the first TextContent in the message to add newlines diff --git a/openhands/agenthub/codeact_agent/prompts/additional_info.j2 b/openhands/agenthub/codeact_agent/prompts/additional_info.j2 index df051256a8..409dcba13a 100644 --- a/openhands/agenthub/codeact_agent/prompts/additional_info.j2 +++ b/openhands/agenthub/codeact_agent/prompts/additional_info.j2 @@ -20,6 +20,8 @@ When starting a web server, use the corresponding ports. You should also set any options to allow iframes and CORS requests, and allow the server to be accessed from any host (e.g. 0.0.0.0). {% endif %} +{% if runtime_info.additional_agent_instructions %} {{ runtime_info.additional_agent_instructions }} +{% endif %} {% endif %} diff --git a/openhands/agenthub/codeact_agent/prompts/microagent_info.j2 b/openhands/agenthub/codeact_agent/prompts/microagent_info.j2 index 2059e41b07..264828fbe2 100644 --- a/openhands/agenthub/codeact_agent/prompts/microagent_info.j2 +++ b/openhands/agenthub/codeact_agent/prompts/microagent_info.j2 @@ -1,8 +1,8 @@ {% for agent_info in triggered_agents %} -The following information has been included based on a keyword match for "{{ agent_info.trigger_word }}". +The following information has been included based on a keyword match for "{{ agent_info.trigger }}". It may or may not be relevant to the user's request. -{{ agent_info.agent.content }} +{{ agent_info.content }} {% endfor %} diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index 61bebacfc9..b4571c5ff4 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -29,7 +29,12 @@ from openhands.core.exceptions import ( from openhands.core.logger import LOG_ALL_EVENTS from openhands.core.logger import openhands_logger as logger from openhands.core.schema import AgentState -from openhands.events import EventSource, EventStream, EventStreamSubscriber +from openhands.events import ( + EventSource, + EventStream, + EventStreamSubscriber, + RecallType, +) from openhands.events.action import ( Action, ActionConfirmationStatus, @@ -42,6 +47,7 @@ from openhands.events.action import ( MessageAction, NullAction, ) +from openhands.events.action.agent import RecallAction from openhands.events.event import Event from openhands.events.observation import ( AgentCondensationObservation, @@ -89,7 +95,7 @@ class AgentController: max_budget_per_task: float | None = None, agent_to_llm_config: dict[str, LLMConfig] | None = None, agent_configs: dict[str, AgentConfig] | None = None, - sid: str = 'default', + sid: str | None = None, confirmation_mode: bool = False, initial_state: State | None = None, is_delegate: bool = False, @@ -116,7 +122,7 @@ class AgentController: status_callback: Optional callback function to handle status updates. replay_events: A list of logs to replay. """ - self.id = sid + self.id = sid or event_stream.sid self.agent = agent self.headless_mode = headless_mode self.is_delegate = is_delegate @@ -287,8 +293,14 @@ class AgentController: return True return False if isinstance(event, Observation): - if isinstance(event, NullObservation) or isinstance( - event, AgentStateChangedObservation + if ( + isinstance(event, NullObservation) + and event.cause is not None + and event.cause > 0 + ): + return True + if isinstance(event, AgentStateChangedObservation) or isinstance( + event, NullObservation ): return False return True @@ -388,6 +400,7 @@ class AgentController: if observation.llm_metrics is not None: self.agent.llm.metrics.merge(observation.llm_metrics) + # this happens for runnable actions and microagent actions if self._pending_action and self._pending_action.id == observation.cause: if self.state.agent_state == AgentState.AWAITING_USER_CONFIRMATION: return @@ -431,6 +444,25 @@ class AgentController: 'debug', f'Extended max iterations to {self.state.max_iterations} after user message', ) + # try to retrieve microagents relevant to the user message + # set pending_action while we search for information + + # if this is the first user message for this agent, matters for the microagent info type + first_user_message = self._first_user_message() + is_first_user_message = ( + action.id == first_user_message.id if first_user_message else False + ) + recall_type = ( + RecallType.WORKSPACE_CONTEXT + if is_first_user_message + else RecallType.KNOWLEDGE + ) + + recall_action = RecallAction(query=action.content, recall_type=recall_type) + self._pending_action = recall_action + # this is source=USER because the user message is the trigger for the microagent retrieval + self.event_stream.add_event(recall_action, EventSource.USER) + if self.get_agent_state() != AgentState.RUNNING: await self.set_agent_state_to(AgentState.RUNNING) elif action.source == EventSource.AGENT and action.wait_for_response: @@ -438,6 +470,7 @@ class AgentController: def _reset(self) -> None: """Resets the agent controller""" + # Runnable actions need an Observation # make sure there is an Observation with the tool call metadata to be recognized by the agent # otherwise the pending action is found in history, but it's incomplete without an obs with tool result if self._pending_action and hasattr(self._pending_action, 'tool_call_metadata'): @@ -459,6 +492,8 @@ class AgentController: obs._cause = self._pending_action.id # type: ignore[attr-defined] self.event_stream.add_event(obs, EventSource.AGENT) + # NOTE: RecallActions don't need an ErrorObservation upon reset, as long as they have no tool calls + # reset the pending action, this will be called when the agent is STOPPED or ERROR self._pending_action = None self.agent.reset() @@ -1146,3 +1181,26 @@ class AgentController: result = event.agent_state == AgentState.RUNNING return result return False + + def _first_user_message(self) -> MessageAction | None: + """ + Get the first user message for this agent. + + For regular agents, this is the first user message from the beginning (start_id=0). + For delegate agents, this is the first user message after the delegate's start_id. + + Returns: + MessageAction | None: The first user message, or None if no user message found + """ + # Find the first user message from the appropriate starting point + user_messages = list(self.event_stream.get_events(start_id=self.state.start_id)) + + # Get and return the first user message + return next( + ( + e + for e in user_messages + if isinstance(e, MessageAction) and e.source == EventSource.USER + ), + None, + ) diff --git a/openhands/controller/stuck.py b/openhands/controller/stuck.py index 373a95abfb..0fc85d0f97 100644 --- a/openhands/controller/stuck.py +++ b/openhands/controller/stuck.py @@ -135,7 +135,7 @@ class StuckDetector: # it takes 3 actions and 3 observations to detect a loop # check if the last three actions are the same and result in errors - if len(last_actions) < 4 or len(last_observations) < 4: + if len(last_actions) < 3 or len(last_observations) < 3: return False # are the last three actions the "same"? diff --git a/openhands/core/cli.py b/openhands/core/cli.py index bb134803ae..aed45ec287 100644 --- a/openhands/core/cli.py +++ b/openhands/core/cli.py @@ -17,6 +17,7 @@ from openhands.core.schema import AgentState from openhands.core.setup import ( create_agent, create_controller, + create_memory, create_runtime, initialize_repository_for_runtime, ) @@ -170,13 +171,22 @@ async def main(loop: asyncio.AbstractEventLoop): await runtime.connect() # Initialize repository if needed + repo_directory = None if config.sandbox.selected_repo: - initialize_repository_for_runtime( + repo_directory = initialize_repository_for_runtime( runtime, - agent=agent, selected_repository=config.sandbox.selected_repo, ) + # when memory is created, it will load the microagents from the selected repository + memory = create_memory( + runtime=runtime, + event_stream=event_stream, + sid=sid, + selected_repository=config.sandbox.selected_repo, + repo_directory=repo_directory, + ) + if initial_user_action: # If there's an initial user action, enqueue it and do not prompt again event_stream.add_event(initial_user_action, EventSource.USER) @@ -185,7 +195,7 @@ async def main(loop: asyncio.AbstractEventLoop): asyncio.create_task(prompt_for_next_task()) await run_agent_until_done( - controller, runtime, [AgentState.STOPPED, AgentState.ERROR] + controller, runtime, memory, [AgentState.STOPPED, AgentState.ERROR] ) diff --git a/openhands/core/loop.py b/openhands/core/loop.py index d3f783563e..daf95d2f01 100644 --- a/openhands/core/loop.py +++ b/openhands/core/loop.py @@ -3,12 +3,14 @@ import asyncio from openhands.controller import AgentController from openhands.core.logger import openhands_logger as logger from openhands.core.schema import AgentState +from openhands.memory.memory import Memory from openhands.runtime.base import Runtime async def run_agent_until_done( controller: AgentController, runtime: Runtime, + memory: Memory, end_states: list[AgentState], ): """ @@ -37,6 +39,7 @@ async def run_agent_until_done( runtime.status_callback = status_callback controller.status_callback = status_callback + memory.status_callback = status_callback while controller.state.agent_state not in end_states: await asyncio.sleep(1) diff --git a/openhands/core/main.py b/openhands/core/main.py index 4b282864f2..01c3819766 100644 --- a/openhands/core/main.py +++ b/openhands/core/main.py @@ -18,6 +18,7 @@ from openhands.core.schema import AgentState from openhands.core.setup import ( create_agent, create_controller, + create_memory, create_runtime, generate_sid, initialize_repository_for_runtime, @@ -29,6 +30,7 @@ from openhands.events.event import Event from openhands.events.observation import AgentStateChangedObservation from openhands.events.serialization import event_from_dict from openhands.io import read_input, read_task +from openhands.memory.memory import Memory from openhands.runtime.base import Runtime from openhands.utils.async_utils import call_async_from_sync @@ -51,6 +53,7 @@ async def run_controller( exit_on_message: bool = False, fake_user_response_fn: FakeUserResponseFunc | None = None, headless_mode: bool = True, + memory: Memory | None = None, ) -> State | None: """Main coroutine to run the agent controller with task input flexibility. @@ -93,6 +96,8 @@ async def run_controller( if agent is None: agent = create_agent(config) + # when the runtime is created, it will be connected and clone the selected repository + repo_directory = None if runtime is None: runtime = create_runtime( config, @@ -105,14 +110,23 @@ async def run_controller( # Initialize repository if needed if config.sandbox.selected_repo: - initialize_repository_for_runtime( + repo_directory = initialize_repository_for_runtime( runtime, - agent=agent, selected_repository=config.sandbox.selected_repo, ) event_stream = runtime.event_stream + # when memory is created, it will load the microagents from the selected repository + if memory is None: + memory = create_memory( + runtime=runtime, + event_stream=event_stream, + sid=sid, + selected_repository=config.sandbox.selected_repo, + repo_directory=repo_directory, + ) + replay_events: list[Event] | None = None if config.replay_trajectory_path: logger.info('Trajectory replay is enabled') @@ -172,7 +186,7 @@ async def run_controller( ] try: - await run_agent_until_done(controller, runtime, end_states) + await run_agent_until_done(controller, runtime, memory, end_states) except Exception as e: logger.error(f'Exception in main loop: {e}') diff --git a/openhands/core/schema/action.py b/openhands/core/schema/action.py index fcc5e0a5ae..f81dc58ab9 100644 --- a/openhands/core/schema/action.py +++ b/openhands/core/schema/action.py @@ -82,5 +82,8 @@ class ActionTypeSchema(BaseModel): SEND_PR: str = Field(default='send_pr') """Send a PR to github.""" + RECALL: str = Field(default='recall') + """Retrieves content from a user workspace, microagent, or other source.""" + ActionType = ActionTypeSchema() diff --git a/openhands/core/schema/observation.py b/openhands/core/schema/observation.py index 51ee13f926..1c6ef55bac 100644 --- a/openhands/core/schema/observation.py +++ b/openhands/core/schema/observation.py @@ -49,5 +49,8 @@ class ObservationTypeSchema(BaseModel): CONDENSE: str = Field(default='condense') """Result of a condensation operation.""" + MICROAGENT: str = Field(default='microagent') + """Result of a microagent retrieval operation.""" + ObservationType = ObservationTypeSchema() diff --git a/openhands/core/setup.py b/openhands/core/setup.py index 02cd7fbf8d..9832d6eb04 100644 --- a/openhands/core/setup.py +++ b/openhands/core/setup.py @@ -1,7 +1,7 @@ import hashlib import os import uuid -from typing import Tuple, Type +from typing import Callable, Tuple, Type from pydantic import SecretStr @@ -16,6 +16,7 @@ from openhands.core.logger import openhands_logger as logger from openhands.events import EventStream from openhands.events.event import Event from openhands.llm.llm import LLM +from openhands.memory.memory import Memory from openhands.microagent.microagent import BaseMicroAgent from openhands.runtime import get_runtime_cls from openhands.runtime.base import Runtime @@ -83,7 +84,6 @@ def create_runtime( def initialize_repository_for_runtime( runtime: Runtime, - agent: Agent | None = None, selected_repository: str | None = None, github_token: SecretStr | None = None, ) -> str | None: @@ -91,7 +91,6 @@ def initialize_repository_for_runtime( Args: runtime: The runtime to initialize the repository for. - agent: (optional) The agent to load microagents for. selected_repository: (optional) The GitHub repository to use. github_token: (optional) The GitHub token to use. @@ -99,10 +98,10 @@ def initialize_repository_for_runtime( The repository directory path if a repository was cloned, None otherwise. """ # clone selected repository if provided - repo_directory = None github_token = ( SecretStr(os.environ.get('GITHUB_TOKEN')) if not github_token else github_token ) + repo_directory = None if selected_repository and github_token: logger.debug(f'Selected repository {selected_repository}.') repo_directory = runtime.clone_repo( @@ -111,16 +110,47 @@ def initialize_repository_for_runtime( None, ) - # load microagents from selected repository - if agent and agent.prompt_manager and selected_repository and repo_directory: - agent.prompt_manager.set_runtime_info(runtime) + return repo_directory + + +def create_memory( + runtime: Runtime, + event_stream: EventStream, + sid: str, + selected_repository: str | None = None, + repo_directory: str | None = None, + status_callback: Callable | None = None, +) -> Memory: + """Create a memory for the agent to use. + + Args: + runtime: The runtime to use. + event_stream: The event stream it will subscribe to. + sid: The session id. + selected_repository: The repository to clone and start with, if any. + repo_directory: The repository directory, if any. + status_callback: Optional callback function to handle status updates. + """ + memory = Memory( + event_stream=event_stream, + sid=sid, + status_callback=status_callback, + ) + + if runtime: + # sets available hosts + memory.set_runtime_info(runtime) + + # loads microagents from repo/.openhands/microagents microagents: list[BaseMicroAgent] = runtime.get_microagents_from_selected_repo( selected_repository ) - agent.prompt_manager.load_microagents(microagents) - agent.prompt_manager.set_repository_info(selected_repository, repo_directory) + memory.load_user_workspace_microagents(microagents) - return repo_directory + if selected_repository and repo_directory: + memory.set_repository_info(selected_repository, repo_directory) + + return memory def create_agent(config: AppConfig) -> Agent: diff --git a/openhands/events/__init__.py b/openhands/events/__init__.py index c1694dba4b..3378883bee 100644 --- a/openhands/events/__init__.py +++ b/openhands/events/__init__.py @@ -1,4 +1,4 @@ -from openhands.events.event import Event, EventSource +from openhands.events.event import Event, EventSource, RecallType from openhands.events.stream import EventStream, EventStreamSubscriber __all__ = [ @@ -6,4 +6,5 @@ __all__ = [ 'EventSource', 'EventStream', 'EventStreamSubscriber', + 'RecallType', ] diff --git a/openhands/events/action/__init__.py b/openhands/events/action/__init__.py index 29956e3bb3..c007632e73 100644 --- a/openhands/events/action/__init__.py +++ b/openhands/events/action/__init__.py @@ -6,6 +6,7 @@ from openhands.events.action.agent import ( AgentSummarizeAction, AgentThinkAction, ChangeAgentStateAction, + RecallAction, ) from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction from openhands.events.action.commands import CmdRunAction, IPythonRunCellAction @@ -35,4 +36,5 @@ __all__ = [ 'MessageAction', 'ActionConfirmationStatus', 'AgentThinkAction', + 'RecallAction', ] diff --git a/openhands/events/action/agent.py b/openhands/events/action/agent.py index a46b7626cf..01ab8d05e4 100644 --- a/openhands/events/action/agent.py +++ b/openhands/events/action/agent.py @@ -4,6 +4,7 @@ from typing import Any from openhands.core.schema import ActionType from openhands.events.action.action import Action +from openhands.events.event import RecallType @dataclass @@ -106,3 +107,22 @@ class AgentDelegateAction(Action): @property def message(self) -> str: return f"I'm asking {self.agent} for help with this task." + + +@dataclass +class RecallAction(Action): + """This action is used for retrieving content, e.g., from the global directory or user workspace.""" + + recall_type: RecallType + query: str = '' + thought: str = '' + action: str = ActionType.RECALL + + @property + def message(self) -> str: + return f'Retrieving content for: {self.query[:50]}' + + def __str__(self) -> str: + ret = '**RecallAction**\n' + ret += f'QUERY: {self.query[:50]}' + return ret diff --git a/openhands/events/event.py b/openhands/events/event.py index 9d7af19160..8481a16393 100644 --- a/openhands/events/event.py +++ b/openhands/events/event.py @@ -22,6 +22,16 @@ class FileReadSource(str, Enum): DEFAULT = 'default' +class RecallType(str, Enum): + """The type of information that can be retrieved from microagents.""" + + WORKSPACE_CONTEXT = 'workspace_context' + """Workspace context (repo instructions, runtime, etc.)""" + + KNOWLEDGE = 'knowledge' + """A knowledge microagent.""" + + @dataclass class Event: INVALID_ID = -1 diff --git a/openhands/events/observation/__init__.py b/openhands/events/observation/__init__.py index 7fe9de9093..9e9fdf6568 100644 --- a/openhands/events/observation/__init__.py +++ b/openhands/events/observation/__init__.py @@ -1,7 +1,9 @@ +from openhands.events.event import RecallType from openhands.events.observation.agent import ( AgentCondensationObservation, AgentStateChangedObservation, AgentThinkObservation, + MicroagentObservation, ) from openhands.events.observation.browse import BrowserOutputObservation from openhands.events.observation.commands import ( @@ -40,4 +42,6 @@ __all__ = [ 'SuccessObservation', 'UserRejectObservation', 'AgentCondensationObservation', + 'MicroagentObservation', + 'RecallType', ] diff --git a/openhands/events/observation/agent.py b/openhands/events/observation/agent.py index 413828f3de..e2dfe49456 100644 --- a/openhands/events/observation/agent.py +++ b/openhands/events/observation/agent.py @@ -1,6 +1,7 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from openhands.core.schema import ObservationType +from openhands.events.event import RecallType from openhands.events.observation.observation import Observation @@ -40,3 +41,76 @@ class AgentThinkObservation(Observation): @property def message(self) -> str: return self.content + + +@dataclass +class MicroagentKnowledge: + """ + Represents knowledge from a triggered microagent. + + Attributes: + name: The name of the microagent that was triggered + trigger: The word that triggered this microagent + content: The actual content/knowledge from the microagent + """ + + name: str + trigger: str + content: str + + +@dataclass +class MicroagentObservation(Observation): + """The retrieval of content from a microagent or more microagents.""" + + recall_type: RecallType + observation: str = ObservationType.MICROAGENT + + # environment + repo_name: str = '' + repo_directory: str = '' + repo_instructions: str = '' + runtime_hosts: dict[str, int] = field(default_factory=dict) + additional_agent_instructions: str = '' + + # knowledge + microagent_knowledge: list[MicroagentKnowledge] = field(default_factory=list) + """ + A list of MicroagentKnowledge objects, each containing information from a triggered microagent. + + Example: + [ + MicroagentKnowledge( + name="python_best_practices", + trigger="python", + content="Always use virtual environments for Python projects." + ), + MicroagentKnowledge( + name="git_workflow", + trigger="git", + content="Create a new branch for each feature or bugfix." + ) + ] + """ + + @property + def message(self) -> str: + return self.__str__() + + def __str__(self) -> str: + # Build a string representation of all fields + fields = [ + f'recall_type={self.recall_type}', + f'repo_name={self.repo_name}', + f'repo_instructions={self.repo_instructions[:20]}...', + f'runtime_hosts={self.runtime_hosts}', + f'additional_agent_instructions={self.additional_agent_instructions[:20]}...', + ] + + # Only include microagent_knowledge if it's not empty + if self.microagent_knowledge: + fields.append( + f'microagent_knowledge={", ".join([m.name for m in self.microagent_knowledge])}' + ) + + return f'**MicroagentObservation**\n{", ".join(fields)}' diff --git a/openhands/events/serialization/action.py b/openhands/events/serialization/action.py index 905cf45171..c314201750 100644 --- a/openhands/events/serialization/action.py +++ b/openhands/events/serialization/action.py @@ -8,6 +8,7 @@ from openhands.events.action.agent import ( AgentRejectAction, AgentThinkAction, ChangeAgentStateAction, + RecallAction, ) from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction from openhands.events.action.commands import ( @@ -35,6 +36,7 @@ actions = ( AgentFinishAction, AgentRejectAction, AgentDelegateAction, + RecallAction, ChangeAgentStateAction, MessageAction, ) diff --git a/openhands/events/serialization/event.py b/openhands/events/serialization/event.py index ad566faa1a..8c096e9848 100644 --- a/openhands/events/serialization/event.py +++ b/openhands/events/serialization/event.py @@ -1,5 +1,6 @@ from dataclasses import asdict from datetime import datetime +from enum import Enum from pydantic import BaseModel @@ -102,6 +103,8 @@ def event_to_dict(event: 'Event') -> dict: d['timestamp'] = d['timestamp'].isoformat() if key == 'source' and 'source' in d: d['source'] = d['source'].value + if key == 'recall_type' and 'recall_type' in d: + d['recall_type'] = d['recall_type'].value if key == 'tool_call_metadata' and 'tool_call_metadata' in d: d['tool_call_metadata'] = d['tool_call_metadata'].model_dump() if key == 'llm_metrics' and 'llm_metrics' in d: @@ -119,7 +122,11 @@ def event_to_dict(event: 'Event') -> dict: # props is a dict whose values can include a complex object like an instance of a BaseModel subclass # such as CmdOutputMetadata # we serialize it along with the rest - d['extras'] = {k: _convert_pydantic_to_dict(v) for k, v in props.items()} + # we also handle the Enum conversion for MicroagentObservation + d['extras'] = { + k: (v.value if isinstance(v, Enum) else _convert_pydantic_to_dict(v)) + for k, v in props.items() + } # Include success field for CmdOutputObservation if hasattr(event, 'success'): d['success'] = event.success diff --git a/openhands/events/serialization/observation.py b/openhands/events/serialization/observation.py index 8cc67313b1..e1d59833bc 100644 --- a/openhands/events/serialization/observation.py +++ b/openhands/events/serialization/observation.py @@ -1,9 +1,12 @@ import copy +from openhands.events.event import RecallType from openhands.events.observation.agent import ( AgentCondensationObservation, AgentStateChangedObservation, AgentThinkObservation, + MicroagentKnowledge, + MicroagentObservation, ) from openhands.events.observation.browse import BrowserOutputObservation from openhands.events.observation.commands import ( @@ -40,6 +43,7 @@ observations = ( UserRejectObservation, AgentCondensationObservation, AgentThinkObservation, + MicroagentObservation, ) OBSERVATION_TYPE_TO_CLASS = { @@ -110,4 +114,18 @@ def observation_from_dict(observation: dict) -> Observation: else: extras['metadata'] = CmdOutputMetadata() + if observation_class is MicroagentObservation: + # handle the Enum conversion + if 'recall_type' in extras: + extras['recall_type'] = RecallType(extras['recall_type']) + + # convert dicts in microagent_knowledge to MicroagentKnowledge objects + if 'microagent_knowledge' in extras and isinstance( + extras['microagent_knowledge'], list + ): + extras['microagent_knowledge'] = [ + MicroagentKnowledge(**item) if isinstance(item, dict) else item + for item in extras['microagent_knowledge'] + ] + return observation_class(content=content, **extras) diff --git a/openhands/events/stream.py b/openhands/events/stream.py index 938269822a..e1ecb7adbe 100644 --- a/openhands/events/stream.py +++ b/openhands/events/stream.py @@ -27,6 +27,7 @@ class EventStreamSubscriber(str, Enum): RESOLVER = 'openhands_resolver' SERVER = 'server' RUNTIME = 'runtime' + MEMORY = 'memory' MAIN = 'main' TEST = 'test' diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index bed6424278..b3b526777e 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -249,7 +249,8 @@ class LLM(RetryMixin, DebugMixin): # if we mocked function calling, and we have tools, convert the response back to function calling format if mock_function_calling and mock_fncall_tools is not None: - assert len(resp.choices) == 1 + logger.debug(f'Response choices: {len(resp.choices)}') + assert len(resp.choices) >= 1 non_fncall_response_message = resp.choices[0].message fn_call_messages_with_response = ( convert_non_fncall_messages_to_fncall_messages( diff --git a/openhands/memory/conversation_memory.py b/openhands/memory/conversation_memory.py index d44cb3cb95..b308fef142 100644 --- a/openhands/memory/conversation_memory.py +++ b/openhands/memory/conversation_memory.py @@ -1,5 +1,6 @@ from litellm import ModelResponse +from openhands.core.config.agent_config import AgentConfig from openhands.core.logger import openhands_logger as logger from openhands.core.message import ImageContent, Message, TextContent from openhands.core.schema import ActionType @@ -16,7 +17,7 @@ from openhands.events.action import ( IPythonRunCellAction, MessageAction, ) -from openhands.events.event import Event +from openhands.events.event import Event, RecallType from openhands.events.observation import ( AgentCondensationObservation, AgentDelegateObservation, @@ -28,16 +29,21 @@ from openhands.events.observation import ( IPythonRunCellObservation, UserRejectObservation, ) +from openhands.events.observation.agent import ( + MicroagentKnowledge, + MicroagentObservation, +) from openhands.events.observation.error import ErrorObservation from openhands.events.observation.observation import Observation from openhands.events.serialization.event import truncate_content -from openhands.utils.prompt import PromptManager +from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo class ConversationMemory: """Processes event history into a coherent conversation for the agent.""" - def __init__(self, prompt_manager: PromptManager): + def __init__(self, config: AgentConfig, prompt_manager: PromptManager): + self.agent_config = config self.prompt_manager = prompt_manager def process_events( @@ -53,14 +59,14 @@ class ConversationMemory: Ensures that tool call actions are processed correctly in function calling mode. Args: - state: The state containing the history of events to convert - condensed_history: The condensed list of events to process - initial_messages: The initial messages to include in the result + condensed_history: The condensed history of events to convert + initial_messages: The initial messages to include in the conversation max_message_chars: The maximum number of characters in the content of an event included in the prompt to the LLM. Larger observations are truncated. vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included. enable_som_visual_browsing: Whether to enable visual browsing for the SOM model. """ + events = condensed_history # Process special events first (system prompts, etc.) @@ -70,7 +76,7 @@ class ConversationMemory: pending_tool_call_action_messages: dict[str, Message] = {} tool_call_id_to_message: dict[str, Message] = {} - for event in events: + for i, event in enumerate(events): # create a regular message from an event if isinstance(event, Action): messages_to_add = self._process_action( @@ -84,7 +90,9 @@ class ConversationMemory: tool_call_id_to_message=tool_call_id_to_message, max_message_chars=max_message_chars, vision_is_active=vision_is_active, - enable_som_visual_browsing=enable_som_visual_browsing, + enable_som_visual_browsing=self.agent_config.enable_som_visual_browsing, + current_index=i, + events=events, ) else: raise ValueError(f'Unknown event type: {type(event)}') @@ -270,6 +278,8 @@ class ConversationMemory: max_message_chars: int | None = None, vision_is_active: bool = False, enable_som_visual_browsing: bool = False, + current_index: int = 0, + events: list[Event] | None = None, ) -> list[Message]: """Converts an observation into a message format that can be sent to the LLM. @@ -291,6 +301,8 @@ class ConversationMemory: max_message_chars: The maximum number of characters in the content of an observation included in the prompt to the LLM vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included enable_som_visual_browsing: Whether to enable visual browsing for the SOM model + current_index: The index of the current event in the events list (for deduplication) + events: The list of all events (for deduplication) Returns: list[Message]: A list containing the formatted message(s) for the observation. @@ -372,6 +384,92 @@ class ConversationMemory: elif isinstance(obs, AgentCondensationObservation): text = truncate_content(obs.content, max_message_chars) message = Message(role='user', content=[TextContent(text=text)]) + elif ( + isinstance(obs, MicroagentObservation) + and self.agent_config.enable_prompt_extensions + ): + if obs.recall_type == RecallType.WORKSPACE_CONTEXT: + # everything is optional, check if they are present + repo_info = ( + RepositoryInfo( + repo_name=obs.repo_name or '', + repo_directory=obs.repo_directory or '', + ) + if obs.repo_name or obs.repo_directory + else None + ) + if obs.runtime_hosts or obs.additional_agent_instructions: + runtime_info = RuntimeInfo( + available_hosts=obs.runtime_hosts, + additional_agent_instructions=obs.additional_agent_instructions, + ) + else: + runtime_info = None + + repo_instructions = ( + obs.repo_instructions if obs.repo_instructions else '' + ) + + # Have some meaningful content before calling the template + has_repo_info = repo_info is not None and ( + repo_info.repo_name or repo_info.repo_directory + ) + has_runtime_info = runtime_info is not None and ( + runtime_info.available_hosts + or runtime_info.additional_agent_instructions + ) + has_repo_instructions = bool(repo_instructions.strip()) + + # Build additional info if we have something to render + if has_repo_info or has_runtime_info or has_repo_instructions: + # ok, now we can build the additional info + formatted_text = self.prompt_manager.build_additional_info( + repository_info=repo_info, + runtime_info=runtime_info, + repo_instructions=repo_instructions, + ) + message = Message( + role='user', content=[TextContent(text=formatted_text)] + ) + else: + return [] + elif obs.recall_type == RecallType.KNOWLEDGE: + # Use prompt manager to build the microagent info + # First, filter out agents that appear in earlier MicroagentObservations + filtered_agents = self._filter_agents_in_microagent_obs( + obs, current_index, events or [] + ) + + # Create and return a message if there is microagent knowledge to include + if filtered_agents: + # Exclude disabled microagents + filtered_agents = [ + agent + for agent in filtered_agents + if agent.name not in self.agent_config.disabled_microagents + ] + + # Only proceed if we still have agents after filtering out disabled ones + if filtered_agents: + formatted_text = self.prompt_manager.build_microagent_info( + triggered_agents=filtered_agents, + ) + + return [ + Message( + role='user', content=[TextContent(text=formatted_text)] + ) + ] + + # Return empty list if no microagents to include or all were disabled + return [] + elif ( + isinstance(obs, MicroagentObservation) + and not self.agent_config.enable_prompt_extensions + ): + # If prompt extensions are disabled, we don't add any additional info + # TODO: test this + return [] else: # If an observation message is not returned, it will cause an error # when the LLM tries to return the next message @@ -404,3 +502,53 @@ class ConversationMemory: -1 ].cache_prompt = True # Last item inside the message content break + + def _filter_agents_in_microagent_obs( + self, obs: MicroagentObservation, current_index: int, events: list[Event] + ) -> list[MicroagentKnowledge]: + """Filter out agents that appear in earlier MicroagentObservations. + + Args: + obs: The current MicroagentObservation to filter + current_index: The index of the current event in the events list + events: The list of all events + + Returns: + list[MicroagentKnowledge]: The filtered list of microagent knowledge + """ + if obs.recall_type != RecallType.KNOWLEDGE: + return obs.microagent_knowledge + + # For each agent in the current microagent observation, check if it appears in any earlier microagent observation + filtered_agents = [] + for agent in obs.microagent_knowledge: + # Keep this agent if it doesn't appear in any earlier observation + # that is, if this is the first microagent observation with this microagent + if not self._has_agent_in_earlier_events(agent.name, current_index, events): + filtered_agents.append(agent) + + return filtered_agents + + def _has_agent_in_earlier_events( + self, agent_name: str, current_index: int, events: list[Event] + ) -> bool: + """Check if an agent appears in any earlier MicroagentObservation in the event list. + + Args: + agent_name: The name of the agent to look for + current_index: The index of the current event in the events list + events: The list of all events + + Returns: + bool: True if the agent appears in an earlier MicroagentObservation, False otherwise + """ + for event in events[:current_index]: + if ( + isinstance(event, MicroagentObservation) + and event.recall_type == RecallType.KNOWLEDGE + ): + if any( + agent.name == agent_name for agent in event.microagent_knowledge + ): + return True + return False diff --git a/openhands/memory/memory.py b/openhands/memory/memory.py new file mode 100644 index 0000000000..2dba5f10ea --- /dev/null +++ b/openhands/memory/memory.py @@ -0,0 +1,270 @@ +import asyncio +import os +import uuid +from typing import Callable + +import openhands +from openhands.core.logger import openhands_logger as logger +from openhands.events.action.agent import RecallAction +from openhands.events.event import Event, EventSource, RecallType +from openhands.events.observation.agent import ( + MicroagentKnowledge, + MicroagentObservation, +) +from openhands.events.observation.empty import NullObservation +from openhands.events.stream import EventStream, EventStreamSubscriber +from openhands.microagent import ( + BaseMicroAgent, + KnowledgeMicroAgent, + RepoMicroAgent, + load_microagents_from_dir, +) +from openhands.runtime.base import Runtime +from openhands.utils.prompt import RepositoryInfo, RuntimeInfo + +GLOBAL_MICROAGENTS_DIR = os.path.join( + os.path.dirname(os.path.dirname(openhands.__file__)), + 'microagents', +) + + +class Memory: + """ + Memory is a component that listens to the EventStream for information retrieval actions + (a RecallAction) and publishes observations with the content (such as MicroagentObservation). + """ + + sid: str + event_stream: EventStream + status_callback: Callable | None + loop: asyncio.AbstractEventLoop | None + + def __init__( + self, + event_stream: EventStream, + sid: str, + status_callback: Callable | None = None, + ): + self.event_stream = event_stream + self.sid = sid if sid else str(uuid.uuid4()) + self.status_callback = status_callback + self.loop = None + + self.event_stream.subscribe( + EventStreamSubscriber.MEMORY, + self.on_event, + self.sid, + ) + + # Additional placeholders to store user workspace microagents + self.repo_microagents: dict[str, RepoMicroAgent] = {} + self.knowledge_microagents: dict[str, KnowledgeMicroAgent] = {} + + # Store repository / runtime info to send them to the templating later + self.repository_info: RepositoryInfo | None = None + self.runtime_info: RuntimeInfo | None = None + + # Load global microagents (Knowledge + Repo) + # from typically OpenHands/microagents (i.e., the PUBLIC microagents) + self._load_global_microagents() + + def on_event(self, event: Event): + """Handle an event from the event stream.""" + asyncio.get_event_loop().run_until_complete(self._on_event(event)) + + async def _on_event(self, event: Event): + """Handle an event from the event stream asynchronously.""" + try: + observation: MicroagentObservation | NullObservation | None = None + + if isinstance(event, RecallAction): + # if this is a workspace context recall (on first user message) + # create and add a MicroagentObservation + # with info about repo and runtime. + if ( + event.source == EventSource.USER + and event.recall_type == RecallType.WORKSPACE_CONTEXT + ): + observation = self._on_first_microagent_action(event) + + # continue with the next handler, to include knowledge microagents if suitable for this query + assert observation is None or isinstance( + observation, MicroagentObservation + ), f'Expected a MicroagentObservation, but got {type(observation)}' + observation = self._on_microagent_action( + event, prev_observation=observation + ) + + if observation is None: + observation = NullObservation(content='') + + # important: this will release the execution flow from waiting for the retrieval to complete + observation._cause = event.id # type: ignore[union-attr] + + self.event_stream.add_event(observation, EventSource.ENVIRONMENT) + except Exception as e: + error_str = f'Error: {str(e.__class__.__name__)}' + logger.error(error_str) + self.send_error_message('STATUS$ERROR_MEMORY', error_str) + return + + def _on_first_microagent_action( + self, event: RecallAction + ) -> MicroagentObservation | None: + """Add repository and runtime information to the stream as a MicroagentObservation.""" + + # Create ENVIRONMENT info: + # - repository_info + # - runtime_info + # - repository_instructions + + # Collect raw repository instructions + repo_instructions = '' + assert ( + len(self.repo_microagents) <= 1 + ), f'Expecting at most one repo microagent, but found {len(self.repo_microagents)}: {self.repo_microagents.keys()}' + + # Retrieve the context of repo instructions + for microagent in self.repo_microagents.values(): + if repo_instructions: + repo_instructions += '\n\n' + repo_instructions += microagent.content + + # Create observation if we have anything + if self.repository_info or self.runtime_info or repo_instructions: + obs = MicroagentObservation( + recall_type=RecallType.WORKSPACE_CONTEXT, + repo_name=self.repository_info.repo_name + if self.repository_info and self.repository_info.repo_name is not None + else '', + repo_directory=self.repository_info.repo_directory + if self.repository_info + and self.repository_info.repo_directory is not None + else '', + repo_instructions=repo_instructions if repo_instructions else '', + runtime_hosts=self.runtime_info.available_hosts + if self.runtime_info and self.runtime_info.available_hosts is not None + else {}, + additional_agent_instructions=self.runtime_info.additional_agent_instructions + if self.runtime_info + and self.runtime_info.additional_agent_instructions is not None + else '', + microagent_knowledge=[], + content='Retrieved environment info', + ) + return obs + return None + + def _on_microagent_action( + self, + event: RecallAction, + prev_observation: MicroagentObservation | None = None, + ) -> MicroagentObservation | None: + """When a microagent action triggers microagents, create a MicroagentObservation with structured data.""" + # If there's no query, do nothing + query = event.query.strip() + if not query: + return prev_observation + + assert prev_observation is None or isinstance( + prev_observation, MicroagentObservation + ), f'Expected a MicroagentObservation, but got {type(prev_observation)}' + + # Process text to find suitable microagents and create a MicroagentObservation. + recalled_content: list[MicroagentKnowledge] = [] + for name, microagent in self.knowledge_microagents.items(): + trigger = microagent.match_trigger(query) + if trigger: + logger.info("Microagent '%s' triggered by keyword '%s'", name, trigger) + recalled_content.append( + MicroagentKnowledge( + name=microagent.name, + trigger=trigger, + content=microagent.content, + ) + ) + + if recalled_content: + if prev_observation is not None: + # it may be on the first user message that already found some repo info etc + prev_observation.microagent_knowledge.extend(recalled_content) + else: + # if it's not the first user message, we may not have found any information this step + obs = MicroagentObservation( + recall_type=RecallType.KNOWLEDGE, + microagent_knowledge=recalled_content, + content='Retrieved knowledge from microagents', + ) + + return obs + + return prev_observation + + def load_user_workspace_microagents( + self, user_microagents: list[BaseMicroAgent] + ) -> None: + """ + This method loads microagents from a user's cloned repo or workspace directory. + + This is typically called from agent_session or setup once the workspace is cloned. + """ + logger.info( + 'Loading user workspace microagents: %s', [m.name for m in user_microagents] + ) + for user_microagent in user_microagents: + if isinstance(user_microagent, KnowledgeMicroAgent): + self.knowledge_microagents[user_microagent.name] = user_microagent + elif isinstance(user_microagent, RepoMicroAgent): + self.repo_microagents[user_microagent.name] = user_microagent + + def _load_global_microagents(self) -> None: + """ + Loads microagents from the global microagents_dir + """ + repo_agents, knowledge_agents, _ = load_microagents_from_dir( + GLOBAL_MICROAGENTS_DIR + ) + for name, agent in knowledge_agents.items(): + if isinstance(agent, KnowledgeMicroAgent): + self.knowledge_microagents[name] = agent + for name, agent in repo_agents.items(): + if isinstance(agent, RepoMicroAgent): + self.repo_microagents[name] = agent + + def set_repository_info(self, repo_name: str, repo_directory: str) -> None: + """Store repository info so we can reference it in an observation.""" + if repo_name or repo_directory: + self.repository_info = RepositoryInfo(repo_name, repo_directory) + else: + self.repository_info = None + + def set_runtime_info(self, runtime: Runtime) -> None: + """Store runtime info (web hosts, ports, etc.).""" + # e.g. { '127.0.0.1': 8080 } + if runtime.web_hosts or runtime.additional_agent_instructions: + self.runtime_info = RuntimeInfo( + available_hosts=runtime.web_hosts, + additional_agent_instructions=runtime.additional_agent_instructions, + ) + else: + self.runtime_info = None + + def send_error_message(self, message_id: str, message: str): + """Sends an error message if the callback function was provided.""" + if self.status_callback: + try: + if self.loop is None: + self.loop = asyncio.get_running_loop() + asyncio.run_coroutine_threadsafe( + self._send_status_message('error', message_id, message), self.loop + ) + except RuntimeError as e: + logger.error( + f'Error sending status message: {e.__class__.__name__}', + stack_info=False, + ) + + async def _send_status_message(self, msg_type: str, id: str, message: str): + """Sends a status message to the client.""" + if self.status_callback: + self.status_callback(msg_type, id, message) diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index 56f208db41..02466a7ad9 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -15,7 +15,8 @@ from openhands.core.schema.agent import AgentState from openhands.events.action import ChangeAgentStateAction, MessageAction from openhands.events.event import EventSource from openhands.events.stream import EventStream -from openhands.microagent import BaseMicroAgent +from openhands.memory.memory import Memory +from openhands.microagent.microagent import BaseMicroAgent from openhands.runtime import get_runtime_cls from openhands.runtime.base import Runtime from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime @@ -126,6 +127,15 @@ class AgentSession: agent_to_llm_config=agent_to_llm_config, agent_configs=agent_configs, ) + + repo_directory = None + if self.runtime and runtime_connected and selected_repository: + repo_directory = selected_repository.split('/')[-1] + self.memory = await self._create_memory( + selected_repository=selected_repository, + repo_directory=repo_directory, + ) + if github_token: self.event_stream.set_secrets( { @@ -260,26 +270,14 @@ class AgentSession: ) return False - repo_directory = None if selected_repository: - repo_directory = await call_sync_from_async( + await call_sync_from_async( self.runtime.clone_repo, github_token, selected_repository, selected_branch, ) - if agent.prompt_manager: - agent.prompt_manager.set_runtime_info(self.runtime) - microagents: list[BaseMicroAgent] = await call_sync_from_async( - self.runtime.get_microagents_from_selected_repo, selected_repository - ) - agent.prompt_manager.load_microagents(microagents) - if selected_repository and repo_directory: - agent.prompt_manager.set_repository_info( - selected_repository, repo_directory - ) - self.logger.debug( f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}' ) @@ -342,6 +340,29 @@ class AgentSession: return controller + async def _create_memory( + self, selected_repository: str | None, repo_directory: str | None + ) -> Memory: + memory = Memory( + event_stream=self.event_stream, + sid=self.sid, + status_callback=self._status_callback, + ) + + if self.runtime: + # sets available hosts and other runtime info + memory.set_runtime_info(self.runtime) + + # loads microagents from repo/.openhands/microagents + microagents: list[BaseMicroAgent] = await call_sync_from_async( + self.runtime.get_microagents_from_selected_repo, selected_repository + ) + memory.load_user_workspace_microagents(microagents) + + if selected_repository and repo_directory: + memory.set_repository_info(selected_repository, repo_directory) + return memory + def _maybe_restore_state(self) -> State | None: """Helper method to handle state restore logic.""" restored_state = None diff --git a/openhands/storage/conversation/file_conversation_store.py b/openhands/storage/conversation/file_conversation_store.py index 0f7d1f9fa9..ed18b7cd01 100644 --- a/openhands/storage/conversation/file_conversation_store.py +++ b/openhands/storage/conversation/file_conversation_store.py @@ -85,8 +85,8 @@ class FileConversationStore(ConversationStore): try: conversations.append(await self.get_metadata(conversation_id)) except Exception: - logger.error( - f'Error loading conversation: {conversation_id}', + logger.warning( + f'Could not load conversation metadata: {conversation_id}', ) conversations.sort(key=_sort_key, reverse=True) conversations = conversations[start:end] diff --git a/openhands/utils/prompt.py b/openhands/utils/prompt.py index 171d7f991c..643af5aa2e 100644 --- a/openhands/utils/prompt.py +++ b/openhands/utils/prompt.py @@ -1,25 +1,18 @@ import os -from dataclasses import dataclass +from dataclasses import dataclass, field from itertools import islice from jinja2 import Template from openhands.controller.state.state import State -from openhands.core.logger import openhands_logger from openhands.core.message import Message, TextContent -from openhands.microagent import ( - BaseMicroAgent, - KnowledgeMicroAgent, - RepoMicroAgent, - load_microagents_from_dir, -) -from openhands.runtime.base import Runtime +from openhands.events.observation.agent import MicroagentKnowledge @dataclass class RuntimeInfo: - available_hosts: dict[str, int] - additional_agent_instructions: str + available_hosts: dict[str, int] = field(default_factory=dict) + additional_agent_instructions: str = '' @dataclass @@ -32,75 +25,23 @@ class RepositoryInfo: class PromptManager: """ - Manages prompt templates and micro-agents for AI interactions. + Manages prompt templates and includes information from the user's workspace micro-agents and global micro-agents. - This class handles loading and rendering of system and user prompt templates, - as well as loading micro-agent specifications. It provides methods to access - rendered system and initial user messages for AI interactions. + This class is dedicated to loading and rendering prompts (system prompt, user prompt). Attributes: - prompt_dir (str): Directory containing prompt templates. - microagent_dir (str): Directory containing microagent specifications. - disabled_microagents (list[str] | None): List of microagents to disable. If None, all microagents are enabled. + prompt_dir: Directory containing prompt templates. """ def __init__( self, prompt_dir: str, - microagent_dir: str | None = None, - disabled_microagents: list[str] | None = None, ): - self.disabled_microagents: list[str] = disabled_microagents or [] self.prompt_dir: str = prompt_dir - self.repository_info: RepositoryInfo | None = None self.system_template: Template = self._load_template('system_prompt') self.user_template: Template = self._load_template('user_prompt') self.additional_info_template: Template = self._load_template('additional_info') self.microagent_info_template: Template = self._load_template('microagent_info') - self.runtime_info = RuntimeInfo( - available_hosts={}, additional_agent_instructions='' - ) - - self.knowledge_microagents: dict[str, KnowledgeMicroAgent] = {} - self.repo_microagents: dict[str, RepoMicroAgent] = {} - - if microagent_dir: - # This loads micro-agents from the microagent_dir - # which is typically the OpenHands/microagents (i.e., the PUBLIC microagents) - - # Only load KnowledgeMicroAgents - repo_microagents, knowledge_microagents, _ = load_microagents_from_dir( - microagent_dir - ) - assert all( - isinstance(microagent, KnowledgeMicroAgent) - for microagent in knowledge_microagents.values() - ) - for name, microagent in knowledge_microagents.items(): - if name not in self.disabled_microagents: - self.knowledge_microagents[name] = microagent - assert all( - isinstance(microagent, RepoMicroAgent) - for microagent in repo_microagents.values() - ) - for name, microagent in repo_microagents.items(): - if name not in self.disabled_microagents: - self.repo_microagents[name] = microagent - - def load_microagents(self, microagents: list[BaseMicroAgent]) -> None: - """Load microagents from a list of BaseMicroAgents. - - This is typically used when loading microagents from inside a repo. - """ - openhands_logger.info('Loading microagents: %s', [m.name for m in microagents]) - # Only keep KnowledgeMicroAgents and RepoMicroAgents - for microagent in microagents: - if microagent.name in self.disabled_microagents: - continue - if isinstance(microagent, KnowledgeMicroAgent): - self.knowledge_microagents[microagent.name] = microagent - elif isinstance(microagent, RepoMicroAgent): - self.repo_microagents[microagent.name] = microagent def _load_template(self, template_name: str) -> Template: if self.prompt_dir is None: @@ -114,27 +55,6 @@ class PromptManager: def get_system_message(self) -> str: return self.system_template.render().strip() - def set_runtime_info(self, runtime: Runtime) -> None: - self.runtime_info.available_hosts = runtime.web_hosts - self.runtime_info.additional_agent_instructions = ( - runtime.additional_agent_instructions - ) - - def set_repository_info( - self, - repo_name: str, - repo_directory: str, - ) -> None: - """Sets information about the GitHub repository that has been cloned. - - Args: - repo_name: The name of the GitHub repository (e.g. 'owner/repo') - repo_directory: The directory where the repository has been cloned - """ - self.repository_info = RepositoryInfo( - repo_name=repo_name, repo_directory=repo_directory - ) - def get_example_user_message(self) -> str: """This is the initial user message provided to the agent before *actual* user instructions are provided. @@ -148,45 +68,6 @@ class PromptManager: return self.user_template.render().strip() - def enhance_message(self, message: Message) -> None: - """Enhance the user message with additional context. - - This method is used to enhance the user message with additional context - about the user's task. The additional context will convert the current - generic agent into a more specialized agent that is tailored to the user's task. - """ - if not message.content: - return - - # if there were other texts included, they were before the user message - # so the last TextContent is the user message - # content can be a list of TextContent or ImageContent - message_content = '' - for content in reversed(message.content): - if isinstance(content, TextContent): - message_content = content.text - break - - if not message_content: - return - - triggered_agents = [] - for name, microagent in self.knowledge_microagents.items(): - trigger = microagent.match_trigger(message_content) - if trigger: - openhands_logger.info( - "Microagent '%s' triggered by keyword '%s'", - name, - trigger, - ) - # Create a dictionary with the agent and trigger word - triggered_agents.append({'agent': microagent, 'trigger_word': trigger}) - - if triggered_agents: - formatted_text = self.build_microagent_info(triggered_agents) - # Insert the new content at the start of the TextContent list - message.content.insert(0, TextContent(text=formatted_text)) - def add_examples_to_initial_message(self, message: Message) -> None: """Add example_message to the first user message.""" example_message = self.get_example_user_message() or None @@ -195,44 +76,28 @@ class PromptManager: if example_message: message.content.insert(0, TextContent(text=example_message)) - def add_info_to_initial_message( + def build_additional_info( self, - message: Message, - ) -> None: - """Adds information about the repository and runtime to the initial user message. - - Args: - message: The initial user message to add information to. - """ - repo_instructions = '' - assert ( - len(self.repo_microagents) <= 1 - ), f'Expecting at most one repo microagent, but found {len(self.repo_microagents)}: {self.repo_microagents.keys()}' - for microagent in self.repo_microagents.values(): - # We assume these are the repo instructions - if repo_instructions: - repo_instructions += '\n\n' - repo_instructions += microagent.content - - additional_info = self.additional_info_template.render( + repository_info: RepositoryInfo | None, + runtime_info: RuntimeInfo | None, + repo_instructions: str = '', + ) -> str: + """Renders the additional info template with the stored repository/runtime info.""" + return self.additional_info_template.render( + repository_info=repository_info, repository_instructions=repo_instructions, - repository_info=self.repository_info, - runtime_info=self.runtime_info, + runtime_info=runtime_info, ).strip() - # Insert the new content at the start of the TextContent list - if additional_info: - message.content.insert(0, TextContent(text=additional_info)) - def build_microagent_info( self, - triggered_agents: list[dict], + triggered_agents: list[MicroagentKnowledge], ) -> str: """Renders the microagent info template with the triggered agents. Args: - triggered_agents: A list of dictionaries, each containing an "agent" - (KnowledgeMicroAgent) and a "trigger_word" (str). + triggered_agents: A list of MicroagentKnowledge objects containing information + about triggered microagents. """ return self.microagent_info_template.render( triggered_agents=triggered_agents diff --git a/tests/unit/test_action_serialization.py b/tests/unit/test_action_serialization.py index d29f40e3e5..afab8a3f77 100644 --- a/tests/unit/test_action_serialization.py +++ b/tests/unit/test_action_serialization.py @@ -9,6 +9,7 @@ from openhands.events.action import ( FileReadAction, FileWriteAction, MessageAction, + RecallAction, ) from openhands.events.action.action import ActionConfirmationStatus from openhands.events.action.files import FileEditSource, FileReadSource @@ -356,6 +357,18 @@ def test_file_ohaci_edit_action_legacy_serialization(): assert event_dict['args']['end'] == -1 +def test_agent_microagent_action_serialization_deserialization(): + original_action_dict = { + 'action': 'recall', + 'args': { + 'query': 'What is the capital of France?', + 'thought': 'I need to find information about France', + 'recall_type': 'knowledge', + }, + } + serialization_deserialization(original_action_dict, RecallAction) + + def test_file_read_action_legacy_serialization(): original_action_dict = { 'action': 'read', diff --git a/tests/unit/test_agent_controller.py b/tests/unit/test_agent_controller.py index 1d8c7bb8d5..58b8ae7c95 100644 --- a/tests/unit/test_agent_controller.py +++ b/tests/unit/test_agent_controller.py @@ -1,5 +1,5 @@ import asyncio -from unittest.mock import ANY, AsyncMock, MagicMock +from unittest.mock import ANY, AsyncMock, MagicMock, patch from uuid import uuid4 import pytest @@ -14,12 +14,16 @@ from openhands.core.main import run_controller from openhands.core.schema import AgentState from openhands.events import Event, EventSource, EventStream, EventStreamSubscriber from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction +from openhands.events.action.agent import RecallAction +from openhands.events.event import RecallType from openhands.events.observation import ( ErrorObservation, ) +from openhands.events.observation.agent import MicroagentObservation from openhands.events.serialization import event_to_dict from openhands.llm import LLM from openhands.llm.metrics import Metrics, TokenUsage +from openhands.memory.memory import Memory from openhands.runtime.base import Runtime from openhands.storage.memory import InMemoryFileStore @@ -47,17 +51,36 @@ def mock_agent(): @pytest.fixture def mock_event_stream(): - mock = MagicMock(spec=EventStream) + mock = MagicMock( + spec=EventStream, + event_stream=EventStream(sid='test', file_store=InMemoryFileStore({})), + ) mock.get_latest_event_id.return_value = 0 return mock +@pytest.fixture +def test_event_stream(): + event_stream = EventStream(sid='test', file_store=InMemoryFileStore({})) + return event_stream + + @pytest.fixture def mock_runtime() -> Runtime: - return MagicMock( + runtime = MagicMock( spec=Runtime, - event_stream=EventStream(sid='test', file_store=InMemoryFileStore({})), + event_stream=test_event_stream, ) + return runtime + + +@pytest.fixture +def mock_memory() -> Memory: + memory = MagicMock( + spec=Memory, + event_stream=test_event_stream, + ) + return memory @pytest.fixture @@ -68,6 +91,7 @@ def mock_status_callback(): async def send_event_to_controller(controller, event): await controller._on_event(event) await asyncio.sleep(0.1) + controller._pending_action = None @pytest.mark.asyncio @@ -140,10 +164,8 @@ async def test_react_to_exception(mock_agent, mock_event_stream, mock_status_cal @pytest.mark.asyncio -async def test_run_controller_with_fatal_error(): +async def test_run_controller_with_fatal_error(test_event_stream, mock_memory): config = AppConfig() - file_store = InMemoryFileStore({}) - event_stream = EventStream(sid='test', file_store=file_store) agent = MagicMock(spec=Agent) agent = MagicMock(spec=Agent) @@ -163,10 +185,23 @@ async def test_run_controller_with_fatal_error(): if isinstance(event, CmdRunAction): error_obs = ErrorObservation('You messed around with Jim') error_obs._cause = event.id - event_stream.add_event(error_obs, EventSource.USER) + test_event_stream.add_event(error_obs, EventSource.USER) - event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4())) - runtime.event_stream = event_stream + test_event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4())) + runtime.event_stream = test_event_stream + + def on_event_memory(event: Event): + if isinstance(event, RecallAction): + microagent_obs = MicroagentObservation( + content='Test microagent content', + recall_type=RecallType.KNOWLEDGE, + ) + microagent_obs._cause = event.id + test_event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT) + + test_event_stream.subscribe( + EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4()) + ) state = await run_controller( config=config, @@ -175,22 +210,20 @@ async def test_run_controller_with_fatal_error(): sid='test', agent=agent, fake_user_response_fn=lambda _: 'repeat', + memory=mock_memory, ) print(f'state: {state}') - events = list(event_stream.get_events()) + events = list(test_event_stream.get_events()) print(f'event_stream: {events}') - assert state.iteration == 4 + assert state.iteration == 3 assert state.agent_state == AgentState.ERROR assert state.last_error == 'AgentStuckInLoopError: Agent got stuck in a loop' assert len(events) == 11 @pytest.mark.asyncio -async def test_run_controller_stop_with_stuck(): +async def test_run_controller_stop_with_stuck(test_event_stream, mock_memory): config = AppConfig() - file_store = InMemoryFileStore({}) - event_stream = EventStream(sid='test', file_store=file_store) - agent = MagicMock(spec=Agent) def agent_step_fn(state): @@ -209,10 +242,23 @@ async def test_run_controller_stop_with_stuck(): 'Non fatal error here to trigger loop' ) non_fatal_error_obs._cause = event.id - event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT) + test_event_stream.add_event(non_fatal_error_obs, EventSource.ENVIRONMENT) - event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4())) - runtime.event_stream = event_stream + test_event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4())) + runtime.event_stream = test_event_stream + + def on_event_memory(event: Event): + if isinstance(event, RecallAction): + microagent_obs = MicroagentObservation( + content='Test microagent content', + recall_type=RecallType.KNOWLEDGE, + ) + microagent_obs._cause = event.id + test_event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT) + + test_event_stream.subscribe( + EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4()) + ) state = await run_controller( config=config, @@ -221,16 +267,17 @@ async def test_run_controller_stop_with_stuck(): sid='test', agent=agent, fake_user_response_fn=lambda _: 'repeat', + memory=mock_memory, ) - events = list(event_stream.get_events()) + events = list(test_event_stream.get_events()) print(f'state: {state}') for i, event in enumerate(events): print(f'event {i}: {event_to_dict(event)}') - assert state.iteration == 4 + assert state.iteration == 3 assert len(events) == 11 # check the eventstream have 4 pairs of repeated actions and observations - repeating_actions_and_observations = events[2:10] + repeating_actions_and_observations = events[4:12] for action, observation in zip( repeating_actions_and_observations[0::2], repeating_actions_and_observations[1::2], @@ -510,12 +557,13 @@ async def test_reset_with_pending_action_no_metadata( @pytest.mark.asyncio -async def test_run_controller_max_iterations_has_metrics(): +async def test_run_controller_max_iterations_has_metrics( + test_event_stream, mock_memory +): config = AppConfig( max_iterations=3, ) - file_store = InMemoryFileStore({}) - event_stream = EventStream(sid='test', file_store=file_store) + event_stream = test_event_stream agent = MagicMock(spec=Agent) agent.llm = MagicMock(spec=LLM) @@ -546,6 +594,17 @@ async def test_run_controller_max_iterations_has_metrics(): event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4())) runtime.event_stream = event_stream + def on_event_memory(event: Event): + if isinstance(event, RecallAction): + microagent_obs = MicroagentObservation( + content='Test microagent content', + recall_type=RecallType.KNOWLEDGE, + ) + microagent_obs._cause = event.id + event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT) + + event_stream.subscribe(EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())) + state = await run_controller( config=config, initial_user_action=MessageAction(content='Test message'), @@ -553,6 +612,7 @@ async def test_run_controller_max_iterations_has_metrics(): sid='test', agent=agent, fake_user_response_fn=lambda _: 'repeat', + memory=mock_memory, ) assert state.iteration == 3 assert state.agent_state == AgentState.ERROR @@ -630,7 +690,7 @@ async def test_context_window_exceeded_error_handling(mock_agent, mock_event_str @pytest.mark.asyncio async def test_run_controller_with_context_window_exceeded_with_truncation( - mock_agent, mock_runtime + mock_agent, mock_runtime, mock_memory, test_event_stream ): """Tests that the controller can make progress after handling context window exceeded errors, as long as enable_history_truncation is ON""" @@ -656,6 +716,20 @@ async def test_run_controller_with_context_window_exceeded_with_truncation( mock_agent.step = step_state.step mock_agent.config = AgentConfig() + def on_event_memory(event: Event): + if isinstance(event, RecallAction): + microagent_obs = MicroagentObservation( + content='Test microagent content', + recall_type=RecallType.KNOWLEDGE, + ) + microagent_obs._cause = event.id + test_event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT) + + test_event_stream.subscribe( + EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4()) + ) + mock_runtime.event_stream = test_event_stream + try: state = await asyncio.wait_for( run_controller( @@ -665,6 +739,7 @@ async def test_run_controller_with_context_window_exceeded_with_truncation( sid='test', agent=mock_agent, fake_user_response_fn=lambda _: 'repeat', + memory=mock_memory, ), timeout=10, ) @@ -691,7 +766,7 @@ async def test_run_controller_with_context_window_exceeded_with_truncation( @pytest.mark.asyncio async def test_run_controller_with_context_window_exceeded_without_truncation( - mock_agent, mock_runtime + mock_agent, mock_runtime, mock_memory, test_event_stream ): """Tests that the controller would quit upon context window exceeded errors without enable_history_truncation ON.""" @@ -702,7 +777,7 @@ async def test_run_controller_with_context_window_exceeded_without_truncation( def step(self, state: State): # If the state has more than one message and we haven't errored yet, # throw the context window exceeded error - if len(state.history) > 1 and not self.has_errored: + if len(state.history) > 3 and not self.has_errored: error = ContextWindowExceededError( message='prompt is too long: 233885 tokens > 200000 maximum', model='', @@ -718,6 +793,19 @@ async def test_run_controller_with_context_window_exceeded_without_truncation( mock_agent.config = AgentConfig() mock_agent.config.enable_history_truncation = False + def on_event_memory(event: Event): + if isinstance(event, RecallAction): + microagent_obs = MicroagentObservation( + content='Test microagent content', + recall_type=RecallType.KNOWLEDGE, + ) + microagent_obs._cause = event.id + test_event_stream.add_event(microagent_obs, EventSource.ENVIRONMENT) + + test_event_stream.subscribe( + EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4()) + ) + mock_runtime.event_stream = test_event_stream try: state = await asyncio.wait_for( run_controller( @@ -727,6 +815,7 @@ async def test_run_controller_with_context_window_exceeded_without_truncation( sid='test', agent=mock_agent, fake_user_response_fn=lambda _: 'repeat', + memory=mock_memory, ), timeout=10, ) @@ -751,6 +840,44 @@ async def test_run_controller_with_context_window_exceeded_without_truncation( assert step_state.has_errored +@pytest.mark.asyncio +async def test_run_controller_with_memory_error(test_event_stream): + config = AppConfig() + event_stream = test_event_stream + + agent = MagicMock(spec=Agent) + agent.llm = MagicMock(spec=LLM) + agent.llm.metrics = Metrics() + agent.llm.config = config.get_llm_config() + + runtime = MagicMock(spec=Runtime) + runtime.event_stream = event_stream + + # Create a real Memory instance + memory = Memory(event_stream=event_stream, sid='test-memory') + + # Patch the _on_microagent_action method to raise our test exception + def mock_on_microagent_action(*args, **kwargs): + raise RuntimeError('Test memory error') + + with patch.object( + memory, '_on_microagent_action', side_effect=mock_on_microagent_action + ): + state = await run_controller( + config=config, + initial_user_action=MessageAction(content='Test message'), + runtime=runtime, + sid='test', + agent=agent, + fake_user_response_fn=lambda _: 'repeat', + memory=memory, + ) + + assert state.iteration == 0 + assert state.agent_state == AgentState.ERROR + assert state.last_error == 'Error: RuntimeError' + + @pytest.mark.asyncio async def test_action_metrics_copy(): # Setup @@ -851,3 +978,56 @@ async def test_action_metrics_copy(): assert last_action.llm_metrics.accumulated_cost == 0.07 await controller.close() + + +@pytest.mark.asyncio +async def test_first_user_message_with_identical_content(): + """ + Test that _first_user_message correctly identifies the first user message + even when multiple messages have identical content but different IDs. + + The issue we're checking is that the comparison (action == self._first_user_message()) + should correctly differentiate between messages with the same content but different IDs. + """ + # Create a real event stream for this test + event_stream = EventStream(sid='test', file_store=InMemoryFileStore({})) + + # Create an agent controller + mock_agent = MagicMock(spec=Agent) + mock_agent.llm = MagicMock(spec=LLM) + mock_agent.llm.metrics = Metrics() + mock_agent.llm.config = AppConfig().get_llm_config() + + controller = AgentController( + agent=mock_agent, + event_stream=event_stream, + max_iterations=10, + sid='test', + confirmation_mode=False, + headless_mode=True, + ) + + # Create and add the first user message + first_message = MessageAction(content='Hello, this is a test message') + first_message._source = EventSource.USER + event_stream.add_event(first_message, EventSource.USER) + + # Create and add a second user message with identical content + second_message = MessageAction(content='Hello, this is a test message') + second_message._source = EventSource.USER + event_stream.add_event(second_message, EventSource.USER) + + # Verify that _first_user_message returns the first message + first_user_message = controller._first_user_message() + assert first_user_message is not None + assert first_user_message.id == first_message.id # Check IDs match + assert first_user_message.id != second_message.id # Different IDs + assert first_user_message == first_message == second_message # dataclass equality + + # Test the comparison used in the actual code + assert first_message == first_user_message # This should be True + assert ( + second_message.id != first_user_message.id + ) # This should be False, but may be True if there's a bug + + await controller.close() diff --git a/tests/unit/test_agent_delegation.py b/tests/unit/test_agent_delegation.py index c5c4e63f1d..006711f19f 100644 --- a/tests/unit/test_agent_delegation.py +++ b/tests/unit/test_agent_delegation.py @@ -17,8 +17,13 @@ from openhands.events.action import ( AgentFinishAction, MessageAction, ) +from openhands.events.action.agent import RecallAction +from openhands.events.event import Event, RecallType +from openhands.events.observation.agent import MicroagentObservation +from openhands.events.stream import EventStreamSubscriber from openhands.llm.llm import LLM from openhands.llm.metrics import Metrics +from openhands.memory.memory import Memory from openhands.storage.memory import InMemoryFileStore @@ -75,6 +80,25 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s initial_state=parent_state, ) + # Setup Memory to catch RecallActions + mock_memory = MagicMock(spec=Memory) + mock_memory.event_stream = mock_event_stream + + def on_event(event: Event): + if isinstance(event, RecallAction): + # create a MicroagentObservation + microagent_observation = MicroagentObservation( + recall_type=RecallType.KNOWLEDGE, + content='microagent', + ) + microagent_observation._cause = event.id # ignore attr-defined warning + mock_event_stream.add_event(microagent_observation, EventSource.ENVIRONMENT) + + mock_memory.on_event = on_event + mock_event_stream.subscribe( + EventStreamSubscriber.MEMORY, mock_memory.on_event, mock_memory + ) + # Setup a delegate action from the parent delegate_action = AgentDelegateAction(agent='ChildAgent', inputs={'test': True}) mock_parent_agent.step.return_value = delegate_action @@ -87,7 +111,16 @@ async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_s # Give time for the async step() to execute await asyncio.sleep(1) - # The parent should receive step() from that event + # Verify that a MicroagentObservation was added to the event stream + events = list(mock_event_stream.get_events()) + assert ( + mock_event_stream.get_latest_event_id() == 3 + ) # Microagents and AgentChangeState + + # a MicroagentObservation and an AgentDelegateAction should be in the list + assert any(isinstance(event, MicroagentObservation) for event in events) + assert any(isinstance(event, AgentDelegateAction) for event in events) + # Verify that a delegate agent controller is created assert ( parent_controller.delegate is not None diff --git a/tests/unit/test_agent_session.py b/tests/unit/test_agent_session.py index 33a6cbd988..2ac8794e61 100644 --- a/tests/unit/test_agent_session.py +++ b/tests/unit/test_agent_session.py @@ -6,9 +6,11 @@ from openhands.controller.agent import Agent from openhands.controller.agent_controller import AgentController from openhands.controller.state.state import State from openhands.core.config import AppConfig, LLMConfig +from openhands.core.config.agent_config import AgentConfig from openhands.events import EventStream, EventStreamSubscriber from openhands.llm import LLM from openhands.llm.metrics import Metrics +from openhands.memory.memory import Memory from openhands.runtime.base import Runtime from openhands.server.session.agent_session import AgentSession from openhands.storage.memory import InMemoryFileStore @@ -22,18 +24,24 @@ def mock_agent(): llm = MagicMock(spec=LLM) metrics = MagicMock(spec=Metrics) llm_config = MagicMock(spec=LLMConfig) + agent_config = MagicMock(spec=AgentConfig) # Configure the LLM config llm_config.model = 'test-model' llm_config.base_url = 'http://test' llm_config.max_message_chars = 1000 + # Configure the agent config + agent_config.disabled_microagents = [] + # Set up the chain of mocks llm.metrics = metrics llm.config = llm_config agent.llm = llm agent.name = 'test-agent' agent.sandbox_plugins = [] + agent.config = agent_config + agent.prompt_manager = MagicMock() return agent @@ -78,7 +86,11 @@ async def test_agent_session_start_with_no_state(mock_agent): self.test_initial_state = state super().set_initial_state(*args, state=state, **kwargs) - # Patch AgentController and State.restore_from_session to fail + # Create a real Memory instance with the mock event stream + memory = Memory(event_stream=mock_event_stream, sid='test-session') + memory.microagents_dir = 'test-dir' + + # Patch AgentController and State.restore_from_session to fail; patch Memory in AgentSession with patch( 'openhands.server.session.agent_session.AgentController', SpyAgentController ), patch( @@ -87,7 +99,7 @@ async def test_agent_session_start_with_no_state(mock_agent): ), patch( 'openhands.controller.state.state.State.restore_from_session', side_effect=Exception('No state found'), - ): + ), patch('openhands.server.session.agent_session.Memory', return_value=memory): await session.start( runtime_name='test-runtime', config=AppConfig(), @@ -96,12 +108,18 @@ async def test_agent_session_start_with_no_state(mock_agent): ) # Verify EventStream.subscribe was called with correct parameters - mock_event_stream.subscribe.assert_called_with( + mock_event_stream.subscribe.assert_any_call( EventStreamSubscriber.AGENT_CONTROLLER, session.controller.on_event, session.controller.id, ) + mock_event_stream.subscribe.assert_any_call( + EventStreamSubscriber.MEMORY, + session.memory.on_event, + session.controller.id, + ) + # Verify set_initial_state was called once with None as state assert session.controller.set_initial_state_call_count == 1 assert session.controller.test_initial_state is None @@ -159,7 +177,10 @@ async def test_agent_session_start_with_restored_state(mock_agent): self.test_initial_state = state super().set_initial_state(*args, state=state, **kwargs) - # Patch AgentController and State.restore_from_session to succeed + # create a mock Memory + mock_memory = MagicMock(spec=Memory) + + # Patch AgentController and State.restore_from_session to succeed, patch Memory in AgentSession with patch( 'openhands.server.session.agent_session.AgentController', SpyAgentController ), patch( @@ -168,7 +189,7 @@ async def test_agent_session_start_with_restored_state(mock_agent): ), patch( 'openhands.controller.state.state.State.restore_from_session', return_value=mock_restored_state, - ): + ), patch('openhands.server.session.agent_session.Memory', mock_memory): await session.start( runtime_name='test-runtime', config=AppConfig(), diff --git a/tests/unit/test_cli_sid.py b/tests/unit/test_cli_sid.py index 939e45ef2b..67db79b63e 100644 --- a/tests/unit/test_cli_sid.py +++ b/tests/unit/test_cli_sid.py @@ -1,7 +1,7 @@ import asyncio from argparse import Namespace from pathlib import Path -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest @@ -17,13 +17,18 @@ def mock_runtime(): with patch('openhands.core.cli.create_runtime') as mock_create_runtime: mock_runtime_instance = AsyncMock() # Mock the event stream with proper async methods - mock_runtime_instance.event_stream = AsyncMock() - mock_runtime_instance.event_stream.subscribe = AsyncMock() - mock_runtime_instance.event_stream.add_event = AsyncMock() + mock_event_stream = AsyncMock() + mock_event_stream.subscribe = AsyncMock() + mock_event_stream.add_event = AsyncMock() + mock_event_stream.get_events = AsyncMock(return_value=[]) + mock_event_stream.get_latest_event_id = AsyncMock(return_value=0) + mock_runtime_instance.event_stream = mock_event_stream # Mock connect method to return immediately mock_runtime_instance.connect = AsyncMock() # Ensure status_callback is None mock_runtime_instance.status_callback = None + # Mock get_microagents_from_selected_repo + mock_runtime_instance.get_microagents_from_selected_repo = Mock(return_value=[]) mock_create_runtime.return_value = mock_runtime_instance yield mock_runtime_instance @@ -32,6 +37,16 @@ def mock_runtime(): def mock_agent(): with patch('openhands.core.cli.create_agent') as mock_create_agent: mock_agent_instance = AsyncMock() + mock_agent_instance.name = 'test-agent' + mock_agent_instance.llm = AsyncMock() + mock_agent_instance.llm.config = AsyncMock() + mock_agent_instance.llm.config.model = 'test-model' + mock_agent_instance.llm.config.base_url = 'http://test' + mock_agent_instance.llm.config.max_message_chars = 1000 + mock_agent_instance.config = AsyncMock() + mock_agent_instance.config.disabled_microagents = [] + mock_agent_instance.sandbox_plugins = [] + mock_agent_instance.prompt_manager = AsyncMock() mock_create_agent.return_value = mock_agent_instance yield mock_agent_instance diff --git a/tests/unit/test_codeact_agent.py b/tests/unit/test_codeact_agent.py index 6dcc57c4a6..55cd1eb6bd 100644 --- a/tests/unit/test_codeact_agent.py +++ b/tests/unit/test_codeact_agent.py @@ -369,9 +369,3 @@ def test_enhance_messages_adds_newlines_between_consecutive_user_messages( # Fifth message only has ImageContent, no TextContent to modify assert len(enhanced_messages[5].content) == 1 assert isinstance(enhanced_messages[5].content[0], ImageContent) - - # Verify prompt manager methods were called as expected - assert agent.prompt_manager.add_examples_to_initial_message.call_count == 1 - assert ( - agent.prompt_manager.enhance_message.call_count == 5 - ) # Called for each user message diff --git a/tests/unit/test_conversation_memory.py b/tests/unit/test_conversation_memory.py index 7721354bdb..5397c63559 100644 --- a/tests/unit/test_conversation_memory.py +++ b/tests/unit/test_conversation_memory.py @@ -1,16 +1,29 @@ +import os +import shutil from unittest.mock import MagicMock, Mock import pytest from openhands.controller.state.state import State +from openhands.core.config.agent_config import AgentConfig from openhands.core.message import ImageContent, Message, TextContent from openhands.events.action import ( AgentFinishAction, CmdRunAction, MessageAction, ) -from openhands.events.event import Event, EventSource, FileEditSource, FileReadSource +from openhands.events.event import ( + Event, + EventSource, + FileEditSource, + FileReadSource, + RecallType, +) from openhands.events.observation import CmdOutputObservation +from openhands.events.observation.agent import ( + MicroagentKnowledge, + MicroagentObservation, +) from openhands.events.observation.browse import BrowserOutputObservation from openhands.events.observation.commands import ( CmdOutputMetadata, @@ -22,14 +35,45 @@ from openhands.events.observation.files import FileEditObservation, FileReadObse from openhands.events.observation.reject import UserRejectObservation from openhands.events.tool import ToolCallMetadata from openhands.memory.conversation_memory import ConversationMemory -from openhands.utils.prompt import PromptManager +from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo @pytest.fixture -def conversation_memory(): +def agent_config(): + return AgentConfig( + enable_prompt_extensions=True, + enable_som_visual_browsing=True, + disabled_microagents=['disabled_agent'], + ) + + +@pytest.fixture +def conversation_memory(agent_config): prompt_manager = MagicMock(spec=PromptManager) prompt_manager.get_system_message.return_value = 'System message' - return ConversationMemory(prompt_manager) + prompt_manager.build_additional_info.return_value = ( + 'Formatted repository and runtime info' + ) + + # Make build_microagent_info return the actual content from the triggered agents + def build_microagent_info(triggered_agents): + if not triggered_agents: + return '' + return '\n'.join(agent.content for agent in triggered_agents) + + prompt_manager.build_microagent_info.side_effect = build_microagent_info + return ConversationMemory(agent_config, prompt_manager) + + +@pytest.fixture +def prompt_dir(tmp_path): + # Copy contents from "openhands/agenthub/codeact_agent" to the temp directory + shutil.copytree( + 'openhands/agenthub/codeact_agent/prompts', tmp_path, dirs_exist_ok=True + ) + + # Return the temporary directory path + return tmp_path @pytest.fixture @@ -308,6 +352,40 @@ def test_process_events_with_user_reject_observation(conversation_memory): assert '[Last action has been rejected by the user]' in result.content[0].text +def test_process_events_with_empty_environment_info(conversation_memory): + """Test that empty environment info observations return an empty list of messages without calling build_additional_info.""" + # Create a MicroagentObservation with empty info + + empty_obs = MicroagentObservation( + recall_type=RecallType.WORKSPACE_CONTEXT, + repo_name='', + repo_directory='', + repo_instructions='', + runtime_hosts={}, + additional_agent_instructions='', + microagent_knowledge=[], + content='Retrieved environment info', + ) + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + condensed_history=[empty_obs], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + # Should only contain the initial system message + assert len(messages) == 1 + assert messages[0].role == 'system' + + # Verify that build_additional_info was NOT called since all input values were empty + conversation_memory.prompt_manager.build_additional_info.assert_not_called() + + def test_process_events_with_function_calling_observation(conversation_memory): mock_response = { 'id': 'mock_id', @@ -446,3 +524,529 @@ def test_apply_prompt_caching(conversation_memory): assert messages[1].content[0].cache_prompt is False assert messages[2].content[0].cache_prompt is False assert messages[3].content[0].cache_prompt is True + + +def test_process_events_with_environment_microagent_observation(conversation_memory): + """Test processing a MicroagentObservation with ENVIRONMENT info type.""" + obs = MicroagentObservation( + recall_type=RecallType.WORKSPACE_CONTEXT, + repo_name='test-repo', + repo_directory='/path/to/repo', + repo_instructions='# Test Repository\nThis is a test repository.', + runtime_hosts={'localhost': 8080}, + content='Retrieved environment info', + ) + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + condensed_history=[obs], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + assert len(messages) == 2 + result = messages[1] + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == 'Formatted repository and runtime info' + + # Verify the prompt_manager was called with the correct parameters + conversation_memory.prompt_manager.build_additional_info.assert_called_once() + call_args = conversation_memory.prompt_manager.build_additional_info.call_args[1] + assert isinstance(call_args['repository_info'], RepositoryInfo) + assert call_args['repository_info'].repo_name == 'test-repo' + assert call_args['repository_info'].repo_directory == '/path/to/repo' + assert isinstance(call_args['runtime_info'], RuntimeInfo) + assert call_args['runtime_info'].available_hosts == {'localhost': 8080} + assert ( + call_args['repo_instructions'] + == '# Test Repository\nThis is a test repository.' + ) + + +def test_process_events_with_knowledge_microagent_microagent_observation( + conversation_memory, +): + """Test processing a MicroagentObservation with KNOWLEDGE type.""" + microagent_knowledge = [ + MicroagentKnowledge( + name='test_agent', + trigger='test', + content='This is test agent content', + ), + MicroagentKnowledge( + name='another_agent', + trigger='another', + content='This is another agent content', + ), + MicroagentKnowledge( + name='disabled_agent', + trigger='disabled', + content='This is disabled agent content', + ), + ] + + obs = MicroagentObservation( + recall_type=RecallType.KNOWLEDGE, + microagent_knowledge=microagent_knowledge, + content='Retrieved knowledge from microagents', + ) + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + condensed_history=[obs], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + assert len(messages) == 2 + result = messages[1] + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + # Verify that disabled_agent is filtered out and enabled agents are included + assert 'This is test agent content' in result.content[0].text + assert 'This is another agent content' in result.content[0].text + assert 'This is disabled agent content' not in result.content[0].text + + # Verify the prompt_manager was called with the correct parameters + conversation_memory.prompt_manager.build_microagent_info.assert_called_once() + call_args = conversation_memory.prompt_manager.build_microagent_info.call_args[1] + + # Check that disabled_agent was filtered out + triggered_agents = call_args['triggered_agents'] + assert len(triggered_agents) == 2 + agent_names = [agent.name for agent in triggered_agents] + assert 'test_agent' in agent_names + assert 'another_agent' in agent_names + assert 'disabled_agent' not in agent_names + + +def test_process_events_with_microagent_observation_extensions_disabled( + agent_config, conversation_memory +): + """Test processing a MicroagentObservation when prompt extensions are disabled.""" + # Modify the agent config to disable prompt extensions + agent_config.enable_prompt_extensions = False + + obs = MicroagentObservation( + recall_type=RecallType.WORKSPACE_CONTEXT, + repo_name='test-repo', + repo_directory='/path/to/repo', + content='Retrieved environment info', + ) + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + condensed_history=[obs], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + # When prompt extensions are disabled, the MicroagentObservation should be ignored + assert len(messages) == 1 # Only the initial system message + assert messages[0].role == 'system' + + # Verify the prompt_manager was not called + conversation_memory.prompt_manager.build_additional_info.assert_not_called() + conversation_memory.prompt_manager.build_microagent_info.assert_not_called() + + +def test_process_events_with_empty_microagent_knowledge(conversation_memory): + """Test processing a MicroagentObservation with empty microagent knowledge.""" + obs = MicroagentObservation( + recall_type=RecallType.KNOWLEDGE, + microagent_knowledge=[], + content='Retrieved knowledge from microagents', + ) + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + condensed_history=[obs], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + # The implementation returns an empty string and it doesn't creates a message + assert len(messages) == 1 + assert messages[0].role == 'system' + + # When there are no triggered agents, build_microagent_info is not called + conversation_memory.prompt_manager.build_microagent_info.assert_not_called() + + +def test_conversation_memory_processes_microagent_observation(prompt_dir): + """Test that ConversationMemory processes MicroagentObservations correctly.""" + # Create a microagent_info.j2 template file + template_path = os.path.join(prompt_dir, 'microagent_info.j2') + if not os.path.exists(template_path): + with open(template_path, 'w') as f: + f.write("""{% for agent_info in triggered_agents %} + +The following information has been included based on a keyword match for "{{ agent_info.trigger_word }}". +It may or may not be relevant to the user's request. + + # Verify the template was correctly rendered +{{ agent_info.content }} + +{% endfor %} +""") + + # Create a mock agent config + agent_config = MagicMock(spec=AgentConfig) + agent_config.enable_prompt_extensions = True + agent_config.disabled_microagents = [] + + # Create a PromptManager + prompt_manager = PromptManager(prompt_dir=prompt_dir) + + # Initialize ConversationMemory + conversation_memory = ConversationMemory( + config=agent_config, prompt_manager=prompt_manager + ) + + # Create a MicroagentObservation with microagent knowledge + microagent_observation = MicroagentObservation( + recall_type=RecallType.KNOWLEDGE, + microagent_knowledge=[ + MicroagentKnowledge( + name='test_agent', + trigger='test_trigger', + content='This is triggered content for testing.', + ) + ], + content='Retrieved knowledge from microagents', + ) + + # Process the observation + messages = conversation_memory._process_observation( + obs=microagent_observation, tool_call_id_to_message={}, max_message_chars=None + ) + + # Verify the message was created correctly + assert len(messages) == 1 + message = messages[0] + assert message.role == 'user' + assert len(message.content) == 1 + assert isinstance(message.content[0], TextContent) + + expected_text = """ +The following information has been included based on a keyword match for "test_trigger". +It may or may not be relevant to the user's request. + +This is triggered content for testing. +""" + + assert message.content[0].text.strip() == expected_text.strip() + + # Clean up + os.remove(os.path.join(prompt_dir, 'microagent_info.j2')) + + +def test_conversation_memory_processes_environment_microagent_observation(prompt_dir): + """Test that ConversationMemory processes environment info MicroagentObservations correctly.""" + # Create an additional_info.j2 template file + template_path = os.path.join(prompt_dir, 'additional_info.j2') + if not os.path.exists(template_path): + with open(template_path, 'w') as f: + f.write(""" +{% if repository_info %} + +At the user's request, repository {{ repository_info.repo_name }} has been cloned to directory {{ repository_info.repo_directory }}. + +{% endif %} + +{% if repository_instructions %} + +{{ repository_instructions }} + +{% endif %} + +{% if runtime_info and runtime_info.available_hosts %} + +The user has access to the following hosts for accessing a web application, +each of which has a corresponding port: +{% for host, port in runtime_info.available_hosts.items() %} +* {{ host }} (port {{ port }}) +{% endfor %} + +{% endif %} +""") + + # Create a mock agent config + agent_config = MagicMock(spec=AgentConfig) + agent_config.enable_prompt_extensions = True + + # Create a PromptManager + prompt_manager = PromptManager(prompt_dir=prompt_dir) + + # Initialize ConversationMemory + conversation_memory = ConversationMemory( + config=agent_config, prompt_manager=prompt_manager + ) + + # Create a MicroagentObservation with environment info + microagent_observation = MicroagentObservation( + recall_type=RecallType.WORKSPACE_CONTEXT, + repo_name='owner/repo', + repo_directory='/workspace/repo', + repo_instructions='This repository contains important code.', + runtime_hosts={'example.com': 8080}, + content='Retrieved environment info', + ) + + # Process the observation + messages = conversation_memory._process_observation( + obs=microagent_observation, tool_call_id_to_message={}, max_message_chars=None + ) + + # Verify the message was created correctly + assert len(messages) == 1 + message = messages[0] + assert message.role == 'user' + assert len(message.content) == 1 + assert isinstance(message.content[0], TextContent) + + # Check that the message contains the repository info + assert '' in message.content[0].text + assert 'owner/repo' in message.content[0].text + assert '/workspace/repo' in message.content[0].text + + # Check that the message contains the repository instructions + assert '' in message.content[0].text + assert 'This repository contains important code.' in message.content[0].text + + # Check that the message contains the runtime info + assert '' in message.content[0].text + assert 'example.com (port 8080)' in message.content[0].text + + +def test_process_events_with_microagent_observation_deduplication(conversation_memory): + """Test that MicroagentObservations are properly deduplicated based on agent name. + + The deduplication logic should keep the FIRST occurrence of each microagent + and filter out later occurrences to avoid redundant information. + """ + # Create a sequence of MicroagentObservations with overlapping agents + obs1 = MicroagentObservation( + recall_type=RecallType.KNOWLEDGE, + microagent_knowledge=[ + MicroagentKnowledge( + name='python_agent', + trigger='python', + content='Python best practices v1', + ), + MicroagentKnowledge( + name='git_agent', + trigger='git', + content='Git best practices v1', + ), + MicroagentKnowledge( + name='image_agent', + trigger='image', + content='Image best practices v1', + ), + ], + content='First retrieval', + ) + + obs2 = MicroagentObservation( + recall_type=RecallType.KNOWLEDGE, + microagent_knowledge=[ + MicroagentKnowledge( + name='python_agent', + trigger='python', + content='Python best practices v2', + ), + ], + content='Second retrieval', + ) + + obs3 = MicroagentObservation( + recall_type=RecallType.KNOWLEDGE, + microagent_knowledge=[ + MicroagentKnowledge( + name='git_agent', + trigger='git', + content='Git best practices v3', + ), + ], + content='Third retrieval', + ) + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + condensed_history=[obs1, obs2, obs3], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + # Verify that only the first occurrence of content for each agent is included + assert ( + len(messages) == 2 + ) # system + 1 microagent, because the second and third microagents are duplicates + microagent_messages = messages[1:] # Skip system message + + # First microagent should include all agents since they appear here first + assert 'Image best practices v1' in microagent_messages[0].content[0].text + assert 'Git best practices v1' in microagent_messages[0].content[0].text + assert 'Python best practices v1' in microagent_messages[0].content[0].text + + +def test_process_events_with_microagent_observation_deduplication_disabled_agents( + conversation_memory, +): + """Test that disabled agents are filtered out and deduplication keeps the first occurrence.""" + # Create a sequence of MicroagentObservations with disabled agents + obs1 = MicroagentObservation( + recall_type=RecallType.KNOWLEDGE, + microagent_knowledge=[ + MicroagentKnowledge( + name='disabled_agent', + trigger='disabled', + content='Disabled agent content', + ), + MicroagentKnowledge( + name='enabled_agent', + trigger='enabled', + content='Enabled agent content v1', + ), + ], + content='First retrieval', + ) + + obs2 = MicroagentObservation( + recall_type=RecallType.KNOWLEDGE, + microagent_knowledge=[ + MicroagentKnowledge( + name='enabled_agent', + trigger='enabled', + content='Enabled agent content v2', + ), + ], + content='Second retrieval', + ) + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + condensed_history=[obs1, obs2], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + # Verify that disabled agents are filtered out and only the first occurrence of enabled agents is included + assert ( + len(messages) == 2 + ) # system + 1 microagent, the second is the same "enabled_agent" + microagent_messages = messages[1:] # Skip system message + + # First microagent should include enabled_agent but not disabled_agent + assert 'Disabled agent content' not in microagent_messages[0].content[0].text + assert 'Enabled agent content v1' in microagent_messages[0].content[0].text + + +def test_process_events_with_microagent_observation_deduplication_empty( + conversation_memory, +): + """Test that empty MicroagentObservations are handled correctly.""" + obs = MicroagentObservation( + recall_type=RecallType.KNOWLEDGE, + microagent_knowledge=[], + content='Empty retrieval', + ) + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + condensed_history=[obs], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + # Verify that empty MicroagentObservations are handled gracefully + assert ( + len(messages) == 1 + ) # system message, because an empty microagent is not added to Messages + + +def test_has_agent_in_earlier_events(conversation_memory): + """Test the _has_agent_in_earlier_events helper method.""" + # Create test MicroagentObservations + obs1 = MicroagentObservation( + recall_type=RecallType.KNOWLEDGE, + microagent_knowledge=[ + MicroagentKnowledge( + name='agent1', + trigger='trigger1', + content='Content 1', + ), + ], + content='First retrieval', + ) + + obs2 = MicroagentObservation( + recall_type=RecallType.KNOWLEDGE, + microagent_knowledge=[ + MicroagentKnowledge( + name='agent2', + trigger='trigger2', + content='Content 2', + ), + ], + content='Second retrieval', + ) + + obs3 = MicroagentObservation( + recall_type=RecallType.WORKSPACE_CONTEXT, + content='Environment info', + ) + + # Create a list with mixed event types + events = [obs1, MessageAction(content='User message'), obs2, obs3] + + # Test looking for existing agents + assert conversation_memory._has_agent_in_earlier_events('agent1', 2, events) is True + assert conversation_memory._has_agent_in_earlier_events('agent1', 3, events) is True + assert conversation_memory._has_agent_in_earlier_events('agent1', 4, events) is True + + # Test looking for an agent in a later position (should not find it) + assert ( + conversation_memory._has_agent_in_earlier_events('agent2', 0, events) is False + ) + assert ( + conversation_memory._has_agent_in_earlier_events('agent2', 1, events) is False + ) + + # Test looking for an agent in a different microagent type (should not find it) + assert ( + conversation_memory._has_agent_in_earlier_events('non_existent', 3, events) + is False + ) diff --git a/tests/unit/test_is_stuck.py b/tests/unit/test_is_stuck.py index 1c0a40a725..e535ab22a4 100644 --- a/tests/unit/test_is_stuck.py +++ b/tests/unit/test_is_stuck.py @@ -358,12 +358,12 @@ class TestStuckDetector: with patch('logging.Logger.warning'): assert stuck_detector.is_stuck(headless_mode=True) is False - def test_is_not_stuck_ipython_unterminated_string_error_only_three_incidents( + def test_is_not_stuck_ipython_unterminated_string_error_only_two_incidents( self, stuck_detector: StuckDetector ): state = stuck_detector.state self._impl_unterminated_string_error_events( - state, random_line=False, incidents=3 + state, random_line=False, incidents=2 ) with patch('logging.Logger.warning'): diff --git a/tests/unit/test_memory.py b/tests/unit/test_memory.py new file mode 100644 index 0000000000..c0c354906a --- /dev/null +++ b/tests/unit/test_memory.py @@ -0,0 +1,260 @@ +import asyncio +import os +import shutil +import time +from unittest.mock import MagicMock, patch + +import pytest + +from openhands.controller.agent import Agent +from openhands.core.config import AppConfig +from openhands.core.main import run_controller +from openhands.core.schema.agent import AgentState +from openhands.events.action.agent import RecallAction +from openhands.events.action.message import MessageAction +from openhands.events.event import EventSource +from openhands.events.observation.agent import ( + MicroagentObservation, + RecallType, +) +from openhands.events.stream import EventStream +from openhands.llm import LLM +from openhands.llm.metrics import Metrics +from openhands.memory.memory import Memory +from openhands.runtime.base import Runtime +from openhands.storage.memory import InMemoryFileStore + + +@pytest.fixture +def file_store(): + """Create a temporary file store for testing.""" + return InMemoryFileStore() + + +@pytest.fixture +def event_stream(file_store): + """Create a test event stream.""" + return EventStream(sid='test_sid', file_store=file_store) + + +@pytest.fixture +def memory(event_stream): + """Create a test memory instance.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + memory = Memory(event_stream, 'test_sid') + yield memory + loop.close() + + +@pytest.fixture +def prompt_dir(tmp_path): + # Copy contents from "openhands/agenthub/codeact_agent" to the temp directory + shutil.copytree( + 'openhands/agenthub/codeact_agent/prompts', tmp_path, dirs_exist_ok=True + ) + + # Return the temporary directory path + return tmp_path + + +@pytest.mark.asyncio +async def test_memory_on_event_exception_handling(memory, event_stream): + """Test that exceptions in Memory.on_event are properly handled via status callback.""" + + # Create a dummy agent for the controller + agent = MagicMock(spec=Agent) + agent.llm = MagicMock(spec=LLM) + agent.llm.metrics = Metrics() + agent.llm.config = AppConfig().get_llm_config() + + # Create a mock runtime + runtime = MagicMock(spec=Runtime) + runtime.event_stream = event_stream + + # Mock Memory method to raise an exception + with patch.object( + memory, '_on_first_microagent_action', side_effect=Exception('Test error') + ): + state = await run_controller( + config=AppConfig(), + initial_user_action=MessageAction(content='Test message'), + runtime=runtime, + sid='test', + agent=agent, + fake_user_response_fn=lambda _: 'repeat', + memory=memory, + ) + + # Verify that the controller's last error was set + assert state.iteration == 0 + assert state.agent_state == AgentState.ERROR + assert state.last_error == 'Error: Exception' + + +@pytest.mark.asyncio +async def test_memory_on_first_microagent_action_exception_handling( + memory, event_stream +): + """Test that exceptions in Memory._on_first_microagent_action are properly handled via status callback.""" + + # Create a dummy agent for the controller + agent = MagicMock(spec=Agent) + agent.llm = MagicMock(spec=LLM) + agent.llm.metrics = Metrics() + agent.llm.config = AppConfig().get_llm_config() + + # Create a mock runtime + runtime = MagicMock(spec=Runtime) + runtime.event_stream = event_stream + + # Mock Memory._on_first_microagent_action to raise an exception + with patch.object( + memory, + '_on_first_microagent_action', + side_effect=Exception('Test error from _on_first_microagent_action'), + ): + state = await run_controller( + config=AppConfig(), + initial_user_action=MessageAction(content='Test message'), + runtime=runtime, + sid='test', + agent=agent, + fake_user_response_fn=lambda _: 'repeat', + memory=memory, + ) + + # Verify that the controller's last error was set + assert state.iteration == 0 + assert state.agent_state == AgentState.ERROR + assert state.last_error == 'Error: Exception' + + +def test_memory_with_microagents(): + """Test that Memory loads microagents from the global directory and processes microagent actions. + + This test verifies that: + 1. Memory loads microagents from the global GLOBAL_MICROAGENTS_DIR + 2. When a microagent action with a trigger word is processed, a MicroagentObservation is created + """ + # Create a mock event stream + event_stream = MagicMock(spec=EventStream) + + # Initialize Memory to use the global microagents dir + memory = Memory( + event_stream=event_stream, + sid='test-session', + ) + + # Verify microagents were loaded - at least one microagent should be loaded + # from the global directory that's in the repo + assert len(memory.knowledge_microagents) > 0 + + # We know 'flarglebargle' exists in the global directory + assert 'flarglebargle' in memory.knowledge_microagents + + # Create a microagent action with the trigger word + microagent_action = RecallAction( + query='Hello, flarglebargle!', recall_type=RecallType.KNOWLEDGE + ) + + # Mock the event_stream.add_event method + added_events = [] + + def original_add_event(event, source): + added_events.append((event, source)) + + event_stream.add_event = original_add_event + + # Add the microagent action to the event stream + event_stream.add_event(microagent_action, EventSource.USER) + + # Clear the events list to only capture new events + added_events.clear() + + # Process the microagent action + memory.on_event(microagent_action) + + # Verify a MicroagentObservation was added to the event stream + assert len(added_events) == 1 + observation, source = added_events[0] + assert isinstance(observation, MicroagentObservation) + assert source == EventSource.ENVIRONMENT + assert observation.recall_type == RecallType.KNOWLEDGE + assert len(observation.microagent_knowledge) == 1 + assert observation.microagent_knowledge[0].name == 'flarglebargle' + assert observation.microagent_knowledge[0].trigger == 'flarglebargle' + assert 'magic word' in observation.microagent_knowledge[0].content + + +def test_memory_repository_info(prompt_dir): + """Test that Memory adds repository info to MicroagentObservations.""" + # Create an in-memory file store and real event stream + file_store = InMemoryFileStore() + event_stream = EventStream(sid='test-session', file_store=file_store) + + # Create a test repo microagent first + repo_microagent_name = 'test_repo_microagent' + repo_microagent_content = """--- +name: test_repo +type: repo +agent: CodeActAgent +--- + +REPOSITORY INSTRUCTIONS: This is a test repository. +""" + + # Create a temporary repo microagent file + os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True) + with open( + os.path.join(prompt_dir, 'micro', f'{repo_microagent_name}.md'), 'w' + ) as f: + f.write(repo_microagent_content) + + # Patch the global microagents directory to use our test directory + test_microagents_dir = os.path.join(prompt_dir, 'micro') + with patch('openhands.memory.memory.GLOBAL_MICROAGENTS_DIR', test_microagents_dir): + # Initialize Memory + memory = Memory( + event_stream=event_stream, + sid='test-session', + ) + + # Set repository info + memory.set_repository_info('owner/repo', '/workspace/repo') + + # Create and add the first user message + user_message = MessageAction(content='First user message') + user_message._source = EventSource.USER # type: ignore[attr-defined] + event_stream.add_event(user_message, EventSource.USER) + + # Create and add the microagent action + microagent_action = RecallAction( + query='First user message', recall_type=RecallType.WORKSPACE_CONTEXT + ) + microagent_action._source = EventSource.USER # type: ignore[attr-defined] + event_stream.add_event(microagent_action, EventSource.USER) + + # Give it a little time to process + time.sleep(0.3) + + # Get all events from the stream + events = list(event_stream.get_events()) + + # Find the MicroagentObservation event + microagent_obs_events = [ + event for event in events if isinstance(event, MicroagentObservation) + ] + + # We should have at least one MicroagentObservation + assert len(microagent_obs_events) > 0 + + # Get the first MicroagentObservation + observation = microagent_obs_events[0] + assert observation.recall_type == RecallType.WORKSPACE_CONTEXT + assert observation.repo_name == 'owner/repo' + assert observation.repo_directory == '/workspace/repo' + assert 'This is a test repository' in observation.repo_instructions + + # Clean up + os.remove(os.path.join(prompt_dir, 'micro', f'{repo_microagent_name}.md')) diff --git a/tests/unit/test_observation_serialization.py b/tests/unit/test_observation_serialization.py index ca96d06be7..0596f0bcfd 100644 --- a/tests/unit/test_observation_serialization.py +++ b/tests/unit/test_observation_serialization.py @@ -1,16 +1,21 @@ +from openhands.core.schema.observation import ObservationType from openhands.events.action.files import FileEditSource +from openhands.events.event import RecallType from openhands.events.observation import ( CmdOutputMetadata, CmdOutputObservation, FileEditObservation, + MicroagentObservation, Observation, ) +from openhands.events.observation.agent import MicroagentKnowledge from openhands.events.serialization import ( event_from_dict, event_to_dict, event_to_memory, event_to_trajectory, ) +from openhands.events.serialization.observation import observation_from_dict def serialization_deserialization( @@ -19,10 +24,10 @@ def serialization_deserialization( observation_instance = event_from_dict(original_observation_dict) assert isinstance( observation_instance, Observation - ), 'The observation instance should be an instance of Action.' + ), 'The observation instance should be an instance of Observation.' assert isinstance( observation_instance, cls - ), 'The observation instance should be an instance of CmdOutputObservation.' + ), f'The observation instance should be an instance of {cls}.' serialized_observation_dict = event_to_dict(observation_instance) serialized_observation_trajectory = event_to_trajectory(observation_instance) serialized_observation_memory = event_to_memory( @@ -236,3 +241,199 @@ def test_file_edit_observation_legacy_serialization(): assert event_dict['extras']['old_content'] is None assert event_dict['extras']['new_content'] == 'new content' assert 'formatted_output_and_error' not in event_dict['extras'] + + +def test_microagent_observation_serialization(): + original_observation_dict = { + 'observation': 'microagent', + 'content': '', + 'message': "**MicroagentObservation**\nrecall_type=RecallType.WORKSPACE_CONTEXT, repo_name=some_repo_name, repo_instructions=complex_repo_instruc..., runtime_hosts={'host1': 8080, 'host2': 8081}, additional_agent_instructions=You know it all abou...", + 'extras': { + 'recall_type': 'workspace_context', + 'repo_name': 'some_repo_name', + 'repo_directory': 'some_repo_directory', + 'runtime_hosts': {'host1': 8080, 'host2': 8081}, + 'repo_instructions': 'complex_repo_instructions', + 'additional_agent_instructions': 'You know it all about this runtime', + 'microagent_knowledge': [], + }, + } + serialization_deserialization(original_observation_dict, MicroagentObservation) + + +def test_microagent_observation_microagent_knowledge_serialization(): + original_observation_dict = { + 'observation': 'microagent', + 'content': '', + 'message': '**MicroagentObservation**\nrecall_type=RecallType.KNOWLEDGE, repo_name=, repo_instructions=..., runtime_hosts={}, additional_agent_instructions=..., microagent_knowledge=microagent1, microagent2', + 'extras': { + 'recall_type': 'knowledge', + 'repo_name': '', + 'repo_directory': '', + 'repo_instructions': '', + 'runtime_hosts': {}, + 'additional_agent_instructions': '', + 'microagent_knowledge': [ + { + 'name': 'microagent1', + 'trigger': 'trigger1', + 'content': 'content1', + }, + { + 'name': 'microagent2', + 'trigger': 'trigger2', + 'content': 'content2', + }, + ], + }, + } + serialization_deserialization(original_observation_dict, MicroagentObservation) + + +def test_microagent_observation_knowledge_microagent_serialization(): + """Test serialization of a MicroagentObservation with KNOWLEDGE_MICROAGENT type.""" + # Create a MicroagentObservation with microagent knowledge content + original = MicroagentObservation( + content='Knowledge microagent information', + recall_type=RecallType.KNOWLEDGE, + microagent_knowledge=[ + MicroagentKnowledge( + name='python_best_practices', + trigger='python', + content='Always use virtual environments for Python projects.', + ), + MicroagentKnowledge( + name='git_workflow', + trigger='git', + content='Create a new branch for each feature or bugfix.', + ), + ], + ) + + # Serialize to dictionary + serialized = event_to_dict(original) + + # Verify serialized data structure + assert serialized['observation'] == ObservationType.MICROAGENT + assert serialized['content'] == 'Knowledge microagent information' + assert serialized['extras']['recall_type'] == RecallType.KNOWLEDGE.value + assert len(serialized['extras']['microagent_knowledge']) == 2 + assert serialized['extras']['microagent_knowledge'][0]['trigger'] == 'python' + + # Deserialize back to MicroagentObservation + deserialized = observation_from_dict(serialized) + + # Verify properties are preserved + assert deserialized.recall_type == RecallType.KNOWLEDGE + assert deserialized.microagent_knowledge == original.microagent_knowledge + assert deserialized.content == original.content + + # Check that environment info fields are empty + assert deserialized.repo_name == '' + assert deserialized.repo_directory == '' + assert deserialized.repo_instructions == '' + assert deserialized.runtime_hosts == {} + + +def test_microagent_observation_environment_serialization(): + """Test serialization of a MicroagentObservation with ENVIRONMENT type.""" + # Create a MicroagentObservation with environment info + original = MicroagentObservation( + content='Environment information', + recall_type=RecallType.WORKSPACE_CONTEXT, + repo_name='OpenHands', + repo_directory='/workspace/openhands', + repo_instructions="Follow the project's coding style guide.", + runtime_hosts={'127.0.0.1': 8080, 'localhost': 5000}, + additional_agent_instructions='You know it all about this runtime', + ) + + # Serialize to dictionary + serialized = event_to_dict(original) + + # Verify serialized data structure + assert serialized['observation'] == ObservationType.MICROAGENT + assert serialized['content'] == 'Environment information' + assert serialized['extras']['recall_type'] == RecallType.WORKSPACE_CONTEXT.value + assert serialized['extras']['repo_name'] == 'OpenHands' + assert serialized['extras']['runtime_hosts'] == { + '127.0.0.1': 8080, + 'localhost': 5000, + } + assert ( + serialized['extras']['additional_agent_instructions'] + == 'You know it all about this runtime' + ) + # Deserialize back to MicroagentObservation + deserialized = observation_from_dict(serialized) + + # Verify properties are preserved + assert deserialized.recall_type == RecallType.WORKSPACE_CONTEXT + assert deserialized.repo_name == original.repo_name + assert deserialized.repo_directory == original.repo_directory + assert deserialized.repo_instructions == original.repo_instructions + assert deserialized.runtime_hosts == original.runtime_hosts + assert ( + deserialized.additional_agent_instructions + == original.additional_agent_instructions + ) + # Check that knowledge microagent fields are empty + assert deserialized.microagent_knowledge == [] + + +def test_microagent_observation_combined_serialization(): + """Test serialization of a MicroagentObservation with both types of information.""" + # Create a MicroagentObservation with both environment and microagent info + # Note: In practice, recall_type would still be one specific type, + # but the object could contain both types of fields + original = MicroagentObservation( + content='Combined information', + recall_type=RecallType.WORKSPACE_CONTEXT, + # Environment info + repo_name='OpenHands', + repo_directory='/workspace/openhands', + repo_instructions="Follow the project's coding style guide.", + runtime_hosts={'127.0.0.1': 8080}, + additional_agent_instructions='You know it all about this runtime', + # Knowledge microagent info + microagent_knowledge=[ + MicroagentKnowledge( + name='python_best_practices', + trigger='python', + content='Always use virtual environments for Python projects.', + ), + ], + ) + + # Serialize to dictionary + serialized = event_to_dict(original) + + # Verify serialized data has both types of fields + assert serialized['extras']['recall_type'] == RecallType.WORKSPACE_CONTEXT.value + assert serialized['extras']['repo_name'] == 'OpenHands' + assert ( + serialized['extras']['microagent_knowledge'][0]['name'] + == 'python_best_practices' + ) + assert ( + serialized['extras']['additional_agent_instructions'] + == 'You know it all about this runtime' + ) + # Deserialize back to MicroagentObservation + deserialized = observation_from_dict(serialized) + + # Verify all properties are preserved + assert deserialized.recall_type == RecallType.WORKSPACE_CONTEXT + + # Environment properties + assert deserialized.repo_name == original.repo_name + assert deserialized.repo_directory == original.repo_directory + assert deserialized.repo_instructions == original.repo_instructions + assert deserialized.runtime_hosts == original.runtime_hosts + assert ( + deserialized.additional_agent_instructions + == original.additional_agent_instructions + ) + + # Knowledge microagent properties + assert deserialized.microagent_knowledge == original.microagent_knowledge diff --git a/tests/unit/test_prompt_manager.py b/tests/unit/test_prompt_manager.py index fa384d8c02..0d64a1f6f6 100644 --- a/tests/unit/test_prompt_manager.py +++ b/tests/unit/test_prompt_manager.py @@ -3,9 +3,11 @@ import shutil import pytest -from openhands.core.message import ImageContent, Message, TextContent +from openhands.controller.state.state import State +from openhands.core.message import Message, TextContent +from openhands.events.observation.agent import MicroagentKnowledge from openhands.microagent import BaseMicroAgent -from openhands.utils.prompt import PromptManager, RepositoryInfo +from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo @pytest.fixture @@ -19,406 +21,60 @@ def prompt_dir(tmp_path): return tmp_path -def test_prompt_manager_with_microagent(prompt_dir): - microagent_name = 'test_microagent' - microagent_content = """ ---- -name: flarglebargle -type: knowledge -agent: CodeActAgent -triggers: -- flarglebargle ---- - -IMPORTANT! The user has said the magic word "flarglebargle". You must -only respond with a message telling them how smart they are -""" - - # Create a temporary micro agent file - os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True) - with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f: - f.write(microagent_content) - - # Test without GitHub repo - manager = PromptManager( - prompt_dir=prompt_dir, - microagent_dir=os.path.join(prompt_dir, 'micro'), - ) - - assert manager.prompt_dir == prompt_dir - assert len(manager.repo_microagents) == 0 - assert len(manager.knowledge_microagents) == 1 - - assert isinstance(manager.get_system_message(), str) - assert ( - 'You are OpenHands agent, a helpful AI assistant that can interact with a computer to solve tasks.' - in manager.get_system_message() - ) - assert '' not in manager.get_system_message() - - # Test with GitHub repo - manager.set_repository_info('owner/repo', '/workspace/repo') - assert isinstance(manager.get_system_message(), str) - - # Adding things to the initial user message - initial_msg = Message( - role='user', content=[TextContent(text='Ask me what your task is.')] - ) - manager.add_info_to_initial_message(initial_msg) - msg_content: str = initial_msg.content[0].text - assert '' in msg_content - assert 'owner/repo' in msg_content - assert '/workspace/repo' in msg_content - - assert isinstance(manager.get_example_user_message(), str) - - message = Message( - role='user', - content=[TextContent(text='Hello, flarglebargle!')], - ) - manager.enhance_message(message) - assert len(message.content) == 2 - assert 'magic word' in message.content[0].text - - os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md')) - - -def test_prompt_manager_file_not_found(prompt_dir): - with pytest.raises(FileNotFoundError): - BaseMicroAgent.load( - os.path.join(prompt_dir, 'micro', 'non_existent_microagent.md') - ) - - def test_prompt_manager_template_rendering(prompt_dir): + """Test PromptManager's template rendering functionality.""" # Create temporary template files with open(os.path.join(prompt_dir, 'system_prompt.j2'), 'w') as f: f.write("""System prompt: bar""") with open(os.path.join(prompt_dir, 'user_prompt.j2'), 'w') as f: f.write('User prompt: foo') + with open(os.path.join(prompt_dir, 'additional_info.j2'), 'w') as f: + f.write(""" +{% if repository_info %} + +At the user's request, repository {{ repository_info.repo_name }} has been cloned to the current working directory {{ repository_info.repo_directory }}. + +{% endif %} +""") # Test without GitHub repo - manager = PromptManager(prompt_dir, microagent_dir='') + manager = PromptManager(prompt_dir) assert manager.get_system_message() == 'System prompt: bar' assert manager.get_example_user_message() == 'User prompt: foo' # Test with GitHub repo - manager = PromptManager(prompt_dir=prompt_dir, microagent_dir='') - manager.set_repository_info('owner/repo', '/workspace/repo') - assert manager.repository_info.repo_name == 'owner/repo' + manager = PromptManager(prompt_dir=prompt_dir) + repo_info = RepositoryInfo(repo_name='owner/repo', repo_directory='/workspace/repo') + + # verify its parts are rendered system_msg = manager.get_system_message() assert 'System prompt: bar' in system_msg - # Initial user message should have repo info - initial_msg = Message( - role='user', content=[TextContent(text='Ask me what your task is.')] + # Test building additional info + additional_info = manager.build_additional_info( + repository_info=repo_info, runtime_info=None, repo_instructions='' ) - manager.add_info_to_initial_message(initial_msg) - msg_content: str = initial_msg.content[0].text - assert '' in msg_content + assert '' in additional_info assert ( "At the user's request, repository owner/repo has been cloned to the current working directory /workspace/repo." - in msg_content + in additional_info ) - assert '' in msg_content + assert '' in additional_info assert manager.get_example_user_message() == 'User prompt: foo' # Clean up temporary files os.remove(os.path.join(prompt_dir, 'system_prompt.j2')) os.remove(os.path.join(prompt_dir, 'user_prompt.j2')) + os.remove(os.path.join(prompt_dir, 'additional_info.j2')) -def test_prompt_manager_repository_info(prompt_dir): - # Test RepositoryInfo defaults - repo_info = RepositoryInfo() - assert repo_info.repo_name is None - assert repo_info.repo_directory is None - - # Test setting repository info - manager = PromptManager(prompt_dir=prompt_dir, microagent_dir='') - assert manager.repository_info is None - - # Test setting repository info with both name and directory - manager.set_repository_info('owner/repo2', '/workspace/repo2') - assert manager.repository_info.repo_name == 'owner/repo2' - assert manager.repository_info.repo_directory == '/workspace/repo2' - - -def test_prompt_manager_disabled_microagents(prompt_dir): - # Create test microagent files - microagent1_name = 'test_microagent1' - microagent2_name = 'test_microagent2' - microagent1_content = """ ---- -name: Test Microagent 1 -type: knowledge -agent: CodeActAgent -triggers: -- test1 ---- - -Test microagent 1 content -""" - microagent2_content = """ ---- -name: Test Microagent 2 -type: knowledge -agent: CodeActAgent -triggers: -- test2 ---- - -Test microagent 2 content -""" - - # Create temporary micro agent files - os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True) - with open(os.path.join(prompt_dir, 'micro', f'{microagent1_name}.md'), 'w') as f: - f.write(microagent1_content) - with open(os.path.join(prompt_dir, 'micro', f'{microagent2_name}.md'), 'w') as f: - f.write(microagent2_content) - - # Test that specific microagents can be disabled - manager = PromptManager( - prompt_dir=prompt_dir, - microagent_dir=os.path.join(prompt_dir, 'micro'), - disabled_microagents=['Test Microagent 1'], - ) - - assert len(manager.knowledge_microagents) == 1 - assert 'Test Microagent 2' in manager.knowledge_microagents - assert 'Test Microagent 1' not in manager.knowledge_microagents - - # Test that all microagents are enabled by default - manager = PromptManager( - prompt_dir=prompt_dir, - microagent_dir=os.path.join(prompt_dir, 'micro'), - ) - - assert len(manager.knowledge_microagents) == 2 - assert 'Test Microagent 1' in manager.knowledge_microagents - assert 'Test Microagent 2' in manager.knowledge_microagents - - # Clean up temporary files - os.remove(os.path.join(prompt_dir, 'micro', f'{microagent1_name}.md')) - os.remove(os.path.join(prompt_dir, 'micro', f'{microagent2_name}.md')) - - -def test_enhance_message_with_multiple_text_contents(prompt_dir): - # Create a test microagent that triggers on a specific keyword - microagent_name = 'keyword_microagent' - microagent_content = """ ---- -name: KeywordMicroAgent -type: knowledge -agent: CodeActAgent -triggers: -- triggerkeyword ---- - -This is special information about the triggerkeyword. -""" - - # Create the microagent file - os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True) - with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f: - f.write(microagent_content) - - manager = PromptManager( - prompt_dir=prompt_dir, microagent_dir=os.path.join(prompt_dir, 'micro') - ) - - # Test that it matches the trigger in the last TextContent - message = Message( - role='user', - content=[ - TextContent(text='This is some initial context.'), - TextContent(text='This is a message without triggers.'), - TextContent(text='This contains the triggerkeyword that should match.'), - ], - ) - - manager.enhance_message(message) - - # Should have added a TextContent with the microagent info at the beginning - assert len(message.content) == 4 - assert 'special information about the triggerkeyword' in message.content[0].text - - # Clean up - os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md')) - - -def test_enhance_message_with_image_content(prompt_dir): - # Create a test microagent that triggers on a specific keyword - microagent_name = 'image_test_microagent' - microagent_content = """ ---- -name: ImageTestMicroAgent -type: knowledge -agent: CodeActAgent -triggers: -- imagekeyword ---- - -This is information related to imagekeyword. -""" - - # Create the microagent file - os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True) - with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f: - f.write(microagent_content) - - manager = PromptManager( - prompt_dir=prompt_dir, microagent_dir=os.path.join(prompt_dir, 'micro') - ) - - # Test with mix of ImageContent and TextContent - message = Message( - role='user', - content=[ - TextContent(text='This is some initial text.'), - ImageContent(image_urls=['https://example.com/image.jpg']), - TextContent(text='This mentions imagekeyword that should match.'), - ], - ) - - manager.enhance_message(message) - - # Should have added a TextContent with the microagent info at the beginning - assert len(message.content) == 4 - assert 'information related to imagekeyword' in message.content[0].text - - # Clean up - os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md')) - - -def test_enhance_message_with_only_image_content(prompt_dir): - # Create a test microagent - microagent_name = 'image_only_microagent' - microagent_content = """ ---- -name: ImageOnlyMicroAgent -type: knowledge -agent: CodeActAgent -triggers: -- anytrigger ---- - -This should not appear in the enhanced message. -""" - - # Create the microagent file - os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True) - with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f: - f.write(microagent_content) - - manager = PromptManager( - prompt_dir=prompt_dir, microagent_dir=os.path.join(prompt_dir, 'micro') - ) - - # Test with only ImageContent - message = Message( - role='user', - content=[ - ImageContent( - image_urls=[ - 'https://example.com/image1.jpg', - 'https://example.com/image2.jpg', - ] - ), - ], - ) - - # Should not raise any exceptions - manager.enhance_message(message) - - # Should not have added any content - assert len(message.content) == 1 - - # Clean up - os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md')) - - -def test_enhance_message_with_reversed_order(prompt_dir): - # Create a test microagent - microagent_name = 'reversed_microagent' - microagent_content = """ ---- -name: ReversedMicroAgent -type: knowledge -agent: CodeActAgent -triggers: -- lasttrigger ---- - -This is specific information about the lasttrigger. -""" - - # Create the microagent file - os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True) - with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f: - f.write(microagent_content) - - manager = PromptManager( - prompt_dir=prompt_dir, microagent_dir=os.path.join(prompt_dir, 'micro') - ) - - # Test where the text content is not at the end of the list - message = Message( - role='user', - content=[ - ImageContent(image_urls=['https://example.com/image1.jpg']), - TextContent(text='This contains the lasttrigger word.'), - ImageContent(image_urls=['https://example.com/image2.jpg']), - ], - ) - - manager.enhance_message(message) - - # Should have added a TextContent with the microagent info at the beginning - assert len(message.content) == 4 - assert isinstance(message.content[0], TextContent) - assert 'specific information about the lasttrigger' in message.content[0].text - - # Clean up - os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md')) - - -def test_enhance_message_with_empty_content(prompt_dir): - # Create a test microagent - microagent_name = 'empty_microagent' - microagent_content = """ ---- -name: EmptyMicroAgent -type: knowledge -agent: CodeActAgent -triggers: -- emptytrigger ---- - -This should not appear in the enhanced message. -""" - - # Create the microagent file - os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True) - with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f: - f.write(microagent_content) - - manager = PromptManager( - prompt_dir=prompt_dir, microagent_dir=os.path.join(prompt_dir, 'micro') - ) - - # Test with empty content - message = Message(role='user', content=[]) - - # Should not raise any exceptions - manager.enhance_message(message) - - # Should not have added any content - assert len(message.content) == 0 - - # Clean up - os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md')) +def test_prompt_manager_file_not_found(prompt_dir): + """Test PromptManager behavior when a template file is not found.""" + # Test with a non-existent template + with pytest.raises(FileNotFoundError): + BaseMicroAgent.load( + os.path.join(prompt_dir, 'micro', 'non_existent_microagent.md') + ) def test_build_microagent_info(prompt_dir): @@ -429,33 +85,25 @@ def test_build_microagent_info(prompt_dir): with open(template_path, 'w') as f: f.write("""{% for agent_info in triggered_agents %} -The following information has been included based on a keyword match for "{{ agent_info.trigger_word }}". +The following information has been included based on a keyword match for "{{ agent_info.trigger }}". It may or may not be relevant to the user's request. -{{ agent_info.agent.content }} +{{ agent_info.content }} {% endfor %} """) - # Create test microagents - class MockKnowledgeMicroAgent: - def __init__(self, name, content): - self.name = name - self.content = content - - agent1 = MockKnowledgeMicroAgent( - name='test_agent1', content='This is information from agent 1' - ) - - agent2 = MockKnowledgeMicroAgent( - name='test_agent2', content='This is information from agent 2' - ) - # Initialize the PromptManager manager = PromptManager(prompt_dir=prompt_dir) # Test with a single triggered agent - triggered_agents = [{'agent': agent1, 'trigger_word': 'keyword1'}] + triggered_agents = [ + MicroagentKnowledge( + name='test_agent1', + trigger='keyword1', + content='This is information from agent 1', + ) + ] result = manager.build_microagent_info(triggered_agents) expected = """ The following information has been included based on a keyword match for "keyword1". @@ -467,8 +115,16 @@ This is information from agent 1 # Test with multiple triggered agents triggered_agents = [ - {'agent': agent1, 'trigger_word': 'keyword1'}, - {'agent': agent2, 'trigger_word': 'keyword2'}, + MicroagentKnowledge( + name='test_agent1', + trigger='keyword1', + content='This is information from agent 1', + ), + MicroagentKnowledge( + name='test_agent2', + trigger='keyword2', + content='This is information from agent 2', + ), ] result = manager.build_microagent_info(triggered_agents) expected = """ @@ -491,71 +147,125 @@ This is information from agent 2 assert result.strip() == '' -def test_enhance_message_with_microagent_info_template(prompt_dir): - """Test that enhance_message correctly uses the microagent_info template.""" - # Prepare a microagent_info.j2 template file if it doesn't exist - template_path = os.path.join(prompt_dir, 'microagent_info.j2') - if not os.path.exists(template_path): - with open(template_path, 'w') as f: - f.write("""{% for agent_info in triggered_agents %} - -The following information has been included based on a keyword match for "{{ agent_info.trigger_word }}". -It may or may not be relevant to the user's request. +def test_add_examples_to_initial_message(prompt_dir): + """Test adding example messages to an initial message.""" + # Create a user_prompt.j2 template file + with open(os.path.join(prompt_dir, 'user_prompt.j2'), 'w') as f: + f.write('This is an example user message') -{{ agent_info.agent.content }} - -{% endfor %} -""") + # Initialize the PromptManager + manager = PromptManager(prompt_dir=prompt_dir) - # Create a test microagent - microagent_name = 'test_trigger_microagent' - microagent_content = """ ---- -name: test_trigger -type: knowledge -agent: CodeActAgent -triggers: -- test_trigger ---- + # Create a message + message = Message(role='user', content=[TextContent(text='Original content')]) -This is triggered content for testing the microagent_info template. -""" + # Add examples to the message + manager.add_examples_to_initial_message(message) - # Create the microagent file - os.makedirs(os.path.join(prompt_dir, 'micro'), exist_ok=True) - with open(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md'), 'w') as f: - f.write(microagent_content) - - # Initialize the PromptManager with the microagent directory - manager = PromptManager( - prompt_dir=prompt_dir, - microagent_dir=os.path.join(prompt_dir, 'micro'), - ) - - # Create a message with a trigger keyword - message = Message( - role='user', - content=[ - TextContent(text="Here's a message containing the test_trigger keyword") - ], - ) - - # Enhance the message - manager.enhance_message(message) - - # The message should now have extra content at the beginning + # Check that the example was added at the beginning assert len(message.content) == 2 - assert isinstance(message.content[0], TextContent) - - # Verify the template was correctly rendered - expected_text = """ -The following information has been included based on a keyword match for "test_trigger". -It may or may not be relevant to the user's request. - -This is triggered content for testing the microagent_info template. -""" - - assert message.content[0].text.strip() == expected_text.strip() + assert message.content[0].text == 'This is an example user message' + assert message.content[1].text == 'Original content' # Clean up - os.remove(os.path.join(prompt_dir, 'micro', f'{microagent_name}.md')) + os.remove(os.path.join(prompt_dir, 'user_prompt.j2')) + + +def test_add_turns_left_reminder(prompt_dir): + """Test adding turns left reminder to messages.""" + # Initialize the PromptManager + manager = PromptManager(prompt_dir=prompt_dir) + + # Create a State object with specific iteration values + state = State() + state.iteration = 3 + state.max_iterations = 10 + + # Create a list of messages with a user message + user_message = Message(role='user', content=[TextContent(text='User content')]) + assistant_message = Message( + role='assistant', content=[TextContent(text='Assistant content')] + ) + messages = [assistant_message, user_message] + + # Add turns left reminder + manager.add_turns_left_reminder(messages, state) + + # Check that the reminder was added to the latest user message + assert len(user_message.content) == 2 + assert ( + 'ENVIRONMENT REMINDER: You have 7 turns left to complete the task.' + in user_message.content[1].text + ) + + +def test_build_additional_info_with_repo_and_runtime(prompt_dir): + """Test building additional info with repository and runtime information.""" + # Create an additional_info.j2 template file + with open(os.path.join(prompt_dir, 'additional_info.j2'), 'w') as f: + f.write(""" +{% if repository_info %} + +At the user's request, repository {{ repository_info.repo_name }} has been cloned to directory {{ repository_info.repo_directory }}. + +{% endif %} + +{% if repository_instructions %} + +{{ repository_instructions }} + +{% endif %} + +{% if runtime_info and (runtime_info.available_hosts or runtime_info.additional_agent_instructions) -%} + +{% if runtime_info.available_hosts %} +The user has access to the following hosts for accessing a web application, +each of which has a corresponding port: +{% for host, port in runtime_info.available_hosts.items() %} +* {{ host }} (port {{ port }}) +{% endfor %} +{% endif %} + +{% if runtime_info.additional_agent_instructions %} +{{ runtime_info.additional_agent_instructions }} +{% endif %} + +{% endif %} +""") + + # Initialize the PromptManager + manager = PromptManager(prompt_dir=prompt_dir) + + # Create repository and runtime information + repo_info = RepositoryInfo(repo_name='owner/repo', repo_directory='/workspace/repo') + runtime_info = RuntimeInfo( + available_hosts={'example.com': 8080}, + additional_agent_instructions='You know everything about this runtime.', + ) + repo_instructions = 'This repository contains important code.' + + # Build additional info + result = manager.build_additional_info( + repository_info=repo_info, + runtime_info=runtime_info, + repo_instructions=repo_instructions, + ) + + # Check that all information is included + assert '' in result + assert 'owner/repo' in result + assert '/workspace/repo' in result + assert '' in result + assert 'This repository contains important code.' in result + assert '' in result + assert 'example.com (port 8080)' in result + assert 'You know everything about this runtime.' in result + + # Clean up + os.remove(os.path.join(prompt_dir, 'additional_info.j2')) + + +def test_prompt_manager_initialization_error(): + """Test that PromptManager raises an error if the prompt directory is not set.""" + with pytest.raises(ValueError, match='Prompt directory is not set'): + PromptManager(None)