mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
[Refactor] Transforms Utils (#2863)
* wip * tests + docstrings * improves tests * fix import
This commit is contained in:
72
test/agentchat/contrib/capabilities/test_transforms_util.py
Normal file
72
test/agentchat/contrib/capabilities/test_transforms_util.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import itertools
|
||||
import tempfile
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from autogen.agentchat.contrib.capabilities import transforms_util
|
||||
from autogen.cache.cache import Cache
|
||||
from autogen.types import MessageContentType
|
||||
|
||||
MESSAGES = {
|
||||
"message1": {
|
||||
"content": [{"text": "Hello"}, {"image_url": {"url": "https://example.com/image.jpg"}}],
|
||||
"text_tokens": 1,
|
||||
},
|
||||
"message2": {"content": [{"image_url": {"url": "https://example.com/image.jpg"}}], "text_tokens": 0},
|
||||
"message3": {"content": [{"text": "Hello"}, {"text": "World"}], "text_tokens": 2},
|
||||
"message4": {"content": None, "text_tokens": 0},
|
||||
"message5": {"content": "Hello there!", "text_tokens": 3},
|
||||
"message6": {"content": ["Hello there!", "Hello there!"], "text_tokens": 6},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("message", MESSAGES.values())
|
||||
def test_cache_content(message: Dict[str, MessageContentType]) -> None:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
cache = Cache.disk(tmpdirname)
|
||||
cache_key_1 = "test_string"
|
||||
|
||||
transforms_util.cache_content_set(cache, cache_key_1, message["content"])
|
||||
assert transforms_util.cache_content_get(cache, cache_key_1) == (message["content"],)
|
||||
|
||||
cache_key_2 = "test_list"
|
||||
cache_value_2 = [message["content"], 1, "some_string", {"new_key": "new_value"}]
|
||||
transforms_util.cache_content_set(cache, cache_key_2, *cache_value_2)
|
||||
assert transforms_util.cache_content_get(cache, cache_key_2) == tuple(cache_value_2)
|
||||
assert isinstance(cache_value_2[1], int)
|
||||
assert isinstance(cache_value_2[2], str)
|
||||
assert isinstance(cache_value_2[3], dict)
|
||||
|
||||
cache_key_3 = "test_None"
|
||||
transforms_util.cache_content_set(None, cache_key_3, message["content"])
|
||||
assert transforms_util.cache_content_get(cache, cache_key_3) is None
|
||||
assert transforms_util.cache_content_get(None, cache_key_3) is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("messages", itertools.product(MESSAGES.values(), MESSAGES.values()))
|
||||
def test_cache_key(messages: Tuple[Dict[str, MessageContentType], Dict[str, MessageContentType]]) -> None:
|
||||
message_1, message_2 = messages
|
||||
cache_1 = transforms_util.cache_key(message_1["content"], 10)
|
||||
cache_2 = transforms_util.cache_key(message_2["content"], 10)
|
||||
if message_1 == message_2:
|
||||
assert cache_1 == cache_2
|
||||
else:
|
||||
assert cache_1 != cache_2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("message", MESSAGES.values())
|
||||
def test_min_tokens_reached(message: Dict[str, MessageContentType]):
|
||||
assert transforms_util.min_tokens_reached([message], None)
|
||||
assert transforms_util.min_tokens_reached([message], 0)
|
||||
assert not transforms_util.min_tokens_reached([message], message["text_tokens"] + 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("message", MESSAGES.values())
|
||||
def test_count_text_tokens(message: Dict[str, MessageContentType]):
|
||||
assert transforms_util.count_text_tokens(message["content"]) == message["text_tokens"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("message", MESSAGES.values())
|
||||
def test_is_content_text_empty(message: Dict[str, MessageContentType]):
|
||||
assert transforms_util.is_content_text_empty(message["content"]) == (message["text_tokens"] == 0)
|
||||
Reference in New Issue
Block a user