mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
fix: use tool_calls field to detect tool calls in OpenAI client; add integration tests for OpenAI and Gemini (#5122)
* fix: use tool_calls field to detect tool calls in OpenAI client * Add unit tests for tool calling; and integration tests for openai and gemini
This commit is contained in:
@@ -539,20 +539,33 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
if self._resolved_model is not None:
|
||||
if self._resolved_model != result.model:
|
||||
warnings.warn(
|
||||
f"Resolved model mismatch: {self._resolved_model} != {result.model}. Model mapping may be incorrect.",
|
||||
f"Resolved model mismatch: {self._resolved_model} != {result.model}. "
|
||||
"Model mapping in autogen_ext.models.openai may be incorrect.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Limited to a single choice currently.
|
||||
choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], Choice] = result.choices[0]
|
||||
if choice.finish_reason == "function_call":
|
||||
raise ValueError("Function calls are not supported in this context")
|
||||
|
||||
# Detect whether it is a function call or not.
|
||||
# We don't rely on choice.finish_reason as it is not always accurate, depending on the API used.
|
||||
content: Union[str, List[FunctionCall]]
|
||||
if choice.finish_reason == "tool_calls":
|
||||
assert choice.message.tool_calls is not None
|
||||
assert choice.message.function_call is None
|
||||
|
||||
if choice.message.function_call is not None:
|
||||
raise ValueError("function_call is deprecated and is not supported by this model client.")
|
||||
elif choice.message.tool_calls is not None:
|
||||
if choice.finish_reason != "tool_calls":
|
||||
warnings.warn(
|
||||
f"Finish reason mismatch: {choice.finish_reason} != tool_calls "
|
||||
"when tool_calls are present. Finish reason may not be accurate. "
|
||||
"This may be due to the API used that is not returning the correct finish reason.",
|
||||
stacklevel=2,
|
||||
)
|
||||
if choice.message.content is not None and choice.message.content != "":
|
||||
warnings.warn(
|
||||
"Both tool_calls and content are present in the message. "
|
||||
"This is unexpected. content will be ignored, tool_calls will be used.",
|
||||
stacklevel=2,
|
||||
)
|
||||
# NOTE: If OAI response type changes, this will need to be updated
|
||||
content = [
|
||||
FunctionCall(
|
||||
@@ -562,10 +575,11 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
)
|
||||
for x in choice.message.tool_calls
|
||||
]
|
||||
finish_reason = "function_calls"
|
||||
finish_reason = "tool_calls"
|
||||
else:
|
||||
finish_reason = choice.finish_reason
|
||||
content = choice.message.content or ""
|
||||
|
||||
logprobs: Optional[List[ChatCompletionTokenLogprob]] = None
|
||||
if choice.logprobs and choice.logprobs.content:
|
||||
logprobs = [
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Annotated, Any, AsyncGenerator, Generic, List, Literal, Tuple, TypeVar
|
||||
import os
|
||||
from typing import Annotated, Any, AsyncGenerator, Dict, Generic, List, Literal, Tuple, TypeVar
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from autogen_core import CancellationToken, Image
|
||||
from autogen_core import CancellationToken, FunctionCall, Image
|
||||
from autogen_core.models import (
|
||||
AssistantMessage,
|
||||
CreateResult,
|
||||
@@ -26,10 +27,31 @@ from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta
|
||||
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion, ParsedChatCompletionMessage, ParsedChoice
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class _MockChatCompletion:
|
||||
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
|
||||
self._saved_chat_completions = chat_completions
|
||||
self.curr_index = 0
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def mock_create(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
||||
self.calls.append(kwargs) # Save the call
|
||||
await asyncio.sleep(0.1)
|
||||
completion = self._saved_chat_completions[self.curr_index]
|
||||
self.curr_index += 1
|
||||
return completion
|
||||
|
||||
|
||||
ResponseFormatT = TypeVar("ResponseFormatT", bound=BaseModel)
|
||||
|
||||
|
||||
@@ -37,20 +59,32 @@ class _MockBetaChatCompletion(Generic[ResponseFormatT]):
|
||||
def __init__(self, chat_completions: List[ParsedChatCompletion[ResponseFormatT]]) -> None:
|
||||
self._saved_chat_completions = chat_completions
|
||||
self.curr_index = 0
|
||||
self.calls: List[List[LLMMessage]] = []
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def mock_parse(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> ParsedChatCompletion[ResponseFormatT]:
|
||||
self.calls.append(kwargs["messages"])
|
||||
self.calls.append(kwargs) # Save the call
|
||||
await asyncio.sleep(0.1)
|
||||
completion = self._saved_chat_completions[self.curr_index]
|
||||
self.curr_index += 1
|
||||
return completion
|
||||
|
||||
|
||||
def _pass_function(input: str) -> str:
|
||||
return "pass"
|
||||
|
||||
|
||||
async def _fail_function(input: str) -> str:
|
||||
return "fail"
|
||||
|
||||
|
||||
async def _echo_function(input: str) -> str:
|
||||
return input
|
||||
|
||||
|
||||
class MyResult(BaseModel):
|
||||
result: str = Field(description="The other description.")
|
||||
|
||||
@@ -432,3 +466,292 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
== "The user explicitly states that they are happy without any indication of sadness or neutrality."
|
||||
)
|
||||
assert response.response == "happy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
chat_completions = [
|
||||
# Successful completion, single tool call
|
||||
ChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="_pass_function",
|
||||
arguments=json.dumps({"input": "task"}),
|
||||
),
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
# Successful completion, parallel tool calls
|
||||
ChatCompletion(
|
||||
id="id2",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="_pass_function",
|
||||
arguments=json.dumps({"input": "task"}),
|
||||
),
|
||||
),
|
||||
ChatCompletionMessageToolCall(
|
||||
id="2",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="_fail_function",
|
||||
arguments=json.dumps({"input": "task"}),
|
||||
),
|
||||
),
|
||||
ChatCompletionMessageToolCall(
|
||||
id="3",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="_echo_function",
|
||||
arguments=json.dumps({"input": "task"}),
|
||||
),
|
||||
),
|
||||
],
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
# Warning completion when finish reason is not tool_calls.
|
||||
ChatCompletion(
|
||||
id="id3",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="_pass_function",
|
||||
arguments=json.dumps({"input": "task"}),
|
||||
),
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
# Warning completion when content is not None.
|
||||
ChatCompletion(
|
||||
id="id4",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content="I should make a tool call.",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="1",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="_pass_function",
|
||||
arguments=json.dumps({"input": "task"}),
|
||||
),
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
),
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
pass_tool = FunctionTool(_pass_function, description="pass tool.")
|
||||
fail_tool = FunctionTool(_fail_function, description="fail tool.")
|
||||
echo_tool = FunctionTool(_echo_function, description="echo tool.")
|
||||
model_client = OpenAIChatCompletionClient(model=model, api_key="")
|
||||
|
||||
# Single tool call
|
||||
create_result = await model_client.create(messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool])
|
||||
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
|
||||
# Verify that the tool schema was passed to the model client.
|
||||
kwargs = mock.calls[0]
|
||||
assert kwargs["tools"] == [{"function": pass_tool.schema, "type": "function"}]
|
||||
# Verify finish reason
|
||||
assert create_result.finish_reason == "function_calls"
|
||||
|
||||
# Parallel tool calls
|
||||
create_result = await model_client.create(
|
||||
messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool, fail_tool, echo_tool]
|
||||
)
|
||||
assert create_result.content == [
|
||||
FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function"),
|
||||
FunctionCall(id="2", arguments=r'{"input": "task"}', name="_fail_function"),
|
||||
FunctionCall(id="3", arguments=r'{"input": "task"}', name="_echo_function"),
|
||||
]
|
||||
# Verify that the tool schema was passed to the model client.
|
||||
kwargs = mock.calls[1]
|
||||
assert kwargs["tools"] == [
|
||||
{"function": pass_tool.schema, "type": "function"},
|
||||
{"function": fail_tool.schema, "type": "function"},
|
||||
{"function": echo_tool.schema, "type": "function"},
|
||||
]
|
||||
# Verify finish reason
|
||||
assert create_result.finish_reason == "function_calls"
|
||||
|
||||
# Warning completion when finish reason is not tool_calls.
|
||||
with pytest.warns(UserWarning, match="Finish reason mismatch"):
|
||||
create_result = await model_client.create(
|
||||
messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool]
|
||||
)
|
||||
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
|
||||
assert create_result.finish_reason == "function_calls"
|
||||
|
||||
# Warning completion when content is not None.
|
||||
with pytest.warns(UserWarning, match="Both tool_calls and content are present in the message"):
|
||||
create_result = await model_client.create(
|
||||
messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool]
|
||||
)
|
||||
assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
|
||||
assert create_result.finish_reason == "function_calls"
|
||||
|
||||
|
||||
async def _test_model_client(model_client: OpenAIChatCompletionClient) -> None:
|
||||
# Test basic completion
|
||||
create_result = await model_client.create(
|
||||
messages=[
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="Explain to me how AI works.", source="user"),
|
||||
]
|
||||
)
|
||||
assert isinstance(create_result.content, str)
|
||||
assert len(create_result.content) > 0
|
||||
|
||||
# Test tool calling
|
||||
pass_tool = FunctionTool(_pass_function, name="pass_tool", description="pass session.")
|
||||
fail_tool = FunctionTool(_fail_function, name="fail_tool", description="fail session.")
|
||||
messages: List[LLMMessage] = [UserMessage(content="Call the pass tool with input 'task'", source="user")]
|
||||
create_result = await model_client.create(messages=messages, tools=[pass_tool, fail_tool])
|
||||
assert isinstance(create_result.content, list)
|
||||
assert len(create_result.content) == 1
|
||||
assert isinstance(create_result.content[0], FunctionCall)
|
||||
assert create_result.content[0].name == "pass_tool"
|
||||
assert json.loads(create_result.content[0].arguments) == {"input": "task"}
|
||||
assert create_result.finish_reason == "function_calls"
|
||||
assert create_result.usage is not None
|
||||
|
||||
# Test reflection on tool call response.
|
||||
messages.append(AssistantMessage(content=create_result.content, source="assistant"))
|
||||
messages.append(
|
||||
FunctionExecutionResultMessage(
|
||||
content=[FunctionExecutionResult(content="passed", call_id=create_result.content[0].id)]
|
||||
)
|
||||
)
|
||||
create_result = await model_client.create(messages=messages)
|
||||
assert isinstance(create_result.content, str)
|
||||
assert len(create_result.content) > 0
|
||||
|
||||
# Test parallel tool calling
|
||||
messages = [
|
||||
UserMessage(
|
||||
content="Call both the pass tool with input 'task' and the fail tool also with input 'task'", source="user"
|
||||
)
|
||||
]
|
||||
create_result = await model_client.create(messages=messages, tools=[pass_tool, fail_tool])
|
||||
assert isinstance(create_result.content, list)
|
||||
assert len(create_result.content) == 2
|
||||
assert isinstance(create_result.content[0], FunctionCall)
|
||||
assert create_result.content[0].name == "pass_tool"
|
||||
assert json.loads(create_result.content[0].arguments) == {"input": "task"}
|
||||
assert isinstance(create_result.content[1], FunctionCall)
|
||||
assert create_result.content[1].name == "fail_tool"
|
||||
assert json.loads(create_result.content[1].arguments) == {"input": "task"}
|
||||
assert create_result.finish_reason == "function_calls"
|
||||
assert create_result.usage is not None
|
||||
|
||||
# Test reflection on parallel tool call response.
|
||||
messages.append(AssistantMessage(content=create_result.content, source="assistant"))
|
||||
messages.append(
|
||||
FunctionExecutionResultMessage(
|
||||
content=[
|
||||
FunctionExecutionResult(content="passed", call_id=create_result.content[0].id),
|
||||
FunctionExecutionResult(content="failed", call_id=create_result.content[1].id),
|
||||
]
|
||||
)
|
||||
)
|
||||
create_result = await model_client.create(messages=messages)
|
||||
assert isinstance(create_result.content, str)
|
||||
assert len(create_result.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai() -> None:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("OPENAI_API_KEY not found in environment variables")
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o-mini",
|
||||
api_key=api_key,
|
||||
)
|
||||
await _test_model_client(model_client)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini() -> None:
|
||||
api_key = os.getenv("GEMINI_API_KEY")
|
||||
if not api_key:
|
||||
pytest.skip("GEMINI_API_KEY not found in environment variables")
|
||||
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gemini-1.5-flash",
|
||||
api_key=api_key,
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
model_info={
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"vision": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
},
|
||||
)
|
||||
await _test_model_client(model_client)
|
||||
|
||||
|
||||
# TODO: add integration tests for Azure OpenAI using AAD token.
|
||||
|
||||
Reference in New Issue
Block a user