mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Added support for streaming tool calls (#1184)
* added support for streaming tool calls * bug fix: removed tmp assert --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
@@ -1,3 +1,6 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from unittest.mock import MagicMock
|
||||
import pytest
|
||||
from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai
|
||||
import sys
|
||||
@@ -13,12 +16,21 @@ except ImportError:
|
||||
else:
|
||||
skip = False or skip_openai
|
||||
|
||||
# raises exception if openai>=1 is installed and something is wrong with imports
|
||||
# otherwise the test will be skipped
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
ChoiceDeltaFunctionCall,
|
||||
ChoiceDeltaToolCall,
|
||||
ChoiceDeltaToolCallFunction,
|
||||
)
|
||||
from openai.types.chat.chat_completion import ChatCompletionMessage # type: ignore [attr-defined]
|
||||
|
||||
KEY_LOC = "notebook"
|
||||
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
def test_aoai_chat_completion_stream():
|
||||
def test_aoai_chat_completion_stream() -> None:
|
||||
config_list = config_list_from_json(
|
||||
env_or_file=OAI_CONFIG_LIST,
|
||||
file_location=KEY_LOC,
|
||||
@@ -31,7 +43,7 @@ def test_aoai_chat_completion_stream():
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
def test_chat_completion_stream():
|
||||
def test_chat_completion_stream() -> None:
|
||||
config_list = config_list_from_json(
|
||||
env_or_file=OAI_CONFIG_LIST,
|
||||
file_location=KEY_LOC,
|
||||
@@ -43,8 +55,147 @@ def test_chat_completion_stream():
|
||||
print(client.extract_text_or_completion_object(response))
|
||||
|
||||
|
||||
# no need for OpenAI, works with any model
|
||||
def test__update_dict_from_chunk() -> None:
|
||||
# dictionaries and lists are not supported
|
||||
mock = MagicMock()
|
||||
empty_collections: List[Union[List[Any], Dict[str, Any]]] = [{}, []]
|
||||
for c in empty_collections:
|
||||
mock.c = c
|
||||
with pytest.raises(NotImplementedError):
|
||||
OpenAIWrapper._update_dict_from_chunk(mock, {}, "c")
|
||||
|
||||
org_d: Dict[str, Any] = {}
|
||||
for i, v in enumerate([0, 1, False, True, 0.0, 1.0]):
|
||||
field = "abcedfghijklmnopqrstuvwxyz"[i]
|
||||
setattr(mock, field, v)
|
||||
|
||||
d = org_d.copy()
|
||||
OpenAIWrapper._update_dict_from_chunk(mock, d, field)
|
||||
|
||||
org_d[field] = v
|
||||
assert d == org_d
|
||||
|
||||
mock.s = "beginning"
|
||||
OpenAIWrapper._update_dict_from_chunk(mock, d, "s")
|
||||
assert d["s"] == "beginning"
|
||||
|
||||
mock.s = " and"
|
||||
OpenAIWrapper._update_dict_from_chunk(mock, d, "s")
|
||||
assert d["s"] == "beginning and"
|
||||
|
||||
mock.s = " end"
|
||||
OpenAIWrapper._update_dict_from_chunk(mock, d, "s")
|
||||
assert d["s"] == "beginning and end"
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
def test_chat_functions_stream():
|
||||
def test__update_function_call_from_chunk() -> None:
|
||||
function_call_chunks = [
|
||||
ChoiceDeltaFunctionCall(arguments=None, name="get_current_weather"),
|
||||
ChoiceDeltaFunctionCall(arguments='{"', name=None),
|
||||
ChoiceDeltaFunctionCall(arguments="location", name=None),
|
||||
ChoiceDeltaFunctionCall(arguments='":"', name=None),
|
||||
ChoiceDeltaFunctionCall(arguments="San", name=None),
|
||||
ChoiceDeltaFunctionCall(arguments=" Francisco", name=None),
|
||||
ChoiceDeltaFunctionCall(arguments='"}', name=None),
|
||||
]
|
||||
expected = {"name": "get_current_weather", "arguments": '{"location":"San Francisco"}'}
|
||||
|
||||
full_function_call = None
|
||||
completion_tokens = 0
|
||||
for function_call_chunk in function_call_chunks:
|
||||
# print(f"{function_call_chunk=}")
|
||||
full_function_call, completion_tokens = OpenAIWrapper._update_function_call_from_chunk(
|
||||
function_call_chunk=function_call_chunk,
|
||||
full_function_call=full_function_call,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
print(f"{full_function_call=}")
|
||||
print(f"{completion_tokens=}")
|
||||
|
||||
assert full_function_call == expected
|
||||
assert completion_tokens == len(function_call_chunks)
|
||||
|
||||
ChatCompletionMessage(role="assistant", function_call=full_function_call, content=None)
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
def test__update_tool_calls_from_chunk() -> None:
|
||||
tool_calls_chunks = [
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id="call_D2HOWGMekmkxXu9Ix3DUqJRv",
|
||||
function=ChoiceDeltaToolCallFunction(arguments="", name="get_current_weather"),
|
||||
type="function",
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='{"lo', name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="catio", name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='n": "S', name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="an F", name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="ranci", name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="sco, C", name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='A"}', name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=1,
|
||||
id="call_22HgJep4nwoKU3UOr96xaLmd",
|
||||
function=ChoiceDeltaToolCallFunction(arguments="", name="get_current_weather"),
|
||||
type="function",
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='{"lo', name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="catio", name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='n": "N', name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="ew Y", name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="ork, ", name=None), type=None
|
||||
),
|
||||
ChoiceDeltaToolCall(
|
||||
index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='NY"}', name=None), type=None
|
||||
),
|
||||
]
|
||||
|
||||
full_tool_calls: List[Optional[Dict[str, Any]]] = [None, None]
|
||||
completion_tokens = 0
|
||||
for tool_calls_chunk in tool_calls_chunks:
|
||||
index = tool_calls_chunk.index
|
||||
full_tool_calls[index], completion_tokens = OpenAIWrapper._update_tool_calls_from_chunk(
|
||||
tool_calls_chunk=tool_calls_chunk,
|
||||
full_tool_call=full_tool_calls[index],
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
print(f"{full_tool_calls=}")
|
||||
print(f"{completion_tokens=}")
|
||||
|
||||
ChatCompletionMessage(role="assistant", tool_calls=full_tool_calls, content=None)
|
||||
|
||||
|
||||
# todo: remove when OpenAI removes functions from the API
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
def test_chat_functions_stream() -> None:
|
||||
config_list = config_list_from_json(
|
||||
env_or_file=OAI_CONFIG_LIST,
|
||||
file_location=KEY_LOC,
|
||||
@@ -76,8 +227,63 @@ def test_chat_functions_stream():
|
||||
print(client.extract_text_or_completion_object(response))
|
||||
|
||||
|
||||
# test for tool support instead of the deprecated function calls
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
def test_completion_stream():
|
||||
def test_chat_tools_stream() -> None:
|
||||
config_list = config_list_from_json(
|
||||
env_or_file=OAI_CONFIG_LIST,
|
||||
file_location=KEY_LOC,
|
||||
filter_dict={"model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
|
||||
)
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
print(f"{config_list=}")
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(
|
||||
# the intention is to trigger two tool invocations as a response to a single message
|
||||
messages=[{"role": "user", "content": "What's the weather like today in San Francisco and New York?"}],
|
||||
tools=tools,
|
||||
stream=True,
|
||||
)
|
||||
print(f"{response=}")
|
||||
print(f"{type(response)=}")
|
||||
print(f"{client.extract_text_or_completion_object(response)=}")
|
||||
# check response
|
||||
choices = response.choices
|
||||
assert isinstance(choices, list)
|
||||
assert len(choices) == 1
|
||||
choice = choices[0]
|
||||
assert choice.finish_reason == "tool_calls"
|
||||
message = choice.message
|
||||
tool_calls = message.tool_calls
|
||||
assert isinstance(tool_calls, list)
|
||||
assert len(tool_calls) == 2
|
||||
arguments = [tool_call.function.arguments for tool_call in tool_calls]
|
||||
locations = [json.loads(argument)["location"] for argument in arguments]
|
||||
print(f"{locations=}")
|
||||
assert any(["San Francisco" in location for location in locations])
|
||||
assert any(["New York" in location for location in locations])
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
|
||||
def test_completion_stream() -> None:
|
||||
config_list = config_list_openai_aoai(KEY_LOC)
|
||||
client = OpenAIWrapper(config_list=config_list)
|
||||
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True)
|
||||
|
||||
Reference in New Issue
Block a user