refactor(backend): simplify EmbeddingService client property

Address PR review: replace manual double-checked locking with
@functools.cached_property for cleaner, simpler lazy initialization.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Swifty
2025-12-18 19:16:11 +01:00
parent c7063a46a6
commit 8268d919f5
2 changed files with 15 additions and 22 deletions

View File

@@ -6,7 +6,6 @@ Used for vector-based semantic search in the store.
import functools
import logging
import threading
from typing import Optional
import openai
@@ -43,25 +42,16 @@ class EmbeddingService:
or settings.secrets.openai_internal_api_key
or settings.secrets.openai_api_key
)
self._client: Optional[openai.AsyncOpenAI] = None
self._client_lock = threading.Lock()
@property
@functools.cached_property
def client(self) -> openai.AsyncOpenAI:
"""Lazily create the OpenAI client, raising if no API key is configured.
Uses double-checked locking for thread-safe lazy initialization.
"""
if self._client is None:
with self._client_lock:
if self._client is None:
if not self.api_key:
raise ValueError(
"OpenAI API key not configured. "
"Set OPENAI_API_KEY or OPENAI_INTERNAL_API_KEY environment variable."
)
self._client = openai.AsyncOpenAI(api_key=self.api_key)
return self._client
"""Lazily create the OpenAI client, raising if no API key is configured."""
if not self.api_key:
raise ValueError(
"OpenAI API key not configured. "
"Set OPENAI_API_KEY or OPENAI_INTERNAL_API_KEY environment variable."
)
return openai.AsyncOpenAI(api_key=self.api_key)
async def generate_embedding(self, text: str) -> list[float]:
"""
@@ -90,6 +80,8 @@ class EmbeddingService:
input=text,
dimensions=EMBEDDING_DIMENSIONS,
)
if not response.data:
raise ValueError("OpenAI API returned empty embedding data")
return response.data[0].embedding
async def generate_embeddings(self, texts: list[str]) -> list[list[float]]:
@@ -145,6 +137,7 @@ def create_search_text(name: str, sub_heading: str, description: str) -> str:
A single string combining all non-empty fields.
"""
parts = [name or "", sub_heading or "", description or ""]
# filter(None, parts) removes empty strings since empty string is falsy
return " ".join(filter(None, parts)).strip()

View File

@@ -77,8 +77,8 @@ class TestEmbeddingServiceValidation:
def service(self, mock_settings):
"""Create an EmbeddingService instance with mocked settings."""
service = EmbeddingService()
# Inject a mock client directly to avoid lazy initialization errors
service._client = MagicMock()
# Inject a mock client by setting the cached_property directly
service.__dict__["client"] = MagicMock()
return service
def test_client_access_requires_api_key(self):
@@ -174,8 +174,8 @@ class TestEmbeddingServiceAPI:
mock_settings.return_value = mock_instance
service = EmbeddingService()
# Directly inject mock client to bypass lazy initialization
service._client = mock_openai_client
# Inject mock client by setting the cached_property directly
service.__dict__["client"] = mock_openai_client
return service, mock_openai_client
@pytest.mark.asyncio