(feat) LLM class: added acompletion and streaming + unit test (#3202)

* LLM class: added acompletion and streaming, unit test test_acompletion.py

* LLM: cleanup of self.config defaults and their use

* added set_missing_attributes to LLMConfig

* move default checker up
This commit is contained in:
tobitege
2024-08-01 22:41:40 +02:00
committed by GitHub
parent 8d11e0eac9
commit a4cb880699
4 changed files with 416 additions and 11 deletions

View File

@@ -111,6 +111,12 @@ class LLMConfig:
ret[k] = '******' if v else None
return ret
def set_missing_attributes(self):
"""Set any missing attributes to their default values."""
for field_name, field_obj in self.__dataclass_fields__.items():
if not hasattr(self, field_name):
setattr(self, field_name, field_obj.default)
@dataclass
class AgentConfig:

View File

@@ -67,3 +67,8 @@ class LLMNoActionError(Exception):
class LLMResponseError(Exception):
def __init__(self, message='Failed to retrieve action from LLM response'):
super().__init__(message)
class UserCancelledError(Exception):
def __init__(self, message='User cancelled the request'):
super().__init__(message)

View File

@@ -1,3 +1,4 @@
import asyncio
import copy
import warnings
from functools import partial
@@ -13,6 +14,7 @@ from litellm.exceptions import (
APIConnectionError,
ContentPolicyViolationError,
InternalServerError,
OpenAIError,
RateLimitError,
ServiceUnavailableError,
)
@@ -24,6 +26,7 @@ from tenacity import (
wait_random_exponential,
)
from opendevin.core.exceptions import UserCancelledError
from opendevin.core.logger import llm_prompt_logger, llm_response_logger
from opendevin.core.logger import opendevin_logger as logger
from opendevin.core.metrics import Metrics
@@ -56,6 +59,9 @@ class LLM:
self.metrics = metrics if metrics is not None else Metrics()
self.cost_metric_supported = True
# Set up config attributes with default values to prevent AttributeError
LLMConfig.set_missing_attributes(self.config)
# litellm actually uses base Exception here for unknown model
self.model_info = None
try:
@@ -66,11 +72,11 @@ class LLM:
self.config.model.split(':')[0]
)
# noinspection PyBroadException
except Exception:
logger.warning(f'Could not get model info for {config.model}')
except Exception as e:
logger.warning(f'Could not get model info for {config.model}:\n{e}')
# Set the max tokens in an LM-specific way if not set
if config.max_input_tokens is None:
if self.config.max_input_tokens is None:
if (
self.model_info is not None
and 'max_input_tokens' in self.model_info
@@ -81,7 +87,7 @@ class LLM:
# Max input tokens for gpt3.5, so this is a safe fallback for any potentially viable model
self.config.max_input_tokens = 4096
if config.max_output_tokens is None:
if self.config.max_output_tokens is None:
if (
self.model_info is not None
and 'max_output_tokens' in self.model_info
@@ -119,11 +125,11 @@ class LLM:
@retry(
reraise=True,
stop=stop_after_attempt(config.num_retries),
stop=stop_after_attempt(self.config.num_retries),
wait=wait_random_exponential(
multiplier=config.retry_multiplier,
min=config.retry_min_wait,
max=config.retry_max_wait,
multiplier=self.config.retry_multiplier,
min=self.config.retry_min_wait,
max=self.config.retry_max_wait,
),
retry=retry_if_exception_type(
(
@@ -147,11 +153,15 @@ class LLM:
# log the prompt
debug_message = ''
for message in messages:
debug_message += message_separator + message['content']
if message['content'].strip():
debug_message += message_separator + message['content']
llm_prompt_logger.debug(debug_message)
# call the completion function
resp = completion_unwrapped(*args, **kwargs)
# skip if messages is empty (thus debug_message is empty)
if debug_message:
resp = completion_unwrapped(*args, **kwargs)
else:
resp = {'choices': [{'message': {'content': ''}}]}
# log the response
message_back = resp['choices'][0]['message']['content']
@@ -159,10 +169,191 @@ class LLM:
# post-process to log costs
self._post_completion(resp)
return resp
self._completion = wrapper # type: ignore
# Async version
self._async_completion = partial(
self._call_acompletion,
model=self.config.model,
api_key=self.config.api_key,
base_url=self.config.base_url,
api_version=self.config.api_version,
custom_llm_provider=self.config.custom_llm_provider,
max_tokens=self.config.max_output_tokens,
timeout=self.config.timeout,
temperature=self.config.temperature,
top_p=self.config.top_p,
drop_params=True,
)
async_completion_unwrapped = self._async_completion
@retry(
reraise=True,
stop=stop_after_attempt(self.config.num_retries),
wait=wait_random_exponential(
multiplier=self.config.retry_multiplier,
min=self.config.retry_min_wait,
max=self.config.retry_max_wait,
),
retry=retry_if_exception_type(
(
RateLimitError,
APIConnectionError,
ServiceUnavailableError,
InternalServerError,
ContentPolicyViolationError,
)
),
after=attempt_on_error,
)
async def async_completion_wrapper(*args, **kwargs):
"""Async wrapper for the litellm acompletion function."""
# some callers might just send the messages directly
if 'messages' in kwargs:
messages = kwargs['messages']
else:
messages = args[1]
# log the prompt
debug_message = ''
for message in messages:
debug_message += message_separator + message['content']
llm_prompt_logger.debug(debug_message)
async def check_stopped():
while True:
if (
hasattr(self.config, 'on_cancel_requested_fn')
and self.config.on_cancel_requested_fn is not None
and await self.config.on_cancel_requested_fn()
):
raise UserCancelledError('LLM request cancelled by user')
await asyncio.sleep(0.1)
stop_check_task = asyncio.create_task(check_stopped())
try:
# Directly call and await litellm_acompletion
resp = await async_completion_unwrapped(*args, **kwargs)
# skip if messages is empty (thus debug_message is empty)
if debug_message:
message_back = resp['choices'][0]['message']['content']
llm_response_logger.debug(message_back)
else:
resp = {'choices': [{'message': {'content': ''}}]}
self._post_completion(resp)
# We do not support streaming in this method, thus return resp
return resp
except UserCancelledError:
logger.info('LLM request cancelled by user.')
raise
except OpenAIError as e:
logger.error(f'OpenAIError occurred:\n{e}')
raise
except (
RateLimitError,
APIConnectionError,
ServiceUnavailableError,
InternalServerError,
) as e:
logger.error(f'Completion Error occurred:\n{e}')
raise
finally:
await asyncio.sleep(0.1)
stop_check_task.cancel()
try:
await stop_check_task
except asyncio.CancelledError:
pass
@retry(
reraise=True,
stop=stop_after_attempt(self.config.num_retries),
wait=wait_random_exponential(
multiplier=self.config.retry_multiplier,
min=self.config.retry_min_wait,
max=self.config.retry_max_wait,
),
retry=retry_if_exception_type(
(
RateLimitError,
APIConnectionError,
ServiceUnavailableError,
InternalServerError,
ContentPolicyViolationError,
)
),
after=attempt_on_error,
)
async def async_acompletion_stream_wrapper(*args, **kwargs):
"""Async wrapper for the litellm acompletion with streaming function."""
# some callers might just send the messages directly
if 'messages' in kwargs:
messages = kwargs['messages']
else:
messages = args[1]
# log the prompt
debug_message = ''
for message in messages:
debug_message += message_separator + message['content']
llm_prompt_logger.debug(debug_message)
try:
# Directly call and await litellm_acompletion
resp = await async_completion_unwrapped(*args, **kwargs)
# For streaming we iterate over the chunks
async for chunk in resp:
# Check for cancellation before yielding the chunk
if (
hasattr(self.config, 'on_cancel_requested_fn')
and self.config.on_cancel_requested_fn is not None
and await self.config.on_cancel_requested_fn()
):
raise UserCancelledError(
'LLM request cancelled due to CANCELLED state'
)
# with streaming, it is "delta", not "message"!
message_back = chunk['choices'][0]['delta']['content']
llm_response_logger.debug(message_back)
self._post_completion(chunk)
yield chunk
except UserCancelledError:
logger.info('LLM request cancelled by user.')
raise
except OpenAIError as e:
logger.error(f'OpenAIError occurred:\n{e}')
raise
except (
RateLimitError,
APIConnectionError,
ServiceUnavailableError,
InternalServerError,
) as e:
logger.error(f'Completion Error occurred:\n{e}')
raise
finally:
if kwargs.get('stream', False):
await asyncio.sleep(0.1)
self._async_completion = async_completion_wrapper # type: ignore
self._async_streaming_completion = async_acompletion_stream_wrapper # type: ignore
async def _call_acompletion(self, *args, **kwargs):
return await litellm.acompletion(*args, **kwargs)
@property
def completion(self):
"""Decorator for the litellm completion function.
@@ -171,6 +362,22 @@ class LLM:
"""
return self._completion
@property
def async_completion(self):
"""Decorator for the async litellm acompletion function.
Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
"""
return self._async_completion
@property
def async_streaming_completion(self):
"""Decorator for the async litellm acompletion function with streaming.
Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion
"""
return self._async_streaming_completion
def _post_completion(self, response: str) -> None:
"""Post-process the completion response."""
try:

View File

@@ -0,0 +1,187 @@
import asyncio
from unittest.mock import AsyncMock, patch
import pytest
from opendevin.core.config import load_app_config
from opendevin.core.exceptions import UserCancelledError
from opendevin.llm.llm import LLM
config = load_app_config()
@pytest.fixture
def test_llm():
# Create a mock config for testing
return LLM(config=config.get_llm_config())
@pytest.fixture
def mock_response():
return [
{'choices': [{'delta': {'content': 'This is a'}}]},
{'choices': [{'delta': {'content': ' test'}}]},
{'choices': [{'delta': {'content': ' message.'}}]},
{'choices': [{'delta': {'content': ' It is'}}]},
{'choices': [{'delta': {'content': ' a bit'}}]},
{'choices': [{'delta': {'content': ' longer'}}]},
{'choices': [{'delta': {'content': ' than'}}]},
{'choices': [{'delta': {'content': ' the'}}]},
{'choices': [{'delta': {'content': ' previous'}}]},
{'choices': [{'delta': {'content': ' one,'}}]},
{'choices': [{'delta': {'content': ' but'}}]},
{'choices': [{'delta': {'content': ' hopefully'}}]},
{'choices': [{'delta': {'content': ' still'}}]},
{'choices': [{'delta': {'content': ' short'}}]},
{'choices': [{'delta': {'content': ' enough.'}}]},
]
@pytest.mark.asyncio
async def test_acompletion_non_streaming():
with patch.object(LLM, '_call_acompletion') as mock_call_acompletion:
mock_response = {
'choices': [{'message': {'content': 'This is a test message.'}}]
}
mock_call_acompletion.return_value = mock_response
test_llm = LLM(config=config.get_llm_config())
response = await test_llm.async_completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
drop_params=True,
)
# Assertions for non-streaming completion
assert response['choices'][0]['message']['content'] != ''
@pytest.mark.asyncio
async def test_acompletion_streaming(mock_response):
with patch.object(LLM, '_call_acompletion') as mock_call_acompletion:
mock_call_acompletion.return_value.__aiter__.return_value = iter(mock_response)
test_llm = LLM(config=config.get_llm_config())
async for chunk in test_llm.async_streaming_completion(
messages=[{'role': 'user', 'content': 'Hello!'}], stream=True
):
print(f"Chunk: {chunk['choices'][0]['delta']['content']}")
# Assertions for streaming completion
assert chunk['choices'][0]['delta']['content'] in [
r['choices'][0]['delta']['content'] for r in mock_response
]
@pytest.mark.asyncio
async def test_completion(test_llm):
with patch.object(LLM, 'completion') as mock_completion:
mock_completion.return_value = {
'choices': [{'message': {'content': 'This is a test message.'}}]
}
response = test_llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}])
assert response['choices'][0]['message']['content'] == 'This is a test message.'
@pytest.mark.asyncio
@pytest.mark.parametrize('cancel_delay', [0.1, 0.3, 0.5, 0.7, 0.9])
async def test_async_completion_with_user_cancellation(cancel_delay):
cancel_event = asyncio.Event()
async def mock_on_cancel_requested():
is_set = cancel_event.is_set()
print(f'Cancel requested: {is_set}')
return is_set
config = load_app_config()
config.on_cancel_requested_fn = mock_on_cancel_requested
async def mock_acompletion(*args, **kwargs):
print('Starting mock_acompletion')
for i in range(20): # Increased iterations for longer running task
print(f'mock_acompletion iteration {i}')
await asyncio.sleep(0.1)
if await mock_on_cancel_requested():
print('Cancellation detected in mock_acompletion')
raise UserCancelledError('LLM request cancelled by user')
print('Completing mock_acompletion without cancellation')
return {'choices': [{'message': {'content': 'This is a test message.'}}]}
with patch.object(
LLM, '_call_acompletion', new_callable=AsyncMock
) as mock_call_acompletion:
mock_call_acompletion.side_effect = mock_acompletion
test_llm = LLM(config=config.get_llm_config())
async def cancel_after_delay():
print(f'Starting cancel_after_delay with delay {cancel_delay}')
await asyncio.sleep(cancel_delay)
print('Setting cancel event')
cancel_event.set()
with pytest.raises(UserCancelledError):
await asyncio.gather(
test_llm.async_completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
),
cancel_after_delay(),
)
# Ensure the mock was called
mock_call_acompletion.assert_called_once()
@pytest.mark.asyncio
@pytest.mark.parametrize('cancel_after_chunks', [1, 3, 5, 7, 9])
async def test_async_streaming_completion_with_user_cancellation(cancel_after_chunks):
cancel_requested = False
async def mock_on_cancel_requested():
nonlocal cancel_requested
return cancel_requested
config = load_app_config()
config.on_cancel_requested_fn = mock_on_cancel_requested
test_messages = [
'This is ',
'a test ',
'message ',
'with ',
'multiple ',
'chunks ',
'to ',
'simulate ',
'a ',
'longer ',
'streaming ',
'response.',
]
async def mock_acompletion(*args, **kwargs):
for i, content in enumerate(test_messages):
yield {'choices': [{'delta': {'content': content}}]}
if i + 1 == cancel_after_chunks:
nonlocal cancel_requested
cancel_requested = True
if cancel_requested:
raise UserCancelledError('LLM request cancelled by user')
await asyncio.sleep(0.05) # Simulate some delay between chunks
with patch.object(
LLM, '_call_acompletion', new_callable=AsyncMock
) as mock_call_acompletion:
mock_call_acompletion.return_value = mock_acompletion()
test_llm = LLM(config=config.get_llm_config())
received_chunks = []
with pytest.raises(UserCancelledError):
async for chunk in test_llm.async_streaming_completion(
messages=[{'role': 'user', 'content': 'Hello!'}], stream=True
):
received_chunks.append(chunk['choices'][0]['delta']['content'])
print(f"Chunk: {chunk['choices'][0]['delta']['content']}")
# Assert that we received the expected number of chunks before cancellation
assert len(received_chunks) == cancel_after_chunks
assert received_chunks == test_messages[:cancel_after_chunks]
# Ensure the mock was called
mock_call_acompletion.assert_called_once()