diff --git a/openhands/llm/async_llm.py b/openhands/llm/async_llm.py index a9a9224a9b..ef3d4e1848 100644 --- a/openhands/llm/async_llm.py +++ b/openhands/llm/async_llm.py @@ -1,6 +1,6 @@ import asyncio from functools import partial -from typing import Any +from typing import Any, Callable from litellm import acompletion as litellm_acompletion @@ -17,7 +17,7 @@ from openhands.utils.shutdown_listener import should_continue class AsyncLLM(LLM): """Asynchronous LLM class.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._async_completion = partial( @@ -46,7 +46,7 @@ class AsyncLLM(LLM): retry_max_wait=self.config.retry_max_wait, retry_multiplier=self.config.retry_multiplier, ) - async def async_completion_wrapper(*args, **kwargs): + async def async_completion_wrapper(*args: Any, **kwargs: Any) -> Any: """Wrapper for the litellm acompletion function that adds logging and cost tracking.""" messages: list[dict[str, Any]] | dict[str, Any] = [] @@ -77,7 +77,7 @@ class AsyncLLM(LLM): self.log_prompt(messages) - async def check_stopped(): + async def check_stopped() -> None: while should_continue(): if ( hasattr(self.config, 'on_cancel_requested_fn') @@ -117,14 +117,14 @@ class AsyncLLM(LLM): except asyncio.CancelledError: pass - self._async_completion = async_completion_wrapper # type: ignore + self._async_completion = async_completion_wrapper - async def _call_acompletion(self, *args, **kwargs): + async def _call_acompletion(self, *args: Any, **kwargs: Any) -> Any: """Wrapper for the litellm acompletion function.""" # Used in testing? return await litellm_acompletion(*args, **kwargs) @property - def async_completion(self): + def async_completion(self) -> Callable: """Decorator for the async litellm acompletion function.""" return self._async_completion diff --git a/openhands/llm/bedrock.py b/openhands/llm/bedrock.py index 2f32b2d79e..62cfe1780a 100644 --- a/openhands/llm/bedrock.py +++ b/openhands/llm/bedrock.py @@ -28,5 +28,5 @@ def list_foundation_models( return [] -def remove_error_modelId(model_list): +def remove_error_modelId(model_list: list[str]) -> list[str]: return list(filter(lambda m: not m.startswith('bedrock'), model_list)) diff --git a/openhands/llm/debug_mixin.py b/openhands/llm/debug_mixin.py index 6a247471ee..f80d98d7ad 100644 --- a/openhands/llm/debug_mixin.py +++ b/openhands/llm/debug_mixin.py @@ -7,7 +7,7 @@ MESSAGE_SEPARATOR = '\n\n----------\n\n' class DebugMixin: - def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]): + def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]) -> None: if not messages: logger.debug('No completion messages!') return @@ -24,11 +24,11 @@ class DebugMixin: else: logger.debug('No completion messages!') - def log_response(self, message_back: str): + def log_response(self, message_back: str) -> None: if message_back: llm_response_logger.debug(message_back) - def _format_message_content(self, message: dict[str, Any]): + def _format_message_content(self, message: dict[str, Any]) -> str: content = message['content'] if isinstance(content, list): return '\n'.join( @@ -36,18 +36,18 @@ class DebugMixin: ) return str(content) - def _format_content_element(self, element: dict[str, Any]): + def _format_content_element(self, element: dict[str, Any] | Any) -> str: if isinstance(element, dict): if 'text' in element: - return element['text'] + return str(element['text']) if ( self.vision_is_active() and 'image_url' in element and 'url' in element['image_url'] ): - return element['image_url']['url'] + return str(element['image_url']['url']) return str(element) # This method should be implemented in the class that uses DebugMixin - def vision_is_active(self): + def vision_is_active(self) -> bool: raise NotImplementedError diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 27a948fd1c..f834f039d9 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -186,7 +186,7 @@ class LLM(RetryMixin, DebugMixin): retry_multiplier=self.config.retry_multiplier, retry_listener=self.retry_listener, ) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: """Wrapper for the litellm completion function. Logs the input and output of the completion function.""" from openhands.io import json @@ -355,14 +355,14 @@ class LLM(RetryMixin, DebugMixin): self._completion = wrapper @property - def completion(self): + def completion(self) -> Callable: """Decorator for the litellm completion function. Check the complete documentation at https://litellm.vercel.app/docs/completion """ return self._completion - def init_model_info(self): + def init_model_info(self) -> None: if self._tried_model_info: return self._tried_model_info = True @@ -622,10 +622,12 @@ class LLM(RetryMixin, DebugMixin): # try to get the token count with the default litellm tokenizers # or the custom tokenizer if set for this LLM configuration try: - return litellm.token_counter( - model=self.config.model, - messages=messages, - custom_tokenizer=self.tokenizer, + return int( + litellm.token_counter( + model=self.config.model, + messages=messages, + custom_tokenizer=self.tokenizer, + ) ) except Exception as e: # limit logspam in case token count is not supported @@ -654,7 +656,7 @@ class LLM(RetryMixin, DebugMixin): return True return False - def _completion_cost(self, response) -> float: + def _completion_cost(self, response: Any) -> float: """Calculate completion cost and update metrics with running total. Calculate the cost of a completion response based on the model. Local models are treated as free. @@ -707,21 +709,21 @@ class LLM(RetryMixin, DebugMixin): logger.debug( f'Using fallback model name {_model_name} to get cost: {cost}' ) - self.metrics.add_cost(cost) - return cost + self.metrics.add_cost(float(cost)) + return float(cost) except Exception: self.cost_metric_supported = False logger.debug('Cost calculation not supported for this model.') return 0.0 - def __str__(self): + def __str__(self) -> str: if self.config.api_version: return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})' elif self.config.base_url: return f'LLM(model={self.config.model}, base_url={self.config.base_url})' return f'LLM(model={self.config.model})' - def __repr__(self): + def __repr__(self) -> str: return str(self) def reset(self) -> None: diff --git a/openhands/llm/metrics.py b/openhands/llm/metrics.py index d155dc23ee..3dfa963bf1 100644 --- a/openhands/llm/metrics.py +++ b/openhands/llm/metrics.py @@ -177,7 +177,7 @@ class Metrics: 'token_usages': [usage.model_dump() for usage in self._token_usages], } - def reset(self): + def reset(self) -> None: self._accumulated_cost = 0.0 self._costs = [] self._response_latencies = [] @@ -192,7 +192,7 @@ class Metrics: response_id='', ) - def log(self): + def log(self) -> str: """Log the metrics.""" metrics = self.get() logs = '' @@ -200,5 +200,5 @@ class Metrics: logs += f'{key}: {value}\n' return logs - def __repr__(self): + def __repr__(self) -> str: return f'Metrics({self.get()}' diff --git a/openhands/llm/retry_mixin.py b/openhands/llm/retry_mixin.py index 19b7c9689f..367bcbd97d 100644 --- a/openhands/llm/retry_mixin.py +++ b/openhands/llm/retry_mixin.py @@ -1,3 +1,5 @@ +from typing import Any, Callable + from tenacity import ( retry, retry_if_exception_type, @@ -13,7 +15,7 @@ from openhands.utils.tenacity_stop import stop_if_should_exit class RetryMixin: """Mixin class for retry logic.""" - def retry_decorator(self, **kwargs): + def retry_decorator(self, **kwargs: Any) -> Callable: """ Create a LLM retry decorator with customizable parameters. This is used for 429 errors, and a few other exceptions in LLM classes. @@ -31,7 +33,7 @@ class RetryMixin: retry_multiplier = kwargs.get('retry_multiplier') retry_listener = kwargs.get('retry_listener') - def before_sleep(retry_state): + def before_sleep(retry_state: Any) -> None: self.log_retry_attempt(retry_state) if retry_listener: retry_listener(retry_state.attempt_number, num_retries) @@ -52,7 +54,7 @@ class RetryMixin: f'LLMNoResponseError detected with temperature={current_temp}, keeping original temperature' ) - return retry( + retry_decorator: Callable = retry( before_sleep=before_sleep, stop=stop_after_attempt(num_retries) | stop_if_should_exit(), reraise=True, @@ -65,8 +67,9 @@ class RetryMixin: max=retry_max_wait, ), ) + return retry_decorator - def log_retry_attempt(self, retry_state): + def log_retry_attempt(self, retry_state: Any) -> None: """Log retry attempts.""" exception = retry_state.outcome.exception() logger.error( diff --git a/openhands/llm/streaming_llm.py b/openhands/llm/streaming_llm.py index 2a0e5b2d9d..d722f80d06 100644 --- a/openhands/llm/streaming_llm.py +++ b/openhands/llm/streaming_llm.py @@ -1,6 +1,6 @@ import asyncio from functools import partial -from typing import Any +from typing import Any, Callable from openhands.core.exceptions import UserCancelledError from openhands.core.logger import openhands_logger as logger @@ -11,7 +11,7 @@ from openhands.llm.llm import REASONING_EFFORT_SUPPORTED_MODELS class StreamingLLM(AsyncLLM): """Streaming LLM class.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._async_streaming_completion = partial( @@ -40,7 +40,7 @@ class StreamingLLM(AsyncLLM): retry_max_wait=self.config.retry_max_wait, retry_multiplier=self.config.retry_multiplier, ) - async def async_streaming_completion_wrapper(*args, **kwargs): + async def async_streaming_completion_wrapper(*args: Any, **kwargs: Any) -> Any: messages: list[dict[str, Any]] | dict[str, Any] = [] # some callers might send the model and messages directly @@ -108,6 +108,6 @@ class StreamingLLM(AsyncLLM): self._async_streaming_completion = async_streaming_completion_wrapper @property - def async_streaming_completion(self): + def async_streaming_completion(self) -> Callable: """Decorator for the async litellm acompletion function with streaming.""" return self._async_streaming_completion