fix(backend): address PR review comments for vector search

- Make EmbeddingService API key validation lazy (doesn't break startup)
- Make chat service OpenAI client creation lazy with @functools.cache
- Add thread-safe double-checked locking for client initialization
- Remove unnecessary catch/re-raise error handling in embeddings
- Replace async lock singleton with simpler @functools.cache pattern
- Update backfill script to use get_embedding_service() singleton
- Fix test mocks to use MagicMock instead of AsyncMock for sync function
This commit is contained in:
Swifty
2025-12-09 08:54:11 +01:00
parent 521dbdc25f
commit 69d0c05017
6 changed files with 79 additions and 68 deletions

View File

@@ -4,8 +4,9 @@ Embedding service for generating text embeddings using OpenAI.
Used for vector-based semantic search in the store.
"""
import asyncio
import functools
import logging
import threading
from typing import Optional
import openai
@@ -28,7 +29,12 @@ MAX_BATCH_SIZE = 100 # maximum texts per batch request
class EmbeddingService:
"""Service for generating text embeddings using OpenAI."""
"""Service for generating text embeddings using OpenAI.
The service can be created without an API key - the key is validated
only when the client property is first accessed. This allows the service
to be instantiated at module load time without requiring configuration.
"""
def __init__(self, api_key: Optional[str] = None):
settings = Settings()
@@ -37,12 +43,25 @@ class EmbeddingService:
or settings.secrets.openai_internal_api_key
or settings.secrets.openai_api_key
)
if not self.api_key:
raise ValueError(
"OpenAI API key not configured. "
"Set OPENAI_API_KEY or OPENAI_INTERNAL_API_KEY environment variable."
)
self.client = openai.AsyncOpenAI(api_key=self.api_key)
self._client: Optional[openai.AsyncOpenAI] = None
self._client_lock = threading.Lock()
@property
def client(self) -> openai.AsyncOpenAI:
"""Lazily create the OpenAI client, raising if no API key is configured.
Uses double-checked locking for thread-safe lazy initialization.
"""
if self._client is None:
with self._client_lock:
if self._client is None:
if not self.api_key:
raise ValueError(
"OpenAI API key not configured. "
"Set OPENAI_API_KEY or OPENAI_INTERNAL_API_KEY environment variable."
)
self._client = openai.AsyncOpenAI(api_key=self.api_key)
return self._client
async def generate_embedding(self, text: str) -> list[float]:
"""
@@ -66,16 +85,12 @@ class EmbeddingService:
f"Text exceeds maximum length of {MAX_TEXT_LENGTH} characters"
)
try:
response = await self.client.embeddings.create(
model=EMBEDDING_MODEL,
input=text,
dimensions=EMBEDDING_DIMENSIONS,
)
return response.data[0].embedding
except openai.APIError as e:
logger.error(f"OpenAI API error generating embedding: {e}")
raise
response = await self.client.embeddings.create(
model=EMBEDDING_MODEL,
input=text,
dimensions=EMBEDDING_DIMENSIONS,
)
return response.data[0].embedding
async def generate_embeddings(self, texts: list[str]) -> list[list[float]]:
"""
@@ -104,18 +119,14 @@ class EmbeddingService:
f"Text at index {i} exceeds maximum length of {MAX_TEXT_LENGTH} characters"
)
try:
response = await self.client.embeddings.create(
model=EMBEDDING_MODEL,
input=texts,
dimensions=EMBEDDING_DIMENSIONS,
)
# Sort by index to ensure correct ordering
sorted_data = sorted(response.data, key=lambda x: x.index)
return [item.embedding for item in sorted_data]
except openai.APIError as e:
logger.error(f"OpenAI API error generating embeddings: {e}")
raise
response = await self.client.embeddings.create(
model=EMBEDDING_MODEL,
input=texts,
dimensions=EMBEDDING_DIMENSIONS,
)
# Sort by index to ensure correct ordering
sorted_data = sorted(response.data, key=lambda x: x.index)
return [item.embedding for item in sorted_data]
def create_search_text(name: str, sub_heading: str, description: str) -> str:
@@ -137,27 +148,17 @@ def create_search_text(name: str, sub_heading: str, description: str) -> str:
return " ".join(filter(None, parts)).strip()
# Singleton instance with lock for thread-safe initialization
_embedding_service: Optional[EmbeddingService] = None
_embedding_service_lock: asyncio.Lock = asyncio.Lock()
async def get_embedding_service() -> EmbeddingService:
@functools.cache
def get_embedding_service() -> EmbeddingService:
"""
Get or create the embedding service singleton.
Uses double-checked locking to prevent race conditions in concurrent
async environments while avoiding lock overhead after initialization.
Uses functools.cache for thread-safe lazy initialization.
Returns:
The shared EmbeddingService instance.
Raises:
ValueError: If OpenAI API key is not configured.
ValueError: If OpenAI API key is not configured (when generating embeddings).
"""
global _embedding_service
if _embedding_service is None:
async with _embedding_service_lock:
if _embedding_service is None:
_embedding_service = EmbeddingService()
return _embedding_service
return EmbeddingService()

View File

@@ -76,19 +76,25 @@ class TestEmbeddingServiceValidation:
@pytest.fixture
def service(self, mock_settings):
"""Create an EmbeddingService instance with mocked settings."""
with patch("backend.integrations.embeddings.openai.AsyncOpenAI"):
return EmbeddingService()
service = EmbeddingService()
# Inject a mock client directly to avoid lazy initialization errors
service._client = MagicMock()
return service
def test_init_requires_api_key(self):
"""Test that initialization fails without an API key."""
def test_client_access_requires_api_key(self):
"""Test that accessing client fails without an API key."""
with patch("backend.integrations.embeddings.Settings") as mock:
mock_instance = MagicMock()
mock_instance.secrets.openai_internal_api_key = ""
mock_instance.secrets.openai_api_key = ""
mock.return_value = mock_instance
# Service creation should succeed
service = EmbeddingService()
# But accessing client should fail
with pytest.raises(ValueError, match="OpenAI API key not configured"):
EmbeddingService()
_ = service.client
def test_init_accepts_explicit_api_key(self):
"""Test that explicit API key overrides settings."""
@@ -167,12 +173,10 @@ class TestEmbeddingServiceAPI:
mock_instance.secrets.openai_api_key = ""
mock_settings.return_value = mock_instance
with patch(
"backend.integrations.embeddings.openai.AsyncOpenAI"
) as mock_openai:
mock_openai.return_value = mock_openai_client
service = EmbeddingService()
return service, mock_openai_client
service = EmbeddingService()
# Directly inject mock client to bypass lazy initialization
service._client = mock_openai_client
return service, mock_openai_client
@pytest.mark.asyncio
async def test_generate_embedding_success(self, service_with_mock_client):

View File

@@ -1,3 +1,4 @@
import functools
import logging
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
@@ -32,7 +33,12 @@ from backend.util.exceptions import NotFoundError
logger = logging.getLogger(__name__)
config = backend.server.v2.chat.config.ChatConfig()
client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
@functools.cache
def get_openai_client() -> AsyncOpenAI:
"""Lazily create the OpenAI client singleton."""
return AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
async def create_chat_session(
@@ -355,7 +361,7 @@ async def _stream_chat_chunks(
logger.info("Creating OpenAI chat completion stream...")
# Create the stream with proper types
stream = await client.chat.completions.create(
stream = await get_openai_client().chat.completions.create(
model=model,
messages=session.to_openai_messages(),
tools=tools,

View File

@@ -16,7 +16,7 @@ import logging
import sys
from backend.data.db import connect, disconnect, query_raw_with_schema
from backend.integrations.embeddings import EmbeddingService, create_search_text
from backend.integrations.embeddings import create_search_text, get_embedding_service
logging.basicConfig(
level=logging.INFO,
@@ -48,7 +48,7 @@ async def backfill_embeddings(
await connect()
try:
embedding_service = EmbeddingService()
embedding_service = get_embedding_service()
# Get all versions without embeddings
versions = await query_raw_with_schema(

View File

@@ -67,7 +67,7 @@ async def get_store_agents(
offset = (page - 1) * page_size
# Generate embedding for search query
embedding_service = await get_embedding_service()
embedding_service = get_embedding_service()
query_embedding = await embedding_service.generate_embedding(search_query)
# Convert embedding to PostgreSQL array format
embedding_str = "[" + ",".join(map(str, query_embedding)) + "]"
@@ -291,7 +291,7 @@ async def _generate_and_store_embedding(
description: The agent description.
"""
try:
embedding_service = await get_embedding_service()
embedding_service = get_embedding_service()
search_text = create_search_text(name, sub_heading, description)
if not search_text:

View File

@@ -423,7 +423,7 @@ async def test_get_store_agents_vector_search_mocked(mocker):
)
mocker.patch(
"backend.server.v2.store.db.get_embedding_service",
mocker.AsyncMock(return_value=mock_embedding_service),
mocker.MagicMock(return_value=mock_embedding_service),
)
# Mock query_raw_with_schema to return empty results
@@ -456,7 +456,7 @@ async def test_get_store_agents_vector_search_with_results(mocker):
)
mocker.patch(
"backend.server.v2.store.db.get_embedding_service",
mocker.AsyncMock(return_value=mock_embedding_service),
mocker.MagicMock(return_value=mock_embedding_service),
)
# Mock query results
@@ -508,7 +508,7 @@ async def test_get_store_agents_vector_search_with_filters(mocker):
)
mocker.patch(
"backend.server.v2.store.db.get_embedding_service",
mocker.AsyncMock(return_value=mock_embedding_service),
mocker.MagicMock(return_value=mock_embedding_service),
)
# Mock query_raw_with_schema
@@ -554,7 +554,7 @@ async def test_generate_and_store_embedding_success(mocker):
)
mocker.patch(
"backend.server.v2.store.db.get_embedding_service",
mocker.AsyncMock(return_value=mock_embedding_service),
mocker.MagicMock(return_value=mock_embedding_service),
)
# Mock query_raw_with_schema
@@ -592,7 +592,7 @@ async def test_generate_and_store_embedding_empty_text(mocker):
mock_embedding_service.generate_embedding = mocker.AsyncMock()
mocker.patch(
"backend.server.v2.store.db.get_embedding_service",
mocker.AsyncMock(return_value=mock_embedding_service),
mocker.MagicMock(return_value=mock_embedding_service),
)
# Mock query_raw_with_schema
@@ -626,7 +626,7 @@ async def test_generate_and_store_embedding_handles_error(mocker):
)
mocker.patch(
"backend.server.v2.store.db.get_embedding_service",
mocker.AsyncMock(return_value=mock_embedding_service),
mocker.MagicMock(return_value=mock_embedding_service),
)
# Call should not raise - errors are logged but not propagated