Ignore Some Messages When Transforming (#2661)

* works

* spelling

* returned old docstring

* add cache fix

* spelling?

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Wael Karkoub
2024-05-22 21:22:17 +01:00
committed by GitHub
parent 3e11b07d1d
commit 4ebfb82186
4 changed files with 253 additions and 93 deletions

View File

@@ -8,6 +8,7 @@ from termcolor import colored
from autogen import token_count_utils
from autogen.cache import AbstractCache, Cache
from autogen.oai.openai_utils import filter_config
from .text_compressors import LLMLingua, TextCompressor
@@ -130,6 +131,8 @@ class MessageTokenLimiter:
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
model: str = "gpt-3.5-turbo-0613",
filter_dict: Optional[Dict] = None,
exclude_filter: bool = True,
):
"""
Args:
@@ -140,11 +143,17 @@ class MessageTokenLimiter:
min_tokens (Optional[int]): Minimum number of tokens in messages to apply the transformation.
Must be greater than or equal to 0 if not None.
model (str): The target OpenAI model for tokenization alignment.
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
If None, no filters will be applied.
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
excluded from token truncation. If False, messages that match the filter will be truncated.
"""
self._model = model
self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message)
self._max_tokens = self._validate_max_tokens(max_tokens)
self._min_tokens = self._validate_min_tokens(min_tokens, max_tokens)
self._filter_dict = filter_dict
self._exclude_filter = exclude_filter
def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Applies token truncation to the conversation history.
@@ -169,10 +178,15 @@ class MessageTokenLimiter:
for msg in reversed(temp_messages):
# Some messages may not have content.
if not isinstance(msg.get("content"), (str, list)):
if not _is_content_right_type(msg.get("content")):
processed_messages.insert(0, msg)
continue
if not _should_transform_message(msg, self._filter_dict, self._exclude_filter):
processed_messages.insert(0, msg)
processed_messages_tokens += _count_tokens(msg["content"])
continue
expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message
# If adding this message would exceed the token limit, truncate the last message to meet the total token
@@ -282,6 +296,8 @@ class TextMessageCompressor:
min_tokens: Optional[int] = None,
compression_params: Dict = dict(),
cache: Optional[AbstractCache] = Cache.disk(),
filter_dict: Optional[Dict] = None,
exclude_filter: bool = True,
):
"""
Args:
@@ -293,6 +309,10 @@ class TextMessageCompressor:
dictionary.
cache (None or AbstractCache): The cache client to use to store and retrieve previously compressed messages.
If None, no caching will be used.
filter_dict (None or dict): A dictionary to filter out messages that you want/don't want to compress.
If None, no filters will be applied.
exclude_filter (bool): If exclude filter is True (the default value), messages that match the filter will be
excluded from compression. If False, messages that match the filter will be compressed.
"""
if text_compressor is None:
@@ -303,6 +323,8 @@ class TextMessageCompressor:
self._text_compressor = text_compressor
self._min_tokens = min_tokens
self._compression_args = compression_params
self._filter_dict = filter_dict
self._exclude_filter = exclude_filter
self._cache = cache
# Optimizing savings calculations to optimize log generation
@@ -334,7 +356,10 @@ class TextMessageCompressor:
processed_messages = messages.copy()
for message in processed_messages:
# Some messages may not have content.
if not isinstance(message.get("content"), (str, list)):
if not _is_content_right_type(message.get("content")):
continue
if not _should_transform_message(message, self._filter_dict, self._exclude_filter):
continue
if _is_content_text_empty(message["content"]):
@@ -397,7 +422,7 @@ class TextMessageCompressor:
self, content: Union[str, List[Dict]], compressed_content: Union[str, List[Dict]], tokens_saved: int
):
if self._cache:
value = (tokens_saved, json.dumps(compressed_content))
value = (tokens_saved, compressed_content)
self._cache.set(self._cache_key(content), value)
def _cache_key(self, content: Union[str, List[Dict]]) -> str:
@@ -427,6 +452,10 @@ def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
return token_count
def _is_content_right_type(content: Any) -> bool:
return isinstance(content, (str, list))
def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool:
if isinstance(content, str):
return content == ""
@@ -434,3 +463,10 @@ def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool:
return all(_is_content_text_empty(item.get("text", "")) for item in content)
else:
return False
def _should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool:
if not filter_dict:
return True
return len(filter_config([message], filter_dict, exclude)) > 0

View File

@@ -379,11 +379,10 @@ def config_list_gpt4_gpt35(
def filter_config(
config_list: List[Dict[str, Any]],
filter_dict: Optional[Dict[str, Union[List[Union[str, None]], Set[Union[str, None]]]]],
exclude: bool = False,
) -> List[Dict[str, Any]]:
"""
This function filters `config_list` by checking each configuration dictionary against the
criteria specified in `filter_dict`. A configuration dictionary is retained if for every
key in `filter_dict`, see example below.
"""This function filters `config_list` by checking each configuration dictionary against the criteria specified in
`filter_dict`. A configuration dictionary is retained if for every key in `filter_dict`, see example below.
Args:
config_list (list of dict): A list of configuration dictionaries to be filtered.
@@ -394,71 +393,68 @@ def filter_config(
when it is found in the list of acceptable values. If the configuration's
field's value is a list, then a match occurs if there is a non-empty
intersection with the acceptable values.
exclude (bool): If False (the default value), configs that match the filter will be included in the returned
list. If True, configs that match the filter will be excluded in the returned list.
Returns:
list of dict: A list of configuration dictionaries that meet all the criteria specified
in `filter_dict`.
Example:
```python
# Example configuration list with various models and API types
configs = [
{'model': 'gpt-3.5-turbo'},
{'model': 'gpt-4'},
{'model': 'gpt-3.5-turbo', 'api_type': 'azure'},
{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']},
]
# Define filter criteria to select configurations for the 'gpt-3.5-turbo' model
# that are also using the 'azure' API type
filter_criteria = {
'model': ['gpt-3.5-turbo'], # Only accept configurations for 'gpt-3.5-turbo'
'api_type': ['azure'] # Only accept configurations for 'azure' API type
}
# Apply the filter to the configuration list
filtered_configs = filter_config(configs, filter_criteria)
# The resulting `filtered_configs` will be:
# [{'model': 'gpt-3.5-turbo', 'api_type': 'azure', ...}]
# Define a filter to select a given tag
filter_criteria = {
'tags': ['gpt35_turbo'],
}
# Apply the filter to the configuration list
filtered_configs = filter_config(configs, filter_criteria)
# The resulting `filtered_configs` will be:
# [{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}]
```
```python
# Example configuration list with various models and API types
configs = [
{'model': 'gpt-3.5-turbo'},
{'model': 'gpt-4'},
{'model': 'gpt-3.5-turbo', 'api_type': 'azure'},
{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']},
]
# Define filter criteria to select configurations for the 'gpt-3.5-turbo' model
# that are also using the 'azure' API type
filter_criteria = {
'model': ['gpt-3.5-turbo'], # Only accept configurations for 'gpt-3.5-turbo'
'api_type': ['azure'] # Only accept configurations for 'azure' API type
}
# Apply the filter to the configuration list
filtered_configs = filter_config(configs, filter_criteria)
# The resulting `filtered_configs` will be:
# [{'model': 'gpt-3.5-turbo', 'api_type': 'azure', ...}]
# Define a filter to select a given tag
filter_criteria = {
'tags': ['gpt35_turbo'],
}
# Apply the filter to the configuration list
filtered_configs = filter_config(configs, filter_criteria)
# The resulting `filtered_configs` will be:
# [{'model': 'gpt-3.5-turbo', 'tags': ['gpt35_turbo', 'gpt-35-turbo']}]
```
Note:
- If `filter_dict` is empty or None, no filtering is applied and `config_list` is returned as is.
- If a configuration dictionary in `config_list` does not contain a key specified in `filter_dict`,
it is considered a non-match and is excluded from the result.
- If the list of acceptable values for a key in `filter_dict` includes None, then configuration
dictionaries that do not have that key will also be considered a match.
"""
def _satisfies(config_value: Any, acceptable_values: Any) -> bool:
if isinstance(config_value, list):
return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection
else:
return config_value in acceptable_values
if filter_dict:
config_list = [
config
for config in config_list
if all(_satisfies(config.get(key), value) for key, value in filter_dict.items())
return [
item
for item in config_list
if all(_satisfies_criteria(item.get(key), values) != exclude for key, values in filter_dict.items())
]
return config_list
def _satisfies_criteria(value: Any, criteria_values: Any) -> bool:
if value is None:
return False
if isinstance(value, list):
return bool(set(value) & set(criteria_values)) # Non-empty intersection
else:
return value in criteria_values
def config_list_from_json(
env_or_file: str,
file_location: Optional[str] = "",
@@ -785,3 +781,10 @@ def update_gpt_assistant(client: OpenAI, assistant_id: str, assistant_config: Di
assistant_update_kwargs["file_ids"] = assistant_config["file_ids"]
return client.beta.assistants.update(assistant_id=assistant_id, **assistant_update_kwargs)
def _satisfies(config_value: Any, acceptable_values: Any) -> bool:
if isinstance(config_value, list):
return bool(set(config_value) & set(acceptable_values)) # Non-empty intersection
else:
return config_value in acceptable_values

View File

@@ -1,10 +1,21 @@
import copy
from typing import Dict, List
from typing import Any, Dict, List
from unittest.mock import MagicMock, patch
import pytest
from autogen.agentchat.contrib.capabilities.transforms import MessageHistoryLimiter, MessageTokenLimiter, _count_tokens
from autogen.agentchat.contrib.capabilities.text_compressors import TextCompressor
from autogen.agentchat.contrib.capabilities.transforms import (
MessageHistoryLimiter,
MessageTokenLimiter,
TextMessageCompressor,
_count_tokens,
)
class _MockTextCompressor:
def compress_text(self, text: str, **compression_params) -> Dict[str, Any]:
return {"compressed_prompt": ""}
def get_long_messages() -> List[Dict]:
@@ -29,6 +40,18 @@ def get_no_content_messages() -> List[Dict]:
return [{"role": "user", "function_call": "example"}, {"role": "assistant", "content": None}]
def get_text_compressors() -> List[TextCompressor]:
compressors: List[TextCompressor] = [_MockTextCompressor()]
try:
from autogen.agentchat.contrib.capabilities.text_compressors import LLMLingua
compressors.append(LLMLingua())
except ImportError:
pass
return compressors
@pytest.fixture
def message_history_limiter() -> MessageHistoryLimiter:
return MessageHistoryLimiter(max_messages=3)
@@ -44,6 +67,30 @@ def message_token_limiter_with_threshold() -> MessageTokenLimiter:
return MessageTokenLimiter(max_tokens_per_message=1, min_tokens=10)
def _filter_dict_test(
post_transformed_message: Dict, pre_transformed_messages: Dict, roles: List[str], exclude_filter: bool
) -> bool:
is_role = post_transformed_message["role"] in roles
if exclude_filter:
is_role = not is_role
if isinstance(post_transformed_message["content"], list):
condition = (
len(post_transformed_message["content"][0]["text"]) < len(pre_transformed_messages["content"][0]["text"])
if is_role
else len(post_transformed_message["content"][0]["text"])
== len(pre_transformed_messages["content"][0]["text"])
)
else:
condition = (
len(post_transformed_message["content"]) < len(pre_transformed_messages["content"])
if is_role
else len(post_transformed_message["content"]) == len(pre_transformed_messages["content"])
)
return condition
# MessageHistoryLimiter
@@ -82,13 +129,35 @@ def test_message_history_limiter_get_logs(message_history_limiter, messages, exp
def test_message_token_limiter_apply_transform(
message_token_limiter, messages, expected_token_count, expected_messages_len
):
transformed_messages = message_token_limiter.apply_transform(messages)
transformed_messages = message_token_limiter.apply_transform(copy.deepcopy(messages))
assert (
sum(_count_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) == expected_token_count
)
assert len(transformed_messages) == expected_messages_len
@pytest.mark.parametrize("messages", [get_long_messages(), get_short_messages()])
def test_message_token_limiter_with_filter(messages):
# Test truncating all messages except for user
message_token_limiter = MessageTokenLimiter(max_tokens_per_message=0, filter_dict={"role": "user"})
transformed_messages = message_token_limiter.apply_transform(copy.deepcopy(messages))
pre_post_messages = zip(messages, transformed_messages)
for pre_transform, post_transform in pre_post_messages:
assert _filter_dict_test(post_transform, pre_transform, ["user"], exclude_filter=True)
# Test truncating all user messages only
message_token_limiter = MessageTokenLimiter(
max_tokens_per_message=0, filter_dict={"role": "user"}, exclude_filter=False
)
transformed_messages = message_token_limiter.apply_transform(copy.deepcopy(messages))
pre_post_messages = zip(messages, transformed_messages)
for pre_transform, post_transform in pre_post_messages:
assert _filter_dict_test(post_transform, pre_transform, ["user"], exclude_filter=False)
@pytest.mark.parametrize(
"messages, expected_token_count, expected_messages_len",
[(get_long_messages(), 5, 5), (get_short_messages(), 5, 3), (get_no_content_messages(), 0, 2)],
@@ -119,49 +188,60 @@ def test_message_token_limiter_get_logs(message_token_limiter, messages, expecte
assert logs_str == expected_logs
def test_text_compression():
"""Test the TextMessageCompressor transform."""
try:
from autogen.agentchat.contrib.capabilities.transforms import TextMessageCompressor
# TextMessageCompressor tests
text_compressor = TextMessageCompressor()
except ImportError:
pytest.skip("LLM Lingua is not installed.")
@pytest.mark.parametrize("text_compressor", get_text_compressors())
def test_text_compression(text_compressor):
"""Test the TextMessageCompressor transform."""
compressor = TextMessageCompressor(text_compressor=text_compressor)
text = "Run this test with a long string. "
messages = [
{
"role": "assistant",
"content": [{"type": "text", "text": "".join([text] * 3)}],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "".join([text] * 3)}],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "".join([text] * 3)}],
},
{"role": "assistant", "content": [{"type": "text", "text": "".join([text] * 3)}]},
{"role": "role", "content": [{"type": "text", "text": "".join([text] * 3)}]},
{"role": "assistant", "content": [{"type": "text", "text": "".join([text] * 3)}]},
{"role": "assistant", "content": [{"type": "text", "text": "".join([text] * 3)}]},
]
transformed_messages = text_compressor.apply_transform([{"content": text}])
transformed_messages = compressor.apply_transform([{"content": text}])
assert len(transformed_messages[0]["content"]) < len(text)
# Test compressing all messages
text_compressor = TextMessageCompressor()
transformed_messages = text_compressor.apply_transform(copy.deepcopy(messages))
for message in transformed_messages:
assert len(message["content"][0]["text"]) < len(messages[0]["content"][0]["text"])
compressor = TextMessageCompressor(text_compressor=text_compressor)
transformed_messages = compressor.apply_transform(copy.deepcopy(messages))
pre_post_messages = zip(messages, transformed_messages)
for pre_transform, post_transform in pre_post_messages:
assert len(post_transform["content"][0]["text"]) < len(pre_transform["content"][0]["text"])
def test_text_compression_cache():
try:
from autogen.agentchat.contrib.capabilities.transforms import TextMessageCompressor
@pytest.mark.parametrize("messages", [get_long_messages(), get_short_messages()])
@pytest.mark.parametrize("text_compressor", get_text_compressors())
def test_text_compression_with_filter(messages, text_compressor):
# Test truncating all messages except for user
compressor = TextMessageCompressor(text_compressor=text_compressor, filter_dict={"role": "user"})
transformed_messages = compressor.apply_transform(copy.deepcopy(messages))
except ImportError:
pytest.skip("LLM Lingua is not installed.")
pre_post_messages = zip(messages, transformed_messages)
for pre_transform, post_transform in pre_post_messages:
assert _filter_dict_test(post_transform, pre_transform, ["user"], exclude_filter=True)
# Test truncating all user messages only
compressor = TextMessageCompressor(
text_compressor=text_compressor, filter_dict={"role": "user"}, exclude_filter=False
)
transformed_messages = compressor.apply_transform(copy.deepcopy(messages))
pre_post_messages = zip(messages, transformed_messages)
for pre_transform, post_transform in pre_post_messages:
assert _filter_dict_test(post_transform, pre_transform, ["user"], exclude_filter=False)
@pytest.mark.parametrize("text_compressor", get_text_compressors())
def test_text_compression_cache(text_compressor):
messages = get_long_messages()
mock_compressed_content = (1, {"content": "mock"})
@@ -171,18 +251,18 @@ def test_text_compression_cache():
) as mocked_get, patch(
"autogen.agentchat.contrib.capabilities.transforms.TextMessageCompressor._cache_set", MagicMock()
) as mocked_set:
text_compressor = TextMessageCompressor()
compressor = TextMessageCompressor(text_compressor=text_compressor)
text_compressor.apply_transform(messages)
text_compressor.apply_transform(messages)
compressor.apply_transform(messages)
compressor.apply_transform(messages)
assert mocked_get.call_count == len(messages)
assert mocked_set.call_count == len(messages)
# We already populated the cache with the mock content
# We need to test if we retrieve the correct content
text_compressor = TextMessageCompressor()
compressed_messages = text_compressor.apply_transform(messages)
compressor = TextMessageCompressor(text_compressor=text_compressor)
compressed_messages = compressor.apply_transform(messages)
for message in compressed_messages:
assert message["content"] == mock_compressed_content[1]

View File

@@ -4,6 +4,7 @@ import json
import logging
import os
import tempfile
from typing import Dict, List
from unittest import mock
from unittest.mock import patch
@@ -43,11 +44,13 @@ JSON_SAMPLE = """
[
{
"model": "gpt-3.5-turbo",
"api_type": "openai"
"api_type": "openai",
"tags": ["gpt35"]
},
{
"model": "gpt-4",
"api_type": "openai"
"api_type": "openai",
"tags": ["gpt4"]
},
{
"model": "gpt-35-turbo-v0301",
@@ -65,6 +68,33 @@ JSON_SAMPLE = """
]
"""
JSON_SAMPLE_DICT = json.loads(JSON_SAMPLE)
FILTER_CONFIG_TEST = [
{
"filter_dict": {"tags": ["gpt35", "gpt4"]},
"exclude": False,
"expected": JSON_SAMPLE_DICT[0:2],
},
{
"filter_dict": {"tags": ["gpt35", "gpt4"]},
"exclude": True,
"expected": JSON_SAMPLE_DICT[2:4],
},
{
"filter_dict": {"api_type": "azure", "api_version": "2024-02-15-preview"},
"exclude": False,
"expected": [JSON_SAMPLE_DICT[2]],
},
]
def _compare_lists_of_dicts(list1: List[Dict], list2: List[Dict]) -> bool:
dump1 = sorted(json.dumps(d, sort_keys=True) for d in list1)
dump2 = sorted(json.dumps(d, sort_keys=True) for d in list2)
return dump1 == dump2
@pytest.fixture
def mock_os_environ():
@@ -72,6 +102,17 @@ def mock_os_environ():
yield
@pytest.mark.parametrize("test_case", FILTER_CONFIG_TEST)
def test_filter_config(test_case):
filter_dict = test_case["filter_dict"]
exclude = test_case["exclude"]
expected = test_case["expected"]
config_list = filter_config(JSON_SAMPLE_DICT, filter_dict, exclude)
assert _compare_lists_of_dicts(config_list, expected)
def test_config_list_from_json():
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tmp_file:
json_data = json.loads(JSON_SAMPLE)