diff --git a/docs/src/guides/azure_openai_with_aad_auth.md b/docs/src/guides/azure_openai_with_aad_auth.md index 5dc5daaf4..c4dcc1d41 100644 --- a/docs/src/guides/azure_openai_with_aad_auth.md +++ b/docs/src/guides/azure_openai_with_aad_auth.md @@ -15,7 +15,7 @@ pip install azure-identity ## Using the Model Client ```python -from agnext.components.model_client import AzureOpenAI +from agnext.components.llm import AzureOpenAI from azure.identity import DefaultAzureCredential, get_bearer_token_provider # Create the token provider diff --git a/examples/orchestrator.py b/examples/orchestrator.py index f2bde7f6a..d0c576c79 100644 --- a/examples/orchestrator.py +++ b/examples/orchestrator.py @@ -16,8 +16,7 @@ from agnext.chat.types import TextMessage from agnext.components.function_executor._impl.in_process_function_executor import ( InProcessFunctionExecutor, ) -from agnext.components.model_client import OpenAI -from agnext.components.types import SystemMessage +from agnext.components.llm import OpenAI, SystemMessage from agnext.core import Agent, AgentRuntime from agnext.core.intervention import DefaultInterventionHandler, DropMessage from tavily import TavilyClient diff --git a/src/agnext/chat/agents/chat_completion_agent.py b/src/agnext/chat/agents/chat_completion_agent.py index 78c28bfb5..84537a75d 100644 --- a/src/agnext/chat/agents/chat_completion_agent.py +++ b/src/agnext/chat/agents/chat_completion_agent.py @@ -5,8 +5,6 @@ from typing import Any, Coroutine, Dict, List, Mapping, Tuple from agnext.chat.agents.base import BaseChatAgent from agnext.chat.types import ( FunctionCallMessage, - FunctionExecutionResult, - FunctionExecutionResultMessage, Message, Reset, RespondNow, @@ -15,12 +13,11 @@ from agnext.chat.types import ( ) from agnext.chat.utils import convert_messages_to_llm_messages from agnext.components.function_executor import FunctionExecutor -from agnext.components.model_client import ModelClient +from agnext.components.llm import FunctionExecutionResult, FunctionExecutionResultMessage, ModelClient, SystemMessage from agnext.components.type_routed_agent import TypeRoutedAgent, message_handler from agnext.components.types import ( FunctionCall, FunctionSignature, - SystemMessage, ) from agnext.core import AgentRuntime, CancellationToken @@ -141,7 +138,7 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent): results.append(FunctionExecutionResult(content=execution_result, call_id=call_id)) # Create a tool call result message. - tool_call_result_msg = FunctionExecutionResultMessage(content=results, source=self.name) + tool_call_result_msg = FunctionExecutionResultMessage(content=results) # Add tool call result message. self._chat_messages.append(tool_call_result_msg) diff --git a/src/agnext/chat/types.py b/src/agnext/chat/types.py index 159b6994f..9eda88e18 100644 --- a/src/agnext/chat/types.py +++ b/src/agnext/chat/types.py @@ -5,6 +5,7 @@ from enum import Enum from typing import List, Union from agnext.components.image import Image +from agnext.components.llm import FunctionExecutionResultMessage from agnext.components.types import FunctionCall @@ -29,17 +30,6 @@ class FunctionCallMessage(BaseMessage): content: List[FunctionCall] -@dataclass -class FunctionExecutionResult: - content: str - call_id: str - - -@dataclass -class FunctionExecutionResultMessage(BaseMessage): - content: List[FunctionExecutionResult] - - Message = Union[TextMessage, MultiModalMessage, FunctionCallMessage, FunctionExecutionResultMessage] diff --git a/src/agnext/chat/utils.py b/src/agnext/chat/utils.py index 493926ad9..321771ddb 100644 --- a/src/agnext/chat/utils.py +++ b/src/agnext/chat/utils.py @@ -4,22 +4,17 @@ from typing_extensions import Literal from agnext.chat.types import ( FunctionCallMessage, - FunctionExecutionResultMessage, Message, MultiModalMessage, TextMessage, ) -from agnext.components.types import ( +from agnext.components.llm import ( AssistantMessage, + FunctionExecutionResult, + FunctionExecutionResultMessage, LLMMessage, UserMessage, ) -from agnext.components.types import ( - FunctionExecutionResult as FunctionExecutionResultType, -) -from agnext.components.types import ( - FunctionExecutionResultMessage as FunctionExecutionResultMessageType, -) def convert_content_message_to_assistant_message( @@ -61,11 +56,11 @@ def convert_content_message_to_user_message( def convert_tool_call_response_message( message: FunctionExecutionResultMessage, handle_unrepresentable: Literal["error", "ignore", "try_slice"] = "error", -) -> Optional[FunctionExecutionResultMessageType]: +) -> Optional[FunctionExecutionResultMessage]: match message: case FunctionExecutionResultMessage(): - return FunctionExecutionResultMessageType( - content=[FunctionExecutionResultType(content=x.content, call_id=x.call_id) for x in message.content] + return FunctionExecutionResultMessage( + content=[FunctionExecutionResult(content=x.content, call_id=x.call_id) for x in message.content] ) @@ -93,7 +88,7 @@ def convert_messages_to_llm_messages( converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable) if converted_message_2 is not None: result.append(converted_message_2) - case FunctionExecutionResultMessage(_, source=source) if source == self_name: + case FunctionExecutionResultMessage(_): converted_message_3 = convert_tool_call_response_message(message, handle_unrepresentable) if converted_message_3 is not None: result.append(converted_message_3) diff --git a/src/agnext/components/function_utils.py b/src/agnext/components/_function_utils.py similarity index 99% rename from src/agnext/components/function_utils.py rename to src/agnext/components/_function_utils.py index 230b6f679..15cce58c1 100644 --- a/src/agnext/components/function_utils.py +++ b/src/agnext/components/_function_utils.py @@ -20,7 +20,7 @@ from typing import ( from pydantic import BaseModel, Field from typing_extensions import Annotated, Literal -from .pydantic_compat import evaluate_forwardref, model_dump, type2schema +from ._pydantic_compat import evaluate_forwardref, model_dump, type2schema logger = getLogger(__name__) diff --git a/src/agnext/components/pydantic_compat.py b/src/agnext/components/_pydantic_compat.py similarity index 100% rename from src/agnext/components/pydantic_compat.py rename to src/agnext/components/_pydantic_compat.py diff --git a/src/agnext/components/function_executor/_base.py b/src/agnext/components/function_executor/_base.py index e80db24ba..1b3284b62 100644 --- a/src/agnext/components/function_executor/_base.py +++ b/src/agnext/components/function_executor/_base.py @@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, Protocol, TypedDict, Union, runtime_chec from typing_extensions import NotRequired, Required -from ..function_utils import get_function_schema +from .._function_utils import get_function_schema from ..types import FunctionSignature diff --git a/src/agnext/components/llm/__init__.py b/src/agnext/components/llm/__init__.py new file mode 100644 index 000000000..f5a77ab04 --- /dev/null +++ b/src/agnext/components/llm/__init__.py @@ -0,0 +1,32 @@ +from ._model_client import ModelCapabilities, ModelClient +from ._openai_client import ( + AzureOpenAI, + OpenAI, +) +from ._types import ( + AssistantMessage, + CreateResult, + FinishReasons, + FunctionExecutionResult, + FunctionExecutionResultMessage, + LLMMessage, + RequestUsage, + SystemMessage, + UserMessage, +) + +__all__ = [ + "AzureOpenAI", + "OpenAI", + "ModelCapabilities", + "ModelClient", + "SystemMessage", + "UserMessage", + "AssistantMessage", + "FunctionExecutionResult", + "FunctionExecutionResultMessage", + "LLMMessage", + "RequestUsage", + "FinishReasons", + "CreateResult", +] diff --git a/src/agnext/components/model_client/_model_client.py b/src/agnext/components/llm/_model_client.py similarity index 94% rename from src/agnext/components/model_client/_model_client.py rename to src/agnext/components/llm/_model_client.py index 209292b6d..d582e689e 100644 --- a/src/agnext/components/model_client/_model_client.py +++ b/src/agnext/components/llm/_model_client.py @@ -11,7 +11,8 @@ from typing_extensions import ( Union, ) -from ..types import CreateResult, FunctionSignature, LLMMessage, RequestUsage +from ..types import FunctionSignature +from ._types import CreateResult, LLMMessage, RequestUsage class ModelCapabilities(TypedDict, total=False): diff --git a/src/agnext/components/model_client/_model_info.py b/src/agnext/components/llm/_model_info.py similarity index 100% rename from src/agnext/components/model_client/_model_info.py rename to src/agnext/components/llm/_model_info.py diff --git a/src/agnext/components/model_client/_openai_client.py b/src/agnext/components/llm/_openai_client.py similarity index 93% rename from src/agnext/components/model_client/_openai_client.py rename to src/agnext/components/llm/_openai_client.py index be890dbdd..619e9b05c 100644 --- a/src/agnext/components/model_client/_openai_client.py +++ b/src/agnext/components/llm/_openai_client.py @@ -4,11 +4,8 @@ import warnings from typing import ( Any, AsyncGenerator, - Awaitable, - Callable, Dict, List, - Literal, Mapping, Optional, Sequence, @@ -30,25 +27,26 @@ from openai.types.chat import ( ChatCompletionUserMessageParam, completion_create_params, ) -from typing_extensions import Required, TypedDict, Unpack +from typing_extensions import Unpack from ...application.logging import EVENT_LOGGER_NAME, LLMCallEvent - -# from ..._pydantic import type2schema from ..image import Image from ..types import ( + FunctionCall, + FunctionSignature, +) +from . import _model_info +from ._model_client import ModelCapabilities, ModelClient +from ._types import ( AssistantMessage, CreateResult, - FunctionCall, FunctionExecutionResultMessage, - FunctionSignature, LLMMessage, RequestUsage, SystemMessage, UserMessage, ) -from . import _model_info -from ._model_client import ModelCapabilities, ModelClient +from .config import AzureOpenAIClientConfiguration, OpenAIClientConfiguration logger = logging.getLogger(EVENT_LOGGER_NAME) @@ -202,53 +200,6 @@ def _add_usage(usage1: RequestUsage, usage2: RequestUsage) -> RequestUsage: ) -class ResponseFormat(TypedDict): - type: Literal["text", "json_object"] - - -class CreateArguments(TypedDict, total=False): - frequency_penalty: Optional[float] - logit_bias: Optional[Dict[str, int]] - max_tokens: Optional[int] - n: Optional[int] - presence_penalty: Optional[float] - response_format: ResponseFormat - seed: Optional[int] - stop: Union[Optional[str], List[str]] - temperature: Optional[float] - top_p: Optional[float] - user: str - - -AsyncAzureADTokenProvider = Callable[[], Union[str, Awaitable[str]]] - - -class BaseOpenAIClientConfiguration(CreateArguments, total=False): - model: str - api_key: str - timeout: Union[float, None] - max_retries: int - - -# See OpenAI docs for explanation of these parameters -class OpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False): - organization: str - base_url: str - # Not required - model_capabilities: ModelCapabilities - - -class AzureOpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False): - # Azure specific - azure_endpoint: Required[str] - azure_deployment: str - api_version: Required[str] - azure_ad_token: str - azure_ad_token_provider: AsyncAzureADTokenProvider - # Must be provided - model_capabilities: Required[ModelCapabilities] - - def convert_functions( functions: Sequence[FunctionSignature], ) -> List[ChatCompletionToolParam]: diff --git a/src/agnext/components/llm/_types.py b/src/agnext/components/llm/_types.py new file mode 100644 index 000000000..3995303fc --- /dev/null +++ b/src/agnext/components/llm/_types.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass +from typing import List, Literal, Union + +from ..image import Image +from ..types import FunctionCall + + +@dataclass +class SystemMessage: + content: str + + +@dataclass +class UserMessage: + content: Union[str, List[Union[str, Image]]] + + # Name of the agent that sent this message + source: str + + +@dataclass +class AssistantMessage: + content: Union[str, List[FunctionCall]] + + # Name of the agent that sent this message + source: str + + +@dataclass +class FunctionExecutionResult: + content: str + call_id: str + + +@dataclass +class FunctionExecutionResultMessage: + content: List[FunctionExecutionResult] + + +LLMMessage = Union[SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage] + + +@dataclass +class RequestUsage: + prompt_tokens: int + completion_tokens: int + + +FinishReasons = Literal["stop", "length", "function_calls", "content_filter"] + + +@dataclass +class CreateResult: + finish_reason: FinishReasons + content: Union[str, List[FunctionCall]] + usage: RequestUsage + cached: bool diff --git a/src/agnext/components/llm/config/__init__.py b/src/agnext/components/llm/config/__init__.py new file mode 100644 index 000000000..d1edcf8c6 --- /dev/null +++ b/src/agnext/components/llm/config/__init__.py @@ -0,0 +1,52 @@ +from typing import Awaitable, Callable, Dict, List, Literal, Optional, Union + +from typing_extensions import Required, TypedDict + +from .._model_client import ModelCapabilities + + +class ResponseFormat(TypedDict): + type: Literal["text", "json_object"] + + +class CreateArguments(TypedDict, total=False): + frequency_penalty: Optional[float] + logit_bias: Optional[Dict[str, int]] + max_tokens: Optional[int] + n: Optional[int] + presence_penalty: Optional[float] + response_format: ResponseFormat + seed: Optional[int] + stop: Union[Optional[str], List[str]] + temperature: Optional[float] + top_p: Optional[float] + user: str + + +AsyncAzureADTokenProvider = Callable[[], Union[str, Awaitable[str]]] + + +class BaseOpenAIClientConfiguration(CreateArguments, total=False): + model: str + api_key: str + timeout: Union[float, None] + max_retries: int + + +# See OpenAI docs for explanation of these parameters +class OpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False): + organization: str + base_url: str + # Not required + model_capabilities: ModelCapabilities + + +class AzureOpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False): + # Azure specific + azure_endpoint: Required[str] + azure_deployment: str + api_version: Required[str] + azure_ad_token: str + azure_ad_token_provider: AsyncAzureADTokenProvider + # Must be provided + model_capabilities: Required[ModelCapabilities] diff --git a/src/agnext/components/model_client/__init__.py b/src/agnext/components/model_client/__init__.py deleted file mode 100644 index d5a5ec187..000000000 --- a/src/agnext/components/model_client/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -from ._model_client import ModelCapabilities, ModelClient -from ._openai_client import ( - AsyncAzureADTokenProvider, - AzureOpenAI, - AzureOpenAIClientConfiguration, - BaseOpenAIClientConfiguration, - CreateArguments, - OpenAI, - OpenAIClientConfiguration, - ResponseFormat, -) - -__all__ = [ - "AzureOpenAI", - "OpenAI", - "OpenAIClientConfiguration", - "AzureOpenAIClientConfiguration", - "ResponseFormat", - "CreateArguments", - "AsyncAzureADTokenProvider", - "BaseOpenAIClientConfiguration", - "ModelCapabilities", - "ModelClient", -] diff --git a/src/agnext/components/types.py b/src/agnext/components/types.py index 22903d6ee..a62762465 100644 --- a/src/agnext/components/types.py +++ b/src/agnext/components/types.py @@ -1,11 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Dict, List, Union - -from typing_extensions import Literal - -from .image import Image +from typing import Any, Dict @dataclass @@ -22,55 +18,3 @@ class FunctionSignature: name: str parameters: Dict[str, Any] description: str - - -@dataclass -class RequestUsage: - prompt_tokens: int - completion_tokens: int - - -@dataclass -class SystemMessage: - content: str - - -@dataclass -class UserMessage: - content: Union[str, List[Union[str, Image]]] - - # Name of the agent that sent this message - source: str - - -@dataclass -class AssistantMessage: - content: Union[str, List[FunctionCall]] - - # Name of the agent that sent this message - source: str - - -@dataclass -class FunctionExecutionResult: - content: str - call_id: str - - -@dataclass -class FunctionExecutionResultMessage: - content: List[FunctionExecutionResult] - - -LLMMessage = Union[SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage] - - -FinishReasons = Literal["stop", "length", "function_calls", "content_filter"] - - -@dataclass -class CreateResult: - finish_reason: FinishReasons - content: Union[str, List[FunctionCall]] - usage: RequestUsage - cached: bool