mirror of
https://github.com/microsoft/autogen.git
synced 2026-02-12 00:04:56 -05:00
feat!: Add support for model family specification (#4856)
* Add support for model family specification * spelling mistake * lint, etc * fixes
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from ._model_client import ChatCompletionClient, ModelCapabilities
|
||||
from ._model_client import ChatCompletionClient, ModelCapabilities, ModelFamily, ModelInfo # type: ignore
|
||||
from ._types import (
|
||||
AssistantMessage,
|
||||
ChatCompletionTokenLogprob,
|
||||
@@ -27,4 +27,6 @@ __all__ = [
|
||||
"CreateResult",
|
||||
"TopLogprob",
|
||||
"ChatCompletionTokenLogprob",
|
||||
"ModelFamily",
|
||||
"ModelInfo",
|
||||
]
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Mapping, Optional, Sequence
|
||||
from typing import Literal, Mapping, Optional, Sequence, TypeAlias
|
||||
|
||||
from typing_extensions import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Required,
|
||||
TypedDict,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import Any, AsyncGenerator, Required, TypedDict, Union, deprecated
|
||||
|
||||
from .. import CancellationToken
|
||||
from .._component_config import ComponentLoader
|
||||
@@ -17,12 +12,41 @@ from ..tools import Tool, ToolSchema
|
||||
from ._types import CreateResult, LLMMessage, RequestUsage
|
||||
|
||||
|
||||
class ModelFamily:
|
||||
"""A model family is a group of models that share similar characteristics from a capabilities perspective. This is different to discrete supported features such as vision, function calling, and JSON output.
|
||||
|
||||
This namespace class holds constants for the model families that AutoGen understands. Other families definitely exist and can be represented by a string, however, AutoGen will treat them as unknown."""
|
||||
|
||||
GPT_4O = "gpt-4o"
|
||||
O1 = "o1"
|
||||
GPT_4 = "gpt-4"
|
||||
GPT_35 = "gpt-35"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
ANY: TypeAlias = Literal["gpt-4o", "o1", "gpt-4", "gpt-35", "unknown"]
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> ModelFamily:
|
||||
raise TypeError(f"{cls.__name__} is a namespace class and cannot be instantiated.")
|
||||
|
||||
|
||||
@deprecated("Use the ModelInfo class instead ModelCapabilities.")
|
||||
class ModelCapabilities(TypedDict, total=False):
|
||||
vision: Required[bool]
|
||||
function_calling: Required[bool]
|
||||
json_output: Required[bool]
|
||||
|
||||
|
||||
class ModelInfo(TypedDict, total=False):
|
||||
vision: Required[bool]
|
||||
"""True if the model supports vision, aka image input, otherwise False."""
|
||||
function_calling: Required[bool]
|
||||
"""True if the model supports function calling, otherwise False."""
|
||||
json_output: Required[bool]
|
||||
"""True if the model supports json output, otherwise False. Note: this is different to structured json."""
|
||||
family: Required[ModelFamily.ANY | str]
|
||||
"""Model family should be one of the constants from :py:class:`ModelFamily` or a string representing an unknown model family."""
|
||||
|
||||
|
||||
class ChatCompletionClient(ABC, ComponentLoader):
|
||||
# Caching has to be handled internally as they can depend on the create args that were stored in the constructor
|
||||
@abstractmethod
|
||||
@@ -63,6 +87,18 @@ class ChatCompletionClient(ABC, ComponentLoader):
|
||||
@abstractmethod
|
||||
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: ...
|
||||
|
||||
# Deprecated
|
||||
@property
|
||||
@abstractmethod
|
||||
def capabilities(self) -> ModelCapabilities: ...
|
||||
def capabilities(self) -> ModelCapabilities: ... # type: ignore
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model_info(self) -> ModelInfo:
|
||||
warnings.warn(
|
||||
"Model client in use does not implement model_info property. Falling back to capabilities property. The capabilities property is deprecated and will be removed soon, please implement model_info instead in the model client class.",
|
||||
stacklevel=2,
|
||||
)
|
||||
base_info: ModelInfo = self.capabilities # type: ignore
|
||||
base_info["family"] = ModelFamily.UNKNOWN
|
||||
return base_info
|
||||
|
||||
@@ -11,10 +11,11 @@ from autogen_core.models import (
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
ModelCapabilities,
|
||||
ModelCapabilities, # type: ignore
|
||||
RequestUsage,
|
||||
UserMessage,
|
||||
)
|
||||
from autogen_core.models._model_client import ModelFamily, ModelInfo
|
||||
from autogen_core.tool_agent import (
|
||||
InvalidToolArgumentsException,
|
||||
ToolAgent,
|
||||
@@ -138,8 +139,12 @@ async def test_caller_loop() -> None:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def capabilities(self) -> ModelCapabilities:
|
||||
return ModelCapabilities(vision=False, function_calling=True, json_output=False)
|
||||
def capabilities(self) -> ModelCapabilities: # type: ignore
|
||||
return ModelCapabilities(vision=False, function_calling=True, json_output=False) # type: ignore
|
||||
|
||||
@property
|
||||
def model_info(self) -> ModelInfo:
|
||||
return ModelInfo(vision=False, function_calling=True, json_output=False, family=ModelFamily.UNKNOWN)
|
||||
|
||||
client = MockChatCompletionClient()
|
||||
tools: List[Tool] = [FunctionTool(_pass_function, name="pass", description="Pass function")]
|
||||
|
||||
Reference in New Issue
Block a user