Add more extensive typing to openhands/llm/ directory (#7727)

Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
Graham Neubig
2025-04-06 13:59:25 -04:00
committed by GitHub
parent 288bcd254e
commit 9b8a628395
7 changed files with 43 additions and 38 deletions

View File

@@ -1,6 +1,6 @@
import asyncio
from functools import partial
from typing import Any
from typing import Any, Callable
from litellm import acompletion as litellm_acompletion
@@ -17,7 +17,7 @@ from openhands.utils.shutdown_listener import should_continue
class AsyncLLM(LLM):
"""Asynchronous LLM class."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._async_completion = partial(
@@ -46,7 +46,7 @@ class AsyncLLM(LLM):
retry_max_wait=self.config.retry_max_wait,
retry_multiplier=self.config.retry_multiplier,
)
async def async_completion_wrapper(*args, **kwargs):
async def async_completion_wrapper(*args: Any, **kwargs: Any) -> Any:
"""Wrapper for the litellm acompletion function that adds logging and cost tracking."""
messages: list[dict[str, Any]] | dict[str, Any] = []
@@ -77,7 +77,7 @@ class AsyncLLM(LLM):
self.log_prompt(messages)
async def check_stopped():
async def check_stopped() -> None:
while should_continue():
if (
hasattr(self.config, 'on_cancel_requested_fn')
@@ -117,14 +117,14 @@ class AsyncLLM(LLM):
except asyncio.CancelledError:
pass
self._async_completion = async_completion_wrapper # type: ignore
self._async_completion = async_completion_wrapper
async def _call_acompletion(self, *args, **kwargs):
async def _call_acompletion(self, *args: Any, **kwargs: Any) -> Any:
"""Wrapper for the litellm acompletion function."""
# Used in testing?
return await litellm_acompletion(*args, **kwargs)
@property
def async_completion(self):
def async_completion(self) -> Callable:
"""Decorator for the async litellm acompletion function."""
return self._async_completion

View File

@@ -28,5 +28,5 @@ def list_foundation_models(
return []
def remove_error_modelId(model_list):
def remove_error_modelId(model_list: list[str]) -> list[str]:
return list(filter(lambda m: not m.startswith('bedrock'), model_list))

View File

@@ -7,7 +7,7 @@ MESSAGE_SEPARATOR = '\n\n----------\n\n'
class DebugMixin:
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]):
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]) -> None:
if not messages:
logger.debug('No completion messages!')
return
@@ -24,11 +24,11 @@ class DebugMixin:
else:
logger.debug('No completion messages!')
def log_response(self, message_back: str):
def log_response(self, message_back: str) -> None:
if message_back:
llm_response_logger.debug(message_back)
def _format_message_content(self, message: dict[str, Any]):
def _format_message_content(self, message: dict[str, Any]) -> str:
content = message['content']
if isinstance(content, list):
return '\n'.join(
@@ -36,18 +36,18 @@ class DebugMixin:
)
return str(content)
def _format_content_element(self, element: dict[str, Any]):
def _format_content_element(self, element: dict[str, Any] | Any) -> str:
if isinstance(element, dict):
if 'text' in element:
return element['text']
return str(element['text'])
if (
self.vision_is_active()
and 'image_url' in element
and 'url' in element['image_url']
):
return element['image_url']['url']
return str(element['image_url']['url'])
return str(element)
# This method should be implemented in the class that uses DebugMixin
def vision_is_active(self):
def vision_is_active(self) -> bool:
raise NotImplementedError

View File

@@ -186,7 +186,7 @@ class LLM(RetryMixin, DebugMixin):
retry_multiplier=self.config.retry_multiplier,
retry_listener=self.retry_listener,
)
def wrapper(*args, **kwargs):
def wrapper(*args: Any, **kwargs: Any) -> Any:
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
from openhands.io import json
@@ -355,14 +355,14 @@ class LLM(RetryMixin, DebugMixin):
self._completion = wrapper
@property
def completion(self):
def completion(self) -> Callable:
"""Decorator for the litellm completion function.
Check the complete documentation at https://litellm.vercel.app/docs/completion
"""
return self._completion
def init_model_info(self):
def init_model_info(self) -> None:
if self._tried_model_info:
return
self._tried_model_info = True
@@ -622,10 +622,12 @@ class LLM(RetryMixin, DebugMixin):
# try to get the token count with the default litellm tokenizers
# or the custom tokenizer if set for this LLM configuration
try:
return litellm.token_counter(
model=self.config.model,
messages=messages,
custom_tokenizer=self.tokenizer,
return int(
litellm.token_counter(
model=self.config.model,
messages=messages,
custom_tokenizer=self.tokenizer,
)
)
except Exception as e:
# limit logspam in case token count is not supported
@@ -654,7 +656,7 @@ class LLM(RetryMixin, DebugMixin):
return True
return False
def _completion_cost(self, response) -> float:
def _completion_cost(self, response: Any) -> float:
"""Calculate completion cost and update metrics with running total.
Calculate the cost of a completion response based on the model. Local models are treated as free.
@@ -707,21 +709,21 @@ class LLM(RetryMixin, DebugMixin):
logger.debug(
f'Using fallback model name {_model_name} to get cost: {cost}'
)
self.metrics.add_cost(cost)
return cost
self.metrics.add_cost(float(cost))
return float(cost)
except Exception:
self.cost_metric_supported = False
logger.debug('Cost calculation not supported for this model.')
return 0.0
def __str__(self):
def __str__(self) -> str:
if self.config.api_version:
return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
elif self.config.base_url:
return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
return f'LLM(model={self.config.model})'
def __repr__(self):
def __repr__(self) -> str:
return str(self)
def reset(self) -> None:

View File

@@ -177,7 +177,7 @@ class Metrics:
'token_usages': [usage.model_dump() for usage in self._token_usages],
}
def reset(self):
def reset(self) -> None:
self._accumulated_cost = 0.0
self._costs = []
self._response_latencies = []
@@ -192,7 +192,7 @@ class Metrics:
response_id='',
)
def log(self):
def log(self) -> str:
"""Log the metrics."""
metrics = self.get()
logs = ''
@@ -200,5 +200,5 @@ class Metrics:
logs += f'{key}: {value}\n'
return logs
def __repr__(self):
def __repr__(self) -> str:
return f'Metrics({self.get()}'

View File

@@ -1,3 +1,5 @@
from typing import Any, Callable
from tenacity import (
retry,
retry_if_exception_type,
@@ -13,7 +15,7 @@ from openhands.utils.tenacity_stop import stop_if_should_exit
class RetryMixin:
"""Mixin class for retry logic."""
def retry_decorator(self, **kwargs):
def retry_decorator(self, **kwargs: Any) -> Callable:
"""
Create a LLM retry decorator with customizable parameters. This is used for 429 errors, and a few other exceptions in LLM classes.
@@ -31,7 +33,7 @@ class RetryMixin:
retry_multiplier = kwargs.get('retry_multiplier')
retry_listener = kwargs.get('retry_listener')
def before_sleep(retry_state):
def before_sleep(retry_state: Any) -> None:
self.log_retry_attempt(retry_state)
if retry_listener:
retry_listener(retry_state.attempt_number, num_retries)
@@ -52,7 +54,7 @@ class RetryMixin:
f'LLMNoResponseError detected with temperature={current_temp}, keeping original temperature'
)
return retry(
retry_decorator: Callable = retry(
before_sleep=before_sleep,
stop=stop_after_attempt(num_retries) | stop_if_should_exit(),
reraise=True,
@@ -65,8 +67,9 @@ class RetryMixin:
max=retry_max_wait,
),
)
return retry_decorator
def log_retry_attempt(self, retry_state):
def log_retry_attempt(self, retry_state: Any) -> None:
"""Log retry attempts."""
exception = retry_state.outcome.exception()
logger.error(

View File

@@ -1,6 +1,6 @@
import asyncio
from functools import partial
from typing import Any
from typing import Any, Callable
from openhands.core.exceptions import UserCancelledError
from openhands.core.logger import openhands_logger as logger
@@ -11,7 +11,7 @@ from openhands.llm.llm import REASONING_EFFORT_SUPPORTED_MODELS
class StreamingLLM(AsyncLLM):
"""Streaming LLM class."""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._async_streaming_completion = partial(
@@ -40,7 +40,7 @@ class StreamingLLM(AsyncLLM):
retry_max_wait=self.config.retry_max_wait,
retry_multiplier=self.config.retry_multiplier,
)
async def async_streaming_completion_wrapper(*args, **kwargs):
async def async_streaming_completion_wrapper(*args: Any, **kwargs: Any) -> Any:
messages: list[dict[str, Any]] | dict[str, Any] = []
# some callers might send the model and messages directly
@@ -108,6 +108,6 @@ class StreamingLLM(AsyncLLM):
self._async_streaming_completion = async_streaming_completion_wrapper
@property
def async_streaming_completion(self):
def async_streaming_completion(self) -> Callable:
"""Decorator for the async litellm acompletion function with streaming."""
return self._async_streaming_completion