fix: use huggingface_hub InferenceClient for HuggingFace embeddings

Fixes #4145

The HuggingFace embedder was failing with 'could not convert string to float: error'
because chromadb's HuggingFaceEmbeddingFunction uses the deprecated
api-inference.huggingface.co endpoint which returns error messages instead of embeddings.

This fix creates a custom HuggingFaceEmbeddingFunction that uses huggingface_hub's
InferenceClient with provider='hf-inference' instead of the deprecated endpoint.

Changes:
- Add custom embedding_callable.py using huggingface_hub.InferenceClient
- Update HuggingFaceProvider to use the new embedding callable
- Handle different embedding response formats (1D, 2D, 3D arrays)
- Add comprehensive error handling with actionable error messages
- Add 16 test cases covering initialization, embedding generation, error handling,
  and ChromaDB integration

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-12-22 22:44:13 +00:00
parent be70a04153
commit 9460e5e182
4 changed files with 455 additions and 3 deletions

View File

@@ -1,5 +1,8 @@
"""HuggingFace embedding providers."""
from crewai.rag.embeddings.providers.huggingface.embedding_callable import (
HuggingFaceEmbeddingFunction,
)
from crewai.rag.embeddings.providers.huggingface.huggingface_provider import (
HuggingFaceProvider,
)
@@ -10,6 +13,7 @@ from crewai.rag.embeddings.providers.huggingface.types import (
__all__ = [
"HuggingFaceEmbeddingFunction",
"HuggingFaceProvider",
"HuggingFaceProviderConfig",
"HuggingFaceProviderSpec",

View File

@@ -0,0 +1,158 @@
"""HuggingFace embedding function implementation using huggingface_hub."""
from typing import Any
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
import numpy as np
from typing_extensions import Unpack
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderConfig
class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for HuggingFace models using the Inference API.
This implementation uses huggingface_hub's InferenceClient instead of the
deprecated api-inference.huggingface.co endpoint that chromadb uses.
"""
def __init__(self, **kwargs: Unpack[HuggingFaceProviderConfig]) -> None:
"""Initialize HuggingFace embedding function.
Args:
**kwargs: Configuration parameters for HuggingFace.
- api_key: HuggingFace API key (optional for public models)
- model_name: Model name to use for embeddings
"""
try:
from huggingface_hub import InferenceClient
except ImportError as e:
raise ImportError(
"huggingface_hub is required for HuggingFace embeddings. "
"Install it with: uv add huggingface_hub"
) from e
self._config = kwargs
self._model_name = kwargs.get(
"model_name", "sentence-transformers/all-MiniLM-L6-v2"
)
api_key = kwargs.get("api_key")
self._client = InferenceClient(
provider="hf-inference",
token=api_key,
)
@staticmethod
def name() -> str:
"""Return the name of the embedding function for ChromaDB compatibility."""
return "huggingface"
def __call__(self, input: Documents) -> Embeddings:
"""Generate embeddings for input documents.
Args:
input: List of documents to embed.
Returns:
List of embedding vectors.
Raises:
ValueError: If the API returns an error or unexpected response format.
"""
if isinstance(input, str):
input = [input]
embeddings: list[list[float]] = []
for text in input:
embedding = self._get_embedding_for_text(text)
embeddings.append(embedding)
return embeddings
def _get_embedding_for_text(self, text: str) -> list[float]:
"""Get embedding for a single text.
Args:
text: The text to embed.
Returns:
The embedding vector.
Raises:
ValueError: If the API returns an error.
"""
try:
result = self._client.feature_extraction(
text=text,
model=self._model_name,
)
# Handle different response formats
return self._process_embedding_result(result)
except Exception as e:
error_msg = str(e)
# Provide more helpful error messages for common issues
if "deprecated" in error_msg.lower() or "no longer supported" in error_msg.lower():
raise ValueError(
f"HuggingFace API endpoint error: {error_msg}. "
"Please ensure you have the latest version of huggingface_hub installed."
) from e
if "unauthorized" in error_msg.lower() or "401" in error_msg:
raise ValueError(
f"HuggingFace API authentication error: {error_msg}. "
"Please check your API key configuration."
) from e
if "not found" in error_msg.lower() or "404" in error_msg:
raise ValueError(
f"HuggingFace model not found: {error_msg}. "
f"Please verify the model name '{self._model_name}' is correct "
"and supports feature extraction."
) from e
raise ValueError(f"HuggingFace API error: {error_msg}") from e
def _process_embedding_result(self, result: Any) -> list[float]:
"""Process the embedding result from the API.
The HuggingFace API can return different formats depending on the model:
- 1D array: Direct embedding vector
- 2D array: Token-level embeddings (needs pooling)
- Nested structure: Various model-specific formats
Args:
result: The raw result from the API.
Returns:
A 1D list of floats representing the embedding.
Raises:
ValueError: If the result format is unexpected.
"""
# Convert to numpy array for easier processing
arr = np.array(result)
# Handle different dimensionalities
if arr.ndim == 1:
# Already a 1D embedding vector
return arr.astype(np.float32).tolist()
if arr.ndim == 2:
# Token-level embeddings - apply mean pooling
pooled = np.mean(arr, axis=0)
return pooled.astype(np.float32).tolist()
if arr.ndim == 3:
# Batch of token-level embeddings - take first and apply mean pooling
pooled = np.mean(arr[0], axis=0)
return pooled.astype(np.float32).tolist()
raise ValueError(
f"Unexpected embedding result shape: {arr.shape}. "
"Expected 1D, 2D, or 3D array."
)
def get_config(self) -> dict[str, Any]:
"""Return the configuration for serialization."""
return {
"model_name": self._model_name,
"api_key": self._config.get("api_key"),
}

View File

@@ -1,11 +1,11 @@
"""HuggingFace embeddings provider."""
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingFunction,
)
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.huggingface.embedding_callable import (
HuggingFaceEmbeddingFunction,
)
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingFunction]):

View File

@@ -0,0 +1,290 @@
"""Tests for HuggingFace embedding function."""
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from crewai.rag.embeddings.providers.huggingface.embedding_callable import (
HuggingFaceEmbeddingFunction,
)
class TestHuggingFaceEmbeddingFunction:
"""Test HuggingFace embedding function."""
@patch("huggingface_hub.InferenceClient")
def test_initialization_with_api_key(self, mock_client_class):
"""Test initialization with API key."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
ef = HuggingFaceEmbeddingFunction(
api_key="test-api-key",
model_name="sentence-transformers/all-MiniLM-L6-v2",
)
mock_client_class.assert_called_once_with(
provider="hf-inference",
token="test-api-key",
)
assert ef._model_name == "sentence-transformers/all-MiniLM-L6-v2"
@patch("huggingface_hub.InferenceClient")
def test_initialization_without_api_key(self, mock_client_class):
"""Test initialization without API key (for public models)."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
ef = HuggingFaceEmbeddingFunction(
model_name="sentence-transformers/all-MiniLM-L6-v2",
)
mock_client_class.assert_called_once_with(
provider="hf-inference",
token=None,
)
assert ef._model_name == "sentence-transformers/all-MiniLM-L6-v2"
@patch("huggingface_hub.InferenceClient")
def test_initialization_with_default_model(self, mock_client_class):
"""Test initialization with default model name."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
ef = HuggingFaceEmbeddingFunction()
assert ef._model_name == "sentence-transformers/all-MiniLM-L6-v2"
@patch("huggingface_hub.InferenceClient")
def test_call_with_single_document(self, mock_client_class):
"""Test embedding generation for a single document."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
# Mock the feature_extraction response (1D embedding)
mock_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]
mock_client.feature_extraction.return_value = mock_embedding
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
result = ef(["Hello, world!"])
mock_client.feature_extraction.assert_called_once_with(
text="Hello, world!",
model="sentence-transformers/all-MiniLM-L6-v2",
)
assert len(result) == 1
assert result[0] == pytest.approx(mock_embedding, rel=1e-5)
@patch("huggingface_hub.InferenceClient")
def test_call_with_multiple_documents(self, mock_client_class):
"""Test embedding generation for multiple documents."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
# Mock the feature_extraction response
mock_embedding1 = [0.1, 0.2, 0.3]
mock_embedding2 = [0.4, 0.5, 0.6]
mock_client.feature_extraction.side_effect = [mock_embedding1, mock_embedding2]
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
result = ef(["Hello", "World"])
assert mock_client.feature_extraction.call_count == 2
assert len(result) == 2
assert result[0] == pytest.approx(mock_embedding1, rel=1e-5)
assert result[1] == pytest.approx(mock_embedding2, rel=1e-5)
@patch("huggingface_hub.InferenceClient")
def test_call_with_string_input(self, mock_client_class):
"""Test that string input is converted to list."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_embedding = [0.1, 0.2, 0.3]
mock_client.feature_extraction.return_value = mock_embedding
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
result = ef("Hello") # type: ignore[arg-type]
mock_client.feature_extraction.assert_called_once()
assert len(result) == 1
@patch("huggingface_hub.InferenceClient")
def test_process_2d_embedding_result(self, mock_client_class):
"""Test processing of 2D token-level embeddings (mean pooling)."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
# Mock 2D token-level embeddings (3 tokens, 4 dimensions each)
mock_token_embeddings = [
[0.1, 0.2, 0.3, 0.4],
[0.2, 0.3, 0.4, 0.5],
[0.3, 0.4, 0.5, 0.6],
]
mock_client.feature_extraction.return_value = mock_token_embeddings
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
result = ef(["Hello"])
# Expected: mean pooling across tokens
expected = np.mean(mock_token_embeddings, axis=0).tolist()
assert len(result) == 1
assert result[0] == pytest.approx(expected, rel=1e-5)
@patch("huggingface_hub.InferenceClient")
def test_process_3d_embedding_result(self, mock_client_class):
"""Test processing of 3D batch token-level embeddings."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
# Mock 3D embeddings (1 batch, 3 tokens, 4 dimensions)
mock_batch_embeddings = [
[
[0.1, 0.2, 0.3, 0.4],
[0.2, 0.3, 0.4, 0.5],
[0.3, 0.4, 0.5, 0.6],
]
]
mock_client.feature_extraction.return_value = mock_batch_embeddings
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
result = ef(["Hello"])
# Expected: take first batch, then mean pooling
expected = np.mean(mock_batch_embeddings[0], axis=0).tolist()
assert len(result) == 1
assert result[0] == pytest.approx(expected, rel=1e-5)
@patch("huggingface_hub.InferenceClient")
def test_error_handling_deprecated_endpoint(self, mock_client_class):
"""Test error handling for deprecated endpoint error."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.feature_extraction.side_effect = Exception(
"https://api-inference.huggingface.co is no longer supported"
)
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
with pytest.raises(ValueError, match="HuggingFace API endpoint error"):
ef(["Hello"])
@patch("huggingface_hub.InferenceClient")
def test_error_handling_unauthorized(self, mock_client_class):
"""Test error handling for authentication error."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.feature_extraction.side_effect = Exception("401 Unauthorized")
ef = HuggingFaceEmbeddingFunction(api_key="invalid-key")
with pytest.raises(ValueError, match="HuggingFace API authentication error"):
ef(["Hello"])
@patch("huggingface_hub.InferenceClient")
def test_error_handling_model_not_found(self, mock_client_class):
"""Test error handling for model not found error."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.feature_extraction.side_effect = Exception("404 Not Found")
ef = HuggingFaceEmbeddingFunction(
api_key="test-key", model_name="nonexistent/model"
)
with pytest.raises(ValueError, match="HuggingFace model not found"):
ef(["Hello"])
@patch("huggingface_hub.InferenceClient")
def test_error_handling_generic_error(self, mock_client_class):
"""Test error handling for generic API error."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_client.feature_extraction.side_effect = Exception("Some unexpected error")
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
with pytest.raises(ValueError, match="HuggingFace API error"):
ef(["Hello"])
@patch("huggingface_hub.InferenceClient")
def test_name_method(self, mock_client_class):
"""Test the name() static method."""
assert HuggingFaceEmbeddingFunction.name() == "huggingface"
@patch("huggingface_hub.InferenceClient")
def test_get_config(self, mock_client_class):
"""Test get_config method."""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
ef = HuggingFaceEmbeddingFunction(
api_key="test-key",
model_name="custom/model",
)
config = ef.get_config()
assert config["model_name"] == "custom/model"
assert config["api_key"] == "test-key"
class TestHuggingFaceEmbeddingFunctionIntegration:
"""Integration tests for HuggingFace embedding function with RAGStorage."""
@patch("huggingface_hub.InferenceClient")
def test_embedding_function_works_with_rag_storage_validation(
self, mock_client_class
):
"""Test that the embedding function works with RAGStorage validation.
This test simulates the validation that happens in RAGStorage.__init__
where it calls embedding_function(["test"]) to verify the embedder works.
"""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
# Mock a valid embedding response
mock_embedding = [0.1] * 384 # 384 dimensions like all-MiniLM-L6-v2
mock_client.feature_extraction.return_value = mock_embedding
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
# This is what RAGStorage does to validate the embedder
result = ef(["test"])
assert len(result) == 1
assert len(result[0]) == 384
# Values should be numeric (float or numpy float)
assert all(isinstance(x, (int, float)) or hasattr(x, "__float__") for x in result[0])
@patch("huggingface_hub.InferenceClient")
def test_embedding_function_returns_correct_format_for_chromadb(
self, mock_client_class
):
"""Test that embeddings are in the correct format for ChromaDB.
ChromaDB expects embeddings as a sequence of embedding vectors where each
inner element is a 1D embedding vector with numeric values.
"""
mock_client = MagicMock()
mock_client_class.return_value = mock_client
mock_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]
mock_client.feature_extraction.return_value = mock_embedding
ef = HuggingFaceEmbeddingFunction(api_key="test-key")
result = ef(["Hello", "World"])
# ChromaDB expects a sequence of embedding vectors
assert isinstance(result, list)
for embedding in result:
# Each embedding should be a sequence of numeric values
assert len(embedding) == 5
for value in embedding:
# Values should be numeric (float or numpy float)
assert isinstance(value, (int, float)) or hasattr(value, "__float__")