mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
28 Commits
fix/utils-
...
neubig/ref
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
84a53c0458 | ||
|
|
935cd9d9a2 | ||
|
|
e61b8aa31b | ||
|
|
5dbcf9b44c | ||
|
|
750b1dda2f | ||
|
|
dec6de531a | ||
|
|
c977657612 | ||
|
|
ab899831ce | ||
|
|
92b919cbc6 | ||
|
|
5e9729fbb9 | ||
|
|
3c5f432696 | ||
|
|
395e54bc77 | ||
|
|
6d8f983fb4 | ||
|
|
487b189025 | ||
|
|
e328c4d7b8 | ||
|
|
31fd6064ea | ||
|
|
1933180ba1 | ||
|
|
2e958d2e9c | ||
|
|
af9a2277b0 | ||
|
|
8c12c2edce | ||
|
|
2c4f3612bc | ||
|
|
6fd51a6ae7 | ||
|
|
926e3983d1 | ||
|
|
1975d39ec4 | ||
|
|
f557731ef7 | ||
|
|
26b857fc40 | ||
|
|
10fbf83dac | ||
|
|
8d0381aae7 |
@@ -10,6 +10,7 @@ import { addUserMessage } from "#/state/chat-slice";
|
||||
import { RootState } from "#/store";
|
||||
import { AgentState } from "#/types/agent-state";
|
||||
import { generateAgentStateChangeEvent } from "#/services/agent-state-service";
|
||||
import { getStopProcessesCommand } from "#/services/terminal-service";
|
||||
import { FeedbackModal } from "../feedback/feedback-modal";
|
||||
import { useScrollToBottom } from "#/hooks/use-scroll-to-bottom";
|
||||
import { TypingIndicator } from "./typing-indicator";
|
||||
@@ -82,7 +83,8 @@ export function ChatInterface() {
|
||||
|
||||
const handleStop = () => {
|
||||
posthog.capture("stop_button_clicked");
|
||||
send(generateAgentStateChangeEvent(AgentState.STOPPED));
|
||||
send(getStopProcessesCommand()); // First kill all processes
|
||||
send(generateAgentStateChangeEvent(AgentState.STOPPED)); // Then change agent state
|
||||
};
|
||||
|
||||
const onClickShareFeedbackActionButton = async (
|
||||
|
||||
@@ -4,3 +4,8 @@ export function getTerminalCommand(command: string, hidden: boolean = false) {
|
||||
const event = { action: ActionType.RUN, args: { command, hidden } };
|
||||
return event;
|
||||
}
|
||||
|
||||
export function getStopProcessesCommand() {
|
||||
const event = { action: ActionType.RUN, args: { command: "pkill -P $$" } };
|
||||
return event;
|
||||
}
|
||||
|
||||
@@ -86,4 +86,6 @@ class ActionTypeSchema(BaseModel):
|
||||
"""Retrieves content from a user workspace, microagent, or other source."""
|
||||
|
||||
|
||||
|
||||
|
||||
ActionType = ActionTypeSchema()
|
||||
|
||||
@@ -36,5 +36,5 @@ __all__ = [
|
||||
'MessageAction',
|
||||
'ActionConfirmationStatus',
|
||||
'AgentThinkAction',
|
||||
'RecallAction',
|
||||
|
||||
]
|
||||
|
||||
@@ -60,3 +60,6 @@ class IPythonRunCellAction(Action):
|
||||
@property
|
||||
def message(self) -> str:
|
||||
return f'Running Python code interactively: {self.code}'
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -17,12 +17,12 @@ class MessageAction(Action):
|
||||
return self.content
|
||||
|
||||
@property
|
||||
def images_urls(self) -> list[str] | None:
|
||||
def images_urls(self):
|
||||
# Deprecated alias for backward compatibility
|
||||
return self.image_urls
|
||||
|
||||
@images_urls.setter
|
||||
def images_urls(self, value: list[str] | None) -> None:
|
||||
def images_urls(self, value):
|
||||
self.image_urls = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
||||
@@ -39,23 +39,19 @@ class Event:
|
||||
@property
|
||||
def message(self) -> str | None:
|
||||
if hasattr(self, '_message'):
|
||||
msg = getattr(self, '_message')
|
||||
return str(msg) if msg is not None else None
|
||||
return self._message # type: ignore[attr-defined]
|
||||
return ''
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
if hasattr(self, '_id'):
|
||||
id_val = getattr(self, '_id')
|
||||
return int(id_val) if id_val is not None else Event.INVALID_ID
|
||||
return self._id # type: ignore[attr-defined]
|
||||
return Event.INVALID_ID
|
||||
|
||||
@property
|
||||
def timestamp(self) -> str | None:
|
||||
def timestamp(self):
|
||||
if hasattr(self, '_timestamp') and isinstance(self._timestamp, str):
|
||||
ts = getattr(self, '_timestamp')
|
||||
return str(ts) if ts is not None else None
|
||||
return None
|
||||
return self._timestamp
|
||||
|
||||
@timestamp.setter
|
||||
def timestamp(self, value: datetime) -> None:
|
||||
@@ -65,25 +61,22 @@ class Event:
|
||||
@property
|
||||
def source(self) -> EventSource | None:
|
||||
if hasattr(self, '_source'):
|
||||
src = getattr(self, '_source')
|
||||
return EventSource(src) if src is not None else None
|
||||
return self._source # type: ignore[attr-defined]
|
||||
return None
|
||||
|
||||
@property
|
||||
def cause(self) -> int | None:
|
||||
if hasattr(self, '_cause'):
|
||||
cause_val = getattr(self, '_cause')
|
||||
return int(cause_val) if cause_val is not None else None
|
||||
return self._cause # type: ignore[attr-defined]
|
||||
return None
|
||||
|
||||
@property
|
||||
def timeout(self) -> float | None:
|
||||
def timeout(self) -> int | None:
|
||||
if hasattr(self, '_timeout'):
|
||||
timeout_val = getattr(self, '_timeout')
|
||||
return float(timeout_val) if timeout_val is not None else None
|
||||
return self._timeout # type: ignore[attr-defined]
|
||||
return None
|
||||
|
||||
def set_hard_timeout(self, value: float | None, blocking: bool = True) -> None:
|
||||
def set_hard_timeout(self, value: int | None, blocking: bool = True) -> None:
|
||||
"""Set the timeout for the event.
|
||||
|
||||
NOTE, this is a hard timeout, meaning that the event will be blocked
|
||||
@@ -107,8 +100,7 @@ class Event:
|
||||
@property
|
||||
def llm_metrics(self) -> Metrics | None:
|
||||
if hasattr(self, '_llm_metrics'):
|
||||
metrics = getattr(self, '_llm_metrics')
|
||||
return metrics if isinstance(metrics, Metrics) else None
|
||||
return self._llm_metrics # type: ignore[attr-defined]
|
||||
return None
|
||||
|
||||
@llm_metrics.setter
|
||||
@@ -119,8 +111,7 @@ class Event:
|
||||
@property
|
||||
def tool_call_metadata(self) -> ToolCallMetadata | None:
|
||||
if hasattr(self, '_tool_call_metadata'):
|
||||
metadata = getattr(self, '_tool_call_metadata')
|
||||
return metadata if isinstance(metadata, ToolCallMetadata) else None
|
||||
return self._tool_call_metadata # type: ignore[attr-defined]
|
||||
return None
|
||||
|
||||
@tool_call_metadata.setter
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from browsergym.utils.obs import flatten_axtree_to_str
|
||||
|
||||
@@ -17,17 +16,13 @@ class BrowserOutputObservation(Observation):
|
||||
set_of_marks: str = field(default='', repr=False) # don't show in repr
|
||||
error: bool = False
|
||||
observation: str = ObservationType.BROWSE
|
||||
goal_image_urls: list[str] = field(default_factory=list)
|
||||
goal_image_urls: list = field(default_factory=list)
|
||||
# do not include in the memory
|
||||
open_pages_urls: list[str] = field(default_factory=list)
|
||||
open_pages_urls: list = field(default_factory=list)
|
||||
active_page_index: int = -1
|
||||
dom_object: dict[str, Any] = field(
|
||||
default_factory=dict, repr=False
|
||||
) # don't show in repr
|
||||
axtree_object: dict[str, Any] = field(
|
||||
default_factory=dict, repr=False
|
||||
) # don't show in repr
|
||||
extra_element_properties: dict[str, Any] = field(
|
||||
dom_object: dict = field(default_factory=dict, repr=False) # don't show in repr
|
||||
axtree_object: dict = field(default_factory=dict, repr=False) # don't show in repr
|
||||
extra_element_properties: dict = field(
|
||||
default_factory=dict, repr=False
|
||||
) # don't show in repr
|
||||
last_browser_action: str = ''
|
||||
@@ -107,4 +102,4 @@ class BrowserOutputObservation(Observation):
|
||||
skip_generic=False,
|
||||
filter_visible_only=filter_visible_only,
|
||||
)
|
||||
return str(cur_axtree_txt)
|
||||
return cur_axtree_txt
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import re
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Self
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -105,10 +105,10 @@ class CmdOutputObservation(Observation):
|
||||
content: str,
|
||||
command: str,
|
||||
observation: str = ObservationType.RUN,
|
||||
metadata: dict[str, Any] | CmdOutputMetadata | None = None,
|
||||
metadata: dict | CmdOutputMetadata | None = None,
|
||||
hidden: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(content)
|
||||
self.command = command
|
||||
self.observation = observation
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import re
|
||||
from typing import Any
|
||||
from openhands.core.exceptions import LLMMalformedActionError
|
||||
from openhands.events.action.action import Action
|
||||
from openhands.events.action.agent import (
|
||||
@@ -44,7 +42,7 @@ actions = (
|
||||
ACTION_TYPE_TO_CLASS = {action_class.action: action_class for action_class in actions} # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def handle_action_deprecated_args(args: dict[str, Any]) -> dict[str, Any]:
|
||||
def handle_action_deprecated_args(args: dict) -> dict:
|
||||
# keep_prompt has been deprecated in https://github.com/All-Hands-AI/OpenHands/pull/4881
|
||||
if 'keep_prompt' in args:
|
||||
args.pop('keep_prompt')
|
||||
@@ -128,5 +126,4 @@ def action_from_dict(action: dict) -> Action:
|
||||
raise LLMMalformedActionError(
|
||||
f'action={action} has the wrong arguments: {str(e)}'
|
||||
)
|
||||
assert isinstance(decoded_action, Action)
|
||||
return decoded_action
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from dataclasses import asdict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -49,14 +48,14 @@ DELETE_FROM_TRAJECTORY_EXTRAS_AND_SCREENSHOTS = DELETE_FROM_TRAJECTORY_EXTRAS |
|
||||
}
|
||||
|
||||
|
||||
def event_from_dict(data: dict[str, Any]) -> 'Event':
|
||||
def event_from_dict(data) -> 'Event':
|
||||
evt: Event
|
||||
if 'action' in data:
|
||||
evt = action_from_dict(data)
|
||||
elif 'observation' in data:
|
||||
evt = observation_from_dict(data)
|
||||
else:
|
||||
raise ValueError(f'Unknown event type: {data}')
|
||||
raise ValueError('Unknown event type: ' + data)
|
||||
for key in UNDERSCORE_KEYS:
|
||||
if key in data:
|
||||
value = data[key]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import copy
|
||||
from typing import Any
|
||||
|
||||
from openhands.events.event import RecallType
|
||||
from openhands.events.observation.agent import (
|
||||
@@ -54,8 +53,8 @@ OBSERVATION_TYPE_TO_CLASS = {
|
||||
|
||||
|
||||
def _update_cmd_output_metadata(
|
||||
metadata: dict[str, Any] | CmdOutputMetadata | None, **kwargs: Any
|
||||
) -> dict[str, Any] | CmdOutputMetadata:
|
||||
metadata: dict | CmdOutputMetadata | None, **kwargs
|
||||
) -> dict | CmdOutputMetadata:
|
||||
"""Update the metadata of a CmdOutputObservation.
|
||||
|
||||
If metadata is None, create a new CmdOutputMetadata instance.
|
||||
@@ -129,6 +128,4 @@ def observation_from_dict(observation: dict) -> Observation:
|
||||
for item in extras['microagent_knowledge']
|
||||
]
|
||||
|
||||
obs = observation_class(content=content, **extras)
|
||||
assert isinstance(obs, Observation)
|
||||
return obs
|
||||
return observation_class(content=content, **extras)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
def remove_fields(obj: dict | list | tuple, fields: set[str]) -> None:
|
||||
def remove_fields(obj, fields: set[str]):
|
||||
"""Remove fields from an object.
|
||||
|
||||
Parameters:
|
||||
@@ -14,7 +14,7 @@ def remove_fields(obj: dict | list | tuple, fields: set[str]) -> None:
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
for item in obj:
|
||||
remove_fields(item, fields)
|
||||
if hasattr(obj, '__dataclass_fields__'):
|
||||
elif hasattr(obj, '__dataclass_fields__'):
|
||||
raise ValueError(
|
||||
'Object must not contain dataclass, consider converting to dict first'
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, AsyncIterator, Callable, Iterable
|
||||
from typing import Callable, Iterable
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.event import Event, EventSource
|
||||
@@ -43,21 +43,18 @@ async def session_exists(
|
||||
|
||||
|
||||
class AsyncEventStreamWrapper:
|
||||
def __init__(self, event_stream: 'EventStream', *args: Any, **kwargs: Any) -> None:
|
||||
def __init__(self, event_stream, *args, **kwargs):
|
||||
self.event_stream = event_stream
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
async def __aiter__(self) -> AsyncIterator[Event]:
|
||||
async def __aiter__(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Create an async generator that yields events
|
||||
for event in self.event_stream.get_events(*self.args, **self.kwargs):
|
||||
# Run the blocking get_events() in a thread pool
|
||||
def get_event(e: Event = event) -> Event:
|
||||
return e
|
||||
|
||||
yield await loop.run_in_executor(None, get_event)
|
||||
yield await loop.run_in_executor(None, lambda e=event: e) # type: ignore
|
||||
|
||||
|
||||
class EventStream:
|
||||
@@ -124,14 +121,14 @@ class EventStream:
|
||||
if id >= self._cur_id:
|
||||
self._cur_id = id + 1
|
||||
|
||||
def _init_thread_loop(self, subscriber_id: str, callback_id: str) -> None:
|
||||
def _init_thread_loop(self, subscriber_id: str, callback_id: str):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
if subscriber_id not in self._thread_loops:
|
||||
self._thread_loops[subscriber_id] = {}
|
||||
self._thread_loops[subscriber_id][callback_id] = loop
|
||||
|
||||
def close(self) -> None:
|
||||
def close(self):
|
||||
self._stop_flag.set()
|
||||
if self._queue_thread.is_alive():
|
||||
self._queue_thread.join()
|
||||
@@ -146,7 +143,7 @@ class EventStream:
|
||||
while not self._queue.empty():
|
||||
self._queue.get()
|
||||
|
||||
def _clean_up_subscriber(self, subscriber_id: str, callback_id: str) -> None:
|
||||
def _clean_up_subscriber(self, subscriber_id: str, callback_id: str):
|
||||
if subscriber_id not in self._subscribers:
|
||||
logger.warning(f'Subscriber not found during cleanup: {subscriber_id}')
|
||||
return
|
||||
@@ -194,7 +191,7 @@ class EventStream:
|
||||
end_id: int | None = None,
|
||||
reverse: bool = False,
|
||||
filter_out_type: tuple[type[Event], ...] | None = None,
|
||||
filter_hidden: bool = False,
|
||||
filter_hidden=False,
|
||||
) -> Iterable[Event]:
|
||||
"""
|
||||
Retrieve events from the event stream, optionally filtering out events of a given type
|
||||
@@ -211,7 +208,7 @@ class EventStream:
|
||||
Events from the stream that match the criteria.
|
||||
"""
|
||||
|
||||
def should_filter(event: Event) -> bool:
|
||||
def should_filter(event: Event):
|
||||
if filter_hidden and hasattr(event, 'hidden') and event.hidden:
|
||||
return True
|
||||
if filter_out_type is not None and isinstance(event, filter_out_type):
|
||||
@@ -266,11 +263,8 @@ class EventStream:
|
||||
return self._cur_id - 1
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
subscriber_id: EventStreamSubscriber,
|
||||
callback: Callable[[Event], None],
|
||||
callback_id: str,
|
||||
) -> None:
|
||||
self, subscriber_id: EventStreamSubscriber, callback: Callable, callback_id: str
|
||||
):
|
||||
initializer = partial(self._init_thread_loop, subscriber_id, callback_id)
|
||||
pool = ThreadPoolExecutor(max_workers=1, initializer=initializer)
|
||||
if subscriber_id not in self._subscribers:
|
||||
@@ -285,9 +279,7 @@ class EventStream:
|
||||
self._subscribers[subscriber_id][callback_id] = callback
|
||||
self._thread_pools[subscriber_id][callback_id] = pool
|
||||
|
||||
def unsubscribe(
|
||||
self, subscriber_id: EventStreamSubscriber, callback_id: str
|
||||
) -> None:
|
||||
def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str):
|
||||
if subscriber_id not in self._subscribers:
|
||||
logger.warning(f'Subscriber not found during unsubscribe: {subscriber_id}')
|
||||
return
|
||||
@@ -298,8 +290,8 @@ class EventStream:
|
||||
|
||||
self._clean_up_subscriber(subscriber_id, callback_id)
|
||||
|
||||
def add_event(self, event: Event, source: EventSource) -> None:
|
||||
if event.id != Event.INVALID_ID:
|
||||
def add_event(self, event: Event, source: EventSource):
|
||||
if hasattr(event, '_id') and event.id is not None:
|
||||
raise ValueError(
|
||||
f'Event already has an ID:{event.id}. It was probably added back to the EventStream from inside a handler, triggering a loop.'
|
||||
)
|
||||
@@ -318,13 +310,13 @@ class EventStream:
|
||||
)
|
||||
self._queue.put(event)
|
||||
|
||||
def set_secrets(self, secrets: dict[str, str]) -> None:
|
||||
def set_secrets(self, secrets: dict[str, str]):
|
||||
self.secrets = secrets.copy()
|
||||
|
||||
def update_secrets(self, secrets: dict[str, str]) -> None:
|
||||
def update_secrets(self, secrets: dict[str, str]):
|
||||
self.secrets.update(secrets)
|
||||
|
||||
def _replace_secrets(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
def _replace_secrets(self, data: dict) -> dict:
|
||||
for key in data:
|
||||
if isinstance(data[key], dict):
|
||||
data[key] = self._replace_secrets(data[key])
|
||||
@@ -333,7 +325,7 @@ class EventStream:
|
||||
data[key] = data[key].replace(secret, '<secret_hidden>')
|
||||
return data
|
||||
|
||||
def _run_queue_loop(self) -> None:
|
||||
def _run_queue_loop(self):
|
||||
self._queue_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._queue_loop)
|
||||
try:
|
||||
@@ -341,7 +333,7 @@ class EventStream:
|
||||
finally:
|
||||
self._queue_loop.close()
|
||||
|
||||
async def _process_queue(self) -> None:
|
||||
async def _process_queue(self):
|
||||
while should_continue() and not self._stop_flag.is_set():
|
||||
event = None
|
||||
try:
|
||||
@@ -358,10 +350,8 @@ class EventStream:
|
||||
future = pool.submit(callback, event)
|
||||
future.add_done_callback(self._make_error_handler(callback_id, key))
|
||||
|
||||
def _make_error_handler(
|
||||
self, callback_id: str, subscriber_id: str
|
||||
) -> Callable[[Any], None]:
|
||||
def _handle_callback_error(fut: Any) -> None:
|
||||
def _make_error_handler(self, callback_id: str, subscriber_id: str):
|
||||
def _handle_callback_error(fut):
|
||||
try:
|
||||
# This will raise any exception that occurred during callback execution
|
||||
fut.result()
|
||||
@@ -374,14 +364,14 @@ class EventStream:
|
||||
|
||||
return _handle_callback_error
|
||||
|
||||
def filtered_events_by_source(self, source: EventSource) -> Iterable[Event]:
|
||||
def filtered_events_by_source(self, source: EventSource):
|
||||
for event in self.get_events():
|
||||
if event.source == source:
|
||||
yield event
|
||||
|
||||
def _should_filter_event(
|
||||
self,
|
||||
event: Event,
|
||||
event,
|
||||
query: str | None = None,
|
||||
event_types: tuple[type[Event], ...] | None = None,
|
||||
source: str | None = None,
|
||||
@@ -404,14 +394,13 @@ class EventStream:
|
||||
if event_types and not isinstance(event, event_types):
|
||||
return True
|
||||
|
||||
if source:
|
||||
if event.source is None or event.source.value != source:
|
||||
return True
|
||||
|
||||
if start_date and event.timestamp is not None and event.timestamp < start_date:
|
||||
if source and not event.source.value == source:
|
||||
return True
|
||||
|
||||
if end_date and event.timestamp is not None and event.timestamp > end_date:
|
||||
if start_date and event.timestamp < start_date:
|
||||
return True
|
||||
|
||||
if end_date and event.timestamp > end_date:
|
||||
return True
|
||||
|
||||
# Text search in event content if query provided
|
||||
@@ -433,7 +422,7 @@ class EventStream:
|
||||
start_id: int = 0,
|
||||
limit: int = 100,
|
||||
reverse: bool = False,
|
||||
) -> list[Event]:
|
||||
) -> list[type[Event]]:
|
||||
"""Get matching events from the event stream based on filters.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -7,9 +7,10 @@ from enum import Enum
|
||||
|
||||
import bashlex
|
||||
import libtmux
|
||||
import psutil
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.events.action import CmdRunAction
|
||||
from openhands.events.action import Action, CmdRunAction
|
||||
from openhands.events.observation import ErrorObservation
|
||||
from openhands.events.observation.commands import (
|
||||
CMD_OUTPUT_PS1_END,
|
||||
@@ -256,10 +257,41 @@ class BashSession:
|
||||
)
|
||||
return content
|
||||
|
||||
def kill_process(self, pid: int) -> bool:
|
||||
"""Kill a process by its PID.
|
||||
|
||||
Args:
|
||||
pid (int): The PID of the process to kill.
|
||||
|
||||
Returns:
|
||||
bool: True if the process was killed successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
process = psutil.Process(pid)
|
||||
process.kill()
|
||||
return True
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
return False
|
||||
|
||||
def kill_all_processes(self) -> bool:
|
||||
"""Kill all processes associated with the current command.
|
||||
|
||||
Returns:
|
||||
bool: True if any processes were killed successfully, False otherwise.
|
||||
"""
|
||||
process_info = self.get_running_processes()
|
||||
success = False
|
||||
for pid in process_info['process_pids']:
|
||||
if pid != int(self.pane.cmd('display-message', '-p', '#{pane_pid}').stdout[0].strip()):
|
||||
if self.kill_process(pid):
|
||||
success = True
|
||||
return success
|
||||
|
||||
def close(self):
|
||||
"""Clean up the session."""
|
||||
if self._closed:
|
||||
return
|
||||
self.kill_all_processes() # Kill any remaining processes
|
||||
self.session.kill_session()
|
||||
self._closed = True
|
||||
|
||||
@@ -429,6 +461,119 @@ class BashSession:
|
||||
# Clear the current content
|
||||
self._clear_screen()
|
||||
|
||||
def get_running_processes(self):
|
||||
"""Get a list of processes that are currently running in the bash session.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing:
|
||||
- 'is_command_running': Boolean indicating if the last command is still running
|
||||
- 'current_command_pid': PID of the currently running command (if any)
|
||||
- 'processes': List of all processes visible to this bash session
|
||||
- 'command_processes': List of processes that are likely part of the current command
|
||||
- 'process_pids': List of PIDs of all processes
|
||||
- 'command_pids': List of PIDs of processes that are likely part of the current command
|
||||
"""
|
||||
# Check if a command is running in this session
|
||||
is_command_running = False
|
||||
|
||||
# Get the shell's PID directly from tmux
|
||||
shell_pid_str = (
|
||||
self.pane.cmd('display-message', '-p', '#{pane_pid}').stdout[0].strip()
|
||||
)
|
||||
shell_pid = int(shell_pid_str)
|
||||
|
||||
try:
|
||||
# Get process information for the shell
|
||||
shell_process = psutil.Process(shell_pid)
|
||||
process_list = []
|
||||
command_processes = []
|
||||
current_command_pid = None
|
||||
|
||||
# Get all child processes recursively
|
||||
children = shell_process.children(recursive=True)
|
||||
|
||||
# Add the shell process first
|
||||
process_str = f"{shell_pid} {shell_process.ppid()} {shell_process.status()[0]} {' '.join(shell_process.cmdline())}"
|
||||
process_list.append(process_str)
|
||||
|
||||
for child in children:
|
||||
try:
|
||||
# Skip if no cmdline (might be a kernel process)
|
||||
cmdline = child.cmdline()
|
||||
if not cmdline:
|
||||
continue
|
||||
|
||||
# Format the process info
|
||||
status_flag = child.status()[0]
|
||||
|
||||
# Build process string (PID PPID STATUS COMMAND)
|
||||
cmd_str = ' '.join(cmdline)
|
||||
process_str = f'{child.pid} {child.ppid()} {status_flag} {cmd_str}'
|
||||
process_list.append(process_str)
|
||||
|
||||
# Identify processes that are likely part of current command
|
||||
child_ppid = child.ppid()
|
||||
# Direct child of shell = likely current command
|
||||
if child_ppid == shell_pid:
|
||||
if not current_command_pid:
|
||||
current_command_pid = child.pid
|
||||
is_command_running = True
|
||||
command_processes.append(process_str)
|
||||
# Child of identified command process = part of current command
|
||||
elif current_command_pid and (
|
||||
child_ppid == current_command_pid
|
||||
or any(
|
||||
p.pid == child_ppid
|
||||
for p in children
|
||||
if p.pid == current_command_pid
|
||||
or p.ppid() == current_command_pid
|
||||
)
|
||||
):
|
||||
command_processes.append(process_str)
|
||||
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
# Process may have terminated while we were examining it
|
||||
continue
|
||||
|
||||
# If we have no command processes, we're not running anything
|
||||
if not command_processes:
|
||||
is_command_running = False
|
||||
current_command_pid = None
|
||||
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied) as e:
|
||||
logger.warning(f'Error accessing process information: {e}')
|
||||
return {
|
||||
'is_command_running': is_command_running,
|
||||
'current_command_pid': None,
|
||||
'processes': [],
|
||||
'command_processes': [],
|
||||
}
|
||||
|
||||
# Extract PIDs from process strings
|
||||
process_pids = []
|
||||
command_pids = []
|
||||
for proc in process_list:
|
||||
try:
|
||||
pid = int(proc.split()[0])
|
||||
process_pids.append(pid)
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
for proc in command_processes:
|
||||
try:
|
||||
pid = int(proc.split()[0])
|
||||
command_pids.append(pid)
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
return {
|
||||
'is_command_running': is_command_running,
|
||||
'current_command_pid': current_command_pid,
|
||||
'processes': process_list,
|
||||
'command_processes': command_processes,
|
||||
'process_pids': process_pids,
|
||||
'command_pids': command_pids,
|
||||
}
|
||||
|
||||
def _combine_outputs_between_matches(
|
||||
self,
|
||||
pane_content: str,
|
||||
@@ -464,34 +609,95 @@ class BashSession:
|
||||
logger.debug(f'COMBINED OUTPUT: {combined_output}')
|
||||
return combined_output
|
||||
|
||||
def execute(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservation:
|
||||
def execute(self, action: Action) -> CmdOutputObservation | ErrorObservation:
|
||||
"""Execute a command in the bash session."""
|
||||
if not self._initialized:
|
||||
raise RuntimeError('Bash session is not initialized')
|
||||
|
||||
# Strip the command of any leading/trailing whitespace
|
||||
logger.debug(f'RECEIVED ACTION: {action}')
|
||||
command = action.command.strip()
|
||||
is_input: bool = action.is_input
|
||||
|
||||
# If the previous command is not completed, we need to check if the command is empty
|
||||
|
||||
# Handle CmdRunAction
|
||||
if not isinstance(action, CmdRunAction):
|
||||
return ErrorObservation(f"Unsupported action type: {type(action)}")
|
||||
|
||||
command = action.command.strip()
|
||||
is_input = action.is_input
|
||||
|
||||
# Handle different command types
|
||||
if command == '':
|
||||
return self._handle_empty_command(action)
|
||||
elif is_input:
|
||||
return self._handle_input_command(action)
|
||||
else:
|
||||
return self._handle_normal_command(action)
|
||||
|
||||
def _handle_empty_command(self, action: CmdRunAction) -> CmdOutputObservation:
|
||||
"""Handle an empty command (usually to retrieve more output from a running command)."""
|
||||
assert action.command.strip() == ''
|
||||
# If the previous command is not in a continuing state, return an error
|
||||
if self.prev_status not in {
|
||||
BashCommandStatus.CONTINUE,
|
||||
BashCommandStatus.NO_CHANGE_TIMEOUT,
|
||||
BashCommandStatus.HARD_TIMEOUT,
|
||||
}:
|
||||
if command == '':
|
||||
return CmdOutputObservation(
|
||||
content='ERROR: No previous running command to retrieve logs from.',
|
||||
command='',
|
||||
metadata=CmdOutputMetadata(),
|
||||
)
|
||||
if is_input:
|
||||
return CmdOutputObservation(
|
||||
content='ERROR: No previous running command to interact with.',
|
||||
command='',
|
||||
metadata=CmdOutputMetadata(),
|
||||
)
|
||||
return CmdOutputObservation(
|
||||
content='ERROR: No previous running command to retrieve logs from.',
|
||||
command='',
|
||||
metadata=CmdOutputMetadata(),
|
||||
)
|
||||
|
||||
# Start polling for command completion
|
||||
return self._poll_for_command_completion('', action)
|
||||
|
||||
def _handle_input_command(self, action: CmdRunAction) -> CmdOutputObservation:
|
||||
"""Handle an input command (sent to a running process)."""
|
||||
command = action.command.strip()
|
||||
|
||||
# If the previous command is not in a continuing state, return an error
|
||||
if self.prev_status not in {
|
||||
BashCommandStatus.CONTINUE,
|
||||
BashCommandStatus.NO_CHANGE_TIMEOUT,
|
||||
BashCommandStatus.HARD_TIMEOUT,
|
||||
}:
|
||||
return CmdOutputObservation(
|
||||
content='ERROR: No previous running command to interact with.',
|
||||
command='',
|
||||
metadata=CmdOutputMetadata(),
|
||||
)
|
||||
|
||||
# Check if it's a special key
|
||||
is_special_key = self._is_special_key(command)
|
||||
|
||||
# Send the input to the pane
|
||||
logger.debug(f'SENDING INPUT TO RUNNING PROCESS: {command!r}')
|
||||
self.pane.send_keys(
|
||||
command,
|
||||
enter=not is_special_key,
|
||||
)
|
||||
|
||||
# Start polling for command completion
|
||||
return self._poll_for_command_completion(command, action)
|
||||
|
||||
def _handle_normal_command(
|
||||
self, action: CmdRunAction
|
||||
) -> CmdOutputObservation | ErrorObservation:
|
||||
"""Handle a normal command."""
|
||||
command = action.command.strip()
|
||||
|
||||
# Check if command is running previous command first
|
||||
last_pane_output = self._get_pane_content()
|
||||
if (
|
||||
self.prev_status
|
||||
in {
|
||||
BashCommandStatus.HARD_TIMEOUT,
|
||||
BashCommandStatus.NO_CHANGE_TIMEOUT,
|
||||
}
|
||||
and not last_pane_output.endswith(
|
||||
CMD_OUTPUT_PS1_END
|
||||
) # prev command is not completed
|
||||
):
|
||||
return self._handle_interrupted_command(command, last_pane_output)
|
||||
|
||||
# Check if the command is a single command or multiple commands
|
||||
splited_commands = split_bash_commands(command)
|
||||
@@ -504,67 +710,56 @@ class BashSession:
|
||||
)
|
||||
)
|
||||
|
||||
# Convert command to raw string and send it
|
||||
is_special_key = self._is_special_key(command)
|
||||
command = escape_bash_special_chars(command)
|
||||
logger.debug(f'SENDING COMMAND: {command!r}')
|
||||
self.pane.send_keys(
|
||||
command,
|
||||
enter=not is_special_key,
|
||||
)
|
||||
|
||||
# Start polling for command completion
|
||||
return self._poll_for_command_completion(command, action)
|
||||
|
||||
def _handle_interrupted_command(
|
||||
self, command: str, last_pane_output: str
|
||||
) -> CmdOutputObservation:
|
||||
"""Handle the case where a new command is sent while a previous command is still running."""
|
||||
_ps1_matches = CmdOutputMetadata.matches_ps1_metadata(last_pane_output)
|
||||
raw_command_output = self._combine_outputs_between_matches(
|
||||
last_pane_output, _ps1_matches
|
||||
)
|
||||
metadata = CmdOutputMetadata() # No metadata available
|
||||
metadata.suffix = (
|
||||
f'\n[Your command "{command}" is NOT executed. '
|
||||
f'The previous command is still running - You CANNOT send new commands until the previous command is completed. '
|
||||
'By setting `is_input` to `true`, you can interact with the current process: '
|
||||
"You may wait longer to see additional output of the previous command by sending empty command '', "
|
||||
'send other commands to interact with the current process, '
|
||||
'or send keys ("C-c", "C-z", "C-d") to interrupt/kill the previous command before sending your new command.]'
|
||||
)
|
||||
logger.debug(f'PREVIOUS COMMAND OUTPUT: {raw_command_output}')
|
||||
command_output = self._get_command_output(
|
||||
command,
|
||||
raw_command_output,
|
||||
metadata,
|
||||
continue_prefix='[Below is the output of the previous command.]\n',
|
||||
)
|
||||
return CmdOutputObservation(
|
||||
command=command,
|
||||
content=command_output,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _poll_for_command_completion(
|
||||
self, command: str, action: CmdRunAction
|
||||
) -> CmdOutputObservation:
|
||||
"""Poll for command completion and handle timeouts."""
|
||||
start_time = time.time()
|
||||
last_change_time = start_time
|
||||
last_pane_output = self._get_pane_content()
|
||||
|
||||
# When prev command is still running, and we are trying to send a new command
|
||||
if (
|
||||
self.prev_status
|
||||
in {
|
||||
BashCommandStatus.HARD_TIMEOUT,
|
||||
BashCommandStatus.NO_CHANGE_TIMEOUT,
|
||||
}
|
||||
and not last_pane_output.endswith(
|
||||
CMD_OUTPUT_PS1_END
|
||||
) # prev command is not completed
|
||||
and not is_input
|
||||
and command != '' # not input and not empty command
|
||||
):
|
||||
_ps1_matches = CmdOutputMetadata.matches_ps1_metadata(last_pane_output)
|
||||
raw_command_output = self._combine_outputs_between_matches(
|
||||
last_pane_output, _ps1_matches
|
||||
)
|
||||
metadata = CmdOutputMetadata() # No metadata available
|
||||
metadata.suffix = (
|
||||
f'\n[Your command "{command}" is NOT executed. '
|
||||
f'The previous command is still running - You CANNOT send new commands until the previous command is completed. '
|
||||
'By setting `is_input` to `true`, you can interact with the current process: '
|
||||
"You may wait longer to see additional output of the previous command by sending empty command '', "
|
||||
'send other commands to interact with the current process, '
|
||||
'or send keys ("C-c", "C-z", "C-d") to interrupt/kill the previous command before sending your new command.]'
|
||||
)
|
||||
logger.debug(f'PREVIOUS COMMAND OUTPUT: {raw_command_output}')
|
||||
command_output = self._get_command_output(
|
||||
command,
|
||||
raw_command_output,
|
||||
metadata,
|
||||
continue_prefix='[Below is the output of the previous command.]\n',
|
||||
)
|
||||
return CmdOutputObservation(
|
||||
command=command,
|
||||
content=command_output,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# Send actual command/inputs to the pane
|
||||
if command != '':
|
||||
is_special_key = self._is_special_key(command)
|
||||
if is_input:
|
||||
logger.debug(f'SENDING INPUT TO RUNNING PROCESS: {command!r}')
|
||||
self.pane.send_keys(
|
||||
command,
|
||||
enter=not is_special_key,
|
||||
)
|
||||
else:
|
||||
# convert command to raw string
|
||||
command = escape_bash_special_chars(command)
|
||||
logger.debug(f'SENDING COMMAND: {command!r}')
|
||||
self.pane.send_keys(
|
||||
command,
|
||||
enter=not is_special_key,
|
||||
)
|
||||
|
||||
# Loop until the command completes or times out
|
||||
while should_continue():
|
||||
_start_time = time.time()
|
||||
@@ -575,6 +770,18 @@ class BashSession:
|
||||
)
|
||||
logger.debug(f'BEGIN OF PANE CONTENT: {cur_pane_output.split("\n")[:10]}')
|
||||
logger.debug(f'END OF PANE CONTENT: {cur_pane_output.split("\n")[-10:]}')
|
||||
|
||||
# Log running processes for debugging
|
||||
try:
|
||||
process_info = self.get_running_processes()
|
||||
logger.debug(
|
||||
f'RUNNING PROCESSES: is_command_running={process_info["is_command_running"]}, '
|
||||
f'current_command_pid={process_info["current_command_pid"]}, '
|
||||
f'command_processes_count={len(process_info["command_processes"])}'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f'Failed to get running processes: {e}')
|
||||
|
||||
ps1_matches = CmdOutputMetadata.matches_ps1_metadata(cur_pane_output)
|
||||
if cur_pane_output != last_pane_output:
|
||||
last_pane_output = cur_pane_output
|
||||
@@ -582,7 +789,6 @@ class BashSession:
|
||||
logger.debug(f'CONTENT UPDATED DETECTED at {last_change_time}')
|
||||
|
||||
# 1) Execution completed
|
||||
# if the last command output contains the end marker
|
||||
if cur_pane_output.rstrip().endswith(CMD_OUTPUT_PS1_END.rstrip()):
|
||||
return self._handle_completed_command(
|
||||
command,
|
||||
@@ -591,8 +797,6 @@ class BashSession:
|
||||
)
|
||||
|
||||
# 2) Execution timed out since there's no change in output
|
||||
# for a while (self.NO_CHANGE_TIMEOUT_SECONDS)
|
||||
# We ignore this if the command is *blocking
|
||||
time_since_last_change = time.time() - last_change_time
|
||||
logger.debug(
|
||||
f'CHECKING NO CHANGE TIMEOUT ({self.NO_CHANGE_TIMEOUT_SECONDS}s): elapsed {time_since_last_change}. Action blocking: {action.blocking}'
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -20,12 +19,8 @@ class SecurityAnalyzer:
|
||||
event_stream: The event stream to listen for events.
|
||||
"""
|
||||
self.event_stream = event_stream
|
||||
|
||||
def sync_on_event(event: Event) -> None:
|
||||
asyncio.create_task(self.on_event(event))
|
||||
|
||||
self.event_stream.subscribe(
|
||||
EventStreamSubscriber.SECURITY_ANALYZER, sync_on_event, str(uuid4())
|
||||
EventStreamSubscriber.SECURITY_ANALYZER, self.on_event, str(uuid4())
|
||||
)
|
||||
|
||||
async def on_event(self, event: Event) -> None:
|
||||
|
||||
@@ -58,8 +58,6 @@ async def connect(connection_id: str, environ):
|
||||
f'Connected to conversation {conversation_id} with connection_id {connection_id}. Replaying event stream...'
|
||||
)
|
||||
agent_state_changed = None
|
||||
if event_stream is None:
|
||||
raise ConnectionRefusedError('Failed to join conversation')
|
||||
async_stream = AsyncEventStreamWrapper(event_stream, latest_event_id + 1)
|
||||
async for event in async_stream:
|
||||
logger.info(f'oh_event: {event.__class__.__name__}')
|
||||
|
||||
@@ -7,9 +7,7 @@ GENERAL_TIMEOUT: int = 15
|
||||
EXECUTOR = ThreadPoolExecutor()
|
||||
|
||||
|
||||
async def call_sync_from_async(
|
||||
fn: Callable[..., object], *args: object, **kwargs: object
|
||||
) -> object:
|
||||
async def call_sync_from_async(fn: Callable, *args, **kwargs):
|
||||
"""
|
||||
Shorthand for running a function in the default background thread pool executor
|
||||
and awaiting the result. The nature of synchronous code is that the future
|
||||
@@ -22,11 +20,8 @@ async def call_sync_from_async(
|
||||
|
||||
|
||||
def call_async_from_sync(
|
||||
corofn: Callable[..., Coroutine[object, object, object]],
|
||||
timeout: float = GENERAL_TIMEOUT,
|
||||
*args: object,
|
||||
**kwargs: object,
|
||||
) -> object:
|
||||
corofn: Callable, timeout: float = GENERAL_TIMEOUT, *args, **kwargs
|
||||
):
|
||||
"""
|
||||
Shorthand for running a coroutine in the default background thread pool executor
|
||||
and awaiting the result
|
||||
@@ -37,12 +32,12 @@ def call_async_from_sync(
|
||||
if not asyncio.iscoroutinefunction(corofn):
|
||||
raise ValueError('corofn is not a coroutine function')
|
||||
|
||||
async def arun() -> object:
|
||||
async def arun():
|
||||
coro = corofn(*args, **kwargs)
|
||||
result = await coro
|
||||
return result
|
||||
|
||||
def run() -> object:
|
||||
def run():
|
||||
loop_for_thread = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(loop_for_thread)
|
||||
@@ -57,15 +52,10 @@ def call_async_from_sync(
|
||||
|
||||
|
||||
async def call_coro_in_bg_thread(
|
||||
corofn: Callable[..., Coroutine[object, object, object]],
|
||||
timeout: float = GENERAL_TIMEOUT,
|
||||
*args: object,
|
||||
**kwargs: object,
|
||||
) -> object:
|
||||
corofn: Callable, timeout: float = GENERAL_TIMEOUT, *args, **kwargs
|
||||
):
|
||||
"""Function for running a coroutine in a background thread."""
|
||||
return await call_sync_from_async(
|
||||
call_async_from_sync, corofn, timeout, *args, **kwargs
|
||||
)
|
||||
await call_sync_from_async(call_async_from_sync, corofn, timeout, *args, **kwargs)
|
||||
|
||||
|
||||
async def wait_all(
|
||||
@@ -100,8 +90,8 @@ async def wait_all(
|
||||
|
||||
|
||||
class AsyncException(Exception):
|
||||
def __init__(self, exceptions: list[Exception]) -> None:
|
||||
def __init__(self, exceptions):
|
||||
self.exceptions = exceptions
|
||||
|
||||
def __str__(self) -> str:
|
||||
def __str__(self):
|
||||
return '\n'.join(str(e) for e in self.exceptions)
|
||||
|
||||
@@ -25,7 +25,7 @@ class Chunk(BaseModel):
|
||||
return ret
|
||||
|
||||
|
||||
def _create_chunks_from_raw_string(content: str, size: int) -> list[Chunk]:
|
||||
def _create_chunks_from_raw_string(content: str, size: int):
|
||||
lines = content.split('\n')
|
||||
ret = []
|
||||
for i in range(0, len(lines), size):
|
||||
@@ -65,7 +65,7 @@ def normalized_lcs(chunk: str, query: str) -> float:
|
||||
"""
|
||||
if len(chunk) == 0:
|
||||
return 0.0
|
||||
_score = float(pylcs.lcs_sequence_length(chunk, query))
|
||||
_score = pylcs.lcs_sequence_length(chunk, query)
|
||||
return _score / len(chunk)
|
||||
|
||||
|
||||
|
||||
@@ -15,17 +15,15 @@ Hopefully, this will be fixed soon and we can remove this abomination.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
from typing import Any, Callable, Iterator, TypeVar
|
||||
from typing import Callable
|
||||
|
||||
import httpx
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ensure_httpx_close() -> Iterator[None]:
|
||||
def ensure_httpx_close():
|
||||
wrapped_class = httpx.Client
|
||||
proxys: list[Any] = []
|
||||
proxys = []
|
||||
|
||||
class ClientProxy:
|
||||
"""
|
||||
@@ -34,52 +32,47 @@ def ensure_httpx_close() -> Iterator[None]:
|
||||
where a client is reused, we need to be able to reuse the client even after closing it.
|
||||
"""
|
||||
|
||||
client_constructor: Callable[..., Any]
|
||||
args: tuple[Any, ...]
|
||||
kwargs: dict[str, Any]
|
||||
client: httpx.Client | None
|
||||
client_constructor: Callable
|
||||
args: tuple
|
||||
kwargs: dict
|
||||
client: httpx.Client
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.client = wrapped_class(*self.args, **self.kwargs)
|
||||
proxys.append(self)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
def __getattr__(self, name):
|
||||
# Invoke a method on the proxied client - create one if required
|
||||
if self.client is None:
|
||||
self.client = wrapped_class(*self.args, **self.kwargs)
|
||||
return getattr(self.client, name)
|
||||
|
||||
def close(self) -> None:
|
||||
def close(self):
|
||||
# Close the client if it is open
|
||||
if self.client:
|
||||
self.client.close()
|
||||
self.client = None
|
||||
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
def __iter__(self, *args, **kwargs):
|
||||
# We have to override this as debuggers invoke it causing the client to reopen
|
||||
if self.client:
|
||||
# Convert client to list first since it's not directly iterable
|
||||
return iter(list(self.client.__dict__.items()))
|
||||
return iter([])
|
||||
return self.client.iter(*args, **kwargs)
|
||||
return object.__getattribute__(self, 'iter')(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
def is_closed(self):
|
||||
# Check if closed
|
||||
if self.client is None:
|
||||
return True
|
||||
# Convert to bool to ensure we return a bool
|
||||
return bool(self.client.is_closed)
|
||||
return self.client.is_closed
|
||||
|
||||
# We need to monkey patch the Client class to track instances
|
||||
# This is a hack until LiteLLM fixes their client lifecycle management
|
||||
original_client = httpx.Client
|
||||
httpx.Client = ClientProxy
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
httpx.Client = original_client
|
||||
httpx.Client = wrapped_class
|
||||
while proxys:
|
||||
proxy = proxys.pop()
|
||||
proxy.close()
|
||||
|
||||
@@ -4,15 +4,12 @@ from typing import Type, TypeVar
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def import_from(qual_name: str) -> type:
|
||||
def import_from(qual_name: str):
|
||||
"""Import the value from the qualified name given"""
|
||||
parts = qual_name.split('.')
|
||||
module_name = '.'.join(parts[:-1])
|
||||
module = importlib.import_module(module_name)
|
||||
result = getattr(module, parts[-1])
|
||||
assert isinstance(
|
||||
result, type
|
||||
), f'Expected {qual_name} to be a type, got {type(result)}'
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import base64
|
||||
from typing import Any, AsyncIterator, Callable
|
||||
from typing import AsyncIterator, Callable
|
||||
|
||||
|
||||
def offset_to_page_id(offset: int, has_next: bool) -> str | None:
|
||||
@@ -16,7 +16,7 @@ def page_id_to_offset(page_id: str | None) -> int:
|
||||
return offset
|
||||
|
||||
|
||||
async def iterate(fn: Callable[..., Any], **kwargs: Any) -> AsyncIterator[Any]:
|
||||
async def iterate(fn: Callable, **kwargs) -> AsyncIterator:
|
||||
"""Iterate over paged result sets. Assumes that the results sets contain an array of result objects, and a next_page_id"""
|
||||
kwargs = {**kwargs}
|
||||
kwargs['page_id'] = None
|
||||
|
||||
@@ -22,7 +22,4 @@ def colorize(text: str, color: TermColor = TermColor.WARNING) -> str:
|
||||
Returns:
|
||||
str: Colored text
|
||||
"""
|
||||
# colored() returns a string with ANSI color codes
|
||||
result = colored(text, color.value)
|
||||
assert isinstance(result, str)
|
||||
return result
|
||||
return colored(text, color.value)
|
||||
|
||||
@@ -52,7 +52,7 @@ def test_bash_server(temp_dir, runtime_cls, run_as_openhands):
|
||||
assert obs.exit_code == -1
|
||||
assert 'Serving HTTP on 0.0.0.0 port 8080' in obs.content
|
||||
assert (
|
||||
"[The command timed out after 1.0 seconds. You may wait longer to see additional output by sending empty command '', send other commands to interact with the current process, or send keys to interrupt/kill the command.]"
|
||||
"[The command timed out after 1 seconds. You may wait longer to see additional output by sending empty command '', send other commands to interact with the current process, or send keys to interrupt/kill the command.]"
|
||||
in obs.metadata.suffix
|
||||
)
|
||||
|
||||
|
||||
@@ -386,3 +386,40 @@ def test_python_interactive_input():
|
||||
assert session.prev_status == BashCommandStatus.COMPLETED
|
||||
|
||||
session.close()
|
||||
|
||||
|
||||
def test_get_running_processes():
|
||||
"""Test the get_running_processes method to detect running processes."""
|
||||
session = BashSession(work_dir=os.getcwd(), no_change_timeout_seconds=2)
|
||||
session.initialize()
|
||||
|
||||
# First check with no running command
|
||||
process_info = session.get_running_processes()
|
||||
assert isinstance(process_info, dict)
|
||||
assert 'is_command_running' in process_info
|
||||
assert process_info['is_command_running'] is False
|
||||
assert 'processes' in process_info
|
||||
assert len(process_info['processes']) == 1 # should have the shell process
|
||||
assert 'command_processes' in process_info
|
||||
assert len(process_info['command_processes']) == 0
|
||||
assert 'current_command_pid' in process_info
|
||||
assert process_info['current_command_pid'] is None
|
||||
|
||||
session.execute(CmdRunAction('sleep 120', blocking=False))
|
||||
|
||||
# Check running processes
|
||||
process_info = session.get_running_processes()
|
||||
assert process_info['is_command_running'] is True
|
||||
assert process_info['current_command_pid'] is not None
|
||||
assert len(process_info['command_processes']) > 0
|
||||
|
||||
# Send Ctrl+C to terminate the process
|
||||
session.execute(CmdRunAction('C-c', is_input=True))
|
||||
|
||||
# Verify process is no longer running
|
||||
process_info = session.get_running_processes()
|
||||
assert process_info['is_command_running'] is False
|
||||
assert process_info['current_command_pid'] is None
|
||||
assert len(process_info['command_processes']) == 0
|
||||
|
||||
session.close()
|
||||
|
||||
@@ -24,7 +24,6 @@ from openhands.events.observation.files import (
|
||||
FileReadObservation,
|
||||
FileWriteObservation,
|
||||
)
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.locations import get_conversation_event_filename
|
||||
|
||||
@@ -158,31 +157,6 @@ def test_get_matching_events_source_filter(temp_dir: str):
|
||||
and events[0].source == EventSource.ENVIRONMENT
|
||||
)
|
||||
|
||||
# Test that source comparison works correctly with None source
|
||||
null_source_event = NullObservation('test4')
|
||||
event_stream.add_event(null_source_event, EventSource.AGENT)
|
||||
event = event_stream.get_event(event_stream.get_latest_event_id())
|
||||
event._source = None # type: ignore
|
||||
|
||||
# Update the serialized version
|
||||
data = event_to_dict(event)
|
||||
event_stream.file_store.write(
|
||||
event_stream._get_filename_for_id(event.id, event_stream.user_id),
|
||||
json.dumps(data),
|
||||
)
|
||||
|
||||
# Verify that source comparison works correctly
|
||||
assert event_stream._should_filter_event(
|
||||
event, source='agent'
|
||||
) # Should filter out None source events
|
||||
assert not event_stream._should_filter_event(
|
||||
event, source=None
|
||||
) # Should not filter out when source filter is None
|
||||
|
||||
# Filter by AGENT source again
|
||||
events = event_stream.get_matching_events(source='agent')
|
||||
assert len(events) == 2 # Should not include the None source event
|
||||
|
||||
|
||||
def test_get_matching_events_pagination(temp_dir: str):
|
||||
file_store = get_file_store('local', temp_dir)
|
||||
@@ -236,7 +210,7 @@ def test_memory_usage_file_operations(temp_dir: str):
|
||||
"""
|
||||
|
||||
def get_memory_mb():
|
||||
"""Get current memory usage in MB."""
|
||||
"""Get current memory usage in MB"""
|
||||
process = psutil.Process(os.getpid())
|
||||
return process.memory_info().rss / 1024 / 1024
|
||||
|
||||
|
||||
@@ -61,13 +61,13 @@ def test_get_messages(codeact_agent: CodeActAgent):
|
||||
message_action_1._source = 'user'
|
||||
history.append(message_action_1)
|
||||
message_action_2 = MessageAction('Sure!')
|
||||
message_action_2._source = 'agent'
|
||||
message_action_2._source = 'assistant'
|
||||
history.append(message_action_2)
|
||||
message_action_3 = MessageAction('Hello, agent!')
|
||||
message_action_3._source = 'user'
|
||||
history.append(message_action_3)
|
||||
message_action_4 = MessageAction('Hello, user!')
|
||||
message_action_4._source = 'agent'
|
||||
message_action_4._source = 'assistant'
|
||||
history.append(message_action_4)
|
||||
message_action_5 = MessageAction('Laaaaaaaast!')
|
||||
message_action_5._source = 'user'
|
||||
@@ -106,7 +106,7 @@ def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
|
||||
message_action_user._source = 'user'
|
||||
history.append(message_action_user)
|
||||
message_action_agent = MessageAction(f'Agent message {i}')
|
||||
message_action_agent._source = 'agent'
|
||||
message_action_agent._source = 'assistant'
|
||||
history.append(message_action_agent)
|
||||
|
||||
codeact_agent.reset()
|
||||
|
||||
@@ -603,11 +603,11 @@ async def test_check_usertask(
|
||||
|
||||
if is_appropriate == 'No':
|
||||
assert len(event_list) == 2
|
||||
assert isinstance(event_list[0], MessageAction)
|
||||
assert isinstance(event_list[1], ChangeAgentStateAction)
|
||||
assert type(event_list[0]) == MessageAction
|
||||
assert type(event_list[1]) == ChangeAgentStateAction
|
||||
elif is_appropriate == 'Yes':
|
||||
assert len(event_list) == 1
|
||||
assert isinstance(event_list[0], MessageAction)
|
||||
assert type(event_list[0]) == MessageAction
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -665,8 +665,8 @@ async def test_check_fillaction(
|
||||
|
||||
if is_harmful == 'Yes':
|
||||
assert len(event_list) == 2
|
||||
assert isinstance(event_list[0], BrowseInteractiveAction)
|
||||
assert isinstance(event_list[1], ChangeAgentStateAction)
|
||||
assert type(event_list[0]) == BrowseInteractiveAction
|
||||
assert type(event_list[1]) == ChangeAgentStateAction
|
||||
elif is_harmful == 'No':
|
||||
assert len(event_list) == 1
|
||||
assert isinstance(event_list[0], BrowseInteractiveAction)
|
||||
assert type(event_list[0]) == BrowseInteractiveAction
|
||||
|
||||
39
tests/unit/test_stop_button_issue.py
Normal file
39
tests/unit/test_stop_button_issue.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import time
|
||||
|
||||
from openhands.events.action import CmdRunAction
|
||||
from openhands.runtime.utils.bash import BashSession
|
||||
|
||||
|
||||
def test_stop_button_background_process():
|
||||
session = BashSession(work_dir='/tmp', no_change_timeout_seconds=2)
|
||||
session.initialize()
|
||||
|
||||
# Start a process that runs indefinitely and detaches from the terminal
|
||||
session.execute(
|
||||
CmdRunAction(
|
||||
'nohup sleep 60 > /dev/null 2>&1 &'
|
||||
) # Background process that detaches from terminal
|
||||
)
|
||||
time.sleep(2) # Give time for the process to start
|
||||
|
||||
# Get initial process info
|
||||
process_info = session.get_running_processes()
|
||||
print('Initial process info:', process_info) # Debug output
|
||||
assert any(
|
||||
'sleep' in p for p in process_info['processes']
|
||||
), 'Expected to find sleep process'
|
||||
initial_processes = [p for p in process_info['processes'] if 'sleep' in p]
|
||||
assert len(initial_processes) > 0, 'Expected at least one sleep process'
|
||||
|
||||
# Send kill command to stop it
|
||||
session.execute(CmdRunAction('pkill -P $$'))
|
||||
time.sleep(1) # Give time for processes to be killed
|
||||
|
||||
# Check if process is still running (it should be terminated)
|
||||
process_info = session.get_running_processes()
|
||||
print('Process info after kill command:', process_info) # Debug output
|
||||
assert not any(
|
||||
'sleep' in p for p in process_info['processes']
|
||||
), 'Background process should be terminated'
|
||||
|
||||
session.close()
|
||||
Reference in New Issue
Block a user