Compare commits

..

26 Commits

Author SHA1 Message Date
openhands ac07afbb3c fix: make LLM response handling more robust for deepseek and other providers
- Replace type-based checks with dictionary-based checks
- Update message access to use dictionary syntax
- Remove unused imports
2025-03-25 06:28:34 +00:00
openhands b9f330db1b Merge main into fix/llm-mypy-errors (keeping our poetry.lock) 2025-03-24 23:35:44 +00:00
openhands 096b259694 Replace partial() with proper type hints
- Remove partial() since we're not using its argument binding functionality
- Add proper type hints for class attributes and variables
- Fix mypy errors with proper type annotations
2025-03-03 20:00:53 +00:00
openhands 7f3202d0e0 Replace partial() with direct function wrappers
- Replace partial() with direct function wrappers to make the code clearer
- Keep type hints using cast() to satisfy mypy
- Add proper type hints for all functions
2025-03-03 19:56:30 +00:00
openhands e63dfca517 Replace partial() with cast() for type hints
- Replace partial() with cast() since we're not using partial's argument binding
- Add proper type hints for the cast() calls
- Add missing imports for type hints
2025-03-03 19:50:30 +00:00
openhands cb705b736f Remove duplicate default values in retry_mixin.py
- Remove hardcoded default values for retry parameters in retry_mixin.py
  since they are already defined in LLMConfig
- Add comment explaining where the default values come from
- Keep partial() usage to maintain type compatibility
2025-03-03 18:50:33 +00:00
Graham Neubig 1e5c4da0fc Merge branch 'main' into fix/llm-mypy-errors 2025-03-03 13:46:18 -05:00
Graham Neubig 526618753d Update .github/workflows/py-unit-tests.yml 2025-02-22 23:40:37 -05:00
openhands e25e6766fb Fix ruff and ruff-format issues 2025-02-22 23:27:10 +00:00
Graham Neubig 1b9a2b43c3 Merge branch 'main' into fix/llm-mypy-errors 2025-02-22 17:27:14 -05:00
openhands 7886c1f920 Fix litellm type imports to use OpenAI types directly 2025-02-21 22:21:40 +00:00
openhands 4b49ffb01d Fix litellm type imports to use OpenAI types directly 2025-02-21 17:30:43 +00:00
openhands f74ce56a35 Fix litellm import path for Choices and StreamingChoices 2025-02-21 17:29:12 +00:00
openhands 9ccf680c38 Remove parallel test execution to fix failing tests 2025-02-21 17:25:59 +00:00
Graham Neubig 41068f6ea1 Merge branch 'main' into fix/llm-mypy-errors 2025-02-21 08:22:40 -05:00
Engel Nyst 5b1c8bc2e8 Merge branch 'main' into fix/llm-mypy-errors 2025-02-19 19:28:48 +01:00
Xingyao Wang 5b8db983b7 Merge branch 'main' into fix/llm-mypy-errors 2025-02-19 12:19:39 -05:00
Graham Neubig 18543a2efa Merge branch 'main' into fix/llm-mypy-errors 2025-02-19 05:22:06 -05:00
openhands 63d9c3d668 Revert changes outside openhands/llm directory 2025-02-19 02:43:02 +00:00
openhands 2589b13815 Fix mypy errors in openhands/llm directory 2025-02-19 02:42:17 +00:00
Graham Neubig 592aca05e1 Merge branch 'main' into feature/strict-mypy-checks 2025-02-18 20:14:17 -05:00
Graham Neubig d309455733 Merge branch 'main' into feature/strict-mypy-checks 2025-02-11 10:19:36 -05:00
Graham Neubig 66a7920539 Merge branch 'main' into feature/strict-mypy-checks 2025-02-10 13:06:49 -05:00
Graham Neubig 64ebef3646 Update .github/workflows/lint.yml 2025-01-21 14:52:56 -05:00
Graham Neubig 7a259915c1 Update .github/workflows/lint.yml 2025-01-21 14:47:51 -05:00
openhands 66bd8fdbcd Enable strict type checking with mypy
- Update mypy configuration with stricter type checking rules
- Add more type stubs to pre-commit configuration
- Run mypy both through pre-commit and directly in CI
- Install project in editable mode for better type checking
- Set correct PYTHONPATH in CI environment
2025-01-21 19:12:09 +00:00
16 changed files with 213 additions and 465 deletions
@@ -10,7 +10,6 @@ import { addUserMessage } from "#/state/chat-slice";
import { RootState } from "#/store";
import { AgentState } from "#/types/agent-state";
import { generateAgentStateChangeEvent } from "#/services/agent-state-service";
import { getStopProcessesCommand } from "#/services/terminal-service";
import { FeedbackModal } from "../feedback/feedback-modal";
import { useScrollToBottom } from "#/hooks/use-scroll-to-bottom";
import { TypingIndicator } from "./typing-indicator";
@@ -83,8 +82,7 @@ export function ChatInterface() {
const handleStop = () => {
posthog.capture("stop_button_clicked");
send(getStopProcessesCommand()); // First kill all processes
send(generateAgentStateChangeEvent(AgentState.STOPPED)); // Then change agent state
send(generateAgentStateChangeEvent(AgentState.STOPPED));
};
const onClickShareFeedbackActionButton = async (
@@ -4,8 +4,3 @@ export function getTerminalCommand(command: string, hidden: boolean = false) {
const event = { action: ActionType.RUN, args: { command, hidden } };
return event;
}
export function getStopProcessesCommand() {
const event = { action: ActionType.RUN, args: { command: "pkill -P $$" } };
return event;
}
-2
View File
@@ -86,6 +86,4 @@ class ActionTypeSchema(BaseModel):
"""Retrieves content from a user workspace, microagent, or other source."""
ActionType = ActionTypeSchema()
+1 -1
View File
@@ -36,5 +36,5 @@ __all__ = [
'MessageAction',
'ActionConfirmationStatus',
'AgentThinkAction',
'RecallAction',
]
-3
View File
@@ -60,6 +60,3 @@ class IPythonRunCellAction(Action):
@property
def message(self) -> str:
return f'Running Python code interactively: {self.code}'
+16 -12
View File
@@ -1,8 +1,9 @@
import asyncio
from functools import partial
from typing import Any
from typing import Any, Callable, Coroutine
from litellm import acompletion as litellm_acompletion
from litellm.types.utils import ModelResponse
from openhands.core.exceptions import UserCancelledError
from openhands.core.logger import openhands_logger as logger
@@ -17,7 +18,9 @@ from openhands.utils.shutdown_listener import should_continue
class AsyncLLM(LLM):
"""Asynchronous LLM class."""
def __init__(self, *args, **kwargs):
_async_completion: Callable[..., Coroutine[Any, Any, ModelResponse]]
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._async_completion = partial(
@@ -37,7 +40,9 @@ class AsyncLLM(LLM):
seed=self.config.seed,
)
async_completion_unwrapped = self._async_completion
async_completion_unwrapped: Callable[
..., Coroutine[Any, Any, ModelResponse]
] = self._async_completion
@self.retry_decorator(
num_retries=self.config.num_retries,
@@ -46,7 +51,7 @@ class AsyncLLM(LLM):
retry_max_wait=self.config.retry_max_wait,
retry_multiplier=self.config.retry_multiplier,
)
async def async_completion_wrapper(*args, **kwargs):
async def async_completion_wrapper(*args: Any, **kwargs: Any) -> dict[str, Any]:
"""Wrapper for the litellm acompletion function that adds logging and cost tracking."""
messages: list[dict[str, Any]] | dict[str, Any] = []
@@ -77,7 +82,7 @@ class AsyncLLM(LLM):
self.log_prompt(messages)
async def check_stopped():
async def check_stopped() -> None:
while should_continue():
if (
hasattr(self.config, 'on_cancel_requested_fn')
@@ -97,10 +102,8 @@ class AsyncLLM(LLM):
self.log_response(message_back)
# log costs and tokens used
self._post_completion(resp)
# We do not support streaming in this method, thus return resp
return resp
return dict(resp)
except UserCancelledError:
logger.debug('LLM request cancelled by user.')
@@ -117,14 +120,15 @@ class AsyncLLM(LLM):
except asyncio.CancelledError:
pass
self._async_completion = async_completion_wrapper # type: ignore
self._async_completion = async_completion_wrapper
async def _call_acompletion(self, *args, **kwargs):
async def _call_acompletion(self, *args: Any, **kwargs: Any) -> ModelResponse:
"""Wrapper for the litellm acompletion function."""
# Used in testing?
return await litellm_acompletion(*args, **kwargs)
resp = await litellm_acompletion(*args, **kwargs)
return ModelResponse(**resp)
@property
def async_completion(self):
def async_completion(self) -> Callable[..., Coroutine[Any, Any, ModelResponse]]:
"""Decorator for the async litellm acompletion function."""
return self._async_completion
+1 -1
View File
@@ -28,5 +28,5 @@ def list_foundation_models(
return []
def remove_error_modelId(model_list):
def remove_error_modelId(model_list: list[str]) -> list[str]:
return list(filter(lambda m: not m.startswith('bedrock'), model_list))
+7 -7
View File
@@ -7,7 +7,7 @@ MESSAGE_SEPARATOR = '\n\n----------\n\n'
class DebugMixin:
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]):
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]) -> None:
if not messages:
logger.debug('No completion messages!')
return
@@ -24,11 +24,11 @@ class DebugMixin:
else:
logger.debug('No completion messages!')
def log_response(self, message_back: str):
def log_response(self, message_back: str) -> None:
if message_back:
llm_response_logger.debug(message_back)
def _format_message_content(self, message: dict[str, Any]):
def _format_message_content(self, message: dict[str, Any]) -> str:
content = message['content']
if isinstance(content, list):
return '\n'.join(
@@ -36,18 +36,18 @@ class DebugMixin:
)
return str(content)
def _format_content_element(self, element: dict[str, Any]):
def _format_content_element(self, element: dict[str, Any]) -> str:
if isinstance(element, dict):
if 'text' in element:
return element['text']
return str(element['text'])
if (
self.vision_is_active()
and 'image_url' in element
and 'url' in element['image_url']
):
return element['image_url']['url']
return str(element['image_url']['url'])
return str(element)
# This method should be implemented in the class that uses DebugMixin
def vision_is_active(self):
def vision_is_active(self) -> bool:
raise NotImplementedError
+5 -5
View File
@@ -9,7 +9,7 @@ We follow format from: https://docs.litellm.ai/docs/completion/function_call
import copy
import json
import re
from typing import Iterable
from typing import Any, Iterable
from litellm import ChatCompletionToolParam
@@ -265,7 +265,7 @@ def convert_tool_call_to_string(tool_call: dict) -> str:
return ret
def convert_tools_to_description(tools: list[dict]) -> str:
def convert_tools_to_description(tools: list[ChatCompletionToolParam]) -> str:
ret = ''
for i, tool in enumerate(tools):
assert tool['type'] == 'function'
@@ -474,8 +474,8 @@ def convert_fncall_messages_to_non_fncall_messages(
def _extract_and_validate_params(
matching_tool: dict, param_matches: Iterable[re.Match], fn_name: str
) -> dict:
matching_tool: dict[str, Any], param_matches: Iterable[re.Match], fn_name: str
) -> dict[str, Any]:
params = {}
# Parse and validate parameters
required_params = set()
@@ -712,7 +712,7 @@ def convert_non_fncall_messages_to_fncall_messages(
# Parse parameters
param_matches = re.finditer(FN_PARAM_REGEX_PATTERN, fn_body, re.DOTALL)
params = _extract_and_validate_params(
matching_tool, param_matches, fn_name
dict(matching_tool), param_matches, fn_name
)
# Create tool call with unique ID
+61 -48
View File
@@ -2,7 +2,6 @@ import copy
import os
import time
import warnings
from functools import partial
from typing import Any, Callable
import requests
@@ -13,14 +12,15 @@ with warnings.catch_warnings():
warnings.simplefilter('ignore')
import litellm
from litellm import ChatCompletionMessageToolCall, ModelInfo, PromptTokensDetails
from litellm import Message as LiteLLMMessage
from litellm import ChatCompletionMessageToolCall, PromptTokensDetails
from litellm import completion as litellm_completion
from litellm import completion_cost as litellm_completion_cost
from litellm.exceptions import (
RateLimitError,
)
from litellm.types.router import ModelInfo as RouterModelInfo
from litellm.types.utils import CostPerToken, ModelResponse, Usage
from litellm.types.utils import ModelInfo as UtilsModelInfo
from litellm.utils import create_pretrained_tokenizer
from openhands.core.logger import openhands_logger as logger
@@ -87,6 +87,8 @@ class LLM(RetryMixin, DebugMixin):
config: an LLMConfig object specifying the configuration of the LLM.
"""
_completion: Callable[..., ModelResponse]
def __init__(
self,
config: LLMConfig,
@@ -108,7 +110,7 @@ class LLM(RetryMixin, DebugMixin):
self.cost_metric_supported: bool = True
self.config: LLMConfig = copy.deepcopy(config)
self.model_info: ModelInfo | None = None
self.model_info: RouterModelInfo | UtilsModelInfo | None = None
self.retry_listener = retry_listener
if self.config.log_completions:
if self.config.log_completions_folder is None:
@@ -153,23 +155,26 @@ class LLM(RetryMixin, DebugMixin):
kwargs['max_tokens'] = self.config.max_output_tokens
kwargs.pop('max_completion_tokens')
self._completion = partial(
litellm_completion,
model=self.config.model,
api_key=self.config.api_key.get_secret_value()
if self.config.api_key
else None,
base_url=self.config.base_url,
api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider,
timeout=self.config.timeout,
top_p=self.config.top_p,
drop_params=self.config.drop_params,
seed=self.config.seed,
**kwargs,
)
# Create a wrapper function that captures the config values
def completion_with_config(*args: Any, **user_kwargs: Any) -> ModelResponse:
"""Wrapper for litellm_completion that includes the config values."""
merged_kwargs = {
'model': self.config.model,
'api_key': self.config.api_key.get_secret_value()
if self.config.api_key
else None,
'base_url': self.config.base_url,
'api_version': self.config.api_version,
'custom_llm_provider': self.config.custom_llm_provider,
'timeout': self.config.timeout,
'top_p': self.config.top_p,
'drop_params': self.config.drop_params,
**kwargs,
**user_kwargs,
}
return litellm_completion(*args, **merged_kwargs)
self._completion_unwrapped = self._completion
self._completion_unwrapped = completion_with_config
@self.retry_decorator(
num_retries=self.config.num_retries,
@@ -179,7 +184,7 @@ class LLM(RetryMixin, DebugMixin):
retry_multiplier=self.config.retry_multiplier,
retry_listener=self.retry_listener,
)
def wrapper(*args, **kwargs):
def wrapper(*args: Any, **kwargs: Any) -> ModelResponse:
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
from openhands.io import json
@@ -257,20 +262,18 @@ class LLM(RetryMixin, DebugMixin):
# if we mocked function calling, and we have tools, convert the response back to function calling format
if mock_function_calling and mock_fncall_tools is not None:
logger.debug(f'Response choices: {len(resp.choices)}')
assert len(resp.choices) >= 1
non_fncall_response_message = resp.choices[0].message
fn_call_messages_with_response = (
convert_non_fncall_messages_to_fncall_messages(
messages + [non_fncall_response_message], mock_fncall_tools
assert len(resp.choices) == 1
if isinstance(resp.choices[0], dict) and 'message' in resp.choices[0]:
non_fncall_response_message = resp.choices[0]['message']
fn_call_messages_with_response = (
convert_non_fncall_messages_to_fncall_messages(
messages + [dict(non_fncall_response_message)],
mock_fncall_tools,
)
)
)
fn_call_response_message = fn_call_messages_with_response[-1]
if not isinstance(fn_call_response_message, LiteLLMMessage):
fn_call_response_message = LiteLLMMessage(
**fn_call_response_message
)
resp.choices[0].message = fn_call_response_message
fn_call_response_message = fn_call_messages_with_response[-1]
fn_call_response_message = dict(fn_call_response_message)
resp.choices[0]['message'] = fn_call_response_message
message_back: str = resp['choices'][0]['message']['content'] or ''
tool_calls: list[ChatCompletionMessageToolCall] = resp['choices'][0][
@@ -327,14 +330,14 @@ class LLM(RetryMixin, DebugMixin):
self._completion = wrapper
@property
def completion(self):
def completion(self) -> Callable[..., ModelResponse]:
"""Decorator for the litellm completion function.
Check the complete documentation at https://litellm.vercel.app/docs/completion
"""
return self._completion
def init_model_info(self):
def init_model_info(self) -> None:
if self._tried_model_info:
return
self._tried_model_info = True
@@ -464,11 +467,11 @@ class LLM(RetryMixin, DebugMixin):
# remove when litellm is updated to fix https://github.com/BerriAI/litellm/issues/5608
# Check both the full model name and the name after proxy prefix for vision support
return (
litellm.supports_vision(self.config.model)
or litellm.supports_vision(self.config.model.split('/')[-1])
bool(litellm.supports_vision(self.config.model))
or bool(litellm.supports_vision(self.config.model.split('/')[-1]))
or (
self.model_info is not None
and self.model_info.get('supports_vision', False)
and bool(self.model_info.get('supports_vision', False))
)
)
@@ -624,7 +627,7 @@ class LLM(RetryMixin, DebugMixin):
return True
return False
def _completion_cost(self, response) -> float:
def _completion_cost(self, response: ModelResponse) -> float:
"""Calculate completion cost and update metrics with running total.
Calculate the cost of a completion response based on the model. Local models are treated as free.
@@ -663,35 +666,45 @@ class LLM(RetryMixin, DebugMixin):
try:
if cost is None:
try:
cost = litellm_completion_cost(
completion_response=response, **extra_kwargs
cost = float(
litellm_completion_cost(
completion_response=response,
custom_cost_per_token=extra_kwargs.get(
'custom_cost_per_token'
),
)
)
except Exception as e:
logger.error(f'Error getting cost from litellm: {e}')
if cost is None:
_model_name = '/'.join(self.config.model.split('/')[1:])
cost = litellm_completion_cost(
completion_response=response, model=_model_name, **extra_kwargs
cost = float(
litellm_completion_cost(
completion_response=response,
model=_model_name,
custom_cost_per_token=extra_kwargs.get('custom_cost_per_token'),
)
)
logger.debug(
f'Using fallback model name {_model_name} to get cost: {cost}'
)
self.metrics.add_cost(cost)
return cost
cost_float = float(cost)
self.metrics.add_cost(cost_float)
return cost_float
except Exception:
self.cost_metric_supported = False
logger.debug('Cost calculation not supported for this model.')
return 0.0
def __str__(self):
def __str__(self) -> str:
if self.config.api_version:
return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
elif self.config.base_url:
return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
return f'LLM(model={self.config.model})'
def __repr__(self):
def __repr__(self) -> str:
return str(self)
def reset(self) -> None:
+3 -3
View File
@@ -129,13 +129,13 @@ class Metrics:
'token_usages': [usage.model_dump() for usage in self._token_usages],
}
def reset(self):
def reset(self) -> None:
self._accumulated_cost = 0.0
self._costs = []
self._response_latencies = []
self._token_usages = []
def log(self):
def log(self) -> str:
"""Log the metrics."""
metrics = self.get()
logs = ''
@@ -143,5 +143,5 @@ class Metrics:
logs += f'{key}: {value}\n'
return logs
def __repr__(self):
def __repr__(self) -> str:
return f'Metrics({self.get()}'
+23 -11
View File
@@ -1,3 +1,5 @@
from typing import Any, Callable
from tenacity import (
retry,
retry_if_exception_type,
@@ -12,25 +14,35 @@ from openhands.utils.tenacity_stop import stop_if_should_exit
class RetryMixin:
"""Mixin class for retry logic."""
def retry_decorator(self, **kwargs):
def retry_decorator(
self,
*,
num_retries: int | None = None,
retry_exceptions: tuple = (),
retry_min_wait: int | None = None,
retry_max_wait: int | None = None,
retry_multiplier: float | None = None,
retry_listener: Any | None = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""
Create a LLM retry decorator with customizable parameters. This is used for 429 errors, and a few other exceptions in LLM classes.
Args:
**kwargs: Keyword arguments to override default retry behavior.
Keys: num_retries, retry_exceptions, retry_min_wait, retry_max_wait, retry_multiplier
num_retries: Number of retries before giving up
retry_exceptions: Tuple of exception types to retry on
retry_min_wait: Minimum wait time between retries in seconds
retry_max_wait: Maximum wait time between retries in seconds
retry_multiplier: Multiplier for exponential backoff
retry_listener: Optional callback for retry events
Returns:
A retry decorator with the parameters customizable in configuration.
"""
num_retries = kwargs.get('num_retries')
retry_exceptions: tuple = kwargs.get('retry_exceptions', ())
retry_min_wait = kwargs.get('retry_min_wait')
retry_max_wait = kwargs.get('retry_max_wait')
retry_multiplier = kwargs.get('retry_multiplier')
retry_listener = kwargs.get('retry_listener')
# Use the values from config if not provided
# Note: These values are already set in LLMConfig with appropriate defaults
# See openhands/core/config/llm_config.py for the actual default values
def before_sleep(retry_state):
def before_sleep(retry_state: Any) -> None:
self.log_retry_attempt(retry_state)
if retry_listener:
retry_listener(retry_state.attempt_number, num_retries)
@@ -49,7 +61,7 @@ class RetryMixin:
),
)
def log_retry_attempt(self, retry_state):
def log_retry_attempt(self, retry_state: Any) -> None:
"""Log retry attempts."""
exception = retry_state.outcome.exception()
logger.error(
+18 -7
View File
@@ -1,6 +1,8 @@
import asyncio
from functools import partial
from typing import Any
from typing import Any, AsyncIterator, Callable, Coroutine
from litellm.types.utils import ModelResponse
from openhands.core.exceptions import UserCancelledError
from openhands.core.logger import openhands_logger as logger
@@ -11,7 +13,11 @@ from openhands.llm.llm import REASONING_EFFORT_SUPPORTED_MODELS
class StreamingLLM(AsyncLLM):
"""Streaming LLM class."""
def __init__(self, *args, **kwargs):
_async_streaming_completion: Callable[
..., Coroutine[Any, Any, AsyncIterator[ModelResponse]]
]
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._async_streaming_completion = partial(
@@ -31,7 +37,9 @@ class StreamingLLM(AsyncLLM):
stream=True, # Ensure streaming is enabled
)
async_streaming_completion_unwrapped = self._async_streaming_completion
async_streaming_completion_unwrapped: Callable[
..., Coroutine[Any, Any, AsyncIterator[ModelResponse]]
] = self._async_streaming_completion
@self.retry_decorator(
num_retries=self.config.num_retries,
@@ -40,7 +48,7 @@ class StreamingLLM(AsyncLLM):
retry_max_wait=self.config.retry_max_wait,
retry_multiplier=self.config.retry_multiplier,
)
async def async_streaming_completion_wrapper(*args, **kwargs):
async def async_streaming_completion_wrapper(*args: Any, **kwargs: Any) -> Any:
messages: list[dict[str, Any]] | dict[str, Any] = []
# some callers might send the model and messages directly
@@ -89,9 +97,10 @@ class StreamingLLM(AsyncLLM):
message_back = chunk['choices'][0]['delta'].get('content', '')
if message_back:
self.log_response(message_back)
self._post_completion(chunk)
chunk_dict = dict(chunk)
self._post_completion(ModelResponse(**chunk_dict))
yield chunk
yield chunk_dict
except UserCancelledError:
logger.debug('LLM request cancelled by user.')
@@ -108,6 +117,8 @@ class StreamingLLM(AsyncLLM):
self._async_streaming_completion = async_streaming_completion_wrapper
@property
def async_streaming_completion(self):
def async_streaming_completion(
self,
) -> Callable[..., Coroutine[Any, Any, AsyncIterator[ModelResponse]]]:
"""Decorator for the async litellm acompletion function with streaming."""
return self._async_streaming_completion
+77 -281
View File
@@ -7,10 +7,9 @@ from enum import Enum
import bashlex
import libtmux
import psutil
from openhands.core.logger import openhands_logger as logger
from openhands.events.action import Action, CmdRunAction
from openhands.events.action import CmdRunAction
from openhands.events.observation import ErrorObservation
from openhands.events.observation.commands import (
CMD_OUTPUT_PS1_END,
@@ -257,41 +256,10 @@ class BashSession:
)
return content
def kill_process(self, pid: int) -> bool:
"""Kill a process by its PID.
Args:
pid (int): The PID of the process to kill.
Returns:
bool: True if the process was killed successfully, False otherwise.
"""
try:
process = psutil.Process(pid)
process.kill()
return True
except (psutil.NoSuchProcess, psutil.AccessDenied):
return False
def kill_all_processes(self) -> bool:
"""Kill all processes associated with the current command.
Returns:
bool: True if any processes were killed successfully, False otherwise.
"""
process_info = self.get_running_processes()
success = False
for pid in process_info['process_pids']:
if pid != int(self.pane.cmd('display-message', '-p', '#{pane_pid}').stdout[0].strip()):
if self.kill_process(pid):
success = True
return success
def close(self):
"""Clean up the session."""
if self._closed:
return
self.kill_all_processes() # Kill any remaining processes
self.session.kill_session()
self._closed = True
@@ -461,119 +429,6 @@ class BashSession:
# Clear the current content
self._clear_screen()
def get_running_processes(self):
"""Get a list of processes that are currently running in the bash session.
Returns:
dict: A dictionary containing:
- 'is_command_running': Boolean indicating if the last command is still running
- 'current_command_pid': PID of the currently running command (if any)
- 'processes': List of all processes visible to this bash session
- 'command_processes': List of processes that are likely part of the current command
- 'process_pids': List of PIDs of all processes
- 'command_pids': List of PIDs of processes that are likely part of the current command
"""
# Check if a command is running in this session
is_command_running = False
# Get the shell's PID directly from tmux
shell_pid_str = (
self.pane.cmd('display-message', '-p', '#{pane_pid}').stdout[0].strip()
)
shell_pid = int(shell_pid_str)
try:
# Get process information for the shell
shell_process = psutil.Process(shell_pid)
process_list = []
command_processes = []
current_command_pid = None
# Get all child processes recursively
children = shell_process.children(recursive=True)
# Add the shell process first
process_str = f"{shell_pid} {shell_process.ppid()} {shell_process.status()[0]} {' '.join(shell_process.cmdline())}"
process_list.append(process_str)
for child in children:
try:
# Skip if no cmdline (might be a kernel process)
cmdline = child.cmdline()
if not cmdline:
continue
# Format the process info
status_flag = child.status()[0]
# Build process string (PID PPID STATUS COMMAND)
cmd_str = ' '.join(cmdline)
process_str = f'{child.pid} {child.ppid()} {status_flag} {cmd_str}'
process_list.append(process_str)
# Identify processes that are likely part of current command
child_ppid = child.ppid()
# Direct child of shell = likely current command
if child_ppid == shell_pid:
if not current_command_pid:
current_command_pid = child.pid
is_command_running = True
command_processes.append(process_str)
# Child of identified command process = part of current command
elif current_command_pid and (
child_ppid == current_command_pid
or any(
p.pid == child_ppid
for p in children
if p.pid == current_command_pid
or p.ppid() == current_command_pid
)
):
command_processes.append(process_str)
except (psutil.NoSuchProcess, psutil.AccessDenied):
# Process may have terminated while we were examining it
continue
# If we have no command processes, we're not running anything
if not command_processes:
is_command_running = False
current_command_pid = None
except (psutil.NoSuchProcess, psutil.AccessDenied) as e:
logger.warning(f'Error accessing process information: {e}')
return {
'is_command_running': is_command_running,
'current_command_pid': None,
'processes': [],
'command_processes': [],
}
# Extract PIDs from process strings
process_pids = []
command_pids = []
for proc in process_list:
try:
pid = int(proc.split()[0])
process_pids.append(pid)
except (ValueError, IndexError):
continue
for proc in command_processes:
try:
pid = int(proc.split()[0])
command_pids.append(pid)
except (ValueError, IndexError):
continue
return {
'is_command_running': is_command_running,
'current_command_pid': current_command_pid,
'processes': process_list,
'command_processes': command_processes,
'process_pids': process_pids,
'command_pids': command_pids,
}
def _combine_outputs_between_matches(
self,
pane_content: str,
@@ -609,95 +464,34 @@ class BashSession:
logger.debug(f'COMBINED OUTPUT: {combined_output}')
return combined_output
def execute(self, action: Action) -> CmdOutputObservation | ErrorObservation:
def execute(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservation:
"""Execute a command in the bash session."""
if not self._initialized:
raise RuntimeError('Bash session is not initialized')
# Strip the command of any leading/trailing whitespace
logger.debug(f'RECEIVED ACTION: {action}')
# Handle CmdRunAction
if not isinstance(action, CmdRunAction):
return ErrorObservation(f"Unsupported action type: {type(action)}")
command = action.command.strip()
is_input = action.is_input
is_input: bool = action.is_input
# Handle different command types
if command == '':
return self._handle_empty_command(action)
elif is_input:
return self._handle_input_command(action)
else:
return self._handle_normal_command(action)
def _handle_empty_command(self, action: CmdRunAction) -> CmdOutputObservation:
"""Handle an empty command (usually to retrieve more output from a running command)."""
assert action.command.strip() == ''
# If the previous command is not in a continuing state, return an error
# If the previous command is not completed, we need to check if the command is empty
if self.prev_status not in {
BashCommandStatus.CONTINUE,
BashCommandStatus.NO_CHANGE_TIMEOUT,
BashCommandStatus.HARD_TIMEOUT,
}:
return CmdOutputObservation(
content='ERROR: No previous running command to retrieve logs from.',
command='',
metadata=CmdOutputMetadata(),
)
# Start polling for command completion
return self._poll_for_command_completion('', action)
def _handle_input_command(self, action: CmdRunAction) -> CmdOutputObservation:
"""Handle an input command (sent to a running process)."""
command = action.command.strip()
# If the previous command is not in a continuing state, return an error
if self.prev_status not in {
BashCommandStatus.CONTINUE,
BashCommandStatus.NO_CHANGE_TIMEOUT,
BashCommandStatus.HARD_TIMEOUT,
}:
return CmdOutputObservation(
content='ERROR: No previous running command to interact with.',
command='',
metadata=CmdOutputMetadata(),
)
# Check if it's a special key
is_special_key = self._is_special_key(command)
# Send the input to the pane
logger.debug(f'SENDING INPUT TO RUNNING PROCESS: {command!r}')
self.pane.send_keys(
command,
enter=not is_special_key,
)
# Start polling for command completion
return self._poll_for_command_completion(command, action)
def _handle_normal_command(
self, action: CmdRunAction
) -> CmdOutputObservation | ErrorObservation:
"""Handle a normal command."""
command = action.command.strip()
# Check if command is running previous command first
last_pane_output = self._get_pane_content()
if (
self.prev_status
in {
BashCommandStatus.HARD_TIMEOUT,
BashCommandStatus.NO_CHANGE_TIMEOUT,
}
and not last_pane_output.endswith(
CMD_OUTPUT_PS1_END
) # prev command is not completed
):
return self._handle_interrupted_command(command, last_pane_output)
if command == '':
return CmdOutputObservation(
content='ERROR: No previous running command to retrieve logs from.',
command='',
metadata=CmdOutputMetadata(),
)
if is_input:
return CmdOutputObservation(
content='ERROR: No previous running command to interact with.',
command='',
metadata=CmdOutputMetadata(),
)
# Check if the command is a single command or multiple commands
splited_commands = split_bash_commands(command)
@@ -710,56 +504,67 @@ class BashSession:
)
)
# Convert command to raw string and send it
is_special_key = self._is_special_key(command)
command = escape_bash_special_chars(command)
logger.debug(f'SENDING COMMAND: {command!r}')
self.pane.send_keys(
command,
enter=not is_special_key,
)
# Start polling for command completion
return self._poll_for_command_completion(command, action)
def _handle_interrupted_command(
self, command: str, last_pane_output: str
) -> CmdOutputObservation:
"""Handle the case where a new command is sent while a previous command is still running."""
_ps1_matches = CmdOutputMetadata.matches_ps1_metadata(last_pane_output)
raw_command_output = self._combine_outputs_between_matches(
last_pane_output, _ps1_matches
)
metadata = CmdOutputMetadata() # No metadata available
metadata.suffix = (
f'\n[Your command "{command}" is NOT executed. '
f'The previous command is still running - You CANNOT send new commands until the previous command is completed. '
'By setting `is_input` to `true`, you can interact with the current process: '
"You may wait longer to see additional output of the previous command by sending empty command '', "
'send other commands to interact with the current process, '
'or send keys ("C-c", "C-z", "C-d") to interrupt/kill the previous command before sending your new command.]'
)
logger.debug(f'PREVIOUS COMMAND OUTPUT: {raw_command_output}')
command_output = self._get_command_output(
command,
raw_command_output,
metadata,
continue_prefix='[Below is the output of the previous command.]\n',
)
return CmdOutputObservation(
command=command,
content=command_output,
metadata=metadata,
)
def _poll_for_command_completion(
self, command: str, action: CmdRunAction
) -> CmdOutputObservation:
"""Poll for command completion and handle timeouts."""
start_time = time.time()
last_change_time = start_time
last_pane_output = self._get_pane_content()
# When prev command is still running, and we are trying to send a new command
if (
self.prev_status
in {
BashCommandStatus.HARD_TIMEOUT,
BashCommandStatus.NO_CHANGE_TIMEOUT,
}
and not last_pane_output.endswith(
CMD_OUTPUT_PS1_END
) # prev command is not completed
and not is_input
and command != '' # not input and not empty command
):
_ps1_matches = CmdOutputMetadata.matches_ps1_metadata(last_pane_output)
raw_command_output = self._combine_outputs_between_matches(
last_pane_output, _ps1_matches
)
metadata = CmdOutputMetadata() # No metadata available
metadata.suffix = (
f'\n[Your command "{command}" is NOT executed. '
f'The previous command is still running - You CANNOT send new commands until the previous command is completed. '
'By setting `is_input` to `true`, you can interact with the current process: '
"You may wait longer to see additional output of the previous command by sending empty command '', "
'send other commands to interact with the current process, '
'or send keys ("C-c", "C-z", "C-d") to interrupt/kill the previous command before sending your new command.]'
)
logger.debug(f'PREVIOUS COMMAND OUTPUT: {raw_command_output}')
command_output = self._get_command_output(
command,
raw_command_output,
metadata,
continue_prefix='[Below is the output of the previous command.]\n',
)
return CmdOutputObservation(
command=command,
content=command_output,
metadata=metadata,
)
# Send actual command/inputs to the pane
if command != '':
is_special_key = self._is_special_key(command)
if is_input:
logger.debug(f'SENDING INPUT TO RUNNING PROCESS: {command!r}')
self.pane.send_keys(
command,
enter=not is_special_key,
)
else:
# convert command to raw string
command = escape_bash_special_chars(command)
logger.debug(f'SENDING COMMAND: {command!r}')
self.pane.send_keys(
command,
enter=not is_special_key,
)
# Loop until the command completes or times out
while should_continue():
_start_time = time.time()
@@ -770,18 +575,6 @@ class BashSession:
)
logger.debug(f'BEGIN OF PANE CONTENT: {cur_pane_output.split("\n")[:10]}')
logger.debug(f'END OF PANE CONTENT: {cur_pane_output.split("\n")[-10:]}')
# Log running processes for debugging
try:
process_info = self.get_running_processes()
logger.debug(
f'RUNNING PROCESSES: is_command_running={process_info["is_command_running"]}, '
f'current_command_pid={process_info["current_command_pid"]}, '
f'command_processes_count={len(process_info["command_processes"])}'
)
except Exception as e:
logger.warning(f'Failed to get running processes: {e}')
ps1_matches = CmdOutputMetadata.matches_ps1_metadata(cur_pane_output)
if cur_pane_output != last_pane_output:
last_pane_output = cur_pane_output
@@ -789,6 +582,7 @@ class BashSession:
logger.debug(f'CONTENT UPDATED DETECTED at {last_change_time}')
# 1) Execution completed
# if the last command output contains the end marker
if cur_pane_output.rstrip().endswith(CMD_OUTPUT_PS1_END.rstrip()):
return self._handle_completed_command(
command,
@@ -797,6 +591,8 @@ class BashSession:
)
# 2) Execution timed out since there's no change in output
# for a while (self.NO_CHANGE_TIMEOUT_SECONDS)
# We ignore this if the command is *blocking
time_since_last_change = time.time() - last_change_time
logger.debug(
f'CHECKING NO CHANGE TIMEOUT ({self.NO_CHANGE_TIMEOUT_SECONDS}s): elapsed {time_since_last_change}. Action blocking: {action.blocking}'
-37
View File
@@ -386,40 +386,3 @@ def test_python_interactive_input():
assert session.prev_status == BashCommandStatus.COMPLETED
session.close()
def test_get_running_processes():
"""Test the get_running_processes method to detect running processes."""
session = BashSession(work_dir=os.getcwd(), no_change_timeout_seconds=2)
session.initialize()
# First check with no running command
process_info = session.get_running_processes()
assert isinstance(process_info, dict)
assert 'is_command_running' in process_info
assert process_info['is_command_running'] is False
assert 'processes' in process_info
assert len(process_info['processes']) == 1 # should have the shell process
assert 'command_processes' in process_info
assert len(process_info['command_processes']) == 0
assert 'current_command_pid' in process_info
assert process_info['current_command_pid'] is None
session.execute(CmdRunAction('sleep 120', blocking=False))
# Check running processes
process_info = session.get_running_processes()
assert process_info['is_command_running'] is True
assert process_info['current_command_pid'] is not None
assert len(process_info['command_processes']) > 0
# Send Ctrl+C to terminate the process
session.execute(CmdRunAction('C-c', is_input=True))
# Verify process is no longer running
process_info = session.get_running_processes()
assert process_info['is_command_running'] is False
assert process_info['current_command_pid'] is None
assert len(process_info['command_processes']) == 0
session.close()
-39
View File
@@ -1,39 +0,0 @@
import time
from openhands.events.action import CmdRunAction
from openhands.runtime.utils.bash import BashSession
def test_stop_button_background_process():
session = BashSession(work_dir='/tmp', no_change_timeout_seconds=2)
session.initialize()
# Start a process that runs indefinitely and detaches from the terminal
session.execute(
CmdRunAction(
'nohup sleep 60 > /dev/null 2>&1 &'
) # Background process that detaches from terminal
)
time.sleep(2) # Give time for the process to start
# Get initial process info
process_info = session.get_running_processes()
print('Initial process info:', process_info) # Debug output
assert any(
'sleep' in p for p in process_info['processes']
), 'Expected to find sleep process'
initial_processes = [p for p in process_info['processes'] if 'sleep' in p]
assert len(initial_processes) > 0, 'Expected at least one sleep process'
# Send kill command to stop it
session.execute(CmdRunAction('pkill -P $$'))
time.sleep(1) # Give time for processes to be killed
# Check if process is still running (it should be terminated)
process_info = session.get_running_processes()
print('Process info after kill command:', process_info) # Debug output
assert not any(
'sleep' in p for p in process_info['processes']
), 'Background process should be terminated'
session.close()