mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Fix message history limiter for tool call (#3178)
* fix: message history limiter to support tool calls * add: pytest and docs for message history limiter for tool calls * Added keep_first_message for HistoryLimiter transform * Update to inbetween to between * Updated keep_first_message to non-optional, logic for history limiter * Update transforms.py * Update test_transforms to match utils introduction, add keep_first_message testing * Update test_transforms.py for pre-commit checks --------- Co-authored-by: Mark Sze <66362098+marklysze@users.noreply.github.com> Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
@@ -9,8 +9,8 @@ from autogen.agentchat.contrib.capabilities.transforms import (
|
||||
MessageHistoryLimiter,
|
||||
MessageTokenLimiter,
|
||||
TextMessageCompressor,
|
||||
_count_tokens,
|
||||
)
|
||||
from autogen.agentchat.contrib.capabilities.transforms_util import count_text_tokens
|
||||
|
||||
|
||||
class _MockTextCompressor:
|
||||
@@ -40,6 +40,26 @@ def get_no_content_messages() -> List[Dict]:
|
||||
return [{"role": "user", "function_call": "example"}, {"role": "assistant", "content": None}]
|
||||
|
||||
|
||||
def get_tool_messages() -> List[Dict]:
|
||||
return [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "tool_calls", "content": "calling_tool"},
|
||||
{"role": "tool", "content": "tool_response"},
|
||||
{"role": "user", "content": "how are you"},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "are you doing?"}]},
|
||||
]
|
||||
|
||||
|
||||
def get_tool_messages_kept() -> List[Dict]:
|
||||
return [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "tool_calls", "content": "calling_tool"},
|
||||
{"role": "tool", "content": "tool_response"},
|
||||
{"role": "tool_calls", "content": "calling_tool"},
|
||||
{"role": "tool", "content": "tool_response"},
|
||||
]
|
||||
|
||||
|
||||
def get_text_compressors() -> List[TextCompressor]:
|
||||
compressors: List[TextCompressor] = [_MockTextCompressor()]
|
||||
try:
|
||||
@@ -57,6 +77,11 @@ def message_history_limiter() -> MessageHistoryLimiter:
|
||||
return MessageHistoryLimiter(max_messages=3)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def message_history_limiter_keep_first() -> MessageHistoryLimiter:
|
||||
return MessageHistoryLimiter(max_messages=3, keep_first_message=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def message_token_limiter() -> MessageTokenLimiter:
|
||||
return MessageTokenLimiter(max_tokens_per_message=3)
|
||||
@@ -96,12 +121,43 @@ def _filter_dict_test(
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"messages, expected_messages_len",
|
||||
[(get_long_messages(), 3), (get_short_messages(), 3), (get_no_content_messages(), 2)],
|
||||
[
|
||||
(get_long_messages(), 3),
|
||||
(get_short_messages(), 3),
|
||||
(get_no_content_messages(), 2),
|
||||
(get_tool_messages(), 2),
|
||||
(get_tool_messages_kept(), 2),
|
||||
],
|
||||
)
|
||||
def test_message_history_limiter_apply_transform(message_history_limiter, messages, expected_messages_len):
|
||||
transformed_messages = message_history_limiter.apply_transform(messages)
|
||||
assert len(transformed_messages) == expected_messages_len
|
||||
|
||||
if messages == get_tool_messages_kept():
|
||||
assert transformed_messages[0]["role"] == "tool_calls"
|
||||
assert transformed_messages[1]["role"] == "tool"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"messages, expected_messages_len",
|
||||
[
|
||||
(get_long_messages(), 3),
|
||||
(get_short_messages(), 3),
|
||||
(get_no_content_messages(), 2),
|
||||
(get_tool_messages(), 3),
|
||||
(get_tool_messages_kept(), 3),
|
||||
],
|
||||
)
|
||||
def test_message_history_limiter_apply_transform_keep_first(
|
||||
message_history_limiter_keep_first, messages, expected_messages_len
|
||||
):
|
||||
transformed_messages = message_history_limiter_keep_first.apply_transform(messages)
|
||||
assert len(transformed_messages) == expected_messages_len
|
||||
|
||||
if messages == get_tool_messages_kept():
|
||||
assert transformed_messages[1]["role"] == "tool_calls"
|
||||
assert transformed_messages[2]["role"] == "tool"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"messages, expected_logs, expected_effect",
|
||||
@@ -109,6 +165,8 @@ def test_message_history_limiter_apply_transform(message_history_limiter, messag
|
||||
(get_long_messages(), "Removed 2 messages. Number of messages reduced from 5 to 3.", True),
|
||||
(get_short_messages(), "No messages were removed.", False),
|
||||
(get_no_content_messages(), "No messages were removed.", False),
|
||||
(get_tool_messages(), "Removed 3 messages. Number of messages reduced from 5 to 2.", True),
|
||||
(get_tool_messages_kept(), "Removed 3 messages. Number of messages reduced from 5 to 2.", True),
|
||||
],
|
||||
)
|
||||
def test_message_history_limiter_get_logs(message_history_limiter, messages, expected_logs, expected_effect):
|
||||
@@ -131,7 +189,8 @@ def test_message_token_limiter_apply_transform(
|
||||
):
|
||||
transformed_messages = message_token_limiter.apply_transform(copy.deepcopy(messages))
|
||||
assert (
|
||||
sum(_count_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) == expected_token_count
|
||||
sum(count_text_tokens(msg["content"]) for msg in transformed_messages if "content" in msg)
|
||||
== expected_token_count
|
||||
)
|
||||
assert len(transformed_messages) == expected_messages_len
|
||||
|
||||
@@ -167,7 +226,8 @@ def test_message_token_limiter_with_threshold_apply_transform(
|
||||
):
|
||||
transformed_messages = message_token_limiter_with_threshold.apply_transform(messages)
|
||||
assert (
|
||||
sum(_count_tokens(msg["content"]) for msg in transformed_messages if "content" in msg) == expected_token_count
|
||||
sum(count_text_tokens(msg["content"]) for msg in transformed_messages if "content" in msg)
|
||||
== expected_token_count
|
||||
)
|
||||
assert len(transformed_messages) == expected_messages_len
|
||||
|
||||
@@ -240,56 +300,31 @@ def test_text_compression_with_filter(messages, text_compressor):
|
||||
assert _filter_dict_test(post_transform, pre_transform, ["user"], exclude_filter=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text_compressor", get_text_compressors())
|
||||
def test_text_compression_cache(text_compressor):
|
||||
messages = get_long_messages()
|
||||
mock_compressed_content = (1, {"content": "mock"})
|
||||
|
||||
with patch(
|
||||
"autogen.agentchat.contrib.capabilities.transforms.TextMessageCompressor._cache_get",
|
||||
MagicMock(return_value=(1, {"content": "mock"})),
|
||||
) as mocked_get, patch(
|
||||
"autogen.agentchat.contrib.capabilities.transforms.TextMessageCompressor._cache_set", MagicMock()
|
||||
) as mocked_set:
|
||||
compressor = TextMessageCompressor(text_compressor=text_compressor)
|
||||
|
||||
compressor.apply_transform(messages)
|
||||
compressor.apply_transform(messages)
|
||||
|
||||
assert mocked_get.call_count == len(messages)
|
||||
assert mocked_set.call_count == len(messages)
|
||||
|
||||
# We already populated the cache with the mock content
|
||||
# We need to test if we retrieve the correct content
|
||||
compressor = TextMessageCompressor(text_compressor=text_compressor)
|
||||
compressed_messages = compressor.apply_transform(messages)
|
||||
|
||||
for message in compressed_messages:
|
||||
assert message["content"] == mock_compressed_content[1]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
long_messages = get_long_messages()
|
||||
short_messages = get_short_messages()
|
||||
no_content_messages = get_no_content_messages()
|
||||
tool_messages = get_tool_messages()
|
||||
msg_history_limiter = MessageHistoryLimiter(max_messages=3)
|
||||
msg_history_limiter_keep_first = MessageHistoryLimiter(max_messages=3, keep_first=True)
|
||||
msg_token_limiter = MessageTokenLimiter(max_tokens_per_message=3)
|
||||
msg_token_limiter_with_threshold = MessageTokenLimiter(max_tokens_per_message=1, min_tokens=10)
|
||||
|
||||
# Test Parameters
|
||||
message_history_limiter_apply_transform_parameters = {
|
||||
"messages": [long_messages, short_messages, no_content_messages],
|
||||
"expected_messages_len": [3, 3, 2],
|
||||
"messages": [long_messages, short_messages, no_content_messages, tool_messages],
|
||||
"expected_messages_len": [3, 3, 2, 4],
|
||||
}
|
||||
|
||||
message_history_limiter_get_logs_parameters = {
|
||||
"messages": [long_messages, short_messages, no_content_messages],
|
||||
"messages": [long_messages, short_messages, no_content_messages, tool_messages],
|
||||
"expected_logs": [
|
||||
"Removed 2 messages. Number of messages reduced from 5 to 3.",
|
||||
"No messages were removed.",
|
||||
"No messages were removed.",
|
||||
"Removed 1 messages. Number of messages reduced from 5 to 4.",
|
||||
],
|
||||
"expected_effect": [True, False, False],
|
||||
"expected_effect": [True, False, False, True],
|
||||
}
|
||||
|
||||
message_token_limiter_apply_transform_parameters = {
|
||||
@@ -322,6 +357,14 @@ if __name__ == "__main__":
|
||||
):
|
||||
test_message_history_limiter_apply_transform(msg_history_limiter, messages, expected_messages_len)
|
||||
|
||||
for messages, expected_messages_len in zip(
|
||||
message_history_limiter_apply_transform_parameters["messages"],
|
||||
message_history_limiter_apply_transform_parameters["expected_messages_len"],
|
||||
):
|
||||
test_message_history_limiter_apply_transform_keep_first(
|
||||
msg_history_limiter_keep_first, messages, expected_messages_len
|
||||
)
|
||||
|
||||
for messages, expected_logs, expected_effect in zip(
|
||||
message_history_limiter_get_logs_parameters["messages"],
|
||||
message_history_limiter_get_logs_parameters["expected_logs"],
|
||||
|
||||
Reference in New Issue
Block a user