mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Ignore Some Messages When Transforming (#2661)
* works * spelling * returned old docstring * add cache fix * spelling? --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
@@ -8,6 +8,7 @@ from termcolor import colored
|
||||
|
||||
from autogen import token_count_utils
|
||||
from autogen.cache import AbstractCache, Cache
|
||||
from autogen.oai.openai_utils import filter_config
|
||||
|
||||
from .text_compressors import LLMLingua, TextCompressor
|
||||
|
||||
@@ -130,6 +131,8 @@ class MessageTokenLimiter:
|
||||
max_tokens: Optional[int] = None,
|
||||
min_tokens: Optional[int] = None,
|
||||
model: str = "gpt-3.5-turbo-0613",
|
||||
filter_dict: Optional[Dict] = None,
|
||||
exclude_filter: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -140,11 +143,17 @@ class MessageTokenLimiter:
|
||||
min_tokens (Optional[int]): Minimum number of tokens in messages to apply the transformation.
|
||||
Must be greater than or equal to 0 if not None.
|
||||
model (str): The target OpenAI model for tokenization alignment.
|
||||
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
|
||||
If None, no filters will be applied.
|
||||
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
|
||||
excluded from token truncation. If False, messages that match the filter will be truncated.
|
||||
"""
|
||||
self._model = model
|
||||
self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message)
|
||||
self._max_tokens = self._validate_max_tokens(max_tokens)
|
||||
self._min_tokens = self._validate_min_tokens(min_tokens, max_tokens)
|
||||
self._filter_dict = filter_dict
|
||||
self._exclude_filter = exclude_filter
|
||||
|
||||
def apply_transform(self, messages: List[Dict]) -> List[Dict]:
|
||||
"""Applies token truncation to the conversation history.
|
||||
@@ -169,10 +178,15 @@ class MessageTokenLimiter:
|
||||
|
||||
for msg in reversed(temp_messages):
|
||||
# Some messages may not have content.
|
||||
if not isinstance(msg.get("content"), (str, list)):
|
||||
if not _is_content_right_type(msg.get("content")):
|
||||
processed_messages.insert(0, msg)
|
||||
continue
|
||||
|
||||
if not _should_transform_message(msg, self._filter_dict, self._exclude_filter):
|
||||
processed_messages.insert(0, msg)
|
||||
processed_messages_tokens += _count_tokens(msg["content"])
|
||||
continue
|
||||
|
||||
expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message
|
||||
|
||||
# If adding this message would exceed the token limit, truncate the last message to meet the total token
|
||||
@@ -282,6 +296,8 @@ class TextMessageCompressor:
|
||||
min_tokens: Optional[int] = None,
|
||||
compression_params: Dict = dict(),
|
||||
cache: Optional[AbstractCache] = Cache.disk(),
|
||||
filter_dict: Optional[Dict] = None,
|
||||
exclude_filter: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -293,6 +309,10 @@ class TextMessageCompressor:
|
||||
dictionary.
|
||||
cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages.
|
||||
If None, no caching will be used.
|
||||
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
|
||||
If None, no filters will be applied.
|
||||
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
|
||||
excluded from compression. If False, messages that match the filter will be compressed.
|
||||
"""
|
||||
|
||||
if text_compressor is None:
|
||||
@@ -303,6 +323,8 @@ class TextMessageCompressor:
|
||||
self._text_compressor = text_compressor
|
||||
self._min_tokens = min_tokens
|
||||
self._compression_args = compression_params
|
||||
self._filter_dict = filter_dict
|
||||
self._exclude_filter = exclude_filter
|
||||
self._cache = cache
|
||||
|
||||
# Optimizing savings calculations to optimize log generation
|
||||
@@ -334,7 +356,10 @@ class TextMessageCompressor:
|
||||
processed_messages = messages.copy()
|
||||
for message in processed_messages:
|
||||
# Some messages may not have content.
|
||||
if not isinstance(message.get("content"), (str, list)):
|
||||
if not _is_content_right_type(message.get("content")):
|
||||
continue
|
||||
|
||||
if not _should_transform_message(message, self._filter_dict, self._exclude_filter):
|
||||
continue
|
||||
|
||||
if _is_content_text_empty(message["content"]):
|
||||
@@ -397,7 +422,7 @@ class TextMessageCompressor:
|
||||
self, content: Union[str, List[Dict]], compressed_content: Union[str, List[Dict]], tokens_saved: int
|
||||
):
|
||||
if self._cache:
|
||||
value = (tokens_saved, json.dumps(compressed_content))
|
||||
value = (tokens_saved, compressed_content)
|
||||
self._cache.set(self._cache_key(content), value)
|
||||
|
||||
def _cache_key(self, content: Union[str, List[Dict]]) -> str:
|
||||
@@ -427,6 +452,10 @@ def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
|
||||
return token_count
|
||||
|
||||
|
||||
def _is_content_right_type(content: Any) -> bool:
|
||||
return isinstance(content, (str, list))
|
||||
|
||||
|
||||
def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool:
|
||||
if isinstance(content, str):
|
||||
return content == ""
|
||||
@@ -434,3 +463,10 @@ def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool:
|
||||
return all(_is_content_text_empty(item.get("text", "")) for item in content)
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def _should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool:
|
||||
if not filter_dict:
|
||||
return True
|
||||
|
||||
return len(filter_config([message], filter_dict, exclude)) > 0
|
||||
|
||||
@@ -379,11 +379,10 @@ def config_list_gpt4_gpt35(
|
||||
def filter_config(
|
||||
config_list: List[Dict[str, Any]],
|
||||
filter_dict: Optional[Dict[str, Union[List[Union[str, None]], Set[Union[str, None]]]]],
|
||||
exclude: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
This function filters `config_list` by checking each configuration dictionary against the
|
||||
criteria specified in `filter_dict`. A configuration dictionary is retained if for every
|
||||
key in `filter_dict`, see example below.
|
||||
"""This function filters `config_list` by checking each configuration dictionary against the criteria specified in
|
||||
`filter_dict`. A configuration dictionary is retained if for every key in `filter_dict`, see example below.
|
||||
|
||||
Args:
|
||||
config_list (list of dict): A list of configuration dictionaries to be filtered.
|
||||
@@ -394,71 +393,68 @@ def filter_config(
|
||||
when it is found in the list of acceptable values. If the configuration's
|
||||
field's value is a list, then a match occurs if there is a non-empty
|
||||
intersection with the acceptable values.
|
||||
|
||||
|
||||
exclude (bool): If False (the default value), configs that match the filter will be included in the returned
|
||||
list. If True, configs that match the filter will be excluded in the returned list.
|
||||
Returns:
|
||||
list of dict: A list of configuration dictionaries that meet all the criteria specified
|
||||
in `filter_dict`.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Example configuration list with various models and API types
|
||||
configs = [
|
||||
{'model': 'gpt-3.5-turbo'},
|
||||
{'model': 'gpt-4'},
|
||||
{'model': 'gpt-3.5-turbo', 'api_type': 'azure'},
|
||||
{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']},
|
||||
]
|
||||
|
||||
# Define filter criteria to select configurations for the 'gpt-3.5-turbo' model
|
||||
# that are also using the 'azure' API type
|
||||
filter_criteria = {
|
||||
'model': ['gpt-3.5-turbo'], # Only accept configurations for 'gpt-3.5-turbo'
|
||||
'api_type': ['azure'] # Only accept configurations for 'azure' API type
|
||||
}
|
||||
|
||||
# Apply the filter to the configuration list
|
||||
filtered_configs = filter_config(configs, filter_criteria)
|
||||
|
||||
# The resulting `filtered_configs` will be:
|
||||
# [{'model': 'gpt-3.5-turbo', 'api_type': 'azure', ...}]
|
||||
|
||||
|
||||
# Define a filter to select a given tag
|
||||
filter_criteria = {
|
||||
'tags': ['gpt35_turbo'],
|
||||
}
|
||||
|
||||
# Apply the filter to the configuration list
|
||||
filtered_configs = filter_config(configs, filter_criteria)
|
||||
|
||||
# The resulting `filtered_configs` will be:
|
||||
# [{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}]
|
||||
```
|
||||
|
||||
```python
|
||||
# Example configuration list with various models and API types
|
||||
configs = [
|
||||
{'model': 'gpt-3.5-turbo'},
|
||||
{'model': 'gpt-4'},
|
||||
{'model': 'gpt-3.5-turbo', 'api_type': 'azure'},
|
||||
{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']},
|
||||
]
|
||||
# Define filter criteria to select configurations for the 'gpt-3.5-turbo' model
|
||||
# that are also using the 'azure' API type
|
||||
filter_criteria = {
|
||||
'model': ['gpt-3.5-turbo'], # Only accept configurations for 'gpt-3.5-turbo'
|
||||
'api_type': ['azure'] # Only accept configurations for 'azure' API type
|
||||
}
|
||||
# Apply the filter to the configuration list
|
||||
filtered_configs = filter_config(configs, filter_criteria)
|
||||
# The resulting `filtered_configs` will be:
|
||||
# [{'model': 'gpt-3.5-turbo', 'api_type': 'azure', ...}]
|
||||
# Define a filter to select a given tag
|
||||
filter_criteria = {
|
||||
'tags': ['gpt35_turbo'],
|
||||
}
|
||||
# Apply the filter to the configuration list
|
||||
filtered_configs = filter_config(configs, filter_criteria)
|
||||
# The resulting `filtered_configs` will be:
|
||||
# [{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}]
|
||||
```
|
||||
Note:
|
||||
- If `filter_dict` is empty or None, no filtering is applied and `config_list` is returned as is.
|
||||
- If a configuration dictionary in `config_list` does not contain a key specified in `filter_dict`,
|
||||
it is considered a non-match and is excluded from the result.
|
||||
- If the list of acceptable values for a key in `filter_dict` includes None, then configuration
|
||||
dictionaries that do not have that key will also be considered a match.
|
||||
|
||||
"""
|
||||
|
||||
def _satisfies(config_value: Any, acceptable_values: Any) -> bool:
|
||||
if isinstance(config_value, list):
|
||||
return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection
|
||||
else:
|
||||
return config_value in acceptable_values
|
||||
|
||||
if filter_dict:
|
||||
config_list = [
|
||||
config
|
||||
for config in config_list
|
||||
if all(_satisfies(config.get(key), value) for key, value in filter_dict.items())
|
||||
return [
|
||||
item
|
||||
for item in config_list
|
||||
if all(_satisfies_criteria(item.get(key), values) != exclude for key, values in filter_dict.items())
|
||||
]
|
||||
return config_list
|
||||
|
||||
|
||||
def _satisfies_criteria(value: Any, criteria_values: Any) -> bool:
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
if isinstance(value, list):
|
||||
return bool(set(value) & set(criteria_values)) # Non-empty intersection
|
||||
else:
|
||||
return value in criteria_values
|
||||
|
||||
|
||||
def config_list_from_json(
|
||||
env_or_file: str,
|
||||
file_location: Optional[str] = "",
|
||||
@@ -785,3 +781,10 @@ def update_gpt_assistant(client: OpenAI, assistant_id: str, assistant_config: Di
|
||||
assistant_update_kwargs["file_ids"] = assistant_config["file_ids"]
|
||||
|
||||
return client.beta.assistants.update(assistant_id=assistant_id, **assistant_update_kwargs)
|
||||
|
||||
|
||||
def _satisfies(config_value: Any, acceptable_values: Any) -> bool:
|
||||
if isinstance(config_value, list):
|
||||
return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection
|
||||
else:
|
||||
return config_value in acceptable_values
|
||||
|
||||
@@ -1,10 +1,21 @@
|
||||
import copy
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from autogen.agentchat.contrib.capabilities.transforms import MessageHistoryLimiter, MessageTokenLimiter, _count_tokens
|
||||
from autogen.agentchat.contrib.capabilities.text_compressors import TextCompressor
|
||||
from autogen.agentchat.contrib.capabilities.transforms import (
|
||||
MessageHistoryLimiter,
|
||||
MessageTokenLimiter,
|
||||
TextMessageCompressor,
|
||||
_count_tokens,
|
||||
)
|
||||
|
||||
|
||||
class _MockTextCompressor:
|
||||
def compress_text(self, text: str, **compression_params) -> Dict[str, Any]:
|
||||
return {"compressed_prompt": ""}
|
||||
|
||||
|
||||
def get_long_messages() -> List[Dict]:
|
||||
@@ -29,6 +40,18 @@ def get_no_content_messages() -> List[Dict]:
|
||||
return [{"role": "user", "function_call": "example"}, {"role": "assistant", "content": None}]
|
||||
|
||||
|
||||
def get_text_compressors() -> List[TextCompressor]:
|
||||
compressors: List[TextCompressor] = [_MockTextCompressor()]
|
||||
try:
|
||||
from autogen.agentchat.contrib.capabilities.text_compressors import LLMLingua
|
||||
|
||||
compressors.append(LLMLingua())
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return compressors
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def message_history_limiter() -> MessageHistoryLimiter:
|
||||
return MessageHistoryLimiter(max_messages=3)
|
||||
@@ -44,6 +67,30 @@ def message_token_limiter_with_threshold() -> MessageTokenLimiter:
|
||||
return MessageTokenLimiter(max_tokens_per_message=1, min_tokens=10)
|
||||
|
||||
|
||||
def _filter_dict_test(
|
||||
post_transformed_message: Dict, pre_transformed_messages: Dict, roles: List[str], exclude_filter: bool
|
||||
) -> bool:
|
||||
is_role = post_transformed_message["role"] in roles
|
||||
if exclude_filter:
|
||||
is_role = not is_role
|
||||
|
||||
if isinstance(post_transformed_message["content"], list):
|
||||
condition = (
|
||||
len(post_transformed_message["content"][0]["text"]) < len(pre_transformed_messages["content"][0]["text"])
|
||||
if is_role
|
||||
else len(post_transformed_message["content"][0]["text"])
|
||||
== len(pre_transformed_messages["content"][0]["text"])
|
||||
)
|
||||
else:
|
||||
condition = (
|
||||
len(post_transformed_message["content"]) < len(pre_transformed_messages["content"])
|
||||
if is_role
|
||||
else len(post_transformed_message["content"]) == len(pre_transformed_messages["content"])
|
||||
)
|
||||
|
||||
return condition
|
||||
|
||||
|
||||
# MessageHistoryLimiter
|
||||
|
||||
|
||||
@@ -82,13 +129,35 @@ def test_message_history_limiter_get_logs(message_history_limiter, messages, exp
|
||||
def test_message_token_limiter_apply_transform(
|
||||
message_token_limiter, messages, expected_token_count, expected_messages_len
|
||||
):
|
||||
transformed_messages = message_token_limiter.apply_transform(messages)
|
||||
transformed_messages = message_token_limiter.apply_transform(copy.deepcopy(messages))
|
||||
assert (
|
||||
sum(_count_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) == expected_token_count
|
||||
)
|
||||
assert len(transformed_messages) == expected_messages_len
|
||||
|
||||
|
||||
@pytest.mark.parametrize("messages", [get_long_messages(), get_short_messages()])
|
||||
def test_message_token_limiter_with_filter(messages):
|
||||
# Test truncating all messages except for user
|
||||
message_token_limiter = MessageTokenLimiter(max_tokens_per_message=0, filter_dict={"role": "user"})
|
||||
transformed_messages = message_token_limiter.apply_transform(copy.deepcopy(messages))
|
||||
|
||||
pre_post_messages = zip(messages, transformed_messages)
|
||||
|
||||
for pre_transform, post_transform in pre_post_messages:
|
||||
assert _filter_dict_test(post_transform, pre_transform, ["user"], exclude_filter=True)
|
||||
|
||||
# Test truncating all user messages only
|
||||
message_token_limiter = MessageTokenLimiter(
|
||||
max_tokens_per_message=0, filter_dict={"role": "user"}, exclude_filter=False
|
||||
)
|
||||
transformed_messages = message_token_limiter.apply_transform(copy.deepcopy(messages))
|
||||
|
||||
pre_post_messages = zip(messages, transformed_messages)
|
||||
for pre_transform, post_transform in pre_post_messages:
|
||||
assert _filter_dict_test(post_transform, pre_transform, ["user"], exclude_filter=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"messages, expected_token_count, expected_messages_len",
|
||||
[(get_long_messages(), 5, 5), (get_short_messages(), 5, 3), (get_no_content_messages(), 0, 2)],
|
||||
@@ -119,49 +188,60 @@ def test_message_token_limiter_get_logs(message_token_limiter, messages, expecte
|
||||
assert logs_str == expected_logs
|
||||
|
||||
|
||||
def test_text_compression():
|
||||
"""Test the TextMessageCompressor transform."""
|
||||
try:
|
||||
from autogen.agentchat.contrib.capabilities.transforms import TextMessageCompressor
|
||||
# TextMessageCompressor tests
|
||||
|
||||
text_compressor = TextMessageCompressor()
|
||||
except ImportError:
|
||||
pytest.skip("LLM Lingua is not installed.")
|
||||
|
||||
@pytest.mark.parametrize("text_compressor", get_text_compressors())
|
||||
def test_text_compression(text_compressor):
|
||||
"""Test the TextMessageCompressor transform."""
|
||||
|
||||
compressor = TextMessageCompressor(text_compressor=text_compressor)
|
||||
|
||||
text = "Run this test with a long string. "
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "".join([text] * 3)}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "".join([text] * 3)}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "".join([text] * 3)}],
|
||||
},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "".join([text] * 3)}]},
|
||||
{"role": "role", "content": [{"type": "text", "text": "".join([text] * 3)}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "".join([text] * 3)}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "".join([text] * 3)}]},
|
||||
]
|
||||
|
||||
transformed_messages = text_compressor.apply_transform([{"content": text}])
|
||||
transformed_messages = compressor.apply_transform([{"content": text}])
|
||||
|
||||
assert len(transformed_messages[0]["content"]) < len(text)
|
||||
|
||||
# Test compressing all messages
|
||||
text_compressor = TextMessageCompressor()
|
||||
transformed_messages = text_compressor.apply_transform(copy.deepcopy(messages))
|
||||
for message in transformed_messages:
|
||||
assert len(message["content"][0]["text"]) < len(messages[0]["content"][0]["text"])
|
||||
compressor = TextMessageCompressor(text_compressor=text_compressor)
|
||||
transformed_messages = compressor.apply_transform(copy.deepcopy(messages))
|
||||
|
||||
pre_post_messages = zip(messages, transformed_messages)
|
||||
for pre_transform, post_transform in pre_post_messages:
|
||||
assert len(post_transform["content"][0]["text"]) < len(pre_transform["content"][0]["text"])
|
||||
|
||||
|
||||
def test_text_compression_cache():
|
||||
try:
|
||||
from autogen.agentchat.contrib.capabilities.transforms import TextMessageCompressor
|
||||
@pytest.mark.parametrize("messages", [get_long_messages(), get_short_messages()])
|
||||
@pytest.mark.parametrize("text_compressor", get_text_compressors())
|
||||
def test_text_compression_with_filter(messages, text_compressor):
|
||||
# Test truncating all messages except for user
|
||||
compressor = TextMessageCompressor(text_compressor=text_compressor, filter_dict={"role": "user"})
|
||||
transformed_messages = compressor.apply_transform(copy.deepcopy(messages))
|
||||
|
||||
except ImportError:
|
||||
pytest.skip("LLM Lingua is not installed.")
|
||||
pre_post_messages = zip(messages, transformed_messages)
|
||||
for pre_transform, post_transform in pre_post_messages:
|
||||
assert _filter_dict_test(post_transform, pre_transform, ["user"], exclude_filter=True)
|
||||
|
||||
# Test truncating all user messages only
|
||||
compressor = TextMessageCompressor(
|
||||
text_compressor=text_compressor, filter_dict={"role": "user"}, exclude_filter=False
|
||||
)
|
||||
transformed_messages = compressor.apply_transform(copy.deepcopy(messages))
|
||||
|
||||
pre_post_messages = zip(messages, transformed_messages)
|
||||
for pre_transform, post_transform in pre_post_messages:
|
||||
assert _filter_dict_test(post_transform, pre_transform, ["user"], exclude_filter=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text_compressor", get_text_compressors())
|
||||
def test_text_compression_cache(text_compressor):
|
||||
messages = get_long_messages()
|
||||
mock_compressed_content = (1, {"content": "mock"})
|
||||
|
||||
@@ -171,18 +251,18 @@ def test_text_compression_cache():
|
||||
) as mocked_get, patch(
|
||||
"autogen.agentchat.contrib.capabilities.transforms.TextMessageCompressor._cache_set", MagicMock()
|
||||
) as mocked_set:
|
||||
text_compressor = TextMessageCompressor()
|
||||
compressor = TextMessageCompressor(text_compressor=text_compressor)
|
||||
|
||||
text_compressor.apply_transform(messages)
|
||||
text_compressor.apply_transform(messages)
|
||||
compressor.apply_transform(messages)
|
||||
compressor.apply_transform(messages)
|
||||
|
||||
assert mocked_get.call_count == len(messages)
|
||||
assert mocked_set.call_count == len(messages)
|
||||
|
||||
# We already populated the cache with the mock content
|
||||
# We need to test if we retrieve the correct content
|
||||
text_compressor = TextMessageCompressor()
|
||||
compressed_messages = text_compressor.apply_transform(messages)
|
||||
compressor = TextMessageCompressor(text_compressor=text_compressor)
|
||||
compressed_messages = compressor.apply_transform(messages)
|
||||
|
||||
for message in compressed_messages:
|
||||
assert message["content"] == mock_compressed_content[1]
|
||||
|
||||
@@ -4,6 +4,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Dict, List
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -43,11 +44,13 @@ JSON_SAMPLE = """
|
||||
[
|
||||
{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_type": "openai"
|
||||
"api_type": "openai",
|
||||
"tags": ["gpt35"]
|
||||
},
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"api_type": "openai"
|
||||
"api_type": "openai",
|
||||
"tags": ["gpt4"]
|
||||
},
|
||||
{
|
||||
"model": "gpt-35-turbo-v0301",
|
||||
@@ -65,6 +68,33 @@ JSON_SAMPLE = """
|
||||
]
|
||||
"""
|
||||
|
||||
JSON_SAMPLE_DICT = json.loads(JSON_SAMPLE)
|
||||
|
||||
|
||||
FILTER_CONFIG_TEST = [
|
||||
{
|
||||
"filter_dict": {"tags": ["gpt35", "gpt4"]},
|
||||
"exclude": False,
|
||||
"expected": JSON_SAMPLE_DICT[0:2],
|
||||
},
|
||||
{
|
||||
"filter_dict": {"tags": ["gpt35", "gpt4"]},
|
||||
"exclude": True,
|
||||
"expected": JSON_SAMPLE_DICT[2:4],
|
||||
},
|
||||
{
|
||||
"filter_dict": {"api_type": "azure", "api_version": "2024-02-15-preview"},
|
||||
"exclude": False,
|
||||
"expected": [JSON_SAMPLE_DICT[2]],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _compare_lists_of_dicts(list1: List[Dict], list2: List[Dict]) -> bool:
|
||||
dump1 = sorted(json.dumps(d, sort_keys=True) for d in list1)
|
||||
dump2 = sorted(json.dumps(d, sort_keys=True) for d in list2)
|
||||
return dump1 == dump2
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_os_environ():
|
||||
@@ -72,6 +102,17 @@ def mock_os_environ():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_case", FILTER_CONFIG_TEST)
|
||||
def test_filter_config(test_case):
|
||||
filter_dict = test_case["filter_dict"]
|
||||
exclude = test_case["exclude"]
|
||||
expected = test_case["expected"]
|
||||
|
||||
config_list = filter_config(JSON_SAMPLE_DICT, filter_dict, exclude)
|
||||
|
||||
assert _compare_lists_of_dicts(config_list, expected)
|
||||
|
||||
|
||||
def test_config_list_from_json():
|
||||
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp_file:
|
||||
json_data = json.loads(JSON_SAMPLE)
|
||||
|
||||
Reference in New Issue
Block a user