mirror of
https://github.com/microsoft/autogen.git
synced 2026-01-24 19:28:12 -05:00
[Refactor] Transforms Utils (#2863)
* wip * tests + docstrings * improves tests * fix import
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import copy
|
||||
import json
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
|
||||
|
||||
@@ -8,8 +7,9 @@ from termcolor import colored
|
||||
|
||||
from autogen import token_count_utils
|
||||
from autogen.cache import AbstractCache, Cache
|
||||
from autogen.oai.openai_utils import filter_config
|
||||
from autogen.types import MessageContentType
|
||||
|
||||
from . import transforms_util
|
||||
from .text_compressors import LLMLingua, TextCompressor
|
||||
|
||||
|
||||
@@ -169,7 +169,7 @@ class MessageTokenLimiter:
|
||||
assert self._min_tokens is not None
|
||||
|
||||
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
|
||||
if not _min_tokens_reached(messages, self._min_tokens):
|
||||
if not transforms_util.min_tokens_reached(messages, self._min_tokens):
|
||||
return messages
|
||||
|
||||
temp_messages = copy.deepcopy(messages)
|
||||
@@ -178,13 +178,13 @@ class MessageTokenLimiter:
|
||||
|
||||
for msg in reversed(temp_messages):
|
||||
# Some messages may not have content.
|
||||
if not _is_content_right_type(msg.get("content")):
|
||||
if not transforms_util.is_content_right_type(msg.get("content")):
|
||||
processed_messages.insert(0, msg)
|
||||
continue
|
||||
|
||||
if not _should_transform_message(msg, self._filter_dict, self._exclude_filter):
|
||||
if not transforms_util.should_transform_message(msg, self._filter_dict, self._exclude_filter):
|
||||
processed_messages.insert(0, msg)
|
||||
processed_messages_tokens += _count_tokens(msg["content"])
|
||||
processed_messages_tokens += transforms_util.count_text_tokens(msg["content"])
|
||||
continue
|
||||
|
||||
expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message
|
||||
@@ -199,7 +199,7 @@ class MessageTokenLimiter:
|
||||
break
|
||||
|
||||
msg["content"] = self._truncate_str_to_tokens(msg["content"], self._max_tokens_per_message)
|
||||
msg_tokens = _count_tokens(msg["content"])
|
||||
msg_tokens = transforms_util.count_text_tokens(msg["content"])
|
||||
|
||||
# prepend the message to the list to preserve order
|
||||
processed_messages_tokens += msg_tokens
|
||||
@@ -209,10 +209,10 @@ class MessageTokenLimiter:
|
||||
|
||||
def get_logs(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]) -> Tuple[str, bool]:
|
||||
pre_transform_messages_tokens = sum(
|
||||
_count_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg
|
||||
transforms_util.count_text_tokens(msg["content"]) for msg in pre_transform_messages if "content" in msg
|
||||
)
|
||||
post_transform_messages_tokens = sum(
|
||||
_count_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg
|
||||
transforms_util.count_text_tokens(msg["content"]) for msg in post_transform_messages if "content" in msg
|
||||
)
|
||||
|
||||
if post_transform_messages_tokens < pre_transform_messages_tokens:
|
||||
@@ -349,31 +349,32 @@ class TextMessageCompressor:
|
||||
return messages
|
||||
|
||||
# if the total number of tokens in the messages is less than the min_tokens, return the messages as is
|
||||
if not _min_tokens_reached(messages, self._min_tokens):
|
||||
if not transforms_util.min_tokens_reached(messages, self._min_tokens):
|
||||
return messages
|
||||
|
||||
total_savings = 0
|
||||
processed_messages = messages.copy()
|
||||
for message in processed_messages:
|
||||
# Some messages may not have content.
|
||||
if not _is_content_right_type(message.get("content")):
|
||||
if not transforms_util.is_content_right_type(message.get("content")):
|
||||
continue
|
||||
|
||||
if not _should_transform_message(message, self._filter_dict, self._exclude_filter):
|
||||
if not transforms_util.should_transform_message(message, self._filter_dict, self._exclude_filter):
|
||||
continue
|
||||
|
||||
if _is_content_text_empty(message["content"]):
|
||||
if transforms_util.is_content_text_empty(message["content"]):
|
||||
continue
|
||||
|
||||
cached_content = self._cache_get(message["content"])
|
||||
cache_key = transforms_util.cache_key(message["content"], self._min_tokens)
|
||||
cached_content = transforms_util.cache_content_get(self._cache, cache_key)
|
||||
if cached_content is not None:
|
||||
savings, compressed_content = cached_content
|
||||
message["content"], savings = cached_content
|
||||
else:
|
||||
savings, compressed_content = self._compress(message["content"])
|
||||
message["content"], savings = self._compress(message["content"])
|
||||
|
||||
self._cache_set(message["content"], compressed_content, savings)
|
||||
transforms_util.cache_content_set(self._cache, cache_key, message["content"], savings)
|
||||
|
||||
message["content"] = compressed_content
|
||||
assert isinstance(savings, int)
|
||||
total_savings += savings
|
||||
|
||||
self._recent_tokens_savings = total_savings
|
||||
@@ -385,24 +386,29 @@ class TextMessageCompressor:
|
||||
else:
|
||||
return "No tokens saved with text compression.", False
|
||||
|
||||
def _compress(self, content: Union[str, List[Dict]]) -> Tuple[int, Union[str, List[Dict]]]:
|
||||
def _compress(self, content: MessageContentType) -> Tuple[MessageContentType, int]:
|
||||
"""Compresses the given text or multimodal content using the specified compression method."""
|
||||
if isinstance(content, str):
|
||||
return self._compress_text(content)
|
||||
elif isinstance(content, list):
|
||||
return self._compress_multimodal(content)
|
||||
else:
|
||||
return 0, content
|
||||
return content, 0
|
||||
|
||||
def _compress_multimodal(self, content: List[Dict]) -> Tuple[int, List[Dict]]:
|
||||
def _compress_multimodal(self, content: MessageContentType) -> Tuple[MessageContentType, int]:
|
||||
tokens_saved = 0
|
||||
for msg in content:
|
||||
if "text" in msg:
|
||||
savings, msg["text"] = self._compress_text(msg["text"])
|
||||
for item in content:
|
||||
if isinstance(item, dict) and "text" in item:
|
||||
item["text"], savings = self._compress_text(item["text"])
|
||||
tokens_saved += savings
|
||||
return tokens_saved, content
|
||||
|
||||
def _compress_text(self, text: str) -> Tuple[int, str]:
|
||||
elif isinstance(item, str):
|
||||
item, savings = self._compress_text(item)
|
||||
tokens_saved += savings
|
||||
|
||||
return content, tokens_saved
|
||||
|
||||
def _compress_text(self, text: str) -> Tuple[str, int]:
|
||||
"""Compresses the given text using the specified compression method."""
|
||||
compressed_text = self._text_compressor.compress_text(text, **self._compression_args)
|
||||
|
||||
@@ -410,63 +416,8 @@ class TextMessageCompressor:
|
||||
if "origin_tokens" in compressed_text and "compressed_tokens" in compressed_text:
|
||||
savings = compressed_text["origin_tokens"] - compressed_text["compressed_tokens"]
|
||||
|
||||
return savings, compressed_text["compressed_prompt"]
|
||||
|
||||
def _cache_get(self, content: Union[str, List[Dict]]) -> Optional[Tuple[int, Union[str, List[Dict]]]]:
|
||||
if self._cache:
|
||||
cached_value = self._cache.get(self._cache_key(content))
|
||||
if cached_value:
|
||||
return cached_value
|
||||
|
||||
def _cache_set(
|
||||
self, content: Union[str, List[Dict]], compressed_content: Union[str, List[Dict]], tokens_saved: int
|
||||
):
|
||||
if self._cache:
|
||||
value = (tokens_saved, compressed_content)
|
||||
self._cache.set(self._cache_key(content), value)
|
||||
|
||||
def _cache_key(self, content: Union[str, List[Dict]]) -> str:
|
||||
return f"{json.dumps(content)}_{self._min_tokens}"
|
||||
return compressed_text["compressed_prompt"], savings
|
||||
|
||||
def _validate_min_tokens(self, min_tokens: Optional[int]):
|
||||
if min_tokens is not None and min_tokens <= 0:
|
||||
raise ValueError("min_tokens must be greater than 0 or None")
|
||||
|
||||
|
||||
def _min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool:
|
||||
"""Returns True if the total number of tokens in the messages is greater than or equal to the specified value."""
|
||||
if not min_tokens:
|
||||
return True
|
||||
|
||||
messages_tokens = sum(_count_tokens(msg["content"]) for msg in messages if "content" in msg)
|
||||
return messages_tokens >= min_tokens
|
||||
|
||||
|
||||
def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
|
||||
token_count = 0
|
||||
if isinstance(content, str):
|
||||
token_count = token_count_utils.count_token(content)
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
token_count += _count_tokens(item.get("text", ""))
|
||||
return token_count
|
||||
|
||||
|
||||
def _is_content_right_type(content: Any) -> bool:
|
||||
return isinstance(content, (str, list))
|
||||
|
||||
|
||||
def _is_content_text_empty(content: Union[str, List[Dict[str, Any]]]) -> bool:
|
||||
if isinstance(content, str):
|
||||
return content == ""
|
||||
elif isinstance(content, list):
|
||||
return all(_is_content_text_empty(item.get("text", "")) for item in content)
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def _should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool:
|
||||
if not filter_dict:
|
||||
return True
|
||||
|
||||
return len(filter_config([message], filter_dict, exclude)) > 0
|
||||
|
||||
114
autogen/agentchat/contrib/capabilities/transforms_util.py
Normal file
114
autogen/agentchat/contrib/capabilities/transforms_util.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from typing import Any, Dict, Hashable, List, Optional, Tuple
|
||||
|
||||
from autogen import token_count_utils
|
||||
from autogen.cache.abstract_cache_base import AbstractCache
|
||||
from autogen.oai.openai_utils import filter_config
|
||||
from autogen.types import MessageContentType
|
||||
|
||||
|
||||
def cache_key(content: MessageContentType, *args: Hashable) -> str:
|
||||
"""Calculates the cache key for the given message content and any other hashable args.
|
||||
|
||||
Args:
|
||||
content (MessageContentType): The message content to calculate the cache key for.
|
||||
*args: Any additional hashable args to include in the cache key.
|
||||
"""
|
||||
str_keys = [str(key) for key in (content, *args)]
|
||||
return "".join(str_keys)
|
||||
|
||||
|
||||
def cache_content_get(cache: Optional[AbstractCache], key: str) -> Optional[Tuple[MessageContentType, ...]]:
|
||||
"""Retrieves cachedd content from the cache.
|
||||
|
||||
Args:
|
||||
cache (None or AbstractCache): The cache to retrieve the content from. If None, the cache is ignored.
|
||||
key (str): The key to retrieve the content from.
|
||||
"""
|
||||
if cache:
|
||||
cached_value = cache.get(key)
|
||||
if cached_value:
|
||||
return cached_value
|
||||
|
||||
|
||||
def cache_content_set(cache: Optional[AbstractCache], key: str, content: MessageContentType, *extra_values):
|
||||
"""Sets content into the cache.
|
||||
|
||||
Args:
|
||||
cache (None or AbstractCache): The cache to set the content into. If None, the cache is ignored.
|
||||
key (str): The key to set the content into.
|
||||
content (MessageContentType): The message content to set into the cache.
|
||||
*extra_values: Additional values to be passed to the cache.
|
||||
"""
|
||||
if cache:
|
||||
cache_value = (content, *extra_values)
|
||||
cache.set(key, cache_value)
|
||||
|
||||
|
||||
def min_tokens_reached(messages: List[Dict], min_tokens: Optional[int]) -> bool:
|
||||
"""Returns True if the total number of tokens in the messages is greater than or equal to the specified value.
|
||||
|
||||
Args:
|
||||
messages (List[Dict]): A list of messages to check.
|
||||
"""
|
||||
if not min_tokens:
|
||||
return True
|
||||
|
||||
messages_tokens = sum(count_text_tokens(msg["content"]) for msg in messages if "content" in msg)
|
||||
return messages_tokens >= min_tokens
|
||||
|
||||
|
||||
def count_text_tokens(content: MessageContentType) -> int:
|
||||
"""Calculates the number of text tokens in the given message content.
|
||||
|
||||
Args:
|
||||
content (MessageContentType): The message content to calculate the number of text tokens for.
|
||||
"""
|
||||
token_count = 0
|
||||
if isinstance(content, str):
|
||||
token_count = token_count_utils.count_token(content)
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
token_count += token_count_utils.count_token(item)
|
||||
else:
|
||||
token_count += count_text_tokens(item.get("text", ""))
|
||||
return token_count
|
||||
|
||||
|
||||
def is_content_right_type(content: Any) -> bool:
|
||||
"""A helper function to check if the passed in content is of the right type."""
|
||||
return isinstance(content, (str, list))
|
||||
|
||||
|
||||
def is_content_text_empty(content: MessageContentType) -> bool:
|
||||
"""Checks if the content of the message does not contain any text.
|
||||
|
||||
Args:
|
||||
content (MessageContentType): The message content to check.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content == ""
|
||||
elif isinstance(content, list):
|
||||
texts = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
texts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
texts.append(item.get("text", ""))
|
||||
return not any(texts)
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def should_transform_message(message: Dict[str, Any], filter_dict: Optional[Dict[str, Any]], exclude: bool) -> bool:
|
||||
"""Validates whether the transform should be applied according to the filter dictionary.
|
||||
|
||||
Args:
|
||||
message (Dict[str, Any]): The message to validate.
|
||||
filter_dict (None or Dict[str, Any]): The filter dictionary to validate against. If None, the transform is always applied.
|
||||
exclude (bool): Whether to exclude messages that match the filter dictionary.
|
||||
"""
|
||||
if not filter_dict:
|
||||
return True
|
||||
|
||||
return len(filter_config([message], filter_dict, exclude)) > 0
|
||||
@@ -1,5 +1,7 @@
|
||||
from typing import Dict, List, Literal, TypedDict, Union
|
||||
|
||||
MessageContentType = Union[str, List[Union[Dict, str]], None]
|
||||
|
||||
|
||||
class UserMessageTextContentPart(TypedDict):
|
||||
type: Literal["text"]
|
||||
|
||||
72
test/agentchat/contrib/capabilities/test_transforms_util.py
Normal file
72
test/agentchat/contrib/capabilities/test_transforms_util.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import itertools
|
||||
import tempfile
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from autogen.agentchat.contrib.capabilities import transforms_util
|
||||
from autogen.cache.cache import Cache
|
||||
from autogen.types import MessageContentType
|
||||
|
||||
MESSAGES = {
|
||||
"message1": {
|
||||
"content": [{"text": "Hello"}, {"image_url": {"url": "https://example.com/image.jpg"}}],
|
||||
"text_tokens": 1,
|
||||
},
|
||||
"message2": {"content": [{"image_url": {"url": "https://example.com/image.jpg"}}], "text_tokens": 0},
|
||||
"message3": {"content": [{"text": "Hello"}, {"text": "World"}], "text_tokens": 2},
|
||||
"message4": {"content": None, "text_tokens": 0},
|
||||
"message5": {"content": "Hello there!", "text_tokens": 3},
|
||||
"message6": {"content": ["Hello there!", "Hello there!"], "text_tokens": 6},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("message", MESSAGES.values())
|
||||
def test_cache_content(message: Dict[str, MessageContentType]) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
cache = Cache.disk(tmpdirname)
|
||||
cache_key_1 = "test_string"
|
||||
|
||||
transforms_util.cache_content_set(cache, cache_key_1, message["content"])
|
||||
assert transforms_util.cache_content_get(cache, cache_key_1) == (message["content"],)
|
||||
|
||||
cache_key_2 = "test_list"
|
||||
cache_value_2 = [message["content"], 1, "some_string", {"new_key": "new_value"}]
|
||||
transforms_util.cache_content_set(cache, cache_key_2, *cache_value_2)
|
||||
assert transforms_util.cache_content_get(cache, cache_key_2) == tuple(cache_value_2)
|
||||
assert isinstance(cache_value_2[1], int)
|
||||
assert isinstance(cache_value_2[2], str)
|
||||
assert isinstance(cache_value_2[3], dict)
|
||||
|
||||
cache_key_3 = "test_None"
|
||||
transforms_util.cache_content_set(None, cache_key_3, message["content"])
|
||||
assert transforms_util.cache_content_get(cache, cache_key_3) is None
|
||||
assert transforms_util.cache_content_get(None, cache_key_3) is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("messages", itertools.product(MESSAGES.values(), MESSAGES.values()))
|
||||
def test_cache_key(messages: Tuple[Dict[str, MessageContentType], Dict[str, MessageContentType]]) -> None:
|
||||
message_1, message_2 = messages
|
||||
cache_1 = transforms_util.cache_key(message_1["content"], 10)
|
||||
cache_2 = transforms_util.cache_key(message_2["content"], 10)
|
||||
if message_1 == message_2:
|
||||
assert cache_1 == cache_2
|
||||
else:
|
||||
assert cache_1 != cache_2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("message", MESSAGES.values())
|
||||
def test_min_tokens_reached(message: Dict[str, MessageContentType]):
|
||||
assert transforms_util.min_tokens_reached([message], None)
|
||||
assert transforms_util.min_tokens_reached([message], 0)
|
||||
assert not transforms_util.min_tokens_reached([message], message["text_tokens"] + 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("message", MESSAGES.values())
|
||||
def test_count_text_tokens(message: Dict[str, MessageContentType]):
|
||||
assert transforms_util.count_text_tokens(message["content"]) == message["text_tokens"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("message", MESSAGES.values())
|
||||
def test_is_content_text_empty(message: Dict[str, MessageContentType]):
|
||||
assert transforms_util.is_content_text_empty(message["content"]) == (message["text_tokens"] == 0)
|
||||
Reference in New Issue
Block a user