diff --git a/enterprise/tests/unit/integrations/github/test_github_v1_callback_processor.py b/enterprise/tests/unit/integrations/github/test_github_v1_callback_processor.py index 4e8e500318..7ee32d6731 100644 --- a/enterprise/tests/unit/integrations/github/test_github_v1_callback_processor.py +++ b/enterprise/tests/unit/integrations/github/test_github_v1_callback_processor.py @@ -32,7 +32,6 @@ from openhands.app_server.sandbox.sandbox_models import ( SandboxInfo, SandboxStatus, ) -from openhands.events.action.message import MessageAction from openhands.sdk.event import ConversationStateUpdateEvent # --------------------------------------------------------------------------- @@ -76,7 +75,10 @@ def conversation_state_update_event(): @pytest.fixture def wrong_event(): - return MessageAction(content='Hello world') + """Return a mock event that is not a ConversationStateUpdateEvent.""" + mock_event = MagicMock() + mock_event.id = uuid4() + return mock_event @pytest.fixture diff --git a/enterprise/tests/unit/integrations/gitlab/test_gitlab_v1_callback_processor.py b/enterprise/tests/unit/integrations/gitlab/test_gitlab_v1_callback_processor.py index 43fab5d862..9590314247 100644 --- a/enterprise/tests/unit/integrations/gitlab/test_gitlab_v1_callback_processor.py +++ b/enterprise/tests/unit/integrations/gitlab/test_gitlab_v1_callback_processor.py @@ -28,7 +28,6 @@ from openhands.app_server.sandbox.sandbox_models import ( SandboxInfo, SandboxStatus, ) -from openhands.events.action.message import MessageAction from openhands.sdk.event import ConversationStateUpdateEvent # --------------------------------------------------------------------------- @@ -73,7 +72,10 @@ def conversation_state_update_event(): @pytest.fixture def wrong_event(): - return MessageAction(content='Hello world') + """Return a mock event that is not a ConversationStateUpdateEvent.""" + mock_event = MagicMock() + mock_event.id = uuid4() + return mock_event @pytest.fixture diff --git a/enterprise/tests/unit/integrations/slack/test_slack_v1_callback_processor.py b/enterprise/tests/unit/integrations/slack/test_slack_v1_callback_processor.py index e3baf86f12..4ca6902562 100644 --- a/enterprise/tests/unit/integrations/slack/test_slack_v1_callback_processor.py +++ b/enterprise/tests/unit/integrations/slack/test_slack_v1_callback_processor.py @@ -28,9 +28,16 @@ from openhands.app_server.sandbox.sandbox_models import ( SandboxInfo, SandboxStatus, ) -from openhands.events.action.message import MessageAction from openhands.sdk.event import ConversationStateUpdateEvent + +def _create_mock_event(): + """Create a mock event that is not a ConversationStateUpdateEvent.""" + mock_event = MagicMock() + mock_event.id = uuid4() + return mock_event + + # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @@ -105,8 +112,10 @@ class TestSlackV1CallbackProcessor: @pytest.mark.parametrize( 'event,expected_result', [ - # Wrong event types should be ignored - (MessageAction(content='Hello world'), None), + # Wrong event types should be ignored (use lazy evaluation for mock) + pytest.param( + None, None, id='wrong_event_type', marks=pytest.mark.wrong_event_type + ), # Wrong state values should be ignored ( ConversationStateUpdateEvent(key='execution_status', value='running'), @@ -120,9 +129,12 @@ class TestSlackV1CallbackProcessor: ], ) async def test_event_filtering( - self, slack_callback_processor, event_callback, event, expected_result + self, slack_callback_processor, event_callback, event, expected_result, request ): """Test that processor correctly filters events.""" + # Handle the mock event case specially + if event is None and 'wrong_event_type' in request.node.name: + event = _create_mock_event() result = await slack_callback_processor(uuid4(), event_callback, event) assert result == expected_result diff --git a/openhands/events/__init__.py b/openhands/events/__init__.py deleted file mode 100644 index 36edbb7772..0000000000 --- a/openhands/events/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from openhands.events.event import Event, EventSource -from openhands.events.recall_type import RecallType -from openhands.events.stream import EventStream, EventStreamSubscriber - -__all__ = [ - 'Event', - 'EventSource', - 'EventStream', - 'EventStreamSubscriber', - 'RecallType', -] diff --git a/openhands/events/action/__init__.py b/openhands/events/action/__init__.py deleted file mode 100644 index 4731d68e9e..0000000000 --- a/openhands/events/action/__init__.py +++ /dev/null @@ -1,50 +0,0 @@ -from openhands.events.action.action import ( - Action, - ActionConfirmationStatus, - ActionSecurityRisk, -) -from openhands.events.action.agent import ( - AgentDelegateAction, - AgentFinishAction, - AgentRejectAction, - AgentThinkAction, - ChangeAgentStateAction, - LoopRecoveryAction, - RecallAction, - TaskTrackingAction, -) -from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction -from openhands.events.action.commands import CmdRunAction, IPythonRunCellAction -from openhands.events.action.empty import NullAction -from openhands.events.action.files import ( - FileEditAction, - FileReadAction, - FileWriteAction, -) -from openhands.events.action.mcp import MCPAction -from openhands.events.action.message import MessageAction, SystemMessageAction - -__all__ = [ - 'Action', - 'NullAction', - 'CmdRunAction', - 'BrowseURLAction', - 'BrowseInteractiveAction', - 'FileReadAction', - 'FileWriteAction', - 'FileEditAction', - 'AgentFinishAction', - 'AgentRejectAction', - 'AgentDelegateAction', - 'ChangeAgentStateAction', - 'IPythonRunCellAction', - 'MessageAction', - 'SystemMessageAction', - 'ActionConfirmationStatus', - 'AgentThinkAction', - 'RecallAction', - 'MCPAction', - 'TaskTrackingAction', - 'ActionSecurityRisk', - 'LoopRecoveryAction', -] diff --git a/openhands/events/action/action.py b/openhands/events/action/action.py deleted file mode 100644 index 0605af7ed5..0000000000 --- a/openhands/events/action/action.py +++ /dev/null @@ -1,23 +0,0 @@ -from dataclasses import dataclass -from enum import Enum -from typing import ClassVar - -from openhands.events.event import Event - - -class ActionConfirmationStatus(str, Enum): - CONFIRMED = 'confirmed' - REJECTED = 'rejected' - AWAITING_CONFIRMATION = 'awaiting_confirmation' - - -class ActionSecurityRisk(int, Enum): - UNKNOWN = -1 - LOW = 0 - MEDIUM = 1 - HIGH = 2 - - -@dataclass -class Action(Event): - runnable: ClassVar[bool] = False diff --git a/openhands/events/action/agent.py b/openhands/events/action/agent.py deleted file mode 100644 index 41fe0a4d02..0000000000 --- a/openhands/events/action/agent.py +++ /dev/null @@ -1,242 +0,0 @@ -from dataclasses import dataclass, field -from typing import Any - -from openhands.core.schema import ActionType -from openhands.events.action.action import Action -from openhands.events.recall_type import RecallType - - -@dataclass -class ChangeAgentStateAction(Action): - """Fake action, just to notify the client that a task state has changed.""" - - agent_state: str - thought: str = '' - action: str = ActionType.CHANGE_AGENT_STATE - - @property - def message(self) -> str: - return f'Agent state changed to {self.agent_state}' - - -@dataclass -class AgentFinishAction(Action): - """An action where the agent finishes the task. - - Attributes: - final_thought (str): The message to send to the user. - outputs (dict): The other outputs of the agent, for instance "content". - thought (str): The agent's explanation of its actions. - action (str): The action type, namely ActionType.FINISH. - """ - - final_thought: str = '' - outputs: dict[str, Any] = field(default_factory=dict) - thought: str = '' - action: str = ActionType.FINISH - - @property - def message(self) -> str: - if self.thought != '': - return self.thought - return "All done! What's next on the agenda?" - - -@dataclass -class AgentThinkAction(Action): - """An action where the agent logs a thought. - - Attributes: - thought (str): The agent's explanation of its actions. - action (str): The action type, namely ActionType.THINK. - """ - - thought: str = '' - action: str = ActionType.THINK - - @property - def message(self) -> str: - return f'I am thinking...: {self.thought}' - - -@dataclass -class AgentRejectAction(Action): - outputs: dict = field(default_factory=dict) - thought: str = '' - action: str = ActionType.REJECT - - @property - def message(self) -> str: - msg: str = 'Task is rejected by the agent.' - if 'reason' in self.outputs: - msg += ' Reason: ' + self.outputs['reason'] - return msg - - -@dataclass -class AgentDelegateAction(Action): - agent: str - inputs: dict - thought: str = '' - action: str = ActionType.DELEGATE - - @property - def message(self) -> str: - return f"I'm asking {self.agent} for help with this task." - - -@dataclass -class RecallAction(Action): - """This action is used for retrieving content, e.g., from the global directory or user workspace.""" - - recall_type: RecallType - query: str = '' - thought: str = '' - action: str = ActionType.RECALL - - @property - def message(self) -> str: - return f'Retrieving content for: {self.query[:50]}' - - def __str__(self) -> str: - ret = '**RecallAction**\n' - ret += f'QUERY: {self.query[:50]}' - return ret - - -@dataclass -class CondensationAction(Action): - """This action indicates a condensation of the conversation history is happening. - - There are two ways to specify the events to be forgotten: - 1. By providing a list of event IDs. - 2. By providing the start and end IDs of a range of events. - - In the second case, we assume that event IDs are monotonically increasing, and that _all_ events between the start and end IDs are to be forgotten. - - Raises: - ValueError: If the optional fields are not instantiated in a valid configuration. - """ - - action: str = ActionType.CONDENSATION - - forgotten_event_ids: list[int] | None = None - """The IDs of the events that are being forgotten (removed from the `View` given to the LLM).""" - - forgotten_events_start_id: int | None = None - """The ID of the first event to be forgotten in a range of events.""" - - forgotten_events_end_id: int | None = None - """The ID of the last event to be forgotten in a range of events.""" - - summary: str | None = None - """An optional summary of the events being forgotten.""" - - summary_offset: int | None = None - """An optional offset to the start of the resulting view indicating where the summary should be inserted.""" - - def _validate_field_polymorphism(self) -> bool: - """Check if the optional fields are instantiated in a valid configuration.""" - # For the forgotton events, there are only two valid configurations: - # 1. We're forgetting events based on the list of provided IDs, or - using_event_ids = self.forgotten_event_ids is not None - # 2. We're forgetting events based on the range of IDs. - using_event_range = ( - self.forgotten_events_start_id is not None - and self.forgotten_events_end_id is not None - ) - - # Either way, we can only have one of the two valid configurations. - forgotten_event_configuration = using_event_ids ^ using_event_range - - # We also need to check that if the summary is provided, so is the - # offset (and vice versa). - summary_configuration = ( - self.summary is None and self.summary_offset is None - ) or (self.summary is not None and self.summary_offset is not None) - - return forgotten_event_configuration and summary_configuration - - def __post_init__(self): - if not self._validate_field_polymorphism(): - raise ValueError('Invalid configuration of the optional fields.') - - @property - def forgotten(self) -> list[int]: - """The list of event IDs that should be forgotten.""" - # Start by making sure the fields are instantiated in a valid - # configuration. We check this whenever the event is initialized, but we - # can't make the dataclass immutable so we need to check it again here - # to make sure the configuration is still valid. - if not self._validate_field_polymorphism(): - raise ValueError('Invalid configuration of the optional fields.') - - if self.forgotten_event_ids is not None: - return self.forgotten_event_ids - - # If we've gotten this far, the start/end IDs are not None. - assert self.forgotten_events_start_id is not None - assert self.forgotten_events_end_id is not None - return list( - range(self.forgotten_events_start_id, self.forgotten_events_end_id + 1) - ) - - @property - def message(self) -> str: - if self.summary: - return f'Summary: {self.summary}' - return f'Condenser is dropping the events: {self.forgotten}.' - - -@dataclass -class CondensationRequestAction(Action): - """This action is used to request a condensation of the conversation history. - - Attributes: - action (str): The action type, namely ActionType.CONDENSATION_REQUEST. - """ - - action: str = ActionType.CONDENSATION_REQUEST - - @property - def message(self) -> str: - return 'Requesting a condensation of the conversation history.' - - -@dataclass -class TaskTrackingAction(Action): - """An action where the agent writes or updates a task list for task management. - Attributes: - task_list (list): The list of task items with their status and metadata. - thought (str): The agent's explanation of its actions. - action (str): The action type, namely ActionType.TASK_TRACKING. - """ - - command: str = 'view' - task_list: list[dict[str, Any]] = field(default_factory=list) - thought: str = '' - action: str = ActionType.TASK_TRACKING - - @property - def message(self) -> str: - num_tasks = len(self.task_list) - if num_tasks == 0: - return 'Clearing the task list.' - elif num_tasks == 1: - return 'Managing 1 task item.' - else: - return f'Managing {num_tasks} task items.' - - -@dataclass -class LoopRecoveryAction(Action): - """An action that shows three ways to handle dead loop. - The class should be invisible to LLM. - Attributes: - option (int): 1 allow user to prompt again - 2 automatically use latest user prompt - 3 stop agent - """ - - option: int = 1 - action: str = ActionType.LOOP_RECOVERY diff --git a/openhands/events/action/browse.py b/openhands/events/action/browse.py deleted file mode 100644 index 481549cffb..0000000000 --- a/openhands/events/action/browse.py +++ /dev/null @@ -1,48 +0,0 @@ -from dataclasses import dataclass -from typing import ClassVar - -from openhands.core.schema import ActionType -from openhands.events.action.action import Action, ActionSecurityRisk - - -@dataclass -class BrowseURLAction(Action): - url: str - thought: str = '' - action: str = ActionType.BROWSE - runnable: ClassVar[bool] = True - security_risk: ActionSecurityRisk = ActionSecurityRisk.UNKNOWN - return_axtree: bool = False - - @property - def message(self) -> str: - return f'I am browsing the URL: {self.url}' - - def __str__(self) -> str: - ret = '**BrowseURLAction**\n' - if self.thought: - ret += f'THOUGHT: {self.thought}\n' - ret += f'URL: {self.url}' - return ret - - -@dataclass -class BrowseInteractiveAction(Action): - browser_actions: str - thought: str = '' - browsergym_send_msg_to_user: str = '' - action: str = ActionType.BROWSE_INTERACTIVE - runnable: ClassVar[bool] = True - security_risk: ActionSecurityRisk = ActionSecurityRisk.UNKNOWN - return_axtree: bool = False - - @property - def message(self) -> str: - return f'I am interacting with the browser:\n```\n{self.browser_actions}\n```' - - def __str__(self) -> str: - ret = '**BrowseInteractiveAction**\n' - if self.thought: - ret += f'THOUGHT: {self.thought}\n' - ret += f'BROWSER_ACTIONS: {self.browser_actions}' - return ret diff --git a/openhands/events/action/commands.py b/openhands/events/action/commands.py deleted file mode 100644 index 2590a2b141..0000000000 --- a/openhands/events/action/commands.py +++ /dev/null @@ -1,64 +0,0 @@ -from dataclasses import dataclass -from typing import ClassVar - -from openhands.core.schema import ActionType -from openhands.events.action.action import ( - Action, - ActionConfirmationStatus, - ActionSecurityRisk, -) - - -@dataclass -class CmdRunAction(Action): - command: ( - str # When `command` is empty, it will be used to print the current tmux window - ) - is_input: bool = False # if True, the command is an input to the running process - thought: str = '' - blocking: bool = False # if True, the command will be run in a blocking manner, but a timeout must be set through _set_hard_timeout - is_static: bool = False # if True, runs the command in a separate process - cwd: str | None = None # current working directory, only used if is_static is True - hidden: bool = ( - False # if True, this command does not go through the LLM or event stream - ) - action: str = ActionType.RUN - runnable: ClassVar[bool] = True - confirmation_state: ActionConfirmationStatus = ActionConfirmationStatus.CONFIRMED - security_risk: ActionSecurityRisk = ActionSecurityRisk.UNKNOWN - - @property - def message(self) -> str: - return f'Running command: {self.command}' - - def __str__(self) -> str: - ret = f'**CmdRunAction (source={self.source}, is_input={self.is_input})**\n' - if self.thought: - ret += f'THOUGHT: {self.thought}\n' - ret += f'COMMAND:\n{self.command}' - return ret - - -@dataclass -class IPythonRunCellAction(Action): - code: str - thought: str = '' - include_extra: bool = ( - True # whether to include CWD & Python interpreter in the output - ) - action: str = ActionType.RUN_IPYTHON - runnable: ClassVar[bool] = True - confirmation_state: ActionConfirmationStatus = ActionConfirmationStatus.CONFIRMED - security_risk: ActionSecurityRisk = ActionSecurityRisk.UNKNOWN - kernel_init_code: str = '' # code to run in the kernel (if the kernel is restarted) - - def __str__(self) -> str: - ret = '**IPythonRunCellAction**\n' - if self.thought: - ret += f'THOUGHT: {self.thought}\n' - ret += f'CODE:\n{self.code}' - return ret - - @property - def message(self) -> str: - return f'Running Python code interactively: {self.code}' diff --git a/openhands/events/action/empty.py b/openhands/events/action/empty.py deleted file mode 100644 index 32e0346001..0000000000 --- a/openhands/events/action/empty.py +++ /dev/null @@ -1,15 +0,0 @@ -from dataclasses import dataclass - -from openhands.core.schema import ActionType -from openhands.events.action.action import Action - - -@dataclass -class NullAction(Action): - """An action that does nothing.""" - - action: str = ActionType.NULL - - @property - def message(self) -> str: - return 'No action' diff --git a/openhands/events/action/files.py b/openhands/events/action/files.py deleted file mode 100644 index a8a00bc113..0000000000 --- a/openhands/events/action/files.py +++ /dev/null @@ -1,138 +0,0 @@ -from dataclasses import dataclass -from typing import ClassVar - -from openhands.core.schema import ActionType -from openhands.events.action.action import Action, ActionSecurityRisk -from openhands.events.event import FileEditSource, FileReadSource - - -@dataclass -class FileReadAction(Action): - """Reads a file from a given path. - Can be set to read specific lines using start and end - Default lines 0:-1 (whole file) - """ - - path: str - start: int = 0 - end: int = -1 - thought: str = '' - action: str = ActionType.READ - runnable: ClassVar[bool] = True - security_risk: ActionSecurityRisk = ActionSecurityRisk.UNKNOWN - impl_source: FileReadSource = FileReadSource.DEFAULT - view_range: list[int] | None = None # ONLY used in OH_ACI mode - - @property - def message(self) -> str: - return f'Reading file: {self.path}' - - -@dataclass -class FileWriteAction(Action): - """Writes a file to a given path. - Can be set to write specific lines using start and end - Default lines 0:-1 (whole file) - """ - - path: str - content: str - start: int = 0 - end: int = -1 - thought: str = '' - action: str = ActionType.WRITE - runnable: ClassVar[bool] = True - security_risk: ActionSecurityRisk = ActionSecurityRisk.UNKNOWN - - @property - def message(self) -> str: - return f'Writing file: {self.path}' - - def __repr__(self) -> str: - return ( - f'**FileWriteAction**\n' - f'Path: {self.path}\n' - f'Range: [L{self.start}:L{self.end}]\n' - f'Thought: {self.thought}\n' - f'Content:\n```\n{self.content}\n```\n' - ) - - -@dataclass -class FileEditAction(Action): - """Edits a file using various commands including view, create, str_replace, insert, and undo_edit. - - This class supports two main modes of operation: - 1. LLM-based editing (impl_source = FileEditSource.LLM_BASED_EDIT) - 2. ACI-based editing (impl_source = FileEditSource.OH_ACI) - - Attributes: - path (str): The path to the file being edited. Works for both LLM-based and OH_ACI editing. - OH_ACI only arguments: - command (str): The editing command to be performed (view, create, str_replace, insert, undo_edit, write). - file_text (str): The content of the file to be created (used with 'create' command in OH_ACI mode). - old_str (str): The string to be replaced (used with 'str_replace' command in OH_ACI mode). - new_str (str): The string to replace old_str (used with 'str_replace' and 'insert' commands in OH_ACI mode). - insert_line (int): The line number after which to insert new_str (used with 'insert' command in OH_ACI mode). - LLM-based editing arguments: - content (str): The content to be written or edited in the file (used in LLM-based editing and 'write' command). - start (int): The starting line for editing (1-indexed, inclusive). Default is 1. - end (int): The ending line for editing (1-indexed, inclusive). Default is -1 (end of file). - thought (str): The reasoning behind the edit action. - action (str): The type of action being performed (always ActionType.EDIT). - runnable (bool): Indicates if the action can be executed (always True). - security_risk (ActionSecurityRisk | None): Indicates any security risks associated with the action. - impl_source (FileEditSource): The source of the implementation (LLM_BASED_EDIT or OH_ACI). - - Usage: - - For LLM-based editing: Use path, content, start, and end attributes. - - For ACI-based editing: Use path, command, and the appropriate attributes for the specific command. - - Note: - - If start is set to -1 in LLM-based editing, the content will be appended to the file. - - The 'write' command behaves similarly to LLM-based editing, using content, start, and end attributes. - """ - - path: str - - # OH_ACI arguments - command: str = '' - file_text: str | None = None - old_str: str | None = None - new_str: str | None = None - insert_line: int | None = None - - # LLM-based editing arguments - content: str = '' - start: int = 1 - end: int = -1 - - # Shared arguments - thought: str = '' - action: str = ActionType.EDIT - runnable: ClassVar[bool] = True - security_risk: ActionSecurityRisk = ActionSecurityRisk.UNKNOWN - impl_source: FileEditSource = FileEditSource.OH_ACI - - def __repr__(self) -> str: - ret = '**FileEditAction**\n' - ret += f'Path: [{self.path}]\n' - ret += f'Thought: {self.thought}\n' - - if self.impl_source == FileEditSource.LLM_BASED_EDIT: - ret += f'Range: [L{self.start}:L{self.end}]\n' - ret += f'Content:\n```\n{self.content}\n```\n' - else: # OH_ACI mode - ret += f'Command: {self.command}\n' - if self.command == 'create': - ret += f'Created File with Text:\n```\n{self.file_text}\n```\n' - elif self.command == 'str_replace': - ret += f'Old String: ```\n{self.old_str}\n```\n' - ret += f'New String: ```\n{self.new_str}\n```\n' - elif self.command == 'insert': - ret += f'Insert Line: {self.insert_line}\n' - ret += f'New String: ```\n{self.new_str}\n```\n' - elif self.command == 'undo_edit': - ret += 'Undo Edit\n' - # We ignore "view" command because it will be mapped to a FileReadAction - return ret diff --git a/openhands/events/action/mcp.py b/openhands/events/action/mcp.py deleted file mode 100644 index 40334fe2fc..0000000000 --- a/openhands/events/action/mcp.py +++ /dev/null @@ -1,32 +0,0 @@ -from dataclasses import dataclass, field -from typing import Any, ClassVar - -from openhands.core.schema import ActionType -from openhands.events.action.action import Action, ActionSecurityRisk - - -@dataclass -class MCPAction(Action): - name: str - arguments: dict[str, Any] = field(default_factory=dict) - thought: str = '' - action: str = ActionType.MCP - runnable: ClassVar[bool] = True - security_risk: ActionSecurityRisk = ActionSecurityRisk.UNKNOWN - - @property - def message(self) -> str: - return ( - f'I am interacting with the MCP server with name:\n' - f'```\n{self.name}\n```\n' - f'and arguments:\n' - f'```\n{self.arguments}\n```' - ) - - def __str__(self) -> str: - ret = '**MCPAction**\n' - if self.thought: - ret += f'THOUGHT: {self.thought}\n' - ret += f'NAME: {self.name}\n' - ret += f'ARGUMENTS: {self.arguments}' - return ret diff --git a/openhands/events/action/message.py b/openhands/events/action/message.py deleted file mode 100644 index 3b9b14fff4..0000000000 --- a/openhands/events/action/message.py +++ /dev/null @@ -1,66 +0,0 @@ -from dataclasses import dataclass -from typing import Any - -from openhands.core.schema import ActionType -from openhands.events.action.action import Action, ActionSecurityRisk -from openhands.version import get_version - - -@dataclass -class MessageAction(Action): - content: str - file_urls: list[str] | None = None - image_urls: list[str] | None = None - wait_for_response: bool = False - action: str = ActionType.MESSAGE - security_risk: ActionSecurityRisk = ActionSecurityRisk.UNKNOWN - - @property - def message(self) -> str: - return self.content - - @property - def images_urls(self) -> list[str] | None: - # Deprecated alias for backward compatibility - return self.image_urls - - @images_urls.setter - def images_urls(self, value: list[str] | None) -> None: - self.image_urls = value - - def __str__(self) -> str: - ret = f'**MessageAction** (source={self.source})\n' - ret += f'CONTENT: {self.content}' - if self.image_urls: - for url in self.image_urls: - ret += f'\nIMAGE_URL: {url}' - if self.file_urls: - for url in self.file_urls: - ret += f'\nFILE_URL: {url}' - return ret - - -@dataclass -class SystemMessageAction(Action): - """Action that represents a system message for an agent, including the system prompt - and available tools. This should be the first message in the event stream. - """ - - content: str - tools: list[Any] | None = None - openhands_version: str | None = get_version() - agent_class: str | None = None - action: ActionType = ActionType.SYSTEM - - @property - def message(self) -> str: - return self.content - - def __str__(self) -> str: - ret = f'**SystemMessageAction** (source={self.source})\n' - ret += f'CONTENT: {self.content}' - if self.tools: - ret += f'\nTOOLS: {len(self.tools)} tools available' - if self.agent_class: - ret += f'\nAGENT_CLASS: {self.agent_class}' - return ret diff --git a/openhands/events/event.py b/openhands/events/event.py deleted file mode 100644 index b459293a8a..0000000000 --- a/openhands/events/event.py +++ /dev/null @@ -1,121 +0,0 @@ -from dataclasses import dataclass -from datetime import datetime -from enum import Enum - -from openhands.events.metrics import Metrics -from openhands.events.tool import ToolCallMetadata - - -class EventSource(str, Enum): - AGENT = 'agent' - USER = 'user' - ENVIRONMENT = 'environment' - - -class FileEditSource(str, Enum): - LLM_BASED_EDIT = 'llm_based_edit' - OH_ACI = 'oh_aci' # openhands-aci - - -class FileReadSource(str, Enum): - OH_ACI = 'oh_aci' # openhands-aci - DEFAULT = 'default' - - -@dataclass -class Event: - INVALID_ID = -1 - - @property - def message(self) -> str | None: - if hasattr(self, '_message'): - msg = getattr(self, '_message') - return str(msg) if msg is not None else None - return '' - - @property - def id(self) -> int: - if hasattr(self, '_id'): - id_val = getattr(self, '_id') - return int(id_val) if id_val is not None else Event.INVALID_ID - return Event.INVALID_ID - - @property - def timestamp(self) -> str | None: - if hasattr(self, '_timestamp') and isinstance(self._timestamp, str): - ts = getattr(self, '_timestamp') - return str(ts) if ts is not None else None - return None - - @timestamp.setter - def timestamp(self, value: datetime) -> None: - if isinstance(value, datetime): - self._timestamp = value.isoformat() - - @property - def source(self) -> EventSource | None: - if hasattr(self, '_source'): - src = getattr(self, '_source') - return EventSource(src) if src is not None else None - return None - - @property - def cause(self) -> int | None: - if hasattr(self, '_cause'): - cause_val = getattr(self, '_cause') - return int(cause_val) if cause_val is not None else None - return None - - @property - def timeout(self) -> float | None: - if hasattr(self, '_timeout'): - timeout_val = getattr(self, '_timeout') - return float(timeout_val) if timeout_val is not None else None - return None - - def set_hard_timeout(self, value: float | None, blocking: bool = True) -> None: - """Set the timeout for the event. - - NOTE, this is a hard timeout, meaning that the event will be blocked - until the timeout is reached. - """ - self._timeout = value - # Check if .blocking is an attribute of the event - if hasattr(self, 'blocking'): - # .blocking needs to be set to True if .timeout is set - self.blocking = blocking - - # optional metadata, LLM call cost of the edit - @property - def llm_metrics(self) -> Metrics | None: - if hasattr(self, '_llm_metrics'): - metrics = getattr(self, '_llm_metrics') - return metrics if isinstance(metrics, Metrics) else None - return None - - @llm_metrics.setter - def llm_metrics(self, value: Metrics) -> None: - self._llm_metrics = value - - # optional field, metadata about the tool call, if the event has a tool call - @property - def tool_call_metadata(self) -> ToolCallMetadata | None: - if hasattr(self, '_tool_call_metadata'): - metadata = getattr(self, '_tool_call_metadata') - return metadata if isinstance(metadata, ToolCallMetadata) else None - return None - - @tool_call_metadata.setter - def tool_call_metadata(self, value: ToolCallMetadata) -> None: - self._tool_call_metadata = value - - # optional field, the id of the response from the LLM - @property - def response_id(self) -> str | None: - if hasattr(self, '_response_id'): - return self._response_id # type: ignore[attr-defined] - return None - - @response_id.setter - def response_id(self, value: str) -> None: - self._response_id = value diff --git a/openhands/events/event_filter.py b/openhands/events/event_filter.py deleted file mode 100644 index cfd41b568c..0000000000 --- a/openhands/events/event_filter.py +++ /dev/null @@ -1,98 +0,0 @@ -import json -from dataclasses import dataclass - -from openhands.events.event import Event -from openhands.events.serialization.event import event_to_dict - - -@dataclass -class EventFilter: - """A filter for Event objects in the event stream. - - EventFilter provides a flexible way to filter events based on various criteria - such as event type, source, date range, and content. It can be used to include - or exclude events from search results based on the specified criteria. - - Attributes: - exclude_hidden: Whether to exclude events marked as hidden. Defaults to False. - query: Text string to search for in event content. Case-insensitive. Defaults to None. - include_types: Tuple of Event types to include. Only events of these types will pass the filter. - Defaults to None (include all types). - exclude_types: Tuple of Event types to exclude. Events of these types will be filtered out. - Defaults to None (exclude no types). - source: Filter by event source (e.g., 'agent', 'user', 'environment'). Defaults to None. - start_date: ISO format date string. Only events after this date will pass the filter. - Defaults to None. - end_date: ISO format date string. Only events before this date will pass the filter. - Defaults to None. - """ - - exclude_hidden: bool = False - query: str | None = None - include_types: tuple[type[Event], ...] | None = None - exclude_types: tuple[type[Event], ...] | None = None - source: str | None = None - start_date: str | None = None - end_date: str | None = None - - def include(self, event: Event) -> bool: - """Determine if an event should be included based on the filter criteria. - - This method checks if the given event matches all the filter criteria. - If any criterion fails, the event is excluded. - - Args: - event: The Event object to check against the filter criteria. - - Returns: - bool: True if the event passes all filter criteria and should be included, - False otherwise. - """ - if self.include_types and not isinstance(event, self.include_types): - return False - - if self.exclude_types is not None and isinstance(event, self.exclude_types): - return False - - if self.source: - if event.source is None or event.source.value != self.source: - return False - - if ( - self.start_date - and event.timestamp is not None - and event.timestamp < self.start_date - ): - return False - - if ( - self.end_date - and event.timestamp is not None - and event.timestamp > self.end_date - ): - return False - - if self.exclude_hidden and getattr(event, 'hidden', False): - return False - - # Text search in event content if query provided - if self.query: - event_dict = event_to_dict(event) - event_str = json.dumps(event_dict).lower() - if self.query.lower() not in event_str: - return False - - return True - - def exclude(self, event: Event) -> bool: - """Determine if an event should be excluded based on the filter criteria. - - This is the inverse of the include method. - - Args: - event: The Event object to check against the filter criteria. - - Returns: - bool: True if the event should be excluded, False if it should be included. - """ - return not self.include(event) diff --git a/openhands/events/event_store.py b/openhands/events/event_store.py deleted file mode 100644 index 07ea385efa..0000000000 --- a/openhands/events/event_store.py +++ /dev/null @@ -1,183 +0,0 @@ -import json -from dataclasses import dataclass -from typing import Iterable - -from openhands.core.logger import openhands_logger as logger -from openhands.events.event import Event, EventSource -from openhands.events.event_filter import EventFilter -from openhands.events.event_store_abc import EventStoreABC -from openhands.events.serialization.event import event_from_dict -from openhands.storage.files import FileStore -from openhands.storage.locations import ( - get_conversation_dir, - get_conversation_event_filename, - get_conversation_events_dir, -) -from openhands.utils.shutdown_listener import should_continue - - -@dataclass(frozen=True) -class _CachePage: - events: list[dict] | None - start: int - end: int - - def covers(self, global_index: int) -> bool: - if global_index < self.start: - return False - if global_index >= self.end: - return False - return True - - def get_event(self, global_index: int) -> Event | None: - # If there was not actually a cached page, return None - if not self.events: - return None - local_index = global_index - self.start - return event_from_dict(self.events[local_index]) - - -_DUMMY_PAGE = _CachePage(None, 1, -1) - - -@dataclass -class EventStore(EventStoreABC): - """A stored list of events backing a conversation""" - - sid: str - file_store: FileStore - user_id: str | None - cache_size: int = 25 - _cur_id: int | None = None # Private field to cache the calculated value - - @property - def cur_id(self) -> int: - """Lazy calculated property for the current event ID.""" - if self._cur_id is None: - self._cur_id = self._calculate_cur_id() - return self._cur_id - - @cur_id.setter - def cur_id(self, value: int) -> None: - """Setter for cur_id to allow updates.""" - self._cur_id = value - - def _calculate_cur_id(self) -> int: - """Calculate the current event ID based on file system content.""" - events = [] - try: - events_dir = get_conversation_events_dir(self.sid, self.user_id) - events = self.file_store.list(events_dir) - except FileNotFoundError: - logger.debug(f'No events found for session {self.sid} at {events_dir}') - - if not events: - return 0 - - # if we have events, we need to find the highest id to prepare for new events - max_id = -1 - for event_str in events: - id = self._get_id_from_filename(event_str) - if id >= max_id: - max_id = id - return max_id + 1 - - def search_events( - self, - start_id: int = 0, - end_id: int | None = None, - reverse: bool = False, - filter: EventFilter | None = None, - limit: int | None = None, - ) -> Iterable[Event]: - """Retrieve events from the event stream, optionally filtering out events of a given type - and events marked as hidden. - - Args: - start_id: The ID of the first event to retrieve. Defaults to 0. - end_id: The ID of the last event to retrieve. Defaults to the last event in the stream. - reverse: Whether to retrieve events in reverse order. Defaults to False. - filter: EventFilter to use - - Yields: - Events from the stream that match the criteria. - """ - if end_id is None: - end_id = self.cur_id - else: - end_id += 1 # From inclusive to exclusive - - if reverse: - step = -1 - start_id, end_id = end_id, start_id - start_id -= 1 - end_id -= 1 - else: - step = 1 - - cache_page = _DUMMY_PAGE - num_results = 0 - for index in range(start_id, end_id, step): - if not should_continue(): - return - if not cache_page.covers(index): - cache_page = self._load_cache_page_for_index(index) - event = cache_page.get_event(index) - if event is None: - try: - event = self.get_event(index) - except FileNotFoundError: - event = None - if event: - if not filter or filter.include(event): - yield event - num_results += 1 - if limit and limit <= num_results: - return - - def get_event(self, id: int) -> Event: - filename = self._get_filename_for_id(id, self.user_id) - content = self.file_store.read(filename) - data = json.loads(content) - return event_from_dict(data) - - def get_latest_event(self) -> Event: - return self.get_event(self.cur_id - 1) - - def get_latest_event_id(self) -> int: - return self.cur_id - 1 - - def filtered_events_by_source(self, source: EventSource) -> Iterable[Event]: - for event in self.search_events(): - if event.source == source: - yield event - - def _get_filename_for_id(self, id: int, user_id: str | None) -> str: - return get_conversation_event_filename(self.sid, id, user_id) - - def _get_filename_for_cache(self, start: int, end: int) -> str: - return f'{get_conversation_dir(self.sid, self.user_id)}event_cache/{start}-{end}.json' - - def _load_cache_page(self, start: int, end: int) -> _CachePage: - """Read a page from the cache. Reading individual events is slow when there are a lot of them, so we use pages.""" - cache_filename = self._get_filename_for_cache(start, end) - try: - content = self.file_store.read(cache_filename) - events = json.loads(content) - except FileNotFoundError: - events = None - page = _CachePage(events, start, end) - return page - - def _load_cache_page_for_index(self, index: int) -> _CachePage: - offset = index % self.cache_size - index -= offset - return self._load_cache_page(index, index + self.cache_size) - - @staticmethod - def _get_id_from_filename(filename: str) -> int: - try: - return int(filename.split('/')[-1].split('.')[0]) - except ValueError: - logger.warning(f'get id from filename ({filename}) failed.') - return -1 diff --git a/openhands/events/event_store_abc.py b/openhands/events/event_store_abc.py deleted file mode 100644 index 840410ce34..0000000000 --- a/openhands/events/event_store_abc.py +++ /dev/null @@ -1,111 +0,0 @@ -from abc import abstractmethod -from itertools import islice -from typing import Iterable - -from deprecated import deprecated # type: ignore - -from openhands.events.event import Event, EventSource -from openhands.events.event_filter import EventFilter - - -class EventStoreABC: - """A stored list of events backing a conversation""" - - sid: str - user_id: str | None - - @abstractmethod - def search_events( - self, - start_id: int = 0, - end_id: int | None = None, - reverse: bool = False, - filter: EventFilter | None = None, - limit: int | None = None, - ) -> Iterable[Event]: - """Retrieve events from the event stream, optionally excluding events using a filter - - Args: - start_id: The ID of the first event to retrieve. Defaults to 0. - end_id: The ID of the last event to retrieve. Defaults to the last event in the stream. - reverse: Whether to retrieve events in reverse order. Defaults to False. - filter: An optional event filter - - Yields: - Events from the stream that match the criteria. - """ - - @deprecated('Use search_events instead') - def get_events( - self, - start_id: int = 0, - end_id: int | None = None, - reverse: bool = False, - filter_out_type: tuple[type[Event], ...] | None = None, - filter_hidden: bool = False, - ) -> Iterable[Event]: - yield from self.search_events( - start_id, - end_id, - reverse, - EventFilter(exclude_types=filter_out_type, exclude_hidden=filter_hidden), - ) - - @abstractmethod - def get_event(self, id: int) -> Event: - """Retrieve a single event from the event stream. Raise a FileNotFoundError if there was no such event""" - - @abstractmethod - def get_latest_event(self) -> Event: - """Get the latest event from the event stream""" - - @abstractmethod - def get_latest_event_id(self) -> int: - """Get the id of the latest event from the event stream""" - - @deprecated('use search_events instead') - def filtered_events_by_source(self, source: EventSource) -> Iterable[Event]: - yield from self.search_events(filter=EventFilter(source=source)) - - @deprecated('use search_events instead') - def get_matching_events( - self, - query: str | None = None, - event_types: tuple[type[Event], ...] | None = None, - source: str | None = None, - start_date: str | None = None, - end_date: str | None = None, - start_id: int = 0, - limit: int = 100, - reverse: bool = False, - ) -> list[Event]: - """Get matching events from the event stream based on filters. - - Args: - query: Text to search for in event content - event_types: Filter by event type classes (e.g., (FileReadAction, ) ). - source: Filter by event source - start_date: Filter events after this date (ISO format) - end_date: Filter events before this date (ISO format) - start_id: Starting ID in the event stream. Defaults to 0 - limit: Maximum number of events to return. Must be between 1 and 100. Defaults to 100 - reverse: Whether to retrieve events in reverse order. Defaults to False. - - Returns: - list: List of matching events (as dicts) - """ - if limit < 1 or limit > 100: - raise ValueError('Limit must be between 1 and 100') - - events = self.search_events( - start_id=start_id, - reverse=reverse, - filter=EventFilter( - query=query, - include_types=event_types, - source=source, - start_date=start_date, - end_date=end_date, - ), - ) - return list(islice(events, limit)) diff --git a/openhands/events/metrics.py b/openhands/events/metrics.py deleted file mode 100644 index a4204d3a06..0000000000 --- a/openhands/events/metrics.py +++ /dev/null @@ -1,284 +0,0 @@ -# IMPORTANT: LEGACY V0 CODE - Deprecated since version 1.0.0, scheduled for removal April 1, 2026 -# This file is part of the legacy (V0) implementation of OpenHands and will be removed soon as we complete the migration to V1. -# OpenHands V1 uses the Software Agent SDK for the agentic core and runs a new application server. Please refer to: -# - V1 agentic core (SDK): https://github.com/OpenHands/software-agent-sdk -# - V1 application server (in this repo): openhands/app_server/ -# Unless you are working on deprecation, please avoid extending this legacy file and consult the V1 codepaths above. -# Tag: Legacy-V0 -import copy -import time - -from pydantic import BaseModel, Field - - -class Cost(BaseModel): - model: str - cost: float - timestamp: float = Field(default_factory=time.time) - - -class ResponseLatency(BaseModel): - """Metric tracking the round-trip time per completion call.""" - - model: str - latency: float - response_id: str - - -class TokenUsage(BaseModel): - """Metric tracking detailed token usage per completion call.""" - - model: str = Field(default='') - prompt_tokens: int = Field(default=0) - completion_tokens: int = Field(default=0) - cache_read_tokens: int = Field(default=0) - cache_write_tokens: int = Field(default=0) - context_window: int = Field(default=0) - per_turn_token: int = Field(default=0) - response_id: str = Field(default='') - - def __add__(self, other: 'TokenUsage') -> 'TokenUsage': - """Add two TokenUsage instances together.""" - return TokenUsage( - model=self.model, - prompt_tokens=self.prompt_tokens + other.prompt_tokens, - completion_tokens=self.completion_tokens + other.completion_tokens, - cache_read_tokens=self.cache_read_tokens + other.cache_read_tokens, - cache_write_tokens=self.cache_write_tokens + other.cache_write_tokens, - context_window=max(self.context_window, other.context_window), - per_turn_token=other.per_turn_token, - response_id=self.response_id, - ) - - -class Metrics: - """Metrics class can record various metrics during running and evaluation. - We track: - - accumulated_cost and costs - - max_budget_per_task (budget limit) - - A list of ResponseLatency - - A list of TokenUsage (one per call). - """ - - def __init__(self, model_name: str = 'default') -> None: - self._accumulated_cost: float = 0.0 - self._max_budget_per_task: float | None = None - self._costs: list[Cost] = [] - self._response_latencies: list[ResponseLatency] = [] - self.model_name = model_name - self._token_usages: list[TokenUsage] = [] - self._accumulated_token_usage: TokenUsage = TokenUsage( - model=model_name, - prompt_tokens=0, - completion_tokens=0, - cache_read_tokens=0, - cache_write_tokens=0, - context_window=0, - response_id='', - ) - - @property - def accumulated_cost(self) -> float: - return self._accumulated_cost - - @accumulated_cost.setter - def accumulated_cost(self, value: float) -> None: - if value < 0: - raise ValueError('Total cost cannot be negative.') - self._accumulated_cost = value - - @property - def max_budget_per_task(self) -> float | None: - return self._max_budget_per_task - - @max_budget_per_task.setter - def max_budget_per_task(self, value: float | None) -> None: - self._max_budget_per_task = value - - @property - def costs(self) -> list[Cost]: - return self._costs - - @property - def response_latencies(self) -> list[ResponseLatency]: - if not hasattr(self, '_response_latencies'): - self._response_latencies = [] - return self._response_latencies - - @response_latencies.setter - def response_latencies(self, value: list[ResponseLatency]) -> None: - self._response_latencies = value - - @property - def token_usages(self) -> list[TokenUsage]: - if not hasattr(self, '_token_usages'): - self._token_usages = [] - return self._token_usages - - @token_usages.setter - def token_usages(self, value: list[TokenUsage]) -> None: - self._token_usages = value - - @property - def accumulated_token_usage(self) -> TokenUsage: - """Get the accumulated token usage, initializing it if it doesn't exist.""" - if not hasattr(self, '_accumulated_token_usage'): - self._accumulated_token_usage = TokenUsage( - model=self.model_name, - prompt_tokens=0, - completion_tokens=0, - cache_read_tokens=0, - cache_write_tokens=0, - context_window=0, - response_id='', - ) - return self._accumulated_token_usage - - def add_cost(self, value: float) -> None: - if value < 0: - raise ValueError('Added cost cannot be negative.') - self._accumulated_cost += value - self._costs.append(Cost(cost=value, model=self.model_name)) - - def add_response_latency(self, value: float, response_id: str) -> None: - self._response_latencies.append( - ResponseLatency( - latency=max(0.0, value), model=self.model_name, response_id=response_id - ) - ) - - def add_token_usage( - self, - prompt_tokens: int, - completion_tokens: int, - cache_read_tokens: int, - cache_write_tokens: int, - context_window: int, - response_id: str, - ) -> None: - """Add a single usage record.""" - # Token each turn for calculating context usage. - per_turn_token = prompt_tokens + completion_tokens - - usage = TokenUsage( - model=self.model_name, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - cache_read_tokens=cache_read_tokens, - cache_write_tokens=cache_write_tokens, - context_window=context_window, - per_turn_token=per_turn_token, - response_id=response_id, - ) - self._token_usages.append(usage) - - # Update accumulated token usage using the __add__ operator - self._accumulated_token_usage = self.accumulated_token_usage + TokenUsage( - model=self.model_name, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - cache_read_tokens=cache_read_tokens, - cache_write_tokens=cache_write_tokens, - context_window=context_window, - per_turn_token=per_turn_token, - response_id='', - ) - - def merge(self, other: 'Metrics') -> None: - """Merge 'other' metrics into this one.""" - self._accumulated_cost += other.accumulated_cost - - # Keep the max_budget_per_task from other if it's set and this one isn't - if self._max_budget_per_task is None and other.max_budget_per_task is not None: - self._max_budget_per_task = other.max_budget_per_task - - self._costs += other._costs - # use the property so older picked objects that lack the field won't crash - self.token_usages += other.token_usages - self.response_latencies += other.response_latencies - - # Merge accumulated token usage using the __add__ operator - self._accumulated_token_usage = ( - self.accumulated_token_usage + other.accumulated_token_usage - ) - - def get(self) -> dict: - """Return the metrics in a dictionary.""" - return { - 'accumulated_cost': self._accumulated_cost, - 'max_budget_per_task': self._max_budget_per_task, - 'accumulated_token_usage': self.accumulated_token_usage.model_dump(), - 'costs': [cost.model_dump() for cost in self._costs], - 'response_latencies': [ - latency.model_dump() for latency in self._response_latencies - ], - 'token_usages': [usage.model_dump() for usage in self._token_usages], - } - - def log(self) -> str: - """Log the metrics.""" - metrics = self.get() - logs = '' - for key, value in metrics.items(): - logs += f'{key}: {value}\n' - return logs - - def copy(self) -> 'Metrics': - """Create a deep copy of the Metrics object.""" - return copy.deepcopy(self) - - def diff(self, baseline: 'Metrics') -> 'Metrics': - """Calculate the difference between current metrics and a baseline. - - This is useful for tracking metrics for specific operations like delegates. - - Args: - baseline: A metrics object representing the baseline state - - Returns: - A new Metrics object containing only the differences since the baseline - """ - result = Metrics(self.model_name) - - # Calculate cost difference - result._accumulated_cost = self._accumulated_cost - baseline._accumulated_cost - - # Include only costs that were added after the baseline - if baseline._costs: - last_baseline_timestamp = baseline._costs[-1].timestamp - result._costs = [ - cost for cost in self._costs if cost.timestamp > last_baseline_timestamp - ] - else: - result._costs = self._costs.copy() - - # Include only response latencies that were added after the baseline - result._response_latencies = self._response_latencies[ - len(baseline._response_latencies) : - ] - - # Include only token usages that were added after the baseline - result._token_usages = self._token_usages[len(baseline._token_usages) :] - - # Calculate accumulated token usage difference - base_usage = baseline.accumulated_token_usage - current_usage = self.accumulated_token_usage - - result._accumulated_token_usage = TokenUsage( - model=self.model_name, - prompt_tokens=current_usage.prompt_tokens - base_usage.prompt_tokens, - completion_tokens=current_usage.completion_tokens - - base_usage.completion_tokens, - cache_read_tokens=current_usage.cache_read_tokens - - base_usage.cache_read_tokens, - cache_write_tokens=current_usage.cache_write_tokens - - base_usage.cache_write_tokens, - context_window=current_usage.context_window, - per_turn_token=0, - response_id='', - ) - - return result - - def __repr__(self) -> str: - return f'Metrics({self.get()}' diff --git a/openhands/events/nested_event_store.py b/openhands/events/nested_event_store.py deleted file mode 100644 index cd7819eb1a..0000000000 --- a/openhands/events/nested_event_store.py +++ /dev/null @@ -1,101 +0,0 @@ -from dataclasses import dataclass -from typing import Iterable -from urllib.parse import urlencode - -import httpx # type: ignore -from fastapi import status - -from openhands.events.event import Event -from openhands.events.event_filter import EventFilter -from openhands.events.event_store_abc import EventStoreABC -from openhands.events.serialization.event import event_from_dict - - -@dataclass -class NestedEventStore(EventStoreABC): - """A stored list of events backing a conversation""" - - base_url: str - sid: str - user_id: str | None - session_api_key: str | None = None - - def search_events( - self, - start_id: int = 0, - end_id: int | None = None, - reverse: bool = False, - filter: EventFilter | None = None, - limit: int | None = None, - ) -> Iterable[Event]: - # Maintain explicit cursors for pagination to avoid accidental mutation. - start_cursor = start_id - end_cursor: int | None = None # Used only for reverse pagination - while True: - search_params: dict[str, int | bool] = { - 'start_id': start_cursor, - 'reverse': reverse, - } - if reverse and end_cursor is not None: - # Bound the upper end when scanning backwards to avoid duplicates - search_params['end_id'] = end_cursor - if limit is not None: - search_params['limit'] = min(100, limit) - search_str = urlencode(search_params) - url = f'{self.base_url}/events?{search_str}' - headers: dict[str, str] = {} - if self.session_api_key: - headers['X-Session-API-Key'] = self.session_api_key - response = httpx.get(url, headers=headers) - if response.status_code == status.HTTP_404_NOT_FOUND: - # Follow pattern of event store not throwing errors on not found - return - result_set = response.json() - - page_min_id: int | None = None - forward_next_start = start_cursor - for result in result_set['events']: - event = event_from_dict(result) - if reverse: - page_min_id = ( - event.id if page_min_id is None else min(page_min_id, event.id) - ) - else: - forward_next_start = max(forward_next_start, event.id + 1) - if end_id == event.id: - if not filter or filter.include(event): - yield event - return - if filter and filter.exclude(event): - continue - yield event - if limit is not None: - limit -= 1 - if limit <= 0: - return - - # Update pagination cursor for next request - if reverse and page_min_id is not None: - # Next page should end strictly before the smallest ID we just saw - end_cursor = page_min_id - 1 - elif not reverse: - start_cursor = forward_next_start - - if not result_set['has_more']: - return - - def get_event(self, id: int) -> Event: - events = list(self.search_events(start_id=id, limit=1)) - if not events: - raise FileNotFoundError('no_event') - return events[0] - - def get_latest_event(self) -> Event: - events = list(self.search_events(reverse=True, limit=1)) - if not events: - raise FileNotFoundError('no_event') - return events[0] - - def get_latest_event_id(self) -> int: - event = self.get_latest_event() - return event.id diff --git a/openhands/events/observation/__init__.py b/openhands/events/observation/__init__.py deleted file mode 100644 index 9145b2bbe0..0000000000 --- a/openhands/events/observation/__init__.py +++ /dev/null @@ -1,55 +0,0 @@ -from openhands.events.observation.agent import ( - AgentCondensationObservation, - AgentStateChangedObservation, - AgentThinkObservation, - RecallObservation, -) -from openhands.events.observation.browse import BrowserOutputObservation -from openhands.events.observation.commands import ( - CmdOutputMetadata, - CmdOutputObservation, - IPythonRunCellObservation, -) -from openhands.events.observation.delegate import AgentDelegateObservation -from openhands.events.observation.empty import ( - NullObservation, -) -from openhands.events.observation.error import ErrorObservation -from openhands.events.observation.file_download import FileDownloadObservation -from openhands.events.observation.files import ( - FileEditObservation, - FileReadObservation, - FileWriteObservation, -) -from openhands.events.observation.loop_recovery import LoopDetectionObservation -from openhands.events.observation.mcp import MCPObservation -from openhands.events.observation.observation import Observation -from openhands.events.observation.reject import UserRejectObservation -from openhands.events.observation.success import SuccessObservation -from openhands.events.observation.task_tracking import TaskTrackingObservation -from openhands.events.recall_type import RecallType - -__all__ = [ - 'Observation', - 'NullObservation', - 'AgentThinkObservation', - 'CmdOutputObservation', - 'CmdOutputMetadata', - 'IPythonRunCellObservation', - 'BrowserOutputObservation', - 'FileReadObservation', - 'FileWriteObservation', - 'FileEditObservation', - 'ErrorObservation', - 'AgentStateChangedObservation', - 'AgentDelegateObservation', - 'SuccessObservation', - 'UserRejectObservation', - 'AgentCondensationObservation', - 'RecallObservation', - 'RecallType', - 'LoopDetectionObservation', - 'MCPObservation', - 'FileDownloadObservation', - 'TaskTrackingObservation', -] diff --git a/openhands/events/observation/agent.py b/openhands/events/observation/agent.py deleted file mode 100644 index 2b73d2ff07..0000000000 --- a/openhands/events/observation/agent.py +++ /dev/null @@ -1,138 +0,0 @@ -from dataclasses import dataclass, field - -from openhands.core.schema import ObservationType -from openhands.events.observation.observation import Observation -from openhands.events.recall_type import RecallType - - -@dataclass -class AgentStateChangedObservation(Observation): - """This data class represents the result from delegating to another agent""" - - agent_state: str - reason: str = '' - observation: str = ObservationType.AGENT_STATE_CHANGED - - @property - def message(self) -> str: - return '' - - -@dataclass -class AgentCondensationObservation(Observation): - """The output of a condensation action.""" - - observation: str = ObservationType.CONDENSE - - @property - def message(self) -> str: - return self.content - - -@dataclass -class AgentThinkObservation(Observation): - """The output of a think action. - - In practice, this is a no-op, since it will just reply a static message to the agent - acknowledging that the thought has been logged. - """ - - observation: str = ObservationType.THINK - - @property - def message(self) -> str: - return self.content - - -@dataclass -class MicroagentKnowledge: - """Represents knowledge from a triggered microagent. - - Attributes: - name: The name of the microagent that was triggered - trigger: The word that triggered this microagent - content: The actual content/knowledge from the microagent - """ - - name: str - trigger: str - content: str - - -@dataclass -class RecallObservation(Observation): - """The retrieval of content from a microagent or more microagents.""" - - recall_type: RecallType - observation: str = ObservationType.RECALL - - # workspace context - repo_name: str = '' - repo_directory: str = '' - repo_branch: str = '' - repo_instructions: str = '' - runtime_hosts: dict[str, int] = field(default_factory=dict) - additional_agent_instructions: str = '' - date: str = '' - custom_secrets_descriptions: dict[str, str] = field(default_factory=dict) - conversation_instructions: str = '' - working_dir: str = '' - - # knowledge - microagent_knowledge: list[MicroagentKnowledge] = field(default_factory=list) - """ - A list of MicroagentKnowledge objects, each containing information from a triggered microagent. - - Example: - [ - MicroagentKnowledge( - name="python_best_practices", - trigger="python", - content="Always use virtual environments for Python projects." - ), - MicroagentKnowledge( - name="git_workflow", - trigger="git", - content="Create a new branch for each feature or bugfix." - ) - ] - """ - - @property - def message(self) -> str: - return ( - 'Added workspace context' - if self.recall_type == RecallType.WORKSPACE_CONTEXT - else 'Added microagent knowledge' - ) - - def __str__(self) -> str: - # Build a string representation - fields = [] - if self.recall_type == RecallType.WORKSPACE_CONTEXT: - fields.extend( - [ - f'recall_type={self.recall_type}', - f'repo_name={self.repo_name}', - f'repo_instructions={self.repo_instructions[:20]}...', - f'runtime_hosts={self.runtime_hosts}', - f'additional_agent_instructions={self.additional_agent_instructions[:20]}...', - f'date={self.date}' - f'custom_secrets_descriptions={self.custom_secrets_descriptions}', - f'conversation_instructions={self.conversation_instructions[0:20]}...', - ] - ) - else: - fields.extend( - [ - f'recall_type={self.recall_type}', - ] - ) - if self.microagent_knowledge: - fields.extend( - [ - f'microagent_knowledge={", ".join([m.name for m in self.microagent_knowledge])}', - ] - ) - - return f'**RecallObservation**\n{", ".join(fields)}' diff --git a/openhands/events/observation/browse.py b/openhands/events/observation/browse.py deleted file mode 100644 index 9a565767ac..0000000000 --- a/openhands/events/observation/browse.py +++ /dev/null @@ -1,56 +0,0 @@ -from dataclasses import dataclass, field -from typing import Any - -from openhands.core.schema import ObservationType -from openhands.events.observation.observation import Observation - - -@dataclass -class BrowserOutputObservation(Observation): - """This data class represents the output of a browser.""" - - url: str - trigger_by_action: str - screenshot: str = field(repr=False, default='') # don't show in repr - screenshot_path: str | None = field(default=None) # path to saved screenshot file - set_of_marks: str = field(default='', repr=False) # don't show in repr - error: bool = False - observation: str = ObservationType.BROWSE - goal_image_urls: list[str] = field(default_factory=list) - # do not include in the memory - open_pages_urls: list[str] = field(default_factory=list) - active_page_index: int = -1 - dom_object: dict[str, Any] = field( - default_factory=dict, repr=False - ) # don't show in repr - axtree_object: dict[str, Any] = field( - default_factory=dict, repr=False - ) # don't show in repr - extra_element_properties: dict[str, Any] = field( - default_factory=dict, repr=False - ) # don't show in repr - last_browser_action: str = '' - last_browser_action_error: str = '' - focused_element_bid: str = '' - filter_visible_only: bool = False - - @property - def message(self) -> str: - return 'Visited ' + self.url - - def __str__(self) -> str: - ret = ( - '**BrowserOutputObservation**\n' - f'URL: {self.url}\n' - f'Error: {self.error}\n' - f'Open pages: {self.open_pages_urls}\n' - f'Active page index: {self.active_page_index}\n' - f'Last browser action: {self.last_browser_action}\n' - f'Last browser action error: {self.last_browser_action_error}\n' - f'Focused element bid: {self.focused_element_bid}\n' - ) - if self.screenshot_path: - ret += f'Screenshot saved to: {self.screenshot_path}\n' - ret += '--- Agent Observation ---\n' - ret += self.content - return ret diff --git a/openhands/events/observation/commands.py b/openhands/events/observation/commands.py deleted file mode 100644 index da49ef641a..0000000000 --- a/openhands/events/observation/commands.py +++ /dev/null @@ -1,231 +0,0 @@ -import json -import re -from dataclasses import dataclass, field -from typing import Any, Self - -from pydantic import BaseModel - -from openhands.core.logger import openhands_logger as logger -from openhands.core.schema import ObservationType -from openhands.events.observation.observation import Observation - -CMD_OUTPUT_PS1_BEGIN = '\n###PS1JSON###\n' -CMD_OUTPUT_PS1_END = '\n###PS1END###' -CMD_OUTPUT_METADATA_PS1_REGEX = re.compile( - f'^{CMD_OUTPUT_PS1_BEGIN.strip()}(.*?){CMD_OUTPUT_PS1_END.strip()}', - re.DOTALL | re.MULTILINE, -) - -# Default max size for command output content -# to prevent too large observations from being saved in the stream -# This matches the default max_message_chars in LLMConfig -MAX_CMD_OUTPUT_SIZE: int = 30000 - - -class CmdOutputMetadata(BaseModel): - """Additional metadata captured from PS1""" - - exit_code: int = -1 - pid: int = -1 - username: str | None = None - hostname: str | None = None - working_dir: str | None = None - py_interpreter_path: str | None = None - prefix: str = '' # Prefix to add to command output - suffix: str = '' # Suffix to add to command output - - @classmethod - def to_ps1_prompt(cls) -> str: - """Convert the required metadata into a PS1 prompt.""" - prompt = CMD_OUTPUT_PS1_BEGIN - json_str = json.dumps( - { - 'pid': '$!', - 'exit_code': '$?', - 'username': r'\u', - 'hostname': r'\h', - 'working_dir': r'$(pwd)', - 'py_interpreter_path': r'$(which python 2>/dev/null || echo "")', - }, - indent=2, - ) - # Make sure we escape double quotes in the JSON string - # So that PS1 will keep them as part of the output - prompt += json_str.replace('"', r'\"') - prompt += CMD_OUTPUT_PS1_END + '\n' # Ensure there's a newline at the end - return prompt - - @classmethod - def matches_ps1_metadata(cls, string: str) -> list[re.Match[str]]: - matches = [] - for match in CMD_OUTPUT_METADATA_PS1_REGEX.finditer(string): - try: - json.loads(match.group(1).strip()) # Try to parse as JSON - matches.append(match) - except json.JSONDecodeError: - logger.warning( - f'Failed to parse PS1 metadata: {match.group(1)}. Skipping.', - exc_info=True, - ) - continue # Skip if not valid JSON - return matches - - @classmethod - def from_ps1_match(cls, match: re.Match[str]) -> Self: - """Extract the required metadata from a PS1 prompt.""" - metadata = json.loads(match.group(1)) - # Create a copy of metadata to avoid modifying the original - processed = metadata.copy() - # Convert numeric fields - if 'pid' in metadata: - try: - processed['pid'] = int(float(str(metadata['pid']))) - except (ValueError, TypeError): - processed['pid'] = -1 - if 'exit_code' in metadata: - try: - processed['exit_code'] = int(float(str(metadata['exit_code']))) - except (ValueError, TypeError): - logger.warning( - f'Failed to parse exit code: {metadata["exit_code"]}. Setting to -1.' - ) - processed['exit_code'] = -1 - return cls(**processed) - - -@dataclass -class CmdOutputObservation(Observation): - """This data class represents the output of a command.""" - - command: str - observation: str = ObservationType.RUN - # Additional metadata captured from PS1 - metadata: CmdOutputMetadata = field(default_factory=CmdOutputMetadata) - # Whether the command output should be hidden from the user - hidden: bool = False - - def __init__( - self, - content: str, - command: str, - observation: str = ObservationType.RUN, - metadata: dict[str, Any] | CmdOutputMetadata | None = None, - hidden: bool = False, - **kwargs: Any, - ) -> None: - # Truncate content before passing it to parent - # Hidden commands don't go through LLM/event stream, so no need to truncate - truncate = not hidden - if truncate: - content = self._maybe_truncate(content) - - super().__init__(content) - - self.command = command - self.observation = observation - self.hidden = hidden - if isinstance(metadata, dict): - self.metadata = CmdOutputMetadata(**metadata) - else: - self.metadata = metadata or CmdOutputMetadata() - - # Handle legacy attribute - if 'exit_code' in kwargs: - self.metadata.exit_code = kwargs['exit_code'] - if 'command_id' in kwargs: - self.metadata.pid = kwargs['command_id'] - - @staticmethod - def _maybe_truncate(content: str, max_size: int = MAX_CMD_OUTPUT_SIZE) -> str: - """Truncate the content if it's too large. - - This helps avoid storing unnecessarily large content in the event stream. - - Args: - content: The content to truncate - max_size: Maximum size before truncation. Defaults to MAX_CMD_OUTPUT_SIZE. - - Returns: - Original content if not too large, or truncated content otherwise - """ - if len(content) <= max_size: - return content - - # Truncate the middle and include a message about it - half = max_size // 2 - original_length = len(content) - truncated = ( - content[:half] - + '\n[... Observation truncated due to length ...]\n' - + content[-half:] - ) - logger.debug( - f'Truncated large command output: {original_length} -> {len(truncated)} chars' - ) - return truncated - - @property - def command_id(self) -> int: - return self.metadata.pid - - @property - def exit_code(self) -> int: - return self.metadata.exit_code - - @property - def error(self) -> bool: - return self.exit_code != 0 - - @property - def message(self) -> str: - return f'Command `{self.command}` executed with exit code {self.exit_code}.' - - @property - def success(self) -> bool: - return not self.error - - def __str__(self) -> str: - return ( - f'**CmdOutputObservation (source={self.source}, exit code={self.exit_code}, ' - f'metadata={json.dumps(self.metadata.model_dump(), indent=2)})**\n' - '--BEGIN AGENT OBSERVATION--\n' - f'{self.to_agent_observation()}\n' - '--END AGENT OBSERVATION--' - ) - - def to_agent_observation(self) -> str: - ret = f'{self.metadata.prefix}{self.content}{self.metadata.suffix}' - if self.metadata.working_dir: - ret += f'\n[Current working directory: {self.metadata.working_dir}]' - if self.metadata.py_interpreter_path: - ret += f'\n[Python interpreter: {self.metadata.py_interpreter_path}]' - if self.metadata.exit_code != -1: - ret += f'\n[Command finished with exit code {self.metadata.exit_code}]' - return ret - - -@dataclass -class IPythonRunCellObservation(Observation): - """This data class represents the output of a IPythonRunCellAction.""" - - code: str - observation: str = ObservationType.RUN_IPYTHON - image_urls: list[str] | None = None - - @property - def error(self) -> bool: - return False # IPython cells do not return exit codes - - @property - def message(self) -> str: - return 'Code executed in IPython cell.' - - @property - def success(self) -> bool: - return True # IPython cells are always considered successful - - def __str__(self) -> str: - result = f'**IPythonRunCellObservation**\n{self.content}' - if self.image_urls: - result += f'\nImages: {len(self.image_urls)}' - return result diff --git a/openhands/events/observation/delegate.py b/openhands/events/observation/delegate.py deleted file mode 100644 index 9e98c6b598..0000000000 --- a/openhands/events/observation/delegate.py +++ /dev/null @@ -1,22 +0,0 @@ -from dataclasses import dataclass - -from openhands.core.schema import ObservationType -from openhands.events.observation.observation import Observation - - -@dataclass -class AgentDelegateObservation(Observation): - """This data class represents the result from delegating to another agent. - - Attributes: - content (str): The content of the observation. - outputs (dict): The outputs of the delegated agent. - observation (str): The type of observation. - """ - - outputs: dict - observation: str = ObservationType.DELEGATE - - @property - def message(self) -> str: - return '' diff --git a/openhands/events/observation/empty.py b/openhands/events/observation/empty.py deleted file mode 100644 index 9d7d0f18a7..0000000000 --- a/openhands/events/observation/empty.py +++ /dev/null @@ -1,17 +0,0 @@ -from dataclasses import dataclass - -from openhands.core.schema import ObservationType -from openhands.events.observation.observation import Observation - - -@dataclass -class NullObservation(Observation): - """This data class represents a null observation. - This is used when the produced action is NOT executable. - """ - - observation: str = ObservationType.NULL - - @property - def message(self) -> str: - return 'No observation' diff --git a/openhands/events/observation/error.py b/openhands/events/observation/error.py deleted file mode 100644 index 4ed05b89ac..0000000000 --- a/openhands/events/observation/error.py +++ /dev/null @@ -1,23 +0,0 @@ -from dataclasses import dataclass - -from openhands.core.schema import ObservationType -from openhands.events.observation.observation import Observation - - -@dataclass -class ErrorObservation(Observation): - """This data class represents an error encountered by the agent. - - This is the type of error that LLM can recover from. - E.g., Linter error after editing a file. - """ - - observation: str = ObservationType.ERROR - error_id: str = '' - - @property - def message(self) -> str: - return self.content - - def __str__(self) -> str: - return f'**ErrorObservation**\n{self.content}' diff --git a/openhands/events/observation/file_download.py b/openhands/events/observation/file_download.py deleted file mode 100644 index f80b6019d4..0000000000 --- a/openhands/events/observation/file_download.py +++ /dev/null @@ -1,21 +0,0 @@ -from dataclasses import dataclass - -from openhands.core.schema import ObservationType -from openhands.events.observation.observation import Observation - - -@dataclass -class FileDownloadObservation(Observation): - file_path: str - observation: str = ObservationType.DOWNLOAD - - @property - def message(self) -> str: - return f'Downloaded the file at location: {self.file_path}' - - def __str__(self) -> str: - ret = ( - '**FileDownloadObservation**\n' - f'Location of downloaded file: {self.file_path}\n' - ) - return ret diff --git a/openhands/events/observation/files.py b/openhands/events/observation/files.py deleted file mode 100644 index 673018a833..0000000000 --- a/openhands/events/observation/files.py +++ /dev/null @@ -1,195 +0,0 @@ -"""File-related observation classes for tracking file operations.""" - -from dataclasses import dataclass -from difflib import SequenceMatcher - -from openhands.core.schema import ObservationType -from openhands.events.event import FileEditSource, FileReadSource -from openhands.events.observation.observation import Observation - - -@dataclass -class FileReadObservation(Observation): - """This data class represents the content of a file.""" - - path: str - observation: str = ObservationType.READ - impl_source: FileReadSource = FileReadSource.DEFAULT - - @property - def message(self) -> str: - """Get a human-readable message describing the file read operation.""" - return f'I read the file {self.path}.' - - def __str__(self) -> str: - """Get a string representation of the file read observation.""" - return f'[Read from {self.path} is successful.]\n{self.content}' - - -@dataclass -class FileWriteObservation(Observation): - """This data class represents a file write operation.""" - - path: str - observation: str = ObservationType.WRITE - - @property - def message(self) -> str: - """Get a human-readable message describing the file write operation.""" - return f'I wrote to the file {self.path}.' - - def __str__(self) -> str: - """Get a string representation of the file write observation.""" - return f'[Write to {self.path} is successful.]\n{self.content}' - - -@dataclass -class FileEditObservation(Observation): - """This data class represents a file edit operation. - - The observation includes both the old and new content of the file, and can - generate a diff visualization showing the changes. The diff is computed lazily - and cached to improve performance. - - The .content property can either be: - - Git diff in LLM-based editing mode - - the rendered message sent to the LLM in OH_ACI mode (e.g., "The file /path/to/file.txt is created with the provided content.") - """ - - path: str = '' - prev_exist: bool = False - old_content: str | None = None - new_content: str | None = None - observation: str = ObservationType.EDIT - impl_source: FileEditSource = FileEditSource.LLM_BASED_EDIT - diff: str | None = ( - None # The raw diff between old and new content, used in OH_ACI mode - ) - _diff_cache: str | None = ( - None # Cache for the diff visualization, used in LLM-based editing mode - ) - - @property - def message(self) -> str: - """Get a human-readable message describing the file edit operation.""" - return f'I edited the file {self.path}.' - - def get_edit_groups(self, n_context_lines: int = 2) -> list[dict[str, list[str]]]: - """Get the edit groups showing changes between old and new content. - - Args: - n_context_lines: Number of context lines to show around each change. - - Returns: - A list of edit groups, where each group contains before/after edits. - """ - if self.old_content is None or self.new_content is None: - return [] - old_lines = self.old_content.split('\n') - new_lines = self.new_content.split('\n') - # Borrowed from difflib.unified_diff to directly parse into structured format - edit_groups: list[dict] = [] - for group in SequenceMatcher(None, old_lines, new_lines).get_grouped_opcodes( - n_context_lines - ): - # Take the max line number in the group - _indent_pad_size = len(str(group[-1][3])) + 1 # +1 for "*" prefix - cur_group: dict[str, list[str]] = { - 'before_edits': [], - 'after_edits': [], - } - for tag, i1, i2, j1, j2 in group: - if tag == 'equal': - for idx, line in enumerate(old_lines[i1:i2]): - line_num = i1 + idx + 1 - cur_group['before_edits'].append( - f'{line_num:>{_indent_pad_size}}|{line}' - ) - for idx, line in enumerate(new_lines[j1:j2]): - line_num = j1 + idx + 1 - cur_group['after_edits'].append( - f'{line_num:>{_indent_pad_size}}|{line}' - ) - continue - if tag in {'replace', 'delete'}: - for idx, line in enumerate(old_lines[i1:i2]): - line_num = i1 + idx + 1 - cur_group['before_edits'].append( - f'-{line_num:>{_indent_pad_size - 1}}|{line}' - ) - if tag in {'replace', 'insert'}: - for idx, line in enumerate(new_lines[j1:j2]): - line_num = j1 + idx + 1 - cur_group['after_edits'].append( - f'+{line_num:>{_indent_pad_size - 1}}|{line}' - ) - edit_groups.append(cur_group) - return edit_groups - - def visualize_diff( - self, - n_context_lines: int = 2, - change_applied: bool = True, - ) -> str: - """Visualize the diff of the file edit. Used in the LLM-based editing mode. - - Instead of showing the diff line by line, this function shows each hunk - of changes as a separate entity. - - Args: - n_context_lines: Number of context lines to show before/after changes. - change_applied: Whether changes are applied. If false, shows as - attempted edit. - - Returns: - A string containing the formatted diff visualization. - """ - # Use cached diff if available - if self._diff_cache is not None: - return self._diff_cache - - # Check if there are any changes - if change_applied and self.old_content == self.new_content: - msg = '(no changes detected. Please make sure your edits change ' - msg += 'the content of the existing file.)\n' - self._diff_cache = msg - return self._diff_cache - - edit_groups = self.get_edit_groups(n_context_lines=n_context_lines) - - if change_applied: - header = f'[Existing file {self.path} is edited with ' - header += f'{len(edit_groups)} changes.]' - else: - header = f"[Changes are NOT applied to {self.path} - Here's how " - header += 'the file looks like if changes are applied.]' - result = [header] - - op_type = 'edit' if change_applied else 'ATTEMPTED edit' - for i, cur_edit_group in enumerate(edit_groups): - if i != 0: - result.append('-------------------------') - result.append(f'[begin of {op_type} {i + 1} / {len(edit_groups)}]') - result.append(f'(content before {op_type})') - result.extend(cur_edit_group['before_edits']) - result.append(f'(content after {op_type})') - result.extend(cur_edit_group['after_edits']) - result.append(f'[end of {op_type} {i + 1} / {len(edit_groups)}]') - - # Cache the result - self._diff_cache = '\n'.join(result) - return self._diff_cache - - def __str__(self) -> str: - """Get a string representation of the file edit observation.""" - if self.impl_source == FileEditSource.OH_ACI: - return self.content - - if not self.prev_exist: - assert self.old_content == '', ( - 'old_content should be empty if the file is new (prev_exist=False).' - ) - return f'[New file {self.path} is created with the provided content.]\n' - - # Use cached diff if available, otherwise compute it - return self.visualize_diff().rstrip() + '\n' diff --git a/openhands/events/observation/loop_recovery.py b/openhands/events/observation/loop_recovery.py deleted file mode 100644 index 34c11870a2..0000000000 --- a/openhands/events/observation/loop_recovery.py +++ /dev/null @@ -1,18 +0,0 @@ -from dataclasses import dataclass - -from openhands.core.schema import ObservationType -from openhands.events.observation.observation import Observation - - -@dataclass -class LoopDetectionObservation(Observation): - """Observation for loop recovery state changes. - - This observation is used to notify the UI layer when agent - is in loop recovery mode. - - This observation is CLI-specific and should only be displayed - in CLI/TUI mode, not in GUI or other UI modes. - """ - - observation: str = ObservationType.LOOP_DETECTION diff --git a/openhands/events/observation/mcp.py b/openhands/events/observation/mcp.py deleted file mode 100644 index 019e1c72e8..0000000000 --- a/openhands/events/observation/mcp.py +++ /dev/null @@ -1,20 +0,0 @@ -from dataclasses import dataclass, field -from typing import Any - -from openhands.core.schema import ObservationType -from openhands.events.observation.observation import Observation - - -@dataclass -class MCPObservation(Observation): - """This data class represents the result of a MCP Server operation.""" - - observation: str = ObservationType.MCP - name: str = '' # The name of the MCP tool that was called - arguments: dict[str, Any] = field( - default_factory=dict - ) # The arguments passed to the MCP tool - - @property - def message(self) -> str: - return self.content diff --git a/openhands/events/observation/observation.py b/openhands/events/observation/observation.py deleted file mode 100644 index f0407d0f3f..0000000000 --- a/openhands/events/observation/observation.py +++ /dev/null @@ -1,15 +0,0 @@ -from dataclasses import dataclass - -from openhands.events.event import Event - - -@dataclass -class Observation(Event): - """Base class for observations from the environment. - - Attributes: - content: The content of the observation. For large observations, - this might be truncated when stored. - """ - - content: str diff --git a/openhands/events/observation/reject.py b/openhands/events/observation/reject.py deleted file mode 100644 index 16973c8083..0000000000 --- a/openhands/events/observation/reject.py +++ /dev/null @@ -1,15 +0,0 @@ -from dataclasses import dataclass - -from openhands.core.schema import ObservationType -from openhands.events.observation.observation import Observation - - -@dataclass -class UserRejectObservation(Observation): - """This data class represents the result of a rejected action.""" - - observation: str = ObservationType.USER_REJECTED - - @property - def message(self) -> str: - return self.content diff --git a/openhands/events/observation/success.py b/openhands/events/observation/success.py deleted file mode 100644 index 27ef979977..0000000000 --- a/openhands/events/observation/success.py +++ /dev/null @@ -1,15 +0,0 @@ -from dataclasses import dataclass - -from openhands.core.schema import ObservationType -from openhands.events.observation.observation import Observation - - -@dataclass -class SuccessObservation(Observation): - """This data class represents the result of a successful action.""" - - observation: str = ObservationType.SUCCESS - - @property - def message(self) -> str: - return self.content diff --git a/openhands/events/observation/task_tracking.py b/openhands/events/observation/task_tracking.py deleted file mode 100644 index 50595398ff..0000000000 --- a/openhands/events/observation/task_tracking.py +++ /dev/null @@ -1,18 +0,0 @@ -from dataclasses import dataclass, field -from typing import Any - -from openhands.core.schema import ObservationType -from openhands.events.observation.observation import Observation - - -@dataclass -class TaskTrackingObservation(Observation): - """This data class represents the result of a task tracking operation.""" - - observation: str = ObservationType.TASK_TRACKING - command: str = '' - task_list: list[dict[str, Any]] = field(default_factory=list) - - @property - def message(self) -> str: - return self.content diff --git a/openhands/events/recall_type.py b/openhands/events/recall_type.py deleted file mode 100644 index 2031950f1a..0000000000 --- a/openhands/events/recall_type.py +++ /dev/null @@ -1,11 +0,0 @@ -from enum import Enum - - -class RecallType(str, Enum): - """The type of information that can be retrieved from microagents.""" - - WORKSPACE_CONTEXT = 'workspace_context' - """Workspace context (repo instructions, runtime, etc.)""" - - KNOWLEDGE = 'knowledge' - """A knowledge microagent.""" diff --git a/openhands/events/serialization/__init__.py b/openhands/events/serialization/__init__.py deleted file mode 100644 index 76793821b4..0000000000 --- a/openhands/events/serialization/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from openhands.events.serialization.action import ( - action_from_dict, -) -from openhands.events.serialization.event import ( - event_from_dict, - event_to_dict, - event_to_trajectory, -) -from openhands.events.serialization.observation import ( - observation_from_dict, -) - -__all__ = [ - 'action_from_dict', - 'event_from_dict', - 'event_to_dict', - 'event_to_trajectory', - 'observation_from_dict', -] diff --git a/openhands/events/serialization/action.py b/openhands/events/serialization/action.py deleted file mode 100644 index 98f2b89e61..0000000000 --- a/openhands/events/serialization/action.py +++ /dev/null @@ -1,156 +0,0 @@ -from typing import Any - -from openhands.core.exceptions import LLMMalformedActionError -from openhands.events.action.action import Action, ActionSecurityRisk -from openhands.events.action.agent import ( - AgentDelegateAction, - AgentFinishAction, - AgentRejectAction, - AgentThinkAction, - ChangeAgentStateAction, - CondensationAction, - CondensationRequestAction, - LoopRecoveryAction, - RecallAction, - TaskTrackingAction, -) -from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction -from openhands.events.action.commands import ( - CmdRunAction, - IPythonRunCellAction, -) -from openhands.events.action.empty import NullAction -from openhands.events.action.files import ( - FileEditAction, - FileReadAction, - FileWriteAction, -) -from openhands.events.action.mcp import MCPAction -from openhands.events.action.message import MessageAction, SystemMessageAction - -actions = ( - NullAction, - CmdRunAction, - IPythonRunCellAction, - BrowseURLAction, - BrowseInteractiveAction, - FileReadAction, - FileWriteAction, - FileEditAction, - AgentThinkAction, - AgentFinishAction, - AgentRejectAction, - AgentDelegateAction, - RecallAction, - ChangeAgentStateAction, - MessageAction, - SystemMessageAction, - CondensationAction, - CondensationRequestAction, - MCPAction, - TaskTrackingAction, - LoopRecoveryAction, -) - -ACTION_TYPE_TO_CLASS = {action_class.action: action_class for action_class in actions} # type: ignore[attr-defined] - - -def handle_action_deprecated_args(args: dict[str, Any]) -> dict[str, Any]: - # keep_prompt has been deprecated in https://github.com/OpenHands/OpenHands/pull/4881 - if 'keep_prompt' in args: - args.pop('keep_prompt') - - # task_completed has been deprecated - remove it from args to maintain backward compatibility - if 'task_completed' in args: - args.pop('task_completed') - - # Handle translated_ipython_code deprecation - if 'translated_ipython_code' in args: - code = args.pop('translated_ipython_code') - - # Check if it's a file_editor call using a prefix check for efficiency - file_editor_prefix = 'print(file_editor(**' - if ( - code is not None - and code.startswith(file_editor_prefix) - and code.endswith('))') - ): - try: - # Extract and evaluate the dictionary string - import ast - - # Extract the dictionary string between the prefix and the closing parentheses - dict_str = code[len(file_editor_prefix) : -2] # Remove prefix and '))' - file_args = ast.literal_eval(dict_str) - - # Update args with the extracted file editor arguments - args.update(file_args) - except (ValueError, SyntaxError): - # If parsing fails, just remove the translated_ipython_code - pass - - if args.get('command') == 'view': - args.pop( - 'command' - ) # "view" will be translated to FileReadAction which doesn't have a command argument - - return args - - -def action_from_dict(action: dict) -> Action: - if not isinstance(action, dict): - raise LLMMalformedActionError('action must be a dictionary') - action = action.copy() - if 'action' not in action: - raise LLMMalformedActionError(f"'action' key is not found in {action=}") - if not isinstance(action['action'], str): - raise LLMMalformedActionError( - f"'{action['action']=}' is not defined. Available actions: {ACTION_TYPE_TO_CLASS.keys()}" - ) - action_class = ACTION_TYPE_TO_CLASS.get(action['action']) - if action_class is None: - raise LLMMalformedActionError( - f"'{action['action']=}' is not defined. Available actions: {ACTION_TYPE_TO_CLASS.keys()}" - ) - args = action.get('args', {}) - # Remove timestamp from args if present - timestamp = args.pop('timestamp', None) - - # compatibility for older event streams - # is_confirmed has been renamed to confirmation_state - is_confirmed = args.pop('is_confirmed', None) - if is_confirmed is not None: - args['confirmation_state'] = is_confirmed - - # images_urls has been renamed to image_urls - if 'images_urls' in args: - args['image_urls'] = args.pop('images_urls') - - # Handle security_risk deserialization - if 'security_risk' in args and args['security_risk'] is not None: - try: - # Convert numeric value (int) back to enum - args['security_risk'] = ActionSecurityRisk(args['security_risk']) - except (ValueError, TypeError): - # If conversion fails, remove the invalid value - args.pop('security_risk') - - # handle deprecated args - args = handle_action_deprecated_args(args) - - try: - decoded_action = action_class(**args) - if 'timeout' in action: - blocking = args.get('blocking', False) - decoded_action.set_hard_timeout(action['timeout'], blocking=blocking) - - # Set timestamp if it was provided - if timestamp: - decoded_action._timestamp = timestamp - - except TypeError as e: - raise LLMMalformedActionError( - f'action={action} has the wrong arguments: {str(e)}' - ) - assert isinstance(decoded_action, Action) - return decoded_action diff --git a/openhands/events/serialization/event.py b/openhands/events/serialization/event.py deleted file mode 100644 index 5c17068553..0000000000 --- a/openhands/events/serialization/event.py +++ /dev/null @@ -1,178 +0,0 @@ -from dataclasses import asdict -from datetime import datetime -from enum import Enum -from typing import Any - -from pydantic import BaseModel - -from openhands.events import Event, EventSource -from openhands.events.metrics import Cost, Metrics, ResponseLatency, TokenUsage -from openhands.events.serialization.action import action_from_dict -from openhands.events.serialization.observation import observation_from_dict -from openhands.events.serialization.utils import remove_fields -from openhands.events.tool import ToolCallMetadata - -# TODO: move `content` into `extras` -TOP_KEYS = [ - 'id', - 'timestamp', - 'source', - 'message', - 'cause', - 'action', - 'observation', - 'tool_call_metadata', - 'llm_metrics', -] -UNDERSCORE_KEYS = [ - 'id', - 'timestamp', - 'source', - 'cause', - 'tool_call_metadata', - 'llm_metrics', -] - -DELETE_FROM_TRAJECTORY_EXTRAS = { - 'dom_object', - 'axtree_object', - 'active_page_index', - 'last_browser_action', - 'last_browser_action_error', - 'focused_element_bid', - 'extra_element_properties', -} - -DELETE_FROM_TRAJECTORY_EXTRAS_AND_SCREENSHOTS = DELETE_FROM_TRAJECTORY_EXTRAS | { - 'screenshot', - 'set_of_marks', -} - - -def event_from_dict(data: dict[str, Any]) -> 'Event': - evt: Event - if 'action' in data: - evt = action_from_dict(data) - elif 'observation' in data: - evt = observation_from_dict(data) - else: - raise ValueError(f'Unknown event type: {data}') - for key in UNDERSCORE_KEYS: - if key in data: - value = data[key] - if key == 'timestamp' and isinstance(value, datetime): - value = value.isoformat() - if key == 'source': - value = EventSource(value) - if key == 'tool_call_metadata': - value = ToolCallMetadata(**value) - if key == 'llm_metrics': - metrics = Metrics() - if isinstance(value, dict): - metrics.accumulated_cost = value.get('accumulated_cost', 0.0) - # Set max_budget_per_task if available - metrics.max_budget_per_task = value.get('max_budget_per_task') - for cost in value.get('costs', []): - metrics._costs.append(Cost(**cost)) - metrics.response_latencies = [ - ResponseLatency(**latency) - for latency in value.get('response_latencies', []) - ] - metrics.token_usages = [ - TokenUsage(**usage) for usage in value.get('token_usages', []) - ] - # Set accumulated token usage if available - if 'accumulated_token_usage' in value: - metrics._accumulated_token_usage = TokenUsage( - **value.get('accumulated_token_usage', {}) - ) - value = metrics - setattr(evt, '_' + key, value) - return evt - - -def _convert_pydantic_to_dict(obj: BaseModel | dict) -> dict: - if isinstance(obj, BaseModel): - return obj.model_dump() - return obj - - -def event_to_dict(event: 'Event') -> dict: - props = asdict(event) - d = {} - for key in TOP_KEYS: - if hasattr(event, key) and getattr(event, key) is not None: - d[key] = getattr(event, key) - elif hasattr(event, f'_{key}') and getattr(event, f'_{key}') is not None: - d[key] = getattr(event, f'_{key}') - if key == 'id' and d.get('id') == -1: - d.pop('id', None) - if key == 'timestamp' and 'timestamp' in d: - if isinstance(d['timestamp'], datetime): - d['timestamp'] = d['timestamp'].isoformat() - if key == 'source' and 'source' in d: - d['source'] = d['source'].value - if key == 'recall_type' and 'recall_type' in d: - d['recall_type'] = d['recall_type'].value - if key == 'tool_call_metadata' and 'tool_call_metadata' in d: - d['tool_call_metadata'] = d['tool_call_metadata'].model_dump() - if key == 'llm_metrics' and 'llm_metrics' in d: - d['llm_metrics'] = d['llm_metrics'].get() - props.pop(key, None) - - if 'security_risk' in props and props['security_risk'] is None: - props.pop('security_risk') - - # Remove task_completed from serialization when it's None (backward compatibility) - if 'task_completed' in props and props['task_completed'] is None: - props.pop('task_completed') - if 'action' in d: - # Handle security_risk for actions - include it in args - if 'security_risk' in props: - props['security_risk'] = props['security_risk'].value - d['args'] = props - if event.timeout is not None: - d['timeout'] = event.timeout - elif 'observation' in d: - d['content'] = props.pop('content', '') - - # props is a dict whose values can include a complex object like an instance of a BaseModel subclass - # such as CmdOutputMetadata - # we serialize it along with the rest - # we also handle the Enum conversion for RecallObservation - d['extras'] = { - k: (v.value if isinstance(v, Enum) else _convert_pydantic_to_dict(v)) - for k, v in props.items() - } - # Include success field for CmdOutputObservation - if hasattr(event, 'success'): - d['success'] = event.success - else: - raise ValueError(f'Event must be either action or observation. has: {event}') - return d - - -def event_to_trajectory(event: 'Event', include_screenshots: bool = False) -> dict: - d = event_to_dict(event) - if 'extras' in d: - remove_fields( - d['extras'], - DELETE_FROM_TRAJECTORY_EXTRAS - if include_screenshots - else DELETE_FROM_TRAJECTORY_EXTRAS_AND_SCREENSHOTS, - ) - return d - - -def truncate_content(content: str, max_chars: int | None = None) -> str: - """Truncate the middle of the observation content if it is too long.""" - if max_chars is None or len(content) <= max_chars or max_chars < 0: - return content - - # truncate the middle and include a message to the LLM about it - half = max_chars // 2 - return ( - content[:half] - + '\n[... Observation truncated due to length ...]\n' - + content[-half:] - ) diff --git a/openhands/events/serialization/observation.py b/openhands/events/serialization/observation.py deleted file mode 100644 index a7e50754a3..0000000000 --- a/openhands/events/serialization/observation.py +++ /dev/null @@ -1,142 +0,0 @@ -import copy -from typing import Any - -from openhands.events.observation.agent import ( - AgentCondensationObservation, - AgentStateChangedObservation, - AgentThinkObservation, - MicroagentKnowledge, - RecallObservation, -) -from openhands.events.observation.browse import BrowserOutputObservation -from openhands.events.observation.commands import ( - CmdOutputMetadata, - CmdOutputObservation, - IPythonRunCellObservation, -) -from openhands.events.observation.delegate import AgentDelegateObservation -from openhands.events.observation.empty import ( - NullObservation, -) -from openhands.events.observation.error import ErrorObservation -from openhands.events.observation.file_download import FileDownloadObservation -from openhands.events.observation.files import ( - FileEditObservation, - FileReadObservation, - FileWriteObservation, -) -from openhands.events.observation.loop_recovery import LoopDetectionObservation -from openhands.events.observation.mcp import MCPObservation -from openhands.events.observation.observation import Observation -from openhands.events.observation.reject import UserRejectObservation -from openhands.events.observation.success import SuccessObservation -from openhands.events.observation.task_tracking import TaskTrackingObservation -from openhands.events.recall_type import RecallType - -observations = ( - NullObservation, - CmdOutputObservation, - IPythonRunCellObservation, - BrowserOutputObservation, - FileReadObservation, - FileWriteObservation, - FileEditObservation, - AgentDelegateObservation, - SuccessObservation, - ErrorObservation, - AgentStateChangedObservation, - UserRejectObservation, - AgentCondensationObservation, - AgentThinkObservation, - RecallObservation, - MCPObservation, - FileDownloadObservation, - TaskTrackingObservation, - LoopDetectionObservation, -) - -OBSERVATION_TYPE_TO_CLASS = { - observation_class.observation: observation_class # type: ignore[attr-defined] - for observation_class in observations -} - - -def _update_cmd_output_metadata( - metadata: dict[str, Any] | CmdOutputMetadata | None, **kwargs: Any -) -> dict[str, Any] | CmdOutputMetadata: - """Update the metadata of a CmdOutputObservation. - - If metadata is None, create a new CmdOutputMetadata instance. - If metadata is a dict, update the dict. - If metadata is a CmdOutputMetadata instance, update the instance. - """ - if metadata is None: - return CmdOutputMetadata(**kwargs) - - if isinstance(metadata, dict): - metadata.update(**kwargs) - elif isinstance(metadata, CmdOutputMetadata): - for key, value in kwargs.items(): - setattr(metadata, key, value) - return metadata - - -def handle_observation_deprecated_extras(extras: dict) -> dict: - # These are deprecated in https://github.com/OpenHands/OpenHands/pull/4881 - if 'exit_code' in extras: - extras['metadata'] = _update_cmd_output_metadata( - extras.get('metadata', None), exit_code=extras.pop('exit_code') - ) - if 'command_id' in extras: - extras['metadata'] = _update_cmd_output_metadata( - extras.get('metadata', None), pid=extras.pop('command_id') - ) - - # formatted_output_and_error has been deprecated in https://github.com/OpenHands/OpenHands/pull/6671 - if 'formatted_output_and_error' in extras: - extras.pop('formatted_output_and_error') - return extras - - -def observation_from_dict(observation: dict) -> Observation: - observation = observation.copy() - if 'observation' not in observation: - raise KeyError(f"'observation' key is not found in {observation=}") - observation_class = OBSERVATION_TYPE_TO_CLASS.get(observation['observation']) - if observation_class is None: - raise KeyError( - f"'{observation['observation']=}' is not defined. Available observations: {OBSERVATION_TYPE_TO_CLASS.keys()}" - ) - observation.pop('observation') - observation.pop('message', None) - content = observation.pop('content', '') - extras = copy.deepcopy(observation.pop('extras', {})) - - extras = handle_observation_deprecated_extras(extras) - - # convert metadata to CmdOutputMetadata if it is a dict - if observation_class is CmdOutputObservation: - if 'metadata' in extras and isinstance(extras['metadata'], dict): - extras['metadata'] = CmdOutputMetadata(**extras['metadata']) - elif 'metadata' in extras and isinstance(extras['metadata'], CmdOutputMetadata): - pass - else: - extras['metadata'] = CmdOutputMetadata() - - if observation_class is RecallObservation: - # handle the Enum conversion - if 'recall_type' in extras: - extras['recall_type'] = RecallType(extras['recall_type']) - - # convert dicts in microagent_knowledge to MicroagentKnowledge objects - if 'microagent_knowledge' in extras and isinstance( - extras['microagent_knowledge'], list - ): - extras['microagent_knowledge'] = [ - MicroagentKnowledge(**item) if isinstance(item, dict) else item - for item in extras['microagent_knowledge'] - ] - - obs = observation_class(content=content, **extras) - assert isinstance(obs, Observation) - return obs diff --git a/openhands/events/serialization/utils.py b/openhands/events/serialization/utils.py deleted file mode 100644 index c2eac8a996..0000000000 --- a/openhands/events/serialization/utils.py +++ /dev/null @@ -1,20 +0,0 @@ -def remove_fields(obj: dict | list | tuple, fields: set[str]) -> None: - """Remove fields from an object. - - Parameters: - - obj: The dictionary, or list of dictionaries to remove fields from - - fields (set[str]): A set of field names to remove from the object - """ - if isinstance(obj, dict): - for field in fields: - if field in obj: - del obj[field] - for _, value in obj.items(): - remove_fields(value, fields) - elif isinstance(obj, (list, tuple)): - for item in obj: - remove_fields(item, fields) - if hasattr(obj, '__dataclass_fields__'): - raise ValueError( - 'Object must not contain dataclass, consider converting to dict first' - ) diff --git a/openhands/events/stream.py b/openhands/events/stream.py deleted file mode 100644 index d464527a48..0000000000 --- a/openhands/events/stream.py +++ /dev/null @@ -1,291 +0,0 @@ -import asyncio -import json -import queue -import threading -from concurrent.futures import ThreadPoolExecutor -from datetime import datetime -from enum import Enum -from functools import partial -from typing import Any, Callable - -from openhands.core.logger import openhands_logger as logger -from openhands.events.event import Event, EventSource -from openhands.events.event_store import EventStore -from openhands.events.serialization.event import event_from_dict, event_to_dict -from openhands.storage import FileStore -from openhands.storage.locations import ( - get_conversation_dir, -) -from openhands.utils.async_utils import call_sync_from_async -from openhands.utils.shutdown_listener import should_continue - - -class EventStreamSubscriber(str, Enum): - AGENT_CONTROLLER = 'agent_controller' - RESOLVER = 'openhands_resolver' - SERVER = 'server' - RUNTIME = 'runtime' - MEMORY = 'memory' - MAIN = 'main' - TEST = 'test' - - -async def session_exists( - sid: str, file_store: FileStore, user_id: str | None = None -) -> bool: - try: - await call_sync_from_async(file_store.list, get_conversation_dir(sid, user_id)) - return True - except FileNotFoundError: - return False - - -class EventStream(EventStore): - secrets: dict[str, str] - # For each subscriber ID, there is a map of callback functions - useful - # when there are multiple listeners - _subscribers: dict[str, dict[str, Callable]] - _lock: threading.Lock - _queue: queue.Queue[Event] - _queue_thread: threading.Thread - _queue_loop: asyncio.AbstractEventLoop | None - _thread_pools: dict[str, dict[str, ThreadPoolExecutor]] - _thread_loops: dict[str, dict[str, asyncio.AbstractEventLoop]] - _write_page_cache: list[dict] - - def __init__(self, sid: str, file_store: FileStore, user_id: str | None = None): - super().__init__(sid, file_store, user_id) - self._stop_flag = threading.Event() - self._queue: queue.Queue[Event] = queue.Queue() - self._thread_pools = {} - self._thread_loops = {} - self._queue_loop = None - self._queue_thread = threading.Thread(target=self._run_queue_loop) - self._queue_thread.daemon = True - self._queue_thread.start() - self._subscribers = {} - self._lock = threading.Lock() - self.secrets = {} - self._write_page_cache = [] - - def _init_thread_loop(self, subscriber_id: str, callback_id: str) -> None: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - if subscriber_id not in self._thread_loops: - self._thread_loops[subscriber_id] = {} - self._thread_loops[subscriber_id][callback_id] = loop - - def close(self) -> None: - self._stop_flag.set() - if self._queue_thread.is_alive(): - self._queue_thread.join() - - subscriber_ids = list(self._subscribers.keys()) - for subscriber_id in subscriber_ids: - callback_ids = list(self._subscribers[subscriber_id].keys()) - for callback_id in callback_ids: - self._clean_up_subscriber(subscriber_id, callback_id) - - # Clear queue - while not self._queue.empty(): - self._queue.get() - - def _clean_up_subscriber(self, subscriber_id: str, callback_id: str) -> None: - if subscriber_id not in self._subscribers: - logger.warning(f'Subscriber not found during cleanup: {subscriber_id}') - return - if callback_id not in self._subscribers[subscriber_id]: - logger.warning(f'Callback not found during cleanup: {callback_id}') - return - if ( - subscriber_id in self._thread_loops - and callback_id in self._thread_loops[subscriber_id] - ): - loop = self._thread_loops[subscriber_id][callback_id] - current_task = asyncio.current_task(loop) - pending = [ - task for task in asyncio.all_tasks(loop) if task is not current_task - ] - for task in pending: - task.cancel() - try: - loop.stop() - loop.close() - except Exception as e: - logger.warning( - f'Error closing loop for {subscriber_id}/{callback_id}: {e}' - ) - del self._thread_loops[subscriber_id][callback_id] - - if ( - subscriber_id in self._thread_pools - and callback_id in self._thread_pools[subscriber_id] - ): - pool = self._thread_pools[subscriber_id][callback_id] - pool.shutdown() - del self._thread_pools[subscriber_id][callback_id] - - del self._subscribers[subscriber_id][callback_id] - - def subscribe( - self, - subscriber_id: EventStreamSubscriber, - callback: Callable[[Event], None], - callback_id: str, - ) -> None: - initializer = partial(self._init_thread_loop, subscriber_id, callback_id) - pool = ThreadPoolExecutor(max_workers=1, initializer=initializer) - if subscriber_id not in self._subscribers: - self._subscribers[subscriber_id] = {} - self._thread_pools[subscriber_id] = {} - - if callback_id in self._subscribers[subscriber_id]: - raise ValueError( - f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}' - ) - - self._subscribers[subscriber_id][callback_id] = callback - self._thread_pools[subscriber_id][callback_id] = pool - - def unsubscribe( - self, subscriber_id: EventStreamSubscriber, callback_id: str - ) -> None: - if subscriber_id not in self._subscribers: - logger.warning(f'Subscriber not found during unsubscribe: {subscriber_id}') - return - - if callback_id not in self._subscribers[subscriber_id]: - logger.warning(f'Callback not found during unsubscribe: {callback_id}') - return - - self._clean_up_subscriber(subscriber_id, callback_id) - - def add_event(self, event: Event, source: EventSource) -> None: - if event.id != Event.INVALID_ID: - raise ValueError( - f'Event already has an ID:{event.id}. It was probably added back to the EventStream from inside a handler, triggering a loop.' - ) - event._timestamp = datetime.now().isoformat() - event._source = source # type: ignore [attr-defined] - with self._lock: - event._id = self.cur_id # type: ignore [attr-defined] - self.cur_id += 1 - - # Take a copy of the current write page - current_write_page = self._write_page_cache - - data = event_to_dict(event) - data = self._replace_secrets(data) - event = event_from_dict(data) - current_write_page.append(data) - - # If the page is full, create a new page for future events / other threads to use - if len(current_write_page) == self.cache_size: - self._write_page_cache = [] - - if event.id is not None: - # Write the event to the store - this can take some time - event_json = json.dumps(data) - filename = self._get_filename_for_id(event.id, self.user_id) - if len(event_json) > 1_000_000: # Roughly 1MB in bytes, ignoring encoding - logger.warning( - f'Saving event JSON over 1MB: {len(event_json):,} bytes, filename: {filename}', - extra={ - 'user_id': self.user_id, - 'session_id': self.sid, - 'size': len(event_json), - }, - ) - self.file_store.write(filename, event_json) - - # Store the cache page last - if it is not present during reads then it will simply be bypassed. - self._store_cache_page(current_write_page) - self._queue.put(event) - - def _store_cache_page(self, current_write_page: list[dict]): - """Store a page in the cache. Reading individual events is slow when there are a lot of them, so we use pages.""" - if len(current_write_page) < self.cache_size: - return - start = current_write_page[0]['id'] - end = start + self.cache_size - contents = json.dumps(current_write_page) - cache_filename = self._get_filename_for_cache(start, end) - self.file_store.write(cache_filename, contents) - - def set_secrets(self, secrets: dict[str, str]) -> None: - self.secrets = secrets.copy() - - def update_secrets(self, secrets: dict[str, str]) -> None: - self.secrets.update(secrets) - - def _replace_secrets( - self, data: dict[str, Any], is_top_level: bool = True - ) -> dict[str, Any]: - # Fields that should not have secrets replaced (only at top level - system metadata) - TOP_LEVEL_PROTECTED_FIELDS = { - 'timestamp', - 'id', - 'source', - 'cause', - 'action', - 'observation', - 'message', - } - - for key in data: - if is_top_level and key in TOP_LEVEL_PROTECTED_FIELDS: - # Skip secret replacement for protected system fields at top level only - continue - elif isinstance(data[key], dict): - data[key] = self._replace_secrets(data[key], is_top_level=False) - elif isinstance(data[key], str): - for secret in self.secrets.values(): - data[key] = data[key].replace(secret, '') - return data - - def _run_queue_loop(self) -> None: - self._queue_loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._queue_loop) - try: - self._queue_loop.run_until_complete(self._process_queue()) - finally: - self._queue_loop.close() - - async def _process_queue(self) -> None: - while should_continue() and not self._stop_flag.is_set(): - event = None - try: - event = self._queue.get(timeout=0.1) - except queue.Empty: - continue - - # pass each event to each callback in order - for key in sorted(self._subscribers.keys()): - callbacks = self._subscribers[key] - # Create a copy of the keys to avoid "dictionary changed size during iteration" error - callback_ids = list(callbacks.keys()) - for callback_id in callback_ids: - # Check if callback_id still exists (might have been removed during iteration) - if callback_id in callbacks: - callback = callbacks[callback_id] - pool = self._thread_pools[key][callback_id] - future = pool.submit(callback, event) - future.add_done_callback( - self._make_error_handler(callback_id, key) - ) - - def _make_error_handler( - self, callback_id: str, subscriber_id: str - ) -> Callable[[Any], None]: - def _handle_callback_error(fut: Any) -> None: - try: - # This will raise any exception that occurred during callback execution - fut.result() - except Exception as e: - logger.error( - f'Error in event callback {callback_id} for subscriber {subscriber_id}: {str(e)}', - ) - # Re-raise in the main thread so the error is not swallowed - raise e - - return _handle_callback_error diff --git a/openhands/events/tool.py b/openhands/events/tool.py deleted file mode 100644 index 30e288dc2f..0000000000 --- a/openhands/events/tool.py +++ /dev/null @@ -1,11 +0,0 @@ -from litellm import ModelResponse -from pydantic import BaseModel - - -class ToolCallMetadata(BaseModel): - # See https://docs.litellm.ai/docs/completion/function_call#step-3---second-litellmcompletion-call - function_name: str # Name of the function that was called - tool_call_id: str # ID of the tool call - - model_response: ModelResponse - total_calls_in_response: int diff --git a/tests/unit/app_server/test_webhook_router_stats.py b/tests/unit/app_server/test_webhook_router_stats.py index 3384180a3b..a3175fc835 100644 --- a/tests/unit/app_server/test_webhook_router_stats.py +++ b/tests/unit/app_server/test_webhook_router_stats.py @@ -529,18 +529,19 @@ class TestOnEventStatsProcessing: @pytest.mark.asyncio async def test_on_event_skips_non_stats_events(self): """Test that on_event skips non-stats events.""" - from unittest.mock import patch + from unittest.mock import MagicMock, patch from openhands.app_server.event_callback.webhook_router import on_event - from openhands.events.action.message import MessageAction conversation_id = uuid4() sandbox_id = 'sandbox_123' - # Create non-stats events + # Create non-stats events (use MagicMock for non-ConversationStateUpdateEvent) + mock_other_event = MagicMock() + mock_other_event.id = uuid4() events = [ ConversationStateUpdateEvent(key='execution_status', value='running'), - MessageAction(content='test'), + mock_other_event, ] mock_app_conversation_info = AppConversationInfo( diff --git a/tests/unit/events/test_action_serialization.py b/tests/unit/events/test_action_serialization.py deleted file mode 100644 index eda90ebe33..0000000000 --- a/tests/unit/events/test_action_serialization.py +++ /dev/null @@ -1,439 +0,0 @@ -from openhands.events.action import ( - Action, - AgentFinishAction, - AgentRejectAction, - BrowseInteractiveAction, - BrowseURLAction, - CmdRunAction, - FileEditAction, - FileReadAction, - FileWriteAction, - MessageAction, - RecallAction, -) -from openhands.events.action.action import ActionConfirmationStatus -from openhands.events.action.files import FileEditSource, FileReadSource -from openhands.events.serialization import ( - event_from_dict, - event_to_dict, -) - - -def serialization_deserialization( - original_action_dict, cls, max_message_chars: int = 10000 -): - action_instance = event_from_dict(original_action_dict) - assert isinstance(action_instance, Action), ( - 'The action instance should be an instance of Action.' - ) - assert isinstance(action_instance, cls), ( - f'The action instance should be an instance of {cls.__name__}.' - ) - - # event_to_dict is the regular serialization of an event - serialized_action_dict = event_to_dict(action_instance) - - # it has an extra message property, for the UI - serialized_action_dict.pop('message') - assert serialized_action_dict == original_action_dict, ( - 'The serialized action should match the original action dict.' - ) - - -def test_event_props_serialization_deserialization(): - original_action_dict = { - 'id': 42, - 'source': 'agent', - 'timestamp': '2021-08-01T12:00:00', - 'action': 'message', - 'args': { - 'content': 'This is a test.', - 'image_urls': None, - 'file_urls': None, - 'wait_for_response': False, - 'security_risk': -1, - }, - } - serialization_deserialization(original_action_dict, MessageAction) - - -def test_message_action_serialization_deserialization(): - original_action_dict = { - 'action': 'message', - 'args': { - 'content': 'This is a test.', - 'image_urls': None, - 'file_urls': None, - 'wait_for_response': False, - 'security_risk': -1, - }, - } - serialization_deserialization(original_action_dict, MessageAction) - - -def test_agent_finish_action_serialization_deserialization(): - original_action_dict = { - 'action': 'finish', - 'args': { - 'outputs': {}, - 'thought': '', - 'final_thought': '', - }, - } - serialization_deserialization(original_action_dict, AgentFinishAction) - - -def test_agent_finish_action_legacy_task_completed_serialization(): - """Test that old conversations with task_completed can still be loaded.""" - original_action_dict = { - 'action': 'finish', - 'args': { - 'outputs': {}, - 'thought': '', - 'final_thought': 'Task completed', - 'task_completed': 'true', # This should be ignored during deserialization - }, - } - # This should work without errors - task_completed should be stripped out - event = event_from_dict(original_action_dict) - assert isinstance(event, Action) - assert isinstance(event, AgentFinishAction) - assert event.final_thought == 'Task completed' - # task_completed attribute should not exist anymore - assert not hasattr(event, 'task_completed') - - # When serialized back, task_completed should not be present - event_dict = event_to_dict(event) - assert 'task_completed' not in event_dict['args'] - - -def test_agent_reject_action_serialization_deserialization(): - original_action_dict = { - 'action': 'reject', - 'args': {'outputs': {}, 'thought': ''}, - } - serialization_deserialization(original_action_dict, AgentRejectAction) - - -def test_cmd_run_action_serialization_deserialization(): - original_action_dict = { - 'action': 'run', - 'args': { - 'blocking': False, - 'command': 'echo "Hello world"', - 'is_input': False, - 'thought': '', - 'hidden': False, - 'confirmation_state': ActionConfirmationStatus.CONFIRMED, - 'is_static': False, - 'cwd': None, - 'security_risk': -1, - }, - } - serialization_deserialization(original_action_dict, CmdRunAction) - - -def test_browse_url_action_serialization_deserialization(): - original_action_dict = { - 'action': 'browse', - 'args': { - 'thought': '', - 'url': 'https://www.example.com', - 'return_axtree': False, - 'security_risk': -1, - }, - } - serialization_deserialization(original_action_dict, BrowseURLAction) - - -def test_browse_interactive_action_serialization_deserialization(): - original_action_dict = { - 'action': 'browse_interactive', - 'args': { - 'thought': '', - 'browser_actions': 'goto("https://www.example.com")', - 'browsergym_send_msg_to_user': '', - 'return_axtree': False, - 'security_risk': -1, - }, - } - serialization_deserialization(original_action_dict, BrowseInteractiveAction) - - -def test_file_read_action_serialization_deserialization(): - original_action_dict = { - 'action': 'read', - 'args': { - 'path': '/path/to/file.txt', - 'start': 0, - 'end': -1, - 'thought': 'None', - 'impl_source': 'default', - 'view_range': None, - 'security_risk': -1, - }, - } - serialization_deserialization(original_action_dict, FileReadAction) - - -def test_file_write_action_serialization_deserialization(): - original_action_dict = { - 'action': 'write', - 'args': { - 'path': '/path/to/file.txt', - 'content': 'Hello world', - 'start': 0, - 'end': 1, - 'thought': 'None', - 'security_risk': -1, - }, - } - serialization_deserialization(original_action_dict, FileWriteAction) - - -def test_file_edit_action_aci_serialization_deserialization(): - original_action_dict = { - 'action': 'edit', - 'args': { - 'path': '/path/to/file.txt', - 'command': 'str_replace', - 'file_text': None, - 'old_str': 'old text', - 'new_str': 'new text', - 'insert_line': None, - 'content': '', - 'start': 1, - 'end': -1, - 'thought': 'Replacing text', - 'impl_source': 'oh_aci', - 'security_risk': -1, - }, - } - serialization_deserialization(original_action_dict, FileEditAction) - - -def test_file_edit_action_llm_serialization_deserialization(): - original_action_dict = { - 'action': 'edit', - 'args': { - 'path': '/path/to/file.txt', - 'command': None, - 'file_text': None, - 'old_str': None, - 'new_str': None, - 'insert_line': None, - 'content': 'Updated content', - 'start': 1, - 'end': 10, - 'thought': 'Updating file content', - 'impl_source': 'llm_based_edit', - 'security_risk': -1, - }, - } - serialization_deserialization(original_action_dict, FileEditAction) - - -def test_cmd_run_action_legacy_serialization(): - original_action_dict = { - 'action': 'run', - 'args': { - 'blocking': False, - 'command': 'echo "Hello world"', - 'thought': '', - 'hidden': False, - 'confirmation_state': ActionConfirmationStatus.CONFIRMED, - 'keep_prompt': False, # will be treated as no-op - }, - } - event = event_from_dict(original_action_dict) - assert isinstance(event, Action) - assert isinstance(event, CmdRunAction) - assert event.command == 'echo "Hello world"' - assert event.hidden is False - assert not hasattr(event, 'keep_prompt') - - event_dict = event_to_dict(event) - assert 'keep_prompt' not in event_dict['args'] - assert ( - event_dict['args']['confirmation_state'] == ActionConfirmationStatus.CONFIRMED - ) - assert event_dict['args']['blocking'] is False - assert event_dict['args']['command'] == 'echo "Hello world"' - assert event_dict['args']['thought'] == '' - assert event_dict['args']['is_input'] is False - - -def test_file_llm_based_edit_action_legacy_serialization(): - original_action_dict = { - 'action': 'edit', - 'args': { - 'path': '/path/to/file.txt', - 'content': 'dummy content', - 'start': 1, - 'end': -1, - 'thought': 'Replacing text', - 'impl_source': 'oh_aci', - 'translated_ipython_code': None, - }, - } - event = event_from_dict(original_action_dict) - assert isinstance(event, Action) - assert isinstance(event, FileEditAction) - - # Common arguments - assert event.path == '/path/to/file.txt' - assert event.thought == 'Replacing text' - assert event.impl_source == FileEditSource.OH_ACI - assert not hasattr(event, 'translated_ipython_code') - - # OH_ACI arguments - assert event.command == '' - assert event.file_text is None - assert event.old_str is None - assert event.new_str is None - assert event.insert_line is None - - # LLM-based editing arguments - assert event.content == 'dummy content' - assert event.start == 1 - assert event.end == -1 - - event_dict = event_to_dict(event) - assert 'translated_ipython_code' not in event_dict['args'] - - # Common arguments - assert event_dict['args']['path'] == '/path/to/file.txt' - assert event_dict['args']['impl_source'] == 'oh_aci' - assert event_dict['args']['thought'] == 'Replacing text' - - # OH_ACI arguments - assert event_dict['args']['command'] == '' - assert event_dict['args']['file_text'] is None - assert event_dict['args']['old_str'] is None - assert event_dict['args']['new_str'] is None - assert event_dict['args']['insert_line'] is None - - # LLM-based editing arguments - assert event_dict['args']['content'] == 'dummy content' - assert event_dict['args']['start'] == 1 - assert event_dict['args']['end'] == -1 - - -def test_file_ohaci_edit_action_legacy_serialization(): - original_action_dict = { - 'action': 'edit', - 'args': { - 'path': '/workspace/game_2048.py', - 'content': '', - 'start': 1, - 'end': -1, - 'thought': "I'll help you create a simple 2048 game in Python. I'll use the str_replace_editor to create the file.", - 'impl_source': 'oh_aci', - 'translated_ipython_code': "print(file_editor(**{'command': 'create', 'path': '/workspace/game_2048.py', 'file_text': 'New file content'}))", - }, - } - event = event_from_dict(original_action_dict) - assert isinstance(event, Action) - assert isinstance(event, FileEditAction) - - # Common arguments - assert event.path == '/workspace/game_2048.py' - assert ( - event.thought - == "I'll help you create a simple 2048 game in Python. I'll use the str_replace_editor to create the file." - ) - assert event.impl_source == FileEditSource.OH_ACI - assert not hasattr(event, 'translated_ipython_code') - - # OH_ACI arguments - assert event.command == 'create' - assert event.file_text == 'New file content' - assert event.old_str is None - assert event.new_str is None - assert event.insert_line is None - - # LLM-based editing arguments - assert event.content == '' - assert event.start == 1 - assert event.end == -1 - - event_dict = event_to_dict(event) - assert 'translated_ipython_code' not in event_dict['args'] - - # Common arguments - assert event_dict['args']['path'] == '/workspace/game_2048.py' - assert event_dict['args']['impl_source'] == 'oh_aci' - assert ( - event_dict['args']['thought'] - == "I'll help you create a simple 2048 game in Python. I'll use the str_replace_editor to create the file." - ) - - # OH_ACI arguments - assert event_dict['args']['command'] == 'create' - assert event_dict['args']['file_text'] == 'New file content' - assert event_dict['args']['old_str'] is None - assert event_dict['args']['new_str'] is None - assert event_dict['args']['insert_line'] is None - - # LLM-based editing arguments - assert event_dict['args']['content'] == '' - assert event_dict['args']['start'] == 1 - assert event_dict['args']['end'] == -1 - - -def test_agent_microagent_action_serialization_deserialization(): - original_action_dict = { - 'action': 'recall', - 'args': { - 'query': 'What is the capital of France?', - 'thought': 'I need to find information about France', - 'recall_type': 'knowledge', - }, - } - serialization_deserialization(original_action_dict, RecallAction) - - -def test_file_read_action_legacy_serialization(): - original_action_dict = { - 'action': 'read', - 'args': { - 'path': '/workspace/test.txt', - 'start': 0, - 'end': -1, - 'thought': 'Reading the file contents', - 'impl_source': 'oh_aci', - 'translated_ipython_code': "print(file_editor(**{'command': 'view', 'path': '/workspace/test.txt'}))", - }, - } - - event = event_from_dict(original_action_dict) - assert isinstance(event, Action) - assert isinstance(event, FileReadAction) - - # Common arguments - assert event.path == '/workspace/test.txt' - assert event.thought == 'Reading the file contents' - assert event.impl_source == FileReadSource.OH_ACI - assert not hasattr(event, 'translated_ipython_code') - assert not hasattr( - event, 'command' - ) # FileReadAction should not have command attribute - - # Read-specific arguments - assert event.start == 0 - assert event.end == -1 - - event_dict = event_to_dict(event) - assert 'translated_ipython_code' not in event_dict['args'] - assert ( - 'command' not in event_dict['args'] - ) # command should not be in serialized args - - # Common arguments in serialized form - assert event_dict['args']['path'] == '/workspace/test.txt' - assert event_dict['args']['impl_source'] == 'oh_aci' - assert event_dict['args']['thought'] == 'Reading the file contents' - - # Read-specific arguments in serialized form - assert event_dict['args']['start'] == 0 - assert event_dict['args']['end'] == -1 diff --git a/tests/unit/events/test_command_success.py b/tests/unit/events/test_command_success.py deleted file mode 100644 index 298a3bcb4f..0000000000 --- a/tests/unit/events/test_command_success.py +++ /dev/null @@ -1,32 +0,0 @@ -from openhands.events.observation.commands import ( - CmdOutputMetadata, - CmdOutputObservation, - IPythonRunCellObservation, -) - - -def test_cmd_output_success(): - # Test successful command - obs = CmdOutputObservation( - command='ls', - content='file1.txt\nfile2.txt', - metadata=CmdOutputMetadata(exit_code=0), - ) - assert obs.success is True - assert obs.error is False - - # Test failed command - obs = CmdOutputObservation( - command='ls', - content='No such file or directory', - metadata=CmdOutputMetadata(exit_code=1), - ) - assert obs.success is False - assert obs.error is True - - -def test_ipython_cell_success(): - # IPython cells are always successful - obs = IPythonRunCellObservation(code='print("Hello")', content='Hello') - assert obs.success is True - assert obs.error is False diff --git a/tests/unit/events/test_event_serialization.py b/tests/unit/events/test_event_serialization.py deleted file mode 100644 index 3fdaab61bb..0000000000 --- a/tests/unit/events/test_event_serialization.py +++ /dev/null @@ -1,152 +0,0 @@ -from openhands.events.action import CmdRunAction, MessageAction -from openhands.events.action.action import ActionSecurityRisk -from openhands.events.metrics import Cost, Metrics, ResponseLatency, TokenUsage -from openhands.events.observation import CmdOutputMetadata, CmdOutputObservation -from openhands.events.serialization import event_from_dict, event_to_dict - - -def test_command_output_success_serialization(): - # Test successful command - obs = CmdOutputObservation( - command='ls', - content='file1.txt\nfile2.txt', - metadata=CmdOutputMetadata(exit_code=0), - ) - serialized = event_to_dict(obs) - assert serialized['success'] is True - - # Test failed command - obs = CmdOutputObservation( - command='ls', - content='No such file or directory', - metadata=CmdOutputMetadata(exit_code=1), - ) - serialized = event_to_dict(obs) - assert serialized['success'] is False - - -def test_metrics_basic_serialization(): - # Create a basic action with only accumulated_cost - action = MessageAction(content='Hello, world!') - metrics = Metrics() - metrics.accumulated_cost = 0.03 - action._llm_metrics = metrics - - # Test serialization - serialized = event_to_dict(action) - assert 'llm_metrics' in serialized - assert serialized['llm_metrics']['accumulated_cost'] == 0.03 - assert serialized['llm_metrics']['costs'] == [] - assert serialized['llm_metrics']['response_latencies'] == [] - assert serialized['llm_metrics']['token_usages'] == [] - - # Test deserialization - deserialized = event_from_dict(serialized) - assert deserialized.llm_metrics is not None - assert deserialized.llm_metrics.accumulated_cost == 0.03 - assert len(deserialized.llm_metrics.costs) == 0 - assert len(deserialized.llm_metrics.response_latencies) == 0 - assert len(deserialized.llm_metrics.token_usages) == 0 - - -def test_metrics_full_serialization(): - # Create an observation with all metrics fields - obs = CmdOutputObservation( - command='ls', - content='test.txt', - metadata=CmdOutputMetadata(exit_code=0), - ) - metrics = Metrics(model_name='test-model') - metrics.accumulated_cost = 0.03 - - # Add a cost - cost = Cost(model='test-model', cost=0.02) - metrics._costs.append(cost) - - # Add a response latency - latency = ResponseLatency(model='test-model', latency=0.5, response_id='test-id') - metrics.response_latencies = [latency] - - # Add token usage - usage = TokenUsage( - model='test-model', - prompt_tokens=10, - completion_tokens=20, - cache_read_tokens=0, - cache_write_tokens=0, - response_id='test-id', - ) - metrics.token_usages = [usage] - - obs._llm_metrics = metrics - - # Test serialization - serialized = event_to_dict(obs) - assert 'llm_metrics' in serialized - metrics_dict = serialized['llm_metrics'] - assert metrics_dict['accumulated_cost'] == 0.03 - assert len(metrics_dict['costs']) == 1 - assert metrics_dict['costs'][0]['cost'] == 0.02 - assert len(metrics_dict['response_latencies']) == 1 - assert metrics_dict['response_latencies'][0]['latency'] == 0.5 - assert len(metrics_dict['token_usages']) == 1 - assert metrics_dict['token_usages'][0]['prompt_tokens'] == 10 - assert metrics_dict['token_usages'][0]['completion_tokens'] == 20 - - # Test deserialization - deserialized = event_from_dict(serialized) - assert deserialized.llm_metrics is not None - assert deserialized.llm_metrics.accumulated_cost == 0.03 - assert len(deserialized.llm_metrics.costs) == 1 - assert deserialized.llm_metrics.costs[0].cost == 0.02 - assert len(deserialized.llm_metrics.response_latencies) == 1 - assert deserialized.llm_metrics.response_latencies[0].latency == 0.5 - assert len(deserialized.llm_metrics.token_usages) == 1 - assert deserialized.llm_metrics.token_usages[0].prompt_tokens == 10 - assert deserialized.llm_metrics.token_usages[0].completion_tokens == 20 - - -def test_metrics_none_serialization(): - # Test when metrics is None - obs = CmdOutputObservation( - command='ls', - content='test.txt', - metadata=CmdOutputMetadata(exit_code=0), - ) - obs._llm_metrics = None - - # Test serialization - serialized = event_to_dict(obs) - assert 'llm_metrics' not in serialized - - # Test deserialization - deserialized = event_from_dict(serialized) - assert deserialized.llm_metrics is None - - -def test_action_risk_serialization(): - # Test action with security risk - action = CmdRunAction(command='rm -rf /tmp/test') - action.security_risk = ActionSecurityRisk.HIGH - - # Test serialization - serialized = event_to_dict(action) - assert 'security_risk' in serialized['args'] - assert serialized['args']['security_risk'] == ActionSecurityRisk.HIGH.value - - # Test deserialization - deserialized = event_from_dict(serialized) - assert deserialized.security_risk == ActionSecurityRisk.HIGH - - # Test action with no security risk - action = CmdRunAction(command='ls') - # Don't set action_risk - - # Test serialization - serialized = event_to_dict(action) - assert 'security_risk' in serialized['args'] - assert serialized['args']['security_risk'] == ActionSecurityRisk.UNKNOWN.value - - # Test deserialization - deserialized = event_from_dict(serialized) - assert deserialized.security_risk == ActionSecurityRisk.UNKNOWN diff --git a/tests/unit/events/test_event_stream.py b/tests/unit/events/test_event_stream.py deleted file mode 100644 index 0895259b02..0000000000 --- a/tests/unit/events/test_event_stream.py +++ /dev/null @@ -1,865 +0,0 @@ -import gc -import json -import os -import time -from datetime import datetime - -import psutil -import pytest -from pytest import TempPathFactory - -from openhands.core.schema import ActionType, ObservationType -from openhands.events import EventSource, EventStream, EventStreamSubscriber -from openhands.events.action import ( - CmdRunAction, - NullAction, -) -from openhands.events.action.files import ( - FileEditAction, - FileReadAction, - FileWriteAction, -) -from openhands.events.action.message import MessageAction -from openhands.events.event import FileEditSource, FileReadSource -from openhands.events.event_filter import EventFilter -from openhands.events.observation import NullObservation -from openhands.events.observation.files import ( - FileEditObservation, - FileReadObservation, - FileWriteObservation, -) -from openhands.events.serialization.event import event_to_dict -from openhands.storage import get_file_store -from openhands.storage.locations import ( - get_conversation_event_filename, -) - - -@pytest.fixture -def temp_dir(tmp_path_factory: TempPathFactory) -> str: - return str(tmp_path_factory.mktemp('test_event_stream')) - - -def collect_events(stream): - return [event for event in stream.get_events()] - - -def test_basic_flow(temp_dir: str): - file_store = get_file_store('local', temp_dir) - event_stream = EventStream('abc', file_store) - event_stream.add_event(NullAction(), EventSource.AGENT) - assert len(collect_events(event_stream)) == 1 - - -def test_stream_storage(temp_dir: str): - file_store = get_file_store('local', temp_dir) - event_stream = EventStream('abc', file_store) - event_stream.add_event(NullObservation(''), EventSource.AGENT) - assert len(collect_events(event_stream)) == 1 - content = event_stream.file_store.read(get_conversation_event_filename('abc', 0)) - assert content is not None - data = json.loads(content) - assert 'timestamp' in data - del data['timestamp'] - assert data == { - 'id': 0, - 'source': 'agent', - 'observation': 'null', - 'content': '', - 'extras': {}, - 'message': 'No observation', - } - - -def test_rehydration(temp_dir: str): - file_store = get_file_store('local', temp_dir) - event_stream = EventStream('abc', file_store) - event_stream.add_event(NullObservation('obs1'), EventSource.AGENT) - event_stream.add_event(NullObservation('obs2'), EventSource.AGENT) - assert len(collect_events(event_stream)) == 2 - - stream2 = EventStream('es2', file_store) - assert len(collect_events(stream2)) == 0 - - stream1rehydrated = EventStream('abc', file_store) - events = collect_events(stream1rehydrated) - assert len(events) == 2 - assert events[0].content == 'obs1' - assert events[1].content == 'obs2' - - -def test_get_matching_events_type_filter(temp_dir: str): - file_store = get_file_store('local', temp_dir) - event_stream = EventStream('abc', file_store) - - # Add mixed event types - event_stream.add_event(NullAction(), EventSource.AGENT) - event_stream.add_event(NullObservation('test'), EventSource.AGENT) - event_stream.add_event(NullAction(), EventSource.AGENT) - event_stream.add_event(MessageAction(content='test'), EventSource.AGENT) - - # Filter by NullAction - events = event_stream.get_matching_events(event_types=(NullAction,)) - assert len(events) == 2 - assert all(isinstance(e, NullAction) for e in events) - - # Filter by NullObservation - events = event_stream.get_matching_events(event_types=(NullObservation,)) - assert len(events) == 1 - assert ( - isinstance(events[0], NullObservation) - and events[0].observation == ObservationType.NULL - ) - - # Filter by NullAction and MessageAction - events = event_stream.get_matching_events(event_types=(NullAction, MessageAction)) - assert len(events) == 3 - - # Filter in reverse - events = event_stream.get_matching_events(reverse=True, limit=3) - assert len(events) == 3 - assert isinstance(events[0], MessageAction) and events[0].content == 'test' - assert isinstance(events[2], NullObservation) and events[2].content == 'test' - - -def test_get_matching_events_query_search(temp_dir: str): - file_store = get_file_store('local', temp_dir) - event_stream = EventStream('abc', file_store) - - event_stream.add_event(NullObservation('hello world'), EventSource.AGENT) - event_stream.add_event(NullObservation('test message'), EventSource.AGENT) - event_stream.add_event(NullObservation('another hello'), EventSource.AGENT) - - # Search for 'hello' - events = event_stream.get_matching_events(query='hello') - assert len(events) == 2 - - # Search should be case-insensitive - events = event_stream.get_matching_events(query='HELLO') - assert len(events) == 2 - - # Search for non-existent text - events = event_stream.get_matching_events(query='nonexistent') - assert len(events) == 0 - - -def test_get_matching_events_source_filter(temp_dir: str): - file_store = get_file_store('local', temp_dir) - event_stream = EventStream('abc', file_store) - - event_stream.add_event(NullObservation('test1'), EventSource.AGENT) - event_stream.add_event(NullObservation('test2'), EventSource.ENVIRONMENT) - event_stream.add_event(NullObservation('test3'), EventSource.AGENT) - - # Filter by AGENT source - events = event_stream.get_matching_events(source='agent') - assert len(events) == 2 - assert all( - isinstance(e, NullObservation) and e.source == EventSource.AGENT for e in events - ) - - # Filter by ENVIRONMENT source - events = event_stream.get_matching_events(source='environment') - assert len(events) == 1 - assert ( - isinstance(events[0], NullObservation) - and events[0].source == EventSource.ENVIRONMENT - ) - - # Test that source comparison works correctly with None source - null_source_event = NullObservation('test4') - event_stream.add_event(null_source_event, EventSource.AGENT) - event = event_stream.get_event(event_stream.get_latest_event_id()) - event._source = None # type: ignore - - # Update the serialized version - data = event_to_dict(event) - event_stream.file_store.write( - event_stream._get_filename_for_id(event.id, event_stream.user_id), - json.dumps(data), - ) - - # Verify that source comparison works correctly - assert EventFilter(source='agent').exclude(event) - assert EventFilter(source=None).include(event) - - # Filter by AGENT source again - events = event_stream.get_matching_events(source='agent') - assert len(events) == 2 # Should not include the None source event - - -def test_get_matching_events_pagination(temp_dir: str): - file_store = get_file_store('local', temp_dir) - event_stream = EventStream('abc', file_store) - - # Add 5 events - for i in range(5): - event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT) - - # Test limit - events = event_stream.get_matching_events(limit=3) - assert len(events) == 3 - - # Test start_id - events = event_stream.get_matching_events(start_id=2) - assert len(events) == 3 - assert isinstance(events[0], NullObservation) and events[0].content == 'test2' - - # Test combination of start_id and limit - events = event_stream.get_matching_events(start_id=1, limit=2) - assert len(events) == 2 - assert isinstance(events[0], NullObservation) and events[0].content == 'test1' - assert isinstance(events[1], NullObservation) and events[1].content == 'test2' - - -def test_get_matching_events_limit_validation(temp_dir: str): - file_store = get_file_store('local', temp_dir) - event_stream = EventStream('abc', file_store) - - # Test limit less than 1 - with pytest.raises(ValueError, match='Limit must be between 1 and 100'): - event_stream.get_matching_events(limit=0) - - # Test limit greater than 100 - with pytest.raises(ValueError, match='Limit must be between 1 and 100'): - event_stream.get_matching_events(limit=101) - - # Test valid limits work - event_stream.add_event(NullObservation('test'), EventSource.AGENT) - events = event_stream.get_matching_events(limit=1) - assert len(events) == 1 - events = event_stream.get_matching_events(limit=100) - assert len(events) == 1 - - -def test_memory_usage_file_operations(temp_dir: str): - """Test memory usage during file operations in EventStream. - - This test verifies that memory usage during file operations is reasonable - and that memory is properly cleaned up after operations complete. - """ - - def get_memory_mb(): - """Get current memory usage in MB.""" - process = psutil.Process(os.getpid()) - return process.memory_info().rss / 1024 / 1024 - - # Create a test file with 100kb content - test_file = os.path.join(temp_dir, 'test_file.txt') - test_content = 'x' * (100 * 1024) # 100kb of data - with open(test_file, 'w') as f: - f.write(test_content) - - # Initialize FileStore and EventStream - file_store = get_file_store('local', temp_dir) - - # Record initial memory usage - gc.collect() - initial_memory = get_memory_mb() - max_memory_increase = 0 - - # Perform operations 20 times - for i in range(20): - event_stream = EventStream('test_session', file_store) - - # 1. Read file - read_action = FileReadAction( - path=test_file, - start=0, - end=-1, - thought='Reading file', - action=ActionType.READ, - impl_source=FileReadSource.DEFAULT, - ) - event_stream.add_event(read_action, EventSource.AGENT) - - read_obs = FileReadObservation( - path=test_file, impl_source=FileReadSource.DEFAULT, content=test_content - ) - event_stream.add_event(read_obs, EventSource.ENVIRONMENT) - - # 2. Write file - write_action = FileWriteAction( - path=test_file, - content=test_content, - start=0, - end=-1, - thought='Writing file', - action=ActionType.WRITE, - ) - event_stream.add_event(write_action, EventSource.AGENT) - - write_obs = FileWriteObservation(path=test_file, content=test_content) - event_stream.add_event(write_obs, EventSource.ENVIRONMENT) - - # 3. Edit file - edit_action = FileEditAction( - path=test_file, - content=test_content, - start=1, - end=-1, - thought='Editing file', - action=ActionType.EDIT, - impl_source=FileEditSource.LLM_BASED_EDIT, - ) - event_stream.add_event(edit_action, EventSource.AGENT) - - edit_obs = FileEditObservation( - path=test_file, - prev_exist=True, - old_content=test_content, - new_content=test_content, - impl_source=FileEditSource.LLM_BASED_EDIT, - content=test_content, - ) - event_stream.add_event(edit_obs, EventSource.ENVIRONMENT) - - # Close event stream and force garbage collection - event_stream.close() - gc.collect() - - # Check memory usage - current_memory = get_memory_mb() - memory_increase = current_memory - initial_memory - max_memory_increase = max(max_memory_increase, memory_increase) - - # Clean up - os.remove(test_file) - - # Memory increase should be reasonable (less than 50MB after 20 iterations) - assert max_memory_increase < 50, ( - f'Memory increase of {max_memory_increase:.1f}MB exceeds limit of 50MB' - ) - - -def test_cache_page_creation(temp_dir: str): - """Test that cache pages are created correctly when adding events.""" - file_store = get_file_store('local', temp_dir) - event_stream = EventStream('cache_test', file_store) - - # Set a smaller cache size for testing - event_stream.cache_size = 5 - - # Add events up to the cache size threshold - for i in range(10): - event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT) - - # Check that a cache page was created after adding the 5th event - cache_filename = event_stream._get_filename_for_cache(0, 5) - - try: - # Verify the content of the cache page - cache_content = file_store.read(cache_filename) - cache_exists = True - except FileNotFoundError: - cache_exists = False - - assert cache_exists, f'Cache file {cache_filename} should exist' - - # If cache exists, verify its content - if cache_exists: - cache_data = json.loads(cache_content) - assert len(cache_data) == 5, 'Cache page should contain 5 events' - - # Verify each event in the cache - for i, event_data in enumerate(cache_data): - assert event_data['content'] == f'test{i}', ( - f"Event {i} content should be 'test{i}'" - ) - - -def test_cache_page_loading(temp_dir: str): - """Test that cache pages are loaded correctly when retrieving events.""" - file_store = get_file_store('local', temp_dir) - - # Create an event stream with a small cache size - event_stream = EventStream('cache_load_test', file_store) - event_stream.cache_size = 5 - - # Add enough events to create multiple cache pages - for i in range(15): - event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT) - - # Create a new event stream to force loading from cache - new_stream = EventStream('cache_load_test', file_store) - new_stream.cache_size = 5 - - # Get all events and verify they're correct - events = collect_events(new_stream) - - # Check that we have a reasonable number of events (may not be exactly 15 due to implementation details) - assert len(events) > 10, 'Should retrieve most of the events' - - # Verify the events we did get are in the correct order and format - for i, event in enumerate(events): - assert isinstance(event, NullObservation), ( - f'Event {i} should be a NullObservation' - ) - assert event.content == f'test{i}', f"Event {i} content should be 'test{i}'" - - -def test_cache_page_performance(temp_dir: str): - """Test that using cache pages improves performance when retrieving many events.""" - file_store = get_file_store('local', temp_dir) - - # Create an event stream with cache enabled - cached_stream = EventStream('perf_test_cached', file_store) - cached_stream.cache_size = 10 - - # Add a significant number of events to the cached stream - num_events = 50 - for i in range(num_events): - cached_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT) - - # Create a second event stream with a different session ID but same cache size - uncached_stream = EventStream('perf_test_uncached', file_store) - uncached_stream.cache_size = 10 - - # Add the same number of events to the uncached stream - for i in range(num_events): - uncached_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT) - - # Measure time to retrieve all events from cached stream - start_time = time.time() - cached_events = collect_events(cached_stream) - cached_time = time.time() - start_time - - # Measure time to retrieve all events from uncached stream - start_time = time.time() - uncached_events = collect_events(uncached_stream) - uncached_time = time.time() - start_time - - # Verify both streams returned a reasonable number of events - assert len(cached_events) > 40, 'Cached stream should return most of the events' - assert len(uncached_events) > 40, 'Uncached stream should return most of the events' - - # Log the performance difference - logger_message = ( - f'Cached time: {cached_time:.4f}s, Uncached time: {uncached_time:.4f}s' - ) - print(logger_message) - - # We're primarily checking functionality here, not strict performance metrics - # In real-world scenarios with many more events, the performance difference would be more significant. - - -def test_search_events_limit(temp_dir: str): - """Test that the search_events method correctly applies the limit parameter.""" - file_store = get_file_store('local', temp_dir) - event_stream = EventStream('abc', file_store) - - # Add 10 events - for i in range(10): - event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT) - - # Test with no limit (should return all events) - events = list(event_stream.search_events()) - assert len(events) == 10 - - # Test with limit=5 (should return first 5 events) - events = list(event_stream.search_events(limit=5)) - assert len(events) == 5 - assert all(isinstance(e, NullObservation) for e in events) - assert [e.content for e in events] == ['test0', 'test1', 'test2', 'test3', 'test4'] - - # Test with limit=3 and start_id=5 (should return 3 events starting from ID 5) - events = list(event_stream.search_events(start_id=5, limit=3)) - assert len(events) == 3 - assert [e.content for e in events] == ['test5', 'test6', 'test7'] - - # Test with limit and reverse=True (should return events in reverse order) - events = list(event_stream.search_events(reverse=True, limit=4)) - assert len(events) == 4 - assert [e.content for e in events] == ['test9', 'test8', 'test7', 'test6'] - - # Test with limit and filter (should apply limit after filtering) - # Add some events with different content for filtering - event_stream.add_event(NullObservation('filter_me'), EventSource.AGENT) - event_stream.add_event(NullObservation('filter_me_too'), EventSource.AGENT) - - events = list( - event_stream.search_events(filter=EventFilter(query='filter'), limit=1) - ) - assert len(events) == 1 - assert events[0].content == 'filter_me' - - -def test_search_events_limit_with_complex_filters(temp_dir: str): - """Test the interaction between limit and various filter combinations in search_events.""" - file_store = get_file_store('local', temp_dir) - event_stream = EventStream('abc', file_store) - - # Add events with different sources and types - event_stream.add_event(NullAction(), EventSource.AGENT) # id 0 - event_stream.add_event(NullObservation('test1'), EventSource.AGENT) # id 1 - event_stream.add_event(MessageAction(content='hello'), EventSource.USER) # id 2 - event_stream.add_event(NullObservation('test2'), EventSource.ENVIRONMENT) # id 3 - event_stream.add_event(NullAction(), EventSource.AGENT) # id 4 - event_stream.add_event(MessageAction(content='world'), EventSource.USER) # id 5 - event_stream.add_event(NullObservation('hello world'), EventSource.AGENT) # id 6 - - # Test limit with type filter - events = list( - event_stream.search_events( - filter=EventFilter(include_types=(NullAction,)), limit=1 - ) - ) - assert len(events) == 1 - assert isinstance(events[0], NullAction) - assert events[0].id == 0 - - # Test limit with source filter - events = list( - event_stream.search_events(filter=EventFilter(source='user'), limit=1) - ) - assert len(events) == 1 - assert events[0].source == EventSource.USER - assert events[0].id == 2 - - # Test limit with query filter - events = list( - event_stream.search_events(filter=EventFilter(query='hello'), limit=2) - ) - assert len(events) == 2 - assert [e.id for e in events] == [2, 6] - - # Test limit with combined filters - events = list( - event_stream.search_events( - filter=EventFilter(source='agent', include_types=(NullObservation,)), - limit=1, - ) - ) - assert len(events) == 1 - assert isinstance(events[0], NullObservation) - assert events[0].source == EventSource.AGENT - assert events[0].id == 1 - - # Test limit with reverse and filter - events = list( - event_stream.search_events( - filter=EventFilter(source='agent'), reverse=True, limit=2 - ) - ) - assert len(events) == 2 - assert [e.id for e in events] == [6, 4] - - -def test_search_events_limit_edge_cases(temp_dir: str): - """Test edge cases for the limit parameter in search_events.""" - file_store = get_file_store('local', temp_dir) - event_stream = EventStream('abc', file_store) - - # Add some events - for i in range(5): - event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT) - - # Test with limit=None (should return all events) - events = list(event_stream.search_events(limit=None)) - assert len(events) == 5 - - # Test with limit larger than number of events - events = list(event_stream.search_events(limit=10)) - assert len(events) == 5 - - # Test with limit=0 (let's check actual behavior) - events = list(event_stream.search_events(limit=0)) - # If it returns all events, assert len(events) == 5 - # If it returns no events, assert len(events) == 0 - # Let's check the actual behavior - assert len(events) in [0, 5] - - # Test with negative limit (implementation returns only first event) - events = list(event_stream.search_events(limit=-1)) - assert len(events) == 1 - - # Test with empty result set and limit - events = list( - event_stream.search_events(filter=EventFilter(query='nonexistent'), limit=5) - ) - assert len(events) == 0 - - # Test with start_id beyond available events - events = list(event_stream.search_events(start_id=10, limit=5)) - assert len(events) == 0 - - -def test_callback_dictionary_modification(temp_dir: str): - """Test that the event stream can handle dictionary modification during iteration. - - This test verifies that the fix for the 'dictionary changed size during iteration' error works. - The test adds a callback that adds a new callback during iteration, which would cause an error - without the fix. - """ - file_store = get_file_store('local', temp_dir) - event_stream = EventStream('callback_test', file_store) - - # Track callback execution - callback_executed = [False, False, False] - - # Define a callback that will be added during iteration - def callback_added_during_iteration(event): - callback_executed[2] = True - - # First callback that will be called - def callback1(event): - callback_executed[0] = True - # This callback will add a new callback during iteration - # Without our fix, this would cause a "dictionary changed size during iteration" error - event_stream.subscribe( - EventStreamSubscriber.TEST, callback_added_during_iteration, 'callback3' - ) - - # Second callback that will be called - def callback2(event): - callback_executed[1] = True - - # Subscribe both callbacks - event_stream.subscribe(EventStreamSubscriber.TEST, callback1, 'callback1') - event_stream.subscribe(EventStreamSubscriber.TEST, callback2, 'callback2') - - # Add an event to trigger callbacks - event_stream.add_event(NullObservation('test'), EventSource.AGENT) - - # Give some time for the callbacks to execute - time.sleep(0.5) - - # Verify that the first two callbacks were executed - assert callback_executed[0] is True, 'First callback should have been executed' - assert callback_executed[1] is True, 'Second callback should have been executed' - - # The third callback should not have been executed for this event - # since it was added during iteration - assert callback_executed[2] is False, ( - 'Third callback should not have been executed for this event' - ) - - # Add another event to trigger all callbacks including the newly added one - callback_executed = [False, False, False] # Reset execution tracking - event_stream.add_event(NullObservation('test2'), EventSource.AGENT) - - # Give some time for the callbacks to execute - time.sleep(0.5) - - # Now all three callbacks should have been executed - assert callback_executed[0] is True, 'First callback should have been executed' - assert callback_executed[1] is True, 'Second callback should have been executed' - assert callback_executed[2] is True, 'Third callback should have been executed' - - # Clean up - event_stream.close() - - -def test_cache_page_partial_retrieval(temp_dir: str): - """Test retrieving events with start_id and end_id parameters using the cache.""" - file_store = get_file_store('local', temp_dir) - - # Create an event stream with a small cache size - event_stream = EventStream('partial_test', file_store) - event_stream.cache_size = 5 - - # Add events - for i in range(20): - event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT) - - # Test retrieving a subset of events that spans multiple cache pages - events = list(event_stream.get_events(start_id=3, end_id=12)) - - # Verify we got a reasonable number of events - assert len(events) >= 8, 'Should retrieve most events in the range' - - # Verify the events we did get are in the correct order - for i, event in enumerate(events): - expected_content = f'test{i + 3}' - assert event.content == expected_content, ( - f"Event {i} content should be '{expected_content}'" - ) - - # Test retrieving events in reverse order - reverse_events = list(event_stream.get_events(start_id=3, end_id=12, reverse=True)) - - # Verify we got a reasonable number of events in reverse - assert len(reverse_events) >= 8, 'Should retrieve most events in reverse' - - # Check the first few events to ensure they're in reverse order - if len(reverse_events) >= 3: - assert reverse_events[0].content.startswith('test1'), ( - 'First reverse event should be near the end of the range' - ) - assert int(reverse_events[0].content[4:]) > int( - reverse_events[1].content[4:] - ), 'Events should be in descending order' - - -def test_cache_page_with_missing_events(temp_dir: str): - """Test cache behavior when some events are missing.""" - file_store = get_file_store('local', temp_dir) - - # Create an event stream with a small cache size - event_stream = EventStream('missing_test', file_store) - event_stream.cache_size = 5 - - # Add events - for i in range(10): - event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT) - - # Create a new event stream to force reloading events - new_stream = EventStream('missing_test', file_store) - new_stream.cache_size = 5 - - # Get the initial count of events - initial_events = list(new_stream.get_events()) - initial_count = len(initial_events) - - # Delete an event file to simulate a missing event - # Choose an ID that's not at the beginning or end - missing_id = 5 - missing_filename = new_stream._get_filename_for_id(missing_id, new_stream.user_id) - try: - file_store.delete(missing_filename) - - # Create another stream to force reloading after deletion - reload_stream = EventStream('missing_test', file_store) - reload_stream.cache_size = 5 - - # Retrieve events after deletion - events_after_deletion = list(reload_stream.get_events()) - - # We should have fewer events than before - assert len(events_after_deletion) <= initial_count, ( - 'Should have fewer or equal events after deletion' - ) - - # Test that we can still retrieve events successfully - assert len(events_after_deletion) > 0, 'Should still retrieve some events' - - except Exception as e: - # If the delete operation fails, we'll just verify that the basic functionality works - print(f'Note: Could not delete file {missing_filename}: {e}') - assert len(initial_events) > 0, 'Should retrieve events successfully' - - -def test_secrets_replaced_in_content(temp_dir: str): - """Test that secrets are properly replaced in event content.""" - file_store = get_file_store('local', temp_dir) - stream = EventStream('test_session', file_store) - - # Set up a secret - stream.set_secrets({'api_key': 'secret123'}) - - # Create an event with the secret in the command - action = CmdRunAction( - command='curl -H "Authorization: Bearer secret123" https://api.example.com' - ) - action._timestamp = datetime.now().isoformat() - - # Convert to dict and apply secret replacement - data = event_to_dict(action) - data_with_secrets_replaced = stream._replace_secrets(data) - - # The secret should be replaced in the command - assert '' in data_with_secrets_replaced['args']['command'] - assert 'secret123' not in data_with_secrets_replaced['args']['command'] - - -def test_timestamp_not_affected_by_secret_replacement(temp_dir: str): - """Test that timestamps are not corrupted by secret replacement.""" - file_store = get_file_store('local', temp_dir) - stream = EventStream('test_session', file_store) - - # Set up a secret that appears in the current date (e.g., "18" for 2025-07-18) - stream.set_secrets({'test_secret': '18'}) - - # Create an event with a timestamp - action = CmdRunAction(command='echo "hello world"') - action._timestamp = '2025-07-18T17:01:36.799608' # Contains "18" - - # Convert to dict and apply secret replacement - data = event_to_dict(action) - original_timestamp = data['timestamp'] - data_with_secrets_replaced = stream._replace_secrets(data) - - # The timestamp should NOT be affected by secret replacement - assert data_with_secrets_replaced['timestamp'] == original_timestamp - assert '' not in data_with_secrets_replaced['timestamp'] - assert '18' in data_with_secrets_replaced['timestamp'] # Original value preserved - - -def test_protected_fields_not_affected_by_secret_replacement(temp_dir: str): - """Test that protected system fields are not affected by secret replacement.""" - file_store = get_file_store('local', temp_dir) - stream = EventStream('test_session', file_store) - - # Set up secrets that might appear in system fields - stream.set_secrets( - { - 'secret1': '123', # Could appear in ID - 'secret2': 'user', # Could appear in source - 'secret3': 'run', # Could appear in action/observation - 'secret4': 'Running', # Could appear in message - } - ) - - # Create test data with protected fields - data = { - 'id': 123, - 'timestamp': '2025-07-18T17:01:36.799608', - 'source': 'user', - 'cause': 123, - 'action': 'run', - 'observation': 'run', - 'message': 'Running command: echo hello', - 'content': 'This contains secret1: 123 and secret2: user and secret3: run', - } - - data_with_secrets_replaced = stream._replace_secrets(data) - - # Protected fields should not be affected at top level - assert data_with_secrets_replaced['id'] == 123 - assert data_with_secrets_replaced['timestamp'] == '2025-07-18T17:01:36.799608' - assert data_with_secrets_replaced['source'] == 'user' - assert data_with_secrets_replaced['cause'] == 123 - assert data_with_secrets_replaced['action'] == 'run' - assert data_with_secrets_replaced['observation'] == 'run' - assert data_with_secrets_replaced['message'] == 'Running command: echo hello' - - # But non-protected fields should have secrets replaced - assert '' in data_with_secrets_replaced['content'] - assert '123' not in data_with_secrets_replaced['content'] - assert 'user' not in data_with_secrets_replaced['content'] - # Note: 'run' should still be replaced in content since it's not a protected field - - -def test_nested_dict_secret_replacement(temp_dir: str): - """Test that secrets are replaced in nested dictionaries while preserving protected fields.""" - file_store = get_file_store('local', temp_dir) - stream = EventStream('test_session', file_store) - - stream.set_secrets({'secret': 'password123'}) - - # Create nested data structure - data = { - 'timestamp': '2025-07-18T17:01:36.799608', - 'args': { - 'command': 'login --password password123', - 'env': { - 'SECRET_KEY': 'password123', - 'timestamp': 'password123_timestamp', # This should be replaced since it's not top-level - }, - }, - } - - data_with_secrets_replaced = stream._replace_secrets(data) - - # Top-level timestamp should be protected - assert data_with_secrets_replaced['timestamp'] == '2025-07-18T17:01:36.799608' - - # Nested secrets should be replaced - assert '' in data_with_secrets_replaced['args']['command'] - assert data_with_secrets_replaced['args']['env']['SECRET_KEY'] == '' - assert '' in data_with_secrets_replaced['args']['env']['timestamp'] - - # Original secret should not appear in nested content - assert 'password123' not in data_with_secrets_replaced['args']['command'] - assert 'password123' not in data_with_secrets_replaced['args']['env']['SECRET_KEY'] - assert 'password123' not in data_with_secrets_replaced['args']['env']['timestamp'] diff --git a/tests/unit/events/test_file_edit_observation.py b/tests/unit/events/test_file_edit_observation.py deleted file mode 100644 index 4c8cfef780..0000000000 --- a/tests/unit/events/test_file_edit_observation.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Tests for FileEditObservation class.""" - -from openhands.events.event import FileEditSource -from openhands.events.observation.files import FileEditObservation - - -def test_file_edit_observation_basic(): - """Test basic properties of FileEditObservation.""" - obs = FileEditObservation( - path='/test/file.txt', - prev_exist=True, - old_content='Hello\nWorld\n', - new_content='Hello\nNew World\n', - impl_source=FileEditSource.LLM_BASED_EDIT, - content='Hello\nWorld\n', # Initial content is old_content - ) - - assert obs.path == '/test/file.txt' - assert obs.prev_exist is True - assert obs.old_content == 'Hello\nWorld\n' - assert obs.new_content == 'Hello\nNew World\n' - assert obs.impl_source == FileEditSource.LLM_BASED_EDIT - assert obs.message == 'I edited the file /test/file.txt.' - - -def test_file_edit_observation_diff_cache(): - """Test that diff visualization is cached.""" - obs = FileEditObservation( - path='/test/file.txt', - prev_exist=True, - old_content='Hello\nWorld\n', - new_content='Hello\nNew World\n', - impl_source=FileEditSource.LLM_BASED_EDIT, - content='Hello\nWorld\n', # Initial content is old_content - ) - - # First call should compute diff - diff1 = obs.visualize_diff() - assert obs._diff_cache is not None - - # Second call should use cache - diff2 = obs.visualize_diff() - assert diff1 == diff2 - - -def test_file_edit_observation_no_changes(): - """Test behavior when content hasn't changed.""" - content = 'Hello\nWorld\n' - obs = FileEditObservation( - path='/test/file.txt', - prev_exist=True, - old_content=content, - new_content=content, - impl_source=FileEditSource.LLM_BASED_EDIT, - content=content, # Initial content is old_content - ) - - diff = obs.visualize_diff() - assert '(no changes detected' in diff - - -def test_file_edit_observation_get_edit_groups(): - """Test the get_edit_groups method.""" - obs = FileEditObservation( - path='/test/file.txt', - prev_exist=True, - old_content='Line 1\nLine 2\nLine 3\nLine 4\n', - new_content='Line 1\nNew Line 2\nLine 3\nNew Line 4\n', - impl_source=FileEditSource.LLM_BASED_EDIT, - content='Line 1\nLine 2\nLine 3\nLine 4\n', # Initial content is old_content - ) - - groups = obs.get_edit_groups(n_context_lines=1) - assert len(groups) > 0 - - # Check structure of edit groups - for group in groups: - assert 'before_edits' in group - assert 'after_edits' in group - assert isinstance(group['before_edits'], list) - assert isinstance(group['after_edits'], list) - - # Verify line numbers and content - first_group = groups[0] - assert any('Line 2' in line for line in first_group['before_edits']) - assert any('New Line 2' in line for line in first_group['after_edits']) - - -def test_file_edit_observation_new_file(): - """Test behavior when editing a new file.""" - obs = FileEditObservation( - path='/test/new_file.txt', - prev_exist=False, - old_content='', - new_content='Hello\nWorld\n', - impl_source=FileEditSource.LLM_BASED_EDIT, - content='', # Initial content is old_content (empty for new file) - ) - - assert obs.prev_exist is False - assert obs.old_content == '' - assert ( - str(obs) - == '[New file /test/new_file.txt is created with the provided content.]\n' - ) - - # Test that trying to visualize diff for a new file works - diff = obs.visualize_diff() - assert diff is not None - - -def test_file_edit_observation_context_lines(): - """Test diff visualization with different context line settings.""" - obs = FileEditObservation( - path='/test/file.txt', - prev_exist=True, - old_content='Line 1\nLine 2\nLine 3\nLine 4\nLine 5\n', - new_content='Line 1\nNew Line 2\nLine 3\nNew Line 4\nLine 5\n', - impl_source=FileEditSource.LLM_BASED_EDIT, - content='Line 1\nLine 2\nLine 3\nLine 4\nLine 5\n', # Initial content is old_content - ) - - # Test with 0 context lines - groups_0 = obs.get_edit_groups(n_context_lines=0) - # Test with 2 context lines - groups_2 = obs.get_edit_groups(n_context_lines=2) - - # More context should mean more lines in the groups - total_lines_0 = sum( - len(g['before_edits']) + len(g['after_edits']) for g in groups_0 - ) - total_lines_2 = sum( - len(g['before_edits']) + len(g['after_edits']) for g in groups_2 - ) - assert total_lines_2 > total_lines_0 diff --git a/tests/unit/events/test_mcp_action_observation.py b/tests/unit/events/test_mcp_action_observation.py deleted file mode 100644 index e6da007e9a..0000000000 --- a/tests/unit/events/test_mcp_action_observation.py +++ /dev/null @@ -1,146 +0,0 @@ -import json - -from openhands.core.schema import ActionType, ObservationType -from openhands.events.action.action import ActionSecurityRisk -from openhands.events.action.mcp import MCPAction -from openhands.events.observation.mcp import MCPObservation - - -def test_mcp_action_creation(): - """Test creating an MCPAction.""" - action = MCPAction(name='test_tool', arguments={'arg1': 'value1', 'arg2': 42}) - - assert action.name == 'test_tool' - assert action.arguments == {'arg1': 'value1', 'arg2': 42} - assert action.action == ActionType.MCP - assert action.thought == '' - assert action.runnable is True - assert action.security_risk == ActionSecurityRisk.UNKNOWN - - -def test_mcp_action_with_thought(): - """Test creating an MCPAction with a thought.""" - action = MCPAction( - name='test_tool', - arguments={'arg1': 'value1', 'arg2': 42}, - thought='This is a test thought', - ) - - assert action.name == 'test_tool' - assert action.arguments == {'arg1': 'value1', 'arg2': 42} - assert action.thought == 'This is a test thought' - - -def test_mcp_action_message(): - """Test the message property of MCPAction.""" - action = MCPAction(name='test_tool', arguments={'arg1': 'value1', 'arg2': 42}) - - message = action.message - assert 'test_tool' in message - assert 'arg1' in message - assert 'value1' in message - assert '42' in message - - -def test_mcp_action_str_representation(): - """Test the string representation of MCPAction.""" - action = MCPAction( - name='test_tool', - arguments={'arg1': 'value1', 'arg2': 42}, - thought='This is a test thought', - ) - - str_repr = str(action) - assert 'MCPAction' in str_repr - assert 'THOUGHT: This is a test thought' in str_repr - assert 'NAME: test_tool' in str_repr - assert 'ARGUMENTS:' in str_repr - assert 'arg1' in str_repr - assert 'value1' in str_repr - assert '42' in str_repr - - -def test_mcp_observation_creation(): - """Test creating an MCPObservation.""" - observation = MCPObservation( - content=json.dumps({'result': 'success', 'data': 'test data'}) - ) - - assert observation.content == json.dumps({'result': 'success', 'data': 'test data'}) - assert observation.observation == ObservationType.MCP - - -def test_mcp_observation_message(): - """Test the message property of MCPObservation.""" - observation = MCPObservation( - content=json.dumps({'result': 'success', 'data': 'test data'}) - ) - - message = observation.message - assert message == json.dumps({'result': 'success', 'data': 'test data'}) - assert 'result' in message - assert 'success' in message - assert 'data' in message - assert 'test data' in message - - -def test_mcp_action_with_complex_arguments(): - """Test MCPAction with complex nested arguments.""" - complex_args = { - 'simple_arg': 'value', - 'number_arg': 42, - 'boolean_arg': True, - 'nested_arg': {'inner_key': 'inner_value', 'inner_list': [1, 2, 3]}, - 'list_arg': ['a', 'b', 'c'], - } - - action = MCPAction(name='complex_tool', arguments=complex_args) - - assert action.name == 'complex_tool' - assert action.arguments == complex_args - assert action.arguments['nested_arg']['inner_key'] == 'inner_value' - assert action.arguments['list_arg'] == ['a', 'b', 'c'] - - # Check that the message contains the complex arguments - message = action.message - assert 'complex_tool' in message - assert 'nested_arg' in message - assert 'inner_key' in message - assert 'inner_value' in message - - -def test_mcp_observation_with_arguments(): - """Test MCPObservation with arguments.""" - complex_args = { - 'simple_arg': 'value', - 'number_arg': 42, - 'boolean_arg': True, - 'nested_arg': {'inner_key': 'inner_value', 'inner_list': [1, 2, 3]}, - 'list_arg': ['a', 'b', 'c'], - } - - observation = MCPObservation( - content=json.dumps({'result': 'success', 'data': 'test data'}), - name='test_tool', - arguments=complex_args, - ) - - assert observation.content == json.dumps({'result': 'success', 'data': 'test data'}) - assert observation.observation == ObservationType.MCP - assert observation.name == 'test_tool' - assert observation.arguments == complex_args - assert observation.arguments['nested_arg']['inner_key'] == 'inner_value' - assert observation.arguments['list_arg'] == ['a', 'b', 'c'] - - # Test serialization - from openhands.events.serialization import event_to_dict - - serialized = event_to_dict(observation) - - assert serialized['observation'] == ObservationType.MCP - assert serialized['content'] == json.dumps( - {'result': 'success', 'data': 'test data'} - ) - assert serialized['extras']['name'] == 'test_tool' - assert serialized['extras']['arguments'] == complex_args - assert serialized['extras']['arguments']['nested_arg']['inner_key'] == 'inner_value' diff --git a/tests/unit/events/test_nested_event_store.py b/tests/unit/events/test_nested_event_store.py deleted file mode 100644 index af9b6ea7f0..0000000000 --- a/tests/unit/events/test_nested_event_store.py +++ /dev/null @@ -1,437 +0,0 @@ -"""Unit tests for the NestedEventStore class. - -These tests focus on the search_events method, which retrieves events from a remote API -and applies filtering based on various criteria. -""" - -from typing import Any -from unittest.mock import MagicMock, patch - -import pytest - -from openhands.events.action import MessageAction -from openhands.events.event import EventSource -from openhands.events.event_filter import EventFilter -from openhands.events.nested_event_store import NestedEventStore - - -def create_mock_event( - id: int, content: str, source: str = 'user', hidden: bool = False -) -> dict[str, Any]: - """Create a properly formatted mock event dictionary.""" - event_dict = { - 'id': id, - 'action': 'message', # This is required for the event_from_dict function - 'args': { - 'content': content, - }, - 'source': source, - } - - # Add hidden as a property that will be set on the event after deserialization - if hidden: - event_dict['hidden'] = True - - return event_dict - - -def create_mock_response( - events: list[dict[str, Any]], has_more: bool = False -) -> MagicMock: - """Helper function to create a mock HTTP response.""" - mock_response = MagicMock() - mock_response.json.return_value = {'events': events, 'has_more': has_more} - return mock_response - - -class TestNestedEventStore: - """Tests for the NestedEventStore class.""" - - @pytest.fixture - def event_store(self): - """Create a NestedEventStore instance for testing.""" - return NestedEventStore( - base_url='http://test-api.example.com', - sid='test-session', - user_id='test-user', - session_api_key='test-api-key', - ) - - @patch('httpx.get') - def test_search_events_basic(self, mock_get, event_store): - """Test basic event retrieval without filters.""" - # Setup mock response with two events - mock_events = [ - create_mock_event(1, 'Hello', 'user'), - create_mock_event(2, 'World', 'agent'), - ] - mock_get.return_value = create_mock_response(mock_events) - - # Call the method - events = list(event_store.search_events()) - - # Verify the results - assert len(events) == 2 - assert events[0].id == 1 - assert events[1].id == 2 - - # Verify the API call - mock_get.assert_called_once_with( - 'http://test-api.example.com/events?start_id=0&reverse=False', - headers={'X-Session-API-Key': 'test-api-key'}, - ) - - @patch('httpx.get') - def test_search_events_with_limit(self, mock_get, event_store): - """Test event retrieval with a limit.""" - # Setup mock response - mock_events = [ - create_mock_event(1, 'Hello', 'user'), - create_mock_event(2, 'World', 'agent'), - ] - mock_get.return_value = create_mock_response(mock_events) - - # Call the method with a limit - events = list(event_store.search_events(limit=1)) - - # Verify the results - assert len(events) == 1 - assert events[0].id == 1 - - # Verify the API call includes the limit parameter - mock_get.assert_called_once_with( - 'http://test-api.example.com/events?start_id=0&reverse=False&limit=1', - headers={'X-Session-API-Key': 'test-api-key'}, - ) - - @patch('httpx.get') - def test_search_events_with_start_id(self, mock_get, event_store): - """Test event retrieval with a specific start_id.""" - # Setup mock response - mock_events = [ - create_mock_event(5, 'Hello', 'user'), - create_mock_event(6, 'World', 'agent'), - ] - mock_get.return_value = create_mock_response(mock_events) - - # Call the method with a start_id - events = list(event_store.search_events(start_id=5)) - - # Verify the results - assert len(events) == 2 - assert events[0].id == 5 - assert events[1].id == 6 - - # Verify the API call includes the correct start_id - mock_get.assert_called_once_with( - 'http://test-api.example.com/events?start_id=5&reverse=False', - headers={'X-Session-API-Key': 'test-api-key'}, - ) - - @patch('httpx.get') - def test_search_events_reverse_order(self, mock_get, event_store): - """Test event retrieval in reverse order.""" - # Setup mock response - mock_events = [ - create_mock_event(3, 'World', 'agent'), - create_mock_event(2, 'Hello', 'user'), - ] - mock_get.return_value = create_mock_response(mock_events) - - # Call the method with reverse=True - events = list(event_store.search_events(reverse=True)) - - # Verify the results - assert len(events) == 2 - assert events[0].id == 3 - assert events[1].id == 2 - - # Verify the API call includes reverse=True - mock_get.assert_called_once_with( - 'http://test-api.example.com/events?start_id=0&reverse=True', - headers={'X-Session-API-Key': 'test-api-key'}, - ) - - @patch('httpx.get') - def test_search_events_with_end_id(self, mock_get, event_store): - """Test event retrieval with a specific end_id.""" - # Setup mock response - mock_events = [ - create_mock_event(1, 'Hello', 'user'), - create_mock_event(2, 'World', 'agent'), - create_mock_event(3, 'End', 'user'), - ] - mock_get.return_value = create_mock_response(mock_events) - - # Call the method with an end_id - events = list(event_store.search_events(end_id=3)) - - # Verify the results - assert len(events) == 3 - assert events[0].id == 1 - assert events[1].id == 2 - assert events[2].id == 3 - - # Verify the API call - mock_get.assert_called_once_with( - 'http://test-api.example.com/events?start_id=0&reverse=False', - headers={'X-Session-API-Key': 'test-api-key'}, - ) - - @patch('httpx.get') - @patch('openhands.events.event_filter.EventFilter.exclude') - def test_search_events_with_filter(self, mock_exclude, mock_get, event_store): - """Test event retrieval with an EventFilter.""" - # Setup mock response with mixed events - mock_events = [ - create_mock_event(1, 'Hello', 'user'), - create_mock_event(2, 'World', 'agent'), - create_mock_event(3, 'Hidden', 'user'), - ] - mock_get.return_value = create_mock_response(mock_events) - - # Configure the mock to exclude the third event (simulating exclude_hidden=True) - # Return True for the third event (to exclude it) and False for others - mock_exclude.side_effect = [False, False, True] - - # Create a filter (the actual implementation doesn't matter since we're mocking exclude) - event_filter = EventFilter() - - # Call the method with the filter - events = list(event_store.search_events(filter=event_filter)) - - # Verify the results (should exclude the third event) - assert len(events) == 2 - assert events[0].id == 1 - assert events[1].id == 2 - - # Verify the API call - mock_get.assert_called_once_with( - 'http://test-api.example.com/events?start_id=0&reverse=False', - headers={'X-Session-API-Key': 'test-api-key'}, - ) - - @patch('httpx.get') - def test_search_events_with_source_filter(self, mock_get, event_store): - """Test event retrieval with a source filter.""" - # Setup mock response with mixed sources - mock_events = [ - create_mock_event(1, 'Hello', 'user'), - create_mock_event(2, 'World', 'agent'), - create_mock_event(3, 'Another', 'user'), - ] - mock_get.return_value = create_mock_response(mock_events) - - # Create a filter to include only user events - event_filter = EventFilter(source='user') - - # Call the method with the filter - events = list(event_store.search_events(filter=event_filter)) - - # Verify the results (should only include user events) - assert len(events) == 2 - assert events[0].id == 1 - assert events[0].source == EventSource.USER - assert events[1].id == 3 - assert events[1].source == EventSource.USER - - # Verify the API call - mock_get.assert_called_once_with( - 'http://test-api.example.com/events?start_id=0&reverse=False', - headers={'X-Session-API-Key': 'test-api-key'}, - ) - - @patch('httpx.get') - def test_search_events_with_type_filter(self, mock_get, event_store): - """Test event retrieval with a type filter.""" - # Setup mock response with different event types - mock_events = [ - create_mock_event(1, 'Hello', 'user'), - # Create a different type of event (read) - { - 'id': 2, - 'action': 'read', # Using the correct ActionType.READ value - 'args': { - 'path': '/test/file.txt', - }, - 'source': 'agent', - }, - create_mock_event(3, 'Another', 'user'), - ] - mock_get.return_value = create_mock_response(mock_events) - - # Create a filter to include only MessageAction events - event_filter = EventFilter(include_types=(MessageAction,)) - - # Call the method with the filter - events = list(event_store.search_events(filter=event_filter)) - - # Verify the results (should only include MessageAction events) - assert len(events) == 2 - assert events[0].id == 1 - assert events[1].id == 3 - - # Verify the API call - mock_get.assert_called_once_with( - 'http://test-api.example.com/events?start_id=0&reverse=False', - headers={'X-Session-API-Key': 'test-api-key'}, - ) - - @patch('httpx.get') - def test_search_events_pagination(self, mock_get, event_store): - """Test event retrieval with pagination (has_more=True).""" - # Setup first page response - first_page_events = [ - create_mock_event(1, 'Hello', 'user'), - create_mock_event(2, 'World', 'agent'), - ] - first_response = create_mock_response(first_page_events, has_more=True) - - # Setup second page response - second_page_events = [ - create_mock_event(3, 'More', 'user'), - create_mock_event(4, 'Data', 'agent'), - ] - second_response = create_mock_response(second_page_events, has_more=False) - - # Configure mock to return different responses on consecutive calls - mock_get.side_effect = [first_response, second_response] - - # Call the method - events = list(event_store.search_events()) - - # Verify the results (should include all events from both pages) - assert len(events) == 4 - assert events[0].id == 1 - assert events[1].id == 2 - assert events[2].id == 3 - assert events[3].id == 4 - - # Verify the API calls - assert mock_get.call_count == 2 - # First call with start_id=0 - mock_get.assert_any_call( - 'http://test-api.example.com/events?start_id=0&reverse=False', - headers={'X-Session-API-Key': 'test-api-key'}, - ) - # Second call with start_id=3 (after processing events with IDs 1 and 2) - mock_get.assert_any_call( - 'http://test-api.example.com/events?start_id=3&reverse=False', - headers={'X-Session-API-Key': 'test-api-key'}, - ) - - @patch('httpx.get') - def test_search_events_no_session_api_key(self, mock_get): - """Test event retrieval without a session API key.""" - # Create event store without session_api_key - event_store = NestedEventStore( - base_url='http://test-api.example.com', - sid='test-session', - user_id='test-user', - ) - - # Setup mock response - mock_events = [create_mock_event(1, 'Hello', 'user')] - mock_get.return_value = create_mock_response(mock_events) - - # Call the method - events = list(event_store.search_events()) - - # Verify the results - assert len(events) == 1 - - # Verify the API call has no headers - mock_get.assert_called_once_with( - 'http://test-api.example.com/events?start_id=0&reverse=False', headers={} - ) - - @patch('httpx.get') - def test_search_events_with_query_filter(self, mock_get, event_store): - """Test event retrieval with a text query filter.""" - # Setup mock response with different content - mock_events = [ - create_mock_event(1, 'Hello world', 'user'), - create_mock_event(2, 'Python is great', 'agent'), - create_mock_event(3, 'Hello Python', 'user'), - ] - mock_get.return_value = create_mock_response(mock_events) - - # Create a filter to search for 'Python' - event_filter = EventFilter(query='Python') - - # Call the method with the filter - events = list(event_store.search_events(filter=event_filter)) - - # Verify the results (should only include events with 'Python' in content) - assert len(events) == 2 - assert events[0].id == 2 - assert events[1].id == 3 - - # Verify the API call - mock_get.assert_called_once_with( - 'http://test-api.example.com/events?start_id=0&reverse=False', - headers={'X-Session-API-Key': 'test-api-key'}, - ) - - @patch('httpx.get') - def test_search_events_reverse_pagination_multiple_pages( - self, mock_get, event_store - ): - """Ensure reverse pagination works across multiple server pages. - - We emulate the remote /events endpoint by using an in-memory EventStream as the - backing store and having httpx.get return paginated JSON responses derived from it. - """ - from urllib.parse import parse_qs, urlparse - - from openhands.events.event import EventSource - from openhands.events.observation import NullObservation - from openhands.events.serialization.event import event_to_dict - from openhands.events.stream import EventStream - from openhands.storage.memory import InMemoryFileStore - - # Build a fake server-side store with many events so that server pagination kicks in - fs = InMemoryFileStore() - server_stream = EventStream('test-session', fs, user_id='test-user') - total_events = 50 - for i in range(total_events): - server_stream.add_event(NullObservation(f'e{i}'), EventSource.AGENT) - - def server_side_get(url: str, headers: dict | None = None): - # Parse query params like the FastAPI layer would receive - parsed = urlparse(url) - qs = parse_qs(parsed.query) - start_id = int(qs.get('start_id', ['0'])[0]) - reverse = qs.get('reverse', ['False'])[0] == 'True' - end_id = int(qs['end_id'][0]) if 'end_id' in qs else None - limit = int(qs.get('limit', ['20'])[0]) # server default = 20 - - # Emulate server route logic: request limit+1 to compute has_more - events = list( - server_stream.search_events( - start_id=start_id, end_id=end_id, reverse=reverse, limit=limit + 1 - ) - ) - has_more = len(events) > limit - if has_more: - events = events[:limit] - - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - 'events': [event_to_dict(e) for e in events], - 'has_more': has_more, - } - return mock_response - - mock_get.side_effect = server_side_get - - # Execute the nested search in reverse without a client-side limit - results = list(event_store.search_events(reverse=True)) - - # Verify we received all events in descending order - assert len(results) == total_events - assert [e.id for e in results] == list(range(total_events - 1, -1, -1)) - - # Ensure multiple HTTP calls were made due to pagination - assert mock_get.call_count >= 2 diff --git a/tests/unit/events/test_observation_serialization.py b/tests/unit/events/test_observation_serialization.py deleted file mode 100644 index d681a286cd..0000000000 --- a/tests/unit/events/test_observation_serialization.py +++ /dev/null @@ -1,496 +0,0 @@ -from openhands.core.schema.observation import ObservationType -from openhands.events.action.files import FileEditSource -from openhands.events.observation import ( - CmdOutputMetadata, - CmdOutputObservation, - FileEditObservation, - Observation, - RecallObservation, -) -from openhands.events.observation.agent import MicroagentKnowledge -from openhands.events.observation.commands import MAX_CMD_OUTPUT_SIZE -from openhands.events.recall_type import RecallType -from openhands.events.serialization import ( - event_from_dict, - event_to_dict, - event_to_trajectory, -) -from openhands.events.serialization.observation import observation_from_dict - - -def serialization_deserialization( - original_observation_dict, cls, max_message_chars: int = 10000 -): - observation_instance = event_from_dict(original_observation_dict) - assert isinstance(observation_instance, Observation), ( - 'The observation instance should be an instance of Observation.' - ) - assert isinstance(observation_instance, cls), ( - f'The observation instance should be an instance of {cls}.' - ) - serialized_observation_dict = event_to_dict(observation_instance) - serialized_observation_trajectory = event_to_trajectory(observation_instance) - assert serialized_observation_dict == original_observation_dict, ( - 'The serialized observation should match the original observation dict.' - ) - assert serialized_observation_trajectory == original_observation_dict, ( - 'The serialized observation trajectory should match the original observation dict.' - ) - - -# Additional tests for various observation subclasses can be included here -def test_observation_event_props_serialization_deserialization(): - original_observation_dict = { - 'id': 42, - 'source': 'agent', - 'timestamp': '2021-08-01T12:00:00', - 'observation': 'run', - 'message': 'Command `ls -l` executed with exit code 0.', - 'extras': { - 'command': 'ls -l', - 'hidden': False, - 'metadata': { - 'exit_code': 0, - 'hostname': None, - 'pid': -1, - 'prefix': '', - 'py_interpreter_path': None, - 'suffix': '', - 'username': None, - 'working_dir': None, - }, - }, - 'content': 'foo.txt', - 'success': True, - } - serialization_deserialization(original_observation_dict, CmdOutputObservation) - - -def test_command_output_observation_serialization_deserialization(): - original_observation_dict = { - 'observation': 'run', - 'extras': { - 'command': 'ls -l', - 'hidden': False, - 'metadata': { - 'exit_code': 0, - 'hostname': None, - 'pid': -1, - 'prefix': '', - 'py_interpreter_path': None, - 'suffix': '', - 'username': None, - 'working_dir': None, - }, - }, - 'message': 'Command `ls -l` executed with exit code 0.', - 'content': 'foo.txt', - 'success': True, - } - serialization_deserialization(original_observation_dict, CmdOutputObservation) - - -def test_success_field_serialization(): - # Test success=True - obs = CmdOutputObservation( - content='Command succeeded', - command='ls -l', - metadata=CmdOutputMetadata( - exit_code=0, - ), - ) - serialized = event_to_dict(obs) - assert serialized['success'] is True - - # Test success=False - obs = CmdOutputObservation( - content='No such file or directory', - command='ls -l', - metadata=CmdOutputMetadata( - exit_code=1, - ), - ) - serialized = event_to_dict(obs) - assert serialized['success'] is False - - -def test_cmd_output_truncation(): - """Test that large command outputs are truncated during initialization.""" - # Create a large content string that exceeds MAX_CMD_OUTPUT_SIZE (50000 characters) - large_content = 'a' * 60000 # 60k characters - - # Create a CmdOutputObservation with the large content - obs = CmdOutputObservation( - content=large_content, - command='ls -R', - metadata=CmdOutputMetadata( - exit_code=0, - ), - ) - - # Verify the content was truncated - assert len(obs.content) < 60000 - - # The truncated content might be slightly larger than MAX_CMD_OUTPUT_SIZE - # due to the added truncation message - truncation_msg = '[... Observation truncated due to length ...]' - assert truncation_msg in obs.content - - # The truncation algorithm might add a few extra characters due to the truncation message - # We'll allow a small margin (1% of MAX_CMD_OUTPUT_SIZE) for the total content length - margin = int(MAX_CMD_OUTPUT_SIZE * 0.01) # 1% margin - assert len(obs.content) <= MAX_CMD_OUTPUT_SIZE + margin - - # Verify the beginning and end of the content are preserved - half_size = MAX_CMD_OUTPUT_SIZE // 2 - assert obs.content.startswith('a' * half_size) - assert obs.content.endswith('a' * half_size) - - -def test_cmd_output_no_truncation(): - """Test that small command outputs are not truncated.""" - # Create a content string that doesn't exceed MAX_CMD_OUTPUT_SIZE (50000 characters) - # We use a much smaller value for testing efficiency - small_content = 'a' * 1000 # 1k characters - - # Create a CmdOutputObservation with the small content - obs = CmdOutputObservation( - content=small_content, - command='ls', - metadata=CmdOutputMetadata( - exit_code=0, - ), - ) - - # Verify the content was not truncated - assert len(obs.content) == 1000 - assert obs.content == small_content - - -def test_legacy_serialization(): - original_observation_dict = { - 'id': 42, - 'source': 'agent', - 'timestamp': '2021-08-01T12:00:00', - 'observation': 'run', - 'message': 'Command `ls -l` executed with exit code 0.', - 'extras': { - 'command': 'ls -l', - 'hidden': False, - 'exit_code': 0, - 'command_id': 3, - }, - 'content': 'foo.txt', - 'success': True, - } - event = event_from_dict(original_observation_dict) - assert isinstance(event, Observation) - assert isinstance(event, CmdOutputObservation) - assert event.metadata.exit_code == 0 - assert event.success is True - assert event.command == 'ls -l' - assert event.hidden is False - - event_dict = event_to_dict(event) - assert event_dict['success'] is True - assert event_dict['extras']['metadata']['exit_code'] == 0 - assert event_dict['extras']['metadata']['pid'] == 3 - assert event_dict['extras']['command'] == 'ls -l' - assert event_dict['extras']['hidden'] is False - - -def test_file_edit_observation_serialization(): - original_observation_dict = { - 'observation': 'edit', - 'extras': { - '_diff_cache': None, - 'impl_source': FileEditSource.LLM_BASED_EDIT, - 'new_content': None, - 'old_content': None, - 'path': '', - 'prev_exist': False, - 'diff': None, - }, - 'message': 'I edited the file .', - 'content': '[Existing file /path/to/file.txt is edited with 1 changes.]', - } - serialization_deserialization(original_observation_dict, FileEditObservation) - - -def test_file_edit_observation_new_file_serialization(): - original_observation_dict = { - 'observation': 'edit', - 'content': '[New file /path/to/newfile.txt is created with the provided content.]', - 'extras': { - '_diff_cache': None, - 'impl_source': FileEditSource.LLM_BASED_EDIT, - 'new_content': None, - 'old_content': None, - 'path': '', - 'prev_exist': False, - 'diff': None, - }, - 'message': 'I edited the file .', - } - - serialization_deserialization(original_observation_dict, FileEditObservation) - - -def test_file_edit_observation_oh_aci_serialization(): - original_observation_dict = { - 'observation': 'edit', - 'content': 'The file /path/to/file.txt is edited with the provided content.', - 'extras': { - '_diff_cache': None, - 'impl_source': FileEditSource.LLM_BASED_EDIT, - 'new_content': None, - 'old_content': None, - 'path': '', - 'prev_exist': False, - 'diff': None, - }, - 'message': 'I edited the file .', - } - serialization_deserialization(original_observation_dict, FileEditObservation) - - -def test_file_edit_observation_legacy_serialization(): - original_observation_dict = { - 'observation': 'edit', - 'content': 'content', - 'extras': { - 'path': '/workspace/game_2048.py', - 'prev_exist': False, - 'old_content': None, - 'new_content': 'new content', - 'impl_source': 'oh_aci', - 'formatted_output_and_error': 'File created successfully at: /workspace/game_2048.py', - }, - } - - event = event_from_dict(original_observation_dict) - assert isinstance(event, Observation) - assert isinstance(event, FileEditObservation) - assert event.impl_source == FileEditSource.OH_ACI - assert event.path == '/workspace/game_2048.py' - assert event.prev_exist is False - assert event.old_content is None - assert event.new_content == 'new content' - assert not hasattr(event, 'formatted_output_and_error') - - event_dict = event_to_dict(event) - assert event_dict['extras']['impl_source'] == 'oh_aci' - assert event_dict['extras']['path'] == '/workspace/game_2048.py' - assert event_dict['extras']['prev_exist'] is False - assert event_dict['extras']['old_content'] is None - assert event_dict['extras']['new_content'] == 'new content' - assert 'formatted_output_and_error' not in event_dict['extras'] - - -def test_microagent_observation_serialization(): - original_observation_dict = { - 'observation': 'recall', - 'content': '', - 'message': 'Added workspace context', - 'extras': { - 'recall_type': 'workspace_context', - 'repo_name': 'some_repo_name', - 'repo_directory': 'some_repo_directory', - 'repo_branch': '', - 'working_dir': '', - 'runtime_hosts': {'host1': 8080, 'host2': 8081}, - 'repo_instructions': 'complex_repo_instructions', - 'additional_agent_instructions': 'You know it all about this runtime', - 'custom_secrets_descriptions': {'SECRET': 'CUSTOM'}, - 'date': '04/12/1023', - 'microagent_knowledge': [], - 'conversation_instructions': 'additional_context', - }, - } - serialization_deserialization(original_observation_dict, RecallObservation) - - -def test_microagent_observation_microagent_knowledge_serialization(): - original_observation_dict = { - 'observation': 'recall', - 'content': '', - 'message': 'Added microagent knowledge', - 'extras': { - 'recall_type': 'knowledge', - 'repo_name': '', - 'repo_directory': '', - 'repo_branch': '', - 'repo_instructions': '', - 'runtime_hosts': {}, - 'working_dir': '', - 'additional_agent_instructions': '', - 'custom_secrets_descriptions': {}, - 'conversation_instructions': 'additional_context', - 'date': '', - 'microagent_knowledge': [ - { - 'name': 'microagent1', - 'trigger': 'trigger1', - 'content': 'content1', - }, - { - 'name': 'microagent2', - 'trigger': 'trigger2', - 'content': 'content2', - }, - ], - }, - } - serialization_deserialization(original_observation_dict, RecallObservation) - - -def test_microagent_observation_knowledge_microagent_serialization(): - """Test serialization of a RecallObservation with KNOWLEDGE_MICROAGENT type.""" - # Create a RecallObservation with microagent knowledge content - original = RecallObservation( - content='Knowledge microagent information', - recall_type=RecallType.KNOWLEDGE, - repo_branch='', - microagent_knowledge=[ - MicroagentKnowledge( - name='python_best_practices', - trigger='python', - content='Always use virtual environments for Python projects.', - ), - MicroagentKnowledge( - name='git_workflow', - trigger='git', - content='Create a new branch for each feature or bugfix.', - ), - ], - ) - - # Serialize to dictionary - serialized = event_to_dict(original) - - # Verify serialized data structure - assert serialized['observation'] == ObservationType.RECALL - assert serialized['content'] == 'Knowledge microagent information' - assert serialized['extras']['recall_type'] == RecallType.KNOWLEDGE.value - assert len(serialized['extras']['microagent_knowledge']) == 2 - assert serialized['extras']['microagent_knowledge'][0]['trigger'] == 'python' - - # Deserialize back to RecallObservation - deserialized = observation_from_dict(serialized) - - # Verify properties are preserved - assert deserialized.recall_type == RecallType.KNOWLEDGE - assert deserialized.microagent_knowledge == original.microagent_knowledge - assert deserialized.content == original.content - - # Check that environment info fields are empty - assert deserialized.repo_name == '' - assert deserialized.repo_directory == '' - assert deserialized.repo_instructions == '' - assert deserialized.runtime_hosts == {} - - -def test_microagent_observation_environment_serialization(): - """Test serialization of a RecallObservation with ENVIRONMENT type.""" - # Create a RecallObservation with environment info - original = RecallObservation( - content='Environment information', - recall_type=RecallType.WORKSPACE_CONTEXT, - repo_name='OpenHands', - repo_directory='/workspace/openhands', - repo_branch='main', - repo_instructions="Follow the project's coding style guide.", - runtime_hosts={'127.0.0.1': 8080, 'localhost': 5000}, - additional_agent_instructions='You know it all about this runtime', - ) - - # Serialize to dictionary - serialized = event_to_dict(original) - - # Verify serialized data structure - assert serialized['observation'] == ObservationType.RECALL - assert serialized['content'] == 'Environment information' - assert serialized['extras']['recall_type'] == RecallType.WORKSPACE_CONTEXT.value - assert serialized['extras']['repo_name'] == 'OpenHands' - assert serialized['extras']['runtime_hosts'] == { - '127.0.0.1': 8080, - 'localhost': 5000, - } - assert ( - serialized['extras']['additional_agent_instructions'] - == 'You know it all about this runtime' - ) - # Deserialize back to RecallObservation - deserialized = observation_from_dict(serialized) - - # Verify properties are preserved - assert deserialized.recall_type == RecallType.WORKSPACE_CONTEXT - assert deserialized.repo_name == original.repo_name - assert deserialized.repo_directory == original.repo_directory - assert deserialized.repo_instructions == original.repo_instructions - assert deserialized.runtime_hosts == original.runtime_hosts - assert ( - deserialized.additional_agent_instructions - == original.additional_agent_instructions - ) - # Check that knowledge microagent fields are empty - assert deserialized.microagent_knowledge == [] - - -def test_microagent_observation_combined_serialization(): - """Test serialization of a RecallObservation with both types of information.""" - # Create a RecallObservation with both environment and microagent info - # Note: In practice, recall_type would still be one specific type, - # but the object could contain both types of fields - original = RecallObservation( - content='Combined information', - recall_type=RecallType.WORKSPACE_CONTEXT, - # Environment info - repo_name='OpenHands', - repo_directory='/workspace/openhands', - repo_branch='main', - repo_instructions="Follow the project's coding style guide.", - runtime_hosts={'127.0.0.1': 8080}, - additional_agent_instructions='You know it all about this runtime', - # Knowledge microagent info - microagent_knowledge=[ - MicroagentKnowledge( - name='python_best_practices', - trigger='python', - content='Always use virtual environments for Python projects.', - ), - ], - ) - - # Serialize to dictionary - serialized = event_to_dict(original) - - # Verify serialized data has both types of fields - assert serialized['extras']['recall_type'] == RecallType.WORKSPACE_CONTEXT.value - assert serialized['extras']['repo_name'] == 'OpenHands' - assert ( - serialized['extras']['microagent_knowledge'][0]['name'] - == 'python_best_practices' - ) - assert ( - serialized['extras']['additional_agent_instructions'] - == 'You know it all about this runtime' - ) - # Deserialize back to RecallObservation - deserialized = observation_from_dict(serialized) - - # Verify all properties are preserved - assert deserialized.recall_type == RecallType.WORKSPACE_CONTEXT - - # Environment properties - assert deserialized.repo_name == original.repo_name - assert deserialized.repo_directory == original.repo_directory - assert deserialized.repo_instructions == original.repo_instructions - assert deserialized.runtime_hosts == original.runtime_hosts - assert ( - deserialized.additional_agent_instructions - == original.additional_agent_instructions - ) - - # Knowledge microagent properties - assert deserialized.microagent_knowledge == original.microagent_knowledge