mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
Allow model client to accept the tool schema directly (#196)
This commit is contained in:
@@ -11,7 +11,7 @@ from typing_extensions import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from ..tools import Tool
|
||||
from ..tools import Tool, ToolSchema
|
||||
from ._types import CreateResult, LLMMessage, RequestUsage
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ class ChatCompletionClient(Protocol):
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool] = [],
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
json_output: Optional[bool] = None,
|
||||
@@ -37,7 +37,7 @@ class ChatCompletionClient(Protocol):
|
||||
def create_stream(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool] = [],
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
json_output: Optional[bool] = None,
|
||||
|
||||
@@ -38,7 +38,7 @@ from .. import (
|
||||
FunctionCall,
|
||||
Image,
|
||||
)
|
||||
from ..tools import Tool
|
||||
from ..tools import Tool, ToolSchema
|
||||
from . import _model_info
|
||||
from ._model_client import ChatCompletionClient, ModelCapabilities
|
||||
from ._types import (
|
||||
@@ -205,11 +205,16 @@ def _add_usage(usage1: RequestUsage, usage2: RequestUsage) -> RequestUsage:
|
||||
|
||||
|
||||
def convert_tools(
|
||||
tools: Sequence[Tool],
|
||||
tools: Sequence[Tool | ToolSchema],
|
||||
) -> List[ChatCompletionToolParam]:
|
||||
result: List[ChatCompletionToolParam] = []
|
||||
for tool in tools:
|
||||
tool_schema = tool.schema
|
||||
if isinstance(tool, Tool):
|
||||
tool_schema = tool.schema
|
||||
else:
|
||||
assert isinstance(tool, dict)
|
||||
tool_schema = tool
|
||||
|
||||
result.append(
|
||||
ChatCompletionToolParam(
|
||||
type="function",
|
||||
@@ -287,7 +292,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool] = [],
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
) -> CreateResult:
|
||||
@@ -393,7 +398,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
async def create_stream(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool] = [],
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from ._base import BaseTool, BaseToolWithState, Tool
|
||||
from ._base import BaseTool, BaseToolWithState, Tool, ToolSchema
|
||||
from ._code_execution import CodeExecutionInput, CodeExecutionResult, PythonCodeExecutionTool
|
||||
from ._function_tool import FunctionTool
|
||||
|
||||
__all__ = [
|
||||
"Tool",
|
||||
"ToolSchema",
|
||||
"BaseTool",
|
||||
"BaseToolWithState",
|
||||
"PythonCodeExecutionTool",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar
|
||||
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import NotRequired
|
||||
@@ -24,6 +24,7 @@ class ToolSchema(TypedDict):
|
||||
description: NotRequired[str]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Tool(Protocol):
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
|
||||
Reference in New Issue
Block a user