Implement functions API compatibility mode for older OpenAI models

This commit is contained in:
Reinier van der Leer
2023-10-08 18:05:08 -07:00
parent aae650fe3a
commit ad0c3ebf07
2 changed files with 148 additions and 5 deletions

View File

@@ -3,7 +3,7 @@ import functools
import logging
import math
import time
from typing import Callable, ParamSpec, TypeVar
from typing import Callable, Optional, ParamSpec, TypeVar
import openai
import tiktoken
@@ -16,6 +16,7 @@ from autogpt.core.configuration import (
)
from autogpt.core.resource.model_providers.schema import (
AssistantChatMessageDict,
AssistantFunctionCallDict,
ChatMessage,
ChatModelInfo,
ChatModelProvider,
@@ -33,6 +34,7 @@ from autogpt.core.resource.model_providers.schema import (
ModelProviderUsage,
ModelTokenizer,
)
from autogpt.core.utils.json_schema import JSONSchema
_T = TypeVar("_T")
_P = ParamSpec("_P")
@@ -263,11 +265,17 @@ class OpenAIProvider(
model_prompt: list[ChatMessage],
model_name: OpenAIModelName,
completion_parser: Callable[[AssistantChatMessageDict], _T] = lambda _: None,
functions: list[CompletionModelFunction] = [],
functions: Optional[list[CompletionModelFunction]] = None,
**kwargs,
) -> ChatModelResponse[_T]:
"""Create a completion using the OpenAI API."""
completion_kwargs = self._get_completion_kwargs(model_name, functions, **kwargs)
functions_compat_mode = functions and "functions" not in completion_kwargs
if "messages" in completion_kwargs:
model_prompt += completion_kwargs["messages"]
del completion_kwargs["messages"]
response = await self._create_chat_completion(
messages=model_prompt,
**completion_kwargs,
@@ -279,6 +287,10 @@ class OpenAIProvider(
}
response_message = response.choices[0].message.to_dict_recursive()
if functions_compat_mode:
response_message["function_call"] = _functions_compat_extract_call(
response_message["content"]
)
response = ChatModelResponse(
response=response_message,
parsed_result=completion_parser(response_message),
@@ -313,7 +325,7 @@ class OpenAIProvider(
def _get_completion_kwargs(
self,
model_name: OpenAIModelName,
functions: list[CompletionModelFunction],
functions: Optional[list[CompletionModelFunction]] = None,
**kwargs,
) -> dict:
"""Get kwargs for completion API call.
@@ -331,8 +343,13 @@ class OpenAIProvider(
**kwargs,
**self._credentials.unmasked(),
}
if functions:
completion_kwargs["functions"] = [f.schema for f in functions]
if OPEN_AI_CHAT_MODELS[model_name].has_function_call_api:
completion_kwargs["functions"] = [f.schema for f in functions]
else:
# Provide compatibility with older models
_functions_compat_fix_kwargs(functions, completion_kwargs)
return completion_kwargs
@@ -459,3 +476,129 @@ class _OpenAIRetryHandler:
self._backoff(attempt)
return _wrapped
def format_function_specs_as_typescript_ns(
functions: list[CompletionModelFunction],
) -> str:
"""Returns a function signature block in the format used by OpenAI internally:
https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18
For use with `count_tokens` to determine token usage of provided functions.
Example:
```ts
namespace functions {
// Get the current weather in a given location
type get_current_weather = (_: {
// The city and state, e.g. San Francisco, CA
location: string,
unit?: "celsius" | "fahrenheit",
}) => any;
} // namespace functions
```
"""
return (
"namespace functions {\n\n"
+ "\n\n".join(format_openai_function_for_prompt(f) for f in functions)
+ "\n\n} // namespace functions"
)
def format_openai_function_for_prompt(func: CompletionModelFunction) -> str:
"""Returns the function formatted similarly to the way OpenAI does it internally:
https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18
Example:
```ts
// Get the current weather in a given location
type get_current_weather = (_: {
// The city and state, e.g. San Francisco, CA
location: string,
unit?: "celsius" | "fahrenheit",
}) => any;
```
"""
def param_signature(name: str, spec: JSONSchema) -> str:
return (
f"// {spec.description}\n" if spec.description else ""
) + f"{name}{'' if spec.required else '?'}: {spec.typescript_type},"
return "\n".join(
[
f"// {func.description}",
f"type {func.name} = (_ :{{",
*[param_signature(name, p) for name, p in func.parameters.items()],
"}) => any;",
]
)
def count_openai_functions_tokens(
functions: list[CompletionModelFunction], count_tokens: Callable[[str], int]
) -> int:
"""Returns the number of tokens taken up by a set of function definitions
Reference: https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18
"""
return count_tokens(
f"# Tools\n\n## functions\n\n{format_function_specs_as_typescript_ns(functions)}"
)
def _functions_compat_fix_kwargs(
functions: list[CompletionModelFunction],
completion_kwargs: dict,
):
function_definitions = format_function_specs_as_typescript_ns(functions)
function_call_schema = JSONSchema(
type=JSONSchema.Type.OBJECT,
properties={
"name": JSONSchema(
description="The name of the function to call",
enum=[f.name for f in functions],
required=True,
),
"arguments": JSONSchema(
description="The arguments for the function call",
type=JSONSchema.Type.OBJECT,
required=True,
),
},
)
completion_kwargs["messages"] = [
ChatMessage.system(
"# function_call instructions\n\n"
"Specify a '```function_call' block in your response,"
" enclosing a function call in the form of a valid JSON object"
" that adheres to the following schema:\n\n"
f"{function_call_schema.to_dict()}\n\n"
"Put the function_call block at the end of your response"
" and include its fences if it is not the only content.\n\n"
"## functions\n\n"
"For the function call itself, use one of the following"
f" functions:\n\n{function_definitions}"
),
]
def _functions_compat_extract_call(response: str) -> AssistantFunctionCallDict:
import json
import re
logging.debug(f"Trying to extract function call from response:\n{response}")
if response[0] == "{":
function_call = json.loads(response)
else:
block = re.search(r"```(?:function_call)?\n(.*)\n```\s*$", response, re.DOTALL)
if not block:
raise ValueError("Could not find function call block in response")
function_call = json.loads(block.group(1))
function_call["arguments"] = str(function_call["arguments"]) # HACK
return function_call

View File

@@ -333,7 +333,7 @@ class ChatModelProvider(ModelProvider):
model_prompt: list[ChatMessage],
model_name: str,
completion_parser: Callable[[AssistantChatMessageDict], _T] = lambda _: None,
functions: list[CompletionModelFunction] = [],
functions: Optional[list[CompletionModelFunction]] = None,
**kwargs,
) -> ChatModelResponse[_T]:
...