[Refactor] Transforms Utils (#2863)

* wip

* tests + docstrings

* improves tests

* fix import
This commit is contained in:
Wael Karkoub
2024-06-06 22:49:22 +01:00
committed by GitHub
parent 102d36d98f
commit 8564bd4c48
4 changed files with 221 additions and 82 deletions

View File

@@ -1,5 +1,4 @@
import copy
import json
import sys
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
@@ -8,8 +7,9 @@ from termcolor import colored
from autogen import token_count_utils
from autogen.cache import AbstractCache, Cache
from autogen.oai.openai_utils import filter_config
from autogen.types import MessageContentType
from . import transforms_util
from .text_compressors import LLMLingua, TextCompressor
@@ -169,7 +169,7 @@ class MessageTokenLimiter:
assert self._min_tokens is not None
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not _min_tokens_reached(messages, self._min_tokens):
if not transforms_util.min_tokens_reached(messages, self._min_tokens):
return messages
temp_messages = copy.deepcopy(messages)
@@ -178,13 +178,13 @@ class MessageTokenLimiter:
for msg in reversed(temp_messages):
# Some messages may not have content.
if not _is_content_right_type(msg.get("content")):
if not transforms_util.is_content_right_type(msg.get("content")):
processed_messages.insert(0, msg)
continue
if not _should_transform_message(msg, self._filter_dict, self._exclude_filter):
if not transforms_util.should_transform_message(msg, self._filter_dict, self._exclude_filter):
processed_messages.insert(0, msg)
processed_messages_tokens += _count_tokens(msg["content"])
processed_messages_tokens += transforms_util.count_text_tokens(msg["content"])
continue
expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message
@@ -199,7 +199,7 @@ class MessageTokenLimiter:
break
msg["content"] = self._truncate_str_to_tokens(msg["content"], self._max_tokens_per_message)
msg_tokens = _count_tokens(msg["content"])
msg_tokens = transforms_util.count_text_tokens(msg["content"])
# prepend the message to the list to preserve order
processed_messages_tokens += msg_tokens
@@ -209,10 +209,10 @@ class MessageTokenLimiter:
def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
pre_transform_messages_tokens = sum(
_count_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg
transforms_util.count_text_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg
)
post_transform_messages_tokens = sum(
_count_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg
transforms_util.count_text_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg
)
if post_transform_messages_tokens < pre_transform_messages_tokens:
@@ -349,31 +349,32 @@ class TextMessageCompressor:
return messages
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
if not _min_tokens_reached(messages, self._min_tokens):
if not transforms_util.min_tokens_reached(messages, self._min_tokens):
return messages
total_savings = 0
processed_messages = messages.copy()
for message in processed_messages:
# Some messages may not have content.
if not _is_content_right_type(message.get("content")):
if not transforms_util.is_content_right_type(message.get("content")):
continue
if not _should_transform_message(message, self._filter_dict, self._exclude_filter):
if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter):
continue
if _is_content_text_empty(message["content"]):
if transforms_util.is_content_text_empty(message["content"]):
continue
cached_content = self._cache_get(message["content"])
cache_key = transforms_util.cache_key(message["content"], self._min_tokens)
cached_content = transforms_util.cache_content_get(self._cache, cache_key)
if cached_content is not None:
savings, compressed_content = cached_content
message["content"], savings = cached_content
else:
savings, compressed_content = self._compress(message["content"])
message["content"], savings = self._compress(message["content"])
self._cache_set(message["content"], compressed_content, savings)
transforms_util.cache_content_set(self._cache, cache_key, message["content"], savings)
message["content"] = compressed_content
assert isinstance(savings, int)
total_savings += savings
self._recent_tokens_savings = total_savings
@@ -385,24 +386,29 @@ class TextMessageCompressor:
else:
return "No tokens saved with text compression.", False
def _compress(self, content: Union[str, List[Dict]]) -> Tuple[int, Union[str, List[Dict]]]:
def _compress(self, content: MessageContentType) -> Tuple[MessageContentType, int]:
"""Compresses the given text or multimodal content using the specified compression method."""
if isinstance(content, str):
return self._compress_text(content)
elif isinstance(content, list):
return self._compress_multimodal(content)
else:
return 0, content
return content, 0
def _compress_multimodal(self, content: List[Dict]) -> Tuple[int, List[Dict]]:
def _compress_multimodal(self, content: MessageContentType) -> Tuple[MessageContentType, int]:
tokens_saved = 0
for msg in content:
if "text" in msg:
savings, msg["text"] = self._compress_text(msg["text"])
for item in content:
if isinstance(item, dict) and "text" in item:
item["text"], savings = self._compress_text(item["text"])
tokens_saved += savings
return tokens_saved, content
def _compress_text(self, text: str) -> Tuple[int, str]:
elif isinstance(item, str):
item, savings = self._compress_text(item)
tokens_saved += savings
return content, tokens_saved
def _compress_text(self, text: str) -> Tuple[str, int]:
"""Compresses the given text using the specified compression method."""
compressed_text = self._text_compressor.compress_text(text, **self._compression_args)
@@ -410,63 +416,8 @@ class TextMessageCompressor:
if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text:
savings = compressed_text["origin_tokens"] - compressed_text["compressed_tokens"]
return savings, compressed_text["compressed_prompt"]
def _cache_get(self, content: Union[str, List[Dict]]) -> Optional[Tuple[int, Union[str, List[Dict]]]]:
if self._cache:
cached_value = self._cache.get(self._cache_key(content))
if cached_value:
return cached_value
def _cache_set(
self, content: Union[str, List[Dict]], compressed_content: Union[str, List[Dict]], tokens_saved: int
):
if self._cache:
value = (tokens_saved, compressed_content)
self._cache.set(self._cache_key(content), value)
def _cache_key(self, content: Union[str, List[Dict]]) -> str:
return f"{json.dumps(content)}_{self._min_tokens}"
return compressed_text["compressed_prompt"], savings
def _validate_min_tokens(self, min_tokens: Optional[int]):
if min_tokens is not None and min_tokens <= 0:
raise ValueError("min_tokens must be greater than 0 or None")
def _min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool:
"""Returns True if the total number of tokens in the messages is greater than or equal to the specified value."""
if not min_tokens:
return True
messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg)
return messages_tokens >= min_tokens
def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
token_count = 0
if isinstance(content, str):
token_count = token_count_utils.count_token(content)
elif isinstance(content, list):
for item in content:
token_count += _count_tokens(item.get("text", ""))
return token_count
def _is_content_right_type(content: Any) -> bool:
return isinstance(content, (str, list))
def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool:
if isinstance(content, str):
return content == ""
elif isinstance(content, list):
return all(_is_content_text_empty(item.get("text", "")) for item in content)
else:
return False
def _should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool:
if not filter_dict:
return True
return len(filter_config([message], filter_dict, exclude)) > 0

View File

@@ -0,0 +1,114 @@
from typing import Any, Dict, Hashable, List, Optional, Tuple
from autogen import token_count_utils
from autogen.cache.abstract_cache_base import AbstractCache
from autogen.oai.openai_utils import filter_config
from autogen.types import MessageContentType
def cache_key(content: MessageContentType, *args: Hashable) -> str:
"""Calculates the cache key for the given message content and any other hashable args.
Args:
content (MessageContentType): The message content to calculate the cache key for.
*args: Any additional hashable args to include in the cache key.
"""
str_keys = [str(key) for key in (content, *args)]
return "".join(str_keys)
def cache_content_get(cache: Optional[AbstractCache], key: str) -> Optional[Tuple[MessageContentType, ...]]:
"""Retrieves cachedd content from the cache.
Args:
cache (None or AbstractCache): The cache to retrieve the content from. If None, the cache is ignored.
key (str): The key to retrieve the content from.
"""
if cache:
cached_value = cache.get(key)
if cached_value:
return cached_value
def cache_content_set(cache: Optional[AbstractCache], key: str, content: MessageContentType, *extra_values):
"""Sets content into the cache.
Args:
cache (None or AbstractCache): The cache to set the content into. If None, the cache is ignored.
key (str): The key to set the content into.
content (MessageContentType): The message content to set into the cache.
*extra_values: Additional values to be passed to the cache.
"""
if cache:
cache_value = (content, *extra_values)
cache.set(key, cache_value)
def min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool:
"""Returns True if the total number of tokens in the messages is greater than or equal to the specified value.
Args:
messages (List[Dict]): A list of messages to check.
"""
if not min_tokens:
return True
messages_tokens = sum(count_text_tokens(msg["content"]) for msg in messages if "content" in msg)
return messages_tokens >= min_tokens
def count_text_tokens(content: MessageContentType) -> int:
"""Calculates the number of text tokens in the given message content.
Args:
content (MessageContentType): The message content to calculate the number of text tokens for.
"""
token_count = 0
if isinstance(content, str):
token_count = token_count_utils.count_token(content)
elif isinstance(content, list):
for item in content:
if isinstance(item, str):
token_count += token_count_utils.count_token(item)
else:
token_count += count_text_tokens(item.get("text", ""))
return token_count
def is_content_right_type(content: Any) -> bool:
"""A helper function to check if the passed in content is of the right type."""
return isinstance(content, (str, list))
def is_content_text_empty(content: MessageContentType) -> bool:
"""Checks if the content of the message does not contain any text.
Args:
content (MessageContentType): The message content to check.
"""
if isinstance(content, str):
return content == ""
elif isinstance(content, list):
texts = []
for item in content:
if isinstance(item, str):
texts.append(item)
elif isinstance(item, dict):
texts.append(item.get("text", ""))
return not any(texts)
else:
return True
def should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool:
"""Validates whether the transform should be applied according to the filter dictionary.
Args:
message (Dict[str, Any]): The message to validate.
filter_dict (None or Dict[str, Any]): The filter dictionary to validate against. If None, the transform is always applied.
exclude (bool): Whether to exclude messages that match the filter dictionary.
"""
if not filter_dict:
return True
return len(filter_config([message], filter_dict, exclude)) > 0

View File

@@ -1,5 +1,7 @@
from typing import Dict, List, Literal, TypedDict, Union
MessageContentType = Union[str, List[Union[Dict, str]], None]
class UserMessageTextContentPart(TypedDict):
type: Literal["text"]

View File

@@ -0,0 +1,72 @@
import itertools
import tempfile
from typing import Dict, Tuple
import pytest
from autogen.agentchat.contrib.capabilities import transforms_util
from autogen.cache.cache import Cache
from autogen.types import MessageContentType
MESSAGES = {
"message1": {
"content": [{"text": "Hello"}, {"image_url": {"url": "https://example.com/image.jpg"}}],
"text_tokens": 1,
},
"message2": {"content": [{"image_url": {"url": "https://example.com/image.jpg"}}], "text_tokens": 0},
"message3": {"content": [{"text": "Hello"}, {"text": "World"}], "text_tokens": 2},
"message4": {"content": None, "text_tokens": 0},
"message5": {"content": "Hello there!", "text_tokens": 3},
"message6": {"content": ["Hello there!", "Hello there!"], "text_tokens": 6},
}
@pytest.mark.parametrize("message", MESSAGES.values())
def test_cache_content(message: Dict[str, MessageContentType]) -> None:
with tempfile.TemporaryDirectory() as tmpdirname:
cache = Cache.disk(tmpdirname)
cache_key_1 = "test_string"
transforms_util.cache_content_set(cache, cache_key_1, message["content"])
assert transforms_util.cache_content_get(cache, cache_key_1) == (message["content"],)
cache_key_2 = "test_list"
cache_value_2 = [message["content"], 1, "some_string", {"new_key": "new_value"}]
transforms_util.cache_content_set(cache, cache_key_2, *cache_value_2)
assert transforms_util.cache_content_get(cache, cache_key_2) == tuple(cache_value_2)
assert isinstance(cache_value_2[1], int)
assert isinstance(cache_value_2[2], str)
assert isinstance(cache_value_2[3], dict)
cache_key_3 = "test_None"
transforms_util.cache_content_set(None, cache_key_3, message["content"])
assert transforms_util.cache_content_get(cache, cache_key_3) is None
assert transforms_util.cache_content_get(None, cache_key_3) is None
@pytest.mark.parametrize("messages", itertools.product(MESSAGES.values(), MESSAGES.values()))
def test_cache_key(messages: Tuple[Dict[str, MessageContentType], Dict[str, MessageContentType]]) -> None:
message_1, message_2 = messages
cache_1 = transforms_util.cache_key(message_1["content"], 10)
cache_2 = transforms_util.cache_key(message_2["content"], 10)
if message_1 == message_2:
assert cache_1 == cache_2
else:
assert cache_1 != cache_2
@pytest.mark.parametrize("message", MESSAGES.values())
def test_min_tokens_reached(message: Dict[str, MessageContentType]):
assert transforms_util.min_tokens_reached([message], None)
assert transforms_util.min_tokens_reached([message], 0)
assert not transforms_util.min_tokens_reached([message], message["text_tokens"] + 1)
@pytest.mark.parametrize("message", MESSAGES.values())
def test_count_text_tokens(message: Dict[str, MessageContentType]):
assert transforms_util.count_text_tokens(message["content"]) == message["text_tokens"]
@pytest.mark.parametrize("message", MESSAGES.values())
def test_is_content_text_empty(message: Dict[str, MessageContentType]):
assert transforms_util.is_content_text_empty(message["content"]) == (message["text_tokens"] == 0)