mirror of
https://github.com/microsoft/autogen.git
synced 2026-01-09 01:38:56 -05:00
Fix not supported field warnings in count_tokens_openai (#6987)
This commit is contained in:
@@ -137,7 +137,10 @@ async def test_token_limited_model_context_with_token_limit(
|
||||
await model_context.add_message(msg)
|
||||
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 1 # Token limit set very low, will remove 2 of the messages
|
||||
# Token limit set low, will remove some messages
|
||||
# OpenAI: keeps 2 messages (29 tokens with limit 30)
|
||||
# Ollama: keeps 1 message (20 tokens with limit 20)
|
||||
assert len(retrieved) < len(messages) # Some messages removed due to token limit
|
||||
assert retrieved != messages # Will not be equal to the original messages
|
||||
|
||||
await model_context.clear()
|
||||
@@ -151,7 +154,7 @@ async def test_token_limited_model_context_with_token_limit(
|
||||
await model_context.clear()
|
||||
await model_context.load_state(state)
|
||||
retrieved = await model_context.get_messages()
|
||||
assert len(retrieved) == 1
|
||||
assert len(retrieved) < len(messages) # Some messages removed due to token limit
|
||||
assert retrieved != messages
|
||||
|
||||
|
||||
|
||||
@@ -393,6 +393,17 @@ def count_tokens_openai(
|
||||
elif field == "description":
|
||||
tool_tokens += 2
|
||||
tool_tokens += len(encoding.encode(v["description"])) # pyright: ignore
|
||||
elif field == "anyOf":
|
||||
tool_tokens -= 3
|
||||
for o in v["anyOf"]: # type: ignore
|
||||
tool_tokens += 3
|
||||
tool_tokens += len(encoding.encode(str(o["type"]))) # pyright: ignore
|
||||
elif field == "default":
|
||||
tool_tokens += 2
|
||||
tool_tokens += len(encoding.encode(json.dumps(v["default"])))
|
||||
elif field == "title":
|
||||
tool_tokens += 2
|
||||
tool_tokens += len(encoding.encode(str(v["title"]))) # pyright: ignore
|
||||
elif field == "enum":
|
||||
tool_tokens -= 3
|
||||
for o in v["enum"]: # pyright: ignore
|
||||
@@ -404,7 +415,9 @@ def count_tokens_openai(
|
||||
if len(parameters["properties"]) == 0: # pyright: ignore
|
||||
tool_tokens -= 2
|
||||
num_tokens += tool_tokens
|
||||
num_tokens += 12
|
||||
|
||||
if oai_tools:
|
||||
num_tokens += 12
|
||||
return num_tokens
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Annotated, Any, AsyncGenerator, Dict, List, Literal, Tuple, TypeVar
|
||||
from typing import Annotated, Any, AsyncGenerator, Dict, List, Literal, Optional, Tuple, TypeVar
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import httpx
|
||||
@@ -450,11 +450,27 @@ async def test_openai_chat_completion_client_count_tokens(monkeypatch: pytest.Mo
|
||||
def tool2(test1: int, test2: List[int]) -> str:
|
||||
return str(test1) + str(test2)
|
||||
|
||||
tools = [FunctionTool(tool1, description="example tool 1"), FunctionTool(tool2, description="example tool 2")]
|
||||
def tool3(test1: Annotated[Optional[str], "example"] = None, test2: Literal["1", "2"] = "2") -> str:
|
||||
return str(test1) + str(test2)
|
||||
|
||||
tools = [
|
||||
FunctionTool(tool1, description="example tool 1"),
|
||||
FunctionTool(tool2, description="example tool 2"),
|
||||
FunctionTool(tool3, description="example tool 3"),
|
||||
]
|
||||
|
||||
mockcalculate_vision_tokens = MagicMock()
|
||||
monkeypatch.setattr("autogen_ext.models.openai._openai_client.calculate_vision_tokens", mockcalculate_vision_tokens)
|
||||
|
||||
# Test count_tokens without tools
|
||||
num_tokens = client.count_tokens(messages)
|
||||
assert num_tokens
|
||||
|
||||
# Check that calculate_vision_tokens was called
|
||||
mockcalculate_vision_tokens.assert_called_once()
|
||||
mockcalculate_vision_tokens.reset_mock()
|
||||
|
||||
# Test count_tokens with tools
|
||||
num_tokens = client.count_tokens(messages, tools=tools)
|
||||
assert num_tokens
|
||||
|
||||
|
||||
Reference in New Issue
Block a user