mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Remove deprecated openhands.events package (V0) (#14162)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -32,7 +32,6 @@ from openhands.app_server.sandbox.sandbox_models import (
|
||||
SandboxInfo,
|
||||
SandboxStatus,
|
||||
)
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.sdk.event import ConversationStateUpdateEvent
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -76,7 +75,10 @@ def conversation_state_update_event():
|
||||
|
||||
@pytest.fixture
|
||||
def wrong_event():
|
||||
return MessageAction(content='Hello world')
|
||||
"""Return a mock event that is not a ConversationStateUpdateEvent."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.id = uuid4()
|
||||
return mock_event
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -28,7 +28,6 @@ from openhands.app_server.sandbox.sandbox_models import (
|
||||
SandboxInfo,
|
||||
SandboxStatus,
|
||||
)
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.sdk.event import ConversationStateUpdateEvent
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -73,7 +72,10 @@ def conversation_state_update_event():
|
||||
|
||||
@pytest.fixture
|
||||
def wrong_event():
|
||||
return MessageAction(content='Hello world')
|
||||
"""Return a mock event that is not a ConversationStateUpdateEvent."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.id = uuid4()
|
||||
return mock_event
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -28,9 +28,16 @@ from openhands.app_server.sandbox.sandbox_models import (
|
||||
SandboxInfo,
|
||||
SandboxStatus,
|
||||
)
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.sdk.event import ConversationStateUpdateEvent
|
||||
|
||||
|
||||
def _create_mock_event():
|
||||
"""Create a mock event that is not a ConversationStateUpdateEvent."""
|
||||
mock_event = MagicMock()
|
||||
mock_event.id = uuid4()
|
||||
return mock_event
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -105,8 +112,10 @@ class TestSlackV1CallbackProcessor:
|
||||
@pytest.mark.parametrize(
|
||||
'event,expected_result',
|
||||
[
|
||||
# Wrong event types should be ignored
|
||||
(MessageAction(content='Hello world'), None),
|
||||
# Wrong event types should be ignored (use lazy evaluation for mock)
|
||||
pytest.param(
|
||||
None, None, id='wrong_event_type', marks=pytest.mark.wrong_event_type
|
||||
),
|
||||
# Wrong state values should be ignored
|
||||
(
|
||||
ConversationStateUpdateEvent(key='execution_status', value='running'),
|
||||
@@ -120,9 +129,12 @@ class TestSlackV1CallbackProcessor:
|
||||
],
|
||||
)
|
||||
async def test_event_filtering(
|
||||
self, slack_callback_processor, event_callback, event, expected_result
|
||||
self, slack_callback_processor, event_callback, event, expected_result, request
|
||||
):
|
||||
"""Test that processor correctly filters events."""
|
||||
# Handle the mock event case specially
|
||||
if event is None and 'wrong_event_type' in request.node.name:
|
||||
event = _create_mock_event()
|
||||
result = await slack_callback_processor(uuid4(), event_callback, event)
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.recall_type import RecallType
|
||||
from openhands.events.stream import EventStream, EventStreamSubscriber
|
||||
|
||||
__all__ = [
|
||||
'Event',
|
||||
'EventSource',
|
||||
'EventStream',
|
||||
'EventStreamSubscriber',
|
||||
'RecallType',
|
||||
]
|
||||
@@ -1,50 +0,0 @@
|
||||
from openhands.events.action.action import (
|
||||
Action,
|
||||
ActionConfirmationStatus,
|
||||
ActionSecurityRisk,
|
||||
)
|
||||
from openhands.events.action.agent import (
|
||||
AgentDelegateAction,
|
||||
AgentFinishAction,
|
||||
AgentRejectAction,
|
||||
AgentThinkAction,
|
||||
ChangeAgentStateAction,
|
||||
LoopRecoveryAction,
|
||||
RecallAction,
|
||||
TaskTrackingAction,
|
||||
)
|
||||
from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction
|
||||
from openhands.events.action.commands import CmdRunAction, IPythonRunCellAction
|
||||
from openhands.events.action.empty import NullAction
|
||||
from openhands.events.action.files import (
|
||||
FileEditAction,
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
)
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.action.message import MessageAction, SystemMessageAction
|
||||
|
||||
__all__ = [
|
||||
'Action',
|
||||
'NullAction',
|
||||
'CmdRunAction',
|
||||
'BrowseURLAction',
|
||||
'BrowseInteractiveAction',
|
||||
'FileReadAction',
|
||||
'FileWriteAction',
|
||||
'FileEditAction',
|
||||
'AgentFinishAction',
|
||||
'AgentRejectAction',
|
||||
'AgentDelegateAction',
|
||||
'ChangeAgentStateAction',
|
||||
'IPythonRunCellAction',
|
||||
'MessageAction',
|
||||
'SystemMessageAction',
|
||||
'ActionConfirmationStatus',
|
||||
'AgentThinkAction',
|
||||
'RecallAction',
|
||||
'MCPAction',
|
||||
'TaskTrackingAction',
|
||||
'ActionSecurityRisk',
|
||||
'LoopRecoveryAction',
|
||||
]
|
||||
@@ -1,23 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import ClassVar
|
||||
|
||||
from openhands.events.event import Event
|
||||
|
||||
|
||||
class ActionConfirmationStatus(str, Enum):
|
||||
CONFIRMED = 'confirmed'
|
||||
REJECTED = 'rejected'
|
||||
AWAITING_CONFIRMATION = 'awaiting_confirmation'
|
||||
|
||||
|
||||
class ActionSecurityRisk(int, Enum):
|
||||
UNKNOWN = -1
|
||||
LOW = 0
|
||||
MEDIUM = 1
|
||||
HIGH = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class Action(Event):
|
||||
runnable: ClassVar[bool] = False
|
||||
@@ -1,242 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.recall_type import RecallType
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChangeAgentStateAction(Action):
|
||||
"""Fake action, just to notify the client that a task state has changed."""
|
||||
|
||||
agent_state: str
|
||||
thought: str = ''
|
||||
action: str = ActionType.CHANGE_AGENT_STATE
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Agent state changed to {self.agent_state}'
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentFinishAction(Action):
|
||||
"""An action where the agent finishes the task.
|
||||
|
||||
Attributes:
|
||||
final_thought (str): The message to send to the user.
|
||||
outputs (dict): The other outputs of the agent, for instance "content".
|
||||
thought (str): The agent's explanation of its actions.
|
||||
action (str): The action type, namely ActionType.FINISH.
|
||||
"""
|
||||
|
||||
final_thought: str = ''
|
||||
outputs: dict[str, Any] = field(default_factory=dict)
|
||||
thought: str = ''
|
||||
action: str = ActionType.FINISH
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
if self.thought != '':
|
||||
return self.thought
|
||||
return "All done! What's next on the agenda?"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentThinkAction(Action):
|
||||
"""An action where the agent logs a thought.
|
||||
|
||||
Attributes:
|
||||
thought (str): The agent's explanation of its actions.
|
||||
action (str): The action type, namely ActionType.THINK.
|
||||
"""
|
||||
|
||||
thought: str = ''
|
||||
action: str = ActionType.THINK
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'I am thinking...: {self.thought}'
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentRejectAction(Action):
|
||||
outputs: dict = field(default_factory=dict)
|
||||
thought: str = ''
|
||||
action: str = ActionType.REJECT
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
msg: str = 'Task is rejected by the agent.'
|
||||
if 'reason' in self.outputs:
|
||||
msg += ' Reason: ' + self.outputs['reason']
|
||||
return msg
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentDelegateAction(Action):
|
||||
agent: str
|
||||
inputs: dict
|
||||
thought: str = ''
|
||||
action: str = ActionType.DELEGATE
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f"I'm asking {self.agent} for help with this task."
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecallAction(Action):
|
||||
"""This action is used for retrieving content, e.g., from the global directory or user workspace."""
|
||||
|
||||
recall_type: RecallType
|
||||
query: str = ''
|
||||
thought: str = ''
|
||||
action: str = ActionType.RECALL
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Retrieving content for: {self.query[:50]}'
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = '**RecallAction**\n'
|
||||
ret += f'QUERY: {self.query[:50]}'
|
||||
return ret
|
||||
|
||||
|
||||
@dataclass
|
||||
class CondensationAction(Action):
|
||||
"""This action indicates a condensation of the conversation history is happening.
|
||||
|
||||
There are two ways to specify the events to be forgotten:
|
||||
1. By providing a list of event IDs.
|
||||
2. By providing the start and end IDs of a range of events.
|
||||
|
||||
In the second case, we assume that event IDs are monotonically increasing, and that _all_ events between the start and end IDs are to be forgotten.
|
||||
|
||||
Raises:
|
||||
ValueError: If the optional fields are not instantiated in a valid configuration.
|
||||
"""
|
||||
|
||||
action: str = ActionType.CONDENSATION
|
||||
|
||||
forgotten_event_ids: list[int] | None = None
|
||||
"""The IDs of the events that are being forgotten (removed from the `View` given to the LLM)."""
|
||||
|
||||
forgotten_events_start_id: int | None = None
|
||||
"""The ID of the first event to be forgotten in a range of events."""
|
||||
|
||||
forgotten_events_end_id: int | None = None
|
||||
"""The ID of the last event to be forgotten in a range of events."""
|
||||
|
||||
summary: str | None = None
|
||||
"""An optional summary of the events being forgotten."""
|
||||
|
||||
summary_offset: int | None = None
|
||||
"""An optional offset to the start of the resulting view indicating where the summary should be inserted."""
|
||||
|
||||
def _validate_field_polymorphism(self) -> bool:
|
||||
"""Check if the optional fields are instantiated in a valid configuration."""
|
||||
# For the forgotton events, there are only two valid configurations:
|
||||
# 1. We're forgetting events based on the list of provided IDs, or
|
||||
using_event_ids = self.forgotten_event_ids is not None
|
||||
# 2. We're forgetting events based on the range of IDs.
|
||||
using_event_range = (
|
||||
self.forgotten_events_start_id is not None
|
||||
and self.forgotten_events_end_id is not None
|
||||
)
|
||||
|
||||
# Either way, we can only have one of the two valid configurations.
|
||||
forgotten_event_configuration = using_event_ids ^ using_event_range
|
||||
|
||||
# We also need to check that if the summary is provided, so is the
|
||||
# offset (and vice versa).
|
||||
summary_configuration = (
|
||||
self.summary is None and self.summary_offset is None
|
||||
) or (self.summary is not None and self.summary_offset is not None)
|
||||
|
||||
return forgotten_event_configuration and summary_configuration
|
||||
|
||||
def __post_init__(self):
|
||||
if not self._validate_field_polymorphism():
|
||||
raise ValueError('Invalid configuration of the optional fields.')
|
||||
|
||||
@property
|
||||
def forgotten(self) -> list[int]:
|
||||
"""The list of event IDs that should be forgotten."""
|
||||
# Start by making sure the fields are instantiated in a valid
|
||||
# configuration. We check this whenever the event is initialized, but we
|
||||
# can't make the dataclass immutable so we need to check it again here
|
||||
# to make sure the configuration is still valid.
|
||||
if not self._validate_field_polymorphism():
|
||||
raise ValueError('Invalid configuration of the optional fields.')
|
||||
|
||||
if self.forgotten_event_ids is not None:
|
||||
return self.forgotten_event_ids
|
||||
|
||||
# If we've gotten this far, the start/end IDs are not None.
|
||||
assert self.forgotten_events_start_id is not None
|
||||
assert self.forgotten_events_end_id is not None
|
||||
return list(
|
||||
range(self.forgotten_events_start_id, self.forgotten_events_end_id + 1)
|
||||
)
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
if self.summary:
|
||||
return f'Summary: {self.summary}'
|
||||
return f'Condenser is dropping the events: {self.forgotten}.'
|
||||
|
||||
|
||||
@dataclass
|
||||
class CondensationRequestAction(Action):
|
||||
"""This action is used to request a condensation of the conversation history.
|
||||
|
||||
Attributes:
|
||||
action (str): The action type, namely ActionType.CONDENSATION_REQUEST.
|
||||
"""
|
||||
|
||||
action: str = ActionType.CONDENSATION_REQUEST
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return 'Requesting a condensation of the conversation history.'
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskTrackingAction(Action):
|
||||
"""An action where the agent writes or updates a task list for task management.
|
||||
Attributes:
|
||||
task_list (list): The list of task items with their status and metadata.
|
||||
thought (str): The agent's explanation of its actions.
|
||||
action (str): The action type, namely ActionType.TASK_TRACKING.
|
||||
"""
|
||||
|
||||
command: str = 'view'
|
||||
task_list: list[dict[str, Any]] = field(default_factory=list)
|
||||
thought: str = ''
|
||||
action: str = ActionType.TASK_TRACKING
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
num_tasks = len(self.task_list)
|
||||
if num_tasks == 0:
|
||||
return 'Clearing the task list.'
|
||||
elif num_tasks == 1:
|
||||
return 'Managing 1 task item.'
|
||||
else:
|
||||
return f'Managing {num_tasks} task items.'
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoopRecoveryAction(Action):
|
||||
"""An action that shows three ways to handle dead loop.
|
||||
The class should be invisible to LLM.
|
||||
Attributes:
|
||||
option (int): 1 allow user to prompt again
|
||||
2 automatically use latest user prompt
|
||||
3 stop agent
|
||||
"""
|
||||
|
||||
option: int = 1
|
||||
action: str = ActionType.LOOP_RECOVERY
|
||||
@@ -1,48 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.events.action.action import Action, ActionSecurityRisk
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrowseURLAction(Action):
|
||||
url: str
|
||||
thought: str = ''
|
||||
action: str = ActionType.BROWSE
|
||||
runnable: ClassVar[bool] = True
|
||||
security_risk: ActionSecurityRisk = ActionSecurityRisk.UNKNOWN
|
||||
return_axtree: bool = False
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'I am browsing the URL: {self.url}'
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = '**BrowseURLAction**\n'
|
||||
if self.thought:
|
||||
ret += f'THOUGHT: {self.thought}\n'
|
||||
ret += f'URL: {self.url}'
|
||||
return ret
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrowseInteractiveAction(Action):
|
||||
browser_actions: str
|
||||
thought: str = ''
|
||||
browsergym_send_msg_to_user: str = ''
|
||||
action: str = ActionType.BROWSE_INTERACTIVE
|
||||
runnable: ClassVar[bool] = True
|
||||
security_risk: ActionSecurityRisk = ActionSecurityRisk.UNKNOWN
|
||||
return_axtree: bool = False
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'I am interacting with the browser:\n```\n{self.browser_actions}\n```'
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = '**BrowseInteractiveAction**\n'
|
||||
if self.thought:
|
||||
ret += f'THOUGHT: {self.thought}\n'
|
||||
ret += f'BROWSER_ACTIONS: {self.browser_actions}'
|
||||
return ret
|
||||
@@ -1,64 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.events.action.action import (
|
||||
Action,
|
||||
ActionConfirmationStatus,
|
||||
ActionSecurityRisk,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CmdRunAction(Action):
|
||||
command: (
|
||||
str # When `command` is empty, it will be used to print the current tmux window
|
||||
)
|
||||
is_input: bool = False # if True, the command is an input to the running process
|
||||
thought: str = ''
|
||||
blocking: bool = False # if True, the command will be run in a blocking manner, but a timeout must be set through _set_hard_timeout
|
||||
is_static: bool = False # if True, runs the command in a separate process
|
||||
cwd: str | None = None # current working directory, only used if is_static is True
|
||||
hidden: bool = (
|
||||
False # if True, this command does not go through the LLM or event stream
|
||||
)
|
||||
action: str = ActionType.RUN
|
||||
runnable: ClassVar[bool] = True
|
||||
confirmation_state: ActionConfirmationStatus = ActionConfirmationStatus.CONFIRMED
|
||||
security_risk: ActionSecurityRisk = ActionSecurityRisk.UNKNOWN
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Running command: {self.command}'
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = f'**CmdRunAction (source={self.source}, is_input={self.is_input})**\n'
|
||||
if self.thought:
|
||||
ret += f'THOUGHT: {self.thought}\n'
|
||||
ret += f'COMMAND:\n{self.command}'
|
||||
return ret
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPythonRunCellAction(Action):
|
||||
code: str
|
||||
thought: str = ''
|
||||
include_extra: bool = (
|
||||
True # whether to include CWD & Python interpreter in the output
|
||||
)
|
||||
action: str = ActionType.RUN_IPYTHON
|
||||
runnable: ClassVar[bool] = True
|
||||
confirmation_state: ActionConfirmationStatus = ActionConfirmationStatus.CONFIRMED
|
||||
security_risk: ActionSecurityRisk = ActionSecurityRisk.UNKNOWN
|
||||
kernel_init_code: str = '' # code to run in the kernel (if the kernel is restarted)
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = '**IPythonRunCellAction**\n'
|
||||
if self.thought:
|
||||
ret += f'THOUGHT: {self.thought}\n'
|
||||
ret += f'CODE:\n{self.code}'
|
||||
return ret
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Running Python code interactively: {self.code}'
|
||||
@@ -1,15 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.events.action.action import Action
|
||||
|
||||
|
||||
@dataclass
|
||||
class NullAction(Action):
|
||||
"""An action that does nothing."""
|
||||
|
||||
action: str = ActionType.NULL
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return 'No action'
|
||||
@@ -1,138 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.events.action.action import Action, ActionSecurityRisk
|
||||
from openhands.events.event import FileEditSource, FileReadSource
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileReadAction(Action):
|
||||
"""Reads a file from a given path.
|
||||
Can be set to read specific lines using start and end
|
||||
Default lines 0:-1 (whole file)
|
||||
"""
|
||||
|
||||
path: str
|
||||
start: int = 0
|
||||
end: int = -1
|
||||
thought: str = ''
|
||||
action: str = ActionType.READ
|
||||
runnable: ClassVar[bool] = True
|
||||
security_risk: ActionSecurityRisk = ActionSecurityRisk.UNKNOWN
|
||||
impl_source: FileReadSource = FileReadSource.DEFAULT
|
||||
view_range: list[int] | None = None # ONLY used in OH_ACI mode
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Reading file: {self.path}'
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileWriteAction(Action):
|
||||
"""Writes a file to a given path.
|
||||
Can be set to write specific lines using start and end
|
||||
Default lines 0:-1 (whole file)
|
||||
"""
|
||||
|
||||
path: str
|
||||
content: str
|
||||
start: int = 0
|
||||
end: int = -1
|
||||
thought: str = ''
|
||||
action: str = ActionType.WRITE
|
||||
runnable: ClassVar[bool] = True
|
||||
security_risk: ActionSecurityRisk = ActionSecurityRisk.UNKNOWN
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Writing file: {self.path}'
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'**FileWriteAction**\n'
|
||||
f'Path: {self.path}\n'
|
||||
f'Range: [L{self.start}:L{self.end}]\n'
|
||||
f'Thought: {self.thought}\n'
|
||||
f'Content:\n```\n{self.content}\n```\n'
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileEditAction(Action):
|
||||
"""Edits a file using various commands including view, create, str_replace, insert, and undo_edit.
|
||||
|
||||
This class supports two main modes of operation:
|
||||
1. LLM-based editing (impl_source = FileEditSource.LLM_BASED_EDIT)
|
||||
2. ACI-based editing (impl_source = FileEditSource.OH_ACI)
|
||||
|
||||
Attributes:
|
||||
path (str): The path to the file being edited. Works for both LLM-based and OH_ACI editing.
|
||||
OH_ACI only arguments:
|
||||
command (str): The editing command to be performed (view, create, str_replace, insert, undo_edit, write).
|
||||
file_text (str): The content of the file to be created (used with 'create' command in OH_ACI mode).
|
||||
old_str (str): The string to be replaced (used with 'str_replace' command in OH_ACI mode).
|
||||
new_str (str): The string to replace old_str (used with 'str_replace' and 'insert' commands in OH_ACI mode).
|
||||
insert_line (int): The line number after which to insert new_str (used with 'insert' command in OH_ACI mode).
|
||||
LLM-based editing arguments:
|
||||
content (str): The content to be written or edited in the file (used in LLM-based editing and 'write' command).
|
||||
start (int): The starting line for editing (1-indexed, inclusive). Default is 1.
|
||||
end (int): The ending line for editing (1-indexed, inclusive). Default is -1 (end of file).
|
||||
thought (str): The reasoning behind the edit action.
|
||||
action (str): The type of action being performed (always ActionType.EDIT).
|
||||
runnable (bool): Indicates if the action can be executed (always True).
|
||||
security_risk (ActionSecurityRisk | None): Indicates any security risks associated with the action.
|
||||
impl_source (FileEditSource): The source of the implementation (LLM_BASED_EDIT or OH_ACI).
|
||||
|
||||
Usage:
|
||||
- For LLM-based editing: Use path, content, start, and end attributes.
|
||||
- For ACI-based editing: Use path, command, and the appropriate attributes for the specific command.
|
||||
|
||||
Note:
|
||||
- If start is set to -1 in LLM-based editing, the content will be appended to the file.
|
||||
- The 'write' command behaves similarly to LLM-based editing, using content, start, and end attributes.
|
||||
"""
|
||||
|
||||
path: str
|
||||
|
||||
# OH_ACI arguments
|
||||
command: str = ''
|
||||
file_text: str | None = None
|
||||
old_str: str | None = None
|
||||
new_str: str | None = None
|
||||
insert_line: int | None = None
|
||||
|
||||
# LLM-based editing arguments
|
||||
content: str = ''
|
||||
start: int = 1
|
||||
end: int = -1
|
||||
|
||||
# Shared arguments
|
||||
thought: str = ''
|
||||
action: str = ActionType.EDIT
|
||||
runnable: ClassVar[bool] = True
|
||||
security_risk: ActionSecurityRisk = ActionSecurityRisk.UNKNOWN
|
||||
impl_source: FileEditSource = FileEditSource.OH_ACI
|
||||
|
||||
def __repr__(self) -> str:
|
||||
ret = '**FileEditAction**\n'
|
||||
ret += f'Path: [{self.path}]\n'
|
||||
ret += f'Thought: {self.thought}\n'
|
||||
|
||||
if self.impl_source == FileEditSource.LLM_BASED_EDIT:
|
||||
ret += f'Range: [L{self.start}:L{self.end}]\n'
|
||||
ret += f'Content:\n```\n{self.content}\n```\n'
|
||||
else: # OH_ACI mode
|
||||
ret += f'Command: {self.command}\n'
|
||||
if self.command == 'create':
|
||||
ret += f'Created File with Text:\n```\n{self.file_text}\n```\n'
|
||||
elif self.command == 'str_replace':
|
||||
ret += f'Old String: ```\n{self.old_str}\n```\n'
|
||||
ret += f'New String: ```\n{self.new_str}\n```\n'
|
||||
elif self.command == 'insert':
|
||||
ret += f'Insert Line: {self.insert_line}\n'
|
||||
ret += f'New String: ```\n{self.new_str}\n```\n'
|
||||
elif self.command == 'undo_edit':
|
||||
ret += 'Undo Edit\n'
|
||||
# We ignore "view" command because it will be mapped to a FileReadAction
|
||||
return ret
|
||||
@@ -1,32 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.events.action.action import Action, ActionSecurityRisk
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPAction(Action):
|
||||
name: str
|
||||
arguments: dict[str, Any] = field(default_factory=dict)
|
||||
thought: str = ''
|
||||
action: str = ActionType.MCP
|
||||
runnable: ClassVar[bool] = True
|
||||
security_risk: ActionSecurityRisk = ActionSecurityRisk.UNKNOWN
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return (
|
||||
f'I am interacting with the MCP server with name:\n'
|
||||
f'```\n{self.name}\n```\n'
|
||||
f'and arguments:\n'
|
||||
f'```\n{self.arguments}\n```'
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = '**MCPAction**\n'
|
||||
if self.thought:
|
||||
ret += f'THOUGHT: {self.thought}\n'
|
||||
ret += f'NAME: {self.name}\n'
|
||||
ret += f'ARGUMENTS: {self.arguments}'
|
||||
return ret
|
||||
@@ -1,66 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from openhands.core.schema import ActionType
|
||||
from openhands.events.action.action import Action, ActionSecurityRisk
|
||||
from openhands.version import get_version
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageAction(Action):
|
||||
content: str
|
||||
file_urls: list[str] | None = None
|
||||
image_urls: list[str] | None = None
|
||||
wait_for_response: bool = False
|
||||
action: str = ActionType.MESSAGE
|
||||
security_risk: ActionSecurityRisk = ActionSecurityRisk.UNKNOWN
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.content
|
||||
|
||||
@property
|
||||
def images_urls(self) -> list[str] | None:
|
||||
# Deprecated alias for backward compatibility
|
||||
return self.image_urls
|
||||
|
||||
@images_urls.setter
|
||||
def images_urls(self, value: list[str] | None) -> None:
|
||||
self.image_urls = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = f'**MessageAction** (source={self.source})\n'
|
||||
ret += f'CONTENT: {self.content}'
|
||||
if self.image_urls:
|
||||
for url in self.image_urls:
|
||||
ret += f'\nIMAGE_URL: {url}'
|
||||
if self.file_urls:
|
||||
for url in self.file_urls:
|
||||
ret += f'\nFILE_URL: {url}'
|
||||
return ret
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemMessageAction(Action):
|
||||
"""Action that represents a system message for an agent, including the system prompt
|
||||
and available tools. This should be the first message in the event stream.
|
||||
"""
|
||||
|
||||
content: str
|
||||
tools: list[Any] | None = None
|
||||
openhands_version: str | None = get_version()
|
||||
agent_class: str | None = None
|
||||
action: ActionType = ActionType.SYSTEM
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.content
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = f'**SystemMessageAction** (source={self.source})\n'
|
||||
ret += f'CONTENT: {self.content}'
|
||||
if self.tools:
|
||||
ret += f'\nTOOLS: {len(self.tools)} tools available'
|
||||
if self.agent_class:
|
||||
ret += f'\nAGENT_CLASS: {self.agent_class}'
|
||||
return ret
|
||||
@@ -1,121 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from openhands.events.metrics import Metrics
|
||||
from openhands.events.tool import ToolCallMetadata
|
||||
|
||||
|
||||
class EventSource(str, Enum):
|
||||
AGENT = 'agent'
|
||||
USER = 'user'
|
||||
ENVIRONMENT = 'environment'
|
||||
|
||||
|
||||
class FileEditSource(str, Enum):
|
||||
LLM_BASED_EDIT = 'llm_based_edit'
|
||||
OH_ACI = 'oh_aci' # openhands-aci
|
||||
|
||||
|
||||
class FileReadSource(str, Enum):
|
||||
OH_ACI = 'oh_aci' # openhands-aci
|
||||
DEFAULT = 'default'
|
||||
|
||||
|
||||
@dataclass
|
||||
class Event:
|
||||
INVALID_ID = -1
|
||||
|
||||
@property
|
||||
def message(self) -> str | None:
|
||||
if hasattr(self, '_message'):
|
||||
msg = getattr(self, '_message')
|
||||
return str(msg) if msg is not None else None
|
||||
return ''
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
if hasattr(self, '_id'):
|
||||
id_val = getattr(self, '_id')
|
||||
return int(id_val) if id_val is not None else Event.INVALID_ID
|
||||
return Event.INVALID_ID
|
||||
|
||||
@property
|
||||
def timestamp(self) -> str | None:
|
||||
if hasattr(self, '_timestamp') and isinstance(self._timestamp, str):
|
||||
ts = getattr(self, '_timestamp')
|
||||
return str(ts) if ts is not None else None
|
||||
return None
|
||||
|
||||
@timestamp.setter
|
||||
def timestamp(self, value: datetime) -> None:
|
||||
if isinstance(value, datetime):
|
||||
self._timestamp = value.isoformat()
|
||||
|
||||
@property
|
||||
def source(self) -> EventSource | None:
|
||||
if hasattr(self, '_source'):
|
||||
src = getattr(self, '_source')
|
||||
return EventSource(src) if src is not None else None
|
||||
return None
|
||||
|
||||
@property
|
||||
def cause(self) -> int | None:
|
||||
if hasattr(self, '_cause'):
|
||||
cause_val = getattr(self, '_cause')
|
||||
return int(cause_val) if cause_val is not None else None
|
||||
return None
|
||||
|
||||
@property
|
||||
def timeout(self) -> float | None:
|
||||
if hasattr(self, '_timeout'):
|
||||
timeout_val = getattr(self, '_timeout')
|
||||
return float(timeout_val) if timeout_val is not None else None
|
||||
return None
|
||||
|
||||
def set_hard_timeout(self, value: float | None, blocking: bool = True) -> None:
|
||||
"""Set the timeout for the event.
|
||||
|
||||
NOTE, this is a hard timeout, meaning that the event will be blocked
|
||||
until the timeout is reached.
|
||||
"""
|
||||
self._timeout = value
|
||||
# Check if .blocking is an attribute of the event
|
||||
if hasattr(self, 'blocking'):
|
||||
# .blocking needs to be set to True if .timeout is set
|
||||
self.blocking = blocking
|
||||
|
||||
# optional metadata, LLM call cost of the edit
|
||||
@property
|
||||
def llm_metrics(self) -> Metrics | None:
|
||||
if hasattr(self, '_llm_metrics'):
|
||||
metrics = getattr(self, '_llm_metrics')
|
||||
return metrics if isinstance(metrics, Metrics) else None
|
||||
return None
|
||||
|
||||
@llm_metrics.setter
|
||||
def llm_metrics(self, value: Metrics) -> None:
|
||||
self._llm_metrics = value
|
||||
|
||||
# optional field, metadata about the tool call, if the event has a tool call
|
||||
@property
|
||||
def tool_call_metadata(self) -> ToolCallMetadata | None:
|
||||
if hasattr(self, '_tool_call_metadata'):
|
||||
metadata = getattr(self, '_tool_call_metadata')
|
||||
return metadata if isinstance(metadata, ToolCallMetadata) else None
|
||||
return None
|
||||
|
||||
@tool_call_metadata.setter
|
||||
def tool_call_metadata(self, value: ToolCallMetadata) -> None:
|
||||
self._tool_call_metadata = value
|
||||
|
||||
# optional field, the id of the response from the LLM
|
||||
@property
|
||||
def response_id(self) -> str | None:
|
||||
if hasattr(self, '_response_id'):
|
||||
return self._response_id # type: ignore[attr-defined]
|
||||
return None
|
||||
|
||||
@response_id.setter
|
||||
def response_id(self, value: str) -> None:
|
||||
self._response_id = value
|
||||
@@ -1,98 +0,0 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventFilter:
|
||||
"""A filter for Event objects in the event stream.
|
||||
|
||||
EventFilter provides a flexible way to filter events based on various criteria
|
||||
such as event type, source, date range, and content. It can be used to include
|
||||
or exclude events from search results based on the specified criteria.
|
||||
|
||||
Attributes:
|
||||
exclude_hidden: Whether to exclude events marked as hidden. Defaults to False.
|
||||
query: Text string to search for in event content. Case-insensitive. Defaults to None.
|
||||
include_types: Tuple of Event types to include. Only events of these types will pass the filter.
|
||||
Defaults to None (include all types).
|
||||
exclude_types: Tuple of Event types to exclude. Events of these types will be filtered out.
|
||||
Defaults to None (exclude no types).
|
||||
source: Filter by event source (e.g., 'agent', 'user', 'environment'). Defaults to None.
|
||||
start_date: ISO format date string. Only events after this date will pass the filter.
|
||||
Defaults to None.
|
||||
end_date: ISO format date string. Only events before this date will pass the filter.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
exclude_hidden: bool = False
|
||||
query: str | None = None
|
||||
include_types: tuple[type[Event], ...] | None = None
|
||||
exclude_types: tuple[type[Event], ...] | None = None
|
||||
source: str | None = None
|
||||
start_date: str | None = None
|
||||
end_date: str | None = None
|
||||
|
||||
def include(self, event: Event) -> bool:
|
||||
"""Determine if an event should be included based on the filter criteria.
|
||||
|
||||
This method checks if the given event matches all the filter criteria.
|
||||
If any criterion fails, the event is excluded.
|
||||
|
||||
Args:
|
||||
event: The Event object to check against the filter criteria.
|
||||
|
||||
Returns:
|
||||
bool: True if the event passes all filter criteria and should be included,
|
||||
False otherwise.
|
||||
"""
|
||||
if self.include_types and not isinstance(event, self.include_types):
|
||||
return False
|
||||
|
||||
if self.exclude_types is not None and isinstance(event, self.exclude_types):
|
||||
return False
|
||||
|
||||
if self.source:
|
||||
if event.source is None or event.source.value != self.source:
|
||||
return False
|
||||
|
||||
if (
|
||||
self.start_date
|
||||
and event.timestamp is not None
|
||||
and event.timestamp < self.start_date
|
||||
):
|
||||
return False
|
||||
|
||||
if (
|
||||
self.end_date
|
||||
and event.timestamp is not None
|
||||
and event.timestamp > self.end_date
|
||||
):
|
||||
return False
|
||||
|
||||
if self.exclude_hidden and getattr(event, 'hidden', False):
|
||||
return False
|
||||
|
||||
# Text search in event content if query provided
|
||||
if self.query:
|
||||
event_dict = event_to_dict(event)
|
||||
event_str = json.dumps(event_dict).lower()
|
||||
if self.query.lower() not in event_str:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def exclude(self, event: Event) -> bool:
|
||||
"""Determine if an event should be excluded based on the filter criteria.
|
||||
|
||||
This is the inverse of the include method.
|
||||
|
||||
Args:
|
||||
event: The Event object to check against the filter criteria.
|
||||
|
||||
Returns:
|
||||
bool: True if the event should be excluded, False if it should be included.
|
||||
"""
|
||||
return not self.include(event)
|
||||
@@ -1,183 +0,0 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.event_store_abc import EventStoreABC
|
||||
from openhands.events.serialization.event import event_from_dict
|
||||
from openhands.storage.files import FileStore
|
||||
from openhands.storage.locations import (
|
||||
get_conversation_dir,
|
||||
get_conversation_event_filename,
|
||||
get_conversation_events_dir,
|
||||
)
|
||||
from openhands.utils.shutdown_listener import should_continue
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _CachePage:
|
||||
events: list[dict] | None
|
||||
start: int
|
||||
end: int
|
||||
|
||||
def covers(self, global_index: int) -> bool:
|
||||
if global_index < self.start:
|
||||
return False
|
||||
if global_index >= self.end:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_event(self, global_index: int) -> Event | None:
|
||||
# If there was not actually a cached page, return None
|
||||
if not self.events:
|
||||
return None
|
||||
local_index = global_index - self.start
|
||||
return event_from_dict(self.events[local_index])
|
||||
|
||||
|
||||
_DUMMY_PAGE = _CachePage(None, 1, -1)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventStore(EventStoreABC):
|
||||
"""A stored list of events backing a conversation"""
|
||||
|
||||
sid: str
|
||||
file_store: FileStore
|
||||
user_id: str | None
|
||||
cache_size: int = 25
|
||||
_cur_id: int | None = None # Private field to cache the calculated value
|
||||
|
||||
@property
|
||||
def cur_id(self) -> int:
|
||||
"""Lazy calculated property for the current event ID."""
|
||||
if self._cur_id is None:
|
||||
self._cur_id = self._calculate_cur_id()
|
||||
return self._cur_id
|
||||
|
||||
@cur_id.setter
|
||||
def cur_id(self, value: int) -> None:
|
||||
"""Setter for cur_id to allow updates."""
|
||||
self._cur_id = value
|
||||
|
||||
def _calculate_cur_id(self) -> int:
|
||||
"""Calculate the current event ID based on file system content."""
|
||||
events = []
|
||||
try:
|
||||
events_dir = get_conversation_events_dir(self.sid, self.user_id)
|
||||
events = self.file_store.list(events_dir)
|
||||
except FileNotFoundError:
|
||||
logger.debug(f'No events found for session {self.sid} at {events_dir}')
|
||||
|
||||
if not events:
|
||||
return 0
|
||||
|
||||
# if we have events, we need to find the highest id to prepare for new events
|
||||
max_id = -1
|
||||
for event_str in events:
|
||||
id = self._get_id_from_filename(event_str)
|
||||
if id >= max_id:
|
||||
max_id = id
|
||||
return max_id + 1
|
||||
|
||||
def search_events(
|
||||
self,
|
||||
start_id: int = 0,
|
||||
end_id: int | None = None,
|
||||
reverse: bool = False,
|
||||
filter: EventFilter | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterable[Event]:
|
||||
"""Retrieve events from the event stream, optionally filtering out events of a given type
|
||||
and events marked as hidden.
|
||||
|
||||
Args:
|
||||
start_id: The ID of the first event to retrieve. Defaults to 0.
|
||||
end_id: The ID of the last event to retrieve. Defaults to the last event in the stream.
|
||||
reverse: Whether to retrieve events in reverse order. Defaults to False.
|
||||
filter: EventFilter to use
|
||||
|
||||
Yields:
|
||||
Events from the stream that match the criteria.
|
||||
"""
|
||||
if end_id is None:
|
||||
end_id = self.cur_id
|
||||
else:
|
||||
end_id += 1 # From inclusive to exclusive
|
||||
|
||||
if reverse:
|
||||
step = -1
|
||||
start_id, end_id = end_id, start_id
|
||||
start_id -= 1
|
||||
end_id -= 1
|
||||
else:
|
||||
step = 1
|
||||
|
||||
cache_page = _DUMMY_PAGE
|
||||
num_results = 0
|
||||
for index in range(start_id, end_id, step):
|
||||
if not should_continue():
|
||||
return
|
||||
if not cache_page.covers(index):
|
||||
cache_page = self._load_cache_page_for_index(index)
|
||||
event = cache_page.get_event(index)
|
||||
if event is None:
|
||||
try:
|
||||
event = self.get_event(index)
|
||||
except FileNotFoundError:
|
||||
event = None
|
||||
if event:
|
||||
if not filter or filter.include(event):
|
||||
yield event
|
||||
num_results += 1
|
||||
if limit and limit <= num_results:
|
||||
return
|
||||
|
||||
def get_event(self, id: int) -> Event:
|
||||
filename = self._get_filename_for_id(id, self.user_id)
|
||||
content = self.file_store.read(filename)
|
||||
data = json.loads(content)
|
||||
return event_from_dict(data)
|
||||
|
||||
def get_latest_event(self) -> Event:
|
||||
return self.get_event(self.cur_id - 1)
|
||||
|
||||
def get_latest_event_id(self) -> int:
|
||||
return self.cur_id - 1
|
||||
|
||||
def filtered_events_by_source(self, source: EventSource) -> Iterable[Event]:
|
||||
for event in self.search_events():
|
||||
if event.source == source:
|
||||
yield event
|
||||
|
||||
def _get_filename_for_id(self, id: int, user_id: str | None) -> str:
|
||||
return get_conversation_event_filename(self.sid, id, user_id)
|
||||
|
||||
def _get_filename_for_cache(self, start: int, end: int) -> str:
|
||||
return f'{get_conversation_dir(self.sid, self.user_id)}event_cache/{start}-{end}.json'
|
||||
|
||||
def _load_cache_page(self, start: int, end: int) -> _CachePage:
|
||||
"""Read a page from the cache. Reading individual events is slow when there are a lot of them, so we use pages."""
|
||||
cache_filename = self._get_filename_for_cache(start, end)
|
||||
try:
|
||||
content = self.file_store.read(cache_filename)
|
||||
events = json.loads(content)
|
||||
except FileNotFoundError:
|
||||
events = None
|
||||
page = _CachePage(events, start, end)
|
||||
return page
|
||||
|
||||
def _load_cache_page_for_index(self, index: int) -> _CachePage:
|
||||
offset = index % self.cache_size
|
||||
index -= offset
|
||||
return self._load_cache_page(index, index + self.cache_size)
|
||||
|
||||
@staticmethod
|
||||
def _get_id_from_filename(filename: str) -> int:
|
||||
try:
|
||||
return int(filename.split('/')[-1].split('.')[0])
|
||||
except ValueError:
|
||||
logger.warning(f'get id from filename ({filename}) failed.')
|
||||
return -1
|
||||
@@ -1,111 +0,0 @@
|
||||
from abc import abstractmethod
|
||||
from itertools import islice
|
||||
from typing import Iterable
|
||||
|
||||
from deprecated import deprecated # type: ignore
|
||||
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.event_filter import EventFilter
|
||||
|
||||
|
||||
class EventStoreABC:
|
||||
"""A stored list of events backing a conversation"""
|
||||
|
||||
sid: str
|
||||
user_id: str | None
|
||||
|
||||
@abstractmethod
|
||||
def search_events(
|
||||
self,
|
||||
start_id: int = 0,
|
||||
end_id: int | None = None,
|
||||
reverse: bool = False,
|
||||
filter: EventFilter | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterable[Event]:
|
||||
"""Retrieve events from the event stream, optionally excluding events using a filter
|
||||
|
||||
Args:
|
||||
start_id: The ID of the first event to retrieve. Defaults to 0.
|
||||
end_id: The ID of the last event to retrieve. Defaults to the last event in the stream.
|
||||
reverse: Whether to retrieve events in reverse order. Defaults to False.
|
||||
filter: An optional event filter
|
||||
|
||||
Yields:
|
||||
Events from the stream that match the criteria.
|
||||
"""
|
||||
|
||||
@deprecated('Use search_events instead')
|
||||
def get_events(
|
||||
self,
|
||||
start_id: int = 0,
|
||||
end_id: int | None = None,
|
||||
reverse: bool = False,
|
||||
filter_out_type: tuple[type[Event], ...] | None = None,
|
||||
filter_hidden: bool = False,
|
||||
) -> Iterable[Event]:
|
||||
yield from self.search_events(
|
||||
start_id,
|
||||
end_id,
|
||||
reverse,
|
||||
EventFilter(exclude_types=filter_out_type, exclude_hidden=filter_hidden),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_event(self, id: int) -> Event:
|
||||
"""Retrieve a single event from the event stream. Raise a FileNotFoundError if there was no such event"""
|
||||
|
||||
@abstractmethod
|
||||
def get_latest_event(self) -> Event:
|
||||
"""Get the latest event from the event stream"""
|
||||
|
||||
@abstractmethod
|
||||
def get_latest_event_id(self) -> int:
|
||||
"""Get the id of the latest event from the event stream"""
|
||||
|
||||
@deprecated('use search_events instead')
|
||||
def filtered_events_by_source(self, source: EventSource) -> Iterable[Event]:
|
||||
yield from self.search_events(filter=EventFilter(source=source))
|
||||
|
||||
@deprecated('use search_events instead')
|
||||
def get_matching_events(
|
||||
self,
|
||||
query: str | None = None,
|
||||
event_types: tuple[type[Event], ...] | None = None,
|
||||
source: str | None = None,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
start_id: int = 0,
|
||||
limit: int = 100,
|
||||
reverse: bool = False,
|
||||
) -> list[Event]:
|
||||
"""Get matching events from the event stream based on filters.
|
||||
|
||||
Args:
|
||||
query: Text to search for in event content
|
||||
event_types: Filter by event type classes (e.g., (FileReadAction, ) ).
|
||||
source: Filter by event source
|
||||
start_date: Filter events after this date (ISO format)
|
||||
end_date: Filter events before this date (ISO format)
|
||||
start_id: Starting ID in the event stream. Defaults to 0
|
||||
limit: Maximum number of events to return. Must be between 1 and 100. Defaults to 100
|
||||
reverse: Whether to retrieve events in reverse order. Defaults to False.
|
||||
|
||||
Returns:
|
||||
list: List of matching events (as dicts)
|
||||
"""
|
||||
if limit < 1 or limit > 100:
|
||||
raise ValueError('Limit must be between 1 and 100')
|
||||
|
||||
events = self.search_events(
|
||||
start_id=start_id,
|
||||
reverse=reverse,
|
||||
filter=EventFilter(
|
||||
query=query,
|
||||
include_types=event_types,
|
||||
source=source,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
),
|
||||
)
|
||||
return list(islice(events, limit))
|
||||
@@ -1,284 +0,0 @@
|
||||
# IMPORTANT: LEGACY V0 CODE - Deprecated since version 1.0.0, scheduled for removal April 1, 2026
|
||||
# This file is part of the legacy (V0) implementation of OpenHands and will be removed soon as we complete the migration to V1.
|
||||
# OpenHands V1 uses the Software Agent SDK for the agentic core and runs a new application server. Please refer to:
|
||||
# - V1 agentic core (SDK): https://github.com/OpenHands/software-agent-sdk
|
||||
# - V1 application server (in this repo): openhands/app_server/
|
||||
# Unless you are working on deprecation, please avoid extending this legacy file and consult the V1 codepaths above.
|
||||
# Tag: Legacy-V0
|
||||
import copy
|
||||
import time
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Cost(BaseModel):
|
||||
model: str
|
||||
cost: float
|
||||
timestamp: float = Field(default_factory=time.time)
|
||||
|
||||
|
||||
class ResponseLatency(BaseModel):
|
||||
"""Metric tracking the round-trip time per completion call."""
|
||||
|
||||
model: str
|
||||
latency: float
|
||||
response_id: str
|
||||
|
||||
|
||||
class TokenUsage(BaseModel):
|
||||
"""Metric tracking detailed token usage per completion call."""
|
||||
|
||||
model: str = Field(default='')
|
||||
prompt_tokens: int = Field(default=0)
|
||||
completion_tokens: int = Field(default=0)
|
||||
cache_read_tokens: int = Field(default=0)
|
||||
cache_write_tokens: int = Field(default=0)
|
||||
context_window: int = Field(default=0)
|
||||
per_turn_token: int = Field(default=0)
|
||||
response_id: str = Field(default='')
|
||||
|
||||
def __add__(self, other: 'TokenUsage') -> 'TokenUsage':
|
||||
"""Add two TokenUsage instances together."""
|
||||
return TokenUsage(
|
||||
model=self.model,
|
||||
prompt_tokens=self.prompt_tokens + other.prompt_tokens,
|
||||
completion_tokens=self.completion_tokens + other.completion_tokens,
|
||||
cache_read_tokens=self.cache_read_tokens + other.cache_read_tokens,
|
||||
cache_write_tokens=self.cache_write_tokens + other.cache_write_tokens,
|
||||
context_window=max(self.context_window, other.context_window),
|
||||
per_turn_token=other.per_turn_token,
|
||||
response_id=self.response_id,
|
||||
)
|
||||
|
||||
|
||||
class Metrics:
|
||||
"""Metrics class can record various metrics during running and evaluation.
|
||||
We track:
|
||||
- accumulated_cost and costs
|
||||
- max_budget_per_task (budget limit)
|
||||
- A list of ResponseLatency
|
||||
- A list of TokenUsage (one per call).
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = 'default') -> None:
|
||||
self._accumulated_cost: float = 0.0
|
||||
self._max_budget_per_task: float | None = None
|
||||
self._costs: list[Cost] = []
|
||||
self._response_latencies: list[ResponseLatency] = []
|
||||
self.model_name = model_name
|
||||
self._token_usages: list[TokenUsage] = []
|
||||
self._accumulated_token_usage: TokenUsage = TokenUsage(
|
||||
model=model_name,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=0,
|
||||
response_id='',
|
||||
)
|
||||
|
||||
@property
|
||||
def accumulated_cost(self) -> float:
|
||||
return self._accumulated_cost
|
||||
|
||||
@accumulated_cost.setter
|
||||
def accumulated_cost(self, value: float) -> None:
|
||||
if value < 0:
|
||||
raise ValueError('Total cost cannot be negative.')
|
||||
self._accumulated_cost = value
|
||||
|
||||
@property
|
||||
def max_budget_per_task(self) -> float | None:
|
||||
return self._max_budget_per_task
|
||||
|
||||
@max_budget_per_task.setter
|
||||
def max_budget_per_task(self, value: float | None) -> None:
|
||||
self._max_budget_per_task = value
|
||||
|
||||
@property
|
||||
def costs(self) -> list[Cost]:
|
||||
return self._costs
|
||||
|
||||
@property
|
||||
def response_latencies(self) -> list[ResponseLatency]:
|
||||
if not hasattr(self, '_response_latencies'):
|
||||
self._response_latencies = []
|
||||
return self._response_latencies
|
||||
|
||||
@response_latencies.setter
|
||||
def response_latencies(self, value: list[ResponseLatency]) -> None:
|
||||
self._response_latencies = value
|
||||
|
||||
@property
|
||||
def token_usages(self) -> list[TokenUsage]:
|
||||
if not hasattr(self, '_token_usages'):
|
||||
self._token_usages = []
|
||||
return self._token_usages
|
||||
|
||||
@token_usages.setter
|
||||
def token_usages(self, value: list[TokenUsage]) -> None:
|
||||
self._token_usages = value
|
||||
|
||||
@property
|
||||
def accumulated_token_usage(self) -> TokenUsage:
|
||||
"""Get the accumulated token usage, initializing it if it doesn't exist."""
|
||||
if not hasattr(self, '_accumulated_token_usage'):
|
||||
self._accumulated_token_usage = TokenUsage(
|
||||
model=self.model_name,
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
context_window=0,
|
||||
response_id='',
|
||||
)
|
||||
return self._accumulated_token_usage
|
||||
|
||||
def add_cost(self, value: float) -> None:
|
||||
if value < 0:
|
||||
raise ValueError('Added cost cannot be negative.')
|
||||
self._accumulated_cost += value
|
||||
self._costs.append(Cost(cost=value, model=self.model_name))
|
||||
|
||||
def add_response_latency(self, value: float, response_id: str) -> None:
|
||||
self._response_latencies.append(
|
||||
ResponseLatency(
|
||||
latency=max(0.0, value), model=self.model_name, response_id=response_id
|
||||
)
|
||||
)
|
||||
|
||||
def add_token_usage(
|
||||
self,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cache_read_tokens: int,
|
||||
cache_write_tokens: int,
|
||||
context_window: int,
|
||||
response_id: str,
|
||||
) -> None:
|
||||
"""Add a single usage record."""
|
||||
# Token each turn for calculating context usage.
|
||||
per_turn_token = prompt_tokens + completion_tokens
|
||||
|
||||
usage = TokenUsage(
|
||||
model=self.model_name,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
context_window=context_window,
|
||||
per_turn_token=per_turn_token,
|
||||
response_id=response_id,
|
||||
)
|
||||
self._token_usages.append(usage)
|
||||
|
||||
# Update accumulated token usage using the __add__ operator
|
||||
self._accumulated_token_usage = self.accumulated_token_usage + TokenUsage(
|
||||
model=self.model_name,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_write_tokens=cache_write_tokens,
|
||||
context_window=context_window,
|
||||
per_turn_token=per_turn_token,
|
||||
response_id='',
|
||||
)
|
||||
|
||||
def merge(self, other: 'Metrics') -> None:
|
||||
"""Merge 'other' metrics into this one."""
|
||||
self._accumulated_cost += other.accumulated_cost
|
||||
|
||||
# Keep the max_budget_per_task from other if it's set and this one isn't
|
||||
if self._max_budget_per_task is None and other.max_budget_per_task is not None:
|
||||
self._max_budget_per_task = other.max_budget_per_task
|
||||
|
||||
self._costs += other._costs
|
||||
# use the property so older picked objects that lack the field won't crash
|
||||
self.token_usages += other.token_usages
|
||||
self.response_latencies += other.response_latencies
|
||||
|
||||
# Merge accumulated token usage using the __add__ operator
|
||||
self._accumulated_token_usage = (
|
||||
self.accumulated_token_usage + other.accumulated_token_usage
|
||||
)
|
||||
|
||||
def get(self) -> dict:
|
||||
"""Return the metrics in a dictionary."""
|
||||
return {
|
||||
'accumulated_cost': self._accumulated_cost,
|
||||
'max_budget_per_task': self._max_budget_per_task,
|
||||
'accumulated_token_usage': self.accumulated_token_usage.model_dump(),
|
||||
'costs': [cost.model_dump() for cost in self._costs],
|
||||
'response_latencies': [
|
||||
latency.model_dump() for latency in self._response_latencies
|
||||
],
|
||||
'token_usages': [usage.model_dump() for usage in self._token_usages],
|
||||
}
|
||||
|
||||
def log(self) -> str:
|
||||
"""Log the metrics."""
|
||||
metrics = self.get()
|
||||
logs = ''
|
||||
for key, value in metrics.items():
|
||||
logs += f'{key}: {value}\n'
|
||||
return logs
|
||||
|
||||
def copy(self) -> 'Metrics':
|
||||
"""Create a deep copy of the Metrics object."""
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def diff(self, baseline: 'Metrics') -> 'Metrics':
|
||||
"""Calculate the difference between current metrics and a baseline.
|
||||
|
||||
This is useful for tracking metrics for specific operations like delegates.
|
||||
|
||||
Args:
|
||||
baseline: A metrics object representing the baseline state
|
||||
|
||||
Returns:
|
||||
A new Metrics object containing only the differences since the baseline
|
||||
"""
|
||||
result = Metrics(self.model_name)
|
||||
|
||||
# Calculate cost difference
|
||||
result._accumulated_cost = self._accumulated_cost - baseline._accumulated_cost
|
||||
|
||||
# Include only costs that were added after the baseline
|
||||
if baseline._costs:
|
||||
last_baseline_timestamp = baseline._costs[-1].timestamp
|
||||
result._costs = [
|
||||
cost for cost in self._costs if cost.timestamp > last_baseline_timestamp
|
||||
]
|
||||
else:
|
||||
result._costs = self._costs.copy()
|
||||
|
||||
# Include only response latencies that were added after the baseline
|
||||
result._response_latencies = self._response_latencies[
|
||||
len(baseline._response_latencies) :
|
||||
]
|
||||
|
||||
# Include only token usages that were added after the baseline
|
||||
result._token_usages = self._token_usages[len(baseline._token_usages) :]
|
||||
|
||||
# Calculate accumulated token usage difference
|
||||
base_usage = baseline.accumulated_token_usage
|
||||
current_usage = self.accumulated_token_usage
|
||||
|
||||
result._accumulated_token_usage = TokenUsage(
|
||||
model=self.model_name,
|
||||
prompt_tokens=current_usage.prompt_tokens - base_usage.prompt_tokens,
|
||||
completion_tokens=current_usage.completion_tokens
|
||||
- base_usage.completion_tokens,
|
||||
cache_read_tokens=current_usage.cache_read_tokens
|
||||
- base_usage.cache_read_tokens,
|
||||
cache_write_tokens=current_usage.cache_write_tokens
|
||||
- base_usage.cache_write_tokens,
|
||||
context_window=current_usage.context_window,
|
||||
per_turn_token=0,
|
||||
response_id='',
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'Metrics({self.get()}'
|
||||
@@ -1,101 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx # type: ignore
|
||||
from fastapi import status
|
||||
|
||||
from openhands.events.event import Event
|
||||
from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.event_store_abc import EventStoreABC
|
||||
from openhands.events.serialization.event import event_from_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class NestedEventStore(EventStoreABC):
|
||||
"""A stored list of events backing a conversation"""
|
||||
|
||||
base_url: str
|
||||
sid: str
|
||||
user_id: str | None
|
||||
session_api_key: str | None = None
|
||||
|
||||
def search_events(
|
||||
self,
|
||||
start_id: int = 0,
|
||||
end_id: int | None = None,
|
||||
reverse: bool = False,
|
||||
filter: EventFilter | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterable[Event]:
|
||||
# Maintain explicit cursors for pagination to avoid accidental mutation.
|
||||
start_cursor = start_id
|
||||
end_cursor: int | None = None # Used only for reverse pagination
|
||||
while True:
|
||||
search_params: dict[str, int | bool] = {
|
||||
'start_id': start_cursor,
|
||||
'reverse': reverse,
|
||||
}
|
||||
if reverse and end_cursor is not None:
|
||||
# Bound the upper end when scanning backwards to avoid duplicates
|
||||
search_params['end_id'] = end_cursor
|
||||
if limit is not None:
|
||||
search_params['limit'] = min(100, limit)
|
||||
search_str = urlencode(search_params)
|
||||
url = f'{self.base_url}/events?{search_str}'
|
||||
headers: dict[str, str] = {}
|
||||
if self.session_api_key:
|
||||
headers['X-Session-API-Key'] = self.session_api_key
|
||||
response = httpx.get(url, headers=headers)
|
||||
if response.status_code == status.HTTP_404_NOT_FOUND:
|
||||
# Follow pattern of event store not throwing errors on not found
|
||||
return
|
||||
result_set = response.json()
|
||||
|
||||
page_min_id: int | None = None
|
||||
forward_next_start = start_cursor
|
||||
for result in result_set['events']:
|
||||
event = event_from_dict(result)
|
||||
if reverse:
|
||||
page_min_id = (
|
||||
event.id if page_min_id is None else min(page_min_id, event.id)
|
||||
)
|
||||
else:
|
||||
forward_next_start = max(forward_next_start, event.id + 1)
|
||||
if end_id == event.id:
|
||||
if not filter or filter.include(event):
|
||||
yield event
|
||||
return
|
||||
if filter and filter.exclude(event):
|
||||
continue
|
||||
yield event
|
||||
if limit is not None:
|
||||
limit -= 1
|
||||
if limit <= 0:
|
||||
return
|
||||
|
||||
# Update pagination cursor for next request
|
||||
if reverse and page_min_id is not None:
|
||||
# Next page should end strictly before the smallest ID we just saw
|
||||
end_cursor = page_min_id - 1
|
||||
elif not reverse:
|
||||
start_cursor = forward_next_start
|
||||
|
||||
if not result_set['has_more']:
|
||||
return
|
||||
|
||||
def get_event(self, id: int) -> Event:
|
||||
events = list(self.search_events(start_id=id, limit=1))
|
||||
if not events:
|
||||
raise FileNotFoundError('no_event')
|
||||
return events[0]
|
||||
|
||||
def get_latest_event(self) -> Event:
|
||||
events = list(self.search_events(reverse=True, limit=1))
|
||||
if not events:
|
||||
raise FileNotFoundError('no_event')
|
||||
return events[0]
|
||||
|
||||
def get_latest_event_id(self) -> int:
|
||||
event = self.get_latest_event()
|
||||
return event.id
|
||||
@@ -1,55 +0,0 @@
|
||||
from openhands.events.observation.agent import (
|
||||
AgentCondensationObservation,
|
||||
AgentStateChangedObservation,
|
||||
AgentThinkObservation,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.observation.browse import BrowserOutputObservation
|
||||
from openhands.events.observation.commands import (
|
||||
CmdOutputMetadata,
|
||||
CmdOutputObservation,
|
||||
IPythonRunCellObservation,
|
||||
)
|
||||
from openhands.events.observation.delegate import AgentDelegateObservation
|
||||
from openhands.events.observation.empty import (
|
||||
NullObservation,
|
||||
)
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.observation.file_download import FileDownloadObservation
|
||||
from openhands.events.observation.files import (
|
||||
FileEditObservation,
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
)
|
||||
from openhands.events.observation.loop_recovery import LoopDetectionObservation
|
||||
from openhands.events.observation.mcp import MCPObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.observation.reject import UserRejectObservation
|
||||
from openhands.events.observation.success import SuccessObservation
|
||||
from openhands.events.observation.task_tracking import TaskTrackingObservation
|
||||
from openhands.events.recall_type import RecallType
|
||||
|
||||
__all__ = [
|
||||
'Observation',
|
||||
'NullObservation',
|
||||
'AgentThinkObservation',
|
||||
'CmdOutputObservation',
|
||||
'CmdOutputMetadata',
|
||||
'IPythonRunCellObservation',
|
||||
'BrowserOutputObservation',
|
||||
'FileReadObservation',
|
||||
'FileWriteObservation',
|
||||
'FileEditObservation',
|
||||
'ErrorObservation',
|
||||
'AgentStateChangedObservation',
|
||||
'AgentDelegateObservation',
|
||||
'SuccessObservation',
|
||||
'UserRejectObservation',
|
||||
'AgentCondensationObservation',
|
||||
'RecallObservation',
|
||||
'RecallType',
|
||||
'LoopDetectionObservation',
|
||||
'MCPObservation',
|
||||
'FileDownloadObservation',
|
||||
'TaskTrackingObservation',
|
||||
]
|
||||
@@ -1,138 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.recall_type import RecallType
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentStateChangedObservation(Observation):
|
||||
"""This data class represents the result from delegating to another agent"""
|
||||
|
||||
agent_state: str
|
||||
reason: str = ''
|
||||
observation: str = ObservationType.AGENT_STATE_CHANGED
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return ''
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentCondensationObservation(Observation):
|
||||
"""The output of a condensation action."""
|
||||
|
||||
observation: str = ObservationType.CONDENSE
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.content
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentThinkObservation(Observation):
|
||||
"""The output of a think action.
|
||||
|
||||
In practice, this is a no-op, since it will just reply a static message to the agent
|
||||
acknowledging that the thought has been logged.
|
||||
"""
|
||||
|
||||
observation: str = ObservationType.THINK
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.content
|
||||
|
||||
|
||||
@dataclass
|
||||
class MicroagentKnowledge:
|
||||
"""Represents knowledge from a triggered microagent.
|
||||
|
||||
Attributes:
|
||||
name: The name of the microagent that was triggered
|
||||
trigger: The word that triggered this microagent
|
||||
content: The actual content/knowledge from the microagent
|
||||
"""
|
||||
|
||||
name: str
|
||||
trigger: str
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecallObservation(Observation):
|
||||
"""The retrieval of content from a microagent or more microagents."""
|
||||
|
||||
recall_type: RecallType
|
||||
observation: str = ObservationType.RECALL
|
||||
|
||||
# workspace context
|
||||
repo_name: str = ''
|
||||
repo_directory: str = ''
|
||||
repo_branch: str = ''
|
||||
repo_instructions: str = ''
|
||||
runtime_hosts: dict[str, int] = field(default_factory=dict)
|
||||
additional_agent_instructions: str = ''
|
||||
date: str = ''
|
||||
custom_secrets_descriptions: dict[str, str] = field(default_factory=dict)
|
||||
conversation_instructions: str = ''
|
||||
working_dir: str = ''
|
||||
|
||||
# knowledge
|
||||
microagent_knowledge: list[MicroagentKnowledge] = field(default_factory=list)
|
||||
"""
|
||||
A list of MicroagentKnowledge objects, each containing information from a triggered microagent.
|
||||
|
||||
Example:
|
||||
[
|
||||
MicroagentKnowledge(
|
||||
name="python_best_practices",
|
||||
trigger="python",
|
||||
content="Always use virtual environments for Python projects."
|
||||
),
|
||||
MicroagentKnowledge(
|
||||
name="git_workflow",
|
||||
trigger="git",
|
||||
content="Create a new branch for each feature or bugfix."
|
||||
)
|
||||
]
|
||||
"""
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return (
|
||||
'Added workspace context'
|
||||
if self.recall_type == RecallType.WORKSPACE_CONTEXT
|
||||
else 'Added microagent knowledge'
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
# Build a string representation
|
||||
fields = []
|
||||
if self.recall_type == RecallType.WORKSPACE_CONTEXT:
|
||||
fields.extend(
|
||||
[
|
||||
f'recall_type={self.recall_type}',
|
||||
f'repo_name={self.repo_name}',
|
||||
f'repo_instructions={self.repo_instructions[:20]}...',
|
||||
f'runtime_hosts={self.runtime_hosts}',
|
||||
f'additional_agent_instructions={self.additional_agent_instructions[:20]}...',
|
||||
f'date={self.date}'
|
||||
f'custom_secrets_descriptions={self.custom_secrets_descriptions}',
|
||||
f'conversation_instructions={self.conversation_instructions[0:20]}...',
|
||||
]
|
||||
)
|
||||
else:
|
||||
fields.extend(
|
||||
[
|
||||
f'recall_type={self.recall_type}',
|
||||
]
|
||||
)
|
||||
if self.microagent_knowledge:
|
||||
fields.extend(
|
||||
[
|
||||
f'microagent_knowledge={", ".join([m.name for m in self.microagent_knowledge])}',
|
||||
]
|
||||
)
|
||||
|
||||
return f'**RecallObservation**\n{", ".join(fields)}'
|
||||
@@ -1,56 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.observation.observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrowserOutputObservation(Observation):
|
||||
"""This data class represents the output of a browser."""
|
||||
|
||||
url: str
|
||||
trigger_by_action: str
|
||||
screenshot: str = field(repr=False, default='') # don't show in repr
|
||||
screenshot_path: str | None = field(default=None) # path to saved screenshot file
|
||||
set_of_marks: str = field(default='', repr=False) # don't show in repr
|
||||
error: bool = False
|
||||
observation: str = ObservationType.BROWSE
|
||||
goal_image_urls: list[str] = field(default_factory=list)
|
||||
# do not include in the memory
|
||||
open_pages_urls: list[str] = field(default_factory=list)
|
||||
active_page_index: int = -1
|
||||
dom_object: dict[str, Any] = field(
|
||||
default_factory=dict, repr=False
|
||||
) # don't show in repr
|
||||
axtree_object: dict[str, Any] = field(
|
||||
default_factory=dict, repr=False
|
||||
) # don't show in repr
|
||||
extra_element_properties: dict[str, Any] = field(
|
||||
default_factory=dict, repr=False
|
||||
) # don't show in repr
|
||||
last_browser_action: str = ''
|
||||
last_browser_action_error: str = ''
|
||||
focused_element_bid: str = ''
|
||||
filter_visible_only: bool = False
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return 'Visited ' + self.url
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = (
|
||||
'**BrowserOutputObservation**\n'
|
||||
f'URL: {self.url}\n'
|
||||
f'Error: {self.error}\n'
|
||||
f'Open pages: {self.open_pages_urls}\n'
|
||||
f'Active page index: {self.active_page_index}\n'
|
||||
f'Last browser action: {self.last_browser_action}\n'
|
||||
f'Last browser action error: {self.last_browser_action_error}\n'
|
||||
f'Focused element bid: {self.focused_element_bid}\n'
|
||||
)
|
||||
if self.screenshot_path:
|
||||
ret += f'Screenshot saved to: {self.screenshot_path}\n'
|
||||
ret += '--- Agent Observation ---\n'
|
||||
ret += self.content
|
||||
return ret
|
||||
@@ -1,231 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.observation.observation import Observation
|
||||
|
||||
CMD_OUTPUT_PS1_BEGIN = '\n###PS1JSON###\n'
|
||||
CMD_OUTPUT_PS1_END = '\n###PS1END###'
|
||||
CMD_OUTPUT_METADATA_PS1_REGEX = re.compile(
|
||||
f'^{CMD_OUTPUT_PS1_BEGIN.strip()}(.*?){CMD_OUTPUT_PS1_END.strip()}',
|
||||
re.DOTALL | re.MULTILINE,
|
||||
)
|
||||
|
||||
# Default max size for command output content
|
||||
# to prevent too large observations from being saved in the stream
|
||||
# This matches the default max_message_chars in LLMConfig
|
||||
MAX_CMD_OUTPUT_SIZE: int = 30000
|
||||
|
||||
|
||||
class CmdOutputMetadata(BaseModel):
|
||||
"""Additional metadata captured from PS1"""
|
||||
|
||||
exit_code: int = -1
|
||||
pid: int = -1
|
||||
username: str | None = None
|
||||
hostname: str | None = None
|
||||
working_dir: str | None = None
|
||||
py_interpreter_path: str | None = None
|
||||
prefix: str = '' # Prefix to add to command output
|
||||
suffix: str = '' # Suffix to add to command output
|
||||
|
||||
@classmethod
|
||||
def to_ps1_prompt(cls) -> str:
|
||||
"""Convert the required metadata into a PS1 prompt."""
|
||||
prompt = CMD_OUTPUT_PS1_BEGIN
|
||||
json_str = json.dumps(
|
||||
{
|
||||
'pid': '$!',
|
||||
'exit_code': '$?',
|
||||
'username': r'\u',
|
||||
'hostname': r'\h',
|
||||
'working_dir': r'$(pwd)',
|
||||
'py_interpreter_path': r'$(which python 2>/dev/null || echo "")',
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
# Make sure we escape double quotes in the JSON string
|
||||
# So that PS1 will keep them as part of the output
|
||||
prompt += json_str.replace('"', r'\"')
|
||||
prompt += CMD_OUTPUT_PS1_END + '\n' # Ensure there's a newline at the end
|
||||
return prompt
|
||||
|
||||
@classmethod
|
||||
def matches_ps1_metadata(cls, string: str) -> list[re.Match[str]]:
|
||||
matches = []
|
||||
for match in CMD_OUTPUT_METADATA_PS1_REGEX.finditer(string):
|
||||
try:
|
||||
json.loads(match.group(1).strip()) # Try to parse as JSON
|
||||
matches.append(match)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(
|
||||
f'Failed to parse PS1 metadata: {match.group(1)}. Skipping.',
|
||||
exc_info=True,
|
||||
)
|
||||
continue # Skip if not valid JSON
|
||||
return matches
|
||||
|
||||
@classmethod
|
||||
def from_ps1_match(cls, match: re.Match[str]) -> Self:
|
||||
"""Extract the required metadata from a PS1 prompt."""
|
||||
metadata = json.loads(match.group(1))
|
||||
# Create a copy of metadata to avoid modifying the original
|
||||
processed = metadata.copy()
|
||||
# Convert numeric fields
|
||||
if 'pid' in metadata:
|
||||
try:
|
||||
processed['pid'] = int(float(str(metadata['pid'])))
|
||||
except (ValueError, TypeError):
|
||||
processed['pid'] = -1
|
||||
if 'exit_code' in metadata:
|
||||
try:
|
||||
processed['exit_code'] = int(float(str(metadata['exit_code'])))
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
f'Failed to parse exit code: {metadata["exit_code"]}. Setting to -1.'
|
||||
)
|
||||
processed['exit_code'] = -1
|
||||
return cls(**processed)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CmdOutputObservation(Observation):
|
||||
"""This data class represents the output of a command."""
|
||||
|
||||
command: str
|
||||
observation: str = ObservationType.RUN
|
||||
# Additional metadata captured from PS1
|
||||
metadata: CmdOutputMetadata = field(default_factory=CmdOutputMetadata)
|
||||
# Whether the command output should be hidden from the user
|
||||
hidden: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: str,
|
||||
command: str,
|
||||
observation: str = ObservationType.RUN,
|
||||
metadata: dict[str, Any] | CmdOutputMetadata | None = None,
|
||||
hidden: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Truncate content before passing it to parent
|
||||
# Hidden commands don't go through LLM/event stream, so no need to truncate
|
||||
truncate = not hidden
|
||||
if truncate:
|
||||
content = self._maybe_truncate(content)
|
||||
|
||||
super().__init__(content)
|
||||
|
||||
self.command = command
|
||||
self.observation = observation
|
||||
self.hidden = hidden
|
||||
if isinstance(metadata, dict):
|
||||
self.metadata = CmdOutputMetadata(**metadata)
|
||||
else:
|
||||
self.metadata = metadata or CmdOutputMetadata()
|
||||
|
||||
# Handle legacy attribute
|
||||
if 'exit_code' in kwargs:
|
||||
self.metadata.exit_code = kwargs['exit_code']
|
||||
if 'command_id' in kwargs:
|
||||
self.metadata.pid = kwargs['command_id']
|
||||
|
||||
@staticmethod
|
||||
def _maybe_truncate(content: str, max_size: int = MAX_CMD_OUTPUT_SIZE) -> str:
|
||||
"""Truncate the content if it's too large.
|
||||
|
||||
This helps avoid storing unnecessarily large content in the event stream.
|
||||
|
||||
Args:
|
||||
content: The content to truncate
|
||||
max_size: Maximum size before truncation. Defaults to MAX_CMD_OUTPUT_SIZE.
|
||||
|
||||
Returns:
|
||||
Original content if not too large, or truncated content otherwise
|
||||
"""
|
||||
if len(content) <= max_size:
|
||||
return content
|
||||
|
||||
# Truncate the middle and include a message about it
|
||||
half = max_size // 2
|
||||
original_length = len(content)
|
||||
truncated = (
|
||||
content[:half]
|
||||
+ '\n[... Observation truncated due to length ...]\n'
|
||||
+ content[-half:]
|
||||
)
|
||||
logger.debug(
|
||||
f'Truncated large command output: {original_length} -> {len(truncated)} chars'
|
||||
)
|
||||
return truncated
|
||||
|
||||
@property
|
||||
def command_id(self) -> int:
|
||||
return self.metadata.pid
|
||||
|
||||
@property
|
||||
def exit_code(self) -> int:
|
||||
return self.metadata.exit_code
|
||||
|
||||
@property
|
||||
def error(self) -> bool:
|
||||
return self.exit_code != 0
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Command `{self.command}` executed with exit code {self.exit_code}.'
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
return not self.error
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'**CmdOutputObservation (source={self.source}, exit code={self.exit_code}, '
|
||||
f'metadata={json.dumps(self.metadata.model_dump(), indent=2)})**\n'
|
||||
'--BEGIN AGENT OBSERVATION--\n'
|
||||
f'{self.to_agent_observation()}\n'
|
||||
'--END AGENT OBSERVATION--'
|
||||
)
|
||||
|
||||
def to_agent_observation(self) -> str:
|
||||
ret = f'{self.metadata.prefix}{self.content}{self.metadata.suffix}'
|
||||
if self.metadata.working_dir:
|
||||
ret += f'\n[Current working directory: {self.metadata.working_dir}]'
|
||||
if self.metadata.py_interpreter_path:
|
||||
ret += f'\n[Python interpreter: {self.metadata.py_interpreter_path}]'
|
||||
if self.metadata.exit_code != -1:
|
||||
ret += f'\n[Command finished with exit code {self.metadata.exit_code}]'
|
||||
return ret
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPythonRunCellObservation(Observation):
|
||||
"""This data class represents the output of a IPythonRunCellAction."""
|
||||
|
||||
code: str
|
||||
observation: str = ObservationType.RUN_IPYTHON
|
||||
image_urls: list[str] | None = None
|
||||
|
||||
@property
|
||||
def error(self) -> bool:
|
||||
return False # IPython cells do not return exit codes
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return 'Code executed in IPython cell.'
|
||||
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
return True # IPython cells are always considered successful
|
||||
|
||||
def __str__(self) -> str:
|
||||
result = f'**IPythonRunCellObservation**\n{self.content}'
|
||||
if self.image_urls:
|
||||
result += f'\nImages: {len(self.image_urls)}'
|
||||
return result
|
||||
@@ -1,22 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.observation.observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentDelegateObservation(Observation):
|
||||
"""This data class represents the result from delegating to another agent.
|
||||
|
||||
Attributes:
|
||||
content (str): The content of the observation.
|
||||
outputs (dict): The outputs of the delegated agent.
|
||||
observation (str): The type of observation.
|
||||
"""
|
||||
|
||||
outputs: dict
|
||||
observation: str = ObservationType.DELEGATE
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return ''
|
||||
@@ -1,17 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.observation.observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class NullObservation(Observation):
|
||||
"""This data class represents a null observation.
|
||||
This is used when the produced action is NOT executable.
|
||||
"""
|
||||
|
||||
observation: str = ObservationType.NULL
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return 'No observation'
|
||||
@@ -1,23 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.observation.observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErrorObservation(Observation):
|
||||
"""This data class represents an error encountered by the agent.
|
||||
|
||||
This is the type of error that LLM can recover from.
|
||||
E.g., Linter error after editing a file.
|
||||
"""
|
||||
|
||||
observation: str = ObservationType.ERROR
|
||||
error_id: str = ''
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.content
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'**ErrorObservation**\n{self.content}'
|
||||
@@ -1,21 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.observation.observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileDownloadObservation(Observation):
|
||||
file_path: str
|
||||
observation: str = ObservationType.DOWNLOAD
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Downloaded the file at location: {self.file_path}'
|
||||
|
||||
def __str__(self) -> str:
|
||||
ret = (
|
||||
'**FileDownloadObservation**\n'
|
||||
f'Location of downloaded file: {self.file_path}\n'
|
||||
)
|
||||
return ret
|
||||
@@ -1,195 +0,0 @@
|
||||
"""File-related observation classes for tracking file operations."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.event import FileEditSource, FileReadSource
|
||||
from openhands.events.observation.observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileReadObservation(Observation):
|
||||
"""This data class represents the content of a file."""
|
||||
|
||||
path: str
|
||||
observation: str = ObservationType.READ
|
||||
impl_source: FileReadSource = FileReadSource.DEFAULT
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
"""Get a human-readable message describing the file read operation."""
|
||||
return f'I read the file {self.path}.'
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Get a string representation of the file read observation."""
|
||||
return f'[Read from {self.path} is successful.]\n{self.content}'
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileWriteObservation(Observation):
|
||||
"""This data class represents a file write operation."""
|
||||
|
||||
path: str
|
||||
observation: str = ObservationType.WRITE
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
"""Get a human-readable message describing the file write operation."""
|
||||
return f'I wrote to the file {self.path}.'
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Get a string representation of the file write observation."""
|
||||
return f'[Write to {self.path} is successful.]\n{self.content}'
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileEditObservation(Observation):
|
||||
"""This data class represents a file edit operation.
|
||||
|
||||
The observation includes both the old and new content of the file, and can
|
||||
generate a diff visualization showing the changes. The diff is computed lazily
|
||||
and cached to improve performance.
|
||||
|
||||
The .content property can either be:
|
||||
- Git diff in LLM-based editing mode
|
||||
- the rendered message sent to the LLM in OH_ACI mode (e.g., "The file /path/to/file.txt is created with the provided content.")
|
||||
"""
|
||||
|
||||
path: str = ''
|
||||
prev_exist: bool = False
|
||||
old_content: str | None = None
|
||||
new_content: str | None = None
|
||||
observation: str = ObservationType.EDIT
|
||||
impl_source: FileEditSource = FileEditSource.LLM_BASED_EDIT
|
||||
diff: str | None = (
|
||||
None # The raw diff between old and new content, used in OH_ACI mode
|
||||
)
|
||||
_diff_cache: str | None = (
|
||||
None # Cache for the diff visualization, used in LLM-based editing mode
|
||||
)
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
"""Get a human-readable message describing the file edit operation."""
|
||||
return f'I edited the file {self.path}.'
|
||||
|
||||
def get_edit_groups(self, n_context_lines: int = 2) -> list[dict[str, list[str]]]:
|
||||
"""Get the edit groups showing changes between old and new content.
|
||||
|
||||
Args:
|
||||
n_context_lines: Number of context lines to show around each change.
|
||||
|
||||
Returns:
|
||||
A list of edit groups, where each group contains before/after edits.
|
||||
"""
|
||||
if self.old_content is None or self.new_content is None:
|
||||
return []
|
||||
old_lines = self.old_content.split('\n')
|
||||
new_lines = self.new_content.split('\n')
|
||||
# Borrowed from difflib.unified_diff to directly parse into structured format
|
||||
edit_groups: list[dict] = []
|
||||
for group in SequenceMatcher(None, old_lines, new_lines).get_grouped_opcodes(
|
||||
n_context_lines
|
||||
):
|
||||
# Take the max line number in the group
|
||||
_indent_pad_size = len(str(group[-1][3])) + 1 # +1 for "*" prefix
|
||||
cur_group: dict[str, list[str]] = {
|
||||
'before_edits': [],
|
||||
'after_edits': [],
|
||||
}
|
||||
for tag, i1, i2, j1, j2 in group:
|
||||
if tag == 'equal':
|
||||
for idx, line in enumerate(old_lines[i1:i2]):
|
||||
line_num = i1 + idx + 1
|
||||
cur_group['before_edits'].append(
|
||||
f'{line_num:>{_indent_pad_size}}|{line}'
|
||||
)
|
||||
for idx, line in enumerate(new_lines[j1:j2]):
|
||||
line_num = j1 + idx + 1
|
||||
cur_group['after_edits'].append(
|
||||
f'{line_num:>{_indent_pad_size}}|{line}'
|
||||
)
|
||||
continue
|
||||
if tag in {'replace', 'delete'}:
|
||||
for idx, line in enumerate(old_lines[i1:i2]):
|
||||
line_num = i1 + idx + 1
|
||||
cur_group['before_edits'].append(
|
||||
f'-{line_num:>{_indent_pad_size - 1}}|{line}'
|
||||
)
|
||||
if tag in {'replace', 'insert'}:
|
||||
for idx, line in enumerate(new_lines[j1:j2]):
|
||||
line_num = j1 + idx + 1
|
||||
cur_group['after_edits'].append(
|
||||
f'+{line_num:>{_indent_pad_size - 1}}|{line}'
|
||||
)
|
||||
edit_groups.append(cur_group)
|
||||
return edit_groups
|
||||
|
||||
def visualize_diff(
|
||||
self,
|
||||
n_context_lines: int = 2,
|
||||
change_applied: bool = True,
|
||||
) -> str:
|
||||
"""Visualize the diff of the file edit. Used in the LLM-based editing mode.
|
||||
|
||||
Instead of showing the diff line by line, this function shows each hunk
|
||||
of changes as a separate entity.
|
||||
|
||||
Args:
|
||||
n_context_lines: Number of context lines to show before/after changes.
|
||||
change_applied: Whether changes are applied. If false, shows as
|
||||
attempted edit.
|
||||
|
||||
Returns:
|
||||
A string containing the formatted diff visualization.
|
||||
"""
|
||||
# Use cached diff if available
|
||||
if self._diff_cache is not None:
|
||||
return self._diff_cache
|
||||
|
||||
# Check if there are any changes
|
||||
if change_applied and self.old_content == self.new_content:
|
||||
msg = '(no changes detected. Please make sure your edits change '
|
||||
msg += 'the content of the existing file.)\n'
|
||||
self._diff_cache = msg
|
||||
return self._diff_cache
|
||||
|
||||
edit_groups = self.get_edit_groups(n_context_lines=n_context_lines)
|
||||
|
||||
if change_applied:
|
||||
header = f'[Existing file {self.path} is edited with '
|
||||
header += f'{len(edit_groups)} changes.]'
|
||||
else:
|
||||
header = f"[Changes are NOT applied to {self.path} - Here's how "
|
||||
header += 'the file looks like if changes are applied.]'
|
||||
result = [header]
|
||||
|
||||
op_type = 'edit' if change_applied else 'ATTEMPTED edit'
|
||||
for i, cur_edit_group in enumerate(edit_groups):
|
||||
if i != 0:
|
||||
result.append('-------------------------')
|
||||
result.append(f'[begin of {op_type} {i + 1} / {len(edit_groups)}]')
|
||||
result.append(f'(content before {op_type})')
|
||||
result.extend(cur_edit_group['before_edits'])
|
||||
result.append(f'(content after {op_type})')
|
||||
result.extend(cur_edit_group['after_edits'])
|
||||
result.append(f'[end of {op_type} {i + 1} / {len(edit_groups)}]')
|
||||
|
||||
# Cache the result
|
||||
self._diff_cache = '\n'.join(result)
|
||||
return self._diff_cache
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Get a string representation of the file edit observation."""
|
||||
if self.impl_source == FileEditSource.OH_ACI:
|
||||
return self.content
|
||||
|
||||
if not self.prev_exist:
|
||||
assert self.old_content == '', (
|
||||
'old_content should be empty if the file is new (prev_exist=False).'
|
||||
)
|
||||
return f'[New file {self.path} is created with the provided content.]\n'
|
||||
|
||||
# Use cached diff if available, otherwise compute it
|
||||
return self.visualize_diff().rstrip() + '\n'
|
||||
@@ -1,18 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.observation.observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoopDetectionObservation(Observation):
|
||||
"""Observation for loop recovery state changes.
|
||||
|
||||
This observation is used to notify the UI layer when agent
|
||||
is in loop recovery mode.
|
||||
|
||||
This observation is CLI-specific and should only be displayed
|
||||
in CLI/TUI mode, not in GUI or other UI modes.
|
||||
"""
|
||||
|
||||
observation: str = ObservationType.LOOP_DETECTION
|
||||
@@ -1,20 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.observation.observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class MCPObservation(Observation):
|
||||
"""This data class represents the result of a MCP Server operation."""
|
||||
|
||||
observation: str = ObservationType.MCP
|
||||
name: str = '' # The name of the MCP tool that was called
|
||||
arguments: dict[str, Any] = field(
|
||||
default_factory=dict
|
||||
) # The arguments passed to the MCP tool
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.content
|
||||
@@ -1,15 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.events.event import Event
|
||||
|
||||
|
||||
@dataclass
|
||||
class Observation(Event):
|
||||
"""Base class for observations from the environment.
|
||||
|
||||
Attributes:
|
||||
content: The content of the observation. For large observations,
|
||||
this might be truncated when stored.
|
||||
"""
|
||||
|
||||
content: str
|
||||
@@ -1,15 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.observation.observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserRejectObservation(Observation):
|
||||
"""This data class represents the result of a rejected action."""
|
||||
|
||||
observation: str = ObservationType.USER_REJECTED
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.content
|
||||
@@ -1,15 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.observation.observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class SuccessObservation(Observation):
|
||||
"""This data class represents the result of a successful action."""
|
||||
|
||||
observation: str = ObservationType.SUCCESS
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.content
|
||||
@@ -1,18 +0,0 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from openhands.core.schema import ObservationType
|
||||
from openhands.events.observation.observation import Observation
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskTrackingObservation(Observation):
|
||||
"""This data class represents the result of a task tracking operation."""
|
||||
|
||||
observation: str = ObservationType.TASK_TRACKING
|
||||
command: str = ''
|
||||
task_list: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return self.content
|
||||
@@ -1,11 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class RecallType(str, Enum):
|
||||
"""The type of information that can be retrieved from microagents."""
|
||||
|
||||
WORKSPACE_CONTEXT = 'workspace_context'
|
||||
"""Workspace context (repo instructions, runtime, etc.)"""
|
||||
|
||||
KNOWLEDGE = 'knowledge'
|
||||
"""A knowledge microagent."""
|
||||
@@ -1,19 +0,0 @@
|
||||
from openhands.events.serialization.action import (
|
||||
action_from_dict,
|
||||
)
|
||||
from openhands.events.serialization.event import (
|
||||
event_from_dict,
|
||||
event_to_dict,
|
||||
event_to_trajectory,
|
||||
)
|
||||
from openhands.events.serialization.observation import (
|
||||
observation_from_dict,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'action_from_dict',
|
||||
'event_from_dict',
|
||||
'event_to_dict',
|
||||
'event_to_trajectory',
|
||||
'observation_from_dict',
|
||||
]
|
||||
@@ -1,156 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from openhands.core.exceptions import LLMMalformedActionError
|
||||
from openhands.events.action.action import Action, ActionSecurityRisk
|
||||
from openhands.events.action.agent import (
|
||||
AgentDelegateAction,
|
||||
AgentFinishAction,
|
||||
AgentRejectAction,
|
||||
AgentThinkAction,
|
||||
ChangeAgentStateAction,
|
||||
CondensationAction,
|
||||
CondensationRequestAction,
|
||||
LoopRecoveryAction,
|
||||
RecallAction,
|
||||
TaskTrackingAction,
|
||||
)
|
||||
from openhands.events.action.browse import BrowseInteractiveAction, BrowseURLAction
|
||||
from openhands.events.action.commands import (
|
||||
CmdRunAction,
|
||||
IPythonRunCellAction,
|
||||
)
|
||||
from openhands.events.action.empty import NullAction
|
||||
from openhands.events.action.files import (
|
||||
FileEditAction,
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
)
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.action.message import MessageAction, SystemMessageAction
|
||||
|
||||
actions = (
|
||||
NullAction,
|
||||
CmdRunAction,
|
||||
IPythonRunCellAction,
|
||||
BrowseURLAction,
|
||||
BrowseInteractiveAction,
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
FileEditAction,
|
||||
AgentThinkAction,
|
||||
AgentFinishAction,
|
||||
AgentRejectAction,
|
||||
AgentDelegateAction,
|
||||
RecallAction,
|
||||
ChangeAgentStateAction,
|
||||
MessageAction,
|
||||
SystemMessageAction,
|
||||
CondensationAction,
|
||||
CondensationRequestAction,
|
||||
MCPAction,
|
||||
TaskTrackingAction,
|
||||
LoopRecoveryAction,
|
||||
)
|
||||
|
||||
ACTION_TYPE_TO_CLASS = {action_class.action: action_class for action_class in actions} # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def handle_action_deprecated_args(args: dict[str, Any]) -> dict[str, Any]:
|
||||
# keep_prompt has been deprecated in https://github.com/OpenHands/OpenHands/pull/4881
|
||||
if 'keep_prompt' in args:
|
||||
args.pop('keep_prompt')
|
||||
|
||||
# task_completed has been deprecated - remove it from args to maintain backward compatibility
|
||||
if 'task_completed' in args:
|
||||
args.pop('task_completed')
|
||||
|
||||
# Handle translated_ipython_code deprecation
|
||||
if 'translated_ipython_code' in args:
|
||||
code = args.pop('translated_ipython_code')
|
||||
|
||||
# Check if it's a file_editor call using a prefix check for efficiency
|
||||
file_editor_prefix = 'print(file_editor(**'
|
||||
if (
|
||||
code is not None
|
||||
and code.startswith(file_editor_prefix)
|
||||
and code.endswith('))')
|
||||
):
|
||||
try:
|
||||
# Extract and evaluate the dictionary string
|
||||
import ast
|
||||
|
||||
# Extract the dictionary string between the prefix and the closing parentheses
|
||||
dict_str = code[len(file_editor_prefix) : -2] # Remove prefix and '))'
|
||||
file_args = ast.literal_eval(dict_str)
|
||||
|
||||
# Update args with the extracted file editor arguments
|
||||
args.update(file_args)
|
||||
except (ValueError, SyntaxError):
|
||||
# If parsing fails, just remove the translated_ipython_code
|
||||
pass
|
||||
|
||||
if args.get('command') == 'view':
|
||||
args.pop(
|
||||
'command'
|
||||
) # "view" will be translated to FileReadAction which doesn't have a command argument
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def action_from_dict(action: dict) -> Action:
|
||||
if not isinstance(action, dict):
|
||||
raise LLMMalformedActionError('action must be a dictionary')
|
||||
action = action.copy()
|
||||
if 'action' not in action:
|
||||
raise LLMMalformedActionError(f"'action' key is not found in {action=}")
|
||||
if not isinstance(action['action'], str):
|
||||
raise LLMMalformedActionError(
|
||||
f"'{action['action']=}' is not defined. Available actions: {ACTION_TYPE_TO_CLASS.keys()}"
|
||||
)
|
||||
action_class = ACTION_TYPE_TO_CLASS.get(action['action'])
|
||||
if action_class is None:
|
||||
raise LLMMalformedActionError(
|
||||
f"'{action['action']=}' is not defined. Available actions: {ACTION_TYPE_TO_CLASS.keys()}"
|
||||
)
|
||||
args = action.get('args', {})
|
||||
# Remove timestamp from args if present
|
||||
timestamp = args.pop('timestamp', None)
|
||||
|
||||
# compatibility for older event streams
|
||||
# is_confirmed has been renamed to confirmation_state
|
||||
is_confirmed = args.pop('is_confirmed', None)
|
||||
if is_confirmed is not None:
|
||||
args['confirmation_state'] = is_confirmed
|
||||
|
||||
# images_urls has been renamed to image_urls
|
||||
if 'images_urls' in args:
|
||||
args['image_urls'] = args.pop('images_urls')
|
||||
|
||||
# Handle security_risk deserialization
|
||||
if 'security_risk' in args and args['security_risk'] is not None:
|
||||
try:
|
||||
# Convert numeric value (int) back to enum
|
||||
args['security_risk'] = ActionSecurityRisk(args['security_risk'])
|
||||
except (ValueError, TypeError):
|
||||
# If conversion fails, remove the invalid value
|
||||
args.pop('security_risk')
|
||||
|
||||
# handle deprecated args
|
||||
args = handle_action_deprecated_args(args)
|
||||
|
||||
try:
|
||||
decoded_action = action_class(**args)
|
||||
if 'timeout' in action:
|
||||
blocking = args.get('blocking', False)
|
||||
decoded_action.set_hard_timeout(action['timeout'], blocking=blocking)
|
||||
|
||||
# Set timestamp if it was provided
|
||||
if timestamp:
|
||||
decoded_action._timestamp = timestamp
|
||||
|
||||
except TypeError as e:
|
||||
raise LLMMalformedActionError(
|
||||
f'action={action} has the wrong arguments: {str(e)}'
|
||||
)
|
||||
assert isinstance(decoded_action, Action)
|
||||
return decoded_action
|
||||
@@ -1,178 +0,0 @@
|
||||
from dataclasses import asdict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.events import Event, EventSource
|
||||
from openhands.events.metrics import Cost, Metrics, ResponseLatency, TokenUsage
|
||||
from openhands.events.serialization.action import action_from_dict
|
||||
from openhands.events.serialization.observation import observation_from_dict
|
||||
from openhands.events.serialization.utils import remove_fields
|
||||
from openhands.events.tool import ToolCallMetadata
|
||||
|
||||
# TODO: move `content` into `extras`
|
||||
TOP_KEYS = [
|
||||
'id',
|
||||
'timestamp',
|
||||
'source',
|
||||
'message',
|
||||
'cause',
|
||||
'action',
|
||||
'observation',
|
||||
'tool_call_metadata',
|
||||
'llm_metrics',
|
||||
]
|
||||
UNDERSCORE_KEYS = [
|
||||
'id',
|
||||
'timestamp',
|
||||
'source',
|
||||
'cause',
|
||||
'tool_call_metadata',
|
||||
'llm_metrics',
|
||||
]
|
||||
|
||||
DELETE_FROM_TRAJECTORY_EXTRAS = {
|
||||
'dom_object',
|
||||
'axtree_object',
|
||||
'active_page_index',
|
||||
'last_browser_action',
|
||||
'last_browser_action_error',
|
||||
'focused_element_bid',
|
||||
'extra_element_properties',
|
||||
}
|
||||
|
||||
DELETE_FROM_TRAJECTORY_EXTRAS_AND_SCREENSHOTS = DELETE_FROM_TRAJECTORY_EXTRAS | {
|
||||
'screenshot',
|
||||
'set_of_marks',
|
||||
}
|
||||
|
||||
|
||||
def event_from_dict(data: dict[str, Any]) -> 'Event':
|
||||
evt: Event
|
||||
if 'action' in data:
|
||||
evt = action_from_dict(data)
|
||||
elif 'observation' in data:
|
||||
evt = observation_from_dict(data)
|
||||
else:
|
||||
raise ValueError(f'Unknown event type: {data}')
|
||||
for key in UNDERSCORE_KEYS:
|
||||
if key in data:
|
||||
value = data[key]
|
||||
if key == 'timestamp' and isinstance(value, datetime):
|
||||
value = value.isoformat()
|
||||
if key == 'source':
|
||||
value = EventSource(value)
|
||||
if key == 'tool_call_metadata':
|
||||
value = ToolCallMetadata(**value)
|
||||
if key == 'llm_metrics':
|
||||
metrics = Metrics()
|
||||
if isinstance(value, dict):
|
||||
metrics.accumulated_cost = value.get('accumulated_cost', 0.0)
|
||||
# Set max_budget_per_task if available
|
||||
metrics.max_budget_per_task = value.get('max_budget_per_task')
|
||||
for cost in value.get('costs', []):
|
||||
metrics._costs.append(Cost(**cost))
|
||||
metrics.response_latencies = [
|
||||
ResponseLatency(**latency)
|
||||
for latency in value.get('response_latencies', [])
|
||||
]
|
||||
metrics.token_usages = [
|
||||
TokenUsage(**usage) for usage in value.get('token_usages', [])
|
||||
]
|
||||
# Set accumulated token usage if available
|
||||
if 'accumulated_token_usage' in value:
|
||||
metrics._accumulated_token_usage = TokenUsage(
|
||||
**value.get('accumulated_token_usage', {})
|
||||
)
|
||||
value = metrics
|
||||
setattr(evt, '_' + key, value)
|
||||
return evt
|
||||
|
||||
|
||||
def _convert_pydantic_to_dict(obj: BaseModel | dict) -> dict:
|
||||
if isinstance(obj, BaseModel):
|
||||
return obj.model_dump()
|
||||
return obj
|
||||
|
||||
|
||||
def event_to_dict(event: 'Event') -> dict:
|
||||
props = asdict(event)
|
||||
d = {}
|
||||
for key in TOP_KEYS:
|
||||
if hasattr(event, key) and getattr(event, key) is not None:
|
||||
d[key] = getattr(event, key)
|
||||
elif hasattr(event, f'_{key}') and getattr(event, f'_{key}') is not None:
|
||||
d[key] = getattr(event, f'_{key}')
|
||||
if key == 'id' and d.get('id') == -1:
|
||||
d.pop('id', None)
|
||||
if key == 'timestamp' and 'timestamp' in d:
|
||||
if isinstance(d['timestamp'], datetime):
|
||||
d['timestamp'] = d['timestamp'].isoformat()
|
||||
if key == 'source' and 'source' in d:
|
||||
d['source'] = d['source'].value
|
||||
if key == 'recall_type' and 'recall_type' in d:
|
||||
d['recall_type'] = d['recall_type'].value
|
||||
if key == 'tool_call_metadata' and 'tool_call_metadata' in d:
|
||||
d['tool_call_metadata'] = d['tool_call_metadata'].model_dump()
|
||||
if key == 'llm_metrics' and 'llm_metrics' in d:
|
||||
d['llm_metrics'] = d['llm_metrics'].get()
|
||||
props.pop(key, None)
|
||||
|
||||
if 'security_risk' in props and props['security_risk'] is None:
|
||||
props.pop('security_risk')
|
||||
|
||||
# Remove task_completed from serialization when it's None (backward compatibility)
|
||||
if 'task_completed' in props and props['task_completed'] is None:
|
||||
props.pop('task_completed')
|
||||
if 'action' in d:
|
||||
# Handle security_risk for actions - include it in args
|
||||
if 'security_risk' in props:
|
||||
props['security_risk'] = props['security_risk'].value
|
||||
d['args'] = props
|
||||
if event.timeout is not None:
|
||||
d['timeout'] = event.timeout
|
||||
elif 'observation' in d:
|
||||
d['content'] = props.pop('content', '')
|
||||
|
||||
# props is a dict whose values can include a complex object like an instance of a BaseModel subclass
|
||||
# such as CmdOutputMetadata
|
||||
# we serialize it along with the rest
|
||||
# we also handle the Enum conversion for RecallObservation
|
||||
d['extras'] = {
|
||||
k: (v.value if isinstance(v, Enum) else _convert_pydantic_to_dict(v))
|
||||
for k, v in props.items()
|
||||
}
|
||||
# Include success field for CmdOutputObservation
|
||||
if hasattr(event, 'success'):
|
||||
d['success'] = event.success
|
||||
else:
|
||||
raise ValueError(f'Event must be either action or observation. has: {event}')
|
||||
return d
|
||||
|
||||
|
||||
def event_to_trajectory(event: 'Event', include_screenshots: bool = False) -> dict:
|
||||
d = event_to_dict(event)
|
||||
if 'extras' in d:
|
||||
remove_fields(
|
||||
d['extras'],
|
||||
DELETE_FROM_TRAJECTORY_EXTRAS
|
||||
if include_screenshots
|
||||
else DELETE_FROM_TRAJECTORY_EXTRAS_AND_SCREENSHOTS,
|
||||
)
|
||||
return d
|
||||
|
||||
|
||||
def truncate_content(content: str, max_chars: int | None = None) -> str:
|
||||
"""Truncate the middle of the observation content if it is too long."""
|
||||
if max_chars is None or len(content) <= max_chars or max_chars < 0:
|
||||
return content
|
||||
|
||||
# truncate the middle and include a message to the LLM about it
|
||||
half = max_chars // 2
|
||||
return (
|
||||
content[:half]
|
||||
+ '\n[... Observation truncated due to length ...]\n'
|
||||
+ content[-half:]
|
||||
)
|
||||
@@ -1,142 +0,0 @@
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
from openhands.events.observation.agent import (
|
||||
AgentCondensationObservation,
|
||||
AgentStateChangedObservation,
|
||||
AgentThinkObservation,
|
||||
MicroagentKnowledge,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.observation.browse import BrowserOutputObservation
|
||||
from openhands.events.observation.commands import (
|
||||
CmdOutputMetadata,
|
||||
CmdOutputObservation,
|
||||
IPythonRunCellObservation,
|
||||
)
|
||||
from openhands.events.observation.delegate import AgentDelegateObservation
|
||||
from openhands.events.observation.empty import (
|
||||
NullObservation,
|
||||
)
|
||||
from openhands.events.observation.error import ErrorObservation
|
||||
from openhands.events.observation.file_download import FileDownloadObservation
|
||||
from openhands.events.observation.files import (
|
||||
FileEditObservation,
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
)
|
||||
from openhands.events.observation.loop_recovery import LoopDetectionObservation
|
||||
from openhands.events.observation.mcp import MCPObservation
|
||||
from openhands.events.observation.observation import Observation
|
||||
from openhands.events.observation.reject import UserRejectObservation
|
||||
from openhands.events.observation.success import SuccessObservation
|
||||
from openhands.events.observation.task_tracking import TaskTrackingObservation
|
||||
from openhands.events.recall_type import RecallType
|
||||
|
||||
observations = (
|
||||
NullObservation,
|
||||
CmdOutputObservation,
|
||||
IPythonRunCellObservation,
|
||||
BrowserOutputObservation,
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
FileEditObservation,
|
||||
AgentDelegateObservation,
|
||||
SuccessObservation,
|
||||
ErrorObservation,
|
||||
AgentStateChangedObservation,
|
||||
UserRejectObservation,
|
||||
AgentCondensationObservation,
|
||||
AgentThinkObservation,
|
||||
RecallObservation,
|
||||
MCPObservation,
|
||||
FileDownloadObservation,
|
||||
TaskTrackingObservation,
|
||||
LoopDetectionObservation,
|
||||
)
|
||||
|
||||
OBSERVATION_TYPE_TO_CLASS = {
|
||||
observation_class.observation: observation_class # type: ignore[attr-defined]
|
||||
for observation_class in observations
|
||||
}
|
||||
|
||||
|
||||
def _update_cmd_output_metadata(
|
||||
metadata: dict[str, Any] | CmdOutputMetadata | None, **kwargs: Any
|
||||
) -> dict[str, Any] | CmdOutputMetadata:
|
||||
"""Update the metadata of a CmdOutputObservation.
|
||||
|
||||
If metadata is None, create a new CmdOutputMetadata instance.
|
||||
If metadata is a dict, update the dict.
|
||||
If metadata is a CmdOutputMetadata instance, update the instance.
|
||||
"""
|
||||
if metadata is None:
|
||||
return CmdOutputMetadata(**kwargs)
|
||||
|
||||
if isinstance(metadata, dict):
|
||||
metadata.update(**kwargs)
|
||||
elif isinstance(metadata, CmdOutputMetadata):
|
||||
for key, value in kwargs.items():
|
||||
setattr(metadata, key, value)
|
||||
return metadata
|
||||
|
||||
|
||||
def handle_observation_deprecated_extras(extras: dict) -> dict:
|
||||
# These are deprecated in https://github.com/OpenHands/OpenHands/pull/4881
|
||||
if 'exit_code' in extras:
|
||||
extras['metadata'] = _update_cmd_output_metadata(
|
||||
extras.get('metadata', None), exit_code=extras.pop('exit_code')
|
||||
)
|
||||
if 'command_id' in extras:
|
||||
extras['metadata'] = _update_cmd_output_metadata(
|
||||
extras.get('metadata', None), pid=extras.pop('command_id')
|
||||
)
|
||||
|
||||
# formatted_output_and_error has been deprecated in https://github.com/OpenHands/OpenHands/pull/6671
|
||||
if 'formatted_output_and_error' in extras:
|
||||
extras.pop('formatted_output_and_error')
|
||||
return extras
|
||||
|
||||
|
||||
def observation_from_dict(observation: dict) -> Observation:
|
||||
observation = observation.copy()
|
||||
if 'observation' not in observation:
|
||||
raise KeyError(f"'observation' key is not found in {observation=}")
|
||||
observation_class = OBSERVATION_TYPE_TO_CLASS.get(observation['observation'])
|
||||
if observation_class is None:
|
||||
raise KeyError(
|
||||
f"'{observation['observation']=}' is not defined. Available observations: {OBSERVATION_TYPE_TO_CLASS.keys()}"
|
||||
)
|
||||
observation.pop('observation')
|
||||
observation.pop('message', None)
|
||||
content = observation.pop('content', '')
|
||||
extras = copy.deepcopy(observation.pop('extras', {}))
|
||||
|
||||
extras = handle_observation_deprecated_extras(extras)
|
||||
|
||||
# convert metadata to CmdOutputMetadata if it is a dict
|
||||
if observation_class is CmdOutputObservation:
|
||||
if 'metadata' in extras and isinstance(extras['metadata'], dict):
|
||||
extras['metadata'] = CmdOutputMetadata(**extras['metadata'])
|
||||
elif 'metadata' in extras and isinstance(extras['metadata'], CmdOutputMetadata):
|
||||
pass
|
||||
else:
|
||||
extras['metadata'] = CmdOutputMetadata()
|
||||
|
||||
if observation_class is RecallObservation:
|
||||
# handle the Enum conversion
|
||||
if 'recall_type' in extras:
|
||||
extras['recall_type'] = RecallType(extras['recall_type'])
|
||||
|
||||
# convert dicts in microagent_knowledge to MicroagentKnowledge objects
|
||||
if 'microagent_knowledge' in extras and isinstance(
|
||||
extras['microagent_knowledge'], list
|
||||
):
|
||||
extras['microagent_knowledge'] = [
|
||||
MicroagentKnowledge(**item) if isinstance(item, dict) else item
|
||||
for item in extras['microagent_knowledge']
|
||||
]
|
||||
|
||||
obs = observation_class(content=content, **extras)
|
||||
assert isinstance(obs, Observation)
|
||||
return obs
|
||||
@@ -1,20 +0,0 @@
|
||||
def remove_fields(obj: dict | list | tuple, fields: set[str]) -> None:
|
||||
"""Remove fields from an object.
|
||||
|
||||
Parameters:
|
||||
- obj: The dictionary, or list of dictionaries to remove fields from
|
||||
- fields (set[str]): A set of field names to remove from the object
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
for field in fields:
|
||||
if field in obj:
|
||||
del obj[field]
|
||||
for _, value in obj.items():
|
||||
remove_fields(value, fields)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
for item in obj:
|
||||
remove_fields(item, fields)
|
||||
if hasattr(obj, '__dataclass_fields__'):
|
||||
raise ValueError(
|
||||
'Object must not contain dataclass, consider converting to dict first'
|
||||
)
|
||||
@@ -1,291 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import queue
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.event import Event, EventSource
|
||||
from openhands.events.event_store import EventStore
|
||||
from openhands.events.serialization.event import event_from_dict, event_to_dict
|
||||
from openhands.storage import FileStore
|
||||
from openhands.storage.locations import (
|
||||
get_conversation_dir,
|
||||
)
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.shutdown_listener import should_continue
|
||||
|
||||
|
||||
class EventStreamSubscriber(str, Enum):
|
||||
AGENT_CONTROLLER = 'agent_controller'
|
||||
RESOLVER = 'openhands_resolver'
|
||||
SERVER = 'server'
|
||||
RUNTIME = 'runtime'
|
||||
MEMORY = 'memory'
|
||||
MAIN = 'main'
|
||||
TEST = 'test'
|
||||
|
||||
|
||||
async def session_exists(
|
||||
sid: str, file_store: FileStore, user_id: str | None = None
|
||||
) -> bool:
|
||||
try:
|
||||
await call_sync_from_async(file_store.list, get_conversation_dir(sid, user_id))
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
class EventStream(EventStore):
|
||||
secrets: dict[str, str]
|
||||
# For each subscriber ID, there is a map of callback functions - useful
|
||||
# when there are multiple listeners
|
||||
_subscribers: dict[str, dict[str, Callable]]
|
||||
_lock: threading.Lock
|
||||
_queue: queue.Queue[Event]
|
||||
_queue_thread: threading.Thread
|
||||
_queue_loop: asyncio.AbstractEventLoop | None
|
||||
_thread_pools: dict[str, dict[str, ThreadPoolExecutor]]
|
||||
_thread_loops: dict[str, dict[str, asyncio.AbstractEventLoop]]
|
||||
_write_page_cache: list[dict]
|
||||
|
||||
def __init__(self, sid: str, file_store: FileStore, user_id: str | None = None):
|
||||
super().__init__(sid, file_store, user_id)
|
||||
self._stop_flag = threading.Event()
|
||||
self._queue: queue.Queue[Event] = queue.Queue()
|
||||
self._thread_pools = {}
|
||||
self._thread_loops = {}
|
||||
self._queue_loop = None
|
||||
self._queue_thread = threading.Thread(target=self._run_queue_loop)
|
||||
self._queue_thread.daemon = True
|
||||
self._queue_thread.start()
|
||||
self._subscribers = {}
|
||||
self._lock = threading.Lock()
|
||||
self.secrets = {}
|
||||
self._write_page_cache = []
|
||||
|
||||
def _init_thread_loop(self, subscriber_id: str, callback_id: str) -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
if subscriber_id not in self._thread_loops:
|
||||
self._thread_loops[subscriber_id] = {}
|
||||
self._thread_loops[subscriber_id][callback_id] = loop
|
||||
|
||||
def close(self) -> None:
|
||||
self._stop_flag.set()
|
||||
if self._queue_thread.is_alive():
|
||||
self._queue_thread.join()
|
||||
|
||||
subscriber_ids = list(self._subscribers.keys())
|
||||
for subscriber_id in subscriber_ids:
|
||||
callback_ids = list(self._subscribers[subscriber_id].keys())
|
||||
for callback_id in callback_ids:
|
||||
self._clean_up_subscriber(subscriber_id, callback_id)
|
||||
|
||||
# Clear queue
|
||||
while not self._queue.empty():
|
||||
self._queue.get()
|
||||
|
||||
def _clean_up_subscriber(self, subscriber_id: str, callback_id: str) -> None:
|
||||
if subscriber_id not in self._subscribers:
|
||||
logger.warning(f'Subscriber not found during cleanup: {subscriber_id}')
|
||||
return
|
||||
if callback_id not in self._subscribers[subscriber_id]:
|
||||
logger.warning(f'Callback not found during cleanup: {callback_id}')
|
||||
return
|
||||
if (
|
||||
subscriber_id in self._thread_loops
|
||||
and callback_id in self._thread_loops[subscriber_id]
|
||||
):
|
||||
loop = self._thread_loops[subscriber_id][callback_id]
|
||||
current_task = asyncio.current_task(loop)
|
||||
pending = [
|
||||
task for task in asyncio.all_tasks(loop) if task is not current_task
|
||||
]
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
try:
|
||||
loop.stop()
|
||||
loop.close()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'Error closing loop for {subscriber_id}/{callback_id}: {e}'
|
||||
)
|
||||
del self._thread_loops[subscriber_id][callback_id]
|
||||
|
||||
if (
|
||||
subscriber_id in self._thread_pools
|
||||
and callback_id in self._thread_pools[subscriber_id]
|
||||
):
|
||||
pool = self._thread_pools[subscriber_id][callback_id]
|
||||
pool.shutdown()
|
||||
del self._thread_pools[subscriber_id][callback_id]
|
||||
|
||||
del self._subscribers[subscriber_id][callback_id]
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
subscriber_id: EventStreamSubscriber,
|
||||
callback: Callable[[Event], None],
|
||||
callback_id: str,
|
||||
) -> None:
|
||||
initializer = partial(self._init_thread_loop, subscriber_id, callback_id)
|
||||
pool = ThreadPoolExecutor(max_workers=1, initializer=initializer)
|
||||
if subscriber_id not in self._subscribers:
|
||||
self._subscribers[subscriber_id] = {}
|
||||
self._thread_pools[subscriber_id] = {}
|
||||
|
||||
if callback_id in self._subscribers[subscriber_id]:
|
||||
raise ValueError(
|
||||
f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}'
|
||||
)
|
||||
|
||||
self._subscribers[subscriber_id][callback_id] = callback
|
||||
self._thread_pools[subscriber_id][callback_id] = pool
|
||||
|
||||
def unsubscribe(
|
||||
self, subscriber_id: EventStreamSubscriber, callback_id: str
|
||||
) -> None:
|
||||
if subscriber_id not in self._subscribers:
|
||||
logger.warning(f'Subscriber not found during unsubscribe: {subscriber_id}')
|
||||
return
|
||||
|
||||
if callback_id not in self._subscribers[subscriber_id]:
|
||||
logger.warning(f'Callback not found during unsubscribe: {callback_id}')
|
||||
return
|
||||
|
||||
self._clean_up_subscriber(subscriber_id, callback_id)
|
||||
|
||||
def add_event(self, event: Event, source: EventSource) -> None:
|
||||
if event.id != Event.INVALID_ID:
|
||||
raise ValueError(
|
||||
f'Event already has an ID:{event.id}. It was probably added back to the EventStream from inside a handler, triggering a loop.'
|
||||
)
|
||||
event._timestamp = datetime.now().isoformat()
|
||||
event._source = source # type: ignore [attr-defined]
|
||||
with self._lock:
|
||||
event._id = self.cur_id # type: ignore [attr-defined]
|
||||
self.cur_id += 1
|
||||
|
||||
# Take a copy of the current write page
|
||||
current_write_page = self._write_page_cache
|
||||
|
||||
data = event_to_dict(event)
|
||||
data = self._replace_secrets(data)
|
||||
event = event_from_dict(data)
|
||||
current_write_page.append(data)
|
||||
|
||||
# If the page is full, create a new page for future events / other threads to use
|
||||
if len(current_write_page) == self.cache_size:
|
||||
self._write_page_cache = []
|
||||
|
||||
if event.id is not None:
|
||||
# Write the event to the store - this can take some time
|
||||
event_json = json.dumps(data)
|
||||
filename = self._get_filename_for_id(event.id, self.user_id)
|
||||
if len(event_json) > 1_000_000: # Roughly 1MB in bytes, ignoring encoding
|
||||
logger.warning(
|
||||
f'Saving event JSON over 1MB: {len(event_json):,} bytes, filename: {filename}',
|
||||
extra={
|
||||
'user_id': self.user_id,
|
||||
'session_id': self.sid,
|
||||
'size': len(event_json),
|
||||
},
|
||||
)
|
||||
self.file_store.write(filename, event_json)
|
||||
|
||||
# Store the cache page last - if it is not present during reads then it will simply be bypassed.
|
||||
self._store_cache_page(current_write_page)
|
||||
self._queue.put(event)
|
||||
|
||||
def _store_cache_page(self, current_write_page: list[dict]):
|
||||
"""Store a page in the cache. Reading individual events is slow when there are a lot of them, so we use pages."""
|
||||
if len(current_write_page) < self.cache_size:
|
||||
return
|
||||
start = current_write_page[0]['id']
|
||||
end = start + self.cache_size
|
||||
contents = json.dumps(current_write_page)
|
||||
cache_filename = self._get_filename_for_cache(start, end)
|
||||
self.file_store.write(cache_filename, contents)
|
||||
|
||||
def set_secrets(self, secrets: dict[str, str]) -> None:
|
||||
self.secrets = secrets.copy()
|
||||
|
||||
def update_secrets(self, secrets: dict[str, str]) -> None:
|
||||
self.secrets.update(secrets)
|
||||
|
||||
def _replace_secrets(
|
||||
self, data: dict[str, Any], is_top_level: bool = True
|
||||
) -> dict[str, Any]:
|
||||
# Fields that should not have secrets replaced (only at top level - system metadata)
|
||||
TOP_LEVEL_PROTECTED_FIELDS = {
|
||||
'timestamp',
|
||||
'id',
|
||||
'source',
|
||||
'cause',
|
||||
'action',
|
||||
'observation',
|
||||
'message',
|
||||
}
|
||||
|
||||
for key in data:
|
||||
if is_top_level and key in TOP_LEVEL_PROTECTED_FIELDS:
|
||||
# Skip secret replacement for protected system fields at top level only
|
||||
continue
|
||||
elif isinstance(data[key], dict):
|
||||
data[key] = self._replace_secrets(data[key], is_top_level=False)
|
||||
elif isinstance(data[key], str):
|
||||
for secret in self.secrets.values():
|
||||
data[key] = data[key].replace(secret, '<secret_hidden>')
|
||||
return data
|
||||
|
||||
def _run_queue_loop(self) -> None:
|
||||
self._queue_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._queue_loop)
|
||||
try:
|
||||
self._queue_loop.run_until_complete(self._process_queue())
|
||||
finally:
|
||||
self._queue_loop.close()
|
||||
|
||||
async def _process_queue(self) -> None:
|
||||
while should_continue() and not self._stop_flag.is_set():
|
||||
event = None
|
||||
try:
|
||||
event = self._queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
# pass each event to each callback in order
|
||||
for key in sorted(self._subscribers.keys()):
|
||||
callbacks = self._subscribers[key]
|
||||
# Create a copy of the keys to avoid "dictionary changed size during iteration" error
|
||||
callback_ids = list(callbacks.keys())
|
||||
for callback_id in callback_ids:
|
||||
# Check if callback_id still exists (might have been removed during iteration)
|
||||
if callback_id in callbacks:
|
||||
callback = callbacks[callback_id]
|
||||
pool = self._thread_pools[key][callback_id]
|
||||
future = pool.submit(callback, event)
|
||||
future.add_done_callback(
|
||||
self._make_error_handler(callback_id, key)
|
||||
)
|
||||
|
||||
def _make_error_handler(
|
||||
self, callback_id: str, subscriber_id: str
|
||||
) -> Callable[[Any], None]:
|
||||
def _handle_callback_error(fut: Any) -> None:
|
||||
try:
|
||||
# This will raise any exception that occurred during callback execution
|
||||
fut.result()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'Error in event callback {callback_id} for subscriber {subscriber_id}: {str(e)}',
|
||||
)
|
||||
# Re-raise in the main thread so the error is not swallowed
|
||||
raise e
|
||||
|
||||
return _handle_callback_error
|
||||
@@ -1,11 +0,0 @@
|
||||
from litellm import ModelResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ToolCallMetadata(BaseModel):
|
||||
# See https://docs.litellm.ai/docs/completion/function_call#step-3---second-litellmcompletion-call
|
||||
function_name: str # Name of the function that was called
|
||||
tool_call_id: str # ID of the tool call
|
||||
|
||||
model_response: ModelResponse
|
||||
total_calls_in_response: int
|
||||
@@ -529,18 +529,19 @@ class TestOnEventStatsProcessing:
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_event_skips_non_stats_events(self):
|
||||
"""Test that on_event skips non-stats events."""
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from openhands.app_server.event_callback.webhook_router import on_event
|
||||
from openhands.events.action.message import MessageAction
|
||||
|
||||
conversation_id = uuid4()
|
||||
sandbox_id = 'sandbox_123'
|
||||
|
||||
# Create non-stats events
|
||||
# Create non-stats events (use MagicMock for non-ConversationStateUpdateEvent)
|
||||
mock_other_event = MagicMock()
|
||||
mock_other_event.id = uuid4()
|
||||
events = [
|
||||
ConversationStateUpdateEvent(key='execution_status', value='running'),
|
||||
MessageAction(content='test'),
|
||||
mock_other_event,
|
||||
]
|
||||
|
||||
mock_app_conversation_info = AppConversationInfo(
|
||||
|
||||
@@ -1,439 +0,0 @@
|
||||
from openhands.events.action import (
|
||||
Action,
|
||||
AgentFinishAction,
|
||||
AgentRejectAction,
|
||||
BrowseInteractiveAction,
|
||||
BrowseURLAction,
|
||||
CmdRunAction,
|
||||
FileEditAction,
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
MessageAction,
|
||||
RecallAction,
|
||||
)
|
||||
from openhands.events.action.action import ActionConfirmationStatus
|
||||
from openhands.events.action.files import FileEditSource, FileReadSource
|
||||
from openhands.events.serialization import (
|
||||
event_from_dict,
|
||||
event_to_dict,
|
||||
)
|
||||
|
||||
|
||||
def serialization_deserialization(
|
||||
original_action_dict, cls, max_message_chars: int = 10000
|
||||
):
|
||||
action_instance = event_from_dict(original_action_dict)
|
||||
assert isinstance(action_instance, Action), (
|
||||
'The action instance should be an instance of Action.'
|
||||
)
|
||||
assert isinstance(action_instance, cls), (
|
||||
f'The action instance should be an instance of {cls.__name__}.'
|
||||
)
|
||||
|
||||
# event_to_dict is the regular serialization of an event
|
||||
serialized_action_dict = event_to_dict(action_instance)
|
||||
|
||||
# it has an extra message property, for the UI
|
||||
serialized_action_dict.pop('message')
|
||||
assert serialized_action_dict == original_action_dict, (
|
||||
'The serialized action should match the original action dict.'
|
||||
)
|
||||
|
||||
|
||||
def test_event_props_serialization_deserialization():
|
||||
original_action_dict = {
|
||||
'id': 42,
|
||||
'source': 'agent',
|
||||
'timestamp': '2021-08-01T12:00:00',
|
||||
'action': 'message',
|
||||
'args': {
|
||||
'content': 'This is a test.',
|
||||
'image_urls': None,
|
||||
'file_urls': None,
|
||||
'wait_for_response': False,
|
||||
'security_risk': -1,
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_action_dict, MessageAction)
|
||||
|
||||
|
||||
def test_message_action_serialization_deserialization():
|
||||
original_action_dict = {
|
||||
'action': 'message',
|
||||
'args': {
|
||||
'content': 'This is a test.',
|
||||
'image_urls': None,
|
||||
'file_urls': None,
|
||||
'wait_for_response': False,
|
||||
'security_risk': -1,
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_action_dict, MessageAction)
|
||||
|
||||
|
||||
def test_agent_finish_action_serialization_deserialization():
|
||||
original_action_dict = {
|
||||
'action': 'finish',
|
||||
'args': {
|
||||
'outputs': {},
|
||||
'thought': '',
|
||||
'final_thought': '',
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_action_dict, AgentFinishAction)
|
||||
|
||||
|
||||
def test_agent_finish_action_legacy_task_completed_serialization():
|
||||
"""Test that old conversations with task_completed can still be loaded."""
|
||||
original_action_dict = {
|
||||
'action': 'finish',
|
||||
'args': {
|
||||
'outputs': {},
|
||||
'thought': '',
|
||||
'final_thought': 'Task completed',
|
||||
'task_completed': 'true', # This should be ignored during deserialization
|
||||
},
|
||||
}
|
||||
# This should work without errors - task_completed should be stripped out
|
||||
event = event_from_dict(original_action_dict)
|
||||
assert isinstance(event, Action)
|
||||
assert isinstance(event, AgentFinishAction)
|
||||
assert event.final_thought == 'Task completed'
|
||||
# task_completed attribute should not exist anymore
|
||||
assert not hasattr(event, 'task_completed')
|
||||
|
||||
# When serialized back, task_completed should not be present
|
||||
event_dict = event_to_dict(event)
|
||||
assert 'task_completed' not in event_dict['args']
|
||||
|
||||
|
||||
def test_agent_reject_action_serialization_deserialization():
|
||||
original_action_dict = {
|
||||
'action': 'reject',
|
||||
'args': {'outputs': {}, 'thought': ''},
|
||||
}
|
||||
serialization_deserialization(original_action_dict, AgentRejectAction)
|
||||
|
||||
|
||||
def test_cmd_run_action_serialization_deserialization():
|
||||
original_action_dict = {
|
||||
'action': 'run',
|
||||
'args': {
|
||||
'blocking': False,
|
||||
'command': 'echo "Hello world"',
|
||||
'is_input': False,
|
||||
'thought': '',
|
||||
'hidden': False,
|
||||
'confirmation_state': ActionConfirmationStatus.CONFIRMED,
|
||||
'is_static': False,
|
||||
'cwd': None,
|
||||
'security_risk': -1,
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_action_dict, CmdRunAction)
|
||||
|
||||
|
||||
def test_browse_url_action_serialization_deserialization():
|
||||
original_action_dict = {
|
||||
'action': 'browse',
|
||||
'args': {
|
||||
'thought': '',
|
||||
'url': 'https://www.example.com',
|
||||
'return_axtree': False,
|
||||
'security_risk': -1,
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_action_dict, BrowseURLAction)
|
||||
|
||||
|
||||
def test_browse_interactive_action_serialization_deserialization():
|
||||
original_action_dict = {
|
||||
'action': 'browse_interactive',
|
||||
'args': {
|
||||
'thought': '',
|
||||
'browser_actions': 'goto("https://www.example.com")',
|
||||
'browsergym_send_msg_to_user': '',
|
||||
'return_axtree': False,
|
||||
'security_risk': -1,
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_action_dict, BrowseInteractiveAction)
|
||||
|
||||
|
||||
def test_file_read_action_serialization_deserialization():
|
||||
original_action_dict = {
|
||||
'action': 'read',
|
||||
'args': {
|
||||
'path': '/path/to/file.txt',
|
||||
'start': 0,
|
||||
'end': -1,
|
||||
'thought': 'None',
|
||||
'impl_source': 'default',
|
||||
'view_range': None,
|
||||
'security_risk': -1,
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_action_dict, FileReadAction)
|
||||
|
||||
|
||||
def test_file_write_action_serialization_deserialization():
|
||||
original_action_dict = {
|
||||
'action': 'write',
|
||||
'args': {
|
||||
'path': '/path/to/file.txt',
|
||||
'content': 'Hello world',
|
||||
'start': 0,
|
||||
'end': 1,
|
||||
'thought': 'None',
|
||||
'security_risk': -1,
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_action_dict, FileWriteAction)
|
||||
|
||||
|
||||
def test_file_edit_action_aci_serialization_deserialization():
|
||||
original_action_dict = {
|
||||
'action': 'edit',
|
||||
'args': {
|
||||
'path': '/path/to/file.txt',
|
||||
'command': 'str_replace',
|
||||
'file_text': None,
|
||||
'old_str': 'old text',
|
||||
'new_str': 'new text',
|
||||
'insert_line': None,
|
||||
'content': '',
|
||||
'start': 1,
|
||||
'end': -1,
|
||||
'thought': 'Replacing text',
|
||||
'impl_source': 'oh_aci',
|
||||
'security_risk': -1,
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_action_dict, FileEditAction)
|
||||
|
||||
|
||||
def test_file_edit_action_llm_serialization_deserialization():
|
||||
original_action_dict = {
|
||||
'action': 'edit',
|
||||
'args': {
|
||||
'path': '/path/to/file.txt',
|
||||
'command': None,
|
||||
'file_text': None,
|
||||
'old_str': None,
|
||||
'new_str': None,
|
||||
'insert_line': None,
|
||||
'content': 'Updated content',
|
||||
'start': 1,
|
||||
'end': 10,
|
||||
'thought': 'Updating file content',
|
||||
'impl_source': 'llm_based_edit',
|
||||
'security_risk': -1,
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_action_dict, FileEditAction)
|
||||
|
||||
|
||||
def test_cmd_run_action_legacy_serialization():
|
||||
original_action_dict = {
|
||||
'action': 'run',
|
||||
'args': {
|
||||
'blocking': False,
|
||||
'command': 'echo "Hello world"',
|
||||
'thought': '',
|
||||
'hidden': False,
|
||||
'confirmation_state': ActionConfirmationStatus.CONFIRMED,
|
||||
'keep_prompt': False, # will be treated as no-op
|
||||
},
|
||||
}
|
||||
event = event_from_dict(original_action_dict)
|
||||
assert isinstance(event, Action)
|
||||
assert isinstance(event, CmdRunAction)
|
||||
assert event.command == 'echo "Hello world"'
|
||||
assert event.hidden is False
|
||||
assert not hasattr(event, 'keep_prompt')
|
||||
|
||||
event_dict = event_to_dict(event)
|
||||
assert 'keep_prompt' not in event_dict['args']
|
||||
assert (
|
||||
event_dict['args']['confirmation_state'] == ActionConfirmationStatus.CONFIRMED
|
||||
)
|
||||
assert event_dict['args']['blocking'] is False
|
||||
assert event_dict['args']['command'] == 'echo "Hello world"'
|
||||
assert event_dict['args']['thought'] == ''
|
||||
assert event_dict['args']['is_input'] is False
|
||||
|
||||
|
||||
def test_file_llm_based_edit_action_legacy_serialization():
|
||||
original_action_dict = {
|
||||
'action': 'edit',
|
||||
'args': {
|
||||
'path': '/path/to/file.txt',
|
||||
'content': 'dummy content',
|
||||
'start': 1,
|
||||
'end': -1,
|
||||
'thought': 'Replacing text',
|
||||
'impl_source': 'oh_aci',
|
||||
'translated_ipython_code': None,
|
||||
},
|
||||
}
|
||||
event = event_from_dict(original_action_dict)
|
||||
assert isinstance(event, Action)
|
||||
assert isinstance(event, FileEditAction)
|
||||
|
||||
# Common arguments
|
||||
assert event.path == '/path/to/file.txt'
|
||||
assert event.thought == 'Replacing text'
|
||||
assert event.impl_source == FileEditSource.OH_ACI
|
||||
assert not hasattr(event, 'translated_ipython_code')
|
||||
|
||||
# OH_ACI arguments
|
||||
assert event.command == ''
|
||||
assert event.file_text is None
|
||||
assert event.old_str is None
|
||||
assert event.new_str is None
|
||||
assert event.insert_line is None
|
||||
|
||||
# LLM-based editing arguments
|
||||
assert event.content == 'dummy content'
|
||||
assert event.start == 1
|
||||
assert event.end == -1
|
||||
|
||||
event_dict = event_to_dict(event)
|
||||
assert 'translated_ipython_code' not in event_dict['args']
|
||||
|
||||
# Common arguments
|
||||
assert event_dict['args']['path'] == '/path/to/file.txt'
|
||||
assert event_dict['args']['impl_source'] == 'oh_aci'
|
||||
assert event_dict['args']['thought'] == 'Replacing text'
|
||||
|
||||
# OH_ACI arguments
|
||||
assert event_dict['args']['command'] == ''
|
||||
assert event_dict['args']['file_text'] is None
|
||||
assert event_dict['args']['old_str'] is None
|
||||
assert event_dict['args']['new_str'] is None
|
||||
assert event_dict['args']['insert_line'] is None
|
||||
|
||||
# LLM-based editing arguments
|
||||
assert event_dict['args']['content'] == 'dummy content'
|
||||
assert event_dict['args']['start'] == 1
|
||||
assert event_dict['args']['end'] == -1
|
||||
|
||||
|
||||
def test_file_ohaci_edit_action_legacy_serialization():
|
||||
original_action_dict = {
|
||||
'action': 'edit',
|
||||
'args': {
|
||||
'path': '/workspace/game_2048.py',
|
||||
'content': '',
|
||||
'start': 1,
|
||||
'end': -1,
|
||||
'thought': "I'll help you create a simple 2048 game in Python. I'll use the str_replace_editor to create the file.",
|
||||
'impl_source': 'oh_aci',
|
||||
'translated_ipython_code': "print(file_editor(**{'command': 'create', 'path': '/workspace/game_2048.py', 'file_text': 'New file content'}))",
|
||||
},
|
||||
}
|
||||
event = event_from_dict(original_action_dict)
|
||||
assert isinstance(event, Action)
|
||||
assert isinstance(event, FileEditAction)
|
||||
|
||||
# Common arguments
|
||||
assert event.path == '/workspace/game_2048.py'
|
||||
assert (
|
||||
event.thought
|
||||
== "I'll help you create a simple 2048 game in Python. I'll use the str_replace_editor to create the file."
|
||||
)
|
||||
assert event.impl_source == FileEditSource.OH_ACI
|
||||
assert not hasattr(event, 'translated_ipython_code')
|
||||
|
||||
# OH_ACI arguments
|
||||
assert event.command == 'create'
|
||||
assert event.file_text == 'New file content'
|
||||
assert event.old_str is None
|
||||
assert event.new_str is None
|
||||
assert event.insert_line is None
|
||||
|
||||
# LLM-based editing arguments
|
||||
assert event.content == ''
|
||||
assert event.start == 1
|
||||
assert event.end == -1
|
||||
|
||||
event_dict = event_to_dict(event)
|
||||
assert 'translated_ipython_code' not in event_dict['args']
|
||||
|
||||
# Common arguments
|
||||
assert event_dict['args']['path'] == '/workspace/game_2048.py'
|
||||
assert event_dict['args']['impl_source'] == 'oh_aci'
|
||||
assert (
|
||||
event_dict['args']['thought']
|
||||
== "I'll help you create a simple 2048 game in Python. I'll use the str_replace_editor to create the file."
|
||||
)
|
||||
|
||||
# OH_ACI arguments
|
||||
assert event_dict['args']['command'] == 'create'
|
||||
assert event_dict['args']['file_text'] == 'New file content'
|
||||
assert event_dict['args']['old_str'] is None
|
||||
assert event_dict['args']['new_str'] is None
|
||||
assert event_dict['args']['insert_line'] is None
|
||||
|
||||
# LLM-based editing arguments
|
||||
assert event_dict['args']['content'] == ''
|
||||
assert event_dict['args']['start'] == 1
|
||||
assert event_dict['args']['end'] == -1
|
||||
|
||||
|
||||
def test_agent_microagent_action_serialization_deserialization():
|
||||
original_action_dict = {
|
||||
'action': 'recall',
|
||||
'args': {
|
||||
'query': 'What is the capital of France?',
|
||||
'thought': 'I need to find information about France',
|
||||
'recall_type': 'knowledge',
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_action_dict, RecallAction)
|
||||
|
||||
|
||||
def test_file_read_action_legacy_serialization():
|
||||
original_action_dict = {
|
||||
'action': 'read',
|
||||
'args': {
|
||||
'path': '/workspace/test.txt',
|
||||
'start': 0,
|
||||
'end': -1,
|
||||
'thought': 'Reading the file contents',
|
||||
'impl_source': 'oh_aci',
|
||||
'translated_ipython_code': "print(file_editor(**{'command': 'view', 'path': '/workspace/test.txt'}))",
|
||||
},
|
||||
}
|
||||
|
||||
event = event_from_dict(original_action_dict)
|
||||
assert isinstance(event, Action)
|
||||
assert isinstance(event, FileReadAction)
|
||||
|
||||
# Common arguments
|
||||
assert event.path == '/workspace/test.txt'
|
||||
assert event.thought == 'Reading the file contents'
|
||||
assert event.impl_source == FileReadSource.OH_ACI
|
||||
assert not hasattr(event, 'translated_ipython_code')
|
||||
assert not hasattr(
|
||||
event, 'command'
|
||||
) # FileReadAction should not have command attribute
|
||||
|
||||
# Read-specific arguments
|
||||
assert event.start == 0
|
||||
assert event.end == -1
|
||||
|
||||
event_dict = event_to_dict(event)
|
||||
assert 'translated_ipython_code' not in event_dict['args']
|
||||
assert (
|
||||
'command' not in event_dict['args']
|
||||
) # command should not be in serialized args
|
||||
|
||||
# Common arguments in serialized form
|
||||
assert event_dict['args']['path'] == '/workspace/test.txt'
|
||||
assert event_dict['args']['impl_source'] == 'oh_aci'
|
||||
assert event_dict['args']['thought'] == 'Reading the file contents'
|
||||
|
||||
# Read-specific arguments in serialized form
|
||||
assert event_dict['args']['start'] == 0
|
||||
assert event_dict['args']['end'] == -1
|
||||
@@ -1,32 +0,0 @@
|
||||
from openhands.events.observation.commands import (
|
||||
CmdOutputMetadata,
|
||||
CmdOutputObservation,
|
||||
IPythonRunCellObservation,
|
||||
)
|
||||
|
||||
|
||||
def test_cmd_output_success():
|
||||
# Test successful command
|
||||
obs = CmdOutputObservation(
|
||||
command='ls',
|
||||
content='file1.txt\nfile2.txt',
|
||||
metadata=CmdOutputMetadata(exit_code=0),
|
||||
)
|
||||
assert obs.success is True
|
||||
assert obs.error is False
|
||||
|
||||
# Test failed command
|
||||
obs = CmdOutputObservation(
|
||||
command='ls',
|
||||
content='No such file or directory',
|
||||
metadata=CmdOutputMetadata(exit_code=1),
|
||||
)
|
||||
assert obs.success is False
|
||||
assert obs.error is True
|
||||
|
||||
|
||||
def test_ipython_cell_success():
|
||||
# IPython cells are always successful
|
||||
obs = IPythonRunCellObservation(code='print("Hello")', content='Hello')
|
||||
assert obs.success is True
|
||||
assert obs.error is False
|
||||
@@ -1,152 +0,0 @@
|
||||
from openhands.events.action import CmdRunAction, MessageAction
|
||||
from openhands.events.action.action import ActionSecurityRisk
|
||||
from openhands.events.metrics import Cost, Metrics, ResponseLatency, TokenUsage
|
||||
from openhands.events.observation import CmdOutputMetadata, CmdOutputObservation
|
||||
from openhands.events.serialization import event_from_dict, event_to_dict
|
||||
|
||||
|
||||
def test_command_output_success_serialization():
|
||||
# Test successful command
|
||||
obs = CmdOutputObservation(
|
||||
command='ls',
|
||||
content='file1.txt\nfile2.txt',
|
||||
metadata=CmdOutputMetadata(exit_code=0),
|
||||
)
|
||||
serialized = event_to_dict(obs)
|
||||
assert serialized['success'] is True
|
||||
|
||||
# Test failed command
|
||||
obs = CmdOutputObservation(
|
||||
command='ls',
|
||||
content='No such file or directory',
|
||||
metadata=CmdOutputMetadata(exit_code=1),
|
||||
)
|
||||
serialized = event_to_dict(obs)
|
||||
assert serialized['success'] is False
|
||||
|
||||
|
||||
def test_metrics_basic_serialization():
|
||||
# Create a basic action with only accumulated_cost
|
||||
action = MessageAction(content='Hello, world!')
|
||||
metrics = Metrics()
|
||||
metrics.accumulated_cost = 0.03
|
||||
action._llm_metrics = metrics
|
||||
|
||||
# Test serialization
|
||||
serialized = event_to_dict(action)
|
||||
assert 'llm_metrics' in serialized
|
||||
assert serialized['llm_metrics']['accumulated_cost'] == 0.03
|
||||
assert serialized['llm_metrics']['costs'] == []
|
||||
assert serialized['llm_metrics']['response_latencies'] == []
|
||||
assert serialized['llm_metrics']['token_usages'] == []
|
||||
|
||||
# Test deserialization
|
||||
deserialized = event_from_dict(serialized)
|
||||
assert deserialized.llm_metrics is not None
|
||||
assert deserialized.llm_metrics.accumulated_cost == 0.03
|
||||
assert len(deserialized.llm_metrics.costs) == 0
|
||||
assert len(deserialized.llm_metrics.response_latencies) == 0
|
||||
assert len(deserialized.llm_metrics.token_usages) == 0
|
||||
|
||||
|
||||
def test_metrics_full_serialization():
|
||||
# Create an observation with all metrics fields
|
||||
obs = CmdOutputObservation(
|
||||
command='ls',
|
||||
content='test.txt',
|
||||
metadata=CmdOutputMetadata(exit_code=0),
|
||||
)
|
||||
metrics = Metrics(model_name='test-model')
|
||||
metrics.accumulated_cost = 0.03
|
||||
|
||||
# Add a cost
|
||||
cost = Cost(model='test-model', cost=0.02)
|
||||
metrics._costs.append(cost)
|
||||
|
||||
# Add a response latency
|
||||
latency = ResponseLatency(model='test-model', latency=0.5, response_id='test-id')
|
||||
metrics.response_latencies = [latency]
|
||||
|
||||
# Add token usage
|
||||
usage = TokenUsage(
|
||||
model='test-model',
|
||||
prompt_tokens=10,
|
||||
completion_tokens=20,
|
||||
cache_read_tokens=0,
|
||||
cache_write_tokens=0,
|
||||
response_id='test-id',
|
||||
)
|
||||
metrics.token_usages = [usage]
|
||||
|
||||
obs._llm_metrics = metrics
|
||||
|
||||
# Test serialization
|
||||
serialized = event_to_dict(obs)
|
||||
assert 'llm_metrics' in serialized
|
||||
metrics_dict = serialized['llm_metrics']
|
||||
assert metrics_dict['accumulated_cost'] == 0.03
|
||||
assert len(metrics_dict['costs']) == 1
|
||||
assert metrics_dict['costs'][0]['cost'] == 0.02
|
||||
assert len(metrics_dict['response_latencies']) == 1
|
||||
assert metrics_dict['response_latencies'][0]['latency'] == 0.5
|
||||
assert len(metrics_dict['token_usages']) == 1
|
||||
assert metrics_dict['token_usages'][0]['prompt_tokens'] == 10
|
||||
assert metrics_dict['token_usages'][0]['completion_tokens'] == 20
|
||||
|
||||
# Test deserialization
|
||||
deserialized = event_from_dict(serialized)
|
||||
assert deserialized.llm_metrics is not None
|
||||
assert deserialized.llm_metrics.accumulated_cost == 0.03
|
||||
assert len(deserialized.llm_metrics.costs) == 1
|
||||
assert deserialized.llm_metrics.costs[0].cost == 0.02
|
||||
assert len(deserialized.llm_metrics.response_latencies) == 1
|
||||
assert deserialized.llm_metrics.response_latencies[0].latency == 0.5
|
||||
assert len(deserialized.llm_metrics.token_usages) == 1
|
||||
assert deserialized.llm_metrics.token_usages[0].prompt_tokens == 10
|
||||
assert deserialized.llm_metrics.token_usages[0].completion_tokens == 20
|
||||
|
||||
|
||||
def test_metrics_none_serialization():
|
||||
# Test when metrics is None
|
||||
obs = CmdOutputObservation(
|
||||
command='ls',
|
||||
content='test.txt',
|
||||
metadata=CmdOutputMetadata(exit_code=0),
|
||||
)
|
||||
obs._llm_metrics = None
|
||||
|
||||
# Test serialization
|
||||
serialized = event_to_dict(obs)
|
||||
assert 'llm_metrics' not in serialized
|
||||
|
||||
# Test deserialization
|
||||
deserialized = event_from_dict(serialized)
|
||||
assert deserialized.llm_metrics is None
|
||||
|
||||
|
||||
def test_action_risk_serialization():
|
||||
# Test action with security risk
|
||||
action = CmdRunAction(command='rm -rf /tmp/test')
|
||||
action.security_risk = ActionSecurityRisk.HIGH
|
||||
|
||||
# Test serialization
|
||||
serialized = event_to_dict(action)
|
||||
assert 'security_risk' in serialized['args']
|
||||
assert serialized['args']['security_risk'] == ActionSecurityRisk.HIGH.value
|
||||
|
||||
# Test deserialization
|
||||
deserialized = event_from_dict(serialized)
|
||||
assert deserialized.security_risk == ActionSecurityRisk.HIGH
|
||||
|
||||
# Test action with no security risk
|
||||
action = CmdRunAction(command='ls')
|
||||
# Don't set action_risk
|
||||
|
||||
# Test serialization
|
||||
serialized = event_to_dict(action)
|
||||
assert 'security_risk' in serialized['args']
|
||||
assert serialized['args']['security_risk'] == ActionSecurityRisk.UNKNOWN.value
|
||||
|
||||
# Test deserialization
|
||||
deserialized = event_from_dict(serialized)
|
||||
assert deserialized.security_risk == ActionSecurityRisk.UNKNOWN
|
||||
@@ -1,865 +0,0 @@
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import psutil
|
||||
import pytest
|
||||
from pytest import TempPathFactory
|
||||
|
||||
from openhands.core.schema import ActionType, ObservationType
|
||||
from openhands.events import EventSource, EventStream, EventStreamSubscriber
|
||||
from openhands.events.action import (
|
||||
CmdRunAction,
|
||||
NullAction,
|
||||
)
|
||||
from openhands.events.action.files import (
|
||||
FileEditAction,
|
||||
FileReadAction,
|
||||
FileWriteAction,
|
||||
)
|
||||
from openhands.events.action.message import MessageAction
|
||||
from openhands.events.event import FileEditSource, FileReadSource
|
||||
from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.observation import NullObservation
|
||||
from openhands.events.observation.files import (
|
||||
FileEditObservation,
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
)
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.locations import (
|
||||
get_conversation_event_filename,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir(tmp_path_factory: TempPathFactory) -> str:
|
||||
return str(tmp_path_factory.mktemp('test_event_stream'))
|
||||
|
||||
|
||||
def collect_events(stream):
|
||||
return [event for event in stream.get_events()]
|
||||
|
||||
|
||||
def test_basic_flow(temp_dir: str):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
event_stream.add_event(NullAction(), EventSource.AGENT)
|
||||
assert len(collect_events(event_stream)) == 1
|
||||
|
||||
|
||||
def test_stream_storage(temp_dir: str):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
event_stream.add_event(NullObservation(''), EventSource.AGENT)
|
||||
assert len(collect_events(event_stream)) == 1
|
||||
content = event_stream.file_store.read(get_conversation_event_filename('abc', 0))
|
||||
assert content is not None
|
||||
data = json.loads(content)
|
||||
assert 'timestamp' in data
|
||||
del data['timestamp']
|
||||
assert data == {
|
||||
'id': 0,
|
||||
'source': 'agent',
|
||||
'observation': 'null',
|
||||
'content': '',
|
||||
'extras': {},
|
||||
'message': 'No observation',
|
||||
}
|
||||
|
||||
|
||||
def test_rehydration(temp_dir: str):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
event_stream.add_event(NullObservation('obs1'), EventSource.AGENT)
|
||||
event_stream.add_event(NullObservation('obs2'), EventSource.AGENT)
|
||||
assert len(collect_events(event_stream)) == 2
|
||||
|
||||
stream2 = EventStream('es2', file_store)
|
||||
assert len(collect_events(stream2)) == 0
|
||||
|
||||
stream1rehydrated = EventStream('abc', file_store)
|
||||
events = collect_events(stream1rehydrated)
|
||||
assert len(events) == 2
|
||||
assert events[0].content == 'obs1'
|
||||
assert events[1].content == 'obs2'
|
||||
|
||||
|
||||
def test_get_matching_events_type_filter(temp_dir: str):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
|
||||
# Add mixed event types
|
||||
event_stream.add_event(NullAction(), EventSource.AGENT)
|
||||
event_stream.add_event(NullObservation('test'), EventSource.AGENT)
|
||||
event_stream.add_event(NullAction(), EventSource.AGENT)
|
||||
event_stream.add_event(MessageAction(content='test'), EventSource.AGENT)
|
||||
|
||||
# Filter by NullAction
|
||||
events = event_stream.get_matching_events(event_types=(NullAction,))
|
||||
assert len(events) == 2
|
||||
assert all(isinstance(e, NullAction) for e in events)
|
||||
|
||||
# Filter by NullObservation
|
||||
events = event_stream.get_matching_events(event_types=(NullObservation,))
|
||||
assert len(events) == 1
|
||||
assert (
|
||||
isinstance(events[0], NullObservation)
|
||||
and events[0].observation == ObservationType.NULL
|
||||
)
|
||||
|
||||
# Filter by NullAction and MessageAction
|
||||
events = event_stream.get_matching_events(event_types=(NullAction, MessageAction))
|
||||
assert len(events) == 3
|
||||
|
||||
# Filter in reverse
|
||||
events = event_stream.get_matching_events(reverse=True, limit=3)
|
||||
assert len(events) == 3
|
||||
assert isinstance(events[0], MessageAction) and events[0].content == 'test'
|
||||
assert isinstance(events[2], NullObservation) and events[2].content == 'test'
|
||||
|
||||
|
||||
def test_get_matching_events_query_search(temp_dir: str):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
|
||||
event_stream.add_event(NullObservation('hello world'), EventSource.AGENT)
|
||||
event_stream.add_event(NullObservation('test message'), EventSource.AGENT)
|
||||
event_stream.add_event(NullObservation('another hello'), EventSource.AGENT)
|
||||
|
||||
# Search for 'hello'
|
||||
events = event_stream.get_matching_events(query='hello')
|
||||
assert len(events) == 2
|
||||
|
||||
# Search should be case-insensitive
|
||||
events = event_stream.get_matching_events(query='HELLO')
|
||||
assert len(events) == 2
|
||||
|
||||
# Search for non-existent text
|
||||
events = event_stream.get_matching_events(query='nonexistent')
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
def test_get_matching_events_source_filter(temp_dir: str):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
|
||||
event_stream.add_event(NullObservation('test1'), EventSource.AGENT)
|
||||
event_stream.add_event(NullObservation('test2'), EventSource.ENVIRONMENT)
|
||||
event_stream.add_event(NullObservation('test3'), EventSource.AGENT)
|
||||
|
||||
# Filter by AGENT source
|
||||
events = event_stream.get_matching_events(source='agent')
|
||||
assert len(events) == 2
|
||||
assert all(
|
||||
isinstance(e, NullObservation) and e.source == EventSource.AGENT for e in events
|
||||
)
|
||||
|
||||
# Filter by ENVIRONMENT source
|
||||
events = event_stream.get_matching_events(source='environment')
|
||||
assert len(events) == 1
|
||||
assert (
|
||||
isinstance(events[0], NullObservation)
|
||||
and events[0].source == EventSource.ENVIRONMENT
|
||||
)
|
||||
|
||||
# Test that source comparison works correctly with None source
|
||||
null_source_event = NullObservation('test4')
|
||||
event_stream.add_event(null_source_event, EventSource.AGENT)
|
||||
event = event_stream.get_event(event_stream.get_latest_event_id())
|
||||
event._source = None # type: ignore
|
||||
|
||||
# Update the serialized version
|
||||
data = event_to_dict(event)
|
||||
event_stream.file_store.write(
|
||||
event_stream._get_filename_for_id(event.id, event_stream.user_id),
|
||||
json.dumps(data),
|
||||
)
|
||||
|
||||
# Verify that source comparison works correctly
|
||||
assert EventFilter(source='agent').exclude(event)
|
||||
assert EventFilter(source=None).include(event)
|
||||
|
||||
# Filter by AGENT source again
|
||||
events = event_stream.get_matching_events(source='agent')
|
||||
assert len(events) == 2 # Should not include the None source event
|
||||
|
||||
|
||||
def test_get_matching_events_pagination(temp_dir: str):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
|
||||
# Add 5 events
|
||||
for i in range(5):
|
||||
event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT)
|
||||
|
||||
# Test limit
|
||||
events = event_stream.get_matching_events(limit=3)
|
||||
assert len(events) == 3
|
||||
|
||||
# Test start_id
|
||||
events = event_stream.get_matching_events(start_id=2)
|
||||
assert len(events) == 3
|
||||
assert isinstance(events[0], NullObservation) and events[0].content == 'test2'
|
||||
|
||||
# Test combination of start_id and limit
|
||||
events = event_stream.get_matching_events(start_id=1, limit=2)
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], NullObservation) and events[0].content == 'test1'
|
||||
assert isinstance(events[1], NullObservation) and events[1].content == 'test2'
|
||||
|
||||
|
||||
def test_get_matching_events_limit_validation(temp_dir: str):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
|
||||
# Test limit less than 1
|
||||
with pytest.raises(ValueError, match='Limit must be between 1 and 100'):
|
||||
event_stream.get_matching_events(limit=0)
|
||||
|
||||
# Test limit greater than 100
|
||||
with pytest.raises(ValueError, match='Limit must be between 1 and 100'):
|
||||
event_stream.get_matching_events(limit=101)
|
||||
|
||||
# Test valid limits work
|
||||
event_stream.add_event(NullObservation('test'), EventSource.AGENT)
|
||||
events = event_stream.get_matching_events(limit=1)
|
||||
assert len(events) == 1
|
||||
events = event_stream.get_matching_events(limit=100)
|
||||
assert len(events) == 1
|
||||
|
||||
|
||||
def test_memory_usage_file_operations(temp_dir: str):
|
||||
"""Test memory usage during file operations in EventStream.
|
||||
|
||||
This test verifies that memory usage during file operations is reasonable
|
||||
and that memory is properly cleaned up after operations complete.
|
||||
"""
|
||||
|
||||
def get_memory_mb():
|
||||
"""Get current memory usage in MB."""
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss / 1024 / 1024
|
||||
|
||||
# Create a test file with 100kb content
|
||||
test_file = os.path.join(temp_dir, 'test_file.txt')
|
||||
test_content = 'x' * (100 * 1024) # 100kb of data
|
||||
with open(test_file, 'w') as f:
|
||||
f.write(test_content)
|
||||
|
||||
# Initialize FileStore and EventStream
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
|
||||
# Record initial memory usage
|
||||
gc.collect()
|
||||
initial_memory = get_memory_mb()
|
||||
max_memory_increase = 0
|
||||
|
||||
# Perform operations 20 times
|
||||
for i in range(20):
|
||||
event_stream = EventStream('test_session', file_store)
|
||||
|
||||
# 1. Read file
|
||||
read_action = FileReadAction(
|
||||
path=test_file,
|
||||
start=0,
|
||||
end=-1,
|
||||
thought='Reading file',
|
||||
action=ActionType.READ,
|
||||
impl_source=FileReadSource.DEFAULT,
|
||||
)
|
||||
event_stream.add_event(read_action, EventSource.AGENT)
|
||||
|
||||
read_obs = FileReadObservation(
|
||||
path=test_file, impl_source=FileReadSource.DEFAULT, content=test_content
|
||||
)
|
||||
event_stream.add_event(read_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
# 2. Write file
|
||||
write_action = FileWriteAction(
|
||||
path=test_file,
|
||||
content=test_content,
|
||||
start=0,
|
||||
end=-1,
|
||||
thought='Writing file',
|
||||
action=ActionType.WRITE,
|
||||
)
|
||||
event_stream.add_event(write_action, EventSource.AGENT)
|
||||
|
||||
write_obs = FileWriteObservation(path=test_file, content=test_content)
|
||||
event_stream.add_event(write_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
# 3. Edit file
|
||||
edit_action = FileEditAction(
|
||||
path=test_file,
|
||||
content=test_content,
|
||||
start=1,
|
||||
end=-1,
|
||||
thought='Editing file',
|
||||
action=ActionType.EDIT,
|
||||
impl_source=FileEditSource.LLM_BASED_EDIT,
|
||||
)
|
||||
event_stream.add_event(edit_action, EventSource.AGENT)
|
||||
|
||||
edit_obs = FileEditObservation(
|
||||
path=test_file,
|
||||
prev_exist=True,
|
||||
old_content=test_content,
|
||||
new_content=test_content,
|
||||
impl_source=FileEditSource.LLM_BASED_EDIT,
|
||||
content=test_content,
|
||||
)
|
||||
event_stream.add_event(edit_obs, EventSource.ENVIRONMENT)
|
||||
|
||||
# Close event stream and force garbage collection
|
||||
event_stream.close()
|
||||
gc.collect()
|
||||
|
||||
# Check memory usage
|
||||
current_memory = get_memory_mb()
|
||||
memory_increase = current_memory - initial_memory
|
||||
max_memory_increase = max(max_memory_increase, memory_increase)
|
||||
|
||||
# Clean up
|
||||
os.remove(test_file)
|
||||
|
||||
# Memory increase should be reasonable (less than 50MB after 20 iterations)
|
||||
assert max_memory_increase < 50, (
|
||||
f'Memory increase of {max_memory_increase:.1f}MB exceeds limit of 50MB'
|
||||
)
|
||||
|
||||
|
||||
def test_cache_page_creation(temp_dir: str):
|
||||
"""Test that cache pages are created correctly when adding events."""
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('cache_test', file_store)
|
||||
|
||||
# Set a smaller cache size for testing
|
||||
event_stream.cache_size = 5
|
||||
|
||||
# Add events up to the cache size threshold
|
||||
for i in range(10):
|
||||
event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT)
|
||||
|
||||
# Check that a cache page was created after adding the 5th event
|
||||
cache_filename = event_stream._get_filename_for_cache(0, 5)
|
||||
|
||||
try:
|
||||
# Verify the content of the cache page
|
||||
cache_content = file_store.read(cache_filename)
|
||||
cache_exists = True
|
||||
except FileNotFoundError:
|
||||
cache_exists = False
|
||||
|
||||
assert cache_exists, f'Cache file {cache_filename} should exist'
|
||||
|
||||
# If cache exists, verify its content
|
||||
if cache_exists:
|
||||
cache_data = json.loads(cache_content)
|
||||
assert len(cache_data) == 5, 'Cache page should contain 5 events'
|
||||
|
||||
# Verify each event in the cache
|
||||
for i, event_data in enumerate(cache_data):
|
||||
assert event_data['content'] == f'test{i}', (
|
||||
f"Event {i} content should be 'test{i}'"
|
||||
)
|
||||
|
||||
|
||||
def test_cache_page_loading(temp_dir: str):
|
||||
"""Test that cache pages are loaded correctly when retrieving events."""
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
|
||||
# Create an event stream with a small cache size
|
||||
event_stream = EventStream('cache_load_test', file_store)
|
||||
event_stream.cache_size = 5
|
||||
|
||||
# Add enough events to create multiple cache pages
|
||||
for i in range(15):
|
||||
event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT)
|
||||
|
||||
# Create a new event stream to force loading from cache
|
||||
new_stream = EventStream('cache_load_test', file_store)
|
||||
new_stream.cache_size = 5
|
||||
|
||||
# Get all events and verify they're correct
|
||||
events = collect_events(new_stream)
|
||||
|
||||
# Check that we have a reasonable number of events (may not be exactly 15 due to implementation details)
|
||||
assert len(events) > 10, 'Should retrieve most of the events'
|
||||
|
||||
# Verify the events we did get are in the correct order and format
|
||||
for i, event in enumerate(events):
|
||||
assert isinstance(event, NullObservation), (
|
||||
f'Event {i} should be a NullObservation'
|
||||
)
|
||||
assert event.content == f'test{i}', f"Event {i} content should be 'test{i}'"
|
||||
|
||||
|
||||
def test_cache_page_performance(temp_dir: str):
|
||||
"""Test that using cache pages improves performance when retrieving many events."""
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
|
||||
# Create an event stream with cache enabled
|
||||
cached_stream = EventStream('perf_test_cached', file_store)
|
||||
cached_stream.cache_size = 10
|
||||
|
||||
# Add a significant number of events to the cached stream
|
||||
num_events = 50
|
||||
for i in range(num_events):
|
||||
cached_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT)
|
||||
|
||||
# Create a second event stream with a different session ID but same cache size
|
||||
uncached_stream = EventStream('perf_test_uncached', file_store)
|
||||
uncached_stream.cache_size = 10
|
||||
|
||||
# Add the same number of events to the uncached stream
|
||||
for i in range(num_events):
|
||||
uncached_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT)
|
||||
|
||||
# Measure time to retrieve all events from cached stream
|
||||
start_time = time.time()
|
||||
cached_events = collect_events(cached_stream)
|
||||
cached_time = time.time() - start_time
|
||||
|
||||
# Measure time to retrieve all events from uncached stream
|
||||
start_time = time.time()
|
||||
uncached_events = collect_events(uncached_stream)
|
||||
uncached_time = time.time() - start_time
|
||||
|
||||
# Verify both streams returned a reasonable number of events
|
||||
assert len(cached_events) > 40, 'Cached stream should return most of the events'
|
||||
assert len(uncached_events) > 40, 'Uncached stream should return most of the events'
|
||||
|
||||
# Log the performance difference
|
||||
logger_message = (
|
||||
f'Cached time: {cached_time:.4f}s, Uncached time: {uncached_time:.4f}s'
|
||||
)
|
||||
print(logger_message)
|
||||
|
||||
# We're primarily checking functionality here, not strict performance metrics
|
||||
# In real-world scenarios with many more events, the performance difference would be more significant.
|
||||
|
||||
|
||||
def test_search_events_limit(temp_dir: str):
|
||||
"""Test that the search_events method correctly applies the limit parameter."""
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
|
||||
# Add 10 events
|
||||
for i in range(10):
|
||||
event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT)
|
||||
|
||||
# Test with no limit (should return all events)
|
||||
events = list(event_stream.search_events())
|
||||
assert len(events) == 10
|
||||
|
||||
# Test with limit=5 (should return first 5 events)
|
||||
events = list(event_stream.search_events(limit=5))
|
||||
assert len(events) == 5
|
||||
assert all(isinstance(e, NullObservation) for e in events)
|
||||
assert [e.content for e in events] == ['test0', 'test1', 'test2', 'test3', 'test4']
|
||||
|
||||
# Test with limit=3 and start_id=5 (should return 3 events starting from ID 5)
|
||||
events = list(event_stream.search_events(start_id=5, limit=3))
|
||||
assert len(events) == 3
|
||||
assert [e.content for e in events] == ['test5', 'test6', 'test7']
|
||||
|
||||
# Test with limit and reverse=True (should return events in reverse order)
|
||||
events = list(event_stream.search_events(reverse=True, limit=4))
|
||||
assert len(events) == 4
|
||||
assert [e.content for e in events] == ['test9', 'test8', 'test7', 'test6']
|
||||
|
||||
# Test with limit and filter (should apply limit after filtering)
|
||||
# Add some events with different content for filtering
|
||||
event_stream.add_event(NullObservation('filter_me'), EventSource.AGENT)
|
||||
event_stream.add_event(NullObservation('filter_me_too'), EventSource.AGENT)
|
||||
|
||||
events = list(
|
||||
event_stream.search_events(filter=EventFilter(query='filter'), limit=1)
|
||||
)
|
||||
assert len(events) == 1
|
||||
assert events[0].content == 'filter_me'
|
||||
|
||||
|
||||
def test_search_events_limit_with_complex_filters(temp_dir: str):
|
||||
"""Test the interaction between limit and various filter combinations in search_events."""
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
|
||||
# Add events with different sources and types
|
||||
event_stream.add_event(NullAction(), EventSource.AGENT) # id 0
|
||||
event_stream.add_event(NullObservation('test1'), EventSource.AGENT) # id 1
|
||||
event_stream.add_event(MessageAction(content='hello'), EventSource.USER) # id 2
|
||||
event_stream.add_event(NullObservation('test2'), EventSource.ENVIRONMENT) # id 3
|
||||
event_stream.add_event(NullAction(), EventSource.AGENT) # id 4
|
||||
event_stream.add_event(MessageAction(content='world'), EventSource.USER) # id 5
|
||||
event_stream.add_event(NullObservation('hello world'), EventSource.AGENT) # id 6
|
||||
|
||||
# Test limit with type filter
|
||||
events = list(
|
||||
event_stream.search_events(
|
||||
filter=EventFilter(include_types=(NullAction,)), limit=1
|
||||
)
|
||||
)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], NullAction)
|
||||
assert events[0].id == 0
|
||||
|
||||
# Test limit with source filter
|
||||
events = list(
|
||||
event_stream.search_events(filter=EventFilter(source='user'), limit=1)
|
||||
)
|
||||
assert len(events) == 1
|
||||
assert events[0].source == EventSource.USER
|
||||
assert events[0].id == 2
|
||||
|
||||
# Test limit with query filter
|
||||
events = list(
|
||||
event_stream.search_events(filter=EventFilter(query='hello'), limit=2)
|
||||
)
|
||||
assert len(events) == 2
|
||||
assert [e.id for e in events] == [2, 6]
|
||||
|
||||
# Test limit with combined filters
|
||||
events = list(
|
||||
event_stream.search_events(
|
||||
filter=EventFilter(source='agent', include_types=(NullObservation,)),
|
||||
limit=1,
|
||||
)
|
||||
)
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], NullObservation)
|
||||
assert events[0].source == EventSource.AGENT
|
||||
assert events[0].id == 1
|
||||
|
||||
# Test limit with reverse and filter
|
||||
events = list(
|
||||
event_stream.search_events(
|
||||
filter=EventFilter(source='agent'), reverse=True, limit=2
|
||||
)
|
||||
)
|
||||
assert len(events) == 2
|
||||
assert [e.id for e in events] == [6, 4]
|
||||
|
||||
|
||||
def test_search_events_limit_edge_cases(temp_dir: str):
|
||||
"""Test edge cases for the limit parameter in search_events."""
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('abc', file_store)
|
||||
|
||||
# Add some events
|
||||
for i in range(5):
|
||||
event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT)
|
||||
|
||||
# Test with limit=None (should return all events)
|
||||
events = list(event_stream.search_events(limit=None))
|
||||
assert len(events) == 5
|
||||
|
||||
# Test with limit larger than number of events
|
||||
events = list(event_stream.search_events(limit=10))
|
||||
assert len(events) == 5
|
||||
|
||||
# Test with limit=0 (let's check actual behavior)
|
||||
events = list(event_stream.search_events(limit=0))
|
||||
# If it returns all events, assert len(events) == 5
|
||||
# If it returns no events, assert len(events) == 0
|
||||
# Let's check the actual behavior
|
||||
assert len(events) in [0, 5]
|
||||
|
||||
# Test with negative limit (implementation returns only first event)
|
||||
events = list(event_stream.search_events(limit=-1))
|
||||
assert len(events) == 1
|
||||
|
||||
# Test with empty result set and limit
|
||||
events = list(
|
||||
event_stream.search_events(filter=EventFilter(query='nonexistent'), limit=5)
|
||||
)
|
||||
assert len(events) == 0
|
||||
|
||||
# Test with start_id beyond available events
|
||||
events = list(event_stream.search_events(start_id=10, limit=5))
|
||||
assert len(events) == 0
|
||||
|
||||
|
||||
def test_callback_dictionary_modification(temp_dir: str):
|
||||
"""Test that the event stream can handle dictionary modification during iteration.
|
||||
|
||||
This test verifies that the fix for the 'dictionary changed size during iteration' error works.
|
||||
The test adds a callback that adds a new callback during iteration, which would cause an error
|
||||
without the fix.
|
||||
"""
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
event_stream = EventStream('callback_test', file_store)
|
||||
|
||||
# Track callback execution
|
||||
callback_executed = [False, False, False]
|
||||
|
||||
# Define a callback that will be added during iteration
|
||||
def callback_added_during_iteration(event):
|
||||
callback_executed[2] = True
|
||||
|
||||
# First callback that will be called
|
||||
def callback1(event):
|
||||
callback_executed[0] = True
|
||||
# This callback will add a new callback during iteration
|
||||
# Without our fix, this would cause a "dictionary changed size during iteration" error
|
||||
event_stream.subscribe(
|
||||
EventStreamSubscriber.TEST, callback_added_during_iteration, 'callback3'
|
||||
)
|
||||
|
||||
# Second callback that will be called
|
||||
def callback2(event):
|
||||
callback_executed[1] = True
|
||||
|
||||
# Subscribe both callbacks
|
||||
event_stream.subscribe(EventStreamSubscriber.TEST, callback1, 'callback1')
|
||||
event_stream.subscribe(EventStreamSubscriber.TEST, callback2, 'callback2')
|
||||
|
||||
# Add an event to trigger callbacks
|
||||
event_stream.add_event(NullObservation('test'), EventSource.AGENT)
|
||||
|
||||
# Give some time for the callbacks to execute
|
||||
time.sleep(0.5)
|
||||
|
||||
# Verify that the first two callbacks were executed
|
||||
assert callback_executed[0] is True, 'First callback should have been executed'
|
||||
assert callback_executed[1] is True, 'Second callback should have been executed'
|
||||
|
||||
# The third callback should not have been executed for this event
|
||||
# since it was added during iteration
|
||||
assert callback_executed[2] is False, (
|
||||
'Third callback should not have been executed for this event'
|
||||
)
|
||||
|
||||
# Add another event to trigger all callbacks including the newly added one
|
||||
callback_executed = [False, False, False] # Reset execution tracking
|
||||
event_stream.add_event(NullObservation('test2'), EventSource.AGENT)
|
||||
|
||||
# Give some time for the callbacks to execute
|
||||
time.sleep(0.5)
|
||||
|
||||
# Now all three callbacks should have been executed
|
||||
assert callback_executed[0] is True, 'First callback should have been executed'
|
||||
assert callback_executed[1] is True, 'Second callback should have been executed'
|
||||
assert callback_executed[2] is True, 'Third callback should have been executed'
|
||||
|
||||
# Clean up
|
||||
event_stream.close()
|
||||
|
||||
|
||||
def test_cache_page_partial_retrieval(temp_dir: str):
|
||||
"""Test retrieving events with start_id and end_id parameters using the cache."""
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
|
||||
# Create an event stream with a small cache size
|
||||
event_stream = EventStream('partial_test', file_store)
|
||||
event_stream.cache_size = 5
|
||||
|
||||
# Add events
|
||||
for i in range(20):
|
||||
event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT)
|
||||
|
||||
# Test retrieving a subset of events that spans multiple cache pages
|
||||
events = list(event_stream.get_events(start_id=3, end_id=12))
|
||||
|
||||
# Verify we got a reasonable number of events
|
||||
assert len(events) >= 8, 'Should retrieve most events in the range'
|
||||
|
||||
# Verify the events we did get are in the correct order
|
||||
for i, event in enumerate(events):
|
||||
expected_content = f'test{i + 3}'
|
||||
assert event.content == expected_content, (
|
||||
f"Event {i} content should be '{expected_content}'"
|
||||
)
|
||||
|
||||
# Test retrieving events in reverse order
|
||||
reverse_events = list(event_stream.get_events(start_id=3, end_id=12, reverse=True))
|
||||
|
||||
# Verify we got a reasonable number of events in reverse
|
||||
assert len(reverse_events) >= 8, 'Should retrieve most events in reverse'
|
||||
|
||||
# Check the first few events to ensure they're in reverse order
|
||||
if len(reverse_events) >= 3:
|
||||
assert reverse_events[0].content.startswith('test1'), (
|
||||
'First reverse event should be near the end of the range'
|
||||
)
|
||||
assert int(reverse_events[0].content[4:]) > int(
|
||||
reverse_events[1].content[4:]
|
||||
), 'Events should be in descending order'
|
||||
|
||||
|
||||
def test_cache_page_with_missing_events(temp_dir: str):
|
||||
"""Test cache behavior when some events are missing."""
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
|
||||
# Create an event stream with a small cache size
|
||||
event_stream = EventStream('missing_test', file_store)
|
||||
event_stream.cache_size = 5
|
||||
|
||||
# Add events
|
||||
for i in range(10):
|
||||
event_stream.add_event(NullObservation(f'test{i}'), EventSource.AGENT)
|
||||
|
||||
# Create a new event stream to force reloading events
|
||||
new_stream = EventStream('missing_test', file_store)
|
||||
new_stream.cache_size = 5
|
||||
|
||||
# Get the initial count of events
|
||||
initial_events = list(new_stream.get_events())
|
||||
initial_count = len(initial_events)
|
||||
|
||||
# Delete an event file to simulate a missing event
|
||||
# Choose an ID that's not at the beginning or end
|
||||
missing_id = 5
|
||||
missing_filename = new_stream._get_filename_for_id(missing_id, new_stream.user_id)
|
||||
try:
|
||||
file_store.delete(missing_filename)
|
||||
|
||||
# Create another stream to force reloading after deletion
|
||||
reload_stream = EventStream('missing_test', file_store)
|
||||
reload_stream.cache_size = 5
|
||||
|
||||
# Retrieve events after deletion
|
||||
events_after_deletion = list(reload_stream.get_events())
|
||||
|
||||
# We should have fewer events than before
|
||||
assert len(events_after_deletion) <= initial_count, (
|
||||
'Should have fewer or equal events after deletion'
|
||||
)
|
||||
|
||||
# Test that we can still retrieve events successfully
|
||||
assert len(events_after_deletion) > 0, 'Should still retrieve some events'
|
||||
|
||||
except Exception as e:
|
||||
# If the delete operation fails, we'll just verify that the basic functionality works
|
||||
print(f'Note: Could not delete file {missing_filename}: {e}')
|
||||
assert len(initial_events) > 0, 'Should retrieve events successfully'
|
||||
|
||||
|
||||
def test_secrets_replaced_in_content(temp_dir: str):
|
||||
"""Test that secrets are properly replaced in event content."""
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
stream = EventStream('test_session', file_store)
|
||||
|
||||
# Set up a secret
|
||||
stream.set_secrets({'api_key': 'secret123'})
|
||||
|
||||
# Create an event with the secret in the command
|
||||
action = CmdRunAction(
|
||||
command='curl -H "Authorization: Bearer secret123" https://api.example.com'
|
||||
)
|
||||
action._timestamp = datetime.now().isoformat()
|
||||
|
||||
# Convert to dict and apply secret replacement
|
||||
data = event_to_dict(action)
|
||||
data_with_secrets_replaced = stream._replace_secrets(data)
|
||||
|
||||
# The secret should be replaced in the command
|
||||
assert '<secret_hidden>' in data_with_secrets_replaced['args']['command']
|
||||
assert 'secret123' not in data_with_secrets_replaced['args']['command']
|
||||
|
||||
|
||||
def test_timestamp_not_affected_by_secret_replacement(temp_dir: str):
|
||||
"""Test that timestamps are not corrupted by secret replacement."""
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
stream = EventStream('test_session', file_store)
|
||||
|
||||
# Set up a secret that appears in the current date (e.g., "18" for 2025-07-18)
|
||||
stream.set_secrets({'test_secret': '18'})
|
||||
|
||||
# Create an event with a timestamp
|
||||
action = CmdRunAction(command='echo "hello world"')
|
||||
action._timestamp = '2025-07-18T17:01:36.799608' # Contains "18"
|
||||
|
||||
# Convert to dict and apply secret replacement
|
||||
data = event_to_dict(action)
|
||||
original_timestamp = data['timestamp']
|
||||
data_with_secrets_replaced = stream._replace_secrets(data)
|
||||
|
||||
# The timestamp should NOT be affected by secret replacement
|
||||
assert data_with_secrets_replaced['timestamp'] == original_timestamp
|
||||
assert '<secret_hidden>' not in data_with_secrets_replaced['timestamp']
|
||||
assert '18' in data_with_secrets_replaced['timestamp'] # Original value preserved
|
||||
|
||||
|
||||
def test_protected_fields_not_affected_by_secret_replacement(temp_dir: str):
|
||||
"""Test that protected system fields are not affected by secret replacement."""
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
stream = EventStream('test_session', file_store)
|
||||
|
||||
# Set up secrets that might appear in system fields
|
||||
stream.set_secrets(
|
||||
{
|
||||
'secret1': '123', # Could appear in ID
|
||||
'secret2': 'user', # Could appear in source
|
||||
'secret3': 'run', # Could appear in action/observation
|
||||
'secret4': 'Running', # Could appear in message
|
||||
}
|
||||
)
|
||||
|
||||
# Create test data with protected fields
|
||||
data = {
|
||||
'id': 123,
|
||||
'timestamp': '2025-07-18T17:01:36.799608',
|
||||
'source': 'user',
|
||||
'cause': 123,
|
||||
'action': 'run',
|
||||
'observation': 'run',
|
||||
'message': 'Running command: echo hello',
|
||||
'content': 'This contains secret1: 123 and secret2: user and secret3: run',
|
||||
}
|
||||
|
||||
data_with_secrets_replaced = stream._replace_secrets(data)
|
||||
|
||||
# Protected fields should not be affected at top level
|
||||
assert data_with_secrets_replaced['id'] == 123
|
||||
assert data_with_secrets_replaced['timestamp'] == '2025-07-18T17:01:36.799608'
|
||||
assert data_with_secrets_replaced['source'] == 'user'
|
||||
assert data_with_secrets_replaced['cause'] == 123
|
||||
assert data_with_secrets_replaced['action'] == 'run'
|
||||
assert data_with_secrets_replaced['observation'] == 'run'
|
||||
assert data_with_secrets_replaced['message'] == 'Running command: echo hello'
|
||||
|
||||
# But non-protected fields should have secrets replaced
|
||||
assert '<secret_hidden>' in data_with_secrets_replaced['content']
|
||||
assert '123' not in data_with_secrets_replaced['content']
|
||||
assert 'user' not in data_with_secrets_replaced['content']
|
||||
# Note: 'run' should still be replaced in content since it's not a protected field
|
||||
|
||||
|
||||
def test_nested_dict_secret_replacement(temp_dir: str):
|
||||
"""Test that secrets are replaced in nested dictionaries while preserving protected fields."""
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
stream = EventStream('test_session', file_store)
|
||||
|
||||
stream.set_secrets({'secret': 'password123'})
|
||||
|
||||
# Create nested data structure
|
||||
data = {
|
||||
'timestamp': '2025-07-18T17:01:36.799608',
|
||||
'args': {
|
||||
'command': 'login --password password123',
|
||||
'env': {
|
||||
'SECRET_KEY': 'password123',
|
||||
'timestamp': 'password123_timestamp', # This should be replaced since it's not top-level
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
data_with_secrets_replaced = stream._replace_secrets(data)
|
||||
|
||||
# Top-level timestamp should be protected
|
||||
assert data_with_secrets_replaced['timestamp'] == '2025-07-18T17:01:36.799608'
|
||||
|
||||
# Nested secrets should be replaced
|
||||
assert '<secret_hidden>' in data_with_secrets_replaced['args']['command']
|
||||
assert data_with_secrets_replaced['args']['env']['SECRET_KEY'] == '<secret_hidden>'
|
||||
assert '<secret_hidden>' in data_with_secrets_replaced['args']['env']['timestamp']
|
||||
|
||||
# Original secret should not appear in nested content
|
||||
assert 'password123' not in data_with_secrets_replaced['args']['command']
|
||||
assert 'password123' not in data_with_secrets_replaced['args']['env']['SECRET_KEY']
|
||||
assert 'password123' not in data_with_secrets_replaced['args']['env']['timestamp']
|
||||
@@ -1,135 +0,0 @@
|
||||
"""Tests for FileEditObservation class."""
|
||||
|
||||
from openhands.events.event import FileEditSource
|
||||
from openhands.events.observation.files import FileEditObservation
|
||||
|
||||
|
||||
def test_file_edit_observation_basic():
|
||||
"""Test basic properties of FileEditObservation."""
|
||||
obs = FileEditObservation(
|
||||
path='/test/file.txt',
|
||||
prev_exist=True,
|
||||
old_content='Hello\nWorld\n',
|
||||
new_content='Hello\nNew World\n',
|
||||
impl_source=FileEditSource.LLM_BASED_EDIT,
|
||||
content='Hello\nWorld\n', # Initial content is old_content
|
||||
)
|
||||
|
||||
assert obs.path == '/test/file.txt'
|
||||
assert obs.prev_exist is True
|
||||
assert obs.old_content == 'Hello\nWorld\n'
|
||||
assert obs.new_content == 'Hello\nNew World\n'
|
||||
assert obs.impl_source == FileEditSource.LLM_BASED_EDIT
|
||||
assert obs.message == 'I edited the file /test/file.txt.'
|
||||
|
||||
|
||||
def test_file_edit_observation_diff_cache():
|
||||
"""Test that diff visualization is cached."""
|
||||
obs = FileEditObservation(
|
||||
path='/test/file.txt',
|
||||
prev_exist=True,
|
||||
old_content='Hello\nWorld\n',
|
||||
new_content='Hello\nNew World\n',
|
||||
impl_source=FileEditSource.LLM_BASED_EDIT,
|
||||
content='Hello\nWorld\n', # Initial content is old_content
|
||||
)
|
||||
|
||||
# First call should compute diff
|
||||
diff1 = obs.visualize_diff()
|
||||
assert obs._diff_cache is not None
|
||||
|
||||
# Second call should use cache
|
||||
diff2 = obs.visualize_diff()
|
||||
assert diff1 == diff2
|
||||
|
||||
|
||||
def test_file_edit_observation_no_changes():
|
||||
"""Test behavior when content hasn't changed."""
|
||||
content = 'Hello\nWorld\n'
|
||||
obs = FileEditObservation(
|
||||
path='/test/file.txt',
|
||||
prev_exist=True,
|
||||
old_content=content,
|
||||
new_content=content,
|
||||
impl_source=FileEditSource.LLM_BASED_EDIT,
|
||||
content=content, # Initial content is old_content
|
||||
)
|
||||
|
||||
diff = obs.visualize_diff()
|
||||
assert '(no changes detected' in diff
|
||||
|
||||
|
||||
def test_file_edit_observation_get_edit_groups():
|
||||
"""Test the get_edit_groups method."""
|
||||
obs = FileEditObservation(
|
||||
path='/test/file.txt',
|
||||
prev_exist=True,
|
||||
old_content='Line 1\nLine 2\nLine 3\nLine 4\n',
|
||||
new_content='Line 1\nNew Line 2\nLine 3\nNew Line 4\n',
|
||||
impl_source=FileEditSource.LLM_BASED_EDIT,
|
||||
content='Line 1\nLine 2\nLine 3\nLine 4\n', # Initial content is old_content
|
||||
)
|
||||
|
||||
groups = obs.get_edit_groups(n_context_lines=1)
|
||||
assert len(groups) > 0
|
||||
|
||||
# Check structure of edit groups
|
||||
for group in groups:
|
||||
assert 'before_edits' in group
|
||||
assert 'after_edits' in group
|
||||
assert isinstance(group['before_edits'], list)
|
||||
assert isinstance(group['after_edits'], list)
|
||||
|
||||
# Verify line numbers and content
|
||||
first_group = groups[0]
|
||||
assert any('Line 2' in line for line in first_group['before_edits'])
|
||||
assert any('New Line 2' in line for line in first_group['after_edits'])
|
||||
|
||||
|
||||
def test_file_edit_observation_new_file():
|
||||
"""Test behavior when editing a new file."""
|
||||
obs = FileEditObservation(
|
||||
path='/test/new_file.txt',
|
||||
prev_exist=False,
|
||||
old_content='',
|
||||
new_content='Hello\nWorld\n',
|
||||
impl_source=FileEditSource.LLM_BASED_EDIT,
|
||||
content='', # Initial content is old_content (empty for new file)
|
||||
)
|
||||
|
||||
assert obs.prev_exist is False
|
||||
assert obs.old_content == ''
|
||||
assert (
|
||||
str(obs)
|
||||
== '[New file /test/new_file.txt is created with the provided content.]\n'
|
||||
)
|
||||
|
||||
# Test that trying to visualize diff for a new file works
|
||||
diff = obs.visualize_diff()
|
||||
assert diff is not None
|
||||
|
||||
|
||||
def test_file_edit_observation_context_lines():
|
||||
"""Test diff visualization with different context line settings."""
|
||||
obs = FileEditObservation(
|
||||
path='/test/file.txt',
|
||||
prev_exist=True,
|
||||
old_content='Line 1\nLine 2\nLine 3\nLine 4\nLine 5\n',
|
||||
new_content='Line 1\nNew Line 2\nLine 3\nNew Line 4\nLine 5\n',
|
||||
impl_source=FileEditSource.LLM_BASED_EDIT,
|
||||
content='Line 1\nLine 2\nLine 3\nLine 4\nLine 5\n', # Initial content is old_content
|
||||
)
|
||||
|
||||
# Test with 0 context lines
|
||||
groups_0 = obs.get_edit_groups(n_context_lines=0)
|
||||
# Test with 2 context lines
|
||||
groups_2 = obs.get_edit_groups(n_context_lines=2)
|
||||
|
||||
# More context should mean more lines in the groups
|
||||
total_lines_0 = sum(
|
||||
len(g['before_edits']) + len(g['after_edits']) for g in groups_0
|
||||
)
|
||||
total_lines_2 = sum(
|
||||
len(g['before_edits']) + len(g['after_edits']) for g in groups_2
|
||||
)
|
||||
assert total_lines_2 > total_lines_0
|
||||
@@ -1,146 +0,0 @@
|
||||
import json
|
||||
|
||||
from openhands.core.schema import ActionType, ObservationType
|
||||
from openhands.events.action.action import ActionSecurityRisk
|
||||
from openhands.events.action.mcp import MCPAction
|
||||
from openhands.events.observation.mcp import MCPObservation
|
||||
|
||||
|
||||
def test_mcp_action_creation():
|
||||
"""Test creating an MCPAction."""
|
||||
action = MCPAction(name='test_tool', arguments={'arg1': 'value1', 'arg2': 42})
|
||||
|
||||
assert action.name == 'test_tool'
|
||||
assert action.arguments == {'arg1': 'value1', 'arg2': 42}
|
||||
assert action.action == ActionType.MCP
|
||||
assert action.thought == ''
|
||||
assert action.runnable is True
|
||||
assert action.security_risk == ActionSecurityRisk.UNKNOWN
|
||||
|
||||
|
||||
def test_mcp_action_with_thought():
|
||||
"""Test creating an MCPAction with a thought."""
|
||||
action = MCPAction(
|
||||
name='test_tool',
|
||||
arguments={'arg1': 'value1', 'arg2': 42},
|
||||
thought='This is a test thought',
|
||||
)
|
||||
|
||||
assert action.name == 'test_tool'
|
||||
assert action.arguments == {'arg1': 'value1', 'arg2': 42}
|
||||
assert action.thought == 'This is a test thought'
|
||||
|
||||
|
||||
def test_mcp_action_message():
|
||||
"""Test the message property of MCPAction."""
|
||||
action = MCPAction(name='test_tool', arguments={'arg1': 'value1', 'arg2': 42})
|
||||
|
||||
message = action.message
|
||||
assert 'test_tool' in message
|
||||
assert 'arg1' in message
|
||||
assert 'value1' in message
|
||||
assert '42' in message
|
||||
|
||||
|
||||
def test_mcp_action_str_representation():
|
||||
"""Test the string representation of MCPAction."""
|
||||
action = MCPAction(
|
||||
name='test_tool',
|
||||
arguments={'arg1': 'value1', 'arg2': 42},
|
||||
thought='This is a test thought',
|
||||
)
|
||||
|
||||
str_repr = str(action)
|
||||
assert 'MCPAction' in str_repr
|
||||
assert 'THOUGHT: This is a test thought' in str_repr
|
||||
assert 'NAME: test_tool' in str_repr
|
||||
assert 'ARGUMENTS:' in str_repr
|
||||
assert 'arg1' in str_repr
|
||||
assert 'value1' in str_repr
|
||||
assert '42' in str_repr
|
||||
|
||||
|
||||
def test_mcp_observation_creation():
|
||||
"""Test creating an MCPObservation."""
|
||||
observation = MCPObservation(
|
||||
content=json.dumps({'result': 'success', 'data': 'test data'})
|
||||
)
|
||||
|
||||
assert observation.content == json.dumps({'result': 'success', 'data': 'test data'})
|
||||
assert observation.observation == ObservationType.MCP
|
||||
|
||||
|
||||
def test_mcp_observation_message():
|
||||
"""Test the message property of MCPObservation."""
|
||||
observation = MCPObservation(
|
||||
content=json.dumps({'result': 'success', 'data': 'test data'})
|
||||
)
|
||||
|
||||
message = observation.message
|
||||
assert message == json.dumps({'result': 'success', 'data': 'test data'})
|
||||
assert 'result' in message
|
||||
assert 'success' in message
|
||||
assert 'data' in message
|
||||
assert 'test data' in message
|
||||
|
||||
|
||||
def test_mcp_action_with_complex_arguments():
|
||||
"""Test MCPAction with complex nested arguments."""
|
||||
complex_args = {
|
||||
'simple_arg': 'value',
|
||||
'number_arg': 42,
|
||||
'boolean_arg': True,
|
||||
'nested_arg': {'inner_key': 'inner_value', 'inner_list': [1, 2, 3]},
|
||||
'list_arg': ['a', 'b', 'c'],
|
||||
}
|
||||
|
||||
action = MCPAction(name='complex_tool', arguments=complex_args)
|
||||
|
||||
assert action.name == 'complex_tool'
|
||||
assert action.arguments == complex_args
|
||||
assert action.arguments['nested_arg']['inner_key'] == 'inner_value'
|
||||
assert action.arguments['list_arg'] == ['a', 'b', 'c']
|
||||
|
||||
# Check that the message contains the complex arguments
|
||||
message = action.message
|
||||
assert 'complex_tool' in message
|
||||
assert 'nested_arg' in message
|
||||
assert 'inner_key' in message
|
||||
assert 'inner_value' in message
|
||||
|
||||
|
||||
def test_mcp_observation_with_arguments():
|
||||
"""Test MCPObservation with arguments."""
|
||||
complex_args = {
|
||||
'simple_arg': 'value',
|
||||
'number_arg': 42,
|
||||
'boolean_arg': True,
|
||||
'nested_arg': {'inner_key': 'inner_value', 'inner_list': [1, 2, 3]},
|
||||
'list_arg': ['a', 'b', 'c'],
|
||||
}
|
||||
|
||||
observation = MCPObservation(
|
||||
content=json.dumps({'result': 'success', 'data': 'test data'}),
|
||||
name='test_tool',
|
||||
arguments=complex_args,
|
||||
)
|
||||
|
||||
assert observation.content == json.dumps({'result': 'success', 'data': 'test data'})
|
||||
assert observation.observation == ObservationType.MCP
|
||||
assert observation.name == 'test_tool'
|
||||
assert observation.arguments == complex_args
|
||||
assert observation.arguments['nested_arg']['inner_key'] == 'inner_value'
|
||||
assert observation.arguments['list_arg'] == ['a', 'b', 'c']
|
||||
|
||||
# Test serialization
|
||||
from openhands.events.serialization import event_to_dict
|
||||
|
||||
serialized = event_to_dict(observation)
|
||||
|
||||
assert serialized['observation'] == ObservationType.MCP
|
||||
assert serialized['content'] == json.dumps(
|
||||
{'result': 'success', 'data': 'test data'}
|
||||
)
|
||||
assert serialized['extras']['name'] == 'test_tool'
|
||||
assert serialized['extras']['arguments'] == complex_args
|
||||
assert serialized['extras']['arguments']['nested_arg']['inner_key'] == 'inner_value'
|
||||
@@ -1,437 +0,0 @@
|
||||
"""Unit tests for the NestedEventStore class.
|
||||
|
||||
These tests focus on the search_events method, which retrieves events from a remote API
|
||||
and applies filtering based on various criteria.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.event_filter import EventFilter
|
||||
from openhands.events.nested_event_store import NestedEventStore
|
||||
|
||||
|
||||
def create_mock_event(
|
||||
id: int, content: str, source: str = 'user', hidden: bool = False
|
||||
) -> dict[str, Any]:
|
||||
"""Create a properly formatted mock event dictionary."""
|
||||
event_dict = {
|
||||
'id': id,
|
||||
'action': 'message', # This is required for the event_from_dict function
|
||||
'args': {
|
||||
'content': content,
|
||||
},
|
||||
'source': source,
|
||||
}
|
||||
|
||||
# Add hidden as a property that will be set on the event after deserialization
|
||||
if hidden:
|
||||
event_dict['hidden'] = True
|
||||
|
||||
return event_dict
|
||||
|
||||
|
||||
def create_mock_response(
|
||||
events: list[dict[str, Any]], has_more: bool = False
|
||||
) -> MagicMock:
|
||||
"""Helper function to create a mock HTTP response."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {'events': events, 'has_more': has_more}
|
||||
return mock_response
|
||||
|
||||
|
||||
class TestNestedEventStore:
|
||||
"""Tests for the NestedEventStore class."""
|
||||
|
||||
@pytest.fixture
|
||||
def event_store(self):
|
||||
"""Create a NestedEventStore instance for testing."""
|
||||
return NestedEventStore(
|
||||
base_url='http://test-api.example.com',
|
||||
sid='test-session',
|
||||
user_id='test-user',
|
||||
session_api_key='test-api-key',
|
||||
)
|
||||
|
||||
@patch('httpx.get')
|
||||
def test_search_events_basic(self, mock_get, event_store):
|
||||
"""Test basic event retrieval without filters."""
|
||||
# Setup mock response with two events
|
||||
mock_events = [
|
||||
create_mock_event(1, 'Hello', 'user'),
|
||||
create_mock_event(2, 'World', 'agent'),
|
||||
]
|
||||
mock_get.return_value = create_mock_response(mock_events)
|
||||
|
||||
# Call the method
|
||||
events = list(event_store.search_events())
|
||||
|
||||
# Verify the results
|
||||
assert len(events) == 2
|
||||
assert events[0].id == 1
|
||||
assert events[1].id == 2
|
||||
|
||||
# Verify the API call
|
||||
mock_get.assert_called_once_with(
|
||||
'http://test-api.example.com/events?start_id=0&reverse=False',
|
||||
headers={'X-Session-API-Key': 'test-api-key'},
|
||||
)
|
||||
|
||||
@patch('httpx.get')
|
||||
def test_search_events_with_limit(self, mock_get, event_store):
|
||||
"""Test event retrieval with a limit."""
|
||||
# Setup mock response
|
||||
mock_events = [
|
||||
create_mock_event(1, 'Hello', 'user'),
|
||||
create_mock_event(2, 'World', 'agent'),
|
||||
]
|
||||
mock_get.return_value = create_mock_response(mock_events)
|
||||
|
||||
# Call the method with a limit
|
||||
events = list(event_store.search_events(limit=1))
|
||||
|
||||
# Verify the results
|
||||
assert len(events) == 1
|
||||
assert events[0].id == 1
|
||||
|
||||
# Verify the API call includes the limit parameter
|
||||
mock_get.assert_called_once_with(
|
||||
'http://test-api.example.com/events?start_id=0&reverse=False&limit=1',
|
||||
headers={'X-Session-API-Key': 'test-api-key'},
|
||||
)
|
||||
|
||||
@patch('httpx.get')
|
||||
def test_search_events_with_start_id(self, mock_get, event_store):
|
||||
"""Test event retrieval with a specific start_id."""
|
||||
# Setup mock response
|
||||
mock_events = [
|
||||
create_mock_event(5, 'Hello', 'user'),
|
||||
create_mock_event(6, 'World', 'agent'),
|
||||
]
|
||||
mock_get.return_value = create_mock_response(mock_events)
|
||||
|
||||
# Call the method with a start_id
|
||||
events = list(event_store.search_events(start_id=5))
|
||||
|
||||
# Verify the results
|
||||
assert len(events) == 2
|
||||
assert events[0].id == 5
|
||||
assert events[1].id == 6
|
||||
|
||||
# Verify the API call includes the correct start_id
|
||||
mock_get.assert_called_once_with(
|
||||
'http://test-api.example.com/events?start_id=5&reverse=False',
|
||||
headers={'X-Session-API-Key': 'test-api-key'},
|
||||
)
|
||||
|
||||
@patch('httpx.get')
|
||||
def test_search_events_reverse_order(self, mock_get, event_store):
|
||||
"""Test event retrieval in reverse order."""
|
||||
# Setup mock response
|
||||
mock_events = [
|
||||
create_mock_event(3, 'World', 'agent'),
|
||||
create_mock_event(2, 'Hello', 'user'),
|
||||
]
|
||||
mock_get.return_value = create_mock_response(mock_events)
|
||||
|
||||
# Call the method with reverse=True
|
||||
events = list(event_store.search_events(reverse=True))
|
||||
|
||||
# Verify the results
|
||||
assert len(events) == 2
|
||||
assert events[0].id == 3
|
||||
assert events[1].id == 2
|
||||
|
||||
# Verify the API call includes reverse=True
|
||||
mock_get.assert_called_once_with(
|
||||
'http://test-api.example.com/events?start_id=0&reverse=True',
|
||||
headers={'X-Session-API-Key': 'test-api-key'},
|
||||
)
|
||||
|
||||
@patch('httpx.get')
|
||||
def test_search_events_with_end_id(self, mock_get, event_store):
|
||||
"""Test event retrieval with a specific end_id."""
|
||||
# Setup mock response
|
||||
mock_events = [
|
||||
create_mock_event(1, 'Hello', 'user'),
|
||||
create_mock_event(2, 'World', 'agent'),
|
||||
create_mock_event(3, 'End', 'user'),
|
||||
]
|
||||
mock_get.return_value = create_mock_response(mock_events)
|
||||
|
||||
# Call the method with an end_id
|
||||
events = list(event_store.search_events(end_id=3))
|
||||
|
||||
# Verify the results
|
||||
assert len(events) == 3
|
||||
assert events[0].id == 1
|
||||
assert events[1].id == 2
|
||||
assert events[2].id == 3
|
||||
|
||||
# Verify the API call
|
||||
mock_get.assert_called_once_with(
|
||||
'http://test-api.example.com/events?start_id=0&reverse=False',
|
||||
headers={'X-Session-API-Key': 'test-api-key'},
|
||||
)
|
||||
|
||||
@patch('httpx.get')
|
||||
@patch('openhands.events.event_filter.EventFilter.exclude')
|
||||
def test_search_events_with_filter(self, mock_exclude, mock_get, event_store):
|
||||
"""Test event retrieval with an EventFilter."""
|
||||
# Setup mock response with mixed events
|
||||
mock_events = [
|
||||
create_mock_event(1, 'Hello', 'user'),
|
||||
create_mock_event(2, 'World', 'agent'),
|
||||
create_mock_event(3, 'Hidden', 'user'),
|
||||
]
|
||||
mock_get.return_value = create_mock_response(mock_events)
|
||||
|
||||
# Configure the mock to exclude the third event (simulating exclude_hidden=True)
|
||||
# Return True for the third event (to exclude it) and False for others
|
||||
mock_exclude.side_effect = [False, False, True]
|
||||
|
||||
# Create a filter (the actual implementation doesn't matter since we're mocking exclude)
|
||||
event_filter = EventFilter()
|
||||
|
||||
# Call the method with the filter
|
||||
events = list(event_store.search_events(filter=event_filter))
|
||||
|
||||
# Verify the results (should exclude the third event)
|
||||
assert len(events) == 2
|
||||
assert events[0].id == 1
|
||||
assert events[1].id == 2
|
||||
|
||||
# Verify the API call
|
||||
mock_get.assert_called_once_with(
|
||||
'http://test-api.example.com/events?start_id=0&reverse=False',
|
||||
headers={'X-Session-API-Key': 'test-api-key'},
|
||||
)
|
||||
|
||||
@patch('httpx.get')
|
||||
def test_search_events_with_source_filter(self, mock_get, event_store):
|
||||
"""Test event retrieval with a source filter."""
|
||||
# Setup mock response with mixed sources
|
||||
mock_events = [
|
||||
create_mock_event(1, 'Hello', 'user'),
|
||||
create_mock_event(2, 'World', 'agent'),
|
||||
create_mock_event(3, 'Another', 'user'),
|
||||
]
|
||||
mock_get.return_value = create_mock_response(mock_events)
|
||||
|
||||
# Create a filter to include only user events
|
||||
event_filter = EventFilter(source='user')
|
||||
|
||||
# Call the method with the filter
|
||||
events = list(event_store.search_events(filter=event_filter))
|
||||
|
||||
# Verify the results (should only include user events)
|
||||
assert len(events) == 2
|
||||
assert events[0].id == 1
|
||||
assert events[0].source == EventSource.USER
|
||||
assert events[1].id == 3
|
||||
assert events[1].source == EventSource.USER
|
||||
|
||||
# Verify the API call
|
||||
mock_get.assert_called_once_with(
|
||||
'http://test-api.example.com/events?start_id=0&reverse=False',
|
||||
headers={'X-Session-API-Key': 'test-api-key'},
|
||||
)
|
||||
|
||||
@patch('httpx.get')
|
||||
def test_search_events_with_type_filter(self, mock_get, event_store):
|
||||
"""Test event retrieval with a type filter."""
|
||||
# Setup mock response with different event types
|
||||
mock_events = [
|
||||
create_mock_event(1, 'Hello', 'user'),
|
||||
# Create a different type of event (read)
|
||||
{
|
||||
'id': 2,
|
||||
'action': 'read', # Using the correct ActionType.READ value
|
||||
'args': {
|
||||
'path': '/test/file.txt',
|
||||
},
|
||||
'source': 'agent',
|
||||
},
|
||||
create_mock_event(3, 'Another', 'user'),
|
||||
]
|
||||
mock_get.return_value = create_mock_response(mock_events)
|
||||
|
||||
# Create a filter to include only MessageAction events
|
||||
event_filter = EventFilter(include_types=(MessageAction,))
|
||||
|
||||
# Call the method with the filter
|
||||
events = list(event_store.search_events(filter=event_filter))
|
||||
|
||||
# Verify the results (should only include MessageAction events)
|
||||
assert len(events) == 2
|
||||
assert events[0].id == 1
|
||||
assert events[1].id == 3
|
||||
|
||||
# Verify the API call
|
||||
mock_get.assert_called_once_with(
|
||||
'http://test-api.example.com/events?start_id=0&reverse=False',
|
||||
headers={'X-Session-API-Key': 'test-api-key'},
|
||||
)
|
||||
|
||||
@patch('httpx.get')
|
||||
def test_search_events_pagination(self, mock_get, event_store):
|
||||
"""Test event retrieval with pagination (has_more=True)."""
|
||||
# Setup first page response
|
||||
first_page_events = [
|
||||
create_mock_event(1, 'Hello', 'user'),
|
||||
create_mock_event(2, 'World', 'agent'),
|
||||
]
|
||||
first_response = create_mock_response(first_page_events, has_more=True)
|
||||
|
||||
# Setup second page response
|
||||
second_page_events = [
|
||||
create_mock_event(3, 'More', 'user'),
|
||||
create_mock_event(4, 'Data', 'agent'),
|
||||
]
|
||||
second_response = create_mock_response(second_page_events, has_more=False)
|
||||
|
||||
# Configure mock to return different responses on consecutive calls
|
||||
mock_get.side_effect = [first_response, second_response]
|
||||
|
||||
# Call the method
|
||||
events = list(event_store.search_events())
|
||||
|
||||
# Verify the results (should include all events from both pages)
|
||||
assert len(events) == 4
|
||||
assert events[0].id == 1
|
||||
assert events[1].id == 2
|
||||
assert events[2].id == 3
|
||||
assert events[3].id == 4
|
||||
|
||||
# Verify the API calls
|
||||
assert mock_get.call_count == 2
|
||||
# First call with start_id=0
|
||||
mock_get.assert_any_call(
|
||||
'http://test-api.example.com/events?start_id=0&reverse=False',
|
||||
headers={'X-Session-API-Key': 'test-api-key'},
|
||||
)
|
||||
# Second call with start_id=3 (after processing events with IDs 1 and 2)
|
||||
mock_get.assert_any_call(
|
||||
'http://test-api.example.com/events?start_id=3&reverse=False',
|
||||
headers={'X-Session-API-Key': 'test-api-key'},
|
||||
)
|
||||
|
||||
@patch('httpx.get')
|
||||
def test_search_events_no_session_api_key(self, mock_get):
|
||||
"""Test event retrieval without a session API key."""
|
||||
# Create event store without session_api_key
|
||||
event_store = NestedEventStore(
|
||||
base_url='http://test-api.example.com',
|
||||
sid='test-session',
|
||||
user_id='test-user',
|
||||
)
|
||||
|
||||
# Setup mock response
|
||||
mock_events = [create_mock_event(1, 'Hello', 'user')]
|
||||
mock_get.return_value = create_mock_response(mock_events)
|
||||
|
||||
# Call the method
|
||||
events = list(event_store.search_events())
|
||||
|
||||
# Verify the results
|
||||
assert len(events) == 1
|
||||
|
||||
# Verify the API call has no headers
|
||||
mock_get.assert_called_once_with(
|
||||
'http://test-api.example.com/events?start_id=0&reverse=False', headers={}
|
||||
)
|
||||
|
||||
@patch('httpx.get')
|
||||
def test_search_events_with_query_filter(self, mock_get, event_store):
|
||||
"""Test event retrieval with a text query filter."""
|
||||
# Setup mock response with different content
|
||||
mock_events = [
|
||||
create_mock_event(1, 'Hello world', 'user'),
|
||||
create_mock_event(2, 'Python is great', 'agent'),
|
||||
create_mock_event(3, 'Hello Python', 'user'),
|
||||
]
|
||||
mock_get.return_value = create_mock_response(mock_events)
|
||||
|
||||
# Create a filter to search for 'Python'
|
||||
event_filter = EventFilter(query='Python')
|
||||
|
||||
# Call the method with the filter
|
||||
events = list(event_store.search_events(filter=event_filter))
|
||||
|
||||
# Verify the results (should only include events with 'Python' in content)
|
||||
assert len(events) == 2
|
||||
assert events[0].id == 2
|
||||
assert events[1].id == 3
|
||||
|
||||
# Verify the API call
|
||||
mock_get.assert_called_once_with(
|
||||
'http://test-api.example.com/events?start_id=0&reverse=False',
|
||||
headers={'X-Session-API-Key': 'test-api-key'},
|
||||
)
|
||||
|
||||
@patch('httpx.get')
|
||||
def test_search_events_reverse_pagination_multiple_pages(
|
||||
self, mock_get, event_store
|
||||
):
|
||||
"""Ensure reverse pagination works across multiple server pages.
|
||||
|
||||
We emulate the remote /events endpoint by using an in-memory EventStream as the
|
||||
backing store and having httpx.get return paginated JSON responses derived from it.
|
||||
"""
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from openhands.events.event import EventSource
|
||||
from openhands.events.observation import NullObservation
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.events.stream import EventStream
|
||||
from openhands.storage.memory import InMemoryFileStore
|
||||
|
||||
# Build a fake server-side store with many events so that server pagination kicks in
|
||||
fs = InMemoryFileStore()
|
||||
server_stream = EventStream('test-session', fs, user_id='test-user')
|
||||
total_events = 50
|
||||
for i in range(total_events):
|
||||
server_stream.add_event(NullObservation(f'e{i}'), EventSource.AGENT)
|
||||
|
||||
def server_side_get(url: str, headers: dict | None = None):
|
||||
# Parse query params like the FastAPI layer would receive
|
||||
parsed = urlparse(url)
|
||||
qs = parse_qs(parsed.query)
|
||||
start_id = int(qs.get('start_id', ['0'])[0])
|
||||
reverse = qs.get('reverse', ['False'])[0] == 'True'
|
||||
end_id = int(qs['end_id'][0]) if 'end_id' in qs else None
|
||||
limit = int(qs.get('limit', ['20'])[0]) # server default = 20
|
||||
|
||||
# Emulate server route logic: request limit+1 to compute has_more
|
||||
events = list(
|
||||
server_stream.search_events(
|
||||
start_id=start_id, end_id=end_id, reverse=reverse, limit=limit + 1
|
||||
)
|
||||
)
|
||||
has_more = len(events) > limit
|
||||
if has_more:
|
||||
events = events[:limit]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'events': [event_to_dict(e) for e in events],
|
||||
'has_more': has_more,
|
||||
}
|
||||
return mock_response
|
||||
|
||||
mock_get.side_effect = server_side_get
|
||||
|
||||
# Execute the nested search in reverse without a client-side limit
|
||||
results = list(event_store.search_events(reverse=True))
|
||||
|
||||
# Verify we received all events in descending order
|
||||
assert len(results) == total_events
|
||||
assert [e.id for e in results] == list(range(total_events - 1, -1, -1))
|
||||
|
||||
# Ensure multiple HTTP calls were made due to pagination
|
||||
assert mock_get.call_count >= 2
|
||||
@@ -1,496 +0,0 @@
|
||||
from openhands.core.schema.observation import ObservationType
|
||||
from openhands.events.action.files import FileEditSource
|
||||
from openhands.events.observation import (
|
||||
CmdOutputMetadata,
|
||||
CmdOutputObservation,
|
||||
FileEditObservation,
|
||||
Observation,
|
||||
RecallObservation,
|
||||
)
|
||||
from openhands.events.observation.agent import MicroagentKnowledge
|
||||
from openhands.events.observation.commands import MAX_CMD_OUTPUT_SIZE
|
||||
from openhands.events.recall_type import RecallType
|
||||
from openhands.events.serialization import (
|
||||
event_from_dict,
|
||||
event_to_dict,
|
||||
event_to_trajectory,
|
||||
)
|
||||
from openhands.events.serialization.observation import observation_from_dict
|
||||
|
||||
|
||||
def serialization_deserialization(
|
||||
original_observation_dict, cls, max_message_chars: int = 10000
|
||||
):
|
||||
observation_instance = event_from_dict(original_observation_dict)
|
||||
assert isinstance(observation_instance, Observation), (
|
||||
'The observation instance should be an instance of Observation.'
|
||||
)
|
||||
assert isinstance(observation_instance, cls), (
|
||||
f'The observation instance should be an instance of {cls}.'
|
||||
)
|
||||
serialized_observation_dict = event_to_dict(observation_instance)
|
||||
serialized_observation_trajectory = event_to_trajectory(observation_instance)
|
||||
assert serialized_observation_dict == original_observation_dict, (
|
||||
'The serialized observation should match the original observation dict.'
|
||||
)
|
||||
assert serialized_observation_trajectory == original_observation_dict, (
|
||||
'The serialized observation trajectory should match the original observation dict.'
|
||||
)
|
||||
|
||||
|
||||
# Additional tests for various observation subclasses can be included here
|
||||
def test_observation_event_props_serialization_deserialization():
|
||||
original_observation_dict = {
|
||||
'id': 42,
|
||||
'source': 'agent',
|
||||
'timestamp': '2021-08-01T12:00:00',
|
||||
'observation': 'run',
|
||||
'message': 'Command `ls -l` executed with exit code 0.',
|
||||
'extras': {
|
||||
'command': 'ls -l',
|
||||
'hidden': False,
|
||||
'metadata': {
|
||||
'exit_code': 0,
|
||||
'hostname': None,
|
||||
'pid': -1,
|
||||
'prefix': '',
|
||||
'py_interpreter_path': None,
|
||||
'suffix': '',
|
||||
'username': None,
|
||||
'working_dir': None,
|
||||
},
|
||||
},
|
||||
'content': 'foo.txt',
|
||||
'success': True,
|
||||
}
|
||||
serialization_deserialization(original_observation_dict, CmdOutputObservation)
|
||||
|
||||
|
||||
def test_command_output_observation_serialization_deserialization():
|
||||
original_observation_dict = {
|
||||
'observation': 'run',
|
||||
'extras': {
|
||||
'command': 'ls -l',
|
||||
'hidden': False,
|
||||
'metadata': {
|
||||
'exit_code': 0,
|
||||
'hostname': None,
|
||||
'pid': -1,
|
||||
'prefix': '',
|
||||
'py_interpreter_path': None,
|
||||
'suffix': '',
|
||||
'username': None,
|
||||
'working_dir': None,
|
||||
},
|
||||
},
|
||||
'message': 'Command `ls -l` executed with exit code 0.',
|
||||
'content': 'foo.txt',
|
||||
'success': True,
|
||||
}
|
||||
serialization_deserialization(original_observation_dict, CmdOutputObservation)
|
||||
|
||||
|
||||
def test_success_field_serialization():
|
||||
# Test success=True
|
||||
obs = CmdOutputObservation(
|
||||
content='Command succeeded',
|
||||
command='ls -l',
|
||||
metadata=CmdOutputMetadata(
|
||||
exit_code=0,
|
||||
),
|
||||
)
|
||||
serialized = event_to_dict(obs)
|
||||
assert serialized['success'] is True
|
||||
|
||||
# Test success=False
|
||||
obs = CmdOutputObservation(
|
||||
content='No such file or directory',
|
||||
command='ls -l',
|
||||
metadata=CmdOutputMetadata(
|
||||
exit_code=1,
|
||||
),
|
||||
)
|
||||
serialized = event_to_dict(obs)
|
||||
assert serialized['success'] is False
|
||||
|
||||
|
||||
def test_cmd_output_truncation():
|
||||
"""Test that large command outputs are truncated during initialization."""
|
||||
# Create a large content string that exceeds MAX_CMD_OUTPUT_SIZE (50000 characters)
|
||||
large_content = 'a' * 60000 # 60k characters
|
||||
|
||||
# Create a CmdOutputObservation with the large content
|
||||
obs = CmdOutputObservation(
|
||||
content=large_content,
|
||||
command='ls -R',
|
||||
metadata=CmdOutputMetadata(
|
||||
exit_code=0,
|
||||
),
|
||||
)
|
||||
|
||||
# Verify the content was truncated
|
||||
assert len(obs.content) < 60000
|
||||
|
||||
# The truncated content might be slightly larger than MAX_CMD_OUTPUT_SIZE
|
||||
# due to the added truncation message
|
||||
truncation_msg = '[... Observation truncated due to length ...]'
|
||||
assert truncation_msg in obs.content
|
||||
|
||||
# The truncation algorithm might add a few extra characters due to the truncation message
|
||||
# We'll allow a small margin (1% of MAX_CMD_OUTPUT_SIZE) for the total content length
|
||||
margin = int(MAX_CMD_OUTPUT_SIZE * 0.01) # 1% margin
|
||||
assert len(obs.content) <= MAX_CMD_OUTPUT_SIZE + margin
|
||||
|
||||
# Verify the beginning and end of the content are preserved
|
||||
half_size = MAX_CMD_OUTPUT_SIZE // 2
|
||||
assert obs.content.startswith('a' * half_size)
|
||||
assert obs.content.endswith('a' * half_size)
|
||||
|
||||
|
||||
def test_cmd_output_no_truncation():
|
||||
"""Test that small command outputs are not truncated."""
|
||||
# Create a content string that doesn't exceed MAX_CMD_OUTPUT_SIZE (50000 characters)
|
||||
# We use a much smaller value for testing efficiency
|
||||
small_content = 'a' * 1000 # 1k characters
|
||||
|
||||
# Create a CmdOutputObservation with the small content
|
||||
obs = CmdOutputObservation(
|
||||
content=small_content,
|
||||
command='ls',
|
||||
metadata=CmdOutputMetadata(
|
||||
exit_code=0,
|
||||
),
|
||||
)
|
||||
|
||||
# Verify the content was not truncated
|
||||
assert len(obs.content) == 1000
|
||||
assert obs.content == small_content
|
||||
|
||||
|
||||
def test_legacy_serialization():
|
||||
original_observation_dict = {
|
||||
'id': 42,
|
||||
'source': 'agent',
|
||||
'timestamp': '2021-08-01T12:00:00',
|
||||
'observation': 'run',
|
||||
'message': 'Command `ls -l` executed with exit code 0.',
|
||||
'extras': {
|
||||
'command': 'ls -l',
|
||||
'hidden': False,
|
||||
'exit_code': 0,
|
||||
'command_id': 3,
|
||||
},
|
||||
'content': 'foo.txt',
|
||||
'success': True,
|
||||
}
|
||||
event = event_from_dict(original_observation_dict)
|
||||
assert isinstance(event, Observation)
|
||||
assert isinstance(event, CmdOutputObservation)
|
||||
assert event.metadata.exit_code == 0
|
||||
assert event.success is True
|
||||
assert event.command == 'ls -l'
|
||||
assert event.hidden is False
|
||||
|
||||
event_dict = event_to_dict(event)
|
||||
assert event_dict['success'] is True
|
||||
assert event_dict['extras']['metadata']['exit_code'] == 0
|
||||
assert event_dict['extras']['metadata']['pid'] == 3
|
||||
assert event_dict['extras']['command'] == 'ls -l'
|
||||
assert event_dict['extras']['hidden'] is False
|
||||
|
||||
|
||||
def test_file_edit_observation_serialization():
|
||||
original_observation_dict = {
|
||||
'observation': 'edit',
|
||||
'extras': {
|
||||
'_diff_cache': None,
|
||||
'impl_source': FileEditSource.LLM_BASED_EDIT,
|
||||
'new_content': None,
|
||||
'old_content': None,
|
||||
'path': '',
|
||||
'prev_exist': False,
|
||||
'diff': None,
|
||||
},
|
||||
'message': 'I edited the file .',
|
||||
'content': '[Existing file /path/to/file.txt is edited with 1 changes.]',
|
||||
}
|
||||
serialization_deserialization(original_observation_dict, FileEditObservation)
|
||||
|
||||
|
||||
def test_file_edit_observation_new_file_serialization():
|
||||
original_observation_dict = {
|
||||
'observation': 'edit',
|
||||
'content': '[New file /path/to/newfile.txt is created with the provided content.]',
|
||||
'extras': {
|
||||
'_diff_cache': None,
|
||||
'impl_source': FileEditSource.LLM_BASED_EDIT,
|
||||
'new_content': None,
|
||||
'old_content': None,
|
||||
'path': '',
|
||||
'prev_exist': False,
|
||||
'diff': None,
|
||||
},
|
||||
'message': 'I edited the file .',
|
||||
}
|
||||
|
||||
serialization_deserialization(original_observation_dict, FileEditObservation)
|
||||
|
||||
|
||||
def test_file_edit_observation_oh_aci_serialization():
|
||||
original_observation_dict = {
|
||||
'observation': 'edit',
|
||||
'content': 'The file /path/to/file.txt is edited with the provided content.',
|
||||
'extras': {
|
||||
'_diff_cache': None,
|
||||
'impl_source': FileEditSource.LLM_BASED_EDIT,
|
||||
'new_content': None,
|
||||
'old_content': None,
|
||||
'path': '',
|
||||
'prev_exist': False,
|
||||
'diff': None,
|
||||
},
|
||||
'message': 'I edited the file .',
|
||||
}
|
||||
serialization_deserialization(original_observation_dict, FileEditObservation)
|
||||
|
||||
|
||||
def test_file_edit_observation_legacy_serialization():
|
||||
original_observation_dict = {
|
||||
'observation': 'edit',
|
||||
'content': 'content',
|
||||
'extras': {
|
||||
'path': '/workspace/game_2048.py',
|
||||
'prev_exist': False,
|
||||
'old_content': None,
|
||||
'new_content': 'new content',
|
||||
'impl_source': 'oh_aci',
|
||||
'formatted_output_and_error': 'File created successfully at: /workspace/game_2048.py',
|
||||
},
|
||||
}
|
||||
|
||||
event = event_from_dict(original_observation_dict)
|
||||
assert isinstance(event, Observation)
|
||||
assert isinstance(event, FileEditObservation)
|
||||
assert event.impl_source == FileEditSource.OH_ACI
|
||||
assert event.path == '/workspace/game_2048.py'
|
||||
assert event.prev_exist is False
|
||||
assert event.old_content is None
|
||||
assert event.new_content == 'new content'
|
||||
assert not hasattr(event, 'formatted_output_and_error')
|
||||
|
||||
event_dict = event_to_dict(event)
|
||||
assert event_dict['extras']['impl_source'] == 'oh_aci'
|
||||
assert event_dict['extras']['path'] == '/workspace/game_2048.py'
|
||||
assert event_dict['extras']['prev_exist'] is False
|
||||
assert event_dict['extras']['old_content'] is None
|
||||
assert event_dict['extras']['new_content'] == 'new content'
|
||||
assert 'formatted_output_and_error' not in event_dict['extras']
|
||||
|
||||
|
||||
def test_microagent_observation_serialization():
|
||||
original_observation_dict = {
|
||||
'observation': 'recall',
|
||||
'content': '',
|
||||
'message': 'Added workspace context',
|
||||
'extras': {
|
||||
'recall_type': 'workspace_context',
|
||||
'repo_name': 'some_repo_name',
|
||||
'repo_directory': 'some_repo_directory',
|
||||
'repo_branch': '',
|
||||
'working_dir': '',
|
||||
'runtime_hosts': {'host1': 8080, 'host2': 8081},
|
||||
'repo_instructions': 'complex_repo_instructions',
|
||||
'additional_agent_instructions': 'You know it all about this runtime',
|
||||
'custom_secrets_descriptions': {'SECRET': 'CUSTOM'},
|
||||
'date': '04/12/1023',
|
||||
'microagent_knowledge': [],
|
||||
'conversation_instructions': 'additional_context',
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_observation_dict, RecallObservation)
|
||||
|
||||
|
||||
def test_microagent_observation_microagent_knowledge_serialization():
|
||||
original_observation_dict = {
|
||||
'observation': 'recall',
|
||||
'content': '',
|
||||
'message': 'Added microagent knowledge',
|
||||
'extras': {
|
||||
'recall_type': 'knowledge',
|
||||
'repo_name': '',
|
||||
'repo_directory': '',
|
||||
'repo_branch': '',
|
||||
'repo_instructions': '',
|
||||
'runtime_hosts': {},
|
||||
'working_dir': '',
|
||||
'additional_agent_instructions': '',
|
||||
'custom_secrets_descriptions': {},
|
||||
'conversation_instructions': 'additional_context',
|
||||
'date': '',
|
||||
'microagent_knowledge': [
|
||||
{
|
||||
'name': 'microagent1',
|
||||
'trigger': 'trigger1',
|
||||
'content': 'content1',
|
||||
},
|
||||
{
|
||||
'name': 'microagent2',
|
||||
'trigger': 'trigger2',
|
||||
'content': 'content2',
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
serialization_deserialization(original_observation_dict, RecallObservation)
|
||||
|
||||
|
||||
def test_microagent_observation_knowledge_microagent_serialization():
|
||||
"""Test serialization of a RecallObservation with KNOWLEDGE_MICROAGENT type."""
|
||||
# Create a RecallObservation with microagent knowledge content
|
||||
original = RecallObservation(
|
||||
content='Knowledge microagent information',
|
||||
recall_type=RecallType.KNOWLEDGE,
|
||||
repo_branch='',
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='python_best_practices',
|
||||
trigger='python',
|
||||
content='Always use virtual environments for Python projects.',
|
||||
),
|
||||
MicroagentKnowledge(
|
||||
name='git_workflow',
|
||||
trigger='git',
|
||||
content='Create a new branch for each feature or bugfix.',
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Serialize to dictionary
|
||||
serialized = event_to_dict(original)
|
||||
|
||||
# Verify serialized data structure
|
||||
assert serialized['observation'] == ObservationType.RECALL
|
||||
assert serialized['content'] == 'Knowledge microagent information'
|
||||
assert serialized['extras']['recall_type'] == RecallType.KNOWLEDGE.value
|
||||
assert len(serialized['extras']['microagent_knowledge']) == 2
|
||||
assert serialized['extras']['microagent_knowledge'][0]['trigger'] == 'python'
|
||||
|
||||
# Deserialize back to RecallObservation
|
||||
deserialized = observation_from_dict(serialized)
|
||||
|
||||
# Verify properties are preserved
|
||||
assert deserialized.recall_type == RecallType.KNOWLEDGE
|
||||
assert deserialized.microagent_knowledge == original.microagent_knowledge
|
||||
assert deserialized.content == original.content
|
||||
|
||||
# Check that environment info fields are empty
|
||||
assert deserialized.repo_name == ''
|
||||
assert deserialized.repo_directory == ''
|
||||
assert deserialized.repo_instructions == ''
|
||||
assert deserialized.runtime_hosts == {}
|
||||
|
||||
|
||||
def test_microagent_observation_environment_serialization():
|
||||
"""Test serialization of a RecallObservation with ENVIRONMENT type."""
|
||||
# Create a RecallObservation with environment info
|
||||
original = RecallObservation(
|
||||
content='Environment information',
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
repo_name='OpenHands',
|
||||
repo_directory='/workspace/openhands',
|
||||
repo_branch='main',
|
||||
repo_instructions="Follow the project's coding style guide.",
|
||||
runtime_hosts={'127.0.0.1': 8080, 'localhost': 5000},
|
||||
additional_agent_instructions='You know it all about this runtime',
|
||||
)
|
||||
|
||||
# Serialize to dictionary
|
||||
serialized = event_to_dict(original)
|
||||
|
||||
# Verify serialized data structure
|
||||
assert serialized['observation'] == ObservationType.RECALL
|
||||
assert serialized['content'] == 'Environment information'
|
||||
assert serialized['extras']['recall_type'] == RecallType.WORKSPACE_CONTEXT.value
|
||||
assert serialized['extras']['repo_name'] == 'OpenHands'
|
||||
assert serialized['extras']['runtime_hosts'] == {
|
||||
'127.0.0.1': 8080,
|
||||
'localhost': 5000,
|
||||
}
|
||||
assert (
|
||||
serialized['extras']['additional_agent_instructions']
|
||||
== 'You know it all about this runtime'
|
||||
)
|
||||
# Deserialize back to RecallObservation
|
||||
deserialized = observation_from_dict(serialized)
|
||||
|
||||
# Verify properties are preserved
|
||||
assert deserialized.recall_type == RecallType.WORKSPACE_CONTEXT
|
||||
assert deserialized.repo_name == original.repo_name
|
||||
assert deserialized.repo_directory == original.repo_directory
|
||||
assert deserialized.repo_instructions == original.repo_instructions
|
||||
assert deserialized.runtime_hosts == original.runtime_hosts
|
||||
assert (
|
||||
deserialized.additional_agent_instructions
|
||||
== original.additional_agent_instructions
|
||||
)
|
||||
# Check that knowledge microagent fields are empty
|
||||
assert deserialized.microagent_knowledge == []
|
||||
|
||||
|
||||
def test_microagent_observation_combined_serialization():
|
||||
"""Test serialization of a RecallObservation with both types of information."""
|
||||
# Create a RecallObservation with both environment and microagent info
|
||||
# Note: In practice, recall_type would still be one specific type,
|
||||
# but the object could contain both types of fields
|
||||
original = RecallObservation(
|
||||
content='Combined information',
|
||||
recall_type=RecallType.WORKSPACE_CONTEXT,
|
||||
# Environment info
|
||||
repo_name='OpenHands',
|
||||
repo_directory='/workspace/openhands',
|
||||
repo_branch='main',
|
||||
repo_instructions="Follow the project's coding style guide.",
|
||||
runtime_hosts={'127.0.0.1': 8080},
|
||||
additional_agent_instructions='You know it all about this runtime',
|
||||
# Knowledge microagent info
|
||||
microagent_knowledge=[
|
||||
MicroagentKnowledge(
|
||||
name='python_best_practices',
|
||||
trigger='python',
|
||||
content='Always use virtual environments for Python projects.',
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Serialize to dictionary
|
||||
serialized = event_to_dict(original)
|
||||
|
||||
# Verify serialized data has both types of fields
|
||||
assert serialized['extras']['recall_type'] == RecallType.WORKSPACE_CONTEXT.value
|
||||
assert serialized['extras']['repo_name'] == 'OpenHands'
|
||||
assert (
|
||||
serialized['extras']['microagent_knowledge'][0]['name']
|
||||
== 'python_best_practices'
|
||||
)
|
||||
assert (
|
||||
serialized['extras']['additional_agent_instructions']
|
||||
== 'You know it all about this runtime'
|
||||
)
|
||||
# Deserialize back to RecallObservation
|
||||
deserialized = observation_from_dict(serialized)
|
||||
|
||||
# Verify all properties are preserved
|
||||
assert deserialized.recall_type == RecallType.WORKSPACE_CONTEXT
|
||||
|
||||
# Environment properties
|
||||
assert deserialized.repo_name == original.repo_name
|
||||
assert deserialized.repo_directory == original.repo_directory
|
||||
assert deserialized.repo_instructions == original.repo_instructions
|
||||
assert deserialized.runtime_hosts == original.runtime_hosts
|
||||
assert (
|
||||
deserialized.additional_agent_instructions
|
||||
== original.additional_agent_instructions
|
||||
)
|
||||
|
||||
# Knowledge microagent properties
|
||||
assert deserialized.microagent_knowledge == original.microagent_knowledge
|
||||
Reference in New Issue
Block a user