feat: add structured output to model clients (#5936)

This commit is contained in:
Eric Zhu
2025-03-15 07:58:13 -07:00
committed by GitHub
parent 9bde5ef911
commit aba41d74d3
27 changed files with 1629 additions and 404 deletions

View File

@@ -13,6 +13,7 @@ from autogen_core.models import (
RequestUsage,
)
from autogen_core.tools import Tool, ToolSchema
from pydantic import BaseModel
from .page_logger import PageLogger
@@ -87,7 +88,7 @@ class ChatCompletionClientRecorder(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> CreateResult:
@@ -154,7 +155,7 @@ class ChatCompletionClientRecorder(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> AsyncGenerator[Union[str, CreateResult], None]:

View File

@@ -61,7 +61,7 @@ from autogen_core.models import (
validate_model_info,
)
from autogen_core.tools import Tool, ToolSchema
from pydantic import SecretStr
from pydantic import BaseModel, SecretStr
from typing_extensions import Self, Unpack
from . import _model_info
@@ -413,7 +413,7 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> CreateResult:
@@ -435,6 +435,8 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
if json_output is True:
create_args["response_format"] = {"type": "json_object"}
elif isinstance(json_output, type):
raise ValueError("Structured output is currently not supported for Anthropic models")
# Process system message separately
system_message = None
@@ -568,7 +570,7 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
max_consecutive_empty_chunk_tolerance: int = 0,
@@ -595,6 +597,9 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
if json_output is True:
create_args["response_format"] = {"type": "json_object"}
if isinstance(json_output, type):
raise ValueError("Structured output is currently not supported for Anthropic models")
# Process system message separately
system_message = None
anthropic_messages: List[MessageParam] = []

View File

@@ -12,6 +12,7 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
"function_calling": True,
"json_output": True,
"family": ModelFamily.CLAUDE_3_7_SONNET,
"structured_output": False,
},
# Claude 3.7 Sonnet latest alias
"claude-3-7-sonnet-latest": {
@@ -19,6 +20,7 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
"function_calling": True,
"json_output": True,
"family": ModelFamily.CLAUDE_3_7_SONNET,
"structured_output": False,
},
# Claude 3 Opus (most powerful)
"claude-3-opus-20240229": {
@@ -26,6 +28,7 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
"function_calling": True,
"json_output": True,
"family": ModelFamily.CLAUDE_3_5_SONNET,
"structured_output": False,
},
# Claude 3 Sonnet (balanced)
"claude-3-sonnet-20240229": {
@@ -33,6 +36,7 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
"function_calling": True,
"json_output": True,
"family": ModelFamily.CLAUDE_3_5_SONNET,
"structured_output": False,
},
# Claude 3 Haiku (fastest)
"claude-3-haiku-20240307": {
@@ -40,6 +44,7 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
"function_calling": True,
"json_output": True,
"family": ModelFamily.CLAUDE_3_5_SONNET,
"structured_output": False,
},
# Claude 3.5 Sonnet
"claude-3-5-sonnet-20240620": {
@@ -47,6 +52,7 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
"function_calling": True,
"json_output": True,
"family": ModelFamily.CLAUDE_3_5_SONNET,
"structured_output": False,
},
# Claude Instant v1 (legacy)
"claude-instant-1.2": {
@@ -54,6 +60,7 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
"function_calling": False,
"json_output": True,
"family": ModelFamily.CLAUDE_3_5_SONNET,
"structured_output": False,
},
# Claude 2 (legacy)
"claude-2.0": {
@@ -61,6 +68,7 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
"function_calling": False,
"json_output": True,
"family": ModelFamily.CLAUDE_3_5_SONNET,
"structured_output": False,
},
# Claude 2.1 (legacy)
"claude-2.1": {
@@ -68,6 +76,7 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
"function_calling": False,
"json_output": True,
"family": ModelFamily.CLAUDE_3_5_SONNET,
"structured_output": False,
},
}

View File

@@ -52,6 +52,7 @@ from azure.ai.inference.models import (
from azure.ai.inference.models import (
UserMessage as AzureUserMessage,
)
from pydantic import BaseModel
from typing_extensions import AsyncGenerator, Union, Unpack
from autogen_ext.models.azure.config import (
@@ -226,6 +227,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
"function_calling": False,
"vision": False,
"family": "unknown",
"structured_output": False,
},
)
@@ -263,6 +265,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
"function_calling": False,
"vision": False,
"family": "unknown",
"structured_output": False,
},
)
@@ -323,7 +326,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool | ToolSchema],
json_output: Optional[bool],
json_output: Optional[bool | type[BaseModel]],
create_args: Dict[str, Any],
) -> None:
if self.model_info["vision"] is False:
@@ -336,6 +339,10 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
if self.model_info["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output")
if isinstance(json_output, type):
# TODO: we should support this in the future.
raise ValueError("Structured output is not currently supported for AzureAIChatCompletionClient")
if json_output is True and "response_format" not in create_args:
create_args["response_format"] = "json_object"
@@ -349,7 +356,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> CreateResult:
@@ -442,7 +449,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> AsyncGenerator[Union[str, CreateResult], None]:

View File

@@ -102,7 +102,7 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
self,
messages: Sequence[LLMMessage],
tools: Sequence[Tool | ToolSchema],
json_output: Optional[bool],
json_output: Optional[bool | type[BaseModel]],
extra_create_args: Mapping[str, Any],
) -> tuple[Optional[Union[CreateResult, List[Union[str, CreateResult]]]], str]:
"""
@@ -110,10 +110,17 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
Returns a tuple of (cached_result, cache_key).
"""
json_output_data: str | bool | None = None
if isinstance(json_output, type) and issubclass(json_output, BaseModel):
json_output_data = json.dumps(json_output.model_json_schema())
elif isinstance(json_output, bool):
json_output_data = json_output
data = {
"messages": [message.model_dump() for message in messages],
"tools": [(tool.schema if isinstance(tool, Tool) else tool) for tool in tools],
"json_output": json_output,
"json_output": json_output_data,
"extra_create_args": extra_create_args,
}
serialized_data = json.dumps(data, sort_keys=True)
@@ -130,7 +137,7 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> CreateResult:
@@ -162,7 +169,7 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> AsyncGenerator[Union[str, CreateResult], None]:

View File

@@ -1,3 +1,4 @@
import asyncio
import logging # added import
import re
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Sequence, TypedDict, Union, cast
@@ -11,6 +12,7 @@ from autogen_core.models import (
FinishReasons,
FunctionExecutionResultMessage,
LLMMessage,
ModelFamily,
ModelInfo,
RequestUsage,
SystemMessage,
@@ -30,6 +32,7 @@ from llama_cpp import (
Llama,
llama_chat_format,
)
from pydantic import BaseModel
from typing_extensions import Unpack
logger = logging.getLogger(EVENT_LOGGER_NAME) # initialize logger
@@ -172,6 +175,7 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
This client allows you to interact with LlamaCpp models, either by specifying a local model path or by downloading a model from Hugging Face Hub.
Args:
model_info (optional, ModelInfo): The information about the model. Defaults to :attr:`~LlamaCppChatCompletionClient.DEFAULT_MODEL_INFO`.
model_path (optional, str): The path to the LlamaCpp model file. Required if repo_id and filename are not provided.
repo_id (optional, str): The Hugging Face Hub repository ID. Required if model_path is not provided.
filename (optional, str): The filename of the model within the Hugging Face Hub repository. Required if model_path is not provided.
@@ -179,7 +183,6 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
n_ctx (optional, int): The context size.
n_batch (optional, int): The batch size.
verbose (optional, bool): Whether to print verbose output.
model_info (optional, ModelInfo): The capabilities of the model. Defaults to a ModelInfo instance with function_calling set to True.
**kwargs: Additional parameters to pass to the Llama class.
Examples:
@@ -223,6 +226,10 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
asyncio.run(main())
"""
DEFAULT_MODEL_INFO: ModelInfo = ModelInfo(
vision=False, json_output=True, family=ModelFamily.UNKNOWN, function_calling=True, structured_output=True
)
def __init__(
self,
model_info: Optional[ModelInfo] = None,
@@ -234,6 +241,10 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
if model_info:
validate_model_info(model_info)
self._model_info = model_info
else:
# Default model info.
self._model_info = self.DEFAULT_MODEL_INFO
if "repo_id" in kwargs and "filename" in kwargs and kwargs["repo_id"] and kwargs["filename"]:
repo_id: str = cast(str, kwargs.pop("repo_id"))
@@ -255,10 +266,11 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
tools: Sequence[Tool | ToolSchema] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> CreateResult:
create_args = dict(extra_create_args)
# Convert LLMMessage objects to dictionaries with 'role' and 'content'
# converted_messages: List[Dict[str, str | Image | list[str | Image] | list[FunctionCall]]] = []
converted_messages: list[
@@ -283,12 +295,28 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
else:
raise ValueError(f"Unsupported message type: {type(msg)}")
if isinstance(json_output, type) and issubclass(json_output, BaseModel):
create_args["response_format"] = {"type": "json_object", "schema": json_output.model_json_schema()}
elif json_output is True:
create_args["response_format"] = {"type": "json_object"}
elif json_output is not False and json_output is not None:
raise ValueError("json_output must be a boolean, a BaseModel subclass or None.")
if self.model_info["function_calling"]:
response = self.llm.create_chat_completion(
messages=converted_messages, tools=convert_tools(tools), stream=False
# Run this in on the event loop to avoid blocking.
response_future = asyncio.get_event_loop().run_in_executor(
None,
lambda: self.llm.create_chat_completion(
messages=converted_messages, tools=convert_tools(tools), stream=False, **create_args
),
)
else:
response = self.llm.create_chat_completion(messages=converted_messages, stream=False)
response_future = asyncio.get_event_loop().run_in_executor(
None, lambda: self.llm.create_chat_completion(messages=converted_messages, stream=False, **create_args)
)
if cancellation_token:
cancellation_token.link_future(response_future)
response = await response_future
if not isinstance(response, dict):
raise ValueError("Unexpected response type from LlamaCpp model.")
@@ -371,7 +399,7 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
tools: Sequence[Tool | ToolSchema] = [],
# None means do not override the default
# A value means to override the client default - often specified in the constructor
json_output: Optional[bool] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> AsyncGenerator[Union[str, CreateResult], None]:
@@ -403,7 +431,7 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
@property
def model_info(self) -> ModelInfo:
return ModelInfo(vision=False, json_output=False, family="llama-cpp", function_calling=True)
return self._model_info
def remaining_tokens(
self,

View File

@@ -6,73 +6,300 @@ from autogen_core.models import ModelFamily, ModelInfo
# TODO: fix model family?
# TODO: json_output is True for all models because ollama supports structured output via pydantic. How to handle this situation?
_MODEL_INFO: Dict[str, ModelInfo] = {
"all-minilm": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"bge-m3": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"codegemma": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"codellama": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"command-r": {"vision": False, "function_calling": True, "json_output": True, "family": ModelFamily.UNKNOWN},
"deepseek-coder": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"all-minilm": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"bge-m3": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"codegemma": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"codellama": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"command-r": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"deepseek-coder": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"deepseek-coder-v2": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"deepseek-r1": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.R1,
"structured_output": True,
},
"dolphin-llama3": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"dolphin-mistral": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"dolphin-mixtral": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"gemma": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"gemma2": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"llama2": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"deepseek-r1": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.R1},
"dolphin-llama3": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"dolphin-mistral": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"dolphin-mixtral": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"gemma": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"gemma2": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"llama2": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"llama2-uncensored": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"llama3": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"llama3.1": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"llama3.2": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"llama3.2-vision": {
"vision": True,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"llama3.3": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"llava": {
"vision": True,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"llava-llama3": {
"vision": True,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"mistral": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"mistral-nemo": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"mixtral": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"llama3": {"vision": False, "function_calling": True, "json_output": True, "family": ModelFamily.UNKNOWN},
"llama3.1": {"vision": False, "function_calling": True, "json_output": True, "family": ModelFamily.UNKNOWN},
"llama3.2": {"vision": False, "function_calling": True, "json_output": True, "family": ModelFamily.UNKNOWN},
"llama3.2-vision": {"vision": True, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"llama3.3": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"llava": {"vision": True, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"llava-llama3": {"vision": True, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"mistral": {"vision": False, "function_calling": True, "json_output": True, "family": ModelFamily.UNKNOWN},
"mistral-nemo": {"vision": False, "function_calling": True, "json_output": True, "family": ModelFamily.UNKNOWN},
"mixtral": {"vision": False, "function_calling": True, "json_output": True, "family": ModelFamily.UNKNOWN},
"mxbai-embed-large": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"nomic-embed-text": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"orca-mini": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"phi": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"phi3": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"phi3.5": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"phi4": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"qwen": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"qwen2": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"qwen2.5": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"qwen2.5-coder": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"orca-mini": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"phi": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"phi3": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"phi3.5": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"phi4": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"qwen": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"qwen2": {"vision": False, "function_calling": True, "json_output": True, "family": ModelFamily.UNKNOWN},
"qwen2.5": {"vision": False, "function_calling": True, "json_output": True, "family": ModelFamily.UNKNOWN},
"qwen2.5-coder": {"vision": False, "function_calling": True, "json_output": True, "family": ModelFamily.UNKNOWN},
"snowflake-arctic-embed": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"starcoder2": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"tinyllama": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"wizardlm2": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"yi": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"zephyr": {
"vision": False,
"function_calling": False,
"json_output": True,
"family": ModelFamily.UNKNOWN,
"structured_output": True,
},
"starcoder2": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"tinyllama": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"wizardlm2": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"yi": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
"zephyr": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
}
# TODO: the ollama model card for some of these models were incorrect. I made a best effort to get the actual values, but they aren't guaranteed to be correct.

View File

@@ -5,12 +5,13 @@ import logging
import math
import re
import warnings
from asyncio import Task
from dataclasses import dataclass
from typing import (
Any,
AsyncGenerator,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
@@ -48,6 +49,7 @@ from ollama import Image as OllamaImage
from ollama import Tool as OllamaTool
from ollama._types import ChatRequest
from pydantic import BaseModel
from pydantic.json_schema import JsonSchemaValue
from typing_extensions import Self, Unpack
from . import _model_info
@@ -328,6 +330,7 @@ def normalize_stop_reason(stop_reason: str | None) -> FinishReasons:
stop_reason = stop_reason.lower()
KNOWN_STOP_MAPPINGS: Dict[str, FinishReasons] = {
"stop": "stop",
"end_turn": "stop",
"tool_calls": "function_calls",
}
@@ -335,6 +338,14 @@ def normalize_stop_reason(stop_reason: str | None) -> FinishReasons:
return KNOWN_STOP_MAPPINGS.get(stop_reason, "unknown")
@dataclass
class CreateParams:
messages: Sequence[Message]
tools: Sequence[OllamaTool]
format: Optional[Union[Literal["", "json"], JsonSchemaValue]]
create_args: Dict[str, Any]
class BaseOllamaChatCompletionClient(ChatCompletionClient):
def __init__(
self,
@@ -390,39 +401,61 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
def get_create_args(self) -> Mapping[str, Any]:
return self._create_args
async def create(
def _process_create_args(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> CreateResult:
# Make sure all extra_create_args are valid
# TODO: kwarg checking logic
# extra_create_args_keys = set(extra_create_args.keys())
# if not create_kwargs.issuperset(extra_create_args_keys):
# raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
tools: Sequence[Tool | ToolSchema],
json_output: Optional[bool | type[BaseModel]],
extra_create_args: Mapping[str, Any],
) -> CreateParams:
# Copy the create args and overwrite anything in extra_create_args
create_args = self._create_args.copy()
create_args.update(extra_create_args)
response_format_value: Optional[Mapping[str, Any]] = None
response_format_value: JsonSchemaValue | Literal["json"] | None = None
if "response_format" in create_args:
warnings.warn(
"Using response_format will be deprecated. Use json_output instead.",
DeprecationWarning,
stacklevel=2,
)
value = create_args["response_format"]
# If value is a Pydantic model class, use the beta client
if isinstance(value, type) and issubclass(value, BaseModel):
response_format_value = value.model_json_schema()
# Remove response_format from create_args to prevent passing it twice.
del create_args["response_format"]
else:
# response_format_value is not a Pydantic model class
# TODO: Should this be an warning/error?
response_format_value = None
raise ValueError(f"response_format must be a Pydantic model class, not {type(value)}")
# Remove 'response_format' from create_args to prevent passing it twice
create_args_no_response_format = {k: v for k, v in create_args.items() if k != "response_format"}
if json_output is not None:
if self.model_info["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output.")
if json_output is True:
# JSON mode.
response_format_value = "json"
elif json_output is False:
# Text mode.
response_format_value = None
elif isinstance(json_output, type) and issubclass(json_output, BaseModel):
if response_format_value is not None:
raise ValueError(
"response_format and json_output cannot be set to a Pydantic model class at the same time. "
"Use json_output instead."
)
# Beta client mode with Pydantic model class.
response_format_value = json_output.model_json_schema()
else:
raise ValueError(f"json_output must be a boolean or a Pydantic model class, got {type(json_output)}")
if "format" in create_args:
# Handle the case where format is set from create_args.
if json_output is not None:
raise ValueError("json_output and format cannot be set at the same time. Use json_output instead.")
assert response_format_value is None
response_format_value = create_args["format"]
# Remove format from create_args to prevent passing it twice.
del create_args["format"]
# TODO: allow custom handling.
# For now we raise an error if images are present and vision is not supported
@@ -432,15 +465,6 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
raise ValueError("Model does not support vision and image was provided")
if json_output is not None:
if self.model_info["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output.")
if json_output is True:
create_args["response_format"] = {"type": "json_object"}
else:
create_args["response_format"] = {"type": "text"}
if self.model_info["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output.")
@@ -448,30 +472,47 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
ollama_messages = [item for sublist in ollama_messages_nested for item in sublist]
if self.model_info["function_calling"] is False and len(tools) > 0:
raise ValueError("Model does not support function calling")
future: Task[ChatResponse]
if len(tools) > 0:
converted_tools = convert_tools(tools)
future = asyncio.ensure_future(
self._client.chat( # type: ignore
# model=self._model_name,
messages=ollama_messages,
tools=converted_tools,
stream=False,
format=response_format_value,
**create_args_no_response_format,
)
)
else:
future = asyncio.ensure_future(
self._client.chat( # type: ignore
# model=self._model_name,
messages=ollama_messages,
stream=False,
format=response_format_value,
**create_args_no_response_format,
)
raise ValueError("Model does not support function calling and tools were provided")
converted_tools = convert_tools(tools)
return CreateParams(
messages=ollama_messages,
tools=converted_tools,
format=response_format_value,
create_args=create_args,
)
async def create(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> CreateResult:
# Make sure all extra_create_args are valid
# TODO: kwarg checking logic
# extra_create_args_keys = set(extra_create_args.keys())
# if not create_kwargs.issuperset(extra_create_args_keys):
# raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
create_params = self._process_create_args(
messages,
tools,
json_output,
extra_create_args,
)
future = asyncio.ensure_future(
self._client.chat( # type: ignore
# model=self._model_name,
messages=create_params.messages,
tools=create_params.tools if len(create_params.tools) > 0 else None,
stream=False,
format=create_params.format,
**create_params.create_args,
)
)
if cancellation_token is not None:
cancellation_token.link_future(future)
result: ChatResponse = await future
@@ -484,7 +525,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
logger.info(
LLMCallEvent(
messages=[m.model_dump() for m in ollama_messages],
messages=[m.model_dump() for m in create_params.messages],
response=result.model_dump(),
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
@@ -564,109 +605,31 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
max_consecutive_empty_chunk_tolerance: int = 0,
) -> AsyncGenerator[Union[str, CreateResult], None]:
"""
Creates an AsyncGenerator that will yield a stream of chat completions based on the provided messages and tools.
Args:
messages (Sequence[LLMMessage]): A sequence of messages to be processed.
tools (Sequence[Tool | ToolSchema], optional): A sequence of tools to be used in the completion. Defaults to `[]`.
json_output (Optional[bool], optional): If True, the output will be in JSON format. Defaults to None.
extra_create_args (Mapping[str, Any], optional): Additional arguments for the creation process. Default to `{}`.
cancellation_token (Optional[CancellationToken], optional): A token to cancel the operation. Defaults to None.
max_consecutive_empty_chunk_tolerance (int): The maximum number of consecutive empty chunks to tolerate before raising a ValueError. This seems to only be needed to set when using `AzureOpenAIChatCompletionClient`. Defaults to 0.
Yields:
AsyncGenerator[Union[str, CreateResult], None]: A generator yielding the completion results as they are produced.
In streaming, the default behaviour is not return token usage counts. See: [OpenAI API reference for possible args](https://platform.openai.com/docs/api-reference/chat/create).
However `extra_create_args={"stream_options": {"include_usage": True}}` will (if supported by the accessed API)
return a final chunk with usage set to a RequestUsage object having prompt and completion token counts,
all preceding chunks will have usage as None. See: [stream_options](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options).
Other examples of OPENAI supported arguments that can be included in `extra_create_args`:
- `temperature` (float): Controls the randomness of the output. Higher values (e.g., 0.8) make the output more random, while lower values (e.g., 0.2) make it more focused and deterministic.
- `max_tokens` (int): The maximum number of tokens to generate in the completion.
- `top_p` (float): An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass.
- `frequency_penalty` (float): A value between -2.0 and 2.0 that penalizes new tokens based on their existing frequency in the text so far, decreasing the likelihood of repeated phrases.
- `presence_penalty` (float): A value between -2.0 and 2.0 that penalizes new tokens based on whether they appear in the text so far, encouraging the model to talk about new topics.
"""
# Make sure all extra_create_args are valid
# TODO: kwarg checking logic
# extra_create_args_keys = set(extra_create_args.keys())
# if not create_kwargs.issuperset(extra_create_args_keys):
# raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
# Copy the create args and overwrite anything in extra_create_args
create_args = self._create_args.copy()
create_args.update(extra_create_args)
response_format_value: Optional[Mapping[str, Any]] = None
if "response_format" in create_args:
value = create_args["response_format"]
# If value is a Pydantic model class, use the beta client
if isinstance(value, type) and issubclass(value, BaseModel):
response_format_value = value.model_json_schema()
else:
# response_format_value is not a Pydantic model class
response_format_value = None
# Remove 'response_format' from create_args to prevent passing it twice
create_args_no_response_format = {k: v for k, v in create_args.items() if k != "response_format"}
# TODO: allow custom handling.
# For now we raise an error if images are present and vision is not supported
if self.model_info["vision"] is False:
for message in messages:
if isinstance(message, UserMessage):
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
raise ValueError("Model does not support vision and image was provided")
if json_output is not None:
if self.model_info["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output.")
if json_output is True:
create_args["response_format"] = {"type": "json_object"}
else:
create_args["response_format"] = {"type": "text"}
if self.model_info["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output.")
ollama_messages_nested = [to_ollama_type(m) for m in messages]
ollama_messages = [item for sublist in ollama_messages_nested for item in sublist]
if self.model_info["function_calling"] is False and len(tools) > 0:
raise ValueError("Model does not support function calling")
if len(tools) > 0:
converted_tools = convert_tools(tools)
stream_future = asyncio.ensure_future(
self._client.chat( # type: ignore
# model=self._model_name,
messages=ollama_messages,
tools=converted_tools,
stream=True,
format=response_format_value,
**create_args_no_response_format,
)
)
else:
stream_future = asyncio.ensure_future(
self._client.chat( # type: ignore
# model=self._model_name,
messages=ollama_messages,
stream=True,
format=response_format_value,
**create_args_no_response_format,
)
create_params = self._process_create_args(
messages,
tools,
json_output,
extra_create_args,
)
stream_future = asyncio.ensure_future(
self._client.chat( # type: ignore
# model=self._model_name,
messages=create_params.messages,
tools=create_params.tools if len(create_params.tools) > 0 else None,
stream=True,
format=create_params.format,
**create_params.create_args,
)
)
if cancellation_token is not None:
cancellation_token.link_future(stream_future)
stream = await stream_future
@@ -689,7 +652,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
# Emit the start event.
logger.info(
LLMStreamStartEvent(
messages=[m.model_dump() for m in ollama_messages],
messages=[m.model_dump() for m in create_params.messages],
)
)
# set the stop_reason for the usage chunk to the prior stop_reason

View File

@@ -25,144 +25,168 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
"function_calling": True,
"json_output": True,
"family": ModelFamily.O3,
"structured_output": True,
},
"o1-2024-12-17": {
"vision": False,
"function_calling": False,
"json_output": False,
"family": ModelFamily.O1,
"structured_output": True,
},
"o1-preview-2024-09-12": {
"vision": False,
"function_calling": False,
"json_output": False,
"family": ModelFamily.O1,
"structured_output": True,
},
"o1-mini-2024-09-12": {
"vision": False,
"function_calling": False,
"json_output": False,
"family": ModelFamily.O1,
"structured_output": False,
},
"gpt-4o-2024-11-20": {
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GPT_4O,
"structured_output": True,
},
"gpt-4o-2024-08-06": {
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GPT_4O,
"structured_output": True,
},
"gpt-4o-2024-05-13": {
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GPT_4O,
"structured_output": False,
},
"gpt-4o-mini-2024-07-18": {
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GPT_4O,
"structured_output": True,
},
"gpt-4-turbo-2024-04-09": {
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GPT_4,
"structured_output": False,
},
"gpt-4-0125-preview": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GPT_4,
"structured_output": False,
},
"gpt-4-1106-preview": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GPT_4,
"structured_output": False,
},
"gpt-4-1106-vision-preview": {
"vision": True,
"function_calling": False,
"json_output": False,
"family": ModelFamily.GPT_4,
"structured_output": False,
},
"gpt-4-0613": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GPT_4,
"structured_output": False,
},
"gpt-4-32k-0613": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GPT_4,
"structured_output": False,
},
"gpt-3.5-turbo-0125": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GPT_35,
"structured_output": False,
},
"gpt-3.5-turbo-1106": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GPT_35,
"structured_output": False,
},
"gpt-3.5-turbo-instruct": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GPT_35,
"structured_output": False,
},
"gpt-3.5-turbo-0613": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GPT_35,
"structured_output": False,
},
"gpt-3.5-turbo-16k-0613": {
"vision": False,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GPT_35,
"structured_output": False,
},
"gemini-1.5-flash": {
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GEMINI_1_5_FLASH,
"structured_output": True,
},
"gemini-1.5-flash-8b": {
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GEMINI_1_5_FLASH,
"structured_output": True,
},
"gemini-1.5-pro": {
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GEMINI_1_5_PRO,
"structured_output": True,
},
"gemini-2.0-flash": {
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GEMINI_2_0_FLASH,
"structured_output": True,
},
"gemini-2.0-flash-lite-preview-02-05": {
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GEMINI_2_0_FLASH,
"structured_output": True,
},
}

View File

@@ -7,6 +7,7 @@ import os
import re
import warnings
from asyncio import Task
from dataclasses import dataclass
from typing import (
Any,
AsyncGenerator,
@@ -69,7 +70,12 @@ from openai.types.chat import (
)
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from openai.types.shared_params import FunctionDefinition, FunctionParameters
from openai.types.shared_params import (
FunctionDefinition,
FunctionParameters,
ResponseFormatJSONObject,
ResponseFormatText,
)
from pydantic import BaseModel, SecretStr
from typing_extensions import Self, Unpack
@@ -348,6 +354,14 @@ def assert_valid_name(name: str) -> str:
return name
@dataclass
class CreateParams:
messages: List[ChatCompletionMessageParam]
tools: List[ChatCompletionToolParam]
response_format: Optional[Type[BaseModel]]
create_args: Dict[str, Any]
class BaseOpenAIChatCompletionClient(ChatCompletionClient):
def __init__(
self,
@@ -400,15 +414,13 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
def create_from_config(cls, config: Dict[str, Any]) -> ChatCompletionClient:
return OpenAIChatCompletionClient(**config)
async def create(
def _process_create_args(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> CreateResult:
tools: Sequence[Tool | ToolSchema],
json_output: Optional[bool | type[BaseModel]],
extra_create_args: Mapping[str, Any],
) -> CreateParams:
# Make sure all extra_create_args are valid
extra_create_args_keys = set(extra_create_args.keys())
if not create_kwargs.issuperset(extra_create_args_keys):
@@ -418,23 +430,56 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
create_args = self._create_args.copy()
create_args.update(extra_create_args)
# Declare use_beta_client
use_beta_client: bool = False
# The response format value to use for the beta client.
response_format_value: Optional[Type[BaseModel]] = None
if "response_format" in create_args:
# Legacy support for getting beta client mode from response_format.
value = create_args["response_format"]
# If value is a Pydantic model class, use the beta client
if isinstance(value, type) and issubclass(value, BaseModel):
if self.model_info["structured_output"] is False:
raise ValueError("Model does not support structured output.")
warnings.warn(
"Using response_format to specify structured output type will be deprecated. "
"Use json_output instead.",
DeprecationWarning,
stacklevel=2,
)
response_format_value = value
use_beta_client = True
else:
# response_format_value is not a Pydantic model class
use_beta_client = False
response_format_value = None
# Remove response_format from create_args to prevent passing it twice.
del create_args["response_format"]
# Remove 'response_format' from create_args to prevent passing it twice
create_args_no_response_format = {k: v for k, v in create_args.items() if k != "response_format"}
if json_output is not None:
if self.model_info["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output.")
if json_output is True:
# JSON mode.
create_args["response_format"] = ResponseFormatJSONObject(type="json_object")
elif json_output is False:
# Text mode.
create_args["response_format"] = ResponseFormatText(type="text")
elif isinstance(json_output, type) and issubclass(json_output, BaseModel):
if self.model_info["structured_output"] is False:
raise ValueError("Model does not support structured output.")
if response_format_value is not None:
raise ValueError(
"response_format and json_output cannot be set to a Pydantic model class at the same time."
)
# Beta client mode with Pydantic model class.
response_format_value = json_output
else:
raise ValueError(f"json_output must be a boolean or a Pydantic model class, got {type(json_output)}")
if response_format_value is not None and "response_format" in create_args:
warnings.warn(
"response_format is found in extra_create_args while json_output is set to a Pydantic model class. "
"Skipping the response_format in extra_create_args in favor of the json_output. "
"Structured output will be used.",
UserWarning,
stacklevel=2,
)
# If using beta client, remove response_format from create_args to prevent passing it twice
del create_args["response_format"]
# TODO: allow custom handling.
# For now we raise an error if images are present and vision is not supported
@@ -444,15 +489,6 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
raise ValueError("Model does not support vision and image was provided")
if json_output is not None:
if self.model_info["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output.")
if json_output is True:
create_args["response_format"] = {"type": "json_object"}
else:
create_args["response_format"] = {"type": "text"}
if self.model_info["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output.")
@@ -461,67 +497,57 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
if self.model_info["function_calling"] is False and len(tools) > 0:
raise ValueError("Model does not support function calling")
converted_tools = convert_tools(tools)
return CreateParams(
messages=oai_messages,
tools=converted_tools,
response_format=response_format_value,
create_args=create_args,
)
async def create(
self,
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> CreateResult:
create_params = self._process_create_args(
messages,
tools,
json_output,
extra_create_args,
)
future: Union[Task[ParsedChatCompletion[BaseModel]], Task[ChatCompletion]]
if len(tools) > 0:
converted_tools = convert_tools(tools)
if use_beta_client:
# Pass response_format_value if it's not None
if response_format_value is not None:
future = asyncio.ensure_future(
self._client.beta.chat.completions.parse(
messages=oai_messages,
tools=converted_tools,
response_format=response_format_value,
**create_args_no_response_format,
)
)
else:
future = asyncio.ensure_future(
self._client.beta.chat.completions.parse(
messages=oai_messages,
tools=converted_tools,
**create_args_no_response_format,
)
)
else:
future = asyncio.ensure_future(
self._client.chat.completions.create(
messages=oai_messages,
stream=False,
tools=converted_tools,
**create_args,
)
if create_params.response_format is not None:
# Use beta client if response_format is not None
future = asyncio.ensure_future(
self._client.beta.chat.completions.parse(
messages=create_params.messages,
tools=create_params.tools if len(create_params.tools) > 0 else NOT_GIVEN,
response_format=create_params.response_format,
**create_params.create_args,
)
)
else:
if use_beta_client:
if response_format_value is not None:
future = asyncio.ensure_future(
self._client.beta.chat.completions.parse(
messages=oai_messages,
response_format=response_format_value,
**create_args_no_response_format,
)
)
else:
future = asyncio.ensure_future(
self._client.beta.chat.completions.parse(
messages=oai_messages,
**create_args_no_response_format,
)
)
else:
future = asyncio.ensure_future(
self._client.chat.completions.create(
messages=oai_messages,
stream=False,
**create_args,
)
# Use the regular client
future = asyncio.ensure_future(
self._client.chat.completions.create(
messages=create_params.messages,
stream=False,
tools=create_params.tools if len(create_params.tools) > 0 else NOT_GIVEN,
**create_params.create_args,
)
)
if cancellation_token is not None:
cancellation_token.link_future(future)
result: Union[ParsedChatCompletion[BaseModel], ChatCompletion] = await future
if use_beta_client:
if create_params.response_format is not None:
result = cast(ParsedChatCompletion[Any], result)
usage = RequestUsage(
@@ -532,7 +558,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
logger.info(
LLMCallEvent(
messages=cast(List[Dict[str, Any]], oai_messages),
messages=cast(List[Dict[str, Any]], create_params.messages),
response=result.model_dump(),
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
@@ -627,7 +653,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
max_consecutive_empty_chunk_tolerance: int = 0,
@@ -638,7 +664,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
Args:
messages (Sequence[LLMMessage]): A sequence of messages to be processed.
tools (Sequence[Tool | ToolSchema], optional): A sequence of tools to be used in the completion. Defaults to `[]`.
json_output (Optional[bool], optional): If True, the output will be in JSON format. Defaults to None.
json_output (Optional[bool | type[BaseModel]], optional): If True, the output will be in JSON format. If a Pydantic model class, the output will be in that format. Defaults to None.
extra_create_args (Mapping[str, Any], optional): Additional arguments for the creation process. Default to `{}`.
cancellation_token (Optional[CancellationToken], optional): A token to cancel the operation. Defaults to None.
max_consecutive_empty_chunk_tolerance (int): [Deprecated] This parameter is deprecated, empty chunks will be skipped.
@@ -663,55 +689,12 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
- `frequency_penalty` (float): A value between -2.0 and 2.0 that penalizes new tokens based on their existing frequency in the text so far, decreasing the likelihood of repeated phrases.
- `presence_penalty` (float): A value between -2.0 and 2.0 that penalizes new tokens based on whether they appear in the text so far, encouraging the model to talk about new topics.
"""
# Make sure all extra_create_args are valid
extra_create_args_keys = set(extra_create_args.keys())
if not create_kwargs.issuperset(extra_create_args_keys):
raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
# Copy the create args and overwrite anything in extra_create_args
create_args = self._create_args.copy()
create_args.update(extra_create_args)
# Declare use_beta_client
use_beta_client: bool = False
response_format_value: Optional[Type[BaseModel]] = None
if "response_format" in create_args:
value = create_args["response_format"]
# If value is a Pydantic model class, use the beta client
if isinstance(value, type) and issubclass(value, BaseModel):
response_format_value = value
use_beta_client = True
else:
# response_format_value is not a Pydantic model class
use_beta_client = False
response_format_value = None
# Remove 'response_format' from create_args to prevent passing it twice
create_args_no_response_format = {k: v for k, v in create_args.items() if k != "response_format"}
# TODO: allow custom handling.
# For now we raise an error if images are present and vision is not supported
if self.model_info["vision"] is False:
for message in messages:
if isinstance(message, UserMessage):
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
raise ValueError("Model does not support vision and image was provided")
if json_output is not None:
if self.model_info["json_output"] is False and json_output is True:
raise ValueError("Model does not support JSON output")
if json_output is True:
create_args["response_format"] = {"type": "json_object"}
else:
create_args["response_format"] = {"type": "text"}
oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages]
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
if self.model_info["function_calling"] is False and len(tools) > 0:
raise ValueError("Model does not support function calling")
create_params = self._process_create_args(
messages,
tools,
json_output,
extra_create_args,
)
if max_consecutive_empty_chunk_tolerance != 0:
warnings.warn(
@@ -720,22 +703,19 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
stacklevel=2,
)
tool_params = convert_tools(tools)
# Get the async generator of chunks.
if use_beta_client:
if create_params.response_format is not None:
chunks = self._create_stream_chunks_beta_client(
tool_params=tool_params,
oai_messages=oai_messages,
response_format=response_format_value,
create_args_no_response_format=create_args_no_response_format,
tool_params=create_params.tools,
oai_messages=create_params.messages,
response_format=create_params.response_format,
create_args_no_response_format=create_params.create_args,
cancellation_token=cancellation_token,
)
else:
chunks = self._create_stream_chunks(
tool_params=tool_params,
oai_messages=oai_messages,
create_args=create_args,
tool_params=create_params.tools,
oai_messages=create_params.messages,
create_args=create_params.create_args,
cancellation_token=cancellation_token,
)
@@ -762,7 +742,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
# Emit the start event.
logger.info(
LLMStreamStartEvent(
messages=cast(List[Dict[str, Any]], oai_messages),
messages=cast(List[Dict[str, Any]], create_params.messages),
)
)
# Empty chunks has been observed when the endpoint is under heavy load.
@@ -839,7 +819,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
raise ValueError("Function calls are not supported in this context")
# We need to get the model from the last chunk, if available.
model = maybe_model or create_args["model"]
model = maybe_model or create_params.create_args["model"]
model = model.replace("gpt-35", "gpt-3.5") # hack for Azure API
# Because the usage chunk is not guaranteed to be the last chunk, we need to check if it is available.
@@ -1151,6 +1131,7 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
"function_calling": False,
"json_output": False,
"family": ModelFamily.R1,
"structured_output": True,
},
)

View File

@@ -139,7 +139,11 @@ class ReplayChatCompletionClient(ChatCompletionClient, Component[ReplayChatCompl
validate_model_info(self._model_info)
else:
self._model_info = ModelInfo(
vision=False, function_calling=False, json_output=False, family=ModelFamily.UNKNOWN
vision=False,
function_calling=False,
json_output=False,
family=ModelFamily.UNKNOWN,
structured_output=False,
)
self._total_available_tokens = 10000
self._cur_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
@@ -158,7 +162,7 @@ class ReplayChatCompletionClient(ChatCompletionClient, Component[ReplayChatCompl
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> CreateResult:
@@ -197,7 +201,7 @@ class ReplayChatCompletionClient(ChatCompletionClient, Component[ReplayChatCompl
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> AsyncGenerator[Union[str, CreateResult], None]:

View File

@@ -135,6 +135,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
"json_output": True,
"vision": True,
"family": ModelFamily.CLAUDE_3_5_SONNET,
"structured_output": True,
},
)
@@ -201,6 +202,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
"function_calling": True,
"json_output": True,
"vision": True,
"structured_output": True,
},
)
@@ -283,7 +285,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
self._prompt_settings = prompt_settings
self._sk_client = sk_client
self._model_info = model_info or ModelInfo(
vision=False, function_calling=False, json_output=False, family=ModelFamily.UNKNOWN
vision=False, function_calling=False, json_output=False, family=ModelFamily.UNKNOWN, structured_output=False
)
validate_model_info(self._model_info)
self._total_prompt_tokens = 0
@@ -437,7 +439,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> CreateResult:
@@ -465,6 +467,9 @@ class SKChatCompletionAdapter(ChatCompletionClient):
Returns:
CreateResult: The result of the chat completion.
"""
if isinstance(json_output, type) and issubclass(json_output, BaseModel):
raise ValueError("structured output is not currently supported in SKChatCompletionAdapter")
kernel = self._get_kernel(extra_create_args)
chat_history = self._convert_to_chat_history(messages)
@@ -545,7 +550,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
messages: Sequence[LLMMessage],
*,
tools: Sequence[Tool | ToolSchema] = [],
json_output: Optional[bool] = None,
json_output: Optional[bool | type[BaseModel]] = None,
extra_create_args: Mapping[str, Any] = {},
cancellation_token: Optional[CancellationToken] = None,
) -> AsyncGenerator[Union[str, CreateResult], None]:
@@ -574,6 +579,9 @@ class SKChatCompletionAdapter(ChatCompletionClient):
Union[str, CreateResult]: Either a string chunk of the response or a CreateResult containing function calls.
"""
if isinstance(json_output, type) and issubclass(json_output, BaseModel):
raise ValueError("structured output is not currently supported in SKChatCompletionAdapter")
kernel = self._get_kernel(extra_create_args)
chat_history = self._convert_to_chat_history(messages)
user_settings = self._get_prompt_settings(extra_create_args)

View File

@@ -88,6 +88,7 @@ def azure_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient
"function_calling": False,
"vision": False,
"family": "unknown",
"structured_output": False,
},
model="model",
)
@@ -101,6 +102,7 @@ def azure_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient
"function_calling": False,
"vision": False,
"family": "unknown",
"structured_output": False,
},
model="model",
)
@@ -117,6 +119,7 @@ async def test_azure_ai_chat_completion_client_validation() -> None:
"function_calling": False,
"vision": False,
"family": "unknown",
"structured_output": False,
},
)
@@ -129,6 +132,7 @@ async def test_azure_ai_chat_completion_client_validation() -> None:
"function_calling": False,
"vision": False,
"family": "unknown",
"structured_output": False,
},
)
@@ -141,6 +145,7 @@ async def test_azure_ai_chat_completion_client_validation() -> None:
"function_calling": False,
"vision": False,
"family": "unknown",
"structured_output": False,
},
)
@@ -267,6 +272,7 @@ def function_calling_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompl
"function_calling": True,
"vision": False,
"family": "function_calling_model",
"structured_output": False,
},
model="model",
)
@@ -354,6 +360,7 @@ async def test_multimodal_supported(monkeypatch: pytest.MonkeyPatch) -> None:
"function_calling": False,
"vision": True,
"family": "vision_model",
"structured_output": False,
},
model="model",
)
@@ -436,6 +443,7 @@ async def test_r1_content(monkeypatch: pytest.MonkeyPatch) -> None:
"function_calling": False,
"vision": True,
"family": ModelFamily.R1,
"structured_output": False,
},
model="model",
)

View File

@@ -11,10 +11,12 @@ from autogen_core.models import (
)
from autogen_ext.models.cache import ChatCompletionCache
from autogen_ext.models.replay import ReplayChatCompletionClient
from pydantic import BaseModel
def get_test_data() -> Tuple[list[str], list[str], SystemMessage, ChatCompletionClient, ChatCompletionCache]:
num_messages = 3
def get_test_data(
num_messages: int = 3,
) -> Tuple[list[str], list[str], SystemMessage, ChatCompletionClient, ChatCompletionCache]:
responses = [f"This is dummy message number {i}" for i in range(num_messages)]
prompts = [f"This is dummy prompt number {i}" for i in range(num_messages)]
system_prompt = SystemMessage(content="This is a system prompt")
@@ -53,6 +55,54 @@ async def test_cache_basic_with_args() -> None:
assert response2.content == responses[2]
@pytest.mark.asyncio
async def test_cache_structured_output_with_args() -> None:
responses, prompts, system_prompt, _, cached_client = get_test_data(num_messages=4)
class Answer(BaseModel):
thought: str
answer: str
class Answer2(BaseModel):
calculation: str
answer: str
response0 = await cached_client.create(
[system_prompt, UserMessage(content=prompts[0], source="user")], json_output=Answer
)
assert isinstance(response0, CreateResult)
assert not response0.cached
assert response0.content == responses[0]
response1 = await cached_client.create(
[system_prompt, UserMessage(content=prompts[1], source="user")], json_output=Answer
)
assert not response1.cached
assert response1.content == responses[1]
# Cached output.
response0_cached = await cached_client.create(
[system_prompt, UserMessage(content=prompts[0], source="user")], json_output=Answer
)
assert isinstance(response0, CreateResult)
assert response0_cached.cached
assert response0_cached.content == responses[0]
# Without the json_output argument, the cache should not be hit.
response0 = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
assert isinstance(response0, CreateResult)
assert not response0.cached
assert response0.content == responses[2]
# With a different output type, the cache should not be hit.
response0 = await cached_client.create(
[system_prompt, UserMessage(content=prompts[1], source="user")], json_output=Answer2
)
assert isinstance(response0, CreateResult)
assert not response0.cached
assert response0.content == responses[3]
@pytest.mark.asyncio
async def test_cache_model_and_count_api() -> None:
_, prompts, system_prompt, replay_client, cached_client = get_test_data()

View File

@@ -9,6 +9,8 @@ import torch
# from autogen_agentchat.messages import TextMessage
# from autogen_core import CancellationToken
from autogen_core.models import RequestUsage, SystemMessage, UserMessage
from llama_cpp import ChatCompletionRequestResponseFormat
from pydantic import BaseModel
# from autogen_core.tools import FunctionTool
try:
@@ -21,6 +23,13 @@ except ImportError:
pytest.skip("Skipping LlamaCppChatCompletionClient tests: llama-cpp-python not installed", allow_module_level=True)
class AgentResponse(BaseModel):
"""A response from the agent."""
thoughts: str
content: str
# Fake Llama class to simulate responses
class FakeLlama:
def __init__(
@@ -30,16 +39,29 @@ class FakeLlama:
) -> None:
self.model_path = model_path
self.n_ctx = lambda: 1024
self._structured_response = AgentResponse(thoughts="Test thoughts", content="Test content")
# Added tokenize method for testing purposes.
def tokenize(self, b: bytes) -> list[int]:
return list(b)
def create_chat_completion(
self, messages: Any, tools: List[ChatCompletionMessageToolCalls] | None, stream: bool = False
self,
messages: Any,
tools: List[ChatCompletionMessageToolCalls] | None,
stream: bool = False,
response_format: ChatCompletionRequestResponseFormat | None = None,
) -> dict[str, Any]:
# Return fake non-streaming response.
if response_format is not None:
assert self._structured_response is not None
# If response_format is provided, return a different format.
return {
"usage": {"prompt_tokens": 1, "completion_tokens": 2},
"choices": [{"message": {"content": self._structured_response.model_dump_json()}}],
}
return {
"usage": {"prompt_tokens": 1, "completion_tokens": 2},
"choices": [{"message": {"content": "Fake response"}}],
@@ -81,6 +103,22 @@ async def test_llama_cpp_create(get_completion_client: "ContextManager[type[Llam
assert result.finish_reason in ("stop", "unknown")
@pytest.mark.asyncio
async def test_llama_cpp_create_structured_output(
get_completion_client: "ContextManager[type[LlamaCppChatCompletionClient]]",
) -> None:
with get_completion_client as Client:
client = Client(model_path="dummy")
messages: Sequence[Union[SystemMessage, UserMessage]] = [
SystemMessage(content="Test system"),
UserMessage(content="Test user", source="user"),
]
result = await client.create(messages=messages, json_output=AgentResponse)
assert isinstance(result.content, str)
assert AgentResponse.model_validate_json(result.content).thoughts == "Test thoughts"
assert AgentResponse.model_validate_json(result.content).content == "Test content"
# Commmented out due to raising not implemented error will leave in case streaming is supported in the future.
# @pytest.mark.asyncio
# async def test_llama_cpp_create_stream(
@@ -133,7 +171,12 @@ async def test_llama_cpp_integration_non_streaming() -> None:
from autogen_ext.models.llama_cpp._llama_cpp_completion_client import LlamaCppChatCompletionClient
client = LlamaCppChatCompletionClient(
repo_id="unsloth/phi-4-GGUF", filename="phi-4-Q2_K_L.gguf", n_gpu_layers=-1, seed=1337, n_ctx=5000
repo_id="unsloth/phi-4-GGUF",
filename="phi-4-Q2_K_L.gguf",
n_gpu_layers=-1,
seed=1337,
n_ctx=5000,
verbose=False,
)
messages: Sequence[Union[SystemMessage, UserMessage]] = [
SystemMessage(content="You are a helpful assistant."),
@@ -143,6 +186,30 @@ async def test_llama_cpp_integration_non_streaming() -> None:
assert isinstance(result.content, str) and len(result.content.strip()) > 0
@pytest.mark.asyncio
async def test_llama_cpp_integration_non_streaming_structured_output() -> None:
if not ((hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) or torch.cuda.is_available()):
pytest.skip("Skipping LlamaCpp integration tests: GPU not available not set")
from autogen_ext.models.llama_cpp._llama_cpp_completion_client import LlamaCppChatCompletionClient
client = LlamaCppChatCompletionClient(
repo_id="unsloth/phi-4-GGUF",
filename="phi-4-Q2_K_L.gguf",
n_gpu_layers=-1,
seed=1337,
n_ctx=5000,
verbose=False,
)
messages: Sequence[Union[SystemMessage, UserMessage]] = [
SystemMessage(content="You are a helpful assistant."),
UserMessage(content="Hello, how are you?", source="user"),
]
result = await client.create(messages=messages, json_output=AgentResponse)
assert isinstance(result.content, str) and len(result.content.strip()) > 0
assert AgentResponse.model_validate_json(result.content)
# Commmented out due to raising not implemented error will leave in case streaming is supported in the future.
# @pytest.mark.asyncio
# async def test_llama_cpp_integration_streaming() -> None:

View File

@@ -1,11 +1,24 @@
from typing import Any, Mapping
import json
import logging
from typing import Any, AsyncGenerator, List, Mapping
import httpx
import pytest
from autogen_core.models._types import UserMessage
import pytest_asyncio
from autogen_core import FunctionCall
from autogen_core.models import (
AssistantMessage,
CreateResult,
FunctionExecutionResult,
FunctionExecutionResultMessage,
UserMessage,
)
from autogen_core.tools import FunctionTool
from autogen_ext.models.ollama import OllamaChatCompletionClient
from autogen_ext.models.ollama._ollama_client import OLLAMA_VALID_CREATE_KWARGS_KEYS
from httpx import Response
from ollama import AsyncClient
from ollama import AsyncClient, ChatResponse, Message
from pydantic import BaseModel
def _mock_request(*args: Any, **kwargs: Any) -> Response:
@@ -47,3 +60,544 @@ def test_create_args_from_config_drops_unexpected_kwargs() -> None:
for arg in final_create_args.keys():
assert arg in OLLAMA_VALID_CREATE_KWARGS_KEYS
@pytest.mark.asyncio
async def test_create(monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture) -> None:
model = "llama3.2"
content_raw = "Hello world! This is a test response. Test response."
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
return ChatResponse(
model=model,
done=True,
done_reason="stop",
message=Message(
role="assistant",
content=content_raw,
),
prompt_eval_count=10,
eval_count=12,
)
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
with caplog.at_level(logging.INFO):
client = OllamaChatCompletionClient(model=model)
create_result = await client.create(
messages=[
UserMessage(content="hi", source="user"),
],
)
assert "LLMCall" in caplog.text and content_raw in caplog.text
assert isinstance(create_result.content, str)
assert len(create_result.content) > 0
assert create_result.finish_reason == "stop"
assert create_result.usage is not None
assert create_result.usage.prompt_tokens == 10
assert create_result.usage.completion_tokens == 12
assert create_result.content == content_raw
@pytest.mark.asyncio
async def test_create_stream(monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture) -> None:
model = "llama3.2"
content_raw = "Hello world! This is a test response. Test response."
async def _mock_chat(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatResponse, None]:
assert "stream" in kwargs
assert kwargs["stream"] is True
async def _mock_stream() -> AsyncGenerator[ChatResponse, None]:
chunks = [content_raw[i : i + 5] for i in range(0, len(content_raw), 5)]
# Simulate streaming by yielding chunks of the response
for chunk in chunks[:-1]:
yield ChatResponse(
model=model,
done=False,
message=Message(
role="assistant",
content=chunk,
),
)
yield ChatResponse(
model=model,
done=True,
done_reason="stop",
message=Message(
role="assistant",
content=chunks[-1],
),
prompt_eval_count=10,
eval_count=12,
)
return _mock_stream()
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
client = OllamaChatCompletionClient(model=model)
with caplog.at_level(logging.INFO):
stream = client.create_stream(
messages=[
UserMessage(content="hi", source="user"),
],
)
chunks: List[str | CreateResult] = []
async for chunk in stream:
chunks.append(chunk)
assert "LLMStreamStart" in caplog.text and "hi" in caplog.text
assert "LLMStreamEnd" in caplog.text and content_raw in caplog.text
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
assert isinstance(chunks[-1].content, str)
assert chunks[-1].content == content_raw
assert chunks[-1].finish_reason == "stop"
assert chunks[-1].usage is not None
assert chunks[-1].usage.prompt_tokens == 10
assert chunks[-1].usage.completion_tokens == 12
@pytest.mark.asyncio
async def test_create_tools(monkeypatch: pytest.MonkeyPatch) -> None:
def add(x: int, y: int) -> str:
return str(x + y)
add_tool = FunctionTool(add, description="Add two numbers")
model = "llama3.2"
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
return ChatResponse(
model=model,
done=True,
done_reason="stop",
message=Message(
role="assistant",
tool_calls=[
Message.ToolCall(
function=Message.ToolCall.Function(
name=add_tool.name,
arguments={"x": 2, "y": 2},
),
),
],
),
prompt_eval_count=10,
eval_count=12,
)
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
client = OllamaChatCompletionClient(model=model)
create_result = await client.create(
messages=[
UserMessage(content="hi", source="user"),
],
tools=[add_tool],
)
assert isinstance(create_result.content, list)
assert len(create_result.content) > 0
assert isinstance(create_result.content[0], FunctionCall)
assert create_result.content[0].name == add_tool.name
assert create_result.content[0].arguments == json.dumps({"x": 2, "y": 2})
assert create_result.finish_reason == "function_calls"
assert create_result.usage is not None
assert create_result.usage.prompt_tokens == 10
assert create_result.usage.completion_tokens == 12
@pytest.mark.asyncio
async def test_create_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
class ResponseType(BaseModel):
response: str
model = "llama3.2"
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
return ChatResponse(
model=model,
done=True,
done_reason="stop",
message=Message(
role="assistant",
content=json.dumps({"response": "Hello world!"}),
),
prompt_eval_count=10,
eval_count=12,
)
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
client = OllamaChatCompletionClient(model=model)
create_result = await client.create(
messages=[
UserMessage(content="hi", source="user"),
],
json_output=ResponseType,
)
assert isinstance(create_result.content, str)
assert len(create_result.content) > 0
assert create_result.finish_reason == "stop"
assert create_result.usage is not None
assert create_result.usage.prompt_tokens == 10
assert create_result.usage.completion_tokens == 12
assert ResponseType.model_validate_json(create_result.content)
create_result = await client.create(
messages=[
UserMessage(content="hi", source="user"),
],
extra_create_args={"format": ResponseType.model_json_schema()},
)
assert isinstance(create_result.content, str)
assert len(create_result.content) > 0
assert create_result.finish_reason == "stop"
assert create_result.usage is not None
assert create_result.usage.prompt_tokens == 10
assert create_result.usage.completion_tokens == 12
assert ResponseType.model_validate_json(create_result.content)
# Test case when response_format is in extra_create_args.
with pytest.warns(DeprecationWarning, match="Using response_format will be deprecated. Use json_output instead."):
create_result = await client.create(
messages=[
UserMessage(content="hi", source="user"),
],
extra_create_args={"response_format": ResponseType},
)
# Test case when response_format is in extra_create_args but is not a pydantic model.
with pytest.raises(ValueError, match="response_format must be a Pydantic model class"):
create_result = await client.create(
messages=[
UserMessage(content="hi", source="user"),
],
extra_create_args={"response_format": "json"},
)
# Test case when response_format is in extra_create_args and json_output is also set.
with pytest.raises(
ValueError,
match="response_format and json_output cannot be set to a Pydantic model class at the same time. Use json_output instead.",
):
create_result = await client.create(
messages=[
UserMessage(content="hi", source="user"),
],
extra_create_args={"response_format": ResponseType},
json_output=ResponseType,
)
# Test case when format is in extra_create_args and json_output is also set.
with pytest.raises(
ValueError, match="json_output and format cannot be set at the same time. Use json_output instead."
):
create_result = await client.create(
messages=[
UserMessage(content="hi", source="user"),
],
extra_create_args={"format": ResponseType.model_json_schema()},
json_output=ResponseType,
)
@pytest.mark.asyncio
async def test_create_stream_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
class ResponseType(BaseModel):
response: str
model = "llama3.2"
content_raw = json.dumps({"response": "Hello world! This is a test response. Test response."})
async def _mock_chat(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatResponse, None]:
assert "stream" in kwargs
assert kwargs["stream"] is True
async def _mock_stream() -> AsyncGenerator[ChatResponse, None]:
chunks = [content_raw[i : i + 5] for i in range(0, len(content_raw), 5)]
# Simulate streaming by yielding chunks of the response
for chunk in chunks[:-1]:
yield ChatResponse(
model=model,
done=False,
message=Message(
role="assistant",
content=chunk,
),
)
yield ChatResponse(
model=model,
done=True,
done_reason="stop",
message=Message(
role="assistant",
content=chunks[-1],
),
prompt_eval_count=10,
eval_count=12,
)
return _mock_stream()
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
client = OllamaChatCompletionClient(model=model)
stream = client.create_stream(
messages=[
UserMessage(content="hi", source="user"),
],
json_output=ResponseType,
)
chunks: List[str | CreateResult] = []
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
assert isinstance(chunks[-1].content, str)
assert chunks[-1].content == content_raw
assert chunks[-1].finish_reason == "stop"
assert chunks[-1].usage is not None
assert chunks[-1].usage.prompt_tokens == 10
assert chunks[-1].usage.completion_tokens == 12
assert ResponseType.model_validate_json(chunks[-1].content)
@pytest_asyncio.fixture # type: ignore
async def ollama_client(request: pytest.FixtureRequest) -> OllamaChatCompletionClient:
model = request.node.callspec.params["model"] # type: ignore
assert isinstance(model, str)
# Check if the model is running locally.
try:
async with httpx.AsyncClient() as client:
response = await client.get(f"http://localhost:11434/v1/models/{model}")
response.raise_for_status()
except httpx.HTTPStatusError as e:
pytest.skip(f"{model} model is not running locally: {e}")
except httpx.ConnectError as e:
pytest.skip(f"Ollama is not running locally: {e}")
return OllamaChatCompletionClient(model=model)
@pytest.mark.asyncio
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b", "llama3.2:1b"])
async def test_ollama_create(model: str, ollama_client: OllamaChatCompletionClient) -> None:
create_result = await ollama_client.create(
messages=[
UserMessage(
content="Taking two balls from a bag of 10 green balls and 20 red balls, "
"what is the probability of getting a green and a red balls?",
source="user",
),
]
)
assert isinstance(create_result.content, str)
assert len(create_result.content) > 0
assert create_result.finish_reason == "stop"
assert create_result.usage is not None
chunks: List[str | CreateResult] = []
async for chunk in ollama_client.create_stream(
messages=[
UserMessage(
content="Taking two balls from a bag of 10 green balls and 20 red balls, "
"what is the probability of getting a green and a red balls?",
source="user",
),
]
):
chunks.append(chunk)
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
assert chunks[-1].finish_reason == "stop"
assert len(chunks[-1].content) > 0
assert chunks[-1].usage is not None
@pytest.mark.asyncio
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b", "llama3.2:1b"])
async def test_ollama_create_structured_output(model: str, ollama_client: OllamaChatCompletionClient) -> None:
class ResponseType(BaseModel):
calculation: str
result: str
create_result = await ollama_client.create(
messages=[
UserMessage(
content="Taking two balls from a bag of 10 green balls and 20 red balls, "
"what is the probability of getting a green and a red balls?",
source="user",
),
],
json_output=ResponseType,
)
assert isinstance(create_result.content, str)
assert len(create_result.content) > 0
assert create_result.finish_reason == "stop"
assert create_result.usage is not None
assert ResponseType.model_validate_json(create_result.content)
# Test streaming completion with the Ollama deepseek-r1:1.5b model.
chunks: List[str | CreateResult] = []
async for chunk in ollama_client.create_stream(
messages=[
UserMessage(
content="Taking two balls from a bag of 10 green balls and 20 red balls, "
"what is the probability of getting a green and a red balls?",
source="user",
),
],
json_output=ResponseType,
):
chunks.append(chunk)
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
assert chunks[-1].finish_reason == "stop"
assert isinstance(chunks[-1].content, str)
assert len(chunks[-1].content) > 0
assert chunks[-1].usage is not None
assert ResponseType.model_validate_json(chunks[-1].content)
@pytest.mark.asyncio
@pytest.mark.parametrize("model", ["qwen2.5:0.5b", "llama3.2:1b"])
async def test_ollama_create_tools(model: str, ollama_client: OllamaChatCompletionClient) -> None:
def add(x: int, y: int) -> str:
return str(x + y)
add_tool = FunctionTool(add, description="Add two numbers")
create_result = await ollama_client.create(
messages=[
UserMessage(
content="What is 2 + 2? Use the add tool.",
source="user",
),
],
tools=[add_tool],
)
assert isinstance(create_result.content, list)
assert len(create_result.content) > 0
assert isinstance(create_result.content[0], FunctionCall)
assert create_result.content[0].name == add_tool.name
assert create_result.finish_reason == "function_calls"
execution_result = FunctionExecutionResult(
content="4",
name=add_tool.name,
call_id=create_result.content[0].id,
is_error=False,
)
create_result = await ollama_client.create(
messages=[
UserMessage(
content="What is 2 + 2? Use the add tool.",
source="user",
),
AssistantMessage(
content=create_result.content,
source="assistant",
),
FunctionExecutionResultMessage(
content=[execution_result],
),
],
)
assert isinstance(create_result.content, str)
assert len(create_result.content) > 0
assert create_result.finish_reason == "stop"
@pytest.mark.skip("TODO: Does Ollama support structured outputs with tools?")
@pytest.mark.asyncio
@pytest.mark.parametrize("model", ["llama3.2:1b"])
async def test_ollama_create_structured_output_with_tools(
model: str, ollama_client: OllamaChatCompletionClient
) -> None:
class ResponseType(BaseModel):
calculation: str
result: str
def add(x: int, y: int) -> str:
return str(x + y)
add_tool = FunctionTool(add, description="Add two numbers")
create_result = await ollama_client.create(
messages=[
UserMessage(
content="What is 2 + 2? Use the add tool.",
source="user",
),
],
tools=[add_tool],
json_output=ResponseType,
)
assert isinstance(create_result.content, list)
assert len(create_result.content) > 0
assert isinstance(create_result.content[0], FunctionCall)
assert create_result.content[0].name == add_tool.name
assert create_result.finish_reason == "function_calls"
assert create_result.thought is not None
assert ResponseType.model_validate_json(create_result.thought)
@pytest.mark.skip("TODO: Fix streaming with tools")
@pytest.mark.asyncio
@pytest.mark.parametrize("model", ["qwen2.5:0.5b", "llama3.2:1b"])
async def test_ollama_create_stream_tools(model: str, ollama_client: OllamaChatCompletionClient) -> None:
def add(x: int, y: int) -> str:
return str(x + y)
add_tool = FunctionTool(add, description="Add two numbers")
stream = ollama_client.create_stream(
messages=[
UserMessage(
content="What is 2 + 2? Use the add tool.",
source="user",
),
],
tools=[add_tool],
)
chunks: List[str | CreateResult] = []
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
create_result = chunks[-1]
assert isinstance(create_result.content, list)
assert len(create_result.content) > 0
assert isinstance(create_result.content[0], FunctionCall)
assert create_result.content[0].name == add_tool.name
assert create_result.finish_reason == "function_calls"
execution_result = FunctionExecutionResult(
content="4",
name=add_tool.name,
call_id=create_result.content[0].id,
is_error=False,
)
stream = ollama_client.create_stream(
messages=[
UserMessage(
content="What is 2 + 2? Use the add tool.",
source="user",
),
AssistantMessage(
content=create_result.content,
source="assistant",
),
FunctionExecutionResultMessage(
content=[execution_result],
),
],
)
chunks = []
async for chunk in stream:
chunks.append(chunk)
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
create_result = chunks[-1]
assert isinstance(create_result.content, str)
assert len(create_result.content) > 0
assert create_result.finish_reason == "stop"

View File

@@ -218,7 +218,13 @@ async def test_custom_model_with_capabilities() -> None:
model="dummy_model",
base_url="https://api.dummy.com/v0",
api_key="api_key",
model_info={"vision": False, "function_calling": False, "json_output": False, "family": ModelFamily.UNKNOWN},
model_info={
"vision": False,
"function_calling": False,
"json_output": False,
"family": ModelFamily.UNKNOWN,
"structured_output": False,
},
)
assert client
@@ -231,7 +237,13 @@ async def test_azure_openai_chat_completion_client() -> None:
api_key="api_key",
api_version="2020-08-04",
azure_endpoint="https://dummy.com",
model_info={"vision": True, "function_calling": True, "json_output": True, "family": ModelFamily.GPT_4O},
model_info={
"vision": True,
"function_calling": True,
"json_output": True,
"family": ModelFamily.GPT_4O,
"structured_output": True,
},
)
assert client
@@ -446,6 +458,97 @@ def test_convert_tools_accepts_both_tool_and_schema() -> None:
assert converted_tool_schema[0] == converted_tool_schema[1]
@pytest.mark.asyncio
async def test_json_mode(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-11-20"
called_args = {}
async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion:
# Capture the arguments passed to the function
called_args["kwargs"] = kwargs
return ChatCompletion(
id="id1",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(
content=json.dumps({"thoughts": "happy", "response": "happy"}),
role="assistant",
),
)
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
)
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
model_client = OpenAIChatCompletionClient(model=model, api_key="")
# Test that the openai client was called with the correct response format.
create_result = await model_client.create(
messages=[UserMessage(content="I am happy.", source="user")], json_output=True
)
assert isinstance(create_result.content, str)
response = json.loads(create_result.content)
assert response["thoughts"] == "happy"
assert response["response"] == "happy"
assert called_args["kwargs"]["response_format"] == {"type": "json_object"}
# Make sure that the response format is set to json_object when json_output is True, regardless of the extra_create_args.
create_result = await model_client.create(
messages=[UserMessage(content="I am happy.", source="user")],
json_output=True,
extra_create_args={"response_format": "json_object"},
)
assert isinstance(create_result.content, str)
response = json.loads(create_result.content)
assert response["thoughts"] == "happy"
assert response["response"] == "happy"
assert called_args["kwargs"]["response_format"] == {"type": "json_object"}
create_result = await model_client.create(
messages=[UserMessage(content="I am happy.", source="user")],
json_output=True,
extra_create_args={"response_format": "text"},
)
assert isinstance(create_result.content, str)
response = json.loads(create_result.content)
assert response["thoughts"] == "happy"
assert response["response"] == "happy"
# Check that the openai client was called with the correct response format.
assert called_args["kwargs"]["response_format"] == {"type": "json_object"}
# Make sure when json_output is set to False, the response format is always set to text.
create_result = await model_client.create(
messages=[UserMessage(content="I am happy.", source="user")],
json_output=False,
extra_create_args={"response_format": "text"},
)
assert called_args["kwargs"]["response_format"] == {"type": "text"}
create_result = await model_client.create(
messages=[UserMessage(content="I am happy.", source="user")],
json_output=False,
extra_create_args={"response_format": "json_object"},
)
assert called_args["kwargs"]["response_format"] == {"type": "text"}
# Make sure when response_format is set it is used when json_output is not set.
create_result = await model_client.create(
messages=[UserMessage(content="I am happy.", source="user")],
extra_create_args={"response_format": {"type": "json_object"}},
)
assert isinstance(create_result.content, str)
response = json.loads(create_result.content)
assert response["thoughts"] == "happy"
assert response["response"] == "happy"
assert called_args["kwargs"]["response_format"] == {"type": "json_object"}
@pytest.mark.asyncio
async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
class AgentResponse(BaseModel):
@@ -483,11 +586,12 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
model_client = OpenAIChatCompletionClient(
model=model,
api_key="",
response_format=AgentResponse, # type: ignore
)
# Test that the openai client was called with the correct response format.
create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")])
create_result = await model_client.create(
messages=[UserMessage(content="I am happy.", source="user")], json_output=AgentResponse
)
assert isinstance(create_result.content, str)
response = AgentResponse.model_validate(json.loads(create_result.content))
assert (
@@ -496,6 +600,36 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
)
assert response.response == "happy"
# Test that a warning will be raise if response_format is set to a dict.
with pytest.warns(
UserWarning,
match="response_format is found in extra_create_args while json_output is set to a Pydantic model class.",
):
create_result = await model_client.create(
messages=[UserMessage(content="I am happy.", source="user")],
json_output=AgentResponse,
extra_create_args={"response_format": {"type": "json_object"}},
)
# Test that a warning will be raised if response_format is set to a pydantic model.
with pytest.warns(
DeprecationWarning, match="Using response_format to specify structured output type will be deprecated."
):
create_result = await model_client.create(
messages=[UserMessage(content="I am happy.", source="user")],
extra_create_args={"response_format": AgentResponse},
)
# Test that a ValueError will be raised if response_format and json_output are set to a pydantic model.
with pytest.raises(
ValueError, match="response_format and json_output cannot be set to a Pydantic model class at the same time."
):
create_result = await model_client.create(
messages=[UserMessage(content="I am happy.", source="user")],
json_output=AgentResponse,
extra_create_args={"response_format": AgentResponse},
)
@pytest.mark.asyncio
async def test_structured_output_with_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
@@ -544,11 +678,12 @@ async def test_structured_output_with_tool_calls(monkeypatch: pytest.MonkeyPatch
model_client = OpenAIChatCompletionClient(
model=model,
api_key="",
response_format=AgentResponse, # type: ignore
)
# Test that the openai client was called with the correct response format.
create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")])
create_result = await model_client.create(
messages=[UserMessage(content="I am happy.", source="user")], json_output=AgentResponse
)
assert isinstance(create_result.content, list)
assert len(create_result.content) == 1
assert create_result.content[0] == FunctionCall(
@@ -617,12 +752,13 @@ async def test_structured_output_with_streaming(monkeypatch: pytest.MonkeyPatch)
model_client = OpenAIChatCompletionClient(
model=model,
api_key="",
response_format=AgentResponse, # type: ignore
)
# Test that the openai client was called with the correct response format.
chunks: List[str | CreateResult] = []
async for chunk in model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")]):
async for chunk in model_client.create_stream(
messages=[UserMessage(content="I am happy.", source="user")], json_output=AgentResponse
):
chunks.append(chunk)
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
@@ -726,12 +862,13 @@ async def test_structured_output_with_streaming_tool_calls(monkeypatch: pytest.M
model_client = OpenAIChatCompletionClient(
model=model,
api_key="",
response_format=AgentResponse, # type: ignore
)
# Test that the openai client was called with the correct response format.
chunks: List[str | CreateResult] = []
async for chunk in model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")]):
async for chunk in model_client.create_stream(
messages=[UserMessage(content="I am happy.", source="user")], json_output=AgentResponse
):
chunks.append(chunk)
assert len(chunks) > 0
assert isinstance(chunks[-1], CreateResult)
@@ -801,7 +938,13 @@ async def test_r1_think_field(monkeypatch: pytest.MonkeyPatch) -> None:
model_client = OpenAIChatCompletionClient(
model="r1",
api_key="",
model_info={"family": ModelFamily.R1, "vision": False, "function_calling": False, "json_output": False},
model_info={
"family": ModelFamily.R1,
"vision": False,
"function_calling": False,
"json_output": False,
"structured_output": False,
},
)
# Successful completion with think field.
@@ -874,7 +1017,13 @@ async def test_r1_think_field_not_present(monkeypatch: pytest.MonkeyPatch) -> No
model_client = OpenAIChatCompletionClient(
model="r1",
api_key="",
model_info={"family": ModelFamily.R1, "vision": False, "function_calling": False, "json_output": False},
model_info={
"family": ModelFamily.R1,
"vision": False,
"function_calling": False,
"json_output": False,
"structured_output": False,
},
)
# Warning completion when think field is not present.
@@ -1338,11 +1487,12 @@ async def test_openai_structured_output() -> None:
model_client = OpenAIChatCompletionClient(
model="gpt-4o-mini",
api_key=api_key,
response_format=AgentResponse, # type: ignore
)
# Test that the openai client was called with the correct response format.
create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")])
create_result = await model_client.create(
messages=[UserMessage(content="I am happy.", source="user")], json_output=AgentResponse
)
assert isinstance(create_result.content, str)
response = AgentResponse.model_validate(json.loads(create_result.content))
assert response.thoughts
@@ -1362,11 +1512,12 @@ async def test_openai_structured_output_with_streaming() -> None:
model_client = OpenAIChatCompletionClient(
model="gpt-4o-mini",
api_key=api_key,
response_format=AgentResponse, # type: ignore
)
# Test that the openai client was called with the correct response format.
stream = model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")])
stream = model_client.create_stream(
messages=[UserMessage(content="I am happy.", source="user")], json_output=AgentResponse
)
chunks: List[str | CreateResult] = []
async for chunk in stream:
chunks.append(chunk)
@@ -1397,7 +1548,6 @@ async def test_openai_structured_output_with_tool_calls() -> None:
model_client = OpenAIChatCompletionClient(
model="gpt-4o-mini",
api_key=api_key,
response_format=AgentResponse, # type: ignore
)
response1 = await model_client.create(
@@ -1407,6 +1557,7 @@ async def test_openai_structured_output_with_tool_calls() -> None:
],
tools=[tool],
extra_create_args={"tool_choice": "required"},
json_output=AgentResponse,
)
assert isinstance(response1.content, list)
assert len(response1.content) == 1
@@ -1428,6 +1579,7 @@ async def test_openai_structured_output_with_tool_calls() -> None:
]
),
],
json_output=AgentResponse,
)
assert isinstance(response2.content, str)
parsed_response = AgentResponse.model_validate(json.loads(response2.content))
@@ -1454,7 +1606,6 @@ async def test_openai_structured_output_with_streaming_tool_calls() -> None:
model_client = OpenAIChatCompletionClient(
model="gpt-4o-mini",
api_key=api_key,
response_format=AgentResponse, # type: ignore
)
chunks1: List[str | CreateResult] = []
@@ -1465,6 +1616,7 @@ async def test_openai_structured_output_with_streaming_tool_calls() -> None:
],
tools=[tool],
extra_create_args={"tool_choice": "required"},
json_output=AgentResponse,
)
async for chunk in stream1:
chunks1.append(chunk)
@@ -1491,6 +1643,7 @@ async def test_openai_structured_output_with_streaming_tool_calls() -> None:
]
),
],
json_output=AgentResponse,
)
chunks2: List[str | CreateResult] = []
async for chunk in stream2:
@@ -1532,6 +1685,7 @@ async def test_hugging_face() -> None:
"json_output": False,
"vision": False,
"family": ModelFamily.UNKNOWN,
"structured_output": False,
},
)
@@ -1546,6 +1700,7 @@ async def test_ollama() -> None:
"json_output": False,
"vision": False,
"family": ModelFamily.R1,
"structured_output": False,
}
# Check if the model is running locally.
try:

View File

@@ -453,7 +453,9 @@ async def test_sk_chat_completion_default_model_info(sk_client: AzureChatComplet
@pytest.mark.asyncio
async def test_sk_chat_completion_custom_model_info(sk_client: AzureChatCompletion) -> None:
# Create custom model info
custom_model_info = ModelInfo(vision=True, function_calling=True, json_output=True, family=ModelFamily.GPT_4)
custom_model_info = ModelInfo(
vision=True, function_calling=True, json_output=True, family=ModelFamily.GPT_4, structured_output=False
)
# Create adapter with custom model_info
adapter = SKChatCompletionAdapter(sk_client, model_info=custom_model_info)
@@ -522,7 +524,9 @@ async def test_sk_chat_completion_r1_content() -> None:
adapter = SKChatCompletionAdapter(
mock_client,
kernel=kernel,
model_info=ModelInfo(vision=False, function_calling=False, json_output=False, family=ModelFamily.R1),
model_info=ModelInfo(
vision=False, function_calling=False, json_output=False, family=ModelFamily.R1, structured_output=False
),
)
result = await adapter.create(messages=[UserMessage(content="Say hello!", source="user")])