mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Implement functions API compatibility mode for older OpenAI models
This commit is contained in:
@@ -3,7 +3,7 @@ import functools
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from typing import Callable, ParamSpec, TypeVar
|
||||
from typing import Callable, Optional, ParamSpec, TypeVar
|
||||
|
||||
import openai
|
||||
import tiktoken
|
||||
@@ -16,6 +16,7 @@ from autogpt.core.configuration import (
|
||||
)
|
||||
from autogpt.core.resource.model_providers.schema import (
|
||||
AssistantChatMessageDict,
|
||||
AssistantFunctionCallDict,
|
||||
ChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelProvider,
|
||||
@@ -33,6 +34,7 @@ from autogpt.core.resource.model_providers.schema import (
|
||||
ModelProviderUsage,
|
||||
ModelTokenizer,
|
||||
)
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
@@ -263,11 +265,17 @@ class OpenAIProvider(
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: OpenAIModelName,
|
||||
completion_parser: Callable[[AssistantChatMessageDict], _T] = lambda _: None,
|
||||
functions: list[CompletionModelFunction] = [],
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
"""Create a completion using the OpenAI API."""
|
||||
|
||||
completion_kwargs = self._get_completion_kwargs(model_name, functions, **kwargs)
|
||||
functions_compat_mode = functions and "functions" not in completion_kwargs
|
||||
if "messages" in completion_kwargs:
|
||||
model_prompt += completion_kwargs["messages"]
|
||||
del completion_kwargs["messages"]
|
||||
|
||||
response = await self._create_chat_completion(
|
||||
messages=model_prompt,
|
||||
**completion_kwargs,
|
||||
@@ -279,6 +287,10 @@ class OpenAIProvider(
|
||||
}
|
||||
|
||||
response_message = response.choices[0].message.to_dict_recursive()
|
||||
if functions_compat_mode:
|
||||
response_message["function_call"] = _functions_compat_extract_call(
|
||||
response_message["content"]
|
||||
)
|
||||
response = ChatModelResponse(
|
||||
response=response_message,
|
||||
parsed_result=completion_parser(response_message),
|
||||
@@ -313,7 +325,7 @@ class OpenAIProvider(
|
||||
def _get_completion_kwargs(
|
||||
self,
|
||||
model_name: OpenAIModelName,
|
||||
functions: list[CompletionModelFunction],
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""Get kwargs for completion API call.
|
||||
@@ -331,8 +343,13 @@ class OpenAIProvider(
|
||||
**kwargs,
|
||||
**self._credentials.unmasked(),
|
||||
}
|
||||
|
||||
if functions:
|
||||
completion_kwargs["functions"] = [f.schema for f in functions]
|
||||
if OPEN_AI_CHAT_MODELS[model_name].has_function_call_api:
|
||||
completion_kwargs["functions"] = [f.schema for f in functions]
|
||||
else:
|
||||
# Provide compatibility with older models
|
||||
_functions_compat_fix_kwargs(functions, completion_kwargs)
|
||||
|
||||
return completion_kwargs
|
||||
|
||||
@@ -459,3 +476,129 @@ class _OpenAIRetryHandler:
|
||||
self._backoff(attempt)
|
||||
|
||||
return _wrapped
|
||||
|
||||
|
||||
def format_function_specs_as_typescript_ns(
|
||||
functions: list[CompletionModelFunction],
|
||||
) -> str:
|
||||
"""Returns a function signature block in the format used by OpenAI internally:
|
||||
https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18
|
||||
|
||||
For use with `count_tokens` to determine token usage of provided functions.
|
||||
|
||||
Example:
|
||||
```ts
|
||||
namespace functions {
|
||||
|
||||
// Get the current weather in a given location
|
||||
type get_current_weather = (_: {
|
||||
// The city and state, e.g. San Francisco, CA
|
||||
location: string,
|
||||
unit?: "celsius" | "fahrenheit",
|
||||
}) => any;
|
||||
|
||||
} // namespace functions
|
||||
```
|
||||
"""
|
||||
|
||||
return (
|
||||
"namespace functions {\n\n"
|
||||
+ "\n\n".join(format_openai_function_for_prompt(f) for f in functions)
|
||||
+ "\n\n} // namespace functions"
|
||||
)
|
||||
|
||||
|
||||
def format_openai_function_for_prompt(func: CompletionModelFunction) -> str:
|
||||
"""Returns the function formatted similarly to the way OpenAI does it internally:
|
||||
https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18
|
||||
|
||||
Example:
|
||||
```ts
|
||||
// Get the current weather in a given location
|
||||
type get_current_weather = (_: {
|
||||
// The city and state, e.g. San Francisco, CA
|
||||
location: string,
|
||||
unit?: "celsius" | "fahrenheit",
|
||||
}) => any;
|
||||
```
|
||||
"""
|
||||
|
||||
def param_signature(name: str, spec: JSONSchema) -> str:
|
||||
return (
|
||||
f"// {spec.description}\n" if spec.description else ""
|
||||
) + f"{name}{'' if spec.required else '?'}: {spec.typescript_type},"
|
||||
|
||||
return "\n".join(
|
||||
[
|
||||
f"// {func.description}",
|
||||
f"type {func.name} = (_ :{{",
|
||||
*[param_signature(name, p) for name, p in func.parameters.items()],
|
||||
"}) => any;",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def count_openai_functions_tokens(
|
||||
functions: list[CompletionModelFunction], count_tokens: Callable[[str], int]
|
||||
) -> int:
|
||||
"""Returns the number of tokens taken up by a set of function definitions
|
||||
|
||||
Reference: https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/18
|
||||
"""
|
||||
return count_tokens(
|
||||
f"# Tools\n\n## functions\n\n{format_function_specs_as_typescript_ns(functions)}"
|
||||
)
|
||||
|
||||
|
||||
def _functions_compat_fix_kwargs(
|
||||
functions: list[CompletionModelFunction],
|
||||
completion_kwargs: dict,
|
||||
):
|
||||
function_definitions = format_function_specs_as_typescript_ns(functions)
|
||||
function_call_schema = JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
properties={
|
||||
"name": JSONSchema(
|
||||
description="The name of the function to call",
|
||||
enum=[f.name for f in functions],
|
||||
required=True,
|
||||
),
|
||||
"arguments": JSONSchema(
|
||||
description="The arguments for the function call",
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
completion_kwargs["messages"] = [
|
||||
ChatMessage.system(
|
||||
"# function_call instructions\n\n"
|
||||
"Specify a '```function_call' block in your response,"
|
||||
" enclosing a function call in the form of a valid JSON object"
|
||||
" that adheres to the following schema:\n\n"
|
||||
f"{function_call_schema.to_dict()}\n\n"
|
||||
"Put the function_call block at the end of your response"
|
||||
" and include its fences if it is not the only content.\n\n"
|
||||
"## functions\n\n"
|
||||
"For the function call itself, use one of the following"
|
||||
f" functions:\n\n{function_definitions}"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _functions_compat_extract_call(response: str) -> AssistantFunctionCallDict:
|
||||
import json
|
||||
import re
|
||||
|
||||
logging.debug(f"Trying to extract function call from response:\n{response}")
|
||||
|
||||
if response[0] == "{":
|
||||
function_call = json.loads(response)
|
||||
else:
|
||||
block = re.search(r"```(?:function_call)?\n(.*)\n```\s*$", response, re.DOTALL)
|
||||
if not block:
|
||||
raise ValueError("Could not find function call block in response")
|
||||
function_call = json.loads(block.group(1))
|
||||
|
||||
function_call["arguments"] = str(function_call["arguments"]) # HACK
|
||||
return function_call
|
||||
|
||||
@@ -333,7 +333,7 @@ class ChatModelProvider(ModelProvider):
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: str,
|
||||
completion_parser: Callable[[AssistantChatMessageDict], _T] = lambda _: None,
|
||||
functions: list[CompletionModelFunction] = [],
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
...
|
||||
|
||||
Reference in New Issue
Block a user