Refactor transform_messages (#1631)

* refactored code to simplify

* optimize function. Instead of iterating over each character, guess at size and then iterate by token.

* adding tests

* Add missing tests

* minor test fix

* simplified token truncation by using tiktoken to encode and decode

* updated truncated notification message

* Fix llm_config spec to use os.environ

* Add test case and fix bug in loop

---------

Co-authored-by: gagb <gagb@users.noreply.github.com>
This commit is contained in:
dkirsche
2024-02-20 17:53:05 -05:00
committed by GitHub
parent d8a204a9a3
commit a34e4cc515
3 changed files with 487 additions and 279 deletions

View File

@@ -3,6 +3,7 @@ from termcolor import colored
from typing import Dict, Optional, List
from autogen import ConversableAgent
from autogen import token_count_utils
import tiktoken
class TransformChatHistory:
@@ -53,56 +54,70 @@ class TransformChatHistory:
messages: List of messages to process.
Returns:
List of messages with the first system message and the last max_messages messages.
List of messages with the first system message and the last max_messages messages,
ensuring each message does not exceed max_tokens_per_message.
"""
temp_messages = messages.copy()
processed_messages = []
messages = messages.copy()
rest_messages = messages
# check if the first message is a system message and append it to the processed messages
if len(messages) > 0:
if messages[0]["role"] == "system":
msg = messages[0]
processed_messages.append(msg)
rest_messages = messages[1:]
system_message = None
processed_messages_tokens = 0
for msg in messages:
msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message)
# iterate through rest of the messages and append them to the processed messages
for msg in rest_messages[-self.max_messages :]:
if messages[0]["role"] == "system":
system_message = messages[0].copy()
temp_messages.pop(0)
total_tokens = sum(
token_count_utils.count_token(msg["content"]) for msg in temp_messages
) # Calculate tokens for all messages
# Truncate each message's content to a maximum token limit of each message
# Process recent messages first
for msg in reversed(temp_messages[-self.max_messages :]):
msg["content"] = truncate_str_to_tokens(msg["content"], self.max_tokens_per_message)
msg_tokens = token_count_utils.count_token(msg["content"])
if processed_messages_tokens + msg_tokens > self.max_tokens:
break
processed_messages.append(msg)
# append the message to the beginning of the list to preserve order
processed_messages = [msg] + processed_messages
processed_messages_tokens += msg_tokens
total_tokens = 0
for msg in messages:
total_tokens += token_count_utils.count_token(msg["content"])
if system_message:
processed_messages.insert(0, system_message)
# Optionally, log the number of truncated messages and tokens if needed
num_truncated = len(messages) - len(processed_messages)
if num_truncated > 0 or total_tokens > processed_messages_tokens:
print(colored(f"Truncated {len(messages) - len(processed_messages)} messages.", "yellow"))
print(colored(f"Truncated {total_tokens - processed_messages_tokens} tokens.", "yellow"))
print(
colored(
f"Truncated {num_truncated} messages. Reduced from {len(messages)} to {len(processed_messages)}.",
"yellow",
)
)
print(
colored(
f"Truncated {total_tokens - processed_messages_tokens} tokens. Tokens reduced from {total_tokens} to {processed_messages_tokens}",
"yellow",
)
)
return processed_messages
def truncate_str_to_tokens(text: str, max_tokens: int) -> str:
"""
Truncate a string so that number of tokens in less than max_tokens.
def truncate_str_to_tokens(text: str, max_tokens: int, model: str = "gpt-3.5-turbo-0613") -> str:
"""Truncate a string so that the number of tokens is less than or equal to max_tokens using tiktoken.
Args:
content: String to process.
max_tokens: Maximum number of tokens to keep.
text: The string to truncate.
max_tokens: The maximum number of tokens to keep.
model: The target OpenAI model for tokenization alignment.
Returns:
Truncated string.
The truncated string.
"""
truncated_string = ""
for char in text:
truncated_string += char
if token_count_utils.count_token(truncated_string) == max_tokens:
break
return truncated_string
encoding = tiktoken.encoding_for_model(model) # Get the appropriate tokenizer
encoded_tokens = encoding.encode(text)
truncated_tokens = encoded_tokens[:max_tokens]
truncated_text = encoding.decode(truncated_tokens) # Decode back to text
return truncated_text

File diff suppressed because one or more lines are too long

View File

@@ -102,6 +102,129 @@ def test_transform_chat_history_with_agents():
assert False, f"Chat initiation failed with error {str(e)}"
def test_transform_messages():
"""
Test transform_messages_retain_order()
"""
# Test case 1: Test that the order of messages is retained after transformation and Test that the messages are properly truncating.
messages = [
{"role": "system", "content": "System message"},
{"role": "user", "content": "Hi"},
{"role": "user", "content": "user sending the 2nd test message"},
{"role": "assistant", "content": "assistant sending the 3rd test message"},
{"role": "assistant", "content": "assistant sending the 4th test message"},
]
transform_chat_history = TransformChatHistory(max_messages=3, max_tokens_per_message=10, max_tokens=100)
transformed_messages = transform_chat_history._transform_messages(messages)
assert transformed_messages[0]["role"] == "system"
assert transformed_messages[0]["content"] == "System message"
assert transformed_messages[1]["role"] == "user"
assert transformed_messages[1]["content"] == "user sending the 2nd test message"
assert transformed_messages[2]["role"] == "assistant"
assert transformed_messages[2]["content"] == "assistant sending the 3rd test message"
assert transformed_messages[3]["role"] == "assistant"
assert transformed_messages[3]["content"] == "assistant sending the 4th test message"
# Test case 2: Test when no system message
messages = [
{"role": "user", "content": "Hi"},
{"role": "user", "content": "user sending the 2nd test message"},
{"role": "assistant", "content": "assistant sending the 3rd test message"},
{"role": "assistant", "content": "assistant sending the 4th test message"},
]
transform_chat_history = TransformChatHistory(max_messages=3, max_tokens_per_message=10, max_tokens=100)
transformed_messages = transform_chat_history._transform_messages(messages)
assert transformed_messages[0]["role"] == "user"
assert transformed_messages[0]["content"] == "user sending the 2nd test message"
assert transformed_messages[1]["role"] == "assistant"
assert transformed_messages[1]["content"] == "assistant sending the 3rd test message"
assert transformed_messages[2]["role"] == "assistant"
assert transformed_messages[2]["content"] == "assistant sending the 4th test message"
messages = [
{"role": "user", "content": "Out of max messages"},
{"role": "assistant", "content": "first second third fourth"},
{"role": "user", "content": "a"},
]
print(f"----Messages (N={len(messages)})----")
orignal_tokens = 0
for i, msg in enumerate(messages):
print(f"[{msg['role']}-{i}]: {msg['content']}")
tokens = token_count_utils.count_token(msg["content"])
print("Number of tokens: ", tokens)
orignal_tokens += tokens
print("-----Total tokens: ", orignal_tokens, "-----")
allowed_max_tokens = 2
transform_chat_history = TransformChatHistory(max_messages=2, max_tokens=allowed_max_tokens)
transformed_messages = transform_chat_history._transform_messages(messages)
print("Max allowed tokens: ", allowed_max_tokens)
print("Transformed contents")
for msg in transformed_messages:
print(msg["content"])
print("Number of tokens: ", token_count_utils.count_token(msg["content"]))
assert len(transformed_messages) == 1
assert transformed_messages[0]["role"] == "user"
def test_truncate_str_to_tokens():
"""
Test the truncate_str_to_tokens function.
"""
from autogen.agentchat.contrib.capabilities.context_handling import truncate_str_to_tokens
# Test case 1: Truncate string with fewer tokens than max_tokens
text = "This is a test"
max_tokens = 5
truncated_text = truncate_str_to_tokens(text, max_tokens)
assert truncated_text == text
# Test case 2: Truncate string with more tokens than max_tokens
text = "This is a test"
max_tokens = 3
truncated_text = truncate_str_to_tokens(text, max_tokens)
assert truncated_text == "This is a"
# Test case 3: Truncate empty string
text = ""
max_tokens = 5
truncated_text = truncate_str_to_tokens(text, max_tokens)
assert truncated_text == ""
# Test case 4: Truncate string with exact number of tokens as max_tokens
text = "This is a test"
max_tokens = 4
truncated_text = truncate_str_to_tokens(text, max_tokens)
assert truncated_text == "This is a test"
# Test case 5: Truncate string with no tokens found
text = "This is a test"
max_tokens = 0
truncated_text = truncate_str_to_tokens(text, max_tokens)
assert truncated_text == ""
# Test case 6: Truncate string when actual tokens are more than max_tokens
text = "This is a test with a looooooonngggg word"
max_tokens = 8
truncated_text = truncate_str_to_tokens(text, max_tokens)
word_count = len(truncated_text.split())
assert word_count <= max_tokens
# Test case 7: Truncate string with exact number of tokens as max_tokens
text = "This\nis\na test"
max_tokens = 4
truncated_text = truncate_str_to_tokens(text, max_tokens)
assert "This\nis" in truncated_text
if __name__ == "__main__":
test_transform_chat_history()
test_transform_chat_history_with_agents()
test_truncate_str_to_tokens()
test_transform_messages()