mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Allow model client to accept the tool schema directly (#196)
This commit is contained in:
@@ -11,7 +11,7 @@ from typing_extensions import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from ..tools import Tool
|
||||
from ..tools import Tool, ToolSchema
|
||||
from ._types import CreateResult, LLMMessage, RequestUsage
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ class ChatCompletionClient(Protocol):
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool] = [],
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
json_output: Optional[bool] = None,
|
||||
@@ -37,7 +37,7 @@ class ChatCompletionClient(Protocol):
|
||||
def create_stream(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool] = [],
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
json_output: Optional[bool] = None,
|
||||
|
||||
@@ -38,7 +38,7 @@ from .. import (
|
||||
FunctionCall,
|
||||
Image,
|
||||
)
|
||||
from ..tools import Tool
|
||||
from ..tools import Tool, ToolSchema
|
||||
from . import _model_info
|
||||
from ._model_client import ChatCompletionClient, ModelCapabilities
|
||||
from ._types import (
|
||||
@@ -205,11 +205,16 @@ def _add_usage(usage1: RequestUsage, usage2: RequestUsage) -> RequestUsage:
|
||||
|
||||
|
||||
def convert_tools(
|
||||
tools: Sequence[Tool],
|
||||
tools: Sequence[Tool | ToolSchema],
|
||||
) -> List[ChatCompletionToolParam]:
|
||||
result: List[ChatCompletionToolParam] = []
|
||||
for tool in tools:
|
||||
tool_schema = tool.schema
|
||||
if isinstance(tool, Tool):
|
||||
tool_schema = tool.schema
|
||||
else:
|
||||
assert isinstance(tool, dict)
|
||||
tool_schema = tool
|
||||
|
||||
result.append(
|
||||
ChatCompletionToolParam(
|
||||
type="function",
|
||||
@@ -287,7 +292,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool] = [],
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
) -> CreateResult:
|
||||
@@ -393,7 +398,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
async def create_stream(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool] = [],
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from ._base import BaseTool, BaseToolWithState, Tool
|
||||
from ._base import BaseTool, BaseToolWithState, Tool, ToolSchema
|
||||
from ._code_execution import CodeExecutionInput, CodeExecutionResult, PythonCodeExecutionTool
|
||||
from ._function_tool import FunctionTool
|
||||
|
||||
__all__ = [
|
||||
"Tool",
|
||||
"ToolSchema",
|
||||
"BaseTool",
|
||||
"BaseToolWithState",
|
||||
"PythonCodeExecutionTool",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar
|
||||
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import NotRequired
|
||||
@@ -24,6 +24,7 @@ class ToolSchema(TypedDict):
|
||||
description: NotRequired[str]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Tool(Protocol):
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Annotated
|
||||
import pytest
|
||||
from agnext.components._function_utils import get_typed_signature
|
||||
from agnext.components.tools import BaseTool, FunctionTool
|
||||
from agnext.components.models._openai_client import convert_tools
|
||||
from agnext.core import CancellationToken
|
||||
from pydantic import BaseModel, Field, model_serializer
|
||||
from pydantic_core import PydanticUndefined
|
||||
@@ -286,3 +287,25 @@ async def test_func_int_res()-> None:
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
result = await tool.run_json({"arg": 5}, CancellationToken())
|
||||
assert tool.return_value_as_string(result) == "5"
|
||||
|
||||
|
||||
def test_convert_tools_accepts_both_func_tool_and_schema() -> None:
|
||||
def my_function(arg: str, other: Annotated[int, "int arg"], nonrequired: int = 5) -> MyResult:
|
||||
return MyResult(result="test")
|
||||
tool = FunctionTool(my_function, description="Function tool.")
|
||||
schema = tool.schema
|
||||
|
||||
converted_tool_schema = convert_tools([tool, schema])
|
||||
|
||||
assert len(converted_tool_schema) == 2
|
||||
assert converted_tool_schema[0] == converted_tool_schema[1]
|
||||
|
||||
|
||||
def test_convert_tools_accepts_both_tool_and_schema() -> None:
|
||||
tool = MyTool()
|
||||
schema = tool.schema
|
||||
|
||||
converted_tool_schema = convert_tools([tool, schema])
|
||||
|
||||
assert len(converted_tool_schema) == 2
|
||||
assert converted_tool_schema[0] == converted_tool_schema[1]
|
||||
|
||||
Reference in New Issue
Block a user