Handle streamed function calls (#1118)

* Handle streamed function calls

* apply black formatting

* rm unnecessary stdout print

* bug fix

---------

Co-authored-by: Davor Runje <davor@airt.ai>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
bitnom
2024-01-08 01:22:31 -05:00
committed by GitHub
parent 1c4ae3d303
commit 78a2d8448d

View File

@@ -287,9 +287,8 @@ class OpenAIWrapper:
def _completions_create(self, client, params):
completions = client.chat.completions if "messages" in params else client.completions
# If streaming is enabled, has messages, and does not have functions or tools, then
# iterate over the chunks of the response
if params.get("stream", False) and "messages" in params and "functions" not in params and "tools" not in params:
# If streaming is enabled and has messages, then iterate over the chunks of the response.
if params.get("stream", False) and "messages" in params:
response_contents = [""] * params.get("n", 1)
finish_reasons = [""] * params.get("n", 1)
completion_tokens = 0
@@ -297,12 +296,33 @@ class OpenAIWrapper:
# Set the terminal text color to green
print("\033[32m", end="")
# Prepare for potential function call
full_function_call = None
# Send the chat completion request to OpenAI's API and process the response in chunks
for chunk in completions.create(**params):
if chunk.choices:
for choice in chunk.choices:
content = choice.delta.content
function_call_chunk = choice.delta.function_call
finish_reasons[choice.index] = choice.finish_reason
# Handle function call
if function_call_chunk:
if hasattr(function_call_chunk, "name") and function_call_chunk.name:
if full_function_call is None:
full_function_call = {"name": "", "arguments": ""}
full_function_call["name"] += function_call_chunk.name
completion_tokens += 1
if hasattr(function_call_chunk, "arguments") and function_call_chunk.arguments:
full_function_call["arguments"] += function_call_chunk.arguments
completion_tokens += 1
if choice.finish_reason == "function_call":
# Need something here? I don't think so.
pass
if not content:
continue
# End handle function call
# If content is present, print it to the terminal and update response variables
if content is not None:
print(content, end="", flush=True)
@@ -336,7 +356,7 @@ class OpenAIWrapper:
index=i,
finish_reason=finish_reasons[i],
message=ChatCompletionMessage(
role="assistant", content=response_contents[i], function_call=None
role="assistant", content=response_contents[i], function_call=full_function_call
),
logprobs=None,
)
@@ -346,17 +366,17 @@ class OpenAIWrapper:
index=i,
finish_reason=finish_reasons[i],
message=ChatCompletionMessage(
role="assistant", content=response_contents[i], function_call=None
role="assistant", content=response_contents[i], function_call=full_function_call
),
)
response.choices.append(choice)
else:
# If streaming is not enabled, using functions, or tools, send a regular chat completion request
# Functions and Tools are not supported, so ensure streaming is disabled
# If streaming is not enabled, send a regular chat completion request
params = params.copy()
params["stream"] = False
response = completions.create(**params)
return response
def _update_usage_summary(self, response: ChatCompletion | Completion, use_cache: bool) -> None: