Allow model client to accept the tool schema directly (#196)

This commit is contained in:
Jack Gerrits
2024-07-09 16:44:58 -04:00
committed by GitHub
parent 2191b3144b
commit 05e72084e8
5 changed files with 40 additions and 10 deletions

View File

@@ -11,7 +11,7 @@ from typing_extensions import (
Union,
)
from ..tools import Tool
from ..tools import Tool, ToolSchema
from ._types import CreateResult, LLMMessage, RequestUsage
@@ -27,7 +27,7 @@ class ChatCompletionClient(Protocol):
async def create(
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool] = [],
tools: Sequence[Tool | ToolSchema] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool] = None,
@@ -37,7 +37,7 @@ class ChatCompletionClient(Protocol):
def create_stream(
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool] = [],
tools: Sequence[Tool | ToolSchema] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool] = None,

View File

@@ -38,7 +38,7 @@ from .. import (
FunctionCall,
Image,
)
from ..tools import Tool
from ..tools import Tool, ToolSchema
from . import _model_info
from ._model_client import ChatCompletionClient, ModelCapabilities
from ._types import (
@@ -205,11 +205,16 @@ def _add_usage(usage1: RequestUsage, usage2: RequestUsage) -> RequestUsage:
def convert_tools(
tools: Sequence[Tool],
tools: Sequence[Tool | ToolSchema],
) -> List[ChatCompletionToolParam]:
result: List[ChatCompletionToolParam] = []
for tool in tools:
tool_schema = tool.schema
if isinstance(tool, Tool):
tool_schema = tool.schema
else:
assert isinstance(tool, dict)
tool_schema = tool
result.append(
ChatCompletionToolParam(
type="function",
@@ -287,7 +292,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
async def create(
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool] = [],
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
) -> CreateResult:
@@ -393,7 +398,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
async def create_stream(
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool] = [],
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
) -> AsyncGenerator[Union[str, CreateResult], None]:

View File

@@ -1,9 +1,10 @@
from ._base import BaseTool, BaseToolWithState, Tool
from ._base import BaseTool, BaseToolWithState, Tool, ToolSchema
from ._code_execution import CodeExecutionInput, CodeExecutionResult, PythonCodeExecutionTool
from ._function_tool import FunctionTool
__all__ = [
"Tool",
"ToolSchema",
"BaseTool",
"BaseToolWithState",
"PythonCodeExecutionTool",

View File

@@ -1,7 +1,7 @@
import json
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar, runtime_checkable
from pydantic import BaseModel
from typing_extensions import NotRequired
@@ -24,6 +24,7 @@ class ToolSchema(TypedDict):
description: NotRequired[str]
@runtime_checkable
class Tool(Protocol):
@property
def name(self) -> str: ...