mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
26 Commits
replace-si
...
fix/llm-my
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ac07afbb3c | ||
|
|
b9f330db1b | ||
|
|
096b259694 | ||
|
|
7f3202d0e0 | ||
|
|
e63dfca517 | ||
|
|
cb705b736f | ||
|
|
1e5c4da0fc | ||
|
|
526618753d | ||
|
|
e25e6766fb | ||
|
|
1b9a2b43c3 | ||
|
|
7886c1f920 | ||
|
|
4b49ffb01d | ||
|
|
f74ce56a35 | ||
|
|
9ccf680c38 | ||
|
|
41068f6ea1 | ||
|
|
5b1c8bc2e8 | ||
|
|
5b8db983b7 | ||
|
|
18543a2efa | ||
|
|
63d9c3d668 | ||
|
|
2589b13815 | ||
|
|
592aca05e1 | ||
|
|
d309455733 | ||
|
|
66a7920539 | ||
|
|
64ebef3646 | ||
|
|
7a259915c1 | ||
|
|
66bd8fdbcd |
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
from litellm import acompletion as litellm_acompletion
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from openhands.core.exceptions import UserCancelledError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -17,7 +18,9 @@ from openhands.utils.shutdown_listener import should_continue
|
||||
class AsyncLLM(LLM):
|
||||
"""Asynchronous LLM class."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
_async_completion: Callable[..., Coroutine[Any, Any, ModelResponse]]
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._async_completion = partial(
|
||||
@@ -37,7 +40,9 @@ class AsyncLLM(LLM):
|
||||
seed=self.config.seed,
|
||||
)
|
||||
|
||||
async_completion_unwrapped = self._async_completion
|
||||
async_completion_unwrapped: Callable[
|
||||
..., Coroutine[Any, Any, ModelResponse]
|
||||
] = self._async_completion
|
||||
|
||||
@self.retry_decorator(
|
||||
num_retries=self.config.num_retries,
|
||||
@@ -46,7 +51,7 @@ class AsyncLLM(LLM):
|
||||
retry_max_wait=self.config.retry_max_wait,
|
||||
retry_multiplier=self.config.retry_multiplier,
|
||||
)
|
||||
async def async_completion_wrapper(*args, **kwargs):
|
||||
async def async_completion_wrapper(*args: Any, **kwargs: Any) -> dict[str, Any]:
|
||||
"""Wrapper for the litellm acompletion function that adds logging and cost tracking."""
|
||||
messages: list[dict[str, Any]] | dict[str, Any] = []
|
||||
|
||||
@@ -77,7 +82,7 @@ class AsyncLLM(LLM):
|
||||
|
||||
self.log_prompt(messages)
|
||||
|
||||
async def check_stopped():
|
||||
async def check_stopped() -> None:
|
||||
while should_continue():
|
||||
if (
|
||||
hasattr(self.config, 'on_cancel_requested_fn')
|
||||
@@ -97,10 +102,8 @@ class AsyncLLM(LLM):
|
||||
self.log_response(message_back)
|
||||
|
||||
# log costs and tokens used
|
||||
self._post_completion(resp)
|
||||
|
||||
# We do not support streaming in this method, thus return resp
|
||||
return resp
|
||||
return dict(resp)
|
||||
|
||||
except UserCancelledError:
|
||||
logger.debug('LLM request cancelled by user.')
|
||||
@@ -117,14 +120,15 @@ class AsyncLLM(LLM):
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._async_completion = async_completion_wrapper # type: ignore
|
||||
self._async_completion = async_completion_wrapper
|
||||
|
||||
async def _call_acompletion(self, *args, **kwargs):
|
||||
async def _call_acompletion(self, *args: Any, **kwargs: Any) -> ModelResponse:
|
||||
"""Wrapper for the litellm acompletion function."""
|
||||
# Used in testing?
|
||||
return await litellm_acompletion(*args, **kwargs)
|
||||
resp = await litellm_acompletion(*args, **kwargs)
|
||||
return ModelResponse(**resp)
|
||||
|
||||
@property
|
||||
def async_completion(self):
|
||||
def async_completion(self) -> Callable[..., Coroutine[Any, Any, ModelResponse]]:
|
||||
"""Decorator for the async litellm acompletion function."""
|
||||
return self._async_completion
|
||||
|
||||
@@ -28,5 +28,5 @@ def list_foundation_models(
|
||||
return []
|
||||
|
||||
|
||||
def remove_error_modelId(model_list):
|
||||
def remove_error_modelId(model_list: list[str]) -> list[str]:
|
||||
return list(filter(lambda m: not m.startswith('bedrock'), model_list))
|
||||
|
||||
@@ -7,7 +7,7 @@ MESSAGE_SEPARATOR = '\n\n----------\n\n'
|
||||
|
||||
|
||||
class DebugMixin:
|
||||
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]):
|
||||
def log_prompt(self, messages: list[dict[str, Any]] | dict[str, Any]) -> None:
|
||||
if not messages:
|
||||
logger.debug('No completion messages!')
|
||||
return
|
||||
@@ -24,11 +24,11 @@ class DebugMixin:
|
||||
else:
|
||||
logger.debug('No completion messages!')
|
||||
|
||||
def log_response(self, message_back: str):
|
||||
def log_response(self, message_back: str) -> None:
|
||||
if message_back:
|
||||
llm_response_logger.debug(message_back)
|
||||
|
||||
def _format_message_content(self, message: dict[str, Any]):
|
||||
def _format_message_content(self, message: dict[str, Any]) -> str:
|
||||
content = message['content']
|
||||
if isinstance(content, list):
|
||||
return '\n'.join(
|
||||
@@ -36,18 +36,18 @@ class DebugMixin:
|
||||
)
|
||||
return str(content)
|
||||
|
||||
def _format_content_element(self, element: dict[str, Any]):
|
||||
def _format_content_element(self, element: dict[str, Any]) -> str:
|
||||
if isinstance(element, dict):
|
||||
if 'text' in element:
|
||||
return element['text']
|
||||
return str(element['text'])
|
||||
if (
|
||||
self.vision_is_active()
|
||||
and 'image_url' in element
|
||||
and 'url' in element['image_url']
|
||||
):
|
||||
return element['image_url']['url']
|
||||
return str(element['image_url']['url'])
|
||||
return str(element)
|
||||
|
||||
# This method should be implemented in the class that uses DebugMixin
|
||||
def vision_is_active(self):
|
||||
def vision_is_active(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -9,7 +9,7 @@ We follow format from: https://docs.litellm.ai/docs/completion/function_call
|
||||
import copy
|
||||
import json
|
||||
import re
|
||||
from typing import Iterable
|
||||
from typing import Any, Iterable
|
||||
|
||||
from litellm import ChatCompletionToolParam
|
||||
|
||||
@@ -265,7 +265,7 @@ def convert_tool_call_to_string(tool_call: dict) -> str:
|
||||
return ret
|
||||
|
||||
|
||||
def convert_tools_to_description(tools: list[dict]) -> str:
|
||||
def convert_tools_to_description(tools: list[ChatCompletionToolParam]) -> str:
|
||||
ret = ''
|
||||
for i, tool in enumerate(tools):
|
||||
assert tool['type'] == 'function'
|
||||
@@ -474,8 +474,8 @@ def convert_fncall_messages_to_non_fncall_messages(
|
||||
|
||||
|
||||
def _extract_and_validate_params(
|
||||
matching_tool: dict, param_matches: Iterable[re.Match], fn_name: str
|
||||
) -> dict:
|
||||
matching_tool: dict[str, Any], param_matches: Iterable[re.Match], fn_name: str
|
||||
) -> dict[str, Any]:
|
||||
params = {}
|
||||
# Parse and validate parameters
|
||||
required_params = set()
|
||||
@@ -712,7 +712,7 @@ def convert_non_fncall_messages_to_fncall_messages(
|
||||
# Parse parameters
|
||||
param_matches = re.finditer(FN_PARAM_REGEX_PATTERN, fn_body, re.DOTALL)
|
||||
params = _extract_and_validate_params(
|
||||
matching_tool, param_matches, fn_name
|
||||
dict(matching_tool), param_matches, fn_name
|
||||
)
|
||||
|
||||
# Create tool call with unique ID
|
||||
|
||||
@@ -2,7 +2,6 @@ import copy
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any, Callable
|
||||
|
||||
import requests
|
||||
@@ -13,14 +12,15 @@ with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
import litellm
|
||||
|
||||
from litellm import ChatCompletionMessageToolCall, ModelInfo, PromptTokensDetails
|
||||
from litellm import Message as LiteLLMMessage
|
||||
from litellm import ChatCompletionMessageToolCall, PromptTokensDetails
|
||||
from litellm import completion as litellm_completion
|
||||
from litellm import completion_cost as litellm_completion_cost
|
||||
from litellm.exceptions import (
|
||||
RateLimitError,
|
||||
)
|
||||
from litellm.types.router import ModelInfo as RouterModelInfo
|
||||
from litellm.types.utils import CostPerToken, ModelResponse, Usage
|
||||
from litellm.types.utils import ModelInfo as UtilsModelInfo
|
||||
from litellm.utils import create_pretrained_tokenizer
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -87,6 +87,8 @@ class LLM(RetryMixin, DebugMixin):
|
||||
config: an LLMConfig object specifying the configuration of the LLM.
|
||||
"""
|
||||
|
||||
_completion: Callable[..., ModelResponse]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LLMConfig,
|
||||
@@ -108,7 +110,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
self.cost_metric_supported: bool = True
|
||||
self.config: LLMConfig = copy.deepcopy(config)
|
||||
|
||||
self.model_info: ModelInfo | None = None
|
||||
self.model_info: RouterModelInfo | UtilsModelInfo | None = None
|
||||
self.retry_listener = retry_listener
|
||||
if self.config.log_completions:
|
||||
if self.config.log_completions_folder is None:
|
||||
@@ -153,23 +155,26 @@ class LLM(RetryMixin, DebugMixin):
|
||||
kwargs['max_tokens'] = self.config.max_output_tokens
|
||||
kwargs.pop('max_completion_tokens')
|
||||
|
||||
self._completion = partial(
|
||||
litellm_completion,
|
||||
model=self.config.model,
|
||||
api_key=self.config.api_key.get_secret_value()
|
||||
if self.config.api_key
|
||||
else None,
|
||||
base_url=self.config.base_url,
|
||||
api_version=self.config.api_version,
|
||||
custom_llm_provider=self.config.custom_llm_provider,
|
||||
timeout=self.config.timeout,
|
||||
top_p=self.config.top_p,
|
||||
drop_params=self.config.drop_params,
|
||||
seed=self.config.seed,
|
||||
**kwargs,
|
||||
)
|
||||
# Create a wrapper function that captures the config values
|
||||
def completion_with_config(*args: Any, **user_kwargs: Any) -> ModelResponse:
|
||||
"""Wrapper for litellm_completion that includes the config values."""
|
||||
merged_kwargs = {
|
||||
'model': self.config.model,
|
||||
'api_key': self.config.api_key.get_secret_value()
|
||||
if self.config.api_key
|
||||
else None,
|
||||
'base_url': self.config.base_url,
|
||||
'api_version': self.config.api_version,
|
||||
'custom_llm_provider': self.config.custom_llm_provider,
|
||||
'timeout': self.config.timeout,
|
||||
'top_p': self.config.top_p,
|
||||
'drop_params': self.config.drop_params,
|
||||
**kwargs,
|
||||
**user_kwargs,
|
||||
}
|
||||
return litellm_completion(*args, **merged_kwargs)
|
||||
|
||||
self._completion_unwrapped = self._completion
|
||||
self._completion_unwrapped = completion_with_config
|
||||
|
||||
@self.retry_decorator(
|
||||
num_retries=self.config.num_retries,
|
||||
@@ -179,7 +184,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
retry_multiplier=self.config.retry_multiplier,
|
||||
retry_listener=self.retry_listener,
|
||||
)
|
||||
def wrapper(*args, **kwargs):
|
||||
def wrapper(*args: Any, **kwargs: Any) -> ModelResponse:
|
||||
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
|
||||
from openhands.io import json
|
||||
|
||||
@@ -257,20 +262,18 @@ class LLM(RetryMixin, DebugMixin):
|
||||
|
||||
# if we mocked function calling, and we have tools, convert the response back to function calling format
|
||||
if mock_function_calling and mock_fncall_tools is not None:
|
||||
logger.debug(f'Response choices: {len(resp.choices)}')
|
||||
assert len(resp.choices) >= 1
|
||||
non_fncall_response_message = resp.choices[0].message
|
||||
fn_call_messages_with_response = (
|
||||
convert_non_fncall_messages_to_fncall_messages(
|
||||
messages + [non_fncall_response_message], mock_fncall_tools
|
||||
assert len(resp.choices) == 1
|
||||
if isinstance(resp.choices[0], dict) and 'message' in resp.choices[0]:
|
||||
non_fncall_response_message = resp.choices[0]['message']
|
||||
fn_call_messages_with_response = (
|
||||
convert_non_fncall_messages_to_fncall_messages(
|
||||
messages + [dict(non_fncall_response_message)],
|
||||
mock_fncall_tools,
|
||||
)
|
||||
)
|
||||
)
|
||||
fn_call_response_message = fn_call_messages_with_response[-1]
|
||||
if not isinstance(fn_call_response_message, LiteLLMMessage):
|
||||
fn_call_response_message = LiteLLMMessage(
|
||||
**fn_call_response_message
|
||||
)
|
||||
resp.choices[0].message = fn_call_response_message
|
||||
fn_call_response_message = fn_call_messages_with_response[-1]
|
||||
fn_call_response_message = dict(fn_call_response_message)
|
||||
resp.choices[0]['message'] = fn_call_response_message
|
||||
|
||||
message_back: str = resp['choices'][0]['message']['content'] or ''
|
||||
tool_calls: list[ChatCompletionMessageToolCall] = resp['choices'][0][
|
||||
@@ -327,14 +330,14 @@ class LLM(RetryMixin, DebugMixin):
|
||||
self._completion = wrapper
|
||||
|
||||
@property
|
||||
def completion(self):
|
||||
def completion(self) -> Callable[..., ModelResponse]:
|
||||
"""Decorator for the litellm completion function.
|
||||
|
||||
Check the complete documentation at https://litellm.vercel.app/docs/completion
|
||||
"""
|
||||
return self._completion
|
||||
|
||||
def init_model_info(self):
|
||||
def init_model_info(self) -> None:
|
||||
if self._tried_model_info:
|
||||
return
|
||||
self._tried_model_info = True
|
||||
@@ -464,11 +467,11 @@ class LLM(RetryMixin, DebugMixin):
|
||||
# remove when litellm is updated to fix https://github.com/BerriAI/litellm/issues/5608
|
||||
# Check both the full model name and the name after proxy prefix for vision support
|
||||
return (
|
||||
litellm.supports_vision(self.config.model)
|
||||
or litellm.supports_vision(self.config.model.split('/')[-1])
|
||||
bool(litellm.supports_vision(self.config.model))
|
||||
or bool(litellm.supports_vision(self.config.model.split('/')[-1]))
|
||||
or (
|
||||
self.model_info is not None
|
||||
and self.model_info.get('supports_vision', False)
|
||||
and bool(self.model_info.get('supports_vision', False))
|
||||
)
|
||||
)
|
||||
|
||||
@@ -624,7 +627,7 @@ class LLM(RetryMixin, DebugMixin):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _completion_cost(self, response) -> float:
|
||||
def _completion_cost(self, response: ModelResponse) -> float:
|
||||
"""Calculate completion cost and update metrics with running total.
|
||||
|
||||
Calculate the cost of a completion response based on the model. Local models are treated as free.
|
||||
@@ -663,35 +666,45 @@ class LLM(RetryMixin, DebugMixin):
|
||||
try:
|
||||
if cost is None:
|
||||
try:
|
||||
cost = litellm_completion_cost(
|
||||
completion_response=response, **extra_kwargs
|
||||
cost = float(
|
||||
litellm_completion_cost(
|
||||
completion_response=response,
|
||||
custom_cost_per_token=extra_kwargs.get(
|
||||
'custom_cost_per_token'
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Error getting cost from litellm: {e}')
|
||||
|
||||
if cost is None:
|
||||
_model_name = '/'.join(self.config.model.split('/')[1:])
|
||||
cost = litellm_completion_cost(
|
||||
completion_response=response, model=_model_name, **extra_kwargs
|
||||
cost = float(
|
||||
litellm_completion_cost(
|
||||
completion_response=response,
|
||||
model=_model_name,
|
||||
custom_cost_per_token=extra_kwargs.get('custom_cost_per_token'),
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f'Using fallback model name {_model_name} to get cost: {cost}'
|
||||
)
|
||||
self.metrics.add_cost(cost)
|
||||
return cost
|
||||
cost_float = float(cost)
|
||||
self.metrics.add_cost(cost_float)
|
||||
return cost_float
|
||||
except Exception:
|
||||
self.cost_metric_supported = False
|
||||
logger.debug('Cost calculation not supported for this model.')
|
||||
return 0.0
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
if self.config.api_version:
|
||||
return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
|
||||
elif self.config.base_url:
|
||||
return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
|
||||
return f'LLM(model={self.config.model})'
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return str(self)
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@@ -129,13 +129,13 @@ class Metrics:
|
||||
'token_usages': [usage.model_dump() for usage in self._token_usages],
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
self._accumulated_cost = 0.0
|
||||
self._costs = []
|
||||
self._response_latencies = []
|
||||
self._token_usages = []
|
||||
|
||||
def log(self):
|
||||
def log(self) -> str:
|
||||
"""Log the metrics."""
|
||||
metrics = self.get()
|
||||
logs = ''
|
||||
@@ -143,5 +143,5 @@ class Metrics:
|
||||
logs += f'{key}: {value}\n'
|
||||
return logs
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f'Metrics({self.get()}'
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any, Callable
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
@@ -12,25 +14,35 @@ from openhands.utils.tenacity_stop import stop_if_should_exit
|
||||
class RetryMixin:
|
||||
"""Mixin class for retry logic."""
|
||||
|
||||
def retry_decorator(self, **kwargs):
|
||||
def retry_decorator(
|
||||
self,
|
||||
*,
|
||||
num_retries: int | None = None,
|
||||
retry_exceptions: tuple = (),
|
||||
retry_min_wait: int | None = None,
|
||||
retry_max_wait: int | None = None,
|
||||
retry_multiplier: float | None = None,
|
||||
retry_listener: Any | None = None,
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""
|
||||
Create a LLM retry decorator with customizable parameters. This is used for 429 errors, and a few other exceptions in LLM classes.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to override default retry behavior.
|
||||
Keys: num_retries, retry_exceptions, retry_min_wait, retry_max_wait, retry_multiplier
|
||||
num_retries: Number of retries before giving up
|
||||
retry_exceptions: Tuple of exception types to retry on
|
||||
retry_min_wait: Minimum wait time between retries in seconds
|
||||
retry_max_wait: Maximum wait time between retries in seconds
|
||||
retry_multiplier: Multiplier for exponential backoff
|
||||
retry_listener: Optional callback for retry events
|
||||
|
||||
Returns:
|
||||
A retry decorator with the parameters customizable in configuration.
|
||||
"""
|
||||
num_retries = kwargs.get('num_retries')
|
||||
retry_exceptions: tuple = kwargs.get('retry_exceptions', ())
|
||||
retry_min_wait = kwargs.get('retry_min_wait')
|
||||
retry_max_wait = kwargs.get('retry_max_wait')
|
||||
retry_multiplier = kwargs.get('retry_multiplier')
|
||||
retry_listener = kwargs.get('retry_listener')
|
||||
# Use the values from config if not provided
|
||||
# Note: These values are already set in LLMConfig with appropriate defaults
|
||||
# See openhands/core/config/llm_config.py for the actual default values
|
||||
|
||||
def before_sleep(retry_state):
|
||||
def before_sleep(retry_state: Any) -> None:
|
||||
self.log_retry_attempt(retry_state)
|
||||
if retry_listener:
|
||||
retry_listener(retry_state.attempt_number, num_retries)
|
||||
@@ -49,7 +61,7 @@ class RetryMixin:
|
||||
),
|
||||
)
|
||||
|
||||
def log_retry_attempt(self, retry_state):
|
||||
def log_retry_attempt(self, retry_state: Any) -> None:
|
||||
"""Log retry attempts."""
|
||||
exception = retry_state.outcome.exception()
|
||||
logger.error(
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import Any, AsyncIterator, Callable, Coroutine
|
||||
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from openhands.core.exceptions import UserCancelledError
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -11,7 +13,11 @@ from openhands.llm.llm import REASONING_EFFORT_SUPPORTED_MODELS
|
||||
class StreamingLLM(AsyncLLM):
|
||||
"""Streaming LLM class."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
_async_streaming_completion: Callable[
|
||||
..., Coroutine[Any, Any, AsyncIterator[ModelResponse]]
|
||||
]
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._async_streaming_completion = partial(
|
||||
@@ -31,7 +37,9 @@ class StreamingLLM(AsyncLLM):
|
||||
stream=True, # Ensure streaming is enabled
|
||||
)
|
||||
|
||||
async_streaming_completion_unwrapped = self._async_streaming_completion
|
||||
async_streaming_completion_unwrapped: Callable[
|
||||
..., Coroutine[Any, Any, AsyncIterator[ModelResponse]]
|
||||
] = self._async_streaming_completion
|
||||
|
||||
@self.retry_decorator(
|
||||
num_retries=self.config.num_retries,
|
||||
@@ -40,7 +48,7 @@ class StreamingLLM(AsyncLLM):
|
||||
retry_max_wait=self.config.retry_max_wait,
|
||||
retry_multiplier=self.config.retry_multiplier,
|
||||
)
|
||||
async def async_streaming_completion_wrapper(*args, **kwargs):
|
||||
async def async_streaming_completion_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
messages: list[dict[str, Any]] | dict[str, Any] = []
|
||||
|
||||
# some callers might send the model and messages directly
|
||||
@@ -89,9 +97,10 @@ class StreamingLLM(AsyncLLM):
|
||||
message_back = chunk['choices'][0]['delta'].get('content', '')
|
||||
if message_back:
|
||||
self.log_response(message_back)
|
||||
self._post_completion(chunk)
|
||||
chunk_dict = dict(chunk)
|
||||
self._post_completion(ModelResponse(**chunk_dict))
|
||||
|
||||
yield chunk
|
||||
yield chunk_dict
|
||||
|
||||
except UserCancelledError:
|
||||
logger.debug('LLM request cancelled by user.')
|
||||
@@ -108,6 +117,8 @@ class StreamingLLM(AsyncLLM):
|
||||
self._async_streaming_completion = async_streaming_completion_wrapper
|
||||
|
||||
@property
|
||||
def async_streaming_completion(self):
|
||||
def async_streaming_completion(
|
||||
self,
|
||||
) -> Callable[..., Coroutine[Any, Any, AsyncIterator[ModelResponse]]]:
|
||||
"""Decorator for the async litellm acompletion function with streaming."""
|
||||
return self._async_streaming_completion
|
||||
|
||||
Reference in New Issue
Block a user