fix(backend/store): use tiktoken for embedding truncation and add user_id to delete

Critical:
- Replace character-based truncation (32k chars) with token-based (8,191 tokens)
- Fixes potential API failures when text has high token-to-char ratio
- Use tiktoken.encoding_for_model() to match OpenAI's token counting

Security:
- Add user_id parameter to delete_content_embedding()
- Prevents accidental deletion of other users' embeddings for LIBRARY_AGENT
- WHERE clause now filters by user_id for user-scoped content types

Addresses CodeRabbit security and critical issues
This commit is contained in:
Zamil Majdy
2026-01-13 17:43:54 -06:00
parent c5c1d8d605
commit 1f3a9d0922

View File

@@ -12,6 +12,7 @@ from typing import Any
import prisma
from prisma.enums import ContentType
from tiktoken import encoding_for_model
from backend.data.db import execute_raw_with_schema, query_raw_with_schema
from backend.util.clients import get_openai_client
@@ -22,6 +23,8 @@ logger = logging.getLogger(__name__)
# OpenAI embedding model configuration
EMBEDDING_MODEL = "text-embedding-3-small"
# OpenAI embedding token limit (8,191 with 1 token buffer for safety)
EMBEDDING_MAX_TOKENS = 8191
def build_searchable_text(
@@ -69,8 +72,18 @@ async def generate_embedding(text: str) -> list[float] | None:
logger.error("openai_internal_api_key not set, cannot generate embedding")
return None
# Truncate text to avoid token limits (~32k chars for safety)
truncated_text = text[:32000]
# Truncate text to token limit using tiktoken
# Character-based truncation is insufficient because token ratios vary by content type
enc = encoding_for_model(EMBEDDING_MODEL)
tokens = enc.encode(text)
if len(tokens) > EMBEDDING_MAX_TOKENS:
tokens = tokens[:EMBEDDING_MAX_TOKENS]
truncated_text = enc.decode(tokens)
logger.info(
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
)
else:
truncated_text = text
start_time = time.time()
response = await client.embeddings.create(
@@ -82,7 +95,7 @@ async def generate_embedding(text: str) -> list[float] | None:
embedding = response.data[0].embedding
logger.info(
f"Generated embedding: {len(embedding)} dims, "
f"{len(truncated_text)} chars, {latency_ms:.0f}ms"
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
)
return embedding
@@ -307,13 +320,25 @@ async def delete_embedding(version_id: str) -> bool:
return await delete_content_embedding(ContentType.STORE_AGENT, version_id)
async def delete_content_embedding(content_type: ContentType, content_id: str) -> bool:
async def delete_content_embedding(
content_type: ContentType, content_id: str, user_id: str | None = None
) -> bool:
"""
Delete embedding for any content type.
New function for unified content embedding deletion.
Note: This is usually handled automatically by CASCADE delete,
but provided for manual cleanup if needed.
Args:
content_type: The type of content (STORE_AGENT, LIBRARY_AGENT, etc.)
content_id: The unique identifier for the content
user_id: Optional user ID. For public content (STORE_AGENT, BLOCK), pass None.
For user-scoped content (LIBRARY_AGENT), pass the user's ID to avoid
deleting embeddings belonging to other users.
Returns:
True if deletion succeeded, False otherwise
"""
try:
client = prisma.get_client()
@@ -321,14 +346,18 @@ async def delete_content_embedding(content_type: ContentType, content_id: str) -
await execute_raw_with_schema(
"""
DELETE FROM {schema_prefix}"UnifiedContentEmbedding"
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2
WHERE "contentType" = $1::{schema_prefix}"ContentType"
AND "contentId" = $2
AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
""",
content_type,
content_id,
user_id,
client=client,
)
logger.info(f"Deleted embedding for {content_type}:{content_id}")
user_str = f" (user: {user_id})" if user_id else ""
logger.info(f"Deleted embedding for {content_type}:{content_id}{user_str}")
return True
except Exception as e: