From a928cde6ee1cade3dfa06c2098238c2aaad201c7 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Mon, 24 Nov 2025 16:51:28 -0500 Subject: [PATCH] fix: rag tool embeddings config * fix: ensure config is not flattened, add tests * chore: refactor inits to model_validator * chore: refactor rag tool config parsing * chore: add initial docs * chore: add additional validation aliases for provider env vars * chore: add solid docs * chore: move imports to top * fix: revert circular import * fix: lazy import qdrant-client * fix: allow collection name config * chore: narrow model names for google * chore: update additional docs * chore: add backward compat on model name aliases * chore: add tests for config changes --- docs/en/concepts/knowledge.mdx | 8 +- docs/en/concepts/memory.mdx | 18 +- docs/en/tools/ai-ml/ragtool.mdx | 528 +++++++++++++++++- docs/en/tools/database-data/mysqltool.mdx | 6 +- docs/en/tools/database-data/pgsearchtool.mdx | 6 +- .../en/tools/file-document/jsonsearchtool.mdx | 6 +- docs/en/tools/file-document/pdfsearchtool.mdx | 8 +- docs/en/tools/file-document/txtsearchtool.mdx | 4 +- .../search-research/codedocssearchtool.mdx | 6 +- .../search-research/githubsearchtool.mdx | 6 +- .../search-research/websitesearchtool.mdx | 6 +- .../youtubechannelsearchtool.mdx | 6 +- .../youtubevideosearchtool.mdx | 6 +- .../adapters/crewai_rag_adapter.py | 36 +- .../tools/pdf_search_tool/pdf_search_tool.py | 16 +- .../src/crewai_tools/tools/rag/__init__.py | 10 + .../src/crewai_tools/tools/rag/rag_tool.py | 237 ++++---- .../src/crewai_tools/tools/rag/types.py | 32 ++ .../tools/txt_search_tool/txt_search_tool.py | 16 +- .../tests/tools/rag/rag_tool_test.py | 135 ++++- .../tools/rag/test_rag_tool_validation.py | 66 +++ .../tools/test_pdf_search_tool_config.py | 116 ++++ .../tools/test_txt_search_tool_config.py | 104 ++++ .../src/crewai/rag/embeddings/factory.py | 1 + .../rag/embeddings/providers/aws/bedrock.py | 11 +- .../providers/cohere/cohere_provider.py | 10 +- .../providers/google/generative_ai.py | 25 +- .../rag/embeddings/providers/google/types.py | 17 +- .../rag/embeddings/providers/google/vertex.py | 21 +- .../huggingface/huggingface_provider.py | 5 +- .../rag/embeddings/providers/ibm/watsonx.py | 85 ++- .../instructor/instructor_provider.py | 16 +- .../providers/jina/jina_provider.py | 11 +- .../embeddings/providers/microsoft/azure.py | 47 +- .../embeddings/providers/microsoft/types.py | 2 +- .../providers/ollama/ollama_provider.py | 11 +- .../providers/onnx/onnx_provider.py | 6 +- .../providers/openai/openai_provider.py | 30 +- .../providers/openclip/openclip_provider.py | 14 +- .../providers/roboflow/roboflow_provider.py | 10 +- .../sentence_transformer_provider.py | 17 +- .../providers/text2vec/text2vec_provider.py | 8 +- .../providers/voyageai/voyageai_provider.py | 33 +- lib/crewai/src/crewai/rag/embeddings/types.py | 2 +- lib/crewai/src/crewai/rag/qdrant/config.py | 13 +- .../embeddings/test_backward_compatibility.py | 364 ++++++++++++ 46 files changed, 1850 insertions(+), 291 deletions(-) create mode 100644 lib/crewai-tools/src/crewai_tools/tools/rag/types.py create mode 100644 lib/crewai-tools/tests/tools/rag/test_rag_tool_validation.py create mode 100644 lib/crewai-tools/tests/tools/test_pdf_search_tool_config.py create mode 100644 lib/crewai-tools/tests/tools/test_txt_search_tool_config.py create mode 100644 lib/crewai/tests/rag/embeddings/test_backward_compatibility.py diff --git a/docs/en/concepts/knowledge.mdx b/docs/en/concepts/knowledge.mdx index dfd74949a..937cca1fd 100644 --- a/docs/en/concepts/knowledge.mdx +++ b/docs/en/concepts/knowledge.mdx @@ -388,8 +388,8 @@ crew = Crew( agents=[sales_agent, tech_agent, support_agent], tasks=[...], embedder={ # Fallback embedder for agents without their own - "provider": "google", - "config": {"model": "text-embedding-004"} + "provider": "google-generativeai", + "config": {"model_name": "gemini-embedding-001"} } ) @@ -629,9 +629,9 @@ agent = Agent( backstory="Expert researcher", knowledge_sources=[knowledge_source], embedder={ - "provider": "google", + "provider": "google-generativeai", "config": { - "model": "models/text-embedding-004", + "model_name": "gemini-embedding-001", "api_key": "your-google-key" } } diff --git a/docs/en/concepts/memory.mdx b/docs/en/concepts/memory.mdx index 27390395b..deb9de07b 100644 --- a/docs/en/concepts/memory.mdx +++ b/docs/en/concepts/memory.mdx @@ -341,7 +341,7 @@ crew = Crew( embedder={ "provider": "openai", "config": { - "model": "text-embedding-3-small" # or "text-embedding-3-large" + "model_name": "text-embedding-3-small" # or "text-embedding-3-large" } } ) @@ -353,7 +353,7 @@ crew = Crew( "provider": "openai", "config": { "api_key": "your-openai-api-key", # Optional: override env var - "model": "text-embedding-3-large", + "model_name": "text-embedding-3-large", "dimensions": 1536, # Optional: reduce dimensions for smaller storage "organization_id": "your-org-id" # Optional: for organization accounts } @@ -375,7 +375,7 @@ crew = Crew( "api_base": "https://your-resource.openai.azure.com/", "api_type": "azure", "api_version": "2023-05-15", - "model": "text-embedding-3-small", + "model_name": "text-embedding-3-small", "deployment_id": "your-deployment-name" # Azure deployment name } } @@ -390,10 +390,10 @@ Use Google's text embedding models for integration with Google Cloud services. crew = Crew( memory=True, embedder={ - "provider": "google", + "provider": "google-generativeai", "config": { "api_key": "your-google-api-key", - "model": "text-embedding-004" # or "text-embedding-preview-0409" + "model_name": "gemini-embedding-001" # or "text-embedding-005", "text-multilingual-embedding-002" } } ) @@ -461,7 +461,7 @@ crew = Crew( "provider": "cohere", "config": { "api_key": "your-cohere-api-key", - "model": "embed-english-v3.0" # or "embed-multilingual-v3.0" + "model_name": "embed-english-v3.0" # or "embed-multilingual-v3.0" } } ) @@ -478,7 +478,7 @@ crew = Crew( "provider": "voyageai", "config": { "api_key": "your-voyage-api-key", - "model": "voyage-large-2", # or "voyage-code-2" for code + "model": "voyage-3", # or "voyage-3-lite", "voyage-code-3" "input_type": "document" # or "query" } } @@ -912,10 +912,10 @@ crew = Crew( crew = Crew( memory=True, embedder={ - "provider": "google", + "provider": "google-generativeai", "config": { "api_key": "your-api-key", - "model": "text-embedding-004" + "model_name": "gemini-embedding-001" } } ) diff --git a/docs/en/tools/ai-ml/ragtool.mdx b/docs/en/tools/ai-ml/ragtool.mdx index 547ec94da..0380c4bac 100644 --- a/docs/en/tools/ai-ml/ragtool.mdx +++ b/docs/en/tools/ai-ml/ragtool.mdx @@ -77,7 +77,7 @@ The `RagTool` accepts the following parameters: - **summarize**: Optional. Whether to summarize the retrieved content. Default is `False`. - **adapter**: Optional. A custom adapter for the knowledge base. If not provided, a CrewAIRagAdapter will be used. -- **config**: Optional. Configuration for the underlying CrewAI RAG system. +- **config**: Optional. Configuration for the underlying CrewAI RAG system. Accepts a `RagToolConfig` TypedDict with optional `embedding_model` (ProviderSpec) and `vectordb` (VectorDbConfig) keys. All configuration values provided programmatically take precedence over environment variables. ## Adding Content @@ -127,26 +127,528 @@ You can customize the behavior of the `RagTool` by providing a configuration dic ```python Code from crewai_tools import RagTool +from crewai_tools.tools.rag import RagToolConfig, VectorDbConfig, ProviderSpec # Create a RAG tool with custom configuration -config = { - "vectordb": { - "provider": "qdrant", - "config": { - "collection_name": "my-collection" - } - }, - "embedding_model": { - "provider": "openai", - "config": { - "model": "text-embedding-3-small" - } + +vectordb: VectorDbConfig = { + "provider": "qdrant", + "config": { + "collection_name": "my-collection" } } +embedding_model: ProviderSpec = { + "provider": "openai", + "config": { + "model_name": "text-embedding-3-small" + } +} + +config: RagToolConfig = { + "vectordb": vectordb, + "embedding_model": embedding_model +} + rag_tool = RagTool(config=config, summarize=True) ``` +## Embedding Model Configuration + +The `embedding_model` parameter accepts a `crewai.rag.embeddings.types.ProviderSpec` dictionary with the structure: + +```python +{ + "provider": "provider-name", # Required + "config": { # Optional + # Provider-specific configuration + } +} +``` + +### Supported Providers + + + + ```python main.py + from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec + + embedding_model: OpenAIProviderSpec = { + "provider": "openai", + "config": { + "api_key": "your-api-key", + "model_name": "text-embedding-ada-002", + "dimensions": 1536, + "organization_id": "your-org-id", + "api_base": "https://api.openai.com/v1", + "api_version": "v1", + "default_headers": {"Custom-Header": "value"} + } + } + ``` + + **Config Options:** + - `api_key` (str): OpenAI API key + - `model_name` (str): Model to use. Default: `text-embedding-ada-002`. Options: `text-embedding-3-small`, `text-embedding-3-large`, `text-embedding-ada-002` + - `dimensions` (int): Number of dimensions for the embedding + - `organization_id` (str): OpenAI organization ID + - `api_base` (str): Custom API base URL + - `api_version` (str): API version + - `default_headers` (dict): Custom headers for API requests + + **Environment Variables:** + - `OPENAI_API_KEY` or `EMBEDDINGS_OPENAI_API_KEY`: `api_key` + - `OPENAI_ORGANIZATION_ID` or `EMBEDDINGS_OPENAI_ORGANIZATION_ID`: `organization_id` + - `OPENAI_MODEL_NAME` or `EMBEDDINGS_OPENAI_MODEL_NAME`: `model_name` + - `OPENAI_API_BASE` or `EMBEDDINGS_OPENAI_API_BASE`: `api_base` + - `OPENAI_API_VERSION` or `EMBEDDINGS_OPENAI_API_VERSION`: `api_version` + - `OPENAI_DIMENSIONS` or `EMBEDDINGS_OPENAI_DIMENSIONS`: `dimensions` + + + + ```python main.py + from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec + + embedding_model: CohereProviderSpec = { + "provider": "cohere", + "config": { + "api_key": "your-api-key", + "model_name": "embed-english-v3.0" + } + } + ``` + + **Config Options:** + - `api_key` (str): Cohere API key + - `model_name` (str): Model to use. Default: `large`. Options: `embed-english-v3.0`, `embed-multilingual-v3.0`, `large`, `small` + + **Environment Variables:** + - `COHERE_API_KEY` or `EMBEDDINGS_COHERE_API_KEY`: `api_key` + - `EMBEDDINGS_COHERE_MODEL_NAME`: `model_name` + + + + ```python main.py + from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec + + embedding_model: VoyageAIProviderSpec = { + "provider": "voyageai", + "config": { + "api_key": "your-api-key", + "model": "voyage-3", + "input_type": "document", + "truncation": True, + "output_dtype": "float32", + "output_dimension": 1024, + "max_retries": 3, + "timeout": 60.0 + } + } + ``` + + **Config Options:** + - `api_key` (str): VoyageAI API key + - `model` (str): Model to use. Default: `voyage-2`. Options: `voyage-3`, `voyage-3-lite`, `voyage-code-3`, `voyage-large-2` + - `input_type` (str): Type of input. Options: `document` (for storage), `query` (for search) + - `truncation` (bool): Whether to truncate inputs that exceed max length. Default: `True` + - `output_dtype` (str): Output data type + - `output_dimension` (int): Dimension of output embeddings + - `max_retries` (int): Maximum number of retry attempts. Default: `0` + - `timeout` (float): Request timeout in seconds + + **Environment Variables:** + - `VOYAGEAI_API_KEY` or `EMBEDDINGS_VOYAGEAI_API_KEY`: `api_key` + - `VOYAGEAI_MODEL` or `EMBEDDINGS_VOYAGEAI_MODEL`: `model` + - `VOYAGEAI_INPUT_TYPE` or `EMBEDDINGS_VOYAGEAI_INPUT_TYPE`: `input_type` + - `VOYAGEAI_TRUNCATION` or `EMBEDDINGS_VOYAGEAI_TRUNCATION`: `truncation` + - `VOYAGEAI_OUTPUT_DTYPE` or `EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE`: `output_dtype` + - `VOYAGEAI_OUTPUT_DIMENSION` or `EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION`: `output_dimension` + - `VOYAGEAI_MAX_RETRIES` or `EMBEDDINGS_VOYAGEAI_MAX_RETRIES`: `max_retries` + - `VOYAGEAI_TIMEOUT` or `EMBEDDINGS_VOYAGEAI_TIMEOUT`: `timeout` + + + + ```python main.py + from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec + + embedding_model: OllamaProviderSpec = { + "provider": "ollama", + "config": { + "model_name": "llama2", + "url": "http://localhost:11434/api/embeddings" + } + } + ``` + + **Config Options:** + - `model_name` (str): Ollama model name (e.g., `llama2`, `mistral`, `nomic-embed-text`) + - `url` (str): Ollama API endpoint URL. Default: `http://localhost:11434/api/embeddings` + + **Environment Variables:** + - `OLLAMA_MODEL` or `EMBEDDINGS_OLLAMA_MODEL`: `model_name` + - `OLLAMA_URL` or `EMBEDDINGS_OLLAMA_URL`: `url` + + + + ```python main.py + from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec + + embedding_model: BedrockProviderSpec = { + "provider": "amazon-bedrock", + "config": { + "model_name": "amazon.titan-embed-text-v2:0", + "session": boto3_session + } + } + ``` + + **Config Options:** + - `model_name` (str): Bedrock model ID. Default: `amazon.titan-embed-text-v1`. Options: `amazon.titan-embed-text-v1`, `amazon.titan-embed-text-v2:0`, `cohere.embed-english-v3`, `cohere.embed-multilingual-v3` + - `session` (Any): Boto3 session object for AWS authentication + + **Environment Variables:** + - `AWS_ACCESS_KEY_ID`: AWS access key + - `AWS_SECRET_ACCESS_KEY`: AWS secret key + - `AWS_REGION`: AWS region (e.g., `us-east-1`) + + + + ```python main.py + from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec + + embedding_model: AzureProviderSpec = { + "provider": "azure", + "config": { + "deployment_id": "your-deployment-id", + "api_key": "your-api-key", + "api_base": "https://your-resource.openai.azure.com", + "api_version": "2024-02-01", + "model_name": "text-embedding-ada-002", + "api_type": "azure" + } + } + ``` + + **Config Options:** + - `deployment_id` (str): **Required** - Azure OpenAI deployment ID + - `api_key` (str): Azure OpenAI API key + - `api_base` (str): Azure OpenAI resource endpoint + - `api_version` (str): API version. Example: `2024-02-01` + - `model_name` (str): Model name. Default: `text-embedding-ada-002` + - `api_type` (str): API type. Default: `azure` + - `dimensions` (int): Output dimensions + - `default_headers` (dict): Custom headers + + **Environment Variables:** + - `AZURE_OPENAI_API_KEY` or `EMBEDDINGS_AZURE_API_KEY`: `api_key` + - `AZURE_OPENAI_ENDPOINT` or `EMBEDDINGS_AZURE_API_BASE`: `api_base` + - `EMBEDDINGS_AZURE_DEPLOYMENT_ID`: `deployment_id` + - `EMBEDDINGS_AZURE_API_VERSION`: `api_version` + - `EMBEDDINGS_AZURE_MODEL_NAME`: `model_name` + - `EMBEDDINGS_AZURE_API_TYPE`: `api_type` + - `EMBEDDINGS_AZURE_DIMENSIONS`: `dimensions` + + + + ```python main.py + from crewai.rag.embeddings.providers.google.types import GenerativeAiProviderSpec + + embedding_model: GenerativeAiProviderSpec = { + "provider": "google-generativeai", + "config": { + "api_key": "your-api-key", + "model_name": "gemini-embedding-001", + "task_type": "RETRIEVAL_DOCUMENT" + } + } + ``` + + **Config Options:** + - `api_key` (str): Google AI API key + - `model_name` (str): Model name. Default: `gemini-embedding-001`. Options: `gemini-embedding-001`, `text-embedding-005`, `text-multilingual-embedding-002` + - `task_type` (str): Task type for embeddings. Default: `RETRIEVAL_DOCUMENT`. Options: `RETRIEVAL_DOCUMENT`, `RETRIEVAL_QUERY` + + **Environment Variables:** + - `GOOGLE_API_KEY`, `GEMINI_API_KEY`, or `EMBEDDINGS_GOOGLE_API_KEY`: `api_key` + - `EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME`: `model_name` + - `EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE`: `task_type` + + + + ```python main.py + from crewai.rag.embeddings.providers.google.types import VertexAIProviderSpec + + embedding_model: VertexAIProviderSpec = { + "provider": "google-vertex", + "config": { + "model_name": "text-embedding-004", + "project_id": "your-project-id", + "region": "us-central1", + "api_key": "your-api-key" + } + } + ``` + + **Config Options:** + - `model_name` (str): Model name. Default: `textembedding-gecko`. Options: `text-embedding-004`, `textembedding-gecko`, `textembedding-gecko-multilingual` + - `project_id` (str): Google Cloud project ID. Default: `cloud-large-language-models` + - `region` (str): Google Cloud region. Default: `us-central1` + - `api_key` (str): API key for authentication + + **Environment Variables:** + - `GOOGLE_APPLICATION_CREDENTIALS`: Path to service account JSON file + - `GOOGLE_CLOUD_PROJECT` or `EMBEDDINGS_GOOGLE_VERTEX_PROJECT_ID`: `project_id` + - `EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME`: `model_name` + - `EMBEDDINGS_GOOGLE_VERTEX_REGION`: `region` + - `EMBEDDINGS_GOOGLE_VERTEX_API_KEY`: `api_key` + + + + ```python main.py + from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec + + embedding_model: JinaProviderSpec = { + "provider": "jina", + "config": { + "api_key": "your-api-key", + "model_name": "jina-embeddings-v3" + } + } + ``` + + **Config Options:** + - `api_key` (str): Jina AI API key + - `model_name` (str): Model name. Default: `jina-embeddings-v2-base-en`. Options: `jina-embeddings-v3`, `jina-embeddings-v2-base-en`, `jina-embeddings-v2-small-en` + + **Environment Variables:** + - `JINA_API_KEY` or `EMBEDDINGS_JINA_API_KEY`: `api_key` + - `EMBEDDINGS_JINA_MODEL_NAME`: `model_name` + + + + ```python main.py + from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec + + embedding_model: HuggingFaceProviderSpec = { + "provider": "huggingface", + "config": { + "url": "https://api-inference.huggingface.co/models/sentence-transformers/all-MiniLM-L6-v2" + } + } + ``` + + **Config Options:** + - `url` (str): Full URL to HuggingFace inference API endpoint + + **Environment Variables:** + - `HUGGINGFACE_URL` or `EMBEDDINGS_HUGGINGFACE_URL`: `url` + + + + ```python main.py + from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec + + embedding_model: InstructorProviderSpec = { + "provider": "instructor", + "config": { + "model_name": "hkunlp/instructor-xl", + "device": "cuda", + "instruction": "Represent the document" + } + } + ``` + + **Config Options:** + - `model_name` (str): HuggingFace model ID. Default: `hkunlp/instructor-base`. Options: `hkunlp/instructor-xl`, `hkunlp/instructor-large`, `hkunlp/instructor-base` + - `device` (str): Device to run on. Default: `cpu`. Options: `cpu`, `cuda`, `mps` + - `instruction` (str): Instruction prefix for embeddings + + **Environment Variables:** + - `EMBEDDINGS_INSTRUCTOR_MODEL_NAME`: `model_name` + - `EMBEDDINGS_INSTRUCTOR_DEVICE`: `device` + - `EMBEDDINGS_INSTRUCTOR_INSTRUCTION`: `instruction` + + + + ```python main.py + from crewai.rag.embeddings.providers.sentence_transformer.types import SentenceTransformerProviderSpec + + embedding_model: SentenceTransformerProviderSpec = { + "provider": "sentence-transformer", + "config": { + "model_name": "all-mpnet-base-v2", + "device": "cuda", + "normalize_embeddings": True + } + } + ``` + + **Config Options:** + - `model_name` (str): Sentence Transformers model name. Default: `all-MiniLM-L6-v2`. Options: `all-mpnet-base-v2`, `all-MiniLM-L6-v2`, `paraphrase-multilingual-MiniLM-L12-v2` + - `device` (str): Device to run on. Default: `cpu`. Options: `cpu`, `cuda`, `mps` + - `normalize_embeddings` (bool): Whether to normalize embeddings. Default: `False` + + **Environment Variables:** + - `EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME`: `model_name` + - `EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE`: `device` + - `EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS`: `normalize_embeddings` + + + + ```python main.py + from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec + + embedding_model: ONNXProviderSpec = { + "provider": "onnx", + "config": { + "preferred_providers": ["CUDAExecutionProvider", "CPUExecutionProvider"] + } + } + ``` + + **Config Options:** + - `preferred_providers` (list[str]): List of ONNX execution providers in order of preference + + **Environment Variables:** + - `EMBEDDINGS_ONNX_PREFERRED_PROVIDERS`: `preferred_providers` (comma-separated list) + + + + ```python main.py + from crewai.rag.embeddings.providers.openclip.types import OpenCLIPProviderSpec + + embedding_model: OpenCLIPProviderSpec = { + "provider": "openclip", + "config": { + "model_name": "ViT-B-32", + "checkpoint": "laion2b_s34b_b79k", + "device": "cuda" + } + } + ``` + + **Config Options:** + - `model_name` (str): OpenCLIP model architecture. Default: `ViT-B-32`. Options: `ViT-B-32`, `ViT-B-16`, `ViT-L-14` + - `checkpoint` (str): Pretrained checkpoint name. Default: `laion2b_s34b_b79k`. Options: `laion2b_s34b_b79k`, `laion400m_e32`, `openai` + - `device` (str): Device to run on. Default: `cpu`. Options: `cpu`, `cuda` + + **Environment Variables:** + - `EMBEDDINGS_OPENCLIP_MODEL_NAME`: `model_name` + - `EMBEDDINGS_OPENCLIP_CHECKPOINT`: `checkpoint` + - `EMBEDDINGS_OPENCLIP_DEVICE`: `device` + + + + ```python main.py + from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec + + embedding_model: Text2VecProviderSpec = { + "provider": "text2vec", + "config": { + "model_name": "shibing624/text2vec-base-multilingual" + } + } + ``` + + **Config Options:** + - `model_name` (str): Text2Vec model name from HuggingFace. Default: `shibing624/text2vec-base-chinese`. Options: `shibing624/text2vec-base-multilingual`, `shibing624/text2vec-base-chinese` + + **Environment Variables:** + - `EMBEDDINGS_TEXT2VEC_MODEL_NAME`: `model_name` + + + + ```python main.py + from crewai.rag.embeddings.providers.roboflow.types import RoboflowProviderSpec + + embedding_model: RoboflowProviderSpec = { + "provider": "roboflow", + "config": { + "api_key": "your-api-key", + "api_url": "https://infer.roboflow.com" + } + } + ``` + + **Config Options:** + - `api_key` (str): Roboflow API key. Default: `""` (empty string) + - `api_url` (str): Roboflow inference API URL. Default: `https://infer.roboflow.com` + + **Environment Variables:** + - `ROBOFLOW_API_KEY` or `EMBEDDINGS_ROBOFLOW_API_KEY`: `api_key` + - `ROBOFLOW_API_URL` or `EMBEDDINGS_ROBOFLOW_API_URL`: `api_url` + + + + ```python main.py + from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderSpec + + embedding_model: WatsonXProviderSpec = { + "provider": "watsonx", + "config": { + "model_id": "ibm/slate-125m-english-rtrvr", + "url": "https://us-south.ml.cloud.ibm.com", + "api_key": "your-api-key", + "project_id": "your-project-id", + "batch_size": 100, + "concurrency_limit": 10, + "persistent_connection": True + } + } + ``` + + **Config Options:** + - `model_id` (str): WatsonX model identifier + - `url` (str): WatsonX API endpoint + - `api_key` (str): IBM Cloud API key + - `project_id` (str): WatsonX project ID + - `space_id` (str): WatsonX space ID (alternative to project_id) + - `batch_size` (int): Batch size for embeddings. Default: `100` + - `concurrency_limit` (int): Maximum concurrent requests. Default: `10` + - `persistent_connection` (bool): Use persistent connections. Default: `True` + - Plus 20+ additional authentication and configuration options + + **Environment Variables:** + - `WATSONX_API_KEY` or `EMBEDDINGS_WATSONX_API_KEY`: `api_key` + - `WATSONX_URL` or `EMBEDDINGS_WATSONX_URL`: `url` + - `WATSONX_PROJECT_ID` or `EMBEDDINGS_WATSONX_PROJECT_ID`: `project_id` + - `EMBEDDINGS_WATSONX_MODEL_ID`: `model_id` + - `EMBEDDINGS_WATSONX_SPACE_ID`: `space_id` + - `EMBEDDINGS_WATSONX_BATCH_SIZE`: `batch_size` + - `EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT`: `concurrency_limit` + - `EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION`: `persistent_connection` + + + + ```python main.py + from crewai.rag.core.base_embeddings_callable import EmbeddingFunction + from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec + + class MyEmbeddingFunction(EmbeddingFunction): + def __call__(self, input): + # Your custom embedding logic + return embeddings + + embedding_model: CustomProviderSpec = { + "provider": "custom", + "config": { + "embedding_callable": MyEmbeddingFunction + } + } + ``` + + **Config Options:** + - `embedding_callable` (type[EmbeddingFunction]): Custom embedding function class + + **Note:** Custom embedding functions must implement the `EmbeddingFunction` protocol defined in `crewai.rag.core.base_embeddings_callable`. The `__call__` method should accept input data and return embeddings as a list of numpy arrays (or compatible format that will be normalized). The returned embeddings are automatically normalized and validated. + + + +### Notes +- All config fields are optional unless marked as **Required** +- API keys can typically be provided via environment variables instead of config +- Default values are shown where applicable + ## Conclusion The `RagTool` provides a powerful way to create and query knowledge bases from various data sources. By leveraging Retrieval-Augmented Generation, it enables agents to access and retrieve relevant information efficiently, enhancing their ability to provide accurate and contextually appropriate responses. diff --git a/docs/en/tools/database-data/mysqltool.mdx b/docs/en/tools/database-data/mysqltool.mdx index efdd3371f..c66176297 100644 --- a/docs/en/tools/database-data/mysqltool.mdx +++ b/docs/en/tools/database-data/mysqltool.mdx @@ -58,10 +58,10 @@ tool = MySQLSearchTool( ), ), embedder=dict( - provider="google", + provider="google-generativeai", config=dict( - model="models/embedding-001", - task_type="retrieval_document", + model_name="gemini-embedding-001", + task_type="RETRIEVAL_DOCUMENT", # title="Embeddings", ), ), diff --git a/docs/en/tools/database-data/pgsearchtool.mdx b/docs/en/tools/database-data/pgsearchtool.mdx index d4a228fdd..cb021d4d9 100644 --- a/docs/en/tools/database-data/pgsearchtool.mdx +++ b/docs/en/tools/database-data/pgsearchtool.mdx @@ -71,10 +71,10 @@ tool = PGSearchTool( ), ), embedder=dict( - provider="google", # or openai, ollama, ... + provider="google-generativeai", # or openai, ollama, ... config=dict( - model="models/embedding-001", - task_type="retrieval_document", + model_name="gemini-embedding-001", + task_type="RETRIEVAL_DOCUMENT", # title="Embeddings", ), ), diff --git a/docs/en/tools/file-document/jsonsearchtool.mdx b/docs/en/tools/file-document/jsonsearchtool.mdx index 6228ccbc2..7b1737faa 100644 --- a/docs/en/tools/file-document/jsonsearchtool.mdx +++ b/docs/en/tools/file-document/jsonsearchtool.mdx @@ -64,10 +64,10 @@ tool = JSONSearchTool( }, }, "embedding_model": { - "provider": "google", # or openai, ollama, ... + "provider": "google-generativeai", # or openai, ollama, ... "config": { - "model": "models/embedding-001", - "task_type": "retrieval_document", + "model_name": "gemini-embedding-001", + "task_type": "RETRIEVAL_DOCUMENT", # Further customization options can be added here. }, }, diff --git a/docs/en/tools/file-document/pdfsearchtool.mdx b/docs/en/tools/file-document/pdfsearchtool.mdx index cede7cfe2..32e05669e 100644 --- a/docs/en/tools/file-document/pdfsearchtool.mdx +++ b/docs/en/tools/file-document/pdfsearchtool.mdx @@ -63,15 +63,15 @@ tool = PDFSearchTool( "config": { # Model identifier for the chosen provider. "model" will be auto-mapped to "model_name" internally. "model": "text-embedding-3-small", - # Optional: API key. If omitted, the tool will use provider-specific env vars when available - # (e.g., OPENAI_API_KEY for provider="openai"). + # Optional: API key. If omitted, the tool will use provider-specific env vars + # (e.g., OPENAI_API_KEY or EMBEDDINGS_OPENAI_API_KEY for OpenAI). # "api_key": "sk-...", # Provider-specific examples: # --- Google Generative AI --- # (Set provider="google-generativeai" above) - # "model": "models/embedding-001", - # "task_type": "retrieval_document", + # "model_name": "gemini-embedding-001", + # "task_type": "RETRIEVAL_DOCUMENT", # "title": "Embeddings", # --- Cohere --- diff --git a/docs/en/tools/file-document/txtsearchtool.mdx b/docs/en/tools/file-document/txtsearchtool.mdx index 4c4b0d91d..fda46180a 100644 --- a/docs/en/tools/file-document/txtsearchtool.mdx +++ b/docs/en/tools/file-document/txtsearchtool.mdx @@ -66,9 +66,9 @@ tool = TXTSearchTool( "provider": "openai", # or google-generativeai, cohere, ollama, ... "config": { "model": "text-embedding-3-small", - # "api_key": "sk-...", # optional if env var is set + # "api_key": "sk-...", # optional if env var is set (e.g., OPENAI_API_KEY or EMBEDDINGS_OPENAI_API_KEY) # Provider examples: - # Google → model: "models/embedding-001", task_type: "retrieval_document" + # Google → model_name: "gemini-embedding-001", task_type: "RETRIEVAL_DOCUMENT" # Cohere → model: "embed-english-v3.0" # Ollama → model: "nomic-embed-text" }, diff --git a/docs/en/tools/search-research/codedocssearchtool.mdx b/docs/en/tools/search-research/codedocssearchtool.mdx index 2c5890280..0635509e3 100644 --- a/docs/en/tools/search-research/codedocssearchtool.mdx +++ b/docs/en/tools/search-research/codedocssearchtool.mdx @@ -73,10 +73,10 @@ tool = CodeDocsSearchTool( ), ), embedder=dict( - provider="google", # or openai, ollama, ... + provider="google-generativeai", # or openai, ollama, ... config=dict( - model="models/embedding-001", - task_type="retrieval_document", + model_name="gemini-embedding-001", + task_type="RETRIEVAL_DOCUMENT", # title="Embeddings", ), ), diff --git a/docs/en/tools/search-research/githubsearchtool.mdx b/docs/en/tools/search-research/githubsearchtool.mdx index b512ea43c..f5ea710cc 100644 --- a/docs/en/tools/search-research/githubsearchtool.mdx +++ b/docs/en/tools/search-research/githubsearchtool.mdx @@ -75,10 +75,10 @@ tool = GithubSearchTool( ), ), embedder=dict( - provider="google", # or openai, ollama, ... + provider="google-generativeai", # or openai, ollama, ... config=dict( - model="models/embedding-001", - task_type="retrieval_document", + model_name="gemini-embedding-001", + task_type="RETRIEVAL_DOCUMENT", # title="Embeddings", ), ), diff --git a/docs/en/tools/search-research/websitesearchtool.mdx b/docs/en/tools/search-research/websitesearchtool.mdx index ad60c76bd..52c163fd1 100644 --- a/docs/en/tools/search-research/websitesearchtool.mdx +++ b/docs/en/tools/search-research/websitesearchtool.mdx @@ -66,10 +66,10 @@ tool = WebsiteSearchTool( ), ), embedder=dict( - provider="google", # or openai, ollama, ... + provider="google-generativeai", # or openai, ollama, ... config=dict( - model="models/embedding-001", - task_type="retrieval_document", + model_name="gemini-embedding-001", + task_type="RETRIEVAL_DOCUMENT", # title="Embeddings", ), ), diff --git a/docs/en/tools/search-research/youtubechannelsearchtool.mdx b/docs/en/tools/search-research/youtubechannelsearchtool.mdx index c024fd7ce..8d53134f3 100644 --- a/docs/en/tools/search-research/youtubechannelsearchtool.mdx +++ b/docs/en/tools/search-research/youtubechannelsearchtool.mdx @@ -106,10 +106,10 @@ youtube_channel_tool = YoutubeChannelSearchTool( ), ), embedder=dict( - provider="google", # or openai, ollama, ... + provider="google-generativeai", # or openai, ollama, ... config=dict( - model="models/embedding-001", - task_type="retrieval_document", + model_name="gemini-embedding-001", + task_type="RETRIEVAL_DOCUMENT", # title="Embeddings", ), ), diff --git a/docs/en/tools/search-research/youtubevideosearchtool.mdx b/docs/en/tools/search-research/youtubevideosearchtool.mdx index b52ffa448..668e48833 100644 --- a/docs/en/tools/search-research/youtubevideosearchtool.mdx +++ b/docs/en/tools/search-research/youtubevideosearchtool.mdx @@ -108,10 +108,10 @@ youtube_search_tool = YoutubeVideoSearchTool( ), ), embedder=dict( - provider="google", # or openai, ollama, ... + provider="google-generativeai", # or openai, ollama, ... config=dict( - model="models/embedding-001", - task_type="retrieval_document", + model_name="gemini-embedding-001", + task_type="RETRIEVAL_DOCUMENT", # title="Embeddings", ), ), diff --git a/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py b/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py index 9716ca4e9..fb0a22791 100644 --- a/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py +++ b/lib/crewai-tools/src/crewai_tools/adapters/crewai_rag_adapter.py @@ -1,28 +1,51 @@ """Adapter for CrewAI's native RAG system.""" +from __future__ import annotations + import hashlib from pathlib import Path -from typing import Any, TypeAlias, TypedDict +from typing import TYPE_CHECKING, Any, TypeAlias, TypedDict, cast import uuid from crewai.rag.config.types import RagConfigType from crewai.rag.config.utils import get_rag_client from crewai.rag.core.base_client import BaseClient from crewai.rag.factory import create_client -from crewai.rag.qdrant.config import QdrantConfig from crewai.rag.types import BaseRecord, SearchResult from pydantic import PrivateAttr -from qdrant_client.models import VectorParams -from typing_extensions import Unpack +from pydantic.dataclasses import is_pydantic_dataclass +from typing_extensions import TypeIs, Unpack from crewai_tools.rag.data_types import DataType from crewai_tools.rag.misc import sanitize_metadata_for_chromadb from crewai_tools.tools.rag.rag_tool import Adapter +if TYPE_CHECKING: + from crewai.rag.qdrant.config import QdrantConfig + + ContentItem: TypeAlias = str | Path | dict[str, Any] +def _is_qdrant_config(config: Any) -> TypeIs[QdrantConfig]: + """Check if config is a QdrantConfig using safe duck typing. + + Args: + config: RAG configuration to check. + + Returns: + True if config is a QdrantConfig instance. + """ + if not is_pydantic_dataclass(config): + return False + + try: + return cast(bool, config.provider == "qdrant") # type: ignore[attr-defined] + except (AttributeError, ImportError): + return False + + class AddDocumentParams(TypedDict, total=False): """Parameters for adding documents to the RAG system.""" @@ -56,8 +79,9 @@ class CrewAIRagAdapter(Adapter): else: self._client = get_rag_client() collection_params: dict[str, Any] = {"collection_name": self.collection_name} - if isinstance(self.config, QdrantConfig) and self.config.vectors_config: - if isinstance(self.config.vectors_config, VectorParams): + + if self.config is not None and _is_qdrant_config(self.config): + if self.config.vectors_config is not None: collection_params["vectors_config"] = self.config.vectors_config self._client.get_or_create_collection(**collection_params) diff --git a/lib/crewai-tools/src/crewai_tools/tools/pdf_search_tool/pdf_search_tool.py b/lib/crewai-tools/src/crewai_tools/tools/pdf_search_tool/pdf_search_tool.py index 049745d45..3689b8925 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/pdf_search_tool/pdf_search_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/pdf_search_tool/pdf_search_tool.py @@ -1,4 +1,5 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator +from typing_extensions import Self from crewai_tools.rag.data_types import DataType from crewai_tools.tools.rag.rag_tool import RagTool @@ -24,14 +25,17 @@ class PDFSearchTool(RagTool): "A tool that can be used to semantic search a query from a PDF's content." ) args_schema: type[BaseModel] = PDFSearchToolSchema + pdf: str | None = None - def __init__(self, pdf: str | None = None, **kwargs): - super().__init__(**kwargs) - if pdf is not None: - self.add(pdf) - self.description = f"A tool that can be used to semantic search a query the {pdf} PDF's content." + @model_validator(mode="after") + def _configure_for_pdf(self) -> Self: + """Configure tool for specific PDF if provided.""" + if self.pdf is not None: + self.add(self.pdf) + self.description = f"A tool that can be used to semantic search a query the {self.pdf} PDF's content." self.args_schema = FixedPDFSearchToolSchema self._generate_description() + return self def add(self, pdf: str) -> None: super().add(pdf, data_type=DataType.PDF_FILE) diff --git a/lib/crewai-tools/src/crewai_tools/tools/rag/__init__.py b/lib/crewai-tools/src/crewai_tools/tools/rag/__init__.py index e69de29bb..9985f63f7 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/rag/__init__.py +++ b/lib/crewai-tools/src/crewai_tools/tools/rag/__init__.py @@ -0,0 +1,10 @@ +from crewai.rag.embeddings.types import ProviderSpec + +from crewai_tools.tools.rag.types import RagToolConfig, VectorDbConfig + + +__all__ = [ + "ProviderSpec", + "RagToolConfig", + "VectorDbConfig", +] diff --git a/lib/crewai-tools/src/crewai_tools/tools/rag/rag_tool.py b/lib/crewai-tools/src/crewai_tools/tools/rag/rag_tool.py index 743946226..549a01062 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/rag/rag_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/rag/rag_tool.py @@ -1,10 +1,74 @@ from abc import ABC, abstractmethod -import os -from typing import Any, cast +from typing import Any, Literal, cast -from crewai.rag.embeddings.factory import get_embedding_function +from crewai.rag.core.base_embeddings_callable import EmbeddingFunction +from crewai.rag.embeddings.factory import build_embedder +from crewai.rag.embeddings.types import ProviderSpec from crewai.tools import BaseTool -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + TypeAdapter, + ValidationError, + field_validator, + model_validator, +) +from typing_extensions import Self + +from crewai_tools.tools.rag.types import RagToolConfig, VectorDbConfig + + +def _validate_embedding_config( + value: dict[str, Any] | ProviderSpec, +) -> dict[str, Any] | ProviderSpec: + """Validate embedding config and provide clearer error messages for union validation. + + This pre-validator catches Pydantic ValidationErrors from the ProviderSpec union + and provides a cleaner, more focused error message that only shows the relevant + provider's validation errors instead of all 18 union members. + + Args: + value: The embedding configuration dictionary or validated ProviderSpec. + + Returns: + A validated ProviderSpec instance, or the original value if already validated + or missing required fields. + + Raises: + ValueError: If the configuration is invalid for the specified provider. + """ + if not isinstance(value, dict): + return value + + provider = value.get("provider") + if not provider: + return value + + try: + type_adapter: TypeAdapter[ProviderSpec] = TypeAdapter(ProviderSpec) + return type_adapter.validate_python(value) + except ValidationError as e: + provider_key = f"{provider.lower()}providerspec" + provider_errors = [ + err for err in e.errors() if provider_key in str(err.get("loc", "")).lower() + ] + + if provider_errors: + error_msgs = [] + for err in provider_errors: + loc_parts = err["loc"] + if str(loc_parts[0]).lower() == provider_key: + loc_parts = loc_parts[1:] + loc = ".".join(str(x) for x in loc_parts) + error_msgs.append(f" - {loc}: {err['msg']}") + + raise ValueError( + f"Invalid configuration for embedding provider '{provider}':\n" + + "\n".join(error_msgs) + ) from e + + raise class Adapter(BaseModel, ABC): @@ -46,139 +110,100 @@ class RagTool(BaseTool): summarize: bool = False similarity_threshold: float = 0.6 limit: int = 5 + collection_name: str = "rag_tool_collection" adapter: Adapter = Field(default_factory=_AdapterPlaceholder) - config: Any | None = None + config: RagToolConfig = Field( + default_factory=RagToolConfig, + description="Configuration format accepted by RagTool.", + ) + + @field_validator("config", mode="before") + @classmethod + def _validate_config(cls, value: Any) -> Any: + """Validate config with improved error messages for embedding providers.""" + if not isinstance(value, dict): + return value + + embedding_model = value.get("embedding_model") + if embedding_model: + try: + value["embedding_model"] = _validate_embedding_config(embedding_model) + except ValueError: + raise + + return value @model_validator(mode="after") - def _set_default_adapter(self): + def _ensure_adapter(self) -> Self: if isinstance(self.adapter, RagTool._AdapterPlaceholder): from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter - parsed_config = self._parse_config(self.config) - + provider_cfg = self._parse_config(self.config) self.adapter = CrewAIRagAdapter( - collection_name="rag_tool_collection", + collection_name=self.collection_name, summarize=self.summarize, similarity_threshold=self.similarity_threshold, limit=self.limit, - config=parsed_config, + config=provider_cfg, ) - return self - def _parse_config(self, config: Any) -> Any: - """Parse complex config format to extract provider-specific config. + def _parse_config(self, config: RagToolConfig) -> Any: + """Normalize the RagToolConfig into a provider-specific config object. - Raises: - ValueError: If the config format is invalid or uses unsupported providers. + Defaults to 'chromadb' with no extra provider config if none is supplied. """ - if config is None: - return None + if not config: + return self._create_provider_config("chromadb", {}, None) - if isinstance(config, dict) and "provider" in config: - return config + vectordb_cfg = cast(VectorDbConfig, config.get("vectordb", {})) + provider: Literal["chromadb", "qdrant"] = vectordb_cfg.get( + "provider", "chromadb" + ) + provider_config: dict[str, Any] = vectordb_cfg.get("config", {}) - if isinstance(config, dict): - if "vectordb" in config: - vectordb_config = config["vectordb"] - if isinstance(vectordb_config, dict) and "provider" in vectordb_config: - provider = vectordb_config["provider"] - provider_config = vectordb_config.get("config", {}) + supported = ("chromadb", "qdrant") + if provider not in supported: + raise ValueError( + f"Unsupported vector database provider: '{provider}'. " + f"CrewAI RAG currently supports: {', '.join(supported)}." + ) - supported_providers = ["chromadb", "qdrant"] - if provider not in supported_providers: - raise ValueError( - f"Unsupported vector database provider: '{provider}'. " - f"CrewAI RAG currently supports: {', '.join(supported_providers)}." - ) + embedding_spec: ProviderSpec | None = config.get("embedding_model") + if embedding_spec: + embedding_spec = cast( + ProviderSpec, _validate_embedding_config(embedding_spec) + ) - embedding_config = config.get("embedding_model") - embedding_function = None - if embedding_config and isinstance(embedding_config, dict): - embedding_function = self._create_embedding_function( - embedding_config, provider - ) - - return self._create_provider_config( - provider, provider_config, embedding_function - ) - return None - embedding_config = config.get("embedding_model") - embedding_function = None - if embedding_config and isinstance(embedding_config, dict): - embedding_function = self._create_embedding_function( - embedding_config, "chromadb" - ) - - return self._create_provider_config("chromadb", {}, embedding_function) - return config - - @staticmethod - def _create_embedding_function(embedding_config: dict, provider: str) -> Any: - """Create embedding function for the specified vector database provider.""" - embedding_provider = embedding_config.get("provider") - embedding_model_config = embedding_config.get("config", {}).copy() - - if "model" in embedding_model_config: - embedding_model_config["model_name"] = embedding_model_config.pop("model") - - factory_config = {"provider": embedding_provider, **embedding_model_config} - - if embedding_provider == "openai" and "api_key" not in factory_config: - api_key = os.getenv("OPENAI_API_KEY") - if api_key: - factory_config["api_key"] = api_key - - if provider == "chromadb": - return get_embedding_function(factory_config) # type: ignore[call-overload] - - if provider == "qdrant": - chromadb_func = get_embedding_function(factory_config) # type: ignore[call-overload] - - def qdrant_embed_fn(text: str) -> list[float]: - """Embed text using ChromaDB function and convert to list of floats for Qdrant. - - Args: - text: The input text to embed. - - Returns: - A list of floats representing the embedding. - """ - embeddings = chromadb_func([text]) - return embeddings[0] if embeddings and len(embeddings) > 0 else [] - - return cast(Any, qdrant_embed_fn) - - return None + embedding_function = build_embedder(embedding_spec) if embedding_spec else None + return self._create_provider_config( + provider, provider_config, embedding_function + ) @staticmethod def _create_provider_config( - provider: str, provider_config: dict, embedding_function: Any + provider: Literal["chromadb", "qdrant"], + provider_config: dict[str, Any], + embedding_function: EmbeddingFunction[Any] | None, ) -> Any: - """Create proper provider config object.""" + """Instantiate provider config with optional embedding_function injected.""" if provider == "chromadb": from crewai.rag.chromadb.config import ChromaDBConfig - config_kwargs = {} - if embedding_function: - config_kwargs["embedding_function"] = embedding_function - - config_kwargs.update(provider_config) - - return ChromaDBConfig(**config_kwargs) + kwargs = dict(provider_config) + if embedding_function is not None: + kwargs["embedding_function"] = embedding_function + return ChromaDBConfig(**kwargs) if provider == "qdrant": from crewai.rag.qdrant.config import QdrantConfig - config_kwargs = {} - if embedding_function: - config_kwargs["embedding_function"] = embedding_function + kwargs = dict(provider_config) + if embedding_function is not None: + kwargs["embedding_function"] = embedding_function + return QdrantConfig(**kwargs) - config_kwargs.update(provider_config) - - return QdrantConfig(**config_kwargs) - - return None + raise ValueError(f"Unhandled provider: {provider}") def add( self, diff --git a/lib/crewai-tools/src/crewai_tools/tools/rag/types.py b/lib/crewai-tools/src/crewai_tools/tools/rag/types.py new file mode 100644 index 000000000..1077c7b9b --- /dev/null +++ b/lib/crewai-tools/src/crewai_tools/tools/rag/types.py @@ -0,0 +1,32 @@ +"""Type definitions for RAG tool configuration.""" + +from typing import Any, Literal + +from crewai.rag.embeddings.types import ProviderSpec +from typing_extensions import TypedDict + + +class VectorDbConfig(TypedDict): + """Configuration for vector database provider. + + Attributes: + provider: RAG provider literal. + config: RAG configuration options. + """ + + provider: Literal["chromadb", "qdrant"] + config: dict[str, Any] + + +class RagToolConfig(TypedDict, total=False): + """Configuration accepted by RAG tools. + + Supports embedding model and vector database configuration. + + Attributes: + embedding_model: Embedding model configuration accepted by RAG tools. + vectordb: Vector database configuration accepted by RAG tools. + """ + + embedding_model: ProviderSpec + vectordb: VectorDbConfig diff --git a/lib/crewai-tools/src/crewai_tools/tools/txt_search_tool/txt_search_tool.py b/lib/crewai-tools/src/crewai_tools/tools/txt_search_tool/txt_search_tool.py index 12bb00a18..e287b504e 100644 --- a/lib/crewai-tools/src/crewai_tools/tools/txt_search_tool/txt_search_tool.py +++ b/lib/crewai-tools/src/crewai_tools/tools/txt_search_tool/txt_search_tool.py @@ -1,4 +1,5 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator +from typing_extensions import Self from crewai_tools.tools.rag.rag_tool import RagTool @@ -24,14 +25,17 @@ class TXTSearchTool(RagTool): "A tool that can be used to semantic search a query from a txt's content." ) args_schema: type[BaseModel] = TXTSearchToolSchema + txt: str | None = None - def __init__(self, txt: str | None = None, **kwargs): - super().__init__(**kwargs) - if txt is not None: - self.add(txt) - self.description = f"A tool that can be used to semantic search a query the {txt} txt's content." + @model_validator(mode="after") + def _configure_for_txt(self) -> Self: + """Configure tool for specific TXT file if provided.""" + if self.txt is not None: + self.add(self.txt) + self.description = f"A tool that can be used to semantic search a query the {self.txt} txt's content." self.args_schema = FixedTXTSearchToolSchema self._generate_description() + return self def _run( # type: ignore[override] self, diff --git a/lib/crewai-tools/tests/tools/rag/rag_tool_test.py b/lib/crewai-tools/tests/tools/rag/rag_tool_test.py index 5298ce1e2..48411699e 100644 --- a/lib/crewai-tools/tests/tools/rag/rag_tool_test.py +++ b/lib/crewai-tools/tests/tools/rag/rag_tool_test.py @@ -1,5 +1,3 @@ -"""Tests for RAG tool with mocked embeddings and vector database.""" - from pathlib import Path from tempfile import TemporaryDirectory from typing import cast @@ -117,15 +115,15 @@ def test_rag_tool_with_file( assert "Python is a programming language" in result -@patch("crewai_tools.tools.rag.rag_tool.RagTool._create_embedding_function") +@patch("crewai_tools.tools.rag.rag_tool.build_embedder") @patch("crewai_tools.adapters.crewai_rag_adapter.create_client") def test_rag_tool_with_custom_embeddings( - mock_create_client: Mock, mock_create_embedding: Mock + mock_create_client: Mock, mock_build_embedder: Mock ) -> None: """Test RagTool with custom embeddings configuration to ensure no API calls.""" mock_embedding_func = MagicMock() mock_embedding_func.return_value = [[0.2] * 1536] - mock_create_embedding.return_value = mock_embedding_func + mock_build_embedder.return_value = mock_embedding_func mock_client = MagicMock() mock_client.get_or_create_collection = MagicMock(return_value=None) @@ -153,7 +151,7 @@ def test_rag_tool_with_custom_embeddings( assert "Relevant Content:" in result assert "Test content" in result - mock_create_embedding.assert_called() + mock_build_embedder.assert_called() @patch("crewai_tools.adapters.crewai_rag_adapter.get_rag_client") @@ -176,3 +174,128 @@ def test_rag_tool_no_results( result = tool._run(query="Non-existent content") assert "Relevant Content:" in result assert "No relevant content found" in result + + +@patch("crewai_tools.adapters.crewai_rag_adapter.create_client") +def test_rag_tool_with_azure_config_without_env_vars( + mock_create_client: Mock, +) -> None: + """Test that RagTool accepts Azure config without requiring env vars. + + This test verifies the fix for the issue where RAG tools were ignoring + the embedding configuration passed via the config parameter and instead + requiring environment variables like EMBEDDINGS_OPENAI_API_KEY. + """ + mock_embedding_func = MagicMock() + mock_embedding_func.return_value = [[0.1] * 1536] + + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_client.add_documents = MagicMock(return_value=None) + mock_create_client.return_value = mock_client + + # Patch the embedding function builder to avoid actual API calls + with patch( + "crewai_tools.tools.rag.rag_tool.build_embedder", + return_value=mock_embedding_func, + ): + + class MyTool(RagTool): + pass + + # Configuration with explicit Azure credentials - should work without env vars + config = { + "embedding_model": { + "provider": "azure", + "config": { + "model": "text-embedding-3-small", + "api_key": "test-api-key", + "api_base": "https://test.openai.azure.com/", + "api_version": "2024-02-01", + "api_type": "azure", + "deployment_id": "test-deployment", + }, + } + } + + # This should not raise a validation error about missing env vars + tool = MyTool(config=config) + + assert tool.adapter is not None + assert isinstance(tool.adapter, CrewAIRagAdapter) + + +@patch("crewai_tools.adapters.crewai_rag_adapter.create_client") +def test_rag_tool_with_openai_config_without_env_vars( + mock_create_client: Mock, +) -> None: + """Test that RagTool accepts OpenAI config without requiring env vars.""" + mock_embedding_func = MagicMock() + mock_embedding_func.return_value = [[0.1] * 1536] + + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_create_client.return_value = mock_client + + with patch( + "crewai_tools.tools.rag.rag_tool.build_embedder", + return_value=mock_embedding_func, + ): + + class MyTool(RagTool): + pass + + config = { + "embedding_model": { + "provider": "openai", + "config": { + "model": "text-embedding-3-small", + "api_key": "sk-test123", + }, + } + } + + tool = MyTool(config=config) + + assert tool.adapter is not None + assert isinstance(tool.adapter, CrewAIRagAdapter) + + +@patch("crewai_tools.adapters.crewai_rag_adapter.create_client") +def test_rag_tool_config_with_qdrant_and_azure_embeddings( + mock_create_client: Mock, +) -> None: + """Test RagTool with Qdrant vector DB and Azure embeddings config.""" + mock_embedding_func = MagicMock() + mock_embedding_func.return_value = [[0.1] * 1536] + + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_create_client.return_value = mock_client + + with patch( + "crewai_tools.tools.rag.rag_tool.build_embedder", + return_value=mock_embedding_func, + ): + + class MyTool(RagTool): + pass + + config = { + "vectordb": {"provider": "qdrant", "config": {}}, + "embedding_model": { + "provider": "azure", + "config": { + "model": "text-embedding-3-large", + "api_key": "test-key", + "api_base": "https://test.openai.azure.com/", + "api_version": "2024-02-01", + "deployment_id": "test-deployment", + }, + }, + } + + tool = MyTool(config=config) + + assert tool.adapter is not None + assert isinstance(tool.adapter, CrewAIRagAdapter) diff --git a/lib/crewai-tools/tests/tools/rag/test_rag_tool_validation.py b/lib/crewai-tools/tests/tools/rag/test_rag_tool_validation.py new file mode 100644 index 000000000..773425187 --- /dev/null +++ b/lib/crewai-tools/tests/tools/rag/test_rag_tool_validation.py @@ -0,0 +1,66 @@ +"""Tests for improved RAG tool validation error messages.""" + +from unittest.mock import MagicMock, Mock, patch + +import pytest +from pydantic import ValidationError + +from crewai_tools.tools.rag.rag_tool import RagTool + + +@patch("crewai_tools.adapters.crewai_rag_adapter.create_client") +def test_azure_missing_deployment_id_gives_clear_error(mock_create_client: Mock) -> None: + """Test that missing deployment_id for Azure gives a clear, focused error message.""" + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_create_client.return_value = mock_client + + class MyTool(RagTool): + pass + + config = { + "embedding_model": { + "provider": "azure", + "config": { + "api_base": "http://localhost:4000/v1", + "api_key": "test-key", + "api_version": "2024-02-01", + }, + } + } + + with pytest.raises(ValueError) as exc_info: + MyTool(config=config) + + error_msg = str(exc_info.value) + assert "azure" in error_msg.lower() + assert "deployment_id" in error_msg.lower() + assert "bedrock" not in error_msg.lower() + assert "cohere" not in error_msg.lower() + assert "huggingface" not in error_msg.lower() + + +@patch("crewai_tools.adapters.crewai_rag_adapter.create_client") +def test_valid_azure_config_works(mock_create_client: Mock) -> None: + """Test that valid Azure config works without errors.""" + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_create_client.return_value = mock_client + + class MyTool(RagTool): + pass + + config = { + "embedding_model": { + "provider": "azure", + "config": { + "api_base": "http://localhost:4000/v1", + "api_key": "test-key", + "api_version": "2024-02-01", + "deployment_id": "text-embedding-3-small", + }, + } + } + + tool = MyTool(config=config) + assert tool is not None \ No newline at end of file diff --git a/lib/crewai-tools/tests/tools/test_pdf_search_tool_config.py b/lib/crewai-tools/tests/tools/test_pdf_search_tool_config.py new file mode 100644 index 000000000..fc893dfb2 --- /dev/null +++ b/lib/crewai-tools/tests/tools/test_pdf_search_tool_config.py @@ -0,0 +1,116 @@ +from unittest.mock import MagicMock, Mock, patch + +from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter +from crewai_tools.tools.pdf_search_tool.pdf_search_tool import PDFSearchTool + + +@patch("crewai_tools.adapters.crewai_rag_adapter.create_client") +def test_pdf_search_tool_with_azure_config_without_env_vars( + mock_create_client: Mock, +) -> None: + """Test PDFSearchTool accepts Azure config without requiring env vars. + + This verifies the fix for the reported issue where PDFSearchTool would + throw a validation error: + pydantic_core._pydantic_core.ValidationError: 1 validation error for PDFSearchTool + EMBEDDINGS_OPENAI_API_KEY + Field required [type=missing, input_value={}, input_type=dict] + """ + mock_embedding_func = MagicMock() + mock_embedding_func.return_value = [[0.1] * 1536] + + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_create_client.return_value = mock_client + + # Patch the embedding function builder to avoid actual API calls + with patch( + "crewai_tools.tools.rag.rag_tool.build_embedder", + return_value=mock_embedding_func, + ): + # This is the exact config format from the bug report + config = { + "embedding_model": { + "provider": "azure", + "config": { + "model": "text-embedding-3-small", + "api_key": "test-litellm-api-key", + "api_base": "https://test.litellm.proxy/", + "api_version": "2024-02-01", + "api_type": "azure", + "deployment_id": "test-deployment", + }, + } + } + + # This should not raise a validation error about missing env vars + tool = PDFSearchTool(config=config) + + assert tool.adapter is not None + assert isinstance(tool.adapter, CrewAIRagAdapter) + assert tool.name == "Search a PDF's content" + + +@patch("crewai_tools.adapters.crewai_rag_adapter.create_client") +def test_pdf_search_tool_with_openai_config_without_env_vars( + mock_create_client: Mock, +) -> None: + """Test PDFSearchTool accepts OpenAI config without requiring env vars.""" + mock_embedding_func = MagicMock() + mock_embedding_func.return_value = [[0.1] * 1536] + + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_create_client.return_value = mock_client + + with patch( + "crewai_tools.tools.rag.rag_tool.build_embedder", + return_value=mock_embedding_func, + ): + config = { + "embedding_model": { + "provider": "openai", + "config": { + "model": "text-embedding-3-small", + "api_key": "sk-test123", + }, + } + } + + tool = PDFSearchTool(config=config) + + assert tool.adapter is not None + assert isinstance(tool.adapter, CrewAIRagAdapter) + + +@patch("crewai_tools.adapters.crewai_rag_adapter.create_client") +def test_pdf_search_tool_with_vectordb_and_embedding_config( + mock_create_client: Mock, +) -> None: + """Test PDFSearchTool with both vector DB and embedding config.""" + mock_embedding_func = MagicMock() + mock_embedding_func.return_value = [[0.1] * 1536] + + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_create_client.return_value = mock_client + + with patch( + "crewai_tools.tools.rag.rag_tool.build_embedder", + return_value=mock_embedding_func, + ): + config = { + "vectordb": {"provider": "chromadb", "config": {}}, + "embedding_model": { + "provider": "openai", + "config": { + "model": "text-embedding-3-large", + "api_key": "sk-test-key", + }, + }, + } + + tool = PDFSearchTool(config=config) + + assert tool.adapter is not None + assert isinstance(tool.adapter, CrewAIRagAdapter) \ No newline at end of file diff --git a/lib/crewai-tools/tests/tools/test_txt_search_tool_config.py b/lib/crewai-tools/tests/tools/test_txt_search_tool_config.py new file mode 100644 index 000000000..266f9bef7 --- /dev/null +++ b/lib/crewai-tools/tests/tools/test_txt_search_tool_config.py @@ -0,0 +1,104 @@ +from unittest.mock import MagicMock, Mock, patch + +from crewai_tools.adapters.crewai_rag_adapter import CrewAIRagAdapter +from crewai_tools.tools.txt_search_tool.txt_search_tool import TXTSearchTool + + +@patch("crewai_tools.adapters.crewai_rag_adapter.create_client") +def test_txt_search_tool_with_azure_config_without_env_vars( + mock_create_client: Mock, +) -> None: + """Test TXTSearchTool accepts Azure config without requiring env vars.""" + mock_embedding_func = MagicMock() + mock_embedding_func.return_value = [[0.1] * 1536] + + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_create_client.return_value = mock_client + + with patch( + "crewai_tools.tools.rag.rag_tool.build_embedder", + return_value=mock_embedding_func, + ): + config = { + "embedding_model": { + "provider": "azure", + "config": { + "model": "text-embedding-3-small", + "api_key": "test-api-key", + "api_base": "https://test.openai.azure.com/", + "api_version": "2024-02-01", + "api_type": "azure", + "deployment_id": "test-deployment", + }, + } + } + + # This should not raise a validation error about missing env vars + tool = TXTSearchTool(config=config) + + assert tool.adapter is not None + assert isinstance(tool.adapter, CrewAIRagAdapter) + assert tool.name == "Search a txt's content" + + +@patch("crewai_tools.adapters.crewai_rag_adapter.create_client") +def test_txt_search_tool_with_openai_config_without_env_vars( + mock_create_client: Mock, +) -> None: + """Test TXTSearchTool accepts OpenAI config without requiring env vars.""" + mock_embedding_func = MagicMock() + mock_embedding_func.return_value = [[0.1] * 1536] + + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_create_client.return_value = mock_client + + with patch( + "crewai_tools.tools.rag.rag_tool.build_embedder", + return_value=mock_embedding_func, + ): + config = { + "embedding_model": { + "provider": "openai", + "config": { + "model": "text-embedding-3-small", + "api_key": "sk-test123", + }, + } + } + + tool = TXTSearchTool(config=config) + + assert tool.adapter is not None + assert isinstance(tool.adapter, CrewAIRagAdapter) + + +@patch("crewai_tools.adapters.crewai_rag_adapter.create_client") +def test_txt_search_tool_with_cohere_config(mock_create_client: Mock) -> None: + """Test TXTSearchTool with Cohere embedding provider.""" + mock_embedding_func = MagicMock() + mock_embedding_func.return_value = [[0.1] * 1024] + + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_create_client.return_value = mock_client + + with patch( + "crewai_tools.tools.rag.rag_tool.build_embedder", + return_value=mock_embedding_func, + ): + config = { + "embedding_model": { + "provider": "cohere", + "config": { + "model": "embed-english-v3.0", + "api_key": "test-cohere-key", + }, + } + } + + tool = TXTSearchTool(config=config) + + assert tool.adapter is not None + assert isinstance(tool.adapter, CrewAIRagAdapter) \ No newline at end of file diff --git a/lib/crewai/src/crewai/rag/embeddings/factory.py b/lib/crewai/src/crewai/rag/embeddings/factory.py index f5d5d9559..eacf67b82 100644 --- a/lib/crewai/src/crewai/rag/embeddings/factory.py +++ b/lib/crewai/src/crewai/rag/embeddings/factory.py @@ -91,6 +91,7 @@ PROVIDER_PATHS = { "cohere": "crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider", "custom": "crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider", "google-generativeai": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider", + "google": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider", "google-vertex": "crewai.rag.embeddings.providers.google.vertex.VertexAIProvider", "huggingface": "crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider", "instructor": "crewai.rag.embeddings.providers.instructor.instructor_provider.InstructorProvider", diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/aws/bedrock.py b/lib/crewai/src/crewai/rag/embeddings/providers/aws/bedrock.py index 7d7c7bae4..1a0665110 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/aws/bedrock.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/aws/bedrock.py @@ -5,7 +5,7 @@ from typing import Any from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import ( AmazonBedrockEmbeddingFunction, ) -from pydantic import Field +from pydantic import AliasChoices, Field from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider @@ -21,7 +21,7 @@ def create_aws_session() -> Any: ValueError: If AWS session creation fails """ try: - import boto3 # type: ignore[import] + import boto3 return boto3.Session() except ImportError as e: @@ -46,7 +46,12 @@ class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]): model_name: str = Field( default="amazon.titan-embed-text-v1", description="Model name to use for embeddings", - validation_alias="EMBEDDINGS_BEDROCK_MODEL_NAME", + validation_alias=AliasChoices( + "EMBEDDINGS_BEDROCK_MODEL_NAME", + "BEDROCK_MODEL_NAME", + "AWS_BEDROCK_MODEL_NAME", + "model", + ), ) session: Any = Field( default_factory=create_aws_session, description="AWS session object" diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/cohere/cohere_provider.py b/lib/crewai/src/crewai/rag/embeddings/providers/cohere/cohere_provider.py index b5df0f7dd..90f49eb2c 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/cohere/cohere_provider.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/cohere/cohere_provider.py @@ -3,7 +3,7 @@ from chromadb.utils.embedding_functions.cohere_embedding_function import ( CohereEmbeddingFunction, ) -from pydantic import Field +from pydantic import AliasChoices, Field from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider @@ -15,10 +15,14 @@ class CohereProvider(BaseEmbeddingsProvider[CohereEmbeddingFunction]): default=CohereEmbeddingFunction, description="Cohere embedding function class" ) api_key: str = Field( - description="Cohere API key", validation_alias="EMBEDDINGS_COHERE_API_KEY" + description="Cohere API key", + validation_alias=AliasChoices("EMBEDDINGS_COHERE_API_KEY", "COHERE_API_KEY"), ) model_name: str = Field( default="large", description="Model name to use for embeddings", - validation_alias="EMBEDDINGS_COHERE_MODEL_NAME", + validation_alias=AliasChoices( + "EMBEDDINGS_COHERE_MODEL_NAME", + "model", + ), ) diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/google/generative_ai.py b/lib/crewai/src/crewai/rag/embeddings/providers/google/generative_ai.py index 6bd6a8c58..28e3db690 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/google/generative_ai.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/google/generative_ai.py @@ -1,9 +1,11 @@ """Google Generative AI embeddings provider.""" +from typing import Literal + from chromadb.utils.embedding_functions.google_embedding_function import ( GoogleGenerativeAiEmbeddingFunction, ) -from pydantic import Field +from pydantic import AliasChoices, Field from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider @@ -15,16 +17,27 @@ class GenerativeAiProvider(BaseEmbeddingsProvider[GoogleGenerativeAiEmbeddingFun default=GoogleGenerativeAiEmbeddingFunction, description="Google Generative AI embedding function class", ) - model_name: str = Field( - default="models/embedding-001", + model_name: Literal[ + "gemini-embedding-001", "text-embedding-005", "text-multilingual-embedding-002" + ] = Field( + default="gemini-embedding-001", description="Model name to use for embeddings", - validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME", + validation_alias=AliasChoices( + "EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME", "model" + ), ) api_key: str = Field( - description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_API_KEY" + description="Google API key", + validation_alias=AliasChoices( + "EMBEDDINGS_GOOGLE_API_KEY", "GOOGLE_API_KEY", "GEMINI_API_KEY" + ), ) task_type: str = Field( default="RETRIEVAL_DOCUMENT", description="Task type for embeddings", - validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE", + validation_alias=AliasChoices( + "EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE", + "GOOGLE_GENERATIVE_AI_TASK_TYPE", + "GEMINI_TASK_TYPE", + ), ) diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/google/types.py b/lib/crewai/src/crewai/rag/embeddings/providers/google/types.py index b97ec9474..eaec1cd8f 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/google/types.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/google/types.py @@ -6,10 +6,23 @@ from typing_extensions import Required, TypedDict class GenerativeAiProviderConfig(TypedDict, total=False): - """Configuration for Google Generative AI provider.""" + """Configuration for Google Generative AI provider. + + Attributes: + api_key: Google API key for authentication. + model_name: Embedding model name. + task_type: Task type for embeddings. Default is "RETRIEVAL_DOCUMENT". + """ api_key: str - model_name: Annotated[str, "models/embedding-001"] + model_name: Annotated[ + Literal[ + "gemini-embedding-001", + "text-embedding-005", + "text-multilingual-embedding-002", + ], + "gemini-embedding-001", + ] task_type: Annotated[str, "RETRIEVAL_DOCUMENT"] diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/google/vertex.py b/lib/crewai/src/crewai/rag/embeddings/providers/google/vertex.py index ab14177b9..6547e4abc 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/google/vertex.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/google/vertex.py @@ -3,7 +3,7 @@ from chromadb.utils.embedding_functions.google_embedding_function import ( GoogleVertexEmbeddingFunction, ) -from pydantic import Field +from pydantic import AliasChoices, Field from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider @@ -18,18 +18,29 @@ class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]): model_name: str = Field( default="textembedding-gecko", description="Model name to use for embeddings", - validation_alias="EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME", + validation_alias=AliasChoices( + "EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME", + "GOOGLE_VERTEX_MODEL_NAME", + "model", + ), ) api_key: str = Field( - description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_CLOUD_API_KEY" + description="Google API key", + validation_alias=AliasChoices( + "EMBEDDINGS_GOOGLE_CLOUD_API_KEY", "GOOGLE_CLOUD_API_KEY" + ), ) project_id: str = Field( default="cloud-large-language-models", description="GCP project ID", - validation_alias="EMBEDDINGS_GOOGLE_CLOUD_PROJECT", + validation_alias=AliasChoices( + "EMBEDDINGS_GOOGLE_CLOUD_PROJECT", "GOOGLE_CLOUD_PROJECT" + ), ) region: str = Field( default="us-central1", description="GCP region", - validation_alias="EMBEDDINGS_GOOGLE_CLOUD_REGION", + validation_alias=AliasChoices( + "EMBEDDINGS_GOOGLE_CLOUD_REGION", "GOOGLE_CLOUD_REGION" + ), ) diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/huggingface_provider.py b/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/huggingface_provider.py index cd774c17f..481e9f8ba 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/huggingface_provider.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/huggingface/huggingface_provider.py @@ -3,7 +3,7 @@ from chromadb.utils.embedding_functions.huggingface_embedding_function import ( HuggingFaceEmbeddingServer, ) -from pydantic import Field +from pydantic import AliasChoices, Field from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider @@ -16,5 +16,6 @@ class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]): description="HuggingFace embedding function class", ) url: str = Field( - description="HuggingFace API URL", validation_alias="EMBEDDINGS_HUGGINGFACE_URL" + description="HuggingFace API URL", + validation_alias=AliasChoices("EMBEDDINGS_HUGGINGFACE_URL", "HUGGINGFACE_URL"), ) diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/ibm/watsonx.py b/lib/crewai/src/crewai/rag/embeddings/providers/ibm/watsonx.py index bcd52804c..062ca31f8 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/ibm/watsonx.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/ibm/watsonx.py @@ -2,7 +2,7 @@ from typing import Any -from pydantic import Field, model_validator +from pydantic import AliasChoices, Field, model_validator from typing_extensions import Self from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider @@ -21,7 +21,10 @@ class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]): default=WatsonXEmbeddingFunction, description="WatsonX embedding function class" ) model_id: str = Field( - description="WatsonX model ID", validation_alias="EMBEDDINGS_WATSONX_MODEL_ID" + description="WatsonX model ID", + validation_alias=AliasChoices( + "EMBEDDINGS_WATSONX_MODEL_ID", "WATSONX_MODEL_ID" + ), ) params: dict[str, str | dict[str, str]] | None = Field( default=None, description="Additional parameters" @@ -30,109 +33,143 @@ class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]): project_id: str | None = Field( default=None, description="WatsonX project ID", - validation_alias="EMBEDDINGS_WATSONX_PROJECT_ID", + validation_alias=AliasChoices( + "EMBEDDINGS_WATSONX_PROJECT_ID", "WATSONX_PROJECT_ID" + ), ) space_id: str | None = Field( default=None, description="WatsonX space ID", - validation_alias="EMBEDDINGS_WATSONX_SPACE_ID", + validation_alias=AliasChoices( + "EMBEDDINGS_WATSONX_SPACE_ID", "WATSONX_SPACE_ID" + ), ) api_client: Any | None = Field(default=None, description="WatsonX API client") verify: bool | str | None = Field( default=None, description="SSL verification", - validation_alias="EMBEDDINGS_WATSONX_VERIFY", + validation_alias=AliasChoices("EMBEDDINGS_WATSONX_VERIFY", "WATSONX_VERIFY"), ) persistent_connection: bool = Field( default=True, description="Use persistent connection", - validation_alias="EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION", + validation_alias=AliasChoices( + "EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION", "WATSONX_PERSISTENT_CONNECTION" + ), ) batch_size: int = Field( default=100, description="Batch size for processing", - validation_alias="EMBEDDINGS_WATSONX_BATCH_SIZE", + validation_alias=AliasChoices( + "EMBEDDINGS_WATSONX_BATCH_SIZE", "WATSONX_BATCH_SIZE" + ), ) concurrency_limit: int = Field( default=10, description="Concurrency limit", - validation_alias="EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT", + validation_alias=AliasChoices( + "EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT", "WATSONX_CONCURRENCY_LIMIT" + ), ) max_retries: int | None = Field( default=None, description="Maximum retries", - validation_alias="EMBEDDINGS_WATSONX_MAX_RETRIES", + validation_alias=AliasChoices( + "EMBEDDINGS_WATSONX_MAX_RETRIES", "WATSONX_MAX_RETRIES" + ), ) delay_time: float | None = Field( default=None, description="Delay time between retries", - validation_alias="EMBEDDINGS_WATSONX_DELAY_TIME", + validation_alias=AliasChoices( + "EMBEDDINGS_WATSONX_DELAY_TIME", "WATSONX_DELAY_TIME" + ), ) retry_status_codes: list[int] | None = Field( default=None, description="HTTP status codes to retry on" ) url: str = Field( - description="WatsonX API URL", validation_alias="EMBEDDINGS_WATSONX_URL" + description="WatsonX API URL", + validation_alias=AliasChoices("EMBEDDINGS_WATSONX_URL", "WATSONX_URL"), ) api_key: str = Field( - description="WatsonX API key", validation_alias="EMBEDDINGS_WATSONX_API_KEY" + description="WatsonX API key", + validation_alias=AliasChoices("EMBEDDINGS_WATSONX_API_KEY", "WATSONX_API_KEY"), ) name: str | None = Field( default=None, description="Service name", - validation_alias="EMBEDDINGS_WATSONX_NAME", + validation_alias=AliasChoices("EMBEDDINGS_WATSONX_NAME", "WATSONX_NAME"), ) iam_serviceid_crn: str | None = Field( default=None, description="IAM service ID CRN", - validation_alias="EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN", + validation_alias=AliasChoices( + "EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN", "WATSONX_IAM_SERVICEID_CRN" + ), ) trusted_profile_id: str | None = Field( default=None, description="Trusted profile ID", - validation_alias="EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID", + validation_alias=AliasChoices( + "EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID", "WATSONX_TRUSTED_PROFILE_ID" + ), ) token: str | None = Field( default=None, description="Bearer token", - validation_alias="EMBEDDINGS_WATSONX_TOKEN", + validation_alias=AliasChoices("EMBEDDINGS_WATSONX_TOKEN", "WATSONX_TOKEN"), ) projects_token: str | None = Field( default=None, description="Projects token", - validation_alias="EMBEDDINGS_WATSONX_PROJECTS_TOKEN", + validation_alias=AliasChoices( + "EMBEDDINGS_WATSONX_PROJECTS_TOKEN", "WATSONX_PROJECTS_TOKEN" + ), ) username: str | None = Field( default=None, description="Username", - validation_alias="EMBEDDINGS_WATSONX_USERNAME", + validation_alias=AliasChoices( + "EMBEDDINGS_WATSONX_USERNAME", "WATSONX_USERNAME" + ), ) password: str | None = Field( default=None, description="Password", - validation_alias="EMBEDDINGS_WATSONX_PASSWORD", + validation_alias=AliasChoices( + "EMBEDDINGS_WATSONX_PASSWORD", "WATSONX_PASSWORD" + ), ) instance_id: str | None = Field( default=None, description="Service instance ID", - validation_alias="EMBEDDINGS_WATSONX_INSTANCE_ID", + validation_alias=AliasChoices( + "EMBEDDINGS_WATSONX_INSTANCE_ID", "WATSONX_INSTANCE_ID" + ), ) version: str | None = Field( default=None, description="API version", - validation_alias="EMBEDDINGS_WATSONX_VERSION", + validation_alias=AliasChoices("EMBEDDINGS_WATSONX_VERSION", "WATSONX_VERSION"), ) bedrock_url: str | None = Field( default=None, description="Bedrock URL", - validation_alias="EMBEDDINGS_WATSONX_BEDROCK_URL", + validation_alias=AliasChoices( + "EMBEDDINGS_WATSONX_BEDROCK_URL", "WATSONX_BEDROCK_URL" + ), ) platform_url: str | None = Field( default=None, description="Platform URL", - validation_alias="EMBEDDINGS_WATSONX_PLATFORM_URL", + validation_alias=AliasChoices( + "EMBEDDINGS_WATSONX_PLATFORM_URL", "WATSONX_PLATFORM_URL" + ), + ) + proxies: dict[str, Any] | None = Field( + default=None, description="Proxy configuration" ) - proxies: dict | None = Field(default=None, description="Proxy configuration") @model_validator(mode="after") def validate_space_or_project(self) -> Self: diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/instructor/instructor_provider.py b/lib/crewai/src/crewai/rag/embeddings/providers/instructor/instructor_provider.py index 030dbac30..d569d989d 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/instructor/instructor_provider.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/instructor/instructor_provider.py @@ -3,7 +3,7 @@ from chromadb.utils.embedding_functions.instructor_embedding_function import ( InstructorEmbeddingFunction, ) -from pydantic import Field +from pydantic import AliasChoices, Field from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider @@ -18,15 +18,23 @@ class InstructorProvider(BaseEmbeddingsProvider[InstructorEmbeddingFunction]): model_name: str = Field( default="hkunlp/instructor-base", description="Model name to use", - validation_alias="EMBEDDINGS_INSTRUCTOR_MODEL_NAME", + validation_alias=AliasChoices( + "EMBEDDINGS_INSTRUCTOR_MODEL_NAME", + "INSTRUCTOR_MODEL_NAME", + "model", + ), ) device: str = Field( default="cpu", description="Device to run model on (cpu or cuda)", - validation_alias="EMBEDDINGS_INSTRUCTOR_DEVICE", + validation_alias=AliasChoices( + "EMBEDDINGS_INSTRUCTOR_DEVICE", "INSTRUCTOR_DEVICE" + ), ) instruction: str | None = Field( default=None, description="Instruction for embeddings", - validation_alias="EMBEDDINGS_INSTRUCTOR_INSTRUCTION", + validation_alias=AliasChoices( + "EMBEDDINGS_INSTRUCTOR_INSTRUCTION", "INSTRUCTOR_INSTRUCTION" + ), ) diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/jina/jina_provider.py b/lib/crewai/src/crewai/rag/embeddings/providers/jina/jina_provider.py index 8b5fab3f2..0c85fedbf 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/jina/jina_provider.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/jina/jina_provider.py @@ -3,7 +3,7 @@ from chromadb.utils.embedding_functions.jina_embedding_function import ( JinaEmbeddingFunction, ) -from pydantic import Field +from pydantic import AliasChoices, Field from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider @@ -15,10 +15,15 @@ class JinaProvider(BaseEmbeddingsProvider[JinaEmbeddingFunction]): default=JinaEmbeddingFunction, description="Jina embedding function class" ) api_key: str = Field( - description="Jina API key", validation_alias="EMBEDDINGS_JINA_API_KEY" + description="Jina API key", + validation_alias=AliasChoices("EMBEDDINGS_JINA_API_KEY", "JINA_API_KEY"), ) model_name: str = Field( default="jina-embeddings-v2-base-en", description="Model name to use for embeddings", - validation_alias="EMBEDDINGS_JINA_MODEL_NAME", + validation_alias=AliasChoices( + "EMBEDDINGS_JINA_MODEL_NAME", + "JINA_MODEL_NAME", + "model", + ), ) diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/microsoft/azure.py b/lib/crewai/src/crewai/rag/embeddings/providers/microsoft/azure.py index d798b5e0d..e1d03dd19 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/microsoft/azure.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/microsoft/azure.py @@ -5,7 +5,7 @@ from typing import Any from chromadb.utils.embedding_functions.openai_embedding_function import ( OpenAIEmbeddingFunction, ) -from pydantic import Field +from pydantic import AliasChoices, Field from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider @@ -18,27 +18,39 @@ class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]): description="Azure OpenAI embedding function class", ) api_key: str = Field( - description="Azure API key", validation_alias="EMBEDDINGS_OPENAI_API_KEY" + description="Azure API key", + validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_KEY", "OPENAI_API_KEY"), ) api_base: str | None = Field( default=None, description="Azure endpoint URL", - validation_alias="EMBEDDINGS_OPENAI_API_BASE", + validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_BASE", "OPENAI_API_BASE"), ) api_type: str = Field( default="azure", description="API type for Azure", - validation_alias="EMBEDDINGS_OPENAI_API_TYPE", + validation_alias=AliasChoices( + "EMBEDDINGS_OPENAI_API_TYPE", "OPENAI_API_TYPE", "AZURE_OPENAI_API_TYPE" + ), ) api_version: str | None = Field( - default=None, + default="2024-02-01", description="Azure API version", - validation_alias="EMBEDDINGS_OPENAI_API_VERSION", + validation_alias=AliasChoices( + "EMBEDDINGS_OPENAI_API_VERSION", + "OPENAI_API_VERSION", + "AZURE_OPENAI_API_VERSION", + ), ) model_name: str = Field( default="text-embedding-ada-002", description="Model name to use for embeddings", - validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME", + validation_alias=AliasChoices( + "EMBEDDINGS_OPENAI_MODEL_NAME", + "OPENAI_MODEL_NAME", + "AZURE_OPENAI_MODEL_NAME", + "model", + ), ) default_headers: dict[str, Any] | None = Field( default=None, description="Default headers for API requests" @@ -46,15 +58,26 @@ class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]): dimensions: int | None = Field( default=None, description="Embedding dimensions", - validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS", + validation_alias=AliasChoices( + "EMBEDDINGS_OPENAI_DIMENSIONS", + "OPENAI_DIMENSIONS", + "AZURE_OPENAI_DIMENSIONS", + ), ) - deployment_id: str | None = Field( - default=None, + deployment_id: str = Field( description="Azure deployment ID", - validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID", + validation_alias=AliasChoices( + "EMBEDDINGS_OPENAI_DEPLOYMENT_ID", + "AZURE_OPENAI_DEPLOYMENT", + "AZURE_DEPLOYMENT_ID", + ), ) organization_id: str | None = Field( default=None, description="Organization ID", - validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID", + validation_alias=AliasChoices( + "EMBEDDINGS_OPENAI_ORGANIZATION_ID", + "OPENAI_ORGANIZATION_ID", + "AZURE_OPENAI_ORGANIZATION_ID", + ), ) diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/microsoft/types.py b/lib/crewai/src/crewai/rag/embeddings/providers/microsoft/types.py index 2f00cf787..45dc2b2ef 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/microsoft/types.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/microsoft/types.py @@ -15,7 +15,7 @@ class AzureProviderConfig(TypedDict, total=False): model_name: Annotated[str, "text-embedding-ada-002"] default_headers: dict[str, Any] dimensions: int - deployment_id: str + deployment_id: Required[str] organization_id: str diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/ollama/ollama_provider.py b/lib/crewai/src/crewai/rag/embeddings/providers/ollama/ollama_provider.py index 8f16f92dd..0dd1024de 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/ollama/ollama_provider.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/ollama/ollama_provider.py @@ -3,7 +3,7 @@ from chromadb.utils.embedding_functions.ollama_embedding_function import ( OllamaEmbeddingFunction, ) -from pydantic import Field +from pydantic import AliasChoices, Field from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider @@ -17,9 +17,14 @@ class OllamaProvider(BaseEmbeddingsProvider[OllamaEmbeddingFunction]): url: str = Field( default="http://localhost:11434/api/embeddings", description="Ollama API endpoint URL", - validation_alias="EMBEDDINGS_OLLAMA_URL", + validation_alias=AliasChoices("EMBEDDINGS_OLLAMA_URL", "OLLAMA_URL"), ) model_name: str = Field( description="Model name to use for embeddings", - validation_alias="EMBEDDINGS_OLLAMA_MODEL_NAME", + validation_alias=AliasChoices( + "EMBEDDINGS_OLLAMA_MODEL_NAME", + "OLLAMA_MODEL_NAME", + "OLLAMA_MODEL", + "model", + ), ) diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/onnx/onnx_provider.py b/lib/crewai/src/crewai/rag/embeddings/providers/onnx/onnx_provider.py index 17ec7bdf3..7cc3b9739 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/onnx/onnx_provider.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/onnx/onnx_provider.py @@ -1,7 +1,7 @@ """ONNX embeddings provider.""" from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2 -from pydantic import Field +from pydantic import AliasChoices, Field from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider @@ -15,5 +15,7 @@ class ONNXProvider(BaseEmbeddingsProvider[ONNXMiniLM_L6_V2]): preferred_providers: list[str] | None = Field( default=None, description="Preferred ONNX execution providers", - validation_alias="EMBEDDINGS_ONNX_PREFERRED_PROVIDERS", + validation_alias=AliasChoices( + "EMBEDDINGS_ONNX_PREFERRED_PROVIDERS", "ONNX_PREFERRED_PROVIDERS" + ), ) diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/openai/openai_provider.py b/lib/crewai/src/crewai/rag/embeddings/providers/openai/openai_provider.py index a4531dafe..67017add4 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/openai/openai_provider.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/openai/openai_provider.py @@ -5,7 +5,7 @@ from typing import Any from chromadb.utils.embedding_functions.openai_embedding_function import ( OpenAIEmbeddingFunction, ) -from pydantic import Field +from pydantic import AliasChoices, Field from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider @@ -20,27 +20,33 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]): api_key: str | None = Field( default=None, description="OpenAI API key", - validation_alias="EMBEDDINGS_OPENAI_API_KEY", + validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_KEY", "OPENAI_API_KEY"), ) model_name: str = Field( default="text-embedding-ada-002", description="Model name to use for embeddings", - validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME", + validation_alias=AliasChoices( + "EMBEDDINGS_OPENAI_MODEL_NAME", + "OPENAI_MODEL_NAME", + "model", + ), ) api_base: str | None = Field( default=None, description="Base URL for API requests", - validation_alias="EMBEDDINGS_OPENAI_API_BASE", + validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_BASE", "OPENAI_API_BASE"), ) api_type: str | None = Field( default=None, description="API type (e.g., 'azure')", - validation_alias="EMBEDDINGS_OPENAI_API_TYPE", + validation_alias=AliasChoices("EMBEDDINGS_OPENAI_API_TYPE", "OPENAI_API_TYPE"), ) api_version: str | None = Field( default=None, description="API version", - validation_alias="EMBEDDINGS_OPENAI_API_VERSION", + validation_alias=AliasChoices( + "EMBEDDINGS_OPENAI_API_VERSION", "OPENAI_API_VERSION" + ), ) default_headers: dict[str, Any] | None = Field( default=None, description="Default headers for API requests" @@ -48,15 +54,21 @@ class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]): dimensions: int | None = Field( default=None, description="Embedding dimensions", - validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS", + validation_alias=AliasChoices( + "EMBEDDINGS_OPENAI_DIMENSIONS", "OPENAI_DIMENSIONS" + ), ) deployment_id: str | None = Field( default=None, description="Azure deployment ID", - validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID", + validation_alias=AliasChoices( + "EMBEDDINGS_OPENAI_DEPLOYMENT_ID", "OPENAI_DEPLOYMENT_ID" + ), ) organization_id: str | None = Field( default=None, description="OpenAI organization ID", - validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID", + validation_alias=AliasChoices( + "EMBEDDINGS_OPENAI_ORGANIZATION_ID", "OPENAI_ORGANIZATION_ID" + ), ) diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/openclip/openclip_provider.py b/lib/crewai/src/crewai/rag/embeddings/providers/openclip/openclip_provider.py index 74b025087..b790ce897 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/openclip/openclip_provider.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/openclip/openclip_provider.py @@ -3,7 +3,7 @@ from chromadb.utils.embedding_functions.open_clip_embedding_function import ( OpenCLIPEmbeddingFunction, ) -from pydantic import Field +from pydantic import AliasChoices, Field from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider @@ -18,15 +18,21 @@ class OpenCLIPProvider(BaseEmbeddingsProvider[OpenCLIPEmbeddingFunction]): model_name: str = Field( default="ViT-B-32", description="Model name to use", - validation_alias="EMBEDDINGS_OPENCLIP_MODEL_NAME", + validation_alias=AliasChoices( + "EMBEDDINGS_OPENCLIP_MODEL_NAME", + "OPENCLIP_MODEL_NAME", + "model", + ), ) checkpoint: str = Field( default="laion2b_s34b_b79k", description="Model checkpoint", - validation_alias="EMBEDDINGS_OPENCLIP_CHECKPOINT", + validation_alias=AliasChoices( + "EMBEDDINGS_OPENCLIP_CHECKPOINT", "OPENCLIP_CHECKPOINT" + ), ) device: str | None = Field( default="cpu", description="Device to run model on", - validation_alias="EMBEDDINGS_OPENCLIP_DEVICE", + validation_alias=AliasChoices("EMBEDDINGS_OPENCLIP_DEVICE", "OPENCLIP_DEVICE"), ) diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/roboflow/roboflow_provider.py b/lib/crewai/src/crewai/rag/embeddings/providers/roboflow/roboflow_provider.py index 5ac33b4bb..a6f310d21 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/roboflow/roboflow_provider.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/roboflow/roboflow_provider.py @@ -3,7 +3,7 @@ from chromadb.utils.embedding_functions.roboflow_embedding_function import ( RoboflowEmbeddingFunction, ) -from pydantic import Field +from pydantic import AliasChoices, Field from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider @@ -18,10 +18,14 @@ class RoboflowProvider(BaseEmbeddingsProvider[RoboflowEmbeddingFunction]): api_key: str = Field( default="", description="Roboflow API key", - validation_alias="EMBEDDINGS_ROBOFLOW_API_KEY", + validation_alias=AliasChoices( + "EMBEDDINGS_ROBOFLOW_API_KEY", "ROBOFLOW_API_KEY" + ), ) api_url: str = Field( default="https://infer.roboflow.com", description="Roboflow API URL", - validation_alias="EMBEDDINGS_ROBOFLOW_API_URL", + validation_alias=AliasChoices( + "EMBEDDINGS_ROBOFLOW_API_URL", "ROBOFLOW_API_URL" + ), ) diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/sentence_transformer/sentence_transformer_provider.py b/lib/crewai/src/crewai/rag/embeddings/providers/sentence_transformer/sentence_transformer_provider.py index af9bc0195..c045b4ebd 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/sentence_transformer/sentence_transformer_provider.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/sentence_transformer/sentence_transformer_provider.py @@ -3,7 +3,7 @@ from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import ( SentenceTransformerEmbeddingFunction, ) -from pydantic import Field +from pydantic import AliasChoices, Field from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider @@ -20,15 +20,24 @@ class SentenceTransformerProvider( model_name: str = Field( default="all-MiniLM-L6-v2", description="Model name to use", - validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME", + validation_alias=AliasChoices( + "EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME", + "SENTENCE_TRANSFORMER_MODEL_NAME", + "model", + ), ) device: str = Field( default="cpu", description="Device to run model on (cpu or cuda)", - validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE", + validation_alias=AliasChoices( + "EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE", "SENTENCE_TRANSFORMER_DEVICE" + ), ) normalize_embeddings: bool = Field( default=False, description="Whether to normalize embeddings", - validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS", + validation_alias=AliasChoices( + "EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS", + "SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS", + ), ) diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/text2vec/text2vec_provider.py b/lib/crewai/src/crewai/rag/embeddings/providers/text2vec/text2vec_provider.py index e6ebb1c08..83497aa7e 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/text2vec/text2vec_provider.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/text2vec/text2vec_provider.py @@ -3,7 +3,7 @@ from chromadb.utils.embedding_functions.text2vec_embedding_function import ( Text2VecEmbeddingFunction, ) -from pydantic import Field +from pydantic import AliasChoices, Field from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider @@ -18,5 +18,9 @@ class Text2VecProvider(BaseEmbeddingsProvider[Text2VecEmbeddingFunction]): model_name: str = Field( default="shibing624/text2vec-base-chinese", description="Model name to use", - validation_alias="EMBEDDINGS_TEXT2VEC_MODEL_NAME", + validation_alias=AliasChoices( + "EMBEDDINGS_TEXT2VEC_MODEL_NAME", + "TEXT2VEC_MODEL_NAME", + "model", + ), ) diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/voyageai/voyageai_provider.py b/lib/crewai/src/crewai/rag/embeddings/providers/voyageai/voyageai_provider.py index 133b02db7..75fcf59e7 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/voyageai/voyageai_provider.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/voyageai/voyageai_provider.py @@ -1,6 +1,6 @@ """Voyage AI embeddings provider.""" -from pydantic import Field +from pydantic import AliasChoices, Field from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider from crewai.rag.embeddings.providers.voyageai.embedding_callable import ( @@ -18,38 +18,53 @@ class VoyageAIProvider(BaseEmbeddingsProvider[VoyageAIEmbeddingFunction]): model: str = Field( default="voyage-2", description="Model to use for embeddings", - validation_alias="EMBEDDINGS_VOYAGEAI_MODEL", + validation_alias=AliasChoices("EMBEDDINGS_VOYAGEAI_MODEL", "VOYAGEAI_MODEL"), ) api_key: str = Field( - description="Voyage AI API key", validation_alias="EMBEDDINGS_VOYAGEAI_API_KEY" + description="Voyage AI API key", + validation_alias=AliasChoices( + "EMBEDDINGS_VOYAGEAI_API_KEY", "VOYAGEAI_API_KEY" + ), ) input_type: str | None = Field( default=None, description="Input type for embeddings", - validation_alias="EMBEDDINGS_VOYAGEAI_INPUT_TYPE", + validation_alias=AliasChoices( + "EMBEDDINGS_VOYAGEAI_INPUT_TYPE", "VOYAGEAI_INPUT_TYPE" + ), ) truncation: bool = Field( default=True, description="Whether to truncate inputs", - validation_alias="EMBEDDINGS_VOYAGEAI_TRUNCATION", + validation_alias=AliasChoices( + "EMBEDDINGS_VOYAGEAI_TRUNCATION", "VOYAGEAI_TRUNCATION" + ), ) output_dtype: str | None = Field( default=None, description="Output data type", - validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE", + validation_alias=AliasChoices( + "EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE", "VOYAGEAI_OUTPUT_DTYPE" + ), ) output_dimension: int | None = Field( default=None, description="Output dimension", - validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION", + validation_alias=AliasChoices( + "EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION", "VOYAGEAI_OUTPUT_DIMENSION" + ), ) max_retries: int = Field( default=0, description="Maximum retries for API calls", - validation_alias="EMBEDDINGS_VOYAGEAI_MAX_RETRIES", + validation_alias=AliasChoices( + "EMBEDDINGS_VOYAGEAI_MAX_RETRIES", "VOYAGEAI_MAX_RETRIES" + ), ) timeout: float | None = Field( default=None, description="Timeout for API calls", - validation_alias="EMBEDDINGS_VOYAGEAI_TIMEOUT", + validation_alias=AliasChoices( + "EMBEDDINGS_VOYAGEAI_TIMEOUT", "VOYAGEAI_TIMEOUT" + ), ) diff --git a/lib/crewai/src/crewai/rag/embeddings/types.py b/lib/crewai/src/crewai/rag/embeddings/types.py index 1c8ea1ca0..794f4c6f9 100644 --- a/lib/crewai/src/crewai/rag/embeddings/types.py +++ b/lib/crewai/src/crewai/rag/embeddings/types.py @@ -29,7 +29,7 @@ from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec -ProviderSpec = ( +ProviderSpec: TypeAlias = ( AzureProviderSpec | BedrockProviderSpec | CohereProviderSpec diff --git a/lib/crewai/src/crewai/rag/qdrant/config.py b/lib/crewai/src/crewai/rag/qdrant/config.py index 0926c3385..1ff6ed159 100644 --- a/lib/crewai/src/crewai/rag/qdrant/config.py +++ b/lib/crewai/src/crewai/rag/qdrant/config.py @@ -1,16 +1,23 @@ """Qdrant configuration model.""" +from __future__ import annotations + from dataclasses import field -from typing import Literal, cast +from typing import TYPE_CHECKING, Any, Literal, cast from pydantic.dataclasses import dataclass as pyd_dataclass -from qdrant_client.models import VectorParams from crewai.rag.config.base import BaseRagConfig from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper +if TYPE_CHECKING: + from qdrant_client.models import VectorParams +else: + VectorParams = Any + + def _default_options() -> QdrantClientParams: """Create default Qdrant client options. @@ -26,7 +33,7 @@ def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper: Returns: Default embedding function using fastembed with all-MiniLM-L6-v2. """ - from fastembed import TextEmbedding # type: ignore[import-not-found] + from fastembed import TextEmbedding model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL) diff --git a/lib/crewai/tests/rag/embeddings/test_backward_compatibility.py b/lib/crewai/tests/rag/embeddings/test_backward_compatibility.py new file mode 100644 index 000000000..d10a75cde --- /dev/null +++ b/lib/crewai/tests/rag/embeddings/test_backward_compatibility.py @@ -0,0 +1,364 @@ +"""Tests for backward compatibility of embedding provider configurations.""" + +from crewai.rag.embeddings.factory import build_embedder, PROVIDER_PATHS +from crewai.rag.embeddings.providers.openai.openai_provider import OpenAIProvider +from crewai.rag.embeddings.providers.cohere.cohere_provider import CohereProvider +from crewai.rag.embeddings.providers.google.generative_ai import GenerativeAiProvider +from crewai.rag.embeddings.providers.google.vertex import VertexAIProvider +from crewai.rag.embeddings.providers.microsoft.azure import AzureProvider +from crewai.rag.embeddings.providers.jina.jina_provider import JinaProvider +from crewai.rag.embeddings.providers.ollama.ollama_provider import OllamaProvider +from crewai.rag.embeddings.providers.aws.bedrock import BedrockProvider +from crewai.rag.embeddings.providers.text2vec.text2vec_provider import Text2VecProvider +from crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider import ( + SentenceTransformerProvider, +) +from crewai.rag.embeddings.providers.instructor.instructor_provider import InstructorProvider +from crewai.rag.embeddings.providers.openclip.openclip_provider import OpenCLIPProvider + + +class TestGoogleProviderAlias: + """Test that 'google' provider name alias works for backward compatibility.""" + + def test_google_alias_in_provider_paths(self): + """Verify 'google' is registered as an alias for google-generativeai.""" + assert "google" in PROVIDER_PATHS + assert "google-generativeai" in PROVIDER_PATHS + assert PROVIDER_PATHS["google"] == PROVIDER_PATHS["google-generativeai"] + + +class TestModelKeyBackwardCompatibility: + """Test that 'model' config key works as alias for 'model_name'.""" + + def test_openai_provider_accepts_model_key(self): + """Test OpenAI provider accepts 'model' as alias for 'model_name'.""" + provider = OpenAIProvider( + api_key="test-key", + model="text-embedding-3-small", + ) + assert provider.model_name == "text-embedding-3-small" + + def test_openai_provider_model_name_takes_precedence(self): + """Test that model_name takes precedence when both are provided.""" + provider = OpenAIProvider( + api_key="test-key", + model_name="text-embedding-3-large", + ) + assert provider.model_name == "text-embedding-3-large" + + def test_cohere_provider_accepts_model_key(self): + """Test Cohere provider accepts 'model' as alias for 'model_name'.""" + provider = CohereProvider( + api_key="test-key", + model="embed-english-v3.0", + ) + assert provider.model_name == "embed-english-v3.0" + + def test_google_generativeai_provider_accepts_model_key(self): + """Test Google Generative AI provider accepts 'model' as alias.""" + provider = GenerativeAiProvider( + api_key="test-key", + model="gemini-embedding-001", + ) + assert provider.model_name == "gemini-embedding-001" + + def test_google_vertex_provider_accepts_model_key(self): + """Test Google Vertex AI provider accepts 'model' as alias.""" + provider = VertexAIProvider( + api_key="test-key", + model="text-embedding-004", + ) + assert provider.model_name == "text-embedding-004" + + def test_azure_provider_accepts_model_key(self): + """Test Azure provider accepts 'model' as alias for 'model_name'.""" + provider = AzureProvider( + api_key="test-key", + deployment_id="test-deployment", + model="text-embedding-ada-002", + ) + assert provider.model_name == "text-embedding-ada-002" + + def test_jina_provider_accepts_model_key(self): + """Test Jina provider accepts 'model' as alias for 'model_name'.""" + provider = JinaProvider( + api_key="test-key", + model="jina-embeddings-v3", + ) + assert provider.model_name == "jina-embeddings-v3" + + def test_ollama_provider_accepts_model_key(self): + """Test Ollama provider accepts 'model' as alias for 'model_name'.""" + provider = OllamaProvider( + model="nomic-embed-text", + ) + assert provider.model_name == "nomic-embed-text" + + def test_text2vec_provider_accepts_model_key(self): + """Test Text2Vec provider accepts 'model' as alias for 'model_name'.""" + provider = Text2VecProvider( + model="shibing624/text2vec-base-multilingual", + ) + assert provider.model_name == "shibing624/text2vec-base-multilingual" + + def test_sentence_transformer_provider_accepts_model_key(self): + """Test SentenceTransformer provider accepts 'model' as alias.""" + provider = SentenceTransformerProvider( + model="all-mpnet-base-v2", + ) + assert provider.model_name == "all-mpnet-base-v2" + + def test_instructor_provider_accepts_model_key(self): + """Test Instructor provider accepts 'model' as alias for 'model_name'.""" + provider = InstructorProvider( + model="hkunlp/instructor-xl", + ) + assert provider.model_name == "hkunlp/instructor-xl" + + def test_openclip_provider_accepts_model_key(self): + """Test OpenCLIP provider accepts 'model' as alias for 'model_name'.""" + provider = OpenCLIPProvider( + model="ViT-B-16", + ) + assert provider.model_name == "ViT-B-16" + + +class TestTaskTypeConfiguration: + """Test that task_type configuration works correctly.""" + + def test_google_provider_accepts_lowercase_task_type(self): + """Test Google provider accepts lowercase task_type.""" + provider = GenerativeAiProvider( + api_key="test-key", + task_type="retrieval_document", + ) + assert provider.task_type == "retrieval_document" + + def test_google_provider_accepts_uppercase_task_type(self): + """Test Google provider accepts uppercase task_type.""" + provider = GenerativeAiProvider( + api_key="test-key", + task_type="RETRIEVAL_QUERY", + ) + assert provider.task_type == "RETRIEVAL_QUERY" + + def test_google_provider_default_task_type(self): + """Test Google provider has correct default task_type.""" + provider = GenerativeAiProvider( + api_key="test-key", + ) + assert provider.task_type == "RETRIEVAL_DOCUMENT" + + +class TestFactoryBackwardCompatibility: + """Test factory function with backward compatible configurations.""" + + def test_factory_with_google_alias(self): + """Test factory resolves 'google' to google-generativeai provider.""" + config = { + "provider": "google", + "config": { + "api_key": "test-key", + "model": "gemini-embedding-001", + }, + } + + from unittest.mock import patch, MagicMock + + with patch("crewai.rag.embeddings.factory.import_and_validate_definition") as mock_import: + mock_provider_class = MagicMock() + mock_provider_instance = MagicMock() + mock_import.return_value = mock_provider_class + mock_provider_class.return_value = mock_provider_instance + + build_embedder(config) + + mock_import.assert_called_once_with( + "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider" + ) + + def test_factory_with_model_key_openai(self): + """Test factory passes 'model' config to OpenAI provider.""" + config = { + "provider": "openai", + "config": { + "api_key": "test-key", + "model": "text-embedding-3-small", + }, + } + + from unittest.mock import patch, MagicMock + + with patch("crewai.rag.embeddings.factory.import_and_validate_definition") as mock_import: + mock_provider_class = MagicMock() + mock_provider_instance = MagicMock() + mock_import.return_value = mock_provider_class + mock_provider_class.return_value = mock_provider_instance + + build_embedder(config) + + call_kwargs = mock_provider_class.call_args.kwargs + assert call_kwargs["model"] == "text-embedding-3-small" + + +class TestDocumentationCodeSnippets: + """Test code snippets from documentation work correctly.""" + + def test_memory_openai_config(self): + """Test OpenAI config from memory.mdx documentation.""" + provider = OpenAIProvider( + model_name="text-embedding-3-small", + ) + assert provider.model_name == "text-embedding-3-small" + + def test_memory_openai_config_with_options(self): + """Test OpenAI config with all options from memory.mdx.""" + provider = OpenAIProvider( + api_key="your-openai-api-key", + model_name="text-embedding-3-large", + dimensions=1536, + organization_id="your-org-id", + ) + assert provider.model_name == "text-embedding-3-large" + assert provider.dimensions == 1536 + + def test_memory_azure_config(self): + """Test Azure config from memory.mdx documentation.""" + provider = AzureProvider( + api_key="your-azure-key", + api_base="https://your-resource.openai.azure.com/", + api_type="azure", + api_version="2023-05-15", + model_name="text-embedding-3-small", + deployment_id="your-deployment-name", + ) + assert provider.model_name == "text-embedding-3-small" + assert provider.api_type == "azure" + + def test_memory_google_generativeai_config(self): + """Test Google Generative AI config from memory.mdx documentation.""" + provider = GenerativeAiProvider( + api_key="your-google-api-key", + model_name="gemini-embedding-001", + ) + assert provider.model_name == "gemini-embedding-001" + + def test_memory_cohere_config(self): + """Test Cohere config from memory.mdx documentation.""" + provider = CohereProvider( + api_key="your-cohere-api-key", + model_name="embed-english-v3.0", + ) + assert provider.model_name == "embed-english-v3.0" + + def test_knowledge_agent_embedder_config(self): + """Test agent embedder config from knowledge.mdx documentation.""" + provider = GenerativeAiProvider( + model_name="gemini-embedding-001", + api_key="your-google-key", + ) + assert provider.model_name == "gemini-embedding-001" + + def test_ragtool_openai_config(self): + """Test RagTool OpenAI config from ragtool.mdx documentation.""" + provider = OpenAIProvider( + model_name="text-embedding-3-small", + ) + assert provider.model_name == "text-embedding-3-small" + + def test_ragtool_cohere_config(self): + """Test RagTool Cohere config from ragtool.mdx documentation.""" + provider = CohereProvider( + api_key="your-api-key", + model_name="embed-english-v3.0", + ) + assert provider.model_name == "embed-english-v3.0" + + def test_ragtool_ollama_config(self): + """Test RagTool Ollama config from ragtool.mdx documentation.""" + provider = OllamaProvider( + model_name="llama2", + url="http://localhost:11434/api/embeddings", + ) + assert provider.model_name == "llama2" + + def test_ragtool_azure_config(self): + """Test RagTool Azure config from ragtool.mdx documentation.""" + provider = AzureProvider( + deployment_id="your-deployment-id", + api_key="your-api-key", + api_base="https://your-resource.openai.azure.com", + api_version="2024-02-01", + model_name="text-embedding-ada-002", + api_type="azure", + ) + assert provider.model_name == "text-embedding-ada-002" + assert provider.deployment_id == "your-deployment-id" + + def test_ragtool_google_generativeai_config(self): + """Test RagTool Google Generative AI config from ragtool.mdx.""" + provider = GenerativeAiProvider( + api_key="your-api-key", + model_name="gemini-embedding-001", + task_type="RETRIEVAL_DOCUMENT", + ) + assert provider.model_name == "gemini-embedding-001" + assert provider.task_type == "RETRIEVAL_DOCUMENT" + + def test_ragtool_jina_config(self): + """Test RagTool Jina config from ragtool.mdx documentation.""" + provider = JinaProvider( + api_key="your-api-key", + model_name="jina-embeddings-v3", + ) + assert provider.model_name == "jina-embeddings-v3" + + def test_ragtool_sentence_transformer_config(self): + """Test RagTool SentenceTransformer config from ragtool.mdx.""" + provider = SentenceTransformerProvider( + model_name="all-mpnet-base-v2", + device="cuda", + normalize_embeddings=True, + ) + assert provider.model_name == "all-mpnet-base-v2" + assert provider.device == "cuda" + assert provider.normalize_embeddings is True + + +class TestLegacyConfigurationFormats: + """Test legacy configuration formats that should still work.""" + + def test_legacy_google_with_model_key(self): + """Test legacy Google config using 'model' instead of 'model_name'.""" + provider = GenerativeAiProvider( + api_key="test-key", + model="text-embedding-005", + task_type="retrieval_document", + ) + assert provider.model_name == "text-embedding-005" + assert provider.task_type == "retrieval_document" + + def test_legacy_openai_with_model_key(self): + """Test legacy OpenAI config using 'model' instead of 'model_name'.""" + provider = OpenAIProvider( + api_key="test-key", + model="text-embedding-ada-002", + ) + assert provider.model_name == "text-embedding-ada-002" + + def test_legacy_cohere_with_model_key(self): + """Test legacy Cohere config using 'model' instead of 'model_name'.""" + provider = CohereProvider( + api_key="test-key", + model="embed-multilingual-v3.0", + ) + assert provider.model_name == "embed-multilingual-v3.0" + + def test_legacy_azure_with_model_key(self): + """Test legacy Azure config using 'model' instead of 'model_name'.""" + provider = AzureProvider( + api_key="test-key", + deployment_id="test-deployment", + model="text-embedding-3-large", + ) + assert provider.model_name == "text-embedding-3-large" \ No newline at end of file