mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
feat: add structured output to model clients (#5936)
This commit is contained in:
@@ -13,6 +13,7 @@ from autogen_core.models import (
|
||||
RequestUsage,
|
||||
)
|
||||
from autogen_core.tools import Tool, ToolSchema
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .page_logger import PageLogger
|
||||
|
||||
@@ -87,7 +88,7 @@ class ChatCompletionClientRecorder(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> CreateResult:
|
||||
@@ -154,7 +155,7 @@ class ChatCompletionClientRecorder(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
|
||||
@@ -61,7 +61,7 @@ from autogen_core.models import (
|
||||
validate_model_info,
|
||||
)
|
||||
from autogen_core.tools import Tool, ToolSchema
|
||||
from pydantic import SecretStr
|
||||
from pydantic import BaseModel, SecretStr
|
||||
from typing_extensions import Self, Unpack
|
||||
|
||||
from . import _model_info
|
||||
@@ -413,7 +413,7 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> CreateResult:
|
||||
@@ -435,6 +435,8 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
|
||||
|
||||
if json_output is True:
|
||||
create_args["response_format"] = {"type": "json_object"}
|
||||
elif isinstance(json_output, type):
|
||||
raise ValueError("Structured output is currently not supported for Anthropic models")
|
||||
|
||||
# Process system message separately
|
||||
system_message = None
|
||||
@@ -568,7 +570,7 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
max_consecutive_empty_chunk_tolerance: int = 0,
|
||||
@@ -595,6 +597,9 @@ class BaseAnthropicChatCompletionClient(ChatCompletionClient):
|
||||
if json_output is True:
|
||||
create_args["response_format"] = {"type": "json_object"}
|
||||
|
||||
if isinstance(json_output, type):
|
||||
raise ValueError("Structured output is currently not supported for Anthropic models")
|
||||
|
||||
# Process system message separately
|
||||
system_message = None
|
||||
anthropic_messages: List[MessageParam] = []
|
||||
|
||||
@@ -12,6 +12,7 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.CLAUDE_3_7_SONNET,
|
||||
"structured_output": False,
|
||||
},
|
||||
# Claude 3.7 Sonnet latest alias
|
||||
"claude-3-7-sonnet-latest": {
|
||||
@@ -19,6 +20,7 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.CLAUDE_3_7_SONNET,
|
||||
"structured_output": False,
|
||||
},
|
||||
# Claude 3 Opus (most powerful)
|
||||
"claude-3-opus-20240229": {
|
||||
@@ -26,6 +28,7 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.CLAUDE_3_5_SONNET,
|
||||
"structured_output": False,
|
||||
},
|
||||
# Claude 3 Sonnet (balanced)
|
||||
"claude-3-sonnet-20240229": {
|
||||
@@ -33,6 +36,7 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.CLAUDE_3_5_SONNET,
|
||||
"structured_output": False,
|
||||
},
|
||||
# Claude 3 Haiku (fastest)
|
||||
"claude-3-haiku-20240307": {
|
||||
@@ -40,6 +44,7 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.CLAUDE_3_5_SONNET,
|
||||
"structured_output": False,
|
||||
},
|
||||
# Claude 3.5 Sonnet
|
||||
"claude-3-5-sonnet-20240620": {
|
||||
@@ -47,6 +52,7 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.CLAUDE_3_5_SONNET,
|
||||
"structured_output": False,
|
||||
},
|
||||
# Claude Instant v1 (legacy)
|
||||
"claude-instant-1.2": {
|
||||
@@ -54,6 +60,7 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.CLAUDE_3_5_SONNET,
|
||||
"structured_output": False,
|
||||
},
|
||||
# Claude 2 (legacy)
|
||||
"claude-2.0": {
|
||||
@@ -61,6 +68,7 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.CLAUDE_3_5_SONNET,
|
||||
"structured_output": False,
|
||||
},
|
||||
# Claude 2.1 (legacy)
|
||||
"claude-2.1": {
|
||||
@@ -68,6 +76,7 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.CLAUDE_3_5_SONNET,
|
||||
"structured_output": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ from azure.ai.inference.models import (
|
||||
from azure.ai.inference.models import (
|
||||
UserMessage as AzureUserMessage,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import AsyncGenerator, Union, Unpack
|
||||
|
||||
from autogen_ext.models.azure.config import (
|
||||
@@ -226,6 +227,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
"function_calling": False,
|
||||
"vision": False,
|
||||
"family": "unknown",
|
||||
"structured_output": False,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -263,6 +265,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
"function_calling": False,
|
||||
"vision": False,
|
||||
"family": "unknown",
|
||||
"structured_output": False,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -323,7 +326,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool | ToolSchema],
|
||||
json_output: Optional[bool],
|
||||
json_output: Optional[bool | type[BaseModel]],
|
||||
create_args: Dict[str, Any],
|
||||
) -> None:
|
||||
if self.model_info["vision"] is False:
|
||||
@@ -336,6 +339,10 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
if self.model_info["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output")
|
||||
|
||||
if isinstance(json_output, type):
|
||||
# TODO: we should support this in the future.
|
||||
raise ValueError("Structured output is not currently supported for AzureAIChatCompletionClient")
|
||||
|
||||
if json_output is True and "response_format" not in create_args:
|
||||
create_args["response_format"] = "json_object"
|
||||
|
||||
@@ -349,7 +356,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> CreateResult:
|
||||
@@ -442,7 +449,7 @@ class AzureAIChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
|
||||
@@ -102,7 +102,7 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
tools: Sequence[Tool | ToolSchema],
|
||||
json_output: Optional[bool],
|
||||
json_output: Optional[bool | type[BaseModel]],
|
||||
extra_create_args: Mapping[str, Any],
|
||||
) -> tuple[Optional[Union[CreateResult, List[Union[str, CreateResult]]]], str]:
|
||||
"""
|
||||
@@ -110,10 +110,17 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
|
||||
Returns a tuple of (cached_result, cache_key).
|
||||
"""
|
||||
|
||||
json_output_data: str | bool | None = None
|
||||
|
||||
if isinstance(json_output, type) and issubclass(json_output, BaseModel):
|
||||
json_output_data = json.dumps(json_output.model_json_schema())
|
||||
elif isinstance(json_output, bool):
|
||||
json_output_data = json_output
|
||||
|
||||
data = {
|
||||
"messages": [message.model_dump() for message in messages],
|
||||
"tools": [(tool.schema if isinstance(tool, Tool) else tool) for tool in tools],
|
||||
"json_output": json_output,
|
||||
"json_output": json_output_data,
|
||||
"extra_create_args": extra_create_args,
|
||||
}
|
||||
serialized_data = json.dumps(data, sort_keys=True)
|
||||
@@ -130,7 +137,7 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> CreateResult:
|
||||
@@ -162,7 +169,7 @@ class ChatCompletionCache(ChatCompletionClient, Component[ChatCompletionCacheCon
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging # added import
|
||||
import re
|
||||
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Sequence, TypedDict, Union, cast
|
||||
@@ -11,6 +12,7 @@ from autogen_core.models import (
|
||||
FinishReasons,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
ModelFamily,
|
||||
ModelInfo,
|
||||
RequestUsage,
|
||||
SystemMessage,
|
||||
@@ -30,6 +32,7 @@ from llama_cpp import (
|
||||
Llama,
|
||||
llama_chat_format,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Unpack
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME) # initialize logger
|
||||
@@ -172,6 +175,7 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
|
||||
This client allows you to interact with LlamaCpp models, either by specifying a local model path or by downloading a model from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
model_info (optional, ModelInfo): The information about the model. Defaults to :attr:`~LlamaCppChatCompletionClient.DEFAULT_MODEL_INFO`.
|
||||
model_path (optional, str): The path to the LlamaCpp model file. Required if repo_id and filename are not provided.
|
||||
repo_id (optional, str): The Hugging Face Hub repository ID. Required if model_path is not provided.
|
||||
filename (optional, str): The filename of the model within the Hugging Face Hub repository. Required if model_path is not provided.
|
||||
@@ -179,7 +183,6 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
|
||||
n_ctx (optional, int): The context size.
|
||||
n_batch (optional, int): The batch size.
|
||||
verbose (optional, bool): Whether to print verbose output.
|
||||
model_info (optional, ModelInfo): The capabilities of the model. Defaults to a ModelInfo instance with function_calling set to True.
|
||||
**kwargs: Additional parameters to pass to the Llama class.
|
||||
|
||||
Examples:
|
||||
@@ -223,6 +226,10 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
|
||||
asyncio.run(main())
|
||||
"""
|
||||
|
||||
DEFAULT_MODEL_INFO: ModelInfo = ModelInfo(
|
||||
vision=False, json_output=True, family=ModelFamily.UNKNOWN, function_calling=True, structured_output=True
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
@@ -234,6 +241,10 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
|
||||
|
||||
if model_info:
|
||||
validate_model_info(model_info)
|
||||
self._model_info = model_info
|
||||
else:
|
||||
# Default model info.
|
||||
self._model_info = self.DEFAULT_MODEL_INFO
|
||||
|
||||
if "repo_id" in kwargs and "filename" in kwargs and kwargs["repo_id"] and kwargs["filename"]:
|
||||
repo_id: str = cast(str, kwargs.pop("repo_id"))
|
||||
@@ -255,10 +266,11 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
json_output: Optional[bool] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> CreateResult:
|
||||
create_args = dict(extra_create_args)
|
||||
# Convert LLMMessage objects to dictionaries with 'role' and 'content'
|
||||
# converted_messages: List[Dict[str, str | Image | list[str | Image] | list[FunctionCall]]] = []
|
||||
converted_messages: list[
|
||||
@@ -283,12 +295,28 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
|
||||
else:
|
||||
raise ValueError(f"Unsupported message type: {type(msg)}")
|
||||
|
||||
if isinstance(json_output, type) and issubclass(json_output, BaseModel):
|
||||
create_args["response_format"] = {"type": "json_object", "schema": json_output.model_json_schema()}
|
||||
elif json_output is True:
|
||||
create_args["response_format"] = {"type": "json_object"}
|
||||
elif json_output is not False and json_output is not None:
|
||||
raise ValueError("json_output must be a boolean, a BaseModel subclass or None.")
|
||||
|
||||
if self.model_info["function_calling"]:
|
||||
response = self.llm.create_chat_completion(
|
||||
messages=converted_messages, tools=convert_tools(tools), stream=False
|
||||
# Run this in on the event loop to avoid blocking.
|
||||
response_future = asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: self.llm.create_chat_completion(
|
||||
messages=converted_messages, tools=convert_tools(tools), stream=False, **create_args
|
||||
),
|
||||
)
|
||||
else:
|
||||
response = self.llm.create_chat_completion(messages=converted_messages, stream=False)
|
||||
response_future = asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: self.llm.create_chat_completion(messages=converted_messages, stream=False, **create_args)
|
||||
)
|
||||
if cancellation_token:
|
||||
cancellation_token.link_future(response_future)
|
||||
response = await response_future
|
||||
|
||||
if not isinstance(response, dict):
|
||||
raise ValueError("Unexpected response type from LlamaCpp model.")
|
||||
@@ -371,7 +399,7 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
# None means do not override the default
|
||||
# A value means to override the client default - often specified in the constructor
|
||||
json_output: Optional[bool] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
@@ -403,7 +431,7 @@ class LlamaCppChatCompletionClient(ChatCompletionClient):
|
||||
|
||||
@property
|
||||
def model_info(self) -> ModelInfo:
|
||||
return ModelInfo(vision=False, json_output=False, family="llama-cpp", function_calling=True)
|
||||
return self._model_info
|
||||
|
||||
def remaining_tokens(
|
||||
self,
|
||||
|
||||
@@ -6,73 +6,300 @@ from autogen_core.models import ModelFamily, ModelInfo
|
||||
# TODO: fix model family?
|
||||
# TODO: json_output is True for all models because ollama supports structured output via pydantic. How to handle this situation?
|
||||
_MODEL_INFO: Dict[str, ModelInfo] = {
|
||||
"all-minilm": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"bge-m3": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"codegemma": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"codellama": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"command-r": {"vision": False, "function_calling": True, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"deepseek-coder": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"all-minilm": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"bge-m3": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"codegemma": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"codellama": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"command-r": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"deepseek-coder": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"deepseek-coder-v2": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"deepseek-r1": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.R1,
|
||||
"structured_output": True,
|
||||
},
|
||||
"dolphin-llama3": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"dolphin-mistral": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"dolphin-mixtral": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"gemma": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"gemma2": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"llama2": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"deepseek-r1": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.R1},
|
||||
"dolphin-llama3": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"dolphin-mistral": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"dolphin-mixtral": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"gemma": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"gemma2": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"llama2": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"llama2-uncensored": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"llama3": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"llama3.1": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"llama3.2": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"llama3.2-vision": {
|
||||
"vision": True,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"llama3.3": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"llava": {
|
||||
"vision": True,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"llava-llama3": {
|
||||
"vision": True,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"mistral": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"mistral-nemo": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"mixtral": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"llama3": {"vision": False, "function_calling": True, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"llama3.1": {"vision": False, "function_calling": True, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"llama3.2": {"vision": False, "function_calling": True, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"llama3.2-vision": {"vision": True, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"llama3.3": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"llava": {"vision": True, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"llava-llama3": {"vision": True, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"mistral": {"vision": False, "function_calling": True, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"mistral-nemo": {"vision": False, "function_calling": True, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"mixtral": {"vision": False, "function_calling": True, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"mxbai-embed-large": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"nomic-embed-text": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"orca-mini": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"phi": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"phi3": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"phi3.5": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"phi4": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"qwen": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"qwen2": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"qwen2.5": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"qwen2.5-coder": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"orca-mini": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"phi": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"phi3": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"phi3.5": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"phi4": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"qwen": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"qwen2": {"vision": False, "function_calling": True, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"qwen2.5": {"vision": False, "function_calling": True, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"qwen2.5-coder": {"vision": False, "function_calling": True, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"snowflake-arctic-embed": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"starcoder2": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"tinyllama": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"wizardlm2": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"yi": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"zephyr": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": True,
|
||||
},
|
||||
"starcoder2": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"tinyllama": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"wizardlm2": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"yi": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
"zephyr": {"vision": False, "function_calling": False, "json_output": True, "family": ModelFamily.UNKNOWN},
|
||||
}
|
||||
|
||||
# TODO: the ollama model card for some of these models were incorrect. I made a best effort to get the actual values, but they aren't guaranteed to be correct.
|
||||
|
||||
@@ -5,12 +5,13 @@ import logging
|
||||
import math
|
||||
import re
|
||||
import warnings
|
||||
from asyncio import Task
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
@@ -48,6 +49,7 @@ from ollama import Image as OllamaImage
|
||||
from ollama import Tool as OllamaTool
|
||||
from ollama._types import ChatRequest
|
||||
from pydantic import BaseModel
|
||||
from pydantic.json_schema import JsonSchemaValue
|
||||
from typing_extensions import Self, Unpack
|
||||
|
||||
from . import _model_info
|
||||
@@ -328,6 +330,7 @@ def normalize_stop_reason(stop_reason: str | None) -> FinishReasons:
|
||||
stop_reason = stop_reason.lower()
|
||||
|
||||
KNOWN_STOP_MAPPINGS: Dict[str, FinishReasons] = {
|
||||
"stop": "stop",
|
||||
"end_turn": "stop",
|
||||
"tool_calls": "function_calls",
|
||||
}
|
||||
@@ -335,6 +338,14 @@ def normalize_stop_reason(stop_reason: str | None) -> FinishReasons:
|
||||
return KNOWN_STOP_MAPPINGS.get(stop_reason, "unknown")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateParams:
|
||||
messages: Sequence[Message]
|
||||
tools: Sequence[OllamaTool]
|
||||
format: Optional[Union[Literal["", "json"], JsonSchemaValue]]
|
||||
create_args: Dict[str, Any]
|
||||
|
||||
|
||||
class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -390,39 +401,61 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
def get_create_args(self) -> Mapping[str, Any]:
|
||||
return self._create_args
|
||||
|
||||
async def create(
|
||||
def _process_create_args(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> CreateResult:
|
||||
# Make sure all extra_create_args are valid
|
||||
# TODO: kwarg checking logic
|
||||
# extra_create_args_keys = set(extra_create_args.keys())
|
||||
# if not create_kwargs.issuperset(extra_create_args_keys):
|
||||
# raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
|
||||
|
||||
tools: Sequence[Tool | ToolSchema],
|
||||
json_output: Optional[bool | type[BaseModel]],
|
||||
extra_create_args: Mapping[str, Any],
|
||||
) -> CreateParams:
|
||||
# Copy the create args and overwrite anything in extra_create_args
|
||||
create_args = self._create_args.copy()
|
||||
create_args.update(extra_create_args)
|
||||
|
||||
response_format_value: Optional[Mapping[str, Any]] = None
|
||||
response_format_value: JsonSchemaValue | Literal["json"] | None = None
|
||||
|
||||
if "response_format" in create_args:
|
||||
warnings.warn(
|
||||
"Using response_format will be deprecated. Use json_output instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
value = create_args["response_format"]
|
||||
# If value is a Pydantic model class, use the beta client
|
||||
if isinstance(value, type) and issubclass(value, BaseModel):
|
||||
response_format_value = value.model_json_schema()
|
||||
# Remove response_format from create_args to prevent passing it twice.
|
||||
del create_args["response_format"]
|
||||
else:
|
||||
# response_format_value is not a Pydantic model class
|
||||
# TODO: Should this be an warning/error?
|
||||
response_format_value = None
|
||||
raise ValueError(f"response_format must be a Pydantic model class, not {type(value)}")
|
||||
|
||||
# Remove 'response_format' from create_args to prevent passing it twice
|
||||
create_args_no_response_format = {k: v for k, v in create_args.items() if k != "response_format"}
|
||||
if json_output is not None:
|
||||
if self.model_info["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output.")
|
||||
if json_output is True:
|
||||
# JSON mode.
|
||||
response_format_value = "json"
|
||||
elif json_output is False:
|
||||
# Text mode.
|
||||
response_format_value = None
|
||||
elif isinstance(json_output, type) and issubclass(json_output, BaseModel):
|
||||
if response_format_value is not None:
|
||||
raise ValueError(
|
||||
"response_format and json_output cannot be set to a Pydantic model class at the same time. "
|
||||
"Use json_output instead."
|
||||
)
|
||||
# Beta client mode with Pydantic model class.
|
||||
response_format_value = json_output.model_json_schema()
|
||||
else:
|
||||
raise ValueError(f"json_output must be a boolean or a Pydantic model class, got {type(json_output)}")
|
||||
|
||||
if "format" in create_args:
|
||||
# Handle the case where format is set from create_args.
|
||||
if json_output is not None:
|
||||
raise ValueError("json_output and format cannot be set at the same time. Use json_output instead.")
|
||||
assert response_format_value is None
|
||||
response_format_value = create_args["format"]
|
||||
# Remove format from create_args to prevent passing it twice.
|
||||
del create_args["format"]
|
||||
|
||||
# TODO: allow custom handling.
|
||||
# For now we raise an error if images are present and vision is not supported
|
||||
@@ -432,15 +465,6 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
|
||||
raise ValueError("Model does not support vision and image was provided")
|
||||
|
||||
if json_output is not None:
|
||||
if self.model_info["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output.")
|
||||
|
||||
if json_output is True:
|
||||
create_args["response_format"] = {"type": "json_object"}
|
||||
else:
|
||||
create_args["response_format"] = {"type": "text"}
|
||||
|
||||
if self.model_info["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output.")
|
||||
|
||||
@@ -448,30 +472,47 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
ollama_messages = [item for sublist in ollama_messages_nested for item in sublist]
|
||||
|
||||
if self.model_info["function_calling"] is False and len(tools) > 0:
|
||||
raise ValueError("Model does not support function calling")
|
||||
future: Task[ChatResponse]
|
||||
if len(tools) > 0:
|
||||
converted_tools = convert_tools(tools)
|
||||
future = asyncio.ensure_future(
|
||||
self._client.chat( # type: ignore
|
||||
# model=self._model_name,
|
||||
messages=ollama_messages,
|
||||
tools=converted_tools,
|
||||
stream=False,
|
||||
format=response_format_value,
|
||||
**create_args_no_response_format,
|
||||
)
|
||||
)
|
||||
else:
|
||||
future = asyncio.ensure_future(
|
||||
self._client.chat( # type: ignore
|
||||
# model=self._model_name,
|
||||
messages=ollama_messages,
|
||||
stream=False,
|
||||
format=response_format_value,
|
||||
**create_args_no_response_format,
|
||||
)
|
||||
raise ValueError("Model does not support function calling and tools were provided")
|
||||
|
||||
converted_tools = convert_tools(tools)
|
||||
|
||||
return CreateParams(
|
||||
messages=ollama_messages,
|
||||
tools=converted_tools,
|
||||
format=response_format_value,
|
||||
create_args=create_args,
|
||||
)
|
||||
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> CreateResult:
|
||||
# Make sure all extra_create_args are valid
|
||||
# TODO: kwarg checking logic
|
||||
# extra_create_args_keys = set(extra_create_args.keys())
|
||||
# if not create_kwargs.issuperset(extra_create_args_keys):
|
||||
# raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
|
||||
create_params = self._process_create_args(
|
||||
messages,
|
||||
tools,
|
||||
json_output,
|
||||
extra_create_args,
|
||||
)
|
||||
future = asyncio.ensure_future(
|
||||
self._client.chat( # type: ignore
|
||||
# model=self._model_name,
|
||||
messages=create_params.messages,
|
||||
tools=create_params.tools if len(create_params.tools) > 0 else None,
|
||||
stream=False,
|
||||
format=create_params.format,
|
||||
**create_params.create_args,
|
||||
)
|
||||
)
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.link_future(future)
|
||||
result: ChatResponse = await future
|
||||
@@ -484,7 +525,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
|
||||
logger.info(
|
||||
LLMCallEvent(
|
||||
messages=[m.model_dump() for m in ollama_messages],
|
||||
messages=[m.model_dump() for m in create_params.messages],
|
||||
response=result.model_dump(),
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
@@ -564,109 +605,31 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
max_consecutive_empty_chunk_tolerance: int = 0,
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
"""
|
||||
Creates an AsyncGenerator that will yield a stream of chat completions based on the provided messages and tools.
|
||||
|
||||
Args:
|
||||
messages (Sequence[LLMMessage]): A sequence of messages to be processed.
|
||||
tools (Sequence[Tool | ToolSchema], optional): A sequence of tools to be used in the completion. Defaults to `[]`.
|
||||
json_output (Optional[bool], optional): If True, the output will be in JSON format. Defaults to None.
|
||||
extra_create_args (Mapping[str, Any], optional): Additional arguments for the creation process. Default to `{}`.
|
||||
cancellation_token (Optional[CancellationToken], optional): A token to cancel the operation. Defaults to None.
|
||||
max_consecutive_empty_chunk_tolerance (int): The maximum number of consecutive empty chunks to tolerate before raising a ValueError. This seems to only be needed to set when using `AzureOpenAIChatCompletionClient`. Defaults to 0.
|
||||
|
||||
Yields:
|
||||
AsyncGenerator[Union[str, CreateResult], None]: A generator yielding the completion results as they are produced.
|
||||
|
||||
In streaming, the default behaviour is not return token usage counts. See: [OpenAI API reference for possible args](https://platform.openai.com/docs/api-reference/chat/create).
|
||||
However `extra_create_args={"stream_options": {"include_usage": True}}` will (if supported by the accessed API)
|
||||
return a final chunk with usage set to a RequestUsage object having prompt and completion token counts,
|
||||
all preceding chunks will have usage as None. See: [stream_options](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options).
|
||||
|
||||
Other examples of OPENAI supported arguments that can be included in `extra_create_args`:
|
||||
- `temperature` (float): Controls the randomness of the output. Higher values (e.g., 0.8) make the output more random, while lower values (e.g., 0.2) make it more focused and deterministic.
|
||||
- `max_tokens` (int): The maximum number of tokens to generate in the completion.
|
||||
- `top_p` (float): An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass.
|
||||
- `frequency_penalty` (float): A value between -2.0 and 2.0 that penalizes new tokens based on their existing frequency in the text so far, decreasing the likelihood of repeated phrases.
|
||||
- `presence_penalty` (float): A value between -2.0 and 2.0 that penalizes new tokens based on whether they appear in the text so far, encouraging the model to talk about new topics.
|
||||
"""
|
||||
# Make sure all extra_create_args are valid
|
||||
# TODO: kwarg checking logic
|
||||
# extra_create_args_keys = set(extra_create_args.keys())
|
||||
# if not create_kwargs.issuperset(extra_create_args_keys):
|
||||
# raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
|
||||
|
||||
# Copy the create args and overwrite anything in extra_create_args
|
||||
create_args = self._create_args.copy()
|
||||
create_args.update(extra_create_args)
|
||||
|
||||
response_format_value: Optional[Mapping[str, Any]] = None
|
||||
|
||||
if "response_format" in create_args:
|
||||
value = create_args["response_format"]
|
||||
# If value is a Pydantic model class, use the beta client
|
||||
if isinstance(value, type) and issubclass(value, BaseModel):
|
||||
response_format_value = value.model_json_schema()
|
||||
else:
|
||||
# response_format_value is not a Pydantic model class
|
||||
response_format_value = None
|
||||
|
||||
# Remove 'response_format' from create_args to prevent passing it twice
|
||||
create_args_no_response_format = {k: v for k, v in create_args.items() if k != "response_format"}
|
||||
|
||||
# TODO: allow custom handling.
|
||||
# For now we raise an error if images are present and vision is not supported
|
||||
if self.model_info["vision"] is False:
|
||||
for message in messages:
|
||||
if isinstance(message, UserMessage):
|
||||
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
|
||||
raise ValueError("Model does not support vision and image was provided")
|
||||
|
||||
if json_output is not None:
|
||||
if self.model_info["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output.")
|
||||
|
||||
if json_output is True:
|
||||
create_args["response_format"] = {"type": "json_object"}
|
||||
else:
|
||||
create_args["response_format"] = {"type": "text"}
|
||||
|
||||
if self.model_info["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output.")
|
||||
|
||||
ollama_messages_nested = [to_ollama_type(m) for m in messages]
|
||||
ollama_messages = [item for sublist in ollama_messages_nested for item in sublist]
|
||||
|
||||
if self.model_info["function_calling"] is False and len(tools) > 0:
|
||||
raise ValueError("Model does not support function calling")
|
||||
|
||||
if len(tools) > 0:
|
||||
converted_tools = convert_tools(tools)
|
||||
stream_future = asyncio.ensure_future(
|
||||
self._client.chat( # type: ignore
|
||||
# model=self._model_name,
|
||||
messages=ollama_messages,
|
||||
tools=converted_tools,
|
||||
stream=True,
|
||||
format=response_format_value,
|
||||
**create_args_no_response_format,
|
||||
)
|
||||
)
|
||||
else:
|
||||
stream_future = asyncio.ensure_future(
|
||||
self._client.chat( # type: ignore
|
||||
# model=self._model_name,
|
||||
messages=ollama_messages,
|
||||
stream=True,
|
||||
format=response_format_value,
|
||||
**create_args_no_response_format,
|
||||
)
|
||||
create_params = self._process_create_args(
|
||||
messages,
|
||||
tools,
|
||||
json_output,
|
||||
extra_create_args,
|
||||
)
|
||||
stream_future = asyncio.ensure_future(
|
||||
self._client.chat( # type: ignore
|
||||
# model=self._model_name,
|
||||
messages=create_params.messages,
|
||||
tools=create_params.tools if len(create_params.tools) > 0 else None,
|
||||
stream=True,
|
||||
format=create_params.format,
|
||||
**create_params.create_args,
|
||||
)
|
||||
)
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.link_future(stream_future)
|
||||
stream = await stream_future
|
||||
@@ -689,7 +652,7 @@ class BaseOllamaChatCompletionClient(ChatCompletionClient):
|
||||
# Emit the start event.
|
||||
logger.info(
|
||||
LLMStreamStartEvent(
|
||||
messages=[m.model_dump() for m in ollama_messages],
|
||||
messages=[m.model_dump() for m in create_params.messages],
|
||||
)
|
||||
)
|
||||
# set the stop_reason for the usage chunk to the prior stop_reason
|
||||
|
||||
@@ -25,144 +25,168 @@ _MODEL_INFO: Dict[str, ModelInfo] = {
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.O3,
|
||||
"structured_output": True,
|
||||
},
|
||||
"o1-2024-12-17": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
"family": ModelFamily.O1,
|
||||
"structured_output": True,
|
||||
},
|
||||
"o1-preview-2024-09-12": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
"family": ModelFamily.O1,
|
||||
"structured_output": True,
|
||||
},
|
||||
"o1-mini-2024-09-12": {
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
"family": ModelFamily.O1,
|
||||
"structured_output": False,
|
||||
},
|
||||
"gpt-4o-2024-11-20": {
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GPT_4O,
|
||||
"structured_output": True,
|
||||
},
|
||||
"gpt-4o-2024-08-06": {
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GPT_4O,
|
||||
"structured_output": True,
|
||||
},
|
||||
"gpt-4o-2024-05-13": {
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GPT_4O,
|
||||
"structured_output": False,
|
||||
},
|
||||
"gpt-4o-mini-2024-07-18": {
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GPT_4O,
|
||||
"structured_output": True,
|
||||
},
|
||||
"gpt-4-turbo-2024-04-09": {
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GPT_4,
|
||||
"structured_output": False,
|
||||
},
|
||||
"gpt-4-0125-preview": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GPT_4,
|
||||
"structured_output": False,
|
||||
},
|
||||
"gpt-4-1106-preview": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GPT_4,
|
||||
"structured_output": False,
|
||||
},
|
||||
"gpt-4-1106-vision-preview": {
|
||||
"vision": True,
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
"family": ModelFamily.GPT_4,
|
||||
"structured_output": False,
|
||||
},
|
||||
"gpt-4-0613": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GPT_4,
|
||||
"structured_output": False,
|
||||
},
|
||||
"gpt-4-32k-0613": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GPT_4,
|
||||
"structured_output": False,
|
||||
},
|
||||
"gpt-3.5-turbo-0125": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GPT_35,
|
||||
"structured_output": False,
|
||||
},
|
||||
"gpt-3.5-turbo-1106": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GPT_35,
|
||||
"structured_output": False,
|
||||
},
|
||||
"gpt-3.5-turbo-instruct": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GPT_35,
|
||||
"structured_output": False,
|
||||
},
|
||||
"gpt-3.5-turbo-0613": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GPT_35,
|
||||
"structured_output": False,
|
||||
},
|
||||
"gpt-3.5-turbo-16k-0613": {
|
||||
"vision": False,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GPT_35,
|
||||
"structured_output": False,
|
||||
},
|
||||
"gemini-1.5-flash": {
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GEMINI_1_5_FLASH,
|
||||
"structured_output": True,
|
||||
},
|
||||
"gemini-1.5-flash-8b": {
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GEMINI_1_5_FLASH,
|
||||
"structured_output": True,
|
||||
},
|
||||
"gemini-1.5-pro": {
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GEMINI_1_5_PRO,
|
||||
"structured_output": True,
|
||||
},
|
||||
"gemini-2.0-flash": {
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GEMINI_2_0_FLASH,
|
||||
"structured_output": True,
|
||||
},
|
||||
"gemini-2.0-flash-lite-preview-02-05": {
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GEMINI_2_0_FLASH,
|
||||
"structured_output": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import os
|
||||
import re
|
||||
import warnings
|
||||
from asyncio import Task
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
@@ -69,7 +70,12 @@ from openai.types.chat import (
|
||||
)
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
|
||||
from openai.types.shared_params import FunctionDefinition, FunctionParameters
|
||||
from openai.types.shared_params import (
|
||||
FunctionDefinition,
|
||||
FunctionParameters,
|
||||
ResponseFormatJSONObject,
|
||||
ResponseFormatText,
|
||||
)
|
||||
from pydantic import BaseModel, SecretStr
|
||||
from typing_extensions import Self, Unpack
|
||||
|
||||
@@ -348,6 +354,14 @@ def assert_valid_name(name: str) -> str:
|
||||
return name
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateParams:
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
tools: List[ChatCompletionToolParam]
|
||||
response_format: Optional[Type[BaseModel]]
|
||||
create_args: Dict[str, Any]
|
||||
|
||||
|
||||
class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -400,15 +414,13 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
def create_from_config(cls, config: Dict[str, Any]) -> ChatCompletionClient:
|
||||
return OpenAIChatCompletionClient(**config)
|
||||
|
||||
async def create(
|
||||
def _process_create_args(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> CreateResult:
|
||||
tools: Sequence[Tool | ToolSchema],
|
||||
json_output: Optional[bool | type[BaseModel]],
|
||||
extra_create_args: Mapping[str, Any],
|
||||
) -> CreateParams:
|
||||
# Make sure all extra_create_args are valid
|
||||
extra_create_args_keys = set(extra_create_args.keys())
|
||||
if not create_kwargs.issuperset(extra_create_args_keys):
|
||||
@@ -418,23 +430,56 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
create_args = self._create_args.copy()
|
||||
create_args.update(extra_create_args)
|
||||
|
||||
# Declare use_beta_client
|
||||
use_beta_client: bool = False
|
||||
# The response format value to use for the beta client.
|
||||
response_format_value: Optional[Type[BaseModel]] = None
|
||||
|
||||
if "response_format" in create_args:
|
||||
# Legacy support for getting beta client mode from response_format.
|
||||
value = create_args["response_format"]
|
||||
# If value is a Pydantic model class, use the beta client
|
||||
if isinstance(value, type) and issubclass(value, BaseModel):
|
||||
if self.model_info["structured_output"] is False:
|
||||
raise ValueError("Model does not support structured output.")
|
||||
warnings.warn(
|
||||
"Using response_format to specify structured output type will be deprecated. "
|
||||
"Use json_output instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
response_format_value = value
|
||||
use_beta_client = True
|
||||
else:
|
||||
# response_format_value is not a Pydantic model class
|
||||
use_beta_client = False
|
||||
response_format_value = None
|
||||
# Remove response_format from create_args to prevent passing it twice.
|
||||
del create_args["response_format"]
|
||||
|
||||
# Remove 'response_format' from create_args to prevent passing it twice
|
||||
create_args_no_response_format = {k: v for k, v in create_args.items() if k != "response_format"}
|
||||
if json_output is not None:
|
||||
if self.model_info["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output.")
|
||||
if json_output is True:
|
||||
# JSON mode.
|
||||
create_args["response_format"] = ResponseFormatJSONObject(type="json_object")
|
||||
elif json_output is False:
|
||||
# Text mode.
|
||||
create_args["response_format"] = ResponseFormatText(type="text")
|
||||
elif isinstance(json_output, type) and issubclass(json_output, BaseModel):
|
||||
if self.model_info["structured_output"] is False:
|
||||
raise ValueError("Model does not support structured output.")
|
||||
if response_format_value is not None:
|
||||
raise ValueError(
|
||||
"response_format and json_output cannot be set to a Pydantic model class at the same time."
|
||||
)
|
||||
# Beta client mode with Pydantic model class.
|
||||
response_format_value = json_output
|
||||
else:
|
||||
raise ValueError(f"json_output must be a boolean or a Pydantic model class, got {type(json_output)}")
|
||||
|
||||
if response_format_value is not None and "response_format" in create_args:
|
||||
warnings.warn(
|
||||
"response_format is found in extra_create_args while json_output is set to a Pydantic model class. "
|
||||
"Skipping the response_format in extra_create_args in favor of the json_output. "
|
||||
"Structured output will be used.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
# If using beta client, remove response_format from create_args to prevent passing it twice
|
||||
del create_args["response_format"]
|
||||
|
||||
# TODO: allow custom handling.
|
||||
# For now we raise an error if images are present and vision is not supported
|
||||
@@ -444,15 +489,6 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
|
||||
raise ValueError("Model does not support vision and image was provided")
|
||||
|
||||
if json_output is not None:
|
||||
if self.model_info["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output.")
|
||||
|
||||
if json_output is True:
|
||||
create_args["response_format"] = {"type": "json_object"}
|
||||
else:
|
||||
create_args["response_format"] = {"type": "text"}
|
||||
|
||||
if self.model_info["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output.")
|
||||
|
||||
@@ -461,67 +497,57 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
|
||||
if self.model_info["function_calling"] is False and len(tools) > 0:
|
||||
raise ValueError("Model does not support function calling")
|
||||
|
||||
converted_tools = convert_tools(tools)
|
||||
|
||||
return CreateParams(
|
||||
messages=oai_messages,
|
||||
tools=converted_tools,
|
||||
response_format=response_format_value,
|
||||
create_args=create_args,
|
||||
)
|
||||
|
||||
async def create(
|
||||
self,
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> CreateResult:
|
||||
create_params = self._process_create_args(
|
||||
messages,
|
||||
tools,
|
||||
json_output,
|
||||
extra_create_args,
|
||||
)
|
||||
future: Union[Task[ParsedChatCompletion[BaseModel]], Task[ChatCompletion]]
|
||||
if len(tools) > 0:
|
||||
converted_tools = convert_tools(tools)
|
||||
if use_beta_client:
|
||||
# Pass response_format_value if it's not None
|
||||
if response_format_value is not None:
|
||||
future = asyncio.ensure_future(
|
||||
self._client.beta.chat.completions.parse(
|
||||
messages=oai_messages,
|
||||
tools=converted_tools,
|
||||
response_format=response_format_value,
|
||||
**create_args_no_response_format,
|
||||
)
|
||||
)
|
||||
else:
|
||||
future = asyncio.ensure_future(
|
||||
self._client.beta.chat.completions.parse(
|
||||
messages=oai_messages,
|
||||
tools=converted_tools,
|
||||
**create_args_no_response_format,
|
||||
)
|
||||
)
|
||||
else:
|
||||
future = asyncio.ensure_future(
|
||||
self._client.chat.completions.create(
|
||||
messages=oai_messages,
|
||||
stream=False,
|
||||
tools=converted_tools,
|
||||
**create_args,
|
||||
)
|
||||
if create_params.response_format is not None:
|
||||
# Use beta client if response_format is not None
|
||||
future = asyncio.ensure_future(
|
||||
self._client.beta.chat.completions.parse(
|
||||
messages=create_params.messages,
|
||||
tools=create_params.tools if len(create_params.tools) > 0 else NOT_GIVEN,
|
||||
response_format=create_params.response_format,
|
||||
**create_params.create_args,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if use_beta_client:
|
||||
if response_format_value is not None:
|
||||
future = asyncio.ensure_future(
|
||||
self._client.beta.chat.completions.parse(
|
||||
messages=oai_messages,
|
||||
response_format=response_format_value,
|
||||
**create_args_no_response_format,
|
||||
)
|
||||
)
|
||||
else:
|
||||
future = asyncio.ensure_future(
|
||||
self._client.beta.chat.completions.parse(
|
||||
messages=oai_messages,
|
||||
**create_args_no_response_format,
|
||||
)
|
||||
)
|
||||
else:
|
||||
future = asyncio.ensure_future(
|
||||
self._client.chat.completions.create(
|
||||
messages=oai_messages,
|
||||
stream=False,
|
||||
**create_args,
|
||||
)
|
||||
# Use the regular client
|
||||
future = asyncio.ensure_future(
|
||||
self._client.chat.completions.create(
|
||||
messages=create_params.messages,
|
||||
stream=False,
|
||||
tools=create_params.tools if len(create_params.tools) > 0 else NOT_GIVEN,
|
||||
**create_params.create_args,
|
||||
)
|
||||
)
|
||||
|
||||
if cancellation_token is not None:
|
||||
cancellation_token.link_future(future)
|
||||
result: Union[ParsedChatCompletion[BaseModel], ChatCompletion] = await future
|
||||
if use_beta_client:
|
||||
if create_params.response_format is not None:
|
||||
result = cast(ParsedChatCompletion[Any], result)
|
||||
|
||||
usage = RequestUsage(
|
||||
@@ -532,7 +558,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
|
||||
logger.info(
|
||||
LLMCallEvent(
|
||||
messages=cast(List[Dict[str, Any]], oai_messages),
|
||||
messages=cast(List[Dict[str, Any]], create_params.messages),
|
||||
response=result.model_dump(),
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
@@ -627,7 +653,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
max_consecutive_empty_chunk_tolerance: int = 0,
|
||||
@@ -638,7 +664,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
Args:
|
||||
messages (Sequence[LLMMessage]): A sequence of messages to be processed.
|
||||
tools (Sequence[Tool | ToolSchema], optional): A sequence of tools to be used in the completion. Defaults to `[]`.
|
||||
json_output (Optional[bool], optional): If True, the output will be in JSON format. Defaults to None.
|
||||
json_output (Optional[bool | type[BaseModel]], optional): If True, the output will be in JSON format. If a Pydantic model class, the output will be in that format. Defaults to None.
|
||||
extra_create_args (Mapping[str, Any], optional): Additional arguments for the creation process. Default to `{}`.
|
||||
cancellation_token (Optional[CancellationToken], optional): A token to cancel the operation. Defaults to None.
|
||||
max_consecutive_empty_chunk_tolerance (int): [Deprecated] This parameter is deprecated, empty chunks will be skipped.
|
||||
@@ -663,55 +689,12 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
- `frequency_penalty` (float): A value between -2.0 and 2.0 that penalizes new tokens based on their existing frequency in the text so far, decreasing the likelihood of repeated phrases.
|
||||
- `presence_penalty` (float): A value between -2.0 and 2.0 that penalizes new tokens based on whether they appear in the text so far, encouraging the model to talk about new topics.
|
||||
"""
|
||||
# Make sure all extra_create_args are valid
|
||||
extra_create_args_keys = set(extra_create_args.keys())
|
||||
if not create_kwargs.issuperset(extra_create_args_keys):
|
||||
raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}")
|
||||
|
||||
# Copy the create args and overwrite anything in extra_create_args
|
||||
create_args = self._create_args.copy()
|
||||
create_args.update(extra_create_args)
|
||||
|
||||
# Declare use_beta_client
|
||||
use_beta_client: bool = False
|
||||
response_format_value: Optional[Type[BaseModel]] = None
|
||||
|
||||
if "response_format" in create_args:
|
||||
value = create_args["response_format"]
|
||||
# If value is a Pydantic model class, use the beta client
|
||||
if isinstance(value, type) and issubclass(value, BaseModel):
|
||||
response_format_value = value
|
||||
use_beta_client = True
|
||||
else:
|
||||
# response_format_value is not a Pydantic model class
|
||||
use_beta_client = False
|
||||
response_format_value = None
|
||||
|
||||
# Remove 'response_format' from create_args to prevent passing it twice
|
||||
create_args_no_response_format = {k: v for k, v in create_args.items() if k != "response_format"}
|
||||
|
||||
# TODO: allow custom handling.
|
||||
# For now we raise an error if images are present and vision is not supported
|
||||
if self.model_info["vision"] is False:
|
||||
for message in messages:
|
||||
if isinstance(message, UserMessage):
|
||||
if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content):
|
||||
raise ValueError("Model does not support vision and image was provided")
|
||||
|
||||
if json_output is not None:
|
||||
if self.model_info["json_output"] is False and json_output is True:
|
||||
raise ValueError("Model does not support JSON output")
|
||||
|
||||
if json_output is True:
|
||||
create_args["response_format"] = {"type": "json_object"}
|
||||
else:
|
||||
create_args["response_format"] = {"type": "text"}
|
||||
|
||||
oai_messages_nested = [to_oai_type(m, prepend_name=self._add_name_prefixes) for m in messages]
|
||||
oai_messages = [item for sublist in oai_messages_nested for item in sublist]
|
||||
|
||||
if self.model_info["function_calling"] is False and len(tools) > 0:
|
||||
raise ValueError("Model does not support function calling")
|
||||
create_params = self._process_create_args(
|
||||
messages,
|
||||
tools,
|
||||
json_output,
|
||||
extra_create_args,
|
||||
)
|
||||
|
||||
if max_consecutive_empty_chunk_tolerance != 0:
|
||||
warnings.warn(
|
||||
@@ -720,22 +703,19 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
tool_params = convert_tools(tools)
|
||||
|
||||
# Get the async generator of chunks.
|
||||
if use_beta_client:
|
||||
if create_params.response_format is not None:
|
||||
chunks = self._create_stream_chunks_beta_client(
|
||||
tool_params=tool_params,
|
||||
oai_messages=oai_messages,
|
||||
response_format=response_format_value,
|
||||
create_args_no_response_format=create_args_no_response_format,
|
||||
tool_params=create_params.tools,
|
||||
oai_messages=create_params.messages,
|
||||
response_format=create_params.response_format,
|
||||
create_args_no_response_format=create_params.create_args,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
else:
|
||||
chunks = self._create_stream_chunks(
|
||||
tool_params=tool_params,
|
||||
oai_messages=oai_messages,
|
||||
create_args=create_args,
|
||||
tool_params=create_params.tools,
|
||||
oai_messages=create_params.messages,
|
||||
create_args=create_params.create_args,
|
||||
cancellation_token=cancellation_token,
|
||||
)
|
||||
|
||||
@@ -762,7 +742,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
# Emit the start event.
|
||||
logger.info(
|
||||
LLMStreamStartEvent(
|
||||
messages=cast(List[Dict[str, Any]], oai_messages),
|
||||
messages=cast(List[Dict[str, Any]], create_params.messages),
|
||||
)
|
||||
)
|
||||
# Empty chunks has been observed when the endpoint is under heavy load.
|
||||
@@ -839,7 +819,7 @@ class BaseOpenAIChatCompletionClient(ChatCompletionClient):
|
||||
raise ValueError("Function calls are not supported in this context")
|
||||
|
||||
# We need to get the model from the last chunk, if available.
|
||||
model = maybe_model or create_args["model"]
|
||||
model = maybe_model or create_params.create_args["model"]
|
||||
model = model.replace("gpt-35", "gpt-3.5") # hack for Azure API
|
||||
|
||||
# Because the usage chunk is not guaranteed to be the last chunk, we need to check if it is available.
|
||||
@@ -1151,6 +1131,7 @@ class OpenAIChatCompletionClient(BaseOpenAIChatCompletionClient, Component[OpenA
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
"family": ModelFamily.R1,
|
||||
"structured_output": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -139,7 +139,11 @@ class ReplayChatCompletionClient(ChatCompletionClient, Component[ReplayChatCompl
|
||||
validate_model_info(self._model_info)
|
||||
else:
|
||||
self._model_info = ModelInfo(
|
||||
vision=False, function_calling=False, json_output=False, family=ModelFamily.UNKNOWN
|
||||
vision=False,
|
||||
function_calling=False,
|
||||
json_output=False,
|
||||
family=ModelFamily.UNKNOWN,
|
||||
structured_output=False,
|
||||
)
|
||||
self._total_available_tokens = 10000
|
||||
self._cur_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
|
||||
@@ -158,7 +162,7 @@ class ReplayChatCompletionClient(ChatCompletionClient, Component[ReplayChatCompl
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> CreateResult:
|
||||
@@ -197,7 +201,7 @@ class ReplayChatCompletionClient(ChatCompletionClient, Component[ReplayChatCompl
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
|
||||
@@ -135,6 +135,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
"json_output": True,
|
||||
"vision": True,
|
||||
"family": ModelFamily.CLAUDE_3_5_SONNET,
|
||||
"structured_output": True,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -201,6 +202,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"vision": True,
|
||||
"structured_output": True,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -283,7 +285,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
self._prompt_settings = prompt_settings
|
||||
self._sk_client = sk_client
|
||||
self._model_info = model_info or ModelInfo(
|
||||
vision=False, function_calling=False, json_output=False, family=ModelFamily.UNKNOWN
|
||||
vision=False, function_calling=False, json_output=False, family=ModelFamily.UNKNOWN, structured_output=False
|
||||
)
|
||||
validate_model_info(self._model_info)
|
||||
self._total_prompt_tokens = 0
|
||||
@@ -437,7 +439,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> CreateResult:
|
||||
@@ -465,6 +467,9 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
Returns:
|
||||
CreateResult: The result of the chat completion.
|
||||
"""
|
||||
if isinstance(json_output, type) and issubclass(json_output, BaseModel):
|
||||
raise ValueError("structured output is not currently supported in SKChatCompletionAdapter")
|
||||
|
||||
kernel = self._get_kernel(extra_create_args)
|
||||
|
||||
chat_history = self._convert_to_chat_history(messages)
|
||||
@@ -545,7 +550,7 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
messages: Sequence[LLMMessage],
|
||||
*,
|
||||
tools: Sequence[Tool | ToolSchema] = [],
|
||||
json_output: Optional[bool] = None,
|
||||
json_output: Optional[bool | type[BaseModel]] = None,
|
||||
extra_create_args: Mapping[str, Any] = {},
|
||||
cancellation_token: Optional[CancellationToken] = None,
|
||||
) -> AsyncGenerator[Union[str, CreateResult], None]:
|
||||
@@ -574,6 +579,9 @@ class SKChatCompletionAdapter(ChatCompletionClient):
|
||||
Union[str, CreateResult]: Either a string chunk of the response or a CreateResult containing function calls.
|
||||
"""
|
||||
|
||||
if isinstance(json_output, type) and issubclass(json_output, BaseModel):
|
||||
raise ValueError("structured output is not currently supported in SKChatCompletionAdapter")
|
||||
|
||||
kernel = self._get_kernel(extra_create_args)
|
||||
chat_history = self._convert_to_chat_history(messages)
|
||||
user_settings = self._get_prompt_settings(extra_create_args)
|
||||
|
||||
@@ -88,6 +88,7 @@ def azure_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient
|
||||
"function_calling": False,
|
||||
"vision": False,
|
||||
"family": "unknown",
|
||||
"structured_output": False,
|
||||
},
|
||||
model="model",
|
||||
)
|
||||
@@ -101,6 +102,7 @@ def azure_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompletionClient
|
||||
"function_calling": False,
|
||||
"vision": False,
|
||||
"family": "unknown",
|
||||
"structured_output": False,
|
||||
},
|
||||
model="model",
|
||||
)
|
||||
@@ -117,6 +119,7 @@ async def test_azure_ai_chat_completion_client_validation() -> None:
|
||||
"function_calling": False,
|
||||
"vision": False,
|
||||
"family": "unknown",
|
||||
"structured_output": False,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -129,6 +132,7 @@ async def test_azure_ai_chat_completion_client_validation() -> None:
|
||||
"function_calling": False,
|
||||
"vision": False,
|
||||
"family": "unknown",
|
||||
"structured_output": False,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -141,6 +145,7 @@ async def test_azure_ai_chat_completion_client_validation() -> None:
|
||||
"function_calling": False,
|
||||
"vision": False,
|
||||
"family": "unknown",
|
||||
"structured_output": False,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -267,6 +272,7 @@ def function_calling_client(monkeypatch: pytest.MonkeyPatch) -> AzureAIChatCompl
|
||||
"function_calling": True,
|
||||
"vision": False,
|
||||
"family": "function_calling_model",
|
||||
"structured_output": False,
|
||||
},
|
||||
model="model",
|
||||
)
|
||||
@@ -354,6 +360,7 @@ async def test_multimodal_supported(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"function_calling": False,
|
||||
"vision": True,
|
||||
"family": "vision_model",
|
||||
"structured_output": False,
|
||||
},
|
||||
model="model",
|
||||
)
|
||||
@@ -436,6 +443,7 @@ async def test_r1_content(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"function_calling": False,
|
||||
"vision": True,
|
||||
"family": ModelFamily.R1,
|
||||
"structured_output": False,
|
||||
},
|
||||
model="model",
|
||||
)
|
||||
|
||||
@@ -11,10 +11,12 @@ from autogen_core.models import (
|
||||
)
|
||||
from autogen_ext.models.cache import ChatCompletionCache
|
||||
from autogen_ext.models.replay import ReplayChatCompletionClient
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def get_test_data() -> Tuple[list[str], list[str], SystemMessage, ChatCompletionClient, ChatCompletionCache]:
|
||||
num_messages = 3
|
||||
def get_test_data(
|
||||
num_messages: int = 3,
|
||||
) -> Tuple[list[str], list[str], SystemMessage, ChatCompletionClient, ChatCompletionCache]:
|
||||
responses = [f"This is dummy message number {i}" for i in range(num_messages)]
|
||||
prompts = [f"This is dummy prompt number {i}" for i in range(num_messages)]
|
||||
system_prompt = SystemMessage(content="This is a system prompt")
|
||||
@@ -53,6 +55,54 @@ async def test_cache_basic_with_args() -> None:
|
||||
assert response2.content == responses[2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_structured_output_with_args() -> None:
|
||||
responses, prompts, system_prompt, _, cached_client = get_test_data(num_messages=4)
|
||||
|
||||
class Answer(BaseModel):
|
||||
thought: str
|
||||
answer: str
|
||||
|
||||
class Answer2(BaseModel):
|
||||
calculation: str
|
||||
answer: str
|
||||
|
||||
response0 = await cached_client.create(
|
||||
[system_prompt, UserMessage(content=prompts[0], source="user")], json_output=Answer
|
||||
)
|
||||
assert isinstance(response0, CreateResult)
|
||||
assert not response0.cached
|
||||
assert response0.content == responses[0]
|
||||
|
||||
response1 = await cached_client.create(
|
||||
[system_prompt, UserMessage(content=prompts[1], source="user")], json_output=Answer
|
||||
)
|
||||
assert not response1.cached
|
||||
assert response1.content == responses[1]
|
||||
|
||||
# Cached output.
|
||||
response0_cached = await cached_client.create(
|
||||
[system_prompt, UserMessage(content=prompts[0], source="user")], json_output=Answer
|
||||
)
|
||||
assert isinstance(response0, CreateResult)
|
||||
assert response0_cached.cached
|
||||
assert response0_cached.content == responses[0]
|
||||
|
||||
# Without the json_output argument, the cache should not be hit.
|
||||
response0 = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")])
|
||||
assert isinstance(response0, CreateResult)
|
||||
assert not response0.cached
|
||||
assert response0.content == responses[2]
|
||||
|
||||
# With a different output type, the cache should not be hit.
|
||||
response0 = await cached_client.create(
|
||||
[system_prompt, UserMessage(content=prompts[1], source="user")], json_output=Answer2
|
||||
)
|
||||
assert isinstance(response0, CreateResult)
|
||||
assert not response0.cached
|
||||
assert response0.content == responses[3]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_model_and_count_api() -> None:
|
||||
_, prompts, system_prompt, replay_client, cached_client = get_test_data()
|
||||
|
||||
@@ -9,6 +9,8 @@ import torch
|
||||
# from autogen_agentchat.messages import TextMessage
|
||||
# from autogen_core import CancellationToken
|
||||
from autogen_core.models import RequestUsage, SystemMessage, UserMessage
|
||||
from llama_cpp import ChatCompletionRequestResponseFormat
|
||||
from pydantic import BaseModel
|
||||
|
||||
# from autogen_core.tools import FunctionTool
|
||||
try:
|
||||
@@ -21,6 +23,13 @@ except ImportError:
|
||||
pytest.skip("Skipping LlamaCppChatCompletionClient tests: llama-cpp-python not installed", allow_module_level=True)
|
||||
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
"""A response from the agent."""
|
||||
|
||||
thoughts: str
|
||||
content: str
|
||||
|
||||
|
||||
# Fake Llama class to simulate responses
|
||||
class FakeLlama:
|
||||
def __init__(
|
||||
@@ -30,16 +39,29 @@ class FakeLlama:
|
||||
) -> None:
|
||||
self.model_path = model_path
|
||||
self.n_ctx = lambda: 1024
|
||||
self._structured_response = AgentResponse(thoughts="Test thoughts", content="Test content")
|
||||
|
||||
# Added tokenize method for testing purposes.
|
||||
def tokenize(self, b: bytes) -> list[int]:
|
||||
return list(b)
|
||||
|
||||
def create_chat_completion(
|
||||
self, messages: Any, tools: List[ChatCompletionMessageToolCalls] | None, stream: bool = False
|
||||
self,
|
||||
messages: Any,
|
||||
tools: List[ChatCompletionMessageToolCalls] | None,
|
||||
stream: bool = False,
|
||||
response_format: ChatCompletionRequestResponseFormat | None = None,
|
||||
) -> dict[str, Any]:
|
||||
# Return fake non-streaming response.
|
||||
|
||||
if response_format is not None:
|
||||
assert self._structured_response is not None
|
||||
# If response_format is provided, return a different format.
|
||||
return {
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 2},
|
||||
"choices": [{"message": {"content": self._structured_response.model_dump_json()}}],
|
||||
}
|
||||
|
||||
return {
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 2},
|
||||
"choices": [{"message": {"content": "Fake response"}}],
|
||||
@@ -81,6 +103,22 @@ async def test_llama_cpp_create(get_completion_client: "ContextManager[type[Llam
|
||||
assert result.finish_reason in ("stop", "unknown")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llama_cpp_create_structured_output(
|
||||
get_completion_client: "ContextManager[type[LlamaCppChatCompletionClient]]",
|
||||
) -> None:
|
||||
with get_completion_client as Client:
|
||||
client = Client(model_path="dummy")
|
||||
messages: Sequence[Union[SystemMessage, UserMessage]] = [
|
||||
SystemMessage(content="Test system"),
|
||||
UserMessage(content="Test user", source="user"),
|
||||
]
|
||||
result = await client.create(messages=messages, json_output=AgentResponse)
|
||||
assert isinstance(result.content, str)
|
||||
assert AgentResponse.model_validate_json(result.content).thoughts == "Test thoughts"
|
||||
assert AgentResponse.model_validate_json(result.content).content == "Test content"
|
||||
|
||||
|
||||
# Commmented out due to raising not implemented error will leave in case streaming is supported in the future.
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_llama_cpp_create_stream(
|
||||
@@ -133,7 +171,12 @@ async def test_llama_cpp_integration_non_streaming() -> None:
|
||||
from autogen_ext.models.llama_cpp._llama_cpp_completion_client import LlamaCppChatCompletionClient
|
||||
|
||||
client = LlamaCppChatCompletionClient(
|
||||
repo_id="unsloth/phi-4-GGUF", filename="phi-4-Q2_K_L.gguf", n_gpu_layers=-1, seed=1337, n_ctx=5000
|
||||
repo_id="unsloth/phi-4-GGUF",
|
||||
filename="phi-4-Q2_K_L.gguf",
|
||||
n_gpu_layers=-1,
|
||||
seed=1337,
|
||||
n_ctx=5000,
|
||||
verbose=False,
|
||||
)
|
||||
messages: Sequence[Union[SystemMessage, UserMessage]] = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
@@ -143,6 +186,30 @@ async def test_llama_cpp_integration_non_streaming() -> None:
|
||||
assert isinstance(result.content, str) and len(result.content.strip()) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llama_cpp_integration_non_streaming_structured_output() -> None:
|
||||
if not ((hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) or torch.cuda.is_available()):
|
||||
pytest.skip("Skipping LlamaCpp integration tests: GPU not available not set")
|
||||
|
||||
from autogen_ext.models.llama_cpp._llama_cpp_completion_client import LlamaCppChatCompletionClient
|
||||
|
||||
client = LlamaCppChatCompletionClient(
|
||||
repo_id="unsloth/phi-4-GGUF",
|
||||
filename="phi-4-Q2_K_L.gguf",
|
||||
n_gpu_layers=-1,
|
||||
seed=1337,
|
||||
n_ctx=5000,
|
||||
verbose=False,
|
||||
)
|
||||
messages: Sequence[Union[SystemMessage, UserMessage]] = [
|
||||
SystemMessage(content="You are a helpful assistant."),
|
||||
UserMessage(content="Hello, how are you?", source="user"),
|
||||
]
|
||||
result = await client.create(messages=messages, json_output=AgentResponse)
|
||||
assert isinstance(result.content, str) and len(result.content.strip()) > 0
|
||||
assert AgentResponse.model_validate_json(result.content)
|
||||
|
||||
|
||||
# Commmented out due to raising not implemented error will leave in case streaming is supported in the future.
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_llama_cpp_integration_streaming() -> None:
|
||||
|
||||
@@ -1,11 +1,24 @@
|
||||
from typing import Any, Mapping
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncGenerator, List, Mapping
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from autogen_core.models._types import UserMessage
|
||||
import pytest_asyncio
|
||||
from autogen_core import FunctionCall
|
||||
from autogen_core.models import (
|
||||
AssistantMessage,
|
||||
CreateResult,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from autogen_core.tools import FunctionTool
|
||||
from autogen_ext.models.ollama import OllamaChatCompletionClient
|
||||
from autogen_ext.models.ollama._ollama_client import OLLAMA_VALID_CREATE_KWARGS_KEYS
|
||||
from httpx import Response
|
||||
from ollama import AsyncClient
|
||||
from ollama import AsyncClient, ChatResponse, Message
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def _mock_request(*args: Any, **kwargs: Any) -> Response:
|
||||
@@ -47,3 +60,544 @@ def test_create_args_from_config_drops_unexpected_kwargs() -> None:
|
||||
|
||||
for arg in final_create_args.keys():
|
||||
assert arg in OLLAMA_VALID_CREATE_KWARGS_KEYS
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create(monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture) -> None:
|
||||
model = "llama3.2"
|
||||
content_raw = "Hello world! This is a test response. Test response."
|
||||
|
||||
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
|
||||
return ChatResponse(
|
||||
model=model,
|
||||
done=True,
|
||||
done_reason="stop",
|
||||
message=Message(
|
||||
role="assistant",
|
||||
content=content_raw,
|
||||
),
|
||||
prompt_eval_count=10,
|
||||
eval_count=12,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
|
||||
with caplog.at_level(logging.INFO):
|
||||
client = OllamaChatCompletionClient(model=model)
|
||||
create_result = await client.create(
|
||||
messages=[
|
||||
UserMessage(content="hi", source="user"),
|
||||
],
|
||||
)
|
||||
assert "LLMCall" in caplog.text and content_raw in caplog.text
|
||||
assert isinstance(create_result.content, str)
|
||||
assert len(create_result.content) > 0
|
||||
assert create_result.finish_reason == "stop"
|
||||
assert create_result.usage is not None
|
||||
assert create_result.usage.prompt_tokens == 10
|
||||
assert create_result.usage.completion_tokens == 12
|
||||
assert create_result.content == content_raw
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_stream(monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture) -> None:
|
||||
model = "llama3.2"
|
||||
content_raw = "Hello world! This is a test response. Test response."
|
||||
|
||||
async def _mock_chat(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatResponse, None]:
|
||||
assert "stream" in kwargs
|
||||
assert kwargs["stream"] is True
|
||||
|
||||
async def _mock_stream() -> AsyncGenerator[ChatResponse, None]:
|
||||
chunks = [content_raw[i : i + 5] for i in range(0, len(content_raw), 5)]
|
||||
# Simulate streaming by yielding chunks of the response
|
||||
for chunk in chunks[:-1]:
|
||||
yield ChatResponse(
|
||||
model=model,
|
||||
done=False,
|
||||
message=Message(
|
||||
role="assistant",
|
||||
content=chunk,
|
||||
),
|
||||
)
|
||||
yield ChatResponse(
|
||||
model=model,
|
||||
done=True,
|
||||
done_reason="stop",
|
||||
message=Message(
|
||||
role="assistant",
|
||||
content=chunks[-1],
|
||||
),
|
||||
prompt_eval_count=10,
|
||||
eval_count=12,
|
||||
)
|
||||
|
||||
return _mock_stream()
|
||||
|
||||
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
|
||||
client = OllamaChatCompletionClient(model=model)
|
||||
with caplog.at_level(logging.INFO):
|
||||
stream = client.create_stream(
|
||||
messages=[
|
||||
UserMessage(content="hi", source="user"),
|
||||
],
|
||||
)
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
assert "LLMStreamStart" in caplog.text and "hi" in caplog.text
|
||||
assert "LLMStreamEnd" in caplog.text and content_raw in caplog.text
|
||||
assert len(chunks) > 0
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
assert isinstance(chunks[-1].content, str)
|
||||
assert chunks[-1].content == content_raw
|
||||
assert chunks[-1].finish_reason == "stop"
|
||||
assert chunks[-1].usage is not None
|
||||
assert chunks[-1].usage.prompt_tokens == 10
|
||||
assert chunks[-1].usage.completion_tokens == 12
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def add(x: int, y: int) -> str:
|
||||
return str(x + y)
|
||||
|
||||
add_tool = FunctionTool(add, description="Add two numbers")
|
||||
|
||||
model = "llama3.2"
|
||||
|
||||
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
|
||||
return ChatResponse(
|
||||
model=model,
|
||||
done=True,
|
||||
done_reason="stop",
|
||||
message=Message(
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
Message.ToolCall(
|
||||
function=Message.ToolCall.Function(
|
||||
name=add_tool.name,
|
||||
arguments={"x": 2, "y": 2},
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
prompt_eval_count=10,
|
||||
eval_count=12,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
|
||||
|
||||
client = OllamaChatCompletionClient(model=model)
|
||||
create_result = await client.create(
|
||||
messages=[
|
||||
UserMessage(content="hi", source="user"),
|
||||
],
|
||||
tools=[add_tool],
|
||||
)
|
||||
assert isinstance(create_result.content, list)
|
||||
assert len(create_result.content) > 0
|
||||
assert isinstance(create_result.content[0], FunctionCall)
|
||||
assert create_result.content[0].name == add_tool.name
|
||||
assert create_result.content[0].arguments == json.dumps({"x": 2, "y": 2})
|
||||
assert create_result.finish_reason == "function_calls"
|
||||
assert create_result.usage is not None
|
||||
assert create_result.usage.prompt_tokens == 10
|
||||
assert create_result.usage.completion_tokens == 12
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class ResponseType(BaseModel):
|
||||
response: str
|
||||
|
||||
model = "llama3.2"
|
||||
|
||||
async def _mock_chat(*args: Any, **kwargs: Any) -> ChatResponse:
|
||||
return ChatResponse(
|
||||
model=model,
|
||||
done=True,
|
||||
done_reason="stop",
|
||||
message=Message(
|
||||
role="assistant",
|
||||
content=json.dumps({"response": "Hello world!"}),
|
||||
),
|
||||
prompt_eval_count=10,
|
||||
eval_count=12,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
|
||||
|
||||
client = OllamaChatCompletionClient(model=model)
|
||||
create_result = await client.create(
|
||||
messages=[
|
||||
UserMessage(content="hi", source="user"),
|
||||
],
|
||||
json_output=ResponseType,
|
||||
)
|
||||
assert isinstance(create_result.content, str)
|
||||
assert len(create_result.content) > 0
|
||||
assert create_result.finish_reason == "stop"
|
||||
assert create_result.usage is not None
|
||||
assert create_result.usage.prompt_tokens == 10
|
||||
assert create_result.usage.completion_tokens == 12
|
||||
assert ResponseType.model_validate_json(create_result.content)
|
||||
|
||||
create_result = await client.create(
|
||||
messages=[
|
||||
UserMessage(content="hi", source="user"),
|
||||
],
|
||||
extra_create_args={"format": ResponseType.model_json_schema()},
|
||||
)
|
||||
assert isinstance(create_result.content, str)
|
||||
assert len(create_result.content) > 0
|
||||
assert create_result.finish_reason == "stop"
|
||||
assert create_result.usage is not None
|
||||
assert create_result.usage.prompt_tokens == 10
|
||||
assert create_result.usage.completion_tokens == 12
|
||||
assert ResponseType.model_validate_json(create_result.content)
|
||||
|
||||
# Test case when response_format is in extra_create_args.
|
||||
with pytest.warns(DeprecationWarning, match="Using response_format will be deprecated. Use json_output instead."):
|
||||
create_result = await client.create(
|
||||
messages=[
|
||||
UserMessage(content="hi", source="user"),
|
||||
],
|
||||
extra_create_args={"response_format": ResponseType},
|
||||
)
|
||||
|
||||
# Test case when response_format is in extra_create_args but is not a pydantic model.
|
||||
with pytest.raises(ValueError, match="response_format must be a Pydantic model class"):
|
||||
create_result = await client.create(
|
||||
messages=[
|
||||
UserMessage(content="hi", source="user"),
|
||||
],
|
||||
extra_create_args={"response_format": "json"},
|
||||
)
|
||||
|
||||
# Test case when response_format is in extra_create_args and json_output is also set.
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="response_format and json_output cannot be set to a Pydantic model class at the same time. Use json_output instead.",
|
||||
):
|
||||
create_result = await client.create(
|
||||
messages=[
|
||||
UserMessage(content="hi", source="user"),
|
||||
],
|
||||
extra_create_args={"response_format": ResponseType},
|
||||
json_output=ResponseType,
|
||||
)
|
||||
|
||||
# Test case when format is in extra_create_args and json_output is also set.
|
||||
with pytest.raises(
|
||||
ValueError, match="json_output and format cannot be set at the same time. Use json_output instead."
|
||||
):
|
||||
create_result = await client.create(
|
||||
messages=[
|
||||
UserMessage(content="hi", source="user"),
|
||||
],
|
||||
extra_create_args={"format": ResponseType.model_json_schema()},
|
||||
json_output=ResponseType,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_stream_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class ResponseType(BaseModel):
|
||||
response: str
|
||||
|
||||
model = "llama3.2"
|
||||
content_raw = json.dumps({"response": "Hello world! This is a test response. Test response."})
|
||||
|
||||
async def _mock_chat(*args: Any, **kwargs: Any) -> AsyncGenerator[ChatResponse, None]:
|
||||
assert "stream" in kwargs
|
||||
assert kwargs["stream"] is True
|
||||
|
||||
async def _mock_stream() -> AsyncGenerator[ChatResponse, None]:
|
||||
chunks = [content_raw[i : i + 5] for i in range(0, len(content_raw), 5)]
|
||||
# Simulate streaming by yielding chunks of the response
|
||||
for chunk in chunks[:-1]:
|
||||
yield ChatResponse(
|
||||
model=model,
|
||||
done=False,
|
||||
message=Message(
|
||||
role="assistant",
|
||||
content=chunk,
|
||||
),
|
||||
)
|
||||
yield ChatResponse(
|
||||
model=model,
|
||||
done=True,
|
||||
done_reason="stop",
|
||||
message=Message(
|
||||
role="assistant",
|
||||
content=chunks[-1],
|
||||
),
|
||||
prompt_eval_count=10,
|
||||
eval_count=12,
|
||||
)
|
||||
|
||||
return _mock_stream()
|
||||
|
||||
monkeypatch.setattr(AsyncClient, "chat", _mock_chat)
|
||||
|
||||
client = OllamaChatCompletionClient(model=model)
|
||||
stream = client.create_stream(
|
||||
messages=[
|
||||
UserMessage(content="hi", source="user"),
|
||||
],
|
||||
json_output=ResponseType,
|
||||
)
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
assert len(chunks) > 0
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
assert isinstance(chunks[-1].content, str)
|
||||
assert chunks[-1].content == content_raw
|
||||
assert chunks[-1].finish_reason == "stop"
|
||||
assert chunks[-1].usage is not None
|
||||
assert chunks[-1].usage.prompt_tokens == 10
|
||||
assert chunks[-1].usage.completion_tokens == 12
|
||||
assert ResponseType.model_validate_json(chunks[-1].content)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture # type: ignore
|
||||
async def ollama_client(request: pytest.FixtureRequest) -> OllamaChatCompletionClient:
|
||||
model = request.node.callspec.params["model"] # type: ignore
|
||||
assert isinstance(model, str)
|
||||
# Check if the model is running locally.
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(f"http://localhost:11434/v1/models/{model}")
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
pytest.skip(f"{model} model is not running locally: {e}")
|
||||
except httpx.ConnectError as e:
|
||||
pytest.skip(f"Ollama is not running locally: {e}")
|
||||
return OllamaChatCompletionClient(model=model)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b", "llama3.2:1b"])
|
||||
async def test_ollama_create(model: str, ollama_client: OllamaChatCompletionClient) -> None:
|
||||
create_result = await ollama_client.create(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Taking two balls from a bag of 10 green balls and 20 red balls, "
|
||||
"what is the probability of getting a green and a red balls?",
|
||||
source="user",
|
||||
),
|
||||
]
|
||||
)
|
||||
assert isinstance(create_result.content, str)
|
||||
assert len(create_result.content) > 0
|
||||
assert create_result.finish_reason == "stop"
|
||||
assert create_result.usage is not None
|
||||
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in ollama_client.create_stream(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Taking two balls from a bag of 10 green balls and 20 red balls, "
|
||||
"what is the probability of getting a green and a red balls?",
|
||||
source="user",
|
||||
),
|
||||
]
|
||||
):
|
||||
chunks.append(chunk)
|
||||
assert len(chunks) > 0
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
assert chunks[-1].finish_reason == "stop"
|
||||
assert len(chunks[-1].content) > 0
|
||||
assert chunks[-1].usage is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model", ["deepseek-r1:1.5b", "llama3.2:1b"])
|
||||
async def test_ollama_create_structured_output(model: str, ollama_client: OllamaChatCompletionClient) -> None:
|
||||
class ResponseType(BaseModel):
|
||||
calculation: str
|
||||
result: str
|
||||
|
||||
create_result = await ollama_client.create(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Taking two balls from a bag of 10 green balls and 20 red balls, "
|
||||
"what is the probability of getting a green and a red balls?",
|
||||
source="user",
|
||||
),
|
||||
],
|
||||
json_output=ResponseType,
|
||||
)
|
||||
assert isinstance(create_result.content, str)
|
||||
assert len(create_result.content) > 0
|
||||
assert create_result.finish_reason == "stop"
|
||||
assert create_result.usage is not None
|
||||
assert ResponseType.model_validate_json(create_result.content)
|
||||
|
||||
# Test streaming completion with the Ollama deepseek-r1:1.5b model.
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in ollama_client.create_stream(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="Taking two balls from a bag of 10 green balls and 20 red balls, "
|
||||
"what is the probability of getting a green and a red balls?",
|
||||
source="user",
|
||||
),
|
||||
],
|
||||
json_output=ResponseType,
|
||||
):
|
||||
chunks.append(chunk)
|
||||
assert len(chunks) > 0
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
assert chunks[-1].finish_reason == "stop"
|
||||
assert isinstance(chunks[-1].content, str)
|
||||
assert len(chunks[-1].content) > 0
|
||||
assert chunks[-1].usage is not None
|
||||
assert ResponseType.model_validate_json(chunks[-1].content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model", ["qwen2.5:0.5b", "llama3.2:1b"])
|
||||
async def test_ollama_create_tools(model: str, ollama_client: OllamaChatCompletionClient) -> None:
|
||||
def add(x: int, y: int) -> str:
|
||||
return str(x + y)
|
||||
|
||||
add_tool = FunctionTool(add, description="Add two numbers")
|
||||
|
||||
create_result = await ollama_client.create(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is 2 + 2? Use the add tool.",
|
||||
source="user",
|
||||
),
|
||||
],
|
||||
tools=[add_tool],
|
||||
)
|
||||
assert isinstance(create_result.content, list)
|
||||
assert len(create_result.content) > 0
|
||||
assert isinstance(create_result.content[0], FunctionCall)
|
||||
assert create_result.content[0].name == add_tool.name
|
||||
assert create_result.finish_reason == "function_calls"
|
||||
|
||||
execution_result = FunctionExecutionResult(
|
||||
content="4",
|
||||
name=add_tool.name,
|
||||
call_id=create_result.content[0].id,
|
||||
is_error=False,
|
||||
)
|
||||
create_result = await ollama_client.create(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is 2 + 2? Use the add tool.",
|
||||
source="user",
|
||||
),
|
||||
AssistantMessage(
|
||||
content=create_result.content,
|
||||
source="assistant",
|
||||
),
|
||||
FunctionExecutionResultMessage(
|
||||
content=[execution_result],
|
||||
),
|
||||
],
|
||||
)
|
||||
assert isinstance(create_result.content, str)
|
||||
assert len(create_result.content) > 0
|
||||
assert create_result.finish_reason == "stop"
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO: Does Ollama support structured outputs with tools?")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model", ["llama3.2:1b"])
|
||||
async def test_ollama_create_structured_output_with_tools(
|
||||
model: str, ollama_client: OllamaChatCompletionClient
|
||||
) -> None:
|
||||
class ResponseType(BaseModel):
|
||||
calculation: str
|
||||
result: str
|
||||
|
||||
def add(x: int, y: int) -> str:
|
||||
return str(x + y)
|
||||
|
||||
add_tool = FunctionTool(add, description="Add two numbers")
|
||||
|
||||
create_result = await ollama_client.create(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is 2 + 2? Use the add tool.",
|
||||
source="user",
|
||||
),
|
||||
],
|
||||
tools=[add_tool],
|
||||
json_output=ResponseType,
|
||||
)
|
||||
assert isinstance(create_result.content, list)
|
||||
assert len(create_result.content) > 0
|
||||
assert isinstance(create_result.content[0], FunctionCall)
|
||||
assert create_result.content[0].name == add_tool.name
|
||||
assert create_result.finish_reason == "function_calls"
|
||||
assert create_result.thought is not None
|
||||
assert ResponseType.model_validate_json(create_result.thought)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO: Fix streaming with tools")
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model", ["qwen2.5:0.5b", "llama3.2:1b"])
|
||||
async def test_ollama_create_stream_tools(model: str, ollama_client: OllamaChatCompletionClient) -> None:
|
||||
def add(x: int, y: int) -> str:
|
||||
return str(x + y)
|
||||
|
||||
add_tool = FunctionTool(add, description="Add two numbers")
|
||||
|
||||
stream = ollama_client.create_stream(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is 2 + 2? Use the add tool.",
|
||||
source="user",
|
||||
),
|
||||
],
|
||||
tools=[add_tool],
|
||||
)
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
assert len(chunks) > 0
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
create_result = chunks[-1]
|
||||
assert isinstance(create_result.content, list)
|
||||
assert len(create_result.content) > 0
|
||||
assert isinstance(create_result.content[0], FunctionCall)
|
||||
assert create_result.content[0].name == add_tool.name
|
||||
assert create_result.finish_reason == "function_calls"
|
||||
|
||||
execution_result = FunctionExecutionResult(
|
||||
content="4",
|
||||
name=add_tool.name,
|
||||
call_id=create_result.content[0].id,
|
||||
is_error=False,
|
||||
)
|
||||
stream = ollama_client.create_stream(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content="What is 2 + 2? Use the add tool.",
|
||||
source="user",
|
||||
),
|
||||
AssistantMessage(
|
||||
content=create_result.content,
|
||||
source="assistant",
|
||||
),
|
||||
FunctionExecutionResultMessage(
|
||||
content=[execution_result],
|
||||
),
|
||||
],
|
||||
)
|
||||
chunks = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
assert len(chunks) > 0
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
create_result = chunks[-1]
|
||||
assert isinstance(create_result.content, str)
|
||||
assert len(create_result.content) > 0
|
||||
assert create_result.finish_reason == "stop"
|
||||
|
||||
@@ -218,7 +218,13 @@ async def test_custom_model_with_capabilities() -> None:
|
||||
model="dummy_model",
|
||||
base_url="https://api.dummy.com/v0",
|
||||
api_key="api_key",
|
||||
model_info={"vision": False, "function_calling": False, "json_output": False, "family": ModelFamily.UNKNOWN},
|
||||
model_info={
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": False,
|
||||
},
|
||||
)
|
||||
assert client
|
||||
|
||||
@@ -231,7 +237,13 @@ async def test_azure_openai_chat_completion_client() -> None:
|
||||
api_key="api_key",
|
||||
api_version="2020-08-04",
|
||||
azure_endpoint="https://dummy.com",
|
||||
model_info={"vision": True, "function_calling": True, "json_output": True, "family": ModelFamily.GPT_4O},
|
||||
model_info={
|
||||
"vision": True,
|
||||
"function_calling": True,
|
||||
"json_output": True,
|
||||
"family": ModelFamily.GPT_4O,
|
||||
"structured_output": True,
|
||||
},
|
||||
)
|
||||
assert client
|
||||
|
||||
@@ -446,6 +458,97 @@ def test_convert_tools_accepts_both_tool_and_schema() -> None:
|
||||
assert converted_tool_schema[0] == converted_tool_schema[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-11-20"
|
||||
|
||||
called_args = {}
|
||||
|
||||
async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletion:
|
||||
# Capture the arguments passed to the function
|
||||
called_args["kwargs"] = kwargs
|
||||
return ChatCompletion(
|
||||
id="id1",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content=json.dumps({"thoughts": "happy", "response": "happy"}),
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
],
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(AsyncCompletions, "create", _mock_create)
|
||||
model_client = OpenAIChatCompletionClient(model=model, api_key="")
|
||||
|
||||
# Test that the openai client was called with the correct response format.
|
||||
create_result = await model_client.create(
|
||||
messages=[UserMessage(content="I am happy.", source="user")], json_output=True
|
||||
)
|
||||
assert isinstance(create_result.content, str)
|
||||
response = json.loads(create_result.content)
|
||||
assert response["thoughts"] == "happy"
|
||||
assert response["response"] == "happy"
|
||||
assert called_args["kwargs"]["response_format"] == {"type": "json_object"}
|
||||
|
||||
# Make sure that the response format is set to json_object when json_output is True, regardless of the extra_create_args.
|
||||
create_result = await model_client.create(
|
||||
messages=[UserMessage(content="I am happy.", source="user")],
|
||||
json_output=True,
|
||||
extra_create_args={"response_format": "json_object"},
|
||||
)
|
||||
assert isinstance(create_result.content, str)
|
||||
response = json.loads(create_result.content)
|
||||
assert response["thoughts"] == "happy"
|
||||
assert response["response"] == "happy"
|
||||
assert called_args["kwargs"]["response_format"] == {"type": "json_object"}
|
||||
|
||||
create_result = await model_client.create(
|
||||
messages=[UserMessage(content="I am happy.", source="user")],
|
||||
json_output=True,
|
||||
extra_create_args={"response_format": "text"},
|
||||
)
|
||||
assert isinstance(create_result.content, str)
|
||||
response = json.loads(create_result.content)
|
||||
assert response["thoughts"] == "happy"
|
||||
assert response["response"] == "happy"
|
||||
# Check that the openai client was called with the correct response format.
|
||||
assert called_args["kwargs"]["response_format"] == {"type": "json_object"}
|
||||
|
||||
# Make sure when json_output is set to False, the response format is always set to text.
|
||||
create_result = await model_client.create(
|
||||
messages=[UserMessage(content="I am happy.", source="user")],
|
||||
json_output=False,
|
||||
extra_create_args={"response_format": "text"},
|
||||
)
|
||||
assert called_args["kwargs"]["response_format"] == {"type": "text"}
|
||||
|
||||
create_result = await model_client.create(
|
||||
messages=[UserMessage(content="I am happy.", source="user")],
|
||||
json_output=False,
|
||||
extra_create_args={"response_format": "json_object"},
|
||||
)
|
||||
assert called_args["kwargs"]["response_format"] == {"type": "text"}
|
||||
|
||||
# Make sure when response_format is set it is used when json_output is not set.
|
||||
create_result = await model_client.create(
|
||||
messages=[UserMessage(content="I am happy.", source="user")],
|
||||
extra_create_args={"response_format": {"type": "json_object"}},
|
||||
)
|
||||
assert isinstance(create_result.content, str)
|
||||
response = json.loads(create_result.content)
|
||||
assert response["thoughts"] == "happy"
|
||||
assert response["response"] == "happy"
|
||||
assert called_args["kwargs"]["response_format"] == {"type": "json_object"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class AgentResponse(BaseModel):
|
||||
@@ -483,11 +586,12 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model=model,
|
||||
api_key="",
|
||||
response_format=AgentResponse, # type: ignore
|
||||
)
|
||||
|
||||
# Test that the openai client was called with the correct response format.
|
||||
create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")])
|
||||
create_result = await model_client.create(
|
||||
messages=[UserMessage(content="I am happy.", source="user")], json_output=AgentResponse
|
||||
)
|
||||
assert isinstance(create_result.content, str)
|
||||
response = AgentResponse.model_validate(json.loads(create_result.content))
|
||||
assert (
|
||||
@@ -496,6 +600,36 @@ async def test_structured_output(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
)
|
||||
assert response.response == "happy"
|
||||
|
||||
# Test that a warning will be raise if response_format is set to a dict.
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match="response_format is found in extra_create_args while json_output is set to a Pydantic model class.",
|
||||
):
|
||||
create_result = await model_client.create(
|
||||
messages=[UserMessage(content="I am happy.", source="user")],
|
||||
json_output=AgentResponse,
|
||||
extra_create_args={"response_format": {"type": "json_object"}},
|
||||
)
|
||||
|
||||
# Test that a warning will be raised if response_format is set to a pydantic model.
|
||||
with pytest.warns(
|
||||
DeprecationWarning, match="Using response_format to specify structured output type will be deprecated."
|
||||
):
|
||||
create_result = await model_client.create(
|
||||
messages=[UserMessage(content="I am happy.", source="user")],
|
||||
extra_create_args={"response_format": AgentResponse},
|
||||
)
|
||||
|
||||
# Test that a ValueError will be raised if response_format and json_output are set to a pydantic model.
|
||||
with pytest.raises(
|
||||
ValueError, match="response_format and json_output cannot be set to a Pydantic model class at the same time."
|
||||
):
|
||||
create_result = await model_client.create(
|
||||
messages=[UserMessage(content="I am happy.", source="user")],
|
||||
json_output=AgentResponse,
|
||||
extra_create_args={"response_format": AgentResponse},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_structured_output_with_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
@@ -544,11 +678,12 @@ async def test_structured_output_with_tool_calls(monkeypatch: pytest.MonkeyPatch
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model=model,
|
||||
api_key="",
|
||||
response_format=AgentResponse, # type: ignore
|
||||
)
|
||||
|
||||
# Test that the openai client was called with the correct response format.
|
||||
create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")])
|
||||
create_result = await model_client.create(
|
||||
messages=[UserMessage(content="I am happy.", source="user")], json_output=AgentResponse
|
||||
)
|
||||
assert isinstance(create_result.content, list)
|
||||
assert len(create_result.content) == 1
|
||||
assert create_result.content[0] == FunctionCall(
|
||||
@@ -617,12 +752,13 @@ async def test_structured_output_with_streaming(monkeypatch: pytest.MonkeyPatch)
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model=model,
|
||||
api_key="",
|
||||
response_format=AgentResponse, # type: ignore
|
||||
)
|
||||
|
||||
# Test that the openai client was called with the correct response format.
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")]):
|
||||
async for chunk in model_client.create_stream(
|
||||
messages=[UserMessage(content="I am happy.", source="user")], json_output=AgentResponse
|
||||
):
|
||||
chunks.append(chunk)
|
||||
assert len(chunks) > 0
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
@@ -726,12 +862,13 @@ async def test_structured_output_with_streaming_tool_calls(monkeypatch: pytest.M
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model=model,
|
||||
api_key="",
|
||||
response_format=AgentResponse, # type: ignore
|
||||
)
|
||||
|
||||
# Test that the openai client was called with the correct response format.
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")]):
|
||||
async for chunk in model_client.create_stream(
|
||||
messages=[UserMessage(content="I am happy.", source="user")], json_output=AgentResponse
|
||||
):
|
||||
chunks.append(chunk)
|
||||
assert len(chunks) > 0
|
||||
assert isinstance(chunks[-1], CreateResult)
|
||||
@@ -801,7 +938,13 @@ async def test_r1_think_field(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="r1",
|
||||
api_key="",
|
||||
model_info={"family": ModelFamily.R1, "vision": False, "function_calling": False, "json_output": False},
|
||||
model_info={
|
||||
"family": ModelFamily.R1,
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
"structured_output": False,
|
||||
},
|
||||
)
|
||||
|
||||
# Successful completion with think field.
|
||||
@@ -874,7 +1017,13 @@ async def test_r1_think_field_not_present(monkeypatch: pytest.MonkeyPatch) -> No
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="r1",
|
||||
api_key="",
|
||||
model_info={"family": ModelFamily.R1, "vision": False, "function_calling": False, "json_output": False},
|
||||
model_info={
|
||||
"family": ModelFamily.R1,
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
"structured_output": False,
|
||||
},
|
||||
)
|
||||
|
||||
# Warning completion when think field is not present.
|
||||
@@ -1338,11 +1487,12 @@ async def test_openai_structured_output() -> None:
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o-mini",
|
||||
api_key=api_key,
|
||||
response_format=AgentResponse, # type: ignore
|
||||
)
|
||||
|
||||
# Test that the openai client was called with the correct response format.
|
||||
create_result = await model_client.create(messages=[UserMessage(content="I am happy.", source="user")])
|
||||
create_result = await model_client.create(
|
||||
messages=[UserMessage(content="I am happy.", source="user")], json_output=AgentResponse
|
||||
)
|
||||
assert isinstance(create_result.content, str)
|
||||
response = AgentResponse.model_validate(json.loads(create_result.content))
|
||||
assert response.thoughts
|
||||
@@ -1362,11 +1512,12 @@ async def test_openai_structured_output_with_streaming() -> None:
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o-mini",
|
||||
api_key=api_key,
|
||||
response_format=AgentResponse, # type: ignore
|
||||
)
|
||||
|
||||
# Test that the openai client was called with the correct response format.
|
||||
stream = model_client.create_stream(messages=[UserMessage(content="I am happy.", source="user")])
|
||||
stream = model_client.create_stream(
|
||||
messages=[UserMessage(content="I am happy.", source="user")], json_output=AgentResponse
|
||||
)
|
||||
chunks: List[str | CreateResult] = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
@@ -1397,7 +1548,6 @@ async def test_openai_structured_output_with_tool_calls() -> None:
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o-mini",
|
||||
api_key=api_key,
|
||||
response_format=AgentResponse, # type: ignore
|
||||
)
|
||||
|
||||
response1 = await model_client.create(
|
||||
@@ -1407,6 +1557,7 @@ async def test_openai_structured_output_with_tool_calls() -> None:
|
||||
],
|
||||
tools=[tool],
|
||||
extra_create_args={"tool_choice": "required"},
|
||||
json_output=AgentResponse,
|
||||
)
|
||||
assert isinstance(response1.content, list)
|
||||
assert len(response1.content) == 1
|
||||
@@ -1428,6 +1579,7 @@ async def test_openai_structured_output_with_tool_calls() -> None:
|
||||
]
|
||||
),
|
||||
],
|
||||
json_output=AgentResponse,
|
||||
)
|
||||
assert isinstance(response2.content, str)
|
||||
parsed_response = AgentResponse.model_validate(json.loads(response2.content))
|
||||
@@ -1454,7 +1606,6 @@ async def test_openai_structured_output_with_streaming_tool_calls() -> None:
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="gpt-4o-mini",
|
||||
api_key=api_key,
|
||||
response_format=AgentResponse, # type: ignore
|
||||
)
|
||||
|
||||
chunks1: List[str | CreateResult] = []
|
||||
@@ -1465,6 +1616,7 @@ async def test_openai_structured_output_with_streaming_tool_calls() -> None:
|
||||
],
|
||||
tools=[tool],
|
||||
extra_create_args={"tool_choice": "required"},
|
||||
json_output=AgentResponse,
|
||||
)
|
||||
async for chunk in stream1:
|
||||
chunks1.append(chunk)
|
||||
@@ -1491,6 +1643,7 @@ async def test_openai_structured_output_with_streaming_tool_calls() -> None:
|
||||
]
|
||||
),
|
||||
],
|
||||
json_output=AgentResponse,
|
||||
)
|
||||
chunks2: List[str | CreateResult] = []
|
||||
async for chunk in stream2:
|
||||
@@ -1532,6 +1685,7 @@ async def test_hugging_face() -> None:
|
||||
"json_output": False,
|
||||
"vision": False,
|
||||
"family": ModelFamily.UNKNOWN,
|
||||
"structured_output": False,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1546,6 +1700,7 @@ async def test_ollama() -> None:
|
||||
"json_output": False,
|
||||
"vision": False,
|
||||
"family": ModelFamily.R1,
|
||||
"structured_output": False,
|
||||
}
|
||||
# Check if the model is running locally.
|
||||
try:
|
||||
|
||||
@@ -453,7 +453,9 @@ async def test_sk_chat_completion_default_model_info(sk_client: AzureChatComplet
|
||||
@pytest.mark.asyncio
|
||||
async def test_sk_chat_completion_custom_model_info(sk_client: AzureChatCompletion) -> None:
|
||||
# Create custom model info
|
||||
custom_model_info = ModelInfo(vision=True, function_calling=True, json_output=True, family=ModelFamily.GPT_4)
|
||||
custom_model_info = ModelInfo(
|
||||
vision=True, function_calling=True, json_output=True, family=ModelFamily.GPT_4, structured_output=False
|
||||
)
|
||||
|
||||
# Create adapter with custom model_info
|
||||
adapter = SKChatCompletionAdapter(sk_client, model_info=custom_model_info)
|
||||
@@ -522,7 +524,9 @@ async def test_sk_chat_completion_r1_content() -> None:
|
||||
adapter = SKChatCompletionAdapter(
|
||||
mock_client,
|
||||
kernel=kernel,
|
||||
model_info=ModelInfo(vision=False, function_calling=False, json_output=False, family=ModelFamily.R1),
|
||||
model_info=ModelInfo(
|
||||
vision=False, function_calling=False, json_output=False, family=ModelFamily.R1, structured_output=False
|
||||
),
|
||||
)
|
||||
|
||||
result = await adapter.create(messages=[UserMessage(content="Say hello!", source="user")])
|
||||
|
||||
Reference in New Issue
Block a user