From 7fff2b654cb8dc3d99f6018e6827b56f230f3ad7 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Thu, 4 Dec 2025 18:05:50 -0500 Subject: [PATCH] fix: use HuggingFaceEmbeddingFunction for embeddings, update keys and add tests (#4005) --- docs/en/concepts/memory.mdx | 3 +- docs/ko/concepts/memory.mdx | 3 +- docs/pt-BR/concepts/memory.mdx | 3 +- .../huggingface/huggingface_provider.py | 30 ++++++++++++++----- .../embeddings/providers/huggingface/types.py | 8 +++-- .../rag/embeddings/test_embedding_factory.py | 30 +++++++++++++++++++ 6 files changed, 61 insertions(+), 16 deletions(-) diff --git a/docs/en/concepts/memory.mdx b/docs/en/concepts/memory.mdx index deb9de07b..d931382e4 100644 --- a/docs/en/concepts/memory.mdx +++ b/docs/en/concepts/memory.mdx @@ -515,8 +515,7 @@ crew = Crew( "provider": "huggingface", "config": { "api_key": "your-hf-token", # Optional for public models - "model": "sentence-transformers/all-MiniLM-L6-v2", - "api_url": "https://api-inference.huggingface.co" # or your custom endpoint + "model": "sentence-transformers/all-MiniLM-L6-v2" } } ) diff --git a/docs/ko/concepts/memory.mdx b/docs/ko/concepts/memory.mdx index 3c6a21469..23a98e7fe 100644 --- a/docs/ko/concepts/memory.mdx +++ b/docs/ko/concepts/memory.mdx @@ -515,8 +515,7 @@ crew = Crew( "provider": "huggingface", "config": { "api_key": "your-hf-token", # Optional for public models - "model": "sentence-transformers/all-MiniLM-L6-v2", - "api_url": "https://api-inference.huggingface.co" # or your custom endpoint + "model": "sentence-transformers/all-MiniLM-L6-v2" } } ) diff --git a/docs/pt-BR/concepts/memory.mdx b/docs/pt-BR/concepts/memory.mdx index 05301ccaf..f7daa1560 100644 --- a/docs/pt-BR/concepts/memory.mdx +++ b/docs/pt-BR/concepts/memory.mdx @@ -515,8 +515,7 @@ crew = Crew( "provider": "huggingface", "config": { "api_key": "your-hf-token", # Opcional para modelos públicos - "model": "sentence-transformers/all-MiniLM-L6-v2", - "api_url": "https://api-inference.huggingface.co" # ou seu endpoint customizado + "model": "sentence-transformers/all-MiniLM-L6-v2" } } ) diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/huggingface_provider.py b/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/huggingface_provider.py index 481e9f8ba..8dc32b1f1 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/huggingface_provider.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/huggingface_provider.py @@ -1,21 +1,35 @@ """HuggingFace embeddings provider.""" from chromadb.utils.embedding_functions.huggingface_embedding_function import ( - HuggingFaceEmbeddingServer, + HuggingFaceEmbeddingFunction, ) from pydantic import AliasChoices, Field from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider -class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]): - """HuggingFace embeddings provider.""" +class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingFunction]): + """HuggingFace embeddings provider for the HuggingFace Inference API.""" - embedding_callable: type[HuggingFaceEmbeddingServer] = Field( - default=HuggingFaceEmbeddingServer, + embedding_callable: type[HuggingFaceEmbeddingFunction] = Field( + default=HuggingFaceEmbeddingFunction, description="HuggingFace embedding function class", ) - url: str = Field( - description="HuggingFace API URL", - validation_alias=AliasChoices("EMBEDDINGS_HUGGINGFACE_URL", "HUGGINGFACE_URL"), + api_key: str | None = Field( + default=None, + description="HuggingFace API key", + validation_alias=AliasChoices( + "EMBEDDINGS_HUGGINGFACE_API_KEY", + "HUGGINGFACE_API_KEY", + "HF_TOKEN", + ), + ) + model_name: str = Field( + default="sentence-transformers/all-MiniLM-L6-v2", + description="Model name to use for embeddings", + validation_alias=AliasChoices( + "EMBEDDINGS_HUGGINGFACE_MODEL_NAME", + "HUGGINGFACE_MODEL_NAME", + "model", + ), ) diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/types.py b/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/types.py index 48ff4f5b3..48d4211b0 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/types.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/types.py @@ -1,6 +1,6 @@ """Type definitions for HuggingFace embedding providers.""" -from typing import Literal +from typing import Annotated, Literal from typing_extensions import Required, TypedDict @@ -8,7 +8,11 @@ from typing_extensions import Required, TypedDict class HuggingFaceProviderConfig(TypedDict, total=False): """Configuration for HuggingFace provider.""" - url: str + api_key: str + model: Annotated[ + str, "sentence-transformers/all-MiniLM-L6-v2" + ] # alias for model_name for backward compat + model_name: Annotated[str, "sentence-transformers/all-MiniLM-L6-v2"] class HuggingFaceProviderSpec(TypedDict, total=False): diff --git a/lib/crewai/tests/rag/embeddings/test_embedding_factory.py b/lib/crewai/tests/rag/embeddings/test_embedding_factory.py index b5a33bd74..b173367a3 100644 --- a/lib/crewai/tests/rag/embeddings/test_embedding_factory.py +++ b/lib/crewai/tests/rag/embeddings/test_embedding_factory.py @@ -99,6 +99,36 @@ class TestEmbeddingFactory: "crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider" ) + @patch("crewai.rag.embeddings.factory.import_and_validate_definition") + def test_build_embedder_huggingface(self, mock_import): + """Test building HuggingFace embedder.""" + mock_provider_class = MagicMock() + mock_provider_instance = MagicMock() + mock_embedding_function = MagicMock() + + mock_import.return_value = mock_provider_class + mock_provider_class.return_value = mock_provider_instance + mock_provider_instance.embedding_callable.return_value = mock_embedding_function + + config = { + "provider": "huggingface", + "config": { + "api_key": "hf-test-key", + "model": "sentence-transformers/all-MiniLM-L6-v2", + }, + } + + build_embedder(config) + + mock_import.assert_called_once_with( + "crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider" + ) + mock_provider_class.assert_called_once() + + call_kwargs = mock_provider_class.call_args.kwargs + assert call_kwargs["api_key"] == "hf-test-key" + assert call_kwargs["model"] == "sentence-transformers/all-MiniLM-L6-v2" + @patch("crewai.rag.embeddings.factory.import_and_validate_definition") def test_build_embedder_cohere(self, mock_import): """Test building Cohere embedder."""