Compare commits

...

26 Commits

Author SHA1 Message Date
openhands
ac07afbb3c fix: make LLM response handling more robust for deepseek and other providers
- Replace type-based checks with dictionary-based checks
- Update message access to use dictionary syntax
- Remove unused imports
2025-03-25 06:28:34 +00:00
openhands
b9f330db1b Merge main into fix/llm-mypy-errors (keeping our poetry.lock) 2025-03-24 23:35:44 +00:00
openhands
096b259694 Replace partial() with proper type hints
- Remove partial() since we're not using its argument binding functionality
- Add proper type hints for class attributes and variables
- Fix mypy errors with proper type annotations
2025-03-03 20:00:53 +00:00
openhands
7f3202d0e0 Replace partial() with direct function wrappers
- Replace partial() with direct function wrappers to make the code clearer
- Keep type hints using cast() to satisfy mypy
- Add proper type hints for all functions
2025-03-03 19:56:30 +00:00
openhands
e63dfca517 Replace partial() with cast() for type hints
- Replace partial() with cast() since we're not using partial's argument binding
- Add proper type hints for the cast() calls
- Add missing imports for type hints
2025-03-03 19:50:30 +00:00
openhands
cb705b736f Remove duplicate default values in retry_mixin.py
- Remove hardcoded default values for retry parameters in retry_mixin.py
  since they are already defined in LLMConfig
- Add comment explaining where the default values come from
- Keep partial() usage to maintain type compatibility
2025-03-03 18:50:33 +00:00
Graham Neubig
1e5c4da0fc Merge branch 'main' into fix/llm-mypy-errors 2025-03-03 13:46:18 -05:00
Graham Neubig
526618753d Update .github/workflows/py-unit-tests.yml 2025-02-22 23:40:37 -05:00
openhands
e25e6766fb Fix ruff and ruff-format issues 2025-02-22 23:27:10 +00:00
Graham Neubig
1b9a2b43c3 Merge branch 'main' into fix/llm-mypy-errors 2025-02-22 17:27:14 -05:00
openhands
7886c1f920 Fix litellm type imports to use OpenAI types directly 2025-02-21 22:21:40 +00:00
openhands
4b49ffb01d Fix litellm type imports to use OpenAI types directly 2025-02-21 17:30:43 +00:00
openhands
f74ce56a35 Fix litellm import path for Choices and StreamingChoices 2025-02-21 17:29:12 +00:00
openhands
9ccf680c38 Remove parallel test execution to fix failing tests 2025-02-21 17:25:59 +00:00
Graham Neubig
41068f6ea1 Merge branch 'main' into fix/llm-mypy-errors 2025-02-21 08:22:40 -05:00
Engel Nyst
5b1c8bc2e8 Merge branch 'main' into fix/llm-mypy-errors 2025-02-19 19:28:48 +01:00
Xingyao Wang
5b8db983b7 Merge branch 'main' into fix/llm-mypy-errors 2025-02-19 12:19:39 -05:00
Graham Neubig
18543a2efa Merge branch 'main' into fix/llm-mypy-errors 2025-02-19 05:22:06 -05:00
openhands
63d9c3d668 Revert changes outside openhands/llm directory 2025-02-19 02:43:02 +00:00
openhands
2589b13815 Fix mypy errors in openhands/llm directory 2025-02-19 02:42:17 +00:00
Graham Neubig
592aca05e1 Merge branch 'main' into feature/strict-mypy-checks 2025-02-18 20:14:17 -05:00
Graham Neubig
d309455733 Merge branch 'main' into feature/strict-mypy-checks 2025-02-11 10:19:36 -05:00
Graham Neubig
66a7920539 Merge branch 'main' into feature/strict-mypy-checks 2025-02-10 13:06:49 -05:00
Graham Neubig
64ebef3646 Update .github/workflows/lint.yml 2025-01-21 14:52:56 -05:00
Graham Neubig
7a259915c1 Update .github/workflows/lint.yml 2025-01-21 14:47:51 -05:00
openhands
66bd8fdbcd Enable strict type checking with mypy
- Update mypy configuration with stricter type checking rules
- Add more type stubs to pre-commit configuration
- Run mypy both through pre-commit and directly in CI
- Install project in editable mode for better type checking
- Set correct PYTHONPATH in CI environment
2025-01-21 19:12:09 +00:00
8 changed files with 134 additions and 94 deletions

View File

@@ -1,8 +1,9 @@
import asyncio
from functools import partial
from typing import Any
from typing import Any, Callable, Coroutine
from litellm import acompletion as litellm_acompletion
from litellm.types.utils import ModelResponse
from openhands.core.exceptions import UserCancelledError
from openhands.core.logger import openhands_logger as logger
@@ -17,7 +18,9 @@ from openhands.utils.shutdown_listener import should_continue
class AsyncLLM(LLM):
"""Asynchronous LLM class."""
def __init__(self, *args, **kwargs):
_async_completion: Callable[..., Coroutine[Any, Any, ModelResponse]]
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._async_completion = partial(
@@ -37,7 +40,9 @@ class AsyncLLM(LLM):
seed=self.config.seed,
)
async_completion_unwrapped = self._async_completion
async_completion_unwrapped: Callable[
..., Coroutine[Any, Any, ModelResponse]
] = self._async_completion
@self.retry_decorator(
num_retries=self.config.num_retries,
@@ -46,7 +51,7 @@ class AsyncLLM(LLM):
retry_max_wait=self.config.retry_max_wait,
retry_multiplier=self.config.retry_multiplier,
)
async def async_completion_wrapper(*args, **kwargs):
async def async_completion_wrapper(*args: Any, **kwargs: Any) -> dict[str, Any]:
"""Wrapper for the litellm acompletion function that adds logging and cost tracking."""
messages: list[dict[str, Any]] | dict[str, Any] = []
@@ -77,7 +82,7 @@ class AsyncLLM(LLM):
self.log_prompt(messages)
async def check_stopped():
async def check_stopped() -> None:
while should_continue():
if (
hasattr(self.config, 'on_cancel_requested_fn')
@@ -97,10 +102,8 @@ class AsyncLLM(LLM):
self.log_response(message_back)
# log costs and tokens used
self._post_completion(resp)
# We do not support streaming in this method, thus return resp
return resp
return dict(resp)
except UserCancelledError:
logger.debug('LLM request cancelled by user.')
@@ -117,14 +120,15 @@ class AsyncLLM(LLM):
except asyncio.CancelledError:
pass
self._async_completion = async_completion_wrapper # type: ignore
self._async_completion = async_completion_wrapper
async def _call_acompletion(self, *args, **kwargs):
async def _call_acompletion(self, *args: Any, **kwargs: Any) -> ModelResponse:
"""Wrapper for the litellm acompletion function."""
# Used in testing?
return await litellm_acompletion(*args, **kwargs)
resp = await litellm_acompletion(*args, **kwargs)
return ModelResponse(**resp)
@property
def async_completion(self):
def async_completion(self) -> Callable[..., Coroutine[Any, Any, ModelResponse]]:
"""Decorator for the async litellm acompletion function."""
return self._async_completion

View File

@@ -28,5 +28,5 @@ def list_foundation_models(
return []
def remove_error_modelId(model_list):
def remove_error_modelId(model_list: list[str]) -> list[str]:
return list(filter(lambda m: not m.startswith('bedrock'), model_list))

View File

@@ -7,7 +7,7 @@ MESSAGE_SEPARATOR = '\n\n----------\n\n'
class DebugMixin:
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]):
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]) -> None:
if not messages:
logger.debug('No completion messages!')
return
@@ -24,11 +24,11 @@ class DebugMixin:
else:
logger.debug('No completion messages!')
def log_response(self, message_back: str):
def log_response(self, message_back: str) -> None:
if message_back:
llm_response_logger.debug(message_back)
def _format_message_content(self, message: dict[str, Any]):
def _format_message_content(self, message: dict[str, Any]) -> str:
content = message['content']
if isinstance(content, list):
return '\n'.join(
@@ -36,18 +36,18 @@ class DebugMixin:
)
return str(content)
def _format_content_element(self, element: dict[str, Any]):
def _format_content_element(self, element: dict[str, Any]) -> str:
if isinstance(element, dict):
if 'text' in element:
return element['text']
return str(element['text'])
if (
self.vision_is_active()
and 'image_url' in element
and 'url' in element['image_url']
):
return element['image_url']['url']
return str(element['image_url']['url'])
return str(element)
# This method should be implemented in the class that uses DebugMixin
def vision_is_active(self):
def vision_is_active(self) -> bool:
raise NotImplementedError

View File

@@ -9,7 +9,7 @@ We follow format from: https://docs.litellm.ai/docs/completion/function_call
import copy
import json
import re
from typing import Iterable
from typing import Any, Iterable
from litellm import ChatCompletionToolParam
@@ -265,7 +265,7 @@ def convert_tool_call_to_string(tool_call: dict) -> str:
return ret
def convert_tools_to_description(tools: list[dict]) -> str:
def convert_tools_to_description(tools: list[ChatCompletionToolParam]) -> str:
ret = ''
for i, tool in enumerate(tools):
assert tool['type'] == 'function'
@@ -474,8 +474,8 @@ def convert_fncall_messages_to_non_fncall_messages(
def _extract_and_validate_params(
matching_tool: dict, param_matches: Iterable[re.Match], fn_name: str
) -> dict:
matching_tool: dict[str, Any], param_matches: Iterable[re.Match], fn_name: str
) -> dict[str, Any]:
params = {}
# Parse and validate parameters
required_params = set()
@@ -712,7 +712,7 @@ def convert_non_fncall_messages_to_fncall_messages(
# Parse parameters
param_matches = re.finditer(FN_PARAM_REGEX_PATTERN, fn_body, re.DOTALL)
params = _extract_and_validate_params(
matching_tool, param_matches, fn_name
dict(matching_tool), param_matches, fn_name
)
# Create tool call with unique ID

View File

@@ -2,7 +2,6 @@ import copy
import os
import time
import warnings
from functools import partial
from typing import Any, Callable
import requests
@@ -13,14 +12,15 @@ with warnings.catch_warnings():
warnings.simplefilter('ignore')
import litellm
from litellm import ChatCompletionMessageToolCall, ModelInfo, PromptTokensDetails
from litellm import Message as LiteLLMMessage
from litellm import ChatCompletionMessageToolCall, PromptTokensDetails
from litellm import completion as litellm_completion
from litellm import completion_cost as litellm_completion_cost
from litellm.exceptions import (
RateLimitError,
)
from litellm.types.router import ModelInfo as RouterModelInfo
from litellm.types.utils import CostPerToken, ModelResponse, Usage
from litellm.types.utils import ModelInfo as UtilsModelInfo
from litellm.utils import create_pretrained_tokenizer
from openhands.core.logger import openhands_logger as logger
@@ -87,6 +87,8 @@ class LLM(RetryMixin, DebugMixin):
config: an LLMConfig object specifying the configuration of the LLM.
"""
_completion: Callable[..., ModelResponse]
def __init__(
self,
config: LLMConfig,
@@ -108,7 +110,7 @@ class LLM(RetryMixin, DebugMixin):
self.cost_metric_supported: bool = True
self.config: LLMConfig = copy.deepcopy(config)
self.model_info: ModelInfo | None = None
self.model_info: RouterModelInfo | UtilsModelInfo | None = None
self.retry_listener = retry_listener
if self.config.log_completions:
if self.config.log_completions_folder is None:
@@ -153,23 +155,26 @@ class LLM(RetryMixin, DebugMixin):
kwargs['max_tokens'] = self.config.max_output_tokens
kwargs.pop('max_completion_tokens')
self._completion = partial(
litellm_completion,
model=self.config.model,
api_key=self.config.api_key.get_secret_value()
if self.config.api_key
else None,
base_url=self.config.base_url,
api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider,
timeout=self.config.timeout,
top_p=self.config.top_p,
drop_params=self.config.drop_params,
seed=self.config.seed,
**kwargs,
)
# Create a wrapper function that captures the config values
def completion_with_config(*args: Any, **user_kwargs: Any) -> ModelResponse:
"""Wrapper for litellm_completion that includes the config values."""
merged_kwargs = {
'model': self.config.model,
'api_key': self.config.api_key.get_secret_value()
if self.config.api_key
else None,
'base_url': self.config.base_url,
'api_version': self.config.api_version,
'custom_llm_provider': self.config.custom_llm_provider,
'timeout': self.config.timeout,
'top_p': self.config.top_p,
'drop_params': self.config.drop_params,
**kwargs,
**user_kwargs,
}
return litellm_completion(*args, **merged_kwargs)
self._completion_unwrapped = self._completion
self._completion_unwrapped = completion_with_config
@self.retry_decorator(
num_retries=self.config.num_retries,
@@ -179,7 +184,7 @@ class LLM(RetryMixin, DebugMixin):
retry_multiplier=self.config.retry_multiplier,
retry_listener=self.retry_listener,
)
def wrapper(*args, **kwargs):
def wrapper(*args: Any, **kwargs: Any) -> ModelResponse:
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
from openhands.io import json
@@ -257,20 +262,18 @@ class LLM(RetryMixin, DebugMixin):
# if we mocked function calling, and we have tools, convert the response back to function calling format
if mock_function_calling and mock_fncall_tools is not None:
logger.debug(f'Response choices: {len(resp.choices)}')
assert len(resp.choices) >= 1
non_fncall_response_message = resp.choices[0].message
fn_call_messages_with_response = (
convert_non_fncall_messages_to_fncall_messages(
messages + [non_fncall_response_message], mock_fncall_tools
assert len(resp.choices) == 1
if isinstance(resp.choices[0], dict) and 'message' in resp.choices[0]:
non_fncall_response_message = resp.choices[0]['message']
fn_call_messages_with_response = (
convert_non_fncall_messages_to_fncall_messages(
messages + [dict(non_fncall_response_message)],
mock_fncall_tools,
)
)
)
fn_call_response_message = fn_call_messages_with_response[-1]
if not isinstance(fn_call_response_message, LiteLLMMessage):
fn_call_response_message = LiteLLMMessage(
**fn_call_response_message
)
resp.choices[0].message = fn_call_response_message
fn_call_response_message = fn_call_messages_with_response[-1]
fn_call_response_message = dict(fn_call_response_message)
resp.choices[0]['message'] = fn_call_response_message
message_back: str = resp['choices'][0]['message']['content'] or ''
tool_calls: list[ChatCompletionMessageToolCall] = resp['choices'][0][
@@ -327,14 +330,14 @@ class LLM(RetryMixin, DebugMixin):
self._completion = wrapper
@property
def completion(self):
def completion(self) -> Callable[..., ModelResponse]:
"""Decorator for the litellm completion function.
Check the complete documentation at https://litellm.vercel.app/docs/completion
"""
return self._completion
def init_model_info(self):
def init_model_info(self) -> None:
if self._tried_model_info:
return
self._tried_model_info = True
@@ -464,11 +467,11 @@ class LLM(RetryMixin, DebugMixin):
# remove when litellm is updated to fix https://github.com/BerriAI/litellm/issues/5608
# Check both the full model name and the name after proxy prefix for vision support
return (
litellm.supports_vision(self.config.model)
or litellm.supports_vision(self.config.model.split('/')[-1])
bool(litellm.supports_vision(self.config.model))
or bool(litellm.supports_vision(self.config.model.split('/')[-1]))
or (
self.model_info is not None
and self.model_info.get('supports_vision', False)
and bool(self.model_info.get('supports_vision', False))
)
)
@@ -624,7 +627,7 @@ class LLM(RetryMixin, DebugMixin):
return True
return False
def _completion_cost(self, response) -> float:
def _completion_cost(self, response: ModelResponse) -> float:
"""Calculate completion cost and update metrics with running total.
Calculate the cost of a completion response based on the model. Local models are treated as free.
@@ -663,35 +666,45 @@ class LLM(RetryMixin, DebugMixin):
try:
if cost is None:
try:
cost = litellm_completion_cost(
completion_response=response, **extra_kwargs
cost = float(
litellm_completion_cost(
completion_response=response,
custom_cost_per_token=extra_kwargs.get(
'custom_cost_per_token'
),
)
)
except Exception as e:
logger.error(f'Error getting cost from litellm: {e}')
if cost is None:
_model_name = '/'.join(self.config.model.split('/')[1:])
cost = litellm_completion_cost(
completion_response=response, model=_model_name, **extra_kwargs
cost = float(
litellm_completion_cost(
completion_response=response,
model=_model_name,
custom_cost_per_token=extra_kwargs.get('custom_cost_per_token'),
)
)
logger.debug(
f'Using fallback model name {_model_name} to get cost: {cost}'
)
self.metrics.add_cost(cost)
return cost
cost_float = float(cost)
self.metrics.add_cost(cost_float)
return cost_float
except Exception:
self.cost_metric_supported = False
logger.debug('Cost calculation not supported for this model.')
return 0.0
def __str__(self):
def __str__(self) -> str:
if self.config.api_version:
return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
elif self.config.base_url:
return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
return f'LLM(model={self.config.model})'
def __repr__(self):
def __repr__(self) -> str:
return str(self)
def reset(self) -> None:

View File

@@ -129,13 +129,13 @@ class Metrics:
'token_usages': [usage.model_dump() for usage in self._token_usages],
}
def reset(self):
def reset(self) -> None:
self._accumulated_cost = 0.0
self._costs = []
self._response_latencies = []
self._token_usages = []
def log(self):
def log(self) -> str:
"""Log the metrics."""
metrics = self.get()
logs = ''
@@ -143,5 +143,5 @@ class Metrics:
logs += f'{key}: {value}\n'
return logs
def __repr__(self):
def __repr__(self) -> str:
return f'Metrics({self.get()}'

View File

@@ -1,3 +1,5 @@
from typing import Any, Callable
from tenacity import (
retry,
retry_if_exception_type,
@@ -12,25 +14,35 @@ from openhands.utils.tenacity_stop import stop_if_should_exit
class RetryMixin:
"""Mixin class for retry logic."""
def retry_decorator(self, **kwargs):
def retry_decorator(
self,
*,
num_retries: int | None = None,
retry_exceptions: tuple = (),
retry_min_wait: int | None = None,
retry_max_wait: int | None = None,
retry_multiplier: float | None = None,
retry_listener: Any | None = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""
Create a LLM retry decorator with customizable parameters. This is used for 429 errors, and a few other exceptions in LLM classes.
Args:
**kwargs: Keyword arguments to override default retry behavior.
Keys: num_retries, retry_exceptions, retry_min_wait, retry_max_wait, retry_multiplier
num_retries: Number of retries before giving up
retry_exceptions: Tuple of exception types to retry on
retry_min_wait: Minimum wait time between retries in seconds
retry_max_wait: Maximum wait time between retries in seconds
retry_multiplier: Multiplier for exponential backoff
retry_listener: Optional callback for retry events
Returns:
A retry decorator with the parameters customizable in configuration.
"""
num_retries = kwargs.get('num_retries')
retry_exceptions: tuple = kwargs.get('retry_exceptions', ())
retry_min_wait = kwargs.get('retry_min_wait')
retry_max_wait = kwargs.get('retry_max_wait')
retry_multiplier = kwargs.get('retry_multiplier')
retry_listener = kwargs.get('retry_listener')
# Use the values from config if not provided
# Note: These values are already set in LLMConfig with appropriate defaults
# See openhands/core/config/llm_config.py for the actual default values
def before_sleep(retry_state):
def before_sleep(retry_state: Any) -> None:
self.log_retry_attempt(retry_state)
if retry_listener:
retry_listener(retry_state.attempt_number, num_retries)
@@ -49,7 +61,7 @@ class RetryMixin:
),
)
def log_retry_attempt(self, retry_state):
def log_retry_attempt(self, retry_state: Any) -> None:
"""Log retry attempts."""
exception = retry_state.outcome.exception()
logger.error(

View File

@@ -1,6 +1,8 @@
import asyncio
from functools import partial
from typing import Any
from typing import Any, AsyncIterator, Callable, Coroutine
from litellm.types.utils import ModelResponse
from openhands.core.exceptions import UserCancelledError
from openhands.core.logger import openhands_logger as logger
@@ -11,7 +13,11 @@ from openhands.llm.llm import REASONING_EFFORT_SUPPORTED_MODELS
class StreamingLLM(AsyncLLM):
"""Streaming LLM class."""
def __init__(self, *args, **kwargs):
_async_streaming_completion: Callable[
..., Coroutine[Any, Any, AsyncIterator[ModelResponse]]
]
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._async_streaming_completion = partial(
@@ -31,7 +37,9 @@ class StreamingLLM(AsyncLLM):
stream=True, # Ensure streaming is enabled
)
async_streaming_completion_unwrapped = self._async_streaming_completion
async_streaming_completion_unwrapped: Callable[
..., Coroutine[Any, Any, AsyncIterator[ModelResponse]]
] = self._async_streaming_completion
@self.retry_decorator(
num_retries=self.config.num_retries,
@@ -40,7 +48,7 @@ class StreamingLLM(AsyncLLM):
retry_max_wait=self.config.retry_max_wait,
retry_multiplier=self.config.retry_multiplier,
)
async def async_streaming_completion_wrapper(*args, **kwargs):
async def async_streaming_completion_wrapper(*args: Any, **kwargs: Any) -> Any:
messages: list[dict[str, Any]] | dict[str, Any] = []
# some callers might send the model and messages directly
@@ -89,9 +97,10 @@ class StreamingLLM(AsyncLLM):
message_back = chunk['choices'][0]['delta'].get('content', '')
if message_back:
self.log_response(message_back)
self._post_completion(chunk)
chunk_dict = dict(chunk)
self._post_completion(ModelResponse(**chunk_dict))
yield chunk
yield chunk_dict
except UserCancelledError:
logger.debug('LLM request cancelled by user.')
@@ -108,6 +117,8 @@ class StreamingLLM(AsyncLLM):
self._async_streaming_completion = async_streaming_completion_wrapper
@property
def async_streaming_completion(self):
def async_streaming_completion(
self,
) -> Callable[..., Coroutine[Any, Any, AsyncIterator[ModelResponse]]]:
"""Decorator for the async litellm acompletion function with streaming."""
return self._async_streaming_completion