mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 14:37:59 -05:00
feat: Add Azure AD token authentication support for Azure provider
This commit adds support for Azure AD token authentication (Microsoft Entra ID) to the Azure AI Inference native provider, addressing issue #4069. Changes: - Add credential parameter for passing TokenCredential directly - Add azure_ad_token parameter and AZURE_AD_TOKEN env var support - Add use_default_credential flag for DefaultAzureCredential - Add _StaticTokenCredential class for wrapping static tokens - Add _select_credential method with clear priority order - Update error messages to reflect all authentication options - Add comprehensive tests for all new authentication methods Authentication Priority: 1. credential parameter (explicit TokenCredential) 2. azure_ad_token parameter or AZURE_AD_TOKEN env var 3. api_key parameter or AZURE_API_KEY env var 4. use_default_credential=True (DefaultAzureCredential) Fixes #4069 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from typing import TYPE_CHECKING, Any, TypedDict
|
from typing import TYPE_CHECKING, Any, TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -17,6 +18,8 @@ from crewai.utilities.types import LLMMessage
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from azure.core.credentials import AccessToken, TokenCredential
|
||||||
|
|
||||||
from crewai.llms.hooks.base import BaseInterceptor
|
from crewai.llms.hooks.base import BaseInterceptor
|
||||||
|
|
||||||
|
|
||||||
@@ -51,6 +54,39 @@ except ImportError:
|
|||||||
) from None
|
) from None
|
||||||
|
|
||||||
|
|
||||||
|
class _StaticTokenCredential:
|
||||||
|
"""A simple TokenCredential implementation for static Azure AD tokens.
|
||||||
|
|
||||||
|
This class wraps a static token string and provides it as a TokenCredential
|
||||||
|
that can be used with Azure SDK clients. The token is assumed to be valid
|
||||||
|
and the user is responsible for token rotation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, token: str) -> None:
|
||||||
|
"""Initialize with a static token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The Azure AD bearer token string.
|
||||||
|
"""
|
||||||
|
self._token = token
|
||||||
|
|
||||||
|
def get_token(
|
||||||
|
self, *scopes: str, **kwargs: Any
|
||||||
|
) -> AccessToken:
|
||||||
|
"""Get the static token as an AccessToken.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*scopes: Token scopes (ignored for static tokens).
|
||||||
|
**kwargs: Additional arguments (ignored).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AccessToken with the static token and a far-future expiry.
|
||||||
|
"""
|
||||||
|
from azure.core.credentials import AccessToken
|
||||||
|
|
||||||
|
return AccessToken(self._token, int(time.time()) + 3600)
|
||||||
|
|
||||||
|
|
||||||
class AzureCompletionParams(TypedDict, total=False):
|
class AzureCompletionParams(TypedDict, total=False):
|
||||||
"""Type definition for Azure chat completion parameters."""
|
"""Type definition for Azure chat completion parameters."""
|
||||||
|
|
||||||
@@ -92,6 +128,9 @@ class AzureCompletion(BaseLLM):
|
|||||||
stop: list[str] | None = None,
|
stop: list[str] | None = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||||
|
credential: TokenCredential | None = None,
|
||||||
|
azure_ad_token: str | None = None,
|
||||||
|
use_default_credential: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
"""Initialize Azure AI Inference chat completion client.
|
"""Initialize Azure AI Inference chat completion client.
|
||||||
@@ -111,7 +150,36 @@ class AzureCompletion(BaseLLM):
|
|||||||
stop: Stop sequences
|
stop: Stop sequences
|
||||||
stream: Enable streaming responses
|
stream: Enable streaming responses
|
||||||
interceptor: HTTP interceptor (not yet supported for Azure).
|
interceptor: HTTP interceptor (not yet supported for Azure).
|
||||||
|
credential: Azure TokenCredential for Azure AD authentication (e.g.,
|
||||||
|
DefaultAzureCredential, ManagedIdentityCredential). Takes precedence
|
||||||
|
over other authentication methods.
|
||||||
|
azure_ad_token: Static Azure AD token string (defaults to AZURE_AD_TOKEN
|
||||||
|
env var). Use this for scenarios where you have a pre-fetched token.
|
||||||
|
use_default_credential: If True, automatically use DefaultAzureCredential
|
||||||
|
for Azure AD authentication. Requires azure-identity package.
|
||||||
**kwargs: Additional parameters
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Authentication Priority:
|
||||||
|
1. credential parameter (explicit TokenCredential)
|
||||||
|
2. azure_ad_token parameter or AZURE_AD_TOKEN env var
|
||||||
|
3. api_key parameter or AZURE_API_KEY env var
|
||||||
|
4. use_default_credential=True (DefaultAzureCredential)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# Using API key (existing behavior)
|
||||||
|
llm = LLM(model="azure/gpt-4", api_key="...", endpoint="...")
|
||||||
|
|
||||||
|
# Using Azure AD token from environment
|
||||||
|
os.environ["AZURE_AD_TOKEN"] = token_provider()
|
||||||
|
llm = LLM(model="azure/gpt-4", endpoint="...")
|
||||||
|
|
||||||
|
# Using DefaultAzureCredential (Managed Identity, Azure CLI, etc.)
|
||||||
|
llm = LLM(model="azure/gpt-4", endpoint="...", use_default_credential=True)
|
||||||
|
|
||||||
|
# Using explicit TokenCredential
|
||||||
|
from azure.identity import ManagedIdentityCredential
|
||||||
|
llm = LLM(model="azure/gpt-4", endpoint="...",
|
||||||
|
credential=ManagedIdentityCredential())
|
||||||
"""
|
"""
|
||||||
if interceptor is not None:
|
if interceptor is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@@ -124,6 +192,9 @@ class AzureCompletion(BaseLLM):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.api_key = api_key or os.getenv("AZURE_API_KEY")
|
self.api_key = api_key or os.getenv("AZURE_API_KEY")
|
||||||
|
self.azure_ad_token = azure_ad_token or os.getenv("AZURE_AD_TOKEN")
|
||||||
|
self._explicit_credential = credential
|
||||||
|
self.use_default_credential = use_default_credential
|
||||||
self.endpoint = (
|
self.endpoint = (
|
||||||
endpoint
|
endpoint
|
||||||
or os.getenv("AZURE_ENDPOINT")
|
or os.getenv("AZURE_ENDPOINT")
|
||||||
@@ -134,10 +205,6 @@ class AzureCompletion(BaseLLM):
|
|||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.max_retries = max_retries
|
self.max_retries = max_retries
|
||||||
|
|
||||||
if not self.api_key:
|
|
||||||
raise ValueError(
|
|
||||||
"Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter."
|
|
||||||
)
|
|
||||||
if not self.endpoint:
|
if not self.endpoint:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
|
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
|
||||||
@@ -146,19 +213,22 @@ class AzureCompletion(BaseLLM):
|
|||||||
# Validate and potentially fix Azure OpenAI endpoint URL
|
# Validate and potentially fix Azure OpenAI endpoint URL
|
||||||
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)
|
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)
|
||||||
|
|
||||||
|
# Select credential based on priority
|
||||||
|
selected_credential = self._select_credential()
|
||||||
|
|
||||||
# Build client kwargs
|
# Build client kwargs
|
||||||
client_kwargs = {
|
client_kwargs: dict[str, Any] = {
|
||||||
"endpoint": self.endpoint,
|
"endpoint": self.endpoint,
|
||||||
"credential": AzureKeyCredential(self.api_key),
|
"credential": selected_credential,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add api_version if specified (primarily for Azure OpenAI endpoints)
|
# Add api_version if specified (primarily for Azure OpenAI endpoints)
|
||||||
if self.api_version:
|
if self.api_version:
|
||||||
client_kwargs["api_version"] = self.api_version
|
client_kwargs["api_version"] = self.api_version
|
||||||
|
|
||||||
self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
|
self.client = ChatCompletionsClient(**client_kwargs)
|
||||||
|
|
||||||
self.async_client = AsyncChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
|
self.async_client = AsyncChatCompletionsClient(**client_kwargs)
|
||||||
|
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
self.frequency_penalty = frequency_penalty
|
self.frequency_penalty = frequency_penalty
|
||||||
@@ -175,6 +245,47 @@ class AzureCompletion(BaseLLM):
|
|||||||
and "/openai/deployments/" in self.endpoint
|
and "/openai/deployments/" in self.endpoint
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _select_credential(self) -> AzureKeyCredential | TokenCredential:
|
||||||
|
"""Select the appropriate credential based on configuration priority.
|
||||||
|
|
||||||
|
Priority order:
|
||||||
|
1. Explicit credential parameter (TokenCredential)
|
||||||
|
2. azure_ad_token parameter or AZURE_AD_TOKEN env var
|
||||||
|
3. api_key parameter or AZURE_API_KEY env var
|
||||||
|
4. use_default_credential=True (DefaultAzureCredential)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The selected credential for Azure authentication.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no valid credentials are configured.
|
||||||
|
"""
|
||||||
|
if self._explicit_credential is not None:
|
||||||
|
return self._explicit_credential
|
||||||
|
|
||||||
|
if self.azure_ad_token:
|
||||||
|
return _StaticTokenCredential(self.azure_ad_token)
|
||||||
|
|
||||||
|
if self.api_key:
|
||||||
|
return AzureKeyCredential(self.api_key)
|
||||||
|
|
||||||
|
if self.use_default_credential:
|
||||||
|
try:
|
||||||
|
from azure.identity import DefaultAzureCredential
|
||||||
|
|
||||||
|
return DefaultAzureCredential()
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"azure-identity package is required for use_default_credential=True. "
|
||||||
|
'Install it with: uv add "azure-identity"'
|
||||||
|
) from None
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
"Azure credentials are required. Provide one of: "
|
||||||
|
"api_key / AZURE_API_KEY, azure_ad_token / AZURE_AD_TOKEN, "
|
||||||
|
"a TokenCredential via 'credential', or set use_default_credential=True."
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _validate_and_fix_endpoint(endpoint: str, model: str) -> str:
|
def _validate_and_fix_endpoint(endpoint: str, model: str) -> str:
|
||||||
"""Validate and fix Azure endpoint URL format.
|
"""Validate and fix Azure endpoint URL format.
|
||||||
|
|||||||
@@ -389,12 +389,12 @@ def test_azure_raises_error_when_endpoint_missing():
|
|||||||
|
|
||||||
|
|
||||||
def test_azure_raises_error_when_api_key_missing():
|
def test_azure_raises_error_when_api_key_missing():
|
||||||
"""Test that AzureCompletion raises ValueError when API key is missing"""
|
"""Test that AzureCompletion raises ValueError when no credentials are provided"""
|
||||||
from crewai.llms.providers.azure.completion import AzureCompletion
|
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||||
|
|
||||||
# Clear environment variables
|
# Clear environment variables
|
||||||
with patch.dict(os.environ, {}, clear=True):
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
with pytest.raises(ValueError, match="Azure API key is required"):
|
with pytest.raises(ValueError, match="Azure credentials are required"):
|
||||||
AzureCompletion(model="gpt-4", endpoint="https://test.openai.azure.com")
|
AzureCompletion(model="gpt-4", endpoint="https://test.openai.azure.com")
|
||||||
|
|
||||||
|
|
||||||
@@ -1127,3 +1127,288 @@ def test_azure_streaming_returns_usage_metrics():
|
|||||||
assert result.token_usage.prompt_tokens > 0
|
assert result.token_usage.prompt_tokens > 0
|
||||||
assert result.token_usage.completion_tokens > 0
|
assert result.token_usage.completion_tokens > 0
|
||||||
assert result.token_usage.successful_requests >= 1
|
assert result.token_usage.successful_requests >= 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_ad_token_authentication():
|
||||||
|
"""
|
||||||
|
Test that Azure AD token authentication works via AZURE_AD_TOKEN env var.
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import AzureCompletion, _StaticTokenCredential
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {
|
||||||
|
"AZURE_AD_TOKEN": "test-ad-token",
|
||||||
|
"AZURE_ENDPOINT": "https://test.openai.azure.com"
|
||||||
|
}, clear=True):
|
||||||
|
llm = LLM(model="azure/gpt-4")
|
||||||
|
|
||||||
|
assert isinstance(llm, AzureCompletion)
|
||||||
|
assert llm.azure_ad_token == "test-ad-token"
|
||||||
|
assert llm.api_key is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_ad_token_parameter():
|
||||||
|
"""
|
||||||
|
Test that azure_ad_token parameter works for Azure AD authentication.
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="azure/gpt-4",
|
||||||
|
azure_ad_token="my-ad-token",
|
||||||
|
endpoint="https://test.openai.azure.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(llm, AzureCompletion)
|
||||||
|
assert llm.azure_ad_token == "my-ad-token"
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_credential_parameter():
|
||||||
|
"""
|
||||||
|
Test that credential parameter works for passing TokenCredential directly.
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||||
|
|
||||||
|
class MockTokenCredential:
|
||||||
|
def get_token(self, *scopes, **kwargs):
|
||||||
|
from azure.core.credentials import AccessToken
|
||||||
|
return AccessToken("mock-token", 9999999999)
|
||||||
|
|
||||||
|
mock_credential = MockTokenCredential()
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="azure/gpt-4",
|
||||||
|
credential=mock_credential,
|
||||||
|
endpoint="https://test.openai.azure.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(llm, AzureCompletion)
|
||||||
|
assert llm._explicit_credential is mock_credential
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_use_default_credential():
|
||||||
|
"""
|
||||||
|
Test that use_default_credential=True uses DefaultAzureCredential.
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||||
|
|
||||||
|
try:
|
||||||
|
from azure.identity import DefaultAzureCredential
|
||||||
|
azure_identity_available = True
|
||||||
|
except ImportError:
|
||||||
|
azure_identity_available = False
|
||||||
|
|
||||||
|
if azure_identity_available:
|
||||||
|
with patch('azure.identity.DefaultAzureCredential') as mock_default_cred:
|
||||||
|
mock_default_cred.return_value = MagicMock()
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {
|
||||||
|
"AZURE_ENDPOINT": "https://test.openai.azure.com"
|
||||||
|
}, clear=True):
|
||||||
|
llm = LLM(
|
||||||
|
model="azure/gpt-4",
|
||||||
|
use_default_credential=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(llm, AzureCompletion)
|
||||||
|
assert llm.use_default_credential is True
|
||||||
|
mock_default_cred.assert_called_once()
|
||||||
|
else:
|
||||||
|
with patch.dict(os.environ, {
|
||||||
|
"AZURE_ENDPOINT": "https://test.openai.azure.com"
|
||||||
|
}, clear=True):
|
||||||
|
with pytest.raises(ImportError, match="azure-identity package is required"):
|
||||||
|
LLM(
|
||||||
|
model="azure/gpt-4",
|
||||||
|
use_default_credential=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_credential_priority_explicit_credential_first():
|
||||||
|
"""
|
||||||
|
Test that explicit credential takes priority over other auth methods.
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||||
|
|
||||||
|
class MockTokenCredential:
|
||||||
|
def get_token(self, *scopes, **kwargs):
|
||||||
|
from azure.core.credentials import AccessToken
|
||||||
|
return AccessToken("mock-token", 9999999999)
|
||||||
|
|
||||||
|
mock_credential = MockTokenCredential()
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {
|
||||||
|
"AZURE_API_KEY": "test-key",
|
||||||
|
"AZURE_AD_TOKEN": "test-ad-token",
|
||||||
|
"AZURE_ENDPOINT": "https://test.openai.azure.com"
|
||||||
|
}):
|
||||||
|
llm = LLM(
|
||||||
|
model="azure/gpt-4",
|
||||||
|
credential=mock_credential,
|
||||||
|
api_key="another-key",
|
||||||
|
azure_ad_token="another-token"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(llm, AzureCompletion)
|
||||||
|
assert llm._explicit_credential is mock_credential
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_credential_priority_ad_token_over_api_key():
|
||||||
|
"""
|
||||||
|
Test that azure_ad_token takes priority over api_key.
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import AzureCompletion, _StaticTokenCredential
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {
|
||||||
|
"AZURE_ENDPOINT": "https://test.openai.azure.com"
|
||||||
|
}, clear=True):
|
||||||
|
llm = LLM(
|
||||||
|
model="azure/gpt-4",
|
||||||
|
azure_ad_token="my-ad-token",
|
||||||
|
api_key="my-api-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(llm, AzureCompletion)
|
||||||
|
assert llm.azure_ad_token == "my-ad-token"
|
||||||
|
assert llm.api_key == "my-api-key"
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_raises_error_when_no_credentials():
|
||||||
|
"""
|
||||||
|
Test that AzureCompletion raises ValueError when no credentials are provided.
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {
|
||||||
|
"AZURE_ENDPOINT": "https://test.openai.azure.com"
|
||||||
|
}, clear=True):
|
||||||
|
with pytest.raises(ValueError, match="Azure credentials are required"):
|
||||||
|
AzureCompletion(model="gpt-4", endpoint="https://test.openai.azure.com")
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_static_token_credential():
|
||||||
|
"""
|
||||||
|
Test that _StaticTokenCredential properly wraps a static token.
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import _StaticTokenCredential
|
||||||
|
from azure.core.credentials import AccessToken
|
||||||
|
|
||||||
|
token = "my-static-token"
|
||||||
|
credential = _StaticTokenCredential(token)
|
||||||
|
|
||||||
|
access_token = credential.get_token("https://cognitiveservices.azure.com/.default")
|
||||||
|
|
||||||
|
assert isinstance(access_token, AccessToken)
|
||||||
|
assert access_token.token == token
|
||||||
|
assert access_token.expires_on > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_ad_token_env_var_used_when_no_api_key():
|
||||||
|
"""
|
||||||
|
Test that AZURE_AD_TOKEN env var is used when AZURE_API_KEY is not set.
|
||||||
|
This reproduces the scenario from GitHub issue #4069.
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {
|
||||||
|
"AZURE_AD_TOKEN": "token-from-provider",
|
||||||
|
"AZURE_ENDPOINT": "https://my-endpoint.openai.azure.com"
|
||||||
|
}, clear=True):
|
||||||
|
llm = LLM(
|
||||||
|
model="azure/gpt-4o-mini",
|
||||||
|
api_version="2024-02-01"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(llm, AzureCompletion)
|
||||||
|
assert llm.azure_ad_token == "token-from-provider"
|
||||||
|
assert llm.api_key is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_backward_compatibility_api_key():
|
||||||
|
"""
|
||||||
|
Test that existing API key authentication still works (backward compatibility).
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||||
|
from azure.core.credentials import AzureKeyCredential
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {
|
||||||
|
"AZURE_API_KEY": "test-api-key",
|
||||||
|
"AZURE_ENDPOINT": "https://test.openai.azure.com"
|
||||||
|
}, clear=True):
|
||||||
|
llm = LLM(model="azure/gpt-4")
|
||||||
|
|
||||||
|
assert isinstance(llm, AzureCompletion)
|
||||||
|
assert llm.api_key == "test-api-key"
|
||||||
|
assert llm.azure_ad_token is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_select_credential_returns_correct_type():
|
||||||
|
"""
|
||||||
|
Test that _select_credential returns the correct credential type based on config.
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import AzureCompletion, _StaticTokenCredential
|
||||||
|
from azure.core.credentials import AzureKeyCredential
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {
|
||||||
|
"AZURE_ENDPOINT": "https://test.openai.azure.com"
|
||||||
|
}, clear=True):
|
||||||
|
llm_api_key = AzureCompletion(
|
||||||
|
model="gpt-4",
|
||||||
|
api_key="test-key",
|
||||||
|
endpoint="https://test.openai.azure.com"
|
||||||
|
)
|
||||||
|
credential = llm_api_key._select_credential()
|
||||||
|
assert isinstance(credential, AzureKeyCredential)
|
||||||
|
|
||||||
|
llm_ad_token = AzureCompletion(
|
||||||
|
model="gpt-4",
|
||||||
|
azure_ad_token="test-ad-token",
|
||||||
|
endpoint="https://test.openai.azure.com"
|
||||||
|
)
|
||||||
|
credential = llm_ad_token._select_credential()
|
||||||
|
assert isinstance(credential, _StaticTokenCredential)
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_use_default_credential_import_error():
|
||||||
|
"""
|
||||||
|
Test that use_default_credential raises ImportError when azure-identity is not available.
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||||
|
import builtins
|
||||||
|
|
||||||
|
original_import = builtins.__import__
|
||||||
|
|
||||||
|
def mock_import(name, *args, **kwargs):
|
||||||
|
if name == 'azure.identity':
|
||||||
|
raise ImportError("No module named 'azure.identity'")
|
||||||
|
return original_import(name, *args, **kwargs)
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {
|
||||||
|
"AZURE_ENDPOINT": "https://test.openai.azure.com"
|
||||||
|
}, clear=True):
|
||||||
|
with patch.object(builtins, '__import__', side_effect=mock_import):
|
||||||
|
with pytest.raises(ImportError, match="azure-identity package is required"):
|
||||||
|
AzureCompletion(
|
||||||
|
model="gpt-4",
|
||||||
|
endpoint="https://test.openai.azure.com",
|
||||||
|
use_default_credential=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_azure_improved_error_message_no_credentials():
|
||||||
|
"""
|
||||||
|
Test that the error message when no credentials are provided is helpful.
|
||||||
|
"""
|
||||||
|
from crewai.llms.providers.azure.completion import AzureCompletion
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
AzureCompletion(model="gpt-4", endpoint="https://test.openai.azure.com")
|
||||||
|
|
||||||
|
error_message = str(excinfo.value)
|
||||||
|
assert "Azure credentials are required" in error_message
|
||||||
|
assert "api_key" in error_message
|
||||||
|
assert "AZURE_API_KEY" in error_message
|
||||||
|
assert "azure_ad_token" in error_message
|
||||||
|
assert "AZURE_AD_TOKEN" in error_message
|
||||||
|
assert "credential" in error_message
|
||||||
|
assert "use_default_credential" in error_message
|
||||||
|
|||||||
Reference in New Issue
Block a user