feat!: Add support for model family specification (#4856)

* Add support for model family specification

* spelling mistake

* lint, etc

* fixes
This commit is contained in:
Jack Gerrits
2024-12-30 15:09:21 -05:00
committed by GitHub
parent 190fcd15ed
commit cb1633b501
13 changed files with 152 additions and 58 deletions

View File

@@ -1,4 +1,4 @@
from ._model_client import ChatCompletionClient, ModelCapabilities
from ._model_client import ChatCompletionClient, ModelCapabilities, ModelFamily, ModelInfo # type: ignore
from ._types import (
AssistantMessage,
ChatCompletionTokenLogprob,
@@ -27,4 +27,6 @@ __all__ = [
"CreateResult",
"TopLogprob",
"ChatCompletionTokenLogprob",
"ModelFamily",
"ModelInfo",
]

View File

@@ -1,15 +1,10 @@
from __future__ import annotations
import warnings
from abc import ABC, abstractmethod
from typing import Mapping, Optional, Sequence
from typing import Literal, Mapping, Optional, Sequence, TypeAlias
from typing_extensions import (
Any,
AsyncGenerator,
Required,
TypedDict,
Union,
)
from typing_extensions import Any, AsyncGenerator, Required, TypedDict, Union, deprecated
from .. import CancellationToken
from .._component_config import ComponentLoader
@@ -17,12 +12,41 @@ from ..tools import Tool, ToolSchema
from ._types import CreateResult, LLMMessage, RequestUsage
class ModelFamily:
"""A model family is a group of models that share similar characteristics from a capabilities perspective. This is different to discrete supported features such as vision, function calling, and JSON output.
This namespace class holds constants for the model families that AutoGen understands. Other families definitely exist and can be represented by a string, however, AutoGen will treat them as unknown."""
GPT_4O = "gpt-4o"
O1 = "o1"
GPT_4 = "gpt-4"
GPT_35 = "gpt-35"
UNKNOWN = "unknown"
ANY: TypeAlias = Literal["gpt-4o", "o1", "gpt-4", "gpt-35", "unknown"]
def __new__(cls, *args: Any, **kwargs: Any) -> ModelFamily:
raise TypeError(f"{cls.__name__} is a namespace class and cannot be instantiated.")
@deprecated("Use the ModelInfo class instead ModelCapabilities.")
class ModelCapabilities(TypedDict, total=False):
vision: Required[bool]
function_calling: Required[bool]
json_output: Required[bool]
class ModelInfo(TypedDict, total=False):
vision: Required[bool]
"""True if the model supports vision, aka image input, otherwise False."""
function_calling: Required[bool]
"""True if the model supports function calling, otherwise False."""
json_output: Required[bool]
"""True if the model supports json output, otherwise False. Note: this is different to structured json."""
family: Required[ModelFamily.ANY | str]
"""Model family should be one of the constants from :py:class:`ModelFamily` or a string representing an unknown model family."""
class ChatCompletionClient(ABC, ComponentLoader):
# Caching has to be handled internally as they can depend on the create args that were stored in the constructor
@abstractmethod
@@ -63,6 +87,18 @@ class ChatCompletionClient(ABC, ComponentLoader):
@abstractmethod
def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: ...
# Deprecated
@property
@abstractmethod
def capabilities(self) -> ModelCapabilities: ...
def capabilities(self) -> ModelCapabilities: ... # type: ignore
@property
@abstractmethod
def model_info(self) -> ModelInfo:
warnings.warn(
"Model client in use does not implement model_info property. Falling back to capabilities property. The capabilities property is deprecated and will be removed soon, please implement model_info instead in the model client class.",
stacklevel=2,
)
base_info: ModelInfo = self.capabilities # type: ignore
base_info["family"] = ModelFamily.UNKNOWN
return base_info

View File

@@ -11,10 +11,11 @@ from autogen_core.models import (
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
ModelCapabilities,
ModelCapabilities, # type: ignore
RequestUsage,
UserMessage,
)
from autogen_core.models._model_client import ModelFamily, ModelInfo
from autogen_core.tool_agent import (
InvalidToolArgumentsException,
ToolAgent,
@@ -138,8 +139,12 @@ async def test_caller_loop() -> None:
return 0
@property
def capabilities(self) -> ModelCapabilities:
return ModelCapabilities(vision=False, function_calling=True, json_output=False)
def capabilities(self) -> ModelCapabilities: # type: ignore
return ModelCapabilities(vision=False, function_calling=True, json_output=False) # type: ignore
@property
def model_info(self) -> ModelInfo:
return ModelInfo(vision=False, function_calling=True, json_output=False, family=ModelFamily.UNKNOWN)
client = MockChatCompletionClient()
tools: List[Tool] = [FunctionTool(_pass_function, name="pass", description="Pass function")]