mirror of
https://github.com/Pythagora-io/gpt-pilot.git
synced 2026-01-09 21:27:53 -05:00
add first-class support for Azure OpenAI
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from os.path import abspath, dirname, isdir, join
|
from os.path import abspath, dirname, isdir, join
|
||||||
from typing import Literal, Optional, Union
|
from typing import Any, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
@@ -55,6 +55,7 @@ class LLMProvider(str, Enum):
|
|||||||
ANTHROPIC = "anthropic"
|
ANTHROPIC = "anthropic"
|
||||||
GROQ = "groq"
|
GROQ = "groq"
|
||||||
LM_STUDIO = "lm-studio"
|
LM_STUDIO = "lm-studio"
|
||||||
|
AZURE = "azure"
|
||||||
|
|
||||||
|
|
||||||
class UIAdapter(str, Enum):
|
class UIAdapter(str, Enum):
|
||||||
@@ -89,6 +90,10 @@ class ProviderConfig(_StrictModel):
|
|||||||
description="Timeout (in seconds) for receiving a new chunk of data from the response stream",
|
description="Timeout (in seconds) for receiving a new chunk of data from the response stream",
|
||||||
ge=0.0,
|
ge=0.0,
|
||||||
)
|
)
|
||||||
|
extra: Optional[dict[str, Any]] = Field(
|
||||||
|
None,
|
||||||
|
description="Extra provider-specific configuration",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AgentLLMConfig(_StrictModel):
|
class AgentLLMConfig(_StrictModel):
|
||||||
@@ -140,6 +145,10 @@ class LLMConfig(_StrictModel):
|
|||||||
description="Timeout (in seconds) for receiving a new chunk of data from the response stream",
|
description="Timeout (in seconds) for receiving a new chunk of data from the response stream",
|
||||||
ge=0.0,
|
ge=0.0,
|
||||||
)
|
)
|
||||||
|
extra: Optional[dict[str, Any]] = Field(
|
||||||
|
None,
|
||||||
|
description="Extra provider-specific configuration",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_provider_and_agent_configs(cls, provider: ProviderConfig, agent: AgentLLMConfig):
|
def from_provider_and_agent_configs(cls, provider: ProviderConfig, agent: AgentLLMConfig):
|
||||||
@@ -151,6 +160,7 @@ class LLMConfig(_StrictModel):
|
|||||||
temperature=agent.temperature,
|
temperature=agent.temperature,
|
||||||
connect_timeout=provider.connect_timeout,
|
connect_timeout=provider.connect_timeout,
|
||||||
read_timeout=provider.read_timeout,
|
read_timeout=provider.read_timeout,
|
||||||
|
extra=provider.extra,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
29
core/llm/azure_client.py
Normal file
29
core/llm/azure_client.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from httpx import Timeout
|
||||||
|
from openai import AsyncAzureOpenAI
|
||||||
|
|
||||||
|
from core.config import LLMProvider
|
||||||
|
from core.llm.openai_client import OpenAIClient
|
||||||
|
from core.log import get_logger
|
||||||
|
|
||||||
|
log = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AzureClient(OpenAIClient):
|
||||||
|
provider = LLMProvider.AZURE
|
||||||
|
stream_options = None
|
||||||
|
|
||||||
|
def _init_client(self):
|
||||||
|
azure_deployment = self.config.extra.get("azure_deployment")
|
||||||
|
api_version = self.config.extra.get("api_version")
|
||||||
|
|
||||||
|
self.client = AsyncAzureOpenAI(
|
||||||
|
api_key=self.config.api_key,
|
||||||
|
azure_endpoint=self.config.base_url,
|
||||||
|
azure_deployment=azure_deployment,
|
||||||
|
api_version=api_version,
|
||||||
|
timeout=Timeout(
|
||||||
|
max(self.config.connect_timeout, self.config.read_timeout),
|
||||||
|
connect=self.config.connect_timeout,
|
||||||
|
read=self.config.read_timeout,
|
||||||
|
),
|
||||||
|
)
|
||||||
@@ -316,6 +316,7 @@ class BaseLLMClient:
|
|||||||
:return: Client class for the specified provider.
|
:return: Client class for the specified provider.
|
||||||
"""
|
"""
|
||||||
from .anthropic_client import AnthropicClient
|
from .anthropic_client import AnthropicClient
|
||||||
|
from .azure_client import AzureClient
|
||||||
from .groq_client import GroqClient
|
from .groq_client import GroqClient
|
||||||
from .openai_client import OpenAIClient
|
from .openai_client import OpenAIClient
|
||||||
|
|
||||||
@@ -325,6 +326,8 @@ class BaseLLMClient:
|
|||||||
return AnthropicClient
|
return AnthropicClient
|
||||||
elif provider == LLMProvider.GROQ:
|
elif provider == LLMProvider.GROQ:
|
||||||
return GroqClient
|
return GroqClient
|
||||||
|
elif provider == LLMProvider.AZURE:
|
||||||
|
return AzureClient
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported LLM provider: {provider.value}")
|
raise ValueError(f"Unsupported LLM provider: {provider.value}")
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ tokenizer = tiktoken.get_encoding("cl100k_base")
|
|||||||
|
|
||||||
class OpenAIClient(BaseLLMClient):
|
class OpenAIClient(BaseLLMClient):
|
||||||
provider = LLMProvider.OPENAI
|
provider = LLMProvider.OPENAI
|
||||||
|
stream_options = {"include_usage": True}
|
||||||
|
|
||||||
def _init_client(self):
|
def _init_client(self):
|
||||||
self.client = AsyncOpenAI(
|
self.client = AsyncOpenAI(
|
||||||
@@ -40,10 +41,10 @@ class OpenAIClient(BaseLLMClient):
|
|||||||
"messages": convo.messages,
|
"messages": convo.messages,
|
||||||
"temperature": self.config.temperature if temperature is None else temperature,
|
"temperature": self.config.temperature if temperature is None else temperature,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"stream_options": {
|
|
||||||
"include_usage": True,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
if self.stream_options:
|
||||||
|
completion_kwargs["stream_options"] = self.stream_options
|
||||||
|
|
||||||
if json_mode:
|
if json_mode:
|
||||||
completion_kwargs["response_format"] = {"type": "json_object"}
|
completion_kwargs["response_format"] = {"type": "json_object"}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
// Configuration for the LLM providers that can be used. Pythagora supports
|
// Configuration for the LLM providers that can be used. Pythagora supports
|
||||||
// OpenAI, Anthropic and Groq. Azure and OpenRouter and local LLMs (such as LM-Studio)
|
// OpenAI, Azure, Anthropic and Groq. OpenRouter and local LLMs (such as LM-Studio)
|
||||||
// also work, you can use "openai" provider to define these.
|
// also work, you can use "openai" provider to define these.
|
||||||
"llm": {
|
"llm": {
|
||||||
"openai": {
|
"openai": {
|
||||||
@@ -9,6 +9,17 @@
|
|||||||
"api_key": null,
|
"api_key": null,
|
||||||
"connect_timeout": 60.0,
|
"connect_timeout": 60.0,
|
||||||
"read_timeout": 10.0
|
"read_timeout": 10.0
|
||||||
|
},
|
||||||
|
// Example config for Azure OpenAI (see https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions)
|
||||||
|
"azure": {
|
||||||
|
"base_url": "https://your-resource-name.openai.azure.com/",
|
||||||
|
"api_key": "your-api-key",
|
||||||
|
"connect_timeout": 60.0,
|
||||||
|
"read_timeout": 10.0,
|
||||||
|
"extra": {
|
||||||
|
"azure_deployment": "your-azure-deployment-id",
|
||||||
|
"api_version": "2024-02-01"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
// Each agent can use a different model or configuration. The default, as before, is GPT4 Turbo
|
// Each agent can use a different model or configuration. The default, as before, is GPT4 Turbo
|
||||||
|
|||||||
Reference in New Issue
Block a user