mirror of
https://github.com/microsoft/autogen.git
synced 2026-01-09 22:38:08 -05:00
Fix not supported field warnings in count_tokens_openai (#6987)
This commit is contained in:
@@ -137,7 +137,10 @@ async def test_token_limited_model_context_with_token_limit(
|
|||||||
await model_context.add_message(msg)
|
await model_context.add_message(msg)
|
||||||
|
|
||||||
retrieved = await model_context.get_messages()
|
retrieved = await model_context.get_messages()
|
||||||
assert len(retrieved) == 1 # Token limit set very low, will remove 2 of the messages
|
# Token limit set low, will remove some messages
|
||||||
|
# OpenAI: keeps 2 messages (29 tokens with limit 30)
|
||||||
|
# Ollama: keeps 1 message (20 tokens with limit 20)
|
||||||
|
assert len(retrieved) < len(messages) # Some messages removed due to token limit
|
||||||
assert retrieved != messages # Will not be equal to the original messages
|
assert retrieved != messages # Will not be equal to the original messages
|
||||||
|
|
||||||
await model_context.clear()
|
await model_context.clear()
|
||||||
@@ -151,7 +154,7 @@ async def test_token_limited_model_context_with_token_limit(
|
|||||||
await model_context.clear()
|
await model_context.clear()
|
||||||
await model_context.load_state(state)
|
await model_context.load_state(state)
|
||||||
retrieved = await model_context.get_messages()
|
retrieved = await model_context.get_messages()
|
||||||
assert len(retrieved) == 1
|
assert len(retrieved) < len(messages) # Some messages removed due to token limit
|
||||||
assert retrieved != messages
|
assert retrieved != messages
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -393,6 +393,17 @@ def count_tokens_openai(
|
|||||||
elif field == "description":
|
elif field == "description":
|
||||||
tool_tokens += 2
|
tool_tokens += 2
|
||||||
tool_tokens += len(encoding.encode(v["description"])) # pyright: ignore
|
tool_tokens += len(encoding.encode(v["description"])) # pyright: ignore
|
||||||
|
elif field == "anyOf":
|
||||||
|
tool_tokens -= 3
|
||||||
|
for o in v["anyOf"]: # type: ignore
|
||||||
|
tool_tokens += 3
|
||||||
|
tool_tokens += len(encoding.encode(str(o["type"]))) # pyright: ignore
|
||||||
|
elif field == "default":
|
||||||
|
tool_tokens += 2
|
||||||
|
tool_tokens += len(encoding.encode(json.dumps(v["default"])))
|
||||||
|
elif field == "title":
|
||||||
|
tool_tokens += 2
|
||||||
|
tool_tokens += len(encoding.encode(str(v["title"]))) # pyright: ignore
|
||||||
elif field == "enum":
|
elif field == "enum":
|
||||||
tool_tokens -= 3
|
tool_tokens -= 3
|
||||||
for o in v["enum"]: # pyright: ignore
|
for o in v["enum"]: # pyright: ignore
|
||||||
@@ -404,7 +415,9 @@ def count_tokens_openai(
|
|||||||
if len(parameters["properties"]) == 0: # pyright: ignore
|
if len(parameters["properties"]) == 0: # pyright: ignore
|
||||||
tool_tokens -= 2
|
tool_tokens -= 2
|
||||||
num_tokens += tool_tokens
|
num_tokens += tool_tokens
|
||||||
num_tokens += 12
|
|
||||||
|
if oai_tools:
|
||||||
|
num_tokens += 12
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Annotated, Any, AsyncGenerator, Dict, List, Literal, Tuple, TypeVar
|
from typing import Annotated, Any, AsyncGenerator, Dict, List, Literal, Optional, Tuple, TypeVar
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -450,11 +450,27 @@ async def test_openai_chat_completion_client_count_tokens(monkeypatch: pytest.Mo
|
|||||||
def tool2(test1: int, test2: List[int]) -> str:
|
def tool2(test1: int, test2: List[int]) -> str:
|
||||||
return str(test1) + str(test2)
|
return str(test1) + str(test2)
|
||||||
|
|
||||||
tools = [FunctionTool(tool1, description="example tool 1"), FunctionTool(tool2, description="example tool 2")]
|
def tool3(test1: Annotated[Optional[str], "example"] = None, test2: Literal["1", "2"] = "2") -> str:
|
||||||
|
return str(test1) + str(test2)
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
FunctionTool(tool1, description="example tool 1"),
|
||||||
|
FunctionTool(tool2, description="example tool 2"),
|
||||||
|
FunctionTool(tool3, description="example tool 3"),
|
||||||
|
]
|
||||||
|
|
||||||
mockcalculate_vision_tokens = MagicMock()
|
mockcalculate_vision_tokens = MagicMock()
|
||||||
monkeypatch.setattr("autogen_ext.models.openai._openai_client.calculate_vision_tokens", mockcalculate_vision_tokens)
|
monkeypatch.setattr("autogen_ext.models.openai._openai_client.calculate_vision_tokens", mockcalculate_vision_tokens)
|
||||||
|
|
||||||
|
# Test count_tokens without tools
|
||||||
|
num_tokens = client.count_tokens(messages)
|
||||||
|
assert num_tokens
|
||||||
|
|
||||||
|
# Check that calculate_vision_tokens was called
|
||||||
|
mockcalculate_vision_tokens.assert_called_once()
|
||||||
|
mockcalculate_vision_tokens.reset_mock()
|
||||||
|
|
||||||
|
# Test count_tokens with tools
|
||||||
num_tokens = client.count_tokens(messages, tools=tools)
|
num_tokens = client.count_tokens(messages, tools=tools)
|
||||||
assert num_tokens
|
assert num_tokens
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user