fix: use tool_calls field to detect tool calls in OpenAI client; add integration tests for OpenAI and Gemini (#5122)

* fix: use tool_calls field to detect tool calls in OpenAI client

* Add unit tests for tool calling; and integration tests for openai and gemini
This commit is contained in:
Eric Zhu
2025-01-21 06:06:19 -08:00
committed by GitHub
parent e0a6a86b12
commit da1c2bf12e
2 changed files with 349 additions and 12 deletions

View File

@@ -539,20 +539,33 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
if self._resolved_model is not None:
if self._resolved_model != result.model:
warnings.warn(
f"Resolved model mismatch: {self._resolved_model} != {result.model}. Model mapping may be incorrect.",
f"Resolved model mismatch: {self._resolved_model} != {result.model}. "
"Model mapping in autogen_ext.models.openai may be incorrect.",
stacklevel=2,
)
# Limited to a single choice currently.
choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], Choice] = result.choices[0]
if choice.finish_reason == "function_call":
raise ValueError("Function calls are not supported in this context")
# Detect whether it is a function call or not.
# We don't rely on choice.finish_reason as it is not always accurate, depending on the API used.
content: Union[str, List[FunctionCall]]
if choice.finish_reason == "tool_calls":
assert choice.message.tool_calls is not None
assert choice.message.function_call is None
if choice.message.function_call is not None:
raise ValueError("function_call is deprecated and is not supported by this model client.")
elif choice.message.tool_calls is not None:
if choice.finish_reason != "tool_calls":
warnings.warn(
f"Finish reason mismatch: {choice.finish_reason} != tool_calls "
"when tool_calls are present. Finish reason may not be accurate. "
"This may be due to the API used that is not returning the correct finish reason.",
stacklevel=2,
)
if choice.message.content is not None and choice.message.content != "":
warnings.warn(
"Both tool_calls and content are present in the message. "
"This is unexpected. content will be ignored, tool_calls will be used.",
stacklevel=2,
)
# NOTE: If OAI response type changes, this will need to be updated
content = [
FunctionCall(
@@ -562,10 +575,11 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
)
for x in choice.message.tool_calls
]
finish_reason = "function_calls"
finish_reason = "tool_calls"
else:
finish_reason = choice.finish_reason
content = choice.message.content or ""
logprobs: Optional[List[ChatCompletionTokenLogprob]] = None
if choice.logprobs and choice.logprobs.content:
logprobs = [

View File

@@ -1,10 +1,11 @@
import asyncio
import json
from typing import Annotated, Any, AsyncGenerator, Generic, List, Literal, Tuple, TypeVar
import os
from typing import Annotated, Any, AsyncGenerator, Dict, Generic, List, Literal, Tuple, TypeVar
from unittest.mock import MagicMock
import pytest
from autogen_core import CancellationToken, Image
from autogen_core import CancellationToken, FunctionCall, Image
from autogen_core.models import (
AssistantMessage,
CreateResult,
@@ -26,10 +27,31 @@ from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion, ParsedChatCompletionMessage, ParsedChoice
from openai.types.completion_usage import CompletionUsage
from pydantic import BaseModel, Field
class _MockChatCompletion:
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
self._saved_chat_completions = chat_completions
self.curr_index = 0
self.calls: List[Dict[str, Any]] = []
async def mock_create(
self, *args: Any, **kwargs: Any
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
self.calls.append(kwargs) # Save the call
await asyncio.sleep(0.1)
completion = self._saved_chat_completions[self.curr_index]
self.curr_index += 1
return completion
ResponseFormatT = TypeVar("ResponseFormatT", bound=BaseModel)
@@ -37,20 +59,32 @@ class _MockBetaChatCompletion(Generic[ResponseFormatT]):
def __init__(self, chat_completions: List[ParsedChatCompletion[ResponseFormatT]]) -> None:
self._saved_chat_completions = chat_completions
self.curr_index = 0
self.calls: List[List[LLMMessage]] = []
self.calls: List[Dict[str, Any]] = []
async def mock_parse(
self,
*args: Any,
**kwargs: Any,
) -> ParsedChatCompletion[ResponseFormatT]:
self.calls.append(kwargs["messages"])
self.calls.append(kwargs) # Save the call
await asyncio.sleep(0.1)
completion = self._saved_chat_completions[self.curr_index]
self.curr_index += 1
return completion
def _pass_function(input: str) -> str:
return "pass"
async def _fail_function(input: str) -> str:
return "fail"
async def _echo_function(input: str) -> str:
return input
class MyResult(BaseModel):
result: str = Field(description="The other description.")
@@ -432,3 +466,292 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
== "The user explicitly states that they are happy without any indication of sadness or neutrality."
)
assert response.response == "happy"
@pytest.mark.asyncio
async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
# Successful completion, single tool call
ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "task"}),
),
)
],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
# Successful completion, parallel tool calls
ChatCompletion(
id="id2",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "task"}),
),
),
ChatCompletionMessageToolCall(
id="2",
type="function",
function=Function(
name="_fail_function",
arguments=json.dumps({"input": "task"}),
),
),
ChatCompletionMessageToolCall(
id="3",
type="function",
function=Function(
name="_echo_function",
arguments=json.dumps({"input": "task"}),
),
),
],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
# Warning completion when finish reason is not tool_calls.
ChatCompletion(
id="id3",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "task"}),
),
)
],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
# Warning completion when content is not None.
ChatCompletion(
id="id4",
choices=[
Choice(
finish_reason="tool_calls",
index=0,
message=ChatCompletionMessage(
content="I should make a tool call.",
tool_calls=[
ChatCompletionMessageToolCall(
id="1",
type="function",
function=Function(
name="_pass_function",
arguments=json.dumps({"input": "task"}),
),
)
],
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
pass_tool = FunctionTool(_pass_function, description="pass tool.")
fail_tool = FunctionTool(_fail_function, description="fail tool.")
echo_tool = FunctionTool(_echo_function, description="echo tool.")
model_client = OpenAIChatCompletionClient(model=model, api_key="")
# Single tool call
create_result = await model_client.create(messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool])
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
# Verify that the tool schema was passed to the model client.
kwargs = mock.calls[0]
assert kwargs["tools"] == [{"function": pass_tool.schema, "type": "function"}]
# Verify finish reason
assert create_result.finish_reason == "function_calls"
# Parallel tool calls
create_result = await model_client.create(
messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool, fail_tool, echo_tool]
)
assert create_result.content == [
FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function"),
FunctionCall(id="2", arguments=r'{"input": "task"}', name="_fail_function"),
FunctionCall(id="3", arguments=r'{"input": "task"}', name="_echo_function"),
]
# Verify that the tool schema was passed to the model client.
kwargs = mock.calls[1]
assert kwargs["tools"] == [
{"function": pass_tool.schema, "type": "function"},
{"function": fail_tool.schema, "type": "function"},
{"function": echo_tool.schema, "type": "function"},
]
# Verify finish reason
assert create_result.finish_reason == "function_calls"
# Warning completion when finish reason is not tool_calls.
with pytest.warns(UserWarning, match="Finish reason mismatch"):
create_result = await model_client.create(
messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool]
)
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
assert create_result.finish_reason == "function_calls"
# Warning completion when content is not None.
with pytest.warns(UserWarning, match="Both tool_calls and content are present in the message"):
create_result = await model_client.create(
messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool]
)
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
assert create_result.finish_reason == "function_calls"
async def _test_model_client(model_client: OpenAIChatCompletionClient) -> None:
# Test basic completion
create_result = await model_client.create(
messages=[
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Explain to me how AI works.", source="user"),
]
)
assert isinstance(create_result.content, str)
assert len(create_result.content) > 0
# Test tool calling
pass_tool = FunctionTool(_pass_function, name="pass_tool", description="pass session.")
fail_tool = FunctionTool(_fail_function, name="fail_tool", description="fail session.")
messages: List[LLMMessage] = [UserMessage(content="Call the pass tool with input 'task'", source="user")]
create_result = await model_client.create(messages=messages, tools=[pass_tool, fail_tool])
assert isinstance(create_result.content, list)
assert len(create_result.content) == 1
assert isinstance(create_result.content[0], FunctionCall)
assert create_result.content[0].name == "pass_tool"
assert json.loads(create_result.content[0].arguments) == {"input": "task"}
assert create_result.finish_reason == "function_calls"
assert create_result.usage is not None
# Test reflection on tool call response.
messages.append(AssistantMessage(content=create_result.content, source="assistant"))
messages.append(
FunctionExecutionResultMessage(
content=[FunctionExecutionResult(content="passed", call_id=create_result.content[0].id)]
)
)
create_result = await model_client.create(messages=messages)
assert isinstance(create_result.content, str)
assert len(create_result.content) > 0
# Test parallel tool calling
messages = [
UserMessage(
content="Call both the pass tool with input 'task' and the fail tool also with input 'task'", source="user"
)
]
create_result = await model_client.create(messages=messages, tools=[pass_tool, fail_tool])
assert isinstance(create_result.content, list)
assert len(create_result.content) == 2
assert isinstance(create_result.content[0], FunctionCall)
assert create_result.content[0].name == "pass_tool"
assert json.loads(create_result.content[0].arguments) == {"input": "task"}
assert isinstance(create_result.content[1], FunctionCall)
assert create_result.content[1].name == "fail_tool"
assert json.loads(create_result.content[1].arguments) == {"input": "task"}
assert create_result.finish_reason == "function_calls"
assert create_result.usage is not None
# Test reflection on parallel tool call response.
messages.append(AssistantMessage(content=create_result.content, source="assistant"))
messages.append(
FunctionExecutionResultMessage(
content=[
FunctionExecutionResult(content="passed", call_id=create_result.content[0].id),
FunctionExecutionResult(content="failed", call_id=create_result.content[1].id),
]
)
)
create_result = await model_client.create(messages=messages)
assert isinstance(create_result.content, str)
assert len(create_result.content) > 0
@pytest.mark.asyncio
async def test_openai() -> None:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
pytest.skip("OPENAI_API_KEY not found in environment variables")
model_client = OpenAIChatCompletionClient(
model="gpt-4o-mini",
api_key=api_key,
)
await _test_model_client(model_client)
@pytest.mark.asyncio
async def test_gemini() -> None:
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
pytest.skip("GEMINI_API_KEY not found in environment variables")
model_client = OpenAIChatCompletionClient(
model="gemini-1.5-flash",
api_key=api_key,
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
model_info={
"function_calling": True,
"json_output": True,
"vision": True,
"family": ModelFamily.UNKNOWN,
},
)
await _test_model_client(model_client)
# TODO: add integration tests for Azure OpenAI using AAD token.