mirror of
https://github.com/microsoft/autogen.git
synced 2026-02-18 08:15:23 -05:00
Refactor transform_messages (#1631)
* refactored code to simplify * optimize function. Instead of iterating over each character, guess at size and then iterate by token. * adding tests * Add missing tests * minor test fix * simplified token truncation by using tiktoken to encode and decode * updated truncated notification message * Fix llm_config spec to use os.environ * Add test case and fix bug in loop --------- Co-authored-by: gagb <gagb@users.noreply.github.com>
This commit is contained in:
@@ -3,6 +3,7 @@ from termcolor import colored
|
||||
from typing import Dict, Optional, List
|
||||
from autogen import ConversableAgent
|
||||
from autogen import token_count_utils
|
||||
import tiktoken
|
||||
|
||||
|
||||
class TransformChatHistory:
|
||||
@@ -53,56 +54,70 @@ class TransformChatHistory:
|
||||
messages: List of messages to process.
|
||||
|
||||
Returns:
|
||||
List of messages with the first system message and the last max_messages messages.
|
||||
List of messages with the first system message and the last max_messages messages,
|
||||
ensuring each message does not exceed max_tokens_per_message.
|
||||
"""
|
||||
temp_messages = messages.copy()
|
||||
processed_messages = []
|
||||
messages = messages.copy()
|
||||
rest_messages = messages
|
||||
|
||||
# check if the first message is a system message and append it to the processed messages
|
||||
if len(messages) > 0:
|
||||
if messages[0]["role"] == "system":
|
||||
msg = messages[0]
|
||||
processed_messages.append(msg)
|
||||
rest_messages = messages[1:]
|
||||
|
||||
system_message = None
|
||||
processed_messages_tokens = 0
|
||||
for msg in messages:
|
||||
msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message)
|
||||
|
||||
# iterate through rest of the messages and append them to the processed messages
|
||||
for msg in rest_messages[-self.max_messages :]:
|
||||
if messages[0]["role"] == "system":
|
||||
system_message = messages[0].copy()
|
||||
temp_messages.pop(0)
|
||||
|
||||
total_tokens = sum(
|
||||
token_count_utils.count_token(msg["content"]) for msg in temp_messages
|
||||
) # Calculate tokens for all messages
|
||||
|
||||
# Truncate each message's content to a maximum token limit of each message
|
||||
|
||||
# Process recent messages first
|
||||
for msg in reversed(temp_messages[-self.max_messages :]):
|
||||
msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message)
|
||||
msg_tokens = token_count_utils.count_token(msg["content"])
|
||||
if processed_messages_tokens + msg_tokens > self.max_tokens:
|
||||
break
|
||||
processed_messages.append(msg)
|
||||
# append the message to the beginning of the list to preserve order
|
||||
processed_messages = [msg] + processed_messages
|
||||
processed_messages_tokens += msg_tokens
|
||||
|
||||
total_tokens = 0
|
||||
for msg in messages:
|
||||
total_tokens += token_count_utils.count_token(msg["content"])
|
||||
|
||||
if system_message:
|
||||
processed_messages.insert(0, system_message)
|
||||
# Optionally, log the number of truncated messages and tokens if needed
|
||||
num_truncated = len(messages) - len(processed_messages)
|
||||
|
||||
if num_truncated > 0 or total_tokens > processed_messages_tokens:
|
||||
print(colored(f"Truncated {len(messages) - len(processed_messages)} messages.", "yellow"))
|
||||
print(colored(f"Truncated {total_tokens - processed_messages_tokens} tokens.", "yellow"))
|
||||
print(
|
||||
colored(
|
||||
f"Truncated {num_truncated} messages. Reduced from {len(messages)} to {len(processed_messages)}.",
|
||||
"yellow",
|
||||
)
|
||||
)
|
||||
print(
|
||||
colored(
|
||||
f"Truncated {total_tokens - processed_messages_tokens} tokens. Tokens reduced from {total_tokens} to {processed_messages_tokens}",
|
||||
"yellow",
|
||||
)
|
||||
)
|
||||
return processed_messages
|
||||
|
||||
|
||||
def truncate_str_to_tokens(text: str, max_tokens: int) -> str:
|
||||
"""
|
||||
Truncate a string so that number of tokens in less than max_tokens.
|
||||
def truncate_str_to_tokens(text: str, max_tokens: int, model: str = "gpt-3.5-turbo-0613") -> str:
|
||||
"""Truncate a string so that the number of tokens is less than or equal to max_tokens using tiktoken.
|
||||
|
||||
Args:
|
||||
content: String to process.
|
||||
max_tokens: Maximum number of tokens to keep.
|
||||
text: The string to truncate.
|
||||
max_tokens: The maximum number of tokens to keep.
|
||||
model: The target OpenAI model for tokenization alignment.
|
||||
|
||||
Returns:
|
||||
Truncated string.
|
||||
The truncated string.
|
||||
"""
|
||||
truncated_string = ""
|
||||
for char in text:
|
||||
truncated_string += char
|
||||
if token_count_utils.count_token(truncated_string) == max_tokens:
|
||||
break
|
||||
return truncated_string
|
||||
|
||||
encoding = tiktoken.encoding_for_model(model) # Get the appropriate tokenizer
|
||||
|
||||
encoded_tokens = encoding.encode(text)
|
||||
truncated_tokens = encoded_tokens[:max_tokens]
|
||||
truncated_text = encoding.decode(truncated_tokens) # Decode back to text
|
||||
|
||||
return truncated_text
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -102,6 +102,129 @@ def test_transform_chat_history_with_agents():
|
||||
assert False, f"Chat initiation failed with error {str(e)}"
|
||||
|
||||
|
||||
def test_transform_messages():
|
||||
"""
|
||||
Test transform_messages_retain_order()
|
||||
"""
|
||||
# Test case 1: Test that the order of messages is retained after transformation and Test that the messages are properly truncating.
|
||||
messages = [
|
||||
{"role": "system", "content": "System message"},
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "user", "content": "user sending the 2nd test message"},
|
||||
{"role": "assistant", "content": "assistant sending the 3rd test message"},
|
||||
{"role": "assistant", "content": "assistant sending the 4th test message"},
|
||||
]
|
||||
|
||||
transform_chat_history = TransformChatHistory(max_messages=3, max_tokens_per_message=10, max_tokens=100)
|
||||
transformed_messages = transform_chat_history._transform_messages(messages)
|
||||
|
||||
assert transformed_messages[0]["role"] == "system"
|
||||
assert transformed_messages[0]["content"] == "System message"
|
||||
assert transformed_messages[1]["role"] == "user"
|
||||
assert transformed_messages[1]["content"] == "user sending the 2nd test message"
|
||||
assert transformed_messages[2]["role"] == "assistant"
|
||||
assert transformed_messages[2]["content"] == "assistant sending the 3rd test message"
|
||||
assert transformed_messages[3]["role"] == "assistant"
|
||||
assert transformed_messages[3]["content"] == "assistant sending the 4th test message"
|
||||
|
||||
# Test case 2: Test when no system message
|
||||
messages = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "user", "content": "user sending the 2nd test message"},
|
||||
{"role": "assistant", "content": "assistant sending the 3rd test message"},
|
||||
{"role": "assistant", "content": "assistant sending the 4th test message"},
|
||||
]
|
||||
|
||||
transform_chat_history = TransformChatHistory(max_messages=3, max_tokens_per_message=10, max_tokens=100)
|
||||
transformed_messages = transform_chat_history._transform_messages(messages)
|
||||
|
||||
assert transformed_messages[0]["role"] == "user"
|
||||
assert transformed_messages[0]["content"] == "user sending the 2nd test message"
|
||||
assert transformed_messages[1]["role"] == "assistant"
|
||||
assert transformed_messages[1]["content"] == "assistant sending the 3rd test message"
|
||||
assert transformed_messages[2]["role"] == "assistant"
|
||||
assert transformed_messages[2]["content"] == "assistant sending the 4th test message"
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Out of max messages"},
|
||||
{"role": "assistant", "content": "first second third fourth"},
|
||||
{"role": "user", "content": "a"},
|
||||
]
|
||||
print(f"----Messages (N={len(messages)})----")
|
||||
orignal_tokens = 0
|
||||
for i, msg in enumerate(messages):
|
||||
print(f"[{msg['role']}-{i}]: {msg['content']}")
|
||||
tokens = token_count_utils.count_token(msg["content"])
|
||||
print("Number of tokens: ", tokens)
|
||||
orignal_tokens += tokens
|
||||
print("-----Total tokens: ", orignal_tokens, "-----")
|
||||
|
||||
allowed_max_tokens = 2
|
||||
transform_chat_history = TransformChatHistory(max_messages=2, max_tokens=allowed_max_tokens)
|
||||
transformed_messages = transform_chat_history._transform_messages(messages)
|
||||
|
||||
print("Max allowed tokens: ", allowed_max_tokens)
|
||||
|
||||
print("Transformed contents")
|
||||
for msg in transformed_messages:
|
||||
print(msg["content"])
|
||||
print("Number of tokens: ", token_count_utils.count_token(msg["content"]))
|
||||
assert len(transformed_messages) == 1
|
||||
assert transformed_messages[0]["role"] == "user"
|
||||
|
||||
|
||||
def test_truncate_str_to_tokens():
|
||||
"""
|
||||
Test the truncate_str_to_tokens function.
|
||||
"""
|
||||
from autogen.agentchat.contrib.capabilities.context_handling import truncate_str_to_tokens
|
||||
|
||||
# Test case 1: Truncate string with fewer tokens than max_tokens
|
||||
text = "This is a test"
|
||||
max_tokens = 5
|
||||
truncated_text = truncate_str_to_tokens(text, max_tokens)
|
||||
assert truncated_text == text
|
||||
|
||||
# Test case 2: Truncate string with more tokens than max_tokens
|
||||
text = "This is a test"
|
||||
max_tokens = 3
|
||||
truncated_text = truncate_str_to_tokens(text, max_tokens)
|
||||
assert truncated_text == "This is a"
|
||||
|
||||
# Test case 3: Truncate empty string
|
||||
text = ""
|
||||
max_tokens = 5
|
||||
truncated_text = truncate_str_to_tokens(text, max_tokens)
|
||||
assert truncated_text == ""
|
||||
|
||||
# Test case 4: Truncate string with exact number of tokens as max_tokens
|
||||
text = "This is a test"
|
||||
max_tokens = 4
|
||||
truncated_text = truncate_str_to_tokens(text, max_tokens)
|
||||
assert truncated_text == "This is a test"
|
||||
|
||||
# Test case 5: Truncate string with no tokens found
|
||||
text = "This is a test"
|
||||
max_tokens = 0
|
||||
truncated_text = truncate_str_to_tokens(text, max_tokens)
|
||||
assert truncated_text == ""
|
||||
|
||||
# Test case 6: Truncate string when actual tokens are more than max_tokens
|
||||
text = "This is a test with a looooooonngggg word"
|
||||
max_tokens = 8
|
||||
truncated_text = truncate_str_to_tokens(text, max_tokens)
|
||||
word_count = len(truncated_text.split())
|
||||
assert word_count <= max_tokens
|
||||
|
||||
# Test case 7: Truncate string with exact number of tokens as max_tokens
|
||||
text = "This\nis\na test"
|
||||
max_tokens = 4
|
||||
truncated_text = truncate_str_to_tokens(text, max_tokens)
|
||||
assert "This\nis" in truncated_text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_transform_chat_history()
|
||||
test_transform_chat_history_with_agents()
|
||||
test_truncate_str_to_tokens()
|
||||
test_transform_messages()
|
||||
|
||||
Reference in New Issue
Block a user