mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
Add more extensive typing to openhands/llm/ directory (#7727)
Co-authored-by: openhands <openhands@all-hands.dev>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
from litellm import acompletion as litellm_acompletion
|
||||
|
||||
@@ -17,7 +17,7 @@ from openhands.utils.shutdown_listener import should_continue
|
||||
class AsyncLLM(LLM):
|
||||
"""Asynchronous LLM class."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._async_completion = partial(
|
||||
@@ -46,7 +46,7 @@ class AsyncLLM(LLM):
|
||||
retry_max_wait=self.config.retry_max_wait,
|
||||
retry_multiplier=self.config.retry_multiplier,
|
||||
)
|
||||
async def async_completion_wrapper(*args, **kwargs):
|
||||
async def async_completion_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Wrapper for the litellm acompletion function that adds logging and cost tracking."""
|
||||
messages: list[dict[str, Any]] | dict[str, Any] = []
|
||||
|
||||
@@ -77,7 +77,7 @@ class AsyncLLM(LLM):
|
||||
|
||||
self.log_prompt(messages)
|
||||
|
||||
async def check_stopped():
|
||||
async def check_stopped() -> None:
|
||||
while should_continue():
|
||||
if (
|
||||
hasattr(self.config, 'on_cancel_requested_fn')
|
||||
@@ -117,14 +117,14 @@ class AsyncLLM(LLM):
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._async_completion = async_completion_wrapper # type: ignore
|
||||
self._async_completion = async_completion_wrapper
|
||||
|
||||
async def _call_acompletion(self, *args, **kwargs):
|
||||
async def _call_acompletion(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Wrapper for the litellm acompletion function."""
|
||||
# Used in testing?
|
||||
return await litellm_acompletion(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def async_completion(self):
|
||||
def async_completion(self) -> Callable:
|
||||
"""Decorator for the async litellm acompletion function."""
|
||||
return self._async_completion
|
||||
|
||||
@@ -28,5 +28,5 @@ def list_foundation_models(
|
||||
return []
|
||||
|
||||
|
||||
def remove_error_modelId(model_list):
|
||||
def remove_error_modelId(model_list: list[str]) -> list[str]:
|
||||
return list(filter(lambda m: not m.startswith('bedrock'), model_list))
|
||||
|
||||
@@ -7,7 +7,7 @@ MESSAGE_SEPARATOR = '\n\n----------\n\n'
|
||||
|
||||
|
||||
class DebugMixin:
|
||||
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]):
|
||||
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]) -> None:
|
||||
if not messages:
|
||||
logger.debug('No completion messages!')
|
||||
return
|
||||
@@ -24,11 +24,11 @@ class DebugMixin:
|
||||
else:
|
||||
logger.debug('No completion messages!')
|
||||
|
||||
def log_response(self, message_back: str):
|
||||
def log_response(self, message_back: str) -> None:
|
||||
if message_back:
|
||||
llm_response_logger.debug(message_back)
|
||||
|
||||
def _format_message_content(self, message: dict[str, Any]):
|
||||
def _format_message_content(self, message: dict[str, Any]) -> str:
|
||||
content = message['content']
|
||||
if isinstance(content, list):
|
||||
return '\n'.join(
|
||||
@@ -36,18 +36,18 @@ class DebugMixin:
|
||||
)
|
||||
return str(content)
|
||||
|
||||
def _format_content_element(self, element: dict[str, Any]):
|
||||
def _format_content_element(self, element: dict[str, Any] | Any) -> str:
|
||||
if isinstance(element, dict):
|
||||
if 'text' in element:
|
||||
return element['text']
|
||||
return str(element['text'])
|
||||
if (
|
||||
self.vision_is_active()
|
||||
and 'image_url' in element
|
||||
and 'url' in element['image_url']
|
||||
):
|
||||
return element['image_url']['url']
|
||||
return str(element['image_url']['url'])
|
||||
return str(element)
|
||||
|
||||
# This method should be implemented in the class that uses DebugMixin
|
||||
def vision_is_active(self):
|
||||
def vision_is_active(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -186,7 +186,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
retry_multiplier=self.config.retry_multiplier,
|
||||
retry_listener=self.retry_listener,
|
||||
)
|
||||
def wrapper(*args, **kwargs):
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
|
||||
from openhands.io import json
|
||||
|
||||
@@ -355,14 +355,14 @@ class LLM(RetryMixin, DebugMixin):
|
||||
self._completion = wrapper
|
||||
|
||||
@property
|
||||
def completion(self):
|
||||
def completion(self) -> Callable:
|
||||
"""Decorator for the litellm completion function.
|
||||
|
||||
Check the complete documentation at https://litellm.vercel.app/docs/completion
|
||||
"""
|
||||
return self._completion
|
||||
|
||||
def init_model_info(self):
|
||||
def init_model_info(self) -> None:
|
||||
if self._tried_model_info:
|
||||
return
|
||||
self._tried_model_info = True
|
||||
@@ -622,10 +622,12 @@ class LLM(RetryMixin, DebugMixin):
|
||||
# try to get the token count with the default litellm tokenizers
|
||||
# or the custom tokenizer if set for this LLM configuration
|
||||
try:
|
||||
return litellm.token_counter(
|
||||
model=self.config.model,
|
||||
messages=messages,
|
||||
custom_tokenizer=self.tokenizer,
|
||||
return int(
|
||||
litellm.token_counter(
|
||||
model=self.config.model,
|
||||
messages=messages,
|
||||
custom_tokenizer=self.tokenizer,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# limit logspam in case token count is not supported
|
||||
@@ -654,7 +656,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _completion_cost(self, response) -> float:
|
||||
def _completion_cost(self, response: Any) -> float:
|
||||
"""Calculate completion cost and update metrics with running total.
|
||||
|
||||
Calculate the cost of a completion response based on the model. Local models are treated as free.
|
||||
@@ -707,21 +709,21 @@ class LLM(RetryMixin, DebugMixin):
|
||||
logger.debug(
|
||||
f'Using fallback model name {_model_name} to get cost: {cost}'
|
||||
)
|
||||
self.metrics.add_cost(cost)
|
||||
return cost
|
||||
self.metrics.add_cost(float(cost))
|
||||
return float(cost)
|
||||
except Exception:
|
||||
self.cost_metric_supported = False
|
||||
logger.debug('Cost calculation not supported for this model.')
|
||||
return 0.0
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
if self.config.api_version:
|
||||
return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
|
||||
elif self.config.base_url:
|
||||
return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
|
||||
return f'LLM(model={self.config.model})'
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return str(self)
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@@ -177,7 +177,7 @@ class Metrics:
|
||||
'token_usages': [usage.model_dump() for usage in self._token_usages],
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
self._accumulated_cost = 0.0
|
||||
self._costs = []
|
||||
self._response_latencies = []
|
||||
@@ -192,7 +192,7 @@ class Metrics:
|
||||
response_id='',
|
||||
)
|
||||
|
||||
def log(self):
|
||||
def log(self) -> str:
|
||||
"""Log the metrics."""
|
||||
metrics = self.get()
|
||||
logs = ''
|
||||
@@ -200,5 +200,5 @@ class Metrics:
|
||||
logs += f'{key}: {value}\n'
|
||||
return logs
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f'Metrics({self.get()}'
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any, Callable
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
@@ -13,7 +15,7 @@ from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
class RetryMixin:
|
||||
"""Mixin class for retry logic."""
|
||||
|
||||
def retry_decorator(self, **kwargs):
|
||||
def retry_decorator(self, **kwargs: Any) -> Callable:
|
||||
"""
|
||||
Create a LLM retry decorator with customizable parameters. This is used for 429 errors, and a few other exceptions in LLM classes.
|
||||
|
||||
@@ -31,7 +33,7 @@ class RetryMixin:
|
||||
retry_multiplier = kwargs.get('retry_multiplier')
|
||||
retry_listener = kwargs.get('retry_listener')
|
||||
|
||||
def before_sleep(retry_state):
|
||||
def before_sleep(retry_state: Any) -> None:
|
||||
self.log_retry_attempt(retry_state)
|
||||
if retry_listener:
|
||||
retry_listener(retry_state.attempt_number, num_retries)
|
||||
@@ -52,7 +54,7 @@ class RetryMixin:
|
||||
f'LLMNoResponseError detected with temperature={current_temp}, keeping original temperature'
|
||||
)
|
||||
|
||||
return retry(
|
||||
retry_decorator: Callable = retry(
|
||||
before_sleep=before_sleep,
|
||||
stop=stop_after_attempt(num_retries) | stop_if_should_exit(),
|
||||
reraise=True,
|
||||
@@ -65,8 +67,9 @@ class RetryMixin:
|
||||
max=retry_max_wait,
|
||||
),
|
||||
)
|
||||
return retry_decorator
|
||||
|
||||
def log_retry_attempt(self, retry_state):
|
||||
def log_retry_attempt(self, retry_state: Any) -> None:
|
||||
"""Log retry attempts."""
|
||||
exception = retry_state.outcome.exception()
|
||||
logger.error(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
from openhands.core.exceptions import UserCancelledError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -11,7 +11,7 @@ from openhands.llm.llm import REASONING_EFFORT_SUPPORTED_MODELS
|
||||
class StreamingLLM(AsyncLLM):
|
||||
"""Streaming LLM class."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._async_streaming_completion = partial(
|
||||
@@ -40,7 +40,7 @@ class StreamingLLM(AsyncLLM):
|
||||
retry_max_wait=self.config.retry_max_wait,
|
||||
retry_multiplier=self.config.retry_multiplier,
|
||||
)
|
||||
async def async_streaming_completion_wrapper(*args, **kwargs):
|
||||
async def async_streaming_completion_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
messages: list[dict[str, Any]] | dict[str, Any] = []
|
||||
|
||||
# some callers might send the model and messages directly
|
||||
@@ -108,6 +108,6 @@ class StreamingLLM(AsyncLLM):
|
||||
self._async_streaming_completion = async_streaming_completion_wrapper
|
||||
|
||||
@property
|
||||
def async_streaming_completion(self):
|
||||
def async_streaming_completion(self) -> Callable:
|
||||
"""Decorator for the async litellm acompletion function with streaming."""
|
||||
return self._async_streaming_completion
|
||||
|
||||
Reference in New Issue
Block a user