Compare commits

...

23 Commits

Author SHA1 Message Date
openhands
34ad595fcd Fix mypy error in events directory by adding type ignore comment for browsergym import 2025-03-09 18:48:04 +00:00
Graham Neubig
10ae123349 Update linting 2025-03-03 16:56:56 -05:00
openhands
9fac84170f Fix type comparisons in test_security.py 2025-03-03 21:02:15 +00:00
openhands
4469278434 Fix docstring in test_event_stream.py 2025-03-03 20:59:22 +00:00
openhands
ccd00be2e7 Fix source comparison with None source and add unit test 2025-03-03 20:56:09 +00:00
Graham Neubig
667f373a54 Merge branch 'main' into fix/mypy-events-errors 2025-03-03 14:52:52 -05:00
openhands
670ee91a41 Format test_event_stream.py with ruff 2025-03-03 18:59:45 +00:00
openhands
3153c9d9ba Add test for source comparison with None source 2025-03-03 18:56:30 +00:00
Graham Neubig
7a37d99731 Merge branch 'main' into fix/mypy-events-errors 2025-02-21 20:51:53 -05:00
Graham Neubig
ffd7c231ad Merge branch 'main' into fix/mypy-events-errors 2025-02-20 23:55:13 -05:00
openhands
67677e6b49 Fix EventSource usage in tests and apply linting fixes 2025-02-21 04:49:14 +00:00
Engel Nyst
c6c327d26d Merge branch 'main' into fix/mypy-events-errors 2025-02-19 19:39:55 +01:00
Xingyao Wang
2575df8696 Merge branch 'main' into fix/mypy-events-errors 2025-02-19 12:19:11 -05:00
Graham Neubig
8aa628ba34 Merge branch 'main' into fix/mypy-events-errors 2025-02-19 05:21:24 -05:00
openhands
db47d6e78b Fix mypy and ruff errors in events module 2025-02-19 02:30:49 +00:00
openhands
b36254b7b4 Revert changes outside openhands/events directory 2025-02-19 01:57:30 +00:00
openhands
a974164397 Fix mypy errors in events directory 2025-02-19 01:57:04 +00:00
Graham Neubig
592aca05e1 Merge branch 'main' into feature/strict-mypy-checks 2025-02-18 20:14:17 -05:00
Graham Neubig
d309455733 Merge branch 'main' into feature/strict-mypy-checks 2025-02-11 10:19:36 -05:00
Graham Neubig
66a7920539 Merge branch 'main' into feature/strict-mypy-checks 2025-02-10 13:06:49 -05:00
Graham Neubig
64ebef3646 Update .github/workflows/lint.yml 2025-01-21 14:52:56 -05:00
Graham Neubig
7a259915c1 Update .github/workflows/lint.yml 2025-01-21 14:47:51 -05:00
openhands
66bd8fdbcd Enable strict type checking with mypy
- Update mypy configuration with stricter type checking rules
- Add more type stubs to pre-commit configuration
- Run mypy both through pre-commit and directly in CI
- Install project in editable mode for better type checking
- Set correct PYTHONPATH in CI environment
2025-01-21 19:12:09 +00:00
15 changed files with 133 additions and 70 deletions

View File

@@ -8,7 +8,7 @@
* - Please do NOT serve this file on production.
*/
const PACKAGE_VERSION = '2.7.0'
const PACKAGE_VERSION = '2.7.3'
const INTEGRITY_CHECKSUM = '00729d72e3b82faf54ca8b9621dbb96f'
const IS_MOCKED_RESPONSE = Symbol('isMockedResponse')
const activeClientIds = new Set()

View File

@@ -17,12 +17,12 @@ class MessageAction(Action):
return self.content
@property
def images_urls(self):
def images_urls(self) -> list[str] | None:
# Deprecated alias for backward compatibility
return self.image_urls
@images_urls.setter
def images_urls(self, value):
def images_urls(self, value: list[str] | None) -> None:
self.image_urls = value
def __str__(self) -> str:

View File

@@ -29,19 +29,23 @@ class Event:
@property
def message(self) -> str | None:
if hasattr(self, '_message'):
return self._message # type: ignore[attr-defined]
msg = getattr(self, '_message')
return str(msg) if msg is not None else None
return ''
@property
def id(self) -> int:
if hasattr(self, '_id'):
return self._id # type: ignore[attr-defined]
id_val = getattr(self, '_id')
return int(id_val) if id_val is not None else Event.INVALID_ID
return Event.INVALID_ID
@property
def timestamp(self):
def timestamp(self) -> str | None:
if hasattr(self, '_timestamp') and isinstance(self._timestamp, str):
return self._timestamp
ts = getattr(self, '_timestamp')
return str(ts) if ts is not None else None
return None
@timestamp.setter
def timestamp(self, value: datetime) -> None:
@@ -51,19 +55,22 @@ class Event:
@property
def source(self) -> EventSource | None:
if hasattr(self, '_source'):
return self._source # type: ignore[attr-defined]
src = getattr(self, '_source')
return EventSource(src) if src is not None else None
return None
@property
def cause(self) -> int | None:
if hasattr(self, '_cause'):
return self._cause # type: ignore[attr-defined]
cause_val = getattr(self, '_cause')
return int(cause_val) if cause_val is not None else None
return None
@property
def timeout(self) -> int | None:
if hasattr(self, '_timeout'):
return self._timeout # type: ignore[attr-defined]
timeout_val = getattr(self, '_timeout')
return int(timeout_val) if timeout_val is not None else None
return None
def set_hard_timeout(self, value: int | None, blocking: bool = True) -> None:
@@ -90,7 +97,8 @@ class Event:
@property
def llm_metrics(self) -> Metrics | None:
if hasattr(self, '_llm_metrics'):
return self._llm_metrics # type: ignore[attr-defined]
metrics = getattr(self, '_llm_metrics')
return metrics if isinstance(metrics, Metrics) else None
return None
@llm_metrics.setter
@@ -101,7 +109,8 @@ class Event:
@property
def tool_call_metadata(self) -> ToolCallMetadata | None:
if hasattr(self, '_tool_call_metadata'):
return self._tool_call_metadata # type: ignore[attr-defined]
metadata = getattr(self, '_tool_call_metadata')
return metadata if isinstance(metadata, ToolCallMetadata) else None
return None
@tool_call_metadata.setter

View File

@@ -1,6 +1,7 @@
from dataclasses import dataclass, field
from typing import Any
from browsergym.utils.obs import flatten_axtree_to_str
from browsergym.utils.obs import flatten_axtree_to_str # type: ignore
from openhands.core.schema import ActionType, ObservationType
from openhands.events.observation.observation import Observation
@@ -16,13 +17,17 @@ class BrowserOutputObservation(Observation):
set_of_marks: str = field(default='', repr=False) # don't show in repr
error: bool = False
observation: str = ObservationType.BROWSE
goal_image_urls: list = field(default_factory=list)
goal_image_urls: list[str] = field(default_factory=list)
# do not include in the memory
open_pages_urls: list = field(default_factory=list)
open_pages_urls: list[str] = field(default_factory=list)
active_page_index: int = -1
dom_object: dict = field(default_factory=dict, repr=False) # don't show in repr
axtree_object: dict = field(default_factory=dict, repr=False) # don't show in repr
extra_element_properties: dict = field(
dom_object: dict[str, Any] = field(
default_factory=dict, repr=False
) # don't show in repr
axtree_object: dict[str, Any] = field(
default_factory=dict, repr=False
) # don't show in repr
extra_element_properties: dict[str, Any] = field(
default_factory=dict, repr=False
) # don't show in repr
last_browser_action: str = ''
@@ -102,4 +107,4 @@ class BrowserOutputObservation(Observation):
skip_generic=False,
filter_visible_only=filter_visible_only,
)
return cur_axtree_txt
return str(cur_axtree_txt)

View File

@@ -2,7 +2,7 @@ import json
import re
import traceback
from dataclasses import dataclass, field
from typing import Self
from typing import Any, Self
from pydantic import BaseModel
@@ -105,10 +105,10 @@ class CmdOutputObservation(Observation):
content: str,
command: str,
observation: str = ObservationType.RUN,
metadata: dict | CmdOutputMetadata | None = None,
metadata: dict[str, Any] | CmdOutputMetadata | None = None,
hidden: bool = False,
**kwargs,
):
**kwargs: Any,
) -> None:
super().__init__(content)
self.command = command
self.observation = observation

View File

@@ -1,4 +1,5 @@
import re
from typing import Any
from openhands.core.exceptions import LLMMalformedActionError
from openhands.events.action.action import Action
@@ -42,7 +43,7 @@ actions = (
ACTION_TYPE_TO_CLASS = {action_class.action: action_class for action_class in actions} # type: ignore[attr-defined]
def handle_action_deprecated_args(args: dict) -> dict:
def handle_action_deprecated_args(args: dict[str, Any]) -> dict[str, Any]:
# keep_prompt has been deprecated in https://github.com/All-Hands-AI/OpenHands/pull/4881
if 'keep_prompt' in args:
args.pop('keep_prompt')
@@ -120,4 +121,5 @@ def action_from_dict(action: dict) -> Action:
raise LLMMalformedActionError(
f'action={action} has the wrong arguments: {str(e)}'
)
assert isinstance(decoded_action, Action)
return decoded_action

View File

@@ -1,5 +1,6 @@
from dataclasses import asdict
from datetime import datetime
from typing import Any
from pydantic import BaseModel
@@ -37,14 +38,14 @@ DELETE_FROM_TRAJECTORY_EXTRAS = {
DELETE_FROM_MEMORY_EXTRAS = DELETE_FROM_TRAJECTORY_EXTRAS | {'open_pages_urls'}
def event_from_dict(data) -> 'Event':
def event_from_dict(data: dict[str, Any]) -> 'Event':
evt: Event
if 'action' in data:
evt = action_from_dict(data)
elif 'observation' in data:
evt = observation_from_dict(data)
else:
raise ValueError('Unknown event type: ' + data)
raise ValueError(f'Unknown event type: {data}')
for key in UNDERSCORE_KEYS:
if key in data:
value = data[key]

View File

@@ -1,4 +1,5 @@
import copy
from typing import Any
from openhands.events.observation.agent import (
AgentCondensationObservation,
@@ -49,8 +50,8 @@ OBSERVATION_TYPE_TO_CLASS = {
def _update_cmd_output_metadata(
metadata: dict | CmdOutputMetadata | None, **kwargs
) -> dict | CmdOutputMetadata:
metadata: dict[str, Any] | CmdOutputMetadata | None, **kwargs: Any
) -> dict[str, Any] | CmdOutputMetadata:
"""Update the metadata of a CmdOutputObservation.
If metadata is None, create a new CmdOutputMetadata instance.
@@ -110,4 +111,6 @@ def observation_from_dict(observation: dict) -> Observation:
else:
extras['metadata'] = CmdOutputMetadata()
return observation_class(content=content, **extras)
obs = observation_class(content=content, **extras)
assert isinstance(obs, Observation)
return obs

View File

@@ -1,4 +1,4 @@
def remove_fields(obj, fields: set[str]):
def remove_fields(obj: dict | list | tuple, fields: set[str]) -> None:
"""Remove fields from an object.
Parameters:
@@ -14,7 +14,7 @@ def remove_fields(obj, fields: set[str]):
elif isinstance(obj, (list, tuple)):
for item in obj:
remove_fields(item, fields)
elif hasattr(obj, '__dataclass_fields__'):
if hasattr(obj, '__dataclass_fields__'):
raise ValueError(
'Object must not contain dataclass, consider converting to dict first'
)

View File

@@ -5,7 +5,7 @@ from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from enum import Enum
from functools import partial
from typing import Callable, Iterable
from typing import Any, AsyncIterator, Callable, Iterable
from openhands.core.logger import openhands_logger as logger
from openhands.events.event import Event, EventSource
@@ -40,18 +40,21 @@ async def session_exists(sid: str, file_store: FileStore) -> bool:
class AsyncEventStreamWrapper:
def __init__(self, event_stream, *args, **kwargs):
def __init__(self, event_stream: 'EventStream', *args: Any, **kwargs: Any) -> None:
self.event_stream = event_stream
self.args = args
self.kwargs = kwargs
async def __aiter__(self):
async def __aiter__(self) -> AsyncIterator[Event]:
loop = asyncio.get_running_loop()
# Create an async generator that yields events
for event in self.event_stream.get_events(*self.args, **self.kwargs):
# Run the blocking get_events() in a thread pool
yield await loop.run_in_executor(None, lambda e=event: e) # type: ignore
def get_event(e: Event = event) -> Event:
return e
yield await loop.run_in_executor(None, get_event)
class EventStream:
@@ -102,14 +105,14 @@ class EventStream:
if id >= self._cur_id:
self._cur_id = id + 1
def _init_thread_loop(self, subscriber_id: str, callback_id: str):
def _init_thread_loop(self, subscriber_id: str, callback_id: str) -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if subscriber_id not in self._thread_loops:
self._thread_loops[subscriber_id] = {}
self._thread_loops[subscriber_id][callback_id] = loop
def close(self):
def close(self) -> None:
self._stop_flag.set()
if self._queue_thread.is_alive():
self._queue_thread.join()
@@ -124,7 +127,7 @@ class EventStream:
while not self._queue.empty():
self._queue.get()
def _clean_up_subscriber(self, subscriber_id: str, callback_id: str):
def _clean_up_subscriber(self, subscriber_id: str, callback_id: str) -> None:
if subscriber_id not in self._subscribers:
logger.warning(f'Subscriber not found during cleanup: {subscriber_id}')
return
@@ -172,7 +175,7 @@ class EventStream:
end_id: int | None = None,
reverse: bool = False,
filter_out_type: tuple[type[Event], ...] | None = None,
filter_hidden=False,
filter_hidden: bool = False,
) -> Iterable[Event]:
"""
Retrieve events from the event stream, optionally filtering out events of a given type
@@ -189,7 +192,7 @@ class EventStream:
Events from the stream that match the criteria.
"""
def should_filter(event: Event):
def should_filter(event: Event) -> bool:
if filter_hidden and hasattr(event, 'hidden') and event.hidden:
return True
if filter_out_type is not None and isinstance(event, filter_out_type):
@@ -234,8 +237,11 @@ class EventStream:
return self._cur_id - 1
def subscribe(
self, subscriber_id: EventStreamSubscriber, callback: Callable, callback_id: str
):
self,
subscriber_id: EventStreamSubscriber,
callback: Callable[[Event], None],
callback_id: str,
) -> None:
initializer = partial(self._init_thread_loop, subscriber_id, callback_id)
pool = ThreadPoolExecutor(max_workers=1, initializer=initializer)
if subscriber_id not in self._subscribers:
@@ -250,7 +256,9 @@ class EventStream:
self._subscribers[subscriber_id][callback_id] = callback
self._thread_pools[subscriber_id][callback_id] = pool
def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str):
def unsubscribe(
self, subscriber_id: EventStreamSubscriber, callback_id: str
) -> None:
if subscriber_id not in self._subscribers:
logger.warning(f'Subscriber not found during unsubscribe: {subscriber_id}')
return
@@ -261,7 +269,7 @@ class EventStream:
self._clean_up_subscriber(subscriber_id, callback_id)
def add_event(self, event: Event, source: EventSource):
def add_event(self, event: Event, source: EventSource) -> None:
if hasattr(event, '_id') and event.id is not None:
raise ValueError(
f'Event already has an ID:{event.id}. It was probably added back to the EventStream from inside a handler, triggering a loop.'
@@ -279,13 +287,13 @@ class EventStream:
self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
self._queue.put(event)
def set_secrets(self, secrets: dict[str, str]):
def set_secrets(self, secrets: dict[str, str]) -> None:
self.secrets = secrets.copy()
def update_secrets(self, secrets: dict[str, str]):
def update_secrets(self, secrets: dict[str, str]) -> None:
self.secrets.update(secrets)
def _replace_secrets(self, data: dict) -> dict:
def _replace_secrets(self, data: dict[str, Any]) -> dict[str, Any]:
for key in data:
if isinstance(data[key], dict):
data[key] = self._replace_secrets(data[key])
@@ -294,7 +302,7 @@ class EventStream:
data[key] = data[key].replace(secret, '<secret_hidden>')
return data
def _run_queue_loop(self):
def _run_queue_loop(self) -> None:
self._queue_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._queue_loop)
try:
@@ -302,7 +310,7 @@ class EventStream:
finally:
self._queue_loop.close()
async def _process_queue(self):
async def _process_queue(self) -> None:
while should_continue() and not self._stop_flag.is_set():
event = None
try:
@@ -319,8 +327,10 @@ class EventStream:
future = pool.submit(callback, event)
future.add_done_callback(self._make_error_handler(callback_id, key))
def _make_error_handler(self, callback_id: str, subscriber_id: str):
def _handle_callback_error(fut):
def _make_error_handler(
self, callback_id: str, subscriber_id: str
) -> Callable[[Any], None]:
def _handle_callback_error(fut: Any) -> None:
try:
# This will raise any exception that occurred during callback execution
fut.result()
@@ -333,14 +343,14 @@ class EventStream:
return _handle_callback_error
def filtered_events_by_source(self, source: EventSource):
def filtered_events_by_source(self, source: EventSource) -> Iterable[Event]:
for event in self.get_events():
if event.source == source:
yield event
def _should_filter_event(
self,
event,
event: Event,
query: str | None = None,
event_types: tuple[type[Event], ...] | None = None,
source: str | None = None,
@@ -363,13 +373,14 @@ class EventStream:
if event_types and not isinstance(event, event_types):
return True
if source and not event.source.value == source:
if source:
if event.source is None or event.source.value != source:
return True
if start_date and event.timestamp is not None and event.timestamp < start_date:
return True
if start_date and event.timestamp < start_date:
return True
if end_date and event.timestamp > end_date:
if end_date and event.timestamp is not None and event.timestamp > end_date:
return True
# Text search in event content if query provided
@@ -391,7 +402,7 @@ class EventStream:
start_id: int = 0,
limit: int = 100,
reverse: bool = False,
) -> list[type[Event]]:
) -> list[Event]:
"""Get matching events from the event stream based on filters.
Args:

View File

@@ -1,3 +1,4 @@
import asyncio
from typing import Any
from uuid import uuid4
@@ -19,8 +20,12 @@ class SecurityAnalyzer:
event_stream: The event stream to listen for events.
"""
self.event_stream = event_stream
def sync_on_event(event: Event) -> None:
asyncio.create_task(self.on_event(event))
self.event_stream.subscribe(
EventStreamSubscriber.SECURITY_ANALYZER, self.on_event, str(uuid4())
EventStreamSubscriber.SECURITY_ANALYZER, sync_on_event, str(uuid4())
)
async def on_event(self, event: Event) -> None:

View File

@@ -50,6 +50,8 @@ async def connect(connection_id: str, environ):
)
agent_state_changed = None
if event_stream is None:
raise ConnectionRefusedError('Failed to join conversation')
async_stream = AsyncEventStreamWrapper(event_stream, latest_event_id + 1)
async for event in async_stream:
if isinstance(

View File

@@ -24,6 +24,7 @@ from openhands.events.observation.files import (
FileReadObservation,
FileWriteObservation,
)
from openhands.events.serialization.event import event_to_dict
from openhands.storage import get_file_store
@@ -156,6 +157,30 @@ def test_get_matching_events_source_filter(temp_dir: str):
and events[0].source == EventSource.ENVIRONMENT
)
# Test that source comparison works correctly with None source
null_source_event = NullObservation('test4')
event_stream.add_event(null_source_event, EventSource.AGENT)
event = event_stream.get_event(event_stream.get_latest_event_id())
event._source = None # type: ignore
# Update the serialized version
data = event_to_dict(event)
event_stream.file_store.write(
event_stream._get_filename_for_id(event.id), json.dumps(data)
)
# Verify that source comparison works correctly
assert event_stream._should_filter_event(
event, source='agent'
) # Should filter out None source events
assert not event_stream._should_filter_event(
event, source=None
) # Should not filter out when source filter is None
# Filter by AGENT source again
events = event_stream.get_matching_events(source='agent')
assert len(events) == 2 # Should not include the None source event
def test_get_matching_events_pagination(temp_dir: str):
file_store = get_file_store('local', temp_dir)
@@ -209,7 +234,7 @@ def test_memory_usage_file_operations(temp_dir: str):
"""
def get_memory_mb():
"""Get current memory usage in MB"""
"""Get current memory usage in MB."""
process = psutil.Process(os.getpid())
return process.memory_info().rss / 1024 / 1024

View File

@@ -61,13 +61,13 @@ def test_get_messages(codeact_agent: CodeActAgent):
message_action_1._source = 'user'
history.append(message_action_1)
message_action_2 = MessageAction('Sure!')
message_action_2._source = 'assistant'
message_action_2._source = 'agent'
history.append(message_action_2)
message_action_3 = MessageAction('Hello, agent!')
message_action_3._source = 'user'
history.append(message_action_3)
message_action_4 = MessageAction('Hello, user!')
message_action_4._source = 'assistant'
message_action_4._source = 'agent'
history.append(message_action_4)
message_action_5 = MessageAction('Laaaaaaaast!')
message_action_5._source = 'user'
@@ -106,7 +106,7 @@ def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
message_action_user._source = 'user'
history.append(message_action_user)
message_action_agent = MessageAction(f'Agent message {i}')
message_action_agent._source = 'assistant'
message_action_agent._source = 'agent'
history.append(message_action_agent)
codeact_agent.reset()

View File

@@ -603,11 +603,11 @@ async def test_check_usertask(
if is_appropriate == 'No':
assert len(event_list) == 2
assert type(event_list[0]) == MessageAction
assert type(event_list[1]) == ChangeAgentStateAction
assert isinstance(event_list[0], MessageAction)
assert isinstance(event_list[1], ChangeAgentStateAction)
elif is_appropriate == 'Yes':
assert len(event_list) == 1
assert type(event_list[0]) == MessageAction
assert isinstance(event_list[0], MessageAction)
@pytest.mark.parametrize(
@@ -665,8 +665,8 @@ async def test_check_fillaction(
if is_harmful == 'Yes':
assert len(event_list) == 2
assert type(event_list[0]) == BrowseInteractiveAction
assert type(event_list[1]) == ChangeAgentStateAction
assert isinstance(event_list[0], BrowseInteractiveAction)
assert isinstance(event_list[1], ChangeAgentStateAction)
elif is_harmful == 'No':
assert len(event_list) == 1
assert type(event_list[0]) == BrowseInteractiveAction
assert isinstance(event_list[0], BrowseInteractiveAction)