Organize some more modules (#48)

* Organize some more modules

* cleanup model_client
This commit is contained in:
Jack Gerrits
2024-06-04 11:13:13 -04:00
committed by GitHub
parent 19570fdd98
commit ed0229734d
16 changed files with 166 additions and 172 deletions

View File

@@ -20,7 +20,7 @@ from typing import (
from pydantic import BaseModel, Field
from typing_extensions import Annotated, Literal
from .pydantic_compat import evaluate_forwardref, model_dump, type2schema
from ._pydantic_compat import evaluate_forwardref, model_dump, type2schema
logger = getLogger(__name__)

View File

@@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, Protocol, TypedDict, Union, runtime_chec
from typing_extensions import NotRequired, Required
from ..function_utils import get_function_schema
from .._function_utils import get_function_schema
from ..types import FunctionSignature

View File

@@ -0,0 +1,32 @@
from ._model_client import ModelCapabilities, ModelClient
from ._openai_client import (
AzureOpenAI,
OpenAI,
)
from ._types import (
AssistantMessage,
CreateResult,
FinishReasons,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
RequestUsage,
SystemMessage,
UserMessage,
)
__all__ = [
"AzureOpenAI",
"OpenAI",
"ModelCapabilities",
"ModelClient",
"SystemMessage",
"UserMessage",
"AssistantMessage",
"FunctionExecutionResult",
"FunctionExecutionResultMessage",
"LLMMessage",
"RequestUsage",
"FinishReasons",
"CreateResult",
]

View File

@@ -11,7 +11,8 @@ from typing_extensions import (
Union,
)
from ..types import CreateResult, FunctionSignature, LLMMessage, RequestUsage
from ..types import FunctionSignature
from ._types import CreateResult, LLMMessage, RequestUsage
class ModelCapabilities(TypedDict, total=False):

View File

@@ -4,11 +4,8 @@ import warnings
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
@@ -30,25 +27,26 @@ from openai.types.chat import (
ChatCompletionUserMessageParam,
completion_create_params,
)
from typing_extensions import Required, TypedDict, Unpack
from typing_extensions import Unpack
from ...application.logging import EVENT_LOGGER_NAME, LLMCallEvent
# from ..._pydantic import type2schema
from ..image import Image
from ..types import (
FunctionCall,
FunctionSignature,
)
from . import _model_info
from ._model_client import ModelCapabilities, ModelClient
from ._types import (
AssistantMessage,
CreateResult,
FunctionCall,
FunctionExecutionResultMessage,
FunctionSignature,
LLMMessage,
RequestUsage,
SystemMessage,
UserMessage,
)
from . import _model_info
from ._model_client import ModelCapabilities, ModelClient
from .config import AzureOpenAIClientConfiguration, OpenAIClientConfiguration
logger = logging.getLogger(EVENT_LOGGER_NAME)
@@ -202,53 +200,6 @@ def _add_usage(usage1: RequestUsage, usage2: RequestUsage) -> RequestUsage:
)
class ResponseFormat(TypedDict):
type: Literal["text", "json_object"]
class CreateArguments(TypedDict, total=False):
frequency_penalty: Optional[float]
logit_bias: Optional[Dict[str, int]]
max_tokens: Optional[int]
n: Optional[int]
presence_penalty: Optional[float]
response_format: ResponseFormat
seed: Optional[int]
stop: Union[Optional[str], List[str]]
temperature: Optional[float]
top_p: Optional[float]
user: str
AsyncAzureADTokenProvider = Callable[[], Union[str, Awaitable[str]]]
class BaseOpenAIClientConfiguration(CreateArguments, total=False):
model: str
api_key: str
timeout: Union[float, None]
max_retries: int
# See OpenAI docs for explanation of these parameters
class OpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
organization: str
base_url: str
# Not required
model_capabilities: ModelCapabilities
class AzureOpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
# Azure specific
azure_endpoint: Required[str]
azure_deployment: str
api_version: Required[str]
azure_ad_token: str
azure_ad_token_provider: AsyncAzureADTokenProvider
# Must be provided
model_capabilities: Required[ModelCapabilities]
def convert_functions(
functions: Sequence[FunctionSignature],
) -> List[ChatCompletionToolParam]:

View File

@@ -0,0 +1,57 @@
from dataclasses import dataclass
from typing import List, Literal, Union
from ..image import Image
from ..types import FunctionCall
@dataclass
class SystemMessage:
content: str
@dataclass
class UserMessage:
content: Union[str, List[Union[str, Image]]]
# Name of the agent that sent this message
source: str
@dataclass
class AssistantMessage:
content: Union[str, List[FunctionCall]]
# Name of the agent that sent this message
source: str
@dataclass
class FunctionExecutionResult:
content: str
call_id: str
@dataclass
class FunctionExecutionResultMessage:
content: List[FunctionExecutionResult]
LLMMessage = Union[SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage]
@dataclass
class RequestUsage:
prompt_tokens: int
completion_tokens: int
FinishReasons = Literal["stop", "length", "function_calls", "content_filter"]
@dataclass
class CreateResult:
finish_reason: FinishReasons
content: Union[str, List[FunctionCall]]
usage: RequestUsage
cached: bool

View File

@@ -0,0 +1,52 @@
from typing import Awaitable, Callable, Dict, List, Literal, Optional, Union
from typing_extensions import Required, TypedDict
from .._model_client import ModelCapabilities
class ResponseFormat(TypedDict):
type: Literal["text", "json_object"]
class CreateArguments(TypedDict, total=False):
frequency_penalty: Optional[float]
logit_bias: Optional[Dict[str, int]]
max_tokens: Optional[int]
n: Optional[int]
presence_penalty: Optional[float]
response_format: ResponseFormat
seed: Optional[int]
stop: Union[Optional[str], List[str]]
temperature: Optional[float]
top_p: Optional[float]
user: str
AsyncAzureADTokenProvider = Callable[[], Union[str, Awaitable[str]]]
class BaseOpenAIClientConfiguration(CreateArguments, total=False):
model: str
api_key: str
timeout: Union[float, None]
max_retries: int
# See OpenAI docs for explanation of these parameters
class OpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
organization: str
base_url: str
# Not required
model_capabilities: ModelCapabilities
class AzureOpenAIClientConfiguration(BaseOpenAIClientConfiguration, total=False):
# Azure specific
azure_endpoint: Required[str]
azure_deployment: str
api_version: Required[str]
azure_ad_token: str
azure_ad_token_provider: AsyncAzureADTokenProvider
# Must be provided
model_capabilities: Required[ModelCapabilities]

View File

@@ -1,24 +0,0 @@
from ._model_client import ModelCapabilities, ModelClient
from ._openai_client import (
AsyncAzureADTokenProvider,
AzureOpenAI,
AzureOpenAIClientConfiguration,
BaseOpenAIClientConfiguration,
CreateArguments,
OpenAI,
OpenAIClientConfiguration,
ResponseFormat,
)
__all__ = [
"AzureOpenAI",
"OpenAI",
"OpenAIClientConfiguration",
"AzureOpenAIClientConfiguration",
"ResponseFormat",
"CreateArguments",
"AsyncAzureADTokenProvider",
"BaseOpenAIClientConfiguration",
"ModelCapabilities",
"ModelClient",
]

View File

@@ -1,11 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from typing_extensions import Literal
from .image import Image
from typing import Any, Dict
@dataclass
@@ -22,55 +18,3 @@ class FunctionSignature:
name: str
parameters: Dict[str, Any]
description: str
@dataclass
class RequestUsage:
prompt_tokens: int
completion_tokens: int
@dataclass
class SystemMessage:
content: str
@dataclass
class UserMessage:
content: Union[str, List[Union[str, Image]]]
# Name of the agent that sent this message
source: str
@dataclass
class AssistantMessage:
content: Union[str, List[FunctionCall]]
# Name of the agent that sent this message
source: str
@dataclass
class FunctionExecutionResult:
content: str
call_id: str
@dataclass
class FunctionExecutionResultMessage:
content: List[FunctionExecutionResult]
LLMMessage = Union[SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage]
FinishReasons = Literal["stop", "length", "function_calls", "content_filter"]
@dataclass
class CreateResult:
finish_reason: FinishReasons
content: Union[str, List[FunctionCall]]
usage: RequestUsage
cached: bool