mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-08 22:38:05 -05:00
Fix mypy errors in events directory (#6810)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Xingyao Wang <xingyao@all-hands.dev> Co-authored-by: Engel Nyst <enyst@users.noreply.github.com>
This commit is contained in:
@@ -17,12 +17,12 @@ class MessageAction(Action):
|
||||
return self.content
|
||||
|
||||
@property
|
||||
def images_urls(self):
|
||||
def images_urls(self) -> list[str] | None:
|
||||
# Deprecated alias for backward compatibility
|
||||
return self.image_urls
|
||||
|
||||
@images_urls.setter
|
||||
def images_urls(self, value):
|
||||
def images_urls(self, value: list[str] | None) -> None:
|
||||
self.image_urls = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
||||
@@ -39,19 +39,23 @@ class Event:
|
||||
@property
|
||||
def message(self) -> str | None:
|
||||
if hasattr(self, '_message'):
|
||||
return self._message # type: ignore[attr-defined]
|
||||
msg = getattr(self, '_message')
|
||||
return str(msg) if msg is not None else None
|
||||
return ''
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
if hasattr(self, '_id'):
|
||||
return self._id # type: ignore[attr-defined]
|
||||
id_val = getattr(self, '_id')
|
||||
return int(id_val) if id_val is not None else Event.INVALID_ID
|
||||
return Event.INVALID_ID
|
||||
|
||||
@property
|
||||
def timestamp(self):
|
||||
def timestamp(self) -> str | None:
|
||||
if hasattr(self, '_timestamp') and isinstance(self._timestamp, str):
|
||||
return self._timestamp
|
||||
ts = getattr(self, '_timestamp')
|
||||
return str(ts) if ts is not None else None
|
||||
return None
|
||||
|
||||
@timestamp.setter
|
||||
def timestamp(self, value: datetime) -> None:
|
||||
@@ -61,22 +65,25 @@ class Event:
|
||||
@property
|
||||
def source(self) -> EventSource | None:
|
||||
if hasattr(self, '_source'):
|
||||
return self._source # type: ignore[attr-defined]
|
||||
src = getattr(self, '_source')
|
||||
return EventSource(src) if src is not None else None
|
||||
return None
|
||||
|
||||
@property
|
||||
def cause(self) -> int | None:
|
||||
if hasattr(self, '_cause'):
|
||||
return self._cause # type: ignore[attr-defined]
|
||||
cause_val = getattr(self, '_cause')
|
||||
return int(cause_val) if cause_val is not None else None
|
||||
return None
|
||||
|
||||
@property
|
||||
def timeout(self) -> int | None:
|
||||
def timeout(self) -> float | None:
|
||||
if hasattr(self, '_timeout'):
|
||||
return self._timeout # type: ignore[attr-defined]
|
||||
timeout_val = getattr(self, '_timeout')
|
||||
return float(timeout_val) if timeout_val is not None else None
|
||||
return None
|
||||
|
||||
def set_hard_timeout(self, value: int | None, blocking: bool = True) -> None:
|
||||
def set_hard_timeout(self, value: float | None, blocking: bool = True) -> None:
|
||||
"""Set the timeout for the event.
|
||||
|
||||
NOTE, this is a hard timeout, meaning that the event will be blocked
|
||||
@@ -100,7 +107,8 @@ class Event:
|
||||
@property
|
||||
def llm_metrics(self) -> Metrics | None:
|
||||
if hasattr(self, '_llm_metrics'):
|
||||
return self._llm_metrics # type: ignore[attr-defined]
|
||||
metrics = getattr(self, '_llm_metrics')
|
||||
return metrics if isinstance(metrics, Metrics) else None
|
||||
return None
|
||||
|
||||
@llm_metrics.setter
|
||||
@@ -111,7 +119,8 @@ class Event:
|
||||
@property
|
||||
def tool_call_metadata(self) -> ToolCallMetadata | None:
|
||||
if hasattr(self, '_tool_call_metadata'):
|
||||
return self._tool_call_metadata # type: ignore[attr-defined]
|
||||
metadata = getattr(self, '_tool_call_metadata')
|
||||
return metadata if isinstance(metadata, ToolCallMetadata) else None
|
||||
return None
|
||||
|
||||
@tool_call_metadata.setter
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from browsergym.utils.obs import flatten_axtree_to_str
|
||||
|
||||
@@ -16,13 +17,17 @@ class BrowserOutputObservation(Observation):
|
||||
set_of_marks: str = field(default='', repr=False) # don't show in repr
|
||||
error: bool = False
|
||||
observation: str = ObservationType.BROWSE
|
||||
goal_image_urls: list = field(default_factory=list)
|
||||
goal_image_urls: list[str] = field(default_factory=list)
|
||||
# do not include in the memory
|
||||
open_pages_urls: list = field(default_factory=list)
|
||||
open_pages_urls: list[str] = field(default_factory=list)
|
||||
active_page_index: int = -1
|
||||
dom_object: dict = field(default_factory=dict, repr=False) # don't show in repr
|
||||
axtree_object: dict = field(default_factory=dict, repr=False) # don't show in repr
|
||||
extra_element_properties: dict = field(
|
||||
dom_object: dict[str, Any] = field(
|
||||
default_factory=dict, repr=False
|
||||
) # don't show in repr
|
||||
axtree_object: dict[str, Any] = field(
|
||||
default_factory=dict, repr=False
|
||||
) # don't show in repr
|
||||
extra_element_properties: dict[str, Any] = field(
|
||||
default_factory=dict, repr=False
|
||||
) # don't show in repr
|
||||
last_browser_action: str = ''
|
||||
@@ -102,4 +107,4 @@ class BrowserOutputObservation(Observation):
|
||||
skip_generic=False,
|
||||
filter_visible_only=filter_visible_only,
|
||||
)
|
||||
return cur_axtree_txt
|
||||
return str(cur_axtree_txt)
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import re
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Self
|
||||
from typing import Any, Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -105,10 +105,10 @@ class CmdOutputObservation(Observation):
|
||||
content: str,
|
||||
command: str,
|
||||
observation: str = ObservationType.RUN,
|
||||
metadata: dict | CmdOutputMetadata | None = None,
|
||||
metadata: dict[str, Any] | CmdOutputMetadata | None = None,
|
||||
hidden: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(content)
|
||||
self.command = command
|
||||
self.observation = observation
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import re
|
||||
from typing import Any
|
||||
from openhands.core.exceptions import LLMMalformedActionError
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.action.agent import (
|
||||
@@ -42,7 +44,7 @@ actions = (
|
||||
ACTION_TYPE_TO_CLASS = {action_class.action: action_class for action_class in actions} # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def handle_action_deprecated_args(args: dict) -> dict:
|
||||
def handle_action_deprecated_args(args: dict[str, Any]) -> dict[str, Any]:
|
||||
# keep_prompt has been deprecated in https://github.com/All-Hands-AI/OpenHands/pull/4881
|
||||
if 'keep_prompt' in args:
|
||||
args.pop('keep_prompt')
|
||||
@@ -126,4 +128,5 @@ def action_from_dict(action: dict) -> Action:
|
||||
raise LLMMalformedActionError(
|
||||
f'action={action} has the wrong arguments: {str(e)}'
|
||||
)
|
||||
assert isinstance(decoded_action, Action)
|
||||
return decoded_action
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from dataclasses import asdict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -48,14 +49,14 @@ DELETE_FROM_TRAJECTORY_EXTRAS_AND_SCREENSHOTS = DELETE_FROM_TRAJECTORY_EXTRAS |
|
||||
}
|
||||
|
||||
|
||||
def event_from_dict(data) -> 'Event':
|
||||
def event_from_dict(data: dict[str, Any]) -> 'Event':
|
||||
evt: Event
|
||||
if 'action' in data:
|
||||
evt = action_from_dict(data)
|
||||
elif 'observation' in data:
|
||||
evt = observation_from_dict(data)
|
||||
else:
|
||||
raise ValueError('Unknown event type: ' + data)
|
||||
raise ValueError(f'Unknown event type: {data}')
|
||||
for key in UNDERSCORE_KEYS:
|
||||
if key in data:
|
||||
value = data[key]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation.agent import (
|
||||
@@ -53,8 +54,8 @@ OBSERVATION_TYPE_TO_CLASS = {
|
||||
|
||||
|
||||
def _update_cmd_output_metadata(
|
||||
metadata: dict | CmdOutputMetadata | None, **kwargs
|
||||
) -> dict | CmdOutputMetadata:
|
||||
metadata: dict[str, Any] | CmdOutputMetadata | None, **kwargs: Any
|
||||
) -> dict[str, Any] | CmdOutputMetadata:
|
||||
"""Update the metadata of a CmdOutputObservation.
|
||||
|
||||
If metadata is None, create a new CmdOutputMetadata instance.
|
||||
@@ -128,4 +129,6 @@ def observation_from_dict(observation: dict) -> Observation:
|
||||
for item in extras['microagent_knowledge']
|
||||
]
|
||||
|
||||
return observation_class(content=content, **extras)
|
||||
obs = observation_class(content=content, **extras)
|
||||
assert isinstance(obs, Observation)
|
||||
return obs
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
def remove_fields(obj, fields: set[str]):
|
||||
def remove_fields(obj: dict | list | tuple, fields: set[str]) -> None:
|
||||
"""Remove fields from an object.
|
||||
|
||||
Parameters:
|
||||
@@ -14,7 +14,7 @@ def remove_fields(obj, fields: set[str]):
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
for item in obj:
|
||||
remove_fields(item, fields)
|
||||
elif hasattr(obj, '__dataclass_fields__'):
|
||||
if hasattr(obj, '__dataclass_fields__'):
|
||||
raise ValueError(
|
||||
'Object must not contain dataclass, consider converting to dict first'
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Callable, Iterable
|
||||
from typing import Any, AsyncIterator, Callable, Iterable
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.event import Event, EventSource
|
||||
@@ -43,18 +43,21 @@ async def session_exists(
|
||||
|
||||
|
||||
class AsyncEventStreamWrapper:
|
||||
def __init__(self, event_stream, *args, **kwargs):
|
||||
def __init__(self, event_stream: 'EventStream', *args: Any, **kwargs: Any) -> None:
|
||||
self.event_stream = event_stream
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
async def __aiter__(self):
|
||||
async def __aiter__(self) -> AsyncIterator[Event]:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Create an async generator that yields events
|
||||
for event in self.event_stream.get_events(*self.args, **self.kwargs):
|
||||
# Run the blocking get_events() in a thread pool
|
||||
yield await loop.run_in_executor(None, lambda e=event: e) # type: ignore
|
||||
def get_event(e: Event = event) -> Event:
|
||||
return e
|
||||
|
||||
yield await loop.run_in_executor(None, get_event)
|
||||
|
||||
|
||||
class EventStream:
|
||||
@@ -121,14 +124,14 @@ class EventStream:
|
||||
if id >= self._cur_id:
|
||||
self._cur_id = id + 1
|
||||
|
||||
def _init_thread_loop(self, subscriber_id: str, callback_id: str):
|
||||
def _init_thread_loop(self, subscriber_id: str, callback_id: str) -> None:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
if subscriber_id not in self._thread_loops:
|
||||
self._thread_loops[subscriber_id] = {}
|
||||
self._thread_loops[subscriber_id][callback_id] = loop
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
self._stop_flag.set()
|
||||
if self._queue_thread.is_alive():
|
||||
self._queue_thread.join()
|
||||
@@ -143,7 +146,7 @@ class EventStream:
|
||||
while not self._queue.empty():
|
||||
self._queue.get()
|
||||
|
||||
def _clean_up_subscriber(self, subscriber_id: str, callback_id: str):
|
||||
def _clean_up_subscriber(self, subscriber_id: str, callback_id: str) -> None:
|
||||
if subscriber_id not in self._subscribers:
|
||||
logger.warning(f'Subscriber not found during cleanup: {subscriber_id}')
|
||||
return
|
||||
@@ -191,7 +194,7 @@ class EventStream:
|
||||
end_id: int | None = None,
|
||||
reverse: bool = False,
|
||||
filter_out_type: tuple[type[Event], ...] | None = None,
|
||||
filter_hidden=False,
|
||||
filter_hidden: bool = False,
|
||||
) -> Iterable[Event]:
|
||||
"""
|
||||
Retrieve events from the event stream, optionally filtering out events of a given type
|
||||
@@ -208,7 +211,7 @@ class EventStream:
|
||||
Events from the stream that match the criteria.
|
||||
"""
|
||||
|
||||
def should_filter(event: Event):
|
||||
def should_filter(event: Event) -> bool:
|
||||
if filter_hidden and hasattr(event, 'hidden') and event.hidden:
|
||||
return True
|
||||
if filter_out_type is not None and isinstance(event, filter_out_type):
|
||||
@@ -263,8 +266,11 @@ class EventStream:
|
||||
return self._cur_id - 1
|
||||
|
||||
def subscribe(
|
||||
self, subscriber_id: EventStreamSubscriber, callback: Callable, callback_id: str
|
||||
):
|
||||
self,
|
||||
subscriber_id: EventStreamSubscriber,
|
||||
callback: Callable[[Event], None],
|
||||
callback_id: str,
|
||||
) -> None:
|
||||
initializer = partial(self._init_thread_loop, subscriber_id, callback_id)
|
||||
pool = ThreadPoolExecutor(max_workers=1, initializer=initializer)
|
||||
if subscriber_id not in self._subscribers:
|
||||
@@ -279,7 +285,9 @@ class EventStream:
|
||||
self._subscribers[subscriber_id][callback_id] = callback
|
||||
self._thread_pools[subscriber_id][callback_id] = pool
|
||||
|
||||
def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str):
|
||||
def unsubscribe(
|
||||
self, subscriber_id: EventStreamSubscriber, callback_id: str
|
||||
) -> None:
|
||||
if subscriber_id not in self._subscribers:
|
||||
logger.warning(f'Subscriber not found during unsubscribe: {subscriber_id}')
|
||||
return
|
||||
@@ -290,8 +298,8 @@ class EventStream:
|
||||
|
||||
self._clean_up_subscriber(subscriber_id, callback_id)
|
||||
|
||||
def add_event(self, event: Event, source: EventSource):
|
||||
if hasattr(event, '_id') and event.id is not None:
|
||||
def add_event(self, event: Event, source: EventSource) -> None:
|
||||
if event.id != Event.INVALID_ID:
|
||||
raise ValueError(
|
||||
f'Event already has an ID:{event.id}. It was probably added back to the EventStream from inside a handler, triggering a loop.'
|
||||
)
|
||||
@@ -310,13 +318,13 @@ class EventStream:
|
||||
)
|
||||
self._queue.put(event)
|
||||
|
||||
def set_secrets(self, secrets: dict[str, str]):
|
||||
def set_secrets(self, secrets: dict[str, str]) -> None:
|
||||
self.secrets = secrets.copy()
|
||||
|
||||
def update_secrets(self, secrets: dict[str, str]):
|
||||
def update_secrets(self, secrets: dict[str, str]) -> None:
|
||||
self.secrets.update(secrets)
|
||||
|
||||
def _replace_secrets(self, data: dict) -> dict:
|
||||
def _replace_secrets(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
for key in data:
|
||||
if isinstance(data[key], dict):
|
||||
data[key] = self._replace_secrets(data[key])
|
||||
@@ -325,7 +333,7 @@ class EventStream:
|
||||
data[key] = data[key].replace(secret, '<secret_hidden>')
|
||||
return data
|
||||
|
||||
def _run_queue_loop(self):
|
||||
def _run_queue_loop(self) -> None:
|
||||
self._queue_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._queue_loop)
|
||||
try:
|
||||
@@ -333,7 +341,7 @@ class EventStream:
|
||||
finally:
|
||||
self._queue_loop.close()
|
||||
|
||||
async def _process_queue(self):
|
||||
async def _process_queue(self) -> None:
|
||||
while should_continue() and not self._stop_flag.is_set():
|
||||
event = None
|
||||
try:
|
||||
@@ -350,8 +358,10 @@ class EventStream:
|
||||
future = pool.submit(callback, event)
|
||||
future.add_done_callback(self._make_error_handler(callback_id, key))
|
||||
|
||||
def _make_error_handler(self, callback_id: str, subscriber_id: str):
|
||||
def _handle_callback_error(fut):
|
||||
def _make_error_handler(
|
||||
self, callback_id: str, subscriber_id: str
|
||||
) -> Callable[[Any], None]:
|
||||
def _handle_callback_error(fut: Any) -> None:
|
||||
try:
|
||||
# This will raise any exception that occurred during callback execution
|
||||
fut.result()
|
||||
@@ -364,14 +374,14 @@ class EventStream:
|
||||
|
||||
return _handle_callback_error
|
||||
|
||||
def filtered_events_by_source(self, source: EventSource):
|
||||
def filtered_events_by_source(self, source: EventSource) -> Iterable[Event]:
|
||||
for event in self.get_events():
|
||||
if event.source == source:
|
||||
yield event
|
||||
|
||||
def _should_filter_event(
|
||||
self,
|
||||
event,
|
||||
event: Event,
|
||||
query: str | None = None,
|
||||
event_types: tuple[type[Event], ...] | None = None,
|
||||
source: str | None = None,
|
||||
@@ -394,13 +404,14 @@ class EventStream:
|
||||
if event_types and not isinstance(event, event_types):
|
||||
return True
|
||||
|
||||
if source and not event.source.value == source:
|
||||
if source:
|
||||
if event.source is None or event.source.value != source:
|
||||
return True
|
||||
|
||||
if start_date and event.timestamp is not None and event.timestamp < start_date:
|
||||
return True
|
||||
|
||||
if start_date and event.timestamp < start_date:
|
||||
return True
|
||||
|
||||
if end_date and event.timestamp > end_date:
|
||||
if end_date and event.timestamp is not None and event.timestamp > end_date:
|
||||
return True
|
||||
|
||||
# Text search in event content if query provided
|
||||
@@ -422,7 +433,7 @@ class EventStream:
|
||||
start_id: int = 0,
|
||||
limit: int = 100,
|
||||
reverse: bool = False,
|
||||
) -> list[type[Event]]:
|
||||
) -> list[Event]:
|
||||
"""Get matching events from the event stream based on filters.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -19,8 +20,12 @@ class SecurityAnalyzer:
|
||||
event_stream: The event stream to listen for events.
|
||||
"""
|
||||
self.event_stream = event_stream
|
||||
|
||||
def sync_on_event(event: Event) -> None:
|
||||
asyncio.create_task(self.on_event(event))
|
||||
|
||||
self.event_stream.subscribe(
|
||||
EventStreamSubscriber.SECURITY_ANALYZER, self.on_event, str(uuid4())
|
||||
EventStreamSubscriber.SECURITY_ANALYZER, sync_on_event, str(uuid4())
|
||||
)
|
||||
|
||||
async def on_event(self, event: Event) -> None:
|
||||
|
||||
@@ -58,6 +58,8 @@ async def connect(connection_id: str, environ):
|
||||
f'Connected to conversation {conversation_id} with connection_id {connection_id}. Replaying event stream...'
|
||||
)
|
||||
agent_state_changed = None
|
||||
if event_stream is None:
|
||||
raise ConnectionRefusedError('Failed to join conversation')
|
||||
async_stream = AsyncEventStreamWrapper(event_stream, latest_event_id + 1)
|
||||
async for event in async_stream:
|
||||
logger.info(f'oh_event: {event.__class__.__name__}')
|
||||
|
||||
@@ -52,7 +52,7 @@ def test_bash_server(temp_dir, runtime_cls, run_as_openhands):
|
||||
assert obs.exit_code == -1
|
||||
assert 'Serving HTTP on 0.0.0.0 port 8080' in obs.content
|
||||
assert (
|
||||
"[The command timed out after 1 seconds. You may wait longer to see additional output by sending empty command '', send other commands to interact with the current process, or send keys to interrupt/kill the command.]"
|
||||
"[The command timed out after 1.0 seconds. You may wait longer to see additional output by sending empty command '', send other commands to interact with the current process, or send keys to interrupt/kill the command.]"
|
||||
in obs.metadata.suffix
|
||||
)
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from openhands.events.observation.files import (
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
)
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.locations import get_conversation_event_filename
|
||||
|
||||
@@ -157,6 +158,31 @@ def test_get_matching_events_source_filter(temp_dir: str):
|
||||
and events[0].source == EventSource.ENVIRONMENT
|
||||
)
|
||||
|
||||
# Test that source comparison works correctly with None source
|
||||
null_source_event = NullObservation('test4')
|
||||
event_stream.add_event(null_source_event, EventSource.AGENT)
|
||||
event = event_stream.get_event(event_stream.get_latest_event_id())
|
||||
event._source = None # type: ignore
|
||||
|
||||
# Update the serialized version
|
||||
data = event_to_dict(event)
|
||||
event_stream.file_store.write(
|
||||
event_stream._get_filename_for_id(event.id, event_stream.user_id),
|
||||
json.dumps(data),
|
||||
)
|
||||
|
||||
# Verify that source comparison works correctly
|
||||
assert event_stream._should_filter_event(
|
||||
event, source='agent'
|
||||
) # Should filter out None source events
|
||||
assert not event_stream._should_filter_event(
|
||||
event, source=None
|
||||
) # Should not filter out when source filter is None
|
||||
|
||||
# Filter by AGENT source again
|
||||
events = event_stream.get_matching_events(source='agent')
|
||||
assert len(events) == 2 # Should not include the None source event
|
||||
|
||||
|
||||
def test_get_matching_events_pagination(temp_dir: str):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
@@ -210,7 +236,7 @@ def test_memory_usage_file_operations(temp_dir: str):
|
||||
"""
|
||||
|
||||
def get_memory_mb():
|
||||
"""Get current memory usage in MB"""
|
||||
"""Get current memory usage in MB."""
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss / 1024 / 1024
|
||||
|
||||
|
||||
@@ -61,13 +61,13 @@ def test_get_messages(codeact_agent: CodeActAgent):
|
||||
message_action_1._source = 'user'
|
||||
history.append(message_action_1)
|
||||
message_action_2 = MessageAction('Sure!')
|
||||
message_action_2._source = 'assistant'
|
||||
message_action_2._source = 'agent'
|
||||
history.append(message_action_2)
|
||||
message_action_3 = MessageAction('Hello, agent!')
|
||||
message_action_3._source = 'user'
|
||||
history.append(message_action_3)
|
||||
message_action_4 = MessageAction('Hello, user!')
|
||||
message_action_4._source = 'assistant'
|
||||
message_action_4._source = 'agent'
|
||||
history.append(message_action_4)
|
||||
message_action_5 = MessageAction('Laaaaaaaast!')
|
||||
message_action_5._source = 'user'
|
||||
@@ -106,7 +106,7 @@ def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
|
||||
message_action_user._source = 'user'
|
||||
history.append(message_action_user)
|
||||
message_action_agent = MessageAction(f'Agent message {i}')
|
||||
message_action_agent._source = 'assistant'
|
||||
message_action_agent._source = 'agent'
|
||||
history.append(message_action_agent)
|
||||
|
||||
codeact_agent.reset()
|
||||
|
||||
@@ -603,11 +603,11 @@ async def test_check_usertask(
|
||||
|
||||
if is_appropriate == 'No':
|
||||
assert len(event_list) == 2
|
||||
assert type(event_list[0]) == MessageAction
|
||||
assert type(event_list[1]) == ChangeAgentStateAction
|
||||
assert isinstance(event_list[0], MessageAction)
|
||||
assert isinstance(event_list[1], ChangeAgentStateAction)
|
||||
elif is_appropriate == 'Yes':
|
||||
assert len(event_list) == 1
|
||||
assert type(event_list[0]) == MessageAction
|
||||
assert isinstance(event_list[0], MessageAction)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -665,8 +665,8 @@ async def test_check_fillaction(
|
||||
|
||||
if is_harmful == 'Yes':
|
||||
assert len(event_list) == 2
|
||||
assert type(event_list[0]) == BrowseInteractiveAction
|
||||
assert type(event_list[1]) == ChangeAgentStateAction
|
||||
assert isinstance(event_list[0], BrowseInteractiveAction)
|
||||
assert isinstance(event_list[1], ChangeAgentStateAction)
|
||||
elif is_harmful == 'No':
|
||||
assert len(event_list) == 1
|
||||
assert type(event_list[0]) == BrowseInteractiveAction
|
||||
assert isinstance(event_list[0], BrowseInteractiveAction)
|
||||
|
||||
Reference in New Issue
Block a user