From 78a2d8448ddf38e53cab85d67e5541de63ce8f95 Mon Sep 17 00:00:00 2001 From: bitnom <14287229+bitnom@users.noreply.github.com> Date: Mon, 8 Jan 2024 01:22:31 -0500 Subject: [PATCH] Handle streamed function calls (#1118) * Handle streamed function calls * apply black formatting * rm unnecessary stdout print * bug fix --------- Co-authored-by: Davor Runje Co-authored-by: Eric Zhu --- autogen/oai/client.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index fdba6e108..feaf56a76 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -287,9 +287,8 @@ class OpenAIWrapper: def _completions_create(self, client, params): completions = client.chat.completions if "messages" in params else client.completions - # If streaming is enabled, has messages, and does not have functions or tools, then - # iterate over the chunks of the response - if params.get("stream", False) and "messages" in params and "functions" not in params and "tools" not in params: + # If streaming is enabled and has messages, then iterate over the chunks of the response. + if params.get("stream", False) and "messages" in params: response_contents = [""] * params.get("n", 1) finish_reasons = [""] * params.get("n", 1) completion_tokens = 0 @@ -297,12 +296,33 @@ class OpenAIWrapper: # Set the terminal text color to green print("\033[32m", end="") + # Prepare for potential function call + full_function_call = None # Send the chat completion request to OpenAI's API and process the response in chunks for chunk in completions.create(**params): if chunk.choices: for choice in chunk.choices: content = choice.delta.content + function_call_chunk = choice.delta.function_call finish_reasons[choice.index] = choice.finish_reason + + # Handle function call + if function_call_chunk: + if hasattr(function_call_chunk, "name") and function_call_chunk.name: + if full_function_call is None: + full_function_call = {"name": "", "arguments": ""} + full_function_call["name"] += function_call_chunk.name + completion_tokens += 1 + if hasattr(function_call_chunk, "arguments") and function_call_chunk.arguments: + full_function_call["arguments"] += function_call_chunk.arguments + completion_tokens += 1 + if choice.finish_reason == "function_call": + # Need something here? I don't think so. + pass + if not content: + continue + # End handle function call + # If content is present, print it to the terminal and update response variables if content is not None: print(content, end="", flush=True) @@ -336,7 +356,7 @@ class OpenAIWrapper: index=i, finish_reason=finish_reasons[i], message=ChatCompletionMessage( - role="assistant", content=response_contents[i], function_call=None + role="assistant", content=response_contents[i], function_call=full_function_call ), logprobs=None, ) @@ -346,17 +366,17 @@ class OpenAIWrapper: index=i, finish_reason=finish_reasons[i], message=ChatCompletionMessage( - role="assistant", content=response_contents[i], function_call=None + role="assistant", content=response_contents[i], function_call=full_function_call ), ) response.choices.append(choice) else: - # If streaming is not enabled, using functions, or tools, send a regular chat completion request - # Functions and Tools are not supported, so ensure streaming is disabled + # If streaming is not enabled, send a regular chat completion request params = params.copy() params["stream"] = False response = completions.create(**params) + return response def _update_usage_summary(self, response: ChatCompletion | Completion, use_cache: bool) -> None: