mirror of
https://github.com/Pythagora-io/gpt-pilot.git
synced 2026-01-09 13:17:55 -05:00
add first-class support for Azure OpenAI
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from enum import Enum
|
||||
from os.path import abspath, dirname, isdir, join
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from typing_extensions import Annotated
|
||||
@@ -55,6 +55,7 @@ class LLMProvider(str, Enum):
|
||||
ANTHROPIC = "anthropic"
|
||||
GROQ = "groq"
|
||||
LM_STUDIO = "lm-studio"
|
||||
AZURE = "azure"
|
||||
|
||||
|
||||
class UIAdapter(str, Enum):
|
||||
@@ -89,6 +90,10 @@ class ProviderConfig(_StrictModel):
|
||||
description="Timeout (in seconds) for receiving a new chunk of data from the response stream",
|
||||
ge=0.0,
|
||||
)
|
||||
extra: Optional[dict[str, Any]] = Field(
|
||||
None,
|
||||
description="Extra provider-specific configuration",
|
||||
)
|
||||
|
||||
|
||||
class AgentLLMConfig(_StrictModel):
|
||||
@@ -140,6 +145,10 @@ class LLMConfig(_StrictModel):
|
||||
description="Timeout (in seconds) for receiving a new chunk of data from the response stream",
|
||||
ge=0.0,
|
||||
)
|
||||
extra: Optional[dict[str, Any]] = Field(
|
||||
None,
|
||||
description="Extra provider-specific configuration",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_provider_and_agent_configs(cls, provider: ProviderConfig, agent: AgentLLMConfig):
|
||||
@@ -151,6 +160,7 @@ class LLMConfig(_StrictModel):
|
||||
temperature=agent.temperature,
|
||||
connect_timeout=provider.connect_timeout,
|
||||
read_timeout=provider.read_timeout,
|
||||
extra=provider.extra,
|
||||
)
|
||||
|
||||
|
||||
|
||||
29
core/llm/azure_client.py
Normal file
29
core/llm/azure_client.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from httpx import Timeout
|
||||
from openai import AsyncAzureOpenAI
|
||||
|
||||
from core.config import LLMProvider
|
||||
from core.llm.openai_client import OpenAIClient
|
||||
from core.log import get_logger
|
||||
|
||||
log = get_logger(__name__)
|
||||
|
||||
|
||||
class AzureClient(OpenAIClient):
|
||||
provider = LLMProvider.AZURE
|
||||
stream_options = None
|
||||
|
||||
def _init_client(self):
|
||||
azure_deployment = self.config.extra.get("azure_deployment")
|
||||
api_version = self.config.extra.get("api_version")
|
||||
|
||||
self.client = AsyncAzureOpenAI(
|
||||
api_key=self.config.api_key,
|
||||
azure_endpoint=self.config.base_url,
|
||||
azure_deployment=azure_deployment,
|
||||
api_version=api_version,
|
||||
timeout=Timeout(
|
||||
max(self.config.connect_timeout, self.config.read_timeout),
|
||||
connect=self.config.connect_timeout,
|
||||
read=self.config.read_timeout,
|
||||
),
|
||||
)
|
||||
@@ -316,6 +316,7 @@ class BaseLLMClient:
|
||||
:return: Client class for the specified provider.
|
||||
"""
|
||||
from .anthropic_client import AnthropicClient
|
||||
from .azure_client import AzureClient
|
||||
from .groq_client import GroqClient
|
||||
from .openai_client import OpenAIClient
|
||||
|
||||
@@ -325,6 +326,8 @@ class BaseLLMClient:
|
||||
return AnthropicClient
|
||||
elif provider == LLMProvider.GROQ:
|
||||
return GroqClient
|
||||
elif provider == LLMProvider.AZURE:
|
||||
return AzureClient
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM provider: {provider.value}")
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
class OpenAIClient(BaseLLMClient):
|
||||
provider = LLMProvider.OPENAI
|
||||
stream_options = {"include_usage": True}
|
||||
|
||||
def _init_client(self):
|
||||
self.client = AsyncOpenAI(
|
||||
@@ -40,10 +41,10 @@ class OpenAIClient(BaseLLMClient):
|
||||
"messages": convo.messages,
|
||||
"temperature": self.config.temperature if temperature is None else temperature,
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
},
|
||||
}
|
||||
if self.stream_options:
|
||||
completion_kwargs["stream_options"] = self.stream_options
|
||||
|
||||
if json_mode:
|
||||
completion_kwargs["response_format"] = {"type": "json_object"}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
// Configuration for the LLM providers that can be used. Pythagora supports
|
||||
// OpenAI, Anthropic and Groq. Azure and OpenRouter and local LLMs (such as LM-Studio)
|
||||
// OpenAI, Azure, Anthropic and Groq. OpenRouter and local LLMs (such as LM-Studio)
|
||||
// also work, you can use "openai" provider to define these.
|
||||
"llm": {
|
||||
"openai": {
|
||||
@@ -9,6 +9,17 @@
|
||||
"api_key": null,
|
||||
"connect_timeout": 60.0,
|
||||
"read_timeout": 10.0
|
||||
},
|
||||
// Example config for Azure OpenAI (see https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions)
|
||||
"azure": {
|
||||
"base_url": "https://your-resource-name.openai.azure.com/",
|
||||
"api_key": "your-api-key",
|
||||
"connect_timeout": 60.0,
|
||||
"read_timeout": 10.0,
|
||||
"extra": {
|
||||
"azure_deployment": "your-azure-deployment-id",
|
||||
"api_version": "2024-02-01"
|
||||
}
|
||||
}
|
||||
},
|
||||
// Each agent can use a different model or configuration. The default, as before, is GPT4 Turbo
|
||||
|
||||
Reference in New Issue
Block a user