diff --git a/opendevin/core/config.py b/opendevin/core/config.py index 86a21a9d82..673973146f 100644 --- a/opendevin/core/config.py +++ b/opendevin/core/config.py @@ -111,6 +111,12 @@ class LLMConfig: ret[k] = '******' if v else None return ret + def set_missing_attributes(self): + """Set any missing attributes to their default values.""" + for field_name, field_obj in self.__dataclass_fields__.items(): + if not hasattr(self, field_name): + setattr(self, field_name, field_obj.default) + @dataclass class AgentConfig: diff --git a/opendevin/core/exceptions.py b/opendevin/core/exceptions.py index fe97308396..c7e9da7614 100644 --- a/opendevin/core/exceptions.py +++ b/opendevin/core/exceptions.py @@ -67,3 +67,8 @@ class LLMNoActionError(Exception): class LLMResponseError(Exception): def __init__(self, message='Failed to retrieve action from LLM response'): super().__init__(message) + + +class UserCancelledError(Exception): + def __init__(self, message='User cancelled the request'): + super().__init__(message) diff --git a/opendevin/llm/llm.py b/opendevin/llm/llm.py index 5d4e8b55ab..ac502a07a9 100644 --- a/opendevin/llm/llm.py +++ b/opendevin/llm/llm.py @@ -1,3 +1,4 @@ +import asyncio import copy import warnings from functools import partial @@ -13,6 +14,7 @@ from litellm.exceptions import ( APIConnectionError, ContentPolicyViolationError, InternalServerError, + OpenAIError, RateLimitError, ServiceUnavailableError, ) @@ -24,6 +26,7 @@ from tenacity import ( wait_random_exponential, ) +from opendevin.core.exceptions import UserCancelledError from opendevin.core.logger import llm_prompt_logger, llm_response_logger from opendevin.core.logger import opendevin_logger as logger from opendevin.core.metrics import Metrics @@ -56,6 +59,9 @@ class LLM: self.metrics = metrics if metrics is not None else Metrics() self.cost_metric_supported = True + # Set up config attributes with default values to prevent AttributeError + LLMConfig.set_missing_attributes(self.config) + # litellm actually uses base Exception here for unknown model self.model_info = None try: @@ -66,11 +72,11 @@ class LLM: self.config.model.split(':')[0] ) # noinspection PyBroadException - except Exception: - logger.warning(f'Could not get model info for {config.model}') + except Exception as e: + logger.warning(f'Could not get model info for {config.model}:\n{e}') # Set the max tokens in an LM-specific way if not set - if config.max_input_tokens is None: + if self.config.max_input_tokens is None: if ( self.model_info is not None and 'max_input_tokens' in self.model_info @@ -81,7 +87,7 @@ class LLM: # Max input tokens for gpt3.5, so this is a safe fallback for any potentially viable model self.config.max_input_tokens = 4096 - if config.max_output_tokens is None: + if self.config.max_output_tokens is None: if ( self.model_info is not None and 'max_output_tokens' in self.model_info @@ -119,11 +125,11 @@ class LLM: @retry( reraise=True, - stop=stop_after_attempt(config.num_retries), + stop=stop_after_attempt(self.config.num_retries), wait=wait_random_exponential( - multiplier=config.retry_multiplier, - min=config.retry_min_wait, - max=config.retry_max_wait, + multiplier=self.config.retry_multiplier, + min=self.config.retry_min_wait, + max=self.config.retry_max_wait, ), retry=retry_if_exception_type( ( @@ -147,11 +153,15 @@ class LLM: # log the prompt debug_message = '' for message in messages: - debug_message += message_separator + message['content'] + if message['content'].strip(): + debug_message += message_separator + message['content'] llm_prompt_logger.debug(debug_message) - # call the completion function - resp = completion_unwrapped(*args, **kwargs) + # skip if messages is empty (thus debug_message is empty) + if debug_message: + resp = completion_unwrapped(*args, **kwargs) + else: + resp = {'choices': [{'message': {'content': ''}}]} # log the response message_back = resp['choices'][0]['message']['content'] @@ -159,10 +169,191 @@ class LLM: # post-process to log costs self._post_completion(resp) + return resp self._completion = wrapper # type: ignore + # Async version + self._async_completion = partial( + self._call_acompletion, + model=self.config.model, + api_key=self.config.api_key, + base_url=self.config.base_url, + api_version=self.config.api_version, + custom_llm_provider=self.config.custom_llm_provider, + max_tokens=self.config.max_output_tokens, + timeout=self.config.timeout, + temperature=self.config.temperature, + top_p=self.config.top_p, + drop_params=True, + ) + + async_completion_unwrapped = self._async_completion + + @retry( + reraise=True, + stop=stop_after_attempt(self.config.num_retries), + wait=wait_random_exponential( + multiplier=self.config.retry_multiplier, + min=self.config.retry_min_wait, + max=self.config.retry_max_wait, + ), + retry=retry_if_exception_type( + ( + RateLimitError, + APIConnectionError, + ServiceUnavailableError, + InternalServerError, + ContentPolicyViolationError, + ) + ), + after=attempt_on_error, + ) + async def async_completion_wrapper(*args, **kwargs): + """Async wrapper for the litellm acompletion function.""" + # some callers might just send the messages directly + if 'messages' in kwargs: + messages = kwargs['messages'] + else: + messages = args[1] + + # log the prompt + debug_message = '' + for message in messages: + debug_message += message_separator + message['content'] + llm_prompt_logger.debug(debug_message) + + async def check_stopped(): + while True: + if ( + hasattr(self.config, 'on_cancel_requested_fn') + and self.config.on_cancel_requested_fn is not None + and await self.config.on_cancel_requested_fn() + ): + raise UserCancelledError('LLM request cancelled by user') + await asyncio.sleep(0.1) + + stop_check_task = asyncio.create_task(check_stopped()) + + try: + # Directly call and await litellm_acompletion + resp = await async_completion_unwrapped(*args, **kwargs) + + # skip if messages is empty (thus debug_message is empty) + if debug_message: + message_back = resp['choices'][0]['message']['content'] + llm_response_logger.debug(message_back) + else: + resp = {'choices': [{'message': {'content': ''}}]} + self._post_completion(resp) + + # We do not support streaming in this method, thus return resp + return resp + + except UserCancelledError: + logger.info('LLM request cancelled by user.') + raise + except OpenAIError as e: + logger.error(f'OpenAIError occurred:\n{e}') + raise + except ( + RateLimitError, + APIConnectionError, + ServiceUnavailableError, + InternalServerError, + ) as e: + logger.error(f'Completion Error occurred:\n{e}') + raise + + finally: + await asyncio.sleep(0.1) + stop_check_task.cancel() + try: + await stop_check_task + except asyncio.CancelledError: + pass + + @retry( + reraise=True, + stop=stop_after_attempt(self.config.num_retries), + wait=wait_random_exponential( + multiplier=self.config.retry_multiplier, + min=self.config.retry_min_wait, + max=self.config.retry_max_wait, + ), + retry=retry_if_exception_type( + ( + RateLimitError, + APIConnectionError, + ServiceUnavailableError, + InternalServerError, + ContentPolicyViolationError, + ) + ), + after=attempt_on_error, + ) + async def async_acompletion_stream_wrapper(*args, **kwargs): + """Async wrapper for the litellm acompletion with streaming function.""" + # some callers might just send the messages directly + if 'messages' in kwargs: + messages = kwargs['messages'] + else: + messages = args[1] + + # log the prompt + debug_message = '' + for message in messages: + debug_message += message_separator + message['content'] + llm_prompt_logger.debug(debug_message) + + try: + # Directly call and await litellm_acompletion + resp = await async_completion_unwrapped(*args, **kwargs) + + # For streaming we iterate over the chunks + async for chunk in resp: + # Check for cancellation before yielding the chunk + if ( + hasattr(self.config, 'on_cancel_requested_fn') + and self.config.on_cancel_requested_fn is not None + and await self.config.on_cancel_requested_fn() + ): + raise UserCancelledError( + 'LLM request cancelled due to CANCELLED state' + ) + # with streaming, it is "delta", not "message"! + message_back = chunk['choices'][0]['delta']['content'] + llm_response_logger.debug(message_back) + self._post_completion(chunk) + + yield chunk + + except UserCancelledError: + logger.info('LLM request cancelled by user.') + raise + except OpenAIError as e: + logger.error(f'OpenAIError occurred:\n{e}') + raise + except ( + RateLimitError, + APIConnectionError, + ServiceUnavailableError, + InternalServerError, + ) as e: + logger.error(f'Completion Error occurred:\n{e}') + raise + + finally: + if kwargs.get('stream', False): + await asyncio.sleep(0.1) + + self._async_completion = async_completion_wrapper # type: ignore + self._async_streaming_completion = async_acompletion_stream_wrapper # type: ignore + + async def _call_acompletion(self, *args, **kwargs): + return await litellm.acompletion(*args, **kwargs) + @property def completion(self): """Decorator for the litellm completion function. @@ -171,6 +362,22 @@ class LLM: """ return self._completion + @property + def async_completion(self): + """Decorator for the async litellm acompletion function. + + Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion + """ + return self._async_completion + + @property + def async_streaming_completion(self): + """Decorator for the async litellm acompletion function with streaming. + + Check the complete documentation at https://litellm.vercel.app/docs/providers/ollama#example-usage---streaming--acompletion + """ + return self._async_streaming_completion + def _post_completion(self, response: str) -> None: """Post-process the completion response.""" try: diff --git a/tests/unit/test_acompletion.py b/tests/unit/test_acompletion.py new file mode 100644 index 0000000000..b3de1f4656 --- /dev/null +++ b/tests/unit/test_acompletion.py @@ -0,0 +1,187 @@ +import asyncio +from unittest.mock import AsyncMock, patch + +import pytest + +from opendevin.core.config import load_app_config +from opendevin.core.exceptions import UserCancelledError +from opendevin.llm.llm import LLM + +config = load_app_config() + + +@pytest.fixture +def test_llm(): + # Create a mock config for testing + return LLM(config=config.get_llm_config()) + + +@pytest.fixture +def mock_response(): + return [ + {'choices': [{'delta': {'content': 'This is a'}}]}, + {'choices': [{'delta': {'content': ' test'}}]}, + {'choices': [{'delta': {'content': ' message.'}}]}, + {'choices': [{'delta': {'content': ' It is'}}]}, + {'choices': [{'delta': {'content': ' a bit'}}]}, + {'choices': [{'delta': {'content': ' longer'}}]}, + {'choices': [{'delta': {'content': ' than'}}]}, + {'choices': [{'delta': {'content': ' the'}}]}, + {'choices': [{'delta': {'content': ' previous'}}]}, + {'choices': [{'delta': {'content': ' one,'}}]}, + {'choices': [{'delta': {'content': ' but'}}]}, + {'choices': [{'delta': {'content': ' hopefully'}}]}, + {'choices': [{'delta': {'content': ' still'}}]}, + {'choices': [{'delta': {'content': ' short'}}]}, + {'choices': [{'delta': {'content': ' enough.'}}]}, + ] + + +@pytest.mark.asyncio +async def test_acompletion_non_streaming(): + with patch.object(LLM, '_call_acompletion') as mock_call_acompletion: + mock_response = { + 'choices': [{'message': {'content': 'This is a test message.'}}] + } + mock_call_acompletion.return_value = mock_response + test_llm = LLM(config=config.get_llm_config()) + response = await test_llm.async_completion( + messages=[{'role': 'user', 'content': 'Hello!'}], + stream=False, + drop_params=True, + ) + # Assertions for non-streaming completion + assert response['choices'][0]['message']['content'] != '' + + +@pytest.mark.asyncio +async def test_acompletion_streaming(mock_response): + with patch.object(LLM, '_call_acompletion') as mock_call_acompletion: + mock_call_acompletion.return_value.__aiter__.return_value = iter(mock_response) + test_llm = LLM(config=config.get_llm_config()) + async for chunk in test_llm.async_streaming_completion( + messages=[{'role': 'user', 'content': 'Hello!'}], stream=True + ): + print(f"Chunk: {chunk['choices'][0]['delta']['content']}") + # Assertions for streaming completion + assert chunk['choices'][0]['delta']['content'] in [ + r['choices'][0]['delta']['content'] for r in mock_response + ] + + +@pytest.mark.asyncio +async def test_completion(test_llm): + with patch.object(LLM, 'completion') as mock_completion: + mock_completion.return_value = { + 'choices': [{'message': {'content': 'This is a test message.'}}] + } + response = test_llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}]) + assert response['choices'][0]['message']['content'] == 'This is a test message.' + + +@pytest.mark.asyncio +@pytest.mark.parametrize('cancel_delay', [0.1, 0.3, 0.5, 0.7, 0.9]) +async def test_async_completion_with_user_cancellation(cancel_delay): + cancel_event = asyncio.Event() + + async def mock_on_cancel_requested(): + is_set = cancel_event.is_set() + print(f'Cancel requested: {is_set}') + return is_set + + config = load_app_config() + config.on_cancel_requested_fn = mock_on_cancel_requested + + async def mock_acompletion(*args, **kwargs): + print('Starting mock_acompletion') + for i in range(20): # Increased iterations for longer running task + print(f'mock_acompletion iteration {i}') + await asyncio.sleep(0.1) + if await mock_on_cancel_requested(): + print('Cancellation detected in mock_acompletion') + raise UserCancelledError('LLM request cancelled by user') + print('Completing mock_acompletion without cancellation') + return {'choices': [{'message': {'content': 'This is a test message.'}}]} + + with patch.object( + LLM, '_call_acompletion', new_callable=AsyncMock + ) as mock_call_acompletion: + mock_call_acompletion.side_effect = mock_acompletion + test_llm = LLM(config=config.get_llm_config()) + + async def cancel_after_delay(): + print(f'Starting cancel_after_delay with delay {cancel_delay}') + await asyncio.sleep(cancel_delay) + print('Setting cancel event') + cancel_event.set() + + with pytest.raises(UserCancelledError): + await asyncio.gather( + test_llm.async_completion( + messages=[{'role': 'user', 'content': 'Hello!'}], + stream=False, + ), + cancel_after_delay(), + ) + + # Ensure the mock was called + mock_call_acompletion.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize('cancel_after_chunks', [1, 3, 5, 7, 9]) +async def test_async_streaming_completion_with_user_cancellation(cancel_after_chunks): + cancel_requested = False + + async def mock_on_cancel_requested(): + nonlocal cancel_requested + return cancel_requested + + config = load_app_config() + config.on_cancel_requested_fn = mock_on_cancel_requested + + test_messages = [ + 'This is ', + 'a test ', + 'message ', + 'with ', + 'multiple ', + 'chunks ', + 'to ', + 'simulate ', + 'a ', + 'longer ', + 'streaming ', + 'response.', + ] + + async def mock_acompletion(*args, **kwargs): + for i, content in enumerate(test_messages): + yield {'choices': [{'delta': {'content': content}}]} + if i + 1 == cancel_after_chunks: + nonlocal cancel_requested + cancel_requested = True + if cancel_requested: + raise UserCancelledError('LLM request cancelled by user') + await asyncio.sleep(0.05) # Simulate some delay between chunks + + with patch.object( + LLM, '_call_acompletion', new_callable=AsyncMock + ) as mock_call_acompletion: + mock_call_acompletion.return_value = mock_acompletion() + test_llm = LLM(config=config.get_llm_config()) + + received_chunks = [] + with pytest.raises(UserCancelledError): + async for chunk in test_llm.async_streaming_completion( + messages=[{'role': 'user', 'content': 'Hello!'}], stream=True + ): + received_chunks.append(chunk['choices'][0]['delta']['content']) + print(f"Chunk: {chunk['choices'][0]['delta']['content']}") + + # Assert that we received the expected number of chunks before cancellation + assert len(received_chunks) == cancel_after_chunks + assert received_chunks == test_messages[:cancel_after_chunks] + + # Ensure the mock was called + mock_call_acompletion.assert_called_once()