Files
OpenHands/tests/unit/test_acompletion.py
tobitege a4cb880699 (feat) LLM class: added acompletion and streaming + unit test (#3202)
* LLM class: added acompletion and streaming, unit test test_acompletion.py

* LLM: cleanup of self.config defaults and their use

* added set_missing_attributes to LLMConfig

* move default checker up
2024-08-01 22:41:40 +02:00

188 lines
6.9 KiB
Python

import asyncio
from unittest.mock import AsyncMock, patch
import pytest
from opendevin.core.config import load_app_config
from opendevin.core.exceptions import UserCancelledError
from opendevin.llm.llm import LLM
config = load_app_config()
@pytest.fixture
def test_llm():
# Create a mock config for testing
return LLM(config=config.get_llm_config())
@pytest.fixture
def mock_response():
return [
{'choices': [{'delta': {'content': 'This is a'}}]},
{'choices': [{'delta': {'content': ' test'}}]},
{'choices': [{'delta': {'content': ' message.'}}]},
{'choices': [{'delta': {'content': ' It is'}}]},
{'choices': [{'delta': {'content': ' a bit'}}]},
{'choices': [{'delta': {'content': ' longer'}}]},
{'choices': [{'delta': {'content': ' than'}}]},
{'choices': [{'delta': {'content': ' the'}}]},
{'choices': [{'delta': {'content': ' previous'}}]},
{'choices': [{'delta': {'content': ' one,'}}]},
{'choices': [{'delta': {'content': ' but'}}]},
{'choices': [{'delta': {'content': ' hopefully'}}]},
{'choices': [{'delta': {'content': ' still'}}]},
{'choices': [{'delta': {'content': ' short'}}]},
{'choices': [{'delta': {'content': ' enough.'}}]},
]
@pytest.mark.asyncio
async def test_acompletion_non_streaming():
with patch.object(LLM, '_call_acompletion') as mock_call_acompletion:
mock_response = {
'choices': [{'message': {'content': 'This is a test message.'}}]
}
mock_call_acompletion.return_value = mock_response
test_llm = LLM(config=config.get_llm_config())
response = await test_llm.async_completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
drop_params=True,
)
# Assertions for non-streaming completion
assert response['choices'][0]['message']['content'] != ''
@pytest.mark.asyncio
async def test_acompletion_streaming(mock_response):
with patch.object(LLM, '_call_acompletion') as mock_call_acompletion:
mock_call_acompletion.return_value.__aiter__.return_value = iter(mock_response)
test_llm = LLM(config=config.get_llm_config())
async for chunk in test_llm.async_streaming_completion(
messages=[{'role': 'user', 'content': 'Hello!'}], stream=True
):
print(f"Chunk: {chunk['choices'][0]['delta']['content']}")
# Assertions for streaming completion
assert chunk['choices'][0]['delta']['content'] in [
r['choices'][0]['delta']['content'] for r in mock_response
]
@pytest.mark.asyncio
async def test_completion(test_llm):
with patch.object(LLM, 'completion') as mock_completion:
mock_completion.return_value = {
'choices': [{'message': {'content': 'This is a test message.'}}]
}
response = test_llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}])
assert response['choices'][0]['message']['content'] == 'This is a test message.'
@pytest.mark.asyncio
@pytest.mark.parametrize('cancel_delay', [0.1, 0.3, 0.5, 0.7, 0.9])
async def test_async_completion_with_user_cancellation(cancel_delay):
cancel_event = asyncio.Event()
async def mock_on_cancel_requested():
is_set = cancel_event.is_set()
print(f'Cancel requested: {is_set}')
return is_set
config = load_app_config()
config.on_cancel_requested_fn = mock_on_cancel_requested
async def mock_acompletion(*args, **kwargs):
print('Starting mock_acompletion')
for i in range(20): # Increased iterations for longer running task
print(f'mock_acompletion iteration {i}')
await asyncio.sleep(0.1)
if await mock_on_cancel_requested():
print('Cancellation detected in mock_acompletion')
raise UserCancelledError('LLM request cancelled by user')
print('Completing mock_acompletion without cancellation')
return {'choices': [{'message': {'content': 'This is a test message.'}}]}
with patch.object(
LLM, '_call_acompletion', new_callable=AsyncMock
) as mock_call_acompletion:
mock_call_acompletion.side_effect = mock_acompletion
test_llm = LLM(config=config.get_llm_config())
async def cancel_after_delay():
print(f'Starting cancel_after_delay with delay {cancel_delay}')
await asyncio.sleep(cancel_delay)
print('Setting cancel event')
cancel_event.set()
with pytest.raises(UserCancelledError):
await asyncio.gather(
test_llm.async_completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
),
cancel_after_delay(),
)
# Ensure the mock was called
mock_call_acompletion.assert_called_once()
@pytest.mark.asyncio
@pytest.mark.parametrize('cancel_after_chunks', [1, 3, 5, 7, 9])
async def test_async_streaming_completion_with_user_cancellation(cancel_after_chunks):
cancel_requested = False
async def mock_on_cancel_requested():
nonlocal cancel_requested
return cancel_requested
config = load_app_config()
config.on_cancel_requested_fn = mock_on_cancel_requested
test_messages = [
'This is ',
'a test ',
'message ',
'with ',
'multiple ',
'chunks ',
'to ',
'simulate ',
'a ',
'longer ',
'streaming ',
'response.',
]
async def mock_acompletion(*args, **kwargs):
for i, content in enumerate(test_messages):
yield {'choices': [{'delta': {'content': content}}]}
if i + 1 == cancel_after_chunks:
nonlocal cancel_requested
cancel_requested = True
if cancel_requested:
raise UserCancelledError('LLM request cancelled by user')
await asyncio.sleep(0.05) # Simulate some delay between chunks
with patch.object(
LLM, '_call_acompletion', new_callable=AsyncMock
) as mock_call_acompletion:
mock_call_acompletion.return_value = mock_acompletion()
test_llm = LLM(config=config.get_llm_config())
received_chunks = []
with pytest.raises(UserCancelledError):
async for chunk in test_llm.async_streaming_completion(
messages=[{'role': 'user', 'content': 'Hello!'}], stream=True
):
received_chunks.append(chunk['choices'][0]['delta']['content'])
print(f"Chunk: {chunk['choices'][0]['delta']['content']}")
# Assert that we received the expected number of chunks before cancellation
assert len(received_chunks) == cancel_after_chunks
assert received_chunks == test_messages[:cancel_after_chunks]
# Ensure the mock was called
mock_call_acompletion.assert_called_once()