mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-10 07:18:10 -05:00
(feat) LLM class: added acompletion and streaming + unit test (#3202)
* LLM class: added acompletion and streaming, unit test test_acompletion.py * LLM: cleanup of self.config defaults and their use * added set_missing_attributes to LLMConfig * move default checker up
This commit is contained in:
@@ -111,6 +111,12 @@ class LLMConfig:
|
||||
ret[k] = '******' if v else None
|
||||
return ret
|
||||
|
||||
def set_missing_attributes(self):
|
||||
"""Set any missing attributes to their default values."""
|
||||
for field_name, field_obj in self.__dataclass_fields__.items():
|
||||
if not hasattr(self, field_name):
|
||||
setattr(self, field_name, field_obj.default)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
|
||||
@@ -67,3 +67,8 @@ class LLMNoActionError(Exception):
|
||||
class LLMResponseError(Exception):
|
||||
def __init__(self, message='Failed to retrieve action from LLM response'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class UserCancelledError(Exception):
|
||||
def __init__(self, message='User cancelled the request'):
|
||||
super().__init__(message)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import warnings
|
||||
from functools import partial
|
||||
@@ -13,6 +14,7 @@ from litellm.exceptions import (
|
||||
APIConnectionError,
|
||||
ContentPolicyViolationError,
|
||||
InternalServerError,
|
||||
OpenAIError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
)
|
||||
@@ -24,6 +26,7 @@ from tenacity import (
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from opendevin.core.exceptions import UserCancelledError
|
||||
from opendevin.core.logger import llm_prompt_logger, llm_response_logger
|
||||
from opendevin.core.logger import opendevin_logger as logger
|
||||
from opendevin.core.metrics import Metrics
|
||||
@@ -56,6 +59,9 @@ class LLM:
|
||||
self.metrics = metrics if metrics is not None else Metrics()
|
||||
self.cost_metric_supported = True
|
||||
|
||||
# Set up config attributes with default values to prevent AttributeError
|
||||
LLMConfig.set_missing_attributes(self.config)
|
||||
|
||||
# litellm actually uses base Exception here for unknown model
|
||||
self.model_info = None
|
||||
try:
|
||||
@@ -66,11 +72,11 @@ class LLM:
|
||||
self.config.model.split(':')[0]
|
||||
)
|
||||
# noinspection PyBroadException
|
||||
except Exception:
|
||||
logger.warning(f'Could not get model info for {config.model}')
|
||||
except Exception as e:
|
||||
logger.warning(f'Could not get model info for {config.model}:\n{e}')
|
||||
|
||||
# Set the max tokens in an LM-specific way if not set
|
||||
if config.max_input_tokens is None:
|
||||
if self.config.max_input_tokens is None:
|
||||
if (
|
||||
self.model_info is not None
|
||||
and 'max_input_tokens' in self.model_info
|
||||
@@ -81,7 +87,7 @@ class LLM:
|
||||
# Max input tokens for gpt3.5, so this is a safe fallback for any potentially viable model
|
||||
self.config.max_input_tokens = 4096
|
||||
|
||||
if config.max_output_tokens is None:
|
||||
if self.config.max_output_tokens is None:
|
||||
if (
|
||||
self.model_info is not None
|
||||
and 'max_output_tokens' in self.model_info
|
||||
@@ -119,11 +125,11 @@ class LLM:
|
||||
|
||||
@retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(config.num_retries),
|
||||
stop=stop_after_attempt(self.config.num_retries),
|
||||
wait=wait_random_exponential(
|
||||
multiplier=config.retry_multiplier,
|
||||
min=config.retry_min_wait,
|
||||
max=config.retry_max_wait,
|
||||
multiplier=self.config.retry_multiplier,
|
||||
min=self.config.retry_min_wait,
|
||||
max=self.config.retry_max_wait,
|
||||
),
|
||||
retry=retry_if_exception_type(
|
||||
(
|
||||
@@ -147,11 +153,15 @@ class LLM:
|
||||
# log the prompt
|
||||
debug_message = ''
|
||||
for message in messages:
|
||||
debug_message += message_separator + message['content']
|
||||
if message['content'].strip():
|
||||
debug_message += message_separator + message['content']
|
||||
llm_prompt_logger.debug(debug_message)
|
||||
|
||||
# call the completion function
|
||||
resp = completion_unwrapped(*args, **kwargs)
|
||||
# skip if messages is empty (thus debug_message is empty)
|
||||
if debug_message:
|
||||
resp = completion_unwrapped(*args, **kwargs)
|
||||
else:
|
||||
resp = {'choices': [{'message': {'content': ''}}]}
|
||||
|
||||
# log the response
|
||||
message_back = resp['choices'][0]['message']['content']
|
||||
@@ -159,10 +169,191 @@ class LLM:
|
||||
|
||||
# post-process to log costs
|
||||
self._post_completion(resp)
|
||||
|
||||
return resp
|
||||
|
||||
self._completion = wrapper # type: ignore
|
||||
|
||||
# Async version
|
||||
self._async_completion = partial(
|
||||
self._call_acompletion,
|
||||
model=self.config.model,
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url,
|
||||
api_version=self.config.api_version,
|
||||
custom_llm_provider=self.config.custom_llm_provider,
|
||||
max_tokens=self.config.max_output_tokens,
|
||||
timeout=self.config.timeout,
|
||||
temperature=self.config.temperature,
|
||||
top_p=self.config.top_p,
|
||||
drop_params=True,
|
||||
)
|
||||
|
||||
async_completion_unwrapped = self._async_completion
|
||||
|
||||
@retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(self.config.num_retries),
|
||||
wait=wait_random_exponential(
|
||||
multiplier=self.config.retry_multiplier,
|
||||
min=self.config.retry_min_wait,
|
||||
max=self.config.retry_max_wait,
|
||||
),
|
||||
retry=retry_if_exception_type(
|
||||
(
|
||||
RateLimitError,
|
||||
APIConnectionError,
|
||||
ServiceUnavailableError,
|
||||
InternalServerError,
|
||||
ContentPolicyViolationError,
|
||||
)
|
||||
),
|
||||
after=attempt_on_error,
|
||||
)
|
||||
async def async_completion_wrapper(*args, **kwargs):
|
||||
"""Async wrapper for the litellm acompletion function."""
|
||||
# some callers might just send the messages directly
|
||||
if 'messages' in kwargs:
|
||||
messages = kwargs['messages']
|
||||
else:
|
||||
messages = args[1]
|
||||
|
||||
# log the prompt
|
||||
debug_message = ''
|
||||
for message in messages:
|
||||
debug_message += message_separator + message['content']
|
||||
llm_prompt_logger.debug(debug_message)
|
||||
|
||||
async def check_stopped():
|
||||
while True:
|
||||
if (
|
||||
hasattr(self.config, 'on_cancel_requested_fn')
|
||||
and self.config.on_cancel_requested_fn is not None
|
||||
and await self.config.on_cancel_requested_fn()
|
||||
):
|
||||
raise UserCancelledError('LLM request cancelled by user')
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
stop_check_task = asyncio.create_task(check_stopped())
|
||||
|
||||
try:
|
||||
# Directly call and await litellm_acompletion
|
||||
resp = await async_completion_unwrapped(*args, **kwargs)
|
||||
|
||||
# skip if messages is empty (thus debug_message is empty)
|
||||
if debug_message:
|
||||
message_back = resp['choices'][0]['message']['content']
|
||||
llm_response_logger.debug(message_back)
|
||||
else:
|
||||
resp = {'choices': [{'message': {'content': ''}}]}
|
||||
self._post_completion(resp)
|
||||
|
||||
# We do not support streaming in this method, thus return resp
|
||||
return resp
|
||||
|
||||
except UserCancelledError:
|
||||
logger.info('LLM request cancelled by user.')
|
||||
raise
|
||||
except OpenAIError as e:
|
||||
logger.error(f'OpenAIError occurred:\n{e}')
|
||||
raise
|
||||
except (
|
||||
RateLimitError,
|
||||
APIConnectionError,
|
||||
ServiceUnavailableError,
|
||||
InternalServerError,
|
||||
) as e:
|
||||
logger.error(f'Completion Error occurred:\n{e}')
|
||||
raise
|
||||
|
||||
finally:
|
||||
await asyncio.sleep(0.1)
|
||||
stop_check_task.cancel()
|
||||
try:
|
||||
await stop_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
@retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(self.config.num_retries),
|
||||
wait=wait_random_exponential(
|
||||
multiplier=self.config.retry_multiplier,
|
||||
min=self.config.retry_min_wait,
|
||||
max=self.config.retry_max_wait,
|
||||
),
|
||||
retry=retry_if_exception_type(
|
||||
(
|
||||
RateLimitError,
|
||||
APIConnectionError,
|
||||
ServiceUnavailableError,
|
||||
InternalServerError,
|
||||
ContentPolicyViolationError,
|
||||
)
|
||||
),
|
||||
after=attempt_on_error,
|
||||
)
|
||||
async def async_acompletion_stream_wrapper(*args, **kwargs):
|
||||
"""Async wrapper for the litellm acompletion with streaming function."""
|
||||
# some callers might just send the messages directly
|
||||
if 'messages' in kwargs:
|
||||
messages = kwargs['messages']
|
||||
else:
|
||||
messages = args[1]
|
||||
|
||||
# log the prompt
|
||||
debug_message = ''
|
||||
for message in messages:
|
||||
debug_message += message_separator + message['content']
|
||||
llm_prompt_logger.debug(debug_message)
|
||||
|
||||
try:
|
||||
# Directly call and await litellm_acompletion
|
||||
resp = await async_completion_unwrapped(*args, **kwargs)
|
||||
|
||||
# For streaming we iterate over the chunks
|
||||
async for chunk in resp:
|
||||
# Check for cancellation before yielding the chunk
|
||||
if (
|
||||
hasattr(self.config, 'on_cancel_requested_fn')
|
||||
and self.config.on_cancel_requested_fn is not None
|
||||
and await self.config.on_cancel_requested_fn()
|
||||
):
|
||||
raise UserCancelledError(
|
||||
'LLM request cancelled due to CANCELLED state'
|
||||
)
|
||||
# with streaming, it is "delta", not "message"!
|
||||
message_back = chunk['choices'][0]['delta']['content']
|
||||
llm_response_logger.debug(message_back)
|
||||
self._post_completion(chunk)
|
||||
|
||||
yield chunk
|
||||
|
||||
except UserCancelledError:
|
||||
logger.info('LLM request cancelled by user.')
|
||||
raise
|
||||
except OpenAIError as e:
|
||||
logger.error(f'OpenAIError occurred:\n{e}')
|
||||
raise
|
||||
except (
|
||||
RateLimitError,
|
||||
APIConnectionError,
|
||||
ServiceUnavailableError,
|
||||
InternalServerError,
|
||||
) as e:
|
||||
logger.error(f'Completion Error occurred:\n{e}')
|
||||
raise
|
||||
|
||||
finally:
|
||||
if kwargs.get('stream', False):
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
self._async_completion = async_completion_wrapper # type: ignore
|
||||
self._async_streaming_completion = async_acompletion_stream_wrapper # type: ignore
|
||||
|
||||
async def _call_acompletion(self, *args, **kwargs):
|
||||
return await litellm.acompletion(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def completion(self):
|
||||
"""Decorator for the litellm completion function.
|
||||
@@ -171,6 +362,22 @@ class LLM:
|
||||
"""
|
||||
return self._completion
|
||||
|
||||
@property
|
||||
def async_completion(self):
|
||||
"""Decorator for the async litellm acompletion function.
|
||||
|
||||
Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
|
||||
"""
|
||||
return self._async_completion
|
||||
|
||||
@property
|
||||
def async_streaming_completion(self):
|
||||
"""Decorator for the async litellm acompletion function with streaming.
|
||||
|
||||
Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
|
||||
"""
|
||||
return self._async_streaming_completion
|
||||
|
||||
def _post_completion(self, response: str) -> None:
|
||||
"""Post-process the completion response."""
|
||||
try:
|
||||
|
||||
187
tests/unit/test_acompletion.py
Normal file
187
tests/unit/test_acompletion.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from opendevin.core.config import load_app_config
|
||||
from opendevin.core.exceptions import UserCancelledError
|
||||
from opendevin.llm.llm import LLM
|
||||
|
||||
config = load_app_config()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_llm():
|
||||
# Create a mock config for testing
|
||||
return LLM(config=config.get_llm_config())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response():
|
||||
return [
|
||||
{'choices': [{'delta': {'content': 'This is a'}}]},
|
||||
{'choices': [{'delta': {'content': ' test'}}]},
|
||||
{'choices': [{'delta': {'content': ' message.'}}]},
|
||||
{'choices': [{'delta': {'content': ' It is'}}]},
|
||||
{'choices': [{'delta': {'content': ' a bit'}}]},
|
||||
{'choices': [{'delta': {'content': ' longer'}}]},
|
||||
{'choices': [{'delta': {'content': ' than'}}]},
|
||||
{'choices': [{'delta': {'content': ' the'}}]},
|
||||
{'choices': [{'delta': {'content': ' previous'}}]},
|
||||
{'choices': [{'delta': {'content': ' one,'}}]},
|
||||
{'choices': [{'delta': {'content': ' but'}}]},
|
||||
{'choices': [{'delta': {'content': ' hopefully'}}]},
|
||||
{'choices': [{'delta': {'content': ' still'}}]},
|
||||
{'choices': [{'delta': {'content': ' short'}}]},
|
||||
{'choices': [{'delta': {'content': ' enough.'}}]},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_non_streaming():
|
||||
with patch.object(LLM, '_call_acompletion') as mock_call_acompletion:
|
||||
mock_response = {
|
||||
'choices': [{'message': {'content': 'This is a test message.'}}]
|
||||
}
|
||||
mock_call_acompletion.return_value = mock_response
|
||||
test_llm = LLM(config=config.get_llm_config())
|
||||
response = await test_llm.async_completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
drop_params=True,
|
||||
)
|
||||
# Assertions for non-streaming completion
|
||||
assert response['choices'][0]['message']['content'] != ''
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_streaming(mock_response):
|
||||
with patch.object(LLM, '_call_acompletion') as mock_call_acompletion:
|
||||
mock_call_acompletion.return_value.__aiter__.return_value = iter(mock_response)
|
||||
test_llm = LLM(config=config.get_llm_config())
|
||||
async for chunk in test_llm.async_streaming_completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}], stream=True
|
||||
):
|
||||
print(f"Chunk: {chunk['choices'][0]['delta']['content']}")
|
||||
# Assertions for streaming completion
|
||||
assert chunk['choices'][0]['delta']['content'] in [
|
||||
r['choices'][0]['delta']['content'] for r in mock_response
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion(test_llm):
|
||||
with patch.object(LLM, 'completion') as mock_completion:
|
||||
mock_completion.return_value = {
|
||||
'choices': [{'message': {'content': 'This is a test message.'}}]
|
||||
}
|
||||
response = test_llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}])
|
||||
assert response['choices'][0]['message']['content'] == 'This is a test message.'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize('cancel_delay', [0.1, 0.3, 0.5, 0.7, 0.9])
|
||||
async def test_async_completion_with_user_cancellation(cancel_delay):
|
||||
cancel_event = asyncio.Event()
|
||||
|
||||
async def mock_on_cancel_requested():
|
||||
is_set = cancel_event.is_set()
|
||||
print(f'Cancel requested: {is_set}')
|
||||
return is_set
|
||||
|
||||
config = load_app_config()
|
||||
config.on_cancel_requested_fn = mock_on_cancel_requested
|
||||
|
||||
async def mock_acompletion(*args, **kwargs):
|
||||
print('Starting mock_acompletion')
|
||||
for i in range(20): # Increased iterations for longer running task
|
||||
print(f'mock_acompletion iteration {i}')
|
||||
await asyncio.sleep(0.1)
|
||||
if await mock_on_cancel_requested():
|
||||
print('Cancellation detected in mock_acompletion')
|
||||
raise UserCancelledError('LLM request cancelled by user')
|
||||
print('Completing mock_acompletion without cancellation')
|
||||
return {'choices': [{'message': {'content': 'This is a test message.'}}]}
|
||||
|
||||
with patch.object(
|
||||
LLM, '_call_acompletion', new_callable=AsyncMock
|
||||
) as mock_call_acompletion:
|
||||
mock_call_acompletion.side_effect = mock_acompletion
|
||||
test_llm = LLM(config=config.get_llm_config())
|
||||
|
||||
async def cancel_after_delay():
|
||||
print(f'Starting cancel_after_delay with delay {cancel_delay}')
|
||||
await asyncio.sleep(cancel_delay)
|
||||
print('Setting cancel event')
|
||||
cancel_event.set()
|
||||
|
||||
with pytest.raises(UserCancelledError):
|
||||
await asyncio.gather(
|
||||
test_llm.async_completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}],
|
||||
stream=False,
|
||||
),
|
||||
cancel_after_delay(),
|
||||
)
|
||||
|
||||
# Ensure the mock was called
|
||||
mock_call_acompletion.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize('cancel_after_chunks', [1, 3, 5, 7, 9])
|
||||
async def test_async_streaming_completion_with_user_cancellation(cancel_after_chunks):
|
||||
cancel_requested = False
|
||||
|
||||
async def mock_on_cancel_requested():
|
||||
nonlocal cancel_requested
|
||||
return cancel_requested
|
||||
|
||||
config = load_app_config()
|
||||
config.on_cancel_requested_fn = mock_on_cancel_requested
|
||||
|
||||
test_messages = [
|
||||
'This is ',
|
||||
'a test ',
|
||||
'message ',
|
||||
'with ',
|
||||
'multiple ',
|
||||
'chunks ',
|
||||
'to ',
|
||||
'simulate ',
|
||||
'a ',
|
||||
'longer ',
|
||||
'streaming ',
|
||||
'response.',
|
||||
]
|
||||
|
||||
async def mock_acompletion(*args, **kwargs):
|
||||
for i, content in enumerate(test_messages):
|
||||
yield {'choices': [{'delta': {'content': content}}]}
|
||||
if i + 1 == cancel_after_chunks:
|
||||
nonlocal cancel_requested
|
||||
cancel_requested = True
|
||||
if cancel_requested:
|
||||
raise UserCancelledError('LLM request cancelled by user')
|
||||
await asyncio.sleep(0.05) # Simulate some delay between chunks
|
||||
|
||||
with patch.object(
|
||||
LLM, '_call_acompletion', new_callable=AsyncMock
|
||||
) as mock_call_acompletion:
|
||||
mock_call_acompletion.return_value = mock_acompletion()
|
||||
test_llm = LLM(config=config.get_llm_config())
|
||||
|
||||
received_chunks = []
|
||||
with pytest.raises(UserCancelledError):
|
||||
async for chunk in test_llm.async_streaming_completion(
|
||||
messages=[{'role': 'user', 'content': 'Hello!'}], stream=True
|
||||
):
|
||||
received_chunks.append(chunk['choices'][0]['delta']['content'])
|
||||
print(f"Chunk: {chunk['choices'][0]['delta']['content']}")
|
||||
|
||||
# Assert that we received the expected number of chunks before cancellation
|
||||
assert len(received_chunks) == cancel_after_chunks
|
||||
assert received_chunks == test_messages[:cancel_after_chunks]
|
||||
|
||||
# Ensure the mock was called
|
||||
mock_call_acompletion.assert_called_once()
|
||||
Reference in New Issue
Block a user