Allow model client to accept the tool schema directly (#196)

This commit is contained in:
Jack Gerrits
2024-07-09 16:44:58 -04:00
committed by GitHub
parent 2191b3144b
commit 05e72084e8
5 changed files with 40 additions and 10 deletions

View File

@@ -11,7 +11,7 @@ from typing_extensions import (
Union,
)
from ..tools import Tool
from ..tools import Tool, ToolSchema
from ._types import CreateResult, LLMMessage, RequestUsage
@@ -27,7 +27,7 @@ class ChatCompletionClient(Protocol):
async def create(
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool] = [],
tools: Sequence[Tool | ToolSchema] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool] = None,
@@ -37,7 +37,7 @@ class ChatCompletionClient(Protocol):
def create_stream(
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool] = [],
tools: Sequence[Tool | ToolSchema] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool] = None,

View File

@@ -38,7 +38,7 @@ from .. import (
FunctionCall,
Image,
)
from ..tools import Tool
from ..tools import Tool, ToolSchema
from . import _model_info
from ._model_client import ChatCompletionClient, ModelCapabilities
from ._types import (
@@ -205,11 +205,16 @@ def _add_usage(usage1: RequestUsage, usage2: RequestUsage) -> RequestUsage:
def convert_tools(
tools: Sequence[Tool],
tools: Sequence[Tool | ToolSchema],
) -> List[ChatCompletionToolParam]:
result: List[ChatCompletionToolParam] = []
for tool in tools:
tool_schema = tool.schema
if isinstance(tool, Tool):
tool_schema = tool.schema
else:
assert isinstance(tool, dict)
tool_schema = tool
result.append(
ChatCompletionToolParam(
type="function",
@@ -287,7 +292,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
async def create(
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool] = [],
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
) -> CreateResult:
@@ -393,7 +398,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
async def create_stream(
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool] = [],
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
) -> AsyncGenerator[Union[str, CreateResult], None]:

View File

@@ -1,9 +1,10 @@
from ._base import BaseTool, BaseToolWithState, Tool
from ._base import BaseTool, BaseToolWithState, Tool, ToolSchema
from ._code_execution import CodeExecutionInput, CodeExecutionResult, PythonCodeExecutionTool
from ._function_tool import FunctionTool
__all__ = [
"Tool",
"ToolSchema",
"BaseTool",
"BaseToolWithState",
"PythonCodeExecutionTool",

View File

@@ -1,7 +1,7 @@
import json
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar, runtime_checkable
from pydantic import BaseModel
from typing_extensions import NotRequired
@@ -24,6 +24,7 @@ class ToolSchema(TypedDict):
description: NotRequired[str]
@runtime_checkable
class Tool(Protocol):
@property
def name(self) -> str: ...

View File

@@ -5,6 +5,7 @@ from typing import Annotated
import pytest
from agnext.components._function_utils import get_typed_signature
from agnext.components.tools import BaseTool, FunctionTool
from agnext.components.models._openai_client import convert_tools
from agnext.core import CancellationToken
from pydantic import BaseModel, Field, model_serializer
from pydantic_core import PydanticUndefined
@@ -286,3 +287,25 @@ async def test_func_int_res()-> None:
tool = FunctionTool(my_function, description="Function tool.")
result = await tool.run_json({"arg": 5}, CancellationToken())
assert tool.return_value_as_string(result) == "5"
def test_convert_tools_accepts_both_func_tool_and_schema() -> None:
def my_function(arg: str, other: Annotated[int, "int arg"], nonrequired: int = 5) -> MyResult:
return MyResult(result="test")
tool = FunctionTool(my_function, description="Function tool.")
schema = tool.schema
converted_tool_schema = convert_tools([tool, schema])
assert len(converted_tool_schema) == 2
assert converted_tool_schema[0] == converted_tool_schema[1]
def test_convert_tools_accepts_both_tool_and_schema() -> None:
tool = MyTool()
schema = tool.schema
converted_tool_schema = convert_tools([tool, schema])
assert len(converted_tool_schema) == 2
assert converted_tool_schema[0] == converted_tool_schema[1]