mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-14 17:47:57 -05:00
Compare commits
75 Commits
figure-out
...
feat/backf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
06b07604b4 | ||
|
|
9f0c8c06c5 | ||
|
|
3ba374286c | ||
|
|
f4da46cb57 | ||
|
|
10e385612e | ||
|
|
0db134fdd9 | ||
|
|
461bf25bc1 | ||
|
|
f45ef091e2 | ||
|
|
61efee4139 | ||
|
|
83f46d373d | ||
|
|
07153d5536 | ||
|
|
f3c747027b | ||
|
|
764e1026e5 | ||
|
|
0890ce00b5 | ||
|
|
7f952900ae | ||
|
|
dc5da41703 | ||
|
|
1f3a9d0922 | ||
|
|
c5c1d8d605 | ||
|
|
9ae54e2975 | ||
|
|
8063bb4503 | ||
|
|
2b28023266 | ||
|
|
1b8d8e3772 | ||
|
|
34eb6bdca1 | ||
|
|
44610bb778 | ||
|
|
9afa8a739b | ||
|
|
a76fa0f0a9 | ||
|
|
b0b556e24e | ||
|
|
60ba50431d | ||
|
|
4b8332a14f | ||
|
|
7097cedc1d | ||
|
|
5a60618c2d | ||
|
|
547c6f93d4 | ||
|
|
6dbd45eaf0 | ||
|
|
ca398f3cc5 | ||
|
|
16a14ca09e | ||
|
|
704b8a9207 | ||
|
|
1a5abcc36a | ||
|
|
419b966db1 | ||
|
|
9b8d917d99 | ||
|
|
6432d35db2 | ||
|
|
7d46a5c1dc | ||
|
|
a63370bc30 | ||
|
|
6a86f2e3ea | ||
|
|
679c7806f2 | ||
|
|
5c7391fcd7 | ||
|
|
faf9ad9b57 | ||
|
|
f5899acac0 | ||
|
|
e539280e98 | ||
|
|
72783dcc02 | ||
|
|
af13badf8f | ||
|
|
b491610ebf | ||
|
|
0b022073eb | ||
|
|
01eef83809 | ||
|
|
4644c09b9e | ||
|
|
374860ff2c | ||
|
|
e7e09ef4e1 | ||
|
|
5e691661a8 | ||
|
|
b0e8c17419 | ||
|
|
5a7c1e39dd | ||
|
|
53b03e746a | ||
|
|
db8b43bb3d | ||
|
|
923d8baedc | ||
|
|
a55b2e02dc | ||
|
|
6b6648b290 | ||
|
|
c0a9c0410b | ||
|
|
17a77b02c7 | ||
|
|
701fce83ca | ||
|
|
78d89d0faf | ||
|
|
f482eb668b | ||
|
|
4a52b7eca0 | ||
|
|
5aaf07fbaf | ||
|
|
0d2996e501 | ||
|
|
9e37a66bca | ||
|
|
429a074848 | ||
|
|
7f1245dc42 |
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -176,7 +176,7 @@ jobs:
|
||||
}
|
||||
|
||||
- name: Run Database Migrations
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
run: poetry run prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
DIRECT_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
1
autogpt_platform/backend/.gitignore
vendored
1
autogpt_platform/backend/.gitignore
vendored
@@ -18,3 +18,4 @@ load-tests/results/
|
||||
load-tests/*.json
|
||||
load-tests/*.log
|
||||
load-tests/node_modules/*
|
||||
migrations/*/rollback*.sql
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import orjson
|
||||
import pytest
|
||||
@@ -17,6 +18,17 @@ setup_test_data = setup_test_data
|
||||
setup_firecrawl_test_data = setup_firecrawl_test_data
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def mock_embedding_functions():
|
||||
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
||||
with patch(
|
||||
"backend.api.features.store.db.ensure_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_run_agent(setup_test_data):
|
||||
"""Test that the run_agent tool successfully executes an approved agent"""
|
||||
|
||||
@@ -0,0 +1,417 @@
|
||||
"""
|
||||
Content Type Handlers for Unified Embeddings
|
||||
|
||||
Pluggable system for different content sources (store agents, blocks, docs).
|
||||
Each handler knows how to fetch and process its content type for embedding.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.data.db import query_raw_with_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContentItem:
|
||||
"""Represents a piece of content to be embedded."""
|
||||
|
||||
content_id: str # Unique identifier (DB ID or file path)
|
||||
content_type: ContentType
|
||||
searchable_text: str # Combined text for embedding
|
||||
metadata: dict[str, Any] # Content-specific metadata
|
||||
user_id: str | None = None # For user-scoped content
|
||||
|
||||
|
||||
class ContentHandler(ABC):
|
||||
"""Base handler for fetching and processing content for embeddings."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def content_type(self) -> ContentType:
|
||||
"""The ContentType this handler manages."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""
|
||||
Fetch items that don't have embeddings yet.
|
||||
|
||||
Args:
|
||||
batch_size: Maximum number of items to return
|
||||
|
||||
Returns:
|
||||
List of ContentItem objects ready for embedding
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""
|
||||
Get statistics about embedding coverage.
|
||||
|
||||
Returns:
|
||||
Dict with keys: total, with_embeddings, without_embeddings
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StoreAgentHandler(ContentHandler):
|
||||
"""Handler for marketplace store agent listings."""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.STORE_AGENT
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch approved store listings without embeddings."""
|
||||
from backend.api.features.store.embeddings import build_searchable_text
|
||||
|
||||
missing = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT
|
||||
slv.id,
|
||||
slv.name,
|
||||
slv.description,
|
||||
slv."subHeading",
|
||||
slv.categories
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
LEFT JOIN {schema_prefix}"UnifiedContentEmbedding" uce
|
||||
ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
AND uce."contentId" IS NULL
|
||||
LIMIT $1
|
||||
""",
|
||||
batch_size,
|
||||
)
|
||||
|
||||
return [
|
||||
ContentItem(
|
||||
content_id=row["id"],
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
searchable_text=build_searchable_text(
|
||||
name=row["name"],
|
||||
description=row["description"],
|
||||
sub_heading=row["subHeading"],
|
||||
categories=row["categories"] or [],
|
||||
),
|
||||
metadata={
|
||||
"name": row["name"],
|
||||
"categories": row["categories"] or [],
|
||||
},
|
||||
user_id=None, # Store agents are public
|
||||
)
|
||||
for row in missing
|
||||
]
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about store agent embedding coverage."""
|
||||
# Count approved versions
|
||||
approved_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"StoreListingVersion"
|
||||
WHERE "submissionStatus" = 'APPROVED'
|
||||
AND "isDeleted" = false
|
||||
"""
|
||||
)
|
||||
total_approved = approved_result[0]["count"] if approved_result else 0
|
||||
|
||||
# Count versions with embeddings
|
||||
embedded_result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {schema_prefix}"StoreListingVersion" slv
|
||||
JOIN {schema_prefix}"UnifiedContentEmbedding" uce ON slv.id = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{schema_prefix}"ContentType"
|
||||
WHERE slv."submissionStatus" = 'APPROVED'
|
||||
AND slv."isDeleted" = false
|
||||
"""
|
||||
)
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_approved,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_approved - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
class BlockHandler(ContentHandler):
|
||||
"""Handler for block definitions (Python classes)."""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.BLOCK
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch blocks without embeddings."""
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
# Get all available blocks
|
||||
all_blocks = get_blocks()
|
||||
|
||||
# Check which ones have embeddings
|
||||
if not all_blocks:
|
||||
return []
|
||||
|
||||
block_ids = list(all_blocks.keys())
|
||||
|
||||
# Query for existing embeddings
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
existing_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT "contentId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'BLOCK'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*block_ids,
|
||||
)
|
||||
|
||||
existing_ids = {row["contentId"] for row in existing_result}
|
||||
missing_blocks = [
|
||||
(block_id, block_cls)
|
||||
for block_id, block_cls in all_blocks.items()
|
||||
if block_id not in existing_ids
|
||||
]
|
||||
|
||||
# Convert to ContentItem
|
||||
items = []
|
||||
for block_id, block_cls in missing_blocks[:batch_size]:
|
||||
try:
|
||||
block_instance = block_cls()
|
||||
|
||||
# Build searchable text from block metadata
|
||||
parts = []
|
||||
if hasattr(block_instance, "name") and block_instance.name:
|
||||
parts.append(block_instance.name)
|
||||
if (
|
||||
hasattr(block_instance, "description")
|
||||
and block_instance.description
|
||||
):
|
||||
parts.append(block_instance.description)
|
||||
if hasattr(block_instance, "categories") and block_instance.categories:
|
||||
# Convert BlockCategory enum to strings
|
||||
parts.append(
|
||||
" ".join(str(cat.value) for cat in block_instance.categories)
|
||||
)
|
||||
|
||||
# Add input/output schema info
|
||||
if hasattr(block_instance, "input_schema"):
|
||||
schema = block_instance.input_schema
|
||||
if hasattr(schema, "model_json_schema"):
|
||||
schema_dict = schema.model_json_schema()
|
||||
if "properties" in schema_dict:
|
||||
for prop_name, prop_info in schema_dict[
|
||||
"properties"
|
||||
].items():
|
||||
if "description" in prop_info:
|
||||
parts.append(
|
||||
f"{prop_name}: {prop_info['description']}"
|
||||
)
|
||||
|
||||
searchable_text = " ".join(parts)
|
||||
|
||||
items.append(
|
||||
ContentItem(
|
||||
content_id=block_id,
|
||||
content_type=ContentType.BLOCK,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"name": getattr(block_instance, "name", ""),
|
||||
"categories": getattr(block_instance, "categories", []),
|
||||
},
|
||||
user_id=None, # Blocks are public
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process block {block_id}: {e}")
|
||||
continue
|
||||
|
||||
return items
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about block embedding coverage."""
|
||||
from backend.data.block import get_blocks
|
||||
|
||||
all_blocks = get_blocks()
|
||||
total_blocks = len(all_blocks)
|
||||
|
||||
if total_blocks == 0:
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
block_ids = list(all_blocks.keys())
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(block_ids))])
|
||||
|
||||
embedded_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'BLOCK'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*block_ids,
|
||||
)
|
||||
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_blocks,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_blocks - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
class DocumentationHandler(ContentHandler):
|
||||
"""Handler for documentation files (.md/.mdx)."""
|
||||
|
||||
@property
|
||||
def content_type(self) -> ContentType:
|
||||
return ContentType.DOCUMENTATION
|
||||
|
||||
def _get_docs_root(self) -> Path:
|
||||
"""Get the documentation root directory."""
|
||||
# Assuming docs are in /docs relative to project root
|
||||
backend_root = Path(__file__).parent.parent.parent.parent
|
||||
docs_root = backend_root.parent.parent / "docs"
|
||||
return docs_root
|
||||
|
||||
def _extract_title_and_content(self, file_path: Path) -> tuple[str, str]:
|
||||
"""Extract title and content from markdown file."""
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
|
||||
# Try to extract title from first # heading
|
||||
lines = content.split("\n")
|
||||
title = ""
|
||||
body_lines = []
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("# ") and not title:
|
||||
title = line[2:].strip()
|
||||
else:
|
||||
body_lines.append(line)
|
||||
|
||||
# If no title found, use filename
|
||||
if not title:
|
||||
title = file_path.stem.replace("-", " ").replace("_", " ").title()
|
||||
|
||||
body = "\n".join(body_lines)
|
||||
|
||||
return title, body
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read {file_path}: {e}")
|
||||
return file_path.stem, ""
|
||||
|
||||
async def get_missing_items(self, batch_size: int) -> list[ContentItem]:
|
||||
"""Fetch documentation files without embeddings."""
|
||||
docs_root = self._get_docs_root()
|
||||
|
||||
if not docs_root.exists():
|
||||
logger.warning(f"Documentation root not found: {docs_root}")
|
||||
return []
|
||||
|
||||
# Find all .md and .mdx files
|
||||
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
|
||||
|
||||
# Get relative paths for content IDs
|
||||
doc_paths = [str(doc.relative_to(docs_root)) for doc in all_docs]
|
||||
|
||||
if not doc_paths:
|
||||
return []
|
||||
|
||||
# Check which ones have embeddings
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(doc_paths))])
|
||||
existing_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT "contentId"
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'DOCUMENTATION'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*doc_paths,
|
||||
)
|
||||
|
||||
existing_ids = {row["contentId"] for row in existing_result}
|
||||
missing_docs = [
|
||||
(doc_path, doc_file)
|
||||
for doc_path, doc_file in zip(doc_paths, all_docs)
|
||||
if doc_path not in existing_ids
|
||||
]
|
||||
|
||||
# Convert to ContentItem
|
||||
items = []
|
||||
for doc_path, doc_file in missing_docs[:batch_size]:
|
||||
try:
|
||||
title, content = self._extract_title_and_content(doc_file)
|
||||
|
||||
# Build searchable text
|
||||
searchable_text = f"{title} {content}"
|
||||
|
||||
items.append(
|
||||
ContentItem(
|
||||
content_id=doc_path,
|
||||
content_type=ContentType.DOCUMENTATION,
|
||||
searchable_text=searchable_text,
|
||||
metadata={
|
||||
"title": title,
|
||||
"path": doc_path,
|
||||
},
|
||||
user_id=None, # Documentation is public
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to process doc {doc_path}: {e}")
|
||||
continue
|
||||
|
||||
return items
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""Get statistics about documentation embedding coverage."""
|
||||
docs_root = self._get_docs_root()
|
||||
|
||||
if not docs_root.exists():
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
# Count all .md and .mdx files
|
||||
all_docs = list(docs_root.rglob("*.md")) + list(docs_root.rglob("*.mdx"))
|
||||
total_docs = len(all_docs)
|
||||
|
||||
if total_docs == 0:
|
||||
return {"total": 0, "with_embeddings": 0, "without_embeddings": 0}
|
||||
|
||||
doc_paths = [str(doc.relative_to(docs_root)) for doc in all_docs]
|
||||
placeholders = ",".join([f"${i+1}" for i in range(len(doc_paths))])
|
||||
|
||||
embedded_result = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {{schema_prefix}}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = 'DOCUMENTATION'::{{schema_prefix}}"ContentType"
|
||||
AND "contentId" = ANY(ARRAY[{placeholders}])
|
||||
""",
|
||||
*doc_paths,
|
||||
)
|
||||
|
||||
with_embeddings = embedded_result[0]["count"] if embedded_result else 0
|
||||
|
||||
return {
|
||||
"total": total_docs,
|
||||
"with_embeddings": with_embeddings,
|
||||
"without_embeddings": total_docs - with_embeddings,
|
||||
}
|
||||
|
||||
|
||||
# Content handler registry
|
||||
CONTENT_HANDLERS: dict[ContentType, ContentHandler] = {
|
||||
ContentType.STORE_AGENT: StoreAgentHandler(),
|
||||
ContentType.BLOCK: BlockHandler(),
|
||||
ContentType.DOCUMENTATION: DocumentationHandler(),
|
||||
}
|
||||
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
Integration tests for content handlers using real DB.
|
||||
|
||||
Run with: poetry run pytest backend/api/features/store/content_handlers_integration_test.py -xvs
|
||||
|
||||
These tests use the real database but mock OpenAI calls.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.store.content_handlers import (
|
||||
CONTENT_HANDLERS,
|
||||
BlockHandler,
|
||||
DocumentationHandler,
|
||||
StoreAgentHandler,
|
||||
)
|
||||
from backend.api.features.store.embeddings import (
|
||||
backfill_all_content_types,
|
||||
ensure_content_embedding,
|
||||
get_embedding_stats,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_real_db():
|
||||
"""Test StoreAgentHandler with real database queries."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Get stats from real DB
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list (may be empty if all have embeddings)
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None
|
||||
assert item.content_type.value == "STORE_AGENT"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_real_db():
|
||||
"""Test BlockHandler with real database queries."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Get stats from real DB
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0 # Should have at least some blocks
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None # Should be block UUID
|
||||
assert item.content_type.value == "BLOCK"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_real_fs():
|
||||
"""Test DocumentationHandler with real filesystem."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Get stats from real filesystem
|
||||
stats = await handler.get_stats()
|
||||
|
||||
# Stats should have correct structure
|
||||
assert "total" in stats
|
||||
assert "with_embeddings" in stats
|
||||
assert "without_embeddings" in stats
|
||||
assert stats["total"] >= 0
|
||||
assert stats["with_embeddings"] >= 0
|
||||
assert stats["without_embeddings"] >= 0
|
||||
|
||||
# Get missing items (max 1 to keep test fast)
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
# Items should be list
|
||||
assert isinstance(items, list)
|
||||
|
||||
if items:
|
||||
item = items[0]
|
||||
assert item.content_id is not None # Should be relative path
|
||||
assert item.content_type.value == "DOCUMENTATION"
|
||||
assert item.searchable_text != ""
|
||||
assert item.user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_stats_all_types():
|
||||
"""Test get_embedding_stats aggregates all content types."""
|
||||
stats = await get_embedding_stats()
|
||||
|
||||
# Should have structure with by_type and totals
|
||||
assert "by_type" in stats
|
||||
assert "totals" in stats
|
||||
|
||||
# Check each content type is present
|
||||
by_type = stats["by_type"]
|
||||
assert "STORE_AGENT" in by_type
|
||||
assert "BLOCK" in by_type
|
||||
assert "DOCUMENTATION" in by_type
|
||||
|
||||
# Check totals are aggregated
|
||||
totals = stats["totals"]
|
||||
assert totals["total"] >= 0
|
||||
assert totals["with_embeddings"] >= 0
|
||||
assert totals["without_embeddings"] >= 0
|
||||
assert "coverage_percent" in totals
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
async def test_ensure_content_embedding_blocks(mock_generate):
|
||||
"""Test creating embeddings for blocks (mocked OpenAI)."""
|
||||
# Mock OpenAI to return fake embedding
|
||||
mock_generate.return_value = [0.1] * 1536
|
||||
|
||||
# Get one block without embedding
|
||||
handler = BlockHandler()
|
||||
items = await handler.get_missing_items(batch_size=1)
|
||||
|
||||
if not items:
|
||||
pytest.skip("No blocks without embeddings")
|
||||
|
||||
item = items[0]
|
||||
|
||||
# Try to create embedding (OpenAI mocked)
|
||||
result = await ensure_content_embedding(
|
||||
content_type=item.content_type,
|
||||
content_id=item.content_id,
|
||||
searchable_text=item.searchable_text,
|
||||
metadata=item.metadata,
|
||||
user_id=item.user_id,
|
||||
)
|
||||
|
||||
# Should succeed with mocked OpenAI
|
||||
assert result is True
|
||||
mock_generate.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
async def test_backfill_all_content_types_dry_run(mock_generate):
|
||||
"""Test backfill_all_content_types processes all handlers in order."""
|
||||
# Mock OpenAI to return fake embedding
|
||||
mock_generate.return_value = [0.1] * 1536
|
||||
|
||||
# Run backfill with batch_size=1 to process max 1 per type
|
||||
result = await backfill_all_content_types(batch_size=1)
|
||||
|
||||
# Should have results for all content types
|
||||
assert "by_type" in result
|
||||
assert "totals" in result
|
||||
|
||||
by_type = result["by_type"]
|
||||
assert "BLOCK" in by_type
|
||||
assert "STORE_AGENT" in by_type
|
||||
assert "DOCUMENTATION" in by_type
|
||||
|
||||
# Each type should have correct structure
|
||||
for content_type, type_result in by_type.items():
|
||||
assert "processed" in type_result
|
||||
assert "success" in type_result
|
||||
assert "failed" in type_result
|
||||
|
||||
# Totals should aggregate
|
||||
totals = result["totals"]
|
||||
assert totals["processed"] >= 0
|
||||
assert totals["success"] >= 0
|
||||
assert totals["failed"] >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_content_handler_registry():
|
||||
"""Test all handlers are registered in correct order."""
|
||||
from prisma.enums import ContentType
|
||||
|
||||
# All three types should be registered
|
||||
assert ContentType.STORE_AGENT in CONTENT_HANDLERS
|
||||
assert ContentType.BLOCK in CONTENT_HANDLERS
|
||||
assert ContentType.DOCUMENTATION in CONTENT_HANDLERS
|
||||
|
||||
# Check handler types
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
|
||||
@@ -0,0 +1,324 @@
|
||||
"""
|
||||
E2E tests for content handlers (blocks, store agents, documentation).
|
||||
|
||||
Tests the full flow: discovering content → generating embeddings → storing.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store.content_handlers import (
|
||||
CONTENT_HANDLERS,
|
||||
BlockHandler,
|
||||
DocumentationHandler,
|
||||
StoreAgentHandler,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_get_missing_items(mocker):
|
||||
"""Test StoreAgentHandler fetches approved agents without embeddings."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock database query
|
||||
mock_missing = [
|
||||
{
|
||||
"id": "agent-1",
|
||||
"name": "Test Agent",
|
||||
"description": "A test agent",
|
||||
"subHeading": "Test heading",
|
||||
"categories": ["AI", "Testing"],
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_missing,
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "agent-1"
|
||||
assert items[0].content_type == ContentType.STORE_AGENT
|
||||
assert "Test Agent" in items[0].searchable_text
|
||||
assert "A test agent" in items[0].searchable_text
|
||||
assert items[0].metadata["name"] == "Test Agent"
|
||||
assert items[0].user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_agent_handler_get_stats(mocker):
|
||||
"""Test StoreAgentHandler returns correct stats."""
|
||||
handler = StoreAgentHandler()
|
||||
|
||||
# Mock approved count query
|
||||
mock_approved = [{"count": 50}]
|
||||
# Mock embedded count query
|
||||
mock_embedded = [{"count": 30}]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
side_effect=[mock_approved, mock_embedded],
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 50
|
||||
assert stats["with_embeddings"] == 30
|
||||
assert stats["without_embeddings"] == 20
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_missing_items(mocker):
|
||||
"""Test BlockHandler discovers blocks without embeddings."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock get_blocks to return test blocks
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Calculator Block"
|
||||
mock_block_instance.description = "Performs calculations"
|
||||
mock_block_instance.categories = [MagicMock(value="MATH")]
|
||||
mock_block_instance.input_schema.model_json_schema.return_value = {
|
||||
"properties": {"expression": {"description": "Math expression to evaluate"}}
|
||||
}
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-uuid-1": mock_block_class}
|
||||
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
mock_existing = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_existing,
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "block-uuid-1"
|
||||
assert items[0].content_type == ContentType.BLOCK
|
||||
assert "Calculator Block" in items[0].searchable_text
|
||||
assert "Performs calculations" in items[0].searchable_text
|
||||
assert "MATH" in items[0].searchable_text
|
||||
assert "expression: Math expression" in items[0].searchable_text
|
||||
assert items[0].user_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_get_stats(mocker):
|
||||
"""Test BlockHandler returns correct stats."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock get_blocks
|
||||
mock_blocks = {
|
||||
"block-1": MagicMock(),
|
||||
"block-2": MagicMock(),
|
||||
"block-3": MagicMock(),
|
||||
}
|
||||
|
||||
# Mock embedded count query (2 blocks have embeddings)
|
||||
mock_embedded = [{"count": 2}]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_embedded,
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 3
|
||||
assert stats["with_embeddings"] == 2
|
||||
assert stats["without_embeddings"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_get_missing_items(tmp_path, mocker):
|
||||
"""Test DocumentationHandler discovers docs without embeddings."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory with test files
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
|
||||
(docs_root / "guide.md").write_text("# Getting Started\n\nThis is a guide.")
|
||||
(docs_root / "api.mdx").write_text("# API Reference\n\nAPI documentation.")
|
||||
|
||||
# Mock _get_docs_root to return temp dir
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
# Mock existing embeddings query (no embeddings exist)
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 2
|
||||
|
||||
# Check guide.md
|
||||
guide_item = next(
|
||||
(item for item in items if item.content_id == "guide.md"), None
|
||||
)
|
||||
assert guide_item is not None
|
||||
assert guide_item.content_type == ContentType.DOCUMENTATION
|
||||
assert "Getting Started" in guide_item.searchable_text
|
||||
assert "This is a guide" in guide_item.searchable_text
|
||||
assert guide_item.metadata["title"] == "Getting Started"
|
||||
assert guide_item.user_id is None
|
||||
|
||||
# Check api.mdx
|
||||
api_item = next(
|
||||
(item for item in items if item.content_id == "api.mdx"), None
|
||||
)
|
||||
assert api_item is not None
|
||||
assert "API Reference" in api_item.searchable_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_get_stats(tmp_path, mocker):
|
||||
"""Test DocumentationHandler returns correct stats."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Create temporary docs directory
|
||||
docs_root = tmp_path / "docs"
|
||||
docs_root.mkdir()
|
||||
(docs_root / "doc1.md").write_text("# Doc 1")
|
||||
(docs_root / "doc2.md").write_text("# Doc 2")
|
||||
(docs_root / "doc3.mdx").write_text("# Doc 3")
|
||||
|
||||
# Mock embedded count query (1 doc has embedding)
|
||||
mock_embedded = [{"count": 1}]
|
||||
|
||||
with patch.object(handler, "_get_docs_root", return_value=docs_root):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=mock_embedded,
|
||||
):
|
||||
stats = await handler.get_stats()
|
||||
|
||||
assert stats["total"] == 3
|
||||
assert stats["with_embeddings"] == 1
|
||||
assert stats["without_embeddings"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_title_extraction(tmp_path):
|
||||
"""Test DocumentationHandler extracts title from markdown heading."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Test with heading
|
||||
doc_with_heading = tmp_path / "with_heading.md"
|
||||
doc_with_heading.write_text("# My Title\n\nContent here")
|
||||
title, content = handler._extract_title_and_content(doc_with_heading)
|
||||
assert title == "My Title"
|
||||
assert "# My Title" not in content
|
||||
assert "Content here" in content
|
||||
|
||||
# Test without heading
|
||||
doc_without_heading = tmp_path / "no-heading.md"
|
||||
doc_without_heading.write_text("Just content, no heading")
|
||||
title, content = handler._extract_title_and_content(doc_without_heading)
|
||||
assert title == "No Heading" # Uses filename
|
||||
assert "Just content" in content
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_content_handlers_registry():
|
||||
"""Test all content types are registered."""
|
||||
assert ContentType.STORE_AGENT in CONTENT_HANDLERS
|
||||
assert ContentType.BLOCK in CONTENT_HANDLERS
|
||||
assert ContentType.DOCUMENTATION in CONTENT_HANDLERS
|
||||
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.STORE_AGENT], StoreAgentHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.BLOCK], BlockHandler)
|
||||
assert isinstance(CONTENT_HANDLERS[ContentType.DOCUMENTATION], DocumentationHandler)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_handles_missing_attributes():
|
||||
"""Test BlockHandler gracefully handles blocks with missing attributes."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock block with minimal attributes
|
||||
mock_block_class = MagicMock()
|
||||
mock_block_instance = MagicMock()
|
||||
mock_block_instance.name = "Minimal Block"
|
||||
# No description, categories, or schema
|
||||
del mock_block_instance.description
|
||||
del mock_block_instance.categories
|
||||
del mock_block_instance.input_schema
|
||||
mock_block_class.return_value = mock_block_instance
|
||||
|
||||
mock_blocks = {"block-minimal": mock_block_class}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].searchable_text == "Minimal Block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_block_handler_skips_failed_blocks():
|
||||
"""Test BlockHandler skips blocks that fail to instantiate."""
|
||||
handler = BlockHandler()
|
||||
|
||||
# Mock one good block and one bad block
|
||||
good_block = MagicMock()
|
||||
good_instance = MagicMock()
|
||||
good_instance.name = "Good Block"
|
||||
good_instance.description = "Works fine"
|
||||
good_instance.categories = []
|
||||
good_block.return_value = good_instance
|
||||
|
||||
bad_block = MagicMock()
|
||||
bad_block.side_effect = Exception("Instantiation failed")
|
||||
|
||||
mock_blocks = {"good-block": good_block, "bad-block": bad_block}
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.get_blocks",
|
||||
return_value=mock_blocks,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.store.content_handlers.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
|
||||
# Should only get the good block
|
||||
assert len(items) == 1
|
||||
assert items[0].content_id == "good-block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_documentation_handler_missing_docs_directory():
|
||||
"""Test DocumentationHandler handles missing docs directory gracefully."""
|
||||
handler = DocumentationHandler()
|
||||
|
||||
# Mock _get_docs_root to return non-existent path
|
||||
fake_path = Path("/nonexistent/docs")
|
||||
with patch.object(handler, "_get_docs_root", return_value=fake_path):
|
||||
items = await handler.get_missing_items(batch_size=10)
|
||||
assert items == []
|
||||
|
||||
stats = await handler.get_stats()
|
||||
assert stats["total"] == 0
|
||||
assert stats["with_embeddings"] == 0
|
||||
assert stats["without_embeddings"] == 0
|
||||
@@ -1,8 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import typing
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import fastapi
|
||||
import prisma.enums
|
||||
@@ -10,7 +9,7 @@ import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
from backend.data.db import query_raw_with_schema, transaction
|
||||
from backend.data.db import transaction
|
||||
from backend.data.graph import (
|
||||
GraphMeta,
|
||||
GraphModel,
|
||||
@@ -30,6 +29,8 @@ from backend.util.settings import Settings
|
||||
|
||||
from . import exceptions as store_exceptions
|
||||
from . import model as store_model
|
||||
from .embeddings import ensure_embedding
|
||||
from .hybrid_search import hybrid_search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
@@ -50,128 +51,77 @@ async def get_store_agents(
|
||||
page_size: int = 20,
|
||||
) -> store_model.StoreAgentsResponse:
|
||||
"""
|
||||
Get PUBLIC store agents from the StoreAgent view
|
||||
Get PUBLIC store agents from the StoreAgent view.
|
||||
|
||||
Search behavior:
|
||||
- With search_query: Uses hybrid search (semantic + lexical)
|
||||
- Fallback: If embeddings unavailable, gracefully degrades to lexical-only
|
||||
- Rationale: User-facing endpoint prioritizes availability over accuracy
|
||||
|
||||
Note: Admin operations (approval) use fail-fast to prevent inconsistent state.
|
||||
"""
|
||||
logger.debug(
|
||||
f"Getting store agents. featured={featured}, creators={creators}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
||||
)
|
||||
|
||||
search_used_hybrid = False
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
agents: list[dict[str, Any]] = []
|
||||
total = 0
|
||||
total_pages = 0
|
||||
|
||||
try:
|
||||
# If search_query is provided, use full-text search
|
||||
# If search_query is provided, use hybrid search (embeddings + tsvector)
|
||||
if search_query:
|
||||
offset = (page - 1) * page_size
|
||||
# Try hybrid search combining semantic and lexical signals
|
||||
# Falls back to lexical-only if OpenAI unavailable (user-facing, high SLA)
|
||||
try:
|
||||
agents, total = await hybrid_search(
|
||||
query=search_query,
|
||||
featured=featured,
|
||||
creators=creators,
|
||||
category=category,
|
||||
sorted_by="relevance", # Use hybrid scoring for relevance
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
search_used_hybrid = True
|
||||
except Exception as e:
|
||||
# Log error but fall back to lexical search for better UX
|
||||
logger.error(
|
||||
f"Hybrid search failed (likely OpenAI unavailable), "
|
||||
f"falling back to lexical search: {e}"
|
||||
)
|
||||
# search_used_hybrid remains False, will use fallback path below
|
||||
|
||||
# Whitelist allowed order_by columns
|
||||
ALLOWED_ORDER_BY = {
|
||||
"rating": "rating DESC, rank DESC",
|
||||
"runs": "runs DESC, rank DESC",
|
||||
"name": "agent_name ASC, rank ASC",
|
||||
"updated_at": "updated_at DESC, rank DESC",
|
||||
}
|
||||
# Convert hybrid search results (dict format) if hybrid succeeded
|
||||
if search_used_hybrid:
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
store_agent = store_model.StoreAgent(
|
||||
slug=agent["slug"],
|
||||
agent_name=agent["agent_name"],
|
||||
agent_image=(
|
||||
agent["agent_image"][0] if agent["agent_image"] else ""
|
||||
),
|
||||
creator=agent["creator_username"] or "Needs Profile",
|
||||
creator_avatar=agent["creator_avatar"] or "",
|
||||
sub_heading=agent["sub_heading"],
|
||||
description=agent["description"],
|
||||
runs=agent["runs"],
|
||||
rating=agent["rating"],
|
||||
)
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error parsing Store agent from hybrid search results: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Validate and get order clause
|
||||
if sorted_by and sorted_by in ALLOWED_ORDER_BY:
|
||||
order_by_clause = ALLOWED_ORDER_BY[sorted_by]
|
||||
else:
|
||||
order_by_clause = "updated_at DESC, rank DESC"
|
||||
|
||||
# Build WHERE conditions and parameters list
|
||||
where_parts: list[str] = []
|
||||
params: list[typing.Any] = [search_query] # $1 - search term
|
||||
param_index = 2 # Start at $2 for next parameter
|
||||
|
||||
# Always filter for available agents
|
||||
where_parts.append("is_available = true")
|
||||
|
||||
if featured:
|
||||
where_parts.append("featured = true")
|
||||
|
||||
if creators and creators:
|
||||
# Use ANY with array parameter
|
||||
where_parts.append(f"creator_username = ANY(${param_index})")
|
||||
params.append(creators)
|
||||
param_index += 1
|
||||
|
||||
if category and category:
|
||||
where_parts.append(f"${param_index} = ANY(categories)")
|
||||
params.append(category)
|
||||
param_index += 1
|
||||
|
||||
sql_where_clause: str = " AND ".join(where_parts) if where_parts else "1=1"
|
||||
|
||||
# Add pagination params
|
||||
params.extend([page_size, offset])
|
||||
limit_param = f"${param_index}"
|
||||
offset_param = f"${param_index + 1}"
|
||||
|
||||
# Execute full-text search query with parameterized values
|
||||
sql_query = f"""
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
ts_rank_cd(search, query) AS rank
|
||||
FROM {{schema_prefix}}"StoreAgent",
|
||||
plainto_tsquery('english', $1) AS query
|
||||
WHERE {sql_where_clause}
|
||||
AND search @@ query
|
||||
ORDER BY {order_by_clause}
|
||||
LIMIT {limit_param} OFFSET {offset_param}
|
||||
"""
|
||||
|
||||
# Count query for pagination - only uses search term parameter
|
||||
count_query = f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {{schema_prefix}}"StoreAgent",
|
||||
plainto_tsquery('english', $1) AS query
|
||||
WHERE {sql_where_clause}
|
||||
AND search @@ query
|
||||
"""
|
||||
|
||||
# Execute both queries with parameters
|
||||
agents = await query_raw_with_schema(sql_query, *params)
|
||||
|
||||
# For count, use params without pagination (last 2 params)
|
||||
count_params = params[:-2]
|
||||
count_result = await query_raw_with_schema(count_query, *count_params)
|
||||
|
||||
total = count_result[0]["count"] if count_result else 0
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# Convert raw results to StoreAgent models
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
try:
|
||||
store_agent = store_model.StoreAgent(
|
||||
slug=agent["slug"],
|
||||
agent_name=agent["agent_name"],
|
||||
agent_image=(
|
||||
agent["agent_image"][0] if agent["agent_image"] else ""
|
||||
),
|
||||
creator=agent["creator_username"] or "Needs Profile",
|
||||
creator_avatar=agent["creator_avatar"] or "",
|
||||
sub_heading=agent["sub_heading"],
|
||||
description=agent["description"],
|
||||
runs=agent["runs"],
|
||||
rating=agent["rating"],
|
||||
)
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing Store agent from search results: {e}")
|
||||
continue
|
||||
|
||||
else:
|
||||
# Non-search query path (original logic)
|
||||
if not search_used_hybrid:
|
||||
# Fallback path - use basic search or no search
|
||||
where_clause: prisma.types.StoreAgentWhereInput = {"is_available": True}
|
||||
if featured:
|
||||
where_clause["featured"] = featured
|
||||
@@ -180,6 +130,14 @@ async def get_store_agents(
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
# Add basic text search if search_query provided but hybrid failed
|
||||
if search_query:
|
||||
where_clause["OR"] = [
|
||||
{"agent_name": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"sub_heading": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"description": {"contains": search_query, "mode": "insensitive"}},
|
||||
]
|
||||
|
||||
order_by = []
|
||||
if sorted_by == "rating":
|
||||
order_by.append({"rating": "desc"})
|
||||
@@ -188,7 +146,7 @@ async def get_store_agents(
|
||||
elif sorted_by == "name":
|
||||
order_by.append({"agent_name": "asc"})
|
||||
|
||||
agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
db_agents = await prisma.models.StoreAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
@@ -199,7 +157,7 @@ async def get_store_agents(
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
store_agents: list[store_model.StoreAgent] = []
|
||||
for agent in agents:
|
||||
for agent in db_agents:
|
||||
try:
|
||||
# Create the StoreAgent object safely
|
||||
store_agent = store_model.StoreAgent(
|
||||
@@ -1577,7 +1535,7 @@ async def review_store_submission(
|
||||
)
|
||||
|
||||
# Update the AgentGraph with store listing data
|
||||
await prisma.models.AgentGraph.prisma().update(
|
||||
await prisma.models.AgentGraph.prisma(tx).update(
|
||||
where={
|
||||
"graphVersionId": {
|
||||
"id": store_listing_version.agentGraphId,
|
||||
@@ -1592,6 +1550,23 @@ async def review_store_submission(
|
||||
},
|
||||
)
|
||||
|
||||
# Generate embedding for approved listing (blocking - admin operation)
|
||||
# Inside transaction: if embedding fails, entire transaction rolls back
|
||||
embedding_success = await ensure_embedding(
|
||||
version_id=store_listing_version_id,
|
||||
name=store_listing_version.name,
|
||||
description=store_listing_version.description,
|
||||
sub_heading=store_listing_version.subHeading,
|
||||
categories=store_listing_version.categories or [],
|
||||
tx=tx,
|
||||
)
|
||||
if not embedding_success:
|
||||
raise ValueError(
|
||||
f"Failed to generate embedding for listing {store_listing_version_id}. "
|
||||
"This is likely due to OpenAI API being unavailable. "
|
||||
"Please try again later or contact support if the issue persists."
|
||||
)
|
||||
|
||||
await prisma.models.StoreListing.prisma(tx).update(
|
||||
where={"id": store_listing_version.StoreListing.id},
|
||||
data={
|
||||
|
||||
@@ -0,0 +1,628 @@
|
||||
"""
|
||||
Unified Content Embeddings Service
|
||||
|
||||
Handles generation and storage of OpenAI embeddings for all content types
|
||||
(store listings, blocks, documentation, library agents) to enable semantic/hybrid search.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import prisma
|
||||
from prisma.enums import ContentType
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from backend.api.features.store.content_handlers import CONTENT_HANDLERS
|
||||
from backend.data.db import execute_raw_with_schema, query_raw_with_schema
|
||||
from backend.util.clients import get_openai_client
|
||||
from backend.util.json import dumps
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# OpenAI embedding model configuration
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
# OpenAI embedding token limit (8,191 with 1 token buffer for safety)
|
||||
EMBEDDING_MAX_TOKENS = 8191
|
||||
|
||||
|
||||
def build_searchable_text(
|
||||
name: str,
|
||||
description: str,
|
||||
sub_heading: str,
|
||||
categories: list[str],
|
||||
) -> str:
|
||||
"""
|
||||
Build searchable text from listing version fields.
|
||||
|
||||
Combines relevant fields into a single string for embedding.
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# Name is important - include it
|
||||
if name:
|
||||
parts.append(name)
|
||||
|
||||
# Sub-heading provides context
|
||||
if sub_heading:
|
||||
parts.append(sub_heading)
|
||||
|
||||
# Description is the main content
|
||||
if description:
|
||||
parts.append(description)
|
||||
|
||||
# Categories help with semantic matching
|
||||
if categories:
|
||||
parts.append(" ".join(categories))
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
async def generate_embedding(text: str) -> list[float] | None:
|
||||
"""
|
||||
Generate embedding for text using OpenAI API.
|
||||
|
||||
Returns None if embedding generation fails.
|
||||
Fail-fast: no retries to maintain consistency with approval flow.
|
||||
"""
|
||||
try:
|
||||
client = get_openai_client()
|
||||
if not client:
|
||||
logger.error("openai_internal_api_key not set, cannot generate embedding")
|
||||
return None
|
||||
|
||||
# Truncate text to token limit using tiktoken
|
||||
# Character-based truncation is insufficient because token ratios vary by content type
|
||||
enc = encoding_for_model(EMBEDDING_MODEL)
|
||||
tokens = enc.encode(text)
|
||||
if len(tokens) > EMBEDDING_MAX_TOKENS:
|
||||
tokens = tokens[:EMBEDDING_MAX_TOKENS]
|
||||
truncated_text = enc.decode(tokens)
|
||||
logger.info(
|
||||
f"Truncated text from {len(enc.encode(text))} to {len(tokens)} tokens"
|
||||
)
|
||||
else:
|
||||
truncated_text = text
|
||||
|
||||
start_time = time.time()
|
||||
response = await client.embeddings.create(
|
||||
model=EMBEDDING_MODEL,
|
||||
input=truncated_text,
|
||||
)
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
embedding = response.data[0].embedding
|
||||
logger.info(
|
||||
f"Generated embedding: {len(embedding)} dims, "
|
||||
f"{len(tokens)} tokens, {latency_ms:.0f}ms"
|
||||
)
|
||||
return embedding
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate embedding: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def store_embedding(
|
||||
version_id: str,
|
||||
embedding: list[float],
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Store embedding in the database.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||
DEPRECATED: Use ensure_embedding() instead (includes searchable_text).
|
||||
"""
|
||||
return await store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=version_id,
|
||||
embedding=embedding,
|
||||
searchable_text="", # Empty for backward compat; ensure_embedding() populates this
|
||||
metadata=None,
|
||||
user_id=None, # Store agents are public
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
|
||||
async def store_content_embedding(
|
||||
content_type: ContentType,
|
||||
content_id: str,
|
||||
embedding: list[float],
|
||||
searchable_text: str,
|
||||
metadata: dict | None = None,
|
||||
user_id: str | None = None,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Store embedding in the unified content embeddings table.
|
||||
|
||||
New function for unified content embedding storage.
|
||||
Uses raw SQL since Prisma doesn't natively support pgvector.
|
||||
"""
|
||||
try:
|
||||
client = tx if tx else prisma.get_client()
|
||||
|
||||
# Convert embedding to PostgreSQL vector format
|
||||
embedding_str = embedding_to_vector_string(embedding)
|
||||
metadata_json = dumps(metadata or {})
|
||||
|
||||
# Upsert the embedding
|
||||
# WHERE clause in DO UPDATE prevents PostgreSQL 15 bug with NULLS NOT DISTINCT
|
||||
await execute_raw_with_schema(
|
||||
"""
|
||||
INSERT INTO {schema_prefix}"UnifiedContentEmbedding" (
|
||||
"id", "contentType", "contentId", "userId", "embedding", "searchableText", "metadata", "createdAt", "updatedAt"
|
||||
)
|
||||
VALUES (gen_random_uuid()::text, $1::{schema_prefix}"ContentType", $2, $3, $4::vector, $5, $6::jsonb, NOW(), NOW())
|
||||
ON CONFLICT ("contentType", "contentId", "userId")
|
||||
DO UPDATE SET
|
||||
"embedding" = $4::vector,
|
||||
"searchableText" = $5,
|
||||
"metadata" = $6::jsonb,
|
||||
"updatedAt" = NOW()
|
||||
WHERE {schema_prefix}"UnifiedContentEmbedding"."contentType" = $1::{schema_prefix}"ContentType"
|
||||
AND {schema_prefix}"UnifiedContentEmbedding"."contentId" = $2
|
||||
AND ({schema_prefix}"UnifiedContentEmbedding"."userId" = $3 OR ($3 IS NULL AND {schema_prefix}"UnifiedContentEmbedding"."userId" IS NULL))
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
embedding_str,
|
||||
searchable_text,
|
||||
metadata_json,
|
||||
client=client,
|
||||
set_public_search_path=True,
|
||||
)
|
||||
|
||||
logger.info(f"Stored embedding for {content_type}:{content_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store embedding for {content_type}:{content_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_embedding(version_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Retrieve embedding record for a listing version.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||
Returns dict with storeListingVersionId, embedding, timestamps or None if not found.
|
||||
"""
|
||||
result = await get_content_embedding(
|
||||
ContentType.STORE_AGENT, version_id, user_id=None
|
||||
)
|
||||
if result:
|
||||
# Transform to old format for backward compatibility
|
||||
return {
|
||||
"storeListingVersionId": result["contentId"],
|
||||
"embedding": result["embedding"],
|
||||
"createdAt": result["createdAt"],
|
||||
"updatedAt": result["updatedAt"],
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def get_content_embedding(
|
||||
content_type: ContentType, content_id: str, user_id: str | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Retrieve embedding record for any content type.
|
||||
|
||||
New function for unified content embedding retrieval.
|
||||
Returns dict with contentType, contentId, embedding, timestamps or None if not found.
|
||||
"""
|
||||
try:
|
||||
result = await query_raw_with_schema(
|
||||
"""
|
||||
SELECT
|
||||
"contentType",
|
||||
"contentId",
|
||||
"userId",
|
||||
"embedding"::text as "embedding",
|
||||
"searchableText",
|
||||
"metadata",
|
||||
"createdAt",
|
||||
"updatedAt"
|
||||
FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = $1::{schema_prefix}"ContentType" AND "contentId" = $2 AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
set_public_search_path=True,
|
||||
)
|
||||
|
||||
if result and len(result) > 0:
|
||||
return result[0]
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding for {content_type}:{content_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def ensure_embedding(
|
||||
version_id: str,
|
||||
name: str,
|
||||
description: str,
|
||||
sub_heading: str,
|
||||
categories: list[str],
|
||||
force: bool = False,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure an embedding exists for the listing version.
|
||||
|
||||
Creates embedding if missing. Use force=True to regenerate.
|
||||
Backward-compatible wrapper for store listings.
|
||||
|
||||
Args:
|
||||
version_id: The StoreListingVersion ID
|
||||
name: Agent name
|
||||
description: Agent description
|
||||
sub_heading: Agent sub-heading
|
||||
categories: Agent categories
|
||||
force: Force regeneration even if embedding exists
|
||||
tx: Optional transaction client
|
||||
|
||||
Returns:
|
||||
True if embedding exists/was created, False on failure
|
||||
"""
|
||||
try:
|
||||
# Check if embedding already exists
|
||||
if not force:
|
||||
existing = await get_embedding(version_id)
|
||||
if existing and existing.get("embedding"):
|
||||
logger.debug(f"Embedding for version {version_id} already exists")
|
||||
return True
|
||||
|
||||
# Build searchable text for embedding
|
||||
searchable_text = build_searchable_text(
|
||||
name, description, sub_heading, categories
|
||||
)
|
||||
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
if embedding is None:
|
||||
logger.warning(f"Could not generate embedding for version {version_id}")
|
||||
return False
|
||||
|
||||
# Store the embedding with metadata using new function
|
||||
metadata = {
|
||||
"name": name,
|
||||
"subHeading": sub_heading,
|
||||
"categories": categories,
|
||||
}
|
||||
return await store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id=version_id,
|
||||
embedding=embedding,
|
||||
searchable_text=searchable_text,
|
||||
metadata=metadata,
|
||||
user_id=None, # Store agents are public
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure embedding for version {version_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def delete_embedding(version_id: str) -> bool:
|
||||
"""
|
||||
Delete embedding for a listing version.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing store listing usage.
|
||||
Note: This is usually handled automatically by CASCADE delete,
|
||||
but provided for manual cleanup if needed.
|
||||
"""
|
||||
return await delete_content_embedding(ContentType.STORE_AGENT, version_id)
|
||||
|
||||
|
||||
async def delete_content_embedding(
|
||||
content_type: ContentType, content_id: str, user_id: str | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Delete embedding for any content type.
|
||||
|
||||
New function for unified content embedding deletion.
|
||||
Note: This is usually handled automatically by CASCADE delete,
|
||||
but provided for manual cleanup if needed.
|
||||
|
||||
Args:
|
||||
content_type: The type of content (STORE_AGENT, LIBRARY_AGENT, etc.)
|
||||
content_id: The unique identifier for the content
|
||||
user_id: Optional user ID. For public content (STORE_AGENT, BLOCK), pass None.
|
||||
For user-scoped content (LIBRARY_AGENT), pass the user's ID to avoid
|
||||
deleting embeddings belonging to other users.
|
||||
|
||||
Returns:
|
||||
True if deletion succeeded, False otherwise
|
||||
"""
|
||||
try:
|
||||
client = prisma.get_client()
|
||||
|
||||
await execute_raw_with_schema(
|
||||
"""
|
||||
DELETE FROM {schema_prefix}"UnifiedContentEmbedding"
|
||||
WHERE "contentType" = $1::{schema_prefix}"ContentType"
|
||||
AND "contentId" = $2
|
||||
AND ("userId" = $3 OR ($3 IS NULL AND "userId" IS NULL))
|
||||
""",
|
||||
content_type,
|
||||
content_id,
|
||||
user_id,
|
||||
client=client,
|
||||
)
|
||||
|
||||
user_str = f" (user: {user_id})" if user_id else ""
|
||||
logger.info(f"Deleted embedding for {content_type}:{content_id}{user_str}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete embedding for {content_type}:{content_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def get_embedding_stats() -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics about embedding coverage for all content types.
|
||||
|
||||
Returns stats per content type and overall totals.
|
||||
"""
|
||||
try:
|
||||
stats_by_type = {}
|
||||
total_items = 0
|
||||
total_with_embeddings = 0
|
||||
total_without_embeddings = 0
|
||||
|
||||
# Aggregate stats from all handlers
|
||||
for content_type, handler in CONTENT_HANDLERS.items():
|
||||
try:
|
||||
stats = await handler.get_stats()
|
||||
stats_by_type[content_type.value] = {
|
||||
"total": stats["total"],
|
||||
"with_embeddings": stats["with_embeddings"],
|
||||
"without_embeddings": stats["without_embeddings"],
|
||||
"coverage_percent": (
|
||||
round(stats["with_embeddings"] / stats["total"] * 100, 1)
|
||||
if stats["total"] > 0
|
||||
else 0
|
||||
),
|
||||
}
|
||||
|
||||
total_items += stats["total"]
|
||||
total_with_embeddings += stats["with_embeddings"]
|
||||
total_without_embeddings += stats["without_embeddings"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get stats for {content_type.value}: {e}")
|
||||
stats_by_type[content_type.value] = {
|
||||
"total": 0,
|
||||
"with_embeddings": 0,
|
||||
"without_embeddings": 0,
|
||||
"coverage_percent": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
return {
|
||||
"by_type": stats_by_type,
|
||||
"totals": {
|
||||
"total": total_items,
|
||||
"with_embeddings": total_with_embeddings,
|
||||
"without_embeddings": total_without_embeddings,
|
||||
"coverage_percent": (
|
||||
round(total_with_embeddings / total_items * 100, 1)
|
||||
if total_items > 0
|
||||
else 0
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get embedding stats: {e}")
|
||||
return {
|
||||
"by_type": {},
|
||||
"totals": {
|
||||
"total": 0,
|
||||
"with_embeddings": 0,
|
||||
"without_embeddings": 0,
|
||||
"coverage_percent": 0,
|
||||
},
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
|
||||
async def backfill_missing_embeddings(batch_size: int = 10) -> dict[str, Any]:
|
||||
"""
|
||||
Generate embeddings for approved listings that don't have them.
|
||||
|
||||
BACKWARD COMPATIBILITY: Maintained for existing usage.
|
||||
This now delegates to backfill_all_content_types() to process all content types.
|
||||
|
||||
Args:
|
||||
batch_size: Number of embeddings to generate per content type
|
||||
|
||||
Returns:
|
||||
Dict with success/failure counts aggregated across all content types
|
||||
"""
|
||||
# Delegate to the new generic backfill system
|
||||
result = await backfill_all_content_types(batch_size)
|
||||
|
||||
# Return in the old format for backward compatibility
|
||||
return result["totals"]
|
||||
|
||||
|
||||
async def backfill_all_content_types(batch_size: int = 10) -> dict[str, Any]:
|
||||
"""
|
||||
Generate embeddings for all content types using registered handlers.
|
||||
|
||||
Processes content types in order: BLOCK → STORE_AGENT → DOCUMENTATION.
|
||||
This ensures foundational content (blocks) are searchable first.
|
||||
|
||||
Args:
|
||||
batch_size: Number of embeddings to generate per content type
|
||||
|
||||
Returns:
|
||||
Dict with stats per content type and overall totals
|
||||
"""
|
||||
results_by_type = {}
|
||||
total_processed = 0
|
||||
total_success = 0
|
||||
total_failed = 0
|
||||
|
||||
# Process content types in explicit order
|
||||
processing_order = [
|
||||
ContentType.BLOCK,
|
||||
ContentType.STORE_AGENT,
|
||||
ContentType.DOCUMENTATION,
|
||||
]
|
||||
|
||||
for content_type in processing_order:
|
||||
handler = CONTENT_HANDLERS.get(content_type)
|
||||
if not handler:
|
||||
logger.warning(f"No handler registered for {content_type.value}")
|
||||
continue
|
||||
try:
|
||||
logger.info(f"Processing {content_type.value} content type...")
|
||||
|
||||
# Get missing items from handler
|
||||
missing_items = await handler.get_missing_items(batch_size)
|
||||
|
||||
if not missing_items:
|
||||
results_by_type[content_type.value] = {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"message": "No missing embeddings",
|
||||
}
|
||||
continue
|
||||
|
||||
# Process embeddings concurrently for better performance
|
||||
embedding_tasks = [
|
||||
ensure_content_embedding(
|
||||
content_type=item.content_type,
|
||||
content_id=item.content_id,
|
||||
searchable_text=item.searchable_text,
|
||||
metadata=item.metadata,
|
||||
user_id=item.user_id,
|
||||
)
|
||||
for item in missing_items
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*embedding_tasks, return_exceptions=True)
|
||||
|
||||
success = sum(1 for result in results if result is True)
|
||||
failed = len(results) - success
|
||||
|
||||
results_by_type[content_type.value] = {
|
||||
"processed": len(missing_items),
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"message": f"Backfilled {success} embeddings, {failed} failed",
|
||||
}
|
||||
|
||||
total_processed += len(missing_items)
|
||||
total_success += success
|
||||
total_failed += failed
|
||||
|
||||
logger.info(
|
||||
f"{content_type.value}: processed {len(missing_items)}, "
|
||||
f"success {success}, failed {failed}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process {content_type.value}: {e}")
|
||||
results_by_type[content_type.value] = {
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
return {
|
||||
"by_type": results_by_type,
|
||||
"totals": {
|
||||
"processed": total_processed,
|
||||
"success": total_success,
|
||||
"failed": total_failed,
|
||||
"message": f"Overall: {total_success} succeeded, {total_failed} failed",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def embed_query(query: str) -> list[float] | None:
|
||||
"""
|
||||
Generate embedding for a search query.
|
||||
|
||||
Same as generate_embedding but with clearer intent.
|
||||
"""
|
||||
return await generate_embedding(query)
|
||||
|
||||
|
||||
def embedding_to_vector_string(embedding: list[float]) -> str:
|
||||
"""Convert embedding list to PostgreSQL vector string format."""
|
||||
return "[" + ",".join(str(x) for x in embedding) + "]"
|
||||
|
||||
|
||||
async def ensure_content_embedding(
|
||||
content_type: ContentType,
|
||||
content_id: str,
|
||||
searchable_text: str,
|
||||
metadata: dict | None = None,
|
||||
user_id: str | None = None,
|
||||
force: bool = False,
|
||||
tx: prisma.Prisma | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure an embedding exists for any content type.
|
||||
|
||||
Generic function for creating embeddings for store agents, blocks, docs, etc.
|
||||
|
||||
Args:
|
||||
content_type: ContentType enum value (STORE_AGENT, BLOCK, etc.)
|
||||
content_id: Unique identifier for the content
|
||||
searchable_text: Combined text for embedding generation
|
||||
metadata: Optional metadata to store with embedding
|
||||
force: Force regeneration even if embedding exists
|
||||
tx: Optional transaction client
|
||||
|
||||
Returns:
|
||||
True if embedding exists/was created, False on failure
|
||||
"""
|
||||
try:
|
||||
# Check if embedding already exists
|
||||
if not force:
|
||||
existing = await get_content_embedding(content_type, content_id, user_id)
|
||||
if existing and existing.get("embedding"):
|
||||
logger.debug(
|
||||
f"Embedding for {content_type}:{content_id} already exists"
|
||||
)
|
||||
return True
|
||||
|
||||
# Generate new embedding
|
||||
embedding = await generate_embedding(searchable_text)
|
||||
if embedding is None:
|
||||
logger.warning(
|
||||
f"Could not generate embedding for {content_type}:{content_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Store the embedding
|
||||
return await store_content_embedding(
|
||||
content_type=content_type,
|
||||
content_id=content_id,
|
||||
embedding=embedding,
|
||||
searchable_text=searchable_text,
|
||||
metadata=metadata or {},
|
||||
user_id=user_id,
|
||||
tx=tx,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure embedding for {content_type}:{content_id}: {e}")
|
||||
return False
|
||||
@@ -0,0 +1,329 @@
|
||||
"""
|
||||
Integration tests for embeddings with schema handling.
|
||||
|
||||
These tests verify that embeddings operations work correctly across different database schemas.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store import embeddings
|
||||
|
||||
# Schema prefix tests removed - functionality moved to db.raw_with_schema() helper
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_store_content_embedding_with_schema():
|
||||
"""Test storing embeddings with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
embedding=[0.1] * 1536,
|
||||
searchable_text="test text",
|
||||
metadata={"test": "data"},
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_client.execute_raw.called
|
||||
|
||||
# Get the SQL query that was executed
|
||||
call_args = mock_client.execute_raw.call_args
|
||||
sql_query = call_args[0][0]
|
||||
|
||||
# Verify schema prefix is in the query
|
||||
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
||||
|
||||
# Verify result
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_get_content_embedding_with_schema():
|
||||
"""Test retrieving embeddings with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.query_raw.return_value = [
|
||||
{
|
||||
"contentType": "STORE_AGENT",
|
||||
"contentId": "test-id",
|
||||
"userId": None,
|
||||
"embedding": "[0.1, 0.2]",
|
||||
"searchableText": "test",
|
||||
"metadata": {},
|
||||
"createdAt": "2024-01-01",
|
||||
"updatedAt": "2024-01-01",
|
||||
}
|
||||
]
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.get_content_embedding(
|
||||
ContentType.STORE_AGENT,
|
||||
"test-id",
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_client.query_raw.called
|
||||
|
||||
# Get the SQL query that was executed
|
||||
call_args = mock_client.query_raw.call_args
|
||||
sql_query = call_args[0][0]
|
||||
|
||||
# Verify schema prefix is in the query
|
||||
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
||||
|
||||
# Verify result
|
||||
assert result is not None
|
||||
assert result["contentId"] == "test-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_delete_content_embedding_with_schema():
|
||||
"""Test deleting embeddings with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.delete_content_embedding(
|
||||
ContentType.STORE_AGENT,
|
||||
"test-id",
|
||||
)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_client.execute_raw.called
|
||||
|
||||
# Get the SQL query that was executed
|
||||
call_args = mock_client.execute_raw.call_args
|
||||
sql_query = call_args[0][0]
|
||||
|
||||
# Verify schema prefix is in the query
|
||||
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
||||
|
||||
# Verify result
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_get_embedding_stats_with_schema():
|
||||
"""Test embedding statistics with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
# Mock both query results
|
||||
mock_client.query_raw.side_effect = [
|
||||
[{"count": 100}], # total_approved
|
||||
[{"count": 80}], # with_embeddings
|
||||
]
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.get_embedding_stats()
|
||||
|
||||
# Verify both queries were called
|
||||
assert mock_client.query_raw.call_count == 2
|
||||
|
||||
# Get both SQL queries
|
||||
first_call = mock_client.query_raw.call_args_list[0]
|
||||
second_call = mock_client.query_raw.call_args_list[1]
|
||||
|
||||
first_sql = first_call[0][0]
|
||||
second_sql = second_call[0][0]
|
||||
|
||||
# Verify schema prefix in both queries
|
||||
assert '"platform"."StoreListingVersion"' in first_sql
|
||||
assert '"platform"."StoreListingVersion"' in second_sql
|
||||
assert '"platform"."UnifiedContentEmbedding"' in second_sql
|
||||
|
||||
# Verify results
|
||||
assert result["total_approved"] == 100
|
||||
assert result["with_embeddings"] == 80
|
||||
assert result["without_embeddings"] == 20
|
||||
assert result["coverage_percent"] == 80.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_backfill_missing_embeddings_with_schema():
|
||||
"""Test backfilling embeddings with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
# Mock missing embeddings query
|
||||
mock_client.query_raw.return_value = [
|
||||
{
|
||||
"id": "version-1",
|
||||
"name": "Test Agent",
|
||||
"description": "Test description",
|
||||
"subHeading": "Test heading",
|
||||
"categories": ["test"],
|
||||
}
|
||||
]
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.ensure_embedding"
|
||||
) as mock_ensure:
|
||||
mock_ensure.return_value = True
|
||||
|
||||
result = await embeddings.backfill_missing_embeddings(batch_size=10)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_client.query_raw.called
|
||||
|
||||
# Get the SQL query
|
||||
call_args = mock_client.query_raw.call_args
|
||||
sql_query = call_args[0][0]
|
||||
|
||||
# Verify schema prefix in query
|
||||
assert '"platform"."StoreListingVersion"' in sql_query
|
||||
assert '"platform"."UnifiedContentEmbedding"' in sql_query
|
||||
|
||||
# Verify ensure_embedding was called
|
||||
assert mock_ensure.called
|
||||
|
||||
# Verify results
|
||||
assert result["processed"] == 1
|
||||
assert result["success"] == 1
|
||||
assert result["failed"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_ensure_content_embedding_with_schema():
|
||||
"""Test ensuring embeddings exist with proper schema handling."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_content_embedding"
|
||||
) as mock_get:
|
||||
# Simulate no existing embedding
|
||||
mock_get.return_value = None
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.generate_embedding"
|
||||
) as mock_generate:
|
||||
mock_generate.return_value = [0.1] * 1536
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.store_content_embedding"
|
||||
) as mock_store:
|
||||
mock_store.return_value = True
|
||||
|
||||
result = await embeddings.ensure_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
searchable_text="test text",
|
||||
metadata={"test": "data"},
|
||||
user_id=None,
|
||||
force=False,
|
||||
)
|
||||
|
||||
# Verify the flow
|
||||
assert mock_get.called
|
||||
assert mock_generate.called
|
||||
assert mock_store.called
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_backward_compatibility_store_embedding():
|
||||
"""Test backward compatibility wrapper for store_embedding."""
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.store_content_embedding"
|
||||
) as mock_store:
|
||||
mock_store.return_value = True
|
||||
|
||||
result = await embeddings.store_embedding(
|
||||
version_id="test-version-id",
|
||||
embedding=[0.1] * 1536,
|
||||
tx=None,
|
||||
)
|
||||
|
||||
# Verify it calls the new function with correct parameters
|
||||
assert mock_store.called
|
||||
call_args = mock_store.call_args
|
||||
|
||||
assert call_args[1]["content_type"] == ContentType.STORE_AGENT
|
||||
assert call_args[1]["content_id"] == "test-version-id"
|
||||
assert call_args[1]["user_id"] is None
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_backward_compatibility_get_embedding():
|
||||
"""Test backward compatibility wrapper for get_embedding."""
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_content_embedding"
|
||||
) as mock_get:
|
||||
mock_get.return_value = {
|
||||
"contentType": "STORE_AGENT",
|
||||
"contentId": "test-version-id",
|
||||
"embedding": "[0.1, 0.2]",
|
||||
"createdAt": "2024-01-01",
|
||||
"updatedAt": "2024-01-01",
|
||||
}
|
||||
|
||||
result = await embeddings.get_embedding("test-version-id")
|
||||
|
||||
# Verify it calls the new function
|
||||
assert mock_get.called
|
||||
|
||||
# Verify it transforms to old format
|
||||
assert result is not None
|
||||
assert result["storeListingVersionId"] == "test-version-id"
|
||||
assert "embedding" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_schema_handling_error_cases():
|
||||
"""Test error handling in schema-aware operations."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch("prisma.get_client") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.execute_raw.side_effect = Exception("Database error")
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.store_content_embedding(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
embedding=[0.1] * 1536,
|
||||
searchable_text="test",
|
||||
metadata=None,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Should return False on error, not raise
|
||||
assert result is False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -0,0 +1,387 @@
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import prisma
|
||||
import pytest
|
||||
from prisma import Prisma
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.store import embeddings
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_prisma():
|
||||
"""Setup Prisma client for tests."""
|
||||
try:
|
||||
Prisma()
|
||||
except prisma.errors.ClientAlreadyRegisteredError:
|
||||
pass
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_build_searchable_text():
|
||||
"""Test searchable text building from listing fields."""
|
||||
result = embeddings.build_searchable_text(
|
||||
name="AI Assistant",
|
||||
description="A helpful AI assistant for productivity",
|
||||
sub_heading="Boost your productivity",
|
||||
categories=["AI", "Productivity"],
|
||||
)
|
||||
|
||||
expected = "AI Assistant Boost your productivity A helpful AI assistant for productivity AI Productivity"
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_build_searchable_text_empty_fields():
|
||||
"""Test searchable text building with empty fields."""
|
||||
result = embeddings.build_searchable_text(
|
||||
name="", description="Test description", sub_heading="", categories=[]
|
||||
)
|
||||
|
||||
assert result == "Test description"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_generate_embedding_success():
|
||||
"""Test successful embedding generation."""
|
||||
# Mock OpenAI response
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [MagicMock()]
|
||||
mock_response.data[0].embedding = [0.1, 0.2, 0.3] * 512 # 1536 dimensions
|
||||
|
||||
# Use AsyncMock for async embeddings.create method
|
||||
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Patch at the point of use in embeddings.py
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_openai_client"
|
||||
) as mock_get_client:
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.generate_embedding("test text")
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 1536
|
||||
assert result[0] == 0.1
|
||||
|
||||
mock_client.embeddings.create.assert_called_once_with(
|
||||
model="text-embedding-3-small", input="test text"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_generate_embedding_no_api_key():
|
||||
"""Test embedding generation without API key."""
|
||||
# Patch at the point of use in embeddings.py
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_openai_client"
|
||||
) as mock_get_client:
|
||||
mock_get_client.return_value = None
|
||||
|
||||
result = await embeddings.generate_embedding("test text")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_generate_embedding_api_error():
|
||||
"""Test embedding generation with API error."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.embeddings.create = AsyncMock(side_effect=Exception("API Error"))
|
||||
|
||||
# Patch at the point of use in embeddings.py
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_openai_client"
|
||||
) as mock_get_client:
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
result = await embeddings.generate_embedding("test text")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_generate_embedding_text_truncation():
|
||||
"""Test that long text is properly truncated using tiktoken."""
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [MagicMock()]
|
||||
mock_response.data[0].embedding = [0.1] * 1536
|
||||
|
||||
# Use AsyncMock for async embeddings.create method
|
||||
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Patch at the point of use in embeddings.py
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.get_openai_client"
|
||||
) as mock_get_client:
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# Create text that will exceed 8191 tokens
|
||||
# Use varied characters to ensure token-heavy text: each word is ~1 token
|
||||
words = [f"word{i}" for i in range(10000)]
|
||||
long_text = " ".join(words) # ~10000 tokens
|
||||
|
||||
await embeddings.generate_embedding(long_text)
|
||||
|
||||
# Verify text was truncated to 8191 tokens
|
||||
call_args = mock_client.embeddings.create.call_args
|
||||
truncated_text = call_args.kwargs["input"]
|
||||
|
||||
# Count actual tokens in truncated text
|
||||
enc = encoding_for_model("text-embedding-3-small")
|
||||
actual_tokens = len(enc.encode(truncated_text))
|
||||
|
||||
# Should be at or just under 8191 tokens
|
||||
assert actual_tokens <= 8191
|
||||
# Should be close to the limit (not over-truncated)
|
||||
assert actual_tokens >= 8100
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_embedding_success(mocker):
|
||||
"""Test successful embedding storage."""
|
||||
mock_client = mocker.AsyncMock()
|
||||
mock_client.execute_raw = mocker.AsyncMock()
|
||||
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
result = await embeddings.store_embedding(
|
||||
version_id="test-version-id", embedding=embedding, tx=mock_client
|
||||
)
|
||||
|
||||
assert result is True
|
||||
# execute_raw is called twice: once for SET search_path, once for INSERT
|
||||
assert mock_client.execute_raw.call_count == 2
|
||||
|
||||
# First call: SET search_path
|
||||
first_call_args = mock_client.execute_raw.call_args_list[0][0]
|
||||
assert "SET search_path" in first_call_args[0]
|
||||
|
||||
# Second call: INSERT query with the actual data
|
||||
second_call_args = mock_client.execute_raw.call_args_list[1][0]
|
||||
assert "test-version-id" in second_call_args
|
||||
assert "[0.1,0.2,0.3]" in second_call_args
|
||||
assert None in second_call_args # userId should be None for store agents
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_store_embedding_database_error(mocker):
|
||||
"""Test embedding storage with database error."""
|
||||
mock_client = mocker.AsyncMock()
|
||||
mock_client.execute_raw.side_effect = Exception("Database error")
|
||||
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
result = await embeddings.store_embedding(
|
||||
version_id="test-version-id", embedding=embedding, tx=mock_client
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_success():
|
||||
"""Test successful embedding retrieval."""
|
||||
mock_result = [
|
||||
{
|
||||
"contentType": "STORE_AGENT",
|
||||
"contentId": "test-version-id",
|
||||
"userId": None,
|
||||
"embedding": "[0.1,0.2,0.3]",
|
||||
"searchableText": "Test text",
|
||||
"metadata": {},
|
||||
"createdAt": "2024-01-01T00:00:00Z",
|
||||
"updatedAt": "2024-01-01T00:00:00Z",
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=mock_result,
|
||||
):
|
||||
result = await embeddings.get_embedding("test-version-id")
|
||||
|
||||
assert result is not None
|
||||
assert result["storeListingVersionId"] == "test-version-id"
|
||||
assert result["embedding"] == "[0.1,0.2,0.3]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_not_found():
|
||||
"""Test embedding retrieval when not found."""
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
result = await embeddings.get_embedding("test-version-id")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
@patch("backend.api.features.store.embeddings.store_embedding")
|
||||
@patch("backend.api.features.store.embeddings.get_embedding")
|
||||
async def test_ensure_embedding_already_exists(mock_get, mock_store, mock_generate):
|
||||
"""Test ensure_embedding when embedding already exists."""
|
||||
mock_get.return_value = {"embedding": "[0.1,0.2,0.3]"}
|
||||
|
||||
result = await embeddings.ensure_embedding(
|
||||
version_id="test-id",
|
||||
name="Test",
|
||||
description="Test description",
|
||||
sub_heading="Test heading",
|
||||
categories=["test"],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_generate.assert_not_called()
|
||||
mock_store.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
@patch("backend.api.features.store.embeddings.store_content_embedding")
|
||||
@patch("backend.api.features.store.embeddings.get_embedding")
|
||||
async def test_ensure_embedding_create_new(mock_get, mock_store, mock_generate):
|
||||
"""Test ensure_embedding creating new embedding."""
|
||||
mock_get.return_value = None
|
||||
mock_generate.return_value = [0.1, 0.2, 0.3]
|
||||
mock_store.return_value = True
|
||||
|
||||
result = await embeddings.ensure_embedding(
|
||||
version_id="test-id",
|
||||
name="Test",
|
||||
description="Test description",
|
||||
sub_heading="Test heading",
|
||||
categories=["test"],
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_generate.assert_called_once_with("Test Test heading Test description test")
|
||||
mock_store.assert_called_once_with(
|
||||
content_type=ContentType.STORE_AGENT,
|
||||
content_id="test-id",
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
searchable_text="Test Test heading Test description test",
|
||||
metadata={"name": "Test", "subHeading": "Test heading", "categories": ["test"]},
|
||||
user_id=None,
|
||||
tx=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.generate_embedding")
|
||||
@patch("backend.api.features.store.embeddings.get_embedding")
|
||||
async def test_ensure_embedding_generation_fails(mock_get, mock_generate):
|
||||
"""Test ensure_embedding when generation fails."""
|
||||
mock_get.return_value = None
|
||||
mock_generate.return_value = None
|
||||
|
||||
result = await embeddings.ensure_embedding(
|
||||
version_id="test-id",
|
||||
name="Test",
|
||||
description="Test description",
|
||||
sub_heading="Test heading",
|
||||
categories=["test"],
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_get_embedding_stats():
|
||||
"""Test embedding statistics retrieval."""
|
||||
# Mock approved count query and embedded count query
|
||||
mock_approved_result = [{"count": 100}]
|
||||
mock_embedded_result = [{"count": 75}]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
side_effect=[mock_approved_result, mock_embedded_result],
|
||||
):
|
||||
result = await embeddings.get_embedding_stats()
|
||||
|
||||
assert result["total_approved"] == 100
|
||||
assert result["with_embeddings"] == 75
|
||||
assert result["without_embeddings"] == 25
|
||||
assert result["coverage_percent"] == 75.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@patch("backend.api.features.store.embeddings.ensure_embedding")
|
||||
async def test_backfill_missing_embeddings_success(mock_ensure):
|
||||
"""Test backfill with successful embedding generation."""
|
||||
# Mock missing embeddings query
|
||||
mock_missing = [
|
||||
{
|
||||
"id": "version-1",
|
||||
"name": "Agent 1",
|
||||
"description": "Description 1",
|
||||
"subHeading": "Heading 1",
|
||||
"categories": ["AI"],
|
||||
},
|
||||
{
|
||||
"id": "version-2",
|
||||
"name": "Agent 2",
|
||||
"description": "Description 2",
|
||||
"subHeading": "Heading 2",
|
||||
"categories": ["Productivity"],
|
||||
},
|
||||
]
|
||||
|
||||
# Mock ensure_embedding to succeed for first, fail for second
|
||||
mock_ensure.side_effect = [True, False]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=mock_missing,
|
||||
):
|
||||
result = await embeddings.backfill_missing_embeddings(batch_size=5)
|
||||
|
||||
assert result["processed"] == 2
|
||||
assert result["success"] == 1
|
||||
assert result["failed"] == 1
|
||||
assert mock_ensure.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_backfill_missing_embeddings_no_missing():
|
||||
"""Test backfill when no embeddings are missing."""
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.query_raw_with_schema",
|
||||
return_value=[],
|
||||
):
|
||||
result = await embeddings.backfill_missing_embeddings(batch_size=5)
|
||||
|
||||
assert result["processed"] == 0
|
||||
assert result["success"] == 0
|
||||
assert result["failed"] == 0
|
||||
assert result["message"] == "No missing embeddings"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_embedding_to_vector_string():
|
||||
"""Test embedding to PostgreSQL vector string conversion."""
|
||||
embedding = [0.1, 0.2, 0.3, -0.4]
|
||||
result = embeddings.embedding_to_vector_string(embedding)
|
||||
assert result == "[0.1,0.2,0.3,-0.4]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_embed_query():
|
||||
"""Test embed_query function (alias for generate_embedding)."""
|
||||
with patch(
|
||||
"backend.api.features.store.embeddings.generate_embedding"
|
||||
) as mock_generate:
|
||||
mock_generate.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
result = await embeddings.embed_query("test query")
|
||||
|
||||
assert result == [0.1, 0.2, 0.3]
|
||||
mock_generate.assert_called_once_with("test query")
|
||||
@@ -0,0 +1,393 @@
|
||||
"""
|
||||
Hybrid Search for Store Agents
|
||||
|
||||
Combines semantic (embedding) search with lexical (tsvector) search
|
||||
for improved relevance in marketplace agent discovery.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from backend.api.features.store.embeddings import (
|
||||
embed_query,
|
||||
embedding_to_vector_string,
|
||||
)
|
||||
from backend.data.db import query_raw_with_schema
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridSearchWeights:
|
||||
"""Weights for combining search signals."""
|
||||
|
||||
semantic: float = 0.30 # Embedding cosine similarity
|
||||
lexical: float = 0.30 # tsvector ts_rank_cd score
|
||||
category: float = 0.20 # Category match boost
|
||||
recency: float = 0.10 # Newer agents ranked higher
|
||||
popularity: float = 0.10 # Agent usage/runs (PageRank-like)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate weights are non-negative and sum to approximately 1.0."""
|
||||
total = (
|
||||
self.semantic
|
||||
+ self.lexical
|
||||
+ self.category
|
||||
+ self.recency
|
||||
+ self.popularity
|
||||
)
|
||||
|
||||
if any(
|
||||
w < 0
|
||||
for w in [
|
||||
self.semantic,
|
||||
self.lexical,
|
||||
self.category,
|
||||
self.recency,
|
||||
self.popularity,
|
||||
]
|
||||
):
|
||||
raise ValueError("All weights must be non-negative")
|
||||
|
||||
if not (0.99 <= total <= 1.01):
|
||||
raise ValueError(f"Weights must sum to ~1.0, got {total:.3f}")
|
||||
|
||||
|
||||
DEFAULT_WEIGHTS = HybridSearchWeights()
|
||||
|
||||
# Minimum relevance score threshold - agents below this are filtered out
|
||||
# With weights (0.30 semantic + 0.30 lexical + 0.20 category + 0.10 recency + 0.10 popularity):
|
||||
# - 0.20 means at least ~60% semantic match OR strong lexical match required
|
||||
# - Ensures only genuinely relevant results are returned
|
||||
# - Recency/popularity alone (0.10 each) won't pass the threshold
|
||||
DEFAULT_MIN_SCORE = 0.20
|
||||
|
||||
|
||||
@dataclass
|
||||
class HybridSearchResult:
|
||||
"""A single search result with score breakdown."""
|
||||
|
||||
slug: str
|
||||
agent_name: str
|
||||
agent_image: str
|
||||
creator_username: str
|
||||
creator_avatar: str
|
||||
sub_heading: str
|
||||
description: str
|
||||
runs: int
|
||||
rating: float
|
||||
categories: list[str]
|
||||
featured: bool
|
||||
is_available: bool
|
||||
updated_at: datetime
|
||||
|
||||
# Score breakdown (for debugging/tuning)
|
||||
combined_score: float
|
||||
semantic_score: float = 0.0
|
||||
lexical_score: float = 0.0
|
||||
category_score: float = 0.0
|
||||
recency_score: float = 0.0
|
||||
popularity_score: float = 0.0
|
||||
|
||||
|
||||
async def hybrid_search(
|
||||
query: str,
|
||||
featured: bool = False,
|
||||
creators: list[str] | None = None,
|
||||
category: str | None = None,
|
||||
sorted_by: (
|
||||
Literal["relevance", "rating", "runs", "name", "updated_at"] | None
|
||||
) = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
weights: HybridSearchWeights | None = None,
|
||||
min_score: float | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Perform hybrid search combining semantic and lexical signals.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
featured: Filter for featured agents only
|
||||
creators: Filter by creator usernames
|
||||
category: Filter by category
|
||||
sorted_by: Sort order (relevance uses hybrid scoring)
|
||||
page: Page number (1-indexed)
|
||||
page_size: Results per page
|
||||
weights: Custom weights for search signals
|
||||
min_score: Minimum relevance score threshold (0-1). Results below
|
||||
this score are filtered out. Defaults to DEFAULT_MIN_SCORE.
|
||||
|
||||
Returns:
|
||||
Tuple of (results list, total count). Returns empty list if no
|
||||
results meet the minimum relevance threshold.
|
||||
"""
|
||||
# Validate inputs
|
||||
query = query.strip()
|
||||
if not query:
|
||||
return [], 0 # Empty query returns no results
|
||||
|
||||
if page < 1:
|
||||
page = 1
|
||||
if page_size < 1:
|
||||
page_size = 1
|
||||
if page_size > 100: # Cap at reasonable limit to prevent performance issues
|
||||
page_size = 100
|
||||
|
||||
if weights is None:
|
||||
weights = DEFAULT_WEIGHTS
|
||||
if min_score is None:
|
||||
min_score = DEFAULT_MIN_SCORE
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = await embed_query(query)
|
||||
|
||||
# Build WHERE clause conditions
|
||||
where_parts: list[str] = ["sa.is_available = true"]
|
||||
params: list[Any] = []
|
||||
param_index = 1
|
||||
|
||||
# Add search query for lexical matching
|
||||
params.append(query)
|
||||
query_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
# Add lowercased query for category matching
|
||||
params.append(query.lower())
|
||||
query_lower_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
if featured:
|
||||
where_parts.append("sa.featured = true")
|
||||
|
||||
if creators:
|
||||
where_parts.append(f"sa.creator_username = ANY(${param_index})")
|
||||
params.append(creators)
|
||||
param_index += 1
|
||||
|
||||
if category:
|
||||
where_parts.append(f"${param_index} = ANY(sa.categories)")
|
||||
params.append(category)
|
||||
param_index += 1
|
||||
|
||||
# Safe: where_parts only contains hardcoded strings with $N parameter placeholders
|
||||
# No user input is concatenated directly into the SQL string
|
||||
where_clause = " AND ".join(where_parts)
|
||||
|
||||
# Embedding is required for hybrid search - fail fast if unavailable
|
||||
if query_embedding is None or not query_embedding:
|
||||
# Log detailed error server-side
|
||||
logger.error(
|
||||
"Failed to generate query embedding. "
|
||||
"Check that openai_internal_api_key is configured and OpenAI API is accessible."
|
||||
)
|
||||
# Raise generic error to client
|
||||
raise ValueError("Search service temporarily unavailable")
|
||||
|
||||
# Add embedding parameter
|
||||
embedding_str = embedding_to_vector_string(query_embedding)
|
||||
params.append(embedding_str)
|
||||
embedding_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
# Add weight parameters for SQL calculation
|
||||
params.append(weights.semantic)
|
||||
weight_semantic_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
params.append(weights.lexical)
|
||||
weight_lexical_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
params.append(weights.category)
|
||||
weight_category_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
params.append(weights.recency)
|
||||
weight_recency_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
params.append(weights.popularity)
|
||||
weight_popularity_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
# Add min_score parameter
|
||||
params.append(min_score)
|
||||
min_score_param = f"${param_index}"
|
||||
param_index += 1
|
||||
|
||||
# Optimized hybrid search query:
|
||||
# 1. Direct join to UnifiedContentEmbedding via contentId=storeListingVersionId (no redundant JOINs)
|
||||
# 2. UNION approach (deduplicates agents matching both branches)
|
||||
# 3. COUNT(*) OVER() to get total count in single query
|
||||
# 4. Optimized category matching with EXISTS + unnest
|
||||
# 5. Pre-calculated max values for lexical and popularity normalization
|
||||
# 6. Simplified recency calculation with linear decay
|
||||
# 7. Logarithmic popularity scaling to prevent viral agents from dominating
|
||||
sql_query = f"""
|
||||
WITH candidates AS (
|
||||
-- Lexical matches (uses GIN index on search column)
|
||||
SELECT sa."storeListingVersionId"
|
||||
FROM {{schema_prefix}}"StoreAgent" sa
|
||||
WHERE {where_clause}
|
||||
AND sa.search @@ plainto_tsquery('english', {query_param})
|
||||
|
||||
UNION
|
||||
|
||||
-- Semantic matches (uses HNSW index on embedding with KNN)
|
||||
SELECT "storeListingVersionId"
|
||||
FROM (
|
||||
SELECT sa."storeListingVersionId", uce.embedding
|
||||
FROM {{schema_prefix}}"StoreAgent" sa
|
||||
INNER JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
ON sa."storeListingVersionId" = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
WHERE {where_clause}
|
||||
ORDER BY uce.embedding <=> {embedding_param}::vector
|
||||
LIMIT 200
|
||||
) semantic_results
|
||||
),
|
||||
search_scores AS (
|
||||
SELECT
|
||||
sa.slug,
|
||||
sa.agent_name,
|
||||
sa.agent_image,
|
||||
sa.creator_username,
|
||||
sa.creator_avatar,
|
||||
sa.sub_heading,
|
||||
sa.description,
|
||||
sa.runs,
|
||||
sa.rating,
|
||||
sa.categories,
|
||||
sa.featured,
|
||||
sa.is_available,
|
||||
sa.updated_at,
|
||||
-- Semantic score: cosine similarity (1 - distance)
|
||||
COALESCE(1 - (uce.embedding <=> {embedding_param}::vector), 0) as semantic_score,
|
||||
-- Lexical score: ts_rank_cd (will be normalized later)
|
||||
COALESCE(ts_rank_cd(sa.search, plainto_tsquery('english', {query_param})), 0) as lexical_raw,
|
||||
-- Category match: optimized with unnest for better performance
|
||||
CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1 FROM unnest(sa.categories) cat
|
||||
WHERE LOWER(cat) LIKE '%' || {query_lower_param} || '%'
|
||||
)
|
||||
THEN 1.0
|
||||
ELSE 0.0
|
||||
END as category_score,
|
||||
-- Recency score: linear decay over 90 days (simpler than exponential)
|
||||
GREATEST(0, 1 - EXTRACT(EPOCH FROM (NOW() - sa.updated_at)) / (90 * 24 * 3600)) as recency_score,
|
||||
-- Popularity raw: agent runs count (will be normalized with log scaling)
|
||||
sa.runs as popularity_raw
|
||||
FROM candidates c
|
||||
INNER JOIN {{schema_prefix}}"StoreAgent" sa
|
||||
ON c."storeListingVersionId" = sa."storeListingVersionId"
|
||||
LEFT JOIN {{schema_prefix}}"UnifiedContentEmbedding" uce
|
||||
ON sa."storeListingVersionId" = uce."contentId" AND uce."contentType" = 'STORE_AGENT'::{{schema_prefix}}"ContentType"
|
||||
),
|
||||
max_lexical AS (
|
||||
SELECT MAX(lexical_raw) as max_val FROM search_scores
|
||||
),
|
||||
max_popularity AS (
|
||||
SELECT MAX(popularity_raw) as max_val FROM search_scores
|
||||
),
|
||||
normalized AS (
|
||||
SELECT
|
||||
ss.*,
|
||||
-- Normalize lexical score by pre-calculated max
|
||||
CASE
|
||||
WHEN ml.max_val > 0
|
||||
THEN ss.lexical_raw / ml.max_val
|
||||
ELSE 0
|
||||
END as lexical_score,
|
||||
-- Normalize popularity with logarithmic scaling to prevent viral agents from dominating
|
||||
-- LOG(1 + runs) / LOG(1 + max_runs) ensures score is 0-1 range
|
||||
CASE
|
||||
WHEN mp.max_val > 0 AND ss.popularity_raw > 0
|
||||
THEN LN(1 + ss.popularity_raw) / LN(1 + mp.max_val)
|
||||
ELSE 0
|
||||
END as popularity_score
|
||||
FROM search_scores ss
|
||||
CROSS JOIN max_lexical ml
|
||||
CROSS JOIN max_popularity mp
|
||||
),
|
||||
scored AS (
|
||||
SELECT
|
||||
slug,
|
||||
agent_name,
|
||||
agent_image,
|
||||
creator_username,
|
||||
creator_avatar,
|
||||
sub_heading,
|
||||
description,
|
||||
runs,
|
||||
rating,
|
||||
categories,
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
category_score,
|
||||
recency_score,
|
||||
popularity_score,
|
||||
(
|
||||
{weight_semantic_param} * semantic_score +
|
||||
{weight_lexical_param} * lexical_score +
|
||||
{weight_category_param} * category_score +
|
||||
{weight_recency_param} * recency_score +
|
||||
{weight_popularity_param} * popularity_score
|
||||
) as combined_score
|
||||
FROM normalized
|
||||
),
|
||||
filtered AS (
|
||||
SELECT
|
||||
*,
|
||||
COUNT(*) OVER () as total_count
|
||||
FROM scored
|
||||
WHERE combined_score >= {min_score_param}
|
||||
)
|
||||
SELECT * FROM filtered
|
||||
ORDER BY combined_score DESC
|
||||
LIMIT ${param_index} OFFSET ${param_index + 1}
|
||||
"""
|
||||
|
||||
# Add pagination params
|
||||
params.extend([page_size, offset])
|
||||
|
||||
# Execute search query - includes total_count via window function
|
||||
results = await query_raw_with_schema(
|
||||
sql_query, *params, set_public_search_path=True
|
||||
)
|
||||
|
||||
# Extract total count from first result (all rows have same count)
|
||||
total = results[0]["total_count"] if results else 0
|
||||
|
||||
# Remove total_count from results before returning
|
||||
for result in results:
|
||||
result.pop("total_count", None)
|
||||
|
||||
# Log without sensitive query content
|
||||
logger.info(f"Hybrid search: {len(results)} results, {total} total")
|
||||
|
||||
return results, total
|
||||
|
||||
|
||||
async def hybrid_search_simple(
|
||||
query: str,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Simplified hybrid search for common use cases.
|
||||
|
||||
Uses default weights and no filters.
|
||||
"""
|
||||
return await hybrid_search(
|
||||
query=query,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
@@ -0,0 +1,334 @@
|
||||
"""
|
||||
Integration tests for hybrid search with schema handling.
|
||||
|
||||
These tests verify that hybrid search works correctly across different database schemas.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.store.hybrid_search import HybridSearchWeights, hybrid_search
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_with_schema_handling():
|
||||
"""Test that hybrid search correctly handles database schema prefixes."""
|
||||
# Test with a mock query to ensure schema handling works
|
||||
query = "test agent"
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
# Mock the query result
|
||||
mock_query.return_value = [
|
||||
{
|
||||
"slug": "test/agent",
|
||||
"agent_name": "Test Agent",
|
||||
"agent_image": "test.png",
|
||||
"creator_username": "test",
|
||||
"creator_avatar": "avatar.png",
|
||||
"sub_heading": "Test sub-heading",
|
||||
"description": "Test description",
|
||||
"runs": 10,
|
||||
"rating": 4.5,
|
||||
"categories": ["test"],
|
||||
"featured": False,
|
||||
"is_available": True,
|
||||
"updated_at": "2024-01-01T00:00:00Z",
|
||||
"combined_score": 0.8,
|
||||
"semantic_score": 0.7,
|
||||
"lexical_score": 0.6,
|
||||
"category_score": 0.5,
|
||||
"recency_score": 0.4,
|
||||
"total_count": 1,
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536 # Mock embedding
|
||||
|
||||
results, total = await hybrid_search(
|
||||
query=query,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify the query was called
|
||||
assert mock_query.called
|
||||
# Verify the SQL template uses schema_prefix placeholder
|
||||
call_args = mock_query.call_args
|
||||
sql_template = call_args[0][0]
|
||||
assert "{schema_prefix}" in sql_template
|
||||
|
||||
# Verify results
|
||||
assert len(results) == 1
|
||||
assert total == 1
|
||||
assert results[0]["slug"] == "test/agent"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_with_public_schema():
|
||||
"""Test hybrid search when using public schema (no prefix needed)."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "public"
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify the mock was set up correctly
|
||||
assert mock_schema.return_value == "public"
|
||||
|
||||
# Results should work even with empty results
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_with_custom_schema():
|
||||
"""Test hybrid search when using custom schema (e.g., 'platform')."""
|
||||
with patch("backend.data.db.get_database_schema") as mock_schema:
|
||||
mock_schema.return_value = "platform"
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify the mock was set up correctly
|
||||
assert mock_schema.return_value == "platform"
|
||||
|
||||
assert results == []
|
||||
assert total == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_without_embeddings():
|
||||
"""Test hybrid search fails fast when embeddings are unavailable."""
|
||||
# Patch where the function is used, not where it's defined
|
||||
with patch("backend.api.features.store.hybrid_search.embed_query") as mock_embed:
|
||||
# Simulate embedding failure
|
||||
mock_embed.return_value = None
|
||||
|
||||
# Should raise ValueError with helpful message
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify error message is generic (doesn't leak implementation details)
|
||||
assert "Search service temporarily unavailable" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_with_filters():
|
||||
"""Test hybrid search with various filters."""
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
# Test with featured filter
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
featured=True,
|
||||
creators=["user1", "user2"],
|
||||
category="productivity",
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
# Verify filters were applied in the query
|
||||
call_args = mock_query.call_args
|
||||
params = call_args[0][1:] # Skip SQL template
|
||||
|
||||
# Should have query, query_lower, creators array, category
|
||||
assert len(params) >= 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_weights():
|
||||
"""Test hybrid search with custom weights."""
|
||||
custom_weights = HybridSearchWeights(
|
||||
semantic=0.5,
|
||||
lexical=0.3,
|
||||
category=0.1,
|
||||
recency=0.1,
|
||||
popularity=0.0,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
weights=custom_weights,
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify custom weights were used in the query
|
||||
call_args = mock_query.call_args
|
||||
sql_template = call_args[0][0]
|
||||
params = call_args[0][1:] # Get all parameters passed
|
||||
|
||||
# Check that SQL uses parameterized weights (not f-string interpolation)
|
||||
assert "$" in sql_template # Verify parameterization is used
|
||||
|
||||
# Check that custom weights are in the params
|
||||
assert 0.5 in params # semantic weight
|
||||
assert 0.3 in params # lexical weight
|
||||
assert 0.1 in params # category and recency weights
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_min_score_filtering():
|
||||
"""Test hybrid search minimum score threshold."""
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
# Return results with varying scores
|
||||
mock_query.return_value = [
|
||||
{
|
||||
"slug": "high-score/agent",
|
||||
"agent_name": "High Score Agent",
|
||||
"combined_score": 0.8,
|
||||
"total_count": 1,
|
||||
# ... other fields
|
||||
}
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
# Test with custom min_score
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
min_score=0.5, # High threshold
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
# Verify min_score was applied in query
|
||||
call_args = mock_query.call_args
|
||||
sql_template = call_args[0][0]
|
||||
params = call_args[0][1:] # Get all parameters
|
||||
|
||||
# Check that SQL uses parameterized min_score
|
||||
assert "combined_score >=" in sql_template
|
||||
assert "$" in sql_template # Verify parameterization
|
||||
|
||||
# Check that custom min_score is in the params
|
||||
assert 0.5 in params
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_pagination():
|
||||
"""Test hybrid search pagination."""
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
mock_query.return_value = []
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
# Test page 2 with page_size 10
|
||||
results, total = await hybrid_search(
|
||||
query="test",
|
||||
page=2,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
# Verify pagination parameters
|
||||
call_args = mock_query.call_args
|
||||
params = call_args[0]
|
||||
|
||||
# Last two params should be LIMIT and OFFSET
|
||||
limit = params[-2]
|
||||
offset = params[-1]
|
||||
|
||||
assert limit == 10 # page_size
|
||||
assert offset == 10 # (page - 1) * page_size = (2 - 1) * 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
@pytest.mark.integration
|
||||
async def test_hybrid_search_error_handling():
|
||||
"""Test hybrid search error handling."""
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.query_raw_with_schema"
|
||||
) as mock_query:
|
||||
# Simulate database error
|
||||
mock_query.side_effect = Exception("Database connection error")
|
||||
|
||||
with patch(
|
||||
"backend.api.features.store.hybrid_search.embed_query"
|
||||
) as mock_embed:
|
||||
mock_embed.return_value = [0.1] * 1536
|
||||
|
||||
# Should raise exception
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await hybrid_search(
|
||||
query="test",
|
||||
page=1,
|
||||
page_size=20,
|
||||
)
|
||||
|
||||
assert "Database connection error" in str(exc_info.value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
@@ -18,6 +18,7 @@ from backend.data.model import (
|
||||
SchemaField,
|
||||
)
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.request import DEFAULT_USER_AGENT
|
||||
|
||||
|
||||
class GetWikipediaSummaryBlock(Block, GetRequest):
|
||||
@@ -39,17 +40,27 @@ class GetWikipediaSummaryBlock(Block, GetRequest):
|
||||
output_schema=GetWikipediaSummaryBlock.Output,
|
||||
test_input={"topic": "Artificial Intelligence"},
|
||||
test_output=("summary", "summary content"),
|
||||
test_mock={"get_request": lambda url, json: {"extract": "summary content"}},
|
||||
test_mock={
|
||||
"get_request": lambda url, headers, json: {"extract": "summary content"}
|
||||
},
|
||||
)
|
||||
|
||||
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
topic = input_data.topic
|
||||
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{topic}"
|
||||
# URL-encode the topic to handle spaces and special characters
|
||||
encoded_topic = quote(topic, safe="")
|
||||
url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{encoded_topic}"
|
||||
|
||||
# Set headers per Wikimedia robot policy (https://w.wiki/4wJS)
|
||||
# - User-Agent: Required, must identify the bot
|
||||
# - Accept-Encoding: gzip recommended to reduce bandwidth
|
||||
headers = {
|
||||
"User-Agent": DEFAULT_USER_AGENT,
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
}
|
||||
|
||||
# Note: User-Agent is now automatically set by the request library
|
||||
# to comply with Wikimedia's robot policy (https://w.wiki/4wJS)
|
||||
try:
|
||||
response = await self.get_request(url, json=True)
|
||||
response = await self.get_request(url, headers=headers, json=True)
|
||||
if "extract" not in response:
|
||||
raise ValueError(f"Unable to parse Wikipedia response: {response}")
|
||||
yield "summary", response["extract"]
|
||||
|
||||
@@ -391,8 +391,12 @@ class SmartDecisionMakerBlock(Block):
|
||||
"""
|
||||
block = sink_node.block
|
||||
|
||||
# Use custom name from node metadata if set, otherwise fall back to block.name
|
||||
custom_name = sink_node.metadata.get("customized_name")
|
||||
tool_name = custom_name if custom_name else block.name
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": SmartDecisionMakerBlock.cleanup(block.name),
|
||||
"name": SmartDecisionMakerBlock.cleanup(tool_name),
|
||||
"description": block.description,
|
||||
}
|
||||
sink_block_input_schema = block.input_schema
|
||||
@@ -489,14 +493,24 @@ class SmartDecisionMakerBlock(Block):
|
||||
f"Sink graph metadata not found: {graph_id} {graph_version}"
|
||||
)
|
||||
|
||||
# Use custom name from node metadata if set, otherwise fall back to graph name
|
||||
custom_name = sink_node.metadata.get("customized_name")
|
||||
tool_name = custom_name if custom_name else sink_graph_meta.name
|
||||
|
||||
tool_function: dict[str, Any] = {
|
||||
"name": SmartDecisionMakerBlock.cleanup(sink_graph_meta.name),
|
||||
"name": SmartDecisionMakerBlock.cleanup(tool_name),
|
||||
"description": sink_graph_meta.description,
|
||||
}
|
||||
|
||||
properties = {}
|
||||
field_mapping = {}
|
||||
|
||||
for link in links:
|
||||
field_name = link.sink_name
|
||||
|
||||
clean_field_name = SmartDecisionMakerBlock.cleanup(field_name)
|
||||
field_mapping[clean_field_name] = field_name
|
||||
|
||||
sink_block_input_schema = sink_node.input_default["input_schema"]
|
||||
sink_block_properties = sink_block_input_schema.get("properties", {}).get(
|
||||
link.sink_name, {}
|
||||
@@ -506,7 +520,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
if "description" in sink_block_properties
|
||||
else f"The {link.sink_name} of the tool"
|
||||
)
|
||||
properties[link.sink_name] = {
|
||||
properties[clean_field_name] = {
|
||||
"type": "string",
|
||||
"description": description,
|
||||
"default": json.dumps(sink_block_properties.get("default", None)),
|
||||
@@ -519,7 +533,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
"strict": True,
|
||||
}
|
||||
|
||||
# Store node info for later use in output processing
|
||||
tool_function["_field_mapping"] = field_mapping
|
||||
tool_function["_sink_node_id"] = sink_node.id
|
||||
|
||||
return {"type": "function", "function": tool_function}
|
||||
@@ -1147,8 +1161,9 @@ class SmartDecisionMakerBlock(Block):
|
||||
original_field_name = field_mapping.get(clean_arg_name, clean_arg_name)
|
||||
arg_value = tool_args.get(clean_arg_name)
|
||||
|
||||
sanitized_arg_name = self.cleanup(original_field_name)
|
||||
emit_key = f"tools_^_{sink_node_id}_~_{sanitized_arg_name}"
|
||||
# Use original_field_name directly (not sanitized) to match link sink_name
|
||||
# The field_mapping already translates from LLM's cleaned names to original names
|
||||
emit_key = f"tools_^_{sink_node_id}_~_{original_field_name}"
|
||||
|
||||
logger.debug(
|
||||
"[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s",
|
||||
|
||||
@@ -1057,3 +1057,153 @@ async def test_smart_decision_maker_traditional_mode_default():
|
||||
) # Should yield individual tool parameters
|
||||
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
|
||||
assert "conversations" in outputs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_uses_customized_name_for_blocks():
|
||||
"""Test that SmartDecisionMakerBlock uses customized_name from node metadata for tool names."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node with customized_name in metadata
|
||||
mock_node = MagicMock(spec=Node)
|
||||
mock_node.id = "test-node-id"
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
|
||||
mock_node.block = StoreValueBlock()
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "input"
|
||||
|
||||
# Call the function directly
|
||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
# Verify the tool name uses the customized name (cleaned up)
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "my_custom_tool_name" # Cleaned version
|
||||
assert result["function"]["_sink_node_id"] == "test-node-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_falls_back_to_block_name():
|
||||
"""Test that SmartDecisionMakerBlock falls back to block.name when no customized_name."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.blocks.basic import StoreValueBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node without customized_name
|
||||
mock_node = MagicMock(spec=Node)
|
||||
mock_node.id = "test-node-id"
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {} # No customized_name
|
||||
mock_node.block = StoreValueBlock()
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "input"
|
||||
|
||||
# Call the function directly
|
||||
result = await SmartDecisionMakerBlock._create_block_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
# Verify the tool name uses the block's default name
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "storevalueblock" # Default block name cleaned
|
||||
assert result["function"]["_sink_node_id"] == "test-node-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_uses_customized_name_for_agents():
|
||||
"""Test that SmartDecisionMakerBlock uses customized_name from metadata for agent nodes."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node with customized_name in metadata
|
||||
mock_node = MagicMock(spec=Node)
|
||||
mock_node.id = "test-agent-node-id"
|
||||
mock_node.metadata = {"customized_name": "My Custom Agent"}
|
||||
mock_node.input_default = {
|
||||
"graph_id": "test-graph-id",
|
||||
"graph_version": 1,
|
||||
"input_schema": {"properties": {"test_input": {"description": "Test input"}}},
|
||||
}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "test_input"
|
||||
|
||||
# Mock the database client
|
||||
mock_graph_meta = MagicMock()
|
||||
mock_graph_meta.name = "Original Agent Name"
|
||||
mock_graph_meta.description = "Agent description"
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
||||
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
# Verify the tool name uses the customized name (cleaned up)
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "my_custom_agent" # Cleaned version
|
||||
assert result["function"]["_sink_node_id"] == "test-agent-node-id"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_decision_maker_agent_falls_back_to_graph_name():
|
||||
"""Test that agent node falls back to graph name when no customized_name."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.data.graph import Link, Node
|
||||
|
||||
# Create a mock node without customized_name
|
||||
mock_node = MagicMock(spec=Node)
|
||||
mock_node.id = "test-agent-node-id"
|
||||
mock_node.metadata = {} # No customized_name
|
||||
mock_node.input_default = {
|
||||
"graph_id": "test-graph-id",
|
||||
"graph_version": 1,
|
||||
"input_schema": {"properties": {"test_input": {"description": "Test input"}}},
|
||||
}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
mock_link.sink_name = "test_input"
|
||||
|
||||
# Mock the database client
|
||||
mock_graph_meta = MagicMock()
|
||||
mock_graph_meta.name = "Original Agent Name"
|
||||
mock_graph_meta.description = "Agent description"
|
||||
|
||||
mock_db_client = AsyncMock()
|
||||
mock_db_client.get_graph_metadata.return_value = mock_graph_meta
|
||||
|
||||
with patch(
|
||||
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
result = await SmartDecisionMakerBlock._create_agent_function_signature(
|
||||
mock_node, [mock_link]
|
||||
)
|
||||
|
||||
# Verify the tool name uses the graph's default name
|
||||
assert result["type"] == "function"
|
||||
assert result["function"]["name"] == "original_agent_name" # Graph name cleaned
|
||||
assert result["function"]["_sink_node_id"] == "test-agent-node-id"
|
||||
|
||||
@@ -15,6 +15,7 @@ async def test_smart_decision_maker_handles_dynamic_dict_fields():
|
||||
mock_node.block = CreateDictionaryBlock()
|
||||
mock_node.block_id = CreateDictionaryBlock().id
|
||||
mock_node.input_default = {}
|
||||
mock_node.metadata = {}
|
||||
|
||||
# Create mock links with dynamic dictionary fields
|
||||
mock_links = [
|
||||
@@ -77,6 +78,7 @@ async def test_smart_decision_maker_handles_dynamic_list_fields():
|
||||
mock_node.block = AddToListBlock()
|
||||
mock_node.block_id = AddToListBlock().id
|
||||
mock_node.input_default = {}
|
||||
mock_node.metadata = {}
|
||||
|
||||
# Create mock links with dynamic list fields
|
||||
mock_links = [
|
||||
|
||||
@@ -44,6 +44,7 @@ async def test_create_block_function_signature_with_dict_fields():
|
||||
mock_node.block = CreateDictionaryBlock()
|
||||
mock_node.block_id = CreateDictionaryBlock().id
|
||||
mock_node.input_default = {}
|
||||
mock_node.metadata = {}
|
||||
|
||||
# Create mock links with dynamic dictionary fields (source sanitized, sink original)
|
||||
mock_links = [
|
||||
@@ -106,6 +107,7 @@ async def test_create_block_function_signature_with_list_fields():
|
||||
mock_node.block = AddToListBlock()
|
||||
mock_node.block_id = AddToListBlock().id
|
||||
mock_node.input_default = {}
|
||||
mock_node.metadata = {}
|
||||
|
||||
# Create mock links with dynamic list fields
|
||||
mock_links = [
|
||||
@@ -159,6 +161,7 @@ async def test_create_block_function_signature_with_object_fields():
|
||||
mock_node.block = MatchTextPatternBlock()
|
||||
mock_node.block_id = MatchTextPatternBlock().id
|
||||
mock_node.input_default = {}
|
||||
mock_node.metadata = {}
|
||||
|
||||
# Create mock links with dynamic object fields
|
||||
mock_links = [
|
||||
@@ -208,11 +211,13 @@ async def test_create_tool_node_signatures():
|
||||
mock_dict_node.block = CreateDictionaryBlock()
|
||||
mock_dict_node.block_id = CreateDictionaryBlock().id
|
||||
mock_dict_node.input_default = {}
|
||||
mock_dict_node.metadata = {}
|
||||
|
||||
mock_list_node = Mock()
|
||||
mock_list_node.block = AddToListBlock()
|
||||
mock_list_node.block_id = AddToListBlock().id
|
||||
mock_list_node.input_default = {}
|
||||
mock_list_node.metadata = {}
|
||||
|
||||
# Mock links with dynamic fields
|
||||
dict_link1 = Mock(
|
||||
@@ -423,6 +428,7 @@ async def test_mixed_regular_and_dynamic_fields():
|
||||
mock_node.block.name = "TestBlock"
|
||||
mock_node.block.description = "A test block"
|
||||
mock_node.block.input_schema = Mock()
|
||||
mock_node.metadata = {}
|
||||
|
||||
# Mock the get_field_schema to return a proper schema for regular fields
|
||||
def get_field_schema(field_name):
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .blog import WordPressCreatePostBlock
|
||||
from .blog import WordPressCreatePostBlock, WordPressGetAllPostsBlock
|
||||
|
||||
__all__ = ["WordPressCreatePostBlock"]
|
||||
__all__ = ["WordPressCreatePostBlock", "WordPressGetAllPostsBlock"]
|
||||
|
||||
@@ -161,7 +161,7 @@ async def oauth_exchange_code_for_tokens(
|
||||
grant_type="authorization_code",
|
||||
).model_dump(exclude_none=True)
|
||||
|
||||
response = await Requests().post(
|
||||
response = await Requests(raise_for_status=False).post(
|
||||
f"{WORDPRESS_BASE_URL}oauth2/token",
|
||||
headers=headers,
|
||||
data=data,
|
||||
@@ -205,7 +205,7 @@ async def oauth_refresh_tokens(
|
||||
grant_type="refresh_token",
|
||||
).model_dump(exclude_none=True)
|
||||
|
||||
response = await Requests().post(
|
||||
response = await Requests(raise_for_status=False).post(
|
||||
f"{WORDPRESS_BASE_URL}oauth2/token",
|
||||
headers=headers,
|
||||
data=data,
|
||||
@@ -252,7 +252,7 @@ async def validate_token(
|
||||
"token": token,
|
||||
}
|
||||
|
||||
response = await Requests().get(
|
||||
response = await Requests(raise_for_status=False).get(
|
||||
f"{WORDPRESS_BASE_URL}oauth2/token-info",
|
||||
params=params,
|
||||
)
|
||||
@@ -296,7 +296,7 @@ async def make_api_request(
|
||||
|
||||
url = f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}"
|
||||
|
||||
request_method = getattr(Requests(), method.lower())
|
||||
request_method = getattr(Requests(raise_for_status=False), method.lower())
|
||||
response = await request_method(
|
||||
url,
|
||||
headers=headers,
|
||||
@@ -476,6 +476,7 @@ async def create_post(
|
||||
data["tags"] = ",".join(str(t) for t in data["tags"])
|
||||
|
||||
# Make the API request
|
||||
site = normalize_site(site)
|
||||
endpoint = f"/rest/v1.1/sites/{site}/posts/new"
|
||||
|
||||
headers = {
|
||||
@@ -483,7 +484,7 @@ async def create_post(
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
|
||||
response = await Requests().post(
|
||||
response = await Requests(raise_for_status=False).post(
|
||||
f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}",
|
||||
headers=headers,
|
||||
data=data,
|
||||
@@ -499,3 +500,132 @@ async def create_post(
|
||||
)
|
||||
error_message = error_data.get("message", response.text)
|
||||
raise ValueError(f"Failed to create post: {response.status} - {error_message}")
|
||||
|
||||
|
||||
class Post(BaseModel):
|
||||
"""Response model for individual posts in a posts list response.
|
||||
|
||||
This is a simplified version compared to PostResponse, as the list endpoint
|
||||
returns less detailed information than the create/get single post endpoints.
|
||||
"""
|
||||
|
||||
ID: int
|
||||
site_ID: int
|
||||
author: PostAuthor
|
||||
date: datetime
|
||||
modified: datetime
|
||||
title: str
|
||||
URL: str
|
||||
short_URL: str
|
||||
content: str | None = None
|
||||
excerpt: str | None = None
|
||||
slug: str
|
||||
guid: str
|
||||
status: str
|
||||
sticky: bool
|
||||
password: str | None = ""
|
||||
parent: Union[Dict[str, Any], bool, None] = None
|
||||
type: str
|
||||
discussion: Dict[str, Union[str, bool, int]] | None = None
|
||||
likes_enabled: bool | None = None
|
||||
sharing_enabled: bool | None = None
|
||||
like_count: int | None = None
|
||||
i_like: bool | None = None
|
||||
is_reblogged: bool | None = None
|
||||
is_following: bool | None = None
|
||||
global_ID: str | None = None
|
||||
featured_image: str | None = None
|
||||
post_thumbnail: Dict[str, Any] | None = None
|
||||
format: str | None = None
|
||||
geo: Union[Dict[str, Any], bool, None] = None
|
||||
menu_order: int | None = None
|
||||
page_template: str | None = None
|
||||
publicize_URLs: List[str] | None = None
|
||||
terms: Dict[str, Dict[str, Any]] | None = None
|
||||
tags: Dict[str, Dict[str, Any]] | None = None
|
||||
categories: Dict[str, Dict[str, Any]] | None = None
|
||||
attachments: Dict[str, Dict[str, Any]] | None = None
|
||||
attachment_count: int | None = None
|
||||
metadata: List[Dict[str, Any]] | None = None
|
||||
meta: Dict[str, Any] | None = None
|
||||
capabilities: Dict[str, bool] | None = None
|
||||
revisions: List[int] | None = None
|
||||
other_URLs: Dict[str, Any] | None = None
|
||||
|
||||
|
||||
class PostsResponse(BaseModel):
|
||||
"""Response model for WordPress posts list."""
|
||||
|
||||
found: int
|
||||
posts: List[Post]
|
||||
meta: Dict[str, Any]
|
||||
|
||||
|
||||
def normalize_site(site: str) -> str:
|
||||
"""
|
||||
Normalize a site identifier by stripping protocol and trailing slashes.
|
||||
|
||||
Args:
|
||||
site: Site URL, domain, or ID (e.g., "https://myblog.wordpress.com/", "myblog.wordpress.com", "123456789")
|
||||
|
||||
Returns:
|
||||
Normalized site identifier (domain or ID only)
|
||||
"""
|
||||
site = site.strip()
|
||||
if site.startswith("https://"):
|
||||
site = site[8:]
|
||||
elif site.startswith("http://"):
|
||||
site = site[7:]
|
||||
return site.rstrip("/")
|
||||
|
||||
|
||||
async def get_posts(
|
||||
credentials: Credentials,
|
||||
site: str,
|
||||
status: PostStatus | None = None,
|
||||
number: int = 100,
|
||||
offset: int = 0,
|
||||
) -> PostsResponse:
|
||||
"""
|
||||
Get posts from a WordPress site.
|
||||
|
||||
Args:
|
||||
credentials: OAuth credentials
|
||||
site: Site ID or domain (e.g., "myblog.wordpress.com" or "123456789")
|
||||
status: Filter by post status using PostStatus enum, or None for all
|
||||
number: Number of posts to retrieve (max 100)
|
||||
offset: Number of posts to skip (for pagination)
|
||||
|
||||
Returns:
|
||||
PostsResponse with the list of posts
|
||||
"""
|
||||
site = normalize_site(site)
|
||||
endpoint = f"/rest/v1.1/sites/{site}/posts"
|
||||
|
||||
headers = {
|
||||
"Authorization": credentials.auth_header(),
|
||||
}
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"number": max(1, min(number, 100)), # 1–100 posts per request
|
||||
"offset": offset,
|
||||
}
|
||||
|
||||
if status:
|
||||
params["status"] = status.value
|
||||
response = await Requests(raise_for_status=False).get(
|
||||
f"{WORDPRESS_BASE_URL.rstrip('/')}{endpoint}",
|
||||
headers=headers,
|
||||
params=params,
|
||||
)
|
||||
|
||||
if response.ok:
|
||||
return PostsResponse.model_validate(response.json())
|
||||
|
||||
error_data = (
|
||||
response.json()
|
||||
if response.headers.get("content-type", "").startswith("application/json")
|
||||
else {}
|
||||
)
|
||||
error_message = error_data.get("message", response.text)
|
||||
raise ValueError(f"Failed to get posts: {response.status} - {error_message}")
|
||||
|
||||
@@ -9,7 +9,15 @@ from backend.sdk import (
|
||||
SchemaField,
|
||||
)
|
||||
|
||||
from ._api import CreatePostRequest, PostResponse, PostStatus, create_post
|
||||
from ._api import (
|
||||
CreatePostRequest,
|
||||
Post,
|
||||
PostResponse,
|
||||
PostsResponse,
|
||||
PostStatus,
|
||||
create_post,
|
||||
get_posts,
|
||||
)
|
||||
from ._config import wordpress
|
||||
|
||||
|
||||
@@ -49,8 +57,15 @@ class WordPressCreatePostBlock(Block):
|
||||
media_urls: list[str] = SchemaField(
|
||||
description="URLs of images to sideload and attach to the post", default=[]
|
||||
)
|
||||
publish_as_draft: bool = SchemaField(
|
||||
description="If True, publishes the post as a draft. If False, publishes it publicly.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
site: str = SchemaField(
|
||||
description="The site ID or domain (pass-through for chaining with other blocks)"
|
||||
)
|
||||
post_id: int = SchemaField(description="The ID of the created post")
|
||||
post_url: str = SchemaField(description="The full URL of the created post")
|
||||
short_url: str = SchemaField(description="The shortened wp.me URL")
|
||||
@@ -78,7 +93,9 @@ class WordPressCreatePostBlock(Block):
|
||||
tags=input_data.tags,
|
||||
featured_image=input_data.featured_image,
|
||||
media_urls=input_data.media_urls,
|
||||
status=PostStatus.PUBLISH,
|
||||
status=(
|
||||
PostStatus.DRAFT if input_data.publish_as_draft else PostStatus.PUBLISH
|
||||
),
|
||||
)
|
||||
|
||||
post_response: PostResponse = await create_post(
|
||||
@@ -87,7 +104,69 @@ class WordPressCreatePostBlock(Block):
|
||||
post_data=post_request,
|
||||
)
|
||||
|
||||
yield "site", input_data.site
|
||||
yield "post_id", post_response.ID
|
||||
yield "post_url", post_response.URL
|
||||
yield "short_url", post_response.short_URL
|
||||
yield "post_data", post_response.model_dump()
|
||||
|
||||
|
||||
class WordPressGetAllPostsBlock(Block):
|
||||
"""
|
||||
Fetches all posts from a WordPress.com site or Jetpack-enabled site.
|
||||
Supports filtering by status and pagination.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: CredentialsMetaInput = wordpress.credentials_field()
|
||||
site: str = SchemaField(
|
||||
description="Site ID or domain (e.g., 'myblog.wordpress.com' or '123456789')"
|
||||
)
|
||||
status: PostStatus | None = SchemaField(
|
||||
description="Filter by post status, or None for all",
|
||||
default=None,
|
||||
)
|
||||
number: int = SchemaField(
|
||||
description="Number of posts to retrieve (max 100 per request)", default=20
|
||||
)
|
||||
offset: int = SchemaField(
|
||||
description="Number of posts to skip (for pagination)", default=0
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
site: str = SchemaField(
|
||||
description="The site ID or domain (pass-through for chaining with other blocks)"
|
||||
)
|
||||
found: int = SchemaField(description="Total number of posts found")
|
||||
posts: list[Post] = SchemaField(
|
||||
description="List of post objects with their details"
|
||||
)
|
||||
post: Post = SchemaField(
|
||||
description="Individual post object (yielded for each post)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="97728fa7-7f6f-4789-ba0c-f2c114119536",
|
||||
description="Fetch all posts from WordPress.com or Jetpack sites",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self, input_data: Input, *, credentials: Credentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
posts_response: PostsResponse = await get_posts(
|
||||
credentials=credentials,
|
||||
site=input_data.site,
|
||||
status=input_data.status,
|
||||
number=input_data.number,
|
||||
offset=input_data.offset,
|
||||
)
|
||||
|
||||
yield "site", input_data.site
|
||||
yield "found", posts_response.found
|
||||
yield "posts", posts_response.posts
|
||||
for post in posts_response.posts:
|
||||
yield "post", post
|
||||
|
||||
@@ -38,6 +38,20 @@ POOL_TIMEOUT = os.getenv("DB_POOL_TIMEOUT")
|
||||
if POOL_TIMEOUT:
|
||||
DATABASE_URL = add_param(DATABASE_URL, "pool_timeout", POOL_TIMEOUT)
|
||||
|
||||
# Add public schema to search_path for pgvector type access
|
||||
# The vector extension is in public schema, but search_path is determined by schema parameter
|
||||
# Extract the schema from DATABASE_URL or default to 'platform'
|
||||
parsed_url = urlparse(DATABASE_URL)
|
||||
url_params = dict(parse_qsl(parsed_url.query))
|
||||
db_schema = url_params.get("schema", "platform")
|
||||
# Build search_path, avoiding duplicates if db_schema is already 'public'
|
||||
search_path_schemas = list(
|
||||
dict.fromkeys([db_schema, "public"])
|
||||
) # Preserves order, removes duplicates
|
||||
search_path = ",".join(search_path_schemas)
|
||||
# This allows using ::vector without schema qualification
|
||||
DATABASE_URL = add_param(DATABASE_URL, "options", f"-c search_path={search_path}")
|
||||
|
||||
HTTP_TIMEOUT = int(POOL_TIMEOUT) if POOL_TIMEOUT else None
|
||||
|
||||
prisma = Prisma(
|
||||
@@ -108,21 +122,102 @@ def get_database_schema() -> str:
|
||||
return query_params.get("schema", "public")
|
||||
|
||||
|
||||
async def query_raw_with_schema(query_template: str, *args) -> list[dict]:
|
||||
"""Execute raw SQL query with proper schema handling."""
|
||||
async def _raw_with_schema(
|
||||
query_template: str,
|
||||
*args,
|
||||
execute: bool = False,
|
||||
client: Prisma | None = None,
|
||||
set_public_search_path: bool = False,
|
||||
) -> list[dict] | int:
|
||||
"""Internal: Execute raw SQL with proper schema handling.
|
||||
|
||||
Use query_raw_with_schema() or execute_raw_with_schema() instead.
|
||||
|
||||
Args:
|
||||
query_template: SQL query with {schema_prefix} placeholder
|
||||
*args: Query parameters
|
||||
execute: If False, executes SELECT query. If True, executes INSERT/UPDATE/DELETE.
|
||||
client: Optional Prisma client for transactions (only used when execute=True).
|
||||
set_public_search_path: If True, sets search_path to include public schema.
|
||||
Needed for pgvector types and other public schema objects.
|
||||
|
||||
Returns:
|
||||
- list[dict] if execute=False (query results)
|
||||
- int if execute=True (number of affected rows)
|
||||
"""
|
||||
schema = get_database_schema()
|
||||
schema_prefix = f'"{schema}".' if schema != "public" else ""
|
||||
formatted_query = query_template.format(schema_prefix=schema_prefix)
|
||||
|
||||
import prisma as prisma_module
|
||||
|
||||
result = await prisma_module.get_client().query_raw(
|
||||
formatted_query, *args # type: ignore
|
||||
)
|
||||
db_client = client if client else prisma_module.get_client()
|
||||
|
||||
# Set search_path to include public schema if requested
|
||||
# Prisma doesn't support the 'options' connection parameter, so we set it per-session
|
||||
# This is idempotent and safe to call multiple times
|
||||
if set_public_search_path:
|
||||
await db_client.execute_raw(f"SET search_path = {schema}, public") # type: ignore
|
||||
|
||||
if execute:
|
||||
result = await db_client.execute_raw(formatted_query, *args) # type: ignore
|
||||
else:
|
||||
result = await db_client.query_raw(formatted_query, *args) # type: ignore
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def query_raw_with_schema(
|
||||
query_template: str, *args, set_public_search_path: bool = False
|
||||
) -> list[dict]:
|
||||
"""Execute raw SQL SELECT query with proper schema handling.
|
||||
|
||||
Args:
|
||||
query_template: SQL query with {schema_prefix} placeholder
|
||||
*args: Query parameters
|
||||
set_public_search_path: If True, sets search_path to include public schema.
|
||||
Needed for pgvector types and other public schema objects.
|
||||
|
||||
Returns:
|
||||
List of result rows as dictionaries
|
||||
|
||||
Example:
|
||||
results = await query_raw_with_schema(
|
||||
'SELECT * FROM {schema_prefix}"User" WHERE id = $1',
|
||||
user_id
|
||||
)
|
||||
"""
|
||||
return await _raw_with_schema(query_template, *args, execute=False, set_public_search_path=set_public_search_path) # type: ignore
|
||||
|
||||
|
||||
async def execute_raw_with_schema(
|
||||
query_template: str,
|
||||
*args,
|
||||
client: Prisma | None = None,
|
||||
set_public_search_path: bool = False,
|
||||
) -> int:
|
||||
"""Execute raw SQL command (INSERT/UPDATE/DELETE) with proper schema handling.
|
||||
|
||||
Args:
|
||||
query_template: SQL query with {schema_prefix} placeholder
|
||||
*args: Query parameters
|
||||
client: Optional Prisma client for transactions
|
||||
set_public_search_path: If True, sets search_path to include public schema.
|
||||
Needed for pgvector types and other public schema objects.
|
||||
|
||||
Returns:
|
||||
Number of affected rows
|
||||
|
||||
Example:
|
||||
await execute_raw_with_schema(
|
||||
'INSERT INTO {schema_prefix}"User" (id, name) VALUES ($1, $2)',
|
||||
user_id, name,
|
||||
client=tx # Optional transaction client
|
||||
)
|
||||
"""
|
||||
return await _raw_with_schema(query_template, *args, execute=True, client=client, set_public_search_path=set_public_search_path) # type: ignore
|
||||
|
||||
|
||||
class BaseDbModel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import fastapi.exceptions
|
||||
@@ -18,6 +19,17 @@ from backend.usecases.sample import create_test_user
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def mock_embedding_functions():
|
||||
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
||||
with patch(
|
||||
"backend.api.features.store.db.ensure_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_graph_creation(server: SpinTestServer, snapshot: Snapshot):
|
||||
"""
|
||||
|
||||
@@ -7,6 +7,10 @@ from backend.api.features.library.db import (
|
||||
list_library_agents,
|
||||
)
|
||||
from backend.api.features.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.api.features.store.embeddings import (
|
||||
backfill_missing_embeddings,
|
||||
get_embedding_stats,
|
||||
)
|
||||
from backend.data import db
|
||||
from backend.data.analytics import (
|
||||
get_accuracy_trends_and_alerts,
|
||||
@@ -208,6 +212,10 @@ class DatabaseManager(AppService):
|
||||
get_store_agents = _(get_store_agents)
|
||||
get_store_agent_details = _(get_store_agent_details)
|
||||
|
||||
# Store Embeddings
|
||||
get_embedding_stats = _(get_embedding_stats)
|
||||
backfill_missing_embeddings = _(backfill_missing_embeddings)
|
||||
|
||||
# Summary data - async
|
||||
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
||||
|
||||
@@ -259,6 +267,10 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
get_store_agents = _(d.get_store_agents)
|
||||
get_store_agent_details = _(d.get_store_agent_details)
|
||||
|
||||
# Store Embeddings
|
||||
get_embedding_stats = _(d.get_embedding_stats)
|
||||
backfill_missing_embeddings = _(d.backfill_missing_embeddings)
|
||||
|
||||
|
||||
class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import fastapi.responses
|
||||
import pytest
|
||||
@@ -19,6 +20,17 @@ from backend.util.test import SpinTestServer, wait_execution
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def mock_embedding_functions():
|
||||
"""Mock embedding functions for all tests to avoid database/API dependencies."""
|
||||
with patch(
|
||||
"backend.api.features.store.db.ensure_embedding",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
async def create_graph(s: SpinTestServer, g: graph.Graph, u: User) -> graph.Graph:
|
||||
logger.info(f"Creating graph for user {u.id}")
|
||||
return await s.agent_server.test_create_graph(CreateGraph(graph=g), u.id)
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
@@ -37,7 +38,7 @@ from backend.monitoring import (
|
||||
report_execution_accuracy_alerts,
|
||||
report_late_executions,
|
||||
)
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.clients import get_database_manager_client, get_scheduler_client
|
||||
from backend.util.cloud_storage import cleanup_expired_files_async
|
||||
from backend.util.exceptions import (
|
||||
GraphNotFoundError,
|
||||
@@ -254,6 +255,88 @@ def execution_accuracy_alerts():
|
||||
return report_execution_accuracy_alerts()
|
||||
|
||||
|
||||
def ensure_embeddings_coverage():
|
||||
"""
|
||||
Ensure all content types (store agents, blocks, docs) have embeddings for search.
|
||||
|
||||
Processes ALL missing embeddings in batches of 10 per content type until 100% coverage.
|
||||
Missing embeddings = content invisible in hybrid search.
|
||||
|
||||
Schedule: Runs every 6 hours (balanced between coverage and API costs).
|
||||
- Catches new content added between scheduled runs
|
||||
- Batch size 10 per content type: gradual processing to avoid rate limits
|
||||
- Manual trigger available via execute_ensure_embeddings_coverage endpoint
|
||||
"""
|
||||
db_client = get_database_manager_client()
|
||||
stats = db_client.get_embedding_stats()
|
||||
|
||||
# Check for error from get_embedding_stats() first
|
||||
if "error" in stats:
|
||||
logger.error(
|
||||
f"Failed to get embedding stats: {stats['error']} - skipping backfill"
|
||||
)
|
||||
return {"processed": 0, "success": 0, "failed": 0, "error": stats["error"]}
|
||||
|
||||
# Extract totals from new stats structure
|
||||
totals = stats.get("totals", {})
|
||||
without_embeddings = totals.get("without_embeddings", 0)
|
||||
coverage_percent = totals.get("coverage_percent", 0)
|
||||
|
||||
if without_embeddings == 0:
|
||||
logger.info("All content has embeddings, skipping backfill")
|
||||
return {"processed": 0, "success": 0, "failed": 0}
|
||||
|
||||
# Log per-content-type stats for visibility
|
||||
by_type = stats.get("by_type", {})
|
||||
for content_type, type_stats in by_type.items():
|
||||
if type_stats.get("without_embeddings", 0) > 0:
|
||||
logger.info(
|
||||
f"{content_type}: {type_stats['without_embeddings']} items without embeddings "
|
||||
f"({type_stats['coverage_percent']}% coverage)"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Total: {without_embeddings} items without embeddings "
|
||||
f"({coverage_percent}% coverage) - processing all"
|
||||
)
|
||||
|
||||
total_processed = 0
|
||||
total_success = 0
|
||||
total_failed = 0
|
||||
|
||||
# Process in batches until no more missing embeddings
|
||||
while True:
|
||||
result = db_client.backfill_missing_embeddings(batch_size=10)
|
||||
|
||||
total_processed += result["processed"]
|
||||
total_success += result["success"]
|
||||
total_failed += result["failed"]
|
||||
|
||||
if result["processed"] == 0:
|
||||
# No more missing embeddings
|
||||
break
|
||||
|
||||
if result["success"] == 0 and result["processed"] > 0:
|
||||
# All attempts in this batch failed - stop to avoid infinite loop
|
||||
logger.error(
|
||||
f"All {result['processed']} embedding attempts failed - stopping backfill"
|
||||
)
|
||||
break
|
||||
|
||||
# Small delay between batches to avoid rate limits
|
||||
time.sleep(1)
|
||||
|
||||
logger.info(
|
||||
f"Embedding backfill completed: {total_success}/{total_processed} succeeded, "
|
||||
f"{total_failed} failed"
|
||||
)
|
||||
return {
|
||||
"processed": total_processed,
|
||||
"success": total_success,
|
||||
"failed": total_failed,
|
||||
}
|
||||
|
||||
|
||||
# Monitoring functions are now imported from monitoring module
|
||||
|
||||
|
||||
@@ -475,6 +558,19 @@ class Scheduler(AppService):
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
# Embedding Coverage - Every 6 hours
|
||||
# Ensures all approved agents have embeddings for hybrid search
|
||||
# Critical: missing embeddings = agents invisible in search
|
||||
self.scheduler.add_job(
|
||||
ensure_embeddings_coverage,
|
||||
id="ensure_embeddings_coverage",
|
||||
trigger="interval",
|
||||
hours=6,
|
||||
replace_existing=True,
|
||||
max_instances=1, # Prevent overlapping runs
|
||||
jobstore=Jobstores.EXECUTION.value,
|
||||
)
|
||||
|
||||
self.scheduler.add_listener(job_listener, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR)
|
||||
self.scheduler.add_listener(job_missed_listener, EVENT_JOB_MISSED)
|
||||
self.scheduler.add_listener(job_max_instances_listener, EVENT_JOB_MAX_INSTANCES)
|
||||
@@ -632,6 +728,11 @@ class Scheduler(AppService):
|
||||
"""Manually trigger execution accuracy alert checking."""
|
||||
return execution_accuracy_alerts()
|
||||
|
||||
@expose
|
||||
def execute_ensure_embeddings_coverage(self):
|
||||
"""Manually trigger embedding backfill for approved store agents."""
|
||||
return ensure_embeddings_coverage()
|
||||
|
||||
|
||||
class SchedulerClient(AppServiceClient):
|
||||
@classmethod
|
||||
|
||||
@@ -10,6 +10,7 @@ from backend.util.settings import Settings
|
||||
settings = Settings()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI
|
||||
from supabase import AClient, Client
|
||||
|
||||
from backend.data.execution import (
|
||||
@@ -139,6 +140,24 @@ async def get_async_supabase() -> "AClient":
|
||||
)
|
||||
|
||||
|
||||
# ============ OpenAI Client ============ #
|
||||
|
||||
|
||||
@cached(ttl_seconds=3600)
|
||||
def get_openai_client() -> "AsyncOpenAI | None":
|
||||
"""
|
||||
Get a process-cached async OpenAI client for embeddings.
|
||||
|
||||
Returns None if API key is not configured.
|
||||
"""
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
api_key = settings.secrets.openai_internal_api_key
|
||||
if not api_key:
|
||||
return None
|
||||
return AsyncOpenAI(api_key=api_key)
|
||||
|
||||
|
||||
# ============ Notification Queue Helpers ============ #
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
-- CreateExtension
|
||||
-- Supabase: pgvector must be enabled via Dashboard → Database → Extensions first
|
||||
-- Create in public schema so vector type is available across all schemas
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "vector" WITH SCHEMA "public";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'vector extension not available or already exists, skipping';
|
||||
END $$;
|
||||
|
||||
-- CreateEnum
|
||||
CREATE TYPE "ContentType" AS ENUM ('STORE_AGENT', 'BLOCK', 'INTEGRATION', 'DOCUMENTATION', 'LIBRARY_AGENT');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "UnifiedContentEmbedding" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"contentType" "ContentType" NOT NULL,
|
||||
"contentId" TEXT NOT NULL,
|
||||
"userId" TEXT,
|
||||
"embedding" public.vector(1536) NOT NULL,
|
||||
"searchableText" TEXT NOT NULL,
|
||||
"metadata" JSONB NOT NULL DEFAULT '{}',
|
||||
|
||||
CONSTRAINT "UnifiedContentEmbedding_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "UnifiedContentEmbedding_contentType_idx" ON "UnifiedContentEmbedding"("contentType");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "UnifiedContentEmbedding_userId_idx" ON "UnifiedContentEmbedding"("userId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "UnifiedContentEmbedding_contentType_userId_idx" ON "UnifiedContentEmbedding"("contentType", "userId");
|
||||
|
||||
-- CreateIndex
|
||||
-- NULLS NOT DISTINCT ensures only one public (NULL userId) embedding per contentType+contentId
|
||||
-- Requires PostgreSQL 15+. Supabase uses PostgreSQL 15+.
|
||||
CREATE UNIQUE INDEX "UnifiedContentEmbedding_contentType_contentId_userId_key" ON "UnifiedContentEmbedding"("contentType", "contentId", "userId") NULLS NOT DISTINCT;
|
||||
|
||||
-- CreateIndex
|
||||
-- HNSW index for fast vector similarity search on embeddings
|
||||
-- Uses cosine distance operator (<=>), which matches the query in hybrid_search.py
|
||||
CREATE INDEX "UnifiedContentEmbedding_embedding_idx" ON "UnifiedContentEmbedding" USING hnsw ("embedding" public.vector_cosine_ops);
|
||||
@@ -0,0 +1,71 @@
|
||||
-- Acknowledge Supabase-managed extensions to prevent drift warnings
|
||||
-- These extensions are pre-installed by Supabase in specific schemas
|
||||
-- This migration ensures they exist where available (Supabase) or skips gracefully (CI)
|
||||
|
||||
-- Create schemas (safe in both CI and Supabase)
|
||||
CREATE SCHEMA IF NOT EXISTS "extensions";
|
||||
|
||||
-- Extensions that exist in both CI and Supabase
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pgcrypto" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgcrypto extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "uuid-ossp" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'uuid-ossp extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
-- Supabase-specific extensions (skip gracefully in CI)
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pg_stat_statements" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pg_stat_statements extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pg_net" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pg_net extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE EXTENSION IF NOT EXISTS "pgjwt" WITH SCHEMA "extensions";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgjwt extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE SCHEMA IF NOT EXISTS "graphql";
|
||||
CREATE EXTENSION IF NOT EXISTS "pg_graphql" WITH SCHEMA "graphql";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pg_graphql extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE SCHEMA IF NOT EXISTS "pgsodium";
|
||||
CREATE EXTENSION IF NOT EXISTS "pgsodium" WITH SCHEMA "pgsodium";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'pgsodium extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
CREATE SCHEMA IF NOT EXISTS "vault";
|
||||
CREATE EXTENSION IF NOT EXISTS "supabase_vault" WITH SCHEMA "vault";
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RAISE NOTICE 'supabase_vault extension not available, skipping';
|
||||
END $$;
|
||||
|
||||
|
||||
-- Return to platform
|
||||
CREATE SCHEMA IF NOT EXISTS "platform";
|
||||
@@ -1,14 +1,15 @@
|
||||
datasource db {
|
||||
provider = "postgresql"
|
||||
url = env("DATABASE_URL")
|
||||
directUrl = env("DIRECT_URL")
|
||||
provider = "postgresql"
|
||||
url = env("DATABASE_URL")
|
||||
directUrl = env("DIRECT_URL")
|
||||
extensions = [pgvector(map: "vector")]
|
||||
}
|
||||
|
||||
generator client {
|
||||
provider = "prisma-client-py"
|
||||
recursive_type_depth = -1
|
||||
interface = "asyncio"
|
||||
previewFeatures = ["views", "fullTextSearch"]
|
||||
previewFeatures = ["views", "fullTextSearch", "postgresqlExtensions"]
|
||||
partial_type_generator = "backend/data/partial_types.py"
|
||||
}
|
||||
|
||||
@@ -127,8 +128,8 @@ model BuilderSearchHistory {
|
||||
updatedAt DateTime @default(now()) @updatedAt
|
||||
|
||||
searchQuery String
|
||||
filter String[] @default([])
|
||||
byCreator String[] @default([])
|
||||
filter String[] @default([])
|
||||
byCreator String[] @default([])
|
||||
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
@@ -721,26 +722,25 @@ view StoreAgent {
|
||||
storeListingVersionId String
|
||||
updated_at DateTime
|
||||
|
||||
slug String
|
||||
agent_name String
|
||||
agent_video String?
|
||||
agent_output_demo String?
|
||||
agent_image String[]
|
||||
slug String
|
||||
agent_name String
|
||||
agent_video String?
|
||||
agent_output_demo String?
|
||||
agent_image String[]
|
||||
|
||||
featured Boolean @default(false)
|
||||
creator_username String?
|
||||
creator_avatar String?
|
||||
sub_heading String
|
||||
description String
|
||||
categories String[]
|
||||
search Unsupported("tsvector")? @default(dbgenerated("''::tsvector"))
|
||||
runs Int
|
||||
rating Float
|
||||
versions String[]
|
||||
agentGraphVersions String[]
|
||||
agentGraphId String
|
||||
is_available Boolean @default(true)
|
||||
useForOnboarding Boolean @default(false)
|
||||
featured Boolean @default(false)
|
||||
creator_username String?
|
||||
creator_avatar String?
|
||||
sub_heading String
|
||||
description String
|
||||
categories String[]
|
||||
runs Int
|
||||
rating Float
|
||||
versions String[]
|
||||
agentGraphVersions String[]
|
||||
agentGraphId String
|
||||
is_available Boolean @default(true)
|
||||
useForOnboarding Boolean @default(false)
|
||||
|
||||
// Materialized views used (refreshed every 15 minutes via pg_cron):
|
||||
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
|
||||
@@ -856,14 +856,14 @@ model StoreListingVersion {
|
||||
AgentGraph AgentGraph @relation(fields: [agentGraphId, agentGraphVersion], references: [id, version])
|
||||
|
||||
// Content fields
|
||||
name String
|
||||
subHeading String
|
||||
videoUrl String?
|
||||
agentOutputDemoUrl String?
|
||||
imageUrls String[]
|
||||
description String
|
||||
instructions String?
|
||||
categories String[]
|
||||
name String
|
||||
subHeading String
|
||||
videoUrl String?
|
||||
agentOutputDemoUrl String?
|
||||
imageUrls String[]
|
||||
description String
|
||||
instructions String?
|
||||
categories String[]
|
||||
|
||||
isFeatured Boolean @default(false)
|
||||
|
||||
@@ -899,6 +899,9 @@ model StoreListingVersion {
|
||||
// Reviews for this specific version
|
||||
Reviews StoreListingReview[]
|
||||
|
||||
// Note: Embeddings now stored in UnifiedContentEmbedding table
|
||||
// Use contentType=STORE_AGENT and contentId=storeListingVersionId
|
||||
|
||||
@@unique([storeListingId, version])
|
||||
@@index([storeListingId, submissionStatus, isAvailable])
|
||||
@@index([submissionStatus])
|
||||
@@ -906,6 +909,42 @@ model StoreListingVersion {
|
||||
@@index([agentGraphId, agentGraphVersion]) // Non-unique index for efficient lookups
|
||||
}
|
||||
|
||||
// Content type enum for unified search across store agents, blocks, docs
|
||||
// Note: BLOCK/INTEGRATION are file-based (Python classes), not DB records
|
||||
// DOCUMENTATION are file-based (.md files), not DB records
|
||||
// Only STORE_AGENT and LIBRARY_AGENT are stored in database
|
||||
enum ContentType {
|
||||
STORE_AGENT // Database: StoreListingVersion
|
||||
BLOCK // File-based: Python classes in /backend/blocks/
|
||||
INTEGRATION // File-based: Python classes (blocks with credentials)
|
||||
DOCUMENTATION // File-based: .md/.mdx files
|
||||
LIBRARY_AGENT // Database: User's personal agents
|
||||
}
|
||||
|
||||
// Unified embeddings table for all searchable content types
|
||||
// Supports both public content (userId=null) and user-specific content (userId=userID)
|
||||
model UnifiedContentEmbedding {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
// Content identification
|
||||
contentType ContentType
|
||||
contentId String // DB ID (storeListingVersionId) or file identifier (block.id, file_path)
|
||||
userId String? // NULL for public content (store, blocks, docs), userId for private content (library agents)
|
||||
|
||||
// Search data
|
||||
embedding Unsupported("vector(1536)") // pgvector embedding (extension in platform schema)
|
||||
searchableText String // Combined text for search and fallback
|
||||
metadata Json @default("{}") // Content-specific metadata
|
||||
|
||||
@@unique([contentType, contentId, userId], map: "UnifiedContentEmbedding_contentType_contentId_userId_key")
|
||||
@@index([contentType])
|
||||
@@index([userId])
|
||||
@@index([contentType, userId])
|
||||
@@index([embedding], map: "UnifiedContentEmbedding_embedding_idx")
|
||||
}
|
||||
|
||||
model StoreListingReview {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
@@ -998,16 +1037,16 @@ model OAuthApplication {
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
// Application metadata
|
||||
name String
|
||||
description String?
|
||||
logoUrl String? // URL to app logo stored in GCS
|
||||
clientId String @unique
|
||||
clientSecret String // Hashed with Scrypt (same as API keys)
|
||||
clientSecretSalt String // Salt for Scrypt hashing
|
||||
name String
|
||||
description String?
|
||||
logoUrl String? // URL to app logo stored in GCS
|
||||
clientId String @unique
|
||||
clientSecret String // Hashed with Scrypt (same as API keys)
|
||||
clientSecretSalt String // Salt for Scrypt hashing
|
||||
|
||||
// OAuth configuration
|
||||
redirectUris String[] // Allowed callback URLs
|
||||
grantTypes String[] @default(["authorization_code", "refresh_token"])
|
||||
grantTypes String[] @default(["authorization_code", "refresh_token"])
|
||||
scopes APIKeyPermission[] // Which permissions the app can request
|
||||
|
||||
// Application management
|
||||
|
||||
@@ -68,7 +68,10 @@ export const NodeHeader = ({ data, nodeId }: Props) => {
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div>
|
||||
<Text variant="large-semibold" className="line-clamp-1">
|
||||
<Text
|
||||
variant="large-semibold"
|
||||
className="line-clamp-1 hover:cursor-text"
|
||||
>
|
||||
{beautifyString(title).replace("Block", "").trim()}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
@@ -89,6 +89,18 @@ export function extractOptions(
|
||||
|
||||
// get display type and color for schema types [need for type display next to field name]
|
||||
export const getTypeDisplayInfo = (schema: any) => {
|
||||
if (
|
||||
schema?.type === "array" &&
|
||||
"format" in schema &&
|
||||
schema.format === "table"
|
||||
) {
|
||||
return {
|
||||
displayType: "table",
|
||||
colorClass: "!text-indigo-500",
|
||||
hexColor: "#6366f1",
|
||||
};
|
||||
}
|
||||
|
||||
if (schema?.type === "string" && schema?.format) {
|
||||
const formatMap: Record<
|
||||
string,
|
||||
|
||||
@@ -36,6 +36,7 @@ type Props = {
|
||||
readOnly?: boolean;
|
||||
isOptional?: boolean;
|
||||
showTitle?: boolean;
|
||||
variant?: "default" | "node";
|
||||
};
|
||||
|
||||
export function CredentialsInput({
|
||||
@@ -48,6 +49,7 @@ export function CredentialsInput({
|
||||
readOnly = false,
|
||||
isOptional = false,
|
||||
showTitle = true,
|
||||
variant = "default",
|
||||
}: Props) {
|
||||
const hookData = useCredentialsInput({
|
||||
schema,
|
||||
@@ -123,6 +125,7 @@ export function CredentialsInput({
|
||||
onClearCredential={() => onSelectCredential(undefined)}
|
||||
readOnly={readOnly}
|
||||
allowNone={isOptional}
|
||||
variant={variant}
|
||||
/>
|
||||
) : (
|
||||
<div className="mb-4 space-y-2">
|
||||
|
||||
@@ -30,6 +30,8 @@ type CredentialRowProps = {
|
||||
readOnly?: boolean;
|
||||
showCaret?: boolean;
|
||||
asSelectTrigger?: boolean;
|
||||
/** When "node", applies compact styling for node context */
|
||||
variant?: "default" | "node";
|
||||
};
|
||||
|
||||
export function CredentialRow({
|
||||
@@ -41,14 +43,22 @@ export function CredentialRow({
|
||||
readOnly = false,
|
||||
showCaret = false,
|
||||
asSelectTrigger = false,
|
||||
variant = "default",
|
||||
}: CredentialRowProps) {
|
||||
const ProviderIcon = providerIcons[provider] || fallbackIcon;
|
||||
const isNodeVariant = variant === "node";
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center gap-3 rounded-medium border border-zinc-200 bg-white p-3 transition-colors",
|
||||
asSelectTrigger ? "border-0 bg-transparent" : readOnly ? "w-fit" : "",
|
||||
asSelectTrigger && isNodeVariant
|
||||
? "min-w-0 flex-1 overflow-hidden border-0 bg-transparent"
|
||||
: asSelectTrigger
|
||||
? "border-0 bg-transparent"
|
||||
: readOnly
|
||||
? "w-fit"
|
||||
: "",
|
||||
)}
|
||||
onClick={readOnly || showCaret || asSelectTrigger ? undefined : onSelect}
|
||||
style={
|
||||
@@ -61,19 +71,31 @@ export function CredentialRow({
|
||||
<ProviderIcon className="h-3 w-3 text-white" />
|
||||
</div>
|
||||
<IconKey className="h-5 w-5 shrink-0 text-zinc-800" />
|
||||
<div className="flex min-w-0 flex-1 flex-nowrap items-center gap-4">
|
||||
<div
|
||||
className={cn(
|
||||
"flex min-w-0 flex-1 flex-nowrap items-center gap-4",
|
||||
isNodeVariant && "overflow-hidden",
|
||||
)}
|
||||
>
|
||||
<Text
|
||||
variant="body"
|
||||
className="line-clamp-1 flex-[0_0_50%] text-ellipsis tracking-tight"
|
||||
className={cn(
|
||||
"tracking-tight",
|
||||
isNodeVariant
|
||||
? "truncate"
|
||||
: "line-clamp-1 flex-[0_0_50%] text-ellipsis",
|
||||
)}
|
||||
>
|
||||
{getCredentialDisplayName(credential, displayName)}
|
||||
</Text>
|
||||
<Text
|
||||
variant="large"
|
||||
className="lex-[0_0_40%] relative top-1 hidden overflow-hidden whitespace-nowrap font-mono tracking-tight md:block"
|
||||
>
|
||||
{"*".repeat(MASKED_KEY_LENGTH)}
|
||||
</Text>
|
||||
{!(asSelectTrigger && isNodeVariant) && (
|
||||
<Text
|
||||
variant="large"
|
||||
className="relative top-1 hidden overflow-hidden whitespace-nowrap font-mono tracking-tight md:block"
|
||||
>
|
||||
{"*".repeat(MASKED_KEY_LENGTH)}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
{showCaret && !asSelectTrigger && (
|
||||
<CaretDown className="h-4 w-4 shrink-0 text-gray-400" />
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
} from "@/components/__legacy__/ui/select";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { CredentialsMetaInput } from "@/lib/autogpt-server-api/types";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useEffect } from "react";
|
||||
import { getCredentialDisplayName } from "../../helpers";
|
||||
import { CredentialRow } from "../CredentialRow/CredentialRow";
|
||||
@@ -26,6 +27,8 @@ interface Props {
|
||||
onClearCredential?: () => void;
|
||||
readOnly?: boolean;
|
||||
allowNone?: boolean;
|
||||
/** When "node", applies compact styling for node context */
|
||||
variant?: "default" | "node";
|
||||
}
|
||||
|
||||
export function CredentialsSelect({
|
||||
@@ -37,6 +40,7 @@ export function CredentialsSelect({
|
||||
onClearCredential,
|
||||
readOnly = false,
|
||||
allowNone = true,
|
||||
variant = "default",
|
||||
}: Props) {
|
||||
// Auto-select first credential if none is selected (only if allowNone is false)
|
||||
useEffect(() => {
|
||||
@@ -59,7 +63,12 @@ export function CredentialsSelect({
|
||||
value={selectedCredentials?.id || (allowNone ? "__none__" : "")}
|
||||
onValueChange={handleValueChange}
|
||||
>
|
||||
<SelectTrigger className="h-auto min-h-12 w-full rounded-medium border-zinc-200 p-0 pr-4 shadow-none">
|
||||
<SelectTrigger
|
||||
className={cn(
|
||||
"h-auto min-h-12 w-full rounded-medium border-zinc-200 p-0 pr-4 shadow-none",
|
||||
variant === "node" && "overflow-hidden",
|
||||
)}
|
||||
>
|
||||
{selectedCredentials ? (
|
||||
<SelectValue key={selectedCredentials.id} asChild>
|
||||
<CredentialRow
|
||||
@@ -75,6 +84,7 @@ export function CredentialsSelect({
|
||||
onDelete={() => {}}
|
||||
readOnly={readOnly}
|
||||
asSelectTrigger={true}
|
||||
variant={variant}
|
||||
/>
|
||||
</SelectValue>
|
||||
) : (
|
||||
|
||||
@@ -29,7 +29,7 @@ export default function Layout({ children }: { children: React.ReactNode }) {
|
||||
href: "/profile/dashboard",
|
||||
icon: <StorefrontIcon className="size-5" />,
|
||||
},
|
||||
...(isPaymentEnabled || true
|
||||
...(isPaymentEnabled
|
||||
? [
|
||||
{
|
||||
text: "Billing",
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
import type { Meta, StoryObj } from "@storybook/nextjs";
|
||||
import { TooltipProvider } from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { Table } from "./Table";
|
||||
|
||||
const meta = {
|
||||
title: "Molecules/Table",
|
||||
component: Table,
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<TooltipProvider>
|
||||
<Story />
|
||||
</TooltipProvider>
|
||||
),
|
||||
],
|
||||
parameters: {
|
||||
layout: "centered",
|
||||
},
|
||||
tags: ["autodocs"],
|
||||
argTypes: {
|
||||
allowAddRow: {
|
||||
control: "boolean",
|
||||
description: "Whether to show the Add row button",
|
||||
},
|
||||
allowDeleteRow: {
|
||||
control: "boolean",
|
||||
description: "Whether to show delete buttons for each row",
|
||||
},
|
||||
readOnly: {
|
||||
control: "boolean",
|
||||
description:
|
||||
"Whether the table is read-only (renders text instead of inputs)",
|
||||
},
|
||||
addRowLabel: {
|
||||
control: "text",
|
||||
description: "Label for the Add row button",
|
||||
},
|
||||
},
|
||||
} satisfies Meta<typeof Table>;
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof meta>;
|
||||
|
||||
export const Default: Story = {
|
||||
args: {
|
||||
columns: ["name", "email", "role"],
|
||||
allowAddRow: true,
|
||||
allowDeleteRow: true,
|
||||
},
|
||||
};
|
||||
|
||||
export const WithDefaultValues: Story = {
|
||||
args: {
|
||||
columns: ["name", "email", "role"],
|
||||
defaultValues: [
|
||||
{ name: "John Doe", email: "john@example.com", role: "Admin" },
|
||||
{ name: "Jane Smith", email: "jane@example.com", role: "User" },
|
||||
{ name: "Bob Wilson", email: "bob@example.com", role: "Editor" },
|
||||
],
|
||||
allowAddRow: true,
|
||||
allowDeleteRow: true,
|
||||
},
|
||||
};
|
||||
|
||||
export const ReadOnly: Story = {
|
||||
args: {
|
||||
columns: ["name", "email"],
|
||||
defaultValues: [
|
||||
{ name: "John Doe", email: "john@example.com" },
|
||||
{ name: "Jane Smith", email: "jane@example.com" },
|
||||
],
|
||||
readOnly: true,
|
||||
},
|
||||
};
|
||||
|
||||
export const NoAddOrDelete: Story = {
|
||||
args: {
|
||||
columns: ["name", "email"],
|
||||
defaultValues: [
|
||||
{ name: "John Doe", email: "john@example.com" },
|
||||
{ name: "Jane Smith", email: "jane@example.com" },
|
||||
],
|
||||
allowAddRow: false,
|
||||
allowDeleteRow: false,
|
||||
},
|
||||
};
|
||||
|
||||
export const SingleColumn: Story = {
|
||||
args: {
|
||||
columns: ["item"],
|
||||
allowAddRow: true,
|
||||
allowDeleteRow: true,
|
||||
addRowLabel: "Add item",
|
||||
},
|
||||
};
|
||||
|
||||
export const CustomAddLabel: Story = {
|
||||
args: {
|
||||
columns: ["key", "value"],
|
||||
allowAddRow: true,
|
||||
allowDeleteRow: true,
|
||||
addRowLabel: "Add new entry",
|
||||
},
|
||||
};
|
||||
|
||||
export const KeyValuePairs: Story = {
|
||||
args: {
|
||||
columns: ["key", "value"],
|
||||
defaultValues: [
|
||||
{ key: "API_KEY", value: "sk-..." },
|
||||
{ key: "DATABASE_URL", value: "postgres://..." },
|
||||
],
|
||||
allowAddRow: true,
|
||||
allowDeleteRow: true,
|
||||
addRowLabel: "Add variable",
|
||||
},
|
||||
};
|
||||
@@ -0,0 +1,133 @@
|
||||
import * as React from "react";
|
||||
import {
|
||||
Table as BaseTable,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/__legacy__/ui/table";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Plus, Trash2 } from "lucide-react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useTable, RowData } from "./useTable";
|
||||
import { formatColumnTitle, formatPlaceholder } from "./helpers";
|
||||
|
||||
export interface TableProps {
|
||||
columns: string[];
|
||||
defaultValues?: RowData[];
|
||||
onChange?: (rows: RowData[]) => void;
|
||||
allowAddRow?: boolean;
|
||||
allowDeleteRow?: boolean;
|
||||
addRowLabel?: string;
|
||||
className?: string;
|
||||
readOnly?: boolean;
|
||||
}
|
||||
|
||||
export function Table({
|
||||
columns,
|
||||
defaultValues,
|
||||
onChange,
|
||||
allowAddRow = true,
|
||||
allowDeleteRow = true,
|
||||
addRowLabel = "Add row",
|
||||
className,
|
||||
readOnly = false,
|
||||
}: TableProps) {
|
||||
const { rows, handleAddRow, handleDeleteRow, handleCellChange } = useTable({
|
||||
columns,
|
||||
defaultValues,
|
||||
onChange,
|
||||
});
|
||||
|
||||
const showDeleteColumn = allowDeleteRow && !readOnly;
|
||||
const showAddButton = allowAddRow && !readOnly;
|
||||
|
||||
return (
|
||||
<div className={cn("flex flex-col gap-3", className)}>
|
||||
<div className="overflow-hidden rounded-xl border border-zinc-200 bg-white">
|
||||
<BaseTable>
|
||||
<TableHeader>
|
||||
<TableRow className="border-b border-zinc-100 bg-zinc-50/50">
|
||||
{columns.map((column) => (
|
||||
<TableHead
|
||||
key={column}
|
||||
className="h-10 px-3 text-sm font-medium text-zinc-600"
|
||||
>
|
||||
{formatColumnTitle(column)}
|
||||
</TableHead>
|
||||
))}
|
||||
{showDeleteColumn && <TableHead className="w-[50px]" />}
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{rows.map((row, rowIndex) => (
|
||||
<TableRow key={rowIndex} className="border-none">
|
||||
{columns.map((column) => (
|
||||
<TableCell key={`${rowIndex}-${column}`} className="p-2">
|
||||
{readOnly ? (
|
||||
<Text
|
||||
variant="body"
|
||||
className="px-3 py-2 text-sm text-zinc-800"
|
||||
>
|
||||
{row[column] || "-"}
|
||||
</Text>
|
||||
) : (
|
||||
<Input
|
||||
id={`table-${rowIndex}-${column}`}
|
||||
label={formatColumnTitle(column)}
|
||||
hideLabel
|
||||
value={row[column] ?? ""}
|
||||
onChange={(e) =>
|
||||
handleCellChange(rowIndex, column, e.target.value)
|
||||
}
|
||||
placeholder={formatPlaceholder(column)}
|
||||
size="small"
|
||||
wrapperClassName="mb-0"
|
||||
/>
|
||||
)}
|
||||
</TableCell>
|
||||
))}
|
||||
{showDeleteColumn && (
|
||||
<TableCell className="p-2">
|
||||
<Button
|
||||
variant="icon"
|
||||
size="icon"
|
||||
onClick={() => handleDeleteRow(rowIndex)}
|
||||
aria-label="Delete row"
|
||||
className="text-zinc-400 transition-colors hover:text-red-500"
|
||||
>
|
||||
<Trash2 className="h-4 w-4" />
|
||||
</Button>
|
||||
</TableCell>
|
||||
)}
|
||||
</TableRow>
|
||||
))}
|
||||
{showAddButton && (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={columns.length + (showDeleteColumn ? 1 : 0)}
|
||||
className="p-2"
|
||||
>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={handleAddRow}
|
||||
leftIcon={<Plus className="h-4 w-4" />}
|
||||
className="w-fit"
|
||||
>
|
||||
{addRowLabel}
|
||||
</Button>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
)}
|
||||
</TableBody>
|
||||
</BaseTable>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { type RowData } from "./useTable";
|
||||
@@ -0,0 +1,7 @@
|
||||
export const formatColumnTitle = (key: string): string => {
|
||||
return key.charAt(0).toUpperCase() + key.slice(1);
|
||||
};
|
||||
|
||||
export const formatPlaceholder = (key: string): string => {
|
||||
return `Enter ${key.toLowerCase()}`;
|
||||
};
|
||||
@@ -0,0 +1,81 @@
|
||||
import { useState, useEffect } from "react";
|
||||
|
||||
export type RowData = Record<string, string>;
|
||||
|
||||
interface UseTableOptions {
|
||||
columns: string[];
|
||||
defaultValues?: RowData[];
|
||||
onChange?: (rows: RowData[]) => void;
|
||||
}
|
||||
|
||||
export function useTable({
|
||||
columns,
|
||||
defaultValues,
|
||||
onChange,
|
||||
}: UseTableOptions) {
|
||||
const createEmptyRow = (): RowData => {
|
||||
const emptyRow: RowData = {};
|
||||
columns.forEach((column) => {
|
||||
emptyRow[column] = "";
|
||||
});
|
||||
return emptyRow;
|
||||
};
|
||||
|
||||
const [rows, setRows] = useState<RowData[]>(() => {
|
||||
if (defaultValues && defaultValues.length > 0) {
|
||||
return defaultValues;
|
||||
}
|
||||
return [];
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (defaultValues !== undefined) {
|
||||
setRows(defaultValues);
|
||||
}
|
||||
}, [defaultValues]);
|
||||
|
||||
const updateRows = (newRows: RowData[]) => {
|
||||
setRows(newRows);
|
||||
onChange?.(newRows);
|
||||
};
|
||||
|
||||
const handleAddRow = () => {
|
||||
const newRows = [...rows, createEmptyRow()];
|
||||
updateRows(newRows);
|
||||
};
|
||||
|
||||
const handleDeleteRow = (rowIndex: number) => {
|
||||
const newRows = rows.filter((_, index) => index !== rowIndex);
|
||||
updateRows(newRows);
|
||||
};
|
||||
|
||||
const handleCellChange = (
|
||||
rowIndex: number,
|
||||
columnKey: string,
|
||||
value: string,
|
||||
) => {
|
||||
const newRows = rows.map((row, index) => {
|
||||
if (index === rowIndex) {
|
||||
return {
|
||||
...row,
|
||||
[columnKey]: value,
|
||||
};
|
||||
}
|
||||
return row;
|
||||
});
|
||||
updateRows(newRows);
|
||||
};
|
||||
|
||||
const clearAll = () => {
|
||||
updateRows([]);
|
||||
};
|
||||
|
||||
return {
|
||||
rows,
|
||||
handleAddRow,
|
||||
handleDeleteRow,
|
||||
handleCellChange,
|
||||
clearAll,
|
||||
createEmptyRow,
|
||||
};
|
||||
}
|
||||
@@ -30,6 +30,8 @@ export const FormRenderer = ({
|
||||
return generateUiSchemaForCustomFields(preprocessedSchema, uiSchema);
|
||||
}, [preprocessedSchema, uiSchema]);
|
||||
|
||||
console.log("preprocessedSchema", preprocessedSchema);
|
||||
|
||||
return (
|
||||
<div className={"mb-6 mt-4"}>
|
||||
<Form
|
||||
|
||||
@@ -5,19 +5,14 @@ import { useAnyOfField } from "./useAnyOfField";
|
||||
import { getHandleId, updateUiOption } from "../../helpers";
|
||||
import { useEdgeStore } from "@/app/(platform)/build/stores/edgeStore";
|
||||
import { ANY_OF_FLAG } from "../../constants";
|
||||
import { findCustomFieldId } from "../../registry";
|
||||
|
||||
export const AnyOfField = (props: FieldProps) => {
|
||||
const { registry, schema } = props;
|
||||
const { fields } = registry;
|
||||
const { SchemaField: _SchemaField } = fields;
|
||||
const { nodeId } = registry.formContext;
|
||||
|
||||
const { isInputConnected } = useEdgeStore();
|
||||
|
||||
const uiOptions = getUiOptions(props.uiSchema, props.globalUiOptions);
|
||||
|
||||
const Widget = getWidget({ type: "string" }, "select", registry.widgets);
|
||||
|
||||
const {
|
||||
handleOptionChange,
|
||||
enumOptions,
|
||||
@@ -26,6 +21,15 @@ export const AnyOfField = (props: FieldProps) => {
|
||||
field_id,
|
||||
} = useAnyOfField(props);
|
||||
|
||||
const parentCustomFieldId = findCustomFieldId(schema);
|
||||
if (parentCustomFieldId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const uiOptions = getUiOptions(props.uiSchema, props.globalUiOptions);
|
||||
|
||||
const Widget = getWidget({ type: "string" }, "select", registry.widgets);
|
||||
|
||||
const handleId = getHandleId({
|
||||
uiOptions,
|
||||
id: field_id + ANY_OF_FLAG,
|
||||
@@ -40,12 +44,21 @@ export const AnyOfField = (props: FieldProps) => {
|
||||
|
||||
const isHandleConnected = isInputConnected(nodeId, handleId);
|
||||
|
||||
// Now anyOf can render - custom fields if the option schema matches a custom field
|
||||
const optionCustomFieldId = optionSchema
|
||||
? findCustomFieldId(optionSchema)
|
||||
: null;
|
||||
|
||||
const optionUiSchema = optionCustomFieldId
|
||||
? { ...updatedUiSchema, "ui:field": optionCustomFieldId }
|
||||
: updatedUiSchema;
|
||||
|
||||
const optionsSchemaField =
|
||||
(optionSchema && optionSchema.type !== "null" && (
|
||||
<_SchemaField
|
||||
{...props}
|
||||
schema={optionSchema}
|
||||
uiSchema={updatedUiSchema}
|
||||
uiSchema={optionUiSchema}
|
||||
/>
|
||||
)) ||
|
||||
null;
|
||||
|
||||
@@ -17,6 +17,7 @@ interface InputExpanderModalProps {
|
||||
defaultValue: string;
|
||||
description?: string;
|
||||
placeholder?: string;
|
||||
inputType?: "text" | "json";
|
||||
}
|
||||
|
||||
export const InputExpanderModal: FC<InputExpanderModalProps> = ({
|
||||
@@ -27,6 +28,7 @@ export const InputExpanderModal: FC<InputExpanderModalProps> = ({
|
||||
defaultValue,
|
||||
description,
|
||||
placeholder,
|
||||
inputType = "text",
|
||||
}) => {
|
||||
const [tempValue, setTempValue] = useState(defaultValue);
|
||||
const [isCopied, setIsCopied] = useState(false);
|
||||
@@ -78,7 +80,10 @@ export const InputExpanderModal: FC<InputExpanderModalProps> = ({
|
||||
hideLabel
|
||||
id="input-expander-modal"
|
||||
value={tempValue}
|
||||
className="!min-h-[300px] rounded-2xlarge"
|
||||
className={cn(
|
||||
"!min-h-[300px] rounded-2xlarge",
|
||||
inputType === "json" && "font-mono text-sm",
|
||||
)}
|
||||
onChange={(e) => setTempValue(e.target.value)}
|
||||
placeholder={placeholder || "Enter text..."}
|
||||
autoFocus
|
||||
|
||||
@@ -88,6 +88,8 @@ export const CredentialsField = (props: FieldProps) => {
|
||||
showTitle={false}
|
||||
readOnly={formContext?.readOnly}
|
||||
isOptional={!isRequired}
|
||||
className="w-full"
|
||||
variant="node"
|
||||
/>
|
||||
|
||||
{/* Optional credentials toggle - only show in builder canvas, not run dialogs */}
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
"use client";
|
||||
|
||||
import { FieldProps, getTemplate, getUiOptions } from "@rjsf/utils";
|
||||
import { Input } from "@/components/atoms/Input/Input";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/atoms/Tooltip/BaseTooltip";
|
||||
import { ArrowsOutIcon } from "@phosphor-icons/react";
|
||||
import { InputExpanderModal } from "../../base/standard/widgets/TextInput/TextInputExpanderModal";
|
||||
import { getHandleId, updateUiOption } from "../../helpers";
|
||||
import { useJsonTextField } from "./useJsonTextField";
|
||||
import { getPlaceholder } from "./helpers";
|
||||
|
||||
export const JsonTextField = (props: FieldProps) => {
|
||||
const {
|
||||
formData,
|
||||
onChange,
|
||||
schema,
|
||||
registry,
|
||||
uiSchema,
|
||||
required,
|
||||
name,
|
||||
fieldPathId,
|
||||
} = props;
|
||||
|
||||
const uiOptions = getUiOptions(uiSchema);
|
||||
|
||||
const TitleFieldTemplate = getTemplate(
|
||||
"TitleFieldTemplate",
|
||||
registry,
|
||||
uiOptions,
|
||||
);
|
||||
|
||||
const fieldId = fieldPathId?.$id ?? props.id ?? "json-field";
|
||||
|
||||
const handleId = getHandleId({
|
||||
uiOptions,
|
||||
id: fieldId,
|
||||
schema: schema,
|
||||
});
|
||||
|
||||
const updatedUiSchema = updateUiOption(uiSchema, {
|
||||
handleId: handleId,
|
||||
});
|
||||
|
||||
const {
|
||||
textValue,
|
||||
isModalOpen,
|
||||
handleChange,
|
||||
handleModalOpen,
|
||||
handleModalClose,
|
||||
handleModalSave,
|
||||
} = useJsonTextField({
|
||||
formData,
|
||||
onChange,
|
||||
path: fieldPathId?.path,
|
||||
});
|
||||
|
||||
const placeholder = getPlaceholder(schema);
|
||||
const title = schema.title || name || "JSON Value";
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
<TitleFieldTemplate
|
||||
id={fieldId}
|
||||
title={title}
|
||||
required={required}
|
||||
schema={schema}
|
||||
uiSchema={updatedUiSchema}
|
||||
registry={registry}
|
||||
/>
|
||||
<div className="nodrag relative flex items-center gap-2">
|
||||
<Input
|
||||
id={fieldId}
|
||||
hideLabel={true}
|
||||
type="textarea"
|
||||
label=""
|
||||
size="small"
|
||||
wrapperClassName="mb-0 flex-1 "
|
||||
value={textValue}
|
||||
onChange={handleChange}
|
||||
placeholder={placeholder}
|
||||
required={required}
|
||||
disabled={props.disabled}
|
||||
className="min-h-[60px] pr-8 font-mono text-xs"
|
||||
/>
|
||||
|
||||
<Tooltip delayDuration={0}>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={handleModalOpen}
|
||||
type="button"
|
||||
className="p-1"
|
||||
>
|
||||
<ArrowsOutIcon className="size-4" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>Expand input</TooltipContent>
|
||||
</Tooltip>
|
||||
</div>
|
||||
{schema.description && (
|
||||
<span className="text-xs text-gray-500">{schema.description}</span>
|
||||
)}
|
||||
|
||||
<InputExpanderModal
|
||||
isOpen={isModalOpen}
|
||||
onClose={handleModalClose}
|
||||
onSave={handleModalSave}
|
||||
title={`Edit ${title}`}
|
||||
description={schema.description || "Enter valid JSON"}
|
||||
defaultValue={textValue}
|
||||
placeholder={placeholder}
|
||||
inputType="json"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default JsonTextField;
|
||||
@@ -0,0 +1,67 @@
|
||||
import { RJSFSchema } from "@rjsf/utils";
|
||||
|
||||
/**
|
||||
* Converts form data to a JSON string for display
|
||||
* @param formData - The data to stringify
|
||||
* @returns JSON string or empty string if data is null/undefined
|
||||
*/
|
||||
export function stringifyFormData(formData: unknown): string {
|
||||
if (formData === undefined || formData === null) {
|
||||
return "";
|
||||
}
|
||||
try {
|
||||
return JSON.stringify(formData, null, 2);
|
||||
} catch {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses a JSON string into an object/array
|
||||
* @param value - The JSON string to parse
|
||||
* @returns Parsed value or undefined if parsing fails or empty
|
||||
*/
|
||||
export function parseJsonValue(value: string): unknown | undefined {
|
||||
const trimmed = value.trim();
|
||||
if (trimmed === "") {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
try {
|
||||
return JSON.parse(trimmed);
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the appropriate placeholder text based on schema type
|
||||
* @param schema - The JSON schema
|
||||
* @returns Placeholder string
|
||||
*/
|
||||
export function getPlaceholder(schema: RJSFSchema): string {
|
||||
if (schema.type === "array") {
|
||||
return '["item1", "item2"] or [{"key": "value"}]';
|
||||
}
|
||||
if (schema.type === "object") {
|
||||
return '{"key": "value"}';
|
||||
}
|
||||
return "Enter JSON value...";
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a JSON string is valid
|
||||
* @param value - The JSON string to validate
|
||||
* @returns true if valid JSON, false otherwise
|
||||
*/
|
||||
export function isValidJson(value: string): boolean {
|
||||
if (value.trim() === "") {
|
||||
return true; // Empty is considered valid (will be undefined)
|
||||
}
|
||||
try {
|
||||
JSON.parse(value);
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
import { useState, useEffect, useCallback } from "react";
|
||||
import { FieldProps } from "@rjsf/utils";
|
||||
import { stringifyFormData, parseJsonValue, isValidJson } from "./helpers";
|
||||
|
||||
type FieldOnChange = FieldProps["onChange"];
|
||||
type FieldPathId = FieldProps["fieldPathId"];
|
||||
|
||||
interface UseJsonTextFieldOptions {
|
||||
formData: unknown;
|
||||
onChange: FieldOnChange;
|
||||
path?: FieldPathId["path"];
|
||||
}
|
||||
|
||||
interface UseJsonTextFieldReturn {
|
||||
textValue: string;
|
||||
isModalOpen: boolean;
|
||||
hasError: boolean;
|
||||
handleChange: (
|
||||
e: React.ChangeEvent<HTMLInputElement | HTMLTextAreaElement>,
|
||||
) => void;
|
||||
handleModalOpen: () => void;
|
||||
handleModalClose: () => void;
|
||||
handleModalSave: (value: string) => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom hook for managing JSON text field state and handlers
|
||||
*/
|
||||
export function useJsonTextField({
|
||||
formData,
|
||||
onChange,
|
||||
path,
|
||||
}: UseJsonTextFieldOptions): UseJsonTextFieldReturn {
|
||||
const [textValue, setTextValue] = useState(() => stringifyFormData(formData));
|
||||
const [isModalOpen, setIsModalOpen] = useState(false);
|
||||
const [hasError, setHasError] = useState(false);
|
||||
|
||||
// Update text value when formData changes externally
|
||||
useEffect(() => {
|
||||
const newValue = stringifyFormData(formData);
|
||||
setTextValue(newValue);
|
||||
setHasError(false);
|
||||
}, [formData]);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(e: React.ChangeEvent<HTMLInputElement | HTMLTextAreaElement>) => {
|
||||
const value = e.target.value;
|
||||
setTextValue(value);
|
||||
|
||||
// Validate JSON and update error state
|
||||
const valid = isValidJson(value);
|
||||
setHasError(!valid);
|
||||
|
||||
// Try to parse and update formData
|
||||
if (value.trim() === "") {
|
||||
onChange(undefined, path ?? []);
|
||||
return;
|
||||
}
|
||||
|
||||
const parsed = parseJsonValue(value);
|
||||
if (parsed !== undefined) {
|
||||
onChange(parsed, path ?? []);
|
||||
}
|
||||
},
|
||||
[onChange, path],
|
||||
);
|
||||
|
||||
const handleModalOpen = useCallback(() => {
|
||||
setIsModalOpen(true);
|
||||
}, []);
|
||||
|
||||
const handleModalClose = useCallback(() => {
|
||||
setIsModalOpen(false);
|
||||
}, []);
|
||||
|
||||
const handleModalSave = useCallback(
|
||||
(value: string) => {
|
||||
setTextValue(value);
|
||||
setIsModalOpen(false);
|
||||
|
||||
// Validate and update
|
||||
const valid = isValidJson(value);
|
||||
setHasError(!valid);
|
||||
|
||||
if (value.trim() === "") {
|
||||
onChange(undefined, path ?? []);
|
||||
return;
|
||||
}
|
||||
|
||||
const parsed = parseJsonValue(value);
|
||||
if (parsed !== undefined) {
|
||||
onChange(parsed, path ?? []);
|
||||
}
|
||||
},
|
||||
[onChange, path],
|
||||
);
|
||||
|
||||
return {
|
||||
textValue,
|
||||
isModalOpen,
|
||||
hasError,
|
||||
handleChange,
|
||||
handleModalOpen,
|
||||
handleModalClose,
|
||||
handleModalSave,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
import React from "react";
|
||||
import { FieldProps, getUiOptions } from "@rjsf/utils";
|
||||
import { BlockIOObjectSubSchema } from "@/lib/autogpt-server-api/types";
|
||||
import {
|
||||
MultiSelector,
|
||||
MultiSelectorContent,
|
||||
MultiSelectorInput,
|
||||
MultiSelectorItem,
|
||||
MultiSelectorList,
|
||||
MultiSelectorTrigger,
|
||||
} from "@/components/__legacy__/ui/multiselect";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useMultiSelectField } from "./useMultiSelectField";
|
||||
|
||||
export const MultiSelectField = (props: FieldProps) => {
|
||||
const { schema, formData, onChange, fieldPathId } = props;
|
||||
const uiOptions = getUiOptions(props.uiSchema);
|
||||
|
||||
const { optionSchema, options, selection, createChangeHandler } =
|
||||
useMultiSelectField({
|
||||
schema: schema as BlockIOObjectSubSchema,
|
||||
formData,
|
||||
});
|
||||
|
||||
const handleValuesChange = createChangeHandler(onChange, fieldPathId);
|
||||
|
||||
const displayName = schema.title || "options";
|
||||
|
||||
return (
|
||||
<div className={cn("flex flex-col", uiOptions.className)}>
|
||||
<MultiSelector
|
||||
className="nodrag"
|
||||
values={selection}
|
||||
onValuesChange={handleValuesChange}
|
||||
>
|
||||
<MultiSelectorTrigger className="rounded-3xl border border-zinc-200 bg-white px-2 shadow-none">
|
||||
<MultiSelectorInput
|
||||
placeholder={
|
||||
(schema as any).placeholder ?? `Select ${displayName}...`
|
||||
}
|
||||
/>
|
||||
</MultiSelectorTrigger>
|
||||
<MultiSelectorContent className="nowheel">
|
||||
<MultiSelectorList>
|
||||
{options
|
||||
.map((key) => ({ ...optionSchema[key], key }))
|
||||
.map(({ key, title, description }) => (
|
||||
<MultiSelectorItem key={key} value={key} title={description}>
|
||||
{title ?? key}
|
||||
</MultiSelectorItem>
|
||||
))}
|
||||
</MultiSelectorList>
|
||||
</MultiSelectorContent>
|
||||
</MultiSelector>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1 @@
|
||||
export { MultiSelectField } from "./MultiSelectField";
|
||||
@@ -0,0 +1,65 @@
|
||||
import { FieldProps } from "@rjsf/utils";
|
||||
import { BlockIOObjectSubSchema } from "@/lib/autogpt-server-api/types";
|
||||
|
||||
type FormData = Record<string, boolean> | null | undefined;
|
||||
|
||||
interface UseMultiSelectFieldOptions {
|
||||
schema: BlockIOObjectSubSchema;
|
||||
formData: FormData;
|
||||
}
|
||||
|
||||
export function useMultiSelectField({
|
||||
schema,
|
||||
formData,
|
||||
}: UseMultiSelectFieldOptions) {
|
||||
const getOptionSchema = (): Record<string, BlockIOObjectSubSchema> => {
|
||||
if (schema.properties) {
|
||||
return schema.properties as Record<string, BlockIOObjectSubSchema>;
|
||||
}
|
||||
if (
|
||||
"anyOf" in schema &&
|
||||
Array.isArray(schema.anyOf) &&
|
||||
schema.anyOf.length > 0 &&
|
||||
"properties" in schema.anyOf[0]
|
||||
) {
|
||||
return (schema.anyOf[0] as BlockIOObjectSubSchema).properties as Record<
|
||||
string,
|
||||
BlockIOObjectSubSchema
|
||||
>;
|
||||
}
|
||||
return {};
|
||||
};
|
||||
|
||||
const optionSchema = getOptionSchema();
|
||||
const options = Object.keys(optionSchema);
|
||||
|
||||
const getSelection = (): string[] => {
|
||||
if (!formData || typeof formData !== "object") {
|
||||
return [];
|
||||
}
|
||||
return Object.entries(formData)
|
||||
.filter(([_, value]) => value === true)
|
||||
.map(([key]) => key);
|
||||
};
|
||||
|
||||
const selection = getSelection();
|
||||
|
||||
const createChangeHandler =
|
||||
(
|
||||
onChange: FieldProps["onChange"],
|
||||
fieldPathId: FieldProps["fieldPathId"],
|
||||
) =>
|
||||
(values: string[]) => {
|
||||
const newValue = Object.fromEntries(
|
||||
options.map((opt) => [opt, values.includes(opt)]),
|
||||
);
|
||||
onChange(newValue, fieldPathId?.path);
|
||||
};
|
||||
|
||||
return {
|
||||
optionSchema,
|
||||
options,
|
||||
selection,
|
||||
createChangeHandler,
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
import { descriptionId, FieldProps, getTemplate, titleId } from "@rjsf/utils";
|
||||
import { Table, RowData } from "@/components/molecules/Table/Table";
|
||||
import { useMemo } from "react";
|
||||
|
||||
export const TableField = (props: FieldProps) => {
|
||||
const { schema, formData, onChange, fieldPathId, registry, uiSchema } = props;
|
||||
|
||||
const itemSchema = schema.items as any;
|
||||
const properties = itemSchema?.properties || {};
|
||||
|
||||
const columns: string[] = useMemo(() => {
|
||||
return Object.keys(properties);
|
||||
}, [properties]);
|
||||
|
||||
const handleChange = (rows: RowData[]) => {
|
||||
onChange(rows, fieldPathId?.path.slice(0, -1));
|
||||
};
|
||||
|
||||
const TitleFieldTemplate = getTemplate("TitleFieldTemplate", registry);
|
||||
const DescriptionFieldTemplate = getTemplate(
|
||||
"DescriptionFieldTemplate",
|
||||
registry,
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-2">
|
||||
<TitleFieldTemplate
|
||||
id={titleId(fieldPathId)}
|
||||
title={schema.title || ""}
|
||||
required={true}
|
||||
schema={schema}
|
||||
uiSchema={uiSchema}
|
||||
registry={registry}
|
||||
/>
|
||||
<DescriptionFieldTemplate
|
||||
id={descriptionId(fieldPathId)}
|
||||
description={schema.description || ""}
|
||||
schema={schema}
|
||||
registry={registry}
|
||||
/>
|
||||
|
||||
<Table
|
||||
columns={columns}
|
||||
defaultValues={formData}
|
||||
onChange={handleChange}
|
||||
allowAddRow={true}
|
||||
allowDeleteRow={true}
|
||||
addRowLabel="Add row"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -1,6 +1,10 @@
|
||||
import { FieldProps, RJSFSchema, RegistryFieldsType } from "@rjsf/utils";
|
||||
import { CredentialsField } from "./CredentialField/CredentialField";
|
||||
import { GoogleDrivePickerField } from "./GoogleDrivePickerField/GoogleDrivePickerField";
|
||||
import { JsonTextField } from "./JsonTextField/JsonTextField";
|
||||
import { MultiSelectField } from "./MultiSelectField/MultiSelectField";
|
||||
import { isMultiSelectSchema } from "../utils/schema-utils";
|
||||
import { TableField } from "./TableField/TableField";
|
||||
|
||||
export interface CustomFieldDefinition {
|
||||
id: string;
|
||||
@@ -8,6 +12,9 @@ export interface CustomFieldDefinition {
|
||||
component: (props: FieldProps<any, RJSFSchema, any>) => JSX.Element | null;
|
||||
}
|
||||
|
||||
/** Field ID for JsonTextField - used to render nested complex types as text input */
|
||||
export const JSON_TEXT_FIELD_ID = "custom/json_text_field";
|
||||
|
||||
export const CUSTOM_FIELDS: CustomFieldDefinition[] = [
|
||||
{
|
||||
id: "custom/credential_field",
|
||||
@@ -30,6 +37,28 @@ export const CUSTOM_FIELDS: CustomFieldDefinition[] = [
|
||||
},
|
||||
component: GoogleDrivePickerField,
|
||||
},
|
||||
{
|
||||
id: "custom/json_text_field",
|
||||
// Not matched by schema - assigned via uiSchema for nested complex types
|
||||
matcher: () => false,
|
||||
component: JsonTextField,
|
||||
},
|
||||
{
|
||||
id: "custom/multi_select_field",
|
||||
matcher: isMultiSelectSchema,
|
||||
component: MultiSelectField,
|
||||
},
|
||||
{
|
||||
id: "custom/table_field",
|
||||
matcher: (schema: any) => {
|
||||
return (
|
||||
schema.type === "array" &&
|
||||
"format" in schema &&
|
||||
schema.format === "table"
|
||||
);
|
||||
},
|
||||
component: TableField,
|
||||
},
|
||||
];
|
||||
|
||||
export function findCustomFieldId(schema: any): string | null {
|
||||
|
||||
@@ -1,19 +1,46 @@
|
||||
import { RJSFSchema, UiSchema } from "@rjsf/utils";
|
||||
import { findCustomFieldId } from "../custom/custom-registry";
|
||||
import {
|
||||
findCustomFieldId,
|
||||
JSON_TEXT_FIELD_ID,
|
||||
} from "../custom/custom-registry";
|
||||
|
||||
function isComplexType(schema: RJSFSchema): boolean {
|
||||
return schema.type === "object" || schema.type === "array";
|
||||
}
|
||||
|
||||
function hasComplexAnyOfOptions(schema: RJSFSchema): boolean {
|
||||
const options = schema.anyOf || schema.oneOf;
|
||||
if (!Array.isArray(options)) return false;
|
||||
return options.some(
|
||||
(opt: any) =>
|
||||
opt &&
|
||||
typeof opt === "object" &&
|
||||
(opt.type === "object" || opt.type === "array"),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates uiSchema with ui:field settings for custom fields based on schema matchers.
|
||||
* This is the standard RJSF way to route fields to custom components.
|
||||
*
|
||||
* Nested complex types (arrays/objects inside arrays/objects) are rendered as JsonTextField
|
||||
* to avoid deeply nested form UIs. Users can enter raw JSON for these fields.
|
||||
*
|
||||
* @param schema - The JSON schema
|
||||
* @param existingUiSchema - Existing uiSchema to merge with
|
||||
* @param insideComplexType - Whether we're already inside a complex type (object/array)
|
||||
*/
|
||||
export function generateUiSchemaForCustomFields(
|
||||
schema: RJSFSchema,
|
||||
existingUiSchema: UiSchema = {},
|
||||
insideComplexType: boolean = false,
|
||||
): UiSchema {
|
||||
const uiSchema: UiSchema = { ...existingUiSchema };
|
||||
|
||||
if (schema.properties) {
|
||||
for (const [key, propSchema] of Object.entries(schema.properties)) {
|
||||
if (propSchema && typeof propSchema === "object") {
|
||||
// First check for custom field matchers (credentials, google drive, etc.)
|
||||
const customFieldId = findCustomFieldId(propSchema);
|
||||
|
||||
if (customFieldId) {
|
||||
@@ -21,8 +48,33 @@ export function generateUiSchemaForCustomFields(
|
||||
...(uiSchema[key] as object),
|
||||
"ui:field": customFieldId,
|
||||
};
|
||||
// Skip further processing for custom fields
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle nested complex types - render as JsonTextField
|
||||
if (insideComplexType && isComplexType(propSchema as RJSFSchema)) {
|
||||
uiSchema[key] = {
|
||||
...(uiSchema[key] as object),
|
||||
"ui:field": JSON_TEXT_FIELD_ID,
|
||||
};
|
||||
// Don't recurse further - this field is now a text input
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle anyOf/oneOf inside complex types
|
||||
if (
|
||||
insideComplexType &&
|
||||
hasComplexAnyOfOptions(propSchema as RJSFSchema)
|
||||
) {
|
||||
uiSchema[key] = {
|
||||
...(uiSchema[key] as object),
|
||||
"ui:field": JSON_TEXT_FIELD_ID,
|
||||
};
|
||||
continue;
|
||||
}
|
||||
|
||||
// Recurse into object properties
|
||||
if (
|
||||
propSchema.type === "object" &&
|
||||
propSchema.properties &&
|
||||
@@ -31,6 +83,7 @@ export function generateUiSchemaForCustomFields(
|
||||
const nestedUiSchema = generateUiSchemaForCustomFields(
|
||||
propSchema as RJSFSchema,
|
||||
(uiSchema[key] as UiSchema) || {},
|
||||
true, // Now inside a complex type
|
||||
);
|
||||
uiSchema[key] = {
|
||||
...(uiSchema[key] as object),
|
||||
@@ -38,9 +91,11 @@ export function generateUiSchemaForCustomFields(
|
||||
};
|
||||
}
|
||||
|
||||
// Handle arrays
|
||||
if (propSchema.type === "array" && propSchema.items) {
|
||||
const itemsSchema = propSchema.items as RJSFSchema;
|
||||
if (itemsSchema && typeof itemsSchema === "object") {
|
||||
// Check for custom field on array items
|
||||
const itemsCustomFieldId = findCustomFieldId(itemsSchema);
|
||||
if (itemsCustomFieldId) {
|
||||
uiSchema[key] = {
|
||||
@@ -49,10 +104,28 @@ export function generateUiSchemaForCustomFields(
|
||||
"ui:field": itemsCustomFieldId,
|
||||
},
|
||||
};
|
||||
} else if (isComplexType(itemsSchema)) {
|
||||
// Array items that are complex types become JsonTextField
|
||||
uiSchema[key] = {
|
||||
...(uiSchema[key] as object),
|
||||
items: {
|
||||
"ui:field": JSON_TEXT_FIELD_ID,
|
||||
},
|
||||
};
|
||||
} else if (hasComplexAnyOfOptions(itemsSchema)) {
|
||||
// Array items with anyOf containing complex types become JsonTextField
|
||||
uiSchema[key] = {
|
||||
...(uiSchema[key] as object),
|
||||
items: {
|
||||
"ui:field": JSON_TEXT_FIELD_ID,
|
||||
},
|
||||
};
|
||||
} else if (itemsSchema.properties) {
|
||||
// Recurse into object items (but they're now inside a complex type)
|
||||
const itemsUiSchema = generateUiSchemaForCustomFields(
|
||||
itemsSchema,
|
||||
((uiSchema[key] as UiSchema)?.items as UiSchema) || {},
|
||||
true, // Inside complex type (array)
|
||||
);
|
||||
if (Object.keys(itemsUiSchema).length > 0) {
|
||||
uiSchema[key] = {
|
||||
@@ -63,6 +136,61 @@ export function generateUiSchemaForCustomFields(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle anyOf/oneOf at root level - process complex options
|
||||
if (!insideComplexType) {
|
||||
const anyOfOptions = propSchema.anyOf || propSchema.oneOf;
|
||||
|
||||
if (Array.isArray(anyOfOptions)) {
|
||||
for (let i = 0; i < anyOfOptions.length; i++) {
|
||||
const option = anyOfOptions[i] as RJSFSchema;
|
||||
if (option && typeof option === "object") {
|
||||
// Handle anyOf array options with complex items
|
||||
if (option.type === "array" && option.items) {
|
||||
const itemsSchema = option.items as RJSFSchema;
|
||||
if (itemsSchema && typeof itemsSchema === "object") {
|
||||
// Array items that are complex types become JsonTextField
|
||||
if (isComplexType(itemsSchema)) {
|
||||
uiSchema[key] = {
|
||||
...(uiSchema[key] as object),
|
||||
items: {
|
||||
"ui:field": JSON_TEXT_FIELD_ID,
|
||||
},
|
||||
};
|
||||
} else if (hasComplexAnyOfOptions(itemsSchema)) {
|
||||
uiSchema[key] = {
|
||||
...(uiSchema[key] as object),
|
||||
items: {
|
||||
"ui:field": JSON_TEXT_FIELD_ID,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Recurse into anyOf object options with properties
|
||||
if (
|
||||
option.type === "object" &&
|
||||
option.properties &&
|
||||
typeof option.properties === "object"
|
||||
) {
|
||||
const optionUiSchema = generateUiSchemaForCustomFields(
|
||||
option,
|
||||
{},
|
||||
true, // Inside complex type (anyOf object option)
|
||||
);
|
||||
if (Object.keys(optionUiSchema).length > 0) {
|
||||
// Store under the property key - RJSF will apply it
|
||||
uiSchema[key] = {
|
||||
...(uiSchema[key] as object),
|
||||
...optionUiSchema,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import { getUiOptions, RJSFSchema, UiSchema } from "@rjsf/utils";
|
||||
|
||||
export function isAnyOfSchema(schema: RJSFSchema | undefined): boolean {
|
||||
return Array.isArray(schema?.anyOf) && schema!.anyOf.length > 0;
|
||||
return (
|
||||
Array.isArray(schema?.anyOf) &&
|
||||
schema!.anyOf.length > 0 &&
|
||||
schema?.enum === undefined
|
||||
);
|
||||
}
|
||||
|
||||
export const isAnyOfChild = (
|
||||
@@ -33,3 +37,21 @@ export function isOptionalType(schema: RJSFSchema | undefined): {
|
||||
export function isAnyOfSelector(name: string) {
|
||||
return name.includes("anyof_select");
|
||||
}
|
||||
|
||||
export function isMultiSelectSchema(schema: RJSFSchema | undefined): boolean {
|
||||
if (typeof schema !== "object" || schema === null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ("anyOf" in schema || "oneOf" in schema) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return !!(
|
||||
schema.type === "object" &&
|
||||
schema.properties &&
|
||||
Object.values(schema.properties).every(
|
||||
(prop: any) => prop.type === "boolean",
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user