mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
23 Commits
api/conver
...
fix-events
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
34ad595fcd | ||
|
|
10ae123349 | ||
|
|
9fac84170f | ||
|
|
4469278434 | ||
|
|
ccd00be2e7 | ||
|
|
667f373a54 | ||
|
|
670ee91a41 | ||
|
|
3153c9d9ba | ||
|
|
7a37d99731 | ||
|
|
ffd7c231ad | ||
|
|
67677e6b49 | ||
|
|
c6c327d26d | ||
|
|
2575df8696 | ||
|
|
8aa628ba34 | ||
|
|
db47d6e78b | ||
|
|
b36254b7b4 | ||
|
|
a974164397 | ||
|
|
592aca05e1 | ||
|
|
d309455733 | ||
|
|
66a7920539 | ||
|
|
64ebef3646 | ||
|
|
7a259915c1 | ||
|
|
66bd8fdbcd |
@@ -8,7 +8,7 @@
|
||||
* - Please do NOT serve this file on production.
|
||||
*/
|
||||
|
||||
const PACKAGE_VERSION = '2.7.0'
|
||||
const PACKAGE_VERSION = '2.7.3'
|
||||
const INTEGRITY_CHECKSUM = '00729d72e3b82faf54ca8b9621dbb96f'
|
||||
const IS_MOCKED_RESPONSE = Symbol('isMockedResponse')
|
||||
const activeClientIds = new Set()
|
||||
|
||||
@@ -17,12 +17,12 @@ class MessageAction(Action):
|
||||
return self.content
|
||||
|
||||
@property
|
||||
def images_urls(self):
|
||||
def images_urls(self) -> list[str] | None:
|
||||
# Deprecated alias for backward compatibility
|
||||
return self.image_urls
|
||||
|
||||
@images_urls.setter
|
||||
def images_urls(self, value):
|
||||
def images_urls(self, value: list[str] | None) -> None:
|
||||
self.image_urls = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
||||
@@ -29,19 +29,23 @@ class Event:
|
||||
@property
|
||||
def message(self) -> str | None:
|
||||
if hasattr(self, '_message'):
|
||||
return self._message # type: ignore[attr-defined]
|
||||
msg = getattr(self, '_message')
|
||||
return str(msg) if msg is not None else None
|
||||
return ''
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
if hasattr(self, '_id'):
|
||||
return self._id # type: ignore[attr-defined]
|
||||
id_val = getattr(self, '_id')
|
||||
return int(id_val) if id_val is not None else Event.INVALID_ID
|
||||
return Event.INVALID_ID
|
||||
|
||||
@property
|
||||
def timestamp(self):
|
||||
def timestamp(self) -> str | None:
|
||||
if hasattr(self, '_timestamp') and isinstance(self._timestamp, str):
|
||||
return self._timestamp
|
||||
ts = getattr(self, '_timestamp')
|
||||
return str(ts) if ts is not None else None
|
||||
return None
|
||||
|
||||
@timestamp.setter
|
||||
def timestamp(self, value: datetime) -> None:
|
||||
@@ -51,19 +55,22 @@ class Event:
|
||||
@property
|
||||
def source(self) -> EventSource | None:
|
||||
if hasattr(self, '_source'):
|
||||
return self._source # type: ignore[attr-defined]
|
||||
src = getattr(self, '_source')
|
||||
return EventSource(src) if src is not None else None
|
||||
return None
|
||||
|
||||
@property
|
||||
def cause(self) -> int | None:
|
||||
if hasattr(self, '_cause'):
|
||||
return self._cause # type: ignore[attr-defined]
|
||||
cause_val = getattr(self, '_cause')
|
||||
return int(cause_val) if cause_val is not None else None
|
||||
return None
|
||||
|
||||
@property
|
||||
def timeout(self) -> int | None:
|
||||
if hasattr(self, '_timeout'):
|
||||
return self._timeout # type: ignore[attr-defined]
|
||||
timeout_val = getattr(self, '_timeout')
|
||||
return int(timeout_val) if timeout_val is not None else None
|
||||
return None
|
||||
|
||||
def set_hard_timeout(self, value: int | None, blocking: bool = True) -> None:
|
||||
@@ -90,7 +97,8 @@ class Event:
|
||||
@property
|
||||
def llm_metrics(self) -> Metrics | None:
|
||||
if hasattr(self, '_llm_metrics'):
|
||||
return self._llm_metrics # type: ignore[attr-defined]
|
||||
metrics = getattr(self, '_llm_metrics')
|
||||
return metrics if isinstance(metrics, Metrics) else None
|
||||
return None
|
||||
|
||||
@llm_metrics.setter
|
||||
@@ -101,7 +109,8 @@ class Event:
|
||||
@property
|
||||
def tool_call_metadata(self) -> ToolCallMetadata | None:
|
||||
if hasattr(self, '_tool_call_metadata'):
|
||||
return self._tool_call_metadata # type: ignore[attr-defined]
|
||||
metadata = getattr(self, '_tool_call_metadata')
|
||||
return metadata if isinstance(metadata, ToolCallMetadata) else None
|
||||
return None
|
||||
|
||||
@tool_call_metadata.setter
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from browsergym.utils.obs import flatten_axtree_to_str
|
||||
from browsergym.utils.obs import flatten_axtree_to_str # type: ignore
|
||||
|
||||
from openhands.core.schema import ActionType, ObservationType
|
||||
from openhands.events.observation.observation import Observation
|
||||
@@ -16,13 +17,17 @@ class BrowserOutputObservation(Observation):
|
||||
set_of_marks: str = field(default='', repr=False) # don't show in repr
|
||||
error: bool = False
|
||||
observation: str = ObservationType.BROWSE
|
||||
goal_image_urls: list = field(default_factory=list)
|
||||
goal_image_urls: list[str] = field(default_factory=list)
|
||||
# do not include in the memory
|
||||
open_pages_urls: list = field(default_factory=list)
|
||||
open_pages_urls: list[str] = field(default_factory=list)
|
||||
active_page_index: int = -1
|
||||
dom_object: dict = field(default_factory=dict, repr=False) # don't show in repr
|
||||
axtree_object: dict = field(default_factory=dict, repr=False) # don't show in repr
|
||||
extra_element_properties: dict = field(
|
||||
dom_object: dict[str, Any] = field(
|
||||
default_factory=dict, repr=False
|
||||
) # don't show in repr
|
||||
axtree_object: dict[str, Any] = field(
|
||||
default_factory=dict, repr=False
|
||||
) # don't show in repr
|
||||
extra_element_properties: dict[str, Any] = field(
|
||||
default_factory=dict, repr=False
|
||||
) # don't show in repr
|
||||
last_browser_action: str = ''
|
||||
@@ -102,4 +107,4 @@ class BrowserOutputObservation(Observation):
|
||||
skip_generic=False,
|
||||
filter_visible_only=filter_visible_only,
|
||||
)
|
||||
return cur_axtree_txt
|
||||
return str(cur_axtree_txt)
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import re
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Self
|
||||
from typing import Any, Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -105,10 +105,10 @@ class CmdOutputObservation(Observation):
|
||||
content: str,
|
||||
command: str,
|
||||
observation: str = ObservationType.RUN,
|
||||
metadata: dict | CmdOutputMetadata | None = None,
|
||||
metadata: dict[str, Any] | CmdOutputMetadata | None = None,
|
||||
hidden: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(content)
|
||||
self.command = command
|
||||
self.observation = observation
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from openhands.core.exceptions import LLMMalformedActionError
|
||||
from openhands.events.action.action import Action
|
||||
@@ -42,7 +43,7 @@ actions = (
|
||||
ACTION_TYPE_TO_CLASS = {action_class.action: action_class for action_class in actions} # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def handle_action_deprecated_args(args: dict) -> dict:
|
||||
def handle_action_deprecated_args(args: dict[str, Any]) -> dict[str, Any]:
|
||||
# keep_prompt has been deprecated in https://github.com/All-Hands-AI/OpenHands/pull/4881
|
||||
if 'keep_prompt' in args:
|
||||
args.pop('keep_prompt')
|
||||
@@ -120,4 +121,5 @@ def action_from_dict(action: dict) -> Action:
|
||||
raise LLMMalformedActionError(
|
||||
f'action={action} has the wrong arguments: {str(e)}'
|
||||
)
|
||||
assert isinstance(decoded_action, Action)
|
||||
return decoded_action
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from dataclasses import asdict
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -37,14 +38,14 @@ DELETE_FROM_TRAJECTORY_EXTRAS = {
|
||||
DELETE_FROM_MEMORY_EXTRAS = DELETE_FROM_TRAJECTORY_EXTRAS | {'open_pages_urls'}
|
||||
|
||||
|
||||
def event_from_dict(data) -> 'Event':
|
||||
def event_from_dict(data: dict[str, Any]) -> 'Event':
|
||||
evt: Event
|
||||
if 'action' in data:
|
||||
evt = action_from_dict(data)
|
||||
elif 'observation' in data:
|
||||
evt = observation_from_dict(data)
|
||||
else:
|
||||
raise ValueError('Unknown event type: ' + data)
|
||||
raise ValueError(f'Unknown event type: {data}')
|
||||
for key in UNDERSCORE_KEYS:
|
||||
if key in data:
|
||||
value = data[key]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
from openhands.events.observation.agent import (
|
||||
AgentCondensationObservation,
|
||||
@@ -49,8 +50,8 @@ OBSERVATION_TYPE_TO_CLASS = {
|
||||
|
||||
|
||||
def _update_cmd_output_metadata(
|
||||
metadata: dict | CmdOutputMetadata | None, **kwargs
|
||||
) -> dict | CmdOutputMetadata:
|
||||
metadata: dict[str, Any] | CmdOutputMetadata | None, **kwargs: Any
|
||||
) -> dict[str, Any] | CmdOutputMetadata:
|
||||
"""Update the metadata of a CmdOutputObservation.
|
||||
|
||||
If metadata is None, create a new CmdOutputMetadata instance.
|
||||
@@ -110,4 +111,6 @@ def observation_from_dict(observation: dict) -> Observation:
|
||||
else:
|
||||
extras['metadata'] = CmdOutputMetadata()
|
||||
|
||||
return observation_class(content=content, **extras)
|
||||
obs = observation_class(content=content, **extras)
|
||||
assert isinstance(obs, Observation)
|
||||
return obs
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
def remove_fields(obj, fields: set[str]):
|
||||
def remove_fields(obj: dict | list | tuple, fields: set[str]) -> None:
|
||||
"""Remove fields from an object.
|
||||
|
||||
Parameters:
|
||||
@@ -14,7 +14,7 @@ def remove_fields(obj, fields: set[str]):
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
for item in obj:
|
||||
remove_fields(item, fields)
|
||||
elif hasattr(obj, '__dataclass_fields__'):
|
||||
if hasattr(obj, '__dataclass_fields__'):
|
||||
raise ValueError(
|
||||
'Object must not contain dataclass, consider converting to dict first'
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Callable, Iterable
|
||||
from typing import Any, AsyncIterator, Callable, Iterable
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.event import Event, EventSource
|
||||
@@ -40,18 +40,21 @@ async def session_exists(sid: str, file_store: FileStore) -> bool:
|
||||
|
||||
|
||||
class AsyncEventStreamWrapper:
|
||||
def __init__(self, event_stream, *args, **kwargs):
|
||||
def __init__(self, event_stream: 'EventStream', *args: Any, **kwargs: Any) -> None:
|
||||
self.event_stream = event_stream
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
async def __aiter__(self):
|
||||
async def __aiter__(self) -> AsyncIterator[Event]:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Create an async generator that yields events
|
||||
for event in self.event_stream.get_events(*self.args, **self.kwargs):
|
||||
# Run the blocking get_events() in a thread pool
|
||||
yield await loop.run_in_executor(None, lambda e=event: e) # type: ignore
|
||||
def get_event(e: Event = event) -> Event:
|
||||
return e
|
||||
|
||||
yield await loop.run_in_executor(None, get_event)
|
||||
|
||||
|
||||
class EventStream:
|
||||
@@ -102,14 +105,14 @@ class EventStream:
|
||||
if id >= self._cur_id:
|
||||
self._cur_id = id + 1
|
||||
|
||||
def _init_thread_loop(self, subscriber_id: str, callback_id: str):
|
||||
def _init_thread_loop(self, subscriber_id: str, callback_id: str) -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
if subscriber_id not in self._thread_loops:
|
||||
self._thread_loops[subscriber_id] = {}
|
||||
self._thread_loops[subscriber_id][callback_id] = loop
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
self._stop_flag.set()
|
||||
if self._queue_thread.is_alive():
|
||||
self._queue_thread.join()
|
||||
@@ -124,7 +127,7 @@ class EventStream:
|
||||
while not self._queue.empty():
|
||||
self._queue.get()
|
||||
|
||||
def _clean_up_subscriber(self, subscriber_id: str, callback_id: str):
|
||||
def _clean_up_subscriber(self, subscriber_id: str, callback_id: str) -> None:
|
||||
if subscriber_id not in self._subscribers:
|
||||
logger.warning(f'Subscriber not found during cleanup: {subscriber_id}')
|
||||
return
|
||||
@@ -172,7 +175,7 @@ class EventStream:
|
||||
end_id: int | None = None,
|
||||
reverse: bool = False,
|
||||
filter_out_type: tuple[type[Event], ...] | None = None,
|
||||
filter_hidden=False,
|
||||
filter_hidden: bool = False,
|
||||
) -> Iterable[Event]:
|
||||
"""
|
||||
Retrieve events from the event stream, optionally filtering out events of a given type
|
||||
@@ -189,7 +192,7 @@ class EventStream:
|
||||
Events from the stream that match the criteria.
|
||||
"""
|
||||
|
||||
def should_filter(event: Event):
|
||||
def should_filter(event: Event) -> bool:
|
||||
if filter_hidden and hasattr(event, 'hidden') and event.hidden:
|
||||
return True
|
||||
if filter_out_type is not None and isinstance(event, filter_out_type):
|
||||
@@ -234,8 +237,11 @@ class EventStream:
|
||||
return self._cur_id - 1
|
||||
|
||||
def subscribe(
|
||||
self, subscriber_id: EventStreamSubscriber, callback: Callable, callback_id: str
|
||||
):
|
||||
self,
|
||||
subscriber_id: EventStreamSubscriber,
|
||||
callback: Callable[[Event], None],
|
||||
callback_id: str,
|
||||
) -> None:
|
||||
initializer = partial(self._init_thread_loop, subscriber_id, callback_id)
|
||||
pool = ThreadPoolExecutor(max_workers=1, initializer=initializer)
|
||||
if subscriber_id not in self._subscribers:
|
||||
@@ -250,7 +256,9 @@ class EventStream:
|
||||
self._subscribers[subscriber_id][callback_id] = callback
|
||||
self._thread_pools[subscriber_id][callback_id] = pool
|
||||
|
||||
def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str):
|
||||
def unsubscribe(
|
||||
self, subscriber_id: EventStreamSubscriber, callback_id: str
|
||||
) -> None:
|
||||
if subscriber_id not in self._subscribers:
|
||||
logger.warning(f'Subscriber not found during unsubscribe: {subscriber_id}')
|
||||
return
|
||||
@@ -261,7 +269,7 @@ class EventStream:
|
||||
|
||||
self._clean_up_subscriber(subscriber_id, callback_id)
|
||||
|
||||
def add_event(self, event: Event, source: EventSource):
|
||||
def add_event(self, event: Event, source: EventSource) -> None:
|
||||
if hasattr(event, '_id') and event.id is not None:
|
||||
raise ValueError(
|
||||
f'Event already has an ID:{event.id}. It was probably added back to the EventStream from inside a handler, triggering a loop.'
|
||||
@@ -279,13 +287,13 @@ class EventStream:
|
||||
self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
|
||||
self._queue.put(event)
|
||||
|
||||
def set_secrets(self, secrets: dict[str, str]):
|
||||
def set_secrets(self, secrets: dict[str, str]) -> None:
|
||||
self.secrets = secrets.copy()
|
||||
|
||||
def update_secrets(self, secrets: dict[str, str]):
|
||||
def update_secrets(self, secrets: dict[str, str]) -> None:
|
||||
self.secrets.update(secrets)
|
||||
|
||||
def _replace_secrets(self, data: dict) -> dict:
|
||||
def _replace_secrets(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
for key in data:
|
||||
if isinstance(data[key], dict):
|
||||
data[key] = self._replace_secrets(data[key])
|
||||
@@ -294,7 +302,7 @@ class EventStream:
|
||||
data[key] = data[key].replace(secret, '<secret_hidden>')
|
||||
return data
|
||||
|
||||
def _run_queue_loop(self):
|
||||
def _run_queue_loop(self) -> None:
|
||||
self._queue_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._queue_loop)
|
||||
try:
|
||||
@@ -302,7 +310,7 @@ class EventStream:
|
||||
finally:
|
||||
self._queue_loop.close()
|
||||
|
||||
async def _process_queue(self):
|
||||
async def _process_queue(self) -> None:
|
||||
while should_continue() and not self._stop_flag.is_set():
|
||||
event = None
|
||||
try:
|
||||
@@ -319,8 +327,10 @@ class EventStream:
|
||||
future = pool.submit(callback, event)
|
||||
future.add_done_callback(self._make_error_handler(callback_id, key))
|
||||
|
||||
def _make_error_handler(self, callback_id: str, subscriber_id: str):
|
||||
def _handle_callback_error(fut):
|
||||
def _make_error_handler(
|
||||
self, callback_id: str, subscriber_id: str
|
||||
) -> Callable[[Any], None]:
|
||||
def _handle_callback_error(fut: Any) -> None:
|
||||
try:
|
||||
# This will raise any exception that occurred during callback execution
|
||||
fut.result()
|
||||
@@ -333,14 +343,14 @@ class EventStream:
|
||||
|
||||
return _handle_callback_error
|
||||
|
||||
def filtered_events_by_source(self, source: EventSource):
|
||||
def filtered_events_by_source(self, source: EventSource) -> Iterable[Event]:
|
||||
for event in self.get_events():
|
||||
if event.source == source:
|
||||
yield event
|
||||
|
||||
def _should_filter_event(
|
||||
self,
|
||||
event,
|
||||
event: Event,
|
||||
query: str | None = None,
|
||||
event_types: tuple[type[Event], ...] | None = None,
|
||||
source: str | None = None,
|
||||
@@ -363,13 +373,14 @@ class EventStream:
|
||||
if event_types and not isinstance(event, event_types):
|
||||
return True
|
||||
|
||||
if source and not event.source.value == source:
|
||||
if source:
|
||||
if event.source is None or event.source.value != source:
|
||||
return True
|
||||
|
||||
if start_date and event.timestamp is not None and event.timestamp < start_date:
|
||||
return True
|
||||
|
||||
if start_date and event.timestamp < start_date:
|
||||
return True
|
||||
|
||||
if end_date and event.timestamp > end_date:
|
||||
if end_date and event.timestamp is not None and event.timestamp > end_date:
|
||||
return True
|
||||
|
||||
# Text search in event content if query provided
|
||||
@@ -391,7 +402,7 @@ class EventStream:
|
||||
start_id: int = 0,
|
||||
limit: int = 100,
|
||||
reverse: bool = False,
|
||||
) -> list[type[Event]]:
|
||||
) -> list[Event]:
|
||||
"""Get matching events from the event stream based on filters.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -19,8 +20,12 @@ class SecurityAnalyzer:
|
||||
event_stream: The event stream to listen for events.
|
||||
"""
|
||||
self.event_stream = event_stream
|
||||
|
||||
def sync_on_event(event: Event) -> None:
|
||||
asyncio.create_task(self.on_event(event))
|
||||
|
||||
self.event_stream.subscribe(
|
||||
EventStreamSubscriber.SECURITY_ANALYZER, self.on_event, str(uuid4())
|
||||
EventStreamSubscriber.SECURITY_ANALYZER, sync_on_event, str(uuid4())
|
||||
)
|
||||
|
||||
async def on_event(self, event: Event) -> None:
|
||||
|
||||
@@ -50,6 +50,8 @@ async def connect(connection_id: str, environ):
|
||||
)
|
||||
|
||||
agent_state_changed = None
|
||||
if event_stream is None:
|
||||
raise ConnectionRefusedError('Failed to join conversation')
|
||||
async_stream = AsyncEventStreamWrapper(event_stream, latest_event_id + 1)
|
||||
async for event in async_stream:
|
||||
if isinstance(
|
||||
|
||||
@@ -24,6 +24,7 @@ from openhands.events.observation.files import (
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
)
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.storage import get_file_store
|
||||
|
||||
|
||||
@@ -156,6 +157,30 @@ def test_get_matching_events_source_filter(temp_dir: str):
|
||||
and events[0].source == EventSource.ENVIRONMENT
|
||||
)
|
||||
|
||||
# Test that source comparison works correctly with None source
|
||||
null_source_event = NullObservation('test4')
|
||||
event_stream.add_event(null_source_event, EventSource.AGENT)
|
||||
event = event_stream.get_event(event_stream.get_latest_event_id())
|
||||
event._source = None # type: ignore
|
||||
|
||||
# Update the serialized version
|
||||
data = event_to_dict(event)
|
||||
event_stream.file_store.write(
|
||||
event_stream._get_filename_for_id(event.id), json.dumps(data)
|
||||
)
|
||||
|
||||
# Verify that source comparison works correctly
|
||||
assert event_stream._should_filter_event(
|
||||
event, source='agent'
|
||||
) # Should filter out None source events
|
||||
assert not event_stream._should_filter_event(
|
||||
event, source=None
|
||||
) # Should not filter out when source filter is None
|
||||
|
||||
# Filter by AGENT source again
|
||||
events = event_stream.get_matching_events(source='agent')
|
||||
assert len(events) == 2 # Should not include the None source event
|
||||
|
||||
|
||||
def test_get_matching_events_pagination(temp_dir: str):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
@@ -209,7 +234,7 @@ def test_memory_usage_file_operations(temp_dir: str):
|
||||
"""
|
||||
|
||||
def get_memory_mb():
|
||||
"""Get current memory usage in MB"""
|
||||
"""Get current memory usage in MB."""
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss / 1024 / 1024
|
||||
|
||||
|
||||
@@ -61,13 +61,13 @@ def test_get_messages(codeact_agent: CodeActAgent):
|
||||
message_action_1._source = 'user'
|
||||
history.append(message_action_1)
|
||||
message_action_2 = MessageAction('Sure!')
|
||||
message_action_2._source = 'assistant'
|
||||
message_action_2._source = 'agent'
|
||||
history.append(message_action_2)
|
||||
message_action_3 = MessageAction('Hello, agent!')
|
||||
message_action_3._source = 'user'
|
||||
history.append(message_action_3)
|
||||
message_action_4 = MessageAction('Hello, user!')
|
||||
message_action_4._source = 'assistant'
|
||||
message_action_4._source = 'agent'
|
||||
history.append(message_action_4)
|
||||
message_action_5 = MessageAction('Laaaaaaaast!')
|
||||
message_action_5._source = 'user'
|
||||
@@ -106,7 +106,7 @@ def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
|
||||
message_action_user._source = 'user'
|
||||
history.append(message_action_user)
|
||||
message_action_agent = MessageAction(f'Agent message {i}')
|
||||
message_action_agent._source = 'assistant'
|
||||
message_action_agent._source = 'agent'
|
||||
history.append(message_action_agent)
|
||||
|
||||
codeact_agent.reset()
|
||||
|
||||
@@ -603,11 +603,11 @@ async def test_check_usertask(
|
||||
|
||||
if is_appropriate == 'No':
|
||||
assert len(event_list) == 2
|
||||
assert type(event_list[0]) == MessageAction
|
||||
assert type(event_list[1]) == ChangeAgentStateAction
|
||||
assert isinstance(event_list[0], MessageAction)
|
||||
assert isinstance(event_list[1], ChangeAgentStateAction)
|
||||
elif is_appropriate == 'Yes':
|
||||
assert len(event_list) == 1
|
||||
assert type(event_list[0]) == MessageAction
|
||||
assert isinstance(event_list[0], MessageAction)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -665,8 +665,8 @@ async def test_check_fillaction(
|
||||
|
||||
if is_harmful == 'Yes':
|
||||
assert len(event_list) == 2
|
||||
assert type(event_list[0]) == BrowseInteractiveAction
|
||||
assert type(event_list[1]) == ChangeAgentStateAction
|
||||
assert isinstance(event_list[0], BrowseInteractiveAction)
|
||||
assert isinstance(event_list[1], ChangeAgentStateAction)
|
||||
elif is_harmful == 'No':
|
||||
assert len(event_list) == 1
|
||||
assert type(event_list[0]) == BrowseInteractiveAction
|
||||
assert isinstance(event_list[0], BrowseInteractiveAction)
|
||||
|
||||
Reference in New Issue
Block a user