mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(backend): address PR review comments for vector search
- Make EmbeddingService API key validation lazy (doesn't break startup) - Make chat service OpenAI client creation lazy with @functools.cache - Add thread-safe double-checked locking for client initialization - Remove unnecessary catch/re-raise error handling in embeddings - Replace async lock singleton with simpler @functools.cache pattern - Update backfill script to use get_embedding_service() singleton - Fix test mocks to use MagicMock instead of AsyncMock for sync function
This commit is contained in:
@@ -4,8 +4,9 @@ Embedding service for generating text embeddings using OpenAI.
|
||||
Used for vector-based semantic search in the store.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import openai
|
||||
@@ -28,7 +29,12 @@ MAX_BATCH_SIZE = 100 # maximum texts per batch request
|
||||
|
||||
|
||||
class EmbeddingService:
|
||||
"""Service for generating text embeddings using OpenAI."""
|
||||
"""Service for generating text embeddings using OpenAI.
|
||||
|
||||
The service can be created without an API key - the key is validated
|
||||
only when the client property is first accessed. This allows the service
|
||||
to be instantiated at module load time without requiring configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
settings = Settings()
|
||||
@@ -37,12 +43,25 @@ class EmbeddingService:
|
||||
or settings.secrets.openai_internal_api_key
|
||||
or settings.secrets.openai_api_key
|
||||
)
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"OpenAI API key not configured. "
|
||||
"Set OPENAI_API_KEY or OPENAI_INTERNAL_API_KEY environment variable."
|
||||
)
|
||||
self.client = openai.AsyncOpenAI(api_key=self.api_key)
|
||||
self._client: Optional[openai.AsyncOpenAI] = None
|
||||
self._client_lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def client(self) -> openai.AsyncOpenAI:
|
||||
"""Lazily create the OpenAI client, raising if no API key is configured.
|
||||
|
||||
Uses double-checked locking for thread-safe lazy initialization.
|
||||
"""
|
||||
if self._client is None:
|
||||
with self._client_lock:
|
||||
if self._client is None:
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"OpenAI API key not configured. "
|
||||
"Set OPENAI_API_KEY or OPENAI_INTERNAL_API_KEY environment variable."
|
||||
)
|
||||
self._client = openai.AsyncOpenAI(api_key=self.api_key)
|
||||
return self._client
|
||||
|
||||
async def generate_embedding(self, text: str) -> list[float]:
|
||||
"""
|
||||
@@ -66,16 +85,12 @@ class EmbeddingService:
|
||||
f"Text exceeds maximum length of {MAX_TEXT_LENGTH} characters"
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.embeddings.create(
|
||||
model=EMBEDDING_MODEL,
|
||||
input=text,
|
||||
dimensions=EMBEDDING_DIMENSIONS,
|
||||
)
|
||||
return response.data[0].embedding
|
||||
except openai.APIError as e:
|
||||
logger.error(f"OpenAI API error generating embedding: {e}")
|
||||
raise
|
||||
response = await self.client.embeddings.create(
|
||||
model=EMBEDDING_MODEL,
|
||||
input=text,
|
||||
dimensions=EMBEDDING_DIMENSIONS,
|
||||
)
|
||||
return response.data[0].embedding
|
||||
|
||||
async def generate_embeddings(self, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
@@ -104,18 +119,14 @@ class EmbeddingService:
|
||||
f"Text at index {i} exceeds maximum length of {MAX_TEXT_LENGTH} characters"
|
||||
)
|
||||
|
||||
try:
|
||||
response = await self.client.embeddings.create(
|
||||
model=EMBEDDING_MODEL,
|
||||
input=texts,
|
||||
dimensions=EMBEDDING_DIMENSIONS,
|
||||
)
|
||||
# Sort by index to ensure correct ordering
|
||||
sorted_data = sorted(response.data, key=lambda x: x.index)
|
||||
return [item.embedding for item in sorted_data]
|
||||
except openai.APIError as e:
|
||||
logger.error(f"OpenAI API error generating embeddings: {e}")
|
||||
raise
|
||||
response = await self.client.embeddings.create(
|
||||
model=EMBEDDING_MODEL,
|
||||
input=texts,
|
||||
dimensions=EMBEDDING_DIMENSIONS,
|
||||
)
|
||||
# Sort by index to ensure correct ordering
|
||||
sorted_data = sorted(response.data, key=lambda x: x.index)
|
||||
return [item.embedding for item in sorted_data]
|
||||
|
||||
|
||||
def create_search_text(name: str, sub_heading: str, description: str) -> str:
|
||||
@@ -137,27 +148,17 @@ def create_search_text(name: str, sub_heading: str, description: str) -> str:
|
||||
return " ".join(filter(None, parts)).strip()
|
||||
|
||||
|
||||
# Singleton instance with lock for thread-safe initialization
|
||||
_embedding_service: Optional[EmbeddingService] = None
|
||||
_embedding_service_lock: asyncio.Lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_embedding_service() -> EmbeddingService:
|
||||
@functools.cache
|
||||
def get_embedding_service() -> EmbeddingService:
|
||||
"""
|
||||
Get or create the embedding service singleton.
|
||||
|
||||
Uses double-checked locking to prevent race conditions in concurrent
|
||||
async environments while avoiding lock overhead after initialization.
|
||||
Uses functools.cache for thread-safe lazy initialization.
|
||||
|
||||
Returns:
|
||||
The shared EmbeddingService instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If OpenAI API key is not configured.
|
||||
ValueError: If OpenAI API key is not configured (when generating embeddings).
|
||||
"""
|
||||
global _embedding_service
|
||||
if _embedding_service is None:
|
||||
async with _embedding_service_lock:
|
||||
if _embedding_service is None:
|
||||
_embedding_service = EmbeddingService()
|
||||
return _embedding_service
|
||||
return EmbeddingService()
|
||||
|
||||
@@ -76,19 +76,25 @@ class TestEmbeddingServiceValidation:
|
||||
@pytest.fixture
|
||||
def service(self, mock_settings):
|
||||
"""Create an EmbeddingService instance with mocked settings."""
|
||||
with patch("backend.integrations.embeddings.openai.AsyncOpenAI"):
|
||||
return EmbeddingService()
|
||||
service = EmbeddingService()
|
||||
# Inject a mock client directly to avoid lazy initialization errors
|
||||
service._client = MagicMock()
|
||||
return service
|
||||
|
||||
def test_init_requires_api_key(self):
|
||||
"""Test that initialization fails without an API key."""
|
||||
def test_client_access_requires_api_key(self):
|
||||
"""Test that accessing client fails without an API key."""
|
||||
with patch("backend.integrations.embeddings.Settings") as mock:
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.secrets.openai_internal_api_key = ""
|
||||
mock_instance.secrets.openai_api_key = ""
|
||||
mock.return_value = mock_instance
|
||||
|
||||
# Service creation should succeed
|
||||
service = EmbeddingService()
|
||||
|
||||
# But accessing client should fail
|
||||
with pytest.raises(ValueError, match="OpenAI API key not configured"):
|
||||
EmbeddingService()
|
||||
_ = service.client
|
||||
|
||||
def test_init_accepts_explicit_api_key(self):
|
||||
"""Test that explicit API key overrides settings."""
|
||||
@@ -167,12 +173,10 @@ class TestEmbeddingServiceAPI:
|
||||
mock_instance.secrets.openai_api_key = ""
|
||||
mock_settings.return_value = mock_instance
|
||||
|
||||
with patch(
|
||||
"backend.integrations.embeddings.openai.AsyncOpenAI"
|
||||
) as mock_openai:
|
||||
mock_openai.return_value = mock_openai_client
|
||||
service = EmbeddingService()
|
||||
return service, mock_openai_client
|
||||
service = EmbeddingService()
|
||||
# Directly inject mock client to bypass lazy initialization
|
||||
service._client = mock_openai_client
|
||||
return service, mock_openai_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_embedding_success(self, service_with_mock_client):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
@@ -32,7 +33,12 @@ from backend.util.exceptions import NotFoundError
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
config = backend.server.v2.chat.config.ChatConfig()
|
||||
client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_openai_client() -> AsyncOpenAI:
|
||||
"""Lazily create the OpenAI client singleton."""
|
||||
return AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
|
||||
async def create_chat_session(
|
||||
@@ -355,7 +361,7 @@ async def _stream_chat_chunks(
|
||||
logger.info("Creating OpenAI chat completion stream...")
|
||||
|
||||
# Create the stream with proper types
|
||||
stream = await client.chat.completions.create(
|
||||
stream = await get_openai_client().chat.completions.create(
|
||||
model=model,
|
||||
messages=session.to_openai_messages(),
|
||||
tools=tools,
|
||||
|
||||
@@ -16,7 +16,7 @@ import logging
|
||||
import sys
|
||||
|
||||
from backend.data.db import connect, disconnect, query_raw_with_schema
|
||||
from backend.integrations.embeddings import EmbeddingService, create_search_text
|
||||
from backend.integrations.embeddings import create_search_text, get_embedding_service
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@@ -48,7 +48,7 @@ async def backfill_embeddings(
|
||||
await connect()
|
||||
|
||||
try:
|
||||
embedding_service = EmbeddingService()
|
||||
embedding_service = get_embedding_service()
|
||||
|
||||
# Get all versions without embeddings
|
||||
versions = await query_raw_with_schema(
|
||||
|
||||
@@ -67,7 +67,7 @@ async def get_store_agents(
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Generate embedding for search query
|
||||
embedding_service = await get_embedding_service()
|
||||
embedding_service = get_embedding_service()
|
||||
query_embedding = await embedding_service.generate_embedding(search_query)
|
||||
# Convert embedding to PostgreSQL array format
|
||||
embedding_str = "[" + ",".join(map(str, query_embedding)) + "]"
|
||||
@@ -291,7 +291,7 @@ async def _generate_and_store_embedding(
|
||||
description: The agent description.
|
||||
"""
|
||||
try:
|
||||
embedding_service = await get_embedding_service()
|
||||
embedding_service = get_embedding_service()
|
||||
search_text = create_search_text(name, sub_heading, description)
|
||||
|
||||
if not search_text:
|
||||
|
||||
@@ -423,7 +423,7 @@ async def test_get_store_agents_vector_search_mocked(mocker):
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.store.db.get_embedding_service",
|
||||
mocker.AsyncMock(return_value=mock_embedding_service),
|
||||
mocker.MagicMock(return_value=mock_embedding_service),
|
||||
)
|
||||
|
||||
# Mock query_raw_with_schema to return empty results
|
||||
@@ -456,7 +456,7 @@ async def test_get_store_agents_vector_search_with_results(mocker):
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.store.db.get_embedding_service",
|
||||
mocker.AsyncMock(return_value=mock_embedding_service),
|
||||
mocker.MagicMock(return_value=mock_embedding_service),
|
||||
)
|
||||
|
||||
# Mock query results
|
||||
@@ -508,7 +508,7 @@ async def test_get_store_agents_vector_search_with_filters(mocker):
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.store.db.get_embedding_service",
|
||||
mocker.AsyncMock(return_value=mock_embedding_service),
|
||||
mocker.MagicMock(return_value=mock_embedding_service),
|
||||
)
|
||||
|
||||
# Mock query_raw_with_schema
|
||||
@@ -554,7 +554,7 @@ async def test_generate_and_store_embedding_success(mocker):
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.store.db.get_embedding_service",
|
||||
mocker.AsyncMock(return_value=mock_embedding_service),
|
||||
mocker.MagicMock(return_value=mock_embedding_service),
|
||||
)
|
||||
|
||||
# Mock query_raw_with_schema
|
||||
@@ -592,7 +592,7 @@ async def test_generate_and_store_embedding_empty_text(mocker):
|
||||
mock_embedding_service.generate_embedding = mocker.AsyncMock()
|
||||
mocker.patch(
|
||||
"backend.server.v2.store.db.get_embedding_service",
|
||||
mocker.AsyncMock(return_value=mock_embedding_service),
|
||||
mocker.MagicMock(return_value=mock_embedding_service),
|
||||
)
|
||||
|
||||
# Mock query_raw_with_schema
|
||||
@@ -626,7 +626,7 @@ async def test_generate_and_store_embedding_handles_error(mocker):
|
||||
)
|
||||
mocker.patch(
|
||||
"backend.server.v2.store.db.get_embedding_service",
|
||||
mocker.AsyncMock(return_value=mock_embedding_service),
|
||||
mocker.MagicMock(return_value=mock_embedding_service),
|
||||
)
|
||||
|
||||
# Call should not raise - errors are logged but not propagated
|
||||
|
||||
Reference in New Issue
Block a user