mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Add token_count_util (#421)
* add token_count_util * remove token_count from retrieval util * format * update dependency * update test
This commit is contained in:
@@ -6,7 +6,8 @@ except ImportError:
|
||||
raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`")
|
||||
from autogen.agentchat.agent import Agent
|
||||
from autogen.agentchat import UserProxyAgent
|
||||
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db, num_tokens_from_text
|
||||
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db
|
||||
from autogen.token_count_utils import count_token
|
||||
from autogen.code_utils import extract_code
|
||||
|
||||
from typing import Callable, Dict, Optional, Union, List, Tuple, Any
|
||||
@@ -124,8 +125,8 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||
- get_or_create (Optional, bool): if True, will create/recreate a collection for the retrieve chat.
|
||||
This is the same as that used in chromadb. Default is False. Will be set to False if docs_path is None.
|
||||
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
|
||||
The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name).
|
||||
Default is None, tiktoken will be used and may not be accurate for non-OpenAI models.
|
||||
The function should take (text:str, model:str) as input and return the token_count(int). the retrieve_config["model"] will be passed in the function.
|
||||
Default is autogen.token_count_utils.count_token that uses tiktoken, which may not be accurate for non-OpenAI models.
|
||||
- custom_text_split_function(Optional, Callable): a custom function to split a string into a list of strings.
|
||||
Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`.
|
||||
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
|
||||
@@ -180,7 +181,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||
self._get_or_create = (
|
||||
self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else False
|
||||
)
|
||||
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", None)
|
||||
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", count_token)
|
||||
self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None)
|
||||
self._context_max_tokens = self._max_tokens * 0.8
|
||||
self._collection = True if self._docs_path is None else False # whether the collection is created
|
||||
@@ -244,7 +245,7 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
||||
continue
|
||||
if results["ids"][0][idx] in self._doc_ids:
|
||||
continue
|
||||
_doc_tokens = num_tokens_from_text(doc, custom_token_count_function=self.custom_token_count_function)
|
||||
_doc_tokens = self.custom_token_count_function(doc, self._model)
|
||||
if _doc_tokens > self._context_max_tokens:
|
||||
func_print = f"Skip doc_id {results['ids'][0][idx]} as it is too long to fit in the context."
|
||||
print(colored(func_print, "green"), flush=True)
|
||||
|
||||
@@ -14,7 +14,7 @@ from chromadb.api.types import QueryResult
|
||||
import chromadb.utils.embedding_functions as ef
|
||||
import logging
|
||||
import pypdf
|
||||
|
||||
from autogen.token_count_utils import count_token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
TEXT_FORMATS = [
|
||||
@@ -37,80 +37,6 @@ TEXT_FORMATS = [
|
||||
VALID_CHUNK_MODES = frozenset({"one_line", "multi_lines"})
|
||||
|
||||
|
||||
def num_tokens_from_text(
|
||||
text: str,
|
||||
model: str = "gpt-3.5-turbo-0613",
|
||||
return_tokens_per_name_and_message: bool = False,
|
||||
custom_token_count_function: Callable = None,
|
||||
) -> Union[int, Tuple[int, int, int]]:
|
||||
"""Return the number of tokens used by a text.
|
||||
|
||||
Args:
|
||||
text (str): The text to count tokens for.
|
||||
model (Optional, str): The model to use for tokenization. Default is "gpt-3.5-turbo-0613".
|
||||
return_tokens_per_name_and_message (Optional, bool): Whether to return the number of tokens per name and per
|
||||
message. Default is False.
|
||||
custom_token_count_function (Optional, Callable): A custom function to count tokens. Default is None.
|
||||
|
||||
Returns:
|
||||
int: The number of tokens used by the text.
|
||||
int: The number of tokens per message. Only returned if return_tokens_per_name_and_message is True.
|
||||
int: The number of tokens per name. Only returned if return_tokens_per_name_and_message is True.
|
||||
"""
|
||||
if isinstance(custom_token_count_function, Callable):
|
||||
token_count, tokens_per_message, tokens_per_name = custom_token_count_function(text)
|
||||
else:
|
||||
# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
logger.debug("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
known_models = {
|
||||
"gpt-3.5-turbo": (3, 1),
|
||||
"gpt-35-turbo": (3, 1),
|
||||
"gpt-3.5-turbo-0613": (3, 1),
|
||||
"gpt-3.5-turbo-16k-0613": (3, 1),
|
||||
"gpt-3.5-turbo-0301": (4, -1),
|
||||
"gpt-4": (3, 1),
|
||||
"gpt-4-0314": (3, 1),
|
||||
"gpt-4-32k-0314": (3, 1),
|
||||
"gpt-4-0613": (3, 1),
|
||||
"gpt-4-32k-0613": (3, 1),
|
||||
}
|
||||
tokens_per_message, tokens_per_name = known_models.get(model, (3, 1))
|
||||
token_count = len(encoding.encode(text))
|
||||
|
||||
if return_tokens_per_name_and_message:
|
||||
return token_count, tokens_per_message, tokens_per_name
|
||||
else:
|
||||
return token_count
|
||||
|
||||
|
||||
def num_tokens_from_messages(
|
||||
messages: dict,
|
||||
model: str = "gpt-3.5-turbo-0613",
|
||||
custom_token_count_function: Callable = None,
|
||||
custom_prime_count: int = 3,
|
||||
):
|
||||
"""Return the number of tokens used by a list of messages."""
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
for key, value in message.items():
|
||||
_num_tokens, tokens_per_message, tokens_per_name = num_tokens_from_text(
|
||||
value,
|
||||
model=model,
|
||||
return_tokens_per_name_and_message=True,
|
||||
custom_token_count_function=custom_token_count_function,
|
||||
)
|
||||
num_tokens += _num_tokens
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
num_tokens += tokens_per_message
|
||||
num_tokens += custom_prime_count # With ChatGPT, every reply is primed with <|start|>assistant<|message|>
|
||||
return num_tokens
|
||||
|
||||
|
||||
def split_text_to_chunks(
|
||||
text: str,
|
||||
max_tokens: int = 4000,
|
||||
@@ -125,7 +51,7 @@ def split_text_to_chunks(
|
||||
must_break_at_empty_line = False
|
||||
chunks = []
|
||||
lines = text.split("\n")
|
||||
lines_tokens = [num_tokens_from_text(line) for line in lines]
|
||||
lines_tokens = [count_token(line) for line in lines]
|
||||
sum_tokens = sum(lines_tokens)
|
||||
while sum_tokens > max_tokens:
|
||||
if chunk_mode == "one_line":
|
||||
@@ -148,7 +74,7 @@ def split_text_to_chunks(
|
||||
split_len = int(max_tokens / lines_tokens[0] * 0.9 * len(lines[0]))
|
||||
prev = lines[0][:split_len]
|
||||
lines[0] = lines[0][split_len:]
|
||||
lines_tokens[0] = num_tokens_from_text(lines[0])
|
||||
lines_tokens[0] = count_token(lines[0])
|
||||
else:
|
||||
logger.warning("Failed to split docs with must_break_at_empty_line being True, set to False.")
|
||||
must_break_at_empty_line = False
|
||||
|
||||
182
autogen/token_count_utils.py
Normal file
182
autogen/token_count_utils.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import tiktoken
|
||||
from typing import List, Union, Dict, Tuple
|
||||
import logging
|
||||
import json
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_max_token_limit(model="gpt-3.5-turbo-0613"):
|
||||
max_token_limit = {
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-0301": 4096,
|
||||
"gpt-3.5-turbo-0613": 4096,
|
||||
"gpt-3.5-turbo-instruct": 4096,
|
||||
"gpt-3.5-turbo-16k": 16384,
|
||||
"gpt-35-turbo": 4096,
|
||||
"gpt-35-turbo-16k": 16384,
|
||||
"gpt-35-turbo-instruct": 4096,
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4-32k-0314": 32768, # deprecate in Sep
|
||||
"gpt-4-0314": 8192, # deprecate in Sep
|
||||
"gpt-4-0613": 8192,
|
||||
"gpt-4-32k-0613": 32768,
|
||||
}
|
||||
return max_token_limit[model]
|
||||
|
||||
|
||||
def percentile_used(input, model="gpt-3.5-turbo-0613"):
|
||||
return count_token(input) / get_max_token_limit(model)
|
||||
|
||||
|
||||
def token_left(input: Union[str, List, Dict], model="gpt-3.5-turbo-0613") -> int:
|
||||
"""Count number of tokens left for an OpenAI model.
|
||||
|
||||
Args:
|
||||
input: (str, list, dict): Input to the model.
|
||||
model: (str): Model name.
|
||||
|
||||
Returns:
|
||||
int: Number of tokens left that the model can use for completion.
|
||||
"""
|
||||
return get_max_token_limit(model) - count_token(input, model=model)
|
||||
|
||||
|
||||
def count_token(input: Union[str, List, Dict], model: str = "gpt-3.5-turbo-0613") -> int:
|
||||
"""Count number of tokens used by an OpenAI model.
|
||||
Args:
|
||||
input: (str, list, dict): Input to the model.
|
||||
model: (str): Model name.
|
||||
|
||||
Returns:
|
||||
int: Number of tokens from the input.
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
return _num_token_from_text(input, model=model)
|
||||
elif isinstance(input, list) or isinstance(input, dict):
|
||||
return _num_token_from_messages(input, model=model)
|
||||
else:
|
||||
raise ValueError("input must be str, list or dict")
|
||||
|
||||
|
||||
def _num_token_from_text(text: str, model: str = "gpt-3.5-turbo-0613"):
|
||||
"""Return the number of tokens used by a string."""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
return len(encoding.encode(text))
|
||||
|
||||
|
||||
def _num_token_from_messages(messages: Union[List, Dict], model="gpt-3.5-turbo-0613"):
|
||||
"""Return the number of tokens used by a list of messages.
|
||||
|
||||
retrieved from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb/
|
||||
"""
|
||||
if isinstance(messages, dict):
|
||||
messages = [messages]
|
||||
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
print("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
if model in {
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
}:
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
elif model == "gpt-3.5-turbo-0301":
|
||||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
elif "gpt-3.5-turbo" in model:
|
||||
logger.info("gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
|
||||
return _num_token_from_messages(messages, model="gpt-3.5-turbo-0613")
|
||||
elif "gpt-4" in model:
|
||||
logger.info("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
||||
return _num_token_from_messages(messages, model="gpt-4-0613")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"""_num_token_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
|
||||
)
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
# function calls
|
||||
if not isinstance(value, str):
|
||||
try:
|
||||
value = json.dumps(value)
|
||||
except TypeError:
|
||||
logger.warning(
|
||||
f"Value {value} is not a string and cannot be converted to json. It is a type: {type(value)} Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
num_tokens += len(encoding.encode(value))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
return num_tokens
|
||||
|
||||
|
||||
def num_tokens_from_functions(functions, model="gpt-3.5-turbo-0613") -> int:
|
||||
"""Return the number of tokens used by a list of functions.
|
||||
|
||||
Args:
|
||||
functions: (list): List of function descriptions that will be passed in model.
|
||||
model: (str): Model name.
|
||||
|
||||
Returns:
|
||||
int: Number of tokens from the function descriptions.
|
||||
"""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
print("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
num_tokens = 0
|
||||
for function in functions:
|
||||
function_tokens = len(encoding.encode(function["name"]))
|
||||
function_tokens += len(encoding.encode(function["description"]))
|
||||
function_tokens -= 2
|
||||
if "parameters" in function:
|
||||
parameters = function["parameters"]
|
||||
if "properties" in parameters:
|
||||
for propertiesKey in parameters["properties"]:
|
||||
function_tokens += len(encoding.encode(propertiesKey))
|
||||
v = parameters["properties"][propertiesKey]
|
||||
for field in v:
|
||||
if field == "type":
|
||||
function_tokens += 2
|
||||
function_tokens += len(encoding.encode(v["type"]))
|
||||
elif field == "description":
|
||||
function_tokens += 2
|
||||
function_tokens += len(encoding.encode(v["description"]))
|
||||
elif field == "enum":
|
||||
function_tokens -= 3
|
||||
for o in v["enum"]:
|
||||
function_tokens += 3
|
||||
function_tokens += len(encoding.encode(o))
|
||||
else:
|
||||
print(f"Warning: not supported field {field}")
|
||||
function_tokens += 11
|
||||
if len(parameters["properties"]) == 0:
|
||||
function_tokens -= 2
|
||||
|
||||
num_tokens += function_tokens
|
||||
|
||||
num_tokens += 12
|
||||
return num_tokens
|
||||
@@ -11,10 +11,9 @@ from autogen.retrieve_utils import (
|
||||
is_url,
|
||||
create_vector_db_from_dir,
|
||||
query_vector_db,
|
||||
num_tokens_from_text,
|
||||
num_tokens_from_messages,
|
||||
TEXT_FORMATS,
|
||||
)
|
||||
from autogen.token_count_utils import count_token
|
||||
|
||||
import os
|
||||
import sys
|
||||
@@ -31,31 +30,10 @@ integration, testing, and deployment."""
|
||||
|
||||
|
||||
class TestRetrieveUtils:
|
||||
def test_num_tokens_from_text_custom_token_count_function(self):
|
||||
def custom_token_count_function(text):
|
||||
return len(text), 1, 2
|
||||
|
||||
text = "This is a sample text."
|
||||
assert num_tokens_from_text(
|
||||
text, return_tokens_per_name_and_message=True, custom_token_count_function=custom_token_count_function
|
||||
) == (22, 1, 2)
|
||||
|
||||
def test_num_tokens_from_text(self):
|
||||
text = "This is a sample text."
|
||||
assert num_tokens_from_text(text) == len(tiktoken.get_encoding("cl100k_base").encode(text))
|
||||
|
||||
def test_num_tokens_from_messages(self):
|
||||
messages = [{"content": "This is a sample text."}, {"content": "Another sample text."}]
|
||||
# Review the implementation of num_tokens_from_messages
|
||||
# and adjust the expected_tokens accordingly.
|
||||
actual_tokens = num_tokens_from_messages(messages)
|
||||
expected_tokens = actual_tokens # Adjusted to make the test pass temporarily.
|
||||
assert actual_tokens == expected_tokens
|
||||
|
||||
def test_split_text_to_chunks(self):
|
||||
long_text = "A" * 10000
|
||||
chunks = split_text_to_chunks(long_text, max_tokens=1000)
|
||||
assert all(num_tokens_from_text(chunk) <= 1000 for chunk in chunks)
|
||||
assert all(count_token(chunk) <= 1000 for chunk in chunks)
|
||||
|
||||
def test_split_text_to_chunks_raises_on_invalid_chunk_mode(self):
|
||||
with pytest.raises(AssertionError):
|
||||
|
||||
72
test/test_token_count.py
Normal file
72
test/test_token_count.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from autogen.token_count_utils import count_token, num_tokens_from_functions, token_left, percentile_used
|
||||
import pytest
|
||||
|
||||
func1 = {
|
||||
"name": "sh",
|
||||
"description": "run a shell script and return the execution result.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"script": {
|
||||
"type": "string",
|
||||
"description": "Valid shell script to execute.",
|
||||
}
|
||||
},
|
||||
"required": ["script"],
|
||||
},
|
||||
}
|
||||
func2 = {
|
||||
"name": "query_wolfram",
|
||||
"description": "Return the API query result from the Wolfram Alpha. the ruturn is a tuple of (result, is_success).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
}
|
||||
func3 = {
|
||||
"name": "python",
|
||||
"description": "run cell in ipython and return the execution result.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"cell": {
|
||||
"type": "string",
|
||||
"description": "Valid Python cell to execute.",
|
||||
}
|
||||
},
|
||||
"required": ["cell"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_functions, expected_count", [([func1], 44), ([func2], 47), ([func3], 45), ([func1, func2], 79)]
|
||||
)
|
||||
def test_num_tokens_from_functions(input_functions, expected_count):
|
||||
assert num_tokens_from_functions(input_functions) == expected_count
|
||||
|
||||
|
||||
def test_count_token():
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant. af3758 *3 33(3)",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello asdfjj qeweee",
|
||||
},
|
||||
]
|
||||
assert count_token(messages) == 34
|
||||
assert percentile_used(messages) == 34 / 4096
|
||||
assert token_left(messages) == 4096 - 34
|
||||
|
||||
text = "I'm sorry, but I'm not able to"
|
||||
assert count_token(text) == 10
|
||||
assert token_left(text) == 4096 - 10
|
||||
assert percentile_used(text) == 10 / 4096
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_num_tokens_from_functions()
|
||||
test_count_token()
|
||||
Reference in New Issue
Block a user