mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 21:54:10 -05:00
fix: use huggingface_hub InferenceClient for HuggingFace embeddings
Fixes #4145 The HuggingFace embedder was failing with 'could not convert string to float: error' because chromadb's HuggingFaceEmbeddingFunction uses the deprecated api-inference.huggingface.co endpoint which returns error messages instead of embeddings. This fix creates a custom HuggingFaceEmbeddingFunction that uses huggingface_hub's InferenceClient with provider='hf-inference' instead of the deprecated endpoint. Changes: - Add custom embedding_callable.py using huggingface_hub.InferenceClient - Update HuggingFaceProvider to use the new embedding callable - Handle different embedding response formats (1D, 2D, 3D arrays) - Add comprehensive error handling with actionable error messages - Add 16 test cases covering initialization, embedding generation, error handling, and ChromaDB integration Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
"""HuggingFace embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.huggingface.embedding_callable import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.huggingface.huggingface_provider import (
|
||||
HuggingFaceProvider,
|
||||
)
|
||||
@@ -10,6 +13,7 @@ from crewai.rag.embeddings.providers.huggingface.types import (
|
||||
|
||||
|
||||
__all__ = [
|
||||
"HuggingFaceEmbeddingFunction",
|
||||
"HuggingFaceProvider",
|
||||
"HuggingFaceProviderConfig",
|
||||
"HuggingFaceProviderSpec",
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
"""HuggingFace embedding function implementation using huggingface_hub."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
||||
import numpy as np
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderConfig
|
||||
|
||||
|
||||
class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Embedding function for HuggingFace models using the Inference API.
|
||||
|
||||
This implementation uses huggingface_hub's InferenceClient instead of the
|
||||
deprecated api-inference.huggingface.co endpoint that chromadb uses.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Unpack[HuggingFaceProviderConfig]) -> None:
|
||||
"""Initialize HuggingFace embedding function.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters for HuggingFace.
|
||||
- api_key: HuggingFace API key (optional for public models)
|
||||
- model_name: Model name to use for embeddings
|
||||
"""
|
||||
try:
|
||||
from huggingface_hub import InferenceClient
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"huggingface_hub is required for HuggingFace embeddings. "
|
||||
"Install it with: uv add huggingface_hub"
|
||||
) from e
|
||||
|
||||
self._config = kwargs
|
||||
self._model_name = kwargs.get(
|
||||
"model_name", "sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
api_key = kwargs.get("api_key")
|
||||
|
||||
self._client = InferenceClient(
|
||||
provider="hf-inference",
|
||||
token=api_key,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
"""Return the name of the embedding function for ChromaDB compatibility."""
|
||||
return "huggingface"
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Generate embeddings for input documents.
|
||||
|
||||
Args:
|
||||
input: List of documents to embed.
|
||||
|
||||
Returns:
|
||||
List of embedding vectors.
|
||||
|
||||
Raises:
|
||||
ValueError: If the API returns an error or unexpected response format.
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
embeddings: list[list[float]] = []
|
||||
|
||||
for text in input:
|
||||
embedding = self._get_embedding_for_text(text)
|
||||
embeddings.append(embedding)
|
||||
|
||||
return embeddings
|
||||
|
||||
def _get_embedding_for_text(self, text: str) -> list[float]:
|
||||
"""Get embedding for a single text.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
The embedding vector.
|
||||
|
||||
Raises:
|
||||
ValueError: If the API returns an error.
|
||||
"""
|
||||
try:
|
||||
result = self._client.feature_extraction(
|
||||
text=text,
|
||||
model=self._model_name,
|
||||
)
|
||||
|
||||
# Handle different response formats
|
||||
return self._process_embedding_result(result)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# Provide more helpful error messages for common issues
|
||||
if "deprecated" in error_msg.lower() or "no longer supported" in error_msg.lower():
|
||||
raise ValueError(
|
||||
f"HuggingFace API endpoint error: {error_msg}. "
|
||||
"Please ensure you have the latest version of huggingface_hub installed."
|
||||
) from e
|
||||
if "unauthorized" in error_msg.lower() or "401" in error_msg:
|
||||
raise ValueError(
|
||||
f"HuggingFace API authentication error: {error_msg}. "
|
||||
"Please check your API key configuration."
|
||||
) from e
|
||||
if "not found" in error_msg.lower() or "404" in error_msg:
|
||||
raise ValueError(
|
||||
f"HuggingFace model not found: {error_msg}. "
|
||||
f"Please verify the model name '{self._model_name}' is correct "
|
||||
"and supports feature extraction."
|
||||
) from e
|
||||
raise ValueError(f"HuggingFace API error: {error_msg}") from e
|
||||
|
||||
def _process_embedding_result(self, result: Any) -> list[float]:
|
||||
"""Process the embedding result from the API.
|
||||
|
||||
The HuggingFace API can return different formats depending on the model:
|
||||
- 1D array: Direct embedding vector
|
||||
- 2D array: Token-level embeddings (needs pooling)
|
||||
- Nested structure: Various model-specific formats
|
||||
|
||||
Args:
|
||||
result: The raw result from the API.
|
||||
|
||||
Returns:
|
||||
A 1D list of floats representing the embedding.
|
||||
|
||||
Raises:
|
||||
ValueError: If the result format is unexpected.
|
||||
"""
|
||||
# Convert to numpy array for easier processing
|
||||
arr = np.array(result)
|
||||
|
||||
# Handle different dimensionalities
|
||||
if arr.ndim == 1:
|
||||
# Already a 1D embedding vector
|
||||
return arr.astype(np.float32).tolist()
|
||||
if arr.ndim == 2:
|
||||
# Token-level embeddings - apply mean pooling
|
||||
pooled = np.mean(arr, axis=0)
|
||||
return pooled.astype(np.float32).tolist()
|
||||
if arr.ndim == 3:
|
||||
# Batch of token-level embeddings - take first and apply mean pooling
|
||||
pooled = np.mean(arr[0], axis=0)
|
||||
return pooled.astype(np.float32).tolist()
|
||||
raise ValueError(
|
||||
f"Unexpected embedding result shape: {arr.shape}. "
|
||||
"Expected 1D, 2D, or 3D array."
|
||||
)
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return the configuration for serialization."""
|
||||
return {
|
||||
"model_name": self._model_name,
|
||||
"api_key": self._config.get("api_key"),
|
||||
}
|
||||
@@ -1,11 +1,11 @@
|
||||
"""HuggingFace embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.huggingface.embedding_callable import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingFunction]):
|
||||
|
||||
290
lib/crewai/tests/rag/embeddings/test_huggingface_embedder.py
Normal file
290
lib/crewai/tests/rag/embeddings/test_huggingface_embedder.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""Tests for HuggingFace embedding function."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from crewai.rag.embeddings.providers.huggingface.embedding_callable import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class TestHuggingFaceEmbeddingFunction:
|
||||
"""Test HuggingFace embedding function."""
|
||||
|
||||
@patch("huggingface_hub.InferenceClient")
|
||||
def test_initialization_with_api_key(self, mock_client_class):
|
||||
"""Test initialization with API key."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
ef = HuggingFaceEmbeddingFunction(
|
||||
api_key="test-api-key",
|
||||
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
||||
)
|
||||
|
||||
mock_client_class.assert_called_once_with(
|
||||
provider="hf-inference",
|
||||
token="test-api-key",
|
||||
)
|
||||
assert ef._model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
@patch("huggingface_hub.InferenceClient")
|
||||
def test_initialization_without_api_key(self, mock_client_class):
|
||||
"""Test initialization without API key (for public models)."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
ef = HuggingFaceEmbeddingFunction(
|
||||
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
||||
)
|
||||
|
||||
mock_client_class.assert_called_once_with(
|
||||
provider="hf-inference",
|
||||
token=None,
|
||||
)
|
||||
assert ef._model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
@patch("huggingface_hub.InferenceClient")
|
||||
def test_initialization_with_default_model(self, mock_client_class):
|
||||
"""Test initialization with default model name."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
ef = HuggingFaceEmbeddingFunction()
|
||||
|
||||
assert ef._model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
|
||||
@patch("huggingface_hub.InferenceClient")
|
||||
def test_call_with_single_document(self, mock_client_class):
|
||||
"""Test embedding generation for a single document."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Mock the feature_extraction response (1D embedding)
|
||||
mock_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
mock_client.feature_extraction.return_value = mock_embedding
|
||||
|
||||
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
|
||||
result = ef(["Hello, world!"])
|
||||
|
||||
mock_client.feature_extraction.assert_called_once_with(
|
||||
text="Hello, world!",
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0] == pytest.approx(mock_embedding, rel=1e-5)
|
||||
|
||||
@patch("huggingface_hub.InferenceClient")
|
||||
def test_call_with_multiple_documents(self, mock_client_class):
|
||||
"""Test embedding generation for multiple documents."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Mock the feature_extraction response
|
||||
mock_embedding1 = [0.1, 0.2, 0.3]
|
||||
mock_embedding2 = [0.4, 0.5, 0.6]
|
||||
mock_client.feature_extraction.side_effect = [mock_embedding1, mock_embedding2]
|
||||
|
||||
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
|
||||
result = ef(["Hello", "World"])
|
||||
|
||||
assert mock_client.feature_extraction.call_count == 2
|
||||
assert len(result) == 2
|
||||
assert result[0] == pytest.approx(mock_embedding1, rel=1e-5)
|
||||
assert result[1] == pytest.approx(mock_embedding2, rel=1e-5)
|
||||
|
||||
@patch("huggingface_hub.InferenceClient")
|
||||
def test_call_with_string_input(self, mock_client_class):
|
||||
"""Test that string input is converted to list."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_embedding = [0.1, 0.2, 0.3]
|
||||
mock_client.feature_extraction.return_value = mock_embedding
|
||||
|
||||
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
|
||||
result = ef("Hello") # type: ignore[arg-type]
|
||||
|
||||
mock_client.feature_extraction.assert_called_once()
|
||||
assert len(result) == 1
|
||||
|
||||
@patch("huggingface_hub.InferenceClient")
|
||||
def test_process_2d_embedding_result(self, mock_client_class):
|
||||
"""Test processing of 2D token-level embeddings (mean pooling)."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Mock 2D token-level embeddings (3 tokens, 4 dimensions each)
|
||||
mock_token_embeddings = [
|
||||
[0.1, 0.2, 0.3, 0.4],
|
||||
[0.2, 0.3, 0.4, 0.5],
|
||||
[0.3, 0.4, 0.5, 0.6],
|
||||
]
|
||||
mock_client.feature_extraction.return_value = mock_token_embeddings
|
||||
|
||||
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
|
||||
result = ef(["Hello"])
|
||||
|
||||
# Expected: mean pooling across tokens
|
||||
expected = np.mean(mock_token_embeddings, axis=0).tolist()
|
||||
assert len(result) == 1
|
||||
assert result[0] == pytest.approx(expected, rel=1e-5)
|
||||
|
||||
@patch("huggingface_hub.InferenceClient")
|
||||
def test_process_3d_embedding_result(self, mock_client_class):
|
||||
"""Test processing of 3D batch token-level embeddings."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Mock 3D embeddings (1 batch, 3 tokens, 4 dimensions)
|
||||
mock_batch_embeddings = [
|
||||
[
|
||||
[0.1, 0.2, 0.3, 0.4],
|
||||
[0.2, 0.3, 0.4, 0.5],
|
||||
[0.3, 0.4, 0.5, 0.6],
|
||||
]
|
||||
]
|
||||
mock_client.feature_extraction.return_value = mock_batch_embeddings
|
||||
|
||||
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
|
||||
result = ef(["Hello"])
|
||||
|
||||
# Expected: take first batch, then mean pooling
|
||||
expected = np.mean(mock_batch_embeddings[0], axis=0).tolist()
|
||||
assert len(result) == 1
|
||||
assert result[0] == pytest.approx(expected, rel=1e-5)
|
||||
|
||||
@patch("huggingface_hub.InferenceClient")
|
||||
def test_error_handling_deprecated_endpoint(self, mock_client_class):
|
||||
"""Test error handling for deprecated endpoint error."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_client.feature_extraction.side_effect = Exception(
|
||||
"https://api-inference.huggingface.co is no longer supported"
|
||||
)
|
||||
|
||||
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
|
||||
|
||||
with pytest.raises(ValueError, match="HuggingFace API endpoint error"):
|
||||
ef(["Hello"])
|
||||
|
||||
@patch("huggingface_hub.InferenceClient")
|
||||
def test_error_handling_unauthorized(self, mock_client_class):
|
||||
"""Test error handling for authentication error."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_client.feature_extraction.side_effect = Exception("401 Unauthorized")
|
||||
|
||||
ef = HuggingFaceEmbeddingFunction(api_key="invalid-key")
|
||||
|
||||
with pytest.raises(ValueError, match="HuggingFace API authentication error"):
|
||||
ef(["Hello"])
|
||||
|
||||
@patch("huggingface_hub.InferenceClient")
|
||||
def test_error_handling_model_not_found(self, mock_client_class):
|
||||
"""Test error handling for model not found error."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_client.feature_extraction.side_effect = Exception("404 Not Found")
|
||||
|
||||
ef = HuggingFaceEmbeddingFunction(
|
||||
api_key="test-key", model_name="nonexistent/model"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="HuggingFace model not found"):
|
||||
ef(["Hello"])
|
||||
|
||||
@patch("huggingface_hub.InferenceClient")
|
||||
def test_error_handling_generic_error(self, mock_client_class):
|
||||
"""Test error handling for generic API error."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_client.feature_extraction.side_effect = Exception("Some unexpected error")
|
||||
|
||||
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
|
||||
|
||||
with pytest.raises(ValueError, match="HuggingFace API error"):
|
||||
ef(["Hello"])
|
||||
|
||||
@patch("huggingface_hub.InferenceClient")
|
||||
def test_name_method(self, mock_client_class):
|
||||
"""Test the name() static method."""
|
||||
assert HuggingFaceEmbeddingFunction.name() == "huggingface"
|
||||
|
||||
@patch("huggingface_hub.InferenceClient")
|
||||
def test_get_config(self, mock_client_class):
|
||||
"""Test get_config method."""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
ef = HuggingFaceEmbeddingFunction(
|
||||
api_key="test-key",
|
||||
model_name="custom/model",
|
||||
)
|
||||
|
||||
config = ef.get_config()
|
||||
assert config["model_name"] == "custom/model"
|
||||
assert config["api_key"] == "test-key"
|
||||
|
||||
|
||||
class TestHuggingFaceEmbeddingFunctionIntegration:
|
||||
"""Integration tests for HuggingFace embedding function with RAGStorage."""
|
||||
|
||||
@patch("huggingface_hub.InferenceClient")
|
||||
def test_embedding_function_works_with_rag_storage_validation(
|
||||
self, mock_client_class
|
||||
):
|
||||
"""Test that the embedding function works with RAGStorage validation.
|
||||
|
||||
This test simulates the validation that happens in RAGStorage.__init__
|
||||
where it calls embedding_function(["test"]) to verify the embedder works.
|
||||
"""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
# Mock a valid embedding response
|
||||
mock_embedding = [0.1] * 384 # 384 dimensions like all-MiniLM-L6-v2
|
||||
mock_client.feature_extraction.return_value = mock_embedding
|
||||
|
||||
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
|
||||
|
||||
# This is what RAGStorage does to validate the embedder
|
||||
result = ef(["test"])
|
||||
|
||||
assert len(result) == 1
|
||||
assert len(result[0]) == 384
|
||||
# Values should be numeric (float or numpy float)
|
||||
assert all(isinstance(x, (int, float)) or hasattr(x, "__float__") for x in result[0])
|
||||
|
||||
@patch("huggingface_hub.InferenceClient")
|
||||
def test_embedding_function_returns_correct_format_for_chromadb(
|
||||
self, mock_client_class
|
||||
):
|
||||
"""Test that embeddings are in the correct format for ChromaDB.
|
||||
|
||||
ChromaDB expects embeddings as a sequence of embedding vectors where each
|
||||
inner element is a 1D embedding vector with numeric values.
|
||||
"""
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
mock_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]
|
||||
mock_client.feature_extraction.return_value = mock_embedding
|
||||
|
||||
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
|
||||
result = ef(["Hello", "World"])
|
||||
|
||||
# ChromaDB expects a sequence of embedding vectors
|
||||
assert isinstance(result, list)
|
||||
for embedding in result:
|
||||
# Each embedding should be a sequence of numeric values
|
||||
assert len(embedding) == 5
|
||||
for value in embedding:
|
||||
# Values should be numeric (float or numpy float)
|
||||
assert isinstance(value, (int, float)) or hasattr(value, "__float__")
|
||||
Reference in New Issue
Block a user