Compare commits

..

1 Commits

Author SHA1 Message Date
Swifty
80dd8932c5 remove langfuse tracing 2026-01-23 13:21:20 +01:00
18 changed files with 520 additions and 568 deletions

View File

@@ -5,9 +5,9 @@ from asyncio import CancelledError
from collections.abc import AsyncGenerator
from typing import Any
import openai
import orjson
from langfuse import get_client, propagate_attributes
from langfuse.openai import openai # type: ignore
from langfuse import get_client
from openai import (
APIConnectionError,
APIError,
@@ -276,347 +276,301 @@ async def stream_chat_completion(
# Build system prompt with business understanding
system_prompt, understanding = await _build_system_prompt(user_id)
# Create Langfuse trace for this LLM call (each call gets its own trace, grouped by session_id)
# Using v3 SDK: start_observation creates a root span, update_trace sets trace-level attributes
input = message
if not message and tool_call_response:
input = tool_call_response
# Initialize variables for streaming
assistant_response = ChatMessage(
role="assistant",
content="",
)
accumulated_tool_calls: list[dict[str, Any]] = []
has_saved_assistant_message = False
has_appended_streaming_message = False
last_cache_time = 0.0
last_cache_content_len = 0
langfuse = get_client()
with langfuse.start_as_current_observation(
as_type="span",
name="user-copilot-request",
input=input,
) as span:
with propagate_attributes(
session_id=session_id,
user_id=user_id,
tags=["copilot"],
metadata={
"users_information": format_understanding_for_prompt(understanding)[
:200
] # langfuse only accepts upto to 200 chars
},
has_yielded_end = False
has_yielded_error = False
has_done_tool_call = False
has_received_text = False
text_streaming_ended = False
tool_response_messages: list[ChatMessage] = []
should_retry = False
# Generate unique IDs for AI SDK protocol
import uuid as uuid_module
message_id = str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4())
# Yield message start
yield StreamStart(messageId=message_id)
try:
async for chunk in _stream_chat_chunks(
session=session,
tools=tools,
system_prompt=system_prompt,
text_block_id=text_block_id,
):
# Initialize variables that will be used in finally block (must be defined before try)
assistant_response = ChatMessage(
role="assistant",
content="",
)
accumulated_tool_calls: list[dict[str, Any]] = []
has_saved_assistant_message = False
has_appended_streaming_message = False
last_cache_time = 0.0
last_cache_content_len = 0
# Wrap main logic in try/finally to ensure Langfuse observations are always ended
has_yielded_end = False
has_yielded_error = False
has_done_tool_call = False
has_received_text = False
text_streaming_ended = False
tool_response_messages: list[ChatMessage] = []
should_retry = False
# Generate unique IDs for AI SDK protocol
import uuid as uuid_module
message_id = str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4())
# Yield message start
yield StreamStart(messageId=message_id)
try:
async for chunk in _stream_chat_chunks(
session=session,
tools=tools,
system_prompt=system_prompt,
text_block_id=text_block_id,
if isinstance(chunk, StreamTextStart):
# Emit text-start before first text delta
if not has_received_text:
yield chunk
elif isinstance(chunk, StreamTextDelta):
delta = chunk.delta or ""
assert assistant_response.content is not None
assistant_response.content += delta
has_received_text = True
if not has_appended_streaming_message:
session.messages.append(assistant_response)
has_appended_streaming_message = True
current_time = time.monotonic()
content_len = len(assistant_response.content)
if (
current_time - last_cache_time >= 1.0
and content_len > last_cache_content_len
):
if isinstance(chunk, StreamTextStart):
# Emit text-start before first text delta
if not has_received_text:
yield chunk
elif isinstance(chunk, StreamTextDelta):
delta = chunk.delta or ""
assert assistant_response.content is not None
assistant_response.content += delta
has_received_text = True
if not has_appended_streaming_message:
session.messages.append(assistant_response)
has_appended_streaming_message = True
current_time = time.monotonic()
content_len = len(assistant_response.content)
if (
current_time - last_cache_time >= 1.0
and content_len > last_cache_content_len
):
try:
await cache_chat_session(session)
except Exception as e:
logger.warning(
f"Failed to cache partial session {session.session_id}: {e}"
)
last_cache_time = current_time
last_cache_content_len = content_len
yield chunk
elif isinstance(chunk, StreamTextEnd):
# Emit text-end after text completes
if has_received_text and not text_streaming_ended:
text_streaming_ended = True
if assistant_response.content:
logger.warn(
f"StreamTextEnd: Attempting to set output {assistant_response.content}"
)
span.update_trace(output=assistant_response.content)
span.update(output=assistant_response.content)
yield chunk
elif isinstance(chunk, StreamToolInputStart):
# Emit text-end before first tool call, but only if we've received text
if has_received_text and not text_streaming_ended:
yield StreamTextEnd(id=text_block_id)
text_streaming_ended = True
yield chunk
elif isinstance(chunk, StreamToolInputAvailable):
# Accumulate tool calls in OpenAI format
accumulated_tool_calls.append(
{
"id": chunk.toolCallId,
"type": "function",
"function": {
"name": chunk.toolName,
"arguments": orjson.dumps(chunk.input).decode(
"utf-8"
),
},
}
)
elif isinstance(chunk, StreamToolOutputAvailable):
result_content = (
chunk.output
if isinstance(chunk.output, str)
else orjson.dumps(chunk.output).decode("utf-8")
)
tool_response_messages.append(
ChatMessage(
role="tool",
content=result_content,
tool_call_id=chunk.toolCallId,
)
)
has_done_tool_call = True
# Track if any tool execution failed
if not chunk.success:
logger.warning(
f"Tool {chunk.toolName} (ID: {chunk.toolCallId}) execution failed"
)
yield chunk
elif isinstance(chunk, StreamFinish):
if not has_done_tool_call:
# Emit text-end before finish if we received text but haven't closed it
if has_received_text and not text_streaming_ended:
yield StreamTextEnd(id=text_block_id)
text_streaming_ended = True
# Save assistant message before yielding finish to ensure it's persisted
# even if client disconnects immediately after receiving StreamFinish
if not has_saved_assistant_message:
messages_to_save_early: list[ChatMessage] = []
if accumulated_tool_calls:
assistant_response.tool_calls = (
accumulated_tool_calls
)
if not has_appended_streaming_message and (
assistant_response.content
or assistant_response.tool_calls
):
messages_to_save_early.append(assistant_response)
messages_to_save_early.extend(tool_response_messages)
if messages_to_save_early:
session.messages.extend(messages_to_save_early)
logger.info(
f"Saving assistant message before StreamFinish: "
f"content_len={len(assistant_response.content or '')}, "
f"tool_calls={len(assistant_response.tool_calls or [])}, "
f"tool_responses={len(tool_response_messages)}"
)
if (
messages_to_save_early
or has_appended_streaming_message
):
await upsert_chat_session(session)
has_saved_assistant_message = True
has_yielded_end = True
yield chunk
elif isinstance(chunk, StreamError):
has_yielded_error = True
yield chunk
elif isinstance(chunk, StreamUsage):
session.usage.append(
Usage(
prompt_tokens=chunk.promptTokens,
completion_tokens=chunk.completionTokens,
total_tokens=chunk.totalTokens,
)
)
else:
logger.error(
f"Unknown chunk type: {type(chunk)}", exc_info=True
)
if assistant_response.content:
langfuse.update_current_trace(output=assistant_response.content)
langfuse.update_current_span(output=assistant_response.content)
elif tool_response_messages:
langfuse.update_current_trace(output=str(tool_response_messages))
langfuse.update_current_span(output=str(tool_response_messages))
except CancelledError:
if not has_saved_assistant_message:
if accumulated_tool_calls:
assistant_response.tool_calls = accumulated_tool_calls
if assistant_response.content:
assistant_response.content = (
f"{assistant_response.content}\n\n[interrupted]"
)
else:
assistant_response.content = "[interrupted]"
if not has_appended_streaming_message:
session.messages.append(assistant_response)
if tool_response_messages:
session.messages.extend(tool_response_messages)
try:
await upsert_chat_session(session)
await cache_chat_session(session)
except Exception as e:
logger.warning(
f"Failed to save interrupted session {session.session_id}: {e}"
f"Failed to cache partial session {session.session_id}: {e}"
)
raise
except Exception as e:
logger.error(f"Error during stream: {e!s}", exc_info=True)
# Check if this is a retryable error (JSON parsing, incomplete tool calls, etc.)
is_retryable = isinstance(
e, (orjson.JSONDecodeError, KeyError, TypeError)
)
if is_retryable and retry_count < config.max_retries:
logger.info(
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
)
should_retry = True
else:
# Non-retryable error or max retries exceeded
# Save any partial progress before reporting error
messages_to_save: list[ChatMessage] = []
# Add assistant message if it has content or tool calls
if accumulated_tool_calls:
assistant_response.tool_calls = accumulated_tool_calls
if not has_appended_streaming_message and (
assistant_response.content or assistant_response.tool_calls
):
messages_to_save.append(assistant_response)
# Add tool response messages after assistant message
messages_to_save.extend(tool_response_messages)
if not has_saved_assistant_message:
if messages_to_save:
session.messages.extend(messages_to_save)
if messages_to_save or has_appended_streaming_message:
await upsert_chat_session(session)
if not has_yielded_error:
error_message = str(e)
if not is_retryable:
error_message = f"Non-retryable error: {error_message}"
elif retry_count >= config.max_retries:
error_message = f"Max retries ({config.max_retries}) exceeded: {error_message}"
error_response = StreamError(errorText=error_message)
yield error_response
if not has_yielded_end:
yield StreamFinish()
return
# Handle retry outside of exception handler to avoid nesting
if should_retry and retry_count < config.max_retries:
logger.info(
f"Retrying stream_chat_completion for session {session_id}, attempt {retry_count + 1}"
)
async for chunk in stream_chat_completion(
session_id=session.session_id,
user_id=user_id,
retry_count=retry_count + 1,
session=session,
context=context,
):
last_cache_time = current_time
last_cache_content_len = content_len
yield chunk
elif isinstance(chunk, StreamTextEnd):
# Emit text-end after text completes
if has_received_text and not text_streaming_ended:
text_streaming_ended = True
yield chunk
return # Exit after retry to avoid double-saving in finally block
elif isinstance(chunk, StreamToolInputStart):
# Emit text-end before first tool call, but only if we've received text
if has_received_text and not text_streaming_ended:
yield StreamTextEnd(id=text_block_id)
text_streaming_ended = True
yield chunk
elif isinstance(chunk, StreamToolInputAvailable):
# Accumulate tool calls in OpenAI format
accumulated_tool_calls.append(
{
"id": chunk.toolCallId,
"type": "function",
"function": {
"name": chunk.toolName,
"arguments": orjson.dumps(chunk.input).decode("utf-8"),
},
}
)
elif isinstance(chunk, StreamToolOutputAvailable):
result_content = (
chunk.output
if isinstance(chunk.output, str)
else orjson.dumps(chunk.output).decode("utf-8")
)
tool_response_messages.append(
ChatMessage(
role="tool",
content=result_content,
tool_call_id=chunk.toolCallId,
)
)
has_done_tool_call = True
# Track if any tool execution failed
if not chunk.success:
logger.warning(
f"Tool {chunk.toolName} (ID: {chunk.toolCallId}) execution failed"
)
yield chunk
elif isinstance(chunk, StreamFinish):
if not has_done_tool_call:
# Emit text-end before finish if we received text but haven't closed it
if has_received_text and not text_streaming_ended:
yield StreamTextEnd(id=text_block_id)
text_streaming_ended = True
# Save assistant message before yielding finish to ensure it's persisted
# even if client disconnects immediately after receiving StreamFinish
if not has_saved_assistant_message:
messages_to_save_early: list[ChatMessage] = []
if accumulated_tool_calls:
assistant_response.tool_calls = accumulated_tool_calls
if not has_appended_streaming_message and (
assistant_response.content or assistant_response.tool_calls
):
messages_to_save_early.append(assistant_response)
messages_to_save_early.extend(tool_response_messages)
if messages_to_save_early:
session.messages.extend(messages_to_save_early)
logger.info(
f"Saving assistant message before StreamFinish: "
f"content_len={len(assistant_response.content or '')}, "
f"tool_calls={len(assistant_response.tool_calls or [])}, "
f"tool_responses={len(tool_response_messages)}"
)
if messages_to_save_early or has_appended_streaming_message:
await upsert_chat_session(session)
has_saved_assistant_message = True
has_yielded_end = True
yield chunk
elif isinstance(chunk, StreamError):
has_yielded_error = True
yield chunk
elif isinstance(chunk, StreamUsage):
session.usage.append(
Usage(
prompt_tokens=chunk.promptTokens,
completion_tokens=chunk.completionTokens,
total_tokens=chunk.totalTokens,
)
)
else:
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
except CancelledError:
if not has_saved_assistant_message:
if accumulated_tool_calls:
assistant_response.tool_calls = accumulated_tool_calls
if assistant_response.content:
assistant_response.content = (
f"{assistant_response.content}\n\n[interrupted]"
)
else:
assistant_response.content = "[interrupted]"
if not has_appended_streaming_message:
session.messages.append(assistant_response)
if tool_response_messages:
session.messages.extend(tool_response_messages)
try:
await upsert_chat_session(session)
except Exception as e:
logger.warning(
f"Failed to save interrupted session {session.session_id}: {e}"
)
raise
except Exception as e:
logger.error(f"Error during stream: {e!s}", exc_info=True)
# Check if this is a retryable error (JSON parsing, incomplete tool calls, etc.)
is_retryable = isinstance(e, (orjson.JSONDecodeError, KeyError, TypeError))
if is_retryable and retry_count < config.max_retries:
logger.info(
f"Retryable error encountered. Attempt {retry_count + 1}/{config.max_retries}"
)
should_retry = True
else:
# Non-retryable error or max retries exceeded
# Save any partial progress before reporting error
messages_to_save: list[ChatMessage] = []
# Add assistant message if it has content or tool calls
if accumulated_tool_calls:
assistant_response.tool_calls = accumulated_tool_calls
if not has_appended_streaming_message and (
assistant_response.content or assistant_response.tool_calls
):
messages_to_save.append(assistant_response)
# Add tool response messages after assistant message
messages_to_save.extend(tool_response_messages)
# Normal completion path - save session and handle tool call continuation
# Only save if we haven't already saved when StreamFinish was received
if not has_saved_assistant_message:
logger.info(
f"Normal completion path: session={session.session_id}, "
f"current message_count={len(session.messages)}"
)
# Build the messages list in the correct order
messages_to_save: list[ChatMessage] = []
# Add assistant message with tool_calls if any
if accumulated_tool_calls:
assistant_response.tool_calls = accumulated_tool_calls
logger.info(
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
)
if not has_appended_streaming_message and (
assistant_response.content or assistant_response.tool_calls
):
messages_to_save.append(assistant_response)
logger.info(
f"Saving assistant message with content_len={len(assistant_response.content or '')}, tool_calls={len(assistant_response.tool_calls or [])}"
)
# Add tool response messages after assistant message
messages_to_save.extend(tool_response_messages)
logger.info(
f"Saving {len(tool_response_messages)} tool response messages, "
f"total_to_save={len(messages_to_save)}"
)
if messages_to_save:
session.messages.extend(messages_to_save)
logger.info(
f"Extended session messages, new message_count={len(session.messages)}"
)
if messages_to_save or has_appended_streaming_message:
await upsert_chat_session(session)
else:
logger.info(
"Assistant message already saved when StreamFinish was received, "
"skipping duplicate save"
)
# If we did a tool call, stream the chat completion again to get the next response
if has_done_tool_call:
logger.info(
"Tool call executed, streaming chat completion again to get assistant response"
)
async for chunk in stream_chat_completion(
session_id=session.session_id,
user_id=user_id,
session=session, # Pass session object to avoid Redis refetch
context=context,
tool_call_response=str(tool_response_messages),
):
yield chunk
if not has_yielded_error:
error_message = str(e)
if not is_retryable:
error_message = f"Non-retryable error: {error_message}"
elif retry_count >= config.max_retries:
error_message = (
f"Max retries ({config.max_retries}) exceeded: {error_message}"
)
error_response = StreamError(errorText=error_message)
yield error_response
if not has_yielded_end:
yield StreamFinish()
return
# Handle retry outside of exception handler to avoid nesting
if should_retry and retry_count < config.max_retries:
logger.info(
f"Retrying stream_chat_completion for session {session_id}, attempt {retry_count + 1}"
)
async for chunk in stream_chat_completion(
session_id=session.session_id,
user_id=user_id,
retry_count=retry_count + 1,
session=session,
context=context,
):
yield chunk
return # Exit after retry to avoid double-saving in finally block
# Normal completion path - save session and handle tool call continuation
# Only save if we haven't already saved when StreamFinish was received
if not has_saved_assistant_message:
logger.info(
f"Normal completion path: session={session.session_id}, "
f"current message_count={len(session.messages)}"
)
# Build the messages list in the correct order
messages_to_save: list[ChatMessage] = []
# Add assistant message with tool_calls if any
if accumulated_tool_calls:
assistant_response.tool_calls = accumulated_tool_calls
logger.info(
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
)
if not has_appended_streaming_message and (
assistant_response.content or assistant_response.tool_calls
):
messages_to_save.append(assistant_response)
logger.info(
f"Saving assistant message with content_len={len(assistant_response.content or '')}, tool_calls={len(assistant_response.tool_calls or [])}"
)
# Add tool response messages after assistant message
messages_to_save.extend(tool_response_messages)
logger.info(
f"Saving {len(tool_response_messages)} tool response messages, "
f"total_to_save={len(messages_to_save)}"
)
if messages_to_save:
session.messages.extend(messages_to_save)
logger.info(
f"Extended session messages, new message_count={len(session.messages)}"
)
if messages_to_save or has_appended_streaming_message:
await upsert_chat_session(session)
else:
logger.info(
"Assistant message already saved when StreamFinish was received, "
"skipping duplicate save"
)
# If we did a tool call, stream the chat completion again to get the next response
if has_done_tool_call:
logger.info(
"Tool call executed, streaming chat completion again to get assistant response"
)
async for chunk in stream_chat_completion(
session_id=session.session_id,
user_id=user_id,
session=session, # Pass session object to avoid Redis refetch
context=context,
tool_call_response=str(tool_response_messages),
):
yield chunk
# Retry configuration for OpenAI API calls

View File

@@ -3,8 +3,6 @@
import logging
from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from backend.data.understanding import (
BusinessUnderstandingInput,
@@ -61,7 +59,6 @@ and automations for the user's specific needs."""
"""Requires authentication to store user-specific data."""
return True
@observe(as_type="tool", name="add_understanding")
async def _execute(
self,
user_id: str | None,

View File

@@ -5,7 +5,6 @@ import re
from datetime import datetime, timedelta, timezone
from typing import Any
from langfuse import observe
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
@@ -329,7 +328,6 @@ class AgentOutputTool(BaseTool):
total_executions=len(available_executions) if available_executions else 1,
)
@observe(as_type="tool", name="view_agent_output")
async def _execute(
self,
user_id: str | None,

View File

@@ -3,8 +3,6 @@
import logging
from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from .agent_generator import (
@@ -80,7 +78,6 @@ class CreateAgentTool(BaseTool):
"required": ["description"],
}
@observe(as_type="tool", name="create_agent")
async def _execute(
self,
user_id: str | None,

View File

@@ -3,8 +3,6 @@
import logging
from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from .agent_generator import (
@@ -87,7 +85,6 @@ class EditAgentTool(BaseTool):
"required": ["agent_id", "changes"],
}
@observe(as_type="tool", name="edit_agent")
async def _execute(
self,
user_id: str | None,

View File

@@ -2,8 +2,6 @@
from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents
@@ -37,7 +35,6 @@ class FindAgentTool(BaseTool):
"required": ["query"],
}
@observe(as_type="tool", name="find_agent")
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:

View File

@@ -1,7 +1,6 @@
import logging
from typing import Any
from langfuse import observe
from prisma.enums import ContentType
from backend.api.features.chat.model import ChatSession
@@ -56,7 +55,6 @@ class FindBlockTool(BaseTool):
def requires_auth(self) -> bool:
return True
@observe(as_type="tool", name="find_block")
async def _execute(
self,
user_id: str | None,

View File

@@ -2,8 +2,6 @@
from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents
@@ -43,7 +41,6 @@ class FindLibraryAgentTool(BaseTool):
def requires_auth(self) -> bool:
return True
@observe(as_type="tool", name="find_library_agent")
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:

View File

@@ -4,8 +4,6 @@ import logging
from pathlib import Path
from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
@@ -73,7 +71,6 @@ class GetDocPageTool(BaseTool):
url_path = path.rsplit(".", 1)[0] if "." in path else path
return f"{DOCS_BASE_URL}/{url_path}"
@observe(as_type="tool", name="get_doc_page")
async def _execute(
self,
user_id: str | None,

View File

@@ -3,7 +3,6 @@
import logging
from typing import Any
from langfuse import observe
from pydantic import BaseModel, Field, field_validator
from backend.api.features.chat.config import ChatConfig
@@ -155,7 +154,6 @@ class RunAgentTool(BaseTool):
"""All operations require authentication."""
return True
@observe(as_type="tool", name="run_agent")
async def _execute(
self,
user_id: str | None,

View File

@@ -4,8 +4,6 @@ import logging
from collections import defaultdict
from typing import Any
from langfuse import observe
from backend.api.features.chat.model import ChatSession
from backend.data.block import get_block
from backend.data.execution import ExecutionContext
@@ -130,7 +128,6 @@ class RunBlockTool(BaseTool):
return matched_credentials, missing_credentials
@observe(as_type="tool", name="run_block")
async def _execute(
self,
user_id: str | None,

View File

@@ -3,7 +3,6 @@
import logging
from typing import Any
from langfuse import observe
from prisma.enums import ContentType
from backend.api.features.chat.model import ChatSession
@@ -88,7 +87,6 @@ class SearchDocsTool(BaseTool):
url_path = path.rsplit(".", 1)[0] if "." in path else path
return f"{DOCS_BASE_URL}/{url_path}"
@observe(as_type="tool", name="search_docs")
async def _execute(
self,
user_id: str | None,

View File

@@ -1552,7 +1552,7 @@ async def review_store_submission(
# Generate embedding for approved listing (blocking - admin operation)
# Inside transaction: if embedding fails, entire transaction rolls back
await ensure_embedding(
embedding_success = await ensure_embedding(
version_id=store_listing_version_id,
name=store_listing_version.name,
description=store_listing_version.description,
@@ -1560,6 +1560,12 @@ async def review_store_submission(
categories=store_listing_version.categories or [],
tx=tx,
)
if not embedding_success:
raise ValueError(
f"Failed to generate embedding for listing {store_listing_version_id}. "
"This is likely due to OpenAI API being unavailable. "
"Please try again later or contact support if the issue persists."
)
await prisma.models.StoreListing.prisma(tx).update(
where={"id": store_listing_version.StoreListing.id},

View File

@@ -63,42 +63,49 @@ def build_searchable_text(
return " ".join(parts)
async def generate_embedding(text: str) -> list[float]:
async def generate_embedding(text: str) -> list[float] | None:
"""
Generate embedding for text using OpenAI API.
Raises exceptions on failure - caller should handle.
Returns None if embedding generation fails.
Fail-fast: no retries to maintain consistency with approval flow.
"""
client = get_openai_client()
if not client:
raise RuntimeError("openai_internal_api_key not set, cannot generate embedding")
try:
client = get_openai_client()
if not client:
logger.error("openai_internal_api_key not set, cannot generate embedding")
return None
# Truncate text to token limit using tiktoken
# Character-based truncation is insufficient because token ratios vary by content type
enc = encoding_for_model(EMBEDDING_MODEL)
tokens = enc.encode(text)
if len(tokens) > EMBEDDING_MAX_TOKENS:
tokens = tokens[:EMBEDDING_MAX_TOKENS]
truncated_text = enc.decode(tokens)
logger.info(
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
# Truncate text to token limit using tiktoken
# Character-based truncation is insufficient because token ratios vary by content type
enc = encoding_for_model(EMBEDDING_MODEL)
tokens = enc.encode(text)
if len(tokens) > EMBEDDING_MAX_TOKENS:
tokens = tokens[:EMBEDDING_MAX_TOKENS]
truncated_text = enc.decode(tokens)
logger.info(
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
)
else:
truncated_text = text
start_time = time.time()
response = await client.embeddings.create(
model=EMBEDDING_MODEL,
input=truncated_text,
)
else:
truncated_text = text
latency_ms = (time.time() - start_time) * 1000
start_time = time.time()
response = await client.embeddings.create(
model=EMBEDDING_MODEL,
input=truncated_text,
)
latency_ms = (time.time() - start_time) * 1000
embedding = response.data[0].embedding
logger.info(
f"Generated embedding: {len(embedding)} dims, "
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
)
return embedding
embedding = response.data[0].embedding
logger.info(
f"Generated embedding: {len(embedding)} dims, "
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
)
return embedding
except Exception as e:
logger.error(f"Failed to generate embedding: {e}")
return None
async def store_embedding(
@@ -137,45 +144,48 @@ async def store_content_embedding(
New function for unified content embedding storage.
Uses raw SQL since Prisma doesn't natively support pgvector.
Raises exceptions on failure - caller should handle.
"""
client = tx if tx else prisma.get_client()
try:
client = tx if tx else prisma.get_client()
# Convert embedding to PostgreSQL vector format
embedding_str = embedding_to_vector_string(embedding)
metadata_json = dumps(metadata or {})
# Convert embedding to PostgreSQL vector format
embedding_str = embedding_to_vector_string(embedding)
metadata_json = dumps(metadata or {})
# Upsert the embedding
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
# Use unqualified ::vector - pgvector is in search_path on all environments
await execute_raw_with_schema(
"""
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
# Upsert the embedding
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
# Use unqualified ::vector - pgvector is in search_path on all environments
await execute_raw_with_schema(
"""
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
)
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
ON CONFLICT ("contentType", "contentId", "userId")
DO UPDATE SET
"embedding" = $4::vector,
"searchableText" = $5,
"metadata" = $6::jsonb,
"updatedAt" = NOW()
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
""",
content_type,
content_id,
user_id,
embedding_str,
searchable_text,
metadata_json,
client=client,
)
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
ON CONFLICT ("contentType", "contentId", "userId")
DO UPDATE SET
"embedding" = $4::vector,
"searchableText" = $5,
"metadata" = $6::jsonb,
"updatedAt" = NOW()
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
""",
content_type,
content_id,
user_id,
embedding_str,
searchable_text,
metadata_json,
client=client,
)
logger.info(f"Stored embedding for {content_type}:{content_id}")
return True
logger.info(f"Stored embedding for {content_type}:{content_id}")
return True
except Exception as e:
logger.error(f"Failed to store embedding for {content_type}:{content_id}: {e}")
return False
async def get_embedding(version_id: str) -> dict[str, Any] | None:
@@ -207,31 +217,34 @@ async def get_content_embedding(
New function for unified content embedding retrieval.
Returns dict with contentType, contentId, embedding, timestamps or None if not found.
Raises exceptions on failure - caller should handle.
"""
result = await query_raw_with_schema(
"""
SELECT
"contentType",
"contentId",
"userId",
"embedding"::text as "embedding",
"searchableText",
"metadata",
"createdAt",
"updatedAt"
FROM {schema_prefix}"UnifiedContentEmbedding"
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
""",
content_type,
content_id,
user_id,
)
try:
result = await query_raw_with_schema(
"""
SELECT
"contentType",
"contentId",
"userId",
"embedding"::text as "embedding",
"searchableText",
"metadata",
"createdAt",
"updatedAt"
FROM {schema_prefix}"UnifiedContentEmbedding"
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
""",
content_type,
content_id,
user_id,
)
if result and len(result) > 0:
return result[0]
return None
if result and len(result) > 0:
return result[0]
return None
except Exception as e:
logger.error(f"Failed to get embedding for {content_type}:{content_id}: {e}")
return None
async def ensure_embedding(
@@ -259,38 +272,46 @@ async def ensure_embedding(
tx: Optional transaction client
Returns:
True if embedding exists/was created
Raises exceptions on failure - caller should handle.
True if embedding exists/was created, False on failure
"""
# Check if embedding already exists
if not force:
existing = await get_embedding(version_id)
if existing and existing.get("embedding"):
logger.debug(f"Embedding for version {version_id} already exists")
return True
try:
# Check if embedding already exists
if not force:
existing = await get_embedding(version_id)
if existing and existing.get("embedding"):
logger.debug(f"Embedding for version {version_id} already exists")
return True
# Build searchable text for embedding
searchable_text = build_searchable_text(name, description, sub_heading, categories)
# Build searchable text for embedding
searchable_text = build_searchable_text(
name, description, sub_heading, categories
)
# Generate new embedding
embedding = await generate_embedding(searchable_text)
# Generate new embedding
embedding = await generate_embedding(searchable_text)
if embedding is None:
logger.warning(f"Could not generate embedding for version {version_id}")
return False
# Store the embedding with metadata using new function
metadata = {
"name": name,
"subHeading": sub_heading,
"categories": categories,
}
return await store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id=version_id,
embedding=embedding,
searchable_text=searchable_text,
metadata=metadata,
user_id=None, # Store agents are public
tx=tx,
)
# Store the embedding with metadata using new function
metadata = {
"name": name,
"subHeading": sub_heading,
"categories": categories,
}
return await store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id=version_id,
embedding=embedding,
searchable_text=searchable_text,
metadata=metadata,
user_id=None, # Store agents are public
tx=tx,
)
except Exception as e:
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
return False
async def delete_embedding(version_id: str) -> bool:
@@ -500,24 +521,6 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
success = sum(1 for result in results if result is True)
failed = len(results) - success
# Aggregate unique errors to avoid Sentry spam
if failed > 0:
# Group errors by type and message
error_summary: dict[str, int] = {}
for result in results:
if isinstance(result, Exception):
error_key = f"{type(result).__name__}: {str(result)}"
error_summary[error_key] = error_summary.get(error_key, 0) + 1
# Log aggregated error summary
error_details = ", ".join(
f"{error} ({count}x)" for error, count in error_summary.items()
)
logger.error(
f"{content_type.value}: {failed}/{len(results)} embeddings failed. "
f"Errors: {error_details}"
)
results_by_type[content_type.value] = {
"processed": len(missing_items),
"success": success,
@@ -554,12 +557,11 @@ async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
}
async def embed_query(query: str) -> list[float]:
async def embed_query(query: str) -> list[float] | None:
"""
Generate embedding for a search query.
Same as generate_embedding but with clearer intent.
Raises exceptions on failure - caller should handle.
"""
return await generate_embedding(query)
@@ -592,30 +594,40 @@ async def ensure_content_embedding(
tx: Optional transaction client
Returns:
True if embedding exists/was created
Raises exceptions on failure - caller should handle.
True if embedding exists/was created, False on failure
"""
# Check if embedding already exists
if not force:
existing = await get_content_embedding(content_type, content_id, user_id)
if existing and existing.get("embedding"):
logger.debug(f"Embedding for {content_type}:{content_id} already exists")
return True
try:
# Check if embedding already exists
if not force:
existing = await get_content_embedding(content_type, content_id, user_id)
if existing and existing.get("embedding"):
logger.debug(
f"Embedding for {content_type}:{content_id} already exists"
)
return True
# Generate new embedding
embedding = await generate_embedding(searchable_text)
# Generate new embedding
embedding = await generate_embedding(searchable_text)
if embedding is None:
logger.warning(
f"Could not generate embedding for {content_type}:{content_id}"
)
return False
# Store the embedding
return await store_content_embedding(
content_type=content_type,
content_id=content_id,
embedding=embedding,
searchable_text=searchable_text,
metadata=metadata or {},
user_id=user_id,
tx=tx,
)
# Store the embedding
return await store_content_embedding(
content_type=content_type,
content_id=content_id,
embedding=embedding,
searchable_text=searchable_text,
metadata=metadata or {},
user_id=user_id,
tx=tx,
)
except Exception as e:
logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}")
return False
async def cleanup_orphaned_embeddings() -> dict[str, Any]:
@@ -842,8 +854,9 @@ async def semantic_search(
limit = 100
# Generate query embedding
try:
query_embedding = await embed_query(query)
query_embedding = await embed_query(query)
if query_embedding is not None:
# Semantic search with embeddings
embedding_str = embedding_to_vector_string(query_embedding)
@@ -894,21 +907,24 @@ async def semantic_search(
"""
)
results = await query_raw_with_schema(sql, *params)
return [
{
"content_id": row["content_id"],
"content_type": row["content_type"],
"searchable_text": row["searchable_text"],
"metadata": row["metadata"],
"similarity": float(row["similarity"]),
}
for row in results
]
except Exception as e:
logger.warning(f"Semantic search failed, falling back to lexical search: {e}")
try:
results = await query_raw_with_schema(sql, *params)
return [
{
"content_id": row["content_id"],
"content_type": row["content_type"],
"searchable_text": row["searchable_text"],
"metadata": row["metadata"],
"similarity": float(row["similarity"]),
}
for row in results
]
except Exception as e:
logger.error(f"Semantic search failed: {e}")
# Fall through to lexical search below
# Fallback to lexical search if embeddings unavailable
logger.warning("Falling back to lexical search (embeddings unavailable)")
params_lexical: list[Any] = [limit]
user_filter = ""

View File

@@ -298,16 +298,17 @@ async def test_schema_handling_error_cases():
mock_client.execute_raw.side_effect = Exception("Database error")
mock_get_client.return_value = mock_client
# Should raise exception on error
with pytest.raises(Exception, match="Database error"):
await embeddings.store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id="test-id",
embedding=[0.1] * EMBEDDING_DIM,
searchable_text="test",
metadata=None,
user_id=None,
)
result = await embeddings.store_content_embedding(
content_type=ContentType.STORE_AGENT,
content_id="test-id",
embedding=[0.1] * EMBEDDING_DIM,
searchable_text="test",
metadata=None,
user_id=None,
)
# Should return False on error, not raise
assert result is False
if __name__ == "__main__":

View File

@@ -80,8 +80,9 @@ async def test_generate_embedding_no_api_key():
) as mock_get_client:
mock_get_client.return_value = None
with pytest.raises(RuntimeError, match="openai_internal_api_key not set"):
await embeddings.generate_embedding("test text")
result = await embeddings.generate_embedding("test text")
assert result is None
@pytest.mark.asyncio(loop_scope="session")
@@ -96,8 +97,9 @@ async def test_generate_embedding_api_error():
) as mock_get_client:
mock_get_client.return_value = mock_client
with pytest.raises(Exception, match="API Error"):
await embeddings.generate_embedding("test text")
result = await embeddings.generate_embedding("test text")
assert result is None
@pytest.mark.asyncio(loop_scope="session")
@@ -171,10 +173,11 @@ async def test_store_embedding_database_error(mocker):
embedding = [0.1, 0.2, 0.3]
with pytest.raises(Exception, match="Database error"):
await embeddings.store_embedding(
version_id="test-version-id", embedding=embedding, tx=mock_client
)
result = await embeddings.store_embedding(
version_id="test-version-id", embedding=embedding, tx=mock_client
)
assert result is False
@pytest.mark.asyncio(loop_scope="session")
@@ -274,16 +277,17 @@ async def test_ensure_embedding_create_new(mock_get, mock_store, mock_generate):
async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
"""Test ensure_embedding when generation fails."""
mock_get.return_value = None
mock_generate.side_effect = Exception("Generation failed")
mock_generate.return_value = None
with pytest.raises(Exception, match="Generation failed"):
await embeddings.ensure_embedding(
version_id="test-id",
name="Test",
description="Test description",
sub_heading="Test heading",
categories=["test"],
)
result = await embeddings.ensure_embedding(
version_id="test-id",
name="Test",
description="Test description",
sub_heading="Test heading",
categories=["test"],
)
assert result is False
@pytest.mark.asyncio(loop_scope="session")

View File

@@ -186,12 +186,13 @@ async def unified_hybrid_search(
offset = (page - 1) * page_size
# Generate query embedding with graceful degradation
try:
query_embedding = await embed_query(query)
except Exception as e:
# Generate query embedding
query_embedding = await embed_query(query)
# Graceful degradation if embedding unavailable
if query_embedding is None or not query_embedding:
logger.warning(
f"Failed to generate query embedding - falling back to lexical-only search: {e}. "
"Failed to generate query embedding - falling back to lexical-only search. "
"Check that openai_internal_api_key is configured and OpenAI API is accessible."
)
query_embedding = [0.0] * EMBEDDING_DIM
@@ -463,12 +464,13 @@ async def hybrid_search(
offset = (page - 1) * page_size
# Generate query embedding with graceful degradation
try:
query_embedding = await embed_query(query)
except Exception as e:
# Generate query embedding
query_embedding = await embed_query(query)
# Graceful degradation
if query_embedding is None or not query_embedding:
logger.warning(
f"Failed to generate query embedding - falling back to lexical-only search: {e}"
"Failed to generate query embedding - falling back to lexical-only search."
)
query_embedding = [0.0] * EMBEDDING_DIM
total_non_semantic = (

View File

@@ -172,8 +172,8 @@ async def test_hybrid_search_without_embeddings():
with patch(
"backend.api.features.store.hybrid_search.query_raw_with_schema"
) as mock_query:
# Simulate embedding failure by raising exception
mock_embed.side_effect = Exception("Embedding generation failed")
# Simulate embedding failure
mock_embed.return_value = None
mock_query.return_value = mock_results
# Should NOT raise - graceful degradation
@@ -613,9 +613,7 @@ async def test_unified_hybrid_search_graceful_degradation():
"backend.api.features.store.hybrid_search.embed_query"
) as mock_embed:
mock_query.return_value = mock_results
mock_embed.side_effect = Exception(
"Embedding generation failed"
) # Embedding failure
mock_embed.return_value = None # Embedding failure
# Should NOT raise - graceful degradation
results, total = await unified_hybrid_search(